├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── Makefile
├── README.md
├── docs
├── ROADMAP.md
└── assets
│ ├── voltron-banner-alpha.png
│ ├── voltron-banner.png
│ └── voltron-framework.png
├── examples
├── pretrain
│ ├── README.md
│ ├── preprocess.py
│ └── pretrain.py
├── usage.py
├── verification
│ ├── img
│ │ ├── peel-carrot-final.png
│ │ ├── peel-carrot-initial.png
│ │ ├── place-bottle-final.png
│ │ ├── place-bottle-grasp.png
│ │ └── place-bottle-initial.png
│ └── verify.py
└── xla-reference
│ ├── README.md
│ ├── xpreprocess.py
│ └── xpretrain.py
├── pyproject.toml
├── setup.py
└── voltron
├── __init__.py
├── conf
├── __init__.py
├── accelerators.py
├── datasets.py
├── models.py
└── tracking.py
├── datasets
├── __init__.py
├── datasets.py
└── v1
│ ├── __init__.py
│ └── stream_datasets.py
├── models
├── __init__.py
├── core
│ ├── __init__.py
│ ├── vcond.py
│ ├── vdual.py
│ └── vgen.py
├── instantiate.py
├── materialize.py
├── reproductions
│ ├── __init__.py
│ ├── vmvp.py
│ ├── vr3m.py
│ └── vrn3m.py
└── util
│ ├── __init__.py
│ ├── extraction.py
│ ├── optimization.py
│ └── transformer.py
├── overwatch
├── __init__.py
└── overwatch.py
├── preprocessing
├── __init__.py
├── core.py
├── process.py
├── transforms.py
└── v1
│ ├── __init__.py
│ ├── process.py
│ ├── transforms.py
│ └── utils.py
└── util
├── __init__.py
├── checkpointing.py
├── metrics.py
├── utilities.py
└── v1
├── __init__.py
├── checkpointing.py
├── distributed.py
├── random.py
└── xla_logger.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # Ruff
132 | .ruff_cache/
133 |
134 | # IDE caches
135 | .idea/
136 | .vscode/
137 |
138 | # Mac OS
139 | .DS_Store
140 |
141 | # Cache
142 | data/
143 | cache/
144 |
145 | # Scratch
146 | scratch/
147 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # See https://pre-commit.com for more information
2 | # See https://pre-commit.com/hooks.html for more hooks
3 | exclude: ".git"
4 |
5 | repos:
6 | - repo: https://github.com/charliermarsh/ruff-pre-commit
7 | rev: v0.0.252
8 | hooks:
9 | - id: ruff
10 | args: [ --fix, --exit-non-zero-on-fix ]
11 |
12 | - repo: https://github.com/psf/black
13 | rev: 23.1.0
14 | hooks:
15 | - id: black
16 |
17 | - repo: https://github.com/pre-commit/pre-commit-hooks
18 | rev: v4.4.0
19 | hooks:
20 | - id: check-added-large-files
21 | args: ["--maxkb=40000"]
22 | - id: check-ast
23 | - id: check-case-conflict
24 | - id: check-merge-conflict
25 | - id: check-toml
26 | - id: check-yaml
27 | - id: end-of-file-fixer
28 | - id: trailing-whitespace
29 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021-present, Siddharth Karamcheti and other contributors.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: help check autoformat
2 | .DEFAULT: help
3 |
4 | # Generates a useful overview/help message for various make features - add to this as necessary!
5 | help:
6 | @echo "make check"
7 | @echo " Run code style and linting (black, ruff) *without* changing files!"
8 | @echo "make autoformat"
9 | @echo " Run code styling (black, ruff) and update in place - committing with pre-commit also does this."
10 |
11 | check:
12 | black --check .
13 | ruff check --show-source .
14 |
15 | autoformat:
16 | black .
17 | ruff check --fix --show-fixes .
18 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |

