├── src ├── __init__.py ├── scrl.py ├── rlpyt_buffer.py ├── agent.py ├── efficient_sampling_offline_dataset.py ├── rlpyt_atari_env.py ├── offline_dataset.py ├── algos.py ├── utils.py ├── rlpyt_utils.py ├── models.py └── networks.py ├── requirements.txt ├── scripts ├── experiments │ ├── bc_finetune.sh │ ├── sgim_finetune.sh │ ├── sgiw_finetune.sh │ ├── sgiml_finetune.sh │ ├── sgim_pretrain.sh │ ├── bc_pretrain.sh │ ├── sgiw_pretrain.sh │ └── sgiml_pretrain.sh ├── download_replay_dataset.sh ├── run.sh ├── config.yaml └── run.py ├── LICENSE ├── .gitignore └── README.md /src/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | 3 | register( 4 | id='atari-v0', 5 | entry_point='src.envs:AtariEnv', 6 | ) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gym[atari]==0.17.3 2 | wandb==0.10.30 3 | opencv-python 4 | recordclass 5 | matplotlib 6 | git+https://github.com/kornia/kornia 7 | numpy 8 | pyprind 9 | -e git+https://github.com/astooke/rlpyt.git@b32d589d12d31ba3c8a9cfb7a3c85c6e350b2904#egg=rlpyt 10 | hydra-core==1.0.6 11 | tqdm 12 | -------------------------------------------------------------------------------- /scripts/experiments/bc_finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | declare -A map=( ["pong"]="Pong" ["breakout"]="Breakout" ["up_n_down"]="UpNDown" ["kangaroo"]="Kangaroo" ["bank_heist"]="BankHeist" ["assault"]="Assault" ["boxing"]="Boxing" ["battle_zone"]="BattleZone" ["frostbite"]="Frostbite" ["crazy_climber"]="CrazyClimber" ["chopper_command"]="ChopperCommand" ["demon_attack"]="DemonAttack" ["alien"]="Alien" ["kung_fu_master"]="KungFuMaster" ["qbert"]="Qbert" ["ms_pacman"]="MsPacman" ["hero"]="Hero" ["seaquest"]="Seaquest" ["jamesbond"]="Jamesbond" ["amidar"]="Amidar" ["asterix"]="Asterix" ["private_eye"]="PrivateEye" ["gopher"]="Gopher" ["krull"]="Krull" ["freeway"]="Freeway" ["road_runner"]="RoadRunner" ) 3 | export game=$1 4 | shift 5 | export seed=$1 6 | 7 | python -m scripts.run public=True env.game=$game seed=$seed num_logs=10 \ 8 | model_load=bc_${game}_resnet_${seed} \ 9 | model_folder=./ \ 10 | algo.encoder_lr=0.000001 \ 11 | algo.q_l1_lr=0.00003\ 12 | algo.clip_grad_norm=-1 \ 13 | algo.clip_model_grad_norm=-1 14 | -------------------------------------------------------------------------------- /scripts/experiments/sgim_finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | declare -A map=( ["pong"]="Pong" ["breakout"]="Breakout" ["up_n_down"]="UpNDown" ["kangaroo"]="Kangaroo" ["bank_heist"]="BankHeist" ["assault"]="Assault" ["boxing"]="Boxing" ["battle_zone"]="BattleZone" ["frostbite"]="Frostbite" ["crazy_climber"]="CrazyClimber" ["chopper_command"]="ChopperCommand" ["demon_attack"]="DemonAttack" ["alien"]="Alien" ["kung_fu_master"]="KungFuMaster" ["qbert"]="Qbert" ["ms_pacman"]="MsPacman" ["hero"]="Hero" ["seaquest"]="Seaquest" ["jamesbond"]="Jamesbond" ["amidar"]="Amidar" ["asterix"]="Asterix" ["private_eye"]="PrivateEye" ["gopher"]="Gopher" ["krull"]="Krull" ["freeway"]="Freeway" ["road_runner"]="RoadRunner" ) 3 | export game=$1 4 | shift 5 | export seed=$1 6 | 7 | python -m scripts.run public=True env.game=$game seed=$seed num_logs=10 \ 8 | model_load=sgim_${game}_resnet_${seed} \ 9 | model_folder=./ \ 10 | algo.encoder_lr=0.000001 \ 11 | algo.q_l1_lr=0.00003\ 12 | algo.clip_grad_norm=-1 \ 13 | algo.clip_model_grad_norm=-1 -------------------------------------------------------------------------------- /scripts/experiments/sgiw_finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | declare -A map=( ["pong"]="Pong" ["breakout"]="Breakout" ["up_n_down"]="UpNDown" ["kangaroo"]="Kangaroo" ["bank_heist"]="BankHeist" ["assault"]="Assault" ["boxing"]="Boxing" ["battle_zone"]="BattleZone" ["frostbite"]="Frostbite" ["crazy_climber"]="CrazyClimber" ["chopper_command"]="ChopperCommand" ["demon_attack"]="DemonAttack" ["alien"]="Alien" ["kung_fu_master"]="KungFuMaster" ["qbert"]="Qbert" ["ms_pacman"]="MsPacman" ["hero"]="Hero" ["seaquest"]="Seaquest" ["jamesbond"]="Jamesbond" ["amidar"]="Amidar" ["asterix"]="Asterix" ["private_eye"]="PrivateEye" ["gopher"]="Gopher" ["krull"]="Krull" ["freeway"]="Freeway" ["road_runner"]="RoadRunner" ) 3 | export game=$1 4 | shift 5 | export seed=$1 6 | 7 | python -m scripts.run public=True env.game=$game seed=$seed num_logs=10 \ 8 | model_load=sgiw_${game}_resnet_${seed} \ 9 | model_folder=./ \ 10 | algo.encoder_lr=0.000001 \ 11 | algo.q_l1_lr=0.00003\ 12 | algo.clip_grad_norm=-1 \ 13 | algo.clip_model_grad_norm=-1 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Ankesh Anand 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scripts/experiments/sgiml_finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | declare -A map=( ["pong"]="Pong" ["breakout"]="Breakout" ["up_n_down"]="UpNDown" ["kangaroo"]="Kangaroo" ["bank_heist"]="BankHeist" ["assault"]="Assault" ["boxing"]="Boxing" ["battle_zone"]="BattleZone" ["frostbite"]="Frostbite" ["crazy_climber"]="CrazyClimber" ["chopper_command"]="ChopperCommand" ["demon_attack"]="DemonAttack" ["alien"]="Alien" ["kung_fu_master"]="KungFuMaster" ["qbert"]="Qbert" ["ms_pacman"]="MsPacman" ["hero"]="Hero" ["seaquest"]="Seaquest" ["jamesbond"]="Jamesbond" ["amidar"]="Amidar" ["asterix"]="Asterix" ["private_eye"]="PrivateEye" ["gopher"]="Gopher" ["krull"]="Krull" ["freeway"]="Freeway" ["road_runner"]="RoadRunner" ) 3 | export game=$1 4 | shift 5 | export seed=$1 6 | 7 | python -m scripts.run public=True env.game=$game seed=$seed num_logs=10 \ 8 | model_load=sgiml_${game}_resnet_${seed} \ 9 | model_folder=./ \ 10 | agent.model_kwargs.blocks_per_group=5 agent.model_kwargs.expand_ratio=4 \ 11 | agent.model_kwargs.cnn_scale_factor=1.5 \ 12 | algo.encoder_lr=0.000001 \ 13 | algo.q_l1_lr=0.00003\ 14 | algo.clip_grad_norm=-1 \ 15 | algo.clip_model_grad_norm=-1 16 | -------------------------------------------------------------------------------- /scripts/experiments/sgim_pretrain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | declare -A map=( ["pong"]="Pong" ["breakout"]="Breakout" ["up_n_down"]="UpNDown" ["kangaroo"]="Kangaroo" ["bank_heist"]="BankHeist" ["assault"]="Assault" ["boxing"]="Boxing" ["battle_zone"]="BattleZone" ["frostbite"]="Frostbite" ["crazy_climber"]="CrazyClimber" ["chopper_command"]="ChopperCommand" ["demon_attack"]="DemonAttack" ["alien"]="Alien" ["kung_fu_master"]="KungFuMaster" ["qbert"]="Qbert" ["ms_pacman"]="MsPacman" ["hero"]="Hero" ["seaquest"]="Seaquest" ["jamesbond"]="Jamesbond" ["amidar"]="Amidar" ["asterix"]="Asterix" ["private_eye"]="PrivateEye" ["gopher"]="Gopher" ["krull"]="Krull" ["freeway"]="Freeway" ["road_runner"]="RoadRunner" ) 3 | export game=$1 4 | shift 5 | export seed=$1 6 | 7 | python -m scripts.run public=True model_folder=./ offline.runner.save_every=2500 \ 8 | env.game=pong seed=1 offline_model_save=sgim_${game}_resnet_${seed} \ 9 | offline.runner.epochs=20 offline.runner.dataloader.games=[${map[${game}]}] \ 10 | offline.runner.no_eval=1 \ 11 | +offline.algo.goal_weight=1 \ 12 | +offline.algo.inverse_model_weight=1 \ 13 | +offline.algo.spr_weight=1 \ 14 | +offline.algo.target_update_tau=0.01 \ 15 | +offline.agent.model_kwargs.momentum_tau=0.01 \ 16 | do_online=False \ 17 | algo.batch_size=256 \ 18 | +offline.agent.model_kwargs.noisy_nets_std=0 \ 19 | offline.runner.dataloader.dataset_on_disk=True \ 20 | offline.runner.dataloader.samples=1000000 \ 21 | offline.runner.dataloader.checkpoints='[1,25,50]' \ 22 | offline.runner.dataloader.num_workers=2 \ 23 | offline.runner.dataloader.data_path=./data/ \ 24 | offline.runner.dataloader.tmp_data_path=./ 25 | -------------------------------------------------------------------------------- /scripts/experiments/bc_pretrain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | declare -A map=( ["pong"]="Pong" ["breakout"]="Breakout" ["up_n_down"]="UpNDown" ["kangaroo"]="Kangaroo" ["bank_heist"]="BankHeist" ["assault"]="Assault" ["boxing"]="Boxing" ["battle_zone"]="BattleZone" ["frostbite"]="Frostbite" ["crazy_climber"]="CrazyClimber" ["chopper_command"]="ChopperCommand" ["demon_attack"]="DemonAttack" ["alien"]="Alien" ["kung_fu_master"]="KungFuMaster" ["qbert"]="Qbert" ["ms_pacman"]="MsPacman" ["hero"]="Hero" ["seaquest"]="Seaquest" ["jamesbond"]="Jamesbond" ["amidar"]="Amidar" ["asterix"]="Asterix" ["private_eye"]="PrivateEye" ["gopher"]="Gopher" ["krull"]="Krull" ["freeway"]="Freeway" ["road_runner"]="RoadRunner" ) 3 | export game=$1 4 | shift 5 | export seed=$1 6 | 7 | python -m scripts.run public=True model_folder=./ offline.runner.save_every=2500 \ 8 | env.game=pong seed=1 offline_model_save=bc_${game}_resnet_${seed} \ 9 | offline.runner.epochs=20 offline.runner.dataloader.games=[${map[${game}]}] \ 10 | offline.runner.no_eval=1 \ 11 | +offline.algo.bc_weight=1 \ 12 | +offline.algo.goal_weight=0 \ 13 | +offline.algo.inverse_model_weight=0 \ 14 | +offline.algo.spr_weight=0 \ 15 | +offline.algo.target_update_tau=0.01 \ 16 | +offline.agent.model_kwargs.momentum_tau=0.01 \ 17 | +offline.algo.jumps=0 \ 18 | do_online=False \ 19 | algo.batch_size=256 \ 20 | +offline.agent.model_kwargs.noisy_nets_std=0 \ 21 | offline.runner.dataloader.dataset_on_disk=True \ 22 | offline.runner.dataloader.samples=1000000 \ 23 | offline.runner.dataloader.checkpoints='[1,25,50]' \ 24 | offline.runner.dataloader.num_workers=2 \ 25 | offline.runner.dataloader.data_path=./data/ \ 26 | offline.runner.dataloader.tmp_data_path=./ -------------------------------------------------------------------------------- /scripts/experiments/sgiw_pretrain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | declare -A map=( ["pong"]="Pong" ["breakout"]="Breakout" ["up_n_down"]="UpNDown" ["kangaroo"]="Kangaroo" ["bank_heist"]="BankHeist" ["assault"]="Assault" ["boxing"]="Boxing" ["battle_zone"]="BattleZone" ["frostbite"]="Frostbite" ["crazy_climber"]="CrazyClimber" ["chopper_command"]="ChopperCommand" ["demon_attack"]="DemonAttack" ["alien"]="Alien" ["kung_fu_master"]="KungFuMaster" ["qbert"]="Qbert" ["ms_pacman"]="MsPacman" ["hero"]="Hero" ["seaquest"]="Seaquest" ["jamesbond"]="Jamesbond" ["amidar"]="Amidar" ["asterix"]="Asterix" ["private_eye"]="PrivateEye" ["gopher"]="Gopher" ["krull"]="Krull" ["freeway"]="Freeway" ["road_runner"]="RoadRunner" ) 3 | export game=$1 4 | shift 5 | export seed=$1 6 | 7 | 8 | #201, 301, 401, 501 denote data from different seeds in the offline dqn dataset 9 | python -m scripts.run public=True model_folder=./ offline.runner.save_every=2500 \ 10 | env.game=pong seed=1 offline_model_save=sgiw_${game}_resnet_${seed} \ 11 | offline.runner.epochs=20 offline.runner.dataloader.games=[${map[${game}]}] \ 12 | offline.runner.no_eval=1 \ 13 | +offline.algo.goal_weight=1 \ 14 | +offline.algo.inverse_model_weight=1 \ 15 | +offline.algo.spr_weight=1 \ 16 | +offline.algo.target_update_tau=0.01 \ 17 | +offline.agent.model_kwargs.momentum_tau=0.01 \ 18 | do_online=False \ 19 | algo.batch_size=256 \ 20 | +offline.agent.model_kwargs.noisy_nets_std=0 \ 21 | offline.runner.dataloader.dataset_on_disk=True \ 22 | offline.runner.dataloader.samples=1000000 \ 23 | offline.runner.dataloader.checkpoints='[1,201,301,401,501]' \ 24 | offline.runner.dataloader.num_workers=2 \ 25 | offline.runner.dataloader.data_path=./data/ \ 26 | offline.runner.dataloader.tmp_data_path=./ 27 | -------------------------------------------------------------------------------- /scripts/experiments/sgiml_pretrain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | declare -A map=( ["pong"]="Pong" ["breakout"]="Breakout" ["up_n_down"]="UpNDown" ["kangaroo"]="Kangaroo" ["bank_heist"]="BankHeist" ["assault"]="Assault" ["boxing"]="Boxing" ["battle_zone"]="BattleZone" ["frostbite"]="Frostbite" ["crazy_climber"]="CrazyClimber" ["chopper_command"]="ChopperCommand" ["demon_attack"]="DemonAttack" ["alien"]="Alien" ["kung_fu_master"]="KungFuMaster" ["qbert"]="Qbert" ["ms_pacman"]="MsPacman" ["hero"]="Hero" ["seaquest"]="Seaquest" ["jamesbond"]="Jamesbond" ["amidar"]="Amidar" ["asterix"]="Asterix" ["private_eye"]="PrivateEye" ["gopher"]="Gopher" ["krull"]="Krull" ["freeway"]="Freeway" ["road_runner"]="RoadRunner" ) 3 | export game=$1 4 | shift 5 | export seed=$1 6 | 7 | python -m scripts.run public=True model_folder=./ offline.runner.save_every=2500 \ 8 | env.game=pong seed=1 offline_model_save=sgiml_${game}_resnet_${seed} \ 9 | agent.model_kwargs.blocks_per_group=5 agent.model_kwargs.expand_ratio=4 \ 10 | agent.model_kwargs.cnn_scale_factor=1.5 \ 11 | offline.runner.epochs=10 offline.runner.dataloader.games=[${map[${game}]}] \ 12 | offline.runner.no_eval=1 \ 13 | +offline.algo.goal_weight=1 \ 14 | +offline.algo.inverse_model_weight=1 \ 15 | +offline.algo.spr_weight=1 \ 16 | +offline.algo.target_update_tau=0.01 \ 17 | +offline.agent.model_kwargs.momentum_tau=0.01 \ 18 | do_online=False \ 19 | algo.batch_size=256 \ 20 | +offline.agent.model_kwargs.noisy_nets_std=0 \ 21 | offline.runner.dataloader.dataset_on_disk=True \ 22 | offline.runner.dataloader.samples=1000000 \ 23 | offline.runner.dataloader.checkpoints='[1,5,10,15,20,25]' \ 24 | offline.runner.dataloader.num_workers=2 \ 25 | offline.runner.dataloader.data_path=./data/ \ 26 | offline.runner.dataloader.tmp_data_path=./ 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | tests 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | 107 | .idea 108 | *.pt 109 | tmp/ 110 | wandb 111 | philly/ 112 | results/ 113 | 114 | # VS Code 115 | .vscode/ 116 | 117 | # Hydra 118 | .hydra/ 119 | 120 | # vim 121 | *.swp 122 | -------------------------------------------------------------------------------- /scripts/download_replay_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | games='AirRaid Alien Amidar Assault Asterix Asteroids Atlantis BankHeist BattleZone BeamRider Berzerk Bowling Boxing Breakout Carnival Centipede ChopperCommand CrazyClimber DemonAttack DoubleDunk ElevatorAction Enduro FishingDerby Freeway Frostbite Gopher Gravitar Hero IceHockey Jamesbond JourneyEscape Kangaroo Krull KungFuMaster MontezumaRevenge MsPacman NameThisGame Phoenix Pitfall Pong Pooyan PrivateEye Qbert Riverraid RoadRunner Robotank Seaquest Skiing Solaris SpaceInvaders StarGunner Tennis TimePilot Tutankham UpNDown Venture VideoPinball WizardOfWor YarsRevenge Zaxxon' 3 | ckpts='1 5 10 15 20 25 35 50' 4 | runs='1 2 3 4 5' 5 | files='action observation reward terminal' 6 | export data_dir=$1 7 | 8 | echo "Missing Files:" 9 | for g in ${games[@]}; do 10 | for f in ${files[@]}; do 11 | for c in ${ckpts[@]}; do 12 | if [ ! -f "${data_dir}/${g}/${f}_${c}.gz" ]; then 13 | echo "${data_dir}/${g}/${f}_${c}.gz" 14 | fi; 15 | done; 16 | for r in ${runs[@]}; do 17 | if [ ! -f "${data_dir}/${g}/${f}_${r}01.gz" ]; then 18 | echo "${data_dir}/${g}/${f}_${r}01.gz" 19 | fi; 20 | done; 21 | done; 22 | done; 23 | 24 | # https://stackoverflow.com/a/226724 25 | echo "Do you wish to download missing files?" 26 | select yn in "Yes" "No"; do 27 | case $yn in 28 | Yes ) break;; 29 | No ) exit;; 30 | esac 31 | done 32 | 33 | for g in ${games[@]}; do 34 | mkdir -p "${data_dir}/${g}" 35 | for f in ${files[@]}; do 36 | for c in ${ckpts[@]}; do 37 | if [ ! -f "${data_dir}/${g}/${f}_${c}.gz" ]; then 38 | gsutil cp "gs://atari-replay-datasets/dqn/${g}/1/replay_logs/\$store\$_${f}_ckpt.${c}.gz" "${data_dir}/${g}/${f}_${c}.gz" 39 | fi; 40 | done; 41 | for r in ${runs[@]}; do 42 | if [ ! -f "${data_dir}/${g}/${f}_${r}01.gz" ]; then 43 | gsutil cp "gs://atari-replay-datasets/dqn/${g}/${r}/replay_logs/\$store\$_${f}_ckpt.1.gz" "${data_dir}/${g}/${f}_${r}01.gz" 44 | fi; 45 | done; 46 | done; 47 | done; 48 | -------------------------------------------------------------------------------- /scripts/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source $HOME/.bashrc 3 | module load python/3.8.2 4 | [[ -d '/project/' ]] && module load StdEnv/2020 gcc/9.3.0 cuda/11.0 opencv/4.5.1 5 | if [[ $(hostname) != *"blg"* && $(hostname) != *"cdr"* && $(hostname) != *"gra"* ]]; then 6 | module load python/3.7/cuda/11.0/cudnn/8.0/pytorch/1.7.0 7 | fi 8 | 9 | echo Running on $HOSTNAME 10 | nvidia-smi 11 | date 12 | 13 | PROJECT_DIR=$HOME/transfer-ssl-rl 14 | if [[ $(hostname) == *"cedar"* && ${USER} != 'schwarzm' ]]; then 15 | PROJECT_DIR=${SCRATCH}/transfer-ssl-rl 16 | elif [[ ${USER} == 'schwarzm' && $(hostname) == *"cedar"* ]]; then 17 | PROJECT_DIR=${SCRATCH}/Github/transfer-ssl-rl/ 18 | elif [[ ${USER} == 'schwarzm' ]]; then 19 | PROJECT_DIR=${HOME}/Github/transfer-ssl-rl/ 20 | fi 21 | 22 | cd ${PROJECT_DIR} 23 | mkdir -p models 24 | mkdir -p ${SLURM_TMPDIR}/atari 25 | export TMP_DATA_DIR=${SLURM_TMPDIR}/atari 26 | 27 | #Set up virtualenv 28 | echo 'Installing dependencies' 29 | python -m venv ${SLURM_TMPDIR}/env 30 | source ${SLURM_TMPDIR}/env/bin/activate 31 | if [[ $(hostname) == *"blg"* || $(hostname) == *"cedar"* || $(hostname) == *"gra"* ]]; then 32 | pip install --no-index -U pip 33 | pip install --no-index --find-links=/scratch/schwarzm/wheels_38 -r scripts/requirements_cc.txt 34 | else 35 | pip install -U pip 36 | pip install -r requirements.txt 37 | fi 38 | #conda activate pytorch 39 | #export PATH=~/anaconda3/envs/pytorch/bin:~/miniconda3/envs/pytorch/bin:$PATH 40 | 41 | # Set default data directories for reading and writing 42 | if [[ -d '/network/' ]]; then # mc 43 | export DATA_DIR=/network/tmp1/rajkuman/atari/ 44 | export USER_DATA_DIR=/network/tmp1/${USER}/atari/ 45 | python -m atari_py.import_roms Roms/ # need to manually load on MC for some reason 46 | elif [[ $(hostname) == *"cedar"* ]]; then # cc 47 | export DATA_DIR=/project/rrg-bengioy-ad/rajkuman/atari 48 | export USER_DATA_DIR=/project/rrg-bengioy-ad/${USER}/atari 49 | echo 'Setting W&B to offline' 50 | wandb off 51 | export WANDB_MODE=dryrun 52 | elif [[ $(hostname) == *"blg"* ]]; then # cc 53 | export DATA_DIR=/project/rrg-bengioy-ad/rajkuman/atari 54 | export USER_DATA_DIR=/project/rrg-bengioy-ad/${USER}/atari 55 | 56 | echo 'Setting W&B to offline' 57 | wandb off 58 | export WANDB_MODE=dryrun 59 | elif [[ $(hostname) == *"gra"* ]]; then # cc 60 | export DATA_DIR=/project/rrg-bengioy-ad/rajkuman/atari 61 | export USER_DATA_DIR=/project/def-bengioy/${USER}/atari 62 | 63 | echo 'Setting W&B to offline' 64 | wandb off 65 | export WANDB_MODE=dryrun 66 | fi 67 | 68 | mkdir -p ${USER_DATA_DIR} 69 | 70 | echo 'Starting experiment' 71 | python -u -m scripts.run wandb.dir=$PROJECT_DIR "$@" 72 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pretraining Representations For Data-Efficient Reinforcement Learning 2 | 3 | *Max Schwarzer, Nitarshan Rajkumar, Michael Noukhovitch, Ankesh Anand, Laurent Charlin, Devon Hjelm, Philip Bachman & Aaron Courville* 4 | 5 | This repo provides code for implementing SGI. 6 | 7 | * [📦 Install ](#install) -- Install relevant dependencies and the project 8 | * [🔧 Usage ](#usage) -- Commands to run different experiments from the paper 9 | 10 | ## Install 11 | To install the requirements, follow these steps: 12 | ```bash 13 | # PyTorch 14 | export LANG=C.UTF-8 15 | # Install requirements 16 | pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html 17 | pip install -r requirements.txt 18 | 19 | # Finally, install the project 20 | pip install --user -e . 21 | ``` 22 | 23 | ## Usage: 24 | The default branch for the latest and stable changes is `release`. 25 | 26 | To run SGI: 27 | 1. Use the helper script to download and parse checkpoints from the [DQN Replay Dataset](https://research.google/tools/datasets/dqn-replay/); this requires [gsutil](https://cloud.google.com/storage/docs/gsutil_install#install) to be installed. You may want to modify the script to download fewer checkpoints from fewer games, as otherwise this requires significant storage. 28 | * Or substitute your own pre-training data! The codebase expects a series of .gz files, one each for observations, actions and terminals. 29 | ```bash 30 | bash scripts/download_replay_dataset.sh $DATA_DIR 31 | ``` 32 | 2. To pretrain with SGI: 33 | ```bash 34 | python -m scripts.run public=True model_folder=./ offline.runner.save_every=2500 \ 35 | env.game=pong seed=1 offline_model_save={your model name} \ 36 | offline.runner.epochs=10 offline.runner.dataloader.games=[Pong] \ 37 | offline.runner.no_eval=1 \ 38 | +offline.algo.goal_weight=1 \ 39 | +offline.algo.inverse_model_weight=1 \ 40 | +offline.algo.spr_weight=1 \ 41 | +offline.algo.target_update_tau=0.01 \ 42 | +offline.agent.model_kwargs.momentum_tau=0.01 \ 43 | do_online=False \ 44 | algo.batch_size=256 \ 45 | +offline.agent.model_kwargs.noisy_nets_std=0 \ 46 | offline.runner.dataloader.dataset_on_disk=True \ 47 | offline.runner.dataloader.samples=1000000 \ 48 | offline.runner.dataloader.checkpoints='{your checkpoints}' \ 49 | offline.runner.dataloader.num_workers=2 \ 50 | offline.runner.dataloader.data_path={your data dir} \ 51 | offline.runner.dataloader.tmp_data_path=./ 52 | ``` 53 | 3. To fine-tune with SGI: 54 | ```bash 55 | python -m scripts.run public=True env.game=pong seed=1 num_logs=10 \ 56 | model_load={your_model_name} model_folder=./ \ 57 | algo.encoder_lr=0.000001 algo.q_l1_lr=0.00003 algo.clip_grad_norm=-1 algo.clip_model_grad_norm=-1 58 | ``` 59 | 60 | When reporting scores, we average across 10 fine-tuning seeds. 61 | 62 | `./scripts/experiments` contains a number of example configurations, including for SGI-M, SGI-M/L and SGI-W, for both pre-training and fine-tuning. 63 | Each of these scripts can be launched by providing a game and seed, e.g., `./scripts/experiments/sgim_pretrain.sh pong 1`. These scripts are provided primarily to illustrate the hyperparameters used for different experiments; you will likely need to modify the arguments in these scripts to point to your data and model directories. 64 | 65 | Data for SGI-R and SGI-E is not included due to its size, but can be re-generated locally. Contact us for details. 66 | 67 | ## What does each file do? 68 | 69 | . 70 | ├── scripts 71 | │ ├── run.py # The main runner script to launch jobs. 72 | │ ├── config.yaml # The hydra configuration file, listing hyperparameters and options. 73 | | ├── download_replay_dataset.sh # Helper script to download the DQN replay dataset. 74 | | └── experiments # Configurations for various experiments done by SGI. 75 | | 76 | ├── src 77 | │ ├── agent.py # Implements the Agent API for action selection 78 | │ ├── algos.py # Distributional RL loss and optimization 79 | │ ├── models.py # Forward passes, network initialization. 80 | │ ├── networks.py # Network architecture and forward passes. 81 | │ ├── offline_dataset.py # Dataloader for offline data. 82 | │ ├── gcrl.py # Utils for SGI's goal-conditioned RL objective. 83 | │ ├── rlpyt_atari_env.py # Slightly modified Atari env from rlpyt 84 | │ ├── rlpyt_utils.py # Utility methods that we use to extend rlpyt's functionality 85 | │ └── utils.py # Command line arguments and helper functions 86 | │ 87 | └── requirements.txt # Dependencies 88 | -------------------------------------------------------------------------------- /src/scrl.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def exp_distance(x, y, t): 7 | dist = norm_dist(x, y, t) 8 | return 1 - torch.exp(-dist) 9 | 10 | 11 | def norm_dist(x, y, t): 12 | x = F.normalize(x, p=2, dim=-1, eps=1e-3) 13 | y = F.normalize(y, p=2, dim=-1, eps=1e-3) 14 | dist = (x - y).pow(2).sum(-1) 15 | return t*dist 16 | 17 | 18 | def calculate_returns(states, 19 | goal, 20 | distance, 21 | gamma, 22 | nonterminal, 23 | distance_scale, 24 | reward_scale=10., 25 | all_to_all=False): 26 | """ 27 | :param states: (batch, jumps, dim) 28 | :param goal: (batch, dim) 29 | :param distance: distance function (state X state X scale -> R). 30 | :param gamma: rl discount gamma in [0, 1] 31 | :param nonterminal: 1 - done, (batch, jumps). 32 | :return: returns: discounted sum of rewards up to t, (batch, jumps); 33 | has shape (batch, batch, jumps) if all_to_all enabled 34 | """ 35 | nonterminal = nonterminal.transpose(0, 1) 36 | 37 | if all_to_all: 38 | states = states.unsqueeze(1) 39 | goal = goal.unsqueeze(0) 40 | nonterminal = nonterminal.unsqueeze(1) 41 | 42 | goal = goal.unsqueeze(-2) 43 | distances = distance(states, goal, distance_scale) 44 | deltas = distances[..., 0:-1] - distances[..., 1:] 45 | 46 | cum_discounts = nonterminal * gamma 47 | cum_discounts = cum_discounts.cumprod(-1) 48 | 49 | discounted_rewards = reward_scale*deltas*cum_discounts 50 | returns = discounted_rewards.cumsum(-1) 51 | 52 | if all_to_all: 53 | returns = returns.flatten(0, 1) 54 | 55 | return returns.transpose(0, 1) 56 | 57 | 58 | # Possible sampling schemes: 59 | # 1. Contrastive: sometimes sample future states in other trajectories 60 | # Advantage: goal is guaranteed to be a "legal" latent 61 | # Can just implement as taking other goals in batch. 62 | # Disadvantage: Might avoid some beneficial exploration in latent space. 63 | # 2. Purely HER: 64 | # Downside: likely not to be diverse enough 65 | # 3. HER+noise: Sample as HER but add noise. 66 | # Adds diversity, maybe not enough. 67 | # 4. At random: sample random normalized vectors. 68 | # Problem: Mostly unreachable, maybe not valid latents. 69 | # Preferred solution: mixture of all methods. 70 | def sample_goals(future_observations, encoder): 71 | """ 72 | :param future_observations: (batch, jumps, c, h, w) 73 | :param encoder: map from observations to latents. 74 | :return: goals, (batch, dim). 75 | """ 76 | future_observations = future_observations.flatten(2, 3) 77 | target_time_steps = torch.randint(1, future_observations.shape[0], 78 | future_observations.shape[1:2], 79 | device=future_observations.device) 80 | target_time_steps = target_time_steps[None, :, None, None, None].expand(-1, -1, *future_observations.shape[2:]) 81 | 82 | target_obs = torch.gather(future_observations, 0, target_time_steps) 83 | 84 | goals = encoder(target_obs) 85 | return goals 86 | 87 | 88 | def sample_goals_random(batch_size, dim, device): 89 | goals = F.relu(torch.randn((batch_size, dim), device=device, dtype=torch.float)) 90 | return F.normalize(goals, dim=-1, eps=1e-3) 91 | 92 | 93 | def add_noise(goals, noise_weight=1): 94 | noise = F.normalize(torch.randn_like(goals), dim=-1, eps=1e-3) 95 | weights = torch.rand((goals.shape[0], 1), device=goals.device, dtype=goals.dtype)*noise_weight 96 | 97 | goals = weights * noise + (1 - weights)*goals 98 | return F.normalize(goals, dim=-1, eps=1e-3) 99 | 100 | 101 | def permute_goals(goals, permute_probability=0.2): 102 | """ 103 | :param goals: (batch, dim) matrix of goal states 104 | :param permute_probability: p in [0, 1] of permuting goals. 105 | :return: (batch, dim) permuted goals. 106 | """ 107 | if permute_probability <= 0: 108 | return goals 109 | original_indices = torch.arange(0, goals.shape[0], device=goals.device, dtype=torch.long) 110 | indices = torch.randint_like(original_indices, 0, goals.shape[0]) 111 | 112 | permute_mask = torch.rand(indices.shape[0], device=goals.device) < permute_probability 113 | permute_mask = permute_mask.long() 114 | 115 | new_indices = permute_mask*indices + (1 - permute_mask)*original_indices 116 | 117 | goals = torch.gather(goals, 0, new_indices.unsqueeze(-1).expand(-1, goals.size(-1))) 118 | 119 | return goals 120 | 121 | 122 | -------------------------------------------------------------------------------- /scripts/config.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: ./ 4 | 5 | experiment_group: ~ 6 | num_logs: 10 7 | seed: 1 8 | model_folder: ./ 9 | model_load: ~ # filename to load a pretrained model from 10 | load_last: True 11 | offline_model_save: ~ # filename to save pretrain model to 12 | do_online: true 13 | full_action_set: false 14 | save_by_epoch: True 15 | save_last_only: True 16 | public: False 17 | 18 | agent: 19 | eps_eval: 0.001 20 | eps_final: 0 21 | eps_init: 1 22 | repeat_random_lambda: 0 23 | softmax_policy: False 24 | model_kwargs: 25 | aug_prob: 1 26 | augmentation: [shift, intensity] # [none, rrc, affine, crop, blur, shift, intensity] 27 | target_augmentation: [shift, intensity] 28 | eval_augmentation: [none] 29 | dqn_hidden_size: 512 30 | dropout: 0 31 | dueling: 1 32 | dynamics_blocks: 0 33 | goal_n_step: ${algo.goal_n_step} 34 | imagesize: 84 35 | jumps: ${algo.jumps} 36 | momentum_tau: 0.01 37 | encoder: resnet 38 | noisy_nets_std: 0.5 39 | resblock: inverted 40 | freeze_encoder: False 41 | expand_ratio: 2 42 | cnn_scale_factor: 1 43 | blocks_per_group: 3 # 5 for sgi-l 44 | ln_for_rl_head: false 45 | noisy_nets: 1 46 | norm_type: bn 47 | predictor: linear 48 | projection: q_l1 49 | share_l1: False 50 | q_l1_type: [value, advantage] # [noisy, value, advantage, relu] 51 | renormalize: max 52 | residual_tm: 0 53 | inverse_model: ${algo.inverse_model_weight} 54 | rl: ${algo.rl_weight} 55 | bc: ${algo.bc_weight} 56 | bc_from_values: True 57 | goal_rl: ${algo.goal_weight} 58 | goal_all_to_all: ${algo.goal_all_to_all} 59 | conv_goal: ${algo.conv_goal} 60 | spr: 1 61 | goal_conditioning_type: [goal_only,film] 62 | load_head_to: 1 63 | load_compat_mode: True 64 | algo: 65 | rl_weight: 1 66 | spr_weight: 5 67 | goal_weight: 0 68 | bc_weight: 0 69 | inverse_model_weight: 0 70 | discount: 0.99 71 | batch_size: 32 72 | offline: ${runner.epochs} 73 | clip_grad_norm: 10 74 | clip_model_grad_norm: 10 75 | eps_steps: 2001 76 | jumps: 5 77 | learning_rate: 0.0001 78 | encoder_lr: ~ 79 | q_l1_lr: ~ 80 | dynamics_model_lr: ~ 81 | min_steps_learn: 2000 82 | n_step_return: 10 83 | optim_kwargs: 84 | eps: 0.00015 85 | pri_alpha: 0.5 86 | pri_beta_steps: 100000 87 | prioritized_replay: 1 88 | replay_ratio: 64 89 | target_update_interval: 1 90 | target_update_tau: 1 91 | goal_permute_prob: 0.2 92 | goal_noise_weight: 0.5 93 | goal_reward_scale: 10. 94 | goal_dist: exp 95 | goal_n_step: 1 96 | goal_window: 50 97 | goal_all_to_all: False 98 | conv_goal: True 99 | data_writer_args: 100 | game: ${env.game} 101 | data_dir: None 102 | save_data: False 103 | checkpoint_size: 1000000 104 | imagesize: [84,84] 105 | save_name: random 106 | mmap: False 107 | context: 108 | log_dir: logs 109 | run_ID: 0 110 | log_params: 111 | game: ${env.game} 112 | name: ${env.game} 113 | snapshot_mode: last 114 | override_prefix: true 115 | env: 116 | game: ms_pacman 117 | grayscale: 1 118 | imagesize: 84 119 | num_img_obs: 4 120 | seed: ${seed} 121 | full_action_set: ${full_action_set} 122 | eval_env: ${env} 123 | runner: 124 | affinity: 125 | cuda_idx: 0 126 | final_eval_only: 1 127 | no_eval: 0 128 | n_steps: 100000 129 | seed: ${seed} 130 | epochs: 0 131 | save_every: ~ 132 | dataloader: ${offline.runner.dataloader} 133 | sampler: 134 | batch_B: 1 135 | batch_T: 1 136 | env_kwargs: ${env} 137 | eval_env_kwargs: ${eval_env} 138 | eval_max_steps: 2800000 # 28k is just a safe ceiling 139 | eval_max_trajectories: 100 140 | eval_n_envs: 100 141 | max_decorrelation_steps: 0 142 | wandb: 143 | dir: '' 144 | entity: '' 145 | project: SGI 146 | tags: [] 147 | 148 | # Offline training 149 | # Can be overridden via CLI arguments to run.py 150 | offline: 151 | agent: 152 | model_kwargs: 153 | freeze_encoder: false 154 | algo: 155 | min_steps_learn: 0 156 | rl_weight: 0 157 | runner: 158 | epochs: 0 159 | save_every: 5000 160 | no_eval: 1 161 | dataloader: 162 | data_path: ./data/ 163 | tmp_data_path: ./data/ 164 | games: [MsPacman] 165 | checkpoints: [1,25,50] 166 | frames: ${env.num_img_obs} 167 | samples: 1000000 168 | jumps: ${algo.jumps} 169 | n_step_return: ${algo.n_step_return} 170 | discount: ${algo.discount} 171 | dataset_on_gpu: false 172 | dataset_on_disk: false 173 | batch_size: ${algo.batch_size} 174 | full_action_set: ${full_action_set} 175 | num_workers: 0 176 | pin_memory: false 177 | prefetch_factor: 2 178 | group_read_factor: 0 179 | shuffle_checkpoints: False 180 | -------------------------------------------------------------------------------- /src/rlpyt_buffer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import division 3 | import torch 4 | 5 | from rlpyt.replays.sequence.prioritized import SamplesFromReplayPri 6 | 7 | from rlpyt.replays.sequence.n_step import SamplesFromReplay 8 | from rlpyt.replays.sequence.frame import AsyncPrioritizedSequenceReplayFrameBuffer, \ 9 | AsyncUniformSequenceReplayFrameBuffer, PrioritizedSequenceReplayFrameBuffer 10 | from rlpyt.utils.buffer import torchify_buffer, numpify_buffer 11 | from rlpyt.utils.collections import namedarraytuple 12 | from rlpyt.utils.misc import extract_sequences 13 | import traceback 14 | 15 | PrioritizedSamples = namedarraytuple("PrioritizedSamples", 16 | ["samples", "priorities"]) 17 | SamplesToBuffer = namedarraytuple("SamplesToBuffer", 18 | ["observation", "action", "reward", "done", "policy_probs", "value"]) 19 | SamplesFromReplayExt = namedarraytuple("SamplesFromReplayPriExt", 20 | SamplesFromReplay._fields + ("values", "age")) 21 | SamplesFromReplayPriExt = namedarraytuple("SamplesFromReplayPriExt", 22 | SamplesFromReplayPri._fields + ("values", "age")) 23 | EPS = 1e-6 24 | 25 | 26 | def samples_to_buffer(observation, action, reward, done, policy_probs, value, priorities=None): 27 | samples = SamplesToBuffer( 28 | observation=observation, 29 | action=action, 30 | reward=reward, 31 | done=done, 32 | policy_probs=policy_probs, 33 | value=value 34 | ) 35 | if priorities is not None: 36 | return PrioritizedSamples(samples=samples, 37 | priorities=priorities) 38 | else: 39 | return samples 40 | 41 | def sanitize_batch(batch): 42 | has_dones, inds = torch.max(batch.done, 0) 43 | for i, (has_done, ind) in enumerate(zip(has_dones, inds)): 44 | if not has_done: 45 | continue 46 | batch.all_observation[ind+1:, i] = batch.all_observation[ind, i] 47 | batch.all_reward[ind+1:, i] = 0 48 | batch.return_[ind+1:, i] = 0 49 | batch.done_n[ind+1:, i] = True 50 | return batch 51 | 52 | 53 | class AsyncUniformSequenceReplayFrameBufferExtended(AsyncUniformSequenceReplayFrameBuffer): 54 | """ 55 | Extends AsyncPrioritizedSequenceReplayFrameBuffer to return policy_logits and values too during sampling. 56 | """ 57 | def sample_batch(self, batch_B): 58 | while True: 59 | try: 60 | self._async_pull() # Updates from writers. 61 | batch_T = self.batch_T 62 | T_idxs, B_idxs = self.sample_idxs(batch_B, batch_T) 63 | sampled_indices = True 64 | if self.rnn_state_interval > 1: 65 | T_idxs = T_idxs * self.rnn_state_interval 66 | batch = self.extract_batch(T_idxs, B_idxs, self.batch_T) 67 | 68 | except Exception as _: 69 | print("FAILED TO LOAD BATCH") 70 | if sampled_indices: 71 | print("B_idxs:", B_idxs, flush=True) 72 | print("T_idxs:", T_idxs, flush=True) 73 | print("Batch_T:", self.batch_T, flush=True) 74 | print("Buffer T:", self.T, flush=True) 75 | 76 | elapsed_iters = self.t + self.T - T_idxs % self.T 77 | # elapsed_samples = self.B*(elapsed_iters) 78 | # values = torch.from_numpy(extract_sequences(self.samples.value, T_idxs, B_idxs, self.batch_T+self.n_step_return+1)) 79 | # batch = SamplesFromReplayExt(*batch, values=values, age=elapsed_samples) 80 | if self.batch_T > 1: 81 | batch = sanitize_batch(batch) 82 | return batch 83 | 84 | 85 | class AsyncPrioritizedSequenceReplayFrameBufferExtended(AsyncPrioritizedSequenceReplayFrameBuffer): 86 | """ 87 | Extends AsyncPrioritizedSequenceReplayFrameBuffer to return policy_logits and values too during sampling. 88 | """ 89 | def sample_batch(self, batch_B, batch_T=None): 90 | if batch_T is None: 91 | batch_T = self.batch_T 92 | while True: 93 | try: 94 | self._async_pull() # Updates from writers. 95 | (T_idxs, B_idxs), priorities = self.priority_tree.sample( 96 | batch_B, unique=self.unique) 97 | sampled_indices = True 98 | if self.rnn_state_interval > 1: 99 | T_idxs = T_idxs * self.rnn_state_interval 100 | 101 | batch = self.extract_batch(T_idxs, B_idxs, batch_T) 102 | 103 | except Exception as _: 104 | print("FAILED TO LOAD BATCH") 105 | traceback.print_exc() 106 | if sampled_indices: 107 | print("B_idxs:", B_idxs, flush=True) 108 | print("T_idxs:", T_idxs, flush=True) 109 | print("Batch_T:", batch_T, flush=True) 110 | print("Buffer T:", self.T, flush=True) 111 | 112 | is_weights = (1. / (priorities + 1e-5)) ** self.beta 113 | is_weights /= max(is_weights) # Normalize. 114 | is_weights = torchify_buffer(is_weights).float() 115 | 116 | # elapsed_iters = self.t + self.T - T_idxs % self.T 117 | # elapsed_samples = self.B*(elapsed_iters) 118 | # values = torch.from_numpy(extract_sequences(self.samples.value, T_idxs, B_idxs, batch_T+self.n_step_return+1)) 119 | batch = SamplesFromReplayPri(*batch, is_weights=is_weights,) 120 | if self.batch_T > 1: 121 | batch = sanitize_batch(batch) 122 | return batch 123 | -------------------------------------------------------------------------------- /scripts/run.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any, Dict, Optional 3 | import os 4 | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" 5 | 6 | import hydra 7 | import numpy as np 8 | from omegaconf import DictConfig, OmegaConf 9 | import torch 10 | import wandb 11 | 12 | from rlpyt.envs.atari.atari_env import AtariTrajInfo 13 | from rlpyt.experiments.configs.atari.dqn.atari_dqn import configs 14 | from rlpyt.samplers.serial.sampler import SerialSampler 15 | from rlpyt.utils.logging.context import logger_context 16 | 17 | from src.agent import SPRAgent 18 | from src.algos import SPRCategoricalDQN 19 | from src.models import SPRCatDqnModel 20 | from src.rlpyt_atari_env import AtariEnv 21 | from src.rlpyt_utils import OnlineEval, OfflineEval, OneToOneSerialEvalCollector, SerialSampler 22 | from src.utils import get_last_save, save_model_fn 23 | 24 | 25 | @hydra.main(config_name="config") 26 | def main(args: DictConfig): 27 | args = OmegaConf.merge(configs['ernbw'], args) 28 | print(OmegaConf.to_yaml(args)) 29 | 30 | # Offline pretraining 31 | state_dict = None 32 | start_itr = 0 33 | if args.model_load is not None: 34 | if args.load_last: 35 | state_dict, start_itr = get_last_save(f'{args.model_folder}/{args.model_load}') 36 | print("Loading checkpoint from {}".format(start_itr)) 37 | else: 38 | state_dict = torch.load(Path(f'{args.model_folder}/{args.model_load}.pt')) 39 | elif args.offline_model_save is not None: 40 | try: 41 | state_dict, start_itr = get_last_save(f'{args.model_folder}/{args.offline_model_save}_{args.seed}') 42 | print("Loaded in old model to resume") 43 | except Exception as e: 44 | print(e) 45 | print("Could not load existing pretraining model; not continuing") 46 | state_dict = None 47 | start_itr = 0 48 | 49 | if state_dict is not None: 50 | if "model" in state_dict: 51 | model_state_dict = state_dict["model"] 52 | optim_state_dict = state_dict["optim"] 53 | else: 54 | model_state_dict = state_dict 55 | optim_state_dict = None 56 | else: 57 | model_state_dict = optim_state_dict = None 58 | 59 | if args.offline.runner.epochs > 0: 60 | print("Offline pretraining") 61 | offline_args = OmegaConf.merge(args, args.offline) 62 | print(OmegaConf.to_yaml(offline_args, resolve=True)) 63 | config: Dict[str, Any] = OmegaConf.to_container(offline_args, resolve=True) 64 | 65 | dl_kwargs = config["runner"]["dataloader"] 66 | dl_kwargs['data_path'] = Path(dl_kwargs['data_path']) 67 | k_step_base = dl_kwargs['jumps']+1 68 | if config["algo"]["goal_weight"] > 0: 69 | k_step_base = max(k_step_base, config["algo"]['goal_window']) 70 | dl_kwargs['k_step'] = k_step_base + dl_kwargs['n_step_return'] 71 | 72 | if args.offline_model_save is not None: 73 | save_fn = save_model_fn(args.model_folder, args.offline_model_save, args.seed, args.save_by_epoch, args.save_last_only) 74 | else: 75 | save_fn = None 76 | 77 | config["algo"]["min_steps_learn"] = 0 78 | agent, _, _, _ = train(config, save_fn=save_fn, 79 | offline=True, 80 | state_dict=model_state_dict, 81 | optim_state_dict=optim_state_dict, 82 | start_itr=start_itr) 83 | 84 | model_state_dict = agent.model.state_dict() 85 | 86 | if args.runner.n_steps > 0 and args.do_online: 87 | print("Online training") 88 | print(OmegaConf.to_yaml(args, resolve=True)) 89 | config: Dict[str, Any] = OmegaConf.to_container(args, resolve=True) 90 | config["runner"]["log_interval_steps"] = args.runner.n_steps // args.num_logs 91 | if args.offline.runner.epochs > 0: 92 | config["model_load"] = config["offline_model_save"] 93 | _, _, _, _ = train(config, offline=False, 94 | state_dict=model_state_dict, 95 | optim_state_dict=optim_state_dict) 96 | 97 | 98 | def train(config: Dict[str, Any], *, offline: bool, 99 | state_dict: Optional[Dict[str, torch.Tensor]], 100 | optim_state_dict: Optional[Dict[str, torch.Tensor]], 101 | save_fn=None, 102 | start_itr=0): 103 | if config["public"]: 104 | wandb.init(config=config, project="SGI", group="offline" if offline else "online", reinit=True, anonymous="allow") 105 | else: 106 | wandb.init(config=config, **config["wandb"], group="offline" if offline else "online", reinit=True) 107 | np.random.seed(config["seed"]) 108 | torch.manual_seed(config["seed"]) 109 | 110 | if state_dict is not None: 111 | print("Initializing with pretrained model") 112 | config["agent"]["model_kwargs"]["state_dict"] = state_dict 113 | if offline and optim_state_dict is not None: 114 | print("Initializing optimizer with previous settings") 115 | config["algo"]["initial_optim_state_dict"] = optim_state_dict 116 | 117 | algo = SPRCategoricalDQN(**config["algo"]) # Run with defaults. 118 | agent = SPRAgent(ModelCls=SPRCatDqnModel, **config["agent"]) 119 | sampler = SerialSampler( 120 | EnvCls=AtariEnv, 121 | TrajInfoCls=AtariTrajInfo, # default traj info + GameScore 122 | eval_CollectorCls=OneToOneSerialEvalCollector, 123 | **config["sampler"], 124 | ) 125 | runner_type = OfflineEval if offline else OnlineEval 126 | runner = runner_type(algo=algo, agent=agent, sampler=sampler, save_fn=save_fn, start_itr=start_itr, **config["runner"]) 127 | 128 | with logger_context(**config["context"]): 129 | runner.train() 130 | 131 | return algo, agent, sampler, runner 132 | 133 | 134 | if __name__ == "__main__": 135 | main() 136 | -------------------------------------------------------------------------------- /src/agent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from torch.distributions.categorical import Categorical 5 | from rlpyt.agents.dqn.atari.atari_catdqn_agent import AtariCatDqnAgent 6 | from rlpyt.utils.buffer import buffer_to 7 | from rlpyt.utils.collections import namedarraytuple 8 | AgentInputs = namedarraytuple("AgentInputs", 9 | ["observation", "prev_action", "prev_reward"]) 10 | AgentInfo = namedarraytuple("AgentInfo", "p") 11 | AgentStep = namedarraytuple("AgentStep", ["action", "agent_info"]) 12 | 13 | 14 | class SPRAgent(AtariCatDqnAgent): 15 | """Agent for Categorical DQN algorithm with search.""" 16 | 17 | def __init__(self, eval=False, repeat_random_lambda=0, softmax_policy=False, **kwargs): 18 | """Standard init, and set the number of probability atoms (bins).""" 19 | super().__init__(**kwargs) 20 | self.eval = eval 21 | self.repeat_random_lambda = repeat_random_lambda 22 | self.softmax_policy = softmax_policy 23 | 24 | def __call__(self, observation, prev_action, prev_reward, goal=None, train=False): 25 | """Returns Q-values for states/observations (with grad).""" 26 | if train: 27 | model_inputs = buffer_to((observation, prev_action, 28 | prev_reward, goal), 29 | device=self.device) 30 | return self.model(*model_inputs, train=train) 31 | else: 32 | device = observation.device 33 | prev_action = self.distribution.to_onehot(prev_action) 34 | model_inputs = buffer_to((observation, prev_action, 35 | prev_reward, goal), 36 | device=self.device) 37 | return self.model(*model_inputs).to(device) 38 | 39 | def target(self, observation, prev_action, prev_reward, goal=None): 40 | """Returns the target Q-values for states/observations.""" 41 | prev_action = self.distribution.to_onehot(prev_action) 42 | model_inputs = buffer_to((observation, prev_action, prev_reward, goal), 43 | device=self.device) 44 | target_q = self.target_model(*model_inputs) 45 | return target_q 46 | 47 | def initialize(self, 48 | env_spaces, 49 | share_memory=False, 50 | global_B=1, 51 | env_ranks=None): 52 | super().initialize(env_spaces, share_memory, global_B, env_ranks) 53 | # Overwrite distribution. 54 | self.search = SPRActionSelection(self.model, self.distribution, repeat_random_lambda=self.repeat_random_lambda, 55 | softmax_policy=self.softmax_policy) 56 | 57 | def to_device(self, cuda_idx=None): 58 | """Moves the model to the specified cuda device, if not ``None``. If 59 | sharing memory, instantiates a new model to preserve the shared (CPU) 60 | model. Agents with additional model components (beyond 61 | ``self.model``) for action-selection or for use during training should 62 | extend this method to move those to the device, as well. 63 | 64 | Typically called in the runner during startup. 65 | """ 66 | super().to_device(cuda_idx) 67 | self.search.to_device(cuda_idx) 68 | self.search.network = self.model 69 | 70 | def eval_mode(self, itr): 71 | """Extend method to set epsilon for evaluation, using 1 for 72 | pre-training eval.""" 73 | super().eval_mode(itr) 74 | self.search.epsilon = self.distribution.epsilon 75 | self.search.network.head.set_sampling(False) 76 | self.itr = itr 77 | 78 | def sample_mode(self, itr): 79 | """Extend method to set epsilon for sampling (including annealing).""" 80 | super().sample_mode(itr) 81 | self.search.epsilon = self.distribution.epsilon 82 | self.search.network.head.set_sampling(True) 83 | self.itr = itr 84 | 85 | def train_mode(self, itr): 86 | super().train_mode(itr) 87 | self.search.network.head.set_sampling(True) 88 | self.itr = itr 89 | 90 | @torch.no_grad() 91 | def step(self, observation, prev_action, prev_reward): 92 | """Compute the discrete distribution for the Q-value for each 93 | action for each state/observation (no grad).""" 94 | action, p = self.search.run(observation.to(self.search.device)) 95 | p = p.cpu() 96 | action = action.cpu() 97 | 98 | agent_info = AgentInfo(p=p) 99 | action, agent_info = buffer_to((action, agent_info), device="cpu") 100 | return AgentStep(action=action, agent_info=agent_info) 101 | 102 | 103 | class SPRActionSelection(torch.nn.Module): 104 | def __init__(self, network, distribution, repeat_random_lambda=0, device="cpu", softmax_policy=False): 105 | super().__init__() 106 | self.network = network 107 | self.epsilon = distribution._epsilon 108 | self.device = device 109 | self.first_call = True 110 | self.softmax_policy = softmax_policy 111 | 112 | self.repeat_random_lambda = repeat_random_lambda 113 | self.repeats_remaining = 0 114 | 115 | def sample_random_action(self, high, size, device): 116 | if self.repeat_random_lambda == 0: 117 | return torch.randint(low=0, high=high, size=size, device=device) 118 | elif self.repeats_remaining == 0: 119 | self.random_action = torch.randint(low=0, high=high, size=size, device=device) 120 | self.repeats_remaining = np.random.geometric(self.repeat_random_lambda) 121 | else: 122 | self.repeats_remaining -= 1 123 | 124 | return self.random_action 125 | 126 | def to_device(self, idx): 127 | self.device = idx 128 | 129 | @torch.no_grad() 130 | def run(self, obs): 131 | while len(obs.shape) <= 4: 132 | obs.unsqueeze_(0) 133 | obs = obs.to(self.device).float() / 255. 134 | 135 | # Don't even bother with the network if all actions will be random. 136 | if self.epsilon == 1: 137 | action = self.sample_random_action(high=self.network.num_actions, size=(obs.shape[0],), device=obs.device) 138 | value = torch.zeros(obs.shape[0], self.network.num_actions) 139 | else: 140 | value = self.network.select_action(obs) 141 | action = self.select_action(value) 142 | 143 | # Stupid, stupid hack because rlpyt does _not_ handle batch_b=1 well. 144 | if self.first_call: 145 | action = action.squeeze() 146 | self.first_call = False 147 | return action, value.squeeze() 148 | 149 | def select_action(self, value): 150 | """Input can be shaped [T,B,Q] or [B,Q], and vector epsilon of length 151 | B will apply across the Batch dimension (same epsilon for all T).""" 152 | if self.softmax_policy: 153 | arg_select = Categorical(probs=F.softmax(value, -1)).sample() 154 | else: 155 | arg_select = torch.argmax(value, dim=-1) 156 | mask = torch.rand(arg_select.shape, device=value.device) < self.epsilon 157 | arg_rand = self.sample_random_action(high=value.shape[-1], size=(mask.sum(),), device=value.device) 158 | arg_select[mask] = arg_rand 159 | return arg_select 160 | -------------------------------------------------------------------------------- /src/efficient_sampling_offline_dataset.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | from pathlib import Path 3 | import re 4 | from typing import List, Tuple 5 | 6 | import numpy as np 7 | from rlpyt.utils.collections import namedarraytuple 8 | import torch 9 | from torch.utils.data import DataLoader, Dataset 10 | import os 11 | 12 | from .rlpyt_atari_env import AtariEnv 13 | 14 | from torch._six import int_classes as _int_classes 15 | from torch import Tensor 16 | 17 | from typing import Iterator, Optional, Sequence, List, TypeVar, Generic, Sized 18 | 19 | T_co = TypeVar('T_co', covariant=True) 20 | 21 | 22 | OfflineSamples = namedarraytuple("OfflineSamples", ["all_observation", "all_action", "all_reward", "done", "done_n", "return_"]) 23 | 24 | class DQNReplayDataset(Dataset): 25 | def __init__(self, data_path: Path, tmp_data_path: Path, game: str, checkpoint: int, frames: int, k_step: int, max_size: int, full_action_set: bool, dataset_on_gpu: bool, dataset_on_disk: bool) -> None: 26 | data = [] 27 | self.dataset_on_disk = dataset_on_disk 28 | assert not (dataset_on_disk and dataset_on_gpu) 29 | for filetype in ['reward', 'action', 'terminal', 'observation']: 30 | filename = Path(data_path / f'{game}/{filetype}_{checkpoint}.gz') 31 | print(f'Loading {filename}') 32 | 33 | # There's no point in putting rewards, actions and terminals on disk. 34 | # They're tiny and it'll just cause more I/O. 35 | on_disk = dataset_on_disk and filetype == "observation" 36 | 37 | g = gzip.GzipFile(filename=filename) 38 | data__ = np.load(g) 39 | if filetype == "reward": 40 | self.has_parallel_envs = len(data__.shape) > 1 41 | if self.has_parallel_envs: 42 | self.n_envs = data__.shape[1] 43 | else: 44 | self.n_envs = 1 45 | if not self.has_parallel_envs: 46 | data__ = np.expand_dims(data__, 1) 47 | 48 | data___ = np.copy(data__[:max_size]) 49 | print(f'Using {data___.size * data___.itemsize} bytes') 50 | if not on_disk: 51 | del data__ 52 | data_ = torch.from_numpy(data___) 53 | else: 54 | new_filename = os.path.join(tmp_data_path, Path(os.path.basename(filename)[:-3]+".npy")) 55 | print("Stored on disk at {}".format(new_filename)) 56 | np.save(new_filename, data___,) 57 | del data___ 58 | del data__ 59 | data_ = np.load(new_filename, mmap_mode="r+") 60 | 61 | if (filetype == 'action') and full_action_set: 62 | action_mapping = dict(zip(data_.unique().numpy(), 63 | AtariEnv(re.sub(r'(? int: 82 | return self.effective_size*self.n_envs 83 | 84 | def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 85 | batch_ind = index // self.effective_size 86 | time_ind = index % self.effective_size 87 | sl = slice(time_ind, time_ind+self.f+self.k) 88 | if self.dataset_on_disk: 89 | obs = torch.from_numpy(self.observations[sl, batch_ind]) 90 | else: 91 | obs = (self.observations[sl, batch_ind]) 92 | return tuple([obs, 93 | self.actions[sl, batch_ind], 94 | self.rewards[sl, batch_ind], 95 | self.terminal[sl, batch_ind], 96 | ]) 97 | 98 | 99 | class MultiDQNReplayDataset(Dataset): 100 | def __init__(self, data_path: Path, tmp_data_path: Path, games: List[str], checkpoints: List[int], frames: int, k_step: int, max_size: int, full_action_set: bool, dataset_on_gpu: bool, dataset_on_disk: bool) -> None: 101 | self.games = [DQNReplayDataset(data_path, tmp_data_path, game, ckpt, frames, k_step, max_size, full_action_set, dataset_on_gpu, dataset_on_disk) for ckpt in checkpoints for game in games] 102 | 103 | def __len__(self) -> int: 104 | return len(self.games) * len(self.games[0]) 105 | 106 | def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 107 | game_index = index % len(self.games) 108 | index = index // len(self.games) 109 | return self.games[game_index][index] 110 | 111 | 112 | def sanitize_batch(batch: OfflineSamples) -> OfflineSamples: 113 | has_dones, inds = torch.max(batch.done, 0) 114 | for i, (has_done, ind) in enumerate(zip(has_dones, inds)): 115 | if not has_done: 116 | continue 117 | batch.all_observation[ind+1:, i] = batch.all_observation[ind, i] 118 | batch.all_reward[ind+1:, i] = 0 119 | batch.return_[ind+1:, i] = 0 120 | batch.done_n[ind+1:, i] = True 121 | return batch 122 | 123 | 124 | def discount_return_n_step(reward, done, n_step, discount, return_dest=None, 125 | done_n_dest=None, do_truncated=False): 126 | """Time-major inputs, optional other dimension: [T], [T,B], etc. Computes 127 | n-step discounted returns within the timeframe of the of given rewards. If 128 | `do_truncated==False`, then only compute at time-steps with full n-step 129 | future rewards are provided (i.e. not at last n-steps--output shape will 130 | change!). Returns n-step returns as well as n-step done signals, which is 131 | True if `done=True` at any future time before the n-step target bootstrap 132 | would apply (bootstrap in the algo, not here).""" 133 | rlen = reward.shape[0] 134 | if not do_truncated: 135 | rlen -= (n_step - 1) 136 | return_ = torch.zeros( 137 | (rlen,) + reward.shape[1:], dtype=reward.dtype, device=reward.device) 138 | done_n = torch.zeros( 139 | (rlen,) + reward.shape[1:], dtype=done.dtype, device=done.device) 140 | return_[:] = reward[:rlen] # 1-step return is current reward. 141 | done_n[:] = done[:rlen] # True at time t if done any time by t + n - 1 142 | 143 | done_dtype = done.dtype 144 | done_n = done_n.type(reward.dtype) 145 | done = done.type(reward.dtype) 146 | 147 | if n_step > 1: 148 | if do_truncated: 149 | for n in range(1, n_step): 150 | return_[:-n] += (discount ** n) * reward[n:n + rlen] * (1 - done_n[:-n]) 151 | done_n[:-n] = torch.max(done_n[:-n], done[n:n + rlen]) 152 | else: 153 | for n in range(1, n_step): 154 | return_ += (discount ** n) * reward[n:n + rlen] * (1 - done_n) 155 | done_n = torch.max(done_n, done[n:n + rlen]) # Supports tensors. 156 | done_n = done_n.type(done_dtype) 157 | return return_, done_n 158 | 159 | def get_offline_dataloaders( 160 | *, 161 | data_path: Path, 162 | tmp_data_path: Path, 163 | games: List[str], 164 | checkpoints: List[int], 165 | frames: int, 166 | k_step: int, 167 | n_step_return: int, 168 | discount: float, 169 | samples: int, 170 | test_game: str, 171 | test_samples: int, 172 | dataset_on_gpu: bool, 173 | dataset_on_disk: bool, 174 | batch_size: int, 175 | full_action_set: bool, 176 | num_workers: int, 177 | pin_memory: bool, 178 | prefetch_factor: int, 179 | **kwargs, 180 | ) -> Tuple[DataLoader, DataLoader, DataLoader]: 181 | def collate(batch): 182 | #batch = list(filter(lambda x: not x[3].any(), batch)) # filter samples with a terminal state 183 | observation, action, reward, done = torch.utils.data.dataloader.default_collate(batch) 184 | observation = torch.einsum('bthw->tbhw', observation).unsqueeze(2).repeat(1, 1, frames, 1, 1) 185 | for i in range(1, frames): 186 | observation[:, :, i] = observation[:, :, i].roll(-i, 0) 187 | observation = observation[:-frames].unsqueeze(3) # tbfchw 188 | action = torch.einsum('bt->tb', action)[frames-1:].long() 189 | reward = torch.einsum('bt->tb', reward)[frames:] 190 | done = torch.einsum('bt->tb', done)[frames:].bool() 191 | return_, done_n = discount_return_n_step(reward, done, n_step_return, discount) 192 | return sanitize_batch(OfflineSamples(observation, action, reward, done[:-n_step_return], done_n, return_)) 193 | 194 | dataset = MultiDQNReplayDataset(data_path, tmp_data_path, games, checkpoints, frames, k_step, samples, full_action_set, dataset_on_gpu, dataset_on_disk) 195 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory, collate_fn=collate, drop_last=True, prefetch_factor=prefetch_factor) 196 | 197 | #test_dataset = DQNReplayDataset(data_path, test_game, frames, k_step, test_samples, dataset_on_gpu) 198 | #test_dataloader = DataLoader(test_dataset, batch_size=100, shuffle=True, num_workers=num_workers, pin_memory=pin_memory, collate_fn=collate, drop_last=True) 199 | 200 | #random_dataset = DQNReplayDataset(data_path, test_game, frames, k_step, test_samples, dataset_on_gpu) 201 | #random_dataset.observations.random_(0, 255) 202 | #random_dataset.actions.random_(random_dataset.actions.min(), random_dataset.actions.max()) 203 | #random_dataloader = DataLoader(random_dataset, batch_size=100, shuffle=False, num_workers=num_workers, pin_memory=pin_memory, collate_fn=collate, drop_last=True) 204 | 205 | return dataloader, None, None #test_dataloader, random_dataloader 206 | 207 | 208 | class CacheEfficientSampler(torch.utils.data.Sampler): 209 | 210 | def __init__(self, num_blocks, block_len, num_repeats=20): 211 | self.num_blocks = num_blocks 212 | self.block_len = block_len # For now, assume all have same length 213 | self.num_repeats = num_repeats 214 | 215 | def num_samples(self) -> int: 216 | # dataset size might change at runtime 217 | return self.block_len*self.num_blocks 218 | 219 | def __iter__(self): 220 | n = self.num_samples() 221 | if self.generator is None: 222 | generator = torch.Generator() 223 | generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item())) 224 | else: 225 | generator = self.generator 226 | 227 | self.block_ids = [np.arange(self.num_blocks)] * (self.block_len // self.num_repeats) 228 | 229 | blocks = torch.randperm(n//self.num_repeats, generator=generator) % self.num_blocks 230 | 231 | subsamplers = [torch.utils.data.SubsetRandomSampler(torch.arange(i*self.block_len, (i+1)*self.block_len), generator=generator) for i in range(len(self.num_blocks))] 232 | 233 | for block in blocks: 234 | for i in range(self.num_repeats): 235 | yield from subsamplers[block] 236 | 237 | def __len__(self): 238 | return self.num_samples 239 | -------------------------------------------------------------------------------- /src/rlpyt_atari_env.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modifies the default rlpyt AtariEnv to be closer to DeepMind's setup, 3 | tries to follow Kaixin/Rainbow's env for the most part. 4 | """ 5 | import numpy as np 6 | import os 7 | import atari_py 8 | import cv2 9 | from collections import namedtuple 10 | from gym.utils import seeding 11 | 12 | from rlpyt.envs.base import Env, EnvStep 13 | from rlpyt.spaces.int_box import IntBox 14 | from rlpyt.utils.quick_args import save__init__args 15 | from rlpyt.samplers.collections import TrajInfo 16 | 17 | 18 | EnvInfo = namedtuple("EnvInfo", ["game_score", "traj_done"]) 19 | 20 | 21 | class AtariTrajInfo(TrajInfo): 22 | """TrajInfo class for use with Atari Env, to store raw game score separate 23 | from clipped reward signal.""" 24 | 25 | def __init__(self, **kwargs): 26 | super().__init__(**kwargs) 27 | self.GameScore = 0 28 | 29 | def step(self, observation, action, reward, done, agent_info, env_info): 30 | super().step(observation, action, reward, done, agent_info, env_info) 31 | self.GameScore += getattr(env_info, "game_score", 0) 32 | 33 | 34 | class AtariEnv(Env): 35 | """An efficient implementation of the classic Atari RL envrionment using the 36 | Arcade Learning Environment (ALE). 37 | 38 | Output `env_info` includes: 39 | * `game_score`: raw game score, separate from reward clipping. 40 | * `traj_done`: special signal which signals game-over or timeout, so that sampler doesn't reset the environment when ``done==True`` but ``traj_done==False``, which can happen when ``episodic_lives==True``. 41 | 42 | Always performs 2-frame max to avoid flickering (this is pretty fast). 43 | 44 | Screen size downsampling is done by cropping two rows and then 45 | downsampling by 2x using `cv2`: (210, 160) --> (80, 104). Downsampling by 46 | 2x is much faster than the old scheme to (84, 84), and the (80, 104) shape 47 | is fairly convenient for convolution filter parameters which don't cut off 48 | edges. 49 | 50 | The action space is an `IntBox` for the number of actions. The observation 51 | space is an `IntBox` with ``dtype=uint8`` to save memory; conversion to float 52 | should happen inside the agent's model's ``forward()`` method. 53 | 54 | (See the file for implementation details.) 55 | 56 | 57 | Args: 58 | game (str): game name 59 | frame_skip (int): frames per step (>=1) 60 | num_img_obs (int): number of frames in observation (>=1) 61 | clip_reward (bool): if ``True``, clip reward to np.sign(reward) 62 | episodic_lives (bool): if ``True``, output ``done=True`` but ``env_info[traj_done]=False`` when a life is lost 63 | max_start_noops (int): upper limit for random number of noop actions after reset 64 | repeat_action_probability (0-1): probability for sticky actions 65 | horizon (int): max number of steps before timeout / ``traj_done=True`` 66 | """ 67 | 68 | def __init__(self, 69 | game="pong", 70 | frame_skip=4, # Frames per step (>=1). 71 | num_img_obs=4, # Number of (past) frames in observation (>=1). 72 | clip_reward=True, 73 | episodic_lives=True, 74 | max_start_noops=30, 75 | repeat_action_probability=0., 76 | horizon=27000, 77 | stack_actions=0, 78 | grayscale=True, 79 | imagesize=84, 80 | seed=42, 81 | id=0, 82 | full_action_set=False, 83 | ): 84 | save__init__args(locals(), underscore=True) 85 | # ALE 86 | game_path = atari_py.get_game_path(game) 87 | if not os.path.exists(game_path): 88 | raise IOError("You asked for game {} but path {} does not " 89 | " exist".format(game, game_path)) 90 | self.ale = atari_py.ALEInterface() 91 | self.seed(seed, id) 92 | self.ale.setFloat(b'repeat_action_probability', repeat_action_probability) 93 | self.ale.loadROM(game_path) 94 | 95 | # Spaces 96 | self.stack_actions = stack_actions 97 | self.full_action_set = full_action_set 98 | self._action_set = self.ale.getMinimalActionSet() if not full_action_set else self.ale.getLegalActionSet() 99 | self._action_space = IntBox(low=0, high=len(self._action_set)) 100 | self.channels = 1 if grayscale else 3 101 | self.grayscale = grayscale 102 | self.imagesize = imagesize 103 | if self.stack_actions: self.channels += 1 104 | obs_shape = (num_img_obs, self.channels, imagesize, imagesize) 105 | self._observation_space = IntBox(low=0, high=255, shape=obs_shape, 106 | dtype="uint8") 107 | self._max_frame = self.ale.getScreenGrayscale() if self.grayscale \ 108 | else self.ale.getScreenRGB() 109 | self._raw_frame_1 = self._max_frame.copy() 110 | self._raw_frame_2 = self._max_frame.copy() 111 | self._obs = np.zeros(shape=obs_shape, dtype="uint8") 112 | 113 | # Settings 114 | self._has_fire = "FIRE" in self.get_action_meanings() 115 | self._has_up = "UP" in self.get_action_meanings() 116 | self._horizon = int(horizon) 117 | self.reset() 118 | 119 | def seed(self, seed=None, id=0): 120 | _, seed1 = seeding.np_random(seed) 121 | if id > 0: 122 | seed = seed*100 + id 123 | self.np_random, _ = seeding.np_random(seed) 124 | # Derive a random seed. This gets passed as a uint, but gets 125 | # checked as an int elsewhere, so we need to keep it below 126 | # 2**31. 127 | seed2 = seeding.hash_seed(seed1 + 1) % 2**31 128 | # Empirically, we need to seed before loading the ROM. 129 | self.ale.setInt(b'random_seed', seed2) 130 | 131 | def reset(self): 132 | """Performs hard reset of ALE game.""" 133 | self.ale.reset_game() 134 | self._reset_obs() 135 | self._life_reset() 136 | if self._max_start_noops > 0: 137 | for _ in range(self.np_random.randint(1, self._max_start_noops + 1)): 138 | self.ale.act(0) 139 | if self._check_life(): 140 | self.reset() 141 | self._update_obs(0) # (don't bother to populate any frame history) 142 | self._step_counter = 0 143 | return self.get_obs() 144 | 145 | def step(self, action): 146 | a = self._action_set[action] 147 | game_score = np.array(0., dtype="float32") 148 | for _ in range(self._frame_skip - 1): 149 | game_score += self.ale.act(a) 150 | self._get_screen(1) 151 | game_score += self.ale.act(a) 152 | lost_life = self._check_life() # Advances from lost_life state. 153 | if lost_life and self._episodic_lives: 154 | self._reset_obs() # Internal reset. 155 | self._update_obs(action) 156 | reward = np.sign(game_score) if self._clip_reward else game_score 157 | game_over = self.ale.game_over() or self._step_counter >= self.horizon 158 | done = game_over or (self._episodic_lives and lost_life) 159 | info = EnvInfo(game_score=game_score, traj_done=game_over) 160 | self._step_counter += 1 161 | return EnvStep(self.get_obs(), reward, done, info) 162 | 163 | def render(self, wait=10, show_full_obs=False): 164 | """Shows game screen via cv2, with option to show all frames in observation.""" 165 | img = self.get_obs() 166 | if show_full_obs: 167 | shape = img.shape 168 | img = img.reshape(shape[0] * shape[1], shape[2]) 169 | else: 170 | img = img[-1] 171 | cv2.imshow(self._game, img) 172 | cv2.waitKey(wait) 173 | 174 | def get_obs(self): 175 | return self._obs.copy() 176 | 177 | ########################################################################### 178 | # Helpers 179 | 180 | def _get_screen(self, frame=1): 181 | frame = self._raw_frame_1 if frame == 1 else self._raw_frame_2 182 | if self.grayscale: 183 | self.ale.getScreenGrayscale(frame) 184 | else: 185 | self.ale.getScreenRGB(frame) 186 | 187 | def _update_obs(self, action): 188 | """Max of last two frames; crop two rows; downsample by 2x.""" 189 | self._get_screen(2) 190 | np.maximum(self._raw_frame_1, self._raw_frame_2, self._max_frame) 191 | img = cv2.resize(self._max_frame, (self.imagesize, self.imagesize), cv2.INTER_LINEAR) 192 | if len(img.shape) == 2: 193 | img = img[np.newaxis] 194 | else: 195 | img = np.transpose(img, (2, 0, 1)) 196 | if self.stack_actions: 197 | action = int(255.*action/self._action_space.n) 198 | action = np.ones_like(img[:1])*action 199 | img = np.concatenate([img, action], 0) 200 | # NOTE: order OLDEST to NEWEST should match use in frame-wise buffer. 201 | self._obs = np.concatenate([self._obs[1:], img[np.newaxis]]) 202 | 203 | def _reset_obs(self): 204 | self._obs[:] = 0 205 | self._max_frame[:] = 0 206 | self._raw_frame_1[:] = 0 207 | self._raw_frame_2[:] = 0 208 | 209 | def _check_life(self): 210 | lives = self.ale.lives() 211 | lost_life = (lives < self._lives) and (lives > 0) 212 | if lost_life: 213 | self._life_reset() 214 | return lost_life 215 | 216 | def _life_reset(self): 217 | self.ale.act(0) 218 | self._lives = self.ale.lives() 219 | 220 | ########################################################################### 221 | # Properties 222 | 223 | @property 224 | def game(self): 225 | return self._game 226 | 227 | @property 228 | def frame_skip(self): 229 | return self._frame_skip 230 | 231 | @property 232 | def num_img_obs(self): 233 | return self._num_img_obs 234 | 235 | @property 236 | def clip_reward(self): 237 | return self._clip_reward 238 | 239 | @property 240 | def max_start_noops(self): 241 | return self._max_start_noops 242 | 243 | @property 244 | def episodic_lives(self): 245 | return self._episodic_lives 246 | 247 | @property 248 | def repeat_action_probability(self): 249 | return self._repeat_action_probability 250 | 251 | @property 252 | def horizon(self): 253 | return self._horizon 254 | 255 | def get_action_meanings(self): 256 | return [ACTION_MEANING[i] for i in self._action_set] 257 | 258 | 259 | ACTION_MEANING = { 260 | 0: "NOOP", 261 | 1: "FIRE", 262 | 2: "UP", 263 | 3: "RIGHT", 264 | 4: "LEFT", 265 | 5: "DOWN", 266 | 6: "UPRIGHT", 267 | 7: "UPLEFT", 268 | 8: "DOWNRIGHT", 269 | 9: "DOWNLEFT", 270 | 10: "UPFIRE", 271 | 11: "RIGHTFIRE", 272 | 12: "LEFTFIRE", 273 | 13: "DOWNFIRE", 274 | 14: "UPRIGHTFIRE", 275 | 15: "UPLEFTFIRE", 276 | 16: "DOWNRIGHTFIRE", 277 | 17: "DOWNLEFTFIRE", 278 | } 279 | 280 | ACTION_INDEX = {v: k for k, v in ACTION_MEANING.items()} 281 | -------------------------------------------------------------------------------- /src/offline_dataset.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | from pathlib import Path 3 | import re 4 | from typing import List, Tuple 5 | 6 | import numpy as np 7 | from rlpyt.utils.collections import namedarraytuple 8 | import torch 9 | from torch.utils.data import DataLoader, Dataset 10 | import os 11 | 12 | from itertools import zip_longest 13 | from .rlpyt_atari_env import AtariEnv 14 | from src.utils import discount_return_n_step 15 | 16 | 17 | OfflineSamples = namedarraytuple("OfflineSamples", ["all_observation", "all_action", "all_reward", "return_", "done", "done_n", "init_rnn_state", "is_weights"]) 18 | 19 | class DQNReplayDataset(Dataset): 20 | def __init__(self, data_path: Path, 21 | tmp_data_path: Path, 22 | game: str, 23 | checkpoint: int, 24 | frames: int, 25 | k_step: int, 26 | max_size: int, 27 | full_action_set: bool, 28 | dataset_on_gpu: bool, 29 | dataset_on_disk: bool, 30 | load_reward: bool = False) -> None: 31 | data = [] 32 | self.dataset_on_disk = dataset_on_disk 33 | self.load_reward = load_reward 34 | assert not (dataset_on_disk and dataset_on_gpu) 35 | filetypes = ['reward', 'action', 'terminal', 'observation'] 36 | if not load_reward: 37 | filetypes = filetypes[1:] 38 | for i, filetype in enumerate(filetypes): 39 | filename = Path(data_path / f'{game}/{filetype}_{checkpoint}.gz') 40 | print(f'Loading {filename}') 41 | 42 | # There's no point in putting rewards, actions or terminals on disk. 43 | # They're tiny and it'll just cause more I/O. 44 | on_disk = dataset_on_disk and filetype == "observation" 45 | 46 | g = gzip.GzipFile(filename=filename) 47 | data__ = np.load(g) 48 | if i == 0: 49 | self.has_parallel_envs = len(data__.shape) > 1 50 | if self.has_parallel_envs: 51 | self.n_envs = data__.shape[1] 52 | else: 53 | self.n_envs = 1 54 | if not self.has_parallel_envs: 55 | data__ = np.expand_dims(data__, 1) 56 | 57 | data___ = np.copy(data__[:max_size]) 58 | print(f'Using {data___.size * data___.itemsize} bytes') 59 | if not on_disk: 60 | del data__ 61 | data_ = torch.from_numpy(data___) 62 | else: 63 | new_filename = os.path.join(tmp_data_path, Path(os.path.basename(filename)[:-3]+".npy")) 64 | print("Stored on disk at {}".format(new_filename)) 65 | np.save(new_filename, data___,) 66 | del data___ 67 | del data__ 68 | data_ = np.load(new_filename, mmap_mode="r+") 69 | 70 | if (filetype == 'action') and full_action_set: 71 | action_mapping = dict(zip(data_.unique().numpy(), 72 | AtariEnv(re.sub(r'(? int: 88 | return self.effective_size*self.n_envs 89 | 90 | def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 91 | batch_ind = index // self.effective_size 92 | time_ind = index % self.effective_size 93 | sl = slice(time_ind, time_ind+self.f+self.k) 94 | if self.dataset_on_disk: 95 | obs = torch.from_numpy(self.observation[sl, batch_ind]) 96 | else: 97 | obs = (self.observation[sl, batch_ind]) 98 | 99 | # Buffer reading code still expects us to return rewards even when not used 100 | if self.load_reward: 101 | rewards = self.reward[sl, batch_ind] 102 | else: 103 | rewards = torch.zeros_like(self.terminal[sl, batch_ind]).float() 104 | return tuple([obs, 105 | self.action[sl, batch_ind], 106 | rewards, 107 | self.terminal[sl, batch_ind], 108 | ]) 109 | 110 | 111 | class MultiDQNReplayDataset(Dataset): 112 | def __init__(self, data_path: Path, 113 | tmp_data_path: 114 | Path, games: List[str], 115 | checkpoints: List[int], 116 | frames: int, 117 | k_step: int, 118 | max_size: int, 119 | full_action_set: bool, 120 | dataset_on_gpu: bool, 121 | dataset_on_disk: bool) -> None: 122 | self.games = [DQNReplayDataset(data_path, 123 | tmp_data_path, 124 | game, 125 | ckpt, 126 | frames, 127 | k_step, 128 | max_size, 129 | full_action_set, 130 | dataset_on_gpu, 131 | dataset_on_disk) for ckpt in checkpoints for game in games] 132 | 133 | self.num_blocks = len(self.games) 134 | self.block_len = len(self.games[0]) 135 | 136 | def __len__(self) -> int: 137 | return len(self.games) * len(self.games[0]) 138 | 139 | def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 140 | game_index = index % len(self.games) 141 | index = index // len(self.games) 142 | return self.games[game_index][index] 143 | 144 | 145 | def sanitize_batch(batch: OfflineSamples) -> OfflineSamples: 146 | has_dones, inds = torch.max(batch.done, 0) 147 | for i, (has_done, ind) in enumerate(zip(has_dones, inds)): 148 | if not has_done: 149 | continue 150 | batch.all_observation[ind+1:, i] = batch.all_observation[ind, i] 151 | batch.all_reward[ind+1:, i] = 0 152 | batch.return_[ind+1:, i] = 0 153 | batch.done_n[ind+1:, i] = True 154 | return batch 155 | 156 | 157 | def get_offline_dataloaders( 158 | *, 159 | data_path: Path, 160 | tmp_data_path: Path, 161 | games: List[str], 162 | checkpoints: List[int], 163 | frames: int, 164 | k_step: int, 165 | n_step_return: int, 166 | discount: float, 167 | samples: int, 168 | dataset_on_gpu: bool, 169 | dataset_on_disk: bool, 170 | batch_size: int, 171 | full_action_set: bool, 172 | num_workers: int, 173 | pin_memory: bool, 174 | prefetch_factor: int, 175 | group_read_factor: int=0, 176 | shuffle_checkpoints: bool=False, 177 | **kwargs, 178 | ) -> Tuple[DataLoader, DataLoader, DataLoader]: 179 | def collate(batch): 180 | observation, action, reward, done = torch.utils.data.dataloader.default_collate(batch) 181 | observation = torch.einsum('bthw->tbhw', observation).unsqueeze(2).repeat(1, 1, frames, 1, 1) 182 | for i in range(1, frames): 183 | observation[:, :, i] = observation[:, :, i].roll(-i, 0) 184 | observation = observation[:-frames].unsqueeze(3) # tbfchw 185 | action = torch.einsum('bt->tb', action)[frames-2:-2].long() 186 | reward = torch.einsum('bt->tb', reward)[frames-2:-2] 187 | reward = torch.nan_to_num(reward).sign() # Apparently possible, somehow. 188 | done = torch.einsum('bt->tb', done)[frames:].bool() 189 | return_, done_n = discount_return_n_step(reward[1:], done, n_step_return, discount) 190 | is_weights = torch.ones(observation.shape[1]).to(reward) 191 | return sanitize_batch(OfflineSamples(observation, action, reward, return_, done[:-n_step_return], done_n, None, is_weights)) 192 | 193 | dataset = MultiDQNReplayDataset(data_path, tmp_data_path, games, checkpoints, frames, k_step, samples, full_action_set, dataset_on_gpu, dataset_on_disk) 194 | 195 | if shuffle_checkpoints: 196 | data = get_from_dataloaders(dataset.games) 197 | shuffled_data = shuffle_batch_dim(*data) 198 | assign_to_dataloaders(dataset.games, *shuffled_data) 199 | 200 | if group_read_factor != 0: 201 | sampler = CacheEfficientSampler(dataset.num_blocks, dataset.block_len, group_read_factor) 202 | dataloader = DataLoader(dataset, batch_size=batch_size, 203 | sampler=sampler, 204 | num_workers=num_workers, 205 | pin_memory=pin_memory, 206 | collate_fn=collate, 207 | drop_last=True, 208 | prefetch_factor=prefetch_factor) 209 | else: 210 | dataloader = DataLoader(dataset, batch_size=batch_size, 211 | shuffle=True, 212 | num_workers=num_workers, 213 | pin_memory=pin_memory, 214 | collate_fn=collate, 215 | drop_last=True, 216 | prefetch_factor=prefetch_factor) 217 | 218 | return dataloader, None, None 219 | 220 | 221 | class CacheEfficientSampler(torch.utils.data.Sampler): 222 | def __init__(self, num_blocks, block_len, num_repeats=20, generator=None): 223 | self.num_blocks = num_blocks 224 | self.block_len = block_len # For now, assume all have same length 225 | self.num_repeats = num_repeats 226 | self.generator = generator 227 | if self.num_repeats == "all": 228 | self.num_repeats = block_len 229 | 230 | def num_samples(self) -> int: 231 | # dataset size might change at runtime 232 | return self.block_len * self.num_blocks 233 | 234 | def __iter__(self): 235 | n = self.num_samples() 236 | if self.generator is None: 237 | generator = torch.Generator() 238 | generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item())) 239 | else: 240 | generator = self.generator 241 | 242 | self.block_ids = [np.arange(self.num_blocks)] * (self.block_len // self.num_repeats) 243 | blocks = torch.randperm(n // self.num_repeats, generator=generator) % self.num_blocks 244 | intra_orders = [torch.randperm(self.block_len, generator=generator) + self.block_len * i for i in 245 | range(self.num_blocks)] 246 | intra_orders = [i.tolist() for i in intra_orders] 247 | 248 | indices = [] 249 | block_counts = [0] * self.num_blocks 250 | 251 | for block in blocks: 252 | indices += intra_orders[block][ 253 | (block_counts[block] * self.num_repeats):(block_counts[block] + 1) * self.num_repeats] 254 | block_counts[block] += 1 255 | 256 | return iter(indices) 257 | 258 | def __len__(self): 259 | return self.num_samples() 260 | 261 | 262 | def shuffle_by_trajectory(): 263 | raise NotImplementedError 264 | 265 | 266 | def grouper(iterable, n, fillvalue=None): 267 | "Collect data into fixed-length chunks or blocks" 268 | # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx" 269 | args = [iter(iterable)] * n 270 | return zip_longest(*args, fillvalue=fillvalue) 271 | 272 | 273 | def shuffle_batch_dim(observations, 274 | rewards, 275 | actions, 276 | dones, 277 | obs_on_disk=True, 278 | chunk_num=1, 279 | ): 280 | """ 281 | :param observations: (T, B, *) obs tensor, optionally mmap 282 | :param rewards: (T, B) rewards tensor 283 | :param actions: (T, B, *) actions tensor 284 | :param dones: (T, B) termination tensor 285 | :param obs_on_disk: Store observations on disk. Generally true if using 286 | more than ~3M transitions 287 | :return: 288 | """ 289 | batch_dim = observations[0].shape[1] 290 | num_sources = len(observations) 291 | batch_allocations = [np.sort((np.arange(batch_dim) + i) % num_sources) for i in range(num_sources)] 292 | 293 | shuffled_observations, shuffled_rewards, shuffled_actions, shuffled_dones = [], [], [], [] 294 | 295 | checkpoints = list(range(num_sources)) 296 | for sources, shuffled, filetype in zip([observations, rewards, actions, dones], 297 | [shuffled_observations, shuffled_rewards, shuffled_actions, shuffled_dones], 298 | ["observations", "rewards", "actions", "dones"]): 299 | ind_counters = [0]*num_sources 300 | 301 | for start in checkpoints[::chunk_num]: 302 | chunk = checkpoints[start:start+chunk_num] 303 | chunk_arrays = [] 304 | for i in chunk: 305 | if isinstance(sources[0], torch.Tensor): 306 | new_array = torch.zeros_like(sources[0]) 307 | else: 308 | new_array = np.zeros(sources[0].shape, dtype=sources[0].dtype) 309 | chunk_arrays.append(new_array) 310 | for source, allocation in zip(sources, batch_allocations): 311 | print(chunk, ind_counters) 312 | for i, new_array in zip(chunk, chunk_arrays): 313 | mapped_to_us = [b for b, dest in enumerate(allocation) if dest == i] 314 | new_array[:, ind_counters[i]:ind_counters[i]+len(mapped_to_us)] = source[:, mapped_to_us[0]:mapped_to_us[-1]+1] 315 | ind_counters[i] += len(mapped_to_us) 316 | 317 | for i, new_array in zip(chunk, chunk_arrays): 318 | if filetype == "observations" and obs_on_disk: 319 | filename = observations[i].filename.replace(".npy", "_shuffled.npy") 320 | print("Stored shuffled obs on disk at {}".format(filename)) 321 | np.save(filename, new_array) 322 | del new_array 323 | new_array = np.load(filename, mmap_mode="r+") 324 | shuffled.append(new_array) 325 | 326 | return shuffled_observations, shuffled_rewards, shuffled_actions, shuffled_dones 327 | 328 | 329 | def get_from_dataloaders(dataloaders): 330 | observations = [dataloader.observations for dataloader in dataloaders] 331 | rewards = [dataloader.rewards for dataloader in dataloaders] 332 | actions = [dataloader.actions for dataloader in dataloaders] 333 | dones = [dataloader.terminal for dataloader in dataloaders] 334 | 335 | return observations, rewards, actions, dones 336 | 337 | 338 | def assign_to_dataloaders(dataloaders, observations, rewards, actions, dones): 339 | for dl, obs, rew, act, done in zip(dataloaders, observations, rewards, actions, dones): 340 | dl.observations = obs 341 | dl.rewards = rew 342 | dl.actions = act 343 | dl.terminal = done 344 | 345 | 346 | -------------------------------------------------------------------------------- /src/algos.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import time 6 | 7 | import wandb 8 | from scipy import stats 9 | import numpy as np 10 | from src.scrl import * 11 | from src.utils import select_at_indexes, find_weight_norm, DataWriter, sanity_check_gcrl, to_categorical, from_categorical 12 | from rlpyt.utils.collections import namedarraytuple 13 | from collections import namedtuple 14 | from src.utils import discount_return_n_step 15 | from rlpyt.algos.dqn.cat_dqn import CategoricalDQN 16 | from src.rlpyt_buffer import AsyncPrioritizedSequenceReplayFrameBufferExtended, \ 17 | AsyncUniformSequenceReplayFrameBufferExtended 18 | from rlpyt.replays.sequence.prioritized import SamplesFromReplayPri 19 | SamplesToBuffer = namedarraytuple("SamplesToBuffer", 20 | ["observation", "action", "reward", "done"]) 21 | ModelSamplesToBuffer = namedarraytuple("SamplesToBuffer", 22 | ["observation", "action", "reward", "done", "value"]) 23 | 24 | OptInfo = namedtuple("OptInfo", ["loss", "gradNorm", "tdAbsErr"]) 25 | ModelOptInfo = namedtuple("OptInfo", ["loss", "gradNorm", 26 | "tdAbsErr", 27 | "GoalLoss", 28 | "GoalError", 29 | "modelGradNorm", 30 | "T0SPRLoss", 31 | "InverseModelLoss", 32 | "RewardLoss", 33 | "BCLoss", 34 | "SampleTime", 35 | "ForwardTime", 36 | "CNNWeightNorm", 37 | "ModelSPRLoss", 38 | "Diversity"]) 39 | 40 | EPS = 1e-6 # (NaN-guard) 41 | 42 | 43 | class SPRCategoricalDQN(CategoricalDQN): 44 | """Distributional DQN with fixed probability bins for the Q-value of each 45 | action, a.k.a. categorical.""" 46 | 47 | def __init__(self, 48 | rl_weight=1., 49 | spr_weight=1., 50 | inverse_model_weight=1, 51 | goal_n_step=1, 52 | goal_window=50, 53 | goal_weight=1., 54 | goal_permute_prob=0.2, 55 | goal_noise_weight=0.5, 56 | goal_reward_scale=5., 57 | goal_all_to_all=False, 58 | conv_goal=False, 59 | clip_model_grad_norm=10., 60 | goal_dist="exp", 61 | jumps=0, 62 | offline=False, 63 | bc_weight=0, 64 | encoder_lr: Optional[float] = None, 65 | dynamics_model_lr: Optional[float] = None, 66 | q_l1_lr: Optional[float] = None, 67 | data_writer_args={"save_data": False}, 68 | **kwargs): 69 | super().__init__(**kwargs) 70 | self.opt_info_fields = tuple(f for f in ModelOptInfo._fields) # copy 71 | self.spr_weight = spr_weight 72 | self.inverse_model_weight = inverse_model_weight 73 | self.clip_model_grad_norm = clip_model_grad_norm 74 | self.goal_window = goal_window 75 | self.goal_n_step = goal_n_step 76 | self.goal_permute_prob = goal_permute_prob 77 | self.goal_reward_scale = goal_reward_scale 78 | self.goal_noise_weight = goal_noise_weight 79 | self.goal_all_to_all = goal_all_to_all 80 | self.offline = offline 81 | self.conv_goal = conv_goal 82 | 83 | self.bc_weight = bc_weight 84 | 85 | if "exp" in goal_dist.lower(): 86 | self.goal_distance = exp_distance 87 | else: 88 | self.goal_distance = norm_dist 89 | 90 | self.rl_weight = rl_weight 91 | self.goal_weight = goal_weight 92 | self.jumps = jumps 93 | 94 | self.save_data = data_writer_args["save_data"] 95 | if self.save_data: 96 | self.data_writer = DataWriter(**data_writer_args) 97 | 98 | self.encoder_lr = encoder_lr if encoder_lr is not None else self.learning_rate 99 | self.dynamics_model_lr = dynamics_model_lr if dynamics_model_lr is not None else self.learning_rate 100 | self.q_l1_lr = q_l1_lr if q_l1_lr is not None else self.learning_rate 101 | 102 | def initialize_replay_buffer(self, examples, batch_spec, async_=False): 103 | example_to_buffer = ModelSamplesToBuffer( 104 | observation=examples["observation"], 105 | action=examples["action"], 106 | reward=examples["reward"], 107 | done=examples["done"], 108 | value=examples["agent_info"].p, 109 | ) 110 | replay_kwargs = dict( 111 | example=example_to_buffer, 112 | size=self.replay_size, 113 | B=batch_spec.B, 114 | batch_T=max(self.jumps+1, self.goal_window) if self.goal_weight > 0 else self.jumps+1, 115 | discount=self.discount, 116 | n_step_return=self.n_step_return, 117 | rnn_state_interval=0, 118 | ) 119 | 120 | if self.prioritized_replay: 121 | replay_kwargs['alpha'] = self.pri_alpha 122 | replay_kwargs['beta'] = self.pri_beta_init 123 | # replay_kwargs["input_priorities"] = self.input_priorities 124 | buffer = AsyncPrioritizedSequenceReplayFrameBufferExtended(**replay_kwargs) 125 | else: 126 | buffer = AsyncUniformSequenceReplayFrameBufferExtended(**replay_kwargs) 127 | 128 | self.replay_buffer = buffer 129 | 130 | def optim_initialize(self, rank=0): 131 | """Called in initilize or by async runner after forking sampler.""" 132 | self.rank = rank 133 | try: 134 | # We're probably dealing with DDP 135 | self.model = self.agent.model.module 136 | except: 137 | self.model = self.agent.model 138 | 139 | # Split into (optionally) three groups for separate LRs. 140 | conv_params, dynamics_model_params, q_l1_params, other_params = self.model.list_params() 141 | self.optimizer = self.OptimCls([ 142 | {'params': conv_params, 'lr': self.encoder_lr}, 143 | {'params': q_l1_params, 'lr': self.q_l1_lr}, 144 | {'params': dynamics_model_params, 'lr': self.dynamics_model_lr}, 145 | {'params': other_params, 'lr': self.learning_rate} 146 | ], 147 | **self.optim_kwargs) 148 | 149 | if self.initial_optim_state_dict is not None: 150 | self.optimizer.load_state_dict(self.initial_optim_state_dict) 151 | if self.prioritized_replay: 152 | self.pri_beta_itr = max(1, self.pri_beta_steps // self.sampler_bs) 153 | 154 | def samples_to_buffer(self, samples): 155 | """Defines how to add data from sampler into the replay buffer. Called 156 | in optimize_agent() if samples are provided to that method. In 157 | asynchronous mode, will be called in the memory_copier process.""" 158 | return ModelSamplesToBuffer( 159 | observation=samples.env.observation, 160 | action=samples.agent.action, 161 | reward=samples.env.reward, 162 | done=samples.env.done, 163 | value=samples.agent.agent_info.p, 164 | ) 165 | 166 | def sample_batch(self): 167 | if not self.offline: 168 | samples = self.replay_buffer.sample_batch(self.batch_size) 169 | return samples 170 | else: 171 | return self.sample_offline_dataset() 172 | 173 | def sample_offline_dataset(self): 174 | try: 175 | samples = next(self.offline_dataloader) 176 | except Exception as e: 177 | self.offline_dataloader = iter(self.offline_dataset) 178 | samples = next(self.offline_dataloader) 179 | return samples 180 | 181 | def optimize_agent(self, itr, samples=None, sampler_itr=None, offline_samples=None): 182 | """ 183 | Extracts the needed fields from input samples and stores them in the 184 | replay buffer. Then samples from the replay buffer to train the agent 185 | by gradient updates (with the number of updates determined by replay 186 | ratio, sampler batch size, and training batch size). If using prioritized 187 | replay, updates the priorities for sampled training batches. 188 | """ 189 | itr = itr if sampler_itr is None else sampler_itr # Async uses sampler_itr.= 190 | if samples is not None: 191 | if self.save_data: 192 | self.data_writer.write(samples) 193 | samples_to_buffer = self.samples_to_buffer(samples) 194 | self.replay_buffer.append_samples(samples_to_buffer) 195 | opt_info = ModelOptInfo(*([] for _ in range(len(ModelOptInfo._fields)))) 196 | if not self.offline and itr < self.min_itr_learn: 197 | return opt_info 198 | for _ in range(1 if self.offline else self.updates_per_optimize): 199 | start = time.time() 200 | samples_from_replay = self.sample_batch() 201 | 202 | end = time.time() 203 | sample_time = end - start 204 | 205 | forward_time = time.time() 206 | rl_loss, td_abs_errors, goal_loss,\ 207 | t0_spr_loss, model_spr_loss, \ 208 | diversity, inverse_model_loss, bc_loss, \ 209 | goal_abs_errors \ 210 | = self.loss(samples_from_replay, self.offline) 211 | forward_time = time.time() - forward_time 212 | 213 | total_loss = self.rl_weight * rl_loss 214 | total_loss += self.spr_weight * model_spr_loss 215 | total_loss += self.goal_weight * goal_loss 216 | total_loss += self.inverse_model_weight * inverse_model_loss 217 | total_loss += self.bc_weight * bc_loss 218 | 219 | self.optimizer.zero_grad() 220 | total_loss.backward() 221 | stem_params, model_params = self.model.split_stem_model_params() 222 | if self.clip_grad_norm > 0: 223 | grad_norm = torch.nn.utils.clip_grad_norm_(stem_params, 224 | self.clip_grad_norm) 225 | else: 226 | grad_norm = 0 227 | if self.clip_model_grad_norm > 0: 228 | model_grad_norm = torch.nn.utils.clip_grad_norm_(model_params, 229 | self.clip_model_grad_norm) 230 | else: 231 | model_grad_norm = 0 232 | 233 | cnn_weight_norm = find_weight_norm(self.model.conv.parameters()) 234 | 235 | self.optimizer.step() 236 | 237 | if not self.offline and self.prioritized_replay: 238 | self.replay_buffer.update_batch_priorities(td_abs_errors) 239 | opt_info.loss.append(rl_loss.item()) 240 | opt_info.gradNorm.append(torch.tensor(grad_norm).item()) # grad_norm is a float sometimes, so wrap in tensor 241 | opt_info.GoalLoss.append(goal_loss.item()) 242 | opt_info.modelGradNorm.append(torch.tensor(model_grad_norm).item()) 243 | opt_info.T0SPRLoss.append(t0_spr_loss.item()) 244 | opt_info.InverseModelLoss.append(inverse_model_loss.item()) 245 | opt_info.BCLoss.append(bc_loss.item()) 246 | opt_info.CNNWeightNorm.append(cnn_weight_norm.item()) 247 | opt_info.SampleTime.append(sample_time) 248 | opt_info.ForwardTime.append(forward_time) 249 | opt_info.Diversity.append(diversity.item()) 250 | opt_info.ModelSPRLoss.append(model_spr_loss.item()) 251 | opt_info.tdAbsErr.extend(td_abs_errors[::8].cpu().numpy()) # Downsample. 252 | opt_info.GoalError.extend(goal_abs_errors[::8].cpu().numpy()) # Downsample. 253 | self.update_counter += 1 254 | if self.update_counter % self.target_update_interval == 0: 255 | self.agent.update_target(self.target_update_tau) 256 | self.update_itr_hyperparams(itr) 257 | return opt_info 258 | 259 | def rl_loss(self, log_pred_ps, observations, goals, actions, rewards, nonterminals, returns, index, n_step): 260 | """ 261 | Computes the Distributional Q-learning loss, based on projecting the 262 | discounted rewards + target Q-distribution into the current Q-domain, 263 | with cross-entropy loss. 264 | 265 | Returns loss and KL-divergence-errors for use in prioritization. 266 | """ 267 | delta_z = (self.V_max - self.V_min) / (self.agent.n_atoms - 1) 268 | z = torch.linspace(self.V_min, self.V_max, self.agent.n_atoms, device=log_pred_ps.device) 269 | # Make 2-D tensor of contracted z_domain for each data point, 270 | # with zeros where next value should not be added. 271 | next_z = z * (self.discount ** n_step) # [P'] 272 | next_z = torch.ger(nonterminals[index], next_z) # [B,P'] 273 | ret = returns[index].unsqueeze(-1) # [B,1] 274 | 275 | next_z = torch.clamp(ret + next_z, self.V_min, self.V_max) # [B,P'] 276 | 277 | z_bc = z.view(1, -1, 1) # [1,P,1] 278 | next_z_bc = next_z.unsqueeze(-2) # [B,1,P'] 279 | abs_diff_on_delta = abs(next_z_bc - z_bc) / delta_z 280 | projection_coeffs = torch.clamp(1 - abs_diff_on_delta, 0, 1) # Most 0. 281 | # projection_coeffs is a 3-D tensor: [B,P,P'] 282 | # dim-0: independent data entries 283 | # dim-1: base_z atoms (remains after projection) 284 | # dim-2: next_z atoms (summed in projection) 285 | 286 | with torch.no_grad(): 287 | target_ps = self.agent.target(observations[index + n_step], 288 | actions[index + n_step], 289 | rewards[index + n_step], 290 | goals) # [B,A,P'] 291 | if self.double_dqn: 292 | next_ps = self.agent(observations[index + n_step], 293 | actions[index + n_step], 294 | rewards[index + n_step], 295 | goals) # [B,A,P'] 296 | next_qs = torch.tensordot(next_ps, z, dims=1) # [B,A] 297 | next_a = torch.argmax(next_qs, dim=-1) # [B] 298 | else: 299 | target_qs = torch.tensordot(target_ps, z, dims=1) # [B,A] 300 | next_a = torch.argmax(target_qs, dim=-1) # [B] 301 | target_p_unproj = select_at_indexes(next_a, target_ps) # [B,P'] 302 | target_p_unproj = target_p_unproj.unsqueeze(1) # [B,1,P'] 303 | target_p = (target_p_unproj * projection_coeffs).sum(-1) # [B,P] 304 | p = select_at_indexes(actions[index + 1].squeeze(-1), log_pred_ps) # [B,P] 305 | # p = torch.clamp(p, EPS, 1) # NaN-guard. 306 | losses = -torch.sum(target_p * p, dim=-1) # Cross-entropy. 307 | 308 | target_p = torch.clamp(target_p, EPS, 1) 309 | KL_div = torch.sum(target_p * 310 | (torch.log(target_p) - p.detach()), dim=-1) 311 | KL_div = torch.clamp(KL_div, EPS, 1 / EPS) # Avoid <0 from NaN-guard. 312 | 313 | return losses, KL_div.detach() 314 | 315 | @torch.no_grad() 316 | def sample_goals(self, observation): 317 | proj_latents, latents = sample_goals(observation, self.model.encode_targets) 318 | if self.conv_goal: 319 | goals = latents.squeeze(0) 320 | else: 321 | goals = proj_latents.squeeze(0) 322 | 323 | goals = add_noise(goals, self.goal_noise_weight) 324 | goals = permute_goals(goals, self.goal_permute_prob) 325 | 326 | goals = self.model.renormalize(goals) 327 | 328 | return goals 329 | 330 | def loss(self, samples, offline=False): 331 | if self.model.noisy: 332 | self.model.head.reset_noise() 333 | self.agent.target_model.head.reset_noise() 334 | 335 | observations = samples.all_observation.to(self.agent.device) 336 | actions = samples.all_action.to(self.agent.device) 337 | rewards = samples.all_reward.to(self.agent.device) 338 | # rewards = torch.nan_to_num(rewards) # Apparently possible, somehow. 339 | dones = samples.done.to(self.agent.device) 340 | done_ns = samples.done_n.to(self.agent.device) 341 | nonterminals = 1. - torch.sign(torch.cumsum(dones, 0)).float() 342 | nonterminals_n = 1. - torch.sign(torch.cumsum(done_ns, 0)).float() 343 | 344 | if self.goal_weight > 0: 345 | goals = self.sample_goals(observations[1:self.goal_window]) 346 | else: 347 | goals = None 348 | 349 | log_pred_ps, goal_log_pred_ps, spr_loss, latents, proj_latents, diversity, inverse_model_loss, bc_preds\ 350 | = self.agent(observations, actions, rewards, goals, train=True) # [B,A,P] 351 | 352 | if self.rl_weight > 0: 353 | returns = samples.return_.to(self.agent.device) 354 | rl_loss, KL = self.rl_loss(log_pred_ps[:self.batch_size], observations[:, :self.batch_size], 355 | None, actions[:, :self.batch_size], 356 | rewards[:, :self.batch_size], nonterminals_n[:, :self.batch_size], 357 | returns[:, :self.batch_size], 0, self.n_step_return) 358 | else: 359 | rl_loss = torch.zeros_like(spr_loss[0][:self.batch_size]) 360 | KL = torch.zeros_like(rl_loss) 361 | 362 | if self.bc_weight > 0: 363 | log_pred_actions = F.log_softmax(bc_preds, -1) 364 | targets = actions[1] 365 | bc_loss = F.nll_loss(log_pred_actions, targets) 366 | else: 367 | bc_loss = torch.zeros_like(rl_loss) 368 | 369 | if self.goal_weight > 0: 370 | if self.conv_goal: 371 | goal_latents = latents[:, :self.goal_n_step+1] 372 | else: 373 | goal_latents = proj_latents[:, :self.goal_n_step+1] 374 | goal_returns = calculate_returns(goal_latents, 375 | goals, 376 | self.goal_distance, 377 | self.discount, 378 | nonterminals[:self.goal_n_step], 379 | distance_scale=5., 380 | reward_scale=self.goal_reward_scale, 381 | all_to_all=self.goal_all_to_all) 382 | 383 | if self.goal_all_to_all: 384 | goal_nonterminals = nonterminals[None, None, self.goal_n_step] 385 | goal_nonterminals = goal_nonterminals.expand(-1, goal_nonterminals.shape[-1], -1).flatten(-2, -1) 386 | goal_actions = actions[:, None, :].expand(-1, actions.shape[-1], -1).flatten(-2, -1) 387 | else: 388 | goal_nonterminals = nonterminals[None, self.goal_n_step] 389 | goal_actions = actions 390 | 391 | goal_loss, goal_KL = self.rl_loss(goal_log_pred_ps, observations, 392 | goals, goal_actions, 393 | rewards, 394 | goal_nonterminals, 395 | goal_returns, 396 | 0, 397 | self.goal_n_step) 398 | 399 | if self.goal_all_to_all: 400 | bs = actions.shape[1] 401 | goal_loss = goal_loss.view(bs, bs).mean(1) 402 | goal_KL = goal_KL.view(bs, bs).mean(1) 403 | 404 | else: 405 | goal_loss = goal_KL = torch.zeros_like(spr_loss[0]) 406 | 407 | spr_loss = spr_loss*nonterminals[:self.jumps+1] 408 | if self.jumps > 0: 409 | model_spr_loss = spr_loss[1:].mean(0) 410 | t0_spr_loss = spr_loss[0] 411 | else: 412 | t0_spr_loss = spr_loss[0] 413 | model_spr_loss = torch.zeros_like(spr_loss) 414 | t0_spr_loss = t0_spr_loss 415 | model_spr_loss = model_spr_loss 416 | if not offline and self.prioritized_replay: 417 | weights = samples.is_weights.to(rl_loss.device) 418 | t0_spr_loss = t0_spr_loss * weights 419 | model_spr_loss = model_spr_loss * weights 420 | goal_loss = goal_loss*weights 421 | bc_loss = bc_loss*weights 422 | inverse_model_loss = inverse_model_loss*weights 423 | rl_loss = rl_loss * weights 424 | 425 | return rl_loss.mean(), KL, \ 426 | goal_loss.mean(), \ 427 | t0_spr_loss.mean(), \ 428 | model_spr_loss.mean(), \ 429 | diversity, \ 430 | inverse_model_loss.mean(),\ 431 | bc_loss.mean(), \ 432 | goal_KL 433 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from kornia.augmentation import RandomAffine,\ 3 | RandomCrop,\ 4 | CenterCrop, \ 5 | RandomResizedCrop 6 | from kornia.filters import GaussianBlur2d 7 | from torch import nn 8 | import numpy as np 9 | import glob 10 | import gzip 11 | import shutil 12 | from pathlib import Path 13 | import os 14 | EPS = 1e-6 15 | 16 | 17 | def count_parameters(model): 18 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 19 | 20 | 21 | def select_at_indexes(indexes, tensor): 22 | """Returns the contents of ``tensor`` at the multi-dimensional integer 23 | array ``indexes``. Leading dimensions of ``tensor`` must match the 24 | dimensions of ``indexes``. 25 | """ 26 | dim = len(indexes.shape) 27 | assert indexes.shape == tensor.shape[:dim] 28 | num = indexes.numel() 29 | t_flat = tensor.view((num,) + tensor.shape[dim:]) 30 | s_flat = t_flat[torch.arange(num, device=tensor.device), indexes.view(-1)] 31 | return s_flat.view(tensor.shape[:dim] + tensor.shape[dim + 1:]) 32 | 33 | 34 | def get_augmentation(augmentation, imagesize): 35 | if isinstance(augmentation, str): 36 | augmentation = augmentation.split("_") 37 | transforms = [] 38 | for aug in augmentation: 39 | if aug == "affine": 40 | transformation = RandomAffine(5, (.14, .14), (.9, 1.1), (-5, 5)) 41 | elif aug == "rrc": 42 | transformation = RandomResizedCrop((imagesize, imagesize), (0.8, 1)) 43 | elif aug == "blur": 44 | transformation = GaussianBlur2d((5, 5), (1.5, 1.5)) 45 | elif aug == "shift" or aug == "crop": 46 | transformation = nn.Sequential(nn.ReplicationPad2d(4), RandomCrop((84, 84))) 47 | elif aug == "intensity": 48 | transformation = Intensity(scale=0.05) 49 | elif aug == "none": 50 | continue 51 | else: 52 | raise NotImplementedError() 53 | transforms.append(transformation) 54 | 55 | return transforms 56 | 57 | 58 | class Intensity(nn.Module): 59 | def __init__(self, scale): 60 | super().__init__() 61 | self.scale = scale 62 | 63 | def forward(self, x): 64 | r = torch.randn((x.size(0), 1, 1, 1), device=x.device) 65 | noise = 1.0 + (self.scale * r.clamp(-2.0, 2.0)) 66 | return x * noise 67 | 68 | 69 | def maybe_transform(image, transform, p=0.8): 70 | processed_images = transform(image) 71 | if p >= 1: 72 | return processed_images 73 | else: 74 | mask = torch.rand((processed_images.shape[0], 1, 1, 1), 75 | device=processed_images.device) 76 | mask = (mask < p).float() 77 | processed_images = mask * processed_images + (1 - mask) * image 78 | return processed_images 79 | 80 | 81 | def renormalize(tensor, first_dim=-3): 82 | if first_dim < 0: 83 | first_dim = len(tensor.shape) + first_dim 84 | flat_tensor = tensor.view(*tensor.shape[:first_dim], -1) 85 | max = torch.max(flat_tensor, first_dim, keepdim=True).values 86 | min = torch.min(flat_tensor, first_dim, keepdim=True).values 87 | flat_tensor = (flat_tensor - min)/(max - min) 88 | 89 | return flat_tensor.view(*tensor.shape) 90 | 91 | 92 | def to_categorical(value, limit=300): 93 | value = value.float() # Avoid any fp16 shenanigans 94 | value = value.clamp(-limit, limit) 95 | distribution = torch.zeros(value.shape[0], (limit*2+1), device=value.device) 96 | lower = value.floor().long() + limit 97 | upper = value.ceil().long() + limit 98 | upper_weight = value % 1 99 | lower_weight = 1 - upper_weight 100 | distribution.scatter_add_(-1, lower.unsqueeze(-1), lower_weight.unsqueeze(-1)) 101 | distribution.scatter_add_(-1, upper.unsqueeze(-1), upper_weight.unsqueeze(-1)) 102 | return distribution 103 | 104 | 105 | def from_categorical(distribution, limit=300, logits=True): 106 | distribution = distribution.float() # Avoid any fp16 shenanigans 107 | if logits: 108 | distribution = torch.softmax(distribution, -1) 109 | num_atoms = distribution.shape[-1] 110 | weights = torch.linspace(-limit, limit, num_atoms, device=distribution.device).float() 111 | return distribution @ weights 112 | 113 | 114 | def extract_epoch(filename): 115 | """ 116 | Get the epoch from a model save string formatted as name_Epoch:{seed}.pt 117 | :param str: Model save name 118 | :return: epoch (int) 119 | """ 120 | 121 | if "epoch" not in filename.lower(): 122 | return 0 123 | 124 | epoch = int(filename.lower().split("epoch_")[-1].replace(".pt", "")) 125 | return epoch 126 | 127 | 128 | def get_last_save(base_pattern, retry=True): 129 | files = glob.glob(base_pattern+"*.pt") 130 | epochs = [extract_epoch(path) for path in files] 131 | 132 | inds = np.argsort(-np.array(epochs)) 133 | for ind in inds: 134 | try: 135 | print("Attempting to load {}".format(files[ind])) 136 | state_dict = torch.load(Path(files[ind])) 137 | epoch = epochs[ind] 138 | return state_dict, epoch 139 | except Exception as e: 140 | if retry: 141 | print("Loading failed: {}".format(e)) 142 | else: 143 | raise e 144 | 145 | 146 | def delete_all_but_last(base_pattern, num_to_keep=3): 147 | files = glob.glob(base_pattern+"*.pt") 148 | epochs = [extract_epoch(path) for path in files] 149 | 150 | order = np.argsort(np.array(epochs)) 151 | 152 | for i in order[:-num_to_keep]: 153 | os.remove(files[i]) 154 | print("Deleted old save {}".format(files[i])) 155 | 156 | 157 | def save_model_fn(folder, model_save, seed, use_epoch=True, save_only_last=False): 158 | def save_model(model, optim, epoch): 159 | if use_epoch: 160 | path = Path(f'{folder}/{model_save}_{seed}_epoch_{epoch}.pt') 161 | else: 162 | path = Path(f'{folder}/{model_save}_{seed}.pt') 163 | 164 | torch.save({"model": model, "optim": optim}, path) 165 | print("Saved model at {}".format(path)) 166 | 167 | if save_only_last: 168 | 169 | delete_all_but_last(f'{folder}/{model_save}_{seed}') 170 | 171 | return save_model 172 | 173 | 174 | def find_weight_norm(parameters, norm_type=1.0) -> torch.Tensor: 175 | r"""Finds the norm of an iterable of parameters. 176 | 177 | The norm is computed over all parameterse together, as if they were 178 | concatenated into a single vector. 179 | 180 | Arguments: 181 | parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a 182 | single Tensor to find norms of 183 | norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for 184 | infinity norm. 185 | 186 | Returns: 187 | Total norm of the parameters (viewed as a single vector). 188 | """ 189 | if isinstance(parameters, torch.Tensor): 190 | parameters = [parameters] 191 | parameters = [p for p in parameters if p is not None] 192 | norm_type = float(norm_type) 193 | if len(parameters) == 0: 194 | return torch.tensor(0.) 195 | device = parameters[0].device 196 | if norm_type == np.inf: 197 | total_norm = max(p.data.detach().abs().max().to(device) for p in parameters) 198 | else: 199 | total_norm = torch.norm(torch.stack([torch.norm(p.data.detach(), norm_type).to(device) for p in parameters]), norm_type) 200 | return total_norm 201 | 202 | 203 | def minimal_quantile_loss(pred_values, target_values, taus, kappa=1.0): 204 | if len(pred_values.shape) == 3: 205 | output_shape = pred_values.shape[:2] 206 | target_values = target_values.expand_as(pred_values) 207 | pred_values = pred_values.flatten(0, 1) 208 | target_values = target_values.flatten(0, 1) 209 | else: 210 | output_shape = pred_values.shape[:1] 211 | 212 | if pred_values.shape[0] != taus.shape[0]: 213 | # somebody has added states along the batch dimension, 214 | # probably to do multiple timesteps' losses simultaneously. 215 | # Since the standard in this codebase is to put time on dimension 1 and 216 | # then flatten 0 and 1, we can do the same here to get the right shape. 217 | expansion_factor = pred_values.shape[0]//taus.shape[0] 218 | taus = taus.unsqueeze(1).expand(-1, expansion_factor, -1,).flatten(0, 1) 219 | 220 | td_errors = pred_values.unsqueeze(-1) - target_values.unsqueeze(1) 221 | assert not taus.requires_grad 222 | batch_size, N, N_dash = td_errors.shape 223 | 224 | # Calculate huber loss element-wisely. 225 | element_wise_huber_loss = calculate_huber_loss(td_errors, kappa) 226 | assert element_wise_huber_loss.shape == ( 227 | batch_size, N, N_dash) 228 | 229 | # Calculate quantile huber loss element-wisely. 230 | element_wise_quantile_huber_loss = torch.abs( 231 | taus[..., None] - (td_errors.detach() < 0).float() 232 | ) * element_wise_huber_loss / kappa 233 | assert element_wise_quantile_huber_loss.shape == ( 234 | batch_size, N, N_dash) 235 | 236 | # Quantile huber loss. 237 | batch_quantile_huber_loss = element_wise_quantile_huber_loss.sum( 238 | dim=1).mean(dim=1, keepdim=True) 239 | assert batch_quantile_huber_loss.shape == (batch_size, 1) 240 | 241 | loss = batch_quantile_huber_loss.squeeze(1) 242 | 243 | # Just use the regular loss as the error for PER, at least for now. 244 | return loss.view(*output_shape), loss.detach().view(*output_shape) 245 | 246 | 247 | def scalar_backup(n, returns, nonterminal, qs, discount, select_action=False, selection_values=None): 248 | """ 249 | :param qs: q estimates 250 | :param n: n-step 251 | :param nonterminal: 252 | :param returns: Returns, already scaled by discount/nonterminal 253 | :param discount: discount in [0, 1] 254 | :return: 255 | """ 256 | if select_action: 257 | if selection_values is None: 258 | selection_values = qs 259 | next_a = selection_values.mean(-1).argmax(-1) 260 | qs = select_at_indexes(next_a, qs) 261 | while len(returns.shape) < len(qs.shape): 262 | returns = returns.unsqueeze(-1) 263 | while len(nonterminal.shape) < len(qs.shape): 264 | nonterminal = nonterminal.unsqueeze(-1) 265 | discount = discount ** n 266 | qs = nonterminal*qs*discount + returns 267 | return qs 268 | 269 | 270 | def calculate_huber_loss(td_errors, kappa=1.0): 271 | return torch.where( 272 | td_errors.abs() <= kappa, 273 | 0.5 * td_errors.pow(2), 274 | kappa * (td_errors.abs() - 0.5 * kappa)) 275 | 276 | 277 | def c51_backup(n_step, 278 | returns, 279 | nonterminal, 280 | target_ps, 281 | select_action=False, 282 | V_max=10., 283 | V_min=10., 284 | n_atoms=51, 285 | discount=0.99, 286 | selection_values=None): 287 | 288 | z = torch.linspace(V_min, V_max, n_atoms, device=target_ps.device) 289 | 290 | if select_action: 291 | if selection_values is None: 292 | selection_values = target_ps 293 | target_qs = torch.tensordot(selection_values, z, dims=1) # [B,A] 294 | next_a = torch.argmax(target_qs, dim=-1) # [B] 295 | target_ps = select_at_indexes(next_a.to(target_ps.device), target_ps) # [B,P'] 296 | 297 | delta_z = (V_max - V_min) / (n_atoms - 1) 298 | # Make 2-D tensor of contracted z_domain for each data point, 299 | # with zeros where next value should not be added. 300 | next_z = z * (discount ** n_step) # [P'] 301 | next_z = nonterminal.unsqueeze(-1)*next_z.unsqueeze(-2) # [B,P'] 302 | ret = returns.unsqueeze(-1) # [B,1] 303 | 304 | num_extra_dims = len(ret.shape) - len(next_z.shape) 305 | next_z = next_z.view(*([1]*num_extra_dims), *next_z.shape) 306 | 307 | next_z = torch.clamp(ret + next_z, V_min, V_max) # [B,P'] 308 | 309 | z_bc = z.view(*([1]*num_extra_dims), 1, -1, 1) # [1,P,1] 310 | next_z_bc = next_z.unsqueeze(-2) # [B,1,P'] 311 | abs_diff_on_delta = abs(next_z_bc - z_bc) / delta_z 312 | projection_coeffs = torch.clamp(1 - abs_diff_on_delta, 0, 1) # Most 0. 313 | 314 | # projection_coeffs is a 3-D tensor: [B,P,P'] 315 | # dim-0: independent data entries 316 | # dim-1: base_z atoms (remains after projection) 317 | # dim-2: next_z atoms (summed in projection) 318 | 319 | target_ps = target_ps.unsqueeze(-2) # [B,1,P'] 320 | if not select_action and len(projection_coeffs.shape) != len(target_ps.shape): 321 | projection_coeffs = projection_coeffs.unsqueeze(-3) 322 | target_p = (target_ps * projection_coeffs).sum(-1) # [B,P] 323 | target_p = torch.clamp(target_p, EPS, 1) 324 | return target_p 325 | 326 | 327 | class DataWriter: 328 | def __init__(self, 329 | save_data=True, 330 | data_dir="/project/rrg-bengioy-ad/schwarzm/atari", 331 | save_name="", 332 | checkpoint_size=1000000, 333 | game="Pong", 334 | imagesize=(84, 84), 335 | mmap=True): 336 | 337 | self.save_name = save_name 338 | self.save_data = save_data 339 | if not self.save_data: 340 | return 341 | 342 | self.pointer = 0 343 | self.checkpoint = 0 344 | self.checkpoint_size = checkpoint_size 345 | self.imagesize = imagesize 346 | self.dir = Path(data_dir) / game.replace("_", " ").title().replace(" ", "") 347 | os.makedirs(self.dir, exist_ok=True) 348 | self.mmap = mmap 349 | self.reset() 350 | 351 | def reset(self): 352 | self.pointer = 0 353 | obs_data = np.zeros((self.checkpoint_size, *self.imagesize), dtype=np.uint8) 354 | action_data = np.zeros((self.checkpoint_size,), dtype=np.int32) 355 | reward_data = np.zeros((self.checkpoint_size,), dtype=np.float32) 356 | terminal_data = np.zeros((self.checkpoint_size,), dtype=np.uint8) 357 | 358 | self.arrays = [] 359 | self.filenames = [] 360 | 361 | for data, filetype in [(obs_data, 'observation'), 362 | (action_data, 'action'), 363 | (reward_data, 'reward'), 364 | (terminal_data, 'terminal')]: 365 | filename = Path(self.dir / f'{filetype}_{self.checkpoint}{self.save_name}.npy') 366 | if self.mmap: 367 | np.save(filename, data) 368 | data_ = np.memmap(filename, mode="w+", dtype=data.dtype, shape=data.shape,) 369 | del data 370 | else: 371 | data_ = data 372 | self.arrays.append(data_) 373 | self.filenames.append(filename) 374 | 375 | def save(self): 376 | for data, filename in zip(self.arrays, self.filenames): 377 | if not self.mmap: 378 | np.save(filename, data) 379 | del data # Flushes memmap 380 | with open(filename, 'rb') as f_in: 381 | new_filename = os.path.join(self.dir, Path(os.path.basename(filename)[:-4]+".gz")) 382 | with gzip.open(new_filename, 'wb') as f_out: 383 | shutil.copyfileobj(f_in, f_out) 384 | 385 | os.remove(filename) 386 | 387 | def write(self, samples): 388 | if not self.save_data: 389 | return 390 | 391 | self.arrays[0][self.pointer] = samples.env.observation[0, 0, -1, 0] 392 | self.arrays[1][self.pointer] = samples.agent.action 393 | self.arrays[2][self.pointer] = samples.env.reward 394 | self.arrays[3][self.pointer] = samples.env.done 395 | 396 | self.pointer += 1 397 | if self.pointer == self.checkpoint_size: 398 | self.checkpoint += 1 399 | self.save() 400 | self.reset() 401 | 402 | 403 | def update_state_dict_compat(osd, nsd): 404 | updated_osd = {k.replace("head.advantage", "head.goal_advantage"). 405 | replace("head.value", "head.goal_value"). 406 | replace("head.secondary_advantage_head", "head.rl_advantage"). 407 | replace("head.secondary_value_head", "head.rl_value") 408 | : v for k, v in osd.items()} 409 | filtered_osd = {k: v for k, v in updated_osd.items() if k in nsd} 410 | missing_items = [k for k, v in updated_osd.items() if k not in nsd] 411 | if len(missing_items) > 0: 412 | print("Could not load into new model: {}".format(missing_items)) 413 | nsd.update(filtered_osd) 414 | return nsd 415 | 416 | 417 | def calculate_true_values(states, 418 | goal, 419 | distance, 420 | gamma, 421 | final_value, 422 | nonterminal, 423 | distance_scale, 424 | reward_scale=10., 425 | all_to_all=False): 426 | """ 427 | :param states: (batch, jumps, dim) 428 | :param goal: (batch, dim) 429 | :param distance: distance function (state X state X scale -> R). 430 | :param gamma: rl discount gamma in [0, 1] 431 | :param nonterminal: 1 - done, (batch, jumps). 432 | :return: returns: discounted sum of rewards up to t, (batch, jumps); 433 | has shape (batch, batch, jumps) if all_to_all enabled 434 | """ 435 | nonterminal = nonterminal.transpose(0, 1) 436 | 437 | if all_to_all: 438 | states = states.unsqueeze(1) 439 | goal = goal.unsqueeze(0) 440 | nonterminal = nonterminal.unsqueeze(1) 441 | 442 | goal = goal.unsqueeze(-2) 443 | distances = distance(states, goal, distance_scale) 444 | deltas = distances[..., 0:-1] - distances[..., 1:] 445 | deltas = deltas*reward_scale 446 | 447 | final_values = torch.zeros_like(deltas) 448 | # final_values[..., -1] = final_value 449 | # import ipdb; ipdb.set_trace() 450 | for i in reversed(range(final_values.shape[1]-1)): 451 | final_values[..., i] = deltas[..., i] + gamma*nonterminal[..., i]*final_values[..., i+1] 452 | 453 | if all_to_all: 454 | final_values = final_values.flatten(0, 1) 455 | 456 | return final_values.transpose(0, 1) 457 | 458 | 459 | @torch.no_grad() 460 | def sanity_check_gcrl(states, 461 | nonterminal, 462 | actions, 463 | distance, 464 | gamma, 465 | distance_scale, 466 | reward_scale, 467 | network, 468 | window=50, 469 | conv_goal=True 470 | ): 471 | reps = network.encode_targets(states.flatten(2, 3)) 472 | goal_latents = (reps[1] if conv_goal else reps[0]) 473 | goal = goal_latents[window] 474 | 475 | input_latents = reps[1].view(*reps[1].shape[:-1], -1, 7, 7) 476 | input_latents = input_latents[:-1] 477 | spatial_goal = goal.unsqueeze(0) 478 | spatial_goal = spatial_goal.view(*spatial_goal.shape[:-1], -1, 7, 7).expand_as(input_latents) 479 | pred_values = network.head_forward(input_latents.flatten(0, 1), None, None, spatial_goal.flatten(0, 1)) 480 | pred_values = pred_values.view(input_latents.shape[0], input_latents.shape[1], *pred_values.shape[1:]) 481 | 482 | actions = actions.contiguous() 483 | pred_values = pred_values.contiguous() 484 | pred_values = select_at_indexes(actions[:-1], pred_values) 485 | pred_values = from_categorical(pred_values, limit=10, logits=False) 486 | 487 | returns = calculate_true_values(goal_latents.transpose(0, 1), 488 | goal, 489 | distance, 490 | gamma, 491 | pred_values[-1], 492 | nonterminal[:window], 493 | distance_scale, 494 | reward_scale) 495 | 496 | return pred_values, returns 497 | 498 | 499 | def discount_return_n_step(reward, done, n_step, discount, return_dest=None, 500 | done_n_dest=None, do_truncated=False): 501 | """Time-major inputs, optional other dimension: [T], [T,B], etc. Computes 502 | n-step discounted returns within the timeframe of the of given rewards. If 503 | `do_truncated==False`, then only compute at time-steps with full n-step 504 | future rewards are provided (i.e. not at last n-steps--output shape will 505 | change!). Returns n-step returns as well as n-step done signals, which is 506 | True if `done=True` at any future time before the n-step target bootstrap 507 | would apply (bootstrap in the algo, not here).""" 508 | rlen = reward.shape[0] 509 | if not do_truncated: 510 | rlen -= (n_step - 1) 511 | return_ = torch.zeros( 512 | (rlen,) + reward.shape[1:], dtype=reward.dtype, device=reward.device) 513 | done_n = torch.zeros( 514 | (rlen,) + reward.shape[1:], dtype=done.dtype, device=done.device) 515 | return_[:] = reward[:rlen].float() # 1-step return is current reward. 516 | done_n[:] = done[:rlen].float() # True at time t if done any time by t + n - 1 517 | 518 | done_dtype = done.dtype 519 | done_n = done_n.type(reward.dtype) 520 | done = done.type(reward.dtype) 521 | 522 | if n_step > 1: 523 | if do_truncated: 524 | for n in range(1, n_step): 525 | return_[:-n] += (discount ** n) * reward[n:n + rlen] * (1 - done_n[:-n]) 526 | done_n[:-n] = torch.max(done_n[:-n], done[n:n + rlen]) 527 | else: 528 | for n in range(1, n_step): 529 | return_ += (discount ** n) * reward[n:n + rlen] * (1 - done_n) 530 | done_n = torch.max(done_n, done[n:n + rlen]) # Supports tensors. 531 | done_n = done_n.type(done_dtype) 532 | return return_, done_n 533 | 534 | -------------------------------------------------------------------------------- /src/rlpyt_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from rlpyt.samplers.base import BaseSampler 3 | from rlpyt.samplers.buffer import build_samples_buffer 4 | from rlpyt.samplers.parallel.cpu.collectors import CpuResetCollector 5 | from rlpyt.samplers.serial.collectors import SerialEvalCollector 6 | from rlpyt.utils.buffer import buffer_from_example, torchify_buffer, numpify_buffer 7 | from rlpyt.utils.logging import logger 8 | from rlpyt.utils.quick_args import save__init__args 9 | from rlpyt.utils.seed import set_seed 10 | from rlpyt.runners.minibatch_rl import MinibatchRlEval 11 | from src.offline_dataset import get_offline_dataloaders 12 | from rlpyt.utils.prog_bar import ProgBarCounter 13 | 14 | import wandb 15 | import psutil 16 | from tqdm import tqdm, trange 17 | 18 | import torch 19 | import numpy as np 20 | import time 21 | 22 | 23 | atari_human_scores = dict( 24 | alien=7127.7, amidar=1719.5, assault=742.0, asterix=8503.3, 25 | bank_heist=753.1, battle_zone=37187.5, boxing=12.1, 26 | breakout=30.5, chopper_command=7387.8, crazy_climber=35829.4, 27 | demon_attack=1971.0, freeway=29.6, frostbite=4334.7, 28 | gopher=2412.5, hero=30826.4, jamesbond=302.8, kangaroo=3035.0, 29 | krull=2665.5, kung_fu_master=22736.3, ms_pacman=6951.6, pong=14.6, 30 | private_eye=69571.3, qbert=13455.0, road_runner=7845.0, 31 | seaquest=42054.7, up_n_down=11693.2 32 | ) 33 | 34 | atari_der_scores = dict( 35 | alien=739.9, amidar=188.6, assault=431.2, asterix=470.8, 36 | bank_heist=51.0, battle_zone=10124.6, boxing=0.2, 37 | breakout=1.9, chopper_command=861.8, crazy_climber=16185.3, 38 | demon_attack=508, freeway=27.9, frostbite=866.8, 39 | gopher=349.5, hero=6857.0, jamesbond=301.6, 40 | kangaroo=779.3, krull=2851.5, kung_fu_master=14346.1, 41 | ms_pacman=1204.1, pong=-19.3, private_eye=97.8, qbert=1152.9, 42 | road_runner=9600.0, seaquest=354.1, up_n_down=2877.4, 43 | ) 44 | 45 | atari_spr_scores = dict( 46 | alien=919.6, amidar=159.6, assault=699.5, asterix=983.5, 47 | bank_heist=370.1, battle_zone=14472.0, boxing=30.5, 48 | breakout=15.6, chopper_command=1130.0, crazy_climber=36659.8, 49 | demon_attack=636.4, freeway=24.6, frostbite=1811.0, 50 | gopher=593.4, hero=5602.8, jamesbond=378.7, 51 | kangaroo=3876.0, krull=3810.3, kung_fu_master=14135.8, 52 | ms_pacman=1205.3, pong=-3.8, private_eye=20.2, qbert=791.8, 53 | road_runner=13062.4, seaquest=603.8, up_n_down=7307.8, 54 | ) 55 | 56 | atari_nature_scores = dict( 57 | alien=3069, amidar=739.5, assault=3359, 58 | asterix=6012, bank_heist=429.7, battle_zone=26300., 59 | boxing=71.8, breakout=401.2, chopper_command=6687., 60 | crazy_climber=114103, demon_attack=9711., freeway=30.3, 61 | frostbite=328.3, gopher=8520., hero=19950., jamesbond=576.7, 62 | kangaroo=6740., krull=3805., kung_fu_master=23270., 63 | ms_pacman=2311., pong=18.9, private_eye=1788., 64 | qbert=10596., road_runner=18257., seaquest=5286., up_n_down=8456. 65 | ) 66 | 67 | atari_random_scores = dict( 68 | alien=227.8, amidar=5.8, assault=222.4, 69 | asterix=210.0, bank_heist=14.2, battle_zone=2360.0, 70 | boxing=0.1, breakout=1.7, chopper_command=811.0, 71 | crazy_climber=10780.5, demon_attack=152.1, freeway=0.0, 72 | frostbite=65.2, gopher=257.6, hero=1027.0, jamesbond=29.0, 73 | kangaroo=52.0, krull=1598.0, kung_fu_master=258.5, 74 | ms_pacman=307.3, pong=-20.7, private_eye=24.9, 75 | qbert=163.9, road_runner=11.5, seaquest=68.4, up_n_down=533.4 76 | ) 77 | 78 | atari_offline_scores = { 79 | 'air_raid': 8438.86630859375, 80 | 'alien': 2766.808740234375, 81 | 'amidar': 1556.9634033203124, 82 | 'assault': 1946.0983642578126, 83 | 'asterix': 4131.7666015625, 84 | 'asteroids': 988.1867919921875, 85 | 'atlantis': 944228.0, 86 | 'bank_heist': 907.7182373046875, 87 | 'battle_zone': 26458.991015625, 88 | 'beam_rider': 6453.26220703125, 89 | 'berzerk': 5934.23671875, 90 | 'bowling': 39.969451141357425, 91 | 'boxing': 84.11411743164062, 92 | 'breakout': 157.86087036132812, 93 | 'carnival': 5339.45888671875, 94 | 'centipede': 3972.48896484375, 95 | 'chopper_command': 3678.1458984375, 96 | 'crazy_climber': 118080.240625, 97 | 'demon_attack': 6517.02294921875, 98 | 'double_dunk': -1.2223684310913085, 99 | 'elevator_action': 1056.0, 100 | 'enduro': 1016.2788940429688, 101 | 'fishing_derby': 18.566691207885743, 102 | 'freeway': 26.761290740966796, 103 | 'frostbite': 1643.6466918945312, 104 | 'gopher': 8240.9982421875, 105 | 'gravitar': 310.55962524414065, 106 | 'hero': 16233.5439453125, 107 | 'ice_hockey': -4.018936491012573, 108 | 'jamesbond': 777.7283569335938, 109 | 'journey_escape': -1838.3529296875, 110 | 'kangaroo': 14125.109765625, 111 | 'krull': 7238.50810546875, 112 | 'kung_fu_master': 26637.877734375, 113 | 'montezuma_revenge': 2.6229507446289064, 114 | 'ms_pacman': 4171.52939453125, 115 | 'name_this_game': 8645.0869140625, 116 | 'phoenix': 5122.29873046875, 117 | 'pitfall': -2.578418827056885, 118 | 'pong': 18.253971099853516, 119 | 'pooyan': 4135.323828125, 120 | 'private_eye': 1415.1702465057374, 121 | 'qbert': 12275.1263671875, 122 | 'riverraid': 12798.88203125, 123 | 'road_runner': 47880.48203125, 124 | 'robotank': 63.44000015258789, 125 | 'seaquest': 3233.4708984375, 126 | 'skiing': -18856.73046875, 127 | 'solaris': 2041.66669921875, 128 | 'space_invaders': 2044.6254638671876, 129 | 'star_gunner': 55103.8390625, 130 | 'tennis': 0.0, 131 | 'time_pilot': 4160.50830078125, 132 | 'tutankham': 189.23845520019532, 133 | 'up_n_down': 15677.91884765625, 134 | 'venture': 60.28846340179443, 135 | 'video_pinball': 335055.6875, 136 | 'wizard_of_wor': 1787.789697265625, 137 | 'yars_revenge': 26762.979296875, 138 | 'zaxxon': 4681.930334472656 139 | } 140 | 141 | 142 | def maybe_update_summary(key, value): 143 | if key not in wandb.run.summary.keys(): 144 | wandb.run.summary[key] = value 145 | else: 146 | wandb.run.summary[key] = max(value, wandb.run.summary[key]) 147 | 148 | 149 | class MinibatchRlEvalWandb(MinibatchRlEval): 150 | 151 | def __init__(self, final_eval_only=False, no_eval=False, freeze_encoder=False, 152 | linear_only=False, save_fn=None, start_itr=0, save_every=None, 153 | *args, **kwargs): 154 | super().__init__(*args, **kwargs) 155 | self.final_eval_only = final_eval_only 156 | self.no_eval = no_eval 157 | self.freeze_encoder = freeze_encoder 158 | self.linear_only = linear_only 159 | 160 | def log_diagnostics(self, itr, eval_traj_infos, eval_time): 161 | cum_steps = (itr + 1) * self.sampler.batch_size * self.world_size 162 | self.wandb_info = {'cum_steps': cum_steps} 163 | super().log_diagnostics(itr, eval_traj_infos, eval_time) 164 | wandb.log(self.wandb_info) 165 | 166 | def startup(self): 167 | """ 168 | Sets hardware affinities, initializes the following: 1) sampler (which 169 | should initialize the agent), 2) agent device and data-parallel wrapper (if applicable), 170 | 3) algorithm, 4) logger. 171 | """ 172 | p = psutil.Process() 173 | try: 174 | if (self.affinity.get("master_cpus", None) is not None and 175 | self.affinity.get("set_affinity", True)): 176 | p.cpu_affinity(self.affinity["master_cpus"]) 177 | cpu_affin = p.cpu_affinity() 178 | except AttributeError: 179 | cpu_affin = "UNAVAILABLE MacOS" 180 | logger.log(f"Runner {getattr(self, 'rank', '')} master CPU affinity: " 181 | f"{cpu_affin}.") 182 | if self.affinity.get("master_torch_threads", None) is not None: 183 | torch.set_num_threads(self.affinity["master_torch_threads"]) 184 | logger.log(f"Runner {getattr(self, 'rank', '')} master Torch threads: " 185 | f"{torch.get_num_threads()}.") 186 | set_seed(self.seed) 187 | torch.backends.cudnn.deterministic = True 188 | torch.backends.cudnn.benchmark = False 189 | # try: 190 | # torch.set_deterministic(True) 191 | # except: 192 | # print("Not doing torch.set_deterministic(True); please update Torch") 193 | 194 | self.rank = rank = getattr(self, "rank", 0) 195 | self.world_size = world_size = getattr(self, "world_size", 1) 196 | examples = self.sampler.initialize( 197 | agent=self.agent, # Agent gets initialized in sampler. 198 | affinity=self.affinity, 199 | seed=self.seed + 1, 200 | bootstrap_value=getattr(self.algo, "bootstrap_value", False), 201 | traj_info_kwargs=self.get_traj_info_kwargs(), 202 | rank=rank, 203 | world_size=world_size, 204 | ) 205 | self.itr_batch_size = self.sampler.batch_spec.size * world_size 206 | n_itr = self.get_n_itr() 207 | if torch.cuda.is_available(): 208 | self.agent.to_device(self.affinity.get("cuda_idx", None)) 209 | if world_size > 1: 210 | self.agent.data_parallel() 211 | self.algo.initialize( 212 | agent=self.agent, 213 | n_itr=n_itr, 214 | batch_spec=self.sampler.batch_spec, 215 | mid_batch_reset=self.sampler.mid_batch_reset, 216 | examples=examples, 217 | world_size=world_size, 218 | rank=rank, 219 | ) 220 | self.initialize_logging() 221 | return n_itr 222 | 223 | def _log_infos(self, traj_infos=None): 224 | """ 225 | Writes trajectory info and optimizer info into csv via the logger. 226 | Resets stored optimizer info. Also dumps the model's parameters to dist 227 | if save_fn was provided. 228 | """ 229 | if traj_infos is None: 230 | traj_infos = self._traj_infos 231 | if traj_infos: 232 | for k in traj_infos[0]: 233 | if not k.startswith("_"): 234 | values = [info[k] for info in traj_infos] 235 | logger.record_tabular_misc_stat(k, 236 | values) 237 | 238 | wandb.run.summary[k] = np.average(values) 239 | self.wandb_info[k + "Average"] = np.average(values) 240 | self.wandb_info[k + "Std"] = np.std(values) 241 | self.wandb_info[k + "Min"] = np.min(values) 242 | self.wandb_info[k + "Max"] = np.max(values) 243 | self.wandb_info[k + "Median"] = np.median(values) 244 | if k == 'GameScore': 245 | game = self.sampler.env_kwargs['game'] 246 | random_score = atari_random_scores[game] 247 | der_score = atari_der_scores[game] 248 | spr_score = atari_spr_scores[game] 249 | nature_score = atari_nature_scores[game] 250 | human_score = atari_human_scores[game] 251 | offline_score = atari_offline_scores[game] 252 | normalized_score = (np.average(values) - random_score) / (human_score - random_score) 253 | der_normalized_score = (np.average(values) - random_score) / (der_score - random_score) 254 | spr_normalized_score = (np.average(values) - random_score) / (spr_score - random_score) 255 | nature_normalized_score = (np.average(values) - random_score) / (nature_score - random_score) 256 | offline_normalized_score = (np.average(values) - random_score) / (offline_score - random_score) 257 | self.wandb_info[k + "Normalized"] = normalized_score 258 | self.wandb_info[k + "DERNormalized"] = der_normalized_score 259 | self.wandb_info[k + "SPRNormalized"] = spr_normalized_score 260 | self.wandb_info[k + "NatureNormalized"] = nature_normalized_score 261 | self.wandb_info[k + "OfflineNormalized"] = offline_normalized_score 262 | 263 | maybe_update_summary(k+"Best", np.average(values)) 264 | maybe_update_summary(k+"NormalizedBest", normalized_score) 265 | maybe_update_summary(k+"DERNormalizedBest", der_normalized_score) 266 | maybe_update_summary(k+"SPRNormalizedBest", spr_normalized_score) 267 | maybe_update_summary(k+"NatureNormalizedBest", nature_normalized_score) 268 | maybe_update_summary(k+"OfflineNormalizedBest", offline_normalized_score) 269 | 270 | if self._opt_infos: 271 | for k, v in self._opt_infos.items(): 272 | logger.record_tabular_misc_stat(k, v) 273 | self.wandb_info[k] = np.average(v) 274 | wandb.run.summary[k] = np.average(v) 275 | self._opt_infos = {k: list() for k in self._opt_infos} # (reset) 276 | 277 | def evaluate_agent(self, itr): 278 | """ 279 | Record offline evaluation of agent performance, by ``sampler.evaluate_agent()``. 280 | """ 281 | if itr > 0: 282 | self.pbar.stop() 283 | 284 | if self.final_eval_only: 285 | eval = itr == 0 or itr >= self.n_itr - 1 286 | else: 287 | eval = itr == 0 or itr >= self.min_itr_learn - 1 288 | if eval and not self.no_eval: 289 | logger.log("Evaluating agent...") 290 | self.agent.eval_mode(itr) # Might be agent in sampler. 291 | eval_time = -time.time() 292 | traj_infos = self.sampler.evaluate_agent(itr) 293 | eval_time += time.time() 294 | else: 295 | traj_infos = [] 296 | eval_time = 0.0 297 | logger.log("Evaluation runs complete.") 298 | return traj_infos, eval_time 299 | 300 | def train(self): 301 | raise NotImplementedError 302 | 303 | 304 | class OnlineEval(MinibatchRlEvalWandb): 305 | def __init__(self, epochs, dataloader, use_offline_data=False, *args, **kwargs): 306 | super().__init__(*args, **kwargs) 307 | 308 | def train(self): 309 | n_itr = self.startup() 310 | wandb.watch(self.agent.model) 311 | self.n_itr = n_itr 312 | with logger.prefix(f"itr #0 "): 313 | eval_traj_infos, eval_time = self.evaluate_agent(0) 314 | self.log_diagnostics(0, eval_traj_infos, eval_time) 315 | for itr in range(self.n_itr): 316 | logger.set_iteration(itr) 317 | with logger.prefix(f"itr #{itr} "): 318 | just_logged = False 319 | self.agent.sample_mode(itr) 320 | samples, traj_infos = self.sampler.obtain_samples(itr) 321 | self.agent.train_mode(itr) 322 | opt_info = self.algo.optimize_agent(itr, samples) 323 | self.store_diagnostics(itr, traj_infos, opt_info) 324 | if (itr + 1) % self.log_interval_itrs == 0: 325 | eval_traj_infos, eval_time = self.evaluate_agent(itr) 326 | self.log_diagnostics(itr, eval_traj_infos, eval_time) 327 | just_logged = True 328 | if not just_logged: 329 | eval_traj_infos, eval_time = self.evaluate_agent(itr) 330 | self.log_diagnostics(itr, eval_traj_infos, eval_time) 331 | self.shutdown() 332 | 333 | 334 | class OfflineEval(MinibatchRlEvalWandb): 335 | def __init__(self, epochs, dataloader, save_fn=None, start_itr=0, save_every=None, use_offline_data=True, *args, **kwargs): 336 | super().__init__(*args, **kwargs) 337 | self.epochs = epochs 338 | self.itr = start_itr 339 | self.save_fn = save_fn 340 | self.dataloader = get_offline_dataloaders(**dataloader)[0] 341 | self.algo.offline_dataloader = iter(self.dataloader) 342 | self.algo.offline_dataset = self.dataloader 343 | self.save_every = save_every 344 | self.log_interval_itrs = self.save_every 345 | self.pbar = ProgBarCounter(self.save_every) 346 | assert use_offline_data, "Cannot pre-train without offline dataset" 347 | 348 | def get_n_itr(self): 349 | return self.save_every 350 | 351 | def train(self): 352 | self.n_itr = self.startup() 353 | 354 | batches_per_epoch = len(self.dataloader) 355 | self.total_iters = self.epochs * batches_per_epoch 356 | if self.save_every is None: 357 | self.save_every = batches_per_epoch 358 | else: 359 | self.save_every = self.save_every 360 | 361 | with logger.prefix(f"itr #0 "): 362 | eval_traj_infos, eval_time = self.evaluate_agent(0) 363 | self.log_diagnostics(0, eval_traj_infos, eval_time) 364 | done = self.itr > self.total_iters 365 | while not done: 366 | self.itr = self.itr + 1 367 | logger.set_iteration(self.itr) 368 | with logger.prefix(f"itr #{self.itr} "): 369 | self.agent.train_mode(self.itr) 370 | opt_info = self.algo.optimize_agent(self.itr) 371 | self.store_diagnostics(self.itr, [], opt_info) 372 | if self.itr == self.total_iters or self.itr % self.save_every == 0: 373 | eval_traj_infos, eval_time = self.evaluate_agent(self.itr) 374 | if self.save_fn is not None: 375 | self.save_fn(self.agent.model.state_dict(), self.algo.optimizer.state_dict(), self.itr) 376 | self.log_diagnostics(self.itr, eval_traj_infos, eval_time) 377 | if self.itr > self.total_iters: 378 | self.shutdown() 379 | return 380 | 381 | self.shutdown() 382 | 383 | 384 | def delete_ind_from_tensor(tensor, ind): 385 | tensor = torch.cat([tensor[:ind], tensor[ind+1:]], 0) 386 | return tensor 387 | 388 | 389 | def delete_ind_from_array(array, ind): 390 | tensor = np.concatenate([array[:ind], array[ind+1:]], 0) 391 | return tensor 392 | 393 | 394 | class OneToOneSerialEvalCollector(SerialEvalCollector): 395 | def collect_evaluation(self, itr): 396 | assert self.max_trajectories == len(self.envs) 397 | traj_infos = [self.TrajInfoCls() for _ in range(len(self.envs))] 398 | completed_traj_infos = list() 399 | observations = list() 400 | for env in self.envs: 401 | observations.append(env.reset()) 402 | observation = buffer_from_example(observations[0], len(self.envs)) 403 | for b, o in enumerate(observations): 404 | observation[b] = o 405 | action = buffer_from_example(self.envs[0].action_space.null_value(), 406 | len(self.envs)) 407 | reward = np.zeros(len(self.envs), dtype="float32") 408 | obs_pyt, act_pyt, rew_pyt = torchify_buffer((observation, action, reward)) 409 | self.agent.reset() 410 | self.agent.eval_mode(itr) 411 | live_envs = list(range(len(self.envs))) 412 | for t in range(self.max_T): 413 | act_pyt, agent_info = self.agent.step(obs_pyt, act_pyt, rew_pyt) 414 | action = numpify_buffer(act_pyt) 415 | 416 | b = 0 417 | while b < len(live_envs): # don't want to do a for loop since live envs changes over time 418 | env_id = live_envs[b] 419 | o, r, d, env_info = self.envs[env_id].step(action[b]) 420 | traj_infos[env_id].step(observation[b], 421 | action[b], r, d, 422 | agent_info[b], env_info) 423 | if getattr(env_info, "traj_done", d): 424 | completed_traj_infos.append(traj_infos[env_id].terminate(o)) 425 | observation = delete_ind_from_array(observation, b) 426 | reward = delete_ind_from_array(reward, b) 427 | action = delete_ind_from_array(action, b) 428 | obs_pyt, act_pyt, rew_pyt = torchify_buffer((observation, action, reward)) 429 | 430 | del live_envs[b] 431 | b -= 1 # live_envs[b] is now the next env, so go back one. 432 | else: 433 | observation[b] = o 434 | reward[b] = r 435 | 436 | b += 1 437 | 438 | if (self.max_trajectories is not None and 439 | len(completed_traj_infos) >= self.max_trajectories): 440 | logger.log("Evaluation reached max num trajectories " 441 | f"({self.max_trajectories}).") 442 | return completed_traj_infos 443 | 444 | if t == self.max_T - 1: 445 | logger.log("Evaluation reached max num time steps " 446 | f"({self.max_T}).") 447 | return completed_traj_infos 448 | 449 | 450 | class SerialSampler(BaseSampler): 451 | """The simplest sampler; no parallelism, everything occurs in same, master 452 | Python process. This can be easier for debugging (e.g. can use 453 | ``breakpoint()`` in master process) and might be fast enough for 454 | experiment purposes. Should be used with collectors which generate the 455 | agent's actions internally, i.e. CPU-based collectors but not GPU-based 456 | ones. 457 | NOTE: We modify this class from rlpyt to pass an id to EnvCls when creating 458 | environments. 459 | """ 460 | 461 | def __init__(self, *args, CollectorCls=CpuResetCollector, 462 | eval_CollectorCls=SerialEvalCollector, **kwargs): 463 | super().__init__(*args, CollectorCls=CollectorCls, 464 | eval_CollectorCls=eval_CollectorCls, **kwargs) 465 | 466 | def initialize( 467 | self, 468 | agent, 469 | affinity=None, 470 | seed=None, 471 | bootstrap_value=False, 472 | traj_info_kwargs=None, 473 | rank=0, 474 | world_size=1, 475 | ): 476 | """Store the input arguments. Instantiate the specified number of environment 477 | instances (``batch_B``). Initialize the agent, and pre-allocate a memory buffer 478 | to hold the samples collected in each batch. Applies ``traj_info_kwargs`` settings 479 | to the `TrajInfoCls` by direct class attribute assignment. Instantiates the Collector 480 | and, if applicable, the evaluation Collector. 481 | 482 | Returns a structure of inidividual examples for data fields such as `observation`, 483 | `action`, etc, which can be used to allocate a replay buffer. 484 | """ 485 | B = self.batch_spec.B 486 | envs = [self.EnvCls(id=i, **self.env_kwargs) for i in range(B)] 487 | global_B = B * world_size 488 | env_ranks = list(range(rank * B, (rank + 1) * B)) 489 | agent.initialize(envs[0].spaces, share_memory=False, 490 | global_B=global_B, env_ranks=env_ranks) 491 | samples_pyt, samples_np, examples = build_samples_buffer(agent, envs[0], 492 | self.batch_spec, bootstrap_value, agent_shared=False, 493 | env_shared=False, subprocess=False) 494 | if traj_info_kwargs: 495 | for k, v in traj_info_kwargs.items(): 496 | setattr(self.TrajInfoCls, "_" + k, v) # Avoid passing at init. 497 | collector = self.CollectorCls( 498 | rank=0, 499 | envs=envs, 500 | samples_np=samples_np, 501 | batch_T=self.batch_spec.T, 502 | TrajInfoCls=self.TrajInfoCls, 503 | agent=agent, 504 | global_B=global_B, 505 | env_ranks=env_ranks, # Might get applied redundantly to agent. 506 | ) 507 | if self.eval_n_envs > 0: # May do evaluation. 508 | eval_envs = [self.EnvCls(id=i, **self.eval_env_kwargs) 509 | for i in range(self.eval_n_envs)] 510 | eval_CollectorCls = self.eval_CollectorCls or SerialEvalCollector 511 | self.eval_collector = eval_CollectorCls( 512 | envs=eval_envs, 513 | agent=agent, 514 | TrajInfoCls=self.TrajInfoCls, 515 | max_T=self.eval_max_steps // self.eval_n_envs, 516 | max_trajectories=self.eval_max_trajectories, 517 | ) 518 | 519 | agent_inputs, traj_infos = collector.start_envs( 520 | self.max_decorrelation_steps) 521 | collector.start_agent() 522 | 523 | self.agent = agent 524 | self.samples_pyt = samples_pyt 525 | self.samples_np = samples_np 526 | self.collector = collector 527 | self.agent_inputs = agent_inputs 528 | self.traj_infos = traj_infos 529 | logger.log("Serial Sampler initialized.") 530 | return examples 531 | 532 | def obtain_samples(self, itr): 533 | """Call the collector to execute a batch of agent-environment interactions. 534 | Return data in torch tensors, and a list of trajectory-info objects from 535 | episodes which ended. 536 | """ 537 | # self.samples_np[:] = 0 # Unnecessary and may take time. 538 | agent_inputs, traj_infos, completed_infos = self.collector.collect_batch( 539 | self.agent_inputs, self.traj_infos, itr) 540 | self.collector.reset_if_needed(agent_inputs) 541 | self.agent_inputs = agent_inputs 542 | self.traj_infos = traj_infos 543 | return self.samples_pyt, completed_infos 544 | 545 | def evaluate_agent(self, itr): 546 | """Call the evaluation collector to execute agent-environment interactions.""" 547 | return self.eval_collector.collect_evaluation(itr) 548 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import numpy as np 5 | from contextlib import nullcontext 6 | 7 | from rlpyt.models.utils import update_state_dict 8 | from rlpyt.utils.tensor import select_at_indexes 9 | from rlpyt.utils.tensor import infer_leading_dims, restore_leading_dims 10 | from src.utils import count_parameters, get_augmentation, from_categorical, find_weight_norm, update_state_dict_compat 11 | from src.networks import * 12 | 13 | import copy 14 | EPS = 1e-6 # (NaN-guard) 15 | 16 | 17 | class SPRCatDqnModel(torch.nn.Module): 18 | """2D conlutional network feeding into MLP with ``n_atoms`` outputs 19 | per action, representing a discrete probability distribution of Q-values.""" 20 | 21 | def __init__( 22 | self, 23 | image_shape, 24 | output_size, 25 | n_atoms, 26 | dueling, 27 | jumps, 28 | spr, 29 | augmentation, 30 | target_augmentation, 31 | eval_augmentation, 32 | dynamics_blocks, 33 | norm_type, 34 | noisy_nets, 35 | aug_prob, 36 | projection, 37 | imagesize, 38 | dqn_hidden_size, 39 | momentum_tau, 40 | renormalize, 41 | q_l1_type, 42 | dropout, 43 | predictor, 44 | rl, 45 | bc, 46 | bc_from_values, 47 | goal_rl, 48 | goal_n_step, 49 | noisy_nets_std, 50 | residual_tm, 51 | inverse_model, 52 | encoder, 53 | goal_conditioning_type, 54 | resblock="inverted", 55 | expand_ratio=2, 56 | use_maxpool=False, 57 | channels=None, # None uses default. 58 | kernel_sizes=None, 59 | strides=None, 60 | paddings=None, 61 | framestack=4, 62 | freeze_encoder=False, 63 | share_l1=False, 64 | cnn_scale_factor=1, 65 | blocks_per_group=3, 66 | ln_for_rl_head=False, 67 | state_dict=None, 68 | conv_goal=True, 69 | goal_all_to_all=False, 70 | load_head_to=1, 71 | load_compat_mode=False, 72 | ): 73 | """Instantiates the neural network according to arguments; network defaults 74 | stored within this method.""" 75 | super().__init__() 76 | 77 | self.noisy = noisy_nets 78 | self.aug_prob = aug_prob 79 | self.projection_type = projection 80 | 81 | self.dqn_hidden_size = dqn_hidden_size 82 | 83 | if resblock == "inverted": 84 | resblock = InvertedResidual 85 | else: 86 | resblock = Residual 87 | 88 | self.transforms = get_augmentation(augmentation, imagesize) 89 | self.target_transforms = get_augmentation(target_augmentation, imagesize) 90 | self.eval_transforms = get_augmentation(eval_augmentation, imagesize) 91 | 92 | self.dueling = dueling 93 | f, c = image_shape[:2] 94 | in_channels = np.prod(image_shape[:2]) 95 | if encoder == "resnet": 96 | self.conv = ResnetCNN(in_channels, 97 | depths=[int(32*cnn_scale_factor), 98 | int(64*cnn_scale_factor), 99 | int(64*cnn_scale_factor)], 100 | strides=[3, 2, 2], 101 | norm_type=norm_type, 102 | blocks_per_group=blocks_per_group, 103 | resblock=resblock, 104 | expand_ratio=expand_ratio,) 105 | 106 | elif encoder.lower() == "normednature": 107 | self.conv = Conv2dModel( 108 | in_channels=in_channels, 109 | channels=[int(32*cnn_scale_factor), 110 | int(64*cnn_scale_factor), 111 | int(64*cnn_scale_factor)], 112 | kernel_sizes=[8, 4, 3], 113 | strides=[4, 2, 1], 114 | paddings=[0, 0, 0], 115 | use_maxpool=False, 116 | dropout=dropout, 117 | norm_type=norm_type, 118 | ) 119 | else: 120 | self.conv = Conv2dModel( 121 | in_channels=in_channels, 122 | channels=[int(32*cnn_scale_factor), 123 | int(64*cnn_scale_factor), 124 | int(64*cnn_scale_factor)], 125 | kernel_sizes=[8, 4, 3], 126 | strides=[4, 2, 1], 127 | paddings=[0, 0, 0], 128 | use_maxpool=False, 129 | dropout=dropout, 130 | ) 131 | 132 | fake_input = torch.zeros(1, f*c, imagesize, imagesize) 133 | fake_output = self.conv(fake_input) 134 | self.latent_shape = fake_output.shape[1:] 135 | self.hidden_size = fake_output.shape[1] 136 | self.pixels = fake_output.shape[-1]*fake_output.shape[-2] 137 | print("Spatial latent size is {}".format(fake_output.shape[1:])) 138 | 139 | self.renormalize = init_normalization(self.hidden_size, renormalize) 140 | 141 | self.jumps = jumps 142 | self.rl = rl 143 | self.bc = bc 144 | self.bc_from_values = bc_from_values 145 | self.goal_n_step = goal_n_step 146 | self.use_spr = spr 147 | self.target_augmentation = target_augmentation 148 | self.eval_augmentation = eval_augmentation 149 | self.num_actions = output_size 150 | 151 | self.head = GoalConditionedDuelingHead(self.hidden_size, 152 | output_size, 153 | hidden_size=self.dqn_hidden_size, 154 | pixels=self.pixels, 155 | noisy=self.noisy, 156 | conv_goals=conv_goal, 157 | goal_all_to_all=goal_all_to_all, 158 | share_l1=share_l1, 159 | n_atoms=n_atoms, 160 | ln_for_dqn=ln_for_rl_head, 161 | conditioning_type=goal_conditioning_type, 162 | std_init=noisy_nets_std) 163 | 164 | # Gotta initialize this no matter what or the state dict won't load 165 | self.dynamics_model = TransitionModel(channels=self.hidden_size, 166 | num_actions=output_size, 167 | hidden_size=self.hidden_size, 168 | blocks=dynamics_blocks, 169 | norm_type=norm_type, 170 | resblock=resblock, 171 | expand_ratio=expand_ratio, 172 | renormalize=self.renormalize, 173 | residual=residual_tm) 174 | 175 | self.momentum_tau = momentum_tau 176 | if self.projection_type == "mlp": 177 | self.projection = nn.Sequential( 178 | nn.Flatten(-3, -1), 179 | nn.Linear(self.pixels*self.hidden_size, 512), 180 | TransposedBN1D(512), 181 | nn.ReLU(), 182 | nn.Linear(512, 256) 183 | ) 184 | self.target_projection = self.projection 185 | projection_size = 256 186 | elif self.projection_type == "q_l1": 187 | if goal_rl: 188 | layers = [self.head.goal_linears[0], self.head.goal_linears[2]] 189 | else: 190 | layers = [self.head.rl_linears[0], self.head.rl_linears[2]] 191 | self.projection = QL1Head(layers, dueling=dueling, type=q_l1_type) 192 | projection_size = self.projection.out_features 193 | else: 194 | projection_size = self.pixels*self.hidden_size 195 | 196 | self.target_projection = self.projection 197 | self.target_projection = copy.deepcopy(self.target_projection) 198 | self.target_encoder = copy.deepcopy(self.conv) 199 | for param in (list(self.target_encoder.parameters()) + 200 | list(self.target_projection.parameters())): 201 | param.requires_grad = False 202 | 203 | if self.bc and not self.bc_from_values: 204 | self.bc_head = nn.Sequential(nn.ReLU(), 205 | nn.Linear(projection_size, output_size)) 206 | 207 | # Gotta initialize this no matter what or the state dict won't load 208 | if predictor == "mlp": 209 | self.predictor = nn.Sequential( 210 | nn.Linear(projection_size, projection_size*2), 211 | TransposedBN1D(projection_size*2), 212 | nn.ReLU(), 213 | nn.Linear(projection_size*2, projection_size) 214 | ) 215 | elif predictor == "linear": 216 | self.predictor = nn.Sequential( 217 | nn.Linear(projection_size, projection_size), 218 | ) 219 | elif predictor == "none": 220 | self.predictor = nn.Identity() 221 | 222 | self.use_inverse_model = inverse_model 223 | # Gotta initialize this no matter what or the state dict won't load 224 | self.inverse_model = InverseModelHead(projection_size, 225 | output_size,) 226 | 227 | print("Initialized model with {} parameters; CNN has {}.".format(count_parameters(self), count_parameters(self.conv))) 228 | print("Initialized CNN weight norm is {}".format(find_weight_norm(self.conv.parameters()).item())) 229 | 230 | if state_dict is not None: 231 | if load_compat_mode: 232 | state_dict = update_state_dict_compat(state_dict, self.state_dict()) 233 | self.load_state_dict(state_dict) 234 | print("Loaded CNN weight norm is {}".format(find_weight_norm(self.conv.parameters()).item())) 235 | if rl: 236 | self.head.copy_base_params(up_to=load_head_to) 237 | self.head.reset_noise_params() 238 | 239 | self.frozen_encoder = freeze_encoder 240 | if self.frozen_encoder: 241 | self.freeze_encoder() 242 | 243 | def set_sampling(self, sampling): 244 | if self.noisy: 245 | self.head.set_sampling(sampling) 246 | 247 | def freeze_encoder(self): 248 | print("Freezing CNN") 249 | for param in self.conv.parameters(): 250 | param.requires_grad = False 251 | 252 | def spr_loss(self, f_x1s, f_x2s): 253 | f_x1 = F.normalize(f_x1s.float(), p=2., dim=-1, eps=1e-3) 254 | f_x2 = F.normalize(f_x2s.float(), p=2., dim=-1, eps=1e-3) 255 | loss = F.mse_loss(f_x1, f_x2, reduction="none").sum(-1).mean(0) 256 | return loss 257 | 258 | def do_spr_loss(self, pred_latents, targets, observation): 259 | pred_latents = self.predictor(pred_latents) 260 | 261 | targets = targets.view(-1, observation.shape[1], 262 | self.jumps+1, 263 | targets.shape[-1]).transpose(1, 2) 264 | latents = pred_latents.view(-1, observation.shape[1], 265 | self.jumps+1, 266 | pred_latents.shape[-1]).transpose(1, 2) 267 | 268 | spr_loss = self.spr_loss(latents, targets).view(-1, observation.shape[1]) # split to batch, jumps 269 | 270 | return spr_loss 271 | 272 | @torch.no_grad() 273 | def calculate_diversity(self, global_latents, observation): 274 | global_latents = global_latents.view(observation.shape[1], self.jumps+1, global_latents.shape[-1])[:, 0] 275 | # shape is jumps, bs, dim 276 | global_latents = F.normalize(global_latents, p=2., dim=-1, eps=1e-3) 277 | cos_sim = torch.matmul(global_latents, global_latents.transpose(0, 1)) 278 | mask = 1 - (torch.eye(cos_sim.shape[0], device=cos_sim.device, dtype=torch.float)) 279 | 280 | cos_sim = cos_sim*mask 281 | offset = cos_sim.shape[-1]/(cos_sim.shape[-1] - 1) 282 | cos_sim = cos_sim.mean()*offset 283 | 284 | return cos_sim 285 | 286 | def apply_transforms(self, transforms, image): 287 | for transform in transforms: 288 | image = maybe_transform(image, transform, p=self.aug_prob) 289 | return image 290 | 291 | @torch.no_grad() 292 | def transform(self, images, transforms, augment=False): 293 | images = images.float()/255. if images.dtype == torch.uint8 else images 294 | if augment: 295 | flat_images = images.reshape(-1, *images.shape[-3:]) 296 | processed_images = self.apply_transforms(transforms, 297 | flat_images) 298 | processed_images = processed_images.view(*images.shape[:-3], 299 | *processed_images.shape[1:]) 300 | return processed_images 301 | else: 302 | return images 303 | 304 | def split_stem_model_params(self): 305 | stem_params = list(self.conv.parameters()) + list(self.head.parameters()) 306 | model_params = self.dynamics_model.parameters() 307 | 308 | return stem_params, model_params 309 | 310 | def sort_params(self, params_dict): 311 | return [params_dict[k] for k in sorted(params_dict.keys())] 312 | 313 | def list_params(self): 314 | all_parameters = {k: v for k, v in self.named_parameters()} 315 | conv_params = {k: v for k, v in all_parameters.items() if k.startswith("conv")} 316 | dynamics_model_params = {k: v for k, v in all_parameters.items() if k.startswith("dynamics_model")} 317 | 318 | q_l1_params = {k: v for k, v in all_parameters.items() 319 | if (k.startswith("head.goal_value.0") 320 | or k.startswith("head.goal_advantage.0") 321 | or k.startswith("head.rl_value.0") 322 | or k.startswith("head.rl_advantage.0"))} 323 | 324 | other_params = {k: v for k, v in all_parameters.items() if not 325 | (k.startswith("target") 326 | or k in conv_params.keys() 327 | or k in dynamics_model_params.keys() 328 | or k in q_l1_params.keys())} 329 | 330 | return self.sort_params(conv_params), \ 331 | self.sort_params(dynamics_model_params), \ 332 | self.sort_params(q_l1_params), \ 333 | self.sort_params(other_params) 334 | 335 | def stem_forward(self, img, prev_action=None, prev_reward=None): 336 | """Returns the normalized output of convolutional layers.""" 337 | # Infer (presence of) leading dimensions: [T,B], [B], or []. 338 | lead_dim, T, B, img_shape = infer_leading_dims(img, 3) 339 | 340 | with torch.no_grad() if self.frozen_encoder else nullcontext(): 341 | conv_out = self.conv(img.view(T * B, *img_shape)) # Fold if T dimension. 342 | conv_out = self.renormalize(conv_out) 343 | return conv_out 344 | 345 | def head_forward(self, 346 | conv_out, 347 | prev_action, 348 | prev_reward, 349 | goal=None, 350 | logits=False): 351 | lead_dim, T, B, img_shape = infer_leading_dims(conv_out, 3) 352 | p = self.head(conv_out, goal) 353 | 354 | if logits: 355 | p = F.log_softmax(p, dim=-1) 356 | else: 357 | p = F.softmax(p, dim=-1) 358 | 359 | # Restore leading dimensions: [T,B], [B], or [], as input. 360 | p = restore_leading_dims(p, lead_dim, T, B) 361 | return p 362 | 363 | @torch.no_grad() 364 | def encode_targets(self, target_images, project=True): 365 | target_images = self.transform(target_images, self.transforms, True) 366 | target_latents = self.target_encoder(target_images.flatten(0, 1)) 367 | target_latents = self.renormalize(target_latents) 368 | if project: 369 | proj_latents = self.target_projection(target_latents) 370 | proj_latents = proj_latents.view(target_images.shape[0], 371 | target_images.shape[1], 372 | -1) 373 | return proj_latents, target_latents.view(target_images.shape[0], 374 | target_images.shape[1], 375 | -1) 376 | else: 377 | return target_latents.view(target_images.shape[0], 378 | target_images.shape[1], 379 | -1) 380 | 381 | def encode_online(self, images, project=True): 382 | images = self.transform(images, self.transforms, True) 383 | latents = self.conv(images.flatten(0, 1)) 384 | latents = self.renormalize(latents) 385 | if project: 386 | proj_latents = self.projection(latents) 387 | proj_latents = proj_latents.view(images.shape[0], 388 | images.shape[1], 389 | -1) 390 | return proj_latents, latents.view(images.shape[0], 391 | images.shape[1], 392 | -1) 393 | else: 394 | return latents.view(images.shape[0], 395 | images.shape[1], 396 | -1) 397 | 398 | def forward(self, 399 | observation, 400 | prev_action, 401 | prev_reward, 402 | goal=None, 403 | train=False, 404 | eval=False): 405 | """ 406 | For convenience reasons with DistributedDataParallel the forward method 407 | has been split into two cases, one for training and one for eval. 408 | """ 409 | if train: 410 | pred_latents = [] 411 | input_obs = observation[0].flatten(1, 2) 412 | input_obs = self.transform(input_obs, self.transforms, augment=True) 413 | latent = self.stem_forward(input_obs, 414 | prev_action[0], 415 | prev_reward[0]) 416 | if self.rl or self.bc_from_values: 417 | log_pred_ps = self.head_forward(latent, 418 | prev_action[0], 419 | prev_reward[0], 420 | goal=None, 421 | logits=True) 422 | else: 423 | log_pred_ps = None 424 | 425 | if goal is not None: 426 | goal_log_pred_ps = self.head_forward(latent, 427 | prev_action[0], 428 | prev_reward[0], 429 | goal=goal, 430 | logits=True) 431 | else: 432 | goal_log_pred_ps = None 433 | 434 | pred_latents.append(latent) 435 | if self.jumps > 0: 436 | for j in range(1, self.jumps + 1): 437 | latent = self.step(latent, prev_action[j]) 438 | pred_latents.append(latent) 439 | 440 | with torch.no_grad(): 441 | to_encode = max(self.jumps+1, self.goal_n_step) 442 | target_images = observation[:to_encode].transpose(0, 1).flatten(2, 3) 443 | target_proj, target_latents = self.encode_targets(target_images, project=True) 444 | 445 | pred_latents = torch.stack(pred_latents, 1) 446 | proj_latents = self.projection(pred_latents) 447 | if self.use_spr: 448 | spr_loss = self.do_spr_loss(proj_latents.flatten(0, 1), 449 | target_proj.flatten(0, 1), 450 | observation) 451 | else: 452 | spr_loss = torch.zeros((self.jumps + 1, observation.shape[1]), device=latent.device) 453 | 454 | if self.bc: 455 | if self.bc_from_values: 456 | bc_preds = from_categorical(log_pred_ps.exp(), limit=10, logits=False) 457 | 458 | if self.bc and not self.bc_from_values: 459 | bc_preds = self.bc_head(proj_latents[:, 0]) 460 | else: 461 | bc_preds = None 462 | 463 | if self.use_inverse_model: 464 | stack = torch.cat([proj_latents[:, :-1], target_proj.view(*proj_latents.shape)[:, 1:]], -1) 465 | pred_actions = self.inverse_model(stack.flatten(0, 1)) 466 | pred_actions = pred_actions.view(stack.shape[0], stack.shape[1], *pred_actions.shape[1:]) 467 | pred_actions = pred_actions.transpose(0, 1) 468 | inv_model_loss = F.cross_entropy(pred_actions.flatten(0, 1), 469 | prev_action[1:self.jumps + 1].flatten(0, 1), reduction="none") 470 | inv_model_loss = inv_model_loss.view(*pred_actions.shape[:-1]).mean(0) 471 | else: 472 | inv_model_loss = torch.zeros_like(spr_loss).mean(0) 473 | 474 | diversity = self.calculate_diversity(proj_latents, observation) 475 | update_state_dict(self.target_encoder, 476 | self.conv.state_dict(), 477 | self.momentum_tau) 478 | update_state_dict(self.target_projection, 479 | self.projection.state_dict(), 480 | self.momentum_tau) 481 | 482 | return log_pred_ps,\ 483 | goal_log_pred_ps,\ 484 | spr_loss, \ 485 | target_latents, \ 486 | target_proj, \ 487 | diversity, \ 488 | inv_model_loss, \ 489 | bc_preds, 490 | 491 | else: 492 | observation = observation.flatten(-4, -3) 493 | 494 | transforms = self.eval_transforms if eval else self.target_transforms 495 | img = self.transform(observation, transforms, len(transforms) > 0) 496 | 497 | # Infer (presence of) leading dimensions: [T,B], [B], or []. 498 | lead_dim, T, B, img_shape = infer_leading_dims(img, 3) 499 | 500 | conv_out = self.conv(img.view(T * B, *img_shape)) # Fold if T dimension. 501 | conv_out = self.renormalize(conv_out) 502 | p = self.head(conv_out, goal) 503 | 504 | p = F.softmax(p, dim=-1) 505 | 506 | # Restore leading dimensions: [T,B], [B], or [], as input. 507 | p = restore_leading_dims(p, lead_dim, T, B) 508 | 509 | return p 510 | 511 | def select_action(self, obs): 512 | if self.bc_from_values or not self.bc: 513 | value = self.forward(obs, None, None, train=False, eval=True) 514 | value = from_categorical(value, logits=False, limit=10) 515 | else: 516 | observation = obs.flatten(-4, -3) 517 | img = self.transform(observation, self.eval_transforms, len(self.eval_transforms) > 0) 518 | lead_dim, T, B, img_shape = infer_leading_dims(img, 3) 519 | conv_out = self.conv(img.view(T * B, *img_shape)) # Fold if T dimension. 520 | conv_out = self.renormalize(conv_out) 521 | proj = self.projection(conv_out) 522 | value = self.bc_head(proj) 523 | value = restore_leading_dims(value, lead_dim, T, B) 524 | return value 525 | 526 | def step(self, state, action): 527 | next_state = self.dynamics_model(state, action) 528 | return next_state 529 | 530 | 531 | class QL1Head(nn.Module): 532 | def __init__(self, layers, dueling=False, type=""): 533 | super().__init__() 534 | self.noisy = "noisy" in type 535 | self.dueling = dueling 536 | self.relu = "relu" in type 537 | 538 | self.encoders = nn.ModuleList(layers) 539 | self.out_features = sum([encoder.out_features for encoder in self.encoders]) 540 | 541 | def forward(self, x): 542 | x = x.flatten(-3, -1) 543 | representations = [] 544 | for encoder in self.encoders: 545 | encoder.noise_override = self.noisy 546 | representations.append(encoder(x)) 547 | encoder.noise_override = None 548 | representation = torch.cat(representations, -1) 549 | if self.relu: 550 | representation = F.relu(representation) 551 | 552 | return representation 553 | 554 | 555 | def maybe_transform(image, transform, p=0.8): 556 | processed_images = transform(image) 557 | if p >= 1: 558 | return processed_images 559 | else: 560 | mask = torch.rand((processed_images.shape[0], 1, 1, 1), 561 | device=processed_images.device) 562 | mask = (mask < p).float() 563 | processed_images = mask * processed_images + (1 - mask) * image 564 | return processed_images 565 | 566 | -------------------------------------------------------------------------------- /src/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from src.utils import renormalize 6 | from rlpyt.models.utils import scale_grad 7 | from rlpyt.utils.tensor import infer_leading_dims, restore_leading_dims 8 | import copy 9 | 10 | 11 | def fixup_init(layer, num_layers): 12 | nn.init.normal_(layer.weight, mean=0, std=np.sqrt( 13 | 2 / (layer.weight.shape[0] * np.prod(layer.weight.shape[2:]))) * num_layers ** (-0.25)) 14 | 15 | 16 | class InvertedResidual(nn.Module): 17 | def __init__(self, in_channels, out_channels, stride, expand_ratio, 18 | norm_type, num_layers=1, groups=-1, 19 | drop_prob=0., bias=True): 20 | super(InvertedResidual, self).__init__() 21 | assert stride in [1, 2, 3] 22 | self.drop_prob = drop_prob 23 | 24 | hidden_dim = round(in_channels * expand_ratio) 25 | 26 | if groups <= 0: 27 | groups = hidden_dim 28 | 29 | conv = nn.Conv2d 30 | 31 | if stride != 1: 32 | self.downsample = nn.Conv2d(in_channels, out_channels, stride, stride) 33 | nn.init.normal_(self.downsample.weight, mean=0, std= 34 | np.sqrt(2 / (self.downsample.weight.shape[0] * 35 | np.prod(self.downsample.weight.shape[2:])))) 36 | else: 37 | self.downsample = False 38 | 39 | if expand_ratio == 1: 40 | conv1 = conv(hidden_dim, hidden_dim, 3, stride, 1, groups=groups, bias=bias) 41 | conv2 = conv(hidden_dim, out_channels, 1, 1, 0, bias=bias) 42 | fixup_init(conv1, num_layers) 43 | fixup_init(conv2, num_layers) 44 | self.conv = nn.Sequential( 45 | # dw 46 | conv1, 47 | init_normalization(hidden_dim, norm_type), 48 | nn.ReLU(inplace=True), 49 | # pw-linear 50 | conv2, 51 | init_normalization(out_channels, norm_type), 52 | ) 53 | nn.init.constant_(self.conv[-1].weight, 0) 54 | else: 55 | conv1 = conv(in_channels, hidden_dim, 1, 1, 0, bias=bias) 56 | conv2 = conv(hidden_dim, hidden_dim, 3, stride, 1, groups=groups, bias=bias) 57 | conv3 = conv(hidden_dim, out_channels, 1, 1, 0, bias=bias) 58 | fixup_init(conv1, num_layers) 59 | fixup_init(conv2, num_layers) 60 | fixup_init(conv3, num_layers) 61 | self.conv = nn.Sequential( 62 | # pw 63 | conv1, 64 | init_normalization(hidden_dim, norm_type), 65 | nn.ReLU(inplace=True), 66 | # dw 67 | conv2, 68 | init_normalization(hidden_dim, norm_type), 69 | nn.ReLU(inplace=True), 70 | # pw-linear 71 | conv3, 72 | init_normalization(out_channels, norm_type) 73 | ) 74 | if norm_type != "none": 75 | nn.init.constant_(self.conv[-1].weight, 0) 76 | 77 | def forward(self, x): 78 | if self.downsample: 79 | identity = self.downsample(x) 80 | else: 81 | identity = x 82 | if self.training and np.random.uniform() < self.drop_prob: 83 | return identity 84 | else: 85 | return identity + self.conv(x) 86 | 87 | 88 | class Residual(InvertedResidual): 89 | def __init__(self, *args, **kwargs): 90 | super().__init__(*args, **kwargs, groups=1) 91 | 92 | 93 | class ResnetCNN(nn.Module): 94 | def __init__(self, input_channels, 95 | depths=(16, 32, 64), 96 | strides=(3, 2, 2), 97 | blocks_per_group=3, 98 | norm_type="bn", 99 | resblock=InvertedResidual, 100 | expand_ratio=2,): 101 | super(ResnetCNN, self).__init__() 102 | self.depths = [input_channels] + depths 103 | self.resblock = resblock 104 | self.expand_ratio = expand_ratio 105 | self.blocks_per_group = blocks_per_group 106 | self.layers = [] 107 | self.norm_type = norm_type 108 | self.num_layers = self.blocks_per_group*len(depths) 109 | for i in range(len(depths)): 110 | self.layers.append(self._make_layer(self.depths[i], 111 | self.depths[i+1], 112 | strides[i], 113 | )) 114 | self.layers = nn.Sequential(*self.layers) 115 | self.train() 116 | 117 | def _make_layer(self, in_channels, depth, stride,): 118 | 119 | blocks = [self.resblock(in_channels, depth, 120 | expand_ratio=self.expand_ratio, 121 | stride=stride, 122 | norm_type=self.norm_type, 123 | num_layers=self.num_layers,)] 124 | 125 | for i in range(1, self.blocks_per_group): 126 | blocks.append(self.resblock(depth, depth, 127 | expand_ratio=self.expand_ratio, 128 | stride=1, 129 | norm_type=self.norm_type, 130 | num_layers=self.num_layers,)) 131 | 132 | return nn.Sequential(*blocks) 133 | 134 | @property 135 | def local_layer_depth(self): 136 | return self.depths[-2] 137 | 138 | def forward(self, inputs): 139 | return self.layers(inputs) 140 | 141 | 142 | class TransitionModel(nn.Module): 143 | def __init__(self, 144 | channels, 145 | num_actions, 146 | args=None, 147 | blocks=0, 148 | hidden_size=256, 149 | norm_type="bn", 150 | renormalize=True, 151 | resblock=InvertedResidual, 152 | expand_ratio=2, 153 | residual=False): 154 | super().__init__() 155 | self.hidden_size = hidden_size 156 | self.num_actions = num_actions 157 | self.args = args 158 | self.renormalize = renormalize 159 | 160 | self.residual = residual 161 | conv = nn.Conv2d 162 | self.initial_layer = nn.Sequential(conv(channels+num_actions, hidden_size, 3, 1, 1), 163 | nn.ReLU(), init_normalization(hidden_size, norm_type)) 164 | self.final_layer = nn.Conv2d(hidden_size, channels, 3, 1, 1) 165 | resblocks = [] 166 | 167 | for i in range(blocks): 168 | resblocks.append(resblock(hidden_size, 169 | hidden_size, 170 | stride=1, 171 | norm_type=norm_type, 172 | expand_ratio=expand_ratio, 173 | num_layers=blocks)) 174 | self.resnet = nn.Sequential(*resblocks) 175 | if self.residual: 176 | nn.init.constant_(self.final_layer.weight, 0) 177 | self.train() 178 | 179 | def forward(self, x, action, blocks=True): 180 | batch_range = torch.arange(action.shape[0], device=action.device) 181 | action_onehot = torch.zeros(action.shape[0], 182 | self.num_actions, 183 | x.shape[-2], 184 | x.shape[-1], 185 | device=action.device) 186 | action_onehot[batch_range, action, :, :] = 1 187 | stacked_image = torch.cat([x, action_onehot], 1) 188 | next_state = self.initial_layer(stacked_image) 189 | if blocks: 190 | next_state = self.resnet(next_state) 191 | next_state = self.final_layer(next_state) 192 | if self.residual: 193 | next_state = next_state + x 194 | next_state = F.relu(next_state) 195 | next_state = self.renormalize(next_state) 196 | return next_state 197 | 198 | 199 | def init_normalization(channels, type="bn", affine=True, one_d=False): 200 | assert type in ["bn", "ln", "in", "gn", "max", "none", None] 201 | if type == "bn": 202 | if one_d: 203 | return nn.BatchNorm1d(channels, affine=affine) 204 | else: 205 | return nn.BatchNorm2d(channels, affine=affine) 206 | elif type == "ln": 207 | if one_d: 208 | return nn.LayerNorm(channels, elementwise_affine=affine) 209 | else: 210 | return nn.GroupNorm(1, channels, affine=affine) 211 | elif type == "in": 212 | return nn.GroupNorm(channels, channels, affine=affine) 213 | elif type == "gn": 214 | groups = max(min(32, channels//4), 1) 215 | return nn.GroupNorm(groups, channels, affine=affine) 216 | elif type == "max": 217 | if not one_d: 218 | return renormalize 219 | else: 220 | return lambda x: renormalize(x, -1) 221 | elif type == "none" or type is None: 222 | return nn.Identity() 223 | 224 | 225 | class NoisyLinear(nn.Module): 226 | def __init__(self, in_features, out_features, std_init=0.1, bias=True): 227 | super(NoisyLinear, self).__init__() 228 | self.bias = bias 229 | self.in_features = in_features 230 | self.out_features = out_features 231 | self.std_init = std_init 232 | self.sampling = True 233 | self.noise_override = None 234 | self.weight_mu = nn.Parameter(torch.empty(out_features, in_features)) 235 | self.weight_sigma = nn.Parameter(torch.empty(out_features, in_features)) 236 | self.bias_mu = nn.Parameter(torch.empty(out_features), requires_grad=bias) 237 | self.bias_sigma = nn.Parameter(torch.empty(out_features), requires_grad=bias) 238 | self.register_buffer('bias_epsilon', torch.empty(out_features)) 239 | self.register_buffer('weight_epsilon', torch.empty(out_features, in_features)) 240 | self.register_buffer('old_bias_epsilon', torch.empty(out_features)) 241 | self.register_buffer('old_weight_epsilon', torch.empty(out_features, in_features)) 242 | self.reset_parameters() 243 | self.reset_noise() 244 | self.use_old_noise = False 245 | 246 | def reset_noise_parameters(self): 247 | self.weight_sigma.data.fill_(self.std_init / np.sqrt(self.in_features)) 248 | if self.bias: 249 | self.bias_sigma.data.fill_(self.std_init / np.sqrt(self.out_features)) 250 | else: 251 | self.bias_sigma.fill_(0) 252 | 253 | def reset_parameters(self): 254 | mu_range = 1 / np.sqrt(self.in_features) 255 | self.weight_mu.data.uniform_(-mu_range, mu_range) 256 | if self.bias: 257 | self.bias_mu.data.uniform_(-mu_range, mu_range) 258 | else: 259 | self.bias_mu.fill_(0) 260 | 261 | self.reset_noise_parameters() 262 | 263 | def _scale_noise(self, size): 264 | x = torch.randn(size) 265 | return x.sign().mul_(x.abs().sqrt_()) 266 | 267 | def reset_noise(self): 268 | self.old_bias_epsilon.copy_(self.bias_epsilon) 269 | self.old_weight_epsilon.copy_(self.weight_epsilon) 270 | epsilon_in = self._scale_noise(self.in_features) 271 | epsilon_out = self._scale_noise(self.out_features) 272 | self.weight_epsilon.copy_(epsilon_out.ger(epsilon_in)) 273 | self.bias_epsilon.copy_(epsilon_out) 274 | 275 | def forward(self, input): 276 | # Self.training alone isn't a good-enough check, since we may need to 277 | # activate .eval() during sampling even when we want to use noise 278 | # (due to batchnorm, dropout, or similar). 279 | # The extra "sampling" flag serves to override this behavior and causes 280 | # noise to be used even when .eval() has been called. 281 | use_noise = (self.training or self.sampling) if self.noise_override is None else self.noise_override 282 | if use_noise: 283 | weight_eps = self.old_weight_epsilon if self.use_old_noise else self.weight_epsilon 284 | bias_eps = self.old_bias_epsilon if self.use_old_noise else self.bias_epsilon 285 | 286 | return F.linear(input, self.weight_mu + self.weight_sigma * weight_eps, 287 | self.bias_mu + self.bias_sigma * bias_eps) 288 | else: 289 | return F.linear(input, self.weight_mu, self.bias_mu) 290 | 291 | 292 | class Conv2dModel(torch.nn.Module): 293 | """2-D Convolutional model component, with option for max-pooling vs 294 | downsampling for strides > 1. Requires number of input channels, but 295 | not input shape. Uses ``torch.nn.Conv2d``. 296 | """ 297 | 298 | def __init__( 299 | self, 300 | in_channels, 301 | channels, 302 | kernel_sizes, 303 | strides, 304 | paddings=None, 305 | nonlinearity=torch.nn.ReLU, # Module, not Functional. 306 | use_maxpool=False, # if True: convs use stride 1, maxpool downsample. 307 | head_sizes=None, # Put an MLP head on top. 308 | dropout=0., 309 | norm_type="none", 310 | ): 311 | super().__init__() 312 | if paddings is None: 313 | paddings = [0 for _ in range(len(channels))] 314 | assert len(channels) == len(kernel_sizes) == len(strides) == len(paddings) 315 | in_channels = [in_channels] + channels[:-1] 316 | ones = [1 for _ in range(len(strides))] 317 | if use_maxpool: 318 | maxp_strides = strides 319 | strides = ones 320 | else: 321 | maxp_strides = ones 322 | conv_layers = [torch.nn.Conv2d(in_channels=ic, out_channels=oc, 323 | kernel_size=k, stride=s, padding=p) for (ic, oc, k, s, p) in 324 | zip(in_channels, channels, kernel_sizes, strides, paddings)] 325 | sequence = list() 326 | for conv_layer, maxp_stride, oc in zip(conv_layers, maxp_strides, channels): 327 | sequence.extend([conv_layer, init_normalization(oc, norm_type), nonlinearity()]) 328 | if dropout > 0: 329 | sequence.append(nn.Dropout(dropout)) 330 | if maxp_stride > 1: 331 | sequence.append(torch.nn.MaxPool2d(maxp_stride)) # No padding. 332 | self.conv = torch.nn.Sequential(*sequence) 333 | 334 | def forward(self, input): 335 | """Computes the convolution stack on the input; assumes correct shape 336 | already: [B,C,H,W].""" 337 | return self.conv(input) 338 | 339 | 340 | class DQNDistributionalDuelingHeadModel(torch.nn.Module): 341 | """An MLP head with optional noisy layers which reshapes output to [B, output_size, n_atoms].""" 342 | 343 | def __init__(self, 344 | input_channels, 345 | output_size, 346 | pixels=30, 347 | n_atoms=51, 348 | hidden_size=256, 349 | grad_scale=2 ** (-1 / 2), 350 | noisy=0, 351 | std_init=0.1): 352 | super().__init__() 353 | if noisy: 354 | self.linears = [NoisyLinear(pixels * input_channels, hidden_size, std_init=std_init), 355 | NoisyLinear(hidden_size, output_size * n_atoms, std_init=std_init), 356 | NoisyLinear(pixels * input_channels, hidden_size, std_init=std_init), 357 | NoisyLinear(hidden_size, n_atoms, std_init=std_init) 358 | ] 359 | else: 360 | self.linears = [nn.Linear(pixels * input_channels, hidden_size), 361 | nn.Linear(hidden_size, output_size * n_atoms), 362 | nn.Linear(pixels * input_channels, hidden_size), 363 | nn.Linear(hidden_size, n_atoms) 364 | ] 365 | self.advantage_layers = [nn.Flatten(-3, -1), 366 | self.linears[0], 367 | nn.ReLU(), 368 | self.linears[1]] 369 | self.value_layers = [nn.Flatten(-3, -1), 370 | self.linears[2], 371 | nn.ReLU(), 372 | self.linears[3]] 373 | self.advantage_net = nn.Sequential(*self.advantage_layers) 374 | self.advantage_bias = torch.nn.Parameter(torch.zeros(n_atoms), requires_grad=True) 375 | self.value_net = nn.Sequential(*self.value_layers) 376 | self._grad_scale = grad_scale 377 | self._output_size = output_size 378 | self._n_atoms = n_atoms 379 | 380 | def forward(self, input, old_noise=False): 381 | [setattr(module, "use_old_noise", old_noise) for module in self.modules()] 382 | x = scale_grad(input, self._grad_scale) 383 | advantage = self.advantage(x) 384 | value = self.value_net(x).view(-1, 1, self._n_atoms) 385 | return value + (advantage - advantage.mean(dim=1, keepdim=True)) 386 | 387 | def advantage(self, input): 388 | x = self.advantage_net(input) 389 | x = x.view(-1, self._output_size, self._n_atoms) 390 | return x + self.advantage_bias 391 | 392 | def reset_noise(self): 393 | for module in self.linears: 394 | module.reset_noise() 395 | 396 | def set_sampling(self, sampling): 397 | for module in self.linears: 398 | module.sampling = sampling 399 | 400 | 401 | class GoalConditioning(nn.Module): 402 | def __init__(self, 403 | pixels=49, 404 | feature_dim=64, 405 | dqn_hidden_size=256, 406 | conv=True, 407 | film=True, 408 | goal_only_conditioning=False, 409 | n_heads=2): 410 | """ 411 | The basic idea: cat the online and goal states as feature maps, 412 | and then run a two-layer CNN on it. We then run this through a flatten 413 | and an MLP to get FiLM weights, which we use in the DQN head. 414 | """ 415 | super().__init__() 416 | output_size = n_heads * dqn_hidden_size 417 | output_size = output_size * 2 if film else output_size 418 | input_dim = feature_dim if goal_only_conditioning else feature_dim * 2 419 | self.film = film 420 | self.goal_only_conditioning = goal_only_conditioning 421 | self.n_heads = n_heads 422 | 423 | self.conv = conv 424 | if conv: 425 | self.network = nn.Sequential( 426 | nn.Conv2d(input_dim, feature_dim * 2, kernel_size=3, padding=1), 427 | nn.ReLU(), 428 | nn.Conv2d(feature_dim * 2, feature_dim * 2, kernel_size=3, padding=1), 429 | nn.ReLU(), 430 | nn.Flatten(-3, -1), 431 | nn.Linear(pixels * feature_dim * 2, output_size) 432 | ) 433 | else: 434 | self.network = nn.Sequential( 435 | nn.Linear(dqn_hidden_size*2, dqn_hidden_size*2), 436 | nn.ReLU(), 437 | nn.Linear(dqn_hidden_size*2, dqn_hidden_size*2), 438 | nn.ReLU(), 439 | nn.Linear(dqn_hidden_size*2, output_size) 440 | ) 441 | 442 | def forward(self, states, goals): 443 | if self.conv: 444 | goals = goals.view(*states.shape) 445 | 446 | if not self.goal_only_conditioning: 447 | if self.conv: 448 | goals = F.normalize(goals, dim=(-1, -2, -3), p=2, eps=1e-5) 449 | states = F.normalize(states, dim=(-1, -2, -3), p=2, eps=1e-5) 450 | else: 451 | goals = F.normalize(goals, dim=(-1), p=2, eps=1e-5) 452 | states = F.normalize(states, dim=(-1), p=2, eps=1e-5) 453 | input = torch.cat([states, goals], -3) 454 | else: 455 | input = goals 456 | 457 | output = self.network(input) 458 | 459 | if self.film: 460 | # Split the output into heads and biases/scales 461 | output = output.view(*output.shape[:-1], self.n_heads, 2, -1) 462 | else: 463 | output = output.view(*output.shape[:-1], self.n_heads, -1) 464 | return output 465 | 466 | 467 | class GoalConditionedDuelingHead(torch.nn.Module): 468 | """An MLP head with optional noisy layers which reshapes output to [B, output_size, n_atoms].""" 469 | 470 | """ 471 | For goal conditioning, we have a few options. 472 | First, we can concatenate the goal after the first linear. This is the simplest 473 | solution, but would almost certainly require at least one more linear before 474 | the output layer, for practical reasons of allowing diversity. 475 | 476 | Alternatively, we could use FiLM or similar. We could compute FiLM weights 477 | with one or two layers from the concatenation of goal and state, and then 478 | apply these to the state before using the final linear. 479 | Could also do a residual connection: state/goal -> two layers -> state delta. 480 | 481 | Of course, we could define goals as being in the convolutional feature map 482 | latent space, but that probably sucks (and has very variable size between architectures). 483 | """ 484 | 485 | def __init__(self, 486 | input_channels, 487 | output_size, 488 | pixels=30, 489 | n_atoms=51, 490 | hidden_size=512, 491 | grad_scale=2 ** (-1 / 2), 492 | noisy=0, 493 | std_init=0.1, 494 | ln_for_dqn=True, 495 | conv_goals=True, 496 | conditioning_type=["goal_only", "film"], 497 | share_l1=False, 498 | goal_all_to_all=False): 499 | super().__init__() 500 | 501 | self.goal_conditioner = GoalConditioning( 502 | pixels=pixels, 503 | feature_dim=input_channels, 504 | dqn_hidden_size=hidden_size, 505 | goal_only_conditioning="goal_only" in conditioning_type, 506 | film="film" in conditioning_type, 507 | n_heads=2, 508 | conv=conv_goals, 509 | ) 510 | self.conditioning_style = "film" if "film" in conditioning_type else \ 511 | "sum" if "sum" in conditioning_type else "product" 512 | 513 | self.goal_all_to_all = goal_all_to_all 514 | 515 | if noisy: 516 | self.goal_linears = [NoisyLinear(pixels * input_channels, hidden_size, std_init=std_init), 517 | NoisyLinear(hidden_size, output_size * n_atoms, std_init=std_init), 518 | NoisyLinear(pixels * input_channels, hidden_size, std_init=std_init), 519 | NoisyLinear(hidden_size, n_atoms, std_init=std_init), 520 | ] 521 | self.rl_linears = [NoisyLinear(pixels * input_channels, hidden_size, std_init=std_init), 522 | NoisyLinear(hidden_size, output_size * n_atoms, std_init=std_init), 523 | NoisyLinear(pixels * input_channels, hidden_size, std_init=std_init), 524 | NoisyLinear(hidden_size, n_atoms, std_init=std_init), 525 | ] 526 | else: 527 | self.goal_linears = [nn.Linear(pixels * input_channels, hidden_size), 528 | nn.Linear(hidden_size, output_size * n_atoms), 529 | nn.Linear(pixels * input_channels, hidden_size), 530 | nn.Linear(hidden_size, n_atoms), 531 | ] 532 | self.rl_linears = [nn.Linear(pixels * input_channels, hidden_size), 533 | nn.Linear(hidden_size, output_size * n_atoms), 534 | nn.Linear(pixels * input_channels, hidden_size), 535 | nn.Linear(hidden_size, n_atoms), 536 | ] 537 | 538 | if share_l1: 539 | self.rl_linears[0] = self.goal_linears[0] 540 | self.rl_linears[2] = self.goal_linears[2] 541 | 542 | self.goal_advantage_layers = [self.goal_linears[0], 543 | nn.ReLU(), 544 | nn.LayerNorm(hidden_size, elementwise_affine=False) 545 | if ln_for_dqn else nn.Identity(), 546 | self.goal_linears[1]] 547 | self.goal_value_layers = [self.goal_linears[2], 548 | nn.ReLU(), 549 | nn.LayerNorm(hidden_size, elementwise_affine=False) 550 | if ln_for_dqn else nn.Identity(), 551 | self.goal_linears[3]] 552 | self.advantage_bias = torch.nn.Parameter(torch.zeros(n_atoms), requires_grad=True) 553 | self.rl_advantage_bias = torch.nn.Parameter(torch.zeros(n_atoms), requires_grad=True) 554 | self.goal_value = nn.Sequential(*self.goal_value_layers) 555 | self.goal_advantage = nn.Sequential(*self.goal_advantage_layers) 556 | self._grad_scale = grad_scale 557 | self._output_size = output_size 558 | self._n_atoms = n_atoms 559 | self.noisy = noisy 560 | 561 | self.rl_advantage = nn.Sequential( 562 | self.rl_linears[0], 563 | nn.ReLU(), 564 | nn.LayerNorm(hidden_size) if ln_for_dqn else nn.Identity(), 565 | self.rl_linears[1], 566 | ) 567 | self.rl_value = nn.Sequential( 568 | self.rl_linears[2], 569 | nn.ReLU(), 570 | nn.LayerNorm(hidden_size) if ln_for_dqn else nn.Identity(), 571 | self.rl_linears[3], 572 | ) 573 | 574 | def forward(self, input, goal): 575 | if goal is None: 576 | return self.regular_forward(input) 577 | 578 | x = scale_grad(input, self._grad_scale) 579 | x = x.flatten(-3, -1) 580 | advantage_hidden = self.goal_advantage[0:3](x) 581 | value_hidden = self.goal_value[0:3](x) 582 | 583 | goal_conditioning = self.goal_conditioner(input, goal) 584 | 585 | if self.goal_all_to_all: 586 | goal_conditioning = goal_conditioning.unsqueeze(0) 587 | advantage_hidden = advantage_hidden.unsqueeze(1) 588 | value_hidden = value_hidden.unsqueeze(1) 589 | 590 | if self.conditioning_style == "film": 591 | advantage_biases = goal_conditioning[..., 0, 0, :] 592 | advantage_scales = goal_conditioning[..., 0, 1, :] 593 | value_biases = goal_conditioning[..., 1, 0, :] 594 | value_scales = goal_conditioning[..., 1, 1, :] 595 | advantage_hidden = advantage_hidden * advantage_scales + advantage_biases 596 | value_hidden = advantage_hidden * value_scales + value_biases 597 | elif self.conditioning_style == "product": 598 | advantage_hidden = advantage_hidden * goal_conditioning[..., 0] 599 | value_hidden = value_hidden * goal_conditioning[..., 1] 600 | elif self.conditioning_style == "sum": 601 | advantage_hidden = advantage_hidden + goal_conditioning[..., 0] 602 | value_hidden = value_hidden + goal_conditioning[..., 1] 603 | 604 | if self.goal_all_to_all: 605 | advantage_hidden = advantage_hidden.flatten(0, 1) 606 | value_hidden = value_hidden.flatten(0, 1) 607 | 608 | advantage = self.goal_advantage[-2:](advantage_hidden) 609 | advantage = advantage.view(-1, self._output_size, self._n_atoms) + self.advantage_bias 610 | value = self.goal_value[-2:](value_hidden).view(-1, 1, self._n_atoms) 611 | return value + (advantage - advantage.mean(dim=-2, keepdim=True)) 612 | 613 | def regular_forward(self, input): 614 | x = scale_grad(input, self._grad_scale) 615 | x = x.flatten(-3, -1) 616 | advantage = self.rl_advantage(x) 617 | advantage = advantage.view(-1, self._output_size, self._n_atoms) + self.rl_advantage_bias 618 | value = self.rl_value(x).view(-1, 1, self._n_atoms) 619 | return value + (advantage - advantage.mean(dim=-2, keepdim=True)) 620 | 621 | def copy_base_params(self, up_to=1): 622 | if up_to == 0: 623 | return 624 | self.rl_value[0:up_to].load_state_dict(self.goal_value[0:up_to].state_dict()) 625 | self.rl_advantage[0:up_to].load_state_dict(self.goal_advantage[0:up_to].state_dict()) 626 | 627 | def reset_noise(self): 628 | for module in self.goal_linears: 629 | module.reset_noise() 630 | for module in self.rl_linears: 631 | module.reset_noise() 632 | 633 | def reset_noise_params(self): 634 | for module in self.goal_linears: 635 | module.reset_noise_parameters() 636 | for module in self.rl_linears: 637 | module.reset_noise_parameters() 638 | 639 | def set_sampling(self, sampling): 640 | for module in self.goal_linears: 641 | module.sampling = sampling 642 | for module in self.rl_linears: 643 | module.sampling = sampling 644 | 645 | 646 | class TransposedBN1D(nn.BatchNorm1d): 647 | def forward(self, x): 648 | x_flat = x.view(-1, x.shape[-1]) 649 | if self.training and x_flat.shape[0] == 1: 650 | return x 651 | x_flat = super().forward(x_flat) 652 | return x_flat.view(*x.shape) 653 | 654 | 655 | class InverseModelHead(nn.Module): 656 | def __init__(self, 657 | input_channels, 658 | num_actions=18,): 659 | super().__init__() 660 | layers = [nn.Linear(input_channels*2, 256), 661 | nn.ReLU(), 662 | nn.Linear(256, num_actions)] 663 | self.network = nn.Sequential(*layers) 664 | self.train() 665 | 666 | def forward(self, x): 667 | return self.network(x) 668 | 669 | --------------------------------------------------------------------------------