├── .github ├── FUNDING.yml ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── PULL_REQUEST_TEMPLATE.md ├── dependabot.yml ├── init.sh ├── release_message.sh └── workflows │ ├── lint.yml │ ├── main.yml │ ├── release.yml │ └── rename_project.yml ├── .gitignore ├── CONTRIBUTING.md ├── Containerfile ├── HISTORY.md ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── examples └── data_collection_and_load.py ├── format.sh ├── mkdocs.yml ├── pyproject.toml ├── pytest.ini ├── robodm ├── __init__.py ├── dataset.py ├── feature.py ├── loader │ ├── __init__.py │ ├── base.py │ ├── hdf5.py │ ├── rlds.py │ └── vla.py ├── trajectory.py ├── trajectory_base.py ├── trajectory_factory.py └── utils.py └── tests ├── README.md ├── __init__.py ├── conftest.py ├── test_fixtures.py ├── test_loaders.py ├── test_openx_trajectory.py ├── test_ray_vla_loader.py ├── test_shape_codec_logic.py ├── test_time_manager.py ├── test_trajectory.py ├── test_trajectory_enhanced_loading.py ├── test_trajectory_loader_edge_cases.py └── test_trajectory_loader_performance.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: ['BerkeleyAutomation'] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 13 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug, help wanted 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Version [e.g. 22] 29 | 30 | **Additional context** 31 | Add any other context about the problem here. 32 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: enhancement, question 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ### Summary :memo: 2 | _Write an overview about it._ 3 | 4 | ### Details 5 | _Describe more what you did on changes._ 6 | 1. (...) 7 | 2. (...) 8 | 9 | ### Bugfixes :bug: (delete if dind't have any) 10 | - 11 | 12 | ### Checks 13 | - [ ] Closed #798 14 | - [ ] Tested Changes 15 | - [ ] Stakeholder Approval 16 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "github-actions" 4 | directory: "/" 5 | schedule: 6 | interval: "weekly" -------------------------------------------------------------------------------- /.github/init.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | overwrite_template_dir=0 3 | 4 | while getopts t:o flag 5 | do 6 | case "${flag}" in 7 | t) template=${OPTARG};; 8 | o) overwrite_template_dir=1;; 9 | esac 10 | done 11 | 12 | if [ -z "${template}" ]; then 13 | echo "Available templates: flask" 14 | read -p "Enter template name: " template 15 | fi 16 | 17 | repo_urlname=$(basename -s .git `git config --get remote.origin.url`) 18 | repo_name=$(basename -s .git `git config --get remote.origin.url` | tr '-' '_' | tr '[:upper:]' '[:lower:]') 19 | repo_owner=$(git config --get remote.origin.url | awk -F ':' '{print $2}' | awk -F '/' '{print $1}') 20 | echo "Repo name: ${repo_name}" 21 | echo "Repo owner: ${repo_owner}" 22 | echo "Repo urlname: ${repo_urlname}" 23 | 24 | if [ -f ".github/workflows/rename_project.yml" ]; then 25 | .github/rename_project.sh -a "${repo_owner}" -n "${repo_name}" -u "${repo_urlname}" -d "Awesome ${repo_name} created by ${repo_owner}" 26 | fi 27 | 28 | function download_template { 29 | rm -rf "${template_dir}" 30 | mkdir -p .github/templates 31 | git clone "${template_url}" "${template_dir}" 32 | } 33 | 34 | echo "Using template:${template}" 35 | template_url="https://github.com/rochacbruno/${template}-project-template" 36 | template_dir=".github/templates/${template}" 37 | if [ -d "${template_dir}" ]; then 38 | # Template directory already exists 39 | if [ "${overwrite_template_dir}" -eq 1 ]; then 40 | # user passed -o flag, delete and re-download 41 | echo "Overwriting ${template_dir}" 42 | download_template 43 | else 44 | # Ask user if they want to overwrite 45 | echo "Directory ${template_dir} already exists." 46 | read -p "Do you want to overwrite it? [y/N] " -n 1 -r 47 | echo 48 | if [[ $REPLY =~ ^[Yy]$ ]]; then 49 | echo "Overwriting ${template_dir}" 50 | download_template 51 | else 52 | # User decided not to overwrite 53 | echo "Using existing ${template_dir}" 54 | fi 55 | fi 56 | else 57 | # Template directory does not exist, download it 58 | echo "Downloading ${template_url}" 59 | download_template 60 | fi 61 | 62 | echo "Applying ${template} template to this project"} 63 | ./.github/templates/${template}/apply.sh -a "${repo_owner}" -n "${repo_name}" -u "${repo_urlname}" -d "Awesome ${repo_name} created by ${repo_owner}" 64 | 65 | # echo "Removing temporary template files" 66 | # rm -rf .github/templates/${template} 67 | 68 | echo "Done! review, commit and push the changes" 69 | -------------------------------------------------------------------------------- /.github/release_message.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | previous_tag=$(git tag --sort=-creatordate | sed -n 2p) 3 | git shortlog "${previous_tag}.." | sed 's/^./ &/' 4 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | push: 5 | paths: 6 | - '**.py' 7 | - '**/pyproject.toml' 8 | - '**/pytest.ini' 9 | pull_request: 10 | paths: 11 | - '**.py' 12 | - '**/pyproject.toml' 13 | - '**/pytest.ini' 14 | 15 | jobs: 16 | mypy: 17 | name: MyPy Type Check 18 | runs-on: ubuntu-latest 19 | steps: 20 | - name: Checkout 21 | uses: actions/checkout@v4 22 | 23 | - name: Setup Python 24 | uses: actions/setup-python@v5 25 | with: 26 | python-version: '3.10' 27 | 28 | - name: Cache pip packages 29 | uses: actions/cache@v4 30 | with: 31 | path: ~/.cache/pip 32 | key: ${{ runner.os }}-pip-mypy-${{ hashFiles('**/pyproject.toml') }} 33 | restore-keys: | 34 | ${{ runner.os }}-pip-mypy- 35 | 36 | - name: Install dependencies 37 | run: | 38 | python -m pip install --upgrade pip 39 | pip install mypy 40 | pip install -e . 41 | 42 | - name: Run mypy 43 | run: | 44 | mypy robodm --ignore-missing-imports --check-untyped-defs --show-error-codes --pretty -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | # This is a basic workflow to help you get started with Actions 2 | 3 | name: CI 4 | 5 | # Controls when the workflow will run 6 | on: 7 | # Triggers the workflow on push or pull request events but only for the main branch 8 | push: 9 | branches: [ main, master ] 10 | pull_request: 11 | branches: [ main, master ] 12 | 13 | # Allows you to run this workflow manually from the Actions tab 14 | workflow_dispatch: 15 | 16 | env: 17 | PYTHONPATH: ${{ github.workspace }} 18 | 19 | jobs: 20 | format-check: 21 | name: Format Check 22 | runs-on: ubuntu-latest 23 | steps: 24 | - uses: actions/checkout@v4 25 | 26 | - name: Set up Python 27 | uses: actions/setup-python@v5 28 | with: 29 | python-version: '3.10' 30 | 31 | - name: Cache pip packages 32 | uses: actions/cache@v4 33 | with: 34 | path: ~/.cache/pip 35 | key: ${{ runner.os }}-pip-format-${{ hashFiles('**/pyproject.toml') }} 36 | restore-keys: | 37 | ${{ runner.os }}-pip-format- 38 | 39 | - name: Install formatting tools 40 | run: | 41 | python -m pip install --upgrade pip 42 | pip install yapf black isort mypy pylint flake8 43 | 44 | - name: Run format check 45 | run: | 46 | bash format.sh --all 47 | 48 | - name: Check for formatting changes 49 | run: | 50 | if ! git diff --quiet; then 51 | echo "Code formatting issues detected. Please run 'bash format.sh --all' locally." 52 | git diff 53 | exit 1 54 | fi 55 | 56 | linter: 57 | name: Lint 58 | runs-on: ubuntu-latest 59 | needs: format-check 60 | strategy: 61 | fail-fast: false 62 | matrix: 63 | python-version: ['3.10', '3.11', '3.12'] 64 | steps: 65 | - uses: actions/checkout@v4 66 | 67 | - name: Set up Python ${{ matrix.python-version }} 68 | uses: actions/setup-python@v5 69 | with: 70 | python-version: ${{ matrix.python-version }} 71 | 72 | - name: Cache pip packages 73 | uses: actions/cache@v4 74 | with: 75 | path: ~/.cache/pip 76 | key: ${{ runner.os }}-pip-lint-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml') }} 77 | restore-keys: | 78 | ${{ runner.os }}-pip-lint-${{ matrix.python-version }}- 79 | 80 | - name: Install project 81 | run: | 82 | python -m pip install --upgrade pip 83 | # Install test dependencies 84 | pip install pytest pytest-cov flake8 black mypy isort yapf pylint 85 | # Install project in editable mode 86 | pip install -e . 87 | 88 | - name: Run linter 89 | run: make lint 90 | 91 | tests: 92 | name: Tests 93 | runs-on: ${{ matrix.os }} 94 | needs: linter 95 | strategy: 96 | fail-fast: false 97 | matrix: 98 | os: [ubuntu-latest, macos-latest, windows-latest] 99 | python-version: ['3.10', '3.11', '3.12'] 100 | exclude: 101 | # Reduce CI load by testing fewer combinations on non-Ubuntu 102 | - os: macos-latest 103 | python-version: '3.11' 104 | - os: windows-latest 105 | python-version: '3.11' 106 | steps: 107 | - uses: actions/checkout@v4 108 | 109 | - name: Set up Python ${{ matrix.python-version }} 110 | uses: actions/setup-python@v5 111 | with: 112 | python-version: ${{ matrix.python-version }} 113 | 114 | - name: Cache pip packages 115 | uses: actions/cache@v4 116 | with: 117 | path: ~/.cache/pip 118 | key: ${{ runner.os }}-pip-test-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml') }} 119 | restore-keys: | 120 | ${{ runner.os }}-pip-test-${{ matrix.python-version }}- 121 | 122 | - name: Install system dependencies (Ubuntu) 123 | if: matrix.os == 'ubuntu-latest' 124 | run: | 125 | sudo apt-get update 126 | sudo apt-get install -y ffmpeg 127 | 128 | - name: Install system dependencies (macOS) 129 | if: matrix.os == 'macos-latest' 130 | run: | 131 | brew install ffmpeg 132 | 133 | - name: Install system dependencies (Windows) 134 | if: matrix.os == 'windows-latest' 135 | shell: powershell 136 | run: | 137 | # Install ffmpeg via chocolatey 138 | choco install ffmpeg -y 139 | 140 | - name: Install project with test dependencies 141 | run: | 142 | python -m pip install --upgrade pip 143 | # Install test dependencies 144 | pip install pytest pytest-cov pytest-benchmark coverage 145 | # Install project with optional dependencies for comprehensive testing 146 | pip install -e .[all] 147 | 148 | - name: Run fast tests 149 | run: | 150 | pytest tests/ -v -m "not slow and not benchmark" --cov=robodm --cov-report=xml --cov-report=term-missing 151 | 152 | - name: Run slow tests (Ubuntu only) 153 | if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.10' 154 | run: | 155 | pytest tests/ -v -m "slow" --cov=robodm --cov-append --cov-report=xml 156 | 157 | - name: Upload coverage to Codecov 158 | if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.10' 159 | uses: codecov/codecov-action@v4 160 | with: 161 | file: ./coverage.xml 162 | fail_ci_if_error: false 163 | verbose: true 164 | 165 | benchmark: 166 | name: Benchmark Tests 167 | runs-on: ubuntu-latest 168 | needs: tests 169 | if: github.event_name == 'push' && github.ref == 'refs/heads/main' 170 | steps: 171 | - uses: actions/checkout@v4 172 | 173 | - name: Set up Python 174 | uses: actions/setup-python@v5 175 | with: 176 | python-version: '3.10' 177 | 178 | - name: Install system dependencies 179 | run: | 180 | sudo apt-get update 181 | sudo apt-get install -y ffmpeg 182 | 183 | - name: Install project with all dependencies 184 | run: | 185 | python -m pip install --upgrade pip 186 | pip install pytest pytest-benchmark 187 | pip install -e .[all] 188 | 189 | - name: Run benchmark tests 190 | run: | 191 | pytest tests/ -v -m "benchmark" --benchmark-only --benchmark-json=benchmark.json 192 | 193 | - name: Store benchmark result 194 | uses: benchmark-action/github-action-benchmark@v1 195 | if: always() 196 | with: 197 | tool: 'pytest' 198 | output-file-path: benchmark.json 199 | github-token: ${{ secrets.GITHUB_TOKEN }} 200 | auto-push: true 201 | comment-on-alert: true 202 | alert-threshold: '200%' 203 | fail-on-alert: false 204 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | tags: 6 | - '*' # Push events to matching any tag 7 | workflow_dispatch: 8 | 9 | jobs: 10 | release: 11 | name: Create Release 12 | runs-on: ubuntu-latest 13 | permissions: 14 | contents: write 15 | steps: 16 | - uses: actions/checkout@v4 17 | with: 18 | fetch-depth: 0 # Fetch all history for changelog generation 19 | 20 | - name: Generate Changelog 21 | run: .github/release_message.sh > release_message.md 22 | 23 | - name: Release 24 | uses: softprops/action-gh-release@v2 25 | with: 26 | body_path: release_message.md 27 | 28 | test-before-deploy: 29 | name: Test Before Deploy 30 | runs-on: ubuntu-latest 31 | steps: 32 | - uses: actions/checkout@v4 33 | 34 | - name: Set up Python 35 | uses: actions/setup-python@v5 36 | with: 37 | python-version: '3.10' 38 | 39 | - name: Install system dependencies 40 | run: | 41 | sudo apt-get update 42 | sudo apt-get install -y ffmpeg 43 | 44 | - name: Install and test 45 | run: | 46 | python -m pip install --upgrade pip 47 | pip install pytest 48 | pip install -e .[all] 49 | pytest tests/ -m "not slow and not benchmark" -x 50 | 51 | deploy: 52 | name: Deploy to PyPI 53 | needs: [release, test-before-deploy] 54 | runs-on: ubuntu-latest 55 | environment: release 56 | permissions: 57 | id-token: write # For trusted publishing 58 | steps: 59 | - uses: actions/checkout@v4 60 | 61 | - name: Set up Python 62 | uses: actions/setup-python@v5 63 | with: 64 | python-version: '3.10' 65 | 66 | - name: Install build dependencies 67 | run: | 68 | python -m pip install --upgrade pip 69 | pip install build twine 70 | 71 | - name: Build package 72 | run: python -m build 73 | 74 | - name: Check package 75 | run: twine check dist/* 76 | 77 | - name: Publish to PyPI 78 | uses: pypa/gh-action-pypi-publish@release/v1 79 | with: 80 | skip-existing: true 81 | -------------------------------------------------------------------------------- /.github/workflows/rename_project.yml: -------------------------------------------------------------------------------- 1 | name: Rename the project from template 2 | 3 | on: [push] 4 | 5 | permissions: write-all 6 | 7 | jobs: 8 | rename-project: 9 | if: ${{ !contains (github.repository, '/python-project-template') }} 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v4 13 | with: 14 | # by default, it uses a depth of 1 15 | # this fetches all history so that we can read each commit 16 | fetch-depth: 0 17 | ref: ${{ github.head_ref }} 18 | 19 | - run: echo "REPOSITORY_NAME=$(echo '${{ github.repository }}' | awk -F '/' '{print $2}' | tr '-' '_' | tr '[:upper:]' '[:lower:]')" >> $GITHUB_ENV 20 | shell: bash 21 | 22 | - run: echo "REPOSITORY_URLNAME=$(echo '${{ github.repository }}' | awk -F '/' '{print $2}')" >> $GITHUB_ENV 23 | shell: bash 24 | 25 | - run: echo "REPOSITORY_OWNER=$(echo '${{ github.repository }}' | awk -F '/' '{print $1}')" >> $GITHUB_ENV 26 | shell: bash 27 | 28 | - name: Is this still a template 29 | id: is_template 30 | run: echo "::set-output name=is_template::$(ls .github/template.yml &> /dev/null && echo true || echo false)" 31 | 32 | - name: Rename the project 33 | if: steps.is_template.outputs.is_template == 'true' 34 | run: | 35 | echo "Renaming the project with -a(author) ${{ env.REPOSITORY_OWNER }} -n(name) ${{ env.REPOSITORY_NAME }} -u(urlname) ${{ env.REPOSITORY_URLNAME }}" 36 | .github/rename_project.sh -a ${{ env.REPOSITORY_OWNER }} -n ${{ env.REPOSITORY_NAME }} -u ${{ env.REPOSITORY_URLNAME }} -d "Awesome ${{ env.REPOSITORY_NAME }} created by ${{ env.REPOSITORY_OWNER }}" 37 | 38 | - uses: stefanzweifel/git-auto-commit-action@v5 39 | with: 40 | commit_message: "✅ Ready to clone and code." 41 | # commit_options: '--amend --no-edit' 42 | push_options: --force 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | examples/high-frequency/ 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | 133 | # templates 134 | .github/templates/* 135 | 136 | # generated by rtx-examples 137 | temp.gif 138 | 139 | *.vla 140 | *.mkv 141 | *.csv 142 | *.pdf -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Robo-DM 2 | 3 | robodm welcomes contributions from the community. 4 | 5 | **You need PYTHON3!** 6 | 7 | This instructions are for linux base systems. (Linux, MacOS, BSD, etc.) 8 | 9 | ## Development Setup 10 | 11 | To set up a development environment: 12 | 13 | 1. Fork the repository on GitHub 14 | 2. Clone your fork of this repo. `git clone git@github.com:YOUR_GIT_USERNAME/robodm.git` 15 | 3. Enter the directory `cd robodm` 16 | 4. Add upstream repo `git remote add upstream https://github.com/BerkeleyAutomation/robodm` 17 | 18 | ## Setting up your own virtual environment 19 | 20 | Run `make virtualenv` to create a virtual environment. 21 | then activate it with `source .venv/bin/activate`. 22 | 23 | ## Install the project in develop mode 24 | 25 | Run `make install` to install the project in develop mode. 26 | 27 | ## Run the tests to ensure everything is working 28 | 29 | Run `make test` to run the tests. 30 | 31 | ## Create a new branch to work on your contribution 32 | 33 | Run `git checkout -b my_contribution` 34 | 35 | ## Make your changes 36 | 37 | Edit the files using your preferred editor. (we recommend VIM or VSCode) 38 | 39 | ## Format the code 40 | 41 | Run `make fmt` to format the code. 42 | 43 | ## Run the linter 44 | 45 | Run `make lint` to run the linter. 46 | 47 | ## Test your changes 48 | 49 | Run `make test` to run the tests. 50 | 51 | Ensure code coverage report shows `100%` coverage, add tests to your PR. 52 | 53 | ## Build the docs locally 54 | 55 | Run `make docs` to build the docs. 56 | 57 | Ensure your new changes are documented. 58 | 59 | ## Commit your changes 60 | 61 | This project uses [conventional git commit messages](https://www.conventionalcommits.org/en/v1.0.0/). 62 | 63 | Example: `fix(package): update setup.py arguments 🎉` (emojis are fine too) 64 | 65 | ## Push your changes to your fork 66 | 67 | Run `git push origin my_contribution` 68 | 69 | ## Submit a pull request 70 | 71 | On github interface, click on `Pull Request` button. 72 | 73 | Wait CI to run and one of the developers will review your PR. 74 | 75 | ## Makefile utilities 76 | 77 | This project comes with a `Makefile` that contains a number of useful utility. 78 | 79 | ```bash 80 | ❯ make 81 | Usage: make 82 | 83 | Targets: 84 | help: ## Show the help. 85 | install: ## Install the project in dev mode. 86 | fmt: ## Format code using black & isort. 87 | lint: ## Run pep8, black, mypy linters. 88 | test: lint ## Run tests and generate coverage report. 89 | watch: ## Run tests on every change. 90 | clean: ## Clean unused files. 91 | virtualenv: ## Create a virtual environment. 92 | release: ## Create a new tag for release. 93 | docs: ## Build the documentation. 94 | switch-to-poetry: ## Switch to poetry package manager. 95 | init: ## Initialize the project based on an application template. 96 | ``` 97 | 98 | ## Making a new release 99 | 100 | This project uses [semantic versioning](https://semver.org/) and tags releases with `X.Y.Z` 101 | Every time a new tag is created and pushed to the remote repo, github actions will 102 | automatically create a new release on github and trigger a release on PyPI. 103 | 104 | For this to work you need to setup a secret called `PIPY_API_TOKEN` on the project settings>secrets, 105 | this token can be generated on [pypi.org](https://pypi.org/account/). 106 | 107 | To trigger a new release all you need to do is. 108 | 109 | 1. If you have changes to add to the repo 110 | * Make your changes following the steps described above. 111 | * Commit your changes following the [conventional git commit messages](https://www.conventionalcommits.org/en/v1.0.0/). 112 | 2. Run the tests to ensure everything is working. 113 | 4. Run `make release` to create a new tag and push it to the remote repo. 114 | 115 | the `make release` will ask you the version number to create the tag, ex: type `0.1.1` when you are asked. 116 | 117 | > **CAUTION**: The make release will change local changelog files and commit all the unstaged changes you have. 118 | -------------------------------------------------------------------------------- /Containerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:latest 2 | 3 | RUN apt-get update && \ 4 | apt-get install -y python3.9 \ 5 | python3-pip \ 6 | libgmp-dev \ 7 | ffmpeg 8 | 9 | RUN pip3 install pandas \ 10 | polars \ 11 | numpy \ 12 | tensorflow \ 13 | torch \ 14 | tensorflow_datasets \ 15 | envlogger \ 16 | datasets \ 17 | pyarrow 18 | 19 | COPY . /app 20 | WORKDIR /app 21 | RUN pip install .[full] 22 | RUN pip3 install jupyter 23 | 24 | COPY . / 25 | 26 | CMD ["robodm"] 27 | -------------------------------------------------------------------------------- /HISTORY.md: -------------------------------------------------------------------------------- 1 | Changelog 2 | ========= 3 | 4 | 5 | 0.1.2 (2021-08-14) 6 | ------------------ 7 | - Fix release, README and windows CI. [Bruno Rocha] 8 | - Release: version 0.1.0. [Bruno Rocha] 9 | 10 | 11 | 0.1.0 (2021-08-14) 12 | ------------------ 13 | - Add release command. [Bruno Rocha] 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | 204 | 205 | 206 | ---------------------------- 207 | 208 | 209 | 210 | Software License Agreement (BSD License) 211 | 212 | Redistribution and use in source and binary forms, with or without 213 | modification, are permitted provided that the following conditions 214 | are met: 215 | 216 | * Redistributions of source code must retain the above copyright 217 | notice, this list of conditions and the following disclaimer. 218 | * Redistributions in binary form must reproduce the above 219 | copyright notice, this list of conditions and the following 220 | disclaimer in the documentation and/or other materials provided 221 | with the distribution. 222 | * Neither the name of Willow Garage, Inc. nor the names of its 223 | contributors may be used to endorse or promote products derived 224 | from this software without specific prior written permission. 225 | 226 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 227 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 228 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 229 | FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 230 | COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 231 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 232 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 233 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 234 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 235 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 236 | ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 237 | POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE 3 | include HISTORY.md 4 | graft robodm 5 | global-exclude *.pyc 6 | global-exclude __pycache__ 7 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .ONESHELL: 2 | ENV_PREFIX=$(shell python -c "if __import__('pathlib').Path('.venv/bin/pip').exists(): print('.venv/bin/')") 3 | USING_POETRY=$(shell grep "tool.poetry" pyproject.toml && echo "yes") 4 | 5 | .PHONY: help 6 | help: ## Show the help. 7 | @echo "Usage: make " 8 | @echo "" 9 | @echo "Targets:" 10 | @fgrep "##" Makefile | fgrep -v fgrep 11 | 12 | 13 | .PHONY: show 14 | show: ## Show the current environment. 15 | @echo "Current environment:" 16 | @if [ "$(USING_POETRY)" ]; then poetry env info && exit; fi 17 | @echo "Running using $(ENV_PREFIX)" 18 | @$(ENV_PREFIX)python -V 19 | @$(ENV_PREFIX)python -m site 20 | 21 | .PHONY: install 22 | install: ## Install the project in dev mode. 23 | @if [ "$(USING_POETRY)" ]; then poetry install && exit; fi 24 | @echo "Don't forget to run 'make virtualenv' if you got errors." 25 | $(ENV_PREFIX)pip install -e .[test] 26 | 27 | .PHONY: fmt 28 | fmt: ## Format code using black & isort. 29 | $(ENV_PREFIX)isort robodm/ 30 | $(ENV_PREFIX)black -l 79 robodm/ 31 | $(ENV_PREFIX)black -l 79 tests/ 32 | $(ENV_PREFIX)isort examples/ 33 | $(ENV_PREFIX)black -l 79 examples/ 34 | 35 | .PHONY: lint 36 | lint: ## Run pep8, black, mypy linters. 37 | $(ENV_PREFIX)flake8 robodm/ 38 | $(ENV_PREFIX)black -l 79 --check robodm/ 39 | $(ENV_PREFIX)black -l 79 --check tests/ 40 | $(ENV_PREFIX)mypy --ignore-missing-imports robodm/ 41 | 42 | .PHONY: test 43 | test: lint ## Run tests and generate coverage report. 44 | $(ENV_PREFIX)pytest -v --cov-config .coveragerc --cov=robodm -l --tb=short --maxfail=1 tests/ 45 | $(ENV_PREFIX)coverage xml 46 | $(ENV_PREFIX)coverage html 47 | 48 | .PHONY: watch 49 | watch: ## Run tests on every change. 50 | ls **/**.py | entr $(ENV_PREFIX)pytest -s -vvv -l --tb=long --maxfail=1 tests/ 51 | 52 | .PHONY: clean 53 | clean: ## Clean unused files. 54 | @find ./ -name '*.pyc' -exec rm -f {} \; 55 | @find ./ -name '__pycache__' -exec rm -rf {} \; 56 | @find ./ -name 'Thumbs.db' -exec rm -f {} \; 57 | @find ./ -name '*~' -exec rm -f {} \; 58 | @rm -rf .cache 59 | @rm -rf .pytest_cache 60 | @rm -rf .mypy_cache 61 | @rm -rf build 62 | @rm -rf dist 63 | @rm -rf *.egg-info 64 | @rm -rf htmlcov 65 | @rm -rf .tox/ 66 | @rm -rf docs/_build 67 | 68 | .PHONY: virtualenv 69 | virtualenv: ## Create a virtual environment. 70 | @if [ "$(USING_POETRY)" ]; then poetry install && exit; fi 71 | @echo "creating virtualenv ..." 72 | @rm -rf .venv 73 | @python3 -m venv .venv 74 | @./.venv/bin/pip install -U pip 75 | @./.venv/bin/pip install -e .[test] 76 | @echo 77 | @echo "!!! Please run 'source .venv/bin/activate' to enable the environment !!!" 78 | 79 | .PHONY: release 80 | release: ## Create a new tag for release. 81 | @echo "WARNING: This operation will create s version tag and push to github" 82 | @read -p "Version? (provide the next x.y.z semver) : " TAG 83 | @echo "$${TAG}" > robodm/VERSION 84 | @$(ENV_PREFIX)gitchangelog > HISTORY.md 85 | @git add robodm/VERSION HISTORY.md 86 | @git commit -m "release: version $${TAG} 🚀" 87 | @echo "creating git tag : $${TAG}" 88 | @git tag $${TAG} 89 | @git push -u origin HEAD --tags 90 | @echo "Github Actions will detect the new tag and release the new version." 91 | 92 | .PHONY: docs 93 | docs: ## Build the documentation. 94 | @echo "building documentation ..." 95 | @$(ENV_PREFIX)mkdocs build 96 | URL="site/index.html"; xdg-open $$URL || sensible-browser $$URL || x-www-browser $$URL || gnome-open $$URL || open $$URL 97 | 98 | .PHONY: switch-to-poetry 99 | switch-to-poetry: ## Switch to poetry package manager. 100 | @echo "Switching to poetry ..." 101 | @if ! poetry --version > /dev/null; then echo 'poetry is required, install from https://python-poetry.org/'; exit 1; fi 102 | @rm -rf .venv 103 | @poetry init --no-interaction --name=a_flask_test --author=rochacbruno 104 | @echo "" >> pyproject.toml 105 | @echo "[tool.poetry.scripts]" >> pyproject.toml 106 | @echo "robodm = 'robodm.__main__:main'" >> pyproject.toml 107 | @cat requirements.txt | while read in; do poetry add --no-interaction "$${in}"; done 108 | @cat requirements-test.txt | while read in; do poetry add --no-interaction "$${in}" --dev; done 109 | @poetry install --no-interaction 110 | @mkdir -p .github/backup 111 | @mv requirements* .github/backup 112 | @mv setup.py .github/backup 113 | @echo "You have switched to https://python-poetry.org/ package manager." 114 | @echo "Please run 'poetry shell' or 'poetry run robodm'" 115 | 116 | .PHONY: init 117 | init: ## Initialize the project based on an application template. 118 | @./.github/init.sh 119 | 120 | 121 | # This project has been generated from rochacbruno/python-project-template 122 | # __author__ = 'rochacbruno' 123 | # __repo__ = https://github.com/rochacbruno/python-project-template 124 | # __sponsor__ = https://github.com/sponsors/rochacbruno/ 125 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🦊 Robo-DM 2 | 3 | **An Efficient and Scalable Data Collection and Management Framework For Robotics Learning** 4 | 5 | [![Python 3.10+](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/downloads/) 6 | [![License](https://img.shields.io/github/license/BerkeleyAutomation/robodm)](LICENSE) 7 | [![Tests](https://github.com/BerkeleyAutomation/robodm/workflows/Tests/badge.svg)](https://github.com/BerkeleyAutomation/robodm/actions) 8 | 9 | robodm is a high-performance robotics data management framework that enables efficient collection, storage, and retrieval of multimodal robotics trajectories. Built with speed 🚀 and memory efficiency 📈 in mind, robodm provides native support for various robotics data formats and cloud storage systems. 10 | 11 | ## ✨ Key Features 12 | 13 | - **🚀 High Performance**: Optimized for speed with active metadata and lazily-loaded trajectory data 14 | - **📈 Memory Efficient**: Smart data loading and compression strategies minimize memory usage 15 | - **🎥 Advanced Video Compression**: Support for multiple codecs (H.264, H.265, AV1, FFV1) with automatic codec selection 16 | - **🔄 Format Compatibility**: Native support for Open-X-Embodiment, HuggingFace datasets, RLDS, and HDF5 17 | - **🎯 Flexible Data Types**: Handle images, videos, sensor data, and custom features seamlessly 18 | - **🏗️ Distributed Ready**: Flexible dataset partitioning for distributed training workflows 19 | 20 | ## 🛠️ Installation 21 | 22 | ### Basic Installation 23 | 24 | ```bash 25 | git clone https://github.com/BerkeleyAutomation/robodm.git 26 | cd robodm 27 | pip install -e . 28 | ``` 29 | 30 | ### Installation with Optional Dependencies 31 | 32 | ```bash 33 | # For HuggingFace integration 34 | pip install -e .[hf] 35 | 36 | # For Open-X-Embodiment support 37 | pip install -e .[rtx] 38 | 39 | # For AWS cloud storage 40 | pip install -e .[aws] 41 | 42 | # For PyTorch integration 43 | pip install -e .[torch] 44 | 45 | # Install all optional dependencies 46 | pip install -e .[all] 47 | ``` 48 | 49 | ## 🚀 Quick Start 50 | 51 | ### Basic Data Collection and Loading 52 | 53 | ```python 54 | import numpy as np 55 | import robodm 56 | 57 | # Create a new trajectory for data collection 58 | trajectory = robodm.Trajectory(path="/tmp/robot_demo.vla", mode="w") 59 | 60 | # Collect multimodal robotics data 61 | for step in range(100): 62 | # Add camera observations 63 | trajectory.add("camera/rgb", np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)) 64 | trajectory.add("camera/depth", np.random.rand(480, 640).astype(np.float32)) 65 | 66 | # Add robot state 67 | trajectory.add("robot/joint_positions", np.random.rand(7).astype(np.float32)) 68 | trajectory.add("robot/joint_velocities", np.random.rand(7).astype(np.float32)) 69 | trajectory.add("robot/end_effector_pose", np.random.rand(4, 4).astype(np.float32)) 70 | 71 | # Add action data 72 | trajectory.add("action/gripper_action", np.random.rand(1).astype(np.float32)) 73 | 74 | # Save and close the trajectory 75 | trajectory.close() 76 | 77 | # Load the trajectory for training 78 | trajectory = robodm.Trajectory(path="/tmp/robot_demo.vla", mode="r") 79 | data = trajectory.load() 80 | 81 | print(f"Loaded trajectory with {len(data['camera/rgb'])} timesteps") 82 | print(f"Camera RGB shape: {data['camera/rgb'][0].shape}") 83 | print(f"Joint positions shape: {data['robot/joint_positions'][0].shape}") 84 | ``` 85 | 86 | ### Batch Data Creation 87 | 88 | ```python 89 | import robodm 90 | 91 | # Create trajectory from dictionary of lists 92 | data = { 93 | "observation/image": [np.random.randint(0, 255, (224, 224, 3)) for _ in range(50)], 94 | "observation/state": [np.random.rand(10) for _ in range(50)], 95 | "action": [np.random.rand(7) for _ in range(50)], 96 | } 97 | 98 | trajectory = robodm.Trajectory.from_dict_of_lists( 99 | data=data, 100 | path="/tmp/batch_trajectory.vla", 101 | video_codec="libaom-av1" # Use AV1 codec for efficient compression 102 | ) 103 | ``` 104 | 105 | ### Advanced Configuration 106 | 107 | ```python 108 | import robodm 109 | 110 | # Configure video compression settings 111 | trajectory = robodm.Trajectory( 112 | path="/tmp/compressed_demo.vla", 113 | mode="w", 114 | video_codec="libx265", # Use H.265 codec 115 | codec_options={ 116 | "crf": "23", # Quality setting (lower = higher quality) 117 | "preset": "fast" # Encoding speed 118 | } 119 | ) 120 | 121 | # Use hierarchical feature names 122 | trajectory.add("sensors/lidar/points", lidar_data) 123 | trajectory.add("sensors/camera/front/rgb", front_camera) 124 | trajectory.add("sensors/camera/wrist/rgb", wrist_camera) 125 | trajectory.add("control/arm/joint_positions", joint_positions) 126 | ``` 127 | 128 | ## 🎥 Video Codec Support 129 | 130 | robodm supports multiple video codecs for efficient storage of visual data: 131 | 132 | | Codec | Use Case | Compression | Quality | 133 | |-------|----------|-------------|---------| 134 | | `rawvideo` | Lossless, fast I/O | None | Perfect | 135 | | `ffv1` | Lossless compression | High | Perfect | 136 | | `libx264` | General purpose | Very High | Excellent | 137 | | `libx265` | Better compression | Very High | Excellent | 138 | | `libaom-av1` | Best compression | Highest | Excellent | 139 | | `auto` | Automatic selection | Optimal | Optimal | 140 | 141 | ```python 142 | # Automatic codec selection based on data characteristics 143 | trajectory = robodm.Trajectory(path="auto.vla", mode="w", video_codec="auto") 144 | 145 | # Manual codec selection for specific needs 146 | trajectory = robodm.Trajectory(path="lossless.vla", mode="w", video_codec="ffv1") 147 | ``` 148 | 149 | ## 🧪 Development & Testing 150 | 151 | ### Running Tests 152 | 153 | ```bash 154 | # Install development dependencies 155 | pip install -e .[test] 156 | 157 | # Run all tests 158 | make test 159 | 160 | # Run specific test categories 161 | pytest tests/test_trajectory.py -v 162 | pytest tests/test_loaders.py -v 163 | ``` 164 | 165 | 166 | ## 📝 Examples 167 | 168 | Explore the `examples/` directory for more detailed usage patterns: 169 | 170 | - **[Basic Data Collection](./examples/data_collection_and_load.py)**: Simple data collection and loading 171 | - **[Benchmark Scripts](./tests/)**: Performance testing and optimization 172 | 173 | We are actively and heavily refactoring the code to make it more robust and maintainable. See commit `5bbb8b` for the prior ICRA submission. 174 | 175 | 176 | 177 | ## 🤝 Contributing 178 | 179 | We welcome contributions! Please see [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on: 180 | 181 | - Setting up development environment 182 | - Running tests and benchmarks 183 | - Code style and formatting 184 | - Submitting pull requests 185 | 186 | ## 📄 License 187 | 188 | This project is licensed under the BSD 3-Clause License. See [LICENSE](LICENSE) for details. 189 | 190 | 191 | ## 📚 Citation 192 | 193 | If you use robodm in your research, please cite: 194 | 195 | ```bibtex 196 | @article{chen2025robo, 197 | title={Robo-DM: Data Management For Large Robot Datasets}, 198 | author={Chen, Kaiyuan and Fu, Letian and Huang, David and Zhang, Yanxiang and Chen, Lawrence Yunliang and Huang, Huang and Hari, Kush and Balakrishna, Ashwin and Xiao, Ted and Sanketi, Pannag R and others}, 199 | journal={arXiv preprint arXiv:2505.15558}, 200 | year={2025} 201 | } 202 | ``` 203 | -------------------------------------------------------------------------------- /examples/data_collection_and_load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import time 4 | 5 | import numpy as np 6 | 7 | import robodm 8 | 9 | if __name__ == "__main__": 10 | path = os.path.join(tempfile.gettempdir(), "test_trajectory.vla") 11 | 12 | # Create a trajectory 13 | traj = robodm.Trajectory(path=path, mode="w") 14 | 15 | # Add some data 16 | for i in range(10): 17 | traj.add( 18 | "observation/image", 19 | np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8), 20 | ) 21 | traj.add("observation/state", np.random.rand(10).astype(np.float32)) 22 | traj.add("action", np.random.rand(7).astype(np.float32)) 23 | time.sleep(0.1) 24 | 25 | # Close the trajectory 26 | traj.close() 27 | 28 | print(f"Trajectory saved to {path}") 29 | 30 | # Load the trajectory 31 | traj = robodm.Trajectory(path=path, mode="r") 32 | data = traj.load() 33 | 34 | print(f"Loaded trajectory with {len(data['observation/image'])} timesteps") 35 | print(f"Image shape: {data['observation/image'][0].shape}") 36 | print(f"State shape: {data['observation/state'][0].shape}") 37 | print(f"Action shape: {data['action'][0].shape}") 38 | 39 | # Clean up 40 | os.remove(path) 41 | -------------------------------------------------------------------------------- /format.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # YAPF formatter, adapted for robodm project. 3 | # 4 | # Usage: 5 | # # Do work and commit your work. 6 | 7 | # # Format files that differ from origin/main. 8 | # bash format.sh 9 | 10 | # # Check formatting without making changes (for CI) 11 | # bash format.sh --check 12 | 13 | # # Format all files 14 | # bash format.sh --all 15 | 16 | # # Commit changed files with message 'Run yapf and pylint' 17 | # 18 | # 19 | # YAPF + Black formatter. This script formats all changed files from the last mergebase. 20 | # You are encouraged to run this locally before pushing changes for review. 21 | 22 | # Cause the script to exit if a single command fails 23 | set -eo pipefail 24 | 25 | # this stops git rev-parse from failing if we run this from the .git directory 26 | builtin cd "$(dirname "${BASH_SOURCE:-$0}")" 27 | ROOT="$(git rev-parse --show-toplevel)" 28 | builtin cd "$ROOT" || exit 1 29 | 30 | # Parse command line arguments 31 | CHECK_ONLY=false 32 | RUN_ALL=false 33 | 34 | while [[ $# -gt 0 ]]; do 35 | case $1 in 36 | --check) 37 | CHECK_ONLY=true 38 | shift 39 | ;; 40 | --all) 41 | RUN_ALL=true 42 | shift 43 | ;; 44 | --files) 45 | # Keep existing behavior for --files 46 | break 47 | ;; 48 | *) 49 | echo "Unknown option: $1" 50 | echo "Usage: $0 [--check] [--all] [--files file1 file2 ...]" 51 | exit 1 52 | ;; 53 | esac 54 | done 55 | 56 | # Check if tools are installed before getting versions 57 | check_tool_installed() { 58 | if ! command -v "$1" &> /dev/null; then 59 | echo "Error: $1 is not installed. Please install development dependencies." 60 | echo "You can install them with: pip install yapf black isort mypy pylint flake8" 61 | exit 1 62 | fi 63 | } 64 | 65 | check_tool_installed "yapf" 66 | check_tool_installed "black" 67 | check_tool_installed "isort" 68 | check_tool_installed "mypy" 69 | check_tool_installed "pylint" 70 | 71 | YAPF_VERSION=$(yapf --version | awk '{print $2}') 72 | BLACK_VERSION=$(black --version | head -n 1 | awk '{print $2}') 73 | ISORT_VERSION=$(isort --version | head -n 1 | awk '{print $2}') 74 | MYPY_VERSION=$(mypy --version | awk '{print $2}') 75 | PYLINT_VERSION=$(pylint --version | head -n 1 | awk '{print $2}') 76 | 77 | echo "Using formatting tools:" 78 | echo " yapf: $YAPF_VERSION" 79 | echo " black: $BLACK_VERSION" 80 | echo " isort: $ISORT_VERSION" 81 | echo " mypy: $MYPY_VERSION" 82 | echo " pylint: $PYLINT_VERSION" 83 | echo 84 | 85 | YAPF_FLAGS=( 86 | '--recursive' 87 | '--parallel' 88 | ) 89 | 90 | # Add --diff flag for check mode 91 | if [ "$CHECK_ONLY" = true ]; then 92 | YAPF_FLAGS+=('--diff') 93 | BLACK_FLAGS=('--check' '--diff') 94 | ISORT_FLAGS=('--check-only' '--diff') 95 | else 96 | YAPF_FLAGS+=('--in-place') 97 | BLACK_FLAGS=() 98 | ISORT_FLAGS=() 99 | fi 100 | 101 | YAPF_EXCLUDES=( 102 | '--exclude' 'build/**' 103 | '--exclude' '.pytest_cache/**' 104 | '--exclude' 'robodm.egg-info/**' 105 | '--exclude' '__pycache__/**' 106 | ) 107 | 108 | ISORT_EXCLUDES=( 109 | '--sg' 'build/**' 110 | '--sg' '.pytest_cache/**' 111 | '--sg' 'robodm.egg-info/**' 112 | '--sg' '__pycache__/**' 113 | ) 114 | 115 | PYLINT_FLAGS=( 116 | '--disable=C0103,C0114,C0115,C0116' # Disable some overly strict checks 117 | ) 118 | 119 | # Track if any formatting issues were found 120 | FORMAT_ISSUES=false 121 | 122 | # Format specified files 123 | format() { 124 | if [ "$CHECK_ONLY" = true ]; then 125 | if ! yapf "${YAPF_FLAGS[@]}" "$@" | grep -q .; then 126 | return 0 127 | else 128 | echo "YAPF formatting issues found" 129 | FORMAT_ISSUES=true 130 | return 1 131 | fi 132 | else 133 | yapf "${YAPF_FLAGS[@]}" "$@" 134 | fi 135 | } 136 | 137 | # Format files that differ from main branch. Ignores dirs that are not slated 138 | # for autoformat yet. 139 | format_changed() { 140 | # The `if` guard ensures that the list of filenames is not empty, which 141 | # could cause yapf to receive 0 positional arguments, making it hang 142 | # waiting for STDIN. 143 | # 144 | # `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that 145 | # exist on both branches. 146 | MERGEBASE="$(git merge-base origin/main HEAD 2>/dev/null || git merge-base origin/master HEAD 2>/dev/null || echo HEAD~1)" 147 | 148 | if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then 149 | local files 150 | files=$(git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi') 151 | if [ -n "$files" ]; then 152 | echo "$files" | tr '\n' '\0' | xargs -P 5 -0 \ 153 | yapf "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}" 154 | fi 155 | fi 156 | } 157 | 158 | # Format all files 159 | format_all() { 160 | if [ "$CHECK_ONLY" = true ]; then 161 | echo "Checking YAPF formatting..." 162 | if ! yapf "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" robodm tests examples | grep -q .; then 163 | echo "✓ YAPF: No formatting issues" 164 | else 165 | echo "✗ YAPF: Formatting issues found" 166 | FORMAT_ISSUES=true 167 | fi 168 | else 169 | yapf "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" robodm tests examples 170 | fi 171 | } 172 | 173 | echo 'robodm Black formatting:' 174 | if [ "$CHECK_ONLY" = true ]; then 175 | echo "Checking Black formatting..." 176 | if black "${BLACK_FLAGS[@]}" robodm tests examples; then 177 | echo "✓ Black: No formatting issues" 178 | else 179 | echo "✗ Black: Formatting issues found" 180 | FORMAT_ISSUES=true 181 | fi 182 | else 183 | black "${BLACK_FLAGS[@]}" robodm tests examples 184 | fi 185 | 186 | ## This flag formats individual files. --files *must* be the first command line 187 | ## arg to use this option. 188 | if [[ "$1" == '--files' ]]; then 189 | format "${@:2}" 190 | # If `--all` is passed, then any further arguments are ignored and the 191 | # entire python directory is formatted. 192 | elif [[ "$RUN_ALL" == true ]]; then 193 | format_all 194 | else 195 | # Format only the files that changed in last commit. 196 | format_changed 197 | fi 198 | echo 'robodm yapf: Done' 199 | 200 | echo 'robodm isort:' 201 | if [ "$CHECK_ONLY" = true ]; then 202 | echo "Checking isort formatting..." 203 | if isort "${ISORT_FLAGS[@]}" robodm tests examples "${ISORT_EXCLUDES[@]}"; then 204 | echo "✓ isort: No formatting issues" 205 | else 206 | echo "✗ isort: Formatting issues found" 207 | FORMAT_ISSUES=true 208 | fi 209 | else 210 | isort "${ISORT_FLAGS[@]}" robodm tests examples "${ISORT_EXCLUDES[@]}" 211 | fi 212 | 213 | # Run mypy 214 | echo 'robodm mypy:' 215 | # Check if there are any Python files to check 216 | if find robodm -name "*.py" | head -1 | grep -q .; then 217 | if mypy robodm --ignore-missing-imports --check-untyped-defs; then 218 | echo "✓ MyPy: No type issues" 219 | else 220 | echo "✗ MyPy: Type issues found" 221 | if [ "$CHECK_ONLY" = true ]; then 222 | FORMAT_ISSUES=true 223 | fi 224 | fi 225 | else 226 | echo "No Python files found in robodm/" 227 | fi 228 | 229 | # Run Pylint 230 | echo 'robodm Pylint:' 231 | if [[ "$1" == '--files' ]]; then 232 | # If --files is passed, filter to files within robodm/ and pass to pylint. 233 | if pylint "${PYLINT_FLAGS[@]}" "${@:2}"; then 234 | echo "✓ Pylint: No issues" 235 | else 236 | echo "✗ Pylint: Issues found" 237 | if [ "$CHECK_ONLY" = true ]; then 238 | FORMAT_ISSUES=true 239 | fi 240 | fi 241 | elif [[ "$RUN_ALL" == true ]]; then 242 | # Pylint entire robodm directory. 243 | if find robodm -name "*.py" | head -1 | grep -q .; then 244 | if pylint "${PYLINT_FLAGS[@]}" robodm; then 245 | echo "✓ Pylint: No issues" 246 | else 247 | echo "✗ Pylint: Issues found" 248 | if [ "$CHECK_ONLY" = true ]; then 249 | FORMAT_ISSUES=true 250 | fi 251 | fi 252 | else 253 | echo "No Python files found in robodm/" 254 | fi 255 | else 256 | # Pylint only files in robodm/ that have changed in last commit. 257 | MERGEBASE="$(git merge-base origin/main HEAD 2>/dev/null || git merge-base origin/master HEAD 2>/dev/null || echo HEAD~1)" 258 | changed_files=$(git diff --name-only --diff-filter=ACM "$MERGEBASE" -- 'robodm/*.py' 'robodm/**/*.py') 259 | if [[ -n "$changed_files" ]]; then 260 | if echo "$changed_files" | tr '\n' '\0' | xargs -0 pylint "${PYLINT_FLAGS[@]}"; then 261 | echo "✓ Pylint: No issues" 262 | else 263 | echo "✗ Pylint: Issues found" 264 | if [ "$CHECK_ONLY" = true ]; then 265 | FORMAT_ISSUES=true 266 | fi 267 | fi 268 | else 269 | echo 'Pylint skipped: no files changed in robodm/.' 270 | fi 271 | fi 272 | 273 | # Final status check 274 | if [ "$CHECK_ONLY" = true ]; then 275 | if [ "$FORMAT_ISSUES" = true ]; then 276 | echo "" 277 | echo "❌ Code formatting/quality issues detected!" 278 | echo "Please run 'bash format.sh --all' to fix formatting issues." 279 | exit 1 280 | else 281 | echo "" 282 | echo "✅ All code formatting and quality checks passed!" 283 | exit 0 284 | fi 285 | fi 286 | 287 | if ! git diff --quiet &>/dev/null; then 288 | echo 'Reformatted files. Please review and stage the changes.' 289 | echo 'Changes not staged for commit:' 290 | echo 291 | git --no-pager diff --name-only 292 | exit 1 293 | fi 294 | 295 | echo 'robodm formatting complete!' -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: robodm 2 | site_description: An Efficient and Scalable Data Collection and Management Framework For Robotics Learning 3 | site_url: https://github.com/BerkeleyAutomation/robodm/ 4 | 5 | nav: 6 | - Home: index.md 7 | - API Reference: api.md 8 | 9 | theme: 10 | name: material 11 | palette: 12 | primary: blue 13 | accent: orange 14 | 15 | plugins: 16 | - search 17 | - mkdocstrings: 18 | default_handler: python 19 | handlers: 20 | python: 21 | options: 22 | heading_level: 3 23 | show_source: false -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "robodm" 7 | version = "0.1.0" 8 | description = "An Efficient and Scalable Data Collection and Management Framework For Robotics Learning" 9 | readme = "README.md" 10 | requires-python = ">=3.10" 11 | license = {text = "BSD-3-Clause"} 12 | authors = [ 13 | {name = "Berkeley Automation Lab", email = "automation@berkeley.edu"}, 14 | ] 15 | keywords = ["robotics", "data management", "machine learning", "trajectories"] 16 | classifiers = [ 17 | "Development Status :: 4 - Beta", 18 | "Intended Audience :: Developers", 19 | "Intended Audience :: Science/Research", 20 | "License :: OSI Approved :: BSD License", 21 | "Operating System :: OS Independent", 22 | "Programming Language :: Python :: 3", 23 | "Programming Language :: Python :: 3.10", 24 | "Programming Language :: Python :: 3.11", 25 | "Programming Language :: Python :: 3.12", 26 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 27 | "Topic :: Software Development :: Libraries :: Python Modules", 28 | ] 29 | dependencies = [ 30 | "numpy>=1.21.0", 31 | "h5py>=3.7.0", 32 | "opencv-python>=4.5.0", 33 | "tqdm>=4.64.0", 34 | "psutil>=5.9.0", 35 | "ray[data]>=2.8.0", 36 | ] 37 | 38 | [project.optional-dependencies] 39 | hf = ["datasets>=2.14.0", "huggingface-hub>=0.16.0"] 40 | rtx = ["tensorflow>=2.13.0", "tensorflow-datasets>=4.9.0"] 41 | aws = ["boto3>=1.26.0", "s3fs>=2023.6.0"] 42 | torch = ["torch>=1.13.0", "torchvision>=0.14.0"] 43 | test = [ 44 | "pytest>=7.0.0", 45 | "pytest-cov>=4.0.0", 46 | "pytest-xdist>=3.0.0", 47 | "pytest-benchmark>=4.0.0", 48 | ] 49 | lerobot = ["lerobot>=0.1.0"] 50 | all = ["robodm[hf,rtx,aws,torch,lerobot]"] 51 | 52 | [project.urls] 53 | homepage = "https://github.com/BerkeleyAutomation/robodm/" 54 | repository = "https://github.com/BerkeleyAutomation/robodm/" 55 | documentation = "https://github.com/BerkeleyAutomation/robodm/" 56 | "Bug Tracker" = "https://github.com/BerkeleyAutomation/robodm/issues" 57 | 58 | [tool.setuptools.packages.find] 59 | include = ["robodm*"] 60 | 61 | [tool.setuptools.package-data] 62 | robodm = ["py.typed"] -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [tool:pytest] 2 | testpaths = tests 3 | python_files = test_*.py 4 | python_functions = test_* 5 | python_classes = Test* 6 | addopts = 7 | -v 8 | --tb=short 9 | --strict-markers 10 | --disable-warnings 11 | markers = 12 | slow: marks tests as slow (deselect with '-m "not slow"') 13 | integration: marks tests as integration tests 14 | benchmark: marks tests as benchmark tests 15 | filterwarnings = 16 | ignore::DeprecationWarning 17 | ignore::PendingDeprecationWarning -------------------------------------------------------------------------------- /robodm/__init__.py: -------------------------------------------------------------------------------- 1 | # robodm: A high-performance robotics data management framework 2 | # Copyright (c) 2024 Berkeley Automation Lab 3 | 4 | import os 5 | 6 | __root_dir__ = os.path.dirname(os.path.abspath(__file__)) 7 | 8 | # from robodm import dataset, episode, feature 9 | # from robodm.dataset import Dataset 10 | # from robodm import trajectory 11 | 12 | from robodm.feature import FeatureType 13 | from robodm.trajectory import Trajectory 14 | from robodm.trajectory_base import (FileSystemInterface, TimeProvider, 15 | TrajectoryInterface) 16 | from robodm.trajectory_factory import TrajectoryFactory, create_trajectory 17 | 18 | __all__ = [ 19 | "FeatureType", 20 | "Trajectory", 21 | "TrajectoryInterface", 22 | "FileSystemInterface", 23 | "TimeProvider", 24 | "TrajectoryFactory", 25 | "create_trajectory", 26 | ] 27 | 28 | # Version of the robodm package 29 | __version__ = "0.1.0" 30 | 31 | # Metadata 32 | __author__ = "Berkeley Automation Lab" 33 | __email__ = "automation@berkeley.edu" 34 | __description__ = "A high-performance robotics data management framework" 35 | __url__ = "https://github.com/BerkeleyAutomation/robodm" 36 | __license__ = "BSD-3-Clause" 37 | 38 | import logging 39 | 40 | _FORMAT = "%(levelname).1s %(asctime)s %(filename)s:%(lineno)d] %(message)s" 41 | logging.basicConfig(format=_FORMAT) 42 | logging.root.setLevel(logging.INFO) 43 | -------------------------------------------------------------------------------- /robodm/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | from typing import Any, Dict, List, Optional, Text, Union 4 | 5 | import numpy as np 6 | 7 | try: 8 | import ray 9 | import ray.data as rd 10 | 11 | RAY_AVAILABLE = True 12 | except ImportError: 13 | RAY_AVAILABLE = False 14 | 15 | from robodm.loader.vla import (LoadingMode, RayVLALoader, SliceConfig, 16 | create_slice_loader, create_trajectory_loader) 17 | from robodm.utils import data_to_tf_schema 18 | 19 | 20 | @dataclass 21 | class DatasetConfig: 22 | """Configuration for VLADataset.""" 23 | 24 | batch_size: int = 1 25 | shuffle: bool = False 26 | num_parallel_reads: int = 4 27 | ray_init_kwargs: Optional[Dict] = None 28 | 29 | 30 | class VLADataset: 31 | """ 32 | Ray Dataset-based VLA dataset supporting both trajectory and slice loading modes. 33 | 34 | This dataset provides: 35 | 1. Parallel data loading using Ray Dataset 36 | 2. Automatic shuffling and splitting 37 | 3. Support for both trajectory and slice loading modes 38 | 4. Efficient data management for large datasets 39 | """ 40 | 41 | def __init__( 42 | self, 43 | path: Text, 44 | mode: Union[str, LoadingMode] = LoadingMode.TRAJECTORY, 45 | split: str = "all", 46 | return_type: str = "numpy", 47 | config: Optional[DatasetConfig] = None, 48 | slice_config: Optional[SliceConfig] = None, 49 | **kwargs, 50 | ): 51 | """ 52 | Initialize VLA dataset. 53 | 54 | Args: 55 | path: Path to VLA files (can be glob pattern, directory, or single file) 56 | mode: Loading mode ("trajectory" or "slice", or LoadingMode enum) 57 | split: Data split ("all", "train", "val") 58 | return_type: Return type ("numpy", "tensor", "container") 59 | config: Dataset configuration 60 | slice_config: Slice configuration (required if mode="slice") 61 | **kwargs: Additional arguments passed to RayVLALoader 62 | """ 63 | if not RAY_AVAILABLE: 64 | raise ImportError( 65 | "Ray is required for VLADataset. Install with: pip install 'ray[data]'" 66 | ) 67 | 68 | self.path = path 69 | self.return_type = return_type 70 | self.config = config or DatasetConfig() 71 | 72 | # Handle string mode input 73 | if isinstance(mode, str): 74 | mode = LoadingMode.TRAJECTORY if mode == "trajectory" else LoadingMode.SLICE 75 | self.mode = mode 76 | 77 | # Initialize Ray if not already initialized 78 | if not ray.is_initialized(): 79 | ray.init(**(self.config.ray_init_kwargs or {})) 80 | 81 | # Create the loader 82 | self.loader = RayVLALoader( 83 | path=path, 84 | mode=mode, 85 | batch_size=self.config.batch_size, 86 | return_type=return_type, 87 | shuffle=self.config.shuffle, 88 | num_parallel_reads=self.config.num_parallel_reads, 89 | slice_config=slice_config, 90 | **kwargs, 91 | ) 92 | 93 | # Cache for schema and stats 94 | self._schema = None 95 | self._stats = None 96 | 97 | @classmethod 98 | def create_trajectory_dataset( 99 | cls, 100 | path: Text, 101 | split: str = "all", 102 | return_type: str = "numpy", 103 | config: Optional[DatasetConfig] = None, 104 | **kwargs, 105 | ) -> "VLADataset": 106 | """Create a dataset for loading complete trajectories.""" 107 | return cls( 108 | path=path, 109 | mode=LoadingMode.TRAJECTORY, 110 | return_type=return_type, 111 | config=config, 112 | **kwargs, 113 | ) 114 | 115 | @classmethod 116 | def create_slice_dataset( 117 | cls, 118 | path: Text, 119 | slice_length: int = 100, 120 | return_type: str = "numpy", 121 | config: Optional[DatasetConfig] = None, 122 | min_slice_length: Optional[int] = None, 123 | stride: int = 1, 124 | random_start: bool = True, 125 | overlap_ratio: float = 0.0, 126 | **kwargs, 127 | ) -> "VLADataset": 128 | """Create a dataset for loading trajectory slices.""" 129 | slice_config = SliceConfig( 130 | slice_length=slice_length, 131 | min_slice_length=min_slice_length, 132 | stride=stride, 133 | random_start=random_start, 134 | overlap_ratio=overlap_ratio, 135 | ) 136 | 137 | return cls( 138 | path=path, 139 | mode=LoadingMode.SLICE, 140 | return_type=return_type, 141 | config=config, 142 | slice_config=slice_config, 143 | **kwargs, 144 | ) 145 | 146 | def get_ray_dataset(self) -> rd.Dataset: 147 | """Get the underlying Ray dataset.""" 148 | return self.loader.dataset 149 | 150 | def iter_batches(self, batch_size: Optional[int] = None): 151 | """Iterate over batches of data.""" 152 | return self.loader.iter_batches(batch_size) 153 | 154 | def iter_rows(self): 155 | """Iterate over individual rows of data.""" 156 | return self.loader.iter_rows() 157 | 158 | def take(self, num_items: int) -> List[Dict[str, Any]]: 159 | """Take a specific number of items.""" 160 | return self.loader.take(num_items) 161 | 162 | def sample(self, 163 | num_samples: int, 164 | replace: bool = False) -> List[Dict[str, Any]]: 165 | """Sample from the dataset.""" 166 | return list(self.loader.sample(num_samples, replace)) 167 | 168 | def count(self) -> int: 169 | """Count the number of items in the dataset.""" 170 | return self.loader.count() 171 | 172 | def schema(self): 173 | """Get the schema of the dataset.""" 174 | if self._schema is None: 175 | self._schema = self.loader.schema() 176 | return self._schema 177 | 178 | def split(self, *fractions: float, shuffle: bool = True): 179 | """Split the dataset into multiple datasets.""" 180 | ray_datasets = self.loader.split(*fractions, shuffle=shuffle) 181 | 182 | # Create new VLADataset instances for each split 183 | split_datasets = [] 184 | for ray_ds in ray_datasets: 185 | split_dataset = VLADataset.__new__(VLADataset) 186 | split_dataset.path = self.path 187 | split_dataset.mode = self.mode 188 | split_dataset.return_type = self.return_type 189 | split_dataset.config = self.config 190 | split_dataset.loader = self.loader.__class__.__new__( 191 | self.loader.__class__) 192 | split_dataset.loader.dataset = ray_ds 193 | split_dataset._schema = self._schema 194 | split_dataset._stats = None 195 | split_datasets.append(split_dataset) 196 | 197 | return split_datasets 198 | 199 | def filter(self, fn): 200 | """Filter the dataset.""" 201 | filtered_dataset = VLADataset.__new__(VLADataset) 202 | filtered_dataset.path = self.path 203 | filtered_dataset.mode = self.mode 204 | filtered_dataset.return_type = self.return_type 205 | filtered_dataset.config = self.config 206 | filtered_dataset.loader = self.loader.__class__.__new__( 207 | self.loader.__class__) 208 | filtered_dataset.loader.dataset = self.loader.dataset.filter(fn) 209 | filtered_dataset._schema = self._schema 210 | filtered_dataset._stats = None 211 | return filtered_dataset 212 | 213 | def map(self, fn, **kwargs): 214 | """Map a function over the dataset.""" 215 | mapped_dataset = VLADataset.__new__(VLADataset) 216 | mapped_dataset.path = self.path 217 | mapped_dataset.mode = self.mode 218 | mapped_dataset.return_type = self.return_type 219 | mapped_dataset.config = self.config 220 | mapped_dataset.loader = self.loader.__class__.__new__( 221 | self.loader.__class__) 222 | mapped_dataset.loader.dataset = self.loader.dataset.map(fn, **kwargs) 223 | mapped_dataset._schema = None # Schema might change after mapping 224 | mapped_dataset._stats = None 225 | return mapped_dataset 226 | 227 | def shuffle(self, seed: Optional[int] = None): 228 | """Shuffle the dataset.""" 229 | shuffled_dataset = VLADataset.__new__(VLADataset) 230 | shuffled_dataset.path = self.path 231 | shuffled_dataset.mode = self.mode 232 | shuffled_dataset.return_type = self.return_type 233 | shuffled_dataset.config = self.config 234 | shuffled_dataset.loader = self.loader.__class__.__new__( 235 | self.loader.__class__) 236 | shuffled_dataset.loader.dataset = self.loader.dataset.random_shuffle( 237 | seed=seed) 238 | shuffled_dataset._schema = self._schema 239 | shuffled_dataset._stats = None 240 | return shuffled_dataset 241 | 242 | def materialize(self): 243 | """Materialize the dataset in memory.""" 244 | return self.loader.materialize() 245 | 246 | def get_stats(self) -> Dict[str, Any]: 247 | """Get dataset statistics.""" 248 | if self._stats is None: 249 | sample = self.peek() 250 | if sample: 251 | self._stats = { 252 | "mode": 253 | self.mode.value, 254 | "return_type": 255 | self.return_type, 256 | "total_items": 257 | self.count(), 258 | "sample_keys": 259 | (list(sample.keys()) if isinstance(sample, dict) else []), 260 | } 261 | 262 | # Add mode-specific stats 263 | if self.mode == LoadingMode.TRAJECTORY: 264 | # For trajectory mode, estimate length from first key 265 | first_key = next(iter(sample.keys())) if sample else None 266 | if first_key and hasattr(sample[first_key], "__len__"): 267 | self._stats["trajectory_length"] = len( 268 | sample[first_key]) 269 | elif self.mode == LoadingMode.SLICE: 270 | # For slice mode, estimate length from first key 271 | first_key = next(iter(sample.keys())) if sample else None 272 | if first_key and hasattr(sample[first_key], "__len__"): 273 | self._stats["slice_length"] = len(sample[first_key]) 274 | self._stats["slice_start"] = ( 275 | 0 # Cannot determine from direct data 276 | ) 277 | self._stats["slice_end"] = len(sample[first_key]) 278 | else: 279 | self._stats = {"mode": self.mode.value, "total_items": 0} 280 | 281 | return self._stats 282 | 283 | def peek(self) -> Optional[Dict[str, Any]]: 284 | """Peek at the first item without consuming it.""" 285 | return self.loader.peek() 286 | 287 | def get_tf_schema(self): 288 | """Get TensorFlow schema for the dataset.""" 289 | sample = self.peek() 290 | if sample: 291 | return data_to_tf_schema(sample) 292 | return None 293 | 294 | # Legacy compatibility methods 295 | def __iter__(self): 296 | """Iterate over the dataset (legacy compatibility).""" 297 | for item in self.loader.iter_rows(): 298 | yield item 299 | 300 | def __next__(self): 301 | """Get next item (legacy compatibility).""" 302 | batch = self.loader.get_batch() 303 | if batch: 304 | return batch[0] 305 | raise StopIteration 306 | 307 | def __len__(self) -> int: 308 | """Get the number of items in the dataset.""" 309 | return self.count() 310 | 311 | def __getitem__(self, index): 312 | """Not supported for Ray datasets - use take() or sample() instead.""" 313 | raise NotImplementedError( 314 | "Random access not supported for Ray datasets. " 315 | "Use take(), sample(), or iterate over the dataset instead.") 316 | 317 | def get_loader(self): 318 | """Get the underlying loader (legacy compatibility).""" 319 | return self.loader 320 | 321 | def get_next_trajectory(self): 322 | """Get next trajectory (legacy compatibility).""" 323 | item = next(self) 324 | return item 325 | 326 | 327 | # Utility functions for common dataset operations 328 | def load_trajectory_dataset( 329 | path: Text, 330 | split: str = "all", 331 | return_type: str = "numpy", 332 | batch_size: int = 1, 333 | shuffle: bool = False, 334 | num_parallel_reads: int = 4, 335 | **kwargs, 336 | ) -> VLADataset: 337 | """Load a dataset for complete trajectories.""" 338 | config = DatasetConfig(batch_size=batch_size, 339 | shuffle=shuffle, 340 | num_parallel_reads=num_parallel_reads) 341 | return VLADataset.create_trajectory_dataset(path=path, 342 | return_type=return_type, 343 | config=config, 344 | **kwargs) 345 | 346 | 347 | def load_slice_dataset( 348 | path: Text, 349 | slice_length: int = 100, 350 | split: str = "all", 351 | return_type: str = "numpy", 352 | batch_size: int = 1, 353 | shuffle: bool = False, 354 | num_parallel_reads: int = 4, 355 | min_slice_length: Optional[int] = None, 356 | stride: int = 1, 357 | random_start: bool = True, 358 | overlap_ratio: float = 0.0, 359 | **kwargs, 360 | ) -> VLADataset: 361 | """Load a dataset for trajectory slices.""" 362 | config = DatasetConfig(batch_size=batch_size, 363 | shuffle=shuffle, 364 | num_parallel_reads=num_parallel_reads) 365 | return VLADataset.create_slice_dataset( 366 | path=path, 367 | slice_length=slice_length, 368 | return_type=return_type, 369 | config=config, 370 | min_slice_length=min_slice_length, 371 | stride=stride, 372 | random_start=random_start, 373 | overlap_ratio=overlap_ratio, 374 | **kwargs, 375 | ) 376 | 377 | 378 | def split_dataset( 379 | dataset: VLADataset, 380 | train_fraction: float = 0.8, 381 | val_fraction: float = 0.2, 382 | shuffle: bool = False, 383 | ) -> tuple[VLADataset, VLADataset]: 384 | """Split a dataset into train and validation sets.""" 385 | if abs(train_fraction + val_fraction - 1.0) > 1e-6: 386 | raise ValueError("train_fraction + val_fraction must equal 1.0") 387 | 388 | splits = dataset.split(train_fraction, val_fraction, shuffle=shuffle) 389 | return splits[0], splits[1] 390 | -------------------------------------------------------------------------------- /robodm/feature.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any, Dict, List, Optional, Tuple, Union 3 | 4 | import numpy as np 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | SUPPORTED_DTYPES = [ 9 | "null", 10 | "bool", 11 | "int8", 12 | "int16", 13 | "int32", 14 | "int64", 15 | "uint8", 16 | "uint16", 17 | "uint32", 18 | "uint64", 19 | "float16", 20 | "float32", 21 | "float64", 22 | "timestamp(s)", 23 | "timestamp(ms)", 24 | "timestamp(us)", 25 | "timestamp(ns)", 26 | "timestamp(s, tz)", 27 | "timestamp(ms, tz)", 28 | "timestamp(us, tz)", 29 | "timestamp(ns, tz)", 30 | "binary", 31 | "large_binary", 32 | "string", 33 | "str", 34 | "large_string", 35 | ] 36 | 37 | 38 | class FeatureType: 39 | """ 40 | class for feature definition and conversions 41 | """ 42 | 43 | def __init__( 44 | self, 45 | dtype: Optional[str] = None, 46 | shape: Optional[Tuple[int, ...]] = None, 47 | tf_feature_spec=None, 48 | data=None, 49 | ) -> None: 50 | # scalar: (), vector: (n,), matrix: (n,m) 51 | self.dtype: str = "" 52 | self.shape: Optional[Tuple[int, ...]] = None 53 | 54 | if data is not None: 55 | self.from_data(data) 56 | elif tf_feature_spec is not None: 57 | self.from_tf_feature_type(tf_feature_spec) 58 | elif dtype is not None: 59 | self._set(dtype, shape) 60 | 61 | def __str__(self): 62 | return f"dtype={self.dtype}; shape={self.shape})" 63 | 64 | def __repr__(self): 65 | return self.__str__() 66 | 67 | def _set(self, dtype: str, shape: Optional[Tuple[int, ...]]): 68 | if dtype == "double": # fix inferred type 69 | dtype = "float64" 70 | if dtype == "float": # fix inferred type 71 | dtype = "float32" 72 | if dtype == "int": # fix inferred type 73 | dtype = "int32" 74 | if dtype == "object": 75 | dtype = "string" 76 | if dtype not in SUPPORTED_DTYPES: 77 | raise ValueError(f"Unsupported dtype: {dtype}") 78 | if shape is not None and not isinstance(shape, tuple): 79 | raise ValueError(f"Shape must be a tuple: {shape}") 80 | self.dtype = dtype 81 | self.shape = shape 82 | 83 | def from_tf_feature_type(self, tf_feature_spec): 84 | """ 85 | Convert from tf feature 86 | """ 87 | logger.debug(f"tf_feature_spec: {tf_feature_spec}") 88 | from tensorflow_datasets.core.features import (FeaturesDict, Image, 89 | Scalar, Tensor, Text) 90 | 91 | if isinstance(tf_feature_spec, Tensor): 92 | shape = tf_feature_spec.shape 93 | dtype = tf_feature_spec.dtype.name 94 | elif isinstance(tf_feature_spec, Image): 95 | shape = tf_feature_spec.shape 96 | dtype = tf_feature_spec.np_dtype 97 | # TODO: currently images are not handled efficiently 98 | elif isinstance(tf_feature_spec, Scalar): 99 | shape = () 100 | dtype = tf_feature_spec.dtype.name 101 | elif isinstance(tf_feature_spec, Text): 102 | shape = () 103 | dtype = "string" 104 | else: 105 | raise ValueError( 106 | f"Unsupported conversion from tf feature: {tf_feature_spec}") 107 | self._set(str(dtype), shape) 108 | return self 109 | 110 | @classmethod 111 | def from_data(cls, data: Any): 112 | """ 113 | Infer feature type from the provided data. 114 | """ 115 | feature_type = FeatureType() 116 | if isinstance(data, np.ndarray): 117 | feature_type._set(data.dtype.name, data.shape) 118 | elif isinstance(data, np.bool_): 119 | feature_type._set("bool", ()) 120 | elif isinstance(data, list): 121 | dtype = type(data[0]).__name__ 122 | data_shape: Tuple[int, ...] = (len(data), ) 123 | feature_type._set(dtype, data_shape) 124 | else: 125 | dtype = type(data).__name__ 126 | empty_shape: Tuple[int, ...] = () 127 | try: 128 | feature_type._set(dtype, empty_shape) 129 | except ValueError as e: 130 | print(f"Error: {e}") 131 | print(f"dtype: {dtype}") 132 | print(f"shape: {empty_shape}") 133 | print(f"data: {data}") 134 | raise e 135 | return feature_type 136 | 137 | @classmethod 138 | def from_str(cls, feature_str: str): 139 | """ 140 | Parse a string representation of the feature type. 141 | """ 142 | dtype, shape_str = feature_str.split(";") 143 | dtype = dtype.split("=")[1] 144 | shape_eval = eval(shape_str.split("=")[1][:-1]) # strip brackets 145 | # Ensure shape is a tuple 146 | if isinstance(shape_eval, tuple): 147 | shape: Optional[Tuple[int, ...]] = shape_eval 148 | else: 149 | shape = None 150 | return FeatureType(dtype=dtype, shape=shape) 151 | 152 | def to_tf_feature_type(self, first_dim_none=False): 153 | """ 154 | Convert to tf feature 155 | """ 156 | import tensorflow as tf 157 | from tensorflow_datasets.core.features import (FeaturesDict, Image, 158 | Scalar, Tensor, Text) 159 | 160 | str_dtype_to_tf_dtype = { 161 | "int8": tf.int8, 162 | "int16": tf.int16, 163 | "int32": tf.int32, 164 | "int64": tf.int64, 165 | "uint8": tf.uint8, 166 | "uint16": tf.uint16, 167 | "uint32": tf.uint32, 168 | "uint64": tf.uint64, 169 | "float16": tf.float16, 170 | "float32": tf.float32, 171 | "float64": tf.float64, 172 | "string": tf.string, 173 | "str": tf.string, 174 | "bool": tf.bool, 175 | } 176 | tf_detype = str_dtype_to_tf_dtype[self.dtype] 177 | if self.shape is not None and len(self.shape) == 0: 178 | if self.dtype == "string": 179 | return Text() 180 | else: 181 | return Scalar(dtype=tf_detype) 182 | elif self.shape is not None and len(self.shape) >= 1: 183 | if first_dim_none: 184 | tf_shape = [None] + list(self.shape[1:]) 185 | return Tensor(shape=tf_shape, dtype=tf_detype) 186 | else: 187 | return Tensor(shape=self.shape, dtype=tf_detype) 188 | else: 189 | raise ValueError(f"Unsupported conversion to tf feature: {self}") 190 | 191 | def to_pld_storage_type(self): 192 | if self.shape is not None and len(self.shape) == 0: 193 | if self.dtype == "string": 194 | return "large_binary" # TODO: better representation of strings 195 | else: 196 | return self.dtype 197 | else: 198 | return "large_binary" 199 | -------------------------------------------------------------------------------- /robodm/loader/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseLoader 2 | from .hdf5 import HDF5Loader 3 | from .rlds import RLDSLoader 4 | from .vla import NonShuffleVLALoader, VLALoader 5 | -------------------------------------------------------------------------------- /robodm/loader/base.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | 3 | 4 | class BaseLoader: 5 | def __init__(self, path): 6 | super(BaseLoader, self).__init__() 7 | self.logger = getLogger(__name__) 8 | self.path = path 9 | 10 | # def get_schema(self) -> Schema: 11 | # raise NotImplementedError 12 | 13 | def __len__(self): 14 | raise NotImplementedError 15 | 16 | def __iter___(self): 17 | raise NotImplementedError 18 | -------------------------------------------------------------------------------- /robodm/loader/hdf5.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import logging 3 | import os 4 | import random 5 | from typing import Any, Dict, List, Optional, Text 6 | 7 | import h5py 8 | import numpy as np 9 | import torch 10 | from torch.utils.data import DataLoader, IterableDataset 11 | 12 | from robodm.utils import _flatten, recursively_read_hdf5_group 13 | 14 | from . import BaseLoader 15 | 16 | 17 | def convert_vla_data_to_hdf5( 18 | data: Dict[str, Any], 19 | output_path: Text, 20 | compression: str = "gzip", 21 | compression_opts: int = 9, 22 | ) -> None: 23 | """ 24 | Convert VLA (Vision-Language-Action) data to HDF5 format. 25 | 26 | Args: 27 | data (Dict[str, Any]): Dictionary containing VLA data with feature names as keys 28 | output_path (Text): Path where the HDF5 file will be saved 29 | compression (str): Compression algorithm to use (default: "gzip") 30 | compression_opts (int): Compression level (0-9, default: 9) 31 | """ 32 | 33 | # Ensure output directory exists 34 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 35 | 36 | try: 37 | with h5py.File(output_path, "w") as h5_file: 38 | _write_dict_to_hdf5_group(h5_file, data, compression, 39 | compression_opts) 40 | logging.info(f"Successfully converted VLA data to HDF5: {output_path}") 41 | except Exception as e: 42 | logging.error(f"Error converting VLA data to HDF5: {e}") 43 | raise 44 | 45 | 46 | def _write_dict_to_hdf5_group( 47 | group: h5py.Group, 48 | data_dict: Dict[str, Any], 49 | compression: str = "gzip", 50 | compression_opts: int = 9, 51 | ) -> None: 52 | """ 53 | Recursively write a dictionary to an HDF5 group. 54 | 55 | Args: 56 | group (h5py.Group): HDF5 group to write to 57 | data_dict (Dict[str, Any]): Data dictionary to write 58 | compression (str): Compression algorithm 59 | compression_opts (int): Compression level 60 | """ 61 | 62 | for key, value in data_dict.items(): 63 | if isinstance(value, dict): 64 | # Create subgroup for nested dictionaries 65 | subgroup = group.create_group(key) 66 | _write_dict_to_hdf5_group(subgroup, value, compression, 67 | compression_opts) 68 | else: 69 | # Convert value to numpy array if needed 70 | if not isinstance(value, np.ndarray): 71 | if isinstance(value, (list, tuple)): 72 | value = np.array(value) 73 | else: 74 | # Single value 75 | value = np.array([value]) 76 | 77 | # Handle object arrays (strings, mixed types) 78 | if value.dtype == object: 79 | # Convert object arrays to string arrays for HDF5 compatibility 80 | string_data = [] 81 | for item in value.flat: 82 | if isinstance(item, (str, bytes)): 83 | string_data.append(str(item)) 84 | else: 85 | string_data.append(str(item)) 86 | value = np.array(string_data, dtype="S") 87 | 88 | # Create dataset with compression 89 | try: 90 | group.create_dataset( 91 | key, 92 | data=value, 93 | compression=compression, 94 | compression_opts=compression_opts, 95 | ) 96 | except Exception as e: 97 | logging.warning( 98 | f"Failed to compress {key}, saving without compression: {e}" 99 | ) 100 | group.create_dataset(key, data=value) 101 | 102 | 103 | def convert_trajectory_to_hdf5( 104 | trajectory_path: Text, 105 | output_path: Text, 106 | compression: str = "gzip", 107 | compression_opts: int = 9, 108 | ) -> None: 109 | """ 110 | Convert a trajectory container file to HDF5 format. 111 | 112 | Args: 113 | trajectory_path (Text): Path to the trajectory container file 114 | output_path (Text): Path where the HDF5 file will be saved 115 | compression (str): Compression algorithm to use (default: "gzip") 116 | compression_opts (int): Compression level (0-9, default: 9) 117 | """ 118 | 119 | # Import here to avoid circular imports 120 | from ..trajectory import Trajectory 121 | 122 | try: 123 | # Load trajectory data 124 | traj = Trajectory(trajectory_path, mode="r") 125 | data = traj.load(return_type="numpy") 126 | traj.close() 127 | 128 | # Convert to HDF5 129 | convert_vla_data_to_hdf5(data, output_path, compression, 130 | compression_opts) 131 | 132 | except Exception as e: 133 | logging.error(f"Error converting trajectory to HDF5: {e}") 134 | raise 135 | 136 | 137 | def batch_convert_trajectories_to_hdf5( 138 | trajectory_paths: List[Text], 139 | output_dir: Text, 140 | compression: str = "gzip", 141 | compression_opts: int = 9, 142 | parallel: bool = True, 143 | num_workers: Optional[int] = None, 144 | ) -> None: 145 | """ 146 | Convert multiple trajectory files to HDF5 format in batch. 147 | 148 | Args: 149 | trajectory_paths (List[Text]): List of trajectory file paths 150 | output_dir (Text): Directory where HDF5 files will be saved 151 | compression (str): Compression algorithm to use 152 | compression_opts (int): Compression level 153 | parallel (bool): Whether to use parallel processing 154 | num_workers (Optional[int]): Number of worker processes (default: CPU count) 155 | """ 156 | 157 | os.makedirs(output_dir, exist_ok=True) 158 | 159 | for traj_path in trajectory_paths: 160 | output_filename = os.path.splitext( 161 | os.path.basename(traj_path))[0] + ".h5" 162 | output_path = os.path.join(output_dir, output_filename) 163 | 164 | convert_trajectory_to_hdf5(traj_path, output_path, compression, 165 | compression_opts) 166 | 167 | 168 | def load_and_convert_to_hdf5( 169 | input_path: Text, 170 | output_path: Text, 171 | input_format: str = "auto", 172 | compression: str = "gzip", 173 | compression_opts: int = 9, 174 | ) -> None: 175 | """ 176 | Load data from various formats and convert to HDF5. 177 | 178 | Args: 179 | input_path (Text): Path to input data file 180 | output_path (Text): Path for output HDF5 file 181 | input_format (str): Format of input data ("auto", "trajectory", "numpy", "pickle") 182 | compression (str): HDF5 compression algorithm 183 | compression_opts (int): Compression level 184 | """ 185 | 186 | if input_format == "auto": 187 | # Auto-detect format based on file extension 188 | ext = os.path.splitext(input_path)[1].lower() 189 | if ext in [".mkv", ".mp4", ".avi"]: 190 | input_format = "trajectory" 191 | elif ext in [".npy", ".npz"]: 192 | input_format = "numpy" 193 | elif ext in [".pkl", ".pickle"]: 194 | input_format = "pickle" 195 | else: 196 | raise ValueError( 197 | f"Cannot auto-detect format for file: {input_path}") 198 | 199 | if input_format == "trajectory": 200 | convert_trajectory_to_hdf5(input_path, output_path, compression, 201 | compression_opts) 202 | 203 | elif input_format == "numpy": 204 | if input_path.endswith(".npz"): 205 | data = dict(np.load(input_path)) 206 | else: 207 | data = {"data": np.load(input_path)} 208 | convert_vla_data_to_hdf5(data, output_path, compression, 209 | compression_opts) 210 | 211 | elif input_format == "pickle": 212 | import pickle 213 | 214 | with open(input_path, "rb") as f: 215 | data = pickle.load(f) 216 | if not isinstance(data, dict): 217 | data = {"data": data} 218 | convert_vla_data_to_hdf5(data, output_path, compression, 219 | compression_opts) 220 | 221 | else: 222 | raise ValueError(f"Unsupported input format: {input_format}") 223 | 224 | 225 | def main(): 226 | """ 227 | Command-line interface for VLA data to HDF5 conversion. 228 | """ 229 | import argparse 230 | 231 | parser = argparse.ArgumentParser( 232 | description="Convert VLA data to HDF5 format") 233 | parser.add_argument("input", help="Input file path") 234 | parser.add_argument("output", help="Output HDF5 file path") 235 | parser.add_argument( 236 | "--format", 237 | choices=["auto", "trajectory", "numpy", "pickle"], 238 | default="auto", 239 | help="Input data format (default: auto)", 240 | ) 241 | parser.add_argument( 242 | "--compression", 243 | default="gzip", 244 | help="HDF5 compression algorithm (default: gzip)", 245 | ) 246 | parser.add_argument( 247 | "--compression-level", 248 | type=int, 249 | default=9, 250 | help="Compression level 0-9 (default: 9)", 251 | ) 252 | parser.add_argument( 253 | "--batch", 254 | action="store_true", 255 | help="Treat input as directory and convert all files", 256 | ) 257 | parser.add_argument( 258 | "--parallel", 259 | action="store_true", 260 | default=True, 261 | help="Use parallel processing for batch conversion", 262 | ) 263 | parser.add_argument( 264 | "--workers", 265 | type=int, 266 | default=None, 267 | help="Number of worker processes (default: CPU count)", 268 | ) 269 | 270 | args = parser.parse_args() 271 | 272 | if args.batch: 273 | if not os.path.isdir(args.input): 274 | raise ValueError("Input must be a directory when using --batch") 275 | 276 | # Find all relevant files in the directory 277 | trajectory_files = [] 278 | for ext in ["*.mkv", "*.mp4", "*.avi"]: 279 | trajectory_files.extend( 280 | glob.glob(os.path.join(args.input, "**", ext), recursive=True)) 281 | 282 | if not trajectory_files: 283 | print(f"No trajectory files found in {args.input}") 284 | return 285 | 286 | print(f"Found {len(trajectory_files)} trajectory files to convert") 287 | batch_convert_trajectories_to_hdf5( 288 | trajectory_files, 289 | args.output, 290 | compression=args.compression, 291 | compression_opts=args.compression_level, 292 | parallel=args.parallel, 293 | num_workers=args.workers, 294 | ) 295 | print(f"Batch conversion completed. Files saved to {args.output}") 296 | else: 297 | load_and_convert_to_hdf5( 298 | args.input, 299 | args.output, 300 | input_format=args.format, 301 | compression=args.compression, 302 | compression_opts=args.compression_level, 303 | ) 304 | print(f"Conversion completed: {args.input} -> {args.output}") 305 | 306 | 307 | if __name__ == "__main__": 308 | main() 309 | 310 | 311 | class HDF5Loader(BaseLoader): 312 | 313 | def __init__(self, path, batch_size=1): 314 | super(HDF5Loader, self).__init__(path) 315 | self.files = glob.glob(self.path, recursive=True) 316 | self.batch_size = batch_size 317 | self.index = 0 318 | random.shuffle(self.files) 319 | 320 | def get_batch(self): 321 | batch = [] 322 | 323 | for _ in range(self.batch_size): 324 | if self.index >= len(self.files): 325 | break # No more files available 326 | 327 | file_path = self.files[self.index] 328 | self.index += 1 329 | 330 | try: 331 | data = self._read_hdf5(file_path) 332 | batch.append(data) 333 | except Exception as e: 334 | logging.error(f"Error reading {file_path}: {e}") 335 | continue # Skip this file and continue 336 | 337 | return batch if batch else None 338 | 339 | def __next__(self): 340 | batch = self.get_batch() 341 | if batch is None: 342 | # Reset for next epoch 343 | self.index = 0 344 | random.shuffle(self.files) 345 | raise StopIteration 346 | return batch 347 | 348 | def _read_hdf5(self, data_path): 349 | with h5py.File(data_path, "r") as f: 350 | data_unflattened = recursively_read_hdf5_group(f) 351 | print(data_unflattened.keys()) 352 | 353 | # Flatten the entire data structure to match VLA format 354 | data_flattened = _flatten(data_unflattened) 355 | 356 | return data_flattened 357 | 358 | def __iter__(self): 359 | return self 360 | 361 | def __len__(self): 362 | return len(self.files) 363 | 364 | def peek(self): 365 | if self.index < len(self.files): 366 | file_path = self.files[self.index] 367 | return self._read_hdf5(file_path) 368 | return None 369 | 370 | def __del__(self): 371 | pass 372 | 373 | 374 | class HDF5IterableDataset(IterableDataset): 375 | 376 | def __init__(self, path): 377 | # Note: batch size = 1 is to bypass the dataloader without pytorch dataloader 378 | self.hdf5_loader = HDF5Loader(path, batch_size=1) 379 | 380 | def __iter__(self): 381 | return self 382 | 383 | def __next__(self): 384 | try: 385 | batch = next(self.hdf5_loader) 386 | return batch[0] # Return a single item, not a batch 387 | except StopIteration: 388 | raise StopIteration 389 | 390 | 391 | def hdf5_collate_fn(batch): 392 | # Convert data to PyTorch tensors 393 | return batch 394 | 395 | 396 | def get_hdf5_dataloader(path: str, batch_size: int = 1, num_workers: int = 0): 397 | dataset = HDF5IterableDataset(path) 398 | return DataLoader( 399 | dataset, 400 | batch_size=batch_size, 401 | collate_fn=hdf5_collate_fn, 402 | num_workers=num_workers, 403 | ) 404 | -------------------------------------------------------------------------------- /robodm/loader/rlds.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from . import BaseLoader 4 | 5 | 6 | class RLDSLoader(BaseLoader): 7 | def __init__( 8 | self, path, split="train", batch_size=1, shuffle_buffer=10, shuffling=True 9 | ): 10 | super(RLDSLoader, self).__init__(path) 11 | 12 | try: 13 | import tensorflow as tf 14 | import tensorflow_datasets as tfds 15 | except ImportError: 16 | raise ImportError( 17 | "Please install tensorflow and tensorflow_datasets to use rlds loader" 18 | ) 19 | 20 | self.batch_size = batch_size 21 | builder = tfds.builder_from_directory(path) 22 | self.ds = builder.as_dataset(split) 23 | self.length = len(self.ds) 24 | self.shuffling = shuffling 25 | if shuffling: 26 | self.ds = self.ds.repeat() 27 | self.ds = self.ds.shuffle(shuffle_buffer) 28 | self.iterator = iter(self.ds) 29 | 30 | self.split = split 31 | self.index = 0 32 | 33 | def __len__(self): 34 | try: 35 | import tensorflow as tf 36 | except ImportError: 37 | raise ImportError("Please install tensorflow to use rlds loader") 38 | 39 | return self.length 40 | 41 | def __iter__(self): 42 | return self 43 | 44 | def get_batch(self): 45 | batch = self.ds.take(self.batch_size) 46 | self.index += self.batch_size 47 | if not self.shuffling and self.index >= self.length: 48 | raise StopIteration 49 | data = [] 50 | for b in batch: 51 | data.append(self._convert_traj_to_numpy(b)) 52 | return data 53 | 54 | def _convert_traj_to_numpy(self, traj): 55 | import tensorflow as tf 56 | 57 | def to_numpy(step_data): 58 | step = {} 59 | for key in step_data: 60 | val = step_data[key] 61 | if isinstance(val, dict): 62 | step[key] = {k: np.array(v) for k, v in val.items()} 63 | else: 64 | step[key] = np.array(val) 65 | return step 66 | 67 | trajectory = [] 68 | for step in traj["steps"]: 69 | trajectory.append(to_numpy(step)) 70 | return trajectory 71 | 72 | def __next__(self): 73 | data = [self._convert_traj_to_numpy(next(self.iterator))] 74 | self.index += 1 75 | if self.index >= self.length: 76 | raise StopIteration 77 | return data 78 | 79 | def __getitem__(self, idx): 80 | batch = next(iter(self.ds.skip(idx).take(1))) 81 | return self._convert_traj_to_numpy(batch) 82 | -------------------------------------------------------------------------------- /robodm/loader/vla.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import logging 3 | import os 4 | import random 5 | from dataclasses import dataclass 6 | from enum import Enum 7 | from typing import Any, Dict, List, Optional, Text, Union 8 | 9 | import numpy as np 10 | 11 | try: 12 | import ray 13 | import ray.data as rd 14 | 15 | RAY_AVAILABLE = True 16 | except ImportError: 17 | RAY_AVAILABLE = False 18 | 19 | import robodm 20 | from robodm.loader.base import BaseLoader 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | class LoadingMode(Enum): 26 | """Loading mode for the VLA loader.""" 27 | 28 | TRAJECTORY = "trajectory" # Load entire trajectories 29 | SLICE = "slice" # Load random slices from trajectories 30 | 31 | 32 | @dataclass 33 | class SliceConfig: 34 | """Configuration for slice loading mode.""" 35 | 36 | slice_length: int = 100 # Number of timesteps per slice 37 | min_slice_length: Optional[int] = ( 38 | None # Minimum slice length (defaults to slice_length) 39 | ) 40 | stride: int = 1 # Stride between consecutive timesteps in slice 41 | random_start: bool = True # Whether to randomly sample start position 42 | overlap_ratio: float = 0.0 # Overlap ratio between consecutive slices (0.0-1.0) 43 | 44 | 45 | class RayVLALoader(BaseLoader): 46 | """ 47 | Ray Dataset-based VLA loader supporting both trajectory and slice loading modes. 48 | 49 | This loader uses Ray Dataset for parallel data loading, automatic shuffling, 50 | and efficient data splitting. 51 | """ 52 | 53 | def __init__( 54 | self, 55 | path: Text, 56 | mode: LoadingMode = LoadingMode.TRAJECTORY, 57 | batch_size: int = 1, 58 | return_type: str = "numpy", 59 | shuffle: bool = False, 60 | num_parallel_reads: int = 4, 61 | slice_config: Optional[SliceConfig] = None, 62 | ray_init_kwargs: Optional[Dict] = None, 63 | ): 64 | """ 65 | Initialize the Ray VLA loader. 66 | 67 | Args: 68 | path: Path to VLA files (can be glob pattern, directory, or single file) 69 | mode: Loading mode (TRAJECTORY or SLICE) 70 | batch_size: Batch size for data loading 71 | return_type: Return type ("numpy", "tensor", "container") 72 | shuffle: Whether to shuffle the data 73 | num_parallel_reads: Number of parallel read operations 74 | slice_config: Configuration for slice mode (required if mode=SLICE) 75 | ray_init_kwargs: Additional kwargs for Ray initialization 76 | """ 77 | super().__init__(path) 78 | 79 | if not RAY_AVAILABLE: 80 | raise ImportError( 81 | "Ray is required for RayVLALoader. Install with: pip install 'ray[data]'" 82 | ) 83 | 84 | self.mode = mode 85 | self.batch_size = batch_size 86 | self.return_type = return_type 87 | self.shuffle = shuffle 88 | self.num_parallel_reads = num_parallel_reads 89 | self.slice_config = slice_config or SliceConfig() 90 | 91 | # Initialize Ray if not already initialized 92 | if not ray.is_initialized(): 93 | ray.init(**(ray_init_kwargs or {})) 94 | 95 | # Validate slice config for slice mode 96 | if mode == LoadingMode.SLICE and slice_config is None: 97 | self.slice_config = SliceConfig() 98 | 99 | # Get file paths and create Ray dataset 100 | self.file_paths = self._get_files(path) 101 | self.dataset = self._create_dataset() 102 | 103 | logger.info( 104 | f"Initialized RayVLALoader with {len(self.file_paths)} files in {mode.value} mode" 105 | ) 106 | 107 | def _get_files(self, path: str) -> List[str]: 108 | """Get list of VLA files based on path.""" 109 | files = [] 110 | 111 | if "*" in path: 112 | files = glob.glob(path) 113 | elif os.path.isdir(path): 114 | files = glob.glob(os.path.join(path, "*.vla")) 115 | else: 116 | files = [path] 117 | 118 | return files 119 | 120 | def _create_dataset(self) -> rd.Dataset: 121 | """Create Ray dataset based on loading mode.""" 122 | # Create initial dataset from file paths 123 | dataset = rd.from_items(self.file_paths) 124 | 125 | if self.mode == LoadingMode.TRAJECTORY: 126 | # For trajectory mode, each item is a complete trajectory 127 | dataset = dataset.map( 128 | self._load_trajectory, 129 | num_cpus=self.num_parallel_reads, 130 | concurrency=self.num_parallel_reads, 131 | ) 132 | elif self.mode == LoadingMode.SLICE: 133 | # For slice mode, expand each trajectory into multiple slices 134 | dataset = dataset.flat_map( 135 | self._extract_slices, 136 | num_cpus=self.num_parallel_reads, 137 | concurrency=self.num_parallel_reads, 138 | ) 139 | 140 | # Apply shuffling if requested 141 | if self.shuffle: 142 | dataset = dataset.random_shuffle() 143 | 144 | return dataset 145 | 146 | def _load_trajectory(self, item) -> Dict[str, Any]: 147 | """Load a complete trajectory from file.""" 148 | # Handle both string paths and dict items from Ray dataset 149 | if isinstance(item, dict): 150 | file_path = item.get("item", item) 151 | else: 152 | file_path = item 153 | 154 | try: 155 | traj = robodm.Trajectory(file_path) 156 | data = traj.load(return_type=self.return_type) 157 | 158 | return data 159 | 160 | except Exception as e: 161 | logger.error(f"Error loading trajectory {file_path}: {e}") 162 | return {} 163 | 164 | def _extract_slices(self, item) -> List[Dict[str, Any]]: 165 | """Extract slices from a trajectory file.""" 166 | # Handle both string paths and dict items from Ray dataset 167 | if isinstance(item, dict): 168 | file_path = item.get("item", item) 169 | else: 170 | file_path = item 171 | 172 | try: 173 | traj = robodm.Trajectory(file_path) 174 | full_data = traj.load(return_type=self.return_type) 175 | 176 | if not full_data: 177 | return [] 178 | 179 | # Get trajectory length 180 | traj_length = len(next(iter(full_data.values()))) 181 | min_length = (self.slice_config.min_slice_length 182 | or self.slice_config.slice_length) 183 | 184 | if traj_length < min_length: 185 | logger.warning( 186 | f"Trajectory {file_path} too short ({traj_length} < {min_length})" 187 | ) 188 | return [] 189 | 190 | slices = [] 191 | slice_step = max( 192 | 1, 193 | int(self.slice_config.slice_length * 194 | (1 - self.slice_config.overlap_ratio)), 195 | ) 196 | 197 | # Generate slice positions 198 | max_start = traj_length - self.slice_config.slice_length 199 | 200 | if self.slice_config.random_start: 201 | # Random sampling of slice positions 202 | num_slices = max(1, max_start // slice_step) 203 | start_positions = [ 204 | random.randint(0, max_start) for _ in range(num_slices) 205 | ] 206 | else: 207 | # Sequential slicing 208 | start_positions = list(range(0, max_start + 1, slice_step)) 209 | 210 | # Extract slices 211 | for start_idx in start_positions: 212 | end_idx = min(start_idx + self.slice_config.slice_length, 213 | traj_length) 214 | actual_length = end_idx - start_idx 215 | 216 | if actual_length < min_length: 217 | continue 218 | 219 | # Extract slice data 220 | slice_data = {} 221 | for key, values in full_data.items(): 222 | if isinstance(values, np.ndarray): 223 | slice_data[key] = values[start_idx:end_idx:self. 224 | slice_config.stride] 225 | elif isinstance(values, list): 226 | slice_data[key] = values[start_idx:end_idx:self. 227 | slice_config.stride] 228 | else: 229 | slice_data[key] = values 230 | 231 | slices.append(slice_data) 232 | 233 | return slices 234 | 235 | except Exception as e: 236 | logger.error(f"Error extracting slices from {file_path}: {e}") 237 | return [] 238 | 239 | def get_batch(self) -> List[Dict[str, Any]]: 240 | """Get a batch of data.""" 241 | try: 242 | batch = self.dataset.take(self.batch_size) 243 | return list(batch) 244 | except Exception as e: 245 | logger.error(f"Error getting batch: {e}") 246 | return [] 247 | 248 | def iter_batches(self, batch_size: Optional[int] = None): 249 | """Iterate over batches of data.""" 250 | batch_size = batch_size or self.batch_size 251 | return self.dataset.iter_batches(batch_size=batch_size) 252 | 253 | def iter_rows(self): 254 | """Iterate over individual rows of data.""" 255 | return self.dataset.iter_rows() 256 | 257 | def take(self, num_items: int) -> List[Dict[str, Any]]: 258 | """Take a specific number of items.""" 259 | return list(self.dataset.take(num_items)) 260 | 261 | def count(self) -> int: 262 | """Count the number of items in the dataset.""" 263 | return self.dataset.count() 264 | 265 | def schema(self): 266 | """Get the schema of the dataset.""" 267 | return self.dataset.schema() 268 | 269 | def split(self, *fractions: float, shuffle: bool = True): 270 | """Split the dataset into multiple datasets.""" 271 | # Validate fractions sum to <= 1.0 272 | if sum(fractions) > 1.0: 273 | raise ValueError( 274 | f"Sum of fractions {sum(fractions)} must be <= 1.0") 275 | 276 | # Ray Dataset.split() doesn't support shuffle parameter 277 | # If shuffle is requested, shuffle the dataset first 278 | dataset_to_split = self.dataset.random_shuffle( 279 | ) if shuffle else self.dataset 280 | 281 | if len(fractions) == 1: 282 | # For single fraction, convert to train/test split 283 | return dataset_to_split.train_test_split(test_size=fractions[0], 284 | shuffle=False) 285 | elif len(fractions) == 2 and abs(sum(fractions) - 1.0) < 1e-10: 286 | # Special case: exactly two fractions that sum to 1.0 287 | # Use train_test_split which handles this case 288 | return dataset_to_split.train_test_split(test_size=fractions[1], 289 | shuffle=False) 290 | else: 291 | # For multiple fractions, use split_proportionately 292 | # Ray requires the sum to be < 1.0, so if it equals 1.0, we need to adjust 293 | fractions_list = list(fractions) 294 | total = sum(fractions_list) 295 | 296 | if abs(total - 1.0) < 1e-10: 297 | # If fractions sum to 1.0, subtract a tiny amount from the last fraction 298 | # so Ray doesn't complain, then drop the extra split 299 | fractions_list[-1] -= 1e-6 300 | splits = dataset_to_split.split_proportionately(fractions_list) 301 | # Drop the last split (which will be nearly empty) 302 | return splits[:-1] 303 | else: 304 | return dataset_to_split.split_proportionately(fractions_list) 305 | 306 | def filter(self, fn): 307 | """Filter the dataset.""" 308 | return self.dataset.filter(fn) 309 | 310 | def map(self, fn, **kwargs): 311 | """Map a function over the dataset.""" 312 | return self.dataset.map(fn, **kwargs) 313 | 314 | def sample(self, num_samples: int, replace: bool = False): 315 | """Sample from the dataset.""" 316 | # Ray's random_sample expects a fraction, not absolute count 317 | total_count = self.count() 318 | if total_count == 0: 319 | return [] 320 | 321 | # For exact count without replacement, use take with random shuffle 322 | if not replace: 323 | shuffled_dataset = self.dataset.random_shuffle() 324 | return list(shuffled_dataset.take(min(num_samples, total_count))) 325 | else: 326 | # For replacement sampling, use multiple passes if needed 327 | # This is a limitation of Ray's API 328 | import warnings 329 | 330 | warnings.warn( 331 | "Sampling with replacement may not return exact count due to Ray API limitations" 332 | ) 333 | 334 | fraction = min(1.0, num_samples / total_count) 335 | # Sample and take up to the requested amount 336 | sampled = self.dataset.random_sample(fraction) 337 | return list(sampled.take(num_samples)) 338 | 339 | def peek(self) -> Optional[Dict[str, Any]]: 340 | """Peek at the first item without consuming it.""" 341 | try: 342 | return self.dataset.take(1)[0] 343 | except: 344 | return None 345 | 346 | def __len__(self) -> int: 347 | """Get the number of items in the dataset.""" 348 | return self.count() 349 | 350 | def __iter__(self): 351 | """Iterate over the dataset.""" 352 | return self.iter_rows() 353 | 354 | def materialize(self): 355 | """Materialize the dataset in memory.""" 356 | return self.dataset.materialize() 357 | 358 | 359 | # Legacy compatibility loaders (deprecated) 360 | class VLALoader(RayVLALoader): 361 | """Legacy VLA loader - deprecated, use RayVLALoader instead.""" 362 | 363 | def __init__(self, path: Text, batch_size=1, return_type="numpy"): 364 | logger.warning("VLALoader is deprecated. Use RayVLALoader instead.") 365 | super().__init__( 366 | path=path, 367 | mode=LoadingMode.TRAJECTORY, 368 | batch_size=batch_size, 369 | return_type=return_type, 370 | shuffle=True, 371 | ) 372 | 373 | 374 | class NonShuffleVLALoader(RayVLALoader): 375 | """Legacy non-shuffle VLA loader - deprecated, use RayVLALoader instead.""" 376 | 377 | def __init__(self, 378 | path: Text, 379 | batch_size=1, 380 | num_workers=1, 381 | return_type="numpy"): 382 | logger.warning( 383 | "NonShuffleVLALoader is deprecated. Use RayVLALoader instead.") 384 | super().__init__( 385 | path=path, 386 | mode=LoadingMode.TRAJECTORY, 387 | batch_size=batch_size, 388 | return_type=return_type, 389 | shuffle=False, 390 | ) 391 | 392 | 393 | def get_vla_dataloader(path: Text, 394 | batch_size: int = 1, 395 | num_workers: int = 1, 396 | **kwargs): 397 | """Legacy function to get VLA dataloader - deprecated, use create_trajectory_loader instead.""" 398 | logger.warning( 399 | "get_vla_dataloader is deprecated. Use create_trajectory_loader instead." 400 | ) 401 | loader = RayVLALoader( 402 | path=path, 403 | mode=LoadingMode.TRAJECTORY, 404 | batch_size=batch_size, 405 | return_type="numpy", 406 | shuffle=True, 407 | num_parallel_reads=max(1, num_workers), 408 | **kwargs, 409 | ) 410 | return loader 411 | 412 | 413 | # Factory functions for common use cases 414 | def create_trajectory_loader( 415 | path: Text, 416 | batch_size: int = 1, 417 | return_type: str = "numpy", 418 | shuffle: bool = False, 419 | num_parallel_reads: int = 4, 420 | **kwargs, 421 | ) -> RayVLALoader: 422 | """Create a loader for complete trajectories.""" 423 | return RayVLALoader( 424 | path=path, 425 | mode=LoadingMode.TRAJECTORY, 426 | batch_size=batch_size, 427 | return_type=return_type, 428 | shuffle=shuffle, 429 | num_parallel_reads=num_parallel_reads, 430 | **kwargs, 431 | ) 432 | 433 | 434 | def create_slice_loader( 435 | path: Text, 436 | slice_length: int = 100, 437 | batch_size: int = 1, 438 | return_type: str = "numpy", 439 | shuffle: bool = False, 440 | num_parallel_reads: int = 4, 441 | min_slice_length: Optional[int] = None, 442 | stride: int = 1, 443 | random_start: bool = True, 444 | overlap_ratio: float = 0.0, 445 | **kwargs, 446 | ) -> RayVLALoader: 447 | """Create a loader for trajectory slices.""" 448 | slice_config = SliceConfig( 449 | slice_length=slice_length, 450 | min_slice_length=min_slice_length, 451 | stride=stride, 452 | random_start=random_start, 453 | overlap_ratio=overlap_ratio, 454 | ) 455 | 456 | return RayVLALoader( 457 | path=path, 458 | mode=LoadingMode.SLICE, 459 | batch_size=batch_size, 460 | return_type=return_type, 461 | shuffle=shuffle, 462 | num_parallel_reads=num_parallel_reads, 463 | slice_config=slice_config, 464 | **kwargs, 465 | ) 466 | -------------------------------------------------------------------------------- /robodm/trajectory_base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any, Dict, List, Optional, Text, Union 3 | 4 | import numpy as np 5 | 6 | 7 | class TrajectoryInterface(ABC): 8 | """ 9 | Abstract base class defining the interface for trajectory objects. 10 | This allows for better testing and dependency injection. 11 | """ 12 | 13 | @abstractmethod 14 | def add(self, 15 | feature: str, 16 | data: Any, 17 | timestamp: Optional[int] = None) -> None: 18 | """Add a single feature value to the trajectory.""" 19 | pass 20 | 21 | @abstractmethod 22 | def add_by_dict(self, 23 | data: Dict[str, Any], 24 | timestamp: Optional[int] = None) -> None: 25 | """Add multiple features from a dictionary to the trajectory.""" 26 | pass 27 | 28 | @abstractmethod 29 | def load(self, 30 | save_to_cache: bool = True, 31 | return_type: str = "numpy") -> Union[Dict, Any]: 32 | """Load the trajectory data.""" 33 | pass 34 | 35 | @abstractmethod 36 | def close(self, compact: bool = True) -> None: 37 | """Close the trajectory file.""" 38 | pass 39 | 40 | @abstractmethod 41 | def __getitem__(self, key: str) -> Any: 42 | """Get a feature from the trajectory.""" 43 | pass 44 | 45 | 46 | class FileSystemInterface(ABC): 47 | """Abstract interface for file system operations to enable testing with mocks.""" 48 | 49 | @abstractmethod 50 | def exists(self, path: str) -> bool: 51 | """Check if a file exists.""" 52 | pass 53 | 54 | @abstractmethod 55 | def makedirs(self, path: str, exist_ok: bool = False) -> None: 56 | """Create directories.""" 57 | pass 58 | 59 | @abstractmethod 60 | def remove(self, path: str) -> None: 61 | """Remove a file.""" 62 | pass 63 | 64 | @abstractmethod 65 | def rename(self, src: str, dst: str) -> None: 66 | """Rename a file.""" 67 | pass 68 | 69 | 70 | class DefaultFileSystem(FileSystemInterface): 71 | """Default implementation using standard file system operations.""" 72 | 73 | def exists(self, path: str) -> bool: 74 | import os 75 | 76 | return os.path.exists(path) 77 | 78 | def makedirs(self, path: str, exist_ok: bool = False) -> None: 79 | import os 80 | 81 | os.makedirs(path, exist_ok=exist_ok) 82 | 83 | def remove(self, path: str) -> None: 84 | import os 85 | 86 | os.remove(path) 87 | 88 | def rename(self, src: str, dst: str) -> None: 89 | import os 90 | 91 | os.rename(src, dst) 92 | 93 | 94 | class TimeProvider(ABC): 95 | """Abstract interface for time operations to enable testing.""" 96 | 97 | @abstractmethod 98 | def time(self) -> float: 99 | """Get current time.""" 100 | pass 101 | 102 | 103 | class DefaultTimeProvider(TimeProvider): 104 | """Default implementation using standard time operations.""" 105 | 106 | def time(self) -> float: 107 | import time 108 | 109 | return time.time() 110 | -------------------------------------------------------------------------------- /robodm/trajectory_factory.py: -------------------------------------------------------------------------------- 1 | """Factory for creating trajectory instances with dependency injection.""" 2 | 3 | from typing import Any, Dict, Optional, Text 4 | 5 | from .trajectory_base import (DefaultFileSystem, DefaultTimeProvider, 6 | FileSystemInterface, TimeProvider, 7 | TrajectoryInterface) 8 | 9 | 10 | class TrajectoryFactory: 11 | """Factory for creating trajectory instances with configurable dependencies.""" 12 | 13 | def __init__( 14 | self, 15 | filesystem: Optional[FileSystemInterface] = None, 16 | time_provider: Optional[TimeProvider] = None, 17 | ): 18 | self.filesystem = filesystem or DefaultFileSystem() 19 | self.time_provider = time_provider or DefaultTimeProvider() 20 | 21 | def create_trajectory( 22 | self, 23 | path: Text, 24 | mode: str = "r", 25 | video_codec: str = "auto", 26 | codec_options: Optional[Dict[str, Any]] = None, 27 | feature_name_separator: Text = "/", 28 | ) -> TrajectoryInterface: 29 | """ 30 | Create a trajectory instance with injected dependencies. 31 | 32 | Args: 33 | path (Text): Path to trajectory file 34 | mode (str): File mode ("r" or "w") 35 | video_codec (str): Video codec to use ("auto", "rawvideo", "h264", "h265", "libaom-av1", "ffv1") 36 | codec_options (Dict[str, Any]): Additional codec-specific options 37 | feature_name_separator (Text): Delimiter for feature names 38 | """ 39 | from .trajectory import Trajectory 40 | 41 | # Create trajectory with dependency injection 42 | trajectory = Trajectory( 43 | path=path, 44 | mode=mode, 45 | video_codec=video_codec, 46 | codec_options=codec_options, 47 | feature_name_separator=feature_name_separator, 48 | filesystem=self.filesystem, 49 | time_provider=self.time_provider, 50 | ) 51 | 52 | return trajectory 53 | 54 | 55 | # Global factory instance for backwards compatibility 56 | default_factory = TrajectoryFactory() 57 | 58 | 59 | def create_trajectory( 60 | path: Text, 61 | mode: str = "r", 62 | video_codec: str = "auto", 63 | codec_options: Optional[Dict[str, Any]] = None, 64 | feature_name_separator: Text = "/", 65 | base_datetime: Optional[Any] = None, 66 | time_unit: str = "ms", 67 | enforce_monotonic: bool = True, 68 | ) -> TrajectoryInterface: 69 | """ 70 | Convenience function to create trajectory with default dependencies. 71 | 72 | Args: 73 | path (Text): Path to trajectory file 74 | mode (str): File mode ("r" or "w") 75 | video_codec (str): Video codec to use ("auto", "rawvideo", "h264", "h265", "libaom-av1", "ffv1") 76 | codec_options (Dict[str, Any]): Additional codec-specific options 77 | feature_name_separator (Text): Delimiter for feature names 78 | base_datetime: Optional base datetime for timestamp calculations 79 | time_unit: Default time unit for timestamp inputs ('ns', 'μs', 'ms', 's') 80 | enforce_monotonic: Whether to enforce monotonically increasing timestamps 81 | """ 82 | from .trajectory import Trajectory 83 | 84 | # Call Trajectory constructor directly since the factory doesn't support time parameters yet 85 | return Trajectory( 86 | path=path, 87 | mode=mode, 88 | video_codec=video_codec, 89 | codec_options=codec_options, 90 | feature_name_separator=feature_name_separator, 91 | base_datetime=base_datetime, 92 | time_unit=time_unit, 93 | enforce_monotonic=enforce_monotonic, 94 | ) 95 | -------------------------------------------------------------------------------- /robodm/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | from typing import Any, Dict, List 4 | 5 | import numpy as np 6 | 7 | from robodm.feature import FeatureType 8 | 9 | 10 | def data_to_tf_schema(data: Dict[str, Any]) -> Dict[str, FeatureType]: 11 | """ 12 | Convert data to a tf schema 13 | """ 14 | data = _flatten(data) 15 | schema: Dict[str, Any] = {} 16 | for k, v in data.items(): 17 | if "/" in k: # make the subkey to be within dict 18 | main_key, sub_key = k.split("/") 19 | if main_key not in schema: 20 | schema[main_key] = {} 21 | schema[main_key][sub_key] = FeatureType.from_data(v).to_tf_feature_type( 22 | first_dim_none=True 23 | ) 24 | # replace first element of shape with None 25 | else: 26 | schema[k] = FeatureType.from_data(v).to_tf_feature_type(first_dim_none=True) 27 | return schema 28 | 29 | 30 | # flatten the data such that all data starts with root level tree (observation and action) 31 | def _flatten(data, parent_key="", sep="/"): 32 | items = {} 33 | for k, v in data.items(): 34 | new_key = parent_key + sep + k if parent_key else k 35 | if isinstance(v, dict): 36 | items.update(_flatten(v, new_key, sep)) 37 | else: 38 | items[new_key] = v 39 | return items 40 | 41 | 42 | import h5py 43 | 44 | 45 | def recursively_read_hdf5_group(group): 46 | if isinstance(group, h5py.Dataset): 47 | return np.array(group) 48 | elif isinstance(group, h5py.Group): 49 | return {key: recursively_read_hdf5_group(value) for key, value in group.items()} 50 | else: 51 | raise TypeError("Unsupported HDF5 group type") 52 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | # robodm Test Suite 2 | 3 | This directory contains comprehensive tests for the robodm trajectory management system, including unit tests, integration tests, and performance benchmarks. 4 | 5 | ## Test Structure 6 | 7 | ``` 8 | tests/ 9 | ├── conftest.py # Pytest configuration and fixtures 10 | ├── test_fixtures.py # Mock objects and test data fixtures 11 | ├── test_trajectory.py # Unit tests for trajectory functionality 12 | ├── test_loaders.py # Unit tests for data loaders 13 | ├── test_benchmark.py # Performance benchmarks 14 | └── README.md # This file 15 | ``` 16 | 17 | ## Test Categories 18 | 19 | ### Unit Tests (`test_trajectory.py`, `test_loaders.py`) 20 | - Test individual components in isolation 21 | - Use mock dependencies for fast, deterministic testing 22 | - Cover core functionality like data creation, loading, and manipulation 23 | 24 | ### Integration Tests 25 | - Test complete workflows across multiple components 26 | - Use real file system operations 27 | - Verify end-to-end functionality 28 | 29 | ### Benchmark Tests (`test_benchmark.py`) 30 | - Compare performance across different formats (VLA, HDF5, TFRecord) 31 | - Measure file sizes, creation time, and loading time 32 | - Generate CSV reports for analysis 33 | 34 | ## Running Tests 35 | 36 | ### Quick Start 37 | ```bash 38 | # Run basic unit tests 39 | python run_tests.py --test-type unit 40 | 41 | # Run with coverage 42 | python run_tests.py --test-type unit --coverage 43 | 44 | # Run benchmarks 45 | python run_tests.py --test-type benchmark --benchmark-size small 46 | ``` 47 | 48 | ### Detailed Commands 49 | 50 | #### Unit Tests Only 51 | ```bash 52 | python -m pytest tests/test_trajectory.py tests/test_loaders.py -m "not slow" 53 | ``` 54 | 55 | #### Integration Tests 56 | ```bash 57 | python -m pytest tests/ -m "integration" 58 | ``` 59 | 60 | #### Benchmark Tests 61 | ```bash 62 | # Small benchmarks (fast) 63 | python -m pytest tests/test_benchmark.py -m "not slow" 64 | 65 | # Full benchmarks (slow) 66 | python -m pytest tests/test_benchmark.py 67 | ``` 68 | 69 | #### All Tests 70 | ```bash 71 | python -m pytest tests/ 72 | ``` 73 | 74 | ## Test Fixtures and Mock Objects 75 | 76 | ### MockFileSystem 77 | - Simulates file system operations without actual I/O 78 | - Enables fast, deterministic testing 79 | - Located in `test_fixtures.py` 80 | 81 | ### MockTimeProvider 82 | - Provides controllable time for testing 83 | - Enables deterministic timestamp testing 84 | - Supports time advancement simulation 85 | 86 | ### Sample Data 87 | - `sample_trajectory_data`: Small datasets for quick tests 88 | - `large_sample_data`: Larger datasets for performance testing 89 | - `sample_dict_of_lists`: Test data in dictionary-of-lists format 90 | 91 | ## Benchmarking 92 | 93 | The benchmark suite compares robodm VLA format against: 94 | - **HDF5**: Popular scientific data format 95 | - **TFRecord**: TensorFlow's native format (if available) 96 | 97 | ### Benchmark Metrics 98 | - **File Size**: Compressed size on disk 99 | - **Creation Time**: Time to write data to format 100 | - **Loading Time**: Time to read data from format 101 | - **Compression Ratio**: Uncompressed size / compressed size 102 | - **Scalability**: Performance vs. dataset size 103 | 104 | ### Sample Benchmark Output 105 | ``` 106 | Format | File Size (MB) | Creation (s) | Loading (s) | Compression 107 | -----------|----------------|--------------|-------------|------------ 108 | VLA_lossy | 12.34 | 2.15 | 0.89 | 8.2x 109 | VLA_lossless| 18.67 | 1.98 | 0.76 | 5.4x 110 | HDF5 | 15.23 | 1.87 | 0.92 | 6.6x 111 | TFDS | 20.45 | 3.21 | 1.23 | 4.9x 112 | ``` 113 | 114 | ## Test Configuration 115 | 116 | ### Pytest Markers 117 | - `@pytest.mark.slow`: Tests that take significant time 118 | - `@pytest.mark.integration`: Integration tests requiring real I/O 119 | - `@pytest.mark.benchmark`: Performance benchmark tests 120 | 121 | ### Environment Variables 122 | ```bash 123 | # Skip slow tests 124 | export PYTEST_IGNORE_SLOW=1 125 | 126 | # Set custom temp directory 127 | export PYTEST_TEMP_DIR=/path/to/temp 128 | 129 | # Enable verbose logging 130 | export ROBODM_TEST_VERBOSE=1 131 | ``` 132 | 133 | ## Adding New Tests 134 | 135 | ### Unit Test Example 136 | ```python 137 | def test_my_feature(temp_dir, mock_filesystem): 138 | """Test my new feature.""" 139 | # Create test data 140 | data = {"feature": [1, 2, 3]} 141 | 142 | # Use mock filesystem for fast testing 143 | factory = TrajectoryFactory(filesystem=mock_filesystem) 144 | 145 | # Test your feature 146 | trajectory = factory.create_trajectory("test.vla", mode="w") 147 | # ... test logic 148 | 149 | assert expected_result == actual_result 150 | ``` 151 | 152 | ### Benchmark Test Example 153 | ```python 154 | def test_my_benchmark(temp_dir): 155 | """Benchmark my feature.""" 156 | runner = BenchmarkRunner(temp_dir) 157 | 158 | # Create test data 159 | data = runner.create_test_data(num_samples=100) 160 | 161 | # Run benchmarks 162 | results = runner.run_comprehensive_benchmark() 163 | 164 | # Verify results 165 | for result in results: 166 | assert result.compression_ratio > 1.0 167 | ``` 168 | 169 | ## Dependencies 170 | 171 | ### Required 172 | - `pytest`: Test framework 173 | - `numpy`: Numerical operations 174 | - `h5py`: HDF5 support 175 | 176 | ### Optional 177 | - `tensorflow`: For TFRecord benchmarks 178 | - `pytest-cov`: Coverage reporting 179 | - `pytest-html`: HTML test reports 180 | - `pytest-benchmark`: Advanced benchmarking 181 | 182 | ## Troubleshooting 183 | 184 | ### Common Issues 185 | 186 | #### "TensorFlow not available" 187 | ```bash 188 | pip install tensorflow # For TFRecord benchmarks 189 | ``` 190 | 191 | #### "Permission denied" errors 192 | ```bash 193 | # Ensure temp directory is writable 194 | chmod 755 /tmp/robodm_tests 195 | ``` 196 | 197 | #### Out of memory errors 198 | ```bash 199 | # Reduce benchmark dataset size 200 | python run_tests.py --test-type benchmark --benchmark-size small 201 | ``` 202 | 203 | ### Debug Mode 204 | ```bash 205 | # Run with full output and debugging 206 | python -m pytest tests/ -v -s --tb=long --no-header 207 | ``` 208 | 209 | ## Performance Expectations 210 | 211 | ### Unit Tests 212 | - Should complete in < 30 seconds 213 | - Use minimal memory (< 100MB) 214 | - No external dependencies 215 | 216 | ### Integration Tests 217 | - Should complete in < 2 minutes 218 | - May use up to 1GB memory 219 | - Require file system access 220 | 221 | ### Benchmark Tests 222 | - Small: < 1 minute, < 500MB memory 223 | - Medium: < 5 minutes, < 2GB memory 224 | - Large: < 15 minutes, < 8GB memory 225 | 226 | ## Contributing 227 | 228 | When adding new tests: 229 | 230 | 1. **Follow naming conventions**: `test_*.py` files, `test_*` functions 231 | 2. **Use appropriate fixtures**: Mock objects for unit tests, real I/O for integration 232 | 3. **Add appropriate markers**: `@pytest.mark.slow` for long-running tests 233 | 4. **Document test purpose**: Clear docstrings explaining what is tested 234 | 5. **Keep tests focused**: One concept per test function 235 | 6. **Use assertions effectively**: Clear, specific assertion messages 236 | 237 | For benchmark tests: 238 | 1. **Use deterministic data**: Set random seeds for reproducible results 239 | 2. **Measure what matters**: Focus on user-relevant metrics 240 | 3. **Consider scalability**: Test with multiple dataset sizes 241 | 4. **Save results**: Generate CSV/reports for analysis -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BerkeleyAutomation/robodm/2ce5d539e49d93273e71dd3d7dfd60a6c49a4e43/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """Pytest configuration and fixture registration.""" 2 | 3 | import sys 4 | 5 | import pytest 6 | 7 | # Import all fixtures from test_fixtures 8 | from .test_fixtures import (benchmark_dataset, large_sample_data, 9 | mock_filesystem, mock_time_provider, 10 | sample_dict_of_lists, sample_trajectory_data, 11 | temp_dir) 12 | 13 | # Re-export fixtures so pytest can find them 14 | __all__ = [ 15 | "mock_filesystem", 16 | "mock_time_provider", 17 | "temp_dir", 18 | "sample_trajectory_data", 19 | "sample_dict_of_lists", 20 | "large_sample_data", 21 | "benchmark_dataset", 22 | ] 23 | 24 | 25 | # each test runs on cwd to its temp dir 26 | @pytest.fixture(autouse=True) 27 | def go_to_tmpdir(request): 28 | # Get the fixture dynamically by its name. 29 | tmpdir = request.getfixturevalue("tmpdir") 30 | # ensure local test created packages can be imported 31 | sys.path.insert(0, str(tmpdir)) 32 | # Chdir only for the duration of the test. 33 | with tmpdir.as_cwd(): 34 | yield 35 | -------------------------------------------------------------------------------- /tests/test_fixtures.py: -------------------------------------------------------------------------------- 1 | """Test fixtures and mock implementations for robodm testing.""" 2 | 3 | import os 4 | import shutil 5 | import tempfile 6 | import time 7 | from typing import Any, Dict, List, Optional, Union 8 | from unittest.mock import MagicMock, Mock 9 | 10 | import numpy as np 11 | import pytest 12 | 13 | from robodm import Trajectory 14 | from robodm.trajectory_base import FileSystemInterface, TimeProvider 15 | 16 | 17 | class MockFileSystem(FileSystemInterface): 18 | """Mock file system for testing.""" 19 | 20 | def __init__(self): 21 | self.files = {} 22 | self.directories = set() 23 | 24 | def exists(self, path: str) -> bool: 25 | return path in self.files or path in self.directories 26 | 27 | def makedirs(self, path: str, exist_ok: bool = False) -> None: 28 | if not exist_ok and path in self.directories: 29 | raise FileExistsError(f"Directory {path} already exists") 30 | self.directories.add(path) 31 | 32 | def remove(self, path: str) -> None: 33 | if path in self.files: 34 | del self.files[path] 35 | else: 36 | raise FileNotFoundError(f"File {path} not found") 37 | 38 | def rename(self, src: str, dst: str) -> None: 39 | if src not in self.files: 40 | raise FileNotFoundError(f"File {src} not found") 41 | self.files[dst] = self.files[src] 42 | del self.files[src] 43 | 44 | def add_file(self, path: str, content: Any = None): 45 | """Add a file to the mock filesystem.""" 46 | self.files[path] = content 47 | 48 | 49 | class MockTimeProvider(TimeProvider): 50 | """Mock time provider for deterministic testing.""" 51 | 52 | def __init__(self, initial_time: float = 0.0): 53 | self._current_time = initial_time 54 | self._time_calls = [] 55 | 56 | def time(self) -> float: 57 | self._time_calls.append(self._current_time) 58 | return self._current_time 59 | 60 | def advance_time(self, seconds: float): 61 | """Advance the mock time.""" 62 | self._current_time += seconds 63 | 64 | def set_time(self, time: float): 65 | """Set the current time.""" 66 | self._current_time = time 67 | 68 | @property 69 | def call_count(self) -> int: 70 | return len(self._time_calls) 71 | 72 | 73 | @pytest.fixture 74 | def mock_filesystem(): 75 | """Pytest fixture for mock filesystem.""" 76 | return MockFileSystem() 77 | 78 | 79 | @pytest.fixture 80 | def mock_time_provider(): 81 | """Pytest fixture for mock time provider.""" 82 | return MockTimeProvider() 83 | 84 | 85 | @pytest.fixture 86 | def temp_dir(): 87 | """Pytest fixture for temporary directory.""" 88 | temp_path = tempfile.mkdtemp() 89 | yield temp_path 90 | shutil.rmtree(temp_path, ignore_errors=True) 91 | 92 | 93 | @pytest.fixture 94 | def sample_trajectory_data(): 95 | """Sample trajectory data for testing.""" 96 | return [ 97 | { 98 | "observation": { 99 | "image": np.random.randint(0, 100 | 255, (480, 640, 3), 101 | dtype=np.uint8), 102 | "joint_positions": np.random.random(7).astype(np.float32), 103 | }, 104 | "action": np.random.random(7).astype(np.float32), 105 | "reward": np.float32(1.0), 106 | }, 107 | { 108 | "observation": { 109 | "image": np.random.randint(0, 110 | 255, (480, 640, 3), 111 | dtype=np.uint8), 112 | "joint_positions": np.random.random(7).astype(np.float32), 113 | }, 114 | "action": np.random.random(7).astype(np.float32), 115 | "reward": np.float32(0.5), 116 | }, 117 | ] 118 | 119 | 120 | @pytest.fixture 121 | def sample_dict_of_lists(): 122 | """Sample dictionary of lists for testing.""" 123 | return { 124 | "observation": { 125 | "image": [ 126 | np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8), 127 | np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8), 128 | ], 129 | "joint_positions": [ 130 | np.random.random(7).astype(np.float32), 131 | np.random.random(7).astype(np.float32), 132 | ], 133 | }, 134 | "action": [ 135 | np.random.random(7).astype(np.float32), 136 | np.random.random(7).astype(np.float32), 137 | ], 138 | "reward": [np.float32(1.0), np.float32(0.5)], 139 | } 140 | 141 | 142 | @pytest.fixture 143 | def large_sample_data(): 144 | """Large sample data for benchmarking.""" 145 | num_samples = 100 146 | return { 147 | "observation/image": [ 148 | np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8) 149 | for _ in range(num_samples) 150 | ], 151 | "observation/joint_positions": 152 | [np.random.random(7).astype(np.float32) for _ in range(num_samples)], 153 | "action": 154 | [np.random.random(7).astype(np.float32) for _ in range(num_samples)], 155 | "reward": [np.float32(np.random.random()) for _ in range(num_samples)], 156 | } 157 | 158 | 159 | class BenchmarkDataset: 160 | """Helper class for creating benchmark datasets.""" 161 | 162 | @staticmethod 163 | def create_vla_dataset(path: str, 164 | data: Dict[str, List[Any]], 165 | video_codec: str = "auto"): 166 | """Create a VLA dataset file for testing.""" 167 | traj = Trajectory.from_dict_of_lists(data, 168 | path, 169 | video_codec=video_codec) 170 | return traj 171 | 172 | @staticmethod 173 | def create_hdf5_dataset(path: str, data: Dict[str, List[Any]]): 174 | """Create an HDF5 dataset file.""" 175 | import h5py 176 | 177 | with h5py.File(path, "w") as f: 178 | for key, values in data.items(): 179 | if isinstance(values[0], np.ndarray): 180 | stacked_data = np.stack(values) 181 | else: 182 | stacked_data = np.array(values) 183 | f.create_dataset(key, 184 | data=stacked_data, 185 | compression="gzip", 186 | compression_opts=9) 187 | 188 | @staticmethod 189 | def get_file_size(path: str) -> int: 190 | """Get file size in bytes.""" 191 | return os.path.getsize(path) 192 | 193 | 194 | @pytest.fixture 195 | def benchmark_dataset(): 196 | """Pytest fixture for benchmark dataset helper.""" 197 | return BenchmarkDataset() 198 | -------------------------------------------------------------------------------- /tests/test_shape_codec_logic.py: -------------------------------------------------------------------------------- 1 | """Test cases for shape-based codec selection and dimensionality checking.""" 2 | 3 | import os 4 | import tempfile 5 | 6 | import numpy as np 7 | import pytest 8 | 9 | from robodm import FeatureType, Trajectory 10 | from robodm.trajectory import CodecConfig 11 | 12 | 13 | class TestShapeBasedCodecSelection: 14 | """Test codec selection based on data shape.""" 15 | 16 | def test_rgb_image_codec_selection(self): 17 | """Test that RGB images get video codecs when compatible.""" 18 | config = CodecConfig() 19 | 20 | # RGB image with even dimensions should get a video codec 21 | rgb_even = FeatureType(dtype="uint8", shape=(128, 128, 3)) 22 | codec = config.get_codec_for_feature(rgb_even) 23 | assert ( 24 | codec != "rawvideo" 25 | ), f"RGB image with even dimensions should get video codec, got {codec}" 26 | assert codec in [ 27 | "libx264", 28 | "libx265", 29 | "libaom-av1", 30 | "ffv1", 31 | ], f"Got unexpected codec: {codec}" 32 | 33 | def test_non_rgb_shapes_use_rawvideo(self): 34 | """Test that non-RGB shapes always use rawvideo.""" 35 | config = CodecConfig() 36 | 37 | test_cases = [ 38 | ((128, 128), "Grayscale image"), 39 | ((10, ), "1D vector"), 40 | ((5, 10), "2D matrix"), 41 | ((128, 128, 1), "Single channel image"), 42 | ((128, 128, 4), "RGBA image"), 43 | ((20, 30, 5), "Multi-channel data"), 44 | ] 45 | 46 | for shape, description in test_cases: 47 | feature_type = FeatureType(dtype="float32", shape=shape) 48 | codec = config.get_codec_for_feature(feature_type) 49 | assert (codec == "rawvideo" 50 | ), f"{description} should use rawvideo, got {codec}" 51 | 52 | def test_user_specified_codec_validation(self): 53 | """Test user-specified codec validation for RGB images.""" 54 | # Valid user-specified codec for compatible RGB image 55 | config = CodecConfig(codec="libx264") 56 | rgb_even = FeatureType(dtype="uint8", shape=(128, 128, 3)) 57 | codec = config.get_codec_for_feature(rgb_even) 58 | assert ( 59 | codec == "libx264" 60 | ), f"Compatible RGB should use user-specified codec, got {codec}" 61 | 62 | # Invalid user-specified codec for incompatible RGB image 63 | config = CodecConfig(codec="libx264") 64 | rgb_odd = FeatureType(dtype="uint8", shape=(127, 129, 3)) 65 | codec = config.get_codec_for_feature(rgb_odd) 66 | assert ( 67 | codec == "rawvideo" 68 | ), f"Incompatible RGB should fall back to rawvideo, got {codec}" 69 | 70 | 71 | class TestCodecCompatibilityValidation: 72 | """Test codec compatibility validation methods.""" 73 | 74 | def test_is_valid_image_shape(self): 75 | """Test the is_valid_image_shape method.""" 76 | test_cases = [ 77 | # (shape, codec, expected_result, description) 78 | ((128, 128, 3), "libx264", True, "Even dimensions should work"), 79 | ((127, 129, 3), "libx264", False, 80 | "Odd dimensions should fail for H.264"), 81 | ((1920, 1080, 3), "libx264", True, 82 | "Large even dimensions should work"), 83 | ((2, 2, 3), "libx264", True, 84 | "Very small even dimensions might work"), 85 | ( 86 | (128, 128), 87 | "libx264", 88 | False, 89 | "Non-RGB should not be valid for video codec", 90 | ), 91 | ((10, ), "libx264", False, "1D data should not be valid"), 92 | ] 93 | 94 | for shape, codec, expected, description in test_cases: 95 | result = CodecConfig.is_valid_image_shape(shape, codec) 96 | assert ( 97 | result == expected 98 | ), f"{description}: shape {shape} with {codec} expected {expected}, got {result}" 99 | 100 | def test_is_codec_config_supported(self): 101 | """Test PyAV codec configuration support.""" 102 | # These should work for most systems 103 | assert CodecConfig.is_codec_config_supported(128, 128, "yuv420p", 104 | "libx264") 105 | 106 | # Very large dimensions might not work 107 | large_result = CodecConfig.is_codec_config_supported( 108 | 10000, 10000, "yuv420p", "libx264") 109 | # Don't assert this as it depends on system capabilities 110 | print(f"Large dimensions test result: {large_result}") 111 | 112 | 113 | class TestRoundtripData: 114 | """Test roundtrip encoding/decoding for various data shapes.""" 115 | 116 | def test_different_shapes_and_types(self): 117 | """Test that different data shapes and types can be handled.""" 118 | config = CodecConfig() 119 | 120 | test_cases = [ 121 | # (shape, dtype, expected_codec_type) 122 | ((128, 128, 3), "uint8", "video"), # RGB image 123 | ((100, 200, 3), "uint8", "video"), # Different RGB size 124 | ((128, 128), "uint8", "rawvideo"), # Grayscale 125 | ((10, ), "float32", "rawvideo"), # Vector 126 | ((5, 10), "float64", "rawvideo"), # Matrix 127 | ((128, 128, 1), "uint8", "rawvideo"), # Single channel 128 | ((128, 128, 4), "uint8", "rawvideo"), # RGBA 129 | ] 130 | 131 | for shape, dtype, expected_type in test_cases: 132 | feature_type = FeatureType(dtype=dtype, shape=shape) 133 | codec = config.get_codec_for_feature(feature_type) 134 | 135 | if expected_type == "video": 136 | assert (codec != "rawvideo" 137 | ), f"Shape {shape} should get video codec, got {codec}" 138 | else: 139 | assert (codec == "rawvideo" 140 | ), f"Shape {shape} should get rawvideo, got {codec}" 141 | 142 | def test_mixed_rgb_and_non_rgb_in_trajectory(self): 143 | """Test handling mixed RGB and non-RGB data types.""" 144 | config = CodecConfig() 145 | 146 | # Simulate mixed data in a trajectory 147 | features = { 148 | "camera/rgb": FeatureType(dtype="uint8", 149 | shape=(128, 128, 3)), # RGB 150 | "camera/depth": FeatureType(dtype="float32", 151 | shape=(128, 128)), # Depth 152 | "robot/joint_pos": FeatureType(dtype="float32", 153 | shape=(7, )), # Vector 154 | "camera/mask": FeatureType(dtype="uint8", 155 | shape=(128, 128, 1)), # Mask 156 | } 157 | 158 | codecs = {} 159 | for name, feature_type in features.items(): 160 | codecs[name] = config.get_codec_for_feature(feature_type) 161 | 162 | # Only RGB should get video codec 163 | assert codecs["camera/rgb"] != "rawvideo", "RGB should get video codec" 164 | assert codecs[ 165 | "camera/depth"] == "rawvideo", "Depth should get rawvideo" 166 | assert (codecs["robot/joint_pos"] == "rawvideo" 167 | ), "Joint positions should get rawvideo" 168 | assert codecs["camera/mask"] == "rawvideo", "Mask should get rawvideo" 169 | 170 | 171 | class TestPixelFormatSelection: 172 | """Test pixel format selection logic.""" 173 | 174 | def test_rgb_pixel_format_selection(self): 175 | """Test pixel format selection for RGB data.""" 176 | config = CodecConfig() 177 | 178 | rgb_type = FeatureType(dtype="uint8", shape=(128, 128, 3)) 179 | 180 | # Test different codecs 181 | yuv_codecs = ["libx264", "libx265", "libaom-av1", "ffv1"] 182 | for codec in yuv_codecs: 183 | result = config.get_pixel_format(codec, rgb_type) 184 | assert ( 185 | result == "yuv420p" 186 | ), f"RGB data with {codec} should get yuv420p, got {result}" 187 | 188 | def test_non_rgb_pixel_format_selection(self): 189 | """Test pixel format selection for non-RGB data.""" 190 | config = CodecConfig() 191 | 192 | # Non-RGB data should not get RGB pixel formats 193 | grayscale_type = FeatureType(dtype="uint8", shape=(128, 128)) 194 | vector_type = FeatureType(dtype="float32", shape=(10, )) 195 | 196 | # These should return None (no pixel format for non-RGB) 197 | for data_type in [grayscale_type, vector_type]: 198 | for codec in ["libx264", "libx265", "libaom-av1", "ffv1"]: 199 | result = config.get_pixel_format(codec, data_type) 200 | # Should not return RGB-specific formats 201 | assert ( 202 | result is None 203 | ), f"Non-RGB data should not get pixel format, got {result}" 204 | 205 | def test_rawvideo_pixel_format(self): 206 | """Test that rawvideo returns None for pixel format.""" 207 | config = CodecConfig() 208 | 209 | rgb_type = FeatureType(dtype="uint8", shape=(128, 128, 3)) 210 | result = config.get_pixel_format("rawvideo", rgb_type) 211 | assert ( 212 | result is None 213 | ), f"rawvideo should return None for pixel format, got {result}" 214 | 215 | 216 | @pytest.fixture 217 | def temp_dir(): 218 | """Create a temporary directory for tests.""" 219 | with tempfile.TemporaryDirectory() as tmpdir: 220 | yield tmpdir 221 | -------------------------------------------------------------------------------- /tests/test_time_manager.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test cases for robodm TimeManager system. 3 | 4 | Tests cover: 5 | - Time unit conversions 6 | - Monotonic timestamp enforcement 7 | - Datetime handling and conversions 8 | - Integration with Trajectory class 9 | - Edge cases and error handling 10 | """ 11 | 12 | import os 13 | import tempfile 14 | from datetime import datetime, timedelta, timezone 15 | 16 | import numpy as np 17 | import pytest 18 | 19 | from robodm import create_trajectory 20 | from robodm.trajectory import TimeManager, Trajectory 21 | 22 | 23 | class TestTimeManager: 24 | """Test the TimeManager class functionality.""" 25 | 26 | def test_time_unit_conversions(self): 27 | """Test conversion between different time units.""" 28 | tm = TimeManager(time_unit="ms") 29 | 30 | # Test conversion to nanoseconds 31 | assert tm.convert_to_nanoseconds(1000, "ms") == 1_000_000_000 32 | assert tm.convert_to_nanoseconds(1, "s") == 1_000_000_000 33 | assert tm.convert_to_nanoseconds(1000, "μs") == 1_000_000 34 | assert tm.convert_to_nanoseconds(1000, "ns") == 1000 35 | 36 | # Test conversion from nanoseconds 37 | assert tm.convert_from_nanoseconds(1_000_000_000, "ms") == 1000 38 | assert tm.convert_from_nanoseconds(1_000_000_000, "s") == 1 39 | assert tm.convert_from_nanoseconds(1_000_000, "μs") == 1000 40 | assert tm.convert_from_nanoseconds(1000, "ns") == 1000 41 | 42 | # Test unit conversion 43 | assert tm.convert_units(1, "s", "ms") == 1000 44 | assert tm.convert_units(1000, "ms", "s") == 1 45 | assert tm.convert_units(1000, "μs", "ms") == 1 46 | 47 | def test_invalid_time_units(self): 48 | """Test handling of invalid time units.""" 49 | with pytest.raises(ValueError): 50 | TimeManager(time_unit="invalid") 51 | 52 | tm = TimeManager() 53 | with pytest.raises(ValueError): 54 | tm.convert_to_nanoseconds(1000, "invalid") 55 | 56 | with pytest.raises(ValueError): 57 | tm.convert_from_nanoseconds(1000, "invalid") 58 | 59 | def test_datetime_conversions(self): 60 | """Test datetime to timestamp conversions.""" 61 | base_dt = datetime(2023, 1, 1, 12, 0, 0, tzinfo=timezone.utc) 62 | tm = TimeManager(base_datetime=base_dt, time_unit="ms") 63 | 64 | # Test conversion of datetime 1 hour after base 65 | test_dt = base_dt + timedelta(hours=1) 66 | timestamp_ms = tm.datetime_to_timestamp(test_dt, "ms") 67 | assert timestamp_ms == 3600 * 1000 # 1 hour in milliseconds 68 | 69 | # Test reverse conversion 70 | converted_dt = tm.timestamp_to_datetime(timestamp_ms, "ms") 71 | assert converted_dt == test_dt 72 | 73 | # Test with different time units 74 | timestamp_s = tm.datetime_to_timestamp(test_dt, "s") 75 | assert timestamp_s == 3600 # 1 hour in seconds 76 | 77 | def test_monotonic_enforcement(self): 78 | """Test monotonic timestamp enforcement.""" 79 | tm = TimeManager(time_unit="ms", enforce_monotonic=True) 80 | 81 | # First timestamp should pass through 82 | ts1 = tm.validate_timestamp(1000) 83 | assert ts1 == 1000 84 | 85 | # Second timestamp should be adjusted if not monotonic 86 | ts2 = tm.validate_timestamp(500) # Earlier than previous 87 | assert ts2 > ts1 88 | 89 | # Valid monotonic timestamp should pass through 90 | ts3 = tm.validate_timestamp(2000) 91 | assert ts3 == 2000 92 | 93 | def test_non_monotonic_mode(self): 94 | """Test behavior when monotonic enforcement is disabled.""" 95 | tm = TimeManager(time_unit="ms", enforce_monotonic=False) 96 | 97 | ts1 = tm.validate_timestamp(1000) 98 | assert ts1 == 1000 99 | 100 | # Should allow non-monotonic timestamps 101 | ts2 = tm.validate_timestamp(500) 102 | assert ts2 == 500 103 | 104 | def test_add_timestep(self): 105 | """Test adding timesteps to current timestamp.""" 106 | tm = TimeManager(time_unit="ms") 107 | 108 | # First timestep 109 | ts1 = tm.add_timestep(100) # 100ms 110 | assert ts1 == 100 111 | 112 | # Second timestep should be cumulative 113 | ts2 = tm.add_timestep(50) # +50ms 114 | assert ts2 == 150 115 | 116 | # Test with different units 117 | ts3 = tm.add_timestep(1, "s") # +1 second = +1000ms 118 | assert ts3 == 1150 119 | 120 | def test_create_timestamp_sequence(self): 121 | """Test creating sequences of monotonic timestamps.""" 122 | tm = TimeManager(time_unit="ms", enforce_monotonic=False 123 | ) # Disable monotonic for predictable sequences 124 | 125 | timestamps = tm.create_timestamp_sequence( 126 | start_timestamp=0, 127 | count=5, 128 | timestep=100 # 100ms steps 129 | ) 130 | 131 | expected = [0, 100, 200, 300, 400] 132 | assert timestamps == expected 133 | 134 | # Test with different units (reset TimeManager) 135 | tm2 = TimeManager(time_unit="ms", enforce_monotonic=False) 136 | timestamps_s = tm2.create_timestamp_sequence(start_timestamp=0, 137 | count=3, 138 | timestep=1, 139 | unit="s") 140 | 141 | expected_s = [0, 1000, 2000] # Converted to milliseconds 142 | assert timestamps_s == expected_s 143 | 144 | def test_reset_functionality(self): 145 | """Test resetting the TimeManager state.""" 146 | tm = TimeManager(time_unit="ms") 147 | 148 | # Add some timestamps 149 | tm.validate_timestamp(1000) 150 | tm.validate_timestamp(2000) 151 | 152 | # Reset should clear internal state 153 | new_base = datetime(2024, 1, 1, tzinfo=timezone.utc) 154 | tm.reset(base_datetime=new_base) 155 | 156 | # Should be able to use earlier timestamps after reset 157 | ts = tm.validate_timestamp(500) 158 | assert ts == 500 159 | 160 | def test_timezone_handling(self): 161 | """Test proper timezone handling in datetime conversions.""" 162 | # Test with UTC timezone 163 | base_dt_utc = datetime(2023, 1, 1, 12, 0, 0, tzinfo=timezone.utc) 164 | tm_utc = TimeManager(base_datetime=base_dt_utc) 165 | 166 | # Test with different timezone 167 | base_dt_est = datetime(2023, 168 | 1, 169 | 1, 170 | 7, 171 | 0, 172 | 0, 173 | tzinfo=timezone(timedelta(hours=-5))) # EST 174 | tm_est = TimeManager(base_datetime=base_dt_est) 175 | 176 | # Both should give same result for same absolute time 177 | test_dt_utc = base_dt_utc + timedelta(hours=1) 178 | test_dt_est = base_dt_est + timedelta(hours=1) 179 | 180 | ts_utc = tm_utc.datetime_to_timestamp(test_dt_utc) 181 | ts_est = tm_est.datetime_to_timestamp(test_dt_est) 182 | 183 | assert ts_utc == ts_est # Should be the same relative to their bases 184 | 185 | 186 | class TestTrajectoryTimeIntegration: 187 | """Test integration of TimeManager with Trajectory class.""" 188 | 189 | def test_trajectory_with_time_manager(self): 190 | """Test that Trajectory properly uses TimeManager.""" 191 | with tempfile.TemporaryDirectory() as temp_dir: 192 | path = os.path.join(temp_dir, "test_trajectory.mkv") 193 | base_dt = datetime(2023, 1, 1, 12, 0, 0, tzinfo=timezone.utc) 194 | 195 | # Create trajectory with specific time settings 196 | trajectory = create_trajectory( 197 | path, 198 | mode="w", 199 | base_datetime=base_dt, 200 | time_unit="ms", 201 | enforce_monotonic=True, 202 | ) 203 | 204 | # Add data with explicit timestamps 205 | trajectory.add("feature1", 206 | "value1", 207 | timestamp=1000, 208 | time_unit="ms") 209 | trajectory.add("feature1", 210 | "value2", 211 | timestamp=2000, 212 | time_unit="ms") 213 | trajectory.add("feature1", 214 | "value3", 215 | timestamp=1500, 216 | time_unit="ms") # Should be adjusted 217 | 218 | trajectory.close() 219 | 220 | # Load and verify 221 | trajectory_read = Trajectory(path, mode="r") 222 | data = trajectory_read.load() 223 | trajectory_read.close() 224 | 225 | assert len(data["feature1"]) == 3 226 | 227 | def test_trajectory_datetime_based_timestamps(self): 228 | """Test trajectory with datetime-based timestamp calculation.""" 229 | with tempfile.TemporaryDirectory() as temp_dir: 230 | path = os.path.join(temp_dir, "test_trajectory.mkv") 231 | base_dt = datetime(2023, 1, 1, 12, 0, 0, tzinfo=timezone.utc) 232 | 233 | trajectory = create_trajectory(path, 234 | mode="w", 235 | base_datetime=base_dt, 236 | time_unit="ms") 237 | 238 | # Add data at specific datetime points 239 | dt1 = base_dt + timedelta(seconds=1) 240 | dt2 = base_dt + timedelta(seconds=2) 241 | 242 | ts1 = trajectory.time_manager.datetime_to_timestamp(dt1, "ms") 243 | ts2 = trajectory.time_manager.datetime_to_timestamp(dt2, "ms") 244 | 245 | trajectory.add("sensor1", 100.0, timestamp=ts1, time_unit="ms") 246 | trajectory.add("sensor1", 200.0, timestamp=ts2, time_unit="ms") 247 | 248 | trajectory.close() 249 | 250 | # Verify timestamps are as expected 251 | assert ts1 == 1000 # 1 second = 1000ms 252 | assert ts2 == 2000 # 2 seconds = 2000ms 253 | 254 | def test_trajectory_auto_timestamps(self): 255 | """Test trajectory with automatic timestamp generation.""" 256 | with tempfile.TemporaryDirectory() as temp_dir: 257 | path = os.path.join(temp_dir, "test_trajectory.mkv") 258 | 259 | trajectory = create_trajectory(path, mode="w", time_unit="ms") 260 | 261 | # Add data without explicit timestamps 262 | trajectory.add("feature1", "value1") 263 | trajectory.add("feature1", "value2") 264 | trajectory.add("feature1", "value3") 265 | 266 | trajectory.close() 267 | 268 | # Should create trajectory without errors 269 | trajectory_read = Trajectory(path, mode="r") 270 | data = trajectory_read.load() 271 | trajectory_read.close() 272 | 273 | assert len(data["feature1"]) == 3 274 | 275 | def test_trajectory_mixed_time_units(self): 276 | """Test trajectory with mixed time units in different add() calls.""" 277 | with tempfile.TemporaryDirectory() as temp_dir: 278 | path = os.path.join(temp_dir, "test_trajectory.mkv") 279 | 280 | trajectory = create_trajectory(path, mode="w", time_unit="ms") 281 | 282 | # Add data with different time units 283 | trajectory.add("sensor1", 1.0, timestamp=1, 284 | time_unit="s") # 1000ms 285 | trajectory.add("sensor1", 2.0, timestamp=1500, 286 | time_unit="ms") # 1500ms 287 | trajectory.add("sensor1", 3.0, timestamp=2000000, 288 | time_unit="μs") # 2000ms 289 | 290 | trajectory.close() 291 | 292 | trajectory_read = Trajectory(path, mode="r") 293 | data = trajectory_read.load() 294 | trajectory_read.close() 295 | 296 | assert len(data["sensor1"]) == 3 297 | 298 | 299 | class TestTimeManagerEdgeCases: 300 | """Test edge cases and error conditions.""" 301 | 302 | def test_large_timestamp_values(self): 303 | """Test handling of very large timestamp values.""" 304 | tm = TimeManager(time_unit="ns") 305 | 306 | # Test nanosecond precision with large values 307 | large_ns = 9223372036854775807 # Near max int64 308 | ts_ms = tm.convert_from_nanoseconds(large_ns, "ms") 309 | back_to_ns = tm.convert_to_nanoseconds(ts_ms, "ms") 310 | 311 | # Should handle large values without overflow 312 | assert isinstance(ts_ms, int) 313 | assert isinstance(back_to_ns, int) 314 | 315 | def test_zero_and_negative_timestamps(self): 316 | """Test handling of zero and negative timestamp values.""" 317 | tm = TimeManager(time_unit="ms", enforce_monotonic=False) 318 | 319 | # Should handle zero timestamps 320 | ts = tm.validate_timestamp(0) 321 | assert ts == 0 322 | 323 | # Should handle negative timestamps when monotonic is disabled 324 | ts_neg = tm.validate_timestamp(-1000) 325 | assert ts_neg == -1000 326 | 327 | def test_floating_point_timestamps(self): 328 | """Test handling of floating point timestamp inputs.""" 329 | tm = TimeManager(time_unit="ms") 330 | 331 | # Should handle float inputs by converting to int 332 | ts = tm.validate_timestamp(1500.7) 333 | assert isinstance(ts, int) 334 | assert ts == 1500 335 | 336 | # Test float conversion in timestep 337 | ts_step = tm.add_timestep(100.5) 338 | assert isinstance(ts_step, int) 339 | 340 | def test_sequence_with_overlap_handling(self): 341 | """Test timestamp sequence generation with overlap scenarios.""" 342 | tm = TimeManager(time_unit="ms", enforce_monotonic=True) 343 | 344 | # Set initial state 345 | tm.validate_timestamp(5000) 346 | 347 | # Create sequence that would overlap with existing state 348 | timestamps = tm.create_timestamp_sequence( 349 | start_timestamp=3000, 350 | count=3, 351 | timestep=1000 # Earlier than current state 352 | ) 353 | 354 | # Should adjust to maintain monotonic ordering 355 | assert all(ts > 5000 for ts in timestamps) 356 | assert timestamps[1] > timestamps[0] 357 | assert timestamps[2] > timestamps[1] 358 | 359 | 360 | class TestTimeManagerPerformance: 361 | """Test performance characteristics of TimeManager.""" 362 | 363 | def test_large_timestamp_sequence_generation(self): 364 | """Test generating large sequences of timestamps efficiently.""" 365 | tm = TimeManager( 366 | time_unit="ms", 367 | enforce_monotonic=False) # Disable for predictable sequence 368 | 369 | # Generate large sequence 370 | timestamps = tm.create_timestamp_sequence(start_timestamp=0, 371 | count=10000, 372 | timestep=1) 373 | 374 | assert len(timestamps) == 10000 375 | assert timestamps[0] == 0 376 | assert timestamps[-1] == 9999 377 | 378 | # Verify monotonic ordering 379 | for i in range(1, len(timestamps)): 380 | assert timestamps[i] > timestamps[i - 1] 381 | 382 | def test_many_timestamp_validations(self): 383 | """Test performance of many timestamp validations.""" 384 | tm = TimeManager(time_unit="ms", enforce_monotonic=True) 385 | 386 | # Validate many timestamps 387 | timestamps = [] 388 | for i in range(1000): 389 | ts = tm.validate_timestamp(i) 390 | timestamps.append(ts) 391 | 392 | # Should maintain monotonic ordering 393 | for i in range(1, len(timestamps)): 394 | assert timestamps[i] >= timestamps[i - 1] 395 | 396 | 397 | if __name__ == "__main__": 398 | # Run tests if executed directly 399 | pytest.main([__file__, "-v"]) 400 | -------------------------------------------------------------------------------- /tests/test_trajectory_loader_edge_cases.py: -------------------------------------------------------------------------------- 1 | """ 2 | Edge case and boundary testing for Trajectory.load functionality. 3 | """ 4 | 5 | import os 6 | import tempfile 7 | from typing import Dict, List 8 | 9 | import av 10 | import numpy as np 11 | import pytest 12 | 13 | from robodm import FeatureType, Trajectory 14 | 15 | 16 | @pytest.fixture 17 | def temp_dir(): 18 | with tempfile.TemporaryDirectory() as td: 19 | yield td 20 | 21 | 22 | @pytest.fixture(scope="session") 23 | def rng() -> np.random.Generator: 24 | return np.random.default_rng(seed=12345) 25 | 26 | 27 | class TestTrajectoryLoaderEdgeCases: 28 | """Edge cases and boundary conditions for the new loader.""" 29 | 30 | def test_zero_length_trajectory(self, temp_dir): 31 | """Test loading trajectory with zero data points.""" 32 | path = os.path.join(temp_dir, "zero_length.vla") 33 | traj = Trajectory(path, mode="w") 34 | traj.close() 35 | 36 | # Check if file exists after creation 37 | if not os.path.exists(path): 38 | # If no file was created (because no data was added), 39 | # the Trajectory constructor should fail when trying to read 40 | with pytest.raises(FileNotFoundError): 41 | t = Trajectory(path, mode="r") 42 | return 43 | 44 | t = Trajectory(path, mode="r") 45 | 46 | # All operations should work on empty trajectory 47 | empty = t.load() 48 | assert isinstance(empty, dict) 49 | assert len(empty) == 0 50 | 51 | # Slicing empty should return empty 52 | sliced = t.load(data_slice=slice(0, 10)) 53 | assert len(sliced) == 0 54 | 55 | # Resampling empty should return empty 56 | resampled = t.load(desired_frequency=10.0) 57 | assert len(resampled) == 0 58 | 59 | # Container return should work 60 | container_path = t.load(return_type="container") 61 | assert container_path == path 62 | 63 | t.close() 64 | 65 | def test_single_packet_with_none_pts(self, temp_dir): 66 | """Test handling of packets with None pts/dts values.""" 67 | path = os.path.join(temp_dir, "none_pts.vla") 68 | traj = Trajectory(path, mode="w") 69 | 70 | # Add one normal data point 71 | traj.add("value", 42, timestamp=100) 72 | traj.close() 73 | 74 | t = Trajectory(path, mode="r") 75 | data = t.load() 76 | 77 | # Should skip packets with None pts and only load valid ones 78 | assert "value" in data 79 | assert len(data["value"]) >= 1 80 | 81 | t.close() 82 | 83 | def test_slice_start_equals_stop(self, temp_dir): 84 | """Test slice where start equals stop (empty slice).""" 85 | path = os.path.join(temp_dir, "equal_start_stop.vla") 86 | traj = Trajectory(path, mode="w") 87 | 88 | for i in range(10): 89 | traj.add("value", i, timestamp=i * 100) 90 | traj.close() 91 | 92 | t = Trajectory(path, mode="r") 93 | 94 | # Empty slices at various positions 95 | for start_stop in [0, 5, 9, 15]: # Including beyond data 96 | empty = t.load(data_slice=slice(start_stop, start_stop)) 97 | if len(empty) > 0: # Only check if trajectory has data 98 | assert all(len(v) == 0 for v in empty.values()) 99 | 100 | t.close() 101 | 102 | def test_slice_with_very_large_step(self, temp_dir): 103 | """Test slicing with step much larger than data length.""" 104 | path = os.path.join(temp_dir, "large_step.vla") 105 | traj = Trajectory(path, mode="w") 106 | 107 | for i in range(20): 108 | traj.add("value", i, timestamp=i * 100) 109 | traj.close() 110 | 111 | t = Trajectory(path, mode="r") 112 | 113 | # Step of 100 on 20 elements should give only first element 114 | result = t.load(data_slice=slice(0, None, 100)) 115 | assert all(len(v) == 1 for v in result.values()) 116 | assert result["value"][0] == 0 117 | 118 | # Step of 10 should give every 10th element 119 | result = t.load(data_slice=slice(0, None, 10)) 120 | assert all(len(v) == 2 for v in result.values()) # Elements 0 and 10 121 | np.testing.assert_array_equal(result["value"], [0, 10]) 122 | 123 | t.close() 124 | 125 | def test_frequency_boundary_values(self, temp_dir): 126 | """Test frequency resampling with boundary values.""" 127 | path = os.path.join(temp_dir, "freq_boundary.vla") 128 | traj = Trajectory(path, mode="w") 129 | 130 | # Create data at 10Hz (100ms intervals) 131 | for i in range(30): 132 | traj.add("value", i, timestamp=i * 100) 133 | traj.close() 134 | 135 | t = Trajectory(path, mode="r") 136 | 137 | # Very small frequency (much less than 1Hz) 138 | very_small = t.load( 139 | desired_frequency=0.001) # 1 frame per 1000 seconds 140 | assert all(len(v) <= 1 for v in very_small.values()) 141 | 142 | # Frequency that creates exactly one frame period 143 | one_period = t.load(desired_frequency=1.0) # 1Hz = 1000ms period 144 | # Should get roughly every 10th frame (1000ms / 100ms = 10) 145 | expected_len = len(next(iter(one_period.values()))) 146 | assert 2 <= expected_len <= 5 # Allow some tolerance 147 | 148 | t.close() 149 | 150 | def test_seek_beyond_stream_end(self, temp_dir): 151 | """Test seeking to position beyond the stream length.""" 152 | path = os.path.join(temp_dir, "seek_beyond.vla") 153 | traj = Trajectory(path, mode="w") 154 | 155 | # Short trajectory 156 | for i in range(5): 157 | traj.add("value", i, timestamp=i * 100) 158 | traj.close() 159 | 160 | t = Trajectory(path, mode="r") 161 | 162 | # Try to slice starting beyond the data 163 | beyond = t.load(data_slice=slice(10, 20)) 164 | assert all(len(v) == 0 for v in beyond.values()) 165 | 166 | # Slice that starts within data but extends beyond 167 | partial = t.load(data_slice=slice(3, 10)) 168 | full = t.load() 169 | for k in partial: 170 | np.testing.assert_array_equal(partial[k], full[k][3:]) 171 | 172 | t.close() 173 | 174 | def test_mixed_data_types_in_single_feature(self, temp_dir): 175 | """Test trajectory with varying data types for same feature name.""" 176 | path = os.path.join(temp_dir, "mixed_types.vla") 177 | traj = Trajectory(path, mode="w") 178 | 179 | # This should be consistent - all same feature should have same type 180 | for i in range(5): 181 | traj.add("consistent_value", float(i), timestamp=i * 100) 182 | 183 | traj.close() 184 | 185 | t = Trajectory(path, mode="r") 186 | data = t.load() 187 | 188 | # All values for same feature should have consistent type 189 | assert "consistent_value" in data 190 | assert len(data["consistent_value"]) == 5 191 | assert data["consistent_value"].dtype in [np.float32, np.float64] 192 | 193 | t.close() 194 | 195 | def test_very_sparse_timestamps(self, temp_dir): 196 | """Test trajectory with very sparse, irregular timestamps.""" 197 | path = os.path.join(temp_dir, "sparse_timestamps.vla") 198 | traj = Trajectory(path, mode="w") 199 | 200 | # Very irregular timestamps 201 | timestamps = [0, 1000, 5000, 5001, 10000] # ms 202 | for i, ts in enumerate(timestamps): 203 | traj.add("value", i, timestamp=ts) 204 | 205 | traj.close() 206 | 207 | t = Trajectory(path, mode="r") 208 | 209 | # Should handle sparse data gracefully 210 | full = t.load() 211 | assert len(full["value"]) == 5 212 | 213 | # Resampling should work with sparse data 214 | resampled = t.load(desired_frequency=1.0) # 1Hz = 1000ms 215 | # Should get fewer frames due to large gaps 216 | assert len(resampled["value"]) <= 5 217 | 218 | t.close() 219 | 220 | def test_unicode_and_special_characters(self, temp_dir): 221 | """Test handling of unicode and special characters in string data.""" 222 | path = os.path.join(temp_dir, "unicode.vla") 223 | traj = Trajectory(path, mode="w") 224 | 225 | special_strings = [ 226 | "hello", 227 | "café", 228 | "🤖", 229 | "データ", 230 | "test\nwith\nnewlines", 231 | "quotes\"and'apostrophes", 232 | "", # empty string 233 | ] 234 | 235 | for i, s in enumerate(special_strings): 236 | traj.add("text", s, timestamp=i * 100) 237 | 238 | traj.close() 239 | 240 | t = Trajectory(path, mode="r") 241 | data = t.load() 242 | 243 | assert "text" in data 244 | assert len(data["text"]) == len(special_strings) 245 | # Should preserve all special characters 246 | for i, expected in enumerate(special_strings): 247 | assert data["text"][i] == expected 248 | 249 | # Test slicing with unicode data 250 | sliced = t.load(data_slice=slice(1, 4)) 251 | np.testing.assert_array_equal(sliced["text"], special_strings[1:4]) 252 | 253 | t.close() 254 | 255 | def test_extremely_large_arrays(self, temp_dir, rng): 256 | """Test loading trajectory with very large numpy arrays.""" 257 | path = os.path.join(temp_dir, "large_arrays.vla") 258 | traj = Trajectory(path, mode="w") 259 | 260 | # Create reasonably large arrays (not extremely large to avoid memory issues) 261 | for i in range(3): 262 | large_array = rng.random((100, 100)).astype(np.float32) 263 | traj.add("large_data", large_array, timestamp=i * 1000) 264 | 265 | traj.close() 266 | 267 | t = Trajectory(path, mode="r") 268 | data = t.load() 269 | 270 | # Should load successfully 271 | assert "large_data" in data 272 | loaded_shape = data["large_data"].shape 273 | assert loaded_shape[0] == 3 # 3 timesteps 274 | assert loaded_shape[1:] == (100, 100) # Each array is 100x100 275 | 276 | t.close() 277 | 278 | def test_load_with_corrupted_metadata(self, temp_dir): 279 | """Test loading trajectory with missing or corrupted stream metadata.""" 280 | path = os.path.join(temp_dir, "normal.vla") 281 | traj = Trajectory(path, mode="w") 282 | 283 | # Create normal trajectory first 284 | for i in range(5): 285 | traj.add("value", i, timestamp=i * 100) 286 | traj.close() 287 | 288 | # Loading should work normally 289 | t = Trajectory(path, mode="r") 290 | data = t.load() 291 | assert "value" in data 292 | assert len(data["value"]) == 5 293 | t.close() 294 | 295 | def test_concurrent_feature_different_lengths(self, temp_dir): 296 | """Test loading when different features might have different packet counts.""" 297 | path = os.path.join(temp_dir, "different_lengths.vla") 298 | traj = Trajectory(path, mode="w") 299 | 300 | # Add features at different rates to same trajectory 301 | # This tests the early termination logic 302 | for i in range(10): 303 | traj.add("frequent", i, timestamp=i * 100) 304 | if i % 2 == 0: # Less frequent feature 305 | traj.add("sparse", i // 2, timestamp=i * 100) 306 | 307 | traj.close() 308 | 309 | t = Trajectory(path, mode="r") 310 | data = t.load() 311 | 312 | # Should load all available data for each feature 313 | assert len(data["frequent"]) == 10 314 | assert len(data["sparse"]) == 5 315 | 316 | # Slicing should work correctly with different lengths 317 | sliced = t.load(data_slice=slice(0, 3)) 318 | # Each feature gets sliced independently 319 | assert len(sliced["frequent"]) == 3 320 | assert len(sliced["sparse"]) <= 3 # Might be fewer due to sparsity 321 | 322 | t.close() 323 | 324 | def test_precision_edge_cases_float(self, temp_dir): 325 | """Test edge cases with floating point precision.""" 326 | path = os.path.join(temp_dir, "float_precision.vla") 327 | traj = Trajectory(path, mode="w") 328 | 329 | # Test various floating point edge cases 330 | float_values = [ 331 | 0.0, 332 | -0.0, 333 | 1e-10, # Very small positive 334 | -1e-10, # Very small negative 335 | 1e10, # Very large 336 | np.inf, 337 | -np.inf, 338 | # np.nan, # Skip NaN as it may cause comparison issues 339 | ] 340 | 341 | for i, val in enumerate(float_values): 342 | if not np.isnan(val): # Skip NaN values for now 343 | traj.add("float_val", float(val), timestamp=i * 100) 344 | 345 | traj.close() 346 | 347 | t = Trajectory(path, mode="r") 348 | data = t.load() 349 | 350 | assert "float_val" in data 351 | # Verify precision is maintained (for finite values) 352 | for i, expected in enumerate(float_values): 353 | if not np.isnan(expected) and np.isfinite(expected): 354 | assert abs(data["float_val"][i] - expected) < 1e-12 355 | 356 | t.close() 357 | 358 | def test_memory_efficient_loading_large_slice(self, temp_dir): 359 | """Test that large slices don't load unnecessary data into memory.""" 360 | path = os.path.join(temp_dir, "memory_test.vla") 361 | traj = Trajectory(path, mode="w") 362 | 363 | # Create reasonably sized trajectory 364 | for i in range(100): # Reduced from 1000 to make test faster 365 | traj.add("value", i, timestamp=i * 100) # 100ms intervals 366 | 367 | traj.close() 368 | 369 | t = Trajectory(path, mode="r") 370 | 371 | # Load small slice from middle - should be efficient 372 | small_slice = t.load(data_slice=slice(40, 50)) 373 | assert len(small_slice["value"]) == 10 374 | np.testing.assert_array_equal(small_slice["value"], list(range(40, 375 | 50))) 376 | 377 | # Load with high frequency + slice - should also be efficient 378 | freq_slice = t.load(desired_frequency=5.0, 379 | data_slice=slice(1, 11)) # 5Hz on 10Hz data 380 | assert len(freq_slice["value"]) == 10 381 | 382 | t.close() 383 | 384 | 385 | class TestTrajectoryLoaderErrorHandling: 386 | """Test error handling and recovery in the loader.""" 387 | 388 | def test_invalid_slice_combinations(self, temp_dir): 389 | """Test various invalid slice parameter combinations.""" 390 | path = os.path.join(temp_dir, "for_error_test.vla") 391 | traj = Trajectory(path, mode="w") 392 | 393 | for i in range(10): 394 | traj.add("value", i, timestamp=i * 100) 395 | traj.close() 396 | 397 | t = Trajectory(path, mode="r") 398 | 399 | # Test invalid step values 400 | invalid_slices = [ 401 | slice(0, 10, 0), # Zero step 402 | slice(0, 10, -1), # Negative step 403 | slice(0, 10, -5), # Large negative step 404 | ] 405 | 406 | for invalid_slice in invalid_slices: 407 | with pytest.raises(ValueError): 408 | _ = t.load(data_slice=invalid_slice) 409 | 410 | t.close() 411 | 412 | def test_invalid_frequency_values(self, temp_dir): 413 | """Test various invalid frequency values.""" 414 | path = os.path.join(temp_dir, "for_freq_error.vla") 415 | traj = Trajectory(path, mode="w") 416 | 417 | traj.add("value", 42, timestamp=0) 418 | traj.close() 419 | 420 | t = Trajectory(path, mode="r") 421 | 422 | invalid_frequencies = [ 423 | 0.0, # Zero 424 | -1.0, # Negative 425 | -100.0, # Large negative 426 | ] 427 | 428 | for invalid_freq in invalid_frequencies: 429 | with pytest.raises(ValueError): 430 | _ = t.load(desired_frequency=invalid_freq) 431 | 432 | t.close() 433 | 434 | def test_parameter_combination_edge_cases(self, temp_dir): 435 | """Test edge cases in parameter combinations.""" 436 | path = os.path.join(temp_dir, "param_combos.vla") 437 | traj = Trajectory(path, mode="w") 438 | 439 | for i in range(20): 440 | traj.add("value", i, timestamp=i * 100) 441 | traj.close() 442 | 443 | t = Trajectory(path, mode="r") 444 | 445 | # Valid but unusual combinations 446 | edge_cases = [ 447 | # Very high frequency with slice 448 | { 449 | "desired_frequency": 1000.0, 450 | "data_slice": slice(0, 5) 451 | }, 452 | # Very low frequency with large slice 453 | { 454 | "desired_frequency": 0.1, 455 | "data_slice": slice(0, None) 456 | }, 457 | # Frequency with slice that results in no data 458 | { 459 | "desired_frequency": 5.0, 460 | "data_slice": slice(100, 200) 461 | }, 462 | ] 463 | 464 | for params in edge_cases: 465 | # Should not raise errors, just return appropriate results 466 | result = t.load(**params) 467 | assert isinstance(result, dict) 468 | # All features should have same length 469 | if result: 470 | lengths = [len(v) for v in result.values()] 471 | assert len(set(lengths)) == 1 472 | 473 | t.close() 474 | -------------------------------------------------------------------------------- /tests/test_trajectory_loader_performance.py: -------------------------------------------------------------------------------- 1 | """ 2 | Performance and benchmarking tests for Trajectory.load functionality. 3 | """ 4 | 5 | import os 6 | import tempfile 7 | import time 8 | from typing import Dict, List 9 | 10 | import numpy as np 11 | import pytest 12 | 13 | from robodm import Trajectory 14 | 15 | 16 | @pytest.fixture 17 | def temp_dir(): 18 | with tempfile.TemporaryDirectory() as td: 19 | yield td 20 | 21 | 22 | @pytest.fixture(scope="session") 23 | def rng() -> np.random.Generator: 24 | return np.random.default_rng(seed=98765) 25 | 26 | 27 | @pytest.fixture 28 | def large_trajectory_path(temp_dir, rng) -> str: 29 | """Create a larger trajectory for performance testing.""" 30 | path = os.path.join(temp_dir, "large_traj.vla") 31 | traj = Trajectory(path, mode="w") 32 | 33 | # Create 1000 timesteps of multimodal data 34 | for i in range(1000): 35 | timestamp_ms = int(i * 50) # 20Hz data 36 | data = { 37 | "position": rng.normal(size=3).astype(np.float32), 38 | "velocity": rng.normal(size=3).astype(np.float32), 39 | "joint_angles": rng.normal(size=7).astype(np.float32), 40 | "image": (rng.random((32, 32, 3)) * 255).astype(np.uint8), 41 | "depth": rng.random((32, 32)).astype(np.float32), 42 | "status": f"status_{i % 10}", 43 | "metadata": { 44 | "step": i, 45 | "phase": "test" 46 | }, 47 | } 48 | traj.add_by_dict(data, timestamp=timestamp_ms) 49 | 50 | traj.close() 51 | return path 52 | 53 | 54 | class TestTrajectoryLoaderPerformance: 55 | """Performance tests for the trajectory loader.""" 56 | 57 | def test_full_load_performance(self, large_trajectory_path): 58 | """Benchmark full trajectory loading.""" 59 | t = Trajectory(large_trajectory_path, mode="r") 60 | 61 | start_time = time.time() 62 | data = t.load() 63 | load_time = time.time() - start_time 64 | 65 | # Verify correctness 66 | assert len(next(iter(data.values()))) == 1000 67 | assert len(data) > 0 68 | 69 | # Performance check - should load 1000 frames reasonably quickly 70 | # This is not a strict requirement, just a sanity check 71 | assert load_time < 30.0 # Should complete within 30 seconds 72 | 73 | print(f"Full load of 1000 frames took {load_time:.3f}s") 74 | t.close() 75 | 76 | def test_slice_performance_vs_full_load(self, large_trajectory_path): 77 | """Compare performance of sliced vs full loading.""" 78 | t = Trajectory(large_trajectory_path, mode="r") 79 | 80 | # Time full load 81 | start_time = time.time() 82 | full_data = t.load() 83 | full_time = time.time() - start_time 84 | 85 | # Time small slice 86 | start_time = time.time() 87 | slice_data = t.load(data_slice=slice(100, 200)) 88 | slice_time = time.time() - start_time 89 | 90 | # Verify correctness 91 | assert len(next(iter(slice_data.values()))) == 100 92 | for k in slice_data: 93 | np.testing.assert_array_equal(slice_data[k], full_data[k][100:200]) 94 | 95 | # Performance - slice should be faster than full load 96 | print(f"Full load: {full_time:.3f}s, Slice load: {slice_time:.3f}s") 97 | 98 | t.close() 99 | 100 | def test_seeking_performance_benefit(self, large_trajectory_path): 101 | """Test that seeking provides performance benefit for large slices.""" 102 | t = Trajectory(large_trajectory_path, mode="r") 103 | 104 | # Test slice from beginning (no seeking needed) 105 | start_time = time.time() 106 | early_slice = t.load(data_slice=slice(0, 100)) 107 | early_time = time.time() - start_time 108 | 109 | # Test slice from middle (seeking should help) 110 | start_time = time.time() 111 | middle_slice = t.load(data_slice=slice(400, 500)) 112 | middle_time = time.time() - start_time 113 | 114 | # Test slice from end (seeking should help significantly) 115 | start_time = time.time() 116 | late_slice = t.load(data_slice=slice( 117 | 800, 900)) # Changed from 900-1000 to avoid edge case 118 | late_time = time.time() - start_time 119 | 120 | # Verify correctness 121 | assert len(next(iter(early_slice.values()))) == 100 122 | assert len(next(iter(middle_slice.values()))) == 100 123 | 124 | # Late slice might have fewer frames if we're near the end of data 125 | late_len = len(next(iter(late_slice.values()))) 126 | assert late_len > 0 # Should have some data 127 | 128 | print( 129 | f"Early slice: {early_time:.3f}s, Middle slice: {middle_time:.3f}s, Late slice: {late_time:.3f}s" 130 | ) 131 | 132 | # All should complete reasonably quickly 133 | assert early_time < 10.0 134 | assert middle_time < 10.0 135 | assert late_time < 10.0 136 | 137 | t.close() 138 | 139 | def test_frequency_resampling_performance(self, large_trajectory_path): 140 | """Test performance of frequency resampling.""" 141 | t = Trajectory(large_trajectory_path, mode="r") 142 | 143 | # Test various downsampling rates 144 | frequencies = [10.0, 5.0, 2.0, 1.0] # Original is 20Hz 145 | times = [] 146 | 147 | for freq in frequencies: 148 | start_time = time.time() 149 | resampled = t.load(desired_frequency=freq) 150 | resample_time = time.time() - start_time 151 | times.append(resample_time) 152 | 153 | # Verify approximate expected length 154 | expected_len = int(1000 * freq / 20.0) # Rough calculation 155 | actual_len = len(next(iter(resampled.values()))) 156 | assert abs(actual_len - expected_len) <= 5 # Allow some tolerance 157 | 158 | print( 159 | f"Resampling to {freq}Hz: {resample_time:.3f}s, {actual_len} frames" 160 | ) 161 | 162 | # All resampling should complete quickly 163 | assert all(t < 15.0 for t in times) 164 | 165 | t.close() 166 | 167 | def test_combined_operations_performance(self, large_trajectory_path): 168 | """Test performance of combined resampling and slicing.""" 169 | t = Trajectory(large_trajectory_path, mode="r") 170 | 171 | # Test various combinations 172 | test_cases = [ 173 | { 174 | "desired_frequency": 10.0, 175 | "data_slice": slice(100, 300) 176 | }, 177 | { 178 | "desired_frequency": 5.0, 179 | "data_slice": slice(0, 500) 180 | }, 181 | { 182 | "desired_frequency": 2.0, 183 | "data_slice": slice(200, 800, 2) 184 | }, 185 | ] 186 | 187 | for i, params in enumerate(test_cases): 188 | start_time = time.time() 189 | result = t.load(**params) 190 | operation_time = time.time() - start_time 191 | 192 | # Verify result is reasonable 193 | assert len(result) > 0 194 | result_len = len(next(iter(result.values()))) 195 | # Allow empty results due to resampling effects, but at least verify no error 196 | assert result_len >= 0 197 | 198 | print( 199 | f"Combined operation {i+1}: {operation_time:.3f}s, {result_len} frames" 200 | ) 201 | 202 | # Should complete quickly 203 | assert operation_time < 20.0 204 | 205 | t.close() 206 | 207 | def test_repeated_load_caching_behavior(self, large_trajectory_path): 208 | """Test if repeated loads show any caching behavior or performance patterns.""" 209 | t = Trajectory(large_trajectory_path, mode="r") 210 | 211 | # Perform same load operation multiple times 212 | load_times = [] 213 | slice_params = slice(200, 400) 214 | 215 | for i in range(5): 216 | start_time = time.time() 217 | data = t.load(data_slice=slice_params) 218 | load_time = time.time() - start_time 219 | load_times.append(load_time) 220 | 221 | # Verify consistency 222 | assert len(next(iter(data.values()))) == 200 223 | 224 | print(f"Repeated load times: {[f'{t:.3f}s' for t in load_times]}") 225 | 226 | # All loads should complete within reasonable time 227 | assert all(t < 10.0 for t in load_times) 228 | 229 | # Check if there's significant variance (indicating potential caching) 230 | avg_time = sum(load_times) / len(load_times) 231 | max_deviation = max(abs(t - avg_time) for t in load_times) 232 | print(f"Average: {avg_time:.3f}s, Max deviation: {max_deviation:.3f}s") 233 | 234 | t.close() 235 | 236 | def test_memory_usage_large_slice(self, large_trajectory_path): 237 | """Test memory efficiency with large slices.""" 238 | t = Trajectory(large_trajectory_path, mode="r") 239 | 240 | # Load progressively larger slices 241 | slice_sizes = [10, 50, 100, 200, 500] 242 | 243 | for size in slice_sizes: 244 | start_time = time.time() 245 | data = t.load(data_slice=slice(0, size)) 246 | load_time = time.time() - start_time 247 | 248 | # Verify correct size 249 | assert len(next(iter(data.values()))) == size 250 | 251 | # Check that larger slices don't have dramatically worse performance 252 | print(f"Slice size {size}: {load_time:.3f}s") 253 | 254 | # Performance should scale reasonably 255 | assert load_time < size * 0.01 + 5.0 # Very loose upper bound 256 | 257 | t.close() 258 | 259 | def test_container_return_performance(self, large_trajectory_path): 260 | """Test that container return is consistently fast regardless of other parameters.""" 261 | t = Trajectory(large_trajectory_path, mode="r") 262 | 263 | # Test container return with various parameters 264 | test_cases = [ 265 | {}, # No parameters 266 | { 267 | "data_slice": slice(0, 1000) 268 | }, # Large slice 269 | { 270 | "desired_frequency": 1.0 271 | }, # Heavy resampling 272 | { 273 | "desired_frequency": 5.0, 274 | "data_slice": slice(100, 900) 275 | }, # Combined 276 | ] 277 | 278 | for i, params in enumerate(test_cases): 279 | params["return_type"] = "container" 280 | 281 | start_time = time.time() 282 | result = t.load(**params) 283 | container_time = time.time() - start_time 284 | 285 | # Verify result 286 | assert result == large_trajectory_path 287 | 288 | print(f"Container return {i+1}: {container_time:.3f}s") 289 | 290 | # Should be consistently very fast 291 | assert container_time < 0.1 # Should be nearly instantaneous 292 | 293 | t.close() 294 | 295 | 296 | class TestTrajectoryLoaderScalability: 297 | """Test scalability characteristics of the loader.""" 298 | 299 | def test_scaling_with_feature_count(self, temp_dir, rng): 300 | """Test how performance scales with number of features.""" 301 | feature_counts = [5, 10, 20] 302 | times = [] 303 | 304 | for feature_count in feature_counts: 305 | path = os.path.join(temp_dir, f"features_{feature_count}.vla") 306 | traj = Trajectory(path, mode="w") 307 | 308 | # Create trajectory with many features 309 | for i in range(200): # Fewer timesteps to keep test reasonable 310 | data = {} 311 | for j in range(feature_count): 312 | data[f"feature_{j}"] = rng.normal(size=3).astype( 313 | np.float32) 314 | traj.add_by_dict(data, timestamp=i * 100) 315 | 316 | traj.close() 317 | 318 | # Time the loading 319 | t = Trajectory(path, mode="r") 320 | start_time = time.time() 321 | loaded = t.load() 322 | load_time = time.time() - start_time 323 | times.append(load_time) 324 | 325 | # Verify correctness 326 | assert len(loaded) == feature_count 327 | assert len(next(iter(loaded.values()))) == 200 328 | 329 | print(f"Loading {feature_count} features: {load_time:.3f}s") 330 | t.close() 331 | 332 | # Performance should scale reasonably with feature count 333 | assert all(t < 20.0 for t in times) 334 | 335 | def test_scaling_with_data_types(self, temp_dir, rng): 336 | """Test performance with different data types and sizes.""" 337 | path = os.path.join(temp_dir, "mixed_types.vla") 338 | traj = Trajectory(path, mode="w") 339 | 340 | # Create trajectory with varied data types 341 | for i in range(300): 342 | data = { 343 | "small_int": i, 344 | "float_val": float(i * 0.1), 345 | "string_data": f"item_{i}", 346 | "small_array": rng.normal(size=3).astype(np.float32), 347 | "medium_array": rng.normal(size=(10, 10)).astype(np.float32), 348 | "large_array": (rng.random( 349 | (20, 20, 3)) * 255).astype(np.uint8), 350 | } 351 | traj.add_by_dict(data, timestamp=i * 100) 352 | 353 | traj.close() 354 | 355 | t = Trajectory(path, mode="r") 356 | 357 | # Test loading different combinations 358 | test_cases = [ 359 | slice(0, 50), # Small slice 360 | slice(0, 150), # Medium slice 361 | slice(0, 300), # Full data 362 | slice(100, 200), # Middle slice 363 | ] 364 | 365 | for i, slice_params in enumerate(test_cases): 366 | start_time = time.time() 367 | data = t.load(data_slice=slice_params) 368 | load_time = time.time() - start_time 369 | 370 | expected_len = slice_params.stop - slice_params.start 371 | if slice_params.stop > 300: 372 | expected_len = 300 - slice_params.start 373 | 374 | actual_len = len(next(iter(data.values()))) 375 | assert actual_len == expected_len 376 | 377 | print( 378 | f"Mixed types, slice {i+1}: {load_time:.3f}s, {actual_len} frames" 379 | ) 380 | 381 | # Should complete reasonably quickly 382 | assert load_time < 15.0 383 | 384 | t.close() 385 | 386 | def test_performance_regression_protection(self, large_trajectory_path): 387 | """Basic regression test to catch significant performance degradation.""" 388 | t = Trajectory(large_trajectory_path, mode="r") 389 | 390 | # Define performance expectations (these are loose bounds) 391 | performance_expectations = [ 392 | (lambda: t.load(data_slice=slice(0, 10)), 2.0, "Small slice"), 393 | (lambda: t.load(data_slice=slice(0, 100)), 5.0, "Medium slice"), 394 | (lambda: t.load(desired_frequency=5.0), 10.0, "Resampling"), 395 | (lambda: t.load(return_type="container"), 0.1, "Container return"), 396 | ] 397 | 398 | for operation, max_time, description in performance_expectations: 399 | start_time = time.time() 400 | result = operation() 401 | operation_time = time.time() - start_time 402 | 403 | print(f"{description}: {operation_time:.3f}s (max: {max_time}s)") 404 | 405 | # Check against regression threshold 406 | if operation_time > max_time: 407 | pytest.fail( 408 | f"Performance regression detected: {description} took " 409 | f"{operation_time:.3f}s, expected < {max_time}s") 410 | 411 | t.close() 412 | 413 | 414 | @pytest.mark.slow 415 | class TestTrajectoryLoaderStressTests: 416 | """Stress tests for the loader (marked as slow).""" 417 | 418 | def test_very_large_trajectory_handling(self, temp_dir, rng): 419 | """Test handling of very large trajectories (if resources allow).""" 420 | path = os.path.join(temp_dir, "very_large.vla") 421 | traj = Trajectory(path, mode="w") 422 | 423 | # Create larger trajectory (but not so large it breaks CI) 424 | n_steps = 5000 425 | for i in range(n_steps): 426 | if i % 1000 == 0: 427 | print(f"Creating step {i}/{n_steps}") 428 | 429 | data = { 430 | "position": rng.normal(size=3).astype(np.float32), 431 | "image": (rng.random((16, 16, 3)) * 255).astype(np.uint8), 432 | } 433 | traj.add_by_dict(data, timestamp=i * 50) 434 | 435 | traj.close() 436 | 437 | t = Trajectory(path, mode="r") 438 | 439 | # Test various operations on large trajectory 440 | start_time = time.time() 441 | small_slice = t.load(data_slice=slice(1000, 1100)) 442 | slice_time = time.time() - start_time 443 | 444 | assert len(next(iter(small_slice.values()))) == 100 445 | print(f"Large trajectory slice: {slice_time:.3f}s") 446 | 447 | # Should still be reasonably fast due to seeking 448 | assert slice_time < 30.0 449 | 450 | t.close() 451 | 452 | def test_high_frequency_resampling_stress(self, large_trajectory_path): 453 | """Test resampling with various challenging frequency combinations.""" 454 | t = Trajectory(large_trajectory_path, mode="r") 455 | 456 | # Test challenging frequency combinations 457 | test_frequencies = [ 458 | 0.1, # Very low frequency 459 | 0.5, # Low frequency 460 | 19.9, # Just under original frequency 461 | 20.0, # Approximately original frequency 462 | 20.1, # Just above original frequency 463 | ] 464 | 465 | for freq in test_frequencies: 466 | start_time = time.time() 467 | resampled = t.load(desired_frequency=freq) 468 | resample_time = time.time() - start_time 469 | 470 | result_len = len(next(iter(resampled.values()))) 471 | print( 472 | f"Frequency {freq}Hz: {resample_time:.3f}s, {result_len} frames" 473 | ) 474 | 475 | # Should complete within reasonable time 476 | assert resample_time < 20.0 477 | 478 | # Result should be reasonable 479 | assert result_len >= 0 480 | 481 | t.close() 482 | --------------------------------------------------------------------------------