├── .github └── workflows │ ├── deploy-docs.yml │ └── test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .travis.yml ├── CONTRIBUTING.md ├── LICENSE.txt ├── MANIFEST.in ├── PSL_catalog.json ├── README.md ├── conda.recipe ├── bld.bat ├── build.sh └── meta.yaml ├── docs ├── _config.yml ├── _toc.yml ├── api │ ├── custom-adjust.ipynb │ ├── custom-types.ipynb │ ├── extend.ipynb │ ├── guide.md │ ├── indexing.ipynb │ ├── reference.rst │ └── viewing-data.ipynb ├── intro.ipynb └── parameters.md ├── environment.yml ├── paramtools ├── __init__.py ├── contrib │ ├── __init__.py │ ├── fields.py │ └── validate.py ├── examples │ ├── baseball │ │ └── defaults.json │ ├── behresp │ │ └── defaults.json │ ├── taxparams-demo │ │ └── defaults.json │ └── taxparams │ │ └── defaults.json ├── exceptions.py ├── parameters.py ├── schema.py ├── schema_factory.py ├── select.py ├── sorted_key_list.py ├── tests │ ├── __init__.py │ ├── defaults.json │ ├── extend_ex.json │ ├── test_examples │ │ ├── __init__.py │ │ ├── test_baseball.py │ │ ├── test_behresp.py │ │ └── test_tc_ex.py │ ├── test_fields.py │ ├── test_parameters.py │ ├── test_schema.py │ ├── test_select.py │ ├── test_sorted_key_list.py │ ├── test_utils.py │ ├── test_validate.py │ └── test_values.py ├── typing.py ├── utils.py └── values.py ├── pyproject.toml └── setup.py /.github/workflows/deploy-docs.yml: -------------------------------------------------------------------------------- 1 | name: Build and Deploy Jupyter Book 2 | on: 3 | push: 4 | branches: 5 | - master 6 | 7 | jobs: 8 | build-and-deploy: 9 | if: github.repository == 'PSLmodels/ParamTools' 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Checkout 13 | uses: actions/checkout@v2 # If you're using actions/checkout@v2 you must set persist-credentials to false in most cases for the deployment to work correctly. 14 | with: 15 | persist-credentials: false 16 | 17 | - name: Setup Miniconda 18 | uses: conda-incubator/setup-miniconda@v3 19 | with: 20 | activate-environment: paramtools-dev 21 | environment-file: environment.yml 22 | python-version: 3.12 23 | auto-activate-base: false 24 | 25 | - name: Build # Build Jupyter Book 26 | shell: bash -l {0} 27 | run: | 28 | pip install -e . 29 | conda install pandas --yes 30 | jupyter-book build ./docs 31 | 32 | - name: Deploy 33 | uses: JamesIves/github-pages-deploy-action@releases/v3 34 | with: 35 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 36 | BRANCH: gh-pages # The branch the action should deploy to. 37 | FOLDER: docs/_build/html # The folder the action should deploy. 38 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | 2 | name: Build Package and Test Source Code [Python 3.10, 3.11, 3.12] 3 | 4 | on: 5 | push: 6 | branches: 7 | - master 8 | pull_request: {} 9 | 10 | jobs: 11 | build: 12 | runs-on: ubuntu-latest 13 | strategy: 14 | matrix: 15 | python-version: ["3.10", "3.11", "3.12"] 16 | 17 | steps: 18 | - name: Checkout 19 | uses: actions/checkout@master 20 | with: 21 | persist-credentials: false 22 | 23 | - name: Setup Miniconda using Python ${{ matrix.python-version }} 24 | uses: conda-incubator/setup-miniconda@v3 25 | with: 26 | activate-environment: paramtools-dev 27 | environment-file: environment.yml 28 | python-version: ${{ matrix.python-version }} 29 | auto-activate-base: false 30 | 31 | - name: Build 32 | shell: bash -l {0} 33 | run: | 34 | pip install -e . 35 | - name: Test 36 | shell: bash -l {0} 37 | working-directory: ./ 38 | run: | 39 | pytest paramtools -v -s -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/ambv/black 3 | rev: 18.9b0 4 | hooks: 5 | - id: black 6 | language_version: python3 7 | - repo: https://github.com/pre-commit/pre-commit-hooks 8 | rev: v2.1.0 9 | hooks: 10 | - id: flake8 11 | args: ["--max-line-length=100", "--ignore=E203,W503,E712"] 12 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | dist: xenial 2 | 3 | language: python 4 | 5 | python: 6 | - "3.6" 7 | - "3.7" 8 | 9 | # command to install dependencies 10 | install: 11 | - wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh 12 | - bash miniconda.sh -b -p $HOME/miniconda 13 | - export PATH="$HOME/miniconda/bin:$PATH" 14 | - conda config --set always_yes yes 15 | - conda update -n base -c defaults conda 16 | - conda create -n paramtools-dev python=$TRAVIS_PYTHON_VERSION; 17 | - source activate paramtools-dev 18 | - conda env update -f environment.yml 19 | 20 | # command to run tests 21 | script: 22 | - pytest -v -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Contributions 4 | ------------------------------ 5 | Contributions are welcome! Open a [PR][2] with your changes (and tests to go along with them!). In this PR describe what your change does and link to any relevant issues. 6 | 7 | Feature Requests 8 | ---------------------------------- 9 | Please open an [issue][1] describing the feature and its potential use cases. 10 | 11 | Bug Reports 12 | ----------------------- 13 | Please open an [issue][1] describing the bug. 14 | 15 | Dev Setup 16 | ------------------------ 17 | 18 | Fork the repo and clone it so that you have a copy of the source code. Next, run the following commands in the terminal: 19 | 20 | ``` 21 | cd ParamTools 22 | conda env create 23 | conda activate paramtools-dev 24 | pip install -e . 25 | pre-commit install 26 | ``` 27 | 28 | Testing 29 | ------------------- 30 | ``` 31 | py.test -v 32 | ``` 33 | 34 | 35 | [1]: https://github.com/PSLmodels/ParamTools/issues 36 | [2]: https://github.com/PSLmodels/ParamTools/pulls 37 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Henry Doupe 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include paramtools/examples/taxparams-demo/*.json 2 | -------------------------------------------------------------------------------- /PSL_catalog.json: -------------------------------------------------------------------------------- 1 | { 2 | "project_one_line": { 3 | "start_header": null, 4 | "end_header": null, 5 | "source": "https://github.com/PSLmodels/ParamTools", 6 | "type": "html", 7 | "data": "

Library for parameter processing and validation with a focus on computational modeling projects

" 8 | }, 9 | "key_features": { 10 | "start_header": null, 11 | "end_header": null, 12 | "source": null, 13 | "type": null, 14 | "data": null 15 | }, 16 | "project_overview": { 17 | "start_header": "ParamTools", 18 | "end_header": "How to use ParamTools", 19 | "source": "README.md", 20 | "type": "github_file", 21 | "data": null 22 | }, 23 | "citation": { 24 | "start_header": null, 25 | "end_header": null, 26 | "source": null, 27 | "data": null, 28 | "type": null 29 | }, 30 | "license": { 31 | "start_header": null, 32 | "end_header": null, 33 | "source": "https://github.com/PSLmodels/ParamTools/blob/master/LICENSE.txt", 34 | "data": "

MIT

