├── common
├── __init__.py
├── pvm_buffer.py
└── utils.py
├── scripts
├── dmc_series.sh
├── atari_series.sh
├── atari_series_5m.sh
├── atari_wp_series.sh
├── atari_100k_base.sh
├── robosuite_series.sh
├── dmc_test.sh
├── robosuite_test.sh
├── atari_test.sh
├── atari_100k_ns.sh
├── dmc_20.sh
├── dmc_30.sh
├── dmc_50.sh
├── atari_100k_20x20.sh
├── atari_100k_30x30.sh
├── atari_100k_50x50.sh
├── atari_100k_wp20_20x20.sh
├── atari_100k_wp20_30x30.sh
├── atari_100k_wp20_50x50.sh
├── atari_100k_5m.sh
├── atari_100k_wp20_peripheral_only.sh
├── atari_100k_20x20_5m.sh
├── atari_100k_30x30_5m.sh
└── atari_100k_50x50_5m.sh
├── _doc
└── media
│ └── sugarl_formulation.png
├── .gitignore
├── active_rl_env.yaml
├── README.md
└── agent
├── dqn_atari_base.py
├── dqn_atari_wp_random.py
├── dqn_atari_wp_base_peripheral.py
├── dqn_atari_wp_single_policy.py
├── dqn_atari_single_policy.py
├── sac_atari_base.py
├── drqv2_dmc_base.py
├── dqn_atari_sugarl.py
└── dqn_atari_wp_sugarl.py
/common/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/scripts/dmc_series.sh:
--------------------------------------------------------------------------------
1 | ./scripts/dmc_20.sh $1 ;
2 | ./scripts/dmc_30.sh $1 ;
3 | ./scripts/dmc_50.sh $1 ;
--------------------------------------------------------------------------------
/_doc/media/sugarl_formulation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elicassion/sugarl/HEAD/_doc/media/sugarl_formulation.png
--------------------------------------------------------------------------------
/scripts/atari_series.sh:
--------------------------------------------------------------------------------
1 | ./scripts/atari_100k_20x20.sh $1 ;
2 | ./scripts/atari_100k_30x30.sh $1 ;
3 | ./scripts/atari_100k_50x50.sh $1 ;
--------------------------------------------------------------------------------
/scripts/atari_series_5m.sh:
--------------------------------------------------------------------------------
1 | ./scripts/atari_100k_20x20_5m.sh $1 ;
2 | ./scripts/atari_100k_30x30_5m.sh $1 ;
3 | ./scripts/atari_100k_50x50_5m.sh $1 ;
--------------------------------------------------------------------------------
/scripts/atari_wp_series.sh:
--------------------------------------------------------------------------------
1 | ./scripts/atari_100k_wp20_20x20.sh $1 ;
2 | ./scripts/atari_100k_wp20_30x30.sh $1 ;
3 | ./scripts/atari_100k_wp20_50x50.sh $1 ;
4 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | */__pycache__
2 | *.pyc
3 | *.txt
4 | runs
5 | runs/*
6 | logs
7 | logs/*
8 | recordings
9 | recordings/*
10 | trained_models
11 | trained_models/*
12 | */.ipynb_checkpoints
13 | *.egg-info
14 |
--------------------------------------------------------------------------------
/active_rl_env.yaml:
--------------------------------------------------------------------------------
1 | name: arl
2 |
3 | channels:
4 | - defaults
5 | - pytorch
6 | - nvidia
7 |
8 | dependencies:
9 | - python=3.9
10 | - pip
11 |
12 | # pytorch
13 | - pytorch==1.13.1
14 | - torchvision
15 | - torchaudio
16 | - pytorch-cuda=11.6
17 |
18 | - pip:
19 | - requests
20 | - joblib
21 | - psutil
22 | - h5py
23 | - lxml
24 | - colorama
25 |
26 | - jupyter # Used to show the notebook
27 | - jupyterlab
28 |
29 | - scipy
30 | - matplotlib # Used for vis
31 | - tqdm
32 |
33 | - mujoco<3.0
34 | - mujoco-py>=2.1.2.14
35 | - atari-py
36 | - dm-control==1.0.11
37 | - gymnasium>=0.28.1,<1.0.0
38 |
39 | - numpy<2.0.0
40 | - pandas
41 | - einops
42 | - opencv-python
43 |
44 | - tensorboard
45 | - tqdm
46 |
47 | - pybullet
48 | - ptflops # get model computation info
49 |
50 | - git+https://github.com/elicassion/active-gym.git
51 |
--------------------------------------------------------------------------------
/scripts/atari_100k_base.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | numgpus=${2:-$(nvidia-smi --list-gpus | wc -l)}
3 |
4 | # atari 100k settings
5 | envlist=(alien amidar assault asterix bank_heist battle_zone boxing breakout chopper_command crazy_climber demon_attack freeway frostbite gopher hero jamesbond kangaroo krull kung_fu_master ms_pacman pong private_eye qbert road_runner seaquest up_n_down) #pong qbert seaquest zaxxon
6 |
7 | expname="atari_100k_base"
8 | totaltimesteps="1000000"
9 | buffersize="100000"
10 | learningstarts="5000"
11 |
12 | mkdir -p logs/${expname}
13 | mkdir -p recordings/${expname}
14 | mkdir -p trained_models/${expname}
15 |
16 | for i in ${!envlist[@]}
17 | do
18 | gpuid=$(( $i % $numgpus ))
19 | (
20 | for seed in 0 1 2 3 4
21 | do
22 | echo "Atari100k GPU: ${gpuid} Env: ${envlist[$i]} Seed: ${seed} ${1}"
23 | # sleep 5
24 | basename=$(basename $1)
25 | echo "========" >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt
26 | CUDA_VISIBLE_DEVICES=$gpuid python $1 --env ${envlist[$i]} --seed $seed --exp-name ${expname} \
27 | --total-timesteps $totaltimesteps --buffer-size $buffersize \
28 | --learning-starts $learningstarts \
29 | ${@:2} >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt
30 | done
31 | ) &
32 | done
33 | wait
--------------------------------------------------------------------------------
/scripts/robosuite_series.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | numgpus=${2:-$(nvidia-smi --list-gpus | wc -l)}
3 | # numgpus=6
4 | gpustart=0
5 |
6 | # robosuite 1m settings
7 | tasklist=(Wipe Stack NutAssemblySquare Door Lift)
8 |
9 | expname="robosuite"
10 | totaltimesteps="100000"
11 | buffersize="100000"
12 | learningstarts="2000"
13 |
14 | mkdir -p logs/${expname}
15 | mkdir -p recordings/${expname}
16 | mkdir -p trained_models/${expname}
17 |
18 | jobnum=0+$gpustart
19 |
20 | for i in ${!tasklist[@]}
21 | do
22 | for seed in 0 1 2
23 | do
24 | gpuid=$(( $jobnum % $numgpus ))
25 | echo "${expname} GPU: ${gpuid} Env: ${tasklist[$i]} Seed: ${seed} ${1}"
26 | # sleep 5
27 | basename=$(basename $1)
28 | echo "========" >> logs/${expname}/${tasklist[$i]}__${basename}__${seed}.txt
29 | CUDA_VISIBLE_DEVICES=$gpuid python $1 --task-name ${tasklist[$i]} \
30 | --seed $seed --exp-name $expname \
31 | --capture-video \
32 | --total-timesteps $totaltimesteps --buffer-size $buffersize \
33 | --learning-starts $learningstarts \
34 | --eval-num 10 \
35 | ${@:3} >> logs/${expname}/${tasklist[$i]}__${basename}__${seed}.txt &
36 | ((jobnum=jobnum+1))
37 | done
38 | done
39 | wait
40 |
--------------------------------------------------------------------------------
/scripts/dmc_test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | numgpus=${2:-$(nvidia-smi --list-gpus | wc -l)}
3 |
4 | # dmc 1m settings
5 | domainlist=(reacher) # reacher walker humanoid cheetah dog)
6 | tasklist=(easy) # hard walk walk run fetch)
7 |
8 | expname="dmc_test"
9 | totaltimesteps="1000"
10 | buffersize="500"
11 | learningstarts="300"
12 |
13 | mkdir -p logs/${expname}
14 | mkdir -p recordings/${expname}
15 | mkdir -p trained_models/${expname}
16 |
17 | for i in ${!domainlist[@]}
18 | do
19 | gpuid=$(( $i % $numgpus ))
20 | (
21 | for seed in 0
22 | do
23 | echo "${expname} GPU: ${gpuid} Env: ${envlist[$i]} Seed: ${seed} ${1}"
24 | # sleep 5
25 | basename=$(basename $1)
26 | echo "========" #>> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt
27 | CUDA_VISIBLE_DEVICES=$gpuid python $1 --domain-name ${domainlist[$i]} --task-name ${tasklist[$i]} \
28 | --seed $seed --exp-name $expname \
29 | --fov-size 20 \
30 | --capture-video \
31 | --total-timesteps $totaltimesteps --buffer-size $buffersize \
32 | --learning-starts $learningstarts \
33 | ${@:3} #>> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt
34 | done
35 | ) &
36 | done
37 | wait
--------------------------------------------------------------------------------
/scripts/robosuite_test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | cudagpus=${2:-$(nvidia-smi --list-gpus | wc -l)}
3 | # gpulist=($(seq 0 1 $cudagpus))
4 | gpulist=(3)
5 |
6 | numgpus=${#gpulist[@]}
7 | # robosuite 1m settings
8 | tasklist=(Door)
9 |
10 | expname="robosuite_test"
11 | totaltimesteps="1000"
12 | buffersize="500"
13 | learningstarts="300"
14 |
15 | mkdir -p logs/${expname}
16 | mkdir -p recordings/${expname}
17 | mkdir -p trained_models/${expname}
18 |
19 | for i in ${!tasklist[@]}
20 | do
21 | gpuindex=$(( $i % $numgpus ))
22 | gpuid=${gpulist[$gpuindex]}
23 | (
24 | for seed in 0
25 | do
26 | echo "${expname} GPU: ${gpuid} Env: ${tasklist[$i]} Seed: ${seed} ${1}"
27 | # sleep 5
28 | basename=$(basename $1)
29 | echo "========" # >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt
30 | CUDA_VISIBLE_DEVICES=$gpuid python $1 --task-name ${tasklist[$i]} \
31 | --seed $seed --exp-name $expname \
32 | --capture-video \
33 | --total-timesteps $totaltimesteps --buffer-size $buffersize \
34 | --learning-starts $learningstarts \
35 | --eval-num 1 \
36 | ${@:3} # >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt
37 | done
38 | ) &
39 | done
40 | wait
--------------------------------------------------------------------------------
/scripts/atari_test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | numgpus=1
3 |
4 | # atari 100k settings
5 | envlist=(alien amidar) #assault asterix bank_heist battle_zone boxing breakout chopper_command crazy_climber demon_attack freeway frostbite gopher hero jamesbond kangaroo krull kung_fu_master ms_pacman pong private_eye qbert road_runner seaquest up_n_down #pong qbert seaquest zaxxon
6 |
7 | expname="test"
8 | totaltimesteps="1000"
9 | buffersize="500"
10 | learningstarts="300"
11 |
12 | mkdir -p logs/${expname}
13 | mkdir -p recordings/${expname}
14 | mkdir -p trained_models/${expname}
15 |
16 | for i in ${!envlist[@]}
17 | do
18 | gpuid=$(( $i % $numgpus ))
19 | (
20 | for seed in 0
21 | do
22 | echo "${expname} GPU: ${gpuid} Env: ${envlist[$i]} Seed: ${seed} ${1}"
23 | # sleep 5
24 | basename=$(basename $1)
25 | # echo "========" >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt
26 | CUDA_VISIBLE_DEVICES=3 python $1 --env ${envlist[$i]} --seed $seed --exp-name ${expname} \
27 | --fov-size 20 \
28 | --total-timesteps $totaltimesteps --buffer-size $buffersize \
29 | --learning-starts $learningstarts \
30 | --capture-video \
31 | ${@:2} #>> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt
32 | done
33 | ) &
34 | done
35 | wait
--------------------------------------------------------------------------------
/scripts/atari_100k_ns.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | numgpus=${2:-$(nvidia-smi --list-gpus | wc -l)}
3 |
4 | # atari 100k settings
5 | envlist=(alien amidar assault asterix bank_heist battle_zone boxing breakout chopper_command crazy_climber demon_attack freeway frostbite gopher hero jamesbond kangaroo krull kung_fu_master ms_pacman pong private_eye qbert road_runner seaquest up_n_down) #pong qbert seaquest zaxxon
6 |
7 | expname="atari_100k"
8 | totaltimesteps="1000000"
9 | buffersize="100000"
10 | learningstarts="80000"
11 |
12 | mkdir -p logs/${expname}
13 | mkdir -p recordings/${expname}
14 | mkdir -p trained_models/${expname}
15 |
16 | for i in ${!envlist[@]}
17 | do
18 | gpuid=$(( $i % $numgpus ))
19 | (
20 | for seed in 0 1 2 3 4
21 | do
22 | echo "Atari100k GPU: ${gpuid} Env: ${envlist[$i]} Seed: ${seed} ${1}"
23 | # sleep 5
24 | basename=$(basename $1)
25 | echo "========" >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt
26 | CUDA_VISIBLE_DEVICES=$gpuid python $1 --env ${envlist[$i]} --seed $seed --exp-name ${expname} \
27 | --total-timesteps $totaltimesteps --buffer-size $buffersize \
28 | --capture-video \
29 | --clip-reward \
30 | --learning-starts $learningstarts \
31 | ${@:2} >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt
32 | done
33 | ) &
34 | done
35 | wait
--------------------------------------------------------------------------------
/scripts/dmc_20.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | numgpus=${2:-$(nvidia-smi --list-gpus | wc -l)}
3 |
4 | # dmc 1m settings
5 | domainlist=(ball_in_cup cartpole cheetah dog fish walker)
6 | tasklist=(catch swingup run fetch swim walk)
7 |
8 | expname="dmc_20"
9 | totaltimesteps="100000"
10 | buffersize="100000"
11 | learningstarts="2000"
12 |
13 | mkdir -p logs/${expname}
14 | mkdir -p recordings/${expname}
15 | mkdir -p trained_models/${expname}
16 |
17 | jobnum=0
18 |
19 | for i in ${!domainlist[@]}
20 | do
21 | for seed in 0 1 2 3 4
22 | do
23 | gpuid=$(( $jobnum % $numgpus ))
24 | echo "${expname} GPU: ${gpuid} Env: ${domainlist[$i]}-${tasklist[$i]} Seed: ${seed} ${1}"
25 | # sleep 5
26 | basename=$(basename $1)
27 | echo "========" >> logs/${expname}/${domainlist[$i]}-${tasklist[$i]}__${basename}__${seed}.txt
28 | MUJOCO_EGL_DEVICE_ID=$gpuid CUDA_VISIBLE_DEVICES=$gpuid python $1 --domain-name ${domainlist[$i]} --task-name ${tasklist[$i]} \
29 | --seed $seed --exp-name $expname \
30 | --fov-size 20 \
31 | --capture-video \
32 | --total-timesteps $totaltimesteps --buffer-size $buffersize \
33 | --learning-starts $learningstarts \
34 | --eval-frequency 10000 \
35 | ${@:3} >> logs/${expname}/${domainlist[$i]}-${tasklist[$i]}__${basename}__${seed}.txt &
36 | ((jobnum=jobnum+1))
37 | done
38 | done
39 | wait
--------------------------------------------------------------------------------
/scripts/dmc_30.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | numgpus=${2:-$(nvidia-smi --list-gpus | wc -l)}
3 |
4 | # dmc 1m settings
5 | domainlist=(ball_in_cup cartpole cheetah dog fish walker)
6 | tasklist=(catch swingup run fetch swim walk)
7 |
8 | expname="dmc_30"
9 | totaltimesteps="100000"
10 | buffersize="100000"
11 | learningstarts="2000"
12 |
13 | mkdir -p logs/${expname}
14 | mkdir -p recordings/${expname}
15 | mkdir -p trained_models/${expname}
16 |
17 | jobnum=0
18 |
19 | for i in ${!domainlist[@]}
20 | do
21 | for seed in 0 1 2 3 4
22 | do
23 | gpuid=$(( $jobnum % $numgpus ))
24 | echo "${expname} GPU: ${gpuid} Env: ${domainlist[$i]}-${tasklist[$i]} Seed: ${seed} ${1}"
25 | # sleep 5
26 | basename=$(basename $1)
27 | echo "========" >> logs/${expname}/${domainlist[$i]}-${tasklist[$i]}__${basename}__${seed}.txt
28 | MUJOCO_EGL_DEVICE_ID=$gpuid CUDA_VISIBLE_DEVICES=$gpuid python $1 --domain-name ${domainlist[$i]} --task-name ${tasklist[$i]} \
29 | --seed $seed --exp-name $expname \
30 | --fov-size 30 \
31 | --capture-video \
32 | --total-timesteps $totaltimesteps --buffer-size $buffersize \
33 | --learning-starts $learningstarts \
34 | --eval-frequency 10000 \
35 | ${@:3} >> logs/${expname}/${domainlist[$i]}-${tasklist[$i]}__${basename}__${seed}.txt &
36 | ((jobnum=jobnum+1))
37 | done
38 | done
39 | wait
--------------------------------------------------------------------------------
/scripts/dmc_50.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | numgpus=${2:-$(nvidia-smi --list-gpus | wc -l)}
3 |
4 | # dmc 1m settings
5 | domainlist=(ball_in_cup cartpole cheetah dog fish walker)
6 | tasklist=(catch swingup run fetch swim walk)
7 |
8 | expname="dmc_50"
9 | totaltimesteps="100000"
10 | buffersize="100000"
11 | learningstarts="2000"
12 |
13 | mkdir -p logs/${expname}
14 | mkdir -p recordings/${expname}
15 | mkdir -p trained_models/${expname}
16 |
17 | jobnum=0
18 |
19 | for i in ${!domainlist[@]}
20 | do
21 | for seed in 0 1 2 3 4
22 | do
23 | gpuid=$(( $jobnum % $numgpus ))
24 | echo "${expname} GPU: ${gpuid} Env: ${domainlist[$i]}-${tasklist[$i]} Seed: ${seed} ${1}"
25 | # sleep 5
26 | basename=$(basename $1)
27 | echo "========" >> logs/${expname}/${domainlist[$i]}-${tasklist[$i]}__${basename}__${seed}.txt
28 | MUJOCO_EGL_DEVICE_ID=$gpuid CUDA_VISIBLE_DEVICES=$gpuid python $1 --domain-name ${domainlist[$i]} --task-name ${tasklist[$i]} \
29 | --seed $seed --exp-name $expname \
30 | --fov-size 50 \
31 | --capture-video \
32 | --total-timesteps $totaltimesteps --buffer-size $buffersize \
33 | --learning-starts $learningstarts \
34 | --eval-frequency 10000 \
35 | ${@:3} >> logs/${expname}/${domainlist[$i]}-${tasklist[$i]}__${basename}__${seed}.txt &
36 | ((jobnum=jobnum+1))
37 | done
38 | done
39 | wait
--------------------------------------------------------------------------------
/scripts/atari_100k_20x20.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | numgpus=${2:-$(nvidia-smi --list-gpus | wc -l)}
3 |
4 | # atari 100k settings
5 | envlist=(alien amidar assault asterix bank_heist battle_zone boxing breakout chopper_command crazy_climber demon_attack freeway frostbite gopher hero jamesbond kangaroo krull kung_fu_master ms_pacman pong private_eye qbert road_runner seaquest up_n_down) #pong qbert seaquest zaxxon
6 |
7 | expname="atari_100k_20x20"
8 | totaltimesteps="1000000"
9 | buffersize="100000"
10 | learningstarts="80000"
11 |
12 | mkdir -p logs/${expname}
13 | mkdir -p recordings/${expname}
14 | mkdir -p trained_models/${expname}
15 |
16 | for i in ${!envlist[@]}
17 | do
18 | gpuid=$(( $i % $numgpus ))
19 | (
20 | for seed in 0 1 2 3 4
21 | do
22 | echo "${expname} GPU: ${gpuid} Env: ${envlist[$i]} Seed: ${seed} ${1}"
23 | # sleep 5
24 | basename=$(basename $1)
25 | echo "========" >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt
26 | CUDA_VISIBLE_DEVICES=$gpuid python $1 --env ${envlist[$i]} --seed $seed --exp-name ${expname} \
27 | --fov-size 20 \
28 | --clip-reward \
29 | --capture-video \
30 | --total-timesteps $totaltimesteps --buffer-size $buffersize \
31 | --learning-starts $learningstarts \
32 | ${@:2} >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt
33 | done
34 | ) &
35 | done
36 | wait
--------------------------------------------------------------------------------
/scripts/atari_100k_30x30.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | numgpus=${2:-$(nvidia-smi --list-gpus | wc -l)}
3 |
4 | # atari 100k settings
5 | envlist=(alien amidar assault asterix bank_heist battle_zone boxing breakout chopper_command crazy_climber demon_attack freeway frostbite gopher hero jamesbond kangaroo krull kung_fu_master ms_pacman pong private_eye qbert road_runner seaquest up_n_down) #pong qbert seaquest zaxxon
6 |
7 | expname="atari_100k_30x30"
8 | totaltimesteps="1000000"
9 | buffersize="100000"
10 | learningstarts="80000"
11 |
12 | mkdir -p logs/${expname}
13 | mkdir -p recordings/${expname}
14 | mkdir -p trained_models/${expname}
15 |
16 | for i in ${!envlist[@]}
17 | do
18 | gpuid=$(( $i % $numgpus ))
19 | (
20 | for seed in 0 1 2 3 4
21 | do
22 | echo "${expname} GPU: ${gpuid} Env: ${envlist[$i]} Seed: ${seed} ${1}"
23 | # sleep 5
24 | basename=$(basename $1)
25 | echo "========" >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt
26 | CUDA_VISIBLE_DEVICES=$gpuid python $1 --env ${envlist[$i]} --seed $seed --exp-name ${expname} \
27 | --fov-size 30 \
28 | --clip-reward \
29 | --capture-video \
30 | --total-timesteps $totaltimesteps --buffer-size $buffersize \
31 | --learning-starts $learningstarts \
32 | ${@:2} >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt
33 | done
34 | ) &
35 | done
36 | wait
--------------------------------------------------------------------------------
/scripts/atari_100k_50x50.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | numgpus=${2:-$(nvidia-smi --list-gpus | wc -l)}
3 |
4 | # atari 100k settings
5 | envlist=(alien amidar assault asterix bank_heist battle_zone boxing breakout chopper_command crazy_climber demon_attack freeway frostbite gopher hero jamesbond kangaroo krull kung_fu_master ms_pacman pong private_eye qbert road_runner seaquest up_n_down) #pong qbert seaquest zaxxon
6 |
7 | expname="atari_100k_50x50"
8 | totaltimesteps="1000000"
9 | buffersize="100000"
10 | learningstarts="80000"
11 |
12 | mkdir -p logs/${expname}
13 | mkdir -p recordings/${expname}
14 | mkdir -p trained_models/${expname}
15 |
16 | for i in ${!envlist[@]}
17 | do
18 | gpuid=$(( $i % $numgpus ))
19 | (
20 | for seed in 0 1 2 3 4
21 | do
22 | echo "${expname} GPU: ${gpuid} Env: ${envlist[$i]} Seed: ${seed} ${1}"
23 | # sleep 5
24 | basename=$(basename $1)
25 | echo "========" >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt
26 | CUDA_VISIBLE_DEVICES=$gpuid python $1 --env ${envlist[$i]} --seed $seed --exp-name ${expname} \
27 | --fov-size 50 \
28 | --clip-reward \
29 | --capture-video \
30 | --total-timesteps $totaltimesteps --buffer-size $buffersize \
31 | --learning-starts $learningstarts \
32 | ${@:3} >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt
33 | done
34 | ) &
35 | done
36 | wait
--------------------------------------------------------------------------------
/scripts/atari_100k_wp20_20x20.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | numgpus=${2:-$(nvidia-smi --list-gpus | wc -l)}
3 |
4 | # atari 100k settings
5 | envlist=(alien amidar assault asterix bank_heist battle_zone boxing breakout chopper_command crazy_climber demon_attack freeway frostbite gopher hero jamesbond kangaroo krull kung_fu_master ms_pacman pong private_eye qbert road_runner seaquest up_n_down) #pong qbert seaquest zaxxon
6 |
7 | expname="atari_100k_wp20_20x20"
8 | totaltimesteps="1000000"
9 | buffersize="100000"
10 | learningstarts="80000"
11 |
12 | mkdir -p logs/${expname}
13 | mkdir -p recordings/${expname}
14 | mkdir -p trained_models/${expname}
15 |
16 | for i in ${!envlist[@]}
17 | do
18 | gpuid=$(( $i % $numgpus ))
19 | (
20 | for seed in 0 1 2 3 4
21 | do
22 | echo "${expname} GPU: ${gpuid} Env: ${envlist[$i]} Seed: ${seed} ${1}"
23 | # sleep 5
24 | basename=$(basename $1)
25 | echo "========" >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt
26 | CUDA_VISIBLE_DEVICES=$gpuid python $1 --env ${envlist[$i]} --seed $seed --exp-name ${expname} \
27 | --fov-size 20 \
28 | --peripheral-res 20 \
29 | --clip-reward \
30 | --capture-video \
31 | --total-timesteps $totaltimesteps --buffer-size $buffersize \
32 | --learning-starts $learningstarts \
33 | ${@:2} >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt
34 | done
35 | ) &
36 | done
37 | wait
--------------------------------------------------------------------------------
/scripts/atari_100k_wp20_30x30.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | numgpus=${2:-$(nvidia-smi --list-gpus | wc -l)}
3 |
4 | # atari 100k settings
5 | envlist=(alien amidar assault asterix bank_heist battle_zone boxing breakout chopper_command crazy_climber demon_attack freeway frostbite gopher hero jamesbond kangaroo krull kung_fu_master ms_pacman pong private_eye qbert road_runner seaquest up_n_down) #pong qbert seaquest zaxxon
6 |
7 | expname="atari_100k_wp20_30x30"
8 | totaltimesteps="1000000"
9 | buffersize="100000"
10 | learningstarts="80000"
11 |
12 | mkdir -p logs/${expname}
13 | mkdir -p recordings/${expname}
14 | mkdir -p trained_models/${expname}
15 |
16 | for i in ${!envlist[@]}
17 | do
18 | gpuid=$(( $i % $numgpus ))
19 | (
20 | for seed in 0 1 2 3 4
21 | do
22 | echo "${expname} GPU: ${gpuid} Env: ${envlist[$i]} Seed: ${seed} ${1}"
23 | # sleep 5
24 | basename=$(basename $1)
25 | echo "========" >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt
26 | CUDA_VISIBLE_DEVICES=$gpuid python $1 --env ${envlist[$i]} --seed $seed --exp-name ${expname} \
27 | --fov-size 30 \
28 | --peripheral-res 20 \
29 | --clip-reward \
30 | --capture-video \
31 | --total-timesteps $totaltimesteps --buffer-size $buffersize \
32 | --learning-starts $learningstarts \
33 | ${@:2} >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt
34 | done
35 | ) &
36 | done
37 | wait
--------------------------------------------------------------------------------
/scripts/atari_100k_wp20_50x50.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | numgpus=${2:-$(nvidia-smi --list-gpus | wc -l)}
3 |
4 | # atari 100k settings
5 | envlist=(alien amidar assault asterix bank_heist battle_zone boxing breakout chopper_command crazy_climber demon_attack freeway frostbite gopher hero jamesbond kangaroo krull kung_fu_master ms_pacman pong private_eye qbert road_runner seaquest up_n_down) #pong qbert seaquest zaxxon
6 |
7 | expname="atari_100k_wp20_50x50"
8 | totaltimesteps="1000000"
9 | buffersize="100000"
10 | learningstarts="80000"
11 |
12 | mkdir -p logs/${expname}
13 | mkdir -p recordings/${expname}
14 | mkdir -p trained_models/${expname}
15 |
16 | for i in ${!envlist[@]}
17 | do
18 | gpuid=$(( $i % $numgpus ))
19 | (
20 | for seed in 0 1 2 3 4
21 | do
22 | echo "${expname} GPU: ${gpuid} Env: ${envlist[$i]} Seed: ${seed} ${1}"
23 | # sleep 5
24 | basename=$(basename $1)
25 | echo "========" >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt
26 | CUDA_VISIBLE_DEVICES=$gpuid python $1 --env ${envlist[$i]} --seed $seed --exp-name ${expname} \
27 | --fov-size 50 \
28 | --peripheral-res 20 \
29 | --clip-reward \
30 | --capture-video \
31 | --total-timesteps $totaltimesteps --buffer-size $buffersize \
32 | --learning-starts $learningstarts \
33 | ${@:2} >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt
34 | done
35 | ) &
36 | done
37 | wait
--------------------------------------------------------------------------------
/scripts/atari_100k_5m.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | numgpus=${2:-$(nvidia-smi --list-gpus | wc -l)}
3 |
4 | # atari 100k settings
5 | envlist=(alien amidar assault asterix bank_heist battle_zone boxing breakout chopper_command crazy_climber demon_attack freeway frostbite gopher hero jamesbond kangaroo krull kung_fu_master ms_pacman pong private_eye qbert road_runner seaquest up_n_down) #pong qbert seaquest zaxxon
6 |
7 | expname="atari_100k_5m"
8 | totaltimesteps="5000000"
9 | buffersize="100000"
10 | learningstarts="80000"
11 |
12 | mkdir -p logs/${expname}
13 | mkdir -p recordings/${expname}
14 | mkdir -p trained_models/${expname}
15 |
16 | for i in ${!envlist[@]}
17 | do
18 | gpuid=$(( $i % $numgpus ))
19 | (
20 | for seed in 0 1 2 3 4
21 | do
22 | echo "Atari100k GPU: ${gpuid} Env: ${envlist[$i]} Seed: ${seed} ${1}"
23 | # sleep 5
24 | basename=$(basename $1)
25 | echo "========" >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt
26 | CUDA_VISIBLE_DEVICES=$gpuid python $1 --env ${envlist[$i]} --seed $seed --exp-name ${expname} \
27 | --exploration-fraction 0.02 \
28 | --eval-frequency 1000000 \
29 | --total-timesteps $totaltimesteps --buffer-size $buffersize \
30 | --capture-video \
31 | --clip-reward \
32 | --learning-starts $learningstarts \
33 | ${@:2} >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt
34 | done
35 | ) &
36 | done
37 | wait
--------------------------------------------------------------------------------
/scripts/atari_100k_wp20_peripheral_only.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | numgpus=${2:-$(nvidia-smi --list-gpus | wc -l)}
3 |
4 | # atari 100k settings
5 | envlist=(alien amidar assault asterix bank_heist battle_zone boxing breakout chopper_command crazy_climber demon_attack freeway frostbite gopher hero jamesbond kangaroo krull kung_fu_master ms_pacman pong private_eye qbert road_runner seaquest up_n_down) #pong qbert seaquest zaxxon
6 |
7 | expname="atari_100k_wp20_peripheral_only"
8 | totaltimesteps="1000000"
9 | buffersize="100000"
10 | learningstarts="80000"
11 |
12 | mkdir -p logs/${expname}
13 | mkdir -p recordings/${expname}
14 | mkdir -p trained_models/${expname}
15 |
16 | for i in ${!envlist[@]}
17 | do
18 | gpuid=$(( $i % $numgpus ))
19 | (
20 | for seed in 0 1 2 3 4
21 | do
22 | echo "${expname} GPU: ${gpuid} Env: ${envlist[$i]} Seed: ${seed} ${1}"
23 | # sleep 5
24 | basename=$(basename $1)
25 | echo "========" >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt
26 | CUDA_VISIBLE_DEVICES=$gpuid python $1 --env ${envlist[$i]} --seed $seed --exp-name ${expname} \
27 | --fov-size 0 \
28 | --peripheral-res 20 \
29 | --clip-reward \
30 | --capture-video \
31 | --total-timesteps $totaltimesteps --buffer-size $buffersize \
32 | --learning-starts $learningstarts \
33 | ${@:3} >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt
34 | done
35 | ) &
36 | done
37 | wait
--------------------------------------------------------------------------------
/scripts/atari_100k_20x20_5m.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | numgpus=${2:-$(nvidia-smi --list-gpus | wc -l)}
3 |
4 | # atari 100k settings
5 | envlist=(alien amidar assault asterix bank_heist battle_zone boxing breakout chopper_command crazy_climber demon_attack freeway frostbite gopher hero jamesbond kangaroo krull kung_fu_master ms_pacman pong private_eye qbert road_runner seaquest up_n_down) #pong qbert seaquest zaxxon
6 |
7 | expname="atari_100k_20x20_5m"
8 | totaltimesteps="5000000"
9 | buffersize="100000"
10 | learningstarts="80000"
11 |
12 | mkdir -p logs/${expname}
13 | mkdir -p recordings/${expname}
14 | mkdir -p trained_models/${expname}
15 |
16 | for i in ${!envlist[@]}
17 | do
18 | gpuid=$(( $i % $numgpus ))
19 | (
20 | for seed in 0 1 2 3 4
21 | do
22 | echo "${expname} GPU: ${gpuid} Env: ${envlist[$i]} Seed: ${seed} ${1}"
23 | # sleep 5
24 | basename=$(basename $1)
25 | echo "========" >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt
26 | CUDA_VISIBLE_DEVICES=$gpuid python $1 --env ${envlist[$i]} --seed $seed --exp-name ${expname} \
27 | --exploration-fraction 0.02 \
28 | --eval-frequency 1000000 \
29 | --fov-size 20 \
30 | --clip-reward \
31 | --capture-video \
32 | --total-timesteps $totaltimesteps --buffer-size $buffersize \
33 | --learning-starts $learningstarts \
34 | ${@:2} >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt
35 | done
36 | ) &
37 | done
38 | wait
--------------------------------------------------------------------------------
/scripts/atari_100k_30x30_5m.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | numgpus=${2:-$(nvidia-smi --list-gpus | wc -l)}
3 |
4 | # atari 100k settings
5 | envlist=(alien amidar assault asterix bank_heist battle_zone boxing breakout chopper_command crazy_climber demon_attack freeway frostbite gopher hero jamesbond kangaroo krull kung_fu_master ms_pacman pong private_eye qbert road_runner seaquest up_n_down) #pong qbert seaquest zaxxon
6 |
7 | expname="atari_100k_30x30_5m"
8 | totaltimesteps="5000000"
9 | buffersize="100000"
10 | learningstarts="80000"
11 |
12 | mkdir -p logs/${expname}
13 | mkdir -p recordings/${expname}
14 | mkdir -p trained_models/${expname}
15 |
16 | for i in ${!envlist[@]}
17 | do
18 | gpuid=$(( $i % $numgpus ))
19 | (
20 | for seed in 0 1 2 3 4
21 | do
22 | echo "${expname} GPU: ${gpuid} Env: ${envlist[$i]} Seed: ${seed} ${1}"
23 | # sleep 5
24 | basename=$(basename $1)
25 | echo "========" >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt
26 | CUDA_VISIBLE_DEVICES=$gpuid python $1 --env ${envlist[$i]} --seed $seed --exp-name ${expname} \
27 | --exploration-fraction 0.02 \
28 | --eval-frequency 1000000 \
29 | --fov-size 30 \
30 | --clip-reward \
31 | --capture-video \
32 | --total-timesteps $totaltimesteps --buffer-size $buffersize \
33 | --learning-starts $learningstarts \
34 | ${@:2} >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt
35 | done
36 | ) &
37 | done
38 | wait
--------------------------------------------------------------------------------
/scripts/atari_100k_50x50_5m.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | numgpus=${2:-$(nvidia-smi --list-gpus | wc -l)}
3 |
4 | # atari 100k settings
5 | envlist=(alien amidar assault asterix bank_heist battle_zone boxing breakout chopper_command crazy_climber demon_attack freeway frostbite gopher hero jamesbond kangaroo krull kung_fu_master ms_pacman pong private_eye qbert road_runner seaquest up_n_down) #pong qbert seaquest zaxxon
6 |
7 | expname="atari_100k_50x50_5m"
8 | totaltimesteps="5000000"
9 | buffersize="100000"
10 | learningstarts="80000"
11 |
12 | mkdir -p logs/${expname}
13 | mkdir -p recordings/${expname}
14 | mkdir -p trained_models/${expname}
15 |
16 | for i in ${!envlist[@]}
17 | do
18 | gpuid=$(( $i % $numgpus ))
19 | (
20 | for seed in 0 1 2 3 4
21 | do
22 | echo "${expname} GPU: ${gpuid} Env: ${envlist[$i]} Seed: ${seed} ${1}"
23 | # sleep 5
24 | basename=$(basename $1)
25 | echo "========" >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt
26 | CUDA_VISIBLE_DEVICES=$gpuid python $1 --env ${envlist[$i]} --seed $seed --exp-name ${expname} \
27 | --exploration-fraction 0.02 \
28 | --eval-frequency 1000000 \
29 | --fov-size 50 \
30 | --clip-reward \
31 | --capture-video \
32 | --total-timesteps $totaltimesteps --buffer-size $buffersize \
33 | --learning-starts $learningstarts \
34 | ${@:2} >> logs/${expname}/${envlist[$i]}__${basename}__${seed}.txt
35 | done
36 | ) &
37 | done
38 | wait
--------------------------------------------------------------------------------
/common/pvm_buffer.py:
--------------------------------------------------------------------------------
1 | from collections import deque
2 | from typing import Tuple
3 | import copy
4 |
5 | import numpy as np
6 |
7 | class PVMBuffer:
8 |
9 | def __init__(self, max_len: int, obs_size: Tuple, fov_loc_size: Tuple = None) -> None:
10 | self.max_len = max_len
11 | self.obs_size = obs_size
12 | self.fov_loc_size = fov_loc_size
13 | self.buffer = None
14 | self.fov_loc_buffer = None
15 | self.init_pvm_buffer()
16 |
17 |
18 | def init_pvm_buffer(self) -> None:
19 | self.buffer = deque([], maxlen=self.max_len)
20 | self.fov_loc_buffer = deque([], maxlen=self.max_len)
21 | for _ in range(self.max_len):
22 | self.buffer.append(np.zeros(self.obs_size, dtype=np.float32))
23 | if self.fov_loc_size is not None: # (1, 2) or (1, 2, 3)
24 | self.fov_loc_buffer.append(np.zeros(self.fov_loc_size, dtype=np.float32))
25 |
26 |
27 | def append(self, x, fov_loc=None) -> None:
28 | self.buffer.append(x)
29 | if fov_loc is not None:
30 | self.fov_loc_buffer.append(fov_loc)
31 |
32 | def copy(self):
33 | return copy.deepcopy(self)
34 |
35 | def get_obs(self, mode="stack_max") -> np.ndarray:
36 | if mode == "stack_max":
37 | return np.amax(np.stack(self.buffer, axis=1), axis=1) # leading dim is batch dim [B, 1, C, H, W]
38 | elif mode == "stack_mean":
39 | return np.mean(np.stack(self.buffer, axis=1), axis=1, keepdims=True) # leading dim is batch dim [B, 1, C, H, W]
40 | elif mode == "stack":
41 | # print ([x.shape for x in self.buffer])
42 | return np.stack(self.buffer, axis=1) # [B, T, C, H, W]
43 | elif mode == "stack_channel":
44 | return np.concatenate(self.buffer, axis=1) # [B, T*C, H, W]
45 | else:
46 | raise NotImplementedError
47 |
48 | def get_fov_locs(self, return_mask=False, relative_transform=False) -> np.ndarray:
49 | # print ([x.shape for x in self.fov_loc_buffer])
50 | transforms = []
51 | if relative_transform:
52 | for t in range(len(self.fov_loc_buffer)):
53 | transforms.append(np.zeros((self.fov_loc_buffer[t].shape[0], 2, 3), dtype=np.float32)) # [B, 2, 3] B=1 usually
54 | for b in range(len(self.fov_loc_buffer[t])):
55 | extrinsic = self.fov_loc_buffer[t][b]
56 | target_extrinsic = self.fov_loc_buffer[t][-1]
57 | # print (extrinsic, extrinsic.shape)
58 | if np.linalg.det(extrinsic):
59 | extrinsic_inv = np.linalg.inv(extrinsic)
60 | transform = np.dot(target_extrinsic, extrinsic_inv)
61 | # 4x4 transformation matrix
62 | R = transform[:3, :3] # Extract the rotation
63 | tr = transform[:3, 3] # Extract the translation
64 | H = R + np.outer(tr, np.array([0, 0, 1], dtype=np.float32))
65 | # Assuming H is the 3x3 homography matrix
66 | A = H / H[2, 2]
67 | affine = A[:2, :]
68 | transforms[t][b] = affine
69 | else:
70 | H = np.identity(3, dtype=np.float32)
71 | A = H / H[2, 2]
72 | affine = A[:2, :]
73 | transforms[t][b] = affine
74 | # print (transforms, [x.shape for x in transforms])
75 | return np.stack(transforms, axis=1) # [B, T, 2, 3]
76 | else:
77 | return np.stack(self.fov_loc_buffer, axis=1)
78 | #[B, T, *fov_locs_size], maybe [B, T, 2] for 2d or [B, T, 4, 4] for 3d
79 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SUGARL
2 | Code for NeurIPS 2023 paper **Active Vision Reinforcement Learning with Limited Visual Observability**, by [Jinghuan Shang](https://www.cs.stonybrook.edu/~jishang) and [Michael S. Ryoo](http://michaelryoo.com/).
3 |
4 | We propose Sensorimotor Understanding Guided Active Reinforcement Learning (SUGARL) to solve ActiveVision-RL tasks.
5 | We also introduce [Active-Gym](https://github.com/elicassion/active-gym), a convenient library that modifies existing RL environments for ActiveVision-RL, with Gymnasium-like interface.
6 |
7 | [[Paper]](https://arxiv.org/abs/2306.00975) [[Project Page]](https://elicassion.github.io/sugarl/sugarl.html) [[Active-Gym]](https://github.com/elicassion/active-gym)
8 |
9 |
10 |
11 | ## Dependency
12 | ```
13 | conda env create -f active_rl_env.yaml
14 | ```
15 | We highlight [Active-Gym](https://github.com/elicassion/active-gym) developed by us to support Active-RL setting for many environments.
16 |
17 |
18 | ## Usage
19 | - General format:
20 | ```
21 | cd sugarl # make sure you are under the root dir of this repo
22 | bash ./scripts/ agent/
23 | ```
24 |
25 | - Reproduce our experiments:
26 | ```
27 | cd sugarl # make sure you are under the root dir of this repo
28 | bash ./scripts/robosuite_series.sh agent/
29 | bash ./scripts/atari_series.sh agent/
30 | bash ./scripts/atari_series_5m.sh agent/
31 | bash ./scripts/atari_wp_series.sh agent/
32 | bash ./scripts/dmc_series.sh agent/
33 | ```
34 | For example, to run SUGARL-DQN on Atari
35 | ```
36 | bash ./scripts/atari_series.sh agent/dqn_atari_sugarl.py
37 | ```
38 |
39 | - Sanity checks: they run through the whole process with only a tiny amount of training to check bugs
40 | ```
41 | cd sugarl # make sure you are under the root dir of this repo
42 | bash ./scripts/atari_test.sh agent/
43 | bash ./scripts/dmc_test.sh agent/
44 | bash ./scripts/robosuite_test.sh agent/
45 | ```
46 |
47 | All experiment scripts automatically scale all tasks to your GPUs. Please modify the gpu behavior (`CUDA_VISIBLE_DEVICES=`) in the script if
48 | - you want to run jobs on certain GPUs
49 | - either VRAM or RAM is not sufficient for scaling all jobs
50 |
51 | In the provided scripts, 26 Atari games are in parallel, with sequentially executing each seed. 6 DMC environments x 5 seeds are all in parallel. Please do check the available RAM and VRAM on your machine before starting.
52 |
53 | ### Notes
54 | **Naming**:
55 |
56 | All agents are under `agent/`, with the name format `__.py`. Each file is an individual entry for the whole process. We support DQN, SAC, and DrQ for base algorithms.
57 |
58 | All experiment scripts are under `scripts/`, with the format `_.sh`
59 | Please ensure that the env and setting match the agent when launching jobs.
60 |
61 | **Resource requirement reference (SUGARL)**:
62 |
63 | - Atari:
64 | for each game with `100k` replay buffer: `~18G` RAM, `<2G` VRAM
65 |
66 | - DMC:
67 | for each task with `100k` replay buffer: `~18G` RAM, `<3G` VRAM
68 |
69 | - Robosuite:
70 | for each task with `100k` replay buffer: `~54G` RAM, `4.2G` VRAM
71 |
72 | **Coding style**:
73 |
74 | We follow the coding style of [clean-rl](https://github.com/vwxyzjn/cleanrl) so that modifications on one agent would not affect others. This does introduce lots of redundency, but is so much easier for arranging experiments and evolving the algorithm.
75 |
76 | ## Citation
77 | Please consider cite us if you find this repo helpful.
78 | ```
79 | @article{shang2023active,
80 | title={Active Reinforcement Learning under Limited Visual Observability},
81 | author={Jinghuan Shang and Michael S. Ryoo},
82 | journal={arXiv preprint},
83 | year={2023},
84 | eprint={2306.00975},
85 | }
86 | ```
87 |
88 | ## Acknowledgement
89 | We thank the implementation of [clean-rl](https://github.com/vwxyzjn/cleanrl).
90 |
--------------------------------------------------------------------------------
/agent/dqn_atari_base.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os, sys
3 | import os.path as osp
4 | import random
5 | import time
6 | from distutils.util import strtobool
7 |
8 | sys.path.append(osp.dirname(osp.dirname(osp.realpath(__file__))))
9 | os.environ["OMP_NUM_THREADS"] = "1"
10 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
11 | import warnings
12 | warnings.filterwarnings("ignore", category=UserWarning)
13 |
14 | import gymnasium as gym
15 | import numpy as np
16 | import torch
17 | import torch.nn as nn
18 | import torch.nn.functional as F
19 | import torch.optim as optim
20 | from common.buffer import ReplayBuffer
21 | from common.utils import get_timestr, seed_everything
22 | from torch.utils.tensorboard import SummaryWriter
23 |
24 | from active_gym import AtariBaseEnv, AtariEnvArgs
25 |
26 |
27 | def parse_args():
28 | # fmt: off
29 | parser = argparse.ArgumentParser()
30 | parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__),
31 | help="the name of this experiment")
32 | parser.add_argument("--seed", type=int, default=1,
33 | help="seed of the experiment")
34 | parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
35 | help="if toggled, cuda will be enabled by default")
36 | parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
37 | help="whether to capture videos of the agent performances (check out `videos` folder)")
38 |
39 | # env setting
40 | parser.add_argument("--env", type=str, default="breakout",
41 | help="the id of the environment")
42 | parser.add_argument("--env-num", type=int, default=1,
43 | help="# envs in parallel")
44 | parser.add_argument("--frame-stack", type=int, default=4,
45 | help="frame stack #")
46 | parser.add_argument("--action-repeat", type=int, default=4,
47 | help="action repeat #")
48 | parser.add_argument("--clip-reward", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True)
49 |
50 | # Algorithm specific arguments
51 | parser.add_argument("--total-timesteps", type=int, default=3000000,
52 | help="total timesteps of the experiments")
53 | parser.add_argument("--learning-rate", type=float, default=1e-4,
54 | help="the learning rate of the optimizer")
55 | parser.add_argument("--buffer-size", type=int, default=500000,
56 | help="the replay memory buffer size")
57 | parser.add_argument("--gamma", type=float, default=0.99,
58 | help="the discount factor gamma")
59 | parser.add_argument("--target-network-frequency", type=int, default=1000,
60 | help="the timesteps it takes to update the target network")
61 | parser.add_argument("--batch-size", type=int, default=32,
62 | help="the batch size of sample from the reply memory")
63 | parser.add_argument("--start-e", type=float, default=1,
64 | help="the starting epsilon for exploration")
65 | parser.add_argument("--end-e", type=float, default=0.01,
66 | help="the ending epsilon for exploration")
67 | parser.add_argument("--exploration-fraction", type=float, default=0.10,
68 | help="the fraction of `total-timesteps` it takes from start-e to go end-e")
69 | parser.add_argument("--learning-starts", type=int, default=80000,
70 | help="timestep to start learning")
71 | parser.add_argument("--train-frequency", type=int, default=4,
72 | help="the frequency of training")
73 |
74 | # eval args
75 | parser.add_argument("--eval-frequency", type=int, default=-1,
76 | help="eval frequency. default -1 is eval at the end.")
77 | parser.add_argument("--eval-num", type=int, default=10,
78 | help="eval frequency. default -1 is eval at the end.")
79 | args = parser.parse_args()
80 | # fmt: on
81 | return args
82 |
83 |
84 | def make_env(env_name, seed, **kwargs):
85 | def thunk():
86 | env_args = AtariEnvArgs(
87 | game=env_name, seed=seed, obs_size=(84, 84), **kwargs
88 | )
89 | env = AtariBaseEnv(env_args)
90 | env.action_space.seed(seed)
91 | env.observation_space.seed(seed)
92 | return env
93 |
94 | return thunk
95 |
96 |
97 | # ALGO LOGIC: initialize agent here:
98 | class QNetwork(nn.Module):
99 | def __init__(self, env):
100 | super().__init__()
101 | self.network = nn.Sequential(
102 | nn.Conv2d(4, 32, 8, stride=4),
103 | nn.ReLU(),
104 | nn.Conv2d(32, 64, 4, stride=2),
105 | nn.ReLU(),
106 | nn.Conv2d(64, 64, 3, stride=1),
107 | nn.ReLU(),
108 | nn.Flatten(),
109 | nn.Linear(3136, 512),
110 | nn.ReLU(),
111 | nn.Linear(512, env.single_action_space.n),
112 | )
113 |
114 | def forward(self, x):
115 | return self.network(x)
116 |
117 |
118 | def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
119 | slope = (end_e - start_e) / duration
120 | return max(slope * t + start_e, end_e)
121 |
122 |
123 | if __name__ == "__main__":
124 | args = parse_args()
125 | args.env = args.env.lower()
126 | run_name = f"{args.env}__{os.path.basename(__file__)}__{args.seed}__{get_timestr()}"
127 | run_dir = os.path.join("runs", args.exp_name)
128 | if not os.path.exists(run_dir):
129 | os.makedirs(run_dir, exist_ok=True)
130 |
131 | writer = SummaryWriter(os.path.join(run_dir, run_name))
132 | writer.add_text(
133 | "hyperparameters",
134 | "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
135 | )
136 |
137 | # TRY NOT TO MODIFY: seeding
138 | seed_everything(args.seed)
139 |
140 |
141 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
142 |
143 | # env setup
144 | envs = []
145 | for i in range(args.env_num):
146 | envs.append(make_env(args.env, args.seed+i, frame_stack=args.frame_stack, action_repeat=args.action_repeat, clip_reward=args.clip_reward))
147 | # envs = gym.vector.AsyncVectorEnv(envs)
148 | envs = gym.vector.SyncVectorEnv(envs)
149 | # async 4 100s - 24k
150 | # async 2 47s - 50k
151 | # async 1 60s - 50k
152 | # sync 1 50s - 50k
153 | # sync 2 46s - 50k
154 |
155 | q_network = QNetwork(envs).to(device)
156 | optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate)
157 | target_network = QNetwork(envs).to(device)
158 | target_network.load_state_dict(q_network.state_dict())
159 |
160 | rb = ReplayBuffer(
161 | args.buffer_size,
162 | envs.single_observation_space,
163 | envs.single_action_space,
164 | device,
165 | n_envs=envs.num_envs,
166 | optimize_memory_usage=True,
167 | handle_timeout_termination=False,
168 | )
169 | start_time = time.time()
170 |
171 | # TRY NOT TO MODIFY: start the game
172 | obs, _ = envs.reset()
173 | global_transitions = 0
174 | while global_transitions < args.total_timesteps:
175 | # ALGO LOGIC: put action logic here
176 | epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_transitions)
177 | if random.random() < epsilon:
178 | actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
179 | else:
180 | q_values = q_network(torch.from_numpy(obs).to(device))
181 | actions = torch.argmax(q_values, dim=1).cpu().numpy()
182 |
183 | # TRY NOT TO MODIFY: execute the game and log data.
184 | next_obs, rewards, dones, _, infos = envs.step(actions)
185 | # print (global_step, infos)
186 |
187 | # TRY NOT TO MODIFY: record rewards for plotting purposes
188 | if "final_info" in infos:
189 | for idx, d in enumerate(dones):
190 | if d:
191 | print(f"[T: {time.time()-start_time:.2f}] [N: {global_transitions:07,d}] [R: {infos['final_info'][idx]['reward']:.2f}]")
192 | writer.add_scalar("charts/episodic_return", infos['final_info'][idx]["reward"], global_transitions)
193 | writer.add_scalar("charts/episodic_length", infos['final_info'][idx]["ep_len"], global_transitions)
194 | writer.add_scalar("charts/epsilon", epsilon, global_transitions)
195 | break
196 |
197 | # TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
198 | real_next_obs = next_obs
199 | for idx, d in enumerate(dones):
200 | if d:
201 | real_next_obs[idx] = infos["final_observation"][idx]
202 | rb.add(obs, real_next_obs, actions, rewards, dones, {})
203 |
204 | # TRY NOT TO MODIFY: CRUCIAL step easy to overlook
205 | obs = next_obs
206 |
207 | # INC total transitions
208 | global_transitions += args.env_num
209 |
210 | # ALGO LOGIC: training.
211 | if global_transitions > args.learning_starts:
212 | if global_transitions % args.train_frequency == 0:
213 | data = rb.sample(args.batch_size // args.env_num) # counter-balance the true global transitions used for training
214 | with torch.no_grad():
215 | target_max, _ = target_network(data.next_observations).max(dim=1)
216 | td_target = data.rewards.flatten() + args.gamma * target_max * (1 - data.dones.flatten())
217 | old_val = q_network(data.observations).gather(1, data.actions).squeeze()
218 | loss = F.mse_loss(td_target, old_val)
219 |
220 | if global_transitions % 100 == 0:
221 | writer.add_scalar("losses/td_loss", loss.item(), global_transitions)
222 | writer.add_scalar("losses/q_values", old_val.mean().item(), global_transitions)
223 | # print("SPS:", int(global_step / (time.time() - start_time)))
224 | writer.add_scalar("charts/SPS", int(global_transitions / (time.time() - start_time)), global_transitions)
225 |
226 | # optimize the model
227 | optimizer.zero_grad()
228 | loss.backward()
229 | optimizer.step()
230 |
231 | # update the target network
232 | if (global_transitions // args.env_num) % args.target_network_frequency == 0:
233 | target_network.load_state_dict(q_network.state_dict())
234 |
235 | # evaluation
236 | if (global_transitions % args.eval_frequency == 0 and args.eval_frequency > 0) or \
237 | (global_transitions >= args.total_timesteps):
238 | q_network.eval()
239 |
240 | eval_episodic_returns, eval_episodic_lengths = [], []
241 |
242 | for eval_ep in range(args.eval_num):
243 | eval_env = [make_env(args.env, args.seed+eval_ep, frame_stack=args.frame_stack, action_repeat=args.action_repeat, clip_reward=args.clip_reward, training=False, record=args.capture_video)]
244 | eval_env = gym.vector.SyncVectorEnv(eval_env)
245 | obs_eval, _ = eval_env.reset()
246 | done = False
247 | while not done:
248 | q_values = q_network(torch.from_numpy(obs_eval).to(device))
249 | actions = torch.argmax(q_values, dim=1).cpu().numpy()
250 | next_obs_eval, rewards, dones, _, infos = eval_env.step(actions)
251 | obs_eval = next_obs_eval
252 | done = dones[0]
253 | if done:
254 | eval_episodic_returns.append(infos['final_info'][0]["reward"])
255 | eval_episodic_lengths.append(infos['final_info'][0]["ep_len"])
256 |
257 | writer.add_scalar("charts/eval_episodic_return", np.mean(eval_episodic_returns), global_transitions)
258 | writer.add_scalar("charts/eval_episodic_return_std", np.std(eval_episodic_returns), global_transitions)
259 | # writer.add_scalar("charts/eval_episodic_length", np.mean(), global_transitions)
260 | print(f"[T: {time.time()-start_time:.2f}] [N: {global_transitions:07,d}] [Eval R: {np.mean(eval_episodic_returns):.2f}+/-{np.std(eval_episodic_returns):.2f}] [R list: {','.join([str(r) for r in eval_episodic_returns])}]")
261 |
262 | q_network.train()
263 |
264 |
265 |
266 | envs.close()
267 | eval_env.close()
268 | writer.close()
--------------------------------------------------------------------------------
/agent/dqn_atari_wp_random.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os, sys
3 | import os.path as osp
4 | import random
5 | import time
6 | from itertools import product
7 | from distutils.util import strtobool
8 |
9 | sys.path.append(osp.dirname(osp.dirname(osp.realpath(__file__))))
10 | os.environ["OMP_NUM_THREADS"] = "1"
11 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
12 | import warnings
13 | warnings.filterwarnings("ignore", category=UserWarning)
14 |
15 | import gymnasium as gym
16 | from gymnasium.spaces import Discrete, Dict
17 | import numpy as np
18 | import torch
19 | import torch.nn as nn
20 | import torch.nn.functional as F
21 | import torch.optim as optim
22 | from torchvision.transforms import Resize
23 |
24 | from common.buffer import ReplayBuffer
25 | from common.utils import get_timestr, seed_everything
26 | from torch.utils.tensorboard import SummaryWriter
27 |
28 | from active_gym import AtariFixedFovealPeripheralEnv, AtariEnvArgs
29 |
30 |
31 | def parse_args():
32 | # fmt: off
33 | parser = argparse.ArgumentParser()
34 | parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
35 | help="the name of this experiment")
36 | parser.add_argument("--seed", type=int, default=1,
37 | help="seed of the experiment")
38 | parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
39 | help="if toggled, cuda will be enabled by default")
40 | parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
41 | help="whether to capture videos of the agent performances (check out `videos` folder)")
42 |
43 | # env setting
44 | parser.add_argument("--env", type=str, default="breakout",
45 | help="the id of the environment")
46 | parser.add_argument("--env-num", type=int, default=1,
47 | help="# envs in parallel")
48 | parser.add_argument("--frame-stack", type=int, default=4,
49 | help="frame stack #")
50 | parser.add_argument("--action-repeat", type=int, default=4,
51 | help="action repeat #")
52 | parser.add_argument("--clip-reward", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True)
53 |
54 | # fov setting
55 | parser.add_argument("--fov-size", type=int, default=50)
56 | parser.add_argument("--fov-init-loc", type=int, default=0)
57 | parser.add_argument("--sensory-action-mode", type=str, default="absolute")
58 | parser.add_argument("--sensory-action-space", type=int, default=10) # ignored when sensory_action_mode="relative"
59 | parser.add_argument("--resize-to-full", default=False, action="store_true")
60 | parser.add_argument("--peripheral-res", type=int, default=20)
61 | # for discrete observ action
62 | parser.add_argument("--sensory-action-x-size", type=int, default=4)
63 | parser.add_argument("--sensory-action-y-size", type=int, default=4)
64 |
65 | # Algorithm specific arguments
66 | parser.add_argument("--total-timesteps", type=int, default=3000000,
67 | help="total timesteps of the experiments")
68 | parser.add_argument("--learning-rate", type=float, default=1e-4,
69 | help="the learning rate of the optimizer")
70 | parser.add_argument("--buffer-size", type=int, default=500000,
71 | help="the replay memory buffer size")
72 | parser.add_argument("--gamma", type=float, default=0.99,
73 | help="the discount factor gamma")
74 | parser.add_argument("--target-network-frequency", type=int, default=1000,
75 | help="the timesteps it takes to update the target network")
76 | parser.add_argument("--batch-size", type=int, default=32,
77 | help="the batch size of sample from the reply memory")
78 | parser.add_argument("--start-e", type=float, default=1,
79 | help="the starting epsilon for exploration")
80 | parser.add_argument("--end-e", type=float, default=0.01,
81 | help="the ending epsilon for exploration")
82 | parser.add_argument("--exploration-fraction", type=float, default=0.10,
83 | help="the fraction of `total-timesteps` it takes from start-e to go end-e")
84 | parser.add_argument("--learning-starts", type=int, default=80000,
85 | help="timestep to start learning")
86 | parser.add_argument("--train-frequency", type=int, default=4,
87 | help="the frequency of training")
88 |
89 | # eval args
90 | parser.add_argument("--eval-frequency", type=int, default=-1,
91 | help="eval frequency. default -1 is eval at the end.")
92 | parser.add_argument("--eval-num", type=int, default=10,
93 | help="eval frequency. default -1 is eval at the end.")
94 | args = parser.parse_args()
95 | # fmt: on
96 | return args
97 |
98 |
99 | def make_env(env_name, seed, **kwargs):
100 | def thunk():
101 | env_args = AtariEnvArgs(
102 | game=env_name, seed=seed, obs_size=(84, 84), **kwargs
103 | )
104 | env = AtariFixedFovealPeripheralEnv(env_args)
105 | env.action_space.seed(seed)
106 | env.observation_space.seed(seed)
107 | return env
108 |
109 | return thunk
110 |
111 |
112 | # ALGO LOGIC: initialize agent here:
113 | class QNetwork(nn.Module):
114 | def __init__(self, env):
115 | super().__init__()
116 | if isinstance(env.single_action_space, Discrete):
117 | action_space_size = env.single_action_space.n
118 | elif isinstance(env.single_action_space, Dict):
119 | action_space_size = env.single_action_space["motor_action"].n
120 | self.network = nn.Sequential(
121 | nn.Conv2d(4, 32, 8, stride=4),
122 | nn.ReLU(),
123 | nn.Conv2d(32, 64, 4, stride=2),
124 | nn.ReLU(),
125 | nn.Conv2d(64, 64, 3, stride=1),
126 | nn.ReLU(),
127 | nn.Flatten(),
128 | nn.Linear(3136, 512),
129 | nn.ReLU(),
130 | nn.Linear(512, action_space_size),
131 | )
132 |
133 | def forward(self, x):
134 | return self.network(x)
135 |
136 |
137 | def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
138 | slope = (end_e - start_e) / duration
139 | return max(slope * t + start_e, end_e)
140 |
141 |
142 | if __name__ == "__main__":
143 | args = parse_args()
144 | args.env = args.env.lower()
145 | run_name = f"{args.env}__{os.path.basename(__file__)}__{args.seed}__{get_timestr()}"
146 | run_dir = os.path.join("runs", args.exp_name)
147 | if not os.path.exists(run_dir):
148 | os.makedirs(run_dir, exist_ok=True)
149 |
150 | writer = SummaryWriter(os.path.join(run_dir, run_name))
151 | writer.add_text(
152 | "hyperparameters",
153 | "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
154 | )
155 |
156 | # TRY NOT TO MODIFY: seeding
157 | seed_everything(args.seed)
158 |
159 |
160 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
161 |
162 | # env setup
163 | envs = []
164 | for i in range(args.env_num):
165 | envs.append(make_env(args.env, args.seed+i, frame_stack=args.frame_stack, action_repeat=args.action_repeat,
166 | fov_size=(args.fov_size, args.fov_size),
167 | fov_init_loc=(args.fov_init_loc, args.fov_init_loc),
168 | peripheral_res=(args.peripheral_res, args.peripheral_res),
169 | sensory_action_mode=args.sensory_action_mode,
170 | sensory_action_space=(-args.sensory_action_space, args.sensory_action_space),
171 | resize_to_full=args.resize_to_full,
172 | clip_reward=args.clip_reward))
173 | # envs = gym.vector.AsyncVectorEnv(envs)
174 | envs = gym.vector.SyncVectorEnv(envs)
175 |
176 | resize = Resize((84, 84))
177 |
178 | # get a discrete observ action space
179 | OBSERVATION_SIZE = (84, 84)
180 | observ_x_max, observ_y_max = OBSERVATION_SIZE[0]-args.fov_size, OBSERVATION_SIZE[1]-args.fov_size
181 | sensory_action_step = (observ_x_max//args.sensory_action_x_size,
182 | observ_y_max//args.sensory_action_y_size)
183 | sensory_action_x_set = list(range(0, observ_x_max, sensory_action_step[0]))[:args.sensory_action_x_size]
184 | sensory_action_y_set = list(range(0, observ_y_max, sensory_action_step[1]))[:args.sensory_action_y_size]
185 | sensory_action_set = list(product(sensory_action_x_set, sensory_action_y_set))
186 |
187 | q_network = QNetwork(envs).to(device)
188 | optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate)
189 | target_network = QNetwork(envs).to(device)
190 | target_network.load_state_dict(q_network.state_dict())
191 |
192 | rb = ReplayBuffer(
193 | args.buffer_size,
194 | envs.single_observation_space,
195 | envs.single_action_space["motor_action"],
196 | device,
197 | n_envs=envs.num_envs,
198 | optimize_memory_usage=True,
199 | handle_timeout_termination=False,
200 | )
201 | start_time = time.time()
202 |
203 | # TRY NOT TO MODIFY: start the game
204 | obs, _ = envs.reset()
205 | global_transitions = 0
206 | while global_transitions < args.total_timesteps:
207 | # ALGO LOGIC: put action logic here
208 | epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_transitions)
209 | if random.random() < epsilon:
210 | actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
211 | motor_actions = np.array([actions[0]["motor_action"]])
212 | else:
213 | q_values = q_network(resize(torch.from_numpy(obs)).to(device))
214 | motor_actions = torch.argmax(q_values, dim=1).cpu().numpy()
215 |
216 | # TRY NOT TO MODIFY: execute the game and log data.
217 | next_obs, rewards, dones, _, infos = envs.step({"motor_action": motor_actions,
218 | "sensory_action": [sensory_action_set[random.randint(0, len(sensory_action_set)-1)]] })
219 | # print (global_step, infos)
220 |
221 | # TRY NOT TO MODIFY: record rewards for plotting purposes
222 | if "final_info" in infos:
223 | for idx, d in enumerate(dones):
224 | if d:
225 | print(f"[T: {time.time()-start_time:.2f}] [N: {global_transitions:07,d}] [R: {infos['final_info'][idx]['reward']:.2f}]")
226 | writer.add_scalar("charts/episodic_return", infos['final_info'][idx]["reward"], global_transitions)
227 | writer.add_scalar("charts/episodic_length", infos['final_info'][idx]["ep_len"], global_transitions)
228 | writer.add_scalar("charts/epsilon", epsilon, global_transitions)
229 | break
230 |
231 | # TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
232 | real_next_obs = next_obs
233 | for idx, d in enumerate(dones):
234 | if d:
235 | real_next_obs[idx] = infos["final_observation"][idx]
236 | rb.add(obs, real_next_obs, motor_actions, rewards, dones, {})
237 |
238 | # TRY NOT TO MODIFY: CRUCIAL step easy to overlook
239 | obs = next_obs
240 |
241 | # INC total transitions
242 | global_transitions += args.env_num
243 |
244 | # ALGO LOGIC: training.
245 | if global_transitions > args.learning_starts:
246 | if global_transitions % args.train_frequency == 0:
247 | data = rb.sample(args.batch_size // args.env_num) # counter-balance the true global transitions used for training
248 | with torch.no_grad():
249 | target_max, _ = target_network(resize(data.next_observations)).max(dim=1)
250 | td_target = data.rewards.flatten() + args.gamma * target_max * (1 - data.dones.flatten())
251 | old_val = q_network(resize(data.observations)).gather(1, data.actions).squeeze()
252 | loss = F.mse_loss(td_target, old_val)
253 |
254 | if global_transitions % 100 == 0:
255 | writer.add_scalar("losses/td_loss", loss, global_transitions)
256 | writer.add_scalar("losses/q_values", old_val.mean().item(), global_transitions)
257 | # print("SPS:", int(global_step / (time.time() - start_time)))
258 | writer.add_scalar("charts/SPS", int(global_transitions / (time.time() - start_time)), global_transitions)
259 |
260 | # optimize the model
261 | optimizer.zero_grad()
262 | loss.backward()
263 | optimizer.step()
264 |
265 | # update the target network
266 | if (global_transitions // args.env_num) % args.target_network_frequency == 0:
267 | target_network.load_state_dict(q_network.state_dict())
268 |
269 | # evaluation
270 | if (global_transitions % args.eval_frequency == 0 and args.eval_frequency > 0) or \
271 | (global_transitions >= args.total_timesteps):
272 | q_network.eval()
273 |
274 | eval_episodic_returns, eval_episodic_lengths = [], []
275 |
276 | for eval_ep in range(args.eval_num):
277 | eval_env = [make_env(args.env, args.seed+eval_ep, frame_stack=args.frame_stack, action_repeat=args.action_repeat,
278 | fov_size=(args.fov_size, args.fov_size),
279 | fov_init_loc=(args.fov_init_loc, args.fov_init_loc),
280 | peripheral_res=(args.peripheral_res, args.peripheral_res),
281 | sensory_action_mode=args.sensory_action_mode,
282 | sensory_action_space=(-args.sensory_action_space, args.sensory_action_space),
283 | resize_to_full=args.resize_to_full,
284 | clip_reward=args.clip_reward,
285 | training=False)]
286 | eval_env = gym.vector.SyncVectorEnv(eval_env)
287 | obs, _ = eval_env.reset()
288 | done = False
289 | while not done:
290 | q_values = q_network(resize(torch.from_numpy(obs)).to(device))
291 | motor_actions = torch.argmax(q_values, dim=1).cpu().numpy()
292 | next_obs, rewards, dones, _, infos = eval_env.step({"motor_action": motor_actions,
293 | "sensory_action": [sensory_action_set[random.randint(0, len(sensory_action_set)-1)]]})
294 | obs = next_obs
295 | done = dones[0]
296 | if done:
297 | eval_episodic_returns.append(infos['final_info'][0]["reward"])
298 | eval_episodic_lengths.append(infos['final_info'][0]["ep_len"])
299 |
300 | writer.add_scalar("charts/eval_episodic_return", np.mean(eval_episodic_returns), global_transitions)
301 | writer.add_scalar("charts/eval_episodic_return_std", np.std(eval_episodic_returns), global_transitions)
302 | print(f"[T: {time.time()-start_time:.2f}] [N: {global_transitions:07,d}] [Eval R: {np.mean(eval_episodic_returns):.2f}+/-{np.std(eval_episodic_returns):.2f}] [R list: {','.join([str(r) for r in eval_episodic_returns])}]")
303 |
304 | q_network.train()
305 |
306 |
307 |
308 | envs.close()
309 | eval_env.close()
310 | writer.close()
--------------------------------------------------------------------------------
/common/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Borrow from stable-baselines3
3 | Due to dependencies incompability, we cherry-pick codes here
4 | """
5 | import os, random, re
6 | from datetime import datetime
7 | import warnings
8 | from typing import Dict, Tuple, Union
9 |
10 | import numpy as np
11 | import torch
12 | import torch.nn as nn
13 | from torch.nn import functional as F
14 | from torch import distributions as pyd
15 | from torch.distributions.utils import _standard_normal
16 |
17 | from gymnasium import spaces
18 |
19 | def seed_everything(seed):
20 | random.seed(seed)
21 | os.environ['PYTHONHASHSEED']=str(seed)
22 | np.random.seed(seed)
23 | torch.manual_seed(seed)
24 | torch.cuda.manual_seed_all(seed)
25 | torch.backends.cudnn.benchmark = False
26 | torch.backends.cudnn.deterministic = True
27 |
28 | def is_image_space_channels_first(observation_space: spaces.Box) -> bool:
29 | """
30 | Check if an image observation space (see ``is_image_space``)
31 | is channels-first (CxHxW, True) or channels-last (HxWxC, False).
32 | Use a heuristic that channel dimension is the smallest of the three.
33 | If second dimension is smallest, raise an exception (no support).
34 | :param observation_space:
35 | :return: True if observation space is channels-first image, False if channels-last.
36 | """
37 | smallest_dimension = np.argmin(observation_space.shape).item()
38 | if smallest_dimension == 1:
39 | warnings.warn("Treating image space as channels-last, while second dimension was smallest of the three.")
40 | return smallest_dimension == 0
41 |
42 |
43 | def is_image_space(
44 | observation_space: spaces.Space,
45 | check_channels: bool = False,
46 | normalized_image: bool = False,
47 | ) -> bool:
48 | """
49 | Check if a observation space has the shape, limits and dtype
50 | of a valid image.
51 | The check is conservative, so that it returns False if there is a doubt.
52 | Valid images: RGB, RGBD, GrayScale with values in [0, 255]
53 | :param observation_space:
54 | :param check_channels: Whether to do or not the check for the number of channels.
55 | e.g., with frame-stacking, the observation space may have more channels than expected.
56 | :param normalized_image: Whether to assume that the image is already normalized
57 | or not (this disables dtype and bounds checks): when True, it only checks that
58 | the space is a Box and has 3 dimensions.
59 | Otherwise, it checks that it has expected dtype (uint8) and bounds (values in [0, 255]).
60 | :return:
61 | """
62 | check_dtype = check_bounds = not normalized_image
63 | if isinstance(observation_space, spaces.Box) and len(observation_space.shape) == 3:
64 | # Check the type
65 | if check_dtype and observation_space.dtype != np.uint8:
66 | return False
67 |
68 | # Check the value range
69 | incorrect_bounds = np.any(observation_space.low != 0) or np.any(observation_space.high != 255)
70 | if check_bounds and incorrect_bounds:
71 | return False
72 |
73 | # Skip channels check
74 | if not check_channels:
75 | return True
76 | # Check the number of channels
77 | if is_image_space_channels_first(observation_space):
78 | n_channels = observation_space.shape[0]
79 | else:
80 | n_channels = observation_space.shape[-1]
81 | # GrayScale, RGB, RGBD
82 | return n_channels in [1, 3, 4]
83 | return False
84 |
85 |
86 |
87 | def preprocess_obs(
88 | obs: torch.Tensor,
89 | observation_space: spaces.Space,
90 | normalize_images: bool = True,
91 | ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
92 | """
93 | Preprocess observation to be to a neural network.
94 | For images, it normalizes the values by dividing them by 255 (to have values in [0, 1])
95 | For discrete observations, it create a one hot vector.
96 | :param obs: Observation
97 | :param observation_space:
98 | :param normalize_images: Whether to normalize images or not
99 | (True by default)
100 | :return:
101 | """
102 | if isinstance(observation_space, spaces.Box):
103 | if normalize_images and is_image_space(observation_space):
104 | return obs.float() / 255.0
105 | return obs.float()
106 |
107 | elif isinstance(observation_space, spaces.Discrete):
108 | # One hot encoding and convert to float to avoid errors
109 | return F.one_hot(obs.long(), num_classes=observation_space.n).float()
110 |
111 | elif isinstance(observation_space, spaces.MultiDiscrete):
112 | # Tensor concatenation of one hot encodings of each Categorical sub-space
113 | return torch.cat(
114 | [
115 | F.one_hot(obs_.long(), num_classes=int(observation_space.nvec[idx])).float()
116 | for idx, obs_ in enumerate(torch.split(obs.long(), 1, dim=1))
117 | ],
118 | dim=-1,
119 | ).view(obs.shape[0], sum(observation_space.nvec))
120 |
121 | elif isinstance(observation_space, spaces.MultiBinary):
122 | return obs.float()
123 |
124 | elif isinstance(observation_space, spaces.Dict):
125 | # Do not modify by reference the original observation
126 | assert isinstance(obs, Dict), f"Expected dict, got {type(obs)}"
127 | preprocessed_obs = {}
128 | for key, _obs in obs.items():
129 | preprocessed_obs[key] = preprocess_obs(_obs, observation_space[key], normalize_images=normalize_images)
130 | return preprocessed_obs
131 |
132 | else:
133 | raise NotImplementedError(f"Preprocessing not implemented for {observation_space}")
134 |
135 |
136 | def get_obs_shape(
137 | observation_space: spaces.Space,
138 | ) -> Union[Tuple[int, ...], Dict[str, Tuple[int, ...]]]:
139 | """
140 | Get the shape of the observation (useful for the buffers).
141 | :param observation_space:
142 | :return:
143 | """
144 | if isinstance(observation_space, spaces.Box):
145 | return observation_space.shape
146 | elif isinstance(observation_space, spaces.Discrete):
147 | # Observation is an int
148 | return (1,)
149 | elif isinstance(observation_space, spaces.MultiDiscrete):
150 | # Number of discrete features
151 | return (int(len(observation_space.nvec)),)
152 | elif isinstance(observation_space, spaces.MultiBinary):
153 | # Number of binary features
154 | if type(observation_space.n) in [tuple, list, np.ndarray]:
155 | return tuple(observation_space.n)
156 | else:
157 | return (int(observation_space.n),)
158 | elif isinstance(observation_space, spaces.Dict):
159 | return {key: get_obs_shape(subspace) for (key, subspace) in observation_space.spaces.items()} # type: ignore[misc]
160 |
161 | else:
162 | raise NotImplementedError(f"{observation_space} observation space is not supported")
163 |
164 |
165 | def get_flattened_obs_dim(observation_space: spaces.Space) -> int:
166 | """
167 | Get the dimension of the observation space when flattened.
168 | It does not apply to image observation space.
169 | Used by the ``FlattenExtractor`` to compute the input shape.
170 | :param observation_space:
171 | :return:
172 | """
173 | # See issue https://github.com/openai/gym/issues/1915
174 | # it may be a problem for Dict/Tuple spaces too...
175 | if isinstance(observation_space, spaces.MultiDiscrete):
176 | return sum(observation_space.nvec)
177 | else:
178 | # Use Gym internal method
179 | return spaces.utils.flatdim(observation_space)
180 |
181 |
182 | def get_action_dim(action_space: spaces.Space) -> int:
183 | """
184 | Get the dimension of the action space.
185 | :param action_space:
186 | :return:
187 | """
188 | if isinstance(action_space, spaces.Box):
189 | return int(np.prod(action_space.shape))
190 | elif isinstance(action_space, spaces.Discrete):
191 | # Action is an int
192 | return 1
193 | elif isinstance(action_space, spaces.MultiDiscrete):
194 | # Number of discrete actions
195 | return int(len(action_space.nvec))
196 | elif isinstance(action_space, spaces.MultiBinary):
197 | # Number of binary actions
198 | return int(action_space.n)
199 | elif isinstance(action_space, spaces.Dict):
200 | return get_action_dim(action_space["motor_action"])
201 | else:
202 | raise NotImplementedError(f"{action_space} action space is not supported")
203 |
204 |
205 | def check_for_nested_spaces(obs_space: spaces.Space):
206 | """
207 | Make sure the observation space does not have nested spaces (Dicts/Tuples inside Dicts/Tuples).
208 | If so, raise an Exception informing that there is no support for this.
209 | :param obs_space: an observation space
210 | :return:
211 | """
212 | if isinstance(obs_space, (spaces.Dict, spaces.Tuple)):
213 | sub_spaces = obs_space.spaces.values() if isinstance(obs_space, spaces.Dict) else obs_space.spaces
214 | for sub_space in sub_spaces:
215 | if isinstance(sub_space, (spaces.Dict, spaces.Tuple)):
216 | raise NotImplementedError(
217 | "Nested observation spaces are not supported (Tuple/Dict space inside Tuple/Dict space)."
218 | )
219 |
220 |
221 | def get_device(device: Union[torch.device, str] = "auto") -> torch.device:
222 | """
223 | Retrieve PyTorch device.
224 | It checks that the requested device is available first.
225 | For now, it supports only cpu and cuda.
226 | By default, it tries to use the gpu.
227 | :param device: One for 'auto', 'cuda', 'cpu'
228 | :return: Supported Pytorch device
229 | """
230 | # Cuda by default
231 | if device == "auto":
232 | device = "cuda"
233 | # Force conversion to torch.device
234 | device = torch.device(device)
235 |
236 | # Cuda not available
237 | if device.type == torch.device("cuda").type and not torch.cuda.is_available():
238 | return torch.device("cpu")
239 |
240 | return device
241 |
242 |
243 | def get_timestr() -> str:
244 | current_datetime = datetime.now()
245 | return current_datetime.strftime("%m-%d-%H-%M-%S")
246 |
247 |
248 | def get_spatial_emb_indices(loc: np.ndarray,
249 | full_img_size=(4, 84, 84),
250 | img_size=(4, 21, 21),
251 | patch_size=(7, 7)) -> np.ndarray:
252 | # loc (2,)
253 | _, H, W = full_img_size
254 | _, h, w = img_size
255 | p1, p2 = patch_size
256 |
257 | st_x = loc[0] // p1
258 | st_y = loc[1] // p2
259 |
260 | ed_x = (loc[0] + h) // p1
261 | ed_y = (loc[1] + w) // p2
262 |
263 | ix, iy = np.meshgrid(np.arange(st_x, ed_x, dtype=np.int64),
264 | np.arange(st_y, ed_y, dtype=np.int64), indexing="ij")
265 |
266 | # print (ix, iy)
267 | indicies = (ix * H // p1 + iy).reshape(-1)
268 |
269 | return indicies
270 |
271 | def get_spatial_emb_mask(loc,
272 | mask,
273 | full_img_size=(4, 84, 84),
274 | img_size=(4, 21, 21),
275 | patch_size=(7, 7),
276 | latent_dim=144) -> np.ndarray:
277 | B, T, _ = loc.size()
278 | # return torch.randn_like()
279 | loc = loc.reshape(-1, 2)
280 | _, H, W = full_img_size
281 | _, h, w = img_size
282 | p1, p2 = patch_size
283 | num_tokens = h*w//p1//p2
284 | # print ("num_tokens", num_tokens)
285 |
286 | st_x = loc[..., 0] // p1
287 | st_y = loc[..., 1] // p2
288 |
289 | ed_x = (loc[..., 0] + h) // p1
290 | ed_y = (loc[..., 1] + w) // p2
291 |
292 | # mask = np.zeros(((32*6, H//p1, W//p2, latent_dim)), dtype=np.bool_)
293 | # mask = torch.zeros((32*6, H//p1, W//p2, latent_dim), dtype=torch.bool)
294 | mask[:] = False
295 | for i in range(B*T):
296 | # print (self.spatial_emb[0, st_x[i]:ed_x[i], st_y[i]:ed_y[i]].size())
297 | mask[i, st_x[i]:ed_x[i], st_y[i]:ed_y[i]] = True
298 | return mask[:B*T]
299 |
300 | def weight_init_drq(m):
301 | if isinstance(m, nn.Linear):
302 | nn.init.orthogonal_(m.weight.data)
303 | if hasattr(m.bias, 'data'):
304 | m.bias.data.fill_(0.0)
305 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
306 | gain = nn.init.calculate_gain('relu')
307 | nn.init.orthogonal_(m.weight.data, gain)
308 | if hasattr(m.bias, 'data'):
309 | m.bias.data.fill_(0.0)
310 |
311 |
312 | def soft_update_params(net, target_net, tau):
313 | for param, target_param in zip(net.parameters(), target_net.parameters()):
314 | target_param.data.copy_(tau * param.data +
315 | (1 - tau) * target_param.data)
316 |
317 | class TruncatedNormal(pyd.Normal):
318 | def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6):
319 | super().__init__(loc, scale, validate_args=False)
320 | self.low = low
321 | self.high = high
322 | self.eps = eps
323 |
324 | def _clamp(self, x):
325 | clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps)
326 | x = x - x.detach() + clamped_x.detach()
327 | return x
328 |
329 | def sample(self, clip=None, sample_shape=torch.Size()):
330 | shape = self._extended_shape(sample_shape)
331 | eps = _standard_normal(shape,
332 | dtype=self.loc.dtype,
333 | device=self.loc.device)
334 | eps *= self.scale
335 | if clip is not None:
336 | eps = torch.clamp(eps, -clip, clip)
337 | x = self.loc + eps
338 | return self._clamp(x)
339 |
340 |
341 | def schedule_drq(schdl, step):
342 | try:
343 | return float(schdl)
344 | except ValueError:
345 | match = re.match(r'linear\((.+),(.+),(.+)\)', schdl)
346 | if match:
347 | init, final, duration = [float(g) for g in match.groups()]
348 | mix = np.clip(step / duration, 0.0, 1.0)
349 | return (1.0 - mix) * init + mix * final
350 | match = re.match(r'step_linear\((.+),(.+),(.+),(.+),(.+)\)', schdl)
351 | if match:
352 | init, final1, duration1, final2, duration2 = [
353 | float(g) for g in match.groups()
354 | ]
355 | if step <= duration1:
356 | mix = np.clip(step / duration1, 0.0, 1.0)
357 | return (1.0 - mix) * init + mix * final1
358 | else:
359 | mix = np.clip((step - duration1) / duration2, 0.0, 1.0)
360 | return (1.0 - mix) * final1 + mix * final2
361 | raise NotImplementedError(schdl)
362 |
363 | def get_sugarl_reward_scale_robosuite(task_name) -> float:
364 | if task_name == "Lift":
365 | sugarl_reward_scale = 150/500
366 | elif task_name == "ToolHang":
367 | sugarl_reward_scale = 100/500
368 | else:
369 | sugarl_reward_scale = 100/500
370 | return sugarl_reward_scale
371 |
372 |
373 | def get_sugarl_reward_scale_dmc(domain_name, task_name) -> float:
374 | if domain_name == "ball_in_cup" and task_name == "catch":
375 | sugarl_reward_scale = 320/500
376 | elif domain_name == "cartpole" and task_name == "swingup":
377 | sugarl_reward_scale = 380/500
378 | elif domain_name == "cheetah" and task_name == "run":
379 | sugarl_reward_scale = 245/500
380 | elif domain_name == "dog" and task_name == "fetch":
381 | sugarl_reward_scale = 4.5/500
382 | elif domain_name == "finger" and task_name == "spin":
383 | sugarl_reward_scale = 290/500
384 | elif domain_name == "fish" and task_name == "swim":
385 | sugarl_reward_scale = 64/500
386 | elif domain_name == "reacher" and task_name == "easy":
387 | sugarl_reward_scale = 200/500
388 | elif domain_name == "walker" and task_name == "walk":
389 | sugarl_reward_scale = 290/500
390 | else:
391 | return 1.
392 |
393 | return sugarl_reward_scale
394 |
395 | def get_sugarl_reward_scale_atari(game) -> float:
396 | base_scale = 4.0
397 | sugarl_reward_scale = 1/200
398 | if game in ["alien", "assault", "asterix", "battle_zone", "seaquest", "qbert", "private_eye", "road_runner"]:
399 | sugarl_reward_scale = 1/100
400 | elif game in ["kangaroo", "krull", "chopper_command", "demon_attack"]:
401 | sugarl_reward_scale = 1/200
402 | elif game in ["up_n_down", "frostbite", "ms_pacman", "amidar", "gopher", "boxing"]:
403 | sugarl_reward_scale = 1/50
404 | elif game in ["hero", "jamesbond", "kung_fu_master"]:
405 | sugarl_reward_scale = 1/25
406 | elif game in ["crazy_climber"]:
407 | sugarl_reward_scale = 1/20
408 | elif game in ["freeway"]:
409 | sugarl_reward_scale = 1/1600
410 | elif game in ["pong"]:
411 | sugarl_reward_scale = 1/800
412 | elif game in ["bank_heist"]:
413 | sugarl_reward_scale = 1/250
414 | elif game in ["breakout"]:
415 | sugarl_reward_scale = 1/35
416 | sugarl_reward_scale = sugarl_reward_scale * base_scale
417 | return sugarl_reward_scale
418 |
--------------------------------------------------------------------------------
/agent/dqn_atari_wp_base_peripheral.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os, sys
3 | import os.path as osp
4 | import random
5 | import time
6 | from collections import deque
7 | from itertools import product
8 | from distutils.util import strtobool
9 |
10 | sys.path.append(osp.dirname(osp.dirname(osp.realpath(__file__))))
11 | os.environ["OMP_NUM_THREADS"] = "1"
12 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
13 | import warnings
14 | warnings.filterwarnings("ignore", category=UserWarning)
15 |
16 | import gymnasium as gym
17 | from gymnasium.spaces import Discrete, Dict, Box
18 | import numpy as np
19 | import torch
20 | import torch.nn as nn
21 | import torch.nn.functional as F
22 | import torch.optim as optim
23 | from torchvision.transforms import Resize
24 |
25 | from common.buffer import ReplayBuffer
26 | from common.pvm_buffer import PVMBuffer
27 | from common.utils import get_timestr, seed_everything
28 | from torch.utils.tensorboard import SummaryWriter
29 |
30 | from active_gym import AtariFixedFovealPeripheralEnv, AtariEnvArgs
31 |
32 |
33 |
34 | def parse_args():
35 | # fmt: off
36 | parser = argparse.ArgumentParser()
37 | parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
38 | help="the name of this experiment")
39 | parser.add_argument("--seed", type=int, default=1,
40 | help="seed of the experiment")
41 | parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
42 | help="if toggled, cuda will be enabled by default")
43 | parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
44 | help="whether to capture videos of the agent performances (check out `videos` folder)")
45 |
46 | # env setting
47 | parser.add_argument("--env", type=str, default="breakout",
48 | help="the id of the environment")
49 | parser.add_argument("--env-num", type=int, default=1,
50 | help="# envs in parallel")
51 | parser.add_argument("--frame-stack", type=int, default=4,
52 | help="frame stack #")
53 | parser.add_argument("--action-repeat", type=int, default=4,
54 | help="action repeat #")
55 | parser.add_argument("--clip-reward", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True)
56 |
57 | # fov setting
58 | parser.add_argument("--fov-size", type=int, default=0)
59 | parser.add_argument("--fov-init-loc", type=int, default=0)
60 | parser.add_argument("--sensory-action-mode", type=str, default="absolute")
61 | parser.add_argument("--sensory-action-space", type=int, default=10) # ignored when sensory_action_mode="relative"
62 | parser.add_argument("--resize-to-full", default=False, action="store_true")
63 | parser.add_argument("--peripheral-res", type=int, default=20)
64 | # for discrete observ action
65 | parser.add_argument("--sensory-action-x-size", type=int, default=4)
66 | parser.add_argument("--sensory-action-y-size", type=int, default=4)
67 | # pvm setting
68 | parser.add_argument("--pvm-stack", type=int, default=1)
69 |
70 | # Algorithm specific arguments
71 | parser.add_argument("--total-timesteps", type=int, default=3000000,
72 | help="total timesteps of the experiments")
73 | parser.add_argument("--learning-rate", type=float, default=1e-4,
74 | help="the learning rate of the optimizer")
75 | parser.add_argument("--buffer-size", type=int, default=500000,
76 | help="the replay memory buffer size")
77 | parser.add_argument("--gamma", type=float, default=0.99,
78 | help="the discount factor gamma")
79 | parser.add_argument("--target-network-frequency", type=int, default=1000,
80 | help="the timesteps it takes to update the target network")
81 | parser.add_argument("--batch-size", type=int, default=32,
82 | help="the batch size of sample from the reply memory")
83 | parser.add_argument("--start-e", type=float, default=1,
84 | help="the starting epsilon for exploration")
85 | parser.add_argument("--end-e", type=float, default=0.01,
86 | help="the ending epsilon for exploration")
87 | parser.add_argument("--exploration-fraction", type=float, default=0.10,
88 | help="the fraction of `total-timesteps` it takes from start-e to go end-e")
89 | parser.add_argument("--learning-starts", type=int, default=80000,
90 | help="timestep to start learning")
91 | parser.add_argument("--train-frequency", type=int, default=4,
92 | help="the frequency of training")
93 |
94 | # eval args
95 | parser.add_argument("--eval-frequency", type=int, default=-1,
96 | help="eval frequency. default -1 is eval at the end.")
97 | parser.add_argument("--eval-num", type=int, default=10,
98 | help="eval frequency. default -1 is eval at the end.")
99 | args = parser.parse_args()
100 | # fmt: on
101 | return args
102 |
103 |
104 | def make_env(env_name, seed, **kwargs):
105 | def thunk():
106 | env_args = AtariEnvArgs(
107 | game=env_name, seed=seed, obs_size=(84, 84), **kwargs
108 | )
109 | env = AtariFixedFovealPeripheralEnv(env_args)
110 | env.action_space.seed(seed)
111 | env.observation_space.seed(seed)
112 | return env
113 |
114 | return thunk
115 |
116 |
117 | # ALGO LOGIC: initialize agent here:
118 | class QNetwork(nn.Module):
119 | def __init__(self, env):
120 | super().__init__()
121 | if isinstance(env.single_action_space, Discrete):
122 | action_space_size = env.single_action_space.n
123 | elif isinstance(env.single_action_space, Dict):
124 | action_space_size = env.single_action_space["motor_action"].n
125 | self.network = nn.Sequential(
126 | nn.Conv2d(4, 32, 8, stride=4),
127 | nn.ReLU(),
128 | nn.Conv2d(32, 64, 4, stride=2),
129 | nn.ReLU(),
130 | nn.Conv2d(64, 64, 3, stride=1),
131 | nn.ReLU(),
132 | nn.Flatten(),
133 | nn.Linear(3136, 512),
134 | nn.ReLU(),
135 | nn.Linear(512, action_space_size),
136 | )
137 |
138 | def forward(self, x):
139 | return self.network(x)
140 |
141 |
142 | def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
143 | slope = (end_e - start_e) / duration
144 | return max(slope * t + start_e, end_e)
145 |
146 |
147 |
148 |
149 |
150 | if __name__ == "__main__":
151 | args = parse_args()
152 | args.env = args.env.lower()
153 | run_name = f"{args.env}__{os.path.basename(__file__)}__{args.seed}__{get_timestr()}"
154 | run_dir = os.path.join("runs", args.exp_name)
155 | if not os.path.exists(run_dir):
156 | os.makedirs(run_dir, exist_ok=True)
157 |
158 | writer = SummaryWriter(os.path.join(run_dir, run_name))
159 | writer.add_text(
160 | "hyperparameters",
161 | "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
162 | )
163 |
164 | # TRY NOT TO MODIFY: seeding
165 | seed_everything(args.seed)
166 |
167 |
168 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
169 |
170 | # env setup
171 | envs = []
172 | for i in range(args.env_num):
173 | envs.append(make_env(args.env, args.seed+i, frame_stack=args.frame_stack, action_repeat=args.action_repeat,
174 | fov_size=(args.fov_size, args.fov_size),
175 | fov_init_loc=(args.fov_init_loc, args.fov_init_loc),
176 | peripheral_res=(args.peripheral_res, args.peripheral_res),
177 | sensory_action_mode=args.sensory_action_mode,
178 | sensory_action_space=(-args.sensory_action_space, args.sensory_action_space),
179 | resize_to_full=args.resize_to_full,
180 | clip_reward=args.clip_reward,
181 | mask_out=True))
182 | # envs = gym.vector.AsyncVectorEnv(envs)
183 | envs = gym.vector.SyncVectorEnv(envs)
184 |
185 | resize = Resize((84, 84))
186 |
187 | # get a discrete observ action space
188 | OBSERVATION_SIZE = (84, 84)
189 | sensory_action_set = [(0, 0)]
190 |
191 | q_network = QNetwork(envs).to(device)
192 | optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate)
193 | target_network = QNetwork(envs).to(device)
194 | target_network.load_state_dict(q_network.state_dict())
195 |
196 | rb = ReplayBuffer(
197 | args.buffer_size,
198 | envs.single_observation_space,
199 | envs.single_action_space["motor_action"],
200 | device,
201 | n_envs=envs.num_envs,
202 | optimize_memory_usage=True,
203 | handle_timeout_termination=False,
204 | )
205 | start_time = time.time()
206 |
207 | # TRY NOT TO MODIFY: start the game
208 | obs, _ = envs.reset()
209 | global_transitions = 0
210 | fov_idx = random.randint(0, len(sensory_action_set)-1)
211 | pvm_buffer = PVMBuffer(args.pvm_stack, (envs.num_envs, args.frame_stack,)+OBSERVATION_SIZE)
212 |
213 | while global_transitions < args.total_timesteps:
214 | pvm_buffer.append(obs)
215 | pvm_obs = pvm_buffer.get_obs(mode="stack_max")
216 | # ALGO LOGIC: put action logic here
217 | epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_transitions)
218 | if random.random() < epsilon:
219 | actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
220 | motor_actions = np.array([actions[0]["motor_action"]])
221 | sensory_actions = sensory_action_set[random.randint(0, len(sensory_action_set)-1)]
222 | else:
223 | q_values = q_network(resize(torch.from_numpy(pvm_obs)).to(device))
224 | motor_actions = torch.argmax(q_values, dim=1).cpu().numpy()
225 | sensory_actions = sensory_action_set[fov_idx]
226 |
227 | # TRY NOT TO MODIFY: execute the game and log data.
228 | next_obs, rewards, dones, _, infos = envs.step({"motor_action": motor_actions,
229 | "sensory_action": [sensory_actions] })
230 | fov_idx = (fov_idx + 1) % len(sensory_action_set)
231 | # print (global_step, infos)
232 |
233 | # TRY NOT TO MODIFY: record rewards for plotting purposes
234 | if "final_info" in infos:
235 | for idx, d in enumerate(dones):
236 | if d:
237 | print(f"[T: {time.time()-start_time:.2f}] [N: {global_transitions:07,d}] [R: {infos['final_info'][idx]['reward']:.2f}]")
238 | writer.add_scalar("charts/episodic_return", infos['final_info'][idx]["reward"], global_transitions)
239 | writer.add_scalar("charts/episodic_length", infos['final_info'][idx]["ep_len"], global_transitions)
240 | writer.add_scalar("charts/epsilon", epsilon, global_transitions)
241 | break
242 |
243 |
244 | # TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
245 | real_next_obs = next_obs
246 | for idx, d in enumerate(dones):
247 | if d:
248 | real_next_obs[idx] = infos["final_observation"][idx]
249 | fov_idx = random.randint(0, len(sensory_action_set)-1)
250 | pvm_buffer_copy = pvm_buffer.copy()
251 | pvm_buffer_copy.append(real_next_obs)
252 | real_next_pvm_obs = pvm_buffer_copy.get_obs(mode="stack_max")
253 | rb.add(pvm_obs, real_next_pvm_obs, motor_actions, rewards, dones, {})
254 |
255 | # TRY NOT TO MODIFY: CRUCIAL step easy to overlook
256 | obs = next_obs
257 |
258 | # INC total transitions
259 | global_transitions += args.env_num
260 |
261 |
262 | obs_backup = obs # back obs
263 | # ALGO LOGIC: training.
264 | if global_transitions > args.learning_starts:
265 | if global_transitions % args.train_frequency == 0:
266 | data = rb.sample(args.batch_size // args.env_num) # counter-balance the true global transitions used for training
267 | with torch.no_grad():
268 | target_max, _ = target_network(resize(data.next_observations)).max(dim=1)
269 | td_target = data.rewards.flatten() + args.gamma * target_max * (1 - data.dones.flatten())
270 | old_val = q_network(resize(data.observations)).gather(1, data.actions).squeeze()
271 | loss = F.mse_loss(td_target, old_val)
272 |
273 | if global_transitions % 100 == 0:
274 | writer.add_scalar("losses/td_loss", loss, global_transitions)
275 | writer.add_scalar("losses/q_values", old_val.mean().item(), global_transitions)
276 | # print("SPS:", int(global_step / (time.time() - start_time)))
277 | writer.add_scalar("charts/SPS", int(global_transitions / (time.time() - start_time)), global_transitions)
278 |
279 | # optimize the model
280 | optimizer.zero_grad()
281 | loss.backward()
282 | optimizer.step()
283 |
284 | # update the target network
285 | if (global_transitions // args.env_num) % args.target_network_frequency == 0:
286 | target_network.load_state_dict(q_network.state_dict())
287 |
288 | # evaluation
289 | if (global_transitions % args.eval_frequency == 0 and args.eval_frequency > 0) or \
290 | (global_transitions >= args.total_timesteps):
291 | q_network.eval()
292 |
293 | eval_episodic_returns, eval_episodic_lengths = [], []
294 |
295 | for eval_ep in range(args.eval_num):
296 | eval_env = [make_env(args.env, args.seed+eval_ep, frame_stack=args.frame_stack, action_repeat=args.action_repeat,
297 | fov_size=(args.fov_size, args.fov_size),
298 | fov_init_loc=(args.fov_init_loc, args.fov_init_loc),
299 | peripheral_res=(args.peripheral_res, args.peripheral_res),
300 | sensory_action_mode=args.sensory_action_mode,
301 | sensory_action_space=(-args.sensory_action_space, args.sensory_action_space),
302 | resize_to_full=args.resize_to_full,
303 | clip_reward=args.clip_reward,
304 | training=False,
305 | mask_out=True,
306 | record=args.capture_video)]
307 | eval_env = gym.vector.SyncVectorEnv(eval_env)
308 | obs, _ = eval_env.reset()
309 | done = False
310 | fov_idx = random.randint(0, len(sensory_action_set)-1)
311 | pvm_buffer = PVMBuffer(args.pvm_stack, (eval_env.num_envs, args.frame_stack,)+OBSERVATION_SIZE)
312 | while not done:
313 | pvm_buffer.append(obs)
314 | pvm_obs = pvm_buffer.get_obs(mode="stack_max")
315 | q_values = q_network(resize(torch.from_numpy(pvm_obs)).to(device))
316 | motor_actions = torch.argmax(q_values, dim=1).cpu().numpy()
317 | next_obs, rewards, dones, _, infos = eval_env.step({"motor_action": motor_actions,
318 | "sensory_action": [sensory_action_set[fov_idx]]})
319 | obs = next_obs
320 | fov_idx = (fov_idx + 1) % len(sensory_action_set)
321 | done = dones[0]
322 | if done:
323 | eval_episodic_returns.append(infos['final_info'][0]["reward"])
324 | eval_episodic_lengths.append(infos['final_info'][0]["ep_len"])
325 | fov_idx = random.randint(0, len(sensory_action_set)-1)
326 | if args.capture_video:
327 | record_file_dir = os.path.join("recordings", args.exp_name, os.path.basename(__file__).rstrip(".py"))
328 | os.makedirs(record_file_dir, exist_ok=True)
329 | record_file_fn = f"seed{args.seed}_step{global_transitions:07d}_record.pt"
330 | eval_env.envs[0].save_record_to_file(os.path.join(record_file_dir, record_file_fn))
331 |
332 | writer.add_scalar("charts/eval_episodic_return", np.mean(eval_episodic_returns), global_transitions)
333 | writer.add_scalar("charts/eval_episodic_return_std", np.std(eval_episodic_returns), global_transitions)
334 | print(f"[T: {time.time()-start_time:.2f}] [N: {global_transitions:07,d}] [Eval R: {np.mean(eval_episodic_returns):.2f}+/-{np.std(eval_episodic_returns):.2f}] [R list: {','.join([str(r) for r in eval_episodic_returns])}]")
335 |
336 | q_network.train()
337 | obs = obs_backup # restore obs if eval occurs
338 |
339 |
340 | envs.close()
341 | eval_env.close()
342 | writer.close()
343 |
--------------------------------------------------------------------------------
/agent/dqn_atari_wp_single_policy.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os, sys
3 | import os.path as osp
4 | import random
5 | import time
6 | from collections import deque
7 | from itertools import product
8 | from distutils.util import strtobool
9 |
10 | sys.path.append(osp.dirname(osp.dirname(osp.realpath(__file__))))
11 | os.environ["OMP_NUM_THREADS"] = "1"
12 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
13 | import warnings
14 | warnings.filterwarnings("ignore", category=UserWarning)
15 |
16 | import gymnasium as gym
17 | from gymnasium.spaces import Discrete, Dict, Box
18 | import numpy as np
19 | import torch
20 | import torch.nn as nn
21 | import torch.nn.functional as F
22 | import torch.optim as optim
23 | from torchvision.transforms import Resize
24 |
25 | from common.buffer import ReplayBuffer
26 | from common.pvm_buffer import PVMBuffer
27 | from common.utils import get_timestr, seed_everything
28 | from torch.utils.tensorboard import SummaryWriter
29 |
30 | from active_gym import AtariFixedFovealEnv, AtariEnvArgs
31 |
32 |
33 |
34 | def parse_args():
35 | # fmt: off
36 | parser = argparse.ArgumentParser()
37 | parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
38 | help="the name of this experiment")
39 | parser.add_argument("--seed", type=int, default=1,
40 | help="seed of the experiment")
41 | parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
42 | help="if toggled, cuda will be enabled by default")
43 | parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
44 | help="whether to capture videos of the agent performances (check out `videos` folder)")
45 |
46 | # env setting
47 | parser.add_argument("--env", type=str, default="breakout",
48 | help="the id of the environment")
49 | parser.add_argument("--env-num", type=int, default=1,
50 | help="# envs in parallel")
51 | parser.add_argument("--frame-stack", type=int, default=4,
52 | help="frame stack #")
53 | parser.add_argument("--action-repeat", type=int, default=4,
54 | help="action repeat #")
55 | parser.add_argument("--clip-reward", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True)
56 |
57 | # fov setting
58 | parser.add_argument("--fov-size", type=int, default=50)
59 | parser.add_argument("--fov-init-loc", type=int, default=0)
60 | parser.add_argument("--sensory-action-mode", type=str, default="relative")
61 | parser.add_argument("--sensory-action-space", type=int, default=10) # ignored when sensory_action_mode="relative"
62 | parser.add_argument("--resize-to-full", default=False, action="store_true")
63 | # for discrete observ action
64 | parser.add_argument("--sensory-action-x-size", type=int, default=4)
65 | parser.add_argument("--sensory-action-y-size", type=int, default=4)
66 | # pvm setting
67 | parser.add_argument("--pvm-stack", type=int, default=3)
68 |
69 | # Algorithm specific arguments
70 | parser.add_argument("--total-timesteps", type=int, default=3000000,
71 | help="total timesteps of the experiments")
72 | parser.add_argument("--learning-rate", type=float, default=1e-4,
73 | help="the learning rate of the optimizer")
74 | parser.add_argument("--buffer-size", type=int, default=500000,
75 | help="the replay memory buffer size")
76 | parser.add_argument("--gamma", type=float, default=0.99,
77 | help="the discount factor gamma")
78 | parser.add_argument("--target-network-frequency", type=int, default=1000,
79 | help="the timesteps it takes to update the target network")
80 | parser.add_argument("--batch-size", type=int, default=32,
81 | help="the batch size of sample from the reply memory")
82 | parser.add_argument("--start-e", type=float, default=1,
83 | help="the starting epsilon for exploration")
84 | parser.add_argument("--end-e", type=float, default=0.01,
85 | help="the ending epsilon for exploration")
86 | parser.add_argument("--exploration-fraction", type=float, default=0.10,
87 | help="the fraction of `total-timesteps` it takes from start-e to go end-e")
88 | parser.add_argument("--learning-starts", type=int, default=80000,
89 | help="timestep to start learning")
90 | parser.add_argument("--train-frequency", type=int, default=4,
91 | help="the frequency of training")
92 |
93 | # eval args
94 | parser.add_argument("--eval-frequency", type=int, default=-1,
95 | help="eval frequency. default -1 is eval at the end.")
96 | parser.add_argument("--eval-num", type=int, default=10,
97 | help="eval frequency. default -1 is eval at the end.")
98 | args = parser.parse_args()
99 | # fmt: on
100 | return args
101 |
102 |
103 | def make_env(env_name, seed, **kwargs):
104 | def thunk():
105 | env_args = AtariEnvArgs(
106 | game=env_name, seed=seed, obs_size=(84, 84), **kwargs
107 | )
108 | env = AtariFixedFovealEnv(env_args)
109 | env.action_space.seed(seed)
110 | env.observation_space.seed(seed)
111 | return env
112 |
113 | return thunk
114 |
115 |
116 | # ALGO LOGIC: initialize agent here:
117 | class QNetwork(nn.Module):
118 | def __init__(self, env, override_action_set=None):
119 | super().__init__()
120 | if override_action_set:
121 | action_space_size = override_action_set.n
122 | else:
123 | if isinstance(env.single_action_space, Discrete):
124 | action_space_size = env.single_action_space.n
125 | elif isinstance(env.single_action_space, Dict):
126 | action_space_size = env.single_action_space["motor_action"].n
127 | self.network = nn.Sequential(
128 | nn.Conv2d(4, 32, 8, stride=4),
129 | nn.ReLU(),
130 | nn.Conv2d(32, 64, 4, stride=2),
131 | nn.ReLU(),
132 | nn.Conv2d(64, 64, 3, stride=1),
133 | nn.ReLU(),
134 | nn.Flatten(),
135 | nn.Linear(3136, 512),
136 | nn.ReLU(),
137 | nn.Linear(512, action_space_size),
138 | )
139 |
140 | def forward(self, x):
141 | return self.network(x)
142 |
143 |
144 | def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
145 | slope = (end_e - start_e) / duration
146 | return max(slope * t + start_e, end_e)
147 |
148 |
149 |
150 |
151 |
152 | if __name__ == "__main__":
153 | args = parse_args()
154 | args.env = args.env.lower()
155 | run_name = f"{args.env}__{os.path.basename(__file__)}__{args.seed}__{get_timestr()}"
156 | run_dir = os.path.join("runs", args.exp_name)
157 | if not os.path.exists(run_dir):
158 | os.makedirs(run_dir, exist_ok=True)
159 |
160 | writer = SummaryWriter(os.path.join(run_dir, run_name))
161 | writer.add_text(
162 | "hyperparameters",
163 | "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
164 | )
165 |
166 | # TRY NOT TO MODIFY: seeding
167 | seed_everything(args.seed)
168 |
169 |
170 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
171 |
172 | # get a discrete observ action space
173 | OBSERVATION_SIZE = (84, 84)
174 | observ_x_max, observ_y_max = OBSERVATION_SIZE[0]-args.fov_size, OBSERVATION_SIZE[1]-args.fov_size
175 | sensory_action_step = (observ_x_max//args.sensory_action_x_size,
176 | observ_y_max//args.sensory_action_y_size)
177 | sensory_action_set = [(-sensory_action_step[0], 0),
178 | (sensory_action_step[0], 0),
179 | (0, 0),
180 | (0, -sensory_action_step[1]),
181 | (0, sensory_action_step[1])]
182 |
183 | # env setup
184 | envs = []
185 | for i in range(args.env_num):
186 | envs.append(make_env(args.env, args.seed+i, frame_stack=args.frame_stack, action_repeat=args.action_repeat,
187 | fov_size=(args.fov_size, args.fov_size),
188 | fov_init_loc=(args.fov_init_loc, args.fov_init_loc),
189 | sensory_action_mode=args.sensory_action_mode,
190 | sensory_action_space=(-max(sensory_action_step), max(sensory_action_step)),
191 | resize_to_full=args.resize_to_full,
192 | clip_reward=args.clip_reward,
193 | mask_out=True))
194 | # envs = gym.vector.AsyncVectorEnv(envs)
195 | envs = gym.vector.SyncVectorEnv(envs)
196 |
197 | resize = Resize((84, 84))
198 |
199 | # make motor sensory joint action space
200 | motor_action_set = list(range(envs.single_action_space["motor_action"].n))
201 | motor_sensory_joint_action_set = []
202 | for ma in motor_action_set:
203 | for sa in sensory_action_set:
204 | motor_sensory_joint_action_set.append((ma, *sa))
205 | motor_sensory_joint_action_space = Discrete(len(motor_sensory_joint_action_set), seed=args.seed)
206 |
207 | # make a method to seperate joint action
208 | def seperate_motor_sensory_joint_action(msas: np.ndarray):
209 | mas, sas = [], []
210 | for msa in msas:
211 | msa = motor_sensory_joint_action_set[msa]
212 | mas.append(msa[0])
213 | sas.append(msa[1:])
214 | mas = np.array(mas)
215 | sas = np.array(sas)
216 | return mas, sas
217 |
218 |
219 |
220 | q_network = QNetwork(envs, motor_sensory_joint_action_space).to(device)
221 | optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate)
222 | target_network = QNetwork(envs, motor_sensory_joint_action_space).to(device)
223 | target_network.load_state_dict(q_network.state_dict())
224 |
225 | rb = ReplayBuffer(
226 | args.buffer_size,
227 | envs.single_observation_space,
228 | motor_sensory_joint_action_space,
229 | device,
230 | n_envs=envs.num_envs,
231 | optimize_memory_usage=True,
232 | handle_timeout_termination=False,
233 | )
234 | start_time = time.time()
235 |
236 | # TRY NOT TO MODIFY: start the game
237 | obs, _ = envs.reset()
238 | global_transitions = 0
239 | pvm_buffer = PVMBuffer(args.pvm_stack, (envs.num_envs, args.frame_stack,)+OBSERVATION_SIZE)
240 |
241 | while global_transitions < args.total_timesteps:
242 | pvm_buffer.append(obs)
243 | pvm_obs = pvm_buffer.get_obs(mode="stack_max")
244 | # ALGO LOGIC: put action logic here
245 | epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_transitions)
246 | if random.random() < epsilon:
247 | actions = np.array([motor_sensory_joint_action_space.sample() for _ in range(envs.num_envs)])
248 | else:
249 | q_values = q_network(resize(torch.from_numpy(pvm_obs)).to(device))
250 | actions = torch.argmax(q_values, dim=1).cpu().numpy()
251 |
252 | # print (actions)
253 | motor_actions, sensory_actions = seperate_motor_sensory_joint_action(actions)
254 |
255 | # TRY NOT TO MODIFY: execute the game and log data.
256 | next_obs, rewards, dones, _, infos = envs.step({"motor_action": motor_actions,
257 | "sensory_action": sensory_actions })
258 |
259 | # TRY NOT TO MODIFY: record rewards for plotting purposes
260 | if "final_info" in infos:
261 | for idx, d in enumerate(dones):
262 | if d:
263 | print(f"[T: {time.time()-start_time:.2f}] [N: {global_transitions:07,d}] [R: {infos['final_info'][idx]['reward']:.2f}]")
264 | writer.add_scalar("charts/episodic_return", infos['final_info'][idx]["reward"], global_transitions)
265 | writer.add_scalar("charts/episodic_length", infos['final_info'][idx]["ep_len"], global_transitions)
266 | writer.add_scalar("charts/epsilon", epsilon, global_transitions)
267 | break
268 |
269 | # TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
270 | real_next_obs = next_obs
271 | for idx, d in enumerate(dones):
272 | if d:
273 | real_next_obs[idx] = infos["final_observation"][idx]
274 | fov_idx = random.randint(0, len(sensory_action_set)-1)
275 | pvm_buffer_copy = pvm_buffer.copy()
276 | pvm_buffer_copy.append(real_next_obs)
277 | real_next_pvm_obs = pvm_buffer_copy.get_obs(mode="stack_max")
278 | rb.add(pvm_obs, real_next_pvm_obs, actions, rewards, dones, {})
279 |
280 | # TRY NOT TO MODIFY: CRUCIAL step easy to overlook
281 | obs = next_obs
282 |
283 | # INC total transitions
284 | global_transitions += args.env_num
285 |
286 |
287 | obs_backup = obs # back obs
288 | # ALGO LOGIC: training.
289 | if global_transitions > args.learning_starts:
290 | if global_transitions % args.train_frequency == 0:
291 | data = rb.sample(args.batch_size // args.env_num) # counter-balance the true global transitions used for training
292 | with torch.no_grad():
293 | target_max, _ = target_network(resize(data.next_observations)).max(dim=1)
294 | td_target = data.rewards.flatten() + args.gamma * target_max * (1 - data.dones.flatten())
295 | old_val = q_network(resize(data.observations)).gather(1, data.actions).squeeze()
296 | loss = F.mse_loss(td_target, old_val)
297 |
298 | if global_transitions % 100 == 0:
299 | writer.add_scalar("losses/td_loss", loss, global_transitions)
300 | writer.add_scalar("losses/q_values", old_val.mean().item(), global_transitions)
301 | # print("SPS:", int(global_step / (time.time() - start_time)))
302 | writer.add_scalar("charts/SPS", int(global_transitions / (time.time() - start_time)), global_transitions)
303 |
304 | # optimize the model
305 | optimizer.zero_grad()
306 | loss.backward()
307 | optimizer.step()
308 |
309 | # update the target network
310 | if (global_transitions // args.env_num) % args.target_network_frequency == 0:
311 | target_network.load_state_dict(q_network.state_dict())
312 |
313 | # evaluation
314 | if (global_transitions % args.eval_frequency == 0 and args.eval_frequency > 0) or \
315 | (global_transitions >= args.total_timesteps):
316 | q_network.eval()
317 |
318 | eval_episodic_returns, eval_episodic_lengths = [], []
319 |
320 | for eval_ep in range(args.eval_num):
321 | eval_env = [make_env(args.env, args.seed+eval_ep, frame_stack=args.frame_stack, action_repeat=args.action_repeat,
322 | fov_size=(args.fov_size, args.fov_size),
323 | fov_init_loc=(args.fov_init_loc, args.fov_init_loc),
324 | sensory_action_mode=args.sensory_action_mode,
325 | sensory_action_space=(-max(sensory_action_step), max(sensory_action_step)),
326 | resize_to_full=args.resize_to_full,
327 | clip_reward=args.clip_reward,
328 | training=False,
329 | mask_out=True)]
330 | eval_env = gym.vector.SyncVectorEnv(eval_env)
331 | obs, _ = eval_env.reset()
332 | done = False
333 | pvm_buffer = PVMBuffer(args.pvm_stack, (eval_env.num_envs, args.frame_stack,)+OBSERVATION_SIZE)
334 | while not done:
335 | pvm_buffer.append(obs)
336 | pvm_obs = pvm_buffer.get_obs(mode="stack_max")
337 | q_values = q_network(resize(torch.from_numpy(pvm_obs)).to(device))
338 | actions = torch.argmax(q_values, dim=1).cpu().numpy()
339 | motor_actions, sensory_actions = seperate_motor_sensory_joint_action(actions)
340 | next_obs, rewards, dones, _, infos = eval_env.step({"motor_action": motor_actions,
341 | "sensory_action": sensory_actions})
342 | obs = next_obs
343 | done = dones[0]
344 | if done:
345 | eval_episodic_returns.append(infos['final_info'][0]["reward"])
346 | eval_episodic_lengths.append(infos['final_info'][0]["ep_len"])
347 |
348 | if args.capture_video:
349 | record_file_dir = os.path.join("recordings", args.exp_name, os.path.basename(__file__).rstrip(".py"))
350 | os.makedirs(record_file_dir, exist_ok=True)
351 | record_file_fn = f"seed{args.seed}_step{global_transitions:07d}_eval{eval_ep:02d}_record.pt"
352 | eval_env.envs[0].save_record_to_file(os.path.join(record_file_dir, record_file_fn))
353 | if eval_ep == 0:
354 | model_file_dir = os.path.join("trained_models", args.exp_name, os.path.basename(__file__).rstrip(".py"), args.env)
355 | os.makedirs(model_file_dir, exist_ok=True)
356 | model_fn = f"{args.env}_seed{args.seed}_step{global_transitions:07d}_model.pt"
357 | torch.save({"sfn": None, "q": q_network.state_dict()}, os.path.join(model_file_dir, model_fn))
358 |
359 | writer.add_scalar("charts/eval_episodic_return", np.mean(eval_episodic_returns), global_transitions)
360 | writer.add_scalar("charts/eval_episodic_return_std", np.std(eval_episodic_returns), global_transitions)
361 | print(f"[T: {time.time()-start_time:.2f}] [N: {global_transitions:07,d}] [Eval R: {np.mean(eval_episodic_returns):.2f}+/-{np.std(eval_episodic_returns):.2f}] [R list: {','.join([str(r) for r in eval_episodic_returns])}]")
362 |
363 | q_network.train()
364 | obs = obs_backup # restore obs if eval occurs
365 |
366 |
367 | envs.close()
368 | eval_env.close()
369 | writer.close()
--------------------------------------------------------------------------------
/agent/dqn_atari_single_policy.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os, sys
3 | import os.path as osp
4 | import random
5 | import time
6 | from collections import deque
7 | from itertools import product
8 | from distutils.util import strtobool
9 |
10 | sys.path.append(osp.dirname(osp.dirname(osp.realpath(__file__))))
11 | os.environ["OMP_NUM_THREADS"] = "1"
12 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
13 | import warnings
14 | warnings.filterwarnings("ignore", category=UserWarning)
15 |
16 | import gymnasium as gym
17 | from gymnasium.spaces import Discrete, Dict, Box
18 | import numpy as np
19 | import torch
20 | import torch.nn as nn
21 | import torch.nn.functional as F
22 | import torch.optim as optim
23 | from torchvision.transforms import Resize
24 |
25 | from common.buffer import ReplayBuffer
26 | from common.pvm_buffer import PVMBuffer
27 | from common.utils import get_timestr, seed_everything
28 | from torch.utils.tensorboard import SummaryWriter
29 |
30 | from active_gym import AtariFixedFovealEnv, AtariEnvArgs
31 |
32 |
33 |
34 | def parse_args():
35 | # fmt: off
36 | parser = argparse.ArgumentParser()
37 | parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
38 | help="the name of this experiment")
39 | parser.add_argument("--seed", type=int, default=1,
40 | help="seed of the experiment")
41 | parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
42 | help="if toggled, cuda will be enabled by default")
43 | parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
44 | help="whether to capture videos of the agent performances (check out `videos` folder)")
45 |
46 | # env setting
47 | parser.add_argument("--env", type=str, default="breakout",
48 | help="the id of the environment")
49 | parser.add_argument("--env-num", type=int, default=1,
50 | help="# envs in parallel")
51 | parser.add_argument("--frame-stack", type=int, default=4,
52 | help="frame stack #")
53 | parser.add_argument("--action-repeat", type=int, default=4,
54 | help="action repeat #")
55 | parser.add_argument("--clip-reward", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True)
56 |
57 | # fov setting
58 | parser.add_argument("--fov-size", type=int, default=50)
59 | parser.add_argument("--fov-init-loc", type=int, default=0)
60 | parser.add_argument("--sensory-action-mode", type=str, default="relative")
61 | parser.add_argument("--sensory-action-space", type=int, default=10) # ignored when sensory_action_mode="relative"
62 | parser.add_argument("--resize-to-full", default=False, action="store_true")
63 | # for discrete observ action
64 | parser.add_argument("--sensory-action-x-size", type=int, default=4)
65 | parser.add_argument("--sensory-action-y-size", type=int, default=4)
66 | # pvm setting
67 | parser.add_argument("--pvm-stack", type=int, default=3)
68 |
69 | # Algorithm specific arguments
70 | parser.add_argument("--total-timesteps", type=int, default=3000000,
71 | help="total timesteps of the experiments")
72 | parser.add_argument("--learning-rate", type=float, default=1e-4,
73 | help="the learning rate of the optimizer")
74 | parser.add_argument("--buffer-size", type=int, default=500000,
75 | help="the replay memory buffer size")
76 | parser.add_argument("--gamma", type=float, default=0.99,
77 | help="the discount factor gamma")
78 | parser.add_argument("--target-network-frequency", type=int, default=1000,
79 | help="the timesteps it takes to update the target network")
80 | parser.add_argument("--batch-size", type=int, default=32,
81 | help="the batch size of sample from the reply memory")
82 | parser.add_argument("--start-e", type=float, default=1,
83 | help="the starting epsilon for exploration")
84 | parser.add_argument("--end-e", type=float, default=0.01,
85 | help="the ending epsilon for exploration")
86 | parser.add_argument("--exploration-fraction", type=float, default=0.10,
87 | help="the fraction of `total-timesteps` it takes from start-e to go end-e")
88 | parser.add_argument("--learning-starts", type=int, default=80000,
89 | help="timestep to start learning")
90 | parser.add_argument("--train-frequency", type=int, default=4,
91 | help="the frequency of training")
92 |
93 | # eval args
94 | parser.add_argument("--eval-frequency", type=int, default=-1,
95 | help="eval frequency. default -1 is eval at the end.")
96 | parser.add_argument("--eval-num", type=int, default=10,
97 | help="eval frequency. default -1 is eval at the end.")
98 | args = parser.parse_args()
99 | # fmt: on
100 | return args
101 |
102 |
103 | def make_env(env_name, seed, **kwargs):
104 | def thunk():
105 | env_args = AtariEnvArgs(
106 | game=env_name, seed=seed, obs_size=(84, 84), **kwargs
107 | )
108 | env = AtariFixedFovealEnv(env_args)
109 | env.action_space.seed(seed)
110 | env.observation_space.seed(seed)
111 | return env
112 |
113 | return thunk
114 |
115 |
116 | # ALGO LOGIC: initialize agent here:
117 | class QNetwork(nn.Module):
118 | def __init__(self, env, override_action_set=None):
119 | super().__init__()
120 | if override_action_set:
121 | action_space_size = override_action_set.n
122 | else:
123 | if isinstance(env.single_action_space, Discrete):
124 | action_space_size = env.single_action_space.n
125 | elif isinstance(env.single_action_space, Dict):
126 | action_space_size = env.single_action_space["motor_action"].n
127 | self.network = nn.Sequential(
128 | nn.Conv2d(4, 32, 8, stride=4),
129 | nn.ReLU(),
130 | nn.Conv2d(32, 64, 4, stride=2),
131 | nn.ReLU(),
132 | nn.Conv2d(64, 64, 3, stride=1),
133 | nn.ReLU(),
134 | nn.Flatten(),
135 | nn.Linear(3136, 512),
136 | nn.ReLU(),
137 | nn.Linear(512, action_space_size),
138 | )
139 |
140 | def forward(self, x):
141 | return self.network(x)
142 |
143 |
144 | def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
145 | slope = (end_e - start_e) / duration
146 | return max(slope * t + start_e, end_e)
147 |
148 |
149 |
150 |
151 |
152 | if __name__ == "__main__":
153 | args = parse_args()
154 | args.env = args.env.lower()
155 | run_name = f"{args.env}__{os.path.basename(__file__)}__{args.seed}__{get_timestr()}"
156 | run_dir = os.path.join("runs", args.exp_name)
157 | if not os.path.exists(run_dir):
158 | os.makedirs(run_dir, exist_ok=True)
159 |
160 | writer = SummaryWriter(os.path.join(run_dir, run_name))
161 | writer.add_text(
162 | "hyperparameters",
163 | "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
164 | )
165 |
166 | # TRY NOT TO MODIFY: seeding
167 | seed_everything(args.seed)
168 |
169 |
170 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
171 |
172 | # get a discrete observ action space
173 | OBSERVATION_SIZE = (84, 84)
174 | observ_x_max, observ_y_max = OBSERVATION_SIZE[0]-args.fov_size, OBSERVATION_SIZE[1]-args.fov_size
175 | sensory_action_step = (observ_x_max//args.sensory_action_x_size,
176 | observ_y_max//args.sensory_action_y_size)
177 | sensory_action_set = [(-sensory_action_step[0], 0),
178 | (sensory_action_step[0], 0),
179 | (0, 0),
180 | (0, -sensory_action_step[1]),
181 | (0, sensory_action_step[1])]
182 |
183 | # env setup
184 | envs = []
185 | for i in range(args.env_num):
186 | envs.append(make_env(args.env, args.seed+i, frame_stack=args.frame_stack, action_repeat=args.action_repeat,
187 | fov_size=(args.fov_size, args.fov_size),
188 | fov_init_loc=(args.fov_init_loc, args.fov_init_loc),
189 | sensory_action_mode=args.sensory_action_mode,
190 | sensory_action_space=(-max(sensory_action_step), max(sensory_action_step)),
191 | resize_to_full=args.resize_to_full,
192 | clip_reward=args.clip_reward,
193 | mask_out=True))
194 | # envs = gym.vector.AsyncVectorEnv(envs)
195 | envs = gym.vector.SyncVectorEnv(envs)
196 |
197 | resize = Resize((84, 84))
198 |
199 | # make motor sensory joint action space
200 | motor_action_set = list(range(envs.single_action_space["motor_action"].n))
201 | motor_sensory_joint_action_set = []
202 | for ma in motor_action_set:
203 | for sa in sensory_action_set:
204 | motor_sensory_joint_action_set.append((ma, *sa))
205 | motor_sensory_joint_action_space = Discrete(len(motor_sensory_joint_action_set), seed=args.seed)
206 |
207 | # make a method to seperate joint action
208 | def seperate_motor_sensory_joint_action(msas: np.ndarray):
209 | mas, sas = [], []
210 | for msa in msas:
211 | msa = motor_sensory_joint_action_set[msa]
212 | mas.append(msa[0])
213 | sas.append(msa[1:])
214 | mas = np.array(mas)
215 | sas = np.array(sas)
216 | return mas, sas
217 |
218 |
219 |
220 | q_network = QNetwork(envs, motor_sensory_joint_action_space).to(device)
221 | optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate)
222 | target_network = QNetwork(envs, motor_sensory_joint_action_space).to(device)
223 | target_network.load_state_dict(q_network.state_dict())
224 |
225 | rb = ReplayBuffer(
226 | args.buffer_size,
227 | envs.single_observation_space,
228 | motor_sensory_joint_action_space,
229 | device,
230 | n_envs=envs.num_envs,
231 | optimize_memory_usage=True,
232 | handle_timeout_termination=False,
233 | )
234 | start_time = time.time()
235 |
236 | # TRY NOT TO MODIFY: start the game
237 | obs, _ = envs.reset()
238 | global_transitions = 0
239 | pvm_buffer = PVMBuffer(args.pvm_stack, (envs.num_envs, args.frame_stack,)+OBSERVATION_SIZE)
240 |
241 | while global_transitions < args.total_timesteps:
242 | pvm_buffer.append(obs)
243 | pvm_obs = pvm_buffer.get_obs(mode="stack_max")
244 | # ALGO LOGIC: put action logic here
245 | epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_transitions)
246 | if random.random() < epsilon:
247 | actions = np.array([motor_sensory_joint_action_space.sample() for _ in range(envs.num_envs)])
248 | else:
249 | q_values = q_network(resize(torch.from_numpy(pvm_obs)).to(device))
250 | actions = torch.argmax(q_values, dim=1).cpu().numpy()
251 |
252 | # print (actions)
253 | motor_actions, sensory_actions = seperate_motor_sensory_joint_action(actions)
254 |
255 | # TRY NOT TO MODIFY: execute the game and log data.
256 | next_obs, rewards, dones, _, infos = envs.step({"motor_action": motor_actions,
257 | "sensory_action": sensory_actions })
258 |
259 | # TRY NOT TO MODIFY: record rewards for plotting purposes
260 | if "final_info" in infos:
261 | for idx, d in enumerate(dones):
262 | if d:
263 | print(f"[T: {time.time()-start_time:.2f}] [N: {global_transitions:07,d}] [R: {infos['final_info'][idx]['reward']:.2f}]")
264 | writer.add_scalar("charts/episodic_return", infos['final_info'][idx]["reward"], global_transitions)
265 | writer.add_scalar("charts/episodic_length", infos['final_info'][idx]["ep_len"], global_transitions)
266 | writer.add_scalar("charts/epsilon", epsilon, global_transitions)
267 | break
268 |
269 | # TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
270 | real_next_obs = next_obs
271 | for idx, d in enumerate(dones):
272 | if d:
273 | real_next_obs[idx] = infos["final_observation"][idx]
274 | fov_idx = random.randint(0, len(sensory_action_set)-1)
275 | pvm_buffer_copy = pvm_buffer.copy()
276 | pvm_buffer_copy.append(real_next_obs)
277 | real_next_pvm_obs = pvm_buffer_copy.get_obs(mode="stack_max")
278 | rb.add(pvm_obs, real_next_pvm_obs, actions, rewards, dones, {})
279 |
280 | # TRY NOT TO MODIFY: CRUCIAL step easy to overlook
281 | obs = next_obs
282 |
283 | # INC total transitions
284 | global_transitions += args.env_num
285 |
286 |
287 | obs_backup = obs # back obs
288 | # ALGO LOGIC: training.
289 | if global_transitions > args.learning_starts:
290 | if global_transitions % args.train_frequency == 0:
291 | data = rb.sample(args.batch_size // args.env_num) # counter-balance the true global transitions used for training
292 | with torch.no_grad():
293 | target_max, _ = target_network(resize(data.next_observations)).max(dim=1)
294 | td_target = data.rewards.flatten() + args.gamma * target_max * (1 - data.dones.flatten())
295 | old_val = q_network(resize(data.observations)).gather(1, data.actions).squeeze()
296 | loss = F.mse_loss(td_target, old_val)
297 |
298 | if global_transitions % 100 == 0:
299 | writer.add_scalar("losses/td_loss", loss, global_transitions)
300 | writer.add_scalar("losses/q_values", old_val.mean().item(), global_transitions)
301 | # print("SPS:", int(global_step / (time.time() - start_time)))
302 | writer.add_scalar("charts/SPS", int(global_transitions / (time.time() - start_time)), global_transitions)
303 |
304 | # optimize the model
305 | optimizer.zero_grad()
306 | loss.backward()
307 | optimizer.step()
308 |
309 | # update the target network
310 | if (global_transitions // args.env_num) % args.target_network_frequency == 0:
311 | target_network.load_state_dict(q_network.state_dict())
312 |
313 | # evaluation
314 | if (global_transitions % args.eval_frequency == 0 and args.eval_frequency > 0) or \
315 | (global_transitions >= args.total_timesteps):
316 | q_network.eval()
317 |
318 | eval_episodic_returns, eval_episodic_lengths = [], []
319 |
320 | for eval_ep in range(args.eval_num):
321 | eval_env = [make_env(args.env, args.seed+eval_ep, frame_stack=args.frame_stack, action_repeat=args.action_repeat,
322 | fov_size=(args.fov_size, args.fov_size),
323 | fov_init_loc=(args.fov_init_loc, args.fov_init_loc),
324 | sensory_action_mode=args.sensory_action_mode,
325 | sensory_action_space=(-max(sensory_action_step), max(sensory_action_step)),
326 | resize_to_full=args.resize_to_full,
327 | clip_reward=args.clip_reward,
328 | training=False,
329 | mask_out=True)]
330 | eval_env = gym.vector.SyncVectorEnv(eval_env)
331 | obs, _ = eval_env.reset()
332 | done = False
333 | pvm_buffer = PVMBuffer(args.pvm_stack, (eval_env.num_envs, args.frame_stack,)+OBSERVATION_SIZE)
334 | while not done:
335 | pvm_buffer.append(obs)
336 | pvm_obs = pvm_buffer.get_obs(mode="stack_max")
337 | q_values = q_network(resize(torch.from_numpy(pvm_obs)).to(device))
338 | actions = torch.argmax(q_values, dim=1).cpu().numpy()
339 | motor_actions, sensory_actions = seperate_motor_sensory_joint_action(actions)
340 | next_obs, rewards, dones, _, infos = eval_env.step({"motor_action": motor_actions,
341 | "sensory_action": sensory_actions})
342 | obs = next_obs
343 | done = dones[0]
344 | if done:
345 | eval_episodic_returns.append(infos['final_info'][0]["reward"])
346 | eval_episodic_lengths.append(infos['final_info'][0]["ep_len"])
347 |
348 | if args.capture_video:
349 | record_file_dir = os.path.join("recordings", args.exp_name, os.path.basename(__file__).rstrip(".py"))
350 | os.makedirs(record_file_dir, exist_ok=True)
351 | record_file_fn = f"seed{args.seed}_step{global_transitions:07d}_eval{eval_ep:02d}_record.pt"
352 | eval_env.envs[0].save_record_to_file(os.path.join(record_file_dir, record_file_fn))
353 | if eval_ep == 0:
354 | model_file_dir = os.path.join("trained_models", args.exp_name, os.path.basename(__file__).rstrip(".py"), args.env)
355 | os.makedirs(model_file_dir, exist_ok=True)
356 | model_fn = f"{args.env}_seed{args.seed}_step{global_transitions:07d}_model.pt"
357 | torch.save({"sfn": None, "q": q_network.state_dict()}, os.path.join(model_file_dir, model_fn))
358 |
359 | writer.add_scalar("charts/eval_episodic_return", np.mean(eval_episodic_returns), global_transitions)
360 | writer.add_scalar("charts/eval_episodic_return_std", np.std(eval_episodic_returns), global_transitions)
361 | print(f"[T: {time.time()-start_time:.2f}] [N: {global_transitions:07,d}] [Eval R: {np.mean(eval_episodic_returns):.2f}+/-{np.std(eval_episodic_returns):.2f}] [R list: {','.join([str(r) for r in eval_episodic_returns])}]")
362 |
363 | q_network.train()
364 | obs = obs_backup # restore obs if eval occurs
365 |
366 |
367 | envs.close()
368 | eval_env.close()
369 | writer.close()
370 |
--------------------------------------------------------------------------------
/agent/sac_atari_base.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os, sys
3 | import os.path as osp
4 | import random
5 | import time
6 | from distutils.util import strtobool
7 |
8 | sys.path.append(osp.dirname(osp.dirname(osp.realpath(__file__))))
9 | os.environ["OMP_NUM_THREADS"] = "1"
10 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
11 | import warnings
12 | warnings.filterwarnings("ignore", category=UserWarning)
13 |
14 | import gymnasium as gym
15 | import numpy as np
16 | import torch
17 | import torch.nn as nn
18 | import torch.nn.functional as F
19 | import torch.optim as optim
20 | from torch.distributions.categorical import Categorical
21 | from common.buffer import ReplayBuffer
22 | from common.utils import get_timestr, seed_everything
23 | from torch.utils.tensorboard import SummaryWriter
24 |
25 | from active_gym import AtariBaseEnv, AtariEnvArgs
26 |
27 |
28 | def parse_args():
29 | # fmt: off
30 | parser = argparse.ArgumentParser()
31 | parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__),
32 | help="the name of this experiment")
33 | parser.add_argument("--seed", type=int, default=1,
34 | help="seed of the experiment")
35 | parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
36 | help="if toggled, cuda will be enabled by default")
37 | parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
38 | help="whether to capture videos of the agent performances (check out `videos` folder)")
39 |
40 | # env setting
41 | parser.add_argument("--env", type=str, default="breakout",
42 | help="the id of the environment")
43 | parser.add_argument("--env-num", type=int, default=1,
44 | help="# envs in parallel")
45 | parser.add_argument("--frame-stack", type=int, default=4,
46 | help="frame stack #")
47 | parser.add_argument("--action-repeat", type=int, default=4,
48 | help="action repeat #")
49 | parser.add_argument("--clip-reward", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True)
50 |
51 | # Algorithm specific arguments
52 | parser.add_argument("--total-timesteps", type=int, default=5000000,
53 | help="total timesteps of the experiments")
54 | parser.add_argument("--buffer-size", type=int, default=int(1e6),
55 | help="the replay memory buffer size") # smaller than in original paper but evaluation is done only for 100k steps anyway
56 | parser.add_argument("--gamma", type=float, default=0.99,
57 | help="the discount factor gamma")
58 | parser.add_argument("--tau", type=float, default=1.0,
59 | help="target smoothing coefficient (default: 1)") # Default is 1 to perform replacement update
60 | parser.add_argument("--batch-size", type=int, default=64,
61 | help="the batch size of sample from the reply memory")
62 | parser.add_argument("--learning-starts", type=int, default=2e4,
63 | help="timestep to start learning")
64 | parser.add_argument("--policy-lr", type=float, default=3e-4,
65 | help="the learning rate of the policy network optimizer")
66 | parser.add_argument("--q-lr", type=float, default=3e-4,
67 | help="the learning rate of the Q network network optimizer")
68 | parser.add_argument("--update-frequency", type=int, default=4,
69 | help="the frequency of training updates")
70 | parser.add_argument("--target-network-frequency", type=int, default=8000,
71 | help="the frequency of updates for the target networks")
72 | parser.add_argument("--alpha", type=float, default=0.2,
73 | help="Entropy regularization coefficient.")
74 | parser.add_argument("--autotune", type=lambda x:bool(strtobool(x)), default=True, nargs="?", const=True,
75 | help="automatic tuning of the entropy coefficient")
76 | parser.add_argument("--target-entropy-scale", type=float, default=0.89,
77 | help="coefficient for scaling the autotune entropy target")
78 |
79 | # eval args
80 | parser.add_argument("--eval-frequency", type=int, default=-1,
81 | help="eval frequency. default -1 is eval at the end.")
82 | parser.add_argument("--eval-num", type=int, default=10,
83 | help="eval frequency. default -1 is eval at the end.")
84 |
85 | args = parser.parse_args()
86 | # fmt: on
87 | return args
88 |
89 |
90 | def make_env(env_name, seed, **kwargs):
91 | def thunk():
92 | env_args = AtariEnvArgs(
93 | game=env_name, seed=seed, obs_size=(84, 84), **kwargs
94 | )
95 | env = AtariBaseEnv(env_args)
96 | env.action_space.seed(seed)
97 | env.observation_space.seed(seed)
98 | return env
99 |
100 | return thunk
101 |
102 |
103 | def layer_init(layer, bias_const=0.0):
104 | nn.init.kaiming_normal_(layer.weight)
105 | torch.nn.init.constant_(layer.bias, bias_const)
106 | return layer
107 |
108 |
109 | # ALGO LOGIC: initialize agent here:
110 | # NOTE: Sharing a CNN encoder between Actor and Critics is not recommended for SAC without stopping actor gradients
111 | # See the SAC+AE paper https://arxiv.org/abs/1910.01741 for more info
112 | # TL;DR The actor's gradients mess up the representation when using a joint encoder
113 | class SoftQNetwork(nn.Module):
114 | def __init__(self, envs):
115 | super().__init__()
116 | obs_shape = envs.single_observation_space.shape
117 | self.conv = nn.Sequential(
118 | layer_init(nn.Conv2d(obs_shape[0], 32, kernel_size=8, stride=4)),
119 | nn.ReLU(),
120 | layer_init(nn.Conv2d(32, 64, kernel_size=4, stride=2)),
121 | nn.ReLU(),
122 | layer_init(nn.Conv2d(64, 64, kernel_size=3, stride=1)),
123 | nn.Flatten(),
124 | )
125 |
126 | with torch.inference_mode():
127 | output_dim = self.conv(torch.zeros(1, *obs_shape)).shape[1]
128 |
129 | self.fc1 = layer_init(nn.Linear(output_dim, 512))
130 | self.fc_q = layer_init(nn.Linear(512, envs.single_action_space.n))
131 |
132 | def forward(self, x):
133 | x = F.relu(self.conv(x))
134 | x = F.relu(self.fc1(x))
135 | q_vals = self.fc_q(x)
136 | return q_vals
137 |
138 |
139 | class Actor(nn.Module):
140 | def __init__(self, envs):
141 | super().__init__()
142 | obs_shape = envs.single_observation_space.shape
143 | self.conv = nn.Sequential(
144 | layer_init(nn.Conv2d(obs_shape[0], 32, kernel_size=8, stride=4)),
145 | nn.ReLU(),
146 | layer_init(nn.Conv2d(32, 64, kernel_size=4, stride=2)),
147 | nn.ReLU(),
148 | layer_init(nn.Conv2d(64, 64, kernel_size=3, stride=1)),
149 | nn.Flatten(),
150 | )
151 |
152 | with torch.inference_mode():
153 | output_dim = self.conv(torch.zeros(1, *obs_shape)).shape[1]
154 |
155 | self.fc1 = layer_init(nn.Linear(output_dim, 512))
156 | self.fc_logits = layer_init(nn.Linear(512, envs.single_action_space.n))
157 |
158 | def forward(self, x):
159 | x = F.relu(self.conv(x))
160 | x = F.relu(self.fc1(x))
161 | logits = self.fc_logits(x)
162 |
163 | return logits
164 |
165 | def get_action(self, x):
166 | logits = self(x)
167 | policy_dist = Categorical(logits=logits)
168 | action = policy_dist.sample()
169 | # Action probabilities for calculating the adapted soft-Q loss
170 | action_probs = policy_dist.probs
171 | log_prob = F.log_softmax(logits, dim=1)
172 | return action, log_prob, action_probs
173 |
174 |
175 | if __name__ == "__main__":
176 | args = parse_args()
177 | args.env = args.env.lower()
178 | run_name = f"{args.env}__{os.path.basename(__file__)}__{args.seed}__{get_timestr()}"
179 | run_dir = os.path.join("runs", args.exp_name)
180 | if not os.path.exists(run_dir):
181 | os.makedirs(run_dir, exist_ok=True)
182 |
183 | writer = SummaryWriter(os.path.join(run_dir, run_name))
184 | writer.add_text(
185 | "hyperparameters",
186 | "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
187 | )
188 |
189 | # TRY NOT TO MODIFY: seeding
190 | seed_everything(args.seed)
191 |
192 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
193 |
194 | envs = []
195 | for i in range(args.env_num):
196 | envs.append(make_env(args.env, args.seed+i, frame_stack=args.frame_stack, action_repeat=args.action_repeat, clip_reward=args.clip_reward))
197 | # envs = gym.vector.AsyncVectorEnv(envs)
198 | envs = gym.vector.SyncVectorEnv(envs)
199 |
200 | actor = Actor(envs).to(device)
201 | qf1 = SoftQNetwork(envs).to(device)
202 | qf2 = SoftQNetwork(envs).to(device)
203 | qf1_target = SoftQNetwork(envs).to(device)
204 | qf2_target = SoftQNetwork(envs).to(device)
205 | qf1_target.load_state_dict(qf1.state_dict())
206 | qf2_target.load_state_dict(qf2.state_dict())
207 | # TRY NOT TO MODIFY: eps=1e-4 increases numerical stability
208 | q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=args.q_lr, eps=1e-4)
209 | actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.policy_lr, eps=1e-4)
210 |
211 | # Automatic entropy tuning
212 | if args.autotune:
213 | target_entropy = -args.target_entropy_scale * torch.log(1 / torch.tensor(envs.single_action_space.n))
214 | log_alpha = torch.zeros(1, requires_grad=True, device=device)
215 | alpha = log_alpha.exp().item()
216 | a_optimizer = optim.Adam([log_alpha], lr=args.q_lr, eps=1e-4)
217 | else:
218 | alpha = args.alpha
219 |
220 | rb = ReplayBuffer(
221 | args.buffer_size,
222 | envs.single_observation_space,
223 | envs.single_action_space,
224 | device,
225 | n_envs=envs.num_envs,
226 | optimize_memory_usage=True,
227 | handle_timeout_termination=False,
228 | )
229 | start_time = time.time()
230 |
231 | # TRY NOT TO MODIFY: start the game
232 | obs, infos = envs.reset()
233 | global_transitions = 0
234 | while global_transitions < args.total_timesteps:
235 | # ALGO LOGIC: put action logic here
236 | if global_transitions < args.learning_starts:
237 | actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
238 | else:
239 | actions, _, _ = actor.get_action(torch.Tensor(obs).to(device))
240 | actions = actions.detach().cpu().numpy()
241 |
242 | # TRY NOT TO MODIFY: execute the game and log data.
243 | next_obs, rewards, dones, _, infos = envs.step(actions)
244 |
245 | # TRY NOT TO MODIFY: record rewards for plotting purposes
246 | if "final_info" in infos:
247 | for idx, d in enumerate(dones):
248 | if d:
249 | print(f"[T: {time.time()-start_time:.2f}] [N: {global_transitions:07,d}] [R: {infos['final_info'][idx]['reward']:.2f}]")
250 | writer.add_scalar("charts/episodic_return", infos['final_info'][idx]["reward"], global_transitions)
251 | writer.add_scalar("charts/episodic_length", infos['final_info'][idx]["ep_len"], global_transitions)
252 | break
253 |
254 | # TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
255 | real_next_obs = next_obs.copy()
256 | for idx, d in enumerate(dones):
257 | if d:
258 | real_next_obs[idx] = infos["final_observation"][idx]
259 | rb.add(obs, real_next_obs, actions, rewards, dones, infos)
260 |
261 | # TRY NOT TO MODIFY: CRUCIAL step easy to overlook
262 | obs = next_obs
263 |
264 | # INC total transitions
265 | global_transitions += args.env_num
266 |
267 | # ALGO LOGIC: training.
268 | if global_transitions > args.learning_starts:
269 | if global_transitions % args.update_frequency == 0:
270 | data = rb.sample(args.batch_size)
271 | # CRITIC training
272 | with torch.no_grad():
273 | _, next_state_log_pi, next_state_action_probs = actor.get_action(data.next_observations)
274 | qf1_next_target = qf1_target(data.next_observations)
275 | qf2_next_target = qf2_target(data.next_observations)
276 | # we can use the action probabilities instead of MC sampling to estimate the expectation
277 | min_qf_next_target = next_state_action_probs * (
278 | torch.min(qf1_next_target, qf2_next_target) - alpha * next_state_log_pi
279 | )
280 | # adapt Q-target for discrete Q-function
281 | min_qf_next_target = min_qf_next_target.sum(dim=1)
282 | next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (min_qf_next_target)
283 |
284 | # use Q-values only for the taken actions
285 | qf1_values = qf1(data.observations)
286 | qf2_values = qf2(data.observations)
287 | qf1_a_values = qf1_values.gather(1, data.actions.long()).view(-1)
288 | qf2_a_values = qf2_values.gather(1, data.actions.long()).view(-1)
289 | qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
290 | qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
291 | qf_loss = qf1_loss + qf2_loss
292 |
293 | q_optimizer.zero_grad()
294 | qf_loss.backward()
295 | q_optimizer.step()
296 |
297 | # ACTOR training
298 | _, log_pi, action_probs = actor.get_action(data.observations)
299 | with torch.no_grad():
300 | qf1_values = qf1(data.observations)
301 | qf2_values = qf2(data.observations)
302 | min_qf_values = torch.min(qf1_values, qf2_values)
303 | # no need for reparameterization, the expectation can be calculated for discrete actions
304 | actor_loss = (action_probs * ((alpha * log_pi) - min_qf_values)).mean()
305 |
306 | actor_optimizer.zero_grad()
307 | actor_loss.backward()
308 | actor_optimizer.step()
309 |
310 | if args.autotune:
311 | # re-use action probabilities for temperature loss
312 | alpha_loss = (action_probs.detach() * (-log_alpha * (log_pi + target_entropy).detach())).mean()
313 |
314 | a_optimizer.zero_grad()
315 | alpha_loss.backward()
316 | a_optimizer.step()
317 | alpha = log_alpha.exp().item()
318 |
319 | # update the target networks
320 | if global_transitions % args.target_network_frequency == 0:
321 | for param, target_param in zip(qf1.parameters(), qf1_target.parameters()):
322 | target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
323 | for param, target_param in zip(qf2.parameters(), qf2_target.parameters()):
324 | target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
325 |
326 | if global_transitions % 100 == 0:
327 | writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_transitions)
328 | writer.add_scalar("losses/qf2_values", qf2_a_values.mean().item(), global_transitions)
329 | writer.add_scalar("losses/qf1_loss", qf1_loss.item(), global_transitions)
330 | writer.add_scalar("losses/qf2_loss", qf2_loss.item(), global_transitions)
331 | writer.add_scalar("losses/qf_loss", qf_loss.item() / 2.0, global_transitions)
332 | writer.add_scalar("losses/actor_loss", actor_loss.item(), global_transitions)
333 | writer.add_scalar("losses/alpha", alpha, global_transitions)
334 | # print("SPS:", int(global_transitions / (time.time() - start_time)))
335 | writer.add_scalar("charts/SPS", int(global_transitions / (time.time() - start_time)), global_transitions)
336 | if args.autotune:
337 | writer.add_scalar("losses/alpha_loss", alpha_loss.item(), global_transitions)
338 |
339 | # evaluation
340 | if (global_transitions % args.eval_frequency == 0 and args.eval_frequency > 0) or \
341 | (global_transitions >= args.total_timesteps):
342 | qf1.eval()
343 | qf2.eval()
344 | actor.eval()
345 |
346 | eval_episodic_returns, eval_episodic_lengths = [], []
347 |
348 | for eval_ep in range(args.eval_num):
349 | eval_env = [make_env(args.env, args.seed+eval_ep, frame_stack=args.frame_stack, action_repeat=args.action_repeat, clip_reward=args.clip_reward, training=False, record=args.capture_video)]
350 | eval_env = gym.vector.SyncVectorEnv(eval_env)
351 | obs, _ = eval_env.reset()
352 | done = False
353 | while not done:
354 | actions, _, _ = actor.get_action(torch.Tensor(obs).to(device))
355 | actions = actions.detach().cpu().numpy()
356 | next_obs, rewards, dones, _, infos = eval_env.step(actions)
357 | obs = next_obs
358 | done = dones[0]
359 | if done:
360 | eval_episodic_returns.append(infos['final_info'][0]["reward"])
361 | eval_episodic_lengths.append(infos['final_info'][0]["ep_len"])
362 | if args.capture_video:
363 | record_file_dir = os.path.join("recordings", args.exp_name, os.path.basename(__file__).rstrip(".py"), args.env)
364 | os.makedirs(record_file_dir, exist_ok=True)
365 | record_file_fn = f"{args.env}_seed{args.seed}_step{global_transitions:07d}_eval{eval_ep:02d}_record.pt"
366 | eval_env.envs[0].save_record_to_file(os.path.join(record_file_dir, record_file_fn))
367 | model_file_dir = os.path.join("trained_models", args.exp_name, os.path.basename(__file__).rstrip(".py"), args.env)
368 | os.makedirs(model_file_dir, exist_ok=True)
369 | model_fn = f"{args.env}_seed{args.seed}_step{global_transitions:07d}_eval{eval_ep:02d}_record.pt"
370 | torch.save({"sfn": None, "qf1": qf1.state_dict(), "qf2": qf2.state_dict(), "actor": actor.state_dict()}, os.path.join(model_file_dir, model_fn))
371 |
372 | writer.add_scalar("charts/eval_episodic_return", np.mean(eval_episodic_returns), global_transitions)
373 | writer.add_scalar("charts/eval_episodic_return_std", np.std(eval_episodic_returns), global_transitions)
374 | # writer.add_scalar("charts/eval_episodic_length", np.mean(), global_transitions)
375 | print(f"[T: {time.time()-start_time:.2f}] [N: {global_transitions:07,d}] [Eval R: {np.mean(eval_episodic_returns):.2f}+/-{np.std(eval_episodic_returns):.2f}] [R list: {','.join([str(r) for r in eval_episodic_returns])}]")
376 |
377 | qf1.train()
378 | qf2.train()
379 | actor.train()
380 |
381 | envs.close()
382 | eval_env.close()
383 | writer.close()
--------------------------------------------------------------------------------
/agent/drqv2_dmc_base.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os, sys
3 | import os.path as osp
4 | import random
5 | import time
6 | from itertools import product
7 | from distutils.util import strtobool
8 |
9 | sys.path.append(osp.dirname(osp.dirname(osp.realpath(__file__))))
10 | os.environ["OMP_NUM_THREADS"] = "1"
11 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
12 | import warnings
13 | warnings.filterwarnings("ignore", category=UserWarning)
14 |
15 | import gymnasium as gym
16 | from gymnasium.spaces import Discrete, Dict
17 | import numpy as np
18 | import torch
19 | import torch.nn as nn
20 | import torch.nn.functional as F
21 | import torch.optim as optim
22 | from torchvision.transforms import Resize
23 | from torch.utils.tensorboard import SummaryWriter
24 |
25 | from common.buffer import NstepRewardReplayBuffer
26 | from common.pvm_buffer import PVMBuffer
27 | from common.utils import (
28 | get_timestr,
29 | schedule_drq,
30 | seed_everything,
31 | soft_update_params,
32 | weight_init_drq,
33 | TruncatedNormal
34 | )
35 |
36 | from active_gym import DMCBaseEnv, DMCEnvArgs
37 |
38 |
39 | def parse_args():
40 | # fmt: off
41 | parser = argparse.ArgumentParser()
42 | parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__),
43 | help="the name of this experiment")
44 | parser.add_argument("--seed", type=int, default=1,
45 | help="seed of the experiment")
46 | parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
47 | help="if toggled, `torch.backends.cudnn.deterministic=False`")
48 | parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
49 | help="if toggled, cuda will be enabled by default")
50 | parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
51 | help="whether to capture videos of the agent performances")
52 |
53 | # env setting
54 | parser.add_argument("--domain-name", type=str, default="walker",
55 | help="the name of the dmc domain")
56 | parser.add_argument("--task-name", type=str, default="walk",
57 | help="the name of the dmc task")
58 | parser.add_argument("--env-num", type=int, default=1,
59 | help="# envs in parallel")
60 | parser.add_argument("--frame-stack", type=int, default=3,
61 | help="frame stack #")
62 | parser.add_argument("--action-repeat", type=int, default=2,
63 | help="action repeat #") # i.e. frame skip
64 | parser.add_argument("--clip-reward", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True) # dmc we may not clip
65 |
66 | # fov setting
67 | parser.add_argument("--fov-size", type=int, default=50)
68 | parser.add_argument("--fov-init-loc", type=int, default=0)
69 | parser.add_argument("--sensory-action-mode", type=str, default="absolute")
70 | parser.add_argument("--sensory-action-space", type=int, default=10) # ignored when sensory_action_mode="relative"
71 | parser.add_argument("--resize-to-full", default=False, action="store_true")
72 | # for discrete observ action
73 | parser.add_argument("--sensory-action-x-size", type=int, default=4)
74 | parser.add_argument("--sensory-action-y-size", type=int, default=4)
75 | # pvm setting
76 | parser.add_argument("--pvm-stack", type=int, default=3)
77 |
78 | # Algorithm specific arguments
79 | parser.add_argument("--total-timesteps", type=int, default=1000000,
80 | help="total timesteps of the experiments")
81 | parser.add_argument("--buffer-size", type=int, default=int(1e6),
82 | help="the replay memory buffer size")
83 | parser.add_argument("--gamma", type=float, default=0.99,
84 | help="the discount factor gamma")
85 | parser.add_argument("--tau", type=float, default=0.01,
86 | help="target smoothing coefficient (default: 0.01)")
87 | parser.add_argument("--batch-size", type=int, default=256,
88 | help="the batch size of sample from the reply memory")
89 | parser.add_argument("--learning-starts", type=int, default=2000,
90 | help="timestep to start learning")
91 | parser.add_argument("--lr", type=float, default=1e-4,
92 | help="the learning rate of drq")
93 | parser.add_argument("--update-frequency", type=int, default=2,
94 | help="update frequency of drq")
95 | parser.add_argument("--stddev-clip", type=float, default=0.3)
96 | parser.add_argument("--stddev-schedule", type=str, default="linear(1.0,0.1,50000)")
97 | parser.add_argument("--feature-dim", type=int, default=50)
98 | parser.add_argument("--hidden-dim", type=int, default=1024)
99 | parser.add_argument("--n-step-reward", type=int, default=3)
100 |
101 | # eval args
102 | parser.add_argument("--eval-frequency", type=int, default=-1,
103 | help="eval frequency. default -1 is eval at the end.")
104 | parser.add_argument("--eval-num", type=int, default=10,
105 | help="eval episodes")
106 |
107 | args = parser.parse_args()
108 | # fmt: on
109 | return args
110 |
111 |
112 | def make_env(domain_name, task_name, seed, **kwargs):
113 | def thunk():
114 | env_args = DMCEnvArgs(
115 | domain_name=domain_name, task_name=task_name, seed=seed, obs_size=(84, 84), **kwargs
116 | )
117 | env = DMCBaseEnv(env_args)
118 | env.action_space.seed(seed)
119 | env.observation_space.seed(seed)
120 | return env
121 | return thunk
122 |
123 | class RandomShiftsAug(nn.Module):
124 | def __init__(self, pad):
125 | super().__init__()
126 | self.pad = pad
127 |
128 | def forward(self, x):
129 | n, c, h, w = x.size()
130 | assert h == w
131 | padding = tuple([self.pad] * 4)
132 | x = F.pad(x, padding, 'replicate')
133 | eps = 1.0 / (h + 2 * self.pad)
134 | arange = torch.linspace(-1.0 + eps,
135 | 1.0 - eps,
136 | h + 2 * self.pad,
137 | device=x.device,
138 | dtype=x.dtype)[:h]
139 | arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
140 | base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
141 | base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)
142 |
143 | shift = torch.randint(0,
144 | 2 * self.pad + 1,
145 | size=(n, 1, 1, 2),
146 | device=x.device,
147 | dtype=x.dtype)
148 | shift *= 2.0 / (h + 2 * self.pad)
149 |
150 | grid = base_grid + shift
151 | return F.grid_sample(x,
152 | grid,
153 | padding_mode='zeros',
154 | align_corners=False)
155 |
156 |
157 | class Encoder(nn.Module):
158 | def __init__(self, obs_shape):
159 | super().__init__()
160 |
161 | assert len(obs_shape) == 3
162 | self.repr_dim = 32 * 35 * 35
163 |
164 | self.convnet = nn.Sequential(nn.Conv2d(obs_shape[0], 32, 3, stride=2),
165 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
166 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
167 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
168 | nn.ReLU())
169 |
170 | self.apply(weight_init_drq)
171 |
172 | def forward(self, obs):
173 | obs = obs - 0.5 # /255 is done by env
174 | h = self.convnet(obs)
175 | h = h.view(h.shape[0], -1)
176 | return h
177 |
178 |
179 | class Actor(nn.Module):
180 | def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim):
181 | super().__init__()
182 |
183 | self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim),
184 | nn.LayerNorm(feature_dim), nn.Tanh())
185 |
186 | self.policy = nn.Sequential(nn.Linear(feature_dim, hidden_dim),
187 | nn.ReLU(inplace=True),
188 | nn.Linear(hidden_dim, hidden_dim),
189 | nn.ReLU(inplace=True),
190 | nn.Linear(hidden_dim, action_shape[0]))
191 |
192 | self.apply(weight_init_drq)
193 |
194 | def forward(self, obs, std):
195 | h = self.trunk(obs)
196 |
197 | mu = self.policy(h)
198 | mu = torch.tanh(mu)
199 | std = torch.ones_like(mu) * std
200 |
201 | dist = TruncatedNormal(mu, std)
202 | return dist
203 |
204 |
205 | class Critic(nn.Module):
206 | def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim):
207 | super().__init__()
208 |
209 | self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim),
210 | nn.LayerNorm(feature_dim), nn.Tanh())
211 |
212 | self.Q1 = nn.Sequential(
213 | nn.Linear(feature_dim + action_shape[0], hidden_dim),
214 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim),
215 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1))
216 |
217 | self.Q2 = nn.Sequential(
218 | nn.Linear(feature_dim + action_shape[0], hidden_dim),
219 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim),
220 | nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1))
221 |
222 | self.apply(weight_init_drq)
223 |
224 | def forward(self, obs, action):
225 | h = self.trunk(obs)
226 | h_action = torch.cat([h, action], dim=-1)
227 | q1 = self.Q1(h_action)
228 | q2 = self.Q2(h_action)
229 |
230 | return q1, q2
231 |
232 |
233 | class DrQV2Agent:
234 | def __init__(self, obs_shape, action_shape, device, lr, feature_dim,
235 | hidden_dim, critic_target_tau, learning_starts,
236 | update_every_steps, stddev_schedule, stddev_clip):
237 | self.device = device
238 | self.critic_target_tau = critic_target_tau
239 | self.update_every_steps = update_every_steps
240 | self.learning_starts = learning_starts
241 | self.stddev_schedule = stddev_schedule
242 | self.stddev_clip = stddev_clip
243 |
244 | # models
245 | self.encoder = Encoder(obs_shape).to(device)
246 | self.actor = Actor(self.encoder.repr_dim, action_shape, feature_dim,
247 | hidden_dim).to(device)
248 |
249 | self.critic = Critic(self.encoder.repr_dim, action_shape, feature_dim,
250 | hidden_dim).to(device)
251 | self.critic_target = Critic(self.encoder.repr_dim, action_shape,
252 | feature_dim, hidden_dim).to(device)
253 | self.critic_target.load_state_dict(self.critic.state_dict())
254 |
255 | # optimizers
256 | self.encoder_opt = torch.optim.Adam(self.encoder.parameters(), lr=lr)
257 | self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr)
258 | self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=lr)
259 |
260 | # data augmentation
261 | self.aug = RandomShiftsAug(pad=4)
262 |
263 | self.train()
264 | self.critic_target.train()
265 |
266 | def train(self, training=True):
267 | self.training = training
268 | self.encoder.train(training)
269 | self.actor.train(training)
270 | self.critic.train(training)
271 |
272 | def eval(self):
273 | self.train(False)
274 |
275 | def act(self, obs, step, eval_mode=None) -> np.ndarray:
276 | obs = self.encoder(obs)
277 | stddev = schedule_drq(self.stddev_schedule, step)
278 | dist = self.actor(obs, stddev)
279 |
280 | # auto eval
281 | if eval_mode is None:
282 | eval_mode = not self.training
283 |
284 | if eval_mode:
285 | action = dist.mean
286 | else:
287 | action = dist.sample(clip=None)
288 | if step < self.learning_starts:
289 | action.uniform_(-1.0, 1.0)
290 | return action.detach().cpu().numpy()
291 |
292 | def update_critic(self, obs, action, reward, discount, next_obs, step):
293 | metrics = dict()
294 |
295 | with torch.no_grad():
296 | stddev = schedule_drq(self.stddev_schedule, step)
297 | dist = self.actor(next_obs, stddev)
298 | next_action = dist.sample(clip=self.stddev_clip)
299 | target_Q1, target_Q2 = self.critic_target(next_obs, next_action)
300 | target_V = torch.min(target_Q1, target_Q2)
301 | target_Q = reward + (discount * target_V)
302 |
303 | Q1, Q2 = self.critic(obs, action)
304 | critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q)
305 |
306 | # optimize encoder and critic
307 | self.encoder_opt.zero_grad(set_to_none=True)
308 | self.critic_opt.zero_grad(set_to_none=True)
309 | critic_loss.backward()
310 | self.critic_opt.step()
311 | self.encoder_opt.step()
312 |
313 | return metrics
314 |
315 | def update_actor(self, obs, step):
316 | metrics = dict()
317 |
318 | stddev = schedule_drq(self.stddev_schedule, step)
319 | dist = self.actor(obs, stddev)
320 | action = dist.sample(clip=self.stddev_clip)
321 | log_prob = dist.log_prob(action).sum(-1, keepdim=True)
322 | Q1, Q2 = self.critic(obs, action)
323 | Q = torch.min(Q1, Q2)
324 |
325 | actor_loss = -Q.mean()
326 |
327 | # optimize actor
328 | self.actor_opt.zero_grad(set_to_none=True)
329 | actor_loss.backward()
330 | self.actor_opt.step()
331 |
332 | return metrics
333 |
334 | def update(self, batch, step):
335 | metrics = dict()
336 |
337 | if step % self.update_every_steps != 0:
338 | return metrics
339 |
340 | obs = batch.observations
341 | action = batch.actions
342 | next_obs = batch.next_observations
343 | reward = batch.rewards
344 | discount = batch.discounts
345 |
346 | # augment
347 | obs = self.aug(obs.float())
348 | next_obs = self.aug(next_obs.float())
349 | # encode
350 | obs = self.encoder(obs)
351 | with torch.no_grad():
352 | next_obs = self.encoder(next_obs)
353 |
354 |
355 | self.update_critic(obs, action, reward, discount, next_obs, step)
356 |
357 | self.update_actor(obs.detach(), step)
358 |
359 | # update critic target
360 | soft_update_params(self.critic, self.critic_target,
361 | self.critic_target_tau)
362 |
363 | return metrics
364 |
365 | def save_agent(self):
366 | agent = {
367 | "encoder": self.encoder.state_dict(),
368 | "critic": self.critic.state_dict(),
369 | "critic_target": self.critic_target.state_dict(),
370 | "actor": self.actor.state_dict()
371 | }
372 | return agent
373 |
374 |
375 | if __name__ == "__main__":
376 | args = parse_args()
377 | args.domain_name = args.domain_name.lower()
378 | args.task_name = args.task_name.lower()
379 | run_name = f"{args.domain_name}-{args.task_name}__{os.path.basename(__file__)}__{args.seed}__{get_timestr()}"
380 | run_dir = os.path.join("runs", args.exp_name)
381 | if not os.path.exists(run_dir):
382 | os.makedirs(run_dir, exist_ok=True)
383 |
384 | writer = SummaryWriter(os.path.join(run_dir, run_name))
385 | writer.add_text(
386 | "hyperparameters",
387 | "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
388 | )
389 |
390 | # TRY NOT TO MODIFY: seeding
391 | seed_everything(args.seed)
392 |
393 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
394 |
395 | # env setup
396 | envs = []
397 | for i in range(args.env_num):
398 | envs.append(make_env(args.domain_name,
399 | args.task_name,
400 | args.seed+i,
401 | frame_stack=args.frame_stack,
402 | action_repeat=args.action_repeat,
403 | clip_reward=False))
404 | envs = gym.vector.SyncVectorEnv(envs)
405 |
406 | drq_agent = DrQV2Agent(
407 | envs.single_observation_space.shape,
408 | envs.single_action_space.shape,
409 | device=device,
410 | lr=args.lr,
411 | feature_dim=args.feature_dim,
412 | hidden_dim=args.hidden_dim,
413 | critic_target_tau=args.tau,
414 | learning_starts=args.learning_starts,
415 | update_every_steps=args.update_frequency,
416 | stddev_schedule=args.stddev_schedule,
417 | stddev_clip=args.stddev_clip
418 | )
419 |
420 | rb = NstepRewardReplayBuffer(
421 | n_step_reward=args.n_step_reward,
422 | gamma=args.gamma,
423 | buffer_size=args.buffer_size,
424 | observation_space=envs.single_observation_space,
425 | action_space=envs.single_action_space,
426 | device=device,
427 | handle_timeout_termination=True,
428 | )
429 | start_time = time.time()
430 |
431 | # TRY NOT TO MODIFY: start the game
432 | obs, infos = envs.reset()
433 | global_transitions = 0
434 | while global_transitions < args.total_timesteps:
435 | actions = drq_agent.act(torch.Tensor(obs).to(device), global_transitions)
436 | next_obs, rewards, dones, _, infos = envs.step(actions)
437 |
438 | if "final_info" in infos:
439 | for idx, d in enumerate(dones):
440 | if d:
441 | print(f"[T: {time.time()-start_time:.2f}] [N: {global_transitions:07,d}] [R: {infos['final_info'][idx]['reward']:.2f}]")
442 | writer.add_scalar("charts/episodic_return", infos['final_info'][idx]["reward"], global_transitions)
443 | writer.add_scalar("charts/episodic_length", infos['final_info'][idx]["ep_len"], global_transitions)
444 | break
445 |
446 | real_next_obs = next_obs.copy()
447 | for idx, d in enumerate(dones):
448 | if d:
449 | real_next_obs[idx] = infos["final_observation"][idx]
450 | rb.add(obs, real_next_obs, actions, rewards, dones, [{}])
451 |
452 | obs = next_obs
453 | global_transitions += args.env_num
454 |
455 | # ALGO LOGIC: training.
456 | if global_transitions > args.learning_starts:
457 | data = rb.sample(args.batch_size)
458 | drq_agent.update(data, global_transitions)
459 |
460 | if global_transitions % 100 == 0:
461 | writer.add_scalar("charts/SPS", int(global_transitions / (time.time() - start_time)), global_transitions)
462 |
463 | if (global_transitions % args.eval_frequency == 0 and args.eval_frequency > 0) or \
464 | (global_transitions >= args.total_timesteps):
465 | drq_agent.eval()
466 |
467 | eval_episodic_returns, eval_episodic_lengths = [], []
468 | for eval_ep in range(args.eval_num):
469 | eval_env=[make_env(args.domain_name,
470 | args.task_name,
471 | args.seed+eval_ep,
472 | frame_stack=args.frame_stack,
473 | action_repeat=args.action_repeat,
474 | clip_reward=False)]
475 | eval_env = gym.vector.SyncVectorEnv(eval_env)
476 | eval_obs, infos = eval_env.reset()
477 | done = False
478 | while not done:
479 | actions = drq_agent.act(torch.Tensor(eval_obs).to(device), step=global_transitions)
480 | eval_next_obs, rewards, dones, _, infos = eval_env.step(actions)
481 | eval_obs = eval_next_obs
482 | done = dones[0]
483 | if done:
484 | eval_episodic_returns.append(infos['final_info'][0]["reward"])
485 | eval_episodic_lengths.append(infos['final_info'][0]["ep_len"])
486 | if args.capture_video:
487 | record_file_dir = os.path.join("recordings", args.exp_name, os.path.basename(__file__).replace(".py", ""), f"{args.domain_name}-{args.task_name}")
488 | os.makedirs(record_file_dir, exist_ok=True)
489 | record_file_fn = f"{args.domain_name}-{args.task_name}_seed{args.seed}_step{global_transitions:07d}_eval{eval_ep:02d}_record.pt"
490 | eval_env.envs[0].save_record_to_file(os.path.join(record_file_dir, record_file_fn))
491 | if global_transitions >= args.total_timesteps and eval_ep == 0:
492 | model_file_dir = os.path.join("trained_models", args.exp_name, os.path.basename(__file__).replace(".py", ""), f"{args.domain_name}-{args.task_name}")
493 | os.makedirs(model_file_dir, exist_ok=True)
494 | model_fn = f"{args.domain_name}-{args.task_name}_seed{args.seed}_model.pt"
495 | torch.save({"sfn": None, "agent": drq_agent.save_agent()}, os.path.join(model_file_dir, model_fn))
496 |
497 | writer.add_scalar("charts/eval_episodic_return", np.mean(eval_episodic_returns), global_transitions)
498 | writer.add_scalar("charts/eval_episodic_return_std", np.std(eval_episodic_returns), global_transitions)
499 | # writer.add_scalar("charts/eval_episodic_length", np.mean(), global_transitions)
500 | print(f"[T: {time.time()-start_time:.2f}] [N: {global_transitions:07,d}] [Eval R: {np.mean(eval_episodic_returns):.2f}+/-{np.std(eval_episodic_returns):.2f}] [R list: {','.join([str(r) for r in eval_episodic_returns])}]")
501 |
502 | drq_agent.train()
503 |
504 | envs.close()
505 | writer.close()
--------------------------------------------------------------------------------
/agent/dqn_atari_sugarl.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os, sys
3 | import os.path as osp
4 | import random
5 | import time
6 | from itertools import product
7 | from distutils.util import strtobool
8 |
9 | sys.path.append(osp.dirname(osp.dirname(osp.realpath(__file__))))
10 | os.environ["OMP_NUM_THREADS"] = "1"
11 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
12 | import warnings
13 | warnings.filterwarnings("ignore", category=UserWarning)
14 |
15 | import gymnasium as gym
16 | from gymnasium.spaces import Discrete, Dict
17 | import numpy as np
18 | import torch
19 | import torch.nn as nn
20 | import torch.nn.functional as F
21 | import torch.optim as optim
22 | from torchvision.transforms import Resize
23 |
24 | from common.buffer import DoubleActionReplayBuffer
25 | from common.pvm_buffer import PVMBuffer
26 | from common.utils import get_timestr, seed_everything, get_sugarl_reward_scale_atari
27 | from torch.utils.tensorboard import SummaryWriter
28 |
29 | from active_gym.atari_env import AtariFixedFovealEnv, AtariEnvArgs
30 |
31 |
32 | def parse_args():
33 | # fmt: off
34 | parser = argparse.ArgumentParser()
35 | parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
36 | help="the name of this experiment")
37 | parser.add_argument("--seed", type=int, default=1,
38 | help="seed of the experiment")
39 | parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
40 | help="if toggled, cuda will be enabled by default")
41 | parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
42 | help="whether to capture videos of the agent performances (check out `videos` folder)")
43 |
44 | # env setting
45 | parser.add_argument("--env", type=str, default="breakout",
46 | help="the id of the environment")
47 | parser.add_argument("--env-num", type=int, default=1,
48 | help="# envs in parallel")
49 | parser.add_argument("--frame-stack", type=int, default=4,
50 | help="frame stack #")
51 | parser.add_argument("--action-repeat", type=int, default=4,
52 | help="action repeat #")
53 | parser.add_argument("--clip-reward", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True)
54 |
55 | # fov setting
56 | parser.add_argument("--fov-size", type=int, default=50)
57 | parser.add_argument("--fov-init-loc", type=int, default=0)
58 | parser.add_argument("--sensory-action-mode", type=str, default="absolute")
59 | parser.add_argument("--sensory-action-space", type=int, default=10) # ignored when sensory_action_mode="relative"
60 | parser.add_argument("--resize-to-full", default=False, action="store_true")
61 | # for discrete observ action
62 | parser.add_argument("--sensory-action-x-size", type=int, default=4)
63 | parser.add_argument("--sensory-action-y-size", type=int, default=4)
64 | # pvm setting
65 | parser.add_argument("--pvm-stack", type=int, default=3)
66 |
67 | # Algorithm specific arguments
68 | parser.add_argument("--total-timesteps", type=int, default=3000000,
69 | help="total timesteps of the experiments")
70 | parser.add_argument("--learning-rate", type=float, default=1e-4,
71 | help="the learning rate of the optimizer")
72 | parser.add_argument("--buffer-size", type=int, default=500000,
73 | help="the replay memory buffer size")
74 | parser.add_argument("--gamma", type=float, default=0.99,
75 | help="the discount factor gamma")
76 | parser.add_argument("--target-network-frequency", type=int, default=1000,
77 | help="the timesteps it takes to update the target network")
78 | parser.add_argument("--batch-size", type=int, default=32,
79 | help="the batch size of sample from the reply memory")
80 | parser.add_argument("--start-e", type=float, default=1,
81 | help="the starting epsilon for exploration")
82 | parser.add_argument("--end-e", type=float, default=0.01,
83 | help="the ending epsilon for exploration")
84 | parser.add_argument("--exploration-fraction", type=float, default=0.10,
85 | help="the fraction of `total-timesteps` it takes from start-e to go end-e")
86 | parser.add_argument("--learning-starts", type=int, default=80000,
87 | help="timestep to start learning")
88 | parser.add_argument("--train-frequency", type=int, default=4,
89 | help="the frequency of training")
90 |
91 | # eval args
92 | parser.add_argument("--eval-frequency", type=int, default=-1,
93 | help="eval frequency. default -1 is eval at the end.")
94 | parser.add_argument("--eval-num", type=int, default=10,
95 | help="eval frequency. default -1 is eval at the end.")
96 | args = parser.parse_args()
97 | # fmt: on
98 | return args
99 |
100 |
101 | def make_env(env_name, seed, **kwargs):
102 | def thunk():
103 | env_args = AtariEnvArgs(
104 | game=env_name, seed=seed, obs_size=(84, 84), **kwargs
105 | )
106 | env = AtariFixedFovealEnv(env_args)
107 | env.action_space.seed(seed)
108 | env.observation_space.seed(seed)
109 | return env
110 |
111 | return thunk
112 |
113 |
114 | # ALGO LOGIC: initialize agent here:
115 | class QNetwork(nn.Module):
116 | def __init__(self, env, sensory_action_set=None):
117 | super().__init__()
118 | if isinstance(env.single_action_space, Discrete):
119 | motor_action_space_size = env.single_action_space.n
120 | sensory_action_space_size = None
121 | elif isinstance(env.single_action_space, Dict):
122 | motor_action_space_size = env.single_action_space["motor_action"].n
123 | if sensory_action_set is not None:
124 | sensory_action_space_size = len(sensory_action_set)
125 | else:
126 | sensory_action_space_size = env.single_action_space["sensory_action"].n
127 | self.backbone = nn.Sequential(
128 | nn.Conv2d(4, 32, 8, stride=4),
129 | nn.ReLU(),
130 | nn.Conv2d(32, 64, 4, stride=2),
131 | nn.ReLU(),
132 | nn.Conv2d(64, 64, 3, stride=1),
133 | nn.ReLU(),
134 | nn.Flatten(),
135 | nn.Linear(3136, 512),
136 | nn.ReLU(),
137 | )
138 |
139 | self.motor_action_head = nn.Linear(512, motor_action_space_size)
140 | self.sensory_action_head = None
141 | if sensory_action_space_size is not None:
142 | self.sensory_action_head = nn.Linear(512, sensory_action_space_size)
143 |
144 |
145 | def forward(self, x):
146 | x = self.backbone(x)
147 | motor_action = self.motor_action_head(x)
148 | sensory_action = None
149 | if self.sensory_action_head:
150 | sensory_action = self.sensory_action_head(x)
151 | return motor_action, sensory_action
152 |
153 | class SelfPredictionNetwork(nn.Module):
154 | def __init__(self, env, sensory_action_set=None):
155 | super().__init__()
156 | if isinstance(env.single_action_space, Discrete):
157 | motor_action_space_size = env.single_action_space.n
158 | sensory_action_space_size = None
159 | elif isinstance(env.single_action_space, Dict):
160 | motor_action_space_size = env.single_action_space["motor_action"].n
161 | if sensory_action_set is not None:
162 | sensory_action_space_size = len(sensory_action_set)
163 | else:
164 | sensory_action_space_size = env.single_action_space["sensory_action"].n
165 |
166 | self.backbone = nn.Sequential(
167 | nn.Conv2d(8, 32, 8, stride=4),
168 | nn.ReLU(),
169 | nn.Conv2d(32, 64, 4, stride=2),
170 | nn.ReLU(),
171 | nn.Conv2d(64, 64, 3, stride=1),
172 | nn.ReLU(),
173 | nn.Flatten(),
174 | nn.Linear(3136, 512),
175 | nn.ReLU(),
176 | )
177 |
178 | self.head = nn.Sequential(
179 | nn.Linear(512, motor_action_space_size),
180 | )
181 |
182 | self.loss = nn.CrossEntropyLoss()
183 |
184 | def get_loss(self, x, target) -> torch.Tensor:
185 | return self.loss(x, target)
186 |
187 |
188 | def forward(self, x):
189 | x = self.backbone(x)
190 | x = self.head(x)
191 | return x
192 |
193 |
194 | def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
195 | slope = (end_e - start_e) / duration
196 | return max(slope * t + start_e, end_e)
197 |
198 |
199 | if __name__ == "__main__":
200 | args = parse_args()
201 | args.env = args.env.lower()
202 | run_name = f"{args.env}__{os.path.basename(__file__)}__{args.seed}__{get_timestr()}"
203 | run_dir = os.path.join("runs", args.exp_name)
204 | if not os.path.exists(run_dir):
205 | os.makedirs(run_dir, exist_ok=True)
206 |
207 | writer = SummaryWriter(os.path.join(run_dir, run_name))
208 | writer.add_text(
209 | "hyperparameters",
210 | "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
211 | )
212 |
213 | # TRY NOT TO MODIFY: seeding
214 | seed_everything(args.seed)
215 |
216 |
217 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
218 |
219 | # env setup
220 | envs = []
221 | for i in range(args.env_num):
222 | envs.append(make_env(args.env, args.seed+i, frame_stack=args.frame_stack, action_repeat=args.action_repeat,
223 | fov_size=(args.fov_size, args.fov_size),
224 | fov_init_loc=(args.fov_init_loc, args.fov_init_loc),
225 | sensory_action_mode=args.sensory_action_mode,
226 | sensory_action_space=(-args.sensory_action_space, args.sensory_action_space),
227 | resize_to_full=args.resize_to_full,
228 | clip_reward=args.clip_reward,
229 | mask_out=True))
230 | # envs = gym.vector.AsyncVectorEnv(envs)
231 | envs = gym.vector.SyncVectorEnv(envs)
232 |
233 | sugarl_r_scale = get_sugarl_reward_scale_atari(args.env)
234 |
235 | resize = Resize((84, 84))
236 |
237 | # get a discrete observ action space
238 | OBSERVATION_SIZE = (84, 84)
239 | observ_x_max, observ_y_max = OBSERVATION_SIZE[0]-args.fov_size, OBSERVATION_SIZE[1]-args.fov_size
240 | sensory_action_step = (observ_x_max//args.sensory_action_x_size,
241 | observ_y_max//args.sensory_action_y_size)
242 | sensory_action_x_set = list(range(0, observ_x_max, sensory_action_step[0]))[:args.sensory_action_x_size]
243 | sensory_action_y_set = list(range(0, observ_y_max, sensory_action_step[1]))[:args.sensory_action_y_size]
244 | sensory_action_set = [np.array(a) for a in list(product(sensory_action_x_set, sensory_action_y_set))]
245 |
246 | q_network = QNetwork(envs, sensory_action_set=sensory_action_set).to(device)
247 | optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate)
248 | target_network = QNetwork(envs, sensory_action_set=sensory_action_set).to(device)
249 | target_network.load_state_dict(q_network.state_dict())
250 |
251 | sfn = SelfPredictionNetwork(envs, sensory_action_set=sensory_action_set).to(device)
252 | sfn_optimizer = optim.Adam(sfn.parameters(), lr=args.learning_rate)
253 |
254 | rb = DoubleActionReplayBuffer(
255 | args.buffer_size,
256 | envs.single_observation_space,
257 | envs.single_action_space["motor_action"],
258 | Discrete(len(sensory_action_set)),
259 | device,
260 | n_envs=envs.num_envs,
261 | optimize_memory_usage=True,
262 | handle_timeout_termination=False,
263 | )
264 | start_time = time.time()
265 |
266 | # TRY NOT TO MODIFY: start the game
267 | obs, infos = envs.reset()
268 | global_transitions = 0
269 | pvm_buffer = PVMBuffer(args.pvm_stack, (envs.num_envs, args.frame_stack,)+OBSERVATION_SIZE)
270 |
271 | while global_transitions < args.total_timesteps:
272 | pvm_buffer.append(obs)
273 | pvm_obs = pvm_buffer.get_obs(mode="stack_max")
274 | # ALGO LOGIC: put action logic here
275 | epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_transitions)
276 | if random.random() < epsilon:
277 | actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
278 | motor_actions = np.array([actions[0]["motor_action"]])
279 | sensory_actions = np.array([random.randint(0, len(sensory_action_set)-1)])
280 | else:
281 | motor_q_values, sensory_q_values = q_network(resize(torch.from_numpy(pvm_obs)).to(device))
282 | motor_actions = torch.argmax(motor_q_values, dim=1).cpu().numpy()
283 | sensory_actions = torch.argmax(sensory_q_values, dim=1).cpu().numpy()
284 |
285 | # TRY NOT TO MODIFY: execute the game and log data.
286 | next_obs, rewards, dones, _, infos = envs.step({"motor_action": motor_actions,
287 | "sensory_action": [sensory_action_set[a] for a in sensory_actions] })
288 | # print (global_step, infos)
289 |
290 | # TRY NOT TO MODIFY: record rewards for plotting purposes
291 | if "final_info" in infos:
292 | for idx, d in enumerate(dones):
293 | if d:
294 | print(f"[T: {time.time()-start_time:.2f}] [N: {global_transitions:07,d}] [R: {infos['final_info'][idx]['reward']:.2f}]")
295 | writer.add_scalar("charts/episodic_return", infos['final_info'][idx]["reward"], global_transitions)
296 | writer.add_scalar("charts/episodic_length", infos['final_info'][idx]["ep_len"], global_transitions)
297 | writer.add_scalar("charts/epsilon", epsilon, global_transitions)
298 | break
299 |
300 | # TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
301 | real_next_obs = next_obs
302 | for idx, d in enumerate(dones):
303 | if d:
304 | real_next_obs[idx] = infos["final_observation"][idx]
305 | pvm_buffer_copy = pvm_buffer.copy()
306 | pvm_buffer_copy.append(real_next_obs)
307 | real_next_pvm_obs = pvm_buffer_copy.get_obs(mode="stack_max")
308 | rb.add(pvm_obs, real_next_pvm_obs, motor_actions, sensory_actions, rewards, dones, {})
309 |
310 | # TRY NOT TO MODIFY: CRUCIAL step easy to overlook
311 | obs = next_obs
312 |
313 | # INC total transitions
314 | global_transitions += args.env_num
315 |
316 | obs_backup = obs # back obs
317 |
318 | if global_transitions < args.batch_size:
319 | continue
320 |
321 | # Training
322 | if global_transitions % args.train_frequency == 0:
323 | data = rb.sample(args.batch_size // args.env_num) # counter-balance the true global transitions used for training
324 |
325 | # sfn learning
326 | concat_observation = torch.concat([data.next_observations, data.observations], dim=1) # concat at dimension T
327 | pred_motor_actions = sfn(resize(concat_observation))
328 | # print (pred_motor_actions.size(), data.motor_actions.size())
329 | sfn_loss = sfn.get_loss(pred_motor_actions, data.motor_actions.flatten())
330 | sfn_optimizer.zero_grad()
331 | sfn_loss.backward()
332 | sfn_optimizer.step()
333 | observ_r = F.softmax(pred_motor_actions).gather(1, data.motor_actions).squeeze().detach() # 0-1
334 |
335 | # Q network learning
336 | if global_transitions > args.learning_starts:
337 | with torch.no_grad():
338 | motor_target, sensory_target = target_network(resize(data.next_observations))
339 | motor_target_max, _ = motor_target.max(dim=1)
340 | sensory_target_max, _ = sensory_target.max(dim=1)
341 | # scale step-wise reward with observ_r
342 | observ_r_adjusted = observ_r.clone()
343 | observ_r_adjusted[data.rewards.flatten() > 0] = 1 - observ_r_adjusted[data.rewards.flatten() > 0]
344 | td_target = data.rewards.flatten() - (1 - observ_r) * sugarl_r_scale + args.gamma * (motor_target_max+sensory_target_max) * (1 - data.dones.flatten())
345 | original_td_target = data.rewards.flatten() + args.gamma * (motor_target_max+sensory_target_max) * (1 - data.dones.flatten())
346 |
347 | old_motor_q_val, old_sensory_q_val = q_network(resize(data.observations))
348 | old_motor_val = old_motor_q_val.gather(1, data.motor_actions).squeeze()
349 | old_sensory_val = old_sensory_q_val.gather(1, data.sensory_actions).squeeze()
350 | old_val = old_motor_val + old_sensory_val
351 |
352 | loss = F.mse_loss(td_target, old_val)
353 |
354 | if global_transitions % 100 == 0:
355 | writer.add_scalar("losses/td_loss", loss, global_transitions)
356 | writer.add_scalar("losses/q_values", old_val.mean().item(), global_transitions)
357 | writer.add_scalar("losses/motor_q_values", old_motor_val.mean().item(), global_transitions)
358 | writer.add_scalar("losses/sensor_q_values", old_sensory_val.mean().item(), global_transitions)
359 | # print("SPS:", int(global_step / (time.time() - start_time)))
360 | writer.add_scalar("charts/SPS", int(global_transitions / (time.time() - start_time)), global_transitions)
361 |
362 | writer.add_scalar("losses/sfn_loss", sfn_loss.item(), global_transitions)
363 | writer.add_scalar("losses/observ_r", observ_r.mean().item(), global_transitions)
364 | writer.add_scalar("losses/original_td_target", original_td_target.mean().item(), global_transitions)
365 | writer.add_scalar("losses/sugarl_r_scaled_td_target", td_target.mean().item(), global_transitions)
366 |
367 | # optimize the model
368 | optimizer.zero_grad()
369 | loss.backward()
370 | optimizer.step()
371 |
372 | # update the target network
373 | if (global_transitions // args.env_num) % args.target_network_frequency == 0:
374 | target_network.load_state_dict(q_network.state_dict())
375 |
376 | # evaluation
377 | if (global_transitions % args.eval_frequency == 0 and args.eval_frequency > 0) or \
378 | (global_transitions >= args.total_timesteps):
379 | q_network.eval()
380 | sfn.eval()
381 |
382 | eval_episodic_returns, eval_episodic_lengths = [], []
383 |
384 | for eval_ep in range(args.eval_num):
385 | eval_env = [make_env(args.env, args.seed+eval_ep, frame_stack=args.frame_stack, action_repeat=args.action_repeat,
386 | fov_size=(args.fov_size, args.fov_size),
387 | fov_init_loc=(args.fov_init_loc, args.fov_init_loc),
388 | sensory_action_mode=args.sensory_action_mode,
389 | sensory_action_space=(-args.sensory_action_space, args.sensory_action_space),
390 | resize_to_full=args.resize_to_full,
391 | clip_reward=args.clip_reward,
392 | mask_out=True,
393 | training=False,
394 | record=args.capture_video)]
395 | eval_env = gym.vector.SyncVectorEnv(eval_env)
396 | obs_eval, _ = eval_env.reset()
397 | done = False
398 | pvm_buffer_eval = PVMBuffer(args.pvm_stack, (eval_env.num_envs, args.frame_stack,)+OBSERVATION_SIZE)
399 | while not done:
400 | pvm_buffer_eval.append(obs_eval)
401 | pvm_obs_eval = pvm_buffer_eval.get_obs(mode="stack_max")
402 | motor_q_values, sensory_q_values = q_network(resize(torch.from_numpy(pvm_obs_eval)).to(device))
403 | motor_actions = torch.argmax(motor_q_values, dim=1).cpu().numpy()
404 | sensory_actions = torch.argmax(sensory_q_values, dim=1).cpu().numpy()
405 | next_obs_eval, rewards, dones, _, infos = eval_env.step({"motor_action": motor_actions,
406 | "sensory_action": [sensory_action_set[a] for a in sensory_actions]})
407 | obs_eval = next_obs_eval
408 | done = dones[0]
409 | if done:
410 | eval_episodic_returns.append(infos['final_info'][0]["reward"])
411 | eval_episodic_lengths.append(infos['final_info'][0]["ep_len"])
412 | if args.capture_video:
413 | record_file_dir = os.path.join("recordings", args.exp_name, os.path.basename(__file__).rstrip(".py"), args.env)
414 | os.makedirs(record_file_dir, exist_ok=True)
415 | record_file_fn = f"{args.env}_seed{args.seed}_step{global_transitions:07d}_eval{eval_ep:02d}_record.pt"
416 | eval_env.envs[0].save_record_to_file(os.path.join(record_file_dir, record_file_fn))
417 | if eval_ep == 0:
418 | model_file_dir = os.path.join("trained_models", args.exp_name, os.path.basename(__file__).rstrip(".py"), args.env)
419 | os.makedirs(model_file_dir, exist_ok=True)
420 | model_fn = f"{args.env}_seed{args.seed}_step{global_transitions:07d}_model.pt"
421 | torch.save({"sfn": sfn.state_dict(), "q": q_network.state_dict()}, os.path.join(model_file_dir, model_fn))
422 |
423 |
424 | writer.add_scalar("charts/eval_episodic_return", np.mean(eval_episodic_returns), global_transitions)
425 | writer.add_scalar("charts/eval_episodic_return_std", np.std(eval_episodic_returns), global_transitions)
426 | # writer.add_scalar("charts/eval_episodic_length", np.mean(), global_transitions)
427 | print(f"[T: {time.time()-start_time:.2f}] [N: {global_transitions:07,d}] [Eval R: {np.mean(eval_episodic_returns):.2f}+/-{np.std(eval_episodic_returns):.2f}] [R list: {','.join([str(r) for r in eval_episodic_returns])}]")
428 |
429 | q_network.train()
430 | sfn.train()
431 |
432 | obs = obs_backup # restore obs if eval occurs
433 |
434 |
435 |
436 | envs.close()
437 | eval_env.close()
438 | writer.close()
--------------------------------------------------------------------------------
/agent/dqn_atari_wp_sugarl.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os, sys
3 | import os.path as osp
4 | import random
5 | import time
6 | from itertools import product
7 | from distutils.util import strtobool
8 |
9 | sys.path.append(osp.dirname(osp.dirname(osp.realpath(__file__))))
10 | os.environ["OMP_NUM_THREADS"] = "1"
11 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
12 | import warnings
13 | warnings.filterwarnings("ignore", category=UserWarning)
14 |
15 | import gymnasium as gym
16 | from gymnasium.spaces import Discrete, Dict
17 | import numpy as np
18 | import torch
19 | import torch.nn as nn
20 | import torch.nn.functional as F
21 | import torch.optim as optim
22 | from torchvision.transforms import Resize
23 |
24 | from common.buffer import DoubleActionReplayBuffer
25 | from common.pvm_buffer import PVMBuffer
26 | from common.utils import get_timestr, seed_everything, get_sugarl_reward_scale_atari
27 | from torch.utils.tensorboard import SummaryWriter
28 |
29 | from active_gym import AtariFixedFovealPeripheralEnv, AtariEnvArgs
30 |
31 |
32 | def parse_args():
33 | # fmt: off
34 | parser = argparse.ArgumentParser()
35 | parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
36 | help="the name of this experiment")
37 | parser.add_argument("--seed", type=int, default=1,
38 | help="seed of the experiment")
39 | parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
40 | help="if toggled, cuda will be enabled by default")
41 | parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
42 | help="whether to capture videos of the agent performances (check out `videos` folder)")
43 |
44 | # env setting
45 | parser.add_argument("--env", type=str, default="breakout",
46 | help="the id of the environment")
47 | parser.add_argument("--env-num", type=int, default=1,
48 | help="# envs in parallel")
49 | parser.add_argument("--frame-stack", type=int, default=4,
50 | help="frame stack #")
51 | parser.add_argument("--action-repeat", type=int, default=4,
52 | help="action repeat #")
53 | parser.add_argument("--clip-reward", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True)
54 |
55 | # fov setting
56 | parser.add_argument("--fov-size", type=int, default=50)
57 | parser.add_argument("--fov-init-loc", type=int, default=0)
58 | parser.add_argument("--sensory-action-mode", type=str, default="absolute")
59 | parser.add_argument("--sensory-action-space", type=int, default=10) # ignored when sensory_action_mode="relative"
60 | parser.add_argument("--resize-to-full", default=False, action="store_true")
61 | parser.add_argument("--peripheral-res", type=int, default=20)
62 | # for discrete observ action
63 | parser.add_argument("--sensory-action-x-size", type=int, default=4)
64 | parser.add_argument("--sensory-action-y-size", type=int, default=4)
65 | # pvm setting
66 | parser.add_argument("--pvm-stack", type=int, default=3)
67 |
68 |
69 | # Algorithm specific arguments
70 | parser.add_argument("--total-timesteps", type=int, default=3000000,
71 | help="total timesteps of the experiments")
72 | parser.add_argument("--learning-rate", type=float, default=1e-4,
73 | help="the learning rate of the optimizer")
74 | parser.add_argument("--buffer-size", type=int, default=500000,
75 | help="the replay memory buffer size")
76 | parser.add_argument("--gamma", type=float, default=0.99,
77 | help="the discount factor gamma")
78 | parser.add_argument("--target-network-frequency", type=int, default=1000,
79 | help="the timesteps it takes to update the target network")
80 | parser.add_argument("--batch-size", type=int, default=32,
81 | help="the batch size of sample from the reply memory")
82 | parser.add_argument("--start-e", type=float, default=1,
83 | help="the starting epsilon for exploration")
84 | parser.add_argument("--end-e", type=float, default=0.01,
85 | help="the ending epsilon for exploration")
86 | parser.add_argument("--exploration-fraction", type=float, default=0.10,
87 | help="the fraction of `total-timesteps` it takes from start-e to go end-e")
88 | parser.add_argument("--learning-starts", type=int, default=80000,
89 | help="timestep to start learning")
90 | parser.add_argument("--train-frequency", type=int, default=4,
91 | help="the frequency of training")
92 |
93 | # eval args
94 | parser.add_argument("--eval-frequency", type=int, default=-1,
95 | help="eval frequency. default -1 is eval at the end.")
96 | parser.add_argument("--eval-num", type=int, default=10,
97 | help="eval frequency. default -1 is eval at the end.")
98 | args = parser.parse_args()
99 | # fmt: on
100 | return args
101 |
102 |
103 | def make_env(env_name, seed, **kwargs):
104 | def thunk():
105 | env_args = AtariEnvArgs(
106 | game=env_name, seed=seed, obs_size=(84, 84), **kwargs
107 | )
108 | env = AtariFixedFovealPeripheralEnv(env_args)
109 | env.action_space.seed(seed)
110 | env.observation_space.seed(seed)
111 | return env
112 |
113 | return thunk
114 |
115 |
116 | # ALGO LOGIC: initialize agent here:
117 | class QNetwork(nn.Module):
118 | def __init__(self, env, sensory_action_set=None):
119 | super().__init__()
120 | if isinstance(env.single_action_space, Discrete):
121 | motor_action_space_size = env.single_action_space.n
122 | sensory_action_space_size = None
123 | elif isinstance(env.single_action_space, Dict):
124 | motor_action_space_size = env.single_action_space["motor_action"].n
125 | if sensory_action_set is not None:
126 | sensory_action_space_size = len(sensory_action_set)
127 | else:
128 | sensory_action_space_size = env.single_action_space["sensory_action"].n
129 | self.backbone = nn.Sequential(
130 | nn.Conv2d(4, 32, 8, stride=4),
131 | nn.ReLU(),
132 | nn.Conv2d(32, 64, 4, stride=2),
133 | nn.ReLU(),
134 | nn.Conv2d(64, 64, 3, stride=1),
135 | nn.ReLU(),
136 | nn.Flatten(),
137 | nn.Linear(3136, 512),
138 | nn.ReLU(),
139 | )
140 |
141 | self.motor_action_head = nn.Linear(512, motor_action_space_size)
142 | self.sensory_action_head = None
143 | if sensory_action_space_size is not None:
144 | self.sensory_action_head = nn.Linear(512, sensory_action_space_size)
145 |
146 |
147 | def forward(self, x):
148 | x = self.backbone(x)
149 | motor_action = self.motor_action_head(x)
150 | sensory_action = None
151 | if self.sensory_action_head:
152 | sensory_action = self.sensory_action_head(x)
153 | return motor_action, sensory_action
154 |
155 | class SelfPredictionNetwork(nn.Module):
156 | def __init__(self, env, sensory_action_set=None):
157 | super().__init__()
158 | if isinstance(env.single_action_space, Discrete):
159 | motor_action_space_size = env.single_action_space.n
160 | sensory_action_space_size = None
161 | elif isinstance(env.single_action_space, Dict):
162 | motor_action_space_size = env.single_action_space["motor_action"].n
163 | if sensory_action_set is not None:
164 | sensory_action_space_size = len(sensory_action_set)
165 | else:
166 | sensory_action_space_size = env.single_action_space["sensory_action"].n
167 |
168 | self.backbone = nn.Sequential(
169 | nn.Conv2d(8, 32, 8, stride=4),
170 | nn.ReLU(),
171 | nn.Conv2d(32, 64, 4, stride=2),
172 | nn.ReLU(),
173 | nn.Conv2d(64, 64, 3, stride=1),
174 | nn.ReLU(),
175 | nn.Flatten(),
176 | nn.Linear(3136, 512),
177 | nn.ReLU(),
178 | )
179 |
180 | self.head = nn.Sequential(
181 | nn.Linear(512, motor_action_space_size),
182 | )
183 |
184 | self.loss = nn.CrossEntropyLoss()
185 |
186 | def get_loss(self, x, target) -> torch.Tensor:
187 | return self.loss(x, target)
188 |
189 |
190 | def forward(self, x):
191 | x = self.backbone(x)
192 | x = self.head(x)
193 | return x
194 |
195 |
196 | def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
197 | slope = (end_e - start_e) / duration
198 | return max(slope * t + start_e, end_e)
199 |
200 |
201 | if __name__ == "__main__":
202 | args = parse_args()
203 | args.env = args.env.lower()
204 | run_name = f"{args.env}__{os.path.basename(__file__)}__{args.seed}__{get_timestr()}"
205 | run_dir = os.path.join("runs", args.exp_name)
206 | if not os.path.exists(run_dir):
207 | os.makedirs(run_dir, exist_ok=True)
208 |
209 | writer = SummaryWriter(os.path.join(run_dir, run_name))
210 | writer.add_text(
211 | "hyperparameters",
212 | "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
213 | )
214 |
215 | # TRY NOT TO MODIFY: seeding
216 | seed_everything(args.seed)
217 |
218 |
219 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
220 |
221 | # env setup
222 | envs = []
223 | for i in range(args.env_num):
224 | envs.append(make_env(args.env, args.seed+i, frame_stack=args.frame_stack, action_repeat=args.action_repeat,
225 | fov_size=(args.fov_size, args.fov_size),
226 | fov_init_loc=(args.fov_init_loc, args.fov_init_loc),
227 | peripheral_res=(args.peripheral_res, args.peripheral_res),
228 | sensory_action_mode=args.sensory_action_mode,
229 | sensory_action_space=(-args.sensory_action_space, args.sensory_action_space),
230 | resize_to_full=args.resize_to_full,
231 | clip_reward=args.clip_reward,
232 | mask_out=True))
233 | # envs = gym.vector.AsyncVectorEnv(envs)
234 | envs = gym.vector.SyncVectorEnv(envs)
235 |
236 | sugarl_r_scale = get_sugarl_reward_scale_atari(args.env)
237 |
238 | resize = Resize((84, 84))
239 |
240 | # get a discrete observ action space
241 | OBSERVATION_SIZE = (84, 84)
242 | observ_x_max, observ_y_max = OBSERVATION_SIZE[0]-args.fov_size, OBSERVATION_SIZE[1]-args.fov_size
243 | sensory_action_step = (observ_x_max//args.sensory_action_x_size,
244 | observ_y_max//args.sensory_action_y_size)
245 | sensory_action_x_set = list(range(0, observ_x_max, sensory_action_step[0]))[:args.sensory_action_x_size]
246 | sensory_action_y_set = list(range(0, observ_y_max, sensory_action_step[1]))[:args.sensory_action_y_size]
247 | sensory_action_set = [np.array(a) for a in list(product(sensory_action_x_set, sensory_action_y_set))]
248 |
249 | q_network = QNetwork(envs, sensory_action_set=sensory_action_set).to(device)
250 | optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate)
251 | target_network = QNetwork(envs, sensory_action_set=sensory_action_set).to(device)
252 | target_network.load_state_dict(q_network.state_dict())
253 |
254 | sfn = SelfPredictionNetwork(envs, sensory_action_set=sensory_action_set).to(device)
255 | sfn_optimizer = optim.Adam(sfn.parameters(), lr=args.learning_rate)
256 |
257 | rb = DoubleActionReplayBuffer(
258 | args.buffer_size,
259 | envs.single_observation_space,
260 | envs.single_action_space["motor_action"],
261 | Discrete(len(sensory_action_set)),
262 | device,
263 | n_envs=envs.num_envs,
264 | optimize_memory_usage=True,
265 | handle_timeout_termination=False,
266 | )
267 | start_time = time.time()
268 |
269 | # TRY NOT TO MODIFY: start the game
270 | obs, infos = envs.reset()
271 | global_transitions = 0
272 | pvm_buffer = PVMBuffer(args.pvm_stack, (envs.num_envs, args.frame_stack,)+OBSERVATION_SIZE)
273 |
274 | while global_transitions < args.total_timesteps:
275 | pvm_buffer.append(obs)
276 | pvm_obs = pvm_buffer.get_obs(mode="stack_max")
277 | # ALGO LOGIC: put action logic here
278 | epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_transitions)
279 | if random.random() < epsilon:
280 | actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
281 | motor_actions = np.array([actions[0]["motor_action"]])
282 | sensory_actions = np.array([random.randint(0, len(sensory_action_set)-1)])
283 | else:
284 | motor_q_values, sensory_q_values = q_network(resize(torch.from_numpy(pvm_obs)).to(device))
285 | motor_actions = torch.argmax(motor_q_values, dim=1).cpu().numpy()
286 | sensory_actions = torch.argmax(sensory_q_values, dim=1).cpu().numpy()
287 |
288 | # TRY NOT TO MODIFY: execute the game and log data.
289 | next_obs, rewards, dones, _, infos = envs.step({"motor_action": motor_actions,
290 | "sensory_action": [sensory_action_set[a] for a in sensory_actions] })
291 | # print (global_step, infos)
292 |
293 | # TRY NOT TO MODIFY: record rewards for plotting purposes
294 | if "final_info" in infos:
295 | for idx, d in enumerate(dones):
296 | if d:
297 | print(f"[T: {time.time()-start_time:.2f}] [N: {global_transitions:07,d}] [R: {infos['final_info'][idx]['reward']:.2f}]")
298 | writer.add_scalar("charts/episodic_return", infos['final_info'][idx]["reward"], global_transitions)
299 | writer.add_scalar("charts/episodic_length", infos['final_info'][idx]["ep_len"], global_transitions)
300 | writer.add_scalar("charts/epsilon", epsilon, global_transitions)
301 | break
302 |
303 | # TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
304 | real_next_obs = next_obs
305 | for idx, d in enumerate(dones):
306 | if d:
307 | real_next_obs[idx] = infos["final_observation"][idx]
308 | pvm_buffer_copy = pvm_buffer.copy()
309 | pvm_buffer_copy.append(real_next_obs)
310 | real_next_pvm_obs = pvm_buffer_copy.get_obs(mode="stack_max")
311 | rb.add(pvm_obs, real_next_pvm_obs, motor_actions, sensory_actions, rewards, dones, {})
312 |
313 | # TRY NOT TO MODIFY: CRUCIAL step easy to overlook
314 | obs = next_obs
315 |
316 | # INC total transitions
317 | global_transitions += args.env_num
318 |
319 | obs_backup = obs # back obs
320 |
321 | if global_transitions < args.batch_size:
322 | continue
323 |
324 | # Training
325 | if global_transitions % args.train_frequency == 0:
326 | data = rb.sample(args.batch_size // args.env_num) # counter-balance the true global transitions used for training
327 |
328 | # sfn learning
329 | concat_observation = torch.concat([data.next_observations, data.observations], dim=1) # concat at dimension T
330 | pred_motor_actions = sfn(resize(concat_observation))
331 | # print (pred_motor_actions.size(), data.motor_actions.size())
332 | sfn_loss = sfn.get_loss(pred_motor_actions, data.motor_actions.flatten())
333 | sfn_optimizer.zero_grad()
334 | sfn_loss.backward()
335 | sfn_optimizer.step()
336 | observ_r = F.softmax(pred_motor_actions).gather(1, data.motor_actions).squeeze().detach() # 0-1
337 |
338 | # Q network learning
339 | if global_transitions > args.learning_starts:
340 | with torch.no_grad():
341 | motor_target, sensory_target = target_network(resize(data.next_observations))
342 | motor_target_max, _ = motor_target.max(dim=1)
343 | sensory_target_max, _ = sensory_target.max(dim=1)
344 | # scale step-wise reward with observ_r
345 | observ_r_adjusted = observ_r.clone()
346 | observ_r_adjusted[data.rewards.flatten() > 0] = 1 - observ_r_adjusted[data.rewards.flatten() > 0]
347 | td_target = data.rewards.flatten() - (1 - observ_r) * sugarl_r_scale + args.gamma * (motor_target_max+sensory_target_max) * (1 - data.dones.flatten())
348 | original_td_target = data.rewards.flatten() + args.gamma * (motor_target_max+sensory_target_max) * (1 - data.dones.flatten())
349 |
350 | old_motor_q_value, old_sensory_q_val = q_network(resize(data.observations))
351 | old_motor_val = old_motor_q_value.gather(1, data.motor_actions).squeeze()
352 | old_sensory_val = old_sensory_q_val.gather(1, data.sensory_actions).squeeze()
353 | old_val = old_motor_val + old_sensory_val
354 |
355 | loss = F.mse_loss(td_target, old_val)
356 |
357 | if global_transitions % 100 == 0:
358 | writer.add_scalar("losses/td_loss", loss, global_transitions)
359 | writer.add_scalar("losses/q_values", old_val.mean().item(), global_transitions)
360 | writer.add_scalar("losses/motor_q_values", old_motor_val.mean().item(), global_transitions)
361 | writer.add_scalar("losses/action_q_values", old_sensory_val.mean().item(), global_transitions)
362 | # print("SPS:", int(global_step / (time.time() - start_time)))
363 | writer.add_scalar("charts/SPS", int(global_transitions / (time.time() - start_time)), global_transitions)
364 |
365 | writer.add_scalar("losses/sfn_loss", sfn_loss.item(), global_transitions)
366 | writer.add_scalar("losses/observ_r", observ_r.mean().item(), global_transitions)
367 | writer.add_scalar("losses/original_td_target", original_td_target.mean().item(), global_transitions)
368 | writer.add_scalar("losses/sugarl_r_scaled_td_target", td_target.mean().item(), global_transitions)
369 |
370 | # optimize the model
371 | optimizer.zero_grad()
372 | loss.backward()
373 | optimizer.step()
374 |
375 | # update the target network
376 | if (global_transitions // args.env_num) % args.target_network_frequency == 0:
377 | target_network.load_state_dict(q_network.state_dict())
378 |
379 | # evaluation
380 | if (global_transitions % args.eval_frequency == 0 and args.eval_frequency > 0) or \
381 | (global_transitions >= args.total_timesteps):
382 | q_network.eval()
383 | sfn.eval()
384 |
385 | eval_episodic_returns, eval_episodic_lengths = [], []
386 |
387 | for eval_ep in range(args.eval_num):
388 | eval_env = [make_env(args.env, args.seed+eval_ep, frame_stack=args.frame_stack, action_repeat=args.action_repeat,
389 | fov_size=(args.fov_size, args.fov_size),
390 | fov_init_loc=(args.fov_init_loc, args.fov_init_loc),
391 | peripheral_res=(args.peripheral_res, args.peripheral_res),
392 | sensory_action_mode=args.sensory_action_mode,
393 | sensory_action_space=(-args.sensory_action_space, args.sensory_action_space),
394 | resize_to_full=args.resize_to_full,
395 | clip_reward=args.clip_reward,
396 | mask_out=True,
397 | training=False,
398 | record=args.capture_video)]
399 | eval_env = gym.vector.SyncVectorEnv(eval_env)
400 | obs_eval, _ = eval_env.reset()
401 | done = False
402 | pvm_buffer_eval = PVMBuffer(args.pvm_stack, (eval_env.num_envs, args.frame_stack,)+OBSERVATION_SIZE)
403 | while not done:
404 | pvm_buffer_eval.append(obs_eval)
405 | pvm_obs_eval = pvm_buffer_eval.get_obs(mode="stack_max")
406 | motor_q_values, sensory_q_values = q_network(resize(torch.from_numpy(pvm_obs_eval)).to(device))
407 | motor_actions = torch.argmax(motor_q_values, dim=1).cpu().numpy()
408 | sensory_actions = torch.argmax(sensory_q_values, dim=1).cpu().numpy()
409 | next_obs_eval, rewards, dones, _, infos = eval_env.step({"motor_action": motor_actions,
410 | "sensory_action": [sensory_action_set[a] for a in sensory_actions]})
411 | obs_eval = next_obs_eval
412 | done = dones[0]
413 | if done:
414 | eval_episodic_returns.append(infos['final_info'][0]["reward"])
415 | eval_episodic_lengths.append(infos['final_info'][0]["ep_len"])
416 | if args.capture_video:
417 | record_file_dir = os.path.join("recordings", args.exp_name, os.path.basename(__file__).rstrip(".py"), args.env)
418 | os.makedirs(record_file_dir, exist_ok=True)
419 | record_file_fn = f"{args.env}_seed{args.seed}_step{global_transitions:07d}_eval{eval_ep:02d}_record.pt"
420 | eval_env.envs[0].save_record_to_file(os.path.join(record_file_dir, record_file_fn))
421 | if eval_ep == 0:
422 | model_file_dir = os.path.join("trained_models", args.exp_name, os.path.basename(__file__).rstrip(".py"), args.env)
423 | os.makedirs(model_file_dir, exist_ok=True)
424 | model_fn = f"{args.env}_seed{args.seed}_step{global_transitions:07d}_model.pt"
425 | torch.save({"sfn": sfn.state_dict(), "q": q_network.state_dict()}, os.path.join(model_file_dir, model_fn))
426 |
427 |
428 | writer.add_scalar("charts/eval_episodic_return", np.mean(eval_episodic_returns), global_transitions)
429 | writer.add_scalar("charts/eval_episodic_return_std", np.std(eval_episodic_returns), global_transitions)
430 | # writer.add_scalar("charts/eval_episodic_length", np.mean(), global_transitions)
431 | print(f"[T: {time.time()-start_time:.2f}] [N: {global_transitions:07,d}] [Eval R: {np.mean(eval_episodic_returns):.2f}+/-{np.std(eval_episodic_returns):.2f}] [R list: {','.join([str(r) for r in eval_episodic_returns])}]")
432 |
433 | q_network.train()
434 | sfn.train()
435 |
436 | obs = obs_backup # restore obs if eval occurs
437 |
438 |
439 |
440 | envs.close()
441 | eval_env.close()
442 | writer.close()
--------------------------------------------------------------------------------