├── .dockerignore ├── .gitignore ├── LICENSE ├── README.md ├── azure ├── Dockerfile ├── config.py ├── download.sh ├── files │ ├── 10_nvidia.json │ └── Xdummy ├── launch_plan.py ├── launch_train.py ├── make_fuse_config.sh ├── mount.sh └── sync.sh ├── config └── offline.py ├── environment.yml ├── plotting ├── bar.png ├── plot.py ├── read_results.py ├── scores.py └── table.py ├── pretrained.sh ├── scripts ├── plan.py └── train.py ├── setup.py └── trajectory ├── __init__.py ├── datasets ├── __init__.py ├── d4rl.py ├── preprocessing.py └── sequence.py ├── models ├── __init__.py ├── ein.py ├── embeddings.py ├── mlp.py └── transformers.py ├── search ├── __init__.py ├── core.py ├── sampling.py └── utils.py └── utils ├── __init__.py ├── arrays.py ├── config.py ├── discretization.py ├── git_utils.py ├── progress.py ├── rendering.py ├── serialization.py ├── setup.py ├── timer.py ├── training.py └── video.py /.dockerignore: -------------------------------------------------------------------------------- 1 | *.ipynb_checkpoints 2 | *__pycache__* 3 | *.egg-info 4 | logs/* 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | __pycache__/ 3 | *.egg-info 4 | .DS_Store 5 | bin/ 6 | 7 | logs/* 8 | data/* 9 | slurm/config.sh 10 | slurm/*sandbox* 11 | *.img 12 | slurm-*.out 13 | *.png 14 | *.pdf 15 | *.pkl 16 | mjkey.txt 17 | azure/fuse.cfg -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Trajectory Transformer authors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Trajectory Transformer 2 | 3 | Code release for [Offline Reinforcement Learning as One Big Sequence Modeling Problem](https://arxiv.org/abs/2106.02039). 4 | 5 | **New:** Also see [Alexander Nikulin's fork](https://github.com/Howuhh/faster-trajectory-transformer) with attention caching and vectorized rollouts! 6 | 7 | ## Installation 8 | 9 | All python dependencies are in [`environment.yml`](environment.yml). Install with: 10 | 11 | ``` 12 | conda env create -f environment.yml 13 | conda activate trajectory 14 | pip install -e . 15 | ``` 16 | 17 | For reproducibility, we have also included system requirements in a [`Dockerfile`](azure/Dockerfile) (see [installation instructions](#Docker)), but the conda installation should work on most standard Linux machines. 18 | 19 | ## Usage 20 | 21 | Train a transformer with: `python scripts/train.py --dataset halfcheetah-medium-v2` 22 | 23 | To reproduce the offline RL results: `python scripts/plan.py --dataset halfcheetah-medium-v2` 24 | 25 | By default, these commands will use the hyperparameters in [`config/offline.py`](config/offline.py). You can override them with runtime flags: 26 | ``` 27 | python scripts/plan.py --dataset halfcheetah-medium-v2 \ 28 | --horizon 5 --beam_width 32 29 | ``` 30 | 31 | A few hyperparameters are different from those listed in the paper because of changes to the discretization strategy. These hyperparameters will be updated in the next arxiv version to match what is currently in the codebase. 32 | 33 | ## Pretrained models 34 | 35 | We have provided [pretrained models](https://www.dropbox.com/sh/r09lkdoj66kx43w/AACbXjMhcI6YNsn1qU4LParja?dl=0) for 16 datasets: `{halfcheetah, hopper, walker2d, ant}-{expert-v2, medium-expert-v2, medium-v2, medium-replay-v2}`. Download them with `./pretrained.sh` 36 | 37 | The models will be saved in `logs/$DATASET/gpt/pretrained`. To plan with these models, refer to them using the `gpt_loadpath` flag: 38 | ``` 39 | python scripts/plan.py --dataset halfcheetah-medium-v2 \ 40 | --gpt_loadpath gpt/pretrained 41 | ``` 42 | 43 | `pretrained.sh` will also download 15 [plans](https://www.dropbox.com/sh/po0nul2u6qk8r2i/AABPDrOEJplQ8JT13DASdOWWa?dl=0) from each model, saved to `logs/$DATASET/plans/pretrained`. Read them with ` 44 | python plotting/read_results.py`. 45 | 46 |
47 | To create the table of offline RL results from the paper, run python plotting/table.py. This will print a table that can be copied into a Latex document. (Expand to view table source.) 48 | 49 | ``` 50 | \begin{table*}[h] 51 | \centering 52 | \small 53 | \begin{tabular}{llrrrrrr} 54 | \toprule 55 | \multicolumn{1}{c}{\bf Dataset} & \multicolumn{1}{c}{\bf Environment} & \multicolumn{1}{c}{\bf BC} & \multicolumn{1}{c}{\bf MBOP} & \multicolumn{1}{c}{\bf BRAC} & \multicolumn{1}{c}{\bf CQL} & \multicolumn{1}{c}{\bf DT} & \multicolumn{1}{c}{\bf TT (Ours)} \\ 56 | \midrule 57 | Medium-Expert & HalfCheetah & $59.9$ & $105.9$ & $41.9$ & $91.6$ & $86.8$ & $95.0$ \scriptsize{\raisebox{1pt}{$\pm 0.2$}} \\ 58 | Medium-Expert & Hopper & $79.6$ & $55.1$ & $0.9$ & $105.4$ & $107.6$ & $110.0$ \scriptsize{\raisebox{1pt}{$\pm 2.7$}} \\ 59 | Medium-Expert & Walker2d & $36.6$ & $70.2$ & $81.6$ & $108.8$ & $108.1$ & $101.9$ \scriptsize{\raisebox{1pt}{$\pm 6.8$}} \\ 60 | Medium-Expert & Ant & $-$ & $-$ & $-$ & $-$ & $-$ & $116.1$ \scriptsize{\raisebox{1pt}{$\pm 9.0$}} \\ 61 | \midrule 62 | Medium & HalfCheetah & $43.1$ & $44.6$ & $46.3$ & $44.0$ & $42.6$ & $46.9$ \scriptsize{\raisebox{1pt}{$\pm 0.4$}} \\ 63 | Medium & Hopper & $63.9$ & $48.8$ & $31.3$ & $58.5$ & $67.6$ & $61.1$ \scriptsize{\raisebox{1pt}{$\pm 3.6$}} \\ 64 | Medium & Walker2d & $77.3$ & $41.0$ & $81.1$ & $72.5$ & $74.0$ & $79.0$ \scriptsize{\raisebox{1pt}{$\pm 2.8$}} \\ 65 | Medium & Ant & $-$ & $-$ & $-$ & $-$ & $-$ & $83.1$ \scriptsize{\raisebox{1pt}{$\pm 7.3$}} \\ 66 | \midrule 67 | Medium-Replay & HalfCheetah & $4.3$ & $42.3$ & $47.7$ & $45.5$ & $36.6$ & $41.9$ \scriptsize{\raisebox{1pt}{$\pm 2.5$}} \\ 68 | Medium-Replay & Hopper & $27.6$ & $12.4$ & $0.6$ & $95.0$ & $82.7$ & $91.5$ \scriptsize{\raisebox{1pt}{$\pm 3.6$}} \\ 69 | Medium-Replay & Walker2d & $36.9$ & $9.7$ & $0.9$ & $77.2$ & $66.6$ & $82.6$ \scriptsize{\raisebox{1pt}{$\pm 6.9$}} \\ 70 | Medium-Replay & Ant & $-$ & $-$ & $-$ & $-$ & $-$ & $77.0$ \scriptsize{\raisebox{1pt}{$\pm 6.8$}} \\ 71 | \midrule 72 | \multicolumn{2}{c}{\bf Average (without Ant)} & 47.7 & 47.8 & 36.9 & 77.6 & 74.7 & 78.9 \hspace{.6cm} \\ 73 | \multicolumn{2}{c}{\bf Average (all settings)} & $-$ & $-$ & $-$ & $-$ & $-$ & 82.2 \hspace{.6cm} \\ 74 | \bottomrule 75 | \end{tabular} 76 | \label{table:d4rl} 77 | \end{table*} 78 | ``` 79 | 80 | ![](https://github.com/anonymized-transformer/anonymized-transformer.github.io/blob/master/plots/table.png) 81 |
82 | 83 |
84 | 85 | To create the average performance plot, run python plotting/plot.py. 86 | 87 | (Expand to view plot.) 88 |
89 | 90 | ![](plotting/bar.png) 91 |
92 | 93 | ## Docker 94 | 95 | Copy your MuJoCo key to the Docker build context and build the container: 96 | ``` 97 | cp ~/.mujoco/mjkey.txt azure/files/ 98 | docker build -f azure/Dockerfile . -t trajectory 99 | ``` 100 | 101 | Test the container: 102 | ``` 103 | docker run -it --rm --gpus all \ 104 | --mount type=bind,source=$PWD,target=/home/code \ 105 | --mount type=bind,source=$HOME/.d4rl,target=/root/.d4rl \ 106 | trajectory \ 107 | bash -c \ 108 | "export PYTHONPATH=$PYTHONPATH:/home/code && \ 109 | python /home/code/scripts/train.py --dataset hopper-medium-expert-v2 --exp_name docker/" 110 | ``` 111 | 112 | ## Running on Azure 113 | 114 | #### Setup 115 | 116 | 1. Launching jobs on Azure requires one more python dependency: 117 | ``` 118 | pip install git+https://github.com/JannerM/doodad.git@janner 119 | ``` 120 | 121 | 2. Tag the image built in [the previous section](#Docker) and push it to Docker Hub: 122 | ``` 123 | export DOCKER_USERNAME=$(docker info | sed '/Username:/!d;s/.* //') 124 | docker tag trajectory ${DOCKER_USERNAME}/trajectory:latest 125 | docker image push ${DOCKER_USERNAME}/trajectory 126 | ``` 127 | 128 | 3. Update [`azure/config.py`](azure/config.py), either by modifying the file directly or setting the relevant [environment variables](azure/config.py#L47-L52). To set the `AZURE_STORAGE_CONNECTION` variable, navigate to the `Access keys` section of your storage account. Click `Show keys` and copy the `Connection string`. 129 | 130 | 4. Download [`azcopy`](https://docs.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-v10): `./azure/download.sh` 131 | 132 | #### Usage 133 | 134 | Launch training jobs with `python azure/launch_train.py` and planning jobs with `python azure/launch_plan.py`. 135 | 136 | These scripts do not take runtime arguments. Instead, they run the corresponding scripts ([`scripts/train.py`](scripts/train.py) and [`scripts/plan.py`](scripts/plan.py), respectively) using the Cartesian product of the parameters in [`params_to_sweep`](azure/launch_train.py#L36-L38). 137 | 138 | #### Viewing results 139 | 140 | To rsync the results from the Azure storage container, run `./azure/sync.sh`. 141 | 142 | To mount the storage container: 143 | 1. Create a blobfuse config with `./azure/make_fuse_config.sh` 144 | 2. Run `./azure/mount.sh` to mount the storage container to `~/azure_mount` 145 | 146 | To unmount the container, run `sudo umount -f ~/azure_mount; rm -r ~/azure_mount` 147 | 148 | ## Reference 149 | ``` 150 | @inproceedings{janner2021sequence, 151 | title = {Offline Reinforcement Learning as One Big Sequence Modeling Problem}, 152 | author = {Michael Janner and Qiyang Li and Sergey Levine}, 153 | booktitle = {Advances in Neural Information Processing Systems}, 154 | year = {2021}, 155 | } 156 | ``` 157 | 158 | ## Acknowledgements 159 | 160 | The GPT implementation is from Andrej Karpathy's [minGPT](https://github.com/karpathy/minGPT) repo. 161 | -------------------------------------------------------------------------------- /azure/Dockerfile: -------------------------------------------------------------------------------- 1 | # We need the CUDA base dockerfile to enable GPU rendering 2 | # on hosts with GPUs. 3 | # The image below is a pinned version of nvidia/cuda:9.1-cudnn7-devel-ubuntu16.04 (from Jan 2018) 4 | # If updating the base image, be sure to test on GPU since it has broken in the past. 5 | FROM nvidia/cuda:10.1-cudnn7-devel-ubuntu16.04 6 | 7 | SHELL ["/bin/bash", "-c"] 8 | 9 | ########################################################## 10 | ### System dependencies 11 | ########################################################## 12 | 13 | RUN apt-get update -q \ 14 | && DEBIAN_FRONTEND=noninteractive apt-get install -y \ 15 | cmake \ 16 | curl \ 17 | git \ 18 | libav-tools \ 19 | libgl1-mesa-dev \ 20 | libgl1-mesa-glx \ 21 | libglew-dev \ 22 | libosmesa6-dev \ 23 | net-tools \ 24 | software-properties-common \ 25 | swig \ 26 | unzip \ 27 | vim \ 28 | wget \ 29 | xpra \ 30 | xserver-xorg-dev \ 31 | zlib1g-dev \ 32 | && apt-get clean \ 33 | && rm -rf /var/lib/apt/lists/* 34 | 35 | ENV LANG C.UTF-8 36 | 37 | COPY ./azure/files/Xdummy /usr/local/bin/Xdummy 38 | RUN chmod +x /usr/local/bin/Xdummy 39 | 40 | # Workaround for https://bugs.launchpad.net/ubuntu/+source/nvidia-graphics-drivers-375/+bug/1674677 41 | COPY ./azure/files/10_nvidia.json /usr/share/glvnd/egl_vendor.d/10_nvidia.json 42 | COPY ./environment.yml /opt/environment.yml 43 | 44 | ENV LD_LIBRARY_PATH /usr/local/nvidia/lib64:${LD_LIBRARY_PATH} 45 | 46 | ########################################################## 47 | ### MuJoCo 48 | ########################################################## 49 | # Note: ~ is an alias for /root 50 | RUN mkdir -p /root/.mujoco \ 51 | && wget https://www.roboti.us/download/mujoco200_linux.zip -O mujoco.zip \ 52 | && unzip mujoco.zip -d /root/.mujoco \ 53 | && rm mujoco.zip 54 | RUN mkdir -p /root/.mujoco \ 55 | && wget https://www.roboti.us/download/mjpro150_linux.zip -O mujoco.zip \ 56 | && unzip mujoco.zip -d /root/.mujoco \ 57 | && rm mujoco.zip 58 | RUN ln -s /root/.mujoco/mujoco200_linux /root/.mujoco/mujoco200 59 | ENV LD_LIBRARY_PATH /root/.mujoco/mjpro150/bin:${LD_LIBRARY_PATH} 60 | ENV LD_LIBRARY_PATH /root/.mujoco/mujoco200/bin:${LD_LIBRARY_PATH} 61 | ENV LD_LIBRARY_PATH /root/.mujoco/mujoco200_linux/bin:${LD_LIBRARY_PATH} 62 | COPY ./azure/files/mjkey.txt /root/.mujoco 63 | 64 | ########################################################## 65 | ### Example Python Installation 66 | ########################################################## 67 | ENV PATH /opt/conda/bin:$PATH 68 | RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /tmp/miniconda.sh && \ 69 | /bin/bash /tmp/miniconda.sh -b -p /opt/conda && \ 70 | rm /tmp/miniconda.sh && \ 71 | ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \ 72 | echo ". /opt/conda/etc/profile.d/conda.sh" >> /etc/bash.bashrc 73 | 74 | RUN conda update -y --name base conda && conda clean --all -y 75 | 76 | RUN conda env create -f /opt/environment.yml 77 | ENV PATH /opt/conda/envs/trajectory/bin:$PATH 78 | 79 | ########################################################## 80 | ### gym sometimes has this patchelf issue 81 | ########################################################## 82 | RUN curl -o /usr/local/bin/patchelf https://s3-us-west-2.amazonaws.com/openai-sci-artifacts/manual-builds/patchelf_0.9_amd64.elf \ 83 | && chmod +x /usr/local/bin/patchelf 84 | # RUN pip install gym[all]==0.12.5 85 | 86 | RUN echo "source activate /opt/conda/envs/trajectory && export PYTHONPATH=$PYTHONPATH:/home/code && export CUDA_VISIBLE_DEVICES=0" >> ~/.bashrc 87 | RUN source ~/.bashrc 88 | -------------------------------------------------------------------------------- /azure/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def get_docker_username(): 4 | import subprocess 5 | import shlex 6 | ps = subprocess.Popen(shlex.split('docker info'), stdout=subprocess.PIPE, stderr=subprocess.PIPE) 7 | output = subprocess.check_output(shlex.split("sed '/Username:/!d;s/.* //'"), stdin=ps.stdout) 8 | username = output.decode('utf-8').replace('\n', '') 9 | print(f'[ azure/config ] Grabbed username from `docker info`: {username}') 10 | return username 11 | 12 | ## /path/to/trajectory-transformer/azure 13 | CWD = os.path.dirname(__file__) 14 | ## /path/to/trajectory-transformer 15 | MODULE_PATH = os.path.dirname(CWD) 16 | 17 | CODE_DIRS_TO_MOUNT = [ 18 | ] 19 | NON_CODE_DIRS_TO_MOUNT = [ 20 | dict( 21 | local_dir=MODULE_PATH, 22 | mount_point='/home/code', 23 | ), 24 | ] 25 | REMOTE_DIRS_TO_MOUNT = [ 26 | dict( 27 | local_dir='/doodad_tmp/', 28 | mount_point='/doodad_tmp/', 29 | ), 30 | ] 31 | LOCAL_LOG_DIR = '/tmp' 32 | 33 | DEFAULT_AZURE_GPU_MODEL = 'nvidia-tesla-t4' 34 | DEFAULT_AZURE_INSTANCE_TYPE = 'Standard_DS1_v2' 35 | DEFAULT_AZURE_REGION = 'eastus' 36 | DEFAULT_AZURE_RESOURCE_GROUP = 'traj' 37 | DEFAULT_AZURE_VM_NAME = 'traj-vm' 38 | DEFAULT_AZURE_VM_PASSWORD = 'Azure1' 39 | 40 | DOCKER_USERNAME = os.environ.get('DOCKER_USERNAME', get_docker_username()) 41 | DEFAULT_DOCKER = f'docker.io/{DOCKER_USERNAME}/trajectory:latest' 42 | 43 | print(f'[ azure/config ] Local dir: {MODULE_PATH}') 44 | print(f'[ azure/config ] Default GPU model: {DEFAULT_AZURE_GPU_MODEL}') 45 | print(f'[ azure/config ] Default Docker image: {DEFAULT_DOCKER}') 46 | 47 | AZ_SUB_ID = os.environ['AZURE_SUBSCRIPTION_ID'] 48 | AZ_CLIENT_ID = os.environ['AZURE_CLIENT_ID'] 49 | AZ_TENANT_ID = os.environ['AZURE_TENANT_ID'] 50 | AZ_SECRET = os.environ['AZURE_CLIENT_SECRET'] 51 | AZ_CONTAINER = os.environ['AZURE_STORAGE_CONTAINER'] 52 | AZ_CONN_STR = os.environ['AZURE_STORAGE_CONNECTION_STRING'] 53 | -------------------------------------------------------------------------------- /azure/download.sh: -------------------------------------------------------------------------------- 1 | DOWNLOAD_DIR=bin 2 | mkdir $DOWNLOAD_DIR 3 | wget https://aka.ms/downloadazcopy-v10-linux -O $DOWNLOAD_DIR/download.tar.gz 4 | tar -xvf $DOWNLOAD_DIR/download.tar.gz --one-top-level=$DOWNLOAD_DIR 5 | mv $DOWNLOAD_DIR/*/azcopy $DOWNLOAD_DIR 6 | rm $DOWNLOAD_DIR/download.tar.gz 7 | rm -r $DOWNLOAD_DIR/azcopy_linux* -------------------------------------------------------------------------------- /azure/files/10_nvidia.json: -------------------------------------------------------------------------------- 1 | { 2 | "file_format_version" : "1.0.0", 3 | "ICD" : { 4 | "library_path" : "libEGL_nvidia.so.0" 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /azure/files/Xdummy: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # ---------------------------------------------------------------------- 3 | # Copyright (C) 2005-2011 Karl J. Runge 4 | # All rights reserved. 5 | # 6 | # This file is part of Xdummy. 7 | # 8 | # Xdummy is free software; you can redistribute it and/or modify 9 | # it under the terms of the GNU General Public License as published by 10 | # the Free Software Foundation; either version 2 of the License, or (at 11 | # your option) any later version. 12 | # 13 | # Xdummy is distributed in the hope that it will be useful, 14 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 15 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 | # GNU General Public License for more details. 17 | # 18 | # You should have received a copy of the GNU General Public License 19 | # along with Xdummy; if not, write to the Free Software 20 | # Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA 21 | # or see . 22 | # ---------------------------------------------------------------------- 23 | # 24 | # 25 | # Xdummy: an LD_PRELOAD hack to run a stock Xorg(1) or XFree86(1) server 26 | # with the "dummy" video driver to make it avoid Linux VT switching, etc. 27 | # 28 | # Run "Xdummy -help" for more info. 29 | # 30 | install="" 31 | uninstall="" 32 | runit=1 33 | prconf="" 34 | notweak="" 35 | root="" 36 | nosudo="" 37 | xserver="" 38 | geom="" 39 | nomodelines="" 40 | depth="" 41 | debug="" 42 | strace="" 43 | cmdline_config="" 44 | 45 | PATH=$PATH:/bin:/usr/bin 46 | export PATH 47 | 48 | program=`basename "$0"` 49 | 50 | help () { 51 | ${PAGER:-more} << END 52 | $program: 53 | 54 | A hack to run a stock Xorg(1) or XFree86(1) X server with the "dummy" 55 | (RAM-only framebuffer) video driver such that it AVOIDS the Linux VT 56 | switching, opening device files in /dev, keyboard and mouse conflicts, 57 | and other problems associated with the normal use of "dummy". 58 | 59 | In other words, it tries to make Xorg/XFree86 with the "dummy" 60 | device driver act more like Xvfb(1). 61 | 62 | The primary motivation for the Xdummy script is to provide a virtual X 63 | server for x11vnc but with more features than Xvfb (or Xvnc); however 64 | it could be used for other reasons (e.g. better automated testing 65 | than with Xvfb.) One nice thing is the dummy server supports RANDR 66 | dynamic resizing while Xvfb does not. 67 | 68 | So, for example, x11vnc+Xdummy terminal services are a little better 69 | than x11vnc+Xvfb. 70 | 71 | To achieve this, while running the real Xserver $program intercepts 72 | system and library calls via the LD_PRELOAD method and modifies 73 | the behavior to make it work correctly (e.g. avoid the VT stuff.) 74 | LD_PRELOAD tricks are usually "clever hacks" and so might not work 75 | in all situations or break when something changes. 76 | 77 | WARNING: Take care in using Xdummy, although it never has it is 78 | possible that it could damage hardware. One can use the -prconf 79 | option to have it print out the xorg.conf config that it would use 80 | and then inspect it carefully before actually using it. 81 | 82 | This program no longer needs to be run as root as of 12/2009. 83 | However, if there are problems for certain situations (usually older 84 | servers) it may perform better if run as root (use the -root option.) 85 | When running as root remember the previous paragraph and that Xdummy 86 | comes without any warranty. 87 | 88 | gcc/cc and other build tools are required for this script to be able 89 | to compile the LD_PRELOAD shared object. Be sure they are installed 90 | on the system. See -install and -uninstall described below. 91 | 92 | Your Linux distribution may not install the dummy driver by default, 93 | e.g: 94 | 95 | /usr/lib/xorg/modules/drivers/dummy_drv.so 96 | 97 | some have it in a package named xserver-xorg-video-dummy you that 98 | need to install. 99 | 100 | Usage: 101 | 102 | $program <${program}-args> 103 | 104 | (actually, the arguments can be supplied in any order.) 105 | 106 | Examples: 107 | 108 | $program -install 109 | 110 | $program :1 111 | 112 | $program -debug :1 113 | 114 | $program -tmpdir ~/mytmp :1 -nolisten tcp 115 | 116 | startx example: 117 | 118 | startx -e bash -- $program :2 -depth 16 119 | 120 | (if startx needs to be run as root, you can su(1) to a normal 121 | user in the bash shell and then launch ~/.xinitrc or ~/.xsession, 122 | gnome-session, startkde, startxfce4, etc.) 123 | 124 | xdm example: 125 | 126 | xdm -config /usr/local/dummy/xdm-config -nodaemon 127 | 128 | where the xdm-config file has line: 129 | 130 | DisplayManager.servers: /usr/local/dummy/Xservers 131 | 132 | and /usr/local/dummy/Xservers has lines: 133 | 134 | :1 local /usr/local/dummy/Xdummy :1 -debug 135 | :2 local /usr/local/dummy/Xdummy :2 -debug 136 | 137 | (-debug is optional) 138 | 139 | gdm/kdm example: 140 | 141 | TBD. 142 | 143 | Config file: 144 | 145 | If the file $program.cfg exists it will be sourced as shell 146 | commands. Usually one will set some variables this way. 147 | To disable sourcing, supply -nocfg or set XDUMMY_NOCFG=1. 148 | 149 | Root permission and x11vnc: 150 | 151 | Update: as of 12/2009 this program no longer must be run as root. 152 | So try it as non-root before running it as root and/or the 153 | following schemes. 154 | 155 | In some circumstances X server program may need to be run as root. 156 | If so, one could run x11vnc as root with -unixpw (it switches 157 | to the user that logs in) and that may be OK, some other ideas: 158 | 159 | - add this to sudo via visudo: 160 | 161 | ALL ALL = NOPASSWD: /usr/local/bin/Xdummy 162 | 163 | - use this little suid wrapper: 164 | /* 165 | * xdummy.c 166 | * 167 | cc -o ./xdummy xdummy.c 168 | sudo cp ./xdummy /usr/local/bin/xdummy 169 | sudo chown root:root /usr/local/bin/xdummy 170 | sudo chmod u+s /usr/local/bin/xdummy 171 | * 172 | */ 173 | #include 174 | #include 175 | #include 176 | #include 177 | 178 | int main (int argc, char *argv[]) { 179 | extern char **environ; 180 | char str[100]; 181 | sprintf(str, "XDUMMY_UID=%d", (int) getuid()); 182 | putenv(str); 183 | setuid(0); 184 | setgid(0); 185 | execv("/usr/local/bin/Xdummy", argv); 186 | exit(1); 187 | return 1; 188 | } 189 | 190 | 191 | Options: 192 | 193 | ${program}-args: 194 | 195 | -install Compile the LD_PRELOAD shared object and install it 196 | next to the $program script file as: 197 | 198 | $0.so 199 | 200 | When that file exists it is used as the LD_PRELOAD 201 | shared object without recompiling. Otherwise, 202 | each time $program is run the LD_PRELOAD shared 203 | object is compiled as a file in /tmp (or -tmpdir) 204 | 205 | If you set the environment variable 206 | INTERPOSE_GETUID=1 when building, then when 207 | $program is run as an ordinary user, the shared 208 | object will interpose getuid() calls and pretend 209 | to be root. Otherwise it doesn't pretend to 210 | be root. 211 | 212 | You can also set the CFLAGS environment variable 213 | to anything else you want on the compile cmdline. 214 | 215 | -uninstall Remove the file: 216 | 217 | $0.so 218 | 219 | The LD_PRELOAD shared object will then be compiled 220 | each time this program is run. 221 | 222 | The X server is not started under -install, -uninstall, or -prconf. 223 | 224 | 225 | :N The DISPLAY (e.g. :15) is often the first 226 | argument. It is passed to the real X server and 227 | also used by the Xdummy script as an identifier. 228 | 229 | -geom geom1[,geom2...] Take the geometry (e.g. 1024x768) or list 230 | of geometries and insert them into the Screen 231 | section of the tweaked X server config file. 232 | Use this to have a different geometry than the 233 | one(s) in the system config file. 234 | 235 | The option -geometry can be used instead of -geom; 236 | x11vnc calls Xdummy and Xvfb this way. 237 | 238 | -nomodelines When you specify -geom/-geometry, $program will 239 | create Modelines for each geometry and put them 240 | in the Monitor section. If you do not want this 241 | then supply -nomodelines. 242 | 243 | -depth n Use pixel color depth n (e.g. 8, 16, or 24). This 244 | makes sure the X config file has a Screen.Display 245 | subsection of this depth. Note this option is 246 | ALSO passed to the X server. 247 | 248 | -DEPTH n Same as -depth, except not passed to X server. 249 | 250 | -tmpdir dir Specify a temporary directory, owned by you and 251 | only writable by you. This is used in place of 252 | /tmp/Xdummy.\$USER/.. to place the $program.so 253 | shared object, tweaked config files, etc. 254 | 255 | -nonroot Run in non-root mode (working 12/2009, now default) 256 | 257 | -root Run as root (may still be needed in some 258 | environments.) Same as XDUMMY_RUN_AS_ROOT=1. 259 | 260 | -nosudo Do not try to use sudo(1) when re-running as root, 261 | use su(1) instead. 262 | 263 | -xserver path Specify the path to the Xserver to use. Default 264 | is to try "Xorg" first and then "XFree86". If 265 | those are not in \$PATH, it tries these locations: 266 | /usr/bin/Xorg 267 | /usr/X11R6/bin/Xorg 268 | /usr/X11R6/bin/XFree86 269 | 270 | -n Do not run the command to start the X server, 271 | just show the command that $program would run. 272 | The LD_PRELOAD shared object will be built, 273 | if needed. Also note any XDUMMY* environment 274 | variables that need to be set. 275 | 276 | -prconf Print, to stdout, the tweaked Xorg/XFree86 277 | config file (-config and -xf86config server 278 | options, respectively.) The Xserver is not 279 | started. 280 | 281 | -notweak Do not tweak (modify) the Xorg/XFree86 config file 282 | (system or server command line) at all. The -geom 283 | and similar config file modifications are ignored. 284 | 285 | It is up to you to make sure it is a working 286 | config file (e.g. "dummy" driver, etc.) 287 | Perhaps you want to use a file based on the 288 | -prconf output. 289 | 290 | -nocfg Do not try to source $program.cfg even if it 291 | exists. Same as setting XDUMMY_NOCFG=1. 292 | 293 | -debug Extra debugging output. 294 | 295 | -strace strace(1) the Xserver process (for troubleshooting.) 296 | -ltrace ltrace(1) instead of strace (can be slow.) 297 | 298 | -h, -help Print out this help. 299 | 300 | 301 | Xserver-args: 302 | 303 | Most of the Xorg and XFree86 options will work and are simply 304 | passed along if you supply them. Important ones that may be 305 | supplied if missing: 306 | 307 | :N X Display number for server to use. 308 | 309 | vtNN Linux virtual terminal (VT) to use (a VT is currently 310 | still used, just not switched to and from.) 311 | 312 | -config file Driver "dummy" tweaked config file, a 313 | -xf86config file number of settings are tweaked besides Driver. 314 | 315 | If -config/-xf86config is not given, the system one 316 | (e.g. /etc/X11/xorg.conf) is used. If the system one cannot be 317 | found, a built-in one is used. Any settings in the config file 318 | that are not consistent with "dummy" mode will be overwritten 319 | (unless -notweak is specified.) 320 | 321 | Use -config xdummy-builtin to force usage of the builtin config. 322 | 323 | If "file" is only a basename (e.g. "xorg.dummy.conf") with no /'s, 324 | then no tweaking of it is done: the X server will look for that 325 | basename via its normal search algorithm. If the found file does 326 | not refer to the "dummy" driver, etc, then the X server will fail. 327 | 328 | You can set the env. var. XDUMMY_EXTRA_SERVER_ARGS to hold some 329 | extra Xserver-args too. (Useful for cfg file.) 330 | 331 | Notes: 332 | 333 | The Xorg/XFree86 "dummy" driver is currently undocumented. It works 334 | well in this mode, but it is evidently not intended for end-users. 335 | So it could be removed or broken at any time. 336 | 337 | If the display Xserver-arg (e.g. :1) is not given, or ":" is given 338 | that indicates $program should try to find a free one (based on 339 | tcp ports.) 340 | 341 | If the display virtual terminal, VT, (e.g. vt9) is not given that 342 | indicates $program should try to find a free one (or guess a high one.) 343 | 344 | This program is not completely secure WRT files in /tmp (but it tries 345 | to a good degree.) Better is to use the -tmpdir option to supply a 346 | directory only writable by you. Even better is to get rid of users 347 | on the local machine you do not trust :-) 348 | 349 | Set XDUMMY_SET_XV=1 to turn on debugging output for this script. 350 | 351 | END 352 | } 353 | 354 | warn() { 355 | echo "$*" 1>&2 356 | } 357 | 358 | if [ "X$XDUMMY_SET_XV" != "X" ]; then 359 | set -xv 360 | fi 361 | 362 | if [ "X$XDUMMY_UID" = "X" ]; then 363 | XDUMMY_UID=`id -u` 364 | export XDUMMY_UID 365 | fi 366 | if [ "X$XDUMMY_UID" = "X0" ]; then 367 | if [ "X$SUDO_UID" != "X" ]; then 368 | XDUMMY_UID=$SUDO_UID 369 | export XDUMMY_UID 370 | fi 371 | fi 372 | 373 | # check if root=1 first: 374 | # 375 | if [ "X$XDUMMY_RUN_AS_ROOT" = "X1" ]; then 376 | root=1 377 | fi 378 | for arg in $* 379 | do 380 | if [ "X$arg" = "X-nonroot" ]; then 381 | root="" 382 | elif [ "X$arg" = "X-root" ]; then 383 | root=1 384 | elif [ "X$arg" = "X-nocfg" ]; then 385 | XDUMMY_NOCFG=1 386 | export XDUMMY_NOCFG 387 | fi 388 | done 389 | 390 | if [ "X$XDUMMY_NOCFG" = "X" -a -f "$0.cfg" ]; then 391 | . "$0.cfg" 392 | fi 393 | 394 | # See if it really needs to be run as root: 395 | # 396 | if [ "X$XDUMMY_SU_EXEC" = "X" -a "X$root" = "X1" -a "X`id -u`" != "X0" ]; then 397 | # this is to prevent infinite loop in case su/sudo doesn't work: 398 | XDUMMY_SU_EXEC=1 399 | export XDUMMY_SU_EXEC 400 | 401 | dosu=1 402 | nosudo="" 403 | 404 | for arg in $* 405 | do 406 | if [ "X$arg" = "X-nonroot" ]; then 407 | dosu="" 408 | elif [ "X$arg" = "X-nosudo" ]; then 409 | nosudo="1" 410 | elif [ "X$arg" = "X-help" ]; then 411 | dosu="" 412 | elif [ "X$arg" = "X-h" ]; then 413 | dosu="" 414 | elif [ "X$arg" = "X-install" ]; then 415 | dosu="" 416 | elif [ "X$arg" = "X-uninstall" ]; then 417 | dosu="" 418 | elif [ "X$arg" = "X-n" ]; then 419 | dosu="" 420 | elif [ "X$arg" = "X-prconf" ]; then 421 | dosu="" 422 | fi 423 | done 424 | if [ $dosu ]; then 425 | # we need to restart it with su/sudo: 426 | if type sudo > /dev/null 2>&1; then 427 | : 428 | else 429 | nosudo=1 430 | fi 431 | if [ "X$nosudo" = "X" ]; then 432 | warn "$program: supply the sudo password to restart as root:" 433 | if [ "X$XDUMMY_UID" != "X" ]; then 434 | exec sudo $0 -uid $XDUMMY_UID "$@" 435 | else 436 | exec sudo $0 "$@" 437 | fi 438 | else 439 | warn "$program: supply the root password to restart as root:" 440 | if [ "X$XDUMMY_UID" != "X" ]; then 441 | exec su -c "$0 -uid $XDUMMY_UID $*" 442 | else 443 | exec su -c "$0 $*" 444 | fi 445 | fi 446 | # DONE: 447 | exit 448 | fi 449 | fi 450 | 451 | # This will hold the X display, e.g. :20 452 | # 453 | disp="" 454 | args="" 455 | cmdline_config="" 456 | 457 | # Process Xdummy args: 458 | # 459 | while [ "X$1" != "X" ] 460 | do 461 | if [ "X$1" = "X-config" -o "X$1" = "X-xf86config" ]; then 462 | cmdline_config="$2" 463 | fi 464 | case $1 in 465 | ":"*) disp=$1 466 | ;; 467 | "-install") install=1; runit="" 468 | ;; 469 | "-uninstall") uninstall=1; runit="" 470 | ;; 471 | "-n") runit="" 472 | ;; 473 | "-no") runit="" 474 | ;; 475 | "-norun") runit="" 476 | ;; 477 | "-prconf") prconf=1; runit="" 478 | ;; 479 | "-notweak") notweak=1 480 | ;; 481 | "-noconf") notweak=1 482 | ;; 483 | "-nonroot") root="" 484 | ;; 485 | "-root") root=1 486 | ;; 487 | "-nosudo") nosudo=1 488 | ;; 489 | "-xserver") xserver="$2"; shift 490 | ;; 491 | "-uid") XDUMMY_UID="$2"; shift 492 | export XDUMMY_UID 493 | ;; 494 | "-geom") geom="$2"; shift 495 | ;; 496 | "-geometry") geom="$2"; shift 497 | ;; 498 | "-nomodelines") nomodelines=1 499 | ;; 500 | "-depth") depth="$2"; args="$args -depth $2"; 501 | shift 502 | ;; 503 | "-DEPTH") depth="$2"; shift 504 | ;; 505 | "-tmpdir") XDUMMY_TMPDIR="$2"; shift 506 | ;; 507 | "-debug") debug=1 508 | ;; 509 | "-nocfg") : 510 | ;; 511 | "-nodebug") debug="" 512 | ;; 513 | "-strace") strace=1 514 | ;; 515 | "-ltrace") strace=2 516 | ;; 517 | "-h") help; exit 0 518 | ;; 519 | "-help") help; exit 0 520 | ;; 521 | *) args="$args $1" 522 | ;; 523 | esac 524 | shift 525 | done 526 | 527 | if [ "X$XDUMMY_EXTRA_SERVER_ARGS" != "X" ]; then 528 | args="$args $XDUMMY_EXTRA_SERVER_ARGS" 529 | fi 530 | 531 | # Try to get a username for use in our tmp directory, etc. 532 | # 533 | user="" 534 | if [ X`id -u` = "X0" ]; then 535 | user=root # this will also be used below for id=0 536 | elif [ "X$USER" != "X" ]; then 537 | user=$USER 538 | elif [ "X$LOGNAME" != "X" ]; then 539 | user=$LOGNAME 540 | fi 541 | 542 | # Keep trying... 543 | # 544 | if [ "X$user" = "X" ]; then 545 | user=`whoami 2>/dev/null` 546 | fi 547 | if [ "X$user" = "X" ]; then 548 | user=`basename "$HOME"` 549 | fi 550 | if [ "X$user" = "X" -o "X$user" = "X." ]; then 551 | user="u$$" 552 | fi 553 | 554 | if [ "X$debug" = "X1" -a "X$runit" != "X" ]; then 555 | echo "" 556 | echo "/usr/bin/env:" 557 | env | egrep -v '^(LS_COLORS|TERMCAP)' | sort 558 | echo "" 559 | fi 560 | 561 | # Function to compile the LD_PRELOAD shared object: 562 | # 563 | make_so() { 564 | # extract code embedded in this script into a tmp C file: 565 | n1=`grep -n '^#code_begin' $0 | head -1 | awk -F: '{print $1}'` 566 | n2=`grep -n '^#code_end' $0 | head -1 | awk -F: '{print $1}'` 567 | n1=`expr $n1 + 1` 568 | dn=`expr $n2 - $n1` 569 | 570 | tmp=$tdir/Xdummy.$RANDOM$$.c 571 | rm -f $tmp 572 | if [ -e $tmp -o -h $tmp ]; then 573 | warn "$tmp still exists." 574 | exit 1 575 | fi 576 | touch $tmp || exit 1 577 | tail -n +$n1 $0 | head -n $dn > $tmp 578 | 579 | # compile it to Xdummy.so: 580 | if [ -f "$SO" ]; then 581 | mv $SO $SO.$$ 582 | rm -f $SO.$$ 583 | fi 584 | rm -f $SO 585 | touch $SO 586 | if [ ! -f "$SO" ]; then 587 | SO=$tdir/Xdummy.$user.so 588 | warn "warning switching LD_PRELOAD shared object to: $SO" 589 | fi 590 | 591 | if [ -f "$SO" ]; then 592 | mv $SO $SO.$$ 593 | rm -f $SO.$$ 594 | fi 595 | rm -f $SO 596 | 597 | # we assume gcc: 598 | if [ "X$INTERPOSE_GETUID" = "X1" ]; then 599 | CFLAGS="$CFLAGS -DINTERPOSE_GETUID" 600 | fi 601 | echo "$program:" cc -shared -fPIC $CFLAGS -o $SO $tmp -ldl 602 | cc -shared -fPIC $CFLAGS -o $SO $tmp -ldl 603 | rc=$? 604 | rm -f $tmp 605 | if [ $rc != 0 ]; then 606 | warn "$program: cannot build $SO" 607 | exit 1 608 | fi 609 | if [ "X$debug" != "X" -o "X$install" != "X" ]; then 610 | warn "$program: created $SO" 611 | ls -l "$SO" 612 | fi 613 | } 614 | 615 | # Set tdir to tmp dir for make_so(): 616 | if [ "X$XDUMMY_TMPDIR" != "X" ]; then 617 | tdir=$XDUMMY_TMPDIR 618 | mkdir -p $tdir 619 | else 620 | tdir="/tmp" 621 | fi 622 | 623 | # Handle -install/-uninstall case: 624 | SO=$0.so 625 | if [ "X$install" != "X" -o "X$uninstall" != "X" ]; then 626 | if [ -e "$SO" -o -h "$SO" ]; then 627 | warn "$program: removing $SO" 628 | fi 629 | if [ -f "$SO" ]; then 630 | mv $SO $SO.$$ 631 | rm -f $SO.$$ 632 | fi 633 | rm -f $SO 634 | if [ -e "$SO" -o -h "$SO" ]; then 635 | warn "warning: $SO still exists." 636 | exit 1 637 | fi 638 | if [ $install ]; then 639 | make_so 640 | if [ ! -f "$SO" ]; then 641 | exit 1 642 | fi 643 | fi 644 | exit 0 645 | fi 646 | 647 | # We need a tmp directory for the .so, tweaked config file, and for 648 | # redirecting filenames we cannot create (under -nonroot) 649 | # 650 | tack="" 651 | if [ "X$XDUMMY_TMPDIR" = "X" ]; then 652 | XDUMMY_TMPDIR="/tmp/Xdummy.$user" 653 | 654 | # try to tack on a unique subdir (display number or pid) 655 | # to allow multiple instances 656 | # 657 | if [ "X$disp" != "X" ]; then 658 | t0=$disp 659 | else 660 | t0=$1 661 | fi 662 | tack=`echo "$t0" | sed -e 's/^.*://'` 663 | if echo "$tack" | grep '^[0-9][0-9]*$' > /dev/null; then 664 | : 665 | else 666 | tack=$$ 667 | fi 668 | if [ "X$tack" != "X" ]; then 669 | XDUMMY_TMPDIR="$XDUMMY_TMPDIR/$tack" 670 | fi 671 | fi 672 | 673 | tmp=$XDUMMY_TMPDIR 674 | if echo "$tmp" | grep '^/tmp' > /dev/null; then 675 | if [ "X$tmp" != "X/tmp" -a "X$tmp" != "X/tmp/" ]; then 676 | # clean this subdir of /tmp out, otherwise leave it... 677 | rm -rf $XDUMMY_TMPDIR 678 | if [ -e $XDUMMY_TMPDIR ]; then 679 | warn "$XDUMMY_TMPDIR still exists" 680 | exit 1 681 | fi 682 | fi 683 | fi 684 | 685 | mkdir -p $XDUMMY_TMPDIR 686 | chmod 700 $XDUMMY_TMPDIR 687 | if [ "X$tack" != "X" ]; then 688 | chmod 700 `dirname "$XDUMMY_TMPDIR"` 2>/dev/null 689 | fi 690 | 691 | # See if we can write something there: 692 | # 693 | tfile="$XDUMMY_TMPDIR/test.file" 694 | touch $tfile 695 | if [ ! -f "$tfile" ]; then 696 | XDUMMY_TMPDIR="/tmp/Xdummy.$$.$USER" 697 | warn "warning: setting tmpdir to $XDUMMY_TMPDIR ..." 698 | rm -rf $XDUMMY_TMPDIR || exit 1 699 | mkdir -p $XDUMMY_TMPDIR || exit 1 700 | fi 701 | rm -f $tfile 702 | 703 | export XDUMMY_TMPDIR 704 | 705 | # Compile the LD_PRELOAD shared object if needed (needs XDUMMY_TMPDIR) 706 | # 707 | if [ ! -f "$SO" ]; then 708 | SO="$XDUMMY_TMPDIR/Xdummy.so" 709 | make_so 710 | fi 711 | 712 | # Decide which X server to use: 713 | # 714 | if [ "X$xserver" = "X" ]; then 715 | if type Xorg >/dev/null 2>&1; then 716 | xserver="Xorg" 717 | elif type XFree86 >/dev/null 2>&1; then 718 | xserver="XFree86" 719 | elif -x /usr/bin/Xorg; then 720 | xserver="/usr/bin/Xorg" 721 | elif -x /usr/X11R6/bin/Xorg; then 722 | xserver="/usr/X11R6/bin/Xorg" 723 | elif -x /usr/X11R6/bin/XFree86; then 724 | xserver="/usr/X11R6/bin/XFree86" 725 | fi 726 | if [ "X$xserver" = "X" ]; then 727 | # just let it fail below. 728 | xserver="/usr/bin/Xorg" 729 | warn "$program: cannot locate a stock Xserver... assuming $xserver" 730 | fi 731 | fi 732 | 733 | # See if the binary is suid or not readable under -nonroot mode: 734 | # 735 | if [ "X$BASH_VERSION" != "X" ]; then 736 | xserver_path=`type -p $xserver 2>/dev/null` 737 | else 738 | xserver_path=`type $xserver 2>/dev/null | awk '{print $NF}'` 739 | fi 740 | if [ -e "$xserver_path" -a "X$root" = "X" -a "X$runit" != "X" ]; then 741 | if [ ! -r $xserver_path -o -u $xserver_path -o -g $xserver_path ]; then 742 | # XXX not quite correct with rm -rf $XDUMMY_TMPDIR ... 743 | # we keep on a filesystem we know root can write to. 744 | base=`basename "$xserver_path"` 745 | new="/tmp/$base.$user.bin" 746 | if [ -e $new ]; then 747 | snew=`ls -l $new | awk '{print $5}' | grep '^[0-9][0-9]*$'` 748 | sold=`ls -l $xserver_path | awk '{print $5}' | grep '^[0-9][0-9]*$'` 749 | if [ "X$snew" != "X" -a "X$sold" != "X" -a "X$sold" != "X$snew" ]; then 750 | warn "removing different sized copy:" 751 | ls -l $new $xserver_path 752 | rm -f $new 753 | fi 754 | fi 755 | if [ ! -e $new -o ! -s $new ]; then 756 | rm -f $new 757 | touch $new || exit 1 758 | chmod 700 $new || exit 1 759 | if [ ! -r $xserver_path ]; then 760 | warn "" 761 | warn "NEED TO COPY UNREADABLE $xserver_path to $new as root:" 762 | warn "" 763 | ls -l $xserver_path 1>&2 764 | warn "" 765 | warn "This only needs to be done once:" 766 | warn " cat $xserver_path > $new" 767 | warn "" 768 | nos=$nosudo 769 | if type sudo > /dev/null 2>&1; then 770 | : 771 | else 772 | nos=1 773 | fi 774 | if [ "X$nos" = "X1" ]; then 775 | warn "Please supply root passwd to 'su -c'" 776 | su -c "cat $xserver_path > $new" 777 | else 778 | warn "Please supply the sudo passwd if asked:" 779 | sudo /bin/sh -c "cat $xserver_path > $new" 780 | fi 781 | else 782 | warn "" 783 | warn "COPYING SETUID $xserver_path to $new" 784 | warn "" 785 | ls -l $xserver_path 1>&2 786 | warn "" 787 | cat $xserver_path > $new 788 | fi 789 | ls -l $new 790 | if [ -s $new ]; then 791 | : 792 | else 793 | rm -f $new 794 | ls -l $new 795 | exit 1 796 | fi 797 | warn "" 798 | warn "Please restart Xdummy now." 799 | exit 0 800 | fi 801 | if [ ! -O $new ]; then 802 | warn "file \"$new\" not owned by us!" 803 | ls -l $new 804 | exit 1 805 | fi 806 | xserver=$new 807 | fi 808 | fi 809 | 810 | # Work out display: 811 | # 812 | if [ "X$disp" != "X" ]; then 813 | : 814 | elif [ "X$1" != "X" ]; then 815 | if echo "$1" | grep '^:[0-9]' > /dev/null; then 816 | disp=$1 817 | shift 818 | elif [ "X$1" = "X:" ]; then 819 | # ":" means for us to find one. 820 | shift 821 | fi 822 | fi 823 | if [ "X$disp" = "X" -o "X$disp" = "X:" ]; then 824 | # try to find an open display port: 825 | # (tcp outdated...) 826 | ports=`netstat -ant | grep LISTEN | awk '{print $4}' | sed -e 's/^.*://'` 827 | n=0 828 | while [ $n -le 20 ] 829 | do 830 | port=`printf "60%02d" $n` 831 | if echo "$ports" | grep "^${port}\$" > /dev/null; then 832 | : 833 | else 834 | disp=":$n" 835 | warn "$program: auto-selected DISPLAY $disp" 836 | break 837 | fi 838 | n=`expr $n + 1` 839 | done 840 | fi 841 | 842 | # Work out which vt to use, try to find/guess an open one if necessary. 843 | # 844 | vt="" 845 | for arg in $* 846 | do 847 | if echo "$arg" | grep '^vt' > /dev/null; then 848 | vt=$arg 849 | break 850 | fi 851 | done 852 | if [ "X$vt" = "X" ]; then 853 | if [ "X$user" = "Xroot" ]; then 854 | # root can user fuser(1) to see if it is in use: 855 | if type fuser >/dev/null 2>&1; then 856 | # try /dev/tty17 thru /dev/tty32 857 | n=17 858 | while [ $n -le 32 ] 859 | do 860 | dev="/dev/tty$n" 861 | if fuser $dev >/dev/null 2>&1; then 862 | : 863 | else 864 | vt="vt$n" 865 | warn "$program: auto-selected VT $vt => $dev" 866 | break 867 | fi 868 | n=`expr $n + 1` 869 | done 870 | fi 871 | fi 872 | if [ "X$vt" = "X" ]; then 873 | # take a wild guess... 874 | vt=vt16 875 | warn "$program: selected fallback VT $vt" 876 | fi 877 | else 878 | vt="" 879 | fi 880 | 881 | # Decide flavor of Xserver: 882 | # 883 | stype=`basename "$xserver"` 884 | if echo "$stype" | grep -i xfree86 > /dev/null; then 885 | stype=xfree86 886 | else 887 | stype=xorg 888 | fi 889 | 890 | tweak_config() { 891 | in="$1" 892 | config2="$XDUMMY_TMPDIR/xdummy_modified_xconfig.conf" 893 | if [ "X$disp" != "X" ]; then 894 | d=`echo "$disp" | sed -e 's,/,,g' -e 's/:/_/g'` 895 | config2="$config2$d" 896 | fi 897 | 898 | # perl script to tweak the config file... add/delete options, etc. 899 | # 900 | env XDUMMY_GEOM=$geom \ 901 | XDUMMY_DEPTH=$depth \ 902 | XDUMMY_NOMODELINES=$nomodelines \ 903 | perl > $config2 < $in -e ' 904 | $n = 0; 905 | $geom = $ENV{XDUMMY_GEOM}; 906 | $depth = $ENV{XDUMMY_DEPTH}; 907 | $nomodelines = $ENV{XDUMMY_NOMODELINES}; 908 | $mode_str = ""; 909 | $videoram = "24000"; 910 | $HorizSync = "30.0 - 130.0"; 911 | $VertRefresh = "50.0 - 250.0"; 912 | if ($geom ne "") { 913 | my $tmp = ""; 914 | foreach $g (split(/,/, $geom)) { 915 | $tmp .= "\"$g\" "; 916 | if (!$nomodelines && $g =~ /(\d+)x(\d+)/) { 917 | my $w = $1; 918 | my $h = $2; 919 | $mode_str .= " Modeline \"$g\" "; 920 | my $dot = sprintf("%.2f", $w * $h * 70 * 1.e-6); 921 | $mode_str .= $dot; 922 | $mode_str .= " " . $w; 923 | $mode_str .= " " . int(1.02 * $w); 924 | $mode_str .= " " . int(1.10 * $w); 925 | $mode_str .= " " . int(1.20 * $w); 926 | $mode_str .= " " . $h; 927 | $mode_str .= " " . int($h + 1); 928 | $mode_str .= " " . int($h + 3); 929 | $mode_str .= " " . int($h + 20); 930 | $mode_str .= "\n"; 931 | } 932 | } 933 | $tmp =~ s/\s*$//; 934 | $geom = $tmp; 935 | } 936 | while (<>) { 937 | if ($ENV{XDUMMY_NOTWEAK}) { 938 | print $_; 939 | next; 940 | } 941 | $n++; 942 | if (/^\s*#/) { 943 | # pass comments straight thru 944 | print; 945 | next; 946 | } 947 | if (/^\s*Section\s+(\S+)/i) { 948 | # start of Section 949 | $sect = $1; 950 | $sect =~ s/\W//g; 951 | $sect =~ y/A-Z/a-z/; 952 | $sects{$sect} = 1; 953 | print; 954 | next; 955 | } 956 | if (/^\s*EndSection/i) { 957 | # end of Section 958 | if ($sect eq "serverflags") { 959 | if (!$got_DontVTSwitch) { 960 | print " ##Xdummy:##\n"; 961 | print " Option \"DontVTSwitch\" \"true\"\n"; 962 | } 963 | if (!$got_AllowMouseOpenFail) { 964 | print " ##Xdummy:##\n"; 965 | print " Option \"AllowMouseOpenFail\" \"true\"\n"; 966 | } 967 | if (!$got_PciForceNone) { 968 | print " ##Xdummy:##\n"; 969 | print " Option \"PciForceNone\" \"true\"\n"; 970 | } 971 | } elsif ($sect eq "device") { 972 | if (!$got_Driver) { 973 | print " ##Xdummy:##\n"; 974 | print " Driver \"dummy\"\n"; 975 | } 976 | if (!$got_VideoRam) { 977 | print " ##Xdummy:##\n"; 978 | print " VideoRam $videoram\n"; 979 | } 980 | } elsif ($sect eq "screen") { 981 | if ($depth ne "" && !got_DefaultDepth) { 982 | print " ##Xdummy:##\n"; 983 | print " DefaultDepth $depth\n"; 984 | } 985 | if ($got_Monitor eq "") { 986 | print " ##Xdummy:##\n"; 987 | print " Monitor \"Monitor0\"\n"; 988 | } 989 | } elsif ($sect eq "monitor") { 990 | if (!got_HorizSync) { 991 | print " ##Xdummy:##\n"; 992 | print " HorizSync $HorizSync\n"; 993 | } 994 | if (!got_VertRefresh) { 995 | print " ##Xdummy:##\n"; 996 | print " VertRefresh $VertRefresh\n"; 997 | } 998 | if (!$nomodelines) { 999 | print " ##Xdummy:##\n"; 1000 | print $mode_str; 1001 | } 1002 | } 1003 | $sect = ""; 1004 | print; 1005 | next; 1006 | } 1007 | 1008 | if (/^\s*SubSection\s+(\S+)/i) { 1009 | # start of Section 1010 | $subsect = $1; 1011 | $subsect =~ s/\W//g; 1012 | $subsect =~ y/A-Z/a-z/; 1013 | $subsects{$subsect} = 1; 1014 | if ($sect eq "screen" && $subsect eq "display") { 1015 | $got_Modes = 0; 1016 | } 1017 | print; 1018 | next; 1019 | } 1020 | if (/^\s*EndSubSection/i) { 1021 | # end of SubSection 1022 | if ($sect eq "screen") { 1023 | if ($subsect eq "display") { 1024 | if ($depth ne "" && !$set_Depth) { 1025 | print " ##Xdummy:##\n"; 1026 | print " Depth\t$depth\n"; 1027 | } 1028 | if ($geom ne "" && ! $got_Modes) { 1029 | print " ##Xdummy:##\n"; 1030 | print " Modes\t$geom\n"; 1031 | } 1032 | } 1033 | } 1034 | $subsect = ""; 1035 | print; 1036 | next; 1037 | } 1038 | 1039 | $l = $_; 1040 | $l =~ s/#.*$//; 1041 | if ($sect eq "serverflags") { 1042 | if ($l =~ /^\s*Option.*DontVTSwitch/i) { 1043 | $_ =~ s/false/true/ig; 1044 | $got_DontVTSwitch = 1; 1045 | } 1046 | if ($l =~ /^\s*Option.*AllowMouseOpenFail/i) { 1047 | $_ =~ s/false/true/ig; 1048 | $got_AllowMouseOpenFail = 1; 1049 | } 1050 | if ($l =~ /^\s*Option.*PciForceNone/i) { 1051 | $_ =~ s/false/true/ig; 1052 | $got_PciForceNone= 1; 1053 | } 1054 | } 1055 | if ($sect eq "module") { 1056 | if ($l =~ /^\s*Load.*\b(dri|fbdevhw)\b/i) { 1057 | $_ = "##Xdummy## $_"; 1058 | } 1059 | } 1060 | if ($sect eq "monitor") { 1061 | if ($l =~ /^\s*HorizSync/i) { 1062 | $got_HorizSync = 1; 1063 | } 1064 | if ($l =~ /^\s*VertRefresh/i) { 1065 | $got_VertRefresh = 1; 1066 | } 1067 | } 1068 | if ($sect eq "device") { 1069 | if ($l =~ /^(\s*Driver)\b/i) { 1070 | $_ = "$1 \"dummy\"\n"; 1071 | $got_Driver = 1; 1072 | } 1073 | if ($l =~ /^\s*VideoRam/i) { 1074 | $got_VideoRam= 1; 1075 | } 1076 | } 1077 | if ($sect eq "inputdevice") { 1078 | if ($l =~ /^\s*Option.*\bDevice\b/i) { 1079 | print " ##Xdummy:##\n"; 1080 | $_ = " Option \"Device\" \"/dev/dilbert$n\"\n"; 1081 | } 1082 | } 1083 | if ($sect eq "screen") { 1084 | if ($l =~ /^\s*DefaultDepth\s+(\d+)/i) { 1085 | if ($depth ne "") { 1086 | print " ##Xdummy:##\n"; 1087 | $_ = " DefaultDepth\t$depth\n"; 1088 | } 1089 | $got_DefaultDepth = 1; 1090 | } 1091 | if ($l =~ /^\s*Monitor\s+(\S+)/i) { 1092 | $got_Monitor = $1; 1093 | $got_Monitor =~ s/"//g; 1094 | } 1095 | if ($subsect eq "display") { 1096 | if ($geom ne "") { 1097 | if ($l =~ /^(\s*Modes)\b/i) { 1098 | print " ##Xdummy:##\n"; 1099 | $_ = "$1 $geom\n"; 1100 | $got_Modes = 1; 1101 | } 1102 | } 1103 | if ($l =~ /^\s*Depth\s+(\d+)/i) { 1104 | my $d = $1; 1105 | if (!$set_Depth && $depth ne "") { 1106 | $set_Depth = 1; 1107 | if ($depth != $d) { 1108 | print " ##Xdummy:##\n"; 1109 | $_ = " Depth\t$depth\n"; 1110 | } 1111 | } 1112 | } 1113 | } 1114 | } 1115 | print; 1116 | } 1117 | if ($ENV{XDUMMY_NOTWEAK}) { 1118 | exit; 1119 | } 1120 | # create any crucial sections that are missing: 1121 | if (! exists($sects{serverflags})) { 1122 | print "\n##Xdummy:##\n"; 1123 | print "Section \"ServerFlags\"\n"; 1124 | print " Option \"DontVTSwitch\" \"true\"\n"; 1125 | print " Option \"AllowMouseOpenFail\" \"true\"\n"; 1126 | print " Option \"PciForceNone\" \"true\"\n"; 1127 | print "EndSection\n"; 1128 | } 1129 | if (! exists($sects{device})) { 1130 | print "\n##Xdummy:##\n"; 1131 | print "Section \"Device\"\n"; 1132 | print " Identifier \"Videocard0\"\n"; 1133 | print " Driver \"dummy\"\n"; 1134 | print " VideoRam $videoram\n"; 1135 | print "EndSection\n"; 1136 | } 1137 | if (! exists($sects{monitor})) { 1138 | print "\n##Xdummy:##\n"; 1139 | print "Section \"Monitor\"\n"; 1140 | print " Identifier \"Monitor0\"\n"; 1141 | print " HorizSync $HorizSync\n"; 1142 | print " VertRefresh $VertRefresh\n"; 1143 | print "EndSection\n"; 1144 | } 1145 | if (! exists($sects{screen})) { 1146 | print "\n##Xdummy:##\n"; 1147 | print "Section \"Screen\"\n"; 1148 | print " Identifier \"Screen0\"\n"; 1149 | print " Device \"Videocard0\"\n"; 1150 | if ($got_Monitor ne "") { 1151 | print " Monitor \"$got_Monitor\"\n"; 1152 | } else { 1153 | print " Monitor \"Monitor0\"\n"; 1154 | } 1155 | if ($depth ne "") { 1156 | print " DefaultDepth $depth\n"; 1157 | } else { 1158 | print " DefaultDepth 24\n"; 1159 | } 1160 | print " SubSection \"Display\"\n"; 1161 | print " Viewport 0 0\n"; 1162 | print " Depth 24\n"; 1163 | if ($got_Modes) { 1164 | ; 1165 | } elsif ($geom ne "") { 1166 | print " Modes $geom\n"; 1167 | } else { 1168 | print " Modes \"1280x1024\" \"1024x768\" \"800x600\"\n"; 1169 | } 1170 | print " EndSubSection\n"; 1171 | print "EndSection\n"; 1172 | } 1173 | '; 1174 | } 1175 | 1176 | # Work out config file and tweak it. 1177 | # 1178 | if [ "X$cmdline_config" = "X" ]; then 1179 | : 1180 | elif [ "X$cmdline_config" = "Xxdummy-builtin" ]; then 1181 | : 1182 | elif echo "$cmdline_config" | grep '/' > /dev/null; then 1183 | : 1184 | else 1185 | # ignore basename only case (let server handle it) 1186 | cmdline_config="" 1187 | notweak=1 1188 | fi 1189 | 1190 | config=$cmdline_config 1191 | 1192 | if [ "X$notweak" = "X1" -a "X$root" = "X" -a -f "$cmdline_config" ]; then 1193 | # if not root we need to copy (but not tweak) the specified config. 1194 | XDUMMY_NOTWEAK=1 1195 | export XDUMMY_NOTWEAK 1196 | notweak="" 1197 | fi 1198 | 1199 | if [ ! $notweak ]; then 1200 | # tweaked config will be put in $config2: 1201 | config2="" 1202 | if [ "X$config" = "X" ]; then 1203 | # use the default one: 1204 | if [ "X$stype" = "Xxorg" ]; then 1205 | config=/etc/X11/xorg.conf 1206 | else 1207 | if [ -f "/etc/X11/XF86Config-4" ]; then 1208 | config="/etc/X11/XF86Config-4" 1209 | else 1210 | config="/etc/X11/XF86Config" 1211 | fi 1212 | fi 1213 | if [ ! -f "$config" ]; then 1214 | for c in /etc/X11/xorg.conf /etc/X11/XF86Config-4 /etc/X11/XF86Config 1215 | do 1216 | if [ -f $c ]; then 1217 | config=$c 1218 | break 1219 | fi 1220 | done 1221 | fi 1222 | fi 1223 | 1224 | if [ "X$config" = "Xxdummy-builtin" ]; then 1225 | config="" 1226 | fi 1227 | 1228 | if [ ! -f "$config" ]; then 1229 | config="$XDUMMY_TMPDIR/xorg.conf" 1230 | warn "$program: using minimal built-in xorg.conf settings." 1231 | cat > $config < /dev/null; then 1359 | so=`echo "$so" | sed -e "s,^\.,$pwd,"` 1360 | fi 1361 | if echo "$so" | grep '/' > /dev/null; then 1362 | : 1363 | else 1364 | so="$pwd/$so" 1365 | fi 1366 | warn "env LD_PRELOAD=$so $xserver $disp $args $vt" 1367 | warn "" 1368 | if [ ! $runit ]; then 1369 | exit 0 1370 | fi 1371 | fi 1372 | 1373 | if [ $strace ]; then 1374 | if [ "X$strace" = "X2" ]; then 1375 | ltrace -f env LD_PRELOAD=$SO $xserver $disp $args $vt 1376 | else 1377 | strace -f env LD_PRELOAD=$SO $xserver $disp $args $vt 1378 | fi 1379 | else 1380 | exec env LD_PRELOAD=$SO $xserver $disp $args $vt 1381 | fi 1382 | 1383 | exit $? 1384 | 1385 | ######################################################################### 1386 | 1387 | code() { 1388 | #code_begin 1389 | #include 1390 | #define O_ACCMODE 0003 1391 | #define O_RDONLY 00 1392 | #define O_WRONLY 01 1393 | #define O_RDWR 02 1394 | #define O_CREAT 0100 /* not fcntl */ 1395 | #define O_EXCL 0200 /* not fcntl */ 1396 | #define O_NOCTTY 0400 /* not fcntl */ 1397 | #define O_TRUNC 01000 /* not fcntl */ 1398 | #define O_APPEND 02000 1399 | #define O_NONBLOCK 04000 1400 | #define O_NDELAY O_NONBLOCK 1401 | #define O_SYNC 010000 1402 | #define O_FSYNC O_SYNC 1403 | #define O_ASYNC 020000 1404 | 1405 | #include 1406 | #include 1407 | #include 1408 | 1409 | #include 1410 | #include 1411 | 1412 | #define __USE_GNU 1413 | #include 1414 | 1415 | static char tmpdir[4096]; 1416 | static char str1[4096]; 1417 | static char str2[4096]; 1418 | 1419 | static char devs[256][1024]; 1420 | static int debug = -1; 1421 | static int root = -1; 1422 | static int changed_uid = 0; 1423 | static int saw_fonts = 0; 1424 | static int saw_lib_modules = 0; 1425 | 1426 | static time_t start = 0; 1427 | 1428 | void check_debug(void) { 1429 | if (debug < 0) { 1430 | if (getenv("XDUMMY_DEBUG") != NULL) { 1431 | debug = 1; 1432 | } else { 1433 | debug = 0; 1434 | } 1435 | /* prevent other processes using the preload: */ 1436 | putenv("LD_PRELOAD="); 1437 | } 1438 | } 1439 | void check_root(void) { 1440 | if (root < 0) { 1441 | /* script tells us if we are root */ 1442 | if (getenv("XDUMMY_ROOT") != NULL) { 1443 | root = 1; 1444 | } else { 1445 | root = 0; 1446 | } 1447 | } 1448 | } 1449 | 1450 | void check_uid(void) { 1451 | if (start == 0) { 1452 | start = time(NULL); 1453 | if (debug) fprintf(stderr, "START: %u\n", (unsigned int) start); 1454 | return; 1455 | } else if (changed_uid == 0) { 1456 | if (saw_fonts || time(NULL) > start + 20) { 1457 | if (getenv("XDUMMY_UID")) { 1458 | int uid = atoi(getenv("XDUMMY_UID")); 1459 | if (debug) fprintf(stderr, "SETREUID: %d saw_fonts=%d\n", uid, saw_fonts); 1460 | if (uid >= 0) { 1461 | /* this will simply fail in -nonroot mode: */ 1462 | setreuid(uid, -1); 1463 | } 1464 | } 1465 | changed_uid = 1; 1466 | } 1467 | } 1468 | } 1469 | 1470 | #define CHECKIT if (debug < 0) check_debug(); \ 1471 | if (root < 0) check_root(); \ 1472 | check_uid(); 1473 | 1474 | static void set_tmpdir(void) { 1475 | char *s; 1476 | static int didset = 0; 1477 | if (didset) { 1478 | return; 1479 | } 1480 | s = getenv("XDUMMY_TMPDIR"); 1481 | if (! s) { 1482 | s = "/tmp"; 1483 | } 1484 | tmpdir[0] = '\0'; 1485 | strcat(tmpdir, s); 1486 | strcat(tmpdir, "/"); 1487 | didset = 1; 1488 | } 1489 | 1490 | static char *tmpdir_path(const char *path) { 1491 | char *str; 1492 | set_tmpdir(); 1493 | strcpy(str2, path); 1494 | str = str2; 1495 | while (*str) { 1496 | if (*str == '/') { 1497 | *str = '_'; 1498 | } 1499 | str++; 1500 | } 1501 | strcpy(str1, tmpdir); 1502 | strcat(str1, str2); 1503 | return str1; 1504 | } 1505 | 1506 | int open(const char *pathname, int flags, unsigned short mode) { 1507 | int fd; 1508 | char *store_dev = NULL; 1509 | static int (*real_open)(const char *, int , unsigned short) = NULL; 1510 | 1511 | CHECKIT 1512 | if (! real_open) { 1513 | real_open = (int (*)(const char *, int , unsigned short)) 1514 | dlsym(RTLD_NEXT, "open"); 1515 | } 1516 | 1517 | if (strstr(pathname, "lib/modules/")) { 1518 | /* not currently used. */ 1519 | saw_lib_modules = 1; 1520 | } 1521 | 1522 | if (!root) { 1523 | if (strstr(pathname, "/dev/") == pathname) { 1524 | store_dev = strdup(pathname); 1525 | } 1526 | if (strstr(pathname, "/dev/tty") == pathname && strcmp(pathname, "/dev/tty")) { 1527 | pathname = tmpdir_path(pathname); 1528 | if (debug) fprintf(stderr, "OPEN: %s -> %s (as FIFO)\n", store_dev, pathname); 1529 | /* we make it a FIFO so ioctl on it does not fail */ 1530 | unlink(pathname); 1531 | mkfifo(pathname, 0666); 1532 | } else if (0) { 1533 | /* we used to handle more /dev files ... */ 1534 | fd = real_open(pathname, O_WRONLY|O_CREAT, 0777); 1535 | close(fd); 1536 | } 1537 | } 1538 | 1539 | fd = real_open(pathname, flags, mode); 1540 | 1541 | if (debug) fprintf(stderr, "OPEN: %s %d %d fd=%d\n", pathname, flags, mode, fd); 1542 | 1543 | if (! root) { 1544 | if (store_dev) { 1545 | if (fd < 256) { 1546 | strcpy(devs[fd], store_dev); 1547 | } 1548 | free(store_dev); 1549 | } 1550 | } 1551 | 1552 | return(fd); 1553 | } 1554 | 1555 | int open64(const char *pathname, int flags, unsigned short mode) { 1556 | int fd; 1557 | 1558 | CHECKIT 1559 | if (debug) fprintf(stderr, "OPEN64: %s %d %d\n", pathname, flags, mode); 1560 | 1561 | fd = open(pathname, flags, mode); 1562 | return(fd); 1563 | } 1564 | 1565 | int rename(const char *oldpath, const char *newpath) { 1566 | static int (*real_rename)(const char *, const char *) = NULL; 1567 | 1568 | CHECKIT 1569 | if (! real_rename) { 1570 | real_rename = (int (*)(const char *, const char *)) 1571 | dlsym(RTLD_NEXT, "rename"); 1572 | } 1573 | 1574 | if (debug) fprintf(stderr, "RENAME: %s %s\n", oldpath, newpath); 1575 | 1576 | if (root) { 1577 | return(real_rename(oldpath, newpath)); 1578 | } 1579 | 1580 | if (strstr(oldpath, "/var/log") == oldpath) { 1581 | if (debug) fprintf(stderr, "RENAME: returning 0\n"); 1582 | return 0; 1583 | } 1584 | return(real_rename(oldpath, newpath)); 1585 | } 1586 | 1587 | FILE *fopen(const char *pathname, const char *mode) { 1588 | static FILE* (*real_fopen)(const char *, const char *) = NULL; 1589 | char *str; 1590 | 1591 | if (! saw_fonts) { 1592 | if (strstr(pathname, "/fonts/")) { 1593 | if (strstr(pathname, "fonts.dir")) { 1594 | saw_fonts = 1; 1595 | } else if (strstr(pathname, "fonts.alias")) { 1596 | saw_fonts = 1; 1597 | } 1598 | } 1599 | } 1600 | 1601 | CHECKIT 1602 | if (! real_fopen) { 1603 | real_fopen = (FILE* (*)(const char *, const char *)) 1604 | dlsym(RTLD_NEXT, "fopen"); 1605 | } 1606 | 1607 | if (debug) fprintf(stderr, "FOPEN: %s %s\n", pathname, mode); 1608 | 1609 | if (strstr(pathname, "xdummy_modified_xconfig.conf")) { 1610 | /* make our config appear to be in /etc/X11, etc. */ 1611 | char *q = strrchr(pathname, '/'); 1612 | if (q != NULL && getenv("XDUMMY_TMPDIR") != NULL) { 1613 | strcpy(str1, getenv("XDUMMY_TMPDIR")); 1614 | strcat(str1, q); 1615 | if (debug) fprintf(stderr, "FOPEN: %s -> %s\n", pathname, str1); 1616 | pathname = str1; 1617 | } 1618 | } 1619 | 1620 | if (root) { 1621 | return(real_fopen(pathname, mode)); 1622 | } 1623 | 1624 | str = (char *) pathname; 1625 | if (strstr(pathname, "/var/log") == pathname) { 1626 | str = tmpdir_path(pathname); 1627 | if (debug) fprintf(stderr, "FOPEN: %s -> %s\n", pathname, str); 1628 | } 1629 | return(real_fopen(str, mode)); 1630 | } 1631 | 1632 | 1633 | #define RETURN0 if (debug) \ 1634 | {fprintf(stderr, "IOCTL: covered %d 0x%x\n", fd, req);} return 0; 1635 | #define RETURN1 if (debug) \ 1636 | {fprintf(stderr, "IOCTL: covered %d 0x%x\n", fd, req);} return -1; 1637 | 1638 | int ioctl(int fd, int req, void *ptr) { 1639 | static int closed_xf86Info_consoleFd = 0; 1640 | static int (*real_ioctl)(int, int , void *) = NULL; 1641 | 1642 | CHECKIT 1643 | if (! real_ioctl) { 1644 | real_ioctl = (int (*)(int, int , void *)) 1645 | dlsym(RTLD_NEXT, "open"); 1646 | } 1647 | if (debug) fprintf(stderr, "IOCTL: %d 0x%x %p\n", fd, req, ptr); 1648 | 1649 | /* based on xorg-x11-6.8.1-dualhead.patch */ 1650 | if (req == VT_GETMODE) { 1651 | /* close(xf86Info.consoleFd) */ 1652 | if (0 && ! closed_xf86Info_consoleFd) { 1653 | /* I think better not to close it... */ 1654 | close(fd); 1655 | closed_xf86Info_consoleFd = 1; 1656 | } 1657 | RETURN0 1658 | } else if (req == VT_SETMODE) { 1659 | RETURN0 1660 | } else if (req == VT_GETSTATE) { 1661 | RETURN0 1662 | } else if (req == KDSETMODE) { 1663 | RETURN0 1664 | } else if (req == KDSETLED) { 1665 | RETURN0 1666 | } else if (req == KDGKBMODE) { 1667 | RETURN0 1668 | } else if (req == KDSKBMODE) { 1669 | RETURN0 1670 | } else if (req == VT_ACTIVATE) { 1671 | RETURN0 1672 | } else if (req == VT_WAITACTIVE) { 1673 | RETURN0 1674 | } else if (req == VT_RELDISP) { 1675 | if (ptr == (void *) 1) { 1676 | RETURN1 1677 | } else if (ptr == (void *) VT_ACKACQ) { 1678 | RETURN0 1679 | } 1680 | } 1681 | 1682 | return(real_ioctl(fd, req, ptr)); 1683 | } 1684 | 1685 | typedef void (*sighandler_t)(int); 1686 | #define SIGUSR1 10 1687 | #define SIG_DFL ((sighandler_t)0) 1688 | 1689 | sighandler_t signal(int signum, sighandler_t handler) { 1690 | static sighandler_t (*real_signal)(int, sighandler_t) = NULL; 1691 | 1692 | CHECKIT 1693 | if (! real_signal) { 1694 | real_signal = (sighandler_t (*)(int, sighandler_t)) 1695 | dlsym(RTLD_NEXT, "signal"); 1696 | } 1697 | 1698 | if (debug) fprintf(stderr, "SIGNAL: %d %p\n", signum, handler); 1699 | 1700 | if (signum == SIGUSR1) { 1701 | if (debug) fprintf(stderr, "SIGNAL: skip SIGUSR1\n"); 1702 | return SIG_DFL; 1703 | } 1704 | 1705 | return(real_signal(signum, handler)); 1706 | } 1707 | 1708 | int close(int fd) { 1709 | static int (*real_close)(int) = NULL; 1710 | 1711 | CHECKIT 1712 | if (! real_close) { 1713 | real_close = (int (*)(int)) dlsym(RTLD_NEXT, "close"); 1714 | } 1715 | 1716 | if (debug) fprintf(stderr, "CLOSE: %d\n", fd); 1717 | if (!root) { 1718 | if (fd < 256) { 1719 | devs[fd][0] = '\0'; 1720 | } 1721 | } 1722 | return(real_close(fd)); 1723 | } 1724 | 1725 | struct stat { 1726 | int foo; 1727 | }; 1728 | 1729 | int stat(const char *path, struct stat *buf) { 1730 | static int (*real_stat)(const char *, struct stat *) = NULL; 1731 | 1732 | CHECKIT 1733 | if (! real_stat) { 1734 | real_stat = (int (*)(const char *, struct stat *)) 1735 | dlsym(RTLD_NEXT, "stat"); 1736 | } 1737 | 1738 | if (debug) fprintf(stderr, "STAT: %s\n", path); 1739 | 1740 | return(real_stat(path, buf)); 1741 | } 1742 | 1743 | int stat64(const char *path, struct stat *buf) { 1744 | static int (*real_stat64)(const char *, struct stat *) = NULL; 1745 | 1746 | CHECKIT 1747 | if (! real_stat64) { 1748 | real_stat64 = (int (*)(const char *, struct stat *)) 1749 | dlsym(RTLD_NEXT, "stat64"); 1750 | } 1751 | 1752 | if (debug) fprintf(stderr, "STAT64: %s\n", path); 1753 | 1754 | return(real_stat64(path, buf)); 1755 | } 1756 | 1757 | int chown(const char *path, uid_t owner, gid_t group) { 1758 | static int (*real_chown)(const char *, uid_t, gid_t) = NULL; 1759 | 1760 | CHECKIT 1761 | if (! real_chown) { 1762 | real_chown = (int (*)(const char *, uid_t, gid_t)) 1763 | dlsym(RTLD_NEXT, "chown"); 1764 | } 1765 | 1766 | if (root) { 1767 | return(real_chown(path, owner, group)); 1768 | } 1769 | 1770 | if (debug) fprintf(stderr, "CHOWN: %s %d %d\n", path, owner, group); 1771 | 1772 | if (strstr(path, "/dev") == path) { 1773 | if (debug) fprintf(stderr, "CHOWN: return 0\n"); 1774 | return 0; 1775 | } 1776 | 1777 | return(real_chown(path, owner, group)); 1778 | } 1779 | 1780 | extern int *__errno_location (void); 1781 | #ifndef ENODEV 1782 | #define ENODEV 19 1783 | #endif 1784 | 1785 | int ioperm(unsigned long from, unsigned long num, int turn_on) { 1786 | static int (*real_ioperm)(unsigned long, unsigned long, int) = NULL; 1787 | 1788 | CHECKIT 1789 | if (! real_ioperm) { 1790 | real_ioperm = (int (*)(unsigned long, unsigned long, int)) 1791 | dlsym(RTLD_NEXT, "ioperm"); 1792 | } 1793 | if (debug) fprintf(stderr, "IOPERM: %d %d %d\n", (int) from, (int) num, turn_on); 1794 | if (root) { 1795 | return(real_ioperm(from, num, turn_on)); 1796 | } 1797 | if (from == 0 && num == 1024 && turn_on == 1) { 1798 | /* we want xf86EnableIO to fail */ 1799 | if (debug) fprintf(stderr, "IOPERM: setting ENODEV.\n"); 1800 | *__errno_location() = ENODEV; 1801 | return -1; 1802 | } 1803 | return 0; 1804 | } 1805 | 1806 | int iopl(int level) { 1807 | static int (*real_iopl)(int) = NULL; 1808 | 1809 | CHECKIT 1810 | if (! real_iopl) { 1811 | real_iopl = (int (*)(int)) dlsym(RTLD_NEXT, "iopl"); 1812 | } 1813 | if (debug) fprintf(stderr, "IOPL: %d\n", level); 1814 | if (root) { 1815 | return(real_iopl(level)); 1816 | } 1817 | return 0; 1818 | } 1819 | 1820 | #ifdef INTERPOSE_GETUID 1821 | 1822 | /* 1823 | * we got things to work w/o pretending to be root. 1824 | * so we no longer interpose getuid(), etc. 1825 | */ 1826 | 1827 | uid_t getuid(void) { 1828 | static uid_t (*real_getuid)(void) = NULL; 1829 | CHECKIT 1830 | if (! real_getuid) { 1831 | real_getuid = (uid_t (*)(void)) dlsym(RTLD_NEXT, "getuid"); 1832 | } 1833 | if (root) { 1834 | return(real_getuid()); 1835 | } 1836 | if (debug) fprintf(stderr, "GETUID: 0\n"); 1837 | return 0; 1838 | } 1839 | uid_t geteuid(void) { 1840 | static uid_t (*real_geteuid)(void) = NULL; 1841 | CHECKIT 1842 | if (! real_geteuid) { 1843 | real_geteuid = (uid_t (*)(void)) dlsym(RTLD_NEXT, "geteuid"); 1844 | } 1845 | if (root) { 1846 | return(real_geteuid()); 1847 | } 1848 | if (debug) fprintf(stderr, "GETEUID: 0\n"); 1849 | return 0; 1850 | } 1851 | uid_t geteuid_kludge1(void) { 1852 | static uid_t (*real_geteuid)(void) = NULL; 1853 | CHECKIT 1854 | if (! real_geteuid) { 1855 | real_geteuid = (uid_t (*)(void)) dlsym(RTLD_NEXT, "geteuid"); 1856 | } 1857 | if (debug) fprintf(stderr, "GETEUID: 0 saw_libmodules=%d\n", saw_lib_modules); 1858 | if (root && !saw_lib_modules) { 1859 | return(real_geteuid()); 1860 | } else { 1861 | saw_lib_modules = 0; 1862 | return 0; 1863 | } 1864 | } 1865 | 1866 | uid_t getuid32(void) { 1867 | static uid_t (*real_getuid32)(void) = NULL; 1868 | CHECKIT 1869 | if (! real_getuid32) { 1870 | real_getuid32 = (uid_t (*)(void)) dlsym(RTLD_NEXT, "getuid32"); 1871 | } 1872 | if (root) { 1873 | return(real_getuid32()); 1874 | } 1875 | if (debug) fprintf(stderr, "GETUID32: 0\n"); 1876 | return 0; 1877 | } 1878 | uid_t geteuid32(void) { 1879 | static uid_t (*real_geteuid32)(void) = NULL; 1880 | CHECKIT 1881 | if (! real_geteuid32) { 1882 | real_geteuid32 = (uid_t (*)(void)) dlsym(RTLD_NEXT, "geteuid32"); 1883 | } 1884 | if (root) { 1885 | return(real_geteuid32()); 1886 | } 1887 | if (debug) fprintf(stderr, "GETEUID32: 0\n"); 1888 | return 0; 1889 | } 1890 | 1891 | gid_t getgid(void) { 1892 | static gid_t (*real_getgid)(void) = NULL; 1893 | CHECKIT 1894 | if (! real_getgid) { 1895 | real_getgid = (gid_t (*)(void)) dlsym(RTLD_NEXT, "getgid"); 1896 | } 1897 | if (root) { 1898 | return(real_getgid()); 1899 | } 1900 | if (debug) fprintf(stderr, "GETGID: 0\n"); 1901 | return 0; 1902 | } 1903 | gid_t getegid(void) { 1904 | static gid_t (*real_getegid)(void) = NULL; 1905 | CHECKIT 1906 | if (! real_getegid) { 1907 | real_getegid = (gid_t (*)(void)) dlsym(RTLD_NEXT, "getegid"); 1908 | } 1909 | if (root) { 1910 | return(real_getegid()); 1911 | } 1912 | if (debug) fprintf(stderr, "GETEGID: 0\n"); 1913 | return 0; 1914 | } 1915 | gid_t getgid32(void) { 1916 | static gid_t (*real_getgid32)(void) = NULL; 1917 | CHECKIT 1918 | if (! real_getgid32) { 1919 | real_getgid32 = (gid_t (*)(void)) dlsym(RTLD_NEXT, "getgid32"); 1920 | } 1921 | if (root) { 1922 | return(real_getgid32()); 1923 | } 1924 | if (debug) fprintf(stderr, "GETGID32: 0\n"); 1925 | return 0; 1926 | } 1927 | gid_t getegid32(void) { 1928 | static gid_t (*real_getegid32)(void) = NULL; 1929 | CHECKIT 1930 | if (! real_getegid32) { 1931 | real_getegid32 = (gid_t (*)(void)) dlsym(RTLD_NEXT, "getegid32"); 1932 | } 1933 | if (root) { 1934 | return(real_getegid32()); 1935 | } 1936 | if (debug) fprintf(stderr, "GETEGID32: 0\n"); 1937 | return 0; 1938 | } 1939 | #endif 1940 | 1941 | #if 0 1942 | /* maybe we need to interpose on strcmp someday... here is the template */ 1943 | int strcmp(const char *s1, const char *s2) { 1944 | static int (*real_strcmp)(const char *, const char *) = NULL; 1945 | CHECKIT 1946 | if (! real_strcmp) { 1947 | real_strcmp = (int (*)(const char *, const char *)) dlsym(RTLD_NEXT, "strcmp"); 1948 | } 1949 | if (debug) fprintf(stderr, "STRCMP: '%s' '%s'\n", s1, s2); 1950 | return(real_strcmp(s1, s2)); 1951 | } 1952 | #endif 1953 | 1954 | #code_end 1955 | } 1956 | -------------------------------------------------------------------------------- /azure/launch_plan.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pdb 3 | 4 | from doodad.wrappers.easy_launch import sweep_function, save_doodad_config 5 | 6 | codepath = '/home/code' 7 | script = 'scripts/plan.py' 8 | 9 | def remote_fn(doodad_config, variant): 10 | ## get suffix range to allow running multiple trials per job 11 | n_suffices = variant['n_suffices'] 12 | suffix_start = variant['suffix_start'] 13 | del variant['n_suffices'] 14 | del variant['suffix_start'] 15 | 16 | kwarg_string = ' '.join([ 17 | f'--{k} {v}' for k, v in variant.items() 18 | ]) 19 | print(kwarg_string) 20 | 21 | d4rl_path = os.path.join(doodad_config.output_directory, 'datasets/') 22 | os.system(f'ls -a {codepath}') 23 | os.system(f'mv {codepath}/git {codepath}/.git') 24 | 25 | for suffix in range(suffix_start, suffix_start + n_suffices): 26 | os.system( 27 | f'''export PYTHONPATH=$PYTHONPATH:{codepath} && ''' 28 | f'''export CUDA_VISIBLE_DEVICES=0 && ''' 29 | f'''export D4RL_DATASET_DIR={d4rl_path} && ''' 30 | f'''python {os.path.join(codepath, script)} ''' 31 | f'''--suffix {suffix} ''' 32 | f'''{kwarg_string}''' 33 | 34 | ) 35 | 36 | save_doodad_config(doodad_config) 37 | 38 | if __name__ == "__main__": 39 | 40 | environments = ['ant'] 41 | buffers = ['medium-expert-v2', 'medium-v2', 'medium-replay-v2', 'random-v2'] 42 | datasets = [f'{env}-{buf}' for env in environments for buf in buffers] 43 | 44 | azure_logpath = 'defaults/' 45 | 46 | params_to_sweep = { 47 | 'dataset': datasets, 48 | 'horizon': [15], 49 | } 50 | 51 | default_params = { 52 | 'logbase': os.path.join('/doodad_tmp', azure_logpath, 'logs'), 53 | 'prefix': 'plans/azure/', 54 | 'verbose': False, 55 | 'suffix_start': 0, 56 | 'n_suffices': 3, 57 | } 58 | 59 | print(params_to_sweep) 60 | print(default_params) 61 | 62 | sweep_function( 63 | remote_fn, 64 | params_to_sweep, 65 | default_params=default_params, 66 | config_path=os.path.abspath('azure/config.py'), 67 | log_path=azure_logpath, 68 | azure_region='westus2', 69 | # gpu_model='nvidia-tesla-v100', 70 | gpu_model='nvidia-tesla-t4', 71 | filter_dir=['logs', 'bin', 'mount'], 72 | use_gpu=True, 73 | ) 74 | -------------------------------------------------------------------------------- /azure/launch_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pdb 3 | 4 | from doodad.wrappers.easy_launch import sweep_function, save_doodad_config 5 | 6 | codepath = '/home/code' 7 | script = 'scripts/train.py' 8 | 9 | def remote_fn(doodad_config, variant): 10 | kwarg_string = ' '.join([ 11 | f'--{k} {v}' for k, v in variant.items() 12 | ]) 13 | print(kwarg_string) 14 | 15 | d4rl_path = os.path.join(doodad_config.output_directory, 'datasets/') 16 | os.system(f'ls -a {codepath}') 17 | os.system(f'mv {codepath}/git {codepath}/.git') 18 | os.system( 19 | f'''export PYTHONPATH=$PYTHONPATH:{codepath} && ''' 20 | f'''export CUDA_VISIBLE_DEVICES=0 && ''' 21 | f'''export D4RL_DATASET_DIR={d4rl_path} && ''' 22 | f'''python {os.path.join(codepath, script)} ''' 23 | f'''{kwarg_string}''' 24 | 25 | ) 26 | save_doodad_config(doodad_config) 27 | 28 | if __name__ == "__main__": 29 | 30 | environments = ['halfcheetah', 'hopper', 'walker2d', 'ant'] 31 | buffers = ['expert-v2'] 32 | datasets = [f'{env}-{buf}' for env in environments for buf in buffers] 33 | 34 | azure_logpath = 'defaults/' 35 | 36 | params_to_sweep = { 37 | 'dataset': datasets, 38 | } 39 | 40 | default_params = { 41 | 'logbase': os.path.join('/doodad_tmp', azure_logpath, 'logs'), 42 | 'exp_name': 'gpt/azure', 43 | } 44 | 45 | sweep_function( 46 | remote_fn, 47 | params_to_sweep, 48 | default_params=default_params, 49 | config_path=os.path.abspath('azure/config.py'), 50 | log_path=azure_logpath, 51 | gpu_model='nvidia-tesla-v100', 52 | filter_dir=['logs', 'bin'], 53 | use_gpu=True, 54 | ) 55 | -------------------------------------------------------------------------------- /azure/make_fuse_config.sh: -------------------------------------------------------------------------------- 1 | ## AZURE_STORAGE_CONNECTION_STRING has a substring formatted lik: 2 | ## AccountName=${STORAGE_ACCOUNT};AccountKey=${STORAGE_KEY};EndpointSuffix= ... 3 | export AZURE_STORAGE_ACCOUNT=`(echo $AZURE_STORAGE_CONNECTION_STRING | grep -o -P '(?<=AccountName=).*(?=;AccountKey)')` 4 | export AZURE_STORAGE_KEY=`(echo $AZURE_STORAGE_CONNECTION_STRING | grep -o -P '(?<=AccountKey=).*(?=;EndpointSuffix)')` 5 | 6 | echo "accountName" ${AZURE_STORAGE_ACCOUNT} > ./azure/fuse.cfg 7 | echo "accountKey" ${AZURE_STORAGE_KEY} >> ./azure/fuse.cfg 8 | echo "containerName" ${AZURE_STORAGE_CONTAINER} >> ./azure/fuse.cfg -------------------------------------------------------------------------------- /azure/mount.sh: -------------------------------------------------------------------------------- 1 | mkdir ~/azure_mount 2 | blobfuse ~/azure_mount --tmp-path=/tmp --config-file=./azure/fuse.cfg 3 | -------------------------------------------------------------------------------- /azure/sync.sh: -------------------------------------------------------------------------------- 1 | if keyctl show 2>&1 | grep -q "workaroundSession"; 2 | then 3 | echo "Already logged in" 4 | else 5 | echo "Logging in with tenant id:" ${AZURE_TENANT_ID} 6 | keyctl session workaroundSession 7 | ./bin/azcopy login --tenant-id=$AZURE_TENANT_ID 8 | fi 9 | 10 | export LOGBASE=defaults 11 | 12 | ## AZURE_STORAGE_CONNECTION_STRING has a substring formatted lik: 13 | ## AccountName=${STORAGE_ACCOUNT};AccountKey= ... 14 | export AZURE_STORAGE_ACCOUNT=`(echo $AZURE_STORAGE_CONNECTION_STRING | grep -o -P '(?<=AccountName=).*(?=;AccountKey)')` 15 | 16 | echo "Syncing from" ${AZURE_STORAGE_ACCOUNT}"/"${AZURE_STORAGE_CONTAINER}"/"${LOGBASE} 17 | 18 | ./bin/azcopy sync https://${AZURE_STORAGE_ACCOUNT}.blob.core.windows.net/${AZURE_STORAGE_CONTAINER}/${LOGBASE}/logs logs/ --recursive -------------------------------------------------------------------------------- /config/offline.py: -------------------------------------------------------------------------------- 1 | from trajectory.utils import watch 2 | 3 | #------------------------ base ------------------------# 4 | 5 | logbase = 'logs/' 6 | gpt_expname = 'gpt/azure' 7 | 8 | ## automatically make experiment names for planning 9 | ## by labelling folders with these args 10 | args_to_watch = [ 11 | ('prefix', ''), 12 | ('plan_freq', 'freq'), 13 | ('horizon', 'H'), 14 | ('beam_width', 'beam'), 15 | ] 16 | 17 | base = { 18 | 19 | 'train': { 20 | 'N': 100, 21 | 'discount': 0.99, 22 | 'n_layer': 4, 23 | 'n_head': 4, 24 | 25 | ## number of epochs for a 1M-size dataset; n_epochs = 1M / dataset_size * n_epochs_ref 26 | 'n_epochs_ref': 50, 27 | 'n_saves': 3, 28 | 'logbase': logbase, 29 | 'device': 'cuda', 30 | 31 | 'n_embd': 32, 32 | 'batch_size': 256, 33 | 'learning_rate': 6e-4, 34 | 'lr_decay': True, 35 | 'seed': 42, 36 | 37 | 'embd_pdrop': 0.1, 38 | 'resid_pdrop': 0.1, 39 | 'attn_pdrop': 0.1, 40 | 41 | 'step': 1, 42 | 'subsampled_sequence_length': 10, 43 | 'termination_penalty': -100, 44 | 'exp_name': gpt_expname, 45 | 46 | 'discretizer': 'QuantileDiscretizer', 47 | 'action_weight': 5, 48 | 'reward_weight': 1, 49 | 'value_weight': 1, 50 | }, 51 | 52 | 'plan': { 53 | 'logbase': logbase, 54 | 'gpt_loadpath': gpt_expname, 55 | 'gpt_epoch': 'latest', 56 | 'device': 'cuda', 57 | 'renderer': 'Renderer', 58 | 59 | 'plan_freq': 1, 60 | 'horizon': 15, 61 | 'beam_width': 128, 62 | 'n_expand': 2, 63 | 64 | 'k_obs': 1, 65 | 'k_act': None, 66 | 'cdf_obs': None, 67 | 'cdf_act': 0.6, 68 | 'percentile': 'mean', 69 | 70 | 'max_context_transitions': 5, 71 | 'prefix_context': True, 72 | 73 | 'vis_freq': 50, 74 | 'exp_name': watch(args_to_watch), 75 | 'prefix': 'plans/defaults/', 76 | 'suffix': '0', 77 | 'verbose': True, 78 | }, 79 | 80 | } 81 | 82 | #------------------------ locomotion ------------------------# 83 | 84 | ## for all halfcheetah environments, you can reduce the planning horizon and beam width without 85 | ## affecting performance. good for speed and sanity. 86 | 87 | halfcheetah_medium_v2 = halfcheetah_medium_replay_v2 = { 88 | 'plan': { 89 | 'horizon': 5, 90 | 'beam_width': 32, 91 | } 92 | } 93 | 94 | halfcheetah_medium_expert_v2 = { 95 | 'plan': { 96 | 'beam_width': 32, 97 | }, 98 | } 99 | 100 | ## if you leave the dictionary empty, it will use the base parameters 101 | hopper_medium_expert_v2 = hopper_medium_v2 = walker2d_medium_v2 = {} 102 | 103 | ## hopper and wlaker2d are a little more sensitive to planning hyperparameters; 104 | ## proceed with caution when reducing the horizon or increasing the planning frequency 105 | 106 | hopper_medium_replay_v2 = { 107 | 'train': { 108 | ## train on the medium-replay datasets longer 109 | 'n_epochs_ref': 80, 110 | }, 111 | } 112 | 113 | walker2d_medium_expert_v2 = { 114 | 'plan': { 115 | ## also safe to reduce the horizon here 116 | 'horizon': 5, 117 | }, 118 | } 119 | 120 | walker2d_medium_replay_v2 = { 121 | 'train': { 122 | ## train on the medium-replay datasets longer 123 | 'n_epochs_ref': 80, 124 | }, 125 | 'plan': { 126 | ## can reduce beam width, but need to adjust action sampling 127 | ## distribution and increase horizon to accomodate 128 | 'horizon': 20, 129 | 'beam_width': 32, 130 | 'k_act': 40, 131 | 'cdf_act': None, 132 | } 133 | } 134 | 135 | ant_medium_v2 = ant_medium_replay_v2 = ant_random_v2 = { 136 | 'train': { 137 | ## reduce batch size because the dimensionality is larger 138 | 'batch_size': 128, 139 | }, 140 | 'plan': { 141 | 'horizon': 5, 142 | } 143 | } 144 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: trajectory 2 | channels: 3 | - defaults 4 | - conda-forge 5 | dependencies: 6 | - python=3.8 7 | - pip 8 | - patchelf 9 | - pip: 10 | - -f https://download.pytorch.org/whl/torch_stable.html 11 | - numpy 12 | - gym==0.18.0 13 | - mujoco-py==2.0.2.13 14 | - matplotlib==3.3.4 15 | - torch==1.9.1+cu111 16 | - typed-argument-parser 17 | - git+https://github.com/Farama-Foundation/d4rl@f2a05c0d66722499bf8031b094d9af3aea7c372b#egg=d4rl 18 | - scikit-image==0.17.2 19 | - scikit-video==1.1.11 20 | - gitpython 21 | -------------------------------------------------------------------------------- /plotting/bar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jannerm/trajectory-transformer/8834a6ed04ceeab8fdb9465e145c6e041c05d71b/plotting/bar.png -------------------------------------------------------------------------------- /plotting/plot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | import matplotlib.pyplot as plt 4 | import pdb 5 | 6 | from plotting.scores import means 7 | 8 | class Colors: 9 | grey = '#B4B4B4' 10 | gold = '#F6C781' 11 | red = '#EC7C7D' 12 | blue = '#70ABCC' 13 | 14 | LABELS = { 15 | # 'BC': 'Behavior\nCloning', 16 | # 'MBOP': 'Model-Based\nOffline Planning', 17 | # 'BRAC': 'Behavior-Reg.\nActor-Critic', 18 | # 'CQL': 'Conservative\nQ-Learning', 19 | } 20 | 21 | def get_mean(results, exclude=None): 22 | ''' 23 | results : { environment: score, ... } 24 | ''' 25 | filtered = { 26 | k: v for k, v in results.items() 27 | if (not exclude) or (exclude and exclude not in k) 28 | } 29 | return np.mean(list(filtered.values())) 30 | 31 | if __name__ == '__main__': 32 | 33 | ################# 34 | ## latex 35 | ################# 36 | matplotlib.rc('font', **{'family': 'serif', 'serif': ['Computer Modern']}) 37 | matplotlib.rc('text', usetex=True) 38 | matplotlib.rcParams['text.latex.preamble']=[r"\usepackage{amsmath}"] 39 | ################# 40 | 41 | fig = plt.gcf() 42 | ax = plt.gca() 43 | fig.set_size_inches(7.5, 2.5) 44 | 45 | means = {k: get_mean(v, exclude='ant') for k, v in means.items()} 46 | print(means) 47 | 48 | algs = ['BC', 'MBOP', 'BRAC', 'CQL', 'Decision\nTransformer', 'Trajectory\nTransformer'] 49 | vals = [means[alg] for alg in algs] 50 | 51 | colors = [ 52 | Colors.grey, Colors.gold, 53 | Colors.red, Colors.red, Colors.blue, Colors.blue 54 | ] 55 | 56 | labels = [LABELS.get(alg, alg) for alg in algs] 57 | plt.bar(labels, vals, color=colors, edgecolor=Colors.gold, lw=0) 58 | plt.ylabel('Average normalized return', labelpad=15) 59 | # plt.title('Offline RL Results') 60 | 61 | legend_labels = ['Behavior Cloning', 'Trajectory Optimization', 'Temporal Difference', 'Sequence Modeling'] 62 | colors = [Colors.grey, Colors.gold, Colors.red, Colors.blue] 63 | handles = [plt.Rectangle((0,0),1,1, color=color) for label, color in zip(legend_labels, colors)] 64 | plt.legend(handles, legend_labels, ncol=4, 65 | bbox_to_anchor=(1.07, -.18), fancybox=False, framealpha=0, shadow=False, columnspacing=1.5, handlelength=1.5) 66 | 67 | matplotlib.rcParams['hatch.linewidth'] = 7.5 68 | # ax.patches[-1].set_hatch('/') 69 | 70 | ax.spines['right'].set_visible(False) 71 | ax.spines['top'].set_visible(False) 72 | 73 | # plt.savefig('plotting/bar.pdf', bbox_inches='tight') 74 | plt.savefig('plotting/bar.png', bbox_inches='tight', dpi=500) 75 | -------------------------------------------------------------------------------- /plotting/read_results.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | import json 5 | import pdb 6 | 7 | import trajectory.utils as utils 8 | 9 | DATASETS = [ 10 | f'{env}-{buffer}' 11 | for env in ['hopper', 'walker2d', 'halfcheetah', 'ant'] 12 | for buffer in ['medium-expert-v2', 'medium-v2', 'medium-replay-v2'] 13 | ] 14 | 15 | LOGBASE = 'logs' 16 | TRIAL = '*' 17 | EXP_NAME = 'plans/pretrained' 18 | 19 | def load_results(paths): 20 | ''' 21 | paths : path to directory containing experiment trials 22 | ''' 23 | scores = [] 24 | for i, path in enumerate(sorted(paths)): 25 | score = load_result(path) 26 | if score is None: 27 | print(f'Skipping {path}') 28 | continue 29 | scores.append(score) 30 | 31 | suffix = path.split('/')[-1] 32 | 33 | mean = np.mean(scores) 34 | err = np.std(scores) / np.sqrt(len(scores)) 35 | return mean, err, scores 36 | 37 | def load_result(path): 38 | ''' 39 | path : path to experiment directory; expects `rollout.json` to be in directory 40 | ''' 41 | fullpath = os.path.join(path, 'rollout.json') 42 | suffix = path.split('/')[-1] 43 | 44 | if not os.path.exists(fullpath): 45 | return None 46 | 47 | results = json.load(open(fullpath, 'rb')) 48 | score = results['score'] 49 | return score * 100 50 | 51 | ####################### 52 | ######## setup ######## 53 | ####################### 54 | 55 | if __name__ == '__main__': 56 | 57 | class Parser(utils.Parser): 58 | dataset: str = None 59 | 60 | args = Parser().parse_args() 61 | 62 | for dataset in ([args.dataset] if args.dataset else DATASETS): 63 | subdirs = glob.glob(os.path.join(LOGBASE, dataset, EXP_NAME)) 64 | 65 | for subdir in subdirs: 66 | reldir = subdir.split('/')[-1] 67 | paths = glob.glob(os.path.join(subdir, TRIAL)) 68 | 69 | mean, err, scores = load_results(paths) 70 | print(f'{dataset.ljust(30)} | {subdir.ljust(50)} | {len(scores)} scores \n {mean:.2f} +/- {err:.2f}\n') 71 | -------------------------------------------------------------------------------- /plotting/scores.py: -------------------------------------------------------------------------------- 1 | means = { 2 | 'Trajectory\nTransformer': { 3 | ## 4 | 'halfcheetah-medium-expert-v2': 95.0, 5 | 'hopper-medium-expert-v2': 110.0, 6 | 'walker2d-medium-expert-v2': 101.9, 7 | 'ant-medium-expert-v2': 116.1, 8 | ## 9 | 'halfcheetah-medium-v2': 46.9, 10 | 'hopper-medium-v2': 61.1, 11 | 'walker2d-medium-v2': 79.0, 12 | 'ant-medium-v2': 83.1, 13 | ## 14 | 'halfcheetah-medium-replay-v2': 41.9, 15 | 'hopper-medium-replay-v2': 91.5, 16 | 'walker2d-medium-replay-v2': 82.6, 17 | 'ant-medium-replay-v2': 77.0, 18 | }, 19 | 'Decision\nTransformer': { 20 | ## 21 | 'halfcheetah-medium-expert-v2': 86.8, 22 | 'hopper-medium-expert-v2': 107.6, 23 | 'walker2d-medium-expert-v2': 108.1, 24 | ## 25 | 'halfcheetah-medium-v2': 42.6, 26 | 'hopper-medium-v2': 67.6, 27 | 'walker2d-medium-v2': 74.0, 28 | ## 29 | 'halfcheetah-medium-replay-v2': 36.6, 30 | 'hopper-medium-replay-v2': 82.7, 31 | 'walker2d-medium-replay-v2': 66.6, 32 | }, 33 | 'CQL': { 34 | ## 35 | 'halfcheetah-medium-expert-v2': 91.6, 36 | 'hopper-medium-expert-v2': 105.4, 37 | 'walker2d-medium-expert-v2': 108.8, 38 | ## 39 | 'halfcheetah-medium-v2': 44.0, 40 | 'hopper-medium-v2': 58.5, 41 | 'walker2d-medium-v2': 72.5, 42 | ## 43 | 'halfcheetah-medium-replay-v2': 45.5, 44 | 'hopper-medium-replay-v2': 95.0, 45 | 'walker2d-medium-replay-v2': 77.2, 46 | }, 47 | 'MOPO': { 48 | ## 49 | 'halfcheetah-medium-expert-v2': 63.3, 50 | 'hopper-medium-expert-v2': 23.7, 51 | 'walker2d-medium-expert-v2': 44.6, 52 | ## 53 | 'halfcheetah-medium-v2': 42.3, 54 | 'hopper-medium-v2': 28.0, 55 | 'walker2d-medium-v2': 17.8, 56 | ## 57 | 'halfcheetah-medium-replay-v2': 53.1, 58 | 'hopper-medium-replay-v2': 67.5, 59 | 'walker2d-medium-replay-v2':39.0, 60 | }, 61 | 'MBOP': { 62 | ## 63 | 'halfcheetah-medium-expert-v2': 105.9, 64 | 'hopper-medium-expert-v2': 55.1, 65 | 'walker2d-medium-expert-v2': 70.2, 66 | ## 67 | 'halfcheetah-medium-v2': 44.6, 68 | 'hopper-medium-v2': 48.8, 69 | 'walker2d-medium-v2': 41.0, 70 | ## 71 | 'halfcheetah-medium-replay-v2': 42.3, 72 | 'hopper-medium-replay-v2': 12.4, 73 | 'walker2d-medium-replay-v2': 9.7, 74 | }, 75 | 'BRAC': { 76 | ## 77 | 'halfcheetah-medium-expert-v2': 41.9, 78 | 'hopper-medium-expert-v2': 0.9, 79 | 'walker2d-medium-expert-v2': 81.6, 80 | ## 81 | 'halfcheetah-medium-v2': 46.3, 82 | 'hopper-medium-v2': 31.3, 83 | 'walker2d-medium-v2': 81.1, 84 | ## 85 | 'halfcheetah-medium-replay-v2': 47.7, 86 | 'hopper-medium-replay-v2': 0.6, 87 | 'walker2d-medium-replay-v2': 0.9, 88 | }, 89 | 'BC': { 90 | ## 91 | 'halfcheetah-medium-expert-v2': 59.9, 92 | 'hopper-medium-expert-v2': 79.6, 93 | 'walker2d-medium-expert-v2': 36.6, 94 | ## 95 | 'halfcheetah-medium-v2': 43.1, 96 | 'hopper-medium-v2': 63.9, 97 | 'walker2d-medium-v2': 77.3, 98 | ## 99 | 'halfcheetah-medium-replay-v2': 4.3, 100 | 'hopper-medium-replay-v2': 27.6, 101 | 'walker2d-medium-replay-v2': 36.9, 102 | }, 103 | } 104 | 105 | errors = { 106 | 'Trajectory\nTransformer': { 107 | ## 108 | 'halfcheetah-medium-expert-v2': 0.2, 109 | 'hopper-medium-expert-v2': 2.7, 110 | 'walker2d-medium-expert-v2': 6.8, 111 | 'ant-medium-expert-v2': 9.0, 112 | ## 113 | 'halfcheetah-medium-v2': 0.4, 114 | 'hopper-medium-v2': 3.6, 115 | 'walker2d-medium-v2': 2.8, 116 | 'ant-medium-v2': 7.3, 117 | ## 118 | 'halfcheetah-medium-replay-v2': 2.5, 119 | 'hopper-medium-replay-v2': 3.6, 120 | 'walker2d-medium-replay-v2': 6.9, 121 | 'ant-medium-replay-v2': 6.8, 122 | }, 123 | } -------------------------------------------------------------------------------- /plotting/table.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pdb 3 | 4 | from plotting.plot import get_mean 5 | from plotting.scores import ( 6 | means as MEANS, 7 | errors as ERRORS, 8 | ) 9 | 10 | ALGORITHM_STRINGS = { 11 | 'Trajectory\nTransformer': 'TT (Ours)', 12 | 'Decision\nTransformer': 'DT', 13 | } 14 | 15 | BUFFER_STRINGS = { 16 | 'medium-expert': 'Medium-Expert', 17 | 'medium': 'Medium', 18 | 'medium-replay': 'Medium-Replay', 19 | } 20 | 21 | ENVIRONMENT_STRINGS = { 22 | 'halfcheetah': 'HalfCheetah', 23 | 'hopper': 'Hopper', 24 | 'walker2d': 'Walker2d', 25 | 'ant': 'Ant', 26 | } 27 | 28 | SHOW_ERRORS = ['Trajectory\nTransformer'] 29 | 30 | def get_result(algorithm, buffer, environment, version='v2'): 31 | key = f'{environment}-{buffer}-{version}' 32 | mean = MEANS[algorithm].get(key, '-') 33 | if algorithm in SHOW_ERRORS: 34 | error = ERRORS[algorithm].get(key) 35 | return (mean, error) 36 | else: 37 | return mean 38 | 39 | def format_result(result): 40 | if type(result) == tuple: 41 | mean, std = result 42 | return f'${mean}$ \\scriptsize{{\\raisebox{{1pt}}{{$\\pm {std}$}}}}' 43 | else: 44 | return f'${result}$' 45 | 46 | def format_row(buffer, environment, results): 47 | buffer_str = BUFFER_STRINGS[buffer] 48 | environment_str = ENVIRONMENT_STRINGS[environment] 49 | results_str = ' & '.join(format_result(result) for result in results) 50 | row = f'{buffer_str} & {environment_str} & {results_str} \\\\ \n' 51 | return row 52 | 53 | def format_buffer_block(algorithms, buffer, environments): 54 | block_str = '\\midrule\n' 55 | for environment in environments: 56 | results = [get_result(alg, buffer, environment) for alg in algorithms] 57 | row_str = format_row(buffer, environment, results) 58 | block_str += row_str 59 | return block_str 60 | 61 | def format_algorithm(algorithm): 62 | algorithm_str = ALGORITHM_STRINGS.get(algorithm, algorithm) 63 | return f'\multicolumn{{1}}{{c}}{{\\bf {algorithm_str}}}' 64 | 65 | def format_algorithms(algorithms): 66 | return ' & '.join(format_algorithm(algorithm) for algorithm in algorithms) 67 | 68 | def format_averages(means, label): 69 | prefix = f'\\multicolumn{{2}}{{c}}{{\\bf Average ({label})}} & ' 70 | formatted = ' & '.join(str(mean) for mean in means) 71 | return prefix + formatted 72 | 73 | def format_averages_block(algorithms): 74 | means_filtered = [np.round(get_mean(MEANS[algorithm], exclude='ant'), 1) for algorithm in algorithms] 75 | means_all = [np.round(get_mean(MEANS[algorithm], exclude=None), 1) for algorithm in algorithms] 76 | 77 | means_all = [ 78 | means 79 | if 'ant-medium-expert-v2' in MEANS[algorithm] 80 | else '$-$' 81 | for algorithm, means in zip(algorithms, means_all) 82 | ] 83 | 84 | formatted_filtered = format_averages(means_filtered, 'without Ant') 85 | formatted_all = format_averages(means_all, 'all settings') 86 | 87 | formatted_block = ( 88 | f'{formatted_filtered} \\hspace{{.6cm}} \\\\ \n' 89 | f'{formatted_all} \\hspace{{.6cm}} \\\\ \n' 90 | ) 91 | return formatted_block 92 | 93 | def format_table(algorithms, buffers, environments): 94 | justify_str = 'll' + 'r' * len(algorithms) 95 | algorithm_str = format_algorithms(['Dataset', 'Environment'] + algorithms) 96 | averages_str = format_averages_block(algorithms) 97 | table_prefix = ( 98 | '\\begin{table*}[h]\n' 99 | '\\centering\n' 100 | '\\small\n' 101 | f'\\begin{{tabular}}{{{justify_str}}}\n' 102 | '\\toprule\n' 103 | f'{algorithm_str} \\\\ \n' 104 | ) 105 | table_suffix = ( 106 | '\\midrule\n' 107 | f'{averages_str}' 108 | '\\bottomrule\n' 109 | '\\end{tabular}\n' 110 | '\\label{table:d4rl}\n' 111 | '\\end{table*}' 112 | ) 113 | blocks = ''.join(format_buffer_block(algorithms, buffer, environments) for buffer in buffers) 114 | table = ( 115 | f'{table_prefix}' 116 | f'{blocks}' 117 | f'{table_suffix}' 118 | ) 119 | return table 120 | 121 | 122 | algorithms =['BC', 'MBOP', 'BRAC', 'CQL', 'Decision\nTransformer', 'Trajectory\nTransformer'] 123 | buffers = ['medium-expert', 'medium', 'medium-replay'] 124 | environments = ['halfcheetah', 'hopper', 'walker2d', 'ant'] 125 | 126 | table = format_table(algorithms, buffers, environments) 127 | print(table) 128 | -------------------------------------------------------------------------------- /pretrained.sh: -------------------------------------------------------------------------------- 1 | export DOWNLOAD_PATH=logs 2 | 3 | [ ! -d ${DOWNLOAD_PATH} ] && mkdir ${DOWNLOAD_PATH} 4 | 5 | ## downloads pretrained models for 16 datasets: 6 | ## {halfcheetah, hopper, walker2d, ant} 7 | ## x 8 | ## {expert-v2, medium-expert-v2, medium-v2, medium-replay-v2} 9 | 10 | wget https://www.dropbox.com/sh/r09lkdoj66kx43w/AACbXjMhcI6YNsn1qU4LParja?dl=1 -O dropbox_models.zip 11 | unzip dropbox_models.zip -d ${DOWNLOAD_PATH} 12 | rm dropbox_models.zip 13 | 14 | ## downloads 15 plans from each pretrained model 15 | wget https://www.dropbox.com/s/5sn79ep79yo22kv/pretrained-plans.tar?dl=1 -O dropbox_plans.tar 16 | tar -xvf dropbox_plans.tar 17 | cp -r pretrained-plans/* ${DOWNLOAD_PATH} 18 | rm -r pretrained-plans 19 | rm dropbox_plans.tar 20 | -------------------------------------------------------------------------------- /scripts/plan.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pdb 3 | from os.path import join 4 | 5 | import trajectory.utils as utils 6 | import trajectory.datasets as datasets 7 | from trajectory.search import ( 8 | beam_plan, 9 | make_prefix, 10 | extract_actions, 11 | update_context, 12 | ) 13 | 14 | class Parser(utils.Parser): 15 | dataset: str = 'halfcheetah-medium-expert-v2' 16 | config: str = 'config.offline' 17 | 18 | ####################### 19 | ######## setup ######## 20 | ####################### 21 | 22 | args = Parser().parse_args('plan') 23 | 24 | ####################### 25 | ####### models ######## 26 | ####################### 27 | 28 | dataset = utils.load_from_config(args.logbase, args.dataset, args.gpt_loadpath, 29 | 'data_config.pkl') 30 | 31 | gpt, gpt_epoch = utils.load_model(args.logbase, args.dataset, args.gpt_loadpath, 32 | epoch=args.gpt_epoch, device=args.device) 33 | 34 | ####################### 35 | ####### dataset ####### 36 | ####################### 37 | 38 | env = datasets.load_environment(args.dataset) 39 | renderer = utils.make_renderer(args) 40 | timer = utils.timer.Timer() 41 | 42 | discretizer = dataset.discretizer 43 | discount = dataset.discount 44 | observation_dim = dataset.observation_dim 45 | action_dim = dataset.action_dim 46 | 47 | value_fn = lambda x: discretizer.value_fn(x, args.percentile) 48 | preprocess_fn = datasets.get_preprocess_fn(env.name) 49 | 50 | ####################### 51 | ###### main loop ###### 52 | ####################### 53 | 54 | observation = env.reset() 55 | total_reward = 0 56 | 57 | ## observations for rendering 58 | rollout = [observation.copy()] 59 | 60 | ## previous (tokenized) transitions for conditioning transformer 61 | context = [] 62 | 63 | T = env.max_episode_steps 64 | for t in range(T): 65 | 66 | observation = preprocess_fn(observation) 67 | 68 | if t % args.plan_freq == 0: 69 | ## concatenate previous transitions and current observations to input to model 70 | prefix = make_prefix(discretizer, context, observation, args.prefix_context) 71 | 72 | ## sample sequence from model beginning with `prefix` 73 | sequence = beam_plan( 74 | gpt, value_fn, prefix, 75 | args.horizon, args.beam_width, args.n_expand, observation_dim, action_dim, 76 | discount, args.max_context_transitions, verbose=args.verbose, 77 | k_obs=args.k_obs, k_act=args.k_act, cdf_obs=args.cdf_obs, cdf_act=args.cdf_act, 78 | ) 79 | 80 | else: 81 | sequence = sequence[1:] 82 | 83 | ## [ horizon x transition_dim ] convert sampled tokens to continuous trajectory 84 | sequence_recon = discretizer.reconstruct(sequence) 85 | 86 | ## [ action_dim ] index into sampled trajectory to grab first action 87 | action = extract_actions(sequence_recon, observation_dim, action_dim, t=0) 88 | 89 | ## execute action in environment 90 | next_observation, reward, terminal, _ = env.step(action) 91 | 92 | ## update return 93 | total_reward += reward 94 | score = env.get_normalized_score(total_reward) 95 | 96 | ## update rollout observations and context transitions 97 | rollout.append(next_observation.copy()) 98 | context = update_context(context, discretizer, observation, action, reward, args.max_context_transitions) 99 | 100 | print( 101 | f'[ plan ] t: {t} / {T} | r: {reward:.2f} | R: {total_reward:.2f} | score: {score:.4f} | ' 102 | f'time: {timer():.2f} | {args.dataset} | {args.exp_name} | {args.suffix}\n' 103 | ) 104 | 105 | ## visualization 106 | if t % args.vis_freq == 0 or terminal or t == T: 107 | 108 | ## save current plan 109 | renderer.render_plan(join(args.savepath, f'{t}_plan.mp4'), sequence_recon, env.state_vector()) 110 | 111 | ## save rollout thus far 112 | renderer.render_rollout(join(args.savepath, f'rollout.mp4'), rollout, fps=80) 113 | 114 | if terminal: break 115 | 116 | observation = next_observation 117 | 118 | ## save result as a json file 119 | json_path = join(args.savepath, 'rollout.json') 120 | json_data = {'score': score, 'step': t, 'return': total_reward, 'term': terminal, 'gpt_epoch': gpt_epoch} 121 | json.dump(json_data, open(json_path, 'w'), indent=2, sort_keys=True) 122 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import pdb 5 | 6 | import trajectory.utils as utils 7 | import trajectory.datasets as datasets 8 | from trajectory.models.transformers import GPT 9 | 10 | 11 | class Parser(utils.Parser): 12 | dataset: str = 'halfcheetah-medium-expert-v2' 13 | config: str = 'config.offline' 14 | 15 | ####################### 16 | ######## setup ######## 17 | ####################### 18 | 19 | args = Parser().parse_args('train') 20 | 21 | ####################### 22 | ####### dataset ####### 23 | ####################### 24 | 25 | env = datasets.load_environment(args.dataset) 26 | 27 | sequence_length = args.subsampled_sequence_length * args.step 28 | 29 | dataset_config = utils.Config( 30 | datasets.DiscretizedDataset, 31 | savepath=(args.savepath, 'data_config.pkl'), 32 | env=args.dataset, 33 | N=args.N, 34 | penalty=args.termination_penalty, 35 | sequence_length=sequence_length, 36 | step=args.step, 37 | discount=args.discount, 38 | discretizer=args.discretizer, 39 | ) 40 | 41 | dataset = dataset_config() 42 | obs_dim = dataset.observation_dim 43 | act_dim = dataset.action_dim 44 | transition_dim = dataset.joined_dim 45 | 46 | ####################### 47 | ######## model ######## 48 | ####################### 49 | 50 | block_size = args.subsampled_sequence_length * transition_dim - 1 51 | print( 52 | f'Dataset size: {len(dataset)} | ' 53 | f'Joined dim: {transition_dim} ' 54 | f'(observation: {obs_dim}, action: {act_dim}) | Block size: {block_size}' 55 | ) 56 | 57 | model_config = utils.Config( 58 | GPT, 59 | savepath=(args.savepath, 'model_config.pkl'), 60 | ## discretization 61 | vocab_size=args.N, block_size=block_size, 62 | ## architecture 63 | n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd*args.n_head, 64 | ## dimensions 65 | observation_dim=obs_dim, action_dim=act_dim, transition_dim=transition_dim, 66 | ## loss weighting 67 | action_weight=args.action_weight, reward_weight=args.reward_weight, value_weight=args.value_weight, 68 | ## dropout probabilities 69 | embd_pdrop=args.embd_pdrop, resid_pdrop=args.resid_pdrop, attn_pdrop=args.attn_pdrop, 70 | ) 71 | 72 | model = model_config() 73 | model.to(args.device) 74 | 75 | ####################### 76 | ####### trainer ####### 77 | ####################### 78 | 79 | warmup_tokens = len(dataset) * block_size ## number of tokens seen per epoch 80 | final_tokens = 20 * warmup_tokens 81 | 82 | trainer_config = utils.Config( 83 | utils.Trainer, 84 | savepath=(args.savepath, 'trainer_config.pkl'), 85 | # optimization parameters 86 | batch_size=args.batch_size, 87 | learning_rate=args.learning_rate, 88 | betas=(0.9, 0.95), 89 | grad_norm_clip=1.0, 90 | weight_decay=0.1, # only applied on matmul weights 91 | # learning rate decay: linear warmup followed by cosine decay to 10% of original 92 | lr_decay=args.lr_decay, 93 | warmup_tokens=warmup_tokens, 94 | final_tokens=final_tokens, 95 | ## dataloader 96 | num_workers=0, 97 | device=args.device, 98 | ) 99 | 100 | trainer = trainer_config() 101 | 102 | ####################### 103 | ###### main loop ###### 104 | ####################### 105 | 106 | ## scale number of epochs to keep number of updates constant 107 | n_epochs = int(1e6 / len(dataset) * args.n_epochs_ref) 108 | save_freq = int(n_epochs // args.n_saves) 109 | 110 | for epoch in range(n_epochs): 111 | print(f'\nEpoch: {epoch} / {n_epochs} | {args.dataset} | {args.exp_name}') 112 | 113 | trainer.train(model, dataset) 114 | 115 | ## get greatest multiple of `save_freq` less than or equal to `save_epoch` 116 | save_epoch = (epoch + 1) // save_freq * save_freq 117 | statepath = os.path.join(args.savepath, f'state_{save_epoch}.pt') 118 | print(f'Saving model to {statepath}') 119 | 120 | ## save state to disk 121 | state = model.state_dict() 122 | torch.save(state, statepath) 123 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from setuptools import find_packages 3 | 4 | setup( 5 | name='trajectory', 6 | packages=find_packages(), 7 | ) 8 | -------------------------------------------------------------------------------- /trajectory/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jannerm/trajectory-transformer/8834a6ed04ceeab8fdb9465e145c6e041c05d71b/trajectory/__init__.py -------------------------------------------------------------------------------- /trajectory/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .d4rl import load_environment 2 | from .sequence import * 3 | from .preprocessing import get_preprocess_fn 4 | -------------------------------------------------------------------------------- /trajectory/datasets/d4rl.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import gym 4 | import pdb 5 | 6 | from contextlib import ( 7 | contextmanager, 8 | redirect_stderr, 9 | redirect_stdout, 10 | ) 11 | 12 | @contextmanager 13 | def suppress_output(): 14 | """ 15 | A context manager that redirects stdout and stderr to devnull 16 | https://stackoverflow.com/a/52442331 17 | """ 18 | with open(os.devnull, 'w') as fnull: 19 | with redirect_stderr(fnull) as err, redirect_stdout(fnull) as out: 20 | yield (err, out) 21 | 22 | with suppress_output(): 23 | ## d4rl prints out a variety of warnings 24 | import d4rl 25 | 26 | # def construct_dataloader(dataset, **kwargs): 27 | # dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, pin_memory=True, **kwargs) 28 | # return dataloader 29 | 30 | def qlearning_dataset_with_timeouts(env, dataset=None, terminate_on_end=False, **kwargs): 31 | if dataset is None: 32 | dataset = env.get_dataset(**kwargs) 33 | 34 | N = dataset['rewards'].shape[0] 35 | obs_ = [] 36 | next_obs_ = [] 37 | action_ = [] 38 | reward_ = [] 39 | done_ = [] 40 | realdone_ = [] 41 | 42 | episode_step = 0 43 | for i in range(N-1): 44 | obs = dataset['observations'][i] 45 | new_obs = dataset['observations'][i+1] 46 | action = dataset['actions'][i] 47 | reward = dataset['rewards'][i] 48 | done_bool = bool(dataset['terminals'][i]) 49 | realdone_bool = bool(dataset['terminals'][i]) 50 | final_timestep = dataset['timeouts'][i] 51 | 52 | if i < N - 1: 53 | done_bool += dataset['timeouts'][i] #+1] 54 | 55 | if (not terminate_on_end) and final_timestep: 56 | # Skip this transition and don't apply terminals on the last step of an episode 57 | episode_step = 0 58 | continue 59 | if done_bool or final_timestep: 60 | episode_step = 0 61 | 62 | obs_.append(obs) 63 | next_obs_.append(new_obs) 64 | action_.append(action) 65 | reward_.append(reward) 66 | done_.append(done_bool) 67 | realdone_.append(realdone_bool) 68 | episode_step += 1 69 | 70 | return { 71 | 'observations': np.array(obs_), 72 | 'actions': np.array(action_), 73 | 'next_observations': np.array(next_obs_), 74 | 'rewards': np.array(reward_)[:,None], 75 | 'terminals': np.array(done_)[:,None], 76 | 'realterminals': np.array(realdone_)[:,None], 77 | } 78 | 79 | def load_environment(name): 80 | with suppress_output(): 81 | wrapped_env = gym.make(name) 82 | env = wrapped_env.unwrapped 83 | env.max_episode_steps = wrapped_env._max_episode_steps 84 | env.name = name 85 | return env 86 | -------------------------------------------------------------------------------- /trajectory/datasets/preprocessing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def kitchen_preprocess_fn(observations): 4 | ## keep first 30 dimensions of 60-dimension observations 5 | keep = observations[:, :30] 6 | remove = observations[:, 30:] 7 | assert (remove.max(0) == remove.min(0)).all(), 'removing important state information' 8 | return keep 9 | 10 | def ant_preprocess_fn(observations): 11 | qpos_dim = 13 ## root_x and root_y removed 12 | qvel_dim = 14 13 | cfrc_dim = 84 14 | assert observations.shape[1] == qpos_dim + qvel_dim + cfrc_dim 15 | keep = observations[:, :qpos_dim + qvel_dim] 16 | return keep 17 | 18 | def vmap(fn): 19 | 20 | def _fn(inputs): 21 | if inputs.ndim == 1: 22 | inputs = inputs[None] 23 | return_1d = True 24 | else: 25 | return_1d = False 26 | 27 | outputs = fn(inputs) 28 | 29 | if return_1d: 30 | return outputs.squeeze(0) 31 | else: 32 | return outputs 33 | 34 | return _fn 35 | 36 | def preprocess_dataset(preprocess_fn): 37 | 38 | def _fn(dataset): 39 | for key in ['observations', 'next_observations']: 40 | dataset[key] = preprocess_fn(dataset[key]) 41 | return dataset 42 | 43 | return _fn 44 | 45 | preprocess_functions = { 46 | 'kitchen-complete-v0': vmap(kitchen_preprocess_fn), 47 | 'kitchen-mixed-v0': vmap(kitchen_preprocess_fn), 48 | 'kitchen-partial-v0': vmap(kitchen_preprocess_fn), 49 | 'ant-expert-v2': vmap(ant_preprocess_fn), 50 | 'ant-medium-expert-v2': vmap(ant_preprocess_fn), 51 | 'ant-medium-replay-v2': vmap(ant_preprocess_fn), 52 | 'ant-medium-v2': vmap(ant_preprocess_fn), 53 | 'ant-random-v2': vmap(ant_preprocess_fn), 54 | } 55 | 56 | dataset_preprocess_functions = { 57 | k: preprocess_dataset(fn) for k, fn in preprocess_functions.items() 58 | } 59 | 60 | def get_preprocess_fn(env): 61 | return preprocess_functions.get(env, lambda x: x) -------------------------------------------------------------------------------- /trajectory/datasets/sequence.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import pdb 5 | 6 | from trajectory.utils import discretization 7 | from trajectory.utils.arrays import to_torch 8 | 9 | from .d4rl import load_environment, qlearning_dataset_with_timeouts 10 | from .preprocessing import dataset_preprocess_functions 11 | 12 | def segment(observations, terminals, max_path_length): 13 | """ 14 | segment `observations` into trajectories according to `terminals` 15 | """ 16 | assert len(observations) == len(terminals) 17 | observation_dim = observations.shape[1] 18 | 19 | trajectories = [[]] 20 | for obs, term in zip(observations, terminals): 21 | trajectories[-1].append(obs) 22 | if term.squeeze(): 23 | trajectories.append([]) 24 | 25 | if len(trajectories[-1]) == 0: 26 | trajectories = trajectories[:-1] 27 | 28 | ## list of arrays because trajectories lengths will be different 29 | trajectories = [np.stack(traj, axis=0) for traj in trajectories] 30 | 31 | n_trajectories = len(trajectories) 32 | path_lengths = [len(traj) for traj in trajectories] 33 | 34 | ## pad trajectories to be of equal length 35 | trajectories_pad = np.zeros((n_trajectories, max_path_length, observation_dim), dtype=trajectories[0].dtype) 36 | early_termination = np.zeros((n_trajectories, max_path_length), dtype=np.bool) 37 | for i, traj in enumerate(trajectories): 38 | path_length = path_lengths[i] 39 | trajectories_pad[i,:path_length] = traj 40 | early_termination[i,path_length:] = 1 41 | 42 | return trajectories_pad, early_termination, path_lengths 43 | 44 | class SequenceDataset(torch.utils.data.Dataset): 45 | 46 | def __init__(self, env, sequence_length=250, step=10, discount=0.99, max_path_length=1000, penalty=None, device='cuda:0'): 47 | print(f'[ datasets/sequence ] Sequence length: {sequence_length} | Step: {step} | Max path length: {max_path_length}') 48 | self.env = env = load_environment(env) if type(env) is str else env 49 | self.sequence_length = sequence_length 50 | self.step = step 51 | self.max_path_length = max_path_length 52 | self.device = device 53 | 54 | print(f'[ datasets/sequence ] Loading...', end=' ', flush=True) 55 | dataset = qlearning_dataset_with_timeouts(env.unwrapped, terminate_on_end=True) 56 | print('✓') 57 | 58 | preprocess_fn = dataset_preprocess_functions.get(env.name) 59 | if preprocess_fn: 60 | print(f'[ datasets/sequence ] Modifying environment') 61 | dataset = preprocess_fn(dataset) 62 | ## 63 | 64 | observations = dataset['observations'] 65 | actions = dataset['actions'] 66 | next_observations = dataset['next_observations'] 67 | rewards = dataset['rewards'] 68 | terminals = dataset['terminals'] 69 | realterminals = dataset['realterminals'] 70 | 71 | self.observations_raw = observations 72 | self.actions_raw = actions 73 | self.next_observations_raw = next_observations 74 | self.joined_raw = np.concatenate([observations, actions], axis=-1) 75 | self.rewards_raw = rewards 76 | self.terminals_raw = terminals 77 | 78 | ## terminal penalty 79 | if penalty is not None: 80 | terminal_mask = realterminals.squeeze() 81 | self.rewards_raw[terminal_mask] = penalty 82 | 83 | ## segment 84 | print(f'[ datasets/sequence ] Segmenting...', end=' ', flush=True) 85 | self.joined_segmented, self.termination_flags, self.path_lengths = segment(self.joined_raw, terminals, max_path_length) 86 | self.rewards_segmented, *_ = segment(self.rewards_raw, terminals, max_path_length) 87 | print('✓') 88 | 89 | self.discount = discount 90 | self.discounts = (discount ** np.arange(self.max_path_length))[:,None] 91 | 92 | ## [ n_paths x max_path_length x 1 ] 93 | self.values_segmented = np.zeros(self.rewards_segmented.shape) 94 | 95 | for t in range(max_path_length): 96 | ## [ n_paths x 1 ] 97 | V = (self.rewards_segmented[:,t+1:] * self.discounts[:-t-1]).sum(axis=1) 98 | self.values_segmented[:,t] = V 99 | 100 | ## add (r, V) to `joined` 101 | values_raw = self.values_segmented.squeeze(axis=-1).reshape(-1) 102 | values_mask = ~self.termination_flags.reshape(-1) 103 | self.values_raw = values_raw[values_mask, None] 104 | self.joined_raw = np.concatenate([self.joined_raw, self.rewards_raw, self.values_raw], axis=-1) 105 | self.joined_segmented = np.concatenate([self.joined_segmented, self.rewards_segmented, self.values_segmented], axis=-1) 106 | 107 | ## get valid indices 108 | indices = [] 109 | for path_ind, length in enumerate(self.path_lengths): 110 | end = length - 1 111 | for i in range(end): 112 | indices.append((path_ind, i, i+sequence_length)) 113 | 114 | self.indices = np.array(indices) 115 | self.observation_dim = observations.shape[1] 116 | self.action_dim = actions.shape[1] 117 | self.joined_dim = self.joined_raw.shape[1] 118 | 119 | ## pad trajectories 120 | n_trajectories, _, joined_dim = self.joined_segmented.shape 121 | self.joined_segmented = np.concatenate([ 122 | self.joined_segmented, 123 | np.zeros((n_trajectories, sequence_length-1, joined_dim)), 124 | ], axis=1) 125 | self.termination_flags = np.concatenate([ 126 | self.termination_flags, 127 | np.ones((n_trajectories, sequence_length-1), dtype=np.bool), 128 | ], axis=1) 129 | 130 | def __len__(self): 131 | return len(self.indices) 132 | 133 | 134 | class DiscretizedDataset(SequenceDataset): 135 | 136 | def __init__(self, *args, N=50, discretizer='QuantileDiscretizer', **kwargs): 137 | super().__init__(*args, **kwargs) 138 | self.N = N 139 | discretizer_class = getattr(discretization, discretizer) 140 | self.discretizer = discretizer_class(self.joined_raw, N) 141 | 142 | def __getitem__(self, idx): 143 | path_ind, start_ind, end_ind = self.indices[idx] 144 | path_length = self.path_lengths[path_ind] 145 | 146 | joined = self.joined_segmented[path_ind, start_ind:end_ind:self.step] 147 | terminations = self.termination_flags[path_ind, start_ind:end_ind:self.step] 148 | 149 | joined_discrete = self.discretizer.discretize(joined) 150 | 151 | ## replace with termination token if the sequence has ended 152 | assert (joined[terminations] == 0).all(), \ 153 | f'Everything after termination should be 0: {path_ind} | {start_ind} | {end_ind}' 154 | joined_discrete[terminations] = self.N 155 | 156 | ## [ (sequence_length / skip) x observation_dim] 157 | joined_discrete = to_torch(joined_discrete, device='cpu', dtype=torch.long).contiguous() 158 | 159 | ## don't compute loss for parts of the prediction that extend 160 | ## beyond the max path length 161 | traj_inds = torch.arange(start_ind, end_ind, self.step) 162 | mask = torch.ones(joined_discrete.shape, dtype=torch.bool) 163 | mask[traj_inds > self.max_path_length - self.step] = 0 164 | 165 | ## flatten everything 166 | joined_discrete = joined_discrete.view(-1) 167 | mask = mask.view(-1) 168 | 169 | X = joined_discrete[:-1] 170 | Y = joined_discrete[1:] 171 | mask = mask[:-1] 172 | 173 | return X, Y, mask 174 | 175 | class GoalDataset(DiscretizedDataset): 176 | 177 | def __init__(self, *args, **kwargs): 178 | super().__init__(*args, **kwargs) 179 | pdb.set_trace() 180 | 181 | def __getitem__(self, idx): 182 | X, Y, mask = super().__getitem__(idx) 183 | 184 | ## get path length for looking up the last transition in the trajcetory 185 | path_ind, start_ind, end_ind = self.indices[idx] 186 | path_length = self.path_lengths[path_ind] 187 | 188 | ## the goal is the first `observation_dim` dimensions of the last transition 189 | goal = self.joined_segmented[path_ind, path_length-1, :self.observation_dim] 190 | goal_discrete = self.discretizer.discretize(goal, subslice=(0, self.observation_dim)) 191 | goal_discrete = to_torch(goal_discrete, device='cpu', dtype=torch.long).contiguous().view(-1) 192 | 193 | return X, goal_discrete, Y, mask 194 | -------------------------------------------------------------------------------- /trajectory/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jannerm/trajectory-transformer/8834a6ed04ceeab8fdb9465e145c6e041c05d71b/trajectory/models/__init__.py -------------------------------------------------------------------------------- /trajectory/models/ein.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import pdb 5 | 6 | class EinLinear(nn.Module): 7 | 8 | def __init__(self, n_models, in_features, out_features, bias): 9 | super().__init__() 10 | self.n_models = n_models 11 | self.out_features = out_features 12 | self.in_features = in_features 13 | self.weight = nn.Parameter(torch.Tensor(n_models, out_features, in_features)) 14 | if bias: 15 | self.bias = nn.Parameter(torch.Tensor(n_models, out_features)) 16 | else: 17 | self.register_parameter('bias', None) 18 | self.reset_parameters() 19 | 20 | def reset_parameters(self): 21 | for i in range(self.n_models): 22 | nn.init.kaiming_uniform_(self.weight[i], a=math.sqrt(5)) 23 | if self.bias is not None: 24 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[i]) 25 | bound = 1 / math.sqrt(fan_in) 26 | nn.init.uniform_(self.bias[i], -bound, bound) 27 | 28 | def forward(self, input): 29 | """ 30 | input : [ B x n_models x input_dim ] 31 | """ 32 | ## [ B x n_models x output_dim ] 33 | output = torch.einsum('eoi,bei->beo', self.weight, input) 34 | if self.bias is not None: 35 | raise RuntimeError() 36 | return output 37 | 38 | def extra_repr(self): 39 | return 'n_models={}, in_features={}, out_features={}, bias={}'.format( 40 | self.n_models, self.in_features, self.out_features, self.bias is not None 41 | ) 42 | -------------------------------------------------------------------------------- /trajectory/models/embeddings.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import pdb 5 | 6 | def make_weights(N, weights): 7 | assert len(weights) % 2 == 1, f'Expected odd number of weights, got: {weights}' 8 | center = int((len(weights) - 1) / 2) 9 | 10 | tokens = np.zeros((N, N)) 11 | for i in range(N): 12 | token = np.zeros(N) 13 | for j, w in enumerate(weights): 14 | ind = i + j - center 15 | ind = np.clip(ind, 0, N-1) 16 | token[ind] += w 17 | tokens[i] = token 18 | assert np.allclose(tokens.sum(axis=-1), 1) 19 | return tokens 20 | 21 | def add_stop_token(tokens): 22 | N = len(tokens) 23 | ## regular tokens put 0 probability on stop token 24 | pad = np.zeros((N, 1)) 25 | tokens = np.concatenate([tokens, pad], axis=1) 26 | ## stop token puts 1 probability on itself 27 | stop_weight = np.zeros((1, N+1)) 28 | stop_weight[0,-1] = 1 29 | tokens = np.concatenate([tokens, stop_weight], axis=0) 30 | 31 | assert tokens.shape[0] == tokens.shape[1] 32 | assert np.allclose(tokens.sum(axis=-1), 1) 33 | return tokens 34 | 35 | class SmoothEmbedding(nn.Module): 36 | 37 | def __init__(self, num_embeddings, embedding_dim, weights, stop_token=False): 38 | super().__init__() 39 | self.weights = make_weights(num_embeddings, weights) 40 | if stop_token: 41 | self.weights = add_stop_token(self.weights) 42 | num_embeddings += 1 43 | self.weights = torch.tensor(self.weights, dtype=torch.float, device='cuda:0') 44 | self.inds = torch.arange(0, num_embeddings, device='cuda:0') 45 | self._embeddings = nn.Embedding(num_embeddings, embedding_dim) 46 | 47 | def forward(self, x): 48 | ''' 49 | x : [ batch_size x context ] 50 | ''' 51 | ## [ num_embeddings x embedding_dim ] 52 | embed = self._embeddings(self.inds) 53 | ## [ batch_size x context x num_embeddings ] 54 | weights = self.weights[x] 55 | assert torch.allclose(weights.sum(-1), torch.ones(1, device=weights.device)) 56 | 57 | # [ batch_size x context x embedding_dim ] 58 | weighted_embed = torch.einsum('btn,nd->btd', weights, embed) 59 | return weighted_embed 60 | 61 | 62 | if __name__ == '__main__': 63 | 64 | x = torch.randint(0, 100, size=(5, 10,)).cuda() 65 | 66 | ## test with weights 67 | embed = SmoothEmbedding(100, 32, weights=[0.15, 0.2, 0.3, 0.2, 0.15], stop_token=True) 68 | embed.cuda() 69 | out = embed(x) 70 | 71 | ## test limiting case of regular embedding module 72 | embed_1 = SmoothEmbedding(100, 32, weights=[1.0], stop_token=True) 73 | embed_1.cuda() 74 | out_1 = embed_1(x) 75 | 76 | ## reference 77 | out_0 = embed_1._embeddings(x) 78 | 79 | print(f'Same: {(out_0 == out_1).all().item()}') 80 | -------------------------------------------------------------------------------- /trajectory/models/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import pdb 4 | 5 | 6 | def get_activation(params): 7 | if type(params) == dict: 8 | name = params['type'] 9 | kwargs = params['kwargs'] 10 | else: 11 | name = str(params) 12 | kwargs = {} 13 | return lambda: getattr(nn, name)(**kwargs) 14 | 15 | def flatten(condition_dict): 16 | keys = sorted(condition_dict) 17 | vals = [condition_dict[key] for key in keys] 18 | condition = torch.cat(vals, dim=-1) 19 | return condition 20 | 21 | class MLP(nn.Module): 22 | 23 | def __init__(self, input_dim, hidden_dims, output_dim, activation='GELU', output_activation='Identity', name='mlp', model_class=None): 24 | """ 25 | @TODO: clean up model instantiation from config so we don't have to pass in `model_class` to the model itself 26 | """ 27 | super(MLP, self).__init__() 28 | self.input_dim = input_dim 29 | self.name = name 30 | activation = get_activation(activation) 31 | output_activation = get_activation(output_activation) 32 | 33 | layers = [] 34 | current = input_dim 35 | for dim in hidden_dims: 36 | linear = nn.Linear(current, dim) 37 | layers.append(linear) 38 | layers.append(activation()) 39 | current = dim 40 | 41 | layers.append(nn.Linear(current, output_dim)) 42 | layers.append(output_activation()) 43 | 44 | self._layers = nn.Sequential(*layers) 45 | 46 | def forward(self, x): 47 | return self._layers(x) 48 | 49 | @property 50 | def num_parameters(self): 51 | parameters = filter(lambda p: p.requires_grad, self.parameters()) 52 | return sum([p.numel() for p in parameters]) 53 | 54 | def __repr__(self): 55 | return '[ {} : {} parameters ] {}'.format( 56 | self.name, self.num_parameters, 57 | super().__repr__()) 58 | 59 | class FlattenMLP(MLP): 60 | 61 | def forward(self, *args): 62 | x = torch.cat(args, dim=-1) 63 | return super().forward(x) 64 | -------------------------------------------------------------------------------- /trajectory/models/transformers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import pdb 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn import functional as F 8 | 9 | from .ein import EinLinear 10 | 11 | class CausalSelfAttention(nn.Module): 12 | 13 | def __init__(self, config): 14 | super().__init__() 15 | assert config.n_embd % config.n_head == 0 16 | # key, query, value projections for all heads 17 | self.key = nn.Linear(config.n_embd, config.n_embd) 18 | self.query = nn.Linear(config.n_embd, config.n_embd) 19 | self.value = nn.Linear(config.n_embd, config.n_embd) 20 | # regularization 21 | self.attn_drop = nn.Dropout(config.attn_pdrop) 22 | self.resid_drop = nn.Dropout(config.resid_pdrop) 23 | # output projection 24 | self.proj = nn.Linear(config.n_embd, config.n_embd) 25 | # causal mask to ensure that attention is only applied to the left in the input sequence 26 | self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size)) 27 | .view(1, 1, config.block_size, config.block_size)) 28 | ## mask previous value estimates 29 | joined_dim = config.observation_dim + config.action_dim + 2 30 | self.mask.squeeze()[:,joined_dim-1::joined_dim] = 0 31 | ## 32 | self.n_head = config.n_head 33 | 34 | def forward(self, x, layer_past=None): 35 | B, T, C = x.size() 36 | 37 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 38 | ## [ B x n_heads x T x head_dim ] 39 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 40 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 41 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 42 | 43 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 44 | ## [ B x n_heads x T x T ] 45 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 46 | att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) 47 | att = F.softmax(att, dim=-1) 48 | self._attn_map = att.clone() 49 | att = self.attn_drop(att) 50 | ## [ B x n_heads x T x head_size ] 51 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 52 | ## [ B x T x embedding_dim ] 53 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 54 | 55 | # output projection 56 | y = self.resid_drop(self.proj(y)) 57 | return y 58 | 59 | class Block(nn.Module): 60 | 61 | def __init__(self, config): 62 | super().__init__() 63 | self.ln1 = nn.LayerNorm(config.n_embd) 64 | self.ln2 = nn.LayerNorm(config.n_embd) 65 | self.attn = CausalSelfAttention(config) 66 | self.mlp = nn.Sequential( 67 | nn.Linear(config.n_embd, 4 * config.n_embd), 68 | nn.GELU(), 69 | nn.Linear(4 * config.n_embd, config.n_embd), 70 | nn.Dropout(config.resid_pdrop), 71 | ) 72 | 73 | def forward(self, x): 74 | x = x + self.attn(self.ln1(x)) 75 | x = x + self.mlp(self.ln2(x)) 76 | return x 77 | 78 | class GPT(nn.Module): 79 | """ the full GPT language model, with a context size of block_size """ 80 | 81 | def __init__(self, config): 82 | super().__init__() 83 | 84 | # input embedding stem (+1 for stop token) 85 | self.tok_emb = nn.Embedding(config.vocab_size * config.transition_dim + 1, config.n_embd) 86 | 87 | self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) 88 | self.drop = nn.Dropout(config.embd_pdrop) 89 | # transformer 90 | self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) 91 | # decoder head 92 | self.ln_f = nn.LayerNorm(config.n_embd) 93 | # self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 94 | self.head = EinLinear(config.transition_dim, config.n_embd, config.vocab_size + 1, bias=False) 95 | 96 | self.vocab_size = config.vocab_size 97 | self.stop_token = config.vocab_size * config.transition_dim 98 | self.block_size = config.block_size 99 | self.observation_dim = config.observation_dim 100 | 101 | self.action_dim = config.action_dim 102 | self.transition_dim = config.transition_dim 103 | self.action_weight = config.action_weight 104 | self.reward_weight = config.reward_weight 105 | self.value_weight = config.value_weight 106 | 107 | self.embedding_dim = config.n_embd 108 | self.apply(self._init_weights) 109 | 110 | def get_block_size(self): 111 | return self.block_size 112 | 113 | def _init_weights(self, module): 114 | if isinstance(module, (nn.Linear, nn.Embedding)): 115 | module.weight.data.normal_(mean=0.0, std=0.02) 116 | if isinstance(module, nn.Linear) and module.bias is not None: 117 | module.bias.data.zero_() 118 | elif isinstance(module, nn.LayerNorm): 119 | module.bias.data.zero_() 120 | module.weight.data.fill_(1.0) 121 | 122 | def configure_optimizers(self, train_config): 123 | """ 124 | This long function is unfortunately doing something very simple and is being very defensive: 125 | We are separating out all parameters of the model into two buckets: those that will experience 126 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights). 127 | We are then returning the PyTorch optimizer object. 128 | """ 129 | 130 | # separate out all parameters to those that will and won't experience regularizing weight decay 131 | decay = set() 132 | no_decay = set() 133 | whitelist_weight_modules = (torch.nn.Linear, EinLinear) 134 | blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) 135 | for mn, m in self.named_modules(): 136 | for pn, p in m.named_parameters(): 137 | fpn = '%s.%s' % (mn, pn) if mn else pn # full param name 138 | 139 | if pn.endswith('bias'): 140 | # all biases will not be decayed 141 | no_decay.add(fpn) 142 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): 143 | # weights of whitelist modules will be weight decayed 144 | decay.add(fpn) 145 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): 146 | # weights of blacklist modules will NOT be weight decayed 147 | no_decay.add(fpn) 148 | 149 | # special case the position embedding parameter in the root GPT module as not decayed 150 | no_decay.add('pos_emb') 151 | 152 | # validate that we considered every parameter 153 | param_dict = {pn: p for pn, p in self.named_parameters()} 154 | inter_params = decay & no_decay 155 | union_params = decay | no_decay 156 | assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) 157 | assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ 158 | % (str(param_dict.keys() - union_params), ) 159 | 160 | # create the pytorch optimizer object 161 | optim_groups = [ 162 | {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay}, 163 | {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, 164 | ] 165 | optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas) 166 | return optimizer 167 | 168 | def offset_tokens(self, idx): 169 | _, t = idx.shape 170 | n_states = int(np.ceil(t / self.transition_dim)) 171 | offsets = torch.arange(self.transition_dim) * self.vocab_size 172 | offsets = offsets.repeat(n_states).to(idx.device) 173 | offset_idx = idx + offsets[:t] 174 | offset_idx[idx == self.vocab_size] = self.stop_token 175 | return offset_idx 176 | 177 | def pad_to_full_observation(self, x, verify=False): 178 | b, t, _ = x.shape 179 | n_pad = (self.transition_dim - t % self.transition_dim) % self.transition_dim 180 | padding = torch.zeros(b, n_pad, self.embedding_dim, device=x.device) 181 | ## [ B x T' x embedding_dim ] 182 | x_pad = torch.cat([x, padding], dim=1) 183 | ## [ (B * T' / transition_dim) x transition_dim x embedding_dim ] 184 | x_pad = x_pad.view(-1, self.transition_dim, self.embedding_dim) 185 | if verify: 186 | self.verify(x, x_pad) 187 | return x_pad, n_pad 188 | 189 | def verify(self, x, x_pad): 190 | b, t, embedding_dim = x.shape 191 | n_states = int(np.ceil(t / self.transition_dim)) 192 | inds = torch.arange(0, self.transition_dim).repeat(n_states)[:t] 193 | for i in range(self.transition_dim): 194 | x_ = x[:,inds == i] 195 | t_ = x_.shape[1] 196 | x_pad_ = x_pad[:,i].view(b, n_states, embedding_dim)[:,:t_] 197 | print(i, x_.shape, x_pad_.shape) 198 | try: 199 | assert (x_ == x_pad_).all() 200 | except: 201 | pdb.set_trace() 202 | 203 | def forward(self, idx, targets=None, mask=None): 204 | """ 205 | idx : [ B x T ] 206 | values : [ B x 1 x 1 ] 207 | """ 208 | b, t = idx.size() 209 | assert t <= self.block_size, "Cannot forward, model block size is exhausted." 210 | 211 | offset_idx = self.offset_tokens(idx) 212 | ## [ B x T x embedding_dim ] 213 | # forward the GPT model 214 | token_embeddings = self.tok_emb(offset_idx) # each index maps to a (learnable) vector 215 | ## [ 1 x T x embedding_dim ] 216 | position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector 217 | ## [ B x T x embedding_dim ] 218 | x = self.drop(token_embeddings + position_embeddings) 219 | x = self.blocks(x) 220 | ## [ B x T x embedding_dim ] 221 | x = self.ln_f(x) 222 | 223 | ## [ (B * T' / transition_dim) x transition_dim x embedding_dim ] 224 | x_pad, n_pad = self.pad_to_full_observation(x) 225 | ## [ (B * T' / transition_dim) x transition_dim x (vocab_size + 1) ] 226 | logits = self.head(x_pad) 227 | ## [ B x T' x (vocab_size + 1) ] 228 | logits = logits.reshape(b, t + n_pad, self.vocab_size + 1) 229 | ## [ B x T x (vocab_size + 1) ] 230 | logits = logits[:,:t] 231 | 232 | # if we are given some desired targets also calculate the loss 233 | if targets is not None: 234 | loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.view(-1), reduction='none') 235 | if self.action_weight != 1 or self.reward_weight != 1 or self.value_weight != 1: 236 | #### make weights 237 | n_states = int(np.ceil(t / self.transition_dim)) 238 | weights = torch.cat([ 239 | torch.ones(self.observation_dim, device=idx.device), 240 | torch.ones(self.action_dim, device=idx.device) * self.action_weight, 241 | torch.ones(1, device=idx.device) * self.reward_weight, 242 | torch.ones(1, device=idx.device) * self.value_weight, 243 | ]) 244 | ## [ t + 1] 245 | weights = weights.repeat(n_states) 246 | ## [ b x t ] 247 | weights = weights[1:].repeat(b, 1) 248 | #### 249 | loss = loss * weights.view(-1) 250 | loss = (loss * mask.view(-1)).mean() 251 | else: 252 | loss = None 253 | 254 | return logits, loss 255 | 256 | class ConditionalGPT(GPT): 257 | 258 | def __init__(self, config): 259 | ## increase block size by `observation_dim` because we are prepending a goal observation 260 | ## to the sequence 261 | config.block_size += config.observation_dim 262 | super().__init__(config) 263 | self.goal_emb = nn.Embedding(config.vocab_size * config.observation_dim, config.n_embd) 264 | 265 | def get_block_size(self): 266 | return self.block_size - self.observation_dim 267 | 268 | def forward(self, idx, goal, targets=None, mask=None): 269 | b, t = idx.size() 270 | assert t <= self.block_size, "Cannot forward, model block size is exhausted." 271 | 272 | #### goal 273 | offset_goal = self.offset_tokens(goal) 274 | goal_embeddings = self.goal_emb(offset_goal) 275 | #### /goal 276 | 277 | offset_idx = self.offset_tokens(idx) 278 | ## [ B x T x embedding_dim ] 279 | # forward the GPT model 280 | token_embeddings = self.tok_emb(offset_idx) # each index maps to a (learnable) vector 281 | ## [ 1 x T x embedding_dim ] 282 | position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector 283 | ## [ B x T x embedding_dim ] 284 | x = self.drop(token_embeddings + position_embeddings) 285 | 286 | #### goal 287 | ## [ B + (obs_dim + T) x embedding_dim ] 288 | gx = torch.cat([goal_embeddings, x], dim=1) 289 | gx = self.blocks(gx) 290 | x = gx[:, self.observation_dim:] 291 | #### /goal 292 | 293 | ## [ B x T x embedding_dim ] 294 | x = self.ln_f(x) 295 | 296 | ## [ (B * T' / transition_dim) x transition_dim x embedding_dim ] 297 | x_pad, n_pad = self.pad_to_full_observation(x) 298 | ## [ (B * T' / transition_dim) x transition_dim x (vocab_size + 1) ] 299 | logits = self.head(x_pad) 300 | ## [ B x T' x (vocab_size + 1) ] 301 | logits = logits.reshape(b, t + n_pad, self.vocab_size + 1) 302 | ## [ B x T x (vocab_size + 1) ] 303 | logits = logits[:,:t] 304 | 305 | # if we are given some desired targets also calculate the loss 306 | if targets is not None: 307 | loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.view(-1), reduction='none') 308 | if self.action_weight != 1 or self.reward_weight != 1 or self.value_weight != 1: 309 | #### make weights 310 | n_states = int(np.ceil(t / self.transition_dim)) 311 | weights = torch.cat([ 312 | torch.ones(self.observation_dim, device=idx.device), 313 | torch.ones(self.action_dim, device=idx.device) * self.action_weight, 314 | torch.ones(1, device=idx.device) * self.reward_weight, 315 | torch.ones(1, device=idx.device) * self.value_weight, 316 | ]) 317 | ## [ t + 1] 318 | weights = weights.repeat(n_states) 319 | ## [ b x t ] 320 | weights = weights[1:].repeat(b, 1) 321 | #### 322 | loss = loss * weights.view(-1) 323 | loss = (loss * mask.view(-1)).mean() 324 | else: 325 | loss = None 326 | 327 | return logits, loss 328 | -------------------------------------------------------------------------------- /trajectory/search/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import * 2 | from .utils import * -------------------------------------------------------------------------------- /trajectory/search/core.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pdb 4 | 5 | from .. import utils 6 | from .sampling import sample_n, get_logp, sort_2d 7 | 8 | REWARD_DIM = VALUE_DIM = 1 9 | 10 | @torch.no_grad() 11 | def beam_plan( 12 | model, value_fn, x, 13 | n_steps, beam_width, n_expand, 14 | observation_dim, action_dim, 15 | discount=0.99, max_context_transitions=None, 16 | k_obs=None, k_act=None, k_rew=1, 17 | cdf_obs=None, cdf_act=None, cdf_rew=None, 18 | verbose=True, previous_actions=None, 19 | ): 20 | ''' 21 | x : tensor[ 1 x input_sequence_length ] 22 | ''' 23 | 24 | inp = x.clone() 25 | 26 | # convert max number of transitions to max number of tokens 27 | transition_dim = observation_dim + action_dim + REWARD_DIM + VALUE_DIM 28 | max_block = max_context_transitions * transition_dim - 1 if max_context_transitions else None 29 | 30 | ## pass in max numer of tokens to sample function 31 | sample_kwargs = { 32 | 'max_block': max_block, 33 | 'crop_increment': transition_dim, 34 | } 35 | 36 | ## repeat input for search 37 | x = x.repeat(beam_width, 1) 38 | 39 | ## construct reward and discount tensors for estimating values 40 | rewards = torch.zeros(beam_width, n_steps + 1, device=x.device) 41 | discounts = discount ** torch.arange(n_steps + 1, device=x.device) 42 | 43 | ## logging 44 | progress = utils.Progress(n_steps) if verbose else utils.Silent() 45 | 46 | for t in range(n_steps): 47 | ## repeat everything by `n_expand` before we sample actions 48 | x = x.repeat(n_expand, 1) 49 | rewards = rewards.repeat(n_expand, 1) 50 | 51 | ## sample actions 52 | x, _ = sample_n(model, x, action_dim, topk=k_act, cdf=cdf_act, **sample_kwargs) 53 | 54 | ## sample reward and value estimate 55 | x, r_probs = sample_n(model, x, REWARD_DIM + VALUE_DIM, topk=k_rew, cdf=cdf_rew, **sample_kwargs) 56 | 57 | ## optionally, use a percentile or mean of the reward and 58 | ## value distributions instead of sampled tokens 59 | r_t, V_t = value_fn(r_probs) 60 | 61 | ## update rewards tensor 62 | rewards[:, t] = r_t 63 | rewards[:, t+1] = V_t 64 | 65 | ## estimate values using rewards up to `t` and terminal value at `t` 66 | values = (rewards * discounts).sum(dim=-1) 67 | 68 | ## get `beam_width` best actions 69 | values, inds = torch.topk(values, beam_width) 70 | 71 | ## index into search candidates to retain `beam_width` highest-reward sequences 72 | x = x[inds] 73 | rewards = rewards[inds] 74 | 75 | ## sample next observation (unless we have reached the end of the planning horizon) 76 | if t < n_steps - 1: 77 | x, _ = sample_n(model, x, observation_dim, topk=k_obs, cdf=cdf_obs, **sample_kwargs) 78 | 79 | ## logging 80 | progress.update({ 81 | 'x': list(x.shape), 82 | 'vmin': values.min(), 'vmax': values.max(), 83 | 'vtmin': V_t.min(), 'vtmax': V_t.max(), 84 | 'discount': discount 85 | }) 86 | 87 | progress.stamp() 88 | 89 | ## [ batch_size x (n_context + n_steps) x transition_dim ] 90 | x = x.view(beam_width, -1, transition_dim) 91 | 92 | ## crop out context transitions 93 | ## [ batch_size x n_steps x transition_dim ] 94 | x = x[:, -n_steps:] 95 | 96 | ## return best sequence 97 | argmax = values.argmax() 98 | best_sequence = x[argmax] 99 | 100 | return best_sequence 101 | 102 | @torch.no_grad() 103 | def beam_search(model, x, n_steps, beam_width=512, goal=None, **sample_kwargs): 104 | batch_size = len(x) 105 | 106 | prefix_i = torch.arange(len(x), dtype=torch.long, device=x.device) 107 | cumulative_logp = torch.zeros(batch_size, 1, device=x.device) 108 | 109 | for t in range(n_steps): 110 | 111 | if goal is not None: 112 | goal_rep = goal.repeat(len(x), 1) 113 | logp = get_logp(model, x, goal=goal_rep, **sample_kwargs) 114 | else: 115 | logp = get_logp(model, x, **sample_kwargs) 116 | 117 | candidate_logp = cumulative_logp + logp 118 | sorted_logp, sorted_i, sorted_j = sort_2d(candidate_logp) 119 | 120 | n_candidates = (candidate_logp > -np.inf).sum().item() 121 | n_retain = min(n_candidates, beam_width) 122 | cumulative_logp = sorted_logp[:n_retain].unsqueeze(-1) 123 | 124 | sorted_i = sorted_i[:n_retain] 125 | sorted_j = sorted_j[:n_retain].unsqueeze(-1) 126 | 127 | x = torch.cat([x[sorted_i], sorted_j], dim=-1) 128 | prefix_i = prefix_i[sorted_i] 129 | 130 | x = x[0] 131 | return x, cumulative_logp.squeeze() 132 | -------------------------------------------------------------------------------- /trajectory/search/sampling.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pdb 4 | 5 | #-------------------------------- helper functions --------------------------------# 6 | 7 | def top_k_logits(logits, k): 8 | v, ix = torch.topk(logits, k) 9 | out = logits.clone() 10 | out[out < v[:, [-1]]] = -float('Inf') 11 | return out 12 | 13 | def filter_cdf(logits, threshold): 14 | batch_inds = torch.arange(logits.shape[0], device=logits.device, dtype=torch.long) 15 | bins_inds = torch.arange(logits.shape[-1], device=logits.device) 16 | probs = logits.softmax(dim=-1) 17 | probs_sorted, _ = torch.sort(probs, dim=-1) 18 | probs_cum = torch.cumsum(probs_sorted, dim=-1) 19 | ## get minimum probability p such that the cdf up to p is at least `threshold` 20 | mask = probs_cum < threshold 21 | masked_inds = torch.argmax(mask * bins_inds, dim=-1) 22 | probs_threshold = probs_sorted[batch_inds, masked_inds] 23 | ## filter 24 | out = logits.clone() 25 | logits_mask = probs <= probs_threshold.unsqueeze(dim=-1) 26 | out[logits_mask] = -1000 27 | return out 28 | 29 | def round_to_multiple(x, N): 30 | ''' 31 | Rounds `x` up to nearest multiple of `N`. 32 | 33 | x : int 34 | N : int 35 | ''' 36 | pad = (N - x % N) % N 37 | return x + pad 38 | 39 | def sort_2d(x): 40 | ''' 41 | x : [ M x N ] 42 | ''' 43 | M, N = x.shape 44 | x = x.view(-1) 45 | x_sort, inds = torch.sort(x, descending=True) 46 | 47 | rows = inds // N 48 | cols = inds % N 49 | 50 | return x_sort, rows, cols 51 | 52 | #-------------------------------- forward pass --------------------------------# 53 | 54 | def forward(model, x, max_block=None, allow_crop=True, crop_increment=None, **kwargs): 55 | ''' 56 | A wrapper around a single forward pass of the transformer. 57 | Crops the input if the sequence is too long. 58 | 59 | x : tensor[ batch_size x sequence_length ] 60 | ''' 61 | model.eval() 62 | 63 | block_size = min(model.get_block_size(), max_block or np.inf) 64 | 65 | if x.shape[1] > block_size: 66 | assert allow_crop, ( 67 | f'[ search/sampling ] input size is {x.shape} and block size is {block_size}, ' 68 | 'but cropping not allowed') 69 | 70 | ## crop out entire transition at a time so that the first token is always s_t^0 71 | n_crop = round_to_multiple(x.shape[1] - block_size, crop_increment) 72 | assert n_crop % crop_increment == 0 73 | x = x[:, n_crop:] 74 | 75 | logits, _ = model(x, **kwargs) 76 | 77 | return logits 78 | 79 | def get_logp(model, x, temperature=1.0, topk=None, cdf=None, **forward_kwargs): 80 | ''' 81 | x : tensor[ batch_size x sequence_length ] 82 | ''' 83 | ## [ batch_size x sequence_length x vocab_size ] 84 | logits = forward(model, x, **forward_kwargs) 85 | 86 | ## pluck the logits at the final step and scale by temperature 87 | ## [ batch_size x vocab_size ] 88 | logits = logits[:, -1] / temperature 89 | 90 | ## optionally crop logits to only the top `1 - cdf` percentile 91 | if cdf is not None: 92 | logits = filter_cdf(logits, cdf) 93 | 94 | ## optionally crop logits to only the most likely `k` options 95 | if topk is not None: 96 | logits = top_k_logits(logits, topk) 97 | 98 | ## apply softmax to convert to probabilities 99 | logp = logits.log_softmax(dim=-1) 100 | 101 | return logp 102 | 103 | #-------------------------------- sampling --------------------------------# 104 | 105 | def sample(model, x, temperature=1.0, topk=None, cdf=None, **forward_kwargs): 106 | ''' 107 | Samples from the distribution parameterized by `model(x)`. 108 | 109 | x : tensor[ batch_size x sequence_length ] 110 | ''' 111 | ## [ batch_size x sequence_length x vocab_size ] 112 | logits = forward(model, x, **forward_kwargs) 113 | 114 | ## pluck the logits at the final step and scale by temperature 115 | ## [ batch_size x vocab_size ] 116 | logits = logits[:, -1] / temperature 117 | 118 | ## keep track of probabilities before modifying logits 119 | raw_probs = logits.softmax(dim=-1) 120 | 121 | ## optionally crop logits to only the top `1 - cdf` percentile 122 | if cdf is not None: 123 | logits = filter_cdf(logits, cdf) 124 | 125 | ## optionally crop logits to only the most likely `k` options 126 | if topk is not None: 127 | logits = top_k_logits(logits, topk) 128 | 129 | ## apply softmax to convert to probabilities 130 | probs = logits.softmax(dim=-1) 131 | 132 | ## sample from the distribution 133 | ## [ batch_size x 1 ] 134 | indices = torch.multinomial(probs, num_samples=1) 135 | 136 | return indices, raw_probs 137 | 138 | @torch.no_grad() 139 | def sample_n(model, x, N, **sample_kwargs): 140 | batch_size = len(x) 141 | 142 | ## keep track of probabilities from each step; 143 | ## `vocab_size + 1` accounts for termination token 144 | probs = torch.zeros(batch_size, N, model.vocab_size + 1, device=x.device) 145 | 146 | for n in range(N): 147 | indices, p = sample(model, x, **sample_kwargs) 148 | 149 | ## append to the sequence and continue 150 | ## [ batch_size x (sequence_length + n) ] 151 | x = torch.cat((x, indices), dim=1) 152 | 153 | probs[:, n] = p 154 | 155 | return x, probs 156 | -------------------------------------------------------------------------------- /trajectory/search/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pdb 4 | 5 | from ..utils.arrays import to_torch 6 | 7 | VALUE_PLACEHOLDER = 1e6 8 | 9 | def make_prefix(discretizer, context, obs, prefix_context=True): 10 | observation_dim = obs.size 11 | obs_discrete = discretizer.discretize(obs, subslice=[0, observation_dim]) 12 | obs_discrete = to_torch(obs_discrete, dtype=torch.long) 13 | 14 | if prefix_context: 15 | prefix = torch.cat(context + [obs_discrete], dim=-1) 16 | else: 17 | prefix = obs_discrete 18 | 19 | return prefix 20 | 21 | def extract_actions(x, observation_dim, action_dim, t=None): 22 | assert x.shape[1] == observation_dim + action_dim + 2 23 | actions = x[:, observation_dim:observation_dim+action_dim] 24 | if t is not None: 25 | return actions[t] 26 | else: 27 | return actions 28 | 29 | def update_context(context, discretizer, observation, action, reward, max_context_transitions): 30 | ''' 31 | context : list of transitions 32 | [ tensor( transition_dim ), ... ] 33 | ''' 34 | ## use a placeholder for value because input values are masked out by model 35 | rew_val = np.array([reward, VALUE_PLACEHOLDER]) 36 | transition = np.concatenate([observation, action, rew_val]) 37 | 38 | ## discretize transition and convert to torch tensor 39 | transition_discrete = discretizer.discretize(transition) 40 | transition_discrete = to_torch(transition_discrete, dtype=torch.long) 41 | 42 | ## add new transition to context 43 | context.append(transition_discrete) 44 | 45 | ## crop context if necessary 46 | context = context[-max_context_transitions:] 47 | 48 | return context -------------------------------------------------------------------------------- /trajectory/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .setup import Parser, watch 2 | from .arrays import * 3 | from .serialization import * 4 | from .progress import Progress, Silent 5 | from .rendering import make_renderer 6 | # from .video import * 7 | from .config import Config 8 | from .training import Trainer 9 | -------------------------------------------------------------------------------- /trajectory/utils/arrays.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | DTYPE = torch.float 5 | DEVICE = 'cuda:0' 6 | 7 | def to_np(x): 8 | if torch.is_tensor(x): 9 | x = x.detach().cpu().numpy() 10 | return x 11 | 12 | def to_torch(x, dtype=None, device=None): 13 | dtype = dtype or DTYPE 14 | device = device or DEVICE 15 | return torch.tensor(x, dtype=dtype, device=device) 16 | 17 | def to_device(*xs, device=DEVICE): 18 | return [x.to(device) for x in xs] 19 | 20 | def normalize(x): 21 | """ 22 | scales `x` to [0, 1] 23 | """ 24 | x = x - x.min() 25 | x = x / x.max() 26 | return x 27 | 28 | def to_img(x): 29 | normalized = normalize(x) 30 | array = to_np(normalized) 31 | array = np.transpose(array, (1,2,0)) 32 | return (array * 255).astype(np.uint8) 33 | 34 | def set_device(device): 35 | DEVICE = device 36 | if 'cuda' in device: 37 | torch.set_default_tensor_type(torch.cuda.FloatTensor) 38 | -------------------------------------------------------------------------------- /trajectory/utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import collections 3 | import pickle 4 | 5 | class Config(collections.Mapping): 6 | 7 | def __init__(self, _class, verbose=True, savepath=None, **kwargs): 8 | self._class = _class 9 | self._dict = {} 10 | 11 | for key, val in kwargs.items(): 12 | self._dict[key] = val 13 | 14 | if verbose: 15 | print(self) 16 | 17 | if savepath is not None: 18 | savepath = os.path.join(*savepath) if type(savepath) is tuple else savepath 19 | pickle.dump(self, open(savepath, 'wb')) 20 | print(f'Saved config to: {savepath}\n') 21 | 22 | def __repr__(self): 23 | string = f'\nConfig: {self._class}\n' 24 | for key in sorted(self._dict.keys()): 25 | val = self._dict[key] 26 | string += f' {key}: {val}\n' 27 | return string 28 | 29 | def __iter__(self): 30 | return iter(self._dict) 31 | 32 | def __getitem__(self, item): 33 | return self._dict[item] 34 | 35 | def __len__(self): 36 | return len(self._dict) 37 | 38 | def __call__(self): 39 | return self.make() 40 | 41 | def __getattr__(self, attr): 42 | if attr == '_dict' and '_dict' not in vars(self): 43 | self._dict = {} 44 | try: 45 | return self._dict[attr] 46 | except KeyError: 47 | raise AttributeError(attr) 48 | 49 | def make(self): 50 | if 'GPT' in str(self._class) or 'Trainer' in str(self._class): 51 | ## GPT class expects the config as the sole input 52 | return self._class(self) 53 | else: 54 | return self._class(**self._dict) 55 | -------------------------------------------------------------------------------- /trajectory/utils/discretization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pdb 4 | 5 | from .arrays import to_np, to_torch 6 | 7 | class QuantileDiscretizer: 8 | 9 | def __init__(self, data, N): 10 | self.data = data 11 | self.N = N 12 | 13 | n_points_per_bin = int(np.ceil(len(data) / N)) 14 | obs_sorted = np.sort(data, axis=0) 15 | thresholds = obs_sorted[::n_points_per_bin, :] 16 | maxs = data.max(axis=0, keepdims=True) 17 | 18 | ## [ (N + 1) x dim ] 19 | self.thresholds = np.concatenate([thresholds, maxs], axis=0) 20 | 21 | # threshold_inds = np.linspace(0, len(data) - 1, N + 1, dtype=int) 22 | # obs_sorted = np.sort(data, axis=0) 23 | 24 | # ## [ (N + 1) x dim ] 25 | # self.thresholds = obs_sorted[threshold_inds, :] 26 | 27 | ## [ N x dim ] 28 | self.diffs = self.thresholds[1:] - self.thresholds[:-1] 29 | 30 | ## for sparse reward tasks 31 | # if (self.diffs[:,-1] == 0).any(): 32 | # raise RuntimeError('rebin for sparse reward tasks') 33 | 34 | self._test() 35 | 36 | def __call__(self, x): 37 | indices = self.discretize(x) 38 | recon = self.reconstruct(indices) 39 | error = np.abs(recon - x).max(0) 40 | return indices, recon, error 41 | 42 | def _test(self): 43 | print('[ utils/discretization ] Testing...', end=' ', flush=True) 44 | inds = np.random.randint(0, len(self.data), size=1000) 45 | X = self.data[inds] 46 | indices = self.discretize(X) 47 | recon = self.reconstruct(indices) 48 | ## make sure reconstruction error is less than the max allowed per dimension 49 | error = np.abs(X - recon).max(0) 50 | assert (error <= self.diffs.max(axis=0)).all() 51 | ## re-discretize reconstruction and make sure it is the same as original indices 52 | indices_2 = self.discretize(recon) 53 | assert (indices == indices_2).all() 54 | ## reconstruct random indices 55 | ## @TODO: remove duplicate thresholds 56 | # randint = np.random.randint(0, self.N, indices.shape) 57 | # randint_2 = self.discretize(self.reconstruct(randint)) 58 | # assert (randint == randint_2).all() 59 | print('✓') 60 | 61 | def discretize(self, x, subslice=(None, None)): 62 | ''' 63 | x : [ B x observation_dim ] 64 | ''' 65 | 66 | if torch.is_tensor(x): 67 | x = to_np(x) 68 | 69 | ## enforce batch mode 70 | if x.ndim == 1: 71 | x = x[None] 72 | 73 | ## [ N x B x observation_dim ] 74 | start, end = subslice 75 | thresholds = self.thresholds[:, start:end] 76 | 77 | gt = x[None] >= thresholds[:,None] 78 | indices = largest_nonzero_index(gt, dim=0) 79 | 80 | if indices.min() < 0 or indices.max() >= self.N: 81 | indices = np.clip(indices, 0, self.N - 1) 82 | 83 | return indices 84 | 85 | def reconstruct(self, indices, subslice=(None, None)): 86 | 87 | if torch.is_tensor(indices): 88 | indices = to_np(indices) 89 | 90 | ## enforce batch mode 91 | if indices.ndim == 1: 92 | indices = indices[None] 93 | 94 | if indices.min() < 0 or indices.max() >= self.N: 95 | print(f'[ utils/discretization ] indices out of range: ({indices.min()}, {indices.max()}) | N: {self.N}') 96 | indices = np.clip(indices, 0, self.N - 1) 97 | 98 | start, end = subslice 99 | thresholds = self.thresholds[:, start:end] 100 | 101 | left = np.take_along_axis(thresholds, indices, axis=0) 102 | right = np.take_along_axis(thresholds, indices + 1, axis=0) 103 | recon = (left + right) / 2. 104 | return recon 105 | 106 | #---------------------------- wrappers for planning ----------------------------# 107 | 108 | def expectation(self, probs, subslice): 109 | ''' 110 | probs : [ B x N ] 111 | ''' 112 | 113 | if torch.is_tensor(probs): 114 | probs = to_np(probs) 115 | 116 | ## [ N ] 117 | thresholds = self.thresholds[:, subslice] 118 | ## [ B ] 119 | left = probs @ thresholds[:-1] 120 | right = probs @ thresholds[1:] 121 | 122 | avg = (left + right) / 2. 123 | return avg 124 | 125 | def percentile(self, probs, percentile, subslice): 126 | ''' 127 | percentile `p` : 128 | returns least value `v` s.t. cdf up to `v` is >= `p` 129 | e.g., p=0.8 and v=100 indicates that 130 | 100 is in the 80% percentile of values 131 | ''' 132 | ## [ N ] 133 | thresholds = self.thresholds[:, subslice] 134 | ## [ B x N ] 135 | cumulative = np.cumsum(probs, axis=-1) 136 | valid = cumulative > percentile 137 | ## [ B ] 138 | inds = np.argmax(np.arange(self.N, 0, -1) * valid, axis=-1) 139 | left = thresholds[inds-1] 140 | right = thresholds[inds] 141 | avg = (left + right) / 2. 142 | return avg 143 | 144 | #---------------------------- wrappers for planning ----------------------------# 145 | 146 | def value_expectation(self, probs): 147 | ''' 148 | probs : [ B x 2 x ( N + 1 ) ] 149 | extra token comes from termination 150 | ''' 151 | 152 | if torch.is_tensor(probs): 153 | probs = to_np(probs) 154 | return_torch = True 155 | else: 156 | return_torch = False 157 | 158 | probs = probs[:, :, :-1] 159 | assert probs.shape[-1] == self.N 160 | 161 | rewards = self.expectation(probs[:, 0], subslice=-2) 162 | next_values = self.expectation(probs[:, 1], subslice=-1) 163 | 164 | if return_torch: 165 | rewards = to_torch(rewards) 166 | next_values = to_torch(next_values) 167 | 168 | return rewards, next_values 169 | 170 | def value_fn(self, probs, percentile): 171 | if percentile == 'mean': 172 | return self.value_expectation(probs) 173 | else: 174 | ## percentile should be interpretable as float, 175 | ## even if passed in as str because of command-line parser 176 | percentile = float(percentile) 177 | 178 | if torch.is_tensor(probs): 179 | probs = to_np(probs) 180 | return_torch = True 181 | else: 182 | return_torch = False 183 | 184 | probs = probs[:, :, :-1] 185 | assert probs.shape[-1] == self.N 186 | 187 | rewards = self.percentile(probs[:, 0], percentile, subslice=-2) 188 | next_values = self.percentile(probs[:, 1], percentile, subslice=-1) 189 | 190 | if return_torch: 191 | rewards = to_torch(rewards) 192 | next_values = to_torch(next_values) 193 | 194 | return rewards, next_values 195 | 196 | def largest_nonzero_index(x, dim): 197 | N = x.shape[dim] 198 | arange = np.arange(N) + 1 199 | 200 | for i in range(dim): 201 | arange = np.expand_dims(arange, axis=0) 202 | for i in range(dim+1, x.ndim): 203 | arange = np.expand_dims(arange, axis=-1) 204 | 205 | inds = np.argmax(x * arange, axis=0) 206 | ## masks for all `False` or all `True` 207 | lt_mask = (~x).all(axis=0) 208 | gt_mask = (x).all(axis=0) 209 | 210 | inds[lt_mask] = 0 211 | inds[gt_mask] = N 212 | 213 | return inds 214 | -------------------------------------------------------------------------------- /trajectory/utils/git_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import git 3 | import pdb 4 | 5 | PROJECT_PATH = os.path.dirname( 6 | os.path.realpath(os.path.join(__file__, '..', '..'))) 7 | 8 | def get_repo(path=PROJECT_PATH, search_parent_directories=True): 9 | repo = git.Repo( 10 | path, search_parent_directories=search_parent_directories) 11 | return repo 12 | 13 | def get_git_rev(*args, **kwargs): 14 | try: 15 | repo = get_repo(*args, **kwargs) 16 | if repo.head.is_detached: 17 | git_rev = repo.head.object.name_rev 18 | else: 19 | git_rev = repo.active_branch.commit.name_rev 20 | except: 21 | git_rev = None 22 | 23 | return git_rev 24 | 25 | def git_diff(*args, **kwargs): 26 | repo = get_repo(*args, **kwargs) 27 | diff = repo.git.diff() 28 | return diff 29 | 30 | def save_git_diff(savepath, *args, **kwargs): 31 | diff = git_diff(*args, **kwargs) 32 | with open(savepath, 'w') as f: 33 | f.write(diff) 34 | 35 | if __name__ == '__main__': 36 | 37 | git_rev = get_git_rev() 38 | print(git_rev) 39 | 40 | save_git_diff('diff_test.txt') -------------------------------------------------------------------------------- /trajectory/utils/progress.py: -------------------------------------------------------------------------------- 1 | import time 2 | import math 3 | import re 4 | import pdb 5 | 6 | class Progress: 7 | 8 | def __init__(self, total, name = 'Progress', ncol=3, max_length=30, indent=8, line_width=100, speed_update_freq=100): 9 | self.total = total 10 | self.name = name 11 | self.ncol = ncol 12 | self.max_length = max_length 13 | self.indent = indent 14 | self.line_width = line_width 15 | self._speed_update_freq = speed_update_freq 16 | 17 | self._step = 0 18 | self._prev_line = '\033[F' 19 | self._clear_line = ' ' * self.line_width 20 | 21 | self._pbar_size = self.ncol * self.max_length 22 | self._complete_pbar = '#' * self._pbar_size 23 | self._incomplete_pbar = ' ' * self._pbar_size 24 | 25 | self.lines = [''] 26 | self.fraction = '{} / {}'.format(0, self.total) 27 | 28 | self.resume() 29 | 30 | 31 | def update(self, description, n=1): 32 | self._step += n 33 | if self._step % self._speed_update_freq == 0: 34 | self._time0 = time.time() 35 | self._step0 = self._step 36 | self.set_description(description) 37 | 38 | def resume(self): 39 | self._skip_lines = 1 40 | print('\n', end='') 41 | self._time0 = time.time() 42 | self._step0 = self._step 43 | 44 | def pause(self): 45 | self._clear() 46 | self._skip_lines = 1 47 | 48 | def set_description(self, params=[]): 49 | 50 | if type(params) == dict: 51 | params = sorted([ 52 | (key, val) 53 | for key, val in params.items() 54 | ]) 55 | 56 | ############ 57 | # Position # 58 | ############ 59 | self._clear() 60 | 61 | ########### 62 | # Percent # 63 | ########### 64 | percent, fraction = self._format_percent(self._step, self.total) 65 | self.fraction = fraction 66 | 67 | ######### 68 | # Speed # 69 | ######### 70 | speed = self._format_speed(self._step) 71 | 72 | ########## 73 | # Params # 74 | ########## 75 | num_params = len(params) 76 | nrow = math.ceil(num_params / self.ncol) 77 | params_split = self._chunk(params, self.ncol) 78 | params_string, lines = self._format(params_split) 79 | self.lines = lines 80 | 81 | 82 | description = '{} | {}{}'.format(percent, speed, params_string) 83 | print(description) 84 | self._skip_lines = nrow + 1 85 | 86 | def append_description(self, descr): 87 | self.lines.append(descr) 88 | 89 | def _clear(self): 90 | position = self._prev_line * self._skip_lines 91 | empty = '\n'.join([self._clear_line for _ in range(self._skip_lines)]) 92 | print(position, end='') 93 | print(empty) 94 | print(position, end='') 95 | 96 | def _format_percent(self, n, total): 97 | if total: 98 | percent = n / float(total) 99 | 100 | complete_entries = int(percent * self._pbar_size) 101 | incomplete_entries = self._pbar_size - complete_entries 102 | 103 | pbar = self._complete_pbar[:complete_entries] + self._incomplete_pbar[:incomplete_entries] 104 | fraction = '{} / {}'.format(n, total) 105 | string = '{} [{}] {:3d}%'.format(fraction, pbar, int(percent*100)) 106 | else: 107 | fraction = '{}'.format(n) 108 | string = '{} iterations'.format(n) 109 | return string, fraction 110 | 111 | def _format_speed(self, n): 112 | num_steps = n - self._step0 113 | t = time.time() - self._time0 114 | speed = num_steps / t 115 | string = '{:.1f} Hz'.format(speed) 116 | if num_steps > 0: 117 | self._speed = string 118 | return string 119 | 120 | def _chunk(self, l, n): 121 | return [l[i:i+n] for i in range(0, len(l), n)] 122 | 123 | def _format(self, chunks): 124 | lines = [self._format_chunk(chunk) for chunk in chunks] 125 | lines.insert(0,'') 126 | padding = '\n' + ' '*self.indent 127 | string = padding.join(lines) 128 | return string, lines 129 | 130 | def _format_chunk(self, chunk): 131 | line = ' | '.join([self._format_param(param) for param in chunk]) 132 | return line 133 | 134 | def _format_param(self, param, str_length=8): 135 | k, v = param 136 | k = k.rjust(str_length) 137 | if type(v) == float or hasattr(v, 'item'): 138 | string = '{}: {:12.4f}' 139 | else: 140 | string = '{}: {}' 141 | v = str(v).rjust(12) 142 | return string.format(k, v)[:self.max_length] 143 | 144 | def stamp(self): 145 | if self.lines != ['']: 146 | params = ' | '.join(self.lines) 147 | string = '[ {} ] {}{} | {}'.format(self.name, self.fraction, params, self._speed) 148 | string = re.sub(r'\s+', ' ', string) 149 | self._clear() 150 | print(string, end='\n') 151 | self._skip_lines = 1 152 | else: 153 | self._clear() 154 | self._skip_lines = 0 155 | 156 | def close(self): 157 | self.pause() 158 | 159 | class Silent: 160 | 161 | def __init__(self, *args, **kwargs): 162 | pass 163 | 164 | def __getattr__(self, attr): 165 | return lambda *args: None 166 | 167 | 168 | if __name__ == '__main__': 169 | silent = Silent() 170 | silent.update() 171 | silent.stamp() 172 | 173 | num_steps = 1000 174 | progress = Progress(num_steps) 175 | for i in range(num_steps): 176 | progress.update() 177 | params = [ 178 | ['A', '{:06d}'.format(i)], 179 | ['B', '{:06d}'.format(i)], 180 | ['C', '{:06d}'.format(i)], 181 | ['D', '{:06d}'.format(i)], 182 | ['E', '{:06d}'.format(i)], 183 | ['F', '{:06d}'.format(i)], 184 | ['G', '{:06d}'.format(i)], 185 | ['H', '{:06d}'.format(i)], 186 | ] 187 | progress.set_description(params) 188 | time.sleep(0.01) 189 | progress.close() 190 | -------------------------------------------------------------------------------- /trajectory/utils/rendering.py: -------------------------------------------------------------------------------- 1 | import time 2 | import sys 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import torch 6 | import gym 7 | import mujoco_py as mjc 8 | import pdb 9 | 10 | from .arrays import to_np 11 | from .video import save_video, save_videos 12 | from ..datasets import load_environment, get_preprocess_fn 13 | 14 | def make_renderer(args): 15 | render_str = getattr(args, 'renderer') 16 | render_class = getattr(sys.modules[__name__], render_str) 17 | ## get dimensions in case the observations are preprocessed 18 | env = load_environment(args.dataset) 19 | preprocess_fn = get_preprocess_fn(args.dataset) 20 | observation = env.reset() 21 | observation = preprocess_fn(observation) 22 | return render_class(args.dataset, observation_dim=observation.size) 23 | 24 | def split(sequence, observation_dim, action_dim): 25 | assert sequence.shape[1] == observation_dim + action_dim + 2 26 | observations = sequence[:, :observation_dim] 27 | actions = sequence[:, observation_dim:observation_dim+action_dim] 28 | rewards = sequence[:, -2] 29 | values = sequence[:, -1] 30 | return observations, actions, rewards, values 31 | 32 | def set_state(env, state): 33 | qpos_dim = env.sim.data.qpos.size 34 | qvel_dim = env.sim.data.qvel.size 35 | qstate_dim = qpos_dim + qvel_dim 36 | 37 | if 'ant' in env.name: 38 | ypos = np.zeros(1) 39 | state = np.concatenate([ypos, state]) 40 | 41 | if state.size == qpos_dim - 1 or state.size == qstate_dim - 1: 42 | xpos = np.zeros(1) 43 | state = np.concatenate([xpos, state]) 44 | 45 | if state.size == qpos_dim: 46 | qvel = np.zeros(qvel_dim) 47 | state = np.concatenate([state, qvel]) 48 | 49 | if 'ant' in env.name and state.size > qpos_dim + qvel_dim: 50 | xpos = np.zeros(1) 51 | state = np.concatenate([xpos, state])[:qstate_dim] 52 | 53 | assert state.size == qpos_dim + qvel_dim 54 | 55 | env.set_state(state[:qpos_dim], state[qpos_dim:]) 56 | 57 | def rollout_from_state(env, state, actions): 58 | qpos_dim = env.sim.data.qpos.size 59 | env.set_state(state[:qpos_dim], state[qpos_dim:]) 60 | observations = [env._get_obs()] 61 | for act in actions: 62 | obs, rew, term, _ = env.step(act) 63 | observations.append(obs) 64 | if term: 65 | break 66 | for i in range(len(observations), len(actions)+1): 67 | ## if terminated early, pad with zeros 68 | observations.append( np.zeros(obs.size) ) 69 | return np.stack(observations) 70 | 71 | class DebugRenderer: 72 | 73 | def __init__(self, *args, **kwargs): 74 | pass 75 | 76 | def render(self, *args, **kwargs): 77 | return np.zeros((10, 10, 3)) 78 | 79 | def render_plan(self, *args, **kwargs): 80 | pass 81 | 82 | def render_rollout(self, *args, **kwargs): 83 | pass 84 | 85 | class Renderer: 86 | 87 | def __init__(self, env, observation_dim=None, action_dim=None): 88 | if type(env) is str: 89 | self.env = load_environment(env) 90 | else: 91 | self.env = env 92 | 93 | self.observation_dim = observation_dim or np.prod(self.env.observation_space.shape) 94 | self.action_dim = action_dim or np.prod(self.env.action_space.shape) 95 | self.viewer = mjc.MjRenderContextOffscreen(self.env.sim) 96 | 97 | def __call__(self, *args, **kwargs): 98 | return self.renders(*args, **kwargs) 99 | 100 | def render(self, observation, dim=256, render_kwargs=None): 101 | observation = to_np(observation) 102 | 103 | if render_kwargs is None: 104 | render_kwargs = { 105 | 'trackbodyid': 2, 106 | 'distance': 3, 107 | 'lookat': [0, -0.5, 1], 108 | 'elevation': -20 109 | } 110 | 111 | for key, val in render_kwargs.items(): 112 | if key == 'lookat': 113 | self.viewer.cam.lookat[:] = val[:] 114 | else: 115 | setattr(self.viewer.cam, key, val) 116 | 117 | set_state(self.env, observation) 118 | 119 | if type(dim) == int: 120 | dim = (dim, dim) 121 | 122 | self.viewer.render(*dim) 123 | data = self.viewer.read_pixels(*dim, depth=False) 124 | data = data[::-1, :, :] 125 | return data 126 | 127 | def renders(self, observations, **kwargs): 128 | images = [] 129 | for observation in observations: 130 | img = self.render(observation, **kwargs) 131 | images.append(img) 132 | return np.stack(images, axis=0) 133 | 134 | def render_plan(self, savepath, sequence, state, fps=30): 135 | ''' 136 | state : np.array[ observation_dim ] 137 | sequence : np.array[ horizon x transition_dim ] 138 | as usual, sequence is ordered as [ s_t, a_t, r_t, V_t, ... ] 139 | ''' 140 | 141 | if len(sequence) == 1: 142 | return 143 | 144 | sequence = to_np(sequence) 145 | 146 | ## compare to ground truth rollout using actions from sequence 147 | actions = sequence[:-1, self.observation_dim : self.observation_dim + self.action_dim] 148 | rollout_states = rollout_from_state(self.env, state, actions) 149 | 150 | videos = [ 151 | self.renders(sequence[:, :self.observation_dim]), 152 | self.renders(rollout_states), 153 | ] 154 | 155 | save_videos(savepath, *videos, fps=fps) 156 | 157 | def render_rollout(self, savepath, states, **video_kwargs): 158 | images = self(states) 159 | save_video(savepath, images, **video_kwargs) 160 | 161 | class KitchenRenderer: 162 | 163 | def __init__(self, env): 164 | if type(env) is str: 165 | self.env = gym.make(env) 166 | else: 167 | self.env = env 168 | 169 | self.observation_dim = np.prod(self.env.observation_space.shape) 170 | self.action_dim = np.prod(self.env.action_space.shape) 171 | 172 | def set_obs(self, obs, goal_dim=30): 173 | robot_dim = self.env.n_jnt 174 | obj_dim = self.env.n_obj 175 | assert robot_dim + obj_dim + goal_dim == obs.size or robot_dim + obj_dim == obs.size 176 | self.env.sim.data.qpos[:robot_dim] = obs[:robot_dim] 177 | self.env.sim.data.qpos[robot_dim:robot_dim+obj_dim] = obs[robot_dim:robot_dim+obj_dim] 178 | self.env.sim.forward() 179 | 180 | def rollout(self, obs, actions): 181 | self.set_obs(obs) 182 | observations = [env._get_obs()] 183 | for act in actions: 184 | obs, rew, term, _ = env.step(act) 185 | observations.append(obs) 186 | if term: 187 | break 188 | for i in range(len(observations), len(actions)+1): 189 | ## if terminated early, pad with zeros 190 | observations.append( np.zeros(observations[-1].size) ) 191 | return np.stack(observations) 192 | 193 | def render(self, observation, dim=512, onscreen=False): 194 | self.env.sim_robot.renderer._camera_settings.update({ 195 | 'distance': 4.5, 196 | 'azimuth': 90, 197 | 'elevation': -25, 198 | 'lookat': [0, 1, 2], 199 | }) 200 | self.set_obs(observation) 201 | if onscreen: 202 | self.env.render() 203 | return self.env.sim_robot.renderer.render_offscreen(dim, dim) 204 | 205 | def renders(self, observations, **kwargs): 206 | images = [] 207 | for observation in observations: 208 | img = self.render(observation, **kwargs) 209 | images.append(img) 210 | return np.stack(images, axis=0) 211 | 212 | def render_plan(self, *args, **kwargs): 213 | return self.render_rollout(*args, **kwargs) 214 | 215 | def render_rollout(self, savepath, states, **video_kwargs): 216 | images = self(states) #np.stack(states, axis=0)) 217 | save_video(savepath, images, **video_kwargs) 218 | 219 | def __call__(self, *args, **kwargs): 220 | return self.renders(*args, **kwargs) 221 | 222 | ANTMAZE_BOUNDS = { 223 | 'antmaze-umaze-v0': (-3, 11), 224 | 'antmaze-medium-play-v0': (-3, 23), 225 | 'antmaze-medium-diverse-v0': (-3, 23), 226 | 'antmaze-large-play-v0': (-3, 39), 227 | 'antmaze-large-diverse-v0': (-3, 39), 228 | } 229 | 230 | class AntMazeRenderer: 231 | 232 | def __init__(self, env_name): 233 | self.env_name = env_name 234 | self.env = gym.make(env_name).unwrapped 235 | self.observation_dim = np.prod(self.env.observation_space.shape) 236 | self.action_dim = np.prod(self.env.action_space.shape) 237 | 238 | def renders(self, savepath, X): 239 | plt.clf() 240 | 241 | if X.ndim < 3: 242 | X = X[None] 243 | 244 | N, path_length, _ = X.shape 245 | if N > 4: 246 | fig, axes = plt.subplots(4, int(N/4)) 247 | axes = axes.flatten() 248 | fig.set_size_inches(N/4,8) 249 | elif N > 1: 250 | fig, axes = plt.subplots(1, N) 251 | fig.set_size_inches(8,8) 252 | else: 253 | fig, axes = plt.subplots(1, 1) 254 | fig.set_size_inches(8,8) 255 | 256 | colors = plt.cm.jet(np.linspace(0,1,path_length)) 257 | for i in range(N): 258 | ax = axes if N == 1 else axes[i] 259 | xlim, ylim = self.plot_boundaries(ax=ax) 260 | x = X[i] 261 | ax.scatter(x[:,0], x[:,1], c=colors) 262 | ax.set_xticks([]) 263 | ax.set_yticks([]) 264 | ax.set_xlim(*xlim) 265 | ax.set_ylim(*ylim) 266 | plt.savefig(savepath + '.png') 267 | plt.close() 268 | print(f'[ attentive/utils/visualization ] Saved to: {savepath}') 269 | 270 | def plot_boundaries(self, N=100, ax=None): 271 | """ 272 | plots the maze boundaries in the antmaze environments 273 | """ 274 | ax = ax or plt.gca() 275 | 276 | xlim = ANTMAZE_BOUNDS[self.env_name] 277 | ylim = ANTMAZE_BOUNDS[self.env_name] 278 | 279 | X = np.linspace(*xlim, N) 280 | Y = np.linspace(*ylim, N) 281 | 282 | Z = np.zeros((N, N)) 283 | for i, x in enumerate(X): 284 | for j, y in enumerate(Y): 285 | collision = self.env.unwrapped._is_in_collision((x, y)) 286 | Z[-j, i] = collision 287 | 288 | ax.imshow(Z, extent=(*xlim, *ylim), aspect='auto', cmap=plt.cm.binary) 289 | return xlim, ylim 290 | 291 | def render_plan(self, savepath, discretizer, state, sequence): 292 | ''' 293 | state : np.array[ observation_dim ] 294 | sequence : np.array[ horizon x transition_dim ] 295 | as usual, sequence is ordered as [ s_t, a_t, r_t, V_t, ... ] 296 | ''' 297 | 298 | if len(sequence) == 1: 299 | # raise RuntimeError(f'horizon is 1 in Renderer:render_plan: {sequence.shape}') 300 | return 301 | 302 | sequence = to_np(sequence) 303 | 304 | sequence_recon = discretizer.reconstruct(sequence) 305 | 306 | observations, actions, *_ = split(sequence_recon, self.observation_dim, self.action_dim) 307 | 308 | rollout_states = rollout_from_state(self.env, state, actions[:-1]) 309 | 310 | X = np.stack([observations, rollout_states], axis=0) 311 | 312 | self.renders(savepath, X) 313 | 314 | def render_rollout(self, savepath, states, **video_kwargs): 315 | if type(states) is list: 316 | states = np.stack(states, axis=0)[None] 317 | images = self.renders(savepath, states) 318 | 319 | class Maze2dRenderer(AntMazeRenderer): 320 | 321 | def _is_in_collision(self, x, y): 322 | ''' 323 | 10 : wall 324 | 11 : free 325 | 12 : goal 326 | ''' 327 | maze = self.env.maze_arr 328 | ind = maze[int(x), int(y)] 329 | return ind == 10 330 | 331 | def plot_boundaries(self, N=100, ax=None, eps=1e-6): 332 | """ 333 | plots the maze boundaries in the antmaze environments 334 | """ 335 | ax = ax or plt.gca() 336 | 337 | maze = self.env.maze_arr 338 | xlim = (0, maze.shape[1]-eps) 339 | ylim = (0, maze.shape[0]-eps) 340 | 341 | X = np.linspace(*xlim, N) 342 | Y = np.linspace(*ylim, N) 343 | 344 | Z = np.zeros((N, N)) 345 | for i, x in enumerate(X): 346 | for j, y in enumerate(Y): 347 | collision = self._is_in_collision(x, y) 348 | Z[-j, i] = collision 349 | 350 | ax.imshow(Z, extent=(*xlim, *ylim), aspect='auto', cmap=plt.cm.binary) 351 | return xlim, ylim 352 | 353 | def renders(self, savepath, X): 354 | return super().renders(savepath, X + 0.5) 355 | 356 | #--------------------------------- planning callbacks ---------------------------------# 357 | 358 | -------------------------------------------------------------------------------- /trajectory/utils/serialization.py: -------------------------------------------------------------------------------- 1 | import time 2 | import sys 3 | import os 4 | import glob 5 | import pickle 6 | import json 7 | import torch 8 | import pdb 9 | 10 | def mkdir(savepath, prune_fname=False): 11 | """ 12 | returns `True` iff `savepath` is created 13 | """ 14 | if prune_fname: 15 | savepath = os.path.dirname(savepath) 16 | if not os.path.exists(savepath): 17 | try: 18 | os.makedirs(savepath) 19 | except: 20 | print(f'[ utils/serialization ] Warning: did not make directory: {savepath}') 21 | return False 22 | return True 23 | else: 24 | return False 25 | 26 | def get_latest_epoch(loadpath): 27 | states = glob.glob1(loadpath, 'state_*') 28 | latest_epoch = -1 29 | for state in states: 30 | epoch = int(state.replace('state_', '').replace('.pt', '')) 31 | latest_epoch = max(epoch, latest_epoch) 32 | return latest_epoch 33 | 34 | def load_model(*loadpath, epoch=None, device='cuda:0'): 35 | loadpath = os.path.join(*loadpath) 36 | config_path = os.path.join(loadpath, 'model_config.pkl') 37 | 38 | if epoch is 'latest': 39 | epoch = get_latest_epoch(loadpath) 40 | 41 | print(f'[ utils/serialization ] Loading model epoch: {epoch}') 42 | state_path = os.path.join(loadpath, f'state_{epoch}.pt') 43 | 44 | config = pickle.load(open(config_path, 'rb')) 45 | state = torch.load(state_path) 46 | 47 | model = config() 48 | model.to(device) 49 | model.load_state_dict(state, strict=True) 50 | 51 | print(f'\n[ utils/serialization ] Loaded config from {config_path}\n') 52 | print(config) 53 | 54 | return model, epoch 55 | 56 | def load_config(*loadpath): 57 | loadpath = os.path.join(*loadpath) 58 | config = pickle.load(open(loadpath, 'rb')) 59 | print(f'[ utils/serialization ] Loaded config from {loadpath}') 60 | print(config) 61 | return config 62 | 63 | def load_from_config(*loadpath): 64 | config = load_config(*loadpath) 65 | return config.make() 66 | 67 | def load_args(*loadpath): 68 | from .setup import Parser 69 | loadpath = os.path.join(*loadpath) 70 | args_path = os.path.join(loadpath, 'args.json') 71 | args = Parser() 72 | args.load(args_path) 73 | return args 74 | -------------------------------------------------------------------------------- /trajectory/utils/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | import random 4 | import numpy as np 5 | import torch 6 | from tap import Tap 7 | import pdb 8 | 9 | from .serialization import mkdir 10 | from .arrays import set_device 11 | from .git_utils import ( 12 | get_git_rev, 13 | save_git_diff, 14 | ) 15 | 16 | def set_seed(seed): 17 | random.seed(seed) 18 | np.random.seed(seed) 19 | torch.manual_seed(seed) 20 | torch.cuda.manual_seed_all(seed) 21 | 22 | def watch(args_to_watch): 23 | def _fn(args): 24 | exp_name = [] 25 | for key, label in args_to_watch: 26 | if not hasattr(args, key): 27 | continue 28 | val = getattr(args, key) 29 | exp_name.append(f'{label}{val}') 30 | exp_name = '_'.join(exp_name) 31 | exp_name = exp_name.replace('/_', '/') 32 | return exp_name 33 | return _fn 34 | 35 | class Parser(Tap): 36 | 37 | def save(self): 38 | fullpath = os.path.join(self.savepath, 'args.json') 39 | print(f'[ utils/setup ] Saved args to {fullpath}') 40 | super().save(fullpath, skip_unpicklable=True) 41 | 42 | def parse_args(self, experiment=None): 43 | args = super().parse_args(known_only=True) 44 | ## if not loading from a config script, skip the result of the setup 45 | if not hasattr(args, 'config'): return args 46 | args = self.read_config(args, experiment) 47 | self.add_extras(args) 48 | self.set_seed(args) 49 | self.get_commit(args) 50 | self.generate_exp_name(args) 51 | self.mkdir(args) 52 | self.save_diff(args) 53 | return args 54 | 55 | def read_config(self, args, experiment): 56 | ''' 57 | Load parameters from config file 58 | ''' 59 | dataset = args.dataset.replace('-', '_') 60 | print(f'[ utils/setup ] Reading config: {args.config}:{dataset}') 61 | module = importlib.import_module(args.config) 62 | params = getattr(module, 'base')[experiment] 63 | 64 | if hasattr(module, dataset) and experiment in getattr(module, dataset): 65 | print(f'[ utils/setup ] Using overrides | config: {args.config} | dataset: {dataset}') 66 | overrides = getattr(module, dataset)[experiment] 67 | params.update(overrides) 68 | else: 69 | print(f'[ utils/setup ] Not using overrides | config: {args.config} | dataset: {dataset}') 70 | 71 | for key, val in params.items(): 72 | setattr(args, key, val) 73 | 74 | return args 75 | 76 | def add_extras(self, args): 77 | ''' 78 | Override config parameters with command-line arguments 79 | ''' 80 | extras = args.extra_args 81 | if not len(extras): 82 | return 83 | 84 | print(f'[ utils/setup ] Found extras: {extras}') 85 | assert len(extras) % 2 == 0, f'Found odd number ({len(extras)}) of extras: {extras}' 86 | for i in range(0, len(extras), 2): 87 | key = extras[i].replace('--', '') 88 | val = extras[i+1] 89 | assert hasattr(args, key), f'[ utils/setup ] {key} not found in config: {args.config}' 90 | old_val = getattr(args, key) 91 | old_type = type(old_val) 92 | print(f'[ utils/setup ] Overriding config | {key} : {old_val} --> {val}') 93 | if val == 'None': 94 | val = None 95 | elif val == 'latest': 96 | val = 'latest' 97 | elif old_type in [bool, type(None)]: 98 | val = eval(val) 99 | else: 100 | val = old_type(val) 101 | setattr(args, key, val) 102 | 103 | def set_seed(self, args): 104 | if not 'seed' in dir(args): 105 | return 106 | set_seed(args.seed) 107 | 108 | def generate_exp_name(self, args): 109 | if not 'exp_name' in dir(args): 110 | return 111 | exp_name = getattr(args, 'exp_name') 112 | if callable(exp_name): 113 | exp_name_string = exp_name(args) 114 | print(f'[ utils/setup ] Setting exp_name to: {exp_name_string}') 115 | setattr(args, 'exp_name', exp_name_string) 116 | 117 | def mkdir(self, args): 118 | if 'logbase' in dir(args) and 'dataset' in dir(args) and 'exp_name' in dir(args): 119 | args.savepath = os.path.join(args.logbase, args.dataset, args.exp_name) 120 | if 'suffix' in dir(args): 121 | args.savepath = os.path.join(args.savepath, args.suffix) 122 | if mkdir(args.savepath): 123 | print(f'[ utils/setup ] Made savepath: {args.savepath}') 124 | self.save() 125 | 126 | def get_commit(self, args): 127 | args.commit = get_git_rev() 128 | 129 | def save_diff(self, args): 130 | try: 131 | save_git_diff(os.path.join(args.savepath, 'diff.txt')) 132 | except: 133 | print('[ utils/setup ] WARNING: did not save git diff') 134 | -------------------------------------------------------------------------------- /trajectory/utils/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | class Timer: 4 | 5 | def __init__(self): 6 | self._start = time.time() 7 | 8 | def __call__(self, reset=True): 9 | now = time.time() 10 | diff = now - self._start 11 | if reset: 12 | self._start = now 13 | return diff -------------------------------------------------------------------------------- /trajectory/utils/training.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data.dataloader import DataLoader 4 | import pdb 5 | 6 | from .timer import Timer 7 | 8 | def to(xs, device): 9 | return [x.to(device) for x in xs] 10 | 11 | class Trainer: 12 | 13 | def __init__(self, config): 14 | self.config = config 15 | self.device = config.device 16 | 17 | self.n_epochs = 0 18 | self.n_tokens = 0 # counter used for learning rate decay 19 | self.optimizer = None 20 | 21 | def get_optimizer(self, model): 22 | if self.optimizer is None: 23 | print(f'[ utils/training ] Making optimizer at epoch {self.n_epochs}') 24 | self.optimizer = model.configure_optimizers(self.config) 25 | return self.optimizer 26 | 27 | def train(self, model, dataset, n_epochs=1, log_freq=100): 28 | 29 | config = self.config 30 | optimizer = self.get_optimizer(model) 31 | model.train(True) 32 | vocab_size = dataset.N 33 | 34 | loader = DataLoader(dataset, shuffle=True, pin_memory=True, 35 | batch_size=config.batch_size, 36 | num_workers=config.num_workers) 37 | 38 | for _ in range(n_epochs): 39 | 40 | losses = [] 41 | timer = Timer() 42 | for it, batch in enumerate(loader): 43 | 44 | batch = to(batch, self.device) 45 | 46 | # forward the model 47 | with torch.set_grad_enabled(True): 48 | logits, loss = model(*batch) 49 | losses.append(loss.item()) 50 | 51 | # backprop and update the parameters 52 | model.zero_grad() 53 | loss.backward() 54 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip) 55 | optimizer.step() 56 | 57 | # decay the learning rate based on our progress 58 | if config.lr_decay: 59 | y = batch[-2] 60 | self.n_tokens += (y != vocab_size).sum() # number of tokens processed this step 61 | if self.n_tokens < config.warmup_tokens: 62 | # linear warmup 63 | lr_mult = float(self.n_tokens) / float(max(1, config.warmup_tokens)) 64 | else: 65 | # cosine learning rate decay 66 | progress = float(self.n_tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens)) 67 | lr_mult = max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress))) 68 | lr = config.learning_rate * lr_mult 69 | for param_group in optimizer.param_groups: 70 | param_group['lr'] = lr 71 | else: 72 | lr = config.learning_rate 73 | 74 | # report progress 75 | if it % log_freq == 0: 76 | print( 77 | f'[ utils/training ] epoch {self.n_epochs} [ {it:4d} / {len(loader):4d} ] ', 78 | f'train loss {loss.item():.5f} | lr {lr:.3e} | lr_mult: {lr_mult:.4f} | ' 79 | f't: {timer():.2f}') 80 | 81 | self.n_epochs += 1 82 | -------------------------------------------------------------------------------- /trajectory/utils/video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import skvideo.io 4 | 5 | def _make_dir(filename): 6 | folder = os.path.dirname(filename) 7 | if not os.path.exists(folder): 8 | os.makedirs(folder) 9 | 10 | def save_video(filename, video_frames, fps=60, video_format='mp4'): 11 | assert fps == int(fps), fps 12 | _make_dir(filename) 13 | 14 | skvideo.io.vwrite( 15 | filename, 16 | video_frames, 17 | inputdict={ 18 | '-r': str(int(fps)), 19 | }, 20 | outputdict={ 21 | '-f': video_format, 22 | '-pix_fmt': 'yuv420p', # '-pix_fmt=yuv420p' needed for osx https://github.com/scikit-video/scikit-video/issues/74 23 | } 24 | ) 25 | 26 | def save_videos(filename, *video_frames, **kwargs): 27 | ## video_frame : [ N x H x W x C ] 28 | video_frames = np.concatenate(video_frames, axis=2) 29 | save_video(filename, video_frames, **kwargs) 30 | --------------------------------------------------------------------------------