├── .gitignore
├── README.md
├── configs
├── HandModelTSLAdjusted.yaml
├── Model-Based-RL.yml
├── half_cheetah.yaml
├── hopper.yaml
├── inverted_double_pendulum.yaml
├── inverted_pendulum.yaml
├── leg.yaml
├── swimmer.yaml
└── walker2d.yaml
├── main.py
├── model
├── __init__.py
├── archs
│ ├── __init__.py
│ └── basic.py
├── blocks
│ ├── __init__.py
│ ├── mujoco.py
│ ├── policy
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── build.py
│ │ ├── deterministic.py
│ │ ├── dynamics.py
│ │ ├── stochastic.py
│ │ ├── strategies.py
│ │ ├── trajopt.py
│ │ └── wrappers
│ │ │ ├── __init__.py
│ │ │ └── zmus.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── build.py
│ │ └── functions.py
├── build.py
├── config
│ ├── __init__.py
│ └── defaults.py
├── engine
│ ├── __init__.py
│ ├── dynamics_model_trainer.py
│ ├── landscape_plot.py
│ ├── tester.py
│ ├── trainer.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── build.py
│ │ └── experience_replay.py
└── layers
│ ├── __init__.py
│ └── feed_forward.py
├── mujoco
├── __init__.py
├── assets
│ ├── Geometry
│ │ ├── bofoot.stl
│ │ ├── calcn_r.stl
│ │ ├── capitate_rFB.stl
│ │ ├── clavicle_rFB.stl
│ │ ├── femur_r.stl
│ │ ├── fibula_r.stl
│ │ ├── foot.stl
│ │ ├── hamate_rFB.stl
│ │ ├── humerus_rFB.stl
│ │ ├── indexDistal_rFB.stl
│ │ ├── indexMetacarpal_rFB.stl
│ │ ├── indexMid_rFB.stl
│ │ ├── indexProx_rFB.stl
│ │ ├── l_pelvis.stl
│ │ ├── lunate_rFB.stl
│ │ ├── middleDistal_rFB.stl
│ │ ├── middleMetacarpal_rFB.stl
│ │ ├── middleMid_rFB.stl
│ │ ├── middleProx_rFB.stl
│ │ ├── pat.stl
│ │ ├── patella_r.stl
│ │ ├── pelvis.stl
│ │ ├── pinkyDistal_rFB.stl
│ │ ├── pinkyMetacarpal_rFB.stl
│ │ ├── pinkyMid_rFB.stl
│ │ ├── pinkyProx_rFB.stl
│ │ ├── pisiform_rFB.stl
│ │ ├── radius_rFB.stl
│ │ ├── ribcageFB.stl
│ │ ├── ringDistal_rFB.stl
│ │ ├── ringMetacarpal_rFB.stl
│ │ ├── ringMid_rFB.stl
│ │ ├── ringProx_rFB.stl
│ │ ├── sacrum.stl
│ │ ├── scaphoid_rFB.stl
│ │ ├── scapula_rFB.stl
│ │ ├── talus.stl
│ │ ├── talus_r.stl
│ │ ├── thoracic10FB.stl
│ │ ├── thoracic11FB.stl
│ │ ├── thoracic12FB.stl
│ │ ├── thoracic1FB.stl
│ │ ├── thoracic2FB.stl
│ │ ├── thoracic3FB.stl
│ │ ├── thoracic4FB.stl
│ │ ├── thoracic5FB.stl
│ │ ├── thoracic6FB.stl
│ │ ├── thoracic7FB.stl
│ │ ├── thoracic8FB.stl
│ │ ├── thoracic9FB.stl
│ │ ├── thumbDistal_rFB.stl
│ │ ├── thumbMid_rFB.stl
│ │ ├── thumbProx_rFB.stl
│ │ ├── tibia_r.stl
│ │ ├── toes_r.stl
│ │ ├── trapezium_rFB.stl
│ │ ├── trapezoid_rFB.stl
│ │ ├── triquetral_rFB.stl
│ │ └── ulna_rFB.stl
│ ├── HandModelTSLAdjusted_converted.xml
│ ├── half_cheetah.xml
│ ├── hopper.xml
│ ├── inverted_double_pendulum.xml
│ ├── inverted_pendulum.xml
│ ├── leg6dof9musc_converted.xml
│ ├── swimmer.xml
│ └── walker2d.xml
├── build.py
├── envs
│ ├── HandModelTSLAdjusted.py
│ ├── __init__.py
│ ├── half_cheetah.py
│ ├── hopper.py
│ ├── inverted_double_pendulum.py
│ ├── inverted_pendulum.py
│ ├── leg.py
│ ├── swimmer.py
│ └── walker2d.py
└── utils
│ ├── __init__.py
│ ├── backward.py
│ ├── forward.py
│ └── wrappers
│ ├── __init__.py
│ ├── etc.py
│ └── mj_block.py
├── requirements.txt
├── solver
├── __init__.py
└── build.py
├── tests
├── __init__.py
└── test_gradients.py
└── utils
├── __init__.py
├── index.py
├── logger.py
└── visdom_plots.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | __pycache__
3 | output/
4 | MUJOCO_LOG.TXT
5 | .pytest_cache
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Model-based Reinforcement Learning
2 |
3 | Directly back-propagate into your policy network, from model jacobians calculated in MuJoCo using finite-difference.
4 |
5 | To backprop into stochastic policies, given an unknown model, one has to use the REINFORCE theory, to be able to calculate the gradients by sampling the environment. These methods usually have high variance, so baselines and value/advantage functions were introduced. Another way to backpropagate into your policy network is to use the “reparameterization trick” as in VAEs, but they entail knowledge of upstream gradients, and hence a known model. The policy gradients calculated w/ the reparam trick are often much lower in variance, so one can go wo/ baselines and value networks. This project puts it all together: computation graph of policy and dynamics, upstream gradients from MuJoCo dynamics and rewards, reparam trick, and optimization.
6 |
7 | ### Vanilla Computation Graph
8 | ```txt
9 | +----------+S0+----------+ +----------+S1+----------+
10 | | | | |
11 | | +------+ A0 +---v----+ + +------+ A1 +---v----+
12 | S0+------>+Policy+---+--->+Dynamics+---+---+S1+-->+Policy+---+--->+Dynamics+--->S2 ...
13 | | +------+ | +--------+ | + +------+ | +--------+ |
14 | | | | | | |
15 | | +--v---+ | | +--v---+ |
16 | +---+S0+---->+Reward+<-----S1-----+ +---+S1+---->+Reward+<-----S2------+
17 | +------+ +------+
18 | ```
19 |
20 | ### Results
21 |
22 |
23 |
24 |
25 |
26 |
27 | ### This repo contains:
28 | * Finite-difference calculation of MuJoCo dynamics jacobians in `mujoco-py`
29 | * MuJoCo dynamics as a PyTorch Operation (i.e. forward and backward pass)
30 | * Reward function PyTorch Operation
31 | * Flexible design to wire up your own meta computation graph
32 | * Trajectory Optimization module alongside Policy Networks
33 | * Flexible design to define your own environment in `gym`
34 | * Fancy logger and monitoring
35 |
36 | ### Dependencies
37 | Python3.6:
38 | * `torch`
39 | * `mujoco-py`
40 | * `gym`
41 | * `numpy`
42 | * `visdom`
43 |
44 | Other:
45 | * Tested w/ `mujoco200`
46 |
47 | ### Usage
48 | For latest changes:
49 | ```bash
50 | git clone -b development git@github.com:MahanFathi/Model-Based-RL.git
51 | ```
52 | Run:
53 | ```bash
54 | python3 main.py --config-file ./configs/inverted_pendulum.yaml
55 | ```
56 |
--------------------------------------------------------------------------------
/configs/HandModelTSLAdjusted.yaml:
--------------------------------------------------------------------------------
1 | MUJOCO:
2 | ENV: 'HandModelTSLAdjustedEnv'
3 | CLIP_ACTIONS: false
4 | MODEL:
5 | META_ARCHITECTURE: 'Basic'
6 | BATCH_SIZE: 1
7 | EPOCHS: 1000
8 | NSTEPS_FOR_BACKWARD: 2
9 | FRAME_SKIP: 20
10 | TIMESTEP: 0.002
11 | POLICY:
12 | ARCH: 'Perttu'
13 | METHOD: 'R-VO'
14 | LAYERS:
15 | - 128
16 | - 128
17 | GAMMA: 0.99
18 | MAX_HORIZON_STEPS: 50
19 | INITIAL_SD: 0.2
20 | INITIAL_ACTION_MEAN: 0.0
21 | INITIAL_ACTION_SD: 0.0
22 | SOLVER:
23 | BASE_LR: 0.05
24 | WEIGHT_DECAY: 0.0
25 | BIAS_LR_FACTOR: 1
26 | STD_LR_FACTOR: 1.0
27 | ADAM_BETAS: (0.9, 0.999)
28 | LOG:
29 | TESTING:
30 | ENABLED: True
31 | ITER_PERIOD: 1
32 | RECORD_VIDEO: true
--------------------------------------------------------------------------------
/configs/Model-Based-RL.yml:
--------------------------------------------------------------------------------
1 | name: Model-Based-RL
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - anaconda
6 | - https://conda.anaconda.org/conda-forge/
7 | - defaults
8 | dependencies:
9 | - _libgcc_mutex=0.1=main
10 | - asn1crypto=1.2.0=py37_0
11 | - attrs=19.3.0=py_0
12 | - backcall=0.1.0=py_0
13 | - blas=1.0=mkl
14 | - bleach=3.1.0=py_0
15 | - bzip2=1.0.8=h7b6447c_0
16 | - ca-certificates=2019.11.28=hecc5488_0
17 | - certifi=2019.11.28=py37_0
18 | - cffi=1.12.3=py37h8022711_0
19 | - chardet=3.0.4=py37_1003
20 | - click=7.0=py_0
21 | - cryptography=2.7=py37h72c5cf5_0
22 | - cudatoolkit=10.1.168=0
23 | - curl=7.65.3=hbc83047_0
24 | - cycler=0.10.0=py37_0
25 | - cython=0.29.13=py37he6710b0_0
26 | - dbus=1.13.6=h746ee38_0
27 | - defusedxml=0.6.0=py_0
28 | - entrypoints=0.3=py37_1000
29 | - expat=2.2.6=he6710b0_0
30 | - ffmpeg=4.2=h167e202_0
31 | - fontconfig=2.13.0=h9420a91_0
32 | - freetype=2.9.1=h8a8886c_1
33 | - future=0.17.1=py37_0
34 | - glib=2.56.2=hd408876_0
35 | - gmp=6.1.2=hf484d3e_1000
36 | - gnutls=3.6.5=hd3a4fd2_1002
37 | - gst-plugins-base=1.14.0=hbbd80ab_1
38 | - gstreamer=1.14.0=hb453b48_1
39 | - hdf4=4.2.13=h3ca952b_2
40 | - hdf5=1.10.4=hb1b8bf9_0
41 | - icu=58.2=h9c2bf20_1
42 | - idna=2.8=py37_1000
43 | - imageio=2.6.1=py37_0
44 | - imageio-ffmpeg=0.3.0=py_0
45 | - importlib_metadata=1.2.0=py37_0
46 | - intel-openmp=2019.4=243
47 | - ipykernel=5.1.3=py37h5ca1d4c_0
48 | - ipython_genutils=0.2.0=py_1
49 | - ipywidgets=7.5.1=py_0
50 | - jedi=0.15.1=py37_0
51 | - jinja2=2.10.3=py_0
52 | - jpeg=9b=h024ee3a_2
53 | - json5=0.8.5=py_0
54 | - jsoncpp=1.8.4=hfd86e86_0
55 | - jsonpatch=1.24=py_0
56 | - jsonpointer=2.0=py_0
57 | - jsonschema=3.2.0=py37_0
58 | - jupyter=1.0.0=py_2
59 | - jupyter_client=5.3.3=py37_1
60 | - jupyter_console=5.2.0=py37_1
61 | - jupyter_core=4.6.1=py37_0
62 | - jupyterlab=1.2.3=py_0
63 | - jupyterlab_server=1.0.6=py_0
64 | - kiwisolver=1.1.0=py37he6710b0_0
65 | - krb5=1.16.1=h173b8e3_7
66 | - lame=3.100=h14c3975_1001
67 | - libcurl=7.65.3=h20c2e04_0
68 | - libedit=3.1.20181209=hc058e9b_0
69 | - libffi=3.2.1=hd88cf55_4
70 | - libgcc-ng=9.1.0=hdf63c60_0
71 | - libgfortran-ng=7.3.0=hdf63c60_0
72 | - libiconv=1.15=h516909a_1005
73 | - libnetcdf=4.6.1=h11d0813_2
74 | - libogg=1.3.2=h7b6447c_0
75 | - libpng=1.6.37=hbc83047_0
76 | - libprotobuf=3.10.1=h8b12597_0
77 | - libsodium=1.0.17=h516909a_0
78 | - libssh2=1.8.2=h1ba5d50_0
79 | - libstdcxx-ng=9.1.0=hdf63c60_0
80 | - libtheora=1.1.1=h5ab3b9f_1
81 | - libtiff=4.0.10=h2733197_2
82 | - libuuid=1.0.3=h1bed415_2
83 | - libvorbis=1.3.6=h7b6447c_0
84 | - libxcb=1.13=h1bed415_1
85 | - libxml2=2.9.9=hea5a465_1
86 | - lockfile=0.12.2=py_1
87 | - lz4-c=1.8.1.2=h14c3975_0
88 | - markdown=3.1.1=py_0
89 | - markupsafe=1.1.1=py37h516909a_0
90 | - matplotlib=3.1.1=py37h5429711_0
91 | - mistune=0.8.4=py37h516909a_1000
92 | - mkl=2019.4=243
93 | - mkl-service=2.3.0=py37he904b0f_0
94 | - mkl_fft=1.0.14=py37ha843d7b_0
95 | - mkl_random=1.1.0=py37hd6b4f25_0
96 | - more-itertools=8.0.2=py_0
97 | - nbconvert=5.6.1=py37_0
98 | - nbformat=4.4.0=py_1
99 | - ncurses=6.1=he6710b0_1
100 | - nettle=3.4.1=h1bed415_1002
101 | - ninja=1.9.0=h6bb024c_0
102 | - notebook=6.0.1=py37_0
103 | - numpy=1.17.2=py37haad9e8e_0
104 | - numpy-base=1.17.2=py37hde5b4d6_0
105 | - olefile=0.46=py_0
106 | - openh264=1.8.0=hdbcaa40_1000
107 | - openssl=1.1.1d=h516909a_0
108 | - pandas=0.25.1=py37he6710b0_0
109 | - pandoc=2.8.1=0
110 | - pandocfilters=1.4.2=py_1
111 | - parso=0.5.1=py_0
112 | - patchelf=0.9=he6710b0_3
113 | - patsy=0.5.1=py37_0
114 | - pcre=8.43=he6710b0_0
115 | - pexpect=4.7.0=py37_0
116 | - pillow=6.2.0=py37h34e0f95_0
117 | - pip=19.2.3=py37_0
118 | - prometheus_client=0.7.1=py_0
119 | - prompt_toolkit=3.0.2=py_0
120 | - protobuf=3.10.1=py37he1b5a44_0
121 | - ptyprocess=0.6.0=py_1001
122 | - pycparser=2.19=py37_1
123 | - pyopenssl=19.0.0=py37_0
124 | - pyparsing=2.4.2=py_0
125 | - pyqt=5.9.2=py37h05f1152_2
126 | - pyrsistent=0.15.6=py37h516909a_0
127 | - pysocks=1.7.1=py37_0
128 | - python=3.7.4=h265db76_1
129 | - python-dateutil=2.8.0=py37_0
130 | - pytorch=1.3.0=py3.7_cuda10.1.243_cudnn7.6.3_0
131 | - pytz=2019.2=py_0
132 | - pyyaml=5.1.2=py37h516909a_0
133 | - pyzmq=18.1.0=py37h1768529_0
134 | - qt=5.9.7=h5867ecd_1
135 | - qtconsole=4.6.0=py_0
136 | - readline=7.0=h7b6447c_5
137 | - requests=2.22.0=py37_1
138 | - scipy=1.3.1=py37h7c811a0_0
139 | - seaborn=0.9.0=pyh91ea838_1
140 | - send2trash=1.5.0=py_0
141 | - setuptools=41.2.0=py37_0
142 | - sip=4.19.8=py37hf484d3e_0
143 | - six=1.12.0=py37_0
144 | - sk-video=1.1.10=pyh24bf2e0_4
145 | - sqlite=3.29.0=h7b6447c_0
146 | - statsmodels=0.10.1=py37hdd07704_0
147 | - tbb=2019.4=hfd86e86_0
148 | - tensorboard=1.14.0=py37_0
149 | - terminado=0.8.3=py37_0
150 | - testpath=0.4.4=py_0
151 | - tk=8.6.8=hbc83047_0
152 | - torchfile=0.1.0=py_0
153 | - tornado=6.0.3=py37h7b6447c_0
154 | - traitlets=4.3.3=py37_0
155 | - urllib3=1.25.6=py37_0
156 | - visdom=0.1.8.9=0
157 | - vtk=8.2.0=py37haa4764d_200
158 | - webencodings=0.5.1=py_1
159 | - websocket-client=0.56.0=py37_0
160 | - werkzeug=0.16.0=py_0
161 | - wheel=0.33.6=py37_0
162 | - widgetsnbextension=3.5.1=py37_0
163 | - x264=1!152.20180806=h14c3975_0
164 | - xarray=0.14.1=py_1
165 | - xmltodict=0.12.0=py_0
166 | - xorg-fixesproto=5.0=h14c3975_1002
167 | - xorg-kbproto=1.0.7=h14c3975_1002
168 | - xorg-libx11=1.6.9=h516909a_0
169 | - xorg-libxcursor=1.2.0=h516909a_0
170 | - xorg-libxext=1.3.4=h516909a_0
171 | - xorg-libxfixes=5.0.3=h516909a_1004
172 | - xorg-libxinerama=1.1.4=hf484d3e_1000
173 | - xorg-libxrandr=1.5.2=h516909a_1
174 | - xorg-libxrender=0.9.10=h516909a_1002
175 | - xorg-randrproto=1.5.0=h14c3975_1001
176 | - xorg-renderproto=0.11.1=h14c3975_1002
177 | - xorg-xextproto=7.3.0=h14c3975_1002
178 | - xorg-xproto=7.0.31=h14c3975_1007
179 | - xz=5.2.4=h14c3975_4
180 | - yacs=0.1.6=py_0
181 | - yaml=0.1.7=h14c3975_1001
182 | - zeromq=4.3.2=he1b5a44_2
183 | - zipp=0.6.0=py_0
184 | - zlib=1.2.11=h7b6447c_3
185 | - zstd=1.3.7=h0b5b093_0
186 | - pip:
187 | - admesh==0.98.9
188 | - cloudpickle==1.2.2
189 | - cma==2.7.0
190 | - decorator==4.4.0
191 | - glfw==1.8.3
192 | - gym==0.15.3
193 | - ipdb==0.12.2
194 | - ipython==7.8.0
195 | - ipython-genutils==0.2.0
196 | - mujoco-py==2.0.2.5
197 | - pickleshare==0.7.5
198 | - prompt-toolkit==2.0.10
199 | - pyglet==1.3.2
200 | - pygments==2.4.2
201 | - pyquaternion==0.9.5
202 | - python-graphviz==0.13
203 | - torchviz==0.0.1
204 | - wcwidth==0.1.7
205 |
--------------------------------------------------------------------------------
/configs/half_cheetah.yaml:
--------------------------------------------------------------------------------
1 | MUJOCO:
2 | ENV: 'HalfCheetahEnv'
3 | CLIP_ACTIONS: false
4 | MODEL:
5 | META_ARCHITECTURE: 'Basic'
6 | BATCH_SIZE: 1
7 | FRAME_SKIP: 10
8 | EPOCHS: 5001
9 | RANDOM_SEED: 40
10 | POLICY:
11 | ARCH: 'Perttu'
12 | METHOD: 'Grad'
13 | NETWORK: False
14 | MAX_HORIZON_STEPS: 40
15 | LAYERS:
16 | - 128
17 | - 128
18 | GAMMA: 0.99
19 | INITIAL_LOG_SD: -0.5
20 | INITIAL_SD: 0.5
21 | INITIAL_ACTION_MEAN: 0.0
22 | INITIAL_ACTION_SD: 0.1
23 | GRAD_WEIGHTS: 'average'
24 | SOLVER:
25 | OPTIMIZER: 'adam'
26 | BASE_LR: 0.001
27 | WEIGHT_DECAY: 0.0
28 | BIAS_LR_FACTOR: 1
29 | STD_LR_FACTOR: 1.0
30 | ADAM_BETAS: (0.9, 0.999)
31 | LOG:
32 | TESTING:
33 | ITER_PERIOD: 200
34 | RECORD_VIDEO: True
35 |
--------------------------------------------------------------------------------
/configs/hopper.yaml:
--------------------------------------------------------------------------------
1 | MUJOCO:
2 | ENV: 'HopperEnv'
3 | CLIP_ACTIONS: false
4 | MODEL:
5 | META_ARCHITECTURE: 'Basic'
6 | BATCH_SIZE: 1
7 | FRAME_SKIP: 20
8 | EPOCHS: 5001
9 | RANDOM_SEED: 50
10 | POLICY:
11 | ARCH: 'Perttu'
12 | METHOD: 'Grad'
13 | NETWORK: False
14 | MAX_HORIZON_STEPS: 40
15 | LAYERS:
16 | - 128
17 | - 128
18 | GAMMA: 0.99
19 | INITIAL_LOG_SD: -0.5
20 | INITIAL_SD: 0.5
21 | INITIAL_ACTION_MEAN: 0.0
22 | INITIAL_ACTION_SD: 0.1
23 | GRAD_WEIGHTS: 'average'
24 | SOLVER:
25 | OPTIMIZER: 'adam'
26 | BASE_LR: 0.001
27 | WEIGHT_DECAY: 0.0
28 | BIAS_LR_FACTOR: 1
29 | STD_LR_FACTOR: 1.0
30 | ADAM_BETAS: (0.9, 0.999)
31 | LOG:
32 | TESTING:
33 | ITER_PERIOD: 200
34 | RECORD_VIDEO: True
35 |
--------------------------------------------------------------------------------
/configs/inverted_double_pendulum.yaml:
--------------------------------------------------------------------------------
1 | MUJOCO:
2 | ENV: 'InvertedDoublePendulumEnv'
3 | CLIP_ACTIONS: false
4 | MODEL:
5 | META_ARCHITECTURE: 'Basic'
6 | BATCH_SIZE: 8
7 | FRAME_SKIP: 4
8 | RANDOM_SEED: 30
9 | EPOCHS: 5001
10 | POLICY:
11 | ARCH: 'Perttu'
12 | METHOD: 'R-VO'
13 | NETWORK: False
14 | MAX_HORIZON_STEPS: 50
15 | LAYERS:
16 | - 128
17 | - 128
18 | GAMMA: 0.99
19 | INITIAL_LOG_SD: -0.5
20 | INITIAL_SD: 0.2
21 | INITIAL_ACTION_MEAN: 0.0
22 | INITIAL_ACTION_SD: 0.1
23 | GRAD_WEIGHTS: "average"
24 | SOLVER:
25 | OPTIMIZER: 'adam'
26 | BASE_LR: 0.001
27 | WEIGHT_DECAY: 0.0
28 | BIAS_LR_FACTOR: 1
29 | STD_LR_FACTOR: 1.0
30 | ADAM_BETAS: (0.9, 0.999)
31 | LOG:
32 | TESTING:
33 | ITER_PERIOD: 100
34 | RECORD_VIDEO: True
35 |
--------------------------------------------------------------------------------
/configs/inverted_pendulum.yaml:
--------------------------------------------------------------------------------
1 | MUJOCO:
2 | ENV: 'InvertedPendulumEnv'
3 | CLIP_ACTIONS: false
4 | MODEL:
5 | META_ARCHITECTURE: 'Basic'
6 | BATCH_SIZE: 8
7 | FRAME_SKIP: 2
8 | EPOCHS: 1001
9 | RANDOM_SEED: 10
10 | POLICY:
11 | ARCH: 'Perttu'
12 | METHOD: 'R-VO'
13 | NETWORK: False
14 | MAX_HORIZON_STEPS: 50
15 | LAYERS:
16 | - 128
17 | - 128
18 | GAMMA: 0.99
19 | INITIAL_LOG_SD: -0.5
20 | INITIAL_SD: 0.1
21 | INITIAL_ACTION_MEAN: 0.0
22 | INITIAL_ACTION_SD: 0.0
23 | GRAD_WEIGHTS: 'average'
24 | SOLVER:
25 | OPTIMIZER: 'adam'
26 | BASE_LR: 0.001
27 | WEIGHT_DECAY: 0.0
28 | BIAS_LR_FACTOR: 1
29 | STD_LR_FACTOR: 1.0
30 | ADAM_BETAS: (0.9, 0.999)
31 | LOG:
32 | TESTING:
33 | ENABLED: True
34 | ITER_PERIOD: 200
35 | RECORD_VIDEO: True
--------------------------------------------------------------------------------
/configs/leg.yaml:
--------------------------------------------------------------------------------
1 | MUJOCO:
2 | ENV: 'LegEnv'
3 | CLIP_ACTIONS: false
4 | MODEL:
5 | META_ARCHITECTURE: 'Basic'
6 | BATCH_SIZE: 32
7 | EPOCHS: 100
8 | POLICY:
9 | ARCH: 'Perttu'
10 | METHOD: 'VO'
11 | LAYERS:
12 | - 128
13 | - 128
14 | GAMMA: 0.99
15 | MAX_HORIZON_STEPS: 228
16 | INITIAL_LOG_SD: -1.0
17 | INITIAL_ACTION_MEAN: 0.0
18 | INITIAL_ACTION_SD: 0.1
19 | SOLVER:
20 | BASE_LR: 0.1
21 | WEIGHT_DECAY: 0.0
22 | BIAS_LR_FACTOR: 1
23 | STD_LR_FACTOR: 1.0
24 | ADAM_BETAS: (0.9, 0.999)
25 | LOG:
26 | TESTING:
27 | ENABLED: False
--------------------------------------------------------------------------------
/configs/swimmer.yaml:
--------------------------------------------------------------------------------
1 | MUJOCO:
2 | ENV: 'SwimmerEnv'
3 | CLIP_ACTIONS: false
4 | MODEL:
5 | META_ARCHITECTURE: 'Basic'
6 | BATCH_SIZE: 1
7 | FRAME_SKIP: 10
8 | EPOCHS: 5001
9 | RANDOM_SEED: 20
10 | POLICY:
11 | ARCH: 'Perttu'
12 | METHOD: 'Grad'
13 | NETWORK: False
14 | MAX_HORIZON_STEPS: 40
15 | LAYERS:
16 | - 128
17 | - 128
18 | GAMMA: 0.99
19 | INITIAL_LOG_SD: -0.5
20 | INITIAL_SD: 0.2
21 | INITIAL_ACTION_MEAN: 0.0
22 | INITIAL_ACTION_SD: 0.1
23 | GRAD_WEIGHTS: 'average'
24 | SOLVER:
25 | OPTIMIZER: 'adam'
26 | BASE_LR: 0.001
27 | WEIGHT_DECAY: 0.0
28 | BIAS_LR_FACTOR: 1
29 | STD_LR_FACTOR: 1.0
30 | ADAM_BETAS: (0.9, 0.999)
31 | LOG:
32 | TESTING:
33 | ITER_PERIOD: 200
34 | RECORD_VIDEO: True
35 |
--------------------------------------------------------------------------------
/configs/walker2d.yaml:
--------------------------------------------------------------------------------
1 | MUJOCO:
2 | ENV: 'Walker2dEnv'
3 | CLIP_ACTIONS: false
4 | MODEL:
5 | META_ARCHITECTURE: 'Basic'
6 | BATCH_SIZE: 1
7 | FRAME_SKIP: 25
8 | EPOCHS: 5001
9 | RANDOM_SEED: 60
10 | POLICY:
11 | ARCH: 'Perttu'
12 | METHOD: 'Grad'
13 | NETWORK: False
14 | MAX_HORIZON_STEPS: 40
15 | LAYERS:
16 | - 128
17 | - 128
18 | GAMMA: 0.99
19 | INITIAL_LOG_SD: -0.5
20 | INITIAL_SD: 0.5
21 | INITIAL_ACTION_MEAN: 0.0
22 | INITIAL_ACTION_SD: 0.1
23 | GRAD_WEIGHTS: 'average'
24 | SOLVER:
25 | OPTIMIZER: 'adam'
26 | BASE_LR: 0.001
27 | WEIGHT_DECAY: 0.0
28 | BIAS_LR_FACTOR: 1
29 | STD_LR_FACTOR: 1.0
30 | ADAM_BETAS: (0.9, 0.999)
31 | LOG:
32 | TESTING:
33 | ITER_PERIOD: 200
34 | RECORD_VIDEO: True
35 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from datetime import datetime
4 |
5 | import model.engine.trainer
6 | import model.engine.dynamics_model_trainer
7 | from model.config import get_cfg_defaults
8 | import utils.logger as lg
9 | import model.engine.landscape_plot
10 |
11 |
12 | def train(cfg, iter):
13 |
14 | # Create output directories
15 | env_output_dir = os.path.join(cfg.OUTPUT.DIR, cfg.MUJOCO.ENV)
16 | if cfg.OUTPUT.NAME == "timestamp":
17 | output_dir_name = "{0:%Y-%m-%d %H:%M:%S}".format(datetime.now())
18 | else:
19 | output_dir_name = cfg.OUTPUT.NAME
20 | output_dir = os.path.join(env_output_dir, output_dir_name)
21 | output_rec_dir = os.path.join(output_dir, 'recordings')
22 | output_weights_dir = os.path.join(output_dir, 'weights')
23 | output_results_dir = os.path.join(output_dir, 'results')
24 | os.makedirs(output_dir)
25 | os.mkdir(output_weights_dir)
26 | os.mkdir(output_results_dir)
27 | if cfg.LOG.TESTING.ENABLED:
28 | os.mkdir(output_rec_dir)
29 |
30 | # Create logger
31 | logger = lg.setup_logger("model.engine.trainer", output_dir, 'logs')
32 | logger.info("Running with config:\n{}".format(cfg))
33 |
34 | # Repeat for required number of iterations
35 | for i in range(iter):
36 | agent = model.engine.trainer.do_training(
37 | cfg,
38 | logger,
39 | output_results_dir,
40 | output_rec_dir,
41 | output_weights_dir,
42 | i
43 | )
44 | model.engine.landscape_plot.visualise2d(agent, output_results_dir, i)
45 |
46 |
47 | def train_dynamics_model(cfg, iter):
48 |
49 | # Create output directories
50 | env_output_dir = os.path.join(cfg.OUTPUT.DIR, cfg.MUJOCO.ENV)
51 | output_dir = os.path.join(env_output_dir, "{0:%Y-%m-%d %H:%M:%S}".format(datetime.now()))
52 | output_rec_dir = os.path.join(output_dir, 'recordings')
53 | output_weights_dir = os.path.join(output_dir, 'weights')
54 | output_results_dir = os.path.join(output_dir, 'results')
55 | os.makedirs(output_dir)
56 | os.mkdir(output_weights_dir)
57 | os.mkdir(output_results_dir)
58 | if cfg.LOG.TESTING.ENABLED:
59 | os.mkdir(output_rec_dir)
60 |
61 | # Create logger
62 | logger = lg.setup_logger("model.engine.dynamics_model_trainer", output_dir, 'logs')
63 | logger.info("Running with config:\n{}".format(cfg))
64 |
65 | # Train the dynamics model
66 | model.engine.dynamics_model_trainer.do_training(
67 | cfg,
68 | logger,
69 | output_results_dir,
70 | output_rec_dir,
71 | output_weights_dir
72 | )
73 |
74 |
75 | def inference(cfg):
76 | pass
77 |
78 |
79 | def main():
80 | parser = argparse.ArgumentParser(description="PyTorch model-based RL.")
81 | parser.add_argument(
82 | "--config-file",
83 | default="",
84 | metavar="file",
85 | help="path to config file",
86 | type=str,
87 | )
88 | parser.add_argument(
89 | "--mode",
90 | default="train",
91 | metavar="mode",
92 | help="'train' or 'test' or 'dynamics'",
93 | type=str,
94 | )
95 | parser.add_argument(
96 | "--iter",
97 | default=1,
98 | help="Number of iterations",
99 | type=int
100 | )
101 | parser.add_argument(
102 | "--opts",
103 | help="Modify config options using the command-line",
104 | default=[],
105 | nargs=argparse.REMAINDER,
106 | )
107 |
108 | args = parser.parse_args()
109 |
110 | # build the config
111 | cfg = get_cfg_defaults()
112 | cfg.merge_from_file(args.config_file)
113 | cfg.merge_from_list(args.opts)
114 | cfg.freeze()
115 |
116 | # TRAIN
117 | if args.mode == "train":
118 | train(cfg, args.iter)
119 | elif args.mode == "dynamics":
120 | train_dynamics_model(cfg, args.iter)
121 |
122 |
123 | if __name__ == "__main__":
124 | main()
125 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .build import build_model
2 |
3 | __all__ = ["build_model"]
4 |
--------------------------------------------------------------------------------
/model/archs/__init__.py:
--------------------------------------------------------------------------------
1 | from .basic import Basic
2 |
3 | __all__ = ["Basic"]
4 |
--------------------------------------------------------------------------------
/model/archs/basic.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from model.blocks import build_policy, mj_torch_block_factory
4 | from copy import deepcopy
5 | from solver import build_optimizer
6 | import numpy as np
7 | from ..blocks.policy.strategies import *
8 |
9 |
10 | class Basic(nn.Module):
11 | def __init__(self, cfg, agent):
12 | """Build the model from the fed config node.
13 | :param cfg: CfgNode containing the configurations of everything.
14 | """
15 | super(Basic, self).__init__()
16 | self.cfg = cfg
17 | self.agent = agent
18 |
19 | # build policy net
20 | self.policy_net = build_policy(cfg, self.agent)
21 |
22 | # Make sure unwrapped agent can call policy_net
23 | self.agent.unwrapped.policy_net = self.policy_net
24 |
25 | # build forward dynamics block
26 | self.dynamics_block = mj_torch_block_factory(agent, 'dynamics').apply
27 | self.reward_block = mj_torch_block_factory(agent, 'reward').apply
28 |
29 | def forward(self, state):
30 | """Single pass.
31 | :param state:
32 | :return:
33 | """
34 |
35 | # We're generally using torch.float64 and numpy.float64 for precision, but the net can be trained with
36 | # torch.float32 -- not sure if this really makes a difference wrt speed or memory, but the default layers
37 | # seem to be using torch.float32
38 | action = self.policy_net(state.detach().float()).double()
39 |
40 | # Forward block will drive the simulation forward
41 | next_state = self.dynamics_block(state, action)
42 |
43 | # The reward is actually calculated in the dynamics_block, here we'll just grab it from the agent
44 | reward = self.reward_block(state, action)
45 |
46 | return next_state, reward
47 |
--------------------------------------------------------------------------------
/model/blocks/__init__.py:
--------------------------------------------------------------------------------
1 | from .policy import build_policy
2 | from .mujoco import mj_torch_block_factory
3 |
4 | __all__ = ["build_policy", "mj_torch_block_factory"]
5 |
--------------------------------------------------------------------------------
/model/blocks/mujoco.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import autograd, nn
3 | import numpy as np
4 |
5 |
6 | def mj_torch_block_factory(agent, mode):
7 | mj_forward = agent.forward_factory(mode)
8 | mj_gradients = agent.gradient_factory(mode)
9 |
10 | class MjBlock(autograd.Function):
11 |
12 | @staticmethod
13 | def forward(ctx, state, action):
14 |
15 | # Advance simulation or return reward
16 | if mode == "dynamics":
17 | # We need to get a deep copy of simulation data so we can return to this "snapshot"
18 | # (we can't deepcopy agent.sim.data because some variables are unpicklable)
19 | # We'll calculate gradients in the backward phase of "reward"
20 | agent.data.qpos[:] = state[:agent.model.nq].detach().numpy().copy()
21 | agent.data.qvel[:] = state[agent.model.nq:].detach().numpy().copy()
22 | agent.data.ctrl[:] = action.detach().numpy().copy()
23 | agent.data_snapshot = agent.get_snapshot()
24 |
25 | next_state = mj_forward()
26 | agent.next_state = next_state
27 |
28 | ctx.data_snapshot = agent.data_snapshot
29 | # ctx.reward = agent.reward
30 | # ctx.next_state = agent.next_state
31 |
32 | return torch.from_numpy(next_state)
33 |
34 | elif mode == "reward":
35 | ctx.data_snapshot = agent.data_snapshot
36 | ctx.reward = agent.reward
37 | ctx.next_state = agent.next_state
38 | return torch.Tensor([agent.reward]).double()
39 |
40 | else:
41 | raise TypeError("mode has to be 'dynamics' or 'gradient'")
42 |
43 | @staticmethod
44 | def backward(ctx, grad_output):
45 |
46 | if agent.cfg.MODEL.POLICY.GRAD_WEIGHTS == "prioritise":
47 | weight = agent.cfg.MODEL.POLICY.MAX_HORIZON_STEPS - ctx.data_snapshot.step_idx.value
48 | else:
49 | weight = 1 / (agent.cfg.MODEL.POLICY.MAX_HORIZON_STEPS - ctx.data_snapshot.step_idx.value)
50 |
51 | # We should need to calculate gradients only once per dynamics/reward cycle
52 | if mode == "dynamics":
53 |
54 | state_jacobian = torch.from_numpy(agent.dynamics_gradients["state"])
55 | action_jacobian = torch.from_numpy(agent.dynamics_gradients["action"])
56 |
57 | if agent.cfg.MODEL.POLICY.GRAD_WEIGHTS == "prioritise":
58 | action_jacobian = (1.0 / agent.running_sum) * action_jacobian
59 | elif agent.cfg.MODEL.POLICY.GRAD_WEIGHTS == "average":
60 | action_jacobian = weight * action_jacobian
61 |
62 | elif mode == "reward":
63 |
64 | if agent.cfg.MODEL.POLICY.GRAD_WEIGHTS == "prioritise":
65 | agent.running_sum += weight
66 |
67 | # Calculate gradients, "reward" is always called first
68 | mj_gradients(ctx.data_snapshot, ctx.next_state, ctx.reward, test=True)
69 | state_jacobian = torch.from_numpy(agent.reward_gradients["state"])
70 | action_jacobian = torch.from_numpy(agent.reward_gradients["action"])
71 |
72 | if agent.cfg.MODEL.POLICY.GRAD_WEIGHTS == "prioritise":
73 | state_jacobian = (weight - 1) * state_jacobian
74 | action_jacobian = (weight / agent.running_sum) * action_jacobian
75 | elif agent.cfg.MODEL.POLICY.GRAD_WEIGHTS == "average":
76 | action_jacobian = weight * action_jacobian
77 |
78 | else:
79 | raise TypeError("mode has to be 'dynamics' or 'reward'")
80 |
81 | if False:
82 | torch.set_printoptions(precision=5, sci_mode=False)
83 | print("{} {}".format(ctx.data_snapshot.time, mode))
84 | print("grad_output")
85 | print(grad_output)
86 | print("state_jacobian")
87 | print(state_jacobian)
88 | print("action_jacobian")
89 | print(action_jacobian)
90 | print("grad_output*state_jacobian")
91 | print(torch.matmul(grad_output, state_jacobian))
92 | print("grad_output*action_jacobian")
93 | print(torch.matmul(grad_output, action_jacobian))
94 | print('weight: {} ---- 1/({} - {})'.format(weight, agent.cfg.MODEL.POLICY.MAX_HORIZON_STEPS, ctx.data_snapshot.step_idx))
95 | print("")
96 |
97 | ds = torch.matmul(grad_output, state_jacobian)
98 | da = torch.matmul(grad_output, action_jacobian)
99 |
100 | #threshold = 5
101 | #ds.clamp_(-threshold, threshold)
102 | #da.clamp_(-threshold, threshold)
103 |
104 | return ds, da
105 |
106 | return MjBlock
107 |
108 |
--------------------------------------------------------------------------------
/model/blocks/policy/__init__.py:
--------------------------------------------------------------------------------
1 | from .build import build_policy
2 | from .deterministic import DeterministicPolicy
3 | from .trajopt import TrajOpt
4 | from .stochastic import StochasticPolicy
5 | from .strategies import CMAES, VariationalOptimization, Perttu
6 |
7 | __all__ = ["build_policy", "DeterministicPolicy", "StochasticPolicy", "TrajOpt", "CMAES", "VariationalOptimization",
8 | "Perttu"]
9 |
--------------------------------------------------------------------------------
/model/blocks/policy/base.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class BasePolicy(nn.Module):
5 | def __init__(self, policy_cfg, agent):
6 | super(BasePolicy, self).__init__()
7 | self.policy_cfg = policy_cfg
8 | self.agent = agent
9 |
10 | def forward(self, s):
11 | raise NotImplementedError
12 |
13 | def episode_callback(self):
14 | raise NotImplementedError
15 |
16 | def batch_callback(self):
17 | raise NotImplementedError
18 |
--------------------------------------------------------------------------------
/model/blocks/policy/build.py:
--------------------------------------------------------------------------------
1 | from model.blocks import policy
2 |
3 |
4 | def build_policy(cfg, agent):
5 | policy_factory = getattr(policy, cfg.MODEL.POLICY.ARCH)
6 | policy_net = policy_factory(cfg, agent)
7 | #if cfg.MODEL.POLICY.OBS_SCALER:
8 | # from .wrappers import ZMUSWrapper
9 | # policy_net = ZMUSWrapper(policy_net)
10 | return policy_net
11 |
12 |
--------------------------------------------------------------------------------
/model/blocks/policy/deterministic.py:
--------------------------------------------------------------------------------
1 | from .base import BasePolicy
2 | from model.layers import FeedForward
3 | from torch.nn.parameter import Parameter
4 | import torch
5 |
6 |
7 | class DeterministicPolicy(BasePolicy):
8 | def __init__(self, policy_cfg, agent):
9 | super(DeterministicPolicy, self).__init__(policy_cfg, agent)
10 |
11 | # for deterministic policies the model is just a simple feed forward net
12 | self.net = FeedForward(
13 | agent.observation_space.shape[0],
14 | policy_cfg.LAYERS,
15 | agent.action_space.shape[0])
16 |
17 | if policy_cfg.VARIATIONAL:
18 | self.logstd = Parameter(torch.empty(policy_cfg.MAX_HORIZON_STEPS, agent.action_space.shape[0]).fill_(0))
19 | self.index = 0
20 |
21 | def forward(self, s):
22 | # Get mean of action value
23 | a_mean = self.net(s)
24 |
25 | if self.policy_cfg.VARIATIONAL:
26 | # Get standard deviation of action
27 | a_std = torch.exp(self.logstd[self.index])
28 | self.index += 1
29 |
30 | # Sample with re-parameterization trick
31 | a = torch.distributions.Normal(a_mean, a_std).rsample()
32 |
33 | else:
34 | a = a_mean
35 |
36 | return a
37 |
38 | def episode_callback(self):
39 | if self.policy_cfg.VARIATIONAL:
40 | self.index = 0
41 |
42 | def batch_callback(self):
43 | pass
44 |
45 |
--------------------------------------------------------------------------------
/model/blocks/policy/dynamics.py:
--------------------------------------------------------------------------------
1 | from solver import build_optimizer
2 | import torch
3 |
4 |
5 | class DynamicsModel(torch.nn.Module):
6 | def __init__(self, agent):
7 | super(DynamicsModel, self).__init__()
8 |
9 | # We're using a simple two layered feedforward net
10 | self.net = torch.nn.Sequential(
11 | torch.nn.Linear(agent.observation_space.shape[0] + agent.action_space.shape[0], 128),
12 | torch.nn.LeakyReLU(),
13 | torch.nn.Linear(128, 128),
14 | torch.nn.LeakyReLU(),
15 | torch.nn.Linear(128, agent.observation_space.shape[0])
16 | ).to(torch.device("cpu"))
17 |
18 | self.optimizer = torch.optim.Adam(self.net.parameters(), 0.001)
19 |
20 | def forward(self, state, action):
21 |
22 | # Predict next state given current state and action
23 | next_state = self.net(torch.cat([state, action]))
24 |
25 | return next_state
26 |
27 |
--------------------------------------------------------------------------------
/model/blocks/policy/stochastic.py:
--------------------------------------------------------------------------------
1 | from .base import BasePolicy
2 | from model.layers import FeedForward
3 | import torch
4 | import torch.nn as nn
5 | import torch.distributions as tdist
6 | from model.blocks.utils import build_soft_lower_bound_fn
7 |
8 |
9 | class StochasticPolicy(BasePolicy):
10 | def __init__(self, policy_cfg, agent):
11 | super(StochasticPolicy, self).__init__(policy_cfg, agent)
12 |
13 | # Get number of actions
14 | self.num_actions = agent.action_space.shape[0]
15 |
16 | # The network outputs a gaussian distribution
17 | self.net = FeedForward(
18 | agent.observation_space.shape[0],
19 | policy_cfg.LAYERS,
20 | self.num_actions*2
21 | )
22 |
23 | def forward(self, s):
24 | # Get means and logs of standard deviations
25 | output = self.net(s)
26 | means = output[:self.num_actions]
27 | log_stds = output[self.num_actions:]
28 |
29 | # Return only means when testing
30 | if not self.training:
31 | return means
32 |
33 | # Get the actual standard deviations
34 | stds = torch.exp(log_stds)
35 |
36 | # Sample with re-parameterization trick
37 | a = tdist.Normal(means, stds).rsample()
38 |
39 | return a
40 |
41 | def episode_callback(self):
42 | pass
43 |
44 | def batch_callback(self):
45 | pass
46 |
47 |
--------------------------------------------------------------------------------
/model/blocks/policy/strategies.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | import torch
3 | import numpy as np
4 | import cma
5 | from torch import nn
6 | from model.layers import FeedForward
7 | from solver import build_optimizer
8 | from torch.nn.parameter import Parameter
9 | from optimizer import Optimizer
10 |
11 |
12 | class BaseStrategy(ABC, nn.Module):
13 |
14 | def __init__(self, cfg, agent, reinforce_loss_weight=1.0,
15 | min_reinforce_loss_weight=0.0, min_sd=0, soft_relu_beta=0.2,
16 | adam_betas=(0.9, 0.999)):
17 |
18 | super(BaseStrategy, self).__init__()
19 |
20 | self.cfg = cfg
21 | self.method = cfg.MODEL.POLICY.METHOD
22 |
23 | # Get action dimension, horizon length, and batch size
24 | self.action_dim = agent.action_space.sample().shape[0]
25 | self.state_dim = agent.observation_space.sample().shape[0]
26 | self.horizon = cfg.MODEL.POLICY.MAX_HORIZON_STEPS
27 | self.dim = (self.action_dim, self.horizon)
28 | self.batch_size = cfg.MODEL.BATCH_SIZE
29 |
30 | # Set initial values
31 | self.mean = []
32 | self.clamped_action = []
33 | self.sd = []
34 | #self.clamped_sd = []
35 | self.min_sd = min_sd
36 | self.soft_relu_beta = soft_relu_beta
37 | self.best_actions = []
38 | self.sd_threshold = 0.0001
39 |
40 | # Set loss parameters
41 | self.gamma = cfg.MODEL.POLICY.GAMMA
42 | self.min_reinforce_loss_weight = min_reinforce_loss_weight
43 | self.reinforce_loss_weight = reinforce_loss_weight
44 | self.eps = np.finfo(np.float32).eps.item()
45 | self.loss_functions = {"IR": self.IR_loss, "PR": self.PR_loss, "H": self.H_loss,
46 | "SIR": self.SIR_loss, "R": self.R_loss}
47 | self.log_prob = []
48 |
49 | # Set optimizer parameters
50 | self.optimizer = None
51 | self.learning_rate = cfg.SOLVER.BASE_LR
52 | self.adam_betas = adam_betas
53 | self.optimize_functions = {"default": self.standard_optimize, "H": self.H_optimize}
54 |
55 | # Get references in case we want to track episode and step
56 | self.step_idx = agent.get_step_idx()
57 | self.episode_idx = agent.get_episode_idx()
58 |
59 | @abstractmethod
60 | def optimize(self, batch_loss):
61 | pass
62 |
63 | @abstractmethod
64 | def forward(self, state):
65 | pass
66 |
67 | @staticmethod
68 | def clip(x, mean, limit):
69 | xmin = mean - limit
70 | xmax = mean + limit
71 | return torch.max(torch.min(x, xmax), xmin)
72 |
73 | @staticmethod
74 | def clip_negative(x):
75 | #return torch.max(x, torch.zeros(x.shape, dtype=torch.double))
76 | return torch.abs(x)
77 |
78 | @staticmethod
79 | def soft_relu(x, beta=0.1):
80 | return (torch.sqrt(x**2 + beta**2) + x) / 2.0
81 |
82 | # From Tassa 2012 synthesis and stabilization of complex behaviour
83 | def clamp_sd(self, sd):
84 | #return torch.max(sd, 0.001*torch.ones(sd.shape, dtype=torch.double))
85 | #return torch.exp(sd)
86 | #return self.min_sd + self.soft_relu(sd - self.min_sd, self.soft_relu_beta)
87 | return sd
88 |
89 | def clamp_action(self, action):
90 | #return self.soft_relu(action, self.soft_relu_beta)
91 | # return torch.clamp(action, min=0.0, max=1.0) # NB! Kills gradients at borders
92 | #return torch.exp(action)
93 |
94 | if self.cfg.MODEL.POLICY.NETWORK:
95 | #for idx in range(len(action)):
96 | # if idx < 10:
97 | # action[idx] = action[idx]/10
98 | # else:
99 | # action[idx] = torch.exp(action[idx])/20
100 | pass
101 | #else:
102 | # for idx in range(len(action)):
103 | # if idx < 10 or action[idx] >= 0:
104 | # continue
105 | # elif action[idx] < 0:
106 | # action[idx].data -= action[idx].data
107 |
108 | return action
109 |
110 | def get_clamped_sd(self):
111 | return self.clamp_sd(self.sd).detach().numpy()
112 |
113 | def get_clamped_action(self):
114 | return self.clamped_action
115 |
116 | def initialise_mean(self, loc=0.0, sd=0.1, seed=0):
117 | if self.dim is not None:
118 | #return np.zeros(self.dim, dtype=np.float64)
119 | if seed > 0:
120 | np.random.seed(seed)
121 | return np.asarray(np.random.normal(loc, sd, self.dim), dtype=np.float64)
122 |
123 | def initialise_sd(self, factor=1.0):
124 | if self.dim is not None:
125 | return factor*np.ones(self.dim, dtype=np.float64)
126 |
127 | # Initialise mean / sd or get dim from initial values
128 |
129 | def calculate_reinforce_loss(self, batch_loss, stepwise_loss=False):
130 | # Initialise a tensor for returns
131 | returns = torch.empty(batch_loss.shape, dtype=torch.float64)
132 |
133 | # Calculate returns
134 | for episode_idx, episode_loss in enumerate(batch_loss):
135 | R = 0
136 | for step_idx in range(1, len(episode_loss)+1):
137 | R = -episode_loss[-step_idx] + self.gamma*R
138 | returns[episode_idx, -step_idx] = R
139 |
140 | # Remove baseline
141 | advantages = ((returns - returns.mean(dim=0)) / (returns.std(dim=0) + self.eps)).detach()
142 |
143 | # Return REINFORCE loss
144 | #reinforce_loss = torch.mean((-advantages * self.log_prob).sum(dim=1))
145 | #reinforce_loss = torch.sum((-advantages * self.log_prob))
146 | avg_stepwise_loss = (-advantages * self.log_prob).mean(dim=0)
147 | if stepwise_loss:
148 | return avg_stepwise_loss
149 | else:
150 | #return torch.sum(avg_stepwise_loss)
151 | return torch.mean(torch.sum((-advantages * self.log_prob), dim=1))
152 |
153 | def calculate_objective_loss(self, batch_loss, stepwise_loss=False):
154 | #batch_loss = (batch_loss.mean(dim=0) - batch_loss) / (batch_loss.std(dim=0) + self.eps)
155 | #batch_loss = batch_loss.mean() - batch_loss
156 |
157 | # Initialise a tensor for returns
158 | #ret = torch.empty(batch_loss.shape, dtype=torch.float64)
159 | ## Calculate returns
160 | #for episode_idx, episode_loss in enumerate(batch_loss):
161 | # R = torch.zeros(1, dtype=torch.float64)
162 | # for step_idx in range(1, len(episode_loss)+1):
163 | # R = -episode_loss[-step_idx] + self.gamma*R.detach()
164 | # ret[episode_idx, -step_idx] = R
165 | #avg_stepwise_loss = torch.mean(-ret, dim=0)
166 | #advantages = ret - ret.mean(dim=0)
167 |
168 | #means = torch.mean(batch_loss, dim=0)
169 | #idxs = (batch_loss > torch.mean(batch_loss)).detach()
170 | #idxs = sums < np.mean(sums)
171 |
172 | #advantages = torch.zeros((1, batch_loss.shape[1]))
173 | #advantages = []
174 | #idxs = batch_loss < torch.mean(batch_loss)
175 | #for i in range(batch_loss.shape[1]):
176 | # avg = torch.mean(batch_loss[:,i]).detach().numpy()
177 | # idxs = batch_loss[:,i].detach().numpy() < avg
178 | # advantages[0,i] = torch.mean(batch_loss[idxs,i], dim=0)
179 | # if torch.any(idxs[:,i]):
180 | # advantages.append(torch.mean(batch_loss[idxs[:,i],i], dim=0))
181 |
182 | #batch_loss = (batch_loss - batch_loss.mean(dim=1).detach()) / (batch_loss.std(dim=1).detach() + self.eps)
183 | #batch_loss = (batch_loss - batch_loss.mean(dim=0).detach()) / (batch_loss.std(dim=0).detach() + self.eps)
184 | #batch_loss = (batch_loss - batch_loss.mean(dim=0).detach())
185 |
186 | #fvals = torch.empty(batch_loss.shape[0])
187 | #for episode_idx, episode_loss in enumerate(batch_loss):
188 | # nans = torch.isnan(episode_loss)
189 | # fvals[episode_idx] = torch.sum(episode_loss[nans == False])
190 |
191 | nans = torch.isnan(batch_loss)
192 | batch_loss[nans] = 0
193 |
194 | if stepwise_loss:
195 | return torch.mean(batch_loss, dim=0)
196 | else:
197 | #return torch.sum(batch_loss[0][100])
198 | #return torch.sum(advantages[advantages<0])
199 | #idxs = torch.isnan(batch_loss) == False
200 | #return torch.sum(batch_loss[idxs])
201 | #return torch.mean(fvals)
202 | #return torch.mean(torch.sum(batch_loss, dim=1))
203 |
204 | return torch.sum(torch.mean(batch_loss, dim=0))
205 | #return batch_loss[0][-1]
206 | #return torch.sum(torch.stack(advantages))
207 |
208 | # Remove baseline
209 | #advantages = ((ret - ret.mean(dim=0)) / (ret.std(dim=0) + self.eps))
210 | #advantages = ret
211 |
212 | # Return REINFORCE loss
213 | #avg_stepwise_loss = -advantages.mean(dim=0)
214 | #if stepwise_loss:
215 | # return avg_stepwise_loss
216 | #else:
217 | # return torch.sum(avg_stepwise_loss)
218 |
219 | def R_loss(self, batch_loss):
220 |
221 | # Get objective loss
222 | objective_loss = self.calculate_objective_loss(batch_loss)
223 |
224 | return objective_loss, {"objective_loss": float(torch.mean(torch.sum(batch_loss, dim=1)).detach().numpy()),
225 | "total_loss": float(objective_loss.detach().numpy())}
226 | #"total_loss": batch_loss.detach().numpy()}
227 |
228 |
229 | def PR_loss(self, batch_loss):
230 |
231 | # Get REINFORCE loss
232 | reinforce_loss = self.calculate_reinforce_loss(batch_loss)
233 |
234 | # Get objective loss
235 | objective_loss = self.calculate_objective_loss(batch_loss)
236 |
237 | # Return a sum of the objective loss and REINFORCE loss
238 | loss = objective_loss + self.reinforce_loss_weight*reinforce_loss
239 |
240 | return loss, {"objective_loss": float(torch.mean(torch.sum(batch_loss, dim=1)).detach().numpy()),
241 | "reinforce_loss": float(reinforce_loss.detach().numpy()),
242 | "total_loss": float(loss.detach().numpy())}
243 |
244 | def IR_loss(self, batch_loss):
245 |
246 | # Get REINFORCE loss
247 | reinforce_loss = self.calculate_reinforce_loss(batch_loss)
248 |
249 | # Get objective loss
250 | objective_loss = self.calculate_objective_loss(batch_loss)
251 |
252 | # Return an interpolated mix of the objective loss and REINFORCE loss
253 | clamped_sd = self.clamp_sd(self.sd)
254 | mix_factor = 0.5*((clamped_sd - self.min_sd) / (self.initial_clamped_sd - self.min_sd)).mean().detach().numpy()
255 | mix_factor = self.min_reinforce_loss_weight + (1.0 - self.min_reinforce_loss_weight)*mix_factor
256 | mix_factor = np.maximum(np.minimum(mix_factor, 1), 0)
257 |
258 | loss = (1.0 - mix_factor) * objective_loss + mix_factor * reinforce_loss
259 | return loss, {"objective_loss": [objective_loss.detach().numpy()],
260 | "reinforce_loss": [reinforce_loss.detach().numpy()],
261 | "total_loss": [loss.detach().numpy()]}
262 |
263 | def SIR_loss(self, batch_loss):
264 |
265 | # Get REINFORCE loss
266 | reinforce_loss = self.calculate_reinforce_loss(batch_loss, stepwise_loss=True)
267 |
268 | # Get objective loss
269 | objective_loss = self.calculate_objective_loss(batch_loss, stepwise_loss=True)
270 |
271 | # Return an interpolated mix of the objective loss and REINFORCE loss
272 | clamped_sd = self.clamp_sd(self.sd)
273 | mix_factor = 0.5*((clamped_sd - self.min_sd) / (self.initial_clamped_sd - self.min_sd)).detach()
274 | mix_factor = self.min_reinforce_loss_weight + (1.0 - self.min_reinforce_loss_weight)*mix_factor
275 | mix_factor = torch.max(torch.min(mix_factor, torch.ones_like(mix_factor)), torch.zeros_like(mix_factor))
276 |
277 | loss = ((1.0 - mix_factor) * objective_loss + mix_factor * reinforce_loss).sum()
278 | return loss, {"objective_loss": [objective_loss.detach().numpy()],
279 | "reinforce_loss": [reinforce_loss.detach().numpy()],
280 | "total_loss": [loss.detach().numpy()]}
281 |
282 | def H_loss(self, batch_loss):
283 |
284 | # Get REINFORCE loss
285 | reinforce_loss = self.calculate_reinforce_loss(batch_loss)
286 |
287 | # Get objective loss
288 | objective_loss = torch.sum(batch_loss, dim=1)
289 |
290 | # Get idx of best sampled action values
291 | best_idx = objective_loss.argmin()
292 |
293 | # If some other sample than the mean was best, update the mean to it
294 | if best_idx != 0:
295 | self.best_actions = self.clamped_action[:, :, best_idx].copy()
296 |
297 | return reinforce_loss, objective_loss[best_idx], \
298 | {"objective_loss": [objective_loss.detach().numpy()],
299 | "reinforce_loss": [reinforce_loss.detach().numpy()],
300 | "total_loss": [reinforce_loss.detach().numpy()]}
301 |
302 | def get_named_parameters(self, name):
303 | params = {}
304 | for key, val in self.named_parameters():
305 | if key.split(".")[0] == name:
306 | params[key] = val
307 |
308 | # Return a generator that behaves like self.named_parameters()
309 | return ((x, params[x]) for x in params)
310 |
311 | def standard_optimize(self, batch_loss):
312 |
313 | # Get appropriate loss
314 | loss, stats = self.loss_functions[self.method](batch_loss)
315 |
316 | self.optimizer.zero_grad()
317 | loss.backward()
318 | #nn.utils.clip_grad_norm_(self.parameters(), 0.1)
319 | if self.cfg.SOLVER.OPTIMIZER == "sgd":
320 | nn.utils.clip_grad_value_(self.parameters(), 1)
321 |
322 | #total_norm_sqr = 0
323 | #total_norm_sqr += self.mean._layers["linear_layer_0"].weight.grad.norm() ** 2
324 | #total_norm_sqr += self.mean._layers["linear_layer_0"].bias.grad.norm() ** 2
325 | #total_norm_sqr += self.mean._layers["linear_layer_1"].weight.grad.norm() ** 2
326 | #total_norm_sqr += self.mean._layers["linear_layer_1"].bias.grad.norm() ** 2
327 | #total_norm_sqr += self.mean._layers["linear_layer_2"].weight.grad.norm() ** 2
328 | #total_norm_sqr += self.mean._layers["linear_layer_2"].bias.grad.norm() ** 2
329 |
330 | #gradient_clip = 0.01
331 | #scale = min(1.0, gradient_clip / (total_norm_sqr ** 0.5 + 1e-4))
332 | #print("lr: ", scale * self.cfg.SOLVER.BASE_LR)
333 |
334 | #if total_norm_sqr ** 0.5 > gradient_clip:
335 |
336 | # for param in self.parameters():
337 | # param.data.sub_(scale * self.cfg.SOLVER.BASE_LR)
338 | # if param.grad is not None:
339 | # param.grad = (param.grad * gradient_clip) / (total_norm_sqr ** 0.5)
340 | #pass
341 | self.optimizer.step()
342 |
343 | #print("ll0 weight: {} {}".format(self.mean._layers["linear_layer_0"].weight_std.min(),
344 | # self.mean._layers["linear_layer_0"].weight_std.max()))
345 | #print("ll1 weight: {} {}".format(self.mean._layers["linear_layer_1"].weight_std.min(),
346 | # self.mean._layers["linear_layer_1"].weight_std.max()))
347 | #print("ll2 weight: {} {}".format(self.mean._layers["linear_layer_2"].weight_std.min(),
348 | # self.mean._layers["linear_layer_2"].weight_std.max()))
349 | #print("ll0 bias: {} {}".format(self.mean._layers["linear_layer_0"].bias_std.min(),
350 | # self.mean._layers["linear_layer_0"].bias_std.max()))
351 | #print("ll1 bias: {} {}".format(self.mean._layers["linear_layer_1"].bias_std.min(),
352 | # self.mean._layers["linear_layer_1"].bias_std.max()))
353 | #print("ll2 bias: {} {}".format(self.mean._layers["linear_layer_2"].bias_std.min(),
354 | # self.mean._layers["linear_layer_2"].bias_std.max()))
355 | #print("")
356 |
357 | # Make sure sd is not negative
358 | idxs = self.sd < self.sd_threshold
359 | self.sd.data[idxs] = self.sd_threshold
360 |
361 | # Empty log probs
362 | self.log_prob = torch.empty(self.batch_size, self.horizon, dtype=torch.float64)
363 |
364 | return stats
365 |
366 | def H_optimize(self, batch_loss):
367 |
368 | # Get appropriate loss
369 | loss, best_objective_loss, stats = self.loss_functions[self.method](batch_loss)
370 |
371 | # Adapt the mean
372 | self.optimizer["mean"].zero_grad()
373 | best_objective_loss.backward(retain_graph=True)
374 | self.optimizer["mean"].step()
375 |
376 | # Adapt the sd
377 | self.optimizer["sd"].zero_grad()
378 | loss.backward()
379 | self.optimizer["sd"].step()
380 |
381 | # Empty log probs
382 | self.log_prob = torch.empty(self.batch_size, self.horizon, dtype=torch.float64)
383 |
384 | return stats
385 |
386 |
387 | class VariationalOptimization(BaseStrategy):
388 |
389 | def __init__(self, *args, **kwargs):
390 | super(VariationalOptimization, self).__init__(*args, **kwargs)
391 |
392 | # Initialise mean and sd
393 | if self.cfg.MODEL.POLICY.NETWORK:
394 | # Set a feedforward network for means
395 | self.mean = FeedForward(
396 | self.state_dim,
397 | self.cfg.MODEL.POLICY.LAYERS,
398 | self.action_dim
399 | )
400 | else:
401 | # Set tensors for means
402 | self.mean = Parameter(torch.from_numpy(
403 | self.initialise_mean(self.cfg.MODEL.POLICY.INITIAL_ACTION_MEAN, self.cfg.MODEL.POLICY.INITIAL_ACTION_SD)
404 | ))
405 | self.register_parameter("mean", self.mean)
406 |
407 | # Set tensors for standard deviations
408 | self.sd = Parameter(torch.from_numpy(self.initialise_sd(self.cfg.MODEL.POLICY.INITIAL_SD)))
409 | self.initial_clamped_sd = self.clamp_sd(self.sd.detach())
410 | self.register_parameter("sd", self.sd)
411 | #self.clamped_sd = np.zeros((self.action_dim, self.horizon), dtype=np.float64)
412 | self.clamped_action = np.zeros((self.action_dim, self.horizon, self.batch_size), dtype=np.float64)
413 |
414 | # Initialise optimizer
415 | if self.method == "H":
416 | # Separate mean and sd optimizers (not sure if actually necessary)
417 | self.optimizer = {"mean": build_optimizer(self.cfg, self.get_named_parameters("mean")),
418 | "sd": build_optimizer(self.cfg, self.get_named_parameters("sd"))}
419 | self.best_actions = np.empty(self.sd.shape)
420 | self.best_actions.fill(np.nan)
421 | else:
422 | self.optimizer = build_optimizer(self.cfg, self.named_parameters())
423 |
424 | # We need log probabilities for calculating REINFORCE loss
425 | self.log_prob = torch.empty(self.batch_size, self.horizon, dtype=torch.float64)
426 |
427 | def forward(self, state):
428 |
429 | # Get clamped sd
430 | clamped_sd = self.clamp_sd(self.sd[:, self.step_idx])
431 |
432 | # Get mean of action value
433 | if self.cfg.MODEL.POLICY.NETWORK:
434 | mean = self.mean(state).double()
435 | else:
436 | mean = self.mean[:, self.step_idx]
437 |
438 | if not self.training:
439 | return self.clamp_action(mean)
440 |
441 | # Get normal distribution
442 | dist = torch.distributions.Normal(mean, clamped_sd)
443 |
444 | # Sample action
445 | if self.method == "H" and self.episode_idx == 0:
446 | if np.all(np.isnan(self.best_actions[:, self.step_idx])):
447 | action = mean
448 | else:
449 | action = torch.from_numpy(self.best_actions[:, self.step_idx])
450 | elif self.batch_size > 1:
451 | action = dist.rsample()
452 | else:
453 | action = mean
454 |
455 | # Clip action
456 | action = self.clamp_action(action)
457 | #action = self.clip(action, mean, 2.0*clamped_sd)
458 |
459 | self.clamped_action[:, self.step_idx, self.episode_idx-1] = action.detach().numpy()
460 |
461 | # Get log prob for REINFORCE loss calculations
462 | self.log_prob[self.episode_idx-1, self.step_idx] = dist.log_prob(action.detach()).sum()
463 |
464 | return action
465 |
466 | def optimize(self, batch_loss):
467 | return self.optimize_functions.get(self.method, self.standard_optimize)(batch_loss)
468 |
469 |
470 | class CMAES(BaseStrategy):
471 |
472 | def __init__(self, *args, **kwargs):
473 | super(CMAES, self).__init__(*args, **kwargs)
474 |
475 | # Make sure batch size is larger than one
476 | assert self.batch_size > 1, "Batch size must be >1 for CMA-ES"
477 |
478 | # Set up CMA-ES options
479 | cmaes_options = {"popsize": self.batch_size, "CMA_diagonal": True}
480 |
481 | # Initialise mean and flatten it
482 | self.mean = self.initialise_mean()
483 | self.mean = np.reshape(self.mean, (self.mean.size,))
484 |
485 | # We want to store original (list of batches) actions for tell
486 | self.orig_actions = []
487 |
488 | # Initialise CMA-ES
489 | self.optimizer = cma.CMAEvolutionStrategy(self.mean, self.cfg.MODEL.POLICY.INITIAL_SD, inopts=cmaes_options)
490 | self.actions = []
491 |
492 | def forward(self, state):
493 |
494 | # If we've hit the end of minibatch we need to sample more actions
495 | if self.step_idx == 0 and self.episode_idx - 1 == 0 and self.training:
496 | self.orig_actions = self.optimizer.ask()
497 | self.actions = torch.empty(self.action_dim, self.horizon, self.batch_size, dtype=torch.float64)
498 | for ep_idx, ep_actions in enumerate(self.orig_actions):
499 | self.actions[:, :, ep_idx] = torch.from_numpy(np.reshape(ep_actions, (self.action_dim, self.horizon)))
500 |
501 | # Get action
502 | action = self.actions[:, self.step_idx, self.episode_idx-1]
503 |
504 | return action
505 |
506 | def optimize(self, batch_loss):
507 | loss = batch_loss.sum(axis=1)
508 | self.optimizer.tell(self.orig_actions, loss.detach().numpy())
509 | return {"objective_loss": float(loss.detach().numpy().mean()), "total_loss": float(loss.detach().numpy().mean())}
510 |
511 | def get_clamped_sd(self):
512 | return np.asarray(self.sd)
513 |
514 | def get_clamped_action(self):
515 | return self.actions.detach().numpy()
516 |
517 |
518 | class Perttu(BaseStrategy):
519 | def __init__(self, *args, **kwargs):
520 | super(Perttu, self).__init__(*args, **kwargs)
521 |
522 | # Initialise optimizer object
523 | self.optimizer = \
524 | Optimizer(mode=self.method,
525 | initialMean=np.random.normal(self.cfg.MODEL.POLICY.INITIAL_ACTION_MEAN,
526 | self.cfg.MODEL.POLICY.INITIAL_ACTION_SD,
527 | (self.action_dim, self.horizon)),
528 | initialSd=self.cfg.MODEL.POLICY.INITIAL_SD*np.ones((self.action_dim, self.horizon)),
529 | #initialSd=self.cfg.MODEL.POLICY.INITIAL_SD*np.ones((1, 1)),
530 | learningRate=self.cfg.SOLVER.BASE_LR,
531 | adamBetas=(0.9, 0.99),
532 | minReinforceLossWeight=0.0,
533 | nBatch=self.cfg.MODEL.BATCH_SIZE,
534 | solver=self.cfg.SOLVER.OPTIMIZER)
535 |
536 | def forward(self, state):
537 |
538 | # If we've hit the end of minibatch we need to sample more actions
539 | if self.training:
540 | if self.step_idx == 0 and self.episode_idx-1 == 0:
541 | samples = self.optimizer.ask()
542 | self.actions = torch.empty(self.action_dim, self.horizon, self.batch_size)
543 | for ep_idx, ep_actions in enumerate(samples):
544 | self.actions[:, :, ep_idx] = torch.reshape(ep_actions, (self.action_dim, self.horizon))
545 |
546 | # Get action
547 | action = self.actions[:, self.step_idx, self.episode_idx-1]
548 |
549 | else:
550 | if self.method != "CMA-ES":
551 | if self.step_idx == 0:
552 | samples = self.optimizer.ask(testing=~self.training)
553 | self.actions = torch.empty(self.action_dim, self.horizon, self.batch_size)
554 | for ep_idx, ep_actions in enumerate(samples):
555 | self.actions[:, :, ep_idx] = torch.reshape(ep_actions, (self.action_dim, self.horizon))
556 |
557 | # Get action
558 | action = self.actions[:, self.step_idx, 0]
559 |
560 | return action.double()
561 |
562 | def optimize(self, batch_loss):
563 | loss, meanFval = self.optimizer.tell(batch_loss)
564 | return {"objective_loss": float(meanFval), "total_loss": float(loss)}
565 |
566 | def get_clamped_sd(self):
567 | return self.optimizer.getClampedSd().reshape([self.optimizer.original_dim, self.optimizer.steps]).detach().numpy()
568 |
569 | def get_clamped_action(self):
570 | return self.actions.detach().numpy()
571 |
--------------------------------------------------------------------------------
/model/blocks/policy/trajopt.py:
--------------------------------------------------------------------------------
1 | from .base import BasePolicy
2 | import torch
3 | from torch.nn.parameter import Parameter
4 | from .strategies import *
5 | import numpy as np
6 |
7 |
8 | class TrajOpt(BasePolicy):
9 | """Trajectory Optimization Network"""
10 | def __init__(self, cfg, agent):
11 | super(TrajOpt, self).__init__(cfg, agent)
12 |
13 | # Parametrize optimization actions
14 | action_size = self.agent.action_space.sample().shape[0]
15 | horizon = cfg.MODEL.POLICY.MAX_HORIZON_STEPS
16 |
17 | #self.action_mean = Parameter(torch.zeros(horizon, action_size))
18 | #nn.init.normal_(self.action_mean, mean=0.0, std=1.0)
19 | #self.register_parameter("action_mean", self.action_mean)
20 |
21 | # Get standard deviations as well when doing variational optimization
22 | #if policy_cfg.VARIATIONAL:
23 | # self.action_std = Parameter(torch.empty(horizon, action_size).fill_(-2))
24 | # self.register_parameter("action_std", self.action_std)
25 |
26 | self.strategy = PRVO(dim=np.array([action_size, horizon]), nbatch=cfg.MODEL.POLICY.BATCH_SIZE,
27 | gamma=cfg.MODEL.POLICY.GAMMA, learning_rate=cfg.SOLVER.BASE_LR)
28 |
29 | # Set index to zero
30 | self.step_index = 0
31 | self.episode_index = 0
32 |
33 | def forward(self, s):
34 | #if self.policy_cfg.VARIATIONAL:
35 | # action = torch.distributions.Normal(self.action_mean[self.index], torch.exp(self.action_std[self.index])).rsample()
36 | #else:
37 | # action = self.action_mean[self.index]
38 |
39 | # Sample a new set of actions when required
40 | if self.step_index == 0:
41 | self.actions = self.strategy.sample(self.training)
42 |
43 | action = self.actions[:, self.step_index]
44 | self.step_index += 1
45 | return action
46 |
47 | def episode_callback(self):
48 | self.step_index = 0
49 |
50 | def optimize(self, batch_loss):
51 | return self.strategy.optimize(batch_loss)
52 |
53 |
--------------------------------------------------------------------------------
/model/blocks/policy/wrappers/__init__.py:
--------------------------------------------------------------------------------
1 | from .zmus import ZMUSWrapper
2 |
3 | __all__ = ["ZMUSWrapper"]
4 |
--------------------------------------------------------------------------------
/model/blocks/policy/wrappers/zmus.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.parameter import Parameter
4 |
5 |
6 | class ZMUSWrapper(nn.Module):
7 | """Zero-mean Unit-STD States
8 | see: https://stats.stackexchange.com/questions/43159/how-to-calculate-pooled-variance-of-two-groups-given-known-group-variances-mean
9 | """
10 |
11 | def __init__(self, policy_net, eps=1e-6):
12 | super(ZMUSWrapper, self).__init__()
13 |
14 | self.eps = eps
15 | self.policy_net = policy_net
16 | self.policy_cfg = self.policy_net.policy_cfg
17 |
18 | # Parameters
19 | state_size = self.policy_net.agent.observation_space.shape
20 | self.state_mean = Parameter(torch.Tensor(state_size, ))
21 | self.state_variance = Parameter(torch.Tensor(state_size, ))
22 | self.state_mean.requires_grad = False
23 | self.state_variance.requires_grad = False
24 |
25 | # cash
26 | self.size = 0
27 | self.ep_states_data = []
28 |
29 | self.first_pass = True
30 |
31 | def _get_state_mean(self):
32 | return self.state_mean.detach()
33 |
34 | def _get_state_variance(self):
35 | return self.state_variance.detach()
36 |
37 | def forward(self, s):
38 | self.size += 1
39 | self.ep_states_data.append(s)
40 | if not self.first_pass:
41 | s = (s - self._get_state_mean()) / \
42 | (torch.sqrt(self._get_state_variance()) + self.eps)
43 | return self.policy_net(s)
44 |
45 | def episode_callback(self):
46 | ep_states_tensor = torch.stack(self.ep_states_data)
47 | new_data_mean = torch.mean(ep_states_tensor, dim=0)
48 | new_data_var = torch.var(ep_states_tensor, dim=0)
49 | if self.first_pass:
50 | self.state_mean.data = new_data_mean
51 | self.state_variance.data = new_data_var
52 | self.first_pass = False
53 | else:
54 | n = len(self.ep_states_data)
55 | mean = self._get_state_mean()
56 | var = self._get_state_variance()
57 | new_data_mean_sq = torch.mul(new_data_mean, new_data_mean)
58 | size = min(self.policy_cfg.FORGET_COUNT_OBS_SCALER, self.size)
59 | new_mean = ((mean * size) + (new_data_mean * n)) / (size + n)
60 | new_var = (((size * (var + torch.mul(mean, mean))) + (n * (new_data_var + new_data_mean_sq))) /
61 | (size + n) - torch.mul(new_mean, new_mean))
62 | self.state_mean.data = new_mean
63 | self.state_variance.data = torch.clamp(new_var, 0.) # occasionally goes negative, clip
64 | self.size += n
65 |
66 | def batch_callback(self):
67 | pass
68 |
--------------------------------------------------------------------------------
/model/blocks/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .build import build_soft_lower_bound_fn
2 |
3 | __all__ = ["build_soft_lower_bound_fn"]
4 |
--------------------------------------------------------------------------------
/model/blocks/utils/build.py:
--------------------------------------------------------------------------------
1 | import functools
2 | from .functions import soft_lower_bound
3 |
4 |
5 | def build_soft_lower_bound_fn(policy_cfg):
6 | soft_lower_bound_fn = functools.partial(soft_lower_bound,
7 | bound=policy_cfg.SOFT_LOWER_STD_BOUND,
8 | threshold=policy_cfg.SOFT_LOWER_STD_THRESHOLD,
9 | )
10 | return soft_lower_bound_fn
11 |
--------------------------------------------------------------------------------
/model/blocks/utils/functions.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def soft_lower_bound(x, bound, threshold):
5 | range = threshold - bound
6 | return torch.max(x, range * torch.tanh((x - threshold) / range) + threshold)
7 |
--------------------------------------------------------------------------------
/model/build.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from model import archs
3 |
4 |
5 | def build_model(cfg, agent):
6 | model_factory = getattr(archs, cfg.MODEL.META_ARCHITECTURE)
7 | model = model_factory(cfg, agent)
8 | # if cfg.MODEL.WEIGHTS != "":
9 | # model.load_state_dict(torch.load(cfg.MODEL.WEIGHTS))
10 | return model
11 |
--------------------------------------------------------------------------------
/model/config/__init__.py:
--------------------------------------------------------------------------------
1 | from .defaults import get_cfg_defaults
2 |
3 | __all__ = ["get_cfg_defaults"]
4 |
--------------------------------------------------------------------------------
/model/config/defaults.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 |
3 | # ---------------------------------------------------------------------------- #
4 | # Define Config Node
5 | # ---------------------------------------------------------------------------- #
6 | _C = CN()
7 |
8 | # ---------------------------------------------------------------------------- #
9 | # Model Configs
10 | # ---------------------------------------------------------------------------- #
11 | _C.MODEL = CN()
12 | _C.MODEL.META_ARCHITECTURE = 'Basic'
13 | _C.MODEL.DEVICE = "cpu"
14 | _C.MODEL.WEIGHTS = "" # should be a path to .pth pytorch state dict file
15 | _C.MODEL.EPOCHS = 10000
16 | _C.MODEL.BATCH_SIZE = 8
17 |
18 | # ---------------------------------------------------------------------------- #
19 | # __Policy Net Configs
20 | # ---------------------------------------------------------------------------- #
21 | _C.MODEL.POLICY = CN()
22 |
23 | _C.MODEL.POLICY = CN()
24 | _C.MODEL.POLICY.ARCH = "StochasticPolicy"
25 | _C.MODEL.POLICY.MAX_HORIZON_STEPS = 100
26 | _C.MODEL.POLICY.LAYERS = [64, 64, 32, 8] # a list of hidden layer sizes for output fc. [] means no hidden
27 | #_C.MODEL.POLICY.NORM_LAYERS = [0, 1, 2] # should be a list of layer indices, example [0, 1, ...]
28 | _C.MODEL.POLICY.STD_SCALER = 1e-1
29 | _C.MODEL.POLICY.SOFT_LOWER_STD_BOUND = 1e-4
30 | _C.MODEL.POLICY.SOFT_LOWER_STD_THRESHOLD = 1e-1
31 | _C.MODEL.POLICY.OBS_SCALER = False
32 | _C.MODEL.POLICY.FORGET_COUNT_OBS_SCALER = 5000
33 | _C.MODEL.POLICY.GAMMA = 0.99
34 | _C.MODEL.POLICY.METHOD = "None"
35 | _C.MODEL.POLICY.INITIAL_LOG_SD = 0.0
36 | _C.MODEL.POLICY.INITIAL_SD = 0.0
37 | _C.MODEL.POLICY.INITIAL_ACTION_MEAN = 0.0
38 | _C.MODEL.POLICY.INITIAL_ACTION_SD = 0.1
39 | _C.MODEL.POLICY.GRAD_WEIGHTS = 'average'
40 | _C.MODEL.POLICY.NETWORK = False
41 | _C.MODEL.NSTEPS_FOR_BACKWARD = 1
42 | _C.MODEL.FRAME_SKIP = 1
43 | _C.MODEL.TIMESTEP = 0.0
44 | _C.MODEL.RANDOM_SEED = 0
45 |
46 | # ---------------------------------------------------------------------------- #
47 | # Model Configs
48 | # ---------------------------------------------------------------------------- #
49 | _C.MUJOCO = CN()
50 | _C.MUJOCO.ENV = 'InvertedPendulumEnv'
51 | _C.MUJOCO.ASSETS_PATH = "./mujoco/assets/"
52 | _C.MUJOCO.REWARD_SCALE = 1
53 | _C.MUJOCO.CLIP_ACTIONS = True
54 | _C.MUJOCO.POOL_SIZE = CN()
55 |
56 | # ---------------------------------------------------------------------------- #
57 | # Experience Replay
58 | # ---------------------------------------------------------------------------- #
59 | _C.EXPERIENCE_REPLAY = CN()
60 | _C.EXPERIENCE_REPLAY.SIZE = 2 ** 15
61 | _C.EXPERIENCE_REPLAY.SHUFFLE = True
62 | _C.EXPERIENCE_REPLAY.ENV_INIT_STATE_NUM = 2 ** 15 * 3 / 4
63 |
64 | # ---------------------------------------------------------------------------- #
65 | # Solver Configs
66 | # ---------------------------------------------------------------------------- #
67 | _C.SOLVER = CN()
68 |
69 | _C.SOLVER.OPTIMIZER = 'adam'
70 | _C.SOLVER.BASE_LR = 0.001
71 | _C.SOLVER.STD_LR_FACTOR = 0.001
72 | _C.SOLVER.BIAS_LR_FACTOR = 2
73 | _C.SOLVER.WEIGHT_DECAY_SD = 0.0
74 |
75 | _C.SOLVER.MOMENTUM = 0.9
76 |
77 | _C.SOLVER.WEIGHT_DECAY = 0.0005
78 | _C.SOLVER.WEIGHT_DECAY_BIAS = 0
79 | _C.SOLVER.ADAM_BETAS = (0.9, 0.999)
80 |
81 | # ---------------------------------------------------------------------------- #
82 | # Output Configs
83 | # ---------------------------------------------------------------------------- #
84 | _C.OUTPUT = CN()
85 | _C.OUTPUT.DIR = './output'
86 | _C.OUTPUT.NAME = 'timestamp'
87 |
88 | # ---------------------------------------------------------------------------- #
89 | # Log Configs
90 | # ---------------------------------------------------------------------------- #
91 | _C.LOG = CN()
92 | _C.LOG.PERIOD = 1
93 | _C.LOG.PLOT = CN()
94 | _C.LOG.PLOT.ENABLED = True
95 | _C.LOG.PLOT.DISPLAY_PORT = 8097
96 | _C.LOG.PLOT.ITER_PERIOD = 1 # effective plotting step is _C.LOG.PERIOD * LOG.PLOT.ITER_PERIOD
97 | _C.LOG.TESTING = CN()
98 | _C.LOG.TESTING.ENABLED = True
99 | _C.LOG.TESTING.ITER_PERIOD = 1
100 | _C.LOG.TESTING.RECORD_VIDEO = False
101 | _C.LOG.TESTING.COUNT_PER_ITER = 1
102 | _C.LOG.CHECKPOINT_PERIOD = 25000
103 |
104 |
105 | def get_cfg_defaults():
106 | """Get a yacs CfgNode object with default values for my_project."""
107 | # Return a clone so that the defaults will not be altered
108 | # This is for the "local variable" use pattern
109 | return _C.clone()
110 |
--------------------------------------------------------------------------------
/model/engine/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/model/engine/__init__.py
--------------------------------------------------------------------------------
/model/engine/dynamics_model_trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import os
4 |
5 | from model.engine.tester import do_testing
6 | import utils.logger as lg
7 | from utils.visdom_plots import VisdomLogger
8 | from mujoco import build_agent
9 | from model import build_model
10 | from model.blocks.policy.dynamics import DynamicsModel
11 | import torchviz
12 |
13 |
14 | def do_training(
15 | cfg,
16 | logger,
17 | output_results_dir,
18 | output_rec_dir,
19 | output_weights_dir
20 | ):
21 | # Build the agent
22 | agent = build_agent(cfg)
23 |
24 | # Build a forward dynamics model
25 | dynamics_model = DynamicsModel(agent)
26 |
27 | # Set mode to training (aside from policy output, matters for Dropout, BatchNorm, etc.)
28 | dynamics_model.train()
29 |
30 | # Set up visdom
31 | if cfg.LOG.PLOT.ENABLED:
32 | visdom = VisdomLogger(cfg.LOG.PLOT.DISPLAY_PORT)
33 | visdom.register_keys(['total_loss', 'average_sd', 'average_action', "reinforce_loss",
34 | "objective_loss", "sd", "action_grad", "sd_grad", "actions"])
35 |
36 | # wrap screen recorder if testing mode is on
37 | if cfg.LOG.TESTING.ENABLED:
38 | if cfg.LOG.PLOT.ENABLED:
39 | visdom.register_keys(['test_reward'])
40 |
41 | # Collect losses here
42 | output = {"epoch": [], "objective_loss": []}
43 |
44 | # Start training
45 | for epoch_idx in range(cfg.MODEL.EPOCHS):
46 | batch_loss = torch.empty(cfg.MODEL.BATCH_SIZE, cfg.MODEL.POLICY.MAX_HORIZON_STEPS, dtype=torch.float64)
47 | batch_loss.fill_(np.nan)
48 |
49 | for episode_idx in range(cfg.MODEL.BATCH_SIZE):
50 |
51 | # Generate "random walk" set of actions (separately for each dimension)
52 | action = torch.zeros(agent.action_space.shape, dtype=torch.float64)
53 | #actions = np.zeros((agent.action_space.shape[0], cfg.MODEL.POLICY.MAX_HORIZON_STEPS))
54 | actions = []
55 |
56 | initial_state = torch.Tensor(agent.reset())
57 | predicted_states = []
58 | real_states = []
59 | corrections = []
60 | for step_idx in range(cfg.MODEL.POLICY.MAX_HORIZON_STEPS):
61 |
62 | # Generate random actions
63 | action = action + 0.1*(2*torch.rand(agent.action_space.shape) - 1)
64 |
65 | # Clamp to [-1, 1]
66 | action.clamp_(-1, 1)
67 |
68 | # Save action
69 | actions.append(action)
70 |
71 | previous_state = torch.from_numpy(agent.unwrapped._get_obs())
72 |
73 | # Advance the actual simulation
74 | next_state, _, _, _ = agent.step(action)
75 | next_state = torch.from_numpy(next_state)
76 | real_states.append(next_state)
77 |
78 | # Advance with learned dynamics simulation
79 | pred_next_state = dynamics_model(previous_state.float(), action.float()).double()
80 |
81 | batch_loss[episode_idx, step_idx] = torch.pow(next_state - pred_next_state, 2).mean()
82 | #if agent.is_done:
83 | # break
84 |
85 | #dot = torchviz.make_dot(pred_next_state, params=dict(dynamics_model.named_parameters()))
86 |
87 | loss = torch.sum(batch_loss)
88 | dynamics_model.optimizer.zero_grad()
89 | loss.backward()
90 | dynamics_model.optimizer.step()
91 |
92 | output["objective_loss"].append(loss.detach().numpy())
93 | output["epoch"].append(epoch_idx)
94 |
95 | if epoch_idx % cfg.LOG.PERIOD == 0:
96 |
97 | if cfg.LOG.PLOT.ENABLED:
98 | visdom.update({"total_loss": loss.detach().numpy()})
99 | visdom.set({'actions': torch.stack(actions).detach().numpy()})
100 | #visdom.set({'total_loss': loss["total_loss"].transpose()})
101 | #visdom.update({'average_grad': np.log(torch.mean(model.policy_net.mean._layers["linear_layer_0"].weight.grad.abs()).detach().numpy())})
102 |
103 | logger.info("REWARD: \t\t{} (iteration {})".format(loss.detach().numpy(), epoch_idx))
104 |
105 | if cfg.LOG.PLOT.ENABLED and epoch_idx % cfg.LOG.PLOT.ITER_PERIOD == 0:
106 | visdom.do_plotting()
107 |
108 | # if epoch_idx % cfg.LOG.CHECKPOINT_PERIOD == 0:
109 | # torch.save(model.state_dict(),
110 | # os.path.join(output_weights_dir, 'iter_{}.pth'.format(epoch_idx)))
111 |
112 | if False:#cfg.LOG.TESTING.ENABLED:
113 | if epoch_idx % cfg.LOG.TESTING.ITER_PERIOD == 0:
114 |
115 | # Record if required
116 | agent.start_recording(os.path.join(output_rec_dir, "iter_{}.mp4".format(epoch_idx)))
117 |
118 | test_rewards = []
119 | for _ in range(cfg.LOG.TESTING.COUNT_PER_ITER):
120 | test_reward = do_testing(
121 | cfg,
122 | model,
123 | agent,
124 | # first_state=state_xr.get_item(),
125 | )
126 | test_rewards.append(test_reward)
127 |
128 | # Set training mode on again
129 | model.train()
130 |
131 | # Close the recorder
132 | agent.stop_recording()
133 |
134 | # Save outputs into log folder
135 | lg.save_dict_into_csv(output_results_dir, "output", output)
136 |
137 | # Save model
138 | torch.save(dynamics_model.state_dict(), os.path.join(output_weights_dir, "final_weights.pt"))
139 |
--------------------------------------------------------------------------------
/model/engine/landscape_plot.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import os
4 |
5 | from model import build_model
6 | import matplotlib.pyplot as pp
7 |
8 |
9 | def visualise2d(agent, output_dir, iter):
10 |
11 | # Get actions
12 | actions = agent.policy_net.optimizer.mean.clone().detach()
13 |
14 | # Build the model
15 | model = build_model(agent.cfg, agent)
16 | device = torch.device(agent.cfg.MODEL.DEVICE)
17 | model.to(device)
18 |
19 | # Set mode to test, we're not interested in optimizing the actions now
20 | model.eval()
21 |
22 | # Choose two random directions
23 | i_dir = torch.from_numpy(np.random.randn(actions.shape[0])).detach()
24 | j_dir = torch.from_numpy(np.random.randn(actions.shape[0])).detach()
25 |
26 | # Define range
27 | i_range = torch.from_numpy(np.linspace(-1, 1, 100)).requires_grad_()
28 | j_range = torch.from_numpy(np.linspace(-1, 1, 100)).requires_grad_()
29 |
30 | # Collect losses
31 | loss = torch.zeros((len(i_range), len(j_range)))
32 |
33 | # Collect grads
34 | #grads = torch.zeros((len(i_range), len(j_range), 2))
35 |
36 | # Need two loops for two directions
37 | for i_idx, i in enumerate(i_range):
38 | for j_idx, j in enumerate(j_range):
39 |
40 | # Calculate new parameters
41 | new_actions = actions + i*i_dir + j*j_dir
42 |
43 | # Set new actions
44 | model.policy_net.optimizer.mean = new_actions
45 |
46 | # Loop through whole simulation
47 | state = torch.Tensor(agent.reset())
48 | for step_idx in range(agent.cfg.MODEL.POLICY.MAX_HORIZON_STEPS):
49 |
50 | # Advance the simulation
51 | state, reward = model(state.detach())
52 |
53 | # Collect losses
54 | loss[i_idx, j_idx] += -reward.squeeze()
55 |
56 | # Get gradients
57 | #loss[i_idx, j_idx].backward(retain_graph=True)
58 | #grads[i_idx, j_idx, 0] = i_range.grad[i_idx]
59 | #grads[i_idx, j_idx, 1] = j_range.grad[j_idx]
60 |
61 | # Do some plotting here
62 | pp.figure(figsize=(12, 12))
63 | contours = pp.contour(i_range.detach().numpy(), j_range.detach().numpy(), loss.detach().numpy(), colors='black')
64 | pp.clabel(contours, inline=True, fontsize=8)
65 | pp.imshow(loss.detach().numpy(), extent=[-1, 1, -1, 1], origin="lower", cmap="RdGy", alpha=0.5)
66 | pp.colorbar()
67 | pp.savefig(os.path.join(output_dir, "contour_{}.png".format(iter)))
68 |
--------------------------------------------------------------------------------
/model/engine/tester.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def do_testing(
5 | cfg,
6 | model,
7 | agent,
8 | samples=None,
9 | first_state=None
10 | ):
11 |
12 | # Let pytorch know we're evaluating a model
13 | model.eval()
14 |
15 | # We don't need gradients now
16 | with torch.no_grad():
17 |
18 | if first_state is None:
19 | state = torch.DoubleTensor(agent.reset(update_episode_idx=False))
20 | else:
21 | state = first_state
22 | agent.set_from_torch_state(state)
23 | reward_sum = 0.
24 | episode_iteration = 0
25 | for step_idx in range(cfg.MODEL.POLICY.MAX_HORIZON_STEPS):
26 | if cfg.LOG.TESTING.RECORD_VIDEO:
27 | agent.capture_frame()
28 | else:
29 | agent.render()
30 | #state, reward = model(state, samples[:, step_idx])
31 | state, reward = model(state)
32 | reward_sum += reward
33 | if agent.is_done:
34 | break
35 | return reward_sum/episode_iteration
36 |
--------------------------------------------------------------------------------
/model/engine/trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import os
4 |
5 | from model.engine.tester import do_testing
6 | import utils.logger as lg
7 | from utils.visdom_plots import VisdomLogger
8 | from mujoco import build_agent
9 | from model import build_model
10 |
11 |
12 | def do_training(
13 | cfg,
14 | logger,
15 | output_results_dir,
16 | output_rec_dir,
17 | output_weights_dir,
18 | iter
19 | ):
20 |
21 | if cfg.MODEL.RANDOM_SEED > 0:
22 | np.random.seed(cfg.MODEL.RANDOM_SEED + iter)
23 | torch.manual_seed(cfg.MODEL.RANDOM_SEED + iter)
24 |
25 | # Build the agent
26 | agent = build_agent(cfg)
27 |
28 | # Build the model
29 | model = build_model(cfg, agent)
30 | device = torch.device(cfg.MODEL.DEVICE)
31 | model.to(device)
32 |
33 | # Set mode to training (aside from policy output, matters for Dropout, BatchNorm, etc.)
34 | model.train()
35 |
36 | # Set up visdom
37 | if cfg.LOG.PLOT.ENABLED:
38 | visdom = VisdomLogger(cfg.LOG.PLOT.DISPLAY_PORT)
39 | visdom.register_keys(['total_loss', 'average_sd', 'average_action', "reinforce_loss",
40 | "objective_loss", "sd", "action_grad", "sd_grad", "average_grad"])
41 | for action_idx in range(model.policy_net.action_dim):
42 | visdom.register_keys(["action_" + str(action_idx)])
43 |
44 | # wrap screen recorder if testing mode is on
45 | if cfg.LOG.TESTING.ENABLED:
46 | if cfg.LOG.PLOT.ENABLED:
47 | visdom.register_keys(['test_reward'])
48 |
49 | # Collect losses here
50 | output = {"epoch": [], "objective_loss": [], "average_sd": []}
51 |
52 | # Start training
53 | for epoch_idx in range(cfg.MODEL.EPOCHS):
54 | batch_loss = torch.empty(cfg.MODEL.BATCH_SIZE, cfg.MODEL.POLICY.MAX_HORIZON_STEPS, dtype=torch.float64)
55 | batch_loss.fill_(np.nan)
56 |
57 | for episode_idx in range(cfg.MODEL.BATCH_SIZE):
58 |
59 | initial_state = torch.DoubleTensor(agent.reset())
60 | states = []
61 | states.append(initial_state)
62 | #grads = np.zeros((cfg.MODEL.POLICY.MAX_HORIZON_STEPS, 120))
63 | for step_idx in range(cfg.MODEL.POLICY.MAX_HORIZON_STEPS):
64 | state, reward = model(states[step_idx])
65 | batch_loss[episode_idx, step_idx] = -reward
66 | #(-reward).backward(retain_graph=True)
67 | #grads[step_idx, :] = model.policy_net.optimizer.mean.grad.detach().numpy()
68 | #grads[step_idx, step_idx+1:40] = np.nan
69 | #grads[step_idx, 40+step_idx+1:80] = np.nan
70 | #grads[step_idx, 80+step_idx+1:] = np.nan
71 | #model.policy_net.optimizer.optimizer.zero_grad()
72 | states.append(state)
73 | if agent.is_done:
74 | break
75 |
76 | agent.running_sum = 0
77 | loss = model.policy_net.optimize(batch_loss)
78 | #zero = np.abs(grads) < 1e-9
79 | #grads[zero] = np.nan
80 | #medians = np.nanmedian(grads, axis=0)
81 | #model.policy_net.optimizer.mean.grad.data = torch.from_numpy(medians)
82 | #torch.nn.utils.clip_grad_value_([model.policy_net.optimizer.mean, model.policy_net.optimizer.sd], 1)
83 | #model.policy_net.optimizer.optimizer.step()
84 | #model.policy_net.optimizer.optimizer.zero_grad()
85 | #loss = {'objective_loss': torch.sum(batch_loss, dim=1).mean().detach().numpy()}
86 |
87 | output["objective_loss"].append(loss["objective_loss"])
88 | output["epoch"].append(epoch_idx)
89 | output["average_sd"].append(np.mean(model.policy_net.get_clamped_sd()))
90 |
91 | if epoch_idx % cfg.LOG.PERIOD == 0:
92 |
93 | if cfg.LOG.PLOT.ENABLED:
94 | visdom.update(loss)
95 | #visdom.set({'total_loss': loss["total_loss"].transpose()})
96 |
97 | clamped_sd = model.policy_net.get_clamped_sd()
98 | clamped_action = model.policy_net.get_clamped_action()
99 |
100 | #visdom.update({'average_grad': np.log(torch.mean(model.policy_net.mean._layers["linear_layer_0"].weight.grad.abs()).detach().numpy())})
101 |
102 | if len(clamped_sd) > 0:
103 | visdom.update({'average_sd': np.mean(clamped_sd, axis=1)})
104 | visdom.update({'average_action': np.mean(clamped_action, axis=(1, 2)).squeeze()})
105 |
106 | for action_idx in range(model.policy_net.action_dim):
107 | visdom.set({'action_'+str(action_idx): clamped_action[action_idx, :, :]})
108 | if clamped_sd is not None:
109 | visdom.set({'sd': clamped_sd.transpose()})
110 | # visdom.set({'action_grad': model.policy_net.mean.grad.detach().numpy().transpose()})
111 |
112 | logger.info("REWARD: \t\t{} (iteration {})".format(loss["objective_loss"], epoch_idx))
113 |
114 | if cfg.LOG.PLOT.ENABLED and epoch_idx % cfg.LOG.PLOT.ITER_PERIOD == 0:
115 | visdom.do_plotting()
116 |
117 | if epoch_idx % cfg.LOG.CHECKPOINT_PERIOD == 0:
118 | torch.save(model.state_dict(),
119 | os.path.join(output_weights_dir, 'iter_{}.pth'.format(epoch_idx)))
120 |
121 | if cfg.LOG.TESTING.ENABLED:
122 | if epoch_idx % cfg.LOG.TESTING.ITER_PERIOD == 0:
123 |
124 | # Record if required
125 | agent.start_recording(os.path.join(output_rec_dir, "iter_{}_{}.mp4".format(iter, epoch_idx)))
126 |
127 | test_rewards = []
128 | for _ in range(cfg.LOG.TESTING.COUNT_PER_ITER):
129 | test_reward = do_testing(
130 | cfg,
131 | model,
132 | agent,
133 | # first_state=state_xr.get_item(),
134 | )
135 | test_rewards.append(test_reward)
136 |
137 | # Set training mode on again
138 | model.train()
139 |
140 | # Close the recorder
141 | agent.stop_recording()
142 |
143 | # Save outputs into log folder
144 | lg.save_dict_into_csv(output_results_dir, "output_{}".format(iter), output)
145 |
146 | # Return actions
147 | return agent
--------------------------------------------------------------------------------
/model/engine/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .build import build_state_experience_replay, build_state_experience_replay_data_loader
2 |
3 | __all__ = ["build_state_experience_replay", "build_state_experience_replay_data_loader"]
4 |
--------------------------------------------------------------------------------
/model/engine/utils/build.py:
--------------------------------------------------------------------------------
1 | from .experience_replay import StateQueue, make_data_sampler, make_batch_data_sampler, batch_collator, make_data_loader
2 |
3 |
4 | def build_state_experience_replay_data_loader(cfg):
5 | state_queue = StateQueue(cfg.EXPERIENCE_REPLAY.SIZE)
6 | sampler = make_data_sampler(state_queue, cfg.EXPERIENCE_REPLAY.SHUFFLE)
7 | batch_sampler = make_batch_data_sampler(sampler, cfg.SOLVER.BATCH_SIZE)
8 | data_loader = make_data_loader(state_queue, batch_sampler, batch_collator)
9 | return data_loader
10 |
11 |
12 | def build_state_experience_replay(cfg):
13 | state_queue = StateQueue(cfg.EXPERIENCE_REPLAY.SIZE)
14 | return state_queue
15 |
--------------------------------------------------------------------------------
/model/engine/utils/experience_replay.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data
3 | import random
4 |
5 |
6 | class StateQueue(object):
7 |
8 | def __init__(self, size):
9 | self.size = size
10 | self.queue = []
11 |
12 | def __getitem__(self, idx):
13 | idx = idx % len(self.queue)
14 | return self.queue[idx]
15 |
16 | def __len__(self):
17 | return len(self.queue) or self.size
18 |
19 | def add(self, state):
20 | self.queue.append(state)
21 | self.queue = self.queue[-self.size:]
22 |
23 | def add_batch(self, states):
24 | self.queue.extend([state for state in states])
25 | self.queue = self.queue[-self.size:]
26 |
27 | def get_item(self):
28 | return random.choice(self.queue)
29 |
30 |
31 | def make_data_sampler(dataset, shuffle):
32 | if shuffle:
33 | sampler = torch.utils.data.sampler.RandomSampler(dataset)
34 | else:
35 | sampler = torch.utils.data.sampler.SequentialSampler(dataset)
36 | return sampler
37 |
38 |
39 | def make_batch_data_sampler(sampler, batch_size):
40 | batch_sampler = torch.utils.data.sampler.BatchSampler(
41 | sampler, batch_size, drop_last=False
42 | )
43 | return batch_sampler
44 |
45 |
46 | def batch_collator(batch):
47 | return torch.stack(batch, dim=0)
48 |
49 |
50 | def make_data_loader(dataset, batch_sampler, collator):
51 | data_loader = torch.utils.data.DataLoader(
52 | dataset,
53 | batch_sampler=batch_sampler,
54 | collate_fn=collator,
55 | )
56 | return data_loader
57 |
--------------------------------------------------------------------------------
/model/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .feed_forward import FeedForward
2 |
3 | __all__ = ["FeedForward"]
4 |
--------------------------------------------------------------------------------
/model/layers/feed_forward.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from itertools import tee
4 | from collections import OrderedDict
5 | from torch.nn.parameter import Parameter
6 | from torch.nn import init
7 | from torch.nn import functional
8 | import math
9 |
10 | def pairwise(iterable):
11 | """s -> (s0,s1), (s1,s2), (s2, s3), ..."""
12 | a, b = tee(iterable)
13 | next(b, None)
14 | return zip(a, b)
15 |
16 |
17 | class Identity(nn.Module):
18 | def __init__(self):
19 | super(Identity, self).__init__()
20 |
21 | def forward(self, x):
22 | return x
23 |
24 |
25 | class FeedForward(nn.Module):
26 | def __init__(self, input_size, middle_sizes, output_size, norm_layers=[], activation_fn=nn.LeakyReLU, activation_out=False):
27 | """
28 | :param input_size: int input feature size
29 | :param middle_sizes: [int] list of intermediate hidden state sizes
30 | :param output_size: int output size of the network
31 | """
32 | # TODO: use setattr for linear layers and activations, parameters are empty
33 | super(FeedForward, self).__init__()
34 | self._sizes = [input_size] + middle_sizes + [output_size]
35 | self._layers = OrderedDict()
36 | for i, (in_size, out_size) in enumerate(pairwise(self._sizes)):
37 | # Add linear layer
38 | #linear_layer = Variational(in_size, out_size, bias=True)
39 | linear_layer = nn.Linear(in_size, out_size, bias=True)
40 | self.__setattr__('linear_layer_{}'.format(str(i)), linear_layer)
41 | self._layers.update({'linear_layer_{}'.format(str(i)): linear_layer})
42 | # Add batch normalization layer
43 | if i in norm_layers:
44 | batchnorm_layer = nn.BatchNorm1d(out_size)
45 | self.__setattr__('batchnorm_layer_{}'.format(str(i)), batchnorm_layer)
46 | self._layers.update({'batchnorm_layer_{}'.format(str(i)): batchnorm_layer})
47 | # Add activation layer
48 | self.__setattr__('activation_layer_{}'.format(str(i)), activation_fn()) # relu for the last layer also makes sense
49 | self._layers.update({'activation_layer_{}'.format(str(i)): activation_fn()})
50 | if not activation_out:
51 | self._layers.popitem()
52 |
53 | self.sequential = nn.Sequential(self._layers)
54 |
55 | def forward(self, x):
56 | # out = x
57 | # for i in range(len(self._sizes) - 1):
58 | # fc = self.__getattr__('linear_layer_' + str(i))
59 | # ac = self.__getattr__('activation_layer_' + str(i))
60 | # out = ac(fc(out))
61 | out = self.sequential(x)
62 | return out
63 |
64 |
65 | class Variational(nn.Module):
66 | __constants__ = ['bias', 'in_features', 'out_features']
67 |
68 | def __init__(self, in_features, out_features, bias=True):
69 | super(Variational, self).__init__()
70 | self.in_features = in_features
71 | self.out_features = out_features
72 | self.weight_mean = Parameter(torch.Tensor(out_features, in_features))
73 | self.weight_std = Parameter(torch.Tensor(out_features, in_features))
74 | if bias:
75 | self.bias_mean = Parameter(torch.Tensor(out_features))
76 | self.bias_std = Parameter(torch.Tensor(out_features))
77 | else:
78 | self.register_parameter('bias_mean', None)
79 | self.register_parameter('bias_std', None)
80 | self.reset_parameters()
81 |
82 | def reset_parameters(self):
83 | init.kaiming_uniform_(self.weight_mean, a=math.sqrt(5))
84 | init.constant_(self.weight_std, 0.03)
85 | if self.bias_mean is not None:
86 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight_mean)
87 | bound = 1 / math.sqrt(fan_in)
88 | init.uniform_(self.bias_mean, -bound, bound)
89 | init.constant_(self.bias_std, 0.03)
90 |
91 | def forward(self, input):
92 | neg_weight = self.weight_std.data < 0
93 | self.weight_std.data[neg_weight] = 0.00001
94 | weight = torch.distributions.Normal(self.weight_mean, self.weight_std).rsample()
95 | neg_bias = self.bias_std.data < 0
96 | self.bias_std.data[neg_bias] = 0.00001
97 | bias = torch.distributions.Normal(self.bias_mean, self.bias_std).rsample()
98 | return functional.linear(input, weight, bias)
99 |
100 | def extra_repr(self):
101 | return 'in_features={}, out_features={}, bias={}'.format(
102 | self.in_features, self.out_features, self.bias is not None
103 | )
104 |
105 |
--------------------------------------------------------------------------------
/mujoco/__init__.py:
--------------------------------------------------------------------------------
1 | from .build import build_agent
2 |
3 | __all__ = ["build_agent"]
4 |
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/bofoot.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/bofoot.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/calcn_r.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/calcn_r.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/capitate_rFB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/capitate_rFB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/clavicle_rFB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/clavicle_rFB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/femur_r.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/femur_r.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/fibula_r.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/fibula_r.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/foot.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/foot.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/hamate_rFB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/hamate_rFB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/humerus_rFB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/humerus_rFB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/indexDistal_rFB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/indexDistal_rFB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/indexMetacarpal_rFB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/indexMetacarpal_rFB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/indexMid_rFB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/indexMid_rFB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/indexProx_rFB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/indexProx_rFB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/l_pelvis.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/l_pelvis.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/lunate_rFB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/lunate_rFB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/middleDistal_rFB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/middleDistal_rFB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/middleMetacarpal_rFB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/middleMetacarpal_rFB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/middleMid_rFB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/middleMid_rFB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/middleProx_rFB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/middleProx_rFB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/pat.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/pat.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/patella_r.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/patella_r.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/pelvis.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/pelvis.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/pinkyDistal_rFB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/pinkyDistal_rFB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/pinkyMetacarpal_rFB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/pinkyMetacarpal_rFB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/pinkyMid_rFB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/pinkyMid_rFB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/pinkyProx_rFB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/pinkyProx_rFB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/pisiform_rFB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/pisiform_rFB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/radius_rFB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/radius_rFB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/ribcageFB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/ribcageFB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/ringDistal_rFB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/ringDistal_rFB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/ringMetacarpal_rFB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/ringMetacarpal_rFB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/ringMid_rFB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/ringMid_rFB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/ringProx_rFB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/ringProx_rFB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/sacrum.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/sacrum.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/scaphoid_rFB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/scaphoid_rFB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/scapula_rFB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/scapula_rFB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/talus.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/talus.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/talus_r.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/talus_r.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/thoracic10FB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/thoracic10FB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/thoracic11FB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/thoracic11FB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/thoracic12FB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/thoracic12FB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/thoracic1FB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/thoracic1FB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/thoracic2FB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/thoracic2FB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/thoracic3FB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/thoracic3FB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/thoracic4FB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/thoracic4FB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/thoracic5FB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/thoracic5FB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/thoracic6FB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/thoracic6FB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/thoracic7FB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/thoracic7FB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/thoracic8FB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/thoracic8FB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/thoracic9FB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/thoracic9FB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/thumbDistal_rFB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/thumbDistal_rFB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/thumbMid_rFB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/thumbMid_rFB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/thumbProx_rFB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/thumbProx_rFB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/tibia_r.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/tibia_r.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/toes_r.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/toes_r.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/trapezium_rFB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/trapezium_rFB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/trapezoid_rFB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/trapezoid_rFB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/triquetral_rFB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/triquetral_rFB.stl
--------------------------------------------------------------------------------
/mujoco/assets/Geometry/ulna_rFB.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/ulna_rFB.stl
--------------------------------------------------------------------------------
/mujoco/assets/half_cheetah.xml:
--------------------------------------------------------------------------------
1 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
--------------------------------------------------------------------------------
/mujoco/assets/hopper.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
--------------------------------------------------------------------------------
/mujoco/assets/inverted_double_pendulum.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
--------------------------------------------------------------------------------
/mujoco/assets/inverted_pendulum.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
--------------------------------------------------------------------------------
/mujoco/assets/leg6dof9musc_converted.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
--------------------------------------------------------------------------------
/mujoco/assets/swimmer.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
--------------------------------------------------------------------------------
/mujoco/assets/walker2d.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
57 |
58 |
59 |
60 |
61 |
62 |
--------------------------------------------------------------------------------
/mujoco/build.py:
--------------------------------------------------------------------------------
1 | from mujoco import envs
2 | from mujoco.utils.wrappers.mj_block import MjBlockWrapper
3 | from mujoco.utils.wrappers.etc import SnapshotWrapper, IndexWrapper, ViewerWrapper
4 |
5 |
6 | def build_agent(cfg):
7 | agent_factory = getattr(envs, cfg.MUJOCO.ENV)
8 | agent = agent_factory(cfg)
9 | # if cfg.MUJOCO.REWARD_SCALE != 1.0:
10 | # from .utils import RewardScaleWrapper
11 | # agent = RewardScaleWrapper(agent, cfg.MUJOCO.REWARD_SCALE)
12 | # if cfg.MUJOCO.CLIP_ACTIONS:
13 | # from .utils import ClipActionsWrapper
14 | # agent = ClipActionsWrapper(agent)
15 | # if cfg.MODEL.POLICY.ARCH == 'TrajOpt':
16 | # from .utils import FixedStateWrapper
17 | # agent = FixedStateWrapper(agent)
18 |
19 | # Make configs accessible through agent
20 | #agent.cfg = cfg
21 |
22 | # Record video
23 | agent = ViewerWrapper(agent)
24 |
25 | # Keep track of step, episode, and batch indices
26 | agent = IndexWrapper(agent, cfg.MODEL.BATCH_SIZE)
27 |
28 | # Grab and set snapshots of data
29 | agent = SnapshotWrapper(agent)
30 |
31 | # This should probably be last so we get all wrappers
32 | agent = MjBlockWrapper(agent)
33 |
34 | # Maybe we should set opt tolerance to zero so mujoco solvers shouldn't stop early?
35 | agent.model.opt.tolerance = 0
36 |
37 | return agent
38 |
--------------------------------------------------------------------------------
/mujoco/envs/HandModelTSLAdjusted.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from gym import utils
4 | from gym.envs.mujoco import mujoco_env
5 | import xmltodict
6 | from how_bots_type import MocapWithKeyStateData
7 | import Utils
8 | from pyquaternion import Quaternion
9 | from copy import deepcopy
10 |
11 |
12 | class HandModelTSLAdjustedEnv(mujoco_env.MujocoEnv, utils.EzPickle):
13 | """
14 | """
15 | def __init__(self, cfg):
16 |
17 | self.initialised = False
18 | self.cfg = cfg
19 |
20 | # Get model path
21 | mujoco_assets_dir = os.path.abspath("./mujoco/assets/")
22 | mujoco_xml_file = os.path.join(mujoco_assets_dir, "HandModelTSLAdjusted_converted.xml")
23 |
24 | # Add sites to fingertips by reading the model xml file, adding the sites, and saving the model file
25 | with open(mujoco_xml_file) as f:
26 | text = f.read()
27 | p = xmltodict.parse(text)
28 |
29 | # We might want to add a "keyboard" object into the simulation; it should be in front of torso
30 | torso = p["mujoco"]["worldbody"]["body"]
31 | pos = np.array(torso["@pos"].split(), dtype=float)
32 | quat = np.array(torso["@quat"].split(), dtype=float)
33 | T_torso = Utils.create_transformation_matrix(pos=pos, quat=quat)
34 | R_keyboard = np.matmul(Utils.create_rotation_matrix(axis=[0, 0, 1], deg=180), Utils.create_rotation_matrix(axis=[0, 1, 0], deg=-90))
35 | T_keyboard = Utils.create_transformation_matrix(pos=np.array([0.6, -0.1, -0.25]), R=R_keyboard[:3, :3])
36 | T = np.matmul(T_torso, T_keyboard)
37 |
38 | # Now create the keyboard object
39 | #keyboard = {"@name": "keyboard",
40 | # "@pos": Utils.array_to_string(T[:3, 3]),
41 | # "@quat": Utils.array_to_string(Quaternion(matrix=T).elements),
42 | # "geom": {"@type": "box", "@pos": "0.25 0.005 -0.15", "@size": "0.25 0.005 0.15"}}
43 |
44 | keyboard = {"@name": "keyboard",
45 | "@pos": Utils.array_to_string(pos + np.array([0.5, 0, 0])),
46 | "@quat": torso["@quat"],
47 | "geom": {"@type": "sphere", "@size": "0.02 0 0", "@rgba": "0.9 0.1 0.1 1"}}
48 |
49 |
50 | # Add it to the xml file
51 | if not isinstance(p["mujoco"]["worldbody"]["body"], list):
52 | p["mujoco"]["worldbody"]["body"] = [p["mujoco"]["worldbody"]["body"]]
53 | p["mujoco"]["worldbody"]["body"].append(keyboard)
54 |
55 | # Get a reference to trapezium_pre_r in the kinematic tree
56 | trapezium = torso["body"]["body"][1]["body"]["body"]["body"]["body"]["body"]["body"][1]["body"]["body"][2]["body"]["body"]
57 |
58 | # Add fingertip site to thumb
59 | self.add_fingertip_site(trapezium[0]["body"]["body"]["body"]["body"]["body"]["site"],
60 | "fingertip_thumb", "-0.0015 -0.022 0.002")
61 |
62 | # Add fingertip site to index finger
63 | self.add_fingertip_site(trapezium[2]["body"]["body"]["body"]["body"]["site"],
64 | "fingertip_index", "-0.0005 -0.0185 0.001")
65 |
66 | # Add fingertip site to middle finger
67 | self.add_fingertip_site(trapezium[3]["body"]["body"]["body"]["site"],
68 | "fingertip_middle", "0.0005 -0.01875 0.0")
69 |
70 | # Add fingertip site to ring finger
71 | self.add_fingertip_site(trapezium[4]["body"]["body"]["body"]["body"]["site"],
72 | "fingertip_ring", "-0.001 -0.01925 0.00075")
73 |
74 | # Add fingertip site to little finger
75 | self.add_fingertip_site(trapezium[5]["body"]["body"]["body"]["body"]["site"],
76 | "fingertip_little", "-0.0005 -0.0175 0.0")
77 |
78 | # Save the modified model xml file
79 | mujoco_xml_file_modified = os.path.join(mujoco_assets_dir, "HandModelTSLAdjusted_converted_fingertips.xml")
80 | with open(mujoco_xml_file_modified, 'w') as f:
81 | f.write(xmltodict.unparse(p, pretty=True, indent=" "))
82 |
83 | # Load model
84 | self.frame_skip = self.cfg.MODEL.FRAME_SKIP
85 | mujoco_env.MujocoEnv.__init__(self, mujoco_xml_file_modified, self.frame_skip)
86 |
87 | # Set timestep
88 | if cfg.MODEL.TIMESTEP > 0:
89 | self.model.opt.timestep = cfg.MODEL.TIMESTEP
90 |
91 | # The default initial state isn't stable; run a simulation for a couple of seconds to get a stable initial state
92 | duration = 4
93 | for _ in range(int(duration/self.model.opt.timestep)):
94 | self.sim.step()
95 |
96 | # Use these joint values for initial state
97 | self.init_qpos = deepcopy(self.data.qpos)
98 |
99 | # Reset model
100 | self.reset_model()
101 |
102 | # Load a dataset from how-bots-type
103 | if False:
104 | data_filename = "/home/aleksi/Workspace/how-bots-type/data/24fps/000895_sentences_mocap_with_key_states.msgpack"
105 | mocap_data = MocapWithKeyStateData.load(data_filename)
106 | mocap_ds = mocap_data.as_dataset()
107 |
108 | # Get positions of each fingertip, maybe start with just 60 seconds
109 | fps = 24
110 | length = 10
111 | self.fingers = ["thumb", "index", "middle", "ring", "little"]
112 | fingers_coords = mocap_ds.fingers.isel(time=range(0, length*fps)).sel(hand='right', joint='tip', finger=self.fingers).values
113 |
114 | # We need to transform the mocap coordinates to simulation coordinates
115 | for time_idx in range(fingers_coords.shape[0]):
116 | for finger_idx in range(fingers_coords.shape[1]):
117 | coords = np.concatenate((fingers_coords[time_idx, finger_idx, :], np.array([1])))
118 | coords[1] *= -1
119 | coords = np.matmul(T, coords)
120 | fingers_coords[time_idx, finger_idx, :] = coords[:3]
121 |
122 | # Get timestamps in seconds
123 | time = mocap_ds.isel(time=range(0, length*fps)).time.values / np.timedelta64(1, 's')
124 |
125 | # We want to interpolate finger positions every (self.frame_skip * self.model.opt.timestep) second
126 | time_interp = np.arange(0, length, self.frame_skip * self.model.opt.timestep)
127 | self.fingers_targets = {x: np.empty(time_interp.shape + fingers_coords.shape[2:]) for x in self.fingers}
128 | #self.fingers_targets = np.empty(time_interp.shape + fingers_coords.shape[1:])
129 | for finger_idx in range(fingers_coords.shape[1]):
130 | for coord_idx in range(fingers_coords.shape[2]):
131 | #self.fingers_targets[:, finger_idx, coord_idx] = np.interp(time_interp, time, self.fingers_coords[:, self.finger_idx, coord_idx])
132 | self.fingers_targets[self.fingers[finger_idx]][:, coord_idx] = \
133 | np.interp(time_interp, time, fingers_coords[:, finger_idx, coord_idx])
134 |
135 | self.initialised = True
136 |
137 | def add_fingertip_site(self, sites, name, pos):
138 | # Check if this fingertip is already in sites, and if so, overwrite the position
139 | for site in sites:
140 | if site["@name"] == name:
141 | site["@pos"] = pos
142 | return
143 |
144 | # Create the site
145 | sites.append({"@name": name, "@pos": pos})
146 |
147 | def step(self, a):
148 |
149 | if not self.initialised:
150 | return self._get_obs(), 0, False, {}
151 |
152 | # Step forward
153 | self.do_simulation(a, self.frame_skip)
154 |
155 | # Cost is difference between target and simulated fingertip positions
156 | #err = {x: 0 for x in self.fingers}
157 | #for finger in self.fingers:
158 | # err[finger] = np.linalg.norm(self.fingers_targets[finger][self._step_idx, :] -
159 | # self.data.site_xpos[self.model._site_name2id["fingertip_"+finger], :], ord=2)
160 |
161 | err = {}
162 | cost = np.linalg.norm(self.data.body_xpos[self.model._body_name2id["keyboard"], :] -
163 | self.data.site_xpos[self.model._site_name2id["fingertip_index"], :], ord=2)
164 |
165 | # Product of errors? Or sum? Squared errors?
166 | #cost = np.prod(list(err.values()))
167 |
168 | return self._get_obs(), -cost, False, err
169 |
170 | def _get_obs(self):
171 | """DIFFERENT FROM ORIGINAL GYM"""
172 | qpos = self.sim.data.qpos
173 | qvel = self.sim.data.qvel
174 | return np.concatenate([qpos.flat, qvel.flat])
175 |
176 | def reset_model(self):
177 | # Reset to initial pose?
178 | #self.set_state(
179 | # self.init_qpos + self.np_random.uniform(low=-1, high=1, size=self.model.nq),
180 | # self.init_qvel + self.np_random.uniform(low=-1, high=1, size=self.model.nv)
181 | #)
182 | self.sim.reset()
183 | self.set_state(self.init_qpos, self.init_qvel)
184 | return self._get_obs()
185 |
186 | @staticmethod
187 | def is_done(state):
188 | done = False
189 | return done
190 |
--------------------------------------------------------------------------------
/mujoco/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .hopper import HopperEnv
2 | from .half_cheetah import HalfCheetahEnv
3 | from .swimmer import SwimmerEnv
4 | from .inverted_pendulum import InvertedPendulumEnv
5 | from .inverted_double_pendulum import InvertedDoublePendulumEnv
6 | from .leg import LegEnv
7 | from .HandModelTSLAdjusted import HandModelTSLAdjustedEnv
8 | from .walker2d import Walker2dEnv
9 |
10 | __all__ = [
11 | "HopperEnv",
12 | "HalfCheetahEnv",
13 | "SwimmerEnv",
14 | "InvertedPendulumEnv",
15 | "InvertedDoublePendulumEnv",
16 | "LegEnv",
17 | "HandModelTSLAdjustedEnv",
18 | "Walker2dEnv"
19 | ]
20 |
--------------------------------------------------------------------------------
/mujoco/envs/half_cheetah.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from gym import utils
5 | from gym.envs.mujoco import mujoco_env
6 |
7 |
8 | class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
9 | """
10 | COPIED FROM GYM. W/ SLIGHT MODIFICATIONS:
11 | * READING FROM OWN .XML.
12 | * FULL STATE OBSERVATIONS, I.E. QPOS CONCAT'D WITH QVEL.
13 | * is_done METHOD SHOULD BE IMPLEMENTED
14 | """
15 | def __init__(self, cfg):
16 | mujoco_assets_dir = os.path.abspath("./mujoco/assets/")
17 | self.cfg = cfg
18 | self.frame_skip = self.cfg.MODEL.FRAME_SKIP
19 | self.initialised = False
20 | mujoco_env.MujocoEnv.__init__(self, os.path.join(mujoco_assets_dir, "half_cheetah.xml"), self.frame_skip)
21 | utils.EzPickle.__init__(self)
22 | self.initialised = True
23 |
24 | def step(self, action):
25 | xposbefore = self.sim.data.qpos[0]
26 | self.do_simulation(action, self.frame_skip)
27 | xposafter = self.sim.data.qpos[0]
28 | ob = self._get_obs()
29 | reward_ctrl = - 0.01 * np.square(action).sum()
30 | reward_run = (xposafter - xposbefore) / self.dt
31 | reward = reward_ctrl + reward_run
32 | done = False
33 | return ob, reward, done, dict(reward_run=reward_run, reward_ctrl=reward_ctrl)
34 |
35 | def _get_obs(self):
36 | """DIFFERENT FROM ORIGINAL GYM"""
37 | return np.concatenate([
38 | self.sim.data.qpos.flat,
39 | self.sim.data.qvel.flat,
40 | ])
41 |
42 | def reset_model(self):
43 | self.sim.reset()
44 | if self.cfg.MODEL.POLICY.NETWORK:
45 | self.set_state(
46 | self.init_qpos + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq),
47 | self.init_qvel + self.np_random.randn(self.model.nv) * .1
48 | )
49 | else:
50 | self.set_state(self.init_qpos, self.init_qvel)
51 | return self._get_obs()
52 |
53 | def viewer_setup(self):
54 | self.viewer.cam.distance = self.model.stat.extent * 0.5
55 |
56 | @staticmethod
57 | def is_done(state):
58 | done = False
59 | return done
60 |
61 | def tensor_reward(self, state, action, next_state):
62 | """DIFFERENT FROM ORIGINAL GYM"""
63 | xposbefore = state[0]
64 | xposafter = next_state[0]
65 | reward_ctrl = - 0.01 * torch.sum(torch.mul(action, action))
66 | reward_run = (xposafter - xposbefore) / self.dt
67 | reward = reward_ctrl + reward_run
68 | return reward.view([1, ])
69 |
--------------------------------------------------------------------------------
/mujoco/envs/hopper.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from gym import utils
5 | from gym.envs.mujoco import mujoco_env
6 | import math
7 |
8 |
9 | class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
10 | """
11 | COPIED FROM GYM. W/ SLIGHT MODIFICATIONS:
12 | * READING FROM OWN .XML.
13 | * FULL STATE OBSERVATIONS, I.E. QPOS CONCAT'D WITH QVEL.
14 | * is_done METHOD SHOULD BE IMPLEMENTED
15 | * torch implementation of reward function
16 | """
17 |
18 | def __init__(self, cfg):
19 | self.cfg = cfg
20 | self.frame_skip = self.cfg.MODEL.FRAME_SKIP
21 | mujoco_assets_dir = os.path.abspath("./mujoco/assets/")
22 | self.initialised = False
23 | mujoco_env.MujocoEnv.__init__(self, os.path.join(mujoco_assets_dir, "hopper.xml"), self.frame_skip)
24 | utils.EzPickle.__init__(self)
25 | self.initialised = True
26 |
27 | def sigmoid(self, x, mi, mx): return mi + (mx - mi) * (lambda t: (1 + 200 ** (-t + 0.5)) ** (-1))((x - mi) / (mx - mi))
28 |
29 | def step(self, a):
30 | posbefore = self.sim.data.qpos[0]
31 | self.do_simulation(a, self.frame_skip)
32 | posafter, height, ang = self.sim.data.qpos[0:3]
33 | alive_bonus = 1.0
34 | reward = (posafter - posbefore) / self.dt
35 | ang_abs = abs(ang) % (2*math.pi)
36 | if ang_abs > math.pi:
37 | ang_abs = 2*math.pi - ang_abs
38 | coeff1 = self.sigmoid(height/1.25, 0, 1)
39 | coeff2 = self.sigmoid((math.pi - ang_abs)/math.pi, 0, 1)
40 | reward += coeff1 * alive_bonus + coeff2 * alive_bonus
41 | reward -= 1e-3 * np.square(a).sum()
42 | s = self.state_vector()
43 | done = not (np.isfinite(s).all() and (np.abs(s[2:]) < 100).all() and
44 | (height > .7) and (abs(ang) < .2)) and self.initialised
45 | ob = self._get_obs()
46 | return ob, reward, False, {}
47 |
48 | def _get_obs(self):
49 | """DIFFERENT FROM ORIGINAL GYM"""
50 | return np.concatenate([
51 | self.sim.data.qpos.flat, # this part different from gym. expose the whole thing.
52 | self.sim.data.qvel.flat, # this part different from gym. clip nothing.
53 | ])
54 |
55 | def reset_model(self):
56 | self.sim.reset()
57 | if self.cfg.MODEL.POLICY.NETWORK:
58 | qpos = self.init_qpos + self.np_random.uniform(low=-.005, high=.005, size=self.model.nq)
59 | qvel = self.init_qvel + self.np_random.uniform(low=-.005, high=.005, size=self.model.nv)
60 | self.set_state(qpos, qvel)
61 | else:
62 | self.set_state(self.init_qpos, self.init_qvel)
63 | return self._get_obs()
64 |
65 | def viewer_setup(self):
66 | self.viewer.cam.trackbodyid = 2
67 | self.viewer.cam.distance = self.model.stat.extent * 0.75
68 | self.viewer.cam.lookat[2] = 1.15
69 | self.viewer.cam.elevation = -20
70 |
71 | @staticmethod
72 | def is_done(state):
73 | height, ang = state[1:3]
74 | done = not (np.isfinite(state).all() and (np.abs(state[2:]) < 100).all() and
75 | (height > .7) and (abs(ang) < .2))
76 | return done
77 |
78 | def tensor_reward(self, state, action, next_state):
79 | """DIFFERENT FROM ORIGINAL GYM"""
80 | posbefore = state[0]
81 | posafter, height, ang = next_state[0:3]
82 | alive_bonus = 1.0
83 | reward = (posafter - posbefore) / self.dt
84 | reward += alive_bonus
85 | reward -= 1e-3 * torch.sum(torch.mul(action, action))
86 | return reward.view([1, ])
87 |
--------------------------------------------------------------------------------
/mujoco/envs/inverted_double_pendulum.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from gym import utils
5 | from gym.envs.mujoco import mujoco_env
6 |
7 |
8 | class InvertedDoublePendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
9 | """
10 | COPIED FROM GYM. W/ SLIGHT MODIFICATIONS:
11 | * READING FROM OWN .XML.
12 | * FULL STATE OBSERVATIONS, I.E. QPOS CONCAT'D WITH QVEL.
13 | * is_done METHOD SHOULD BE IMPLEMENTED
14 | * torch implementation of reward function
15 | """
16 |
17 | def __init__(self, cfg):
18 | self.cfg = cfg
19 | mujoco_assets_dir = os.path.abspath("./mujoco/assets/")
20 | self.initialised = False
21 | mujoco_env.MujocoEnv.__init__(self, os.path.join(mujoco_assets_dir, "inverted_double_pendulum.xml"), self.cfg.MODEL.FRAME_SKIP)
22 | utils.EzPickle.__init__(self)
23 | self.initialised = True
24 |
25 | def step(self, action):
26 | self.do_simulation(action, self.frame_skip)
27 | ob = self._get_obs()
28 | x, _, y = self.sim.data.site_xpos[0]
29 | dist_penalty = 0.01 * x ** 2 + (y - 2) ** 2
30 | v1, v2 = self.sim.data.qvel[1:3]
31 | vel_penalty = 1e-3 * v1 ** 2 + 5e-3 * v2 ** 2
32 | alive_bonus = 10
33 | r = - dist_penalty - vel_penalty
34 | done = bool(y <= 1)
35 | return ob, r, False, {}
36 |
37 | def _get_obs(self):
38 | """DIFFERENT FROM ORIGINAL GYM"""
39 | return np.concatenate([
40 | self.sim.data.qpos.flat, # this part different from gym. expose the whole thing.
41 | self.sim.data.qvel.flat, # this part different from gym. clip nothing.
42 | ])
43 |
44 | def reset_model(self):
45 | self.sim.reset()
46 | self.set_state(
47 | self.init_qpos,# + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq),
48 | self.init_qvel# + self.np_random.randn(self.model.nv) * .1
49 | )
50 | return self._get_obs()
51 |
52 | def viewer_setup(self):
53 | v = self.viewer
54 | v.cam.trackbodyid = 0
55 | v.cam.distance = self.model.stat.extent * 0.5
56 | v.cam.lookat[2] = 0.12250000000000005 # v.model.stat.center[2]
57 |
58 | @staticmethod
59 | def is_done(state):
60 | arm_length = 0.6
61 | theta_1 = state[1]
62 | theta_2 = state[2]
63 | y = arm_length * np.cos(theta_1) + \
64 | arm_length * np.sin(theta_1 + theta_2)
65 | done = bool(y <= 1)
66 | return done
67 |
68 | def tensor_reward(self, state, action, next_state):
69 | """DIFFERENT FROM ORIGINAL GYM"""
70 | arm_length = 0.6
71 | theta_1 = next_state[1]
72 | theta_2 = next_state[2]
73 | y = arm_length * torch.cos(theta_1) + \
74 | arm_length * torch.cos(theta_1 + theta_2)
75 | x = arm_length * torch.cos(theta_1) + \
76 | arm_length * torch.cos(theta_1 + theta_2) + \
77 | next_state[0]
78 | dist_penalty = 0.01 * x ** 2 + (y - 2) ** 2
79 | v1, v2 = next_state[4:6]
80 | vel_penalty = 1e-3 * v1 ** 2 + 5e-3 * v2 ** 2
81 | alive_bonus = 10
82 | reward = alive_bonus - dist_penalty - vel_penalty
83 | return reward.view([1, ])
84 |
--------------------------------------------------------------------------------
/mujoco/envs/inverted_pendulum.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from gym import utils
5 | from gym.envs.mujoco import mujoco_env
6 |
7 |
8 | class InvertedPendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
9 | """
10 | COPIED FROM GYM. W/ SLIGHT MODIFICATIONS:
11 | * READING FROM OWN .XML.
12 | * FULL STATE OBSERVATIONS, I.E. QPOS CONCAT'D WITH QVEL.
13 | * is_done METHOD SHOULD BE IMPLEMENTED
14 | * torch implementation of reward function
15 | """
16 |
17 | def __init__(self, cfg):
18 | self.cfg = cfg
19 | self.frame_skip = self.cfg.MODEL.FRAME_SKIP
20 | utils.EzPickle.__init__(self)
21 | mujoco_assets_dir = os.path.abspath("./mujoco/assets/")
22 | mujoco_env.MujocoEnv.__init__(self, os.path.join(mujoco_assets_dir, "inverted_pendulum.xml"), self.frame_skip)
23 |
24 | def step(self, a):
25 | """DIFFERENT FROM ORIGINAL GYM"""
26 | arm_length = 0.6
27 | self.do_simulation(a, self.frame_skip)
28 | ob = self._get_obs()
29 | theta = ob[1]
30 | y = arm_length * np.cos(theta)
31 | x = arm_length * np.cos(theta)
32 | dist_penalty = 0.01 * x ** 2 + (y - 1) ** 2
33 | #v = ob[3]
34 | #vel_penalty = 1e-3 * v ** 2
35 | reward = -dist_penalty - 0.001*(a**2)
36 | notdone = np.isfinite(ob).all() and (np.abs(ob[1]) <= .2)
37 | done = not notdone
38 | return ob, reward, False, {}
39 |
40 | def reset_model(self):
41 | #qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-0.01, high=0.01)
42 | #qvel = self.init_qvel + self.np_random.uniform(size=self.model.nv, low=-0.01, high=0.01)
43 | #self.set_state(qpos, qvel)
44 | self.set_state(np.array([0.0, 0.0]), np.array([0.0, 0.0]))
45 | return self._get_obs()
46 |
47 | def _get_obs(self):
48 | return np.concatenate([self.sim.data.qpos, self.sim.data.qvel]).ravel()
49 |
50 | def viewer_setup(self):
51 | v = self.viewer
52 | v.cam.trackbodyid = 0
53 | v.cam.distance = self.model.stat.extent
54 |
55 | @staticmethod
56 | def is_done(state):
57 | done = False
58 | return done
59 |
60 | def tensor_reward(self, state, action, next_state):
61 | """DIFFERENT FROM ORIGINAL GYM"""
62 | arm_length = 0.6
63 | theta = next_state[1]
64 | y = arm_length * torch.cos(theta)
65 | x = arm_length * torch.cos(theta)
66 | dist_penalty = 0.01 * x ** 2 + (y - 1) ** 2
67 | v = next_state[3]
68 | vel_penalty = 1e-3 * v ** 2
69 | reward = -dist_penalty - vel_penalty
70 | return reward.view([1, ])
71 |
--------------------------------------------------------------------------------
/mujoco/envs/leg.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from gym import utils, spaces
5 | from gym.envs.mujoco import mujoco_env
6 | from model_accuracy_tests.find_parameter_mappings import get_control, get_kinematics, update_equality_constraint, reindex_dataframe
7 |
8 |
9 | class LegEnv(mujoco_env.MujocoEnv, utils.EzPickle):
10 | """
11 | """
12 | def __init__(self):
13 |
14 | self.initialised = False
15 |
16 | mujoco_assets_dir = os.path.abspath("./mujoco/assets/")
17 | mujoco_env.MujocoEnv.__init__(self, os.path.join(mujoco_assets_dir, "leg6dof9musc_converted.xml"), 1)
18 | utils.EzPickle.__init__(self)
19 |
20 | # Get muscle control values
21 | control_file = "/home/aleksi/Workspace/O2MConverter/models/opensim/Leg6Dof9Musc/CMC/leg6dof9musc_controls.sto"
22 | control_values, control_header = get_control(self.model, control_file)
23 |
24 | # Get joint kinematics values
25 | qpos_file = "/home/aleksi/Workspace/O2MConverter/models/opensim/Leg6Dof9Musc/CMC/leg6dof9musc_Kinematics_q.sto"
26 | qpos, qpos_header = get_kinematics(self.model, qpos_file)
27 | qvel_file = "/home/aleksi/Workspace/O2MConverter/models/opensim/Leg6Dof9Musc/CMC/leg6dof9musc_Kinematics_u.sto"
28 | qvel, qvel_header = get_kinematics(self.model, qvel_file)
29 |
30 | # Make sure both muscle control and joint kinematics have the same timesteps
31 | if not control_values.index.equals(qpos.index) or not control_values.index.equals(qvel.index):
32 | print("Timesteps do not match between muscle control and joint kinematics")
33 | return
34 |
35 | # Timestep might not be constant in the OpenSim reference movement (weird). We can't change timestep dynamically in
36 | # mujoco, at least the viewer does weird things and it could be reflecting underlying issues. Thus, we should
37 | # interpolate the muscle control and joint kinematics with model.opt.timestep
38 | # model.opt.timestep /= 2.65
39 | self.control_values = reindex_dataframe(control_values, self.model.opt.timestep)
40 | self.target_qpos = reindex_dataframe(qpos, self.model.opt.timestep)
41 | self.target_qvel = reindex_dataframe(qvel, self.model.opt.timestep)
42 |
43 | # Get initial state values and reset model
44 | for joint_idx, joint_name in self.model._joint_id2name.items():
45 | self.init_qpos[joint_idx] = self.target_qpos[joint_name][0]
46 | self.init_qvel[joint_idx] = self.target_qvel[joint_name][0]
47 | self.reset_model()
48 |
49 | # We need to update equality constraints here
50 | update_equality_constraint(self.model, self.target_qpos)
51 |
52 | self.initialised = True
53 |
54 | def step(self, a):
55 |
56 | if not self.initialised:
57 | return self._get_obs(), 0, False, {}
58 |
59 | # Step forward
60 | self.do_simulation(a, self.frame_skip)
61 |
62 | # Cost is difference between target and simulated qpos and qvel
63 | e_qpos = 0
64 | e_qvel = 0
65 | for joint_idx, joint_name in self.model._joint_id2name.items():
66 | e_qpos += np.abs(self.target_qpos[joint_name].iloc[self._step_idx.value] - self.data.qpos[joint_idx])
67 | e_qvel += np.abs(self.target_qvel[joint_name].iloc[self._step_idx.value] - self.data.qvel[joint_idx])
68 |
69 | cost = np.square(e_qpos) + 0.001*np.square(e_qvel) + 0.001*np.matmul(a,a)
70 |
71 | return self._get_obs(), -cost, False, {"e_qpos": e_qpos, "e_qvel": e_qvel}
72 |
73 | def _get_obs(self):
74 | """DIFFERENT FROM ORIGINAL GYM"""
75 | qpos = self.sim.data.qpos
76 | qvel = self.sim.data.qvel
77 | return np.concatenate([qpos.flat, qvel.flat])
78 |
79 | def reset_model(self):
80 | # Reset to initial pose?
81 | #self.set_state(
82 | # self.init_qpos + self.np_random.uniform(low=-1, high=1, size=self.model.nq),
83 | # self.init_qvel + self.np_random.uniform(low=-1, high=1, size=self.model.nv)
84 | #)
85 | self.set_state(self.init_qpos, self.init_qvel)
86 | return self._get_obs()
87 |
88 | @staticmethod
89 | def is_done(state):
90 | done = False
91 | return done
92 |
--------------------------------------------------------------------------------
/mujoco/envs/swimmer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from gym import utils
5 | from gym.envs.mujoco import mujoco_env
6 |
7 |
8 | class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
9 | """
10 | COPIED FROM GYM. W/ SLIGHT MODIFICATIONS:
11 | * READING FROM OWN .XML.
12 | * FULL STATE OBSERVATIONS, I.E. QPOS CONCAT'D WITH QVEL.
13 | * is_done METHOD SHOULD BE IMPLEMENTED
14 | """
15 | def __init__(self, cfg):
16 | self.cfg = cfg
17 | self.frame_skip = self.cfg.MODEL.FRAME_SKIP
18 | mujoco_assets_dir = os.path.abspath("./mujoco/assets/")
19 | mujoco_env.MujocoEnv.__init__(self, os.path.join(mujoco_assets_dir, "swimmer.xml"), self.frame_skip)
20 | utils.EzPickle.__init__(self)
21 |
22 | def step(self, a):
23 | ctrl_cost_coeff = 0.0001
24 | xposbefore = self.sim.data.qpos[0]
25 | self.do_simulation(a, self.frame_skip)
26 | xposafter = self.sim.data.qpos[0]
27 | reward_fwd = (xposafter - xposbefore) / self.dt
28 | reward_ctrl = - ctrl_cost_coeff * np.square(a).sum()
29 | reward = reward_fwd + reward_ctrl
30 |
31 | ob = self._get_obs()
32 | return ob, reward, False, dict(reward_fwd=reward_fwd, reward_ctrl=reward_ctrl)
33 |
34 | def _get_obs(self):
35 | """DIFFERENT FROM ORIGINAL GYM"""
36 | qpos = self.sim.data.qpos
37 | qvel = self.sim.data.qvel
38 | return np.concatenate([qpos.flat, qvel.flat])
39 |
40 | def reset_model(self):
41 | self.sim.reset()
42 | if self.cfg.MODEL.POLICY.NETWORK:
43 | self.set_state(
44 | self.init_qpos + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq),
45 | self.init_qvel + self.np_random.uniform(low=-.1, high=.1, size=self.model.nv)
46 | )
47 | else:
48 | self.set_state(self.init_qpos, self.init_qvel)
49 | return self._get_obs()
50 |
51 | @staticmethod
52 | def is_done(state):
53 | done = False
54 | return done
55 |
56 | def tensor_reward(self, state, action, next_state):
57 | """DIFFERENT FROM ORIGINAL GYM"""
58 | ctrl_cost_coeff = 0.0001
59 | xposbefore = state[0]
60 | xposafter = next_state[0]
61 | reward_fwd = (xposafter - xposbefore) / self.dt
62 | reward_ctrl = - ctrl_cost_coeff * torch.sum(torch.mul(action, action))
63 | reward = reward_fwd + reward_ctrl
64 | return reward.view([1, ])
65 |
--------------------------------------------------------------------------------
/mujoco/envs/walker2d.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from gym import utils
3 | from gym.envs.mujoco import mujoco_env
4 | import os
5 | import math
6 |
7 |
8 | class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
9 | """
10 | COPIED FROM GYM. W/ SLIGHT MODIFICATIONS
11 | """
12 |
13 | def __init__(self, cfg):
14 | self.cfg = cfg
15 | self.frame_skip = self.cfg.MODEL.FRAME_SKIP
16 | mujoco_assets_dir = os.path.abspath("./mujoco/assets/")
17 | self.initialised = False
18 | mujoco_env.MujocoEnv.__init__(self, os.path.join(mujoco_assets_dir, "walker2d.xml"), self.frame_skip)
19 | utils.EzPickle.__init__(self)
20 | self.initialised = True
21 |
22 | def step(self, a):
23 | posbefore = self.sim.data.qpos[0]
24 | self.do_simulation(a, self.frame_skip)
25 | posafter, height, ang = self.sim.data.qpos[0:3]
26 | alive_bonus = 1.0
27 | reward = ((posafter - posbefore) / self.dt)
28 | coeff = min(max(height/0.8, 0), 1)*0.5 + max(((math.pi - abs(ang))/math.pi), 1)*0.5
29 | reward += coeff*alive_bonus
30 | reward -= 1e-3 * np.square(a).sum()
31 | done = not (height > 0.8 and height < 2.0 and
32 | ang > -1.0 and ang < 1.0) and self.initialised
33 | ob = self._get_obs()
34 | #if not done:
35 | # reward += alive_bonus
36 | return ob, reward, False, {}
37 |
38 | def _get_obs(self):
39 | """DIFFERENT FROM ORIGINAL GYM"""
40 | return np.concatenate([
41 | self.sim.data.qpos.flat, # this part different from gym. expose the whole thing.
42 | self.sim.data.qvel.flat, # this part different from gym. clip nothing.
43 | ])
44 |
45 | def reset_model(self):
46 | self.sim.reset()
47 | #self.set_state(
48 | # self.init_qpos + self.np_random.uniform(low=-.005, high=.005, size=self.model.nq),
49 | # self.init_qvel + self.np_random.uniform(low=-.005, high=.005, size=self.model.nv)
50 | #)
51 | self.set_state(self.init_qpos, self.init_qvel)
52 | return self._get_obs()
53 |
54 | def viewer_setup(self):
55 | self.viewer.cam.trackbodyid = 2
56 | self.viewer.cam.distance = self.model.stat.extent * 0.5
57 | self.viewer.cam.lookat[2] = 1.15
58 | self.viewer.cam.elevation = -20
--------------------------------------------------------------------------------
/mujoco/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .wrappers import MjBlockWrapper, RewardScaleWrapper, ClipActionsWrapper, FixedStateWrapper, TorchTensorWrapper
2 |
3 | __all__ = ["MjBlockWrapper", "RewardScaleWrapper", "ClipActionsWrapper", "FixedStateWrapper", "TorchTensorWrapper"]
4 |
--------------------------------------------------------------------------------
/mujoco/utils/backward.py:
--------------------------------------------------------------------------------
1 | # TODO: Breaks down in case of quaternions, i.e. free 3D bodies or ball joints.
2 |
3 | import mujoco_py as mj
4 | import numpy as np
5 | import torch
6 | from copy import deepcopy
7 |
8 | # ============================================================
9 | # CONFIG
10 | # ============================================================
11 | niter = 30
12 | nwarmup = 3
13 | eps = 1e-6
14 |
15 |
16 | def copy_data(m, d_source, d_dest):
17 | d_dest.time = d_source.time
18 | d_dest.qpos[:] = d_source.qpos
19 | d_dest.qvel[:] = d_source.qvel
20 | d_dest.qacc[:] = d_source.qacc
21 | d_dest.qacc_warmstart[:] = d_source.qacc_warmstart
22 | d_dest.qfrc_applied[:] = d_source.qfrc_applied
23 | for i in range(m.nbody):
24 | d_dest.xfrc_applied[i][:] = d_source.xfrc_applied[i]
25 | d_dest.ctrl[:] = d_source.ctrl
26 |
27 | # copy d_source.act? probably need to when dealing with muscle actuators
28 | #d_dest.act = d_source.act
29 |
30 |
31 | def initialise_simulation(d_dest, d_source):
32 | d_dest.time = d_source.time
33 | d_dest.qpos[:] = d_source.qpos
34 | d_dest.qvel[:] = d_source.qvel
35 | d_dest.qacc_warmstart[:] = d_source.qacc_warmstart
36 | d_dest.ctrl[:] = d_source.ctrl
37 | if d_source.act is not None:
38 | d_dest.act[:] = d_source.act
39 |
40 |
41 | def integrate_dynamics_gradients(env, dqaccdqpos, dqaccdqvel, dqaccdctrl):
42 | m = env.model
43 | nv = m.nv
44 | nu = m.nu
45 | dt = env.model.opt.timestep * env.frame_skip
46 |
47 | # dfds: d(next_state)/d(current_state): consists of four parts ul, dl, ur, and ud
48 | ul = np.identity(nv, dtype=np.float32)
49 | ur = np.identity(nv, dtype=np.float32) * dt
50 | dl = dqaccdqpos * dt
51 | dr = np.identity(nv, dtype=np.float32) + dqaccdqvel * dt
52 | dfds = np.concatenate([np.concatenate([ul, dl], axis=0),
53 | np.concatenate([ur, dr], axis=0)],
54 | axis=1)
55 |
56 | # dfda: d(next_state)/d(action_values)
57 | dfda = np.concatenate([np.zeros([nv, nu]), dqaccdctrl * dt], axis=0)
58 | return dfds, dfda
59 |
60 |
61 | def integrate_reward_gradient(env, drdqpos, drdqvel, drdctrl):
62 | return np.concatenate([np.array(drdqpos).reshape(1, env.model.nq),
63 | np.array(drdqvel).reshape(1, env.model.nv)], axis=1), \
64 | np.array(drdctrl).reshape(1, env.model.nu)
65 |
66 |
67 | def dynamics_worker(env, d):
68 | m = env.sim.model
69 | dmain = env.sim.data
70 |
71 | dqaccdqpos = [None] * m.nv * m.nv
72 | dqaccdqvel = [None] * m.nv * m.nv
73 | dqaccdctrl = [None] * m.nv * m.nu
74 |
75 | # copy state and control from dmain to thread-specific d
76 | copy_data(m, dmain, d)
77 |
78 | # is_forward
79 | mj.functions.mj_forward(m, d)
80 |
81 | # extra solver iterations to improve warmstart (qacc) at center point
82 | for rep in range(nwarmup):
83 | mj.functions.mj_forwardSkip(m, d, mj.const.STAGE_VEL, 1)
84 |
85 | # select output from forward dynamics
86 | output = d.qacc # always differentiate qacc
87 |
88 | # save output for center point and warmstart (needed in forward only)
89 | center = output.copy()
90 | warmstart = d.qacc_warmstart.copy()
91 |
92 | # finite-difference over control values: skip = mjSTAGE_VEL
93 | for i in range(m.nu):
94 |
95 | # perturb selected target
96 | d.ctrl[i] += eps
97 |
98 | # evaluate dynamics, with center warmstart
99 | mj.functions.mju_copy(d.qacc_warmstart, warmstart, m.nv)
100 | mj.functions.mj_forwardSkip(m, d, mj.const.STAGE_VEL, 1)
101 |
102 | # undo perturbation
103 | d.ctrl[i] = dmain.ctrl[i]
104 |
105 | # compute column i of derivative 2
106 | for j in range(m.nv):
107 | dqaccdctrl[i + j * m.nu] = (output[j] - center[j]) / eps
108 |
109 | # finite-difference over velocity: skip = mjSTAGE_POS
110 | for i in range(m.nv):
111 |
112 | # perturb velocity
113 | d.qvel[i] += eps
114 |
115 | # evaluate dynamics, with center warmstart
116 | mj.functions.mju_copy(d.qacc_warmstart, warmstart, m.nv)
117 | mj.functions.mj_forwardSkip(m, d, mj.const.STAGE_POS, 1)
118 |
119 | # undo perturbation
120 | d.qvel[i] = dmain.qvel[i]
121 |
122 | # compute column i of derivative 1
123 | for j in range(m.nv):
124 | dqaccdqvel[i + j * m.nv] = (output[j] - center[j]) / eps
125 |
126 | # finite-difference over position: skip = mjSTAGE_NONE
127 | for i in range(m.nv):
128 |
129 | # get joint id for this dof
130 | jid = m.dof_jntid[i]
131 |
132 | # get quaternion address and dof position within quaternion (-1: not in quaternion)
133 | quatadr = -1
134 | dofpos = 0
135 | if m.jnt_type[jid] == mj.const.JNT_BALL:
136 | quatadr = m.jnt_qposadr[jid]
137 | dofpos = i - m.jnt_dofadr[jid]
138 | elif m.jnt_type[jid] == mj.const.JNT_FREE and i >= m.jnt_dofadr[jid] + 3:
139 | quatadr = m.jnt_qposadr[jid] + 3
140 | dofpos = i - m.jnt_dofadr[jid] - 3
141 |
142 | # apply quaternion or simple perturbation
143 | if quatadr >= 0:
144 | angvel = np.array([0., 0., 0.])
145 | angvel[dofpos] = eps
146 | mj.functions.mju_quatIntegrate(d.qpos + quatadr, angvel, 1)
147 | else:
148 | d.qpos[m.jnt_qposadr[jid] + i - m.jnt_dofadr[jid]] += eps
149 |
150 | # evaluate dynamics, with center warmstart
151 | mj.functions.mju_copy(d.qacc_warmstart, warmstart, m.nv)
152 | mj.functions.mj_forwardSkip(m, d, mj.const.STAGE_NONE, 1)
153 |
154 | # undo perturbation
155 | mj.functions.mju_copy(d.qpos, dmain.qpos, m.nq)
156 |
157 | # compute column i of derivative 0
158 | for j in range(m.nv):
159 | dqaccdqpos[i + j * m.nv] = (output[j] - center[j]) / eps
160 |
161 | dqaccdqpos = np.array(dqaccdqpos).reshape(m.nv, m.nv)
162 | dqaccdqvel = np.array(dqaccdqvel).reshape(m.nv, m.nv)
163 | dqaccdctrl = np.array(dqaccdctrl).reshape(m.nv, m.nu)
164 | dfds, dfda = integrate_dynamics_gradients(env, dqaccdqpos, dqaccdqvel, dqaccdctrl)
165 | return dfds, dfda
166 |
167 |
168 | def dynamics_worker_separate(env):
169 | m = env.sim.model
170 | d = env.sim.data
171 |
172 | dsdctrl = np.empty((m.nq + m.nv, m.nu))
173 | dsdqpos = np.empty((m.nq + m.nv, m.nq))
174 | dsdqvel = np.empty((m.nq + m.nv, m.nv))
175 |
176 | # Copy initial state
177 | time_initial = d.time
178 | qvel_initial = d.qvel.copy()
179 | qpos_initial = d.qpos.copy()
180 | ctrl_initial = d.ctrl.copy()
181 | qacc_warmstart_initial = d.qacc_warmstart.copy()
182 | act_initial = d.act.copy() if d.act is not None else None
183 |
184 | # Step with the main simulation
185 | mj.functions.mj_step(m, d)
186 |
187 | # Get qpos, qvel of the main simulation
188 | qpos = d.qpos.copy()
189 | qvel = d.qvel.copy()
190 |
191 | # finite-difference over control values: skip = mjSTAGE_VEL
192 | for i in range(m.nu):
193 |
194 | # Initialise simulation
195 | initialise_simulation(d, time_initial, qpos_initial, qvel_initial, qacc_warmstart_initial, ctrl_initial, act_initial)
196 |
197 | # Perturb control
198 | d.ctrl[i] += eps
199 |
200 | # Step with perturbed simulation
201 | mj.functions.mj_step(m, d)
202 |
203 | # Compute gradients of qpos and qvel wrt control
204 | dsdctrl[:m.nq, i] = (d.qpos - qpos) / eps
205 | dsdctrl[m.nq:, i] = (d.qvel - qvel) / eps
206 |
207 | # finite-difference over velocity: skip = mjSTAGE_POS
208 | for i in range(m.nv):
209 |
210 | # Initialise simulation
211 | initialise_simulation(d, time_initial, qpos_initial, qvel_initial, qacc_warmstart_initial, ctrl_initial, act_initial)
212 |
213 | # Perturb velocity
214 | d.qvel[i] += eps
215 |
216 | # Step with perturbed simulation
217 | mj.functions.mj_step(m, d)
218 |
219 | # Compute gradients of qpos and qvel wrt qvel
220 | dsdqvel[:m.nq, i] = (d.qpos - qpos) / eps
221 | dsdqvel[m.nq:, i] = (d.qvel - qvel) / eps
222 |
223 | # finite-difference over position: skip = mjSTAGE_NONE
224 | for i in range(m.nq):
225 |
226 | # Initialise simulation
227 | initialise_simulation(d, time_initial, qpos_initial, qvel_initial, qacc_warmstart_initial, ctrl_initial, act_initial)
228 |
229 | # Get joint id for this dof
230 | jid = m.dof_jntid[i]
231 |
232 | # Get quaternion address and dof position within quaternion (-1: not in quaternion)
233 | quatadr = -1
234 | dofpos = 0
235 | if m.jnt_type[jid] == mj.const.JNT_BALL:
236 | quatadr = m.jnt_qposadr[jid]
237 | dofpos = i - m.jnt_dofadr[jid]
238 | elif m.jnt_type[jid] == mj.const.JNT_FREE and i >= m.jnt_dofadr[jid] + 3:
239 | quatadr = m.jnt_qposadr[jid] + 3
240 | dofpos = i - m.jnt_dofadr[jid] - 3
241 |
242 | # Apply quaternion or simple perturbation
243 | if quatadr >= 0:
244 | angvel = np.array([0., 0., 0.])
245 | angvel[dofpos] = eps
246 | mj.functions.mju_quatIntegrate(d.qpos + quatadr, angvel, 1)
247 | else:
248 | d.qpos[m.jnt_qposadr[jid] + i - m.jnt_dofadr[jid]] += eps
249 |
250 | # Step simulation with perturbed position
251 | mj.functions.mj_step(m, d)
252 |
253 | # Compute gradients of qpos and qvel wrt qpos
254 | dsdqpos[:m.nq, i] = (d.qpos - qpos) / eps
255 | dsdqpos[m.nq:, i] = (d.qvel - qvel) / eps
256 |
257 | return np.concatenate((dsdqpos, dsdqvel), axis=1), dsdctrl
258 |
259 |
260 | def reward_worker(env, d):
261 | m = env.sim.model
262 | dmain = env.sim.data
263 |
264 | drdqpos = [None] * m.nv
265 | drdqvel = [None] * m.nv
266 | drdctrl = [None] * m.nu
267 |
268 | # copy state and control from dmain to thread-specific d
269 | copy_data(m, dmain, d)
270 |
271 | # is_forward
272 | mj.functions.mj_forward(m, d)
273 |
274 | # extra solver iterations to improve warmstart (qacc) at center point
275 | for rep in range(nwarmup):
276 | mj.functions.mj_forwardSkip(m, d, mj.const.STAGE_VEL, 1)
277 |
278 | # get center reward
279 | _, center, _, _ = env.step(d.ctrl)
280 | copy_data(m, d, dmain) # revert changes to state and forces
281 |
282 | # finite-difference over control values
283 | for i in range(m.nu):
284 | # perturb selected target
285 | d.ctrl[i] += eps
286 |
287 | _, output, _, _ = env.step(d.ctrl)
288 | copy_data(m, d, dmain) # undo perturbation
289 |
290 | drdctrl[i] = (output - center) / eps
291 |
292 | # finite-difference over velocity
293 | for i in range(m.nv):
294 | # perturb velocity
295 | d.qvel[i] += eps
296 |
297 | _, output, _, _ = env.step(d.ctrl)
298 | copy_data(m, d, dmain) # undo perturbation
299 |
300 | drdqvel[i] = (output - center) / eps
301 |
302 | # finite-difference over position: skip = mjSTAGE_NONE
303 | for i in range(m.nv):
304 |
305 | # get joint id for this dof
306 | jid = m.dof_jntid[i]
307 |
308 | # get quaternion address and dof position within quaternion (-1: not in quaternion)
309 | quatadr = -1
310 | dofpos = 0
311 | if m.jnt_type[jid] == mj.const.JNT_BALL:
312 | quatadr = m.jnt_qposadr[jid]
313 | dofpos = i - m.jnt_dofadr[jid]
314 | elif m.jnt_type[jid] == mj.const.JNT_FREE and i >= m.jnt_dofadr[jid] + 3:
315 | quatadr = m.jnt_qposadr[jid] + 3
316 | dofpos = i - m.jnt_dofadr[jid] - 3
317 |
318 | # apply quaternion or simple perturbation
319 | if quatadr >= 0:
320 | angvel = np.array([0., 0., 0.])
321 | angvel[dofpos] = eps
322 | mj.functions.mju_quatIntegrate(d.qpos + quatadr, angvel, 1)
323 | else:
324 | d.qpos[m.jnt_qposadr[jid] + i - m.jnt_dofadr[jid]] += eps
325 |
326 | _, output, _, _ = env.step(d.ctrl)
327 | copy_data(m, d, dmain) # undo perturbation
328 |
329 | # compute column i of derivative 0
330 | drdqpos[i] = (output - center) / eps
331 |
332 | drds, drda = integrate_reward_gradient(env, drdqpos, drdqvel, drdctrl)
333 | return drds, drda
334 |
335 |
336 | def calculate_reward(env, qpos, qvel, ctrl, qpos_next, qvel_next):
337 | current_state = np.concatenate((qpos, qvel))
338 | next_state = np.concatenate((qpos_next, qvel_next))
339 | reward = env.tensor_reward(torch.DoubleTensor(current_state), torch.DoubleTensor(ctrl), torch.DoubleTensor(next_state))
340 | return reward.detach().numpy()
341 |
342 |
343 | def calculate_gradients(agent, data_snapshot, next_state, reward, test=False):
344 | # Defining m and d just for shorter notations
345 | m = agent.model
346 | d = agent.data
347 |
348 | # Dynamics gradients
349 | dsdctrl = np.empty((m.nq + m.nv, m.nu))
350 | dsdqpos = np.empty((m.nq + m.nv, m.nq))
351 | dsdqvel = np.empty((m.nq + m.nv, m.nv))
352 |
353 | # Reward gradients
354 | drdctrl = np.empty((1, m.nu))
355 | drdqpos = np.empty((1, m.nq))
356 | drdqvel = np.empty((1, m.nv))
357 |
358 | # Get number of steps (must be >=2 for muscles)
359 | nsteps = agent.cfg.MODEL.NSTEPS_FOR_BACKWARD
360 |
361 | # For testing purposes
362 | if test:
363 |
364 | # Reset simulation to snapshot
365 | agent.set_snapshot(data_snapshot)
366 |
367 | # Step with the main simulation
368 | info = agent.step(d.ctrl.copy())
369 |
370 | # Sanity check. "reward" must equal info[1], otherwise this simulation has diverged from the forward pass
371 | assert reward == info[1], "reward is different from forward pass [{} != {}] at timepoint {}".format(reward, info[1], data_snapshot.time)
372 |
373 | # Another check. "next_state" must equal info[0], otherwise this simulation has diverged from the forward pass
374 | assert (next_state == info[0]).all(), "state is different from forward pass"
375 |
376 | # Get state from the forward pass
377 | if nsteps > 1:
378 | agent.set_snapshot(data_snapshot)
379 | for _ in range(nsteps):
380 | info = agent.step(d.ctrl.copy())
381 | qpos_fwd = info[0][:agent.model.nq]
382 | qvel_fwd = info[0][agent.model.nq:]
383 | reward = info[1]
384 | else:
385 | qpos_fwd = next_state[:agent.model.nq]
386 | qvel_fwd = next_state[agent.model.nq:]
387 |
388 | # finite-difference over control values
389 | for i in range(m.nu):
390 |
391 | # Initialise simulation
392 | agent.set_snapshot(data_snapshot)
393 |
394 | # Perturb control
395 | d.ctrl[i] += eps
396 |
397 | # Step with perturbed simulation
398 | for _ in range(nsteps):
399 | info = agent.step(d.ctrl.copy())
400 |
401 | # Compute gradient of state wrt control
402 | dsdctrl[:m.nq, i] = (d.qpos - qpos_fwd) / eps
403 | dsdctrl[m.nq:, i] = (d.qvel - qvel_fwd) / eps
404 |
405 | # Compute gradient of reward wrt to control
406 | drdctrl[0, i] = (info[1] - reward) / eps
407 |
408 | # finite-difference over velocity
409 | for i in range(m.nv):
410 |
411 | # Initialise simulation
412 | agent.set_snapshot(data_snapshot)
413 |
414 | # Perturb velocity
415 | d.qvel[i] += eps
416 |
417 | # Calculate new ctrl (if it's dependent on state)
418 | if agent.cfg.MODEL.POLICY.NETWORK:
419 | d.ctrl[:] = agent.policy_net(torch.from_numpy(np.concatenate((d.qpos, d.qvel))).float()).double().detach().numpy()
420 |
421 | # Step with perturbed simulation
422 | for _ in range(nsteps):
423 | info = agent.step(d.ctrl)
424 |
425 | # Compute gradient of state wrt qvel
426 | dsdqvel[:m.nq, i] = (d.qpos - qpos_fwd) / eps
427 | dsdqvel[m.nq:, i] = (d.qvel - qvel_fwd) / eps
428 |
429 | # Compute gradient of reward wrt qvel
430 | drdqvel[0, i] = (info[1] - reward) / eps
431 |
432 | # finite-difference over position
433 | for i in range(m.nq):
434 |
435 | # Initialise simulation
436 | agent.set_snapshot(data_snapshot)
437 |
438 | # Get joint id for this dof
439 | jid = m.dof_jntid[i]
440 |
441 | # Get quaternion address and dof position within quaternion (-1: not in quaternion)
442 | quatadr = -1
443 | dofpos = 0
444 | if m.jnt_type[jid] == mj.const.JNT_BALL:
445 | quatadr = m.jnt_qposadr[jid]
446 | dofpos = i - m.jnt_dofadr[jid]
447 | elif m.jnt_type[jid] == mj.const.JNT_FREE and i >= m.jnt_dofadr[jid] + 3:
448 | quatadr = m.jnt_qposadr[jid] + 3
449 | dofpos = i - m.jnt_dofadr[jid] - 3
450 |
451 | # Apply quaternion or simple perturbation
452 | if quatadr >= 0:
453 | angvel = np.array([0., 0., 0.])
454 | angvel[dofpos] = eps
455 | mj.functions.mju_quatIntegrate(d.qpos + quatadr, angvel, 1)
456 | else:
457 | d.qpos[m.jnt_qposadr[jid] + i - m.jnt_dofadr[jid]] += eps
458 |
459 | # Calculate new ctrl (if it's dependent on state)
460 | if agent.cfg.MODEL.POLICY.NETWORK:
461 | d.ctrl[:] = agent.policy_net(torch.from_numpy(np.concatenate((d.qpos, d.qvel))).float()).double().detach().numpy()
462 |
463 | # Step simulation with perturbed position
464 | for _ in range(nsteps):
465 | info = agent.step(d.ctrl)
466 |
467 | # Compute gradient of state wrt qpos
468 | dsdqpos[:m.nq, i] = (d.qpos - qpos_fwd) / eps
469 | dsdqpos[m.nq:, i] = (d.qvel - qvel_fwd) / eps
470 |
471 | # Compute gradient of reward wrt qpos
472 | drdqpos[0, i] = (info[1] - reward) / eps
473 |
474 | # Set dynamics gradients
475 | agent.dynamics_gradients = {"state": np.concatenate((dsdqpos, dsdqvel), axis=1), "action": dsdctrl}
476 |
477 | # Set reward gradients
478 | agent.reward_gradients = {"state": np.concatenate((drdqpos, drdqvel), axis=1), "action": drdctrl}
479 |
480 | return
481 |
482 |
483 | def mj_gradients_factory(agent, mode):
484 | """
485 | :param env: gym.envs.mujoco.mujoco_env.mujoco_env.MujocoEnv
486 | :param mode: 'dynamics' or 'reward'
487 | :return:
488 | """
489 | #mj_sim_main = env.sim
490 | #mj_sim = mj.MjSim(mj_sim_main.model)
491 |
492 | #worker = {'dynamics': reward_worker, 'reward': reward_worker}[mode]
493 |
494 | @agent.gradient_wrapper(mode)
495 | def mj_gradients(data_snapshot, next_state, reward, test=False):
496 | #state = state_action[:env.model.nq + env.model.nv]
497 | #qpos = state[:env.model.nq]
498 | #qvel = state[env.model.nq:]
499 | #ctrl = state_action[-env.model.nu:]
500 | #env.set_state(qpos, qvel)
501 | #env.data.ctrl[:] = ctrl
502 | #d = mj_sim.data
503 | # set solver options for finite differences
504 | #mj_sim_main.model.opt.iterations = niter
505 | #mj_sim_main.model.opt.tolerance = 0
506 | # env.sim.model.opt.iterations = niter
507 | # env.sim.model.opt.tolerance = 0
508 | #dfds, dfda = worker(env)
509 |
510 | calculate_gradients(agent, data_snapshot, next_state, reward, test=test)
511 |
512 | return mj_gradients
513 |
--------------------------------------------------------------------------------
/mujoco/utils/forward.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def mj_forward_factory(agent, mode):
6 | """
7 | :param agent: gym.envs.mujoco.mujoco_env.mujoco_env.MujocoEnv
8 | :param mode: 'dynamics' or 'reward'
9 | :return:
10 | """
11 |
12 | @agent.forward_wrapper(mode)
13 | def mj_forward(action=None):
14 | """
15 | :param action: np.array of action -- if missing, agent.data.ctrl is used as action
16 | :return:
17 | """
18 |
19 | # If action wasn't explicitly given use the current one in agent's env
20 | if action is None:
21 | action = agent.data.ctrl
22 |
23 | # Convert tensor to numpy array, and make sure we're using an action value that isn't referencing agent's data
24 | # (otherwise we might get into trouble if ctrl is limited and frame_skip is larger than one)
25 | if isinstance(action, torch.Tensor):
26 | action = action.detach().numpy().copy()
27 | elif isinstance(action, np.ndarray):
28 | action = action.copy()
29 | else:
30 | raise TypeError("Expecting a torch tensor or numpy ndarray")
31 |
32 | # Make sure dtype is numpy.float64
33 | assert action.dtype == np.float64, "You must use dtype numpy.float64 for actions"
34 |
35 | # Advance simulation with one step
36 | next_state, agent.reward, agent.is_done, _ = agent.step(action)
37 |
38 | return next_state
39 |
40 | return mj_forward
41 |
--------------------------------------------------------------------------------
/mujoco/utils/wrappers/__init__.py:
--------------------------------------------------------------------------------
1 | from .mj_block import MjBlockWrapper
2 | from .etc import AccumulateWrapper, RewardScaleWrapper, ClipActionsWrapper, FixedStateWrapper, TorchTensorWrapper
3 |
4 | __all__ = ["MjBlockWrapper",
5 | "AccumulateWrapper",
6 | "RewardScaleWrapper",
7 | "ClipActionsWrapper",
8 | "FixedStateWrapper",
9 | "TorchTensorWrapper"]
10 |
--------------------------------------------------------------------------------
/mujoco/utils/wrappers/etc.py:
--------------------------------------------------------------------------------
1 | import gym
2 | import torch
3 | import numpy as np
4 | from copy import deepcopy
5 | from utils.index import Index
6 | import mujoco_py
7 | from multiprocessing import Process, Queue
8 | import skvideo
9 | import glfw
10 |
11 |
12 | class AccumulateWrapper(gym.Wrapper):
13 |
14 | def __init__(self, env, gamma):
15 | super(AccumulateWrapper, self).__init__(env)
16 | self.gamma = gamma
17 | self.accumulated_return = 0.
18 | self.accumulated_observations = []
19 | self.accumulated_rewards = []
20 |
21 | def step(self, action):
22 | observation, reward, done, info = self.env.step(action)
23 | self.accumulated_observations.append(observation)
24 | self.accumulated_rewards.append(reward)
25 | self.accumulated_return = self.gamma * self.accumulated_return + reward
26 |
27 | def reset(self, **kwargs):
28 | self.accumulated_return = 0.
29 | self.accumulated_observations = []
30 | self.accumulated_rewards = []
31 | return torch.Tensor(self.env.reset(**kwargs))
32 |
33 |
34 | class RewardScaleWrapper(gym.RewardWrapper):
35 | """Bring rewards to a reasonable scale."""
36 |
37 | def __init__(self, env, scale=0.01):
38 | super(RewardScaleWrapper, self).__init__(env)
39 | self.scale = scale
40 |
41 | def reward(self, reward):
42 | return self.scale * reward
43 |
44 | def torch_reward(self, state, action, next_state):
45 | return self.scale * self.env.torch_reward(state, action, next_state)
46 |
47 |
48 | class ClipActionsWrapper(gym.Wrapper):
49 | def step(self, action):
50 | action = np.nan_to_num(action) # less error prone but your optimization won't benefit from this
51 | action = np.clip(action, self.action_space.low, self.action_space.high)
52 | return self.env.step(action)
53 |
54 | def reset(self, **kwargs):
55 | return self.env.reset(**kwargs)
56 |
57 |
58 | class FixedStateWrapper(gym.Wrapper):
59 | def __init__(self, env):
60 | super(FixedStateWrapper, self).__init__(env)
61 | self.fixed_state = self.env.reset_model()
62 | self.fixed_qpos = self.env.sim.data.qpos
63 | self.fixed_qvel = self.env.sim.data.qvel
64 |
65 | def reset(self, **kwargs):
66 | self.env.reset(**kwargs)
67 | self.env.set_state(self.fixed_qpos, self.fixed_qvel)
68 | return self.env.env._get_obs()
69 |
70 |
71 | class TorchTensorWrapper(gym.Wrapper):
72 | """Takes care of torch Tensors in step and reset modules."""
73 |
74 | #def step(self, action):
75 | # action = action.detach().numpy()
76 | # state, reward, done, info = self.env.step(action)
77 | # return torch.Tensor(state), reward, done, info
78 |
79 | def reset(self, **kwargs):
80 | return torch.Tensor(self.env.reset(**kwargs))
81 |
82 | #def set_from_torch_state(self, state):
83 | # qpos, qvel = np.split(state.detach().numpy(), 2)
84 | # self.env.set_state(qpos, qvel)
85 |
86 | #def is_done(self, state):
87 | # state = state.detach().numpy()
88 | # return self.env.is_done(state)
89 |
90 |
91 | class SnapshotWrapper(gym.Wrapper):
92 | """Handles all stateful stuff, like getting and setting snapshots of states, and resetting"""
93 | def get_snapshot(self):
94 |
95 | class DataSnapshot:
96 | # Note: You should not modify these parameters after creation
97 |
98 | def __init__(self, d_source, step_idx):
99 | self.time = deepcopy(d_source.time)
100 | self.qpos = deepcopy(d_source.qpos)
101 | self.qvel = deepcopy(d_source.qvel)
102 | self.qacc_warmstart = deepcopy(d_source.qacc_warmstart)
103 | self.ctrl = deepcopy(d_source.ctrl)
104 | self.act = deepcopy(d_source.act)
105 | self.qfrc_applied = deepcopy(d_source.qfrc_applied)
106 | self.xfrc_applied = deepcopy(d_source.xfrc_applied)
107 |
108 | self.step_idx = deepcopy(step_idx)
109 |
110 | # These probably aren't necessary, but they should fix the body in the same position with
111 | # respect to worldbody frame?
112 | self.body_xpos = deepcopy(d_source.body_xpos)
113 | self.body_xquat = deepcopy(d_source.body_xquat)
114 |
115 | return DataSnapshot(self.env.sim.data, self.get_step_idx())
116 |
117 | def set_snapshot(self, snapshot_data):
118 | self.env.sim.data.time = deepcopy(snapshot_data.time)
119 | self.env.sim.data.qpos[:] = deepcopy(snapshot_data.qpos)
120 | self.env.sim.data.qvel[:] = deepcopy(snapshot_data.qvel)
121 | self.env.sim.data.qacc_warmstart[:] = deepcopy(snapshot_data.qacc_warmstart)
122 | self.env.sim.data.ctrl[:] = deepcopy(snapshot_data.ctrl)
123 | if snapshot_data.act is not None:
124 | self.env.sim.data.act[:] = deepcopy(snapshot_data.act)
125 | self.env.sim.data.qfrc_applied[:] = deepcopy(snapshot_data.qfrc_applied)
126 | self.env.sim.data.xfrc_applied[:] = deepcopy(snapshot_data.xfrc_applied)
127 |
128 | self.set_step_idx(snapshot_data.step_idx)
129 |
130 | self.env.sim.data.body_xpos[:] = deepcopy(snapshot_data.body_xpos)
131 | self.env.sim.data.body_xquat[:] = deepcopy(snapshot_data.body_xquat)
132 |
133 |
134 | class IndexWrapper(gym.Wrapper):
135 | """Counts steps and episodes"""
136 |
137 | def __init__(self, env, batch_size):
138 | super(IndexWrapper, self).__init__(env)
139 | self._step_idx = Index(0)
140 | self._episode_idx = Index(0)
141 | self._batch_idx = Index(0)
142 | self._batch_size = batch_size
143 |
144 | # Add references to the unwrapped env because we're going to need at least step_idx
145 | # NOTE! This means we can never overwrite self.step_idx, self.episode_idx, or self.batch_idx or we lose
146 | # the reference
147 | env.unwrapped._step_idx = self._step_idx
148 | env.unwrapped._episode_idx = self._episode_idx
149 | env.unwrapped._batch_idx = self._batch_idx
150 |
151 | def step(self, action):
152 | self._step_idx += 1
153 | return self.env.step(action)
154 |
155 | def reset(self, update_episode_idx=True):
156 | self._step_idx.set(0)
157 |
158 | # We don't want to update episode_idx during testing
159 | if update_episode_idx:
160 | if self._episode_idx == self._batch_size:
161 | self._batch_idx += 1
162 | self._episode_idx.set(1)
163 | else:
164 | self._episode_idx += 1
165 |
166 | return self.env.reset()
167 |
168 | def get_step_idx(self):
169 | return self._step_idx
170 |
171 | def get_episode_idx(self):
172 | return self._episode_idx
173 |
174 | def get_batch_idx(self):
175 | return self._batch_idx
176 |
177 | def set_step_idx(self, idx):
178 | self._step_idx.set(idx)
179 |
180 | def set_episode_idx(self, idx):
181 | self._episode_idx.set(idx)
182 |
183 | def set_batch_idx(self, idx):
184 | self._batch_idx.set(idx)
185 |
186 |
187 | class ViewerWrapper(gym.Wrapper):
188 |
189 | def __init__(self, env):
190 | super(ViewerWrapper, self).__init__(env)
191 |
192 | # Keep params in this class to reduce clutter
193 | class Recorder:
194 | width = 1600
195 | height = 1200
196 | imgs = []
197 | record = False
198 | filepath = None
199 |
200 | self.recorder = Recorder
201 |
202 | # Check if we want to record roll-outs
203 | if self.cfg.LOG.TESTING.RECORD_VIDEO:
204 | self.recorder.record = True
205 |
206 | # Create a viewer if we're not recording
207 | else:
208 | # Initialise a MjViewer
209 | self._viewer = mujoco_py.MjViewer(self.sim)
210 | self._viewer._run_speed = 1/self.cfg.MODEL.FRAME_SKIP
211 | self.unwrapped._viewers["human"] = self._viewer
212 |
213 | def capture_frame(self):
214 | if self.recorder.record:
215 | self.recorder.imgs.append(np.flip(self.sim.render(self.recorder.width, self.recorder.height), axis=0))
216 |
217 | def start_recording(self, filepath):
218 | if self.recorder.record:
219 | self.recorder.filepath = filepath
220 | self.recorder.imgs.clear()
221 |
222 | def stop_recording(self):
223 | if self.recorder.record:
224 | writer = skvideo.io.FFmpegWriter(
225 | self.recorder.filepath, inputdict={"-s": "{}x{}".format(self.recorder.width, self.recorder.height),
226 | "-r": str(1 / (self.model.opt.timestep*self.cfg.MODEL.FRAME_SKIP))})
227 | for img in self.recorder.imgs:
228 | writer.writeFrame(img)
229 | writer.close()
230 | self.recorder.imgs.clear()
231 |
--------------------------------------------------------------------------------
/mujoco/utils/wrappers/mj_block.py:
--------------------------------------------------------------------------------
1 | import gym
2 | import numpy as np
3 |
4 | from mujoco.utils.forward import mj_forward_factory
5 | from mujoco.utils.backward import mj_gradients_factory
6 |
7 |
8 | class MjBlockWrapper(gym.Wrapper):
9 | """Wrap the forward and backward model. Further used for PyTorch blocks."""
10 |
11 | def __init__(self, env):
12 | gym.Wrapper.__init__(self, env)
13 |
14 | def gradient_factory(self, mode):
15 | """
16 | :param mode: 'dynamics' or 'reward'
17 | :return:
18 | """
19 | # TODO: due to dynamics and reward isolation, this isn't the
20 | # most efficient way to handle this; lazy, but simple
21 | #env = self.clone()
22 | return mj_gradients_factory(self, mode)
23 |
24 | def forward_factory(self, mode):
25 | """
26 | :param mode: 'dynamics' or 'reward'
27 | :return:
28 | """
29 | #env = self.clone()
30 | return mj_forward_factory(self, mode)
31 |
32 | def gradient_wrapper(self, mode):
33 | """
34 | Decorator for making gradients be the same size the observations for example.
35 | :param mode: either 'dynamics' or 'reward'
36 | :return:
37 | """
38 |
39 | # mode agnostic for now
40 | def decorator(gradients_fn):
41 | def wrapper(*args, **kwargs):
42 | #if mode == "forward":
43 | gradients_fn(*args, **kwargs)
44 | #else:
45 | # dfds, dfda = gradients_fn(*args, **kwargs)
46 | # # no further reshaping is needed for the case of hopper, also it's mode-agnostic
47 | # gradients = np.concatenate([dfds, dfda], axis=1)
48 | return
49 |
50 | return wrapper
51 |
52 | return decorator
53 |
54 | def forward_wrapper(self, mode):
55 | """
56 | Decorator for making gradients be the same size the observations for example.
57 | :param mode: either 'dynamics' or 'reward'
58 | :return:
59 | """
60 |
61 | # mode agnostic for now
62 | def decorator(forward_fn):
63 | def wrapper(*args, **kwargs):
64 | f = forward_fn(*args, **kwargs) # next state
65 | # no further reshaping is needed for the case of hopper, also it's mode-agnostic
66 | return f
67 |
68 | return wrapper
69 |
70 | return decorator
71 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | mujoco_py
3 | numpy
4 | gym
5 | yacs
6 | visdom
7 |
--------------------------------------------------------------------------------
/solver/__init__.py:
--------------------------------------------------------------------------------
1 | from .build import build_optimizer
2 |
3 | __all__ = ["build_optimizer"]
4 |
--------------------------------------------------------------------------------
/solver/build.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def build_optimizer(cfg, named_parameters):
5 | params = []
6 | lr = cfg.SOLVER.BASE_LR
7 | for key, value in named_parameters:
8 | if not value.requires_grad:
9 | continue
10 | #lr = cfg.SOLVER.BASE_LR / cfg.SOLVER.BATCH_SIZE # due to funny way gradients are batched
11 | #weight_decay = cfg.SOLVER.WEIGHT_DECAY
12 | weight_decay = 0.0
13 | if "bias" in key:
14 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR
15 | #weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
16 | if any(x in key for x in ["std", "sd"]):
17 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.STD_LR_FACTOR
18 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_SD
19 | params += [{"params": [value], "lr": lr}]
20 |
21 | if cfg.SOLVER.OPTIMIZER == "sgd":
22 | optimizer = torch.optim.SGD(params, lr)
23 | else:
24 | optimizer = torch.optim.Adam(params, lr, betas=cfg.SOLVER.ADAM_BETAS)
25 | return optimizer
26 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_gradients.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import torch
4 | import numpy as np
5 | from model.config import get_cfg_defaults
6 | from mujoco import build_agent
7 | from utils.index import Index
8 | from model import build_model
9 |
10 | class TestGradients(unittest.TestCase):
11 |
12 | def test_reward_gradients(self):
13 | cfg = get_cfg_defaults()
14 | cfg.MUJOCO.ENV = "InvertedPendulumEnv"
15 | agent = build_agent(cfg)
16 | mj_reward_forward_fn = agent.forward_factory("reward")
17 | mj_reward_gradients_fn = agent.gradient_factory("reward")
18 |
19 | nwarmup = 5
20 | agent.reset()
21 | nv = agent.sim.model.nv
22 | nu = agent.sim.model.nu
23 | for _ in range(nwarmup):
24 | action = torch.Tensor(agent.action_space.sample())
25 | ob, r, _, _ = agent.step(action)
26 | state = ob.detach().numpy()
27 | action = agent.action_space.sample()
28 | state_action = np.concatenate([state, action], axis=0)
29 | drdsa = mj_reward_gradients_fn(state_action)
30 | drds = drdsa[:, :nv * 2]
31 | drda = drdsa[:, -nu:]
32 |
33 | eps = 1e-6
34 | state_action_prime = state_action + eps
35 | r = mj_reward_forward_fn(state_action)
36 | r_prime = mj_reward_forward_fn(state_action_prime)
37 |
38 | r_prime_estimate = r + \
39 | np.squeeze(np.matmul(drds, np.array([eps] * 2 * nv).reshape([-1, 1]))) + \
40 | np.squeeze(np.matmul(drda, np.array([eps] * nu).reshape([-1, 1])))
41 | self.assertAlmostEqual(r_prime[0], r_prime_estimate[0], places=5)
42 |
43 | def test_dynamics_gradients(self):
44 | cfg = get_cfg_defaults()
45 | cfg.merge_from_file("/home/aleksi/Workspace/Model-Based-RL/configs/swimmer.yaml")
46 | agent = build_agent(cfg)
47 | mj_dynamics_forward_fn = agent.forward_factory("dynamics")
48 | mj_dynamics_gradients_fn = agent.gradient_factory("dynamics")
49 |
50 | nwarmup = 5
51 | agent.reset()
52 | nv = agent.sim.model.nv
53 | nu = agent.sim.model.nu
54 | for _ in range(nwarmup):
55 | action = torch.Tensor(agent.action_space.sample())
56 | ob, r, _, _ = agent.step(action)
57 | state = ob.detach().numpy()
58 | action = agent.action_space.sample()
59 | state_action = np.concatenate([state, action], axis=0)
60 | dsdsa = mj_dynamics_gradients_fn(state_action)
61 | dsds = dsdsa[:, :nv * 2]
62 | dsda = dsdsa[:, -nu:]
63 |
64 | eps = 1e-6
65 | state_action_prime = state_action + eps
66 | s = mj_dynamics_forward_fn(state_action)
67 | s_prime = mj_dynamics_forward_fn(state_action_prime)
68 |
69 | s_prime_estimate = s + \
70 | np.squeeze(np.matmul(dsds, np.array([eps] * 2 * nv).reshape([-1, 1]))) + \
71 | np.squeeze(np.matmul(dsda, np.array([eps] * nu).reshape([-1, 1])))
72 | print(s)
73 | print(s_prime)
74 | print(s_prime_estimate)
75 | self.assert_(np.allclose(s_prime, s_prime_estimate))
76 |
77 | def test_action_rewards(self):
78 |
79 | # Run multiple times with different sigma for generating action values
80 | N = 100
81 | sigma = np.linspace(1e-2, 1e1, N)
82 | cfg_file = "/home/aleksi/Workspace/Model-Based-RL/configs/swimmer.yaml"
83 |
84 | for s in sigma:
85 | self.run_reward_test(cfg_file, s)
86 |
87 | def run_reward_test(self, cfg_file, sigma):
88 |
89 | cfg = get_cfg_defaults()
90 | cfg.merge_from_file(cfg_file)
91 | agent = build_agent(cfg)
92 | mj_forward_fn = agent.forward_factory("dynamics")
93 | mj_gradients_fn = agent.gradient_factory("reward")
94 |
95 | model = build_model(cfg, agent)
96 | device = torch.device(cfg.MODEL.DEVICE)
97 | model.to(device)
98 |
99 | # Start from the same state with constant action, make sure reward is equal in both repetitions
100 |
101 | # Drive both simulations forward 5 steps
102 | nwarmup = 5
103 |
104 | # Reset and get initial state
105 | agent.reset()
106 | init_qpos = agent.data.qpos.copy()
107 | init_qvel = agent.data.qvel.copy()
108 |
109 | # Set constant action
110 | na = agent.model.actuator_acc0.shape[0]
111 | action = torch.DoubleTensor(np.random.randn(na)*sigma)
112 |
113 | # Do first simulation
114 | for _ in range(nwarmup):
115 | mj_forward_fn(action)
116 |
117 | # Take a snapshot of this state so we can use it in gradient calculations
118 | agent.data.ctrl[:] = action.detach().numpy().copy()
119 | data = agent.get_snapshot()
120 |
121 | # Advance simulation with one step and get the reward
122 | mj_forward_fn(action)
123 | reward = agent.reward.copy()
124 | next_state = np.concatenate((agent.data.qpos.copy(), agent.data.qvel.copy()))
125 |
126 | # Reset and set to initial state, then do the second simulation; this time call mj_forward_fn without args
127 | agent.reset()
128 | agent.data.qpos[:] = init_qpos
129 | agent.data.qvel[:] = init_qvel
130 | for _ in range(nwarmup):
131 | agent.data.ctrl[:] = action.detach().numpy()
132 | mj_forward_fn()
133 |
134 | # Advance simulation with one step and get reward
135 | agent.data.ctrl[:] = action.detach().numpy()
136 | mj_forward_fn()
137 | reward2 = agent.reward.copy()
138 |
139 | # reward1 and reward2 should be equal
140 | self.assertEqual(reward, reward2, "Simulations from same initial state diverged")
141 |
142 | # Then make sure simulation from snapshot doesn't diverge from original simulation
143 | agent.set_snapshot(data)
144 | mj_forward_fn()
145 | reward_snapshot = agent.reward.copy()
146 | self.assertEqual(reward, reward_snapshot, "Simulation from snapshot diverged")
147 |
148 | # Make sure simulations are correct in the gradient calculations as well
149 | mj_gradients_fn(data, next_state, reward, test=True)
150 |
151 | def test_dynamics_reward(self):
152 |
153 | # Run multiple times with different sigma for generating action values
154 | N = 100
155 | sigma = np.linspace(1e-2, 1e1, N)
156 | cfg_file = "/home/aleksi/Workspace/Model-Based-RL/configs/leg.yaml"
157 |
158 | for s in sigma:
159 | self.run_dynamics_reward_test(cfg_file, s)
160 |
161 | def run_dynamics_reward_test(self, cfg_file, sigma):
162 |
163 | cfg = get_cfg_defaults()
164 | cfg.merge_from_file(cfg_file)
165 | agent = build_agent(cfg)
166 | mj_dynamics_forward_fn = agent.forward_factory("dynamics")
167 | #mj_reward_forward_fn = agent.forward_factory("reward")
168 | #mj_dynamics_gradients_fn = agent.gradient_factory("dynamics")
169 | mj_reward_gradients_fn = agent.gradient_factory("reward")
170 |
171 | # Start from the same state with constant action, make sure reward is equal in both repetitions
172 |
173 | # Drive both simulations forward 5 steps
174 | nwarmup = 5
175 |
176 | # Reset and get initial state
177 | agent.reset()
178 | init_qpos = agent.data.qpos.copy()
179 | init_qvel = agent.data.qvel.copy()
180 |
181 | # Set constant action
182 | na = agent.model.actuator_acc0.shape[0]
183 | action = torch.DoubleTensor(np.random.randn(na)*sigma)
184 |
185 | # Do first simulation
186 | for _ in range(nwarmup):
187 | mj_dynamics_forward_fn(action)
188 |
189 | # Take a snapshot of this state so we can use it in gradient calculations
190 | agent.data.ctrl[:] = action.detach().numpy().copy()
191 | data = agent.get_snapshot()
192 |
193 | # Advance simulation with one step and get the reward
194 | mj_dynamics_forward_fn(action)
195 | reward = agent.reward.copy()
196 | next_state = np.concatenate((agent.data.qpos.copy(), agent.data.qvel.copy()))
197 |
198 | # Reset and set to initial state, then do the second simulation; this time call mj_forward_fn without args
199 | agent.reset()
200 | agent.data.qpos[:] = init_qpos
201 | agent.data.qvel[:] = init_qvel
202 | for _ in range(nwarmup):
203 | agent.data.ctrl[:] = action.detach().numpy()
204 | mj_dynamics_forward_fn()
205 |
206 | # Advance simulation with one step and get reward
207 | agent.data.ctrl[:] = action.detach().numpy()
208 | mj_dynamics_forward_fn()
209 | reward2 = agent.reward.copy()
210 |
211 | # reward1 and reward2 should be equal
212 | self.assertEqual(reward, reward2, "Simulations from same initial state diverged")
213 |
214 | # Then make sure simulation from snapshot doesn't diverge from original simulation
215 | agent.set_snapshot(data)
216 | mj_dynamics_forward_fn()
217 | reward_snapshot = agent.reward.copy()
218 | self.assertEqual(reward, reward_snapshot, "Simulation from snapshot diverged")
219 |
220 | # Make sure simulations are correct in the gradient calculations as well
221 | mj_reward_gradients_fn(data, next_state, reward, test=True)
222 |
223 |
224 | if __name__ == '__main__':
225 | unittest.main()
226 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/utils/__init__.py
--------------------------------------------------------------------------------
/utils/index.py:
--------------------------------------------------------------------------------
1 | class Index(object):
2 | def __init__(self, value=0):
3 | self.value = value
4 |
5 | def __index__(self):
6 | return self.value
7 |
8 | def __iadd__(self, other):
9 | self.value += other
10 | return self
11 |
12 | def __add__(self, other):
13 | if isinstance(other, Index):
14 | return Index(self.value + other.value)
15 | elif isinstance(other, int):
16 | return Index(self.value + other)
17 | else:
18 | raise NotImplementedError
19 |
20 | def __sub__(self, other):
21 | if isinstance(other, Index):
22 | return Index(self.value - other.value)
23 | elif isinstance(other, int):
24 | return Index(self.value - other)
25 | else:
26 | raise NotImplementedError
27 |
28 | def __eq__(self, other):
29 | return self.value == other
30 |
31 | def __repr__(self):
32 | return "Index({})".format(self.value)
33 |
34 | def __str__(self):
35 | return "{}".format(self.value)
36 |
37 | def set(self, other):
38 | if isinstance(other, int):
39 | self.value = other
40 | elif isinstance(other, Index):
41 | self.value = other.value
42 | else:
43 | raise NotImplementedError
44 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import sys
4 | import csv
5 | from datetime import datetime
6 |
7 | def setup_logger(name, save_dir, txt_file_name='log'):
8 | logger = logging.getLogger(name)
9 | logger.setLevel(logging.DEBUG)
10 | ch = logging.StreamHandler(stream=sys.stdout)
11 | ch.setLevel(logging.DEBUG)
12 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
13 | ch.setFormatter(formatter)
14 | logger.addHandler(ch)
15 |
16 | if save_dir:
17 | fh = logging.FileHandler(os.path.join(save_dir, "{}.txt".format(txt_file_name)))
18 | fh.setLevel(logging.DEBUG)
19 | fh.setFormatter(formatter)
20 | logger.addHandler(fh)
21 |
22 | return logger
23 |
24 |
25 | def save_dict_into_csv(save_dir, file_name, output):
26 | try:
27 | # Append file_name with datetime
28 | #file_name = os.path.join(save_dir, file_name + "_{0:%Y-%m-%dT%H:%M:%S}".format(datetime.now()))
29 | file_name = os.path.join(save_dir, file_name)
30 | with open(file_name, "w") as file:
31 | writer = csv.writer(file)
32 | writer.writerow(output.keys())
33 | writer.writerows(zip(*output.values()))
34 | except IOError:
35 | print("Failed to save file {}".format(file_name))
36 |
--------------------------------------------------------------------------------
/utils/visdom_plots.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import visdom
3 | import numpy as np
4 | from subprocess import Popen, PIPE
5 |
6 |
7 | def create_visdom_connections(port):
8 | """If the program could not connect to Visdom server, this function will start a new server at port < port > """
9 | cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % port
10 | print('\n\nCould not connect to Visdom server. \n Trying to start a server....')
11 | print('Command: %s' % cmd)
12 | Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
13 |
14 |
15 | class VisdomLogger(dict):
16 | """Plot losses."""
17 |
18 | def __init__(self, port, *args, **kwargs):
19 | self.visdom_port = port
20 | self.visdom = visdom.Visdom(port=port)
21 | if not self.visdom.check_connection():
22 | create_visdom_connections(port)
23 | super(VisdomLogger, self).__init__(*args, **kwargs)
24 | self.registered = False
25 | self.plot_attributes = {}
26 |
27 | def register_keys(self, keys):
28 | for key in keys:
29 | self[key] = []
30 | for i, key in enumerate(self.keys()):
31 | self.plot_attributes[key] = {'win_id': i}
32 | self.registered = True
33 |
34 | def update(self, new_records):
35 | """Add new updates to records.
36 |
37 | :param new_records: dict of new updates. example: {'loss': [10., 5., 2.], 'lr': [1e-3, 1e-3, 5e-4]}
38 | """
39 | for key, val in new_records.items():
40 | if key in self.keys():
41 | if isinstance(val, list):
42 | self[key].extend(val)
43 | else:
44 | self[key].append(val)
45 |
46 | def set(self, record):
47 | """Set a record (replaces old data)"""
48 | for key, val in record.items():
49 | if key in self.keys():
50 | self[key] = val
51 |
52 | def do_plotting(self):
53 | for k in self.keys():
54 | y_values = np.array(self[k])
55 | x_values = np.arange(len(self[k]))
56 | if len(x_values) < 1:
57 | continue
58 | # if y_values.size > 1:
59 | # x_values = np.reshape(x_values, (len(x_values), 1))
60 | if len(y_values.shape) > 1 and y_values.shape[1]==1 and y_values.size != 1:
61 | y_values = y_values.squeeze()
62 | self.visdom.line(Y=y_values, X=x_values, win=self.plot_attributes[k]['win_id'],
63 | opts={'title': k.upper()}, update='append')
64 | # self[k] = []
65 |
--------------------------------------------------------------------------------