├── .all-contributorsrc ├── .codeclimate.yml ├── .github ├── FUNDING.yml ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── config.yml │ ├── doc_improvement.md │ └── feature_request.md ├── PULL_REQUEST_TEMPLATE.md ├── actions │ └── setup-venv │ │ └── action.yml ├── release-please.yml ├── scripts │ └── update_requirements.py └── workflows │ ├── CI.yml │ ├── release.yml │ └── sync_requirements.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── CITATION.cff ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs ├── Makefile ├── docs │ └── source │ │ └── modules.rst ├── make.bat ├── requirements.txt ├── source │ ├── base_bootstrap.rst │ ├── block_bootstrap.rst │ ├── block_generator.rst │ ├── block_length_sampler.rst │ ├── block_resampler.rst │ ├── bootstrap.rst │ ├── conf.py │ ├── index.rst │ ├── markov_sampler.rst │ ├── odds_and_ends.rst │ ├── ranklags.rst │ ├── time_series_model.rst │ ├── time_series_simulator.rst │ ├── tsfit.rst │ ├── types.rst │ └── validate.rst └── sphinx_build.log ├── extension_templates └── bootstrap.py ├── pyproject.toml ├── report.xml ├── setup.sh ├── src └── tsbootstrap │ ├── __init__.py │ ├── base_bootstrap.py │ ├── base_bootstrap_configs.py │ ├── block_bootstrap.py │ ├── block_bootstrap_configs.py │ ├── block_generator.py │ ├── block_length_sampler.py │ ├── block_resampler.py │ ├── bootstrap.py │ ├── markov_sampler.py │ ├── py.typed │ ├── ranklags.py │ ├── registry │ ├── __init__.py │ ├── _lookup.py │ ├── _tags.py │ └── tests │ │ ├── __init__.py │ │ └── test_tags.py │ ├── tests │ ├── __init__.py │ ├── scenarios │ │ ├── __init__.py │ │ ├── scenarios.py │ │ ├── scenarios_bootstrap.py │ │ └── scenarios_getter.py │ ├── test_all_bootstraps.py │ ├── test_all_estimators.py │ ├── test_class_register.py │ └── test_switch.py │ ├── time_series_model.py │ ├── time_series_simulator.py │ ├── tsfit.py │ └── utils │ ├── __init__.py │ ├── dependencies.py │ ├── estimator_checks.py │ ├── odds_and_ends.py │ ├── types.py │ └── validate.py ├── tests ├── _nopytest_tests.py ├── test_base_bootstrap_configs.py ├── test_block_bootstrap.py ├── test_block_bootstrap_configs.py ├── test_block_generator.py ├── test_block_length_sampler.py ├── test_block_resampler.py ├── test_bootstrap.py ├── test_markov_sampler.py ├── test_odds_and_ends.py ├── test_ranklags.py ├── test_time_series_model.py ├── test_time_series_simulator.py ├── test_tsfit.py └── test_validate.py ├── tox.ini ├── tsbootstrap_logo.png └── uv_vs_pip.jpg /.all-contributorsrc: -------------------------------------------------------------------------------- 1 | { 2 | "projectName": "tsbootstrap", 3 | "projectOwner": "astrogilda", 4 | "repoType": "github", 5 | "repoHost": "https://github.com", 6 | "files": [ 7 | "README.md" 8 | ], 9 | "skipCi": true, 10 | "commitConvention": "angular", 11 | "commitType": "docs", 12 | "imageSize": 100, 13 | "contributorsPerLine": 7 14 | } 15 | -------------------------------------------------------------------------------- /.codeclimate.yml: -------------------------------------------------------------------------------- 1 | version: "2" # required to adjust maintainability checks 2 | checks: 3 | argument-count: 4 | config: 5 | threshold: 7 6 | complex-logic: 7 | config: 8 | threshold: 4 9 | file-lines: 10 | config: 11 | threshold: 2500 12 | method-complexity: 13 | config: 14 | threshold: 10 15 | method-count: 16 | config: 17 | threshold: 30 18 | method-lines: 19 | config: 20 | threshold: 25 21 | nested-control-flow: 22 | config: 23 | threshold: 4 24 | return-statements: 25 | config: 26 | threshold: 4 27 | similar-code: 28 | config: 29 | threshold: # language-specific defaults. an override will affect all languages. 30 | identical-code: 31 | config: 32 | threshold: # language-specific defaults. an override will affect all languages. 33 | 34 | 35 | plugins: 36 | bandit: 37 | enabled: true 38 | git-legal: 39 | enabled: true 40 | markdownlint: 41 | enabled: true 42 | radon: 43 | enabled: true 44 | sonar-python: 45 | enabled: true 46 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | custom: https://www.buymeacoffee.com/sankalp.gilda 2 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | contact_links: 2 | - name: "\U0001F4AC All other questions and general chat" 3 | url: https://discord.gg/5Em6GUrP 4 | about: Chat with the `tsbootstrap` community on Discord 5 | - name: "\u2709\uFE0F Code of Conduct incident reporting" 6 | url: https://www.sktime.net/en/latest/get_involved/code_of_conduct.html#incident-reporting-guidelines 7 | about: Report an incident to the Code of Conduct committee 8 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/doc_improvement.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F4D6 Documentation improvement" 3 | about: Create a report to help us improve the documentation. Alternatively you can just open a pull request with the suggested change. 4 | title: "[DOC]" 5 | labels: documentation 6 | assignees: '' 7 | 8 | --- 9 | 10 | #### Describe the issue linked to the documentation 11 | 12 | 15 | 16 | #### Suggest a potential alternative/fix 17 | 18 | 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | 2 | # Pull Request Template 3 | 4 | ## Description 5 | 6 | Please include a clear and concise description of what the pull request does. Include any relevant issues this PR addresses. 7 | 8 | ## Type of change 9 | 10 | Please delete options that are not relevant. 11 | 12 | - [ ] Bug fix (non-breaking change which fixes an issue) 13 | - [ ] New feature (non-breaking change which adds functionality) 14 | - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) 15 | - [ ] This change requires a documentation update 16 | 17 | ## How Has This Been Tested? 18 | 19 | Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce. Please also list any relevant details for your test configuration. 20 | 21 | - [ ] Test A 22 | - [ ] Test B 23 | 24 | ## Checklist: 25 | 26 | - [ ] My code follows the style guidelines of this project 27 | - [ ] I have performed a self-review of my own code 28 | - [ ] I have commented my code, particularly in hard-to-understand areas 29 | - [ ] I have made corresponding changes to the documentation 30 | - [ ] My changes generate no new warnings 31 | - [ ] Any dependent changes have been merged and published in downstream modules 32 | 33 | ## Additional Information (if applicable) 34 | 35 | - Any additional details you want to add related to the changes 36 | 37 | ## Add All Contributors Command 38 | 39 | Remember to acknowledge your contributions, replace `contribution_type` with your contribution (code, doc, etc.): 40 | 41 | ```plaintext 42 | @all-contributors please add @ for 43 | ``` 44 | -------------------------------------------------------------------------------- /.github/actions/setup-venv/action.yml: -------------------------------------------------------------------------------- 1 | name: Setup Python Virtual Environment 2 | 3 | description: | 4 | This composite action sets up a Python virtual environment using `uv`. It handles the installation of `uv` on different operating systems and creates the virtual environment. This action is reusable across multiple jobs to ensure consistency and reduce duplication. 5 | 6 | inputs: 7 | python-version: 8 | description: 'Python version to set up' 9 | required: true 10 | default: '3.11' 11 | 12 | runs: 13 | using: "composite" 14 | steps: 15 | # Step 1: Install uv 16 | - name: Install uv on Windows 17 | if: runner.os == 'Windows' 18 | run: | 19 | irm https://astral.sh/uv/install.ps1 | iex 20 | shell: pwsh 21 | 22 | - name: Install uv on Linux and macOS 23 | if: runner.os != 'Windows' 24 | run: | 25 | curl -LsSf https://astral.sh/uv/install.sh | sh 26 | shell: bash 27 | 28 | # Step 2: Update PATH to include uv binaries 29 | - name: Update PATH on Windows 30 | if: runner.os == 'Windows' 31 | run: | 32 | echo "$(python -m site --user-base)/Scripts" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append 33 | shell: pwsh 34 | 35 | - name: Update PATH 36 | if: runner.os != 'Windows' 37 | run: | 38 | echo "$(python -m site --user-base)/bin" >> $GITHUB_PATH 39 | shell: bash 40 | 41 | # Step 3: Create the virtual environment 42 | - name: Create virtual environment on Windows 43 | if: runner.os == 'Windows' 44 | run: | 45 | uv venv .venv 46 | shell: pwsh 47 | 48 | - name: Create virtual environment on Linux and macOS 49 | if: runner.os != 'Windows' 50 | run: | 51 | uv venv .venv 52 | shell: bash 53 | 54 | # Step 4: Activate virtual environment and show Python path 55 | - name: Activate and Verify Virtual Environment 56 | if: runner.os == 'Windows' 57 | run: | 58 | .\.venv\Scripts\Activate.ps1 59 | where python 60 | shell: pwsh 61 | 62 | - name: Activate and Verify Virtual Environment 63 | if: runner.os != 'Windows' 64 | run: | 65 | source .venv/bin/activate 66 | which python 67 | shell: bash 68 | -------------------------------------------------------------------------------- /.github/release-please.yml: -------------------------------------------------------------------------------- 1 | # config for release-please bot 2 | primaryBranch: main 3 | releaseType: python 4 | handleGHRelease: true 5 | -------------------------------------------------------------------------------- /.github/scripts/update_requirements.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import tomlkit 4 | 5 | 6 | def update_requirements(): 7 | # Navigate up two levels to the root directory, then to 'pyproject.toml' 8 | pyproject_path = Path(__file__).parent.parent.parent / "pyproject.toml" 9 | 10 | with Path(pyproject_path).open("r") as pyproject: 11 | data = tomlkit.parse(pyproject.read()) 12 | 13 | # Get the dependencies as a list 14 | dependencies = data["project"]["dependencies"] # type: ignore 15 | 16 | docs_dependencies = data["project"]["optional-dependencies"]["docs"] # type: ignore 17 | 18 | requirements_path = ( 19 | Path(__file__).parent.parent.parent / "docs/requirements.txt" 20 | ) 21 | with Path(requirements_path).open("w") as requirements: 22 | for dep in dependencies: # type: ignore 23 | if dep != "python": 24 | # Directly write the dependency string to requirements.txt 25 | requirements.write(f"{dep}\n") 26 | for docs_dep in docs_dependencies: # type: ignore 27 | requirements.write(f"{docs_dep}\n") 28 | 29 | 30 | if __name__ == "__main__": 31 | update_requirements() 32 | -------------------------------------------------------------------------------- /.github/workflows/CI.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | # Minimal permissions following the principle of least privilege 4 | permissions: 5 | contents: read 6 | 7 | # Trigger CI on pushes to main and pull requests targeting main or release branches 8 | on: 9 | push: 10 | branches: 11 | - main 12 | paths: 13 | - 'src/**' 14 | - 'tests/**' 15 | - '.github/workflows/CI.yml' 16 | - 'pyproject.toml' 17 | - 'docs/**' 18 | pull_request: 19 | branches: 20 | - main 21 | - 'release**' 22 | paths: 23 | - 'src/**' 24 | - 'tests/**' 25 | - '.github/workflows/CI.yml' 26 | - 'pyproject.toml' 27 | - 'docs/**' 28 | 29 | jobs: 30 | # Job to test core dependencies without optional (soft) dependencies 31 | test-core-dependencies: 32 | name: Test Core Dependencies 33 | runs-on: ubuntu-latest 34 | steps: 35 | # Step 1: Checkout the repository 36 | - uses: actions/checkout@v4 37 | with: 38 | fetch-depth: 0 # Ensure full history for accurate branch information 39 | 40 | # Step 2: Set up Python 3.11 41 | - name: Set up Python 3.11 42 | uses: actions/setup-python@v5 43 | with: 44 | python-version: '3.11' 45 | 46 | # Step 3: Cache pip dependencies to speed up builds 47 | - name: Cache pip 48 | uses: actions/cache@v4 49 | with: 50 | path: ~/.cache/pip 51 | key: ${{ runner.os }}-pip-${{ hashFiles('**/pyproject.toml') }} 52 | restore-keys: | 53 | ${{ runner.os }}-pip- 54 | 55 | # Step 4: Setup virtual environment using the composite action 56 | - name: Setup Virtual Environment 57 | uses: ./.github/actions/setup-venv 58 | with: 59 | python-version: '3.11' 60 | 61 | # Step 5: Install package without optional dependencies 62 | - name: Install Package and Core Dependencies (Linux/macOS) 63 | if: runner.os != 'Windows' 64 | run: | 65 | source .venv/bin/activate 66 | uv pip install . --no-cache-dir 67 | shell: bash 68 | 69 | - name: Install Package and Core Dependencies (Windows) 70 | if: runner.os == 'Windows' 71 | run: | 72 | .\.venv\Scripts\Activate.ps1 73 | uv pip install . --no-cache-dir 74 | shell: pwsh 75 | 76 | # Step 6: Display installed dependencies for verification 77 | - name: Show Dependencies (Linux/macOS) 78 | if: runner.os != 'Windows' 79 | run: | 80 | source .venv/bin/activate 81 | uv pip list 82 | shell: bash 83 | 84 | - name: Show Dependencies (Windows) 85 | if: runner.os == 'Windows' 86 | run: | 87 | .\.venv\Scripts\Activate.ps1 88 | uv pip list 89 | shell: pwsh 90 | 91 | # Step 7: Run pytest-free tests 92 | - name: Run pytest-free Tests (Linux/macOS) 93 | if: runner.os != 'Windows' 94 | run: | 95 | source .venv/bin/activate 96 | python tests/_nopytest_tests.py 97 | shell: bash 98 | 99 | - name: Run pytest-free Tests (Windows) 100 | if: runner.os == 'Windows' 101 | run: | 102 | .\.venv\Scripts\Activate.ps1 103 | python tests/_nopytest_tests.py 104 | shell: pwsh 105 | 106 | # Job to test without optional dependencies across multiple Python versions and OSes 107 | test-no-optional-dependencies: 108 | name: Test Without Optional Dependencies 109 | needs: test-core-dependencies 110 | runs-on: ${{ matrix.os }} 111 | strategy: 112 | fail-fast: false 113 | matrix: 114 | python-version: ['3.9', '3.10', '3.11', '3.12'] 115 | os: [ubuntu-latest, macos-13, windows-latest] 116 | steps: 117 | # Step 1: Checkout the repository 118 | - uses: actions/checkout@v4 119 | with: 120 | fetch-depth: 0 121 | 122 | # Step 2: Set up Python with the specified version 123 | - name: Set up Python ${{ matrix.python-version }} 124 | uses: actions/setup-python@v5 125 | with: 126 | python-version: ${{ matrix.python-version }} 127 | 128 | # Step 3: Cache pip dependencies to speed up builds 129 | - name: Cache pip 130 | uses: actions/cache@v4 131 | with: 132 | path: ~/.cache/pip 133 | key: ${{ runner.os }}-pip-${{ hashFiles('**/pyproject.toml') }} 134 | restore-keys: | 135 | ${{ runner.os }}-pip- 136 | 137 | # Step 4: Setup virtual environment using the composite action 138 | - name: Setup Virtual Environment 139 | uses: ./.github/actions/setup-venv 140 | with: 141 | python-version: ${{ matrix.python-version }} 142 | 143 | # Step 5: Install package and dev dependencies 144 | - name: Install Package and Dev Dependencies (Linux/macOS) 145 | if: runner.os != 'Windows' 146 | run: | 147 | source .venv/bin/activate 148 | uv pip install .[dev] --no-cache-dir 149 | shell: bash 150 | 151 | - name: Install Package and Dev Dependencies (Windows) 152 | if: runner.os == 'Windows' 153 | run: | 154 | .\.venv\Scripts\Activate.ps1 155 | uv pip install .[dev] --no-cache-dir 156 | shell: pwsh 157 | 158 | # Step 6: Display installed dependencies for verification 159 | - name: Show Dependencies (Linux/macOS) 160 | if: runner.os != 'Windows' 161 | run: | 162 | source .venv/bin/activate 163 | uv pip list 164 | shell: bash 165 | 166 | - name: Show Dependencies (Windows) 167 | if: runner.os == 'Windows' 168 | run: | 169 | .\.venv\Scripts\Activate.ps1 170 | uv pip list 171 | shell: pwsh 172 | 173 | # Step 7: Show available branches for debugging 174 | - name: Show Available Branches 175 | run: git branch -a 176 | 177 | # Step 8: Run tests using pytest 178 | - name: Run Tests (Linux/macOS) 179 | if: runner.os != 'Windows' 180 | run: | 181 | source .venv/bin/activate 182 | python -m pytest src/ tests/ -vv 183 | shell: bash 184 | 185 | - name: Run Tests (Windows) 186 | if: runner.os == 'Windows' 187 | run: | 188 | .\.venv\Scripts\Activate.ps1 189 | python -m pytest src/ tests/ -vv 190 | shell: pwsh 191 | 192 | # Job to test with all optional dependencies across multiple Python versions and OSes 193 | test-all-optional-dependencies: 194 | name: Test With All Optional Dependencies 195 | needs: test-no-optional-dependencies 196 | runs-on: ${{ matrix.os }} 197 | strategy: 198 | fail-fast: false 199 | matrix: 200 | python-version: ['3.9', '3.10', '3.11', '3.12'] 201 | os: [ubuntu-latest, macos-13, windows-latest] 202 | steps: 203 | # Step 1: Checkout the repository 204 | - uses: actions/checkout@v4 205 | with: 206 | fetch-depth: 0 207 | 208 | # Step 2: Set up Python with the specified version 209 | - name: Set up Python ${{ matrix.python-version }} 210 | uses: actions/setup-python@v5 211 | with: 212 | python-version: ${{ matrix.python-version }} 213 | 214 | # Step 3: Cache pip dependencies to speed up builds 215 | - name: Cache pip 216 | uses: actions/cache@v4 217 | with: 218 | path: ~/.cache/pip 219 | key: ${{ runner.os }}-pip-${{ hashFiles('**/pyproject.toml') }} 220 | restore-keys: | 221 | ${{ runner.os }}-pip- 222 | 223 | # Step 4: Setup virtual environment using the composite action 224 | - name: Setup Virtual Environment 225 | uses: ./.github/actions/setup-venv 226 | with: 227 | python-version: ${{ matrix.python-version }} 228 | 229 | # Step 5: Install package and all dependencies including optional ones 230 | - name: Install Package and All Dependencies (Linux/macOS) 231 | if: runner.os != 'Windows' 232 | run: | 233 | source .venv/bin/activate 234 | uv pip install .[all_extras,dev] --no-cache-dir 235 | shell: bash 236 | 237 | - name: Install Package and All Dependencies (Windows) 238 | if: runner.os == 'Windows' 239 | run: | 240 | .\.venv\Scripts\Activate.ps1 241 | uv pip install .[all_extras,dev] --no-cache-dir 242 | shell: pwsh 243 | 244 | # Step 6: Display installed dependencies for verification 245 | - name: Show Dependencies (Linux/macOS) 246 | if: runner.os != 'Windows' 247 | run: | 248 | source .venv/bin/activate 249 | uv pip list 250 | shell: bash 251 | 252 | - name: Show Dependencies (Windows) 253 | if: runner.os == 'Windows' 254 | run: | 255 | .\.venv\Scripts\Activate.ps1 256 | uv pip list 257 | shell: pwsh 258 | 259 | # Step 7: Show available branches for debugging 260 | - name: Show Available Branches 261 | run: git branch -a 262 | 263 | # Step 8: Run tests using pytest 264 | - name: Run Tests (Linux/macOS) 265 | if: runner.os != 'Windows' 266 | run: | 267 | source .venv/bin/activate 268 | python -m pytest src/ tests/ -vv 269 | shell: bash 270 | 271 | - name: Run Tests (Windows) 272 | if: runner.os == 'Windows' 273 | run: | 274 | .\.venv\Scripts\Activate.ps1 275 | python -m pytest src/ tests/ -vv 276 | shell: pwsh 277 | 278 | # Step 9: Upload code coverage report to GitHub artifacts 279 | - name: Upload Coverage Report 280 | uses: actions/upload-artifact@v4 281 | with: 282 | name: coverage-${{ matrix.python-version }}-${{ runner.os }} 283 | path: coverage.md 284 | 285 | # Step 10: Publish code coverage to Codecov 286 | - name: Publish Code Coverage 287 | uses: codecov/codecov-action@v4 288 | with: 289 | token: ${{ secrets.CODECOV_TOKEN }} # Ensure this secret is set in your repository 290 | fail_ci_if_error: true 291 | verbose: true 292 | 293 | # Job to build and test documentation 294 | docs: 295 | name: Test Docs Build 296 | runs-on: ubuntu-latest 297 | steps: 298 | # Step 1: Checkout the repository 299 | - uses: actions/checkout@v4 300 | with: 301 | fetch-depth: 0 302 | 303 | # Step 2: Set up Python 3.11 304 | - name: Set up Python 3.11 305 | uses: actions/setup-python@v5 306 | with: 307 | python-version: '3.11' 308 | 309 | # Step 3: Cache pip dependencies to speed up builds 310 | - name: Cache pip 311 | uses: actions/cache@v4 312 | with: 313 | path: ~/.cache/pip 314 | key: ${{ runner.os }}-pip-${{ hashFiles('**/pyproject.toml') }} 315 | restore-keys: | 316 | ${{ runner.os }}-pip- 317 | 318 | # Step 4: Install uv and update PATH 319 | - name: Install uv and Update PATH 320 | run: | 321 | pip install uv 322 | echo "$(python -m site --user-base)/bin" >> $GITHUB_PATH 323 | shell: bash 324 | 325 | # Step 5: Setup virtual environment using the composite action 326 | - name: Setup Virtual Environment 327 | uses: ./.github/actions/setup-venv 328 | with: 329 | python-version: '3.11' 330 | 331 | # Step 6: Install package and documentation dependencies 332 | - name: Install Package and Dependencies (Linux/macOS) 333 | if: runner.os != 'Windows' 334 | run: | 335 | source .venv/bin/activate 336 | uv pip install .[dev,docs] --no-cache-dir 337 | shell: bash 338 | 339 | - name: Install Package and Dependencies (Windows) 340 | if: runner.os == 'Windows' 341 | run: | 342 | .\.venv\Scripts\Activate.ps1 343 | uv pip install .[dev,docs] --no-cache-dir 344 | shell: pwsh 345 | 346 | # Step 7: Build Sphinx documentation 347 | - name: Build Sphinx Documentation (Linux/macOS) 348 | if: runner.os != 'Windows' 349 | run: | 350 | source .venv/bin/activate 351 | cd docs 352 | make clean 353 | make html --debug --jobs 2 SPHINXOPTS=" -W -v" 354 | shell: bash 355 | 356 | - name: Build Sphinx Documentation (Windows) 357 | if: runner.os == 'Windows' 358 | run: | 359 | .\.venv\Scripts\Activate.ps1 360 | cd docs 361 | make clean 362 | make html --debug --jobs 2 SPHINXOPTS=" -W -v" 363 | shell: pwsh 364 | 365 | # Step 8: Upload built documentation as an artifact 366 | - name: Upload Built Docs 367 | uses: actions/upload-artifact@v4 368 | with: 369 | name: docs-results-${{ runner.os }} 370 | path: docs/build/html/ 371 | # Ensure this step runs even if previous steps fail 372 | if: success() 373 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Build wheels and publish to PyPI 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | build_wheels: 9 | name: Build wheels 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - uses: actions/checkout@v4 14 | 15 | - uses: actions/setup-python@v5 16 | with: 17 | python-version: '3.10' 18 | 19 | - name: Install build tools 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install build 23 | 24 | - name: Build wheel 25 | run: | 26 | rm -rf dist/ build/ wheelhouse/ # Clean up previous builds 27 | python -m build --wheel --sdist --outdir wheelhouse 28 | 29 | - name: Store wheels 30 | uses: actions/upload-artifact@v4 31 | with: 32 | name: wheels 33 | path: wheelhouse/* 34 | 35 | test_unix_wheels: 36 | needs: build_wheels 37 | name: Test wheels on ${{ matrix.os }} with ${{ matrix.python-version }} 38 | runs-on: ${{ matrix.os }} 39 | strategy: 40 | fail-fast: false # to not fail all combinations if just one fail 41 | matrix: 42 | os: [ubuntu-latest, macos-13] 43 | python-version: ['3.9', '3.10', '3.11', '3.12'] 44 | 45 | steps: 46 | - uses: actions/checkout@v4 47 | - uses: actions/setup-python@v5 48 | with: 49 | python-version: ${{ matrix.python-version }} 50 | 51 | - uses: actions/download-artifact@v4 52 | with: 53 | name: wheels 54 | path: wheelhouse 55 | 56 | - name: Display downloaded artifacts 57 | run: ls -l wheelhouse 58 | 59 | - name: Get wheel filename 60 | run: echo "WHEELNAME=$(ls ./wheelhouse/tsbootstrap-*none-any.whl)" >> $GITHUB_ENV 61 | 62 | - name: Install wheel and extras 63 | run: python -m pip install "${{ env.WHEELNAME }}[all_extras,dev]" 64 | 65 | - name: Run tests 66 | run: python -m pytest 67 | 68 | test_windows_wheels: 69 | needs: build_wheels 70 | name: Test wheels on ${{ matrix.os }} with ${{ matrix.python-version }} 71 | runs-on: windows-latest 72 | strategy: 73 | fail-fast: false # to not fail all combinations if just one fail 74 | matrix: 75 | os: [windows-latest] 76 | python-version: ['3.9', '3.10', '3.11', '3.12'] 77 | 78 | steps: 79 | - uses: actions/checkout@v4 80 | - uses: actions/setup-python@v5 81 | with: 82 | python-version: ${{ matrix.python-version }} 83 | 84 | - uses: actions/download-artifact@v4 85 | with: 86 | name: wheels 87 | path: wheelhouse 88 | 89 | - name: Display downloaded artifacts 90 | run: ls -l wheelhouse 91 | 92 | - name: Get wheel filename 93 | run: echo "WHEELNAME=$(ls ./wheelhouse/tsbootstrap-*none-any.whl)" >> $env:GITHUB_ENV 94 | 95 | - name: Install wheel and extras 96 | run: python -m pip install "${env:WHEELNAME}[all_extras,dev]" 97 | 98 | - name: Run tests # explicit commands as windows does not support make 99 | run: python -m pytest 100 | 101 | upload_wheels: 102 | name: Upload wheels to PyPI 103 | runs-on: ubuntu-latest 104 | needs: [build_wheels,test_unix_wheels,test_windows_wheels] 105 | 106 | steps: 107 | - uses: actions/download-artifact@v4 108 | with: 109 | name: wheels 110 | path: wheelhouse 111 | 112 | - name: Publish package to PyPI 113 | uses: pypa/gh-action-pypi-publish@release/v1 114 | with: 115 | password: ${{ secrets.PYPI_TOKEN }} 116 | packages-dir: wheelhouse/ 117 | skip-existing: true 118 | -------------------------------------------------------------------------------- /.github/workflows/sync_requirements.yml: -------------------------------------------------------------------------------- 1 | name: Synchronize Documentation Requirements 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - 'pyproject.toml' 9 | pull_request: 10 | branches: 11 | - main 12 | paths: 13 | - 'pyproject.toml' 14 | 15 | jobs: 16 | update-docs-requirements: 17 | runs-on: ubuntu-latest 18 | 19 | steps: 20 | - name: Check out the repository 21 | uses: actions/checkout@v4 22 | with: 23 | token: ${{ secrets.MY_GITHUB_PAT }} # Use the PAT for checkout 24 | 25 | - name: Set up Python 26 | uses: actions/setup-python@v5 27 | with: 28 | python-version: '3.x' 29 | 30 | - name: Install tomlkit for Python TOML manipulation 31 | run: pip install tomlkit 32 | 33 | - name: Update docs/requirements.txt 34 | run: | 35 | python .github/scripts/update_requirements.py 36 | 37 | - name: Create Pull Request 38 | uses: peter-evans/create-pull-request@v6 39 | with: 40 | token: ${{ secrets.MY_GITHUB_PAT }} # Use the PAT for PR creation 41 | commit-message: Update docs/requirements.txt 42 | title: '[Automated] Update documentation requirements' 43 | branch: update-docs-requirements 44 | base: main 45 | body: | 46 | This is an automated pull request to update the documentation requirements based on pyproject.toml. 47 | labels: | 48 | automated PR 49 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | .Python 10 | build/ 11 | develop-eggs/ 12 | dist/ 13 | downloads/ 14 | eggs/ 15 | .eggs/ 16 | lib/ 17 | lib64/ 18 | parts/ 19 | sdist/ 20 | var/ 21 | wheels/ 22 | pip-wheel-metadata/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | super-linter.log 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Sphinx documentation 59 | docs/build/ 60 | docs/source/api/ 61 | docs/source/CHANGELOG.md 62 | 63 | 64 | 65 | # pyenv 66 | .python-version 67 | 68 | # pipenv 69 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 70 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 71 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 72 | # install all needed dependencies. 73 | #Pipfile.lock 74 | 75 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 76 | __pypackages__/ 77 | 78 | 79 | # Environments 80 | .env 81 | .venv* 82 | env/ 83 | ENV/ 84 | env.bak/ 85 | venv.bak/ 86 | venv*/ 87 | 88 | # mkdocs documentation 89 | /site 90 | 91 | # mypy 92 | .mypy_cache/ 93 | .dmypy.json 94 | dmypy.json 95 | 96 | # Pyre type checker 97 | .pyre/ 98 | .idea/ 99 | 100 | # Ignore vscode 101 | .vscode/** 102 | .vscode/ 103 | .devcontainer/ 104 | 105 | 106 | # MacOS files 107 | .DS_Store 108 | 109 | #ignore gitattributes 110 | !.gitattributes 111 | 112 | #ignore gitmodules 113 | !.gitmodules 114 | 115 | #ignore gitkeep 116 | !.gitkeep 117 | 118 | #ignore gitconfig 119 | !.gitconfig 120 | 121 | #ignore gitignore_global 122 | !.gitignore_global 123 | 124 | 125 | #.ruff linter 126 | .ruff_cache/ 127 | 128 | #.whl files 129 | *.whl 130 | 131 | # temporary 132 | README_template.md 133 | 134 | # scratch file 135 | scratch* 136 | 137 | # poetry.lock 138 | poetry.lock 139 | 140 | # we don't need bumpversion anymore 141 | .bumpversion.cfg 142 | .github/workflows/bumpversion.yml 143 | 144 | # image files, except for tsbootstrap_logo.png and uv_vs_pip.jpg 145 | *.png 146 | *.jpg 147 | *.jpeg 148 | *.dot 149 | !tsbootstrap_logo.png 150 | !uv_vs_pip.jpg 151 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | default_install_hook_types: [pre-commit, commit-msg] 3 | 4 | repos: 5 | 6 | - repo: https://github.com/pre-commit/pre-commit-hooks 7 | rev: v4.6.0 8 | hooks: 9 | #- id: check-added-large-files 10 | - id: check-ast 11 | - id: check-builtin-literals 12 | - id: check-case-conflict 13 | - id: check-docstring-first 14 | - id: check-shebang-scripts-are-executable 15 | - id: check-merge-conflict 16 | - id: check-json 17 | - id: check-toml 18 | - id: check-xml 19 | - id: check-yaml 20 | - id: debug-statements 21 | - id: destroyed-symlinks 22 | - id: detect-private-key 23 | - id: end-of-file-fixer 24 | exclude: '^LICENSE|.*\.(html|csv|txt|svg|py)$|^poetry\.lock$|\.pyc$|\.pyo$|\.pyd$|__pycache__|^venv|^\.venv' 25 | - id: pretty-format-json 26 | args: ["--autofix", "--no-ensure-ascii", "--no-sort-keys"] 27 | exclude: '\.(html|svg)$|^poetry\.lock$|\.pyc$|\.pyo$|\.pyd$|__pycache__|^venv|^\.venv' 28 | - id: trailing-whitespace 29 | args: [--markdown-linebreak-ext=md] 30 | exclude: '\.(html|svg)$|^poetry\.lock$|\.pyc$|\.pyo$|\.pyd$|__pycache__|^venv|^\.venv' 31 | 32 | #- repo: https://github.com/alessandrojcm/commitlint-pre-commit-hook 33 | #rev: v9.1.0 34 | #hooks: 35 | #- id: commitlint 36 | #stages: [commit-msg] 37 | #additional_dependencies: ["@commitlint/config-conventional", "@commitlint/ensure", "commitlint-plugin-function-rules"] 38 | #exclude: '^poetry\.lock$|\.pyc$|\.pyo$|\.pyd$|__pycache__|^venv|^\.venv' 39 | 40 | - repo: https://github.com/astral-sh/ruff-pre-commit 41 | rev: v0.5.0 42 | hooks: 43 | - id: ruff 44 | args: [--fix, --exit-non-zero-on-fix] 45 | exclude: '^poetry\.lock$|\.pyc$|\.pyo$|\.pyd$|__pycache__|^venv|^\.venv' 46 | 47 | - repo: https://github.com/crate-ci/typos 48 | rev: v1.22.9 49 | hooks: 50 | - id: typos 51 | args: [--write-changes] 52 | exclude: '^LICENSE|.*\.(html|csv|txt|svg|py)$|^poetry\.lock$|\.pyc$|\.pyo$|\.pyd$|__pycache__|^venv|^\.venv|^README' 53 | 54 | - repo: https://github.com/myint/autoflake 55 | rev: v2.3.1 56 | hooks: 57 | - id: autoflake 58 | args: ["--remove-all-unused-imports", "--remove-unused-variables", "--in-place", "--recursive"] 59 | exclude: '^poetry\.lock$|\.pyc$|\.pyo$|\.pyd$|__pycache__|^venv|^\.venv' 60 | 61 | - repo: https://github.com/psf/black 62 | rev: 24.4.2 63 | hooks: 64 | - id: black 65 | args: 66 | - "--target-version=py39" 67 | - "--target-version=py310" 68 | - "--target-version=py311" 69 | - "--target-version=py312" 70 | - "--line-length=79" 71 | types: [python] 72 | exclude: '^poetry\.lock$|\.pyc$|\.pyo$|\.pyd$|__pycache__|^venv|^\.venv' 73 | 74 | - repo: https://github.com/adamchainz/blacken-docs 75 | rev: 1.18.0 76 | hooks: 77 | - id: blacken-docs 78 | additional_dependencies: 79 | - "black==24.4.2" 80 | args: 81 | - "--line-length=79" 82 | exclude: '^poetry\.lock$|\.pyc$|\.pyo$|\.pyd$|__pycache__|^venv|^\.venv' 83 | 84 | - repo: https://github.com/econchick/interrogate 85 | rev: 1.7.0 86 | hooks: 87 | - id: interrogate 88 | args: [src/tsbootstrap, -v, -i, --fail-under=80, "-c", "pyproject.toml"] 89 | pass_filenames: false 90 | exclude: '^poetry\.lock$|\.pyc$|\.pyo$|\.pyd$|__pycache__|^venv|^\.venv' 91 | 92 | #- repo: local 93 | # hooks: 94 | # - id: bumpversion 95 | # name: bumpversion 96 | # entry: poetry run bumpversion patch 97 | # language: system 98 | # stages: [push] 99 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the OS, Python version, and other tools you might need 9 | build: 10 | os: ubuntu-22.04 11 | tools: 12 | python: "3.10" 13 | 14 | # Build documentation in the "docs/source" directory with Sphinx 15 | sphinx: 16 | configuration: docs/source/conf.py 17 | 18 | # Optionally build your docs in additional formats such as PDF and ePub 19 | formats: 20 | - pdf 21 | - epub 22 | 23 | # Declare the Python requirements required to build your documentation 24 | # and install the package itself 25 | python: 26 | install: 27 | - requirements: docs/requirements.txt 28 | - method: pip 29 | path: . 30 | extra_requirements: 31 | - docs 32 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "Gilda" 5 | given-names: "Sankalp" 6 | orcid: "https://orcid.org/0000-0002-3645-4501" 7 | title: "tsbootstrap" 8 | version: 0.1.5 9 | doi: 10.5281/zenodo.8226495 10 | date-released: 2024/04/23 11 | url: "https://github.com/astrogilda/tsbootstrap" 12 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct - tsbootstrap 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to a positive environment for our 15 | community include: 16 | 17 | * Demonstrating empathy and kindness toward other people 18 | * Being respectful of differing opinions, viewpoints, and experiences 19 | * Giving and gracefully accepting constructive feedback 20 | * Accepting responsibility and apologizing to those affected by our mistakes, 21 | and learning from the experience 22 | * Focusing on what is best not just for us as individuals, but for the 23 | overall community 24 | 25 | Examples of unacceptable behavior include: 26 | 27 | * The use of sexualized language or imagery, and sexual attention or 28 | advances 29 | * Trolling, insulting or derogatory comments, and personal or political attacks 30 | * Public or private harassment 31 | * Publishing others' private information, such as a physical or email 32 | address, without their explicit permission 33 | * Other conduct which could reasonably be considered inappropriate in a 34 | professional setting 35 | 36 | ## Our Responsibilities 37 | 38 | Project maintainers are responsible for clarifying and enforcing our standards of 39 | acceptable behavior and will take appropriate and fair corrective action in 40 | response to any behavior that they deem inappropriate, 41 | threatening, offensive, or harmful. 42 | 43 | Project maintainers have the right and responsibility to remove, edit, or reject 44 | comments, commits, code, wiki edits, issues, and other contributions that are 45 | not aligned to this Code of Conduct, and will 46 | communicate reasons for moderation decisions when appropriate. 47 | 48 | ## Scope 49 | 50 | This Code of Conduct applies within all community spaces, and also applies when 51 | an individual is officially representing the community in public spaces. 52 | Examples of representing our community include using an official e-mail address, 53 | posting via an official social media account, or acting as an appointed 54 | representative at an online or offline event. 55 | 56 | ## Enforcement 57 | 58 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 59 | reported to the community leaders responsible for enforcement at . 60 | All complaints will be reviewed and investigated promptly and fairly. 61 | 62 | All community leaders are obligated to respect the privacy and security of the 63 | reporter of any incident. 64 | 65 | ## Enforcement Guidelines 66 | 67 | Community leaders will follow these Community Impact Guidelines in determining 68 | the consequences for any action they deem in violation of this Code of Conduct: 69 | 70 | ### 1. Correction 71 | 72 | **Community Impact**: Use of inappropriate language or other behavior deemed 73 | unprofessional or unwelcome in the community. 74 | 75 | **Consequence**: A private, written warning from community leaders, providing 76 | clarity around the nature of the violation and an explanation of why the 77 | behavior was inappropriate. A public apology may be requested. 78 | 79 | ### 2. Warning 80 | 81 | **Community Impact**: A violation through a single incident or series 82 | of actions. 83 | 84 | **Consequence**: A warning with consequences for continued behavior. No 85 | interaction with the people involved, including unsolicited interaction with 86 | those enforcing the Code of Conduct, for a specified period of time. This 87 | includes avoiding interactions in community spaces as well as external channels 88 | like social media. Violating these terms may lead to a temporary or 89 | permanent ban. 90 | 91 | ### 3. Temporary Ban 92 | 93 | **Community Impact**: A serious violation of community standards, including 94 | sustained inappropriate behavior. 95 | 96 | **Consequence**: A temporary ban from any sort of interaction or public 97 | communication with the community for a specified period of time. No public or 98 | private interaction with the people involved, including unsolicited interaction 99 | with those enforcing the Code of Conduct, is allowed during this period. 100 | Violating these terms may lead to a permanent ban. 101 | 102 | ### 4. Permanent Ban 103 | 104 | **Community Impact**: Demonstrating a pattern of violation of community 105 | standards, including sustained inappropriate behavior, harassment of an 106 | individual, or aggression toward or disparagement of classes of individuals. 107 | 108 | **Consequence**: A permanent ban from any sort of public interaction within 109 | the community. 110 | 111 | ## Attribution 112 | 113 | This Code of Conduct is adapted from the [Contributor Covenant](https://contributor-covenant.org/), version 114 | [1.4](https://www.contributor-covenant.org/version/1/4/code-of-conduct/code_of_conduct.md) and 115 | [2.0](https://www.contributor-covenant.org/version/2/0/code_of_conduct/code_of_conduct.md), 116 | and was generated by [contributing-gen](https://github.com/bttger/contributing-gen). 117 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to tsbootstrap 2 | 3 | Welcome to tsbootstrap, and thank you for considering contributing to our project! With over 1 million users, tsbootstrap is a community-driven effort that thrives on the diverse contributions from developers around the world. Whether you're fixing a bug, adding a new feature, improving documentation, or just suggesting an idea, your contribution is invaluable. 4 | 5 | ## Table of Contents 6 | 7 | 1. [Code of Conduct](#code-of-conduct) 8 | 2. [Getting Started](#getting-started) 9 | - [Environment Setup](#environment-setup) 10 | - [Finding Your First Issue](#finding-your-first-issue) 11 | 3. [Issue Creation Guidelines](#issue-creation-guidelines) 12 | - [Reporting Bugs](#reporting-bugs) 13 | - [Suggesting Enhancements](#suggesting-enhancements) 14 | - [Asking Questions](#asking-questions) 15 | 4. [Making Contributions](#making-contributions) 16 | - [Your First Code Contribution](#your-first-code-contribution) 17 | - [Pull Request Process](#pull-request-process) 18 | 5. [Improving Documentation](#improving-documentation) 19 | 6. [Style Guides](#style-guides) 20 | - [Code Style](#code-style) 21 | - [Commit Messages](#commit-messages) 22 | - [Documentation Style](#documentation-style) 23 | 7. [Community and Communication](#community-and-communication) 24 | 8. [Joining The Project Team](#joining-the-project-team) 25 | 9. [Attribution](#attribution) 26 | 27 | ## Code of Conduct 28 | 29 | Before contributing, please read our [Code of Conduct](https://github.com/astrogilda/tsbootstrap/blob/main/CODE_OF_CONDUCT.md). We are committed to providing a welcoming and inclusive environment. All contributors are expected to adhere to this code. 30 | 31 | ## Getting Started 32 | 33 | ### Environment Setup 34 | 35 | To contribute to tsbootstrap, you need to set up your development environment. Detailed instructions are available in our [Setup Guide](https://github.com/astrogilda/tsbootstrap/wiki/Setup-Guide), covering everything from cloning the repository to installing dependencies. 36 | 37 | ### Finding Your First Issue 38 | 39 | Looking for a place to start? Check out issues labeled `good first issue` or `help wanted`. These are great for first-timers. 40 | 41 | ## Issue Creation Guidelines 42 | 43 | ### Reporting Bugs 44 | 45 | Before reporting a bug, ensure it hasn't been reported already. If you find a new bug, create an issue providing: 46 | 47 | - A clear title and description. 48 | - Steps to reproduce. 49 | - Expected behavior. 50 | - Actual behavior. 51 | - Screenshots or code snippets, if applicable. 52 | 53 | ### Suggesting Enhancements 54 | 55 | We love new ideas! Before suggesting an enhancement, please check if it's already been suggested. When creating an enhancement suggestion, include: 56 | 57 | - A clear title and detailed description. 58 | - Why this enhancement would be beneficial. 59 | - Any potential implementation details or challenges. 60 | 61 | ### Asking Questions 62 | 63 | Got a question? First, check our FAQ and past issues. If you don't find an answer, open an issue with your question. Please provide as much context as possible to help us understand and address your question quickly. 64 | 65 | ## Making Contributions 66 | 67 | ### Your First Code Contribution 68 | 69 | Unsure where to begin? Our [Contributor's Guide](https://github.com/astrogilda/tsbootstrap/wiki/Contributor's-Guide) provides step-by-step instructions on how to make your first contribution. 70 | 71 | ### Pull Request Process 72 | 73 | 1. Fork the repository and create your branch from `main`. 74 | 2. If you've added code, add tests. 75 | 3. Ensure the test suite passes. 76 | 4. Update the documentation if necessary. 77 | 5. Submit a pull request. 78 | 79 | ## Improving Documentation 80 | 81 | Good documentation is crucial. To contribute: 82 | 83 | - Update, improve, or correct documentation. 84 | - Submit pull requests with your changes. 85 | - Follow our [Documentation Style Guide](https://github.com/astrogilda/tsbootstrap/wiki/Documentation-Style-Guide). 86 | 87 | ## Style Guides 88 | 89 | ### Code Style 90 | 91 | We use [Ruff](https://ruff.io) to ensure code consistency. This is run automatically in the CI when pushing code. 92 | 93 | ### Commit Messages 94 | 95 | Follow [Conventional Commits](https://www.conventionalcommits.org/) for clear, structured commit messages. 96 | 97 | ### Documentation Style 98 | 99 | Documentation should be clear, concise, and written in simple English. Use markdown for formatting. 100 | 101 | ## Community and Communication 102 | 103 | Join our [Slack](https://tsbootstrap.slack.com), [Discord](https://discord.gg/tsbootstrap), or [GitHub Discussions](https://github.com/astrogilda/tsbootstrap/discussions) to connect with other contributors and the core team. 104 | 105 | ## Joining The Project Team 106 | 107 | Interested in joining the core team? Email us at with your contributions and why you're interested in joining. 108 | 109 | ## Attribution 110 | 111 | This CONTRIBUTING guide is inspired by the open-source community and aims to make contributing to tsbootstrap as clear and beneficial as possible for everyone involved. 112 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Sankalp Gilda 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/docs/source/modules.rst: -------------------------------------------------------------------------------- 1 | docs 2 | ==== 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy<1.27,>=1.21 2 | scikit-base>=0.10.0,<0.11 3 | scikit-learn>=1.5.1,<1.6.0 4 | scipy>=1.13,<1.14.0 5 | packaging>=24.0,<24.2 6 | pydantic>=2.0,<3.0 7 | furo 8 | jupyter 9 | myst-parser 10 | nbsphinx>=0.8.6 11 | numpydoc 12 | pydata-sphinx-theme 13 | Sphinx!=7.2.0,<8.0.0 14 | sphinx-rtd-theme>=1.3.0 15 | sphinx-copybutton>=0.5.2 16 | sphinx-design<0.6.0 17 | sphinx-gallery<0.15.0 18 | sphinx-issues<4.0.0 19 | sphinx-version-warning 20 | tabulate>=0.9.0 21 | -------------------------------------------------------------------------------- /docs/source/base_bootstrap.rst: -------------------------------------------------------------------------------- 1 | Base Bootstrap 2 | ============== 3 | 4 | .. automodule:: tsbootstrap.base_bootstrap 5 | :members: 6 | :noindex: 7 | -------------------------------------------------------------------------------- /docs/source/block_bootstrap.rst: -------------------------------------------------------------------------------- 1 | Block Bootstrap 2 | =============== 3 | 4 | .. automodule:: tsbootstrap.block_bootstrap 5 | :members: 6 | :noindex: 7 | -------------------------------------------------------------------------------- /docs/source/block_generator.rst: -------------------------------------------------------------------------------- 1 | Block Generator 2 | =============== 3 | 4 | .. automodule:: tsbootstrap.block_generator 5 | :members: 6 | :noindex: 7 | -------------------------------------------------------------------------------- /docs/source/block_length_sampler.rst: -------------------------------------------------------------------------------- 1 | Block Length Sampler 2 | ==================== 3 | 4 | .. automodule:: tsbootstrap.block_length_sampler 5 | :members: 6 | :noindex: 7 | -------------------------------------------------------------------------------- /docs/source/block_resampler.rst: -------------------------------------------------------------------------------- 1 | Block Resampler 2 | =============== 3 | 4 | .. automodule:: tsbootstrap.block_resampler 5 | :members: 6 | :noindex: 7 | -------------------------------------------------------------------------------- /docs/source/bootstrap.rst: -------------------------------------------------------------------------------- 1 | Bootstrap 2 | ========= 3 | 4 | .. automodule:: tsbootstrap.bootstrap 5 | :members: 6 | :noindex: 7 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | # sys.path.insert(0, str(Path("../").resolve())) 4 | 5 | # Configuration file for the Sphinx documentation builder. 6 | # 7 | # For the full list of built-in configuration values, see the documentation: 8 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 9 | 10 | # -- Project information ----------------------------------------------------- 11 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 12 | 13 | project = "tsbootstrap" 14 | current_year = datetime.now().year 15 | copyright = f"2023 - {current_year} (MIT License), Sankalp Gilda" 16 | author = "Sankalp Gilda" 17 | release = "0.1.5" 18 | 19 | # -- General configuration --------------------------------------------------- 20 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 21 | 22 | extensions = [ 23 | "sphinx.ext.autodoc", 24 | "sphinx.ext.napoleon", 25 | "sphinx.ext.viewcode", 26 | "sphinx.ext.intersphinx", 27 | ] 28 | 29 | templates_path = ["_templates"] 30 | exclude_patterns = [] 31 | suppress_warnings = ["ref.undefined", "ref.footnote"] 32 | 33 | # -- Options for intersphinx extension --------------------------------------- 34 | # https://www.sphinx-doc.org/en/master/usage/extensions/intersphinx.html#module-sphinx.ext.intersphinx 35 | intersphinx_mapping = { 36 | "sklearn": ("https://scikit-learn.org/stable/", None), 37 | "numpy": ("https://numpy.org/doc/stable/", None), 38 | "pandas": ("https://pandas.pydata.org/docs/", None), 39 | "statsmodels": ("https://www.statsmodels.org/stable/", None), 40 | "arch": ("https://arch.readthedocs.io/en/latest/", None), 41 | } 42 | 43 | # -- Options for HTML output ------------------------------------------------- 44 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 45 | 46 | 47 | html_theme = "sphinx_rtd_theme" 48 | html_theme_options = { 49 | "collapse_navigation": False, 50 | "navigation_depth": 3, 51 | "navigation_with_keys": False, 52 | } 53 | 54 | # html_theme = "furo" 55 | html_static_path = [] 56 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. tsbootstrap documentation master file, created by 2 | sphinx-quickstart on Mon Aug 7 16:06:45 2023. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to tsbootstrap's documentation! 7 | ======================================= 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Contents: 12 | 13 | base_bootstrap 14 | block_bootstrap 15 | block_generator 16 | block_length_sampler 17 | block_resampler 18 | bootstrap 19 | markov_sampler 20 | time_series_model 21 | time_series_simulator 22 | tsfit 23 | odds_and_ends 24 | types 25 | validate 26 | ranklags 27 | 28 | 29 | Indices and tables 30 | ================== 31 | 32 | * :ref:`genindex` 33 | * :ref:`modindex` 34 | * :ref:`search` 35 | -------------------------------------------------------------------------------- /docs/source/markov_sampler.rst: -------------------------------------------------------------------------------- 1 | Markov Sampler 2 | ============== 3 | 4 | .. automodule:: tsbootstrap.markov_sampler 5 | :members: 6 | :noindex: 7 | -------------------------------------------------------------------------------- /docs/source/odds_and_ends.rst: -------------------------------------------------------------------------------- 1 | Odds and Ends 2 | ============= 3 | 4 | .. automodule:: tsbootstrap.utils.odds_and_ends 5 | :members: 6 | :noindex: 7 | -------------------------------------------------------------------------------- /docs/source/ranklags.rst: -------------------------------------------------------------------------------- 1 | RankLags 2 | ======================= 3 | 4 | .. automodule:: tsbootstrap.ranklags 5 | :members: 6 | :noindex: 7 | -------------------------------------------------------------------------------- /docs/source/time_series_model.rst: -------------------------------------------------------------------------------- 1 | Time Series Model 2 | ================= 3 | 4 | .. automodule:: tsbootstrap.time_series_model 5 | :members: 6 | :noindex: 7 | -------------------------------------------------------------------------------- /docs/source/time_series_simulator.rst: -------------------------------------------------------------------------------- 1 | Time Series Simulator 2 | ===================== 3 | 4 | .. automodule:: tsbootstrap.time_series_simulator 5 | :members: 6 | :noindex: 7 | -------------------------------------------------------------------------------- /docs/source/tsfit.rst: -------------------------------------------------------------------------------- 1 | TSFit 2 | ===== 3 | 4 | .. automodule:: tsbootstrap.tsfit 5 | :members: 6 | :noindex: 7 | -------------------------------------------------------------------------------- /docs/source/types.rst: -------------------------------------------------------------------------------- 1 | Types 2 | ===== 3 | 4 | .. automodule:: tsbootstrap.utils.types 5 | :members: 6 | :noindex: 7 | -------------------------------------------------------------------------------- /docs/source/validate.rst: -------------------------------------------------------------------------------- 1 | Validate 2 | ======== 3 | 4 | .. automodule:: tsbootstrap.utils.validate 5 | :members: 6 | :noindex: 7 | -------------------------------------------------------------------------------- /docs/sphinx_build.log: -------------------------------------------------------------------------------- 1 | Running Sphinx v7.2.6 2 | loading intersphinx inventory from https://scikit-learn.org/stable/objects.inv... 3 | loading intersphinx inventory from https://numpy.org/doc/stable/objects.inv... 4 | loading intersphinx inventory from https://pandas.pydata.org/docs/objects.inv... 5 | loading intersphinx inventory from https://www.statsmodels.org/stable/objects.inv... 6 | loading intersphinx inventory from https://arch.readthedocs.io/en/latest/objects.inv... 7 | building [mo]: targets for 0 po files that are out of date 8 | writing output... 9 | building [html]: targets for 17 source files that are out of date 10 | updating environment: [new config] 17 added, 0 changed, 0 removed 11 | reading sources... [ 6%] base_bootstrap reading sources... [ 12%] base_bootstrap_configs reading sources... [ 18%] block_bootstrap reading sources... [ 24%] block_bootstrap_configs reading sources... [ 29%] block_generator reading sources... [ 35%] block_length_sampler reading sources... [ 41%] block_resampler reading sources... [ 47%] bootstrap reading sources... [ 53%] index reading sources... [ 59%] markov_sampler reading sources... [ 65%] odds_and_ends reading sources... [ 71%] ranklags reading sources... [ 76%] time_series_model reading sources... [ 82%] time_series_simulator reading sources... [ 88%] tsfit reading sources... [ 94%] types reading sources... [100%] validate 12 | /home/sgilda/Documents/tsbootstrap/src/tsbootstrap/block_bootstrap_configs.py:docstring of tsbootstrap.block_bootstrap_configs.TukeyBootstrapConfig.tukey_alpha:22: ERROR: Unknown interpreted text role "doi". 13 | looking for now-outdated files... none found 14 | pickling environment... done 15 | checking consistency... done 16 | preparing documents... done 17 | copying assets... copying static files... done 18 | copying extra files... done 19 | done 20 | writing output... [ 6%] base_bootstrap writing output... [ 12%] base_bootstrap_configs writing output... [ 18%] block_bootstrap writing output... [ 24%] block_bootstrap_configs writing output... [ 29%] block_generator writing output... [ 35%] block_length_sampler writing output... [ 41%] block_resampler writing output... [ 47%] bootstrap writing output... [ 53%] index writing output... [ 59%] markov_sampler writing output... [ 65%] odds_and_ends writing output... [ 71%] ranklags writing output... [ 76%] time_series_model writing output... [ 82%] time_series_simulator writing output... [ 88%] tsfit writing output... [ 94%] types writing output... [100%] validate 21 | generating indices... genindex done 22 | highlighting module code... [ 7%] tsbootstrap.base_bootstrap highlighting module code... [ 13%] tsbootstrap.base_bootstrap_configs highlighting module code... [ 20%] tsbootstrap.block_bootstrap highlighting module code... [ 27%] tsbootstrap.block_bootstrap_configs highlighting module code... [ 33%] tsbootstrap.block_generator highlighting module code... [ 40%] tsbootstrap.block_length_sampler highlighting module code... [ 47%] tsbootstrap.block_resampler highlighting module code... [ 53%] tsbootstrap.bootstrap highlighting module code... [ 60%] tsbootstrap.markov_sampler highlighting module code... [ 67%] tsbootstrap.ranklags highlighting module code... [ 73%] tsbootstrap.time_series_model highlighting module code... [ 80%] tsbootstrap.time_series_simulator highlighting module code... [ 87%] tsbootstrap.tsfit highlighting module code... [ 93%] tsbootstrap.utils.odds_and_ends highlighting module code... [100%] tsbootstrap.utils.validate 23 | writing additional pages... search done 24 | dumping search index in English (code: en)... done 25 | dumping object inventory... done 26 | build succeeded, 1 warning. 27 | 28 | The HTML pages are in build/html. 29 | -------------------------------------------------------------------------------- /extension_templates/bootstrap.py: -------------------------------------------------------------------------------- 1 | """Extension template for time series bootstrap algorithms. 2 | 3 | Purpose of this implementation template: 4 | quick implementation of new estimators following the template 5 | NOT a concrete class to import! This is NOT a base class or concrete class! 6 | This is to be used as a "fill-in" coding template. 7 | 8 | How to use this implementation template to implement a new estimator: 9 | - make a copy of the template in a suitable location, give it a descriptive name. 10 | - work through all the "todo" comments below 11 | - fill in code for mandatory methods, and optionally for optional methods 12 | - do not write to reserved attributes: _tags, _tags_dynamic 13 | - you can add more private methods, but do not override BaseObject's private methods 14 | an easy way to be safe is to prefix your methods with "_custom" 15 | - change docstrings for functions and the file 16 | - ensure interface compatibility by using check_estimator from tsbootstrap.utils 17 | - once complete: use as a local library, or contribute to tsbootstrap via PR 18 | 19 | Implementation points: 20 | bootstrapping - _bootstrap(self, X, return_indices, y) 21 | 22 | Testing - required for skbase test framework and check_estimator usage: 23 | get default parameters for test instance(s) - get_test_params() 24 | 25 | copyright: tsbootstrap developers, MIT License (see LICENSE file) 26 | """ 27 | 28 | # todo: write an informative docstring for the file or module, remove the above 29 | # todo: add an appropriate copyright notice for your estimator 30 | # estimators contributed to tsbootstrap should have the copyright notice at the top 31 | # estimators of your own do not need to have permissive copyright 32 | 33 | # todo: uncomment the following line, enter authors' GitHub IDs 34 | # __author__ = [authorGitHubID, anotherAuthorGitHubID] 35 | 36 | 37 | from tsbootstrap.base_bootstrap import BaseTimeSeriesBootstrap 38 | 39 | # todo: add any necessary imports here - only core dependencies 40 | 41 | # todo: for soft dependencies: 42 | # - make sure to fill in the "python_dependencies" tag with the package import name 43 | # - import only in class methods, not at the top of the file 44 | 45 | 46 | class MyBoostrap(BaseTimeSeriesBootstrap): 47 | """Custom time series classifier. todo: write docstring. 48 | 49 | todo: describe your custom time series classifier here 50 | 51 | Parameters 52 | ---------- 53 | parama : int 54 | descriptive explanation of parama 55 | paramb : string, optional (default='default') 56 | descriptive explanation of paramb 57 | paramc : boolean, optional (default= whether paramb is not the default) 58 | descriptive explanation of paramc 59 | and so on 60 | """ 61 | 62 | # optional todo: override base class estimator default tags here if necessary 63 | # these are the default values, only add if different to these. 64 | _tags = { 65 | # packaging info 66 | # -------------- 67 | "authors": ["author1", "author2"], # authors, GitHub handles 68 | "maintainers": [ 69 | "maintainer1", 70 | "maintainer2", 71 | ], # maintainers, GitHub handles 72 | # author = significant contribution to code at some point 73 | # maintainer = algorithm maintainer role, "owner" 74 | # specify one or multiple authors and maintainers, only for contribution 75 | # remove maintainer tag if maintained by tsbootstrap core team 76 | # 77 | "python_version": None, # PEP 440 python version specifier to limit versions 78 | # e.g., ">=3.10", or None if no version limitations 79 | "python_dependencies": None, # PEP 440 python dependencies specifier, 80 | # e.g., "numba>0.53", or a list, e.g., ["numba>0.53", "numpy>=1.19.0"] 81 | # delete if no python dependencies or version limitations 82 | "python_dependencies_aliases": None, 83 | # if import name differs from package name, specify as dict, 84 | # e.g., {"scikit-learn": "sklearn"} 85 | # 86 | # estimator tags 87 | # -------------- 88 | # capability:multivariate = can bootstrap handle multivariate time series? 89 | "capability:multivariate": False, 90 | # valid values: boolean True (yes), False (no) 91 | # if False, raises exception if multivariate data is passed 92 | } 93 | 94 | # todo: add any hyper-parameters and components to constructor 95 | def __init__( 96 | self, 97 | n_bootstraps=10, # every bootstrap must have this as first param 98 | est=None, 99 | parama="foo", 100 | paramb="default", 101 | paramc=None, 102 | ): 103 | # n_bootstraps should be the first parameter, default of 10 104 | # after that, BaseObject descendants should precede other parameters 105 | # all parameters must have default values 106 | 107 | # todo: write any hyper-parameters and components to self 108 | self.n_bootstraps = n_bootstraps 109 | self.est = est 110 | self.parama = parama 111 | self.paramb = paramb 112 | self.paramc = paramc 113 | # IMPORTANT: the self.params should never be overwritten or mutated from now on 114 | # for handling defaults etc, write to other attributes, e.g., self._parama 115 | # for estimators, initialize a clone, e.g., self.est_ = est.clone() 116 | 117 | # leave this as is 118 | super().__init__() 119 | 120 | # todo: optional, parameter checking logic (if applicable) should happen here 121 | # if writes derived values to self, should *not* overwrite self.parama etc 122 | # instead, write to self._parama, self._newparam (starting with _) 123 | 124 | # todo: default estimators should have None arg defaults 125 | # and be initialized here 126 | # do this only with default estimators, not with parameters 127 | # if est is None: 128 | # self.estimator = MyDefaultEstimator() 129 | 130 | # todo: if tags of estimator depend on component tags, set these here 131 | # only needed if estimator is a composite 132 | # tags set in the constructor apply to the object and override the class 133 | # 134 | # example 1: conditional setting of a tag 135 | # if est.foo == 42: 136 | # self.set_tags(handles-missing-data=True) 137 | # example 2: cloning tags from component 138 | # self.clone_tags(est, ["capability:multivariate"]) 139 | 140 | # todo: implement this, mandatory 141 | def _bootstrap(self, X, return_indices=False, y=None): 142 | """Generate indices to split data into training and test set. 143 | 144 | Parameters 145 | ---------- 146 | X : 2D array-like of shape (n_timepoints, n_features) 147 | The endogenous time series to bootstrap. 148 | Dimension 0 is assumed to be the time dimension, ordered 149 | return_indices : bool, default=False 150 | If True, a second output is retured, integer locations of 151 | index references for the bootstrap sample, in reference to original indices. 152 | Indexed values do are not necessarily identical with bootstrapped values. 153 | y : array-like of shape (n_timepoints, n_features_exog), default=None 154 | Exogenous time series to use in bootstrapping. 155 | 156 | Yields 157 | ------ 158 | X_boot_i : 2D np.ndarray-like of shape (n_timepoints_boot_i, n_features) 159 | i-th bootstrapped sample of X. 160 | indices_i : 1D np.nparray of shape (n_timepoints_boot_i,) integer values, 161 | only returned if return_indices=True. 162 | Index references for the i-th bootstrapped sample of X. 163 | Indexed values do are not necessarily identical with bootstrapped values. 164 | """ 165 | # todo: implement the bootstrapping logic here 166 | # 167 | # ensure: no side effects, no mutation of self, X, y 168 | # ensure to deal with return_indices False and True 169 | # y can be ignored if not needed 170 | 171 | yield 42 # replace this with actual bootstrapping logic 172 | 173 | # todo: return default parameters, so that a test instance can be created 174 | # required for automated unit and integration testing of estimator 175 | @classmethod 176 | def get_test_params(cls, parameter_set="default"): 177 | """Return testing parameter settings for the estimator. 178 | 179 | Parameters 180 | ---------- 181 | parameter_set : str, default="default" 182 | Name of the set of test parameters to return, for use in tests. If no 183 | special parameters are defined for a value, will return `"default"` set. 184 | Reserved values for classifiers: 185 | "results_comparison" - used for identity testing in some classifiers 186 | should contain parameter settings comparable to "TSC bakeoff" 187 | 188 | Returns 189 | ------- 190 | params : dict or list of dict, default = {} 191 | Parameters to create testing instances of the class 192 | Each dict are parameters to construct an "interesting" test instance, i.e., 193 | `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. 194 | `create_test_instance` uses the first (or only) dictionary in `params` 195 | """ 196 | 197 | # todo: set the testing parameters for the estimators 198 | # Testing parameters can be dictionary or list of dictionaries 199 | # 200 | # this can, if required, use: 201 | # class properties (e.g., inherited); parent class test case 202 | # imported objects such as estimators from sklearn 203 | # important: all such imports should be *inside get_test_params*, not at the top 204 | # since imports are used only at testing time 205 | # 206 | # The parameter_set argument is not used for most automated, module level tests. 207 | # It can be used in custom, estimator specific tests, for "special" settings. 208 | # For classification, this is also used in tests for reference settings, 209 | # such as published in benchmarking studies, or for identity testing. 210 | # A parameter dictionary must be returned *for all values* of parameter_set, 211 | # i.e., "parameter_set not available" errors should never be raised. 212 | # 213 | # A good parameter set should primarily satisfy two criteria, 214 | # 1. Chosen set of parameters should have a low testing time, 215 | # ideally in the magnitude of few seconds for the entire test suite. 216 | # This is vital for the cases where default values result in 217 | # "big" models which not only increases test time but also 218 | # run into the risk of test workers crashing. 219 | # 2. There should be a minimum two such parameter sets with different 220 | # sets of values to ensure a wide range of code coverage is provided. 221 | # 222 | # example 1: specify params as dictionary 223 | # any number of params can be specified 224 | # params = {"est": value0, "parama": value1, "paramb": value2} 225 | # 226 | # example 2: specify params as list of dictionary 227 | # note: Only first dictionary will be used by create_test_instance 228 | # params = [{"est": value1, "parama": value2}, 229 | # {"est": value3, "parama": value4}] 230 | # 231 | # example 3: parameter set depending on param_set value 232 | # note: only needed if a separate parameter set is needed in tests 233 | # if parameter_set == "special_param_set": 234 | # params = {"est": value1, "parama": value2} 235 | # return params 236 | # 237 | # # "default" params 238 | # params = {"est": value3, "parama": value4} 239 | # return params 240 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "tsbootstrap" 3 | version = "0.1.5" 4 | description = "A Python package to generate bootstrapped time series" 5 | maintainers = [ 6 | { name = "Sankalp Gilda", email = "sankalp.gilda@gmail.com" }, 7 | { name = "Franz Kiraly", email = "franz.kiraly@sktime.net"}, 8 | { name = "Benedikt Heidrich", email = "benedikt.heidrich@sktime.net"}, 9 | ] 10 | authors = [ 11 | { name = "Sankalp Gilda", email = "sankalp.gilda@gmail.com" }, 12 | ] 13 | license = { file = "LICENSE" } 14 | readme = "README.md" 15 | requires-python = ">=3.9,<3.13" 16 | classifiers = [ 17 | "Development Status :: 4 - Beta", 18 | "Intended Audience :: Science/Research", 19 | "Intended Audience :: Developers", 20 | "Operating System :: MacOS", 21 | "Operating System :: Microsoft :: Windows", 22 | "Operating System :: Unix", 23 | "Programming Language :: Python", 24 | "Programming Language :: Python :: 3 :: Only", 25 | "Programming Language :: Python :: 3.9", 26 | "Programming Language :: Python :: 3.10", 27 | "Programming Language :: Python :: 3.11", 28 | "Programming Language :: Python :: 3.12", 29 | ] 30 | 31 | dependencies = [ 32 | "numpy<1.27,>=1.21", 33 | "scikit-base>=0.10.0,<0.11", 34 | "scikit-learn>=1.5.1,<1.6.0", 35 | "scipy>=1.10,<1.14.0", 36 | "packaging>=24.0,<24.2", 37 | "pydantic>=2.0,<3.0", 38 | ] 39 | 40 | [project.optional-dependencies] 41 | 42 | all-extras = [ 43 | "arch>=7.0.0,<7.1.0", 44 | "hmmlearn>=0.3.0,<0.3.2", 45 | "pyclustering>=0.10.0,<0.11.0", 46 | "scikit_learn_extra>=0.3.0,<0.4.0", 47 | "statsmodels>=0.14.2,<0.15.0", 48 | "dtaidistance; python_version < '3.10'", 49 | ] 50 | 51 | docs = [ 52 | "furo", 53 | "jupyter", 54 | "myst-parser", 55 | "nbsphinx>=0.8.6", 56 | "numpydoc", 57 | "pydata-sphinx-theme", 58 | "Sphinx!=7.2.0,<8.0.0", 59 | "sphinx-rtd-theme>=1.3.0", 60 | "sphinx-copybutton>=0.5.2", 61 | "sphinx-design<0.6.0", 62 | "sphinx-gallery<0.15.0", 63 | "sphinx-issues<4.0.0", 64 | "sphinx-version-warning", 65 | "tabulate>=0.9.0", 66 | ] 67 | 68 | dev = [ 69 | "black>=24.3.0", 70 | "blacken-docs", 71 | "hypothesis", 72 | "pre-commit", 73 | "pytest", 74 | "pytest-cov", 75 | "github-actions", 76 | "importlib-metadata", 77 | "pip-tools", 78 | "pyright", 79 | "ruff", 80 | "autoflake", 81 | "typos", 82 | "tox", 83 | "tox-gh-actions", 84 | "pycobertura", 85 | "tomlkit" 86 | ] 87 | 88 | [tool.pytest.ini_options] 89 | minversion = "6.0" 90 | addopts = "-ra -q" 91 | testpaths = [ 92 | "tests", 93 | ] 94 | 95 | [tool.pytest.cov] 96 | source = ["src/tsbootstrap"] 97 | 98 | [tool.black] 99 | line-length = 79 100 | target-version = ["py310", "py311"] 101 | 102 | [tool.ruff] 103 | target-version = 'py310' 104 | select = [ 105 | "B", # flake8-bugbear 106 | "C4", # flake8-comprehensions 107 | "D", # pydocstyle 108 | "E", # Error 109 | "F", # pyflakes 110 | "I", # isort 111 | "ISC", # flake8-implicit-str-concat 112 | "N", # pep8-naming 113 | "PGH", # pygrep-hooks 114 | "PTH", # flake8-use-pathlib 115 | "Q", # flake8-quotes 116 | "S", # bandit 117 | "SIM", # flake8-simplify 118 | "TRY", # tryceratops 119 | "UP", # pyupgrade 120 | "W", # pycodestyle warnings 121 | "YTT", # flake8-2020 122 | ] 123 | 124 | exclude = [ 125 | "migrations", 126 | "__pycache__", 127 | "manage.py", 128 | "settings.py", 129 | "env", 130 | ".env", 131 | "venv", 132 | ".venv", 133 | ] 134 | 135 | ignore = [ 136 | "B905", # zip strict=true; remove once python <3.10 support is dropped. 137 | "C901", # function is too complex; overly strict 138 | "D100", 139 | "D101", 140 | "D102", 141 | "D103", 142 | "D104", 143 | "D105", 144 | "D106", 145 | "D107", 146 | "D200", 147 | "D401", 148 | "E402", 149 | "E501", # line length handled by black 150 | "F401", 151 | "N802", # Function name should be lowercase; overly strict 152 | "N803", # Argument name should be lowercase; overly strict 153 | "N806", # Variable in function should be lowercase; overly strict 154 | "N816", # Variable in class scope should not be mixedCase; overly strict 155 | "PGH003", # Use of "eval"; overly strict 156 | "SIM115", # open-file-with-context-handler; overly strict 157 | "TRY003", # Avoid specifying messages outside exception class; overly strict, especially for ValueError 158 | "UP038", # Use `X | Y` in `isinstance` call instead of `(X, Y)`; overly strict 159 | "UP007", # Use `X | Y` for type annotationsRuffUP007; overly strict 160 | "UP006", # Use `list` instead of `List` for type annotations; overly strict 161 | "UP035", # `typing.List` is deprecated, use `list` instead; overly strict 162 | ] 163 | line-length = 79 # Must agree with Black 164 | 165 | [tool.ruff.isort] 166 | order-by-type = true 167 | relative-imports-order = "closest-to-furthest" 168 | extra-standard-library = ["typing"] 169 | section-order = ["future", "standard-library", "third-party", "first-party", "local-folder"] 170 | known-first-party = [] 171 | 172 | 173 | [tool.ruff.flake8-bugbear] 174 | extend-immutable-calls = [ 175 | "chr", 176 | "typer.Argument", 177 | "typer.Option", 178 | ] 179 | 180 | [tool.ruff.pydocstyle] 181 | convention = "numpy" 182 | 183 | [tool.ruff.per-file-ignores] 184 | "tests/*.py" = [ 185 | "D100", 186 | "D101", 187 | "D102", 188 | "D103", 189 | "D104", 190 | "D105", 191 | "D106", 192 | "D107", 193 | "S101", # use of "assert" 194 | "S102", # use of "exec" 195 | "S106", # possible hardcoded password. 196 | "PGH001", # use of "eval" 197 | ] 198 | "src/tsbootstrap/tests/*.py" = [ 199 | "D100", 200 | "D101", 201 | "D102", 202 | "D103", 203 | "D104", 204 | "D105", 205 | "D106", 206 | "D107", 207 | "S101", # use of "assert" 208 | "S102", # use of "exec" 209 | "S106", # possible hardcoded password. 210 | "PGH001", # use of "eval" 211 | ] 212 | 213 | [tool.ruff.pep8-naming] 214 | staticmethod-decorators = [ 215 | "pydantic.validator", 216 | "pydantic.root_validator", 217 | ] 218 | 219 | [tool.interrogate] 220 | ignore_init_module = true 221 | ignore_init_class = true 222 | ignore_magic = true 223 | ignore_semiprivate = true 224 | ignore_private = true 225 | ignore_nested_functions = true 226 | ignore_nested_classes = true 227 | ignore_imports = false 228 | exclude = [".venv/*", "tests/*", "docs/*", "build/*", "dist/*", "src/tsbootstrap/_version.py", "src/tsbootstrap/__init__.py", "src/tsbootstrap/utils/types.py"] 229 | 230 | [tool.coverage.run] 231 | source = ['src/'] 232 | omit = ['tests/*', '.venv/*'] 233 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python_version=$(python -c 'import sys; print(sys.version_info[:2])') 4 | 5 | poetry config virtualenvs.in-project true 6 | poetry lock 7 | poetry install 8 | 9 | # Only install dtaidistance for Python 3.9 or lower 10 | if [[ "$python_version" != "(3, 10)" && "$python_version" != "(3, 11)" ]]; then 11 | poetry run python -m pip install dtaidistance 12 | fi 13 | -------------------------------------------------------------------------------- /src/tsbootstrap/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib.metadata import version 2 | 3 | __version__ = version("tsbootstrap") 4 | 5 | from .base_bootstrap import ( 6 | BaseDistributionBootstrap, 7 | BaseMarkovBootstrap, 8 | BaseResidualBootstrap, 9 | BaseSieveBootstrap, 10 | BaseStatisticPreservingBootstrap, 11 | BaseTimeSeriesBootstrap, 12 | ) 13 | from .base_bootstrap_configs import ( 14 | BaseDistributionBootstrapConfig, 15 | BaseMarkovBootstrapConfig, 16 | BaseResidualBootstrapConfig, 17 | BaseSieveBootstrapConfig, 18 | BaseStatisticPreservingBootstrapConfig, 19 | BaseTimeSeriesBootstrapConfig, 20 | ) 21 | from .block_bootstrap import ( 22 | BartlettsBootstrap, 23 | BaseBlockBootstrap, 24 | BlackmanBootstrap, 25 | BlockBootstrap, 26 | CircularBlockBootstrap, 27 | HammingBootstrap, 28 | HanningBootstrap, 29 | MovingBlockBootstrap, 30 | NonOverlappingBlockBootstrap, 31 | StationaryBlockBootstrap, 32 | TukeyBootstrap, 33 | ) 34 | from .block_bootstrap_configs import ( 35 | BartlettsBootstrapConfig, 36 | BaseBlockBootstrapConfig, 37 | BlackmanBootstrapConfig, 38 | BlockBootstrapConfig, 39 | CircularBlockBootstrapConfig, 40 | HammingBootstrapConfig, 41 | HanningBootstrapConfig, 42 | MovingBlockBootstrapConfig, 43 | NonOverlappingBlockBootstrapConfig, 44 | StationaryBlockBootstrapConfig, 45 | TukeyBootstrapConfig, 46 | ) 47 | from .block_generator import BlockGenerator 48 | from .block_length_sampler import BlockLengthSampler 49 | from .block_resampler import BlockResampler 50 | from .bootstrap import ( 51 | BlockDistributionBootstrap, 52 | BlockMarkovBootstrap, 53 | BlockResidualBootstrap, 54 | BlockSieveBootstrap, 55 | BlockStatisticPreservingBootstrap, 56 | WholeDistributionBootstrap, 57 | WholeMarkovBootstrap, 58 | WholeResidualBootstrap, 59 | WholeSieveBootstrap, 60 | WholeStatisticPreservingBootstrap, 61 | ) 62 | from .markov_sampler import ( 63 | BlockCompressor, 64 | MarkovSampler, 65 | MarkovTransitionMatrixCalculator, 66 | ) 67 | from .ranklags import RankLags 68 | from .time_series_model import TimeSeriesModel 69 | from .time_series_simulator import TimeSeriesSimulator 70 | from .tsfit import TSFit, TSFitBestLag 71 | -------------------------------------------------------------------------------- /src/tsbootstrap/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/astrogilda/tsbootstrap/6d6b9d37a87c7050a212a6f57fa6fca36d1d1ce4/src/tsbootstrap/py.typed -------------------------------------------------------------------------------- /src/tsbootstrap/ranklags.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | from numbers import Integral 5 | 6 | import numpy as np 7 | 8 | from tsbootstrap.utils.types import ModelTypes 9 | from tsbootstrap.utils.validate import validate_integers, validate_literal_type 10 | 11 | logger = logging.getLogger("tsbootstrap") 12 | 13 | 14 | class RankLags: 15 | """ 16 | A class that uses several metrics to rank lags for time series models. 17 | 18 | Methods 19 | ------- 20 | rank_lags_by_aic_bic() 21 | Rank lags based on Akaike information criterion (AIC) and Bayesian information criterion (BIC). 22 | rank_lags_by_pacf() 23 | Rank lags based on Partial Autocorrelation Function (PACF) values. 24 | estimate_conservative_lag() 25 | Estimate a conservative lag value by considering various metrics. 26 | get_model(order) 27 | Retrieve a previously fitted model given an order. 28 | 29 | Examples 30 | -------- 31 | >>> from tsbootstrap import RankLags 32 | >>> import numpy as np 33 | >>> X = np.random.normal(size=(100, 1)) 34 | >>> rank_obj = RankLags(X, model_type='ar') 35 | >>> rank_obj.estimate_conservative_lag() 36 | 2 37 | >>> rank_obj.rank_lags_by_aic_bic() 38 | (array([2, 1]), array([2, 1])) 39 | >>> rank_obj.rank_lags_by_pacf() 40 | array([1, 2]) 41 | """ 42 | 43 | _tags = {"python_dependencies": "statsmodels"} 44 | 45 | def __init__( 46 | self, 47 | X: np.ndarray, 48 | model_type: ModelTypes, 49 | max_lag: Integral = 10, 50 | y=None, 51 | save_models: bool = False, 52 | ) -> None: 53 | """ 54 | Initialize the RankLags object. 55 | 56 | Parameters 57 | ---------- 58 | X : np.ndarray 59 | The input data. 60 | model_type : str 61 | The type of model to fit. One of 'ar', 'arima', 'sarima', 'var', 'arch'. 62 | max_lag : int, optional, default=10 63 | Maximum lag to consider. 64 | y : np.ndarray, optional, default=None 65 | Exogenous variables to include in the model. 66 | save_models : bool, optional, default=False 67 | Whether to save the models. 68 | """ 69 | self.X = X 70 | self.max_lag = max_lag 71 | self.model_type = model_type 72 | self.y = y 73 | self.save_models = save_models 74 | self.models = [] 75 | 76 | @property 77 | def X(self) -> np.ndarray: 78 | """ 79 | The input data. 80 | 81 | Returns 82 | ------- 83 | np.ndarray 84 | The input data. 85 | """ 86 | return self._X 87 | 88 | @X.setter 89 | def X(self, value: np.ndarray) -> None: 90 | """ 91 | Set the input data. 92 | 93 | Parameters 94 | ---------- 95 | X : np.ndarray 96 | The input data. 97 | """ 98 | if not isinstance(value, np.ndarray): 99 | raise TypeError("X must be a numpy array.") 100 | self._X = value 101 | 102 | @property 103 | def max_lag(self) -> Integral: 104 | """ 105 | Maximum lag to consider. 106 | 107 | Returns 108 | ------- 109 | int 110 | Maximum lag to consider. 111 | """ 112 | return self._max_lag 113 | 114 | @max_lag.setter 115 | def max_lag(self, value: Integral) -> None: 116 | """ 117 | Set the maximum lag to consider. 118 | 119 | Parameters 120 | ---------- 121 | max_lag : int 122 | Maximum lag to consider. 123 | """ 124 | validate_integers(value, min_value=1) 125 | self._max_lag = value 126 | 127 | @property 128 | def model_type(self) -> ModelTypes: 129 | """ 130 | The type of model to fit. 131 | 132 | Returns 133 | ------- 134 | str 135 | The type of model to fit. 136 | """ 137 | return self._model_type 138 | 139 | @model_type.setter 140 | def model_type(self, value: ModelTypes) -> None: 141 | """ 142 | Set the type of model to fit. 143 | 144 | Parameters 145 | ---------- 146 | value : ModelTypes 147 | The type of model to fit. One of 'ar', 'arima', 'sarima', 'var', 'arch'. 148 | """ 149 | validate_literal_type(value, ModelTypes) 150 | self._model_type = value.lower() 151 | 152 | @property 153 | def y(self) -> np.ndarray: 154 | """ 155 | Exogenous variables to include in the model. 156 | 157 | Returns 158 | ------- 159 | np.ndarray 160 | Exogenous variables to include in the model. 161 | """ 162 | return self._y 163 | 164 | @y.setter 165 | def y(self, value: np.ndarray) -> None: 166 | """ 167 | Set the exogenous variables to include in the model. 168 | 169 | Parameters 170 | ---------- 171 | y : np.ndarray 172 | Exogenous variables to include in the model. 173 | """ 174 | if value is not None and not isinstance(value, np.ndarray): 175 | raise TypeError("y must be a numpy array.") 176 | self._y = value 177 | 178 | def rank_lags_by_aic_bic(self): 179 | """ 180 | Rank lags based on Akaike information criterion (AIC) and Bayesian information criterion (BIC). 181 | 182 | Returns 183 | ------- 184 | Tuple[np.ndarray, np.ndarray] 185 | aic_ranked_lags: Lags ranked by AIC. 186 | bic_ranked_lags: Lags ranked by BIC. 187 | """ 188 | from tsbootstrap.tsfit import TSFit 189 | 190 | aic_values = [] 191 | bic_values = [] 192 | for lag in range(1, self.max_lag + 1): 193 | try: 194 | fit_obj = TSFit(order=lag, model_type=self.model_type) 195 | model = fit_obj.fit(X=self.X, y=self.y).model 196 | except Exception as e: 197 | # raise RuntimeError(f"An error occurred during fitting: {e}") 198 | logger.warning( 199 | f"An error occurred during fitting for lag {lag}. Skipping remaining lags." 200 | ) 201 | logger.debug(f"{e}") 202 | break 203 | if self.save_models: 204 | self.models.append(model) 205 | aic_values.append(model.aic) 206 | bic_values.append(model.bic) 207 | 208 | aic_ranked_lags = np.argsort(aic_values) + 1 209 | bic_ranked_lags = np.argsort(bic_values) + 1 210 | 211 | return aic_ranked_lags, bic_ranked_lags 212 | 213 | def rank_lags_by_pacf(self) -> np.ndarray: 214 | """ 215 | Rank lags based on Partial Autocorrelation Function (PACF) values. 216 | 217 | Returns 218 | ------- 219 | np.ndarray 220 | Lags ranked by PACF values. 221 | """ 222 | from statsmodels.tsa.stattools import pacf 223 | 224 | # Can only compute partial correlations for lags up to 50% of the sample size. We use the minimum of max_lag and third of the sample size, to allow for other parameters and trends to be included in the model. 225 | pacf_values = pacf( 226 | self.X, nlags=max(min(self.max_lag, self.X.shape[0] // 3 - 1), 1) 227 | )[1:] 228 | ci = 1.96 / np.sqrt(len(self.X)) 229 | significant_lags = np.where(np.abs(pacf_values) > ci)[0] + 1 230 | return significant_lags 231 | 232 | def estimate_conservative_lag(self) -> int: 233 | """ 234 | Estimate a conservative lag value by considering various metrics. 235 | 236 | Returns 237 | ------- 238 | int 239 | A conservative lag value. 240 | """ 241 | aic_ranked_lags, bic_ranked_lags = self.rank_lags_by_aic_bic() 242 | # PACF is only available for univariate data 243 | if self.X.shape[1] == 1: 244 | pacf_ranked_lags = self.rank_lags_by_pacf() 245 | highest_ranked_lags = set(aic_ranked_lags).intersection( 246 | bic_ranked_lags, pacf_ranked_lags 247 | ) 248 | else: 249 | highest_ranked_lags = set(aic_ranked_lags).intersection( 250 | bic_ranked_lags 251 | ) 252 | 253 | if not highest_ranked_lags: 254 | return aic_ranked_lags[-1] 255 | else: 256 | return min(highest_ranked_lags) 257 | 258 | def get_model(self, order: int): 259 | """ 260 | Retrieve a previously fitted model given an order. 261 | 262 | Parameters 263 | ---------- 264 | order : int 265 | Order of the model to retrieve. 266 | 267 | Returns 268 | ------- 269 | Union[AutoRegResultsWrapper, ARIMAResultsWrapper, SARIMAXResultsWrapper, VARResultsWrapper, ARCHModelResult] 270 | The fitted model. 271 | """ 272 | return self.models[order - 1] if self.save_models else None 273 | -------------------------------------------------------------------------------- /src/tsbootstrap/registry/__init__.py: -------------------------------------------------------------------------------- 1 | """Registry and lookup functionality.""" 2 | 3 | from tsbootstrap.registry._lookup import all_objects 4 | from tsbootstrap.registry._tags import ( 5 | OBJECT_TAG_LIST, 6 | OBJECT_TAG_REGISTER, 7 | ) 8 | 9 | __all__ = [ 10 | "OBJECT_TAG_LIST", 11 | "OBJECT_TAG_REGISTER", 12 | "all_objects", 13 | ] 14 | -------------------------------------------------------------------------------- /src/tsbootstrap/registry/_lookup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Registry lookup methods. 3 | 4 | This module exports the following methods for registry lookup: 5 | 6 | - all_objects(object_types: Optional[Union[str, List[str]]] = None, 7 | filter_tags: Optional[Dict[str, Union[str, List[str], bool]]] = None, 8 | exclude_objects: Optional[Union[str, List[str]]] = None, 9 | return_names: bool = True, 10 | as_dataframe: bool = False, 11 | return_tags: Optional[Union[str, List[str]]] = None, 12 | suppress_import_stdout: bool = True) -> Union[List[Any], List[Tuple]] 13 | Lookup and filtering of objects in the tsbootstrap registry. 14 | """ 15 | 16 | from pathlib import Path 17 | from typing import Any, Dict, List, Optional, Tuple, Union 18 | 19 | from skbase.base import BaseObject 20 | from skbase.lookup import all_objects as _all_objects 21 | 22 | from tsbootstrap.registry._tags import ( 23 | OBJECT_TAG_REGISTER, 24 | check_tag_is_valid, 25 | ) 26 | 27 | VALID_OBJECT_TYPE_STRINGS: set = {tag.scitype for tag in OBJECT_TAG_REGISTER} 28 | 29 | 30 | def all_objects( 31 | object_types: Optional[Union[str, List[str]]] = None, 32 | filter_tags: Optional[ 33 | Union[str, Dict[str, Union[str, List[str], bool]]] 34 | ] = None, 35 | exclude_objects: Optional[Union[str, List[str]]] = None, 36 | return_names: bool = True, 37 | as_dataframe: bool = False, 38 | return_tags: Optional[Union[str, List[str]]] = None, 39 | suppress_import_stdout: bool = True, 40 | ) -> Union[List[Any], List[Tuple]]: 41 | """ 42 | Get a list of all objects from tsbootstrap. 43 | 44 | This function crawls the module and retrieves all classes that inherit 45 | from tsbootstrap's and sklearn's base classes. 46 | 47 | Excluded from retrieval are: 48 | - The base classes themselves 49 | - Classes defined in test modules 50 | 51 | Parameters 52 | ---------- 53 | object_types : Union[str, List[str]], optional (default=None) 54 | Specifies which types of objects to return. 55 | - If None, no filtering is applied and all objects are returned. 56 | - If str or list of str, only objects matching the specified scitypes are returned. 57 | Valid scitypes are entries in `registry.BASE_CLASS_REGISTER` (first column). 58 | 59 | filter_tags : Union[str, Dict[str, Union[str, List[str], bool]]], optional (default=None) 60 | Dictionary or string to filter returned objects based on their tags. 61 | - If a string, it is treated as a boolean tag filter with the value `True`. 62 | - If a dictionary, each key-value pair represents a filter condition in an "AND" conjunction. 63 | - Key is the tag name to filter on. 64 | - Value is a string, list of strings, or boolean that the tag value must match or be within. 65 | - Only objects satisfying all filter conditions are returned. 66 | 67 | exclude_objects : Union[str, List[str]], optional (default=None) 68 | Names of objects to exclude from the results. 69 | 70 | return_names : bool, optional (default=True) 71 | - If True, the object's class name is included in the returned results. 72 | - If False, the class name is omitted. 73 | 74 | as_dataframe : bool, optional (default=False) 75 | - If True, returns a pandas.DataFrame with named columns for all returned attributes. 76 | - If False, returns a list (of objects or tuples). 77 | 78 | return_tags : Union[str, List[str]], optional (default=None) 79 | - Names of tags to fetch and include in the returned results. 80 | - If specified, tag values are appended as either columns or tuple entries. 81 | 82 | suppress_import_stdout : bool, optional (default=True) 83 | Whether to suppress stdout printout upon import. 84 | 85 | Returns 86 | ------- 87 | Union[List[Any], List[Tuple]] 88 | Depending on the parameters: 89 | 1. List of objects: 90 | - Entries are objects matching the query, in alphabetical order of object name. 91 | 2. List of tuples: 92 | - Each tuple contains (optional object name, object class, optional object tags). 93 | - Ordered alphabetically by object name. 94 | 3. pandas.DataFrame: 95 | - Columns represent the returned attributes. 96 | - Includes "objects", "names", and any specified tag columns. 97 | 98 | Examples 99 | -------- 100 | >>> from tsbootstrap.registry import all_objects 101 | >>> # Return a complete list of objects as a DataFrame 102 | >>> all_objects(as_dataframe=True) 103 | >>> # Return all bootstrap algorithms by filtering for object type 104 | >>> all_objects("bootstrap", as_dataframe=True) 105 | >>> # Return all bootstraps which are block bootstraps 106 | >>> all_objects( 107 | ... object_types="bootstrap", 108 | ... filter_tags={"bootstrap_type": "block"}, 109 | ... as_dataframe=True 110 | ... ) 111 | 112 | References 113 | ---------- 114 | Adapted version of sktime's `all_estimators`, 115 | which is an evolution of scikit-learn's `all_estimators`. 116 | """ 117 | MODULES_TO_IGNORE = ( 118 | "tests", 119 | "setup", 120 | "contrib", 121 | "utils", 122 | "all", 123 | ) 124 | 125 | result: Union[List[Any], List[Tuple]] = [] 126 | ROOT = str( 127 | Path(__file__).parent.parent 128 | ) # tsbootstrap package root directory 129 | 130 | # Prepare filter_tags 131 | if isinstance(filter_tags, str): 132 | # Ensure the tag expects a boolean value 133 | tag = next( 134 | (t for t in OBJECT_TAG_REGISTER if t.name == filter_tags), None 135 | ) 136 | if not tag: 137 | raise ValueError( 138 | f"Tag '{filter_tags}' not found in OBJECT_TAG_REGISTER." 139 | ) 140 | if tag.value_type != "bool": 141 | raise ValueError( 142 | f"Tag '{filter_tags}' does not expect a boolean value." 143 | ) 144 | filter_tags = {filter_tags: True} 145 | elif isinstance(filter_tags, dict): 146 | # Validate each tag in filter_tags 147 | for key, value in filter_tags.items(): 148 | try: 149 | if not check_tag_is_valid(key, value): 150 | raise ValueError( 151 | f"Invalid value '{value}' for tag '{key}'." 152 | ) 153 | except KeyError as e: 154 | raise ValueError( 155 | f"Tag '{key}' not found in OBJECT_TAG_REGISTER." 156 | ) from e 157 | else: 158 | filter_tags = None 159 | 160 | if object_types: 161 | if isinstance(object_types, str): 162 | object_types = [object_types] 163 | # Validate object_types 164 | invalid_types = set(object_types) - VALID_OBJECT_TYPE_STRINGS 165 | if invalid_types: 166 | raise ValueError( 167 | f"Invalid object_types: {invalid_types}. Valid types are {VALID_OBJECT_TYPE_STRINGS}." 168 | ) 169 | if filter_tags and "object_type" not in filter_tags: 170 | object_tag_filter = {"object_type": object_types} 171 | filter_tags.update(object_tag_filter) 172 | elif filter_tags and "object_type" in filter_tags: 173 | existing_filter = filter_tags.get("object_type", []) 174 | if isinstance(existing_filter, str): 175 | existing_filter = [existing_filter] 176 | elif isinstance(existing_filter, list): 177 | pass 178 | else: 179 | raise ValueError( 180 | f"Unexpected type for 'object_type' filter: {type(existing_filter)}" 181 | ) 182 | combined_filter = list(set(object_types + existing_filter)) 183 | filter_tags["object_type"] = combined_filter 184 | else: 185 | filter_tags = {"object_type": object_types} 186 | 187 | # Retrieve objects using skbase's all_objects 188 | result = _all_objects( 189 | object_types=[BaseObject], 190 | filter_tags=filter_tags, 191 | exclude_objects=exclude_objects, 192 | return_names=return_names, 193 | as_dataframe=as_dataframe, 194 | return_tags=return_tags, 195 | suppress_import_stdout=suppress_import_stdout, 196 | package_name="tsbootstrap", 197 | path=ROOT, 198 | modules_to_ignore=MODULES_TO_IGNORE, 199 | ) 200 | 201 | return result 202 | -------------------------------------------------------------------------------- /src/tsbootstrap/registry/_tags.py: -------------------------------------------------------------------------------- 1 | """ 2 | Register of estimator and object tags. 3 | 4 | Note for extenders: 5 | New tags should be entered in `OBJECT_TAG_REGISTER`. 6 | No other place is necessary to add new tags. 7 | 8 | This module exports the following: 9 | 10 | - OBJECT_TAG_REGISTER : List[Tag] 11 | A list of Tag instances, each representing a tag with its attributes. 12 | 13 | - OBJECT_TAG_TABLE : List[Dict[str, Any]] 14 | `OBJECT_TAG_REGISTER` in table form as a list of dictionaries. 15 | 16 | - OBJECT_TAG_LIST : List[str] 17 | List of tag names extracted from `OBJECT_TAG_REGISTER`. 18 | 19 | - check_tag_is_valid(tag_name: str, tag_value: Any) -> bool 20 | Function to validate if a tag value is valid for a given tag name. 21 | """ 22 | 23 | from typing import Any, Dict, List, Tuple, Union 24 | 25 | from pydantic import BaseModel, field_validator 26 | 27 | 28 | class Tag(BaseModel): 29 | """ 30 | Represents a single tag with its properties. 31 | 32 | Attributes 33 | ---------- 34 | name : str 35 | Name of the tag as used in the _tags dictionary. 36 | scitype : str 37 | Name of the scitype this tag applies to. 38 | value_type : Union[str, Tuple[str, Union[List[str], str]], List[Union[str, Tuple[str, Union[List[str], str]]]]] 39 | Expected type(s) of the tag value. 40 | description : str 41 | Plain English description of the tag. 42 | """ 43 | 44 | name: str 45 | scitype: str 46 | value_type: Union[ 47 | str, 48 | Tuple[str, Union[List[str], str]], 49 | List[Union[str, Tuple[str, Union[List[str], str]]]], 50 | ] 51 | description: str 52 | 53 | @field_validator("value_type") 54 | @classmethod 55 | def validate_value_type(cls, v): 56 | """ 57 | Validates the `value_type` attribute to ensure it adheres to expected formats. 58 | 59 | Parameters 60 | ---------- 61 | v : Union[str, Tuple[str, Union[List[str], str]], List[Union[str, Tuple[str, Union[List[str], str]]]]] 62 | The value to validate. 63 | 64 | Returns 65 | ------- 66 | Union[str, Tuple[str, Union[List[str], str]], List[Union[str, Tuple[str, Union[List[str], str]]]]] 67 | The validated value. 68 | 69 | Raises 70 | ------ 71 | ValueError 72 | If `v` does not conform to expected types and constraints. 73 | TypeError 74 | If `v` is neither a string, a tuple, nor a list. 75 | """ 76 | valid_base_types = {"bool", "int", "str", "list", "dict"} 77 | 78 | def validate_single_type(single_v): 79 | if isinstance(single_v, str): 80 | if single_v not in valid_base_types: 81 | raise ValueError( 82 | f"Invalid value_type: {single_v}. Must be one of {valid_base_types}." 83 | ) 84 | elif isinstance(single_v, tuple): 85 | if len(single_v) != 2: 86 | raise ValueError( 87 | "Tuple value_type must have exactly two elements." 88 | ) 89 | base, subtype = single_v 90 | if base not in {"str", "list"}: 91 | raise ValueError( 92 | "First element of tuple must be 'str' or 'list'." 93 | ) 94 | if base == "str": 95 | if not isinstance(subtype, list) or not all( 96 | isinstance(item, str) for item in subtype 97 | ): 98 | raise ValueError( 99 | "Second element must be a list of strings when base is 'str'." 100 | ) 101 | elif base == "list" and not ( 102 | ( 103 | isinstance(subtype, list) 104 | and all(isinstance(item, str) for item in subtype) 105 | ) 106 | or isinstance(subtype, str) 107 | ): 108 | raise ValueError( 109 | "Second element must be a list of strings or 'str' when base is 'list'." 110 | ) 111 | else: 112 | raise TypeError( 113 | "Each value_type must be either a string or a tuple." 114 | ) 115 | 116 | if isinstance(v, list): 117 | if not v: 118 | raise ValueError("value_type list cannot be empty.") 119 | for item in v: 120 | validate_single_type(item) 121 | else: 122 | validate_single_type(v) 123 | 124 | return v 125 | 126 | 127 | # Define the OBJECT_TAG_REGISTER with Tag instances 128 | OBJECT_TAG_REGISTER: List[Tag] = [ 129 | # -------------------------- 130 | # All objects and estimators 131 | # -------------------------- 132 | Tag( 133 | name="object_type", 134 | scitype="object", 135 | value_type=("str", ["regressor", "transformer"]), 136 | description="Type of object, e.g., 'regressor', 'transformer'.", 137 | ), 138 | Tag( 139 | name="python_version", 140 | scitype="object", 141 | value_type="str", 142 | description="Python version specifier (PEP 440) for estimator, or None for all versions.", 143 | ), 144 | Tag( 145 | name="python_dependencies", 146 | scitype="object", 147 | # Allow both string and list of strings 148 | value_type=["str", ("list", "str")], 149 | description="Python dependencies of estimator as string or list of strings.", 150 | ), 151 | Tag( 152 | name="python_dependencies_alias", 153 | scitype="object", 154 | value_type="dict", 155 | description="Alias for Python dependencies if import name differs from package name. Key-value pairs are package name and import name.", 156 | ), 157 | # ----------------------- 158 | # BaseTimeSeriesBootstrap 159 | # ----------------------- 160 | Tag( 161 | name="bootstrap_type", 162 | scitype="bootstrap", 163 | value_type=("list", "str"), 164 | description="Type(s) of bootstrap the algorithm supports.", 165 | ), 166 | Tag( 167 | name="capability:multivariate", 168 | scitype="bootstrap", 169 | value_type="bool", 170 | description="Whether the bootstrap algorithm supports multivariate data.", 171 | ), 172 | # ---------------------------- 173 | # BaseMetaObject reserved tags 174 | # ---------------------------- 175 | Tag( 176 | name="named_object_parameters", 177 | scitype="object", 178 | value_type="str", 179 | description="Name of component list attribute for meta-objects.", 180 | ), 181 | Tag( 182 | name="fitted_named_object_parameters", 183 | scitype="estimator", 184 | value_type="str", 185 | description="Name of fitted component list attribute for meta-objects.", 186 | ), 187 | ] 188 | 189 | # Create OBJECT_TAG_TABLE as a list of dictionaries 190 | OBJECT_TAG_TABLE: List[Dict[str, Any]] = [ 191 | { 192 | "name": tag.name, 193 | "scitype": tag.scitype, 194 | "value_type": tag.value_type, 195 | "description": tag.description, 196 | } 197 | for tag in OBJECT_TAG_REGISTER 198 | ] 199 | 200 | # Create OBJECT_TAG_LIST as a list of tag names 201 | OBJECT_TAG_LIST: List[str] = [tag.name for tag in OBJECT_TAG_REGISTER] 202 | 203 | 204 | def check_tag_is_valid(tag_name: str, tag_value: Any) -> bool: 205 | """ 206 | Check whether a tag value is valid for a given tag name. 207 | 208 | Parameters 209 | ---------- 210 | tag_name : str 211 | The name of the tag to validate. 212 | tag_value : Any 213 | The value to validate against the tag's expected type. 214 | 215 | Returns 216 | ------- 217 | bool 218 | True if the tag value is valid for the tag name, False otherwise. 219 | 220 | Raises 221 | ------ 222 | KeyError 223 | If the tag_name is not found in OBJECT_TAG_REGISTER. 224 | """ 225 | try: 226 | tag = next(tag for tag in OBJECT_TAG_REGISTER if tag.name == tag_name) 227 | except StopIteration as e: 228 | raise KeyError( 229 | f"Tag name '{tag_name}' not found in OBJECT_TAG_REGISTER." 230 | ) from e 231 | 232 | value_type = tag.value_type 233 | 234 | if isinstance(value_type, list): 235 | # Iterate through each type definition and return True if any matches 236 | for vt in value_type: 237 | if isinstance(vt, str): 238 | if isinstance(tag_value, str): 239 | return True 240 | elif isinstance(vt, tuple): 241 | base_type, subtype = vt 242 | if base_type == "str": 243 | if isinstance(tag_value, str) and tag_value in subtype: 244 | return True 245 | elif base_type == "list" and isinstance(tag_value, list): 246 | if subtype == "str": 247 | if all(isinstance(item, str) for item in tag_value): 248 | return True 249 | elif isinstance(subtype, list) and all( 250 | isinstance(item, str) and item in subtype 251 | for item in tag_value 252 | ): 253 | return True 254 | return False 255 | elif isinstance(value_type, str): 256 | expected_type = value_type 257 | if expected_type == "bool": 258 | return isinstance(tag_value, bool) 259 | elif expected_type == "int": 260 | return isinstance(tag_value, int) 261 | elif expected_type == "str": 262 | return isinstance(tag_value, str) 263 | elif expected_type == "list": 264 | return isinstance(tag_value, list) 265 | elif expected_type == "dict": 266 | return isinstance(tag_value, dict) 267 | else: 268 | return False 269 | elif isinstance(value_type, tuple): 270 | base_type, subtype = value_type 271 | if base_type == "str": 272 | if isinstance(tag_value, str): 273 | return tag_value in subtype 274 | return False 275 | elif base_type == "list": 276 | if not isinstance(tag_value, list): 277 | return False 278 | if isinstance(subtype, list): 279 | return all( 280 | isinstance(item, str) and item in subtype 281 | for item in tag_value 282 | ) 283 | elif subtype == "str": 284 | return all(isinstance(item, str) for item in tag_value) 285 | return False 286 | else: 287 | return False 288 | -------------------------------------------------------------------------------- /src/tsbootstrap/registry/tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for registry and lookup functionality.""" 2 | -------------------------------------------------------------------------------- /src/tsbootstrap/registry/tests/test_tags.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for the tag registry and tag validation functionality. 3 | 4 | This module contains tests to ensure that the `OBJECT_TAG_REGISTER` is correctly 5 | configured and that each tag adheres to the specified structure and type constraints. 6 | """ 7 | 8 | from tsbootstrap.registry._tags import ( 9 | OBJECT_TAG_LIST, 10 | OBJECT_TAG_REGISTER, 11 | OBJECT_TAG_TABLE, 12 | Tag, 13 | check_tag_is_valid, 14 | ) 15 | 16 | 17 | def test_tag_register_type(): 18 | """ 19 | Test the specification of the tag register. 20 | 21 | Ensures that `OBJECT_TAG_REGISTER` is a list of `Tag` instances with the correct attributes and types. 22 | 23 | Raises 24 | ------ 25 | TypeError 26 | If `OBJECT_TAG_REGISTER` is not a list or contains non-`Tag` instances. 27 | ValueError 28 | If any `Tag` instance does not conform to the expected structure or type constraints. 29 | """ 30 | # Verify that OBJECT_TAG_REGISTER is a list 31 | if not isinstance(OBJECT_TAG_REGISTER, list): 32 | raise TypeError("`OBJECT_TAG_REGISTER` is not a list.") 33 | 34 | # Verify that all elements in OBJECT_TAG_REGISTER are instances of Tag 35 | if not all(isinstance(tag, Tag) for tag in OBJECT_TAG_REGISTER): 36 | raise TypeError( 37 | "Not all elements in `OBJECT_TAG_REGISTER` are `Tag` instances." 38 | ) 39 | 40 | # Iterate through each Tag instance to validate its attributes 41 | for tag in OBJECT_TAG_REGISTER: 42 | # Validate the 'name' attribute 43 | if not isinstance(tag.name, str): 44 | raise TypeError(f"Tag name '{tag.name}' is not a string.") 45 | 46 | # Validate the 'scitype' attribute 47 | if not isinstance(tag.scitype, str): 48 | raise TypeError(f"Tag scitype '{tag.scitype}' is not a string.") 49 | 50 | # Validate the 'value_type' attribute 51 | if not isinstance(tag.value_type, (str, tuple, list)): 52 | raise TypeError( 53 | f"Tag value_type '{tag.value_type}' is not a string, tuple, or list." 54 | ) 55 | 56 | if isinstance(tag.value_type, tuple): 57 | if len(tag.value_type) != 2: 58 | raise ValueError( 59 | "Tuple `value_type` must have exactly two elements." 60 | ) 61 | 62 | base_type, subtype = tag.value_type 63 | 64 | # Validate the base type 65 | if base_type not in {"str", "list"}: 66 | raise ValueError( 67 | f"First element of `value_type` tuple must be 'str' or 'list', got '{base_type}'." 68 | ) 69 | 70 | # Validate the subtype based on the base type 71 | if base_type == "str": 72 | if not isinstance(subtype, list) or not all( 73 | isinstance(item, str) for item in subtype 74 | ): 75 | raise TypeError( 76 | "Second element of `value_type` tuple must be a list of strings when base is 'str'." 77 | ) 78 | elif base_type == "list" and not ( 79 | ( 80 | isinstance(subtype, list) 81 | and all(isinstance(item, str) for item in subtype) 82 | ) 83 | or isinstance(subtype, str) 84 | ): 85 | raise TypeError( 86 | "Second element of `value_type` tuple must be a list of strings or 'str' when base is 'list'." 87 | ) 88 | 89 | elif isinstance(tag.value_type, list): 90 | if not tag.value_type: 91 | raise ValueError("`value_type` list cannot be empty.") 92 | 93 | for vt in tag.value_type: 94 | if isinstance(vt, str): 95 | if vt not in {"bool", "int", "str", "list", "dict"}: 96 | raise ValueError( 97 | f"Invalid value_type in list: {vt}. Must be one of {{'bool', 'int', 'str', 'list', 'dict'}}." 98 | ) 99 | elif isinstance(vt, tuple): 100 | if len(vt) != 2: 101 | raise ValueError( 102 | "Tuple in `value_type` list must have exactly two elements." 103 | ) 104 | base, subtype = vt 105 | if base not in {"str", "list"}: 106 | raise ValueError( 107 | "First element of tuple in `value_type` list must be 'str' or 'list'." 108 | ) 109 | if base == "str": 110 | if not isinstance(subtype, list) or not all( 111 | isinstance(item, str) for item in subtype 112 | ): 113 | raise TypeError( 114 | "Second element of tuple in `value_type` list must be a list of strings when base is 'str'." 115 | ) 116 | elif base == "list" and not ( 117 | ( 118 | isinstance(subtype, list) 119 | and all(isinstance(item, str) for item in subtype) 120 | ) 121 | or isinstance(subtype, str) 122 | ): 123 | raise TypeError( 124 | "Second element of tuple in `value_type` list must be a list of strings or 'str' when base is 'list'." 125 | ) 126 | else: 127 | raise TypeError( 128 | "`value_type` list elements must be either strings or tuples." 129 | ) 130 | 131 | # Validate the 'description' attribute 132 | if not isinstance(tag.description, str): 133 | raise TypeError( 134 | f"Tag description '{tag.description}' is not a string." 135 | ) 136 | 137 | 138 | def test_object_tag_table_structure(): 139 | """ 140 | Test the structure of `OBJECT_TAG_TABLE`. 141 | 142 | Ensures that `OBJECT_TAG_TABLE` is a list of dictionaries, each containing the expected keys and corresponding types. 143 | 144 | Raises 145 | ------ 146 | TypeError 147 | If `OBJECT_TAG_TABLE` is not a list or contains elements that are not dictionaries. 148 | KeyError 149 | If any dictionary in `OBJECT_TAG_TABLE` is missing required keys. 150 | TypeError 151 | If any value in the dictionaries does not match the expected type. 152 | """ 153 | # Define the expected keys and their types 154 | expected_keys = { 155 | "name": str, 156 | "scitype": str, 157 | "value_type": (str, tuple, list), 158 | "description": str, 159 | } 160 | 161 | # Verify that OBJECT_TAG_TABLE is a list 162 | if not isinstance(OBJECT_TAG_TABLE, list): 163 | raise TypeError("`OBJECT_TAG_TABLE` is not a list.") 164 | 165 | # Iterate through each dictionary in OBJECT_TAG_TABLE to validate its structure 166 | for entry in OBJECT_TAG_TABLE: 167 | # Verify that each entry is a dictionary 168 | if not isinstance(entry, dict): 169 | raise TypeError( 170 | "Each entry in `OBJECT_TAG_TABLE` must be a dictionary." 171 | ) 172 | 173 | # Check for the presence of all expected keys 174 | for key, expected_type in expected_keys.items(): 175 | if key not in entry: 176 | raise KeyError( 177 | f"Key '{key}' is missing from an entry in `OBJECT_TAG_TABLE`." 178 | ) 179 | 180 | # Validate the type of each value 181 | if not isinstance(entry[key], expected_type): 182 | raise TypeError( 183 | f"Value for key '{key}' in `OBJECT_TAG_TABLE` entry is not of type {expected_type}." 184 | ) 185 | 186 | 187 | def test_object_tag_list(): 188 | """ 189 | Test the contents of `OBJECT_TAG_LIST`. 190 | 191 | Ensures that `OBJECT_TAG_LIST` contains all tag names present in `OBJECT_TAG_REGISTER` and that each name is a string. 192 | 193 | Raises 194 | ------ 195 | TypeError 196 | If `OBJECT_TAG_LIST` is not a list or contains non-string elements. 197 | ValueError 198 | If any tag name in `OBJECT_TAG_REGISTER` is missing from `OBJECT_TAG_LIST`. 199 | """ 200 | # Verify that OBJECT_TAG_LIST is a list 201 | if not isinstance(OBJECT_TAG_LIST, list): 202 | raise TypeError("`OBJECT_TAG_LIST` is not a list.") 203 | 204 | # Verify that all elements in OBJECT_TAG_LIST are strings 205 | if not all(isinstance(name, str) for name in OBJECT_TAG_LIST): 206 | raise TypeError("All elements in `OBJECT_TAG_LIST` must be strings.") 207 | 208 | # Extract all tag names from OBJECT_TAG_REGISTER 209 | tag_names = {tag.name for tag in OBJECT_TAG_REGISTER} 210 | 211 | # Verify that OBJECT_TAG_LIST contains all tag names 212 | missing_tags = tag_names - set(OBJECT_TAG_LIST) 213 | if missing_tags: 214 | raise ValueError( 215 | f"The following tags are missing from `OBJECT_TAG_LIST`: {missing_tags}" 216 | ) 217 | 218 | 219 | def test_check_tag_is_valid(): 220 | """ 221 | Test the `check_tag_is_valid` function. 222 | 223 | Ensures that `check_tag_is_valid` correctly validates tag values based on their expected types. 224 | 225 | Raises 226 | ------ 227 | AssertionError 228 | If any test case fails. 229 | """ 230 | # Define test cases as tuples of (tag_name, tag_value, expected_result) 231 | test_cases = [ 232 | ("object_type", "regressor", True), 233 | ("object_type", "transformer", True), 234 | ( 235 | "object_type", 236 | "classifier", 237 | False, 238 | ), # Should be False as it's not in the allowed list 239 | ("object_type", "invalid_type", False), 240 | ("capability:multivariate", True, True), 241 | ("capability:multivariate", False, True), 242 | ("capability:multivariate", "yes", False), 243 | ("python_version", "3.8.5", True), 244 | ("python_version", 3.8, False), 245 | ("python_dependencies", ["numpy", "pandas"], True), 246 | ("python_dependencies", "numpy", True), 247 | ("python_dependencies", ["numpy", 123], False), 248 | ("python_dependencies_alias", {"numpy": "np"}, True), 249 | ("python_dependencies_alias", "numpy", False), 250 | ("non_existent_tag", "value", False), # Should raise KeyError 251 | ] 252 | 253 | for tag_name, tag_value, expected in test_cases: 254 | if tag_name == "non_existent_tag": 255 | try: 256 | check_tag_is_valid(tag_name, tag_value) 257 | raise AssertionError( 258 | f"Expected KeyError for tag '{tag_name}', but no error was raised." 259 | ) 260 | except KeyError: 261 | pass # Expected behavior 262 | else: 263 | result = check_tag_is_valid(tag_name, tag_value) 264 | if result != expected: 265 | raise AssertionError( 266 | f"check_tag_is_valid({tag_name!r}, {tag_value!r}) returned {result}, expected {expected}." 267 | ) 268 | -------------------------------------------------------------------------------- /src/tsbootstrap/tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Suite tests for tsbootstrap package.""" 2 | -------------------------------------------------------------------------------- /src/tsbootstrap/tests/scenarios/__init__.py: -------------------------------------------------------------------------------- 1 | """Test scenarios for estimators.""" 2 | -------------------------------------------------------------------------------- /src/tsbootstrap/tests/scenarios/scenarios.py: -------------------------------------------------------------------------------- 1 | """Testing utility to play back usage scenarios for estimators. 2 | 3 | Contains TestScenario class which applies method/args subsequently 4 | """ 5 | 6 | # copied from sktime. Should be jointly refactored to scikit-base. 7 | 8 | __author__ = ["fkiraly"] 9 | 10 | __all__ = ["TestScenario"] 11 | 12 | 13 | from copy import deepcopy 14 | from inspect import isclass 15 | 16 | 17 | class TestScenario: 18 | """Class to run pre-defined method execution scenarios for objects. 19 | 20 | Parameters 21 | ---------- 22 | args : dict of dict, default = None 23 | dict of argument dicts to be used in methods 24 | names for keys need not equal names of methods these are used in 25 | but scripted method will look at key with same name as default 26 | must be passed to constructor, set in a child class 27 | or dynamically created in get_args 28 | default_method_sequence : list of str, default = None 29 | default sequence for methods to be called 30 | optional, if given, default method sequence to use in `run` 31 | if not provided, at least one of the sequence arguments must be passed in `run` 32 | or default_arg_sequence must be provided 33 | default_arg_sequence : list of str, default = None 34 | default sequence of keys for keyword argument dicts to be used 35 | names for keys need not equal names of methods 36 | if not provided, at least one of the sequence arguments must be passed in `run` 37 | or default_method_sequence must be provided 38 | 39 | Methods 40 | ------- 41 | run(obj, args=None, default_method_sequence=None) 42 | Run a call(args) scenario on obj, and retrieve method outputs. 43 | is_applicable(obj) 44 | Check whether scenario is applicable to obj. 45 | get_args(key, obj) 46 | Dynamically create args for call defined by key and obj. 47 | Defaults to self.args[key] if not overridden. 48 | """ 49 | 50 | def __init__( 51 | self, 52 | args=None, 53 | default_method_sequence=None, 54 | default_arg_sequence=None, 55 | ): 56 | if default_method_sequence is not None: 57 | self.default_method_sequence = _check_list_of_str( 58 | default_method_sequence 59 | ) 60 | elif not hasattr(self, "default_method_sequence"): 61 | self.default_method_sequence = None 62 | if default_arg_sequence is not None: 63 | self.default_arg_sequence = _check_list_of_str( 64 | default_arg_sequence 65 | ) 66 | elif not hasattr(self, "default_arg_sequence"): 67 | self.default_arg_sequence = None 68 | if args is not None: 69 | self.args = _check_dict_of_dict(args) 70 | else: 71 | if not hasattr(self, "args"): 72 | raise RuntimeError( 73 | f"{self.__class__.__name__} (scenario class) failed to construct, " 74 | "args must either be given to __init__ or set as an attribute" 75 | ) 76 | _check_dict_of_dict(self.args) 77 | 78 | def get_args(self, key, obj=None, deepcopy_args=True): 79 | """Return args for key. Can be overridden for dynamic arg generation. 80 | 81 | If overridden, must not have any side effects on self.args 82 | e.g., avoid assignments args[key] = x without deepcopying self.args first 83 | 84 | Parameters 85 | ---------- 86 | key : str, argument key to construct/retrieve args for 87 | obj : obj, optional, default=None. Object to construct args for. 88 | deepcopy_args : bool, optional, default=True. Whether to deepcopy return. 89 | 90 | Returns 91 | ------- 92 | args : argument dict to be used for a method, keyed by `key` 93 | names for keys need not equal names of methods these are used in 94 | but scripted method will look at key with same name as default 95 | """ 96 | args = self.args.get(key, {}) 97 | if deepcopy_args: 98 | args = deepcopy(args) 99 | return args 100 | 101 | def run( 102 | self, 103 | obj, 104 | method_sequence=None, 105 | arg_sequence=None, 106 | return_all=False, 107 | return_args=False, 108 | deepcopy_return=False, 109 | ): 110 | """Run a call(args) scenario on obj, and retrieve method outputs. 111 | 112 | Runs a sequence of commands 113 | res_1 = obj.method_1(**args_1) 114 | res_2 = obj.method_2(**args_2) 115 | etc, where method_i is method_sequence[i], 116 | and args_i is self.args[arg_sequence[i]] 117 | and returns results. Args are passed as deepcopy to avoid side effects. 118 | 119 | if method_i is __init__ (a constructor), 120 | obj is changed to obj.__init__(**args_i) from the next line on 121 | 122 | Parameters 123 | ---------- 124 | obj : class or object with methods in method_sequence 125 | method_sequence : list of str, default = arg_sequence if passed 126 | if arg_sequence is also None, then default = self.default_method_sequence 127 | sequence of method names to be run 128 | arg_sequence : list of str, default = method_sequence if passed 129 | if method_sequence is also None, then default = self.default_arg_sequence 130 | sequence of keys for keyword argument dicts to be used 131 | names for keys need not equal names of methods 132 | return_all : bool, default = False 133 | whether all or only the last result should be returned 134 | if False, only the last result is returned 135 | if True, list of deepcopies of intermediate results is returned 136 | return_args : bool, default = False 137 | whether arguments should also be returned 138 | if False, there is no second return argument 139 | if True, "args_after_call" return argument is returned 140 | deepcopy_return : bool, default = False 141 | whether returns are deepcopied before returned 142 | if True, returns are deepcopies of return 143 | if False, returns are references/assignments, not deepcopies 144 | NOTE: if self is returned (e.g., in fit), and deepcopy_return=False 145 | method calls may continue to have side effects on that return 146 | 147 | Returns 148 | ------- 149 | results : output of the last method call, if return_all = False 150 | list of deepcopies of all outputs, if return_all = True 151 | args_after_call : list of args after method call, only if return_args = True 152 | i-th element is deepcopy of args of i-th method call, after method call 153 | this is possibly subject to side effects by the method 154 | """ 155 | # if both None, fill with defaults if exist 156 | if method_sequence is None and arg_sequence is None: 157 | method_sequence = getattr(self, "default_method_sequence", None) 158 | arg_sequence = getattr(self, "default_arg_sequence", None) 159 | 160 | # if both are still None, raise an error 161 | if method_sequence is None and arg_sequence is None: 162 | raise ValueError( 163 | "at least one of method_sequence, arg_sequence must be not None " 164 | "if no defaults are set in the class" 165 | ) 166 | 167 | # if only one is None, fill one with the other 168 | if method_sequence is None: 169 | method_sequence = _check_list_of_str(arg_sequence) 170 | else: 171 | method_sequence = _check_list_of_str(method_sequence) 172 | if arg_sequence is None: 173 | arg_sequence = _check_list_of_str(method_sequence) 174 | else: 175 | arg_sequence = _check_list_of_str(arg_sequence) 176 | 177 | # check that length of sequences is the same 178 | num_calls = len(arg_sequence) 179 | if not num_calls == len(method_sequence): 180 | raise ValueError( 181 | "arg_sequence and method_sequence must have same length" 182 | ) 183 | 184 | # execute the commands in sequence, report result(s) 185 | results = [] 186 | args_after_call = [] 187 | for i in range(num_calls): 188 | methodname = method_sequence[i] 189 | args = deepcopy(self.get_args(key=arg_sequence[i], obj=obj)) 190 | 191 | if methodname != "__init__": 192 | res = getattr(obj, methodname)(**args) 193 | # if constructor is called, run directly and replace obj 194 | else: 195 | res = obj(**args) if isclass(obj) else type(obj)(**args) 196 | obj = res 197 | 198 | args_after_call += [args] 199 | 200 | if deepcopy_return: 201 | res = deepcopy(res) 202 | 203 | if return_all: 204 | results += [res] 205 | else: 206 | results = res 207 | 208 | if return_args: 209 | return results, args_after_call 210 | else: 211 | return results 212 | 213 | def is_applicable(self, obj): 214 | """Check whether scenario is applicable to obj. 215 | 216 | Abstract method, children should implement. This just returns "true". 217 | 218 | Example for child class: scenario is univariate time series forecasting. 219 | Then, this returns False on multivariate, True on univariate forecasters. 220 | 221 | Parameters 222 | ---------- 223 | obj : class or object to check against scenario 224 | 225 | Returns 226 | ------- 227 | applicable: bool 228 | True if self is applicable to obj, False if not 229 | "applicable" is defined as the implementer chooses, as output of this method 230 | False is typically used as a "skip" flag in unit or integration testing 231 | """ 232 | return True 233 | 234 | 235 | def _check_list_of_str(obj, name="obj"): 236 | """Check whether obj is a list of str. 237 | 238 | Parameters 239 | ---------- 240 | obj : any object, check whether is list of str 241 | name : str, default="obj", name of obj to display in error message 242 | 243 | Returns 244 | ------- 245 | obj, unaltered 246 | 247 | Raises 248 | ------ 249 | TypeError if obj is not list of str 250 | """ 251 | if not isinstance(obj, list) or not all(isinstance(x, str) for x in obj): 252 | raise TypeError(f"{name} must be a list of str") 253 | return obj 254 | 255 | 256 | def _check_dict_of_dict(obj, name="obj"): 257 | """Check whether obj is a dict of dict, with str keys. 258 | 259 | Parameters 260 | ---------- 261 | obj : any object, check whether is dict of dict, with str keys 262 | name : str, default="obj", name of obj to display in error message 263 | 264 | Returns 265 | ------- 266 | obj, unaltered 267 | 268 | Raises 269 | ------ 270 | TypeError if obj is not dict of dict, with str keys 271 | """ 272 | if not ( 273 | isinstance(obj, dict) 274 | and all(isinstance(x, dict) for x in obj.values()) 275 | and all(isinstance(x, str) for x in obj) 276 | ): 277 | raise TypeError(f"{name} must be a dict of dict, with str keys") 278 | return obj 279 | -------------------------------------------------------------------------------- /src/tsbootstrap/tests/scenarios/scenarios_bootstrap.py: -------------------------------------------------------------------------------- 1 | """Test scenarios for classification and regression. 2 | 3 | Contains TestScenario concrete children to run in tests for classifiers/regressirs. 4 | """ 5 | 6 | __author__ = ["fkiraly"] 7 | 8 | __all__ = ["scenarios_bootstrap"] 9 | 10 | from inspect import isclass 11 | 12 | import numpy as np 13 | from skbase.base import BaseObject 14 | 15 | from tsbootstrap.tests.scenarios.scenarios import TestScenario 16 | 17 | RAND_SEED = 42 18 | 19 | rng = np.random.default_rng(RAND_SEED) 20 | 21 | 22 | class _BootstrapTestScenario(TestScenario, BaseObject): 23 | """Generic test scenario for classifiers.""" 24 | 25 | def is_applicable(self, obj): 26 | """Check whether scenario is applicable to obj. 27 | 28 | Parameters 29 | ---------- 30 | obj : class or object to check against scenario 31 | 32 | Returns 33 | ------- 34 | applicable: bool 35 | True if self is applicable to obj, False if not 36 | """ 37 | 38 | def get_tag(obj, tag_name): 39 | if isclass(obj): 40 | return obj.get_class_tag(tag_name) 41 | else: 42 | return obj.get_tag(tag_name) 43 | 44 | def scitype(obj): 45 | type_tag = obj.get_class_tag("object_type", "object") 46 | return type_tag 47 | 48 | if scitype(obj) != "bootstrap": 49 | return False 50 | 51 | is_multivariate = not self.get_tag( 52 | "X_univariate", False, raise_error=False 53 | ) 54 | 55 | obj_can_handle_multivariate = get_tag(obj, "capability:multivariate") 56 | 57 | return not (is_multivariate and not obj_can_handle_multivariate) 58 | 59 | 60 | X_np_uni = rng.random((20, 1)) 61 | X_np_mult = rng.random((20, 2)) 62 | exog_np = rng.random((20, 3)) 63 | 64 | 65 | class BootstrapBasicUnivar(_BootstrapTestScenario): 66 | """Simple call, only endogenous data.""" 67 | 68 | _tags = { 69 | "X_univariate": True, 70 | "exog_present": False, 71 | "return_index": False, 72 | } 73 | 74 | args = {"bootstrap": {"X": X_np_uni}} 75 | default_method_sequence = ["bootstrap", "get_n_bootstraps"] 76 | default_arg_sequence = ["bootstrap", "bootstrap"] 77 | 78 | 79 | class BootstrapExogUnivar(_BootstrapTestScenario): 80 | """Call with endogenous and exogenous data.""" 81 | 82 | _tags = { 83 | "X_univariate": True, 84 | "exog_present": True, 85 | "return_index": False, 86 | } 87 | 88 | args = {"bootstrap": {"X": X_np_uni, "y": exog_np}} 89 | default_method_sequence = ["bootstrap", "get_n_bootstraps"] 90 | default_arg_sequence = ["bootstrap", "bootstrap"] 91 | 92 | 93 | class BootstrapUnivarRetIx(_BootstrapTestScenario): 94 | """Call with endogenous and exogenous data, and query to return index.""" 95 | 96 | _tags = { 97 | "X_univariate": True, 98 | "exog_present": True, 99 | "return_index": True, 100 | } 101 | 102 | args = { 103 | "bootstrap": {"X": X_np_uni, "y": exog_np, "return_indices": True}, 104 | "get_n_bootstraps": {"X": X_np_uni, "y": exog_np}, 105 | } 106 | default_method_sequence = ["bootstrap", "get_n_bootstraps"] 107 | default_arg_sequence = ["bootstrap", "bootstrap"] 108 | 109 | 110 | class BootstrapBasicMultivar(_BootstrapTestScenario): 111 | """Simple call, only endogenous data.""" 112 | 113 | _tags = { 114 | "X_univariate": False, 115 | "exog_present": False, 116 | "return_index": False, 117 | } 118 | 119 | args = {"bootstrap": {"X": X_np_mult}} 120 | default_method_sequence = ["bootstrap", "get_n_bootstraps"] 121 | default_arg_sequence = ["bootstrap", "bootstrap"] 122 | 123 | 124 | class BootstrapExogMultivar(_BootstrapTestScenario): 125 | """Call with endogenous and exogenous data.""" 126 | 127 | _tags = { 128 | "X_univariate": False, 129 | "exog_present": True, 130 | "return_index": False, 131 | } 132 | 133 | args = {"bootstrap": {"X": X_np_mult, "y": exog_np}} 134 | default_method_sequence = ["bootstrap", "get_n_bootstraps"] 135 | default_arg_sequence = ["bootstrap", "bootstrap"] 136 | 137 | 138 | class BootstrapMultivarRetIx(_BootstrapTestScenario): 139 | """Call with endogenous and exogenous data, and query to return index.""" 140 | 141 | _tags = { 142 | "X_univariate": False, 143 | "exog_present": True, 144 | "return_index": True, 145 | } 146 | 147 | args = { 148 | "bootstrap": {"X": X_np_mult, "y": exog_np, "return_indices": True}, 149 | "get_n_bootstraps": {"X": X_np_mult, "y": exog_np}, 150 | } 151 | default_method_sequence = ["bootstrap", "get_n_bootstraps"] 152 | default_arg_sequence = ["bootstrap", "bootstrap"] 153 | 154 | 155 | scenarios_bootstrap = [ 156 | BootstrapBasicUnivar, 157 | BootstrapExogUnivar, 158 | BootstrapUnivarRetIx, 159 | BootstrapBasicMultivar, 160 | BootstrapExogMultivar, 161 | BootstrapMultivarRetIx, 162 | ] 163 | -------------------------------------------------------------------------------- /src/tsbootstrap/tests/scenarios/scenarios_getter.py: -------------------------------------------------------------------------------- 1 | """Retrieval utility for test scenarios.""" 2 | 3 | # copied from sktime. Should be jointly refactored to scikit-base. 4 | 5 | __author__ = ["fkiraly"] 6 | 7 | __all__ = ["retrieve_scenarios"] 8 | 9 | 10 | from inspect import isclass 11 | 12 | from tsbootstrap.tests.scenarios.scenarios_bootstrap import scenarios_bootstrap 13 | 14 | scenarios = {} 15 | scenarios["bootstrap"] = scenarios_bootstrap 16 | 17 | 18 | def retrieve_scenarios(obj, filter_tags=None): 19 | """Retrieve test scenarios for obj, or by estimator scitype string. 20 | 21 | Exactly one of the arguments obj, estimator_type must be provided. 22 | 23 | Parameters 24 | ---------- 25 | obj : class or object, or string, or list of str. 26 | Which kind of estimator/object to retrieve scenarios for. 27 | If object, must be a class or object inheriting from BaseObject. 28 | If string(s), must be in registry.BASE_CLASS_REGISTER (first col) 29 | for instance 'classifier', 'regressor', 'transformer', 'forecaster' 30 | filter_tags: dict of (str or list of str), default=None 31 | subsets the returned objectss as follows: 32 | each key/value pair is statement in "and"/conjunction 33 | key is tag name to sub-set on 34 | value str or list of string are tag values 35 | condition is "key must be equal to value, or in set(value)" 36 | 37 | Returns 38 | ------- 39 | scenarios : list of objects, instances of BaseScenario 40 | """ 41 | # if class, get scitypes from inference; otherwise, str or list of str 42 | if not isinstance(obj, str): 43 | if isclass(obj): 44 | if hasattr(obj, "get_class_tag"): 45 | estimator_type = obj.get_class_tag("object_type", "object") 46 | else: 47 | estimator_type = "object" 48 | else: 49 | if hasattr(obj, "get_tag"): 50 | estimator_type = obj.get_tag("object_type", "object", False) 51 | else: 52 | estimator_type = "object" 53 | else: 54 | estimator_type = obj 55 | 56 | # coerce to list, ensure estimator_type is list of str 57 | if not isinstance(estimator_type, list): 58 | estimator_type = [estimator_type] 59 | 60 | # now loop through types and retrieve scenarios 61 | scenarios_for_type = [] 62 | for est_type in estimator_type: 63 | scens = scenarios.get(est_type) 64 | if scens is not None: 65 | scenarios_for_type += scenarios.get(est_type) 66 | 67 | # instantiate all scenarios by calling constructor 68 | scenarios_for_type = [x() for x in scenarios_for_type] 69 | 70 | # if obj was an object, filter to applicable scenarios 71 | if not isinstance(obj, str) and not isinstance(obj, list): 72 | scenarios_for_type = [ 73 | x for x in scenarios_for_type if x.is_applicable(obj) 74 | ] 75 | 76 | if filter_tags is not None: 77 | scenarios_for_type = [ 78 | scen 79 | for scen in scenarios_for_type 80 | if _check_tag_cond(scen, filter_tags) 81 | ] 82 | 83 | return scenarios_for_type 84 | 85 | 86 | def _check_tag_cond(obj, filter_tags=None): 87 | """Check whether object satisfies filter_tags condition. 88 | 89 | Parameters 90 | ---------- 91 | obj: object inheriting from sktime BaseObject 92 | filter_tags: dict of (str or list of str), default=None 93 | subsets the returned objectss as follows: 94 | each key/value pair is statement in "and"/conjunction 95 | key is tag name to sub-set on 96 | value str or list of string are tag values 97 | condition is "key must be equal to value, or in set(value)" 98 | 99 | Returns 100 | ------- 101 | cond_sat: bool, whether estimator satisfies condition in filter_tags 102 | """ 103 | if not isinstance(filter_tags, dict): 104 | raise TypeError("filter_tags must be a dict") 105 | 106 | cond_sat = True 107 | 108 | for key, value in filter_tags.items(): 109 | if not isinstance(value, list): 110 | value = [value] 111 | cond_sat = cond_sat and obj.get_class_tag(key) in set(value) 112 | 113 | return cond_sat 114 | -------------------------------------------------------------------------------- /src/tsbootstrap/tests/test_all_bootstraps.py: -------------------------------------------------------------------------------- 1 | """Automated tests based on the skbase test suite template.""" 2 | 3 | import inspect 4 | 5 | import numpy as np 6 | import pytest 7 | from skbase.testing import QuickTester 8 | 9 | from tsbootstrap.tests.test_all_estimators import ( 10 | BaseFixtureGenerator, 11 | PackageConfig, 12 | ) 13 | 14 | 15 | class TestAllBootstraps(PackageConfig, BaseFixtureGenerator, QuickTester): 16 | """Generic tests for all bootstrap algorithms in tsbootstrap.""" 17 | 18 | # class variables which can be overridden by descendants 19 | # ------------------------------------------------------ 20 | 21 | # which object types are generated; None=all, or class (passed to all_objects) 22 | object_type_filter = "bootstrap" 23 | 24 | def test_class_signature(self, object_class): 25 | """Check constraints on class init signature. 26 | 27 | Tests that: 28 | 29 | * the first parameter is n_bootstraps, with default 10 30 | * all parameters have defaults 31 | """ 32 | init_signature = inspect.signature(object_class.__init__) 33 | 34 | # Consider the constructor parameters excluding 'self' 35 | param_names_in_order = [ 36 | p.name 37 | for p in init_signature.parameters.values() 38 | if p.name != "self" and p.kind != p.VAR_KEYWORD 39 | ] 40 | 41 | param_defaults = object_class.get_param_defaults() 42 | 43 | # test that all parameters have defaults 44 | params_without_default = [ 45 | param 46 | for param in param_names_in_order 47 | if param not in param_defaults 48 | ] 49 | 50 | assert len(params_without_default) == 0, ( 51 | f"All parameters of bootstraps must have default values. " 52 | f"Init parameters without default values: {params_without_default}. " 53 | ) 54 | 55 | # test that first parameter is n_bootstraps, with default 10 56 | assert param_names_in_order[0] == "n_bootstraps" 57 | assert param_defaults["n_bootstraps"] == 10 58 | 59 | def test_n_bootstraps(self, object_instance): 60 | """Tests handling of n_bootstraps parameter.""" 61 | cls_name = object_instance.__class__.__name__ 62 | 63 | params = object_instance.get_params() 64 | 65 | if "n_bootstraps" not in params: 66 | raise ValueError( 67 | f"{cls_name} is a bootstrap algorithm and must have " 68 | "n_bootstraps parameter, but it does not." 69 | ) 70 | 71 | n_bootstraps = params["n_bootstraps"] 72 | 73 | get_n_bootstraps = object_instance.get_n_bootstraps() 74 | 75 | if not get_n_bootstraps == n_bootstraps: 76 | raise ValueError( 77 | f"{cls_name}.get_n_bootstraps() returned {get_n_bootstraps}, " 78 | f"but n_bootstraps parameter is {n_bootstraps}. " 79 | "These should be equal." 80 | ) 81 | 82 | def test_bootstrap_input_output_contract(self, object_instance, scenario): 83 | """Tests that output of bootstrap method is as specified.""" 84 | import types 85 | 86 | cls_name = object_instance.__class__.__name__ 87 | 88 | result = scenario.run(object_instance, method_sequence=["bootstrap"]) 89 | 90 | if not isinstance(result, types.GeneratorType): 91 | raise TypeError( 92 | f"{cls_name}.bootstrap did not return a generator, " 93 | f"but instead returned {type(result)}." 94 | ) 95 | result = list(result) 96 | 97 | n_timepoints, n_vars = scenario.args["bootstrap"]["X"].shape 98 | n_bs_expected = object_instance.get_params()["n_bootstraps"] 99 | 100 | # if return_index=True, result is a tuple of (dataframe, index) 101 | # results are generators, so we need to convert to list 102 | if scenario.get_tag("return_index", False): 103 | if not all(isinstance(x, tuple) for x in result): 104 | raise TypeError( 105 | f"{cls_name}.bootstrap did not return a generator of tuples, " 106 | f"but instead returned {[type(x) for x in result]}." 107 | ) 108 | if not all(len(x) == 2 for x in result): 109 | raise ValueError( 110 | f"{cls_name}.bootstrap did not return a generator of tuples, " 111 | f"but instead returned {[len(x) for x in result]}." 112 | ) 113 | 114 | bss = [x[0] for x in result] 115 | index = [x[1] for x in result] 116 | 117 | else: 118 | bss = result 119 | 120 | if not len(bss) == n_bs_expected: 121 | raise ValueError( 122 | f"{cls_name}.bootstrap did not yield the expected number of " 123 | f"bootstrap samples. Expected {n_bs_expected}, but got {len(bss)}." 124 | ) 125 | 126 | if not all(isinstance(bs, np.ndarray) for bs in bss): 127 | raise ValueError( 128 | f"{cls_name}.bootstrap must yield numpy.ndarray, " 129 | f"but yielded {[type(bs) for bs in bss]} instead." 130 | "Not all bootstraps are numpy arrays." 131 | ) 132 | 133 | if not all(bs.ndim == 2 for bs in bss): 134 | print([bs.shape for bs in bss]) 135 | raise ValueError( 136 | f"{cls_name}.bootstrap yielded arrays with unexpected number of dimensions. All bootstrap samples should have 2 dimensions." 137 | ) 138 | 139 | if not all(bs.shape[0] == n_timepoints for bs in bss): 140 | raise ValueError( 141 | f"{cls_name}.bootstrap yielded arrays unexpected length," 142 | f" {[bs.shape[0] for bs in bss]}. " 143 | f"All bootstrap samples should have the same, " 144 | f"expected length: {n_timepoints}." 145 | ) 146 | if not all(bs.shape[1] == n_vars for bs in bss): 147 | raise ValueError( 148 | f"{cls_name}.bootstrap yielded arrays with unexpected number of " 149 | f"variables, {[bs.shape[1] for bs in bss]}. " 150 | "All bootstrap samples should have the same, " 151 | f"expected number of variables: {n_vars}." 152 | ) 153 | 154 | if scenario.get_tag("return_index", False): 155 | if not all(isinstance(ix, np.ndarray) for ix in index): 156 | raise ValueError( 157 | f"{cls_name}.bootstrap did not return a generator of tuples, " 158 | f"but instead returned {[type(ix) for ix in index]}." 159 | ) 160 | if not all(ix.ndim == 1 for ix in index): 161 | raise ValueError( 162 | f"{cls_name}.bootstrap yielded arrays with unexpected number of " 163 | "dimensions. All bootstrap samples should have 1 dimension." 164 | ) 165 | if not all(ix.shape[0] == n_timepoints for ix in index): 166 | raise ValueError( 167 | f"{cls_name}.bootstrap yielded arrays unexpected length," 168 | f" {[ix.shape[0] for ix in index]}. " 169 | f"All bootstrap samples should have the same, " 170 | f"expected length: {n_timepoints}." 171 | ) 172 | 173 | @pytest.mark.parametrize("test_ratio", [0.2, 0.0, 0.314, 0]) 174 | def test_bootstrap_test_ratio(self, object_instance, scenario, test_ratio): 175 | """Tests that the passing bootstrap test ratio has specified effect.""" 176 | cls_name = object_instance.__class__.__name__ 177 | 178 | bs_kwargs = scenario.args["bootstrap"] 179 | result = object_instance.bootstrap(test_ratio=test_ratio, **bs_kwargs) 180 | result = list(result) 181 | 182 | n_timepoints, n_vars = bs_kwargs["X"].shape 183 | n_bs_expected = object_instance.get_params()["n_bootstraps"] 184 | 185 | expected_length = np.floor(n_timepoints * (1 - test_ratio)).astype(int) 186 | 187 | # if return_index=True, result is a tuple of (dataframe, index) 188 | # results are generators, so we need to convert to list 189 | if scenario.get_tag("return_index", False): 190 | if not all(isinstance(x, tuple) for x in result): 191 | raise TypeError( 192 | f"{cls_name}.bootstrap did not return a generator of tuples, " 193 | f"but instead returned {[type(x) for x in result]}." 194 | ) 195 | if not all(len(x) == 2 for x in result): 196 | raise ValueError( 197 | f"{cls_name}.bootstrap did not return a generator of tuples, " 198 | f"but instead returned {[len(x) for x in result]}." 199 | ) 200 | 201 | bss = [x[0] for x in result] 202 | index = [x[1] for x in result] 203 | 204 | else: 205 | bss = list(result) 206 | 207 | if not len(bss) == n_bs_expected: 208 | raise ValueError( 209 | f"{cls_name}.bootstrap did not yield the expected number of " 210 | f"bootstrap samples. Expected {n_bs_expected}, but got {len(bss)}." 211 | ) 212 | 213 | if not all(isinstance(bs, np.ndarray) for bs in bss): 214 | raise ValueError( 215 | f"{cls_name}.bootstrap must yield numpy.ndarray, " 216 | f"but yielded {[type(bs) for bs in bss]} instead." 217 | "Not all bootstraps are numpy arrays." 218 | ) 219 | 220 | if not all(bs.ndim == 2 for bs in bss): 221 | print([bs.shape for bs in bss]) 222 | raise ValueError( 223 | f"{cls_name}.bootstrap yielded arrays with unexpected number of dimensions. All bootstrap samples should have 2 dimensions." 224 | ) 225 | 226 | if not all(bs.shape[0] == expected_length for bs in bss): 227 | raise ValueError( 228 | f"{cls_name}.bootstrap yielded arrays unexpected length," 229 | f" {[bs.shape[0] for bs in bss]}. " 230 | f"All bootstrap samples should have the same, " 231 | f"expected length: {expected_length}." 232 | ) 233 | if not all(bs.shape[1] == n_vars for bs in bss): 234 | raise ValueError( 235 | f"{cls_name}.bootstrap yielded arrays with unexpected number of " 236 | f"variables, {[bs.shape[1] for bs in bss]}. " 237 | "All bootstrap samples should have the same, " 238 | f"expected number of variables: {n_vars}." 239 | ) 240 | 241 | if scenario.get_tag("return_index", False): 242 | if not all(isinstance(ix, np.ndarray) for ix in index): 243 | raise ValueError( 244 | f"{cls_name}.bootstrap did not return a generator of tuples, " 245 | f"but instead returned {[type(ix) for ix in index]}." 246 | ) 247 | if not all(ix.ndim == 1 for ix in index): 248 | raise ValueError( 249 | f"{cls_name}.bootstrap yielded arrays with unexpected number of " 250 | "dimensions. All bootstrap samples should have 1 dimension." 251 | ) 252 | if not all(ix.shape[0] == expected_length for ix in index): 253 | raise ValueError( 254 | f"{cls_name}.bootstrap yielded arrays unexpected length," 255 | f" {[ix.shape[0] for ix in index]}. " 256 | f"All bootstrap samples should have the same, " 257 | f"expected length: {expected_length}." 258 | ) 259 | -------------------------------------------------------------------------------- /src/tsbootstrap/tests/test_all_estimators.py: -------------------------------------------------------------------------------- 1 | """Automated tests based on the skbase test suite template.""" 2 | 3 | from skbase.testing import BaseFixtureGenerator as _BaseFixtureGenerator 4 | from skbase.testing import TestAllObjects as _TestAllObjects 5 | 6 | from tsbootstrap.registry import OBJECT_TAG_LIST, all_objects 7 | from tsbootstrap.tests.scenarios.scenarios_getter import retrieve_scenarios 8 | from tsbootstrap.tests.test_switch import run_test_for_class 9 | 10 | # whether to test only estimators from modules that are changed w.r.t. main 11 | # default is False, can be set to True by pytest --only_changed_modules True flag 12 | ONLY_CHANGED_MODULES = False 13 | 14 | # objects temporarily excluded due to known bugs 15 | TEMPORARY_EXCLUDED_OBJECTS = [] # ["StationaryBlockBootstrap"] # see bug #73 16 | 17 | 18 | class PackageConfig: 19 | """Contains package config variables for test classes.""" 20 | 21 | # class variables which can be overridden by descendants 22 | # ------------------------------------------------------ 23 | 24 | # package to search for objects 25 | # expected type: str, package/module name, relative to python environment root 26 | package_name = "tsbootstrap" 27 | 28 | # list of object types (class names) to exclude 29 | # expected type: list of str, str are class names 30 | exclude_objects = ["ClassName"] + TEMPORARY_EXCLUDED_OBJECTS 31 | # exclude classes from extension templates 32 | # exclude classes with known bugs 33 | 34 | # list of valid tags 35 | # expected type: list of str, str are tag names 36 | valid_tags = OBJECT_TAG_LIST 37 | 38 | 39 | class BaseFixtureGenerator(_BaseFixtureGenerator): 40 | """Fixture generator for base testing functionality in sktime. 41 | 42 | Test classes inheriting from this and not overriding pytest_generate_tests 43 | will have estimator and scenario fixtures parametrized out of the box. 44 | 45 | Descendants can override: 46 | estimator_type_filter: str, class variable; None or scitype string 47 | e.g., "forecaster", "transformer", "classifier", see BASE_CLASS_SCITYPE_LIST 48 | which estimators are being retrieved and tested 49 | fixture_sequence: list of str 50 | sequence of fixture variable names in conditional fixture generation 51 | _generate_[variable]: object methods, all (test_name: str, **kwargs) -> list 52 | generating list of fixtures for fixture variable with name [variable] 53 | to be used in test with name test_name 54 | can optionally use values for fixtures earlier in fixture_sequence, 55 | these must be input as kwargs in a call 56 | is_excluded: static method (test_name: str, est: class) -> bool 57 | whether test with name test_name should be excluded for estimator est 58 | should be used only for encoding general rules, not individual skips 59 | individual skips should go on the EXCLUDED_TESTS list in _config 60 | requires _generate_estimator_class and _generate_estimator_instance as is 61 | _excluded_scenario: static method (test_name: str, scenario) -> bool 62 | whether scenario should be skipped in test with test_name test_name 63 | requires _generate_estimator_scenario as is 64 | 65 | Fixtures parametrized 66 | --------------------- 67 | object_class: estimator inheriting from BaseObject 68 | ranges over estimator classes not excluded by EXCLUDE_ESTIMATORS, EXCLUDED_TESTS 69 | object_instance: instance of estimator inheriting from BaseObject 70 | ranges over estimator classes not excluded by EXCLUDE_ESTIMATORS, EXCLUDED_TESTS 71 | instances are generated by create_test_instance class method of estimator_class 72 | scenario: instance of TestScenario 73 | ranges over all scenarios returned by retrieve_scenarios 74 | applicable for estimator_class or estimator_instance 75 | """ 76 | 77 | # overrides object retrieval in scikit-base 78 | def _all_objects(self): 79 | """Retrieve list of all object classes of type self.object_type_filter.""" 80 | obj_list = all_objects( 81 | object_types=getattr(self, "object_type_filter", None), 82 | return_names=False, 83 | exclude_objects=self.exclude_objects, 84 | ) 85 | 86 | # run_test_for_class selects the estimators to run 87 | # based on whether they have changed, and whether they have all dependencies 88 | # internally, uses the ONLY_CHANGED_MODULES flag, 89 | # and checks the python env against python_dependencies tag 90 | obj_list = [obj for obj in obj_list if run_test_for_class(obj)] 91 | 92 | def scitype(obj): 93 | type_tag = obj.get_class_tag("object_type", "object") 94 | return type_tag 95 | 96 | # exclude config objects and sampler objects 97 | excluded_types = ["config", "sampler"] 98 | obj_list = [ 99 | obj for obj in obj_list if scitype(obj) not in excluded_types 100 | ] 101 | 102 | return obj_list 103 | 104 | # which sequence the conditional fixtures are generated in 105 | fixture_sequence = [ 106 | "object_class", 107 | "object_instance", 108 | "scenario", 109 | ] 110 | 111 | def _generate_scenario(self, test_name, **kwargs): 112 | """Return estimator test scenario. 113 | 114 | Fixtures parametrized 115 | --------------------- 116 | scenario: instance of TestScenario 117 | ranges over all scenarios returned by retrieve_scenarios 118 | """ 119 | if "object_class" in kwargs: 120 | obj = kwargs["object_class"] 121 | elif "object_instance" in kwargs: 122 | obj = kwargs["object_instance"] 123 | else: 124 | return [] 125 | 126 | scenarios = retrieve_scenarios(obj) 127 | scenarios = [ 128 | s for s in scenarios if not self._excluded_scenario(test_name, s) 129 | ] 130 | scenario_names = [type(scen).__name__ for scen in scenarios] 131 | 132 | return scenarios, scenario_names 133 | 134 | @staticmethod 135 | def _excluded_scenario(test_name, scenario): 136 | """Skip list generator for scenarios to skip in test_name. 137 | 138 | Arguments 139 | --------- 140 | test_name : str, name of test 141 | scenario : instance of TestScenario, to be used in test 142 | 143 | Returns 144 | ------- 145 | bool, whether scenario should be skipped in test_name 146 | """ 147 | # for now, all scenarios are enabled 148 | # if not scenario.get_tag("is_enabled", False, raise_error=False): 149 | # return True 150 | 151 | return False 152 | 153 | 154 | class TestAllObjects(PackageConfig, BaseFixtureGenerator, _TestAllObjects): 155 | """Generic tests for all objects in the mini package.""" 156 | 157 | # override test_constructor to allow for kwargs 158 | def test_constructor(self, object_class): 159 | """Check that the constructor has sklearn compatible signature and behaviour. 160 | 161 | Overrides the test_constructor method from _TestAllObjects, 162 | in order to allow for the constructor to have kwargs. 163 | """ 164 | try: 165 | # dispatch for remaining test logic 166 | super().test_constructor(object_class) 167 | except AssertionError as e: 168 | if "constructor __init__ of" not in str(e): 169 | raise 170 | -------------------------------------------------------------------------------- /src/tsbootstrap/tests/test_class_register.py: -------------------------------------------------------------------------------- 1 | # copyright: tsbootstrap developers, BSD-3-Clause License (see LICENSE file) 2 | """Registry and dispatcher for test classes. 3 | 4 | Module does not contain tests, only test utilities. 5 | """ 6 | 7 | __author__ = ["fkiraly"] 8 | 9 | from inspect import isclass 10 | 11 | 12 | def get_test_class_registry(): 13 | """Return test class registry. 14 | 15 | Wrapped in a function to avoid circular imports. 16 | 17 | Returns 18 | ------- 19 | testclass_dict : dict 20 | test class registry 21 | keys are scitypes, values are test classes TestAll[Scitype] 22 | """ 23 | from tsbootstrap.tests.test_all_bootstraps import TestAllBootstraps 24 | from tsbootstrap.tests.test_all_estimators import TestAllObjects 25 | 26 | testclass_dict = {} 27 | # every object in tsbootstrap inherits from BaseObject 28 | # "object" tests are run for all objects 29 | testclass_dict["object"] = TestAllObjects 30 | # more specific base classes 31 | # these inherit either from BaseEstimator or BaseObject, 32 | # so also imply estimator and object tests, or only object tests 33 | testclass_dict["bootstrap"] = TestAllBootstraps 34 | 35 | return testclass_dict 36 | 37 | 38 | def get_test_classes_for_obj(obj): 39 | """Get all test classes relevant for an object or estimator. 40 | 41 | Parameters 42 | ---------- 43 | obj : object or estimator, descendant of sktime BaseObject or BaseEstimator 44 | object or estimator for which to get test classes 45 | 46 | Returns 47 | ------- 48 | test_classes : list of test classes 49 | list of test classes relevant for obj 50 | these are references to the actual classes, not strings 51 | if obj was not a descendant of BaseObject or BaseEstimator, returns empty list 52 | """ 53 | from skbase.base import BaseObject 54 | 55 | def is_object(obj): 56 | """Return whether obj is an estimator class or estimator object.""" 57 | if isclass(obj): 58 | return issubclass(obj, BaseObject) 59 | else: 60 | return isinstance(obj, BaseObject) 61 | 62 | # warning: BaseEstimator does not inherit from BaseObject, 63 | # therefore we need to check both 64 | if not is_object(obj): 65 | return [] 66 | 67 | testclass_dict = get_test_class_registry() 68 | 69 | # we always need to run "object" tests 70 | test_clss = [testclass_dict["object"]] 71 | 72 | try: 73 | obj_scitypes = obj.get_class_tag("object_type") 74 | if not isinstance(obj_scitypes, list): 75 | obj_scitypes = [obj_scitypes] 76 | except Exception: 77 | obj_scitypes = [] 78 | 79 | for obj_scitype in obj_scitypes: 80 | if obj_scitype in testclass_dict: 81 | test_clss += [testclass_dict[obj_scitype]] 82 | 83 | return test_clss 84 | -------------------------------------------------------------------------------- /src/tsbootstrap/tests/test_switch.py: -------------------------------------------------------------------------------- 1 | # copyright: 2 | # tsbootstrap developers, BSD-3-Clause License (see LICENSE file) 3 | # based on utility from sktime of the same name 4 | 5 | """Switch utility for determining whether tests for a class should be run or not.""" 6 | 7 | __author__ = ["fkiraly", "astrogilda"] 8 | 9 | from typing import Any, List, Optional, Union 10 | 11 | from tsbootstrap.utils.dependencies import _check_estimator_dependencies 12 | 13 | 14 | def run_test_for_class(cls: Union[Any, List[Any], tuple]) -> bool: 15 | """ 16 | Determine whether tests should be run for a given class or function based on dependency checks. 17 | 18 | This function evaluates whether the provided class/function or a list of them has all required 19 | soft dependencies present in the current environment. If all dependencies are satisfied, it returns 20 | `True`, indicating that tests should be executed. Otherwise, it returns `False`. 21 | 22 | Parameters 23 | ---------- 24 | cls : Union[Any, List[Any], tuple] 25 | A single class/function or a list/tuple of classes/functions for which to determine 26 | whether tests should be run. Each class/function should be a descendant of `BaseObject` 27 | and have the `get_class_tag` method for dependency retrieval. 28 | 29 | Returns 30 | ------- 31 | bool 32 | `True` if all provided classes/functions have their required dependencies present. 33 | `False` otherwise. 34 | 35 | Raises 36 | ------ 37 | ValueError 38 | If the severity level provided in dependency checks is invalid. 39 | TypeError 40 | If any object in `cls` does not have the `get_class_tag` method or is not a `BaseObject` descendant. 41 | """ 42 | # Ensure cls is a list for uniform processing 43 | if not isinstance(cls, (list, tuple)): 44 | cls = [cls] 45 | 46 | # Define the severity level and message for dependency checks 47 | # Set to 'none' to silently return False without raising exceptions or warnings 48 | severity = "none" 49 | msg: Optional[str] = None # No custom message 50 | 51 | # Perform dependency checks for all classes/functions 52 | # If any dependency is not met, the function will return False 53 | # Since severity is 'none', no exceptions or warnings will be raised 54 | try: 55 | all_dependencies_present = _check_estimator_dependencies( 56 | obj=cls, severity=severity, msg=msg 57 | ) 58 | except (ValueError, TypeError): 59 | # Log the error if necessary, or handle it as per testing framework 60 | # For now, we assume that any exception means dependencies are not met 61 | all_dependencies_present = False 62 | 63 | return all_dependencies_present 64 | -------------------------------------------------------------------------------- /src/tsbootstrap/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Utilities for tsbootstrap package.""" 2 | 3 | from tsbootstrap.utils.estimator_checks import check_estimator 4 | 5 | __all__ = ["check_estimator"] 6 | -------------------------------------------------------------------------------- /src/tsbootstrap/utils/dependencies.py: -------------------------------------------------------------------------------- 1 | """Utility module for checking soft dependency imports and raising warnings or errors.""" 2 | 3 | __author__ = ["fkiraly", "astrogilda"] 4 | 5 | import logging 6 | from enum import Enum 7 | from typing import Any, List, Optional, Union 8 | 9 | from pydantic import BaseModel, Field, ValidationError, field_validator 10 | from skbase.utils.dependencies import ( 11 | _check_python_version, 12 | _check_soft_dependencies, 13 | ) 14 | 15 | # Configure logging for the module 16 | logger = logging.getLogger(__name__) 17 | logger.addHandler(logging.NullHandler()) 18 | 19 | 20 | class SeverityEnum(str, Enum): 21 | """ 22 | Enumeration for severity levels. 23 | 24 | Attributes 25 | ---------- 26 | ERROR : str 27 | Indicates that a `ModuleNotFoundError` should be raised if dependencies are not met. 28 | WARNING : str 29 | Indicates that a warning should be emitted if dependencies are not met. 30 | NONE : str 31 | Indicates that no action should be taken if dependencies are not met. 32 | """ 33 | 34 | ERROR = "error" 35 | WARNING = "warning" 36 | NONE = "none" 37 | 38 | 39 | def _check_estimator_dependencies( 40 | obj: Union[Any, List[Any], tuple], 41 | severity: Union[str, SeverityEnum] = "error", 42 | msg: Optional[str] = None, 43 | ) -> bool: 44 | """ 45 | Check if an object or list of objects' package and Python requirements are met by the current environment. 46 | 47 | This function serves as a convenience wrapper around `_check_python_version` and `_check_soft_dependencies`, 48 | utilizing the estimator tags `"python_version"` and `"python_dependencies"`. 49 | 50 | Parameters 51 | ---------- 52 | obj : Union[Any, List[Any], tuple] 53 | An object (instance or class) that is a descendant of `BaseObject`, or a list/tuple of such objects. 54 | These objects are checked for compatibility with the current Python environment. 55 | severity : Union[str, SeverityEnum], default="error" 56 | Determines the behavior when incompatibility is detected: 57 | - "error": Raises a `ModuleNotFoundError`. 58 | - "warning": Emits a warning and returns `False`. 59 | - "none": Silently returns `False` without raising an exception or warning. 60 | msg : Optional[str], default=None 61 | Custom error message to be used in the `ModuleNotFoundError`. 62 | Overrides the default message if provided. 63 | 64 | Returns 65 | ------- 66 | bool 67 | `True` if all objects are compatible with the current environment; `False` otherwise. 68 | 69 | Raises 70 | ------ 71 | ModuleNotFoundError 72 | If `severity` is set to "error" and incompatibility is detected. 73 | ValueError 74 | If an invalid severity level is provided. 75 | TypeError 76 | If `obj` is not a `BaseObject` descendant or a list/tuple thereof. 77 | """ 78 | 79 | # Define an inner Pydantic model for validating input parameters 80 | class DependencyCheckConfig(BaseModel): 81 | """ 82 | Pydantic model for configuring dependency checks. 83 | 84 | Attributes 85 | ---------- 86 | severity : SeverityEnum 87 | Determines the behavior when incompatibility is detected. 88 | msg : Optional[str] 89 | Custom error message to be used in the `ModuleNotFoundError`. 90 | """ 91 | 92 | severity: SeverityEnum = Field( 93 | default=SeverityEnum.ERROR, 94 | description=( 95 | "Determines the behavior when incompatibility is detected.\n" 96 | "- 'error': Raises a `ModuleNotFoundError`.\n" 97 | "- 'warning': Emits a warning and returns `False`.\n" 98 | "- 'none': Silently returns `False` without raising an exception or warning." 99 | ), 100 | ) 101 | msg: Optional[str] = Field( 102 | default=None, 103 | description=( 104 | "Custom error message to be used in the `ModuleNotFoundError`. " 105 | "Overrides the default message if provided." 106 | ), 107 | ) 108 | 109 | @field_validator("severity", mode="before") 110 | @classmethod 111 | def validate_severity( 112 | cls, v: Union[str, SeverityEnum] 113 | ) -> SeverityEnum: 114 | """ 115 | Validate and convert the severity level to SeverityEnum. 116 | 117 | Parameters 118 | ---------- 119 | v : Union[str, SeverityEnum] 120 | The severity level to validate. 121 | 122 | Returns 123 | ------- 124 | SeverityEnum 125 | The validated severity level. 126 | 127 | Raises 128 | ------ 129 | ValueError 130 | If the severity level is not one of the defined Enum members. 131 | """ 132 | if isinstance(v, str): 133 | try: 134 | return SeverityEnum(v.lower()) 135 | except ValueError: 136 | raise ValueError( 137 | f"Invalid severity level '{v}'. Choose {[level.value for level in SeverityEnum]}" 138 | ) from None 139 | elif isinstance(v, SeverityEnum): 140 | return v 141 | else: 142 | raise TypeError( 143 | f"Severity must be a string or an instance of SeverityEnum, got {type(v)}." 144 | ) 145 | 146 | try: 147 | # Instantiate DependencyCheckConfig to validate severity and msg 148 | config = DependencyCheckConfig(severity=severity, msg=msg) # type: ignore 149 | except ValidationError as ve: 150 | # Re-raise as a ValueError with detailed message 151 | raise ValueError(f"Invalid input parameters: {ve}") from ve 152 | 153 | def _check_single_dependency(obj_single: Any) -> bool: 154 | """ 155 | Check dependencies for a single object. 156 | 157 | Parameters 158 | ---------- 159 | obj_single : Any 160 | A single `BaseObject` descendant to check. 161 | 162 | Returns 163 | ------- 164 | bool 165 | `True` if the object is compatible; `False` otherwise. 166 | """ 167 | if not hasattr(obj_single, "get_class_tag"): 168 | raise TypeError( 169 | f"Object {obj_single} does not have 'get_class_tag' method." 170 | ) 171 | 172 | compatible = True 173 | 174 | # Check Python version compatibility 175 | if not _check_python_version( 176 | obj_single, severity=config.severity.value 177 | ): 178 | compatible = False 179 | message = ( 180 | config.msg or f"Python version incompatible for {obj_single}." 181 | ) 182 | if config.severity == SeverityEnum.ERROR: 183 | raise ModuleNotFoundError(message) 184 | elif config.severity == SeverityEnum.WARNING: 185 | logger.warning(message) 186 | 187 | # Check soft dependencies 188 | pkg_deps = obj_single.get_class_tag("python_dependencies", None) 189 | pkg_alias = obj_single.get_class_tag("python_dependencies_alias", None) 190 | 191 | if pkg_deps: 192 | if not isinstance(pkg_deps, list): 193 | pkg_deps = [pkg_deps] 194 | if not _check_soft_dependencies( 195 | *pkg_deps, 196 | severity=config.severity.value, 197 | obj=obj_single, 198 | package_import_alias=pkg_alias, 199 | ): 200 | compatible = False 201 | message = ( 202 | config.msg or f"Missing dependencies for {obj_single}." 203 | ) 204 | if config.severity == SeverityEnum.ERROR: 205 | raise ModuleNotFoundError(message) 206 | elif config.severity == SeverityEnum.WARNING: 207 | logger.warning(message) 208 | 209 | return compatible 210 | 211 | compatible = True 212 | 213 | # If obj is a list or tuple, iterate and check each element 214 | if isinstance(obj, (list, tuple)): 215 | for item in obj: 216 | try: 217 | item_compatible = _check_single_dependency(item) 218 | compatible = compatible and item_compatible 219 | # Early exit if incompatibility detected and severity is ERROR 220 | if not compatible and config.severity == SeverityEnum.ERROR: 221 | break 222 | except (ModuleNotFoundError, TypeError, ValueError): 223 | if config.severity == SeverityEnum.ERROR: 224 | raise 225 | elif config.severity == SeverityEnum.WARNING: 226 | compatible = False 227 | return compatible 228 | 229 | # Single object check 230 | return _check_single_dependency(obj) 231 | -------------------------------------------------------------------------------- /src/tsbootstrap/utils/estimator_checks.py: -------------------------------------------------------------------------------- 1 | """Estimator checker for extension.""" 2 | 3 | __author__ = ["fkiraly"] 4 | __all__ = ["check_estimator"] 5 | 6 | from skbase.utils.dependencies import _check_soft_dependencies 7 | 8 | 9 | def check_estimator( 10 | estimator, 11 | raise_exceptions=False, 12 | tests_to_run=None, 13 | fixtures_to_run=None, 14 | verbose=True, 15 | tests_to_exclude=None, 16 | fixtures_to_exclude=None, 17 | ): 18 | """Run all tests on one single estimator. 19 | 20 | Tests that are run on estimator: 21 | 22 | * all tests in `test_all_estimators` 23 | * all interface compatibility tests from the module of estimator's scitype 24 | 25 | Parameters 26 | ---------- 27 | estimator : estimator class or estimator instance 28 | raise_exceptions : bool, optional, default=False 29 | whether to return exceptions/failures in the results dict, or raise them 30 | 31 | * if False: returns exceptions in returned `results` dict 32 | * if True: raises exceptions as they occur 33 | 34 | tests_to_run : str or list of str, optional. Default = run all tests. 35 | Names (test/function name string) of tests to run. 36 | sub-sets tests that are run to the tests given here. 37 | fixtures_to_run : str or list of str, optional. Default = run all tests. 38 | pytest test-fixture combination codes, which test-fixture combinations to run. 39 | sub-sets tests and fixtures to run to the list given here. 40 | If both tests_to_run and fixtures_to_run are provided, runs the *union*, 41 | i.e., all test-fixture combinations for tests in tests_to_run, 42 | plus all test-fixture combinations in fixtures_to_run. 43 | verbose : str, optional, default=True. 44 | whether to print out informative summary of tests run. 45 | tests_to_exclude : str or list of str, names of tests to exclude. default = None 46 | removes tests that should not be run, after subsetting via tests_to_run. 47 | fixtures_to_exclude : str or list of str, fixtures to exclude. default = None 48 | removes test-fixture combinations that should not be run. 49 | This is done after subsetting via fixtures_to_run. 50 | 51 | Returns 52 | ------- 53 | results : dict of results of the tests in self 54 | keys are test/fixture strings, identical as in pytest, e.g., test[fixture] 55 | entries are the string "PASSED" if the test passed, 56 | or the exception raised if the test did not pass 57 | returned only if all tests pass, or raise_exceptions=False 58 | 59 | Raises 60 | ------ 61 | if raise_exceptions=True, 62 | raises any exception produced by the tests directly 63 | 64 | Examples 65 | -------- 66 | >>> from tsbootstrap import MovingBlockBootstrap 67 | >>> from tsbootstrap.utils import check_estimator 68 | >>> 69 | >>> check_estimator(MovingBlockBootstrap, raise_exceptions=True) 70 | ... 71 | """ 72 | msg = ( 73 | "check_estimator is a testing utility for developers, and " 74 | "requires pytest to be present " 75 | "in the python environment, but pytest was not found. " 76 | "pytest is a developer dependency and not included in the base " 77 | "sktime installation. Please run: `pip install pytest` to " 78 | "install the pytest package. " 79 | "To install tsbootstrap with all developer dependencies, run:" 80 | " `pip install tsbootstrap[dev]`" 81 | ) 82 | # _check_soft_dependencies("pytest", msg=msg) 83 | _check_soft_dependencies("pytest") 84 | 85 | from tsbootstrap.tests.test_class_register import get_test_classes_for_obj 86 | 87 | test_clss_for_est = get_test_classes_for_obj(estimator) 88 | 89 | results = {} 90 | 91 | for test_cls in test_clss_for_est: 92 | test_cls_results = test_cls().run_tests( 93 | obj=estimator, 94 | raise_exceptions=raise_exceptions, 95 | tests_to_run=tests_to_run, 96 | fixtures_to_run=fixtures_to_run, 97 | tests_to_exclude=tests_to_exclude, 98 | fixtures_to_exclude=fixtures_to_exclude, 99 | ) 100 | results.update(test_cls_results) 101 | 102 | failed_tests = [key for key in results if results[key] != "PASSED"] 103 | if len(failed_tests) > 0: 104 | msg = failed_tests 105 | msg = ["FAILED: " + x for x in msg] 106 | msg = "\n".join(msg) 107 | else: 108 | msg = "All tests PASSED!" 109 | 110 | if verbose: 111 | # printing is an intended feature, for console usage and interactive debugging 112 | print(msg) # noqa: T001 113 | 114 | return results 115 | -------------------------------------------------------------------------------- /src/tsbootstrap/utils/odds_and_ends.py: -------------------------------------------------------------------------------- 1 | import os 2 | from contextlib import contextmanager 3 | from numbers import Integral 4 | from typing import Union 5 | 6 | import numpy as np 7 | from numpy.random import Generator 8 | 9 | from tsbootstrap.utils.types import RngTypes 10 | 11 | 12 | def time_series_split(X: np.ndarray, test_ratio: float): 13 | """ 14 | Splits a given time series into training and test sets. 15 | 16 | Parameters 17 | ---------- 18 | X : np.ndarray 19 | The input time series. 20 | test_ratio : float 21 | The ratio of the test set size to the total size of the series. 22 | 23 | Returns 24 | ------- 25 | Tuple[np.ndarray, np.ndarray] 26 | A tuple containing the training set and the test set. 27 | """ 28 | # Validate test_ratio 29 | if not 0 <= test_ratio <= 1: 30 | raise ValueError( 31 | f"Test ratio must be between 0 and 1. Got {test_ratio}" 32 | ) 33 | 34 | split_index = int(len(X) * (1 - test_ratio)) 35 | return X[:split_index], X[split_index:] 36 | 37 | 38 | def check_generator(seed_or_rng: RngTypes, seed_allowed: bool = True) -> Generator: # type: ignore 39 | """Turn seed into a np.random.Generator instance. 40 | 41 | Parameters 42 | ---------- 43 | seed_or_rng : int, Generator, or None 44 | If seed_or_rng is None, return the Generator singleton used by np.random. 45 | If seed_or_rng is an int, return a new Generator instance seeded with seed_or_rng. 46 | If seed_or_rng is already a Generator instance, return it. 47 | Otherwise raise ValueError. 48 | 49 | seed_allowed : bool, optional 50 | If True, seed_or_rng can be an int. If False, seed_or_rng cannot be an int. 51 | Default is True. 52 | 53 | Returns 54 | ------- 55 | Generator 56 | A numpy.random.Generator instance. 57 | 58 | Raises 59 | ------ 60 | ValueError 61 | If seed_or_rng is not None, an int, or a numpy.random.Generator instance. 62 | If seed_or_rng is an int and seed_allowed is False. 63 | If seed_or_rng is an int and it is not between 0 and 2**32 - 1. 64 | """ 65 | if seed_or_rng is None: 66 | return np.random.default_rng() 67 | if isinstance(seed_or_rng, Generator): 68 | return seed_or_rng 69 | if seed_allowed and isinstance(seed_or_rng, Integral): 70 | if not (0 <= seed_or_rng < 2**32): # type: ignore 71 | raise ValueError( 72 | f"The random seed must be between 0 and 2**32 - 1. Got {seed_or_rng}" 73 | ) 74 | return np.random.default_rng(seed_or_rng) # type: ignore 75 | 76 | raise ValueError( 77 | f"{seed_or_rng} cannot be used to seed a numpy.random.Generator instance" 78 | ) 79 | 80 | 81 | def generate_random_indices( 82 | num_samples: Integral, rng: RngTypes = None # type: ignore 83 | ) -> np.ndarray: 84 | """ 85 | Generate random indices with replacement. 86 | 87 | This function generates random indices from 0 to `num_samples-1` with replacement. 88 | The generated indices can be used for bootstrap sampling, etc. 89 | 90 | Parameters 91 | ---------- 92 | num_samples : Integral 93 | The number of samples for which the indices are to be generated. 94 | This must be a positive integer. 95 | rng : Integral, optional 96 | The seed for the random number generator. If provided, this must be a non-negative integer. 97 | Default is None, which does not set the numpy's random seed and the results will be non-deterministic. 98 | 99 | Returns 100 | ------- 101 | np.ndarray 102 | A numpy array of shape (`num_samples`,) containing randomly generated indices. 103 | 104 | Raises 105 | ------ 106 | ValueError 107 | If `num_samples` is not a positive integer or if `random_seed` is provided and 108 | it is not a non-negative integer. 109 | 110 | Examples 111 | -------- 112 | >>> generate_random_indices(5, random_seed=0) 113 | array([4, 0, 3, 3, 3]) 114 | >>> generate_random_indices(5) 115 | array([2, 1, 4, 2, 0]) # random 116 | """ 117 | # Check types and values of num_samples and random_seed 118 | from tsbootstrap.utils.validate import validate_integers 119 | 120 | validate_integers(num_samples, min_value=1) # type: ignore 121 | rng = check_generator(rng, seed_allowed=True) 122 | 123 | # Generate random indices with replacement 124 | in_bootstrap_indices = rng.choice( 125 | np.arange(num_samples), size=num_samples, replace=True # type: ignore 126 | ) 127 | 128 | return in_bootstrap_indices 129 | 130 | 131 | @contextmanager 132 | def suppress_output(verbose: int = 2): 133 | """A context manager for controlling the suppression of stdout and stderr. 134 | 135 | Parameters 136 | ---------- 137 | verbose : int, optional 138 | Verbosity level controlling suppression. 139 | 2 - No suppression (default) 140 | 1 - Suppress stdout only 141 | 0 - Suppress both stdout and stderr 142 | 143 | Returns 144 | ------- 145 | None 146 | 147 | Examples 148 | -------- 149 | with suppress_output(verbose=1): 150 | print('This will not be printed to stdout') 151 | """ 152 | # No suppression required 153 | if verbose == 2: 154 | yield 155 | return 156 | 157 | # Open null files as needed 158 | null_fds = [ 159 | os.open(os.devnull, os.O_RDWR) for _ in range(2 if verbose == 0 else 1) 160 | ] 161 | # Save the actual stdout (1) and possibly stderr (2) file descriptors. 162 | save_fds = [os.dup(1), os.dup(2)] if verbose == 0 else [os.dup(1)] 163 | try: 164 | # Assign the null pointers as required 165 | os.dup2(null_fds[0], 1) 166 | if verbose == 0: 167 | os.dup2(null_fds[1], 2) 168 | yield 169 | finally: 170 | # Re-assign the real stdout/stderr back 171 | for fd, save_fd in zip(null_fds, save_fds): 172 | os.dup2(save_fd, fd) 173 | # Close the null files and saved file descriptors 174 | for fd in null_fds + save_fds: 175 | os.close(fd) 176 | 177 | 178 | def _check_nan_inf_locations( 179 | a: np.ndarray, b: np.ndarray, check_same: bool 180 | ) -> bool: 181 | """ 182 | Check the locations of NaNs and Infs in both arrays. 183 | 184 | Parameters 185 | ---------- 186 | a, b : np.ndarray 187 | The arrays to be compared. 188 | check_same : bool 189 | If True, checks if NaNs and Infs are in the same locations. 190 | 191 | Returns 192 | ------- 193 | bool 194 | True if locations do not match and check_same is False, otherwise False. 195 | 196 | Raises 197 | ------ 198 | ValueError 199 | If check_same is True and the arrays have NaNs or Infs in different locations. 200 | """ 201 | a_nan_locs = np.isnan(a) 202 | b_nan_locs = np.isnan(b) 203 | a_inf_locs = np.isinf(a) 204 | b_inf_locs = np.isinf(b) 205 | 206 | if not np.array_equal(a_nan_locs, b_nan_locs) or not np.array_equal( 207 | a_inf_locs, b_inf_locs 208 | ): 209 | if check_same: 210 | raise ValueError("NaNs or Infs in different locations") 211 | else: 212 | return True 213 | 214 | return False 215 | 216 | 217 | def _check_inf_signs(a: np.ndarray, b: np.ndarray, check_same: bool) -> bool: 218 | """ 219 | Check the signs of Infs in both arrays. 220 | 221 | Parameters 222 | ---------- 223 | a, b : np.ndarray 224 | The arrays to be compared. 225 | check_same : bool 226 | If True, checks if Infs have the same signs. 227 | 228 | Returns 229 | ------- 230 | bool 231 | True if signs do not match and check_same is False, otherwise False. 232 | 233 | Raises 234 | ------ 235 | ValueError 236 | If check_same is True and the arrays have Infs with different signs. 237 | """ 238 | a_inf_locs = np.isinf(a) 239 | b_inf_locs = np.isinf(b) 240 | 241 | if not np.array_equal(np.sign(a[a_inf_locs]), np.sign(b[b_inf_locs])): 242 | if check_same: 243 | raise ValueError("Infs with different signs") 244 | else: 245 | return True 246 | 247 | return False 248 | 249 | 250 | def _check_close_values( 251 | a: np.ndarray, b: np.ndarray, rtol: float, atol: float, check_same: bool 252 | ) -> bool: 253 | """ 254 | Check that the finite values in the arrays are close. 255 | 256 | Parameters 257 | ---------- 258 | a, b : np.ndarray 259 | The arrays to be compared. 260 | rtol : float 261 | The relative tolerance parameter for the np.allclose function. 262 | atol : float 263 | The absolute tolerance parameter for the np.allclose function. 264 | check_same : bool 265 | If True, checks if the arrays are almost equal. 266 | 267 | Returns 268 | ------- 269 | bool 270 | True if values are not close and check_same is False, otherwise False. 271 | 272 | Raises 273 | ------ 274 | ValueError 275 | If check_same is True and the arrays are not almost equal. 276 | """ 277 | a_nan_locs = np.isnan(a) 278 | b_nan_locs = np.isnan(b) 279 | a_inf_locs = np.isinf(a) 280 | b_inf_locs = np.isinf(b) 281 | a_masked = np.ma.masked_where(a_nan_locs | a_inf_locs, a) 282 | b_masked = np.ma.masked_where(b_nan_locs | b_inf_locs, b) 283 | 284 | if check_same: 285 | if not np.allclose(a_masked, b_masked, rtol=rtol, atol=atol): 286 | raise ValueError("Arrays are not almost equal") 287 | else: 288 | if np.any(~np.isclose(a_masked, b_masked, rtol=rtol, atol=atol)): 289 | return True 290 | 291 | return False 292 | 293 | 294 | def assert_arrays_compare( 295 | a: np.ndarray, b: np.ndarray, rtol=1e-5, atol=1e-8, check_same=True 296 | ) -> bool: 297 | """ 298 | Assert that two arrays are almost equal. 299 | 300 | This function compares two arrays for equality, allowing for NaNs and Infs in the arrays. 301 | The arrays are considered equal if the following conditions are satisfied: 302 | 1. The locations of NaNs and Infs in both arrays are the same. 303 | 2. The signs of the infinite values in both arrays are the same. 304 | 3. The finite values are almost equal. 305 | 306 | Parameters 307 | ---------- 308 | a, b : np.ndarray 309 | The arrays to be compared. 310 | rtol : float, optional 311 | The relative tolerance parameter for the np.allclose function. 312 | Default is 1e-5. 313 | atol : float, optional 314 | The absolute tolerance parameter for the np.allclose function. 315 | Default is 1e-8. 316 | check_same : bool, optional 317 | If True, raise an AssertionError if the arrays are not almost equal. 318 | If False, return True if the arrays are not almost equal and False otherwise. 319 | Default is True. 320 | 321 | Returns 322 | ------- 323 | bool 324 | If check_same is False, returns True if the arrays are not almost equal and False otherwise. 325 | If check_same is True, returns True if the arrays are almost equal and False otherwise. 326 | 327 | Raises 328 | ------ 329 | AssertionError 330 | If check_same is True and the arrays are not almost equal. 331 | ValueError 332 | If check_same is True and the arrays have NaNs or Infs in different locations. 333 | If check_same is True and the arrays have Infs with different signs. 334 | """ 335 | if _check_nan_inf_locations(a, b, check_same): 336 | return not check_same 337 | if _check_inf_signs(a, b, check_same): 338 | return not check_same 339 | if _check_close_values(a, b, rtol, atol, check_same): 340 | return not check_same 341 | 342 | return not check_same if not check_same else True 343 | -------------------------------------------------------------------------------- /src/tsbootstrap/utils/types.py: -------------------------------------------------------------------------------- 1 | # Use future annotations for better handling of forward references. 2 | from __future__ import annotations 3 | 4 | import sys 5 | from enum import Enum 6 | from numbers import Integral 7 | from typing import Any, List, Literal, Optional, Union 8 | 9 | from numpy.random import Generator 10 | from packaging.specifiers import SpecifierSet 11 | 12 | # Define model and block compressor types using Literal for clearer enum-style typing. 13 | ModelTypesWithoutArch = Literal["ar", "arima", "sarima", "var"] 14 | 15 | ModelTypes = Literal["ar", "arima", "sarima", "var", "arch"] 16 | 17 | BlockCompressorTypes = Literal[ 18 | "first", 19 | "middle", 20 | "last", 21 | "mean", 22 | "mode", 23 | "median", 24 | "kmeans", 25 | "kmedians", 26 | "kmedoids", 27 | ] 28 | 29 | 30 | class DistributionTypes(Enum): 31 | """ 32 | Enumeration of supported distribution types for block length sampling. 33 | """ 34 | 35 | NONE = "none" 36 | POISSON = "poisson" 37 | EXPONENTIAL = "exponential" 38 | NORMAL = "normal" 39 | GAMMA = "gamma" 40 | BETA = "beta" 41 | LOGNORMAL = "lognormal" 42 | WEIBULL = "weibull" 43 | PARETO = "pareto" 44 | GEOMETRIC = "geometric" 45 | UNIFORM = "uniform" 46 | 47 | 48 | # Check Python version for compatibility issues. 49 | sys_version = sys.version.split(" ")[0] 50 | new_typing_available = sys_version in SpecifierSet(">=3.10") 51 | 52 | 53 | def FittedModelTypes() -> tuple: 54 | """ 55 | Return a tuple of fitted model types for use in isinstance checks. 56 | 57 | Returns 58 | ------- 59 | tuple: A tuple containing the result wrapper types for fitted models. 60 | """ 61 | from arch.univariate.base import ARCHModelResult 62 | from statsmodels.tsa.ar_model import AutoRegResultsWrapper 63 | from statsmodels.tsa.arima.model import ARIMAResultsWrapper 64 | from statsmodels.tsa.statespace.sarimax import SARIMAXResultsWrapper 65 | from statsmodels.tsa.vector_ar.var_model import VARResultsWrapper 66 | 67 | fmt = ( 68 | AutoRegResultsWrapper, 69 | ARIMAResultsWrapper, 70 | SARIMAXResultsWrapper, 71 | VARResultsWrapper, 72 | ARCHModelResult, 73 | ) 74 | return fmt 75 | 76 | 77 | # Define complex type conditions using the Python 3.10 union operator if available. 78 | if new_typing_available: 79 | OrderTypesWithoutNone = Union[ 80 | Integral, 81 | List[Integral], 82 | tuple[Integral, Integral, Integral], 83 | tuple[Integral, Integral, Integral, Integral], 84 | ] 85 | OrderTypes = Optional[OrderTypesWithoutNone] 86 | 87 | RngTypes = Optional[Union[Generator, Integral]] 88 | 89 | else: 90 | OrderTypesWithoutNone = Any 91 | OrderTypes = Any 92 | RngTypes = Any 93 | -------------------------------------------------------------------------------- /tests/_nopytest_tests.py: -------------------------------------------------------------------------------- 1 | """Tests to run without pytest, to check pytest isolation.""" 2 | 3 | from skbase.lookup import all_objects 4 | 5 | # all_objects crawls all modules excepting pytest test files 6 | # if it encounters an unisolated import, it will throw an exception 7 | results = all_objects(package_name="tsbootstrap", modules_to_ignore=["tests"]) 8 | -------------------------------------------------------------------------------- /tests/test_block_length_sampler.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import pytest 4 | from hypothesis import given 5 | from hypothesis import strategies as st 6 | from pydantic import ValidationError 7 | from tsbootstrap import BlockLengthSampler 8 | 9 | 10 | class TestPassingCases: 11 | """ 12 | Test suite for all cases where the BlockLengthSampler methods are expected to run successfully. 13 | """ 14 | 15 | @pytest.mark.parametrize( 16 | "distribution_name, avg_block_length", 17 | itertools.product( 18 | [ 19 | "none", 20 | "poisson", 21 | "exponential", 22 | "normal", 23 | "gamma", 24 | "beta", 25 | "lognormal", 26 | "weibull", 27 | "pareto", 28 | "geometric", 29 | "uniform", 30 | ], 31 | [2, 10, 100], 32 | ), 33 | ) 34 | def test_block_length_sampler_initialization( 35 | self, distribution_name, avg_block_length 36 | ): 37 | """ 38 | Test that BlockLengthSampler can be initialized with various valid inputs. 39 | """ 40 | bls = BlockLengthSampler( 41 | block_length_distribution=distribution_name, 42 | avg_block_length=avg_block_length, 43 | ) 44 | assert isinstance(bls, BlockLengthSampler) 45 | 46 | @pytest.mark.parametrize("random_seed", [None, 42, 0, 2**32 - 1]) 47 | def test_block_length_sampler_initialization_with_random_seed( 48 | self, random_seed 49 | ): 50 | """ 51 | Test that BlockLengthSampler can be initialized with various valid random seeds. 52 | """ 53 | bls = BlockLengthSampler( 54 | block_length_distribution="normal", 55 | avg_block_length=10, 56 | rng=random_seed, 57 | ) 58 | assert isinstance(bls, BlockLengthSampler) 59 | 60 | def test_same_random_seed(self): 61 | """ 62 | Test that the same random seed produces the same block lengths. 63 | """ 64 | num_samples = 100 65 | bls1 = BlockLengthSampler( 66 | block_length_distribution="normal", avg_block_length=10, rng=42 67 | ) 68 | bls2 = BlockLengthSampler( 69 | block_length_distribution="normal", avg_block_length=10, rng=42 70 | ) 71 | 72 | samples1 = [bls1.sample_block_length() for _ in range(num_samples)] 73 | samples2 = [bls2.sample_block_length() for _ in range(num_samples)] 74 | 75 | assert samples1 == samples2 76 | 77 | def test_different_random_seeds(self): 78 | """ 79 | Test that different random seeds produce different block lengths. 80 | """ 81 | num_samples = 100 82 | bls1 = BlockLengthSampler( 83 | block_length_distribution="normal", avg_block_length=10, rng=42 84 | ) 85 | bls2 = BlockLengthSampler( 86 | block_length_distribution="normal", avg_block_length=10, rng=123 87 | ) 88 | 89 | samples1 = [bls1.sample_block_length() for _ in range(num_samples)] 90 | samples2 = [bls2.sample_block_length() for _ in range(num_samples)] 91 | 92 | equal_samples = sum([s1 == s2 for s1, s2 in zip(samples1, samples2)]) 93 | assert equal_samples < num_samples * 0.5 94 | 95 | @given(st.integers(min_value=2, max_value=1000)) 96 | def test_sample_block_length(self, avg_block_length): 97 | """ 98 | Test that BlockLengthSampler's sample_block_length method returns results as expected for various average block lengths. 99 | """ 100 | bls = BlockLengthSampler( 101 | block_length_distribution="none", avg_block_length=avg_block_length 102 | ) 103 | assert bls.sample_block_length() == avg_block_length 104 | 105 | 106 | class TestFailingCases: 107 | """ 108 | Test suite for all cases where the BlockLengthSampler methods are expected to raise an exception. 109 | """ 110 | 111 | def test_invalid_distribution_name(self): 112 | """ 113 | Test that an invalid distribution name raises a ValueError. 114 | """ 115 | with pytest.raises(ValueError): 116 | BlockLengthSampler( 117 | block_length_distribution="invalid_distribution", 118 | avg_block_length=10, 119 | ) 120 | 121 | def test_invalid_distribution_number(self): 122 | """ 123 | Test that an invalid distribution number raises a ValueError. 124 | """ 125 | bls = BlockLengthSampler( 126 | block_length_distribution="uniform", avg_block_length=10 127 | ) 128 | with pytest.raises(TypeError): 129 | bls.block_length_distribution = 999 130 | 131 | def test_invalid_random_seed_low(self): 132 | """ 133 | Test that an invalid random seed (less than 0) raises a ValueError. 134 | """ 135 | with pytest.raises(ValueError): 136 | BlockLengthSampler( 137 | block_length_distribution="normal", avg_block_length=10, rng=-1 138 | ) 139 | 140 | def test_invalid_random_seed_high(self): 141 | """ 142 | Test that an invalid random seed (greater than 2**32) raises a ValueError. 143 | """ 144 | with pytest.raises(ValueError): 145 | BlockLengthSampler( 146 | block_length_distribution="normal", 147 | avg_block_length=10, 148 | rng=2**32, 149 | ) 150 | 151 | def test_zero_avg_block_length(self): 152 | """ 153 | Test that a zero average block length raises a ValueError. 154 | """ 155 | with pytest.warns(UserWarning): 156 | BlockLengthSampler( 157 | block_length_distribution="normal", avg_block_length=0 158 | ) 159 | 160 | @given( 161 | st.floats( 162 | min_value=0, 163 | max_value=2**32 - 1, 164 | allow_nan=False, 165 | allow_infinity=False, 166 | ) 167 | ) 168 | def test_non_integer_random_seed(self, random_seed): 169 | """ 170 | Test that a non-integer random seed raises a TypeError. 171 | """ 172 | with pytest.raises(TypeError): 173 | BlockLengthSampler( 174 | avg_block_length=10, 175 | block_length_distribution="normal", 176 | rng=random_seed, 177 | ) 178 | 179 | @given(st.integers(min_value=-1000, max_value=-1)) 180 | def test_negative_avg_block_length(self, avg_block_length): 181 | """ 182 | Test that a negative average block length raises a UserWarning. 183 | """ 184 | with pytest.warns(UserWarning): 185 | BlockLengthSampler( 186 | avg_block_length=avg_block_length, 187 | block_length_distribution="normal", 188 | ) 189 | 190 | def test_one_avg_block_length(self): 191 | """ 192 | Test that a one average block length raises a UserWarning. 193 | """ 194 | q = BlockLengthSampler( 195 | avg_block_length=1, block_length_distribution="normal" 196 | ) 197 | print(q.avg_block_length) 198 | with pytest.warns(UserWarning): 199 | BlockLengthSampler( 200 | avg_block_length=1, block_length_distribution="normal" 201 | ) 202 | 203 | @given( 204 | st.floats( 205 | min_value=0.1, 206 | max_value=1000.0, 207 | allow_nan=False, 208 | allow_infinity=False, 209 | ) 210 | ) 211 | def test_non_integer_avg_block_length(self, avg_block_length): 212 | """ 213 | Test that a non-integer average block length raises a TypeError. 214 | """ 215 | # Skip values that are whole numbers. 216 | if avg_block_length.is_integer(): 217 | return 218 | # Skip values that are smaller than 2 since these are automatically converted to 2, even if they are not whole numbers. 219 | if avg_block_length < 2: 220 | return 221 | 222 | with pytest.raises(ValidationError): 223 | print(f"{avg_block_length=}") 224 | print(f"{avg_block_length.is_integer()=}") 225 | BlockLengthSampler( 226 | avg_block_length=avg_block_length, 227 | block_length_distribution="normal", 228 | ) 229 | 230 | def test_none_avg_block_length(self): 231 | """ 232 | Test that the BlockLengthSampler constructor raises a TypeError when given a None type average block length. 233 | """ 234 | with pytest.raises(TypeError): 235 | BlockLengthSampler( 236 | avg_block_length=None, block_length_distribution="normal" 237 | ) 238 | -------------------------------------------------------------------------------- /tests/test_odds_and_ends.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from hypothesis import given 4 | from hypothesis import strategies as st 5 | from tsbootstrap.utils.odds_and_ends import time_series_split 6 | 7 | 8 | class TestTimeSeriesSplit: 9 | class TestPassingCases: 10 | @given( 11 | st.lists( 12 | st.floats(allow_infinity=False, allow_nan=False), 13 | min_size=2, 14 | max_size=100, 15 | ), 16 | st.floats( 17 | min_value=0.1, 18 | max_value=0.9, 19 | allow_nan=False, 20 | allow_infinity=False, 21 | ), 22 | ) 23 | def test_valid_input(self, X, test_ratio): 24 | X = np.array(X) 25 | X_train, X_test = time_series_split(X, test_ratio) 26 | assert len(X_train) == int(len(X) * (1 - test_ratio)) 27 | assert len(X_test) == len(X) - len(X_train) 28 | assert np.all(X_train == X[: len(X_train)]) 29 | assert np.all(X_test == X[len(X_train) :]) 30 | 31 | def test_zero_ratio(self): 32 | X = np.array([1, 2, 3, 4, 5]) 33 | X_train, X_test = time_series_split(X, 0) 34 | assert len(X_train) == 5 35 | assert len(X_test) == 0 36 | 37 | def test_full_ratio(self): 38 | X = np.array([1, 2, 3, 4, 5]) 39 | X_train, X_test = time_series_split(X, 1) 40 | assert len(X_train) == 0 41 | assert len(X_test) == 5 42 | 43 | class TestFailingCases: 44 | def test_negative_ratio(self): 45 | X = np.array([1, 2, 3, 4, 5]) 46 | with pytest.raises(ValueError): 47 | time_series_split(X, -0.5) 48 | 49 | def test_large_ratio(self): 50 | X = np.array([1, 2, 3, 4, 5]) 51 | with pytest.raises(ValueError): 52 | time_series_split(X, 1.5) 53 | -------------------------------------------------------------------------------- /tests/test_ranklags.py: -------------------------------------------------------------------------------- 1 | from numbers import Integral 2 | 3 | import numpy as np 4 | import pytest 5 | from skbase.utils.dependencies import _check_soft_dependencies 6 | from tsbootstrap.ranklags import RankLags 7 | 8 | 9 | @pytest.mark.skipif( 10 | not _check_soft_dependencies("statsmodels", severity="none"), 11 | reason="skip test if required soft dependency not available", 12 | ) 13 | class TestRankLags: 14 | class TestPassingCases: 15 | def test_basic_initialization(self): 16 | """ 17 | Test if the RankLags object is created with default parameters. 18 | """ 19 | X = np.random.normal(size=(100, 1)) 20 | rank_obj = RankLags(X, model_type="ar") 21 | assert isinstance(rank_obj, RankLags) 22 | 23 | def test_custom_max_lag_initialization(self): 24 | """ 25 | Test if the RankLags object is created with a custom max_lag. 26 | """ 27 | X = np.random.normal(size=(100, 1)) 28 | max_lag = 5 29 | rank_obj = RankLags(X, model_type="ar", max_lag=max_lag) 30 | assert rank_obj.max_lag == max_lag 31 | 32 | def test_exogenous_variable_initialization(self): 33 | """ 34 | Test if the RankLags object is created with exogenous variables. 35 | """ 36 | X = np.random.normal(size=(100, 1)) 37 | exog = np.random.normal(size=(100, 1)) 38 | rank_obj = RankLags(X, model_type="ar", y=exog) 39 | assert np.array_equal(rank_obj.y, exog) 40 | 41 | def test_save_models_flag_initialization(self): 42 | """ 43 | Test if the RankLags object is created with save_models as True. 44 | """ 45 | X = np.random.normal(size=(100, 1)) 46 | save_models = True 47 | rank_obj = RankLags(X, model_type="ar", save_models=save_models) 48 | assert rank_obj.save_models == save_models 49 | 50 | def test_aic_bic_rankings_univariate(self): 51 | """ 52 | Test AIC BIC rankings with univariate data. 53 | 54 | Ensure that the method returns correct rankings for given univariate data. 55 | """ 56 | X = np.random.normal(size=(100, 1)) 57 | rank_obj = RankLags(X, model_type="ar") 58 | aic_lags, bic_lags = rank_obj.rank_lags_by_aic_bic() 59 | assert isinstance(aic_lags, np.ndarray) 60 | assert isinstance(bic_lags, np.ndarray) 61 | assert len(aic_lags) == rank_obj.max_lag 62 | assert len(bic_lags) == rank_obj.max_lag 63 | 64 | def test_aic_bic_rankings_multivariate(self): 65 | """ 66 | Test AIC BIC rankings with multivariate data. 67 | 68 | Ensure that the method returns correct rankings for given multivariate data. 69 | """ 70 | X = np.random.normal(size=(100, 2)) 71 | rank_obj = RankLags(X, model_type="var", max_lag=2) 72 | aic_lags, bic_lags = rank_obj.rank_lags_by_aic_bic() 73 | assert isinstance(aic_lags, np.ndarray) 74 | assert isinstance(bic_lags, np.ndarray) 75 | assert len(aic_lags) == rank_obj.max_lag 76 | assert len(bic_lags) == rank_obj.max_lag 77 | 78 | def test_pacf_rankings_univariate(self): 79 | """ 80 | Test PACF rankings with univariate data. 81 | 82 | Ensure that the method returns correct PACF rankings for given univariate data. 83 | """ 84 | X = np.random.normal(size=(100, 1)) 85 | rank_obj = RankLags(X, model_type="ar") 86 | pacf_lags = rank_obj.rank_lags_by_pacf() 87 | assert isinstance(pacf_lags, np.ndarray) 88 | assert len(pacf_lags) <= rank_obj.max_lag 89 | 90 | def test_conservative_lag_univariate(self): 91 | """ 92 | Test estimation of conservative lag with univariate data. 93 | 94 | Ensure that the method returns a valid conservative lag for given univariate data. 95 | """ 96 | X = np.random.normal(size=(100, 1)) 97 | rank_obj = RankLags(X, model_type="ar") 98 | lag = rank_obj.estimate_conservative_lag() 99 | assert isinstance(lag, Integral) 100 | assert lag <= rank_obj.max_lag 101 | 102 | def test_conservative_lag_multivariate(self): 103 | """ 104 | Test estimation of conservative lag with multivariate data. 105 | 106 | Ensure that the method returns a valid conservative lag for given multivariate data. 107 | """ 108 | X = np.random.normal(size=(100, 2)) 109 | rank_obj = RankLags(X, model_type="var") 110 | lag = rank_obj.estimate_conservative_lag() 111 | assert isinstance(lag, Integral) 112 | assert lag <= rank_obj.max_lag 113 | 114 | def test_model_retrieval(self): 115 | """ 116 | Test model retrieval. 117 | 118 | Ensure that the method retrieves a previously fitted model. 119 | """ 120 | X = np.random.normal(size=(100, 1)) 121 | rank_obj = RankLags(X, model_type="ar", save_models=True) 122 | rank_obj.rank_lags_by_aic_bic() # Assuming this saves the models 123 | model = rank_obj.get_model(order=1) 124 | assert ( 125 | model is not None 126 | ) # Additional assertions based on the expected model type 127 | 128 | class TestFailingCases: 129 | def test_invalid_model_type(self): 130 | """ 131 | Test initialization with an invalid model type. 132 | 133 | Ensure that initializing with an invalid model type should raise an exception. 134 | """ 135 | X = np.random.normal(size=(100, 1)) 136 | with pytest.raises(ValueError, match="Invalid input_value"): 137 | RankLags(X, model_type="invalid_type") 138 | 139 | def test_negative_max_lag(self): 140 | """ 141 | Test initialization with a negative max_lag. 142 | 143 | Ensure that initializing with a negative max_lag should raise an exception. 144 | """ 145 | X = np.random.normal(size=(100, 1)) 146 | with pytest.raises(ValueError, match="Integer must be at least 1"): 147 | RankLags(X, model_type="ar", max_lag=-5) 148 | 149 | def test_pacf_rankings_non_univariate(self): 150 | """ 151 | Test PACF rankings with non-univariate data. 152 | 153 | Since PACF is only available for univariate data, the method should handle non-univariate data properly. 154 | """ 155 | X = np.random.normal(size=(100, 2)) 156 | rank_obj = RankLags(X, model_type="ar") 157 | with pytest.raises( 158 | ValueError 159 | ): # , match="PACF rankings are only available for univariate data"): 160 | rank_obj.rank_lags_by_pacf() 161 | 162 | def test_model_retrieval_without_saving(self): 163 | """ 164 | Test model retrieval without saving models. 165 | 166 | Ensure that the method returns None if models were not saved. 167 | """ 168 | X = np.random.normal(size=(100, 1)) 169 | rank_obj = RankLags(X, model_type="ar") 170 | rank_obj.rank_lags_by_aic_bic() # Assuming this computes but does not save the models 171 | model = rank_obj.get_model(order=1) 172 | assert model is None 173 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | minversion = 3.10.0 3 | envlist = precommit, py310, py311 4 | isolated_build = true 5 | 6 | [gh-actions] 7 | python = 8 | 3.10: py310, precommit 9 | 3.11: py311 10 | 11 | [testenv] 12 | setenv = 13 | PYTHONPATH = {toxinidir} 14 | allowlist_externals = 15 | poetry 16 | bash 17 | commands = 18 | poetry config virtualenvs.in-project true 19 | poetry install -v 20 | poetry run python -c 'import platform, subprocess; version = platform.python_version_tuple(); subprocess.run(["python", "-m", "pip", "install", "dtaidistance"]) if version < ("3", "10") else None' 21 | poetry run pytest --basetemp={envtmpdir} 22 | 23 | [testenv:precommit] 24 | basepython = python3.10 25 | whitelist_externals = poetry 26 | deps = pre-commit 27 | commands = pre-commit run --all-files 28 | -------------------------------------------------------------------------------- /tsbootstrap_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/astrogilda/tsbootstrap/6d6b9d37a87c7050a212a6f57fa6fca36d1d1ce4/tsbootstrap_logo.png -------------------------------------------------------------------------------- /uv_vs_pip.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/astrogilda/tsbootstrap/6d6b9d37a87c7050a212a6f57fa6fca36d1d1ce4/uv_vs_pip.jpg --------------------------------------------------------------------------------