├── .copier-answers.yml ├── .github └── workflows │ ├── linting.yml │ ├── publish-to-pypi.yml │ ├── smoke-test.yml │ └── testing-and-coverage.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yml ├── .travis.yml ├── LICENSE ├── README.md ├── archive ├── README.md ├── x_setup.cfg ├── x_setup.py └── x_tox.ini ├── docs ├── Makefile ├── conf.py ├── index.rst ├── notebooks.rst ├── notebooks │ ├── README.md │ └── intro_notebook.ipynb └── requirements.txt ├── pyproject.toml ├── src ├── .pylintrc └── flexcode │ ├── __init__.py │ ├── basis_functions.py │ ├── core.py │ ├── helpers.py │ ├── loss_functions.py │ ├── post_processing.py │ └── regression_models.py ├── tests └── flexcode │ ├── .pylintrc │ ├── conftest.py │ ├── context.py │ ├── test_cv_optim.py │ ├── test_models_fit.py │ ├── test_params_handling.py │ └── test_post_processing.py ├── tutorial └── Flexcode-tutorial-teddy.ipynb └── vignettes ├── Custom Class.ipynb └── Model Save and Bumps Removal - Flexcode.ipynb /.copier-answers.yml: -------------------------------------------------------------------------------- 1 | # Changes here will be overwritten by Copier 2 | _commit: v1.3.0 3 | _src_path: gh:lincc-frameworks/python-project-template 4 | author_email: annlee@andrew.cmu.edu 5 | author_name: Ann Lee 6 | create_example_module: false 7 | custom_install: true 8 | include_notebooks: true 9 | module_name: flexcode 10 | mypy_type_checking: none 11 | preferred_linter: black 12 | project_license: none 13 | project_name: flexcode 14 | use_gitlfs: none 15 | use_isort: true 16 | -------------------------------------------------------------------------------- /.github/workflows/linting.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, then perform static linting analysis. 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Lint 5 | 6 | on: 7 | push: 8 | branches: [ main ] 9 | pull_request: 10 | branches: [ main ] 11 | 12 | jobs: 13 | build: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v3 17 | - name: Set up Python 3.10 18 | uses: actions/setup-python@v4 19 | with: 20 | python-version: '3.10' 21 | - name: Install dependencies 22 | run: | 23 | sudo apt-get update 24 | python -m pip install --upgrade pip 25 | pip install . 26 | pip install .[dev] 27 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 28 | - name: Analyze code with linter 29 | 30 | uses: psf/black@stable 31 | with: 32 | src: ./src 33 | 34 | -------------------------------------------------------------------------------- /.github/workflows/publish-to-pypi.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Set up Python 26 | uses: actions/setup-python@v4 27 | with: 28 | python-version: '3.x' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | - name: Build package 34 | run: python -m build 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_API_TOKEN }} -------------------------------------------------------------------------------- /.github/workflows/smoke-test.yml: -------------------------------------------------------------------------------- 1 | # This workflow will run daily at 06:45. 2 | # It will install Python dependencies and run tests with a variety of Python versions. 3 | 4 | name: Unit test smoke test 5 | 6 | on: 7 | schedule: 8 | - cron: 45 6 * * * 9 | 10 | jobs: 11 | build: 12 | 13 | runs-on: ubuntu-latest 14 | strategy: 15 | matrix: 16 | python-version: ['3.8', '3.9', '3.10'] 17 | 18 | steps: 19 | - uses: actions/checkout@v3 20 | - name: Set up Python ${{ matrix.python-version }} 21 | uses: actions/setup-python@v4 22 | with: 23 | python-version: ${{ matrix.python-version }} 24 | - name: Install dependencies 25 | run: | 26 | sudo apt-get update 27 | python -m pip install --upgrade pip 28 | pip install . 29 | pip install .[dev] 30 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 31 | - name: Run unit tests with pytest 32 | run: | 33 | python -m pytest tests 34 | -------------------------------------------------------------------------------- /.github/workflows/testing-and-coverage.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and report code coverage with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Unit test and code coverage 5 | 6 | on: 7 | push: 8 | branches: [ main ] 9 | pull_request: 10 | branches: [ main ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: ['3.8', '3.9', '3.10'] 19 | 20 | steps: 21 | - uses: actions/checkout@v3 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v4 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install dependencies 27 | run: | 28 | sudo apt-get update 29 | python -m pip install --upgrade pip 30 | pip install . 31 | pip install .[dev] 32 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 33 | - name: Run unit tests with pytest 34 | run: | 35 | python -m pytest tests --cov=flexcode --cov-report=xml 36 | - name: Upload coverage report to codecov 37 | uses: codecov/codecov-action@v3 38 | - name: Install notebook requirements 39 | run: | 40 | sudo apt-get install pandoc 41 | - name: Build docs 42 | run: | 43 | sphinx-build -T -E -b html -d docs/build/doctrees ./docs docs/build/html -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | _version.py 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | _readthedocs/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | 133 | # vscode 134 | .vscode/ 135 | 136 | # dask 137 | dask-worker-space/ 138 | 139 | # tmp directory 140 | tmp/ 141 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | 3 | # Clear output from jupyter notebooks so that only the input cells are committed. 4 | - repo: local 5 | hooks: 6 | - id: jupyter-nb-clear-output 7 | name: jupyter-nb-clear-output 8 | description: Clear output from Jupyter notebooks. 9 | files: \.ipynb$ 10 | stages: [commit] 11 | language: system 12 | entry: jupyter nbconvert --ClearOutputPreprocessor.enabled=True --inplace 13 | 14 | # Run unit tests, verify that they pass. Note that coverage is run against 15 | # the ./src directory here because that is what will be committed. In the 16 | # github workflow script, the coverage is run against the installed package 17 | # and uploaded to Codecov by calling pytest like so: 18 | # `python -m pytest --cov= --cov-report=xml` 19 | - repo: local 20 | hooks: 21 | - id: pytest-check 22 | name: pytest-check 23 | description: Run unit tests with pytest. 24 | entry: bash -c "if python -m pytest --co -qq; then python -m pytest --cov=./src --cov-report=html; fi" 25 | language: system 26 | pass_filenames: false 27 | always_run: true 28 | 29 | # prevents committing directly branches named 'main' and 'master'. 30 | - repo: https://github.com/pre-commit/pre-commit-hooks 31 | rev: v4.4.0 32 | hooks: 33 | - id: no-commit-to-branch 34 | name: Don't commit to main or master branch 35 | description: Prevent the user from committing directly to the primary branch. 36 | - id: check-added-large-files 37 | name: Check for large files 38 | description: Prevent the user from committing very large files. 39 | args: ['--maxkb=500'] 40 | 41 | # verify that pyproject.toml is well formed 42 | - repo: https://github.com/abravalheri/validate-pyproject 43 | rev: v0.12.1 44 | hooks: 45 | - id: validate-pyproject 46 | name: Validate syntax of pyproject.toml 47 | description: Verify that pyproject.toml adheres to the established schema. 48 | 49 | # Automatically sort the imports used in .py files 50 | - repo: https://github.com/pycqa/isort 51 | rev: 5.12.0 52 | hooks: 53 | - id: isort 54 | name: isort (python files in src/ and tests/) 55 | description: Sort and organize imports in .py files. 56 | types: [python] 57 | files: ^(src|tests)/ 58 | 59 | 60 | # Analyze the code style and report code that doesn't adhere. 61 | - repo: https://github.com/psf/black 62 | rev: 23.1.0 63 | hooks: 64 | - id: black 65 | types: [python] 66 | files: ^(src|tests)/ 67 | # It is recommended to specify the latest version of Python 68 | # supported by your project here, or alternatively use 69 | # pre-commit's default_language_version, see 70 | # https://pre-commit.com/#top_level-default_language_version 71 | language_version: python3.10 72 | 73 | 74 | # Make sure Sphinx can build the documentation while explicitly omitting 75 | # notebooks from the docs, so users don't have to wait through the execution 76 | # of each notebook or each commit. By default, these will be checked in the 77 | # GitHub workflows. 78 | - repo: local 79 | hooks: 80 | - id: sphinx-build 81 | name: Build documentation with Sphinx 82 | entry: sphinx-build 83 | language: system 84 | always_run: true 85 | exclude_types: [file, symlink] 86 | args: 87 | [ 88 | "-M", # Run sphinx in make mode, so we can use -D flag later 89 | # Note: -M requires next 3 args to be builder, source, output 90 | "html", # Specify builder 91 | "./docs", # Source directory of documents 92 | "./docs/build/html", # Output directory for rendered documents 93 | "-T", # Show full trace back on exception 94 | "-E", # Don't use saved env; always read all files 95 | "-d", # Flag for cached environment and doctrees 96 | "./docs/build/doctrees", # Directory 97 | "-D", # Flag to override settings in conf.py 98 | "exclude_patterns=notebooks/*", # Exclude our notebooks from pre-commit 99 | ] -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | build: 9 | os: ubuntu-22.04 10 | tools: 11 | python: "3.10" 12 | 13 | # Build documentation in the docs/ directory with Sphinx 14 | sphinx: 15 | configuration: docs/conf.py 16 | 17 | # Optionally declare the Python requirements required to build your docs 18 | python: 19 | install: 20 | - requirements: docs/requirements.txt 21 | - method: pip 22 | path: . -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: false 2 | language: python 3 | python: 4 | - '2.7' 5 | - '3.6' 6 | install: pip install tox-travis 7 | script: tox 8 | deploy: 9 | provider: pypi 10 | on: 11 | tags: true 12 | condition: $TRAVIS_PYTHON_VERSION = "3.6" 13 | user: tpospisi 14 | password: 15 | secure: oOtSw47ptRh+ELY0r3llOm2tjMy5iqhaxpnMEVgp48owe2MhvMxvvhtSxvFfRklzmFjB45hBVuERYjp1LXA9paC1rCEnigqgHisiyNHx7+QdQ4qGehv155pTDUOSXiPObiozEgE2rM6QXkwdAbAVvDt9f7YA0UjgQ+AQrDGBGuAWxqLi/aoCsqUcnwdkk64CNJq4J9W9S7KfJuUPcOVTorujA21p9/r8sOUUFDwkb7TmXOc+8LDDl6IEkgjhcckjOMGBW4Byh+bz6EfSbM1JkHdSlBLYFMdysO9WII3YhvMr4kGTEmBfeSI5p/J9lAtQAexiV2pWLobMFOZYt37LHDtMY7xaBR1bWPITkJevS26rFQ6WB+hypi1TDINB9apVXOzw/rPd8OvjKg4H6pNbCIGHL1NEwHxuvEIVFVxU/n9981CesC02NyC9PqIiL2ahihkk0f0j7b/an2tyundipgPPydg9Ut+GWVGG597Nban4awEy3zaCdGF8VCzaW76/neT0zrteOVKtpzSg/yJBq9/7Zqm4L7uCX5qGFIpguT4B4iab2GMPOZ5xSE5NqRjq2++3xr8aP86MSyJel4d1z/4C22FBmsugJb6QoiQD4VdvJVVDKxNqLC8zOAghZ4sIJ3/EgCDLSw5h+3iT/VcLiaQZLio+PqN/Zrx//LYFgDg= 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 2, June 1991 3 | 4 | Copyright (C) 1989, 1991 Free Software Foundation, Inc., 5 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA 6 | Everyone is permitted to copy and distribute verbatim copies 7 | of this license document, but changing it is not allowed. 8 | 9 | Preamble 10 | 11 | The licenses for most software are designed to take away your 12 | freedom to share and change it. By contrast, the GNU General Public 13 | License is intended to guarantee your freedom to share and change free 14 | software--to make sure the software is free for all its users. This 15 | General Public License applies to most of the Free Software 16 | Foundation's software and to any other program whose authors commit to 17 | using it. (Some other Free Software Foundation software is covered by 18 | the GNU Lesser General Public License instead.) You can apply it to 19 | your programs, too. 20 | 21 | When we speak of free software, we are referring to freedom, not 22 | price. Our General Public Licenses are designed to make sure that you 23 | have the freedom to distribute copies of free software (and charge for 24 | this service if you wish), that you receive source code or can get it 25 | if you want it, that you can change the software or use pieces of it 26 | in new free programs; and that you know you can do these things. 27 | 28 | To protect your rights, we need to make restrictions that forbid 29 | anyone to deny you these rights or to ask you to surrender the rights. 30 | These restrictions translate to certain responsibilities for you if you 31 | distribute copies of the software, or if you modify it. 32 | 33 | For example, if you distribute copies of such a program, whether 34 | gratis or for a fee, you must give the recipients all the rights that 35 | you have. You must make sure that they, too, receive or can get the 36 | source code. And you must show them these terms so they know their 37 | rights. 38 | 39 | We protect your rights with two steps: (1) copyright the software, and 40 | (2) offer you this license which gives you legal permission to copy, 41 | distribute and/or modify the software. 42 | 43 | Also, for each author's protection and ours, we want to make certain 44 | that everyone understands that there is no warranty for this free 45 | software. If the software is modified by someone else and passed on, we 46 | want its recipients to know that what they have is not the original, so 47 | that any problems introduced by others will not reflect on the original 48 | authors' reputations. 49 | 50 | Finally, any free program is threatened constantly by software 51 | patents. We wish to avoid the danger that redistributors of a free 52 | program will individually obtain patent licenses, in effect making the 53 | program proprietary. To prevent this, we have made it clear that any 54 | patent must be licensed for everyone's free use or not licensed at all. 55 | 56 | The precise terms and conditions for copying, distribution and 57 | modification follow. 58 | 59 | GNU GENERAL PUBLIC LICENSE 60 | TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION 61 | 62 | 0. This License applies to any program or other work which contains 63 | a notice placed by the copyright holder saying it may be distributed 64 | under the terms of this General Public License. The "Program", below, 65 | refers to any such program or work, and a "work based on the Program" 66 | means either the Program or any derivative work under copyright law: 67 | that is to say, a work containing the Program or a portion of it, 68 | either verbatim or with modifications and/or translated into another 69 | language. (Hereinafter, translation is included without limitation in 70 | the term "modification".) Each licensee is addressed as "you". 71 | 72 | Activities other than copying, distribution and modification are not 73 | covered by this License; they are outside its scope. The act of 74 | running the Program is not restricted, and the output from the Program 75 | is covered only if its contents constitute a work based on the 76 | Program (independent of having been made by running the Program). 77 | Whether that is true depends on what the Program does. 78 | 79 | 1. You may copy and distribute verbatim copies of the Program's 80 | source code as you receive it, in any medium, provided that you 81 | conspicuously and appropriately publish on each copy an appropriate 82 | copyright notice and disclaimer of warranty; keep intact all the 83 | notices that refer to this License and to the absence of any warranty; 84 | and give any other recipients of the Program a copy of this License 85 | along with the Program. 86 | 87 | You may charge a fee for the physical act of transferring a copy, and 88 | you may at your option offer warranty protection in exchange for a fee. 89 | 90 | 2. You may modify your copy or copies of the Program or any portion 91 | of it, thus forming a work based on the Program, and copy and 92 | distribute such modifications or work under the terms of Section 1 93 | above, provided that you also meet all of these conditions: 94 | 95 | a) You must cause the modified files to carry prominent notices 96 | stating that you changed the files and the date of any change. 97 | 98 | b) You must cause any work that you distribute or publish, that in 99 | whole or in part contains or is derived from the Program or any 100 | part thereof, to be licensed as a whole at no charge to all third 101 | parties under the terms of this License. 102 | 103 | c) If the modified program normally reads commands interactively 104 | when run, you must cause it, when started running for such 105 | interactive use in the most ordinary way, to print or display an 106 | announcement including an appropriate copyright notice and a 107 | notice that there is no warranty (or else, saying that you provide 108 | a warranty) and that users may redistribute the program under 109 | these conditions, and telling the user how to view a copy of this 110 | License. (Exception: if the Program itself is interactive but 111 | does not normally print such an announcement, your work based on 112 | the Program is not required to print an announcement.) 113 | 114 | These requirements apply to the modified work as a whole. If 115 | identifiable sections of that work are not derived from the Program, 116 | and can be reasonably considered independent and separate works in 117 | themselves, then this License, and its terms, do not apply to those 118 | sections when you distribute them as separate works. But when you 119 | distribute the same sections as part of a whole which is a work based 120 | on the Program, the distribution of the whole must be on the terms of 121 | this License, whose permissions for other licensees extend to the 122 | entire whole, and thus to each and every part regardless of who wrote it. 123 | 124 | Thus, it is not the intent of this section to claim rights or contest 125 | your rights to work written entirely by you; rather, the intent is to 126 | exercise the right to control the distribution of derivative or 127 | collective works based on the Program. 128 | 129 | In addition, mere aggregation of another work not based on the Program 130 | with the Program (or with a work based on the Program) on a volume of 131 | a storage or distribution medium does not bring the other work under 132 | the scope of this License. 133 | 134 | 3. You may copy and distribute the Program (or a work based on it, 135 | under Section 2) in object code or executable form under the terms of 136 | Sections 1 and 2 above provided that you also do one of the following: 137 | 138 | a) Accompany it with the complete corresponding machine-readable 139 | source code, which must be distributed under the terms of Sections 140 | 1 and 2 above on a medium customarily used for software interchange; or, 141 | 142 | b) Accompany it with a written offer, valid for at least three 143 | years, to give any third party, for a charge no more than your 144 | cost of physically performing source distribution, a complete 145 | machine-readable copy of the corresponding source code, to be 146 | distributed under the terms of Sections 1 and 2 above on a medium 147 | customarily used for software interchange; or, 148 | 149 | c) Accompany it with the information you received as to the offer 150 | to distribute corresponding source code. (This alternative is 151 | allowed only for noncommercial distribution and only if you 152 | received the program in object code or executable form with such 153 | an offer, in accord with Subsection b above.) 154 | 155 | The source code for a work means the preferred form of the work for 156 | making modifications to it. For an executable work, complete source 157 | code means all the source code for all modules it contains, plus any 158 | associated interface definition files, plus the scripts used to 159 | control compilation and installation of the executable. However, as a 160 | special exception, the source code distributed need not include 161 | anything that is normally distributed (in either source or binary 162 | form) with the major components (compiler, kernel, and so on) of the 163 | operating system on which the executable runs, unless that component 164 | itself accompanies the executable. 165 | 166 | If distribution of executable or object code is made by offering 167 | access to copy from a designated place, then offering equivalent 168 | access to copy the source code from the same place counts as 169 | distribution of the source code, even though third parties are not 170 | compelled to copy the source along with the object code. 171 | 172 | 4. You may not copy, modify, sublicense, or distribute the Program 173 | except as expressly provided under this License. Any attempt 174 | otherwise to copy, modify, sublicense or distribute the Program is 175 | void, and will automatically terminate your rights under this License. 176 | However, parties who have received copies, or rights, from you under 177 | this License will not have their licenses terminated so long as such 178 | parties remain in full compliance. 179 | 180 | 5. You are not required to accept this License, since you have not 181 | signed it. However, nothing else grants you permission to modify or 182 | distribute the Program or its derivative works. These actions are 183 | prohibited by law if you do not accept this License. Therefore, by 184 | modifying or distributing the Program (or any work based on the 185 | Program), you indicate your acceptance of this License to do so, and 186 | all its terms and conditions for copying, distributing or modifying 187 | the Program or works based on it. 188 | 189 | 6. Each time you redistribute the Program (or any work based on the 190 | Program), the recipient automatically receives a license from the 191 | original licensor to copy, distribute or modify the Program subject to 192 | these terms and conditions. You may not impose any further 193 | restrictions on the recipients' exercise of the rights granted herein. 194 | You are not responsible for enforcing compliance by third parties to 195 | this License. 196 | 197 | 7. If, as a consequence of a court judgment or allegation of patent 198 | infringement or for any other reason (not limited to patent issues), 199 | conditions are imposed on you (whether by court order, agreement or 200 | otherwise) that contradict the conditions of this License, they do not 201 | excuse you from the conditions of this License. If you cannot 202 | distribute so as to satisfy simultaneously your obligations under this 203 | License and any other pertinent obligations, then as a consequence you 204 | may not distribute the Program at all. For example, if a patent 205 | license would not permit royalty-free redistribution of the Program by 206 | all those who receive copies directly or indirectly through you, then 207 | the only way you could satisfy both it and this License would be to 208 | refrain entirely from distribution of the Program. 209 | 210 | If any portion of this section is held invalid or unenforceable under 211 | any particular circumstance, the balance of the section is intended to 212 | apply and the section as a whole is intended to apply in other 213 | circumstances. 214 | 215 | It is not the purpose of this section to induce you to infringe any 216 | patents or other property right claims or to contest validity of any 217 | such claims; this section has the sole purpose of protecting the 218 | integrity of the free software distribution system, which is 219 | implemented by public license practices. Many people have made 220 | generous contributions to the wide range of software distributed 221 | through that system in reliance on consistent application of that 222 | system; it is up to the author/donor to decide if he or she is willing 223 | to distribute software through any other system and a licensee cannot 224 | impose that choice. 225 | 226 | This section is intended to make thoroughly clear what is believed to 227 | be a consequence of the rest of this License. 228 | 229 | 8. If the distribution and/or use of the Program is restricted in 230 | certain countries either by patents or by copyrighted interfaces, the 231 | original copyright holder who places the Program under this License 232 | may add an explicit geographical distribution limitation excluding 233 | those countries, so that distribution is permitted only in or among 234 | countries not thus excluded. In such case, this License incorporates 235 | the limitation as if written in the body of this License. 236 | 237 | 9. The Free Software Foundation may publish revised and/or new versions 238 | of the General Public License from time to time. Such new versions will 239 | be similar in spirit to the present version, but may differ in detail to 240 | address new problems or concerns. 241 | 242 | Each version is given a distinguishing version number. If the Program 243 | specifies a version number of this License which applies to it and "any 244 | later version", you have the option of following the terms and conditions 245 | either of that version or of any later version published by the Free 246 | Software Foundation. If the Program does not specify a version number of 247 | this License, you may choose any version ever published by the Free Software 248 | Foundation. 249 | 250 | 10. If you wish to incorporate parts of the Program into other free 251 | programs whose distribution conditions are different, write to the author 252 | to ask for permission. For software which is copyrighted by the Free 253 | Software Foundation, write to the Free Software Foundation; we sometimes 254 | make exceptions for this. Our decision will be guided by the two goals 255 | of preserving the free status of all derivatives of our free software and 256 | of promoting the sharing and reuse of software generally. 257 | 258 | NO WARRANTY 259 | 260 | 11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY 261 | FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN 262 | OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES 263 | PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED 264 | OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 265 | MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS 266 | TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE 267 | PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, 268 | REPAIR OR CORRECTION. 269 | 270 | 12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 271 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR 272 | REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, 273 | INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING 274 | OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED 275 | TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY 276 | YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER 277 | PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE 278 | POSSIBILITY OF SUCH DAMAGES. 279 | 280 | END OF TERMS AND CONDITIONS 281 | 282 | How to Apply These Terms to Your New Programs 283 | 284 | If you develop a new program, and you want it to be of the greatest 285 | possible use to the public, the best way to achieve this is to make it 286 | free software which everyone can redistribute and change under these terms. 287 | 288 | To do so, attach the following notices to the program. It is safest 289 | to attach them to the start of each source file to most effectively 290 | convey the exclusion of warranty; and each file should have at least 291 | the "copyright" line and a pointer to where the full notice is found. 292 | 293 | {description} 294 | Copyright (C) {year} {fullname} 295 | 296 | This program is free software; you can redistribute it and/or modify 297 | it under the terms of the GNU General Public License as published by 298 | the Free Software Foundation; either version 2 of the License, or 299 | (at your option) any later version. 300 | 301 | This program is distributed in the hope that it will be useful, 302 | but WITHOUT ANY WARRANTY; without even the implied warranty of 303 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 304 | GNU General Public License for more details. 305 | 306 | You should have received a copy of the GNU General Public License along 307 | with this program; if not, write to the Free Software Foundation, Inc., 308 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. 309 | 310 | Also add information on how to contact you by electronic and paper mail. 311 | 312 | If the program is interactive, make it output a short notice like this 313 | when it starts in an interactive mode: 314 | 315 | Gnomovision version 69, Copyright (C) year name of author 316 | Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 317 | This is free software, and you are welcome to redistribute it 318 | under certain conditions; type `show c' for details. 319 | 320 | The hypothetical commands `show w' and `show c' should show the appropriate 321 | parts of the General Public License. Of course, the commands you use may 322 | be called something other than `show w' and `show c'; they could even be 323 | mouse-clicks or menu items--whatever suits your program. 324 | 325 | You should also get your employer (if you work as a programmer) or your 326 | school, if any, to sign a "copyright disclaimer" for the program, if 327 | necessary. Here is a sample; alter the names: 328 | 329 | Yoyodyne, Inc., hereby disclaims all copyright interest in the program 330 | `Gnomovision' (which makes passes at compilers) written by James Hacker. 331 | 332 | {signature of Ty Coon}, 1 April 1989 333 | Ty Coon, President of Vice 334 | 335 | This General Public License does not permit incorporating your program into 336 | proprietary programs. If your program is a subroutine library, you may 337 | consider it more useful to permit linking proprietary applications with the 338 | library. If this is what you want to do, use the GNU Lesser General 339 | Public License instead of this License. 340 | 341 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Implementation of Flexible Conditional Density Estimator (FlexCode) in Python. See Izbicki, R.; Lee, A.B. [Converting High-Dimensional Regression to High-Dimensional Conditional Density Estimation](https://projecteuclid.org/euclid.ejs/1499133755). Electronic Journal of Statistics, 2017 for details. Port of the original [R package](https://github.com/rizbicki/FlexCoDE). 2 | 3 | 4 | # FlexCode 5 | 6 | FlexCode is a general-purpose method for converting any conditional mean point estimator of $z$ to a conditional density estimator $\(f(z \vert x)\)$, where $x$ represents the covariates. The key idea is to expand the unknown function $f(z \vert x)$ in an orthonormal basis $\{\phi_i(z)\}_{i}$: 7 | 8 | $$f(z|x)=\sum_{i}\beta_{i }(x)\phi_i(z)$$ 9 | 10 | By the orthogonality property, the expansion coefficients are just conditional means 11 | 12 | $$\beta_{i }(x) = \mathbb{E}\left[\phi_i(z)|x\right] \equiv \int f(z|x) \phi_i(z) dz$$ 13 | 14 | where the coefficients are estimated from data by an appropriate regression method. 15 | 16 | 17 | # Installation 18 | 19 | ```shell 20 | git clone https://github.com/lee-group-cmu/FlexCode.git 21 | pip install FlexCode[all] 22 | ``` 23 | 24 | Flexcode handles a number of regression models; if you wish to avoid installing all dependencies you can specify your desired regression methods using the optional requires in brackets. Targets include 25 | 26 | - xgboost 27 | - scikit-learn (for nearest neighbor regression, random forests) 28 | 29 | 30 | # A simple example 31 | 32 | ```python 33 | import numpy as np 34 | import scipy.stats 35 | import flexcode 36 | from flexcode.regression_models import NN 37 | import matplotlib.pyplot as plt 38 | 39 | # Generate data p(z | x) = N(x, 1) 40 | def generate_data(n_draws): 41 | x = np.random.normal(0, 1, n_draws) 42 | z = np.random.normal(x, 1, n_draws) 43 | return x.reshape((len(x), 1)), z.reshape((len(z), 1)) 44 | 45 | x_train, z_train = generate_data(10000) 46 | x_validation, z_validation = generate_data(10000) 47 | x_test, z_test = generate_data(10000) 48 | 49 | # Parameterize model 50 | model = flexcode.FlexCodeModel(NN, max_basis=31, basis_system="cosine", 51 | regression_params={"k":20}) 52 | 53 | # Fit and tune model 54 | model.fit(x_train, z_train) 55 | model.tune(x_validation, z_validation) 56 | 57 | # Estimate CDE loss 58 | print(model.estimate_error(x_test, z_test)) 59 | 60 | # Calculate conditional density estimates 61 | cdes, z_grid = model.predict(x_test, n_grid=200) 62 | 63 | for ii in range(10): 64 | true_density = scipy.stats.norm.pdf(z_grid, x_test[ii], 1) 65 | plt.plot(z_grid, cdes[ii, :]) 66 | plt.plot(z_grid, true_density, color = "green") 67 | plt.axvline(x=z_test[ii], color="red") 68 | plt.show() 69 | 70 | ``` 71 | 72 | 73 | # FlexZBoost Buzzard Data 74 | 75 | One particular realization of the FlexCode algorithm is FlexZBoost which uses XGBoost as the regression method. We apply this method to photo-z estimation in the LSST DESC DC-1. For members of the LSST DESC, you can find information on obtaining the data [here](https://confluence.slac.stanford.edu/pages/viewpage.action?spaceKey=LSSTDESC&title=DC1+resources). 76 | 77 | ```python 78 | import numpy as np 79 | import pandas as pd 80 | import flexcode 81 | from flexcode.regression_models import XGBoost 82 | 83 | # Read in data 84 | def process_data(feature_file, has_z=False): 85 | """Processes buzzard data""" 86 | df = pd.read_table(feature_file, sep=" ") 87 | df["ug"] = df["u"] - df["g"] 88 | 89 | df.assign(ug = df.u - df.g, 90 | gr = df.g - df.r, 91 | ri = df.r - df.i, 92 | iz = df.i - df.z, 93 | zy = df.z - df.y, 94 | ug_err = np.sqrt(df['u.err'] ** 2 + df['g.err'] ** 2), 95 | gr_err = np.sqrt(df['g.err'] ** 2 + df['r.err'] ** 2), 96 | ri_err = np.sqrt(df['r.err'] ** 2 + df['i.err'] ** 2), 97 | iz_err = np.sqrt(df['i.err'] ** 2 + df['z.err'] ** 2), 98 | zy_err = np.sqrt(df['z.err'] ** 2 + df['y.err'] ** 2)) 99 | 100 | if has_z: 101 | z = df.redshift.as_matrix() 102 | df.drop('redshift', axis=1, inplace=True) 103 | else: 104 | z = None 105 | 106 | return df.as_matrix(), z 107 | 108 | x_data, z_data = process_data('buzzard_spec_witherrors_mass.txt', has_z=True) 109 | x_test, _ = process_data('buzzard_phot_witherrors_mass.txt', has_z=False) 110 | 111 | n_obs = x_data.shape[0] 112 | n_train = round(n_obs * 0.8) 113 | n_validation = n_obs - n_train 114 | 115 | perm = np.random.permutation(n_obs) 116 | x_train = x_data[perm[:n_train], :] 117 | z_train = z_data[perm[:n_train]] 118 | x_validation = x_data[perm[n_train:]] 119 | z_validation = z_data[perm[n_train:]] 120 | 121 | # Fit the model 122 | model = flexcode.FlexCodeModel(XGBoost, max_basis=40, basis_system='cosine', 123 | regression_params={"max_depth": 8}) 124 | model.fit(x_train, z_train) 125 | model.tune(x_validation, z_validation) 126 | 127 | # Make predictions 128 | cdes, z_grid = model.predict(x_test, n_grid=200) 129 | 130 | ``` 131 | -------------------------------------------------------------------------------- /archive/README.md: -------------------------------------------------------------------------------- 1 | This directory contains **archived** configuration files. 2 | 3 | Flexcode has been migrated to use the [LINCC Frameworks Python Project Template](https://github.com/lincc-frameworks/python-project-template) which uses a pyproject.toml file for build definitions and requirements. 4 | 5 | The pyproject.toml file replaces the need for setup.py, setup.cfg, and ton.ini, 6 | but the files will be retained for a period of time as a convenience. 7 | 8 | **These files will be removed at a later date!** 9 | Please do not use them. 10 | Update pyproject.toml in the base directory if changes to the build are required. 11 | -------------------------------------------------------------------------------- /archive/x_setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md 3 | 4 | [aliases] 5 | test=pytest -------------------------------------------------------------------------------- /archive/x_setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(name = "flexcode", 4 | version = "0.2", 5 | license="GPL", 6 | description="Fits Flexible Conditional Density Estimator (FlexCode)", 7 | author="Taylor Pospisil, Nic Dalmasso", 8 | maintainer="Taylor Pospisil", 9 | author_email="tpospisi@andrew.cmu.edu", 10 | url="http://github.com/tpospisi/Flexcode", 11 | package_dir={"":"src"}, 12 | packages=["flexcode"], 13 | install_requires=["numpy", "pywavelets"], 14 | setup_requires=["pytest-runner"], 15 | tests_require=["pytest", "scikit-learn", "xgboost"], 16 | zip_safe=True, 17 | extras_require={ 18 | "xgboost" : ["xgboost"], 19 | "scikit-learn" : ["scikit-learn>=0.18"], 20 | "all" : ["scikit-learn", "xgboost"], 21 | }, 22 | ) 23 | -------------------------------------------------------------------------------- /archive/x_tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | minversion = 2.0 3 | envlist = py{27,36}--{linux} 4 | 5 | [testenv] 6 | changedir = tests 7 | deps = 8 | numpy 9 | pywavelets 10 | pytest 11 | scikit-learn 12 | scipy 13 | xgboost==0.82 14 | commands = pytest --basetemp={envtmpdir} {posargs} 15 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= -T -E -d _build/doctrees -D language=en 7 | EXCLUDENB ?= -D exclude_patterns="notebooks/*" 8 | SPHINXBUILD ?= sphinx-build 9 | SOURCEDIR = . 10 | BUILDDIR = ../_readthedocs/ 11 | 12 | .PHONY: help no-nb no-notebooks clean Makefile 13 | 14 | # Put it first so that "make" without argument is like "make help". 15 | help: 16 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 17 | 18 | # Build all Sphinx docs locally, except the notebooks 19 | no-nb no-notebooks: 20 | @$(SPHINXBUILD) -M html "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(EXCLUDENB) $(O) 21 | 22 | # Cleans up files generated by the build process 23 | clean: 24 | rm -r "_build/doctrees" 25 | rm -r "$(BUILDDIR)" 26 | 27 | # Catch-all target: route all unknown targets to Sphinx using the new 28 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 29 | %: Makefile 30 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 31 | 32 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | 7 | import os 8 | import sys 9 | 10 | import autoapi 11 | from importlib.metadata import version 12 | 13 | # Define path to the code to be documented **relative to where conf.py (this file) is kept** 14 | sys.path.insert(0, os.path.abspath('../src/')) 15 | 16 | # -- Project information ----------------------------------------------------- 17 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 18 | 19 | project = "flexcode" 20 | copyright = "2023, Ann Lee" 21 | author = "Ann Lee" 22 | release = version("flexcode") 23 | # for example take major/minor 24 | version = ".".join(release.split(".")[:2]) 25 | 26 | # -- General configuration --------------------------------------------------- 27 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 28 | 29 | extensions = ["sphinx.ext.mathjax", "sphinx.ext.napoleon", "sphinx.ext.viewcode"] 30 | 31 | extensions.append("autoapi.extension") 32 | extensions.append("nbsphinx") 33 | 34 | templates_path = [] 35 | exclude_patterns = ['_build', '**.ipynb_checkpoints'] 36 | 37 | master_doc = "index" # This assumes that sphinx-build is called from the root directory 38 | html_show_sourcelink = False # Remove 'view source code' from top of page (for html, not python) 39 | add_module_names = False # Remove namespaces from class/method signatures 40 | 41 | autoapi_type = "python" 42 | autoapi_dirs = ["../src"] 43 | autoapi_add_toc_tree_entry = False 44 | autoapi_member_order = "bysource" 45 | 46 | html_theme = "sphinx_rtd_theme" 47 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. flexcode documentation main file. 2 | You can adapt this file completely to your liking, but it should at least 3 | contain the root `toctree` directive. 4 | 5 | Welcome to flexcode's documentation! 6 | ======================================================================================== 7 | 8 | .. toctree:: 9 | :hidden: 10 | 11 | Home page 12 | API Reference 13 | Notebooks 14 | -------------------------------------------------------------------------------- /docs/notebooks.rst: -------------------------------------------------------------------------------- 1 | Notebooks 2 | ======================================================================================== 3 | 4 | .. toctree:: 5 | 6 | Introducing Jupyter Notebooks -------------------------------------------------------------------------------- /docs/notebooks/README.md: -------------------------------------------------------------------------------- 1 | Put your Jupyter notebooks here :) -------------------------------------------------------------------------------- /docs/notebooks/intro_notebook.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "accepting-editor", 6 | "metadata": { 7 | "cell_marker": "\"\"\"" 8 | }, 9 | "source": [ 10 | "# Introducing Jupyter Notebooks\n", 11 | "\n", 12 | "_(The example used here is JamesALeedham's notebook: [intro.ipynb](https://github.com/JamesALeedham/Sphinx-Autosummary-Recursion/blob/master/docs/notebooks/intro.ipynb))_\n", 13 | "\n", 14 | "First, set up the environment:" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "id": "actual-thirty", 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "import matplotlib\n", 25 | "import matplotlib.pyplot as pl\n", 26 | "import numpy as np\n", 27 | "\n", 28 | "try:\n", 29 | " from IPython import get_ipython\n", 30 | " get_ipython().run_line_magic('matplotlib', 'inline')\n", 31 | "except AttributeError:\n", 32 | " print('Magic function can only be used in IPython environment')\n", 33 | " matplotlib.use('Agg')\n", 34 | "\n", 35 | "pl.rcParams[\"figure.figsize\"] = [15, 8]" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "id": "coral-upper", 41 | "metadata": { 42 | "cell_marker": "\"\"\"", 43 | "lines_to_next_cell": 1 44 | }, 45 | "source": [ 46 | "Then, define a function that creates a pretty graph:" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "id": "funded-protection", 53 | "metadata": { 54 | "lines_to_next_cell": 1 55 | }, 56 | "outputs": [], 57 | "source": [ 58 | "def SineAndCosineWaves():\n", 59 | " # Get a large number of X values for a nice smooth curve. Using Pi as np.sin requires radians...\n", 60 | " x = np.linspace(0, 2 * np.pi, 180)\n", 61 | " # Convert radians to degrees to make for a meaningful X axis (1 radian = 57.29* degrees)\n", 62 | " xdeg = 57.29577951308232 * np.array(x)\n", 63 | " # Calculate the sine of each value of X\n", 64 | " y = np.sin(x)\n", 65 | " # Calculate the cosine of each value of X\n", 66 | " z = np.cos(x)\n", 67 | " # Plot the sine wave in blue, using degrees rather than radians on the X axis\n", 68 | " pl.plot(xdeg, y, color='blue', label='Sine wave')\n", 69 | " # Plot the cos wave in green, using degrees rather than radians on the X axis\n", 70 | " pl.plot(xdeg, z, color='green', label='Cosine wave')\n", 71 | " pl.xlabel(\"Degrees\")\n", 72 | " # More sensible X axis values\n", 73 | " pl.xticks(np.arange(0, 361, 45))\n", 74 | " pl.legend()\n", 75 | " pl.show()" 76 | ] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "id": "thorough-cutting", 81 | "metadata": { 82 | "cell_marker": "\"\"\"" 83 | }, 84 | "source": [ 85 | "Finally, call that function to display the graph:" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "id": "imported-uruguay", 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "SineAndCosineWaves()" 96 | ] 97 | } 98 | ], 99 | "metadata": { 100 | "jupytext": { 101 | "cell_markers": "\"\"\"" 102 | }, 103 | "kernelspec": { 104 | "display_name": "Python 3", 105 | "language": "python", 106 | "name": "python3" 107 | } 108 | }, 109 | "nbformat": 4, 110 | "nbformat_minor": 5 111 | } 112 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx==6.1.3 2 | sphinx_rtd_theme==1.2.0 3 | sphinx-autoapi==2.0.1 4 | nbsphinx 5 | ipython 6 | jupytext 7 | jupyter 8 | matplotlib 9 | numpy -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "flexcode" 3 | license = {file = "LICENSE"} 4 | readme = "README.md" 5 | authors = [ 6 | { name = "Ann Lee", email = "annlee@andrew.cmu.edu" } 7 | ] 8 | classifiers = [ 9 | "Development Status :: 4 - Beta", 10 | "Intended Audience :: Developers", 11 | "Intended Audience :: Science/Research", 12 | "Operating System :: OS Independent", 13 | "Programming Language :: Python", 14 | ] 15 | dynamic = ["version"] 16 | dependencies = [ 17 | "deprecated", 18 | "ipykernel", # Support for Jupyter notebooks 19 | "pywavelets", 20 | "scikit-learn>=0.18", 21 | "xgboost" 22 | ] 23 | 24 | # On a mac, install optional dependencies with `pip install '.[dev]'` (include the single quotes) 25 | [project.optional-dependencies] 26 | dev = [ 27 | "pytest", 28 | "pytest-cov", # Used to report total code coverage 29 | "pre-commit", # Used to run checks before finalizing a git commit 30 | "sphinx==6.1.3", # Used to automatically generate documentation 31 | "sphinx_rtd_theme==1.2.0", # Used to render documentation 32 | "sphinx-autoapi==2.0.1", # Used to automatically generate api documentation 33 | "black", # Used for static linting of files 34 | "nbconvert", # Needed for pre-commit check to clear output from Python notebooks 35 | "nbsphinx", # Used to itegrate Python notebooks into Sphinx documentation 36 | "ipython", # Also used in building notebooks into Sphinx 37 | "matplotlib", # Used in sample notebook intro_notebook.ipynb 38 | "numpy", # Used in sample notebook intro_notebook.ipynb 39 | ] 40 | 41 | 42 | [build-system] 43 | requires = [ 44 | "setuptools>=45", # Used to build and package the Python project 45 | "setuptools_scm>=6.2", # Gets release version from git. Makes it available programmatically 46 | ] 47 | build-backend = "setuptools.build_meta" 48 | 49 | [tool.setuptools_scm] 50 | write_to = "src/flexcode/_version.py" 51 | 52 | [tool.black] 53 | line-length = 110 54 | -------------------------------------------------------------------------------- /src/.pylintrc: -------------------------------------------------------------------------------- 1 | [MAIN] 2 | 3 | # Analyse import fallback blocks. This can be used to support both Python 2 and 4 | # 3 compatible code, which means that the block might have code that exists 5 | # only in one or another interpreter, leading to false positives when analysed. 6 | analyse-fallback-blocks=no 7 | 8 | # Clear in-memory caches upon conclusion of linting. Useful if running pylint 9 | # in a server-like mode. 10 | clear-cache-post-run=no 11 | 12 | # Load and enable all available extensions. Use --list-extensions to see a list 13 | # all available extensions. 14 | #enable-all-extensions= 15 | 16 | # In error mode, messages with a category besides ERROR or FATAL are 17 | # suppressed, and no reports are done by default. Error mode is compatible with 18 | # disabling specific errors. 19 | #errors-only= 20 | 21 | # Always return a 0 (non-error) status code, even if lint errors are found. 22 | # This is primarily useful in continuous integration scripts. 23 | #exit-zero= 24 | 25 | # A comma-separated list of package or module names from where C extensions may 26 | # be loaded. Extensions are loading into the active Python interpreter and may 27 | # run arbitrary code. 28 | extension-pkg-allow-list= 29 | 30 | # A comma-separated list of package or module names from where C extensions may 31 | # be loaded. Extensions are loading into the active Python interpreter and may 32 | # run arbitrary code. (This is an alternative name to extension-pkg-allow-list 33 | # for backward compatibility.) 34 | extension-pkg-whitelist= 35 | 36 | # Return non-zero exit code if any of these messages/categories are detected, 37 | # even if score is above --fail-under value. Syntax same as enable. Messages 38 | # specified are enabled, while categories only check already-enabled messages. 39 | fail-on= 40 | 41 | # Specify a score threshold under which the program will exit with error. 42 | fail-under=10 43 | 44 | # Interpret the stdin as a python script, whose filename needs to be passed as 45 | # the module_or_package argument. 46 | #from-stdin= 47 | 48 | # Files or directories to be skipped. They should be base names, not paths. 49 | ignore=CVS 50 | 51 | # Add files or directories matching the regular expressions patterns to the 52 | # ignore-list. The regex matches against paths and can be in Posix or Windows 53 | # format. Because '\\' represents the directory delimiter on Windows systems, 54 | # it can't be used as an escape character. 55 | ignore-paths= 56 | 57 | # Files or directories matching the regular expression patterns are skipped. 58 | # The regex matches against base names, not paths. 59 | ignore-patterns=_version.py 60 | 61 | # List of module names for which member attributes should not be checked 62 | # (useful for modules/projects where namespaces are manipulated during runtime 63 | # and thus existing member attributes cannot be deduced by static analysis). It 64 | # supports qualified module names, as well as Unix pattern matching. 65 | ignored-modules= 66 | 67 | # Python code to execute, usually for sys.path manipulation such as 68 | # pygtk.require(). 69 | #init-hook= 70 | 71 | # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the 72 | # number of processors available to use, and will cap the count on Windows to 73 | # avoid hangs. 74 | jobs=1 75 | 76 | # Control the amount of potential inferred values when inferring a single 77 | # object. This can help the performance when dealing with large functions or 78 | # complex, nested conditions. 79 | limit-inference-results=100 80 | 81 | # List of plugins (as comma separated values of python module names) to load, 82 | # usually to register additional checkers. 83 | load-plugins= 84 | 85 | # Pickle collected data for later comparisons. 86 | persistent=yes 87 | 88 | # Minimum Python version to use for version dependent checks. Will default to 89 | # the version used to run pylint. 90 | py-version=3.9 91 | 92 | # Discover python modules and packages in the file system subtree. 93 | recursive=no 94 | 95 | # When enabled, pylint would attempt to guess common misconfiguration and emit 96 | # user-friendly hints instead of false-positive error messages. 97 | suggestion-mode=yes 98 | 99 | # Allow loading of arbitrary C extensions. Extensions are imported into the 100 | # active Python interpreter and may run arbitrary code. 101 | unsafe-load-any-extension=no 102 | 103 | # In verbose mode, extra non-checker-related info will be displayed. 104 | #verbose= 105 | 106 | 107 | [BASIC] 108 | 109 | # Naming style matching correct argument names. 110 | argument-naming-style=snake_case 111 | 112 | # Regular expression matching correct argument names. Overrides argument- 113 | # naming-style. If left empty, argument names will be checked with the set 114 | # naming style. 115 | #argument-rgx= 116 | 117 | # Naming style matching correct attribute names. 118 | attr-naming-style=snake_case 119 | 120 | # Regular expression matching correct attribute names. Overrides attr-naming- 121 | # style. If left empty, attribute names will be checked with the set naming 122 | # style. 123 | #attr-rgx= 124 | 125 | # Bad variable names which should always be refused, separated by a comma. 126 | bad-names=foo, 127 | bar, 128 | baz, 129 | toto, 130 | tutu, 131 | tata 132 | 133 | # Bad variable names regexes, separated by a comma. If names match any regex, 134 | # they will always be refused 135 | bad-names-rgxs= 136 | 137 | # Naming style matching correct class attribute names. 138 | class-attribute-naming-style=any 139 | 140 | # Regular expression matching correct class attribute names. Overrides class- 141 | # attribute-naming-style. If left empty, class attribute names will be checked 142 | # with the set naming style. 143 | #class-attribute-rgx= 144 | 145 | # Naming style matching correct class constant names. 146 | class-const-naming-style=UPPER_CASE 147 | 148 | # Regular expression matching correct class constant names. Overrides class- 149 | # const-naming-style. If left empty, class constant names will be checked with 150 | # the set naming style. 151 | #class-const-rgx= 152 | 153 | # Naming style matching correct class names. 154 | class-naming-style=PascalCase 155 | 156 | # Regular expression matching correct class names. Overrides class-naming- 157 | # style. If left empty, class names will be checked with the set naming style. 158 | #class-rgx= 159 | 160 | # Naming style matching correct constant names. 161 | const-naming-style=UPPER_CASE 162 | 163 | # Regular expression matching correct constant names. Overrides const-naming- 164 | # style. If left empty, constant names will be checked with the set naming 165 | # style. 166 | #const-rgx= 167 | 168 | # Minimum line length for functions/classes that require docstrings, shorter 169 | # ones are exempt. 170 | docstring-min-length=-1 171 | 172 | # Naming style matching correct function names. 173 | function-naming-style=snake_case 174 | 175 | # Regular expression matching correct function names. Overrides function- 176 | # naming-style. If left empty, function names will be checked with the set 177 | # naming style. 178 | #function-rgx= 179 | 180 | # Good variable names which should always be accepted, separated by a comma. 181 | good-names=i, 182 | j, 183 | k, 184 | ex, 185 | Run, 186 | _ 187 | 188 | # Good variable names regexes, separated by a comma. If names match any regex, 189 | # they will always be accepted 190 | good-names-rgxs= 191 | 192 | # Include a hint for the correct naming format with invalid-name. 193 | include-naming-hint=no 194 | 195 | # Naming style matching correct inline iteration names. 196 | inlinevar-naming-style=any 197 | 198 | # Regular expression matching correct inline iteration names. Overrides 199 | # inlinevar-naming-style. If left empty, inline iteration names will be checked 200 | # with the set naming style. 201 | #inlinevar-rgx= 202 | 203 | # Naming style matching correct method names. 204 | method-naming-style=snake_case 205 | 206 | # Regular expression matching correct method names. Overrides method-naming- 207 | # style. If left empty, method names will be checked with the set naming style. 208 | #method-rgx= 209 | 210 | # Naming style matching correct module names. 211 | module-naming-style=snake_case 212 | 213 | # Regular expression matching correct module names. Overrides module-naming- 214 | # style. If left empty, module names will be checked with the set naming style. 215 | #module-rgx= 216 | 217 | # Colon-delimited sets of names that determine each other's naming style when 218 | # the name regexes allow several styles. 219 | name-group= 220 | 221 | # Regular expression which should only match function or class names that do 222 | # not require a docstring. 223 | no-docstring-rgx=^_ 224 | 225 | # List of decorators that produce properties, such as abc.abstractproperty. Add 226 | # to this list to register other decorators that produce valid properties. 227 | # These decorators are taken in consideration only for invalid-name. 228 | property-classes=abc.abstractproperty 229 | 230 | # Regular expression matching correct type variable names. If left empty, type 231 | # variable names will be checked with the set naming style. 232 | #typevar-rgx= 233 | 234 | # Naming style matching correct variable names. 235 | variable-naming-style=snake_case 236 | 237 | # Regular expression matching correct variable names. Overrides variable- 238 | # naming-style. If left empty, variable names will be checked with the set 239 | # naming style. 240 | #variable-rgx= 241 | 242 | 243 | [CLASSES] 244 | 245 | # Warn about protected attribute access inside special methods 246 | check-protected-access-in-special-methods=no 247 | 248 | # List of method names used to declare (i.e. assign) instance attributes. 249 | defining-attr-methods=__init__, 250 | __new__, 251 | setUp, 252 | __post_init__ 253 | 254 | # List of member names, which should be excluded from the protected access 255 | # warning. 256 | exclude-protected=_asdict, 257 | _fields, 258 | _replace, 259 | _source, 260 | _make 261 | 262 | # List of valid names for the first argument in a class method. 263 | valid-classmethod-first-arg=cls 264 | 265 | # List of valid names for the first argument in a metaclass class method. 266 | valid-metaclass-classmethod-first-arg=mcs 267 | 268 | 269 | [DESIGN] 270 | 271 | # List of regular expressions of class ancestor names to ignore when counting 272 | # public methods (see R0903) 273 | exclude-too-few-public-methods= 274 | 275 | # List of qualified class names to ignore when counting class parents (see 276 | # R0901) 277 | ignored-parents= 278 | 279 | # Maximum number of arguments for function / method. 280 | max-args=5 281 | 282 | # Maximum number of attributes for a class (see R0902). 283 | max-attributes=7 284 | 285 | # Maximum number of boolean expressions in an if statement (see R0916). 286 | max-bool-expr=5 287 | 288 | # Maximum number of branch for function / method body. 289 | max-branches=12 290 | 291 | # Maximum number of locals for function / method body. 292 | max-locals=15 293 | 294 | # Maximum number of parents for a class (see R0901). 295 | max-parents=7 296 | 297 | # Maximum number of public methods for a class (see R0904). 298 | max-public-methods=20 299 | 300 | # Maximum number of return / yield for function / method body. 301 | max-returns=6 302 | 303 | # Maximum number of statements in function / method body. 304 | max-statements=50 305 | 306 | # Minimum number of public methods for a class (see R0903). 307 | min-public-methods=2 308 | 309 | 310 | [EXCEPTIONS] 311 | 312 | # Exceptions that will emit a warning when caught. 313 | overgeneral-exceptions=builtins.BaseException,builtins.Exception 314 | 315 | 316 | [FORMAT] 317 | 318 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. 319 | expected-line-ending-format= 320 | 321 | # Regexp for a line that is allowed to be longer than the limit. 322 | ignore-long-lines=^\s*(# )??$ 323 | 324 | # Number of spaces of indent required inside a hanging or continued line. 325 | indent-after-paren=4 326 | 327 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 328 | # tab). 329 | indent-string=' ' 330 | 331 | # Maximum number of characters on a single line. 332 | max-line-length=100 333 | 334 | # Maximum number of lines in a module. 335 | max-module-lines=1000 336 | 337 | # Allow the body of a class to be on the same line as the declaration if body 338 | # contains single statement. 339 | single-line-class-stmt=no 340 | 341 | # Allow the body of an if to be on the same line as the test if there is no 342 | # else. 343 | single-line-if-stmt=no 344 | 345 | 346 | [IMPORTS] 347 | 348 | # List of modules that can be imported at any level, not just the top level 349 | # one. 350 | allow-any-import-level= 351 | 352 | # Allow explicit reexports by alias from a package __init__. 353 | allow-reexport-from-package=no 354 | 355 | # Allow wildcard imports from modules that define __all__. 356 | allow-wildcard-with-all=no 357 | 358 | # Deprecated modules which should not be used, separated by a comma. 359 | deprecated-modules= 360 | 361 | # Output a graph (.gv or any supported image format) of external dependencies 362 | # to the given file (report RP0402 must not be disabled). 363 | ext-import-graph= 364 | 365 | # Output a graph (.gv or any supported image format) of all (i.e. internal and 366 | # external) dependencies to the given file (report RP0402 must not be 367 | # disabled). 368 | import-graph= 369 | 370 | # Output a graph (.gv or any supported image format) of internal dependencies 371 | # to the given file (report RP0402 must not be disabled). 372 | int-import-graph= 373 | 374 | # Force import order to recognize a module as part of the standard 375 | # compatibility libraries. 376 | known-standard-library= 377 | 378 | # Force import order to recognize a module as part of a third party library. 379 | known-third-party=enchant 380 | 381 | # Couples of modules and preferred modules, separated by a comma. 382 | preferred-modules= 383 | 384 | 385 | [LOGGING] 386 | 387 | # The type of string formatting that logging methods do. `old` means using % 388 | # formatting, `new` is for `{}` formatting. 389 | logging-format-style=old 390 | 391 | # Logging modules to check that the string format arguments are in logging 392 | # function parameter format. 393 | logging-modules=logging 394 | 395 | 396 | [MESSAGES CONTROL] 397 | 398 | # Only show warnings with the listed confidence levels. Leave empty to show 399 | # all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, 400 | # UNDEFINED. 401 | confidence=HIGH, 402 | CONTROL_FLOW, 403 | INFERENCE, 404 | INFERENCE_FAILURE, 405 | UNDEFINED 406 | 407 | # Disable the message, report, category or checker with the given id(s). You 408 | # can either give multiple identifiers separated by comma (,) or put this 409 | # option multiple times (only on the command line, not in the configuration 410 | # file where it should appear only once). You can also use "--disable=all" to 411 | # disable everything first and then re-enable specific checks. For example, if 412 | # you want to run only the similarities checker, you can use "--disable=all 413 | # --enable=similarities". If you want to run only the classes checker, but have 414 | # no Warning level messages displayed, use "--disable=all --enable=classes 415 | # --disable=W". 416 | disable=raw-checker-failed, 417 | bad-inline-option, 418 | locally-disabled, 419 | file-ignored, 420 | suppressed-message, 421 | useless-suppression, 422 | deprecated-pragma, 423 | use-symbolic-message-instead, 424 | missing-module-docstring, 425 | unnecessary-pass, 426 | 427 | 428 | # Enable the message, report, category or checker with the given id(s). You can 429 | # either give multiple identifier separated by comma (,) or put this option 430 | # multiple time (only on the command line, not in the configuration file where 431 | # it should appear only once). See also the "--disable" option for examples. 432 | enable=c-extension-no-member 433 | 434 | 435 | [METHOD_ARGS] 436 | 437 | # List of qualified names (i.e., library.method) which require a timeout 438 | # parameter e.g. 'requests.api.get,requests.api.post' 439 | timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request 440 | 441 | 442 | [MISCELLANEOUS] 443 | 444 | # List of note tags to take in consideration, separated by a comma. 445 | notes=FIXME, 446 | XXX, 447 | TODO 448 | 449 | # Regular expression of note tags to take in consideration. 450 | notes-rgx= 451 | 452 | 453 | [REFACTORING] 454 | 455 | # Maximum number of nested blocks for function / method body 456 | max-nested-blocks=5 457 | 458 | # Complete name of functions that never returns. When checking for 459 | # inconsistent-return-statements if a never returning function is called then 460 | # it will be considered as an explicit return statement and no message will be 461 | # printed. 462 | never-returning-functions=sys.exit,argparse.parse_error 463 | 464 | 465 | [REPORTS] 466 | 467 | # Python expression which should return a score less than or equal to 10. You 468 | # have access to the variables 'fatal', 'error', 'warning', 'refactor', 469 | # 'convention', and 'info' which contain the number of messages in each 470 | # category, as well as 'statement' which is the total number of statements 471 | # analyzed. This score is used by the global evaluation report (RP0004). 472 | evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)) 473 | 474 | # Template used to display messages. This is a python new-style format string 475 | # used to format the message information. See doc for all details. 476 | msg-template= 477 | 478 | # Set the output format. Available formats are text, parseable, colorized, json 479 | # and msvs (visual studio). You can also give a reporter class, e.g. 480 | # mypackage.mymodule.MyReporterClass. 481 | #output-format= 482 | 483 | # Tells whether to display a full report or only the messages. 484 | reports=no 485 | 486 | # Activate the evaluation score. 487 | score=yes 488 | 489 | 490 | [SIMILARITIES] 491 | 492 | # Comments are removed from the similarity computation 493 | ignore-comments=yes 494 | 495 | # Docstrings are removed from the similarity computation 496 | ignore-docstrings=yes 497 | 498 | # Imports are removed from the similarity computation 499 | ignore-imports=yes 500 | 501 | # Signatures are removed from the similarity computation 502 | ignore-signatures=yes 503 | 504 | # Minimum lines number of a similarity. 505 | min-similarity-lines=4 506 | 507 | 508 | [SPELLING] 509 | 510 | # Limits count of emitted suggestions for spelling mistakes. 511 | max-spelling-suggestions=4 512 | 513 | # Spelling dictionary name. Available dictionaries: none. To make it work, 514 | # install the 'python-enchant' package. 515 | spelling-dict= 516 | 517 | # List of comma separated words that should be considered directives if they 518 | # appear at the beginning of a comment and should not be checked. 519 | spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: 520 | 521 | # List of comma separated words that should not be checked. 522 | spelling-ignore-words= 523 | 524 | # A path to a file that contains the private dictionary; one word per line. 525 | spelling-private-dict-file= 526 | 527 | # Tells whether to store unknown words to the private dictionary (see the 528 | # --spelling-private-dict-file option) instead of raising a message. 529 | spelling-store-unknown-words=no 530 | 531 | 532 | [STRING] 533 | 534 | # This flag controls whether inconsistent-quotes generates a warning when the 535 | # character used as a quote delimiter is used inconsistently within a module. 536 | check-quote-consistency=no 537 | 538 | # This flag controls whether the implicit-str-concat should generate a warning 539 | # on implicit string concatenation in sequences defined over several lines. 540 | check-str-concat-over-line-jumps=no 541 | 542 | 543 | [TYPECHECK] 544 | 545 | # List of decorators that produce context managers, such as 546 | # contextlib.contextmanager. Add to this list to register other decorators that 547 | # produce valid context managers. 548 | contextmanager-decorators=contextlib.contextmanager 549 | 550 | # List of members which are set dynamically and missed by pylint inference 551 | # system, and so shouldn't trigger E1101 when accessed. Python regular 552 | # expressions are accepted. 553 | generated-members= 554 | 555 | # Tells whether to warn about missing members when the owner of the attribute 556 | # is inferred to be None. 557 | ignore-none=yes 558 | 559 | # This flag controls whether pylint should warn about no-member and similar 560 | # checks whenever an opaque object is returned when inferring. The inference 561 | # can return multiple potential results while evaluating a Python object, but 562 | # some branches might not be evaluated, which results in partial inference. In 563 | # that case, it might be useful to still emit no-member and other checks for 564 | # the rest of the inferred objects. 565 | ignore-on-opaque-inference=yes 566 | 567 | # List of symbolic message names to ignore for Mixin members. 568 | ignored-checks-for-mixins=no-member, 569 | not-async-context-manager, 570 | not-context-manager, 571 | attribute-defined-outside-init 572 | 573 | # List of class names for which member attributes should not be checked (useful 574 | # for classes with dynamically set attributes). This supports the use of 575 | # qualified names. 576 | ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace 577 | 578 | # Show a hint with possible names when a member name was not found. The aspect 579 | # of finding the hint is based on edit distance. 580 | missing-member-hint=yes 581 | 582 | # The minimum edit distance a name should have in order to be considered a 583 | # similar match for a missing member name. 584 | missing-member-hint-distance=1 585 | 586 | # The total number of similar names that should be taken in consideration when 587 | # showing a hint for a missing member. 588 | missing-member-max-choices=1 589 | 590 | # Regex pattern to define which classes are considered mixins. 591 | mixin-class-rgx=.*[Mm]ixin 592 | 593 | # List of decorators that change the signature of a decorated function. 594 | signature-mutators= 595 | 596 | 597 | [VARIABLES] 598 | 599 | # List of additional names supposed to be defined in builtins. Remember that 600 | # you should avoid defining new builtins when possible. 601 | additional-builtins= 602 | 603 | # Tells whether unused global variables should be treated as a violation. 604 | allow-global-unused-variables=yes 605 | 606 | # List of names allowed to shadow builtins 607 | allowed-redefined-builtins= 608 | 609 | # List of strings which can identify a callback function by name. A callback 610 | # name must start or end with one of those strings. 611 | callbacks=cb_, 612 | _cb 613 | 614 | # A regular expression matching the name of dummy variables (i.e. expected to 615 | # not be used). 616 | dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ 617 | 618 | # Argument names that match this expression will be ignored. 619 | ignored-argument-names=_.*|^ignored_|^unused_ 620 | 621 | # Tells whether we should check for unused import in __init__ files. 622 | init-import=no 623 | 624 | # List of qualified module names which can have objects that can redefine 625 | # builtins. 626 | redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io 627 | -------------------------------------------------------------------------------- /src/flexcode/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import FlexCodeModel 2 | -------------------------------------------------------------------------------- /src/flexcode/basis_functions.py: -------------------------------------------------------------------------------- 1 | """Functions for evaluation of orthogonal basis functions.""" 2 | 3 | import numpy as np 4 | import pywt 5 | 6 | from .helpers import box_transform, make_grid 7 | from .post_processing import * 8 | 9 | 10 | def evaluate_basis(responses, n_basis, basis_system): 11 | """Evaluates a system of basis functions. 12 | 13 | Arguments 14 | ---------- 15 | responses : array 16 | An array of responses in [0, 1]. 17 | n_basis : integer 18 | The number of basis functions to calculate. 19 | basis_system : {'cosine', 'Fourier', 'db4'} 20 | String denoting the system of orthogonal basis functions. 21 | 22 | Returns 23 | ------- 24 | numpy matrix 25 | A matrix of basis functions evaluations. Each row corresponds 26 | to a value of `responses`, each column corresponds to a basis function. 27 | 28 | Raises 29 | ------ 30 | ValueError 31 | If the basis system isn't recognized. 32 | 33 | """ 34 | systems = {"cosine": cosine_basis, "Fourier": fourier_basis, "db4": wavelet_basis} 35 | try: 36 | basis_fn = systems[basis_system] 37 | except KeyError: 38 | raise ValueError("Basis system {} not recognized".format(basis_system)) 39 | 40 | n_dim = responses.shape[1] 41 | if n_dim == 1: 42 | return basis_fn(responses, n_basis) 43 | else: 44 | if len(n_basis) == 1: 45 | n_basis = [n_basis] * n_dim 46 | return tensor_basis(responses, n_basis, basis_fn) 47 | 48 | 49 | def tensor_basis(responses, n_basis, basis_fn): 50 | """Evaluates tensor basis. 51 | 52 | Combines single-dimensional basis functions \phi_{d}(z) to form 53 | orthogonal tensor basis $\phi(z_{1}, \dots, z_{D}) = \prod_{d} 54 | \phi_{d}(z_{d})$. 55 | 56 | Arguments 57 | --------- 58 | responses : numpy matrix 59 | A matrix of responses in [0, 1]^(n_dim). Each column 60 | corresponds to a variable, each row corresponds to an 61 | observation. 62 | n_basis : list of integers 63 | The number of basis function for each dimension. Should have 64 | the same length as the number of columns of `responses`. 65 | basis_fn : function 66 | The function which evaluates the one-dimensional basis 67 | functions. 68 | 69 | Returns 70 | ------- 71 | numpy matrix 72 | Returns a matrix where each column is a basis function and 73 | each row is an observation. 74 | 75 | """ 76 | n_obs, n_dims = responses.shape 77 | 78 | basis = np.ones((n_obs, np.prod(n_basis))) 79 | period = 1 80 | for dim in range(n_dims): 81 | sub_basis = basis_fn(responses[:, dim], n_basis[dim]) 82 | col = 0 83 | for _ in range(np.prod(n_basis) // (n_basis[dim] * period)): 84 | for sub_col in range(n_basis[dim]): 85 | for _ in range(period): 86 | basis[:, col] *= sub_basis[:, sub_col] 87 | col += 1 88 | period *= n_basis[dim] 89 | return basis 90 | 91 | 92 | def cosine_basis(responses, n_basis): 93 | """Evaluates cosine basis. 94 | 95 | Arguments 96 | ---------- 97 | responses : array 98 | An array of responses in [0, 1]. 99 | n_basis : integer 100 | The number of basis functions to evaluate. 101 | 102 | Returns 103 | ------- 104 | numpy matrix 105 | A matrix of cosine basis functions evaluated at `responses`. Each row 106 | corresponds to a value of `responses`, each column corresponds to a 107 | basis function. 108 | 109 | """ 110 | n_obs = responses.shape[0] 111 | basis = np.empty((n_obs, n_basis)) 112 | 113 | responses = responses.flatten() 114 | 115 | basis[:, 0] = 1.0 116 | for col in range(1, n_basis): 117 | basis[:, col] = np.sqrt(2) * np.cos(np.pi * col * responses) 118 | return basis 119 | 120 | 121 | def fourier_basis(responses, n_basis): 122 | """Evaluates Fourier basis. 123 | 124 | Arguments 125 | ---------- 126 | responses : array 127 | An array of responses in [0, 1]. 128 | n_basis : integer 129 | The number of basis functions to evaluate. 130 | 131 | Returns 132 | ------- 133 | numpy matrix 134 | A matrix of Fourier basis functions evaluated at `responses`. Each row 135 | corresponds to a value of `responses`, each column corresponds to a 136 | basis function. 137 | 138 | """ 139 | n_obs = responses.shape[0] 140 | basis = np.zeros((n_obs, n_basis)) 141 | 142 | responses = responses.flatten() 143 | 144 | basis[:, 0] = 1.0 145 | for col in range(1, (n_basis + 1) // 2): 146 | basis[:, 2 * col - 1] = np.sqrt(2) * np.sin(2 * np.pi * col * responses) 147 | basis[:, 2 * col] = np.sqrt(2) * np.cos(2 * np.pi * col * responses) 148 | if n_basis % 2 == 0: 149 | basis[:, -1] = np.sqrt(2) * np.sin(np.pi * n_basis * responses) 150 | return basis 151 | 152 | 153 | def wavelet_basis(responses, n_basis, family="db4"): 154 | """Evaluates Daubechies basis. 155 | 156 | Arguments 157 | ---------- 158 | responses : array 159 | An array of responses in [0, 1]. 160 | n_basis : integer 161 | The number of basis functions to evaluate. 162 | family : string 163 | The wavelet family to evaluate. 164 | 165 | Returns 166 | ------- 167 | numpy matrix 168 | A matrix of Fourier basis functions evaluated at `responses`. Each row 169 | corresponds to a value of `responses`, each column corresponds to a 170 | basis function. 171 | 172 | """ 173 | responses = responses.flatten() 174 | 175 | n_aux = 15 176 | rez = pywt.DiscreteContinuousWavelet(family).wavefun(n_aux) 177 | if len(rez) == 2: 178 | wavelet, x_grid = rez 179 | else: 180 | _, wavelet, x_grid = rez 181 | wavelet *= np.sqrt(max(x_grid) - min(x_grid)) 182 | x_grid = (x_grid - min(x_grid)) / (max(x_grid) - min(x_grid)) 183 | 184 | def _wave_fun(val): 185 | if val < 0 or val > 1: 186 | return 0.0 187 | return wavelet[np.argmin(abs(val - x_grid))] 188 | 189 | n_obs = responses.shape[0] 190 | basis = np.empty((n_obs, n_basis)) 191 | basis[:, 0] = 1.0 192 | 193 | loc = 0 194 | level = 0 195 | for col in range(1, n_basis): 196 | basis[:, col] = [2 ** (level / 2) * _wave_fun(a * 2**level - loc) for a in responses] 197 | loc += 1 198 | if loc == 2**level: 199 | loc = 0 200 | level += 1 201 | return basis 202 | 203 | 204 | class BasisCoefs(object): 205 | def __init__(self, coefs, basis_system, z_min, z_max, bump_threshold=None, sharpen_alpha=None): 206 | self.coefs = coefs 207 | self.basis_system = basis_system 208 | self.z_min = z_min 209 | self.z_max = z_max 210 | self.bump_threshold = bump_threshold 211 | self.sharpen_alpha = sharpen_alpha 212 | 213 | def evaluate(self, z_grid): 214 | basis = evaluate_basis( 215 | box_transform(z_grid, self.z_min, self.z_max), self.coefs.shape[1], self.basis_system 216 | ) 217 | cdes = np.matmul(self.coefs, basis.T) 218 | 219 | normalize(cdes) 220 | if self.bump_threshold is not None: 221 | remove_bumps(cdes, self.bump_threshold) 222 | if self.sharpen_alpha is not None: 223 | sharpen(cdes, self.sharpen_alpha) 224 | cdes /= self.z_max - self.z_min 225 | return cdes 226 | -------------------------------------------------------------------------------- /src/flexcode/core.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .basis_functions import BasisCoefs, evaluate_basis 4 | from .helpers import box_transform, make_grid 5 | from .loss_functions import cde_loss 6 | from .post_processing import * 7 | 8 | 9 | class FlexCodeModel(object): 10 | def __init__( 11 | self, 12 | model, 13 | max_basis, 14 | basis_system="cosine", 15 | z_min=None, 16 | z_max=None, 17 | regression_params={}, 18 | custom_model=None, 19 | ): 20 | """Initialize FlexCodeModel object 21 | 22 | :param model: A FlexCodeRegression object 23 | :param max_basis: int, the maximal number of basis functions 24 | :param basis_system: string, the basis system: options are "cosine" 25 | :param z_min: float, the minimum z value; if None will default 26 | to the minimum of the training values 27 | :param z_max: float, the maximum z value; if None will default 28 | to the maximum of the training values 29 | :param regression_params: A dictionary of tuning parameters 30 | for the regression model 31 | :param custom_model: a scikit-learn-type model, i.e. with fit and 32 | predict method. 33 | """ 34 | self.max_basis = max_basis 35 | self.best_basis = range(max_basis) 36 | self.basis_system = basis_system 37 | self.model = model(max_basis, regression_params, custom_model) 38 | 39 | self.z_min = z_min 40 | self.z_max = z_max 41 | 42 | self.bump_threshold = None 43 | self.sharpen_alpha = None 44 | 45 | def fit(self, x_train, z_train, weight=None): 46 | """Fits basis function regression models. 47 | 48 | :param x_train: a numpy matrix of training covariates. 49 | :param z_train: a numpy array of z values. 50 | :param weight: (optional) a numpy array of weights. 51 | :returns: None. 52 | :rtype: 53 | 54 | """ 55 | if len(x_train.shape) == 1: 56 | x_train = x_train.reshape(-1, 1) 57 | if len(z_train.shape) == 1: 58 | z_train = z_train.reshape(-1, 1) 59 | 60 | if self.z_min is None: 61 | self.z_min = min(z_train) 62 | if self.z_max is None: 63 | self.z_max = max(z_train) 64 | 65 | z_basis = evaluate_basis( 66 | box_transform(z_train, self.z_min, self.z_max), self.max_basis, self.basis_system 67 | ) 68 | 69 | self.model.fit(x_train, z_basis, weight) 70 | 71 | def tune(self, x_validation, z_validation, bump_threshold_grid=None, sharpen_grid=None, n_grid=1000): 72 | """Set tuning parameters to minimize CDE loss 73 | 74 | Sets best_basis, bump_delta, and sharpen_alpha values attributes 75 | 76 | :param x_validation: a numpy matrix of covariates 77 | :param z_validation: a numpy array of z values 78 | :param bump_threshold_grid: an array of candidate bump threshold values 79 | :param sharpen_grid: an array of candidate sharpen parameter values 80 | :param n_grid: integer, the number of grid points to evaluate 81 | :returns: None 82 | :rtype: 83 | 84 | """ 85 | if len(x_validation.shape) == 1: 86 | x_validation = x_validation.reshape(-1, 1) 87 | if len(z_validation.shape) == 1: 88 | z_validation = z_validation.reshape(-1, 1) 89 | 90 | z_validation = box_transform(z_validation, self.z_min, self.z_max) 91 | z_basis = evaluate_basis(z_validation, self.max_basis, self.basis_system) 92 | 93 | coefs = self.model.predict(x_validation) 94 | 95 | term1 = np.mean(coefs**2, 0) 96 | term2 = np.mean(coefs * z_basis, 0) 97 | # losses = np.cumsum(term1 - 2 * term2) 98 | self.best_basis = np.where(term1 - 2 * term2 < 0.0)[0] 99 | 100 | if bump_threshold_grid is not None or sharpen_grid is not None: 101 | coefs = coefs[:, self.best_basis] 102 | z_grid = make_grid(n_grid, self.z_min, self.z_max) 103 | z_basis = evaluate_basis( 104 | box_transform(z_grid, self.z_min, self.z_max), max(self.best_basis) + 1, self.basis_system 105 | ) 106 | z_basis = z_basis[:, self.best_basis] 107 | cdes = np.matmul(coefs, z_basis.T) 108 | normalize(cdes) 109 | 110 | if bump_threshold_grid is not None: 111 | self.bump_threshold = choose_bump_threshold(cdes, z_grid, z_validation, bump_threshold_grid) 112 | 113 | remove_bumps(cdes, self.bump_threshold) 114 | normalize(cdes) 115 | 116 | if sharpen_grid is not None: 117 | self.sharpen_alpha = choose_sharpen(cdes, z_grid, z_validation, sharpen_grid) 118 | 119 | def predict_coefs(self, x_new): 120 | if len(x_new.shape) == 1: 121 | x_new = x_new.reshape(-1, 1) 122 | 123 | coefs = self.model.predict(x_new)[:, self.best_basis] 124 | return BasisCoefs( 125 | coefs, self.basis_system, self.z_min, self.z_max, self.bump_threshold, self.sharpen_alpha 126 | ) 127 | 128 | def predict(self, x_new, n_grid): 129 | """Predict conditional density estimates on new data 130 | 131 | n :param x_new: A numpy matrix of covariates at which to predict 132 | :param n_grid: int, the number of grid points at which to 133 | predict the conditional density 134 | :returns: A numpy matrix where each row is a conditional 135 | density estimate at the grid points 136 | :rtype: numpy matrix 137 | 138 | """ 139 | if len(x_new.shape) == 1: 140 | x_new = x_new.reshape(-1, 1) 141 | 142 | z_grid = make_grid(n_grid, 0.0, 1.0) 143 | z_basis = evaluate_basis(z_grid, max(self.best_basis) + 1, self.basis_system) 144 | z_basis = z_basis[:, self.best_basis] 145 | coefs = self.model.predict(x_new)[:, self.best_basis] 146 | cdes = np.matmul(coefs, z_basis.T) 147 | 148 | # Post-process 149 | normalize(cdes) 150 | if self.bump_threshold is not None: 151 | remove_bumps(cdes, self.bump_threshold) 152 | if self.sharpen_alpha is not None: 153 | sharpen(cdes, self.sharpen_alpha) 154 | cdes /= self.z_max - self.z_min 155 | return cdes, make_grid(n_grid, self.z_min, self.z_max) 156 | 157 | def estimate_error(self, x_test, z_test, n_grid=1000): 158 | """Estimates CDE loss on test data 159 | 160 | :param x_test: A numpy matrix of covariates 161 | :param z_test: A numpy matrix of z values 162 | :param n_grid: Number of grid points at which to predict the 163 | conditional density 164 | :returns: an estimate of the CDE loss 165 | :rtype: float 166 | 167 | """ 168 | if len(x_test.shape) == 1: 169 | x_test = x_test.reshape(-1, 1) 170 | if len(z_test.shape) == 1: 171 | z_test = z_test.reshape(-1, 1) 172 | 173 | cde_estimate, z_grid = self.predict(x_test, n_grid) 174 | return cde_loss(cde_estimate, z_grid, z_test) 175 | -------------------------------------------------------------------------------- /src/flexcode/helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def box_transform(z, z_min, z_max): 5 | """Projects z from box [z_min, z_max] to [0, 1] 6 | 7 | :param z: an array of z values 8 | :param z_min: float, the minimum value of the z box 9 | :param z_max: float, the maximum value of the z box 10 | :returns: z projected onto [0, 1] 11 | 12 | """ 13 | 14 | return (z - z_min) / (z_max - z_min) 15 | 16 | 17 | def make_grid(n_grid, z_min, z_max): 18 | """Create grid of equally spaced points 19 | 20 | :param n_grid: integer number of grid points 21 | :param z_min: float, the minimum value of the z box 22 | :param z_max: float, the maximum value of the z box 23 | :returns: a grid of n_grid equally spaced points between z_min and z_max 24 | 25 | """ 26 | return np.linspace(z_min, z_max, n_grid).reshape((n_grid, 1)) 27 | 28 | 29 | def params_dict_optim_decision(params, multi_output=False): 30 | """ 31 | Ingest parameter dictionary and determines whether to do CV optimization. 32 | If one of the parameter has a list of length above 1 as values 33 | then automatically format the dictionary for GridSearchCV. 34 | 35 | :param params: dictionary of model parameters 36 | :param multi_output: boolean flag, whether the optimization would need 37 | to be performed in MultiOutputRegressor 38 | :returns: a dictionary of parameters and a boolean flag of whether CV-opt 39 | is going to be performed. If CV-optimization is set to happen then 40 | the paramater dictionary is correctly format. 41 | """ 42 | 43 | # Determines whether there are any list in the items of the dictionary 44 | opt_flag = False 45 | for k, value in params.items(): 46 | if type(value) == tuple: 47 | raise ValueError( 48 | "Parameter values need to be lists or np.array, not tuple." 49 | "Current issues with parameter %s" % (k) 50 | ) 51 | if type(value) == list or type(value) == np.ndarray: 52 | opt_flag = True 53 | break 54 | 55 | # Format the dictionary if necessary - put int, str and float into a list 56 | # with one element 57 | out_param_dict = {} if opt_flag else params.copy() 58 | if opt_flag: 59 | for k, value in params.items(): 60 | out_value = value.tolist() if type(value) == np.ndarray else value 61 | out_value = [out_value] if type(out_value) != list else out_value 62 | out_key = "estimator__" + k if multi_output else k 63 | out_param_dict[out_key] = out_value 64 | 65 | return out_param_dict, opt_flag 66 | 67 | 68 | def params_name_format(params, str_rem): 69 | """ 70 | Changes all the key in dictionaries to remove a specific word from each key (``estimator__``). 71 | This is because in order to GridsearchCV on MultiOutputRegressor one needs to 72 | use ``estimator__`` in all parameters - but once the best parameters are fetched 73 | the name needs to be changed. 74 | 75 | :param params: dictionary of model parameters 76 | :param str_rem: word to be removed 77 | :returns: dictionary of parameters in which the word has been removed in keys 78 | """ 79 | out_dict = {} 80 | for k, v in params.items(): 81 | new_key = k.replace(str_rem, "") if str_rem in k else k 82 | out_dict[new_key] = v 83 | return out_dict 84 | -------------------------------------------------------------------------------- /src/flexcode/loss_functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def cde_loss(cde_estimates, z_grid, true_z): 5 | """Calculates conditional density estimation loss on holdout data 6 | 7 | @param cde_estimates: a numpy array where each row is a density 8 | estimate on z_grid 9 | @param z_grid: a numpy array of the grid points at which cde_estimates is evaluated 10 | @param true_z: a numpy array of the true z values corresponding to the rows of cde_estimates 11 | 12 | @returns The CDE loss (up to a constant) for the CDE estimator on 13 | the holdout data 14 | """ 15 | 16 | n_obs, n_grid = cde_estimates.shape 17 | 18 | term1 = np.mean(np.trapz(cde_estimates**2, z_grid.flatten())) 19 | 20 | nns = [np.argmin(np.abs(z_grid - true_z[ii])) for ii in range(n_obs)] 21 | term2 = np.mean(cde_estimates[range(n_obs), nns]) 22 | return term1 - 2 * term2 23 | -------------------------------------------------------------------------------- /src/flexcode/post_processing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .loss_functions import cde_loss 4 | 5 | 6 | def normalize(cde_estimates, tol=1e-6, max_iter=200): 7 | """Normalizes conditional density estimates to be non-negative and 8 | integrate to one. 9 | 10 | Assumes densities are evaluated on the unit grid. 11 | 12 | :param cde_estimates: a numpy array or matrix of conditional density estimates. 13 | :param tol: float, the tolerance to accept for abs(area - 1). 14 | :param max_iter: int, the maximal number of search iterations. 15 | :returns: the normalized conditional density estimates. 16 | :rtype: numpy array or matrix. 17 | 18 | """ 19 | if cde_estimates.ndim == 1: 20 | _normalize(cde_estimates, tol, max_iter) 21 | else: 22 | np.apply_along_axis(_normalize, 1, cde_estimates, tol=tol, max_iter=max_iter) 23 | 24 | 25 | def _normalize(density, tol=1e-6, max_iter=500): 26 | """Normalizes a density estimate to be non-negative and integrate to 27 | one. 28 | 29 | Assumes density is evaluated on the unit grid. 30 | 31 | :param density: a numpy array of density estimates. 32 | :param z_grid: an array, the grid points at the density is estimated. 33 | :param tol: float, the tolerance to accept for abs(area - 1). 34 | :param max_iter: int, the maximal number of search iterations. 35 | :returns: the normalized density estimate. 36 | :rtype: numpy array. 37 | 38 | """ 39 | hi = np.max(density) 40 | lo = 0.0 41 | 42 | area = np.mean(np.maximum(density, 0.0)) 43 | if area == 0.0: 44 | # replace with uniform if all negative density 45 | density[:] = 1.0 46 | elif area < 1: 47 | density /= area 48 | density[density < 0.0] = 0.0 49 | return 50 | 51 | for _ in range(max_iter): 52 | mid = (hi + lo) / 2 53 | area = np.mean(np.maximum(density - mid, 0.0)) 54 | if abs(1.0 - area) <= tol: 55 | break 56 | if area < 1.0: 57 | hi = mid 58 | else: 59 | lo = mid 60 | 61 | # update in place 62 | density -= mid 63 | density[density < 0.0] = 0.0 64 | 65 | 66 | def sharpen(cde_estimates, alpha): 67 | """Sharpens conditional density estimates. 68 | 69 | Assumes densities are evaluated on the unit grid. 70 | 71 | :param cde_estimates: a numpy array or matrix of conditional density estimates. 72 | :param alpha: float, the exponent to which the estimate is raised. 73 | :returns: the sharpened conditional density estimate. 74 | :rtype: numpy array or matrix. 75 | 76 | """ 77 | cde_estimates **= alpha 78 | normalize(cde_estimates) 79 | 80 | 81 | def choose_sharpen(cde_estimates, z_grid, true_z, alpha_grid): 82 | """Chooses the sharpen parameter by minimizing cde loss. 83 | 84 | :param cde_estimates: a numpy matrix of conditional density estimates 85 | :param true_z: an array of the true z values corresponding to the cde_estimates. 86 | :param alpha_grid: an array of candidate sharpen parameter values. 87 | :returns: the sharpen parameter value from alpha_grid which minimizes cde loss. 88 | :rtype: float 89 | 90 | """ 91 | best_alpha = None 92 | best_loss = np.inf 93 | for alpha in alpha_grid: 94 | tmp_estimates = cde_estimates.copy() 95 | sharpen(tmp_estimates, alpha) 96 | loss = cde_loss(tmp_estimates, z_grid, true_z) 97 | if loss < best_loss: 98 | best_loss = loss 99 | best_alpha = alpha 100 | return best_alpha 101 | 102 | 103 | def remove_bumps(cde_estimates, delta): 104 | """Removes bumps in conditional density estimates 105 | 106 | Assumes that cde_estimates are on the unit grid. 107 | 108 | :param cde_estimates: a numpy array or matrix of conditional density estimates. 109 | :param delta: float, the threshold for bump removal 110 | :returns: the conditional density estimates with bumps removed 111 | :rtype: numpy array or matrix 112 | 113 | """ 114 | if cde_estimates.ndim == 1: 115 | _remove_bumps(cde_estimates, delta) 116 | else: 117 | np.apply_along_axis(_remove_bumps, 1, cde_estimates, delta=delta) 118 | 119 | 120 | def _remove_bumps(density, delta): 121 | """Removes bumps in conditional density estimates. 122 | 123 | Assumes estimates are on the unit grid. 124 | 125 | :param density: a numpy array of conditional density estimate. 126 | :param delta: float, the threshold for bump removal. 127 | :returns: the conditional density estimate with bumps removed. 128 | :rtype: numpy array. 129 | 130 | """ 131 | bin_size = 1.0 / len(density) 132 | area = 0.0 133 | left_idx = 0 134 | removed_area = 0.0 135 | for right_idx, val in enumerate(density): 136 | if val <= 0.0: 137 | if area < delta: 138 | density[left_idx : (right_idx + 1)] = 0.0 139 | removed_area += area 140 | left_idx = right_idx + 1 141 | area = 0.0 142 | else: 143 | area += val * bin_size 144 | if area < delta: # final check at end 145 | density[left_idx:] = 0.0 146 | removed_area += area 147 | _normalize(density) 148 | 149 | 150 | def choose_bump_threshold(cde_estimates, z_grid, true_z, delta_grid): 151 | """Chooses the bump threshold which minimizes cde loss. 152 | 153 | :param cde_estimates: a numpy array or matrix of conditional density estimates. 154 | :param z_grid: an array, the grid points at which the density is estimated.b 155 | :param true_z: the true z values corresponding to the conditional 156 | denstity estimates. 157 | :param delta_grid: an array of candidate bump threshold values 158 | :returns: the bump threshold value from delta_grid which minimizes CDE loss 159 | :rtype: float 160 | 161 | """ 162 | best_delta = None 163 | best_loss = np.inf 164 | for delta in delta_grid: 165 | tmp_estimates = cde_estimates.copy() 166 | remove_bumps(tmp_estimates, delta) 167 | normalize(tmp_estimates) 168 | loss = cde_loss(tmp_estimates, z_grid, true_z) 169 | if loss < best_loss: 170 | best_loss = loss 171 | best_delta = delta 172 | return best_delta 173 | -------------------------------------------------------------------------------- /src/flexcode/regression_models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .helpers import params_dict_optim_decision, params_name_format 4 | 5 | try: 6 | import xgboost as xgb 7 | 8 | XGBOOST_AVAILABLE = True 9 | except ImportError: 10 | XGBOOST_AVAILABLE = False 11 | 12 | try: 13 | import sklearn.ensemble 14 | import sklearn.linear_model 15 | import sklearn.model_selection 16 | import sklearn.multioutput 17 | import sklearn.neighbors 18 | 19 | SKLEARN_AVAILABLE = True 20 | except ImportError: 21 | SKLEARN_AVAILABLE = False 22 | 23 | 24 | class FlexCodeRegression(object): 25 | def __init__(self, max_basis): 26 | self.max_basis = max_basis 27 | 28 | def fit(self, x_train, z_basis, weight): 29 | pass 30 | 31 | def predict(self, x_new): 32 | pass 33 | 34 | 35 | class NN(FlexCodeRegression): 36 | def __init__(self, max_basis, params, *args, **kwargs): 37 | if not SKLEARN_AVAILABLE: 38 | raise Exception("NN requires scikit-learn to be installed") 39 | 40 | super(NN, self).__init__(max_basis) 41 | 42 | # Historically, we have used 'k' to indicate the number of neighbors, so 43 | # this just puts the right notation for KNeighborsRegressor 44 | if "k" in params: 45 | params["n_neighbors"] = params["k"] 46 | del params["k"] 47 | params_opt, opt_flag = params_dict_optim_decision(params, multi_output=True) 48 | self.params = params_opt 49 | self.models = ( 50 | None 51 | if opt_flag 52 | else sklearn.multioutput.MultiOutputRegressor( 53 | sklearn.neighbors.KNeighborsRegressor(**self.params), n_jobs=-1 54 | ) 55 | ) 56 | 57 | def fit(self, x_train, z_basis, weight): 58 | if weight is not None: 59 | raise Exception("Weights not implemented for NN") 60 | 61 | if self.models is None: 62 | self.cv_optim(x_train, z_basis) 63 | 64 | self.models.fit(x_train, z_basis) 65 | 66 | def cv_optim(self, x_train, z_basis): 67 | nn_obj = sklearn.multioutput.MultiOutputRegressor(sklearn.neighbors.KNeighborsRegressor(), n_jobs=-1) 68 | clf = sklearn.model_selection.GridSearchCV( 69 | nn_obj, self.params, cv=5, scoring="neg_mean_squared_error", verbose=2 70 | ) 71 | clf.fit(x_train, z_basis) 72 | 73 | self.params = params_name_format(clf.best_params_, str_rem="estimator__") 74 | self.models = sklearn.multioutput.MultiOutputRegressor( 75 | sklearn.neighbors.KNeighborsRegressor(**self.params), n_jobs=-1 76 | ) 77 | 78 | def predict(self, x_test): 79 | coefs = self.models.predict(x_test) 80 | return coefs 81 | 82 | 83 | class RandomForest(FlexCodeRegression): 84 | def __init__(self, max_basis, params, *args, **kwargs): 85 | if not SKLEARN_AVAILABLE: 86 | raise Exception("RandomForest requires scikit-learn to be installed") 87 | 88 | super(RandomForest, self).__init__(max_basis) 89 | 90 | params_opt, opt_flag = params_dict_optim_decision(params, multi_output=True) 91 | self.params = params_opt 92 | self.models = ( 93 | None 94 | if opt_flag 95 | else sklearn.multioutput.MultiOutputRegressor( 96 | sklearn.ensemble.RandomForestRegressor(**self.params), n_jobs=-1 97 | ) 98 | ) 99 | 100 | def fit(self, x_train, z_basis, weight=None): 101 | if self.models is None: 102 | self.cv_optim(x_train, z_basis, weight) 103 | 104 | self.models.fit(x_train, z_basis, sample_weight=weight) 105 | 106 | def cv_optim(self, x_train, z_basis, weight=None): 107 | rf_obj = sklearn.multioutput.MultiOutputRegressor(sklearn.ensemble.RandomForestRegressor(), n_jobs=-1) 108 | clf = sklearn.model_selection.GridSearchCV( 109 | rf_obj, self.params, cv=5, scoring="neg_mean_squared_error", verbose=2 110 | ) 111 | clf.fit(x_train, z_basis, sample_weight=weight) 112 | 113 | self.params = params_name_format(clf.best_params_, str_rem="estimator__") 114 | self.models = sklearn.multioutput.MultiOutputRegressor( 115 | sklearn.ensemble.RandomForestRegressor(**self.params), n_jobs=-1 116 | ) 117 | 118 | def predict(self, x_test): 119 | coefs = self.models.predict(x_test) 120 | return coefs 121 | 122 | 123 | class XGBoost(FlexCodeRegression): 124 | def __init__(self, max_basis, params, *args, **kwargs): 125 | if not XGBOOST_AVAILABLE: 126 | raise Exception("XGBoost requires xgboost to be installed") 127 | super(XGBoost, self).__init__(max_basis) 128 | 129 | # Historically, people have used `eta` for `learning_rate` - taking that 130 | # into account 131 | if "eta" in params: 132 | params["learning_rate"] = params["eta"] 133 | del params["eta"] 134 | 135 | # Also, set the default values if not passed 136 | params["max_depth"] = params.get("max_depth", 6) 137 | params["learning_rate"] = params.get("learning_rate", 0.3) 138 | params["silent"] = params.get("silent", 1) 139 | params["objective"] = params.get("objective", "reg:linear") 140 | 141 | params_opt, opt_flag = params_dict_optim_decision(params, multi_output=True) 142 | self.params = params_opt 143 | self.models = ( 144 | None 145 | if opt_flag 146 | else sklearn.multioutput.MultiOutputRegressor(xgb.XGBRegressor(**self.params), n_jobs=-1) 147 | ) 148 | 149 | def fit(self, x_train, z_basis, weight=None): 150 | if self.models is None: 151 | self.cv_optim(x_train, z_basis, weight) 152 | 153 | self.models.fit(x_train, z_basis, sample_weight=weight) 154 | 155 | def cv_optim(self, x_train, z_basis, weight=None): 156 | xgb_obj = sklearn.multioutput.MultiOutputRegressor(xgb.XGBRegressor(), n_jobs=-1) 157 | clf = sklearn.model_selection.GridSearchCV( 158 | xgb_obj, self.params, cv=5, scoring="neg_mean_squared_error", verbose=2 159 | ) 160 | clf.fit(x_train, z_basis, sample_weight=weight) 161 | 162 | self.params = params_name_format(clf.best_params_, str_rem="estimator__") 163 | self.models = sklearn.multioutput.MultiOutputRegressor(xgb.XGBRegressor(**self.params), n_jobs=-1) 164 | 165 | def predict(self, x_test): 166 | coefs = self.models.predict(x_test) 167 | return coefs 168 | 169 | 170 | class Lasso(FlexCodeRegression): 171 | def __init__(self, max_basis, params, *args, **kwargs): 172 | if not SKLEARN_AVAILABLE: 173 | raise Exception("Lasso requires scikit-learn to be installed") 174 | super(Lasso, self).__init__(max_basis) 175 | 176 | # Also, set the default values if not passed 177 | params["alpha"] = params.get("alpha", 1.0) 178 | params["l1_ratio"] = params.get("l1_ratio", 1.0) 179 | 180 | params_opt, opt_flag = params_dict_optim_decision(params, multi_output=True) 181 | self.params = params_opt 182 | self.models = ( 183 | None 184 | if opt_flag 185 | else sklearn.multioutput.MultiOutputRegressor( 186 | sklearn.linear_model.ElasticNet(**self.params), n_jobs=-1 187 | ) 188 | ) 189 | 190 | def fit(self, x_train, z_basis, weight=None): 191 | if weight is not None: 192 | raise ValueError( 193 | "Weights are not supported in the ElasticNet/Lasso " "implementation in scikit-learn." 194 | ) 195 | 196 | if self.models is None: 197 | self.cv_optim(x_train, z_basis) 198 | 199 | self.models.fit(x_train, z_basis) 200 | 201 | def cv_optim(self, x_train, z_basis): 202 | lasso_obj = sklearn.multioutput.MultiOutputRegressor(sklearn.linear_model.ElasticNet(), n_jobs=-1) 203 | clf = sklearn.model_selection.GridSearchCV( 204 | lasso_obj, self.params, cv=5, scoring="neg_mean_squared_error", verbose=2 205 | ) 206 | clf.fit(x_train, z_basis) 207 | 208 | self.params = params_name_format(clf.best_params_, str_rem="estimator__") 209 | self.models = sklearn.multioutput.MultiOutputRegressor( 210 | sklearn.linear_model.ElasticNet(**self.params), n_jobs=-1 211 | ) 212 | 213 | def predict(self, x_test): 214 | coefs = self.models.predict(x_test) 215 | return coefs 216 | 217 | 218 | class CustomModel(FlexCodeRegression): 219 | def __init__(self, max_basis, params, custom_model, *args, **kwargs): 220 | if not SKLEARN_AVAILABLE: 221 | raise Exception("Custom class requires scikit-learn to be installed") 222 | super(CustomModel, self).__init__(max_basis) 223 | 224 | params_opt, opt_flag = params_dict_optim_decision(params, multi_output=True) 225 | self.params = params_opt 226 | self.base_model = custom_model 227 | self.models = ( 228 | None 229 | if opt_flag 230 | else sklearn.multioutput.MultiOutputRegressor(self.base_model(**self.params), n_jobs=-1) 231 | ) 232 | 233 | def fit(self, x_train, z_basis, weight=None): 234 | # Given it's a custom class, work would need to be done 235 | # for sample weights - for now this is not implemented. 236 | if weight: 237 | raise NotImplementedError("Weights for custom class not implemented.") 238 | 239 | if self.models is None: 240 | self.cv_optim(x_train, z_basis) 241 | 242 | self.models.fit(x_train, z_basis) 243 | 244 | def cv_optim(self, x_train, z_basis): 245 | custom_obj = sklearn.multioutput.MultiOutputRegressor(self.base_model(), n_jobs=-1) 246 | clf = sklearn.model_selection.GridSearchCV( 247 | custom_obj, self.params, cv=5, scoring="neg_mean_squared_error", verbose=2 248 | ) 249 | clf.fit(x_train, z_basis) 250 | 251 | self.params = params_name_format(clf.best_params_, str_rem="estimator__") 252 | self.models = sklearn.multioutput.MultiOutputRegressor(self.base_model(**self.params), n_jobs=-1) 253 | 254 | def predict(self, x_test): 255 | coefs = self.models.predict(x_test) 256 | return coefs 257 | -------------------------------------------------------------------------------- /tests/flexcode/.pylintrc: -------------------------------------------------------------------------------- 1 | [MAIN] 2 | 3 | # Analyse import fallback blocks. This can be used to support both Python 2 and 4 | # 3 compatible code, which means that the block might have code that exists 5 | # only in one or another interpreter, leading to false positives when analysed. 6 | analyse-fallback-blocks=no 7 | 8 | # Clear in-memory caches upon conclusion of linting. Useful if running pylint 9 | # in a server-like mode. 10 | clear-cache-post-run=no 11 | 12 | # Load and enable all available extensions. Use --list-extensions to see a list 13 | # all available extensions. 14 | #enable-all-extensions= 15 | 16 | # In error mode, messages with a category besides ERROR or FATAL are 17 | # suppressed, and no reports are done by default. Error mode is compatible with 18 | # disabling specific errors. 19 | #errors-only= 20 | 21 | # Always return a 0 (non-error) status code, even if lint errors are found. 22 | # This is primarily useful in continuous integration scripts. 23 | #exit-zero= 24 | 25 | # A comma-separated list of package or module names from where C extensions may 26 | # be loaded. Extensions are loading into the active Python interpreter and may 27 | # run arbitrary code. 28 | extension-pkg-allow-list= 29 | 30 | # A comma-separated list of package or module names from where C extensions may 31 | # be loaded. Extensions are loading into the active Python interpreter and may 32 | # run arbitrary code. (This is an alternative name to extension-pkg-allow-list 33 | # for backward compatibility.) 34 | extension-pkg-whitelist= 35 | 36 | # Return non-zero exit code if any of these messages/categories are detected, 37 | # even if score is above --fail-under value. Syntax same as enable. Messages 38 | # specified are enabled, while categories only check already-enabled messages. 39 | fail-on= 40 | 41 | # Specify a score threshold under which the program will exit with error. 42 | fail-under=10 43 | 44 | # Interpret the stdin as a python script, whose filename needs to be passed as 45 | # the module_or_package argument. 46 | #from-stdin= 47 | 48 | # Files or directories to be skipped. They should be base names, not paths. 49 | ignore=CVS 50 | 51 | # Add files or directories matching the regular expressions patterns to the 52 | # ignore-list. The regex matches against paths and can be in Posix or Windows 53 | # format. Because '\\' represents the directory delimiter on Windows systems, 54 | # it can't be used as an escape character. 55 | ignore-paths= 56 | 57 | # Files or directories matching the regular expression patterns are skipped. 58 | # The regex matches against base names, not paths. 59 | ignore-patterns=_version.py 60 | 61 | # List of module names for which member attributes should not be checked 62 | # (useful for modules/projects where namespaces are manipulated during runtime 63 | # and thus existing member attributes cannot be deduced by static analysis). It 64 | # supports qualified module names, as well as Unix pattern matching. 65 | ignored-modules= 66 | 67 | # Python code to execute, usually for sys.path manipulation such as 68 | # pygtk.require(). 69 | #init-hook= 70 | 71 | # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the 72 | # number of processors available to use, and will cap the count on Windows to 73 | # avoid hangs. 74 | jobs=1 75 | 76 | # Control the amount of potential inferred values when inferring a single 77 | # object. This can help the performance when dealing with large functions or 78 | # complex, nested conditions. 79 | limit-inference-results=100 80 | 81 | # List of plugins (as comma separated values of python module names) to load, 82 | # usually to register additional checkers. 83 | load-plugins= 84 | 85 | # Pickle collected data for later comparisons. 86 | persistent=yes 87 | 88 | # Minimum Python version to use for version dependent checks. Will default to 89 | # the version used to run pylint. 90 | py-version=3.9 91 | 92 | # Discover python modules and packages in the file system subtree. 93 | recursive=no 94 | 95 | # When enabled, pylint would attempt to guess common misconfiguration and emit 96 | # user-friendly hints instead of false-positive error messages. 97 | suggestion-mode=yes 98 | 99 | # Allow loading of arbitrary C extensions. Extensions are imported into the 100 | # active Python interpreter and may run arbitrary code. 101 | unsafe-load-any-extension=no 102 | 103 | # In verbose mode, extra non-checker-related info will be displayed. 104 | #verbose= 105 | 106 | 107 | [BASIC] 108 | 109 | # Naming style matching correct argument names. 110 | argument-naming-style=snake_case 111 | 112 | # Regular expression matching correct argument names. Overrides argument- 113 | # naming-style. If left empty, argument names will be checked with the set 114 | # naming style. 115 | #argument-rgx= 116 | 117 | # Naming style matching correct attribute names. 118 | attr-naming-style=snake_case 119 | 120 | # Regular expression matching correct attribute names. Overrides attr-naming- 121 | # style. If left empty, attribute names will be checked with the set naming 122 | # style. 123 | #attr-rgx= 124 | 125 | # Bad variable names which should always be refused, separated by a comma. 126 | bad-names=foo, 127 | bar, 128 | baz, 129 | toto, 130 | tutu, 131 | tata 132 | 133 | # Bad variable names regexes, separated by a comma. If names match any regex, 134 | # they will always be refused 135 | bad-names-rgxs= 136 | 137 | # Naming style matching correct class attribute names. 138 | class-attribute-naming-style=any 139 | 140 | # Regular expression matching correct class attribute names. Overrides class- 141 | # attribute-naming-style. If left empty, class attribute names will be checked 142 | # with the set naming style. 143 | #class-attribute-rgx= 144 | 145 | # Naming style matching correct class constant names. 146 | class-const-naming-style=UPPER_CASE 147 | 148 | # Regular expression matching correct class constant names. Overrides class- 149 | # const-naming-style. If left empty, class constant names will be checked with 150 | # the set naming style. 151 | #class-const-rgx= 152 | 153 | # Naming style matching correct class names. 154 | class-naming-style=PascalCase 155 | 156 | # Regular expression matching correct class names. Overrides class-naming- 157 | # style. If left empty, class names will be checked with the set naming style. 158 | #class-rgx= 159 | 160 | # Naming style matching correct constant names. 161 | const-naming-style=UPPER_CASE 162 | 163 | # Regular expression matching correct constant names. Overrides const-naming- 164 | # style. If left empty, constant names will be checked with the set naming 165 | # style. 166 | #const-rgx= 167 | 168 | # Minimum line length for functions/classes that require docstrings, shorter 169 | # ones are exempt. 170 | docstring-min-length=-1 171 | 172 | # Naming style matching correct function names. 173 | function-naming-style=snake_case 174 | 175 | # Regular expression matching correct function names. Overrides function- 176 | # naming-style. If left empty, function names will be checked with the set 177 | # naming style. 178 | #function-rgx= 179 | 180 | # Good variable names which should always be accepted, separated by a comma. 181 | good-names=i, 182 | j, 183 | k, 184 | ex, 185 | Run, 186 | _ 187 | 188 | # Good variable names regexes, separated by a comma. If names match any regex, 189 | # they will always be accepted 190 | good-names-rgxs= 191 | 192 | # Include a hint for the correct naming format with invalid-name. 193 | include-naming-hint=no 194 | 195 | # Naming style matching correct inline iteration names. 196 | inlinevar-naming-style=any 197 | 198 | # Regular expression matching correct inline iteration names. Overrides 199 | # inlinevar-naming-style. If left empty, inline iteration names will be checked 200 | # with the set naming style. 201 | #inlinevar-rgx= 202 | 203 | # Naming style matching correct method names. 204 | method-naming-style=snake_case 205 | 206 | # Regular expression matching correct method names. Overrides method-naming- 207 | # style. If left empty, method names will be checked with the set naming style. 208 | #method-rgx= 209 | 210 | # Naming style matching correct module names. 211 | module-naming-style=snake_case 212 | 213 | # Regular expression matching correct module names. Overrides module-naming- 214 | # style. If left empty, module names will be checked with the set naming style. 215 | #module-rgx= 216 | 217 | # Colon-delimited sets of names that determine each other's naming style when 218 | # the name regexes allow several styles. 219 | name-group= 220 | 221 | # Regular expression which should only match function or class names that do 222 | # not require a docstring. 223 | no-docstring-rgx=^_ 224 | 225 | # List of decorators that produce properties, such as abc.abstractproperty. Add 226 | # to this list to register other decorators that produce valid properties. 227 | # These decorators are taken in consideration only for invalid-name. 228 | property-classes=abc.abstractproperty 229 | 230 | # Regular expression matching correct type variable names. If left empty, type 231 | # variable names will be checked with the set naming style. 232 | #typevar-rgx= 233 | 234 | # Naming style matching correct variable names. 235 | variable-naming-style=snake_case 236 | 237 | # Regular expression matching correct variable names. Overrides variable- 238 | # naming-style. If left empty, variable names will be checked with the set 239 | # naming style. 240 | #variable-rgx= 241 | 242 | 243 | [CLASSES] 244 | 245 | # Warn about protected attribute access inside special methods 246 | check-protected-access-in-special-methods=no 247 | 248 | # List of method names used to declare (i.e. assign) instance attributes. 249 | defining-attr-methods=__init__, 250 | __new__, 251 | setUp, 252 | __post_init__ 253 | 254 | # List of member names, which should be excluded from the protected access 255 | # warning. 256 | exclude-protected=_asdict, 257 | _fields, 258 | _replace, 259 | _source, 260 | _make 261 | 262 | # List of valid names for the first argument in a class method. 263 | valid-classmethod-first-arg=cls 264 | 265 | # List of valid names for the first argument in a metaclass class method. 266 | valid-metaclass-classmethod-first-arg=mcs 267 | 268 | 269 | [DESIGN] 270 | 271 | # List of regular expressions of class ancestor names to ignore when counting 272 | # public methods (see R0903) 273 | exclude-too-few-public-methods= 274 | 275 | # List of qualified class names to ignore when counting class parents (see 276 | # R0901) 277 | ignored-parents= 278 | 279 | # Maximum number of arguments for function / method. 280 | max-args=5 281 | 282 | # Maximum number of attributes for a class (see R0902). 283 | max-attributes=7 284 | 285 | # Maximum number of boolean expressions in an if statement (see R0916). 286 | max-bool-expr=5 287 | 288 | # Maximum number of branch for function / method body. 289 | max-branches=12 290 | 291 | # Maximum number of locals for function / method body. 292 | max-locals=15 293 | 294 | # Maximum number of parents for a class (see R0901). 295 | max-parents=7 296 | 297 | # Maximum number of public methods for a class (see R0904). 298 | max-public-methods=20 299 | 300 | # Maximum number of return / yield for function / method body. 301 | max-returns=6 302 | 303 | # Maximum number of statements in function / method body. 304 | max-statements=50 305 | 306 | # Minimum number of public methods for a class (see R0903). 307 | min-public-methods=2 308 | 309 | 310 | [EXCEPTIONS] 311 | 312 | # Exceptions that will emit a warning when caught. 313 | overgeneral-exceptions=builtins.BaseException,builtins.Exception 314 | 315 | 316 | [FORMAT] 317 | 318 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. 319 | expected-line-ending-format= 320 | 321 | # Regexp for a line that is allowed to be longer than the limit. 322 | ignore-long-lines=^\s*(# )??$ 323 | 324 | # Number of spaces of indent required inside a hanging or continued line. 325 | indent-after-paren=4 326 | 327 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 328 | # tab). 329 | indent-string=' ' 330 | 331 | # Maximum number of characters on a single line. 332 | max-line-length=100 333 | 334 | # Maximum number of lines in a module. 335 | max-module-lines=1000 336 | 337 | # Allow the body of a class to be on the same line as the declaration if body 338 | # contains single statement. 339 | single-line-class-stmt=no 340 | 341 | # Allow the body of an if to be on the same line as the test if there is no 342 | # else. 343 | single-line-if-stmt=no 344 | 345 | 346 | [IMPORTS] 347 | 348 | # List of modules that can be imported at any level, not just the top level 349 | # one. 350 | allow-any-import-level= 351 | 352 | # Allow explicit reexports by alias from a package __init__. 353 | allow-reexport-from-package=no 354 | 355 | # Allow wildcard imports from modules that define __all__. 356 | allow-wildcard-with-all=no 357 | 358 | # Deprecated modules which should not be used, separated by a comma. 359 | deprecated-modules= 360 | 361 | # Output a graph (.gv or any supported image format) of external dependencies 362 | # to the given file (report RP0402 must not be disabled). 363 | ext-import-graph= 364 | 365 | # Output a graph (.gv or any supported image format) of all (i.e. internal and 366 | # external) dependencies to the given file (report RP0402 must not be 367 | # disabled). 368 | import-graph= 369 | 370 | # Output a graph (.gv or any supported image format) of internal dependencies 371 | # to the given file (report RP0402 must not be disabled). 372 | int-import-graph= 373 | 374 | # Force import order to recognize a module as part of the standard 375 | # compatibility libraries. 376 | known-standard-library= 377 | 378 | # Force import order to recognize a module as part of a third party library. 379 | known-third-party=enchant 380 | 381 | # Couples of modules and preferred modules, separated by a comma. 382 | preferred-modules= 383 | 384 | 385 | [LOGGING] 386 | 387 | # The type of string formatting that logging methods do. `old` means using % 388 | # formatting, `new` is for `{}` formatting. 389 | logging-format-style=old 390 | 391 | # Logging modules to check that the string format arguments are in logging 392 | # function parameter format. 393 | logging-modules=logging 394 | 395 | 396 | [MESSAGES CONTROL] 397 | 398 | # Only show warnings with the listed confidence levels. Leave empty to show 399 | # all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, 400 | # UNDEFINED. 401 | confidence=HIGH, 402 | CONTROL_FLOW, 403 | INFERENCE, 404 | INFERENCE_FAILURE, 405 | UNDEFINED 406 | 407 | # Disable the message, report, category or checker with the given id(s). You 408 | # can either give multiple identifiers separated by comma (,) or put this 409 | # option multiple times (only on the command line, not in the configuration 410 | # file where it should appear only once). You can also use "--disable=all" to 411 | # disable everything first and then re-enable specific checks. For example, if 412 | # you want to run only the similarities checker, you can use "--disable=all 413 | # --enable=similarities". If you want to run only the classes checker, but have 414 | # no Warning level messages displayed, use "--disable=all --enable=classes 415 | # --disable=W". 416 | disable=raw-checker-failed, 417 | bad-inline-option, 418 | locally-disabled, 419 | file-ignored, 420 | suppressed-message, 421 | useless-suppression, 422 | deprecated-pragma, 423 | use-symbolic-message-instead, 424 | missing-function-docstring, 425 | redefined-outer-name, 426 | protected-access, 427 | missing-module-docstring, 428 | 429 | # Enable the message, report, category or checker with the given id(s). You can 430 | # either give multiple identifier separated by comma (,) or put this option 431 | # multiple time (only on the command line, not in the configuration file where 432 | # it should appear only once). See also the "--disable" option for examples. 433 | enable=c-extension-no-member 434 | 435 | 436 | [METHOD_ARGS] 437 | 438 | # List of qualified names (i.e., library.method) which require a timeout 439 | # parameter e.g. 'requests.api.get,requests.api.post' 440 | timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request 441 | 442 | 443 | [MISCELLANEOUS] 444 | 445 | # List of note tags to take in consideration, separated by a comma. 446 | notes=FIXME, 447 | XXX, 448 | TODO 449 | 450 | # Regular expression of note tags to take in consideration. 451 | notes-rgx= 452 | 453 | 454 | [REFACTORING] 455 | 456 | # Maximum number of nested blocks for function / method body 457 | max-nested-blocks=5 458 | 459 | # Complete name of functions that never returns. When checking for 460 | # inconsistent-return-statements if a never returning function is called then 461 | # it will be considered as an explicit return statement and no message will be 462 | # printed. 463 | never-returning-functions=sys.exit,argparse.parse_error 464 | 465 | 466 | [REPORTS] 467 | 468 | # Python expression which should return a score less than or equal to 10. You 469 | # have access to the variables 'fatal', 'error', 'warning', 'refactor', 470 | # 'convention', and 'info' which contain the number of messages in each 471 | # category, as well as 'statement' which is the total number of statements 472 | # analyzed. This score is used by the global evaluation report (RP0004). 473 | evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)) 474 | 475 | # Template used to display messages. This is a python new-style format string 476 | # used to format the message information. See doc for all details. 477 | msg-template= 478 | 479 | # Set the output format. Available formats are text, parseable, colorized, json 480 | # and msvs (visual studio). You can also give a reporter class, e.g. 481 | # mypackage.mymodule.MyReporterClass. 482 | #output-format= 483 | 484 | # Tells whether to display a full report or only the messages. 485 | reports=no 486 | 487 | # Activate the evaluation score. 488 | score=yes 489 | 490 | 491 | [SIMILARITIES] 492 | 493 | # Comments are removed from the similarity computation 494 | ignore-comments=yes 495 | 496 | # Docstrings are removed from the similarity computation 497 | ignore-docstrings=yes 498 | 499 | # Imports are removed from the similarity computation 500 | ignore-imports=yes 501 | 502 | # Signatures are removed from the similarity computation 503 | ignore-signatures=yes 504 | 505 | # Minimum lines number of a similarity. 506 | min-similarity-lines=4 507 | 508 | 509 | [SPELLING] 510 | 511 | # Limits count of emitted suggestions for spelling mistakes. 512 | max-spelling-suggestions=4 513 | 514 | # Spelling dictionary name. Available dictionaries: none. To make it work, 515 | # install the 'python-enchant' package. 516 | spelling-dict= 517 | 518 | # List of comma separated words that should be considered directives if they 519 | # appear at the beginning of a comment and should not be checked. 520 | spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: 521 | 522 | # List of comma separated words that should not be checked. 523 | spelling-ignore-words= 524 | 525 | # A path to a file that contains the private dictionary; one word per line. 526 | spelling-private-dict-file= 527 | 528 | # Tells whether to store unknown words to the private dictionary (see the 529 | # --spelling-private-dict-file option) instead of raising a message. 530 | spelling-store-unknown-words=no 531 | 532 | 533 | [STRING] 534 | 535 | # This flag controls whether inconsistent-quotes generates a warning when the 536 | # character used as a quote delimiter is used inconsistently within a module. 537 | check-quote-consistency=no 538 | 539 | # This flag controls whether the implicit-str-concat should generate a warning 540 | # on implicit string concatenation in sequences defined over several lines. 541 | check-str-concat-over-line-jumps=no 542 | 543 | 544 | [TYPECHECK] 545 | 546 | # List of decorators that produce context managers, such as 547 | # contextlib.contextmanager. Add to this list to register other decorators that 548 | # produce valid context managers. 549 | contextmanager-decorators=contextlib.contextmanager 550 | 551 | # List of members which are set dynamically and missed by pylint inference 552 | # system, and so shouldn't trigger E1101 when accessed. Python regular 553 | # expressions are accepted. 554 | generated-members= 555 | 556 | # Tells whether to warn about missing members when the owner of the attribute 557 | # is inferred to be None. 558 | ignore-none=yes 559 | 560 | # This flag controls whether pylint should warn about no-member and similar 561 | # checks whenever an opaque object is returned when inferring. The inference 562 | # can return multiple potential results while evaluating a Python object, but 563 | # some branches might not be evaluated, which results in partial inference. In 564 | # that case, it might be useful to still emit no-member and other checks for 565 | # the rest of the inferred objects. 566 | ignore-on-opaque-inference=yes 567 | 568 | # List of symbolic message names to ignore for Mixin members. 569 | ignored-checks-for-mixins=no-member, 570 | not-async-context-manager, 571 | not-context-manager, 572 | attribute-defined-outside-init 573 | 574 | # List of class names for which member attributes should not be checked (useful 575 | # for classes with dynamically set attributes). This supports the use of 576 | # qualified names. 577 | ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace 578 | 579 | # Show a hint with possible names when a member name was not found. The aspect 580 | # of finding the hint is based on edit distance. 581 | missing-member-hint=yes 582 | 583 | # The minimum edit distance a name should have in order to be considered a 584 | # similar match for a missing member name. 585 | missing-member-hint-distance=1 586 | 587 | # The total number of similar names that should be taken in consideration when 588 | # showing a hint for a missing member. 589 | missing-member-max-choices=1 590 | 591 | # Regex pattern to define which classes are considered mixins. 592 | mixin-class-rgx=.*[Mm]ixin 593 | 594 | # List of decorators that change the signature of a decorated function. 595 | signature-mutators= 596 | 597 | 598 | [VARIABLES] 599 | 600 | # List of additional names supposed to be defined in builtins. Remember that 601 | # you should avoid defining new builtins when possible. 602 | additional-builtins= 603 | 604 | # Tells whether unused global variables should be treated as a violation. 605 | allow-global-unused-variables=yes 606 | 607 | # List of names allowed to shadow builtins 608 | allowed-redefined-builtins= 609 | 610 | # List of strings which can identify a callback function by name. A callback 611 | # name must start or end with one of those strings. 612 | callbacks=cb_, 613 | _cb 614 | 615 | # A regular expression matching the name of dummy variables (i.e. expected to 616 | # not be used). 617 | dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ 618 | 619 | # Argument names that match this expression will be ignored. 620 | ignored-argument-names=_.*|^ignored_|^unused_ 621 | 622 | # Tells whether we should check for unused import in __init__ files. 623 | init-import=no 624 | 625 | # List of qualified module names which can have objects that can redefine 626 | # builtins. 627 | redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io 628 | -------------------------------------------------------------------------------- /tests/flexcode/conftest.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | # defining some constants used throughout the test suite 5 | BUMP_THRESHOLD_GRID = np.linspace(0, 0.2, 3) 6 | SHARPEN_GRID = np.linspace(0.5, 1.5, 3) 7 | 8 | 9 | def generate_data(n_draws): 10 | """Generate data p(z | x) = N(x, 1) 11 | 12 | Parameters 13 | ---------- 14 | n_draws : int 15 | number of samples to generate 16 | 17 | Returns 18 | ------- 19 | x : List[float] 20 | Samples drawn from a 0, 1 normal distribution 21 | z : List[float] 22 | Samples drawn from a `x`, 1 normal distribution. Where `x` is random variate 23 | created earlier. 24 | """ 25 | x = np.random.normal(0, 1, n_draws) 26 | z = np.random.normal(x, 1, n_draws) 27 | return x, z 28 | -------------------------------------------------------------------------------- /tests/flexcode/context.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) 5 | 6 | import flexcode 7 | -------------------------------------------------------------------------------- /tests/flexcode/test_cv_optim.py: -------------------------------------------------------------------------------- 1 | """This module tests a particular code path in the FlexCodeModel.fit function. 2 | Each of the test cases specifies a flexcode.FlexCodeModel that is parameterized 3 | with a `regression_params` variable that defines a dictionary with a value that 4 | is an array. 5 | 6 | For example, in `test_coef_predict_same_as_predict_NN`, the model is defined with 7 | the regression parameter `"k": [5, 30]`. The effect of this that within the 8 | `NN.fit` method, the `self.cv_optim` method will be called instead of `self.models.fit`. 9 | 10 | This module is structurally similar to `test_models_fit`, but tests a slightly 11 | different code path. 12 | """ 13 | 14 | import numpy as np 15 | import xgboost as xgb 16 | from conftest import BUMP_THRESHOLD_GRID, SHARPEN_GRID, generate_data 17 | 18 | import flexcode 19 | from flexcode.regression_models import NN, CustomModel, Lasso, RandomForest, XGBoost 20 | 21 | 22 | def test_coef_predict_same_as_predict_nn(): 23 | # Here we generate 3000 random variates because 1000 (like the other tests) 24 | # was causing test instability. 25 | x_train, z_train = generate_data(3000) 26 | x_validation, z_validation = generate_data(3000) 27 | x_test, _ = generate_data(3000) 28 | 29 | # Parameterize model 30 | model = flexcode.FlexCodeModel(NN, max_basis=31, basis_system="cosine", regression_params={"k": [5, 30]}) 31 | 32 | # Fit and tune model 33 | model.fit(x_train, z_train) 34 | model.tune(x_validation, z_validation, bump_threshold_grid=BUMP_THRESHOLD_GRID, sharpen_grid=SHARPEN_GRID) 35 | 36 | cdes_predict, z_grid = model.predict(x_test, n_grid=200) 37 | 38 | coefs = model.predict_coefs(x_test) 39 | cdes_coefs = coefs.evaluate(z_grid) 40 | 41 | assert np.max(np.abs(cdes_predict - cdes_coefs)) <= 1e-4 42 | 43 | 44 | def test_coef_predict_same_as_predict_rf(): 45 | x_train, z_train = generate_data(1000) 46 | x_validation, z_validation = generate_data(1000) 47 | x_test, _ = generate_data(1000) 48 | 49 | # Parameterize model 50 | model = flexcode.FlexCodeModel( 51 | RandomForest, 52 | max_basis=31, 53 | basis_system="cosine", 54 | regression_params={"n_estimators": [10, 30], "min_samples_split": [2]}, 55 | ) 56 | 57 | # Fit and tune model 58 | model.fit(x_train, z_train) 59 | model.tune(x_validation, z_validation, bump_threshold_grid=BUMP_THRESHOLD_GRID, sharpen_grid=SHARPEN_GRID) 60 | 61 | cdes_predict, z_grid = model.predict(x_test, n_grid=200) 62 | 63 | coefs = model.predict_coefs(x_test) 64 | cdes_coefs = coefs.evaluate(z_grid) 65 | 66 | assert np.max(np.abs(cdes_predict - cdes_coefs)) <= 1e-4 67 | 68 | 69 | def test_coef_predict_same_as_predict_xgb(): 70 | x_train, z_train = generate_data(1000) 71 | x_validation, z_validation = generate_data(1000) 72 | x_test, _ = generate_data(1000) 73 | 74 | # Parameterize model 75 | model = flexcode.FlexCodeModel( 76 | XGBoost, max_basis=31, basis_system="cosine", regression_params={"max_depth": [3, 8], "eta": [0.1]} 77 | ) 78 | 79 | # Fit and tune model 80 | model.fit(x_train, z_train) 81 | model.tune(x_validation, z_validation, bump_threshold_grid=BUMP_THRESHOLD_GRID, sharpen_grid=SHARPEN_GRID) 82 | 83 | cdes_predict, z_grid = model.predict(x_test, n_grid=200) 84 | 85 | coefs = model.predict_coefs(x_test) 86 | cdes_coefs = coefs.evaluate(z_grid) 87 | 88 | assert np.max(np.abs(cdes_predict - cdes_coefs)) <= 1e-4 89 | 90 | 91 | def test_coef_predict_same_as_predict_lasso(): 92 | x_train, z_train = generate_data(1000) 93 | x_validation, z_validation = generate_data(1000) 94 | x_test, _ = generate_data(1000) 95 | 96 | # Parameterize model 97 | model = flexcode.FlexCodeModel( 98 | Lasso, max_basis=31, basis_system="cosine", regression_params={"alpha": [1.0, 1.1]} 99 | ) 100 | 101 | # Fit and tune model 102 | model.fit(x_train, z_train) 103 | model.tune( 104 | x_validation, 105 | z_validation, 106 | bump_threshold_grid=np.linspace(0, 0.2, 3), 107 | sharpen_grid=np.linspace(0.5, 1.5, 3), 108 | ) 109 | 110 | cdes_predict, z_grid = model.predict(x_test, n_grid=200) 111 | 112 | coefs = model.predict_coefs(x_test) 113 | cdes_coefs = coefs.evaluate(z_grid) 114 | 115 | assert np.max(np.abs(cdes_predict - cdes_coefs)) <= 0.5 116 | 117 | 118 | def test_coef_predict_same_as_predict_custom_model(): 119 | x_train, z_train = generate_data(1000) 120 | x_validation, z_validation = generate_data(1000) 121 | x_test, _ = generate_data(1000) 122 | 123 | # Parameterize model 124 | custom_model = xgb.XGBRegressor 125 | model = flexcode.FlexCodeModel( 126 | CustomModel, 127 | max_basis=31, 128 | basis_system="cosine", 129 | regression_params={"max_depth": [3, 8], "eta": [0.1]}, 130 | custom_model=custom_model, 131 | ) 132 | 133 | # Fit and tune model 134 | model.fit(x_train, z_train) 135 | model.tune(x_validation, z_validation, bump_threshold_grid=BUMP_THRESHOLD_GRID, sharpen_grid=SHARPEN_GRID) 136 | 137 | cdes_predict, z_grid = model.predict(x_test, n_grid=200) 138 | 139 | coefs = model.predict_coefs(x_test) 140 | cdes_coefs = coefs.evaluate(z_grid) 141 | 142 | assert np.max(np.abs(cdes_predict - cdes_coefs)) <= 1e-4 143 | -------------------------------------------------------------------------------- /tests/flexcode/test_models_fit.py: -------------------------------------------------------------------------------- 1 | """This module tests a particular code path in the FlexCodeModel.fit function. 2 | Each of the test cases specifies a flexcode.FlexCodeModel that is parameterized 3 | with a `regression_params` variable that defines a dictionary with a single float 4 | or integer. 5 | 6 | For example, in `test_coef_predict_same_as_predict_nn`, the model is defined with 7 | the regression parameter `"k": 20`. The effect of this that within the 8 | `NN.fit` method, the `self.models.fit` method will be called instead of `self.cv_optim`. 9 | 10 | This module is structurally similar to `test_cv_optim`, but tests a slightly 11 | different code path. 12 | """ 13 | 14 | import numpy as np 15 | import pytest 16 | import xgboost as xgb 17 | from conftest import BUMP_THRESHOLD_GRID, SHARPEN_GRID, generate_data 18 | 19 | import flexcode 20 | from flexcode.regression_models import NN, CustomModel, Lasso, RandomForest, XGBoost 21 | 22 | 23 | @pytest.mark.skip(reason="The assertion is meaningless and the test is a duplicate") 24 | def test_example(): 25 | x_train, z_train = generate_data(1000) 26 | x_validation, z_validation = generate_data(1000) 27 | x_test, z_test = generate_data(1000) 28 | 29 | # Parameterize model 30 | model = flexcode.FlexCodeModel(NN, max_basis=31, basis_system="cosine", regression_params={"k": 20}) 31 | 32 | # Fit and tune model 33 | model.fit(x_train, z_train) 34 | model.tune(x_validation, z_validation, bump_threshold_grid=BUMP_THRESHOLD_GRID, sharpen_grid=SHARPEN_GRID) 35 | 36 | # Estimate CDE loss 37 | model.estimate_error(x_test, z_test) 38 | 39 | cdes, z_grid = model.predict(x_test, n_grid=200) 40 | 41 | assert True 42 | 43 | 44 | @pytest.mark.skip(reason="The assertion is meaningless and this test is a duplicate") 45 | def test_unshaped_example(): 46 | x_train, z_train = generate_data(1000) 47 | x_validation, z_validation = generate_data(1000) 48 | x_test, z_test = generate_data(1000) 49 | 50 | # Parameterize model 51 | model = flexcode.FlexCodeModel(NN, max_basis=31, basis_system="cosine", regression_params={"k": 20}) 52 | 53 | # Fit and tune model 54 | model.fit(x_train, z_train) 55 | model.tune(x_validation, z_validation, bump_threshold_grid=BUMP_THRESHOLD_GRID, sharpen_grid=SHARPEN_GRID) 56 | 57 | # Estimate CDE loss 58 | model.estimate_error(x_test, z_test) 59 | 60 | cdes, z_grid = model.predict(x_test, n_grid=200) 61 | 62 | assert True 63 | 64 | 65 | def test_coef_predict_same_as_predict_nn(): 66 | x_train, z_train = generate_data(1000) 67 | x_validation, z_validation = generate_data(1000) 68 | x_test, _ = generate_data(1000) 69 | 70 | # Parameterize model 71 | model = flexcode.FlexCodeModel(NN, max_basis=31, basis_system="cosine", regression_params={"k": 20}) 72 | 73 | # Fit and tune model 74 | model.fit(x_train, z_train) 75 | model.tune(x_validation, z_validation, bump_threshold_grid=BUMP_THRESHOLD_GRID, sharpen_grid=SHARPEN_GRID) 76 | 77 | cdes_predict, z_grid = model.predict(x_test, n_grid=200) 78 | 79 | coefs = model.predict_coefs(x_test) 80 | cdes_coefs = coefs.evaluate(z_grid) 81 | 82 | assert np.max(np.abs(cdes_predict - cdes_coefs)) <= 1e-4 83 | 84 | 85 | def test_coef_predict_same_as_predict_rf(): 86 | x_train, z_train = generate_data(1000) 87 | x_validation, z_validation = generate_data(1000) 88 | x_test, _ = generate_data(1000) 89 | 90 | # Parameterize model 91 | model = flexcode.FlexCodeModel( 92 | RandomForest, max_basis=31, basis_system="cosine", regression_params={"n_estimators": 10} 93 | ) 94 | 95 | # Fit and tune model 96 | model.fit(x_train, z_train) 97 | model.tune(x_validation, z_validation, bump_threshold_grid=BUMP_THRESHOLD_GRID, sharpen_grid=SHARPEN_GRID) 98 | 99 | cdes_predict, z_grid = model.predict(x_test, n_grid=200) 100 | 101 | coefs = model.predict_coefs(x_test) 102 | cdes_coefs = coefs.evaluate(z_grid) 103 | 104 | assert np.max(np.abs(cdes_predict - cdes_coefs)) <= 1e-4 105 | 106 | 107 | def test_coef_predict_same_as_predict_xgb(): 108 | x_train, z_train = generate_data(1000) 109 | x_validation, z_validation = generate_data(1000) 110 | x_test, _ = generate_data(1000) 111 | 112 | # Parameterize model 113 | model = flexcode.FlexCodeModel( 114 | XGBoost, max_basis=31, basis_system="cosine", regression_params={"max_depth": 5} 115 | ) 116 | 117 | # Fit and tune model 118 | model.fit(x_train, z_train) 119 | model.tune(x_validation, z_validation, bump_threshold_grid=BUMP_THRESHOLD_GRID, sharpen_grid=SHARPEN_GRID) 120 | 121 | cdes_predict, z_grid = model.predict(x_test, n_grid=200) 122 | 123 | coefs = model.predict_coefs(x_test) 124 | cdes_coefs = coefs.evaluate(z_grid) 125 | 126 | assert np.max(np.abs(cdes_predict - cdes_coefs)) <= 1e-4 127 | 128 | 129 | def test_coef_predict_same_as_predict_lasso(): 130 | x_train, z_train = generate_data(1000) 131 | x_validation, z_validation = generate_data(1000) 132 | x_test, _ = generate_data(1000) 133 | 134 | # Parameterize model 135 | model = flexcode.FlexCodeModel( 136 | Lasso, max_basis=31, basis_system="cosine", regression_params={"alpha": 1.0} 137 | ) 138 | 139 | # Fit and tune model 140 | model.fit(x_train, z_train) 141 | model.tune(x_validation, z_validation, bump_threshold_grid=BUMP_THRESHOLD_GRID, sharpen_grid=SHARPEN_GRID) 142 | 143 | cdes_predict, z_grid = model.predict(x_test, n_grid=200) 144 | 145 | coefs = model.predict_coefs(x_test) 146 | cdes_coefs = coefs.evaluate(z_grid) 147 | 148 | assert np.max(np.abs(cdes_predict - cdes_coefs)) <= 0.5 149 | 150 | 151 | def test_coef_predict_same_as_predict_custom_model(): 152 | x_train, z_train = generate_data(1000) 153 | x_validation, z_validation = generate_data(1000) 154 | x_test, _ = generate_data(1000) 155 | 156 | # Parameterize model 157 | custom_model = xgb.XGBRegressor 158 | model = flexcode.FlexCodeModel( 159 | CustomModel, 160 | max_basis=31, 161 | basis_system="cosine", 162 | regression_params={"max_depth": 5}, 163 | custom_model=custom_model, 164 | ) 165 | 166 | # Fit and tune model 167 | model.fit(x_train, z_train) 168 | model.tune(x_validation, z_validation, bump_threshold_grid=BUMP_THRESHOLD_GRID, sharpen_grid=SHARPEN_GRID) 169 | 170 | cdes_predict, z_grid = model.predict(x_test, n_grid=200) 171 | 172 | coefs = model.predict_coefs(x_test) 173 | cdes_coefs = coefs.evaluate(z_grid) 174 | 175 | assert np.max(np.abs(cdes_predict - cdes_coefs)) <= 1e-4 176 | -------------------------------------------------------------------------------- /tests/flexcode/test_params_handling.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from flexcode.helpers import params_dict_optim_decision 5 | 6 | 7 | def test_params_transform(): 8 | dict1 = {"k": [1, 2, 3, 4, 5, 6]} 9 | dictout1a, flag1a = params_dict_optim_decision(dict1) 10 | 11 | assert dictout1a == dict1 12 | assert flag1a == True 13 | 14 | dictout1b, flag1b = params_dict_optim_decision(dict1, True) 15 | assert dictout1b == {"estimator__k": [1, 2, 3, 4, 5, 6]} 16 | assert flag1b == True 17 | 18 | ############################################################ 19 | 20 | dict2 = {"k": 1} 21 | dictout2a, flag2a = params_dict_optim_decision(dict2) 22 | 23 | assert dictout2a == dict2 24 | assert flag2a == False 25 | 26 | dictout2b, flag2b = params_dict_optim_decision(dict2, True) 27 | assert dictout2b == dict2 28 | assert flag2b == False 29 | 30 | ############################################################# 31 | 32 | dict3 = {"k": [1, 2, 3, 4, 5, 6], "obj": "linear", "eta": 0.3} 33 | dictout3a, flag3a = params_dict_optim_decision(dict3) 34 | 35 | assert dictout3a == {"k": [1, 2, 3, 4, 5, 6], "obj": ["linear"], "eta": [0.3]} 36 | assert flag3a == True 37 | 38 | dictout3b, flag3b = params_dict_optim_decision(dict3, True) 39 | assert dictout3b == { 40 | "estimator__k": [1, 2, 3, 4, 5, 6], 41 | "estimator__obj": ["linear"], 42 | "estimator__eta": [0.3], 43 | } 44 | assert flag3b == True 45 | 46 | ############################################################ 47 | 48 | dict4 = {"k": 1, "obj": "linear", "eta": 0.3} 49 | dictout4a, flag4a = params_dict_optim_decision(dict4) 50 | 51 | assert dictout4a == dict4 52 | assert flag4a == False 53 | 54 | dictout4b, flag4b = params_dict_optim_decision(dict4, True) 55 | assert dictout4b == dict4 56 | assert flag4b == False 57 | 58 | ############################################################ 59 | 60 | dict5 = {"k": (1, 2, 3)} 61 | with pytest.raises(Exception): 62 | _ = params_dict_optim_decision(dict5) 63 | 64 | ############################################################# 65 | 66 | dict6 = {"k": [1, 2, 3, 4, 5, 6], "obj": "linear", "eta": np.linspace(0.3, 0.5, 3)} 67 | dictout6a, flag6a = params_dict_optim_decision(dict6) 68 | 69 | assert dictout6a == {"k": [1, 2, 3, 4, 5, 6], "obj": ["linear"], "eta": [0.3, 0.4, 0.5]} 70 | assert flag6a == True 71 | 72 | dictout6b, flag6b = params_dict_optim_decision(dict6, True) 73 | assert dictout6b == { 74 | "estimator__k": [1, 2, 3, 4, 5, 6], 75 | "estimator__obj": ["linear"], 76 | "estimator__eta": [0.3, 0.4, 0.5], 77 | } 78 | assert flag6b == True 79 | -------------------------------------------------------------------------------- /tests/flexcode/test_post_processing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from context import flexcode 4 | 5 | from flexcode.post_processing import * 6 | 7 | 8 | def test_remove_bumps(): 9 | density = np.ones(100) 10 | density[4] = 0.0 11 | density[96] = 0.0 12 | 13 | z_grid = np.linspace(0, 1, 100) 14 | delta = 0.1 15 | 16 | target_density = density.copy() 17 | target_density[:4] = 0.0 18 | target_density[96:] = 0.0 19 | normalize(target_density) 20 | 21 | remove_bumps(density, delta) 22 | 23 | np.testing.assert_array_equal(density, target_density) 24 | 25 | 26 | def test_normalize(): 27 | n_grid = 1000 28 | 29 | for _ in range(10): 30 | density = np.random.gamma(1, 1, size=n_grid) 31 | normalize(density) 32 | area = np.mean(density) 33 | assert all(density >= 0.0) 34 | assert area == pytest.approx(1.0) 35 | -------------------------------------------------------------------------------- /tutorial/Flexcode-tutorial-teddy.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "The `flexcode` package can be installed directly by cloning the Github repository:\n", 8 | "\n", 9 | "git clone https://github.com/tpospisi/FlexCode.git
\n", 10 | "python setup.py install " 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import numpy as np\n", 20 | "import flexcode\n", 21 | "from cdetools.cde_loss import cde_loss\n", 22 | "from matplotlib import pyplot as plt\n", 23 | "from flexcode.regression_models import RandomForest" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "In the following cell we run the `wget` module to fetch the data from the [COINtoolbox Github repository](https://github.com/COINtoolbox/photoz_catalogues/tree/master/Teddy).
\n", 31 | "You can download the Teddy A and B manually if you prefer." 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "import os\n", 41 | "import wget\n", 42 | "\n", 43 | "data_dir = 'data/'\n", 44 | "if not os.path.exists(data_dir):\n", 45 | " os.makedirs(data_dir)\n", 46 | " print('\"data\" subfolder created')\n", 47 | " \n", 48 | "_ = wget.download('https://github.com/COINtoolbox/photoz_catalogues/raw/master/Teddy/teddy_A', \n", 49 | " out='data/teddy_A.txt')\n", 50 | "\n", 51 | "_ = wget.download('https://github.com/COINtoolbox/photoz_catalogues/raw/master/Teddy/teddy_B', \n", 52 | " out='data/teddy_B.txt')" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "metadata": {}, 58 | "source": [ 59 | "The following lines assume you have the datasets `Teddy A` and `Teddy B` in a subfolder of your current directory. By default this subfolder is `data`, but it can be changed below. The following function extract the information from the .txt file and generates numpy array.