3 |
4 |
5 |
6 |
7 | [](https://arxiv.org/abs/2302.12766)
8 | [](https://pytorch.org/get-started/locally/)
9 | [](https://github.com/psf/black)
10 | [](https://github.com/charliermarsh/ruff)
11 | 
12 |
13 |
14 |
15 | ---
16 |
17 | # Language-Driven Representation Learning for Robotics
18 |
19 | Package repository for Voltron: Language-Driven Representation Learning for Robotics. Provides code for loading
20 | pretrained Voltron, R3M, and MVP representations for adaptation to downstream tasks, as well as code for pretraining
21 | such representations on arbitrary datasets.
22 |
23 | ---
24 |
25 | ## Quickstart
26 |
27 | This repository is built with PyTorch; while specified as a dependency for the package, we highly recommend that
28 | you install the desired version (e.g., with accelerator support) for your given hardware and environment
29 | manager (e.g., `conda`).
30 |
31 | PyTorch installation instructions [can be found here](https://pytorch.org/get-started/locally/). This repository
32 | should work with PyTorch >= 1.12. Releases before 1.1.0 have been thoroughly tested with PyTorch 1.12.0,
33 | Torchvision 0.13.0, and Torchaudio 0.12.0. **Note**: Releases 1.1.0 and after *assume PyTorch 2.0*!
34 |
35 | Once PyTorch has been properly installed, you can install this package via PyPI, and you're off!
36 |
37 | ```bash
38 | pip install voltron-robotics
39 | ```
40 |
41 | You can also install this package locally via an editable installation in case you want to run examples/extend the
42 | current functionality:
43 |
44 | ```bash
45 | git clone https://github.com/siddk/voltron-robotics
46 | cd voltron-robotics
47 | pip install -e .
48 | ```
49 |
50 | ## Usage
51 |
52 | Voltron Robotics (package: `voltron`) is structured to provide easy access to pretrained Voltron models (and
53 | reproductions), to facilitate use for various downstream tasks. Using a pretrained Voltron model is easy:
54 |
55 | ```python
56 | from torchvision.io import read_image
57 | from voltron import instantiate_extractor, load
58 |
59 | # Load a frozen Voltron (V-Cond) model & configure a vector extractor
60 | vcond, preprocess = load("v-cond", device="cuda", freeze=True)
61 | vector_extractor = instantiate_extractor(vcond)()
62 |
63 | # Obtain & Preprocess an image =>> can be from a dataset, or camera on a robot, etc.
64 | # => Feel free to add any language if you have it (Voltron models work either way!)
65 | img = preprocess(read_image("examples/img/peel-carrot-initial.png"))[None, ...].to("cuda")
66 | lang = ["peeling a carrot"]
67 |
68 | # Extract both multimodal AND vision-only embeddings!
69 | multimodal_embeddings = vcond(img, lang, mode="multimodal")
70 | visual_embeddings = vcond(img, mode="visual")
71 |
72 | # Use the `vector_extractor` to output dense vector representations for downstream applications!
73 | # => Pass this representation to model of your choice (object detector, control policy, etc.)
74 | representation = vector_extractor(multimodal_embeddings)
75 | ```
76 |
77 | Voltron representations can be used for a variety of different applications; in the
78 | [`voltron-evaluation`](https://github.com/siddk/voltron-evaluation) repository, you can find code for adapting Voltron
79 | representations to various downstream tasks (segmentation, object detection, control, etc.); all the applications from
80 | our paper.
81 |
82 | ---
83 |
84 | ## API
85 |
86 | 
87 |
88 | The package `voltron` provides the following functionality for using and adapting existing representations:
89 |
90 | #### `voltron.available_models()`
91 |
92 | Returns the name of available Voltron models; right now, the following models (all models trained in the paper) are
93 | available:
94 |
95 | - `v-cond` – V-Cond (ViT-Small) trained on Sth-Sth; single-frame w/ language-conditioning.
96 | - `v-dual` – V-Dual (ViT-Small) trained on Sth-Sth; dual-frame w/ language-conditioning.
97 | - `v-gen` – V-Gen (ViT-Small) trained on Sth-Sth; dual-frame w/ language conditioning AND generation.
98 | - `r-mvp` – R-MVP (ViT-Small); reproduction of [MVP](https://github.com/ir413/mvp) trained on Sth-Sth.
99 | - `r-r3m-vit` – R-R3M (ViT-Small); reproduction of [R3M](https://github.com/facebookresearch/r3m) trained on Sth-Sth.
100 | - `r-r3m-rn50` – R-R3M (ResNet-50); reproduction of [R3M](https://github.com/facebookresearch/r3m) trained on Sth-Sth.
101 | - `v-cond-base` – V-Cond (ViT-Base) trained on Sth-Sth; larger (86M parameter) variant of V-Cond.
102 |
103 | #### `voltron.load(name: str, device: str, freeze: bool, cache: str = cache/)`
104 |
105 | Returns the model and the Torchvision Transform needed by the model, where `name` is one of the strings returned
106 | by `voltron.available_models()`; this in general follows the same API as
107 | [OpenAI's CLIP](https://github.com/openai/CLIP).
108 |
109 | ---
110 |
111 | Voltron models (`v-{cond, dual, gen, ...}`) returned by `voltron.load()` support the following:
112 |
113 | #### `model(img: Tensor, lang: Optional[List[str]], mode: str = "multimodal")`
114 |
115 | Returns a sequence of embeddings corresponding to the output of the multimodal encoder; note that `lang` can be None,
116 | which is totally fine for Voltron models! However, if you have any language (even a coarse task description), it'll
117 | probably be helpful!
118 |
119 | The parameter `mode` in `["multimodal", "visual"]` controls whether the output will contain the fused image patch and
120 | language embeddings, or only the image patch embeddings.
121 |
122 | **Note:** For the API for the non-Voltron models (e.g., R-MVP, R-R3M), take a look at
123 | [`examples/verify.py`](examples/verify.py); this file shows how representations from *every* model can be extracted.
124 |
125 | ### Adaptation
126 |
127 | See [`examples/usage.py`](examples/usage.py) and the [`voltron-evaluation`](https://github.com/siddk/voltron-evaluation)
128 | repository for more examples on the various ways to adapt/use Voltron representations.
129 |
130 | ---
131 |
132 | ## Contributing
133 |
134 | Before committing to the repository, make sure to set up your dev environment!
135 | Here are the basic development environment setup guidelines:
136 |
137 | + Fork/clone the repository, performing an editable installation. Make sure to install with the development dependencies
138 | (e.g., `pip install -e ".[dev]"`); this will install `black`, `ruff`, and `pre-commit`.
139 |
140 | + Install `pre-commit` hooks (`pre-commit install`).
141 |
142 | + Branch for the specific feature/issue, issuing PR against the upstream repository for review.
143 |
144 | Additional Contribution Notes:
145 | - This project has migrated to the recommended
146 | [`pyproject.toml` based configuration for setuptools](https://setuptools.pypa.io/en/latest/userguide/quickstart.html).
147 | However, as some tools haven't yet adopted [PEP 660](https://peps.python.org/pep-0660/), we provide a
148 | [`setup.py` file](https://setuptools.pypa.io/en/latest/userguide/pyproject_config.html).
149 |
150 | - This package follows the [`flat-layout` structure](https://setuptools.pypa.io/en/latest/userguide/package_discovery.html#flat-layout)
151 | described in `setuptools`.
152 |
153 | - Make sure to add any new dependencies to the `project.toml` file!
154 |
155 | ---
156 |
157 | ## Repository Structure
158 |
159 | High-level overview of repository/project file-tree:
160 |
161 | + `docs/` - Package documentation & assets - including project roadmap.
162 | + `voltron` - Package source code; has all core utilities for model specification, loading, feature extraction,
163 | preprocessing, etc.
164 | + `examples/` - Standalone examples scripts for demonstrating various functionality (e.g., extracting different types
165 | of representations, adapting representations in various contexts, pretraining, amongst others).
166 | + `.pre-commit-config.yaml` - Pre-commit configuration file (sane defaults + `black` + `ruff`).
167 | + `LICENSE` - Code is made available under the MIT License.
168 | + `Makefile` - Top-level Makefile (by default, supports linting - checking & auto-fix); extend as needed.
169 | + `pyproject.toml` - Following PEP 621, this file has all project configuration details (including dependencies), as
170 | well as tool configurations (for `black` and `ruff`).
171 | + `README.md` - You are here!
172 |
173 | ---
174 |
175 | ## Citation
176 |
177 | Please cite [our paper](https://arxiv.org/abs/2302.12766) if using any of the Voltron models, evaluation suite, or other parts of our framework in your work.
178 |
179 | ```bibtex
180 | @inproceedings{karamcheti2023voltron,
181 | title={Language-Driven Representation Learning for Robotics},
182 | author={Siddharth Karamcheti and Suraj Nair and Annie S. Chen and Thomas Kollar and Chelsea Finn and Dorsa Sadigh and Percy Liang},
183 | booktitle={Robotics: Science and Systems (RSS)},
184 | year={2023}
185 | }
186 | ```
187 |
--------------------------------------------------------------------------------
/docs/ROADMAP.md:
--------------------------------------------------------------------------------
1 | # Project Roadmap
2 |
3 | We document the future of this project (new features to be added, issues to address) here. For the most part, any
4 | new features/bugfixes are documented as [Github Issues](https://github.com/siddk/voltron-robotics/issues).
5 |
6 | ## Timeline
7 |
8 | [X] - **February 26th, 2023**: Initial Voltron-Robotics release with support for loading/adapting all pretrained models,
9 | with comprehensive verification scripts & a small adaptation example.
10 |
11 | [X] - **April 4, 2023**: [#1](https://github.com/siddk/voltron-robotics/issues/1) - Add `xpretrain.py` reference script,
12 | mostly for completeness. Refactor/rewrite the preprocessing and pretraining pipeline to reflect
13 | the Qualcomm Sth-Sth data format, as well as PyTorch DDP vs. the patched PyTorch XLA!
14 |
15 | [X] - **April 11, 2023**: [#2](https://github.com/siddk/voltron-robotics/issues/2) - Add support and a more general API
16 | for pretraining on other datasets.
17 |
18 | [ ] - **Future**: [#5](https://github.com/siddk/voltron-robotics/issues/5) - Add better documentation and examples
19 | around using the MAP extractor (especially for adaptation tasks).
20 |
--------------------------------------------------------------------------------
/docs/assets/voltron-banner-alpha.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/siddk/voltron-robotics/1b299bf5cfa06673a3738aa6e15423b92a9922cd/docs/assets/voltron-banner-alpha.png
--------------------------------------------------------------------------------
/docs/assets/voltron-banner.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/siddk/voltron-robotics/1b299bf5cfa06673a3738aa6e15423b92a9922cd/docs/assets/voltron-banner.png
--------------------------------------------------------------------------------
/docs/assets/voltron-framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/siddk/voltron-robotics/1b299bf5cfa06673a3738aa6e15423b92a9922cd/docs/assets/voltron-framework.png
--------------------------------------------------------------------------------
/examples/pretrain/README.md:
--------------------------------------------------------------------------------
1 | # Pretraining Voltron Models
2 |
3 | We provide scripts for pretraining Voltron models on various datasets. Below, we provide the full pipeline from
4 | downloading the raw Something-Something-v2 Dataset from Qualcomm, running preprocessing, then running Distributed
5 | Data Parallel (DDP) pretraining on 1+ GPUs via `torchrun`. Adding support for new datasets should follow this same
6 | general flow.
7 |
8 | ---
9 |
10 | ## Dataset Preprocessing
11 |
12 | We provide end-to-end instructions for downloading, preprocessing, and serializing various pretraining datasets (and
13 | combinations thereof). Where possible, we provide links to batch/dataset index files.
14 |
15 | **Note:** We make a key assumption that you have enough local disk space (e.g., on your server, attached NFS volume) to
16 | store all *raw* and *preprocessed* data; this can range from 100s of GBs to 10s of TBs! We did not have access to such
17 | storage in the original work, necessitating the *streaming* dataloaders defined in
18 | `voltron/datasets/v1/stream_datasets.py`. Given your resources, you might consider adopting a similar approach; feel
19 | free to post an issue with any questions!
20 |
21 | We currently support pretraining on the following datasets:
22 |
23 | - [Something-Something-v2](https://developer.qualcomm.com/software/ai-datasets/something-something)
24 |
25 | Instructions for downloading/preprocessing each dataset can be found below!
26 |
27 | ---
28 |
29 | ### Something-Something-v2
30 |
31 | Dataset Download: [Qualcomm AI Datasets](https://developer.qualcomm.com/software/ai-datasets/something-something)
32 |
33 | #### Obtaining the Raw Dataset
34 |
35 | Follow the instructions [at the above link](https://developer.qualcomm.com/software/ai-datasets/something-something) to
36 | download the dataset. Qualcomm requires that you register for a
37 | [Qualcomm OneID Account](https://myaccount.qualcomm.com/signup?target=https%3A%2F%2Fdeveloper.qualcomm.com)
38 | to get access to the data. Approval might take some time.
39 |
40 | After registering for an account, make sure to download all of the following files to a directory of your choosing
41 | (we create a directory `data/raw/something-something-v2/downloaded/`). *You will need to manually download all 22 of
42 | the following files from the Qualcomm site*:
43 |
44 | 1. Datasheet / Instructions (PDF – optional, but useful): `20bn-something-something_download_instructions_-_091622.pdf`
45 | 2. Labels (includes language annotations): `20bn-something-something_download-package-labels.zip`
46 | 3. Chunked Videos (should be 20 `.zip` archives):
47 | + `20bn-something-something-v2-00.zip`
48 | + ...
49 | + `20bn-something-something-v2-19.zip`
50 |
51 | To extract all the given files (we extract to `data/raw/something-something-v2/`) - *execute the following from inside
52 | the `downloaded/` subdirectory)*:
53 |
54 | ```bash
55 | # Labels (annotations/language) --> creates `data/raw/something-something-v2/labels`
56 | unzip 20bn-something-something-download-package-labels.zip -d ../
57 |
58 | # Videos (following instructions in `20-bn-something-something_download_instructions_-_091622.pdf`)
59 | unzip "20bn-something-something-v2-*.zip" -d ../videos
60 | cd ../videos
61 | cat 20bn-something-something-?? | tar -xvzf -
62 | find . -maxdepth 1 -type f -delete
63 | cd 20bn-something-something-v2/
64 | find . -mindepth 1 -maxdepth 1 -exec mv -t .. -- {} +
65 | cd ..
66 | rm -r 20bn-something-something-v2
67 | ls | wc # Should have 220847 `.webm` files!
68 | ```
69 |
70 | #### Dataset Information & Statistics
71 |
72 | Something-Something-v2 consists of 220,847 `.webm` clips (168,913 in the `train` split) each with a height of exactly
73 | 240px, and variable width. The frames are encoded at a fixed 12 FPS.
74 |
75 | There are an average of 45 frames per clip (approx ~7 KB per jpeg); ~7.6M frames total (~56 GB).
76 |
77 | #### Video/Image Transformations --> from Video Clip to "frame" --> "tensor"
78 |
79 | ```python
80 | import av
81 | from PIL import Image, ImageOps
82 |
83 | # Resolutions for "preprocessing" (serialize to disk) and "training"
84 | PREPROCESS_RESOLUTION, TRAIN_RESOLUTION = 240, 224
85 |
86 | # Define Preprocessing Transformation
87 | def preprocess_transform(frames: List[Image.Image]) -> List[Image.Image]:
88 | # Assert width >= height and height >= PREPROCESS_RESOLUTION
89 | orig_w, orig_h = frames[0].size
90 | assert orig_w >= orig_h >= PREPROCESS_RESOLUTION
91 |
92 | # Compute scale factor --> just a function of height and PREPROCESS_RESOLUTION
93 | scale_factor = PREPROCESS_RESOLUTION / orig_h
94 |
95 | # Full Transformation --> scale (preserve aspect ratio, then get square)
96 | for idx in range(len(frames)):
97 | frames[idx] = ImageOps.scale(frames[idx], factor=scale_factor)
98 | left = (frames[idx].size[0] - PREPROCESS_RESOLUTION) // 2
99 | frames[idx] = frames[idx].crop((left, 0, left + PREPROCESS_RESOLUTION, PREPROCESS_RESOLUTION))
100 |
101 | return frames
102 |
103 | def train_transform(img) -> torch.Tensor:
104 | # Assumes square, just resizes to TRAIN_RESOLUTION via `torchvision.transforms`
105 | ...
106 |
107 | def extract_frames(webm_file: str) -> None:
108 | container = av.open(webm_file)
109 | assert int(container.streams.video[0].average_rate) == 12, "FPS for `sth-sth-v2` should be 12!"
110 |
111 | # Extract --> then serialize via `Image.save("frame_{idx}.jpg")`
112 | frames = preprocess_transform([f.to_image() for f in container.decode(video=0)])
113 | ...
114 | ```
115 |
116 |
117 | #### Citation
118 |
119 | If you are pretraining on this dataset, make sure to cite the original research; Something-Something-v2 is the product
120 | of two papers:
121 |
122 | ```bibtex
123 | @inproceedings{goyal2017sthsthv1,
124 | author = {Raghav Goyal and Samira Ebrahimi Kahou and Vincent Michalski and Joanna Materzynska and Susanne Westphal and Heuna Kim and Valentin Haenel and Ingo Fründ and Peter N. Yianilos and Moritz Mueller-Freitag and Florian Hoppe and Christian Thurau and Ingo Bax and Roland Memisevic},
125 | booktitle = {International Conference on Computer Vision (ICCV)},
126 | title = {The ``Something Something'' Video Database for Learning and Evaluating Visual Common Sense},
127 | year = {2017},
128 | }
129 | @article{mahidisoltani2018sthsthv2,
130 | author={Farzaneh Mahdisoltani and Guillaume Berger and Waseem Gharbieh and David J. Fleet and Roland Memisevic},
131 | journal = {arXiv preprint arXiv:1804.09235},
132 | title={On the Effectiveness of Task Granularity for Transfer Learning},
133 | year={2018}
134 | }
135 | ```
136 |
137 | ---
138 |
139 | ## PyTorch Native Pretraining Pipeline
140 |
141 | To pretrain a Voltron model (e.g., `v-cond`) on the processed data, make sure to read `examples/pretrain/preprocess.py`.
142 | A sample launch command to run with the Something-Something-v2 dataset on a single node with 8 GPUs is as follows:
143 |
144 | ```bash
145 | torchrun --standalone --nnodes 1 --nproc-per-node 8 examples/pretrain/pretrain.py
146 | ```
147 |
148 | Make sure to check the following configuration files and either update them manually (adding your own dataclass,
149 | overriding [DEFAULTS](https://github.com/siddk/voltron-robotics/blob/main/examples/pretrain/pretrain.py#L38)), or by
150 | using Hydra semantics to override them at the command line (e.g., `... pretrain.py dataset.path="" ...`):
151 |
152 | - [Accelerator Config](../../voltron/conf/accelerators.py): Depending on hardware, might need to tune `num_workers`
153 | - [Dataset Config](../../voltron/conf/datasets.py): Make sure to override `path` and `artifact_path`
154 | - [Tracking Config](../../voltron/conf/tracking.py): Disable Weights & Biases / change default entity/name
155 |
--------------------------------------------------------------------------------
/examples/pretrain/preprocess.py:
--------------------------------------------------------------------------------
1 | """
2 | preprocess.py
3 |
4 | Centralized script for preprocessing various video/vision-language datasets for GPU pretraining, using a multi-stage,
5 | multiprocessing approach.
6 |
7 | Run as a standalone script, *prior* to calling `pretrain.py` =>> mostly because we want to preprocess the data once, as
8 | a fixed cost.
9 | """
10 | import logging
11 | from dataclasses import dataclass, field
12 | from typing import Any, Dict, List
13 |
14 | import hydra
15 | from hydra.core.config_store import ConfigStore
16 | from omegaconf import MISSING
17 |
18 | from voltron.conf import DatasetConfig
19 | from voltron.overwatch import OverwatchRich
20 | from voltron.preprocessing import extract_frames, preprocess_language, unify_batches
21 | from voltron.util import set_global_seed
22 |
23 | # Grab Logger
24 | overwatch = logging.getLogger(__file__)
25 |
26 |
27 | # Set Defaults (Hydra w/ Structured Configs)
28 | DEFAULTS = ["_self_", {"dataset": "sth-sth-v2"}, {"override hydra/job_logging": "overwatch_rich"}]
29 |
30 |
31 | @dataclass
32 | class PreprocessingConfig:
33 | # fmt: off
34 | defaults: List[Any] = field(default_factory=lambda: DEFAULTS)
35 | hydra: Dict[str, Any] = field(
36 | default_factory=lambda: {"run": {"dir": "./runs/preprocessing/${now:%m-%d}/dataset-${dataset.name}"}}
37 | )
38 |
39 | # Command Line Arguments
40 | seed: int = 21 # Random Seed (for reproducibility)
41 | dry_run: bool = False # Dry Run --> Get a sense of preprocessing/serialization footprint
42 |
43 | # Composable / Structured Arguments
44 | dataset: DatasetConfig = MISSING # Dataset(s) for pretraining/preprocessing
45 | # fmt: on
46 |
47 |
48 | # Hydra Setup :: Retrieve ConfigStore (Singleton) & Register Components
49 | cs = ConfigStore.instance()
50 | cs.store(group="hydra/job_logging", name="overwatch_rich", node=OverwatchRich)
51 | cs.store(name="config", node=PreprocessingConfig)
52 |
53 |
54 | @hydra.main(config_path=None, config_name="config")
55 | def preprocess(cfg: PreprocessingConfig) -> None:
56 | overwatch.info("Preprocessing :: Running Phases for Frame Extraction, Language Compilation, and Batching...")
57 |
58 | # Set Randomness
59 | set_global_seed(cfg.seed)
60 |
61 | # Phase 1 :: Serialize Frames from Video Clips --> get `registry` (index files) for train and validation
62 | train_registry, val_registry, train_dir, val_dir = extract_frames(
63 | cfg.dataset.name,
64 | path=cfg.dataset.path,
65 | artifact_path=cfg.dataset.artifact_path,
66 | preprocess_resolution=cfg.dataset.preprocess_resolution,
67 | n_val_videos=cfg.dataset.n_val_videos,
68 | dry_run=cfg.dry_run,
69 | )
70 |
71 | # Phase 2 :: Normalize & Tokenize Language --> create `index.pt` and `index.json` files
72 | index_dir = preprocess_language(
73 | cfg.dataset.name,
74 | train_registry,
75 | val_registry,
76 | artifact_path=cfg.dataset.artifact_path,
77 | max_lang_len=cfg.dataset.max_lang_len,
78 | language_model=cfg.dataset.language_model,
79 | hf_cache=cfg.dataset.hf_cache,
80 | )
81 |
82 | # Phase 3 :: Assemble "Data-Locked" Batch Sets for Various Models (e.g., for single-frame/dual-frame/quintet)
83 | unify_batches(
84 | cfg.dataset.name,
85 | train_registry,
86 | val_registry,
87 | train_dir,
88 | val_dir,
89 | index_dir,
90 | batch_formats=cfg.dataset.batch_formats,
91 | max_epochs=cfg.dataset.max_epochs,
92 | initial_final_alpha=cfg.dataset.initial_final_alpha,
93 | )
94 |
95 | overwatch.info("Preprocessing Complete!")
96 |
97 |
98 | if __name__ == "__main__":
99 | preprocess()
100 |
--------------------------------------------------------------------------------
/examples/usage.py:
--------------------------------------------------------------------------------
1 | """
2 | usage.py
3 |
4 | Example script demonstrating how to load a Voltron model (`V-Cond`) and instantiate a Multiheaded Attention Pooling
5 | extractor head for downstream tasks.
6 |
7 | This is the basic formula/protocol for using Voltron for arbitrary downstream applications.
8 |
9 | Run with (from root of repository): `python examples/usage.py`
10 | """
11 | import torch
12 | from torchvision.io import read_image
13 |
14 | from voltron import instantiate_extractor, load
15 |
16 |
17 | def usage() -> None:
18 | print("[*] Demonstrating Voltron Usage for Various Adaptation Applications")
19 |
20 | # Get `torch.device` for loading model (note -- we'll load weights directly onto device!)
21 | device = "cuda" if torch.cuda.is_available() else "cpu"
22 |
23 | # Load Voltron model --> specify `freeze`, `device` and get model (nn.Module) and preprocessor
24 | vcond, preprocess = load("v-cond", device=device, freeze=True)
25 |
26 | # Obtain and preprocess an image =>> can be from a dataset, from a camera on a robot, etc.
27 | img = preprocess(read_image("examples/img/peel-carrot-initial.png"))[None, ...].to(device)
28 | lang = ["peeling a carrot"]
29 |
30 | # Get various representations...
31 | with torch.no_grad():
32 | multimodal_features = vcond(img, lang, mode="multimodal") # Fused vision & language features
33 | visual_features = vcond(img, mode="visual") # Vision-only features (no language)
34 |
35 | # Can instantiate various extractors for downstream applications
36 | vector_extractor = instantiate_extractor(vcond, n_latents=1, device=device)()
37 | seq_extractor = instantiate_extractor(vcond, n_latents=64, device=device)()
38 |
39 | # Assertions...
40 | assert list(vector_extractor(multimodal_features).shape) == [1, vcond.embed_dim], "Should return a dense vector!"
41 | assert list(seq_extractor(visual_features).shape) == [1, 64, vcond.embed_dim], "Should return a sequence!"
42 |
43 |
44 | if __name__ == "__main__":
45 | usage()
46 |
--------------------------------------------------------------------------------
/examples/verification/img/peel-carrot-final.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/siddk/voltron-robotics/1b299bf5cfa06673a3738aa6e15423b92a9922cd/examples/verification/img/peel-carrot-final.png
--------------------------------------------------------------------------------
/examples/verification/img/peel-carrot-initial.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/siddk/voltron-robotics/1b299bf5cfa06673a3738aa6e15423b92a9922cd/examples/verification/img/peel-carrot-initial.png
--------------------------------------------------------------------------------
/examples/verification/img/place-bottle-final.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/siddk/voltron-robotics/1b299bf5cfa06673a3738aa6e15423b92a9922cd/examples/verification/img/place-bottle-final.png
--------------------------------------------------------------------------------
/examples/verification/img/place-bottle-grasp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/siddk/voltron-robotics/1b299bf5cfa06673a3738aa6e15423b92a9922cd/examples/verification/img/place-bottle-grasp.png
--------------------------------------------------------------------------------
/examples/verification/img/place-bottle-initial.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/siddk/voltron-robotics/1b299bf5cfa06673a3738aa6e15423b92a9922cd/examples/verification/img/place-bottle-initial.png
--------------------------------------------------------------------------------
/examples/verification/verify.py:
--------------------------------------------------------------------------------
1 | """
2 | verify.py
3 |
4 | Example script demonstrating how to load all Voltron models (and reproduced models), take input image(s), and get the
5 | various (e.g., multimodal, image-only) representations.
6 |
7 | Also serves to verify that representation loading is working as advertised.
8 |
9 | Run with (from root of repository): `python examples/verification/verify.py`
10 | """
11 | import torch
12 | from torchvision.io import read_image
13 |
14 | from voltron import load
15 |
16 | # Available Models
17 | MODELS = ["v-cond", "v-dual", "v-gen", "r-mvp", "r-r3m-vit", "r-r3m-rn50"]
18 |
19 | # Sample Inputs
20 | IMG_A, IMG_B = "examples/verification/img/peel-carrot-initial.png", "examples/verification/img/peel-carrot-final.png"
21 | LANGUAGE = "peeling a carrot"
22 |
23 |
24 | def verify() -> None:
25 | print("[*] Running `verify` =>> Verifying Model Representations!")
26 |
27 | # Read both images (we'll use the second image for the dual-frame models)
28 | image_a, image_b = read_image(IMG_A), read_image(IMG_B)
29 |
30 | # Get `torch.device` for loading model (note -- we'll load weights directly onto device!)
31 | device = "cuda" if torch.cuda.is_available() else "cpu"
32 |
33 | for model_id in MODELS:
34 | print(f"\t=> Loading Model ID `{model_id}` and Verifying Representation Shapes!")
35 | model, preprocess = load(model_id, device=device, freeze=True)
36 |
37 | # Preprocess image, run feature extraction --> assert on shapes!
38 | if model_id in {"v-cond", "v-cond-base"}:
39 | for modality, expected in [("multimodal", 196 + 20), ("visual", 196)]:
40 | representation = model(preprocess(image_a)[None, ...].to(device), [LANGUAGE], mode=modality)
41 | assert representation.squeeze(dim=0).shape[0] == expected, "Shape not expected!"
42 |
43 | elif model_id in {"v-dual", "v-gen"}:
44 | for modality, expected in [("multimodal", 196 + 20), ("visual", 196)]:
45 | dual_img = torch.stack([preprocess(image_a), preprocess(image_b)])[None, ...].to(device)
46 | representation = model(dual_img, [LANGUAGE], mode=modality)
47 | assert representation.squeeze(dim=0).shape[0] == expected, "Shape not expected!"
48 |
49 | elif model_id == "r-mvp":
50 | for mode, expected in [("patch", 196), ("cls", 1)]:
51 | representation = model(preprocess(image_a)[None, ...].to(device), mode=mode)
52 | assert representation.squeeze(dim=0).shape[0] == expected, "Shape not expected!"
53 |
54 | elif model_id in {"r-r3m-vit", "r-r3m-rn50"}:
55 | representation = model(preprocess(image_a)[None, ...].to(device))
56 | assert representation.squeeze(dim=0).shape[0] == 1, "Shape not expected!"
57 |
58 | else:
59 | raise ValueError(f"Model {model_id} not supported!")
60 |
61 | # We're good!
62 | print("[*] All representations & shapes verified! Yay!")
63 |
64 |
65 | if __name__ == "__main__":
66 | verify()
67 |
--------------------------------------------------------------------------------
/examples/xla-reference/README.md:
--------------------------------------------------------------------------------
1 | # XLA Reference
2 |
3 | *Note :: This code was written for the experimental PyTorch XLA build in PyTorch 1.12; no guarantees it works with later
4 | versions!*
5 |
6 | We trained the original Voltron models (and data-locked reproductions of R3M and MVP) on TPU v3-8 nodes generously
7 | provided by the [TPU Research Cloud (TRC)](https://sites.research.google/trc/about/) program. At the time we started
8 | the project, PyTorch XLA still had some bumps, which was further complicated by the switch from
9 | [TPU Nodes to TPU VMs](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu-arch).
10 |
11 | To get things to work, we had to add some non-intuitive code to facilitate PyTorch + TPUs (vs. a standard distributed
12 | data parallel training pipeline). As a result, `xpretrain.py` is here mostly for documentation purposes, with a fully
13 | refactored version `pretrain.py` forthcoming.
14 |
15 | We also include the original cloud preprocessing script `xpreprocess.py` for completeness (this is more general).
16 |
--------------------------------------------------------------------------------
/examples/xla-reference/xpreprocess.py:
--------------------------------------------------------------------------------
1 | """
2 | xpreprocess.py
3 |
4 | Centralized script for preprocessing Sth-Sth-v2 for TPU/GCP pretraining, using a multi-stage, multiprocessing strategy.
5 |
6 | Run as a standalone script, *prior* to calling `xpretrain.py` =>> mostly because we want to preprocess the data
7 | once, as a fixed cost.
8 | """
9 | import logging
10 | from dataclasses import dataclass, field
11 | from typing import Any, Dict, List
12 |
13 | import hydra
14 | from hydra.core.config_store import ConfigStore
15 | from omegaconf import MISSING
16 |
17 | from voltron.conf import DatasetConfig
18 | from voltron.overwatch import OverwatchRich
19 | from voltron.preprocessing.v1 import index, jsonify_language, preprocess_language, preprocess_videos, unify_batches
20 | from voltron.util.v1.random import set_global_seed
21 |
22 | # Grab Logger
23 | overwatch = logging.getLogger(__file__)
24 |
25 |
26 | # Set Defaults (Hydra w/ Structured Configs)
27 | DEFAULTS = ["_self_", {"dataset": "sth-sth-v2"}, {"override hydra/job_logging": "overwatch_rich"}]
28 |
29 |
30 | @dataclass
31 | class PreprocessingConfig:
32 | # fmt: off
33 | defaults: List[Any] = field(default_factory=lambda: DEFAULTS)
34 | hydra: Dict[str, Any] = field(
35 | default_factory=lambda: {"run": {"dir": "./runs/preprocessing/${now:%m-%d}/dataset-${dataset.name}"}}
36 | )
37 |
38 | # Command Line Arguments
39 | seed: int = 21 # Random Seed (for reproducibility)
40 | dry_run: bool = False # Dry Run --> Get a sense of preprocessing/serialization footprint
41 |
42 | # Composable / Structured Arguments
43 | dataset: DatasetConfig = MISSING # Dataset(s) for pretraining/preprocessing
44 | # fmt: on
45 |
46 |
47 | # Hydra Setup :: Retrieve ConfigStore (Singleton) & Register Components
48 | cs = ConfigStore.instance()
49 | cs.store(group="hydra/job_logging", name="overwatch_rich", node=OverwatchRich)
50 | cs.store(name="config", node=PreprocessingConfig)
51 |
52 |
53 | @hydra.main(config_path=None, config_name="config")
54 | def xpreprocess(cfg: PreprocessingConfig) -> None:
55 | overwatch.info("Preprocessing :: Running Phases for Frame Extraction, Language Compilation, and Batching...")
56 |
57 | # Set Randomness
58 | set_global_seed(cfg.seed)
59 |
60 | # Phase 1 :: Serialize Frames from Video Clips --> Get `registry` for train and val (index structure)
61 | train_registry, val_registry, train_dir, val_dir = preprocess_videos(
62 | cfg.dataset.name,
63 | path=cfg.dataset.path,
64 | artifact_path=cfg.dataset.artifact_path,
65 | resolution=cfg.dataset.resolution,
66 | n_val_videos=cfg.dataset.n_val_videos,
67 | dry_run=cfg.dry_run,
68 | )
69 |
70 | # Phase 2 :: Normalize & Tokenize Language --> Create `index.pt` & `index.json` files
71 | preprocess_language(
72 | cfg.dataset.name,
73 | train_registry,
74 | val_registry,
75 | max_lang_len=cfg.dataset.max_lang_len,
76 | language_model=cfg.dataset.language_model,
77 | hf_cache=cfg.dataset.hf_cache,
78 | )
79 | jsonify_language(train_registry, val_registry)
80 | index_dir = index(train_registry, val_registry, cfg.dataset.name, artifact_path=cfg.dataset.artifact_path)
81 |
82 | # Phase 3 :: Assemble & Unify Batch "Sets" across the Varied Dataset Formats (for each Model =>> "data-locked")
83 | unify_batches(
84 | cfg.dataset.artifact_path,
85 | cfg.dataset.name,
86 | train_registry,
87 | val_registry,
88 | train_dir,
89 | val_dir,
90 | index_dir,
91 | cfg.dataset.batch_formats,
92 | max_epochs=cfg.dataset.max_epochs,
93 | initial_final_alpha=cfg.dataset.initial_final_alpha,
94 | )
95 |
96 |
97 | if __name__ == "__main__":
98 | xpreprocess()
99 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "voltron-robotics"
7 | authors = [
8 | {name = "Siddharth Karamcheti", email="skaramcheti@cs.stanford.edu"}
9 | ]
10 | description = "Voltron: Language-Driven Representation Learning for Robotics."
11 | version = "1.1.0"
12 | readme = "README.md"
13 | requires-python = ">=3.8"
14 | keywords = ["robotics", "representation learning", "natural language processing", "machine learning"]
15 | license = {file = "LICENSE"}
16 | classifiers = [
17 | "Development Status :: 3 - Alpha",
18 | "Intended Audience :: Developers",
19 | "Intended Audience :: Education",
20 | "Intended Audience :: Science/Research",
21 | "License :: OSI Approved :: MIT License",
22 | "Operating System :: OS Independent",
23 | "Programming Language :: Python :: 3",
24 | "Programming Language :: Python :: 3.8",
25 | "Programming Language :: Python :: 3.9",
26 | "Programming Language :: Python :: 3.10",
27 | "Programming Language :: Python :: 3 :: Only",
28 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
29 | ]
30 | dependencies = [
31 | "av",
32 | "einops",
33 | "gdown",
34 | "google-cloud-storage",
35 | "h5py",
36 | "hurry.filesize",
37 | "hydra-core==1.1.1", # Lock Hydra =>> future versions break!
38 | "jsonlines",
39 | "omegaconf==2.1.2", # Lock OmegaConf =>> future versions break!
40 | "opencv-python",
41 | "pandas",
42 | "rich",
43 | "torch>=2.0.0", # Native PyTorch Code (Release 2.0.0) uses PyTorch 2.0!
44 | "torchvision>=0.15.0",
45 | "transformers",
46 | "wandb",
47 | ]
48 |
49 | [project.optional-dependencies]
50 | dev = [
51 | "black",
52 | "ipython",
53 | "pre-commit",
54 | "ruff",
55 | ]
56 |
57 | [project.urls]
58 | homepage = "https://github.com/siddk/voltron-robotics"
59 | repository = "https://github.com/siddk/voltron-robotics"
60 | documentation = "https://github.com/siddk/voltron-robotics"
61 |
62 | [tool.black]
63 | line-length = 121
64 | target-version = ["py38", "py39", "py310"]
65 | preview = true
66 |
67 | [tool.ruff]
68 | line-length = 121
69 | target-version = "py38"
70 | select = ["A", "B", "C90", "E", "F", "I", "RUF", "W"]
71 |
72 | [tool.ruff.per-file-ignores]
73 | "__init__.py" = ["E402", "F401"]
74 |
75 | [tool.setuptools.packages.find]
76 | where = ["."]
77 | exclude = ["cache"]
78 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """
2 | setup.py
3 |
4 | PEP 621 switches most of Packaging to `pyproject.toml` -- yet keep a "dummy" setup.py for external code that has not
5 | yet upgraded.
6 | """
7 | from setuptools import setup
8 |
9 | setup()
10 |
--------------------------------------------------------------------------------
/voltron/__init__.py:
--------------------------------------------------------------------------------
1 | from .models.materialize import available_models, load
2 | from .models.util import instantiate_extractor
3 |
--------------------------------------------------------------------------------
/voltron/conf/__init__.py:
--------------------------------------------------------------------------------
1 | from .accelerators import AcceleratorConfig
2 | from .datasets import DatasetConfig
3 | from .models import ModelConfig
4 | from .tracking import TrackingConfig
5 |
--------------------------------------------------------------------------------
/voltron/conf/accelerators.py:
--------------------------------------------------------------------------------
1 | """
2 | accelerator.py
3 |
4 | Base Hydra Structured Configs for defining various accelerator schemes. Uses a simple single inheritance structure.
5 | """
6 | import os
7 | from dataclasses import dataclass
8 |
9 | from hydra.core.config_store import ConfigStore
10 | from omegaconf import MISSING
11 |
12 | # === Vanilla Accelerators (Deprecated; mostly for XLA code) ===
13 |
14 |
15 | @dataclass
16 | class AcceleratorConfig:
17 | accelerator: str = MISSING
18 | num_accelerators: int = MISSING
19 | num_workers: int = MISSING
20 |
21 |
22 | @dataclass
23 | class TPUv2OneConfig(AcceleratorConfig):
24 | accelerator = "tpu"
25 | num_accelerators = 1
26 | num_workers = 4
27 |
28 |
29 | @dataclass
30 | class TPUv2EightConfig(AcceleratorConfig):
31 | accelerator = "tpu"
32 | num_accelerators = 8
33 | num_workers = 4
34 |
35 |
36 | @dataclass
37 | class TPUv3OneConfig(AcceleratorConfig):
38 | accelerator = "tpu"
39 | num_accelerators = 1
40 | num_workers = 8
41 |
42 |
43 | @dataclass
44 | class TPUv3EightConfig(AcceleratorConfig):
45 | accelerator = "tpu"
46 | num_accelerators = 8
47 | num_workers = 8
48 |
49 |
50 | # === GPU Default Config --> just set `num_workers`; `torchrun` takes care of the rest! ===
51 | # > Note :: Defaults to 1 GPU if WORLD_SIZE not set (e.g., not running with `torchrun`)
52 |
53 |
54 | @dataclass
55 | class TorchRunDefaultConfig(AcceleratorConfig):
56 | accelerator = "gpu"
57 | num_accelerators = int(os.environ["WORLD_SIZE"] if "WORLD_SIZE" in os.environ else 1)
58 | num_workers = 8
59 |
60 |
61 | # Create a configuration group `accelerator` and populate with the above...
62 | cs = ConfigStore.instance()
63 | cs.store(group="accelerator", name="tpu-v2-1", node=TPUv2OneConfig)
64 | cs.store(group="accelerator", name="tpu-v2-8", node=TPUv2EightConfig)
65 | cs.store(group="accelerator", name="tpu-v3-1", node=TPUv3OneConfig)
66 | cs.store(group="accelerator", name="tpu-v3-8", node=TPUv3EightConfig)
67 |
68 | cs.store(group="accelerator", name="torchrun", node=TorchRunDefaultConfig)
69 |
--------------------------------------------------------------------------------
/voltron/conf/datasets.py:
--------------------------------------------------------------------------------
1 | """
2 | datasets.py
3 |
4 | Base Hydra Structured Config for defining various pretraining datasets and appropriate configurations. Uses a simple,
5 | single inheritance structure.
6 | """
7 | from dataclasses import dataclass
8 | from typing import Any, Tuple
9 |
10 | from hydra.core.config_store import ConfigStore
11 | from hydra.utils import to_absolute_path
12 | from omegaconf import MISSING
13 |
14 |
15 | @dataclass
16 | class DatasetConfig:
17 | name: str = MISSING
18 | path: str = MISSING
19 | artifact_path: str = MISSING
20 |
21 | # Streaming Parameters (assumes fully preprocessed dataset lives at `stream_prefix/...`)
22 | # =>> Deprecated as of `v2`
23 | stream: bool = True
24 | stream_prefix: str = "data/processed"
25 |
26 | # Dataset-Specific Parameters
27 | resolution: int = 224
28 | normalization: Tuple[Any, Any] = MISSING
29 |
30 | # For preprocessing --> maximum size of saved frames (assumed square)
31 | preprocess_resolution: int = MISSING
32 |
33 | # Validation Parameters
34 | n_val_videos: int = MISSING
35 |
36 | # Language Modeling Parameters
37 | language_model: str = "distilbert-base-uncased"
38 | hf_cache: str = to_absolute_path("data/hf-cache")
39 |
40 | # Maximum Length for truncating language inputs... should be computed after the fact (set to -1 to compute!)
41 | max_lang_len: int = MISSING
42 |
43 | # Dataset sets the number of pretraining epochs (general rule :: warmup should be ~5% of full)
44 | warmup_epochs: int = MISSING
45 | max_epochs: int = MISSING
46 |
47 | # Plausible Formats --> These are instantiations each "batch" could take, with a small DSL
48 | # > Note: Assumes final element of the list is the "most expressive" --> used to back-off
49 | batch_formats: Any = (
50 | ("state", ("state_i",)),
51 | ("state+language", ("state_i", "language")),
52 | ("state+ok", ("state_initial", "state_i", "language")),
53 | ("quintet+language", ("state_initial", "state_i", "state_j", "state_k", "state_final", "language")),
54 | )
55 |
56 | # Preprocessing :: Frame-Sampling Parameters
57 | initial_final_alpha: float = 0.2
58 |
59 |
60 | @dataclass
61 | class SthSthv2Config(DatasetConfig):
62 | # fmt: off
63 | name: str = "sth-sth-v2"
64 | path: str = to_absolute_path("data/raw/sth-sth-v2")
65 | artifact_path: str = to_absolute_path("data/processed/sth-sth-v2")
66 |
67 | # Dataset Specific arguments
68 | normalization: Tuple[Any, Any] = ( # Mean & Standard Deviation (default :: ImageNet)
69 | (0.485, 0.456, 0.406),
70 | (0.229, 0.224, 0.225),
71 | )
72 |
73 | # Sth-Sth-v2 Videos have a fixed height of 240; we'll crop to square at this resolution!
74 | preprocess_resolution: int = 240
75 |
76 | # Validation Parameters
77 | n_val_videos: int = 1000 # Number of Validation Clips (fast evaluation!)
78 |
79 | # Epochs for Dataset
80 | warmup_epochs: int = 20
81 | max_epochs: int = 400
82 |
83 | # Language Modeling Parameters
84 | max_lang_len: int = 20
85 | # fmt: on
86 |
87 |
88 | # Create a configuration group `dataset` and populate with the above...
89 | # =>> Note :: this is meant to be extendable --> add arbitrary datasets & mixtures!
90 | cs = ConfigStore.instance()
91 | cs.store(group="dataset", name="sth-sth-v2", node=SthSthv2Config)
92 |
--------------------------------------------------------------------------------
/voltron/conf/tracking.py:
--------------------------------------------------------------------------------
1 | """
2 | tracking.py
3 |
4 | Base Hydra Structured Config for defining various run & experiment tracking configurations, e.g., via Weights & Biases.
5 | Uses a simple single inheritance structure.
6 | """
7 | from dataclasses import dataclass, field
8 | from typing import List, Optional, Tuple
9 |
10 | from hydra.core.config_store import ConfigStore
11 | from omegaconf import MISSING
12 |
13 |
14 | @dataclass
15 | class TrackingConfig:
16 | # Active Loggers --> List of Loggers
17 | active_loggers: List[str] = field(default_factory=lambda: ["jsonl", "wandb"])
18 |
19 | # Generic Logging Frequency --> Matters more for XLA/TPUs... set this to be as large as you can stomach!
20 | log_frequency: int = 100
21 |
22 | # Checkpointing Strategy --> Save each epoch, keep most recent `idx[0]` checkpoints & *every* `idx[1]` checkpoints
23 | # Additionally, save (locally) a checkpoint every `idx[2]` steps for the current epoch (-1).
24 | checkpoint_strategy: Tuple[int, int, int] = (1, 1, 1500)
25 |
26 | # Weights & Biases Setup
27 | project: str = "voltron-pretraining"
28 | entity: str = "voltron-robotics"
29 |
30 | # Notes & Tags are at the discretion of the user... see below
31 | notes: str = MISSING
32 | tags: Optional[List[str]] = None
33 |
34 | # Directory to save W&B Metadata & Logs in General -- if None, defaults to `logs/` in the Hydra CWD
35 | directory: Optional[str] = None
36 |
37 |
38 | @dataclass
39 | class VoltronTrackingConfig(TrackingConfig):
40 | # Note: I really like using notes to keep track of things, so will crash unless specified with run.
41 | # > For `tags` I like to populate based on other args in the script, so letting it remain None
42 | notes: str = MISSING
43 |
44 |
45 | # Create a configuration group `trackers` and populate with the above...
46 | cs = ConfigStore.instance()
47 | cs.store(group="tracking", name="voltron-tracking", node=VoltronTrackingConfig)
48 |
--------------------------------------------------------------------------------
/voltron/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .datasets import get_datasets
2 |
--------------------------------------------------------------------------------
/voltron/datasets/v1/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/siddk/voltron-robotics/1b299bf5cfa06673a3738aa6e15423b92a9922cd/voltron/datasets/v1/__init__.py
--------------------------------------------------------------------------------
/voltron/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .instantiate import VMVP, VR3M, VRN3M, VCond, VDual, VGen, get_model_optimizer
2 |
--------------------------------------------------------------------------------
/voltron/models/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/siddk/voltron-robotics/1b299bf5cfa06673a3738aa6e15423b92a9922cd/voltron/models/core/__init__.py
--------------------------------------------------------------------------------
/voltron/models/instantiate.py:
--------------------------------------------------------------------------------
1 | """
2 | instantiate.py
3 |
4 | Simple wrapping script for instantiating a core Voltron/reproduction model and configuring the torch.Optimizer for DDP
5 | pretraining. Meant to be modular and extensible!
6 | """
7 | from typing import Callable, Tuple
8 |
9 | import torch.nn as nn
10 | from torch.optim import Optimizer
11 |
12 | from voltron.conf import DatasetConfig, ModelConfig
13 |
14 | from .core.vcond import VCond
15 | from .core.vdual import VDual
16 | from .core.vgen import VGen
17 | from .reproductions.vmvp import VMVP
18 | from .reproductions.vr3m import VR3M
19 | from .reproductions.vrn3m import VRN3M
20 |
21 |
22 | def get_model_optimizer(
23 | model_cfg: ModelConfig, dataset_cfg: DatasetConfig
24 | ) -> Tuple[nn.Module, Optimizer, Callable[[int, float], float]]:
25 | """Switch on `model_cfg.arch` --> instantiate the correct nn.Module and Optimizer (on CPU/default device)."""
26 |
27 | # Data-Locked Reproductions
28 | if model_cfg.arch == "v-mvp":
29 | model = VMVP(
30 | resolution=dataset_cfg.resolution,
31 | patch_size=model_cfg.patch_size,
32 | encoder_depth=model_cfg.encoder_depth,
33 | encoder_embed_dim=model_cfg.encoder_embed_dim,
34 | encoder_n_heads=model_cfg.encoder_n_heads,
35 | decoder_depth=model_cfg.decoder_depth,
36 | decoder_embed_dim=model_cfg.decoder_embed_dim,
37 | decoder_n_heads=model_cfg.decoder_n_heads,
38 | optimizer=model_cfg.optimizer,
39 | schedule=model_cfg.schedule,
40 | base_lr=model_cfg.base_lr,
41 | min_lr=model_cfg.min_lr,
42 | effective_bsz=model_cfg.effective_bsz,
43 | betas=model_cfg.betas,
44 | weight_decay=model_cfg.weight_decay,
45 | warmup_epochs=dataset_cfg.warmup_epochs,
46 | max_epochs=dataset_cfg.max_epochs,
47 | mlp_ratio=model_cfg.mlp_ratio,
48 | norm_pixel_loss=model_cfg.norm_pixel_loss,
49 | )
50 |
51 | elif model_cfg.arch == "v-r3m":
52 | model = VR3M(
53 | resolution=dataset_cfg.resolution,
54 | patch_size=model_cfg.patch_size,
55 | depth=model_cfg.depth,
56 | embed_dim=model_cfg.embed_dim,
57 | n_heads=model_cfg.n_heads,
58 | language_model=model_cfg.language_model,
59 | hf_cache=model_cfg.hf_cache,
60 | language_dim=model_cfg.language_dim,
61 | reward_dim=model_cfg.reward_dim,
62 | n_negatives=model_cfg.n_negatives,
63 | lang_reward_weight=model_cfg.lang_reward_weight,
64 | tcn_weight=model_cfg.tcn_weight,
65 | l1_weight=model_cfg.l1_weight,
66 | l2_weight=model_cfg.l2_weight,
67 | optimizer=model_cfg.optimizer,
68 | schedule=model_cfg.schedule,
69 | lr=model_cfg.lr,
70 | min_lr=model_cfg.min_lr,
71 | warmup_epochs=dataset_cfg.warmup_epochs,
72 | max_epochs=dataset_cfg.max_epochs,
73 | mlp_ratio=model_cfg.mlp_ratio,
74 | )
75 |
76 | elif model_cfg.arch == "v-rn3m":
77 | model = VRN3M(
78 | resolution=dataset_cfg.resolution,
79 | fc_dim=model_cfg.fc_dim,
80 | language_model=model_cfg.language_model,
81 | hf_cache=model_cfg.hf_cache,
82 | language_dim=model_cfg.language_dim,
83 | reward_dim=model_cfg.reward_dim,
84 | n_negatives=model_cfg.n_negatives,
85 | lang_reward_weight=model_cfg.lang_reward_weight,
86 | tcn_weight=model_cfg.tcn_weight,
87 | l1_weight=model_cfg.l1_weight,
88 | l2_weight=model_cfg.l2_weight,
89 | optimizer=model_cfg.optimizer,
90 | lr=model_cfg.lr,
91 | )
92 |
93 | # Voltron Models
94 | elif model_cfg.arch == "v-cond":
95 | model = VCond(
96 | resolution=dataset_cfg.resolution,
97 | patch_size=model_cfg.patch_size,
98 | encoder_depth=model_cfg.encoder_depth,
99 | encoder_embed_dim=model_cfg.encoder_embed_dim,
100 | encoder_n_heads=model_cfg.encoder_n_heads,
101 | decoder_depth=model_cfg.decoder_depth,
102 | decoder_embed_dim=model_cfg.decoder_embed_dim,
103 | decoder_n_heads=model_cfg.decoder_n_heads,
104 | language_model=model_cfg.language_model,
105 | hf_cache=model_cfg.hf_cache,
106 | language_dim=model_cfg.language_dim,
107 | optimizer=model_cfg.optimizer,
108 | schedule=model_cfg.schedule,
109 | base_lr=model_cfg.base_lr,
110 | min_lr=model_cfg.min_lr,
111 | effective_bsz=model_cfg.effective_bsz,
112 | betas=model_cfg.betas,
113 | weight_decay=model_cfg.weight_decay,
114 | warmup_epochs=dataset_cfg.warmup_epochs,
115 | max_epochs=dataset_cfg.max_epochs,
116 | mlp_ratio=model_cfg.mlp_ratio,
117 | norm_pixel_loss=model_cfg.norm_pixel_loss,
118 | )
119 |
120 | elif model_cfg.arch == "v-dual":
121 | model = VDual(
122 | resolution=dataset_cfg.resolution,
123 | patch_size=model_cfg.patch_size,
124 | encoder_depth=model_cfg.encoder_depth,
125 | encoder_embed_dim=model_cfg.encoder_embed_dim,
126 | encoder_n_heads=model_cfg.encoder_n_heads,
127 | decoder_depth=model_cfg.decoder_depth,
128 | decoder_embed_dim=model_cfg.decoder_embed_dim,
129 | decoder_n_heads=model_cfg.decoder_n_heads,
130 | language_model=model_cfg.language_model,
131 | hf_cache=model_cfg.hf_cache,
132 | language_dim=model_cfg.language_dim,
133 | optimizer=model_cfg.optimizer,
134 | schedule=model_cfg.schedule,
135 | base_lr=model_cfg.base_lr,
136 | min_lr=model_cfg.min_lr,
137 | effective_bsz=model_cfg.effective_bsz,
138 | betas=model_cfg.betas,
139 | weight_decay=model_cfg.weight_decay,
140 | warmup_epochs=dataset_cfg.warmup_epochs,
141 | max_epochs=dataset_cfg.max_epochs,
142 | mlp_ratio=model_cfg.mlp_ratio,
143 | norm_pixel_loss=model_cfg.norm_pixel_loss,
144 | )
145 |
146 | elif model_cfg.arch == "v-gen":
147 | model = VGen(
148 | resolution=dataset_cfg.resolution,
149 | patch_size=model_cfg.patch_size,
150 | encoder_depth=model_cfg.encoder_depth,
151 | encoder_embed_dim=model_cfg.encoder_embed_dim,
152 | encoder_n_heads=model_cfg.encoder_n_heads,
153 | decoder_depth=model_cfg.decoder_depth,
154 | decoder_embed_dim=model_cfg.decoder_embed_dim,
155 | decoder_n_heads=model_cfg.decoder_n_heads,
156 | language_model=model_cfg.language_model,
157 | hf_cache=model_cfg.hf_cache,
158 | language_dim=model_cfg.language_dim,
159 | max_lang_len=dataset_cfg.max_lang_len,
160 | vocab_size=model_cfg.vocab_size,
161 | mae_weight=model_cfg.mae_weight,
162 | lm_weight=model_cfg.lm_weight,
163 | optimizer=model_cfg.optimizer,
164 | schedule=model_cfg.schedule,
165 | base_lr=model_cfg.base_lr,
166 | min_lr=model_cfg.min_lr,
167 | effective_bsz=model_cfg.effective_bsz,
168 | betas=model_cfg.betas,
169 | weight_decay=model_cfg.weight_decay,
170 | warmup_epochs=dataset_cfg.warmup_epochs,
171 | max_epochs=dataset_cfg.max_epochs,
172 | mlp_ratio=model_cfg.mlp_ratio,
173 | norm_pixel_loss=model_cfg.norm_pixel_loss,
174 | )
175 |
176 | else:
177 | raise ValueError(f"Model Architecture `{model_cfg.arch}` is not implemented!")
178 |
179 | # Configure Optimizer --> on same device (CPU)
180 | optimizer, update_lr = model.configure_optimizer()
181 |
182 | return model, optimizer, update_lr
183 |
--------------------------------------------------------------------------------
/voltron/models/materialize.py:
--------------------------------------------------------------------------------
1 | """
2 | materialize.py
3 |
4 | Core functionality for using pretrained models; defines the package-level `load` functionality for downloading and
5 | instantiating pretrained Voltron (and baseline) models.
6 | """
7 | import json
8 | import os
9 | from pathlib import Path
10 | from typing import Callable, List, Tuple
11 |
12 | import gdown
13 | import torch
14 | import torch.nn as nn
15 | import torchvision.transforms as T
16 |
17 | from voltron.models import VMVP, VR3M, VRN3M, VCond, VDual, VGen
18 |
19 | # === Define Useful Variables for Loading Models ===
20 | DEFAULT_CACHE = "cache/"
21 | NORMALIZATION = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
22 |
23 | # Pretrained Model Registry :: "model id" -> {"config" -> gdown ID, "checkpoint" -> gdown ID, "cls" -> Model Class}
24 | MODEL_REGISTRY = {
25 | # === Voltron ViT-Small (Sth-Sth) Models ===
26 | "v-cond": {
27 | "config": "1O4oqRIblfS6PdFlZzUcYIX-Rqe6LbvnD",
28 | "checkpoint": "12g5QckQSMKqrfr4lFY3UPdy7oLw4APpG",
29 | "cls": VCond,
30 | },
31 | "v-dual": {
32 | "config": "1zgKiK81SF9-0lg0XbMZwNhUh1Q7YdZZU",
33 | "checkpoint": "1CCRqrwcvF8xhIbJJmwnCbcWfWTJCK40T",
34 | "cls": VDual,
35 | },
36 | "v-gen": {
37 | "config": "18-mUBDsr-2_-KrGoL2E2YzjcUO8JOwUF",
38 | "checkpoint": "1TzSQpKVKBWKCSvYJf22c45hrKczTQz24",
39 | "cls": VGen,
40 | },
41 | # === Voltron ViT-Base Model ===
42 | "v-cond-base": {
43 | "config": "1CLe7CaIzTEcGCijIgw_S-uqMXHfBFSLI",
44 | "checkpoint": "1PwczOijL0hfYD8DI4xLOPLf1xL_7Kg9S",
45 | "cls": VCond,
46 | },
47 | # === Data-Locked Reproductions ===
48 | "r-mvp": {
49 | "config": "1KKNWag6aS1xkUiUjaJ1Khm9D6F3ROhCR",
50 | "checkpoint": "1-ExshZ6EC8guElOv_s-e8gOJ0R1QEAfj",
51 | "cls": VMVP,
52 | },
53 | "r-r3m-vit": {
54 | "config": "1JGk32BLXwI79uDLAGcpbw0PiupBknf-7",
55 | "checkpoint": "1Yby5oB4oPc33IDQqYxwYjQV3-56hjCTW",
56 | "cls": VR3M,
57 | },
58 | "r-r3m-rn50": {
59 | "config": "1OS3mB4QRm-MFzHoD9chtzSmVhOA-eL_n",
60 | "checkpoint": "1t1gkQYr6JbRSkG3fGqy_9laFg_54IIJL",
61 | "cls": VRN3M,
62 | },
63 | }
64 |
65 |
66 | def available_models() -> List[str]:
67 | return list(MODEL_REGISTRY.keys())
68 |
69 |
70 | def load(
71 | model_id: str, device: torch.device = "cpu", freeze: bool = True, cache: str = DEFAULT_CACHE
72 | ) -> Tuple[nn.Module, Callable[[torch.Tensor], torch.Tensor]]:
73 | """
74 | Download & cache specified model configuration & checkpoint, then load & return module & image processor.
75 |
76 | Note :: We *override* the default `forward()` method of each of the respective model classes with the
77 | `extract_features` method --> by default passing "NULL" language for any language-conditioned models.
78 | This can be overridden either by passing in language (as a `str) or by invoking the corresponding methods.
79 | """
80 | assert model_id in MODEL_REGISTRY, f"Model ID `{model_id}` not valid, try one of {list(MODEL_REGISTRY.keys())}"
81 |
82 | # Download Config & Checkpoint (if not in cache)
83 | model_cache = Path(cache) / model_id
84 | config_path, checkpoint_path = model_cache / f"{model_id}-config.json", model_cache / f"{model_id}.pt"
85 | os.makedirs(model_cache, exist_ok=True)
86 | if not checkpoint_path.exists() or not config_path.exists():
87 | gdown.download(id=MODEL_REGISTRY[model_id]["config"], output=str(config_path), quiet=False)
88 | gdown.download(id=MODEL_REGISTRY[model_id]["checkpoint"], output=str(checkpoint_path), quiet=False)
89 |
90 | # Load Configuration --> patch `hf_cache` key if present (don't download to random locations on filesystem)
91 | with open(config_path, "r") as f:
92 | model_kwargs = json.load(f)
93 | if "hf_cache" in model_kwargs:
94 | model_kwargs["hf_cache"] = str(Path(cache) / "hf-cache")
95 |
96 | # By default, the model's `__call__` method defaults to `forward` --> for downstream applications, override!
97 | # > Switch `__call__` to `get_representations`
98 | MODEL_REGISTRY[model_id]["cls"].__call__ = MODEL_REGISTRY[model_id]["cls"].get_representations
99 |
100 | # Materialize Model (load weights from checkpoint; note that unused element `_` are the optimizer states...)
101 | model = MODEL_REGISTRY[model_id]["cls"](**model_kwargs)
102 | state_dict, _ = torch.load(checkpoint_path, map_location=device)
103 | model.load_state_dict(state_dict, strict=True)
104 | model.to(device)
105 | model.eval()
106 |
107 | # Freeze model parameters if specified (default: True)
108 | if freeze:
109 | for _, param in model.named_parameters():
110 | param.requires_grad = False
111 |
112 | # Build Visual Preprocessing Transform (assumes image is read into a torch.Tensor, but can be adapted)
113 | if model_id in {"v-cond", "v-dual", "v-gen", "v-cond-base", "r-mvp"}:
114 | # All models except R3M are by default normalized subject to default IN1K normalization...
115 | preprocess = T.Compose(
116 | [
117 | T.Resize(model_kwargs["resolution"]),
118 | T.CenterCrop(model_kwargs["resolution"]),
119 | T.ConvertImageDtype(torch.float),
120 | T.Normalize(mean=NORMALIZATION[0], std=NORMALIZATION[1]),
121 | ]
122 | )
123 | else:
124 | # R3M models (following original work) expect unnormalized images with values in range [0 - 255)
125 | preprocess = T.Compose(
126 | [
127 | T.Resize(model_kwargs["resolution"]),
128 | T.CenterCrop(model_kwargs["resolution"]),
129 | T.ConvertImageDtype(torch.float),
130 | T.Lambda(lambda x: x * 255.0),
131 | ]
132 | )
133 |
134 | return model, preprocess
135 |
--------------------------------------------------------------------------------
/voltron/models/reproductions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/siddk/voltron-robotics/1b299bf5cfa06673a3738aa6e15423b92a9922cd/voltron/models/reproductions/__init__.py
--------------------------------------------------------------------------------
/voltron/models/reproductions/vmvp.py:
--------------------------------------------------------------------------------
1 | """
2 | vmvp.py
3 |
4 | PyTorch Module defining a basic MAE a la Masked Visual Pretraining for Motor Control (MVP), with the requisite
5 | hyperparameters - as defined in the original ImageMAE paper, and as used by both MVP papers.
6 |
7 | References:
8 | - https://github.com/facebookresearch/mae
9 | - https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
10 | """
11 | from typing import Callable, Optional, Tuple
12 |
13 | import torch
14 | import torch.nn as nn
15 | from einops import rearrange
16 |
17 | from voltron.models.util.optimization import get_lr_update
18 | from voltron.models.util.transformer import Block, PatchEmbed, get_2D_position_embeddings
19 |
20 |
21 | class VMVP(nn.Module):
22 | def __init__(
23 | self,
24 | resolution: int,
25 | patch_size: int,
26 | encoder_depth: int,
27 | encoder_embed_dim: int,
28 | encoder_n_heads: int,
29 | decoder_depth: int,
30 | decoder_embed_dim: int,
31 | decoder_n_heads: int,
32 | optimizer: str,
33 | schedule: str,
34 | base_lr: float,
35 | min_lr: float,
36 | effective_bsz: float,
37 | betas: Tuple[float, float],
38 | weight_decay: float,
39 | warmup_epochs: int,
40 | max_epochs: int,
41 | mask_ratio: float = 0.75,
42 | mlp_ratio: float = 4.0,
43 | in_channels: int = 3,
44 | norm_pixel_loss: bool = True,
45 | ):
46 | """
47 | Initialize an VMVP (MAE) model with the requisite architecture parameters.
48 |
49 | :param resolution: Base image resolution -- usually 224 (ImageNet size).
50 | :param patch_size: Height/Width of each patch in pixels -- usually 16.
51 | :param encoder_depth: Number of Transformer blocks in the encoder -- should be greater than decoder.
52 | :param encoder_embed_dim: Core embedding/hidden dimension for encoder vision transformer backbone.
53 | :param encoder_n_heads: Number of heads for encoder multi-headed self-attention.
54 | :param decoder_depth: Number of Transformer blocks in the decoder -- should be relatively shallow.
55 | :param decoder_embed_dim: Core embedding/hidden dimension for encoder vision transformer backbone.
56 | :param decoder_n_heads: Number of heads for encoder multi-headed self-attention.
57 | :param optimizer: String denoting which optimizer to use (for MAEs, usually `adamw`)
58 | :param schedule: Learning rate schedule to use; for Transformers a linear warmup + decay is recommended!
59 | :param base_lr: Base learning rate, to be scaled via a linear scaling rule (from scaling laws).
60 | :param min_lr: Minimum learning rate to decay to over the course of learning (usually 0.0)
61 | :param effective_bsz: Global batch size for update, dictates the scaling of the base_lr.
62 | :param betas: Adam optimizer betas (only applicable for `adam` and `adamw`. Prevents early loss spiking.
63 | :param weight_decay: Weight decay for global weight regularization (only applied to non-bias, non-LN layers).
64 | :param warmup_epochs: Number of epochs to warmup learning rate for linear warmup schedule.
65 | :param max_epochs: Total number of training epochs to be run.
66 | :param mask_ratio: Ratio for number of patches to mask out for MAE -- should be fairly high!
67 | :param mlp_ratio: Ratio for embedding size to Position-wise FeedForward MLP (gets shrunk back down).
68 | :param in_channels: Default number of channels in the base image -- almost always 3.
69 | :param norm_pixel_loss: Normalize decoder pixel targets for reconstruction (better perf, not interpretable).
70 | """
71 | super().__init__()
72 | self.resolution, self.patch_size, self.mask_ratio = resolution, patch_size, mask_ratio
73 | self.in_channels, self.norm_pixel_loss, self.mlp_ratio = in_channels, norm_pixel_loss, mlp_ratio
74 | self.optimizer, self.schedule, self.betas, self.weight_decay = optimizer, schedule, betas, weight_decay
75 | self.lr, self.base_lr, self.min_lr, self.effective_bsz = None, base_lr, min_lr, effective_bsz
76 | self.warmup_epochs, self.max_epochs = warmup_epochs, max_epochs
77 |
78 | # Encoder/Decoder Parameters
79 | self.encoder_depth, self.decoder_depth = encoder_depth, decoder_depth
80 | self.encoder_embed_dim, self.encoder_n_heads = encoder_embed_dim, encoder_n_heads
81 | self.decoder_embed_dim, self.decoder_n_heads = decoder_embed_dim, decoder_n_heads
82 |
83 | # MAE Encoder Parameters --> MVP uses a CLS Token for feature extraction!
84 | self.cls_token = nn.Parameter(torch.zeros(1, 1, self.encoder_embed_dim))
85 | self.patch2embed = PatchEmbed(
86 | self.resolution, self.patch_size, self.encoder_embed_dim, in_channels=self.in_channels
87 | )
88 | self.encoder_pe = nn.Parameter(
89 | torch.zeros(1, self.patch2embed.num_patches + 1, self.encoder_embed_dim), requires_grad=False
90 | )
91 | self.encoder_blocks = nn.ModuleList(
92 | [Block(self.encoder_embed_dim, self.encoder_n_heads, self.mlp_ratio) for _ in range(self.encoder_depth)]
93 | )
94 | self.encoder_norm = nn.LayerNorm(self.encoder_embed_dim, eps=1e-6)
95 |
96 | # Projection from Encoder to Decoder
97 | self.encoder2decoder = nn.Linear(self.encoder_embed_dim, self.decoder_embed_dim)
98 |
99 | # MAE Decoder Parameters -- Remember the CLS Token!
100 | self.mask_token = nn.Parameter(torch.zeros(1, 1, self.decoder_embed_dim))
101 | self.decoder_pe = nn.Parameter(
102 | torch.zeros(1, self.patch2embed.num_patches + 1, self.decoder_embed_dim), requires_grad=False
103 | )
104 | self.decoder_blocks = nn.ModuleList(
105 | [Block(self.decoder_embed_dim, self.decoder_n_heads, self.mlp_ratio) for _ in range(self.decoder_depth)]
106 | )
107 | self.decoder_norm = nn.LayerNorm(self.decoder_embed_dim, eps=1e-6)
108 | self.decoder_prediction = nn.Linear(self.decoder_embed_dim, (patch_size**2) * in_channels, bias=True)
109 |
110 | # Initialize all Weights
111 | self.initialize_weights()
112 |
113 | def initialize_weights(self) -> None:
114 | # Position Encoding -- Fixed 2D Sine-Cosine Embeddings
115 | enc_pe = get_2D_position_embeddings(self.encoder_embed_dim, int(self.patch2embed.num_patches**0.5), True)
116 | self.encoder_pe.data.copy_(torch.from_numpy(enc_pe).float().unsqueeze(0))
117 | dec_pe = get_2D_position_embeddings(self.decoder_embed_dim, int(self.patch2embed.num_patches**0.5), True)
118 | self.decoder_pe.data.copy_(torch.from_numpy(dec_pe).float().unsqueeze(0))
119 |
120 | # Initialize PatchEmbedding as a Linear...
121 | nn.init.xavier_uniform_(self.patch2embed.proj.weight.data.view([self.patch2embed.proj.weight.data.shape[0], -1]))
122 |
123 | # Initialize CLS Token & Mask Token w/ Truncated Normal
124 | nn.init.normal_(self.cls_token, std=0.02)
125 | nn.init.normal_(self.mask_token, std=0.02)
126 |
127 | # Everything else...
128 | self.apply(self.transformer_initializer)
129 |
130 | @staticmethod
131 | def transformer_initializer(m: nn.Module) -> None:
132 | if isinstance(m, nn.Linear):
133 | # Use xavier_uniform following Jax ViT
134 | torch.nn.init.xavier_uniform_(m.weight)
135 | if isinstance(m, nn.Linear) and m.bias is not None:
136 | nn.init.constant_(m.bias, 0.0)
137 | elif isinstance(m, nn.LayerNorm):
138 | nn.init.constant_(m.weight, 1.0)
139 | nn.init.constant_(m.bias, 0.0)
140 |
141 | def mask(
142 | self, patches: torch.Tensor, mask_ratio: Optional[float] = None
143 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
144 | """Perform per-sample random masking by shuffling :: uses argsort random noise to identify masked patches"""
145 | bsz, n_patches, embed_dim = patches.shape
146 | if mask_ratio is not None:
147 | n_keep = int(n_patches * (1 - mask_ratio))
148 | else:
149 | n_keep = int(n_patches * (1 - self.mask_ratio))
150 |
151 | # Sample some noise of n_patches size, argsort to get shuffled IDs (keep small), argsort again to "unshuffle"
152 | # > For clarity -- argsort is an invertible transformation (if argsort `restore`, recovers `shuffle`)
153 | shuffle_idxs = torch.argsort(torch.rand(bsz, n_patches, device=patches.device), dim=1)
154 | restore_idxs = torch.argsort(shuffle_idxs, dim=1)
155 |
156 | # Get "keep" (visible) patches
157 | visible_patches = torch.gather(patches, dim=1, index=shuffle_idxs[:, :n_keep, None].repeat(1, 1, embed_dim))
158 |
159 | # Generate the binary mask --> IMPORTANT :: `0` is keep, `1` is remove (following FAIR MAE convention)
160 | mask = torch.ones(bsz, n_patches, device=patches.device)
161 | mask[:, :n_keep] = 0
162 | mask = torch.gather(mask, dim=1, index=restore_idxs)
163 |
164 | return visible_patches, mask, restore_idxs
165 |
166 | def get_representations(self, img: torch.Tensor, mode: str = "patch") -> torch.Tensor:
167 | """
168 | Given a single image, extract representations subject to the specified mode in < patch | cls >, where "cls"
169 | denotes extracting the token embedding; for our experiments, we find that running multiheaded attention
170 | pooling on top of the "patch" embeddings is *always* better!
171 |
172 | :param img: Processed batch of images :: [bsz, 3, 224, 224]
173 | :param mode: Type of representation to extract -- `patch` (sequence of patch embeddings) or `cls` ()
174 |
175 | :return: Extracted representations given img input.
176 | """
177 | assert img.ndim == 4, "Invalid input to `get_representations()`"
178 | assert mode in {"patch", "cls"}, f"Extraction mode `{mode}` not supported!"
179 |
180 | # Extract desired representations
181 | representations = self.encode(img)
182 | return representations[:, 1:] if mode == "patch" else representations[:, :1]
183 |
184 | def encode(self, img: torch.Tensor) -> torch.Tensor:
185 | """Run a single image through the MAE and extract patch embeddings."""
186 |
187 | # Note: All of this code is taken near-verbatim from the MVP repository...
188 | # > Ref: https://github.com/ir413/mvp/blob/master/mvp/backbones/vit.py#L30
189 | patches = self.patch2embed(img)
190 | cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
191 | cls_patches = torch.cat([cls_tokens, patches]) + self.encoder_pe
192 |
193 | # Apply Transformer Blocks...
194 | for block in self.encoder_blocks:
195 | cls_patches = block(cls_patches)
196 | cls_patches = self.encoder_norm(cls_patches)
197 | return cls_patches
198 |
199 | def forward_encoder(
200 | self, imgs: torch.Tensor, mask_ratio: Optional[float] = None
201 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
202 | # Patchify + Position Embedding (without the CLS Token)
203 | patches = self.patch2embed(imgs)
204 | patches_pe = patches + self.encoder_pe[:, 1:, :]
205 |
206 | # Create mask (and go ahead and mask out patches at the same time)
207 | visible_patches, mask, restore_idxs = self.mask(patches_pe, mask_ratio)
208 |
209 | # Add the CLS Token
210 | cls_token = self.cls_token + self.encoder_pe[:, :1, :]
211 | cls_tokens = cls_token.expand(imgs.shape[0], -1, -1)
212 | cls_visible_patches = torch.cat([cls_tokens, visible_patches], dim=1)
213 |
214 | # Apply Transformer Blocks...
215 | for block in self.encoder_blocks:
216 | cls_visible_patches = block(cls_visible_patches)
217 | cls_visible_patches = self.encoder_norm(cls_visible_patches)
218 |
219 | return cls_visible_patches, mask, restore_idxs
220 |
221 | def forward_decoder(self, visible_patches: torch.Tensor, restore_idxs: torch.Tensor) -> torch.Tensor:
222 | # Project patches into decoder embedding dimension
223 | projected_patches = self.encoder2decoder(visible_patches)
224 |
225 | # Add Mask Tokens to Sequence
226 | mask_tokens = self.mask_token.repeat(
227 | projected_patches.shape[0], restore_idxs.shape[1] - visible_patches.shape[1] + 1, 1
228 | )
229 |
230 | # Remove & add back CLS Token as part of the "unshuffling"
231 | concatenated_patches = torch.cat([projected_patches[:, 1:, :], mask_tokens], dim=1) # Skip CLS Token
232 | unshuffled_patches = torch.gather(
233 | concatenated_patches, dim=1, index=restore_idxs[..., None].repeat(1, 1, self.decoder_embed_dim)
234 | )
235 | cls_unshuffled_patches = torch.cat([projected_patches[:, :1, :], unshuffled_patches], dim=1) # Add CLS Token
236 |
237 | # Add Position Embeddings
238 | cls_decoder_patches = cls_unshuffled_patches + self.decoder_pe
239 |
240 | # Apply Transformer Blocks...
241 | for block in self.decoder_blocks:
242 | cls_decoder_patches = block(cls_decoder_patches)
243 | cls_decoder_patches = self.decoder_norm(cls_decoder_patches)
244 |
245 | # Run final projection, remove the CLS token, and return
246 | cls_decoder_prediction = self.decoder_prediction(cls_decoder_patches)
247 | decoder_prediction = cls_decoder_prediction[:, 1:, :]
248 | return decoder_prediction
249 |
250 | def patchify(self, imgs: torch.Tensor) -> torch.Tensor:
251 | """Convert a batch of images to their patched equivalents, by naive reshaping"""
252 | return rearrange(
253 | imgs,
254 | "bsz c (height patch_h) (width patch_w) -> bsz (height width) (patch_h patch_w c)",
255 | patch_h=self.patch_size,
256 | patch_w=self.patch_size,
257 | )
258 |
259 | def compute_loss(self, imgs: torch.Tensor, reconstructions: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
260 | assert self.norm_pixel_loss, "`norm_pixel_loss` should always be true... false only for visualizations!"
261 | targets = self.patchify(imgs)
262 |
263 | # Normalize targets...
264 | mu, var = targets.mean(dim=-1, keepdim=True), targets.var(dim=-1, unbiased=True, keepdim=True)
265 | targets = (targets - mu) / ((var + 1e-6) ** 0.5)
266 |
267 | # Compute mean loss per patch first...
268 | mse = (reconstructions - targets) ** 2
269 | avg_loss_per_patch = mse.mean(dim=-1)
270 |
271 | # Compute mean loss only on *removed* patches and return
272 | return (avg_loss_per_patch * mask).sum() / mask.sum()
273 |
274 | def forward(
275 | self, imgs: torch.Tensor, mask_ratio: Optional[float] = None
276 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
277 | visible_patches, mask, restore_idxs = self.forward_encoder(imgs, mask_ratio)
278 | reconstructions = self.forward_decoder(visible_patches, restore_idxs)
279 | loss = self.compute_loss(imgs, reconstructions, mask)
280 |
281 | return loss, reconstructions, mask
282 |
283 | def configure_optimizer(self) -> Tuple[torch.optim.Optimizer, Callable[[int, float], float]]:
284 | # Short-Circuit on Valid Optimizers
285 | if self.optimizer not in ["adamw"]:
286 | raise NotImplementedError(f"Optimizer `{self.optimizer}` not supported - try [`adamw`] instead!")
287 |
288 | # Create Parameter Groups --> Bias terms, Normalization layer parameters shouldn't be decayed...
289 | # > This is a compact rewrite of `param_groups_weight_decay()` from TIMM because I don't want the dependency
290 | decay, no_decay = [], []
291 | for name, param in self.named_parameters():
292 | if not param.requires_grad:
293 | continue
294 |
295 | # Check on any parameters with fewer than 2 dimensions or with "bias" in the name...
296 | if param.ndim <= 1 or name.endswith(".bias"):
297 | no_decay.append(param)
298 | else:
299 | decay.append(param)
300 |
301 | # Build Parameter Groups
302 | groups = [{"params": decay, "weight_decay": self.weight_decay}, {"params": no_decay, "weight_decay": 0.0}]
303 |
304 | # Compute LR -- MAE uses the `linear scaling rule` :: lr = base_lr * (effective_bsz / 256)
305 | # > https://github.com/facebookresearch/mae/blob/main/PRETRAIN.md
306 | self.lr = self.base_lr * (self.effective_bsz / 256)
307 |
308 | # Create Optimizer & LR Scheduler
309 | optimizer = torch.optim.AdamW(groups, lr=self.lr, betas=self.betas)
310 | update_lr = get_lr_update(optimizer, self.schedule, self.lr, self.min_lr, self.warmup_epochs, self.max_epochs)
311 | return optimizer, update_lr
312 |
--------------------------------------------------------------------------------
/voltron/models/reproductions/vrn3m.py:
--------------------------------------------------------------------------------
1 | """
2 | vrn3m.py
3 |
4 | PyTorch Module defining an R3M model (with a ResNet 50 encoder), exactly as described in Nair et. al. 2021, with all the
5 | requisite hyperparameters.
6 |
7 | Reference:
8 | - https://github.com/facebookresearch/r3m
9 | """
10 | from typing import Callable, Tuple
11 |
12 | import torch
13 | import torch.nn as nn
14 | import transformers
15 | from einops import rearrange
16 | from torchvision.models import resnet50
17 |
18 | from voltron.models.util.optimization import get_lr_update
19 |
20 | # Suppress Transformers Logging
21 | transformers.logging.set_verbosity_error()
22 |
23 |
24 | class VRN3M(nn.Module):
25 | def __init__(
26 | self,
27 | resolution: int,
28 | fc_dim: int,
29 | language_model: str,
30 | hf_cache: str,
31 | language_dim: int,
32 | reward_dim: int,
33 | n_negatives: int,
34 | lang_reward_weight: float,
35 | tcn_weight: float,
36 | l1_weight: float,
37 | l2_weight: float,
38 | optimizer: str,
39 | lr: float,
40 | eps: float = 1e-8,
41 | ):
42 | """
43 | Initialize an ResNet-50 R3M model with the required architecture parameters.
44 |
45 | :param resolution: Base image resolution -- usually 224 (ImageNet size).
46 | :param fc_dim: Dimensionality of the pooled embedding coming out of the ResNet (for RN50, fc_dim = 2048)
47 | :param language_model: Language model to freeze for encoding narrations/utterances.
48 | :param hf_cache: Cache directory to store pretrained models, for safe distributed training.
49 | :param language_dim: Dimensionality of the language embedding coming out of the pretrained LM.
50 | :param reward_dim: Hidden layer dimensionality for the language-reward MLP.
51 | :param n_negatives: Number of cross-batch negatives to sample for contrastive learning.
52 | :param lang_reward_weight: Weight applied to the contrastive "language alignment" loss term.
53 | :param tcn_weight: Weight applied to the time contrastive loss term.
54 | :param l1_weight: Weight applied to the L1 regularization loss term.
55 | :param l2_weight: Weight applied to the L2 regularization loss term.
56 | :param optimizer: String denoting which optimizer to use (for R3M, usually `adam`).
57 | :param lr: Learning rate (fixed for ResNet R3M models) for training.
58 | :param eps: Epsilon for preventing divide by zero in the InfoNCE loss terms.
59 | """
60 | super().__init__()
61 | self.resolution, self.fc_dim, self.n_negatives, self.eps = resolution, fc_dim, n_negatives, eps
62 | self.language_dim, self.reward_dim, self.optimizer, self.lr = language_dim, reward_dim, optimizer, lr
63 | self.embed_dim = self.fc_dim
64 |
65 | # Weights for each loss term
66 | self.lang_reward_weight, self.tcn_weight = lang_reward_weight, tcn_weight
67 | self.l1_weight, self.l2_weight = l1_weight, l2_weight
68 |
69 | # Create ResNet50 --> set `rn.fc` to the Identity() to extract final features of dim = `fc_dim`
70 | self.resnet = resnet50(weights=None)
71 | self.resnet.fc = nn.Identity()
72 | self.resnet.train()
73 |
74 | # Create Language Reward Model
75 | self.language_reward = nn.Sequential(
76 | nn.Linear(self.fc_dim + self.fc_dim + self.language_dim, self.reward_dim),
77 | nn.ReLU(),
78 | nn.Linear(self.reward_dim, self.reward_dim),
79 | nn.ReLU(),
80 | nn.Linear(self.reward_dim, self.reward_dim),
81 | nn.ReLU(),
82 | nn.Linear(self.reward_dim, self.reward_dim),
83 | nn.ReLU(),
84 | nn.Linear(self.reward_dim, 1),
85 | nn.Sigmoid(),
86 | )
87 |
88 | # Create Language Model & Language Reward MLP --> LM has requires_grad = False
89 | # > For BERT models, our "embedding" is just going to be the last hidden state
90 | # > Assumes inputs to forward pass are pre-tokenized!
91 | self.tokenizer = transformers.AutoTokenizer.from_pretrained(language_model, cache_dir=hf_cache)
92 | self.lm = transformers.AutoModel.from_pretrained(language_model, cache_dir=hf_cache)
93 | self.lm.eval()
94 |
95 | # Shape Assertion -- make sure self.language_dim actually is the same as the LM dimension!
96 | assert self.lm.config.dim == self.language_dim, "Language model embedding dimension != self.language_dim!"
97 |
98 | # Freeze the LM
99 | for _name, param in self.lm.named_parameters():
100 | param.requires_grad = False
101 |
102 | def get_representations(self, img: torch.Tensor) -> torch.Tensor:
103 | """
104 | Given a single image, extract R3M "default" (ResNet pooled) dense representation.
105 |
106 | :param img: Processed batch of images :: [bsz, 3, 224, 224]
107 | :return: Extracted R3M dense representation given img input.
108 | """
109 | assert img.ndim == 4, "Invalid input to `get_representations()`"
110 | representation = self.resnet(img)
111 | return representation.unsqueeze(1)
112 |
113 | def encode_images(self, imgs: torch.Tensor) -> torch.Tensor:
114 | """Feed images through ResNet-50 to get single embedding after global average pooling."""
115 | return self.resnet(imgs)
116 |
117 | def encode_language(self, lang: torch.Tensor, lang_mask: torch.Tensor) -> torch.Tensor:
118 | """Encode language by feeding the *pre-tokenized text* through the frozen language model."""
119 | self.lm.eval()
120 | with torch.no_grad():
121 | transformer_embeddings = self.lm(lang, attention_mask=lang_mask).last_hidden_state
122 | return transformer_embeddings.mean(dim=1)
123 |
124 | def get_reward(self, initial: torch.Tensor, later: torch.Tensor, lang: torch.Tensor) -> torch.Tensor:
125 | return self.language_reward(torch.cat([initial, later, lang], dim=-1)).squeeze()
126 |
127 | def extract_features(self, img: torch.Tensor) -> torch.Tensor:
128 | """Run a single image of shape [1, 3, 224, 224] through the ResNet and extract the feature."""
129 | return self.encode_images(img).detach()
130 |
131 | def forward(self, imgs: torch.Tensor, lang: torch.Tensor, lang_mask: torch.Tensor) -> Tuple[torch.Tensor, ...]:
132 | """
133 | Run a forward pass through the model, computing the *full* R3M loss -- the TCN contrastive loss, the Language
134 | Alignment loss, and both sparsity losses, as well as the full loss (which will get optimized)!
135 |
136 | :param imgs: A [bsz, 5, in_channels, resolution, resolution] tensor of (start, i, j, k, end) sequences.
137 | :param lang: Tokenized language of dimensionality [bsz, seq_len] to be fed to the language model.
138 | :param lang_mask: Attention mask computed by the tokenizer, as a result of padding to the max_seq_len.
139 |
140 | :return: Tuple of losses, as follows:
141 | > (combined_loss, tcn_loss, reward_loss, l1_loss, l2_loss, tcn_acc, reward_acc)
142 | """
143 | # Encode each image separately... feed to transformer... then reshape
144 | all_images = rearrange(imgs, "bsz n_states c res1 res2 -> (bsz n_states) c res1 res2", n_states=5)
145 | all_embeddings = self.encode_images(all_images)
146 | initial, state_i, state_j, state_k, final = rearrange(
147 | all_embeddings, "(bsz n_states) embed -> n_states bsz embed", n_states=5
148 | )
149 |
150 | # Compute Regularization Losses
151 | l1_loss = torch.linalg.norm(all_embeddings, ord=1, dim=-1).mean()
152 | l2_loss = torch.linalg.norm(all_embeddings, ord=2, dim=-1).mean()
153 |
154 | # Compute TCN Loss
155 | tcn_loss, tcn_acc = self.get_time_contrastive_loss(state_i, state_j, state_k)
156 |
157 | # Compute Language Alignment/Predictive Loss
158 | lang_reward_loss, rew_acc = self.get_reward_loss(lang, lang_mask, initial, state_i, state_j, state_k, final)
159 |
160 | # Compute full weighted loss & return...
161 | loss = (
162 | (self.l1_weight * l1_loss)
163 | + (self.l2_weight * l2_loss)
164 | + (self.tcn_weight * tcn_loss)
165 | + (self.lang_reward_weight * lang_reward_loss)
166 | )
167 | return loss, tcn_loss, lang_reward_loss, l1_loss, l2_loss, tcn_acc, rew_acc
168 |
169 | @staticmethod
170 | def time_similarity(state_x: torch.Tensor, state_y: torch.Tensor, use_l2: bool = True) -> torch.Tensor:
171 | """Computes similarity between embeddings via -L2 distance."""
172 | assert use_l2, "Non-L2 time-similarity functions not yet implemented!"
173 | return -torch.linalg.norm(state_x - state_y, dim=-1)
174 |
175 | def get_time_contrastive_loss(
176 | self, state_i: torch.Tensor, state_j: torch.Tensor, state_k: torch.Tensor
177 | ) -> Tuple[torch.Tensor, ...]:
178 | """Evaluates the Time-Contrastive Loss, computed via InfoNCE."""
179 |
180 | # *Punchline* - we want `sim(i, j)` to be higher than `sim(i, k)` for some k > j (goes both ways)
181 | # `Reward(s*_0, s*_ As our positive examples --> we sample (s_i, s_j) and (s_j, s_k).
183 | # > Our negatives --> other pairs from the triplet, cross-batch negatives!
184 | sim_i_j_exp = torch.exp(self.time_similarity(state_i, state_j))
185 | sim_j_k_exp = torch.exp(self.time_similarity(state_j, state_k))
186 |
187 | # Add a "hard" negative!
188 | neg_i_k_exp = torch.exp(self.time_similarity(state_i, state_k))
189 |
190 | # Obtain *cross-batch* negatives
191 | bsz, neg_i, neg_j = state_i.shape[0], [], []
192 | for _ in range(self.n_negatives):
193 | neg_idx = torch.randperm(bsz)
194 | state_i_shuf = state_i[neg_idx]
195 | neg_idx = torch.randperm(bsz)
196 | state_j_shuf = state_j[neg_idx]
197 | neg_i.append(self.time_similarity(state_i, state_i_shuf))
198 | neg_j.append(self.time_similarity(state_j, state_j_shuf))
199 | neg_i_exp, neg_j_exp = torch.exp(torch.stack(neg_i, -1)), torch.exp(torch.stack(neg_j, -1))
200 |
201 | # Compute InfoNCE
202 | denominator_i = sim_i_j_exp + neg_i_k_exp + neg_i_exp.sum(-1)
203 | denominator_j = sim_j_k_exp + neg_i_k_exp + neg_j_exp.sum(-1)
204 | nce_i = -torch.log(self.eps + (sim_i_j_exp / (self.eps + denominator_i)))
205 | nce_j = -torch.log(self.eps + (sim_j_k_exp / (self.eps + denominator_j)))
206 | nce = (nce_i + nce_j) / 2
207 |
208 | # Compute "accuracy"
209 | i_j_acc = (1.0 * (sim_i_j_exp > neg_i_k_exp)).mean()
210 | j_k_acc = (1.0 * (sim_j_k_exp > neg_i_k_exp)).mean()
211 | acc = (i_j_acc + j_k_acc) / 2
212 |
213 | return nce.mean(), acc
214 |
215 | def get_reward_loss(
216 | self,
217 | lang: torch.Tensor,
218 | lang_mask: torch.Tensor,
219 | initial: torch.Tensor,
220 | state_i: torch.Tensor,
221 | state_j: torch.Tensor,
222 | state_k: torch.Tensor,
223 | final: torch.Tensor,
224 | ) -> Tuple[torch.Tensor, ...]:
225 | """Evaluates the Language-Alignment Reward Loss, computed via InfoNCE."""
226 | lang_embed = self.encode_language(lang, lang_mask)
227 |
228 | # *Punchline* - we want `Reward(s_0, s_t, l)` to be higher than `Reward(s_0, s_ As our positive examples --> we sample s_j, s_k, and s_final (excluding s_i)
231 | pos_final_exp = torch.exp(self.get_reward(initial, final, lang_embed))
232 | pos_j_exp = torch.exp(self.get_reward(initial, state_j, lang_embed))
233 | pos_k_exp = torch.exp(self.get_reward(initial, state_k, lang_embed))
234 |
235 | # Add the within-context negatives <--> these are the most informative examples!
236 | # > We use initial, initial as a negative for the first one, just to get reward model to "capture progress"
237 | negs_final = [self.get_reward(initial, initial, lang_embed)]
238 | negs_j = [self.get_reward(initial, state_i, lang_embed)]
239 | negs_k = [self.get_reward(initial, state_j, lang_embed)]
240 |
241 | # Cross Batch Negatives -- same as positives (indexing), but from a different batch!
242 | # > @SK :: Unclear how well this will unroll on TPUs...
243 | bsz = initial.shape[0]
244 | for _ in range(self.n_negatives):
245 | # We get three random indices to further minimize correlation... from the R3M codebase!
246 | neg_idx = torch.randperm(bsz)
247 | negs_final.append(self.get_reward(initial[neg_idx], final[neg_idx], lang_embed))
248 | neg_idx = torch.randperm(bsz)
249 | negs_j.append(self.get_reward(initial[neg_idx], state_j[neg_idx], lang_embed))
250 | neg_idx = torch.randperm(bsz)
251 | negs_k.append(self.get_reward(initial[neg_idx], state_k[neg_idx], lang_embed))
252 |
253 | # Flatten & exponentiate; get ready for the InfoNCE
254 | negs_final, negs_j, negs_k = torch.stack(negs_final, -1), torch.stack(negs_j, -1), torch.stack(negs_k, -1)
255 | negs_final_exp, negs_j_exp, negs_k_exp = torch.exp(negs_final), torch.exp(negs_j), torch.exp(negs_k)
256 |
257 | # Compute InfoNCE
258 | denominator_final = pos_final_exp + negs_final_exp.sum(-1)
259 | denominator_j = pos_j_exp + negs_j_exp.sum(-1)
260 | denominator_k = pos_k_exp + negs_k_exp.sum(-1)
261 |
262 | nce_final = -torch.log(self.eps + (pos_final_exp / (self.eps + denominator_final)))
263 | nce_j = -torch.log(self.eps + (pos_j_exp / (self.eps + denominator_j)))
264 | nce_k = -torch.log(self.eps + (pos_k_exp / (self.eps + denominator_k)))
265 |
266 | # Compute "accuracy"
267 | acc_final = (1.0 * (negs_final_exp.max(dim=-1)[0] < pos_final_exp)).mean()
268 | acc_j = (1.0 * (negs_j_exp.max(dim=-1)[0] < pos_j_exp)).mean()
269 | acc_k = (1.0 * (negs_k_exp.max(dim=-1)[0] < pos_k_exp)).mean()
270 | acc = (acc_final + acc_j + acc_k) / 3
271 | nce = (nce_final + nce_j + nce_k) / 3
272 |
273 | return nce.mean(), acc
274 |
275 | def configure_optimizer(self) -> Tuple[torch.optim.Optimizer, Callable[[int, float], float]]:
276 | # Short-Circuit on Valid Optimizers
277 | if self.optimizer not in ["adam"]:
278 | raise NotImplementedError(f"Optimizer `{self.optimizer}` not supported - try [`adam`] instead!")
279 |
280 | # Create Optimizer and (No-Op) LR Scheduler
281 | optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
282 | update_lr = get_lr_update(
283 | optimizer, schedule="none", lr=self.lr, min_lr=self.lr, warmup_epochs=-1, max_epochs=-1
284 | )
285 | return optimizer, update_lr
286 |
--------------------------------------------------------------------------------
/voltron/models/util/__init__.py:
--------------------------------------------------------------------------------
1 | from .extraction import instantiate_extractor
2 |
--------------------------------------------------------------------------------
/voltron/models/util/extraction.py:
--------------------------------------------------------------------------------
1 | """
2 | extraction.py
3 |
4 | General Extraction module definitions & associated utilities.
5 |
6 | References:
7 | - Set Transformers (MAP): https://arxiv.org/abs/1810.00825.pdf
8 | """
9 | from typing import Callable
10 |
11 | import torch
12 | import torch.nn as nn
13 | from einops import repeat
14 |
15 | from voltron.models.util.transformer import RMSNorm, SwishGLU
16 |
17 | # === Multiheaded Attention Pooling ===
18 |
19 |
20 | # As defined in Set Transformers () -- basically the above, additionally taking in
21 | # a set of $k$ learned "seed vectors" that are used to "pool" information.
22 | class MAPAttention(nn.Module):
23 | def __init__(self, embed_dim: int, n_heads: int) -> None:
24 | """Multi-Input Multi-Headed Attention Operation"""
25 | super().__init__()
26 | assert embed_dim % n_heads == 0, "`embed_dim` must be divisible by `n_heads`!"
27 | self.n_heads, self.scale = n_heads, (embed_dim // n_heads) ** -0.5
28 |
29 | # Projections (no bias) --> separate for Q (seed vector), and KV ("pool" inputs)
30 | self.q, self.kv = nn.Linear(embed_dim, embed_dim, bias=False), nn.Linear(embed_dim, 2 * embed_dim, bias=False)
31 | self.proj = nn.Linear(embed_dim, embed_dim)
32 |
33 | def forward(self, seed: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
34 | (B_s, K, C_s), (B_x, N, C_x) = seed.shape, x.shape
35 | assert C_s == C_x, "Seed vectors and pool inputs must have the same embedding dimensionality!"
36 |
37 | # Project Seed Vectors to `queries`
38 | q = self.q(seed).reshape(B_s, K, self.n_heads, C_s // self.n_heads).permute(0, 2, 1, 3)
39 | kv = self.kv(x).reshape(B_x, N, 2, self.n_heads, C_x // self.n_heads).permute(2, 0, 3, 1, 4)
40 | k, v = kv.unbind(0)
41 |
42 | # Attention --> compute weighted sum over values!
43 | scores = q @ (k.transpose(-2, -1) * self.scale)
44 | attn = scores.softmax(dim=-1)
45 | vals = (attn @ v).transpose(1, 2).reshape(B_s, K, C_s)
46 |
47 | # Project back to `embed_dim`
48 | return self.proj(vals)
49 |
50 |
51 | class MAPBlock(nn.Module):
52 | def __init__(
53 | self,
54 | n_latents: int,
55 | embed_dim: int,
56 | n_heads: int,
57 | mlp_ratio: float = 4.0,
58 | do_rms_norm: bool = True,
59 | do_swish_glu: bool = True,
60 | ) -> None:
61 | """Multiheaded Attention Pooling Block -- note that for MAP, we adopt earlier post-norm conventions."""
62 | super().__init__()
63 | self.n_latents, self.embed_dim, self.n_heads = n_latents, embed_dim, 2 * n_heads
64 |
65 | # Projection Operator
66 | self.projection = nn.Linear(embed_dim, self.embed_dim)
67 |
68 | # Initialize Latents
69 | self.latents = nn.Parameter(torch.zeros(self.n_latents, self.embed_dim))
70 | nn.init.normal_(self.latents, std=0.02)
71 |
72 | # Custom MAP Attention (seed, encoder outputs) -> seed
73 | self.attn_norm = RMSNorm(self.embed_dim) if do_rms_norm else nn.LayerNorm(self.embed_dim, eps=1e-6)
74 | self.attn = MAPAttention(self.embed_dim, n_heads=self.n_heads)
75 |
76 | # Position-wise Feed-Forward Components
77 | self.mlp_norm = RMSNorm(self.embed_dim) if do_rms_norm else nn.LayerNorm(self.embed_dim, eps=1e-6)
78 | self.mlp = nn.Sequential(
79 | # Handle SwishGLU vs. GELU MLP...
80 | (
81 | SwishGLU(self.embed_dim, int(mlp_ratio * self.embed_dim))
82 | if do_swish_glu
83 | else nn.Sequential(nn.Linear(self.embed_dim, int(mlp_ratio * self.embed_dim)), nn.GELU())
84 | ),
85 | nn.Linear(int(mlp_ratio * self.embed_dim), self.embed_dim),
86 | )
87 |
88 | def forward(self, x: torch.Tensor) -> torch.Tensor:
89 | latents = repeat(self.latents, "n_latents d -> bsz n_latents d", bsz=x.shape[0])
90 | latents = self.attn_norm(latents + self.attn(latents, self.projection(x)))
91 | latents = self.mlp_norm(latents + self.mlp(latents))
92 | return latents.squeeze(dim=1)
93 |
94 |
95 | # MAP Extractor Instantiation --> factory for creating extractors with the given parameters.
96 | def instantiate_extractor(backbone: nn.Module, n_latents: int = 1) -> Callable[[], nn.Module]:
97 | def initialize() -> nn.Module:
98 | return MAPBlock(n_latents, backbone.embed_dim, backbone.n_heads)
99 |
100 | return initialize
101 |
--------------------------------------------------------------------------------
/voltron/models/util/optimization.py:
--------------------------------------------------------------------------------
1 | """
2 | optimization.py
3 |
4 | General utilities for optimization, e.g., schedulers such as Linear Warmup w/ Cosine Decay for Transformer training.
5 | Notably *does not* use the base PyTorch LR Scheduler, since we call it continuously, across epochs, across steps;
6 | PyTorch has no built-in way of separating the two without coupling to the DataLoader, so may as well make this explicit
7 | in the parent loop.
8 |
9 | References
10 | - MAE: https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/lr_sched.py
11 | - ⚡️-Bolts: https://github.com/PyTorchLightning/lightning-bolts/blob/master/pl_bolts/optimizers/lr_scheduler.py
12 | """
13 | import math
14 | from typing import Callable
15 |
16 | from torch.optim.optimizer import Optimizer
17 |
18 |
19 | def get_lr_update(
20 | opt: Optimizer, schedule: str, lr: float, min_lr: float, warmup_epochs: int, max_epochs: int
21 | ) -> Callable[[int, float], float]:
22 | if schedule == "linear-warmup+cosine-decay":
23 |
24 | def lr_update(epoch: int, fractional_progress: float) -> float:
25 | """Run the warmup check for linear increase, else cosine decay."""
26 | if (epoch + fractional_progress) < warmup_epochs:
27 | new_lr = lr * (epoch + fractional_progress) / max(1.0, warmup_epochs)
28 | else:
29 | # Cosine Decay --> as defined in the SGDR Paper...
30 | progress = ((epoch + fractional_progress) - warmup_epochs) / max(1.0, max_epochs - warmup_epochs)
31 | new_lr = min_lr + (lr - min_lr) * (0.5 * (1 + math.cos(math.pi * progress)))
32 |
33 | # Apply...
34 | for group in opt.param_groups:
35 | if "lr_scale" in group:
36 | group["lr"] = new_lr * group["lr_scale"]
37 | else:
38 | group["lr"] = new_lr
39 |
40 | return new_lr
41 |
42 | elif schedule == "none":
43 |
44 | def lr_update(_: int, __: float) -> float:
45 | return lr
46 |
47 | else:
48 | raise NotImplementedError(f"Schedule `{schedule}` not implemented!")
49 |
50 | return lr_update
51 |
--------------------------------------------------------------------------------
/voltron/models/util/transformer.py:
--------------------------------------------------------------------------------
1 | """
2 | transformer.py
3 |
4 | General Transformer modules & utilities.
5 |
6 | References:
7 | - https://github.com/facebookresearch/mae
8 | - https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
9 | """
10 | from typing import Optional
11 |
12 | import numpy as np
13 | import torch
14 | import torch.nn as nn
15 | from einops import rearrange
16 |
17 | # === Position Encoding Utilities ===
18 |
19 |
20 | # Helper/Utility Function -- computes simple 1D sinusoidal position embeddings for both 1D/2D use cases.
21 | # > We'll be combining two 1D sin-cos (traditional) position encodings for height/width of an image (grid features).
22 | def get_1D_sine_cosine(dim: int, pos: np.ndarray) -> np.ndarray:
23 | omega = np.arange(dim // 2, dtype=np.float32) / (dim / 2.0)
24 | omega = 1.0 / (10000**omega)
25 | out = np.einsum("m,d->md", pos.reshape(-1), omega) # [flatten(pos) x omega] -- outer product!
26 | emb_sin, emb_cos = np.sin(out), np.cos(out)
27 | return np.concatenate([emb_sin, emb_cos], axis=1) # [flatten(pos) x D]
28 |
29 |
30 | # 1D Sine-Cosine Position Embedding -- standard from "Attention is all you need!"
31 | def get_1D_position_embeddings(embed_dim: int, length: int) -> np.ndarray:
32 | return get_1D_sine_cosine(embed_dim, np.arange(length))
33 |
34 |
35 | # 2D Sine-Cosine Position Embedding (from MAE repository)
36 | # > https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
37 | def get_2D_position_embeddings(embed_dim: int, grid_size: int, cls_token: bool = False) -> np.ndarray:
38 | # Create 2D Position embeddings by taking cross product of height and width and splicing 1D embeddings...
39 | grid_h, grid_w = np.arange(grid_size, dtype=np.float32), np.arange(grid_size, dtype=np.float32)
40 | grid = np.stack(np.meshgrid(grid_w, grid_h), axis=0).reshape(2, 1, grid_size, grid_size) # w goes first?
41 |
42 | # Use half of dimensions to encode grid_h, other half to encode grid_w
43 | emb_h, emb_w = get_1D_sine_cosine(embed_dim // 2, grid[0]), get_1D_sine_cosine(embed_dim // 2, grid[1])
44 | pos_embed = np.concatenate([emb_h, emb_w], axis=1)
45 |
46 | # CLS token handling (only for R-MVP)
47 | if cls_token:
48 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
49 |
50 | return pos_embed
51 |
52 |
53 | # === Vision Transformer Building Blocks ===
54 |
55 |
56 | # Patch Embedding Module
57 | class PatchEmbed(nn.Module):
58 | def __init__(
59 | self,
60 | resolution: int,
61 | patch_size: int,
62 | embed_dim: int,
63 | in_channels: int = 3,
64 | flatten: bool = True,
65 | ):
66 | super().__init__()
67 | self.resolution, self.patch_size = (resolution, resolution), (patch_size, patch_size)
68 | self.grid_size = (self.resolution[0] // self.patch_size[0], self.resolution[1] // self.patch_size[1])
69 | self.num_patches = self.grid_size[0] * self.grid_size[1]
70 | self.flatten = flatten
71 | self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=self.patch_size, stride=self.patch_size)
72 |
73 | def forward(self, patches: torch.Tensor) -> torch.Tensor:
74 | patch_embeddings = self.proj(patches)
75 | if self.flatten:
76 | return rearrange(patch_embeddings, "bsz embed patch_h patch_w -> bsz (patch_h patch_w) embed")
77 | return patch_embeddings
78 |
79 |
80 | # === Stability Utilities ===
81 |
82 |
83 | # LayerScale -- Trainable scaling for residual blocks -- Mistral/CaIT
84 | class LayerScale(nn.Module):
85 | def __init__(self, dim: int, init_values: float = 0.1) -> None: # CaIT :: 0.1 -> lay 12, 1e-5 -> lay 24, 1e-6...
86 | super().__init__()
87 | self.gamma = nn.Parameter(init_values * torch.ones(dim))
88 |
89 | def forward(self, x: torch.Tensor) -> torch.Tensor:
90 | return x * self.gamma
91 |
92 |
93 | # RMSNorm -- Better, simpler alternative to LayerNorm
94 | class RMSNorm(nn.Module):
95 | def __init__(self, dim: int, eps: float = 1e-8) -> None:
96 | super().__init__()
97 | self.scale, self.eps = dim**-0.5, eps
98 | self.g = nn.Parameter(torch.ones(dim))
99 |
100 | def forward(self, x: torch.Tensor) -> torch.Tensor:
101 | norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
102 | return x / norm.clamp(min=self.eps) * self.g
103 |
104 |
105 | # SwishGLU -- A Gated Linear Unit (GLU) with the Swish activation; always better than GELU MLP!
106 | class SwishGLU(nn.Module):
107 | def __init__(self, in_dim: int, out_dim: int) -> None:
108 | super().__init__()
109 | self.act, self.project = nn.SiLU(), nn.Linear(in_dim, 2 * out_dim)
110 |
111 | def forward(self, x: torch.Tensor) -> torch.Tensor:
112 | projected, gate = self.project(x).tensor_split(2, dim=-1)
113 | return projected * self.act(gate)
114 |
115 |
116 | # === Fundamental Transformer Building Blocks ===
117 |
118 |
119 | class Attention(nn.Module):
120 | def __init__(self, embed_dim: int, n_heads: int, dropout: float = 0.0) -> None:
121 | """Multi-Headed Self-Attention Operation"""
122 | super().__init__()
123 | assert embed_dim % n_heads == 0, "`embed_dim` must be divisible by `n_heads`!"
124 | self.n_heads, self.scale = n_heads, (embed_dim // n_heads) ** -0.5
125 | self.attn_softmax = None
126 |
127 | # Projections
128 | self.qkv, self.proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True), nn.Linear(embed_dim, embed_dim)
129 | self.dropout = nn.Dropout(dropout)
130 |
131 | def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
132 | B, N, C = x.shape
133 |
134 | # Project to Q-K-V
135 | qkv = self.qkv(x).reshape(B, N, 3, self.n_heads, C // self.n_heads).permute(2, 0, 3, 1, 4)
136 | q, k, v = qkv.unbind(0)
137 |
138 | # Self-attention -- with masking!
139 | scores = q @ (k.transpose(-2, -1) * self.scale)
140 | if mask is not None:
141 | if mask.ndim == 2:
142 | mask = rearrange(mask, "bsz seq -> bsz 1 seq 1")
143 | elif mask.ndim != 4:
144 | raise NotImplementedError("Attention got `mask` of shape not in {2, 4}!")
145 |
146 | # Mask out by filling indices with negative infinity...
147 | scores = scores.masked_fill(mask == 0, torch.finfo(scores.dtype).min)
148 |
149 | # Compute weighted sum over values
150 | self.attn_softmax = scores.softmax(dim=-1)
151 | vals = (self.attn_softmax @ v).transpose(1, 2).reshape(B, N, C)
152 |
153 | # Project back to `embed_dim` -- with optional dropout
154 | vals = self.dropout(self.proj(vals))
155 | return vals
156 |
157 |
158 | class Block(nn.Module):
159 | def __init__(
160 | self,
161 | embed_dim: int,
162 | n_heads: int,
163 | mlp_ratio: float = 4.0,
164 | dropout: float = 0.0,
165 | do_rms_norm: bool = False,
166 | do_swish_glu: bool = False,
167 | do_layer_scale: bool = False,
168 | ) -> None:
169 | """
170 | Transformer Block Implementation (modality-agnostic).
171 |
172 | :param embed_dim: Core embedding/hidden dimension for vision transformer backbone.
173 | :param n_heads: Number of heads for multi-headed self-attention.
174 | :param mlp_ratio: Ratio for embedding size to position-wise feed-forward MLP (gets shrunk back down).
175 | :param dropout: [Optional] dropout for projection layer and MLPs -- for MAEs, always 0.0!
176 | :param do_rms_norm: Boolean whether or not to use RMSNorm in lieu of LayerNorm within block.
177 | :param do_swish_glu: Use the Swish-variant of the Gated Linear Unit for the feed-forward layers.
178 | :param do_layer_scale: Boolean whether or not to use LayerScale from Mistral/CaIT w/ initialization of 0.1.
179 | """
180 | super().__init__()
181 | self.embed_dim, self.n_heads, self.do_layer_scale = embed_dim, n_heads, do_layer_scale
182 |
183 | # Attention Components
184 | self.pre_norm_attn = RMSNorm(self.embed_dim) if do_rms_norm else nn.LayerNorm(self.embed_dim, eps=1e-6)
185 | self.attn = Attention(self.embed_dim, n_heads=n_heads, dropout=dropout)
186 | if do_layer_scale:
187 | self.layer_scale_attn = LayerScale(self.embed_dim)
188 |
189 | # Position-wise Feed-Forward Components
190 | self.pre_norm_mlp = RMSNorm(self.embed_dim) if do_rms_norm else nn.LayerNorm(self.embed_dim, eps=1e-6)
191 | self.mlp = nn.Sequential(
192 | # Handle SwishGLU vs. GELU MLP...
193 | (
194 | SwishGLU(embed_dim, int(mlp_ratio * embed_dim))
195 | if do_swish_glu
196 | else nn.Sequential(nn.Linear(embed_dim, int(mlp_ratio * embed_dim)), nn.GELU())
197 | ),
198 | nn.Dropout(dropout),
199 | nn.Linear(int(mlp_ratio * embed_dim), embed_dim),
200 | )
201 | if self.do_layer_scale:
202 | self.layer_scale_mlp = LayerScale(self.embed_dim)
203 |
204 | def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
205 | if self.do_layer_scale:
206 | x = x + self.layer_scale_attn(self.attn(self.pre_norm_attn(x), mask))
207 | x = x + self.layer_scale_mlp(self.mlp(self.pre_norm_mlp(x)))
208 | else:
209 | x = x + self.attn(self.pre_norm_attn(x), mask)
210 | x = x + self.mlp(self.pre_norm_mlp(x))
211 | return x
212 |
--------------------------------------------------------------------------------
/voltron/overwatch/__init__.py:
--------------------------------------------------------------------------------
1 | from .overwatch import OverwatchRich
2 |
--------------------------------------------------------------------------------
/voltron/overwatch/overwatch.py:
--------------------------------------------------------------------------------
1 | """
2 | overwatch.py
3 |
4 | Utility class for creating a centralized/standardized logger (to pass to Hydra), with a sane default format.
5 | """
6 | from dataclasses import dataclass, field
7 | from typing import Any, Dict
8 |
9 | # Overwatch Default Format String
10 | FORMATTER, DATEFMT = "[*] %(asctime)s - %(name)s >> %(levelname)s :: %(message)s", "%m/%d [%H:%M:%S]"
11 | RICH_FORMATTER = "| >> %(message)s"
12 |
13 |
14 | # Rich Overwatch Variant --> Good for debugging, and tracing!
15 | @dataclass
16 | class OverwatchRich:
17 | version: int = 1
18 | formatters: Dict[str, Any] = field(
19 | default_factory=lambda: {
20 | "simple-console": {"format": RICH_FORMATTER, "datefmt": DATEFMT},
21 | "simple-file": {"format": FORMATTER, "datefmt": DATEFMT},
22 | }
23 | )
24 | handlers: Dict[str, Any] = field(
25 | default_factory=lambda: {
26 | "console": {
27 | "class": "rich.logging.RichHandler",
28 | "formatter": "simple-console",
29 | "rich_tracebacks": True,
30 | "show_level": True,
31 | "show_path": True,
32 | "show_time": True,
33 | },
34 | "file": {
35 | "class": "logging.FileHandler",
36 | "formatter": "simple-file",
37 | "filename": "${hydra.job.name}.log",
38 | },
39 | }
40 | )
41 | root: Dict[str, Any] = field(default_factory=lambda: {"level": "INFO", "handlers": ["console", "file"]})
42 | disable_existing_loggers: bool = True
43 |
44 |
45 | # Standard Overwatch Variant --> Performant, no bells & whistles
46 | @dataclass
47 | class OverwatchStandard:
48 | version: int = 1
49 | formatters: Dict[str, Any] = field(default_factory=lambda: {"simple": {"format": FORMATTER, "datefmt": DATEFMT}})
50 | handlers: Dict[str, Any] = field(
51 | default_factory=lambda: {
52 | "console": {"class": "logging.StreamHandler", "formatter": "simple", "stream": "ext://sys.stdout"},
53 | "file": {
54 | "class": "logging.FileHandler",
55 | "formatter": "simple",
56 | "filename": "${hydra.job.name}.log",
57 | },
58 | }
59 | )
60 | root: Dict[str, Any] = field(default_factory=lambda: {"level": "INFO", "handlers": ["console", "file"]})
61 | disable_existing_loggers: bool = True
62 |
--------------------------------------------------------------------------------
/voltron/preprocessing/__init__.py:
--------------------------------------------------------------------------------
1 | from .process import extract_frames, preprocess_language, unify_batches
2 |
--------------------------------------------------------------------------------
/voltron/preprocessing/core.py:
--------------------------------------------------------------------------------
1 | """
2 | utils.py
3 |
4 | Preprocessing utilities, including dry-run and single-video (single-example) processing. This file effectively defines
5 | the "atomic" logic (take one video --> extract all frames, etc.), while the `process.py` functions invoke each unit
6 | in a multiprocessing pool.
7 | """
8 | import glob
9 | import json
10 | import logging
11 | import os
12 | import time
13 | from pathlib import Path
14 | from typing import Any, Callable, Dict, List, Optional, Set, Tuple
15 |
16 | import av
17 | import h5py
18 | import numpy as np
19 | import pandas as pd
20 | from hurry.filesize import alternative, size
21 | from PIL import Image
22 | from rich.progress import track
23 | from tqdm import tqdm
24 |
25 | # Grab Logger
26 | overwatch = logging.getLogger(__file__)
27 | logging.getLogger("libav").setLevel(logging.ERROR)
28 |
29 |
30 | # === General Utilities ===
31 |
32 |
33 | # Videos are saved as `train_dir/{vid}/{vid}_idx={i}.jpg || if `relpath` then *relative path* `{split}/{vid}/...
34 | def get_path(save_dir: Path, v: str, i: int, relpath: bool = False) -> str:
35 | return str((save_dir if not relpath else Path(save_dir.name)) / v / f"{v}_idx={i}.jpg")
36 |
37 |
38 | # === Dry-Run Functionality ===
39 |
40 |
41 | def do_dry_run(
42 | name: str,
43 | path: str,
44 | train_ids: List[str],
45 | val_ids: List[str],
46 | preprocess_transform: Callable[[List[Image.Image]], List[Image.Image]],
47 | n_train_videos: int = 1000,
48 | n_val_videos: int = 100,
49 | n_samples: int = 1000,
50 | ) -> None:
51 | """Iterates through a small subset of the total dataset, logs n_frames & average image size for estimation."""
52 | overwatch.info(f"Performing Dry-Run with {n_train_videos} Train Videos and {n_val_videos} Validation Videos")
53 | dry_run_metrics = {
54 | "n_frames": [],
55 | "jpg_sizes": [],
56 | "n_samples": n_samples,
57 | "time_per_example": [],
58 | "blank": str(Path(path) / "blank.jpg"),
59 | }
60 |
61 | # Switch on dataset (`name`)
62 | if name == "sth-sth-v2":
63 | for k, n_iter, vids in [("train", n_train_videos, train_ids), ("val", n_val_videos, val_ids)]:
64 | for idx in track(range(n_iter), description=f"Reading {k.capitalize()} Videos =>> ", transient=True):
65 | container = av.open(str(Path(path) / "videos" / f"{vids[idx]}.webm"))
66 | assert int(container.streams.video[0].average_rate) == 12, "FPS for `sth-sth-v2` should be 12!"
67 | try:
68 | imgs = [f.to_image() for f in container.decode(video=0)]
69 | except (RuntimeError, ZeroDivisionError) as e:
70 | overwatch.error(f"{type(e).__name__}: WebM reader cannot open `{vids[idx]}.webm` - continuing...")
71 | continue
72 | container.close()
73 |
74 | # Apply `preprocess_transform`
75 | imgs = preprocess_transform(imgs)
76 |
77 | # Dry-Run Handling --> write a dummy JPEG to collect size statistics, dump, and move on...
78 | dry_run_metrics["n_frames"].append(len(imgs))
79 | while dry_run_metrics["n_samples"] > 0 and len(imgs) > 0:
80 | img = imgs.pop(0)
81 | img.save(str(dry_run_metrics["blank"]))
82 | dry_run_metrics["jpg_sizes"].append(os.path.getsize(dry_run_metrics["blank"]))
83 | dry_run_metrics["n_samples"] -= 1
84 |
85 | # Compute nice totals for "dry-run" estimate...
86 | total_clips = len(train_ids) + len(val_ids)
87 |
88 | else:
89 | raise ValueError(f"Dry Run for Dataset `{name}` not implemented!")
90 |
91 | # Compute aggregate statistics and gently exit...
92 | avg_size, avg_frames = np.mean(dry_run_metrics["jpg_sizes"]), int(np.mean(dry_run_metrics["n_frames"]))
93 | overwatch.info("Dry-Run Statistics =>>")
94 | overwatch.info(f"\t> A video has on average `{avg_frames}` frames at {size(avg_size, system=alternative)}")
95 | overwatch.info(f"\t> So - 1 video ~ {size(avg_frames * avg_size, system=alternative)}")
96 | overwatch.info(
97 | f"\t> With the full dataset of {total_clips} Train + Val videos ~"
98 | f" {size(total_clips * avg_frames * avg_size, system=alternative)}"
99 | )
100 | overwatch.info("Dry-Run complete, do what you will... exiting ✌️")
101 |
102 | # Remove dummy file...
103 | os.remove(dry_run_metrics["blank"])
104 | exit(0)
105 |
106 |
107 | # === Atomic "Processing" Steps ===
108 |
109 |
110 | def process_clip(
111 | name: str,
112 | path: Path,
113 | save: Path,
114 | preprocess_transform: Callable[[List[Image.Image]], List[Image.Image]],
115 | item: Tuple[str, str],
116 | ) -> Tuple[Optional[str], Optional[Dict[str, Any]]]:
117 | """Processes a single video clip and extracts/serializes all frames (as jpeg), returning the registry contents."""
118 | if name == "sth-sth-v2":
119 | vid, lang = item
120 | container, registration = av.open(str(Path(path) / "videos" / f"{vid}.webm")), {"language": lang, "n_frames": 0}
121 | assert int(container.streams.video[0].average_rate) == 12, "FPS for `sth-sth-v2` should be 12!"
122 | try:
123 | imgs = [f.to_image() for f in container.decode(video=0)]
124 | except (RuntimeError, ZeroDivisionError) as e:
125 | overwatch.error(f"{type(e).__name__}: WebM reader cannot open `{vid}.webm` - continuing...")
126 | return None, None
127 | container.close()
128 |
129 | # Book-Keeping
130 | os.makedirs(save / vid, exist_ok=True)
131 | registration["n_frames"] = len(imgs)
132 |
133 | # Short Circuit --> Writes are Expensive!
134 | if len(glob.glob1(save / vid, "*.jpg")) == len(imgs):
135 | return vid, registration
136 |
137 | # Apply `preprocess_transform` --> write individual frames, register, and move on!
138 | imgs = preprocess_transform(imgs)
139 | for idx in range(len(imgs)):
140 | imgs[idx].save(get_path(save, vid, idx))
141 |
142 | # Return title & registration
143 | return vid, registration
144 |
145 | else:
146 | raise ValueError(f"Clip Processing for Dataset `{name}` is not implemented!")
147 |
148 |
149 | # ruff: noqa: C901
150 | def serialize_epoch(
151 | index_dir: Path,
152 | registry: Dict[str, Any],
153 | vid_dir: Path,
154 | batch_formats: Tuple[Tuple[str, Tuple[str, ...]], ...],
155 | do_initial: bool,
156 | do_final: bool,
157 | initial_final_alpha: float,
158 | n_int: int,
159 | epoch: int,
160 | is_validation: bool = False,
161 | ) -> Tuple[int, int, Optional[Set[str]]]:
162 | index_file = "validation-batches.json" if is_validation else f"train-epoch={epoch}-batches.json"
163 | index_hdf5 = "validation-batches.hdf5" if is_validation else f"train-epoch={epoch}-batches.hdf5"
164 |
165 | # Short-Circuit
166 | if all([(index_dir / key / index_file).exists() for key, _ in batch_formats]):
167 | return -1, -1, None
168 |
169 | # Random seed is inherited from parent process... we want new randomness w/ each process
170 | np.random.seed((os.getpid() * int(time.time())) % 123456789)
171 |
172 | # Create Tracking Variables
173 | unique_states, batches = set(), {b: [] for b, _ in batch_formats}
174 |
175 | # Iterate through Registry --> Note we're using `tqdm` instead of `track` here because of `position` feature!
176 | for vid in tqdm(registry.keys(), desc=f"Epoch {epoch}", total=len(registry), position=epoch):
177 | # The initial/final states are sampled from the first [0, \alpha) and final 1-\alpha, 1] percent of the video
178 | n_frames = registry[vid]["n_frames"]
179 | initial_idx, final_idx = 0, n_frames - 1
180 | if do_initial:
181 | initial_idx = np.random.randint(0, np.around(n_frames * initial_final_alpha))
182 |
183 | if do_final:
184 | final_idx = np.random.randint(np.around(n_frames * (1 - initial_final_alpha)), n_frames)
185 |
186 | # Assertion --> initial_idx < final_idx - len(state_elements)
187 | assert initial_idx < final_idx - n_int, "Initial & Final are too close... no way to sample!"
188 |
189 | # Assume remaining elements are just random "interior" states --> sort to get ordering!
190 | sampled_idxs = np.random.choice(np.arange(initial_idx + 1, final_idx), size=n_int, replace=False)
191 | sampled_idxs = sorted(list(sampled_idxs))
192 |
193 | # Compile full-set "batch"
194 | retrieved_states = [get_path(vid_dir, vid, x, relpath=True) for x in [initial_idx, *sampled_idxs] + [final_idx]]
195 |
196 | # Add batch to index for specific batch_format key...
197 | batches[batch_formats[-1][0]].append({"vid": vid, "states": retrieved_states, "n_frames": n_frames})
198 | unique_states.update(retrieved_states)
199 |
200 | # Add all other batch formats to indices...
201 | for key, elements in batch_formats[:-1]:
202 | n_states = len([x for x in elements if "state_" in x])
203 | assert (n_states <= 2) or (
204 | n_states == len(retrieved_states)
205 | ), f"Strange value of n_states={n_states} > 2 and not equal to total possible of {len(retrieved_states)}"
206 |
207 | # States are all independent -- each of the retrieved states is its own example...
208 | if n_states == 1:
209 | for idx in range(len(retrieved_states)):
210 | batches[key].append({"vid": vid, "state": retrieved_states[idx], "n_frames": n_frames})
211 |
212 | # OK-Context is the only "valid" context for n_states == 2
213 | elif n_states == 2:
214 | assert elements == ["state_initial", "state_i", "language"], "n_states = 2 but not 0K context?"
215 |
216 | # Append 0th state to each of the remaining sampled contexts (usually 2 or 4)... each pair is an example
217 | for idx in range(1, len(retrieved_states)):
218 | batches[key].append(
219 | {"vid": vid, "states": [retrieved_states[0], retrieved_states[idx]], "n_frames": n_frames}
220 | )
221 |
222 | # We're treating the entire sequence of retrieved states as a single example (for TCN/R3M/Temporal Models)
223 | else:
224 | batches[key].append({"vid": vid, "states": retrieved_states, "n_frames": n_frames})
225 |
226 | # Write JSON Index directly to disk...
227 | for key in batches:
228 | with open(index_dir / key / index_file, "w") as f:
229 | json.dump(batches[key], f)
230 |
231 | # Write HDF5 Index directly to disk...
232 | for key, elements in batch_formats[:-1]:
233 | n_states = len([x for x in elements if "state_" in x])
234 |
235 | # Create HDF5 File
236 | df = pd.DataFrame(batches[key])
237 | h5 = h5py.File(index_dir / key / index_hdf5, "w")
238 | for k in ["vid", "n_frames"]:
239 | h5.create_dataset(k, data=df[k].values)
240 |
241 | # Handle "state(s)" --> (image path strings) --> add leading dimension (`n_states`)
242 | if n_states == 1:
243 | dfs = df["state"].apply(pd.Series)
244 | h5.create_dataset("states", data=dfs.values)
245 |
246 | else:
247 | dfs = df["states"].apply(pd.Series)
248 | h5.create_dataset("states", data=dfs.values)
249 |
250 | # Close HDF5 File
251 | h5.close()
252 |
253 | return epoch, len(batches["state"]), unique_states
254 |
--------------------------------------------------------------------------------
/voltron/preprocessing/process.py:
--------------------------------------------------------------------------------
1 | """
2 | process.py
3 |
4 | Utility functions for preprocessing large-scale video/vision-language datasets in multiple passes, using multiprocessing
5 | for parallelization. Exposes a three-phase sequence for preprocessing --> batching data:
6 | - Phase I (`extract_frames`): Read in raw (video clip, language) pairs, extract and serialize *all frames* to disk.
7 |
8 | This script tries to be smart where it can, using multiprocessing.Pool in Phase I to speed up extraction; however, for
9 | larger datasets YMMV. You might consider extracting the relevant logic, and using tools like SLURM Job Arrays, AWS
10 | Lambda Functions, or GCP Cloud Run to "burst preprocess" data.
11 | """
12 | import json
13 | import logging
14 | import multiprocessing as mp
15 | import os
16 | import shutil
17 | from functools import partial
18 | from pathlib import Path
19 | from typing import Tuple
20 |
21 | import torch
22 | from rich.progress import track
23 | from transformers import AutoTokenizer
24 |
25 | from voltron.preprocessing.core import do_dry_run, process_clip, serialize_epoch
26 | from voltron.preprocessing.transforms import get_preprocess_transform
27 |
28 | # Grab Logger
29 | overwatch = logging.getLogger(__file__)
30 |
31 |
32 | def extract_frames(
33 | name: str,
34 | path: str,
35 | artifact_path: str,
36 | preprocess_resolution: int,
37 | n_val_videos: int,
38 | dry_run: bool = False,
39 | ) -> Tuple[Path, Path, Path, Path]:
40 | """Phase I: Extract and serialize *all frames* from video clips; uses multiprocessing to parallelize."""
41 | overwatch.info(f"Phase 1 Preprocessing :: Extracting Frames for Dataset `{name}`")
42 |
43 | # Overview of Return Values:
44 | # `t_registry` and `v_registry` =>> store mappings of "video id" -> {metadata}
45 | # `t_dir` and `v_dir` =>> store "processed data" (extracted frames)
46 | t_dir, v_dir = Path(artifact_path) / name / "train", Path(artifact_path) / name / "val"
47 | t_registry, v_registry = t_dir / "registry.json", v_dir / "registry.json"
48 |
49 | # Short-Circuit
50 | if t_registry.exists() and v_registry.exists():
51 | return t_registry, v_registry, t_dir, v_dir
52 |
53 | # Setup / Book-Keeping
54 | os.makedirs(t_dir, exist_ok=True)
55 | os.makedirs(v_dir, exist_ok=True)
56 |
57 | # Retrieve "pre-serialization" frame transform --> we scale down video frames (*while preserving aspect ratios*)
58 | # and center crop each frame to `(preprocess_resolution, preprocess_resolution)`; saves on disk space (by a lot!)
59 | preprocess_transform = get_preprocess_transform(name, preprocess_resolution=preprocess_resolution)
60 |
61 | # Switch on dataset (`name`)
62 | if name == "sth-sth-v2":
63 | with open(Path(path) / "labels/train.json", "r") as f:
64 | annotations = json.load(f)
65 | train_ids, train_lang = [x["id"] for x in annotations], [x["label"] for x in annotations]
66 |
67 | with open(Path(path) / "labels/validation.json", "r") as f:
68 | annotations = json.load(f)[:n_val_videos]
69 | val_ids, val_lang = [x["id"] for x in annotations], [x["label"] for x in annotations]
70 |
71 | else:
72 | raise ValueError(f"Language/Metadata Extraction Pipeline for Dataset `{name}` not implemented!")
73 |
74 | # Run Dry-Run (if specified) --> single-threaded for debugging
75 | if dry_run:
76 | do_dry_run(name, path, train_ids, val_ids, preprocess_transform)
77 |
78 | # Otherwise =>> Iterate through all videos, dump all frames subject to the following structure:
79 | # |-> .../processed/something-something-v2/
80 | # |-> /
81 | # |-> /frames<0..k>.jpg
82 | #
83 | # We'll build a single metadata file with a mapping : ("language", n_frames)
84 | # > To speed up serialization, we'll use a multiprocessing.Pool and max out CPU workers
85 | with mp.Pool(mp.cpu_count()) as pool:
86 | for k, save, vids, langs in [("train", t_dir, train_ids, train_lang), ("val", v_dir, val_ids, val_lang)]:
87 | overwatch.info(f"\tWriting `{k}` videos to disk...")
88 |
89 | # Spawn!
90 | process_fn, registration = partial(process_clip, name, Path(path), save, preprocess_transform), {}
91 | for key, value in track(
92 | pool.imap_unordered(process_fn, zip(vids, langs)),
93 | total=len(vids),
94 | transient=True,
95 | ):
96 | if key is not None:
97 | registration[key] = value
98 |
99 | # Write Registration to Disk
100 | with open(t_registry if k == "train" else v_registry, "w") as f:
101 | json.dump(registration, f)
102 |
103 | # Return Paths to Registry & Extract Directories...
104 | return t_registry, v_registry, t_dir, v_dir
105 |
106 |
107 | def preprocess_language(
108 | name: str,
109 | train_registry: Path,
110 | val_registry: Path,
111 | artifact_path: str,
112 | max_lang_len: int,
113 | language_model: str,
114 | hf_cache: str,
115 | ) -> Path:
116 | """Phase II: Iterate through Language Captions/Narrations and Normalize/Tokenize (truncate/pad to max length)."""
117 | overwatch.info(f"Phase 2 Preprocessing :: Normalizing & Tokenizing Language for Dataset `{name}`")
118 | t_index, v_index = train_registry.parent / "index.pt", val_registry.parent / "index.pt"
119 | t_json, v_json = train_registry.parent / "index.json", val_registry.parent / "index.json"
120 | index_dir = Path(artifact_path) / name / "index"
121 | os.makedirs(index_dir, exist_ok=True)
122 |
123 | # Short-Circuit
124 | if (index_dir / "train-language-index.json").exists() and (index_dir / "val-language-index.json").exists():
125 | return index_dir
126 |
127 | # Grab Language --> retain metadata for building index structures!
128 | with open(train_registry, "r") as f:
129 | train_metadata = json.load(f)
130 | train = [(vid, train_metadata[vid]["language"], train_metadata[vid]) for vid in train_metadata]
131 |
132 | with open(val_registry, "r") as f:
133 | val_metadata = json.load(f)
134 | val = [(vid, val_metadata[vid]["language"], val_metadata[vid]) for vid in val_metadata]
135 |
136 | # Assemble *all* language
137 | language = [x[1] for x in train + val]
138 |
139 | # Build AutoTokenizer (from `language_model` identifier)
140 | tokenizer = AutoTokenizer.from_pretrained(language_model, cache_dir=hf_cache)
141 |
142 | # If `max_lang_len` not specified, dump some statistics to compute...
143 | if max_lang_len == -1:
144 | # Naively tokenizes and pads to the "maximum length" of _all_ language... long tail is a problem!
145 | encoded_language = tokenizer(language, return_tensors="pt", padding=True)
146 | lengths = encoded_language["attention_mask"].sum(dim=1)
147 |
148 | # Compute a histogram of lengths
149 | hist = lengths.float().histc(bins=lengths.max()).int()
150 | overwatch.info(f"Histogram: {hist.numpy().tolist()}")
151 | raise AssertionError("Compute max length and update dataset configuration!")
152 |
153 | # Otherwise, we've already set the maximum length, so let's use it!
154 | overwatch.info(f"\tTokenizing all language in dataset to maximum length `{max_lang_len}`")
155 | encoded_language = tokenizer(
156 | language, return_tensors="pt", max_length=max_lang_len, truncation=True, padding="max_length"
157 | )
158 | input_ids, attention_mask = encoded_language["input_ids"], encoded_language["attention_mask"]
159 | train_input_ids, train_attention_mask = input_ids[: len(train)], attention_mask[: len(train)]
160 | val_input_ids, val_attention_mask = input_ids[len(train) :], attention_mask[len(train) :]
161 |
162 | # Assertion, just to sanity check
163 | assert len(val_input_ids) == len(val_attention_mask) == len(val), "Something went wrong tokenizing language..."
164 |
165 | # Compute `index.pt` contents
166 | overwatch.info("\tAssembling `train` and `val` index structures...")
167 | train_pt = {
168 | train[i][0]: {**train[i][2], **{"input_ids": train_input_ids[i], "attention_mask": train_attention_mask[i]}}
169 | for i in range(len(train))
170 | }
171 | val_pt = {
172 | val[i][0]: {**val[i][2], **{"input_ids": val_input_ids[i], "attention_mask": val_attention_mask[i]}}
173 | for i in range(len(val))
174 | }
175 |
176 | # Additionally dump JSON versions of the same --> downstream interpretability, XLA
177 | overwatch.info("JSONifying both Train and Validation Language")
178 | train_json, val_json = {}, {}
179 | for vid in track(train_pt, description="Train Language :: ", transient=True):
180 | train_json[vid] = {
181 | "language": train_pt[vid]["language"],
182 | "n_frames": train_pt[vid]["n_frames"],
183 | "input_ids": train_pt[vid]["input_ids"].numpy().tolist(),
184 | "attention_mask": train_pt[vid]["attention_mask"].numpy().tolist(),
185 | }
186 |
187 | for vid in track(val_pt, description="Validation Language :: ", transient=True):
188 | val_json[vid] = {
189 | "language": val_pt[vid]["language"],
190 | "n_frames": val_pt[vid]["n_frames"],
191 | "input_ids": val_pt[vid]["input_ids"].numpy().tolist(),
192 | "attention_mask": val_pt[vid]["attention_mask"].numpy().tolist(),
193 | }
194 |
195 | # Dump Structures...
196 | overwatch.info(f"Saving Torch indices to `{t_index}` and `{v_index}` respectively...")
197 | torch.save(train_pt, t_index)
198 | torch.save(val_pt, v_index)
199 |
200 | overwatch.info(f"Saving JSON indices to `{t_json}` and `{v_json}` respectively...")
201 | with open(t_json, "w") as f:
202 | json.dump(train_json, f)
203 |
204 | with open(v_json, "w") as f:
205 | json.dump(val_json, f)
206 |
207 | # Pull relevant files out into their own `index` directory...
208 | shutil.copy(t_json, index_dir / "train-language-index.json")
209 | shutil.copy(v_json, index_dir / "val-language-index.json")
210 |
211 | return index_dir
212 |
213 |
214 | def unify_batches(
215 | name: str,
216 | train_registry: Path,
217 | val_registry: Path,
218 | train_dir: Path,
219 | val_dir: Path,
220 | index_dir: Path,
221 | batch_formats: Tuple[Tuple[str, Tuple[str, ...]], ...],
222 | max_epochs: int = 400,
223 | initial_final_alpha: float = 0.2,
224 | ) -> None:
225 | """Phase III: Assemble "Data-Locked" Batches for *all models* for *all epochs* for consistency!"""
226 | overwatch.info(f"Phase 3 Preprocessing :: Assembling *Data-Locked* Batches for Dataset `{name}`")
227 |
228 | # Load Registries
229 | with open(train_registry, "r") as f:
230 | train_registrations = json.load(f)
231 |
232 | with open(val_registry, "r") as f:
233 | val_registrations = json.load(f)
234 |
235 | # Assert last element of `batch_formats` assumes all prior subsets...
236 | full_set_inputs = set(batch_formats[-1][1])
237 | for _, subset_inputs in batch_formats[:-1]:
238 | assert full_set_inputs.issuperset(set(subset_inputs)), "We have a problem with batch formats..."
239 |
240 | # Assemble Tracking Data
241 | b_keys, unique_states = {b[0] for b in batch_formats}, set()
242 |
243 | # Parse out all "state"-specific Elements...
244 | state_elements = [s for s in full_set_inputs if "state_" in s]
245 | do_initial, do_final = "state_initial" in state_elements, "state_final" in state_elements
246 | n_int = len(state_elements) - 2 if ("state_initial" in state_elements and "state_final" in state_elements) else 0
247 |
248 | # Serialize Epochs
249 | overwatch.info("\tSerializing Epochs to JSON --> Storing mapping of Epoch -> Image Paths")
250 | for b in b_keys:
251 | os.makedirs(index_dir / b, exist_ok=True)
252 |
253 | # We only write the Validation Epoch once --> held constant across *all* of training!
254 | overwatch.info("\tWriting Validation Epoch to Disk")
255 | val_epoch_idx, _, uniq_s = serialize_epoch(
256 | index_dir,
257 | val_registrations,
258 | val_dir,
259 | batch_formats,
260 | do_initial,
261 | do_final,
262 | initial_final_alpha,
263 | n_int,
264 | epoch=0,
265 | is_validation=True,
266 | )
267 |
268 | # Update Trackers...
269 | if val_epoch_idx != -1:
270 | unique_states |= uniq_s
271 |
272 | # Compute length of epochs --> CPU Count should be no higher...
273 | epochs, n_frames_per_epoch = list(range(max_epochs)), -1
274 |
275 | # Parallelize Train Epoch Serialization
276 | overwatch.info("\tPlacing the Train Registry into Shared Memory")
277 | manager = mp.Manager()
278 | mg_registry = manager.dict(train_registrations)
279 |
280 | # Multiprocess --> the memory demands here are a bit higher, so limit workers by factor of 4
281 | with mp.Pool(mp.cpu_count() // 4) as pool:
282 | overwatch.info("\tWriting Train Batches per Epoch to Disk")
283 | precompute_fn = partial(
284 | serialize_epoch,
285 | index_dir,
286 | mg_registry,
287 | train_dir,
288 | batch_formats,
289 | do_initial,
290 | do_final,
291 | initial_final_alpha,
292 | n_int,
293 | )
294 | for epoch_idx, n_frames, uniq_s in pool.imap_unordered(precompute_fn, epochs):
295 | if epoch_idx == -1:
296 | continue
297 |
298 | # Update Trackers
299 | unique_states |= uniq_s
300 | n_frames_per_epoch = n_frames
301 |
302 | # Dump Statistics (Note :: Only makes sense on "initial" computation --> uninterrupted!)
303 | overwatch.info(f"Train Uniqueness: {len(unique_states)} States & {len(mg_registry)} Utterances")
304 | overwatch.info(f"Final Statistics :: 1 Epoch has ~ {n_frames_per_epoch} Frames...")
305 |
--------------------------------------------------------------------------------
/voltron/preprocessing/transforms.py:
--------------------------------------------------------------------------------
1 | """
2 | transforms.py
3 |
4 | Default video/image transforms for Voltron preprocessing and training. Provides utilities for defining different scale
5 | and crop transformations on a dataset-specific basis.
6 |
7 | There are two key desiderata we ensure with the transforms:
8 | - Aspect Ratio --> We *never* naively reshape images in a way that distorts the aspect ratio; we crop instead!
9 | - Minimum Size --> We *never* upsample images; processing strictly reduces dimensionality!
10 | """
11 | from functools import partial
12 | from typing import Any, Callable, List, Tuple
13 |
14 | import torch
15 | from PIL import Image, ImageOps
16 | from torchvision.transforms import Compose, ConvertImageDtype, Lambda, Normalize, Resize
17 |
18 |
19 | # Simple Identity Function --> needs to be top-level/pickleable for mp/distributed.spawn()
20 | def identity(x: torch.Tensor) -> torch.Tensor:
21 | return x.float()
22 |
23 |
24 | def scaled_center_crop(target_resolution: int, frames: List[Image.Image]) -> Image.Image:
25 | # Assert width >= height and height >= target_resolution
26 | orig_w, orig_h = frames[0].size
27 | assert orig_w >= orig_h >= target_resolution
28 |
29 | # Compute scale factor --> just a function of height and target_resolution
30 | scale_factor = target_resolution / orig_h
31 | for idx in range(len(frames)):
32 | frames[idx] = ImageOps.scale(frames[idx], factor=scale_factor)
33 | left = (frames[idx].size[0] - target_resolution) // 2
34 | frames[idx] = frames[idx].crop((left, 0, left + target_resolution, target_resolution))
35 |
36 | # Return "scaled and squared" images
37 | return frames
38 |
39 |
40 | def get_preprocess_transform(
41 | dataset_name: str, preprocess_resolution: int
42 | ) -> Callable[[List[Image.Image]], List[Image.Image]]:
43 | """Returns a transform that extracts square crops of `preprocess_resolution` from videos (as [T x H x W x C])."""
44 | if dataset_name == "sth-sth-v2":
45 | return partial(scaled_center_crop, preprocess_resolution)
46 | else:
47 | raise ValueError(f"Preprocessing transform for dataset `{dataset_name}` is not defined!")
48 |
49 |
50 | def get_online_transform(
51 | dataset_name: str, model_arch: str, online_resolution: int, normalization: Tuple[Any, Any]
52 | ) -> Compose:
53 | """Returns an "online" torchvision Transform to be applied during training (batching/inference)."""
54 | if dataset_name == "sth-sth-v2":
55 | # Note: R3M does *not* expect normalized 0-1 (then ImageNet normalized) images --> drop the identity.
56 | if model_arch in {"v-r3m", "v-rn3m"}:
57 | return Compose([Resize((online_resolution, online_resolution), antialias=True), Lambda(identity)])
58 | else:
59 | return Compose(
60 | [
61 | Resize((online_resolution, online_resolution), antialias=True),
62 | ConvertImageDtype(torch.float),
63 | Normalize(mean=normalization[0], std=normalization[1]),
64 | ]
65 | )
66 | else:
67 | raise ValueError(f"Online Transforms for Dataset `{dataset_name}` not implemented!")
68 |
--------------------------------------------------------------------------------
/voltron/preprocessing/v1/__init__.py:
--------------------------------------------------------------------------------
1 | from .process import index, jsonify_language, preprocess_language, preprocess_videos, unify_batches
2 |
--------------------------------------------------------------------------------
/voltron/preprocessing/v1/process.py:
--------------------------------------------------------------------------------
1 | """
2 | process.py
3 |
4 | Utility functions for serializing datasets in multiple passes, using multiprocessing for efficient parallelization.
5 | Exposes a three-phase sequence for preprocessing:
6 | - Phase I: Read in raw videos (and language), serialize *all extracted* frames to a subdirectory for easy retrieval.
7 | - Phase II: Given image paths and language, assemble language statistics & pre-tokenize for easy batching.
8 | - Phase III: Given a total number of "conceivable epochs", create data-controlled "epoch" sets for each model.
9 |
10 | This script tries to be smart where it can, using multiprocessing.Pool in Phase I to speed up the serialization
11 | process. It also tries to be somewhat safe & efficient, producing idempotent resumes.
12 |
13 | Note :: This code represents the `v1` (initial release) preprocessing flow; this will eventually be deprecated!
14 | """
15 | import json
16 | import logging
17 | import multiprocessing as mp
18 | import os
19 | import shutil
20 | from functools import partial
21 | from pathlib import Path
22 | from typing import Tuple
23 |
24 | import torch
25 | from rich.progress import track
26 | from transformers import AutoTokenizer
27 |
28 | from voltron.preprocessing.v1.transforms import get_pre_transform
29 | from voltron.preprocessing.v1.utils import do_dry_run, precompute_epoch, process_video
30 |
31 | # Grab Logger
32 | overwatch = logging.getLogger(__file__)
33 |
34 |
35 | def preprocess_videos(
36 | name: str,
37 | path: str,
38 | artifact_path: str = "data/processed",
39 | resolution: int = 224,
40 | n_val_videos: int = 1000,
41 | dry_run: bool = False,
42 | ) -> Tuple[Path, Path, Path, Path]:
43 | """Phase I of Preprocessing :: Uses Multiprocessing to Read Videos & Serialize Frames."""
44 | overwatch.info(f"Phase 1 Preprocessing :: Frame serializing videos for dataset `{name}`")
45 |
46 | if name == "sth-sth-v2":
47 | # Overview of Return Values:
48 | # `t_registry` and `v_registry` =>> store mappings of "vid_id" -> {metadata}
49 | # `t_dir` and `v_dir` =>> store "processed data" (extracted frames)
50 | t_dir, v_dir = Path(artifact_path) / name / "train", Path(artifact_path) / name / "val"
51 | t_registry, v_registry = t_dir / "registry.json", v_dir / "registry.json"
52 |
53 | # Short-Circuit / Caching Logic
54 | if t_registry.exists() and v_registry.exists():
55 | return t_registry, v_registry, t_dir, v_dir
56 |
57 | # Setup / Book-Keeping
58 | os.makedirs(t_dir, exist_ok=True)
59 | os.makedirs(v_dir, exist_ok=True)
60 |
61 | # Retrieve Image Transforms (pre-serialization, while running "offline" pass); we crop and scale once, so we're
62 | # not overdoing it on disk storage...
63 | pre_transform = get_pre_transform(name, resolution=resolution)
64 |
65 | # Open & Extract Video ID & Language Metadata
66 | with open(Path(path) / "something-something-v2-train.json", "r") as f:
67 | annotations = json.load(f)
68 | train_ids, train_lang = [x["id"] for x in annotations], [x["label"] for x in annotations]
69 |
70 | with open(Path(path) / "something-something-v2-validation.json", "r") as f:
71 | annotations = json.load(f)[:n_val_videos]
72 | val_ids, val_lang = [x["id"] for x in annotations], [x["label"] for x in annotations]
73 |
74 | # Do Dry-Run --> Single-Threaded!
75 | if dry_run:
76 | do_dry_run(
77 | name,
78 | path,
79 | n_train_videos=1000,
80 | n_val_videos=100,
81 | train_ids=train_ids,
82 | val_ids=val_ids,
83 | pre_transform=pre_transform,
84 | )
85 |
86 | # Go Go Go =>> Iterate through all videos, dump all frames subject to the following structure:
87 | # |-> data/processed/sth-sth-v2/
88 | # |-> /
89 | # |-> /frames<0...k>.jpg
90 | # We'll track a single metadata file with the map of : ("language", n_frames).
91 | # > To speed up the serialization, we'll use a multiprocessing.Pool and max out CPU workers
92 | with mp.Pool(mp.cpu_count()) as pool:
93 | for k, save, vids, langs in [("train", t_dir, train_ids, train_lang), ("val", v_dir, val_ids, val_lang)]:
94 | overwatch.info(f"\tWriting `{k}` videos to disk...")
95 |
96 | # Multiprocess!
97 | process_fn, registration = partial(process_video, name, Path(path), save, pre_transform), {}
98 | for key, value in track(
99 | pool.imap_unordered(process_fn, zip(vids, langs)),
100 | description=f"\t[*] Processing {k}...",
101 | total=len(vids),
102 | transient=True,
103 | ):
104 | if key is not None:
105 | registration[key] = value
106 |
107 | # Write Registration to Disk
108 | with open(t_registry if k == "train" else v_registry, "w") as f:
109 | json.dump(registration, f)
110 |
111 | # Return Paths...
112 | return t_registry, v_registry, t_dir, v_dir
113 |
114 | else:
115 | raise NotImplementedError(f"Preprocessing Pipeline for Dataset `{name}` not implemented!")
116 |
117 |
118 | def preprocess_language(
119 | name: str, train_registry: Path, val_registry: Path, max_lang_len: int, language_model: str, hf_cache: str
120 | ) -> None:
121 | """Phase II of Preprocessing :: Iterate through Language & Normalize/Tokenize to Max Length."""
122 | overwatch.info(f"Phase 2 Preprocessing :: Normalizing & tokenizing language for dataset `{name}`")
123 | t_index, v_index = train_registry.parent / "index.pt", val_registry.parent / "index.pt"
124 | t_json, v_json = train_registry.parent / "index.json", val_registry.parent / "index.json"
125 |
126 | # Short-Circuit Logic
127 | if (t_index.exists() and v_index.exists()) or (t_json.exists() and v_json.exists()):
128 | return t_index, v_index
129 |
130 | # Grab Language, Retaining Metadata for Building Index Structures...
131 | with open(train_registry, "r") as f:
132 | train_metadata = json.load(f)
133 | train = [(vid, train_metadata[vid]["language"], train_metadata[vid]) for vid in train_metadata]
134 |
135 | with open(val_registry, "r") as f:
136 | val_metadata = json.load(f)
137 | val = [(vid, val_metadata[vid]["language"], val_metadata[vid]) for vid in val_metadata]
138 |
139 | # Assemble *all* language
140 | language = [x[1] for x in train + val]
141 |
142 | # Build AutoTokenizer (from `language_model` identifier)
143 | tokenizer = AutoTokenizer.from_pretrained(language_model, cache_dir=hf_cache)
144 |
145 | # If `max_lang_len` not specified, dump some statistics to compute...
146 | if max_lang_len == -1:
147 | # Naively tokenizes and pads to the "maximum length" of _all_ language... long tail is a problem!
148 | encoded_language = tokenizer(language, return_tensors="pt", padding=True)
149 | lengths = encoded_language["attention_mask"].sum(dim=1)
150 |
151 | # Compute a histogram of lengths
152 | hist = lengths.float().histc(bins=lengths.max()).int()
153 | overwatch.info(f"Histogram: {hist.numpy().tolist()}")
154 | raise NotImplementedError("Compute max length and update dataset configuration!")
155 |
156 | # Otherwise, we've already set the maximum length, so let's use it!
157 | else:
158 | overwatch.info(f"\tTokenizing all language in dataset to maximum length `{max_lang_len}`")
159 | encoded_language = tokenizer(
160 | language, return_tensors="pt", max_length=max_lang_len, truncation=True, padding="max_length"
161 | )
162 | input_ids, attention_mask = encoded_language["input_ids"], encoded_language["attention_mask"]
163 | train_input_ids, train_attention_mask = input_ids[: len(train)], attention_mask[: len(train)]
164 | val_input_ids, val_attention_mask = input_ids[len(train) :], attention_mask[len(train) :]
165 |
166 | # Assertion, just to sanity check
167 | assert len(val_input_ids) == len(val_attention_mask) == len(val), "Something went wrong tokenizing language..."
168 |
169 | # Compute `index.pt` contents
170 | overwatch.info("\tAssembling `train` and `val` index structures...")
171 | train_pt = {
172 | train[i][0]: {**train[i][2], **{"input_ids": train_input_ids[i], "attention_mask": train_attention_mask[i]}}
173 | for i in range(len(train))
174 | }
175 | val_pt = {
176 | val[i][0]: {**val[i][2], **{"input_ids": val_input_ids[i], "attention_mask": val_attention_mask[i]}}
177 | for i in range(len(val))
178 | }
179 |
180 | # Dump structures...
181 | overwatch.info(f"Saving index structures to `{t_index}` and `{v_index}` respectively...")
182 | torch.save(train_pt, t_index)
183 | torch.save(val_pt, v_index)
184 |
185 |
186 | def jsonify_language(train_registry: Path, val_registry: Path) -> None:
187 | """Phase 2.5 (Aggregation) :: XLA is weird, won't load torch.Tensors in Dataset; JSONify instead."""
188 | overwatch.info("\tPhase 2 Aggregation :: JSONifying Language Index")
189 | t_index, v_index = train_registry.parent / "index.pt", val_registry.parent / "index.pt"
190 | t_json, v_json = train_registry.parent / "index.json", val_registry.parent / "index.json"
191 | train_json, val_json = {}, {}
192 |
193 | # Short-Circuit Logic
194 | if t_json.exists() and v_json.exists():
195 | return
196 |
197 | # Load Data, iterate through and "de-tensorize", while building up JSON symmetric structure...
198 | train_data, val_data = torch.load(t_index), torch.load(v_index)
199 | overwatch.info("JSONifying both Train and Validation")
200 | for vid in track(train_data, description="Train Language...", transient=True):
201 | train_json[vid] = {
202 | "language": train_data[vid]["language"],
203 | "n_frames": train_data[vid]["n_frames"],
204 | "input_ids": train_data[vid]["input_ids"].numpy().tolist(),
205 | "attention_mask": train_data[vid]["attention_mask"].numpy().tolist(),
206 | }
207 | for vid in track(val_data, description="Val Language...", transient=True):
208 | val_json[vid] = {
209 | "language": val_data[vid]["language"],
210 | "n_frames": val_data[vid]["n_frames"],
211 | "input_ids": val_data[vid]["input_ids"].numpy().tolist(),
212 | "attention_mask": val_data[vid]["attention_mask"].numpy().tolist(),
213 | }
214 |
215 | # Write Data to Disk
216 | overwatch.info("Writing JSON Indices")
217 | with open(t_json, "w") as f:
218 | json.dump(train_json, f)
219 |
220 | with open(v_json, "w") as f:
221 | json.dump(val_json, f)
222 |
223 |
224 | def index(train_registry: Path, val_registry: Path, name: str, artifact_path: str = "data/processed") -> Path:
225 | """Phase 2.75 (Indexing) :: Pull out language.json & other `absolutely necessary` indices to separate directory."""
226 | overwatch.info("\tPhase 2 Indexing :: Indexing Language & Registry Files =>> Extracting to Separate Directory")
227 |
228 | # Create "index" directory...
229 | index_dir = Path(artifact_path) / name / "index"
230 | os.makedirs(index_dir, exist_ok=True)
231 |
232 | # Short-Circuit Logic
233 | if (index_dir / "train-language-index.json").exists() and (index_dir / "val-language-index.json").exists():
234 | return index_dir
235 |
236 | # Retrieve Language JSON indices (train & validation) & copy to new directory...
237 | t_json, v_json = train_registry.parent / "index.json", val_registry.parent / "index.json"
238 | shutil.copy(t_json, index_dir / "train-language-index.json")
239 | shutil.copy(v_json, index_dir / "val-language-index.json")
240 |
241 | return index_dir
242 |
243 |
244 | def unify_batches(
245 | artifact_path: Path,
246 | name: str,
247 | train_registry: Path,
248 | val_registry: Path,
249 | train_dir: Path,
250 | val_dir: Path,
251 | index_dir: Path,
252 | batch_formats: Tuple[Tuple[str, Tuple[str, ...]], ...],
253 | max_epochs: int = 400,
254 | initial_final_alpha: float = 0.2,
255 | ) -> None:
256 | """Phase III of Preprocessing :: Assemble Batches for *all models* for *all epochs* in a consistent manner."""
257 | overwatch.info("Phase 3 Preprocessing :: Assembling Data-Equivalent Epochs for each Model Format")
258 |
259 | # Load Registry Files
260 | with open(train_registry, "r") as f:
261 | train_registrations = json.load(f)
262 |
263 | with open(val_registry, "r") as f:
264 | val_registrations = json.load(f)
265 |
266 | # Assert last element of `batch_formats` assumes all prior subsets...
267 | full_set_inputs = set(batch_formats[-1][1])
268 | for _, subset_inputs in batch_formats[:-1]:
269 | assert full_set_inputs.issuperset(set(subset_inputs)), "We have a problem with batch formats..."
270 |
271 | # Assemble Tracking Data
272 | b_keys, unique_states = {b[0] for b in batch_formats}, set()
273 |
274 | # Parse out all "state"-specific elements...
275 | state_elements = [s for s in full_set_inputs if "state_" in s]
276 | do_initial, do_final = "state_initial" in state_elements, "state_final" in state_elements
277 | n_int = len(state_elements) - 2 if ("state_initial" in state_elements and "state_final" in state_elements) else 0
278 |
279 | # Serialize Epochs to Disk
280 | overwatch.info("\tSerializing epochs to json file, pointing to image paths on disk via a dictionary...")
281 | for b in b_keys:
282 | os.makedirs(index_dir / b, exist_ok=True)
283 |
284 | # We only write the validation epoch once --> held constant across _all_ of training!
285 | overwatch.info("\tWriting Validation Epoch to Disk...")
286 | val_epoch_idx, _, uniq_s = precompute_epoch(
287 | index_dir,
288 | val_registrations,
289 | val_dir,
290 | batch_formats,
291 | do_initial,
292 | do_final,
293 | initial_final_alpha,
294 | n_int,
295 | 0,
296 | is_validation=True,
297 | )
298 |
299 | # Update Trackers...
300 | if val_epoch_idx != -1:
301 | unique_states |= uniq_s
302 |
303 | # Compute length of epochs --> CPU Count should be no higher...
304 | epochs, n_frames_per_epoch = list(range(max_epochs)), -1
305 |
306 | # Load "existing" verification file (if possible)
307 | overwatch.info("\tLoading batch verification file (if possible)...")
308 | verified_batches = Path(artifact_path) / name / "verified-batches.json"
309 | if verified_batches.exists():
310 | with open(verified_batches, "r") as f:
311 | missing_epochs_per_format = json.load(f)
312 |
313 | # Set epochs list by taking union of missing epochs over formats...
314 | epochs = sorted(list(set().union(*missing_epochs_per_format.values())))
315 |
316 | # Dump the big objects into an mp.Manager() so that we can read efficiently from other workers...
317 | overwatch.info("\tPlacing the Train Registry into Shared Memory...")
318 | manager = mp.Manager()
319 | mg_registry = manager.dict(train_registrations)
320 |
321 | with mp.Pool(4) as pool:
322 | overwatch.info("\tWriting Train Batches per Epoch to Disk...")
323 |
324 | # Create partial function for multiprocessing pool...
325 | precompute_fn = partial(
326 | precompute_epoch,
327 | index_dir,
328 | mg_registry,
329 | train_dir,
330 | batch_formats,
331 | do_initial,
332 | do_final,
333 | initial_final_alpha,
334 | n_int,
335 | )
336 | for epoch_idx, n_frames, uniq_s in pool.imap_unordered(precompute_fn, epochs):
337 | if epoch_idx == -1:
338 | continue
339 |
340 | # Update Trackers
341 | unique_states |= uniq_s
342 | n_frames_per_epoch = n_frames
343 |
344 | # Statistics only make sense on initial computation... should unify with code above!
345 | overwatch.info(f"Train Uniqueness: {len(unique_states)} States & {len(mg_registry)} Utterances")
346 | overwatch.info(f"Final Statistics :: 1 Epoch has ~ {n_frames_per_epoch} Frames...")
347 | overwatch.info("Preprocessing Complete!")
348 |
--------------------------------------------------------------------------------
/voltron/preprocessing/v1/transforms.py:
--------------------------------------------------------------------------------
1 | """
2 | transforms.py
3 |
4 | Default image/video transformations for various datasets.
5 | """
6 | from typing import Any, Tuple
7 |
8 | import cv2
9 | import numpy as np
10 | import torch
11 | from torchvision.transforms import Compose, ConvertImageDtype, Lambda, Normalize
12 |
13 |
14 | # Definitions of Video Transformations (Reference: `something-something-v2-baseline`)
15 | class ComposeMix:
16 | def __init__(self, transforms):
17 | self.transforms = transforms
18 |
19 | def __call__(self, imgs):
20 | for transformation, scope in self.transforms:
21 | if scope == "img":
22 | for idx, img in enumerate(imgs):
23 | imgs[idx] = transformation(img)
24 | elif scope == "vid":
25 | imgs = transformation(imgs)
26 | else:
27 | raise ValueError("Please specify a valid transformation...")
28 | return imgs
29 |
30 |
31 | class RandomCropVideo:
32 | def __init__(self, size):
33 | self.size = size
34 |
35 | def __call__(self, imgs):
36 | th, tw = self.size
37 | h, w = imgs[0].shape[:2]
38 | x1, y1 = np.random.randint(0, w - tw), np.random.randint(0, h - th)
39 | for idx, img in enumerate(imgs):
40 | imgs[idx] = img[y1 : y1 + th, x1 : x1 + tw]
41 | return imgs
42 |
43 |
44 | class Scale:
45 | def __init__(self, size):
46 | self.size = size
47 |
48 | def __call__(self, img):
49 | return cv2.resize(img, tuple(self.size))
50 |
51 |
52 | def identity(x):
53 | """Transform needs to be pickleable for multiprocessing.spawn()."""
54 | return x.float()
55 |
56 |
57 | def get_pre_transform(dataset: str, resolution: int, scale_factor: float = 1.1) -> ComposeMix:
58 | """Defines a `pre` transform to be applied *when serializing the images* (first pass)."""
59 | if dataset == "sth-sth-v2":
60 | if scale_factor > 1:
61 | transform = ComposeMix(
62 | [
63 | [Scale((int(resolution * scale_factor), int(resolution * scale_factor))), "img"],
64 | [RandomCropVideo((resolution, resolution)), "vid"],
65 | ]
66 | )
67 | else:
68 | transform = ComposeMix(
69 | [
70 | [Scale((int(resolution * scale_factor), int(resolution * scale_factor))), "img"],
71 | ]
72 | )
73 |
74 | return transform
75 | else:
76 | raise NotImplementedError(f"(Pre) transforms for dataset `{dataset}` not yet implemented!")
77 |
78 |
79 | def get_online_transform(dataset: str, model_arch: str, normalization: Tuple[Any, Any]) -> Compose:
80 | """Defines an `online` transform to be applied *when batching the images* (during training/validation)."""
81 | if dataset == "sth-sth-v2":
82 | # Note: R3M does *not* expect normalized 0-1 (then ImageNet normalized) images --> drop the identity.
83 | if model_arch in {"v-r3m", "v-rn3m"}:
84 | return Compose([Lambda(identity)])
85 | else:
86 | return Compose([ConvertImageDtype(torch.float), Normalize(mean=normalization[0], std=normalization[1])])
87 | else:
88 | raise NotImplementedError(f"(Online) transforms for dataset `{dataset} not yet implemented!")
89 |
--------------------------------------------------------------------------------
/voltron/preprocessing/v1/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | utils.py
3 |
4 | Preprocessing utilities, including functions for dry-runs and processing a single video (helpers for multiprocessing
5 | calls down the lines).
6 | """
7 | import glob
8 | import json
9 | import logging
10 | import os
11 | import sys
12 | import time
13 | from pathlib import Path
14 | from typing import Any, Dict, List, Optional, Set, Tuple
15 |
16 | import av
17 | import cv2
18 | import numpy as np
19 | from hurry.filesize import alternative, size
20 | from rich.progress import track
21 | from tqdm import tqdm
22 |
23 | from voltron.preprocessing.v1.transforms import ComposeMix
24 |
25 | # Grab Logger
26 | overwatch = logging.getLogger(__file__)
27 | logging.getLogger("libav").setLevel(logging.ERROR)
28 |
29 |
30 | # Videos are saved as `train_dir/{vid}/{vid}_idx={i}.jpg
31 | def get_path(save_dir: Path, v: str, i: int) -> str:
32 | return str(save_dir / v / f"{v}_idx={i}.jpg")
33 |
34 |
35 | def do_dry_run(
36 | name: str,
37 | path: str,
38 | n_train_videos: int,
39 | n_val_videos: int,
40 | train_ids: List[str],
41 | val_ids: List[str],
42 | pre_transform: ComposeMix,
43 | n_samples: int = 1000,
44 | ) -> None:
45 | """Iterates through a small subset of the total dataset, logs n_frames & average image size for estimation."""
46 | dry_run_metrics = {
47 | "n_frames": [],
48 | "jpg_sizes": [],
49 | "n_samples": n_samples,
50 | "time_per_example": [],
51 | "blank": str(Path(path) / "blank.jpg"),
52 | }
53 | if name == "sth-sth-v2":
54 | for k, n_iter, vids in [("train", n_train_videos, train_ids), ("val", n_val_videos, val_ids)]:
55 | for idx in track(range(n_iter), description=f"Reading {k.capitalize()} Videos =>> ", transient=True):
56 | vid = vids[idx]
57 | container = av.open(str(Path(path) / "videos" / f"{vid}.webm"))
58 | try:
59 | imgs = [f.to_rgb().to_ndarray() for f in container.decode(video=0)]
60 | except (RuntimeError, ZeroDivisionError) as e:
61 | overwatch.error(f"{type(e).__name__}: WebM reader cannot open `{vid}.webm` - continuing...")
62 | continue
63 |
64 | # Close container
65 | container.close()
66 |
67 | # Apply `pre_transform`
68 | imgs = pre_transform(imgs)
69 |
70 | # Dry-Run Handling --> write a dummy JPEG to collect size statistics, dump, and move on...
71 | dry_run_metrics["n_frames"].append(len(imgs))
72 | while dry_run_metrics["n_samples"] > 0 and len(imgs) > 0:
73 | img = imgs.pop(0)
74 | cv2.imwrite(str(dry_run_metrics["blank"]), cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
75 | dry_run_metrics["jpg_sizes"].append(os.path.getsize(dry_run_metrics["blank"]))
76 | dry_run_metrics["n_samples"] -= 1
77 |
78 | # Compute nice totals for "dry-run" estimation
79 | total_clips = len(train_ids) + len(val_ids)
80 |
81 | else:
82 | raise NotImplementedError(f"Dry Run for Dataset `{name}` not yet implemented!")
83 |
84 | # Compute Aggregate Statistics and gently exit...
85 | avg_size, avg_frames = np.mean(dry_run_metrics["jpg_sizes"]), int(np.mean(dry_run_metrics["n_frames"]))
86 | overwatch.info("Dry-Run Statistics =>>")
87 | overwatch.info(f"\t> A video has on average `{avg_frames}` frames at {size(avg_size, system=alternative)}")
88 | overwatch.info(f"\t> So - 1 video ~ {size(avg_frames * avg_size, system=alternative)}")
89 | overwatch.info(
90 | f"\t> With the full dataset of {total_clips} Train + Val videos ~"
91 | f" {size(total_clips * avg_frames * avg_size, system=alternative)}"
92 | )
93 | overwatch.info("Dry-Run complete, do what you will... exiting ✌️")
94 |
95 | # Remove dummy file...
96 | os.remove(dry_run_metrics["blank"])
97 | sys.exit(0)
98 |
99 |
100 | def process_video(
101 | name: str, path: Path, save: Path, pre_transform: ComposeMix, item: Tuple[str, str]
102 | ) -> Tuple[Optional[str], Optional[Dict[str, Any]]]:
103 | """Processes a single video file, dumps to series of image files, and returns the registry contents."""
104 | if name == "sth-sth-v2":
105 | # For sth-sth-v2, `item` corresponds to a single video clip, so just a tuple!
106 | vid, lang = item
107 | container, registration = av.open(str(Path(path) / "videos" / f"{vid}.webm")), {"language": lang, "n_frames": 0}
108 | try:
109 | imgs = [f.to_rgb().to_ndarray() for f in container.decode(video=0)]
110 | except (RuntimeError, ZeroDivisionError) as e:
111 | overwatch.error(f"{type(e).__name__}: WebM reader cannot open `{vid}.webm` - skipping...")
112 | return None, None
113 |
114 | # Close container
115 | container.close()
116 |
117 | # Book-keeping
118 | os.makedirs(save / vid, exist_ok=True)
119 | registration["n_frames"] = len(imgs)
120 |
121 | # Early exit (writes are expensive)
122 | if len(glob.glob1(save / vid, "*.jpg")) == len(imgs):
123 | return vid, registration
124 |
125 | # Apply `pre_transform` --> write individual frames, register, and return
126 | imgs = pre_transform(imgs)
127 | for i in range(len(imgs)):
128 | cv2.imwrite(get_path(save, vid, i), cv2.cvtColor(imgs[i], cv2.COLOR_RGB2BGR))
129 |
130 | # Return title & registration
131 | return vid, registration
132 |
133 | else:
134 | raise NotImplementedError(f"Process Video for Dataset `{name}` not yet implemented!")
135 |
136 |
137 | # ruff: noqa: C901
138 | def precompute_epoch(
139 | index_dir: Path,
140 | registry: Dict[str, Any],
141 | vid_dir: Path,
142 | batch_formats: Tuple[Tuple[str, Tuple[str, ...]], ...],
143 | do_initial: bool,
144 | do_final: bool,
145 | initial_final_alpha: float,
146 | n_int: int,
147 | epoch: int,
148 | is_validation: bool = False,
149 | ) -> Tuple[int, int, Optional[Set[str]]]:
150 | index_file = "validation-batches.json" if is_validation else f"train-epoch={epoch}-batches.json"
151 |
152 | # Short-Circuit
153 | if all([(index_dir / key / index_file).exists() for key, _ in batch_formats]):
154 | return -1, -1, None
155 |
156 | # Random seed is inherited from parent process... we want new randomness w/ each process
157 | np.random.seed((os.getpid() * int(time.time())) % 123456789)
158 |
159 | # Create Tracking Variables
160 | unique_states, batches = set(), {b: [] for b, _ in batch_formats}
161 |
162 | # Iterate through Registry...
163 | for vid in tqdm(registry.keys(), desc=f"Epoch {epoch}", total=len(registry), position=epoch):
164 | # The initial/final states are sampled from the first [0, \alpha) and final 1-\alpha, 1] percent of the video
165 | n_frames = registry[vid]["n_frames"]
166 | initial_idx, final_idx = 0, n_frames - 1
167 | if do_initial:
168 | initial_idx = np.random.randint(0, np.around(n_frames * initial_final_alpha))
169 |
170 | if do_final:
171 | final_idx = np.random.randint(np.around(n_frames * (1 - initial_final_alpha)), n_frames)
172 |
173 | # Assertion --> initial_idx < final_idx - len(state_elements)
174 | assert initial_idx < final_idx - n_int, "Initial & Final are too close... no way to sample!"
175 |
176 | # Assume remaining elements are just random "interior" states --> sort to get ordering!
177 | sampled_idxs = np.random.choice(np.arange(initial_idx + 1, final_idx), size=n_int, replace=False)
178 | sampled_idxs = sorted(list(sampled_idxs))
179 |
180 | # Compile full-set "batch"
181 | retrieved_states = [get_path(vid_dir, vid, x) for x in [initial_idx, *sampled_idxs] + [final_idx]]
182 |
183 | # Add batch to index for specific batch_format key...
184 | batches[batch_formats[-1][0]].append({"vid": vid, "states": retrieved_states, "n_frames": n_frames})
185 | unique_states.update(retrieved_states)
186 |
187 | # Add all other batch formats to indices...
188 | for key, elements in batch_formats[:-1]:
189 | n_states = len([x for x in elements if "state_" in x])
190 | assert (n_states <= 2) or (
191 | n_states == len(retrieved_states)
192 | ), f"Strange value of n_states={n_states} > 2 and not equal to total possible of {len(retrieved_states)}"
193 |
194 | # States are all independent -- each of the retrieved states is its own example...
195 | if n_states == 1:
196 | for idx in range(len(retrieved_states)):
197 | batches[key].append({"vid": vid, "state": retrieved_states[idx], "n_frames": n_frames})
198 |
199 | # OK-Context is the only "valid" context for n_states == 2
200 | elif n_states == 2:
201 | assert elements == ["state_initial", "state_i", "language"], "n_states = 2 but not 0K context?"
202 |
203 | # Append 0th state to each of the remaining sampled contexts (usually 2 or 4)... each pair is an example
204 | for idx in range(1, len(retrieved_states)):
205 | batches[key].append(
206 | {"vid": vid, "states": [retrieved_states[0], retrieved_states[idx]], "n_frames": n_frames}
207 | )
208 |
209 | # We're treating the entire sequence of retrieved states as a single example (for TCN/R3M/Temporal Models)
210 | else:
211 | batches[key].append({"vid": vid, "states": retrieved_states, "n_frames": n_frames})
212 |
213 | # Write JSON Index directly to disk...
214 | for key in batches:
215 | with open(index_dir / key / index_file, "w") as f:
216 | json.dump(batches[key], f)
217 |
218 | return epoch, len(batches["state"]), unique_states
219 |
--------------------------------------------------------------------------------
/voltron/util/__init__.py:
--------------------------------------------------------------------------------
1 | from .checkpointing import CheckpointSaver, do_resume
2 | from .metrics import Metrics
3 | from .utilities import ResumeableDistributedSampler, set_global_seed
4 |
--------------------------------------------------------------------------------
/voltron/util/checkpointing.py:
--------------------------------------------------------------------------------
1 | """
2 | checkpointing.py
3 |
4 | Core utility class for handling model/optimizer serialization & checkpointing -- including resume from checkpoint logic.
5 |
6 | Support the following strategies:
7 | - (k, -1, -1) --> Keep only the most recent "k" epoch checkpoints
8 | - (k, m, -1) --> Keep the most recent "k" epoch checkpoints and *every* m epoch checkpoint
9 | - (k, m, s = 2500) --> Keep "k" and "m" subject to above, but also keep *s* step checkpoints for current epoch
10 | """
11 | import logging
12 | import os
13 | import re
14 | from collections import deque
15 | from pathlib import Path
16 | from typing import Any, Optional, Tuple
17 |
18 | import torch
19 | import torch.nn as nn
20 | from torch.optim.optimizer import Optimizer
21 |
22 | # Grab Logger
23 | overwatch = logging.getLogger(__file__)
24 |
25 |
26 | class FixedDeck(deque):
27 | def __init__(self, maxlen: int) -> None:
28 | super().__init__(maxlen=maxlen)
29 |
30 | def append(self, x: Any) -> Any:
31 | pop_value = None
32 | if self.__len__() == self.maxlen:
33 | pop_value = self.__getitem__(0)
34 |
35 | # Perform parent append and return popped value, if any!
36 | super().append(x)
37 | return pop_value
38 |
39 |
40 | class CheckpointSaver:
41 | def __init__(self, strategy: Tuple[int, int, int], run_dir: str, is_rank_zero: bool = False) -> None:
42 | """
43 | Create a checkpoint saver with the provided strategy that saves to the given path.
44 |
45 | :param strategy: Strategy, following the (k, -1, -1) -- (k, m, -1) -- (k, m, s) description above.
46 | :param run_dir: Path to root of `run_dir`.
47 | :param is_rank_zero: Boolean whether this process is global zero (no-op if not)!
48 | """
49 | (self.k, self.m, self.s), self.run_dir, self.is_rank_zero = strategy, run_dir, is_rank_zero
50 | self.recents, self.intervals, self.step_checkpoints = FixedDeck(maxlen=self.k), set(), set()
51 |
52 | # If `self.s == -1` --> *Disable* step checkpoints (only at save end of epoch!)
53 | self.enable_step = self.s != -1
54 |
55 | # Create "checkpoints" subdirectory
56 | self.path = Path(run_dir) / "checkpoints"
57 | if self.is_rank_zero:
58 | os.makedirs(self.path, exist_ok=True)
59 |
60 | # Populate `step_checkpoints` on __init__ (if resuming *within* an epoch!)
61 | self.step_checkpoints.update([c for c in self.path.iterdir() if "local-epoch=" in str(c)])
62 |
63 | # Created Saver...
64 | overwatch.info(f"Created CheckpointSaver with `k = {self.k}` -- `m = {self.m}` -- s = {self.s}!")
65 |
66 | def save(
67 | self,
68 | epoch: int,
69 | is_local_step: bool,
70 | model: nn.Module,
71 | optimizer: Optimizer,
72 | duration: int,
73 | local_step: Optional[int] = None,
74 | train_loss: Optional[float] = None,
75 | val_loss: Optional[float] = None,
76 | ) -> None:
77 | """Performs a global zero save operation, unlinking stale checkpoints if necessary."""
78 | if not self.is_rank_zero:
79 | return
80 |
81 | # Check if saving a `local_step` (within an epoch) or if end of epoch...
82 | if self.enable_step and is_local_step and (local_step % self.s) == 0:
83 | step_checkpoint = self.path / f"local-epoch={epoch}-step={local_step}-t={duration}.pt"
84 | torch.save(
85 | {"model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict()}, step_checkpoint
86 | )
87 |
88 | # Update Relevant Trackers...
89 | self.step_checkpoints.add(step_checkpoint)
90 |
91 | elif not is_local_step:
92 | if train_loss is None and val_loss is None:
93 | checkpoint = self.path / f"epoch={epoch}-train=inf-val=inf-t={duration}.pt"
94 | else:
95 | checkpoint = self.path / f"epoch={epoch}-train={train_loss:.4f}-val={val_loss:.4f}-t={duration}.pt"
96 | torch.save(
97 | {"model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict()}, checkpoint
98 | )
99 |
100 | # Update Relevant Trackers
101 | if epoch % self.m == 0:
102 | self.intervals.add(checkpoint)
103 |
104 | # Remove all "step_checkpoints" now that we've made it to the end of an epoch!
105 | while len(self.step_checkpoints) > 0:
106 | os.remove(self.step_checkpoints.pop())
107 |
108 | # Add to recents & flush stale checkpoints...
109 | to_remove = self.recents.append(checkpoint)
110 | if to_remove is not None and to_remove not in self.intervals:
111 | os.remove(to_remove)
112 |
113 |
114 | def do_resume(resume: bool, run_dir: str) -> Tuple[Optional[Path], int, int]:
115 | """Handle `resume` logic --> consists of retrieving checkpoint_path and epoch/step computation (if resuming)."""
116 | if not resume:
117 | # We're starting a fresh run --> return None for checkpoint_path, resume_epoch = 0, resume_step = 0
118 | return None, 0, 0
119 |
120 | # === Auto-Resume Logic ===
121 | # **IMPORTANT**: We're making a few assumptions on resuming that should eventually become explicit checks:
122 | # - `accumulate_grad_batches` is exactly the same when resuming; this means:
123 | # + `model_cfg.effective_bsz`, `model_cfg.fabric_bsz`, & `accelerator_cfg.num_accelerators` are the same!
124 | # - The Weights & Biases directory `run_dir/wandb` only contains a *single run*
125 | # - The `param_groups` in `optimizer.state_dict()` are exactly the same across resumes!
126 | # + This means that (and generally should be true for resuming altogether) the architecture is the same!
127 | # - The `cfg.seed` should be the same (again, should generally be true...)
128 | all_checkpoints_path, resume_checkpoint, resume_epoch, resume_step = Path(run_dir) / "checkpoints", None, 0, 0
129 | if all_checkpoints_path.exists() and any(all_checkpoints_path.iterdir()):
130 | # Parse out the latest "complete" epoch checkpoint, as well as any "local step" checkpoints...
131 | checkpoints = list(all_checkpoints_path.iterdir())
132 | complete_checkpoint, complete_epoch = max(
133 | [
134 | (c, int(re.search("epoch=(.+?)-train", c.name).group(1)))
135 | for c in checkpoints
136 | if "local-epoch=" not in str(c)
137 | ],
138 | key=lambda x: x[1],
139 | )
140 |
141 | # Case 1 :: We have "local step" checkpoints --> will always override any "full epoch" checkpoints...
142 | local = [
143 | (
144 | c,
145 | int(re.search("local-epoch=(.+?)-step", c.name).group(1)),
146 | int(re.search("step=(.+?)[.-]", c.name).group(1)),
147 | )
148 | for c in checkpoints
149 | if "local-epoch=" in str(c)
150 | ]
151 | if len(local) > 0:
152 | # Parse out (epoch, "highest" step) + assert no great "full epoch" checkpoint exists!
153 | resume_checkpoint, resume_epoch, resume_step = max(local, key=lambda x: x[1:])
154 | assert resume_epoch == complete_epoch, "Epoch mismatch in `resume` from local_step!"
155 |
156 | # Case 2 :: Otherwise, we're just going to start with the last "complete" epoch
157 | else:
158 | resume_checkpoint, resume_epoch = complete_checkpoint, complete_epoch
159 |
160 | return resume_checkpoint, resume_epoch, resume_step
161 |
--------------------------------------------------------------------------------
/voltron/util/metrics.py:
--------------------------------------------------------------------------------
1 | """
2 | metrics.py
3 |
4 | Utility classes defining Metrics containers with model-specific logging to various endpoints (JSONL local logs, W&B).
5 | """
6 | import os
7 | import re
8 | import time
9 | from abc import ABC, abstractmethod
10 | from collections import deque
11 | from datetime import datetime
12 | from pathlib import Path
13 | from typing import Any, Dict, List, Optional, Tuple, Union
14 |
15 | import jsonlines
16 | import numpy as np
17 | import torch
18 | import wandb
19 |
20 | from voltron.conf import TrackingConfig
21 |
22 | # === Define Loggers (`Logger` is an abstract base class) ===
23 |
24 |
25 | class Logger(ABC):
26 | def __init__(self, run_id: str, hparams: Dict[str, Any], is_rank_zero: bool = False) -> None:
27 | self.run_id, self.hparams, self.is_rank_zero = run_id, hparams, is_rank_zero
28 |
29 | @abstractmethod
30 | def write_hyperparameters(self) -> None:
31 | raise NotImplementedError("Logger is an abstract class!")
32 |
33 | @abstractmethod
34 | def write(self, global_step: int, metrics: Dict[str, Union[int, float]]) -> None:
35 | raise NotImplementedError("Logger is an abstract class!")
36 |
37 | def finalize(self) -> None:
38 | time.sleep(1)
39 |
40 |
41 | class JSONLinesLogger(Logger):
42 | def write_hyperparameters(self) -> None:
43 | if not self.is_rank_zero:
44 | return
45 |
46 | # Only log if `is_rank_zero`
47 | with jsonlines.open(f"{self.run_id}.jsonl", mode="w", sort_keys=True) as js_logger:
48 | js_logger.write(
49 | {
50 | "run_id": self.run_id,
51 | "start_time": datetime.now().strftime("%m-%d-%H:%M"),
52 | "hparams": self.hparams,
53 | }
54 | )
55 |
56 | def write(self, global_step: int, metrics: Dict[str, Union[int, float]]) -> None:
57 | if not self.is_rank_zero:
58 | return
59 |
60 | # Only log if `is_rank_zero`
61 | with jsonlines.open(f"{self.run_id}.jsonl", mode="a", sort_keys=True) as js_logger:
62 | js_logger.write(metrics)
63 |
64 |
65 | class WeightsBiasesLogger(Logger):
66 | def __init__(
67 | self,
68 | run_id: str,
69 | hparams: Dict[str, Any],
70 | tracking_cfg: TrackingConfig,
71 | tags: List[str],
72 | resume: bool = False,
73 | resume_id: Optional[str] = None,
74 | is_rank_zero: bool = False,
75 | ) -> None:
76 | super().__init__(run_id, hparams, is_rank_zero)
77 | self.tracking_cfg, self.tags, self.resume, self.resume_id = tracking_cfg, tags, resume, resume_id
78 | self.path = Path(os.getcwd() if self.tracking_cfg.directory is None else self.tracking_cfg.directory)
79 |
80 | # Handle (Automatic) Resume if `resume = True`
81 | if self.resume and self.resume_id is None:
82 | wandb_path = self.path / "wandb"
83 | if wandb_path.exists() and any((wandb_path / "latest-run").iterdir()):
84 | # Parse unique `run_id` from the `.wandb.` file...
85 | wandb_fns = [f.name for f in (wandb_path / "latest-run").iterdir() if f.name.endswith(".wandb")]
86 | assert len(wandb_fns) == 1, f"There should only be 1 `.wandb.` file... found {len(wandb_fns)}!"
87 |
88 | # Regex Match on `run-{id}.wandb`
89 | self.resume_id = re.search("run-(.+?).wandb", wandb_fns[0]).group(1)
90 |
91 | elif wandb_path.exists():
92 | raise ValueError("Starting Training from Scratch with Preexisting W&B Directory; Remove to Continue!")
93 |
94 | # Call W&B.init()
95 | self.initialize()
96 |
97 | def initialize(self) -> None:
98 | """Run W&B.init on the guarded / rank-zero process."""
99 | if not self.is_rank_zero:
100 | return
101 |
102 | # Only initialize / log if `is_rank_zero`
103 | wandb.init(
104 | project=self.tracking_cfg.project,
105 | entity=self.tracking_cfg.entity,
106 | config=self.hparams,
107 | name=self.run_id,
108 | dir=self.path,
109 | tags=self.tags,
110 | notes=self.tracking_cfg.notes,
111 | resume="allow" if self.resume else False,
112 | id=self.resume_id,
113 | )
114 |
115 | def write_hyperparameters(self) -> None:
116 | if not self.is_rank_zero:
117 | return
118 |
119 | # Only log if `is_rank_zero`
120 | wandb.config = self.hparams
121 |
122 | def write(self, global_step: int, metrics: Dict[str, Union[int, float]]) -> None:
123 | if not self.is_rank_zero:
124 | return
125 |
126 | # Only log if `is_rank_zero`
127 | wandb.log(metrics, step=global_step)
128 |
129 | def finalize(self) -> None:
130 | wandb.finish()
131 | time.sleep(150)
132 |
133 |
134 | # === Core Metrics Container :: Responsible for Initializing Loggers and Compiling/Pushing Metrics ===
135 |
136 |
137 | class Metrics:
138 | def __init__(
139 | self,
140 | active_loggers: List[str],
141 | run_id: str,
142 | hparams: Dict[str, Any],
143 | model_arch: str,
144 | is_rank_zero: bool,
145 | tracking_cfg: Optional[TrackingConfig] = None,
146 | tags: Optional[List[str]] = None,
147 | resume: bool = False,
148 | resume_id: Optional[str] = None,
149 | window: int = 128,
150 | ) -> None:
151 | """High-Level Container Logic for Metrics Logging; logic defined for each model architecture!"""
152 | self.model_arch, self.is_rank_zero, self.window = model_arch, is_rank_zero, window
153 |
154 | # Initialize Loggers
155 | self.loggers = []
156 | for log_type in active_loggers:
157 | if log_type == "jsonl":
158 | logger = JSONLinesLogger(run_id, hparams, is_rank_zero=is_rank_zero)
159 | elif log_type == "wandb":
160 | logger = WeightsBiasesLogger(
161 | run_id, hparams, tracking_cfg, tags, resume, resume_id, is_rank_zero=is_rank_zero
162 | )
163 | else:
164 | raise ValueError(f"Logger `{log_type}` is not defined!")
165 |
166 | # Add Hyperparameters --> Add to `self.loggers`
167 | logger.write_hyperparameters()
168 | self.loggers.append(logger)
169 |
170 | # Create Universal Trackers
171 | self.global_step, self.start_time, self.resume_time, self.step_start_time = 0, time.time(), 0, time.time()
172 | self.tracker = {
173 | "loss": deque(maxlen=self.window),
174 | "lr": [],
175 | "step_time": deque(maxlen=self.window),
176 | }
177 |
178 | # Create Model-Specific Trackers
179 | if self.model_arch == "v-mvp":
180 | self.tracker.update({"reconstruction_loss": deque(maxlen=self.window)})
181 |
182 | elif self.model_arch in {"v-r3m", "v-rn3m"}:
183 | self.tracker.update(
184 | {
185 | "tcn_loss": deque(maxlen=self.window),
186 | "reward_loss": deque(maxlen=self.window),
187 | "l1_loss": deque(maxlen=self.window),
188 | "l2_loss": deque(maxlen=self.window),
189 | "tcn_accuracy": deque(maxlen=self.window),
190 | "reward_accuracy": deque(maxlen=self.window),
191 | }
192 | )
193 |
194 | elif self.model_arch == "v-cond":
195 | self.tracker.update({"reconstruction_loss": deque(maxlen=self.window)})
196 |
197 | elif self.model_arch == "v-dual":
198 | self.tracker.update(
199 | {
200 | "reconstruction_loss": deque(maxlen=self.window),
201 | "zero_reconstruction_loss": deque(maxlen=self.window),
202 | "k_reconstruction_loss": deque(maxlen=self.window),
203 | }
204 | )
205 |
206 | elif self.model_arch == "v-gen":
207 | self.tracker.update(
208 | {
209 | "reconstruction_loss": deque(maxlen=self.window),
210 | "zero_reconstruction_loss": deque(maxlen=self.window),
211 | "k_reconstruction_loss": deque(maxlen=self.window),
212 | "lm_loss": deque(maxlen=self.window),
213 | "lm_ppl": deque(maxlen=self.window),
214 | }
215 | )
216 |
217 | else:
218 | raise ValueError(f"Metrics for Model `{self.model_arch}` are not implemented!")
219 |
220 | def itemize(self) -> Dict[str, torch.Tensor]:
221 | """Utility method for converting `deque[torch.Tensor] --> mean over Tensors."""
222 | return {
223 | k: torch.stack(list(v)).mean().item()
224 | for k, v in self.tracker.items()
225 | if k not in {"loss", "lr", "step_time"}
226 | }
227 |
228 | def log(self, global_step: int, metrics: Dict[str, Union[int, float]]) -> None:
229 | for logger in self.loggers:
230 | logger.write(global_step, metrics)
231 |
232 | def finalize(self) -> None:
233 | for logger in self.loggers:
234 | logger.finalize()
235 |
236 | def get_status(self, epoch: int, loss: Optional[torch.Tensor] = None) -> str:
237 | lr = self.tracker["lr"][-1] if len(self.tracker["lr"]) > 0 else 0
238 | if loss is None:
239 | return f"=>> [Epoch {epoch:03d}] Global Step {self.global_step:06d} =>> LR :: {lr:.6f}"
240 |
241 | # Otherwise, embed `loss` in status!
242 | return f"=>> [Epoch {epoch:03d}] Global Step {self.global_step:06d} =>> LR :: {lr:.6f} -- Loss :: {loss:.4f}"
243 |
244 | def commit(
245 | self,
246 | *,
247 | global_step: Optional[int] = None,
248 | resume_time: Optional[int] = None,
249 | lr: Optional[float] = None,
250 | update_step_time: bool = False,
251 | **kwargs,
252 | ) -> None:
253 | """Update all metrics in `self.tracker` by iterating through special positional arguments & kwargs."""
254 | if not self.is_rank_zero:
255 | return
256 |
257 | # Special Positional Arguments
258 | if global_step is not None:
259 | self.global_step = global_step
260 |
261 | if resume_time is not None:
262 | self.resume_time = resume_time
263 |
264 | if lr is not None:
265 | self.tracker["lr"].append(lr)
266 |
267 | if update_step_time:
268 | self.tracker["step_time"].append(time.time() - self.step_start_time)
269 | self.step_start_time = time.time()
270 |
271 | # Generic Keyword Arguments
272 | for key, value in kwargs.items():
273 | self.tracker[key].append(value.detach())
274 |
275 | def push(self, epoch: int) -> str:
276 | """Push current metrics to loggers with model-specific handling."""
277 | if not self.is_rank_zero:
278 | return
279 |
280 | loss = torch.stack(list(self.tracker["loss"])).mean().item()
281 | step_time, lr = np.mean(list(self.tracker["step_time"])), self.tracker["lr"][-1]
282 | status = self.get_status(epoch, loss)
283 |
284 | # Model-Specific Handling
285 | itemized = self.itemize()
286 | if self.model_arch == "v-mvp":
287 | self.log(
288 | self.global_step,
289 | metrics={
290 | "Pretrain/Step": self.global_step,
291 | "Pretrain/Epoch": epoch,
292 | "Pretrain/V-MVP Train Loss": loss,
293 | "Pretrain/Reconstruction Loss": itemized["reconstruction_loss"],
294 | "Pretrain/Learning Rate": lr,
295 | "Pretrain/Step Time": step_time,
296 | },
297 | )
298 |
299 | elif self.model_arch in {"v-r3m", "v-rn3m"}:
300 | self.log(
301 | self.global_step,
302 | metrics={
303 | "Pretrain/Step": self.global_step,
304 | "Pretrain/Epoch": epoch,
305 | f"Pretrain/V-{'R3M' if self.model_arch == 'v-r3m' else 'RN3M'} Train Loss": loss,
306 | "Pretrain/TCN Loss": itemized["tcn_loss"],
307 | "Pretrain/Reward Loss": itemized["reward_loss"],
308 | "Pretrain/L1 Loss": itemized["l1_loss"],
309 | "Pretrain/L2 Loss": itemized["l2_loss"],
310 | "Pretrain/TCN Accuracy": itemized["tcn_accuracy"],
311 | "Pretrain/Reward Accuracy": itemized["reward_accuracy"],
312 | "Pretrain/Learning Rate": lr,
313 | "Pretrain/Step Time": step_time,
314 | },
315 | )
316 |
317 | elif self.model_arch == "v-cond":
318 | self.log(
319 | self.global_step,
320 | metrics={
321 | "Pretrain/Step": self.global_step,
322 | "Pretrain/Epoch": epoch,
323 | "Pretrain/V-Cond Train Loss": loss,
324 | "Pretrain/Reconstruction Loss": itemized["reconstruction_loss"],
325 | "Pretrain/Learning Rate": lr,
326 | "Pretrain/Step Time": step_time,
327 | },
328 | )
329 |
330 | elif self.model_arch == "v-dual":
331 | self.log(
332 | self.global_step,
333 | metrics={
334 | "Pretrain/Step": self.global_step,
335 | "Pretrain/Epoch": epoch,
336 | "Pretrain/V-Dual Train Loss": loss,
337 | "Pretrain/Reconstruction Loss": itemized["reconstruction_loss"],
338 | "Pretrain/Zero Reconstruction Loss": itemized["zero_reconstruction_loss"],
339 | "Pretrain/K Reconstruction Loss": itemized["k_reconstruction_loss"],
340 | "Pretrain/Learning Rate": lr,
341 | "Pretrain/Step Time": step_time,
342 | },
343 | )
344 |
345 | elif self.model_arch == "v-gen":
346 | self.log(
347 | self.global_step,
348 | metrics={
349 | "Pretrain/Step": self.global_step,
350 | "Pretrain/Epoch": epoch,
351 | "Pretrain/V-Gen Train Loss": loss,
352 | "Pretrain/Reconstruction Loss": itemized["reconstruction_loss"],
353 | "Pretrain/Zero Reconstruction Loss": itemized["zero_reconstruction_loss"],
354 | "Pretrain/K Reconstruction Loss": itemized["k_reconstruction_loss"],
355 | "Pretrain/CLM Loss": itemized["lm_loss"],
356 | "Pretrain/CLM Perplexity": itemized["lm_ppl"],
357 | "Pretrain/LM Loss": itemized["lm_loss"],
358 | "Pretrain/LM Perplexity": itemized["lm_ppl"],
359 | "Pretrain/Learning Rate": lr,
360 | "Pretrain/Step Time": step_time,
361 | },
362 | )
363 |
364 | else:
365 | raise ValueError(f"Metrics.push() for Model `{self.model_arch}` is not implemented!")
366 |
367 | return status
368 |
369 | def push_epoch(self, epoch: int, val_loss: torch.Tensor) -> Tuple[str, torch.Tensor, int]:
370 | """End-of-Epoch => Push accumulated metrics to loggers with model-specific handling."""
371 | if not self.is_rank_zero:
372 | return
373 |
374 | # Compute End-of-Epoch Specialized Metrics
375 | loss, step_time = torch.stack(list(self.tracker["loss"])).mean(), np.mean(list(self.tracker["step_time"]))
376 | lr, duration = self.tracker["lr"][-1], int(time.time() - self.start_time) + self.resume_time
377 | epoch_status = (
378 | f"[Epoch {epoch:03d}] Global Step {self.global_step:06d} =>> LR :: {lr:.6f} -- Loss :: {loss:.4f} "
379 | f"-- Val Loss :: {val_loss:.4f} -- Total Time (sec) :: {duration}"
380 | )
381 |
382 | # Log for Model
383 | p_arch = {
384 | "v-mvp": "MVP",
385 | "v-r3m": "R3M (ViT)",
386 | "v-rn3m": "R3M (RN)",
387 | "v-cond": "V-Cond",
388 | "v-dual": "V-Dual",
389 | "v-gen": "V-Gen",
390 | }[self.model_arch]
391 | self.log(
392 | self.global_step,
393 | metrics={
394 | "Pretrain/Step": self.global_step,
395 | "Pretrain/Epoch": epoch,
396 | "Pretrain/Training Duration": duration,
397 | f"Pretrain/{p_arch} Train Epoch Loss": loss.item(),
398 | f"Pretrain/{p_arch} Train Loss": loss.item(),
399 | f"Pretrain/{p_arch} Validation Loss": val_loss.item(),
400 | "Pretrain/Learning Rate": lr,
401 | "Pretrain/Step Time": step_time,
402 | },
403 | )
404 |
405 | return epoch_status, loss, duration
406 |
--------------------------------------------------------------------------------
/voltron/util/utilities.py:
--------------------------------------------------------------------------------
1 | """
2 | utilities.py
3 |
4 | General utilities for randomness, distributed training, and miscellaneous checks in PyTorch.
5 |
6 | === Randomness ===
7 |
8 | Random `seed_everything` functionality is taken directly from PyTorch-Lighting:
9 | > Ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/seed.py
10 |
11 | This is pretty important to get right if we're every randomly generating our masks (or prefix dropout) inside our
12 | Dataset __getitem__() with multiple workers... if not handled properly, we will get repeated augmentations anytime
13 | we inject randomness from non-PyTorch sources (e.g., numpy, random)!
14 | > Ref: https://tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects/
15 |
16 | === Distributed / DDP Training ====
17 |
18 | Utilities provide a standard API across single-GPU/multi-GPU/multi-node training. Assumes that code is running with
19 | one of the following strategies:
20 | - Single Process (on CPU?, GPU)
21 | - DDP (GPU, Multi-Node GPU) --> uses the `torchrun`/`torch.distributed` API & semantics
22 |
23 | Key Terminology
24 | -> World Size :: Total number of processes distributed over (# nodes x # devices) -- assumed homogenous!
25 | -> Rank :: Integer index of current process in the total world size
26 | -> Local Rank :: Local index on given node in [0, Devices per Node]
27 | """
28 | import os
29 | import random
30 | from typing import Callable, Iterator, Optional, TypeVar
31 |
32 | import numpy as np
33 | import torch
34 | from torch.utils.data import Dataset
35 | from torch.utils.data.distributed import DistributedSampler
36 |
37 | T_co = TypeVar("T_co", covariant=True)
38 |
39 |
40 | # === Randomness ===
41 |
42 |
43 | def worker_init_function(worker_id: int) -> None:
44 | """
45 | Borrowed directly from PyTorch-Lightning; inspired by this issue comment in the PyTorch repo:
46 | > Ref: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562
47 |
48 | Intuition: You can think of the seed sequence spawn function as a "janky" torch.Generator() or jax.PRNGKey that
49 | you can run iterative splitting on to get new (predictable) randomness.
50 |
51 | :param worker_id: Identifier for the given worker [0, num_workers) for the Dataloader in question.
52 | """
53 | # Get current `rank` (if running distributed) and `process_seed`
54 | global_rank, process_seed = int(os.environ["LOCAL_RANK"]), torch.initial_seed()
55 |
56 | # Back out the "base" (original) seed - the per-worker seed is set in PyTorch:
57 | # > https://pytorch.org/docs/stable/data.html#data-loading-randomness
58 | base_seed = process_seed - worker_id
59 |
60 | # "Magic" code --> basically creates a seed sequence that mixes different "sources" and seeds every library...
61 | seed_seq = np.random.SeedSequence([base_seed, worker_id, global_rank])
62 |
63 | # Use 128 bits (4 x 32-bit words) to represent seed --> generate_state(k) produces a `k` element array!
64 | np.random.seed(seed_seq.generate_state(4))
65 |
66 | # Spawn distinct child sequences for PyTorch (reseed) and stdlib random
67 | torch_seed_seq, random_seed_seq = seed_seq.spawn(2)
68 |
69 | # Torch Manual seed takes 64 bits (so just specify a dtype of uint64
70 | torch.manual_seed(torch_seed_seq.generate_state(1, dtype=np.uint64)[0])
71 |
72 | # Use 128 Bits for `random`, but express as integer instead of as an array
73 | random_seed = (random_seed_seq.generate_state(2, dtype=np.uint64).astype(list) * [1 << 64, 1]).sum()
74 | random.seed(random_seed)
75 |
76 |
77 | def set_global_seed(seed: int, get_worker_init_fn: bool = False) -> Optional[Callable[[int], None]]:
78 | """Sets seed for all randomness libraries (mostly random, numpy, torch) and produces a `worker_init_fn`"""
79 | assert np.iinfo(np.uint32).min < seed < np.iinfo(np.uint32).max, "Seed outside the np.uint32 bounds!"
80 |
81 | # Set Seed as an Environment Variable
82 | os.environ["EXPERIMENT_GLOBAL_SEED"] = str(seed)
83 | random.seed(seed)
84 | np.random.seed(seed)
85 | torch.manual_seed(seed)
86 |
87 | return worker_init_function if get_worker_init_fn else None
88 |
89 |
90 | # === Distributed Training ===
91 |
92 |
93 | class ResumeableDistributedSampler(DistributedSampler):
94 | def __init__(
95 | self,
96 | seen_examples: int,
97 | resume_epoch: int,
98 | dataset: Dataset,
99 | num_replicas: int,
100 | rank: int,
101 | shuffle: bool = True,
102 | seed: int = 0,
103 | ) -> None:
104 | super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed)
105 | self.seen_examples, self.resume_epoch, self.do_resume = seen_examples, resume_epoch, True
106 |
107 | # Set `seen_examples_per_replica` --> this is necessary for when we re-wrap the iterator in self.__iter__()
108 | # > Note: `seen_examples` is across _all_ replicas --> so divide!
109 | self.seen_examples_per_replica = self.seen_examples // self.num_replicas
110 |
111 | def __iter__(self) -> Iterator[T_co]:
112 | epoch_iterator = super().__iter__()
113 | if self.do_resume:
114 | # Unpack iterator --> list, slice off the first `seen_examples_per_replica` examples, and re-wrap!
115 | leftover_idxs = list(epoch_iterator)[self.seen_examples_per_replica :]
116 | return iter(leftover_idxs)
117 | else:
118 | return epoch_iterator
119 |
120 | def __len__(self) -> int:
121 | if self.do_resume:
122 | # Remove the "seen" sample from self.num_samples; num_samples is *per replica*!
123 | return self.num_samples - self.seen_examples_per_replica
124 | else:
125 | return self.num_samples
126 |
127 | def set_epoch(self, epoch: int) -> None:
128 | # If epoch != self.resume_epoch --> we're in "regular DistributedSampler" mode (just a wrapper class)
129 | # > Intuition: We should *only* truncate examples on the first epoch upon resuming!
130 | self.epoch = epoch
131 | if self.epoch != self.resume_epoch:
132 | self.do_resume = False
133 |
--------------------------------------------------------------------------------
/voltron/util/v1/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/siddk/voltron-robotics/1b299bf5cfa06673a3738aa6e15423b92a9922cd/voltron/util/v1/__init__.py
--------------------------------------------------------------------------------
/voltron/util/v1/checkpointing.py:
--------------------------------------------------------------------------------
1 | """
2 | checkpointing.py
3 |
4 | XLA-specific utility class for handling model/optimizer serialization & checkpointing.
5 |
6 | Support the following strategies:
7 | - (k, -1, -1) --> Keep only the most recent "k" epoch checkpoints
8 | - (k, m, -1) --> Keep the most recent "k" epoch checkpoints and *every* m epoch checkpoint
9 | - (k, m, s = 2500) --> Keep "k" and "m" subject to above, but also keep *s* step checkpoints for current epoch
10 | """
11 | import os
12 | from collections import deque
13 | from pathlib import Path
14 | from typing import Any, Optional, Tuple
15 |
16 | import torch.nn as nn
17 | from torch.optim.optimizer import Optimizer
18 |
19 |
20 | class FixedDeck(deque):
21 | def __init__(self, maxlen: int) -> None:
22 | super().__init__(maxlen=maxlen)
23 |
24 | def append(self, x: Any) -> Any:
25 | pop_value = None
26 | if self.__len__() == self.maxlen:
27 | pop_value = self.__getitem__(0)
28 |
29 | # Perform parent append and return popped value, if any!
30 | super().append(x)
31 | return pop_value
32 |
33 |
34 | class XLACheckpointSaver:
35 | def __init__(self, strategy: Tuple[int, int, int], run_dir: str) -> None:
36 | """
37 | Create a checkpoint saver with the provided strategy that saves to the given path, with XLA-specific handling.
38 |
39 | :param strategy: Strategy, following the (k, -1, -1) -- (k, m, -1) -- (k, m, s) description above.
40 | :param run_dir: Path to root of `run_dir`
41 | """
42 | import torch_xla.core.xla_model as xm
43 |
44 | (self.k, self.m, self.s), self.run_dir = strategy, run_dir
45 | self.recents, self.intervals, self.step_checkpoints = FixedDeck(maxlen=self.k), set(), set()
46 |
47 | # If `self.s` is -1 --> disable step_checkpoints
48 | self.enable_step = self.s != -1
49 |
50 | # Create "checkpoints" subdirectory
51 | self.path = Path(run_dir) / "checkpoints"
52 | if xm.is_master_ordinal(local=False):
53 | os.makedirs(self.path, exist_ok=True)
54 |
55 | # Populate `step_checkpoints` on __init__ (if resuming *within* an epoch...)
56 | self.step_checkpoints.update([c for c in self.path.iterdir() if "local-epoch=" in str(c)])
57 |
58 | # Create Saver
59 | xm.master_print(f"Created Saver w/ `k` = {self.k}, `m` = {self.m}`, `s` = {self.s}!")
60 |
61 | def save(
62 | self,
63 | epoch: int,
64 | is_local_step: bool,
65 | model: nn.Module,
66 | optimizer: Optimizer,
67 | duration: int,
68 | local_step: Optional[int] = None,
69 | train_loss: Optional[float] = None,
70 | val_loss: Optional[float] = None,
71 | ) -> None:
72 | """Performs the save operation, unlinking existing stale checkpoints, if necessary."""
73 | import torch_xla.core.xla_model as xm
74 |
75 | # Check if saving a `local_step` (within an epoch) or if saving an `epoch`
76 | if self.enable_step and is_local_step and (local_step % self.s) == 0:
77 | # Create filename
78 | step_checkpoint = self.path / f"local-epoch={epoch}-step={local_step}-t={duration}.pt"
79 |
80 | # Perform actual save action...
81 | # > IMPORTANT --> XLA/XM will throw an error if optimizer has "param_groups" so only save "state"...
82 | xm.save([model.state_dict(), optimizer.state_dict()["state"]], step_checkpoint)
83 | if xm.is_master_ordinal(local=False):
84 | self.step_checkpoints.add(step_checkpoint)
85 |
86 | elif not is_local_step:
87 | # Create filename
88 | if train_loss is None and val_loss is None:
89 | checkpoint = self.path / f"epoch={epoch}-train=inf-val=inf-t={duration}.pt"
90 | else:
91 | checkpoint = self.path / f"epoch={epoch}-train={train_loss:.4f}-val={val_loss:.4f}-t={duration}.pt"
92 |
93 | # Perform actual save action...
94 | # > IMPORTANT --> XLA/XM will throw an error if optimizer has "param_groups" so only save "state"...
95 | xm.save([model.state_dict(), optimizer.state_dict()["state"]], checkpoint)
96 |
97 | if xm.is_master_ordinal(local=False):
98 | # Conditional Check for M -- Keep if modulated by interval
99 | if epoch % self.m == 0:
100 | self.intervals.add(checkpoint)
101 |
102 | # Remove all "step_checkpoints" now that we successfully made it to the end of the epoch!
103 | while len(self.step_checkpoints) > 0:
104 | os.remove(self.step_checkpoints.pop())
105 |
106 | # Finally, recency add & unlink/delete if necessary
107 | to_remove = self.recents.append(checkpoint)
108 | if to_remove is not None and to_remove not in self.intervals:
109 | os.remove(to_remove)
110 |
--------------------------------------------------------------------------------
/voltron/util/v1/distributed.py:
--------------------------------------------------------------------------------
1 | """
2 | distributed.py
3 |
4 | Key distributed utilities; notably provides a standard API for getting relevant data from either CPU/GPU or XLA (TPU)
5 | devices, since the underlying implementation does differ substantially.
6 |
7 | Assumes that code is running with one of the following strategies:
8 | - Single Process (on CPU, GPU)
9 | - DDP (CPU, GPU)... uses the torch.distributed.launch API & semantics
10 | - XMP Spawn (TPU)... TPU based XLA + Multiprocessing Spawn semantics
11 |
12 | Key Terminology
13 | -> World Size :: Total number of processes distributed over (# nodes x # devices) -- assumed homogenous!
14 | -> Rank :: Integer index of current process in the total world size
15 | -> Local Rank :: Local index on given node in [0, Devices per Node]
16 | """
17 | from importlib.util import find_spec
18 | from typing import Iterator, TypeVar
19 |
20 | import torch
21 | from torch.utils.data import Dataset
22 | from torch.utils.data.distributed import DistributedSampler
23 |
24 | T_co = TypeVar("T_co", covariant=True)
25 |
26 |
27 | class ResumeableDistributedSampler(DistributedSampler):
28 | def __init__(
29 | self,
30 | seen_examples: int,
31 | resume_epoch: int,
32 | dataset: Dataset,
33 | num_replicas: int,
34 | rank: int,
35 | shuffle: bool = True,
36 | seed: int = 0,
37 | ) -> None:
38 | super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed)
39 | self.seen_examples, self.resume_epoch, self.do_resume = seen_examples, resume_epoch, True
40 |
41 | # Set `seen_examples_per_replica` --> this is necessary for when we re-wrap the iterator in self.__iter__()
42 | # > Note: `seen_examples` is across _all_ replicas --> so divide!
43 | self.seen_examples_per_replica = self.seen_examples // self.num_replicas
44 |
45 | def __iter__(self) -> Iterator[T_co]:
46 | epoch_iterator = super().__iter__()
47 | if self.do_resume:
48 | # Unpack iterator --> list, slice off the first `seen_examples_per_replica` examples, and re-wrap!
49 | leftover_idxs = list(epoch_iterator)[self.seen_examples_per_replica :]
50 | return iter(leftover_idxs)
51 | else:
52 | return epoch_iterator
53 |
54 | def __len__(self) -> int:
55 | if self.do_resume:
56 | # Remove the "seen" sample from self.num_samples; num_samples is *per replica*!
57 | return self.num_samples - self.seen_examples_per_replica
58 | else:
59 | return self.num_samples
60 |
61 | def set_epoch(self, epoch: int) -> None:
62 | # If epoch != self.resume_epoch --> we're in "regular DistributedSampler" mode (just a wrapper class)
63 | # > Intuition: We should *only* truncate examples on the first epoch upon resuming!
64 | self.epoch = epoch
65 | if self.epoch != self.resume_epoch:
66 | self.do_resume = False
67 |
68 |
69 | def xla_available() -> bool:
70 | try:
71 | return find_spec("torch_xla") is not None
72 | except ModuleNotFoundError:
73 | return False
74 |
75 |
76 | def get_rank() -> int:
77 | """Returns the global rank [0, World Size) of the current process."""
78 | if xla_available():
79 | import torch_xla.core.xla_model as xm
80 |
81 | # By default, if XLA is available, assume we're running under XMP Spawn
82 | return xm.get_ordinal()
83 |
84 | # Try to get rank via torch.distributed, but catch error if only single process
85 | try:
86 | return torch.distributed.get_rank()
87 |
88 | # RuntimeError => not running distributed (single process)
89 | except RuntimeError:
90 | return 0
91 |
--------------------------------------------------------------------------------
/voltron/util/v1/random.py:
--------------------------------------------------------------------------------
1 | """
2 | random.py
3 |
4 | Utilities for dealing with randomness for PyTorch, across devices (CPU, GPU, TPU).
5 |
6 | Loosely inspired by functionality in PyTorch-Lightning:
7 | > Ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/seed.py
8 |
9 | This is pretty important to get right if we're every randomly generating our masks (or prefix dropout) inside our
10 | Dataset __getitem__() with multiple workers... if not handled properly, we will get repeated augmentations anytime
11 | we inject randomness from non-PyTorch sources (e.g., numpy, random)!
12 | > Ref: https://tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects/
13 | """
14 | import os
15 | import random
16 | from typing import Callable
17 |
18 | import numpy as np
19 | import torch
20 |
21 | from voltron.util.v1.distributed import get_rank
22 |
23 |
24 | def set_global_seed(seed: int) -> Callable[[int], None]:
25 | """Sets seed for all randomness libraries (mostly random, numpy, torch) and produces a `worker_init_fn`"""
26 | assert np.iinfo(np.uint32).min < seed < np.iinfo(np.uint32).max, "Seed outside the np.uint32 bounds!"
27 |
28 | # Set Seed as an Environment Variable
29 | os.environ["EXPERIMENT_GLOBAL_SEED"] = str(seed)
30 | random.seed(seed)
31 | np.random.seed(seed)
32 | torch.manual_seed(seed)
33 |
34 | return worker_init_function
35 |
36 |
37 | def worker_init_function(worker_id: int) -> None:
38 | """
39 | Borrowed directly from PyTorch-Lightning; inspired by this issue comment in the PyTorch repo:
40 | > Ref: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562
41 |
42 | Intuition: You can think of the seed sequence spawn function as a "janky" torch.Generator() or jax.PRNGKey that
43 | you can run iterative splitting on to get new (predictable) randomness.
44 |
45 | :param worker_id: Identifier for the given worker [0, num_workers) for the Dataloader in question.
46 | """
47 | # Get current `rank` (if running distributed) and `process_seed`
48 | global_rank, process_seed = get_rank(), torch.initial_seed()
49 |
50 | # Back out the "base" (original) seed - the per-worker seed is set in PyTorch:
51 | # > https://pytorch.org/docs/stable/data.html#data-loading-randomness
52 | base_seed = process_seed - worker_id
53 |
54 | # "Magic" code --> basically creates a seed sequence that mixes different "sources" and seeds every library...
55 | seed_seq = np.random.SeedSequence([base_seed, worker_id, global_rank])
56 |
57 | # Use 128 bits (4 x 32-bit words) to represent seed --> generate_state(k) produces a `k` element array!
58 | np.random.seed(seed_seq.generate_state(4))
59 |
60 | # Spawn distinct child sequences for PyTorch (reseed) and stdlib random
61 | torch_seed_seq, random_seed_seq = seed_seq.spawn(2)
62 |
63 | # Torch Manual seed takes 64 bits (so just specify a dtype of uint64
64 | torch.manual_seed(torch_seed_seq.generate_state(1, dtype=np.uint64)[0])
65 |
66 | # Use 128 Bits for `random`, but express as integer instead of as an array
67 | random_seed = (random_seed_seq.generate_state(2, dtype=np.uint64).astype(list) * [1 << 64, 1]).sum()
68 | random.seed(random_seed)
69 |
--------------------------------------------------------------------------------
/voltron/util/v1/xla_logger.py:
--------------------------------------------------------------------------------
1 | """
2 | xla_logger.py
3 |
4 | Utility class defining various XLA logging methods (called within marked closures), for logging metrics periodically
5 | through training & validation.
6 | """
7 | from typing import List
8 |
9 | import jsonlines
10 | import numpy as np
11 | import torch
12 | import torch_xla.core.xla_model as xm
13 | import wandb
14 |
15 |
16 | # === Generic (Cross-Model) Epoch End Update ===
17 | def log_epoch_end_update(
18 | arch: str,
19 | epoch: int,
20 | global_step: int,
21 | run_id: str,
22 | duration: int,
23 | train_losses: List[torch.Tensor],
24 | val_loss: float,
25 | lr: float,
26 | step_times: List[float],
27 | ) -> None:
28 | train_loss = torch.stack(list(train_losses)).mean()
29 | average_step_time = np.mean(list(step_times))
30 |
31 | # Console Logging --> Unclear if it'll work?
32 | xm.master_print(
33 | f"Epoch {epoch:03d}, Global Step {global_step:06d} || LR :: {lr:.6f} -- Train Loss :: {train_loss:.4f} "
34 | f"-- Val Loss :: {val_loss:.4f} -- Total Time (sec) :: {duration}"
35 | )
36 |
37 | # Get Log-Friendly Arch
38 | p_arch = {
39 | "v-mvp": "MVP",
40 | "v-r3m": "R3M (ViT)",
41 | "v-rn3m": "R3M (RN)",
42 | "v-cond": "V-Cond",
43 | "v-dual": "V-Dual",
44 | "v-gen": "V-Gen",
45 | }[arch]
46 |
47 | # Log to Weights & Biases & JSONL
48 | blob = {
49 | "Pretrain/Step": global_step,
50 | "Pretrain/Epoch": epoch,
51 | "Pretrain/Training Duration": duration,
52 | "Pretrain/Step Time": average_step_time,
53 | f"Pretrain/{p_arch} Train Epoch Loss": train_loss.item(),
54 | f"Pretrain/{p_arch} Train Loss": train_loss.item(),
55 | f"Pretrain/{p_arch} Validation Loss": val_loss,
56 | "Pretrain/Learning Rate": lr,
57 | }
58 |
59 | wandb.log(blob, step=global_step)
60 | with jsonlines.open(f"{run_id}.jsonl", mode="a", sort_keys=True) as js_logger:
61 | js_logger.write(blob)
62 |
63 |
64 | # === Data-Locked Reproductions ===
65 |
66 |
67 | def log_vmvp_train_update(
68 | epoch: int,
69 | global_step: int,
70 | run_id: str,
71 | train_losses: List[torch.Tensor],
72 | lr: float,
73 | reconstruction_losses: List[torch.Tensor],
74 | step_times: List[float],
75 | ) -> None:
76 | train_loss = torch.stack(list(train_losses)).mean()
77 | reconstruction_loss = torch.stack(list(reconstruction_losses)).mean()
78 | average_step_time = np.mean(list(step_times))
79 |
80 | # Console Logging --> Just log the aggregated train loss...
81 | xm.master_print(
82 | f"Epoch {epoch:03d}, Global Step {global_step:06d} || LR :: {lr:.6f} -- Train Loss :: {train_loss:.4f}"
83 | )
84 |
85 | # Log to Weights & Biases + JSONL
86 | blob = {
87 | "Pretrain/Step": global_step,
88 | "Pretrain/Epoch": epoch,
89 | "Pretrain/V-MVP Train Loss": train_loss.item(),
90 | "Pretrain/Reconstruction Loss": reconstruction_loss.item(),
91 | "Pretrain/Learning Rate": lr,
92 | "Pretrain/Step Time": average_step_time,
93 | }
94 | wandb.log(blob, step=global_step)
95 | with jsonlines.open(f"{run_id}.jsonl", mode="a", sort_keys=True) as js_logger:
96 | js_logger.write(blob)
97 |
98 |
99 | def log_vr3m_train_update(
100 | epoch: int,
101 | global_step: int,
102 | run_id: str,
103 | train_losses: List[torch.Tensor],
104 | lr: float,
105 | tcn_losses: List[torch.Tensor],
106 | reward_losses: List[torch.Tensor],
107 | l1_losses: List[torch.Tensor],
108 | l2_losses: List[torch.Tensor],
109 | tcn_accuracies: List[torch.Tensor],
110 | reward_accuracies: List[torch.Tensor],
111 | step_times: List[float],
112 | ) -> None:
113 | train_loss = torch.stack(list(train_losses)).mean()
114 | tcn_loss = torch.stack(list(tcn_losses)).mean()
115 | reward_loss = torch.stack(list(reward_losses)).mean()
116 | l1_loss, l2_loss = torch.stack(list(l1_losses)).mean(), torch.stack(list(l2_losses)).mean()
117 | tcn_accuracy = torch.stack(list(tcn_accuracies)).mean()
118 | reward_accuracy = torch.stack(list(reward_accuracies)).mean()
119 | average_step_time = np.mean(list(step_times))
120 |
121 | # Console Logging --> Just log the aggregated train loss...
122 | xm.master_print(
123 | f"Epoch {epoch:03d}, Global Step {global_step:06d} || LR :: {lr:.6f} -- Train Loss :: {train_loss:.4f}"
124 | )
125 |
126 | # Log to Weights & Biases + JSONL
127 | blob = {
128 | "Pretrain/Step": global_step,
129 | "Pretrain/Epoch": epoch,
130 | "Pretrain/V-R3M Train Loss": train_loss.item(),
131 | "Pretrain/TCN Loss": tcn_loss.item(),
132 | "Pretrain/Reward Loss": reward_loss.item(),
133 | "Pretrain/L1 Loss": l1_loss.item(),
134 | "Pretrain/L2 Loss": l2_loss.item(),
135 | "Pretrain/TCN Accuracy": tcn_accuracy.item(),
136 | "Pretrain/Reward Accuracy": reward_accuracy.item(),
137 | "Pretrain/Learning Rate": lr,
138 | "Pretrain/Step Time": average_step_time,
139 | }
140 | wandb.log(blob, step=global_step)
141 | with jsonlines.open(f"{run_id}.jsonl", mode="a", sort_keys=True) as js_logger:
142 | js_logger.write(blob)
143 |
144 |
145 | def log_vrn3m_train_update(
146 | epoch: int,
147 | global_step: int,
148 | run_id: str,
149 | train_losses: List[torch.Tensor],
150 | lr: float,
151 | tcn_losses: List[torch.Tensor],
152 | reward_losses: List[torch.Tensor],
153 | l1_losses: List[torch.Tensor],
154 | l2_losses: List[torch.Tensor],
155 | tcn_accuracies: List[torch.Tensor],
156 | reward_accuracies: List[torch.Tensor],
157 | step_times: List[float],
158 | ) -> None:
159 | train_loss = torch.stack(list(train_losses)).mean()
160 | tcn_loss = torch.stack(list(tcn_losses)).mean()
161 | reward_loss = torch.stack(list(reward_losses)).mean()
162 | l1_loss, l2_loss = torch.stack(list(l1_losses)).mean(), torch.stack(list(l2_losses)).mean()
163 | tcn_accuracy = torch.stack(list(tcn_accuracies)).mean()
164 | reward_accuracy = torch.stack(list(reward_accuracies)).mean()
165 | average_step_time = np.mean(list(step_times))
166 |
167 | # Console Logging --> Just log the aggregated train loss...
168 | xm.master_print(
169 | f"Epoch {epoch:03d}, Global Step {global_step:06d} || LR :: {lr:.6f} -- Train Loss :: {train_loss:.4f}"
170 | )
171 |
172 | # Log to Weights & Biases + JSONL
173 | blob = {
174 | "Pretrain/Step": global_step,
175 | "Pretrain/Epoch": epoch,
176 | "Pretrain/V-RN3M Train Loss": train_loss.item(),
177 | "Pretrain/TCN Loss": tcn_loss.item(),
178 | "Pretrain/Reward Loss": reward_loss.item(),
179 | "Pretrain/L1 Loss": l1_loss.item(),
180 | "Pretrain/L2 Loss": l2_loss.item(),
181 | "Pretrain/TCN Accuracy": tcn_accuracy.item(),
182 | "Pretrain/Reward Accuracy": reward_accuracy.item(),
183 | "Pretrain/Learning Rate": lr,
184 | "Pretrain/Step Time": average_step_time,
185 | }
186 | wandb.log(blob, step=global_step)
187 | with jsonlines.open(f"{run_id}.jsonl", mode="a", sort_keys=True) as js_logger:
188 | js_logger.write(blob)
189 |
190 |
191 | # === Voltron Models ===
192 | def log_vcond_train_update(
193 | epoch: int,
194 | global_step: int,
195 | run_id: str,
196 | train_losses: List[torch.Tensor],
197 | lr: float,
198 | reconstruction_losses: List[torch.Tensor],
199 | step_times: List[float],
200 | ) -> None:
201 | train_loss = torch.stack(list(train_losses)).mean()
202 | reconstruction_loss = torch.stack(list(reconstruction_losses)).mean()
203 | average_step_time = np.mean(list(step_times))
204 |
205 | # Console Logging --> Just log the aggregated train loss...
206 | xm.master_print(
207 | f"Epoch {epoch:03d}, Global Step {global_step:06d} || LR :: {lr:.6f} -- Train Loss :: {train_loss:.4f}"
208 | )
209 |
210 | # Log to Weights & Biases + JSONL
211 | blob = {
212 | "Pretrain/Step": global_step,
213 | "Pretrain/Epoch": epoch,
214 | "Pretrain/V-Cond Train Loss": train_loss.item(),
215 | "Pretrain/Reconstruction Loss": reconstruction_loss.item(),
216 | "Pretrain/Learning Rate": lr,
217 | "Pretrain/Step Time": average_step_time,
218 | }
219 | wandb.log(blob, step=global_step)
220 | with jsonlines.open(f"{run_id}.jsonl", mode="a", sort_keys=True) as js_logger:
221 | js_logger.write(blob)
222 |
223 |
224 | def log_vdual_train_update(
225 | epoch: int,
226 | global_step: int,
227 | run_id: str,
228 | train_losses: List[torch.Tensor],
229 | lr: float,
230 | reconstruction_losses: List[torch.Tensor],
231 | zero_reconstruction_losses: List[torch.Tensor],
232 | k_reconstruction_losses: List[torch.Tensor],
233 | step_times: List[float],
234 | ) -> None:
235 | train_loss = torch.stack(list(train_losses)).mean()
236 | reconstruction_loss = torch.stack(list(reconstruction_losses)).mean()
237 | zero_reconstruction_loss = torch.stack(list(zero_reconstruction_losses)).mean()
238 | k_reconstruction_loss = torch.stack(list(k_reconstruction_losses)).mean()
239 | average_step_time = np.mean(list(step_times))
240 |
241 | # Console Logging --> Just log the aggregated train loss...
242 | xm.master_print(
243 | f"Epoch {epoch:03d}, Global Step {global_step:06d} || LR :: {lr:.6f} -- Train Loss :: {train_loss:.4f}"
244 | )
245 |
246 | # Log to Weights & Biases + JSONL
247 | blob = {
248 | "Pretrain/Step": global_step,
249 | "Pretrain/Epoch": epoch,
250 | "Pretrain/V-Dual Train Loss": train_loss.item(),
251 | "Pretrain/Reconstruction Loss": reconstruction_loss.item(),
252 | "Pretrain/Zero Reconstruction Loss": zero_reconstruction_loss.item(),
253 | "Pretrain/K Reconstruction Loss": k_reconstruction_loss.item(),
254 | "Pretrain/Learning Rate": lr,
255 | "Pretrain/Step Time": average_step_time,
256 | }
257 | wandb.log(blob, step=global_step)
258 | with jsonlines.open(f"{run_id}.jsonl", mode="a", sort_keys=True) as js_logger:
259 | js_logger.write(blob)
260 |
261 |
262 | def log_vgen_train_update(
263 | epoch: int,
264 | global_step: int,
265 | run_id: str,
266 | train_losses: List[torch.Tensor],
267 | lr: float,
268 | reconstruction_losses: List[torch.Tensor],
269 | lm_losses: List[torch.Tensor],
270 | lm_ppl: List[torch.Tensor],
271 | zero_reconstruction_losses: List[torch.Tensor],
272 | k_reconstruction_losses: List[torch.Tensor],
273 | step_times: List[float],
274 | ) -> None:
275 | train_loss = torch.stack(list(train_losses)).mean()
276 | reconstruction_loss = torch.stack(list(reconstruction_losses)).mean()
277 | lm_loss = torch.stack(list(lm_losses)).mean()
278 | lm_perplexity = torch.stack(list(lm_ppl)).mean()
279 | zero_reconstruction_loss = torch.stack(list(zero_reconstruction_losses)).mean()
280 | k_reconstruction_loss = torch.stack(list(k_reconstruction_losses)).mean()
281 | average_step_time = np.mean(list(step_times))
282 |
283 | # Console Logging --> Just log the aggregated train loss...
284 | xm.master_print(
285 | f"Epoch {epoch:03d}, Global Step {global_step:06d} || LR :: {lr:.6f} -- Train Loss :: {train_loss:.4f} --"
286 | f" Reconstruction Loss {reconstruction_loss:.4f} -- LM Loss {lm_loss:.4f}"
287 | )
288 |
289 | # Log to Weights & Biases + JSONL
290 | blob = {
291 | "Pretrain/Step": global_step,
292 | "Pretrain/Epoch": epoch,
293 | "Pretrain/V-Gen Train Loss": train_loss.item(),
294 | "Pretrain/Reconstruction Loss": reconstruction_loss.item(),
295 | "Pretrain/CLM Loss": lm_loss.item(),
296 | "Pretrain/CLM Perplexity": lm_perplexity.item(),
297 | "Pretrain/LM Loss": lm_loss.item(),
298 | "Pretrain/LM Perplexity": lm_perplexity.item(),
299 | "Pretrain/Zero Reconstruction Loss": zero_reconstruction_loss.item(),
300 | "Pretrain/K Reconstruction Loss": k_reconstruction_loss.item(),
301 | "Pretrain/Learning Rate": lr,
302 | "Pretrain/Step Time": average_step_time,
303 | }
304 | wandb.log(blob, step=global_step)
305 | with jsonlines.open(f"{run_id}.jsonl", mode="a", sort_keys=True) as js_logger:
306 | js_logger.write(blob)
307 |
--------------------------------------------------------------------------------