├── .gitignore ├── README.md ├── configs ├── HandModelTSLAdjusted.yaml ├── Model-Based-RL.yml ├── half_cheetah.yaml ├── hopper.yaml ├── inverted_double_pendulum.yaml ├── inverted_pendulum.yaml ├── leg.yaml ├── swimmer.yaml └── walker2d.yaml ├── main.py ├── model ├── __init__.py ├── archs │ ├── __init__.py │ └── basic.py ├── blocks │ ├── __init__.py │ ├── mujoco.py │ ├── policy │ │ ├── __init__.py │ │ ├── base.py │ │ ├── build.py │ │ ├── deterministic.py │ │ ├── dynamics.py │ │ ├── stochastic.py │ │ ├── strategies.py │ │ ├── trajopt.py │ │ └── wrappers │ │ │ ├── __init__.py │ │ │ └── zmus.py │ └── utils │ │ ├── __init__.py │ │ ├── build.py │ │ └── functions.py ├── build.py ├── config │ ├── __init__.py │ └── defaults.py ├── engine │ ├── __init__.py │ ├── dynamics_model_trainer.py │ ├── landscape_plot.py │ ├── tester.py │ ├── trainer.py │ └── utils │ │ ├── __init__.py │ │ ├── build.py │ │ └── experience_replay.py └── layers │ ├── __init__.py │ └── feed_forward.py ├── mujoco ├── __init__.py ├── assets │ ├── Geometry │ │ ├── bofoot.stl │ │ ├── calcn_r.stl │ │ ├── capitate_rFB.stl │ │ ├── clavicle_rFB.stl │ │ ├── femur_r.stl │ │ ├── fibula_r.stl │ │ ├── foot.stl │ │ ├── hamate_rFB.stl │ │ ├── humerus_rFB.stl │ │ ├── indexDistal_rFB.stl │ │ ├── indexMetacarpal_rFB.stl │ │ ├── indexMid_rFB.stl │ │ ├── indexProx_rFB.stl │ │ ├── l_pelvis.stl │ │ ├── lunate_rFB.stl │ │ ├── middleDistal_rFB.stl │ │ ├── middleMetacarpal_rFB.stl │ │ ├── middleMid_rFB.stl │ │ ├── middleProx_rFB.stl │ │ ├── pat.stl │ │ ├── patella_r.stl │ │ ├── pelvis.stl │ │ ├── pinkyDistal_rFB.stl │ │ ├── pinkyMetacarpal_rFB.stl │ │ ├── pinkyMid_rFB.stl │ │ ├── pinkyProx_rFB.stl │ │ ├── pisiform_rFB.stl │ │ ├── radius_rFB.stl │ │ ├── ribcageFB.stl │ │ ├── ringDistal_rFB.stl │ │ ├── ringMetacarpal_rFB.stl │ │ ├── ringMid_rFB.stl │ │ ├── ringProx_rFB.stl │ │ ├── sacrum.stl │ │ ├── scaphoid_rFB.stl │ │ ├── scapula_rFB.stl │ │ ├── talus.stl │ │ ├── talus_r.stl │ │ ├── thoracic10FB.stl │ │ ├── thoracic11FB.stl │ │ ├── thoracic12FB.stl │ │ ├── thoracic1FB.stl │ │ ├── thoracic2FB.stl │ │ ├── thoracic3FB.stl │ │ ├── thoracic4FB.stl │ │ ├── thoracic5FB.stl │ │ ├── thoracic6FB.stl │ │ ├── thoracic7FB.stl │ │ ├── thoracic8FB.stl │ │ ├── thoracic9FB.stl │ │ ├── thumbDistal_rFB.stl │ │ ├── thumbMid_rFB.stl │ │ ├── thumbProx_rFB.stl │ │ ├── tibia_r.stl │ │ ├── toes_r.stl │ │ ├── trapezium_rFB.stl │ │ ├── trapezoid_rFB.stl │ │ ├── triquetral_rFB.stl │ │ └── ulna_rFB.stl │ ├── HandModelTSLAdjusted_converted.xml │ ├── half_cheetah.xml │ ├── hopper.xml │ ├── inverted_double_pendulum.xml │ ├── inverted_pendulum.xml │ ├── leg6dof9musc_converted.xml │ ├── swimmer.xml │ └── walker2d.xml ├── build.py ├── envs │ ├── HandModelTSLAdjusted.py │ ├── __init__.py │ ├── half_cheetah.py │ ├── hopper.py │ ├── inverted_double_pendulum.py │ ├── inverted_pendulum.py │ ├── leg.py │ ├── swimmer.py │ └── walker2d.py └── utils │ ├── __init__.py │ ├── backward.py │ ├── forward.py │ └── wrappers │ ├── __init__.py │ ├── etc.py │ └── mj_block.py ├── requirements.txt ├── solver ├── __init__.py └── build.py ├── tests ├── __init__.py └── test_gradients.py └── utils ├── __init__.py ├── index.py ├── logger.py └── visdom_plots.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__ 3 | output/ 4 | MUJOCO_LOG.TXT 5 | .pytest_cache 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Model-based Reinforcement Learning 2 | 3 | Directly back-propagate into your policy network, from model jacobians calculated in MuJoCo using finite-difference. 4 | 5 | To backprop into stochastic policies, given an unknown model, one has to use the REINFORCE theory, to be able to calculate the gradients by sampling the environment. These methods usually have high variance, so baselines and value/advantage functions were introduced. Another way to backpropagate into your policy network is to use the “reparameterization trick” as in VAEs, but they entail knowledge of upstream gradients, and hence a known model. The policy gradients calculated w/ the reparam trick are often much lower in variance, so one can go wo/ baselines and value networks. This project puts it all together: computation graph of policy and dynamics, upstream gradients from MuJoCo dynamics and rewards, reparam trick, and optimization. 6 | 7 | ### Vanilla Computation Graph 8 | ```txt 9 | +----------+S0+----------+ +----------+S1+----------+ 10 | | | | | 11 | | +------+ A0 +---v----+ + +------+ A1 +---v----+ 12 | S0+------>+Policy+---+--->+Dynamics+---+---+S1+-->+Policy+---+--->+Dynamics+--->S2 ... 13 | | +------+ | +--------+ | + +------+ | +--------+ | 14 | | | | | | | 15 | | +--v---+ | | +--v---+ | 16 | +---+S0+---->+Reward+<-----S1-----+ +---+S1+---->+Reward+<-----S2------+ 17 | +------+ +------+ 18 | ``` 19 | 20 | ### Results 21 | 22 | 23 | 24 | 25 | 26 | 27 | ### This repo contains: 28 | * Finite-difference calculation of MuJoCo dynamics jacobians in `mujoco-py` 29 | * MuJoCo dynamics as a PyTorch Operation (i.e. forward and backward pass) 30 | * Reward function PyTorch Operation 31 | * Flexible design to wire up your own meta computation graph 32 | * Trajectory Optimization module alongside Policy Networks 33 | * Flexible design to define your own environment in `gym` 34 | * Fancy logger and monitoring 35 | 36 | ### Dependencies 37 | Python3.6: 38 | * `torch` 39 | * `mujoco-py` 40 | * `gym` 41 | * `numpy` 42 | * `visdom` 43 | 44 | Other: 45 | * Tested w/ `mujoco200` 46 | 47 | ### Usage 48 | For latest changes: 49 | ```bash 50 | git clone -b development git@github.com:MahanFathi/Model-Based-RL.git 51 | ``` 52 | Run: 53 | ```bash 54 | python3 main.py --config-file ./configs/inverted_pendulum.yaml 55 | ``` 56 | -------------------------------------------------------------------------------- /configs/HandModelTSLAdjusted.yaml: -------------------------------------------------------------------------------- 1 | MUJOCO: 2 | ENV: 'HandModelTSLAdjustedEnv' 3 | CLIP_ACTIONS: false 4 | MODEL: 5 | META_ARCHITECTURE: 'Basic' 6 | BATCH_SIZE: 1 7 | EPOCHS: 1000 8 | NSTEPS_FOR_BACKWARD: 2 9 | FRAME_SKIP: 20 10 | TIMESTEP: 0.002 11 | POLICY: 12 | ARCH: 'Perttu' 13 | METHOD: 'R-VO' 14 | LAYERS: 15 | - 128 16 | - 128 17 | GAMMA: 0.99 18 | MAX_HORIZON_STEPS: 50 19 | INITIAL_SD: 0.2 20 | INITIAL_ACTION_MEAN: 0.0 21 | INITIAL_ACTION_SD: 0.0 22 | SOLVER: 23 | BASE_LR: 0.05 24 | WEIGHT_DECAY: 0.0 25 | BIAS_LR_FACTOR: 1 26 | STD_LR_FACTOR: 1.0 27 | ADAM_BETAS: (0.9, 0.999) 28 | LOG: 29 | TESTING: 30 | ENABLED: True 31 | ITER_PERIOD: 1 32 | RECORD_VIDEO: true -------------------------------------------------------------------------------- /configs/Model-Based-RL.yml: -------------------------------------------------------------------------------- 1 | name: Model-Based-RL 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - anaconda 6 | - https://conda.anaconda.org/conda-forge/ 7 | - defaults 8 | dependencies: 9 | - _libgcc_mutex=0.1=main 10 | - asn1crypto=1.2.0=py37_0 11 | - attrs=19.3.0=py_0 12 | - backcall=0.1.0=py_0 13 | - blas=1.0=mkl 14 | - bleach=3.1.0=py_0 15 | - bzip2=1.0.8=h7b6447c_0 16 | - ca-certificates=2019.11.28=hecc5488_0 17 | - certifi=2019.11.28=py37_0 18 | - cffi=1.12.3=py37h8022711_0 19 | - chardet=3.0.4=py37_1003 20 | - click=7.0=py_0 21 | - cryptography=2.7=py37h72c5cf5_0 22 | - cudatoolkit=10.1.168=0 23 | - curl=7.65.3=hbc83047_0 24 | - cycler=0.10.0=py37_0 25 | - cython=0.29.13=py37he6710b0_0 26 | - dbus=1.13.6=h746ee38_0 27 | - defusedxml=0.6.0=py_0 28 | - entrypoints=0.3=py37_1000 29 | - expat=2.2.6=he6710b0_0 30 | - ffmpeg=4.2=h167e202_0 31 | - fontconfig=2.13.0=h9420a91_0 32 | - freetype=2.9.1=h8a8886c_1 33 | - future=0.17.1=py37_0 34 | - glib=2.56.2=hd408876_0 35 | - gmp=6.1.2=hf484d3e_1000 36 | - gnutls=3.6.5=hd3a4fd2_1002 37 | - gst-plugins-base=1.14.0=hbbd80ab_1 38 | - gstreamer=1.14.0=hb453b48_1 39 | - hdf4=4.2.13=h3ca952b_2 40 | - hdf5=1.10.4=hb1b8bf9_0 41 | - icu=58.2=h9c2bf20_1 42 | - idna=2.8=py37_1000 43 | - imageio=2.6.1=py37_0 44 | - imageio-ffmpeg=0.3.0=py_0 45 | - importlib_metadata=1.2.0=py37_0 46 | - intel-openmp=2019.4=243 47 | - ipykernel=5.1.3=py37h5ca1d4c_0 48 | - ipython_genutils=0.2.0=py_1 49 | - ipywidgets=7.5.1=py_0 50 | - jedi=0.15.1=py37_0 51 | - jinja2=2.10.3=py_0 52 | - jpeg=9b=h024ee3a_2 53 | - json5=0.8.5=py_0 54 | - jsoncpp=1.8.4=hfd86e86_0 55 | - jsonpatch=1.24=py_0 56 | - jsonpointer=2.0=py_0 57 | - jsonschema=3.2.0=py37_0 58 | - jupyter=1.0.0=py_2 59 | - jupyter_client=5.3.3=py37_1 60 | - jupyter_console=5.2.0=py37_1 61 | - jupyter_core=4.6.1=py37_0 62 | - jupyterlab=1.2.3=py_0 63 | - jupyterlab_server=1.0.6=py_0 64 | - kiwisolver=1.1.0=py37he6710b0_0 65 | - krb5=1.16.1=h173b8e3_7 66 | - lame=3.100=h14c3975_1001 67 | - libcurl=7.65.3=h20c2e04_0 68 | - libedit=3.1.20181209=hc058e9b_0 69 | - libffi=3.2.1=hd88cf55_4 70 | - libgcc-ng=9.1.0=hdf63c60_0 71 | - libgfortran-ng=7.3.0=hdf63c60_0 72 | - libiconv=1.15=h516909a_1005 73 | - libnetcdf=4.6.1=h11d0813_2 74 | - libogg=1.3.2=h7b6447c_0 75 | - libpng=1.6.37=hbc83047_0 76 | - libprotobuf=3.10.1=h8b12597_0 77 | - libsodium=1.0.17=h516909a_0 78 | - libssh2=1.8.2=h1ba5d50_0 79 | - libstdcxx-ng=9.1.0=hdf63c60_0 80 | - libtheora=1.1.1=h5ab3b9f_1 81 | - libtiff=4.0.10=h2733197_2 82 | - libuuid=1.0.3=h1bed415_2 83 | - libvorbis=1.3.6=h7b6447c_0 84 | - libxcb=1.13=h1bed415_1 85 | - libxml2=2.9.9=hea5a465_1 86 | - lockfile=0.12.2=py_1 87 | - lz4-c=1.8.1.2=h14c3975_0 88 | - markdown=3.1.1=py_0 89 | - markupsafe=1.1.1=py37h516909a_0 90 | - matplotlib=3.1.1=py37h5429711_0 91 | - mistune=0.8.4=py37h516909a_1000 92 | - mkl=2019.4=243 93 | - mkl-service=2.3.0=py37he904b0f_0 94 | - mkl_fft=1.0.14=py37ha843d7b_0 95 | - mkl_random=1.1.0=py37hd6b4f25_0 96 | - more-itertools=8.0.2=py_0 97 | - nbconvert=5.6.1=py37_0 98 | - nbformat=4.4.0=py_1 99 | - ncurses=6.1=he6710b0_1 100 | - nettle=3.4.1=h1bed415_1002 101 | - ninja=1.9.0=h6bb024c_0 102 | - notebook=6.0.1=py37_0 103 | - numpy=1.17.2=py37haad9e8e_0 104 | - numpy-base=1.17.2=py37hde5b4d6_0 105 | - olefile=0.46=py_0 106 | - openh264=1.8.0=hdbcaa40_1000 107 | - openssl=1.1.1d=h516909a_0 108 | - pandas=0.25.1=py37he6710b0_0 109 | - pandoc=2.8.1=0 110 | - pandocfilters=1.4.2=py_1 111 | - parso=0.5.1=py_0 112 | - patchelf=0.9=he6710b0_3 113 | - patsy=0.5.1=py37_0 114 | - pcre=8.43=he6710b0_0 115 | - pexpect=4.7.0=py37_0 116 | - pillow=6.2.0=py37h34e0f95_0 117 | - pip=19.2.3=py37_0 118 | - prometheus_client=0.7.1=py_0 119 | - prompt_toolkit=3.0.2=py_0 120 | - protobuf=3.10.1=py37he1b5a44_0 121 | - ptyprocess=0.6.0=py_1001 122 | - pycparser=2.19=py37_1 123 | - pyopenssl=19.0.0=py37_0 124 | - pyparsing=2.4.2=py_0 125 | - pyqt=5.9.2=py37h05f1152_2 126 | - pyrsistent=0.15.6=py37h516909a_0 127 | - pysocks=1.7.1=py37_0 128 | - python=3.7.4=h265db76_1 129 | - python-dateutil=2.8.0=py37_0 130 | - pytorch=1.3.0=py3.7_cuda10.1.243_cudnn7.6.3_0 131 | - pytz=2019.2=py_0 132 | - pyyaml=5.1.2=py37h516909a_0 133 | - pyzmq=18.1.0=py37h1768529_0 134 | - qt=5.9.7=h5867ecd_1 135 | - qtconsole=4.6.0=py_0 136 | - readline=7.0=h7b6447c_5 137 | - requests=2.22.0=py37_1 138 | - scipy=1.3.1=py37h7c811a0_0 139 | - seaborn=0.9.0=pyh91ea838_1 140 | - send2trash=1.5.0=py_0 141 | - setuptools=41.2.0=py37_0 142 | - sip=4.19.8=py37hf484d3e_0 143 | - six=1.12.0=py37_0 144 | - sk-video=1.1.10=pyh24bf2e0_4 145 | - sqlite=3.29.0=h7b6447c_0 146 | - statsmodels=0.10.1=py37hdd07704_0 147 | - tbb=2019.4=hfd86e86_0 148 | - tensorboard=1.14.0=py37_0 149 | - terminado=0.8.3=py37_0 150 | - testpath=0.4.4=py_0 151 | - tk=8.6.8=hbc83047_0 152 | - torchfile=0.1.0=py_0 153 | - tornado=6.0.3=py37h7b6447c_0 154 | - traitlets=4.3.3=py37_0 155 | - urllib3=1.25.6=py37_0 156 | - visdom=0.1.8.9=0 157 | - vtk=8.2.0=py37haa4764d_200 158 | - webencodings=0.5.1=py_1 159 | - websocket-client=0.56.0=py37_0 160 | - werkzeug=0.16.0=py_0 161 | - wheel=0.33.6=py37_0 162 | - widgetsnbextension=3.5.1=py37_0 163 | - x264=1!152.20180806=h14c3975_0 164 | - xarray=0.14.1=py_1 165 | - xmltodict=0.12.0=py_0 166 | - xorg-fixesproto=5.0=h14c3975_1002 167 | - xorg-kbproto=1.0.7=h14c3975_1002 168 | - xorg-libx11=1.6.9=h516909a_0 169 | - xorg-libxcursor=1.2.0=h516909a_0 170 | - xorg-libxext=1.3.4=h516909a_0 171 | - xorg-libxfixes=5.0.3=h516909a_1004 172 | - xorg-libxinerama=1.1.4=hf484d3e_1000 173 | - xorg-libxrandr=1.5.2=h516909a_1 174 | - xorg-libxrender=0.9.10=h516909a_1002 175 | - xorg-randrproto=1.5.0=h14c3975_1001 176 | - xorg-renderproto=0.11.1=h14c3975_1002 177 | - xorg-xextproto=7.3.0=h14c3975_1002 178 | - xorg-xproto=7.0.31=h14c3975_1007 179 | - xz=5.2.4=h14c3975_4 180 | - yacs=0.1.6=py_0 181 | - yaml=0.1.7=h14c3975_1001 182 | - zeromq=4.3.2=he1b5a44_2 183 | - zipp=0.6.0=py_0 184 | - zlib=1.2.11=h7b6447c_3 185 | - zstd=1.3.7=h0b5b093_0 186 | - pip: 187 | - admesh==0.98.9 188 | - cloudpickle==1.2.2 189 | - cma==2.7.0 190 | - decorator==4.4.0 191 | - glfw==1.8.3 192 | - gym==0.15.3 193 | - ipdb==0.12.2 194 | - ipython==7.8.0 195 | - ipython-genutils==0.2.0 196 | - mujoco-py==2.0.2.5 197 | - pickleshare==0.7.5 198 | - prompt-toolkit==2.0.10 199 | - pyglet==1.3.2 200 | - pygments==2.4.2 201 | - pyquaternion==0.9.5 202 | - python-graphviz==0.13 203 | - torchviz==0.0.1 204 | - wcwidth==0.1.7 205 | -------------------------------------------------------------------------------- /configs/half_cheetah.yaml: -------------------------------------------------------------------------------- 1 | MUJOCO: 2 | ENV: 'HalfCheetahEnv' 3 | CLIP_ACTIONS: false 4 | MODEL: 5 | META_ARCHITECTURE: 'Basic' 6 | BATCH_SIZE: 1 7 | FRAME_SKIP: 10 8 | EPOCHS: 5001 9 | RANDOM_SEED: 40 10 | POLICY: 11 | ARCH: 'Perttu' 12 | METHOD: 'Grad' 13 | NETWORK: False 14 | MAX_HORIZON_STEPS: 40 15 | LAYERS: 16 | - 128 17 | - 128 18 | GAMMA: 0.99 19 | INITIAL_LOG_SD: -0.5 20 | INITIAL_SD: 0.5 21 | INITIAL_ACTION_MEAN: 0.0 22 | INITIAL_ACTION_SD: 0.1 23 | GRAD_WEIGHTS: 'average' 24 | SOLVER: 25 | OPTIMIZER: 'adam' 26 | BASE_LR: 0.001 27 | WEIGHT_DECAY: 0.0 28 | BIAS_LR_FACTOR: 1 29 | STD_LR_FACTOR: 1.0 30 | ADAM_BETAS: (0.9, 0.999) 31 | LOG: 32 | TESTING: 33 | ITER_PERIOD: 200 34 | RECORD_VIDEO: True 35 | -------------------------------------------------------------------------------- /configs/hopper.yaml: -------------------------------------------------------------------------------- 1 | MUJOCO: 2 | ENV: 'HopperEnv' 3 | CLIP_ACTIONS: false 4 | MODEL: 5 | META_ARCHITECTURE: 'Basic' 6 | BATCH_SIZE: 1 7 | FRAME_SKIP: 20 8 | EPOCHS: 5001 9 | RANDOM_SEED: 50 10 | POLICY: 11 | ARCH: 'Perttu' 12 | METHOD: 'Grad' 13 | NETWORK: False 14 | MAX_HORIZON_STEPS: 40 15 | LAYERS: 16 | - 128 17 | - 128 18 | GAMMA: 0.99 19 | INITIAL_LOG_SD: -0.5 20 | INITIAL_SD: 0.5 21 | INITIAL_ACTION_MEAN: 0.0 22 | INITIAL_ACTION_SD: 0.1 23 | GRAD_WEIGHTS: 'average' 24 | SOLVER: 25 | OPTIMIZER: 'adam' 26 | BASE_LR: 0.001 27 | WEIGHT_DECAY: 0.0 28 | BIAS_LR_FACTOR: 1 29 | STD_LR_FACTOR: 1.0 30 | ADAM_BETAS: (0.9, 0.999) 31 | LOG: 32 | TESTING: 33 | ITER_PERIOD: 200 34 | RECORD_VIDEO: True 35 | -------------------------------------------------------------------------------- /configs/inverted_double_pendulum.yaml: -------------------------------------------------------------------------------- 1 | MUJOCO: 2 | ENV: 'InvertedDoublePendulumEnv' 3 | CLIP_ACTIONS: false 4 | MODEL: 5 | META_ARCHITECTURE: 'Basic' 6 | BATCH_SIZE: 8 7 | FRAME_SKIP: 4 8 | RANDOM_SEED: 30 9 | EPOCHS: 5001 10 | POLICY: 11 | ARCH: 'Perttu' 12 | METHOD: 'R-VO' 13 | NETWORK: False 14 | MAX_HORIZON_STEPS: 50 15 | LAYERS: 16 | - 128 17 | - 128 18 | GAMMA: 0.99 19 | INITIAL_LOG_SD: -0.5 20 | INITIAL_SD: 0.2 21 | INITIAL_ACTION_MEAN: 0.0 22 | INITIAL_ACTION_SD: 0.1 23 | GRAD_WEIGHTS: "average" 24 | SOLVER: 25 | OPTIMIZER: 'adam' 26 | BASE_LR: 0.001 27 | WEIGHT_DECAY: 0.0 28 | BIAS_LR_FACTOR: 1 29 | STD_LR_FACTOR: 1.0 30 | ADAM_BETAS: (0.9, 0.999) 31 | LOG: 32 | TESTING: 33 | ITER_PERIOD: 100 34 | RECORD_VIDEO: True 35 | -------------------------------------------------------------------------------- /configs/inverted_pendulum.yaml: -------------------------------------------------------------------------------- 1 | MUJOCO: 2 | ENV: 'InvertedPendulumEnv' 3 | CLIP_ACTIONS: false 4 | MODEL: 5 | META_ARCHITECTURE: 'Basic' 6 | BATCH_SIZE: 8 7 | FRAME_SKIP: 2 8 | EPOCHS: 1001 9 | RANDOM_SEED: 10 10 | POLICY: 11 | ARCH: 'Perttu' 12 | METHOD: 'R-VO' 13 | NETWORK: False 14 | MAX_HORIZON_STEPS: 50 15 | LAYERS: 16 | - 128 17 | - 128 18 | GAMMA: 0.99 19 | INITIAL_LOG_SD: -0.5 20 | INITIAL_SD: 0.1 21 | INITIAL_ACTION_MEAN: 0.0 22 | INITIAL_ACTION_SD: 0.0 23 | GRAD_WEIGHTS: 'average' 24 | SOLVER: 25 | OPTIMIZER: 'adam' 26 | BASE_LR: 0.001 27 | WEIGHT_DECAY: 0.0 28 | BIAS_LR_FACTOR: 1 29 | STD_LR_FACTOR: 1.0 30 | ADAM_BETAS: (0.9, 0.999) 31 | LOG: 32 | TESTING: 33 | ENABLED: True 34 | ITER_PERIOD: 200 35 | RECORD_VIDEO: True -------------------------------------------------------------------------------- /configs/leg.yaml: -------------------------------------------------------------------------------- 1 | MUJOCO: 2 | ENV: 'LegEnv' 3 | CLIP_ACTIONS: false 4 | MODEL: 5 | META_ARCHITECTURE: 'Basic' 6 | BATCH_SIZE: 32 7 | EPOCHS: 100 8 | POLICY: 9 | ARCH: 'Perttu' 10 | METHOD: 'VO' 11 | LAYERS: 12 | - 128 13 | - 128 14 | GAMMA: 0.99 15 | MAX_HORIZON_STEPS: 228 16 | INITIAL_LOG_SD: -1.0 17 | INITIAL_ACTION_MEAN: 0.0 18 | INITIAL_ACTION_SD: 0.1 19 | SOLVER: 20 | BASE_LR: 0.1 21 | WEIGHT_DECAY: 0.0 22 | BIAS_LR_FACTOR: 1 23 | STD_LR_FACTOR: 1.0 24 | ADAM_BETAS: (0.9, 0.999) 25 | LOG: 26 | TESTING: 27 | ENABLED: False -------------------------------------------------------------------------------- /configs/swimmer.yaml: -------------------------------------------------------------------------------- 1 | MUJOCO: 2 | ENV: 'SwimmerEnv' 3 | CLIP_ACTIONS: false 4 | MODEL: 5 | META_ARCHITECTURE: 'Basic' 6 | BATCH_SIZE: 1 7 | FRAME_SKIP: 10 8 | EPOCHS: 5001 9 | RANDOM_SEED: 20 10 | POLICY: 11 | ARCH: 'Perttu' 12 | METHOD: 'Grad' 13 | NETWORK: False 14 | MAX_HORIZON_STEPS: 40 15 | LAYERS: 16 | - 128 17 | - 128 18 | GAMMA: 0.99 19 | INITIAL_LOG_SD: -0.5 20 | INITIAL_SD: 0.2 21 | INITIAL_ACTION_MEAN: 0.0 22 | INITIAL_ACTION_SD: 0.1 23 | GRAD_WEIGHTS: 'average' 24 | SOLVER: 25 | OPTIMIZER: 'adam' 26 | BASE_LR: 0.001 27 | WEIGHT_DECAY: 0.0 28 | BIAS_LR_FACTOR: 1 29 | STD_LR_FACTOR: 1.0 30 | ADAM_BETAS: (0.9, 0.999) 31 | LOG: 32 | TESTING: 33 | ITER_PERIOD: 200 34 | RECORD_VIDEO: True 35 | -------------------------------------------------------------------------------- /configs/walker2d.yaml: -------------------------------------------------------------------------------- 1 | MUJOCO: 2 | ENV: 'Walker2dEnv' 3 | CLIP_ACTIONS: false 4 | MODEL: 5 | META_ARCHITECTURE: 'Basic' 6 | BATCH_SIZE: 1 7 | FRAME_SKIP: 25 8 | EPOCHS: 5001 9 | RANDOM_SEED: 60 10 | POLICY: 11 | ARCH: 'Perttu' 12 | METHOD: 'Grad' 13 | NETWORK: False 14 | MAX_HORIZON_STEPS: 40 15 | LAYERS: 16 | - 128 17 | - 128 18 | GAMMA: 0.99 19 | INITIAL_LOG_SD: -0.5 20 | INITIAL_SD: 0.5 21 | INITIAL_ACTION_MEAN: 0.0 22 | INITIAL_ACTION_SD: 0.1 23 | GRAD_WEIGHTS: 'average' 24 | SOLVER: 25 | OPTIMIZER: 'adam' 26 | BASE_LR: 0.001 27 | WEIGHT_DECAY: 0.0 28 | BIAS_LR_FACTOR: 1 29 | STD_LR_FACTOR: 1.0 30 | ADAM_BETAS: (0.9, 0.999) 31 | LOG: 32 | TESTING: 33 | ITER_PERIOD: 200 34 | RECORD_VIDEO: True 35 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from datetime import datetime 4 | 5 | import model.engine.trainer 6 | import model.engine.dynamics_model_trainer 7 | from model.config import get_cfg_defaults 8 | import utils.logger as lg 9 | import model.engine.landscape_plot 10 | 11 | 12 | def train(cfg, iter): 13 | 14 | # Create output directories 15 | env_output_dir = os.path.join(cfg.OUTPUT.DIR, cfg.MUJOCO.ENV) 16 | if cfg.OUTPUT.NAME == "timestamp": 17 | output_dir_name = "{0:%Y-%m-%d %H:%M:%S}".format(datetime.now()) 18 | else: 19 | output_dir_name = cfg.OUTPUT.NAME 20 | output_dir = os.path.join(env_output_dir, output_dir_name) 21 | output_rec_dir = os.path.join(output_dir, 'recordings') 22 | output_weights_dir = os.path.join(output_dir, 'weights') 23 | output_results_dir = os.path.join(output_dir, 'results') 24 | os.makedirs(output_dir) 25 | os.mkdir(output_weights_dir) 26 | os.mkdir(output_results_dir) 27 | if cfg.LOG.TESTING.ENABLED: 28 | os.mkdir(output_rec_dir) 29 | 30 | # Create logger 31 | logger = lg.setup_logger("model.engine.trainer", output_dir, 'logs') 32 | logger.info("Running with config:\n{}".format(cfg)) 33 | 34 | # Repeat for required number of iterations 35 | for i in range(iter): 36 | agent = model.engine.trainer.do_training( 37 | cfg, 38 | logger, 39 | output_results_dir, 40 | output_rec_dir, 41 | output_weights_dir, 42 | i 43 | ) 44 | model.engine.landscape_plot.visualise2d(agent, output_results_dir, i) 45 | 46 | 47 | def train_dynamics_model(cfg, iter): 48 | 49 | # Create output directories 50 | env_output_dir = os.path.join(cfg.OUTPUT.DIR, cfg.MUJOCO.ENV) 51 | output_dir = os.path.join(env_output_dir, "{0:%Y-%m-%d %H:%M:%S}".format(datetime.now())) 52 | output_rec_dir = os.path.join(output_dir, 'recordings') 53 | output_weights_dir = os.path.join(output_dir, 'weights') 54 | output_results_dir = os.path.join(output_dir, 'results') 55 | os.makedirs(output_dir) 56 | os.mkdir(output_weights_dir) 57 | os.mkdir(output_results_dir) 58 | if cfg.LOG.TESTING.ENABLED: 59 | os.mkdir(output_rec_dir) 60 | 61 | # Create logger 62 | logger = lg.setup_logger("model.engine.dynamics_model_trainer", output_dir, 'logs') 63 | logger.info("Running with config:\n{}".format(cfg)) 64 | 65 | # Train the dynamics model 66 | model.engine.dynamics_model_trainer.do_training( 67 | cfg, 68 | logger, 69 | output_results_dir, 70 | output_rec_dir, 71 | output_weights_dir 72 | ) 73 | 74 | 75 | def inference(cfg): 76 | pass 77 | 78 | 79 | def main(): 80 | parser = argparse.ArgumentParser(description="PyTorch model-based RL.") 81 | parser.add_argument( 82 | "--config-file", 83 | default="", 84 | metavar="file", 85 | help="path to config file", 86 | type=str, 87 | ) 88 | parser.add_argument( 89 | "--mode", 90 | default="train", 91 | metavar="mode", 92 | help="'train' or 'test' or 'dynamics'", 93 | type=str, 94 | ) 95 | parser.add_argument( 96 | "--iter", 97 | default=1, 98 | help="Number of iterations", 99 | type=int 100 | ) 101 | parser.add_argument( 102 | "--opts", 103 | help="Modify config options using the command-line", 104 | default=[], 105 | nargs=argparse.REMAINDER, 106 | ) 107 | 108 | args = parser.parse_args() 109 | 110 | # build the config 111 | cfg = get_cfg_defaults() 112 | cfg.merge_from_file(args.config_file) 113 | cfg.merge_from_list(args.opts) 114 | cfg.freeze() 115 | 116 | # TRAIN 117 | if args.mode == "train": 118 | train(cfg, args.iter) 119 | elif args.mode == "dynamics": 120 | train_dynamics_model(cfg, args.iter) 121 | 122 | 123 | if __name__ == "__main__": 124 | main() 125 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_model 2 | 3 | __all__ = ["build_model"] 4 | -------------------------------------------------------------------------------- /model/archs/__init__.py: -------------------------------------------------------------------------------- 1 | from .basic import Basic 2 | 3 | __all__ = ["Basic"] 4 | -------------------------------------------------------------------------------- /model/archs/basic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from model.blocks import build_policy, mj_torch_block_factory 4 | from copy import deepcopy 5 | from solver import build_optimizer 6 | import numpy as np 7 | from ..blocks.policy.strategies import * 8 | 9 | 10 | class Basic(nn.Module): 11 | def __init__(self, cfg, agent): 12 | """Build the model from the fed config node. 13 | :param cfg: CfgNode containing the configurations of everything. 14 | """ 15 | super(Basic, self).__init__() 16 | self.cfg = cfg 17 | self.agent = agent 18 | 19 | # build policy net 20 | self.policy_net = build_policy(cfg, self.agent) 21 | 22 | # Make sure unwrapped agent can call policy_net 23 | self.agent.unwrapped.policy_net = self.policy_net 24 | 25 | # build forward dynamics block 26 | self.dynamics_block = mj_torch_block_factory(agent, 'dynamics').apply 27 | self.reward_block = mj_torch_block_factory(agent, 'reward').apply 28 | 29 | def forward(self, state): 30 | """Single pass. 31 | :param state: 32 | :return: 33 | """ 34 | 35 | # We're generally using torch.float64 and numpy.float64 for precision, but the net can be trained with 36 | # torch.float32 -- not sure if this really makes a difference wrt speed or memory, but the default layers 37 | # seem to be using torch.float32 38 | action = self.policy_net(state.detach().float()).double() 39 | 40 | # Forward block will drive the simulation forward 41 | next_state = self.dynamics_block(state, action) 42 | 43 | # The reward is actually calculated in the dynamics_block, here we'll just grab it from the agent 44 | reward = self.reward_block(state, action) 45 | 46 | return next_state, reward 47 | -------------------------------------------------------------------------------- /model/blocks/__init__.py: -------------------------------------------------------------------------------- 1 | from .policy import build_policy 2 | from .mujoco import mj_torch_block_factory 3 | 4 | __all__ = ["build_policy", "mj_torch_block_factory"] 5 | -------------------------------------------------------------------------------- /model/blocks/mujoco.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import autograd, nn 3 | import numpy as np 4 | 5 | 6 | def mj_torch_block_factory(agent, mode): 7 | mj_forward = agent.forward_factory(mode) 8 | mj_gradients = agent.gradient_factory(mode) 9 | 10 | class MjBlock(autograd.Function): 11 | 12 | @staticmethod 13 | def forward(ctx, state, action): 14 | 15 | # Advance simulation or return reward 16 | if mode == "dynamics": 17 | # We need to get a deep copy of simulation data so we can return to this "snapshot" 18 | # (we can't deepcopy agent.sim.data because some variables are unpicklable) 19 | # We'll calculate gradients in the backward phase of "reward" 20 | agent.data.qpos[:] = state[:agent.model.nq].detach().numpy().copy() 21 | agent.data.qvel[:] = state[agent.model.nq:].detach().numpy().copy() 22 | agent.data.ctrl[:] = action.detach().numpy().copy() 23 | agent.data_snapshot = agent.get_snapshot() 24 | 25 | next_state = mj_forward() 26 | agent.next_state = next_state 27 | 28 | ctx.data_snapshot = agent.data_snapshot 29 | # ctx.reward = agent.reward 30 | # ctx.next_state = agent.next_state 31 | 32 | return torch.from_numpy(next_state) 33 | 34 | elif mode == "reward": 35 | ctx.data_snapshot = agent.data_snapshot 36 | ctx.reward = agent.reward 37 | ctx.next_state = agent.next_state 38 | return torch.Tensor([agent.reward]).double() 39 | 40 | else: 41 | raise TypeError("mode has to be 'dynamics' or 'gradient'") 42 | 43 | @staticmethod 44 | def backward(ctx, grad_output): 45 | 46 | if agent.cfg.MODEL.POLICY.GRAD_WEIGHTS == "prioritise": 47 | weight = agent.cfg.MODEL.POLICY.MAX_HORIZON_STEPS - ctx.data_snapshot.step_idx.value 48 | else: 49 | weight = 1 / (agent.cfg.MODEL.POLICY.MAX_HORIZON_STEPS - ctx.data_snapshot.step_idx.value) 50 | 51 | # We should need to calculate gradients only once per dynamics/reward cycle 52 | if mode == "dynamics": 53 | 54 | state_jacobian = torch.from_numpy(agent.dynamics_gradients["state"]) 55 | action_jacobian = torch.from_numpy(agent.dynamics_gradients["action"]) 56 | 57 | if agent.cfg.MODEL.POLICY.GRAD_WEIGHTS == "prioritise": 58 | action_jacobian = (1.0 / agent.running_sum) * action_jacobian 59 | elif agent.cfg.MODEL.POLICY.GRAD_WEIGHTS == "average": 60 | action_jacobian = weight * action_jacobian 61 | 62 | elif mode == "reward": 63 | 64 | if agent.cfg.MODEL.POLICY.GRAD_WEIGHTS == "prioritise": 65 | agent.running_sum += weight 66 | 67 | # Calculate gradients, "reward" is always called first 68 | mj_gradients(ctx.data_snapshot, ctx.next_state, ctx.reward, test=True) 69 | state_jacobian = torch.from_numpy(agent.reward_gradients["state"]) 70 | action_jacobian = torch.from_numpy(agent.reward_gradients["action"]) 71 | 72 | if agent.cfg.MODEL.POLICY.GRAD_WEIGHTS == "prioritise": 73 | state_jacobian = (weight - 1) * state_jacobian 74 | action_jacobian = (weight / agent.running_sum) * action_jacobian 75 | elif agent.cfg.MODEL.POLICY.GRAD_WEIGHTS == "average": 76 | action_jacobian = weight * action_jacobian 77 | 78 | else: 79 | raise TypeError("mode has to be 'dynamics' or 'reward'") 80 | 81 | if False: 82 | torch.set_printoptions(precision=5, sci_mode=False) 83 | print("{} {}".format(ctx.data_snapshot.time, mode)) 84 | print("grad_output") 85 | print(grad_output) 86 | print("state_jacobian") 87 | print(state_jacobian) 88 | print("action_jacobian") 89 | print(action_jacobian) 90 | print("grad_output*state_jacobian") 91 | print(torch.matmul(grad_output, state_jacobian)) 92 | print("grad_output*action_jacobian") 93 | print(torch.matmul(grad_output, action_jacobian)) 94 | print('weight: {} ---- 1/({} - {})'.format(weight, agent.cfg.MODEL.POLICY.MAX_HORIZON_STEPS, ctx.data_snapshot.step_idx)) 95 | print("") 96 | 97 | ds = torch.matmul(grad_output, state_jacobian) 98 | da = torch.matmul(grad_output, action_jacobian) 99 | 100 | #threshold = 5 101 | #ds.clamp_(-threshold, threshold) 102 | #da.clamp_(-threshold, threshold) 103 | 104 | return ds, da 105 | 106 | return MjBlock 107 | 108 | -------------------------------------------------------------------------------- /model/blocks/policy/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_policy 2 | from .deterministic import DeterministicPolicy 3 | from .trajopt import TrajOpt 4 | from .stochastic import StochasticPolicy 5 | from .strategies import CMAES, VariationalOptimization, Perttu 6 | 7 | __all__ = ["build_policy", "DeterministicPolicy", "StochasticPolicy", "TrajOpt", "CMAES", "VariationalOptimization", 8 | "Perttu"] 9 | -------------------------------------------------------------------------------- /model/blocks/policy/base.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class BasePolicy(nn.Module): 5 | def __init__(self, policy_cfg, agent): 6 | super(BasePolicy, self).__init__() 7 | self.policy_cfg = policy_cfg 8 | self.agent = agent 9 | 10 | def forward(self, s): 11 | raise NotImplementedError 12 | 13 | def episode_callback(self): 14 | raise NotImplementedError 15 | 16 | def batch_callback(self): 17 | raise NotImplementedError 18 | -------------------------------------------------------------------------------- /model/blocks/policy/build.py: -------------------------------------------------------------------------------- 1 | from model.blocks import policy 2 | 3 | 4 | def build_policy(cfg, agent): 5 | policy_factory = getattr(policy, cfg.MODEL.POLICY.ARCH) 6 | policy_net = policy_factory(cfg, agent) 7 | #if cfg.MODEL.POLICY.OBS_SCALER: 8 | # from .wrappers import ZMUSWrapper 9 | # policy_net = ZMUSWrapper(policy_net) 10 | return policy_net 11 | 12 | -------------------------------------------------------------------------------- /model/blocks/policy/deterministic.py: -------------------------------------------------------------------------------- 1 | from .base import BasePolicy 2 | from model.layers import FeedForward 3 | from torch.nn.parameter import Parameter 4 | import torch 5 | 6 | 7 | class DeterministicPolicy(BasePolicy): 8 | def __init__(self, policy_cfg, agent): 9 | super(DeterministicPolicy, self).__init__(policy_cfg, agent) 10 | 11 | # for deterministic policies the model is just a simple feed forward net 12 | self.net = FeedForward( 13 | agent.observation_space.shape[0], 14 | policy_cfg.LAYERS, 15 | agent.action_space.shape[0]) 16 | 17 | if policy_cfg.VARIATIONAL: 18 | self.logstd = Parameter(torch.empty(policy_cfg.MAX_HORIZON_STEPS, agent.action_space.shape[0]).fill_(0)) 19 | self.index = 0 20 | 21 | def forward(self, s): 22 | # Get mean of action value 23 | a_mean = self.net(s) 24 | 25 | if self.policy_cfg.VARIATIONAL: 26 | # Get standard deviation of action 27 | a_std = torch.exp(self.logstd[self.index]) 28 | self.index += 1 29 | 30 | # Sample with re-parameterization trick 31 | a = torch.distributions.Normal(a_mean, a_std).rsample() 32 | 33 | else: 34 | a = a_mean 35 | 36 | return a 37 | 38 | def episode_callback(self): 39 | if self.policy_cfg.VARIATIONAL: 40 | self.index = 0 41 | 42 | def batch_callback(self): 43 | pass 44 | 45 | -------------------------------------------------------------------------------- /model/blocks/policy/dynamics.py: -------------------------------------------------------------------------------- 1 | from solver import build_optimizer 2 | import torch 3 | 4 | 5 | class DynamicsModel(torch.nn.Module): 6 | def __init__(self, agent): 7 | super(DynamicsModel, self).__init__() 8 | 9 | # We're using a simple two layered feedforward net 10 | self.net = torch.nn.Sequential( 11 | torch.nn.Linear(agent.observation_space.shape[0] + agent.action_space.shape[0], 128), 12 | torch.nn.LeakyReLU(), 13 | torch.nn.Linear(128, 128), 14 | torch.nn.LeakyReLU(), 15 | torch.nn.Linear(128, agent.observation_space.shape[0]) 16 | ).to(torch.device("cpu")) 17 | 18 | self.optimizer = torch.optim.Adam(self.net.parameters(), 0.001) 19 | 20 | def forward(self, state, action): 21 | 22 | # Predict next state given current state and action 23 | next_state = self.net(torch.cat([state, action])) 24 | 25 | return next_state 26 | 27 | -------------------------------------------------------------------------------- /model/blocks/policy/stochastic.py: -------------------------------------------------------------------------------- 1 | from .base import BasePolicy 2 | from model.layers import FeedForward 3 | import torch 4 | import torch.nn as nn 5 | import torch.distributions as tdist 6 | from model.blocks.utils import build_soft_lower_bound_fn 7 | 8 | 9 | class StochasticPolicy(BasePolicy): 10 | def __init__(self, policy_cfg, agent): 11 | super(StochasticPolicy, self).__init__(policy_cfg, agent) 12 | 13 | # Get number of actions 14 | self.num_actions = agent.action_space.shape[0] 15 | 16 | # The network outputs a gaussian distribution 17 | self.net = FeedForward( 18 | agent.observation_space.shape[0], 19 | policy_cfg.LAYERS, 20 | self.num_actions*2 21 | ) 22 | 23 | def forward(self, s): 24 | # Get means and logs of standard deviations 25 | output = self.net(s) 26 | means = output[:self.num_actions] 27 | log_stds = output[self.num_actions:] 28 | 29 | # Return only means when testing 30 | if not self.training: 31 | return means 32 | 33 | # Get the actual standard deviations 34 | stds = torch.exp(log_stds) 35 | 36 | # Sample with re-parameterization trick 37 | a = tdist.Normal(means, stds).rsample() 38 | 39 | return a 40 | 41 | def episode_callback(self): 42 | pass 43 | 44 | def batch_callback(self): 45 | pass 46 | 47 | -------------------------------------------------------------------------------- /model/blocks/policy/strategies.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import torch 3 | import numpy as np 4 | import cma 5 | from torch import nn 6 | from model.layers import FeedForward 7 | from solver import build_optimizer 8 | from torch.nn.parameter import Parameter 9 | from optimizer import Optimizer 10 | 11 | 12 | class BaseStrategy(ABC, nn.Module): 13 | 14 | def __init__(self, cfg, agent, reinforce_loss_weight=1.0, 15 | min_reinforce_loss_weight=0.0, min_sd=0, soft_relu_beta=0.2, 16 | adam_betas=(0.9, 0.999)): 17 | 18 | super(BaseStrategy, self).__init__() 19 | 20 | self.cfg = cfg 21 | self.method = cfg.MODEL.POLICY.METHOD 22 | 23 | # Get action dimension, horizon length, and batch size 24 | self.action_dim = agent.action_space.sample().shape[0] 25 | self.state_dim = agent.observation_space.sample().shape[0] 26 | self.horizon = cfg.MODEL.POLICY.MAX_HORIZON_STEPS 27 | self.dim = (self.action_dim, self.horizon) 28 | self.batch_size = cfg.MODEL.BATCH_SIZE 29 | 30 | # Set initial values 31 | self.mean = [] 32 | self.clamped_action = [] 33 | self.sd = [] 34 | #self.clamped_sd = [] 35 | self.min_sd = min_sd 36 | self.soft_relu_beta = soft_relu_beta 37 | self.best_actions = [] 38 | self.sd_threshold = 0.0001 39 | 40 | # Set loss parameters 41 | self.gamma = cfg.MODEL.POLICY.GAMMA 42 | self.min_reinforce_loss_weight = min_reinforce_loss_weight 43 | self.reinforce_loss_weight = reinforce_loss_weight 44 | self.eps = np.finfo(np.float32).eps.item() 45 | self.loss_functions = {"IR": self.IR_loss, "PR": self.PR_loss, "H": self.H_loss, 46 | "SIR": self.SIR_loss, "R": self.R_loss} 47 | self.log_prob = [] 48 | 49 | # Set optimizer parameters 50 | self.optimizer = None 51 | self.learning_rate = cfg.SOLVER.BASE_LR 52 | self.adam_betas = adam_betas 53 | self.optimize_functions = {"default": self.standard_optimize, "H": self.H_optimize} 54 | 55 | # Get references in case we want to track episode and step 56 | self.step_idx = agent.get_step_idx() 57 | self.episode_idx = agent.get_episode_idx() 58 | 59 | @abstractmethod 60 | def optimize(self, batch_loss): 61 | pass 62 | 63 | @abstractmethod 64 | def forward(self, state): 65 | pass 66 | 67 | @staticmethod 68 | def clip(x, mean, limit): 69 | xmin = mean - limit 70 | xmax = mean + limit 71 | return torch.max(torch.min(x, xmax), xmin) 72 | 73 | @staticmethod 74 | def clip_negative(x): 75 | #return torch.max(x, torch.zeros(x.shape, dtype=torch.double)) 76 | return torch.abs(x) 77 | 78 | @staticmethod 79 | def soft_relu(x, beta=0.1): 80 | return (torch.sqrt(x**2 + beta**2) + x) / 2.0 81 | 82 | # From Tassa 2012 synthesis and stabilization of complex behaviour 83 | def clamp_sd(self, sd): 84 | #return torch.max(sd, 0.001*torch.ones(sd.shape, dtype=torch.double)) 85 | #return torch.exp(sd) 86 | #return self.min_sd + self.soft_relu(sd - self.min_sd, self.soft_relu_beta) 87 | return sd 88 | 89 | def clamp_action(self, action): 90 | #return self.soft_relu(action, self.soft_relu_beta) 91 | # return torch.clamp(action, min=0.0, max=1.0) # NB! Kills gradients at borders 92 | #return torch.exp(action) 93 | 94 | if self.cfg.MODEL.POLICY.NETWORK: 95 | #for idx in range(len(action)): 96 | # if idx < 10: 97 | # action[idx] = action[idx]/10 98 | # else: 99 | # action[idx] = torch.exp(action[idx])/20 100 | pass 101 | #else: 102 | # for idx in range(len(action)): 103 | # if idx < 10 or action[idx] >= 0: 104 | # continue 105 | # elif action[idx] < 0: 106 | # action[idx].data -= action[idx].data 107 | 108 | return action 109 | 110 | def get_clamped_sd(self): 111 | return self.clamp_sd(self.sd).detach().numpy() 112 | 113 | def get_clamped_action(self): 114 | return self.clamped_action 115 | 116 | def initialise_mean(self, loc=0.0, sd=0.1, seed=0): 117 | if self.dim is not None: 118 | #return np.zeros(self.dim, dtype=np.float64) 119 | if seed > 0: 120 | np.random.seed(seed) 121 | return np.asarray(np.random.normal(loc, sd, self.dim), dtype=np.float64) 122 | 123 | def initialise_sd(self, factor=1.0): 124 | if self.dim is not None: 125 | return factor*np.ones(self.dim, dtype=np.float64) 126 | 127 | # Initialise mean / sd or get dim from initial values 128 | 129 | def calculate_reinforce_loss(self, batch_loss, stepwise_loss=False): 130 | # Initialise a tensor for returns 131 | returns = torch.empty(batch_loss.shape, dtype=torch.float64) 132 | 133 | # Calculate returns 134 | for episode_idx, episode_loss in enumerate(batch_loss): 135 | R = 0 136 | for step_idx in range(1, len(episode_loss)+1): 137 | R = -episode_loss[-step_idx] + self.gamma*R 138 | returns[episode_idx, -step_idx] = R 139 | 140 | # Remove baseline 141 | advantages = ((returns - returns.mean(dim=0)) / (returns.std(dim=0) + self.eps)).detach() 142 | 143 | # Return REINFORCE loss 144 | #reinforce_loss = torch.mean((-advantages * self.log_prob).sum(dim=1)) 145 | #reinforce_loss = torch.sum((-advantages * self.log_prob)) 146 | avg_stepwise_loss = (-advantages * self.log_prob).mean(dim=0) 147 | if stepwise_loss: 148 | return avg_stepwise_loss 149 | else: 150 | #return torch.sum(avg_stepwise_loss) 151 | return torch.mean(torch.sum((-advantages * self.log_prob), dim=1)) 152 | 153 | def calculate_objective_loss(self, batch_loss, stepwise_loss=False): 154 | #batch_loss = (batch_loss.mean(dim=0) - batch_loss) / (batch_loss.std(dim=0) + self.eps) 155 | #batch_loss = batch_loss.mean() - batch_loss 156 | 157 | # Initialise a tensor for returns 158 | #ret = torch.empty(batch_loss.shape, dtype=torch.float64) 159 | ## Calculate returns 160 | #for episode_idx, episode_loss in enumerate(batch_loss): 161 | # R = torch.zeros(1, dtype=torch.float64) 162 | # for step_idx in range(1, len(episode_loss)+1): 163 | # R = -episode_loss[-step_idx] + self.gamma*R.detach() 164 | # ret[episode_idx, -step_idx] = R 165 | #avg_stepwise_loss = torch.mean(-ret, dim=0) 166 | #advantages = ret - ret.mean(dim=0) 167 | 168 | #means = torch.mean(batch_loss, dim=0) 169 | #idxs = (batch_loss > torch.mean(batch_loss)).detach() 170 | #idxs = sums < np.mean(sums) 171 | 172 | #advantages = torch.zeros((1, batch_loss.shape[1])) 173 | #advantages = [] 174 | #idxs = batch_loss < torch.mean(batch_loss) 175 | #for i in range(batch_loss.shape[1]): 176 | # avg = torch.mean(batch_loss[:,i]).detach().numpy() 177 | # idxs = batch_loss[:,i].detach().numpy() < avg 178 | # advantages[0,i] = torch.mean(batch_loss[idxs,i], dim=0) 179 | # if torch.any(idxs[:,i]): 180 | # advantages.append(torch.mean(batch_loss[idxs[:,i],i], dim=0)) 181 | 182 | #batch_loss = (batch_loss - batch_loss.mean(dim=1).detach()) / (batch_loss.std(dim=1).detach() + self.eps) 183 | #batch_loss = (batch_loss - batch_loss.mean(dim=0).detach()) / (batch_loss.std(dim=0).detach() + self.eps) 184 | #batch_loss = (batch_loss - batch_loss.mean(dim=0).detach()) 185 | 186 | #fvals = torch.empty(batch_loss.shape[0]) 187 | #for episode_idx, episode_loss in enumerate(batch_loss): 188 | # nans = torch.isnan(episode_loss) 189 | # fvals[episode_idx] = torch.sum(episode_loss[nans == False]) 190 | 191 | nans = torch.isnan(batch_loss) 192 | batch_loss[nans] = 0 193 | 194 | if stepwise_loss: 195 | return torch.mean(batch_loss, dim=0) 196 | else: 197 | #return torch.sum(batch_loss[0][100]) 198 | #return torch.sum(advantages[advantages<0]) 199 | #idxs = torch.isnan(batch_loss) == False 200 | #return torch.sum(batch_loss[idxs]) 201 | #return torch.mean(fvals) 202 | #return torch.mean(torch.sum(batch_loss, dim=1)) 203 | 204 | return torch.sum(torch.mean(batch_loss, dim=0)) 205 | #return batch_loss[0][-1] 206 | #return torch.sum(torch.stack(advantages)) 207 | 208 | # Remove baseline 209 | #advantages = ((ret - ret.mean(dim=0)) / (ret.std(dim=0) + self.eps)) 210 | #advantages = ret 211 | 212 | # Return REINFORCE loss 213 | #avg_stepwise_loss = -advantages.mean(dim=0) 214 | #if stepwise_loss: 215 | # return avg_stepwise_loss 216 | #else: 217 | # return torch.sum(avg_stepwise_loss) 218 | 219 | def R_loss(self, batch_loss): 220 | 221 | # Get objective loss 222 | objective_loss = self.calculate_objective_loss(batch_loss) 223 | 224 | return objective_loss, {"objective_loss": float(torch.mean(torch.sum(batch_loss, dim=1)).detach().numpy()), 225 | "total_loss": float(objective_loss.detach().numpy())} 226 | #"total_loss": batch_loss.detach().numpy()} 227 | 228 | 229 | def PR_loss(self, batch_loss): 230 | 231 | # Get REINFORCE loss 232 | reinforce_loss = self.calculate_reinforce_loss(batch_loss) 233 | 234 | # Get objective loss 235 | objective_loss = self.calculate_objective_loss(batch_loss) 236 | 237 | # Return a sum of the objective loss and REINFORCE loss 238 | loss = objective_loss + self.reinforce_loss_weight*reinforce_loss 239 | 240 | return loss, {"objective_loss": float(torch.mean(torch.sum(batch_loss, dim=1)).detach().numpy()), 241 | "reinforce_loss": float(reinforce_loss.detach().numpy()), 242 | "total_loss": float(loss.detach().numpy())} 243 | 244 | def IR_loss(self, batch_loss): 245 | 246 | # Get REINFORCE loss 247 | reinforce_loss = self.calculate_reinforce_loss(batch_loss) 248 | 249 | # Get objective loss 250 | objective_loss = self.calculate_objective_loss(batch_loss) 251 | 252 | # Return an interpolated mix of the objective loss and REINFORCE loss 253 | clamped_sd = self.clamp_sd(self.sd) 254 | mix_factor = 0.5*((clamped_sd - self.min_sd) / (self.initial_clamped_sd - self.min_sd)).mean().detach().numpy() 255 | mix_factor = self.min_reinforce_loss_weight + (1.0 - self.min_reinforce_loss_weight)*mix_factor 256 | mix_factor = np.maximum(np.minimum(mix_factor, 1), 0) 257 | 258 | loss = (1.0 - mix_factor) * objective_loss + mix_factor * reinforce_loss 259 | return loss, {"objective_loss": [objective_loss.detach().numpy()], 260 | "reinforce_loss": [reinforce_loss.detach().numpy()], 261 | "total_loss": [loss.detach().numpy()]} 262 | 263 | def SIR_loss(self, batch_loss): 264 | 265 | # Get REINFORCE loss 266 | reinforce_loss = self.calculate_reinforce_loss(batch_loss, stepwise_loss=True) 267 | 268 | # Get objective loss 269 | objective_loss = self.calculate_objective_loss(batch_loss, stepwise_loss=True) 270 | 271 | # Return an interpolated mix of the objective loss and REINFORCE loss 272 | clamped_sd = self.clamp_sd(self.sd) 273 | mix_factor = 0.5*((clamped_sd - self.min_sd) / (self.initial_clamped_sd - self.min_sd)).detach() 274 | mix_factor = self.min_reinforce_loss_weight + (1.0 - self.min_reinforce_loss_weight)*mix_factor 275 | mix_factor = torch.max(torch.min(mix_factor, torch.ones_like(mix_factor)), torch.zeros_like(mix_factor)) 276 | 277 | loss = ((1.0 - mix_factor) * objective_loss + mix_factor * reinforce_loss).sum() 278 | return loss, {"objective_loss": [objective_loss.detach().numpy()], 279 | "reinforce_loss": [reinforce_loss.detach().numpy()], 280 | "total_loss": [loss.detach().numpy()]} 281 | 282 | def H_loss(self, batch_loss): 283 | 284 | # Get REINFORCE loss 285 | reinforce_loss = self.calculate_reinforce_loss(batch_loss) 286 | 287 | # Get objective loss 288 | objective_loss = torch.sum(batch_loss, dim=1) 289 | 290 | # Get idx of best sampled action values 291 | best_idx = objective_loss.argmin() 292 | 293 | # If some other sample than the mean was best, update the mean to it 294 | if best_idx != 0: 295 | self.best_actions = self.clamped_action[:, :, best_idx].copy() 296 | 297 | return reinforce_loss, objective_loss[best_idx], \ 298 | {"objective_loss": [objective_loss.detach().numpy()], 299 | "reinforce_loss": [reinforce_loss.detach().numpy()], 300 | "total_loss": [reinforce_loss.detach().numpy()]} 301 | 302 | def get_named_parameters(self, name): 303 | params = {} 304 | for key, val in self.named_parameters(): 305 | if key.split(".")[0] == name: 306 | params[key] = val 307 | 308 | # Return a generator that behaves like self.named_parameters() 309 | return ((x, params[x]) for x in params) 310 | 311 | def standard_optimize(self, batch_loss): 312 | 313 | # Get appropriate loss 314 | loss, stats = self.loss_functions[self.method](batch_loss) 315 | 316 | self.optimizer.zero_grad() 317 | loss.backward() 318 | #nn.utils.clip_grad_norm_(self.parameters(), 0.1) 319 | if self.cfg.SOLVER.OPTIMIZER == "sgd": 320 | nn.utils.clip_grad_value_(self.parameters(), 1) 321 | 322 | #total_norm_sqr = 0 323 | #total_norm_sqr += self.mean._layers["linear_layer_0"].weight.grad.norm() ** 2 324 | #total_norm_sqr += self.mean._layers["linear_layer_0"].bias.grad.norm() ** 2 325 | #total_norm_sqr += self.mean._layers["linear_layer_1"].weight.grad.norm() ** 2 326 | #total_norm_sqr += self.mean._layers["linear_layer_1"].bias.grad.norm() ** 2 327 | #total_norm_sqr += self.mean._layers["linear_layer_2"].weight.grad.norm() ** 2 328 | #total_norm_sqr += self.mean._layers["linear_layer_2"].bias.grad.norm() ** 2 329 | 330 | #gradient_clip = 0.01 331 | #scale = min(1.0, gradient_clip / (total_norm_sqr ** 0.5 + 1e-4)) 332 | #print("lr: ", scale * self.cfg.SOLVER.BASE_LR) 333 | 334 | #if total_norm_sqr ** 0.5 > gradient_clip: 335 | 336 | # for param in self.parameters(): 337 | # param.data.sub_(scale * self.cfg.SOLVER.BASE_LR) 338 | # if param.grad is not None: 339 | # param.grad = (param.grad * gradient_clip) / (total_norm_sqr ** 0.5) 340 | #pass 341 | self.optimizer.step() 342 | 343 | #print("ll0 weight: {} {}".format(self.mean._layers["linear_layer_0"].weight_std.min(), 344 | # self.mean._layers["linear_layer_0"].weight_std.max())) 345 | #print("ll1 weight: {} {}".format(self.mean._layers["linear_layer_1"].weight_std.min(), 346 | # self.mean._layers["linear_layer_1"].weight_std.max())) 347 | #print("ll2 weight: {} {}".format(self.mean._layers["linear_layer_2"].weight_std.min(), 348 | # self.mean._layers["linear_layer_2"].weight_std.max())) 349 | #print("ll0 bias: {} {}".format(self.mean._layers["linear_layer_0"].bias_std.min(), 350 | # self.mean._layers["linear_layer_0"].bias_std.max())) 351 | #print("ll1 bias: {} {}".format(self.mean._layers["linear_layer_1"].bias_std.min(), 352 | # self.mean._layers["linear_layer_1"].bias_std.max())) 353 | #print("ll2 bias: {} {}".format(self.mean._layers["linear_layer_2"].bias_std.min(), 354 | # self.mean._layers["linear_layer_2"].bias_std.max())) 355 | #print("") 356 | 357 | # Make sure sd is not negative 358 | idxs = self.sd < self.sd_threshold 359 | self.sd.data[idxs] = self.sd_threshold 360 | 361 | # Empty log probs 362 | self.log_prob = torch.empty(self.batch_size, self.horizon, dtype=torch.float64) 363 | 364 | return stats 365 | 366 | def H_optimize(self, batch_loss): 367 | 368 | # Get appropriate loss 369 | loss, best_objective_loss, stats = self.loss_functions[self.method](batch_loss) 370 | 371 | # Adapt the mean 372 | self.optimizer["mean"].zero_grad() 373 | best_objective_loss.backward(retain_graph=True) 374 | self.optimizer["mean"].step() 375 | 376 | # Adapt the sd 377 | self.optimizer["sd"].zero_grad() 378 | loss.backward() 379 | self.optimizer["sd"].step() 380 | 381 | # Empty log probs 382 | self.log_prob = torch.empty(self.batch_size, self.horizon, dtype=torch.float64) 383 | 384 | return stats 385 | 386 | 387 | class VariationalOptimization(BaseStrategy): 388 | 389 | def __init__(self, *args, **kwargs): 390 | super(VariationalOptimization, self).__init__(*args, **kwargs) 391 | 392 | # Initialise mean and sd 393 | if self.cfg.MODEL.POLICY.NETWORK: 394 | # Set a feedforward network for means 395 | self.mean = FeedForward( 396 | self.state_dim, 397 | self.cfg.MODEL.POLICY.LAYERS, 398 | self.action_dim 399 | ) 400 | else: 401 | # Set tensors for means 402 | self.mean = Parameter(torch.from_numpy( 403 | self.initialise_mean(self.cfg.MODEL.POLICY.INITIAL_ACTION_MEAN, self.cfg.MODEL.POLICY.INITIAL_ACTION_SD) 404 | )) 405 | self.register_parameter("mean", self.mean) 406 | 407 | # Set tensors for standard deviations 408 | self.sd = Parameter(torch.from_numpy(self.initialise_sd(self.cfg.MODEL.POLICY.INITIAL_SD))) 409 | self.initial_clamped_sd = self.clamp_sd(self.sd.detach()) 410 | self.register_parameter("sd", self.sd) 411 | #self.clamped_sd = np.zeros((self.action_dim, self.horizon), dtype=np.float64) 412 | self.clamped_action = np.zeros((self.action_dim, self.horizon, self.batch_size), dtype=np.float64) 413 | 414 | # Initialise optimizer 415 | if self.method == "H": 416 | # Separate mean and sd optimizers (not sure if actually necessary) 417 | self.optimizer = {"mean": build_optimizer(self.cfg, self.get_named_parameters("mean")), 418 | "sd": build_optimizer(self.cfg, self.get_named_parameters("sd"))} 419 | self.best_actions = np.empty(self.sd.shape) 420 | self.best_actions.fill(np.nan) 421 | else: 422 | self.optimizer = build_optimizer(self.cfg, self.named_parameters()) 423 | 424 | # We need log probabilities for calculating REINFORCE loss 425 | self.log_prob = torch.empty(self.batch_size, self.horizon, dtype=torch.float64) 426 | 427 | def forward(self, state): 428 | 429 | # Get clamped sd 430 | clamped_sd = self.clamp_sd(self.sd[:, self.step_idx]) 431 | 432 | # Get mean of action value 433 | if self.cfg.MODEL.POLICY.NETWORK: 434 | mean = self.mean(state).double() 435 | else: 436 | mean = self.mean[:, self.step_idx] 437 | 438 | if not self.training: 439 | return self.clamp_action(mean) 440 | 441 | # Get normal distribution 442 | dist = torch.distributions.Normal(mean, clamped_sd) 443 | 444 | # Sample action 445 | if self.method == "H" and self.episode_idx == 0: 446 | if np.all(np.isnan(self.best_actions[:, self.step_idx])): 447 | action = mean 448 | else: 449 | action = torch.from_numpy(self.best_actions[:, self.step_idx]) 450 | elif self.batch_size > 1: 451 | action = dist.rsample() 452 | else: 453 | action = mean 454 | 455 | # Clip action 456 | action = self.clamp_action(action) 457 | #action = self.clip(action, mean, 2.0*clamped_sd) 458 | 459 | self.clamped_action[:, self.step_idx, self.episode_idx-1] = action.detach().numpy() 460 | 461 | # Get log prob for REINFORCE loss calculations 462 | self.log_prob[self.episode_idx-1, self.step_idx] = dist.log_prob(action.detach()).sum() 463 | 464 | return action 465 | 466 | def optimize(self, batch_loss): 467 | return self.optimize_functions.get(self.method, self.standard_optimize)(batch_loss) 468 | 469 | 470 | class CMAES(BaseStrategy): 471 | 472 | def __init__(self, *args, **kwargs): 473 | super(CMAES, self).__init__(*args, **kwargs) 474 | 475 | # Make sure batch size is larger than one 476 | assert self.batch_size > 1, "Batch size must be >1 for CMA-ES" 477 | 478 | # Set up CMA-ES options 479 | cmaes_options = {"popsize": self.batch_size, "CMA_diagonal": True} 480 | 481 | # Initialise mean and flatten it 482 | self.mean = self.initialise_mean() 483 | self.mean = np.reshape(self.mean, (self.mean.size,)) 484 | 485 | # We want to store original (list of batches) actions for tell 486 | self.orig_actions = [] 487 | 488 | # Initialise CMA-ES 489 | self.optimizer = cma.CMAEvolutionStrategy(self.mean, self.cfg.MODEL.POLICY.INITIAL_SD, inopts=cmaes_options) 490 | self.actions = [] 491 | 492 | def forward(self, state): 493 | 494 | # If we've hit the end of minibatch we need to sample more actions 495 | if self.step_idx == 0 and self.episode_idx - 1 == 0 and self.training: 496 | self.orig_actions = self.optimizer.ask() 497 | self.actions = torch.empty(self.action_dim, self.horizon, self.batch_size, dtype=torch.float64) 498 | for ep_idx, ep_actions in enumerate(self.orig_actions): 499 | self.actions[:, :, ep_idx] = torch.from_numpy(np.reshape(ep_actions, (self.action_dim, self.horizon))) 500 | 501 | # Get action 502 | action = self.actions[:, self.step_idx, self.episode_idx-1] 503 | 504 | return action 505 | 506 | def optimize(self, batch_loss): 507 | loss = batch_loss.sum(axis=1) 508 | self.optimizer.tell(self.orig_actions, loss.detach().numpy()) 509 | return {"objective_loss": float(loss.detach().numpy().mean()), "total_loss": float(loss.detach().numpy().mean())} 510 | 511 | def get_clamped_sd(self): 512 | return np.asarray(self.sd) 513 | 514 | def get_clamped_action(self): 515 | return self.actions.detach().numpy() 516 | 517 | 518 | class Perttu(BaseStrategy): 519 | def __init__(self, *args, **kwargs): 520 | super(Perttu, self).__init__(*args, **kwargs) 521 | 522 | # Initialise optimizer object 523 | self.optimizer = \ 524 | Optimizer(mode=self.method, 525 | initialMean=np.random.normal(self.cfg.MODEL.POLICY.INITIAL_ACTION_MEAN, 526 | self.cfg.MODEL.POLICY.INITIAL_ACTION_SD, 527 | (self.action_dim, self.horizon)), 528 | initialSd=self.cfg.MODEL.POLICY.INITIAL_SD*np.ones((self.action_dim, self.horizon)), 529 | #initialSd=self.cfg.MODEL.POLICY.INITIAL_SD*np.ones((1, 1)), 530 | learningRate=self.cfg.SOLVER.BASE_LR, 531 | adamBetas=(0.9, 0.99), 532 | minReinforceLossWeight=0.0, 533 | nBatch=self.cfg.MODEL.BATCH_SIZE, 534 | solver=self.cfg.SOLVER.OPTIMIZER) 535 | 536 | def forward(self, state): 537 | 538 | # If we've hit the end of minibatch we need to sample more actions 539 | if self.training: 540 | if self.step_idx == 0 and self.episode_idx-1 == 0: 541 | samples = self.optimizer.ask() 542 | self.actions = torch.empty(self.action_dim, self.horizon, self.batch_size) 543 | for ep_idx, ep_actions in enumerate(samples): 544 | self.actions[:, :, ep_idx] = torch.reshape(ep_actions, (self.action_dim, self.horizon)) 545 | 546 | # Get action 547 | action = self.actions[:, self.step_idx, self.episode_idx-1] 548 | 549 | else: 550 | if self.method != "CMA-ES": 551 | if self.step_idx == 0: 552 | samples = self.optimizer.ask(testing=~self.training) 553 | self.actions = torch.empty(self.action_dim, self.horizon, self.batch_size) 554 | for ep_idx, ep_actions in enumerate(samples): 555 | self.actions[:, :, ep_idx] = torch.reshape(ep_actions, (self.action_dim, self.horizon)) 556 | 557 | # Get action 558 | action = self.actions[:, self.step_idx, 0] 559 | 560 | return action.double() 561 | 562 | def optimize(self, batch_loss): 563 | loss, meanFval = self.optimizer.tell(batch_loss) 564 | return {"objective_loss": float(meanFval), "total_loss": float(loss)} 565 | 566 | def get_clamped_sd(self): 567 | return self.optimizer.getClampedSd().reshape([self.optimizer.original_dim, self.optimizer.steps]).detach().numpy() 568 | 569 | def get_clamped_action(self): 570 | return self.actions.detach().numpy() 571 | -------------------------------------------------------------------------------- /model/blocks/policy/trajopt.py: -------------------------------------------------------------------------------- 1 | from .base import BasePolicy 2 | import torch 3 | from torch.nn.parameter import Parameter 4 | from .strategies import * 5 | import numpy as np 6 | 7 | 8 | class TrajOpt(BasePolicy): 9 | """Trajectory Optimization Network""" 10 | def __init__(self, cfg, agent): 11 | super(TrajOpt, self).__init__(cfg, agent) 12 | 13 | # Parametrize optimization actions 14 | action_size = self.agent.action_space.sample().shape[0] 15 | horizon = cfg.MODEL.POLICY.MAX_HORIZON_STEPS 16 | 17 | #self.action_mean = Parameter(torch.zeros(horizon, action_size)) 18 | #nn.init.normal_(self.action_mean, mean=0.0, std=1.0) 19 | #self.register_parameter("action_mean", self.action_mean) 20 | 21 | # Get standard deviations as well when doing variational optimization 22 | #if policy_cfg.VARIATIONAL: 23 | # self.action_std = Parameter(torch.empty(horizon, action_size).fill_(-2)) 24 | # self.register_parameter("action_std", self.action_std) 25 | 26 | self.strategy = PRVO(dim=np.array([action_size, horizon]), nbatch=cfg.MODEL.POLICY.BATCH_SIZE, 27 | gamma=cfg.MODEL.POLICY.GAMMA, learning_rate=cfg.SOLVER.BASE_LR) 28 | 29 | # Set index to zero 30 | self.step_index = 0 31 | self.episode_index = 0 32 | 33 | def forward(self, s): 34 | #if self.policy_cfg.VARIATIONAL: 35 | # action = torch.distributions.Normal(self.action_mean[self.index], torch.exp(self.action_std[self.index])).rsample() 36 | #else: 37 | # action = self.action_mean[self.index] 38 | 39 | # Sample a new set of actions when required 40 | if self.step_index == 0: 41 | self.actions = self.strategy.sample(self.training) 42 | 43 | action = self.actions[:, self.step_index] 44 | self.step_index += 1 45 | return action 46 | 47 | def episode_callback(self): 48 | self.step_index = 0 49 | 50 | def optimize(self, batch_loss): 51 | return self.strategy.optimize(batch_loss) 52 | 53 | -------------------------------------------------------------------------------- /model/blocks/policy/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | from .zmus import ZMUSWrapper 2 | 3 | __all__ = ["ZMUSWrapper"] 4 | -------------------------------------------------------------------------------- /model/blocks/policy/wrappers/zmus.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.parameter import Parameter 4 | 5 | 6 | class ZMUSWrapper(nn.Module): 7 | """Zero-mean Unit-STD States 8 | see: https://stats.stackexchange.com/questions/43159/how-to-calculate-pooled-variance-of-two-groups-given-known-group-variances-mean 9 | """ 10 | 11 | def __init__(self, policy_net, eps=1e-6): 12 | super(ZMUSWrapper, self).__init__() 13 | 14 | self.eps = eps 15 | self.policy_net = policy_net 16 | self.policy_cfg = self.policy_net.policy_cfg 17 | 18 | # Parameters 19 | state_size = self.policy_net.agent.observation_space.shape 20 | self.state_mean = Parameter(torch.Tensor(state_size, )) 21 | self.state_variance = Parameter(torch.Tensor(state_size, )) 22 | self.state_mean.requires_grad = False 23 | self.state_variance.requires_grad = False 24 | 25 | # cash 26 | self.size = 0 27 | self.ep_states_data = [] 28 | 29 | self.first_pass = True 30 | 31 | def _get_state_mean(self): 32 | return self.state_mean.detach() 33 | 34 | def _get_state_variance(self): 35 | return self.state_variance.detach() 36 | 37 | def forward(self, s): 38 | self.size += 1 39 | self.ep_states_data.append(s) 40 | if not self.first_pass: 41 | s = (s - self._get_state_mean()) / \ 42 | (torch.sqrt(self._get_state_variance()) + self.eps) 43 | return self.policy_net(s) 44 | 45 | def episode_callback(self): 46 | ep_states_tensor = torch.stack(self.ep_states_data) 47 | new_data_mean = torch.mean(ep_states_tensor, dim=0) 48 | new_data_var = torch.var(ep_states_tensor, dim=0) 49 | if self.first_pass: 50 | self.state_mean.data = new_data_mean 51 | self.state_variance.data = new_data_var 52 | self.first_pass = False 53 | else: 54 | n = len(self.ep_states_data) 55 | mean = self._get_state_mean() 56 | var = self._get_state_variance() 57 | new_data_mean_sq = torch.mul(new_data_mean, new_data_mean) 58 | size = min(self.policy_cfg.FORGET_COUNT_OBS_SCALER, self.size) 59 | new_mean = ((mean * size) + (new_data_mean * n)) / (size + n) 60 | new_var = (((size * (var + torch.mul(mean, mean))) + (n * (new_data_var + new_data_mean_sq))) / 61 | (size + n) - torch.mul(new_mean, new_mean)) 62 | self.state_mean.data = new_mean 63 | self.state_variance.data = torch.clamp(new_var, 0.) # occasionally goes negative, clip 64 | self.size += n 65 | 66 | def batch_callback(self): 67 | pass 68 | -------------------------------------------------------------------------------- /model/blocks/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_soft_lower_bound_fn 2 | 3 | __all__ = ["build_soft_lower_bound_fn"] 4 | -------------------------------------------------------------------------------- /model/blocks/utils/build.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from .functions import soft_lower_bound 3 | 4 | 5 | def build_soft_lower_bound_fn(policy_cfg): 6 | soft_lower_bound_fn = functools.partial(soft_lower_bound, 7 | bound=policy_cfg.SOFT_LOWER_STD_BOUND, 8 | threshold=policy_cfg.SOFT_LOWER_STD_THRESHOLD, 9 | ) 10 | return soft_lower_bound_fn 11 | -------------------------------------------------------------------------------- /model/blocks/utils/functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def soft_lower_bound(x, bound, threshold): 5 | range = threshold - bound 6 | return torch.max(x, range * torch.tanh((x - threshold) / range) + threshold) 7 | -------------------------------------------------------------------------------- /model/build.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from model import archs 3 | 4 | 5 | def build_model(cfg, agent): 6 | model_factory = getattr(archs, cfg.MODEL.META_ARCHITECTURE) 7 | model = model_factory(cfg, agent) 8 | # if cfg.MODEL.WEIGHTS != "": 9 | # model.load_state_dict(torch.load(cfg.MODEL.WEIGHTS)) 10 | return model 11 | -------------------------------------------------------------------------------- /model/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .defaults import get_cfg_defaults 2 | 3 | __all__ = ["get_cfg_defaults"] 4 | -------------------------------------------------------------------------------- /model/config/defaults.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | # ---------------------------------------------------------------------------- # 4 | # Define Config Node 5 | # ---------------------------------------------------------------------------- # 6 | _C = CN() 7 | 8 | # ---------------------------------------------------------------------------- # 9 | # Model Configs 10 | # ---------------------------------------------------------------------------- # 11 | _C.MODEL = CN() 12 | _C.MODEL.META_ARCHITECTURE = 'Basic' 13 | _C.MODEL.DEVICE = "cpu" 14 | _C.MODEL.WEIGHTS = "" # should be a path to .pth pytorch state dict file 15 | _C.MODEL.EPOCHS = 10000 16 | _C.MODEL.BATCH_SIZE = 8 17 | 18 | # ---------------------------------------------------------------------------- # 19 | # __Policy Net Configs 20 | # ---------------------------------------------------------------------------- # 21 | _C.MODEL.POLICY = CN() 22 | 23 | _C.MODEL.POLICY = CN() 24 | _C.MODEL.POLICY.ARCH = "StochasticPolicy" 25 | _C.MODEL.POLICY.MAX_HORIZON_STEPS = 100 26 | _C.MODEL.POLICY.LAYERS = [64, 64, 32, 8] # a list of hidden layer sizes for output fc. [] means no hidden 27 | #_C.MODEL.POLICY.NORM_LAYERS = [0, 1, 2] # should be a list of layer indices, example [0, 1, ...] 28 | _C.MODEL.POLICY.STD_SCALER = 1e-1 29 | _C.MODEL.POLICY.SOFT_LOWER_STD_BOUND = 1e-4 30 | _C.MODEL.POLICY.SOFT_LOWER_STD_THRESHOLD = 1e-1 31 | _C.MODEL.POLICY.OBS_SCALER = False 32 | _C.MODEL.POLICY.FORGET_COUNT_OBS_SCALER = 5000 33 | _C.MODEL.POLICY.GAMMA = 0.99 34 | _C.MODEL.POLICY.METHOD = "None" 35 | _C.MODEL.POLICY.INITIAL_LOG_SD = 0.0 36 | _C.MODEL.POLICY.INITIAL_SD = 0.0 37 | _C.MODEL.POLICY.INITIAL_ACTION_MEAN = 0.0 38 | _C.MODEL.POLICY.INITIAL_ACTION_SD = 0.1 39 | _C.MODEL.POLICY.GRAD_WEIGHTS = 'average' 40 | _C.MODEL.POLICY.NETWORK = False 41 | _C.MODEL.NSTEPS_FOR_BACKWARD = 1 42 | _C.MODEL.FRAME_SKIP = 1 43 | _C.MODEL.TIMESTEP = 0.0 44 | _C.MODEL.RANDOM_SEED = 0 45 | 46 | # ---------------------------------------------------------------------------- # 47 | # Model Configs 48 | # ---------------------------------------------------------------------------- # 49 | _C.MUJOCO = CN() 50 | _C.MUJOCO.ENV = 'InvertedPendulumEnv' 51 | _C.MUJOCO.ASSETS_PATH = "./mujoco/assets/" 52 | _C.MUJOCO.REWARD_SCALE = 1 53 | _C.MUJOCO.CLIP_ACTIONS = True 54 | _C.MUJOCO.POOL_SIZE = CN() 55 | 56 | # ---------------------------------------------------------------------------- # 57 | # Experience Replay 58 | # ---------------------------------------------------------------------------- # 59 | _C.EXPERIENCE_REPLAY = CN() 60 | _C.EXPERIENCE_REPLAY.SIZE = 2 ** 15 61 | _C.EXPERIENCE_REPLAY.SHUFFLE = True 62 | _C.EXPERIENCE_REPLAY.ENV_INIT_STATE_NUM = 2 ** 15 * 3 / 4 63 | 64 | # ---------------------------------------------------------------------------- # 65 | # Solver Configs 66 | # ---------------------------------------------------------------------------- # 67 | _C.SOLVER = CN() 68 | 69 | _C.SOLVER.OPTIMIZER = 'adam' 70 | _C.SOLVER.BASE_LR = 0.001 71 | _C.SOLVER.STD_LR_FACTOR = 0.001 72 | _C.SOLVER.BIAS_LR_FACTOR = 2 73 | _C.SOLVER.WEIGHT_DECAY_SD = 0.0 74 | 75 | _C.SOLVER.MOMENTUM = 0.9 76 | 77 | _C.SOLVER.WEIGHT_DECAY = 0.0005 78 | _C.SOLVER.WEIGHT_DECAY_BIAS = 0 79 | _C.SOLVER.ADAM_BETAS = (0.9, 0.999) 80 | 81 | # ---------------------------------------------------------------------------- # 82 | # Output Configs 83 | # ---------------------------------------------------------------------------- # 84 | _C.OUTPUT = CN() 85 | _C.OUTPUT.DIR = './output' 86 | _C.OUTPUT.NAME = 'timestamp' 87 | 88 | # ---------------------------------------------------------------------------- # 89 | # Log Configs 90 | # ---------------------------------------------------------------------------- # 91 | _C.LOG = CN() 92 | _C.LOG.PERIOD = 1 93 | _C.LOG.PLOT = CN() 94 | _C.LOG.PLOT.ENABLED = True 95 | _C.LOG.PLOT.DISPLAY_PORT = 8097 96 | _C.LOG.PLOT.ITER_PERIOD = 1 # effective plotting step is _C.LOG.PERIOD * LOG.PLOT.ITER_PERIOD 97 | _C.LOG.TESTING = CN() 98 | _C.LOG.TESTING.ENABLED = True 99 | _C.LOG.TESTING.ITER_PERIOD = 1 100 | _C.LOG.TESTING.RECORD_VIDEO = False 101 | _C.LOG.TESTING.COUNT_PER_ITER = 1 102 | _C.LOG.CHECKPOINT_PERIOD = 25000 103 | 104 | 105 | def get_cfg_defaults(): 106 | """Get a yacs CfgNode object with default values for my_project.""" 107 | # Return a clone so that the defaults will not be altered 108 | # This is for the "local variable" use pattern 109 | return _C.clone() 110 | -------------------------------------------------------------------------------- /model/engine/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/model/engine/__init__.py -------------------------------------------------------------------------------- /model/engine/dynamics_model_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | 5 | from model.engine.tester import do_testing 6 | import utils.logger as lg 7 | from utils.visdom_plots import VisdomLogger 8 | from mujoco import build_agent 9 | from model import build_model 10 | from model.blocks.policy.dynamics import DynamicsModel 11 | import torchviz 12 | 13 | 14 | def do_training( 15 | cfg, 16 | logger, 17 | output_results_dir, 18 | output_rec_dir, 19 | output_weights_dir 20 | ): 21 | # Build the agent 22 | agent = build_agent(cfg) 23 | 24 | # Build a forward dynamics model 25 | dynamics_model = DynamicsModel(agent) 26 | 27 | # Set mode to training (aside from policy output, matters for Dropout, BatchNorm, etc.) 28 | dynamics_model.train() 29 | 30 | # Set up visdom 31 | if cfg.LOG.PLOT.ENABLED: 32 | visdom = VisdomLogger(cfg.LOG.PLOT.DISPLAY_PORT) 33 | visdom.register_keys(['total_loss', 'average_sd', 'average_action', "reinforce_loss", 34 | "objective_loss", "sd", "action_grad", "sd_grad", "actions"]) 35 | 36 | # wrap screen recorder if testing mode is on 37 | if cfg.LOG.TESTING.ENABLED: 38 | if cfg.LOG.PLOT.ENABLED: 39 | visdom.register_keys(['test_reward']) 40 | 41 | # Collect losses here 42 | output = {"epoch": [], "objective_loss": []} 43 | 44 | # Start training 45 | for epoch_idx in range(cfg.MODEL.EPOCHS): 46 | batch_loss = torch.empty(cfg.MODEL.BATCH_SIZE, cfg.MODEL.POLICY.MAX_HORIZON_STEPS, dtype=torch.float64) 47 | batch_loss.fill_(np.nan) 48 | 49 | for episode_idx in range(cfg.MODEL.BATCH_SIZE): 50 | 51 | # Generate "random walk" set of actions (separately for each dimension) 52 | action = torch.zeros(agent.action_space.shape, dtype=torch.float64) 53 | #actions = np.zeros((agent.action_space.shape[0], cfg.MODEL.POLICY.MAX_HORIZON_STEPS)) 54 | actions = [] 55 | 56 | initial_state = torch.Tensor(agent.reset()) 57 | predicted_states = [] 58 | real_states = [] 59 | corrections = [] 60 | for step_idx in range(cfg.MODEL.POLICY.MAX_HORIZON_STEPS): 61 | 62 | # Generate random actions 63 | action = action + 0.1*(2*torch.rand(agent.action_space.shape) - 1) 64 | 65 | # Clamp to [-1, 1] 66 | action.clamp_(-1, 1) 67 | 68 | # Save action 69 | actions.append(action) 70 | 71 | previous_state = torch.from_numpy(agent.unwrapped._get_obs()) 72 | 73 | # Advance the actual simulation 74 | next_state, _, _, _ = agent.step(action) 75 | next_state = torch.from_numpy(next_state) 76 | real_states.append(next_state) 77 | 78 | # Advance with learned dynamics simulation 79 | pred_next_state = dynamics_model(previous_state.float(), action.float()).double() 80 | 81 | batch_loss[episode_idx, step_idx] = torch.pow(next_state - pred_next_state, 2).mean() 82 | #if agent.is_done: 83 | # break 84 | 85 | #dot = torchviz.make_dot(pred_next_state, params=dict(dynamics_model.named_parameters())) 86 | 87 | loss = torch.sum(batch_loss) 88 | dynamics_model.optimizer.zero_grad() 89 | loss.backward() 90 | dynamics_model.optimizer.step() 91 | 92 | output["objective_loss"].append(loss.detach().numpy()) 93 | output["epoch"].append(epoch_idx) 94 | 95 | if epoch_idx % cfg.LOG.PERIOD == 0: 96 | 97 | if cfg.LOG.PLOT.ENABLED: 98 | visdom.update({"total_loss": loss.detach().numpy()}) 99 | visdom.set({'actions': torch.stack(actions).detach().numpy()}) 100 | #visdom.set({'total_loss': loss["total_loss"].transpose()}) 101 | #visdom.update({'average_grad': np.log(torch.mean(model.policy_net.mean._layers["linear_layer_0"].weight.grad.abs()).detach().numpy())}) 102 | 103 | logger.info("REWARD: \t\t{} (iteration {})".format(loss.detach().numpy(), epoch_idx)) 104 | 105 | if cfg.LOG.PLOT.ENABLED and epoch_idx % cfg.LOG.PLOT.ITER_PERIOD == 0: 106 | visdom.do_plotting() 107 | 108 | # if epoch_idx % cfg.LOG.CHECKPOINT_PERIOD == 0: 109 | # torch.save(model.state_dict(), 110 | # os.path.join(output_weights_dir, 'iter_{}.pth'.format(epoch_idx))) 111 | 112 | if False:#cfg.LOG.TESTING.ENABLED: 113 | if epoch_idx % cfg.LOG.TESTING.ITER_PERIOD == 0: 114 | 115 | # Record if required 116 | agent.start_recording(os.path.join(output_rec_dir, "iter_{}.mp4".format(epoch_idx))) 117 | 118 | test_rewards = [] 119 | for _ in range(cfg.LOG.TESTING.COUNT_PER_ITER): 120 | test_reward = do_testing( 121 | cfg, 122 | model, 123 | agent, 124 | # first_state=state_xr.get_item(), 125 | ) 126 | test_rewards.append(test_reward) 127 | 128 | # Set training mode on again 129 | model.train() 130 | 131 | # Close the recorder 132 | agent.stop_recording() 133 | 134 | # Save outputs into log folder 135 | lg.save_dict_into_csv(output_results_dir, "output", output) 136 | 137 | # Save model 138 | torch.save(dynamics_model.state_dict(), os.path.join(output_weights_dir, "final_weights.pt")) 139 | -------------------------------------------------------------------------------- /model/engine/landscape_plot.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | 5 | from model import build_model 6 | import matplotlib.pyplot as pp 7 | 8 | 9 | def visualise2d(agent, output_dir, iter): 10 | 11 | # Get actions 12 | actions = agent.policy_net.optimizer.mean.clone().detach() 13 | 14 | # Build the model 15 | model = build_model(agent.cfg, agent) 16 | device = torch.device(agent.cfg.MODEL.DEVICE) 17 | model.to(device) 18 | 19 | # Set mode to test, we're not interested in optimizing the actions now 20 | model.eval() 21 | 22 | # Choose two random directions 23 | i_dir = torch.from_numpy(np.random.randn(actions.shape[0])).detach() 24 | j_dir = torch.from_numpy(np.random.randn(actions.shape[0])).detach() 25 | 26 | # Define range 27 | i_range = torch.from_numpy(np.linspace(-1, 1, 100)).requires_grad_() 28 | j_range = torch.from_numpy(np.linspace(-1, 1, 100)).requires_grad_() 29 | 30 | # Collect losses 31 | loss = torch.zeros((len(i_range), len(j_range))) 32 | 33 | # Collect grads 34 | #grads = torch.zeros((len(i_range), len(j_range), 2)) 35 | 36 | # Need two loops for two directions 37 | for i_idx, i in enumerate(i_range): 38 | for j_idx, j in enumerate(j_range): 39 | 40 | # Calculate new parameters 41 | new_actions = actions + i*i_dir + j*j_dir 42 | 43 | # Set new actions 44 | model.policy_net.optimizer.mean = new_actions 45 | 46 | # Loop through whole simulation 47 | state = torch.Tensor(agent.reset()) 48 | for step_idx in range(agent.cfg.MODEL.POLICY.MAX_HORIZON_STEPS): 49 | 50 | # Advance the simulation 51 | state, reward = model(state.detach()) 52 | 53 | # Collect losses 54 | loss[i_idx, j_idx] += -reward.squeeze() 55 | 56 | # Get gradients 57 | #loss[i_idx, j_idx].backward(retain_graph=True) 58 | #grads[i_idx, j_idx, 0] = i_range.grad[i_idx] 59 | #grads[i_idx, j_idx, 1] = j_range.grad[j_idx] 60 | 61 | # Do some plotting here 62 | pp.figure(figsize=(12, 12)) 63 | contours = pp.contour(i_range.detach().numpy(), j_range.detach().numpy(), loss.detach().numpy(), colors='black') 64 | pp.clabel(contours, inline=True, fontsize=8) 65 | pp.imshow(loss.detach().numpy(), extent=[-1, 1, -1, 1], origin="lower", cmap="RdGy", alpha=0.5) 66 | pp.colorbar() 67 | pp.savefig(os.path.join(output_dir, "contour_{}.png".format(iter))) 68 | -------------------------------------------------------------------------------- /model/engine/tester.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def do_testing( 5 | cfg, 6 | model, 7 | agent, 8 | samples=None, 9 | first_state=None 10 | ): 11 | 12 | # Let pytorch know we're evaluating a model 13 | model.eval() 14 | 15 | # We don't need gradients now 16 | with torch.no_grad(): 17 | 18 | if first_state is None: 19 | state = torch.DoubleTensor(agent.reset(update_episode_idx=False)) 20 | else: 21 | state = first_state 22 | agent.set_from_torch_state(state) 23 | reward_sum = 0. 24 | episode_iteration = 0 25 | for step_idx in range(cfg.MODEL.POLICY.MAX_HORIZON_STEPS): 26 | if cfg.LOG.TESTING.RECORD_VIDEO: 27 | agent.capture_frame() 28 | else: 29 | agent.render() 30 | #state, reward = model(state, samples[:, step_idx]) 31 | state, reward = model(state) 32 | reward_sum += reward 33 | if agent.is_done: 34 | break 35 | return reward_sum/episode_iteration 36 | -------------------------------------------------------------------------------- /model/engine/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | 5 | from model.engine.tester import do_testing 6 | import utils.logger as lg 7 | from utils.visdom_plots import VisdomLogger 8 | from mujoco import build_agent 9 | from model import build_model 10 | 11 | 12 | def do_training( 13 | cfg, 14 | logger, 15 | output_results_dir, 16 | output_rec_dir, 17 | output_weights_dir, 18 | iter 19 | ): 20 | 21 | if cfg.MODEL.RANDOM_SEED > 0: 22 | np.random.seed(cfg.MODEL.RANDOM_SEED + iter) 23 | torch.manual_seed(cfg.MODEL.RANDOM_SEED + iter) 24 | 25 | # Build the agent 26 | agent = build_agent(cfg) 27 | 28 | # Build the model 29 | model = build_model(cfg, agent) 30 | device = torch.device(cfg.MODEL.DEVICE) 31 | model.to(device) 32 | 33 | # Set mode to training (aside from policy output, matters for Dropout, BatchNorm, etc.) 34 | model.train() 35 | 36 | # Set up visdom 37 | if cfg.LOG.PLOT.ENABLED: 38 | visdom = VisdomLogger(cfg.LOG.PLOT.DISPLAY_PORT) 39 | visdom.register_keys(['total_loss', 'average_sd', 'average_action', "reinforce_loss", 40 | "objective_loss", "sd", "action_grad", "sd_grad", "average_grad"]) 41 | for action_idx in range(model.policy_net.action_dim): 42 | visdom.register_keys(["action_" + str(action_idx)]) 43 | 44 | # wrap screen recorder if testing mode is on 45 | if cfg.LOG.TESTING.ENABLED: 46 | if cfg.LOG.PLOT.ENABLED: 47 | visdom.register_keys(['test_reward']) 48 | 49 | # Collect losses here 50 | output = {"epoch": [], "objective_loss": [], "average_sd": []} 51 | 52 | # Start training 53 | for epoch_idx in range(cfg.MODEL.EPOCHS): 54 | batch_loss = torch.empty(cfg.MODEL.BATCH_SIZE, cfg.MODEL.POLICY.MAX_HORIZON_STEPS, dtype=torch.float64) 55 | batch_loss.fill_(np.nan) 56 | 57 | for episode_idx in range(cfg.MODEL.BATCH_SIZE): 58 | 59 | initial_state = torch.DoubleTensor(agent.reset()) 60 | states = [] 61 | states.append(initial_state) 62 | #grads = np.zeros((cfg.MODEL.POLICY.MAX_HORIZON_STEPS, 120)) 63 | for step_idx in range(cfg.MODEL.POLICY.MAX_HORIZON_STEPS): 64 | state, reward = model(states[step_idx]) 65 | batch_loss[episode_idx, step_idx] = -reward 66 | #(-reward).backward(retain_graph=True) 67 | #grads[step_idx, :] = model.policy_net.optimizer.mean.grad.detach().numpy() 68 | #grads[step_idx, step_idx+1:40] = np.nan 69 | #grads[step_idx, 40+step_idx+1:80] = np.nan 70 | #grads[step_idx, 80+step_idx+1:] = np.nan 71 | #model.policy_net.optimizer.optimizer.zero_grad() 72 | states.append(state) 73 | if agent.is_done: 74 | break 75 | 76 | agent.running_sum = 0 77 | loss = model.policy_net.optimize(batch_loss) 78 | #zero = np.abs(grads) < 1e-9 79 | #grads[zero] = np.nan 80 | #medians = np.nanmedian(grads, axis=0) 81 | #model.policy_net.optimizer.mean.grad.data = torch.from_numpy(medians) 82 | #torch.nn.utils.clip_grad_value_([model.policy_net.optimizer.mean, model.policy_net.optimizer.sd], 1) 83 | #model.policy_net.optimizer.optimizer.step() 84 | #model.policy_net.optimizer.optimizer.zero_grad() 85 | #loss = {'objective_loss': torch.sum(batch_loss, dim=1).mean().detach().numpy()} 86 | 87 | output["objective_loss"].append(loss["objective_loss"]) 88 | output["epoch"].append(epoch_idx) 89 | output["average_sd"].append(np.mean(model.policy_net.get_clamped_sd())) 90 | 91 | if epoch_idx % cfg.LOG.PERIOD == 0: 92 | 93 | if cfg.LOG.PLOT.ENABLED: 94 | visdom.update(loss) 95 | #visdom.set({'total_loss': loss["total_loss"].transpose()}) 96 | 97 | clamped_sd = model.policy_net.get_clamped_sd() 98 | clamped_action = model.policy_net.get_clamped_action() 99 | 100 | #visdom.update({'average_grad': np.log(torch.mean(model.policy_net.mean._layers["linear_layer_0"].weight.grad.abs()).detach().numpy())}) 101 | 102 | if len(clamped_sd) > 0: 103 | visdom.update({'average_sd': np.mean(clamped_sd, axis=1)}) 104 | visdom.update({'average_action': np.mean(clamped_action, axis=(1, 2)).squeeze()}) 105 | 106 | for action_idx in range(model.policy_net.action_dim): 107 | visdom.set({'action_'+str(action_idx): clamped_action[action_idx, :, :]}) 108 | if clamped_sd is not None: 109 | visdom.set({'sd': clamped_sd.transpose()}) 110 | # visdom.set({'action_grad': model.policy_net.mean.grad.detach().numpy().transpose()}) 111 | 112 | logger.info("REWARD: \t\t{} (iteration {})".format(loss["objective_loss"], epoch_idx)) 113 | 114 | if cfg.LOG.PLOT.ENABLED and epoch_idx % cfg.LOG.PLOT.ITER_PERIOD == 0: 115 | visdom.do_plotting() 116 | 117 | if epoch_idx % cfg.LOG.CHECKPOINT_PERIOD == 0: 118 | torch.save(model.state_dict(), 119 | os.path.join(output_weights_dir, 'iter_{}.pth'.format(epoch_idx))) 120 | 121 | if cfg.LOG.TESTING.ENABLED: 122 | if epoch_idx % cfg.LOG.TESTING.ITER_PERIOD == 0: 123 | 124 | # Record if required 125 | agent.start_recording(os.path.join(output_rec_dir, "iter_{}_{}.mp4".format(iter, epoch_idx))) 126 | 127 | test_rewards = [] 128 | for _ in range(cfg.LOG.TESTING.COUNT_PER_ITER): 129 | test_reward = do_testing( 130 | cfg, 131 | model, 132 | agent, 133 | # first_state=state_xr.get_item(), 134 | ) 135 | test_rewards.append(test_reward) 136 | 137 | # Set training mode on again 138 | model.train() 139 | 140 | # Close the recorder 141 | agent.stop_recording() 142 | 143 | # Save outputs into log folder 144 | lg.save_dict_into_csv(output_results_dir, "output_{}".format(iter), output) 145 | 146 | # Return actions 147 | return agent -------------------------------------------------------------------------------- /model/engine/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_state_experience_replay, build_state_experience_replay_data_loader 2 | 3 | __all__ = ["build_state_experience_replay", "build_state_experience_replay_data_loader"] 4 | -------------------------------------------------------------------------------- /model/engine/utils/build.py: -------------------------------------------------------------------------------- 1 | from .experience_replay import StateQueue, make_data_sampler, make_batch_data_sampler, batch_collator, make_data_loader 2 | 3 | 4 | def build_state_experience_replay_data_loader(cfg): 5 | state_queue = StateQueue(cfg.EXPERIENCE_REPLAY.SIZE) 6 | sampler = make_data_sampler(state_queue, cfg.EXPERIENCE_REPLAY.SHUFFLE) 7 | batch_sampler = make_batch_data_sampler(sampler, cfg.SOLVER.BATCH_SIZE) 8 | data_loader = make_data_loader(state_queue, batch_sampler, batch_collator) 9 | return data_loader 10 | 11 | 12 | def build_state_experience_replay(cfg): 13 | state_queue = StateQueue(cfg.EXPERIENCE_REPLAY.SIZE) 14 | return state_queue 15 | -------------------------------------------------------------------------------- /model/engine/utils/experience_replay.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import random 4 | 5 | 6 | class StateQueue(object): 7 | 8 | def __init__(self, size): 9 | self.size = size 10 | self.queue = [] 11 | 12 | def __getitem__(self, idx): 13 | idx = idx % len(self.queue) 14 | return self.queue[idx] 15 | 16 | def __len__(self): 17 | return len(self.queue) or self.size 18 | 19 | def add(self, state): 20 | self.queue.append(state) 21 | self.queue = self.queue[-self.size:] 22 | 23 | def add_batch(self, states): 24 | self.queue.extend([state for state in states]) 25 | self.queue = self.queue[-self.size:] 26 | 27 | def get_item(self): 28 | return random.choice(self.queue) 29 | 30 | 31 | def make_data_sampler(dataset, shuffle): 32 | if shuffle: 33 | sampler = torch.utils.data.sampler.RandomSampler(dataset) 34 | else: 35 | sampler = torch.utils.data.sampler.SequentialSampler(dataset) 36 | return sampler 37 | 38 | 39 | def make_batch_data_sampler(sampler, batch_size): 40 | batch_sampler = torch.utils.data.sampler.BatchSampler( 41 | sampler, batch_size, drop_last=False 42 | ) 43 | return batch_sampler 44 | 45 | 46 | def batch_collator(batch): 47 | return torch.stack(batch, dim=0) 48 | 49 | 50 | def make_data_loader(dataset, batch_sampler, collator): 51 | data_loader = torch.utils.data.DataLoader( 52 | dataset, 53 | batch_sampler=batch_sampler, 54 | collate_fn=collator, 55 | ) 56 | return data_loader 57 | -------------------------------------------------------------------------------- /model/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .feed_forward import FeedForward 2 | 3 | __all__ = ["FeedForward"] 4 | -------------------------------------------------------------------------------- /model/layers/feed_forward.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from itertools import tee 4 | from collections import OrderedDict 5 | from torch.nn.parameter import Parameter 6 | from torch.nn import init 7 | from torch.nn import functional 8 | import math 9 | 10 | def pairwise(iterable): 11 | """s -> (s0,s1), (s1,s2), (s2, s3), ...""" 12 | a, b = tee(iterable) 13 | next(b, None) 14 | return zip(a, b) 15 | 16 | 17 | class Identity(nn.Module): 18 | def __init__(self): 19 | super(Identity, self).__init__() 20 | 21 | def forward(self, x): 22 | return x 23 | 24 | 25 | class FeedForward(nn.Module): 26 | def __init__(self, input_size, middle_sizes, output_size, norm_layers=[], activation_fn=nn.LeakyReLU, activation_out=False): 27 | """ 28 | :param input_size: int input feature size 29 | :param middle_sizes: [int] list of intermediate hidden state sizes 30 | :param output_size: int output size of the network 31 | """ 32 | # TODO: use setattr for linear layers and activations, parameters are empty 33 | super(FeedForward, self).__init__() 34 | self._sizes = [input_size] + middle_sizes + [output_size] 35 | self._layers = OrderedDict() 36 | for i, (in_size, out_size) in enumerate(pairwise(self._sizes)): 37 | # Add linear layer 38 | #linear_layer = Variational(in_size, out_size, bias=True) 39 | linear_layer = nn.Linear(in_size, out_size, bias=True) 40 | self.__setattr__('linear_layer_{}'.format(str(i)), linear_layer) 41 | self._layers.update({'linear_layer_{}'.format(str(i)): linear_layer}) 42 | # Add batch normalization layer 43 | if i in norm_layers: 44 | batchnorm_layer = nn.BatchNorm1d(out_size) 45 | self.__setattr__('batchnorm_layer_{}'.format(str(i)), batchnorm_layer) 46 | self._layers.update({'batchnorm_layer_{}'.format(str(i)): batchnorm_layer}) 47 | # Add activation layer 48 | self.__setattr__('activation_layer_{}'.format(str(i)), activation_fn()) # relu for the last layer also makes sense 49 | self._layers.update({'activation_layer_{}'.format(str(i)): activation_fn()}) 50 | if not activation_out: 51 | self._layers.popitem() 52 | 53 | self.sequential = nn.Sequential(self._layers) 54 | 55 | def forward(self, x): 56 | # out = x 57 | # for i in range(len(self._sizes) - 1): 58 | # fc = self.__getattr__('linear_layer_' + str(i)) 59 | # ac = self.__getattr__('activation_layer_' + str(i)) 60 | # out = ac(fc(out)) 61 | out = self.sequential(x) 62 | return out 63 | 64 | 65 | class Variational(nn.Module): 66 | __constants__ = ['bias', 'in_features', 'out_features'] 67 | 68 | def __init__(self, in_features, out_features, bias=True): 69 | super(Variational, self).__init__() 70 | self.in_features = in_features 71 | self.out_features = out_features 72 | self.weight_mean = Parameter(torch.Tensor(out_features, in_features)) 73 | self.weight_std = Parameter(torch.Tensor(out_features, in_features)) 74 | if bias: 75 | self.bias_mean = Parameter(torch.Tensor(out_features)) 76 | self.bias_std = Parameter(torch.Tensor(out_features)) 77 | else: 78 | self.register_parameter('bias_mean', None) 79 | self.register_parameter('bias_std', None) 80 | self.reset_parameters() 81 | 82 | def reset_parameters(self): 83 | init.kaiming_uniform_(self.weight_mean, a=math.sqrt(5)) 84 | init.constant_(self.weight_std, 0.03) 85 | if self.bias_mean is not None: 86 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight_mean) 87 | bound = 1 / math.sqrt(fan_in) 88 | init.uniform_(self.bias_mean, -bound, bound) 89 | init.constant_(self.bias_std, 0.03) 90 | 91 | def forward(self, input): 92 | neg_weight = self.weight_std.data < 0 93 | self.weight_std.data[neg_weight] = 0.00001 94 | weight = torch.distributions.Normal(self.weight_mean, self.weight_std).rsample() 95 | neg_bias = self.bias_std.data < 0 96 | self.bias_std.data[neg_bias] = 0.00001 97 | bias = torch.distributions.Normal(self.bias_mean, self.bias_std).rsample() 98 | return functional.linear(input, weight, bias) 99 | 100 | def extra_repr(self): 101 | return 'in_features={}, out_features={}, bias={}'.format( 102 | self.in_features, self.out_features, self.bias is not None 103 | ) 104 | 105 | -------------------------------------------------------------------------------- /mujoco/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_agent 2 | 3 | __all__ = ["build_agent"] 4 | -------------------------------------------------------------------------------- /mujoco/assets/Geometry/bofoot.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/bofoot.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/calcn_r.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/calcn_r.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/capitate_rFB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/capitate_rFB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/clavicle_rFB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/clavicle_rFB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/femur_r.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/femur_r.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/fibula_r.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/fibula_r.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/foot.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/foot.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/hamate_rFB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/hamate_rFB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/humerus_rFB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/humerus_rFB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/indexDistal_rFB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/indexDistal_rFB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/indexMetacarpal_rFB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/indexMetacarpal_rFB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/indexMid_rFB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/indexMid_rFB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/indexProx_rFB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/indexProx_rFB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/l_pelvis.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/l_pelvis.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/lunate_rFB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/lunate_rFB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/middleDistal_rFB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/middleDistal_rFB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/middleMetacarpal_rFB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/middleMetacarpal_rFB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/middleMid_rFB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/middleMid_rFB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/middleProx_rFB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/middleProx_rFB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/pat.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/pat.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/patella_r.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/patella_r.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/pelvis.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/pelvis.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/pinkyDistal_rFB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/pinkyDistal_rFB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/pinkyMetacarpal_rFB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/pinkyMetacarpal_rFB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/pinkyMid_rFB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/pinkyMid_rFB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/pinkyProx_rFB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/pinkyProx_rFB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/pisiform_rFB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/pisiform_rFB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/radius_rFB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/radius_rFB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/ribcageFB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/ribcageFB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/ringDistal_rFB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/ringDistal_rFB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/ringMetacarpal_rFB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/ringMetacarpal_rFB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/ringMid_rFB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/ringMid_rFB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/ringProx_rFB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/ringProx_rFB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/sacrum.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/sacrum.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/scaphoid_rFB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/scaphoid_rFB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/scapula_rFB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/scapula_rFB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/talus.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/talus.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/talus_r.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/talus_r.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/thoracic10FB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/thoracic10FB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/thoracic11FB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/thoracic11FB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/thoracic12FB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/thoracic12FB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/thoracic1FB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/thoracic1FB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/thoracic2FB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/thoracic2FB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/thoracic3FB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/thoracic3FB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/thoracic4FB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/thoracic4FB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/thoracic5FB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/thoracic5FB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/thoracic6FB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/thoracic6FB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/thoracic7FB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/thoracic7FB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/thoracic8FB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/thoracic8FB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/thoracic9FB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/thoracic9FB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/thumbDistal_rFB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/thumbDistal_rFB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/thumbMid_rFB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/thumbMid_rFB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/thumbProx_rFB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/thumbProx_rFB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/tibia_r.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/tibia_r.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/toes_r.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/toes_r.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/trapezium_rFB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/trapezium_rFB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/trapezoid_rFB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/trapezoid_rFB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/triquetral_rFB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/triquetral_rFB.stl -------------------------------------------------------------------------------- /mujoco/assets/Geometry/ulna_rFB.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/mujoco/assets/Geometry/ulna_rFB.stl -------------------------------------------------------------------------------- /mujoco/assets/half_cheetah.xml: -------------------------------------------------------------------------------- 1 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 97 | -------------------------------------------------------------------------------- /mujoco/assets/hopper.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 45 | -------------------------------------------------------------------------------- /mujoco/assets/inverted_double_pendulum.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 33 | -------------------------------------------------------------------------------- /mujoco/assets/inverted_pendulum.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 28 | -------------------------------------------------------------------------------- /mujoco/assets/leg6dof9musc_converted.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | -------------------------------------------------------------------------------- /mujoco/assets/swimmer.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 40 | -------------------------------------------------------------------------------- /mujoco/assets/walker2d.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /mujoco/build.py: -------------------------------------------------------------------------------- 1 | from mujoco import envs 2 | from mujoco.utils.wrappers.mj_block import MjBlockWrapper 3 | from mujoco.utils.wrappers.etc import SnapshotWrapper, IndexWrapper, ViewerWrapper 4 | 5 | 6 | def build_agent(cfg): 7 | agent_factory = getattr(envs, cfg.MUJOCO.ENV) 8 | agent = agent_factory(cfg) 9 | # if cfg.MUJOCO.REWARD_SCALE != 1.0: 10 | # from .utils import RewardScaleWrapper 11 | # agent = RewardScaleWrapper(agent, cfg.MUJOCO.REWARD_SCALE) 12 | # if cfg.MUJOCO.CLIP_ACTIONS: 13 | # from .utils import ClipActionsWrapper 14 | # agent = ClipActionsWrapper(agent) 15 | # if cfg.MODEL.POLICY.ARCH == 'TrajOpt': 16 | # from .utils import FixedStateWrapper 17 | # agent = FixedStateWrapper(agent) 18 | 19 | # Make configs accessible through agent 20 | #agent.cfg = cfg 21 | 22 | # Record video 23 | agent = ViewerWrapper(agent) 24 | 25 | # Keep track of step, episode, and batch indices 26 | agent = IndexWrapper(agent, cfg.MODEL.BATCH_SIZE) 27 | 28 | # Grab and set snapshots of data 29 | agent = SnapshotWrapper(agent) 30 | 31 | # This should probably be last so we get all wrappers 32 | agent = MjBlockWrapper(agent) 33 | 34 | # Maybe we should set opt tolerance to zero so mujoco solvers shouldn't stop early? 35 | agent.model.opt.tolerance = 0 36 | 37 | return agent 38 | -------------------------------------------------------------------------------- /mujoco/envs/HandModelTSLAdjusted.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from gym import utils 4 | from gym.envs.mujoco import mujoco_env 5 | import xmltodict 6 | from how_bots_type import MocapWithKeyStateData 7 | import Utils 8 | from pyquaternion import Quaternion 9 | from copy import deepcopy 10 | 11 | 12 | class HandModelTSLAdjustedEnv(mujoco_env.MujocoEnv, utils.EzPickle): 13 | """ 14 | """ 15 | def __init__(self, cfg): 16 | 17 | self.initialised = False 18 | self.cfg = cfg 19 | 20 | # Get model path 21 | mujoco_assets_dir = os.path.abspath("./mujoco/assets/") 22 | mujoco_xml_file = os.path.join(mujoco_assets_dir, "HandModelTSLAdjusted_converted.xml") 23 | 24 | # Add sites to fingertips by reading the model xml file, adding the sites, and saving the model file 25 | with open(mujoco_xml_file) as f: 26 | text = f.read() 27 | p = xmltodict.parse(text) 28 | 29 | # We might want to add a "keyboard" object into the simulation; it should be in front of torso 30 | torso = p["mujoco"]["worldbody"]["body"] 31 | pos = np.array(torso["@pos"].split(), dtype=float) 32 | quat = np.array(torso["@quat"].split(), dtype=float) 33 | T_torso = Utils.create_transformation_matrix(pos=pos, quat=quat) 34 | R_keyboard = np.matmul(Utils.create_rotation_matrix(axis=[0, 0, 1], deg=180), Utils.create_rotation_matrix(axis=[0, 1, 0], deg=-90)) 35 | T_keyboard = Utils.create_transformation_matrix(pos=np.array([0.6, -0.1, -0.25]), R=R_keyboard[:3, :3]) 36 | T = np.matmul(T_torso, T_keyboard) 37 | 38 | # Now create the keyboard object 39 | #keyboard = {"@name": "keyboard", 40 | # "@pos": Utils.array_to_string(T[:3, 3]), 41 | # "@quat": Utils.array_to_string(Quaternion(matrix=T).elements), 42 | # "geom": {"@type": "box", "@pos": "0.25 0.005 -0.15", "@size": "0.25 0.005 0.15"}} 43 | 44 | keyboard = {"@name": "keyboard", 45 | "@pos": Utils.array_to_string(pos + np.array([0.5, 0, 0])), 46 | "@quat": torso["@quat"], 47 | "geom": {"@type": "sphere", "@size": "0.02 0 0", "@rgba": "0.9 0.1 0.1 1"}} 48 | 49 | 50 | # Add it to the xml file 51 | if not isinstance(p["mujoco"]["worldbody"]["body"], list): 52 | p["mujoco"]["worldbody"]["body"] = [p["mujoco"]["worldbody"]["body"]] 53 | p["mujoco"]["worldbody"]["body"].append(keyboard) 54 | 55 | # Get a reference to trapezium_pre_r in the kinematic tree 56 | trapezium = torso["body"]["body"][1]["body"]["body"]["body"]["body"]["body"]["body"][1]["body"]["body"][2]["body"]["body"] 57 | 58 | # Add fingertip site to thumb 59 | self.add_fingertip_site(trapezium[0]["body"]["body"]["body"]["body"]["body"]["site"], 60 | "fingertip_thumb", "-0.0015 -0.022 0.002") 61 | 62 | # Add fingertip site to index finger 63 | self.add_fingertip_site(trapezium[2]["body"]["body"]["body"]["body"]["site"], 64 | "fingertip_index", "-0.0005 -0.0185 0.001") 65 | 66 | # Add fingertip site to middle finger 67 | self.add_fingertip_site(trapezium[3]["body"]["body"]["body"]["site"], 68 | "fingertip_middle", "0.0005 -0.01875 0.0") 69 | 70 | # Add fingertip site to ring finger 71 | self.add_fingertip_site(trapezium[4]["body"]["body"]["body"]["body"]["site"], 72 | "fingertip_ring", "-0.001 -0.01925 0.00075") 73 | 74 | # Add fingertip site to little finger 75 | self.add_fingertip_site(trapezium[5]["body"]["body"]["body"]["body"]["site"], 76 | "fingertip_little", "-0.0005 -0.0175 0.0") 77 | 78 | # Save the modified model xml file 79 | mujoco_xml_file_modified = os.path.join(mujoco_assets_dir, "HandModelTSLAdjusted_converted_fingertips.xml") 80 | with open(mujoco_xml_file_modified, 'w') as f: 81 | f.write(xmltodict.unparse(p, pretty=True, indent=" ")) 82 | 83 | # Load model 84 | self.frame_skip = self.cfg.MODEL.FRAME_SKIP 85 | mujoco_env.MujocoEnv.__init__(self, mujoco_xml_file_modified, self.frame_skip) 86 | 87 | # Set timestep 88 | if cfg.MODEL.TIMESTEP > 0: 89 | self.model.opt.timestep = cfg.MODEL.TIMESTEP 90 | 91 | # The default initial state isn't stable; run a simulation for a couple of seconds to get a stable initial state 92 | duration = 4 93 | for _ in range(int(duration/self.model.opt.timestep)): 94 | self.sim.step() 95 | 96 | # Use these joint values for initial state 97 | self.init_qpos = deepcopy(self.data.qpos) 98 | 99 | # Reset model 100 | self.reset_model() 101 | 102 | # Load a dataset from how-bots-type 103 | if False: 104 | data_filename = "/home/aleksi/Workspace/how-bots-type/data/24fps/000895_sentences_mocap_with_key_states.msgpack" 105 | mocap_data = MocapWithKeyStateData.load(data_filename) 106 | mocap_ds = mocap_data.as_dataset() 107 | 108 | # Get positions of each fingertip, maybe start with just 60 seconds 109 | fps = 24 110 | length = 10 111 | self.fingers = ["thumb", "index", "middle", "ring", "little"] 112 | fingers_coords = mocap_ds.fingers.isel(time=range(0, length*fps)).sel(hand='right', joint='tip', finger=self.fingers).values 113 | 114 | # We need to transform the mocap coordinates to simulation coordinates 115 | for time_idx in range(fingers_coords.shape[0]): 116 | for finger_idx in range(fingers_coords.shape[1]): 117 | coords = np.concatenate((fingers_coords[time_idx, finger_idx, :], np.array([1]))) 118 | coords[1] *= -1 119 | coords = np.matmul(T, coords) 120 | fingers_coords[time_idx, finger_idx, :] = coords[:3] 121 | 122 | # Get timestamps in seconds 123 | time = mocap_ds.isel(time=range(0, length*fps)).time.values / np.timedelta64(1, 's') 124 | 125 | # We want to interpolate finger positions every (self.frame_skip * self.model.opt.timestep) second 126 | time_interp = np.arange(0, length, self.frame_skip * self.model.opt.timestep) 127 | self.fingers_targets = {x: np.empty(time_interp.shape + fingers_coords.shape[2:]) for x in self.fingers} 128 | #self.fingers_targets = np.empty(time_interp.shape + fingers_coords.shape[1:]) 129 | for finger_idx in range(fingers_coords.shape[1]): 130 | for coord_idx in range(fingers_coords.shape[2]): 131 | #self.fingers_targets[:, finger_idx, coord_idx] = np.interp(time_interp, time, self.fingers_coords[:, self.finger_idx, coord_idx]) 132 | self.fingers_targets[self.fingers[finger_idx]][:, coord_idx] = \ 133 | np.interp(time_interp, time, fingers_coords[:, finger_idx, coord_idx]) 134 | 135 | self.initialised = True 136 | 137 | def add_fingertip_site(self, sites, name, pos): 138 | # Check if this fingertip is already in sites, and if so, overwrite the position 139 | for site in sites: 140 | if site["@name"] == name: 141 | site["@pos"] = pos 142 | return 143 | 144 | # Create the site 145 | sites.append({"@name": name, "@pos": pos}) 146 | 147 | def step(self, a): 148 | 149 | if not self.initialised: 150 | return self._get_obs(), 0, False, {} 151 | 152 | # Step forward 153 | self.do_simulation(a, self.frame_skip) 154 | 155 | # Cost is difference between target and simulated fingertip positions 156 | #err = {x: 0 for x in self.fingers} 157 | #for finger in self.fingers: 158 | # err[finger] = np.linalg.norm(self.fingers_targets[finger][self._step_idx, :] - 159 | # self.data.site_xpos[self.model._site_name2id["fingertip_"+finger], :], ord=2) 160 | 161 | err = {} 162 | cost = np.linalg.norm(self.data.body_xpos[self.model._body_name2id["keyboard"], :] - 163 | self.data.site_xpos[self.model._site_name2id["fingertip_index"], :], ord=2) 164 | 165 | # Product of errors? Or sum? Squared errors? 166 | #cost = np.prod(list(err.values())) 167 | 168 | return self._get_obs(), -cost, False, err 169 | 170 | def _get_obs(self): 171 | """DIFFERENT FROM ORIGINAL GYM""" 172 | qpos = self.sim.data.qpos 173 | qvel = self.sim.data.qvel 174 | return np.concatenate([qpos.flat, qvel.flat]) 175 | 176 | def reset_model(self): 177 | # Reset to initial pose? 178 | #self.set_state( 179 | # self.init_qpos + self.np_random.uniform(low=-1, high=1, size=self.model.nq), 180 | # self.init_qvel + self.np_random.uniform(low=-1, high=1, size=self.model.nv) 181 | #) 182 | self.sim.reset() 183 | self.set_state(self.init_qpos, self.init_qvel) 184 | return self._get_obs() 185 | 186 | @staticmethod 187 | def is_done(state): 188 | done = False 189 | return done 190 | -------------------------------------------------------------------------------- /mujoco/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .hopper import HopperEnv 2 | from .half_cheetah import HalfCheetahEnv 3 | from .swimmer import SwimmerEnv 4 | from .inverted_pendulum import InvertedPendulumEnv 5 | from .inverted_double_pendulum import InvertedDoublePendulumEnv 6 | from .leg import LegEnv 7 | from .HandModelTSLAdjusted import HandModelTSLAdjustedEnv 8 | from .walker2d import Walker2dEnv 9 | 10 | __all__ = [ 11 | "HopperEnv", 12 | "HalfCheetahEnv", 13 | "SwimmerEnv", 14 | "InvertedPendulumEnv", 15 | "InvertedDoublePendulumEnv", 16 | "LegEnv", 17 | "HandModelTSLAdjustedEnv", 18 | "Walker2dEnv" 19 | ] 20 | -------------------------------------------------------------------------------- /mujoco/envs/half_cheetah.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from gym import utils 5 | from gym.envs.mujoco import mujoco_env 6 | 7 | 8 | class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle): 9 | """ 10 | COPIED FROM GYM. W/ SLIGHT MODIFICATIONS: 11 | * READING FROM OWN .XML. 12 | * FULL STATE OBSERVATIONS, I.E. QPOS CONCAT'D WITH QVEL. 13 | * is_done METHOD SHOULD BE IMPLEMENTED 14 | """ 15 | def __init__(self, cfg): 16 | mujoco_assets_dir = os.path.abspath("./mujoco/assets/") 17 | self.cfg = cfg 18 | self.frame_skip = self.cfg.MODEL.FRAME_SKIP 19 | self.initialised = False 20 | mujoco_env.MujocoEnv.__init__(self, os.path.join(mujoco_assets_dir, "half_cheetah.xml"), self.frame_skip) 21 | utils.EzPickle.__init__(self) 22 | self.initialised = True 23 | 24 | def step(self, action): 25 | xposbefore = self.sim.data.qpos[0] 26 | self.do_simulation(action, self.frame_skip) 27 | xposafter = self.sim.data.qpos[0] 28 | ob = self._get_obs() 29 | reward_ctrl = - 0.01 * np.square(action).sum() 30 | reward_run = (xposafter - xposbefore) / self.dt 31 | reward = reward_ctrl + reward_run 32 | done = False 33 | return ob, reward, done, dict(reward_run=reward_run, reward_ctrl=reward_ctrl) 34 | 35 | def _get_obs(self): 36 | """DIFFERENT FROM ORIGINAL GYM""" 37 | return np.concatenate([ 38 | self.sim.data.qpos.flat, 39 | self.sim.data.qvel.flat, 40 | ]) 41 | 42 | def reset_model(self): 43 | self.sim.reset() 44 | if self.cfg.MODEL.POLICY.NETWORK: 45 | self.set_state( 46 | self.init_qpos + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq), 47 | self.init_qvel + self.np_random.randn(self.model.nv) * .1 48 | ) 49 | else: 50 | self.set_state(self.init_qpos, self.init_qvel) 51 | return self._get_obs() 52 | 53 | def viewer_setup(self): 54 | self.viewer.cam.distance = self.model.stat.extent * 0.5 55 | 56 | @staticmethod 57 | def is_done(state): 58 | done = False 59 | return done 60 | 61 | def tensor_reward(self, state, action, next_state): 62 | """DIFFERENT FROM ORIGINAL GYM""" 63 | xposbefore = state[0] 64 | xposafter = next_state[0] 65 | reward_ctrl = - 0.01 * torch.sum(torch.mul(action, action)) 66 | reward_run = (xposafter - xposbefore) / self.dt 67 | reward = reward_ctrl + reward_run 68 | return reward.view([1, ]) 69 | -------------------------------------------------------------------------------- /mujoco/envs/hopper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from gym import utils 5 | from gym.envs.mujoco import mujoco_env 6 | import math 7 | 8 | 9 | class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle): 10 | """ 11 | COPIED FROM GYM. W/ SLIGHT MODIFICATIONS: 12 | * READING FROM OWN .XML. 13 | * FULL STATE OBSERVATIONS, I.E. QPOS CONCAT'D WITH QVEL. 14 | * is_done METHOD SHOULD BE IMPLEMENTED 15 | * torch implementation of reward function 16 | """ 17 | 18 | def __init__(self, cfg): 19 | self.cfg = cfg 20 | self.frame_skip = self.cfg.MODEL.FRAME_SKIP 21 | mujoco_assets_dir = os.path.abspath("./mujoco/assets/") 22 | self.initialised = False 23 | mujoco_env.MujocoEnv.__init__(self, os.path.join(mujoco_assets_dir, "hopper.xml"), self.frame_skip) 24 | utils.EzPickle.__init__(self) 25 | self.initialised = True 26 | 27 | def sigmoid(self, x, mi, mx): return mi + (mx - mi) * (lambda t: (1 + 200 ** (-t + 0.5)) ** (-1))((x - mi) / (mx - mi)) 28 | 29 | def step(self, a): 30 | posbefore = self.sim.data.qpos[0] 31 | self.do_simulation(a, self.frame_skip) 32 | posafter, height, ang = self.sim.data.qpos[0:3] 33 | alive_bonus = 1.0 34 | reward = (posafter - posbefore) / self.dt 35 | ang_abs = abs(ang) % (2*math.pi) 36 | if ang_abs > math.pi: 37 | ang_abs = 2*math.pi - ang_abs 38 | coeff1 = self.sigmoid(height/1.25, 0, 1) 39 | coeff2 = self.sigmoid((math.pi - ang_abs)/math.pi, 0, 1) 40 | reward += coeff1 * alive_bonus + coeff2 * alive_bonus 41 | reward -= 1e-3 * np.square(a).sum() 42 | s = self.state_vector() 43 | done = not (np.isfinite(s).all() and (np.abs(s[2:]) < 100).all() and 44 | (height > .7) and (abs(ang) < .2)) and self.initialised 45 | ob = self._get_obs() 46 | return ob, reward, False, {} 47 | 48 | def _get_obs(self): 49 | """DIFFERENT FROM ORIGINAL GYM""" 50 | return np.concatenate([ 51 | self.sim.data.qpos.flat, # this part different from gym. expose the whole thing. 52 | self.sim.data.qvel.flat, # this part different from gym. clip nothing. 53 | ]) 54 | 55 | def reset_model(self): 56 | self.sim.reset() 57 | if self.cfg.MODEL.POLICY.NETWORK: 58 | qpos = self.init_qpos + self.np_random.uniform(low=-.005, high=.005, size=self.model.nq) 59 | qvel = self.init_qvel + self.np_random.uniform(low=-.005, high=.005, size=self.model.nv) 60 | self.set_state(qpos, qvel) 61 | else: 62 | self.set_state(self.init_qpos, self.init_qvel) 63 | return self._get_obs() 64 | 65 | def viewer_setup(self): 66 | self.viewer.cam.trackbodyid = 2 67 | self.viewer.cam.distance = self.model.stat.extent * 0.75 68 | self.viewer.cam.lookat[2] = 1.15 69 | self.viewer.cam.elevation = -20 70 | 71 | @staticmethod 72 | def is_done(state): 73 | height, ang = state[1:3] 74 | done = not (np.isfinite(state).all() and (np.abs(state[2:]) < 100).all() and 75 | (height > .7) and (abs(ang) < .2)) 76 | return done 77 | 78 | def tensor_reward(self, state, action, next_state): 79 | """DIFFERENT FROM ORIGINAL GYM""" 80 | posbefore = state[0] 81 | posafter, height, ang = next_state[0:3] 82 | alive_bonus = 1.0 83 | reward = (posafter - posbefore) / self.dt 84 | reward += alive_bonus 85 | reward -= 1e-3 * torch.sum(torch.mul(action, action)) 86 | return reward.view([1, ]) 87 | -------------------------------------------------------------------------------- /mujoco/envs/inverted_double_pendulum.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from gym import utils 5 | from gym.envs.mujoco import mujoco_env 6 | 7 | 8 | class InvertedDoublePendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle): 9 | """ 10 | COPIED FROM GYM. W/ SLIGHT MODIFICATIONS: 11 | * READING FROM OWN .XML. 12 | * FULL STATE OBSERVATIONS, I.E. QPOS CONCAT'D WITH QVEL. 13 | * is_done METHOD SHOULD BE IMPLEMENTED 14 | * torch implementation of reward function 15 | """ 16 | 17 | def __init__(self, cfg): 18 | self.cfg = cfg 19 | mujoco_assets_dir = os.path.abspath("./mujoco/assets/") 20 | self.initialised = False 21 | mujoco_env.MujocoEnv.__init__(self, os.path.join(mujoco_assets_dir, "inverted_double_pendulum.xml"), self.cfg.MODEL.FRAME_SKIP) 22 | utils.EzPickle.__init__(self) 23 | self.initialised = True 24 | 25 | def step(self, action): 26 | self.do_simulation(action, self.frame_skip) 27 | ob = self._get_obs() 28 | x, _, y = self.sim.data.site_xpos[0] 29 | dist_penalty = 0.01 * x ** 2 + (y - 2) ** 2 30 | v1, v2 = self.sim.data.qvel[1:3] 31 | vel_penalty = 1e-3 * v1 ** 2 + 5e-3 * v2 ** 2 32 | alive_bonus = 10 33 | r = - dist_penalty - vel_penalty 34 | done = bool(y <= 1) 35 | return ob, r, False, {} 36 | 37 | def _get_obs(self): 38 | """DIFFERENT FROM ORIGINAL GYM""" 39 | return np.concatenate([ 40 | self.sim.data.qpos.flat, # this part different from gym. expose the whole thing. 41 | self.sim.data.qvel.flat, # this part different from gym. clip nothing. 42 | ]) 43 | 44 | def reset_model(self): 45 | self.sim.reset() 46 | self.set_state( 47 | self.init_qpos,# + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq), 48 | self.init_qvel# + self.np_random.randn(self.model.nv) * .1 49 | ) 50 | return self._get_obs() 51 | 52 | def viewer_setup(self): 53 | v = self.viewer 54 | v.cam.trackbodyid = 0 55 | v.cam.distance = self.model.stat.extent * 0.5 56 | v.cam.lookat[2] = 0.12250000000000005 # v.model.stat.center[2] 57 | 58 | @staticmethod 59 | def is_done(state): 60 | arm_length = 0.6 61 | theta_1 = state[1] 62 | theta_2 = state[2] 63 | y = arm_length * np.cos(theta_1) + \ 64 | arm_length * np.sin(theta_1 + theta_2) 65 | done = bool(y <= 1) 66 | return done 67 | 68 | def tensor_reward(self, state, action, next_state): 69 | """DIFFERENT FROM ORIGINAL GYM""" 70 | arm_length = 0.6 71 | theta_1 = next_state[1] 72 | theta_2 = next_state[2] 73 | y = arm_length * torch.cos(theta_1) + \ 74 | arm_length * torch.cos(theta_1 + theta_2) 75 | x = arm_length * torch.cos(theta_1) + \ 76 | arm_length * torch.cos(theta_1 + theta_2) + \ 77 | next_state[0] 78 | dist_penalty = 0.01 * x ** 2 + (y - 2) ** 2 79 | v1, v2 = next_state[4:6] 80 | vel_penalty = 1e-3 * v1 ** 2 + 5e-3 * v2 ** 2 81 | alive_bonus = 10 82 | reward = alive_bonus - dist_penalty - vel_penalty 83 | return reward.view([1, ]) 84 | -------------------------------------------------------------------------------- /mujoco/envs/inverted_pendulum.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from gym import utils 5 | from gym.envs.mujoco import mujoco_env 6 | 7 | 8 | class InvertedPendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle): 9 | """ 10 | COPIED FROM GYM. W/ SLIGHT MODIFICATIONS: 11 | * READING FROM OWN .XML. 12 | * FULL STATE OBSERVATIONS, I.E. QPOS CONCAT'D WITH QVEL. 13 | * is_done METHOD SHOULD BE IMPLEMENTED 14 | * torch implementation of reward function 15 | """ 16 | 17 | def __init__(self, cfg): 18 | self.cfg = cfg 19 | self.frame_skip = self.cfg.MODEL.FRAME_SKIP 20 | utils.EzPickle.__init__(self) 21 | mujoco_assets_dir = os.path.abspath("./mujoco/assets/") 22 | mujoco_env.MujocoEnv.__init__(self, os.path.join(mujoco_assets_dir, "inverted_pendulum.xml"), self.frame_skip) 23 | 24 | def step(self, a): 25 | """DIFFERENT FROM ORIGINAL GYM""" 26 | arm_length = 0.6 27 | self.do_simulation(a, self.frame_skip) 28 | ob = self._get_obs() 29 | theta = ob[1] 30 | y = arm_length * np.cos(theta) 31 | x = arm_length * np.cos(theta) 32 | dist_penalty = 0.01 * x ** 2 + (y - 1) ** 2 33 | #v = ob[3] 34 | #vel_penalty = 1e-3 * v ** 2 35 | reward = -dist_penalty - 0.001*(a**2) 36 | notdone = np.isfinite(ob).all() and (np.abs(ob[1]) <= .2) 37 | done = not notdone 38 | return ob, reward, False, {} 39 | 40 | def reset_model(self): 41 | #qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-0.01, high=0.01) 42 | #qvel = self.init_qvel + self.np_random.uniform(size=self.model.nv, low=-0.01, high=0.01) 43 | #self.set_state(qpos, qvel) 44 | self.set_state(np.array([0.0, 0.0]), np.array([0.0, 0.0])) 45 | return self._get_obs() 46 | 47 | def _get_obs(self): 48 | return np.concatenate([self.sim.data.qpos, self.sim.data.qvel]).ravel() 49 | 50 | def viewer_setup(self): 51 | v = self.viewer 52 | v.cam.trackbodyid = 0 53 | v.cam.distance = self.model.stat.extent 54 | 55 | @staticmethod 56 | def is_done(state): 57 | done = False 58 | return done 59 | 60 | def tensor_reward(self, state, action, next_state): 61 | """DIFFERENT FROM ORIGINAL GYM""" 62 | arm_length = 0.6 63 | theta = next_state[1] 64 | y = arm_length * torch.cos(theta) 65 | x = arm_length * torch.cos(theta) 66 | dist_penalty = 0.01 * x ** 2 + (y - 1) ** 2 67 | v = next_state[3] 68 | vel_penalty = 1e-3 * v ** 2 69 | reward = -dist_penalty - vel_penalty 70 | return reward.view([1, ]) 71 | -------------------------------------------------------------------------------- /mujoco/envs/leg.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from gym import utils, spaces 5 | from gym.envs.mujoco import mujoco_env 6 | from model_accuracy_tests.find_parameter_mappings import get_control, get_kinematics, update_equality_constraint, reindex_dataframe 7 | 8 | 9 | class LegEnv(mujoco_env.MujocoEnv, utils.EzPickle): 10 | """ 11 | """ 12 | def __init__(self): 13 | 14 | self.initialised = False 15 | 16 | mujoco_assets_dir = os.path.abspath("./mujoco/assets/") 17 | mujoco_env.MujocoEnv.__init__(self, os.path.join(mujoco_assets_dir, "leg6dof9musc_converted.xml"), 1) 18 | utils.EzPickle.__init__(self) 19 | 20 | # Get muscle control values 21 | control_file = "/home/aleksi/Workspace/O2MConverter/models/opensim/Leg6Dof9Musc/CMC/leg6dof9musc_controls.sto" 22 | control_values, control_header = get_control(self.model, control_file) 23 | 24 | # Get joint kinematics values 25 | qpos_file = "/home/aleksi/Workspace/O2MConverter/models/opensim/Leg6Dof9Musc/CMC/leg6dof9musc_Kinematics_q.sto" 26 | qpos, qpos_header = get_kinematics(self.model, qpos_file) 27 | qvel_file = "/home/aleksi/Workspace/O2MConverter/models/opensim/Leg6Dof9Musc/CMC/leg6dof9musc_Kinematics_u.sto" 28 | qvel, qvel_header = get_kinematics(self.model, qvel_file) 29 | 30 | # Make sure both muscle control and joint kinematics have the same timesteps 31 | if not control_values.index.equals(qpos.index) or not control_values.index.equals(qvel.index): 32 | print("Timesteps do not match between muscle control and joint kinematics") 33 | return 34 | 35 | # Timestep might not be constant in the OpenSim reference movement (weird). We can't change timestep dynamically in 36 | # mujoco, at least the viewer does weird things and it could be reflecting underlying issues. Thus, we should 37 | # interpolate the muscle control and joint kinematics with model.opt.timestep 38 | # model.opt.timestep /= 2.65 39 | self.control_values = reindex_dataframe(control_values, self.model.opt.timestep) 40 | self.target_qpos = reindex_dataframe(qpos, self.model.opt.timestep) 41 | self.target_qvel = reindex_dataframe(qvel, self.model.opt.timestep) 42 | 43 | # Get initial state values and reset model 44 | for joint_idx, joint_name in self.model._joint_id2name.items(): 45 | self.init_qpos[joint_idx] = self.target_qpos[joint_name][0] 46 | self.init_qvel[joint_idx] = self.target_qvel[joint_name][0] 47 | self.reset_model() 48 | 49 | # We need to update equality constraints here 50 | update_equality_constraint(self.model, self.target_qpos) 51 | 52 | self.initialised = True 53 | 54 | def step(self, a): 55 | 56 | if not self.initialised: 57 | return self._get_obs(), 0, False, {} 58 | 59 | # Step forward 60 | self.do_simulation(a, self.frame_skip) 61 | 62 | # Cost is difference between target and simulated qpos and qvel 63 | e_qpos = 0 64 | e_qvel = 0 65 | for joint_idx, joint_name in self.model._joint_id2name.items(): 66 | e_qpos += np.abs(self.target_qpos[joint_name].iloc[self._step_idx.value] - self.data.qpos[joint_idx]) 67 | e_qvel += np.abs(self.target_qvel[joint_name].iloc[self._step_idx.value] - self.data.qvel[joint_idx]) 68 | 69 | cost = np.square(e_qpos) + 0.001*np.square(e_qvel) + 0.001*np.matmul(a,a) 70 | 71 | return self._get_obs(), -cost, False, {"e_qpos": e_qpos, "e_qvel": e_qvel} 72 | 73 | def _get_obs(self): 74 | """DIFFERENT FROM ORIGINAL GYM""" 75 | qpos = self.sim.data.qpos 76 | qvel = self.sim.data.qvel 77 | return np.concatenate([qpos.flat, qvel.flat]) 78 | 79 | def reset_model(self): 80 | # Reset to initial pose? 81 | #self.set_state( 82 | # self.init_qpos + self.np_random.uniform(low=-1, high=1, size=self.model.nq), 83 | # self.init_qvel + self.np_random.uniform(low=-1, high=1, size=self.model.nv) 84 | #) 85 | self.set_state(self.init_qpos, self.init_qvel) 86 | return self._get_obs() 87 | 88 | @staticmethod 89 | def is_done(state): 90 | done = False 91 | return done 92 | -------------------------------------------------------------------------------- /mujoco/envs/swimmer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from gym import utils 5 | from gym.envs.mujoco import mujoco_env 6 | 7 | 8 | class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle): 9 | """ 10 | COPIED FROM GYM. W/ SLIGHT MODIFICATIONS: 11 | * READING FROM OWN .XML. 12 | * FULL STATE OBSERVATIONS, I.E. QPOS CONCAT'D WITH QVEL. 13 | * is_done METHOD SHOULD BE IMPLEMENTED 14 | """ 15 | def __init__(self, cfg): 16 | self.cfg = cfg 17 | self.frame_skip = self.cfg.MODEL.FRAME_SKIP 18 | mujoco_assets_dir = os.path.abspath("./mujoco/assets/") 19 | mujoco_env.MujocoEnv.__init__(self, os.path.join(mujoco_assets_dir, "swimmer.xml"), self.frame_skip) 20 | utils.EzPickle.__init__(self) 21 | 22 | def step(self, a): 23 | ctrl_cost_coeff = 0.0001 24 | xposbefore = self.sim.data.qpos[0] 25 | self.do_simulation(a, self.frame_skip) 26 | xposafter = self.sim.data.qpos[0] 27 | reward_fwd = (xposafter - xposbefore) / self.dt 28 | reward_ctrl = - ctrl_cost_coeff * np.square(a).sum() 29 | reward = reward_fwd + reward_ctrl 30 | 31 | ob = self._get_obs() 32 | return ob, reward, False, dict(reward_fwd=reward_fwd, reward_ctrl=reward_ctrl) 33 | 34 | def _get_obs(self): 35 | """DIFFERENT FROM ORIGINAL GYM""" 36 | qpos = self.sim.data.qpos 37 | qvel = self.sim.data.qvel 38 | return np.concatenate([qpos.flat, qvel.flat]) 39 | 40 | def reset_model(self): 41 | self.sim.reset() 42 | if self.cfg.MODEL.POLICY.NETWORK: 43 | self.set_state( 44 | self.init_qpos + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq), 45 | self.init_qvel + self.np_random.uniform(low=-.1, high=.1, size=self.model.nv) 46 | ) 47 | else: 48 | self.set_state(self.init_qpos, self.init_qvel) 49 | return self._get_obs() 50 | 51 | @staticmethod 52 | def is_done(state): 53 | done = False 54 | return done 55 | 56 | def tensor_reward(self, state, action, next_state): 57 | """DIFFERENT FROM ORIGINAL GYM""" 58 | ctrl_cost_coeff = 0.0001 59 | xposbefore = state[0] 60 | xposafter = next_state[0] 61 | reward_fwd = (xposafter - xposbefore) / self.dt 62 | reward_ctrl = - ctrl_cost_coeff * torch.sum(torch.mul(action, action)) 63 | reward = reward_fwd + reward_ctrl 64 | return reward.view([1, ]) 65 | -------------------------------------------------------------------------------- /mujoco/envs/walker2d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym import utils 3 | from gym.envs.mujoco import mujoco_env 4 | import os 5 | import math 6 | 7 | 8 | class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle): 9 | """ 10 | COPIED FROM GYM. W/ SLIGHT MODIFICATIONS 11 | """ 12 | 13 | def __init__(self, cfg): 14 | self.cfg = cfg 15 | self.frame_skip = self.cfg.MODEL.FRAME_SKIP 16 | mujoco_assets_dir = os.path.abspath("./mujoco/assets/") 17 | self.initialised = False 18 | mujoco_env.MujocoEnv.__init__(self, os.path.join(mujoco_assets_dir, "walker2d.xml"), self.frame_skip) 19 | utils.EzPickle.__init__(self) 20 | self.initialised = True 21 | 22 | def step(self, a): 23 | posbefore = self.sim.data.qpos[0] 24 | self.do_simulation(a, self.frame_skip) 25 | posafter, height, ang = self.sim.data.qpos[0:3] 26 | alive_bonus = 1.0 27 | reward = ((posafter - posbefore) / self.dt) 28 | coeff = min(max(height/0.8, 0), 1)*0.5 + max(((math.pi - abs(ang))/math.pi), 1)*0.5 29 | reward += coeff*alive_bonus 30 | reward -= 1e-3 * np.square(a).sum() 31 | done = not (height > 0.8 and height < 2.0 and 32 | ang > -1.0 and ang < 1.0) and self.initialised 33 | ob = self._get_obs() 34 | #if not done: 35 | # reward += alive_bonus 36 | return ob, reward, False, {} 37 | 38 | def _get_obs(self): 39 | """DIFFERENT FROM ORIGINAL GYM""" 40 | return np.concatenate([ 41 | self.sim.data.qpos.flat, # this part different from gym. expose the whole thing. 42 | self.sim.data.qvel.flat, # this part different from gym. clip nothing. 43 | ]) 44 | 45 | def reset_model(self): 46 | self.sim.reset() 47 | #self.set_state( 48 | # self.init_qpos + self.np_random.uniform(low=-.005, high=.005, size=self.model.nq), 49 | # self.init_qvel + self.np_random.uniform(low=-.005, high=.005, size=self.model.nv) 50 | #) 51 | self.set_state(self.init_qpos, self.init_qvel) 52 | return self._get_obs() 53 | 54 | def viewer_setup(self): 55 | self.viewer.cam.trackbodyid = 2 56 | self.viewer.cam.distance = self.model.stat.extent * 0.5 57 | self.viewer.cam.lookat[2] = 1.15 58 | self.viewer.cam.elevation = -20 -------------------------------------------------------------------------------- /mujoco/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .wrappers import MjBlockWrapper, RewardScaleWrapper, ClipActionsWrapper, FixedStateWrapper, TorchTensorWrapper 2 | 3 | __all__ = ["MjBlockWrapper", "RewardScaleWrapper", "ClipActionsWrapper", "FixedStateWrapper", "TorchTensorWrapper"] 4 | -------------------------------------------------------------------------------- /mujoco/utils/backward.py: -------------------------------------------------------------------------------- 1 | # TODO: Breaks down in case of quaternions, i.e. free 3D bodies or ball joints. 2 | 3 | import mujoco_py as mj 4 | import numpy as np 5 | import torch 6 | from copy import deepcopy 7 | 8 | # ============================================================ 9 | # CONFIG 10 | # ============================================================ 11 | niter = 30 12 | nwarmup = 3 13 | eps = 1e-6 14 | 15 | 16 | def copy_data(m, d_source, d_dest): 17 | d_dest.time = d_source.time 18 | d_dest.qpos[:] = d_source.qpos 19 | d_dest.qvel[:] = d_source.qvel 20 | d_dest.qacc[:] = d_source.qacc 21 | d_dest.qacc_warmstart[:] = d_source.qacc_warmstart 22 | d_dest.qfrc_applied[:] = d_source.qfrc_applied 23 | for i in range(m.nbody): 24 | d_dest.xfrc_applied[i][:] = d_source.xfrc_applied[i] 25 | d_dest.ctrl[:] = d_source.ctrl 26 | 27 | # copy d_source.act? probably need to when dealing with muscle actuators 28 | #d_dest.act = d_source.act 29 | 30 | 31 | def initialise_simulation(d_dest, d_source): 32 | d_dest.time = d_source.time 33 | d_dest.qpos[:] = d_source.qpos 34 | d_dest.qvel[:] = d_source.qvel 35 | d_dest.qacc_warmstart[:] = d_source.qacc_warmstart 36 | d_dest.ctrl[:] = d_source.ctrl 37 | if d_source.act is not None: 38 | d_dest.act[:] = d_source.act 39 | 40 | 41 | def integrate_dynamics_gradients(env, dqaccdqpos, dqaccdqvel, dqaccdctrl): 42 | m = env.model 43 | nv = m.nv 44 | nu = m.nu 45 | dt = env.model.opt.timestep * env.frame_skip 46 | 47 | # dfds: d(next_state)/d(current_state): consists of four parts ul, dl, ur, and ud 48 | ul = np.identity(nv, dtype=np.float32) 49 | ur = np.identity(nv, dtype=np.float32) * dt 50 | dl = dqaccdqpos * dt 51 | dr = np.identity(nv, dtype=np.float32) + dqaccdqvel * dt 52 | dfds = np.concatenate([np.concatenate([ul, dl], axis=0), 53 | np.concatenate([ur, dr], axis=0)], 54 | axis=1) 55 | 56 | # dfda: d(next_state)/d(action_values) 57 | dfda = np.concatenate([np.zeros([nv, nu]), dqaccdctrl * dt], axis=0) 58 | return dfds, dfda 59 | 60 | 61 | def integrate_reward_gradient(env, drdqpos, drdqvel, drdctrl): 62 | return np.concatenate([np.array(drdqpos).reshape(1, env.model.nq), 63 | np.array(drdqvel).reshape(1, env.model.nv)], axis=1), \ 64 | np.array(drdctrl).reshape(1, env.model.nu) 65 | 66 | 67 | def dynamics_worker(env, d): 68 | m = env.sim.model 69 | dmain = env.sim.data 70 | 71 | dqaccdqpos = [None] * m.nv * m.nv 72 | dqaccdqvel = [None] * m.nv * m.nv 73 | dqaccdctrl = [None] * m.nv * m.nu 74 | 75 | # copy state and control from dmain to thread-specific d 76 | copy_data(m, dmain, d) 77 | 78 | # is_forward 79 | mj.functions.mj_forward(m, d) 80 | 81 | # extra solver iterations to improve warmstart (qacc) at center point 82 | for rep in range(nwarmup): 83 | mj.functions.mj_forwardSkip(m, d, mj.const.STAGE_VEL, 1) 84 | 85 | # select output from forward dynamics 86 | output = d.qacc # always differentiate qacc 87 | 88 | # save output for center point and warmstart (needed in forward only) 89 | center = output.copy() 90 | warmstart = d.qacc_warmstart.copy() 91 | 92 | # finite-difference over control values: skip = mjSTAGE_VEL 93 | for i in range(m.nu): 94 | 95 | # perturb selected target 96 | d.ctrl[i] += eps 97 | 98 | # evaluate dynamics, with center warmstart 99 | mj.functions.mju_copy(d.qacc_warmstart, warmstart, m.nv) 100 | mj.functions.mj_forwardSkip(m, d, mj.const.STAGE_VEL, 1) 101 | 102 | # undo perturbation 103 | d.ctrl[i] = dmain.ctrl[i] 104 | 105 | # compute column i of derivative 2 106 | for j in range(m.nv): 107 | dqaccdctrl[i + j * m.nu] = (output[j] - center[j]) / eps 108 | 109 | # finite-difference over velocity: skip = mjSTAGE_POS 110 | for i in range(m.nv): 111 | 112 | # perturb velocity 113 | d.qvel[i] += eps 114 | 115 | # evaluate dynamics, with center warmstart 116 | mj.functions.mju_copy(d.qacc_warmstart, warmstart, m.nv) 117 | mj.functions.mj_forwardSkip(m, d, mj.const.STAGE_POS, 1) 118 | 119 | # undo perturbation 120 | d.qvel[i] = dmain.qvel[i] 121 | 122 | # compute column i of derivative 1 123 | for j in range(m.nv): 124 | dqaccdqvel[i + j * m.nv] = (output[j] - center[j]) / eps 125 | 126 | # finite-difference over position: skip = mjSTAGE_NONE 127 | for i in range(m.nv): 128 | 129 | # get joint id for this dof 130 | jid = m.dof_jntid[i] 131 | 132 | # get quaternion address and dof position within quaternion (-1: not in quaternion) 133 | quatadr = -1 134 | dofpos = 0 135 | if m.jnt_type[jid] == mj.const.JNT_BALL: 136 | quatadr = m.jnt_qposadr[jid] 137 | dofpos = i - m.jnt_dofadr[jid] 138 | elif m.jnt_type[jid] == mj.const.JNT_FREE and i >= m.jnt_dofadr[jid] + 3: 139 | quatadr = m.jnt_qposadr[jid] + 3 140 | dofpos = i - m.jnt_dofadr[jid] - 3 141 | 142 | # apply quaternion or simple perturbation 143 | if quatadr >= 0: 144 | angvel = np.array([0., 0., 0.]) 145 | angvel[dofpos] = eps 146 | mj.functions.mju_quatIntegrate(d.qpos + quatadr, angvel, 1) 147 | else: 148 | d.qpos[m.jnt_qposadr[jid] + i - m.jnt_dofadr[jid]] += eps 149 | 150 | # evaluate dynamics, with center warmstart 151 | mj.functions.mju_copy(d.qacc_warmstart, warmstart, m.nv) 152 | mj.functions.mj_forwardSkip(m, d, mj.const.STAGE_NONE, 1) 153 | 154 | # undo perturbation 155 | mj.functions.mju_copy(d.qpos, dmain.qpos, m.nq) 156 | 157 | # compute column i of derivative 0 158 | for j in range(m.nv): 159 | dqaccdqpos[i + j * m.nv] = (output[j] - center[j]) / eps 160 | 161 | dqaccdqpos = np.array(dqaccdqpos).reshape(m.nv, m.nv) 162 | dqaccdqvel = np.array(dqaccdqvel).reshape(m.nv, m.nv) 163 | dqaccdctrl = np.array(dqaccdctrl).reshape(m.nv, m.nu) 164 | dfds, dfda = integrate_dynamics_gradients(env, dqaccdqpos, dqaccdqvel, dqaccdctrl) 165 | return dfds, dfda 166 | 167 | 168 | def dynamics_worker_separate(env): 169 | m = env.sim.model 170 | d = env.sim.data 171 | 172 | dsdctrl = np.empty((m.nq + m.nv, m.nu)) 173 | dsdqpos = np.empty((m.nq + m.nv, m.nq)) 174 | dsdqvel = np.empty((m.nq + m.nv, m.nv)) 175 | 176 | # Copy initial state 177 | time_initial = d.time 178 | qvel_initial = d.qvel.copy() 179 | qpos_initial = d.qpos.copy() 180 | ctrl_initial = d.ctrl.copy() 181 | qacc_warmstart_initial = d.qacc_warmstart.copy() 182 | act_initial = d.act.copy() if d.act is not None else None 183 | 184 | # Step with the main simulation 185 | mj.functions.mj_step(m, d) 186 | 187 | # Get qpos, qvel of the main simulation 188 | qpos = d.qpos.copy() 189 | qvel = d.qvel.copy() 190 | 191 | # finite-difference over control values: skip = mjSTAGE_VEL 192 | for i in range(m.nu): 193 | 194 | # Initialise simulation 195 | initialise_simulation(d, time_initial, qpos_initial, qvel_initial, qacc_warmstart_initial, ctrl_initial, act_initial) 196 | 197 | # Perturb control 198 | d.ctrl[i] += eps 199 | 200 | # Step with perturbed simulation 201 | mj.functions.mj_step(m, d) 202 | 203 | # Compute gradients of qpos and qvel wrt control 204 | dsdctrl[:m.nq, i] = (d.qpos - qpos) / eps 205 | dsdctrl[m.nq:, i] = (d.qvel - qvel) / eps 206 | 207 | # finite-difference over velocity: skip = mjSTAGE_POS 208 | for i in range(m.nv): 209 | 210 | # Initialise simulation 211 | initialise_simulation(d, time_initial, qpos_initial, qvel_initial, qacc_warmstart_initial, ctrl_initial, act_initial) 212 | 213 | # Perturb velocity 214 | d.qvel[i] += eps 215 | 216 | # Step with perturbed simulation 217 | mj.functions.mj_step(m, d) 218 | 219 | # Compute gradients of qpos and qvel wrt qvel 220 | dsdqvel[:m.nq, i] = (d.qpos - qpos) / eps 221 | dsdqvel[m.nq:, i] = (d.qvel - qvel) / eps 222 | 223 | # finite-difference over position: skip = mjSTAGE_NONE 224 | for i in range(m.nq): 225 | 226 | # Initialise simulation 227 | initialise_simulation(d, time_initial, qpos_initial, qvel_initial, qacc_warmstart_initial, ctrl_initial, act_initial) 228 | 229 | # Get joint id for this dof 230 | jid = m.dof_jntid[i] 231 | 232 | # Get quaternion address and dof position within quaternion (-1: not in quaternion) 233 | quatadr = -1 234 | dofpos = 0 235 | if m.jnt_type[jid] == mj.const.JNT_BALL: 236 | quatadr = m.jnt_qposadr[jid] 237 | dofpos = i - m.jnt_dofadr[jid] 238 | elif m.jnt_type[jid] == mj.const.JNT_FREE and i >= m.jnt_dofadr[jid] + 3: 239 | quatadr = m.jnt_qposadr[jid] + 3 240 | dofpos = i - m.jnt_dofadr[jid] - 3 241 | 242 | # Apply quaternion or simple perturbation 243 | if quatadr >= 0: 244 | angvel = np.array([0., 0., 0.]) 245 | angvel[dofpos] = eps 246 | mj.functions.mju_quatIntegrate(d.qpos + quatadr, angvel, 1) 247 | else: 248 | d.qpos[m.jnt_qposadr[jid] + i - m.jnt_dofadr[jid]] += eps 249 | 250 | # Step simulation with perturbed position 251 | mj.functions.mj_step(m, d) 252 | 253 | # Compute gradients of qpos and qvel wrt qpos 254 | dsdqpos[:m.nq, i] = (d.qpos - qpos) / eps 255 | dsdqpos[m.nq:, i] = (d.qvel - qvel) / eps 256 | 257 | return np.concatenate((dsdqpos, dsdqvel), axis=1), dsdctrl 258 | 259 | 260 | def reward_worker(env, d): 261 | m = env.sim.model 262 | dmain = env.sim.data 263 | 264 | drdqpos = [None] * m.nv 265 | drdqvel = [None] * m.nv 266 | drdctrl = [None] * m.nu 267 | 268 | # copy state and control from dmain to thread-specific d 269 | copy_data(m, dmain, d) 270 | 271 | # is_forward 272 | mj.functions.mj_forward(m, d) 273 | 274 | # extra solver iterations to improve warmstart (qacc) at center point 275 | for rep in range(nwarmup): 276 | mj.functions.mj_forwardSkip(m, d, mj.const.STAGE_VEL, 1) 277 | 278 | # get center reward 279 | _, center, _, _ = env.step(d.ctrl) 280 | copy_data(m, d, dmain) # revert changes to state and forces 281 | 282 | # finite-difference over control values 283 | for i in range(m.nu): 284 | # perturb selected target 285 | d.ctrl[i] += eps 286 | 287 | _, output, _, _ = env.step(d.ctrl) 288 | copy_data(m, d, dmain) # undo perturbation 289 | 290 | drdctrl[i] = (output - center) / eps 291 | 292 | # finite-difference over velocity 293 | for i in range(m.nv): 294 | # perturb velocity 295 | d.qvel[i] += eps 296 | 297 | _, output, _, _ = env.step(d.ctrl) 298 | copy_data(m, d, dmain) # undo perturbation 299 | 300 | drdqvel[i] = (output - center) / eps 301 | 302 | # finite-difference over position: skip = mjSTAGE_NONE 303 | for i in range(m.nv): 304 | 305 | # get joint id for this dof 306 | jid = m.dof_jntid[i] 307 | 308 | # get quaternion address and dof position within quaternion (-1: not in quaternion) 309 | quatadr = -1 310 | dofpos = 0 311 | if m.jnt_type[jid] == mj.const.JNT_BALL: 312 | quatadr = m.jnt_qposadr[jid] 313 | dofpos = i - m.jnt_dofadr[jid] 314 | elif m.jnt_type[jid] == mj.const.JNT_FREE and i >= m.jnt_dofadr[jid] + 3: 315 | quatadr = m.jnt_qposadr[jid] + 3 316 | dofpos = i - m.jnt_dofadr[jid] - 3 317 | 318 | # apply quaternion or simple perturbation 319 | if quatadr >= 0: 320 | angvel = np.array([0., 0., 0.]) 321 | angvel[dofpos] = eps 322 | mj.functions.mju_quatIntegrate(d.qpos + quatadr, angvel, 1) 323 | else: 324 | d.qpos[m.jnt_qposadr[jid] + i - m.jnt_dofadr[jid]] += eps 325 | 326 | _, output, _, _ = env.step(d.ctrl) 327 | copy_data(m, d, dmain) # undo perturbation 328 | 329 | # compute column i of derivative 0 330 | drdqpos[i] = (output - center) / eps 331 | 332 | drds, drda = integrate_reward_gradient(env, drdqpos, drdqvel, drdctrl) 333 | return drds, drda 334 | 335 | 336 | def calculate_reward(env, qpos, qvel, ctrl, qpos_next, qvel_next): 337 | current_state = np.concatenate((qpos, qvel)) 338 | next_state = np.concatenate((qpos_next, qvel_next)) 339 | reward = env.tensor_reward(torch.DoubleTensor(current_state), torch.DoubleTensor(ctrl), torch.DoubleTensor(next_state)) 340 | return reward.detach().numpy() 341 | 342 | 343 | def calculate_gradients(agent, data_snapshot, next_state, reward, test=False): 344 | # Defining m and d just for shorter notations 345 | m = agent.model 346 | d = agent.data 347 | 348 | # Dynamics gradients 349 | dsdctrl = np.empty((m.nq + m.nv, m.nu)) 350 | dsdqpos = np.empty((m.nq + m.nv, m.nq)) 351 | dsdqvel = np.empty((m.nq + m.nv, m.nv)) 352 | 353 | # Reward gradients 354 | drdctrl = np.empty((1, m.nu)) 355 | drdqpos = np.empty((1, m.nq)) 356 | drdqvel = np.empty((1, m.nv)) 357 | 358 | # Get number of steps (must be >=2 for muscles) 359 | nsteps = agent.cfg.MODEL.NSTEPS_FOR_BACKWARD 360 | 361 | # For testing purposes 362 | if test: 363 | 364 | # Reset simulation to snapshot 365 | agent.set_snapshot(data_snapshot) 366 | 367 | # Step with the main simulation 368 | info = agent.step(d.ctrl.copy()) 369 | 370 | # Sanity check. "reward" must equal info[1], otherwise this simulation has diverged from the forward pass 371 | assert reward == info[1], "reward is different from forward pass [{} != {}] at timepoint {}".format(reward, info[1], data_snapshot.time) 372 | 373 | # Another check. "next_state" must equal info[0], otherwise this simulation has diverged from the forward pass 374 | assert (next_state == info[0]).all(), "state is different from forward pass" 375 | 376 | # Get state from the forward pass 377 | if nsteps > 1: 378 | agent.set_snapshot(data_snapshot) 379 | for _ in range(nsteps): 380 | info = agent.step(d.ctrl.copy()) 381 | qpos_fwd = info[0][:agent.model.nq] 382 | qvel_fwd = info[0][agent.model.nq:] 383 | reward = info[1] 384 | else: 385 | qpos_fwd = next_state[:agent.model.nq] 386 | qvel_fwd = next_state[agent.model.nq:] 387 | 388 | # finite-difference over control values 389 | for i in range(m.nu): 390 | 391 | # Initialise simulation 392 | agent.set_snapshot(data_snapshot) 393 | 394 | # Perturb control 395 | d.ctrl[i] += eps 396 | 397 | # Step with perturbed simulation 398 | for _ in range(nsteps): 399 | info = agent.step(d.ctrl.copy()) 400 | 401 | # Compute gradient of state wrt control 402 | dsdctrl[:m.nq, i] = (d.qpos - qpos_fwd) / eps 403 | dsdctrl[m.nq:, i] = (d.qvel - qvel_fwd) / eps 404 | 405 | # Compute gradient of reward wrt to control 406 | drdctrl[0, i] = (info[1] - reward) / eps 407 | 408 | # finite-difference over velocity 409 | for i in range(m.nv): 410 | 411 | # Initialise simulation 412 | agent.set_snapshot(data_snapshot) 413 | 414 | # Perturb velocity 415 | d.qvel[i] += eps 416 | 417 | # Calculate new ctrl (if it's dependent on state) 418 | if agent.cfg.MODEL.POLICY.NETWORK: 419 | d.ctrl[:] = agent.policy_net(torch.from_numpy(np.concatenate((d.qpos, d.qvel))).float()).double().detach().numpy() 420 | 421 | # Step with perturbed simulation 422 | for _ in range(nsteps): 423 | info = agent.step(d.ctrl) 424 | 425 | # Compute gradient of state wrt qvel 426 | dsdqvel[:m.nq, i] = (d.qpos - qpos_fwd) / eps 427 | dsdqvel[m.nq:, i] = (d.qvel - qvel_fwd) / eps 428 | 429 | # Compute gradient of reward wrt qvel 430 | drdqvel[0, i] = (info[1] - reward) / eps 431 | 432 | # finite-difference over position 433 | for i in range(m.nq): 434 | 435 | # Initialise simulation 436 | agent.set_snapshot(data_snapshot) 437 | 438 | # Get joint id for this dof 439 | jid = m.dof_jntid[i] 440 | 441 | # Get quaternion address and dof position within quaternion (-1: not in quaternion) 442 | quatadr = -1 443 | dofpos = 0 444 | if m.jnt_type[jid] == mj.const.JNT_BALL: 445 | quatadr = m.jnt_qposadr[jid] 446 | dofpos = i - m.jnt_dofadr[jid] 447 | elif m.jnt_type[jid] == mj.const.JNT_FREE and i >= m.jnt_dofadr[jid] + 3: 448 | quatadr = m.jnt_qposadr[jid] + 3 449 | dofpos = i - m.jnt_dofadr[jid] - 3 450 | 451 | # Apply quaternion or simple perturbation 452 | if quatadr >= 0: 453 | angvel = np.array([0., 0., 0.]) 454 | angvel[dofpos] = eps 455 | mj.functions.mju_quatIntegrate(d.qpos + quatadr, angvel, 1) 456 | else: 457 | d.qpos[m.jnt_qposadr[jid] + i - m.jnt_dofadr[jid]] += eps 458 | 459 | # Calculate new ctrl (if it's dependent on state) 460 | if agent.cfg.MODEL.POLICY.NETWORK: 461 | d.ctrl[:] = agent.policy_net(torch.from_numpy(np.concatenate((d.qpos, d.qvel))).float()).double().detach().numpy() 462 | 463 | # Step simulation with perturbed position 464 | for _ in range(nsteps): 465 | info = agent.step(d.ctrl) 466 | 467 | # Compute gradient of state wrt qpos 468 | dsdqpos[:m.nq, i] = (d.qpos - qpos_fwd) / eps 469 | dsdqpos[m.nq:, i] = (d.qvel - qvel_fwd) / eps 470 | 471 | # Compute gradient of reward wrt qpos 472 | drdqpos[0, i] = (info[1] - reward) / eps 473 | 474 | # Set dynamics gradients 475 | agent.dynamics_gradients = {"state": np.concatenate((dsdqpos, dsdqvel), axis=1), "action": dsdctrl} 476 | 477 | # Set reward gradients 478 | agent.reward_gradients = {"state": np.concatenate((drdqpos, drdqvel), axis=1), "action": drdctrl} 479 | 480 | return 481 | 482 | 483 | def mj_gradients_factory(agent, mode): 484 | """ 485 | :param env: gym.envs.mujoco.mujoco_env.mujoco_env.MujocoEnv 486 | :param mode: 'dynamics' or 'reward' 487 | :return: 488 | """ 489 | #mj_sim_main = env.sim 490 | #mj_sim = mj.MjSim(mj_sim_main.model) 491 | 492 | #worker = {'dynamics': reward_worker, 'reward': reward_worker}[mode] 493 | 494 | @agent.gradient_wrapper(mode) 495 | def mj_gradients(data_snapshot, next_state, reward, test=False): 496 | #state = state_action[:env.model.nq + env.model.nv] 497 | #qpos = state[:env.model.nq] 498 | #qvel = state[env.model.nq:] 499 | #ctrl = state_action[-env.model.nu:] 500 | #env.set_state(qpos, qvel) 501 | #env.data.ctrl[:] = ctrl 502 | #d = mj_sim.data 503 | # set solver options for finite differences 504 | #mj_sim_main.model.opt.iterations = niter 505 | #mj_sim_main.model.opt.tolerance = 0 506 | # env.sim.model.opt.iterations = niter 507 | # env.sim.model.opt.tolerance = 0 508 | #dfds, dfda = worker(env) 509 | 510 | calculate_gradients(agent, data_snapshot, next_state, reward, test=test) 511 | 512 | return mj_gradients 513 | -------------------------------------------------------------------------------- /mujoco/utils/forward.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def mj_forward_factory(agent, mode): 6 | """ 7 | :param agent: gym.envs.mujoco.mujoco_env.mujoco_env.MujocoEnv 8 | :param mode: 'dynamics' or 'reward' 9 | :return: 10 | """ 11 | 12 | @agent.forward_wrapper(mode) 13 | def mj_forward(action=None): 14 | """ 15 | :param action: np.array of action -- if missing, agent.data.ctrl is used as action 16 | :return: 17 | """ 18 | 19 | # If action wasn't explicitly given use the current one in agent's env 20 | if action is None: 21 | action = agent.data.ctrl 22 | 23 | # Convert tensor to numpy array, and make sure we're using an action value that isn't referencing agent's data 24 | # (otherwise we might get into trouble if ctrl is limited and frame_skip is larger than one) 25 | if isinstance(action, torch.Tensor): 26 | action = action.detach().numpy().copy() 27 | elif isinstance(action, np.ndarray): 28 | action = action.copy() 29 | else: 30 | raise TypeError("Expecting a torch tensor or numpy ndarray") 31 | 32 | # Make sure dtype is numpy.float64 33 | assert action.dtype == np.float64, "You must use dtype numpy.float64 for actions" 34 | 35 | # Advance simulation with one step 36 | next_state, agent.reward, agent.is_done, _ = agent.step(action) 37 | 38 | return next_state 39 | 40 | return mj_forward 41 | -------------------------------------------------------------------------------- /mujoco/utils/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | from .mj_block import MjBlockWrapper 2 | from .etc import AccumulateWrapper, RewardScaleWrapper, ClipActionsWrapper, FixedStateWrapper, TorchTensorWrapper 3 | 4 | __all__ = ["MjBlockWrapper", 5 | "AccumulateWrapper", 6 | "RewardScaleWrapper", 7 | "ClipActionsWrapper", 8 | "FixedStateWrapper", 9 | "TorchTensorWrapper"] 10 | -------------------------------------------------------------------------------- /mujoco/utils/wrappers/etc.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import torch 3 | import numpy as np 4 | from copy import deepcopy 5 | from utils.index import Index 6 | import mujoco_py 7 | from multiprocessing import Process, Queue 8 | import skvideo 9 | import glfw 10 | 11 | 12 | class AccumulateWrapper(gym.Wrapper): 13 | 14 | def __init__(self, env, gamma): 15 | super(AccumulateWrapper, self).__init__(env) 16 | self.gamma = gamma 17 | self.accumulated_return = 0. 18 | self.accumulated_observations = [] 19 | self.accumulated_rewards = [] 20 | 21 | def step(self, action): 22 | observation, reward, done, info = self.env.step(action) 23 | self.accumulated_observations.append(observation) 24 | self.accumulated_rewards.append(reward) 25 | self.accumulated_return = self.gamma * self.accumulated_return + reward 26 | 27 | def reset(self, **kwargs): 28 | self.accumulated_return = 0. 29 | self.accumulated_observations = [] 30 | self.accumulated_rewards = [] 31 | return torch.Tensor(self.env.reset(**kwargs)) 32 | 33 | 34 | class RewardScaleWrapper(gym.RewardWrapper): 35 | """Bring rewards to a reasonable scale.""" 36 | 37 | def __init__(self, env, scale=0.01): 38 | super(RewardScaleWrapper, self).__init__(env) 39 | self.scale = scale 40 | 41 | def reward(self, reward): 42 | return self.scale * reward 43 | 44 | def torch_reward(self, state, action, next_state): 45 | return self.scale * self.env.torch_reward(state, action, next_state) 46 | 47 | 48 | class ClipActionsWrapper(gym.Wrapper): 49 | def step(self, action): 50 | action = np.nan_to_num(action) # less error prone but your optimization won't benefit from this 51 | action = np.clip(action, self.action_space.low, self.action_space.high) 52 | return self.env.step(action) 53 | 54 | def reset(self, **kwargs): 55 | return self.env.reset(**kwargs) 56 | 57 | 58 | class FixedStateWrapper(gym.Wrapper): 59 | def __init__(self, env): 60 | super(FixedStateWrapper, self).__init__(env) 61 | self.fixed_state = self.env.reset_model() 62 | self.fixed_qpos = self.env.sim.data.qpos 63 | self.fixed_qvel = self.env.sim.data.qvel 64 | 65 | def reset(self, **kwargs): 66 | self.env.reset(**kwargs) 67 | self.env.set_state(self.fixed_qpos, self.fixed_qvel) 68 | return self.env.env._get_obs() 69 | 70 | 71 | class TorchTensorWrapper(gym.Wrapper): 72 | """Takes care of torch Tensors in step and reset modules.""" 73 | 74 | #def step(self, action): 75 | # action = action.detach().numpy() 76 | # state, reward, done, info = self.env.step(action) 77 | # return torch.Tensor(state), reward, done, info 78 | 79 | def reset(self, **kwargs): 80 | return torch.Tensor(self.env.reset(**kwargs)) 81 | 82 | #def set_from_torch_state(self, state): 83 | # qpos, qvel = np.split(state.detach().numpy(), 2) 84 | # self.env.set_state(qpos, qvel) 85 | 86 | #def is_done(self, state): 87 | # state = state.detach().numpy() 88 | # return self.env.is_done(state) 89 | 90 | 91 | class SnapshotWrapper(gym.Wrapper): 92 | """Handles all stateful stuff, like getting and setting snapshots of states, and resetting""" 93 | def get_snapshot(self): 94 | 95 | class DataSnapshot: 96 | # Note: You should not modify these parameters after creation 97 | 98 | def __init__(self, d_source, step_idx): 99 | self.time = deepcopy(d_source.time) 100 | self.qpos = deepcopy(d_source.qpos) 101 | self.qvel = deepcopy(d_source.qvel) 102 | self.qacc_warmstart = deepcopy(d_source.qacc_warmstart) 103 | self.ctrl = deepcopy(d_source.ctrl) 104 | self.act = deepcopy(d_source.act) 105 | self.qfrc_applied = deepcopy(d_source.qfrc_applied) 106 | self.xfrc_applied = deepcopy(d_source.xfrc_applied) 107 | 108 | self.step_idx = deepcopy(step_idx) 109 | 110 | # These probably aren't necessary, but they should fix the body in the same position with 111 | # respect to worldbody frame? 112 | self.body_xpos = deepcopy(d_source.body_xpos) 113 | self.body_xquat = deepcopy(d_source.body_xquat) 114 | 115 | return DataSnapshot(self.env.sim.data, self.get_step_idx()) 116 | 117 | def set_snapshot(self, snapshot_data): 118 | self.env.sim.data.time = deepcopy(snapshot_data.time) 119 | self.env.sim.data.qpos[:] = deepcopy(snapshot_data.qpos) 120 | self.env.sim.data.qvel[:] = deepcopy(snapshot_data.qvel) 121 | self.env.sim.data.qacc_warmstart[:] = deepcopy(snapshot_data.qacc_warmstart) 122 | self.env.sim.data.ctrl[:] = deepcopy(snapshot_data.ctrl) 123 | if snapshot_data.act is not None: 124 | self.env.sim.data.act[:] = deepcopy(snapshot_data.act) 125 | self.env.sim.data.qfrc_applied[:] = deepcopy(snapshot_data.qfrc_applied) 126 | self.env.sim.data.xfrc_applied[:] = deepcopy(snapshot_data.xfrc_applied) 127 | 128 | self.set_step_idx(snapshot_data.step_idx) 129 | 130 | self.env.sim.data.body_xpos[:] = deepcopy(snapshot_data.body_xpos) 131 | self.env.sim.data.body_xquat[:] = deepcopy(snapshot_data.body_xquat) 132 | 133 | 134 | class IndexWrapper(gym.Wrapper): 135 | """Counts steps and episodes""" 136 | 137 | def __init__(self, env, batch_size): 138 | super(IndexWrapper, self).__init__(env) 139 | self._step_idx = Index(0) 140 | self._episode_idx = Index(0) 141 | self._batch_idx = Index(0) 142 | self._batch_size = batch_size 143 | 144 | # Add references to the unwrapped env because we're going to need at least step_idx 145 | # NOTE! This means we can never overwrite self.step_idx, self.episode_idx, or self.batch_idx or we lose 146 | # the reference 147 | env.unwrapped._step_idx = self._step_idx 148 | env.unwrapped._episode_idx = self._episode_idx 149 | env.unwrapped._batch_idx = self._batch_idx 150 | 151 | def step(self, action): 152 | self._step_idx += 1 153 | return self.env.step(action) 154 | 155 | def reset(self, update_episode_idx=True): 156 | self._step_idx.set(0) 157 | 158 | # We don't want to update episode_idx during testing 159 | if update_episode_idx: 160 | if self._episode_idx == self._batch_size: 161 | self._batch_idx += 1 162 | self._episode_idx.set(1) 163 | else: 164 | self._episode_idx += 1 165 | 166 | return self.env.reset() 167 | 168 | def get_step_idx(self): 169 | return self._step_idx 170 | 171 | def get_episode_idx(self): 172 | return self._episode_idx 173 | 174 | def get_batch_idx(self): 175 | return self._batch_idx 176 | 177 | def set_step_idx(self, idx): 178 | self._step_idx.set(idx) 179 | 180 | def set_episode_idx(self, idx): 181 | self._episode_idx.set(idx) 182 | 183 | def set_batch_idx(self, idx): 184 | self._batch_idx.set(idx) 185 | 186 | 187 | class ViewerWrapper(gym.Wrapper): 188 | 189 | def __init__(self, env): 190 | super(ViewerWrapper, self).__init__(env) 191 | 192 | # Keep params in this class to reduce clutter 193 | class Recorder: 194 | width = 1600 195 | height = 1200 196 | imgs = [] 197 | record = False 198 | filepath = None 199 | 200 | self.recorder = Recorder 201 | 202 | # Check if we want to record roll-outs 203 | if self.cfg.LOG.TESTING.RECORD_VIDEO: 204 | self.recorder.record = True 205 | 206 | # Create a viewer if we're not recording 207 | else: 208 | # Initialise a MjViewer 209 | self._viewer = mujoco_py.MjViewer(self.sim) 210 | self._viewer._run_speed = 1/self.cfg.MODEL.FRAME_SKIP 211 | self.unwrapped._viewers["human"] = self._viewer 212 | 213 | def capture_frame(self): 214 | if self.recorder.record: 215 | self.recorder.imgs.append(np.flip(self.sim.render(self.recorder.width, self.recorder.height), axis=0)) 216 | 217 | def start_recording(self, filepath): 218 | if self.recorder.record: 219 | self.recorder.filepath = filepath 220 | self.recorder.imgs.clear() 221 | 222 | def stop_recording(self): 223 | if self.recorder.record: 224 | writer = skvideo.io.FFmpegWriter( 225 | self.recorder.filepath, inputdict={"-s": "{}x{}".format(self.recorder.width, self.recorder.height), 226 | "-r": str(1 / (self.model.opt.timestep*self.cfg.MODEL.FRAME_SKIP))}) 227 | for img in self.recorder.imgs: 228 | writer.writeFrame(img) 229 | writer.close() 230 | self.recorder.imgs.clear() 231 | -------------------------------------------------------------------------------- /mujoco/utils/wrappers/mj_block.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | 4 | from mujoco.utils.forward import mj_forward_factory 5 | from mujoco.utils.backward import mj_gradients_factory 6 | 7 | 8 | class MjBlockWrapper(gym.Wrapper): 9 | """Wrap the forward and backward model. Further used for PyTorch blocks.""" 10 | 11 | def __init__(self, env): 12 | gym.Wrapper.__init__(self, env) 13 | 14 | def gradient_factory(self, mode): 15 | """ 16 | :param mode: 'dynamics' or 'reward' 17 | :return: 18 | """ 19 | # TODO: due to dynamics and reward isolation, this isn't the 20 | # most efficient way to handle this; lazy, but simple 21 | #env = self.clone() 22 | return mj_gradients_factory(self, mode) 23 | 24 | def forward_factory(self, mode): 25 | """ 26 | :param mode: 'dynamics' or 'reward' 27 | :return: 28 | """ 29 | #env = self.clone() 30 | return mj_forward_factory(self, mode) 31 | 32 | def gradient_wrapper(self, mode): 33 | """ 34 | Decorator for making gradients be the same size the observations for example. 35 | :param mode: either 'dynamics' or 'reward' 36 | :return: 37 | """ 38 | 39 | # mode agnostic for now 40 | def decorator(gradients_fn): 41 | def wrapper(*args, **kwargs): 42 | #if mode == "forward": 43 | gradients_fn(*args, **kwargs) 44 | #else: 45 | # dfds, dfda = gradients_fn(*args, **kwargs) 46 | # # no further reshaping is needed for the case of hopper, also it's mode-agnostic 47 | # gradients = np.concatenate([dfds, dfda], axis=1) 48 | return 49 | 50 | return wrapper 51 | 52 | return decorator 53 | 54 | def forward_wrapper(self, mode): 55 | """ 56 | Decorator for making gradients be the same size the observations for example. 57 | :param mode: either 'dynamics' or 'reward' 58 | :return: 59 | """ 60 | 61 | # mode agnostic for now 62 | def decorator(forward_fn): 63 | def wrapper(*args, **kwargs): 64 | f = forward_fn(*args, **kwargs) # next state 65 | # no further reshaping is needed for the case of hopper, also it's mode-agnostic 66 | return f 67 | 68 | return wrapper 69 | 70 | return decorator 71 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | mujoco_py 3 | numpy 4 | gym 5 | yacs 6 | visdom 7 | -------------------------------------------------------------------------------- /solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_optimizer 2 | 3 | __all__ = ["build_optimizer"] 4 | -------------------------------------------------------------------------------- /solver/build.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def build_optimizer(cfg, named_parameters): 5 | params = [] 6 | lr = cfg.SOLVER.BASE_LR 7 | for key, value in named_parameters: 8 | if not value.requires_grad: 9 | continue 10 | #lr = cfg.SOLVER.BASE_LR / cfg.SOLVER.BATCH_SIZE # due to funny way gradients are batched 11 | #weight_decay = cfg.SOLVER.WEIGHT_DECAY 12 | weight_decay = 0.0 13 | if "bias" in key: 14 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR 15 | #weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS 16 | if any(x in key for x in ["std", "sd"]): 17 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.STD_LR_FACTOR 18 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_SD 19 | params += [{"params": [value], "lr": lr}] 20 | 21 | if cfg.SOLVER.OPTIMIZER == "sgd": 22 | optimizer = torch.optim.SGD(params, lr) 23 | else: 24 | optimizer = torch.optim.Adam(params, lr, betas=cfg.SOLVER.ADAM_BETAS) 25 | return optimizer 26 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_gradients.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | import numpy as np 5 | from model.config import get_cfg_defaults 6 | from mujoco import build_agent 7 | from utils.index import Index 8 | from model import build_model 9 | 10 | class TestGradients(unittest.TestCase): 11 | 12 | def test_reward_gradients(self): 13 | cfg = get_cfg_defaults() 14 | cfg.MUJOCO.ENV = "InvertedPendulumEnv" 15 | agent = build_agent(cfg) 16 | mj_reward_forward_fn = agent.forward_factory("reward") 17 | mj_reward_gradients_fn = agent.gradient_factory("reward") 18 | 19 | nwarmup = 5 20 | agent.reset() 21 | nv = agent.sim.model.nv 22 | nu = agent.sim.model.nu 23 | for _ in range(nwarmup): 24 | action = torch.Tensor(agent.action_space.sample()) 25 | ob, r, _, _ = agent.step(action) 26 | state = ob.detach().numpy() 27 | action = agent.action_space.sample() 28 | state_action = np.concatenate([state, action], axis=0) 29 | drdsa = mj_reward_gradients_fn(state_action) 30 | drds = drdsa[:, :nv * 2] 31 | drda = drdsa[:, -nu:] 32 | 33 | eps = 1e-6 34 | state_action_prime = state_action + eps 35 | r = mj_reward_forward_fn(state_action) 36 | r_prime = mj_reward_forward_fn(state_action_prime) 37 | 38 | r_prime_estimate = r + \ 39 | np.squeeze(np.matmul(drds, np.array([eps] * 2 * nv).reshape([-1, 1]))) + \ 40 | np.squeeze(np.matmul(drda, np.array([eps] * nu).reshape([-1, 1]))) 41 | self.assertAlmostEqual(r_prime[0], r_prime_estimate[0], places=5) 42 | 43 | def test_dynamics_gradients(self): 44 | cfg = get_cfg_defaults() 45 | cfg.merge_from_file("/home/aleksi/Workspace/Model-Based-RL/configs/swimmer.yaml") 46 | agent = build_agent(cfg) 47 | mj_dynamics_forward_fn = agent.forward_factory("dynamics") 48 | mj_dynamics_gradients_fn = agent.gradient_factory("dynamics") 49 | 50 | nwarmup = 5 51 | agent.reset() 52 | nv = agent.sim.model.nv 53 | nu = agent.sim.model.nu 54 | for _ in range(nwarmup): 55 | action = torch.Tensor(agent.action_space.sample()) 56 | ob, r, _, _ = agent.step(action) 57 | state = ob.detach().numpy() 58 | action = agent.action_space.sample() 59 | state_action = np.concatenate([state, action], axis=0) 60 | dsdsa = mj_dynamics_gradients_fn(state_action) 61 | dsds = dsdsa[:, :nv * 2] 62 | dsda = dsdsa[:, -nu:] 63 | 64 | eps = 1e-6 65 | state_action_prime = state_action + eps 66 | s = mj_dynamics_forward_fn(state_action) 67 | s_prime = mj_dynamics_forward_fn(state_action_prime) 68 | 69 | s_prime_estimate = s + \ 70 | np.squeeze(np.matmul(dsds, np.array([eps] * 2 * nv).reshape([-1, 1]))) + \ 71 | np.squeeze(np.matmul(dsda, np.array([eps] * nu).reshape([-1, 1]))) 72 | print(s) 73 | print(s_prime) 74 | print(s_prime_estimate) 75 | self.assert_(np.allclose(s_prime, s_prime_estimate)) 76 | 77 | def test_action_rewards(self): 78 | 79 | # Run multiple times with different sigma for generating action values 80 | N = 100 81 | sigma = np.linspace(1e-2, 1e1, N) 82 | cfg_file = "/home/aleksi/Workspace/Model-Based-RL/configs/swimmer.yaml" 83 | 84 | for s in sigma: 85 | self.run_reward_test(cfg_file, s) 86 | 87 | def run_reward_test(self, cfg_file, sigma): 88 | 89 | cfg = get_cfg_defaults() 90 | cfg.merge_from_file(cfg_file) 91 | agent = build_agent(cfg) 92 | mj_forward_fn = agent.forward_factory("dynamics") 93 | mj_gradients_fn = agent.gradient_factory("reward") 94 | 95 | model = build_model(cfg, agent) 96 | device = torch.device(cfg.MODEL.DEVICE) 97 | model.to(device) 98 | 99 | # Start from the same state with constant action, make sure reward is equal in both repetitions 100 | 101 | # Drive both simulations forward 5 steps 102 | nwarmup = 5 103 | 104 | # Reset and get initial state 105 | agent.reset() 106 | init_qpos = agent.data.qpos.copy() 107 | init_qvel = agent.data.qvel.copy() 108 | 109 | # Set constant action 110 | na = agent.model.actuator_acc0.shape[0] 111 | action = torch.DoubleTensor(np.random.randn(na)*sigma) 112 | 113 | # Do first simulation 114 | for _ in range(nwarmup): 115 | mj_forward_fn(action) 116 | 117 | # Take a snapshot of this state so we can use it in gradient calculations 118 | agent.data.ctrl[:] = action.detach().numpy().copy() 119 | data = agent.get_snapshot() 120 | 121 | # Advance simulation with one step and get the reward 122 | mj_forward_fn(action) 123 | reward = agent.reward.copy() 124 | next_state = np.concatenate((agent.data.qpos.copy(), agent.data.qvel.copy())) 125 | 126 | # Reset and set to initial state, then do the second simulation; this time call mj_forward_fn without args 127 | agent.reset() 128 | agent.data.qpos[:] = init_qpos 129 | agent.data.qvel[:] = init_qvel 130 | for _ in range(nwarmup): 131 | agent.data.ctrl[:] = action.detach().numpy() 132 | mj_forward_fn() 133 | 134 | # Advance simulation with one step and get reward 135 | agent.data.ctrl[:] = action.detach().numpy() 136 | mj_forward_fn() 137 | reward2 = agent.reward.copy() 138 | 139 | # reward1 and reward2 should be equal 140 | self.assertEqual(reward, reward2, "Simulations from same initial state diverged") 141 | 142 | # Then make sure simulation from snapshot doesn't diverge from original simulation 143 | agent.set_snapshot(data) 144 | mj_forward_fn() 145 | reward_snapshot = agent.reward.copy() 146 | self.assertEqual(reward, reward_snapshot, "Simulation from snapshot diverged") 147 | 148 | # Make sure simulations are correct in the gradient calculations as well 149 | mj_gradients_fn(data, next_state, reward, test=True) 150 | 151 | def test_dynamics_reward(self): 152 | 153 | # Run multiple times with different sigma for generating action values 154 | N = 100 155 | sigma = np.linspace(1e-2, 1e1, N) 156 | cfg_file = "/home/aleksi/Workspace/Model-Based-RL/configs/leg.yaml" 157 | 158 | for s in sigma: 159 | self.run_dynamics_reward_test(cfg_file, s) 160 | 161 | def run_dynamics_reward_test(self, cfg_file, sigma): 162 | 163 | cfg = get_cfg_defaults() 164 | cfg.merge_from_file(cfg_file) 165 | agent = build_agent(cfg) 166 | mj_dynamics_forward_fn = agent.forward_factory("dynamics") 167 | #mj_reward_forward_fn = agent.forward_factory("reward") 168 | #mj_dynamics_gradients_fn = agent.gradient_factory("dynamics") 169 | mj_reward_gradients_fn = agent.gradient_factory("reward") 170 | 171 | # Start from the same state with constant action, make sure reward is equal in both repetitions 172 | 173 | # Drive both simulations forward 5 steps 174 | nwarmup = 5 175 | 176 | # Reset and get initial state 177 | agent.reset() 178 | init_qpos = agent.data.qpos.copy() 179 | init_qvel = agent.data.qvel.copy() 180 | 181 | # Set constant action 182 | na = agent.model.actuator_acc0.shape[0] 183 | action = torch.DoubleTensor(np.random.randn(na)*sigma) 184 | 185 | # Do first simulation 186 | for _ in range(nwarmup): 187 | mj_dynamics_forward_fn(action) 188 | 189 | # Take a snapshot of this state so we can use it in gradient calculations 190 | agent.data.ctrl[:] = action.detach().numpy().copy() 191 | data = agent.get_snapshot() 192 | 193 | # Advance simulation with one step and get the reward 194 | mj_dynamics_forward_fn(action) 195 | reward = agent.reward.copy() 196 | next_state = np.concatenate((agent.data.qpos.copy(), agent.data.qvel.copy())) 197 | 198 | # Reset and set to initial state, then do the second simulation; this time call mj_forward_fn without args 199 | agent.reset() 200 | agent.data.qpos[:] = init_qpos 201 | agent.data.qvel[:] = init_qvel 202 | for _ in range(nwarmup): 203 | agent.data.ctrl[:] = action.detach().numpy() 204 | mj_dynamics_forward_fn() 205 | 206 | # Advance simulation with one step and get reward 207 | agent.data.ctrl[:] = action.detach().numpy() 208 | mj_dynamics_forward_fn() 209 | reward2 = agent.reward.copy() 210 | 211 | # reward1 and reward2 should be equal 212 | self.assertEqual(reward, reward2, "Simulations from same initial state diverged") 213 | 214 | # Then make sure simulation from snapshot doesn't diverge from original simulation 215 | agent.set_snapshot(data) 216 | mj_dynamics_forward_fn() 217 | reward_snapshot = agent.reward.copy() 218 | self.assertEqual(reward, reward_snapshot, "Simulation from snapshot diverged") 219 | 220 | # Make sure simulations are correct in the gradient calculations as well 221 | mj_reward_gradients_fn(data, next_state, reward, test=True) 222 | 223 | 224 | if __name__ == '__main__': 225 | unittest.main() 226 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MahanFathi/Model-Based-RL/b7d69c3bea44748411b259ddda1b7e9bb42985e0/utils/__init__.py -------------------------------------------------------------------------------- /utils/index.py: -------------------------------------------------------------------------------- 1 | class Index(object): 2 | def __init__(self, value=0): 3 | self.value = value 4 | 5 | def __index__(self): 6 | return self.value 7 | 8 | def __iadd__(self, other): 9 | self.value += other 10 | return self 11 | 12 | def __add__(self, other): 13 | if isinstance(other, Index): 14 | return Index(self.value + other.value) 15 | elif isinstance(other, int): 16 | return Index(self.value + other) 17 | else: 18 | raise NotImplementedError 19 | 20 | def __sub__(self, other): 21 | if isinstance(other, Index): 22 | return Index(self.value - other.value) 23 | elif isinstance(other, int): 24 | return Index(self.value - other) 25 | else: 26 | raise NotImplementedError 27 | 28 | def __eq__(self, other): 29 | return self.value == other 30 | 31 | def __repr__(self): 32 | return "Index({})".format(self.value) 33 | 34 | def __str__(self): 35 | return "{}".format(self.value) 36 | 37 | def set(self, other): 38 | if isinstance(other, int): 39 | self.value = other 40 | elif isinstance(other, Index): 41 | self.value = other.value 42 | else: 43 | raise NotImplementedError 44 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import sys 4 | import csv 5 | from datetime import datetime 6 | 7 | def setup_logger(name, save_dir, txt_file_name='log'): 8 | logger = logging.getLogger(name) 9 | logger.setLevel(logging.DEBUG) 10 | ch = logging.StreamHandler(stream=sys.stdout) 11 | ch.setLevel(logging.DEBUG) 12 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 13 | ch.setFormatter(formatter) 14 | logger.addHandler(ch) 15 | 16 | if save_dir: 17 | fh = logging.FileHandler(os.path.join(save_dir, "{}.txt".format(txt_file_name))) 18 | fh.setLevel(logging.DEBUG) 19 | fh.setFormatter(formatter) 20 | logger.addHandler(fh) 21 | 22 | return logger 23 | 24 | 25 | def save_dict_into_csv(save_dir, file_name, output): 26 | try: 27 | # Append file_name with datetime 28 | #file_name = os.path.join(save_dir, file_name + "_{0:%Y-%m-%dT%H:%M:%S}".format(datetime.now())) 29 | file_name = os.path.join(save_dir, file_name) 30 | with open(file_name, "w") as file: 31 | writer = csv.writer(file) 32 | writer.writerow(output.keys()) 33 | writer.writerows(zip(*output.values())) 34 | except IOError: 35 | print("Failed to save file {}".format(file_name)) 36 | -------------------------------------------------------------------------------- /utils/visdom_plots.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import visdom 3 | import numpy as np 4 | from subprocess import Popen, PIPE 5 | 6 | 7 | def create_visdom_connections(port): 8 | """If the program could not connect to Visdom server, this function will start a new server at port < port > """ 9 | cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % port 10 | print('\n\nCould not connect to Visdom server. \n Trying to start a server....') 11 | print('Command: %s' % cmd) 12 | Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE) 13 | 14 | 15 | class VisdomLogger(dict): 16 | """Plot losses.""" 17 | 18 | def __init__(self, port, *args, **kwargs): 19 | self.visdom_port = port 20 | self.visdom = visdom.Visdom(port=port) 21 | if not self.visdom.check_connection(): 22 | create_visdom_connections(port) 23 | super(VisdomLogger, self).__init__(*args, **kwargs) 24 | self.registered = False 25 | self.plot_attributes = {} 26 | 27 | def register_keys(self, keys): 28 | for key in keys: 29 | self[key] = [] 30 | for i, key in enumerate(self.keys()): 31 | self.plot_attributes[key] = {'win_id': i} 32 | self.registered = True 33 | 34 | def update(self, new_records): 35 | """Add new updates to records. 36 | 37 | :param new_records: dict of new updates. example: {'loss': [10., 5., 2.], 'lr': [1e-3, 1e-3, 5e-4]} 38 | """ 39 | for key, val in new_records.items(): 40 | if key in self.keys(): 41 | if isinstance(val, list): 42 | self[key].extend(val) 43 | else: 44 | self[key].append(val) 45 | 46 | def set(self, record): 47 | """Set a record (replaces old data)""" 48 | for key, val in record.items(): 49 | if key in self.keys(): 50 | self[key] = val 51 | 52 | def do_plotting(self): 53 | for k in self.keys(): 54 | y_values = np.array(self[k]) 55 | x_values = np.arange(len(self[k])) 56 | if len(x_values) < 1: 57 | continue 58 | # if y_values.size > 1: 59 | # x_values = np.reshape(x_values, (len(x_values), 1)) 60 | if len(y_values.shape) > 1 and y_values.shape[1]==1 and y_values.size != 1: 61 | y_values = y_values.squeeze() 62 | self.visdom.line(Y=y_values, X=x_values, win=self.plot_attributes[k]['win_id'], 63 | opts={'title': k.upper()}, update='append') 64 | # self[k] = [] 65 | --------------------------------------------------------------------------------