├── .dockerignore
├── .gitignore
├── LICENSE
├── README.md
├── azure
├── Dockerfile
├── config.py
├── download.sh
├── files
│ ├── 10_nvidia.json
│ └── Xdummy
├── launch_plan.py
├── launch_train.py
├── make_fuse_config.sh
├── mount.sh
└── sync.sh
├── config
└── offline.py
├── environment.yml
├── plotting
├── bar.png
├── plot.py
├── read_results.py
├── scores.py
└── table.py
├── pretrained.sh
├── scripts
├── plan.py
└── train.py
├── setup.py
└── trajectory
├── __init__.py
├── datasets
├── __init__.py
├── d4rl.py
├── preprocessing.py
└── sequence.py
├── models
├── __init__.py
├── ein.py
├── embeddings.py
├── mlp.py
└── transformers.py
├── search
├── __init__.py
├── core.py
├── sampling.py
└── utils.py
└── utils
├── __init__.py
├── arrays.py
├── config.py
├── discretization.py
├── git_utils.py
├── progress.py
├── rendering.py
├── serialization.py
├── setup.py
├── timer.py
├── training.py
└── video.py
/.dockerignore:
--------------------------------------------------------------------------------
1 | *.ipynb_checkpoints
2 | *__pycache__*
3 | *.egg-info
4 | logs/*
5 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .ipynb_checkpoints/
2 | __pycache__/
3 | *.egg-info
4 | .DS_Store
5 | bin/
6 |
7 | logs/*
8 | data/*
9 | slurm/config.sh
10 | slurm/*sandbox*
11 | *.img
12 | slurm-*.out
13 | *.png
14 | *.pdf
15 | *.pkl
16 | mjkey.txt
17 | azure/fuse.cfg
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Trajectory Transformer authors
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Trajectory Transformer
2 |
3 | Code release for [Offline Reinforcement Learning as One Big Sequence Modeling Problem](https://arxiv.org/abs/2106.02039).
4 |
5 | **New:** Also see [Alexander Nikulin's fork](https://github.com/Howuhh/faster-trajectory-transformer) with attention caching and vectorized rollouts!
6 |
7 | ## Installation
8 |
9 | All python dependencies are in [`environment.yml`](environment.yml). Install with:
10 |
11 | ```
12 | conda env create -f environment.yml
13 | conda activate trajectory
14 | pip install -e .
15 | ```
16 |
17 | For reproducibility, we have also included system requirements in a [`Dockerfile`](azure/Dockerfile) (see [installation instructions](#Docker)), but the conda installation should work on most standard Linux machines.
18 |
19 | ## Usage
20 |
21 | Train a transformer with: `python scripts/train.py --dataset halfcheetah-medium-v2`
22 |
23 | To reproduce the offline RL results: `python scripts/plan.py --dataset halfcheetah-medium-v2`
24 |
25 | By default, these commands will use the hyperparameters in [`config/offline.py`](config/offline.py). You can override them with runtime flags:
26 | ```
27 | python scripts/plan.py --dataset halfcheetah-medium-v2 \
28 | --horizon 5 --beam_width 32
29 | ```
30 |
31 | A few hyperparameters are different from those listed in the paper because of changes to the discretization strategy. These hyperparameters will be updated in the next arxiv version to match what is currently in the codebase.
32 |
33 | ## Pretrained models
34 |
35 | We have provided [pretrained models](https://www.dropbox.com/sh/r09lkdoj66kx43w/AACbXjMhcI6YNsn1qU4LParja?dl=0) for 16 datasets: `{halfcheetah, hopper, walker2d, ant}-{expert-v2, medium-expert-v2, medium-v2, medium-replay-v2}`. Download them with `./pretrained.sh`
36 |
37 | The models will be saved in `logs/$DATASET/gpt/pretrained`. To plan with these models, refer to them using the `gpt_loadpath` flag:
38 | ```
39 | python scripts/plan.py --dataset halfcheetah-medium-v2 \
40 | --gpt_loadpath gpt/pretrained
41 | ```
42 |
43 | `pretrained.sh` will also download 15 [plans](https://www.dropbox.com/sh/po0nul2u6qk8r2i/AABPDrOEJplQ8JT13DASdOWWa?dl=0) from each model, saved to `logs/$DATASET/plans/pretrained`. Read them with `
44 | python plotting/read_results.py`.
45 |
46 |
47 | To create the table of offline RL results from the paper, run python plotting/table.py
. This will print a table that can be copied into a Latex document. (Expand to view table source.)
48 |
49 | ```
50 | \begin{table*}[h]
51 | \centering
52 | \small
53 | \begin{tabular}{llrrrrrr}
54 | \toprule
55 | \multicolumn{1}{c}{\bf Dataset} & \multicolumn{1}{c}{\bf Environment} & \multicolumn{1}{c}{\bf BC} & \multicolumn{1}{c}{\bf MBOP} & \multicolumn{1}{c}{\bf BRAC} & \multicolumn{1}{c}{\bf CQL} & \multicolumn{1}{c}{\bf DT} & \multicolumn{1}{c}{\bf TT (Ours)} \\
56 | \midrule
57 | Medium-Expert & HalfCheetah & $59.9$ & $105.9$ & $41.9$ & $91.6$ & $86.8$ & $95.0$ \scriptsize{\raisebox{1pt}{$\pm 0.2$}} \\
58 | Medium-Expert & Hopper & $79.6$ & $55.1$ & $0.9$ & $105.4$ & $107.6$ & $110.0$ \scriptsize{\raisebox{1pt}{$\pm 2.7$}} \\
59 | Medium-Expert & Walker2d & $36.6$ & $70.2$ & $81.6$ & $108.8$ & $108.1$ & $101.9$ \scriptsize{\raisebox{1pt}{$\pm 6.8$}} \\
60 | Medium-Expert & Ant & $-$ & $-$ & $-$ & $-$ & $-$ & $116.1$ \scriptsize{\raisebox{1pt}{$\pm 9.0$}} \\
61 | \midrule
62 | Medium & HalfCheetah & $43.1$ & $44.6$ & $46.3$ & $44.0$ & $42.6$ & $46.9$ \scriptsize{\raisebox{1pt}{$\pm 0.4$}} \\
63 | Medium & Hopper & $63.9$ & $48.8$ & $31.3$ & $58.5$ & $67.6$ & $61.1$ \scriptsize{\raisebox{1pt}{$\pm 3.6$}} \\
64 | Medium & Walker2d & $77.3$ & $41.0$ & $81.1$ & $72.5$ & $74.0$ & $79.0$ \scriptsize{\raisebox{1pt}{$\pm 2.8$}} \\
65 | Medium & Ant & $-$ & $-$ & $-$ & $-$ & $-$ & $83.1$ \scriptsize{\raisebox{1pt}{$\pm 7.3$}} \\
66 | \midrule
67 | Medium-Replay & HalfCheetah & $4.3$ & $42.3$ & $47.7$ & $45.5$ & $36.6$ & $41.9$ \scriptsize{\raisebox{1pt}{$\pm 2.5$}} \\
68 | Medium-Replay & Hopper & $27.6$ & $12.4$ & $0.6$ & $95.0$ & $82.7$ & $91.5$ \scriptsize{\raisebox{1pt}{$\pm 3.6$}} \\
69 | Medium-Replay & Walker2d & $36.9$ & $9.7$ & $0.9$ & $77.2$ & $66.6$ & $82.6$ \scriptsize{\raisebox{1pt}{$\pm 6.9$}} \\
70 | Medium-Replay & Ant & $-$ & $-$ & $-$ & $-$ & $-$ & $77.0$ \scriptsize{\raisebox{1pt}{$\pm 6.8$}} \\
71 | \midrule
72 | \multicolumn{2}{c}{\bf Average (without Ant)} & 47.7 & 47.8 & 36.9 & 77.6 & 74.7 & 78.9 \hspace{.6cm} \\
73 | \multicolumn{2}{c}{\bf Average (all settings)} & $-$ & $-$ & $-$ & $-$ & $-$ & 82.2 \hspace{.6cm} \\
74 | \bottomrule
75 | \end{tabular}
76 | \label{table:d4rl}
77 | \end{table*}
78 | ```
79 |
80 | 
81 |
82 |
83 |
84 |
85 | To create the average performance plot, run python plotting/plot.py
.
86 |
87 | (Expand to view plot.)
88 |
89 |
90 | 
91 |
92 |
93 | ## Docker
94 |
95 | Copy your MuJoCo key to the Docker build context and build the container:
96 | ```
97 | cp ~/.mujoco/mjkey.txt azure/files/
98 | docker build -f azure/Dockerfile . -t trajectory
99 | ```
100 |
101 | Test the container:
102 | ```
103 | docker run -it --rm --gpus all \
104 | --mount type=bind,source=$PWD,target=/home/code \
105 | --mount type=bind,source=$HOME/.d4rl,target=/root/.d4rl \
106 | trajectory \
107 | bash -c \
108 | "export PYTHONPATH=$PYTHONPATH:/home/code && \
109 | python /home/code/scripts/train.py --dataset hopper-medium-expert-v2 --exp_name docker/"
110 | ```
111 |
112 | ## Running on Azure
113 |
114 | #### Setup
115 |
116 | 1. Launching jobs on Azure requires one more python dependency:
117 | ```
118 | pip install git+https://github.com/JannerM/doodad.git@janner
119 | ```
120 |
121 | 2. Tag the image built in [the previous section](#Docker) and push it to Docker Hub:
122 | ```
123 | export DOCKER_USERNAME=$(docker info | sed '/Username:/!d;s/.* //')
124 | docker tag trajectory ${DOCKER_USERNAME}/trajectory:latest
125 | docker image push ${DOCKER_USERNAME}/trajectory
126 | ```
127 |
128 | 3. Update [`azure/config.py`](azure/config.py), either by modifying the file directly or setting the relevant [environment variables](azure/config.py#L47-L52). To set the `AZURE_STORAGE_CONNECTION` variable, navigate to the `Access keys` section of your storage account. Click `Show keys` and copy the `Connection string`.
129 |
130 | 4. Download [`azcopy`](https://docs.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-v10): `./azure/download.sh`
131 |
132 | #### Usage
133 |
134 | Launch training jobs with `python azure/launch_train.py` and planning jobs with `python azure/launch_plan.py`.
135 |
136 | These scripts do not take runtime arguments. Instead, they run the corresponding scripts ([`scripts/train.py`](scripts/train.py) and [`scripts/plan.py`](scripts/plan.py), respectively) using the Cartesian product of the parameters in [`params_to_sweep`](azure/launch_train.py#L36-L38).
137 |
138 | #### Viewing results
139 |
140 | To rsync the results from the Azure storage container, run `./azure/sync.sh`.
141 |
142 | To mount the storage container:
143 | 1. Create a blobfuse config with `./azure/make_fuse_config.sh`
144 | 2. Run `./azure/mount.sh` to mount the storage container to `~/azure_mount`
145 |
146 | To unmount the container, run `sudo umount -f ~/azure_mount; rm -r ~/azure_mount`
147 |
148 | ## Reference
149 | ```
150 | @inproceedings{janner2021sequence,
151 | title = {Offline Reinforcement Learning as One Big Sequence Modeling Problem},
152 | author = {Michael Janner and Qiyang Li and Sergey Levine},
153 | booktitle = {Advances in Neural Information Processing Systems},
154 | year = {2021},
155 | }
156 | ```
157 |
158 | ## Acknowledgements
159 |
160 | The GPT implementation is from Andrej Karpathy's [minGPT](https://github.com/karpathy/minGPT) repo.
161 |
--------------------------------------------------------------------------------
/azure/Dockerfile:
--------------------------------------------------------------------------------
1 | # We need the CUDA base dockerfile to enable GPU rendering
2 | # on hosts with GPUs.
3 | # The image below is a pinned version of nvidia/cuda:9.1-cudnn7-devel-ubuntu16.04 (from Jan 2018)
4 | # If updating the base image, be sure to test on GPU since it has broken in the past.
5 | FROM nvidia/cuda:10.1-cudnn7-devel-ubuntu16.04
6 |
7 | SHELL ["/bin/bash", "-c"]
8 |
9 | ##########################################################
10 | ### System dependencies
11 | ##########################################################
12 |
13 | RUN apt-get update -q \
14 | && DEBIAN_FRONTEND=noninteractive apt-get install -y \
15 | cmake \
16 | curl \
17 | git \
18 | libav-tools \
19 | libgl1-mesa-dev \
20 | libgl1-mesa-glx \
21 | libglew-dev \
22 | libosmesa6-dev \
23 | net-tools \
24 | software-properties-common \
25 | swig \
26 | unzip \
27 | vim \
28 | wget \
29 | xpra \
30 | xserver-xorg-dev \
31 | zlib1g-dev \
32 | && apt-get clean \
33 | && rm -rf /var/lib/apt/lists/*
34 |
35 | ENV LANG C.UTF-8
36 |
37 | COPY ./azure/files/Xdummy /usr/local/bin/Xdummy
38 | RUN chmod +x /usr/local/bin/Xdummy
39 |
40 | # Workaround for https://bugs.launchpad.net/ubuntu/+source/nvidia-graphics-drivers-375/+bug/1674677
41 | COPY ./azure/files/10_nvidia.json /usr/share/glvnd/egl_vendor.d/10_nvidia.json
42 | COPY ./environment.yml /opt/environment.yml
43 |
44 | ENV LD_LIBRARY_PATH /usr/local/nvidia/lib64:${LD_LIBRARY_PATH}
45 |
46 | ##########################################################
47 | ### MuJoCo
48 | ##########################################################
49 | # Note: ~ is an alias for /root
50 | RUN mkdir -p /root/.mujoco \
51 | && wget https://www.roboti.us/download/mujoco200_linux.zip -O mujoco.zip \
52 | && unzip mujoco.zip -d /root/.mujoco \
53 | && rm mujoco.zip
54 | RUN mkdir -p /root/.mujoco \
55 | && wget https://www.roboti.us/download/mjpro150_linux.zip -O mujoco.zip \
56 | && unzip mujoco.zip -d /root/.mujoco \
57 | && rm mujoco.zip
58 | RUN ln -s /root/.mujoco/mujoco200_linux /root/.mujoco/mujoco200
59 | ENV LD_LIBRARY_PATH /root/.mujoco/mjpro150/bin:${LD_LIBRARY_PATH}
60 | ENV LD_LIBRARY_PATH /root/.mujoco/mujoco200/bin:${LD_LIBRARY_PATH}
61 | ENV LD_LIBRARY_PATH /root/.mujoco/mujoco200_linux/bin:${LD_LIBRARY_PATH}
62 | COPY ./azure/files/mjkey.txt /root/.mujoco
63 |
64 | ##########################################################
65 | ### Example Python Installation
66 | ##########################################################
67 | ENV PATH /opt/conda/bin:$PATH
68 | RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /tmp/miniconda.sh && \
69 | /bin/bash /tmp/miniconda.sh -b -p /opt/conda && \
70 | rm /tmp/miniconda.sh && \
71 | ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \
72 | echo ". /opt/conda/etc/profile.d/conda.sh" >> /etc/bash.bashrc
73 |
74 | RUN conda update -y --name base conda && conda clean --all -y
75 |
76 | RUN conda env create -f /opt/environment.yml
77 | ENV PATH /opt/conda/envs/trajectory/bin:$PATH
78 |
79 | ##########################################################
80 | ### gym sometimes has this patchelf issue
81 | ##########################################################
82 | RUN curl -o /usr/local/bin/patchelf https://s3-us-west-2.amazonaws.com/openai-sci-artifacts/manual-builds/patchelf_0.9_amd64.elf \
83 | && chmod +x /usr/local/bin/patchelf
84 | # RUN pip install gym[all]==0.12.5
85 |
86 | RUN echo "source activate /opt/conda/envs/trajectory && export PYTHONPATH=$PYTHONPATH:/home/code && export CUDA_VISIBLE_DEVICES=0" >> ~/.bashrc
87 | RUN source ~/.bashrc
88 |
--------------------------------------------------------------------------------
/azure/config.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | def get_docker_username():
4 | import subprocess
5 | import shlex
6 | ps = subprocess.Popen(shlex.split('docker info'), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
7 | output = subprocess.check_output(shlex.split("sed '/Username:/!d;s/.* //'"), stdin=ps.stdout)
8 | username = output.decode('utf-8').replace('\n', '')
9 | print(f'[ azure/config ] Grabbed username from `docker info`: {username}')
10 | return username
11 |
12 | ## /path/to/trajectory-transformer/azure
13 | CWD = os.path.dirname(__file__)
14 | ## /path/to/trajectory-transformer
15 | MODULE_PATH = os.path.dirname(CWD)
16 |
17 | CODE_DIRS_TO_MOUNT = [
18 | ]
19 | NON_CODE_DIRS_TO_MOUNT = [
20 | dict(
21 | local_dir=MODULE_PATH,
22 | mount_point='/home/code',
23 | ),
24 | ]
25 | REMOTE_DIRS_TO_MOUNT = [
26 | dict(
27 | local_dir='/doodad_tmp/',
28 | mount_point='/doodad_tmp/',
29 | ),
30 | ]
31 | LOCAL_LOG_DIR = '/tmp'
32 |
33 | DEFAULT_AZURE_GPU_MODEL = 'nvidia-tesla-t4'
34 | DEFAULT_AZURE_INSTANCE_TYPE = 'Standard_DS1_v2'
35 | DEFAULT_AZURE_REGION = 'eastus'
36 | DEFAULT_AZURE_RESOURCE_GROUP = 'traj'
37 | DEFAULT_AZURE_VM_NAME = 'traj-vm'
38 | DEFAULT_AZURE_VM_PASSWORD = 'Azure1'
39 |
40 | DOCKER_USERNAME = os.environ.get('DOCKER_USERNAME', get_docker_username())
41 | DEFAULT_DOCKER = f'docker.io/{DOCKER_USERNAME}/trajectory:latest'
42 |
43 | print(f'[ azure/config ] Local dir: {MODULE_PATH}')
44 | print(f'[ azure/config ] Default GPU model: {DEFAULT_AZURE_GPU_MODEL}')
45 | print(f'[ azure/config ] Default Docker image: {DEFAULT_DOCKER}')
46 |
47 | AZ_SUB_ID = os.environ['AZURE_SUBSCRIPTION_ID']
48 | AZ_CLIENT_ID = os.environ['AZURE_CLIENT_ID']
49 | AZ_TENANT_ID = os.environ['AZURE_TENANT_ID']
50 | AZ_SECRET = os.environ['AZURE_CLIENT_SECRET']
51 | AZ_CONTAINER = os.environ['AZURE_STORAGE_CONTAINER']
52 | AZ_CONN_STR = os.environ['AZURE_STORAGE_CONNECTION_STRING']
53 |
--------------------------------------------------------------------------------
/azure/download.sh:
--------------------------------------------------------------------------------
1 | DOWNLOAD_DIR=bin
2 | mkdir $DOWNLOAD_DIR
3 | wget https://aka.ms/downloadazcopy-v10-linux -O $DOWNLOAD_DIR/download.tar.gz
4 | tar -xvf $DOWNLOAD_DIR/download.tar.gz --one-top-level=$DOWNLOAD_DIR
5 | mv $DOWNLOAD_DIR/*/azcopy $DOWNLOAD_DIR
6 | rm $DOWNLOAD_DIR/download.tar.gz
7 | rm -r $DOWNLOAD_DIR/azcopy_linux*
--------------------------------------------------------------------------------
/azure/files/10_nvidia.json:
--------------------------------------------------------------------------------
1 | {
2 | "file_format_version" : "1.0.0",
3 | "ICD" : {
4 | "library_path" : "libEGL_nvidia.so.0"
5 | }
6 | }
7 |
--------------------------------------------------------------------------------
/azure/files/Xdummy:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | # ----------------------------------------------------------------------
3 | # Copyright (C) 2005-2011 Karl J. Runge
4 | # All rights reserved.
5 | #
6 | # This file is part of Xdummy.
7 | #
8 | # Xdummy is free software; you can redistribute it and/or modify
9 | # it under the terms of the GNU General Public License as published by
10 | # the Free Software Foundation; either version 2 of the License, or (at
11 | # your option) any later version.
12 | #
13 | # Xdummy is distributed in the hope that it will be useful,
14 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
15 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 | # GNU General Public License for more details.
17 | #
18 | # You should have received a copy of the GNU General Public License
19 | # along with Xdummy; if not, write to the Free Software
20 | # Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA
21 | # or see .
22 | # ----------------------------------------------------------------------
23 | #
24 | #
25 | # Xdummy: an LD_PRELOAD hack to run a stock Xorg(1) or XFree86(1) server
26 | # with the "dummy" video driver to make it avoid Linux VT switching, etc.
27 | #
28 | # Run "Xdummy -help" for more info.
29 | #
30 | install=""
31 | uninstall=""
32 | runit=1
33 | prconf=""
34 | notweak=""
35 | root=""
36 | nosudo=""
37 | xserver=""
38 | geom=""
39 | nomodelines=""
40 | depth=""
41 | debug=""
42 | strace=""
43 | cmdline_config=""
44 |
45 | PATH=$PATH:/bin:/usr/bin
46 | export PATH
47 |
48 | program=`basename "$0"`
49 |
50 | help () {
51 | ${PAGER:-more} << END
52 | $program:
53 |
54 | A hack to run a stock Xorg(1) or XFree86(1) X server with the "dummy"
55 | (RAM-only framebuffer) video driver such that it AVOIDS the Linux VT
56 | switching, opening device files in /dev, keyboard and mouse conflicts,
57 | and other problems associated with the normal use of "dummy".
58 |
59 | In other words, it tries to make Xorg/XFree86 with the "dummy"
60 | device driver act more like Xvfb(1).
61 |
62 | The primary motivation for the Xdummy script is to provide a virtual X
63 | server for x11vnc but with more features than Xvfb (or Xvnc); however
64 | it could be used for other reasons (e.g. better automated testing
65 | than with Xvfb.) One nice thing is the dummy server supports RANDR
66 | dynamic resizing while Xvfb does not.
67 |
68 | So, for example, x11vnc+Xdummy terminal services are a little better
69 | than x11vnc+Xvfb.
70 |
71 | To achieve this, while running the real Xserver $program intercepts
72 | system and library calls via the LD_PRELOAD method and modifies
73 | the behavior to make it work correctly (e.g. avoid the VT stuff.)
74 | LD_PRELOAD tricks are usually "clever hacks" and so might not work
75 | in all situations or break when something changes.
76 |
77 | WARNING: Take care in using Xdummy, although it never has it is
78 | possible that it could damage hardware. One can use the -prconf
79 | option to have it print out the xorg.conf config that it would use
80 | and then inspect it carefully before actually using it.
81 |
82 | This program no longer needs to be run as root as of 12/2009.
83 | However, if there are problems for certain situations (usually older
84 | servers) it may perform better if run as root (use the -root option.)
85 | When running as root remember the previous paragraph and that Xdummy
86 | comes without any warranty.
87 |
88 | gcc/cc and other build tools are required for this script to be able
89 | to compile the LD_PRELOAD shared object. Be sure they are installed
90 | on the system. See -install and -uninstall described below.
91 |
92 | Your Linux distribution may not install the dummy driver by default,
93 | e.g:
94 |
95 | /usr/lib/xorg/modules/drivers/dummy_drv.so
96 |
97 | some have it in a package named xserver-xorg-video-dummy you that
98 | need to install.
99 |
100 | Usage:
101 |
102 | $program <${program}-args>
103 |
104 | (actually, the arguments can be supplied in any order.)
105 |
106 | Examples:
107 |
108 | $program -install
109 |
110 | $program :1
111 |
112 | $program -debug :1
113 |
114 | $program -tmpdir ~/mytmp :1 -nolisten tcp
115 |
116 | startx example:
117 |
118 | startx -e bash -- $program :2 -depth 16
119 |
120 | (if startx needs to be run as root, you can su(1) to a normal
121 | user in the bash shell and then launch ~/.xinitrc or ~/.xsession,
122 | gnome-session, startkde, startxfce4, etc.)
123 |
124 | xdm example:
125 |
126 | xdm -config /usr/local/dummy/xdm-config -nodaemon
127 |
128 | where the xdm-config file has line:
129 |
130 | DisplayManager.servers: /usr/local/dummy/Xservers
131 |
132 | and /usr/local/dummy/Xservers has lines:
133 |
134 | :1 local /usr/local/dummy/Xdummy :1 -debug
135 | :2 local /usr/local/dummy/Xdummy :2 -debug
136 |
137 | (-debug is optional)
138 |
139 | gdm/kdm example:
140 |
141 | TBD.
142 |
143 | Config file:
144 |
145 | If the file $program.cfg exists it will be sourced as shell
146 | commands. Usually one will set some variables this way.
147 | To disable sourcing, supply -nocfg or set XDUMMY_NOCFG=1.
148 |
149 | Root permission and x11vnc:
150 |
151 | Update: as of 12/2009 this program no longer must be run as root.
152 | So try it as non-root before running it as root and/or the
153 | following schemes.
154 |
155 | In some circumstances X server program may need to be run as root.
156 | If so, one could run x11vnc as root with -unixpw (it switches
157 | to the user that logs in) and that may be OK, some other ideas:
158 |
159 | - add this to sudo via visudo:
160 |
161 | ALL ALL = NOPASSWD: /usr/local/bin/Xdummy
162 |
163 | - use this little suid wrapper:
164 | /*
165 | * xdummy.c
166 | *
167 | cc -o ./xdummy xdummy.c
168 | sudo cp ./xdummy /usr/local/bin/xdummy
169 | sudo chown root:root /usr/local/bin/xdummy
170 | sudo chmod u+s /usr/local/bin/xdummy
171 | *
172 | */
173 | #include
174 | #include
175 | #include
176 | #include
177 |
178 | int main (int argc, char *argv[]) {
179 | extern char **environ;
180 | char str[100];
181 | sprintf(str, "XDUMMY_UID=%d", (int) getuid());
182 | putenv(str);
183 | setuid(0);
184 | setgid(0);
185 | execv("/usr/local/bin/Xdummy", argv);
186 | exit(1);
187 | return 1;
188 | }
189 |
190 |
191 | Options:
192 |
193 | ${program}-args:
194 |
195 | -install Compile the LD_PRELOAD shared object and install it
196 | next to the $program script file as:
197 |
198 | $0.so
199 |
200 | When that file exists it is used as the LD_PRELOAD
201 | shared object without recompiling. Otherwise,
202 | each time $program is run the LD_PRELOAD shared
203 | object is compiled as a file in /tmp (or -tmpdir)
204 |
205 | If you set the environment variable
206 | INTERPOSE_GETUID=1 when building, then when
207 | $program is run as an ordinary user, the shared
208 | object will interpose getuid() calls and pretend
209 | to be root. Otherwise it doesn't pretend to
210 | be root.
211 |
212 | You can also set the CFLAGS environment variable
213 | to anything else you want on the compile cmdline.
214 |
215 | -uninstall Remove the file:
216 |
217 | $0.so
218 |
219 | The LD_PRELOAD shared object will then be compiled
220 | each time this program is run.
221 |
222 | The X server is not started under -install, -uninstall, or -prconf.
223 |
224 |
225 | :N The DISPLAY (e.g. :15) is often the first
226 | argument. It is passed to the real X server and
227 | also used by the Xdummy script as an identifier.
228 |
229 | -geom geom1[,geom2...] Take the geometry (e.g. 1024x768) or list
230 | of geometries and insert them into the Screen
231 | section of the tweaked X server config file.
232 | Use this to have a different geometry than the
233 | one(s) in the system config file.
234 |
235 | The option -geometry can be used instead of -geom;
236 | x11vnc calls Xdummy and Xvfb this way.
237 |
238 | -nomodelines When you specify -geom/-geometry, $program will
239 | create Modelines for each geometry and put them
240 | in the Monitor section. If you do not want this
241 | then supply -nomodelines.
242 |
243 | -depth n Use pixel color depth n (e.g. 8, 16, or 24). This
244 | makes sure the X config file has a Screen.Display
245 | subsection of this depth. Note this option is
246 | ALSO passed to the X server.
247 |
248 | -DEPTH n Same as -depth, except not passed to X server.
249 |
250 | -tmpdir dir Specify a temporary directory, owned by you and
251 | only writable by you. This is used in place of
252 | /tmp/Xdummy.\$USER/.. to place the $program.so
253 | shared object, tweaked config files, etc.
254 |
255 | -nonroot Run in non-root mode (working 12/2009, now default)
256 |
257 | -root Run as root (may still be needed in some
258 | environments.) Same as XDUMMY_RUN_AS_ROOT=1.
259 |
260 | -nosudo Do not try to use sudo(1) when re-running as root,
261 | use su(1) instead.
262 |
263 | -xserver path Specify the path to the Xserver to use. Default
264 | is to try "Xorg" first and then "XFree86". If
265 | those are not in \$PATH, it tries these locations:
266 | /usr/bin/Xorg
267 | /usr/X11R6/bin/Xorg
268 | /usr/X11R6/bin/XFree86
269 |
270 | -n Do not run the command to start the X server,
271 | just show the command that $program would run.
272 | The LD_PRELOAD shared object will be built,
273 | if needed. Also note any XDUMMY* environment
274 | variables that need to be set.
275 |
276 | -prconf Print, to stdout, the tweaked Xorg/XFree86
277 | config file (-config and -xf86config server
278 | options, respectively.) The Xserver is not
279 | started.
280 |
281 | -notweak Do not tweak (modify) the Xorg/XFree86 config file
282 | (system or server command line) at all. The -geom
283 | and similar config file modifications are ignored.
284 |
285 | It is up to you to make sure it is a working
286 | config file (e.g. "dummy" driver, etc.)
287 | Perhaps you want to use a file based on the
288 | -prconf output.
289 |
290 | -nocfg Do not try to source $program.cfg even if it
291 | exists. Same as setting XDUMMY_NOCFG=1.
292 |
293 | -debug Extra debugging output.
294 |
295 | -strace strace(1) the Xserver process (for troubleshooting.)
296 | -ltrace ltrace(1) instead of strace (can be slow.)
297 |
298 | -h, -help Print out this help.
299 |
300 |
301 | Xserver-args:
302 |
303 | Most of the Xorg and XFree86 options will work and are simply
304 | passed along if you supply them. Important ones that may be
305 | supplied if missing:
306 |
307 | :N X Display number for server to use.
308 |
309 | vtNN Linux virtual terminal (VT) to use (a VT is currently
310 | still used, just not switched to and from.)
311 |
312 | -config file Driver "dummy" tweaked config file, a
313 | -xf86config file number of settings are tweaked besides Driver.
314 |
315 | If -config/-xf86config is not given, the system one
316 | (e.g. /etc/X11/xorg.conf) is used. If the system one cannot be
317 | found, a built-in one is used. Any settings in the config file
318 | that are not consistent with "dummy" mode will be overwritten
319 | (unless -notweak is specified.)
320 |
321 | Use -config xdummy-builtin to force usage of the builtin config.
322 |
323 | If "file" is only a basename (e.g. "xorg.dummy.conf") with no /'s,
324 | then no tweaking of it is done: the X server will look for that
325 | basename via its normal search algorithm. If the found file does
326 | not refer to the "dummy" driver, etc, then the X server will fail.
327 |
328 | You can set the env. var. XDUMMY_EXTRA_SERVER_ARGS to hold some
329 | extra Xserver-args too. (Useful for cfg file.)
330 |
331 | Notes:
332 |
333 | The Xorg/XFree86 "dummy" driver is currently undocumented. It works
334 | well in this mode, but it is evidently not intended for end-users.
335 | So it could be removed or broken at any time.
336 |
337 | If the display Xserver-arg (e.g. :1) is not given, or ":" is given
338 | that indicates $program should try to find a free one (based on
339 | tcp ports.)
340 |
341 | If the display virtual terminal, VT, (e.g. vt9) is not given that
342 | indicates $program should try to find a free one (or guess a high one.)
343 |
344 | This program is not completely secure WRT files in /tmp (but it tries
345 | to a good degree.) Better is to use the -tmpdir option to supply a
346 | directory only writable by you. Even better is to get rid of users
347 | on the local machine you do not trust :-)
348 |
349 | Set XDUMMY_SET_XV=1 to turn on debugging output for this script.
350 |
351 | END
352 | }
353 |
354 | warn() {
355 | echo "$*" 1>&2
356 | }
357 |
358 | if [ "X$XDUMMY_SET_XV" != "X" ]; then
359 | set -xv
360 | fi
361 |
362 | if [ "X$XDUMMY_UID" = "X" ]; then
363 | XDUMMY_UID=`id -u`
364 | export XDUMMY_UID
365 | fi
366 | if [ "X$XDUMMY_UID" = "X0" ]; then
367 | if [ "X$SUDO_UID" != "X" ]; then
368 | XDUMMY_UID=$SUDO_UID
369 | export XDUMMY_UID
370 | fi
371 | fi
372 |
373 | # check if root=1 first:
374 | #
375 | if [ "X$XDUMMY_RUN_AS_ROOT" = "X1" ]; then
376 | root=1
377 | fi
378 | for arg in $*
379 | do
380 | if [ "X$arg" = "X-nonroot" ]; then
381 | root=""
382 | elif [ "X$arg" = "X-root" ]; then
383 | root=1
384 | elif [ "X$arg" = "X-nocfg" ]; then
385 | XDUMMY_NOCFG=1
386 | export XDUMMY_NOCFG
387 | fi
388 | done
389 |
390 | if [ "X$XDUMMY_NOCFG" = "X" -a -f "$0.cfg" ]; then
391 | . "$0.cfg"
392 | fi
393 |
394 | # See if it really needs to be run as root:
395 | #
396 | if [ "X$XDUMMY_SU_EXEC" = "X" -a "X$root" = "X1" -a "X`id -u`" != "X0" ]; then
397 | # this is to prevent infinite loop in case su/sudo doesn't work:
398 | XDUMMY_SU_EXEC=1
399 | export XDUMMY_SU_EXEC
400 |
401 | dosu=1
402 | nosudo=""
403 |
404 | for arg in $*
405 | do
406 | if [ "X$arg" = "X-nonroot" ]; then
407 | dosu=""
408 | elif [ "X$arg" = "X-nosudo" ]; then
409 | nosudo="1"
410 | elif [ "X$arg" = "X-help" ]; then
411 | dosu=""
412 | elif [ "X$arg" = "X-h" ]; then
413 | dosu=""
414 | elif [ "X$arg" = "X-install" ]; then
415 | dosu=""
416 | elif [ "X$arg" = "X-uninstall" ]; then
417 | dosu=""
418 | elif [ "X$arg" = "X-n" ]; then
419 | dosu=""
420 | elif [ "X$arg" = "X-prconf" ]; then
421 | dosu=""
422 | fi
423 | done
424 | if [ $dosu ]; then
425 | # we need to restart it with su/sudo:
426 | if type sudo > /dev/null 2>&1; then
427 | :
428 | else
429 | nosudo=1
430 | fi
431 | if [ "X$nosudo" = "X" ]; then
432 | warn "$program: supply the sudo password to restart as root:"
433 | if [ "X$XDUMMY_UID" != "X" ]; then
434 | exec sudo $0 -uid $XDUMMY_UID "$@"
435 | else
436 | exec sudo $0 "$@"
437 | fi
438 | else
439 | warn "$program: supply the root password to restart as root:"
440 | if [ "X$XDUMMY_UID" != "X" ]; then
441 | exec su -c "$0 -uid $XDUMMY_UID $*"
442 | else
443 | exec su -c "$0 $*"
444 | fi
445 | fi
446 | # DONE:
447 | exit
448 | fi
449 | fi
450 |
451 | # This will hold the X display, e.g. :20
452 | #
453 | disp=""
454 | args=""
455 | cmdline_config=""
456 |
457 | # Process Xdummy args:
458 | #
459 | while [ "X$1" != "X" ]
460 | do
461 | if [ "X$1" = "X-config" -o "X$1" = "X-xf86config" ]; then
462 | cmdline_config="$2"
463 | fi
464 | case $1 in
465 | ":"*) disp=$1
466 | ;;
467 | "-install") install=1; runit=""
468 | ;;
469 | "-uninstall") uninstall=1; runit=""
470 | ;;
471 | "-n") runit=""
472 | ;;
473 | "-no") runit=""
474 | ;;
475 | "-norun") runit=""
476 | ;;
477 | "-prconf") prconf=1; runit=""
478 | ;;
479 | "-notweak") notweak=1
480 | ;;
481 | "-noconf") notweak=1
482 | ;;
483 | "-nonroot") root=""
484 | ;;
485 | "-root") root=1
486 | ;;
487 | "-nosudo") nosudo=1
488 | ;;
489 | "-xserver") xserver="$2"; shift
490 | ;;
491 | "-uid") XDUMMY_UID="$2"; shift
492 | export XDUMMY_UID
493 | ;;
494 | "-geom") geom="$2"; shift
495 | ;;
496 | "-geometry") geom="$2"; shift
497 | ;;
498 | "-nomodelines") nomodelines=1
499 | ;;
500 | "-depth") depth="$2"; args="$args -depth $2";
501 | shift
502 | ;;
503 | "-DEPTH") depth="$2"; shift
504 | ;;
505 | "-tmpdir") XDUMMY_TMPDIR="$2"; shift
506 | ;;
507 | "-debug") debug=1
508 | ;;
509 | "-nocfg") :
510 | ;;
511 | "-nodebug") debug=""
512 | ;;
513 | "-strace") strace=1
514 | ;;
515 | "-ltrace") strace=2
516 | ;;
517 | "-h") help; exit 0
518 | ;;
519 | "-help") help; exit 0
520 | ;;
521 | *) args="$args $1"
522 | ;;
523 | esac
524 | shift
525 | done
526 |
527 | if [ "X$XDUMMY_EXTRA_SERVER_ARGS" != "X" ]; then
528 | args="$args $XDUMMY_EXTRA_SERVER_ARGS"
529 | fi
530 |
531 | # Try to get a username for use in our tmp directory, etc.
532 | #
533 | user=""
534 | if [ X`id -u` = "X0" ]; then
535 | user=root # this will also be used below for id=0
536 | elif [ "X$USER" != "X" ]; then
537 | user=$USER
538 | elif [ "X$LOGNAME" != "X" ]; then
539 | user=$LOGNAME
540 | fi
541 |
542 | # Keep trying...
543 | #
544 | if [ "X$user" = "X" ]; then
545 | user=`whoami 2>/dev/null`
546 | fi
547 | if [ "X$user" = "X" ]; then
548 | user=`basename "$HOME"`
549 | fi
550 | if [ "X$user" = "X" -o "X$user" = "X." ]; then
551 | user="u$$"
552 | fi
553 |
554 | if [ "X$debug" = "X1" -a "X$runit" != "X" ]; then
555 | echo ""
556 | echo "/usr/bin/env:"
557 | env | egrep -v '^(LS_COLORS|TERMCAP)' | sort
558 | echo ""
559 | fi
560 |
561 | # Function to compile the LD_PRELOAD shared object:
562 | #
563 | make_so() {
564 | # extract code embedded in this script into a tmp C file:
565 | n1=`grep -n '^#code_begin' $0 | head -1 | awk -F: '{print $1}'`
566 | n2=`grep -n '^#code_end' $0 | head -1 | awk -F: '{print $1}'`
567 | n1=`expr $n1 + 1`
568 | dn=`expr $n2 - $n1`
569 |
570 | tmp=$tdir/Xdummy.$RANDOM$$.c
571 | rm -f $tmp
572 | if [ -e $tmp -o -h $tmp ]; then
573 | warn "$tmp still exists."
574 | exit 1
575 | fi
576 | touch $tmp || exit 1
577 | tail -n +$n1 $0 | head -n $dn > $tmp
578 |
579 | # compile it to Xdummy.so:
580 | if [ -f "$SO" ]; then
581 | mv $SO $SO.$$
582 | rm -f $SO.$$
583 | fi
584 | rm -f $SO
585 | touch $SO
586 | if [ ! -f "$SO" ]; then
587 | SO=$tdir/Xdummy.$user.so
588 | warn "warning switching LD_PRELOAD shared object to: $SO"
589 | fi
590 |
591 | if [ -f "$SO" ]; then
592 | mv $SO $SO.$$
593 | rm -f $SO.$$
594 | fi
595 | rm -f $SO
596 |
597 | # we assume gcc:
598 | if [ "X$INTERPOSE_GETUID" = "X1" ]; then
599 | CFLAGS="$CFLAGS -DINTERPOSE_GETUID"
600 | fi
601 | echo "$program:" cc -shared -fPIC $CFLAGS -o $SO $tmp -ldl
602 | cc -shared -fPIC $CFLAGS -o $SO $tmp -ldl
603 | rc=$?
604 | rm -f $tmp
605 | if [ $rc != 0 ]; then
606 | warn "$program: cannot build $SO"
607 | exit 1
608 | fi
609 | if [ "X$debug" != "X" -o "X$install" != "X" ]; then
610 | warn "$program: created $SO"
611 | ls -l "$SO"
612 | fi
613 | }
614 |
615 | # Set tdir to tmp dir for make_so():
616 | if [ "X$XDUMMY_TMPDIR" != "X" ]; then
617 | tdir=$XDUMMY_TMPDIR
618 | mkdir -p $tdir
619 | else
620 | tdir="/tmp"
621 | fi
622 |
623 | # Handle -install/-uninstall case:
624 | SO=$0.so
625 | if [ "X$install" != "X" -o "X$uninstall" != "X" ]; then
626 | if [ -e "$SO" -o -h "$SO" ]; then
627 | warn "$program: removing $SO"
628 | fi
629 | if [ -f "$SO" ]; then
630 | mv $SO $SO.$$
631 | rm -f $SO.$$
632 | fi
633 | rm -f $SO
634 | if [ -e "$SO" -o -h "$SO" ]; then
635 | warn "warning: $SO still exists."
636 | exit 1
637 | fi
638 | if [ $install ]; then
639 | make_so
640 | if [ ! -f "$SO" ]; then
641 | exit 1
642 | fi
643 | fi
644 | exit 0
645 | fi
646 |
647 | # We need a tmp directory for the .so, tweaked config file, and for
648 | # redirecting filenames we cannot create (under -nonroot)
649 | #
650 | tack=""
651 | if [ "X$XDUMMY_TMPDIR" = "X" ]; then
652 | XDUMMY_TMPDIR="/tmp/Xdummy.$user"
653 |
654 | # try to tack on a unique subdir (display number or pid)
655 | # to allow multiple instances
656 | #
657 | if [ "X$disp" != "X" ]; then
658 | t0=$disp
659 | else
660 | t0=$1
661 | fi
662 | tack=`echo "$t0" | sed -e 's/^.*://'`
663 | if echo "$tack" | grep '^[0-9][0-9]*$' > /dev/null; then
664 | :
665 | else
666 | tack=$$
667 | fi
668 | if [ "X$tack" != "X" ]; then
669 | XDUMMY_TMPDIR="$XDUMMY_TMPDIR/$tack"
670 | fi
671 | fi
672 |
673 | tmp=$XDUMMY_TMPDIR
674 | if echo "$tmp" | grep '^/tmp' > /dev/null; then
675 | if [ "X$tmp" != "X/tmp" -a "X$tmp" != "X/tmp/" ]; then
676 | # clean this subdir of /tmp out, otherwise leave it...
677 | rm -rf $XDUMMY_TMPDIR
678 | if [ -e $XDUMMY_TMPDIR ]; then
679 | warn "$XDUMMY_TMPDIR still exists"
680 | exit 1
681 | fi
682 | fi
683 | fi
684 |
685 | mkdir -p $XDUMMY_TMPDIR
686 | chmod 700 $XDUMMY_TMPDIR
687 | if [ "X$tack" != "X" ]; then
688 | chmod 700 `dirname "$XDUMMY_TMPDIR"` 2>/dev/null
689 | fi
690 |
691 | # See if we can write something there:
692 | #
693 | tfile="$XDUMMY_TMPDIR/test.file"
694 | touch $tfile
695 | if [ ! -f "$tfile" ]; then
696 | XDUMMY_TMPDIR="/tmp/Xdummy.$$.$USER"
697 | warn "warning: setting tmpdir to $XDUMMY_TMPDIR ..."
698 | rm -rf $XDUMMY_TMPDIR || exit 1
699 | mkdir -p $XDUMMY_TMPDIR || exit 1
700 | fi
701 | rm -f $tfile
702 |
703 | export XDUMMY_TMPDIR
704 |
705 | # Compile the LD_PRELOAD shared object if needed (needs XDUMMY_TMPDIR)
706 | #
707 | if [ ! -f "$SO" ]; then
708 | SO="$XDUMMY_TMPDIR/Xdummy.so"
709 | make_so
710 | fi
711 |
712 | # Decide which X server to use:
713 | #
714 | if [ "X$xserver" = "X" ]; then
715 | if type Xorg >/dev/null 2>&1; then
716 | xserver="Xorg"
717 | elif type XFree86 >/dev/null 2>&1; then
718 | xserver="XFree86"
719 | elif -x /usr/bin/Xorg; then
720 | xserver="/usr/bin/Xorg"
721 | elif -x /usr/X11R6/bin/Xorg; then
722 | xserver="/usr/X11R6/bin/Xorg"
723 | elif -x /usr/X11R6/bin/XFree86; then
724 | xserver="/usr/X11R6/bin/XFree86"
725 | fi
726 | if [ "X$xserver" = "X" ]; then
727 | # just let it fail below.
728 | xserver="/usr/bin/Xorg"
729 | warn "$program: cannot locate a stock Xserver... assuming $xserver"
730 | fi
731 | fi
732 |
733 | # See if the binary is suid or not readable under -nonroot mode:
734 | #
735 | if [ "X$BASH_VERSION" != "X" ]; then
736 | xserver_path=`type -p $xserver 2>/dev/null`
737 | else
738 | xserver_path=`type $xserver 2>/dev/null | awk '{print $NF}'`
739 | fi
740 | if [ -e "$xserver_path" -a "X$root" = "X" -a "X$runit" != "X" ]; then
741 | if [ ! -r $xserver_path -o -u $xserver_path -o -g $xserver_path ]; then
742 | # XXX not quite correct with rm -rf $XDUMMY_TMPDIR ...
743 | # we keep on a filesystem we know root can write to.
744 | base=`basename "$xserver_path"`
745 | new="/tmp/$base.$user.bin"
746 | if [ -e $new ]; then
747 | snew=`ls -l $new | awk '{print $5}' | grep '^[0-9][0-9]*$'`
748 | sold=`ls -l $xserver_path | awk '{print $5}' | grep '^[0-9][0-9]*$'`
749 | if [ "X$snew" != "X" -a "X$sold" != "X" -a "X$sold" != "X$snew" ]; then
750 | warn "removing different sized copy:"
751 | ls -l $new $xserver_path
752 | rm -f $new
753 | fi
754 | fi
755 | if [ ! -e $new -o ! -s $new ]; then
756 | rm -f $new
757 | touch $new || exit 1
758 | chmod 700 $new || exit 1
759 | if [ ! -r $xserver_path ]; then
760 | warn ""
761 | warn "NEED TO COPY UNREADABLE $xserver_path to $new as root:"
762 | warn ""
763 | ls -l $xserver_path 1>&2
764 | warn ""
765 | warn "This only needs to be done once:"
766 | warn " cat $xserver_path > $new"
767 | warn ""
768 | nos=$nosudo
769 | if type sudo > /dev/null 2>&1; then
770 | :
771 | else
772 | nos=1
773 | fi
774 | if [ "X$nos" = "X1" ]; then
775 | warn "Please supply root passwd to 'su -c'"
776 | su -c "cat $xserver_path > $new"
777 | else
778 | warn "Please supply the sudo passwd if asked:"
779 | sudo /bin/sh -c "cat $xserver_path > $new"
780 | fi
781 | else
782 | warn ""
783 | warn "COPYING SETUID $xserver_path to $new"
784 | warn ""
785 | ls -l $xserver_path 1>&2
786 | warn ""
787 | cat $xserver_path > $new
788 | fi
789 | ls -l $new
790 | if [ -s $new ]; then
791 | :
792 | else
793 | rm -f $new
794 | ls -l $new
795 | exit 1
796 | fi
797 | warn ""
798 | warn "Please restart Xdummy now."
799 | exit 0
800 | fi
801 | if [ ! -O $new ]; then
802 | warn "file \"$new\" not owned by us!"
803 | ls -l $new
804 | exit 1
805 | fi
806 | xserver=$new
807 | fi
808 | fi
809 |
810 | # Work out display:
811 | #
812 | if [ "X$disp" != "X" ]; then
813 | :
814 | elif [ "X$1" != "X" ]; then
815 | if echo "$1" | grep '^:[0-9]' > /dev/null; then
816 | disp=$1
817 | shift
818 | elif [ "X$1" = "X:" ]; then
819 | # ":" means for us to find one.
820 | shift
821 | fi
822 | fi
823 | if [ "X$disp" = "X" -o "X$disp" = "X:" ]; then
824 | # try to find an open display port:
825 | # (tcp outdated...)
826 | ports=`netstat -ant | grep LISTEN | awk '{print $4}' | sed -e 's/^.*://'`
827 | n=0
828 | while [ $n -le 20 ]
829 | do
830 | port=`printf "60%02d" $n`
831 | if echo "$ports" | grep "^${port}\$" > /dev/null; then
832 | :
833 | else
834 | disp=":$n"
835 | warn "$program: auto-selected DISPLAY $disp"
836 | break
837 | fi
838 | n=`expr $n + 1`
839 | done
840 | fi
841 |
842 | # Work out which vt to use, try to find/guess an open one if necessary.
843 | #
844 | vt=""
845 | for arg in $*
846 | do
847 | if echo "$arg" | grep '^vt' > /dev/null; then
848 | vt=$arg
849 | break
850 | fi
851 | done
852 | if [ "X$vt" = "X" ]; then
853 | if [ "X$user" = "Xroot" ]; then
854 | # root can user fuser(1) to see if it is in use:
855 | if type fuser >/dev/null 2>&1; then
856 | # try /dev/tty17 thru /dev/tty32
857 | n=17
858 | while [ $n -le 32 ]
859 | do
860 | dev="/dev/tty$n"
861 | if fuser $dev >/dev/null 2>&1; then
862 | :
863 | else
864 | vt="vt$n"
865 | warn "$program: auto-selected VT $vt => $dev"
866 | break
867 | fi
868 | n=`expr $n + 1`
869 | done
870 | fi
871 | fi
872 | if [ "X$vt" = "X" ]; then
873 | # take a wild guess...
874 | vt=vt16
875 | warn "$program: selected fallback VT $vt"
876 | fi
877 | else
878 | vt=""
879 | fi
880 |
881 | # Decide flavor of Xserver:
882 | #
883 | stype=`basename "$xserver"`
884 | if echo "$stype" | grep -i xfree86 > /dev/null; then
885 | stype=xfree86
886 | else
887 | stype=xorg
888 | fi
889 |
890 | tweak_config() {
891 | in="$1"
892 | config2="$XDUMMY_TMPDIR/xdummy_modified_xconfig.conf"
893 | if [ "X$disp" != "X" ]; then
894 | d=`echo "$disp" | sed -e 's,/,,g' -e 's/:/_/g'`
895 | config2="$config2$d"
896 | fi
897 |
898 | # perl script to tweak the config file... add/delete options, etc.
899 | #
900 | env XDUMMY_GEOM=$geom \
901 | XDUMMY_DEPTH=$depth \
902 | XDUMMY_NOMODELINES=$nomodelines \
903 | perl > $config2 < $in -e '
904 | $n = 0;
905 | $geom = $ENV{XDUMMY_GEOM};
906 | $depth = $ENV{XDUMMY_DEPTH};
907 | $nomodelines = $ENV{XDUMMY_NOMODELINES};
908 | $mode_str = "";
909 | $videoram = "24000";
910 | $HorizSync = "30.0 - 130.0";
911 | $VertRefresh = "50.0 - 250.0";
912 | if ($geom ne "") {
913 | my $tmp = "";
914 | foreach $g (split(/,/, $geom)) {
915 | $tmp .= "\"$g\" ";
916 | if (!$nomodelines && $g =~ /(\d+)x(\d+)/) {
917 | my $w = $1;
918 | my $h = $2;
919 | $mode_str .= " Modeline \"$g\" ";
920 | my $dot = sprintf("%.2f", $w * $h * 70 * 1.e-6);
921 | $mode_str .= $dot;
922 | $mode_str .= " " . $w;
923 | $mode_str .= " " . int(1.02 * $w);
924 | $mode_str .= " " . int(1.10 * $w);
925 | $mode_str .= " " . int(1.20 * $w);
926 | $mode_str .= " " . $h;
927 | $mode_str .= " " . int($h + 1);
928 | $mode_str .= " " . int($h + 3);
929 | $mode_str .= " " . int($h + 20);
930 | $mode_str .= "\n";
931 | }
932 | }
933 | $tmp =~ s/\s*$//;
934 | $geom = $tmp;
935 | }
936 | while (<>) {
937 | if ($ENV{XDUMMY_NOTWEAK}) {
938 | print $_;
939 | next;
940 | }
941 | $n++;
942 | if (/^\s*#/) {
943 | # pass comments straight thru
944 | print;
945 | next;
946 | }
947 | if (/^\s*Section\s+(\S+)/i) {
948 | # start of Section
949 | $sect = $1;
950 | $sect =~ s/\W//g;
951 | $sect =~ y/A-Z/a-z/;
952 | $sects{$sect} = 1;
953 | print;
954 | next;
955 | }
956 | if (/^\s*EndSection/i) {
957 | # end of Section
958 | if ($sect eq "serverflags") {
959 | if (!$got_DontVTSwitch) {
960 | print " ##Xdummy:##\n";
961 | print " Option \"DontVTSwitch\" \"true\"\n";
962 | }
963 | if (!$got_AllowMouseOpenFail) {
964 | print " ##Xdummy:##\n";
965 | print " Option \"AllowMouseOpenFail\" \"true\"\n";
966 | }
967 | if (!$got_PciForceNone) {
968 | print " ##Xdummy:##\n";
969 | print " Option \"PciForceNone\" \"true\"\n";
970 | }
971 | } elsif ($sect eq "device") {
972 | if (!$got_Driver) {
973 | print " ##Xdummy:##\n";
974 | print " Driver \"dummy\"\n";
975 | }
976 | if (!$got_VideoRam) {
977 | print " ##Xdummy:##\n";
978 | print " VideoRam $videoram\n";
979 | }
980 | } elsif ($sect eq "screen") {
981 | if ($depth ne "" && !got_DefaultDepth) {
982 | print " ##Xdummy:##\n";
983 | print " DefaultDepth $depth\n";
984 | }
985 | if ($got_Monitor eq "") {
986 | print " ##Xdummy:##\n";
987 | print " Monitor \"Monitor0\"\n";
988 | }
989 | } elsif ($sect eq "monitor") {
990 | if (!got_HorizSync) {
991 | print " ##Xdummy:##\n";
992 | print " HorizSync $HorizSync\n";
993 | }
994 | if (!got_VertRefresh) {
995 | print " ##Xdummy:##\n";
996 | print " VertRefresh $VertRefresh\n";
997 | }
998 | if (!$nomodelines) {
999 | print " ##Xdummy:##\n";
1000 | print $mode_str;
1001 | }
1002 | }
1003 | $sect = "";
1004 | print;
1005 | next;
1006 | }
1007 |
1008 | if (/^\s*SubSection\s+(\S+)/i) {
1009 | # start of Section
1010 | $subsect = $1;
1011 | $subsect =~ s/\W//g;
1012 | $subsect =~ y/A-Z/a-z/;
1013 | $subsects{$subsect} = 1;
1014 | if ($sect eq "screen" && $subsect eq "display") {
1015 | $got_Modes = 0;
1016 | }
1017 | print;
1018 | next;
1019 | }
1020 | if (/^\s*EndSubSection/i) {
1021 | # end of SubSection
1022 | if ($sect eq "screen") {
1023 | if ($subsect eq "display") {
1024 | if ($depth ne "" && !$set_Depth) {
1025 | print " ##Xdummy:##\n";
1026 | print " Depth\t$depth\n";
1027 | }
1028 | if ($geom ne "" && ! $got_Modes) {
1029 | print " ##Xdummy:##\n";
1030 | print " Modes\t$geom\n";
1031 | }
1032 | }
1033 | }
1034 | $subsect = "";
1035 | print;
1036 | next;
1037 | }
1038 |
1039 | $l = $_;
1040 | $l =~ s/#.*$//;
1041 | if ($sect eq "serverflags") {
1042 | if ($l =~ /^\s*Option.*DontVTSwitch/i) {
1043 | $_ =~ s/false/true/ig;
1044 | $got_DontVTSwitch = 1;
1045 | }
1046 | if ($l =~ /^\s*Option.*AllowMouseOpenFail/i) {
1047 | $_ =~ s/false/true/ig;
1048 | $got_AllowMouseOpenFail = 1;
1049 | }
1050 | if ($l =~ /^\s*Option.*PciForceNone/i) {
1051 | $_ =~ s/false/true/ig;
1052 | $got_PciForceNone= 1;
1053 | }
1054 | }
1055 | if ($sect eq "module") {
1056 | if ($l =~ /^\s*Load.*\b(dri|fbdevhw)\b/i) {
1057 | $_ = "##Xdummy## $_";
1058 | }
1059 | }
1060 | if ($sect eq "monitor") {
1061 | if ($l =~ /^\s*HorizSync/i) {
1062 | $got_HorizSync = 1;
1063 | }
1064 | if ($l =~ /^\s*VertRefresh/i) {
1065 | $got_VertRefresh = 1;
1066 | }
1067 | }
1068 | if ($sect eq "device") {
1069 | if ($l =~ /^(\s*Driver)\b/i) {
1070 | $_ = "$1 \"dummy\"\n";
1071 | $got_Driver = 1;
1072 | }
1073 | if ($l =~ /^\s*VideoRam/i) {
1074 | $got_VideoRam= 1;
1075 | }
1076 | }
1077 | if ($sect eq "inputdevice") {
1078 | if ($l =~ /^\s*Option.*\bDevice\b/i) {
1079 | print " ##Xdummy:##\n";
1080 | $_ = " Option \"Device\" \"/dev/dilbert$n\"\n";
1081 | }
1082 | }
1083 | if ($sect eq "screen") {
1084 | if ($l =~ /^\s*DefaultDepth\s+(\d+)/i) {
1085 | if ($depth ne "") {
1086 | print " ##Xdummy:##\n";
1087 | $_ = " DefaultDepth\t$depth\n";
1088 | }
1089 | $got_DefaultDepth = 1;
1090 | }
1091 | if ($l =~ /^\s*Monitor\s+(\S+)/i) {
1092 | $got_Monitor = $1;
1093 | $got_Monitor =~ s/"//g;
1094 | }
1095 | if ($subsect eq "display") {
1096 | if ($geom ne "") {
1097 | if ($l =~ /^(\s*Modes)\b/i) {
1098 | print " ##Xdummy:##\n";
1099 | $_ = "$1 $geom\n";
1100 | $got_Modes = 1;
1101 | }
1102 | }
1103 | if ($l =~ /^\s*Depth\s+(\d+)/i) {
1104 | my $d = $1;
1105 | if (!$set_Depth && $depth ne "") {
1106 | $set_Depth = 1;
1107 | if ($depth != $d) {
1108 | print " ##Xdummy:##\n";
1109 | $_ = " Depth\t$depth\n";
1110 | }
1111 | }
1112 | }
1113 | }
1114 | }
1115 | print;
1116 | }
1117 | if ($ENV{XDUMMY_NOTWEAK}) {
1118 | exit;
1119 | }
1120 | # create any crucial sections that are missing:
1121 | if (! exists($sects{serverflags})) {
1122 | print "\n##Xdummy:##\n";
1123 | print "Section \"ServerFlags\"\n";
1124 | print " Option \"DontVTSwitch\" \"true\"\n";
1125 | print " Option \"AllowMouseOpenFail\" \"true\"\n";
1126 | print " Option \"PciForceNone\" \"true\"\n";
1127 | print "EndSection\n";
1128 | }
1129 | if (! exists($sects{device})) {
1130 | print "\n##Xdummy:##\n";
1131 | print "Section \"Device\"\n";
1132 | print " Identifier \"Videocard0\"\n";
1133 | print " Driver \"dummy\"\n";
1134 | print " VideoRam $videoram\n";
1135 | print "EndSection\n";
1136 | }
1137 | if (! exists($sects{monitor})) {
1138 | print "\n##Xdummy:##\n";
1139 | print "Section \"Monitor\"\n";
1140 | print " Identifier \"Monitor0\"\n";
1141 | print " HorizSync $HorizSync\n";
1142 | print " VertRefresh $VertRefresh\n";
1143 | print "EndSection\n";
1144 | }
1145 | if (! exists($sects{screen})) {
1146 | print "\n##Xdummy:##\n";
1147 | print "Section \"Screen\"\n";
1148 | print " Identifier \"Screen0\"\n";
1149 | print " Device \"Videocard0\"\n";
1150 | if ($got_Monitor ne "") {
1151 | print " Monitor \"$got_Monitor\"\n";
1152 | } else {
1153 | print " Monitor \"Monitor0\"\n";
1154 | }
1155 | if ($depth ne "") {
1156 | print " DefaultDepth $depth\n";
1157 | } else {
1158 | print " DefaultDepth 24\n";
1159 | }
1160 | print " SubSection \"Display\"\n";
1161 | print " Viewport 0 0\n";
1162 | print " Depth 24\n";
1163 | if ($got_Modes) {
1164 | ;
1165 | } elsif ($geom ne "") {
1166 | print " Modes $geom\n";
1167 | } else {
1168 | print " Modes \"1280x1024\" \"1024x768\" \"800x600\"\n";
1169 | }
1170 | print " EndSubSection\n";
1171 | print "EndSection\n";
1172 | }
1173 | ';
1174 | }
1175 |
1176 | # Work out config file and tweak it.
1177 | #
1178 | if [ "X$cmdline_config" = "X" ]; then
1179 | :
1180 | elif [ "X$cmdline_config" = "Xxdummy-builtin" ]; then
1181 | :
1182 | elif echo "$cmdline_config" | grep '/' > /dev/null; then
1183 | :
1184 | else
1185 | # ignore basename only case (let server handle it)
1186 | cmdline_config=""
1187 | notweak=1
1188 | fi
1189 |
1190 | config=$cmdline_config
1191 |
1192 | if [ "X$notweak" = "X1" -a "X$root" = "X" -a -f "$cmdline_config" ]; then
1193 | # if not root we need to copy (but not tweak) the specified config.
1194 | XDUMMY_NOTWEAK=1
1195 | export XDUMMY_NOTWEAK
1196 | notweak=""
1197 | fi
1198 |
1199 | if [ ! $notweak ]; then
1200 | # tweaked config will be put in $config2:
1201 | config2=""
1202 | if [ "X$config" = "X" ]; then
1203 | # use the default one:
1204 | if [ "X$stype" = "Xxorg" ]; then
1205 | config=/etc/X11/xorg.conf
1206 | else
1207 | if [ -f "/etc/X11/XF86Config-4" ]; then
1208 | config="/etc/X11/XF86Config-4"
1209 | else
1210 | config="/etc/X11/XF86Config"
1211 | fi
1212 | fi
1213 | if [ ! -f "$config" ]; then
1214 | for c in /etc/X11/xorg.conf /etc/X11/XF86Config-4 /etc/X11/XF86Config
1215 | do
1216 | if [ -f $c ]; then
1217 | config=$c
1218 | break
1219 | fi
1220 | done
1221 | fi
1222 | fi
1223 |
1224 | if [ "X$config" = "Xxdummy-builtin" ]; then
1225 | config=""
1226 | fi
1227 |
1228 | if [ ! -f "$config" ]; then
1229 | config="$XDUMMY_TMPDIR/xorg.conf"
1230 | warn "$program: using minimal built-in xorg.conf settings."
1231 | cat > $config < /dev/null; then
1359 | so=`echo "$so" | sed -e "s,^\.,$pwd,"`
1360 | fi
1361 | if echo "$so" | grep '/' > /dev/null; then
1362 | :
1363 | else
1364 | so="$pwd/$so"
1365 | fi
1366 | warn "env LD_PRELOAD=$so $xserver $disp $args $vt"
1367 | warn ""
1368 | if [ ! $runit ]; then
1369 | exit 0
1370 | fi
1371 | fi
1372 |
1373 | if [ $strace ]; then
1374 | if [ "X$strace" = "X2" ]; then
1375 | ltrace -f env LD_PRELOAD=$SO $xserver $disp $args $vt
1376 | else
1377 | strace -f env LD_PRELOAD=$SO $xserver $disp $args $vt
1378 | fi
1379 | else
1380 | exec env LD_PRELOAD=$SO $xserver $disp $args $vt
1381 | fi
1382 |
1383 | exit $?
1384 |
1385 | #########################################################################
1386 |
1387 | code() {
1388 | #code_begin
1389 | #include
1390 | #define O_ACCMODE 0003
1391 | #define O_RDONLY 00
1392 | #define O_WRONLY 01
1393 | #define O_RDWR 02
1394 | #define O_CREAT 0100 /* not fcntl */
1395 | #define O_EXCL 0200 /* not fcntl */
1396 | #define O_NOCTTY 0400 /* not fcntl */
1397 | #define O_TRUNC 01000 /* not fcntl */
1398 | #define O_APPEND 02000
1399 | #define O_NONBLOCK 04000
1400 | #define O_NDELAY O_NONBLOCK
1401 | #define O_SYNC 010000
1402 | #define O_FSYNC O_SYNC
1403 | #define O_ASYNC 020000
1404 |
1405 | #include
1406 | #include
1407 | #include
1408 |
1409 | #include
1410 | #include
1411 |
1412 | #define __USE_GNU
1413 | #include
1414 |
1415 | static char tmpdir[4096];
1416 | static char str1[4096];
1417 | static char str2[4096];
1418 |
1419 | static char devs[256][1024];
1420 | static int debug = -1;
1421 | static int root = -1;
1422 | static int changed_uid = 0;
1423 | static int saw_fonts = 0;
1424 | static int saw_lib_modules = 0;
1425 |
1426 | static time_t start = 0;
1427 |
1428 | void check_debug(void) {
1429 | if (debug < 0) {
1430 | if (getenv("XDUMMY_DEBUG") != NULL) {
1431 | debug = 1;
1432 | } else {
1433 | debug = 0;
1434 | }
1435 | /* prevent other processes using the preload: */
1436 | putenv("LD_PRELOAD=");
1437 | }
1438 | }
1439 | void check_root(void) {
1440 | if (root < 0) {
1441 | /* script tells us if we are root */
1442 | if (getenv("XDUMMY_ROOT") != NULL) {
1443 | root = 1;
1444 | } else {
1445 | root = 0;
1446 | }
1447 | }
1448 | }
1449 |
1450 | void check_uid(void) {
1451 | if (start == 0) {
1452 | start = time(NULL);
1453 | if (debug) fprintf(stderr, "START: %u\n", (unsigned int) start);
1454 | return;
1455 | } else if (changed_uid == 0) {
1456 | if (saw_fonts || time(NULL) > start + 20) {
1457 | if (getenv("XDUMMY_UID")) {
1458 | int uid = atoi(getenv("XDUMMY_UID"));
1459 | if (debug) fprintf(stderr, "SETREUID: %d saw_fonts=%d\n", uid, saw_fonts);
1460 | if (uid >= 0) {
1461 | /* this will simply fail in -nonroot mode: */
1462 | setreuid(uid, -1);
1463 | }
1464 | }
1465 | changed_uid = 1;
1466 | }
1467 | }
1468 | }
1469 |
1470 | #define CHECKIT if (debug < 0) check_debug(); \
1471 | if (root < 0) check_root(); \
1472 | check_uid();
1473 |
1474 | static void set_tmpdir(void) {
1475 | char *s;
1476 | static int didset = 0;
1477 | if (didset) {
1478 | return;
1479 | }
1480 | s = getenv("XDUMMY_TMPDIR");
1481 | if (! s) {
1482 | s = "/tmp";
1483 | }
1484 | tmpdir[0] = '\0';
1485 | strcat(tmpdir, s);
1486 | strcat(tmpdir, "/");
1487 | didset = 1;
1488 | }
1489 |
1490 | static char *tmpdir_path(const char *path) {
1491 | char *str;
1492 | set_tmpdir();
1493 | strcpy(str2, path);
1494 | str = str2;
1495 | while (*str) {
1496 | if (*str == '/') {
1497 | *str = '_';
1498 | }
1499 | str++;
1500 | }
1501 | strcpy(str1, tmpdir);
1502 | strcat(str1, str2);
1503 | return str1;
1504 | }
1505 |
1506 | int open(const char *pathname, int flags, unsigned short mode) {
1507 | int fd;
1508 | char *store_dev = NULL;
1509 | static int (*real_open)(const char *, int , unsigned short) = NULL;
1510 |
1511 | CHECKIT
1512 | if (! real_open) {
1513 | real_open = (int (*)(const char *, int , unsigned short))
1514 | dlsym(RTLD_NEXT, "open");
1515 | }
1516 |
1517 | if (strstr(pathname, "lib/modules/")) {
1518 | /* not currently used. */
1519 | saw_lib_modules = 1;
1520 | }
1521 |
1522 | if (!root) {
1523 | if (strstr(pathname, "/dev/") == pathname) {
1524 | store_dev = strdup(pathname);
1525 | }
1526 | if (strstr(pathname, "/dev/tty") == pathname && strcmp(pathname, "/dev/tty")) {
1527 | pathname = tmpdir_path(pathname);
1528 | if (debug) fprintf(stderr, "OPEN: %s -> %s (as FIFO)\n", store_dev, pathname);
1529 | /* we make it a FIFO so ioctl on it does not fail */
1530 | unlink(pathname);
1531 | mkfifo(pathname, 0666);
1532 | } else if (0) {
1533 | /* we used to handle more /dev files ... */
1534 | fd = real_open(pathname, O_WRONLY|O_CREAT, 0777);
1535 | close(fd);
1536 | }
1537 | }
1538 |
1539 | fd = real_open(pathname, flags, mode);
1540 |
1541 | if (debug) fprintf(stderr, "OPEN: %s %d %d fd=%d\n", pathname, flags, mode, fd);
1542 |
1543 | if (! root) {
1544 | if (store_dev) {
1545 | if (fd < 256) {
1546 | strcpy(devs[fd], store_dev);
1547 | }
1548 | free(store_dev);
1549 | }
1550 | }
1551 |
1552 | return(fd);
1553 | }
1554 |
1555 | int open64(const char *pathname, int flags, unsigned short mode) {
1556 | int fd;
1557 |
1558 | CHECKIT
1559 | if (debug) fprintf(stderr, "OPEN64: %s %d %d\n", pathname, flags, mode);
1560 |
1561 | fd = open(pathname, flags, mode);
1562 | return(fd);
1563 | }
1564 |
1565 | int rename(const char *oldpath, const char *newpath) {
1566 | static int (*real_rename)(const char *, const char *) = NULL;
1567 |
1568 | CHECKIT
1569 | if (! real_rename) {
1570 | real_rename = (int (*)(const char *, const char *))
1571 | dlsym(RTLD_NEXT, "rename");
1572 | }
1573 |
1574 | if (debug) fprintf(stderr, "RENAME: %s %s\n", oldpath, newpath);
1575 |
1576 | if (root) {
1577 | return(real_rename(oldpath, newpath));
1578 | }
1579 |
1580 | if (strstr(oldpath, "/var/log") == oldpath) {
1581 | if (debug) fprintf(stderr, "RENAME: returning 0\n");
1582 | return 0;
1583 | }
1584 | return(real_rename(oldpath, newpath));
1585 | }
1586 |
1587 | FILE *fopen(const char *pathname, const char *mode) {
1588 | static FILE* (*real_fopen)(const char *, const char *) = NULL;
1589 | char *str;
1590 |
1591 | if (! saw_fonts) {
1592 | if (strstr(pathname, "/fonts/")) {
1593 | if (strstr(pathname, "fonts.dir")) {
1594 | saw_fonts = 1;
1595 | } else if (strstr(pathname, "fonts.alias")) {
1596 | saw_fonts = 1;
1597 | }
1598 | }
1599 | }
1600 |
1601 | CHECKIT
1602 | if (! real_fopen) {
1603 | real_fopen = (FILE* (*)(const char *, const char *))
1604 | dlsym(RTLD_NEXT, "fopen");
1605 | }
1606 |
1607 | if (debug) fprintf(stderr, "FOPEN: %s %s\n", pathname, mode);
1608 |
1609 | if (strstr(pathname, "xdummy_modified_xconfig.conf")) {
1610 | /* make our config appear to be in /etc/X11, etc. */
1611 | char *q = strrchr(pathname, '/');
1612 | if (q != NULL && getenv("XDUMMY_TMPDIR") != NULL) {
1613 | strcpy(str1, getenv("XDUMMY_TMPDIR"));
1614 | strcat(str1, q);
1615 | if (debug) fprintf(stderr, "FOPEN: %s -> %s\n", pathname, str1);
1616 | pathname = str1;
1617 | }
1618 | }
1619 |
1620 | if (root) {
1621 | return(real_fopen(pathname, mode));
1622 | }
1623 |
1624 | str = (char *) pathname;
1625 | if (strstr(pathname, "/var/log") == pathname) {
1626 | str = tmpdir_path(pathname);
1627 | if (debug) fprintf(stderr, "FOPEN: %s -> %s\n", pathname, str);
1628 | }
1629 | return(real_fopen(str, mode));
1630 | }
1631 |
1632 |
1633 | #define RETURN0 if (debug) \
1634 | {fprintf(stderr, "IOCTL: covered %d 0x%x\n", fd, req);} return 0;
1635 | #define RETURN1 if (debug) \
1636 | {fprintf(stderr, "IOCTL: covered %d 0x%x\n", fd, req);} return -1;
1637 |
1638 | int ioctl(int fd, int req, void *ptr) {
1639 | static int closed_xf86Info_consoleFd = 0;
1640 | static int (*real_ioctl)(int, int , void *) = NULL;
1641 |
1642 | CHECKIT
1643 | if (! real_ioctl) {
1644 | real_ioctl = (int (*)(int, int , void *))
1645 | dlsym(RTLD_NEXT, "open");
1646 | }
1647 | if (debug) fprintf(stderr, "IOCTL: %d 0x%x %p\n", fd, req, ptr);
1648 |
1649 | /* based on xorg-x11-6.8.1-dualhead.patch */
1650 | if (req == VT_GETMODE) {
1651 | /* close(xf86Info.consoleFd) */
1652 | if (0 && ! closed_xf86Info_consoleFd) {
1653 | /* I think better not to close it... */
1654 | close(fd);
1655 | closed_xf86Info_consoleFd = 1;
1656 | }
1657 | RETURN0
1658 | } else if (req == VT_SETMODE) {
1659 | RETURN0
1660 | } else if (req == VT_GETSTATE) {
1661 | RETURN0
1662 | } else if (req == KDSETMODE) {
1663 | RETURN0
1664 | } else if (req == KDSETLED) {
1665 | RETURN0
1666 | } else if (req == KDGKBMODE) {
1667 | RETURN0
1668 | } else if (req == KDSKBMODE) {
1669 | RETURN0
1670 | } else if (req == VT_ACTIVATE) {
1671 | RETURN0
1672 | } else if (req == VT_WAITACTIVE) {
1673 | RETURN0
1674 | } else if (req == VT_RELDISP) {
1675 | if (ptr == (void *) 1) {
1676 | RETURN1
1677 | } else if (ptr == (void *) VT_ACKACQ) {
1678 | RETURN0
1679 | }
1680 | }
1681 |
1682 | return(real_ioctl(fd, req, ptr));
1683 | }
1684 |
1685 | typedef void (*sighandler_t)(int);
1686 | #define SIGUSR1 10
1687 | #define SIG_DFL ((sighandler_t)0)
1688 |
1689 | sighandler_t signal(int signum, sighandler_t handler) {
1690 | static sighandler_t (*real_signal)(int, sighandler_t) = NULL;
1691 |
1692 | CHECKIT
1693 | if (! real_signal) {
1694 | real_signal = (sighandler_t (*)(int, sighandler_t))
1695 | dlsym(RTLD_NEXT, "signal");
1696 | }
1697 |
1698 | if (debug) fprintf(stderr, "SIGNAL: %d %p\n", signum, handler);
1699 |
1700 | if (signum == SIGUSR1) {
1701 | if (debug) fprintf(stderr, "SIGNAL: skip SIGUSR1\n");
1702 | return SIG_DFL;
1703 | }
1704 |
1705 | return(real_signal(signum, handler));
1706 | }
1707 |
1708 | int close(int fd) {
1709 | static int (*real_close)(int) = NULL;
1710 |
1711 | CHECKIT
1712 | if (! real_close) {
1713 | real_close = (int (*)(int)) dlsym(RTLD_NEXT, "close");
1714 | }
1715 |
1716 | if (debug) fprintf(stderr, "CLOSE: %d\n", fd);
1717 | if (!root) {
1718 | if (fd < 256) {
1719 | devs[fd][0] = '\0';
1720 | }
1721 | }
1722 | return(real_close(fd));
1723 | }
1724 |
1725 | struct stat {
1726 | int foo;
1727 | };
1728 |
1729 | int stat(const char *path, struct stat *buf) {
1730 | static int (*real_stat)(const char *, struct stat *) = NULL;
1731 |
1732 | CHECKIT
1733 | if (! real_stat) {
1734 | real_stat = (int (*)(const char *, struct stat *))
1735 | dlsym(RTLD_NEXT, "stat");
1736 | }
1737 |
1738 | if (debug) fprintf(stderr, "STAT: %s\n", path);
1739 |
1740 | return(real_stat(path, buf));
1741 | }
1742 |
1743 | int stat64(const char *path, struct stat *buf) {
1744 | static int (*real_stat64)(const char *, struct stat *) = NULL;
1745 |
1746 | CHECKIT
1747 | if (! real_stat64) {
1748 | real_stat64 = (int (*)(const char *, struct stat *))
1749 | dlsym(RTLD_NEXT, "stat64");
1750 | }
1751 |
1752 | if (debug) fprintf(stderr, "STAT64: %s\n", path);
1753 |
1754 | return(real_stat64(path, buf));
1755 | }
1756 |
1757 | int chown(const char *path, uid_t owner, gid_t group) {
1758 | static int (*real_chown)(const char *, uid_t, gid_t) = NULL;
1759 |
1760 | CHECKIT
1761 | if (! real_chown) {
1762 | real_chown = (int (*)(const char *, uid_t, gid_t))
1763 | dlsym(RTLD_NEXT, "chown");
1764 | }
1765 |
1766 | if (root) {
1767 | return(real_chown(path, owner, group));
1768 | }
1769 |
1770 | if (debug) fprintf(stderr, "CHOWN: %s %d %d\n", path, owner, group);
1771 |
1772 | if (strstr(path, "/dev") == path) {
1773 | if (debug) fprintf(stderr, "CHOWN: return 0\n");
1774 | return 0;
1775 | }
1776 |
1777 | return(real_chown(path, owner, group));
1778 | }
1779 |
1780 | extern int *__errno_location (void);
1781 | #ifndef ENODEV
1782 | #define ENODEV 19
1783 | #endif
1784 |
1785 | int ioperm(unsigned long from, unsigned long num, int turn_on) {
1786 | static int (*real_ioperm)(unsigned long, unsigned long, int) = NULL;
1787 |
1788 | CHECKIT
1789 | if (! real_ioperm) {
1790 | real_ioperm = (int (*)(unsigned long, unsigned long, int))
1791 | dlsym(RTLD_NEXT, "ioperm");
1792 | }
1793 | if (debug) fprintf(stderr, "IOPERM: %d %d %d\n", (int) from, (int) num, turn_on);
1794 | if (root) {
1795 | return(real_ioperm(from, num, turn_on));
1796 | }
1797 | if (from == 0 && num == 1024 && turn_on == 1) {
1798 | /* we want xf86EnableIO to fail */
1799 | if (debug) fprintf(stderr, "IOPERM: setting ENODEV.\n");
1800 | *__errno_location() = ENODEV;
1801 | return -1;
1802 | }
1803 | return 0;
1804 | }
1805 |
1806 | int iopl(int level) {
1807 | static int (*real_iopl)(int) = NULL;
1808 |
1809 | CHECKIT
1810 | if (! real_iopl) {
1811 | real_iopl = (int (*)(int)) dlsym(RTLD_NEXT, "iopl");
1812 | }
1813 | if (debug) fprintf(stderr, "IOPL: %d\n", level);
1814 | if (root) {
1815 | return(real_iopl(level));
1816 | }
1817 | return 0;
1818 | }
1819 |
1820 | #ifdef INTERPOSE_GETUID
1821 |
1822 | /*
1823 | * we got things to work w/o pretending to be root.
1824 | * so we no longer interpose getuid(), etc.
1825 | */
1826 |
1827 | uid_t getuid(void) {
1828 | static uid_t (*real_getuid)(void) = NULL;
1829 | CHECKIT
1830 | if (! real_getuid) {
1831 | real_getuid = (uid_t (*)(void)) dlsym(RTLD_NEXT, "getuid");
1832 | }
1833 | if (root) {
1834 | return(real_getuid());
1835 | }
1836 | if (debug) fprintf(stderr, "GETUID: 0\n");
1837 | return 0;
1838 | }
1839 | uid_t geteuid(void) {
1840 | static uid_t (*real_geteuid)(void) = NULL;
1841 | CHECKIT
1842 | if (! real_geteuid) {
1843 | real_geteuid = (uid_t (*)(void)) dlsym(RTLD_NEXT, "geteuid");
1844 | }
1845 | if (root) {
1846 | return(real_geteuid());
1847 | }
1848 | if (debug) fprintf(stderr, "GETEUID: 0\n");
1849 | return 0;
1850 | }
1851 | uid_t geteuid_kludge1(void) {
1852 | static uid_t (*real_geteuid)(void) = NULL;
1853 | CHECKIT
1854 | if (! real_geteuid) {
1855 | real_geteuid = (uid_t (*)(void)) dlsym(RTLD_NEXT, "geteuid");
1856 | }
1857 | if (debug) fprintf(stderr, "GETEUID: 0 saw_libmodules=%d\n", saw_lib_modules);
1858 | if (root && !saw_lib_modules) {
1859 | return(real_geteuid());
1860 | } else {
1861 | saw_lib_modules = 0;
1862 | return 0;
1863 | }
1864 | }
1865 |
1866 | uid_t getuid32(void) {
1867 | static uid_t (*real_getuid32)(void) = NULL;
1868 | CHECKIT
1869 | if (! real_getuid32) {
1870 | real_getuid32 = (uid_t (*)(void)) dlsym(RTLD_NEXT, "getuid32");
1871 | }
1872 | if (root) {
1873 | return(real_getuid32());
1874 | }
1875 | if (debug) fprintf(stderr, "GETUID32: 0\n");
1876 | return 0;
1877 | }
1878 | uid_t geteuid32(void) {
1879 | static uid_t (*real_geteuid32)(void) = NULL;
1880 | CHECKIT
1881 | if (! real_geteuid32) {
1882 | real_geteuid32 = (uid_t (*)(void)) dlsym(RTLD_NEXT, "geteuid32");
1883 | }
1884 | if (root) {
1885 | return(real_geteuid32());
1886 | }
1887 | if (debug) fprintf(stderr, "GETEUID32: 0\n");
1888 | return 0;
1889 | }
1890 |
1891 | gid_t getgid(void) {
1892 | static gid_t (*real_getgid)(void) = NULL;
1893 | CHECKIT
1894 | if (! real_getgid) {
1895 | real_getgid = (gid_t (*)(void)) dlsym(RTLD_NEXT, "getgid");
1896 | }
1897 | if (root) {
1898 | return(real_getgid());
1899 | }
1900 | if (debug) fprintf(stderr, "GETGID: 0\n");
1901 | return 0;
1902 | }
1903 | gid_t getegid(void) {
1904 | static gid_t (*real_getegid)(void) = NULL;
1905 | CHECKIT
1906 | if (! real_getegid) {
1907 | real_getegid = (gid_t (*)(void)) dlsym(RTLD_NEXT, "getegid");
1908 | }
1909 | if (root) {
1910 | return(real_getegid());
1911 | }
1912 | if (debug) fprintf(stderr, "GETEGID: 0\n");
1913 | return 0;
1914 | }
1915 | gid_t getgid32(void) {
1916 | static gid_t (*real_getgid32)(void) = NULL;
1917 | CHECKIT
1918 | if (! real_getgid32) {
1919 | real_getgid32 = (gid_t (*)(void)) dlsym(RTLD_NEXT, "getgid32");
1920 | }
1921 | if (root) {
1922 | return(real_getgid32());
1923 | }
1924 | if (debug) fprintf(stderr, "GETGID32: 0\n");
1925 | return 0;
1926 | }
1927 | gid_t getegid32(void) {
1928 | static gid_t (*real_getegid32)(void) = NULL;
1929 | CHECKIT
1930 | if (! real_getegid32) {
1931 | real_getegid32 = (gid_t (*)(void)) dlsym(RTLD_NEXT, "getegid32");
1932 | }
1933 | if (root) {
1934 | return(real_getegid32());
1935 | }
1936 | if (debug) fprintf(stderr, "GETEGID32: 0\n");
1937 | return 0;
1938 | }
1939 | #endif
1940 |
1941 | #if 0
1942 | /* maybe we need to interpose on strcmp someday... here is the template */
1943 | int strcmp(const char *s1, const char *s2) {
1944 | static int (*real_strcmp)(const char *, const char *) = NULL;
1945 | CHECKIT
1946 | if (! real_strcmp) {
1947 | real_strcmp = (int (*)(const char *, const char *)) dlsym(RTLD_NEXT, "strcmp");
1948 | }
1949 | if (debug) fprintf(stderr, "STRCMP: '%s' '%s'\n", s1, s2);
1950 | return(real_strcmp(s1, s2));
1951 | }
1952 | #endif
1953 |
1954 | #code_end
1955 | }
1956 |
--------------------------------------------------------------------------------
/azure/launch_plan.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pdb
3 |
4 | from doodad.wrappers.easy_launch import sweep_function, save_doodad_config
5 |
6 | codepath = '/home/code'
7 | script = 'scripts/plan.py'
8 |
9 | def remote_fn(doodad_config, variant):
10 | ## get suffix range to allow running multiple trials per job
11 | n_suffices = variant['n_suffices']
12 | suffix_start = variant['suffix_start']
13 | del variant['n_suffices']
14 | del variant['suffix_start']
15 |
16 | kwarg_string = ' '.join([
17 | f'--{k} {v}' for k, v in variant.items()
18 | ])
19 | print(kwarg_string)
20 |
21 | d4rl_path = os.path.join(doodad_config.output_directory, 'datasets/')
22 | os.system(f'ls -a {codepath}')
23 | os.system(f'mv {codepath}/git {codepath}/.git')
24 |
25 | for suffix in range(suffix_start, suffix_start + n_suffices):
26 | os.system(
27 | f'''export PYTHONPATH=$PYTHONPATH:{codepath} && '''
28 | f'''export CUDA_VISIBLE_DEVICES=0 && '''
29 | f'''export D4RL_DATASET_DIR={d4rl_path} && '''
30 | f'''python {os.path.join(codepath, script)} '''
31 | f'''--suffix {suffix} '''
32 | f'''{kwarg_string}'''
33 |
34 | )
35 |
36 | save_doodad_config(doodad_config)
37 |
38 | if __name__ == "__main__":
39 |
40 | environments = ['ant']
41 | buffers = ['medium-expert-v2', 'medium-v2', 'medium-replay-v2', 'random-v2']
42 | datasets = [f'{env}-{buf}' for env in environments for buf in buffers]
43 |
44 | azure_logpath = 'defaults/'
45 |
46 | params_to_sweep = {
47 | 'dataset': datasets,
48 | 'horizon': [15],
49 | }
50 |
51 | default_params = {
52 | 'logbase': os.path.join('/doodad_tmp', azure_logpath, 'logs'),
53 | 'prefix': 'plans/azure/',
54 | 'verbose': False,
55 | 'suffix_start': 0,
56 | 'n_suffices': 3,
57 | }
58 |
59 | print(params_to_sweep)
60 | print(default_params)
61 |
62 | sweep_function(
63 | remote_fn,
64 | params_to_sweep,
65 | default_params=default_params,
66 | config_path=os.path.abspath('azure/config.py'),
67 | log_path=azure_logpath,
68 | azure_region='westus2',
69 | # gpu_model='nvidia-tesla-v100',
70 | gpu_model='nvidia-tesla-t4',
71 | filter_dir=['logs', 'bin', 'mount'],
72 | use_gpu=True,
73 | )
74 |
--------------------------------------------------------------------------------
/azure/launch_train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pdb
3 |
4 | from doodad.wrappers.easy_launch import sweep_function, save_doodad_config
5 |
6 | codepath = '/home/code'
7 | script = 'scripts/train.py'
8 |
9 | def remote_fn(doodad_config, variant):
10 | kwarg_string = ' '.join([
11 | f'--{k} {v}' for k, v in variant.items()
12 | ])
13 | print(kwarg_string)
14 |
15 | d4rl_path = os.path.join(doodad_config.output_directory, 'datasets/')
16 | os.system(f'ls -a {codepath}')
17 | os.system(f'mv {codepath}/git {codepath}/.git')
18 | os.system(
19 | f'''export PYTHONPATH=$PYTHONPATH:{codepath} && '''
20 | f'''export CUDA_VISIBLE_DEVICES=0 && '''
21 | f'''export D4RL_DATASET_DIR={d4rl_path} && '''
22 | f'''python {os.path.join(codepath, script)} '''
23 | f'''{kwarg_string}'''
24 |
25 | )
26 | save_doodad_config(doodad_config)
27 |
28 | if __name__ == "__main__":
29 |
30 | environments = ['halfcheetah', 'hopper', 'walker2d', 'ant']
31 | buffers = ['expert-v2']
32 | datasets = [f'{env}-{buf}' for env in environments for buf in buffers]
33 |
34 | azure_logpath = 'defaults/'
35 |
36 | params_to_sweep = {
37 | 'dataset': datasets,
38 | }
39 |
40 | default_params = {
41 | 'logbase': os.path.join('/doodad_tmp', azure_logpath, 'logs'),
42 | 'exp_name': 'gpt/azure',
43 | }
44 |
45 | sweep_function(
46 | remote_fn,
47 | params_to_sweep,
48 | default_params=default_params,
49 | config_path=os.path.abspath('azure/config.py'),
50 | log_path=azure_logpath,
51 | gpu_model='nvidia-tesla-v100',
52 | filter_dir=['logs', 'bin'],
53 | use_gpu=True,
54 | )
55 |
--------------------------------------------------------------------------------
/azure/make_fuse_config.sh:
--------------------------------------------------------------------------------
1 | ## AZURE_STORAGE_CONNECTION_STRING has a substring formatted lik:
2 | ## AccountName=${STORAGE_ACCOUNT};AccountKey=${STORAGE_KEY};EndpointSuffix= ...
3 | export AZURE_STORAGE_ACCOUNT=`(echo $AZURE_STORAGE_CONNECTION_STRING | grep -o -P '(?<=AccountName=).*(?=;AccountKey)')`
4 | export AZURE_STORAGE_KEY=`(echo $AZURE_STORAGE_CONNECTION_STRING | grep -o -P '(?<=AccountKey=).*(?=;EndpointSuffix)')`
5 |
6 | echo "accountName" ${AZURE_STORAGE_ACCOUNT} > ./azure/fuse.cfg
7 | echo "accountKey" ${AZURE_STORAGE_KEY} >> ./azure/fuse.cfg
8 | echo "containerName" ${AZURE_STORAGE_CONTAINER} >> ./azure/fuse.cfg
--------------------------------------------------------------------------------
/azure/mount.sh:
--------------------------------------------------------------------------------
1 | mkdir ~/azure_mount
2 | blobfuse ~/azure_mount --tmp-path=/tmp --config-file=./azure/fuse.cfg
3 |
--------------------------------------------------------------------------------
/azure/sync.sh:
--------------------------------------------------------------------------------
1 | if keyctl show 2>&1 | grep -q "workaroundSession";
2 | then
3 | echo "Already logged in"
4 | else
5 | echo "Logging in with tenant id:" ${AZURE_TENANT_ID}
6 | keyctl session workaroundSession
7 | ./bin/azcopy login --tenant-id=$AZURE_TENANT_ID
8 | fi
9 |
10 | export LOGBASE=defaults
11 |
12 | ## AZURE_STORAGE_CONNECTION_STRING has a substring formatted lik:
13 | ## AccountName=${STORAGE_ACCOUNT};AccountKey= ...
14 | export AZURE_STORAGE_ACCOUNT=`(echo $AZURE_STORAGE_CONNECTION_STRING | grep -o -P '(?<=AccountName=).*(?=;AccountKey)')`
15 |
16 | echo "Syncing from" ${AZURE_STORAGE_ACCOUNT}"/"${AZURE_STORAGE_CONTAINER}"/"${LOGBASE}
17 |
18 | ./bin/azcopy sync https://${AZURE_STORAGE_ACCOUNT}.blob.core.windows.net/${AZURE_STORAGE_CONTAINER}/${LOGBASE}/logs logs/ --recursive
--------------------------------------------------------------------------------
/config/offline.py:
--------------------------------------------------------------------------------
1 | from trajectory.utils import watch
2 |
3 | #------------------------ base ------------------------#
4 |
5 | logbase = 'logs/'
6 | gpt_expname = 'gpt/azure'
7 |
8 | ## automatically make experiment names for planning
9 | ## by labelling folders with these args
10 | args_to_watch = [
11 | ('prefix', ''),
12 | ('plan_freq', 'freq'),
13 | ('horizon', 'H'),
14 | ('beam_width', 'beam'),
15 | ]
16 |
17 | base = {
18 |
19 | 'train': {
20 | 'N': 100,
21 | 'discount': 0.99,
22 | 'n_layer': 4,
23 | 'n_head': 4,
24 |
25 | ## number of epochs for a 1M-size dataset; n_epochs = 1M / dataset_size * n_epochs_ref
26 | 'n_epochs_ref': 50,
27 | 'n_saves': 3,
28 | 'logbase': logbase,
29 | 'device': 'cuda',
30 |
31 | 'n_embd': 32,
32 | 'batch_size': 256,
33 | 'learning_rate': 6e-4,
34 | 'lr_decay': True,
35 | 'seed': 42,
36 |
37 | 'embd_pdrop': 0.1,
38 | 'resid_pdrop': 0.1,
39 | 'attn_pdrop': 0.1,
40 |
41 | 'step': 1,
42 | 'subsampled_sequence_length': 10,
43 | 'termination_penalty': -100,
44 | 'exp_name': gpt_expname,
45 |
46 | 'discretizer': 'QuantileDiscretizer',
47 | 'action_weight': 5,
48 | 'reward_weight': 1,
49 | 'value_weight': 1,
50 | },
51 |
52 | 'plan': {
53 | 'logbase': logbase,
54 | 'gpt_loadpath': gpt_expname,
55 | 'gpt_epoch': 'latest',
56 | 'device': 'cuda',
57 | 'renderer': 'Renderer',
58 |
59 | 'plan_freq': 1,
60 | 'horizon': 15,
61 | 'beam_width': 128,
62 | 'n_expand': 2,
63 |
64 | 'k_obs': 1,
65 | 'k_act': None,
66 | 'cdf_obs': None,
67 | 'cdf_act': 0.6,
68 | 'percentile': 'mean',
69 |
70 | 'max_context_transitions': 5,
71 | 'prefix_context': True,
72 |
73 | 'vis_freq': 50,
74 | 'exp_name': watch(args_to_watch),
75 | 'prefix': 'plans/defaults/',
76 | 'suffix': '0',
77 | 'verbose': True,
78 | },
79 |
80 | }
81 |
82 | #------------------------ locomotion ------------------------#
83 |
84 | ## for all halfcheetah environments, you can reduce the planning horizon and beam width without
85 | ## affecting performance. good for speed and sanity.
86 |
87 | halfcheetah_medium_v2 = halfcheetah_medium_replay_v2 = {
88 | 'plan': {
89 | 'horizon': 5,
90 | 'beam_width': 32,
91 | }
92 | }
93 |
94 | halfcheetah_medium_expert_v2 = {
95 | 'plan': {
96 | 'beam_width': 32,
97 | },
98 | }
99 |
100 | ## if you leave the dictionary empty, it will use the base parameters
101 | hopper_medium_expert_v2 = hopper_medium_v2 = walker2d_medium_v2 = {}
102 |
103 | ## hopper and wlaker2d are a little more sensitive to planning hyperparameters;
104 | ## proceed with caution when reducing the horizon or increasing the planning frequency
105 |
106 | hopper_medium_replay_v2 = {
107 | 'train': {
108 | ## train on the medium-replay datasets longer
109 | 'n_epochs_ref': 80,
110 | },
111 | }
112 |
113 | walker2d_medium_expert_v2 = {
114 | 'plan': {
115 | ## also safe to reduce the horizon here
116 | 'horizon': 5,
117 | },
118 | }
119 |
120 | walker2d_medium_replay_v2 = {
121 | 'train': {
122 | ## train on the medium-replay datasets longer
123 | 'n_epochs_ref': 80,
124 | },
125 | 'plan': {
126 | ## can reduce beam width, but need to adjust action sampling
127 | ## distribution and increase horizon to accomodate
128 | 'horizon': 20,
129 | 'beam_width': 32,
130 | 'k_act': 40,
131 | 'cdf_act': None,
132 | }
133 | }
134 |
135 | ant_medium_v2 = ant_medium_replay_v2 = ant_random_v2 = {
136 | 'train': {
137 | ## reduce batch size because the dimensionality is larger
138 | 'batch_size': 128,
139 | },
140 | 'plan': {
141 | 'horizon': 5,
142 | }
143 | }
144 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: trajectory
2 | channels:
3 | - defaults
4 | - conda-forge
5 | dependencies:
6 | - python=3.8
7 | - pip
8 | - patchelf
9 | - pip:
10 | - -f https://download.pytorch.org/whl/torch_stable.html
11 | - numpy
12 | - gym==0.18.0
13 | - mujoco-py==2.0.2.13
14 | - matplotlib==3.3.4
15 | - torch==1.9.1+cu111
16 | - typed-argument-parser
17 | - git+https://github.com/Farama-Foundation/d4rl@f2a05c0d66722499bf8031b094d9af3aea7c372b#egg=d4rl
18 | - scikit-image==0.17.2
19 | - scikit-video==1.1.11
20 | - gitpython
21 |
--------------------------------------------------------------------------------
/plotting/bar.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jannerm/trajectory-transformer/8834a6ed04ceeab8fdb9465e145c6e041c05d71b/plotting/bar.png
--------------------------------------------------------------------------------
/plotting/plot.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib
3 | import matplotlib.pyplot as plt
4 | import pdb
5 |
6 | from plotting.scores import means
7 |
8 | class Colors:
9 | grey = '#B4B4B4'
10 | gold = '#F6C781'
11 | red = '#EC7C7D'
12 | blue = '#70ABCC'
13 |
14 | LABELS = {
15 | # 'BC': 'Behavior\nCloning',
16 | # 'MBOP': 'Model-Based\nOffline Planning',
17 | # 'BRAC': 'Behavior-Reg.\nActor-Critic',
18 | # 'CQL': 'Conservative\nQ-Learning',
19 | }
20 |
21 | def get_mean(results, exclude=None):
22 | '''
23 | results : { environment: score, ... }
24 | '''
25 | filtered = {
26 | k: v for k, v in results.items()
27 | if (not exclude) or (exclude and exclude not in k)
28 | }
29 | return np.mean(list(filtered.values()))
30 |
31 | if __name__ == '__main__':
32 |
33 | #################
34 | ## latex
35 | #################
36 | matplotlib.rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})
37 | matplotlib.rc('text', usetex=True)
38 | matplotlib.rcParams['text.latex.preamble']=[r"\usepackage{amsmath}"]
39 | #################
40 |
41 | fig = plt.gcf()
42 | ax = plt.gca()
43 | fig.set_size_inches(7.5, 2.5)
44 |
45 | means = {k: get_mean(v, exclude='ant') for k, v in means.items()}
46 | print(means)
47 |
48 | algs = ['BC', 'MBOP', 'BRAC', 'CQL', 'Decision\nTransformer', 'Trajectory\nTransformer']
49 | vals = [means[alg] for alg in algs]
50 |
51 | colors = [
52 | Colors.grey, Colors.gold,
53 | Colors.red, Colors.red, Colors.blue, Colors.blue
54 | ]
55 |
56 | labels = [LABELS.get(alg, alg) for alg in algs]
57 | plt.bar(labels, vals, color=colors, edgecolor=Colors.gold, lw=0)
58 | plt.ylabel('Average normalized return', labelpad=15)
59 | # plt.title('Offline RL Results')
60 |
61 | legend_labels = ['Behavior Cloning', 'Trajectory Optimization', 'Temporal Difference', 'Sequence Modeling']
62 | colors = [Colors.grey, Colors.gold, Colors.red, Colors.blue]
63 | handles = [plt.Rectangle((0,0),1,1, color=color) for label, color in zip(legend_labels, colors)]
64 | plt.legend(handles, legend_labels, ncol=4,
65 | bbox_to_anchor=(1.07, -.18), fancybox=False, framealpha=0, shadow=False, columnspacing=1.5, handlelength=1.5)
66 |
67 | matplotlib.rcParams['hatch.linewidth'] = 7.5
68 | # ax.patches[-1].set_hatch('/')
69 |
70 | ax.spines['right'].set_visible(False)
71 | ax.spines['top'].set_visible(False)
72 |
73 | # plt.savefig('plotting/bar.pdf', bbox_inches='tight')
74 | plt.savefig('plotting/bar.png', bbox_inches='tight', dpi=500)
75 |
--------------------------------------------------------------------------------
/plotting/read_results.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import numpy as np
4 | import json
5 | import pdb
6 |
7 | import trajectory.utils as utils
8 |
9 | DATASETS = [
10 | f'{env}-{buffer}'
11 | for env in ['hopper', 'walker2d', 'halfcheetah', 'ant']
12 | for buffer in ['medium-expert-v2', 'medium-v2', 'medium-replay-v2']
13 | ]
14 |
15 | LOGBASE = 'logs'
16 | TRIAL = '*'
17 | EXP_NAME = 'plans/pretrained'
18 |
19 | def load_results(paths):
20 | '''
21 | paths : path to directory containing experiment trials
22 | '''
23 | scores = []
24 | for i, path in enumerate(sorted(paths)):
25 | score = load_result(path)
26 | if score is None:
27 | print(f'Skipping {path}')
28 | continue
29 | scores.append(score)
30 |
31 | suffix = path.split('/')[-1]
32 |
33 | mean = np.mean(scores)
34 | err = np.std(scores) / np.sqrt(len(scores))
35 | return mean, err, scores
36 |
37 | def load_result(path):
38 | '''
39 | path : path to experiment directory; expects `rollout.json` to be in directory
40 | '''
41 | fullpath = os.path.join(path, 'rollout.json')
42 | suffix = path.split('/')[-1]
43 |
44 | if not os.path.exists(fullpath):
45 | return None
46 |
47 | results = json.load(open(fullpath, 'rb'))
48 | score = results['score']
49 | return score * 100
50 |
51 | #######################
52 | ######## setup ########
53 | #######################
54 |
55 | if __name__ == '__main__':
56 |
57 | class Parser(utils.Parser):
58 | dataset: str = None
59 |
60 | args = Parser().parse_args()
61 |
62 | for dataset in ([args.dataset] if args.dataset else DATASETS):
63 | subdirs = glob.glob(os.path.join(LOGBASE, dataset, EXP_NAME))
64 |
65 | for subdir in subdirs:
66 | reldir = subdir.split('/')[-1]
67 | paths = glob.glob(os.path.join(subdir, TRIAL))
68 |
69 | mean, err, scores = load_results(paths)
70 | print(f'{dataset.ljust(30)} | {subdir.ljust(50)} | {len(scores)} scores \n {mean:.2f} +/- {err:.2f}\n')
71 |
--------------------------------------------------------------------------------
/plotting/scores.py:
--------------------------------------------------------------------------------
1 | means = {
2 | 'Trajectory\nTransformer': {
3 | ##
4 | 'halfcheetah-medium-expert-v2': 95.0,
5 | 'hopper-medium-expert-v2': 110.0,
6 | 'walker2d-medium-expert-v2': 101.9,
7 | 'ant-medium-expert-v2': 116.1,
8 | ##
9 | 'halfcheetah-medium-v2': 46.9,
10 | 'hopper-medium-v2': 61.1,
11 | 'walker2d-medium-v2': 79.0,
12 | 'ant-medium-v2': 83.1,
13 | ##
14 | 'halfcheetah-medium-replay-v2': 41.9,
15 | 'hopper-medium-replay-v2': 91.5,
16 | 'walker2d-medium-replay-v2': 82.6,
17 | 'ant-medium-replay-v2': 77.0,
18 | },
19 | 'Decision\nTransformer': {
20 | ##
21 | 'halfcheetah-medium-expert-v2': 86.8,
22 | 'hopper-medium-expert-v2': 107.6,
23 | 'walker2d-medium-expert-v2': 108.1,
24 | ##
25 | 'halfcheetah-medium-v2': 42.6,
26 | 'hopper-medium-v2': 67.6,
27 | 'walker2d-medium-v2': 74.0,
28 | ##
29 | 'halfcheetah-medium-replay-v2': 36.6,
30 | 'hopper-medium-replay-v2': 82.7,
31 | 'walker2d-medium-replay-v2': 66.6,
32 | },
33 | 'CQL': {
34 | ##
35 | 'halfcheetah-medium-expert-v2': 91.6,
36 | 'hopper-medium-expert-v2': 105.4,
37 | 'walker2d-medium-expert-v2': 108.8,
38 | ##
39 | 'halfcheetah-medium-v2': 44.0,
40 | 'hopper-medium-v2': 58.5,
41 | 'walker2d-medium-v2': 72.5,
42 | ##
43 | 'halfcheetah-medium-replay-v2': 45.5,
44 | 'hopper-medium-replay-v2': 95.0,
45 | 'walker2d-medium-replay-v2': 77.2,
46 | },
47 | 'MOPO': {
48 | ##
49 | 'halfcheetah-medium-expert-v2': 63.3,
50 | 'hopper-medium-expert-v2': 23.7,
51 | 'walker2d-medium-expert-v2': 44.6,
52 | ##
53 | 'halfcheetah-medium-v2': 42.3,
54 | 'hopper-medium-v2': 28.0,
55 | 'walker2d-medium-v2': 17.8,
56 | ##
57 | 'halfcheetah-medium-replay-v2': 53.1,
58 | 'hopper-medium-replay-v2': 67.5,
59 | 'walker2d-medium-replay-v2':39.0,
60 | },
61 | 'MBOP': {
62 | ##
63 | 'halfcheetah-medium-expert-v2': 105.9,
64 | 'hopper-medium-expert-v2': 55.1,
65 | 'walker2d-medium-expert-v2': 70.2,
66 | ##
67 | 'halfcheetah-medium-v2': 44.6,
68 | 'hopper-medium-v2': 48.8,
69 | 'walker2d-medium-v2': 41.0,
70 | ##
71 | 'halfcheetah-medium-replay-v2': 42.3,
72 | 'hopper-medium-replay-v2': 12.4,
73 | 'walker2d-medium-replay-v2': 9.7,
74 | },
75 | 'BRAC': {
76 | ##
77 | 'halfcheetah-medium-expert-v2': 41.9,
78 | 'hopper-medium-expert-v2': 0.9,
79 | 'walker2d-medium-expert-v2': 81.6,
80 | ##
81 | 'halfcheetah-medium-v2': 46.3,
82 | 'hopper-medium-v2': 31.3,
83 | 'walker2d-medium-v2': 81.1,
84 | ##
85 | 'halfcheetah-medium-replay-v2': 47.7,
86 | 'hopper-medium-replay-v2': 0.6,
87 | 'walker2d-medium-replay-v2': 0.9,
88 | },
89 | 'BC': {
90 | ##
91 | 'halfcheetah-medium-expert-v2': 59.9,
92 | 'hopper-medium-expert-v2': 79.6,
93 | 'walker2d-medium-expert-v2': 36.6,
94 | ##
95 | 'halfcheetah-medium-v2': 43.1,
96 | 'hopper-medium-v2': 63.9,
97 | 'walker2d-medium-v2': 77.3,
98 | ##
99 | 'halfcheetah-medium-replay-v2': 4.3,
100 | 'hopper-medium-replay-v2': 27.6,
101 | 'walker2d-medium-replay-v2': 36.9,
102 | },
103 | }
104 |
105 | errors = {
106 | 'Trajectory\nTransformer': {
107 | ##
108 | 'halfcheetah-medium-expert-v2': 0.2,
109 | 'hopper-medium-expert-v2': 2.7,
110 | 'walker2d-medium-expert-v2': 6.8,
111 | 'ant-medium-expert-v2': 9.0,
112 | ##
113 | 'halfcheetah-medium-v2': 0.4,
114 | 'hopper-medium-v2': 3.6,
115 | 'walker2d-medium-v2': 2.8,
116 | 'ant-medium-v2': 7.3,
117 | ##
118 | 'halfcheetah-medium-replay-v2': 2.5,
119 | 'hopper-medium-replay-v2': 3.6,
120 | 'walker2d-medium-replay-v2': 6.9,
121 | 'ant-medium-replay-v2': 6.8,
122 | },
123 | }
--------------------------------------------------------------------------------
/plotting/table.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pdb
3 |
4 | from plotting.plot import get_mean
5 | from plotting.scores import (
6 | means as MEANS,
7 | errors as ERRORS,
8 | )
9 |
10 | ALGORITHM_STRINGS = {
11 | 'Trajectory\nTransformer': 'TT (Ours)',
12 | 'Decision\nTransformer': 'DT',
13 | }
14 |
15 | BUFFER_STRINGS = {
16 | 'medium-expert': 'Medium-Expert',
17 | 'medium': 'Medium',
18 | 'medium-replay': 'Medium-Replay',
19 | }
20 |
21 | ENVIRONMENT_STRINGS = {
22 | 'halfcheetah': 'HalfCheetah',
23 | 'hopper': 'Hopper',
24 | 'walker2d': 'Walker2d',
25 | 'ant': 'Ant',
26 | }
27 |
28 | SHOW_ERRORS = ['Trajectory\nTransformer']
29 |
30 | def get_result(algorithm, buffer, environment, version='v2'):
31 | key = f'{environment}-{buffer}-{version}'
32 | mean = MEANS[algorithm].get(key, '-')
33 | if algorithm in SHOW_ERRORS:
34 | error = ERRORS[algorithm].get(key)
35 | return (mean, error)
36 | else:
37 | return mean
38 |
39 | def format_result(result):
40 | if type(result) == tuple:
41 | mean, std = result
42 | return f'${mean}$ \\scriptsize{{\\raisebox{{1pt}}{{$\\pm {std}$}}}}'
43 | else:
44 | return f'${result}$'
45 |
46 | def format_row(buffer, environment, results):
47 | buffer_str = BUFFER_STRINGS[buffer]
48 | environment_str = ENVIRONMENT_STRINGS[environment]
49 | results_str = ' & '.join(format_result(result) for result in results)
50 | row = f'{buffer_str} & {environment_str} & {results_str} \\\\ \n'
51 | return row
52 |
53 | def format_buffer_block(algorithms, buffer, environments):
54 | block_str = '\\midrule\n'
55 | for environment in environments:
56 | results = [get_result(alg, buffer, environment) for alg in algorithms]
57 | row_str = format_row(buffer, environment, results)
58 | block_str += row_str
59 | return block_str
60 |
61 | def format_algorithm(algorithm):
62 | algorithm_str = ALGORITHM_STRINGS.get(algorithm, algorithm)
63 | return f'\multicolumn{{1}}{{c}}{{\\bf {algorithm_str}}}'
64 |
65 | def format_algorithms(algorithms):
66 | return ' & '.join(format_algorithm(algorithm) for algorithm in algorithms)
67 |
68 | def format_averages(means, label):
69 | prefix = f'\\multicolumn{{2}}{{c}}{{\\bf Average ({label})}} & '
70 | formatted = ' & '.join(str(mean) for mean in means)
71 | return prefix + formatted
72 |
73 | def format_averages_block(algorithms):
74 | means_filtered = [np.round(get_mean(MEANS[algorithm], exclude='ant'), 1) for algorithm in algorithms]
75 | means_all = [np.round(get_mean(MEANS[algorithm], exclude=None), 1) for algorithm in algorithms]
76 |
77 | means_all = [
78 | means
79 | if 'ant-medium-expert-v2' in MEANS[algorithm]
80 | else '$-$'
81 | for algorithm, means in zip(algorithms, means_all)
82 | ]
83 |
84 | formatted_filtered = format_averages(means_filtered, 'without Ant')
85 | formatted_all = format_averages(means_all, 'all settings')
86 |
87 | formatted_block = (
88 | f'{formatted_filtered} \\hspace{{.6cm}} \\\\ \n'
89 | f'{formatted_all} \\hspace{{.6cm}} \\\\ \n'
90 | )
91 | return formatted_block
92 |
93 | def format_table(algorithms, buffers, environments):
94 | justify_str = 'll' + 'r' * len(algorithms)
95 | algorithm_str = format_algorithms(['Dataset', 'Environment'] + algorithms)
96 | averages_str = format_averages_block(algorithms)
97 | table_prefix = (
98 | '\\begin{table*}[h]\n'
99 | '\\centering\n'
100 | '\\small\n'
101 | f'\\begin{{tabular}}{{{justify_str}}}\n'
102 | '\\toprule\n'
103 | f'{algorithm_str} \\\\ \n'
104 | )
105 | table_suffix = (
106 | '\\midrule\n'
107 | f'{averages_str}'
108 | '\\bottomrule\n'
109 | '\\end{tabular}\n'
110 | '\\label{table:d4rl}\n'
111 | '\\end{table*}'
112 | )
113 | blocks = ''.join(format_buffer_block(algorithms, buffer, environments) for buffer in buffers)
114 | table = (
115 | f'{table_prefix}'
116 | f'{blocks}'
117 | f'{table_suffix}'
118 | )
119 | return table
120 |
121 |
122 | algorithms =['BC', 'MBOP', 'BRAC', 'CQL', 'Decision\nTransformer', 'Trajectory\nTransformer']
123 | buffers = ['medium-expert', 'medium', 'medium-replay']
124 | environments = ['halfcheetah', 'hopper', 'walker2d', 'ant']
125 |
126 | table = format_table(algorithms, buffers, environments)
127 | print(table)
128 |
--------------------------------------------------------------------------------
/pretrained.sh:
--------------------------------------------------------------------------------
1 | export DOWNLOAD_PATH=logs
2 |
3 | [ ! -d ${DOWNLOAD_PATH} ] && mkdir ${DOWNLOAD_PATH}
4 |
5 | ## downloads pretrained models for 16 datasets:
6 | ## {halfcheetah, hopper, walker2d, ant}
7 | ## x
8 | ## {expert-v2, medium-expert-v2, medium-v2, medium-replay-v2}
9 |
10 | wget https://www.dropbox.com/sh/r09lkdoj66kx43w/AACbXjMhcI6YNsn1qU4LParja?dl=1 -O dropbox_models.zip
11 | unzip dropbox_models.zip -d ${DOWNLOAD_PATH}
12 | rm dropbox_models.zip
13 |
14 | ## downloads 15 plans from each pretrained model
15 | wget https://www.dropbox.com/s/5sn79ep79yo22kv/pretrained-plans.tar?dl=1 -O dropbox_plans.tar
16 | tar -xvf dropbox_plans.tar
17 | cp -r pretrained-plans/* ${DOWNLOAD_PATH}
18 | rm -r pretrained-plans
19 | rm dropbox_plans.tar
20 |
--------------------------------------------------------------------------------
/scripts/plan.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pdb
3 | from os.path import join
4 |
5 | import trajectory.utils as utils
6 | import trajectory.datasets as datasets
7 | from trajectory.search import (
8 | beam_plan,
9 | make_prefix,
10 | extract_actions,
11 | update_context,
12 | )
13 |
14 | class Parser(utils.Parser):
15 | dataset: str = 'halfcheetah-medium-expert-v2'
16 | config: str = 'config.offline'
17 |
18 | #######################
19 | ######## setup ########
20 | #######################
21 |
22 | args = Parser().parse_args('plan')
23 |
24 | #######################
25 | ####### models ########
26 | #######################
27 |
28 | dataset = utils.load_from_config(args.logbase, args.dataset, args.gpt_loadpath,
29 | 'data_config.pkl')
30 |
31 | gpt, gpt_epoch = utils.load_model(args.logbase, args.dataset, args.gpt_loadpath,
32 | epoch=args.gpt_epoch, device=args.device)
33 |
34 | #######################
35 | ####### dataset #######
36 | #######################
37 |
38 | env = datasets.load_environment(args.dataset)
39 | renderer = utils.make_renderer(args)
40 | timer = utils.timer.Timer()
41 |
42 | discretizer = dataset.discretizer
43 | discount = dataset.discount
44 | observation_dim = dataset.observation_dim
45 | action_dim = dataset.action_dim
46 |
47 | value_fn = lambda x: discretizer.value_fn(x, args.percentile)
48 | preprocess_fn = datasets.get_preprocess_fn(env.name)
49 |
50 | #######################
51 | ###### main loop ######
52 | #######################
53 |
54 | observation = env.reset()
55 | total_reward = 0
56 |
57 | ## observations for rendering
58 | rollout = [observation.copy()]
59 |
60 | ## previous (tokenized) transitions for conditioning transformer
61 | context = []
62 |
63 | T = env.max_episode_steps
64 | for t in range(T):
65 |
66 | observation = preprocess_fn(observation)
67 |
68 | if t % args.plan_freq == 0:
69 | ## concatenate previous transitions and current observations to input to model
70 | prefix = make_prefix(discretizer, context, observation, args.prefix_context)
71 |
72 | ## sample sequence from model beginning with `prefix`
73 | sequence = beam_plan(
74 | gpt, value_fn, prefix,
75 | args.horizon, args.beam_width, args.n_expand, observation_dim, action_dim,
76 | discount, args.max_context_transitions, verbose=args.verbose,
77 | k_obs=args.k_obs, k_act=args.k_act, cdf_obs=args.cdf_obs, cdf_act=args.cdf_act,
78 | )
79 |
80 | else:
81 | sequence = sequence[1:]
82 |
83 | ## [ horizon x transition_dim ] convert sampled tokens to continuous trajectory
84 | sequence_recon = discretizer.reconstruct(sequence)
85 |
86 | ## [ action_dim ] index into sampled trajectory to grab first action
87 | action = extract_actions(sequence_recon, observation_dim, action_dim, t=0)
88 |
89 | ## execute action in environment
90 | next_observation, reward, terminal, _ = env.step(action)
91 |
92 | ## update return
93 | total_reward += reward
94 | score = env.get_normalized_score(total_reward)
95 |
96 | ## update rollout observations and context transitions
97 | rollout.append(next_observation.copy())
98 | context = update_context(context, discretizer, observation, action, reward, args.max_context_transitions)
99 |
100 | print(
101 | f'[ plan ] t: {t} / {T} | r: {reward:.2f} | R: {total_reward:.2f} | score: {score:.4f} | '
102 | f'time: {timer():.2f} | {args.dataset} | {args.exp_name} | {args.suffix}\n'
103 | )
104 |
105 | ## visualization
106 | if t % args.vis_freq == 0 or terminal or t == T:
107 |
108 | ## save current plan
109 | renderer.render_plan(join(args.savepath, f'{t}_plan.mp4'), sequence_recon, env.state_vector())
110 |
111 | ## save rollout thus far
112 | renderer.render_rollout(join(args.savepath, f'rollout.mp4'), rollout, fps=80)
113 |
114 | if terminal: break
115 |
116 | observation = next_observation
117 |
118 | ## save result as a json file
119 | json_path = join(args.savepath, 'rollout.json')
120 | json_data = {'score': score, 'step': t, 'return': total_reward, 'term': terminal, 'gpt_epoch': gpt_epoch}
121 | json.dump(json_data, open(json_path, 'w'), indent=2, sort_keys=True)
122 |
--------------------------------------------------------------------------------
/scripts/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | import pdb
5 |
6 | import trajectory.utils as utils
7 | import trajectory.datasets as datasets
8 | from trajectory.models.transformers import GPT
9 |
10 |
11 | class Parser(utils.Parser):
12 | dataset: str = 'halfcheetah-medium-expert-v2'
13 | config: str = 'config.offline'
14 |
15 | #######################
16 | ######## setup ########
17 | #######################
18 |
19 | args = Parser().parse_args('train')
20 |
21 | #######################
22 | ####### dataset #######
23 | #######################
24 |
25 | env = datasets.load_environment(args.dataset)
26 |
27 | sequence_length = args.subsampled_sequence_length * args.step
28 |
29 | dataset_config = utils.Config(
30 | datasets.DiscretizedDataset,
31 | savepath=(args.savepath, 'data_config.pkl'),
32 | env=args.dataset,
33 | N=args.N,
34 | penalty=args.termination_penalty,
35 | sequence_length=sequence_length,
36 | step=args.step,
37 | discount=args.discount,
38 | discretizer=args.discretizer,
39 | )
40 |
41 | dataset = dataset_config()
42 | obs_dim = dataset.observation_dim
43 | act_dim = dataset.action_dim
44 | transition_dim = dataset.joined_dim
45 |
46 | #######################
47 | ######## model ########
48 | #######################
49 |
50 | block_size = args.subsampled_sequence_length * transition_dim - 1
51 | print(
52 | f'Dataset size: {len(dataset)} | '
53 | f'Joined dim: {transition_dim} '
54 | f'(observation: {obs_dim}, action: {act_dim}) | Block size: {block_size}'
55 | )
56 |
57 | model_config = utils.Config(
58 | GPT,
59 | savepath=(args.savepath, 'model_config.pkl'),
60 | ## discretization
61 | vocab_size=args.N, block_size=block_size,
62 | ## architecture
63 | n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd*args.n_head,
64 | ## dimensions
65 | observation_dim=obs_dim, action_dim=act_dim, transition_dim=transition_dim,
66 | ## loss weighting
67 | action_weight=args.action_weight, reward_weight=args.reward_weight, value_weight=args.value_weight,
68 | ## dropout probabilities
69 | embd_pdrop=args.embd_pdrop, resid_pdrop=args.resid_pdrop, attn_pdrop=args.attn_pdrop,
70 | )
71 |
72 | model = model_config()
73 | model.to(args.device)
74 |
75 | #######################
76 | ####### trainer #######
77 | #######################
78 |
79 | warmup_tokens = len(dataset) * block_size ## number of tokens seen per epoch
80 | final_tokens = 20 * warmup_tokens
81 |
82 | trainer_config = utils.Config(
83 | utils.Trainer,
84 | savepath=(args.savepath, 'trainer_config.pkl'),
85 | # optimization parameters
86 | batch_size=args.batch_size,
87 | learning_rate=args.learning_rate,
88 | betas=(0.9, 0.95),
89 | grad_norm_clip=1.0,
90 | weight_decay=0.1, # only applied on matmul weights
91 | # learning rate decay: linear warmup followed by cosine decay to 10% of original
92 | lr_decay=args.lr_decay,
93 | warmup_tokens=warmup_tokens,
94 | final_tokens=final_tokens,
95 | ## dataloader
96 | num_workers=0,
97 | device=args.device,
98 | )
99 |
100 | trainer = trainer_config()
101 |
102 | #######################
103 | ###### main loop ######
104 | #######################
105 |
106 | ## scale number of epochs to keep number of updates constant
107 | n_epochs = int(1e6 / len(dataset) * args.n_epochs_ref)
108 | save_freq = int(n_epochs // args.n_saves)
109 |
110 | for epoch in range(n_epochs):
111 | print(f'\nEpoch: {epoch} / {n_epochs} | {args.dataset} | {args.exp_name}')
112 |
113 | trainer.train(model, dataset)
114 |
115 | ## get greatest multiple of `save_freq` less than or equal to `save_epoch`
116 | save_epoch = (epoch + 1) // save_freq * save_freq
117 | statepath = os.path.join(args.savepath, f'state_{save_epoch}.pt')
118 | print(f'Saving model to {statepath}')
119 |
120 | ## save state to disk
121 | state = model.state_dict()
122 | torch.save(state, statepath)
123 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 | from setuptools import find_packages
3 |
4 | setup(
5 | name='trajectory',
6 | packages=find_packages(),
7 | )
8 |
--------------------------------------------------------------------------------
/trajectory/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jannerm/trajectory-transformer/8834a6ed04ceeab8fdb9465e145c6e041c05d71b/trajectory/__init__.py
--------------------------------------------------------------------------------
/trajectory/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .d4rl import load_environment
2 | from .sequence import *
3 | from .preprocessing import get_preprocess_fn
4 |
--------------------------------------------------------------------------------
/trajectory/datasets/d4rl.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import gym
4 | import pdb
5 |
6 | from contextlib import (
7 | contextmanager,
8 | redirect_stderr,
9 | redirect_stdout,
10 | )
11 |
12 | @contextmanager
13 | def suppress_output():
14 | """
15 | A context manager that redirects stdout and stderr to devnull
16 | https://stackoverflow.com/a/52442331
17 | """
18 | with open(os.devnull, 'w') as fnull:
19 | with redirect_stderr(fnull) as err, redirect_stdout(fnull) as out:
20 | yield (err, out)
21 |
22 | with suppress_output():
23 | ## d4rl prints out a variety of warnings
24 | import d4rl
25 |
26 | # def construct_dataloader(dataset, **kwargs):
27 | # dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, pin_memory=True, **kwargs)
28 | # return dataloader
29 |
30 | def qlearning_dataset_with_timeouts(env, dataset=None, terminate_on_end=False, **kwargs):
31 | if dataset is None:
32 | dataset = env.get_dataset(**kwargs)
33 |
34 | N = dataset['rewards'].shape[0]
35 | obs_ = []
36 | next_obs_ = []
37 | action_ = []
38 | reward_ = []
39 | done_ = []
40 | realdone_ = []
41 |
42 | episode_step = 0
43 | for i in range(N-1):
44 | obs = dataset['observations'][i]
45 | new_obs = dataset['observations'][i+1]
46 | action = dataset['actions'][i]
47 | reward = dataset['rewards'][i]
48 | done_bool = bool(dataset['terminals'][i])
49 | realdone_bool = bool(dataset['terminals'][i])
50 | final_timestep = dataset['timeouts'][i]
51 |
52 | if i < N - 1:
53 | done_bool += dataset['timeouts'][i] #+1]
54 |
55 | if (not terminate_on_end) and final_timestep:
56 | # Skip this transition and don't apply terminals on the last step of an episode
57 | episode_step = 0
58 | continue
59 | if done_bool or final_timestep:
60 | episode_step = 0
61 |
62 | obs_.append(obs)
63 | next_obs_.append(new_obs)
64 | action_.append(action)
65 | reward_.append(reward)
66 | done_.append(done_bool)
67 | realdone_.append(realdone_bool)
68 | episode_step += 1
69 |
70 | return {
71 | 'observations': np.array(obs_),
72 | 'actions': np.array(action_),
73 | 'next_observations': np.array(next_obs_),
74 | 'rewards': np.array(reward_)[:,None],
75 | 'terminals': np.array(done_)[:,None],
76 | 'realterminals': np.array(realdone_)[:,None],
77 | }
78 |
79 | def load_environment(name):
80 | with suppress_output():
81 | wrapped_env = gym.make(name)
82 | env = wrapped_env.unwrapped
83 | env.max_episode_steps = wrapped_env._max_episode_steps
84 | env.name = name
85 | return env
86 |
--------------------------------------------------------------------------------
/trajectory/datasets/preprocessing.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def kitchen_preprocess_fn(observations):
4 | ## keep first 30 dimensions of 60-dimension observations
5 | keep = observations[:, :30]
6 | remove = observations[:, 30:]
7 | assert (remove.max(0) == remove.min(0)).all(), 'removing important state information'
8 | return keep
9 |
10 | def ant_preprocess_fn(observations):
11 | qpos_dim = 13 ## root_x and root_y removed
12 | qvel_dim = 14
13 | cfrc_dim = 84
14 | assert observations.shape[1] == qpos_dim + qvel_dim + cfrc_dim
15 | keep = observations[:, :qpos_dim + qvel_dim]
16 | return keep
17 |
18 | def vmap(fn):
19 |
20 | def _fn(inputs):
21 | if inputs.ndim == 1:
22 | inputs = inputs[None]
23 | return_1d = True
24 | else:
25 | return_1d = False
26 |
27 | outputs = fn(inputs)
28 |
29 | if return_1d:
30 | return outputs.squeeze(0)
31 | else:
32 | return outputs
33 |
34 | return _fn
35 |
36 | def preprocess_dataset(preprocess_fn):
37 |
38 | def _fn(dataset):
39 | for key in ['observations', 'next_observations']:
40 | dataset[key] = preprocess_fn(dataset[key])
41 | return dataset
42 |
43 | return _fn
44 |
45 | preprocess_functions = {
46 | 'kitchen-complete-v0': vmap(kitchen_preprocess_fn),
47 | 'kitchen-mixed-v0': vmap(kitchen_preprocess_fn),
48 | 'kitchen-partial-v0': vmap(kitchen_preprocess_fn),
49 | 'ant-expert-v2': vmap(ant_preprocess_fn),
50 | 'ant-medium-expert-v2': vmap(ant_preprocess_fn),
51 | 'ant-medium-replay-v2': vmap(ant_preprocess_fn),
52 | 'ant-medium-v2': vmap(ant_preprocess_fn),
53 | 'ant-random-v2': vmap(ant_preprocess_fn),
54 | }
55 |
56 | dataset_preprocess_functions = {
57 | k: preprocess_dataset(fn) for k, fn in preprocess_functions.items()
58 | }
59 |
60 | def get_preprocess_fn(env):
61 | return preprocess_functions.get(env, lambda x: x)
--------------------------------------------------------------------------------
/trajectory/datasets/sequence.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | import pdb
5 |
6 | from trajectory.utils import discretization
7 | from trajectory.utils.arrays import to_torch
8 |
9 | from .d4rl import load_environment, qlearning_dataset_with_timeouts
10 | from .preprocessing import dataset_preprocess_functions
11 |
12 | def segment(observations, terminals, max_path_length):
13 | """
14 | segment `observations` into trajectories according to `terminals`
15 | """
16 | assert len(observations) == len(terminals)
17 | observation_dim = observations.shape[1]
18 |
19 | trajectories = [[]]
20 | for obs, term in zip(observations, terminals):
21 | trajectories[-1].append(obs)
22 | if term.squeeze():
23 | trajectories.append([])
24 |
25 | if len(trajectories[-1]) == 0:
26 | trajectories = trajectories[:-1]
27 |
28 | ## list of arrays because trajectories lengths will be different
29 | trajectories = [np.stack(traj, axis=0) for traj in trajectories]
30 |
31 | n_trajectories = len(trajectories)
32 | path_lengths = [len(traj) for traj in trajectories]
33 |
34 | ## pad trajectories to be of equal length
35 | trajectories_pad = np.zeros((n_trajectories, max_path_length, observation_dim), dtype=trajectories[0].dtype)
36 | early_termination = np.zeros((n_trajectories, max_path_length), dtype=np.bool)
37 | for i, traj in enumerate(trajectories):
38 | path_length = path_lengths[i]
39 | trajectories_pad[i,:path_length] = traj
40 | early_termination[i,path_length:] = 1
41 |
42 | return trajectories_pad, early_termination, path_lengths
43 |
44 | class SequenceDataset(torch.utils.data.Dataset):
45 |
46 | def __init__(self, env, sequence_length=250, step=10, discount=0.99, max_path_length=1000, penalty=None, device='cuda:0'):
47 | print(f'[ datasets/sequence ] Sequence length: {sequence_length} | Step: {step} | Max path length: {max_path_length}')
48 | self.env = env = load_environment(env) if type(env) is str else env
49 | self.sequence_length = sequence_length
50 | self.step = step
51 | self.max_path_length = max_path_length
52 | self.device = device
53 |
54 | print(f'[ datasets/sequence ] Loading...', end=' ', flush=True)
55 | dataset = qlearning_dataset_with_timeouts(env.unwrapped, terminate_on_end=True)
56 | print('✓')
57 |
58 | preprocess_fn = dataset_preprocess_functions.get(env.name)
59 | if preprocess_fn:
60 | print(f'[ datasets/sequence ] Modifying environment')
61 | dataset = preprocess_fn(dataset)
62 | ##
63 |
64 | observations = dataset['observations']
65 | actions = dataset['actions']
66 | next_observations = dataset['next_observations']
67 | rewards = dataset['rewards']
68 | terminals = dataset['terminals']
69 | realterminals = dataset['realterminals']
70 |
71 | self.observations_raw = observations
72 | self.actions_raw = actions
73 | self.next_observations_raw = next_observations
74 | self.joined_raw = np.concatenate([observations, actions], axis=-1)
75 | self.rewards_raw = rewards
76 | self.terminals_raw = terminals
77 |
78 | ## terminal penalty
79 | if penalty is not None:
80 | terminal_mask = realterminals.squeeze()
81 | self.rewards_raw[terminal_mask] = penalty
82 |
83 | ## segment
84 | print(f'[ datasets/sequence ] Segmenting...', end=' ', flush=True)
85 | self.joined_segmented, self.termination_flags, self.path_lengths = segment(self.joined_raw, terminals, max_path_length)
86 | self.rewards_segmented, *_ = segment(self.rewards_raw, terminals, max_path_length)
87 | print('✓')
88 |
89 | self.discount = discount
90 | self.discounts = (discount ** np.arange(self.max_path_length))[:,None]
91 |
92 | ## [ n_paths x max_path_length x 1 ]
93 | self.values_segmented = np.zeros(self.rewards_segmented.shape)
94 |
95 | for t in range(max_path_length):
96 | ## [ n_paths x 1 ]
97 | V = (self.rewards_segmented[:,t+1:] * self.discounts[:-t-1]).sum(axis=1)
98 | self.values_segmented[:,t] = V
99 |
100 | ## add (r, V) to `joined`
101 | values_raw = self.values_segmented.squeeze(axis=-1).reshape(-1)
102 | values_mask = ~self.termination_flags.reshape(-1)
103 | self.values_raw = values_raw[values_mask, None]
104 | self.joined_raw = np.concatenate([self.joined_raw, self.rewards_raw, self.values_raw], axis=-1)
105 | self.joined_segmented = np.concatenate([self.joined_segmented, self.rewards_segmented, self.values_segmented], axis=-1)
106 |
107 | ## get valid indices
108 | indices = []
109 | for path_ind, length in enumerate(self.path_lengths):
110 | end = length - 1
111 | for i in range(end):
112 | indices.append((path_ind, i, i+sequence_length))
113 |
114 | self.indices = np.array(indices)
115 | self.observation_dim = observations.shape[1]
116 | self.action_dim = actions.shape[1]
117 | self.joined_dim = self.joined_raw.shape[1]
118 |
119 | ## pad trajectories
120 | n_trajectories, _, joined_dim = self.joined_segmented.shape
121 | self.joined_segmented = np.concatenate([
122 | self.joined_segmented,
123 | np.zeros((n_trajectories, sequence_length-1, joined_dim)),
124 | ], axis=1)
125 | self.termination_flags = np.concatenate([
126 | self.termination_flags,
127 | np.ones((n_trajectories, sequence_length-1), dtype=np.bool),
128 | ], axis=1)
129 |
130 | def __len__(self):
131 | return len(self.indices)
132 |
133 |
134 | class DiscretizedDataset(SequenceDataset):
135 |
136 | def __init__(self, *args, N=50, discretizer='QuantileDiscretizer', **kwargs):
137 | super().__init__(*args, **kwargs)
138 | self.N = N
139 | discretizer_class = getattr(discretization, discretizer)
140 | self.discretizer = discretizer_class(self.joined_raw, N)
141 |
142 | def __getitem__(self, idx):
143 | path_ind, start_ind, end_ind = self.indices[idx]
144 | path_length = self.path_lengths[path_ind]
145 |
146 | joined = self.joined_segmented[path_ind, start_ind:end_ind:self.step]
147 | terminations = self.termination_flags[path_ind, start_ind:end_ind:self.step]
148 |
149 | joined_discrete = self.discretizer.discretize(joined)
150 |
151 | ## replace with termination token if the sequence has ended
152 | assert (joined[terminations] == 0).all(), \
153 | f'Everything after termination should be 0: {path_ind} | {start_ind} | {end_ind}'
154 | joined_discrete[terminations] = self.N
155 |
156 | ## [ (sequence_length / skip) x observation_dim]
157 | joined_discrete = to_torch(joined_discrete, device='cpu', dtype=torch.long).contiguous()
158 |
159 | ## don't compute loss for parts of the prediction that extend
160 | ## beyond the max path length
161 | traj_inds = torch.arange(start_ind, end_ind, self.step)
162 | mask = torch.ones(joined_discrete.shape, dtype=torch.bool)
163 | mask[traj_inds > self.max_path_length - self.step] = 0
164 |
165 | ## flatten everything
166 | joined_discrete = joined_discrete.view(-1)
167 | mask = mask.view(-1)
168 |
169 | X = joined_discrete[:-1]
170 | Y = joined_discrete[1:]
171 | mask = mask[:-1]
172 |
173 | return X, Y, mask
174 |
175 | class GoalDataset(DiscretizedDataset):
176 |
177 | def __init__(self, *args, **kwargs):
178 | super().__init__(*args, **kwargs)
179 | pdb.set_trace()
180 |
181 | def __getitem__(self, idx):
182 | X, Y, mask = super().__getitem__(idx)
183 |
184 | ## get path length for looking up the last transition in the trajcetory
185 | path_ind, start_ind, end_ind = self.indices[idx]
186 | path_length = self.path_lengths[path_ind]
187 |
188 | ## the goal is the first `observation_dim` dimensions of the last transition
189 | goal = self.joined_segmented[path_ind, path_length-1, :self.observation_dim]
190 | goal_discrete = self.discretizer.discretize(goal, subslice=(0, self.observation_dim))
191 | goal_discrete = to_torch(goal_discrete, device='cpu', dtype=torch.long).contiguous().view(-1)
192 |
193 | return X, goal_discrete, Y, mask
194 |
--------------------------------------------------------------------------------
/trajectory/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jannerm/trajectory-transformer/8834a6ed04ceeab8fdb9465e145c6e041c05d71b/trajectory/models/__init__.py
--------------------------------------------------------------------------------
/trajectory/models/ein.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import pdb
5 |
6 | class EinLinear(nn.Module):
7 |
8 | def __init__(self, n_models, in_features, out_features, bias):
9 | super().__init__()
10 | self.n_models = n_models
11 | self.out_features = out_features
12 | self.in_features = in_features
13 | self.weight = nn.Parameter(torch.Tensor(n_models, out_features, in_features))
14 | if bias:
15 | self.bias = nn.Parameter(torch.Tensor(n_models, out_features))
16 | else:
17 | self.register_parameter('bias', None)
18 | self.reset_parameters()
19 |
20 | def reset_parameters(self):
21 | for i in range(self.n_models):
22 | nn.init.kaiming_uniform_(self.weight[i], a=math.sqrt(5))
23 | if self.bias is not None:
24 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[i])
25 | bound = 1 / math.sqrt(fan_in)
26 | nn.init.uniform_(self.bias[i], -bound, bound)
27 |
28 | def forward(self, input):
29 | """
30 | input : [ B x n_models x input_dim ]
31 | """
32 | ## [ B x n_models x output_dim ]
33 | output = torch.einsum('eoi,bei->beo', self.weight, input)
34 | if self.bias is not None:
35 | raise RuntimeError()
36 | return output
37 |
38 | def extra_repr(self):
39 | return 'n_models={}, in_features={}, out_features={}, bias={}'.format(
40 | self.n_models, self.in_features, self.out_features, self.bias is not None
41 | )
42 |
--------------------------------------------------------------------------------
/trajectory/models/embeddings.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import pdb
5 |
6 | def make_weights(N, weights):
7 | assert len(weights) % 2 == 1, f'Expected odd number of weights, got: {weights}'
8 | center = int((len(weights) - 1) / 2)
9 |
10 | tokens = np.zeros((N, N))
11 | for i in range(N):
12 | token = np.zeros(N)
13 | for j, w in enumerate(weights):
14 | ind = i + j - center
15 | ind = np.clip(ind, 0, N-1)
16 | token[ind] += w
17 | tokens[i] = token
18 | assert np.allclose(tokens.sum(axis=-1), 1)
19 | return tokens
20 |
21 | def add_stop_token(tokens):
22 | N = len(tokens)
23 | ## regular tokens put 0 probability on stop token
24 | pad = np.zeros((N, 1))
25 | tokens = np.concatenate([tokens, pad], axis=1)
26 | ## stop token puts 1 probability on itself
27 | stop_weight = np.zeros((1, N+1))
28 | stop_weight[0,-1] = 1
29 | tokens = np.concatenate([tokens, stop_weight], axis=0)
30 |
31 | assert tokens.shape[0] == tokens.shape[1]
32 | assert np.allclose(tokens.sum(axis=-1), 1)
33 | return tokens
34 |
35 | class SmoothEmbedding(nn.Module):
36 |
37 | def __init__(self, num_embeddings, embedding_dim, weights, stop_token=False):
38 | super().__init__()
39 | self.weights = make_weights(num_embeddings, weights)
40 | if stop_token:
41 | self.weights = add_stop_token(self.weights)
42 | num_embeddings += 1
43 | self.weights = torch.tensor(self.weights, dtype=torch.float, device='cuda:0')
44 | self.inds = torch.arange(0, num_embeddings, device='cuda:0')
45 | self._embeddings = nn.Embedding(num_embeddings, embedding_dim)
46 |
47 | def forward(self, x):
48 | '''
49 | x : [ batch_size x context ]
50 | '''
51 | ## [ num_embeddings x embedding_dim ]
52 | embed = self._embeddings(self.inds)
53 | ## [ batch_size x context x num_embeddings ]
54 | weights = self.weights[x]
55 | assert torch.allclose(weights.sum(-1), torch.ones(1, device=weights.device))
56 |
57 | # [ batch_size x context x embedding_dim ]
58 | weighted_embed = torch.einsum('btn,nd->btd', weights, embed)
59 | return weighted_embed
60 |
61 |
62 | if __name__ == '__main__':
63 |
64 | x = torch.randint(0, 100, size=(5, 10,)).cuda()
65 |
66 | ## test with weights
67 | embed = SmoothEmbedding(100, 32, weights=[0.15, 0.2, 0.3, 0.2, 0.15], stop_token=True)
68 | embed.cuda()
69 | out = embed(x)
70 |
71 | ## test limiting case of regular embedding module
72 | embed_1 = SmoothEmbedding(100, 32, weights=[1.0], stop_token=True)
73 | embed_1.cuda()
74 | out_1 = embed_1(x)
75 |
76 | ## reference
77 | out_0 = embed_1._embeddings(x)
78 |
79 | print(f'Same: {(out_0 == out_1).all().item()}')
80 |
--------------------------------------------------------------------------------
/trajectory/models/mlp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import pdb
4 |
5 |
6 | def get_activation(params):
7 | if type(params) == dict:
8 | name = params['type']
9 | kwargs = params['kwargs']
10 | else:
11 | name = str(params)
12 | kwargs = {}
13 | return lambda: getattr(nn, name)(**kwargs)
14 |
15 | def flatten(condition_dict):
16 | keys = sorted(condition_dict)
17 | vals = [condition_dict[key] for key in keys]
18 | condition = torch.cat(vals, dim=-1)
19 | return condition
20 |
21 | class MLP(nn.Module):
22 |
23 | def __init__(self, input_dim, hidden_dims, output_dim, activation='GELU', output_activation='Identity', name='mlp', model_class=None):
24 | """
25 | @TODO: clean up model instantiation from config so we don't have to pass in `model_class` to the model itself
26 | """
27 | super(MLP, self).__init__()
28 | self.input_dim = input_dim
29 | self.name = name
30 | activation = get_activation(activation)
31 | output_activation = get_activation(output_activation)
32 |
33 | layers = []
34 | current = input_dim
35 | for dim in hidden_dims:
36 | linear = nn.Linear(current, dim)
37 | layers.append(linear)
38 | layers.append(activation())
39 | current = dim
40 |
41 | layers.append(nn.Linear(current, output_dim))
42 | layers.append(output_activation())
43 |
44 | self._layers = nn.Sequential(*layers)
45 |
46 | def forward(self, x):
47 | return self._layers(x)
48 |
49 | @property
50 | def num_parameters(self):
51 | parameters = filter(lambda p: p.requires_grad, self.parameters())
52 | return sum([p.numel() for p in parameters])
53 |
54 | def __repr__(self):
55 | return '[ {} : {} parameters ] {}'.format(
56 | self.name, self.num_parameters,
57 | super().__repr__())
58 |
59 | class FlattenMLP(MLP):
60 |
61 | def forward(self, *args):
62 | x = torch.cat(args, dim=-1)
63 | return super().forward(x)
64 |
--------------------------------------------------------------------------------
/trajectory/models/transformers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import math
3 | import pdb
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torch.nn import functional as F
8 |
9 | from .ein import EinLinear
10 |
11 | class CausalSelfAttention(nn.Module):
12 |
13 | def __init__(self, config):
14 | super().__init__()
15 | assert config.n_embd % config.n_head == 0
16 | # key, query, value projections for all heads
17 | self.key = nn.Linear(config.n_embd, config.n_embd)
18 | self.query = nn.Linear(config.n_embd, config.n_embd)
19 | self.value = nn.Linear(config.n_embd, config.n_embd)
20 | # regularization
21 | self.attn_drop = nn.Dropout(config.attn_pdrop)
22 | self.resid_drop = nn.Dropout(config.resid_pdrop)
23 | # output projection
24 | self.proj = nn.Linear(config.n_embd, config.n_embd)
25 | # causal mask to ensure that attention is only applied to the left in the input sequence
26 | self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size))
27 | .view(1, 1, config.block_size, config.block_size))
28 | ## mask previous value estimates
29 | joined_dim = config.observation_dim + config.action_dim + 2
30 | self.mask.squeeze()[:,joined_dim-1::joined_dim] = 0
31 | ##
32 | self.n_head = config.n_head
33 |
34 | def forward(self, x, layer_past=None):
35 | B, T, C = x.size()
36 |
37 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
38 | ## [ B x n_heads x T x head_dim ]
39 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
40 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
41 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
42 |
43 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
44 | ## [ B x n_heads x T x T ]
45 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
46 | att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
47 | att = F.softmax(att, dim=-1)
48 | self._attn_map = att.clone()
49 | att = self.attn_drop(att)
50 | ## [ B x n_heads x T x head_size ]
51 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
52 | ## [ B x T x embedding_dim ]
53 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
54 |
55 | # output projection
56 | y = self.resid_drop(self.proj(y))
57 | return y
58 |
59 | class Block(nn.Module):
60 |
61 | def __init__(self, config):
62 | super().__init__()
63 | self.ln1 = nn.LayerNorm(config.n_embd)
64 | self.ln2 = nn.LayerNorm(config.n_embd)
65 | self.attn = CausalSelfAttention(config)
66 | self.mlp = nn.Sequential(
67 | nn.Linear(config.n_embd, 4 * config.n_embd),
68 | nn.GELU(),
69 | nn.Linear(4 * config.n_embd, config.n_embd),
70 | nn.Dropout(config.resid_pdrop),
71 | )
72 |
73 | def forward(self, x):
74 | x = x + self.attn(self.ln1(x))
75 | x = x + self.mlp(self.ln2(x))
76 | return x
77 |
78 | class GPT(nn.Module):
79 | """ the full GPT language model, with a context size of block_size """
80 |
81 | def __init__(self, config):
82 | super().__init__()
83 |
84 | # input embedding stem (+1 for stop token)
85 | self.tok_emb = nn.Embedding(config.vocab_size * config.transition_dim + 1, config.n_embd)
86 |
87 | self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
88 | self.drop = nn.Dropout(config.embd_pdrop)
89 | # transformer
90 | self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
91 | # decoder head
92 | self.ln_f = nn.LayerNorm(config.n_embd)
93 | # self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
94 | self.head = EinLinear(config.transition_dim, config.n_embd, config.vocab_size + 1, bias=False)
95 |
96 | self.vocab_size = config.vocab_size
97 | self.stop_token = config.vocab_size * config.transition_dim
98 | self.block_size = config.block_size
99 | self.observation_dim = config.observation_dim
100 |
101 | self.action_dim = config.action_dim
102 | self.transition_dim = config.transition_dim
103 | self.action_weight = config.action_weight
104 | self.reward_weight = config.reward_weight
105 | self.value_weight = config.value_weight
106 |
107 | self.embedding_dim = config.n_embd
108 | self.apply(self._init_weights)
109 |
110 | def get_block_size(self):
111 | return self.block_size
112 |
113 | def _init_weights(self, module):
114 | if isinstance(module, (nn.Linear, nn.Embedding)):
115 | module.weight.data.normal_(mean=0.0, std=0.02)
116 | if isinstance(module, nn.Linear) and module.bias is not None:
117 | module.bias.data.zero_()
118 | elif isinstance(module, nn.LayerNorm):
119 | module.bias.data.zero_()
120 | module.weight.data.fill_(1.0)
121 |
122 | def configure_optimizers(self, train_config):
123 | """
124 | This long function is unfortunately doing something very simple and is being very defensive:
125 | We are separating out all parameters of the model into two buckets: those that will experience
126 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
127 | We are then returning the PyTorch optimizer object.
128 | """
129 |
130 | # separate out all parameters to those that will and won't experience regularizing weight decay
131 | decay = set()
132 | no_decay = set()
133 | whitelist_weight_modules = (torch.nn.Linear, EinLinear)
134 | blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
135 | for mn, m in self.named_modules():
136 | for pn, p in m.named_parameters():
137 | fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
138 |
139 | if pn.endswith('bias'):
140 | # all biases will not be decayed
141 | no_decay.add(fpn)
142 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
143 | # weights of whitelist modules will be weight decayed
144 | decay.add(fpn)
145 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
146 | # weights of blacklist modules will NOT be weight decayed
147 | no_decay.add(fpn)
148 |
149 | # special case the position embedding parameter in the root GPT module as not decayed
150 | no_decay.add('pos_emb')
151 |
152 | # validate that we considered every parameter
153 | param_dict = {pn: p for pn, p in self.named_parameters()}
154 | inter_params = decay & no_decay
155 | union_params = decay | no_decay
156 | assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
157 | assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
158 | % (str(param_dict.keys() - union_params), )
159 |
160 | # create the pytorch optimizer object
161 | optim_groups = [
162 | {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
163 | {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
164 | ]
165 | optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
166 | return optimizer
167 |
168 | def offset_tokens(self, idx):
169 | _, t = idx.shape
170 | n_states = int(np.ceil(t / self.transition_dim))
171 | offsets = torch.arange(self.transition_dim) * self.vocab_size
172 | offsets = offsets.repeat(n_states).to(idx.device)
173 | offset_idx = idx + offsets[:t]
174 | offset_idx[idx == self.vocab_size] = self.stop_token
175 | return offset_idx
176 |
177 | def pad_to_full_observation(self, x, verify=False):
178 | b, t, _ = x.shape
179 | n_pad = (self.transition_dim - t % self.transition_dim) % self.transition_dim
180 | padding = torch.zeros(b, n_pad, self.embedding_dim, device=x.device)
181 | ## [ B x T' x embedding_dim ]
182 | x_pad = torch.cat([x, padding], dim=1)
183 | ## [ (B * T' / transition_dim) x transition_dim x embedding_dim ]
184 | x_pad = x_pad.view(-1, self.transition_dim, self.embedding_dim)
185 | if verify:
186 | self.verify(x, x_pad)
187 | return x_pad, n_pad
188 |
189 | def verify(self, x, x_pad):
190 | b, t, embedding_dim = x.shape
191 | n_states = int(np.ceil(t / self.transition_dim))
192 | inds = torch.arange(0, self.transition_dim).repeat(n_states)[:t]
193 | for i in range(self.transition_dim):
194 | x_ = x[:,inds == i]
195 | t_ = x_.shape[1]
196 | x_pad_ = x_pad[:,i].view(b, n_states, embedding_dim)[:,:t_]
197 | print(i, x_.shape, x_pad_.shape)
198 | try:
199 | assert (x_ == x_pad_).all()
200 | except:
201 | pdb.set_trace()
202 |
203 | def forward(self, idx, targets=None, mask=None):
204 | """
205 | idx : [ B x T ]
206 | values : [ B x 1 x 1 ]
207 | """
208 | b, t = idx.size()
209 | assert t <= self.block_size, "Cannot forward, model block size is exhausted."
210 |
211 | offset_idx = self.offset_tokens(idx)
212 | ## [ B x T x embedding_dim ]
213 | # forward the GPT model
214 | token_embeddings = self.tok_emb(offset_idx) # each index maps to a (learnable) vector
215 | ## [ 1 x T x embedding_dim ]
216 | position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
217 | ## [ B x T x embedding_dim ]
218 | x = self.drop(token_embeddings + position_embeddings)
219 | x = self.blocks(x)
220 | ## [ B x T x embedding_dim ]
221 | x = self.ln_f(x)
222 |
223 | ## [ (B * T' / transition_dim) x transition_dim x embedding_dim ]
224 | x_pad, n_pad = self.pad_to_full_observation(x)
225 | ## [ (B * T' / transition_dim) x transition_dim x (vocab_size + 1) ]
226 | logits = self.head(x_pad)
227 | ## [ B x T' x (vocab_size + 1) ]
228 | logits = logits.reshape(b, t + n_pad, self.vocab_size + 1)
229 | ## [ B x T x (vocab_size + 1) ]
230 | logits = logits[:,:t]
231 |
232 | # if we are given some desired targets also calculate the loss
233 | if targets is not None:
234 | loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.view(-1), reduction='none')
235 | if self.action_weight != 1 or self.reward_weight != 1 or self.value_weight != 1:
236 | #### make weights
237 | n_states = int(np.ceil(t / self.transition_dim))
238 | weights = torch.cat([
239 | torch.ones(self.observation_dim, device=idx.device),
240 | torch.ones(self.action_dim, device=idx.device) * self.action_weight,
241 | torch.ones(1, device=idx.device) * self.reward_weight,
242 | torch.ones(1, device=idx.device) * self.value_weight,
243 | ])
244 | ## [ t + 1]
245 | weights = weights.repeat(n_states)
246 | ## [ b x t ]
247 | weights = weights[1:].repeat(b, 1)
248 | ####
249 | loss = loss * weights.view(-1)
250 | loss = (loss * mask.view(-1)).mean()
251 | else:
252 | loss = None
253 |
254 | return logits, loss
255 |
256 | class ConditionalGPT(GPT):
257 |
258 | def __init__(self, config):
259 | ## increase block size by `observation_dim` because we are prepending a goal observation
260 | ## to the sequence
261 | config.block_size += config.observation_dim
262 | super().__init__(config)
263 | self.goal_emb = nn.Embedding(config.vocab_size * config.observation_dim, config.n_embd)
264 |
265 | def get_block_size(self):
266 | return self.block_size - self.observation_dim
267 |
268 | def forward(self, idx, goal, targets=None, mask=None):
269 | b, t = idx.size()
270 | assert t <= self.block_size, "Cannot forward, model block size is exhausted."
271 |
272 | #### goal
273 | offset_goal = self.offset_tokens(goal)
274 | goal_embeddings = self.goal_emb(offset_goal)
275 | #### /goal
276 |
277 | offset_idx = self.offset_tokens(idx)
278 | ## [ B x T x embedding_dim ]
279 | # forward the GPT model
280 | token_embeddings = self.tok_emb(offset_idx) # each index maps to a (learnable) vector
281 | ## [ 1 x T x embedding_dim ]
282 | position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
283 | ## [ B x T x embedding_dim ]
284 | x = self.drop(token_embeddings + position_embeddings)
285 |
286 | #### goal
287 | ## [ B + (obs_dim + T) x embedding_dim ]
288 | gx = torch.cat([goal_embeddings, x], dim=1)
289 | gx = self.blocks(gx)
290 | x = gx[:, self.observation_dim:]
291 | #### /goal
292 |
293 | ## [ B x T x embedding_dim ]
294 | x = self.ln_f(x)
295 |
296 | ## [ (B * T' / transition_dim) x transition_dim x embedding_dim ]
297 | x_pad, n_pad = self.pad_to_full_observation(x)
298 | ## [ (B * T' / transition_dim) x transition_dim x (vocab_size + 1) ]
299 | logits = self.head(x_pad)
300 | ## [ B x T' x (vocab_size + 1) ]
301 | logits = logits.reshape(b, t + n_pad, self.vocab_size + 1)
302 | ## [ B x T x (vocab_size + 1) ]
303 | logits = logits[:,:t]
304 |
305 | # if we are given some desired targets also calculate the loss
306 | if targets is not None:
307 | loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.view(-1), reduction='none')
308 | if self.action_weight != 1 or self.reward_weight != 1 or self.value_weight != 1:
309 | #### make weights
310 | n_states = int(np.ceil(t / self.transition_dim))
311 | weights = torch.cat([
312 | torch.ones(self.observation_dim, device=idx.device),
313 | torch.ones(self.action_dim, device=idx.device) * self.action_weight,
314 | torch.ones(1, device=idx.device) * self.reward_weight,
315 | torch.ones(1, device=idx.device) * self.value_weight,
316 | ])
317 | ## [ t + 1]
318 | weights = weights.repeat(n_states)
319 | ## [ b x t ]
320 | weights = weights[1:].repeat(b, 1)
321 | ####
322 | loss = loss * weights.view(-1)
323 | loss = (loss * mask.view(-1)).mean()
324 | else:
325 | loss = None
326 |
327 | return logits, loss
328 |
--------------------------------------------------------------------------------
/trajectory/search/__init__.py:
--------------------------------------------------------------------------------
1 | from .core import *
2 | from .utils import *
--------------------------------------------------------------------------------
/trajectory/search/core.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import pdb
4 |
5 | from .. import utils
6 | from .sampling import sample_n, get_logp, sort_2d
7 |
8 | REWARD_DIM = VALUE_DIM = 1
9 |
10 | @torch.no_grad()
11 | def beam_plan(
12 | model, value_fn, x,
13 | n_steps, beam_width, n_expand,
14 | observation_dim, action_dim,
15 | discount=0.99, max_context_transitions=None,
16 | k_obs=None, k_act=None, k_rew=1,
17 | cdf_obs=None, cdf_act=None, cdf_rew=None,
18 | verbose=True, previous_actions=None,
19 | ):
20 | '''
21 | x : tensor[ 1 x input_sequence_length ]
22 | '''
23 |
24 | inp = x.clone()
25 |
26 | # convert max number of transitions to max number of tokens
27 | transition_dim = observation_dim + action_dim + REWARD_DIM + VALUE_DIM
28 | max_block = max_context_transitions * transition_dim - 1 if max_context_transitions else None
29 |
30 | ## pass in max numer of tokens to sample function
31 | sample_kwargs = {
32 | 'max_block': max_block,
33 | 'crop_increment': transition_dim,
34 | }
35 |
36 | ## repeat input for search
37 | x = x.repeat(beam_width, 1)
38 |
39 | ## construct reward and discount tensors for estimating values
40 | rewards = torch.zeros(beam_width, n_steps + 1, device=x.device)
41 | discounts = discount ** torch.arange(n_steps + 1, device=x.device)
42 |
43 | ## logging
44 | progress = utils.Progress(n_steps) if verbose else utils.Silent()
45 |
46 | for t in range(n_steps):
47 | ## repeat everything by `n_expand` before we sample actions
48 | x = x.repeat(n_expand, 1)
49 | rewards = rewards.repeat(n_expand, 1)
50 |
51 | ## sample actions
52 | x, _ = sample_n(model, x, action_dim, topk=k_act, cdf=cdf_act, **sample_kwargs)
53 |
54 | ## sample reward and value estimate
55 | x, r_probs = sample_n(model, x, REWARD_DIM + VALUE_DIM, topk=k_rew, cdf=cdf_rew, **sample_kwargs)
56 |
57 | ## optionally, use a percentile or mean of the reward and
58 | ## value distributions instead of sampled tokens
59 | r_t, V_t = value_fn(r_probs)
60 |
61 | ## update rewards tensor
62 | rewards[:, t] = r_t
63 | rewards[:, t+1] = V_t
64 |
65 | ## estimate values using rewards up to `t` and terminal value at `t`
66 | values = (rewards * discounts).sum(dim=-1)
67 |
68 | ## get `beam_width` best actions
69 | values, inds = torch.topk(values, beam_width)
70 |
71 | ## index into search candidates to retain `beam_width` highest-reward sequences
72 | x = x[inds]
73 | rewards = rewards[inds]
74 |
75 | ## sample next observation (unless we have reached the end of the planning horizon)
76 | if t < n_steps - 1:
77 | x, _ = sample_n(model, x, observation_dim, topk=k_obs, cdf=cdf_obs, **sample_kwargs)
78 |
79 | ## logging
80 | progress.update({
81 | 'x': list(x.shape),
82 | 'vmin': values.min(), 'vmax': values.max(),
83 | 'vtmin': V_t.min(), 'vtmax': V_t.max(),
84 | 'discount': discount
85 | })
86 |
87 | progress.stamp()
88 |
89 | ## [ batch_size x (n_context + n_steps) x transition_dim ]
90 | x = x.view(beam_width, -1, transition_dim)
91 |
92 | ## crop out context transitions
93 | ## [ batch_size x n_steps x transition_dim ]
94 | x = x[:, -n_steps:]
95 |
96 | ## return best sequence
97 | argmax = values.argmax()
98 | best_sequence = x[argmax]
99 |
100 | return best_sequence
101 |
102 | @torch.no_grad()
103 | def beam_search(model, x, n_steps, beam_width=512, goal=None, **sample_kwargs):
104 | batch_size = len(x)
105 |
106 | prefix_i = torch.arange(len(x), dtype=torch.long, device=x.device)
107 | cumulative_logp = torch.zeros(batch_size, 1, device=x.device)
108 |
109 | for t in range(n_steps):
110 |
111 | if goal is not None:
112 | goal_rep = goal.repeat(len(x), 1)
113 | logp = get_logp(model, x, goal=goal_rep, **sample_kwargs)
114 | else:
115 | logp = get_logp(model, x, **sample_kwargs)
116 |
117 | candidate_logp = cumulative_logp + logp
118 | sorted_logp, sorted_i, sorted_j = sort_2d(candidate_logp)
119 |
120 | n_candidates = (candidate_logp > -np.inf).sum().item()
121 | n_retain = min(n_candidates, beam_width)
122 | cumulative_logp = sorted_logp[:n_retain].unsqueeze(-1)
123 |
124 | sorted_i = sorted_i[:n_retain]
125 | sorted_j = sorted_j[:n_retain].unsqueeze(-1)
126 |
127 | x = torch.cat([x[sorted_i], sorted_j], dim=-1)
128 | prefix_i = prefix_i[sorted_i]
129 |
130 | x = x[0]
131 | return x, cumulative_logp.squeeze()
132 |
--------------------------------------------------------------------------------
/trajectory/search/sampling.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import pdb
4 |
5 | #-------------------------------- helper functions --------------------------------#
6 |
7 | def top_k_logits(logits, k):
8 | v, ix = torch.topk(logits, k)
9 | out = logits.clone()
10 | out[out < v[:, [-1]]] = -float('Inf')
11 | return out
12 |
13 | def filter_cdf(logits, threshold):
14 | batch_inds = torch.arange(logits.shape[0], device=logits.device, dtype=torch.long)
15 | bins_inds = torch.arange(logits.shape[-1], device=logits.device)
16 | probs = logits.softmax(dim=-1)
17 | probs_sorted, _ = torch.sort(probs, dim=-1)
18 | probs_cum = torch.cumsum(probs_sorted, dim=-1)
19 | ## get minimum probability p such that the cdf up to p is at least `threshold`
20 | mask = probs_cum < threshold
21 | masked_inds = torch.argmax(mask * bins_inds, dim=-1)
22 | probs_threshold = probs_sorted[batch_inds, masked_inds]
23 | ## filter
24 | out = logits.clone()
25 | logits_mask = probs <= probs_threshold.unsqueeze(dim=-1)
26 | out[logits_mask] = -1000
27 | return out
28 |
29 | def round_to_multiple(x, N):
30 | '''
31 | Rounds `x` up to nearest multiple of `N`.
32 |
33 | x : int
34 | N : int
35 | '''
36 | pad = (N - x % N) % N
37 | return x + pad
38 |
39 | def sort_2d(x):
40 | '''
41 | x : [ M x N ]
42 | '''
43 | M, N = x.shape
44 | x = x.view(-1)
45 | x_sort, inds = torch.sort(x, descending=True)
46 |
47 | rows = inds // N
48 | cols = inds % N
49 |
50 | return x_sort, rows, cols
51 |
52 | #-------------------------------- forward pass --------------------------------#
53 |
54 | def forward(model, x, max_block=None, allow_crop=True, crop_increment=None, **kwargs):
55 | '''
56 | A wrapper around a single forward pass of the transformer.
57 | Crops the input if the sequence is too long.
58 |
59 | x : tensor[ batch_size x sequence_length ]
60 | '''
61 | model.eval()
62 |
63 | block_size = min(model.get_block_size(), max_block or np.inf)
64 |
65 | if x.shape[1] > block_size:
66 | assert allow_crop, (
67 | f'[ search/sampling ] input size is {x.shape} and block size is {block_size}, '
68 | 'but cropping not allowed')
69 |
70 | ## crop out entire transition at a time so that the first token is always s_t^0
71 | n_crop = round_to_multiple(x.shape[1] - block_size, crop_increment)
72 | assert n_crop % crop_increment == 0
73 | x = x[:, n_crop:]
74 |
75 | logits, _ = model(x, **kwargs)
76 |
77 | return logits
78 |
79 | def get_logp(model, x, temperature=1.0, topk=None, cdf=None, **forward_kwargs):
80 | '''
81 | x : tensor[ batch_size x sequence_length ]
82 | '''
83 | ## [ batch_size x sequence_length x vocab_size ]
84 | logits = forward(model, x, **forward_kwargs)
85 |
86 | ## pluck the logits at the final step and scale by temperature
87 | ## [ batch_size x vocab_size ]
88 | logits = logits[:, -1] / temperature
89 |
90 | ## optionally crop logits to only the top `1 - cdf` percentile
91 | if cdf is not None:
92 | logits = filter_cdf(logits, cdf)
93 |
94 | ## optionally crop logits to only the most likely `k` options
95 | if topk is not None:
96 | logits = top_k_logits(logits, topk)
97 |
98 | ## apply softmax to convert to probabilities
99 | logp = logits.log_softmax(dim=-1)
100 |
101 | return logp
102 |
103 | #-------------------------------- sampling --------------------------------#
104 |
105 | def sample(model, x, temperature=1.0, topk=None, cdf=None, **forward_kwargs):
106 | '''
107 | Samples from the distribution parameterized by `model(x)`.
108 |
109 | x : tensor[ batch_size x sequence_length ]
110 | '''
111 | ## [ batch_size x sequence_length x vocab_size ]
112 | logits = forward(model, x, **forward_kwargs)
113 |
114 | ## pluck the logits at the final step and scale by temperature
115 | ## [ batch_size x vocab_size ]
116 | logits = logits[:, -1] / temperature
117 |
118 | ## keep track of probabilities before modifying logits
119 | raw_probs = logits.softmax(dim=-1)
120 |
121 | ## optionally crop logits to only the top `1 - cdf` percentile
122 | if cdf is not None:
123 | logits = filter_cdf(logits, cdf)
124 |
125 | ## optionally crop logits to only the most likely `k` options
126 | if topk is not None:
127 | logits = top_k_logits(logits, topk)
128 |
129 | ## apply softmax to convert to probabilities
130 | probs = logits.softmax(dim=-1)
131 |
132 | ## sample from the distribution
133 | ## [ batch_size x 1 ]
134 | indices = torch.multinomial(probs, num_samples=1)
135 |
136 | return indices, raw_probs
137 |
138 | @torch.no_grad()
139 | def sample_n(model, x, N, **sample_kwargs):
140 | batch_size = len(x)
141 |
142 | ## keep track of probabilities from each step;
143 | ## `vocab_size + 1` accounts for termination token
144 | probs = torch.zeros(batch_size, N, model.vocab_size + 1, device=x.device)
145 |
146 | for n in range(N):
147 | indices, p = sample(model, x, **sample_kwargs)
148 |
149 | ## append to the sequence and continue
150 | ## [ batch_size x (sequence_length + n) ]
151 | x = torch.cat((x, indices), dim=1)
152 |
153 | probs[:, n] = p
154 |
155 | return x, probs
156 |
--------------------------------------------------------------------------------
/trajectory/search/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import pdb
4 |
5 | from ..utils.arrays import to_torch
6 |
7 | VALUE_PLACEHOLDER = 1e6
8 |
9 | def make_prefix(discretizer, context, obs, prefix_context=True):
10 | observation_dim = obs.size
11 | obs_discrete = discretizer.discretize(obs, subslice=[0, observation_dim])
12 | obs_discrete = to_torch(obs_discrete, dtype=torch.long)
13 |
14 | if prefix_context:
15 | prefix = torch.cat(context + [obs_discrete], dim=-1)
16 | else:
17 | prefix = obs_discrete
18 |
19 | return prefix
20 |
21 | def extract_actions(x, observation_dim, action_dim, t=None):
22 | assert x.shape[1] == observation_dim + action_dim + 2
23 | actions = x[:, observation_dim:observation_dim+action_dim]
24 | if t is not None:
25 | return actions[t]
26 | else:
27 | return actions
28 |
29 | def update_context(context, discretizer, observation, action, reward, max_context_transitions):
30 | '''
31 | context : list of transitions
32 | [ tensor( transition_dim ), ... ]
33 | '''
34 | ## use a placeholder for value because input values are masked out by model
35 | rew_val = np.array([reward, VALUE_PLACEHOLDER])
36 | transition = np.concatenate([observation, action, rew_val])
37 |
38 | ## discretize transition and convert to torch tensor
39 | transition_discrete = discretizer.discretize(transition)
40 | transition_discrete = to_torch(transition_discrete, dtype=torch.long)
41 |
42 | ## add new transition to context
43 | context.append(transition_discrete)
44 |
45 | ## crop context if necessary
46 | context = context[-max_context_transitions:]
47 |
48 | return context
--------------------------------------------------------------------------------
/trajectory/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .setup import Parser, watch
2 | from .arrays import *
3 | from .serialization import *
4 | from .progress import Progress, Silent
5 | from .rendering import make_renderer
6 | # from .video import *
7 | from .config import Config
8 | from .training import Trainer
9 |
--------------------------------------------------------------------------------
/trajectory/utils/arrays.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | DTYPE = torch.float
5 | DEVICE = 'cuda:0'
6 |
7 | def to_np(x):
8 | if torch.is_tensor(x):
9 | x = x.detach().cpu().numpy()
10 | return x
11 |
12 | def to_torch(x, dtype=None, device=None):
13 | dtype = dtype or DTYPE
14 | device = device or DEVICE
15 | return torch.tensor(x, dtype=dtype, device=device)
16 |
17 | def to_device(*xs, device=DEVICE):
18 | return [x.to(device) for x in xs]
19 |
20 | def normalize(x):
21 | """
22 | scales `x` to [0, 1]
23 | """
24 | x = x - x.min()
25 | x = x / x.max()
26 | return x
27 |
28 | def to_img(x):
29 | normalized = normalize(x)
30 | array = to_np(normalized)
31 | array = np.transpose(array, (1,2,0))
32 | return (array * 255).astype(np.uint8)
33 |
34 | def set_device(device):
35 | DEVICE = device
36 | if 'cuda' in device:
37 | torch.set_default_tensor_type(torch.cuda.FloatTensor)
38 |
--------------------------------------------------------------------------------
/trajectory/utils/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import collections
3 | import pickle
4 |
5 | class Config(collections.Mapping):
6 |
7 | def __init__(self, _class, verbose=True, savepath=None, **kwargs):
8 | self._class = _class
9 | self._dict = {}
10 |
11 | for key, val in kwargs.items():
12 | self._dict[key] = val
13 |
14 | if verbose:
15 | print(self)
16 |
17 | if savepath is not None:
18 | savepath = os.path.join(*savepath) if type(savepath) is tuple else savepath
19 | pickle.dump(self, open(savepath, 'wb'))
20 | print(f'Saved config to: {savepath}\n')
21 |
22 | def __repr__(self):
23 | string = f'\nConfig: {self._class}\n'
24 | for key in sorted(self._dict.keys()):
25 | val = self._dict[key]
26 | string += f' {key}: {val}\n'
27 | return string
28 |
29 | def __iter__(self):
30 | return iter(self._dict)
31 |
32 | def __getitem__(self, item):
33 | return self._dict[item]
34 |
35 | def __len__(self):
36 | return len(self._dict)
37 |
38 | def __call__(self):
39 | return self.make()
40 |
41 | def __getattr__(self, attr):
42 | if attr == '_dict' and '_dict' not in vars(self):
43 | self._dict = {}
44 | try:
45 | return self._dict[attr]
46 | except KeyError:
47 | raise AttributeError(attr)
48 |
49 | def make(self):
50 | if 'GPT' in str(self._class) or 'Trainer' in str(self._class):
51 | ## GPT class expects the config as the sole input
52 | return self._class(self)
53 | else:
54 | return self._class(**self._dict)
55 |
--------------------------------------------------------------------------------
/trajectory/utils/discretization.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import pdb
4 |
5 | from .arrays import to_np, to_torch
6 |
7 | class QuantileDiscretizer:
8 |
9 | def __init__(self, data, N):
10 | self.data = data
11 | self.N = N
12 |
13 | n_points_per_bin = int(np.ceil(len(data) / N))
14 | obs_sorted = np.sort(data, axis=0)
15 | thresholds = obs_sorted[::n_points_per_bin, :]
16 | maxs = data.max(axis=0, keepdims=True)
17 |
18 | ## [ (N + 1) x dim ]
19 | self.thresholds = np.concatenate([thresholds, maxs], axis=0)
20 |
21 | # threshold_inds = np.linspace(0, len(data) - 1, N + 1, dtype=int)
22 | # obs_sorted = np.sort(data, axis=0)
23 |
24 | # ## [ (N + 1) x dim ]
25 | # self.thresholds = obs_sorted[threshold_inds, :]
26 |
27 | ## [ N x dim ]
28 | self.diffs = self.thresholds[1:] - self.thresholds[:-1]
29 |
30 | ## for sparse reward tasks
31 | # if (self.diffs[:,-1] == 0).any():
32 | # raise RuntimeError('rebin for sparse reward tasks')
33 |
34 | self._test()
35 |
36 | def __call__(self, x):
37 | indices = self.discretize(x)
38 | recon = self.reconstruct(indices)
39 | error = np.abs(recon - x).max(0)
40 | return indices, recon, error
41 |
42 | def _test(self):
43 | print('[ utils/discretization ] Testing...', end=' ', flush=True)
44 | inds = np.random.randint(0, len(self.data), size=1000)
45 | X = self.data[inds]
46 | indices = self.discretize(X)
47 | recon = self.reconstruct(indices)
48 | ## make sure reconstruction error is less than the max allowed per dimension
49 | error = np.abs(X - recon).max(0)
50 | assert (error <= self.diffs.max(axis=0)).all()
51 | ## re-discretize reconstruction and make sure it is the same as original indices
52 | indices_2 = self.discretize(recon)
53 | assert (indices == indices_2).all()
54 | ## reconstruct random indices
55 | ## @TODO: remove duplicate thresholds
56 | # randint = np.random.randint(0, self.N, indices.shape)
57 | # randint_2 = self.discretize(self.reconstruct(randint))
58 | # assert (randint == randint_2).all()
59 | print('✓')
60 |
61 | def discretize(self, x, subslice=(None, None)):
62 | '''
63 | x : [ B x observation_dim ]
64 | '''
65 |
66 | if torch.is_tensor(x):
67 | x = to_np(x)
68 |
69 | ## enforce batch mode
70 | if x.ndim == 1:
71 | x = x[None]
72 |
73 | ## [ N x B x observation_dim ]
74 | start, end = subslice
75 | thresholds = self.thresholds[:, start:end]
76 |
77 | gt = x[None] >= thresholds[:,None]
78 | indices = largest_nonzero_index(gt, dim=0)
79 |
80 | if indices.min() < 0 or indices.max() >= self.N:
81 | indices = np.clip(indices, 0, self.N - 1)
82 |
83 | return indices
84 |
85 | def reconstruct(self, indices, subslice=(None, None)):
86 |
87 | if torch.is_tensor(indices):
88 | indices = to_np(indices)
89 |
90 | ## enforce batch mode
91 | if indices.ndim == 1:
92 | indices = indices[None]
93 |
94 | if indices.min() < 0 or indices.max() >= self.N:
95 | print(f'[ utils/discretization ] indices out of range: ({indices.min()}, {indices.max()}) | N: {self.N}')
96 | indices = np.clip(indices, 0, self.N - 1)
97 |
98 | start, end = subslice
99 | thresholds = self.thresholds[:, start:end]
100 |
101 | left = np.take_along_axis(thresholds, indices, axis=0)
102 | right = np.take_along_axis(thresholds, indices + 1, axis=0)
103 | recon = (left + right) / 2.
104 | return recon
105 |
106 | #---------------------------- wrappers for planning ----------------------------#
107 |
108 | def expectation(self, probs, subslice):
109 | '''
110 | probs : [ B x N ]
111 | '''
112 |
113 | if torch.is_tensor(probs):
114 | probs = to_np(probs)
115 |
116 | ## [ N ]
117 | thresholds = self.thresholds[:, subslice]
118 | ## [ B ]
119 | left = probs @ thresholds[:-1]
120 | right = probs @ thresholds[1:]
121 |
122 | avg = (left + right) / 2.
123 | return avg
124 |
125 | def percentile(self, probs, percentile, subslice):
126 | '''
127 | percentile `p` :
128 | returns least value `v` s.t. cdf up to `v` is >= `p`
129 | e.g., p=0.8 and v=100 indicates that
130 | 100 is in the 80% percentile of values
131 | '''
132 | ## [ N ]
133 | thresholds = self.thresholds[:, subslice]
134 | ## [ B x N ]
135 | cumulative = np.cumsum(probs, axis=-1)
136 | valid = cumulative > percentile
137 | ## [ B ]
138 | inds = np.argmax(np.arange(self.N, 0, -1) * valid, axis=-1)
139 | left = thresholds[inds-1]
140 | right = thresholds[inds]
141 | avg = (left + right) / 2.
142 | return avg
143 |
144 | #---------------------------- wrappers for planning ----------------------------#
145 |
146 | def value_expectation(self, probs):
147 | '''
148 | probs : [ B x 2 x ( N + 1 ) ]
149 | extra token comes from termination
150 | '''
151 |
152 | if torch.is_tensor(probs):
153 | probs = to_np(probs)
154 | return_torch = True
155 | else:
156 | return_torch = False
157 |
158 | probs = probs[:, :, :-1]
159 | assert probs.shape[-1] == self.N
160 |
161 | rewards = self.expectation(probs[:, 0], subslice=-2)
162 | next_values = self.expectation(probs[:, 1], subslice=-1)
163 |
164 | if return_torch:
165 | rewards = to_torch(rewards)
166 | next_values = to_torch(next_values)
167 |
168 | return rewards, next_values
169 |
170 | def value_fn(self, probs, percentile):
171 | if percentile == 'mean':
172 | return self.value_expectation(probs)
173 | else:
174 | ## percentile should be interpretable as float,
175 | ## even if passed in as str because of command-line parser
176 | percentile = float(percentile)
177 |
178 | if torch.is_tensor(probs):
179 | probs = to_np(probs)
180 | return_torch = True
181 | else:
182 | return_torch = False
183 |
184 | probs = probs[:, :, :-1]
185 | assert probs.shape[-1] == self.N
186 |
187 | rewards = self.percentile(probs[:, 0], percentile, subslice=-2)
188 | next_values = self.percentile(probs[:, 1], percentile, subslice=-1)
189 |
190 | if return_torch:
191 | rewards = to_torch(rewards)
192 | next_values = to_torch(next_values)
193 |
194 | return rewards, next_values
195 |
196 | def largest_nonzero_index(x, dim):
197 | N = x.shape[dim]
198 | arange = np.arange(N) + 1
199 |
200 | for i in range(dim):
201 | arange = np.expand_dims(arange, axis=0)
202 | for i in range(dim+1, x.ndim):
203 | arange = np.expand_dims(arange, axis=-1)
204 |
205 | inds = np.argmax(x * arange, axis=0)
206 | ## masks for all `False` or all `True`
207 | lt_mask = (~x).all(axis=0)
208 | gt_mask = (x).all(axis=0)
209 |
210 | inds[lt_mask] = 0
211 | inds[gt_mask] = N
212 |
213 | return inds
214 |
--------------------------------------------------------------------------------
/trajectory/utils/git_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import git
3 | import pdb
4 |
5 | PROJECT_PATH = os.path.dirname(
6 | os.path.realpath(os.path.join(__file__, '..', '..')))
7 |
8 | def get_repo(path=PROJECT_PATH, search_parent_directories=True):
9 | repo = git.Repo(
10 | path, search_parent_directories=search_parent_directories)
11 | return repo
12 |
13 | def get_git_rev(*args, **kwargs):
14 | try:
15 | repo = get_repo(*args, **kwargs)
16 | if repo.head.is_detached:
17 | git_rev = repo.head.object.name_rev
18 | else:
19 | git_rev = repo.active_branch.commit.name_rev
20 | except:
21 | git_rev = None
22 |
23 | return git_rev
24 |
25 | def git_diff(*args, **kwargs):
26 | repo = get_repo(*args, **kwargs)
27 | diff = repo.git.diff()
28 | return diff
29 |
30 | def save_git_diff(savepath, *args, **kwargs):
31 | diff = git_diff(*args, **kwargs)
32 | with open(savepath, 'w') as f:
33 | f.write(diff)
34 |
35 | if __name__ == '__main__':
36 |
37 | git_rev = get_git_rev()
38 | print(git_rev)
39 |
40 | save_git_diff('diff_test.txt')
--------------------------------------------------------------------------------
/trajectory/utils/progress.py:
--------------------------------------------------------------------------------
1 | import time
2 | import math
3 | import re
4 | import pdb
5 |
6 | class Progress:
7 |
8 | def __init__(self, total, name = 'Progress', ncol=3, max_length=30, indent=8, line_width=100, speed_update_freq=100):
9 | self.total = total
10 | self.name = name
11 | self.ncol = ncol
12 | self.max_length = max_length
13 | self.indent = indent
14 | self.line_width = line_width
15 | self._speed_update_freq = speed_update_freq
16 |
17 | self._step = 0
18 | self._prev_line = '\033[F'
19 | self._clear_line = ' ' * self.line_width
20 |
21 | self._pbar_size = self.ncol * self.max_length
22 | self._complete_pbar = '#' * self._pbar_size
23 | self._incomplete_pbar = ' ' * self._pbar_size
24 |
25 | self.lines = ['']
26 | self.fraction = '{} / {}'.format(0, self.total)
27 |
28 | self.resume()
29 |
30 |
31 | def update(self, description, n=1):
32 | self._step += n
33 | if self._step % self._speed_update_freq == 0:
34 | self._time0 = time.time()
35 | self._step0 = self._step
36 | self.set_description(description)
37 |
38 | def resume(self):
39 | self._skip_lines = 1
40 | print('\n', end='')
41 | self._time0 = time.time()
42 | self._step0 = self._step
43 |
44 | def pause(self):
45 | self._clear()
46 | self._skip_lines = 1
47 |
48 | def set_description(self, params=[]):
49 |
50 | if type(params) == dict:
51 | params = sorted([
52 | (key, val)
53 | for key, val in params.items()
54 | ])
55 |
56 | ############
57 | # Position #
58 | ############
59 | self._clear()
60 |
61 | ###########
62 | # Percent #
63 | ###########
64 | percent, fraction = self._format_percent(self._step, self.total)
65 | self.fraction = fraction
66 |
67 | #########
68 | # Speed #
69 | #########
70 | speed = self._format_speed(self._step)
71 |
72 | ##########
73 | # Params #
74 | ##########
75 | num_params = len(params)
76 | nrow = math.ceil(num_params / self.ncol)
77 | params_split = self._chunk(params, self.ncol)
78 | params_string, lines = self._format(params_split)
79 | self.lines = lines
80 |
81 |
82 | description = '{} | {}{}'.format(percent, speed, params_string)
83 | print(description)
84 | self._skip_lines = nrow + 1
85 |
86 | def append_description(self, descr):
87 | self.lines.append(descr)
88 |
89 | def _clear(self):
90 | position = self._prev_line * self._skip_lines
91 | empty = '\n'.join([self._clear_line for _ in range(self._skip_lines)])
92 | print(position, end='')
93 | print(empty)
94 | print(position, end='')
95 |
96 | def _format_percent(self, n, total):
97 | if total:
98 | percent = n / float(total)
99 |
100 | complete_entries = int(percent * self._pbar_size)
101 | incomplete_entries = self._pbar_size - complete_entries
102 |
103 | pbar = self._complete_pbar[:complete_entries] + self._incomplete_pbar[:incomplete_entries]
104 | fraction = '{} / {}'.format(n, total)
105 | string = '{} [{}] {:3d}%'.format(fraction, pbar, int(percent*100))
106 | else:
107 | fraction = '{}'.format(n)
108 | string = '{} iterations'.format(n)
109 | return string, fraction
110 |
111 | def _format_speed(self, n):
112 | num_steps = n - self._step0
113 | t = time.time() - self._time0
114 | speed = num_steps / t
115 | string = '{:.1f} Hz'.format(speed)
116 | if num_steps > 0:
117 | self._speed = string
118 | return string
119 |
120 | def _chunk(self, l, n):
121 | return [l[i:i+n] for i in range(0, len(l), n)]
122 |
123 | def _format(self, chunks):
124 | lines = [self._format_chunk(chunk) for chunk in chunks]
125 | lines.insert(0,'')
126 | padding = '\n' + ' '*self.indent
127 | string = padding.join(lines)
128 | return string, lines
129 |
130 | def _format_chunk(self, chunk):
131 | line = ' | '.join([self._format_param(param) for param in chunk])
132 | return line
133 |
134 | def _format_param(self, param, str_length=8):
135 | k, v = param
136 | k = k.rjust(str_length)
137 | if type(v) == float or hasattr(v, 'item'):
138 | string = '{}: {:12.4f}'
139 | else:
140 | string = '{}: {}'
141 | v = str(v).rjust(12)
142 | return string.format(k, v)[:self.max_length]
143 |
144 | def stamp(self):
145 | if self.lines != ['']:
146 | params = ' | '.join(self.lines)
147 | string = '[ {} ] {}{} | {}'.format(self.name, self.fraction, params, self._speed)
148 | string = re.sub(r'\s+', ' ', string)
149 | self._clear()
150 | print(string, end='\n')
151 | self._skip_lines = 1
152 | else:
153 | self._clear()
154 | self._skip_lines = 0
155 |
156 | def close(self):
157 | self.pause()
158 |
159 | class Silent:
160 |
161 | def __init__(self, *args, **kwargs):
162 | pass
163 |
164 | def __getattr__(self, attr):
165 | return lambda *args: None
166 |
167 |
168 | if __name__ == '__main__':
169 | silent = Silent()
170 | silent.update()
171 | silent.stamp()
172 |
173 | num_steps = 1000
174 | progress = Progress(num_steps)
175 | for i in range(num_steps):
176 | progress.update()
177 | params = [
178 | ['A', '{:06d}'.format(i)],
179 | ['B', '{:06d}'.format(i)],
180 | ['C', '{:06d}'.format(i)],
181 | ['D', '{:06d}'.format(i)],
182 | ['E', '{:06d}'.format(i)],
183 | ['F', '{:06d}'.format(i)],
184 | ['G', '{:06d}'.format(i)],
185 | ['H', '{:06d}'.format(i)],
186 | ]
187 | progress.set_description(params)
188 | time.sleep(0.01)
189 | progress.close()
190 |
--------------------------------------------------------------------------------
/trajectory/utils/rendering.py:
--------------------------------------------------------------------------------
1 | import time
2 | import sys
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 | import torch
6 | import gym
7 | import mujoco_py as mjc
8 | import pdb
9 |
10 | from .arrays import to_np
11 | from .video import save_video, save_videos
12 | from ..datasets import load_environment, get_preprocess_fn
13 |
14 | def make_renderer(args):
15 | render_str = getattr(args, 'renderer')
16 | render_class = getattr(sys.modules[__name__], render_str)
17 | ## get dimensions in case the observations are preprocessed
18 | env = load_environment(args.dataset)
19 | preprocess_fn = get_preprocess_fn(args.dataset)
20 | observation = env.reset()
21 | observation = preprocess_fn(observation)
22 | return render_class(args.dataset, observation_dim=observation.size)
23 |
24 | def split(sequence, observation_dim, action_dim):
25 | assert sequence.shape[1] == observation_dim + action_dim + 2
26 | observations = sequence[:, :observation_dim]
27 | actions = sequence[:, observation_dim:observation_dim+action_dim]
28 | rewards = sequence[:, -2]
29 | values = sequence[:, -1]
30 | return observations, actions, rewards, values
31 |
32 | def set_state(env, state):
33 | qpos_dim = env.sim.data.qpos.size
34 | qvel_dim = env.sim.data.qvel.size
35 | qstate_dim = qpos_dim + qvel_dim
36 |
37 | if 'ant' in env.name:
38 | ypos = np.zeros(1)
39 | state = np.concatenate([ypos, state])
40 |
41 | if state.size == qpos_dim - 1 or state.size == qstate_dim - 1:
42 | xpos = np.zeros(1)
43 | state = np.concatenate([xpos, state])
44 |
45 | if state.size == qpos_dim:
46 | qvel = np.zeros(qvel_dim)
47 | state = np.concatenate([state, qvel])
48 |
49 | if 'ant' in env.name and state.size > qpos_dim + qvel_dim:
50 | xpos = np.zeros(1)
51 | state = np.concatenate([xpos, state])[:qstate_dim]
52 |
53 | assert state.size == qpos_dim + qvel_dim
54 |
55 | env.set_state(state[:qpos_dim], state[qpos_dim:])
56 |
57 | def rollout_from_state(env, state, actions):
58 | qpos_dim = env.sim.data.qpos.size
59 | env.set_state(state[:qpos_dim], state[qpos_dim:])
60 | observations = [env._get_obs()]
61 | for act in actions:
62 | obs, rew, term, _ = env.step(act)
63 | observations.append(obs)
64 | if term:
65 | break
66 | for i in range(len(observations), len(actions)+1):
67 | ## if terminated early, pad with zeros
68 | observations.append( np.zeros(obs.size) )
69 | return np.stack(observations)
70 |
71 | class DebugRenderer:
72 |
73 | def __init__(self, *args, **kwargs):
74 | pass
75 |
76 | def render(self, *args, **kwargs):
77 | return np.zeros((10, 10, 3))
78 |
79 | def render_plan(self, *args, **kwargs):
80 | pass
81 |
82 | def render_rollout(self, *args, **kwargs):
83 | pass
84 |
85 | class Renderer:
86 |
87 | def __init__(self, env, observation_dim=None, action_dim=None):
88 | if type(env) is str:
89 | self.env = load_environment(env)
90 | else:
91 | self.env = env
92 |
93 | self.observation_dim = observation_dim or np.prod(self.env.observation_space.shape)
94 | self.action_dim = action_dim or np.prod(self.env.action_space.shape)
95 | self.viewer = mjc.MjRenderContextOffscreen(self.env.sim)
96 |
97 | def __call__(self, *args, **kwargs):
98 | return self.renders(*args, **kwargs)
99 |
100 | def render(self, observation, dim=256, render_kwargs=None):
101 | observation = to_np(observation)
102 |
103 | if render_kwargs is None:
104 | render_kwargs = {
105 | 'trackbodyid': 2,
106 | 'distance': 3,
107 | 'lookat': [0, -0.5, 1],
108 | 'elevation': -20
109 | }
110 |
111 | for key, val in render_kwargs.items():
112 | if key == 'lookat':
113 | self.viewer.cam.lookat[:] = val[:]
114 | else:
115 | setattr(self.viewer.cam, key, val)
116 |
117 | set_state(self.env, observation)
118 |
119 | if type(dim) == int:
120 | dim = (dim, dim)
121 |
122 | self.viewer.render(*dim)
123 | data = self.viewer.read_pixels(*dim, depth=False)
124 | data = data[::-1, :, :]
125 | return data
126 |
127 | def renders(self, observations, **kwargs):
128 | images = []
129 | for observation in observations:
130 | img = self.render(observation, **kwargs)
131 | images.append(img)
132 | return np.stack(images, axis=0)
133 |
134 | def render_plan(self, savepath, sequence, state, fps=30):
135 | '''
136 | state : np.array[ observation_dim ]
137 | sequence : np.array[ horizon x transition_dim ]
138 | as usual, sequence is ordered as [ s_t, a_t, r_t, V_t, ... ]
139 | '''
140 |
141 | if len(sequence) == 1:
142 | return
143 |
144 | sequence = to_np(sequence)
145 |
146 | ## compare to ground truth rollout using actions from sequence
147 | actions = sequence[:-1, self.observation_dim : self.observation_dim + self.action_dim]
148 | rollout_states = rollout_from_state(self.env, state, actions)
149 |
150 | videos = [
151 | self.renders(sequence[:, :self.observation_dim]),
152 | self.renders(rollout_states),
153 | ]
154 |
155 | save_videos(savepath, *videos, fps=fps)
156 |
157 | def render_rollout(self, savepath, states, **video_kwargs):
158 | images = self(states)
159 | save_video(savepath, images, **video_kwargs)
160 |
161 | class KitchenRenderer:
162 |
163 | def __init__(self, env):
164 | if type(env) is str:
165 | self.env = gym.make(env)
166 | else:
167 | self.env = env
168 |
169 | self.observation_dim = np.prod(self.env.observation_space.shape)
170 | self.action_dim = np.prod(self.env.action_space.shape)
171 |
172 | def set_obs(self, obs, goal_dim=30):
173 | robot_dim = self.env.n_jnt
174 | obj_dim = self.env.n_obj
175 | assert robot_dim + obj_dim + goal_dim == obs.size or robot_dim + obj_dim == obs.size
176 | self.env.sim.data.qpos[:robot_dim] = obs[:robot_dim]
177 | self.env.sim.data.qpos[robot_dim:robot_dim+obj_dim] = obs[robot_dim:robot_dim+obj_dim]
178 | self.env.sim.forward()
179 |
180 | def rollout(self, obs, actions):
181 | self.set_obs(obs)
182 | observations = [env._get_obs()]
183 | for act in actions:
184 | obs, rew, term, _ = env.step(act)
185 | observations.append(obs)
186 | if term:
187 | break
188 | for i in range(len(observations), len(actions)+1):
189 | ## if terminated early, pad with zeros
190 | observations.append( np.zeros(observations[-1].size) )
191 | return np.stack(observations)
192 |
193 | def render(self, observation, dim=512, onscreen=False):
194 | self.env.sim_robot.renderer._camera_settings.update({
195 | 'distance': 4.5,
196 | 'azimuth': 90,
197 | 'elevation': -25,
198 | 'lookat': [0, 1, 2],
199 | })
200 | self.set_obs(observation)
201 | if onscreen:
202 | self.env.render()
203 | return self.env.sim_robot.renderer.render_offscreen(dim, dim)
204 |
205 | def renders(self, observations, **kwargs):
206 | images = []
207 | for observation in observations:
208 | img = self.render(observation, **kwargs)
209 | images.append(img)
210 | return np.stack(images, axis=0)
211 |
212 | def render_plan(self, *args, **kwargs):
213 | return self.render_rollout(*args, **kwargs)
214 |
215 | def render_rollout(self, savepath, states, **video_kwargs):
216 | images = self(states) #np.stack(states, axis=0))
217 | save_video(savepath, images, **video_kwargs)
218 |
219 | def __call__(self, *args, **kwargs):
220 | return self.renders(*args, **kwargs)
221 |
222 | ANTMAZE_BOUNDS = {
223 | 'antmaze-umaze-v0': (-3, 11),
224 | 'antmaze-medium-play-v0': (-3, 23),
225 | 'antmaze-medium-diverse-v0': (-3, 23),
226 | 'antmaze-large-play-v0': (-3, 39),
227 | 'antmaze-large-diverse-v0': (-3, 39),
228 | }
229 |
230 | class AntMazeRenderer:
231 |
232 | def __init__(self, env_name):
233 | self.env_name = env_name
234 | self.env = gym.make(env_name).unwrapped
235 | self.observation_dim = np.prod(self.env.observation_space.shape)
236 | self.action_dim = np.prod(self.env.action_space.shape)
237 |
238 | def renders(self, savepath, X):
239 | plt.clf()
240 |
241 | if X.ndim < 3:
242 | X = X[None]
243 |
244 | N, path_length, _ = X.shape
245 | if N > 4:
246 | fig, axes = plt.subplots(4, int(N/4))
247 | axes = axes.flatten()
248 | fig.set_size_inches(N/4,8)
249 | elif N > 1:
250 | fig, axes = plt.subplots(1, N)
251 | fig.set_size_inches(8,8)
252 | else:
253 | fig, axes = plt.subplots(1, 1)
254 | fig.set_size_inches(8,8)
255 |
256 | colors = plt.cm.jet(np.linspace(0,1,path_length))
257 | for i in range(N):
258 | ax = axes if N == 1 else axes[i]
259 | xlim, ylim = self.plot_boundaries(ax=ax)
260 | x = X[i]
261 | ax.scatter(x[:,0], x[:,1], c=colors)
262 | ax.set_xticks([])
263 | ax.set_yticks([])
264 | ax.set_xlim(*xlim)
265 | ax.set_ylim(*ylim)
266 | plt.savefig(savepath + '.png')
267 | plt.close()
268 | print(f'[ attentive/utils/visualization ] Saved to: {savepath}')
269 |
270 | def plot_boundaries(self, N=100, ax=None):
271 | """
272 | plots the maze boundaries in the antmaze environments
273 | """
274 | ax = ax or plt.gca()
275 |
276 | xlim = ANTMAZE_BOUNDS[self.env_name]
277 | ylim = ANTMAZE_BOUNDS[self.env_name]
278 |
279 | X = np.linspace(*xlim, N)
280 | Y = np.linspace(*ylim, N)
281 |
282 | Z = np.zeros((N, N))
283 | for i, x in enumerate(X):
284 | for j, y in enumerate(Y):
285 | collision = self.env.unwrapped._is_in_collision((x, y))
286 | Z[-j, i] = collision
287 |
288 | ax.imshow(Z, extent=(*xlim, *ylim), aspect='auto', cmap=plt.cm.binary)
289 | return xlim, ylim
290 |
291 | def render_plan(self, savepath, discretizer, state, sequence):
292 | '''
293 | state : np.array[ observation_dim ]
294 | sequence : np.array[ horizon x transition_dim ]
295 | as usual, sequence is ordered as [ s_t, a_t, r_t, V_t, ... ]
296 | '''
297 |
298 | if len(sequence) == 1:
299 | # raise RuntimeError(f'horizon is 1 in Renderer:render_plan: {sequence.shape}')
300 | return
301 |
302 | sequence = to_np(sequence)
303 |
304 | sequence_recon = discretizer.reconstruct(sequence)
305 |
306 | observations, actions, *_ = split(sequence_recon, self.observation_dim, self.action_dim)
307 |
308 | rollout_states = rollout_from_state(self.env, state, actions[:-1])
309 |
310 | X = np.stack([observations, rollout_states], axis=0)
311 |
312 | self.renders(savepath, X)
313 |
314 | def render_rollout(self, savepath, states, **video_kwargs):
315 | if type(states) is list:
316 | states = np.stack(states, axis=0)[None]
317 | images = self.renders(savepath, states)
318 |
319 | class Maze2dRenderer(AntMazeRenderer):
320 |
321 | def _is_in_collision(self, x, y):
322 | '''
323 | 10 : wall
324 | 11 : free
325 | 12 : goal
326 | '''
327 | maze = self.env.maze_arr
328 | ind = maze[int(x), int(y)]
329 | return ind == 10
330 |
331 | def plot_boundaries(self, N=100, ax=None, eps=1e-6):
332 | """
333 | plots the maze boundaries in the antmaze environments
334 | """
335 | ax = ax or plt.gca()
336 |
337 | maze = self.env.maze_arr
338 | xlim = (0, maze.shape[1]-eps)
339 | ylim = (0, maze.shape[0]-eps)
340 |
341 | X = np.linspace(*xlim, N)
342 | Y = np.linspace(*ylim, N)
343 |
344 | Z = np.zeros((N, N))
345 | for i, x in enumerate(X):
346 | for j, y in enumerate(Y):
347 | collision = self._is_in_collision(x, y)
348 | Z[-j, i] = collision
349 |
350 | ax.imshow(Z, extent=(*xlim, *ylim), aspect='auto', cmap=plt.cm.binary)
351 | return xlim, ylim
352 |
353 | def renders(self, savepath, X):
354 | return super().renders(savepath, X + 0.5)
355 |
356 | #--------------------------------- planning callbacks ---------------------------------#
357 |
358 |
--------------------------------------------------------------------------------
/trajectory/utils/serialization.py:
--------------------------------------------------------------------------------
1 | import time
2 | import sys
3 | import os
4 | import glob
5 | import pickle
6 | import json
7 | import torch
8 | import pdb
9 |
10 | def mkdir(savepath, prune_fname=False):
11 | """
12 | returns `True` iff `savepath` is created
13 | """
14 | if prune_fname:
15 | savepath = os.path.dirname(savepath)
16 | if not os.path.exists(savepath):
17 | try:
18 | os.makedirs(savepath)
19 | except:
20 | print(f'[ utils/serialization ] Warning: did not make directory: {savepath}')
21 | return False
22 | return True
23 | else:
24 | return False
25 |
26 | def get_latest_epoch(loadpath):
27 | states = glob.glob1(loadpath, 'state_*')
28 | latest_epoch = -1
29 | for state in states:
30 | epoch = int(state.replace('state_', '').replace('.pt', ''))
31 | latest_epoch = max(epoch, latest_epoch)
32 | return latest_epoch
33 |
34 | def load_model(*loadpath, epoch=None, device='cuda:0'):
35 | loadpath = os.path.join(*loadpath)
36 | config_path = os.path.join(loadpath, 'model_config.pkl')
37 |
38 | if epoch is 'latest':
39 | epoch = get_latest_epoch(loadpath)
40 |
41 | print(f'[ utils/serialization ] Loading model epoch: {epoch}')
42 | state_path = os.path.join(loadpath, f'state_{epoch}.pt')
43 |
44 | config = pickle.load(open(config_path, 'rb'))
45 | state = torch.load(state_path)
46 |
47 | model = config()
48 | model.to(device)
49 | model.load_state_dict(state, strict=True)
50 |
51 | print(f'\n[ utils/serialization ] Loaded config from {config_path}\n')
52 | print(config)
53 |
54 | return model, epoch
55 |
56 | def load_config(*loadpath):
57 | loadpath = os.path.join(*loadpath)
58 | config = pickle.load(open(loadpath, 'rb'))
59 | print(f'[ utils/serialization ] Loaded config from {loadpath}')
60 | print(config)
61 | return config
62 |
63 | def load_from_config(*loadpath):
64 | config = load_config(*loadpath)
65 | return config.make()
66 |
67 | def load_args(*loadpath):
68 | from .setup import Parser
69 | loadpath = os.path.join(*loadpath)
70 | args_path = os.path.join(loadpath, 'args.json')
71 | args = Parser()
72 | args.load(args_path)
73 | return args
74 |
--------------------------------------------------------------------------------
/trajectory/utils/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | import importlib
3 | import random
4 | import numpy as np
5 | import torch
6 | from tap import Tap
7 | import pdb
8 |
9 | from .serialization import mkdir
10 | from .arrays import set_device
11 | from .git_utils import (
12 | get_git_rev,
13 | save_git_diff,
14 | )
15 |
16 | def set_seed(seed):
17 | random.seed(seed)
18 | np.random.seed(seed)
19 | torch.manual_seed(seed)
20 | torch.cuda.manual_seed_all(seed)
21 |
22 | def watch(args_to_watch):
23 | def _fn(args):
24 | exp_name = []
25 | for key, label in args_to_watch:
26 | if not hasattr(args, key):
27 | continue
28 | val = getattr(args, key)
29 | exp_name.append(f'{label}{val}')
30 | exp_name = '_'.join(exp_name)
31 | exp_name = exp_name.replace('/_', '/')
32 | return exp_name
33 | return _fn
34 |
35 | class Parser(Tap):
36 |
37 | def save(self):
38 | fullpath = os.path.join(self.savepath, 'args.json')
39 | print(f'[ utils/setup ] Saved args to {fullpath}')
40 | super().save(fullpath, skip_unpicklable=True)
41 |
42 | def parse_args(self, experiment=None):
43 | args = super().parse_args(known_only=True)
44 | ## if not loading from a config script, skip the result of the setup
45 | if not hasattr(args, 'config'): return args
46 | args = self.read_config(args, experiment)
47 | self.add_extras(args)
48 | self.set_seed(args)
49 | self.get_commit(args)
50 | self.generate_exp_name(args)
51 | self.mkdir(args)
52 | self.save_diff(args)
53 | return args
54 |
55 | def read_config(self, args, experiment):
56 | '''
57 | Load parameters from config file
58 | '''
59 | dataset = args.dataset.replace('-', '_')
60 | print(f'[ utils/setup ] Reading config: {args.config}:{dataset}')
61 | module = importlib.import_module(args.config)
62 | params = getattr(module, 'base')[experiment]
63 |
64 | if hasattr(module, dataset) and experiment in getattr(module, dataset):
65 | print(f'[ utils/setup ] Using overrides | config: {args.config} | dataset: {dataset}')
66 | overrides = getattr(module, dataset)[experiment]
67 | params.update(overrides)
68 | else:
69 | print(f'[ utils/setup ] Not using overrides | config: {args.config} | dataset: {dataset}')
70 |
71 | for key, val in params.items():
72 | setattr(args, key, val)
73 |
74 | return args
75 |
76 | def add_extras(self, args):
77 | '''
78 | Override config parameters with command-line arguments
79 | '''
80 | extras = args.extra_args
81 | if not len(extras):
82 | return
83 |
84 | print(f'[ utils/setup ] Found extras: {extras}')
85 | assert len(extras) % 2 == 0, f'Found odd number ({len(extras)}) of extras: {extras}'
86 | for i in range(0, len(extras), 2):
87 | key = extras[i].replace('--', '')
88 | val = extras[i+1]
89 | assert hasattr(args, key), f'[ utils/setup ] {key} not found in config: {args.config}'
90 | old_val = getattr(args, key)
91 | old_type = type(old_val)
92 | print(f'[ utils/setup ] Overriding config | {key} : {old_val} --> {val}')
93 | if val == 'None':
94 | val = None
95 | elif val == 'latest':
96 | val = 'latest'
97 | elif old_type in [bool, type(None)]:
98 | val = eval(val)
99 | else:
100 | val = old_type(val)
101 | setattr(args, key, val)
102 |
103 | def set_seed(self, args):
104 | if not 'seed' in dir(args):
105 | return
106 | set_seed(args.seed)
107 |
108 | def generate_exp_name(self, args):
109 | if not 'exp_name' in dir(args):
110 | return
111 | exp_name = getattr(args, 'exp_name')
112 | if callable(exp_name):
113 | exp_name_string = exp_name(args)
114 | print(f'[ utils/setup ] Setting exp_name to: {exp_name_string}')
115 | setattr(args, 'exp_name', exp_name_string)
116 |
117 | def mkdir(self, args):
118 | if 'logbase' in dir(args) and 'dataset' in dir(args) and 'exp_name' in dir(args):
119 | args.savepath = os.path.join(args.logbase, args.dataset, args.exp_name)
120 | if 'suffix' in dir(args):
121 | args.savepath = os.path.join(args.savepath, args.suffix)
122 | if mkdir(args.savepath):
123 | print(f'[ utils/setup ] Made savepath: {args.savepath}')
124 | self.save()
125 |
126 | def get_commit(self, args):
127 | args.commit = get_git_rev()
128 |
129 | def save_diff(self, args):
130 | try:
131 | save_git_diff(os.path.join(args.savepath, 'diff.txt'))
132 | except:
133 | print('[ utils/setup ] WARNING: did not save git diff')
134 |
--------------------------------------------------------------------------------
/trajectory/utils/timer.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | class Timer:
4 |
5 | def __init__(self):
6 | self._start = time.time()
7 |
8 | def __call__(self, reset=True):
9 | now = time.time()
10 | diff = now - self._start
11 | if reset:
12 | self._start = now
13 | return diff
--------------------------------------------------------------------------------
/trajectory/utils/training.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.utils.data.dataloader import DataLoader
4 | import pdb
5 |
6 | from .timer import Timer
7 |
8 | def to(xs, device):
9 | return [x.to(device) for x in xs]
10 |
11 | class Trainer:
12 |
13 | def __init__(self, config):
14 | self.config = config
15 | self.device = config.device
16 |
17 | self.n_epochs = 0
18 | self.n_tokens = 0 # counter used for learning rate decay
19 | self.optimizer = None
20 |
21 | def get_optimizer(self, model):
22 | if self.optimizer is None:
23 | print(f'[ utils/training ] Making optimizer at epoch {self.n_epochs}')
24 | self.optimizer = model.configure_optimizers(self.config)
25 | return self.optimizer
26 |
27 | def train(self, model, dataset, n_epochs=1, log_freq=100):
28 |
29 | config = self.config
30 | optimizer = self.get_optimizer(model)
31 | model.train(True)
32 | vocab_size = dataset.N
33 |
34 | loader = DataLoader(dataset, shuffle=True, pin_memory=True,
35 | batch_size=config.batch_size,
36 | num_workers=config.num_workers)
37 |
38 | for _ in range(n_epochs):
39 |
40 | losses = []
41 | timer = Timer()
42 | for it, batch in enumerate(loader):
43 |
44 | batch = to(batch, self.device)
45 |
46 | # forward the model
47 | with torch.set_grad_enabled(True):
48 | logits, loss = model(*batch)
49 | losses.append(loss.item())
50 |
51 | # backprop and update the parameters
52 | model.zero_grad()
53 | loss.backward()
54 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
55 | optimizer.step()
56 |
57 | # decay the learning rate based on our progress
58 | if config.lr_decay:
59 | y = batch[-2]
60 | self.n_tokens += (y != vocab_size).sum() # number of tokens processed this step
61 | if self.n_tokens < config.warmup_tokens:
62 | # linear warmup
63 | lr_mult = float(self.n_tokens) / float(max(1, config.warmup_tokens))
64 | else:
65 | # cosine learning rate decay
66 | progress = float(self.n_tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens))
67 | lr_mult = max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress)))
68 | lr = config.learning_rate * lr_mult
69 | for param_group in optimizer.param_groups:
70 | param_group['lr'] = lr
71 | else:
72 | lr = config.learning_rate
73 |
74 | # report progress
75 | if it % log_freq == 0:
76 | print(
77 | f'[ utils/training ] epoch {self.n_epochs} [ {it:4d} / {len(loader):4d} ] ',
78 | f'train loss {loss.item():.5f} | lr {lr:.3e} | lr_mult: {lr_mult:.4f} | '
79 | f't: {timer():.2f}')
80 |
81 | self.n_epochs += 1
82 |
--------------------------------------------------------------------------------
/trajectory/utils/video.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import skvideo.io
4 |
5 | def _make_dir(filename):
6 | folder = os.path.dirname(filename)
7 | if not os.path.exists(folder):
8 | os.makedirs(folder)
9 |
10 | def save_video(filename, video_frames, fps=60, video_format='mp4'):
11 | assert fps == int(fps), fps
12 | _make_dir(filename)
13 |
14 | skvideo.io.vwrite(
15 | filename,
16 | video_frames,
17 | inputdict={
18 | '-r': str(int(fps)),
19 | },
20 | outputdict={
21 | '-f': video_format,
22 | '-pix_fmt': 'yuv420p', # '-pix_fmt=yuv420p' needed for osx https://github.com/scikit-video/scikit-video/issues/74
23 | }
24 | )
25 |
26 | def save_videos(filename, *video_frames, **kwargs):
27 | ## video_frame : [ N x H x W x C ]
28 | video_frames = np.concatenate(video_frames, axis=2)
29 | save_video(filename, video_frames, **kwargs)
30 |
--------------------------------------------------------------------------------