\n", 60 | "\n", 61 | "You can find the `Teddy A` and `Teddy B` dataset in the [COINtoolbox Github repository](https://github.com/COINtoolbox/photoz_catalogues/tree/master/Teddy)." 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "def extract_teddy_data(filename, train_data, directory='data/'):\n", 71 | " \n", 72 | " full_data = []\n", 73 | " outfiles = ('teddy_x_train.npy', 'teddy_z_train.npy') if train_data else ('teddy_x_test.npy', 'teddy_z_test.npy')\n", 74 | " with open(filename) as fp:\n", 75 | " full_lines = fp.readlines()\n", 76 | " for line in full_lines:\n", 77 | " if '#' in line:\n", 78 | " continue\n", 79 | " full_data.append([float(el) for el in line.strip().split(' ') if el])\n", 80 | " fp.close()\n", 81 | " \n", 82 | " # Saving the formatted Teddy data\n", 83 | " np.save(arr=np.array(full_data)[:, 7:12], file=directory + outfiles[0])\n", 84 | " np.save(arr=np.array(full_data)[:, 6], file=directory + outfiles[1])\n", 85 | " print('Extraction and Saving Done!')" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "extract_teddy_data(filename='data/teddy_A.txt', train_data=True, directory='data/')" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "extract_teddy_data(filename='data/teddy_B.txt', train_data=False, directory='data/')" 104 | ] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "metadata": {}, 109 | "source": [ 110 | "## Load Teddy Cosmology Data" 111 | ] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "metadata": {}, 116 | "source": [ 117 | "The $\\texttt{TEDDY}$ data are comprised of 4 datasets, generated by subsampling from the [SDSS DR12](https://www.sdss.org/dr12/).
\n", 118 | "The 4 datasets are named respectively A, B, C and D.\n", 119 | "\n", 120 | "We use dataset A for training and B for testing.
\n", 121 | "Data in these two datasets share the same underlying distribution, so training-based algorithms do not need any further adjustments.
\n", 122 | "For more information, consult the [TEDDY Github Repo](https://github.com/COINtoolbox/photoz_catalogues/tree/master/Teddy)." 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "metadata": {}, 128 | "source": [ 129 | "Both datasets have around 74,000 spectroscopic samples in it.
\n", 130 | "We downsample both training and testing, including only the first 2,000 and 500 galaxies respectively.
\n", 131 | "Here we also use a validation set of 500 galaxies, taken from the training set." 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "train_limit_points = 2000\n", 141 | "valid_limit_points = 500\n", 142 | "test_limit_points = 500\n", 143 | "\n", 144 | "x_train = np.load(file='data/teddy_x_train.npy')[:train_limit_points, :]\n", 145 | "x_validation = np.load(file='data/teddy_x_train.npy')[train_limit_points:train_limit_points + valid_limit_points, :]\n", 146 | "x_test = np.load(file='data/teddy_x_test.npy')[:test_limit_points, :]\n", 147 | "\n", 148 | "z_train = np.load(file='data/teddy_y_train.npy')[:train_limit_points]\n", 149 | "z_validation = np.load(file='data/teddy_y_train.npy')[train_limit_points:train_limit_points + valid_limit_points]\n", 150 | "z_test = np.load(file='data/teddy_y_test.npy')[:test_limit_points]" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": {}, 156 | "source": [ 157 | "## Running Flexcode" 158 | ] 159 | }, 160 | { 161 | "cell_type": "markdown", 162 | "metadata": {}, 163 | "source": [ 164 | "As any $\\texttt{scikit-learn}$ model, with the first call we initialize the model.
\n", 165 | "We need to specify the following:\n", 166 | "* regression model hyper-parameters; here we include those for Random Forest:\n", 167 | " - the number of trees `n_estimators`, \n", 168 | " - the maximum depth of a tree with `max_depth` and \n", 169 | " - the splitting criterion `criterion`.\n", 170 | "* we also pass in the basis system and the maximum number of basis we want Flexcode to consider. Currently, Flexcode will automatically select the best number of basis according to the training data available." 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": null, 176 | "metadata": {}, 177 | "outputs": [], 178 | "source": [ 179 | "n_estimators = 100\n", 180 | "criterion = 'mse'\n", 181 | "max_depth = 5\n", 182 | "\n", 183 | "max_basis = 31\n", 184 | "basis_system = 'cosine'\n", 185 | "\n", 186 | "model = flexcode.FlexCodeModel(RandomForest, max_basis=max_basis, basis_system=basis_system,\n", 187 | " regression_params={'max_depth': max_depth, 'n_estimators': n_estimators, \n", 188 | " 'criterion': criterion})" 189 | ] 190 | }, 191 | { 192 | "cell_type": "markdown", 193 | "metadata": {}, 194 | "source": [ 195 | "We then train the model by using the `train` method" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "model.fit(x_train, z_train)" 205 | ] 206 | }, 207 | { 208 | "cell_type": "markdown", 209 | "metadata": {}, 210 | "source": [ 211 | "For prediction, we only need to specify the number of points `n_grid` in the CDE support, i.e. the grid over which we want the CDE to be predicted.
\n", 212 | "Flexcode creates a grid with that number of points between the minimum and maximum of the response in the training data." 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": null, 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [ 221 | "n_grid = 1000\n", 222 | "cde_test, z_grid = model.predict(x_test, n_grid=n_grid)" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": null, 228 | "metadata": {}, 229 | "outputs": [], 230 | "source": [ 231 | "model.__dict__" 232 | ] 233 | }, 234 | { 235 | "cell_type": "markdown", 236 | "metadata": {}, 237 | "source": [ 238 | "For Flexcode CDE predictions, `cde_test` is a numpy array, but the actual full conditional density estimates is completely identified by the `n_basis` coefficients. In other words, one can achieve any resolution by just storing `n_basis` floats.
\n", 239 | "Densities are also normalized, i.e. they integrate to 1." 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": null, 245 | "metadata": {}, 246 | "outputs": [], 247 | "source": [ 248 | "from scipy.integrate import simps\n", 249 | "\n", 250 | "print(type(cde_test), cde_test.shape)\n", 251 | "\n", 252 | "den_integral = simps(cde_test[0, :], x=np.linspace(model.z_min[0], model.z_max[0], n_grid))\n", 253 | "print('Integral of the first density integrates to: %.2f' % den_integral)" 254 | ] 255 | }, 256 | { 257 | "cell_type": "markdown", 258 | "metadata": {}, 259 | "source": [ 260 | "### Visualize Predicted CDEs" 261 | ] 262 | }, 263 | { 264 | "cell_type": "markdown", 265 | "metadata": {}, 266 | "source": [ 267 | "We can calculate the CDE loss function importing the function from the [`cdetools` package](https://github.com/tpospisi/cdetools)." 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": null, 273 | "metadata": {}, 274 | "outputs": [], 275 | "source": [ 276 | "cde_loss_val, std_cde_loss = cde_loss(cde_test, z_grid, z_test)\n", 277 | "print('CDE Loss: %4.2f \\pm %.2f' % (cde_loss_val, std_cde_loss))" 278 | ] 279 | }, 280 | { 281 | "cell_type": "markdown", 282 | "metadata": {}, 283 | "source": [ 284 | "We here visualize the first 12 CDEs" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": null, 290 | "metadata": {}, 291 | "outputs": [], 292 | "source": [ 293 | "fig = plt.figure(figsize=(30, 20))\n", 294 | "for jj, cde_predicted in enumerate(cde_test[:12,:]):\n", 295 | " ax = fig.add_subplot(3, 4, jj + 1)\n", 296 | " plt.plot(z_grid, cde_predicted, label=r'$\\hat{p}(z| x_{\\rm obs})$')\n", 297 | " plt.axvline(z_test[jj], color='red', label=r'$z_{\\rm obs}$')\n", 298 | " plt.xticks(size=16)\n", 299 | " plt.yticks(size=16)\n", 300 | " plt.xlabel(r'Redshift $z$', size=20)\n", 301 | " plt.ylabel('CDE', size=20)\n", 302 | " plt.legend(loc='upper right', prop={'size': 20})\n", 303 | "plt.show()" 304 | ] 305 | }, 306 | { 307 | "cell_type": "markdown", 308 | "metadata": {}, 309 | "source": [ 310 | "### Remove Basis Bumps" 311 | ] 312 | }, 313 | { 314 | "cell_type": "markdown", 315 | "metadata": {}, 316 | "source": [ 317 | "Basys systems can create artificial \"bumps\" in the conditional distribution.
\n", 318 | "Flexcode allows to remove those bumps by selecting a threshold in density. The best threshold is chosen according to its CDE loss.
\n", 319 | "We can use the method `tune` and pass an array of thresholds." 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": null, 325 | "metadata": {}, 326 | "outputs": [], 327 | "source": [ 328 | "threshold_vec = [5e-2, 1e-1, 5e-1]\n", 329 | "model.tune(x_validation, z_validation, bump_threshold_grid=threshold_vec, n_grid=n_grid)\n", 330 | "print('Best Bump Removal Threshold: %s' % model.bump_threshold)" 331 | ] 332 | }, 333 | { 334 | "cell_type": "markdown", 335 | "metadata": {}, 336 | "source": [ 337 | "The model saves internally the best bump thresholds, and so the next time the `predict` method is called the density will be adjusted accordingly." 338 | ] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "execution_count": null, 343 | "metadata": {}, 344 | "outputs": [], 345 | "source": [ 346 | "cde_test, z_grid = model.predict(x_test, n_grid=n_grid)" 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": null, 352 | "metadata": {}, 353 | "outputs": [], 354 | "source": [ 355 | "cde_loss_val, std_cde_loss = cde_loss(cde_test, z_grid, z_test)\n", 356 | "print('CDE Loss: %4.2f \\pm %.2f' % (cde_loss_val, std_cde_loss))" 357 | ] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "execution_count": null, 362 | "metadata": {}, 363 | "outputs": [], 364 | "source": [ 365 | "fig = plt.figure(figsize=(30, 20))\n", 366 | "for jj, cde_predicted in enumerate(cde_test[:12,:]):\n", 367 | " ax = fig.add_subplot(3, 4, jj + 1)\n", 368 | " plt.plot(z_grid, cde_predicted, label=r'$\\hat{p}(z| x_{\\rm obs})$')\n", 369 | " plt.axvline(z_test[jj], color='red', label=r'$z_{\\rm obs}$')\n", 370 | " plt.xticks(size=16)\n", 371 | " plt.yticks(size=16)\n", 372 | " plt.xlabel(r'Redshift $z$', size=20)\n", 373 | " plt.ylabel('CDE', size=20)\n", 374 | " plt.legend(loc='upper right', prop={'size': 20})\n", 375 | "plt.show()" 376 | ] 377 | }, 378 | { 379 | "cell_type": "markdown", 380 | "metadata": {}, 381 | "source": [ 382 | "## Use Custom Regression Models" 383 | ] 384 | }, 385 | { 386 | "cell_type": "markdown", 387 | "metadata": {}, 388 | "source": [ 389 | "Flexcode can be used with any scikit-learn-compatible regression model.
\n", 390 | "Here we show that the internal implementation of XGBoost (Gradient Boosted Trees Regression) and the XGBoost regressor from the `xgboost` package via Flexcode `CustomModel` yield identical results." 391 | ] 392 | }, 393 | { 394 | "cell_type": "code", 395 | "execution_count": null, 396 | "metadata": {}, 397 | "outputs": [], 398 | "source": [ 399 | "import xgboost as xgb\n", 400 | "from flexcode.regression_models import XGBoost, CustomModel" 401 | ] 402 | }, 403 | { 404 | "cell_type": "markdown", 405 | "metadata": {}, 406 | "source": [ 407 | "Flexcode XGBoost Implementation" 408 | ] 409 | }, 410 | { 411 | "cell_type": "code", 412 | "execution_count": null, 413 | "metadata": {}, 414 | "outputs": [], 415 | "source": [ 416 | "# Parameterize model\n", 417 | "model = flexcode.FlexCodeModel(XGBoost, max_basis=31, basis_system=\"cosine\",\n", 418 | " regression_params={'max_depth': 3, 'learning_rate': 0.5, 'objective': 'reg:linear'})\n", 419 | "\n", 420 | "# Fit and tune model\n", 421 | "model.fit(x_train, z_train)\n", 422 | "cdes_predict_xgb, z_grid = model.predict(x_test, n_grid=1000)" 423 | ] 424 | }, 425 | { 426 | "cell_type": "markdown", 427 | "metadata": {}, 428 | "source": [ 429 | "Flexcode Custom Model, using `XGBRegressor`" 430 | ] 431 | }, 432 | { 433 | "cell_type": "code", 434 | "execution_count": null, 435 | "metadata": {}, 436 | "outputs": [], 437 | "source": [ 438 | "from xgboost import XGBRegressor\n", 439 | "\n", 440 | "model_c = flexcode.FlexCodeModel(CustomModel, max_basis=31, basis_system=\"cosine\",\n", 441 | " regression_params={'max_depth': 3, 'learning_rate': 0.5, 'objective': 'reg:linear'},\n", 442 | " custom_model=XGBRegressor)\n", 443 | "\n", 444 | "# Fit and tune model\n", 445 | "model_c.fit(x_train, z_train)\n", 446 | "cdes_predict_custom, z_grid = model_c.predict(x_test, n_grid=1000)" 447 | ] 448 | }, 449 | { 450 | "cell_type": "markdown", 451 | "metadata": {}, 452 | "source": [ 453 | "We calculate the largest discrepancy between the two sets of predicted CDEs" 454 | ] 455 | }, 456 | { 457 | "cell_type": "code", 458 | "execution_count": null, 459 | "metadata": {}, 460 | "outputs": [], 461 | "source": [ 462 | "np.max(np.abs(cdes_predict_custom - cdes_predict_xgb))" 463 | ] 464 | }, 465 | { 466 | "cell_type": "markdown", 467 | "metadata": {}, 468 | "source": [ 469 | "## CDE Diagnostics" 470 | ] 471 | }, 472 | { 473 | "cell_type": "markdown", 474 | "metadata": {}, 475 | "source": [ 476 | "We can compute the PIT and HPD values over our estimated CDEs.
\n", 477 | "Functions to calculate both values can be found in the [`cdetools` package](https://github.com/tpospisi/cdetools).\n", 478 | "
\n", 479 | "We suggest to clone the Github repository and install it.\n", 480 | "\n", 481 | "git clone https://github.com/tpospisi/cdetools.git
\n", 482 | "cd cdetools/python/
\n", 483 | "python setup.py install " 484 | ] 485 | }, 486 | { 487 | "cell_type": "code", 488 | "execution_count": null, 489 | "metadata": {}, 490 | "outputs": [], 491 | "source": [ 492 | "from cdetools.hpd_coverage import hpd_coverage\n", 493 | "from cdetools.cdf_coverage import cdf_coverage\n", 494 | "from cdetools.plot_utils import plot_with_uniform_band\n", 495 | "\n", 496 | "# Computing the values\n", 497 | "z_grid = np.linspace(z_train.min(), z_train.max(), n_grid)\n", 498 | "pit_values = cdf_coverage(cde_test, z_grid, z_test)\n", 499 | "hpd_values = hpd_coverage(cde_test, z_grid, z_test)" 500 | ] 501 | }, 502 | { 503 | "cell_type": "markdown", 504 | "metadata": {}, 505 | "source": [ 506 | "Computing the number of values per each bin in the histogram under uniformity assumptions.
\n", 507 | "We look at the 99% CI." 508 | ] 509 | }, 510 | { 511 | "cell_type": "code", 512 | "execution_count": null, 513 | "metadata": {}, 514 | "outputs": [], 515 | "source": [ 516 | "fig_pit = plot_with_uniform_band(values=pit_values, ci_level=0.99, x_label='PIT Values', n_bins=30)\n", 517 | "fig_pit" 518 | ] 519 | }, 520 | { 521 | "cell_type": "code", 522 | "execution_count": null, 523 | "metadata": {}, 524 | "outputs": [], 525 | "source": [ 526 | "fig_hpd = plot_with_uniform_band(values=hpd_values, ci_level=0.99, x_label='HPD Values', n_bins=30) \n", 527 | "fig_hpd" 528 | ] 529 | } 530 | ], 531 | "metadata": { 532 | "kernelspec": { 533 | "display_name": "Python 3", 534 | "language": "python", 535 | "name": "python3" 536 | }, 537 | "language_info": { 538 | "codemirror_mode": { 539 | "name": "ipython", 540 | "version": 3 541 | }, 542 | "file_extension": ".py", 543 | "mimetype": "text/x-python", 544 | "name": "python", 545 | "nbconvert_exporter": "python", 546 | "pygments_lexer": "ipython3", 547 | "version": "3.7.1" 548 | } 549 | }, 550 | "nbformat": 4, 551 | "nbformat_minor": 4 552 | } 553 | -------------------------------------------------------------------------------- /vignettes/Custom Class.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "This notebook provides an example on how to use a custom class within Flexcode.
\n", 8 | "In order to be compatible, a regression method needs to have a `fit` and `predict` method implemented - i.e. \n", 9 | "`model.fit()` and `model.predict()` need to be the functions used for training and predicting respectively.\n", 10 | "\n", 11 | "We provide here an example with artifical data.
\n", 12 | "We compare the FlexZBoost (Flexcode with builtin XGBoost) with the custom class of FLexcode when passing\n", 13 | "XGBoost Regressor. The two should give basically identical results." 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "import flexcode\n", 23 | "import numpy as np\n", 24 | "import xgboost as xgb\n", 25 | "from flexcode.regression_models import XGBoost, CustomModel" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "## Data Creation" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "def generate_data(n_draws):\n", 42 | " x = np.random.normal(0, 1, n_draws)\n", 43 | " z = np.random.normal(x, 1, n_draws)\n", 44 | " return x, z\n", 45 | "\n", 46 | "x_train, z_train = generate_data(5000)\n", 47 | "x_validation, z_validation = generate_data(5000)\n", 48 | "x_test, z_test = generate_data(5000)" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": {}, 54 | "source": [ 55 | "## FlexZBoost" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "# Parameterize model\n", 65 | "model = flexcode.FlexCodeModel(XGBoost, max_basis=31, basis_system=\"cosine\",\n", 66 | " regression_params={'max_depth': 3, 'learning_rate': 0.5, 'objective': 'reg:linear'})\n", 67 | "\n", 68 | "# Fit and tune model\n", 69 | "model.fit(x_train, z_train)\n", 70 | "\n", 71 | "cdes_predict_xgb, z_grid = model.predict(x_test, n_grid=200)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "model.__dict__" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "import pickle\n", 90 | "\n", 91 | "pickle.dump(file=open('example.pkl', 'wb'), obj=model, \n", 92 | " protocol=pickle.HIGHEST_PROTOCOL)" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "model = pickle.load(open('example.pkl', 'rb'))\n", 102 | "model.__dict__" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "cdes_predict_xgb, z_grid = model.predict(x_test, n_grid=200)" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "metadata": {}, 117 | "source": [ 118 | "## Custom Model" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "metadata": {}, 124 | "source": [ 125 | "Our custom model in this case is going to be XGBRegressor.
\n", 126 | "The only difference with the above is that we are going to use the `CustomModel` class and we are going to pass\n", 127 | "XGBRegressor as `custom_model`.\n", 128 | "After that, everything is exactly as above.
\n", 129 | "\n", 130 | "Parameters can be passed also in the same way as above." 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": null, 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "# Parameterize model\n", 140 | "my_model = xgb.XGBRegressor\n", 141 | "model_c = flexcode.FlexCodeModel(CustomModel, max_basis=31, basis_system=\"cosine\",\n", 142 | " regression_params={'max_depth': 3, 'learning_rate': 0.5, 'objective': 'reg:linear'},\n", 143 | " custom_model=my_model)\n", 144 | "\n", 145 | "# Fit and tune model\n", 146 | "model_c.fit(x_train, z_train)\n", 147 | "cdes_predict_custom, z_grid = model_c.predict(x_test, n_grid=200)" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "metadata": {}, 153 | "source": [ 154 | "The two conditional density estimates should be the same across the board.
\n", 155 | "We check the maximum difference in absolute value between the two." 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "np.max(np.abs(cdes_predict_custom - cdes_predict_xgb))" 165 | ] 166 | } 167 | ], 168 | "metadata": { 169 | "kernelspec": { 170 | "display_name": "Python 3", 171 | "language": "python", 172 | "name": "python3" 173 | }, 174 | "language_info": { 175 | "codemirror_mode": { 176 | "name": "ipython", 177 | "version": 3 178 | }, 179 | "file_extension": ".py", 180 | "mimetype": "text/x-python", 181 | "name": "python", 182 | "nbconvert_exporter": "python", 183 | "pygments_lexer": "ipython3", 184 | "version": "3.5.2" 185 | } 186 | }, 187 | "nbformat": 4, 188 | "nbformat_minor": 2 189 | } 190 | -------------------------------------------------------------------------------- /vignettes/Model Save and Bumps Removal - Flexcode.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "This notebook is used to show:\n", 8 | "1. How to save and reload the Flexcode model so not to have to re-train it every time;\n", 9 | "2. How to select the best bump removal parameter from an array of potential value" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import flexcode\n", 19 | "import numpy as np\n", 20 | "import xgboost as xgb\n", 21 | "from flexcode.regression_models import XGBoost, CustomModel\n", 22 | "\n", 23 | "from matplotlib import pyplot as plt\n", 24 | "%matplotlib inline" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": {}, 30 | "source": [ 31 | "## Data Creation" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "def generate_data(n_draws):\n", 41 | " x = np.random.normal(0, 1, n_draws)\n", 42 | " z = np.random.normal(x, 1, n_draws)\n", 43 | " return x, z\n", 44 | "\n", 45 | "x_train, z_train = generate_data(1000)\n", 46 | "x_validation, z_validation = generate_data(1000)\n", 47 | "x_test, z_test = generate_data(1000)" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "metadata": {}, 53 | "source": [ 54 | "## Saving and Reload Model" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "# Parameterize model\n", 64 | "model = flexcode.FlexCodeModel(XGBoost, max_basis=31, basis_system=\"cosine\",\n", 65 | " regression_params={'max_depth': 3, 'learning_rate': 0.5, 'objective': 'reg:linear'})\n", 66 | "\n", 67 | "# Fit model and predict on test data\n", 68 | "model.fit(x_train, z_train)\n", 69 | "cdes_predict_xgb, z_grid = model.predict(x_test, n_grid=200)\n", 70 | "\n", 71 | "# Show output some general values of the first two predictions\n", 72 | "# for further check\n", 73 | "print(np.max(cdes_predict_xgb[7, :]))\n", 74 | "print(np.max(cdes_predict_xgb[42, :]))" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "# Print model parameters\n", 84 | "model.__dict__" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "import pickle\n", 94 | "\n", 95 | "# Saving the model\n", 96 | "pickle.dump(file=open('flexcode_model.pkl', 'wb'), obj=model, \n", 97 | " protocol=pickle.HIGHEST_PROTOCOL)" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "# Reaload the model\n", 107 | "model_reloaded = pickle.load(open('flexcode_model.pkl', 'rb'))\n", 108 | "\n", 109 | "# Predict again\n", 110 | "cdes_predict_xgb_reloaded, z_grid = model_reloaded.predict(x_test, n_grid=200)" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": null, 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "# Same output as above\n", 120 | "print(np.max(cdes_predict_xgb_reloaded[7, :]))\n", 121 | "print(np.max(cdes_predict_xgb_reloaded[42, :]))" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "# Print parameters of the reloaded model\n", 131 | "model_reloaded.__dict__" 132 | ] 133 | }, 134 | { 135 | "cell_type": "markdown", 136 | "metadata": {}, 137 | "source": [ 138 | "## Spurious Bump Removal - Tune Using Validation Data" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "# Define a grid of values to tune over\n", 148 | "bump_removal_grid = np.linspace(0.01, 0.2, 20)\n", 149 | "print(bump_removal_grid)" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "# Parameterize model\n", 159 | "model = flexcode.FlexCodeModel(XGBoost, max_basis=31, basis_system=\"cosine\",\n", 160 | " regression_params={'max_depth': 3, 'learning_rate': 0.5, 'objective': 'reg:linear'})\n", 161 | "\n", 162 | "# Fit model\n", 163 | "model.fit(x_train, z_train)" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "# Tune the model - the bump removal grid is passed directly \n", 173 | "# in the tune function\n", 174 | "\n", 175 | "model.tune(x_validation, z_validation, \n", 176 | " bump_threshold_grid=bump_removal_grid)" 177 | ] 178 | }, 179 | { 180 | "cell_type": "markdown", 181 | "metadata": {}, 182 | "source": [ 183 | "Under the hood, it selects the bump value corresponding to the smallest CDE loss on the validation data.\n", 184 | "\n", 185 | "The best value is accessible among the attributes of the model, as below:" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": null, 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | "model.bump_threshold" 195 | ] 196 | } 197 | ], 198 | "metadata": { 199 | "kernelspec": { 200 | "display_name": "Python 3", 201 | "language": "python", 202 | "name": "python3" 203 | }, 204 | "language_info": { 205 | "codemirror_mode": { 206 | "name": "ipython", 207 | "version": 3 208 | }, 209 | "file_extension": ".py", 210 | "mimetype": "text/x-python", 211 | "name": "python", 212 | "nbconvert_exporter": "python", 213 | "pygments_lexer": "ipython3", 214 | "version": "3.5.2" 215 | } 216 | }, 217 | "nbformat": 4, 218 | "nbformat_minor": 2 219 | } 220 | --------------------------------------------------------------------------------