├── common ├── __init__.py ├── pvm_buffer.py └── utils.py ├── scripts ├── dmc_series.sh ├── atari_series.sh ├── atari_series_5m.sh ├── atari_wp_series.sh ├── atari_100k_base.sh ├── robosuite_series.sh ├── dmc_test.sh ├── robosuite_test.sh ├── atari_test.sh ├── atari_100k_ns.sh ├── dmc_20.sh ├── dmc_30.sh ├── dmc_50.sh ├── atari_100k_20x20.sh ├── atari_100k_30x30.sh ├── atari_100k_50x50.sh ├── atari_100k_wp20_20x20.sh ├── atari_100k_wp20_30x30.sh ├── atari_100k_wp20_50x50.sh ├── atari_100k_5m.sh ├── atari_100k_wp20_peripheral_only.sh ├── atari_100k_20x20_5m.sh ├── atari_100k_30x30_5m.sh └── atari_100k_50x50_5m.sh ├── _doc └── media │ └── sugarl_formulation.png ├── .gitignore ├── active_rl_env.yaml ├── README.md └── agent ├── dqn_atari_base.py ├── dqn_atari_wp_random.py ├── dqn_atari_wp_base_peripheral.py ├── dqn_atari_wp_single_policy.py ├── dqn_atari_single_policy.py ├── sac_atari_base.py ├── drqv2_dmc_base.py ├── dqn_atari_sugarl.py └── dqn_atari_wp_sugarl.py /common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripts/dmc_series.sh: -------------------------------------------------------------------------------- 1 | ./scripts/dmc_20.sh $1 ; 2 | ./scripts/dmc_30.sh $1 ; 3 | ./scripts/dmc_50.sh $1 ; -------------------------------------------------------------------------------- /_doc/media/sugarl_formulation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elicassion/sugarl/HEAD/_doc/media/sugarl_formulation.png -------------------------------------------------------------------------------- /scripts/atari_series.sh: -------------------------------------------------------------------------------- 1 | ./scripts/atari_100k_20x20.sh $1 ; 2 | ./scripts/atari_100k_30x30.sh $1 ; 3 | ./scripts/atari_100k_50x50.sh $1 ; -------------------------------------------------------------------------------- /scripts/atari_series_5m.sh: -------------------------------------------------------------------------------- 1 | ./scripts/atari_100k_20x20_5m.sh $1 ; 2 | ./scripts/atari_100k_30x30_5m.sh $1 ; 3 | ./scripts/atari_100k_50x50_5m.sh $1 ; -------------------------------------------------------------------------------- /scripts/atari_wp_series.sh: -------------------------------------------------------------------------------- 1 | ./scripts/atari_100k_wp20_20x20.sh $1 ; 2 | ./scripts/atari_100k_wp20_30x30.sh $1 ; 3 | ./scripts/atari_100k_wp20_50x50.sh $1 ; 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | */__pycache__ 2 | *.pyc 3 | *.txt 4 | runs 5 | runs/* 6 | logs 7 | logs/* 8 | recordings 9 | recordings/* 10 | trained_models 11 | trained_models/* 12 | */.ipynb_checkpoints 13 | *.egg-info 14 | -------------------------------------------------------------------------------- /active_rl_env.yaml: -------------------------------------------------------------------------------- 1 | name: arl 2 | 3 | channels: 4 | - defaults 5 | - pytorch 6 | - nvidia 7 | 8 | dependencies: 9 | - python=3.9 10 | - pip 11 | 12 | # pytorch 13 | - pytorch==1.13.1 14 | - torchvision 15 | - torchaudio 16 | - pytorch-cuda=11.6 17 | 18 | - pip: 19 | - requests 20 | - joblib 21 | - psutil 22 | - h5py 23 | - lxml 24 | - colorama 25 | 26 | - jupyter # Used to show the notebook 27 | - jupyterlab 28 | 29 | - scipy 30 | - matplotlib # Used for vis 31 | - tqdm 32 | 33 | - mujoco<3.0 34 | - mujoco-py>=2.1.2.14 35 | - atari-py 36 | - dm-control==1.0.11 37 | - gymnasium>=0.28.1,<1.0.0 38 | 39 | - numpy<2.0.0 40 | - pandas 41 | - einops 42 | - opencv-python 43 | 44 | - tensorboard 45 | - tqdm 46 | 47 | - pybullet 48 | - ptflops # get model computation info 49 | 50 | - git+https://github.com/elicassion/active-gym.git 51 | -------------------------------------------------------------------------------- /scripts/atari_100k_base.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | numgpus=${2:-$(nvidia-smi --list-gpus | wc -l)} 3 | 4 | # atari 100k settings 5 | envlist=(alien amidar assault asterix bank_heist battle_zone boxing breakout chopper_command crazy_climber demon_attack freeway frostbite gopher hero jamesbond kangaroo krull kung_fu_master ms_pacman pong private_eye qbert road_runner seaquest up_n_down) #pong qbert seaquest zaxxon 6 | 7 | expname="atari_100k_base" 8 | totaltimesteps="1000000" 9 | buffersize="100000" 10 | learningstarts="5000" 11 | 12 | mkdir -p logs/${expname} 13 | mkdir -p recordings/${expname} 14 | mkdir -p trained_models/${expname} 15 | 16 | for i in ${!envlist[@]} 17 | do 18 | gpuid=$(( $i % $numgpus )) 19 | ( 20 | for seed in 0 1 2 3 4 21 | do 22 | echo "Atari100k GPU: ${gpuid} Env: ${envlist[$i]} Seed: ${seed} ${1}" 23 | # sleep 5 24 | basename=$(basename $1) 25 | echo "========" >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt 26 | CUDA_VISIBLE_DEVICES=$gpuid python $1 --env ${envlist[$i]} --seed $seed --exp-name ${expname} \ 27 | --total-timesteps $totaltimesteps --buffer-size $buffersize \ 28 | --learning-starts $learningstarts \ 29 | ${@:2} >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt 30 | done 31 | ) & 32 | done 33 | wait -------------------------------------------------------------------------------- /scripts/robosuite_series.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | numgpus=${2:-$(nvidia-smi --list-gpus | wc -l)} 3 | # numgpus=6 4 | gpustart=0 5 | 6 | # robosuite 1m settings 7 | tasklist=(Wipe Stack NutAssemblySquare Door Lift) 8 | 9 | expname="robosuite" 10 | totaltimesteps="100000" 11 | buffersize="100000" 12 | learningstarts="2000" 13 | 14 | mkdir -p logs/${expname} 15 | mkdir -p recordings/${expname} 16 | mkdir -p trained_models/${expname} 17 | 18 | jobnum=0+$gpustart 19 | 20 | for i in ${!tasklist[@]} 21 | do 22 | for seed in 0 1 2 23 | do 24 | gpuid=$(( $jobnum % $numgpus )) 25 | echo "${expname} GPU: ${gpuid} Env: ${tasklist[$i]} Seed: ${seed} ${1}" 26 | # sleep 5 27 | basename=$(basename $1) 28 | echo "========" >> logs/${expname}/${tasklist[$i]}__${basename}__${seed}.txt 29 | CUDA_VISIBLE_DEVICES=$gpuid python $1 --task-name ${tasklist[$i]} \ 30 | --seed $seed --exp-name $expname \ 31 | --capture-video \ 32 | --total-timesteps $totaltimesteps --buffer-size $buffersize \ 33 | --learning-starts $learningstarts \ 34 | --eval-num 10 \ 35 | ${@:3} >> logs/${expname}/${tasklist[$i]}__${basename}__${seed}.txt & 36 | ((jobnum=jobnum+1)) 37 | done 38 | done 39 | wait 40 | -------------------------------------------------------------------------------- /scripts/dmc_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | numgpus=${2:-$(nvidia-smi --list-gpus | wc -l)} 3 | 4 | # dmc 1m settings 5 | domainlist=(reacher) # reacher walker humanoid cheetah dog) 6 | tasklist=(easy) # hard walk walk run fetch) 7 | 8 | expname="dmc_test" 9 | totaltimesteps="1000" 10 | buffersize="500" 11 | learningstarts="300" 12 | 13 | mkdir -p logs/${expname} 14 | mkdir -p recordings/${expname} 15 | mkdir -p trained_models/${expname} 16 | 17 | for i in ${!domainlist[@]} 18 | do 19 | gpuid=$(( $i % $numgpus )) 20 | ( 21 | for seed in 0 22 | do 23 | echo "${expname} GPU: ${gpuid} Env: ${envlist[$i]} Seed: ${seed} ${1}" 24 | # sleep 5 25 | basename=$(basename $1) 26 | echo "========" #>> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt 27 | CUDA_VISIBLE_DEVICES=$gpuid python $1 --domain-name ${domainlist[$i]} --task-name ${tasklist[$i]} \ 28 | --seed $seed --exp-name $expname \ 29 | --fov-size 20 \ 30 | --capture-video \ 31 | --total-timesteps $totaltimesteps --buffer-size $buffersize \ 32 | --learning-starts $learningstarts \ 33 | ${@:3} #>> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt 34 | done 35 | ) & 36 | done 37 | wait -------------------------------------------------------------------------------- /scripts/robosuite_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cudagpus=${2:-$(nvidia-smi --list-gpus | wc -l)} 3 | # gpulist=($(seq 0 1 $cudagpus)) 4 | gpulist=(3) 5 | 6 | numgpus=${#gpulist[@]} 7 | # robosuite 1m settings 8 | tasklist=(Door) 9 | 10 | expname="robosuite_test" 11 | totaltimesteps="1000" 12 | buffersize="500" 13 | learningstarts="300" 14 | 15 | mkdir -p logs/${expname} 16 | mkdir -p recordings/${expname} 17 | mkdir -p trained_models/${expname} 18 | 19 | for i in ${!tasklist[@]} 20 | do 21 | gpuindex=$(( $i % $numgpus )) 22 | gpuid=${gpulist[$gpuindex]} 23 | ( 24 | for seed in 0 25 | do 26 | echo "${expname} GPU: ${gpuid} Env: ${tasklist[$i]} Seed: ${seed} ${1}" 27 | # sleep 5 28 | basename=$(basename $1) 29 | echo "========" # >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt 30 | CUDA_VISIBLE_DEVICES=$gpuid python $1 --task-name ${tasklist[$i]} \ 31 | --seed $seed --exp-name $expname \ 32 | --capture-video \ 33 | --total-timesteps $totaltimesteps --buffer-size $buffersize \ 34 | --learning-starts $learningstarts \ 35 | --eval-num 1 \ 36 | ${@:3} # >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt 37 | done 38 | ) & 39 | done 40 | wait -------------------------------------------------------------------------------- /scripts/atari_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | numgpus=1 3 | 4 | # atari 100k settings 5 | envlist=(alien amidar) #assault asterix bank_heist battle_zone boxing breakout chopper_command crazy_climber demon_attack freeway frostbite gopher hero jamesbond kangaroo krull kung_fu_master ms_pacman pong private_eye qbert road_runner seaquest up_n_down #pong qbert seaquest zaxxon 6 | 7 | expname="test" 8 | totaltimesteps="1000" 9 | buffersize="500" 10 | learningstarts="300" 11 | 12 | mkdir -p logs/${expname} 13 | mkdir -p recordings/${expname} 14 | mkdir -p trained_models/${expname} 15 | 16 | for i in ${!envlist[@]} 17 | do 18 | gpuid=$(( $i % $numgpus )) 19 | ( 20 | for seed in 0 21 | do 22 | echo "${expname} GPU: ${gpuid} Env: ${envlist[$i]} Seed: ${seed} ${1}" 23 | # sleep 5 24 | basename=$(basename $1) 25 | # echo "========" >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt 26 | CUDA_VISIBLE_DEVICES=3 python $1 --env ${envlist[$i]} --seed $seed --exp-name ${expname} \ 27 | --fov-size 20 \ 28 | --total-timesteps $totaltimesteps --buffer-size $buffersize \ 29 | --learning-starts $learningstarts \ 30 | --capture-video \ 31 | ${@:2} #>> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt 32 | done 33 | ) & 34 | done 35 | wait -------------------------------------------------------------------------------- /scripts/atari_100k_ns.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | numgpus=${2:-$(nvidia-smi --list-gpus | wc -l)} 3 | 4 | # atari 100k settings 5 | envlist=(alien amidar assault asterix bank_heist battle_zone boxing breakout chopper_command crazy_climber demon_attack freeway frostbite gopher hero jamesbond kangaroo krull kung_fu_master ms_pacman pong private_eye qbert road_runner seaquest up_n_down) #pong qbert seaquest zaxxon 6 | 7 | expname="atari_100k" 8 | totaltimesteps="1000000" 9 | buffersize="100000" 10 | learningstarts="80000" 11 | 12 | mkdir -p logs/${expname} 13 | mkdir -p recordings/${expname} 14 | mkdir -p trained_models/${expname} 15 | 16 | for i in ${!envlist[@]} 17 | do 18 | gpuid=$(( $i % $numgpus )) 19 | ( 20 | for seed in 0 1 2 3 4 21 | do 22 | echo "Atari100k GPU: ${gpuid} Env: ${envlist[$i]} Seed: ${seed} ${1}" 23 | # sleep 5 24 | basename=$(basename $1) 25 | echo "========" >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt 26 | CUDA_VISIBLE_DEVICES=$gpuid python $1 --env ${envlist[$i]} --seed $seed --exp-name ${expname} \ 27 | --total-timesteps $totaltimesteps --buffer-size $buffersize \ 28 | --capture-video \ 29 | --clip-reward \ 30 | --learning-starts $learningstarts \ 31 | ${@:2} >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt 32 | done 33 | ) & 34 | done 35 | wait -------------------------------------------------------------------------------- /scripts/dmc_20.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | numgpus=${2:-$(nvidia-smi --list-gpus | wc -l)} 3 | 4 | # dmc 1m settings 5 | domainlist=(ball_in_cup cartpole cheetah dog fish walker) 6 | tasklist=(catch swingup run fetch swim walk) 7 | 8 | expname="dmc_20" 9 | totaltimesteps="100000" 10 | buffersize="100000" 11 | learningstarts="2000" 12 | 13 | mkdir -p logs/${expname} 14 | mkdir -p recordings/${expname} 15 | mkdir -p trained_models/${expname} 16 | 17 | jobnum=0 18 | 19 | for i in ${!domainlist[@]} 20 | do 21 | for seed in 0 1 2 3 4 22 | do 23 | gpuid=$(( $jobnum % $numgpus )) 24 | echo "${expname} GPU: ${gpuid} Env: ${domainlist[$i]}-${tasklist[$i]} Seed: ${seed} ${1}" 25 | # sleep 5 26 | basename=$(basename $1) 27 | echo "========" >> logs/${expname}/${domainlist[$i]}-${tasklist[$i]}__${basename}__${seed}.txt 28 | MUJOCO_EGL_DEVICE_ID=$gpuid CUDA_VISIBLE_DEVICES=$gpuid python $1 --domain-name ${domainlist[$i]} --task-name ${tasklist[$i]} \ 29 | --seed $seed --exp-name $expname \ 30 | --fov-size 20 \ 31 | --capture-video \ 32 | --total-timesteps $totaltimesteps --buffer-size $buffersize \ 33 | --learning-starts $learningstarts \ 34 | --eval-frequency 10000 \ 35 | ${@:3} >> logs/${expname}/${domainlist[$i]}-${tasklist[$i]}__${basename}__${seed}.txt & 36 | ((jobnum=jobnum+1)) 37 | done 38 | done 39 | wait -------------------------------------------------------------------------------- /scripts/dmc_30.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | numgpus=${2:-$(nvidia-smi --list-gpus | wc -l)} 3 | 4 | # dmc 1m settings 5 | domainlist=(ball_in_cup cartpole cheetah dog fish walker) 6 | tasklist=(catch swingup run fetch swim walk) 7 | 8 | expname="dmc_30" 9 | totaltimesteps="100000" 10 | buffersize="100000" 11 | learningstarts="2000" 12 | 13 | mkdir -p logs/${expname} 14 | mkdir -p recordings/${expname} 15 | mkdir -p trained_models/${expname} 16 | 17 | jobnum=0 18 | 19 | for i in ${!domainlist[@]} 20 | do 21 | for seed in 0 1 2 3 4 22 | do 23 | gpuid=$(( $jobnum % $numgpus )) 24 | echo "${expname} GPU: ${gpuid} Env: ${domainlist[$i]}-${tasklist[$i]} Seed: ${seed} ${1}" 25 | # sleep 5 26 | basename=$(basename $1) 27 | echo "========" >> logs/${expname}/${domainlist[$i]}-${tasklist[$i]}__${basename}__${seed}.txt 28 | MUJOCO_EGL_DEVICE_ID=$gpuid CUDA_VISIBLE_DEVICES=$gpuid python $1 --domain-name ${domainlist[$i]} --task-name ${tasklist[$i]} \ 29 | --seed $seed --exp-name $expname \ 30 | --fov-size 30 \ 31 | --capture-video \ 32 | --total-timesteps $totaltimesteps --buffer-size $buffersize \ 33 | --learning-starts $learningstarts \ 34 | --eval-frequency 10000 \ 35 | ${@:3} >> logs/${expname}/${domainlist[$i]}-${tasklist[$i]}__${basename}__${seed}.txt & 36 | ((jobnum=jobnum+1)) 37 | done 38 | done 39 | wait -------------------------------------------------------------------------------- /scripts/dmc_50.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | numgpus=${2:-$(nvidia-smi --list-gpus | wc -l)} 3 | 4 | # dmc 1m settings 5 | domainlist=(ball_in_cup cartpole cheetah dog fish walker) 6 | tasklist=(catch swingup run fetch swim walk) 7 | 8 | expname="dmc_50" 9 | totaltimesteps="100000" 10 | buffersize="100000" 11 | learningstarts="2000" 12 | 13 | mkdir -p logs/${expname} 14 | mkdir -p recordings/${expname} 15 | mkdir -p trained_models/${expname} 16 | 17 | jobnum=0 18 | 19 | for i in ${!domainlist[@]} 20 | do 21 | for seed in 0 1 2 3 4 22 | do 23 | gpuid=$(( $jobnum % $numgpus )) 24 | echo "${expname} GPU: ${gpuid} Env: ${domainlist[$i]}-${tasklist[$i]} Seed: ${seed} ${1}" 25 | # sleep 5 26 | basename=$(basename $1) 27 | echo "========" >> logs/${expname}/${domainlist[$i]}-${tasklist[$i]}__${basename}__${seed}.txt 28 | MUJOCO_EGL_DEVICE_ID=$gpuid CUDA_VISIBLE_DEVICES=$gpuid python $1 --domain-name ${domainlist[$i]} --task-name ${tasklist[$i]} \ 29 | --seed $seed --exp-name $expname \ 30 | --fov-size 50 \ 31 | --capture-video \ 32 | --total-timesteps $totaltimesteps --buffer-size $buffersize \ 33 | --learning-starts $learningstarts \ 34 | --eval-frequency 10000 \ 35 | ${@:3} >> logs/${expname}/${domainlist[$i]}-${tasklist[$i]}__${basename}__${seed}.txt & 36 | ((jobnum=jobnum+1)) 37 | done 38 | done 39 | wait -------------------------------------------------------------------------------- /scripts/atari_100k_20x20.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | numgpus=${2:-$(nvidia-smi --list-gpus | wc -l)} 3 | 4 | # atari 100k settings 5 | envlist=(alien amidar assault asterix bank_heist battle_zone boxing breakout chopper_command crazy_climber demon_attack freeway frostbite gopher hero jamesbond kangaroo krull kung_fu_master ms_pacman pong private_eye qbert road_runner seaquest up_n_down) #pong qbert seaquest zaxxon 6 | 7 | expname="atari_100k_20x20" 8 | totaltimesteps="1000000" 9 | buffersize="100000" 10 | learningstarts="80000" 11 | 12 | mkdir -p logs/${expname} 13 | mkdir -p recordings/${expname} 14 | mkdir -p trained_models/${expname} 15 | 16 | for i in ${!envlist[@]} 17 | do 18 | gpuid=$(( $i % $numgpus )) 19 | ( 20 | for seed in 0 1 2 3 4 21 | do 22 | echo "${expname} GPU: ${gpuid} Env: ${envlist[$i]} Seed: ${seed} ${1}" 23 | # sleep 5 24 | basename=$(basename $1) 25 | echo "========" >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt 26 | CUDA_VISIBLE_DEVICES=$gpuid python $1 --env ${envlist[$i]} --seed $seed --exp-name ${expname} \ 27 | --fov-size 20 \ 28 | --clip-reward \ 29 | --capture-video \ 30 | --total-timesteps $totaltimesteps --buffer-size $buffersize \ 31 | --learning-starts $learningstarts \ 32 | ${@:2} >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt 33 | done 34 | ) & 35 | done 36 | wait -------------------------------------------------------------------------------- /scripts/atari_100k_30x30.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | numgpus=${2:-$(nvidia-smi --list-gpus | wc -l)} 3 | 4 | # atari 100k settings 5 | envlist=(alien amidar assault asterix bank_heist battle_zone boxing breakout chopper_command crazy_climber demon_attack freeway frostbite gopher hero jamesbond kangaroo krull kung_fu_master ms_pacman pong private_eye qbert road_runner seaquest up_n_down) #pong qbert seaquest zaxxon 6 | 7 | expname="atari_100k_30x30" 8 | totaltimesteps="1000000" 9 | buffersize="100000" 10 | learningstarts="80000" 11 | 12 | mkdir -p logs/${expname} 13 | mkdir -p recordings/${expname} 14 | mkdir -p trained_models/${expname} 15 | 16 | for i in ${!envlist[@]} 17 | do 18 | gpuid=$(( $i % $numgpus )) 19 | ( 20 | for seed in 0 1 2 3 4 21 | do 22 | echo "${expname} GPU: ${gpuid} Env: ${envlist[$i]} Seed: ${seed} ${1}" 23 | # sleep 5 24 | basename=$(basename $1) 25 | echo "========" >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt 26 | CUDA_VISIBLE_DEVICES=$gpuid python $1 --env ${envlist[$i]} --seed $seed --exp-name ${expname} \ 27 | --fov-size 30 \ 28 | --clip-reward \ 29 | --capture-video \ 30 | --total-timesteps $totaltimesteps --buffer-size $buffersize \ 31 | --learning-starts $learningstarts \ 32 | ${@:2} >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt 33 | done 34 | ) & 35 | done 36 | wait -------------------------------------------------------------------------------- /scripts/atari_100k_50x50.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | numgpus=${2:-$(nvidia-smi --list-gpus | wc -l)} 3 | 4 | # atari 100k settings 5 | envlist=(alien amidar assault asterix bank_heist battle_zone boxing breakout chopper_command crazy_climber demon_attack freeway frostbite gopher hero jamesbond kangaroo krull kung_fu_master ms_pacman pong private_eye qbert road_runner seaquest up_n_down) #pong qbert seaquest zaxxon 6 | 7 | expname="atari_100k_50x50" 8 | totaltimesteps="1000000" 9 | buffersize="100000" 10 | learningstarts="80000" 11 | 12 | mkdir -p logs/${expname} 13 | mkdir -p recordings/${expname} 14 | mkdir -p trained_models/${expname} 15 | 16 | for i in ${!envlist[@]} 17 | do 18 | gpuid=$(( $i % $numgpus )) 19 | ( 20 | for seed in 0 1 2 3 4 21 | do 22 | echo "${expname} GPU: ${gpuid} Env: ${envlist[$i]} Seed: ${seed} ${1}" 23 | # sleep 5 24 | basename=$(basename $1) 25 | echo "========" >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt 26 | CUDA_VISIBLE_DEVICES=$gpuid python $1 --env ${envlist[$i]} --seed $seed --exp-name ${expname} \ 27 | --fov-size 50 \ 28 | --clip-reward \ 29 | --capture-video \ 30 | --total-timesteps $totaltimesteps --buffer-size $buffersize \ 31 | --learning-starts $learningstarts \ 32 | ${@:3} >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt 33 | done 34 | ) & 35 | done 36 | wait -------------------------------------------------------------------------------- /scripts/atari_100k_wp20_20x20.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | numgpus=${2:-$(nvidia-smi --list-gpus | wc -l)} 3 | 4 | # atari 100k settings 5 | envlist=(alien amidar assault asterix bank_heist battle_zone boxing breakout chopper_command crazy_climber demon_attack freeway frostbite gopher hero jamesbond kangaroo krull kung_fu_master ms_pacman pong private_eye qbert road_runner seaquest up_n_down) #pong qbert seaquest zaxxon 6 | 7 | expname="atari_100k_wp20_20x20" 8 | totaltimesteps="1000000" 9 | buffersize="100000" 10 | learningstarts="80000" 11 | 12 | mkdir -p logs/${expname} 13 | mkdir -p recordings/${expname} 14 | mkdir -p trained_models/${expname} 15 | 16 | for i in ${!envlist[@]} 17 | do 18 | gpuid=$(( $i % $numgpus )) 19 | ( 20 | for seed in 0 1 2 3 4 21 | do 22 | echo "${expname} GPU: ${gpuid} Env: ${envlist[$i]} Seed: ${seed} ${1}" 23 | # sleep 5 24 | basename=$(basename $1) 25 | echo "========" >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt 26 | CUDA_VISIBLE_DEVICES=$gpuid python $1 --env ${envlist[$i]} --seed $seed --exp-name ${expname} \ 27 | --fov-size 20 \ 28 | --peripheral-res 20 \ 29 | --clip-reward \ 30 | --capture-video \ 31 | --total-timesteps $totaltimesteps --buffer-size $buffersize \ 32 | --learning-starts $learningstarts \ 33 | ${@:2} >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt 34 | done 35 | ) & 36 | done 37 | wait -------------------------------------------------------------------------------- /scripts/atari_100k_wp20_30x30.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | numgpus=${2:-$(nvidia-smi --list-gpus | wc -l)} 3 | 4 | # atari 100k settings 5 | envlist=(alien amidar assault asterix bank_heist battle_zone boxing breakout chopper_command crazy_climber demon_attack freeway frostbite gopher hero jamesbond kangaroo krull kung_fu_master ms_pacman pong private_eye qbert road_runner seaquest up_n_down) #pong qbert seaquest zaxxon 6 | 7 | expname="atari_100k_wp20_30x30" 8 | totaltimesteps="1000000" 9 | buffersize="100000" 10 | learningstarts="80000" 11 | 12 | mkdir -p logs/${expname} 13 | mkdir -p recordings/${expname} 14 | mkdir -p trained_models/${expname} 15 | 16 | for i in ${!envlist[@]} 17 | do 18 | gpuid=$(( $i % $numgpus )) 19 | ( 20 | for seed in 0 1 2 3 4 21 | do 22 | echo "${expname} GPU: ${gpuid} Env: ${envlist[$i]} Seed: ${seed} ${1}" 23 | # sleep 5 24 | basename=$(basename $1) 25 | echo "========" >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt 26 | CUDA_VISIBLE_DEVICES=$gpuid python $1 --env ${envlist[$i]} --seed $seed --exp-name ${expname} \ 27 | --fov-size 30 \ 28 | --peripheral-res 20 \ 29 | --clip-reward \ 30 | --capture-video \ 31 | --total-timesteps $totaltimesteps --buffer-size $buffersize \ 32 | --learning-starts $learningstarts \ 33 | ${@:2} >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt 34 | done 35 | ) & 36 | done 37 | wait -------------------------------------------------------------------------------- /scripts/atari_100k_wp20_50x50.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | numgpus=${2:-$(nvidia-smi --list-gpus | wc -l)} 3 | 4 | # atari 100k settings 5 | envlist=(alien amidar assault asterix bank_heist battle_zone boxing breakout chopper_command crazy_climber demon_attack freeway frostbite gopher hero jamesbond kangaroo krull kung_fu_master ms_pacman pong private_eye qbert road_runner seaquest up_n_down) #pong qbert seaquest zaxxon 6 | 7 | expname="atari_100k_wp20_50x50" 8 | totaltimesteps="1000000" 9 | buffersize="100000" 10 | learningstarts="80000" 11 | 12 | mkdir -p logs/${expname} 13 | mkdir -p recordings/${expname} 14 | mkdir -p trained_models/${expname} 15 | 16 | for i in ${!envlist[@]} 17 | do 18 | gpuid=$(( $i % $numgpus )) 19 | ( 20 | for seed in 0 1 2 3 4 21 | do 22 | echo "${expname} GPU: ${gpuid} Env: ${envlist[$i]} Seed: ${seed} ${1}" 23 | # sleep 5 24 | basename=$(basename $1) 25 | echo "========" >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt 26 | CUDA_VISIBLE_DEVICES=$gpuid python $1 --env ${envlist[$i]} --seed $seed --exp-name ${expname} \ 27 | --fov-size 50 \ 28 | --peripheral-res 20 \ 29 | --clip-reward \ 30 | --capture-video \ 31 | --total-timesteps $totaltimesteps --buffer-size $buffersize \ 32 | --learning-starts $learningstarts \ 33 | ${@:2} >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt 34 | done 35 | ) & 36 | done 37 | wait -------------------------------------------------------------------------------- /scripts/atari_100k_5m.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | numgpus=${2:-$(nvidia-smi --list-gpus | wc -l)} 3 | 4 | # atari 100k settings 5 | envlist=(alien amidar assault asterix bank_heist battle_zone boxing breakout chopper_command crazy_climber demon_attack freeway frostbite gopher hero jamesbond kangaroo krull kung_fu_master ms_pacman pong private_eye qbert road_runner seaquest up_n_down) #pong qbert seaquest zaxxon 6 | 7 | expname="atari_100k_5m" 8 | totaltimesteps="5000000" 9 | buffersize="100000" 10 | learningstarts="80000" 11 | 12 | mkdir -p logs/${expname} 13 | mkdir -p recordings/${expname} 14 | mkdir -p trained_models/${expname} 15 | 16 | for i in ${!envlist[@]} 17 | do 18 | gpuid=$(( $i % $numgpus )) 19 | ( 20 | for seed in 0 1 2 3 4 21 | do 22 | echo "Atari100k GPU: ${gpuid} Env: ${envlist[$i]} Seed: ${seed} ${1}" 23 | # sleep 5 24 | basename=$(basename $1) 25 | echo "========" >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt 26 | CUDA_VISIBLE_DEVICES=$gpuid python $1 --env ${envlist[$i]} --seed $seed --exp-name ${expname} \ 27 | --exploration-fraction 0.02 \ 28 | --eval-frequency 1000000 \ 29 | --total-timesteps $totaltimesteps --buffer-size $buffersize \ 30 | --capture-video \ 31 | --clip-reward \ 32 | --learning-starts $learningstarts \ 33 | ${@:2} >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt 34 | done 35 | ) & 36 | done 37 | wait -------------------------------------------------------------------------------- /scripts/atari_100k_wp20_peripheral_only.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | numgpus=${2:-$(nvidia-smi --list-gpus | wc -l)} 3 | 4 | # atari 100k settings 5 | envlist=(alien amidar assault asterix bank_heist battle_zone boxing breakout chopper_command crazy_climber demon_attack freeway frostbite gopher hero jamesbond kangaroo krull kung_fu_master ms_pacman pong private_eye qbert road_runner seaquest up_n_down) #pong qbert seaquest zaxxon 6 | 7 | expname="atari_100k_wp20_peripheral_only" 8 | totaltimesteps="1000000" 9 | buffersize="100000" 10 | learningstarts="80000" 11 | 12 | mkdir -p logs/${expname} 13 | mkdir -p recordings/${expname} 14 | mkdir -p trained_models/${expname} 15 | 16 | for i in ${!envlist[@]} 17 | do 18 | gpuid=$(( $i % $numgpus )) 19 | ( 20 | for seed in 0 1 2 3 4 21 | do 22 | echo "${expname} GPU: ${gpuid} Env: ${envlist[$i]} Seed: ${seed} ${1}" 23 | # sleep 5 24 | basename=$(basename $1) 25 | echo "========" >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt 26 | CUDA_VISIBLE_DEVICES=$gpuid python $1 --env ${envlist[$i]} --seed $seed --exp-name ${expname} \ 27 | --fov-size 0 \ 28 | --peripheral-res 20 \ 29 | --clip-reward \ 30 | --capture-video \ 31 | --total-timesteps $totaltimesteps --buffer-size $buffersize \ 32 | --learning-starts $learningstarts \ 33 | ${@:3} >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt 34 | done 35 | ) & 36 | done 37 | wait -------------------------------------------------------------------------------- /scripts/atari_100k_20x20_5m.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | numgpus=${2:-$(nvidia-smi --list-gpus | wc -l)} 3 | 4 | # atari 100k settings 5 | envlist=(alien amidar assault asterix bank_heist battle_zone boxing breakout chopper_command crazy_climber demon_attack freeway frostbite gopher hero jamesbond kangaroo krull kung_fu_master ms_pacman pong private_eye qbert road_runner seaquest up_n_down) #pong qbert seaquest zaxxon 6 | 7 | expname="atari_100k_20x20_5m" 8 | totaltimesteps="5000000" 9 | buffersize="100000" 10 | learningstarts="80000" 11 | 12 | mkdir -p logs/${expname} 13 | mkdir -p recordings/${expname} 14 | mkdir -p trained_models/${expname} 15 | 16 | for i in ${!envlist[@]} 17 | do 18 | gpuid=$(( $i % $numgpus )) 19 | ( 20 | for seed in 0 1 2 3 4 21 | do 22 | echo "${expname} GPU: ${gpuid} Env: ${envlist[$i]} Seed: ${seed} ${1}" 23 | # sleep 5 24 | basename=$(basename $1) 25 | echo "========" >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt 26 | CUDA_VISIBLE_DEVICES=$gpuid python $1 --env ${envlist[$i]} --seed $seed --exp-name ${expname} \ 27 | --exploration-fraction 0.02 \ 28 | --eval-frequency 1000000 \ 29 | --fov-size 20 \ 30 | --clip-reward \ 31 | --capture-video \ 32 | --total-timesteps $totaltimesteps --buffer-size $buffersize \ 33 | --learning-starts $learningstarts \ 34 | ${@:2} >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt 35 | done 36 | ) & 37 | done 38 | wait -------------------------------------------------------------------------------- /scripts/atari_100k_30x30_5m.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | numgpus=${2:-$(nvidia-smi --list-gpus | wc -l)} 3 | 4 | # atari 100k settings 5 | envlist=(alien amidar assault asterix bank_heist battle_zone boxing breakout chopper_command crazy_climber demon_attack freeway frostbite gopher hero jamesbond kangaroo krull kung_fu_master ms_pacman pong private_eye qbert road_runner seaquest up_n_down) #pong qbert seaquest zaxxon 6 | 7 | expname="atari_100k_30x30_5m" 8 | totaltimesteps="5000000" 9 | buffersize="100000" 10 | learningstarts="80000" 11 | 12 | mkdir -p logs/${expname} 13 | mkdir -p recordings/${expname} 14 | mkdir -p trained_models/${expname} 15 | 16 | for i in ${!envlist[@]} 17 | do 18 | gpuid=$(( $i % $numgpus )) 19 | ( 20 | for seed in 0 1 2 3 4 21 | do 22 | echo "${expname} GPU: ${gpuid} Env: ${envlist[$i]} Seed: ${seed} ${1}" 23 | # sleep 5 24 | basename=$(basename $1) 25 | echo "========" >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt 26 | CUDA_VISIBLE_DEVICES=$gpuid python $1 --env ${envlist[$i]} --seed $seed --exp-name ${expname} \ 27 | --exploration-fraction 0.02 \ 28 | --eval-frequency 1000000 \ 29 | --fov-size 30 \ 30 | --clip-reward \ 31 | --capture-video \ 32 | --total-timesteps $totaltimesteps --buffer-size $buffersize \ 33 | --learning-starts $learningstarts \ 34 | ${@:2} >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt 35 | done 36 | ) & 37 | done 38 | wait -------------------------------------------------------------------------------- /scripts/atari_100k_50x50_5m.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | numgpus=${2:-$(nvidia-smi --list-gpus | wc -l)} 3 | 4 | # atari 100k settings 5 | envlist=(alien amidar assault asterix bank_heist battle_zone boxing breakout chopper_command crazy_climber demon_attack freeway frostbite gopher hero jamesbond kangaroo krull kung_fu_master ms_pacman pong private_eye qbert road_runner seaquest up_n_down) #pong qbert seaquest zaxxon 6 | 7 | expname="atari_100k_50x50_5m" 8 | totaltimesteps="5000000" 9 | buffersize="100000" 10 | learningstarts="80000" 11 | 12 | mkdir -p logs/${expname} 13 | mkdir -p recordings/${expname} 14 | mkdir -p trained_models/${expname} 15 | 16 | for i in ${!envlist[@]} 17 | do 18 | gpuid=$(( $i % $numgpus )) 19 | ( 20 | for seed in 0 1 2 3 4 21 | do 22 | echo "${expname} GPU: ${gpuid} Env: ${envlist[$i]} Seed: ${seed} ${1}" 23 | # sleep 5 24 | basename=$(basename $1) 25 | echo "========" >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt 26 | CUDA_VISIBLE_DEVICES=$gpuid python $1 --env ${envlist[$i]} --seed $seed --exp-name ${expname} \ 27 | --exploration-fraction 0.02 \ 28 | --eval-frequency 1000000 \ 29 | --fov-size 50 \ 30 | --clip-reward \ 31 | --capture-video \ 32 | --total-timesteps $totaltimesteps --buffer-size $buffersize \ 33 | --learning-starts $learningstarts \ 34 | ${@:2} >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt 35 | done 36 | ) & 37 | done 38 | wait -------------------------------------------------------------------------------- /common/pvm_buffer.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | from typing import Tuple 3 | import copy 4 | 5 | import numpy as np 6 | 7 | class PVMBuffer: 8 | 9 | def __init__(self, max_len: int, obs_size: Tuple, fov_loc_size: Tuple = None) -> None: 10 | self.max_len = max_len 11 | self.obs_size = obs_size 12 | self.fov_loc_size = fov_loc_size 13 | self.buffer = None 14 | self.fov_loc_buffer = None 15 | self.init_pvm_buffer() 16 | 17 | 18 | def init_pvm_buffer(self) -> None: 19 | self.buffer = deque([], maxlen=self.max_len) 20 | self.fov_loc_buffer = deque([], maxlen=self.max_len) 21 | for _ in range(self.max_len): 22 | self.buffer.append(np.zeros(self.obs_size, dtype=np.float32)) 23 | if self.fov_loc_size is not None: # (1, 2) or (1, 2, 3) 24 | self.fov_loc_buffer.append(np.zeros(self.fov_loc_size, dtype=np.float32)) 25 | 26 | 27 | def append(self, x, fov_loc=None) -> None: 28 | self.buffer.append(x) 29 | if fov_loc is not None: 30 | self.fov_loc_buffer.append(fov_loc) 31 | 32 | def copy(self): 33 | return copy.deepcopy(self) 34 | 35 | def get_obs(self, mode="stack_max") -> np.ndarray: 36 | if mode == "stack_max": 37 | return np.amax(np.stack(self.buffer, axis=1), axis=1) # leading dim is batch dim [B, 1, C, H, W] 38 | elif mode == "stack_mean": 39 | return np.mean(np.stack(self.buffer, axis=1), axis=1, keepdims=True) # leading dim is batch dim [B, 1, C, H, W] 40 | elif mode == "stack": 41 | # print ([x.shape for x in self.buffer]) 42 | return np.stack(self.buffer, axis=1) # [B, T, C, H, W] 43 | elif mode == "stack_channel": 44 | return np.concatenate(self.buffer, axis=1) # [B, T*C, H, W] 45 | else: 46 | raise NotImplementedError 47 | 48 | def get_fov_locs(self, return_mask=False, relative_transform=False) -> np.ndarray: 49 | # print ([x.shape for x in self.fov_loc_buffer]) 50 | transforms = [] 51 | if relative_transform: 52 | for t in range(len(self.fov_loc_buffer)): 53 | transforms.append(np.zeros((self.fov_loc_buffer[t].shape[0], 2, 3), dtype=np.float32)) # [B, 2, 3] B=1 usually 54 | for b in range(len(self.fov_loc_buffer[t])): 55 | extrinsic = self.fov_loc_buffer[t][b] 56 | target_extrinsic = self.fov_loc_buffer[t][-1] 57 | # print (extrinsic, extrinsic.shape) 58 | if np.linalg.det(extrinsic): 59 | extrinsic_inv = np.linalg.inv(extrinsic) 60 | transform = np.dot(target_extrinsic, extrinsic_inv) 61 | # 4x4 transformation matrix 62 | R = transform[:3, :3] # Extract the rotation 63 | tr = transform[:3, 3] # Extract the translation 64 | H = R + np.outer(tr, np.array([0, 0, 1], dtype=np.float32)) 65 | # Assuming H is the 3x3 homography matrix 66 | A = H / H[2, 2] 67 | affine = A[:2, :] 68 | transforms[t][b] = affine 69 | else: 70 | H = np.identity(3, dtype=np.float32) 71 | A = H / H[2, 2] 72 | affine = A[:2, :] 73 | transforms[t][b] = affine 74 | # print (transforms, [x.shape for x in transforms]) 75 | return np.stack(transforms, axis=1) # [B, T, 2, 3] 76 | else: 77 | return np.stack(self.fov_loc_buffer, axis=1) 78 | #[B, T, *fov_locs_size], maybe [B, T, 2] for 2d or [B, T, 4, 4] for 3d 79 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SUGARL 2 | Code for NeurIPS 2023 paper **Active Vision Reinforcement Learning with Limited Visual Observability**, by [Jinghuan Shang](https://www.cs.stonybrook.edu/~jishang) and [Michael S. Ryoo](http://michaelryoo.com/). 3 | 4 | We propose Sensorimotor Understanding Guided Active Reinforcement Learning (SUGARL) to solve ActiveVision-RL tasks. 5 | We also introduce [Active-Gym](https://github.com/elicassion/active-gym), a convenient library that modifies existing RL environments for ActiveVision-RL, with Gymnasium-like interface. 6 | 7 | [[Paper]](https://arxiv.org/abs/2306.00975) [[Project Page]](https://elicassion.github.io/sugarl/sugarl.html) [[Active-Gym]](https://github.com/elicassion/active-gym) 8 | 9 | 10 | 11 | ## Dependency 12 | ``` 13 | conda env create -f active_rl_env.yaml 14 | ``` 15 | We highlight [Active-Gym](https://github.com/elicassion/active-gym) developed by us to support Active-RL setting for many environments. 16 | 17 | 18 | ## Usage 19 | - General format: 20 | ``` 21 | cd sugarl # make sure you are under the root dir of this repo 22 | bash ./scripts/ agent/ 23 | ``` 24 | 25 | - Reproduce our experiments: 26 | ``` 27 | cd sugarl # make sure you are under the root dir of this repo 28 | bash ./scripts/robosuite_series.sh agent/ 29 | bash ./scripts/atari_series.sh agent/ 30 | bash ./scripts/atari_series_5m.sh agent/ 31 | bash ./scripts/atari_wp_series.sh agent/ 32 | bash ./scripts/dmc_series.sh agent/ 33 | ``` 34 | For example, to run SUGARL-DQN on Atari 35 | ``` 36 | bash ./scripts/atari_series.sh agent/dqn_atari_sugarl.py 37 | ``` 38 | 39 | - Sanity checks: they run through the whole process with only a tiny amount of training to check bugs 40 | ``` 41 | cd sugarl # make sure you are under the root dir of this repo 42 | bash ./scripts/atari_test.sh agent/ 43 | bash ./scripts/dmc_test.sh agent/ 44 | bash ./scripts/robosuite_test.sh agent/ 45 | ``` 46 | 47 | All experiment scripts automatically scale all tasks to your GPUs. Please modify the gpu behavior (`CUDA_VISIBLE_DEVICES=`) in the script if 48 | - you want to run jobs on certain GPUs 49 | - either VRAM or RAM is not sufficient for scaling all jobs 50 | 51 | In the provided scripts, 26 Atari games are in parallel, with sequentially executing each seed. 6 DMC environments x 5 seeds are all in parallel. Please do check the available RAM and VRAM on your machine before starting. 52 | 53 | ### Notes 54 | **Naming**: 55 | 56 | All agents are under `agent/`, with the name format `__.py`. Each file is an individual entry for the whole process. We support DQN, SAC, and DrQ for base algorithms. 57 | 58 | All experiment scripts are under `scripts/`, with the format `_.sh` 59 | Please ensure that the env and setting match the agent when launching jobs. 60 | 61 | **Resource requirement reference (SUGARL)**: 62 | 63 | - Atari: 64 | for each game with `100k` replay buffer: `~18G` RAM, `<2G` VRAM 65 | 66 | - DMC: 67 | for each task with `100k` replay buffer: `~18G` RAM, `<3G` VRAM 68 | 69 | - Robosuite: 70 | for each task with `100k` replay buffer: `~54G` RAM, `4.2G` VRAM 71 | 72 | **Coding style**: 73 | 74 | We follow the coding style of [clean-rl](https://github.com/vwxyzjn/cleanrl) so that modifications on one agent would not affect others. This does introduce lots of redundency, but is so much easier for arranging experiments and evolving the algorithm. 75 | 76 | ## Citation 77 | Please consider cite us if you find this repo helpful. 78 | ``` 79 | @article{shang2023active, 80 | title={Active Reinforcement Learning under Limited Visual Observability}, 81 | author={Jinghuan Shang and Michael S. Ryoo}, 82 | journal={arXiv preprint}, 83 | year={2023}, 84 | eprint={2306.00975}, 85 | } 86 | ``` 87 | 88 | ## Acknowledgement 89 | We thank the implementation of [clean-rl](https://github.com/vwxyzjn/cleanrl). 90 | -------------------------------------------------------------------------------- /agent/dqn_atari_base.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | import os.path as osp 4 | import random 5 | import time 6 | from distutils.util import strtobool 7 | 8 | sys.path.append(osp.dirname(osp.dirname(osp.realpath(__file__)))) 9 | os.environ["OMP_NUM_THREADS"] = "1" 10 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 11 | import warnings 12 | warnings.filterwarnings("ignore", category=UserWarning) 13 | 14 | import gymnasium as gym 15 | import numpy as np 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | import torch.optim as optim 20 | from common.buffer import ReplayBuffer 21 | from common.utils import get_timestr, seed_everything 22 | from torch.utils.tensorboard import SummaryWriter 23 | 24 | from active_gym import AtariBaseEnv, AtariEnvArgs 25 | 26 | 27 | def parse_args(): 28 | # fmt: off 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__), 31 | help="the name of this experiment") 32 | parser.add_argument("--seed", type=int, default=1, 33 | help="seed of the experiment") 34 | parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, 35 | help="if toggled, cuda will be enabled by default") 36 | parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, 37 | help="whether to capture videos of the agent performances (check out `videos` folder)") 38 | 39 | # env setting 40 | parser.add_argument("--env", type=str, default="breakout", 41 | help="the id of the environment") 42 | parser.add_argument("--env-num", type=int, default=1, 43 | help="# envs in parallel") 44 | parser.add_argument("--frame-stack", type=int, default=4, 45 | help="frame stack #") 46 | parser.add_argument("--action-repeat", type=int, default=4, 47 | help="action repeat #") 48 | parser.add_argument("--clip-reward", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True) 49 | 50 | # Algorithm specific arguments 51 | parser.add_argument("--total-timesteps", type=int, default=3000000, 52 | help="total timesteps of the experiments") 53 | parser.add_argument("--learning-rate", type=float, default=1e-4, 54 | help="the learning rate of the optimizer") 55 | parser.add_argument("--buffer-size", type=int, default=500000, 56 | help="the replay memory buffer size") 57 | parser.add_argument("--gamma", type=float, default=0.99, 58 | help="the discount factor gamma") 59 | parser.add_argument("--target-network-frequency", type=int, default=1000, 60 | help="the timesteps it takes to update the target network") 61 | parser.add_argument("--batch-size", type=int, default=32, 62 | help="the batch size of sample from the reply memory") 63 | parser.add_argument("--start-e", type=float, default=1, 64 | help="the starting epsilon for exploration") 65 | parser.add_argument("--end-e", type=float, default=0.01, 66 | help="the ending epsilon for exploration") 67 | parser.add_argument("--exploration-fraction", type=float, default=0.10, 68 | help="the fraction of `total-timesteps` it takes from start-e to go end-e") 69 | parser.add_argument("--learning-starts", type=int, default=80000, 70 | help="timestep to start learning") 71 | parser.add_argument("--train-frequency", type=int, default=4, 72 | help="the frequency of training") 73 | 74 | # eval args 75 | parser.add_argument("--eval-frequency", type=int, default=-1, 76 | help="eval frequency. default -1 is eval at the end.") 77 | parser.add_argument("--eval-num", type=int, default=10, 78 | help="eval frequency. default -1 is eval at the end.") 79 | args = parser.parse_args() 80 | # fmt: on 81 | return args 82 | 83 | 84 | def make_env(env_name, seed, **kwargs): 85 | def thunk(): 86 | env_args = AtariEnvArgs( 87 | game=env_name, seed=seed, obs_size=(84, 84), **kwargs 88 | ) 89 | env = AtariBaseEnv(env_args) 90 | env.action_space.seed(seed) 91 | env.observation_space.seed(seed) 92 | return env 93 | 94 | return thunk 95 | 96 | 97 | # ALGO LOGIC: initialize agent here: 98 | class QNetwork(nn.Module): 99 | def __init__(self, env): 100 | super().__init__() 101 | self.network = nn.Sequential( 102 | nn.Conv2d(4, 32, 8, stride=4), 103 | nn.ReLU(), 104 | nn.Conv2d(32, 64, 4, stride=2), 105 | nn.ReLU(), 106 | nn.Conv2d(64, 64, 3, stride=1), 107 | nn.ReLU(), 108 | nn.Flatten(), 109 | nn.Linear(3136, 512), 110 | nn.ReLU(), 111 | nn.Linear(512, env.single_action_space.n), 112 | ) 113 | 114 | def forward(self, x): 115 | return self.network(x) 116 | 117 | 118 | def linear_schedule(start_e: float, end_e: float, duration: int, t: int): 119 | slope = (end_e - start_e) / duration 120 | return max(slope * t + start_e, end_e) 121 | 122 | 123 | if __name__ == "__main__": 124 | args = parse_args() 125 | args.env = args.env.lower() 126 | run_name = f"{args.env}__{os.path.basename(__file__)}__{args.seed}__{get_timestr()}" 127 | run_dir = os.path.join("runs", args.exp_name) 128 | if not os.path.exists(run_dir): 129 | os.makedirs(run_dir, exist_ok=True) 130 | 131 | writer = SummaryWriter(os.path.join(run_dir, run_name)) 132 | writer.add_text( 133 | "hyperparameters", 134 | "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), 135 | ) 136 | 137 | # TRY NOT TO MODIFY: seeding 138 | seed_everything(args.seed) 139 | 140 | 141 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") 142 | 143 | # env setup 144 | envs = [] 145 | for i in range(args.env_num): 146 | envs.append(make_env(args.env, args.seed+i, frame_stack=args.frame_stack, action_repeat=args.action_repeat, clip_reward=args.clip_reward)) 147 | # envs = gym.vector.AsyncVectorEnv(envs) 148 | envs = gym.vector.SyncVectorEnv(envs) 149 | # async 4 100s - 24k 150 | # async 2 47s - 50k 151 | # async 1 60s - 50k 152 | # sync 1 50s - 50k 153 | # sync 2 46s - 50k 154 | 155 | q_network = QNetwork(envs).to(device) 156 | optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate) 157 | target_network = QNetwork(envs).to(device) 158 | target_network.load_state_dict(q_network.state_dict()) 159 | 160 | rb = ReplayBuffer( 161 | args.buffer_size, 162 | envs.single_observation_space, 163 | envs.single_action_space, 164 | device, 165 | n_envs=envs.num_envs, 166 | optimize_memory_usage=True, 167 | handle_timeout_termination=False, 168 | ) 169 | start_time = time.time() 170 | 171 | # TRY NOT TO MODIFY: start the game 172 | obs, _ = envs.reset() 173 | global_transitions = 0 174 | while global_transitions < args.total_timesteps: 175 | # ALGO LOGIC: put action logic here 176 | epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_transitions) 177 | if random.random() < epsilon: 178 | actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) 179 | else: 180 | q_values = q_network(torch.from_numpy(obs).to(device)) 181 | actions = torch.argmax(q_values, dim=1).cpu().numpy() 182 | 183 | # TRY NOT TO MODIFY: execute the game and log data. 184 | next_obs, rewards, dones, _, infos = envs.step(actions) 185 | # print (global_step, infos) 186 | 187 | # TRY NOT TO MODIFY: record rewards for plotting purposes 188 | if "final_info" in infos: 189 | for idx, d in enumerate(dones): 190 | if d: 191 | print(f"[T: {time.time()-start_time:.2f}] [N: {global_transitions:07,d}] [R: {infos['final_info'][idx]['reward']:.2f}]") 192 | writer.add_scalar("charts/episodic_return", infos['final_info'][idx]["reward"], global_transitions) 193 | writer.add_scalar("charts/episodic_length", infos['final_info'][idx]["ep_len"], global_transitions) 194 | writer.add_scalar("charts/epsilon", epsilon, global_transitions) 195 | break 196 | 197 | # TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation` 198 | real_next_obs = next_obs 199 | for idx, d in enumerate(dones): 200 | if d: 201 | real_next_obs[idx] = infos["final_observation"][idx] 202 | rb.add(obs, real_next_obs, actions, rewards, dones, {}) 203 | 204 | # TRY NOT TO MODIFY: CRUCIAL step easy to overlook 205 | obs = next_obs 206 | 207 | # INC total transitions 208 | global_transitions += args.env_num 209 | 210 | # ALGO LOGIC: training. 211 | if global_transitions > args.learning_starts: 212 | if global_transitions % args.train_frequency == 0: 213 | data = rb.sample(args.batch_size // args.env_num) # counter-balance the true global transitions used for training 214 | with torch.no_grad(): 215 | target_max, _ = target_network(data.next_observations).max(dim=1) 216 | td_target = data.rewards.flatten() + args.gamma * target_max * (1 - data.dones.flatten()) 217 | old_val = q_network(data.observations).gather(1, data.actions).squeeze() 218 | loss = F.mse_loss(td_target, old_val) 219 | 220 | if global_transitions % 100 == 0: 221 | writer.add_scalar("losses/td_loss", loss.item(), global_transitions) 222 | writer.add_scalar("losses/q_values", old_val.mean().item(), global_transitions) 223 | # print("SPS:", int(global_step / (time.time() - start_time))) 224 | writer.add_scalar("charts/SPS", int(global_transitions / (time.time() - start_time)), global_transitions) 225 | 226 | # optimize the model 227 | optimizer.zero_grad() 228 | loss.backward() 229 | optimizer.step() 230 | 231 | # update the target network 232 | if (global_transitions // args.env_num) % args.target_network_frequency == 0: 233 | target_network.load_state_dict(q_network.state_dict()) 234 | 235 | # evaluation 236 | if (global_transitions % args.eval_frequency == 0 and args.eval_frequency > 0) or \ 237 | (global_transitions >= args.total_timesteps): 238 | q_network.eval() 239 | 240 | eval_episodic_returns, eval_episodic_lengths = [], [] 241 | 242 | for eval_ep in range(args.eval_num): 243 | eval_env = [make_env(args.env, args.seed+eval_ep, frame_stack=args.frame_stack, action_repeat=args.action_repeat, clip_reward=args.clip_reward, training=False, record=args.capture_video)] 244 | eval_env = gym.vector.SyncVectorEnv(eval_env) 245 | obs_eval, _ = eval_env.reset() 246 | done = False 247 | while not done: 248 | q_values = q_network(torch.from_numpy(obs_eval).to(device)) 249 | actions = torch.argmax(q_values, dim=1).cpu().numpy() 250 | next_obs_eval, rewards, dones, _, infos = eval_env.step(actions) 251 | obs_eval = next_obs_eval 252 | done = dones[0] 253 | if done: 254 | eval_episodic_returns.append(infos['final_info'][0]["reward"]) 255 | eval_episodic_lengths.append(infos['final_info'][0]["ep_len"]) 256 | 257 | writer.add_scalar("charts/eval_episodic_return", np.mean(eval_episodic_returns), global_transitions) 258 | writer.add_scalar("charts/eval_episodic_return_std", np.std(eval_episodic_returns), global_transitions) 259 | # writer.add_scalar("charts/eval_episodic_length", np.mean(), global_transitions) 260 | print(f"[T: {time.time()-start_time:.2f}] [N: {global_transitions:07,d}] [Eval R: {np.mean(eval_episodic_returns):.2f}+/-{np.std(eval_episodic_returns):.2f}] [R list: {','.join([str(r) for r in eval_episodic_returns])}]") 261 | 262 | q_network.train() 263 | 264 | 265 | 266 | envs.close() 267 | eval_env.close() 268 | writer.close() -------------------------------------------------------------------------------- /agent/dqn_atari_wp_random.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | import os.path as osp 4 | import random 5 | import time 6 | from itertools import product 7 | from distutils.util import strtobool 8 | 9 | sys.path.append(osp.dirname(osp.dirname(osp.realpath(__file__)))) 10 | os.environ["OMP_NUM_THREADS"] = "1" 11 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 12 | import warnings 13 | warnings.filterwarnings("ignore", category=UserWarning) 14 | 15 | import gymnasium as gym 16 | from gymnasium.spaces import Discrete, Dict 17 | import numpy as np 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | import torch.optim as optim 22 | from torchvision.transforms import Resize 23 | 24 | from common.buffer import ReplayBuffer 25 | from common.utils import get_timestr, seed_everything 26 | from torch.utils.tensorboard import SummaryWriter 27 | 28 | from active_gym import AtariFixedFovealPeripheralEnv, AtariEnvArgs 29 | 30 | 31 | def parse_args(): 32 | # fmt: off 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), 35 | help="the name of this experiment") 36 | parser.add_argument("--seed", type=int, default=1, 37 | help="seed of the experiment") 38 | parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, 39 | help="if toggled, cuda will be enabled by default") 40 | parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, 41 | help="whether to capture videos of the agent performances (check out `videos` folder)") 42 | 43 | # env setting 44 | parser.add_argument("--env", type=str, default="breakout", 45 | help="the id of the environment") 46 | parser.add_argument("--env-num", type=int, default=1, 47 | help="# envs in parallel") 48 | parser.add_argument("--frame-stack", type=int, default=4, 49 | help="frame stack #") 50 | parser.add_argument("--action-repeat", type=int, default=4, 51 | help="action repeat #") 52 | parser.add_argument("--clip-reward", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True) 53 | 54 | # fov setting 55 | parser.add_argument("--fov-size", type=int, default=50) 56 | parser.add_argument("--fov-init-loc", type=int, default=0) 57 | parser.add_argument("--sensory-action-mode", type=str, default="absolute") 58 | parser.add_argument("--sensory-action-space", type=int, default=10) # ignored when sensory_action_mode="relative" 59 | parser.add_argument("--resize-to-full", default=False, action="store_true") 60 | parser.add_argument("--peripheral-res", type=int, default=20) 61 | # for discrete observ action 62 | parser.add_argument("--sensory-action-x-size", type=int, default=4) 63 | parser.add_argument("--sensory-action-y-size", type=int, default=4) 64 | 65 | # Algorithm specific arguments 66 | parser.add_argument("--total-timesteps", type=int, default=3000000, 67 | help="total timesteps of the experiments") 68 | parser.add_argument("--learning-rate", type=float, default=1e-4, 69 | help="the learning rate of the optimizer") 70 | parser.add_argument("--buffer-size", type=int, default=500000, 71 | help="the replay memory buffer size") 72 | parser.add_argument("--gamma", type=float, default=0.99, 73 | help="the discount factor gamma") 74 | parser.add_argument("--target-network-frequency", type=int, default=1000, 75 | help="the timesteps it takes to update the target network") 76 | parser.add_argument("--batch-size", type=int, default=32, 77 | help="the batch size of sample from the reply memory") 78 | parser.add_argument("--start-e", type=float, default=1, 79 | help="the starting epsilon for exploration") 80 | parser.add_argument("--end-e", type=float, default=0.01, 81 | help="the ending epsilon for exploration") 82 | parser.add_argument("--exploration-fraction", type=float, default=0.10, 83 | help="the fraction of `total-timesteps` it takes from start-e to go end-e") 84 | parser.add_argument("--learning-starts", type=int, default=80000, 85 | help="timestep to start learning") 86 | parser.add_argument("--train-frequency", type=int, default=4, 87 | help="the frequency of training") 88 | 89 | # eval args 90 | parser.add_argument("--eval-frequency", type=int, default=-1, 91 | help="eval frequency. default -1 is eval at the end.") 92 | parser.add_argument("--eval-num", type=int, default=10, 93 | help="eval frequency. default -1 is eval at the end.") 94 | args = parser.parse_args() 95 | # fmt: on 96 | return args 97 | 98 | 99 | def make_env(env_name, seed, **kwargs): 100 | def thunk(): 101 | env_args = AtariEnvArgs( 102 | game=env_name, seed=seed, obs_size=(84, 84), **kwargs 103 | ) 104 | env = AtariFixedFovealPeripheralEnv(env_args) 105 | env.action_space.seed(seed) 106 | env.observation_space.seed(seed) 107 | return env 108 | 109 | return thunk 110 | 111 | 112 | # ALGO LOGIC: initialize agent here: 113 | class QNetwork(nn.Module): 114 | def __init__(self, env): 115 | super().__init__() 116 | if isinstance(env.single_action_space, Discrete): 117 | action_space_size = env.single_action_space.n 118 | elif isinstance(env.single_action_space, Dict): 119 | action_space_size = env.single_action_space["motor_action"].n 120 | self.network = nn.Sequential( 121 | nn.Conv2d(4, 32, 8, stride=4), 122 | nn.ReLU(), 123 | nn.Conv2d(32, 64, 4, stride=2), 124 | nn.ReLU(), 125 | nn.Conv2d(64, 64, 3, stride=1), 126 | nn.ReLU(), 127 | nn.Flatten(), 128 | nn.Linear(3136, 512), 129 | nn.ReLU(), 130 | nn.Linear(512, action_space_size), 131 | ) 132 | 133 | def forward(self, x): 134 | return self.network(x) 135 | 136 | 137 | def linear_schedule(start_e: float, end_e: float, duration: int, t: int): 138 | slope = (end_e - start_e) / duration 139 | return max(slope * t + start_e, end_e) 140 | 141 | 142 | if __name__ == "__main__": 143 | args = parse_args() 144 | args.env = args.env.lower() 145 | run_name = f"{args.env}__{os.path.basename(__file__)}__{args.seed}__{get_timestr()}" 146 | run_dir = os.path.join("runs", args.exp_name) 147 | if not os.path.exists(run_dir): 148 | os.makedirs(run_dir, exist_ok=True) 149 | 150 | writer = SummaryWriter(os.path.join(run_dir, run_name)) 151 | writer.add_text( 152 | "hyperparameters", 153 | "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), 154 | ) 155 | 156 | # TRY NOT TO MODIFY: seeding 157 | seed_everything(args.seed) 158 | 159 | 160 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") 161 | 162 | # env setup 163 | envs = [] 164 | for i in range(args.env_num): 165 | envs.append(make_env(args.env, args.seed+i, frame_stack=args.frame_stack, action_repeat=args.action_repeat, 166 | fov_size=(args.fov_size, args.fov_size), 167 | fov_init_loc=(args.fov_init_loc, args.fov_init_loc), 168 | peripheral_res=(args.peripheral_res, args.peripheral_res), 169 | sensory_action_mode=args.sensory_action_mode, 170 | sensory_action_space=(-args.sensory_action_space, args.sensory_action_space), 171 | resize_to_full=args.resize_to_full, 172 | clip_reward=args.clip_reward)) 173 | # envs = gym.vector.AsyncVectorEnv(envs) 174 | envs = gym.vector.SyncVectorEnv(envs) 175 | 176 | resize = Resize((84, 84)) 177 | 178 | # get a discrete observ action space 179 | OBSERVATION_SIZE = (84, 84) 180 | observ_x_max, observ_y_max = OBSERVATION_SIZE[0]-args.fov_size, OBSERVATION_SIZE[1]-args.fov_size 181 | sensory_action_step = (observ_x_max//args.sensory_action_x_size, 182 | observ_y_max//args.sensory_action_y_size) 183 | sensory_action_x_set = list(range(0, observ_x_max, sensory_action_step[0]))[:args.sensory_action_x_size] 184 | sensory_action_y_set = list(range(0, observ_y_max, sensory_action_step[1]))[:args.sensory_action_y_size] 185 | sensory_action_set = list(product(sensory_action_x_set, sensory_action_y_set)) 186 | 187 | q_network = QNetwork(envs).to(device) 188 | optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate) 189 | target_network = QNetwork(envs).to(device) 190 | target_network.load_state_dict(q_network.state_dict()) 191 | 192 | rb = ReplayBuffer( 193 | args.buffer_size, 194 | envs.single_observation_space, 195 | envs.single_action_space["motor_action"], 196 | device, 197 | n_envs=envs.num_envs, 198 | optimize_memory_usage=True, 199 | handle_timeout_termination=False, 200 | ) 201 | start_time = time.time() 202 | 203 | # TRY NOT TO MODIFY: start the game 204 | obs, _ = envs.reset() 205 | global_transitions = 0 206 | while global_transitions < args.total_timesteps: 207 | # ALGO LOGIC: put action logic here 208 | epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_transitions) 209 | if random.random() < epsilon: 210 | actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) 211 | motor_actions = np.array([actions[0]["motor_action"]]) 212 | else: 213 | q_values = q_network(resize(torch.from_numpy(obs)).to(device)) 214 | motor_actions = torch.argmax(q_values, dim=1).cpu().numpy() 215 | 216 | # TRY NOT TO MODIFY: execute the game and log data. 217 | next_obs, rewards, dones, _, infos = envs.step({"motor_action": motor_actions, 218 | "sensory_action": [sensory_action_set[random.randint(0, len(sensory_action_set)-1)]] }) 219 | # print (global_step, infos) 220 | 221 | # TRY NOT TO MODIFY: record rewards for plotting purposes 222 | if "final_info" in infos: 223 | for idx, d in enumerate(dones): 224 | if d: 225 | print(f"[T: {time.time()-start_time:.2f}] [N: {global_transitions:07,d}] [R: {infos['final_info'][idx]['reward']:.2f}]") 226 | writer.add_scalar("charts/episodic_return", infos['final_info'][idx]["reward"], global_transitions) 227 | writer.add_scalar("charts/episodic_length", infos['final_info'][idx]["ep_len"], global_transitions) 228 | writer.add_scalar("charts/epsilon", epsilon, global_transitions) 229 | break 230 | 231 | # TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation` 232 | real_next_obs = next_obs 233 | for idx, d in enumerate(dones): 234 | if d: 235 | real_next_obs[idx] = infos["final_observation"][idx] 236 | rb.add(obs, real_next_obs, motor_actions, rewards, dones, {}) 237 | 238 | # TRY NOT TO MODIFY: CRUCIAL step easy to overlook 239 | obs = next_obs 240 | 241 | # INC total transitions 242 | global_transitions += args.env_num 243 | 244 | # ALGO LOGIC: training. 245 | if global_transitions > args.learning_starts: 246 | if global_transitions % args.train_frequency == 0: 247 | data = rb.sample(args.batch_size // args.env_num) # counter-balance the true global transitions used for training 248 | with torch.no_grad(): 249 | target_max, _ = target_network(resize(data.next_observations)).max(dim=1) 250 | td_target = data.rewards.flatten() + args.gamma * target_max * (1 - data.dones.flatten()) 251 | old_val = q_network(resize(data.observations)).gather(1, data.actions).squeeze() 252 | loss = F.mse_loss(td_target, old_val) 253 | 254 | if global_transitions % 100 == 0: 255 | writer.add_scalar("losses/td_loss", loss, global_transitions) 256 | writer.add_scalar("losses/q_values", old_val.mean().item(), global_transitions) 257 | # print("SPS:", int(global_step / (time.time() - start_time))) 258 | writer.add_scalar("charts/SPS", int(global_transitions / (time.time() - start_time)), global_transitions) 259 | 260 | # optimize the model 261 | optimizer.zero_grad() 262 | loss.backward() 263 | optimizer.step() 264 | 265 | # update the target network 266 | if (global_transitions // args.env_num) % args.target_network_frequency == 0: 267 | target_network.load_state_dict(q_network.state_dict()) 268 | 269 | # evaluation 270 | if (global_transitions % args.eval_frequency == 0 and args.eval_frequency > 0) or \ 271 | (global_transitions >= args.total_timesteps): 272 | q_network.eval() 273 | 274 | eval_episodic_returns, eval_episodic_lengths = [], [] 275 | 276 | for eval_ep in range(args.eval_num): 277 | eval_env = [make_env(args.env, args.seed+eval_ep, frame_stack=args.frame_stack, action_repeat=args.action_repeat, 278 | fov_size=(args.fov_size, args.fov_size), 279 | fov_init_loc=(args.fov_init_loc, args.fov_init_loc), 280 | peripheral_res=(args.peripheral_res, args.peripheral_res), 281 | sensory_action_mode=args.sensory_action_mode, 282 | sensory_action_space=(-args.sensory_action_space, args.sensory_action_space), 283 | resize_to_full=args.resize_to_full, 284 | clip_reward=args.clip_reward, 285 | training=False)] 286 | eval_env = gym.vector.SyncVectorEnv(eval_env) 287 | obs, _ = eval_env.reset() 288 | done = False 289 | while not done: 290 | q_values = q_network(resize(torch.from_numpy(obs)).to(device)) 291 | motor_actions = torch.argmax(q_values, dim=1).cpu().numpy() 292 | next_obs, rewards, dones, _, infos = eval_env.step({"motor_action": motor_actions, 293 | "sensory_action": [sensory_action_set[random.randint(0, len(sensory_action_set)-1)]]}) 294 | obs = next_obs 295 | done = dones[0] 296 | if done: 297 | eval_episodic_returns.append(infos['final_info'][0]["reward"]) 298 | eval_episodic_lengths.append(infos['final_info'][0]["ep_len"]) 299 | 300 | writer.add_scalar("charts/eval_episodic_return", np.mean(eval_episodic_returns), global_transitions) 301 | writer.add_scalar("charts/eval_episodic_return_std", np.std(eval_episodic_returns), global_transitions) 302 | print(f"[T: {time.time()-start_time:.2f}] [N: {global_transitions:07,d}] [Eval R: {np.mean(eval_episodic_returns):.2f}+/-{np.std(eval_episodic_returns):.2f}] [R list: {','.join([str(r) for r in eval_episodic_returns])}]") 303 | 304 | q_network.train() 305 | 306 | 307 | 308 | envs.close() 309 | eval_env.close() 310 | writer.close() -------------------------------------------------------------------------------- /common/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Borrow from stable-baselines3 3 | Due to dependencies incompability, we cherry-pick codes here 4 | """ 5 | import os, random, re 6 | from datetime import datetime 7 | import warnings 8 | from typing import Dict, Tuple, Union 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | from torch.nn import functional as F 14 | from torch import distributions as pyd 15 | from torch.distributions.utils import _standard_normal 16 | 17 | from gymnasium import spaces 18 | 19 | def seed_everything(seed): 20 | random.seed(seed) 21 | os.environ['PYTHONHASHSEED']=str(seed) 22 | np.random.seed(seed) 23 | torch.manual_seed(seed) 24 | torch.cuda.manual_seed_all(seed) 25 | torch.backends.cudnn.benchmark = False 26 | torch.backends.cudnn.deterministic = True 27 | 28 | def is_image_space_channels_first(observation_space: spaces.Box) -> bool: 29 | """ 30 | Check if an image observation space (see ``is_image_space``) 31 | is channels-first (CxHxW, True) or channels-last (HxWxC, False). 32 | Use a heuristic that channel dimension is the smallest of the three. 33 | If second dimension is smallest, raise an exception (no support). 34 | :param observation_space: 35 | :return: True if observation space is channels-first image, False if channels-last. 36 | """ 37 | smallest_dimension = np.argmin(observation_space.shape).item() 38 | if smallest_dimension == 1: 39 | warnings.warn("Treating image space as channels-last, while second dimension was smallest of the three.") 40 | return smallest_dimension == 0 41 | 42 | 43 | def is_image_space( 44 | observation_space: spaces.Space, 45 | check_channels: bool = False, 46 | normalized_image: bool = False, 47 | ) -> bool: 48 | """ 49 | Check if a observation space has the shape, limits and dtype 50 | of a valid image. 51 | The check is conservative, so that it returns False if there is a doubt. 52 | Valid images: RGB, RGBD, GrayScale with values in [0, 255] 53 | :param observation_space: 54 | :param check_channels: Whether to do or not the check for the number of channels. 55 | e.g., with frame-stacking, the observation space may have more channels than expected. 56 | :param normalized_image: Whether to assume that the image is already normalized 57 | or not (this disables dtype and bounds checks): when True, it only checks that 58 | the space is a Box and has 3 dimensions. 59 | Otherwise, it checks that it has expected dtype (uint8) and bounds (values in [0, 255]). 60 | :return: 61 | """ 62 | check_dtype = check_bounds = not normalized_image 63 | if isinstance(observation_space, spaces.Box) and len(observation_space.shape) == 3: 64 | # Check the type 65 | if check_dtype and observation_space.dtype != np.uint8: 66 | return False 67 | 68 | # Check the value range 69 | incorrect_bounds = np.any(observation_space.low != 0) or np.any(observation_space.high != 255) 70 | if check_bounds and incorrect_bounds: 71 | return False 72 | 73 | # Skip channels check 74 | if not check_channels: 75 | return True 76 | # Check the number of channels 77 | if is_image_space_channels_first(observation_space): 78 | n_channels = observation_space.shape[0] 79 | else: 80 | n_channels = observation_space.shape[-1] 81 | # GrayScale, RGB, RGBD 82 | return n_channels in [1, 3, 4] 83 | return False 84 | 85 | 86 | 87 | def preprocess_obs( 88 | obs: torch.Tensor, 89 | observation_space: spaces.Space, 90 | normalize_images: bool = True, 91 | ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: 92 | """ 93 | Preprocess observation to be to a neural network. 94 | For images, it normalizes the values by dividing them by 255 (to have values in [0, 1]) 95 | For discrete observations, it create a one hot vector. 96 | :param obs: Observation 97 | :param observation_space: 98 | :param normalize_images: Whether to normalize images or not 99 | (True by default) 100 | :return: 101 | """ 102 | if isinstance(observation_space, spaces.Box): 103 | if normalize_images and is_image_space(observation_space): 104 | return obs.float() / 255.0 105 | return obs.float() 106 | 107 | elif isinstance(observation_space, spaces.Discrete): 108 | # One hot encoding and convert to float to avoid errors 109 | return F.one_hot(obs.long(), num_classes=observation_space.n).float() 110 | 111 | elif isinstance(observation_space, spaces.MultiDiscrete): 112 | # Tensor concatenation of one hot encodings of each Categorical sub-space 113 | return torch.cat( 114 | [ 115 | F.one_hot(obs_.long(), num_classes=int(observation_space.nvec[idx])).float() 116 | for idx, obs_ in enumerate(torch.split(obs.long(), 1, dim=1)) 117 | ], 118 | dim=-1, 119 | ).view(obs.shape[0], sum(observation_space.nvec)) 120 | 121 | elif isinstance(observation_space, spaces.MultiBinary): 122 | return obs.float() 123 | 124 | elif isinstance(observation_space, spaces.Dict): 125 | # Do not modify by reference the original observation 126 | assert isinstance(obs, Dict), f"Expected dict, got {type(obs)}" 127 | preprocessed_obs = {} 128 | for key, _obs in obs.items(): 129 | preprocessed_obs[key] = preprocess_obs(_obs, observation_space[key], normalize_images=normalize_images) 130 | return preprocessed_obs 131 | 132 | else: 133 | raise NotImplementedError(f"Preprocessing not implemented for {observation_space}") 134 | 135 | 136 | def get_obs_shape( 137 | observation_space: spaces.Space, 138 | ) -> Union[Tuple[int, ...], Dict[str, Tuple[int, ...]]]: 139 | """ 140 | Get the shape of the observation (useful for the buffers). 141 | :param observation_space: 142 | :return: 143 | """ 144 | if isinstance(observation_space, spaces.Box): 145 | return observation_space.shape 146 | elif isinstance(observation_space, spaces.Discrete): 147 | # Observation is an int 148 | return (1,) 149 | elif isinstance(observation_space, spaces.MultiDiscrete): 150 | # Number of discrete features 151 | return (int(len(observation_space.nvec)),) 152 | elif isinstance(observation_space, spaces.MultiBinary): 153 | # Number of binary features 154 | if type(observation_space.n) in [tuple, list, np.ndarray]: 155 | return tuple(observation_space.n) 156 | else: 157 | return (int(observation_space.n),) 158 | elif isinstance(observation_space, spaces.Dict): 159 | return {key: get_obs_shape(subspace) for (key, subspace) in observation_space.spaces.items()} # type: ignore[misc] 160 | 161 | else: 162 | raise NotImplementedError(f"{observation_space} observation space is not supported") 163 | 164 | 165 | def get_flattened_obs_dim(observation_space: spaces.Space) -> int: 166 | """ 167 | Get the dimension of the observation space when flattened. 168 | It does not apply to image observation space. 169 | Used by the ``FlattenExtractor`` to compute the input shape. 170 | :param observation_space: 171 | :return: 172 | """ 173 | # See issue https://github.com/openai/gym/issues/1915 174 | # it may be a problem for Dict/Tuple spaces too... 175 | if isinstance(observation_space, spaces.MultiDiscrete): 176 | return sum(observation_space.nvec) 177 | else: 178 | # Use Gym internal method 179 | return spaces.utils.flatdim(observation_space) 180 | 181 | 182 | def get_action_dim(action_space: spaces.Space) -> int: 183 | """ 184 | Get the dimension of the action space. 185 | :param action_space: 186 | :return: 187 | """ 188 | if isinstance(action_space, spaces.Box): 189 | return int(np.prod(action_space.shape)) 190 | elif isinstance(action_space, spaces.Discrete): 191 | # Action is an int 192 | return 1 193 | elif isinstance(action_space, spaces.MultiDiscrete): 194 | # Number of discrete actions 195 | return int(len(action_space.nvec)) 196 | elif isinstance(action_space, spaces.MultiBinary): 197 | # Number of binary actions 198 | return int(action_space.n) 199 | elif isinstance(action_space, spaces.Dict): 200 | return get_action_dim(action_space["motor_action"]) 201 | else: 202 | raise NotImplementedError(f"{action_space} action space is not supported") 203 | 204 | 205 | def check_for_nested_spaces(obs_space: spaces.Space): 206 | """ 207 | Make sure the observation space does not have nested spaces (Dicts/Tuples inside Dicts/Tuples). 208 | If so, raise an Exception informing that there is no support for this. 209 | :param obs_space: an observation space 210 | :return: 211 | """ 212 | if isinstance(obs_space, (spaces.Dict, spaces.Tuple)): 213 | sub_spaces = obs_space.spaces.values() if isinstance(obs_space, spaces.Dict) else obs_space.spaces 214 | for sub_space in sub_spaces: 215 | if isinstance(sub_space, (spaces.Dict, spaces.Tuple)): 216 | raise NotImplementedError( 217 | "Nested observation spaces are not supported (Tuple/Dict space inside Tuple/Dict space)." 218 | ) 219 | 220 | 221 | def get_device(device: Union[torch.device, str] = "auto") -> torch.device: 222 | """ 223 | Retrieve PyTorch device. 224 | It checks that the requested device is available first. 225 | For now, it supports only cpu and cuda. 226 | By default, it tries to use the gpu. 227 | :param device: One for 'auto', 'cuda', 'cpu' 228 | :return: Supported Pytorch device 229 | """ 230 | # Cuda by default 231 | if device == "auto": 232 | device = "cuda" 233 | # Force conversion to torch.device 234 | device = torch.device(device) 235 | 236 | # Cuda not available 237 | if device.type == torch.device("cuda").type and not torch.cuda.is_available(): 238 | return torch.device("cpu") 239 | 240 | return device 241 | 242 | 243 | def get_timestr() -> str: 244 | current_datetime = datetime.now() 245 | return current_datetime.strftime("%m-%d-%H-%M-%S") 246 | 247 | 248 | def get_spatial_emb_indices(loc: np.ndarray, 249 | full_img_size=(4, 84, 84), 250 | img_size=(4, 21, 21), 251 | patch_size=(7, 7)) -> np.ndarray: 252 | # loc (2,) 253 | _, H, W = full_img_size 254 | _, h, w = img_size 255 | p1, p2 = patch_size 256 | 257 | st_x = loc[0] // p1 258 | st_y = loc[1] // p2 259 | 260 | ed_x = (loc[0] + h) // p1 261 | ed_y = (loc[1] + w) // p2 262 | 263 | ix, iy = np.meshgrid(np.arange(st_x, ed_x, dtype=np.int64), 264 | np.arange(st_y, ed_y, dtype=np.int64), indexing="ij") 265 | 266 | # print (ix, iy) 267 | indicies = (ix * H // p1 + iy).reshape(-1) 268 | 269 | return indicies 270 | 271 | def get_spatial_emb_mask(loc, 272 | mask, 273 | full_img_size=(4, 84, 84), 274 | img_size=(4, 21, 21), 275 | patch_size=(7, 7), 276 | latent_dim=144) -> np.ndarray: 277 | B, T, _ = loc.size() 278 | # return torch.randn_like() 279 | loc = loc.reshape(-1, 2) 280 | _, H, W = full_img_size 281 | _, h, w = img_size 282 | p1, p2 = patch_size 283 | num_tokens = h*w//p1//p2 284 | # print ("num_tokens", num_tokens) 285 | 286 | st_x = loc[..., 0] // p1 287 | st_y = loc[..., 1] // p2 288 | 289 | ed_x = (loc[..., 0] + h) // p1 290 | ed_y = (loc[..., 1] + w) // p2 291 | 292 | # mask = np.zeros(((32*6, H//p1, W//p2, latent_dim)), dtype=np.bool_) 293 | # mask = torch.zeros((32*6, H//p1, W//p2, latent_dim), dtype=torch.bool) 294 | mask[:] = False 295 | for i in range(B*T): 296 | # print (self.spatial_emb[0, st_x[i]:ed_x[i], st_y[i]:ed_y[i]].size()) 297 | mask[i, st_x[i]:ed_x[i], st_y[i]:ed_y[i]] = True 298 | return mask[:B*T] 299 | 300 | def weight_init_drq(m): 301 | if isinstance(m, nn.Linear): 302 | nn.init.orthogonal_(m.weight.data) 303 | if hasattr(m.bias, 'data'): 304 | m.bias.data.fill_(0.0) 305 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 306 | gain = nn.init.calculate_gain('relu') 307 | nn.init.orthogonal_(m.weight.data, gain) 308 | if hasattr(m.bias, 'data'): 309 | m.bias.data.fill_(0.0) 310 | 311 | 312 | def soft_update_params(net, target_net, tau): 313 | for param, target_param in zip(net.parameters(), target_net.parameters()): 314 | target_param.data.copy_(tau * param.data + 315 | (1 - tau) * target_param.data) 316 | 317 | class TruncatedNormal(pyd.Normal): 318 | def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6): 319 | super().__init__(loc, scale, validate_args=False) 320 | self.low = low 321 | self.high = high 322 | self.eps = eps 323 | 324 | def _clamp(self, x): 325 | clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps) 326 | x = x - x.detach() + clamped_x.detach() 327 | return x 328 | 329 | def sample(self, clip=None, sample_shape=torch.Size()): 330 | shape = self._extended_shape(sample_shape) 331 | eps = _standard_normal(shape, 332 | dtype=self.loc.dtype, 333 | device=self.loc.device) 334 | eps *= self.scale 335 | if clip is not None: 336 | eps = torch.clamp(eps, -clip, clip) 337 | x = self.loc + eps 338 | return self._clamp(x) 339 | 340 | 341 | def schedule_drq(schdl, step): 342 | try: 343 | return float(schdl) 344 | except ValueError: 345 | match = re.match(r'linear\((.+),(.+),(.+)\)', schdl) 346 | if match: 347 | init, final, duration = [float(g) for g in match.groups()] 348 | mix = np.clip(step / duration, 0.0, 1.0) 349 | return (1.0 - mix) * init + mix * final 350 | match = re.match(r'step_linear\((.+),(.+),(.+),(.+),(.+)\)', schdl) 351 | if match: 352 | init, final1, duration1, final2, duration2 = [ 353 | float(g) for g in match.groups() 354 | ] 355 | if step <= duration1: 356 | mix = np.clip(step / duration1, 0.0, 1.0) 357 | return (1.0 - mix) * init + mix * final1 358 | else: 359 | mix = np.clip((step - duration1) / duration2, 0.0, 1.0) 360 | return (1.0 - mix) * final1 + mix * final2 361 | raise NotImplementedError(schdl) 362 | 363 | def get_sugarl_reward_scale_robosuite(task_name) -> float: 364 | if task_name == "Lift": 365 | sugarl_reward_scale = 150/500 366 | elif task_name == "ToolHang": 367 | sugarl_reward_scale = 100/500 368 | else: 369 | sugarl_reward_scale = 100/500 370 | return sugarl_reward_scale 371 | 372 | 373 | def get_sugarl_reward_scale_dmc(domain_name, task_name) -> float: 374 | if domain_name == "ball_in_cup" and task_name == "catch": 375 | sugarl_reward_scale = 320/500 376 | elif domain_name == "cartpole" and task_name == "swingup": 377 | sugarl_reward_scale = 380/500 378 | elif domain_name == "cheetah" and task_name == "run": 379 | sugarl_reward_scale = 245/500 380 | elif domain_name == "dog" and task_name == "fetch": 381 | sugarl_reward_scale = 4.5/500 382 | elif domain_name == "finger" and task_name == "spin": 383 | sugarl_reward_scale = 290/500 384 | elif domain_name == "fish" and task_name == "swim": 385 | sugarl_reward_scale = 64/500 386 | elif domain_name == "reacher" and task_name == "easy": 387 | sugarl_reward_scale = 200/500 388 | elif domain_name == "walker" and task_name == "walk": 389 | sugarl_reward_scale = 290/500 390 | else: 391 | return 1. 392 | 393 | return sugarl_reward_scale 394 | 395 | def get_sugarl_reward_scale_atari(game) -> float: 396 | base_scale = 4.0 397 | sugarl_reward_scale = 1/200 398 | if game in ["alien", "assault", "asterix", "battle_zone", "seaquest", "qbert", "private_eye", "road_runner"]: 399 | sugarl_reward_scale = 1/100 400 | elif game in ["kangaroo", "krull", "chopper_command", "demon_attack"]: 401 | sugarl_reward_scale = 1/200 402 | elif game in ["up_n_down", "frostbite", "ms_pacman", "amidar", "gopher", "boxing"]: 403 | sugarl_reward_scale = 1/50 404 | elif game in ["hero", "jamesbond", "kung_fu_master"]: 405 | sugarl_reward_scale = 1/25 406 | elif game in ["crazy_climber"]: 407 | sugarl_reward_scale = 1/20 408 | elif game in ["freeway"]: 409 | sugarl_reward_scale = 1/1600 410 | elif game in ["pong"]: 411 | sugarl_reward_scale = 1/800 412 | elif game in ["bank_heist"]: 413 | sugarl_reward_scale = 1/250 414 | elif game in ["breakout"]: 415 | sugarl_reward_scale = 1/35 416 | sugarl_reward_scale = sugarl_reward_scale * base_scale 417 | return sugarl_reward_scale 418 | -------------------------------------------------------------------------------- /agent/dqn_atari_wp_base_peripheral.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | import os.path as osp 4 | import random 5 | import time 6 | from collections import deque 7 | from itertools import product 8 | from distutils.util import strtobool 9 | 10 | sys.path.append(osp.dirname(osp.dirname(osp.realpath(__file__)))) 11 | os.environ["OMP_NUM_THREADS"] = "1" 12 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 13 | import warnings 14 | warnings.filterwarnings("ignore", category=UserWarning) 15 | 16 | import gymnasium as gym 17 | from gymnasium.spaces import Discrete, Dict, Box 18 | import numpy as np 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | import torch.optim as optim 23 | from torchvision.transforms import Resize 24 | 25 | from common.buffer import ReplayBuffer 26 | from common.pvm_buffer import PVMBuffer 27 | from common.utils import get_timestr, seed_everything 28 | from torch.utils.tensorboard import SummaryWriter 29 | 30 | from active_gym import AtariFixedFovealPeripheralEnv, AtariEnvArgs 31 | 32 | 33 | 34 | def parse_args(): 35 | # fmt: off 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), 38 | help="the name of this experiment") 39 | parser.add_argument("--seed", type=int, default=1, 40 | help="seed of the experiment") 41 | parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, 42 | help="if toggled, cuda will be enabled by default") 43 | parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, 44 | help="whether to capture videos of the agent performances (check out `videos` folder)") 45 | 46 | # env setting 47 | parser.add_argument("--env", type=str, default="breakout", 48 | help="the id of the environment") 49 | parser.add_argument("--env-num", type=int, default=1, 50 | help="# envs in parallel") 51 | parser.add_argument("--frame-stack", type=int, default=4, 52 | help="frame stack #") 53 | parser.add_argument("--action-repeat", type=int, default=4, 54 | help="action repeat #") 55 | parser.add_argument("--clip-reward", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True) 56 | 57 | # fov setting 58 | parser.add_argument("--fov-size", type=int, default=0) 59 | parser.add_argument("--fov-init-loc", type=int, default=0) 60 | parser.add_argument("--sensory-action-mode", type=str, default="absolute") 61 | parser.add_argument("--sensory-action-space", type=int, default=10) # ignored when sensory_action_mode="relative" 62 | parser.add_argument("--resize-to-full", default=False, action="store_true") 63 | parser.add_argument("--peripheral-res", type=int, default=20) 64 | # for discrete observ action 65 | parser.add_argument("--sensory-action-x-size", type=int, default=4) 66 | parser.add_argument("--sensory-action-y-size", type=int, default=4) 67 | # pvm setting 68 | parser.add_argument("--pvm-stack", type=int, default=1) 69 | 70 | # Algorithm specific arguments 71 | parser.add_argument("--total-timesteps", type=int, default=3000000, 72 | help="total timesteps of the experiments") 73 | parser.add_argument("--learning-rate", type=float, default=1e-4, 74 | help="the learning rate of the optimizer") 75 | parser.add_argument("--buffer-size", type=int, default=500000, 76 | help="the replay memory buffer size") 77 | parser.add_argument("--gamma", type=float, default=0.99, 78 | help="the discount factor gamma") 79 | parser.add_argument("--target-network-frequency", type=int, default=1000, 80 | help="the timesteps it takes to update the target network") 81 | parser.add_argument("--batch-size", type=int, default=32, 82 | help="the batch size of sample from the reply memory") 83 | parser.add_argument("--start-e", type=float, default=1, 84 | help="the starting epsilon for exploration") 85 | parser.add_argument("--end-e", type=float, default=0.01, 86 | help="the ending epsilon for exploration") 87 | parser.add_argument("--exploration-fraction", type=float, default=0.10, 88 | help="the fraction of `total-timesteps` it takes from start-e to go end-e") 89 | parser.add_argument("--learning-starts", type=int, default=80000, 90 | help="timestep to start learning") 91 | parser.add_argument("--train-frequency", type=int, default=4, 92 | help="the frequency of training") 93 | 94 | # eval args 95 | parser.add_argument("--eval-frequency", type=int, default=-1, 96 | help="eval frequency. default -1 is eval at the end.") 97 | parser.add_argument("--eval-num", type=int, default=10, 98 | help="eval frequency. default -1 is eval at the end.") 99 | args = parser.parse_args() 100 | # fmt: on 101 | return args 102 | 103 | 104 | def make_env(env_name, seed, **kwargs): 105 | def thunk(): 106 | env_args = AtariEnvArgs( 107 | game=env_name, seed=seed, obs_size=(84, 84), **kwargs 108 | ) 109 | env = AtariFixedFovealPeripheralEnv(env_args) 110 | env.action_space.seed(seed) 111 | env.observation_space.seed(seed) 112 | return env 113 | 114 | return thunk 115 | 116 | 117 | # ALGO LOGIC: initialize agent here: 118 | class QNetwork(nn.Module): 119 | def __init__(self, env): 120 | super().__init__() 121 | if isinstance(env.single_action_space, Discrete): 122 | action_space_size = env.single_action_space.n 123 | elif isinstance(env.single_action_space, Dict): 124 | action_space_size = env.single_action_space["motor_action"].n 125 | self.network = nn.Sequential( 126 | nn.Conv2d(4, 32, 8, stride=4), 127 | nn.ReLU(), 128 | nn.Conv2d(32, 64, 4, stride=2), 129 | nn.ReLU(), 130 | nn.Conv2d(64, 64, 3, stride=1), 131 | nn.ReLU(), 132 | nn.Flatten(), 133 | nn.Linear(3136, 512), 134 | nn.ReLU(), 135 | nn.Linear(512, action_space_size), 136 | ) 137 | 138 | def forward(self, x): 139 | return self.network(x) 140 | 141 | 142 | def linear_schedule(start_e: float, end_e: float, duration: int, t: int): 143 | slope = (end_e - start_e) / duration 144 | return max(slope * t + start_e, end_e) 145 | 146 | 147 | 148 | 149 | 150 | if __name__ == "__main__": 151 | args = parse_args() 152 | args.env = args.env.lower() 153 | run_name = f"{args.env}__{os.path.basename(__file__)}__{args.seed}__{get_timestr()}" 154 | run_dir = os.path.join("runs", args.exp_name) 155 | if not os.path.exists(run_dir): 156 | os.makedirs(run_dir, exist_ok=True) 157 | 158 | writer = SummaryWriter(os.path.join(run_dir, run_name)) 159 | writer.add_text( 160 | "hyperparameters", 161 | "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), 162 | ) 163 | 164 | # TRY NOT TO MODIFY: seeding 165 | seed_everything(args.seed) 166 | 167 | 168 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") 169 | 170 | # env setup 171 | envs = [] 172 | for i in range(args.env_num): 173 | envs.append(make_env(args.env, args.seed+i, frame_stack=args.frame_stack, action_repeat=args.action_repeat, 174 | fov_size=(args.fov_size, args.fov_size), 175 | fov_init_loc=(args.fov_init_loc, args.fov_init_loc), 176 | peripheral_res=(args.peripheral_res, args.peripheral_res), 177 | sensory_action_mode=args.sensory_action_mode, 178 | sensory_action_space=(-args.sensory_action_space, args.sensory_action_space), 179 | resize_to_full=args.resize_to_full, 180 | clip_reward=args.clip_reward, 181 | mask_out=True)) 182 | # envs = gym.vector.AsyncVectorEnv(envs) 183 | envs = gym.vector.SyncVectorEnv(envs) 184 | 185 | resize = Resize((84, 84)) 186 | 187 | # get a discrete observ action space 188 | OBSERVATION_SIZE = (84, 84) 189 | sensory_action_set = [(0, 0)] 190 | 191 | q_network = QNetwork(envs).to(device) 192 | optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate) 193 | target_network = QNetwork(envs).to(device) 194 | target_network.load_state_dict(q_network.state_dict()) 195 | 196 | rb = ReplayBuffer( 197 | args.buffer_size, 198 | envs.single_observation_space, 199 | envs.single_action_space["motor_action"], 200 | device, 201 | n_envs=envs.num_envs, 202 | optimize_memory_usage=True, 203 | handle_timeout_termination=False, 204 | ) 205 | start_time = time.time() 206 | 207 | # TRY NOT TO MODIFY: start the game 208 | obs, _ = envs.reset() 209 | global_transitions = 0 210 | fov_idx = random.randint(0, len(sensory_action_set)-1) 211 | pvm_buffer = PVMBuffer(args.pvm_stack, (envs.num_envs, args.frame_stack,)+OBSERVATION_SIZE) 212 | 213 | while global_transitions < args.total_timesteps: 214 | pvm_buffer.append(obs) 215 | pvm_obs = pvm_buffer.get_obs(mode="stack_max") 216 | # ALGO LOGIC: put action logic here 217 | epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_transitions) 218 | if random.random() < epsilon: 219 | actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) 220 | motor_actions = np.array([actions[0]["motor_action"]]) 221 | sensory_actions = sensory_action_set[random.randint(0, len(sensory_action_set)-1)] 222 | else: 223 | q_values = q_network(resize(torch.from_numpy(pvm_obs)).to(device)) 224 | motor_actions = torch.argmax(q_values, dim=1).cpu().numpy() 225 | sensory_actions = sensory_action_set[fov_idx] 226 | 227 | # TRY NOT TO MODIFY: execute the game and log data. 228 | next_obs, rewards, dones, _, infos = envs.step({"motor_action": motor_actions, 229 | "sensory_action": [sensory_actions] }) 230 | fov_idx = (fov_idx + 1) % len(sensory_action_set) 231 | # print (global_step, infos) 232 | 233 | # TRY NOT TO MODIFY: record rewards for plotting purposes 234 | if "final_info" in infos: 235 | for idx, d in enumerate(dones): 236 | if d: 237 | print(f"[T: {time.time()-start_time:.2f}] [N: {global_transitions:07,d}] [R: {infos['final_info'][idx]['reward']:.2f}]") 238 | writer.add_scalar("charts/episodic_return", infos['final_info'][idx]["reward"], global_transitions) 239 | writer.add_scalar("charts/episodic_length", infos['final_info'][idx]["ep_len"], global_transitions) 240 | writer.add_scalar("charts/epsilon", epsilon, global_transitions) 241 | break 242 | 243 | 244 | # TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation` 245 | real_next_obs = next_obs 246 | for idx, d in enumerate(dones): 247 | if d: 248 | real_next_obs[idx] = infos["final_observation"][idx] 249 | fov_idx = random.randint(0, len(sensory_action_set)-1) 250 | pvm_buffer_copy = pvm_buffer.copy() 251 | pvm_buffer_copy.append(real_next_obs) 252 | real_next_pvm_obs = pvm_buffer_copy.get_obs(mode="stack_max") 253 | rb.add(pvm_obs, real_next_pvm_obs, motor_actions, rewards, dones, {}) 254 | 255 | # TRY NOT TO MODIFY: CRUCIAL step easy to overlook 256 | obs = next_obs 257 | 258 | # INC total transitions 259 | global_transitions += args.env_num 260 | 261 | 262 | obs_backup = obs # back obs 263 | # ALGO LOGIC: training. 264 | if global_transitions > args.learning_starts: 265 | if global_transitions % args.train_frequency == 0: 266 | data = rb.sample(args.batch_size // args.env_num) # counter-balance the true global transitions used for training 267 | with torch.no_grad(): 268 | target_max, _ = target_network(resize(data.next_observations)).max(dim=1) 269 | td_target = data.rewards.flatten() + args.gamma * target_max * (1 - data.dones.flatten()) 270 | old_val = q_network(resize(data.observations)).gather(1, data.actions).squeeze() 271 | loss = F.mse_loss(td_target, old_val) 272 | 273 | if global_transitions % 100 == 0: 274 | writer.add_scalar("losses/td_loss", loss, global_transitions) 275 | writer.add_scalar("losses/q_values", old_val.mean().item(), global_transitions) 276 | # print("SPS:", int(global_step / (time.time() - start_time))) 277 | writer.add_scalar("charts/SPS", int(global_transitions / (time.time() - start_time)), global_transitions) 278 | 279 | # optimize the model 280 | optimizer.zero_grad() 281 | loss.backward() 282 | optimizer.step() 283 | 284 | # update the target network 285 | if (global_transitions // args.env_num) % args.target_network_frequency == 0: 286 | target_network.load_state_dict(q_network.state_dict()) 287 | 288 | # evaluation 289 | if (global_transitions % args.eval_frequency == 0 and args.eval_frequency > 0) or \ 290 | (global_transitions >= args.total_timesteps): 291 | q_network.eval() 292 | 293 | eval_episodic_returns, eval_episodic_lengths = [], [] 294 | 295 | for eval_ep in range(args.eval_num): 296 | eval_env = [make_env(args.env, args.seed+eval_ep, frame_stack=args.frame_stack, action_repeat=args.action_repeat, 297 | fov_size=(args.fov_size, args.fov_size), 298 | fov_init_loc=(args.fov_init_loc, args.fov_init_loc), 299 | peripheral_res=(args.peripheral_res, args.peripheral_res), 300 | sensory_action_mode=args.sensory_action_mode, 301 | sensory_action_space=(-args.sensory_action_space, args.sensory_action_space), 302 | resize_to_full=args.resize_to_full, 303 | clip_reward=args.clip_reward, 304 | training=False, 305 | mask_out=True, 306 | record=args.capture_video)] 307 | eval_env = gym.vector.SyncVectorEnv(eval_env) 308 | obs, _ = eval_env.reset() 309 | done = False 310 | fov_idx = random.randint(0, len(sensory_action_set)-1) 311 | pvm_buffer = PVMBuffer(args.pvm_stack, (eval_env.num_envs, args.frame_stack,)+OBSERVATION_SIZE) 312 | while not done: 313 | pvm_buffer.append(obs) 314 | pvm_obs = pvm_buffer.get_obs(mode="stack_max") 315 | q_values = q_network(resize(torch.from_numpy(pvm_obs)).to(device)) 316 | motor_actions = torch.argmax(q_values, dim=1).cpu().numpy() 317 | next_obs, rewards, dones, _, infos = eval_env.step({"motor_action": motor_actions, 318 | "sensory_action": [sensory_action_set[fov_idx]]}) 319 | obs = next_obs 320 | fov_idx = (fov_idx + 1) % len(sensory_action_set) 321 | done = dones[0] 322 | if done: 323 | eval_episodic_returns.append(infos['final_info'][0]["reward"]) 324 | eval_episodic_lengths.append(infos['final_info'][0]["ep_len"]) 325 | fov_idx = random.randint(0, len(sensory_action_set)-1) 326 | if args.capture_video: 327 | record_file_dir = os.path.join("recordings", args.exp_name, os.path.basename(__file__).rstrip(".py")) 328 | os.makedirs(record_file_dir, exist_ok=True) 329 | record_file_fn = f"seed{args.seed}_step{global_transitions:07d}_record.pt" 330 | eval_env.envs[0].save_record_to_file(os.path.join(record_file_dir, record_file_fn)) 331 | 332 | writer.add_scalar("charts/eval_episodic_return", np.mean(eval_episodic_returns), global_transitions) 333 | writer.add_scalar("charts/eval_episodic_return_std", np.std(eval_episodic_returns), global_transitions) 334 | print(f"[T: {time.time()-start_time:.2f}] [N: {global_transitions:07,d}] [Eval R: {np.mean(eval_episodic_returns):.2f}+/-{np.std(eval_episodic_returns):.2f}] [R list: {','.join([str(r) for r in eval_episodic_returns])}]") 335 | 336 | q_network.train() 337 | obs = obs_backup # restore obs if eval occurs 338 | 339 | 340 | envs.close() 341 | eval_env.close() 342 | writer.close() 343 | -------------------------------------------------------------------------------- /agent/dqn_atari_wp_single_policy.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | import os.path as osp 4 | import random 5 | import time 6 | from collections import deque 7 | from itertools import product 8 | from distutils.util import strtobool 9 | 10 | sys.path.append(osp.dirname(osp.dirname(osp.realpath(__file__)))) 11 | os.environ["OMP_NUM_THREADS"] = "1" 12 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 13 | import warnings 14 | warnings.filterwarnings("ignore", category=UserWarning) 15 | 16 | import gymnasium as gym 17 | from gymnasium.spaces import Discrete, Dict, Box 18 | import numpy as np 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | import torch.optim as optim 23 | from torchvision.transforms import Resize 24 | 25 | from common.buffer import ReplayBuffer 26 | from common.pvm_buffer import PVMBuffer 27 | from common.utils import get_timestr, seed_everything 28 | from torch.utils.tensorboard import SummaryWriter 29 | 30 | from active_gym import AtariFixedFovealEnv, AtariEnvArgs 31 | 32 | 33 | 34 | def parse_args(): 35 | # fmt: off 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), 38 | help="the name of this experiment") 39 | parser.add_argument("--seed", type=int, default=1, 40 | help="seed of the experiment") 41 | parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, 42 | help="if toggled, cuda will be enabled by default") 43 | parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, 44 | help="whether to capture videos of the agent performances (check out `videos` folder)") 45 | 46 | # env setting 47 | parser.add_argument("--env", type=str, default="breakout", 48 | help="the id of the environment") 49 | parser.add_argument("--env-num", type=int, default=1, 50 | help="# envs in parallel") 51 | parser.add_argument("--frame-stack", type=int, default=4, 52 | help="frame stack #") 53 | parser.add_argument("--action-repeat", type=int, default=4, 54 | help="action repeat #") 55 | parser.add_argument("--clip-reward", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True) 56 | 57 | # fov setting 58 | parser.add_argument("--fov-size", type=int, default=50) 59 | parser.add_argument("--fov-init-loc", type=int, default=0) 60 | parser.add_argument("--sensory-action-mode", type=str, default="relative") 61 | parser.add_argument("--sensory-action-space", type=int, default=10) # ignored when sensory_action_mode="relative" 62 | parser.add_argument("--resize-to-full", default=False, action="store_true") 63 | # for discrete observ action 64 | parser.add_argument("--sensory-action-x-size", type=int, default=4) 65 | parser.add_argument("--sensory-action-y-size", type=int, default=4) 66 | # pvm setting 67 | parser.add_argument("--pvm-stack", type=int, default=3) 68 | 69 | # Algorithm specific arguments 70 | parser.add_argument("--total-timesteps", type=int, default=3000000, 71 | help="total timesteps of the experiments") 72 | parser.add_argument("--learning-rate", type=float, default=1e-4, 73 | help="the learning rate of the optimizer") 74 | parser.add_argument("--buffer-size", type=int, default=500000, 75 | help="the replay memory buffer size") 76 | parser.add_argument("--gamma", type=float, default=0.99, 77 | help="the discount factor gamma") 78 | parser.add_argument("--target-network-frequency", type=int, default=1000, 79 | help="the timesteps it takes to update the target network") 80 | parser.add_argument("--batch-size", type=int, default=32, 81 | help="the batch size of sample from the reply memory") 82 | parser.add_argument("--start-e", type=float, default=1, 83 | help="the starting epsilon for exploration") 84 | parser.add_argument("--end-e", type=float, default=0.01, 85 | help="the ending epsilon for exploration") 86 | parser.add_argument("--exploration-fraction", type=float, default=0.10, 87 | help="the fraction of `total-timesteps` it takes from start-e to go end-e") 88 | parser.add_argument("--learning-starts", type=int, default=80000, 89 | help="timestep to start learning") 90 | parser.add_argument("--train-frequency", type=int, default=4, 91 | help="the frequency of training") 92 | 93 | # eval args 94 | parser.add_argument("--eval-frequency", type=int, default=-1, 95 | help="eval frequency. default -1 is eval at the end.") 96 | parser.add_argument("--eval-num", type=int, default=10, 97 | help="eval frequency. default -1 is eval at the end.") 98 | args = parser.parse_args() 99 | # fmt: on 100 | return args 101 | 102 | 103 | def make_env(env_name, seed, **kwargs): 104 | def thunk(): 105 | env_args = AtariEnvArgs( 106 | game=env_name, seed=seed, obs_size=(84, 84), **kwargs 107 | ) 108 | env = AtariFixedFovealEnv(env_args) 109 | env.action_space.seed(seed) 110 | env.observation_space.seed(seed) 111 | return env 112 | 113 | return thunk 114 | 115 | 116 | # ALGO LOGIC: initialize agent here: 117 | class QNetwork(nn.Module): 118 | def __init__(self, env, override_action_set=None): 119 | super().__init__() 120 | if override_action_set: 121 | action_space_size = override_action_set.n 122 | else: 123 | if isinstance(env.single_action_space, Discrete): 124 | action_space_size = env.single_action_space.n 125 | elif isinstance(env.single_action_space, Dict): 126 | action_space_size = env.single_action_space["motor_action"].n 127 | self.network = nn.Sequential( 128 | nn.Conv2d(4, 32, 8, stride=4), 129 | nn.ReLU(), 130 | nn.Conv2d(32, 64, 4, stride=2), 131 | nn.ReLU(), 132 | nn.Conv2d(64, 64, 3, stride=1), 133 | nn.ReLU(), 134 | nn.Flatten(), 135 | nn.Linear(3136, 512), 136 | nn.ReLU(), 137 | nn.Linear(512, action_space_size), 138 | ) 139 | 140 | def forward(self, x): 141 | return self.network(x) 142 | 143 | 144 | def linear_schedule(start_e: float, end_e: float, duration: int, t: int): 145 | slope = (end_e - start_e) / duration 146 | return max(slope * t + start_e, end_e) 147 | 148 | 149 | 150 | 151 | 152 | if __name__ == "__main__": 153 | args = parse_args() 154 | args.env = args.env.lower() 155 | run_name = f"{args.env}__{os.path.basename(__file__)}__{args.seed}__{get_timestr()}" 156 | run_dir = os.path.join("runs", args.exp_name) 157 | if not os.path.exists(run_dir): 158 | os.makedirs(run_dir, exist_ok=True) 159 | 160 | writer = SummaryWriter(os.path.join(run_dir, run_name)) 161 | writer.add_text( 162 | "hyperparameters", 163 | "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), 164 | ) 165 | 166 | # TRY NOT TO MODIFY: seeding 167 | seed_everything(args.seed) 168 | 169 | 170 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") 171 | 172 | # get a discrete observ action space 173 | OBSERVATION_SIZE = (84, 84) 174 | observ_x_max, observ_y_max = OBSERVATION_SIZE[0]-args.fov_size, OBSERVATION_SIZE[1]-args.fov_size 175 | sensory_action_step = (observ_x_max//args.sensory_action_x_size, 176 | observ_y_max//args.sensory_action_y_size) 177 | sensory_action_set = [(-sensory_action_step[0], 0), 178 | (sensory_action_step[0], 0), 179 | (0, 0), 180 | (0, -sensory_action_step[1]), 181 | (0, sensory_action_step[1])] 182 | 183 | # env setup 184 | envs = [] 185 | for i in range(args.env_num): 186 | envs.append(make_env(args.env, args.seed+i, frame_stack=args.frame_stack, action_repeat=args.action_repeat, 187 | fov_size=(args.fov_size, args.fov_size), 188 | fov_init_loc=(args.fov_init_loc, args.fov_init_loc), 189 | sensory_action_mode=args.sensory_action_mode, 190 | sensory_action_space=(-max(sensory_action_step), max(sensory_action_step)), 191 | resize_to_full=args.resize_to_full, 192 | clip_reward=args.clip_reward, 193 | mask_out=True)) 194 | # envs = gym.vector.AsyncVectorEnv(envs) 195 | envs = gym.vector.SyncVectorEnv(envs) 196 | 197 | resize = Resize((84, 84)) 198 | 199 | # make motor sensory joint action space 200 | motor_action_set = list(range(envs.single_action_space["motor_action"].n)) 201 | motor_sensory_joint_action_set = [] 202 | for ma in motor_action_set: 203 | for sa in sensory_action_set: 204 | motor_sensory_joint_action_set.append((ma, *sa)) 205 | motor_sensory_joint_action_space = Discrete(len(motor_sensory_joint_action_set), seed=args.seed) 206 | 207 | # make a method to seperate joint action 208 | def seperate_motor_sensory_joint_action(msas: np.ndarray): 209 | mas, sas = [], [] 210 | for msa in msas: 211 | msa = motor_sensory_joint_action_set[msa] 212 | mas.append(msa[0]) 213 | sas.append(msa[1:]) 214 | mas = np.array(mas) 215 | sas = np.array(sas) 216 | return mas, sas 217 | 218 | 219 | 220 | q_network = QNetwork(envs, motor_sensory_joint_action_space).to(device) 221 | optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate) 222 | target_network = QNetwork(envs, motor_sensory_joint_action_space).to(device) 223 | target_network.load_state_dict(q_network.state_dict()) 224 | 225 | rb = ReplayBuffer( 226 | args.buffer_size, 227 | envs.single_observation_space, 228 | motor_sensory_joint_action_space, 229 | device, 230 | n_envs=envs.num_envs, 231 | optimize_memory_usage=True, 232 | handle_timeout_termination=False, 233 | ) 234 | start_time = time.time() 235 | 236 | # TRY NOT TO MODIFY: start the game 237 | obs, _ = envs.reset() 238 | global_transitions = 0 239 | pvm_buffer = PVMBuffer(args.pvm_stack, (envs.num_envs, args.frame_stack,)+OBSERVATION_SIZE) 240 | 241 | while global_transitions < args.total_timesteps: 242 | pvm_buffer.append(obs) 243 | pvm_obs = pvm_buffer.get_obs(mode="stack_max") 244 | # ALGO LOGIC: put action logic here 245 | epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_transitions) 246 | if random.random() < epsilon: 247 | actions = np.array([motor_sensory_joint_action_space.sample() for _ in range(envs.num_envs)]) 248 | else: 249 | q_values = q_network(resize(torch.from_numpy(pvm_obs)).to(device)) 250 | actions = torch.argmax(q_values, dim=1).cpu().numpy() 251 | 252 | # print (actions) 253 | motor_actions, sensory_actions = seperate_motor_sensory_joint_action(actions) 254 | 255 | # TRY NOT TO MODIFY: execute the game and log data. 256 | next_obs, rewards, dones, _, infos = envs.step({"motor_action": motor_actions, 257 | "sensory_action": sensory_actions }) 258 | 259 | # TRY NOT TO MODIFY: record rewards for plotting purposes 260 | if "final_info" in infos: 261 | for idx, d in enumerate(dones): 262 | if d: 263 | print(f"[T: {time.time()-start_time:.2f}] [N: {global_transitions:07,d}] [R: {infos['final_info'][idx]['reward']:.2f}]") 264 | writer.add_scalar("charts/episodic_return", infos['final_info'][idx]["reward"], global_transitions) 265 | writer.add_scalar("charts/episodic_length", infos['final_info'][idx]["ep_len"], global_transitions) 266 | writer.add_scalar("charts/epsilon", epsilon, global_transitions) 267 | break 268 | 269 | # TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation` 270 | real_next_obs = next_obs 271 | for idx, d in enumerate(dones): 272 | if d: 273 | real_next_obs[idx] = infos["final_observation"][idx] 274 | fov_idx = random.randint(0, len(sensory_action_set)-1) 275 | pvm_buffer_copy = pvm_buffer.copy() 276 | pvm_buffer_copy.append(real_next_obs) 277 | real_next_pvm_obs = pvm_buffer_copy.get_obs(mode="stack_max") 278 | rb.add(pvm_obs, real_next_pvm_obs, actions, rewards, dones, {}) 279 | 280 | # TRY NOT TO MODIFY: CRUCIAL step easy to overlook 281 | obs = next_obs 282 | 283 | # INC total transitions 284 | global_transitions += args.env_num 285 | 286 | 287 | obs_backup = obs # back obs 288 | # ALGO LOGIC: training. 289 | if global_transitions > args.learning_starts: 290 | if global_transitions % args.train_frequency == 0: 291 | data = rb.sample(args.batch_size // args.env_num) # counter-balance the true global transitions used for training 292 | with torch.no_grad(): 293 | target_max, _ = target_network(resize(data.next_observations)).max(dim=1) 294 | td_target = data.rewards.flatten() + args.gamma * target_max * (1 - data.dones.flatten()) 295 | old_val = q_network(resize(data.observations)).gather(1, data.actions).squeeze() 296 | loss = F.mse_loss(td_target, old_val) 297 | 298 | if global_transitions % 100 == 0: 299 | writer.add_scalar("losses/td_loss", loss, global_transitions) 300 | writer.add_scalar("losses/q_values", old_val.mean().item(), global_transitions) 301 | # print("SPS:", int(global_step / (time.time() - start_time))) 302 | writer.add_scalar("charts/SPS", int(global_transitions / (time.time() - start_time)), global_transitions) 303 | 304 | # optimize the model 305 | optimizer.zero_grad() 306 | loss.backward() 307 | optimizer.step() 308 | 309 | # update the target network 310 | if (global_transitions // args.env_num) % args.target_network_frequency == 0: 311 | target_network.load_state_dict(q_network.state_dict()) 312 | 313 | # evaluation 314 | if (global_transitions % args.eval_frequency == 0 and args.eval_frequency > 0) or \ 315 | (global_transitions >= args.total_timesteps): 316 | q_network.eval() 317 | 318 | eval_episodic_returns, eval_episodic_lengths = [], [] 319 | 320 | for eval_ep in range(args.eval_num): 321 | eval_env = [make_env(args.env, args.seed+eval_ep, frame_stack=args.frame_stack, action_repeat=args.action_repeat, 322 | fov_size=(args.fov_size, args.fov_size), 323 | fov_init_loc=(args.fov_init_loc, args.fov_init_loc), 324 | sensory_action_mode=args.sensory_action_mode, 325 | sensory_action_space=(-max(sensory_action_step), max(sensory_action_step)), 326 | resize_to_full=args.resize_to_full, 327 | clip_reward=args.clip_reward, 328 | training=False, 329 | mask_out=True)] 330 | eval_env = gym.vector.SyncVectorEnv(eval_env) 331 | obs, _ = eval_env.reset() 332 | done = False 333 | pvm_buffer = PVMBuffer(args.pvm_stack, (eval_env.num_envs, args.frame_stack,)+OBSERVATION_SIZE) 334 | while not done: 335 | pvm_buffer.append(obs) 336 | pvm_obs = pvm_buffer.get_obs(mode="stack_max") 337 | q_values = q_network(resize(torch.from_numpy(pvm_obs)).to(device)) 338 | actions = torch.argmax(q_values, dim=1).cpu().numpy() 339 | motor_actions, sensory_actions = seperate_motor_sensory_joint_action(actions) 340 | next_obs, rewards, dones, _, infos = eval_env.step({"motor_action": motor_actions, 341 | "sensory_action": sensory_actions}) 342 | obs = next_obs 343 | done = dones[0] 344 | if done: 345 | eval_episodic_returns.append(infos['final_info'][0]["reward"]) 346 | eval_episodic_lengths.append(infos['final_info'][0]["ep_len"]) 347 | 348 | if args.capture_video: 349 | record_file_dir = os.path.join("recordings", args.exp_name, os.path.basename(__file__).rstrip(".py")) 350 | os.makedirs(record_file_dir, exist_ok=True) 351 | record_file_fn = f"seed{args.seed}_step{global_transitions:07d}_eval{eval_ep:02d}_record.pt" 352 | eval_env.envs[0].save_record_to_file(os.path.join(record_file_dir, record_file_fn)) 353 | if eval_ep == 0: 354 | model_file_dir = os.path.join("trained_models", args.exp_name, os.path.basename(__file__).rstrip(".py"), args.env) 355 | os.makedirs(model_file_dir, exist_ok=True) 356 | model_fn = f"{args.env}_seed{args.seed}_step{global_transitions:07d}_model.pt" 357 | torch.save({"sfn": None, "q": q_network.state_dict()}, os.path.join(model_file_dir, model_fn)) 358 | 359 | writer.add_scalar("charts/eval_episodic_return", np.mean(eval_episodic_returns), global_transitions) 360 | writer.add_scalar("charts/eval_episodic_return_std", np.std(eval_episodic_returns), global_transitions) 361 | print(f"[T: {time.time()-start_time:.2f}] [N: {global_transitions:07,d}] [Eval R: {np.mean(eval_episodic_returns):.2f}+/-{np.std(eval_episodic_returns):.2f}] [R list: {','.join([str(r) for r in eval_episodic_returns])}]") 362 | 363 | q_network.train() 364 | obs = obs_backup # restore obs if eval occurs 365 | 366 | 367 | envs.close() 368 | eval_env.close() 369 | writer.close() -------------------------------------------------------------------------------- /agent/dqn_atari_single_policy.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | import os.path as osp 4 | import random 5 | import time 6 | from collections import deque 7 | from itertools import product 8 | from distutils.util import strtobool 9 | 10 | sys.path.append(osp.dirname(osp.dirname(osp.realpath(__file__)))) 11 | os.environ["OMP_NUM_THREADS"] = "1" 12 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 13 | import warnings 14 | warnings.filterwarnings("ignore", category=UserWarning) 15 | 16 | import gymnasium as gym 17 | from gymnasium.spaces import Discrete, Dict, Box 18 | import numpy as np 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | import torch.optim as optim 23 | from torchvision.transforms import Resize 24 | 25 | from common.buffer import ReplayBuffer 26 | from common.pvm_buffer import PVMBuffer 27 | from common.utils import get_timestr, seed_everything 28 | from torch.utils.tensorboard import SummaryWriter 29 | 30 | from active_gym import AtariFixedFovealEnv, AtariEnvArgs 31 | 32 | 33 | 34 | def parse_args(): 35 | # fmt: off 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), 38 | help="the name of this experiment") 39 | parser.add_argument("--seed", type=int, default=1, 40 | help="seed of the experiment") 41 | parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, 42 | help="if toggled, cuda will be enabled by default") 43 | parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, 44 | help="whether to capture videos of the agent performances (check out `videos` folder)") 45 | 46 | # env setting 47 | parser.add_argument("--env", type=str, default="breakout", 48 | help="the id of the environment") 49 | parser.add_argument("--env-num", type=int, default=1, 50 | help="# envs in parallel") 51 | parser.add_argument("--frame-stack", type=int, default=4, 52 | help="frame stack #") 53 | parser.add_argument("--action-repeat", type=int, default=4, 54 | help="action repeat #") 55 | parser.add_argument("--clip-reward", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True) 56 | 57 | # fov setting 58 | parser.add_argument("--fov-size", type=int, default=50) 59 | parser.add_argument("--fov-init-loc", type=int, default=0) 60 | parser.add_argument("--sensory-action-mode", type=str, default="relative") 61 | parser.add_argument("--sensory-action-space", type=int, default=10) # ignored when sensory_action_mode="relative" 62 | parser.add_argument("--resize-to-full", default=False, action="store_true") 63 | # for discrete observ action 64 | parser.add_argument("--sensory-action-x-size", type=int, default=4) 65 | parser.add_argument("--sensory-action-y-size", type=int, default=4) 66 | # pvm setting 67 | parser.add_argument("--pvm-stack", type=int, default=3) 68 | 69 | # Algorithm specific arguments 70 | parser.add_argument("--total-timesteps", type=int, default=3000000, 71 | help="total timesteps of the experiments") 72 | parser.add_argument("--learning-rate", type=float, default=1e-4, 73 | help="the learning rate of the optimizer") 74 | parser.add_argument("--buffer-size", type=int, default=500000, 75 | help="the replay memory buffer size") 76 | parser.add_argument("--gamma", type=float, default=0.99, 77 | help="the discount factor gamma") 78 | parser.add_argument("--target-network-frequency", type=int, default=1000, 79 | help="the timesteps it takes to update the target network") 80 | parser.add_argument("--batch-size", type=int, default=32, 81 | help="the batch size of sample from the reply memory") 82 | parser.add_argument("--start-e", type=float, default=1, 83 | help="the starting epsilon for exploration") 84 | parser.add_argument("--end-e", type=float, default=0.01, 85 | help="the ending epsilon for exploration") 86 | parser.add_argument("--exploration-fraction", type=float, default=0.10, 87 | help="the fraction of `total-timesteps` it takes from start-e to go end-e") 88 | parser.add_argument("--learning-starts", type=int, default=80000, 89 | help="timestep to start learning") 90 | parser.add_argument("--train-frequency", type=int, default=4, 91 | help="the frequency of training") 92 | 93 | # eval args 94 | parser.add_argument("--eval-frequency", type=int, default=-1, 95 | help="eval frequency. default -1 is eval at the end.") 96 | parser.add_argument("--eval-num", type=int, default=10, 97 | help="eval frequency. default -1 is eval at the end.") 98 | args = parser.parse_args() 99 | # fmt: on 100 | return args 101 | 102 | 103 | def make_env(env_name, seed, **kwargs): 104 | def thunk(): 105 | env_args = AtariEnvArgs( 106 | game=env_name, seed=seed, obs_size=(84, 84), **kwargs 107 | ) 108 | env = AtariFixedFovealEnv(env_args) 109 | env.action_space.seed(seed) 110 | env.observation_space.seed(seed) 111 | return env 112 | 113 | return thunk 114 | 115 | 116 | # ALGO LOGIC: initialize agent here: 117 | class QNetwork(nn.Module): 118 | def __init__(self, env, override_action_set=None): 119 | super().__init__() 120 | if override_action_set: 121 | action_space_size = override_action_set.n 122 | else: 123 | if isinstance(env.single_action_space, Discrete): 124 | action_space_size = env.single_action_space.n 125 | elif isinstance(env.single_action_space, Dict): 126 | action_space_size = env.single_action_space["motor_action"].n 127 | self.network = nn.Sequential( 128 | nn.Conv2d(4, 32, 8, stride=4), 129 | nn.ReLU(), 130 | nn.Conv2d(32, 64, 4, stride=2), 131 | nn.ReLU(), 132 | nn.Conv2d(64, 64, 3, stride=1), 133 | nn.ReLU(), 134 | nn.Flatten(), 135 | nn.Linear(3136, 512), 136 | nn.ReLU(), 137 | nn.Linear(512, action_space_size), 138 | ) 139 | 140 | def forward(self, x): 141 | return self.network(x) 142 | 143 | 144 | def linear_schedule(start_e: float, end_e: float, duration: int, t: int): 145 | slope = (end_e - start_e) / duration 146 | return max(slope * t + start_e, end_e) 147 | 148 | 149 | 150 | 151 | 152 | if __name__ == "__main__": 153 | args = parse_args() 154 | args.env = args.env.lower() 155 | run_name = f"{args.env}__{os.path.basename(__file__)}__{args.seed}__{get_timestr()}" 156 | run_dir = os.path.join("runs", args.exp_name) 157 | if not os.path.exists(run_dir): 158 | os.makedirs(run_dir, exist_ok=True) 159 | 160 | writer = SummaryWriter(os.path.join(run_dir, run_name)) 161 | writer.add_text( 162 | "hyperparameters", 163 | "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), 164 | ) 165 | 166 | # TRY NOT TO MODIFY: seeding 167 | seed_everything(args.seed) 168 | 169 | 170 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") 171 | 172 | # get a discrete observ action space 173 | OBSERVATION_SIZE = (84, 84) 174 | observ_x_max, observ_y_max = OBSERVATION_SIZE[0]-args.fov_size, OBSERVATION_SIZE[1]-args.fov_size 175 | sensory_action_step = (observ_x_max//args.sensory_action_x_size, 176 | observ_y_max//args.sensory_action_y_size) 177 | sensory_action_set = [(-sensory_action_step[0], 0), 178 | (sensory_action_step[0], 0), 179 | (0, 0), 180 | (0, -sensory_action_step[1]), 181 | (0, sensory_action_step[1])] 182 | 183 | # env setup 184 | envs = [] 185 | for i in range(args.env_num): 186 | envs.append(make_env(args.env, args.seed+i, frame_stack=args.frame_stack, action_repeat=args.action_repeat, 187 | fov_size=(args.fov_size, args.fov_size), 188 | fov_init_loc=(args.fov_init_loc, args.fov_init_loc), 189 | sensory_action_mode=args.sensory_action_mode, 190 | sensory_action_space=(-max(sensory_action_step), max(sensory_action_step)), 191 | resize_to_full=args.resize_to_full, 192 | clip_reward=args.clip_reward, 193 | mask_out=True)) 194 | # envs = gym.vector.AsyncVectorEnv(envs) 195 | envs = gym.vector.SyncVectorEnv(envs) 196 | 197 | resize = Resize((84, 84)) 198 | 199 | # make motor sensory joint action space 200 | motor_action_set = list(range(envs.single_action_space["motor_action"].n)) 201 | motor_sensory_joint_action_set = [] 202 | for ma in motor_action_set: 203 | for sa in sensory_action_set: 204 | motor_sensory_joint_action_set.append((ma, *sa)) 205 | motor_sensory_joint_action_space = Discrete(len(motor_sensory_joint_action_set), seed=args.seed) 206 | 207 | # make a method to seperate joint action 208 | def seperate_motor_sensory_joint_action(msas: np.ndarray): 209 | mas, sas = [], [] 210 | for msa in msas: 211 | msa = motor_sensory_joint_action_set[msa] 212 | mas.append(msa[0]) 213 | sas.append(msa[1:]) 214 | mas = np.array(mas) 215 | sas = np.array(sas) 216 | return mas, sas 217 | 218 | 219 | 220 | q_network = QNetwork(envs, motor_sensory_joint_action_space).to(device) 221 | optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate) 222 | target_network = QNetwork(envs, motor_sensory_joint_action_space).to(device) 223 | target_network.load_state_dict(q_network.state_dict()) 224 | 225 | rb = ReplayBuffer( 226 | args.buffer_size, 227 | envs.single_observation_space, 228 | motor_sensory_joint_action_space, 229 | device, 230 | n_envs=envs.num_envs, 231 | optimize_memory_usage=True, 232 | handle_timeout_termination=False, 233 | ) 234 | start_time = time.time() 235 | 236 | # TRY NOT TO MODIFY: start the game 237 | obs, _ = envs.reset() 238 | global_transitions = 0 239 | pvm_buffer = PVMBuffer(args.pvm_stack, (envs.num_envs, args.frame_stack,)+OBSERVATION_SIZE) 240 | 241 | while global_transitions < args.total_timesteps: 242 | pvm_buffer.append(obs) 243 | pvm_obs = pvm_buffer.get_obs(mode="stack_max") 244 | # ALGO LOGIC: put action logic here 245 | epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_transitions) 246 | if random.random() < epsilon: 247 | actions = np.array([motor_sensory_joint_action_space.sample() for _ in range(envs.num_envs)]) 248 | else: 249 | q_values = q_network(resize(torch.from_numpy(pvm_obs)).to(device)) 250 | actions = torch.argmax(q_values, dim=1).cpu().numpy() 251 | 252 | # print (actions) 253 | motor_actions, sensory_actions = seperate_motor_sensory_joint_action(actions) 254 | 255 | # TRY NOT TO MODIFY: execute the game and log data. 256 | next_obs, rewards, dones, _, infos = envs.step({"motor_action": motor_actions, 257 | "sensory_action": sensory_actions }) 258 | 259 | # TRY NOT TO MODIFY: record rewards for plotting purposes 260 | if "final_info" in infos: 261 | for idx, d in enumerate(dones): 262 | if d: 263 | print(f"[T: {time.time()-start_time:.2f}] [N: {global_transitions:07,d}] [R: {infos['final_info'][idx]['reward']:.2f}]") 264 | writer.add_scalar("charts/episodic_return", infos['final_info'][idx]["reward"], global_transitions) 265 | writer.add_scalar("charts/episodic_length", infos['final_info'][idx]["ep_len"], global_transitions) 266 | writer.add_scalar("charts/epsilon", epsilon, global_transitions) 267 | break 268 | 269 | # TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation` 270 | real_next_obs = next_obs 271 | for idx, d in enumerate(dones): 272 | if d: 273 | real_next_obs[idx] = infos["final_observation"][idx] 274 | fov_idx = random.randint(0, len(sensory_action_set)-1) 275 | pvm_buffer_copy = pvm_buffer.copy() 276 | pvm_buffer_copy.append(real_next_obs) 277 | real_next_pvm_obs = pvm_buffer_copy.get_obs(mode="stack_max") 278 | rb.add(pvm_obs, real_next_pvm_obs, actions, rewards, dones, {}) 279 | 280 | # TRY NOT TO MODIFY: CRUCIAL step easy to overlook 281 | obs = next_obs 282 | 283 | # INC total transitions 284 | global_transitions += args.env_num 285 | 286 | 287 | obs_backup = obs # back obs 288 | # ALGO LOGIC: training. 289 | if global_transitions > args.learning_starts: 290 | if global_transitions % args.train_frequency == 0: 291 | data = rb.sample(args.batch_size // args.env_num) # counter-balance the true global transitions used for training 292 | with torch.no_grad(): 293 | target_max, _ = target_network(resize(data.next_observations)).max(dim=1) 294 | td_target = data.rewards.flatten() + args.gamma * target_max * (1 - data.dones.flatten()) 295 | old_val = q_network(resize(data.observations)).gather(1, data.actions).squeeze() 296 | loss = F.mse_loss(td_target, old_val) 297 | 298 | if global_transitions % 100 == 0: 299 | writer.add_scalar("losses/td_loss", loss, global_transitions) 300 | writer.add_scalar("losses/q_values", old_val.mean().item(), global_transitions) 301 | # print("SPS:", int(global_step / (time.time() - start_time))) 302 | writer.add_scalar("charts/SPS", int(global_transitions / (time.time() - start_time)), global_transitions) 303 | 304 | # optimize the model 305 | optimizer.zero_grad() 306 | loss.backward() 307 | optimizer.step() 308 | 309 | # update the target network 310 | if (global_transitions // args.env_num) % args.target_network_frequency == 0: 311 | target_network.load_state_dict(q_network.state_dict()) 312 | 313 | # evaluation 314 | if (global_transitions % args.eval_frequency == 0 and args.eval_frequency > 0) or \ 315 | (global_transitions >= args.total_timesteps): 316 | q_network.eval() 317 | 318 | eval_episodic_returns, eval_episodic_lengths = [], [] 319 | 320 | for eval_ep in range(args.eval_num): 321 | eval_env = [make_env(args.env, args.seed+eval_ep, frame_stack=args.frame_stack, action_repeat=args.action_repeat, 322 | fov_size=(args.fov_size, args.fov_size), 323 | fov_init_loc=(args.fov_init_loc, args.fov_init_loc), 324 | sensory_action_mode=args.sensory_action_mode, 325 | sensory_action_space=(-max(sensory_action_step), max(sensory_action_step)), 326 | resize_to_full=args.resize_to_full, 327 | clip_reward=args.clip_reward, 328 | training=False, 329 | mask_out=True)] 330 | eval_env = gym.vector.SyncVectorEnv(eval_env) 331 | obs, _ = eval_env.reset() 332 | done = False 333 | pvm_buffer = PVMBuffer(args.pvm_stack, (eval_env.num_envs, args.frame_stack,)+OBSERVATION_SIZE) 334 | while not done: 335 | pvm_buffer.append(obs) 336 | pvm_obs = pvm_buffer.get_obs(mode="stack_max") 337 | q_values = q_network(resize(torch.from_numpy(pvm_obs)).to(device)) 338 | actions = torch.argmax(q_values, dim=1).cpu().numpy() 339 | motor_actions, sensory_actions = seperate_motor_sensory_joint_action(actions) 340 | next_obs, rewards, dones, _, infos = eval_env.step({"motor_action": motor_actions, 341 | "sensory_action": sensory_actions}) 342 | obs = next_obs 343 | done = dones[0] 344 | if done: 345 | eval_episodic_returns.append(infos['final_info'][0]["reward"]) 346 | eval_episodic_lengths.append(infos['final_info'][0]["ep_len"]) 347 | 348 | if args.capture_video: 349 | record_file_dir = os.path.join("recordings", args.exp_name, os.path.basename(__file__).rstrip(".py")) 350 | os.makedirs(record_file_dir, exist_ok=True) 351 | record_file_fn = f"seed{args.seed}_step{global_transitions:07d}_eval{eval_ep:02d}_record.pt" 352 | eval_env.envs[0].save_record_to_file(os.path.join(record_file_dir, record_file_fn)) 353 | if eval_ep == 0: 354 | model_file_dir = os.path.join("trained_models", args.exp_name, os.path.basename(__file__).rstrip(".py"), args.env) 355 | os.makedirs(model_file_dir, exist_ok=True) 356 | model_fn = f"{args.env}_seed{args.seed}_step{global_transitions:07d}_model.pt" 357 | torch.save({"sfn": None, "q": q_network.state_dict()}, os.path.join(model_file_dir, model_fn)) 358 | 359 | writer.add_scalar("charts/eval_episodic_return", np.mean(eval_episodic_returns), global_transitions) 360 | writer.add_scalar("charts/eval_episodic_return_std", np.std(eval_episodic_returns), global_transitions) 361 | print(f"[T: {time.time()-start_time:.2f}] [N: {global_transitions:07,d}] [Eval R: {np.mean(eval_episodic_returns):.2f}+/-{np.std(eval_episodic_returns):.2f}] [R list: {','.join([str(r) for r in eval_episodic_returns])}]") 362 | 363 | q_network.train() 364 | obs = obs_backup # restore obs if eval occurs 365 | 366 | 367 | envs.close() 368 | eval_env.close() 369 | writer.close() 370 | -------------------------------------------------------------------------------- /agent/sac_atari_base.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | import os.path as osp 4 | import random 5 | import time 6 | from distutils.util import strtobool 7 | 8 | sys.path.append(osp.dirname(osp.dirname(osp.realpath(__file__)))) 9 | os.environ["OMP_NUM_THREADS"] = "1" 10 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 11 | import warnings 12 | warnings.filterwarnings("ignore", category=UserWarning) 13 | 14 | import gymnasium as gym 15 | import numpy as np 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | import torch.optim as optim 20 | from torch.distributions.categorical import Categorical 21 | from common.buffer import ReplayBuffer 22 | from common.utils import get_timestr, seed_everything 23 | from torch.utils.tensorboard import SummaryWriter 24 | 25 | from active_gym import AtariBaseEnv, AtariEnvArgs 26 | 27 | 28 | def parse_args(): 29 | # fmt: off 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__), 32 | help="the name of this experiment") 33 | parser.add_argument("--seed", type=int, default=1, 34 | help="seed of the experiment") 35 | parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, 36 | help="if toggled, cuda will be enabled by default") 37 | parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, 38 | help="whether to capture videos of the agent performances (check out `videos` folder)") 39 | 40 | # env setting 41 | parser.add_argument("--env", type=str, default="breakout", 42 | help="the id of the environment") 43 | parser.add_argument("--env-num", type=int, default=1, 44 | help="# envs in parallel") 45 | parser.add_argument("--frame-stack", type=int, default=4, 46 | help="frame stack #") 47 | parser.add_argument("--action-repeat", type=int, default=4, 48 | help="action repeat #") 49 | parser.add_argument("--clip-reward", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True) 50 | 51 | # Algorithm specific arguments 52 | parser.add_argument("--total-timesteps", type=int, default=5000000, 53 | help="total timesteps of the experiments") 54 | parser.add_argument("--buffer-size", type=int, default=int(1e6), 55 | help="the replay memory buffer size") # smaller than in original paper but evaluation is done only for 100k steps anyway 56 | parser.add_argument("--gamma", type=float, default=0.99, 57 | help="the discount factor gamma") 58 | parser.add_argument("--tau", type=float, default=1.0, 59 | help="target smoothing coefficient (default: 1)") # Default is 1 to perform replacement update 60 | parser.add_argument("--batch-size", type=int, default=64, 61 | help="the batch size of sample from the reply memory") 62 | parser.add_argument("--learning-starts", type=int, default=2e4, 63 | help="timestep to start learning") 64 | parser.add_argument("--policy-lr", type=float, default=3e-4, 65 | help="the learning rate of the policy network optimizer") 66 | parser.add_argument("--q-lr", type=float, default=3e-4, 67 | help="the learning rate of the Q network network optimizer") 68 | parser.add_argument("--update-frequency", type=int, default=4, 69 | help="the frequency of training updates") 70 | parser.add_argument("--target-network-frequency", type=int, default=8000, 71 | help="the frequency of updates for the target networks") 72 | parser.add_argument("--alpha", type=float, default=0.2, 73 | help="Entropy regularization coefficient.") 74 | parser.add_argument("--autotune", type=lambda x:bool(strtobool(x)), default=True, nargs="?", const=True, 75 | help="automatic tuning of the entropy coefficient") 76 | parser.add_argument("--target-entropy-scale", type=float, default=0.89, 77 | help="coefficient for scaling the autotune entropy target") 78 | 79 | # eval args 80 | parser.add_argument("--eval-frequency", type=int, default=-1, 81 | help="eval frequency. default -1 is eval at the end.") 82 | parser.add_argument("--eval-num", type=int, default=10, 83 | help="eval frequency. default -1 is eval at the end.") 84 | 85 | args = parser.parse_args() 86 | # fmt: on 87 | return args 88 | 89 | 90 | def make_env(env_name, seed, **kwargs): 91 | def thunk(): 92 | env_args = AtariEnvArgs( 93 | game=env_name, seed=seed, obs_size=(84, 84), **kwargs 94 | ) 95 | env = AtariBaseEnv(env_args) 96 | env.action_space.seed(seed) 97 | env.observation_space.seed(seed) 98 | return env 99 | 100 | return thunk 101 | 102 | 103 | def layer_init(layer, bias_const=0.0): 104 | nn.init.kaiming_normal_(layer.weight) 105 | torch.nn.init.constant_(layer.bias, bias_const) 106 | return layer 107 | 108 | 109 | # ALGO LOGIC: initialize agent here: 110 | # NOTE: Sharing a CNN encoder between Actor and Critics is not recommended for SAC without stopping actor gradients 111 | # See the SAC+AE paper https://arxiv.org/abs/1910.01741 for more info 112 | # TL;DR The actor's gradients mess up the representation when using a joint encoder 113 | class SoftQNetwork(nn.Module): 114 | def __init__(self, envs): 115 | super().__init__() 116 | obs_shape = envs.single_observation_space.shape 117 | self.conv = nn.Sequential( 118 | layer_init(nn.Conv2d(obs_shape[0], 32, kernel_size=8, stride=4)), 119 | nn.ReLU(), 120 | layer_init(nn.Conv2d(32, 64, kernel_size=4, stride=2)), 121 | nn.ReLU(), 122 | layer_init(nn.Conv2d(64, 64, kernel_size=3, stride=1)), 123 | nn.Flatten(), 124 | ) 125 | 126 | with torch.inference_mode(): 127 | output_dim = self.conv(torch.zeros(1, *obs_shape)).shape[1] 128 | 129 | self.fc1 = layer_init(nn.Linear(output_dim, 512)) 130 | self.fc_q = layer_init(nn.Linear(512, envs.single_action_space.n)) 131 | 132 | def forward(self, x): 133 | x = F.relu(self.conv(x)) 134 | x = F.relu(self.fc1(x)) 135 | q_vals = self.fc_q(x) 136 | return q_vals 137 | 138 | 139 | class Actor(nn.Module): 140 | def __init__(self, envs): 141 | super().__init__() 142 | obs_shape = envs.single_observation_space.shape 143 | self.conv = nn.Sequential( 144 | layer_init(nn.Conv2d(obs_shape[0], 32, kernel_size=8, stride=4)), 145 | nn.ReLU(), 146 | layer_init(nn.Conv2d(32, 64, kernel_size=4, stride=2)), 147 | nn.ReLU(), 148 | layer_init(nn.Conv2d(64, 64, kernel_size=3, stride=1)), 149 | nn.Flatten(), 150 | ) 151 | 152 | with torch.inference_mode(): 153 | output_dim = self.conv(torch.zeros(1, *obs_shape)).shape[1] 154 | 155 | self.fc1 = layer_init(nn.Linear(output_dim, 512)) 156 | self.fc_logits = layer_init(nn.Linear(512, envs.single_action_space.n)) 157 | 158 | def forward(self, x): 159 | x = F.relu(self.conv(x)) 160 | x = F.relu(self.fc1(x)) 161 | logits = self.fc_logits(x) 162 | 163 | return logits 164 | 165 | def get_action(self, x): 166 | logits = self(x) 167 | policy_dist = Categorical(logits=logits) 168 | action = policy_dist.sample() 169 | # Action probabilities for calculating the adapted soft-Q loss 170 | action_probs = policy_dist.probs 171 | log_prob = F.log_softmax(logits, dim=1) 172 | return action, log_prob, action_probs 173 | 174 | 175 | if __name__ == "__main__": 176 | args = parse_args() 177 | args.env = args.env.lower() 178 | run_name = f"{args.env}__{os.path.basename(__file__)}__{args.seed}__{get_timestr()}" 179 | run_dir = os.path.join("runs", args.exp_name) 180 | if not os.path.exists(run_dir): 181 | os.makedirs(run_dir, exist_ok=True) 182 | 183 | writer = SummaryWriter(os.path.join(run_dir, run_name)) 184 | writer.add_text( 185 | "hyperparameters", 186 | "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), 187 | ) 188 | 189 | # TRY NOT TO MODIFY: seeding 190 | seed_everything(args.seed) 191 | 192 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") 193 | 194 | envs = [] 195 | for i in range(args.env_num): 196 | envs.append(make_env(args.env, args.seed+i, frame_stack=args.frame_stack, action_repeat=args.action_repeat, clip_reward=args.clip_reward)) 197 | # envs = gym.vector.AsyncVectorEnv(envs) 198 | envs = gym.vector.SyncVectorEnv(envs) 199 | 200 | actor = Actor(envs).to(device) 201 | qf1 = SoftQNetwork(envs).to(device) 202 | qf2 = SoftQNetwork(envs).to(device) 203 | qf1_target = SoftQNetwork(envs).to(device) 204 | qf2_target = SoftQNetwork(envs).to(device) 205 | qf1_target.load_state_dict(qf1.state_dict()) 206 | qf2_target.load_state_dict(qf2.state_dict()) 207 | # TRY NOT TO MODIFY: eps=1e-4 increases numerical stability 208 | q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=args.q_lr, eps=1e-4) 209 | actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.policy_lr, eps=1e-4) 210 | 211 | # Automatic entropy tuning 212 | if args.autotune: 213 | target_entropy = -args.target_entropy_scale * torch.log(1 / torch.tensor(envs.single_action_space.n)) 214 | log_alpha = torch.zeros(1, requires_grad=True, device=device) 215 | alpha = log_alpha.exp().item() 216 | a_optimizer = optim.Adam([log_alpha], lr=args.q_lr, eps=1e-4) 217 | else: 218 | alpha = args.alpha 219 | 220 | rb = ReplayBuffer( 221 | args.buffer_size, 222 | envs.single_observation_space, 223 | envs.single_action_space, 224 | device, 225 | n_envs=envs.num_envs, 226 | optimize_memory_usage=True, 227 | handle_timeout_termination=False, 228 | ) 229 | start_time = time.time() 230 | 231 | # TRY NOT TO MODIFY: start the game 232 | obs, infos = envs.reset() 233 | global_transitions = 0 234 | while global_transitions < args.total_timesteps: 235 | # ALGO LOGIC: put action logic here 236 | if global_transitions < args.learning_starts: 237 | actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) 238 | else: 239 | actions, _, _ = actor.get_action(torch.Tensor(obs).to(device)) 240 | actions = actions.detach().cpu().numpy() 241 | 242 | # TRY NOT TO MODIFY: execute the game and log data. 243 | next_obs, rewards, dones, _, infos = envs.step(actions) 244 | 245 | # TRY NOT TO MODIFY: record rewards for plotting purposes 246 | if "final_info" in infos: 247 | for idx, d in enumerate(dones): 248 | if d: 249 | print(f"[T: {time.time()-start_time:.2f}] [N: {global_transitions:07,d}] [R: {infos['final_info'][idx]['reward']:.2f}]") 250 | writer.add_scalar("charts/episodic_return", infos['final_info'][idx]["reward"], global_transitions) 251 | writer.add_scalar("charts/episodic_length", infos['final_info'][idx]["ep_len"], global_transitions) 252 | break 253 | 254 | # TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation` 255 | real_next_obs = next_obs.copy() 256 | for idx, d in enumerate(dones): 257 | if d: 258 | real_next_obs[idx] = infos["final_observation"][idx] 259 | rb.add(obs, real_next_obs, actions, rewards, dones, infos) 260 | 261 | # TRY NOT TO MODIFY: CRUCIAL step easy to overlook 262 | obs = next_obs 263 | 264 | # INC total transitions 265 | global_transitions += args.env_num 266 | 267 | # ALGO LOGIC: training. 268 | if global_transitions > args.learning_starts: 269 | if global_transitions % args.update_frequency == 0: 270 | data = rb.sample(args.batch_size) 271 | # CRITIC training 272 | with torch.no_grad(): 273 | _, next_state_log_pi, next_state_action_probs = actor.get_action(data.next_observations) 274 | qf1_next_target = qf1_target(data.next_observations) 275 | qf2_next_target = qf2_target(data.next_observations) 276 | # we can use the action probabilities instead of MC sampling to estimate the expectation 277 | min_qf_next_target = next_state_action_probs * ( 278 | torch.min(qf1_next_target, qf2_next_target) - alpha * next_state_log_pi 279 | ) 280 | # adapt Q-target for discrete Q-function 281 | min_qf_next_target = min_qf_next_target.sum(dim=1) 282 | next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (min_qf_next_target) 283 | 284 | # use Q-values only for the taken actions 285 | qf1_values = qf1(data.observations) 286 | qf2_values = qf2(data.observations) 287 | qf1_a_values = qf1_values.gather(1, data.actions.long()).view(-1) 288 | qf2_a_values = qf2_values.gather(1, data.actions.long()).view(-1) 289 | qf1_loss = F.mse_loss(qf1_a_values, next_q_value) 290 | qf2_loss = F.mse_loss(qf2_a_values, next_q_value) 291 | qf_loss = qf1_loss + qf2_loss 292 | 293 | q_optimizer.zero_grad() 294 | qf_loss.backward() 295 | q_optimizer.step() 296 | 297 | # ACTOR training 298 | _, log_pi, action_probs = actor.get_action(data.observations) 299 | with torch.no_grad(): 300 | qf1_values = qf1(data.observations) 301 | qf2_values = qf2(data.observations) 302 | min_qf_values = torch.min(qf1_values, qf2_values) 303 | # no need for reparameterization, the expectation can be calculated for discrete actions 304 | actor_loss = (action_probs * ((alpha * log_pi) - min_qf_values)).mean() 305 | 306 | actor_optimizer.zero_grad() 307 | actor_loss.backward() 308 | actor_optimizer.step() 309 | 310 | if args.autotune: 311 | # re-use action probabilities for temperature loss 312 | alpha_loss = (action_probs.detach() * (-log_alpha * (log_pi + target_entropy).detach())).mean() 313 | 314 | a_optimizer.zero_grad() 315 | alpha_loss.backward() 316 | a_optimizer.step() 317 | alpha = log_alpha.exp().item() 318 | 319 | # update the target networks 320 | if global_transitions % args.target_network_frequency == 0: 321 | for param, target_param in zip(qf1.parameters(), qf1_target.parameters()): 322 | target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) 323 | for param, target_param in zip(qf2.parameters(), qf2_target.parameters()): 324 | target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) 325 | 326 | if global_transitions % 100 == 0: 327 | writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_transitions) 328 | writer.add_scalar("losses/qf2_values", qf2_a_values.mean().item(), global_transitions) 329 | writer.add_scalar("losses/qf1_loss", qf1_loss.item(), global_transitions) 330 | writer.add_scalar("losses/qf2_loss", qf2_loss.item(), global_transitions) 331 | writer.add_scalar("losses/qf_loss", qf_loss.item() / 2.0, global_transitions) 332 | writer.add_scalar("losses/actor_loss", actor_loss.item(), global_transitions) 333 | writer.add_scalar("losses/alpha", alpha, global_transitions) 334 | # print("SPS:", int(global_transitions / (time.time() - start_time))) 335 | writer.add_scalar("charts/SPS", int(global_transitions / (time.time() - start_time)), global_transitions) 336 | if args.autotune: 337 | writer.add_scalar("losses/alpha_loss", alpha_loss.item(), global_transitions) 338 | 339 | # evaluation 340 | if (global_transitions % args.eval_frequency == 0 and args.eval_frequency > 0) or \ 341 | (global_transitions >= args.total_timesteps): 342 | qf1.eval() 343 | qf2.eval() 344 | actor.eval() 345 | 346 | eval_episodic_returns, eval_episodic_lengths = [], [] 347 | 348 | for eval_ep in range(args.eval_num): 349 | eval_env = [make_env(args.env, args.seed+eval_ep, frame_stack=args.frame_stack, action_repeat=args.action_repeat, clip_reward=args.clip_reward, training=False, record=args.capture_video)] 350 | eval_env = gym.vector.SyncVectorEnv(eval_env) 351 | obs, _ = eval_env.reset() 352 | done = False 353 | while not done: 354 | actions, _, _ = actor.get_action(torch.Tensor(obs).to(device)) 355 | actions = actions.detach().cpu().numpy() 356 | next_obs, rewards, dones, _, infos = eval_env.step(actions) 357 | obs = next_obs 358 | done = dones[0] 359 | if done: 360 | eval_episodic_returns.append(infos['final_info'][0]["reward"]) 361 | eval_episodic_lengths.append(infos['final_info'][0]["ep_len"]) 362 | if args.capture_video: 363 | record_file_dir = os.path.join("recordings", args.exp_name, os.path.basename(__file__).rstrip(".py"), args.env) 364 | os.makedirs(record_file_dir, exist_ok=True) 365 | record_file_fn = f"{args.env}_seed{args.seed}_step{global_transitions:07d}_eval{eval_ep:02d}_record.pt" 366 | eval_env.envs[0].save_record_to_file(os.path.join(record_file_dir, record_file_fn)) 367 | model_file_dir = os.path.join("trained_models", args.exp_name, os.path.basename(__file__).rstrip(".py"), args.env) 368 | os.makedirs(model_file_dir, exist_ok=True) 369 | model_fn = f"{args.env}_seed{args.seed}_step{global_transitions:07d}_eval{eval_ep:02d}_record.pt" 370 | torch.save({"sfn": None, "qf1": qf1.state_dict(), "qf2": qf2.state_dict(), "actor": actor.state_dict()}, os.path.join(model_file_dir, model_fn)) 371 | 372 | writer.add_scalar("charts/eval_episodic_return", np.mean(eval_episodic_returns), global_transitions) 373 | writer.add_scalar("charts/eval_episodic_return_std", np.std(eval_episodic_returns), global_transitions) 374 | # writer.add_scalar("charts/eval_episodic_length", np.mean(), global_transitions) 375 | print(f"[T: {time.time()-start_time:.2f}] [N: {global_transitions:07,d}] [Eval R: {np.mean(eval_episodic_returns):.2f}+/-{np.std(eval_episodic_returns):.2f}] [R list: {','.join([str(r) for r in eval_episodic_returns])}]") 376 | 377 | qf1.train() 378 | qf2.train() 379 | actor.train() 380 | 381 | envs.close() 382 | eval_env.close() 383 | writer.close() -------------------------------------------------------------------------------- /agent/drqv2_dmc_base.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | import os.path as osp 4 | import random 5 | import time 6 | from itertools import product 7 | from distutils.util import strtobool 8 | 9 | sys.path.append(osp.dirname(osp.dirname(osp.realpath(__file__)))) 10 | os.environ["OMP_NUM_THREADS"] = "1" 11 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 12 | import warnings 13 | warnings.filterwarnings("ignore", category=UserWarning) 14 | 15 | import gymnasium as gym 16 | from gymnasium.spaces import Discrete, Dict 17 | import numpy as np 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | import torch.optim as optim 22 | from torchvision.transforms import Resize 23 | from torch.utils.tensorboard import SummaryWriter 24 | 25 | from common.buffer import NstepRewardReplayBuffer 26 | from common.pvm_buffer import PVMBuffer 27 | from common.utils import ( 28 | get_timestr, 29 | schedule_drq, 30 | seed_everything, 31 | soft_update_params, 32 | weight_init_drq, 33 | TruncatedNormal 34 | ) 35 | 36 | from active_gym import DMCBaseEnv, DMCEnvArgs 37 | 38 | 39 | def parse_args(): 40 | # fmt: off 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__), 43 | help="the name of this experiment") 44 | parser.add_argument("--seed", type=int, default=1, 45 | help="seed of the experiment") 46 | parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, 47 | help="if toggled, `torch.backends.cudnn.deterministic=False`") 48 | parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, 49 | help="if toggled, cuda will be enabled by default") 50 | parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, 51 | help="whether to capture videos of the agent performances") 52 | 53 | # env setting 54 | parser.add_argument("--domain-name", type=str, default="walker", 55 | help="the name of the dmc domain") 56 | parser.add_argument("--task-name", type=str, default="walk", 57 | help="the name of the dmc task") 58 | parser.add_argument("--env-num", type=int, default=1, 59 | help="# envs in parallel") 60 | parser.add_argument("--frame-stack", type=int, default=3, 61 | help="frame stack #") 62 | parser.add_argument("--action-repeat", type=int, default=2, 63 | help="action repeat #") # i.e. frame skip 64 | parser.add_argument("--clip-reward", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True) # dmc we may not clip 65 | 66 | # fov setting 67 | parser.add_argument("--fov-size", type=int, default=50) 68 | parser.add_argument("--fov-init-loc", type=int, default=0) 69 | parser.add_argument("--sensory-action-mode", type=str, default="absolute") 70 | parser.add_argument("--sensory-action-space", type=int, default=10) # ignored when sensory_action_mode="relative" 71 | parser.add_argument("--resize-to-full", default=False, action="store_true") 72 | # for discrete observ action 73 | parser.add_argument("--sensory-action-x-size", type=int, default=4) 74 | parser.add_argument("--sensory-action-y-size", type=int, default=4) 75 | # pvm setting 76 | parser.add_argument("--pvm-stack", type=int, default=3) 77 | 78 | # Algorithm specific arguments 79 | parser.add_argument("--total-timesteps", type=int, default=1000000, 80 | help="total timesteps of the experiments") 81 | parser.add_argument("--buffer-size", type=int, default=int(1e6), 82 | help="the replay memory buffer size") 83 | parser.add_argument("--gamma", type=float, default=0.99, 84 | help="the discount factor gamma") 85 | parser.add_argument("--tau", type=float, default=0.01, 86 | help="target smoothing coefficient (default: 0.01)") 87 | parser.add_argument("--batch-size", type=int, default=256, 88 | help="the batch size of sample from the reply memory") 89 | parser.add_argument("--learning-starts", type=int, default=2000, 90 | help="timestep to start learning") 91 | parser.add_argument("--lr", type=float, default=1e-4, 92 | help="the learning rate of drq") 93 | parser.add_argument("--update-frequency", type=int, default=2, 94 | help="update frequency of drq") 95 | parser.add_argument("--stddev-clip", type=float, default=0.3) 96 | parser.add_argument("--stddev-schedule", type=str, default="linear(1.0,0.1,50000)") 97 | parser.add_argument("--feature-dim", type=int, default=50) 98 | parser.add_argument("--hidden-dim", type=int, default=1024) 99 | parser.add_argument("--n-step-reward", type=int, default=3) 100 | 101 | # eval args 102 | parser.add_argument("--eval-frequency", type=int, default=-1, 103 | help="eval frequency. default -1 is eval at the end.") 104 | parser.add_argument("--eval-num", type=int, default=10, 105 | help="eval episodes") 106 | 107 | args = parser.parse_args() 108 | # fmt: on 109 | return args 110 | 111 | 112 | def make_env(domain_name, task_name, seed, **kwargs): 113 | def thunk(): 114 | env_args = DMCEnvArgs( 115 | domain_name=domain_name, task_name=task_name, seed=seed, obs_size=(84, 84), **kwargs 116 | ) 117 | env = DMCBaseEnv(env_args) 118 | env.action_space.seed(seed) 119 | env.observation_space.seed(seed) 120 | return env 121 | return thunk 122 | 123 | class RandomShiftsAug(nn.Module): 124 | def __init__(self, pad): 125 | super().__init__() 126 | self.pad = pad 127 | 128 | def forward(self, x): 129 | n, c, h, w = x.size() 130 | assert h == w 131 | padding = tuple([self.pad] * 4) 132 | x = F.pad(x, padding, 'replicate') 133 | eps = 1.0 / (h + 2 * self.pad) 134 | arange = torch.linspace(-1.0 + eps, 135 | 1.0 - eps, 136 | h + 2 * self.pad, 137 | device=x.device, 138 | dtype=x.dtype)[:h] 139 | arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2) 140 | base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2) 141 | base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1) 142 | 143 | shift = torch.randint(0, 144 | 2 * self.pad + 1, 145 | size=(n, 1, 1, 2), 146 | device=x.device, 147 | dtype=x.dtype) 148 | shift *= 2.0 / (h + 2 * self.pad) 149 | 150 | grid = base_grid + shift 151 | return F.grid_sample(x, 152 | grid, 153 | padding_mode='zeros', 154 | align_corners=False) 155 | 156 | 157 | class Encoder(nn.Module): 158 | def __init__(self, obs_shape): 159 | super().__init__() 160 | 161 | assert len(obs_shape) == 3 162 | self.repr_dim = 32 * 35 * 35 163 | 164 | self.convnet = nn.Sequential(nn.Conv2d(obs_shape[0], 32, 3, stride=2), 165 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1), 166 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1), 167 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1), 168 | nn.ReLU()) 169 | 170 | self.apply(weight_init_drq) 171 | 172 | def forward(self, obs): 173 | obs = obs - 0.5 # /255 is done by env 174 | h = self.convnet(obs) 175 | h = h.view(h.shape[0], -1) 176 | return h 177 | 178 | 179 | class Actor(nn.Module): 180 | def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim): 181 | super().__init__() 182 | 183 | self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim), 184 | nn.LayerNorm(feature_dim), nn.Tanh()) 185 | 186 | self.policy = nn.Sequential(nn.Linear(feature_dim, hidden_dim), 187 | nn.ReLU(inplace=True), 188 | nn.Linear(hidden_dim, hidden_dim), 189 | nn.ReLU(inplace=True), 190 | nn.Linear(hidden_dim, action_shape[0])) 191 | 192 | self.apply(weight_init_drq) 193 | 194 | def forward(self, obs, std): 195 | h = self.trunk(obs) 196 | 197 | mu = self.policy(h) 198 | mu = torch.tanh(mu) 199 | std = torch.ones_like(mu) * std 200 | 201 | dist = TruncatedNormal(mu, std) 202 | return dist 203 | 204 | 205 | class Critic(nn.Module): 206 | def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim): 207 | super().__init__() 208 | 209 | self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim), 210 | nn.LayerNorm(feature_dim), nn.Tanh()) 211 | 212 | self.Q1 = nn.Sequential( 213 | nn.Linear(feature_dim + action_shape[0], hidden_dim), 214 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim), 215 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1)) 216 | 217 | self.Q2 = nn.Sequential( 218 | nn.Linear(feature_dim + action_shape[0], hidden_dim), 219 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim), 220 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1)) 221 | 222 | self.apply(weight_init_drq) 223 | 224 | def forward(self, obs, action): 225 | h = self.trunk(obs) 226 | h_action = torch.cat([h, action], dim=-1) 227 | q1 = self.Q1(h_action) 228 | q2 = self.Q2(h_action) 229 | 230 | return q1, q2 231 | 232 | 233 | class DrQV2Agent: 234 | def __init__(self, obs_shape, action_shape, device, lr, feature_dim, 235 | hidden_dim, critic_target_tau, learning_starts, 236 | update_every_steps, stddev_schedule, stddev_clip): 237 | self.device = device 238 | self.critic_target_tau = critic_target_tau 239 | self.update_every_steps = update_every_steps 240 | self.learning_starts = learning_starts 241 | self.stddev_schedule = stddev_schedule 242 | self.stddev_clip = stddev_clip 243 | 244 | # models 245 | self.encoder = Encoder(obs_shape).to(device) 246 | self.actor = Actor(self.encoder.repr_dim, action_shape, feature_dim, 247 | hidden_dim).to(device) 248 | 249 | self.critic = Critic(self.encoder.repr_dim, action_shape, feature_dim, 250 | hidden_dim).to(device) 251 | self.critic_target = Critic(self.encoder.repr_dim, action_shape, 252 | feature_dim, hidden_dim).to(device) 253 | self.critic_target.load_state_dict(self.critic.state_dict()) 254 | 255 | # optimizers 256 | self.encoder_opt = torch.optim.Adam(self.encoder.parameters(), lr=lr) 257 | self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr) 258 | self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=lr) 259 | 260 | # data augmentation 261 | self.aug = RandomShiftsAug(pad=4) 262 | 263 | self.train() 264 | self.critic_target.train() 265 | 266 | def train(self, training=True): 267 | self.training = training 268 | self.encoder.train(training) 269 | self.actor.train(training) 270 | self.critic.train(training) 271 | 272 | def eval(self): 273 | self.train(False) 274 | 275 | def act(self, obs, step, eval_mode=None) -> np.ndarray: 276 | obs = self.encoder(obs) 277 | stddev = schedule_drq(self.stddev_schedule, step) 278 | dist = self.actor(obs, stddev) 279 | 280 | # auto eval 281 | if eval_mode is None: 282 | eval_mode = not self.training 283 | 284 | if eval_mode: 285 | action = dist.mean 286 | else: 287 | action = dist.sample(clip=None) 288 | if step < self.learning_starts: 289 | action.uniform_(-1.0, 1.0) 290 | return action.detach().cpu().numpy() 291 | 292 | def update_critic(self, obs, action, reward, discount, next_obs, step): 293 | metrics = dict() 294 | 295 | with torch.no_grad(): 296 | stddev = schedule_drq(self.stddev_schedule, step) 297 | dist = self.actor(next_obs, stddev) 298 | next_action = dist.sample(clip=self.stddev_clip) 299 | target_Q1, target_Q2 = self.critic_target(next_obs, next_action) 300 | target_V = torch.min(target_Q1, target_Q2) 301 | target_Q = reward + (discount * target_V) 302 | 303 | Q1, Q2 = self.critic(obs, action) 304 | critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q) 305 | 306 | # optimize encoder and critic 307 | self.encoder_opt.zero_grad(set_to_none=True) 308 | self.critic_opt.zero_grad(set_to_none=True) 309 | critic_loss.backward() 310 | self.critic_opt.step() 311 | self.encoder_opt.step() 312 | 313 | return metrics 314 | 315 | def update_actor(self, obs, step): 316 | metrics = dict() 317 | 318 | stddev = schedule_drq(self.stddev_schedule, step) 319 | dist = self.actor(obs, stddev) 320 | action = dist.sample(clip=self.stddev_clip) 321 | log_prob = dist.log_prob(action).sum(-1, keepdim=True) 322 | Q1, Q2 = self.critic(obs, action) 323 | Q = torch.min(Q1, Q2) 324 | 325 | actor_loss = -Q.mean() 326 | 327 | # optimize actor 328 | self.actor_opt.zero_grad(set_to_none=True) 329 | actor_loss.backward() 330 | self.actor_opt.step() 331 | 332 | return metrics 333 | 334 | def update(self, batch, step): 335 | metrics = dict() 336 | 337 | if step % self.update_every_steps != 0: 338 | return metrics 339 | 340 | obs = batch.observations 341 | action = batch.actions 342 | next_obs = batch.next_observations 343 | reward = batch.rewards 344 | discount = batch.discounts 345 | 346 | # augment 347 | obs = self.aug(obs.float()) 348 | next_obs = self.aug(next_obs.float()) 349 | # encode 350 | obs = self.encoder(obs) 351 | with torch.no_grad(): 352 | next_obs = self.encoder(next_obs) 353 | 354 | 355 | self.update_critic(obs, action, reward, discount, next_obs, step) 356 | 357 | self.update_actor(obs.detach(), step) 358 | 359 | # update critic target 360 | soft_update_params(self.critic, self.critic_target, 361 | self.critic_target_tau) 362 | 363 | return metrics 364 | 365 | def save_agent(self): 366 | agent = { 367 | "encoder": self.encoder.state_dict(), 368 | "critic": self.critic.state_dict(), 369 | "critic_target": self.critic_target.state_dict(), 370 | "actor": self.actor.state_dict() 371 | } 372 | return agent 373 | 374 | 375 | if __name__ == "__main__": 376 | args = parse_args() 377 | args.domain_name = args.domain_name.lower() 378 | args.task_name = args.task_name.lower() 379 | run_name = f"{args.domain_name}-{args.task_name}__{os.path.basename(__file__)}__{args.seed}__{get_timestr()}" 380 | run_dir = os.path.join("runs", args.exp_name) 381 | if not os.path.exists(run_dir): 382 | os.makedirs(run_dir, exist_ok=True) 383 | 384 | writer = SummaryWriter(os.path.join(run_dir, run_name)) 385 | writer.add_text( 386 | "hyperparameters", 387 | "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), 388 | ) 389 | 390 | # TRY NOT TO MODIFY: seeding 391 | seed_everything(args.seed) 392 | 393 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") 394 | 395 | # env setup 396 | envs = [] 397 | for i in range(args.env_num): 398 | envs.append(make_env(args.domain_name, 399 | args.task_name, 400 | args.seed+i, 401 | frame_stack=args.frame_stack, 402 | action_repeat=args.action_repeat, 403 | clip_reward=False)) 404 | envs = gym.vector.SyncVectorEnv(envs) 405 | 406 | drq_agent = DrQV2Agent( 407 | envs.single_observation_space.shape, 408 | envs.single_action_space.shape, 409 | device=device, 410 | lr=args.lr, 411 | feature_dim=args.feature_dim, 412 | hidden_dim=args.hidden_dim, 413 | critic_target_tau=args.tau, 414 | learning_starts=args.learning_starts, 415 | update_every_steps=args.update_frequency, 416 | stddev_schedule=args.stddev_schedule, 417 | stddev_clip=args.stddev_clip 418 | ) 419 | 420 | rb = NstepRewardReplayBuffer( 421 | n_step_reward=args.n_step_reward, 422 | gamma=args.gamma, 423 | buffer_size=args.buffer_size, 424 | observation_space=envs.single_observation_space, 425 | action_space=envs.single_action_space, 426 | device=device, 427 | handle_timeout_termination=True, 428 | ) 429 | start_time = time.time() 430 | 431 | # TRY NOT TO MODIFY: start the game 432 | obs, infos = envs.reset() 433 | global_transitions = 0 434 | while global_transitions < args.total_timesteps: 435 | actions = drq_agent.act(torch.Tensor(obs).to(device), global_transitions) 436 | next_obs, rewards, dones, _, infos = envs.step(actions) 437 | 438 | if "final_info" in infos: 439 | for idx, d in enumerate(dones): 440 | if d: 441 | print(f"[T: {time.time()-start_time:.2f}] [N: {global_transitions:07,d}] [R: {infos['final_info'][idx]['reward']:.2f}]") 442 | writer.add_scalar("charts/episodic_return", infos['final_info'][idx]["reward"], global_transitions) 443 | writer.add_scalar("charts/episodic_length", infos['final_info'][idx]["ep_len"], global_transitions) 444 | break 445 | 446 | real_next_obs = next_obs.copy() 447 | for idx, d in enumerate(dones): 448 | if d: 449 | real_next_obs[idx] = infos["final_observation"][idx] 450 | rb.add(obs, real_next_obs, actions, rewards, dones, [{}]) 451 | 452 | obs = next_obs 453 | global_transitions += args.env_num 454 | 455 | # ALGO LOGIC: training. 456 | if global_transitions > args.learning_starts: 457 | data = rb.sample(args.batch_size) 458 | drq_agent.update(data, global_transitions) 459 | 460 | if global_transitions % 100 == 0: 461 | writer.add_scalar("charts/SPS", int(global_transitions / (time.time() - start_time)), global_transitions) 462 | 463 | if (global_transitions % args.eval_frequency == 0 and args.eval_frequency > 0) or \ 464 | (global_transitions >= args.total_timesteps): 465 | drq_agent.eval() 466 | 467 | eval_episodic_returns, eval_episodic_lengths = [], [] 468 | for eval_ep in range(args.eval_num): 469 | eval_env=[make_env(args.domain_name, 470 | args.task_name, 471 | args.seed+eval_ep, 472 | frame_stack=args.frame_stack, 473 | action_repeat=args.action_repeat, 474 | clip_reward=False)] 475 | eval_env = gym.vector.SyncVectorEnv(eval_env) 476 | eval_obs, infos = eval_env.reset() 477 | done = False 478 | while not done: 479 | actions = drq_agent.act(torch.Tensor(eval_obs).to(device), step=global_transitions) 480 | eval_next_obs, rewards, dones, _, infos = eval_env.step(actions) 481 | eval_obs = eval_next_obs 482 | done = dones[0] 483 | if done: 484 | eval_episodic_returns.append(infos['final_info'][0]["reward"]) 485 | eval_episodic_lengths.append(infos['final_info'][0]["ep_len"]) 486 | if args.capture_video: 487 | record_file_dir = os.path.join("recordings", args.exp_name, os.path.basename(__file__).replace(".py", ""), f"{args.domain_name}-{args.task_name}") 488 | os.makedirs(record_file_dir, exist_ok=True) 489 | record_file_fn = f"{args.domain_name}-{args.task_name}_seed{args.seed}_step{global_transitions:07d}_eval{eval_ep:02d}_record.pt" 490 | eval_env.envs[0].save_record_to_file(os.path.join(record_file_dir, record_file_fn)) 491 | if global_transitions >= args.total_timesteps and eval_ep == 0: 492 | model_file_dir = os.path.join("trained_models", args.exp_name, os.path.basename(__file__).replace(".py", ""), f"{args.domain_name}-{args.task_name}") 493 | os.makedirs(model_file_dir, exist_ok=True) 494 | model_fn = f"{args.domain_name}-{args.task_name}_seed{args.seed}_model.pt" 495 | torch.save({"sfn": None, "agent": drq_agent.save_agent()}, os.path.join(model_file_dir, model_fn)) 496 | 497 | writer.add_scalar("charts/eval_episodic_return", np.mean(eval_episodic_returns), global_transitions) 498 | writer.add_scalar("charts/eval_episodic_return_std", np.std(eval_episodic_returns), global_transitions) 499 | # writer.add_scalar("charts/eval_episodic_length", np.mean(), global_transitions) 500 | print(f"[T: {time.time()-start_time:.2f}] [N: {global_transitions:07,d}] [Eval R: {np.mean(eval_episodic_returns):.2f}+/-{np.std(eval_episodic_returns):.2f}] [R list: {','.join([str(r) for r in eval_episodic_returns])}]") 501 | 502 | drq_agent.train() 503 | 504 | envs.close() 505 | writer.close() -------------------------------------------------------------------------------- /agent/dqn_atari_sugarl.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | import os.path as osp 4 | import random 5 | import time 6 | from itertools import product 7 | from distutils.util import strtobool 8 | 9 | sys.path.append(osp.dirname(osp.dirname(osp.realpath(__file__)))) 10 | os.environ["OMP_NUM_THREADS"] = "1" 11 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 12 | import warnings 13 | warnings.filterwarnings("ignore", category=UserWarning) 14 | 15 | import gymnasium as gym 16 | from gymnasium.spaces import Discrete, Dict 17 | import numpy as np 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | import torch.optim as optim 22 | from torchvision.transforms import Resize 23 | 24 | from common.buffer import DoubleActionReplayBuffer 25 | from common.pvm_buffer import PVMBuffer 26 | from common.utils import get_timestr, seed_everything, get_sugarl_reward_scale_atari 27 | from torch.utils.tensorboard import SummaryWriter 28 | 29 | from active_gym.atari_env import AtariFixedFovealEnv, AtariEnvArgs 30 | 31 | 32 | def parse_args(): 33 | # fmt: off 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), 36 | help="the name of this experiment") 37 | parser.add_argument("--seed", type=int, default=1, 38 | help="seed of the experiment") 39 | parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, 40 | help="if toggled, cuda will be enabled by default") 41 | parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, 42 | help="whether to capture videos of the agent performances (check out `videos` folder)") 43 | 44 | # env setting 45 | parser.add_argument("--env", type=str, default="breakout", 46 | help="the id of the environment") 47 | parser.add_argument("--env-num", type=int, default=1, 48 | help="# envs in parallel") 49 | parser.add_argument("--frame-stack", type=int, default=4, 50 | help="frame stack #") 51 | parser.add_argument("--action-repeat", type=int, default=4, 52 | help="action repeat #") 53 | parser.add_argument("--clip-reward", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True) 54 | 55 | # fov setting 56 | parser.add_argument("--fov-size", type=int, default=50) 57 | parser.add_argument("--fov-init-loc", type=int, default=0) 58 | parser.add_argument("--sensory-action-mode", type=str, default="absolute") 59 | parser.add_argument("--sensory-action-space", type=int, default=10) # ignored when sensory_action_mode="relative" 60 | parser.add_argument("--resize-to-full", default=False, action="store_true") 61 | # for discrete observ action 62 | parser.add_argument("--sensory-action-x-size", type=int, default=4) 63 | parser.add_argument("--sensory-action-y-size", type=int, default=4) 64 | # pvm setting 65 | parser.add_argument("--pvm-stack", type=int, default=3) 66 | 67 | # Algorithm specific arguments 68 | parser.add_argument("--total-timesteps", type=int, default=3000000, 69 | help="total timesteps of the experiments") 70 | parser.add_argument("--learning-rate", type=float, default=1e-4, 71 | help="the learning rate of the optimizer") 72 | parser.add_argument("--buffer-size", type=int, default=500000, 73 | help="the replay memory buffer size") 74 | parser.add_argument("--gamma", type=float, default=0.99, 75 | help="the discount factor gamma") 76 | parser.add_argument("--target-network-frequency", type=int, default=1000, 77 | help="the timesteps it takes to update the target network") 78 | parser.add_argument("--batch-size", type=int, default=32, 79 | help="the batch size of sample from the reply memory") 80 | parser.add_argument("--start-e", type=float, default=1, 81 | help="the starting epsilon for exploration") 82 | parser.add_argument("--end-e", type=float, default=0.01, 83 | help="the ending epsilon for exploration") 84 | parser.add_argument("--exploration-fraction", type=float, default=0.10, 85 | help="the fraction of `total-timesteps` it takes from start-e to go end-e") 86 | parser.add_argument("--learning-starts", type=int, default=80000, 87 | help="timestep to start learning") 88 | parser.add_argument("--train-frequency", type=int, default=4, 89 | help="the frequency of training") 90 | 91 | # eval args 92 | parser.add_argument("--eval-frequency", type=int, default=-1, 93 | help="eval frequency. default -1 is eval at the end.") 94 | parser.add_argument("--eval-num", type=int, default=10, 95 | help="eval frequency. default -1 is eval at the end.") 96 | args = parser.parse_args() 97 | # fmt: on 98 | return args 99 | 100 | 101 | def make_env(env_name, seed, **kwargs): 102 | def thunk(): 103 | env_args = AtariEnvArgs( 104 | game=env_name, seed=seed, obs_size=(84, 84), **kwargs 105 | ) 106 | env = AtariFixedFovealEnv(env_args) 107 | env.action_space.seed(seed) 108 | env.observation_space.seed(seed) 109 | return env 110 | 111 | return thunk 112 | 113 | 114 | # ALGO LOGIC: initialize agent here: 115 | class QNetwork(nn.Module): 116 | def __init__(self, env, sensory_action_set=None): 117 | super().__init__() 118 | if isinstance(env.single_action_space, Discrete): 119 | motor_action_space_size = env.single_action_space.n 120 | sensory_action_space_size = None 121 | elif isinstance(env.single_action_space, Dict): 122 | motor_action_space_size = env.single_action_space["motor_action"].n 123 | if sensory_action_set is not None: 124 | sensory_action_space_size = len(sensory_action_set) 125 | else: 126 | sensory_action_space_size = env.single_action_space["sensory_action"].n 127 | self.backbone = nn.Sequential( 128 | nn.Conv2d(4, 32, 8, stride=4), 129 | nn.ReLU(), 130 | nn.Conv2d(32, 64, 4, stride=2), 131 | nn.ReLU(), 132 | nn.Conv2d(64, 64, 3, stride=1), 133 | nn.ReLU(), 134 | nn.Flatten(), 135 | nn.Linear(3136, 512), 136 | nn.ReLU(), 137 | ) 138 | 139 | self.motor_action_head = nn.Linear(512, motor_action_space_size) 140 | self.sensory_action_head = None 141 | if sensory_action_space_size is not None: 142 | self.sensory_action_head = nn.Linear(512, sensory_action_space_size) 143 | 144 | 145 | def forward(self, x): 146 | x = self.backbone(x) 147 | motor_action = self.motor_action_head(x) 148 | sensory_action = None 149 | if self.sensory_action_head: 150 | sensory_action = self.sensory_action_head(x) 151 | return motor_action, sensory_action 152 | 153 | class SelfPredictionNetwork(nn.Module): 154 | def __init__(self, env, sensory_action_set=None): 155 | super().__init__() 156 | if isinstance(env.single_action_space, Discrete): 157 | motor_action_space_size = env.single_action_space.n 158 | sensory_action_space_size = None 159 | elif isinstance(env.single_action_space, Dict): 160 | motor_action_space_size = env.single_action_space["motor_action"].n 161 | if sensory_action_set is not None: 162 | sensory_action_space_size = len(sensory_action_set) 163 | else: 164 | sensory_action_space_size = env.single_action_space["sensory_action"].n 165 | 166 | self.backbone = nn.Sequential( 167 | nn.Conv2d(8, 32, 8, stride=4), 168 | nn.ReLU(), 169 | nn.Conv2d(32, 64, 4, stride=2), 170 | nn.ReLU(), 171 | nn.Conv2d(64, 64, 3, stride=1), 172 | nn.ReLU(), 173 | nn.Flatten(), 174 | nn.Linear(3136, 512), 175 | nn.ReLU(), 176 | ) 177 | 178 | self.head = nn.Sequential( 179 | nn.Linear(512, motor_action_space_size), 180 | ) 181 | 182 | self.loss = nn.CrossEntropyLoss() 183 | 184 | def get_loss(self, x, target) -> torch.Tensor: 185 | return self.loss(x, target) 186 | 187 | 188 | def forward(self, x): 189 | x = self.backbone(x) 190 | x = self.head(x) 191 | return x 192 | 193 | 194 | def linear_schedule(start_e: float, end_e: float, duration: int, t: int): 195 | slope = (end_e - start_e) / duration 196 | return max(slope * t + start_e, end_e) 197 | 198 | 199 | if __name__ == "__main__": 200 | args = parse_args() 201 | args.env = args.env.lower() 202 | run_name = f"{args.env}__{os.path.basename(__file__)}__{args.seed}__{get_timestr()}" 203 | run_dir = os.path.join("runs", args.exp_name) 204 | if not os.path.exists(run_dir): 205 | os.makedirs(run_dir, exist_ok=True) 206 | 207 | writer = SummaryWriter(os.path.join(run_dir, run_name)) 208 | writer.add_text( 209 | "hyperparameters", 210 | "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), 211 | ) 212 | 213 | # TRY NOT TO MODIFY: seeding 214 | seed_everything(args.seed) 215 | 216 | 217 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") 218 | 219 | # env setup 220 | envs = [] 221 | for i in range(args.env_num): 222 | envs.append(make_env(args.env, args.seed+i, frame_stack=args.frame_stack, action_repeat=args.action_repeat, 223 | fov_size=(args.fov_size, args.fov_size), 224 | fov_init_loc=(args.fov_init_loc, args.fov_init_loc), 225 | sensory_action_mode=args.sensory_action_mode, 226 | sensory_action_space=(-args.sensory_action_space, args.sensory_action_space), 227 | resize_to_full=args.resize_to_full, 228 | clip_reward=args.clip_reward, 229 | mask_out=True)) 230 | # envs = gym.vector.AsyncVectorEnv(envs) 231 | envs = gym.vector.SyncVectorEnv(envs) 232 | 233 | sugarl_r_scale = get_sugarl_reward_scale_atari(args.env) 234 | 235 | resize = Resize((84, 84)) 236 | 237 | # get a discrete observ action space 238 | OBSERVATION_SIZE = (84, 84) 239 | observ_x_max, observ_y_max = OBSERVATION_SIZE[0]-args.fov_size, OBSERVATION_SIZE[1]-args.fov_size 240 | sensory_action_step = (observ_x_max//args.sensory_action_x_size, 241 | observ_y_max//args.sensory_action_y_size) 242 | sensory_action_x_set = list(range(0, observ_x_max, sensory_action_step[0]))[:args.sensory_action_x_size] 243 | sensory_action_y_set = list(range(0, observ_y_max, sensory_action_step[1]))[:args.sensory_action_y_size] 244 | sensory_action_set = [np.array(a) for a in list(product(sensory_action_x_set, sensory_action_y_set))] 245 | 246 | q_network = QNetwork(envs, sensory_action_set=sensory_action_set).to(device) 247 | optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate) 248 | target_network = QNetwork(envs, sensory_action_set=sensory_action_set).to(device) 249 | target_network.load_state_dict(q_network.state_dict()) 250 | 251 | sfn = SelfPredictionNetwork(envs, sensory_action_set=sensory_action_set).to(device) 252 | sfn_optimizer = optim.Adam(sfn.parameters(), lr=args.learning_rate) 253 | 254 | rb = DoubleActionReplayBuffer( 255 | args.buffer_size, 256 | envs.single_observation_space, 257 | envs.single_action_space["motor_action"], 258 | Discrete(len(sensory_action_set)), 259 | device, 260 | n_envs=envs.num_envs, 261 | optimize_memory_usage=True, 262 | handle_timeout_termination=False, 263 | ) 264 | start_time = time.time() 265 | 266 | # TRY NOT TO MODIFY: start the game 267 | obs, infos = envs.reset() 268 | global_transitions = 0 269 | pvm_buffer = PVMBuffer(args.pvm_stack, (envs.num_envs, args.frame_stack,)+OBSERVATION_SIZE) 270 | 271 | while global_transitions < args.total_timesteps: 272 | pvm_buffer.append(obs) 273 | pvm_obs = pvm_buffer.get_obs(mode="stack_max") 274 | # ALGO LOGIC: put action logic here 275 | epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_transitions) 276 | if random.random() < epsilon: 277 | actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) 278 | motor_actions = np.array([actions[0]["motor_action"]]) 279 | sensory_actions = np.array([random.randint(0, len(sensory_action_set)-1)]) 280 | else: 281 | motor_q_values, sensory_q_values = q_network(resize(torch.from_numpy(pvm_obs)).to(device)) 282 | motor_actions = torch.argmax(motor_q_values, dim=1).cpu().numpy() 283 | sensory_actions = torch.argmax(sensory_q_values, dim=1).cpu().numpy() 284 | 285 | # TRY NOT TO MODIFY: execute the game and log data. 286 | next_obs, rewards, dones, _, infos = envs.step({"motor_action": motor_actions, 287 | "sensory_action": [sensory_action_set[a] for a in sensory_actions] }) 288 | # print (global_step, infos) 289 | 290 | # TRY NOT TO MODIFY: record rewards for plotting purposes 291 | if "final_info" in infos: 292 | for idx, d in enumerate(dones): 293 | if d: 294 | print(f"[T: {time.time()-start_time:.2f}] [N: {global_transitions:07,d}] [R: {infos['final_info'][idx]['reward']:.2f}]") 295 | writer.add_scalar("charts/episodic_return", infos['final_info'][idx]["reward"], global_transitions) 296 | writer.add_scalar("charts/episodic_length", infos['final_info'][idx]["ep_len"], global_transitions) 297 | writer.add_scalar("charts/epsilon", epsilon, global_transitions) 298 | break 299 | 300 | # TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation` 301 | real_next_obs = next_obs 302 | for idx, d in enumerate(dones): 303 | if d: 304 | real_next_obs[idx] = infos["final_observation"][idx] 305 | pvm_buffer_copy = pvm_buffer.copy() 306 | pvm_buffer_copy.append(real_next_obs) 307 | real_next_pvm_obs = pvm_buffer_copy.get_obs(mode="stack_max") 308 | rb.add(pvm_obs, real_next_pvm_obs, motor_actions, sensory_actions, rewards, dones, {}) 309 | 310 | # TRY NOT TO MODIFY: CRUCIAL step easy to overlook 311 | obs = next_obs 312 | 313 | # INC total transitions 314 | global_transitions += args.env_num 315 | 316 | obs_backup = obs # back obs 317 | 318 | if global_transitions < args.batch_size: 319 | continue 320 | 321 | # Training 322 | if global_transitions % args.train_frequency == 0: 323 | data = rb.sample(args.batch_size // args.env_num) # counter-balance the true global transitions used for training 324 | 325 | # sfn learning 326 | concat_observation = torch.concat([data.next_observations, data.observations], dim=1) # concat at dimension T 327 | pred_motor_actions = sfn(resize(concat_observation)) 328 | # print (pred_motor_actions.size(), data.motor_actions.size()) 329 | sfn_loss = sfn.get_loss(pred_motor_actions, data.motor_actions.flatten()) 330 | sfn_optimizer.zero_grad() 331 | sfn_loss.backward() 332 | sfn_optimizer.step() 333 | observ_r = F.softmax(pred_motor_actions).gather(1, data.motor_actions).squeeze().detach() # 0-1 334 | 335 | # Q network learning 336 | if global_transitions > args.learning_starts: 337 | with torch.no_grad(): 338 | motor_target, sensory_target = target_network(resize(data.next_observations)) 339 | motor_target_max, _ = motor_target.max(dim=1) 340 | sensory_target_max, _ = sensory_target.max(dim=1) 341 | # scale step-wise reward with observ_r 342 | observ_r_adjusted = observ_r.clone() 343 | observ_r_adjusted[data.rewards.flatten() > 0] = 1 - observ_r_adjusted[data.rewards.flatten() > 0] 344 | td_target = data.rewards.flatten() - (1 - observ_r) * sugarl_r_scale + args.gamma * (motor_target_max+sensory_target_max) * (1 - data.dones.flatten()) 345 | original_td_target = data.rewards.flatten() + args.gamma * (motor_target_max+sensory_target_max) * (1 - data.dones.flatten()) 346 | 347 | old_motor_q_val, old_sensory_q_val = q_network(resize(data.observations)) 348 | old_motor_val = old_motor_q_val.gather(1, data.motor_actions).squeeze() 349 | old_sensory_val = old_sensory_q_val.gather(1, data.sensory_actions).squeeze() 350 | old_val = old_motor_val + old_sensory_val 351 | 352 | loss = F.mse_loss(td_target, old_val) 353 | 354 | if global_transitions % 100 == 0: 355 | writer.add_scalar("losses/td_loss", loss, global_transitions) 356 | writer.add_scalar("losses/q_values", old_val.mean().item(), global_transitions) 357 | writer.add_scalar("losses/motor_q_values", old_motor_val.mean().item(), global_transitions) 358 | writer.add_scalar("losses/sensor_q_values", old_sensory_val.mean().item(), global_transitions) 359 | # print("SPS:", int(global_step / (time.time() - start_time))) 360 | writer.add_scalar("charts/SPS", int(global_transitions / (time.time() - start_time)), global_transitions) 361 | 362 | writer.add_scalar("losses/sfn_loss", sfn_loss.item(), global_transitions) 363 | writer.add_scalar("losses/observ_r", observ_r.mean().item(), global_transitions) 364 | writer.add_scalar("losses/original_td_target", original_td_target.mean().item(), global_transitions) 365 | writer.add_scalar("losses/sugarl_r_scaled_td_target", td_target.mean().item(), global_transitions) 366 | 367 | # optimize the model 368 | optimizer.zero_grad() 369 | loss.backward() 370 | optimizer.step() 371 | 372 | # update the target network 373 | if (global_transitions // args.env_num) % args.target_network_frequency == 0: 374 | target_network.load_state_dict(q_network.state_dict()) 375 | 376 | # evaluation 377 | if (global_transitions % args.eval_frequency == 0 and args.eval_frequency > 0) or \ 378 | (global_transitions >= args.total_timesteps): 379 | q_network.eval() 380 | sfn.eval() 381 | 382 | eval_episodic_returns, eval_episodic_lengths = [], [] 383 | 384 | for eval_ep in range(args.eval_num): 385 | eval_env = [make_env(args.env, args.seed+eval_ep, frame_stack=args.frame_stack, action_repeat=args.action_repeat, 386 | fov_size=(args.fov_size, args.fov_size), 387 | fov_init_loc=(args.fov_init_loc, args.fov_init_loc), 388 | sensory_action_mode=args.sensory_action_mode, 389 | sensory_action_space=(-args.sensory_action_space, args.sensory_action_space), 390 | resize_to_full=args.resize_to_full, 391 | clip_reward=args.clip_reward, 392 | mask_out=True, 393 | training=False, 394 | record=args.capture_video)] 395 | eval_env = gym.vector.SyncVectorEnv(eval_env) 396 | obs_eval, _ = eval_env.reset() 397 | done = False 398 | pvm_buffer_eval = PVMBuffer(args.pvm_stack, (eval_env.num_envs, args.frame_stack,)+OBSERVATION_SIZE) 399 | while not done: 400 | pvm_buffer_eval.append(obs_eval) 401 | pvm_obs_eval = pvm_buffer_eval.get_obs(mode="stack_max") 402 | motor_q_values, sensory_q_values = q_network(resize(torch.from_numpy(pvm_obs_eval)).to(device)) 403 | motor_actions = torch.argmax(motor_q_values, dim=1).cpu().numpy() 404 | sensory_actions = torch.argmax(sensory_q_values, dim=1).cpu().numpy() 405 | next_obs_eval, rewards, dones, _, infos = eval_env.step({"motor_action": motor_actions, 406 | "sensory_action": [sensory_action_set[a] for a in sensory_actions]}) 407 | obs_eval = next_obs_eval 408 | done = dones[0] 409 | if done: 410 | eval_episodic_returns.append(infos['final_info'][0]["reward"]) 411 | eval_episodic_lengths.append(infos['final_info'][0]["ep_len"]) 412 | if args.capture_video: 413 | record_file_dir = os.path.join("recordings", args.exp_name, os.path.basename(__file__).rstrip(".py"), args.env) 414 | os.makedirs(record_file_dir, exist_ok=True) 415 | record_file_fn = f"{args.env}_seed{args.seed}_step{global_transitions:07d}_eval{eval_ep:02d}_record.pt" 416 | eval_env.envs[0].save_record_to_file(os.path.join(record_file_dir, record_file_fn)) 417 | if eval_ep == 0: 418 | model_file_dir = os.path.join("trained_models", args.exp_name, os.path.basename(__file__).rstrip(".py"), args.env) 419 | os.makedirs(model_file_dir, exist_ok=True) 420 | model_fn = f"{args.env}_seed{args.seed}_step{global_transitions:07d}_model.pt" 421 | torch.save({"sfn": sfn.state_dict(), "q": q_network.state_dict()}, os.path.join(model_file_dir, model_fn)) 422 | 423 | 424 | writer.add_scalar("charts/eval_episodic_return", np.mean(eval_episodic_returns), global_transitions) 425 | writer.add_scalar("charts/eval_episodic_return_std", np.std(eval_episodic_returns), global_transitions) 426 | # writer.add_scalar("charts/eval_episodic_length", np.mean(), global_transitions) 427 | print(f"[T: {time.time()-start_time:.2f}] [N: {global_transitions:07,d}] [Eval R: {np.mean(eval_episodic_returns):.2f}+/-{np.std(eval_episodic_returns):.2f}] [R list: {','.join([str(r) for r in eval_episodic_returns])}]") 428 | 429 | q_network.train() 430 | sfn.train() 431 | 432 | obs = obs_backup # restore obs if eval occurs 433 | 434 | 435 | 436 | envs.close() 437 | eval_env.close() 438 | writer.close() -------------------------------------------------------------------------------- /agent/dqn_atari_wp_sugarl.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | import os.path as osp 4 | import random 5 | import time 6 | from itertools import product 7 | from distutils.util import strtobool 8 | 9 | sys.path.append(osp.dirname(osp.dirname(osp.realpath(__file__)))) 10 | os.environ["OMP_NUM_THREADS"] = "1" 11 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 12 | import warnings 13 | warnings.filterwarnings("ignore", category=UserWarning) 14 | 15 | import gymnasium as gym 16 | from gymnasium.spaces import Discrete, Dict 17 | import numpy as np 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | import torch.optim as optim 22 | from torchvision.transforms import Resize 23 | 24 | from common.buffer import DoubleActionReplayBuffer 25 | from common.pvm_buffer import PVMBuffer 26 | from common.utils import get_timestr, seed_everything, get_sugarl_reward_scale_atari 27 | from torch.utils.tensorboard import SummaryWriter 28 | 29 | from active_gym import AtariFixedFovealPeripheralEnv, AtariEnvArgs 30 | 31 | 32 | def parse_args(): 33 | # fmt: off 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), 36 | help="the name of this experiment") 37 | parser.add_argument("--seed", type=int, default=1, 38 | help="seed of the experiment") 39 | parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, 40 | help="if toggled, cuda will be enabled by default") 41 | parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, 42 | help="whether to capture videos of the agent performances (check out `videos` folder)") 43 | 44 | # env setting 45 | parser.add_argument("--env", type=str, default="breakout", 46 | help="the id of the environment") 47 | parser.add_argument("--env-num", type=int, default=1, 48 | help="# envs in parallel") 49 | parser.add_argument("--frame-stack", type=int, default=4, 50 | help="frame stack #") 51 | parser.add_argument("--action-repeat", type=int, default=4, 52 | help="action repeat #") 53 | parser.add_argument("--clip-reward", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True) 54 | 55 | # fov setting 56 | parser.add_argument("--fov-size", type=int, default=50) 57 | parser.add_argument("--fov-init-loc", type=int, default=0) 58 | parser.add_argument("--sensory-action-mode", type=str, default="absolute") 59 | parser.add_argument("--sensory-action-space", type=int, default=10) # ignored when sensory_action_mode="relative" 60 | parser.add_argument("--resize-to-full", default=False, action="store_true") 61 | parser.add_argument("--peripheral-res", type=int, default=20) 62 | # for discrete observ action 63 | parser.add_argument("--sensory-action-x-size", type=int, default=4) 64 | parser.add_argument("--sensory-action-y-size", type=int, default=4) 65 | # pvm setting 66 | parser.add_argument("--pvm-stack", type=int, default=3) 67 | 68 | 69 | # Algorithm specific arguments 70 | parser.add_argument("--total-timesteps", type=int, default=3000000, 71 | help="total timesteps of the experiments") 72 | parser.add_argument("--learning-rate", type=float, default=1e-4, 73 | help="the learning rate of the optimizer") 74 | parser.add_argument("--buffer-size", type=int, default=500000, 75 | help="the replay memory buffer size") 76 | parser.add_argument("--gamma", type=float, default=0.99, 77 | help="the discount factor gamma") 78 | parser.add_argument("--target-network-frequency", type=int, default=1000, 79 | help="the timesteps it takes to update the target network") 80 | parser.add_argument("--batch-size", type=int, default=32, 81 | help="the batch size of sample from the reply memory") 82 | parser.add_argument("--start-e", type=float, default=1, 83 | help="the starting epsilon for exploration") 84 | parser.add_argument("--end-e", type=float, default=0.01, 85 | help="the ending epsilon for exploration") 86 | parser.add_argument("--exploration-fraction", type=float, default=0.10, 87 | help="the fraction of `total-timesteps` it takes from start-e to go end-e") 88 | parser.add_argument("--learning-starts", type=int, default=80000, 89 | help="timestep to start learning") 90 | parser.add_argument("--train-frequency", type=int, default=4, 91 | help="the frequency of training") 92 | 93 | # eval args 94 | parser.add_argument("--eval-frequency", type=int, default=-1, 95 | help="eval frequency. default -1 is eval at the end.") 96 | parser.add_argument("--eval-num", type=int, default=10, 97 | help="eval frequency. default -1 is eval at the end.") 98 | args = parser.parse_args() 99 | # fmt: on 100 | return args 101 | 102 | 103 | def make_env(env_name, seed, **kwargs): 104 | def thunk(): 105 | env_args = AtariEnvArgs( 106 | game=env_name, seed=seed, obs_size=(84, 84), **kwargs 107 | ) 108 | env = AtariFixedFovealPeripheralEnv(env_args) 109 | env.action_space.seed(seed) 110 | env.observation_space.seed(seed) 111 | return env 112 | 113 | return thunk 114 | 115 | 116 | # ALGO LOGIC: initialize agent here: 117 | class QNetwork(nn.Module): 118 | def __init__(self, env, sensory_action_set=None): 119 | super().__init__() 120 | if isinstance(env.single_action_space, Discrete): 121 | motor_action_space_size = env.single_action_space.n 122 | sensory_action_space_size = None 123 | elif isinstance(env.single_action_space, Dict): 124 | motor_action_space_size = env.single_action_space["motor_action"].n 125 | if sensory_action_set is not None: 126 | sensory_action_space_size = len(sensory_action_set) 127 | else: 128 | sensory_action_space_size = env.single_action_space["sensory_action"].n 129 | self.backbone = nn.Sequential( 130 | nn.Conv2d(4, 32, 8, stride=4), 131 | nn.ReLU(), 132 | nn.Conv2d(32, 64, 4, stride=2), 133 | nn.ReLU(), 134 | nn.Conv2d(64, 64, 3, stride=1), 135 | nn.ReLU(), 136 | nn.Flatten(), 137 | nn.Linear(3136, 512), 138 | nn.ReLU(), 139 | ) 140 | 141 | self.motor_action_head = nn.Linear(512, motor_action_space_size) 142 | self.sensory_action_head = None 143 | if sensory_action_space_size is not None: 144 | self.sensory_action_head = nn.Linear(512, sensory_action_space_size) 145 | 146 | 147 | def forward(self, x): 148 | x = self.backbone(x) 149 | motor_action = self.motor_action_head(x) 150 | sensory_action = None 151 | if self.sensory_action_head: 152 | sensory_action = self.sensory_action_head(x) 153 | return motor_action, sensory_action 154 | 155 | class SelfPredictionNetwork(nn.Module): 156 | def __init__(self, env, sensory_action_set=None): 157 | super().__init__() 158 | if isinstance(env.single_action_space, Discrete): 159 | motor_action_space_size = env.single_action_space.n 160 | sensory_action_space_size = None 161 | elif isinstance(env.single_action_space, Dict): 162 | motor_action_space_size = env.single_action_space["motor_action"].n 163 | if sensory_action_set is not None: 164 | sensory_action_space_size = len(sensory_action_set) 165 | else: 166 | sensory_action_space_size = env.single_action_space["sensory_action"].n 167 | 168 | self.backbone = nn.Sequential( 169 | nn.Conv2d(8, 32, 8, stride=4), 170 | nn.ReLU(), 171 | nn.Conv2d(32, 64, 4, stride=2), 172 | nn.ReLU(), 173 | nn.Conv2d(64, 64, 3, stride=1), 174 | nn.ReLU(), 175 | nn.Flatten(), 176 | nn.Linear(3136, 512), 177 | nn.ReLU(), 178 | ) 179 | 180 | self.head = nn.Sequential( 181 | nn.Linear(512, motor_action_space_size), 182 | ) 183 | 184 | self.loss = nn.CrossEntropyLoss() 185 | 186 | def get_loss(self, x, target) -> torch.Tensor: 187 | return self.loss(x, target) 188 | 189 | 190 | def forward(self, x): 191 | x = self.backbone(x) 192 | x = self.head(x) 193 | return x 194 | 195 | 196 | def linear_schedule(start_e: float, end_e: float, duration: int, t: int): 197 | slope = (end_e - start_e) / duration 198 | return max(slope * t + start_e, end_e) 199 | 200 | 201 | if __name__ == "__main__": 202 | args = parse_args() 203 | args.env = args.env.lower() 204 | run_name = f"{args.env}__{os.path.basename(__file__)}__{args.seed}__{get_timestr()}" 205 | run_dir = os.path.join("runs", args.exp_name) 206 | if not os.path.exists(run_dir): 207 | os.makedirs(run_dir, exist_ok=True) 208 | 209 | writer = SummaryWriter(os.path.join(run_dir, run_name)) 210 | writer.add_text( 211 | "hyperparameters", 212 | "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), 213 | ) 214 | 215 | # TRY NOT TO MODIFY: seeding 216 | seed_everything(args.seed) 217 | 218 | 219 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") 220 | 221 | # env setup 222 | envs = [] 223 | for i in range(args.env_num): 224 | envs.append(make_env(args.env, args.seed+i, frame_stack=args.frame_stack, action_repeat=args.action_repeat, 225 | fov_size=(args.fov_size, args.fov_size), 226 | fov_init_loc=(args.fov_init_loc, args.fov_init_loc), 227 | peripheral_res=(args.peripheral_res, args.peripheral_res), 228 | sensory_action_mode=args.sensory_action_mode, 229 | sensory_action_space=(-args.sensory_action_space, args.sensory_action_space), 230 | resize_to_full=args.resize_to_full, 231 | clip_reward=args.clip_reward, 232 | mask_out=True)) 233 | # envs = gym.vector.AsyncVectorEnv(envs) 234 | envs = gym.vector.SyncVectorEnv(envs) 235 | 236 | sugarl_r_scale = get_sugarl_reward_scale_atari(args.env) 237 | 238 | resize = Resize((84, 84)) 239 | 240 | # get a discrete observ action space 241 | OBSERVATION_SIZE = (84, 84) 242 | observ_x_max, observ_y_max = OBSERVATION_SIZE[0]-args.fov_size, OBSERVATION_SIZE[1]-args.fov_size 243 | sensory_action_step = (observ_x_max//args.sensory_action_x_size, 244 | observ_y_max//args.sensory_action_y_size) 245 | sensory_action_x_set = list(range(0, observ_x_max, sensory_action_step[0]))[:args.sensory_action_x_size] 246 | sensory_action_y_set = list(range(0, observ_y_max, sensory_action_step[1]))[:args.sensory_action_y_size] 247 | sensory_action_set = [np.array(a) for a in list(product(sensory_action_x_set, sensory_action_y_set))] 248 | 249 | q_network = QNetwork(envs, sensory_action_set=sensory_action_set).to(device) 250 | optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate) 251 | target_network = QNetwork(envs, sensory_action_set=sensory_action_set).to(device) 252 | target_network.load_state_dict(q_network.state_dict()) 253 | 254 | sfn = SelfPredictionNetwork(envs, sensory_action_set=sensory_action_set).to(device) 255 | sfn_optimizer = optim.Adam(sfn.parameters(), lr=args.learning_rate) 256 | 257 | rb = DoubleActionReplayBuffer( 258 | args.buffer_size, 259 | envs.single_observation_space, 260 | envs.single_action_space["motor_action"], 261 | Discrete(len(sensory_action_set)), 262 | device, 263 | n_envs=envs.num_envs, 264 | optimize_memory_usage=True, 265 | handle_timeout_termination=False, 266 | ) 267 | start_time = time.time() 268 | 269 | # TRY NOT TO MODIFY: start the game 270 | obs, infos = envs.reset() 271 | global_transitions = 0 272 | pvm_buffer = PVMBuffer(args.pvm_stack, (envs.num_envs, args.frame_stack,)+OBSERVATION_SIZE) 273 | 274 | while global_transitions < args.total_timesteps: 275 | pvm_buffer.append(obs) 276 | pvm_obs = pvm_buffer.get_obs(mode="stack_max") 277 | # ALGO LOGIC: put action logic here 278 | epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_transitions) 279 | if random.random() < epsilon: 280 | actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) 281 | motor_actions = np.array([actions[0]["motor_action"]]) 282 | sensory_actions = np.array([random.randint(0, len(sensory_action_set)-1)]) 283 | else: 284 | motor_q_values, sensory_q_values = q_network(resize(torch.from_numpy(pvm_obs)).to(device)) 285 | motor_actions = torch.argmax(motor_q_values, dim=1).cpu().numpy() 286 | sensory_actions = torch.argmax(sensory_q_values, dim=1).cpu().numpy() 287 | 288 | # TRY NOT TO MODIFY: execute the game and log data. 289 | next_obs, rewards, dones, _, infos = envs.step({"motor_action": motor_actions, 290 | "sensory_action": [sensory_action_set[a] for a in sensory_actions] }) 291 | # print (global_step, infos) 292 | 293 | # TRY NOT TO MODIFY: record rewards for plotting purposes 294 | if "final_info" in infos: 295 | for idx, d in enumerate(dones): 296 | if d: 297 | print(f"[T: {time.time()-start_time:.2f}] [N: {global_transitions:07,d}] [R: {infos['final_info'][idx]['reward']:.2f}]") 298 | writer.add_scalar("charts/episodic_return", infos['final_info'][idx]["reward"], global_transitions) 299 | writer.add_scalar("charts/episodic_length", infos['final_info'][idx]["ep_len"], global_transitions) 300 | writer.add_scalar("charts/epsilon", epsilon, global_transitions) 301 | break 302 | 303 | # TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation` 304 | real_next_obs = next_obs 305 | for idx, d in enumerate(dones): 306 | if d: 307 | real_next_obs[idx] = infos["final_observation"][idx] 308 | pvm_buffer_copy = pvm_buffer.copy() 309 | pvm_buffer_copy.append(real_next_obs) 310 | real_next_pvm_obs = pvm_buffer_copy.get_obs(mode="stack_max") 311 | rb.add(pvm_obs, real_next_pvm_obs, motor_actions, sensory_actions, rewards, dones, {}) 312 | 313 | # TRY NOT TO MODIFY: CRUCIAL step easy to overlook 314 | obs = next_obs 315 | 316 | # INC total transitions 317 | global_transitions += args.env_num 318 | 319 | obs_backup = obs # back obs 320 | 321 | if global_transitions < args.batch_size: 322 | continue 323 | 324 | # Training 325 | if global_transitions % args.train_frequency == 0: 326 | data = rb.sample(args.batch_size // args.env_num) # counter-balance the true global transitions used for training 327 | 328 | # sfn learning 329 | concat_observation = torch.concat([data.next_observations, data.observations], dim=1) # concat at dimension T 330 | pred_motor_actions = sfn(resize(concat_observation)) 331 | # print (pred_motor_actions.size(), data.motor_actions.size()) 332 | sfn_loss = sfn.get_loss(pred_motor_actions, data.motor_actions.flatten()) 333 | sfn_optimizer.zero_grad() 334 | sfn_loss.backward() 335 | sfn_optimizer.step() 336 | observ_r = F.softmax(pred_motor_actions).gather(1, data.motor_actions).squeeze().detach() # 0-1 337 | 338 | # Q network learning 339 | if global_transitions > args.learning_starts: 340 | with torch.no_grad(): 341 | motor_target, sensory_target = target_network(resize(data.next_observations)) 342 | motor_target_max, _ = motor_target.max(dim=1) 343 | sensory_target_max, _ = sensory_target.max(dim=1) 344 | # scale step-wise reward with observ_r 345 | observ_r_adjusted = observ_r.clone() 346 | observ_r_adjusted[data.rewards.flatten() > 0] = 1 - observ_r_adjusted[data.rewards.flatten() > 0] 347 | td_target = data.rewards.flatten() - (1 - observ_r) * sugarl_r_scale + args.gamma * (motor_target_max+sensory_target_max) * (1 - data.dones.flatten()) 348 | original_td_target = data.rewards.flatten() + args.gamma * (motor_target_max+sensory_target_max) * (1 - data.dones.flatten()) 349 | 350 | old_motor_q_value, old_sensory_q_val = q_network(resize(data.observations)) 351 | old_motor_val = old_motor_q_value.gather(1, data.motor_actions).squeeze() 352 | old_sensory_val = old_sensory_q_val.gather(1, data.sensory_actions).squeeze() 353 | old_val = old_motor_val + old_sensory_val 354 | 355 | loss = F.mse_loss(td_target, old_val) 356 | 357 | if global_transitions % 100 == 0: 358 | writer.add_scalar("losses/td_loss", loss, global_transitions) 359 | writer.add_scalar("losses/q_values", old_val.mean().item(), global_transitions) 360 | writer.add_scalar("losses/motor_q_values", old_motor_val.mean().item(), global_transitions) 361 | writer.add_scalar("losses/action_q_values", old_sensory_val.mean().item(), global_transitions) 362 | # print("SPS:", int(global_step / (time.time() - start_time))) 363 | writer.add_scalar("charts/SPS", int(global_transitions / (time.time() - start_time)), global_transitions) 364 | 365 | writer.add_scalar("losses/sfn_loss", sfn_loss.item(), global_transitions) 366 | writer.add_scalar("losses/observ_r", observ_r.mean().item(), global_transitions) 367 | writer.add_scalar("losses/original_td_target", original_td_target.mean().item(), global_transitions) 368 | writer.add_scalar("losses/sugarl_r_scaled_td_target", td_target.mean().item(), global_transitions) 369 | 370 | # optimize the model 371 | optimizer.zero_grad() 372 | loss.backward() 373 | optimizer.step() 374 | 375 | # update the target network 376 | if (global_transitions // args.env_num) % args.target_network_frequency == 0: 377 | target_network.load_state_dict(q_network.state_dict()) 378 | 379 | # evaluation 380 | if (global_transitions % args.eval_frequency == 0 and args.eval_frequency > 0) or \ 381 | (global_transitions >= args.total_timesteps): 382 | q_network.eval() 383 | sfn.eval() 384 | 385 | eval_episodic_returns, eval_episodic_lengths = [], [] 386 | 387 | for eval_ep in range(args.eval_num): 388 | eval_env = [make_env(args.env, args.seed+eval_ep, frame_stack=args.frame_stack, action_repeat=args.action_repeat, 389 | fov_size=(args.fov_size, args.fov_size), 390 | fov_init_loc=(args.fov_init_loc, args.fov_init_loc), 391 | peripheral_res=(args.peripheral_res, args.peripheral_res), 392 | sensory_action_mode=args.sensory_action_mode, 393 | sensory_action_space=(-args.sensory_action_space, args.sensory_action_space), 394 | resize_to_full=args.resize_to_full, 395 | clip_reward=args.clip_reward, 396 | mask_out=True, 397 | training=False, 398 | record=args.capture_video)] 399 | eval_env = gym.vector.SyncVectorEnv(eval_env) 400 | obs_eval, _ = eval_env.reset() 401 | done = False 402 | pvm_buffer_eval = PVMBuffer(args.pvm_stack, (eval_env.num_envs, args.frame_stack,)+OBSERVATION_SIZE) 403 | while not done: 404 | pvm_buffer_eval.append(obs_eval) 405 | pvm_obs_eval = pvm_buffer_eval.get_obs(mode="stack_max") 406 | motor_q_values, sensory_q_values = q_network(resize(torch.from_numpy(pvm_obs_eval)).to(device)) 407 | motor_actions = torch.argmax(motor_q_values, dim=1).cpu().numpy() 408 | sensory_actions = torch.argmax(sensory_q_values, dim=1).cpu().numpy() 409 | next_obs_eval, rewards, dones, _, infos = eval_env.step({"motor_action": motor_actions, 410 | "sensory_action": [sensory_action_set[a] for a in sensory_actions]}) 411 | obs_eval = next_obs_eval 412 | done = dones[0] 413 | if done: 414 | eval_episodic_returns.append(infos['final_info'][0]["reward"]) 415 | eval_episodic_lengths.append(infos['final_info'][0]["ep_len"]) 416 | if args.capture_video: 417 | record_file_dir = os.path.join("recordings", args.exp_name, os.path.basename(__file__).rstrip(".py"), args.env) 418 | os.makedirs(record_file_dir, exist_ok=True) 419 | record_file_fn = f"{args.env}_seed{args.seed}_step{global_transitions:07d}_eval{eval_ep:02d}_record.pt" 420 | eval_env.envs[0].save_record_to_file(os.path.join(record_file_dir, record_file_fn)) 421 | if eval_ep == 0: 422 | model_file_dir = os.path.join("trained_models", args.exp_name, os.path.basename(__file__).rstrip(".py"), args.env) 423 | os.makedirs(model_file_dir, exist_ok=True) 424 | model_fn = f"{args.env}_seed{args.seed}_step{global_transitions:07d}_model.pt" 425 | torch.save({"sfn": sfn.state_dict(), "q": q_network.state_dict()}, os.path.join(model_file_dir, model_fn)) 426 | 427 | 428 | writer.add_scalar("charts/eval_episodic_return", np.mean(eval_episodic_returns), global_transitions) 429 | writer.add_scalar("charts/eval_episodic_return_std", np.std(eval_episodic_returns), global_transitions) 430 | # writer.add_scalar("charts/eval_episodic_length", np.mean(), global_transitions) 431 | print(f"[T: {time.time()-start_time:.2f}] [N: {global_transitions:07,d}] [Eval R: {np.mean(eval_episodic_returns):.2f}+/-{np.std(eval_episodic_returns):.2f}] [R list: {','.join([str(r) for r in eval_episodic_returns])}]") 432 | 433 | q_network.train() 434 | sfn.train() 435 | 436 | obs = obs_backup # restore obs if eval occurs 437 | 438 | 439 | 440 | envs.close() 441 | eval_env.close() 442 | writer.close() --------------------------------------------------------------------------------