", 35 | "type": "html" 36 | }, 37 | "user_documentation": { 38 | "start_header": null, 39 | "end_header": null, 40 | "type": "html", 41 | "source": "https://paramtools.dev", 42 | "data": "https://paramtools.dev" 43 | }, 44 | "user_changelog": { 45 | "start_header": null, 46 | "end_header": null, 47 | "source": "https://github.com/PSLmodels/ParamTools/releases", 48 | "type": "html", 49 | "data": "https://github.com/PSLmodels/ParamTools/releases" 50 | }, 51 | "user_changelog_recent": { 52 | "start_header": null, 53 | "end_header": null, 54 | "source": "https://github.com/PSLmodels/ParamTools/releases/latest", 55 | "type": "html", 56 | "data": "https://github.com/PSLmodels/ParamTools/releases/latest" 57 | }, 58 | "dev_changelog": { 59 | "start_header": null, 60 | "end_header": null, 61 | "source": "https://github.com/PSLmodels/ParamTools/releases", 62 | "type": "html", 63 | "data": "https://github.com/PSLmodels/ParamTools/releases" 64 | }, 65 | "disclaimer": { 66 | "start_header": null, 67 | "end_header": null, 68 | "source": null, 69 | "type": null, 70 | "data": null 71 | }, 72 | "user_case_studies": { 73 | "start_header": null, 74 | "end_header": null, 75 | "source": null, 76 | "type": null, 77 | "data": null 78 | }, 79 | "project_roadmap": { 80 | "start_header": null, 81 | "end_header": null, 82 | "source": null, 83 | "type": null, 84 | "data": null 85 | }, 86 | "contributor_overview": { 87 | "start_header": "Contributions", 88 | "end_header": "Dev Setup", 89 | "source": "CONTRIBUTING.md", 90 | "type": "github_file", 91 | "data": null 92 | }, 93 | "contributor_guide": { 94 | "start_header": null, 95 | "end_header": null, 96 | "source": "https://github.com/PSLmodels/ParamTools/blob/master/CONTRIBUTING.md", 97 | "type": "html", 98 | "data": "https://github.com/PSLmodels/ParamTools/blob/master/CONTRIBUTING.md" 99 | }, 100 | "governance_overview": { 101 | "start_header": null, 102 | "end_header": null, 103 | "source": null, 104 | "type": null, 105 | "data": null 106 | }, 107 | "public_funding": { 108 | "start_header": null, 109 | "end_header": null, 110 | "source": null, 111 | "type": null, 112 | "data": null 113 | }, 114 | "link_to_webapp": { 115 | "data": null, 116 | "source": null, 117 | "type": null, 118 | "start_header": null, 119 | "end_header": null 120 | }, 121 | "public_issue_tracker": { 122 | "start_header": null, 123 | "end_header": null, 124 | "data": "https://github.com/PSLmodels/ParamTools/issues", 125 | "source": null, 126 | "type": "html" 127 | }, 128 | "public_qanda": { 129 | "start_header": null, 130 | "end_header": null, 131 | "data": "https://github.com/PSLmodels/ParamTools/issues", 132 | "source": null, 133 | "type": "html" 134 | }, 135 | "core_maintainers": { 136 | "start_header": null, 137 | "end_header": null, 138 | "data": "", 139 | "source": null, 140 | "type": "html" 141 | }, 142 | "unit_test": { 143 | "start_header": null, 144 | "end_header": null, 145 | "data": "https://github.com/PSLmodels/ParamTools/tree/master/paramtools/tests", 146 | "source": "https://github.com/PSLmodels/ParamTools/tree/master/paramtools/tests", 147 | "type": "html" 148 | }, 149 | "integration_test": { 150 | "start_header": null, 151 | "end_header": null, 152 | "data": "https://github.com/PSLmodels/ParamTools/tree/master/paramtools/tests", 153 | "source": "https://github.com/PSLmodels/ParamTools/tree/master/paramtools/tests", 154 | "type": "html" 155 | } 156 | } 157 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ParamTools 2 | 3 | **Define, update, and validate your model's parameters.** 4 | 5 | Install using pip: 6 | 7 | ``` 8 | pip install paramtools 9 | ``` 10 | 11 | Install using conda: 12 | 13 | ``` 14 | conda install -c conda-forge paramtools 15 | ``` 16 | 17 | ## Usage 18 | 19 | Subclass `paramtools.Parameters` and define your model's [parameters](https://paramtools.dev/parameters): 20 | 21 | ```python 22 | import paramtools 23 | 24 | 25 | class Params(paramtools.Parameters): 26 | defaults = { 27 | "schema": { 28 | "labels": { 29 | "date": { 30 | "type": "date", 31 | "validators": { 32 | "range": { 33 | "min": "2020-01-01", 34 | "max": "2021-01-01", 35 | "step": {"months": 1} 36 | } 37 | } 38 | } 39 | }, 40 | }, 41 | "a": { 42 | "title": "A", 43 | "type": "int", 44 | "value": [ 45 | {"date": "2020-01-01", "value": 2}, 46 | {"date": "2020-10-01", "value": 8}, 47 | ], 48 | "validators": { 49 | "range" : { 50 | "min": 0, "max": "b" 51 | } 52 | } 53 | }, 54 | "b": { 55 | "title": "B", 56 | "type": "float", 57 | "value": [{"date": "2020-01-01", "value": 10.5}] 58 | } 59 | } 60 | ``` 61 | 62 | ### Access parameter values 63 | 64 | Access values using `.sel`: 65 | 66 | ```python 67 | params = Params() 68 | 69 | params.sel["a"] 70 | ``` 71 | 72 | Values([ 73 | {'date': datetime.date(2020, 1, 1), 'value': 2}, 74 | {'date': datetime.date(2020, 10, 1), 'value': 8}, 75 | ]) 76 | 77 | Look up parameter values using a pandas-like api: 78 | 79 | ```python 80 | from datetime import date 81 | 82 | result = params.sel["a"]["date"] == date(2020, 1, 1) 83 | result 84 | ``` 85 | 86 | QueryResult([ 87 | {'date': datetime.date(2020, 1, 1), 'value': 2} 88 | ]) 89 | 90 | ```python 91 | result.isel[0]["value"] 92 | ``` 93 | 94 | 2 95 | 96 | ### Adjust and validate parameter values 97 | 98 | Add a new value: 99 | 100 | ```python 101 | params.adjust({"a": [{"date": "2020-11-01", "value": 22}]}) 102 | 103 | params.sel["a"] 104 | ``` 105 | 106 | Values([ 107 | {'date': datetime.date(2020, 1, 1), 'value': 2}, 108 | {'date': datetime.date(2020, 10, 1), 'value': 8}, 109 | {'date': datetime.date(2020, 11, 1), 'value': 22}, 110 | ]) 111 | 112 | Update an existing value: 113 | 114 | ```python 115 | params.adjust({"a": [{"date": "2020-01-01", "value": 3}]}) 116 | 117 | params.sel["a"] 118 | ``` 119 | 120 | Values([ 121 | {'date': datetime.date(2020, 1, 1), 'value': 3}, 122 | {'date': datetime.date(2020, 10, 1), 'value': 8}, 123 | {'date': datetime.date(2020, 11, 1), 'value': 22}, 124 | ]) 125 | 126 | Update all values: 127 | 128 | ```python 129 | params.adjust({"a": 7}) 130 | 131 | params.sel["a"] 132 | ``` 133 | 134 | Values([ 135 | {'date': datetime.date(2020, 1, 1), 'value': 7}, 136 | {'date': datetime.date(2020, 10, 1), 'value': 7}, 137 | {'date': datetime.date(2020, 11, 1), 'value': 7}, 138 | ]) 139 | 140 | Errors on values that are out of range: 141 | 142 | ```python 143 | params.adjust({"a": -1}) 144 | ``` 145 | 146 | --------------------------------------------------------------------------- 147 | 148 | ValidationError Traceback (most recent call last) 149 | 150 | in 151 | ----> 1 params.adjust({"a": -1}) 152 | 153 | 154 | ~/Paramtools/paramtools/parameters.py in adjust(self, params_or_path, ignore_warnings, raise_errors, extend_adj, clobber) 155 | 253 least one existing value item's corresponding label values. 156 | 254 """ 157 | --> 255 return self._adjust( 158 | 256 params_or_path, 159 | 257 ignore_warnings=ignore_warnings, 160 | 161 | 162 | ~/Paramtools/paramtools/parameters.py in _adjust(self, params_or_path, ignore_warnings, raise_errors, extend_adj, is_deserialized, clobber) 163 | 371 not ignore_warnings and has_warnings 164 | 372 ): 165 | --> 373 raise self.validation_error 166 | 374 167 | 375 # Update attrs for params that were adjusted. 168 | 169 | 170 | ValidationError: { 171 | "errors": { 172 | "a": [ 173 | "a -1 < min 0 " 174 | ] 175 | } 176 | } 177 | 178 | ```python 179 | params = Params() 180 | 181 | params.adjust({"a": [{"date": "2020-01-01", "value": 11}]}) 182 | ``` 183 | 184 | --------------------------------------------------------------------------- 185 | 186 | ValidationError Traceback (most recent call last) 187 | 188 | in 189 | 1 params = Params() 190 | 2 191 | ----> 3 params.adjust({"a": [{"date": "2020-01-01", "value": 11}]}) 192 | 193 | 194 | ~/Paramtools/paramtools/parameters.py in adjust(self, params_or_path, ignore_warnings, raise_errors, extend_adj, clobber) 195 | 253 least one existing value item's corresponding label values. 196 | 254 """ 197 | --> 255 return self._adjust( 198 | 256 params_or_path, 199 | 257 ignore_warnings=ignore_warnings, 200 | 201 | 202 | ~/Paramtools/paramtools/parameters.py in _adjust(self, params_or_path, ignore_warnings, raise_errors, extend_adj, is_deserialized, clobber) 203 | 371 not ignore_warnings and has_warnings 204 | 372 ): 205 | --> 373 raise self.validation_error 206 | 374 207 | 375 # Update attrs for params that were adjusted. 208 | 209 | 210 | ValidationError: { 211 | "errors": { 212 | "a": [ 213 | "a[date=2020-01-01] 11 > max 10.5 b[date=2020-01-01]" 214 | ] 215 | } 216 | } 217 | 218 | Errors on invalid values: 219 | 220 | ```python 221 | params = Params() 222 | 223 | params.adjust({"b": "abc"}) 224 | ``` 225 | 226 | --------------------------------------------------------------------------- 227 | 228 | ValidationError Traceback (most recent call last) 229 | 230 | in 231 | 1 params = Params() 232 | 2 233 | ----> 3 params.adjust({"b": "abc"}) 234 | 235 | 236 | ~/Paramtools/paramtools/parameters.py in adjust(self, params_or_path, ignore_warnings, raise_errors, extend_adj, clobber) 237 | 253 least one existing value item's corresponding label values. 238 | 254 """ 239 | --> 255 return self._adjust( 240 | 256 params_or_path, 241 | 257 ignore_warnings=ignore_warnings, 242 | 243 | 244 | ~/Paramtools/paramtools/parameters.py in _adjust(self, params_or_path, ignore_warnings, raise_errors, extend_adj, is_deserialized, clobber) 245 | 371 not ignore_warnings and has_warnings 246 | 372 ): 247 | --> 373 raise self.validation_error 248 | 374 249 | 375 # Update attrs for params that were adjusted. 250 | 251 | 252 | ValidationError: { 253 | "errors": { 254 | "b": [ 255 | "Not a valid number: abc." 256 | ] 257 | } 258 | } 259 | 260 | ### Extend parameter values using label definitions 261 | 262 | Extend values using `label_to_extend`: 263 | 264 | ```python 265 | params = Params(label_to_extend="date") 266 | ``` 267 | 268 | ```python 269 | params.sel["a"] 270 | ``` 271 | 272 | Values([ 273 | {'date': datetime.date(2020, 1, 1), 'value': 2}, 274 | {'date': datetime.date(2020, 2, 1), 'value': 2, '_auto': True}, 275 | {'date': datetime.date(2020, 3, 1), 'value': 2, '_auto': True}, 276 | {'date': datetime.date(2020, 4, 1), 'value': 2, '_auto': True}, 277 | {'date': datetime.date(2020, 5, 1), 'value': 2, '_auto': True}, 278 | {'date': datetime.date(2020, 6, 1), 'value': 2, '_auto': True}, 279 | {'date': datetime.date(2020, 7, 1), 'value': 2, '_auto': True}, 280 | {'date': datetime.date(2020, 8, 1), 'value': 2, '_auto': True}, 281 | {'date': datetime.date(2020, 9, 1), 'value': 2, '_auto': True}, 282 | {'date': datetime.date(2020, 10, 1), 'value': 8}, 283 | {'date': datetime.date(2020, 11, 1), 'value': 8, '_auto': True}, 284 | {'date': datetime.date(2020, 12, 1), 'value': 8, '_auto': True}, 285 | {'date': datetime.date(2021, 1, 1), 'value': 8, '_auto': True}, 286 | ]) 287 | 288 | Updates to values are carried through to future dates: 289 | 290 | ```python 291 | params.adjust({"a": [{"date": "2020-4-01", "value": 9}]}) 292 | 293 | params.sel["a"] 294 | ``` 295 | 296 | Values([ 297 | {'date': datetime.date(2020, 1, 1), 'value': 2}, 298 | {'date': datetime.date(2020, 2, 1), 'value': 2, '_auto': True}, 299 | {'date': datetime.date(2020, 3, 1), 'value': 2, '_auto': True}, 300 | {'date': datetime.date(2020, 4, 1), 'value': 9}, 301 | {'date': datetime.date(2020, 5, 1), 'value': 9, '_auto': True}, 302 | {'date': datetime.date(2020, 6, 1), 'value': 9, '_auto': True}, 303 | {'date': datetime.date(2020, 7, 1), 'value': 9, '_auto': True}, 304 | {'date': datetime.date(2020, 8, 1), 'value': 9, '_auto': True}, 305 | {'date': datetime.date(2020, 9, 1), 'value': 9, '_auto': True}, 306 | {'date': datetime.date(2020, 10, 1), 'value': 9, '_auto': True}, 307 | {'date': datetime.date(2020, 11, 1), 'value': 9, '_auto': True}, 308 | {'date': datetime.date(2020, 12, 1), 'value': 9, '_auto': True}, 309 | {'date': datetime.date(2021, 1, 1), 'value': 9, '_auto': True}, 310 | ]) 311 | 312 | Use `clobber` to only update values that were set automatically: 313 | 314 | ```python 315 | params = Params(label_to_extend="date") 316 | params.adjust( 317 | {"a": [{"date": "2020-4-01", "value": 9}]}, 318 | clobber=False, 319 | ) 320 | 321 | # Sort parameter values by date for nicer output 322 | params.sort_values() 323 | params.sel["a"] 324 | ``` 325 | 326 | Values([ 327 | {'date': datetime.date(2020, 1, 1), 'value': 2}, 328 | {'date': datetime.date(2020, 2, 1), 'value': 2, '_auto': True}, 329 | {'date': datetime.date(2020, 3, 1), 'value': 2, '_auto': True}, 330 | {'date': datetime.date(2020, 4, 1), 'value': 9}, 331 | {'date': datetime.date(2020, 5, 1), 'value': 9, '_auto': True}, 332 | {'date': datetime.date(2020, 6, 1), 'value': 9, '_auto': True}, 333 | {'date': datetime.date(2020, 7, 1), 'value': 9, '_auto': True}, 334 | {'date': datetime.date(2020, 8, 1), 'value': 9, '_auto': True}, 335 | {'date': datetime.date(2020, 9, 1), 'value': 9, '_auto': True}, 336 | {'date': datetime.date(2020, 10, 1), 'value': 8}, 337 | {'date': datetime.date(2020, 11, 1), 'value': 8, '_auto': True}, 338 | {'date': datetime.date(2020, 12, 1), 'value': 8, '_auto': True}, 339 | {'date': datetime.date(2021, 1, 1), 'value': 8, '_auto': True}, 340 | ]) 341 | 342 | ### NumPy integration 343 | 344 | Access values as NumPy arrays with `array_first`: 345 | 346 | ```python 347 | params = Params(label_to_extend="date", array_first=True) 348 | 349 | params.a 350 | ``` 351 | 352 | array([2, 2, 2, 2, 2, 2, 2, 2, 2, 8, 8, 8, 8]) 353 | 354 | ```python 355 | params.a * params.b 356 | ``` 357 | 358 | array([21., 21., 21., 21., 21., 21., 21., 21., 21., 84., 84., 84., 84.]) 359 | 360 | Only get the values that you want: 361 | 362 | ```python 363 | arr = params.to_array("a", date=["2020-01-01", "2020-11-01"]) 364 | arr 365 | ``` 366 | 367 | array([2, 8]) 368 | 369 | Go back to a list of dictionaries: 370 | 371 | ```python 372 | params.from_array("a", arr, date=["2020-01-01", "2020-11-01"]) 373 | ``` 374 | 375 | [{'date': datetime.date(2020, 1, 1), 'value': 2}, 376 | {'date': datetime.date(2020, 11, 1), 'value': 8}] 377 | 378 | ## Documentation 379 | 380 | Full documentation available at [paramtools.dev](https://paramtools.dev). 381 | 382 | ## Contributing 383 | 384 | Contributions are welcome! Checkout [CONTRIBUTING.md][3] to get started. 385 | 386 | ## Credits 387 | 388 | ParamTools is built on top of the excellent [marshmallow][1] JSON schema and validation framework. I encourage everyone to check out their repo and documentation. ParamTools was modeled off of [Tax-Calculator's][2] parameter processing and validation engine due to its maturity and sophisticated capabilities. 389 | 390 | [1]: https://github.com/marshmallow-code/marshmallow 391 | [2]: https://github.com/PSLmodels/Tax-Calculator 392 | [3]: https://github.com/PSLmodels/ParamTools/blob/master/CONTRIBUTING.md 393 | -------------------------------------------------------------------------------- /conda.recipe/bld.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | SET BLD_DIR=%CD% 4 | cd /D "%RECIPE_DIR%\.." 5 | "%PYTHON%" setup.py install 6 | -------------------------------------------------------------------------------- /conda.recipe/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | BLD_DIR=`pwd` 4 | 5 | # Recipe and source are stored together 6 | SRC_DIR=$RECIPE_DIR/.. 7 | pushd $SRC_DIR 8 | 9 | $PYTHON setup.py install 10 | popd 11 | -------------------------------------------------------------------------------- /conda.recipe/meta.yaml: -------------------------------------------------------------------------------- 1 | package: 2 | name: paramtools 3 | version: 0.0.0 4 | 5 | requirements: 6 | build: 7 | - python 8 | - "marshmallow>=4.0.0" 9 | - "numpy>=1.13" 10 | - "python-dateutil>=2.8.0" 11 | 12 | run: 13 | - python 14 | - "marshmallow>=4.0.0" 15 | - "numpy>=1.13" 16 | - "python-dateutil>=2.8.0" 17 | 18 | test: 19 | imports: 20 | - paramtools 21 | 22 | about: 23 | home: https://github.com/PSLmodels/ParamTools 24 | -------------------------------------------------------------------------------- /docs/_config.yml: -------------------------------------------------------------------------------- 1 | title: ParamTools 2 | author: Hank Doupe 3 | copyright: "2020" 4 | # logo: "logo.png" 5 | 6 | repository: 7 | url: https://github.com/PSLmodels/ParamTools 8 | path_to_book: "docs" 9 | 10 | sphinx: 11 | extra_extensions: ["sphinx.ext.autodoc", "sphinx.ext.viewcode"] 12 | 13 | execute: 14 | allow_errors: true 15 | -------------------------------------------------------------------------------- /docs/_toc.yml: -------------------------------------------------------------------------------- 1 | format: jb-book 2 | root: intro 3 | parts: 4 | - caption: API 5 | chapters: 6 | - file: api/guide 7 | - file: api/viewing-data 8 | - file: api/extend 9 | - file: api/indexing 10 | - file: api/custom-adjust 11 | - file: api/custom-types 12 | - file: api/reference 13 | - caption: Parameters 14 | chapters: 15 | - file: parameters -------------------------------------------------------------------------------- /docs/api/custom-adjust.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Custom Adjustments\n", 8 | "\n", 9 | "The ParamTools adjustment format and logic can be augmented significantly. This is helpful for projects that need to support a pre-existing data format or require custom adjustment logic. Projects should customize their adjustments by writing their own `adjust` method and then calling the default `adjust` method from there:\n" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import paramtools\n", 19 | "\n", 20 | "\n", 21 | "class Params(paramtools.Parameters):\n", 22 | " def adjust(self, params_or_path, **kwargs):\n", 23 | " params = self.read_params(params_or_path)\n", 24 | "\n", 25 | " # ... custom logic here\n", 26 | "\n", 27 | " # call default adjust method.\n", 28 | " return super().adjust(params, **kwargs)\n", 29 | "\n" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "metadata": {}, 35 | "source": [ 36 | "## Example\n", 37 | "\n", 38 | "Some projects may find it convenient to use CSVs for their adjustment format. That's no problem for ParamTools as long as the CSV is converted to a JSON file or Python dictionary that meets the ParamTools criteria.\n" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 2, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "import io\n", 48 | "import os\n", 49 | "\n", 50 | "import pandas as pd\n", 51 | "\n", 52 | "import paramtools\n", 53 | "\n", 54 | "\n", 55 | "class CSVParams(paramtools.Parameters):\n", 56 | " defaults = {\n", 57 | " \"schema\": {\n", 58 | " \"labels\": {\n", 59 | " \"year\": {\n", 60 | " \"type\": \"int\",\n", 61 | " \"validators\": {\"range\": {\"min\": 2000, \"max\": 2005}}\n", 62 | " }\n", 63 | " }\n", 64 | " },\n", 65 | " \"a\": {\n", 66 | " \"title\": \"A\",\n", 67 | " \"description\": \"a param\",\n", 68 | " \"type\": \"int\",\n", 69 | " \"value\": [\n", 70 | " {\"year\": 2000, \"value\": 1},\n", 71 | " {\"year\": 2001, \"value\": 2},\n", 72 | " ]\n", 73 | " },\n", 74 | " \"b\": {\n", 75 | " \"title\": \"B\",\n", 76 | " \"description\": \"b param\",\n", 77 | " \"type\": \"int\",\n", 78 | " \"value\": [\n", 79 | " {\"year\": 2000, \"value\": 3},\n", 80 | " {\"year\": 2001, \"value\": 4},\n", 81 | " ]\n", 82 | " }\n", 83 | " }\n", 84 | "\n", 85 | " def adjust(self, params_or_path, **kwargs):\n", 86 | " \"\"\"\n", 87 | " A custom adjust method that converts CSV files to\n", 88 | " ParamTools compliant Python dictionaries.\n", 89 | " \"\"\"\n", 90 | " if os.path.exists(params_or_path):\n", 91 | " paramsdf = pd.read_csv(params_or_path, index_col=\"year\")\n", 92 | " else:\n", 93 | " paramsdf = pd.read_csv(io.StringIO(params_or_path), index_col=\"year\")\n", 94 | "\n", 95 | " dfdict = paramsdf.to_dict()\n", 96 | " params = {\"a\": [], \"b\": []}\n", 97 | " for label in params:\n", 98 | " for year, value in dfdict[label].items():\n", 99 | " params[label] += [{\"year\": year, \"value\": value}]\n", 100 | "\n", 101 | " # call adjust method on paramtools.Parameters which will\n", 102 | " # call _adjust to actually do the update.\n", 103 | " return super().adjust(params, **kwargs)\n", 104 | "\n" 105 | ] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "metadata": {}, 110 | "source": [ 111 | "Now we create an example CSV file. To keep the example self-contained, the CSV is just a string, but this example works with CSV files, too. The values of \"A\" are updated to 5 in 2000 and 6 in 2001, and the values of \"B\" are updated to 6 in 2000 and 7 in 2001.\n" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 3, 117 | "metadata": {}, 118 | "outputs": [ 119 | { 120 | "data": { 121 | "text/plain": [ 122 | "OrderedDict([('a',\n", 123 | " [OrderedDict([('year', 2000), ('value', 5)]),\n", 124 | " OrderedDict([('year', 2001), ('value', 6)])]),\n", 125 | " ('b',\n", 126 | " [OrderedDict([('year', 2000), ('value', 6)]),\n", 127 | " OrderedDict([('year', 2001), ('value', 7)])])])" 128 | ] 129 | }, 130 | "execution_count": 3, 131 | "metadata": {}, 132 | "output_type": "execute_result" 133 | } 134 | ], 135 | "source": [ 136 | "# this could also be a path to a CSV file.\n", 137 | "csv_string = \"\"\"\n", 138 | "year,a,b\n", 139 | "2000,5,6\\n\n", 140 | "2001,6,7\\n\n", 141 | "\"\"\"\n", 142 | "\n", 143 | "params = CSVParams()\n", 144 | "params.adjust(csv_string)\n" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 4, 150 | "metadata": {}, 151 | "outputs": [ 152 | { 153 | "data": { 154 | "text/plain": [ 155 | "[OrderedDict([('year', 2000), ('value', 5)]),\n", 156 | " OrderedDict([('year', 2001), ('value', 6)])]" 157 | ] 158 | }, 159 | "execution_count": 4, 160 | "metadata": {}, 161 | "output_type": "execute_result" 162 | } 163 | ], 164 | "source": [ 165 | "params.a" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 5, 171 | "metadata": {}, 172 | "outputs": [ 173 | { 174 | "data": { 175 | "text/plain": [ 176 | "[OrderedDict([('year', 2000), ('value', 6)]),\n", 177 | " OrderedDict([('year', 2001), ('value', 7)])]" 178 | ] 179 | }, 180 | "execution_count": 5, 181 | "metadata": {}, 182 | "output_type": "execute_result" 183 | } 184 | ], 185 | "source": [ 186 | "params.b" 187 | ] 188 | }, 189 | { 190 | "cell_type": "markdown", 191 | "metadata": {}, 192 | "source": [ 193 | "Now, if we use `array_first` and [`label_to_extend`](/api/extend/), the params instance can be loaded into a Pandas\n", 194 | "DataFrame like this:\n" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": 6, 200 | "metadata": {}, 201 | "outputs": [ 202 | { 203 | "data": { 204 | "text/html": [ 205 | "
\n", 206 | "\n", 219 | "\n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | "
ab
056
167
267
367
467
567
\n", 260 | "
" 261 | ], 262 | "text/plain": [ 263 | " a b\n", 264 | "0 5 6\n", 265 | "1 6 7\n", 266 | "2 6 7\n", 267 | "3 6 7\n", 268 | "4 6 7\n", 269 | "5 6 7" 270 | ] 271 | }, 272 | "execution_count": 6, 273 | "metadata": {}, 274 | "output_type": "execute_result" 275 | } 276 | ], 277 | "source": [ 278 | "csv_string = \"\"\"\n", 279 | "year,a,b\n", 280 | "2000,5,6\\n\n", 281 | "2001,6,7\\n\n", 282 | "\"\"\"\n", 283 | "\n", 284 | "params = CSVParams(array_first=True, label_to_extend=\"year\")\n", 285 | "params.adjust(csv_string)\n", 286 | "\n", 287 | "params_df = pd.DataFrame.from_dict(params.to_dict())\n", 288 | "params_df\n" 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": 7, 294 | "metadata": {}, 295 | "outputs": [ 296 | { 297 | "data": { 298 | "text/html": [ 299 | "
\n", 300 | "\n", 313 | "\n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | "
ab
year
200056
200167
200267
200367
200467
200567
\n", 359 | "
" 360 | ], 361 | "text/plain": [ 362 | " a b\n", 363 | "year \n", 364 | "2000 5 6\n", 365 | "2001 6 7\n", 366 | "2002 6 7\n", 367 | "2003 6 7\n", 368 | "2004 6 7\n", 369 | "2005 6 7" 370 | ] 371 | }, 372 | "execution_count": 7, 373 | "metadata": {}, 374 | "output_type": "execute_result" 375 | } 376 | ], 377 | "source": [ 378 | "params_df[\"year\"] = params.label_grid[\"year\"]\n", 379 | "params_df.set_index(\"year\")" 380 | ] 381 | }, 382 | { 383 | "cell_type": "code", 384 | "execution_count": null, 385 | "metadata": {}, 386 | "outputs": [], 387 | "source": [] 388 | } 389 | ], 390 | "metadata": { 391 | "kernelspec": { 392 | "display_name": "Python 3", 393 | "language": "python", 394 | "name": "python3" 395 | }, 396 | "language_info": { 397 | "codemirror_mode": { 398 | "name": "ipython", 399 | "version": 3 400 | }, 401 | "file_extension": ".py", 402 | "mimetype": "text/x-python", 403 | "name": "python", 404 | "nbconvert_exporter": "python", 405 | "pygments_lexer": "ipython3", 406 | "version": "3.8.5" 407 | } 408 | }, 409 | "nbformat": 4, 410 | "nbformat_minor": 4 411 | } 412 | -------------------------------------------------------------------------------- /docs/api/custom-types.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Custom Types\n", 8 | "\n", 9 | "Often, the behavior for a field needs to be customized to support a particular shape or validation method that ParamTools does not support out of the box. In this case, you may use the `register_custom_type` function to add your new `type` to the ParamTools type registry. Each `type` has a corresponding `field` that is used for serialization and deserialization. ParamTools will then use this `field` any time it is handling a `value`, `label`, or `member` that is of this `type`.\n", 10 | "\n", 11 | "ParamTools is built on top of [`marshmallow`](https://github.com/marshmallow-code/marshmallow), a general purpose validation library. This means that you must implement a custom `marshmallow` field to go along with your new type. Please refer to the `marshmallow` [docs](https://marshmallow.readthedocs.io/en/stable/) if you have questions about the use of `marshmallow` in the examples below.\n", 12 | "\n", 13 | "\n", 14 | "## 32 Bit Integer Example\n", 15 | "\n", 16 | "ParamTools's default integer field uses NumPy's `int64` type. This example shows you how to define an `int32` type and reference it in your `defaults`.\n", 17 | "\n", 18 | "First, let's define the Marshmallow class:\n" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 1, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "import marshmallow as ma\n", 28 | "import numpy as np\n", 29 | "\n", 30 | "class Int32(ma.fields.Field):\n", 31 | " \"\"\"\n", 32 | " A custom type for np.int32.\n", 33 | " https://numpy.org/devdocs/reference/arrays.dtypes.html\n", 34 | " \"\"\"\n", 35 | " # minor detail that makes this play nice with array_first\n", 36 | " np_type = np.int32\n", 37 | "\n", 38 | " def _serialize(self, value, *args, **kwargs):\n", 39 | " \"\"\"Convert np.int32 to basic, serializable Python int.\"\"\"\n", 40 | " return value.tolist()\n", 41 | "\n", 42 | " def _deserialize(self, value, *args, **kwargs):\n", 43 | " \"\"\"Cast value from JSON to NumPy Int32.\"\"\"\n", 44 | " converted = np.int32(value)\n", 45 | " return converted\n" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "metadata": {}, 51 | "source": [ 52 | "Now, reference it in our defaults JSON/dict object:\n" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 2, 58 | "metadata": {}, 59 | "outputs": [ 60 | { 61 | "name": "stdout", 62 | "output_type": "stream", 63 | "text": [ 64 | "value: 2, type: \n" 65 | ] 66 | } 67 | ], 68 | "source": [ 69 | "import paramtools as pt\n", 70 | "\n", 71 | "\n", 72 | "# add int32 type to the paramtools type registry\n", 73 | "pt.register_custom_type(\"int32\", Int32())\n", 74 | "\n", 75 | "\n", 76 | "class Params(pt.Parameters):\n", 77 | " defaults = {\n", 78 | " \"small_int\": {\n", 79 | " \"title\": \"Small integer\",\n", 80 | " \"description\": \"Demonstrate how to define a custom type\",\n", 81 | " \"type\": \"int32\",\n", 82 | " \"value\": 2\n", 83 | " }\n", 84 | " }\n", 85 | "\n", 86 | "\n", 87 | "params = Params(array_first=True)\n", 88 | "\n", 89 | "\n", 90 | "print(f\"value: {params.small_int}, type: {type(params.small_int)}\")\n", 91 | "\n" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "metadata": {}, 97 | "source": [ 98 | "One problem with this is that we could run into some deserialization issues. Due to integer overflow, our deserialized result is not the number that we passed in--it's negative!\n" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 3, 104 | "metadata": {}, 105 | "outputs": [ 106 | { 107 | "data": { 108 | "text/plain": [ 109 | "OrderedDict([('small_int', [OrderedDict([('value', -2147483648)])])])" 110 | ] 111 | }, 112 | "execution_count": 3, 113 | "metadata": {}, 114 | "output_type": "execute_result" 115 | } 116 | ], 117 | "source": [ 118 | "params.adjust(dict(\n", 119 | " # this number wasn't chosen randomly.\n", 120 | " small_int=2147483647 + 1\n", 121 | "))\n" 122 | ] 123 | }, 124 | { 125 | "cell_type": "markdown", 126 | "metadata": {}, 127 | "source": [ 128 | "### Marshmallow Validator\n", 129 | "\n", 130 | "Fortunately, you can specify a custom validator with `marshmallow` or ParamTools. Making this works requires modifying the `_deserialize` method to check for overflow like this:\n" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 4, 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "class Int32(ma.fields.Field):\n", 140 | " \"\"\"\n", 141 | " A custom type for np.int32.\n", 142 | " https://numpy.org/devdocs/reference/arrays.dtypes.html\n", 143 | " \"\"\"\n", 144 | " # minor detail that makes this play nice with array_first\n", 145 | " np_type = np.int32\n", 146 | "\n", 147 | " def _serialize(self, value, *args, **kwargs):\n", 148 | " \"\"\"Convert np.int32 to basic Python int.\"\"\"\n", 149 | " return value.tolist()\n", 150 | "\n", 151 | " def _deserialize(self, value, *args, **kwargs):\n", 152 | " \"\"\"Cast value from JSON to NumPy Int32.\"\"\"\n", 153 | " converted = np.int32(value)\n", 154 | "\n", 155 | " # check for overflow and let range validator\n", 156 | " # display the error message.\n", 157 | " if converted != int(value):\n", 158 | " return int(value)\n", 159 | "\n", 160 | " return converted\n" 161 | ] 162 | }, 163 | { 164 | "cell_type": "markdown", 165 | "metadata": {}, 166 | "source": [ 167 | " Now, let's see how to use `marshmallow` to fix this problem:\n" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 9, 173 | "metadata": {}, 174 | "outputs": [ 175 | { 176 | "ename": "ValidationError", 177 | "evalue": "{\n \"errors\": {\n \"small_int\": [\n \"Must be greater than or equal to -2147483648 and less than or equal to 2147483647.\"\n ]\n }\n}", 178 | "output_type": "error", 179 | "traceback": [ 180 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 181 | "\u001b[0;31mValidationError\u001b[0m Traceback (most recent call last)", 182 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0mparams\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mParams\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marray_first\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 31\u001b[0;31m params.adjust(dict(\n\u001b[0m\u001b[1;32m 32\u001b[0m \u001b[0msmall_int\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mint64\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmax_int32\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 33\u001b[0m ))\n", 183 | "\u001b[0;32m~/ParamTools/paramtools/parameters.py\u001b[0m in \u001b[0;36madjust\u001b[0;34m(self, params_or_path, ignore_warnings, raise_errors, extend_adj, clobber)\u001b[0m\n\u001b[1;32m 205\u001b[0m \u001b[0mleast\u001b[0m \u001b[0mone\u001b[0m \u001b[0mexisting\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0mitem\u001b[0m\u001b[0;31m'\u001b[0m\u001b[0ms\u001b[0m \u001b[0mcorresponding\u001b[0m \u001b[0mlabel\u001b[0m \u001b[0mvalues\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 206\u001b[0m \"\"\"\n\u001b[0;32m--> 207\u001b[0;31m return self._adjust(\n\u001b[0m\u001b[1;32m 208\u001b[0m \u001b[0mparams_or_path\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 209\u001b[0m \u001b[0mignore_warnings\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mignore_warnings\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 184 | "\u001b[0;32m~/ParamTools/paramtools/parameters.py\u001b[0m in \u001b[0;36m_adjust\u001b[0;34m(self, params_or_path, ignore_warnings, raise_errors, extend_adj, is_deserialized, clobber)\u001b[0m\n\u001b[1;32m 333\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mignore_warnings\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mhas_warnings\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 334\u001b[0m ):\n\u001b[0;32m--> 335\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalidation_error\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 336\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 337\u001b[0m \u001b[0;31m# Update attrs for params that were adjusted.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 185 | "\u001b[0;31mValidationError\u001b[0m: {\n \"errors\": {\n \"small_int\": [\n \"Must be greater than or equal to -2147483648 and less than or equal to 2147483647.\"\n ]\n }\n}" 186 | ] 187 | } 188 | ], 189 | "source": [ 190 | "import marshmallow as ma\n", 191 | "import paramtools as pt\n", 192 | "\n", 193 | "\n", 194 | "# get the minimum and maxium values for 32 bit integers.\n", 195 | "min_int32 = -2147483648 # = np.iinfo(np.int32).min\n", 196 | "max_int32 = 2147483647 # = np.iinfo(np.int32).max\n", 197 | "\n", 198 | "# add int32 type to the paramtools type registry\n", 199 | "pt.register_custom_type(\n", 200 | " \"int32\",\n", 201 | " Int32(validate=[\n", 202 | " ma.validate.Range(min=min_int32, max=max_int32)\n", 203 | " ])\n", 204 | ")\n", 205 | "\n", 206 | "\n", 207 | "class Params(pt.Parameters):\n", 208 | " defaults = {\n", 209 | " \"small_int\": {\n", 210 | " \"title\": \"Small integer\",\n", 211 | " \"description\": \"Demonstrate how to define a custom type\",\n", 212 | " \"type\": \"int32\",\n", 213 | " \"value\": 2\n", 214 | " }\n", 215 | " }\n", 216 | "\n", 217 | "\n", 218 | "params = Params(array_first=True)\n", 219 | "\n", 220 | "params.adjust(dict(\n", 221 | " small_int=np.int64(max_int32) + 1\n", 222 | "))\n" 223 | ] 224 | }, 225 | { 226 | "cell_type": "markdown", 227 | "metadata": {}, 228 | "source": [ 229 | "### ParamTools Validator\n", 230 | "\n", 231 | "Finally, we will use ParamTools to solve this problem. We need to modify how we create our custom `marshmallow` field so that it's wrapped by ParamTools's `PartialField`. This makes it clear that your field still needs to be initialized, and that your custom field is able to receive validation information from the `defaults` configuration:\n" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": 7, 237 | "metadata": {}, 238 | "outputs": [ 239 | { 240 | "ename": "ValidationError", 241 | "evalue": "{\n \"errors\": {\n \"small_int\": [\n \"small_int 2147483648 > max 2147483647 \"\n ]\n }\n}", 242 | "output_type": "error", 243 | "traceback": [ 244 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 245 | "\u001b[0;31mValidationError\u001b[0m Traceback (most recent call last)", 246 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0mparams\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mParams\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marray_first\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 27\u001b[0;31m params.adjust(dict(\n\u001b[0m\u001b[1;32m 28\u001b[0m \u001b[0msmall_int\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2147483647\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 29\u001b[0m ))\n", 247 | "\u001b[0;32m~/ParamTools/paramtools/parameters.py\u001b[0m in \u001b[0;36madjust\u001b[0;34m(self, params_or_path, ignore_warnings, raise_errors, extend_adj, clobber)\u001b[0m\n\u001b[1;32m 205\u001b[0m \u001b[0mleast\u001b[0m \u001b[0mone\u001b[0m \u001b[0mexisting\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0mitem\u001b[0m\u001b[0;31m'\u001b[0m\u001b[0ms\u001b[0m \u001b[0mcorresponding\u001b[0m \u001b[0mlabel\u001b[0m \u001b[0mvalues\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 206\u001b[0m \"\"\"\n\u001b[0;32m--> 207\u001b[0;31m return self._adjust(\n\u001b[0m\u001b[1;32m 208\u001b[0m \u001b[0mparams_or_path\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 209\u001b[0m \u001b[0mignore_warnings\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mignore_warnings\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 248 | "\u001b[0;32m~/ParamTools/paramtools/parameters.py\u001b[0m in \u001b[0;36m_adjust\u001b[0;34m(self, params_or_path, ignore_warnings, raise_errors, extend_adj, is_deserialized, clobber)\u001b[0m\n\u001b[1;32m 333\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mignore_warnings\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mhas_warnings\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 334\u001b[0m ):\n\u001b[0;32m--> 335\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalidation_error\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 336\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 337\u001b[0m \u001b[0;31m# Update attrs for params that were adjusted.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 249 | "\u001b[0;31mValidationError\u001b[0m: {\n \"errors\": {\n \"small_int\": [\n \"small_int 2147483648 > max 2147483647 \"\n ]\n }\n}" 250 | ] 251 | } 252 | ], 253 | "source": [ 254 | "import paramtools as pt\n", 255 | "\n", 256 | "\n", 257 | "# add int32 type to the paramtools type registry\n", 258 | "pt.register_custom_type(\n", 259 | " \"int32\",\n", 260 | " pt.PartialField(Int32)\n", 261 | ")\n", 262 | "\n", 263 | "\n", 264 | "class Params(pt.Parameters):\n", 265 | " defaults = {\n", 266 | " \"small_int\": {\n", 267 | " \"title\": \"Small integer\",\n", 268 | " \"description\": \"Demonstrate how to define a custom type\",\n", 269 | " \"type\": \"int32\",\n", 270 | " \"value\": 2,\n", 271 | " \"validators\": {\n", 272 | " \"range\": {\"min\": -2147483648, \"max\": 2147483647}\n", 273 | " }\n", 274 | " }\n", 275 | " }\n", 276 | "\n", 277 | "\n", 278 | "params = Params(array_first=True)\n", 279 | "\n", 280 | "params.adjust(dict(\n", 281 | " small_int=2147483647 + 1\n", 282 | "))\n", 283 | "\n" 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": null, 289 | "metadata": {}, 290 | "outputs": [], 291 | "source": [] 292 | } 293 | ], 294 | "metadata": { 295 | "kernelspec": { 296 | "display_name": "Python 3", 297 | "language": "python", 298 | "name": "python3" 299 | }, 300 | "language_info": { 301 | "codemirror_mode": { 302 | "name": "ipython", 303 | "version": 3 304 | }, 305 | "file_extension": ".py", 306 | "mimetype": "text/x-python", 307 | "name": "python", 308 | "nbconvert_exporter": "python", 309 | "pygments_lexer": "ipython3", 310 | "version": "3.8.5" 311 | } 312 | }, 313 | "nbformat": 4, 314 | "nbformat_minor": 4 315 | } 316 | -------------------------------------------------------------------------------- /docs/api/extend.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Extend\n", 8 | "\n", 9 | "The values of a parameter can be extended along a specified label. This is helpful when a parameter's values are the same for different values of a label and there is some inherent order in that label. The extend feature allows you to simply write down the minimum amount of information needed to fill in a parameter's values and ParamTools will fill in the gaps.\n", 10 | "\n", 11 | "To use the extend feature, set the `label_to_extend` class attribute to the label that should be extended.\n", 12 | "\n", 13 | "## Example\n", 14 | "\n", 15 | "The standard deduction parameter's values only need to be specified when there is a change in the tax law. For the other years, it does not change (unless its indexed to inflation). It would be annoying to have to manually write out each of its values. Instead, we can more concisely write its values in 2017, its new values in 2018 after the TCJA tax reform was passed, and its values after provisions of the TCJA are phased out in 2026.\n" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 1, 21 | "metadata": {}, 22 | "outputs": [ 23 | { 24 | "data": { 25 | "text/plain": [ 26 | "array([[ 6350., 12700.],\n", 27 | " [ 6350., 12700.],\n", 28 | " [ 6350., 12700.],\n", 29 | " [ 6350., 12700.],\n", 30 | " [ 6350., 12700.],\n", 31 | " [12000., 24000.],\n", 32 | " [12000., 24000.],\n", 33 | " [12000., 24000.],\n", 34 | " [12000., 24000.],\n", 35 | " [12000., 24000.],\n", 36 | " [12000., 24000.],\n", 37 | " [12000., 24000.],\n", 38 | " [12000., 24000.],\n", 39 | " [ 7685., 15369.],\n", 40 | " [ 7685., 15369.]])" 41 | ] 42 | }, 43 | "execution_count": 1, 44 | "metadata": {}, 45 | "output_type": "execute_result" 46 | } 47 | ], 48 | "source": [ 49 | "import paramtools\n", 50 | "\n", 51 | "\n", 52 | "class TaxParams(paramtools.Parameters):\n", 53 | " defaults = {\n", 54 | " \"schema\": {\n", 55 | " \"labels\": {\n", 56 | " \"year\": {\n", 57 | " \"type\": \"int\",\n", 58 | " \"validators\": {\"range\": {\"min\": 2013, \"max\": 2027}}\n", 59 | " },\n", 60 | " \"marital_status\": {\n", 61 | " \"type\": \"str\",\n", 62 | " \"validators\": {\"choice\": {\"choices\": [\"single\", \"joint\"]}}\n", 63 | " },\n", 64 | " }\n", 65 | " },\n", 66 | " \"standard_deduction\": {\n", 67 | " \"title\": \"Standard deduction amount\",\n", 68 | " \"description\": \"Amount filing unit can use as a standard deduction.\",\n", 69 | " \"type\": \"float\",\n", 70 | " \"value\": [\n", 71 | " {\"year\": 2017, \"marital_status\": \"single\", \"value\": 6350},\n", 72 | " {\"year\": 2017, \"marital_status\": \"joint\", \"value\": 12700},\n", 73 | " {\"year\": 2018, \"marital_status\": \"single\", \"value\": 12000},\n", 74 | " {\"year\": 2018, \"marital_status\": \"joint\", \"value\": 24000},\n", 75 | " {\"year\": 2026, \"marital_status\": \"single\", \"value\": 7685},\n", 76 | " {\"year\": 2026, \"marital_status\": \"joint\", \"value\": 15369}],\n", 77 | " \"validators\": {\n", 78 | " \"range\": {\n", 79 | " \"min\": 0,\n", 80 | " \"max\": 9e+99\n", 81 | " }\n", 82 | " }\n", 83 | " },\n", 84 | " }\n", 85 | "\n", 86 | " label_to_extend = \"year\"\n", 87 | " array_first = True\n", 88 | "\n", 89 | "params = TaxParams()\n", 90 | "params.standard_deduction\n" 91 | ] 92 | }, 93 | { 94 | "cell_type": "markdown", 95 | "metadata": {}, 96 | "source": [ 97 | "Adjustments are also extended along `label_to_extend`. In the example below, `standard_deduction` is set to 10,000 in 2017, increased to 15,000 for single tax units in 2020, and increased to 20,000 for joint tax units in 2021:\n" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 2, 103 | "metadata": {}, 104 | "outputs": [ 105 | { 106 | "data": { 107 | "text/plain": [ 108 | "array([[ 6350., 12700.],\n", 109 | " [ 6350., 12700.],\n", 110 | " [ 6350., 12700.],\n", 111 | " [ 6350., 12700.],\n", 112 | " [10000., 10000.],\n", 113 | " [10000., 10000.],\n", 114 | " [10000., 10000.],\n", 115 | " [15000., 10000.],\n", 116 | " [15000., 20000.],\n", 117 | " [15000., 20000.],\n", 118 | " [15000., 20000.],\n", 119 | " [15000., 20000.],\n", 120 | " [15000., 20000.],\n", 121 | " [15000., 20000.],\n", 122 | " [15000., 20000.]])" 123 | ] 124 | }, 125 | "execution_count": 2, 126 | "metadata": {}, 127 | "output_type": "execute_result" 128 | } 129 | ], 130 | "source": [ 131 | "params.adjust(\n", 132 | " {\n", 133 | " \"standard_deduction\": [\n", 134 | " {\"year\": 2017, \"value\": 10000},\n", 135 | " {\"year\": 2020, \"marital_status\": \"single\", \"value\": 15000},\n", 136 | " {\"year\": 2021, \"marital_status\": \"joint\", \"value\": 20000}\n", 137 | " ]\n", 138 | " }\n", 139 | ")\n", 140 | "\n", 141 | "params.standard_deduction\n" 142 | ] 143 | }, 144 | { 145 | "cell_type": "markdown", 146 | "metadata": {}, 147 | "source": [ 148 | "### Clobber\n", 149 | "\n", 150 | "In the previous example, the new values _clobber_ the existing values in years after they are specified. By setting `clobber` to `False`, only values that were added automatically will be replaced by the new ones. User defined values such as those in 2026 will not be over-written by the new values:\n" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 3, 156 | "metadata": {}, 157 | "outputs": [ 158 | { 159 | "data": { 160 | "text/plain": [ 161 | "array([[ 6350., 12700.],\n", 162 | " [ 6350., 12700.],\n", 163 | " [ 6350., 12700.],\n", 164 | " [ 6350., 12700.],\n", 165 | " [10000., 10000.],\n", 166 | " [12000., 24000.],\n", 167 | " [12000., 24000.],\n", 168 | " [15000., 24000.],\n", 169 | " [15000., 20000.],\n", 170 | " [15000., 20000.],\n", 171 | " [15000., 20000.],\n", 172 | " [15000., 20000.],\n", 173 | " [15000., 20000.],\n", 174 | " [ 7685., 15369.],\n", 175 | " [ 7685., 15369.]])" 176 | ] 177 | }, 178 | "execution_count": 3, 179 | "metadata": {}, 180 | "output_type": "execute_result" 181 | } 182 | ], 183 | "source": [ 184 | "params = TaxParams()\n", 185 | "params.adjust(\n", 186 | " {\n", 187 | " \"standard_deduction\": [\n", 188 | " {\"year\": 2017, \"value\": 10000},\n", 189 | " {\"year\": 2020, \"marital_status\": \"single\", \"value\": 15000},\n", 190 | " {\"year\": 2021, \"marital_status\": \"joint\", \"value\": 20000}\n", 191 | " ]\n", 192 | " },\n", 193 | " clobber=False,\n", 194 | ")\n", 195 | "\n", 196 | "params.standard_deduction\n" 197 | ] 198 | }, 199 | { 200 | "cell_type": "markdown", 201 | "metadata": {}, 202 | "source": [ 203 | "## Extend behavior by validator\n", 204 | "\n", 205 | "ParamTools uses the validator associated with `label_to_extend` to determine how values should be extended by assuming that there is some order among the range of possible values for the label.\n", 206 | "\n", 207 | "Note: You can view the grid of values for any label by inspecting the `label_grid` attribute of a `paramtools.Parameters` derived instance.\n", 208 | "\n", 209 | "### Range\n", 210 | "\n", 211 | "**Type:** `int`\n", 212 | "\n", 213 | "```json\n", 214 | "{\n", 215 | " \"range\": { \"min\": 0, \"max\": 5 }\n", 216 | "}\n", 217 | "```\n", 218 | "\n", 219 | "_Extend values:_\n", 220 | "\n", 221 | "```python\n", 222 | "[0, 1, 2, 3, 4, 5]\n", 223 | "```\n", 224 | "\n", 225 | "**Type:** `float`\n", 226 | "\n", 227 | "```json\n", 228 | "{\n", 229 | " \"range\": { \"min\": 0, \"max\": 2, \"step\": 0.5 }\n", 230 | "}\n", 231 | "```\n", 232 | "\n", 233 | "_Extend values:_\n", 234 | "\n", 235 | "```python\n", 236 | "[0, 0.5, 1.0, 1.5, 2.0]\n", 237 | "```\n", 238 | "\n", 239 | "**Type:** `date`\n", 240 | "\n", 241 | "```json\n", 242 | "{\n", 243 | " \"range\": { \"min\": \"2019-01-01\", \"max\": \"2019-01-05\", \"step\": { \"days\": 2 } }\n", 244 | "}\n", 245 | "```\n", 246 | "\n", 247 | "_Extend values:_\n", 248 | "\n", 249 | "```python\n", 250 | "[datetime.date(2019, 1, 1),\n", 251 | " datetime.date(2019, 1, 3),\n", 252 | " datetime.date(2019, 1, 5)]\n", 253 | "```\n", 254 | "\n", 255 | "### Choice\n", 256 | "\n", 257 | "**Type:** `int`\n", 258 | "\n", 259 | "```json\n", 260 | "{\n", 261 | " \"choice\": { \"choices\": [-1, -2, -3] }\n", 262 | "}\n", 263 | "```\n", 264 | "\n", 265 | "_Extend values:_\n", 266 | "\n", 267 | "```python\n", 268 | "[-1, -2, -3]\n", 269 | "```\n", 270 | "\n", 271 | "**Type:** `str`\n", 272 | "\n", 273 | "```json\n", 274 | "{\n", 275 | " \"choice\": { \"choices\": [\"january\", \"february\", \"march\"] }\n", 276 | "}\n", 277 | "```\n", 278 | "\n", 279 | "_Extend values:_\n", 280 | "\n", 281 | "```python\n", 282 | "[\"january\", \"february\", \"march\"]\n", 283 | "```\n" 284 | ] 285 | } 286 | ], 287 | "metadata": { 288 | "kernelspec": { 289 | "display_name": "Python 3", 290 | "language": "python", 291 | "name": "python3" 292 | }, 293 | "language_info": { 294 | "codemirror_mode": { 295 | "name": "ipython", 296 | "version": 3 297 | }, 298 | "file_extension": ".py", 299 | "mimetype": "text/x-python", 300 | "name": "python", 301 | "nbconvert_exporter": "python", 302 | "pygments_lexer": "ipython3", 303 | "version": "3.8.5" 304 | } 305 | }, 306 | "nbformat": 4, 307 | "nbformat_minor": 4 308 | } 309 | -------------------------------------------------------------------------------- /docs/api/guide.md: -------------------------------------------------------------------------------- 1 | # Guide 2 | 3 | **Just getting started?** 4 | Check out the [home page](/intro/) for a quick start guide. 5 | 6 | **Want to learn more about the JSON spec?** 7 | Check out the [parameters spec](/parameters/) to learn more about what types of parameters and validators are supported by ParamTools. 8 | 9 | ## API 10 | 11 | [**Viewing data**](/api/viewing-data/) Working with parameter values. 12 | 13 | [**Extend:**](/api/extend/) Write more concise JSON schemas with the extend capability. 14 | 15 | [**Extend with Indexing:**](/api/indexing/) Index parameter values. 16 | 17 | [**Custom Adjustment Formats and Logic**](/api/custom-adjust/) Customize the adjustment format and logic to meet your project's needs. 18 | 19 | [**Custom types**](/api/custom-types/) Add custom types for your parameters' values, labels, and members. 20 | -------------------------------------------------------------------------------- /docs/api/indexing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Extend with Indexing\n", 8 | "\n", 9 | "ParamTools provides out-of-the-box parameter indexing. This is helpful for projects that have parameters that change at some rate over time. For example, tax parameters like the standard deduction are often indexed to price inflation. So, the value of the standard deduction actually increases every year by 1 or 2% depending on that year's inflation rate.\n", 10 | "\n", 11 | "The [extend documentation](/api/extend/) may be useful for gaining a better understanding of how ParamTools extends parameter values along `label_to_extend`.\n", 12 | "\n", 13 | "To use the indexing feature:\n", 14 | "\n", 15 | "- Set the `label_to_extend` class attribute to the label that should be extended\n", 16 | "- Set the `indexing_rates` class attribute to a dictionary of inflation rates where the keys correspond to the value of `label_to_extend` and the values are the indexing rates.\n", 17 | "- Set the `uses_extend_func` class attribute to `True`.\n", 18 | "- In `defaults` or `defaults.json`, set `indexed` to `True` for each parameter that needs to be indexed.\n", 19 | "\n", 20 | "## Example\n", 21 | "\n", 22 | "This is a continuation of the tax parameters example from the [extend documentation](/api/extend/). The differences are `indexed` is set to `True` for the `standard_deducation` parameter, `uses_extend_func` is set to `True`, and `index_rates` is specified with inflation rates obtained from the open-source tax modeling package, [Tax-Calculator](https://github.com/PSLmodels/Tax-Calculator/), using version 2.5.0.\n" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 1, 28 | "metadata": {}, 29 | "outputs": [ 30 | { 31 | "data": { 32 | "text/plain": [ 33 | "array([[ 6074.92, 12149.84],\n", 34 | " [ 6164.83, 12329.66],\n", 35 | " [ 6262.85, 12525.7 ],\n", 36 | " [ 6270.37, 12540.73],\n", 37 | " [ 6350. , 12700. ],\n", 38 | " [12000. , 24000. ],\n", 39 | " [12268.8 , 24537.6 ],\n", 40 | " [12497. , 24994. ],\n", 41 | " [12788.18, 25576.36],\n", 42 | " [13081.03, 26162.06],\n", 43 | " [13379.28, 26758.55],\n", 44 | " [13674.96, 27349.91],\n", 45 | " [13963.5 , 27926.99],\n", 46 | " [ 7685. , 15369. ],\n", 47 | " [ 7847.15, 15693.29]])" 48 | ] 49 | }, 50 | "execution_count": 1, 51 | "metadata": {}, 52 | "output_type": "execute_result" 53 | } 54 | ], 55 | "source": [ 56 | "import paramtools\n", 57 | "\n", 58 | "\n", 59 | "class TaxParams(paramtools.Parameters):\n", 60 | " defaults = {\n", 61 | " \"schema\": {\n", 62 | " \"labels\": {\n", 63 | " \"year\": {\n", 64 | " \"type\": \"int\",\n", 65 | " \"validators\": {\"range\": {\"min\": 2013, \"max\": 2027}}\n", 66 | " },\n", 67 | " \"marital_status\": {\n", 68 | " \"type\": \"str\",\n", 69 | " \"validators\": {\"choice\": {\"choices\": [\"single\", \"joint\"]}}\n", 70 | " },\n", 71 | " }\n", 72 | " },\n", 73 | " \"standard_deduction\": {\n", 74 | " \"title\": \"Standard deduction amount\",\n", 75 | " \"description\": \"Amount filing unit can use as a standard deduction.\",\n", 76 | " \"type\": \"float\",\n", 77 | "\n", 78 | " # Set indexed to True to extend standard_deduction with the built-in\n", 79 | " # extension logic.\n", 80 | " \"indexed\": True,\n", 81 | "\n", 82 | " \"value\": [\n", 83 | " {\"year\": 2017, \"marital_status\": \"single\", \"value\": 6350},\n", 84 | " {\"year\": 2017, \"marital_status\": \"joint\", \"value\": 12700},\n", 85 | " {\"year\": 2018, \"marital_status\": \"single\", \"value\": 12000},\n", 86 | " {\"year\": 2018, \"marital_status\": \"joint\", \"value\": 24000},\n", 87 | " {\"year\": 2026, \"marital_status\": \"single\", \"value\": 7685},\n", 88 | " {\"year\": 2026, \"marital_status\": \"joint\", \"value\": 15369}],\n", 89 | " \"validators\": {\n", 90 | " \"range\": {\n", 91 | " \"min\": 0,\n", 92 | " \"max\": 9e+99\n", 93 | " }\n", 94 | " }\n", 95 | " },\n", 96 | " }\n", 97 | " array_first = True\n", 98 | " label_to_extend = \"year\"\n", 99 | " # Activate use of extend_func method.\n", 100 | " uses_extend_func = True\n", 101 | " # inflation rates from Tax-Calculator v2.5.0\n", 102 | " index_rates = {\n", 103 | " 2013: 0.0148,\n", 104 | " 2014: 0.0159,\n", 105 | " 2015: 0.0012,\n", 106 | " 2016: 0.0127,\n", 107 | " 2017: 0.0187,\n", 108 | " 2018: 0.0224,\n", 109 | " 2019: 0.0186,\n", 110 | " 2020: 0.0233,\n", 111 | " 2021: 0.0229,\n", 112 | " 2022: 0.0228,\n", 113 | " 2023: 0.0221,\n", 114 | " 2024: 0.0211,\n", 115 | " 2025: 0.0209,\n", 116 | " 2026: 0.0211,\n", 117 | " 2027: 0.0208,\n", 118 | " 2028: 0.021,\n", 119 | " 2029: 0.021\n", 120 | " }\n", 121 | "\n", 122 | "\n", 123 | "params = TaxParams()\n", 124 | "params.standard_deduction\n" 125 | ] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "metadata": {}, 130 | "source": [ 131 | "Adjustments are also indexed. In the example below, `standard_deduction` is set to 10,000 in 2017, increased to 15,000 for single tax units in 2020, and increased to 20,000 for joint tax units in 2021:\n" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 2, 137 | "metadata": {}, 138 | "outputs": [ 139 | { 140 | "data": { 141 | "text/plain": [ 142 | "array([[ 6074.92, 12149.84],\n", 143 | " [ 6164.83, 12329.66],\n", 144 | " [ 6262.85, 12525.7 ],\n", 145 | " [ 6270.37, 12540.73],\n", 146 | " [10000. , 10000. ],\n", 147 | " [10187. , 10187. ],\n", 148 | " [10415.19, 10415.19],\n", 149 | " [15000. , 10608.91],\n", 150 | " [15349.5 , 20000. ],\n", 151 | " [15701. , 20458. ],\n", 152 | " [16058.98, 20924.44],\n", 153 | " [16413.88, 21386.87],\n", 154 | " [16760.21, 21838.13],\n", 155 | " [17110.5 , 22294.55],\n", 156 | " [17471.53, 22764.97]])" 157 | ] 158 | }, 159 | "execution_count": 2, 160 | "metadata": {}, 161 | "output_type": "execute_result" 162 | } 163 | ], 164 | "source": [ 165 | "params.adjust(\n", 166 | " {\n", 167 | " \"standard_deduction\": [\n", 168 | " {\"year\": 2017, \"value\": 10000},\n", 169 | " {\"year\": 2020, \"marital_status\": \"single\", \"value\": 15000},\n", 170 | " {\"year\": 2021, \"marital_status\": \"joint\", \"value\": 20000}\n", 171 | " ]\n", 172 | " }\n", 173 | ")\n", 174 | "\n", 175 | "params.standard_deduction\n" 176 | ] 177 | }, 178 | { 179 | "cell_type": "markdown", 180 | "metadata": {}, 181 | "source": [ 182 | "All values that are added automatically via the `extend` method are given an `_auto` attribute. You can select them like this:\n" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 3, 188 | "metadata": {}, 189 | "outputs": [ 190 | { 191 | "data": { 192 | "text/plain": [ 193 | "[{'year': 2013, 'marital_status': 'single', 'value': 6074.92, '_auto': True},\n", 194 | " {'year': 2014, 'marital_status': 'single', 'value': 6164.83, '_auto': True},\n", 195 | " {'year': 2015, 'marital_status': 'single', 'value': 6262.85, '_auto': True},\n", 196 | " {'year': 2016, 'marital_status': 'single', 'value': 6270.37, '_auto': True},\n", 197 | " {'year': 2019, 'marital_status': 'single', 'value': 12268.8, '_auto': True},\n", 198 | " {'year': 2020, 'marital_status': 'single', 'value': 12497.0, '_auto': True},\n", 199 | " {'year': 2021, 'marital_status': 'single', 'value': 12788.18, '_auto': True},\n", 200 | " {'year': 2022, 'marital_status': 'single', 'value': 13081.03, '_auto': True},\n", 201 | " {'year': 2023, 'marital_status': 'single', 'value': 13379.28, '_auto': True},\n", 202 | " {'year': 2024, 'marital_status': 'single', 'value': 13674.96, '_auto': True},\n", 203 | " {'year': 2025, 'marital_status': 'single', 'value': 13963.5, '_auto': True},\n", 204 | " {'year': 2027, 'marital_status': 'single', 'value': 7847.15, '_auto': True},\n", 205 | " {'year': 2013, 'marital_status': 'joint', 'value': 12149.84, '_auto': True},\n", 206 | " {'year': 2014, 'marital_status': 'joint', 'value': 12329.66, '_auto': True},\n", 207 | " {'year': 2015, 'marital_status': 'joint', 'value': 12525.7, '_auto': True},\n", 208 | " {'year': 2016, 'marital_status': 'joint', 'value': 12540.73, '_auto': True},\n", 209 | " {'year': 2019, 'marital_status': 'joint', 'value': 24537.6, '_auto': True},\n", 210 | " {'year': 2020, 'marital_status': 'joint', 'value': 24994.0, '_auto': True},\n", 211 | " {'year': 2021, 'marital_status': 'joint', 'value': 25576.36, '_auto': True},\n", 212 | " {'year': 2022, 'marital_status': 'joint', 'value': 26162.06, '_auto': True},\n", 213 | " {'year': 2023, 'marital_status': 'joint', 'value': 26758.55, '_auto': True},\n", 214 | " {'year': 2024, 'marital_status': 'joint', 'value': 27349.91, '_auto': True},\n", 215 | " {'year': 2025, 'marital_status': 'joint', 'value': 27926.99, '_auto': True},\n", 216 | " {'year': 2027, 'marital_status': 'joint', 'value': 15693.29, '_auto': True}]" 217 | ] 218 | }, 219 | "execution_count": 3, 220 | "metadata": {}, 221 | "output_type": "execute_result" 222 | } 223 | ], 224 | "source": [ 225 | "params = TaxParams()\n", 226 | "\n", 227 | "params.select_eq(\n", 228 | " \"standard_deduction\", strict=True, _auto=True\n", 229 | ")\n" 230 | ] 231 | }, 232 | { 233 | "cell_type": "markdown", 234 | "metadata": {}, 235 | "source": [ 236 | "If you want to update the index rates and apply them to your existing values, then all you need to do is remove the values that were added automatically. ParamTools will fill in the missing values using the updated index rates:\n" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": 4, 242 | "metadata": { 243 | "scrolled": false 244 | }, 245 | "outputs": [ 246 | { 247 | "data": { 248 | "text/plain": [ 249 | "array([[ 6015.22, 12030.41],\n", 250 | " [ 6119.28, 12238.54],\n", 251 | " [ 6231.87, 12463.73],\n", 252 | " [ 6254.93, 12509.85],\n", 253 | " [ 6350. , 12700. ],\n", 254 | " [12000. , 24000. ],\n", 255 | " [12298.8 , 24597.6 ],\n", 256 | " [12558.3 , 25116.61],\n", 257 | " [12882.3 , 25764.62],\n", 258 | " [13209.51, 26419.04],\n", 259 | " [13543.71, 27087.44],\n", 260 | " [13876.89, 27753.79],\n", 261 | " [14204.38, 28408.78],\n", 262 | " [ 7685. , 15369. ],\n", 263 | " [ 7866.37, 15731.71]])" 264 | ] 265 | }, 266 | "execution_count": 4, 267 | "metadata": {}, 268 | "output_type": "execute_result" 269 | } 270 | ], 271 | "source": [ 272 | "params = TaxParams()\n", 273 | "\n", 274 | "offset = 0.0025\n", 275 | "for year, rate in params.index_rates.items():\n", 276 | " params.index_rates[year] = rate + offset\n", 277 | "\n", 278 | "automatically_added = params.select_eq(\n", 279 | " \"standard_deduction\", strict=True, _auto=True\n", 280 | ")\n", 281 | "\n", 282 | "params.delete(\n", 283 | " {\n", 284 | " \"standard_deduction\": automatically_added\n", 285 | " }\n", 286 | ")\n", 287 | "\n", 288 | "params.standard_deduction\n" 289 | ] 290 | }, 291 | { 292 | "cell_type": "markdown", 293 | "metadata": {}, 294 | "source": [ 295 | "### Code for getting Tax-Calculator index rates\n", 296 | "\n", 297 | "```python\n", 298 | "import taxcalc\n", 299 | "pol = taxcalc.Policy()\n", 300 | "index_rates = {\n", 301 | " year: value\n", 302 | " for year, value in zip(list(range(2013, 2029 + 1)), pol.inflation_rates())\n", 303 | "}\n", 304 | "```\n", 305 | "\n", 306 | "Note that there are some subtle details that are implemented in Tax-Calculator's indexing logic that are not implemented in this example. Tax-Calculator has a parameter called `CPI_offset` that adjusts inflation rates up or down by a fixed amount. The `indexed` property can also be turned on and off for each parameter. Implementing these nuanced features is left as the proverbial \"trivial exercise to the reader.\"\n" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": null, 312 | "metadata": {}, 313 | "outputs": [], 314 | "source": [] 315 | } 316 | ], 317 | "metadata": { 318 | "kernelspec": { 319 | "display_name": "Python 3", 320 | "language": "python", 321 | "name": "python3" 322 | }, 323 | "language_info": { 324 | "codemirror_mode": { 325 | "name": "ipython", 326 | "version": 3 327 | }, 328 | "file_extension": ".py", 329 | "mimetype": "text/x-python", 330 | "name": "python", 331 | "nbconvert_exporter": "python", 332 | "pygments_lexer": "ipython3", 333 | "version": "3.8.5" 334 | } 335 | }, 336 | "nbformat": 4, 337 | "nbformat_minor": 4 338 | } 339 | -------------------------------------------------------------------------------- /docs/api/reference.rst: -------------------------------------------------------------------------------- 1 | .. _reference: 2 | 3 | API Reference 4 | ================================================= 5 | 6 | **parameters** 7 | 8 | Parameters 9 | ------------------------------------------ 10 | 11 | .. currentmodule:: paramtools.parameters 12 | 13 | .. autoclass:: Parameters 14 | :members: adjust, read_params, set_state, view_state, clear_state, specification, extend, extend_func, to_array, from_array, parse_labels, sort_values, validate, transaction 15 | 16 | Values 17 | ------------------------------------------ 18 | 19 | .. currentmodule:: paramtools.values 20 | 21 | .. autoclass:: Values 22 | :members: eq, ne, gt, lt, lte, isin, isel, __and__, __or__ 23 | 24 | -------------------------------------------------------------------------------- /docs/parameters.md: -------------------------------------------------------------------------------- 1 | # Parameters 2 | 3 | Define your default parameters and let ParamTools handle the rest. 4 | 5 | The ParamTools JSON file is split into two components: a component that defines the structure of your default inputs and a component that defines the variables that are used in your model. The first component is a top level member named `schema`. The second component consists of key-value pairs where the key is the parameter's name and the value is its data. 6 | 7 | ```json 8 | { 9 | "schema": { 10 | "labels": { 11 | "year": { 12 | "type": "int", 13 | "validators": {"range": {"min": 2013, "max": 2027}} 14 | }, 15 | "marital_status": { 16 | "type": "str", 17 | "validators": {"choice": {"choices": ["single", "joint", "separate", 18 | "headhousehold", "widow"]}} 19 | }, 20 | }, 21 | "additional_members": { 22 | "cpi_inflatable": {"type": "bool"}, 23 | "cpi_inflated": {"type": "bool"} 24 | }, 25 | "operators": { 26 | "array_first": true, 27 | "label_to_extend": "year", 28 | "uses_extend_func": true 29 | } 30 | }, 31 | "personal_exemption": { 32 | "title": "Personal Exemption", 33 | "description": "A simple version of the personal exemption.", 34 | "cpi_inflatable": true, 35 | "cpi_inflated": true, 36 | "type": "float", 37 | "value": 0, 38 | "validators": { 39 | "range": { 40 | "min": 0, 41 | } 42 | } 43 | }, 44 | "standard_deduction": { 45 | "title": "Standard deduction amount", 46 | "description": "Amount filing unit can use as a standard deduction.", 47 | "cpi_inflatable": true, 48 | "cpi_inflated": true, 49 | "type": "float", 50 | "value": [ 51 | {"year": 2024, "marital_status": "single", "value": 13673.68}, 52 | {"year": 2024, "marital_status": "joint", "value": 27347.36}, 53 | {"year": 2024, "marital_status": "separate", "value": 13673.68}, 54 | {"year": 2024, "marital_status": "headhousehold", "value": 20510.52}, 55 | {"year": 2024, "marital_status": "widow", "value": 27347.36}, 56 | {"year": 2025, "marital_status": "single", "value": 13967.66}, 57 | {"year": 2025, "marital_status": "joint", "value": 27935.33}, 58 | {"year": 2025, "marital_status": "separate", "value": 13967.66}, 59 | {"year": 2025, "marital_status": "headhousehold", "value": 20951.49}, 60 | {"year": 2025, "marital_status": "widow", "value": 27935.33}], 61 | "validators": { 62 | "range": { 63 | "min": 0, 64 | "level": "warn", 65 | } 66 | } 67 | }, 68 | } 69 | ``` 70 | 71 | 72 | 73 | 74 | 75 | ## Parameters Schema 76 | 77 | ```json 78 | { 79 | "schema": { 80 | "labels": { 81 | "year": { 82 | "type": "int", 83 | "validators": {"range": {"min": 2013, "max": 2027}} 84 | } 85 | }, 86 | "additional_members": { 87 | "cpi_inflatable": {"type": "bool"}, 88 | "cpi_inflated": {"type": "bool"} 89 | } 90 | }, 91 | "operators": { 92 | "array_first": true, 93 | "label_to_extend": true, 94 | "uses_extend_func": true 95 | } 96 | } 97 | ``` 98 | 99 | - `labels`: Labels are used for defining, accessing, and updating a parameter's values. 100 | 101 | - `additional_members`: Additional Members are parameter level members that are specific to your model. For example, "title" is a parameter level member that is required by ParamTools, but "cpi_inflated" is not. Therefore, "cpi_inflated" needs to be defined in `additional_members`. 102 | 103 | - `operators`: Operators affect how the data is read into and handled by the `Parameters` class: 104 | 105 | - `array_first`: If value is `true`, parameters' values will be accessed as arrays by default. 106 | 107 | - `label_to_extend`: The name of the label along which the missing values of the parameters will be extended. For more information, check out the [extend docs](/api/extend/). 108 | 109 | - `uses_extend_func`: If value is `true`, special logic is applied to the values of the parameters as they are extended. For more information, check out the [indexing docs](/api/indexing/). 110 | 111 | 112 | ## Default Parameters 113 | 114 | ```json 115 | { 116 | "standard_deduction": { 117 | "title": "Standard deduction amount", 118 | "description": "Amount filing unit can use as a standard deduction.", 119 | "cpi_inflatable": true, 120 | "cpi_inflated": true, 121 | "type": "float", 122 | "number_dims": 0, 123 | "value": [ 124 | {"year": 2024, "marital_status": "single", "value": 13673.68}, 125 | {"year": 2024, "marital_status": "joint", "value": 27347.36}, 126 | {"year": 2024, "marital_status": "separate", "value": 13673.68}, 127 | {"year": 2024, "marital_status": "headhousehold", "value": 20510.52}, 128 | {"year": 2024, "marital_status": "widow", "value": 27347.36}, 129 | {"year": 2025, "marital_status": "single", "value": 13967.66}, 130 | {"year": 2025, "marital_status": "joint", "value": 27935.33}, 131 | {"year": 2025, "marital_status": "separate", "value": 13967.66}, 132 | {"year": 2025, "marital_status": "headhousehold", "value": 20951.49}, 133 | {"year": 2025, "marital_status": "widow", "value": 27935.33}], 134 | "validators": { 135 | "range": { 136 | "min": 0, 137 | "max": 9e+99 138 | } 139 | } 140 | } 141 | } 142 | ``` 143 | 144 | ### Members: 145 | 146 | - `title`: A human readable name for the parameter. 147 | 148 | - `description`: Describe the parameter. 149 | 150 | - `notes`: (*optional*) Additional advice or information. 151 | 152 | - `type`: Data type of the parameter. Allowed types are `int`, `float`, `bool`, `str` and `date` (YYYY-MM-DD). 153 | 154 | - `number_dims`: (*optional, default is 0*) Number of dimensions for the value, as defined by [`np.ndim`][1]. 155 | 156 | - `value`: Value of the parameter and optionally, the corresponding labels. It can be written in two ways: 157 | 158 | - if labels are used: `{"value": [{"value": "my value", **labels}]}` 159 | 160 | - if labels are not used: `{"value": "my value"}` 161 | 162 | - `validators`: Key-value pairs of the validator objects (*the ranges are inclusive*): 163 | - `level`: All validators take a `level` argument which is either "error" or "warn". By default it is set to "error". 164 | - `when`: 165 | - `is` is set to `equal_to` by default but can also be `greater_than` or `less_than`. 166 | - e.g: `"is": {"greater_than": 0}` 167 | 168 | - If the sub-validators refer to the value of another parameter, then the other parameter 169 | must have `number_dims` equal to 0, i.e. be a scalar value and not an array value. 170 | 171 | ```json 172 | { 173 | "validators": { 174 | "range": {"min": "min value", "max": "max value", "level": "warn"}, 175 | "choice": {"choices": ["list", "of", "allowed", "values"]}, 176 | "date_range": {"min": "2018-01-01", "max": "2018-06-01"}, 177 | "when": { 178 | "param": "other parameter", 179 | "is": "equal_value", 180 | "then": { 181 | "range": { 182 | "min": "min value if other parameter is 'equal_value'" 183 | } 184 | }, 185 | "otherwise": { 186 | "range": { 187 | "min": "min value if other parameter is not 'equal_value'" 188 | } 189 | } 190 | } 191 | } 192 | } 193 | ``` 194 | 195 | [1]: https://docs.scipy.org/doc/numpy/reference/generated/numpy.ndarray.ndim.html 196 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: paramtools-dev 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - "marshmallow>=4.0.0" 6 | - "numpy>=2.1.0" 7 | - "python-dateutil>=2.8.0" 8 | - "pytest>=6.0.0" 9 | - pandas 10 | - fsspec 11 | - pyopenssl 12 | - s3fs # required for tests but not most usage. 13 | - gcsfs # required for tests but not most usage. 14 | - requests # required for tests but not most usage. 15 | - aiohttp # required for tests but not most usage. 16 | - pip 17 | - pip: 18 | - pre-commit 19 | - black 20 | - flake8 21 | - jupyter-book 22 | - sphinx 23 | - autodoc 24 | - ghp-import 25 | - sortedcontainers 26 | -------------------------------------------------------------------------------- /paramtools/__init__.py: -------------------------------------------------------------------------------- 1 | from paramtools.schema_factory import SchemaFactory 2 | from paramtools.exceptions import ( 3 | ParamToolsError, 4 | ParameterUpdateException, 5 | SparseValueObjectsException, 6 | ValidationError, 7 | InconsistentLabelsException, 8 | collision_list, 9 | ParameterNameCollisionException, 10 | UnknownTypeException, 11 | ) 12 | from paramtools.parameters import Parameters 13 | from paramtools.schema import ( 14 | RangeSchema, 15 | ChoiceSchema, 16 | ValueValidatorSchema, 17 | BaseParamSchema, 18 | EmptySchema, 19 | BaseValidatorSchema, 20 | ALLOWED_TYPES, 21 | FIELD_MAP, 22 | VALIDATOR_MAP, 23 | get_type, 24 | get_param_schema, 25 | register_custom_type, 26 | PartialField, 27 | ) 28 | from paramtools.select import ( 29 | select, 30 | select_eq, 31 | select_ne, 32 | select_gt, 33 | select_gte, 34 | select_gt_ix, 35 | select_lt, 36 | select_lte, 37 | ) 38 | from paramtools.sorted_key_list import SortedKeyList, SortedKeyListResult 39 | from paramtools.typing import ValueObject 40 | from paramtools.utils import ( 41 | read_json, 42 | get_example_paths, 43 | LeafGetter, 44 | get_leaves, 45 | ravel, 46 | consistent_labels, 47 | ensure_value_object, 48 | hashable_value_object, 49 | filter_labels, 50 | make_label_str, 51 | ) 52 | from paramtools.values import Values, Slice, QueryResult 53 | 54 | 55 | name = "paramtools" 56 | __version__ = "0.20.0" 57 | 58 | __all__ = [ 59 | "SchemaFactory", 60 | "ParamToolsError", 61 | "ParameterUpdateException", 62 | "SparseValueObjectsException", 63 | "ValidationError", 64 | "InconsistentLabelsException", 65 | "collision_list", 66 | "ParameterNameCollisionException", 67 | "UnknownTypeException", 68 | "Parameters", 69 | "RangeSchema", 70 | "ChoiceSchema", 71 | "ValueValidatorSchema", 72 | "BaseParamSchema", 73 | "EmptySchema", 74 | "BaseValidatorSchema", 75 | "ALLOWED_TYPES", 76 | "FIELD_MAP", 77 | "VALIDATOR_MAP", 78 | "get_type", 79 | "get_param_schema", 80 | "register_custom_type", 81 | "PartialField", 82 | "select", 83 | "select_eq", 84 | "select_gt", 85 | "select_gte", 86 | "select_gt_ix", 87 | "select_lt", 88 | "select_lte", 89 | "select_ne", 90 | "read_json", 91 | "get_example_paths", 92 | "get_defaults", 93 | "LeafGetter", 94 | "get_leaves", 95 | "ravel", 96 | "consistent_labels", 97 | "ensure_value_object", 98 | "hashable_value_object", 99 | "filter_labels", 100 | "make_label_str", 101 | "SortedKeyList", 102 | "SortedKeyListResult", 103 | "ValueObject", 104 | "Values", 105 | "Slice", 106 | "QueryResult", 107 | ] 108 | -------------------------------------------------------------------------------- /paramtools/contrib/__init__.py: -------------------------------------------------------------------------------- 1 | from paramtools.contrib.validate import Range, DateRange, OneOf, When 2 | from paramtools.contrib.fields import ( 3 | Float64, 4 | Int64, 5 | Bool_, 6 | MeshFieldMixin, 7 | Str, 8 | Integer, 9 | Float, 10 | Boolean, 11 | Date, 12 | ) 13 | 14 | 15 | __all__ = [ 16 | "Range", 17 | "DateRange", 18 | "OneOf", 19 | "When", 20 | "Float64", 21 | "Int64", 22 | "Bool_", 23 | "MeshFieldMixin", 24 | "Str", 25 | "Integer", 26 | "Float", 27 | "Boolean", 28 | "Date", 29 | ] 30 | -------------------------------------------------------------------------------- /paramtools/contrib/fields.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import datetime 3 | import json 4 | 5 | import marshmallow as ma 6 | 7 | 8 | default_cmps = { 9 | "key": lambda x: x, 10 | "gt": lambda x, y: x > y, 11 | "gte": lambda x, y: x >= y, 12 | "lt": lambda x, y: x < y, 13 | "lte": lambda x, y: x <= y, 14 | "ne": lambda x, y: x != y, 15 | "eq": lambda x, y: x == y, 16 | } 17 | 18 | 19 | class NumPySerializeMixin: 20 | def _serialize(self, value, attr, obj, **kwargs): 21 | if hasattr(value, "tolist"): 22 | return value.tolist() 23 | else: 24 | return value 25 | 26 | def _validated(self, value): 27 | value = super()._validated(value) 28 | if value.shape != tuple(): 29 | raise self.make_error("invalid", input=value) 30 | return value 31 | 32 | def cmp_funcs(self, **kwargs): 33 | if not self.validators: 34 | return default_cmps 35 | assert len(self.validators) == 1 36 | cmp_funcs = self.validators[0].cmp_funcs(**kwargs) 37 | if cmp_funcs is None: 38 | return default_cmps 39 | else: 40 | return cmp_funcs 41 | 42 | 43 | class Float64(NumPySerializeMixin, ma.fields.Number): 44 | """ 45 | Implements "float" :ref:`spec:Type property` for parameter values. 46 | Defined as 47 | `numpy.float64 `__ type 48 | """ 49 | 50 | num_type = np_type = np.float64 51 | 52 | 53 | class Int64(NumPySerializeMixin, ma.fields.Integer): 54 | """ 55 | Implements "int" :ref:`spec:Type property` for parameter values. 56 | Defined as `numpy.int64 `__ type 57 | """ 58 | 59 | num_type = np_type = np.int64 60 | 61 | 62 | class Bool_(NumPySerializeMixin, ma.fields.Boolean): 63 | """ 64 | Implements "bool" :ref:`spec:Type property` for parameter values. 65 | Defined as `numpy.bool_ `__ type 66 | """ 67 | 68 | num_type = np_type = np.bool_ 69 | 70 | def _deserialize(self, value, attr, obj, **kwargs): 71 | return np.bool_(super()._deserialize(value, attr, obj, **kwargs)) 72 | 73 | 74 | class MeshFieldMixin: 75 | """ 76 | Provides method for accessing ``contrib.validate`` 77 | validators' grid methods 78 | """ 79 | 80 | def grid(self): 81 | if not self.validators: 82 | return [] 83 | assert len(self.validators) == 1 84 | return self.validators[0].grid() 85 | 86 | def cmp_funcs(self, **kwargs): 87 | if not self.validators: 88 | return default_cmps 89 | assert len(self.validators) == 1 90 | cmp_funcs = self.validators[0].cmp_funcs(**kwargs) 91 | if cmp_funcs is None: 92 | return default_cmps 93 | else: 94 | return cmp_funcs 95 | 96 | 97 | class Str(MeshFieldMixin, ma.fields.Str): 98 | """ 99 | Implements "str" :ref:`spec:Type property`. 100 | """ 101 | 102 | np_type = object 103 | 104 | 105 | class Integer(MeshFieldMixin, ma.fields.Integer): 106 | """ 107 | Implements "int" :ref:`spec:Type property` for properties 108 | except for parameter values. 109 | """ 110 | 111 | np_type = int 112 | 113 | 114 | class Float(MeshFieldMixin, ma.fields.Float): 115 | """ 116 | Implements "float" :ref:`spec:Type property` for properties 117 | except for parameter values. 118 | """ 119 | 120 | np_type = float 121 | 122 | 123 | class Boolean(MeshFieldMixin, ma.fields.Boolean): 124 | """ 125 | Implements "bool" :ref:`spec:Type property` for properties 126 | except for parameter values. 127 | """ 128 | 129 | np_type = bool 130 | 131 | 132 | class Date(MeshFieldMixin, ma.fields.Date): 133 | """ 134 | Implements "date" :ref:`spec:Type property`. 135 | """ 136 | 137 | np_type = datetime.date 138 | default_error_messages = { 139 | "invalid": "Not a valid {obj_type}: {input}", 140 | "format": '"{input}" cannot be formatted as a {obj_type}.', 141 | } 142 | 143 | def _deserialize(self, value, attr=None, data=None, **kwargs): 144 | if isinstance(value, (datetime.datetime, datetime.date)): 145 | return value 146 | return super()._deserialize(value, attr, data, **kwargs) 147 | 148 | 149 | class Nested(ma.fields.Nested): 150 | def json_cmp_func(self, data): 151 | try: 152 | return json.dumps(self._serialize(data, None, None)) 153 | except ma.ValidationError as ve: 154 | try: 155 | return json.dumps(data) 156 | except json.JSONDecodeError: 157 | raise ve 158 | 159 | def cmp_funcs(self, **kwargs): 160 | cmp_funcs = getattr(self.nested, "cmp_funcs", None) 161 | if cmp_funcs is None: 162 | # This is not a good comparison function but it's the 163 | # best we can do if the cmp_funcs method is not 164 | # defined. 165 | return {"key": self.json_cmp_func} 166 | else: 167 | return cmp_funcs(**kwargs) 168 | -------------------------------------------------------------------------------- /paramtools/contrib/validate.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import datetime 3 | import itertools 4 | 5 | from dateutil.relativedelta import relativedelta 6 | import numpy as np 7 | import marshmallow as ma 8 | 9 | from paramtools.typing import ValueObject 10 | from paramtools import utils 11 | 12 | 13 | class ValidationError(ma.ValidationError): 14 | def __init__(self, *args, level=None, **kwargs): 15 | self.level = level or "error" 16 | super().__init__(*args, **kwargs) 17 | 18 | 19 | class When(ma.validate.Validator): 20 | then_message = "When value is {is_val}, the input is invalid: {submsg}" 21 | otherwise_message = ( 22 | "When value is {is_val}, the input is invalid: {submsg}" 23 | ) 24 | shape_mismatch = "Shape mismatch between parameters: {shape1} {shape2}" 25 | 26 | is_value_evaluators = { 27 | "equal_to": lambda when_value, is_val: when_value == is_val, 28 | "less_than": lambda when_value, is_val: when_value < is_val, 29 | "greater_than": lambda when_value, is_val: when_value > is_val, 30 | } 31 | 32 | def __init__( 33 | self, 34 | is_object, 35 | when_vos: List[ValueObject], 36 | then_validators: List[ma.validate.Validator], 37 | otherwise_validators: List[ma.validate.Validator], 38 | then_message: str = None, 39 | otherwise_message: str = None, 40 | level: str = "error", 41 | type: str = "int", 42 | number_dims: int = 0, 43 | ): 44 | self.is_operator = next(iter(is_object)) 45 | self.is_val = is_object[self.is_operator] 46 | self.when_vos = when_vos 47 | self.then_validators = then_validators 48 | self.otherwise_validators = otherwise_validators 49 | self.then_message = then_message or self.then_message 50 | self.otherwise_message = otherwise_message or self.otherwise_message 51 | self.level = level 52 | self.type = type 53 | self.number_dims = number_dims 54 | 55 | def __call__(self, value, is_value_object=False): 56 | if value is None: 57 | return value 58 | if not is_value_object: 59 | value = {"value": value} 60 | 61 | msgs = [] 62 | arr = np.array(value["value"]) 63 | for when_vo in self.when_vos: 64 | if not isinstance(when_vo["value"], list): 65 | msgs += self.apply_validator( 66 | value, when_vo["value"], value, when_vo, ix=None 67 | ) 68 | continue 69 | 70 | when_arr = np.array(when_vo["value"]) 71 | if when_arr.shape != arr.shape: 72 | raise ValidationError( 73 | self.shape_mismatch.format( 74 | shape1=when_arr.shape, shape2=arr.shape 75 | ) 76 | ) 77 | for ix in itertools.product(*(map(range, arr.shape))): 78 | msgs += self.apply_validator( 79 | {"value": arr[ix]}, when_arr[ix], value, when_vo, ix 80 | ) 81 | 82 | if msgs: 83 | raise ValidationError( 84 | msgs if len(msgs) > 1 else msgs[0], level=self.level 85 | ) 86 | 87 | def evaluate_is_value(self, when_value): 88 | return self.is_value_evaluators[self.is_operator]( 89 | when_value, self.is_val 90 | ) 91 | 92 | def apply_validator(self, value, when_value, labels, when_labels, ix=None): 93 | def ix2string(ix): 94 | return ( 95 | f"[index={', '.join(map(str, ix))}]" if ix is not None else "" 96 | ) 97 | 98 | msgs = [] 99 | is_val_cond = self.evaluate_is_value(when_value) 100 | if is_val_cond: 101 | for validator in self.then_validators: 102 | try: 103 | validator(value, is_value_object=True) 104 | except ValidationError as ve: 105 | msgs.append( 106 | self.then_message.format( 107 | is_val=f"{self.is_operator.replace('_', ' ')} {self.is_val}", 108 | submsg=str(ve), 109 | labels=utils.make_label_str(labels), 110 | when_labels=utils.make_label_str(when_labels), 111 | ix=ix2string(ix), 112 | ) 113 | ) 114 | else: 115 | for validator in self.otherwise_validators: 116 | try: 117 | validator(value, is_value_object=True) 118 | except ValidationError as ve: 119 | msgs.append( 120 | self.otherwise_message.format( 121 | is_val=f"{self.is_operator.replace('_', ' ')} {self.is_val}", 122 | submsg=str(ve), 123 | labels=utils.make_label_str(labels), 124 | when_labels=utils.make_label_str(when_labels), 125 | ix=ix2string(ix), 126 | ) 127 | ) 128 | return msgs 129 | 130 | def grid(self): 131 | """ 132 | Just return grid of first validator. It's unlikely that 133 | there will be multiple. 134 | """ 135 | return self.then_validators[0].grid() 136 | 137 | def cmp_funcs(self, **kwargs): 138 | return None 139 | 140 | 141 | class Range(ma.validate.Range): 142 | """ 143 | Implements "range" :ref:`spec:Validator object`. 144 | """ 145 | 146 | error = "" 147 | message_min = "Input {input} must be {min_op} {min}." 148 | message_max = "Input {input} must be {max_op} {max}." 149 | 150 | def __init__( 151 | self, 152 | min=None, 153 | max=None, 154 | min_vo=None, 155 | max_vo=None, 156 | error_min=None, 157 | error_max=None, 158 | step=None, 159 | level=None, 160 | ): 161 | if min is not None: 162 | self.min = [{"value": min}] 163 | else: 164 | self.min = min_vo 165 | if max is not None: 166 | self.max = [{"value": max}] 167 | else: 168 | self.max = max_vo 169 | 170 | self.error_min = error_min 171 | self.error_max = error_max 172 | self.step = step or 1 # default to 1 173 | 174 | self.min_inclusive = None 175 | self.max_inclusive = None 176 | self.level = level or "error" 177 | 178 | def __call__(self, value, is_value_object=False): 179 | """ 180 | This is the method that marshmallow calls by default. is_value_object 181 | validation goes straight to validate_value_objects. 182 | """ 183 | if value is None: 184 | return value 185 | if not is_value_object: 186 | value = {"value": value} 187 | return self.validate_value_objects(value) 188 | 189 | def validate_value_objects(self, value): 190 | if value["value"] is None: 191 | return None 192 | msgs = [] 193 | if self.min is not None: 194 | for min_vo in self.min: 195 | if np.any(np.array(value["value"]) < min_vo["value"]): 196 | msgs.append( 197 | (self.error_min or self.message_min).format( 198 | input=value["value"], 199 | min=min_vo["value"], 200 | min_op="greater than", 201 | labels=utils.make_label_str(value), 202 | oth_labels=utils.make_label_str(min_vo), 203 | ) 204 | ) 205 | if self.max is not None: 206 | for max_vo in self.max: 207 | if np.any(np.array(value["value"]) > max_vo["value"]): 208 | msgs.append( 209 | (self.error_max or self.message_max).format( 210 | input=value["value"], 211 | max=max_vo["value"], 212 | max_op="less than", 213 | labels=utils.make_label_str(value), 214 | oth_labels=utils.make_label_str(max_vo), 215 | ) 216 | ) 217 | if msgs: 218 | raise ValidationError( 219 | msgs if len(msgs) > 1 else msgs[0], level=self.level 220 | ) 221 | return value 222 | 223 | def grid(self): 224 | # make np.arange inclusive. 225 | max_ = self.max[0]["value"] + self.step 226 | arr = np.arange(self.min[0]["value"], max_, self.step) 227 | return arr[arr <= self.max[0]["value"]].tolist() 228 | 229 | def cmp_funcs(self, **kwargs): 230 | return None 231 | 232 | 233 | class DateRange(Range): 234 | """ 235 | Implements "date_range" :ref:`spec:Validator object`. 236 | Behaves like ``Range``, except values are ensured to be 237 | ``datetime.date`` type and ``grid`` has special logic for dates. 238 | """ 239 | 240 | # check against allowed args: 241 | # https://docs.python.org/3/library/datetime.html#datetime.timedelta 242 | timedelta_args = { 243 | "days", 244 | "months", 245 | "seconds", 246 | "microseconds", 247 | "milliseconds", 248 | "minutes", 249 | "hours", 250 | "weeks", 251 | } 252 | 253 | step_msg = ( 254 | f"The step field must be a dictionary with only these keys: {', '.join(timedelta_args)}." 255 | f"\n\tFor more information, check out the timedelta docs: " 256 | f"\n\t\thttps://docs.python.org/3/library/datetime.html#datetime.timedelta" 257 | ) 258 | 259 | def __init__( 260 | self, 261 | min=None, 262 | max=None, 263 | min_vo=None, 264 | max_vo=None, 265 | error_min=None, 266 | error_max=None, 267 | step=None, 268 | level=None, 269 | ): 270 | if min is not None: 271 | self.min = [{"value": self.safe_deserialize(min)}] 272 | elif min_vo is not None: 273 | self.min = [ 274 | dict(vo, **{"value": self.safe_deserialize(vo["value"])}) 275 | for vo in min_vo 276 | ] 277 | else: 278 | self.min = None 279 | 280 | if max is not None: 281 | self.max = [{"value": self.safe_deserialize(max)}] 282 | elif max_vo is not None: 283 | self.max = [ 284 | dict(vo, **{"value": self.safe_deserialize(vo["value"])}) 285 | for vo in max_vo 286 | ] 287 | else: 288 | self.max = None 289 | 290 | self.error_min = error_min 291 | self.error_max = error_max 292 | 293 | if step is None: 294 | # set to to default step. 295 | step = {"days": 1} 296 | 297 | if not isinstance(step, dict): 298 | raise ValidationError(self.step_msg) 299 | 300 | has_extra_keys = len(set(step.keys()) - self.timedelta_args) 301 | if has_extra_keys: 302 | raise ValidationError(self.step_msg) 303 | 304 | self.step = relativedelta(**step) 305 | 306 | self.level = level or "error" 307 | 308 | def safe_deserialize(self, date): 309 | if isinstance(date, datetime.date): 310 | return date 311 | else: 312 | return ma.fields.Date()._deserialize(date, None, None) 313 | 314 | def grid(self): 315 | # make np.arange inclusive. 316 | max_ = self.max[0]["value"] + self.step 317 | 318 | current = self.min[0]["value"] 319 | result = [] 320 | while current < max_: 321 | result.append(current) 322 | current += self.step 323 | 324 | return result 325 | 326 | 327 | class OneOf(ma.validate.OneOf): 328 | """ 329 | Implements "choice" :ref:`spec:Validator object`. 330 | """ 331 | 332 | default_message = "Input {input} must be one of {choices}." 333 | 334 | def __init__(self, *args, level=None, **kwargs): 335 | self.level = level or "error" 336 | super().__init__(*args, **kwargs) 337 | 338 | def __call__(self, value, is_value_object=False): 339 | if value is None: 340 | return value 341 | if not is_value_object: 342 | vo = {"value": value} 343 | else: 344 | vo = value 345 | if vo["value"] is None: 346 | return None 347 | if not isinstance(vo["value"], list): 348 | vos = {"value": [vo["value"]]} 349 | else: 350 | vos = {"value": utils.ravel(vo["value"])} 351 | for vo in vos["value"]: 352 | try: 353 | if vo not in self.choices: 354 | raise ValidationError( 355 | self._format_error(vo), level=self.level 356 | ) 357 | except TypeError: 358 | raise ValidationError(self._format_error(vo), level=self.level) 359 | return value 360 | 361 | def grid(self): 362 | return self.choices 363 | 364 | def cmp_funcs(self, choices=None, **kwargs): 365 | if choices is None: 366 | choices = self.choices 367 | return { 368 | "key": lambda x: choices.index(x), 369 | "gt": lambda x, y: choices.index(x) > choices.index(y), 370 | "gte": lambda x, y: choices.index(x) >= choices.index(y), 371 | "lt": lambda x, y: choices.index(x) < choices.index(y), 372 | "lte": lambda x, y: choices.index(x) <= choices.index(y), 373 | "ne": lambda x, y: x != y, 374 | "eq": lambda x, y: x == y, 375 | } 376 | -------------------------------------------------------------------------------- /paramtools/examples/baseball/defaults.json: -------------------------------------------------------------------------------- 1 | { 2 | "schema": { 3 | "labels": { 4 | "use_2018": {"type": "bool", "validators": {}} 5 | }, 6 | "additional_members": { 7 | "section_1": {"type": "str"}, 8 | "section_2": {"type": "str"} 9 | } 10 | }, 11 | "pitcher": { 12 | "title": "Pitcher Name", 13 | "description": "Name of pitcher to pull data on", 14 | "section_1": "Parameters", 15 | "section_2": "Pitcher", 16 | "notes": "Make sure the name of the pitcher is correct. A good place to reference this is baseball-reference.com", 17 | "type": "str", 18 | "value": [{"value": "Clayton Kershaw"}], 19 | "validators": {"choice": {"choices": ["Clayton Kershaw", "Julio Teheran", "Max Scherzer"]}} 20 | }, 21 | "batter": { 22 | "title": "Batter Name", 23 | "description": "Name of batter for pitching matchup analysis", 24 | "section_1": "Parameters", 25 | "section_2": "Batter", 26 | "notes": "Make sure the name of the batter is correct. A good place to reference this is baseball-reference.com", 27 | "type": "str", 28 | "value": [{"value": "Freddie Freeman"}], 29 | "validators": {"choice": {"choices": ["Freddie Freeman", "Bryce Harper", "Mookie Betts"]}} 30 | }, 31 | "start_date": { 32 | "title": "Start Date", 33 | "description": "Date to start pulling statcast information", 34 | "section_1": "Parameters", 35 | "section_2": "Date", 36 | "notes": "If using the 2018 dataset, only use dates in 2018.", 37 | "type": "date", 38 | "value": [{"value": "2018-01-01"}], 39 | "validators": {"date_range": {"min": "2008-01-01", "max": "end_date"}} 40 | }, 41 | "end_date": { 42 | "title": "End Date", 43 | "description": "Date to quit pulling statcast information", 44 | "section_1": "Parameters", 45 | "section_2": "Date", 46 | "notes": "If using the 2018 dataset, only use dates in 2018.", 47 | "type": "date", 48 | "value": [{"value": "2018-11-10"}], 49 | "validators": {"date_range": {"min": "2008-01-01", "max": "2018-11-10"}} 50 | } 51 | } -------------------------------------------------------------------------------- /paramtools/examples/behresp/defaults.json: -------------------------------------------------------------------------------- 1 | { 2 | "schema": { 3 | "labels": {}, 4 | "additional_members": {} 5 | }, 6 | "BE_sub": { 7 | "title": "Substitution elasticity of taxable income", 8 | "description": "Defined as proportional change in taxable income divided by proportional change in marginal net-of-tax rate (1-MTR) on taxpayer earnings caused by the reform. Must be zero or positive.", 9 | "notes": "", 10 | "type": "float", 11 | "value": [{"value": 0.0}], 12 | "validators": { 13 | "range": {"min": 0.0, "max": 9e99} 14 | } 15 | }, 16 | "BE_inc": { 17 | "title": "Income elasticity of taxable income", 18 | "description": "Defined as dollar change in taxable income divided by dollar change in after-tax income caused by the reform. Must be zero or negative.", 19 | "notes": "", 20 | "type": "float", 21 | "value": [{"value": 0.0}], 22 | "validators": { 23 | "range": {"min": -9e99, "max": 0} 24 | } 25 | }, 26 | "BE_cg": { 27 | "title": "Semi-elasticity of long-term capital gains", 28 | "description": "Defined as change in logarithm of long-term capital gains divided by change in marginal tax rate (MTR) on long-term capital gains caused by the reform. Must be zero or negative. Read response function documentation (see below) for discussion of appropriate values.", 29 | "notes": "", 30 | "type": "float", 31 | "value": [{"value": 0.0}], 32 | "validators": { 33 | "range": {"min": -9e99, "max": 0} 34 | } 35 | } 36 | } -------------------------------------------------------------------------------- /paramtools/examples/taxparams-demo/defaults.json: -------------------------------------------------------------------------------- 1 | { 2 | "schema": { 3 | "labels": { 4 | "year": { 5 | "type": "int", 6 | "validators": {"range": {"min": 2013, "max": 2027}} 7 | }, 8 | "marital_status": { 9 | "type": "str", 10 | "validators": {"choice": {"choices": ["single", "joint", "separate", 11 | "headhousehold", "widow"]}} 12 | }, 13 | "idedtype": { 14 | "type": "str", 15 | "validators": {"choice": {"choices": ["medical", "statelocal", 16 | "realestate", "casualty", 17 | "misc", "interest", "charity"]}} 18 | }, 19 | "EIC": { 20 | "type": "str", 21 | "validators": {"choice": {"choices": ["0kids", "1kid", 22 | "2kids", "3+kids"]}} 23 | } 24 | }, 25 | "additional_members": { 26 | "section_1": {"type": "str", "number_dims": 0}, 27 | "section_2": {"type": "str", "number_dims": 0}, 28 | "section_3": {"type": "str", "number_dims": 0}, 29 | "irs_ref": {"type": "str", "number_dims": 0}, 30 | "start_year": {"type": "int", "number_dims": 0}, 31 | "cpi_inflatable": {"type": "bool", "number_dims": 0}, 32 | "cpi_inflated": {"type": "bool", "number_dims": 0} 33 | } 34 | }, 35 | "standard_deduction": { 36 | "title": "Standard deduction amount", 37 | "description": "Amount filing unit can use as a standard deduction.", 38 | "section_1": "Standard Deduction", 39 | "section_2": "Standard Deduction Amount", 40 | "irs_ref": "Form 1040, line 8, instructions. ", 41 | "notes": "", 42 | "start_year": 2024, 43 | "cpi_inflatable": true, 44 | "cpi_inflated": true, 45 | "type": "float", 46 | "value": [ 47 | {"year": 2024, "marital_status": "single", "value": 13673.68}, 48 | {"year": 2024, "marital_status": "joint", "value": 27347.36}, 49 | {"year": 2024, "marital_status": "separate", "value": 13673.68}, 50 | {"year": 2024, "marital_status": "headhousehold", "value": 20510.52}, 51 | {"year": 2024, "marital_status": "widow", "value": 27347.36}, 52 | {"year": 2025, "marital_status": "single", "value": 13967.66}, 53 | {"year": 2025, "marital_status": "joint", "value": 27935.33}, 54 | {"year": 2025, "marital_status": "separate", "value": 13967.66}, 55 | {"year": 2025, "marital_status": "headhousehold", "value": 20951.49}, 56 | {"year": 2025, "marital_status": "widow", "value": 27935.33}, 57 | {"year": 2026, "marital_status": "single", "value": 7690.0}, 58 | {"year": 2026, "marital_status": "joint", "value": 15380.0}, 59 | {"year": 2026, "marital_status": "separate", "value": 7690.0}, 60 | {"year": 2026, "marital_status": "headhousehold", "value": 11323.0}, 61 | {"year": 2026, "marital_status": "widow", "value": 15380.0}], 62 | "validators": { 63 | "range": { 64 | "min": 0, 65 | "max": 9e+99 66 | } 67 | } 68 | }, 69 | "social_security_tax_rate": { 70 | "description": "Social Security FICA rate, including both employer and employee.", 71 | "section_1": "Payroll Taxes", 72 | "section_2": "Social Security FICA", 73 | "irs_ref": "", 74 | "notes": "", 75 | "start_year": 2026, 76 | "cpi_inflatable": false, 77 | "cpi_inflated": false, 78 | "value": [ 79 | {"year": 2024, "value": 0.124}, 80 | {"year": 2025, "value": 0.124}, 81 | {"year": 2026, "value": 0.124} 82 | ], 83 | "title": "Social Security payroll tax rate", 84 | "type": "float", 85 | "validators": { 86 | "range": { 87 | "min": 0, 88 | "max": 1 89 | } 90 | } 91 | }, 92 | "ii_bracket_1": { 93 | "title": "Personal income (regular/non-AMT/non-pass-through) tax bracket (upper threshold) 1", 94 | "description": "Taxable income below this threshold is taxed at tax rate 1.", 95 | "section_1": "Personal Income", 96 | "section_2": "Regular: Non-AMT, Non-Pass-Through", 97 | "irs_ref": "Form 1040, line 44, instruction (Schedule XYZ).", 98 | "notes": "", 99 | "start_year": 2013, 100 | "cpi_inflatable": true, 101 | "cpi_inflated": true, 102 | "type": "float", 103 | "value": [ 104 | {"year": 2024, "marital_status": "single", "value": 10853.48}, 105 | {"year": 2024, "marital_status": "joint", "value": 21706.97}, 106 | {"year": 2024, "marital_status": "separate", "value": 10853.48}, 107 | {"year": 2024, "marital_status": "headhousehold", "value": 15496.84}, 108 | {"year": 2024, "marital_status": "widow", "value": 21706.97}, 109 | {"year": 2025, "marital_status": "single", "value": 11086.83}, 110 | {"year": 2025, "marital_status": "joint", "value": 22173.66}, 111 | {"year": 2025, "marital_status": "separate", "value": 11086.83}, 112 | {"year": 2025, "marital_status": "headhousehold", "value": 15830.02}, 113 | {"year": 2025, "marital_status": "widow", "value": 22173.66}, 114 | {"year": 2026, "marital_status": "single", "value": 11293.0}, 115 | {"year": 2026, "marital_status": "joint", "value": 22585.0}, 116 | {"year": 2026, "marital_status": "separate", "value": 11293.0}, 117 | {"year": 2026, "marital_status": "headhousehold", "value": 16167.0}, 118 | {"year": 2026, "marital_status": "widow", "value": 22585.0}], 119 | "validators": { 120 | "range": { 121 | "min": 0, 122 | "max": "ii_bracket_2" 123 | } 124 | } 125 | }, 126 | "ii_bracket_2": { 127 | "title": "Personal income (regular/non-AMT/non-pass-through) tax bracket (upper threshold) 2", 128 | "description": "Income below this threshold and above tax bracket 1 is taxed at tax rate 2.", 129 | "section_1": "Personal Income", 130 | "section_2": "Regular: Non-AMT, Non-Pass-Through", 131 | "irs_ref": "Form 1040, line 11, instruction (Schedule XYZ).", 132 | "notes": "", 133 | "start_year": 2013, 134 | "cpi_inflatable": true, 135 | "cpi_inflated": true, 136 | "type": "float", 137 | "value": [ 138 | {"year": 2024, "marital_status": "single", "value": 44097.61}, 139 | {"year": 2024, "marital_status": "joint", "value": 88195.23}, 140 | {"year": 2024, "marital_status": "separate", "value": 44097.61}, 141 | {"year": 2024, "marital_status": "headhousehold", "value": 59024.71}, 142 | {"year": 2024, "marital_status": "widow", "value": 88195.23}, 143 | {"year": 2025, "marital_status": "single", "value": 45045.71}, 144 | {"year": 2025, "marital_status": "joint", "value": 90091.43}, 145 | {"year": 2025, "marital_status": "separate", "value": 45045.71}, 146 | {"year": 2025, "marital_status": "headhousehold", "value": 60293.74}, 147 | {"year": 2025, "marital_status": "widow", "value": 90091.43}, 148 | {"year": 2026, "marital_status": "single", "value": 45957.0}, 149 | {"year": 2026, "marital_status": "joint", "value": 91915.0}, 150 | {"year": 2026, "marital_status": "separate", "value": 45957.0}, 151 | {"year": 2026, "marital_status": "headhousehold", "value": 61519.0}, 152 | {"year": 2026, "marital_status": "widow", "value": 91915.0}], 153 | "validators": { 154 | "range": { 155 | "min": "ii_bracket_1", 156 | "max": 9e+99 157 | } 158 | } 159 | } 160 | } 161 | -------------------------------------------------------------------------------- /paramtools/exceptions.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import defaultdict 3 | 4 | from paramtools import utils 5 | 6 | 7 | class ParamToolsError(Exception): 8 | pass 9 | 10 | 11 | class ParameterUpdateException(ParamToolsError): 12 | pass 13 | 14 | 15 | class SparseValueObjectsException(ParamToolsError): 16 | pass 17 | 18 | 19 | class UnknownTypeException(ParamToolsError): 20 | pass 21 | 22 | 23 | class ValidationError(ParamToolsError): 24 | def __init__(self, messages, labels): 25 | self.messages = messages 26 | self.labels = labels 27 | error_msg = defaultdict(dict) 28 | for error_type, msgs in self.messages.items(): 29 | for param, msg in msgs.items(): 30 | error_msg[error_type][param] = utils.ravel(msg) 31 | super().__init__(json.dumps(error_msg, indent=4)) 32 | 33 | 34 | class InconsistentLabelsException(ParamToolsError): 35 | pass 36 | 37 | 38 | collision_list = [ 39 | "_data", 40 | "_errors", 41 | "_warnings", 42 | "select_eq", 43 | "select_ne", 44 | "select_gt", 45 | "select_gte", 46 | "select_lt", 47 | "select_lte", 48 | "_adjust", 49 | "_delete", 50 | "_numpy_type", 51 | "_parse_errors", 52 | "_resolve_order", 53 | "_schema", 54 | "_set_state", 55 | "_state", 56 | "_stateless_label_grid", 57 | "_update_param", 58 | "_validator_schema", 59 | "_defaults_schema", 60 | "_select", 61 | "_defer_validation", 62 | "operators", 63 | "adjust", 64 | "delete", 65 | "validate", 66 | "transaction", 67 | "array_first", 68 | "clear_state", 69 | "defaults", 70 | "dump", 71 | "items", 72 | "keys", 73 | "label_grid", 74 | "label_validators", 75 | "keyfuncs", 76 | "errors", 77 | "warnings", 78 | "from_array", 79 | "parse_labels", 80 | "read_params", 81 | "schema", 82 | "set_state", 83 | "specification", 84 | "sort_values", 85 | "to_array", 86 | "validation_error", 87 | "view_state", 88 | "extend", 89 | "extend_func", 90 | "uses_extend_func", 91 | "label_to_extend", 92 | "get_index_rate", 93 | "index_rates", 94 | "to_dict", 95 | "_parse_validation_messages", 96 | "sel", 97 | "get_defaults", 98 | ] 99 | 100 | 101 | class ParameterNameCollisionException(ParamToolsError): 102 | pass 103 | -------------------------------------------------------------------------------- /paramtools/schema_factory.py: -------------------------------------------------------------------------------- 1 | from marshmallow import fields, Schema 2 | 3 | from paramtools.schema import ( 4 | BaseValidatorSchema, 5 | ValueObject, 6 | get_type, 7 | get_param_schema, 8 | ParamToolsSchema, 9 | ) 10 | 11 | 12 | class SchemaFactory: 13 | """ 14 | Uses data from: 15 | - a schema definition file 16 | - a baseline specification file 17 | 18 | to extend: 19 | - `schema.BaseParamSchema` 20 | - `schema.BaseValidatorSchema` 21 | 22 | Once this has been completed, the `load_params` method can be used to 23 | deserialize and validate parameter data. 24 | """ 25 | 26 | def __init__(self, defaults): 27 | self.defaults = {k: v for k, v in defaults.items() if k != "schema"} 28 | self.schema = ParamToolsSchema().load(defaults.get("schema", {})) 29 | (self.BaseParamSchema, self.label_validators) = get_param_schema( 30 | self.schema 31 | ) 32 | 33 | def schemas(self): 34 | """ 35 | For each parameter defined in the baseline specification file: 36 | - define a parameter schema for that specific parameter 37 | - define a validation schema for that specific parameter 38 | 39 | Next, create a baseline specification schema class (`ParamSchema`) for 40 | all parameters listed in the baseline specification file and a 41 | validator schema class (`ValidatorSchema`) for all parameters in the 42 | baseline specification file. 43 | 44 | - `ParamSchema` reads and validates the baseline specification file 45 | - `ValidatorSchema` reads revisions to the baseline parameters and 46 | validates their type, structure, and whether they are within the 47 | specified range. 48 | 49 | `param_schema` is defined and used to read and validate the baseline 50 | specifications file. `validator_schema` is defined to read and validate 51 | the parameter revisions. The output from the baseline specification 52 | deserialization is saved in the `context` attribute on 53 | `validator_schema` and will be utilized when doing range validation. 54 | """ 55 | param_dict = {} 56 | validator_dict = {} 57 | for k, v in self.defaults.items(): 58 | fieldtype = get_type(v) 59 | classattrs = { 60 | "value": fieldtype, 61 | "_auto": fields.Boolean(required=False, load_only=True), 62 | **self.label_validators, 63 | } 64 | 65 | # TODO: what about case where number_dims > 0 66 | # if not isinstance(v["value"], list): 67 | # v["value"] = [{"value": v["value"]}] 68 | 69 | validator_dict[k] = type("ValidatorItem", (Schema,), classattrs) 70 | 71 | classattrs = {"value": ValueObject(validator_dict[k], many=True)} 72 | param_dict[k] = type( 73 | "IndividualParamSchema", (self.BaseParamSchema,), classattrs 74 | ) 75 | 76 | classattrs = {k: fields.Nested(v) for k, v in param_dict.items()} 77 | DefaultsSchema = type("DefaultsSchema", (Schema,), classattrs) 78 | defaults_schema = DefaultsSchema() 79 | 80 | classattrs = { 81 | k: ValueObject(v, many=True) for k, v in validator_dict.items() 82 | } 83 | ValidatorSchema = type( 84 | "ValidatorSchema", (BaseValidatorSchema,), classattrs 85 | ) 86 | validator_schema = ValidatorSchema() 87 | 88 | return ( 89 | defaults_schema, 90 | validator_schema, 91 | self.schema, 92 | defaults_schema.load(self.defaults), 93 | ) 94 | -------------------------------------------------------------------------------- /paramtools/select.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Iterable, List, Callable 2 | import warnings 3 | 4 | 5 | from paramtools.typing import ValueObject, CmpFunc 6 | from paramtools import values 7 | 8 | 9 | def select( 10 | value_objects: List[ValueObject], 11 | strict: bool, 12 | cmp_func: CmpFunc, 13 | labels: dict, 14 | tree=None, 15 | op=None, 16 | ) -> List[ValueObject]: 17 | """ 18 | Deprecated. Use Values instead. 19 | """ 20 | warnings.warn("The select module is deprecated. Use Values instead.") 21 | assert op, "Op is required." 22 | values_ = values.Values(value_objects) 23 | res = [] 24 | for label, value in labels.items(): 25 | if isinstance(value, list): 26 | if op == "ne": 27 | agg_func = values.intersection 28 | else: 29 | agg_func = values.union 30 | res.append( 31 | agg_func( 32 | values_._cmp(op, strict, **{label: element}) 33 | for element in value 34 | ) 35 | ) 36 | else: 37 | res.append(values_._cmp(op, strict, **{label: value})) 38 | return list(values.intersection(res)) 39 | 40 | 41 | def eq_func(x: Any, y: Iterable) -> bool: 42 | return x in y 43 | 44 | 45 | def ne_func(x: Any, y: Iterable) -> bool: 46 | return x not in y 47 | 48 | 49 | def gt_func(x: Any, y: Iterable) -> bool: 50 | return all(x > item for item in y) 51 | 52 | 53 | def gte_func(x: Any, y: Iterable) -> bool: 54 | return all(x >= item for item in y) 55 | 56 | 57 | def gt_ix_func(cmp_list: list, x: Any, y: Iterable) -> bool: 58 | x_val = cmp_list.index(x) 59 | return all(x_val > cmp_list.index(item) for item in y) 60 | 61 | 62 | def lt_func(x, y) -> bool: 63 | return all(x < item for item in y) 64 | 65 | 66 | def lte_func(x, y) -> bool: 67 | return all(x <= item for item in y) 68 | 69 | 70 | def make_cmp_func( 71 | cmp: Callable[[Any, Iterable], bool], 72 | all_or_any: Callable[[Iterable], bool], 73 | ) -> CmpFunc: 74 | return lambda x, y: all_or_any(cmp(x, item) for item in y) 75 | 76 | 77 | def select_eq( 78 | value_objects: List[ValueObject], strict: bool, labels: dict, tree=None 79 | ) -> List[ValueObject]: 80 | return select(value_objects, strict, eq_func, labels, tree, op="eq") 81 | 82 | 83 | def select_ne( 84 | value_objects: List[ValueObject], strict: bool, labels: dict, tree=None 85 | ) -> List[ValueObject]: 86 | return select(value_objects, strict, ne_func, labels, tree, op="ne") 87 | 88 | 89 | def select_gt( 90 | value_objects: List[ValueObject], strict: bool, labels: dict, tree=None 91 | ) -> List[ValueObject]: 92 | return select(value_objects, strict, gt_func, labels, tree, op="gt") 93 | 94 | 95 | def select_gte( 96 | value_objects: List[ValueObject], strict: bool, labels: dict, tree=None 97 | ) -> List[ValueObject]: 98 | return select(value_objects, strict, gte_func, labels, tree, op="gte") 99 | 100 | 101 | def select_gt_ix( 102 | value_objects: List[ValueObject], 103 | strict: bool, 104 | labels: dict, 105 | cmp_list: List, 106 | tree=None, 107 | ) -> List[ValueObject]: 108 | raise Exception("select_gt_ix is deprecated. Use Values instead.") 109 | 110 | 111 | def select_lt( 112 | value_objects: List[ValueObject], strict: bool, labels: dict, tree=None 113 | ) -> List[ValueObject]: 114 | return select(value_objects, strict, lt_func, labels, tree, op="lt") 115 | 116 | 117 | def select_lte( 118 | value_objects: List[ValueObject], strict: bool, labels: dict, tree=None 119 | ) -> List[ValueObject]: 120 | return select(value_objects, strict, lte_func, labels, tree, op="lte") 121 | -------------------------------------------------------------------------------- /paramtools/sorted_key_list.py: -------------------------------------------------------------------------------- 1 | import sortedcontainers 2 | 3 | 4 | class SortedKeyListException(Exception): 5 | default = ( 6 | "Unable to create SortedKeyList. It is likely that this label or " 7 | "value uses a custom data type. In this case, you should define a " 8 | "cmp_funcs method to make it orderable." 9 | ) 10 | 11 | def __init__(self, *args, **kwargs): 12 | if not args: 13 | args = (self.default,) 14 | 15 | super().__init__(*args, **kwargs) 16 | 17 | 18 | class SortedKeyListResult: 19 | def __init__(self, key_list_values): 20 | self.key_list_values = key_list_values 21 | 22 | @property 23 | def values(self): 24 | return [item[0] for item in self.key_list_values] 25 | 26 | @property 27 | def index(self): 28 | return [item[1] for item in self.key_list_values] 29 | 30 | def __iter__(self): 31 | for value in self.values: 32 | yield value 33 | 34 | 35 | class SortedKeyList: 36 | """ 37 | Sorted key list built on top of sortedcontainers. This adds some 38 | query methods like lt/gt/eq and keeps track of the original indices 39 | for the values. 40 | """ 41 | 42 | def __init__(self, values, keyfunc, index=None): 43 | if index: 44 | assert len(values) == len(index) 45 | if index is None: 46 | index = range(len(values)) 47 | sorted_key_list = [(val, ix) for ix, val in zip(index, values)] 48 | self.index = set(index) 49 | self.keyfunc = keyfunc 50 | 51 | try: 52 | self.sorted_key_list_2 = sortedcontainers.SortedKeyList( 53 | sorted_key_list, key=lambda t: keyfunc(t[0]) 54 | ) 55 | except TypeError as e: 56 | raise SortedKeyListException() from e 57 | 58 | def __repr__(self): 59 | return str(self.sorted_key_list) 60 | 61 | def eq(self, value): 62 | key = self.keyfunc(value) 63 | irange = list( 64 | self.sorted_key_list_2.irange_key(min_key=key, max_key=key) 65 | ) 66 | if irange: 67 | return SortedKeyListResult(irange) 68 | return None 69 | 70 | def ne(self, value): 71 | result = [] 72 | lt = self.lt(value) 73 | if lt is not None: 74 | result += lt.key_list_values 75 | gt = self.gt(value) 76 | if gt is not None: 77 | result += gt.key_list_values 78 | 79 | return SortedKeyListResult(result) 80 | 81 | def lt(self, value): 82 | key = self.keyfunc(value) 83 | irange = list( 84 | self.sorted_key_list_2.irange_key( 85 | max_key=key, inclusive=(True, False) 86 | ) 87 | ) 88 | if irange: 89 | return SortedKeyListResult(irange) 90 | return None 91 | 92 | def lte(self, value): 93 | key = self.keyfunc(value) 94 | irange = list( 95 | self.sorted_key_list_2.irange_key( 96 | max_key=key, inclusive=(True, True) 97 | ) 98 | ) 99 | if irange: 100 | return SortedKeyListResult(irange) 101 | return None 102 | 103 | def gt(self, value): 104 | key = self.keyfunc(value) 105 | irange = list( 106 | self.sorted_key_list_2.irange_key( 107 | min_key=key, inclusive=(False, True) 108 | ) 109 | ) 110 | if irange: 111 | return SortedKeyListResult(irange) 112 | return None 113 | 114 | def gte(self, value): 115 | key = self.keyfunc(value) 116 | irange = list( 117 | self.sorted_key_list_2.irange_key( 118 | min_key=key, inclusive=(True, True) 119 | ) 120 | ) 121 | if irange: 122 | return SortedKeyListResult(irange) 123 | return None 124 | 125 | def add(self, value, index=None): 126 | if index is None: 127 | index = max(self.index) + 1 128 | self.sorted_key_list_2.add((value, index)) 129 | self.index.add(index) 130 | -------------------------------------------------------------------------------- /paramtools/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PSLmodels/ParamTools/3a9094f50b45b5ec7324d8ffa9d6ef7070957452/paramtools/tests/__init__.py -------------------------------------------------------------------------------- /paramtools/tests/defaults.json: -------------------------------------------------------------------------------- 1 | { 2 | "schema": { 3 | "labels": { 4 | "label0": { 5 | "type": "str", 6 | "validators": { 7 | "choice": { 8 | "choices": [ 9 | "zero", 10 | "one" 11 | ] 12 | } 13 | } 14 | }, 15 | "label1": { 16 | "type": "int", 17 | "validators": { 18 | "range": { 19 | "min": 0, 20 | "max": 5 21 | } 22 | } 23 | }, 24 | "label2": { 25 | "type": "int", 26 | "validators": { 27 | "range": { 28 | "min": 0, 29 | "max": 2 30 | } 31 | } 32 | } 33 | }, 34 | "additional_members": { 35 | "opt0": { 36 | "type": "str" 37 | } 38 | } 39 | }, 40 | "when_param": { 41 | "title": "When validator reference param", 42 | "description": "Example for using 'when' validator", 43 | "type": "int", 44 | "value": 0, 45 | "validators": { 46 | "when": { 47 | "param": "str_choice_param", 48 | "is": "value0", 49 | "then": { 50 | "range": { 51 | "min": 0, 52 | "max": 0 53 | } 54 | }, 55 | "otherwise": { 56 | "choice": { 57 | "choices": [ 58 | 0, 59 | 2, 60 | 5, 61 | 7 62 | ] 63 | } 64 | } 65 | } 66 | } 67 | }, 68 | "when_array_param": { 69 | "title": "When validator reference array param", 70 | "description": "Example for using 'when' validator with an array param", 71 | "type": "int", 72 | "number_dims": 1, 73 | "value": [ 74 | 0, 75 | 1, 76 | 2, 77 | 5 78 | ], 79 | "validators": { 80 | "when": { 81 | "param": "simple_int_list_param", 82 | "is": 2, 83 | "then": { 84 | "choice": { 85 | "choices": [ 86 | 1 87 | ] 88 | } 89 | }, 90 | "otherwise": { 91 | "choice": { 92 | "choices": [ 93 | 0, 94 | 2, 95 | 5 96 | ] 97 | } 98 | } 99 | } 100 | } 101 | }, 102 | "float_param": { 103 | "title": "Float Reference Param", 104 | "description": "Example for a float param.", 105 | "opt0": "an option", 106 | "type": "float", 107 | "value": 2.0, 108 | "validators": {} 109 | }, 110 | "bool_param": { 111 | "title": "Boolean Reference Param", 112 | "description": "Example for a bool param.", 113 | "opt0": "an option", 114 | "type": "bool", 115 | "value": true, 116 | "validators": {} 117 | }, 118 | "min_int_param": { 119 | "title": "min integer parameter", 120 | "description": "Serves as minimum reference variable.", 121 | "notes": "See max_int_param", 122 | "opt0": "an option", 123 | "type": "int", 124 | "value": [ 125 | { 126 | "label0": "zero", 127 | "label1": 1, 128 | "value": 1 129 | }, 130 | { 131 | "label0": "one", 132 | "label1": 2, 133 | "value": 2 134 | } 135 | ], 136 | "validators": { 137 | "range": { 138 | "min": 0, 139 | "max": "max_int_param" 140 | } 141 | } 142 | }, 143 | "max_int_param": { 144 | "title": "max integer parameter", 145 | "description": "Serves as maximum reference variable.", 146 | "notes": "See min_int_param", 147 | "opt0": "an option", 148 | "type": "int", 149 | "value": [ 150 | { 151 | "label0": "zero", 152 | "label1": 1, 153 | "value": 3 154 | }, 155 | { 156 | "label0": "one", 157 | "label1": 2, 158 | "value": 4 159 | } 160 | ], 161 | "validators": { 162 | "range": { 163 | "min": "min_int_param", 164 | "max": 10 165 | } 166 | } 167 | }, 168 | "str_choice_param": { 169 | "title": "String Choice Param", 170 | "description": "Example for string type params using a choice validator", 171 | "opt0": "another option", 172 | "type": "str", 173 | "value": "value0", 174 | "validators": { 175 | "choice": { 176 | "choices": [ 177 | "value0", 178 | "value1" 179 | ] 180 | } 181 | } 182 | }, 183 | "date_param": { 184 | "title": "Date parameter", 185 | "description": "Example for a date parameter", 186 | "opt0": "another option", 187 | "type": "date", 188 | "value": [ 189 | { 190 | "label0": "zero", 191 | "label1": 1, 192 | "value": "2018-01-15" 193 | } 194 | ], 195 | "validators": { 196 | "date_range": { 197 | "min": "2018-01-01", 198 | "max": "2018-12-31" 199 | } 200 | } 201 | }, 202 | "date_min_param": { 203 | "title": "Date Min Param", 204 | "description": "Serves as minimum reference variable.", 205 | "notes": "See date_max_param.", 206 | "opt0": "an option", 207 | "type": "date", 208 | "value": [ 209 | { 210 | "label0": "zero", 211 | "label1": 1, 212 | "value": "2018-01-15" 213 | } 214 | ], 215 | "validators": { 216 | "date_range": { 217 | "min": "2018-01-01", 218 | "max": "date_max_param" 219 | } 220 | } 221 | }, 222 | "date_max_param": { 223 | "title": "Date max parameter", 224 | "description": "Serves as maximum reference variable.", 225 | "notes": "See date_min_param.", 226 | "opt0": "an option", 227 | "type": "date", 228 | "value": [ 229 | { 230 | "label0": "zero", 231 | "label1": 1, 232 | "value": "2018-01-15" 233 | } 234 | ], 235 | "validators": { 236 | "date_range": { 237 | "min": "date_min_param", 238 | "max": "2018-12-31" 239 | } 240 | } 241 | }, 242 | "float_list_when_param": { 243 | "title": "Float List When Param", 244 | "description": "Reference for a float list param, using a when validator.", 245 | "opt0": "an option", 246 | "type": "float", 247 | "number_dims": 1, 248 | "value": [ 249 | { 250 | "label0": "zero", 251 | "value": [ 252 | 0, 253 | 2.0, 254 | 3.0, 255 | 4.0 256 | ] 257 | } 258 | ], 259 | "validators": { 260 | "when": { 261 | "param": "float_list_param", 262 | "is": { 263 | "greater_than": 1 264 | }, 265 | "then": { 266 | "range": { 267 | "min": 1 268 | } 269 | }, 270 | "otherwise": { 271 | "range": { 272 | "min": 0 273 | } 274 | } 275 | } 276 | } 277 | }, 278 | "float_list_param": { 279 | "title": "Float List Param", 280 | "description": "Example for a float, list param.", 281 | "opt0": "an option", 282 | "type": "float", 283 | "number_dims": 1, 284 | "value": [ 285 | { 286 | "label0": "zero", 287 | "label1": 1, 288 | "value": [ 289 | 1, 290 | 2.0, 291 | 3.5, 292 | 4.6 293 | ] 294 | } 295 | ], 296 | "validators": { 297 | "range": { 298 | "min": 0, 299 | "max": 10 300 | } 301 | } 302 | }, 303 | "simple_int_list_param": { 304 | "title": "Simple Int List Param", 305 | "description": "Test case where param is simple and a list.", 306 | "opt0": "an option", 307 | "type": "int", 308 | "number_dims": 1, 309 | "value": [ 310 | 1, 311 | 2, 312 | 3, 313 | 4 314 | ], 315 | "validators": { 316 | "range": { 317 | "min": 0, 318 | "max": 10 319 | } 320 | } 321 | }, 322 | "int_default_param": { 323 | "title": "Integer Default Reference Param", 324 | "description": "Example for a int param using a default reference value", 325 | "opt0": "an option", 326 | "type": "int", 327 | "value": 2, 328 | "validators": { 329 | "range": { 330 | "min": "default", 331 | "max": 10 332 | } 333 | } 334 | }, 335 | "int_dense_array_param": { 336 | "title": "Integer Dense Array Param", 337 | "description": "Example of using an int type param that supports to/from_array.", 338 | "opt0": "an option", 339 | "type": "int", 340 | "value": [ 341 | { 342 | "label0": "zero", 343 | "label1": 0, 344 | "label2": 0, 345 | "value": 1 346 | }, 347 | { 348 | "label0": "zero", 349 | "label1": 0, 350 | "label2": 1, 351 | "value": 2 352 | }, 353 | { 354 | "label0": "zero", 355 | "label1": 0, 356 | "label2": 2, 357 | "value": 3 358 | }, 359 | { 360 | "label0": "zero", 361 | "label1": 1, 362 | "label2": 0, 363 | "value": 4 364 | }, 365 | { 366 | "label0": "zero", 367 | "label1": 1, 368 | "label2": 1, 369 | "value": 5 370 | }, 371 | { 372 | "label0": "zero", 373 | "label1": 1, 374 | "label2": 2, 375 | "value": 6 376 | }, 377 | { 378 | "label0": "zero", 379 | "label1": 2, 380 | "label2": 0, 381 | "value": 7 382 | }, 383 | { 384 | "label0": "zero", 385 | "label1": 2, 386 | "label2": 1, 387 | "value": 8 388 | }, 389 | { 390 | "label0": "zero", 391 | "label1": 2, 392 | "label2": 2, 393 | "value": 9 394 | }, 395 | { 396 | "label0": "zero", 397 | "label1": 3, 398 | "label2": 0, 399 | "value": 10 400 | }, 401 | { 402 | "label0": "zero", 403 | "label1": 3, 404 | "label2": 1, 405 | "value": 11 406 | }, 407 | { 408 | "label0": "zero", 409 | "label1": 3, 410 | "label2": 2, 411 | "value": 12 412 | }, 413 | { 414 | "label0": "zero", 415 | "label1": 4, 416 | "label2": 0, 417 | "value": 13 418 | }, 419 | { 420 | "label0": "zero", 421 | "label1": 4, 422 | "label2": 1, 423 | "value": 14 424 | }, 425 | { 426 | "label0": "zero", 427 | "label1": 4, 428 | "label2": 2, 429 | "value": 15 430 | }, 431 | { 432 | "label0": "zero", 433 | "label1": 5, 434 | "label2": 0, 435 | "value": 16 436 | }, 437 | { 438 | "label0": "zero", 439 | "label1": 5, 440 | "label2": 1, 441 | "value": 17 442 | }, 443 | { 444 | "label0": "zero", 445 | "label1": 5, 446 | "label2": 2, 447 | "value": 18 448 | }, 449 | { 450 | "label0": "one", 451 | "label1": 0, 452 | "label2": 0, 453 | "value": 19 454 | }, 455 | { 456 | "label0": "one", 457 | "label1": 0, 458 | "label2": 1, 459 | "value": 20 460 | }, 461 | { 462 | "label0": "one", 463 | "label1": 0, 464 | "label2": 2, 465 | "value": 21 466 | }, 467 | { 468 | "label0": "one", 469 | "label1": 1, 470 | "label2": 0, 471 | "value": 22 472 | }, 473 | { 474 | "label0": "one", 475 | "label1": 1, 476 | "label2": 1, 477 | "value": 23 478 | }, 479 | { 480 | "label0": "one", 481 | "label1": 1, 482 | "label2": 2, 483 | "value": 24 484 | }, 485 | { 486 | "label0": "one", 487 | "label1": 2, 488 | "label2": 0, 489 | "value": 25 490 | }, 491 | { 492 | "label0": "one", 493 | "label1": 2, 494 | "label2": 1, 495 | "value": 26 496 | }, 497 | { 498 | "label0": "one", 499 | "label1": 2, 500 | "label2": 2, 501 | "value": 27 502 | }, 503 | { 504 | "label0": "one", 505 | "label1": 3, 506 | "label2": 0, 507 | "value": 28 508 | }, 509 | { 510 | "label0": "one", 511 | "label1": 3, 512 | "label2": 1, 513 | "value": 29 514 | }, 515 | { 516 | "label0": "one", 517 | "label1": 3, 518 | "label2": 2, 519 | "value": 30 520 | }, 521 | { 522 | "label0": "one", 523 | "label1": 4, 524 | "label2": 0, 525 | "value": 31 526 | }, 527 | { 528 | "label0": "one", 529 | "label1": 4, 530 | "label2": 1, 531 | "value": 32 532 | }, 533 | { 534 | "label0": "one", 535 | "label1": 4, 536 | "label2": 2, 537 | "value": 33 538 | }, 539 | { 540 | "label0": "one", 541 | "label1": 5, 542 | "label2": 0, 543 | "value": 34 544 | }, 545 | { 546 | "label0": "one", 547 | "label1": 5, 548 | "label2": 1, 549 | "value": 35 550 | }, 551 | { 552 | "label0": "one", 553 | "label1": 5, 554 | "label2": 2, 555 | "value": 36 556 | } 557 | ], 558 | "validators": { 559 | "range": { 560 | "min": 1, 561 | "max": 36 562 | } 563 | } 564 | }, 565 | "str_choice_warn_param": { 566 | "title": "String Choice Warnings Param", 567 | "description": "Example for string type params using a choice validator with warnings", 568 | "opt0": "another option", 569 | "type": "str", 570 | "value": "value0", 571 | "validators": { 572 | "choice": { 573 | "choices": [ 574 | "value0", 575 | "value1" 576 | ], 577 | "level": "warn" 578 | } 579 | } 580 | }, 581 | "int_warn_param": { 582 | "title": "Integer Parameter that uses warnings", 583 | "description": "Example for a int param using warnings with a range validator", 584 | "opt0": "an option", 585 | "type": "int", 586 | "value": 2, 587 | "validators": { 588 | "range": { 589 | "min": 0, 590 | "max": 10, 591 | "level": "warn" 592 | } 593 | } 594 | } 595 | } -------------------------------------------------------------------------------- /paramtools/tests/extend_ex.json: -------------------------------------------------------------------------------- 1 | { 2 | "schema": { 3 | "labels": { 4 | "d0": { 5 | "type": "int", 6 | "validators": {"range": {"min": 0, "max": 10}} 7 | }, 8 | "d1": { 9 | "type": "str", 10 | "validators": { 11 | "choice": {"choices": ["c1", "c2"]} 12 | } 13 | } 14 | } 15 | }, 16 | "extend_param": { 17 | "title": "extend param", 18 | "description": ".", 19 | "type": "int", 20 | "value": [ 21 | {"d0": 2, "d1": "c1", "value": 1}, 22 | {"d0": 2, "d1": "c2", "value": 2}, 23 | {"d0": 3, "d1": "c1", "value": 3}, 24 | {"d0": 3, "d1": "c2", "value": 4}, 25 | {"d0": 5, "d1": "c1", "value": 5}, 26 | {"d0": 5, "d1": "c2", "value": 6}, 27 | {"d0": 7, "d1": "c1", "value": 7}, 28 | {"d0": 7, "d1": "c2", "value": 8} 29 | ], 30 | "validators": { 31 | "range": { 32 | "min": -100, "max": "related_param" 33 | } 34 | } 35 | }, 36 | "indexed_param": { 37 | "title": "indexed param", 38 | "description": ".", 39 | "type": "float", 40 | "indexed": true, 41 | "value": [ 42 | {"d0": 2, "d1": "c1", "value": 1}, 43 | {"d0": 2, "d1": "c2", "value": 2}, 44 | {"d0": 3, "d1": "c1", "value": 3}, 45 | {"d0": 3, "d1": "c2", "value": 4}, 46 | {"d0": 5, "d1": "c1", "value": 5}, 47 | {"d0": 5, "d1": "c2", "value": 6}, 48 | {"d0": 7, "d1": "c1", "value": 7}, 49 | {"d0": 7, "d1": "c2", "value": 8} 50 | ], 51 | "validators": { 52 | "range": { 53 | "min": -100, "max": "related_param" 54 | } 55 | } 56 | }, 57 | "related_param": { 58 | "title": "related param", 59 | "description": "Test error on adjustment extension.", 60 | "type": "int", 61 | "value": [ 62 | {"d0": 0, "d1": "c1", "value": 100}, 63 | {"d0": 0, "d1": "c2", "value": 101}, 64 | {"d0": 7, "d1": "c1", "value": 50}, 65 | {"d0": 7, "d1": "c2", "value": 51} 66 | ] 67 | }, 68 | "nonextend_param": { 69 | "title": "nonextend param", 70 | "description": "Test error on adjustment extension.", 71 | "type": "int", 72 | "value": 2 73 | } 74 | } -------------------------------------------------------------------------------- /paramtools/tests/test_examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PSLmodels/ParamTools/3a9094f50b45b5ec7324d8ffa9d6ef7070957452/paramtools/tests/test_examples/__init__.py -------------------------------------------------------------------------------- /paramtools/tests/test_examples/test_baseball.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | from paramtools import parameters 6 | 7 | CURRENT_PATH = os.path.abspath(os.path.dirname(__file__)) 8 | 9 | 10 | @pytest.fixture 11 | def field_map(): 12 | # nothing here for now 13 | return {} 14 | 15 | 16 | @pytest.fixture 17 | def defaults_spec_path(): 18 | return os.path.join(CURRENT_PATH, "../../examples/baseball/defaults.json") 19 | 20 | 21 | @pytest.fixture 22 | def BaseballParams(defaults_spec_path): 23 | class _BaseballParams(parameters.Parameters): 24 | defaults = defaults_spec_path 25 | 26 | return _BaseballParams 27 | 28 | 29 | def test_load_schema(BaseballParams): 30 | params = BaseballParams() 31 | assert params 32 | -------------------------------------------------------------------------------- /paramtools/tests/test_examples/test_behresp.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | from paramtools import Parameters 6 | 7 | CURRENT_PATH = os.path.abspath(os.path.dirname(__file__)) 8 | 9 | 10 | @pytest.fixture 11 | def field_map(): 12 | # nothing here for now 13 | return {} 14 | 15 | 16 | @pytest.fixture 17 | def defaults_spec_path(): 18 | return os.path.join(CURRENT_PATH, "../../examples/behresp/defaults.json") 19 | 20 | 21 | @pytest.fixture 22 | def BehrespParams(defaults_spec_path): 23 | class _BehrespParams(Parameters): 24 | defaults = defaults_spec_path 25 | 26 | return _BehrespParams 27 | 28 | 29 | def test_load_schema(BehrespParams): 30 | params = BehrespParams() 31 | assert params 32 | -------------------------------------------------------------------------------- /paramtools/tests/test_examples/test_tc_ex.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | from marshmallow import fields, Schema 6 | 7 | from paramtools import Parameters, register_custom_type 8 | 9 | 10 | CURRENT_PATH = os.path.abspath(os.path.dirname(__file__)) 11 | 12 | 13 | class CompatibleDataSchema(Schema): 14 | """ 15 | Schema for Compatible data object 16 | { 17 | "compatible_data": {"data1": bool, "data2": bool, ...} 18 | } 19 | """ 20 | 21 | puf = fields.Boolean() 22 | cps = fields.Boolean() 23 | 24 | 25 | @pytest.fixture 26 | def register_compatible_data(): 27 | register_custom_type( 28 | "compatible_data", fields.Nested(CompatibleDataSchema()) 29 | ) 30 | 31 | 32 | @pytest.fixture 33 | def defaults_spec_path(): 34 | return os.path.join(CURRENT_PATH, "../../examples/taxparams/defaults.json") 35 | 36 | 37 | @pytest.fixture 38 | def TaxcalcParams(defaults_spec_path, register_compatible_data): 39 | class _TaxcalcParams(Parameters): 40 | defaults = defaults_spec_path 41 | 42 | return _TaxcalcParams 43 | 44 | 45 | def test_load_schema(TaxcalcParams): 46 | params = TaxcalcParams() 47 | assert params 48 | 49 | 50 | @pytest.fixture 51 | def demo_defaults_spec_path(): 52 | return os.path.join( 53 | CURRENT_PATH, "../../examples/taxparams-demo/defaults.json" 54 | ) 55 | 56 | 57 | @pytest.fixture 58 | def TaxDemoParams(demo_defaults_spec_path): 59 | class _TaxDemoParams(Parameters): 60 | defaults = demo_defaults_spec_path 61 | 62 | return _TaxDemoParams 63 | 64 | 65 | def test_load_demo_schema(TaxDemoParams): 66 | params = TaxDemoParams() 67 | assert params 68 | -------------------------------------------------------------------------------- /paramtools/tests/test_fields.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | import numpy as np 4 | 5 | from paramtools.contrib import fields, validate 6 | 7 | 8 | def test_np_value_fields(): 9 | float64 = fields.Float64() 10 | res = float64._deserialize("2", None, None) 11 | assert res == 2.0 12 | assert isinstance(res, np.float64) 13 | assert type(float64._serialize(res, None, None)) == float 14 | 15 | int64 = fields.Int64() 16 | res = int64._deserialize("2", None, None) 17 | assert res == 2 18 | assert isinstance(res, np.int64) 19 | assert type(int64._serialize(res, None, None)) == int 20 | 21 | bool_ = fields.Bool_() 22 | res = bool_._deserialize("true", None, None) 23 | assert res is np.bool_(True) 24 | assert isinstance(res, np.bool_) 25 | assert bool_._serialize(res, None, None) is True 26 | 27 | 28 | def test_contrib_fields(): 29 | range_validator = validate.Range(0, 10) 30 | daterange_validator = validate.DateRange( 31 | "2019-01-01", "2019-01-05", step={"days": 2} 32 | ) 33 | choice_validator = validate.OneOf(choices=["one", "two"]) 34 | 35 | s = fields.Str(validate=[choice_validator]) 36 | assert s.grid() == ["one", "two"] 37 | s = fields.Str() 38 | assert s.grid() == [] 39 | 40 | s = fields.Integer(validate=[range_validator]) 41 | assert s.grid() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 42 | s = fields.Str() 43 | assert s.grid() == [] 44 | 45 | # date will need an interval argument. 46 | s = fields.Date(validate=[daterange_validator]) 47 | assert s.grid() == [datetime.date(2019, 1, i) for i in range(1, 6, 2)] 48 | 49 | s = fields.Date() 50 | assert s._deserialize(datetime.date(2015, 1, 1), None, None) 51 | 52 | 53 | def test_cmp_funcs(): 54 | range_validator = validate.Range(0, 10) 55 | daterange_validator = validate.DateRange( 56 | "2019-01-01", "2019-01-05", step={"days": 2} 57 | ) 58 | choice_validator = validate.OneOf(choices=["one", "two"]) 59 | 60 | cases = [ 61 | ("one", "two", fields.Str(validate=[choice_validator])), 62 | ( 63 | datetime.date(2019, 1, 2), 64 | datetime.date(2019, 1, 3), 65 | fields.Date(validate=[daterange_validator]), 66 | ), 67 | (2, 5, fields.Integer(validate=[range_validator])), 68 | ] 69 | 70 | for (min_, max_, field) in cases: 71 | cmp_funcs = field.cmp_funcs() 72 | assert cmp_funcs["gt"](min_, max_) is False 73 | assert cmp_funcs["lt"](min_, max_) is True 74 | assert cmp_funcs["eq"](min_, max_) is False 75 | assert cmp_funcs["eq"](max_, max_) is True 76 | assert cmp_funcs["lte"](max_, max_) is True 77 | assert cmp_funcs["lte"](min_, max_) is True 78 | assert cmp_funcs["gte"](max_, min_) is True 79 | -------------------------------------------------------------------------------- /paramtools/tests/test_schema.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import pytest 4 | import marshmallow as ma 5 | 6 | from paramtools import ( 7 | get_type, 8 | get_param_schema, 9 | register_custom_type, 10 | ALLOWED_TYPES, 11 | UnknownTypeException, 12 | PartialField, 13 | ParamToolsError, 14 | ) 15 | 16 | 17 | def test_get_type_with_list(): 18 | int_field = get_type({"type": "int"}) 19 | 20 | list_int_field = get_type({"type": "int", "number_dims": 1}) 21 | assert list_int_field.np_type == int_field.np_type 22 | 23 | list_int_field = get_type({"type": "int", "number_dims": 2}) 24 | assert list_int_field.np_type == int_field.np_type 25 | 26 | 27 | def test_register_custom_type(): 28 | """ 29 | Test allowed to register marshmallow field and PartialField instances and 30 | test that uninitialized fields and random classes throw type errors. 31 | """ 32 | custom_type = "custom" 33 | assert custom_type not in ALLOWED_TYPES 34 | register_custom_type(custom_type, ma.fields.String()) 35 | assert custom_type in ALLOWED_TYPES 36 | 37 | register_custom_type("partial-test", PartialField(ma.fields.Str(), {})) 38 | assert "partial-test" in ALLOWED_TYPES 39 | 40 | with pytest.raises(TypeError): 41 | register_custom_type("custom", ma.fields.Str) 42 | 43 | class Whatever: 44 | pass 45 | 46 | with pytest.raises(TypeError): 47 | register_custom_type("whatever", Whatever()) 48 | 49 | 50 | def test_get_type(): 51 | custom_type = "custom" 52 | register_custom_type(custom_type, ma.fields.String()) 53 | 54 | assert isinstance(get_type({"type": custom_type}), ma.fields.String) 55 | 56 | with pytest.raises(UnknownTypeException): 57 | get_type({"type": "unknown"}) 58 | 59 | 60 | def test_make_schema(): 61 | custom_type = "custom" 62 | register_custom_type(custom_type, ma.fields.String()) 63 | 64 | schema = { 65 | "labels": {"lab": {"type": custom_type, "validators": {}}}, 66 | "additional_members": {"custom": {"type": custom_type}}, 67 | } 68 | 69 | assert get_param_schema(schema) 70 | 71 | bad_schema = copy.deepcopy(schema) 72 | bad_schema["labels"]["lab"]["type"] = "unknown" 73 | with pytest.raises(UnknownTypeException): 74 | get_param_schema(bad_schema) 75 | 76 | bad_schema = copy.deepcopy(schema) 77 | bad_schema["additional_members"]["custom"]["type"] = "unknown" 78 | with pytest.raises(UnknownTypeException): 79 | get_param_schema(bad_schema) 80 | 81 | schema = { 82 | "labels": { 83 | "lab": { 84 | "type": custom_type, 85 | "validators": {"choice": {"choices": ["hello"]}}, 86 | } 87 | }, 88 | "additional_members": {"custom": {"type": custom_type}}, 89 | } 90 | with pytest.raises(ParamToolsError): 91 | get_param_schema(schema) 92 | -------------------------------------------------------------------------------- /paramtools/tests/test_select.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from paramtools.select import ( 4 | select_eq, 5 | select_ne, 6 | select_gt, 7 | select_gte, 8 | select_lt, 9 | select_lte, 10 | ) 11 | 12 | 13 | @pytest.fixture 14 | def vos(): 15 | return [ 16 | {"d0": 1, "d1": "hello", "value": 1}, 17 | {"d0": 1, "d1": "world", "value": 1}, 18 | {"d0": 2, "d1": "hello", "value": 1}, 19 | {"d0": 3, "d1": "world", "value": 1}, 20 | ] 21 | 22 | 23 | def test_select_eq(vos): 24 | assert list(select_eq(vos, False, labels={"d0": 1, "d1": "hello"})) == [ 25 | {"d0": 1, "d1": "hello", "value": 1} 26 | ] 27 | 28 | assert list( 29 | select_eq(vos, False, labels={"d0": [1, 2], "d1": "hello"}) 30 | ) == [ 31 | {"d0": 1, "d1": "hello", "value": 1}, 32 | {"d0": 2, "d1": "hello", "value": 1}, 33 | ] 34 | 35 | 36 | def test_select_eq_strict(vos): 37 | assert list(select_eq(vos, True, labels={"d0": 1, "d1": "hello"})) == [ 38 | {"d0": 1, "d1": "hello", "value": 1} 39 | ] 40 | 41 | assert list( 42 | select_eq(vos, True, labels={"d0": [1, 2], "d1": "hello"}) 43 | ) == [ 44 | {"d0": 1, "d1": "hello", "value": 1}, 45 | {"d0": 2, "d1": "hello", "value": 1}, 46 | ] 47 | 48 | vos[2]["_auto"] = True 49 | vos[3]["_auto"] = True 50 | assert list(select_eq(vos, False, labels={"_auto": False})) == [ 51 | {"d0": 1, "d1": "hello", "value": 1}, 52 | {"d0": 1, "d1": "world", "value": 1}, 53 | ] 54 | 55 | 56 | def test_select_ne(vos): 57 | assert list(select_ne(vos, False, labels={"d0": 1, "d1": "hello"})) == [ 58 | {"d0": 3, "d1": "world", "value": 1} 59 | ] 60 | 61 | assert list(select_ne(vos, False, labels={"d0": [2, 3]})) == [ 62 | {"d0": 1, "d1": "hello", "value": 1}, 63 | {"d0": 1, "d1": "world", "value": 1}, 64 | ] 65 | 66 | 67 | def test_select_gt(vos): 68 | assert list(select_gt(vos, False, labels={"d0": 1})) == [ 69 | {"d0": 2, "d1": "hello", "value": 1}, 70 | {"d0": 3, "d1": "world", "value": 1}, 71 | ] 72 | 73 | 74 | def test_select_gte(vos): 75 | assert list(select_gte(vos, False, labels={"d0": 2})) == [ 76 | {"d0": 2, "d1": "hello", "value": 1}, 77 | {"d0": 3, "d1": "world", "value": 1}, 78 | ] 79 | 80 | 81 | def test_select_lt(vos): 82 | assert list(select_lt(vos, False, labels={"d0": 3})) == [ 83 | {"d0": 1, "d1": "hello", "value": 1}, 84 | {"d0": 1, "d1": "world", "value": 1}, 85 | {"d0": 2, "d1": "hello", "value": 1}, 86 | ] 87 | 88 | 89 | def test_select_lte(vos): 90 | assert list(select_lte(vos, False, labels={"d0": 2})) == [ 91 | {"d0": 1, "d1": "hello", "value": 1}, 92 | {"d0": 1, "d1": "world", "value": 1}, 93 | {"d0": 2, "d1": "hello", "value": 1}, 94 | ] 95 | -------------------------------------------------------------------------------- /paramtools/tests/test_sorted_key_list.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from paramtools.sorted_key_list import SortedKeyList, SortedKeyListException 3 | 4 | 5 | def test_sorted_key_list(): 6 | values = { 7 | "red": 2, 8 | "blue": 3, 9 | "orange": 5, 10 | "white": 6, 11 | "yellow": 7, 12 | "green": 9, 13 | "black": 0, 14 | } 15 | 16 | to_add = ["red", "blue", "orange", "yellow"] 17 | 18 | skl = SortedKeyList( 19 | to_add, keyfunc=lambda x: values[x], index=list(range(len(to_add))) 20 | ) 21 | 22 | assert skl.eq("black") is None 23 | assert skl.gte("black").values[0] == "red" 24 | assert skl.lte("black") is None 25 | skl.add("black") 26 | assert skl.gte("black").values[0] == "black" 27 | assert skl.lte("black").values[-1] == "black" 28 | assert skl.eq("black").values == ["black"] 29 | 30 | assert skl.gte("white").values[0] == "yellow" 31 | assert skl.lte("white").values[-1] == "orange" 32 | skl.add("white") 33 | assert skl.gte("white").values[0] == "white" 34 | assert skl.gt("white").values[0] == "yellow" 35 | assert skl.lte("white").values[-1] == "white" 36 | assert skl.lt("yellow").values[-1] == "white" 37 | 38 | assert skl.gte("green") is None 39 | assert skl.lte("green").values[-1] == "yellow" 40 | skl.add("green") 41 | assert skl.gte("green").values[0] == "green" 42 | assert skl.lte("green").values[-1] == "green" 43 | 44 | skl.add("green") 45 | assert skl.eq("green").values == ["green", "green"] 46 | 47 | assert set(skl.ne("green").values) == set(list(values.keys())) - {"green"} 48 | values["pokadot"] = -1 49 | assert set(skl.ne("pokadot").values) == set(list(values.keys())) - { 50 | "pokadot" 51 | } 52 | 53 | 54 | def test_exception(): 55 | with pytest.raises(SortedKeyListException): 56 | SortedKeyList( 57 | [ 58 | {"really": {"nested": {"field": True}}}, 59 | {"really": {"nested": {"field": False}}}, 60 | ], 61 | keyfunc=lambda x: x, 62 | ) 63 | -------------------------------------------------------------------------------- /paramtools/tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | 4 | from paramtools import ( 5 | get_leaves, 6 | ravel, 7 | consistent_labels, 8 | ensure_value_object, 9 | hashable_value_object, 10 | filter_labels, 11 | make_label_str, 12 | read_json, 13 | ) 14 | 15 | 16 | class TestRead: 17 | @pytest.mark.network_bound 18 | def test_read_s3(self): 19 | res = read_json("s3://paramtools-test/defaults.json", {"anon": True}) 20 | assert isinstance(res, dict) 21 | 22 | # @pytest.mark.network_bound 23 | # def test_read_gcp(self): 24 | # res = read_json("gs://paramtools-dev/defaults.json", {"token": "anon"}) 25 | # assert isinstance(res, dict) 26 | 27 | @pytest.mark.network_bound 28 | def test_read_http(self): 29 | http_path = ( 30 | "https://raw.githubusercontent.com/PSLmodels/ParamTools/master/" 31 | "paramtools/tests/defaults.json" 32 | ) 33 | res = read_json(http_path) 34 | assert isinstance(res, dict) 35 | 36 | @pytest.mark.network_bound 37 | def test_read_github(self): 38 | gh_path = "github://PSLmodels:ParamTools@master/paramtools/tests/defaults.json" 39 | res = read_json(gh_path) 40 | assert isinstance(res, dict) 41 | 42 | def test_read_file_path(self): 43 | CURRENT_PATH = os.path.abspath(os.path.dirname(__file__)) 44 | defaults_path = os.path.join(CURRENT_PATH, "defaults.json") 45 | res = read_json(defaults_path) 46 | assert isinstance(res, dict) 47 | 48 | def test_read_string(self): 49 | res = read_json('{"hello": "world"}') 50 | assert isinstance(res, dict) 51 | 52 | def test_read_invalid(self): 53 | with pytest.raises(ValueError): 54 | read_json('{"hello": "world"') 55 | 56 | with pytest.raises(ValueError): 57 | read_json(f":{['a'] * 200}") 58 | 59 | with pytest.raises(TypeError): 60 | read_json(("hello", "world")) 61 | 62 | with pytest.raises(TypeError): 63 | read_json(None) 64 | 65 | def test_strip_comments_simple(self): 66 | """test strip comment""" 67 | params = """ 68 | // my comment 69 | // another 70 | { 71 | "hello": "world" 72 | } 73 | """ 74 | assert read_json(params) == {"hello": "world"} 75 | 76 | def test_strip_comments_multiline(self): 77 | """test strip comment""" 78 | params = """ 79 | /* my comment 80 | another 81 | */ 82 | { 83 | "hello": "world" 84 | } 85 | """ 86 | assert read_json(params) == {"hello": "world"} 87 | 88 | def test_strip_comments_ignores_url(self): 89 | """test strips comment but doesn't affect http://...""" 90 | params = """ 91 | // my comment 92 | { 93 | "hello": "http://world" 94 | } 95 | """ 96 | assert read_json(params) == {"hello": "http://world"} 97 | 98 | 99 | def test_get_leaves(): 100 | t = { 101 | 0: {"value": {0: ["leaf1", "leaf2"]}, 1: {"value": {0: ["leaf3"]}}}, 102 | 1: { 103 | "value": {1: ["leaf4", "leaf5"]}, 104 | 2: {"value": {0: ["leaf6", ["leaf7", "leaf8"]]}}, 105 | }, 106 | } 107 | 108 | leaves = get_leaves(t) 109 | assert leaves == [f"leaf{i}" for i in range(1, 9)] 110 | 111 | leaves = get_leaves([t]) 112 | assert leaves == [f"leaf{i}" for i in range(1, 9)] 113 | 114 | leaves = get_leaves({}) 115 | assert leaves == [] 116 | 117 | leaves = get_leaves([]) 118 | assert leaves == [] 119 | 120 | leaves = get_leaves("leaf") 121 | assert leaves == ["leaf"] 122 | 123 | 124 | def test_ravel(): 125 | a = 1 126 | assert ravel(a) == 1 127 | 128 | b = [1, 2, 3] 129 | assert ravel(b) == [1, 2, 3] 130 | 131 | c = [[1], 2, 3] 132 | assert ravel(c) == [1, 2, 3] 133 | 134 | d = [[1, 2, 3], [4, 5, 6]] 135 | assert ravel(d) == [1, 2, 3, 4, 5, 6] 136 | 137 | e = [0, [1, 2, 3], 4, [5, 6, 7], 8] 138 | assert ravel(e) == [0, 1, 2, 3, 4, 5, 6, 7, 8] 139 | 140 | 141 | def test_consistent_labels(): 142 | v = [ 143 | {"label0": 1, "label1": 2, "value": 3}, 144 | {"label0": 4, "label1": 5, "value": 6}, 145 | ] 146 | assert consistent_labels(v) == set(["label0", "label1"]) 147 | 148 | v = [{"label0": 1, "value": 3}, {"label0": 4, "label1": 5, "value": 6}] 149 | assert consistent_labels(v) is None 150 | 151 | v = [{"label0": 1, "label1": 2, "value": 3}, {"label0": 4, "value": 6}] 152 | assert consistent_labels(v) is None 153 | 154 | 155 | def test_ensure_value_object(): 156 | assert ensure_value_object("hello") == [{"value": "hello"}] 157 | assert ensure_value_object([{"value": "hello"}]) == [{"value": "hello"}] 158 | assert ensure_value_object([1, 2, 3]) == [{"value": [1, 2, 3]}] 159 | assert ensure_value_object([[1, 2, 3]]) == [{"value": [[1, 2, 3]]}] 160 | assert ensure_value_object({"hello": "world"}) == [ 161 | {"value": {"hello": "world"}} 162 | ] 163 | 164 | 165 | def test_hashable_value_object(): 166 | assert hash(hashable_value_object({"value": "hello", "world": "!"})) 167 | 168 | 169 | def test_filter_labels(): 170 | assert filter_labels({"hello": "world"}, drop=["hello"]) == {} 171 | assert filter_labels({"hello": "world"}, keep=["hello"]) == { 172 | "hello": "world" 173 | } 174 | assert filter_labels({"hello": "world"}) == {"hello": "world"} 175 | assert filter_labels( 176 | {"hello": "world", "world": "hello"}, drop=["world"] 177 | ) == {"hello": "world"} 178 | 179 | 180 | def test_make_label_str(): 181 | assert make_label_str({"hello": "world", "value": 0}) == "[hello=world]" 182 | assert make_label_str({"value": 0}) == "" 183 | assert make_label_str({}) == "" 184 | assert make_label_str({"b": 0, "c": 1, "a": 2}) == "[a=2, b=0, c=1]" 185 | -------------------------------------------------------------------------------- /paramtools/tests/test_validate.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | import pytest 4 | from marshmallow import ValidationError 5 | 6 | from paramtools.contrib import OneOf, Range, DateRange, When 7 | 8 | 9 | def test_OneOf(): 10 | choices = ["allowed1", "allowed2"] 11 | 12 | oneof = OneOf(choices=choices) 13 | assert oneof("allowed1") == "allowed1" 14 | assert oneof(choices) == choices 15 | assert oneof([choices]) == [choices] 16 | 17 | with pytest.raises(ValidationError): 18 | oneof("notallowed") 19 | 20 | with pytest.raises(ValidationError): 21 | oneof(["notallowed", "allowed1"]) 22 | 23 | # no support for 3-D arrays yet. 24 | with pytest.raises(ValidationError): 25 | oneof([[choices]]) 26 | 27 | assert oneof("allowed1") 28 | assert oneof({"value": "allowed1"}, is_value_object=True) 29 | 30 | with pytest.raises(ValidationError): 31 | oneof("notallowed") 32 | 33 | with pytest.raises(ValidationError): 34 | oneof({"value": "notallowed"}, is_value_object=True) 35 | 36 | 37 | def test_Range_errors(): 38 | range_ = Range(0, 10) 39 | with pytest.raises(ValidationError): 40 | range_(11) 41 | 42 | with pytest.raises(ValidationError): 43 | range_({"value": 11}, is_value_object=True) 44 | 45 | range_ = Range(min_vo=[{"value": 0}], max_vo=[{"value": 10}]) 46 | with pytest.raises(ValidationError): 47 | range_(11) 48 | 49 | with pytest.raises(ValidationError): 50 | range_({"value": 11}, is_value_object=True) 51 | 52 | range_ = Range( 53 | min_vo=[{"lab0": 1, "value": 0}, {"lab0": 2, "value": 2}], 54 | max_vo=[{"lab0": 1, "value": 10}, {"lab0": 2, "value": 9}], 55 | error_min="param{labels} {input} < min {min} oth_param{oth_labels}", 56 | error_max="param{labels} {input} > max {max} max_oth_param{oth_labels}", 57 | ) 58 | with pytest.raises(ValidationError) as excinfo: 59 | range_({"lab0": 1, "value": 11}, is_value_object=True) 60 | assert ( 61 | excinfo.value.args[0][0] 62 | == "param[lab0=1] 11 > max 10 max_oth_param[lab0=1]" 63 | ) 64 | 65 | with pytest.raises(ValidationError) as excinfo: 66 | range_({"value": 11}, is_value_object=True) 67 | assert excinfo.value.args[0] == [ 68 | "param 11 > max 10 max_oth_param[lab0=1]", 69 | "param 11 > max 9 max_oth_param[lab0=2]", 70 | ] 71 | 72 | 73 | def test_Range_grid(): 74 | range_ = Range(0, 10) 75 | assert range_.grid() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 76 | 77 | range_ = Range(0, 10, step=3) 78 | assert range_.grid() == [0, 3, 6, 9] 79 | 80 | 81 | def test_DateRange(): 82 | drange = DateRange("2019-01-01", "2019-01-10", step={"days": 1}) 83 | exp = [datetime.date(2019, 1, i) for i in range(1, 10 + 1)] 84 | assert drange.grid() == exp 85 | 86 | drange = DateRange("2019-01-01", "2019-01-10") 87 | exp = [datetime.date(2019, 1, i) for i in range(1, 10 + 1)] 88 | assert drange.grid() == exp 89 | 90 | drange = DateRange("2019-01-01", "2019-01-10", step={"days": 3}) 91 | exp = [datetime.date(2019, 1, i) for i in range(1, 10 + 1, 3)] 92 | assert drange.grid() == exp 93 | 94 | dranges = [ 95 | DateRange("2019-01-01", "2019-01-10", step={"days": 1}), 96 | DateRange( 97 | min_vo=[{"value": "2019-01-01"}], 98 | max_vo=[{"value": "2019-01-10"}], 99 | step={"days": 1}, 100 | ), 101 | ] 102 | for drange in dranges: 103 | assert drange(datetime.date(2019, 1, 2)) 104 | assert drange( 105 | {"value": datetime.date(2019, 1, 2)}, is_value_object=True 106 | ) 107 | 108 | with pytest.raises(ValidationError): 109 | drange(datetime.date(2020, 1, 2)) 110 | 111 | with pytest.raises(ValidationError): 112 | drange({"value": datetime.date(2020, 1, 2)}, is_value_object=True) 113 | 114 | 115 | def test_When(): 116 | range_ = Range(0, 10) 117 | choices = [12, 15] 118 | oneof = OneOf(choices=choices) 119 | when = When( 120 | {"equal_to": "world"}, 121 | when_vos=[{"value": "hello"}], 122 | then_validators=[range_], 123 | otherwise_validators=[oneof], 124 | ) 125 | 126 | when(12) 127 | 128 | with pytest.raises(ValidationError): 129 | when(3) 130 | 131 | when = When( 132 | {"equal_to": "hello"}, 133 | when_vos=[{"value": "hello"}], 134 | then_validators=[range_], 135 | otherwise_validators=[oneof], 136 | ) 137 | 138 | when(3) 139 | 140 | with pytest.raises(ValidationError): 141 | when(12) 142 | 143 | assert when.grid() == list(range(10 + 1)) 144 | 145 | 146 | def test_level(): 147 | oneof = OneOf(choices=["allowed1", "allowed2"], level="warn") 148 | assert oneof.level == "warn" 149 | with pytest.raises(ValidationError) as excinfo: 150 | oneof("notachoice") 151 | assert excinfo.value.level == "warn" 152 | 153 | range_ = Range(0, 10, level="warn") 154 | assert range_.level == "warn" 155 | with pytest.raises(ValidationError) as excinfo: 156 | range_(11) 157 | assert excinfo.value.level == "warn" 158 | -------------------------------------------------------------------------------- /paramtools/tests/test_values.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import copy 3 | 4 | from paramtools.values import Values, Slice, QueryResult, ValueItem 5 | 6 | 7 | @pytest.fixture 8 | def keyfuncs(): 9 | return {"d0": lambda x: x, "d1": lambda x: ["hello", "world"].index(x)} 10 | 11 | 12 | @pytest.fixture 13 | def _values(): 14 | return [ 15 | {"d0": 1, "d1": "hello", "value": 1}, 16 | {"d0": 1, "d1": "world", "value": 1}, 17 | {"d0": 2, "d1": "hello", "value": 1}, 18 | {"d0": 3, "d1": "world", "value": 1}, 19 | ] 20 | 21 | 22 | @pytest.fixture 23 | def values(_values, keyfuncs): 24 | return Values(_values, keyfuncs) 25 | 26 | 27 | class TestValues: 28 | def test_values(self, values): 29 | assert len(values) == 4 30 | 31 | def test_key_error(self, values): 32 | with pytest.raises(KeyError): 33 | values["heyo"] 34 | 35 | def test_types(self, values): 36 | assert isinstance(values["d0"], Slice) 37 | 38 | assert isinstance(values["d0"] > 1, QueryResult) 39 | 40 | qr = values["d0"] > 1 41 | assert isinstance(qr.isel[0], dict) 42 | assert isinstance(qr.isel[:], list) 43 | 44 | assert isinstance(values.isel, ValueItem) 45 | 46 | assert isinstance(qr.as_values(), Values) 47 | 48 | 49 | class TestQuery: 50 | def test_select_eq(self, values): 51 | assert list((values["d0"] == 1) & (values["d1"] == "hello")) == [ 52 | {"d0": 1, "d1": "hello", "value": 1} 53 | ] 54 | assert list( 55 | ((values["d0"] == 1) | (values["d0"] == 2)) 56 | & (values["d1"] == "hello") 57 | ) == [ 58 | {"d0": 1, "d1": "hello", "value": 1}, 59 | {"d0": 2, "d1": "hello", "value": 1}, 60 | ] 61 | 62 | def test_select_eq_strict(self, _values, keyfuncs): 63 | _values[2]["_auto"] = True 64 | _values[3]["_auto"] = True 65 | values = Values(_values, keyfuncs) 66 | 67 | assert list( 68 | (values["_auto"] == False) | (values.missing("_auto")) 69 | ) == [ 70 | {"d0": 1, "d1": "hello", "value": 1}, 71 | {"d0": 1, "d1": "world", "value": 1}, 72 | ] 73 | 74 | def test_select_ne(self, values): 75 | assert list((values["d0"] != 1) & (values["d1"] != "hello")) == [ 76 | {"d0": 3, "d1": "world", "value": 1} 77 | ] 78 | 79 | assert list((values["d0"] != 2) & (values["d0"] != 3)) == [ 80 | {"d0": 1, "d1": "hello", "value": 1}, 81 | {"d0": 1, "d1": "world", "value": 1}, 82 | ] 83 | 84 | def test_select_gt(self, values): 85 | assert list(values["d0"] > 1) == [ 86 | {"d0": 2, "d1": "hello", "value": 1}, 87 | {"d0": 3, "d1": "world", "value": 1}, 88 | ] 89 | 90 | def test_select_gte(self, values): 91 | assert list(values["d0"] >= 2) == [ 92 | {"d0": 2, "d1": "hello", "value": 1}, 93 | {"d0": 3, "d1": "world", "value": 1}, 94 | ] 95 | 96 | def test_select_lt(self, values): 97 | assert list(values["d0"] < 3) == [ 98 | {"d0": 1, "d1": "hello", "value": 1}, 99 | {"d0": 1, "d1": "world", "value": 1}, 100 | {"d0": 2, "d1": "hello", "value": 1}, 101 | ] 102 | 103 | def test_select_lte(self, values): 104 | assert list(values["d0"] <= 2) == [ 105 | {"d0": 1, "d1": "hello", "value": 1}, 106 | {"d0": 1, "d1": "world", "value": 1}, 107 | {"d0": 2, "d1": "hello", "value": 1}, 108 | ] 109 | 110 | def test_isin(self, values): 111 | assert list( 112 | (values["d0"].isin([1, 2])) & (values["d1"] == "hello") 113 | ) == [ 114 | {"d0": 1, "d1": "hello", "value": 1}, 115 | {"d0": 2, "d1": "hello", "value": 1}, 116 | ] 117 | 118 | 119 | class TestOperations: 120 | def test_add(self, values): 121 | copied = copy.deepcopy(values.values) 122 | 123 | new_vals = values.add([{"d0": 3, "d1": "hello", "value": 1}]) 124 | 125 | assert len(values.values) == len(copied) 126 | assert len(values.index) == len(copied) 127 | 128 | assert len(new_vals.values) == len(copied) + 1 129 | assert len(new_vals.index) == len(copied) + 1 130 | assert new_vals.index == [0, 1, 2, 3, 4] 131 | 132 | assert list((new_vals["d0"] == 3) & (new_vals["d1"] == "hello")) == [ 133 | {"d0": 3, "d1": "hello", "value": 1} 134 | ] 135 | 136 | def test_delete(self, values): 137 | copied = copy.deepcopy(values.values) 138 | 139 | new_vals = values.delete(0, inplace=False) 140 | 141 | assert len(values.values) == len(copied) 142 | assert len(values.index) == len(copied) 143 | 144 | assert len(new_vals.values) == len(copied) - 1 145 | assert len(new_vals.index) == len(copied) - 1 146 | assert new_vals.index == [1, 2, 3] 147 | 148 | new_vals.delete(1, inplace=True) 149 | assert len(new_vals.index) == len(copied) - 2 150 | assert len(new_vals.values) == len(copied) - 2 151 | assert new_vals.index == [2, 3] 152 | 153 | def test_as_values(self, values): 154 | queryset = (values["d0"] == 1) | (values["d0"] == 3) 155 | assert list(queryset) == [ 156 | {"d0": 1, "d1": "hello", "value": 1}, 157 | {"d0": 1, "d1": "world", "value": 1}, 158 | {"d0": 3, "d1": "world", "value": 1}, 159 | ] 160 | 161 | new_values = queryset.as_values() 162 | assert list(new_values["d1"] == "hello") == [ 163 | {"d0": 1, "d1": "hello", "value": 1} 164 | ] 165 | assert list( 166 | (new_values["d1"] == "hello") | (new_values["d0"] == 3) 167 | ) == [ 168 | {"d0": 1, "d1": "hello", "value": 1}, 169 | {"d0": 3, "d1": "world", "value": 1}, 170 | ] 171 | 172 | 173 | class TestIndexing: 174 | def test_Values(self, values, _values): 175 | for ix, value in enumerate(_values): 176 | assert values.isel[ix] == _values[ix] 177 | 178 | def test_Slice(self, values, _values): 179 | for ix, value in enumerate(_values): 180 | assert values["d0"][ix] == _values[ix]["d0"] 181 | 182 | def test_QueryResult(self, values): 183 | res1 = (values["d0"] == 1) & (values["d1"] == "hello") 184 | 185 | assert res1.isel[0] == {"d0": 1, "d1": "hello", "value": 1} 186 | 187 | res2 = ((values["d0"] == 1) | (values["d0"] == 2)) & ( 188 | values["d1"] == "hello" 189 | ) 190 | assert list(res2) == res2.isel[:2] 191 | assert res2.isel[:2] == [ 192 | {"d0": 1, "d1": "hello", "value": 1}, 193 | {"d0": 2, "d1": "hello", "value": 1}, 194 | ] 195 | assert res2.isel[0] == {"d0": 1, "d1": "hello", "value": 1} 196 | assert res2.isel[1] == {"d0": 2, "d1": "hello", "value": 1} 197 | 198 | def test_not_implemented(self, values): 199 | with pytest.raises(NotImplementedError): 200 | values["d0"].isel[0] 201 | 202 | with pytest.raises(NotImplementedError): 203 | (values["d0"] == 1)[0] 204 | -------------------------------------------------------------------------------- /paramtools/typing.py: -------------------------------------------------------------------------------- 1 | from typing import NewType, Dict, Any, Callable, Iterable, Union, IO, AnyStr 2 | from pathlib import Path 3 | 4 | ValueObject = NewType("ValueObject", Dict[str, Any]) 5 | CmpFunc = NewType("CmpFunc", Callable[[Any, Iterable], bool]) 6 | 7 | FileDictStringLike = Union[str, Path, IO[AnyStr], Dict[str, any]] 8 | -------------------------------------------------------------------------------- /paramtools/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | import os 4 | from collections import OrderedDict 5 | from typing import Optional, List, Dict, Any 6 | 7 | import fsspec 8 | from fsspec.registry import known_implementations 9 | import marshmallow as ma 10 | 11 | from paramtools.typing import ValueObject, FileDictStringLike 12 | 13 | 14 | def _is_url(maybe_url): 15 | """ 16 | Determine whether string is a URL or not using marshmallow and the URL 17 | schemes available through fsspec. 18 | """ 19 | schemes = ( 20 | set(["http"]) 21 | | known_implementations.keys() 22 | | set(list(fsspec.registry)) 23 | ) 24 | try: 25 | ma.validate.URL(schemes=schemes, require_tld=False)(maybe_url) 26 | return True 27 | except ma.exceptions.ValidationError: 28 | return False 29 | 30 | 31 | def _read( 32 | params_or_path: FileDictStringLike, 33 | storage_options: Optional[Dict[str, Any]] = None, 34 | ): 35 | """ 36 | Read files of the form: 37 | - Local file path. 38 | - Any URL readable by fsspec. For example: 39 | - s3: s3://paramtools-test/defaults.json 40 | - gcs: gs://paramtools-dev/defaults.json 41 | - http: https://somedomain.com/defaults.json 42 | - github: github://PSLmodels:ParamTools@master/paramtools/tests/defaults.json 43 | 44 | """ 45 | if isinstance(params_or_path, str) and os.path.exists(params_or_path): 46 | with open(params_or_path, "r") as f: 47 | return f.read() 48 | 49 | if isinstance(params_or_path, str) and _is_url(params_or_path): 50 | with fsspec.open(params_or_path, "r", **(storage_options or {})) as f: 51 | return f.read() 52 | 53 | if isinstance(params_or_path, str): 54 | return params_or_path 55 | 56 | if isinstance(params_or_path, dict): 57 | return params_or_path 58 | 59 | else: 60 | raise TypeError( 61 | f"Unable to read data of type: {type(params_or_path)}\n" 62 | " Data must be a File Path, URL, String, or Dict." 63 | ) 64 | 65 | 66 | def remove_comments(string): 67 | """ 68 | Remove single and multiline comments from JSON. 69 | 70 | StackOverflow magic: 71 | https://stackoverflow.com/a/18381470/9100772 72 | """ 73 | pattern = r"(\".*?\"|\'.*?\')|(/\*.*?\*/|//[^\r\n]*$)" 74 | # first group captures quoted strings (double or single) 75 | # second group captures comments (//single-line or /* multi-line */) 76 | regex = re.compile(pattern, re.MULTILINE | re.DOTALL) 77 | 78 | def _replacer(match): 79 | # if the 2nd group (capturing comments) is not None, 80 | # it means we have captured a non-quoted (real) comment string. 81 | if match.group(2) is not None: 82 | return "\n" # preserve line numbers 83 | else: # otherwise, we will return the 1st group 84 | return match.group(1) # captured quoted-string 85 | 86 | return regex.sub(_replacer, string) 87 | 88 | 89 | def read_json( 90 | params_or_path: FileDictStringLike, 91 | storage_options: Optional[Dict[str, Any]] = None, 92 | ): 93 | """ 94 | Read JSON data of the form: 95 | - Dict. 96 | - JSON string. 97 | - Local file path. 98 | - Any URL readable by fsspec. For example: 99 | - s3: s3://paramtools-test/defaults.json 100 | - gcs: gs://paramtools-dev/defaults.json 101 | - http: https://somedomain.com/defaults.json 102 | - github: github://PSLmodels:ParamTools@master/paramtools/tests/defaults.json 103 | 104 | """ 105 | res = _read(params_or_path, storage_options) 106 | if isinstance(res, str): 107 | try: 108 | res = remove_comments(res) 109 | return json.loads(res, object_pairs_hook=OrderedDict) 110 | except json.JSONDecodeError as je: 111 | if len(res) > 100: 112 | res = res[:100] + "..." + res[-10:] 113 | raise ValueError(f"Unable to decode JSON string: {res}") from je 114 | 115 | if isinstance(res, dict): 116 | return res 117 | 118 | # Error should be thrown in `_read` 119 | raise TypeError(f"Unknown type: {type(res)}") 120 | 121 | 122 | def get_example_paths(name): 123 | assert name in ("taxparams-demo",) 124 | current_path = os.path.abspath(os.path.dirname(__file__)) 125 | default_spec_path = os.path.join( 126 | current_path, f"examples/{name}/defaults.json" 127 | ) 128 | return default_spec_path 129 | 130 | 131 | class LeafGetter: 132 | """ 133 | Return all non-dict or non-list items of a given object. This object 134 | should be an item or a list or dictionary composed of non-iterable items, 135 | nested dictionaries or nested lists. 136 | 137 | A functional approach was considered instead of this class. However, I was 138 | unable to come up with a way to store all of the leaves without "cheating" 139 | and keeping "leaf" state. 140 | """ 141 | 142 | def __init__(self): 143 | self.leaves = [] 144 | 145 | def get(self, item): 146 | if isinstance(item, dict): 147 | for _, v in item.items(): 148 | self.get(v) 149 | elif isinstance(item, list): 150 | for li in item: 151 | self.get(li) 152 | else: 153 | self.leaves.append(item) 154 | 155 | 156 | def get_leaves(item): 157 | gl = LeafGetter() 158 | gl.get(item) 159 | return gl.leaves 160 | 161 | 162 | def ravel(nlabel_list): 163 | """ only up to 2D for now. """ 164 | if not isinstance(nlabel_list, list): 165 | return nlabel_list 166 | raveled = [] 167 | for maybe_list in nlabel_list: 168 | if isinstance(maybe_list, list): 169 | for item in maybe_list: 170 | raveled.append(item) 171 | else: 172 | raveled.append(maybe_list) 173 | return raveled 174 | 175 | 176 | def consistent_labels(value_items: List[ValueObject]): 177 | """ 178 | Get labels used consistently across all value objects. 179 | Returns None if labels are omitted or added for 180 | some value object(s). 181 | """ 182 | if not value_items: 183 | return set([]) 184 | used = set(k for k in value_items[0] if k not in ("value", "_auto")) 185 | for vo in value_items: 186 | if used != set(k for k in vo if k not in ("value", "_auto")): 187 | return None 188 | return used 189 | 190 | 191 | def ensure_value_object(vo) -> ValueObject: 192 | if not isinstance(vo, list) or ( 193 | isinstance(vo, list) and not isinstance(vo[0], dict) 194 | ): 195 | vo = [{"value": vo}] 196 | return vo 197 | 198 | 199 | def hashable_value_object(vo: ValueObject) -> tuple: 200 | """ 201 | Helper function convertinga value object into a format 202 | that can be stored in a set. 203 | """ 204 | return tuple( 205 | (label, value) 206 | for label, value in sorted(vo.items()) 207 | if label not in ("_auto",) 208 | ) 209 | 210 | 211 | def filter_labels(vo: ValueObject, drop=None, keep=None) -> ValueObject: 212 | """ 213 | Filter a value objects labels by keeping labels 214 | in keep if specified and dropping labels that are in drop. 215 | """ 216 | drop = drop or () 217 | keep = keep or () 218 | return { 219 | lab: lv 220 | for lab, lv in vo.items() 221 | if (lab not in drop) and (not keep or lab in keep) 222 | } 223 | 224 | 225 | def make_label_str(vo: ValueObject) -> str: 226 | """ 227 | Create string from labels. This is used to create error messages. 228 | """ 229 | lab_str = ", ".join( 230 | [ 231 | f"{lab}={vo[lab]}" 232 | for lab in sorted(vo) 233 | if lab not in ("value", "_auto") 234 | ] 235 | ) 236 | if lab_str: 237 | return f"[{lab_str}]" 238 | else: 239 | return "" 240 | 241 | 242 | def grid_sort(vos, label_to_extend, grid): 243 | def key(v): 244 | if label_to_extend in v: 245 | return grid.index(v[label_to_extend]) 246 | else: 247 | return grid[0] 248 | 249 | return sorted(vos, key=key) 250 | -------------------------------------------------------------------------------- /paramtools/values.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from collections import defaultdict 3 | from typing import List, Dict, Any, Union, Generator 4 | 5 | from paramtools.sorted_key_list import SortedKeyList 6 | from paramtools.typing import ValueObject 7 | 8 | 9 | def default_cmp_func(x): 10 | return x 11 | 12 | 13 | class ValueItem: 14 | """ 15 | Handles index-based look-ups on the Values class. 16 | """ 17 | 18 | def __init__(self, values: "Values", index: List[int] = None): 19 | self.values = values 20 | self.index = list(index) if index is not None else index 21 | 22 | def __getitem__(self, item): 23 | if isinstance(item, slice): 24 | if self.index is not None: 25 | indices = item.indices(len(self.index)) 26 | return [ 27 | dict(self.values.values[self.index[ix]]) 28 | for ix in range(*indices) 29 | ] 30 | else: 31 | indices = item.indices(len(self.values)) 32 | return [dict(self.values.values[ix]) for ix in range(*indices)] 33 | elif self.index is not None: 34 | return dict(self.values.values[self.index[item]]) 35 | else: 36 | return dict(self.values.values[item]) 37 | 38 | 39 | class ValueBase: 40 | @property 41 | def cmp_attr(self): 42 | raise NotImplementedError() 43 | 44 | def __eq__(self, value=None, **labels): 45 | return self.cmp_attr.eq(**{self.label: value}) 46 | 47 | def __ne__(self, value): 48 | return self.cmp_attr.ne(**{self.label: value}) 49 | 50 | def __gt__(self, value): 51 | return self.cmp_attr.gt(**{self.label: value}) 52 | 53 | def __ge__(self, value): 54 | return self.cmp_attr.gte(**{self.label: value}) 55 | 56 | def __lt__(self, value): 57 | return self.cmp_attr.lt(**{self.label: value}) 58 | 59 | def __le__(self, value): 60 | return self.cmp_attr.lte(**{self.label: value}) 61 | 62 | def __len__(self): 63 | return len([item for item in iter(self.cmp_attr)]) 64 | 65 | def __iter__(self): 66 | return iter(self.cmp_attr) 67 | 68 | def __getitem__(self, item): 69 | return self.cmp_attr[item] 70 | 71 | def eq(self, value, strict=True): 72 | return self.cmp_attr.eq(strict, **{self.label: value}) 73 | 74 | def ne(self, value, strict=True): 75 | return self.cmp_attr.ne(strict, **{self.label: value}) 76 | 77 | def gt(self, value, strict=True): 78 | return self.cmp_attr.gt(strict, **{self.label: value}) 79 | 80 | def gte(self, value, strict=True): 81 | return self.cmp_attr.gte(strict, **{self.label: value}) 82 | 83 | def lt(self, value, strict=True): 84 | return self.cmp_attr.lt(strict, **{self.label: value}) 85 | 86 | def lte(self, value, strict=True): 87 | return self.cmp_attr.lte(strict, **{self.label: value}) 88 | 89 | def isin(self, value, strict=True): 90 | return self.cmp_attr.isin(strict, **{self.label: value}) 91 | 92 | 93 | class QueryResult(ValueBase): 94 | def __init__(self, values: "Values", index: List[Any]): 95 | self.values = values 96 | self.index = index 97 | 98 | def __and__(self, queryresult: "QueryResult"): 99 | res = set(self.index) & set(queryresult.index) 100 | 101 | return QueryResult(self.values, res) 102 | 103 | def __or__(self, queryresult: "QueryResult"): 104 | res = set(self.index) | set(queryresult.index) 105 | 106 | return QueryResult(self.values, res) 107 | 108 | def __repr__(self): 109 | vo_repr = "\n ".join( 110 | str(dict(self.values.values[i])) for i in (self.index or []) 111 | ) 112 | return f"QueryResult([\n {vo_repr}\n])" 113 | 114 | def __iter__(self): 115 | for i in self.index: 116 | yield self.values.values[i] 117 | 118 | def __getitem__(self, item): 119 | raise NotImplementedError( 120 | "Use .isel to do index-based look ups or as_values to chain queries." 121 | ) 122 | 123 | @property 124 | def isel(self): 125 | return ValueItem(self.values, self.index) 126 | 127 | def tolist(self): 128 | return [self.values.values[i] for i in self.index] 129 | 130 | def eq(self, strict=True, **labels): 131 | return self.cmp_attr.eq(strict, **labels) 132 | 133 | def ne(self, strict=True, **labels): 134 | return self.cmp_attr.ne(strict, **labels) 135 | 136 | def gt(self, strict=True, **labels): 137 | return self.cmp_attr.gt(strict, **labels) 138 | 139 | def gte(self, strict=True, **labels): 140 | return self.cmp_attr.gte(strict, **labels) 141 | 142 | def lt(self, strict=True, **labels): 143 | return self.cmp_attr.lt(strict, **labels) 144 | 145 | def lte(self, strict=True, **labels): 146 | return self.cmp_attr.lte(strict, **labels) 147 | 148 | def isin(self, strict=True, **labels): 149 | return self.cmp_attr.isin(strict, **labels) 150 | 151 | def __eq__(self, *args, **kwargs): 152 | raise NotImplementedError() 153 | 154 | def __ne__(self, *args, **kwargs): 155 | raise NotImplementedError() 156 | 157 | def __gt__(self, *args, **kwargs): 158 | raise NotImplementedError() 159 | 160 | def __ge__(self, *args, **kwargs): 161 | raise NotImplementedError() 162 | 163 | def __lt__(self, *args, **kwargs): 164 | raise NotImplementedError() 165 | 166 | def __le__(self, *args, **kwargs): 167 | raise NotImplementedError() 168 | 169 | def as_values(self): 170 | return Values( 171 | values=list(self), index=self.index, keyfuncs=self.values.keyfuncs 172 | ) 173 | 174 | def delete(self): 175 | self.values.delete(*self.index, inplace=True) 176 | 177 | @property 178 | def cmp_attr(self): 179 | return self.values 180 | 181 | 182 | class Slice(ValueBase): 183 | def __init__(self, values: "Values", label: str): 184 | self.values = values 185 | self.label = label 186 | 187 | @property 188 | def cmp_attr(self): 189 | return self.values 190 | 191 | def __getitem__(self, item): 192 | if isinstance(item, slice): 193 | indices = item.indices(len(self)) 194 | return [ 195 | self.values.values[ix].get(self.label, None) 196 | for ix in range(*indices) 197 | ] 198 | else: 199 | return self.values.values[item][self.label] 200 | 201 | @property 202 | def isel(self): 203 | raise NotImplementedError( 204 | "Access values of a Slice object directly: parameters['label'][1]" 205 | ) 206 | 207 | def __repr__(self): 208 | vo_repr = "\n ".join( 209 | str(dict(self.values.values[i])) for i in self.values.values 210 | ) 211 | return f"Slice([\n {vo_repr}\n], \nlabel={self.label})" 212 | 213 | 214 | class Values(ValueBase): 215 | """ 216 | The Values class is used to query and update parameter values. 217 | 218 | For more information, checkout the 219 | `Viewing Data `_ docs. 220 | """ 221 | 222 | def __init__( 223 | self, 224 | values: List[ValueObject], 225 | keyfuncs: Dict[str, Any] = None, 226 | skls: Dict[str, SortedKeyList] = None, 227 | index: List[Any] = None, 228 | ): 229 | self.index = index or list(range(len(values))) 230 | self.values = {ix: value for ix, value in zip(self.index, values)} 231 | self.keyfuncs = keyfuncs 232 | self.label = "value" 233 | 234 | if skls is not None: 235 | self.skls = skls 236 | else: 237 | self.skls = self.build_skls(self.values, keyfuncs or {}) 238 | 239 | def build_skls(self, values, keyfuncs): 240 | label_values = defaultdict(list) 241 | label_index = defaultdict(list) 242 | for ix, vo in values.items(): 243 | for label, value in vo.items(): 244 | label_values[label].append(value) 245 | label_index[label].append(ix) 246 | 247 | skls = {} 248 | for label in label_values: 249 | keyfunc = self.get_keyfunc(label, keyfuncs) 250 | skls[label] = SortedKeyList( 251 | label_values[label], keyfunc, label_index[label] 252 | ) 253 | 254 | return skls 255 | 256 | def update_skls(self, values): 257 | # TODO: remove existing values with clashing index 258 | for ix, vo in values.items(): 259 | for label, value in vo.items(): 260 | if self.skls.get(label, None) is not None: 261 | self.skls[label].add(value, index=ix) 262 | else: 263 | self.skls[label] = SortedKeyList( 264 | [value], 265 | keyfunc=self.get_keyfunc(label, self.keyfuncs), 266 | index=[ix], 267 | ) 268 | 269 | def get_keyfunc(self, label, keyfuncs): 270 | keyfunc = keyfuncs.get(label) 271 | return keyfunc or default_cmp_func 272 | 273 | def _cmp(self, op, strict, **labels): 274 | label, value = list(labels.items())[0] 275 | skl = self.skls.get(label, None) 276 | 277 | if skl is None and strict: 278 | raise KeyError(f"Unknown label: {label}.") 279 | elif skl is None and not strict: 280 | return QueryResult(self, list(self.index)) 281 | 282 | skl_result = getattr(self.skls[label], op)(value) 283 | if not strict: 284 | match_index = skl_result.index if skl_result else [] 285 | missing = self.missing(label) 286 | match_index = set(match_index + missing.index) 287 | elif skl_result is None: 288 | match_index = [] 289 | else: 290 | match_index = skl_result.index 291 | return QueryResult(self, match_index) 292 | 293 | def __getitem__(self, label): 294 | if label not in self.skls: 295 | raise KeyError(f"Unknown label: {label}") 296 | return Slice(self, label) 297 | 298 | def missing(self, label: str): 299 | index = list(set(self.index) - self.skls[label].index) 300 | return QueryResult(self, index) 301 | 302 | def eq(self, strict=True, **labels): 303 | """ 304 | Returns values that match the given label: 305 | 306 | .. code-block:: Python 307 | 308 | params.sel["my_param"].eq(my_label=5) 309 | params.sel["my_param"]["my_label"] == 5 310 | """ 311 | return self._cmp("eq", strict, **labels) 312 | 313 | def ne(self, strict=True, **labels): 314 | """ 315 | Returns values that do match the given label: 316 | 317 | .. code-block:: Python 318 | 319 | params.sel["my_param"].ne(my_label=5) 320 | params.sel["my_param"]["my_label"] != 5 321 | """ 322 | 323 | return self._cmp("ne", strict, **labels) 324 | 325 | def gt(self, strict=True, **labels): 326 | """ 327 | Returns values that have label values greater than the label value: 328 | 329 | .. code-block:: Python 330 | 331 | params.sel["my_param"].gt(my_label=5) 332 | params.sel["my_param"]["my_label"] > 5 333 | 334 | """ 335 | 336 | return self._cmp("gt", strict, **labels) 337 | 338 | def gte(self, strict=True, **labels): 339 | """ 340 | Returns values that have label values greater than or equal to the label value: 341 | 342 | .. code-block:: Python 343 | 344 | params.sel["my_param"].gte(my_label=5) 345 | params.sel["my_param"]["my_label"] >= 5 346 | 347 | """ 348 | return self._cmp("gte", strict, **labels) 349 | 350 | def lt(self, strict=True, **labels): 351 | """ 352 | Returns values that have label values less than the label value: 353 | 354 | .. code-block:: Python 355 | 356 | params.sel["my_param"].lt(my_label=5) 357 | params.sel["my_param"]["my_label"] < 5 358 | 359 | """ 360 | 361 | return self._cmp("lt", strict, **labels) 362 | 363 | def lte(self, strict=True, **labels): 364 | """ 365 | Returns values that have label values less than or equal to the label value: 366 | 367 | .. code-block:: Python 368 | 369 | params.sel["my_param"].lte(my_label=5) 370 | params.sel["my_param"]["my_label"] <= 5 371 | 372 | """ 373 | 374 | return self._cmp("lte", strict, **labels) 375 | 376 | def isin(self, strict=True, **labels): 377 | """ 378 | Returns values that have label values less than or equal to the label value: 379 | 380 | .. code-block:: Python 381 | 382 | params.sel["my_param"].isin(my_label=[5, 6]) 383 | 384 | """ 385 | 386 | label, values = list(labels.items())[0] 387 | return union( 388 | self.eq(strict=strict, **{label: value}) for value in values 389 | ) 390 | 391 | def add( 392 | self, values: List[ValueObject], index: List[Any] = None, inplace=False 393 | ): 394 | if index is not None: 395 | assert len(index) == len(values) 396 | new_index = index 397 | else: 398 | max_index = max(self.index) if self.index else 0 399 | new_index = [max_index + i + 1 for i in range(len(values))] 400 | 401 | new_values = {ix: value for ix, value in zip(new_index, values)} 402 | 403 | if inplace: 404 | self.update_skls(new_values) 405 | self.values.update(new_values) 406 | self.index += new_index 407 | else: 408 | current_index = list(self.index) 409 | updated_values = dict(self.values) 410 | updated_values.update(new_values) 411 | return Values( 412 | [value for value in updated_values.values()], 413 | skls=self.build_skls(updated_values, self.keyfuncs), 414 | index=current_index + new_index, 415 | ) 416 | 417 | def delete(self, *index, inplace=False): 418 | if not index: 419 | index = list(self.index) 420 | if inplace: 421 | for ix in index: 422 | self.values.pop(ix) 423 | self.index.remove(ix) 424 | self.skls = self.build_skls(self.values, self.keyfuncs) 425 | else: 426 | new_index = list(self.index) 427 | new_values = copy.deepcopy(self.values) 428 | for ix in index: 429 | new_values.pop(ix) 430 | new_index.remove(ix) 431 | 432 | return Values( 433 | [value for value in new_values.values()], 434 | keyfuncs=self.keyfuncs, 435 | index=new_index, 436 | ) 437 | 438 | @property 439 | def cmp_attr(self): 440 | return self 441 | 442 | @property 443 | def isel(self): 444 | """ 445 | Select values by their index: 446 | 447 | .. code-block:: Python 448 | 449 | params.sel["my_param"].isel[0] 450 | params.sel["my_param"].isel[:5] 451 | 452 | """ 453 | 454 | return ValueItem(self, self.index) 455 | 456 | @property 457 | def labels(self): 458 | return list(self.skls.keys()) 459 | 460 | def __eq__(self, other): 461 | if isinstance(other, ValueBase): 462 | return list(self) == list(other) 463 | elif isinstance(other, list): 464 | return list(self) == other 465 | else: 466 | raise TypeError(f"Unable to compare Values against {type(other)}") 467 | 468 | def __and__(self, queryresult: "QueryResult"): 469 | """ 470 | Combine queries with logical 'and': 471 | 472 | .. code-block:: Python 473 | 474 | my_param = params.sel["my_param] 475 | (my_param["my_label"] == 5) & (my_param["oth_label"] == "hello") 476 | """ 477 | 478 | res = set(self.index) & set(queryresult.index) 479 | 480 | return QueryResult(self, res) 481 | 482 | def __or__(self, queryresult: "QueryResult"): 483 | """ 484 | Combine queries with logical 'or': 485 | 486 | .. code-block:: Python 487 | 488 | my_param = params.sel["my_param] 489 | (my_param["my_label"] == 5) | (my_param["oth_label"] == "hello") 490 | """ 491 | 492 | res = set(self.index) | set(queryresult.index) 493 | 494 | return QueryResult(self, res) 495 | 496 | def __iter__(self): 497 | for value in self.values.values(): 498 | yield value 499 | 500 | def __repr__(self): 501 | vo_repr = ( 502 | ",\n ".join(str(dict(self.values[i])) for i in self.index) + "," 503 | ) 504 | return f"Values([\n {vo_repr}\n])" 505 | 506 | 507 | def union( 508 | queryresults: Union[List[ValueBase], Generator[ValueBase, None, None]] 509 | ): 510 | result = None 511 | for queryresult in queryresults: 512 | if result is None: 513 | result = queryresult 514 | else: 515 | result |= queryresult 516 | 517 | return result or QueryResult(None, []) 518 | 519 | 520 | def intersection( 521 | queryresults: Union[List[ValueBase], Generator[ValueBase, None, None]] 522 | ): 523 | result = None 524 | for queryresult in queryresults: 525 | if result is None: 526 | result = queryresult 527 | else: 528 | result &= queryresult 529 | 530 | return result or QueryResult(None, []) 531 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 79 3 | 4 | [tool.pytest.ini_options] 5 | markers = ["network_bound"] 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | import os 3 | 4 | with open("README.md", "r") as f: 5 | long_description = f.read() 6 | 7 | setuptools.setup( 8 | name="paramtools", 9 | version=os.environ.get("VERSION", "0.20.0"), 10 | author="Hank Doupe", 11 | author_email="henrymdoupe@gmail.com", 12 | description=( 13 | "Library for parameter processing and validation with a focus " 14 | "on computational modeling projects" 15 | ), 16 | long_description=long_description, 17 | long_description_content_type="text/markdown", 18 | url="https://github.com/hdoupe/ParamTools", 19 | packages=setuptools.find_packages(), 20 | install_requires=[ 21 | "marshmallow>=4.0.0", 22 | "numpy", 23 | "python-dateutil>=2.8.0", 24 | "fsspec", 25 | "sortedcontainers", 26 | ], 27 | include_package_data=True, 28 | classifiers=[ 29 | "Programming Language :: Python :: 3", 30 | "License :: OSI Approved :: MIT License", 31 | "Operating System :: OS Independent", 32 | ], 33 | ) 34 | --------------------------------------------------------------------------------