├── .binder ├── environment.yml └── postBuild ├── .codecov.yml ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.yml │ ├── documentation_improvement.yml │ └── feature_request.yml ├── dependabot.yml ├── pull_request_template.md └── workflows │ ├── pre-commit.yaml │ ├── tests-workflow.yaml │ ├── tox-workflow.yaml │ └── wheels-workflow.yaml ├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── .zenodo.json ├── CONTRIBUTING.rst ├── COPYING ├── MANIFEST.in ├── README.rst ├── appveyor.yml ├── ci ├── appveyor │ ├── py310.ps1 │ ├── py311.ps1 │ ├── py312.ps1 │ └── py313.ps1 ├── deps │ ├── py310.sh │ ├── py311.sh │ ├── py312.sh │ ├── py313.sh │ └── requirements.yaml.tmpl ├── nb_sanitize.cfg ├── render-requirements.py ├── run_tests.sh ├── setup_conda.sh └── setup_env.sh ├── doc ├── .gitignore ├── Makefile ├── _static │ ├── custom.css │ ├── github-stats.js │ └── images │ │ └── censoring.svg ├── _templates │ └── navbar-github-links.html ├── api │ ├── compare.rst │ ├── datasets.rst │ ├── ensemble.rst │ ├── functions.rst │ ├── index.rst │ ├── io.rst │ ├── kernels.rst │ ├── linear_model.rst │ ├── meta.rst │ ├── metrics.rst │ ├── nonparametric.rst │ ├── preprocessing.rst │ ├── svm.rst │ ├── tree.rst │ └── util.rst ├── cite.rst ├── conf.py ├── contributing.rst ├── index.rst ├── install.rst ├── release_notes.rst ├── release_notes │ ├── v0.1.rst │ ├── v0.10.rst │ ├── v0.11.rst │ ├── v0.12.rst │ ├── v0.13.rst │ ├── v0.14.rst │ ├── v0.15.rst │ ├── v0.16.rst │ ├── v0.17.rst │ ├── v0.18.rst │ ├── v0.19.rst │ ├── v0.2.rst │ ├── v0.20.rst │ ├── v0.21.rst │ ├── v0.22.rst │ ├── v0.23.rst │ ├── v0.24.rst │ ├── v0.3.rst │ ├── v0.4.rst │ ├── v0.5.rst │ ├── v0.6.rst │ ├── v0.7.rst │ ├── v0.8.rst │ └── v0.9.rst ├── spelling_wordlist.txt └── user_guide │ ├── 00-introduction.ipynb │ ├── boosting.ipynb │ ├── competing-risks.ipynb │ ├── coxnet.ipynb │ ├── evaluating-survival-models.ipynb │ ├── index.rst │ ├── random-survival-forest.ipynb │ ├── survival-svm.ipynb │ └── understanding_predictions.rst ├── pyproject.toml ├── setup.py ├── sksurv ├── __init__.py ├── base.py ├── bintrees │ ├── __init__.py │ ├── _binarytrees.pyx │ ├── binarytrees.cpp │ └── binarytrees.h ├── column.py ├── compare.py ├── datasets │ ├── __init__.py │ ├── base.py │ └── data │ │ ├── GBSG2.arff │ │ ├── README.md │ │ ├── actg320.arff │ │ ├── bmt.arff │ │ ├── breast_cancer_GSE7390-metastasis.arff │ │ ├── cgvhd.arff │ │ ├── flchain.arff │ │ ├── veteran.arff │ │ └── whas500.arff ├── ensemble │ ├── __init__.py │ ├── _coxph_loss.pyx │ ├── boosting.py │ ├── forest.py │ └── survival_loss.py ├── exceptions.py ├── functions.py ├── io │ ├── __init__.py │ ├── arffread.py │ └── arffwrite.py ├── kernels │ ├── __init__.py │ ├── _clinical_kernel.pyx │ └── clinical.py ├── linear_model │ ├── __init__.py │ ├── _coxnet.pyx │ ├── aft.py │ ├── coxnet.py │ ├── coxph.py │ └── src │ │ ├── coxnet │ │ ├── constants.h │ │ ├── coxnet.h │ │ ├── data.h │ │ ├── error.h │ │ ├── fit_params.h │ │ ├── fit_result.h │ │ ├── ordered_dict.h │ │ ├── parameters.h │ │ └── soft_threshold.h │ │ └── coxnet_wrapper.h ├── meta │ ├── __init__.py │ ├── base.py │ ├── ensemble_selection.py │ └── stacking.py ├── metrics.py ├── nonparametric.py ├── preprocessing.py ├── svm │ ├── __init__.py │ ├── _minlip.pyx │ ├── _prsvm.pyx │ ├── minlip.py │ ├── naive_survival_svm.py │ └── survival_svm.py ├── testing.py ├── tree │ ├── __init__.py │ ├── _criterion.pyx │ └── tree.py └── util.py ├── tests ├── conftest.py ├── data │ ├── Lagakos_AIDS_adults.csv │ ├── Lagakos_AIDS_children.csv │ ├── breast_cancer_glmnet_coefficients.csv │ ├── breast_cancer_glmnet_coefficients_high.csv │ ├── cgvhd_aalen.npy │ ├── cgvhd_delta.npy │ ├── cgvhd_dinse.npy │ ├── channing.csv │ ├── compnentwise-gradient-boosting-coxph-cumhazard.csv │ ├── compnentwise-gradient-boosting-coxph-surv.csv │ ├── cox-example-coef-1-pf.csv │ ├── cox-example-coef-1-pf2.csv │ ├── cox-example-coef-1-unpen.csv │ ├── cox-example-coef-1.csv │ ├── cox-example-coef-2-alpha.csv │ ├── cox-example-coef-2-nalpha-norm.csv │ ├── cox-example-coef-2-nalpha.csv │ ├── cox-example-coef-2-norm.csv │ ├── cox-example-coef-2-std.csv │ ├── cox-example-coef-2.csv │ ├── cox-example.csv │ ├── cox-simple-coef.csv │ ├── gradient-boosting-coxph-cumhazard.csv │ ├── gradient-boosting-coxph-surv.csv │ ├── rossi.csv │ ├── whas500-noties.arff │ └── whas500_predictions.csv ├── test_aft.py ├── test_binarytrees.py ├── test_boosting.py ├── test_clinical_kernel.py ├── test_column.py ├── test_common.py ├── test_compare.py ├── test_coxnet.py ├── test_coxph.py ├── test_datasets.py ├── test_ensemble_selection.py ├── test_forest.py ├── test_functions.py ├── test_io.py ├── test_metrics.py ├── test_minlip.py ├── test_nonparametric.py ├── test_pandas_inputs.py ├── test_preprocessing.py ├── test_show_versions.py ├── test_stacking.py ├── test_survival_function.py ├── test_survival_svm.py ├── test_tree.py └── test_util.py └── tox.ini /.binder/environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - conda-forge 3 | - defaults 4 | dependencies: 5 | - python=3.11 6 | - pip 7 | - scikit-survival 8 | - matplotlib~=3.8.0 9 | - numpy 10 | - seaborn==0.11.2 11 | -------------------------------------------------------------------------------- /.binder/postBuild: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | # Back up notebooks 6 | TMP_CONTENT_DIR=/tmp/scikit-survival 7 | mkdir -p $TMP_CONTENT_DIR 8 | cp -r doc/user_guide .binder $TMP_CONTENT_DIR 9 | # delete everything in current directory including dot files and dot folders 10 | find . -delete 11 | 12 | # Copy notebooks and remove other files from user_guide folder 13 | NOTEBOOKS_DIR=notebooks 14 | cp -r $TMP_CONTENT_DIR/user_guide $NOTEBOOKS_DIR 15 | find $NOTEBOOKS_DIR -not -name '*.ipynb' -type f -delete 16 | 17 | # Put the .binder folder back (may be useful for debugging purposes) 18 | mv $TMP_CONTENT_DIR/.binder . 19 | # Final clean up 20 | rm -rf $TMP_CONTENT_DIR 21 | -------------------------------------------------------------------------------- /.codecov.yml: -------------------------------------------------------------------------------- 1 | codecov: 2 | disable_default_path_fixes: true 3 | fixes: 4 | - ".*/dist-packages/::" 5 | - ".*/site-packages/::" 6 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | name: Bug Report 2 | description: File a bug report. 3 | title: "Bug: " 4 | labels: ["bug", "needs triage"] 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: | 9 | Thanks for taking the time to fill out this bug report! 10 | 11 | > [!IMPORTANT] 12 | > Before submitting a bug, please make sure the issue hasn't been addressed already 13 | > by searching through [the past issues](https://github.com/sebp/scikit-survival/issues). 14 | - type: textarea 15 | id: description 16 | attributes: 17 | label: Describe the bug 18 | description: | 19 | Please provide a clear and concise description of what the bug is. 20 | validations: 21 | required: true 22 | - type: textarea 23 | id: example 24 | attributes: 25 | label: Steps/Code to Reproduce 26 | description: | 27 | Please add a *minimal, reproducible example* that can reproduce the error when running it. 28 | 29 | Be as succinct as possible, **do not depend on external data files**, instead use synthetically generated data or one of the datasets provided by [sksurv.datasets]( 30 | https://scikit-survival.readthedocs.io/en/latest/api/datasets.html). 31 | 32 | In short, **we are going to copy-paste your code** to run it and we expect to get the same result as you. 33 | 34 | Please follow [this guide](https://matthewrocklin.com/minimal-bug-reports) on how to 35 | provide a minimal, reproducible example. 36 | 37 | Example: 38 | ```python 39 | from sksurv.datasets import load_whas500 40 | from sksurv.preprocessing import OneHotEncoder 41 | from sksurv.linear_model import CoxPHSurvivalAnalysis 42 | 43 | X, y = load_whas500() 44 | 45 | features = OneHotEncoder().fit_transform(X) 46 | cph_model = CoxPHSurvivalAnalysis(alphas=[0.01, 0.1, 1.0]) 47 | cph_model.fit(features, y) 48 | ``` 49 | placeholder: | 50 | ``` 51 | Sample code to reproduce the problem 52 | ``` 53 | validations: 54 | required: true 55 | - type: textarea 56 | id: actual-result 57 | attributes: 58 | label: Actual Results 59 | description: | 60 | Please provide verbose output that clearly demonstrates the problem the reproducible example shows. 61 | 62 | If you observe an error, please paste the error message including the **full traceback** of the exception. For instance the code above raises the following exception: 63 | 64 | ```python-traceback 65 | --------------------------------------------------------------------------- 66 | TypeError Traceback (most recent call last) 67 | File my_bug_report.py:8 68 | 5 X, y = load_whas500() 69 | 7 features = OneHotEncoder().fit_transform(X) 70 | ----> 8 cph_model = CoxPHSurvivalAnalysis(alphas=[0.01, 0.1, 1.0]) 71 | 9 cph_model.fit(features, y) 72 | 73 | TypeError: CoxPHSurvivalAnalysis.__init__() got an unexpected keyword argument 'alphas'. Did you mean 'alpha'? 74 | ``` 75 | placeholder: | 76 | Please paste or specifically describe the actual result or traceback. 77 | validations: 78 | required: true 79 | - type: textarea 80 | id: expected-results 81 | attributes: 82 | label: Expected Results 83 | description: | 84 | Please describe the expected results. 85 | validations: 86 | required: true 87 | - type: textarea 88 | id: version 89 | attributes: 90 | label: Installed Versions 91 | render: shell 92 | description: | 93 | Please execute the code below and paste the output below. 94 | 95 | ```python 96 | import sksurv; sksurv.show_versions() 97 | ``` 98 | validations: 99 | required: true 100 | - type: markdown 101 | attributes: 102 | value: | 103 | Thanks for contributing 🎉! -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation_improvement.yml: -------------------------------------------------------------------------------- 1 | name: Documentation improvement 2 | description: Report wrong or missing documentation. 3 | title: "DOC: " 4 | labels: ["documentation", "needs triage"] 5 | body: 6 | - type: textarea 7 | id: location 8 | attributes: 9 | label: Location of the documentation 10 | description: > 11 | Please provide the location of the documentation, e.g. "sksurv.metrics.brier_score" or the 12 | URL of the documentation, e.g. 13 | "https://scikit-survival.readthedocs.io/en/latest/api/generated/sksurv.metrics.brier_score.html" 14 | validations: 15 | required: true 16 | - type: textarea 17 | id: problem 18 | attributes: 19 | label: Documentation Problem 20 | description: | 21 | Please provide a description of what documentation you believe needs to be fixed/improved. 22 | validations: 23 | required: true 24 | - type: textarea 25 | id: suggested-fix 26 | attributes: 27 | label: Suggested Fix for Documentation 28 | description: | 29 | Please tell us how we could improve the documentation to resolve the problem. 30 | validations: 31 | required: true 32 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | name: Feature request 2 | description: Suggest an idea for scikit-survival. 3 | title: "ENH: " 4 | labels: ["enhancement", "needs triage"] 5 | body: 6 | - type: checkboxes 7 | id: checks 8 | attributes: 9 | label: Feature Type 10 | description: Please check what type of feature request you would like to propose. 11 | options: 12 | - label: | 13 | Adding new functionality to scikit-survival 14 | - label: | 15 | Changing existing functionality in scikit-survival 16 | - label: | 17 | Removing existing functionality in scikit-survival 18 | - type: textarea 19 | id: description 20 | attributes: 21 | label: Problem Description 22 | description: | 23 | Please describe what problem the feature would solve or which workflow it would enable, e.g. "I wish scikit-survival would be able to ..." 24 | validations: 25 | required: true 26 | - type: textarea 27 | id: feature 28 | attributes: 29 | label: Feature Description 30 | description: | 31 | Please describe clearly and concisly how the new feature would be implemented. Use pseudocode if relevant. 32 | validations: 33 | required: true 34 | - type: textarea 35 | id: alternative 36 | attributes: 37 | label: Alternative Solutions 38 | description: | 39 | Please describe any alternative solution (existing functionality, 3rd party package, etc.) 40 | that would satisfy the feature request. 41 | - type: textarea 42 | id: references 43 | attributes: 44 | label: References and existing implementations 45 | description: | 46 | If you want to propose a new algorithm for inclusion, please include the original reference to the publication that first proposed the algorithm. If you are aware of any existing implementations, e.g. for R, please include a link too. 47 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "github-actions" 4 | directory: "/" 5 | schedule: 6 | interval: "weekly" 7 | - package-ecosystem: "pip" 8 | directory: "/" 9 | schedule: 10 | interval: "weekly" 11 | allow: 12 | - dependency-name: "scikit-learn" 13 | versioning-strategy: increase-if-necessary 14 | 15 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | 5 | 6 | **Checklist** 7 | 8 | 9 | - [ ] closes #xxxx 10 | - [ ] pytest passes 11 | - [ ] tests are included 12 | - [ ] code is well formatted 13 | - [ ] documentation renders correctly 14 | 15 | **What does this implement/fix? Explain your changes** 16 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yaml: -------------------------------------------------------------------------------- 1 | name: pre-commit 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [master] 7 | 8 | jobs: 9 | pre-commit: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v4 13 | - uses: actions/setup-python@v5 14 | with: 15 | python-version: '3.x' 16 | - uses: pre-commit/action@v3.0.1 17 | -------------------------------------------------------------------------------- /.github/workflows/tox-workflow.yaml: -------------------------------------------------------------------------------- 1 | name: Tox 2 | on: [push, pull_request] 3 | 4 | jobs: 5 | linting_and_docs: 6 | runs-on: ubuntu-latest 7 | name: Linting and Docs 8 | steps: 9 | - name: Checkout 10 | uses: actions/checkout@v4 11 | with: 12 | submodules: true 13 | - name: Install third-party tools 14 | run: | 15 | sudo apt-get install cmake libenchant-2-dev 16 | eval "$(/home/linuxbrew/.linuxbrew/bin/brew shellenv)" 17 | brew install pandoc 18 | - name: Setup Python 19 | uses: astral-sh/setup-uv@v6 20 | with: 21 | version: latest 22 | python-version: '3.11' 23 | cache-dependency-glob: | 24 | **/pyproject.toml 25 | **/tox.ini 26 | - name: Install Tox 27 | run: uv tool install tox --with tox-uv 28 | - name: Run Tox 29 | env: 30 | RUFF_OUTPUT_FORMAT: github 31 | run: | 32 | eval "$(/home/linuxbrew/.linuxbrew/bin/brew shellenv)" 33 | tox run -e ${{ matrix.tox_env }} 34 | - name: Archive documentation 35 | uses: actions/upload-artifact@v4 36 | with: 37 | name: documentation 38 | path: doc/_build/html 39 | compression-level: 9 40 | if: ${{ matrix.tox_env == 'docs' }} 41 | - name: Print debug information 42 | run: cat .tox/${{ matrix.tox_env }}/log/*.log 43 | if: ${{ failure() }} 44 | strategy: 45 | fail-fast: false 46 | matrix: 47 | tox_env: [lint, docs] 48 | -------------------------------------------------------------------------------- /.github/workflows/wheels-workflow.yaml: -------------------------------------------------------------------------------- 1 | name: Build Wheels 2 | 3 | on: 4 | release: 5 | types: [published] 6 | workflow_dispatch: 7 | inputs: 8 | publish_wheel: 9 | description: 'Whether to publish wheels' 10 | default: false 11 | type: boolean 12 | 13 | 14 | jobs: 15 | build_wheels: 16 | name: Build wheels 📦 on ${{ matrix.os }} 17 | runs-on: ${{ matrix.os }} 18 | strategy: 19 | fail-fast: false 20 | matrix: 21 | os: [ubuntu-22.04, windows-2019, macos-13, macos-14] 22 | python: [310, 311, 312, 313] 23 | 24 | steps: 25 | - uses: actions/checkout@v4 26 | with: 27 | fetch-depth: 0 28 | submodules: true 29 | 30 | - name: Set platform 31 | id: platform 32 | shell: bash 33 | run: | 34 | if [ ${{ runner.os }} = 'Linux' ]; then 35 | echo "build_platform=manylinux" >> $GITHUB_OUTPUT 36 | elif [ ${{ runner.os }} = 'macOS' ]; then 37 | echo "build_platform=macosx" >> $GITHUB_OUTPUT 38 | elif [ ${{ runner.os }} = 'Windows' ]; then 39 | echo "build_platform=win_amd64" >> $GITHUB_OUTPUT 40 | fi 41 | 42 | - name: Build wheels 43 | uses: pypa/cibuildwheel@v2.23.3 44 | env: 45 | CIBW_ARCHS: auto64 46 | CIBW_BUILD: cp${{ matrix.python }}-${{ steps.platform.outputs.build_platform }}* 47 | CIBW_MANYLINUX_X86_64_IMAGE: manylinux2014 48 | CIBW_BUILD_VERBOSITY: 1 49 | CIBW_ENVIRONMENT_MACOS: MACOSX_DEPLOYMENT_TARGET=${{ runner.arch == 'ARM64' && '11.0' || '10.13' }} 50 | CIBW_TEST_COMMAND: pytest --strict-markers -m 'not slow' -k 'not test_fit_and_predict_linear_regression' {project}/tests 51 | CIBW_TEST_COMMAND_WINDOWS: pytest --strict-markers -m "not slow" -k "not test_fit_and_predict_linear_regression" {project}\\tests 52 | CIBW_TEST_REQUIRES: pytest 53 | # Skip trying to test arm64 builds on Intel Macs, and vice versa 54 | CIBW_TEST_SKIP: "*-macosx_${{ runner.arch == 'ARM64' && 'x86_64' || 'arm64' }} *-macosx_universal2:arm64" 55 | - uses: actions/upload-artifact@v4 56 | with: 57 | name: cibw-wheels-${{ matrix.os }}-${{ strategy.job-index }} 58 | path: ./wheelhouse/*.whl 59 | if-no-files-found: error 60 | 61 | build_sdist: 62 | name: Build source distribution 📦 63 | runs-on: ubuntu-latest 64 | steps: 65 | - uses: actions/checkout@v4 66 | with: 67 | fetch-depth: 0 68 | submodules: true 69 | 70 | - name: Build sdist 71 | run: pipx run build --sdist 72 | 73 | - uses: actions/upload-artifact@v4 74 | with: 75 | name: cibw-sdist 76 | path: dist/*.tar.gz 77 | 78 | publish-to-testpypi: 79 | name: >- 80 | Publish scikit-surival 📉 distribution 📦 to TestPyPI 81 | if: github.repository == 'sebp/scikit-survival' && github.event_name == 'workflow_dispatch' && inputs.publish_wheel 82 | needs: 83 | - build_wheels 84 | - build_sdist 85 | runs-on: ubuntu-latest 86 | 87 | environment: 88 | name: testpypi-release 89 | url: https://test.pypi.org/p/scikit-survival # TestPyPI project name 90 | 91 | permissions: 92 | id-token: write # IMPORTANT: mandatory for trusted publishing 93 | 94 | steps: 95 | - name: Download packages 96 | uses: actions/download-artifact@v4 97 | with: 98 | pattern: cibw-* 99 | path: dist 100 | merge-multiple: true 101 | 102 | - name: Print out packages 103 | run: ls dist 104 | 105 | - name: Publish to TestPyPI 106 | uses: pypa/gh-action-pypi-publish@release/v1 107 | with: 108 | repository-url: https://test.pypi.org/legacy/ 109 | verbose: true 110 | 111 | publish-to-pypi: 112 | name: >- 113 | Publish scikit-surival 📉 distribution 📦 to PyPI 114 | if: github.repository == 'sebp/scikit-survival' && github.event_name == 'release' && startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes 115 | needs: 116 | - build_wheels 117 | - build_sdist 118 | runs-on: ubuntu-latest 119 | 120 | environment: 121 | name: pypi-release 122 | url: https://pypi.org/p/scikit-survival # PyPI project name 123 | 124 | permissions: 125 | id-token: write # IMPORTANT: mandatory for trusted publishing 126 | 127 | steps: 128 | - name: Download packages 129 | uses: actions/download-artifact@v4 130 | with: 131 | pattern: cibw-* 132 | path: dist 133 | merge-multiple: true 134 | 135 | - name: Print out packages 136 | run: ls dist 137 | 138 | - name: Publish to PyPI 139 | uses: pypa/gh-action-pypi-publish@release/v1 140 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Cython generated files 10 | _*.c 11 | _*.cpp 12 | 13 | # Cython debug symbols 14 | cython_debug/ 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | cover/ 59 | 60 | # Django stuff: 61 | *.log 62 | 63 | # PyBuilder 64 | .pybuilder/ 65 | target/ 66 | 67 | # Jupyter Notebook 68 | .ipynb_checkpoints 69 | 70 | # IPython 71 | profile_default/ 72 | ipython_config.py 73 | 74 | # Environments 75 | .env 76 | .venv 77 | env/ 78 | venv/ 79 | ENV/ 80 | env.bak/ 81 | venv.bak/ 82 | 83 | # Spyder project settings 84 | .spyderproject 85 | .spyproject 86 | 87 | # Rope project settings 88 | .ropeproject 89 | 90 | # mypy 91 | .mypy_cache/ 92 | .dmypy.json 93 | dmypy.json 94 | 95 | # Pyre type checker 96 | .pyre/ 97 | 98 | # pytype static type analyzer 99 | .pytype/ 100 | 101 | # PyCharm 102 | .idea/ 103 | 104 | # ruff 105 | .ruff_cache/ 106 | 107 | # LSP config files 108 | pyrightconfig.json 109 | 110 | # VisualStudioCode 111 | .vscode/ 112 | 113 | # macOS General 114 | .DS_Store 115 | .AppleDouble 116 | .LSOverride 117 | 118 | # Icon must end with two \r 119 | Icon 120 | 121 | # Thumbnails 122 | ._* 123 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "sksurv/linear_model/src/eigen"] 2 | path = sksurv/linear_model/src/eigen 3 | url = https://gitlab.com/libeigen/eigen.git 4 | branch = 3.4 5 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v5.0.0 6 | hooks: 7 | - id: trailing-whitespace 8 | exclude: '^\.gitignore$' 9 | - id: check-added-large-files 10 | - id: check-ast 11 | - id: check-case-conflict 12 | - id: check-docstring-first 13 | - id: check-merge-conflict 14 | - id: check-symlinks 15 | - id: check-toml 16 | - id: check-yaml 17 | - id: debug-statements 18 | - id: destroyed-symlinks 19 | - id: name-tests-test 20 | args: ["--pytest-test-first"] 21 | - repo: https://github.com/pre-commit/pygrep-hooks 22 | rev: v1.10.0 23 | hooks: 24 | - id: python-check-blanket-noqa 25 | - id: python-no-eval 26 | - id: rst-directive-colons 27 | - id: rst-inline-touching-normal 28 | - repo: https://github.com/python-jsonschema/check-jsonschema 29 | rev: 0.31.3 30 | hooks: 31 | - id: check-dependabot 32 | - id: check-github-workflows 33 | - id: check-readthedocs 34 | - repo: https://github.com/Lucas-C/pre-commit-hooks 35 | rev: v1.5.5 36 | hooks: 37 | - id: forbid-crlf 38 | exclude: '^\.gitignore$' 39 | - id: forbid-tabs 40 | types: [python] 41 | - repo: https://github.com/psf/black-pre-commit-mirror 42 | rev: 25.1.0 43 | hooks: 44 | - id: black-jupyter 45 | - repo: https://github.com/astral-sh/ruff-pre-commit 46 | rev: v0.11.1 47 | hooks: 48 | - id: ruff 49 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | python: 4 | install: 5 | - method: pip 6 | path: . 7 | extra_requirements: 8 | - docs 9 | 10 | build: 11 | os: ubuntu-22.04 12 | tools: 13 | python: "3.11" 14 | 15 | sphinx: 16 | configuration: doc/conf.py 17 | 18 | submodules: 19 | include: all 20 | recursive: true 21 | -------------------------------------------------------------------------------- /.zenodo.json: -------------------------------------------------------------------------------- 1 | { 2 | "license": "GPL-3.0", 3 | "title": "scikit-survival", 4 | "upload_type": "software", 5 | "creators": [ 6 | { 7 | "name": "Sebastian P\u00f6lsterl" 8 | } 9 | ], 10 | "access_right": "open", 11 | "keywords": [ 12 | "survival-analysis", 13 | "machine-learning", 14 | "python", 15 | "scikit-learn" 16 | ] 17 | } 18 | -------------------------------------------------------------------------------- /CONTRIBUTING.rst: -------------------------------------------------------------------------------- 1 | Contributing to scikit-survival 2 | =============================== 3 | 4 | Pull requests are always welcome, and we appreciate any help you give. 5 | There are many ways to contribute to scikit-survival: 6 | 7 | - Reporting bugs. 8 | - Writing new code, e.g. implementations of new algorithms, or examples. 9 | - Fixing bugs. 10 | - Improving documentation. 11 | - Reviewing open pull requests. 12 | 13 | For detailed instructions on how to get started, please see 14 | the `contributing guidelines `_. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | # by default, the source distribution includes everything that is under version control 2 | graft sksurv/linear_model/src/eigen/Eigen 3 | 4 | prune doc/_build 5 | prune doc/api/generated 6 | prune .binder 7 | prune .github 8 | prune ci 9 | exclude appveyor.yml .codecov.yml .pre-commit-config.yaml .readthedocs.yaml .zenodo.json 10 | 11 | global-exclude __pycache__ 12 | global-exclude .ipynb_checkpoints 13 | global-exclude .git* 14 | # Cython generated files 15 | global-exclude _*.c 16 | global-exclude _*.cpp 17 | global-exclude *.py[oc] 18 | global-exclude *.bak 19 | global-exclude *.swp 20 | global-exclude *~ 21 | -------------------------------------------------------------------------------- /appveyor.yml: -------------------------------------------------------------------------------- 1 | # AppVeyor.com is a Continuous Integration service to build and run tests under 2 | # Windows 3 | image: Visual Studio 2019 4 | 5 | environment: 6 | # see https://www.appveyor.com/docs/build-environment/#python 7 | matrix: 8 | - TARGET_ARCH: "x64" 9 | CONDA_PY: "310" 10 | - TARGET_ARCH: "x64" 11 | CONDA_PY: "311" 12 | - TARGET_ARCH: "x64" 13 | CONDA_PY: "312" 14 | - TARGET_ARCH: "x64" 15 | CONDA_PY: "313" 16 | 17 | 18 | # We always use a 64-bit machine. 19 | platform: 20 | - x64 21 | 22 | install: 23 | - ps: $env:CONDA_INSTALL_LOCN="$env:USERPROFILE\Miniconda3" 24 | - ps: Invoke-WebRequest -Uri https://repo.anaconda.com/miniconda/Miniconda3-py312_24.9.2-0-Windows-x86_64.exe -OutFile Miniconda3.exe 25 | - ps: ( Get-FileHash -Algorithm SHA256 Miniconda3.exe ).Hash -eq '3a8897cc5d27236ade8659f0e119f3a3ccaad68a45de45bfdd3102d8bec412ab' 26 | - ps: Start-Process Miniconda3.exe -Wait -ArgumentList @('/S', '/InstallationType=JustMe', '/RegisterPython=1', '/D="$env:CONDA_INSTALL_LOCN"') 27 | - ps: $env:Path += ";$env:CONDA_INSTALL_LOCN\Scripts;$env:CONDA_INSTALL_LOCN\Library\bin" 28 | - ps: conda init powershell 29 | - ps: conda config --set always_yes yes 30 | - ps: conda config --set changeps1 no 31 | - ps: conda config --set auto_update_conda false 32 | - ps: conda config --set notify_outdated_conda false 33 | # - ps: conda update --yes --quiet conda 34 | - ps: conda info -a 35 | 36 | # Install the build and runtime dependencies of the project. 37 | - ps: $envScript=".\ci\appveyor\py$($env:CONDA_PY).ps1" 38 | - ps: "& $envScript" 39 | - ps: python "ci\render-requirements.py" "ci\deps\requirements.yaml.tmpl" > environment.yaml 40 | # work-around for https://github.com/conda/conda/issues/14355 41 | - ps: "& { $env:PYTHONWARNINGS = \"ignore::FutureWarning\"; conda env create -n sksurv-test --file environment.yaml }" 42 | - cmd: call %CONDA_INSTALL_LOCN%\Scripts\activate.bat sksurv-test 43 | - cmd: conda list 44 | # Initialize the submodules 45 | - cmd: git submodule update --init --recursive 46 | # Create binary packages for the project. 47 | - cmd: python -m build . 48 | - ps: "ls dist" 49 | 50 | # Install the generated wheel package to test it 51 | - cmd: pip install --exists-action=w --pre --no-index --find-links dist/ scikit-survival 52 | - cmd: rmdir sksurv /s /q 53 | 54 | test_script: 55 | - cmd: set "PYTHONWARNINGS=default" 56 | - cmd: pytest -m "not slow" 57 | 58 | artifacts: 59 | # Archive the generated wheel package in the ci.appveyor.com build report. 60 | - path: dist\* 61 | 62 | # Skip .NET project specific build phase. 63 | build: off 64 | -------------------------------------------------------------------------------- /ci/appveyor/py310.ps1: -------------------------------------------------------------------------------- 1 | $env:CI_PYTHON_VERSION="3.10.*" 2 | $env:CI_PANDAS_VERSION="1.5.*" 3 | $env:CI_NUMPY_VERSION="1.25.*" 4 | $env:CI_SKLEARN_VERSION="1.6.*" 5 | -------------------------------------------------------------------------------- /ci/appveyor/py311.ps1: -------------------------------------------------------------------------------- 1 | $env:CI_PYTHON_VERSION="3.11.*" 2 | $env:CI_PANDAS_VERSION="2.0.*" 3 | $env:CI_NUMPY_VERSION="1.26.*" 4 | $env:CI_SKLEARN_VERSION="1.6.*" 5 | -------------------------------------------------------------------------------- /ci/appveyor/py312.ps1: -------------------------------------------------------------------------------- 1 | $env:CI_PYTHON_VERSION="3.12.*" 2 | $env:CI_PANDAS_VERSION="2.2.*" 3 | $env:CI_NUMPY_VERSION="2.0.*" 4 | $env:CI_SKLEARN_VERSION="1.6.*" 5 | -------------------------------------------------------------------------------- /ci/appveyor/py313.ps1: -------------------------------------------------------------------------------- 1 | $env:CI_PYTHON_VERSION="3.13.*" 2 | $env:CI_PANDAS_VERSION="2.2.*" 3 | $env:CI_NUMPY_VERSION="2.1.*" 4 | $env:CI_SKLEARN_VERSION="1.6.*" 5 | -------------------------------------------------------------------------------- /ci/deps/py310.sh: -------------------------------------------------------------------------------- 1 | # shellcheck shell=sh 2 | export CI_PYTHON_VERSION='3.10.*' 3 | export CI_PANDAS_VERSION='1.5.*' 4 | export CI_NUMPY_VERSION='1.25.*' 5 | export CI_SKLEARN_VERSION='1.6.*' 6 | export CI_NO_SLOW=false 7 | -------------------------------------------------------------------------------- /ci/deps/py311.sh: -------------------------------------------------------------------------------- 1 | # shellcheck shell=sh 2 | export CI_PYTHON_VERSION='3.11.*' 3 | export CI_PANDAS_VERSION='2.0.*' 4 | export CI_NUMPY_VERSION='1.26.*' 5 | export CI_SKLEARN_VERSION='1.6.*' 6 | export CI_NO_SLOW=true 7 | -------------------------------------------------------------------------------- /ci/deps/py312.sh: -------------------------------------------------------------------------------- 1 | # shellcheck shell=sh 2 | export CI_PYTHON_VERSION='3.12.*' 3 | export CI_PANDAS_VERSION='2.2.*' 4 | export CI_NUMPY_VERSION='2.0.*' 5 | export CI_SKLEARN_VERSION='1.6.*' 6 | export CI_NO_SLOW=true 7 | -------------------------------------------------------------------------------- /ci/deps/py313.sh: -------------------------------------------------------------------------------- 1 | # shellcheck shell=sh 2 | export CI_PYTHON_VERSION='3.13.*' 3 | export CI_PANDAS_VERSION='2.2.*' 4 | export CI_NUMPY_VERSION='2.1.*' 5 | export CI_SKLEARN_VERSION='1.6.*' 6 | export CI_NO_SLOW=false 7 | -------------------------------------------------------------------------------- /ci/deps/requirements.yaml.tmpl: -------------------------------------------------------------------------------- 1 | name: sksurv-test 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - coverage 7 | - cython>=3.0 8 | - ecos 9 | - joblib 10 | - matplotlib>=3.9.0,<3.10 11 | - numexpr 12 | - numpy={CI_NUMPY_VERSION} 13 | - osqp>=0.6.3,<1.0.0 14 | - packaging 15 | - pandas={CI_PANDAS_VERSION} 16 | - pip 17 | - pytest 18 | - python={CI_PYTHON_VERSION} 19 | - scikit-learn=={CI_SKLEARN_VERSION} 20 | - scipy>=1.6.0 21 | - seaborn>=0.13.2,<0.14 22 | - setuptools-scm 23 | - wheel 24 | - pip: 25 | - black[jupyter]>=23.3.0,<23.4 26 | - build 27 | - nbval>=0.10.0 28 | - tomli 29 | -------------------------------------------------------------------------------- /ci/nb_sanitize.cfg: -------------------------------------------------------------------------------- 1 | [memory] 2 | regex: at 0x[0-9a-f]+ 3 | replace: MEMORY-LOCATION 4 | -------------------------------------------------------------------------------- /ci/render-requirements.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument("yaml_file") 6 | 7 | 8 | def get_pinned_packages(): 9 | pkgs = { 10 | "NUMPY", 11 | "PANDAS", 12 | "SKLEARN", 13 | "PYTHON", 14 | } 15 | pinned = {} 16 | for env_name in pkgs: 17 | key = f"CI_{env_name}_VERSION" 18 | ver = os.environ.get(key, "*") 19 | pinned[key] = ver 20 | return pinned 21 | 22 | 23 | def render_requirements(filename): 24 | pinned = get_pinned_packages() 25 | with open(filename) as fin: 26 | contents = "".join(fin.readlines()) 27 | 28 | return contents.format(**pinned) 29 | 30 | 31 | def main(): 32 | args = parser.parse_args() 33 | req = render_requirements(args.yaml_file) 34 | print(req) 35 | 36 | 37 | if __name__ == "__main__": 38 | main() 39 | -------------------------------------------------------------------------------- /ci/run_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | export PYTHONWARNINGS="default" 4 | 5 | pytest_opts=("") 6 | 7 | if [ "x${CI_NO_SLOW:-false}" != "xtrue" ]; then 8 | coverage erase 9 | rm -f coverage.xml 10 | else 11 | pytest_opts+=(-m 'not slow') 12 | fi 13 | 14 | coverage run -m pytest "${pytest_opts[@]}" 15 | 16 | coverage xml 17 | coverage report 18 | -------------------------------------------------------------------------------- /ci/setup_conda.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | RUNNER_OS="${1}" 4 | RUNNER_ARCH="${2}" 5 | CONDA_PKGS_DIR="${3}" 6 | 7 | run_check_sha() { 8 | echo "${1}" | shasum -a 256 --check --strict - 9 | } 10 | 11 | if [[ "${CONDA:-}" = "" ]]; then 12 | # download and install conda 13 | MINICONDA_VERSION="Miniconda3-py312_24.9.2-0" 14 | 15 | if [[ "${RUNNER_OS}" = "macOS" ]] && [[ "${RUNNER_ARCH}" = "ARM64" ]]; then 16 | MINICONDA_VERSION="${MINICONDA_VERSION}-MacOSX-arm64" 17 | MINICONDA_HASH="08d8a82ed21d2dae707554d540b172fe03327347db747644fbb33abfaf07fddd" 18 | elif [[ "${RUNNER_OS}" = "macOS" ]] && [[ "${RUNNER_ARCH}" = "X64" ]]; then 19 | MINICONDA_VERSION="${MINICONDA_VERSION}-MacOSX-x86_64" 20 | MINICONDA_HASH="ce3b440c32c9c636bbe529477fd496798c35b96d9db1838e3df6b0a80714da4e" 21 | elif [[ "${RUNNER_OS}" = "Linux" ]] && [[ "${RUNNER_ARCH}" = "X64" ]]; then 22 | MINICONDA_VERSION="${MINICONDA_VERSION}-Linux-x86_64" 23 | MINICONDA_HASH="8d936ba600300e08eca3d874dee88c61c6f39303597b2b66baee54af4f7b4122" 24 | else 25 | echo "Unsupported OS or ARCH: ${RUNNER_OS} ${RUNNER_ARCH}" 26 | exit 1 27 | fi 28 | 29 | export CONDA="${GITHUB_WORKSPACE}/miniconda3" 30 | 31 | mkdir -p "${CONDA}" && \ 32 | curl "https://repo.anaconda.com/miniconda/${MINICONDA_VERSION}.sh" -o "${CONDA}/miniconda.sh" && \ 33 | run_check_sha "${MINICONDA_HASH} ${CONDA}/miniconda.sh" && \ 34 | bash "${CONDA}/miniconda.sh" -b -u -p "${CONDA}" && \ 35 | rm -rf "${CONDA}/miniconda.sh" || exit 1 36 | 37 | echo "CONDA=${CONDA}" >> "${GITHUB_ENV}" 38 | fi 39 | 40 | "${CONDA}/bin/conda" config --set always_yes yes && \ 41 | "${CONDA}/bin/conda" config --set changeps1 no && \ 42 | "${CONDA}/bin/conda" config --set auto_update_conda false && \ 43 | "${CONDA}/bin/conda" config --set show_channel_urls true || \ 44 | exit 1 45 | 46 | # The directory in which packages are located. 47 | # https://docs.conda.io/projects/conda/en/latest/user-guide/configuration/settings.html#pkgs-dirs-specify-package-directories 48 | if [[ ! -d "${CONDA_PKGS_DIR}" ]]; then 49 | mkdir -p "${CONDA_PKGS_DIR}" || exit 1 50 | fi 51 | sudo chown -R "${USER}" "${CONDA_PKGS_DIR}" || \ 52 | exit 1 53 | 54 | sudo "${CONDA}/bin/conda" update -q -n base conda && \ 55 | sudo chown -R "${USER}" "${CONDA}" || \ 56 | exit 1 57 | 58 | export PATH="${CONDA}/bin:${PATH}" 59 | echo "${CONDA}/bin" >> "${GITHUB_PATH}" 60 | -------------------------------------------------------------------------------- /ci/setup_env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -xe 3 | 4 | OS="$1" 5 | 6 | if [ "x${OS}" = "xLinux" ]; then 7 | COMPILER=() 8 | elif [ "x${OS}" = "xmacOS" ]; then 9 | COMPILER=(clang_osx-arm64 clangxx_osx-arm64) 10 | else 11 | echo "OS '${OS}' is unsupported." 12 | exit 1 13 | fi 14 | 15 | python ci/render-requirements.py ci/deps/requirements.yaml.tmpl > environment.yaml 16 | 17 | conda env create -n sksurv-test --file environment.yaml 18 | 19 | echo "numpy ${CI_NUMPY_VERSION:?}" > "${CONDA:?}/envs/sksurv-test/conda-meta/pinned" 20 | echo "pandas ${CI_PANDAS_VERSION:?}" >> "${CONDA:?}/envs/sksurv-test/conda-meta/pinned" 21 | echo "scikit-learn ${CI_SKLEARN_VERSION:?}" >> "${CONDA:?}/envs/sksurv-test/conda-meta/pinned" 22 | 23 | # Useful for debugging any issues with conda 24 | conda info -a 25 | 26 | # shellcheck disable=SC1091 27 | source activate sksurv-test 28 | 29 | # delete any version that is already installed 30 | pip uninstall --yes scikit-survival || exit 0 31 | 32 | conda list -n sksurv-test -------------------------------------------------------------------------------- /doc/.gitignore: -------------------------------------------------------------------------------- 1 | _build/ 2 | api/generated/ 3 | -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /doc/_static/custom.css: -------------------------------------------------------------------------------- 1 | /* from https://github.com/pandas-dev/pandas/pull/49811 */ 2 | table { 3 | width: auto; /* Override fit-content which breaks Styler user guide ipynb */ 4 | } 5 | 6 | .nboutput .output_area img { 7 | border: 5px solid var(--pst-color-background); 8 | } 9 | 10 | .navbar-brand { 11 | padding-top: .3125rem; 12 | padding-bottom: .3125rem; 13 | height: 45px; 14 | } 15 | 16 | .overview-grid .sd-card-title { 17 | color: var(--pst-color-link); 18 | font-size: calc(var(--pst-font-size-h4) - .1rem); 19 | line-height: 1.4; 20 | } 21 | 22 | .overview-grid .sd-card-title.sd-font-weight-bold { 23 | font-weight: var(--pst-font-weight-heading) !important; 24 | } 25 | 26 | .overview-grid .sd-card-title svg[data-prefix="fas"] { 27 | position: absolute; 28 | right: 1.2rem; 29 | line-height: 1.4; 30 | transition: all .2s 31 | } 32 | 33 | /* see https://github.com/pydata/pydata-sphinx-theme/issues/2112#issuecomment-2619729198 */ 34 | .overview-grid .sd-card .sd-card-body { 35 | background-color: unset !important; 36 | } 37 | 38 | .overview-grid .sd-card-hover:hover .sd-card-title { 39 | color: var(--pst-color-link-hover); 40 | } 41 | 42 | .overview-grid .sd-card-hover:hover { 43 | border-color: var(--sd-color-card-border); 44 | } 45 | 46 | .overview-grid .sd-card-hover:after { 47 | content: " "; 48 | position: absolute; 49 | right: 0; 50 | bottom: 0; 51 | left: 0; 52 | height: 3px; 53 | background: #4ce8ff; 54 | background: linear-gradient(90deg, #4ce8ff, #d07cff); 55 | opacity: 0; 56 | transition: all .2s 57 | } 58 | 59 | .overview-grid .sd-card-hover:hover:after { 60 | opacity: 1; 61 | } 62 | 63 | .github-icon { 64 | display: inline-block; 65 | vertical-align: middle; 66 | } 67 | 68 | .github-icon span { 69 | padding: .5rem; 70 | } 71 | 72 | .github-repository-name { 73 | display: inline-block; 74 | max-width: calc(100% - 1.2rem); 75 | overflow: hidden; 76 | text-overflow: ellipsis; 77 | vertical-align: middle; 78 | } 79 | 80 | .navbar-icon-links ul.github-facts { 81 | display: flex; 82 | gap: .4rem; 83 | margin: .1rem 0 0; 84 | padding: 0; 85 | font-size: .95rem; 86 | font-weight: 400; 87 | list-style: none; 88 | opacity: .75; 89 | overflow: hidden; 90 | text-overflow: ellipsis; 91 | width: 100%; 92 | } 93 | 94 | li.github-fact { 95 | display: inline-block; 96 | vertical-align: middle; 97 | } 98 | 99 | .github-fact-stars::before { 100 | font: var(--fa-font-solid); 101 | content: '\f005'; 102 | margin-right: .15rem; 103 | } 104 | 105 | .github-fact-forks::before { 106 | font: var(--fa-font-solid); 107 | content: '\e13b'; 108 | margin-right: .15rem; 109 | } 110 | 111 | a:hover.github-repository { 112 | text-decoration: none; 113 | color: var(--pst-color-link-hover); 114 | } 115 | 116 | a.github-repository { 117 | color: var(--pst-color-text-muted); 118 | } 119 | 120 | a:visited.github-repository { 121 | color: var(--pst-color-text-muted); 122 | } 123 | 124 | a:visited:hover.github-repository { 125 | text-decoration: none; 126 | color: var(--pst-color-link-hover); 127 | } -------------------------------------------------------------------------------- /doc/_static/github-stats.js: -------------------------------------------------------------------------------- 1 | "use strict"; 2 | var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) { 3 | function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); } 4 | return new (P || (P = Promise))(function (resolve, reject) { 5 | function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } } 6 | function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } } 7 | function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); } 8 | step((generator = generator.apply(thisArg, _arguments || [])).next()); 9 | }); 10 | }; 11 | function round(value) { 12 | if (value > 999) { 13 | let digits = +((value - 950) % 1000 > 99); 14 | return `${((value + 0.000001) / 1000).toFixed(digits)}k`; 15 | } 16 | else { 17 | return value.toString(); 18 | } 19 | } 20 | function get_stats(repo) { 21 | return __awaiter(this, void 0, void 0, function* () { 22 | const response = yield fetch("https://api.github.com/repos" + repo, { 23 | headers: { 24 | "Accept": "application/vnd.github+json", 25 | "X-GitHub-Api-Version": "2022-11-28", 26 | }, 27 | redirect: "follow", 28 | }); 29 | return response.json(); 30 | }); 31 | } 32 | function get_repo_url() { 33 | const repoElem = document.querySelector("a.github-repository"); 34 | if (repoElem == null) { 35 | return null; 36 | } 37 | const url_str = repoElem.getAttribute("href"); 38 | if (url_str != null) { 39 | return new URL(url_str); 40 | } 41 | return null; 42 | } 43 | function append_element_stats(elements, text, css_class) { 44 | elements.forEach((elem) => { 45 | let child = document.createElement("li"); 46 | child.setAttribute("class", `github-fact ${css_class}`); 47 | const txt = document.createTextNode(text); 48 | child.appendChild(txt); 49 | elem.appendChild(child); 50 | }); 51 | } 52 | function update_repo_stats() { 53 | const url = get_repo_url(); 54 | if (url == null) 55 | return; 56 | let stats_elem = document.querySelectorAll("ul.github-facts"); 57 | get_stats(url.pathname).then((data) => { 58 | append_element_stats(stats_elem, round(data["stargazers_count"]), "github-fact-stars"); 59 | append_element_stats(stats_elem, round(data["forks_count"]), "github-fact-forks"); 60 | }); 61 | } 62 | function update_repo_stats_once() { 63 | return __awaiter(this, void 0, void 0, function* () { 64 | document.removeEventListener("DOMContentLoaded", update_repo_stats); 65 | update_repo_stats(); 66 | }); 67 | } 68 | document.addEventListener("DOMContentLoaded", update_repo_stats_once); 69 | //# sourceMappingURL=github-stats.js.map -------------------------------------------------------------------------------- /doc/_templates/navbar-github-links.html: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /doc/api/compare.rst: -------------------------------------------------------------------------------- 1 | Hypothesis testing 2 | ================== 3 | .. currentmodule:: sksurv.compare 4 | 5 | .. autosummary:: 6 | :toctree: generated/ 7 | 8 | compare_survival 9 | -------------------------------------------------------------------------------- /doc/api/datasets.rst: -------------------------------------------------------------------------------- 1 | Datasets 2 | ======== 3 | .. currentmodule:: sksurv.datasets 4 | 5 | .. autosummary:: 6 | :toctree: generated/ 7 | 8 | get_x_y 9 | load_aids 10 | load_arff_files_standardized 11 | load_bmt 12 | load_cgvhd 13 | load_breast_cancer 14 | load_flchain 15 | load_gbsg2 16 | load_whas500 17 | load_veterans_lung_cancer 18 | -------------------------------------------------------------------------------- /doc/api/ensemble.rst: -------------------------------------------------------------------------------- 1 | Ensemble Models 2 | =============== 3 | .. currentmodule:: sksurv.ensemble 4 | 5 | .. autosummary:: 6 | :toctree: generated/ 7 | 8 | ComponentwiseGradientBoostingSurvivalAnalysis 9 | GradientBoostingSurvivalAnalysis 10 | RandomSurvivalForest 11 | ExtraSurvivalTrees 12 | -------------------------------------------------------------------------------- /doc/api/functions.rst: -------------------------------------------------------------------------------- 1 | Functions 2 | ========= 3 | .. currentmodule:: sksurv.functions 4 | 5 | .. autosummary:: 6 | :toctree: generated/ 7 | 8 | StepFunction 9 | -------------------------------------------------------------------------------- /doc/api/index.rst: -------------------------------------------------------------------------------- 1 | API reference 2 | ============= 3 | 4 | This page gives an overview of all public objects, functions and methods. 5 | 6 | .. toctree:: 7 | :maxdepth: 2 8 | 9 | datasets 10 | ensemble 11 | functions 12 | compare 13 | io 14 | kernels 15 | linear_model 16 | meta 17 | metrics 18 | nonparametric 19 | preprocessing 20 | svm 21 | tree 22 | util 23 | -------------------------------------------------------------------------------- /doc/api/io.rst: -------------------------------------------------------------------------------- 1 | I/O Utilities 2 | ============= 3 | .. currentmodule:: sksurv.io 4 | 5 | .. autosummary:: 6 | :toctree: generated/ 7 | 8 | loadarff 9 | writearff 10 | -------------------------------------------------------------------------------- /doc/api/kernels.rst: -------------------------------------------------------------------------------- 1 | Kernels 2 | ======= 3 | .. currentmodule:: sksurv.kernels 4 | 5 | .. autosummary:: 6 | :toctree: generated/ 7 | 8 | ClinicalKernelTransform 9 | clinical_kernel 10 | -------------------------------------------------------------------------------- /doc/api/linear_model.rst: -------------------------------------------------------------------------------- 1 | Linear Models 2 | ============= 3 | .. currentmodule:: sksurv.linear_model 4 | 5 | .. autosummary:: 6 | :toctree: generated/ 7 | 8 | CoxnetSurvivalAnalysis 9 | CoxPHSurvivalAnalysis 10 | IPCRidge 11 | -------------------------------------------------------------------------------- /doc/api/meta.rst: -------------------------------------------------------------------------------- 1 | Meta Models 2 | =========== 3 | .. currentmodule:: sksurv.meta 4 | 5 | .. autosummary:: 6 | :toctree: generated/ 7 | 8 | EnsembleSelection 9 | EnsembleSelectionRegressor 10 | Stacking 11 | -------------------------------------------------------------------------------- /doc/api/metrics.rst: -------------------------------------------------------------------------------- 1 | Metrics 2 | ======= 3 | .. currentmodule:: sksurv.metrics 4 | 5 | .. autosummary:: 6 | :toctree: generated/ 7 | 8 | brier_score 9 | concordance_index_censored 10 | concordance_index_ipcw 11 | cumulative_dynamic_auc 12 | integrated_brier_score 13 | as_concordance_index_ipcw_scorer 14 | as_cumulative_dynamic_auc_scorer 15 | as_integrated_brier_score_scorer 16 | -------------------------------------------------------------------------------- /doc/api/nonparametric.rst: -------------------------------------------------------------------------------- 1 | Non-parametric Estimators 2 | ========================= 3 | .. currentmodule:: sksurv.nonparametric 4 | 5 | .. autosummary:: 6 | :toctree: generated/ 7 | 8 | CensoringDistributionEstimator 9 | SurvivalFunctionEstimator 10 | cumulative_incidence_competing_risks 11 | ipc_weights 12 | kaplan_meier_estimator 13 | nelson_aalen_estimator 14 | -------------------------------------------------------------------------------- /doc/api/preprocessing.rst: -------------------------------------------------------------------------------- 1 | Pre-Processing 2 | ============== 3 | .. currentmodule:: sksurv.preprocessing 4 | 5 | .. autosummary:: 6 | :toctree: generated/ 7 | 8 | OneHotEncoder 9 | 10 | .. currentmodule:: sksurv.column 11 | 12 | .. autosummary:: 13 | :toctree: generated/ 14 | 15 | categorical_to_numeric 16 | encode_categorical 17 | standardize 18 | -------------------------------------------------------------------------------- /doc/api/svm.rst: -------------------------------------------------------------------------------- 1 | .. _mod-svm: 2 | 3 | Survival Support Vector Machine 4 | =============================== 5 | .. currentmodule:: sksurv.svm 6 | 7 | .. autosummary:: 8 | :toctree: generated/ 9 | 10 | HingeLossSurvivalSVM 11 | FastKernelSurvivalSVM 12 | FastSurvivalSVM 13 | MinlipSurvivalAnalysis 14 | NaiveSurvivalSVM 15 | -------------------------------------------------------------------------------- /doc/api/tree.rst: -------------------------------------------------------------------------------- 1 | Survival Trees 2 | ============== 3 | .. currentmodule:: sksurv.tree 4 | 5 | .. autosummary:: 6 | :toctree: generated/ 7 | 8 | SurvivalTree 9 | -------------------------------------------------------------------------------- /doc/api/util.rst: -------------------------------------------------------------------------------- 1 | Utilities 2 | ========= 3 | .. currentmodule:: sksurv.util 4 | 5 | .. autosummary:: 6 | :toctree: generated/ 7 | 8 | Surv 9 | -------------------------------------------------------------------------------- /doc/cite.rst: -------------------------------------------------------------------------------- 1 | Citing scikit-survival 2 | ====================== 3 | 4 | If you are using **scikit-survival** in your scientific research, 5 | please cite the following paper: 6 | 7 | S. Pölsterl, "scikit-survival: A Library for Time-to-Event Analysis Built on Top of scikit-learn," 8 | Journal of Machine Learning Research, vol. 21, no. 212, pp. 1–6, 2020. 9 | 10 | BibTeX entry:: 11 | 12 | @article{sksurv, 13 | author = {Sebastian P{\"o}lsterl}, 14 | title = {scikit-survival: A Library for Time-to-Event Analysis Built on Top of scikit-learn}, 15 | journal = {Journal of Machine Learning Research}, 16 | year = {2020}, 17 | volume = {21}, 18 | number = {212}, 19 | pages = {1-6}, 20 | url = {http://jmlr.org/papers/v21/20-729.html} 21 | } 22 | -------------------------------------------------------------------------------- /doc/index.rst: -------------------------------------------------------------------------------- 1 | .. scikit-survival documentation master file 2 | 3 | scikit-survival 4 | =============== 5 | 6 | scikit-survival is a Python module for `survival analysis `_ 7 | built on top of `scikit-learn `_. It allows doing survival analysis 8 | while utilizing the power of scikit-learn, e.g., for pre-processing or doing cross-validation. 9 | 10 | The objective in survival analysis (also referred to as time-to-event or reliability analysis) 11 | is to establish a connection between covariates and the time of an event. 12 | What makes survival analysis differ from traditional machine learning is the fact that 13 | parts of the training data can only be partially observed – they are *censored*. 14 | 15 | For instance, in a clinical study, patients are often monitored for a particular time period, 16 | and events occurring in this particular period are recorded. 17 | If a patient experiences an event, the exact time of the event can 18 | be recorded – the patient’s record is uncensored. In contrast, right censored records 19 | refer to patients that remained event-free during the study period and 20 | it is unknown whether an event has or has not occurred after the study ended. 21 | Consequently, survival analysis demands for models that take 22 | this unique characteristic of such a dataset into account. 23 | 24 | 25 | .. grid:: 2 26 | :gutter: 3 27 | :class-container: overview-grid 28 | 29 | .. grid-item-card:: Install :fas:`download` 30 | :link: install 31 | :link-type: doc 32 | 33 | The easiest way to install scikit-survival is to use 34 | `conda-forge `_ by running:: 35 | 36 | conda install -c conda-forge scikit-survival 37 | 38 | Alternatively, you can install scikit-survival from source 39 | following :ref:`this guide `. 40 | 41 | 42 | .. grid-item-card:: User Guide :fas:`book-open` 43 | :link: user_guide/index 44 | :link-type: doc 45 | 46 | The user guide provides in-depth information on the key concepts of scikit-survival, an overview of available survival models, and hands-on examples. 47 | 48 | 49 | .. grid-item-card:: API Reference :fas:`cogs` 50 | :link: api/index 51 | :link-type: doc 52 | 53 | The reference guide contains a detailed description of the scikit-survival API. It describes which classes and functions are available 54 | and what their parameters are. 55 | 56 | 57 | .. grid-item-card:: Contributing :fas:`code` 58 | :link: contributing 59 | :link-type: doc 60 | 61 | Saw a typo in the documentation? Want to add new functionalities? The contributing guidelines will guide you through the process of 62 | setting up a development environment and submitting your changes to the scikit-survival team. 63 | 64 | 65 | .. toctree:: 66 | :maxdepth: 1 67 | :hidden: 68 | :titlesonly: 69 | 70 | Install 71 | user_guide/index 72 | api/index 73 | Contribute 74 | release_notes 75 | Cite 76 | -------------------------------------------------------------------------------- /doc/install.rst: -------------------------------------------------------------------------------- 1 | Installing scikit-survival 2 | ========================== 3 | 4 | The recommended and easiest way to install scikit-survival is to use 5 | :ref:`install-conda` or :ref:`install-pip`. 6 | Pre-built binary packages for scikit-survival are available for Linux, macOS, and Windows. 7 | Alternatively, you can install scikit-survival :ref:`install-from-source`. 8 | 9 | .. _install-conda: 10 | 11 | conda 12 | ----- 13 | 14 | If you have `conda `_ installed, you can 15 | install scikit-survival from the ``conda-forge`` channel by running:: 16 | 17 | conda install -c conda-forge scikit-survival 18 | 19 | .. _install-pip: 20 | 21 | pip 22 | --- 23 | 24 | If you use ``pip``, install the latest release of scikit-survival with:: 25 | 26 | pip install scikit-survival 27 | 28 | 29 | .. _install-from-source: 30 | 31 | From Source 32 | ----------- 33 | 34 | If you want to build scikit-survival from source, you 35 | will need a C/C++ compiler to compile extensions. 36 | 37 | **Linux** 38 | 39 | On Linux, you need to install *gcc*, which in most cases is available 40 | via your distribution's packaging system. 41 | Please follow your distribution's instructions on how to install packages. 42 | 43 | **macOS** 44 | 45 | On macOS, you need to install *clang*, which is available from 46 | the *Command Line Tools* package. Open a terminal and execute:: 47 | 48 | xcode-select --install 49 | 50 | Alternatively, you can download it from the 51 | `Apple Developers page `_. 52 | Log in with your Apple ID, then search and download the 53 | *Command Line Tools for Xcode* package. 54 | 55 | **Windows** 56 | 57 | On Windows, the compiler you need depends on the Python version 58 | you are using. See `this guide `_ 59 | to determine which Microsoft Visual C++ compiler to use with a specific Python version. 60 | 61 | 62 | Latest Release 63 | ^^^^^^^^^^^^^^ 64 | 65 | To install the latest release of scikit-survival from source, run:: 66 | 67 | pip install scikit-survival --no-binary scikit-survival 68 | 69 | 70 | .. note:: 71 | 72 | If you have not installed the :ref:`dependencies ` previously, this command 73 | will first install all dependencies before installing scikit-survival. 74 | Therefore, installation might fail if build requirements of some dependencies 75 | are not met. In particular, `osqp `_ 76 | does require `CMake `_ to be installed. 77 | 78 | Development Version 79 | ^^^^^^^^^^^^^^^^^^^ 80 | 81 | To install the latest source from our `GitHub repository `_, 82 | you need to have `Git `_ installed and 83 | simply run:: 84 | 85 | pip install git+https://github.com/sebp/scikit-survival.git 86 | 87 | 88 | 89 | .. _dependencies: 90 | 91 | Dependencies 92 | ------------ 93 | 94 | The current minimum dependencies to run scikit-survival are: 95 | 96 | - Python 3.10 or later 97 | - ecos 98 | - joblib 99 | - numexpr 100 | - numpy 101 | - osqp 102 | - pandas 1.4.0 or later 103 | - scikit-learn 1.6 104 | - scipy 105 | - C/C++ compiler 106 | -------------------------------------------------------------------------------- /doc/release_notes.rst: -------------------------------------------------------------------------------- 1 | Release Notes 2 | ============= 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | 7 | release_notes/v0.24 8 | release_notes/v0.23 9 | release_notes/v0.22 10 | release_notes/v0.21 11 | release_notes/v0.20 12 | release_notes/v0.19 13 | release_notes/v0.18 14 | release_notes/v0.17 15 | release_notes/v0.16 16 | release_notes/v0.15 17 | release_notes/v0.14 18 | release_notes/v0.13 19 | release_notes/v0.12 20 | release_notes/v0.11 21 | release_notes/v0.10 22 | release_notes/v0.9 23 | release_notes/v0.8 24 | release_notes/v0.7 25 | release_notes/v0.6 26 | release_notes/v0.5 27 | release_notes/v0.4 28 | release_notes/v0.3 29 | release_notes/v0.2 30 | release_notes/v0.1 31 | -------------------------------------------------------------------------------- /doc/release_notes/v0.1.rst: -------------------------------------------------------------------------------- 1 | .. _release_notes_0_1: 2 | 3 | What's new in 0.1 4 | ================= 5 | 6 | scikit-survival 0.1 (2016-12-29) 7 | -------------------------------- 8 | 9 | This is the initial release of scikit-survival. 10 | It combines the `implementation of survival support vector machines `_ 11 | with the code used in the `Prostate Cancer DREAM challenge `_. 12 | -------------------------------------------------------------------------------- /doc/release_notes/v0.10.rst: -------------------------------------------------------------------------------- 1 | .. _release_notes_0_10: 2 | 3 | What's new in 0.10 4 | ================== 5 | 6 | scikit-survival 0.10 (2019-09-02) 7 | --------------------------------- 8 | 9 | This release adds the `ties` argument to :class:`sksurv.linear_model.CoxPHSurvivalAnalysis` 10 | to choose between Breslow's and Efron's likelihood in the presence of tied event times. 11 | Moreover, :func:`sksurv.compare.compare_survival` has been added, which implements 12 | the log-rank hypothesis test for comparing the survival function of 2 or more groups. 13 | 14 | Enhancements 15 | ^^^^^^^^^^^^ 16 | 17 | - Update API doc of predict function of boosting estimators (#75). 18 | - Clarify documentation for GradientBoostingSurvivalAnalysis (#78). 19 | - Implement Efron's likelihood for handling tied event times. 20 | - Implement log-rank test for comparing survival curves. 21 | - Add support for scipy 1.3.1 (#66). 22 | 23 | Bug fixes 24 | ^^^^^^^^^ 25 | 26 | - Re-add `baseline_survival_` and `cum_baseline_hazard_` attributes 27 | to :class:`sksurv.linear_model.CoxPHSurvivalAnalysis` (#76). 28 | -------------------------------------------------------------------------------- /doc/release_notes/v0.11.rst: -------------------------------------------------------------------------------- 1 | .. _release_notes_0_11: 2 | 3 | What's new in 0.11 4 | ================== 5 | 6 | scikit-survival 0.11 (2019-12-21) 7 | --------------------------------- 8 | 9 | This release adds :class:`sksurv.tree.SurvivalTree` and :class:`sksurv.ensemble.RandomSurvivalForest`, 10 | which are based on the log-rank split criterion. 11 | It also adds the OSQP solver as option to :class:`sksurv.svm.MinlipSurvivalAnalysis` 12 | and :class:`sksurv.svm.HingeLossSurvivalSVM`, which will replace the now deprecated 13 | `cvxpy` and `cvxopt` options in a future release. 14 | 15 | This release removes support for sklearn 0.20 and requires sklearn 0.21. 16 | 17 | Deprecations 18 | ^^^^^^^^^^^^ 19 | 20 | - The `cvxpy` and `cvxopt` options for `solver` in :class:`sksurv.svm.MinlipSurvivalAnalysis` 21 | and :class:`sksurv.svm.HingeLossSurvivalSVM` are deprecated and will be removed in a future 22 | version. Choosing `osqp` is the preferred option now. 23 | 24 | Enhancements 25 | ^^^^^^^^^^^^ 26 | 27 | - Add support for pandas 0.25. 28 | - Add OSQP solver option to :class:`sksurv.svm.MinlipSurvivalAnalysis` and 29 | :class:`sksurv.svm.HingeLossSurvivalSVM` which has no additional dependencies. 30 | - Fix issue when using cvxpy 1.0.16 or later. 31 | - Explicitly specify utf-8 encoding when reading README.rst (#89). 32 | - Add :class:`sksurv.tree.SurvivalTree` and :class:`sksurv.ensemble.RandomSurvivalForest` (#90). 33 | 34 | Bug fixes 35 | ^^^^^^^^^ 36 | 37 | - Exclude Cython-generated files from source distribution because 38 | they are not forward compatible. 39 | -------------------------------------------------------------------------------- /doc/release_notes/v0.12.rst: -------------------------------------------------------------------------------- 1 | .. _release_notes_0_12: 2 | 3 | What's new in 0.12 4 | ================== 5 | 6 | scikit-survival 0.12 (2020-04-15) 7 | --------------------------------- 8 | 9 | This release adds support for scikit-learn 0.22, thereby dropping support for 10 | older versions. Moreover, the regularization strength of the ridge penalty 11 | in :class:`sksurv.linear_model.CoxPHSurvivalAnalysis` can now be set per 12 | feature. If you want one or more features to enter the model unpenalized, 13 | set the corresponding penalty weights to zero. 14 | Finally, :class:`sklearn.pipeline.Pipeline` will now be automatically patched 15 | to add support for `predict_cumulative_hazard_function` and `predict_survival_function` 16 | if the underlying estimator supports it. 17 | 18 | Deprecations 19 | ^^^^^^^^^^^^ 20 | 21 | - Add scikit-learn's deprecation of `presort` in :class:`sksurv.tree.SurvivalTree` and 22 | :class:`sksurv.ensemble.GradientBoostingSurvivalAnalysis`. 23 | - Add warning that default `alpha_min_ratio` in :class:`sksurv.linear_model.CoxnetSurvivalAnalysis` 24 | will depend on the ratio of the number of samples to the number of features 25 | in the future (#41). 26 | 27 | Enhancements 28 | ^^^^^^^^^^^^ 29 | 30 | - Add references to API doc of :class:`sksurv.ensemble.GradientBoostingSurvivalAnalysis` (#91). 31 | - Add support for pandas 1.0 (#100). 32 | - Add `ccp_alpha` parameter for 33 | `Minimal Cost-Complexity Pruning `_ 34 | to :class:`sksurv.ensemble.GradientBoostingSurvivalAnalysis`. 35 | - Patch :class:`sklearn.pipeline.Pipeline` to add support for 36 | `predict_cumulative_hazard_function` and `predict_survival_function` 37 | if the underlying estimator supports it. 38 | - Allow per-feature regularization for :class:`sksurv.linear_model.CoxPHSurvivalAnalysis` (#102). 39 | - Clarify API docs of :func:`sksurv.metrics.concordance_index_censored` (#96). 40 | -------------------------------------------------------------------------------- /doc/release_notes/v0.13.rst: -------------------------------------------------------------------------------- 1 | .. _release_notes_0_13: 2 | 3 | What's new in 0.13 4 | ================== 5 | 6 | scikit-survival 0.13.1 (2020-07-04) 7 | ----------------------------------- 8 | 9 | This release fixes warnings that were introduced with 0.13.0. 10 | 11 | Bug fixes 12 | ^^^^^^^^^ 13 | 14 | - Explicitly pass ``return_array=True`` in :meth:`sksurv.tree.SurvivalTree.predict` 15 | to avoid FutureWarning. 16 | - Fix error when fitting :class:`sksurv.tree.SurvivalTree` with non-float 17 | dtype for time (#127). 18 | - Fix RuntimeWarning: invalid value encountered in true_divide 19 | in :func:`sksurv.nonparametric.kaplan_meier_estimator`. 20 | - Fix PendingDeprecationWarning about use of matrix when fitting 21 | :class:`sksurv.svm.FastSurvivalSVM` if optimizer is `PRSVM` or `simple`. 22 | 23 | 24 | scikit-survival 0.13.0 (2020-06-28) 25 | ----------------------------------- 26 | 27 | The highlights of this release include the addition of 28 | :func:`sksurv.metrics.brier_score` and 29 | :func:`sksurv.metrics.integrated_brier_score` 30 | and compatibility with scikit-learn 0.23. 31 | 32 | `predict_survival_function` and `predict_cumulative_hazard_function` 33 | of :class:`sksurv.ensemble.RandomSurvivalForest` and 34 | :class:`sksurv.tree.SurvivalTree` can now return an array of 35 | :class:`sksurv.functions.StepFunction`, similar 36 | to :class:`sksurv.linear_model.CoxPHSurvivalAnalysis` 37 | by specifying ``return_array=False``. This will be the default 38 | behavior starting with 0.14.0. 39 | 40 | Note that this release fixes a bug in estimating 41 | inverse probability of censoring weights (IPCW), which will 42 | affect all estimators relying on IPCW. 43 | 44 | Enhancements 45 | ^^^^^^^^^^^^ 46 | 47 | - Make build system compatible with PEP-517/518. 48 | - Added :func:`sksurv.metrics.brier_score` and 49 | :func:`sksurv.metrics.integrated_brier_score` (#101). 50 | - :class:`sksurv.functions.StepFunction` can now be evaluated at multiple points 51 | in a single call. 52 | - Update documentation on usage of `predict_survival_function` and 53 | `predict_cumulative_hazard_function` (#118). 54 | - The default value of `alpha_min_ratio` of 55 | :class:`sksurv.linear_model.CoxnetSurvivalAnalysis` will now depend 56 | on the `n_samples/n_features` ratio. 57 | If ``n_samples > n_features``, the default value is 0.0001 58 | If ``n_samples <= n_features``, the default value is 0.01. 59 | - Add support for scikit-learn 0.23 (#119). 60 | 61 | Deprecations 62 | ^^^^^^^^^^^^ 63 | 64 | - `predict_survival_function` and `predict_cumulative_hazard_function` 65 | of :class:`sksurv.ensemble.RandomSurvivalForest` and 66 | :class:`sksurv.tree.SurvivalTree` will return an array of 67 | :class:`sksurv.functions.StepFunction` in the future 68 | (as :class:`sksurv.linear_model.CoxPHSurvivalAnalysis` does). 69 | For the old behavior, use `return_array=True`. 70 | 71 | Bug fixes 72 | ^^^^^^^^^ 73 | 74 | - Fix deprecation of importing joblib via sklearn. 75 | - Fix estimation of censoring distribution for tied times with events. 76 | When estimating the censoring distribution, 77 | by specifying ``reverse=True`` when calling 78 | :func:`sksurv.nonparametric.kaplan_meier_estimator`, 79 | we now consider events to occur before censoring. 80 | For tied time points with an event, those 81 | with an event are not considered at risk anymore and subtracted from 82 | the denominator of the Kaplan-Meier estimator. 83 | The change affects all functions relying on inverse probability 84 | of censoring weights, namely: 85 | 86 | - :class:`sksurv.nonparametric.CensoringDistributionEstimator` 87 | - :func:`sksurv.nonparametric.ipc_weights` 88 | - :class:`sksurv.linear_model.IPCRidge` 89 | - :func:`sksurv.metrics.cumulative_dynamic_auc` 90 | - :func:`sksurv.metrics.concordance_index_ipcw` 91 | 92 | - Throw an exception when trying to estimate c-index from incomparable data (#117). 93 | - Estimators in ``sksurv.svm`` will now throw an 94 | exception when trying to fit a model to data with incomparable pairs. 95 | -------------------------------------------------------------------------------- /doc/release_notes/v0.14.rst: -------------------------------------------------------------------------------- 1 | .. _release_notes_0_14: 2 | 3 | What's new in 0.14 4 | ================== 5 | 6 | scikit-survival 0.14.0 (2020-10-07) 7 | ----------------------------------- 8 | 9 | This release features a complete overhaul of the :doc:`documentation <../index>`. 10 | It features a new visual design, and the inclusion of several interactive notebooks 11 | in the :ref:`User Guide`. 12 | 13 | In addition, it includes important bug fixes. 14 | It fixes several bugs in :class:`sksurv.linear_model.CoxnetSurvivalAnalysis` 15 | where ``predict``, ``predict_survival_function``, and ``predict_cumulative_hazard_function`` 16 | returned wrong values if features of the training data were not centered. 17 | Moreover, the `score` function of :class:`sksurv.ensemble.ComponentwiseGradientBoostingSurvivalAnalysis` 18 | and :class:`sksurv.ensemble.GradientBoostingSurvivalAnalysis` will now 19 | correctly compute the concordance index if ``loss='ipcwls'`` or ``loss='squared'``. 20 | 21 | Bug fixes 22 | ^^^^^^^^^ 23 | 24 | - :func:`sksurv.column.standardize` modified data in-place. Data is now always copied. 25 | - :func:`sksurv.column.standardize` works with integer numpy arrays now. 26 | - :func:`sksurv.column.standardize` used biased standard deviation for numpy arrays (``ddof=0``), 27 | but unbiased standard deviation for pandas objects (``ddof=1``). It always uses ``ddof=1`` now. 28 | Therefore, the output, if the input is a numpy array, will differ from that of previous versions. 29 | - Fixed :meth:`sksurv.linear_model.CoxnetSurvivalAnalysis.predict_survival_function` 30 | and :meth:`sksurv.linear_model.CoxnetSurvivalAnalysis.predict_cumulative_hazard_function`, 31 | which returned wrong values if features of training data were not already centered. 32 | This adds an ``offset_`` attribute that accounts for non-centered data and is added to the 33 | predicted risk score. Therefore, the outputs of ``predict``, ``predict_survival_function``, 34 | and ``predict_cumulative_hazard_function`` will be different to previous versions for 35 | non-centered data (#139). 36 | - Rescale coefficients of :class:`sksurv.linear_model.CoxnetSurvivalAnalysis` if 37 | `normalize=True`. 38 | - Fix `score` function of :class:`sksurv.ensemble.ComponentwiseGradientBoostingSurvivalAnalysis` 39 | and :class:`sksurv.ensemble.GradientBoostingSurvivalAnalysis` if ``loss='ipcwls'`` or ``loss='squared'`` 40 | is used. Previously, it returned ``1.0 - true_cindex``. 41 | 42 | Enhancements 43 | ^^^^^^^^^^^^ 44 | 45 | - Add :func:`sksurv.show_versions` that prints the version of all dependencies. 46 | - Add support for pandas 1.1 47 | - Include interactive notebooks in documentation on readthedocs. 48 | - Add user guide on :ref:`penalized Cox models `. 49 | - Add user guide on :ref:`gradient boosted models `. 50 | -------------------------------------------------------------------------------- /doc/release_notes/v0.15.rst: -------------------------------------------------------------------------------- 1 | .. _release_notes_0_15: 2 | 3 | What's new in 0.15 4 | ================== 5 | 6 | scikit-survival 0.15.0 (2021-03-20) 7 | ----------------------------------- 8 | 9 | This release adds support for scikit-learn 0.24 and Python 3.9. 10 | scikit-survival now requires at least pandas 0.25 and scikit-learn 0.24. 11 | Moreover, if :class:`sksurv.ensemble.GradientBoostingSurvivalAnalysis`. 12 | or :class:`sksurv.ensemble.ComponentwiseGradientBoostingSurvivalAnalysis` 13 | are fit with ``loss='coxph'``, `predict_cumulative_hazard_function` and 14 | `predict_survival_function` are now available. 15 | :func:`sksurv.metrics.cumulative_dynamic_auc` now supports evaluating 16 | time-dependent predictions, for instance for a :class:`sksurv.ensemble.RandomSurvivalForest` 17 | as illustrated in the 18 | :ref:`User Guide `. 19 | 20 | Bug fixes 21 | ^^^^^^^^^ 22 | - Allow passing pandas data frames to all ``fit`` and ``predict`` methods (#148). 23 | - Allow sparse matrices to be passed to 24 | :meth:`sksurv.ensemble.GradientBoostingSurvivalAnalysis.predict`. 25 | - Fix example in user guide using GridSearchCV to determine alphas for CoxnetSurvivalAnalysis (#186). 26 | 27 | Enhancements 28 | ^^^^^^^^^^^^ 29 | - Add score method to :class:`sksurv.meta.Stacking`, 30 | :class:`sksurv.meta.EnsembleSelection`, and 31 | :class:`sksurv.meta.EnsembleSelectionRegressor` (#151). 32 | - Add support for `predict_cumulative_hazard_function` and 33 | `predict_survival_function` to :class:`sksurv.ensemble.GradientBoostingSurvivalAnalysis`. 34 | and :class:`sksurv.ensemble.ComponentwiseGradientBoostingSurvivalAnalysis` 35 | if model was fit with ``loss='coxph'``. 36 | - Add support for time-dependent predictions to :func:`sksurv.metrics.cumulative_dynamic_auc` 37 | See the :ref:`User Guide ` 38 | for an example (#134). 39 | 40 | Backwards incompatible changes 41 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 42 | - The score method of :class:`sksurv.linear_model.IPCRidge`, 43 | :class:`sksurv.svm.FastSurvivalSVM`, and :class:`sksurv.svm.FastKernelSurvivalSVM` 44 | (if ``rank_ratio`` is smaller than 1) now converts predictions on log(time) scale 45 | to risk scores prior to computing the concordance index. 46 | - Support for cvxpy and cvxopt solver in :class:`sksurv.svm.MinlipSurvivalAnalysis` 47 | and :class:`sksurv.svm.HingeLossSurvivalSVM` has been dropped. The default solver 48 | is now ECOS, which was used by cvxpy (the previous default) internally. Therefore, 49 | results should be identical. 50 | - Dropped the ``presort`` argument from :class:`sksurv.tree.SurvivalTree` 51 | and :class:`sksurv.ensemble.GradientBoostingSurvivalAnalysis`. 52 | - The ``X_idx_sorted`` argument in :meth:`sksurv.tree.SurvivalTree.fit` 53 | has been deprecated in scikit-learn 0.24 and has no effect now. 54 | - `predict_cumulative_hazard_function` and 55 | `predict_survival_function` of :class:`sksurv.ensemble.RandomSurvivalForest` 56 | and :class:`sksurv.tree.SurvivalTree` now return an array of 57 | :class:`sksurv.functions.StepFunction` objects by default. 58 | Use ``return_array=True`` to get the old behavior. 59 | - Support for Python 3.6 has been dropped. 60 | - Increase minimum supported versions of dependencies. We now require: 61 | 62 | +--------------+-----------------+ 63 | | Package | Minimum Version | 64 | +==============+=================+ 65 | | Pandas | 0.25.0 | 66 | +--------------+-----------------+ 67 | | scikit-learn | 0.24.0 | 68 | +--------------+-----------------+ 69 | -------------------------------------------------------------------------------- /doc/release_notes/v0.16.rst: -------------------------------------------------------------------------------- 1 | .. _release_notes_0_16: 2 | 3 | What's new in 0.16 4 | ================== 5 | 6 | scikit-survival 0.16.0 (2021-10-30) 7 | ----------------------------------- 8 | 9 | This release adds support for changing the evaluation metric that 10 | is used in estimators' ``score`` method. This is particular useful 11 | for hyper-parameter optimization using scikit-learn's ``GridSearchCV``. 12 | You can now use :class:`sksurv.metrics.as_concordance_index_ipcw_scorer`, 13 | :class:`sksurv.metrics.as_cumulative_dynamic_auc_scorer`, or 14 | :class:`sksurv.metrics.as_integrated_brier_score_scorer` to adjust the 15 | ``score`` method to your needs. A detailed example is available in the 16 | :ref:`User Guide `. 17 | 18 | Moreover, this release adds :class:`sksurv.ensemble.ExtraSurvivalTrees` 19 | to fit an ensemble of randomized survival trees, and improves the speed 20 | of :func:`sksurv.compare.compare_survival` significantly. 21 | The documentation has been extended by a section on 22 | the :ref:`time-dependent Brier score `. 23 | 24 | Bug fixes 25 | ^^^^^^^^^ 26 | - Columns are dropped in :func:`sksurv.column.encode_categorical` 27 | despite ``allow_drop=False`` (:issue:`199`). 28 | - Ensure :func:`sksurv.column.categorical_to_numeric` always 29 | returns series with int64 dtype. 30 | 31 | Enhancements 32 | ^^^^^^^^^^^^ 33 | - Add :class:`sksurv.ensemble.ExtraSurvivalTrees` ensemble (:issue:`195`). 34 | - Faster speed for :func:`sksurv.compare.compare_survival` (:issue:`215`). 35 | - Add wrapper classes :class:`sksurv.metrics.as_concordance_index_ipcw_scorer`, 36 | :class:`sksurv.metrics.as_cumulative_dynamic_auc_scorer`, and 37 | :class:`sksurv.metrics.as_integrated_brier_score_scorer` to override the 38 | default ``score`` method of estimators (:issue:`192`). 39 | - Remove use of deprecated numpy dtypes. 40 | - Remove use of ``inplace`` in pandas' ``set_categories``. 41 | 42 | Documentation 43 | ^^^^^^^^^^^^^ 44 | - Remove comments and code suggesting log-transforming times prior to training Survival SVM (:issue:`203`). 45 | - Add documentation for ``max_samples`` parameter to :class:`sksurv.ensemble.ExtraSurvivalTrees` 46 | and :class:`sksurv.ensemble.RandomSurvivalForest` (:issue:`217`). 47 | - Add section on time-dependent Brier score (:issue:`220`). 48 | - Add section on using alternative metrics for hyper-parameter optimization. 49 | -------------------------------------------------------------------------------- /doc/release_notes/v0.17.rst: -------------------------------------------------------------------------------- 1 | .. _release_notes_0_17: 2 | 3 | What's new in 0.17 4 | ================== 5 | 6 | scikit-survival 0.17.2 (2022-04-24) 7 | ----------------------------------- 8 | 9 | This release fixes several issues with packaging scikit-survival. 10 | 11 | Bug fixes 12 | ^^^^^^^^^ 13 | - Added backward support for gcc-c++ (:issue:`255`). 14 | - Do not install C/C++ and Cython source files. 15 | - Add ``packaging`` to build requirements in ``pyproject.toml``. 16 | - Exclude generated API docs from source distribution. 17 | - Add Python 3.10 to classifiers. 18 | 19 | Documentation 20 | ^^^^^^^^^^^^^ 21 | - Use `permutation_importance `_ 22 | from sklearn instead of eli5. 23 | - Build documentation with Sphinx 4.4.0. 24 | - Fix missing documentation for classes in ``sksurv.meta``. 25 | 26 | 27 | scikit-survival 0.17.1 (2022-03-05) 28 | ----------------------------------- 29 | 30 | This release adds support for Python 3.10. 31 | 32 | 33 | scikit-survival 0.17.0 (2022-01-09) 34 | ----------------------------------- 35 | 36 | This release adds support for scikit-learn 1.0, which includes 37 | support for feature names. 38 | If you pass a pandas dataframe to ``fit``, the estimator will 39 | set a `feature_names_in_` attribute containing the feature names. 40 | When a dataframe is passed to ``predict``, it is checked that the 41 | column names are consistent with those passed to ``fit``. See the 42 | `scikit-learn release highlights `_ 43 | for details. 44 | 45 | Bug fixes 46 | ^^^^^^^^^ 47 | - Fix a variety of build problems with LLVM (:issue:`243`). 48 | 49 | Enhancements 50 | ^^^^^^^^^^^^ 51 | - Add support for ``feature_names_in_`` and ``n_features_in_`` 52 | to all estimators and transforms. 53 | - Add :meth:`sksurv.preprocessing.OneHotEncoder.get_feature_names_out`. 54 | - Update bundled version of Eigen to 3.3.9. 55 | 56 | Backwards incompatible changes 57 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 58 | - Drop ``min_impurity_split`` parameter from 59 | :class:`sksurv.ensemble.GradientBoostingSurvivalAnalysis`. 60 | - ``base_estimators`` and ``meta_estimator`` attributes of 61 | :class:`sksurv.meta.Stacking` do not contain fitted models anymore, 62 | use ``estimators_`` and ``final_estimator_``, respectively. 63 | 64 | Deprecations 65 | ^^^^^^^^^^^^ 66 | - The ``normalize`` parameter of :class:`sksurv.linear_model.IPCRidge` 67 | is deprecated and will be removed in a future version. Instead, use 68 | a scikit-learn pipeline: 69 | ``make_pipeline(StandardScaler(with_mean=False), IPCRidge())``. 70 | -------------------------------------------------------------------------------- /doc/release_notes/v0.18.rst: -------------------------------------------------------------------------------- 1 | .. _release_notes_0_18: 2 | 3 | What's new in 0.18 4 | ================== 5 | 6 | scikit-survival 0.18.0 (2022-08-15) 7 | ----------------------------------- 8 | 9 | This release adds support for scikit-learn 1.1, which 10 | includes more informative error messages. 11 | Support for Python 3.7 has been dropped, and 12 | the minimum supported versions of dependencies are updated to 13 | 14 | +--------------+-----------------+ 15 | | Package | Minimum Version | 16 | +==============+=================+ 17 | | numpy | 1.17.3 | 18 | +--------------+-----------------+ 19 | | Pandas | 1.0.5 | 20 | +--------------+-----------------+ 21 | | scikit-learn | 1.1.0 | 22 | +--------------+-----------------+ 23 | | scipy | 1.3.2 | 24 | +--------------+-----------------+ 25 | 26 | Enhancements 27 | ^^^^^^^^^^^^ 28 | - Add ``n_iter_`` attribute to all estimators in :ref:`sksurv.svm ` (:issue:`277`). 29 | - Add ``return_array`` argument to all models providing 30 | ``predict_survival_function`` and ``predict_cumulative_hazard_function`` 31 | (:issue:`268`). 32 | 33 | Deprecations 34 | ^^^^^^^^^^^^ 35 | - The ``loss_`` attribute of :class:`sksurv.ensemble.ComponentwiseGradientBoostingSurvivalAnalysis` 36 | and :class:`sksurv.ensemble.GradientBoostingSurvivalAnalysis` 37 | has been deprecated. 38 | - The default for the ``max_features`` argument has been changed 39 | from ``'auto'`` to ``'sqrt'`` for :class:`sksurv.ensemble.RandomSurvivalForest` 40 | and :class:`sksurv.ensemble.ExtraSurvivalTrees`. ``'auto'`` and ``'sqrt'`` 41 | have the same effect. 42 | -------------------------------------------------------------------------------- /doc/release_notes/v0.19.rst: -------------------------------------------------------------------------------- 1 | .. _release_notes_0_19: 2 | 3 | What's new in 0.19 4 | ================== 5 | 6 | scikit-survival 0.19.0 (2022-10-23) 7 | ----------------------------------- 8 | 9 | This release adds :meth:`sksurv.tree.SurvivalTree.apply` and 10 | :meth:`sksurv.tree.SurvivalTree.decision_path`, and support 11 | for sparse matrices to :class:`sksurv.tree.SurvivalTree`. 12 | Moreover, it fixes build issues with scikit-learn 1.1.2 13 | and on macOS with ARM64 CPU. 14 | 15 | Bug fixes 16 | ^^^^^^^^^ 17 | - Fix build issue with scikit-learn 1.1.2, which is binary-incompatible with 18 | previous releases from the 1.1 series. 19 | - Fix build from source on macOS with ARM64 by specifying numpy 1.21.0 as install 20 | requirement for that platform (:issue:`313`). 21 | 22 | Enhancements 23 | ^^^^^^^^^^^^ 24 | - :class:`sksurv.tree.SurvivalTree`: Add :meth:`sksurv.tree.SurvivalTree.apply` and 25 | :meth:`sksurv.tree.SurvivalTree.decision_path` (:issue:`290`). 26 | - :class:`sksurv.tree.SurvivalTree`: Add support for sparse matrices (:issue:`290`). 27 | -------------------------------------------------------------------------------- /doc/release_notes/v0.2.rst: -------------------------------------------------------------------------------- 1 | .. _release_notes_0_2: 2 | 3 | What's new in 0.2 4 | ================= 5 | 6 | scikit-survival 0.2 (2017-05-29) 7 | -------------------------------- 8 | 9 | This release adds support for Python 3.6, and pandas 0.19 and 0.20. 10 | -------------------------------------------------------------------------------- /doc/release_notes/v0.20.rst: -------------------------------------------------------------------------------- 1 | .. _release_notes_0_20: 2 | 3 | What's new in 0.20 4 | ================== 5 | 6 | scikit-survival 0.20.0 (2023-03-05) 7 | ----------------------------------- 8 | 9 | This release adds support for scikit-learn 1.2 and drops support for previous versions. 10 | 11 | Enhancements 12 | ^^^^^^^^^^^^ 13 | - Raise more informative error messages when a parameter does 14 | not have a valid type/value (see 15 | `sklearn#23462 `_). 16 | - Add ``positive`` and ``random_state`` parameters to :class:`sksurv.linear_model.IPCRidge`. 17 | 18 | Documentation 19 | ^^^^^^^^^^^^^ 20 | - Update API docs based on scikit-learn 1.2 (where applicable). 21 | 22 | Backwards incompatible changes 23 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 24 | - To align with the scikit-learn API, many parameters of estimators must be 25 | provided with their names, as keyword arguments, instead of positional arguments. 26 | - Remove deprecated ``normalize`` parameter from :class:`sksurv.linear_model.IPCRidge`. 27 | - Remove deprecated ``X_idx_sorted`` argument from :meth:`sksurv.tree.SurvivalTree.fit`. 28 | - Setting ``kernel="polynomial"`` in :class:`sksurv.svm.FastKernelSurvivalSVM`, 29 | :class:`sksurv.svm.HingeLossSurvivalSVM`, and :class:`sksurv.svm.MinlipSurvivalAnalysis` 30 | has been replaced with ``kernel="poly"``. 31 | -------------------------------------------------------------------------------- /doc/release_notes/v0.21.rst: -------------------------------------------------------------------------------- 1 | .. _release_notes_0_21: 2 | 3 | What's new in 0.21 4 | ================== 5 | 6 | scikit-survival 0.21.0 (2023-06-11) 7 | ----------------------------------- 8 | 9 | This is a major release bringing new features and performance improvements. 10 | 11 | - :func:`sksurv.nonparametric.kaplan_meier_estimator` can estimate 12 | pointwise confidence intervals by specifying the `conf_type` parameter. 13 | - :class:`sksurv.ensemble.GradientBoostingSurvivalAnalysis` supports 14 | early-stopping via the `monitor` parameter of 15 | :meth:`sksurv.ensemble.GradientBoostingSurvivalAnalysis.fit`. 16 | - :func:`sksurv.metrics.concordance_index_censored` has a significantly 17 | reduced memory footprint. Memory usage now scales linear, instead of quadratic, 18 | in the number of samples. 19 | - Fitting of :class:`sksurv.tree.SurvivalTree`, 20 | :class:`sksurv.ensemble.RandomSurvivalForest`, or :class:`sksurv.ensemble.ExtraSurvivalTrees` 21 | is about 3x faster. 22 | - Finally, the release adds support for Python 3.11 and pandas 2.0. 23 | 24 | Bug fixes 25 | ^^^^^^^^^ 26 | - Fix bug where `times` passed to :func:`sksurv.metrics.brier_score` 27 | was downcast, resulting in a loss of precision that may lead 28 | to duplicate time points (:issue:`349`). 29 | - Fix inconsistent behavior of evaluating functions returned by 30 | `predict_cumulative_hazard_function` or `predict_survival_function` 31 | (:issue:`375`). 32 | 33 | Enhancements 34 | ^^^^^^^^^^^^ 35 | - :func:`sksurv.nonparametric.kaplan_meier_estimator` 36 | and :class:`sksurv.nonparametric.CensoringDistributionEstimator` 37 | support returning confidence intervals by specifying the `conf_type` 38 | parameter (:issue:`348`). 39 | - Configure package via pyproject.toml (:issue:`347`). 40 | - Add support for Python 3.11 (:issue:`350`). 41 | - Add support for early-stopping to 42 | :class:`sksurv.ensemble.GradientBoostingSurvivalAnalysis` 43 | (:issue:`354`). 44 | - Do not use deprecated `pkg_resources` API (:issue:`353`). 45 | - Significantly reduce memory usage of :func:`sksurv.metrics.concordance_index_censored` 46 | (:issue:`362`). 47 | - Set `criterion` attribute in :class:`sksurv.tree.SurvivalTree` 48 | such that :func:`sklearn.tree.plot_tree` can be used (:issue:`366`). 49 | - Significantly improve speed to fit a :class:`sksurv.tree.SurvivalTree`, 50 | :class:`sksurv.ensemble.RandomSurvivalForest`, or :class:`sksurv.ensemble.ExtraSurvivalTrees` 51 | (:issue:`371`). 52 | - Expose ``_predict_risk_score`` attribute in :class:`sklearn.pipeline.Pipeline` 53 | if the final estimator of the pipeline has such property (:issue:`374`). 54 | - Add support for pandas 2.0 (:issue:`373`). 55 | 56 | Documentation 57 | ^^^^^^^^^^^^^ 58 | - Fix wrong number of selected features in the guide 59 | :ref:`Introduction to Survival Analysis ` 60 | (:issue:`345`). 61 | - Fix broken links with nbsphinx 0.9.2 (:issue:`367`). 62 | 63 | Backwards incompatible changes 64 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 65 | - The attribute ``event_times_`` of estimators has been replaced by ``unique_times_`` 66 | to clarify that these are all the unique times points, not just the once where 67 | an event occurred (:issue:`371`). 68 | - Functions returned by `predict_cumulative_hazard_function` and `predict_survival_function` 69 | of :class:`sksurv.tree.SurvivalTree`, :class:`sksurv.ensemble.RandomSurvivalForest`, 70 | and :class:`sksurv.ensemble.ExtraSurvivalTrees` are over all unique time points 71 | passed as training data, instead of all unique time points where events occurred 72 | (:issue:`371`). 73 | - Evaluating a function returned by `predict_cumulative_hazard_function` 74 | or `predict_survival_function` will no longer raise an exception if the 75 | specified time point is smaller than the smallest time point observed 76 | during training. Instead, the value at ``StepFunction.x[0]`` will be returned 77 | (:issue:`375`). 78 | -------------------------------------------------------------------------------- /doc/release_notes/v0.22.rst: -------------------------------------------------------------------------------- 1 | .. _release_notes_0_22: 2 | 3 | What's new in 0.22 4 | ================== 5 | 6 | scikit-survival 0.22.2 (2023-12-30) 7 | ----------------------------------- 8 | 9 | This release adds support for Python 3.12. 10 | 11 | Bug fixes 12 | ^^^^^^^^^ 13 | - Fix invalid escape sequence in :ref:`Introduction ` of user guide. 14 | 15 | Enhancements 16 | ^^^^^^^^^^^^ 17 | - Mark Cython functions as noexcept (:issue:`418`). 18 | - Add support for Python 3.12 (:issue:`422`). 19 | - Do not use deprecated ``is_categorical_dtype()`` of Pandas API. 20 | 21 | Documentation 22 | ^^^^^^^^^^^^^ 23 | - Add section :ref:`building-cython-code` to contributing guidelines (:issue:`379`). 24 | - Improve the description of the ``estimate`` parameter in :func:`sksurv.metrics.brier_score` 25 | and :func:`sksurv.metrics.integrated_brier_score` (:issue:`424`). 26 | 27 | 28 | scikit-survival 0.22.1 (2023-10-08) 29 | ----------------------------------- 30 | 31 | Bug fixes 32 | ^^^^^^^^^ 33 | - Fix error in :meth:`sksurv.tree.SurvivalTree.fit` if ``X`` has missing values and dtype other than float32 (:issue:`412`). 34 | 35 | 36 | scikit-survival 0.22.0 (2023-10-01) 37 | ----------------------------------- 38 | 39 | This release adds support for scikit-learn 1.3, 40 | which includes :ref:`missing value support ` for 41 | :class:`sksurv.tree.SurvivalTree`. 42 | Support for previous versions of scikit-learn has been dropped. 43 | 44 | Moreover, a ``low_memory`` option has been added to :class:`sksurv.ensemble.RandomSurvivalForest`, 45 | :class:`sksurv.ensemble.ExtraSurvivalTrees`, and :class:`sksurv.tree.SurvivalTree` 46 | which reduces the memory footprint of calling ``predict``, but disables the use 47 | of ``predict_cumulative_hazard_function`` and ``predict_survival_function``. 48 | 49 | Bug fixes 50 | ^^^^^^^^^ 51 | - Fix issue where an estimator could be fit to data containing 52 | negative event times (:issue:`410`). 53 | 54 | Enhancements 55 | ^^^^^^^^^^^^ 56 | - Expand test_stacking.py coverage w.r.t. ``predict_log_proba`` (:issue:`380`). 57 | - Add ``low_memory`` option to :class:`sksurv.ensemble.RandomSurvivalForest`, 58 | :class:`sksurv.ensemble.ExtraSurvivalTrees`, and 59 | :class:`sksurv.tree.SurvivalTree`. If set, ``predict`` computations use 60 | less memory, but ``predict_cumulative_hazard_function`` 61 | and ``predict_survival_function`` are not implemented (:issue:`369`). 62 | - Allow calling :meth:`sksurv.meta.Stacking.predict_cumulative_hazard_function` 63 | and :meth:`sksurv.meta.Stacking.predict_survival_function` 64 | if the meta estimator supports it (:issue:`388`). 65 | - Add support for missing values in :class:`sksurv.tree.SurvivalTree` based 66 | on missing value support in scikit-learn 1.3 (:issue:`405`). 67 | - Update bundled Eigen to 3.4.0. 68 | 69 | Documentation 70 | ^^^^^^^^^^^^^ 71 | - Add :attr:`sksurv.meta.Stacking.unique_times_` to API docs. 72 | - Upgrade to Sphinx 6.2.1 and pydata_sphinx_theme 0.13.3 (:issue:`390`). 73 | 74 | Backwards incompatible changes 75 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 76 | - The ``loss_`` attribute of :class:`sksurv.ensemble.ComponentwiseGradientBoostingSurvivalAnalysis` 77 | and :class:`sksurv.ensemble.GradientBoostingSurvivalAnalysis` has been removed (:issue:`402`). 78 | - Support for ``max_features='auto'`` in :class:`sksurv.ensemble.GradientBoostingSurvivalAnalysis` 79 | and :class:`sksurv.tree.SurvivalTree` has been removed (:issue:`402`). 80 | -------------------------------------------------------------------------------- /doc/release_notes/v0.23.rst: -------------------------------------------------------------------------------- 1 | .. _release_notes_0_23: 2 | 3 | What's new in 0.23 4 | ================== 5 | 6 | scikit-survival 0.23.1 (2024-11-04) 7 | ----------------------------------- 8 | 9 | This release adds support for Python 3.13. 10 | The minimum required version for pandas has been bumped to pandas 1.4.0. 11 | 12 | Bug fixes 13 | ^^^^^^^^^ 14 | - Add `SurvivalAnalysisMixin` base class to :class:`sksurv.ensemble.ExtraSurvivalTrees` 15 | to enable the :meth:`sksurv.ensemble.ExtraSurvivalTrees.score` method that has been 16 | unintentionally removed in 0.23.0 (:issue:`488`). 17 | 18 | Enhancements 19 | ^^^^^^^^^^^^ 20 | - Improve performance of :func:`sksurv.metrics.concordance_index_censored` and 21 | :func:`sksurv.metrics.concordance_index_ipcw` (:issue:`465`). 22 | 23 | Backwards incompatible changes 24 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 25 | - Support for pandas versions before 1.4.0 has been dropped. 26 | 27 | 28 | scikit-survival 0.23.0 (2024-06-30) 29 | ----------------------------------- 30 | 31 | This release adds support for scikit-learn 1.4 and 1.5, which 32 | includes :ref:`missing value support ` 33 | for :class:`sksurv.ensemble.RandomSurvivalForest`. 34 | 35 | Moreover, this release fixes critical bugs. When fitting :class:`sksurv.tree.SurvivalTree`, 36 | the `sample_weight` is now correctly considered when computing the log-rank statistic 37 | for each split. This change also affects :class:`sksurv.ensemble.RandomSurvivalForest` and 38 | :class:`sksurv.ensemble.ExtraSurvivalTrees` which pass `sample_weight` to the individual 39 | trees in the ensemble. 40 | 41 | This release fixes a bug in :class:`sksurv.ensemble.ComponentwiseGradientBoostingSurvivalAnalysis` 42 | and :class:`sksurv.ensemble.GradientBoostingSurvivalAnalysis` when dropout is used. 43 | Previously, dropout was only applied starting with the third iteration, now dropout is applied 44 | in the second iteration too. 45 | 46 | Finally, this release adds compatibility with numpy 2.0 and drops support for Python 3.8. 47 | 48 | Bug fixes 49 | ^^^^^^^^^ 50 | - Fix issue with `dropout` in :class:`sksurv.ensemble.ComponentwiseGradientBoostingSurvivalAnalysis` 51 | and :class:`sksurv.ensemble.GradientBoostingSurvivalAnalysis`, where it was only applied starting with the third iteration. 52 | - Fix LogrankCriterion in :class:`sksurv.tree.SurvivalTree` to handle `sample_weight` correctly (:issue:`464`). 53 | 54 | Enhancements 55 | ^^^^^^^^^^^^ 56 | - Fix deprecations with pandas 2.2. 57 | - Drop importlib-resources dependency. 58 | - Add support for scikit-learn 1.4 (:issue:`441`). 59 | - Add `warm_start` support to :class:`sksurv.ensemble.ComponentwiseGradientBoostingSurvivalAnalysis`. 60 | - Add missing values support to :class:`sksurv.ensemble.RandomSurvivalForest`. 61 | - Add `require_y` tag to :class:`sksurv.base.SurvivalAnalysisMixin`. 62 | - Upgrade to ruff 0.4.3 (:issue:`456`). 63 | - Add support for scikit-learn 1.5 (:issue:`461`). 64 | 65 | Documentation 66 | ^^^^^^^^^^^^^ 67 | - Fix :func:`sksurv.nonparametric.kaplan_meier_estimator` documentation to give correct default value for `conf_type` (:issue:`430`). 68 | - Fix allowed values for `kernel` in :class:`sksurv.svm.FastSurvivalSVM` and :class:`sksurv.svm.MinlipSurvivalAnalysis` (:issue:`444`). 69 | - Fix typo in API doc of :class:`sksurv.ensemble.RandomSurvivalForest` and :class:`sksurv.ensemble.ExtraSurvivalTrees` (:issue:`446`). 70 | - Fix API doc for the `criterion` parameter of :class:`sksurv.ensemble.GradientBoostingSurvivalAnalysis` (:issue:`449`). 71 | - Update links to scipy, pandas and numpy documentation. 72 | - Fix letter of hyper-parameter used in the formula for :class:`sksurv.linear_model.IPCRidge` (:issue:`454`). 73 | - Upgrade Sphinx to 7.3 and pydata-sphinx-theme to 0.15 (:issue:`455`). 74 | 75 | Backwards incompatible changes 76 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 77 | - Drop support for Python 3.8 (:issue:`427`). 78 | -------------------------------------------------------------------------------- /doc/release_notes/v0.24.rst: -------------------------------------------------------------------------------- 1 | .. _release_notes_0_24: 2 | 3 | What's new in 0.24 4 | ================== 5 | 6 | scikit-survival 0.24.1 (2025-03-25) 7 | ----------------------------------- 8 | 9 | This release restricts the version of osqp to versions prior to 1.0.0. 10 | 11 | 12 | scikit-survival 0.24.0 (2025-02-24) 13 | ----------------------------------- 14 | 15 | This release adds support for scikit-learn 1.6, which includes missing-values support 16 | for :class:`sksurv.ensemble.ExtraSurvivalTrees`. 17 | Moreover, the release features :func:`sksurv.nonparametric.cumulative_incidence_competing_risks` 18 | which implements a non-parameteric estimator of the cumulative incidence function 19 | for competing risks. 20 | See the :ref:`user guide on the analysis of competing risks `. 21 | 22 | Bug fixes 23 | ^^^^^^^^^ 24 | - In the C++ code of :class:`sksurv.linear_model.CoxnetSurvivalAnalysis`, set type of ``n_alphas`` 25 | to ``VectorType::Index`` instead of ``ìnt``, because on Windows, 26 | int and Eigen's Index differ in size. 27 | - Fix printing of Python version in :func:`sksurv.show_versions`. 28 | - Give an error if ``max_sample`` is set, but ``bootstrap`` is False in 29 | :class:`sksurv.ensemble.RandomSurvivalForest` and 30 | :class:`sksurv.ensemble.ExtraSurvivalTrees`. 31 | 32 | Enhancements 33 | ^^^^^^^^^^^^ 34 | - Add :func:`sksurv.nonparametric.cumulative_incidence_competing_risks` to estimate 35 | the cumulative incidence function in the case of competing risks (:issue:`491`, :issue:`500`). 36 | - Add :func:`sksurv.datasets.load_bmt` and :func:`sksurv.datasets.load_cgvhd` which are 37 | datasets with competing risks (:issue:`491`, :issue:`500`). 38 | - Add missing-values support to :class:`sksurv.ensemble.ExtraSurvivalTrees` (:issue:`504`). 39 | - Add ``estimators_samples_`` property to :class:`sksurv.ensemble.RandomSurvivalForest` and 40 | :class:`sksurv.ensemble.ExtraSurvivalTrees`. 41 | - Upgrade syntax to Python 3.10. 42 | - Run nbval with Python 3.12, matplotlib 3.9, and seaborn 0.13. 43 | 44 | Documentation 45 | ^^^^^^^^^^^^^ 46 | - Fix links in documentation. 47 | - Add :ref:`user guide on the analysis of competing risks `. 48 | 49 | Backwards incompatible changes 50 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 51 | - Support for scikit-learn versions before 1.6.1 has been dropped (:issue:`504`). 52 | - Support for Python versions before 3.10 has been dropped. 53 | -------------------------------------------------------------------------------- /doc/release_notes/v0.3.rst: -------------------------------------------------------------------------------- 1 | .. _release_notes_0_3: 2 | 3 | What's new in 0.3 4 | ================= 5 | 6 | scikit-survival 0.3 (2017-08-01) 7 | -------------------------------- 8 | 9 | This release adds :meth:`sksurv.linear_model.CoxPHSurvivalAnalysis.predict_survival_function` 10 | and :meth:`sksurv.linear_model.CoxPHSurvivalAnalysis.predict_cumulative_hazard_function`, 11 | which return the survival function and cumulative hazard function using Breslow's 12 | estimator. 13 | Moreover, it fixes a build error on Windows (:issue:`3`) 14 | and adds the :class:`sksurv.preprocessing.OneHotEncoder` class, which can be used in 15 | a `scikit-learn pipeline `_. 16 | -------------------------------------------------------------------------------- /doc/release_notes/v0.4.rst: -------------------------------------------------------------------------------- 1 | .. _release_notes_0_4: 2 | 3 | What's new in 0.4 4 | ================= 5 | 6 | scikit-survival 0.4 (2017-10-28) 7 | -------------------------------- 8 | 9 | This release adds :class:`sksurv.linear_model.CoxnetSurvivalAnalysis`, which implements 10 | an efficient algorithm to fit Cox's proportional hazards model with LASSO, ridge, and 11 | elastic net penalty. 12 | Moreover, it includes support for Windows with Python 3.5 and later by making the cvxopt 13 | package optional. 14 | -------------------------------------------------------------------------------- /doc/release_notes/v0.5.rst: -------------------------------------------------------------------------------- 1 | .. _release_notes_0_5: 2 | 3 | What's new in 0.5 4 | ================= 5 | 6 | scikit-survival 0.5 (2017-12-09) 7 | -------------------------------- 8 | 9 | This release adds support for scikit-learn 0.19 and pandas 0.21. In turn, 10 | support for older versions is dropped, namely Python 3.4, scikit-learn 0.18, 11 | and pandas 0.18. 12 | -------------------------------------------------------------------------------- /doc/release_notes/v0.6.rst: -------------------------------------------------------------------------------- 1 | .. _release_notes_0_6: 2 | 3 | What's new in 0.6 4 | ================= 5 | 6 | scikit-survival 0.6 (2018-10-07) 7 | -------------------------------- 8 | 9 | This release adds support for numpy 1.14 and pandas up to 0.23. 10 | In addition, the new class :class:`sksurv.util.Surv` makes it easier 11 | to construct a structured array from numpy arrays, lists, or a pandas data frame. 12 | 13 | **Changes:** 14 | 15 | - Support numpy 1.14 and pandas 0.22, 0.23 (#36). 16 | - Enable support for cvxopt with Python 3.5+ on Windows (requires cvxopt >=1.1.9). 17 | - Add `max_iter` parameter to :class:`sksurv.svm.MinlipSurvivalAnalysis` 18 | and :class:`sksurv.svm.HingeLossSurvivalSVM`. 19 | - Fix score function of :class:`sksurv.svm.NaiveSurvivalSVM` to use concordance index. 20 | - :class:`sksurv.linear_model.CoxnetSurvivalAnalysis` now throws an exception if coefficients get too large (#47). 21 | - Add :class:`sksurv.util.Surv` class to ease constructing a structured array (#26). 22 | -------------------------------------------------------------------------------- /doc/release_notes/v0.7.rst: -------------------------------------------------------------------------------- 1 | .. _release_notes_0_7: 2 | 3 | What's new in 0.7 4 | ================= 5 | 6 | scikit-survival 0.7 (2019-02-27) 7 | -------------------------------- 8 | 9 | This release adds support for Python 3.7 and sklearn 0.20. 10 | 11 | **Changes:** 12 | 13 | - Add support for sklearn 0.20 (#48). 14 | - Migrate to py.test (#50). 15 | - Explicitly request ECOS solver for :class:`sksurv.svm.MinlipSurvivalAnalysis` 16 | and :class:`sksurv.svm.HingeLossSurvivalSVM`. 17 | - Add support for Python 3.7 (#49). 18 | - Add support for cvxpy >=1.0. 19 | - Add support for numpy 1.15. 20 | -------------------------------------------------------------------------------- /doc/release_notes/v0.8.rst: -------------------------------------------------------------------------------- 1 | .. _release_notes_0_8: 2 | 3 | What's new in 0.8 4 | ================= 5 | 6 | scikit-survival 0.8 (2019-05-01) 7 | -------------------------------- 8 | 9 | Enhancements 10 | ^^^^^^^^^^^^ 11 | 12 | - Add :meth:`sksurv.linear_model.CoxnetSurvivalAnalysis.predict_survival_function` 13 | and :meth:`sksurv.linear_model.CoxnetSurvivalAnalysis.predict_cumulative_hazard_function` 14 | (#46). 15 | - Add :class:`sksurv.nonparametric.SurvivalFunctionEstimator` 16 | and :class:`sksurv.nonparametric.CensoringDistributionEstimator` that 17 | wrap :func:`sksurv.nonparametric.kaplan_meier_estimator` and provide 18 | a `predict_proba` method for evaluating the estimated function on 19 | test data. 20 | - Implement censoring-adjusted C-statistic proposed by Uno et al. (2011) 21 | in :func:`sksurv.metrics.concordance_index_ipcw`. 22 | - Add estimator of cumulative/dynamic AUC of Uno et al. (2007) 23 | in :func:`sksurv.metrics.cumulative_dynamic_auc`. 24 | - Add flchain dataset (see :func:`sksurv.datasets.load_flchain`). 25 | 26 | Bug fixes 27 | ^^^^^^^^^ 28 | 29 | - The `tied_time` return value of :func:`sksurv.metrics.concordance_index_censored` 30 | now correctly reflects the number of comparable pairs that share the same time 31 | and that are used in computing the concordance index. 32 | - Fix a bug in :func:`sksurv.metrics.concordance_index_censored` where a 33 | pair with risk estimates within tolerance was counted both as 34 | concordant and tied. 35 | -------------------------------------------------------------------------------- /doc/release_notes/v0.9.rst: -------------------------------------------------------------------------------- 1 | .. _release_notes_0_9: 2 | 3 | What's new in 0.9 4 | ================= 5 | 6 | scikit-survival 0.9 (2019-07-26) 7 | -------------------------------- 8 | 9 | This release adds support for sklearn 0.21 and pandas 0.24. 10 | 11 | Enhancements 12 | ^^^^^^^^^^^^ 13 | 14 | - Add reference to IPCRidge (#65). 15 | - Use scipy.special.comb instead of deprecated scipy.misc.comb. 16 | - Add support for pandas 0.24 and drop support for 0.20. 17 | - Add support for scikit-learn 0.21 and drop support for 0.20 (#71). 18 | - Explain use of intercept in ComponentwiseGradientBoostingSurvivalAnalysis (#68) 19 | - Bump Eigen to 3.3.7. 20 | 21 | Bug fixes 22 | ^^^^^^^^^ 23 | - Disallow scipy 1.3.0 due to scipy regression (#66). 24 | -------------------------------------------------------------------------------- /doc/spelling_wordlist.txt: -------------------------------------------------------------------------------- 1 | arff 2 | biomarker 3 | Breiman 4 | Breslow 5 | boolean 6 | callables 7 | Covariance 8 | covariance 9 | covariates 10 | cvxpy 11 | cvxopt 12 | Cython 13 | DataFrame 14 | dataframe 15 | dataset 16 | datasets 17 | Deprecations 18 | dimensionality 19 | discriminative 20 | dtype 21 | dtypes 22 | Efron 23 | Eigen 24 | filename 25 | hyperplane 26 | infeasible 27 | Kaplan-Meier 28 | Lipschitz 29 | macOS 30 | mae 31 | mse 32 | ndarray 33 | Nelson-Aalen 34 | numpy 35 | overfitting 36 | params 37 | parameterized 38 | precomputed 39 | Pre-Processing 40 | reStructuredText 41 | readthedocs 42 | scikit 43 | scikit-learn 44 | scikit-survival 45 | scipy 46 | sigmoid 47 | sklearn 48 | sksurv 49 | sqrt 50 | stagewise 51 | subobjects 52 | subsampling 53 | tol 54 | Uno 55 | unpenalized -------------------------------------------------------------------------------- /doc/user_guide/index.rst: -------------------------------------------------------------------------------- 1 | .. _User Guide: 2 | 3 | User Guide 4 | ========== 5 | 6 | The User Guide covers the most important aspects of doing to survival analysis with scikit-survival. 7 | 8 | It is assumed that users have a basic understanding of survival analysis. If you are brand-new to survival 9 | analysis, consider studying the basics first, e.g. by reading an introductory book, such as 10 | 11 | * David G. Kleinbaum and Mitchel Klein (2012), Survival Analysis: A Self-Learning Text, Springer. 12 | * John P. Klein and Melvin L. Moeschberger (2003), Survival Analysis: Techniques for Censored and Truncated Data, Springer. 13 | 14 | Users new to scikit-survival should read :ref:`understanding_predictions` to get familiar with the basic concepts. 15 | The interactive guide :ref:`/user_guide/00-introduction.ipynb` gives a brief overview of how to use scikit-survival for survival analysis. 16 | Once you are familiar with the basics, it is highly recommended reading the guide :ref:`/user_guide/evaluating-survival-models.ipynb`, 17 | which discusses common pitfalls when evaluating the predictive performance of survival models. 18 | Finally, there are several model-specific guides that discuss details about particular models, with many examples throughout. 19 | 20 | Background 21 | ---------- 22 | 23 | .. toctree:: 24 | :maxdepth: 1 25 | 26 | understanding_predictions 27 | 00-introduction 28 | evaluating-survival-models 29 | competing-risks 30 | 31 | Models 32 | ------ 33 | 34 | .. toctree:: 35 | :maxdepth: 1 36 | 37 | coxnet 38 | random-survival-forest 39 | boosting 40 | survival-svm 41 | -------------------------------------------------------------------------------- /sksurv/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib.metadata import PackageNotFoundError, version 2 | import platform 3 | import sys 4 | 5 | from sklearn.pipeline import Pipeline, _final_estimator_has 6 | from sklearn.utils.metaestimators import available_if 7 | 8 | from .util import property_available_if 9 | 10 | 11 | def _get_version(name): 12 | try: 13 | pkg_version = version(name) 14 | except ImportError: 15 | pkg_version = None 16 | return pkg_version 17 | 18 | 19 | def show_versions(): 20 | sys_info = { 21 | "Platform": platform.platform(), 22 | "Python version": f"{platform.python_implementation()} {platform.python_version()}", 23 | "Python interpreter": sys.executable, 24 | } 25 | 26 | deps = [ 27 | "scikit-survival", 28 | "scikit-learn", 29 | "numpy", 30 | "scipy", 31 | "pandas", 32 | "numexpr", 33 | "ecos", 34 | "osqp", 35 | "joblib", 36 | "matplotlib", 37 | "pytest", 38 | "sphinx", 39 | "Cython", 40 | "pip", 41 | "setuptools", 42 | ] 43 | minwidth = max( 44 | max(map(len, deps)), 45 | max(map(len, sys_info.keys())), 46 | ) 47 | fmt = f"{{0:<{minwidth}s}}: {{1}}" 48 | 49 | print("SYSTEM") 50 | print("------") 51 | for name, version_string in sys_info.items(): 52 | print(fmt.format(name, version_string)) 53 | 54 | print("\nDEPENDENCIES") 55 | print("------------") 56 | for dep in deps: 57 | version_string = _get_version(dep) 58 | print(fmt.format(dep, version_string)) 59 | 60 | 61 | @available_if(_final_estimator_has("predict_cumulative_hazard_function")) 62 | def predict_cumulative_hazard_function(self, X, **kwargs): 63 | """Predict cumulative hazard function. 64 | 65 | The cumulative hazard function for an individual 66 | with feature vector :math:`x` is defined as 67 | 68 | .. math:: 69 | 70 | H(t \\mid x) = \\exp(x^\\top \\beta) H_0(t) , 71 | 72 | where :math:`H_0(t)` is the baseline hazard function, 73 | estimated by Breslow's estimator. 74 | 75 | Parameters 76 | ---------- 77 | X : array-like, shape = (n_samples, n_features) 78 | Data matrix. 79 | 80 | Returns 81 | ------- 82 | cum_hazard : ndarray, shape = (n_samples,) 83 | Predicted cumulative hazard functions. 84 | """ 85 | Xt = X 86 | for _, _, transform in self._iter(with_final=False): 87 | Xt = transform.transform(Xt) 88 | return self.steps[-1][-1].predict_cumulative_hazard_function(Xt, **kwargs) 89 | 90 | 91 | @available_if(_final_estimator_has("predict_survival_function")) 92 | def predict_survival_function(self, X, **kwargs): 93 | """Predict survival function. 94 | 95 | The survival function for an individual 96 | with feature vector :math:`x` is defined as 97 | 98 | .. math:: 99 | 100 | S(t \\mid x) = S_0(t)^{\\exp(x^\\top \\beta)} , 101 | 102 | where :math:`S_0(t)` is the baseline survival function, 103 | estimated by Breslow's estimator. 104 | 105 | Parameters 106 | ---------- 107 | X : array-like, shape = (n_samples, n_features) 108 | Data matrix. 109 | 110 | Returns 111 | ------- 112 | survival : ndarray, shape = (n_samples,) 113 | Predicted survival functions. 114 | """ 115 | Xt = X 116 | for _, _, transform in self._iter(with_final=False): 117 | Xt = transform.transform(Xt) 118 | return self.steps[-1][-1].predict_survival_function(Xt, **kwargs) 119 | 120 | 121 | @property_available_if(_final_estimator_has("_predict_risk_score")) 122 | def _predict_risk_score(self): 123 | return self.steps[-1][-1]._predict_risk_score 124 | 125 | 126 | def patch_pipeline(): 127 | Pipeline.predict_survival_function = predict_survival_function 128 | Pipeline.predict_cumulative_hazard_function = predict_cumulative_hazard_function 129 | Pipeline._predict_risk_score = _predict_risk_score 130 | 131 | 132 | try: 133 | __version__ = version("scikit-survival") 134 | except PackageNotFoundError: # pragma: no cover 135 | # package is not installed 136 | __version__ = "unknown" 137 | 138 | patch_pipeline() 139 | -------------------------------------------------------------------------------- /sksurv/base.py: -------------------------------------------------------------------------------- 1 | # This program is free software: you can redistribute it and/or modify 2 | # it under the terms of the GNU General Public License as published by 3 | # the Free Software Foundation, either version 3 of the License, or 4 | # (at your option) any later version. 5 | # 6 | # This program is distributed in the hope that it will be useful, 7 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 8 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 9 | # GNU General Public License for more details. 10 | # 11 | # You should have received a copy of the GNU General Public License 12 | # along with this program. If not, see . 13 | import numpy as np 14 | 15 | 16 | class SurvivalAnalysisMixin: 17 | def _predict_function(self, func_name, baseline_model, prediction, return_array): 18 | fns = getattr(baseline_model, func_name)(prediction) 19 | 20 | if not return_array: 21 | return fns 22 | 23 | times = baseline_model.unique_times_ 24 | arr = np.empty((prediction.shape[0], times.shape[0]), dtype=float) 25 | for i, fn in enumerate(fns): 26 | arr[i, :] = fn(times) 27 | return arr 28 | 29 | def _predict_survival_function(self, baseline_model, prediction, return_array): 30 | """Return survival functions. 31 | 32 | Parameters 33 | ---------- 34 | baseline_model : sksurv.linear_model.coxph.BreslowEstimator 35 | Estimator of baseline survival function. 36 | 37 | prediction : array-like, shape=(n_samples,) 38 | Predicted risk scores. 39 | 40 | return_array : bool 41 | If True, return a float array of the survival function 42 | evaluated at the unique event times, otherwise return 43 | an array of :class:`sksurv.functions.StepFunction` instances. 44 | 45 | Returns 46 | ------- 47 | survival : ndarray 48 | """ 49 | return self._predict_function("get_survival_function", baseline_model, prediction, return_array) 50 | 51 | def _predict_cumulative_hazard_function(self, baseline_model, prediction, return_array): 52 | """Return cumulative hazard functions. 53 | 54 | Parameters 55 | ---------- 56 | baseline_model : sksurv.linear_model.coxph.BreslowEstimator 57 | Estimator of baseline cumulative hazard function. 58 | 59 | prediction : array-like, shape=(n_samples,) 60 | Predicted risk scores. 61 | 62 | return_array : bool 63 | If True, return a float array of the cumulative hazard function 64 | evaluated at the unique event times, otherwise return 65 | an array of :class:`sksurv.functions.StepFunction` instances. 66 | 67 | Returns 68 | ------- 69 | cum_hazard : ndarray 70 | """ 71 | return self._predict_function("get_cumulative_hazard_function", baseline_model, prediction, return_array) 72 | 73 | def score(self, X, y): 74 | """Returns the concordance index of the prediction. 75 | 76 | Parameters 77 | ---------- 78 | X : array-like, shape = (n_samples, n_features) 79 | Test samples. 80 | 81 | y : structured array, shape = (n_samples,) 82 | A structured array containing the binary event indicator 83 | as first field, and time of event or time of censoring as 84 | second field. 85 | 86 | Returns 87 | ------- 88 | cindex : float 89 | Estimated concordance index. 90 | """ 91 | from .metrics import concordance_index_censored 92 | 93 | name_event, name_time = y.dtype.names 94 | 95 | risk_score = self.predict(X) 96 | if not getattr(self, "_predict_risk_score", True): 97 | risk_score *= -1 # convert prediction on time scale to risk scale 98 | 99 | result = concordance_index_censored(y[name_event], y[name_time], risk_score) 100 | return result[0] 101 | 102 | def __sklearn_tags__(self): 103 | tags = super().__sklearn_tags__() 104 | tags.target_tags.required = True 105 | return tags 106 | -------------------------------------------------------------------------------- /sksurv/bintrees/__init__.py: -------------------------------------------------------------------------------- 1 | # This program is free software: you can redistribute it and/or modify 2 | # it under the terms of the GNU General Public License as published by 3 | # the Free Software Foundation, either version 3 of the License, or 4 | # (at your option) any later version. 5 | # 6 | # This program is distributed in the hope that it will be useful, 7 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 8 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 9 | # GNU General Public License for more details. 10 | # 11 | # You should have received a copy of the GNU General Public License 12 | # along with this program. If not, see . 13 | from ._binarytrees import AATree, AVLTree, RBTree 14 | 15 | __all__ = ["RBTree", "AVLTree", "AATree"] 16 | -------------------------------------------------------------------------------- /sksurv/bintrees/_binarytrees.pyx: -------------------------------------------------------------------------------- 1 | # This program is free software: you can redistribute it and/or modify 2 | # it under the terms of the GNU General Public License as published by 3 | # the Free Software Foundation, either version 3 of the License, or 4 | # (at your option) any later version. 5 | # 6 | # This program is distributed in the hope that it will be useful, 7 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 8 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 9 | # GNU General Public License for more details. 10 | # 11 | # You should have received a copy of the GNU General Public License 12 | # along with this program. If not, see . 13 | cimport cython 14 | from libcpp cimport bool 15 | from libcpp.cast cimport dynamic_cast 16 | 17 | 18 | cdef extern from "binarytrees.h": 19 | cdef cppclass rbtree: 20 | rbtree(int l) 21 | void insert_node(double key, double value) 22 | void count_larger(double key, int *count_ret, double *acc_value_ret) 23 | void count_smaller(double key, int *count_ret, double *acc_value_ret) 24 | double vector_sum_larger(double key) 25 | double vector_sum_smaller(double key) 26 | int get_size() 27 | 28 | cdef cppclass avl(rbtree): 29 | avl(int l) 30 | 31 | cdef cppclass aatree(rbtree): 32 | aatree(int l) 33 | 34 | 35 | ctypedef rbtree* rbtree_ptr 36 | 37 | 38 | cdef class BaseTree: 39 | cdef rbtree_ptr treeptr 40 | 41 | def __dealloc__(self): 42 | if self.treeptr is not NULL: 43 | del self.treeptr 44 | self.treeptr = NULL 45 | 46 | def __len__(self): 47 | return self.treeptr.get_size() 48 | 49 | def insert(self, double key, double value): 50 | self.treeptr.insert_node(key, value) 51 | 52 | def count_smaller(self, double key): 53 | cdef int count_ret; 54 | cdef double acc_value_ret; 55 | 56 | self.treeptr.count_smaller(key, &count_ret, &acc_value_ret) 57 | 58 | return count_ret, acc_value_ret 59 | 60 | def count_larger(self, double key): 61 | cdef int count_ret 62 | cdef double acc_value_ret 63 | 64 | self.treeptr.count_larger(key, &count_ret, &acc_value_ret) 65 | 66 | return count_ret, acc_value_ret 67 | 68 | def vector_sum_smaller(self, double key): 69 | return self.treeptr.vector_sum_smaller(key) 70 | 71 | def vector_sum_larger(self, double key): 72 | return self.treeptr.vector_sum_larger(key) 73 | 74 | def count_larger_with_event(self, double key, bool has_event): 75 | if not has_event: 76 | return 0, 0.0 77 | return self.count_larger(key) 78 | 79 | 80 | @cython.final 81 | cdef class RBTree(BaseTree): 82 | def __cinit__(self, int size): 83 | if size <= 0: 84 | raise ValueError('size must be greater zero') 85 | self.treeptr = new rbtree(size) 86 | 87 | @cython.final 88 | cdef class AVLTree(BaseTree): 89 | def __cinit__(self, int size): 90 | if size <= 0: 91 | raise ValueError('size must be greater zero') 92 | self.treeptr = dynamic_cast[rbtree_ptr](new avl(size)) 93 | 94 | @cython.final 95 | cdef class AATree(BaseTree): 96 | def __cinit__(self, int size): 97 | if size <= 0: 98 | raise ValueError('size must be greater zero') 99 | self.treeptr = dynamic_cast[rbtree_ptr](new aatree(size)) 100 | -------------------------------------------------------------------------------- /sksurv/bintrees/binarytrees.h: -------------------------------------------------------------------------------- 1 | // http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/#large_scale_ranksvm 2 | // 3 | // Copyright (c) 2013 Chih-Jen Lin and Ching-Pei Lee 4 | // All rights reserved. 5 | // 6 | // Redistribution and use in source and binary forms, with or without 7 | // modification, are permitted provided that the following conditions 8 | // are met: 9 | // 10 | // 1. Redistributions of source code must retain the above copyright 11 | // notice, this list of conditions and the following disclaimer. 12 | // 13 | // 2. Redistributions in binary form must reproduce the above copyright 14 | // notice, this list of conditions and the following disclaimer in the 15 | // documentation and/or other materials provided with the distribution. 16 | // 17 | // 3. Neither name of copyright holders nor the names of its contributors 18 | // may be used to endorse or promote products derived from this software 19 | // without specific prior written permission. 20 | // 21 | // 22 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 23 | // ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 24 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 25 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR 26 | // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 27 | // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 28 | // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 29 | // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 30 | // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 31 | // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 32 | // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 33 | #ifndef _BINARYTREES 34 | #define _BINARYTREES 35 | 36 | enum {RED,BLACK}; 37 | enum {LEFT,RIGHT}; 38 | struct node 39 | { 40 | node* parent; 41 | node* child[2]; 42 | double key; 43 | int size; 44 | bool color; 45 | int height; 46 | double vector_sum; 47 | }; 48 | 49 | class rbtree 50 | { 51 | public: 52 | rbtree(int l); 53 | virtual ~rbtree(); 54 | virtual void insert_node(double key, double value); 55 | void count_larger(double key, int* count_ret, double* acc_value_ret) const; 56 | void count_smaller(double key, int* count_ret, double* acc_value_ret) const; 57 | double vector_sum_larger(double key) const; 58 | double vector_sum_smaller(double key) const; 59 | int get_size() const { return tree_size;} 60 | protected: 61 | node* null_node; 62 | int tree_size; 63 | void rotate(node* x, int direction); 64 | virtual void tree_color_fix(node* x); 65 | node* root; 66 | node* tree_nodes; 67 | }; 68 | 69 | 70 | class avl: public rbtree 71 | { 72 | public: 73 | avl(int l); 74 | virtual void insert_node(double key, double value); 75 | private: 76 | void tree_balance_fix(node* x); 77 | }; 78 | 79 | class aatree: public rbtree 80 | { 81 | public: 82 | aatree(int l):rbtree(l){}; 83 | protected: 84 | virtual void tree_color_fix(node* x); 85 | }; 86 | 87 | #endif -------------------------------------------------------------------------------- /sksurv/compare.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from scipy import stats 6 | from sklearn.utils.validation import check_array 7 | 8 | from .util import check_array_survival 9 | 10 | __all__ = ["compare_survival"] 11 | 12 | 13 | def compare_survival(y, group_indicator, return_stats=False): 14 | """K-sample log-rank hypothesis test of identical survival functions. 15 | 16 | Compares the pooled hazard rate with each group-specific 17 | hazard rate. The alternative hypothesis is that the hazard 18 | rate of at least one group differs from the others at some time. 19 | 20 | See [1]_ for more details. 21 | 22 | Parameters 23 | ---------- 24 | y : structured array, shape = (n_samples,) 25 | A structured array containing the binary event indicator 26 | as first field, and time of event or time of censoring as 27 | second field. 28 | 29 | group_indicator : array-like, shape = (n_samples,) 30 | Group membership of each sample. 31 | 32 | return_stats : bool, optional, default: False 33 | Whether to return a data frame with statistics for each group 34 | and the covariance matrix of the test statistic. 35 | 36 | Returns 37 | ------- 38 | chisq : float 39 | Test statistic. 40 | pvalue : float 41 | Two-sided p-value with respect to the null hypothesis 42 | that the hazard rates across all groups are equal. 43 | stats : pandas.DataFrame 44 | Summary statistics for each group: number of samples, 45 | observed number of events, expected number of events, 46 | and test statistic. 47 | Only provided if `return_stats` is True. 48 | covariance : array, shape=(n_groups, n_groups) 49 | Covariance matrix of the test statistic. 50 | Only provided if `return_stats` is True. 51 | 52 | References 53 | ---------- 54 | .. [1] Fleming, T. R. and Harrington, D. P. 55 | A Class of Hypothesis Tests for One and Two Samples of Censored Survival Data. 56 | Communications In Statistics 10 (1981): 763-794. 57 | """ 58 | 59 | event, time = check_array_survival(group_indicator, y) 60 | group_indicator = check_array( 61 | group_indicator, 62 | dtype="O", 63 | ensure_2d=False, 64 | estimator="compare_survival", 65 | input_name="group_indicator", 66 | ) 67 | 68 | n_samples = time.shape[0] 69 | groups, group_counts = np.unique(group_indicator, return_counts=True) 70 | n_groups = groups.shape[0] 71 | if n_groups == 1: 72 | raise ValueError("At least two groups must be specified, but only one was provided.") 73 | 74 | # sort descending 75 | o = np.argsort(-time, kind="mergesort") 76 | x = group_indicator[o] 77 | event = event[o] 78 | time = time[o] 79 | 80 | at_risk = np.zeros(n_groups, dtype=int) 81 | observed = np.zeros(n_groups, dtype=int) 82 | expected = np.zeros(n_groups, dtype=float) 83 | covar = np.zeros((n_groups, n_groups), dtype=float) 84 | 85 | covar_indices = np.diag_indices(n_groups) 86 | 87 | k = 0 88 | while k < n_samples: 89 | ti = time[k] 90 | total_events = 0 91 | while k < n_samples and ti == time[k]: 92 | idx = np.searchsorted(groups, x[k]) 93 | if event[k]: 94 | observed[idx] += 1 95 | total_events += 1 96 | at_risk[idx] += 1 97 | k += 1 98 | 99 | if total_events != 0: 100 | total_at_risk = k 101 | expected += at_risk * (total_events / total_at_risk) 102 | if total_at_risk > 1: 103 | multiplier = total_events * (total_at_risk - total_events) / (total_at_risk * (total_at_risk - 1)) 104 | temp = at_risk * multiplier 105 | covar[covar_indices] += temp 106 | covar -= np.outer(temp, at_risk) / total_at_risk 107 | 108 | df = n_groups - 1 109 | zz = observed[:df] - expected[:df] 110 | chisq = np.linalg.solve(covar[:df, :df], zz).dot(zz) 111 | pval = stats.chi2.sf(chisq, df) 112 | 113 | if return_stats: 114 | table = OrderedDict() 115 | table["counts"] = group_counts 116 | table["observed"] = observed 117 | table["expected"] = expected 118 | table["statistic"] = observed - expected 119 | table = pd.DataFrame.from_dict(table) 120 | table.index = pd.Index(groups, name="group", dtype=groups.dtype) 121 | return chisq, pval, table, covar 122 | 123 | return chisq, pval 124 | -------------------------------------------------------------------------------- /sksurv/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import ( 2 | get_x_y, # noqa: F401 3 | load_aids, # noqa: F401 4 | load_arff_files_standardized, # noqa: F401 5 | load_bmt, # noqa: F401 6 | load_breast_cancer, # noqa: F401 7 | load_cgvhd, # noqa: F401 8 | load_flchain, # noqa: F401 9 | load_gbsg2, # noqa: F401 10 | load_veterans_lung_cancer, # noqa: F401 11 | load_whas500, # noqa: F401 12 | ) 13 | -------------------------------------------------------------------------------- /sksurv/datasets/data/README.md: -------------------------------------------------------------------------------- 1 | # Datasets 2 | 3 | This folder contains freely available datasets that than can be used 4 | for survival analysis. 5 | 6 | | Dataset | Description | Samples | Features | Events | Outcome | 7 | |-------------------------------|------------------------------------------------------|---------|----------|--------------|------------------------------| 8 | | actg320_aids or actg320_death | [AIDS study][Hosmer2008] | 1,151 | 11 | 96 (8.3%) | AIDS defining event or death | 9 | | breast-cancer | [Breast cancer][Desmedt2007] | 198 | 80 | 51 (25.8%) | Distant metastases | 10 | | flchain | [Assay of serum free light chain][Dispenzieri2012] | 7874 | 9 | 2169 (27.5%) | Death | 11 | | GBSG2 | [German Breast Cancer Study Group 2][Schumacher1994] | 686 | 8 | 299 (43.6%) | Recurrence free survival | 12 | | veteran | [Veteran's Lung Cancer][Kalbfleisch2008] | 137 | 6 | 128 (93.4%) | Death | 13 | | whas500 | [Worcester Heart Attack Study][Hosmer2008] | 500 | 14 | 215 (43.0%) | Death | 14 | | BMT | [Leukemia HSC Transplant][Scrucca2007] | 35 | 1 | 24 (68.6%) | Transplant-related death or relapse | 15 | | CGVHD | [CGVHD][Pintilie2006] | 100 | 4 | 96 (96%) | Chronic graft versus host disease (CGVHD), relapse or death | 16 | 17 | [Desmedt2007]: http://dx.doi.org/10.1158/1078-0432.CCR-06-2765 "Desmedt, C., Piette, F., Loi et al.: Strong Time Dependence of the 76-Gene Prognostic Signature for Node-Negative Breast Cancer Patients in the TRANSBIG Multicenter Independent Validation Series. Clin. Cancer Res. 13(11), 3207–14 (2007)" 18 | 19 | [Dispenzieri2012]: https://doi.org/10.1016/j.mayocp.2012.03.009 "Dispenzieri, A., Katzmann, J., Kyle, R., Larson, D., Therneau, T., Colby, C., Clark, R., Mead, G., Kumar, S., Melton III, LJ. and Rajkumar, SV. Use of monclonal serum immunoglobulin free light chains to predict overall survival in the general population, Mayo Clinic Proceedings 87:512-523. (2012)" 20 | 21 | [Hosmer2008]: http://www.wiley.com/WileyCDA/WileyTitle/productCd-0471754994.html "Hosmer, D., Lemeshow, S., May, S.: Applied Survival Analysis: Regression Modeling of Time to Event Data. John Wiley & Sons, Inc. (2008)" 22 | 23 | [Kalbfleisch2008]: http://www.wiley.com/WileyCDA/WileyTitle/productCd-047136357X.html "Kalbfleisch, J.D., Prentice, R.L.: The Statistical Analysis of Failure Time Data. John Wiley & Sons, Inc. (2002)" 24 | 25 | [Schumacher1994]: http://ascopubs.org/doi/abs/10.1200/jco.1994.12.10.2086 "Schumacher, M., Basert, G., Bojar, H., et al. Randomized 2 × 2 trial evaluating hormonal treatment and the duration of chemotherapy in node-positive breast cancer patients. Journal of Clinical Oncology 12, 2086–2093. (1994)" 26 | 27 | [Scrucca2007]: https://doi.org/10.1038/sj.bmt.1705727 "Scrucca, L., Santucci, A. & Aversa, F. Competing risk analysis using R: an easy guide for clinicians. Bone Marrow Transplant 40, 381–387 (2007)" 28 | 29 | [Pintilie2006]: https://www.wiley.com/en-us/Competing+Risks%3A+A+Practical+Perspective-p-9780470870693 "Melania Pintilie: Competing Risks: A Practical Perspective. John Wiley & Sons, (2006)" 30 | -------------------------------------------------------------------------------- /sksurv/datasets/data/bmt.arff: -------------------------------------------------------------------------------- 1 | % Scrucca L., Santucci A., Aversa F. (2007) 2 | % Competing risks analysis using R: an easy guide for clinicians. 3 | % Bone Marrow Transplantation 40, 381-387. 4 | % https://luca-scr.github.io/R/bmt.csv 5 | @RELATION BMT 6 | 7 | @ATTRIBUTE dis {0,1} 8 | @ATTRIBUTE ftime NUMERIC 9 | @ATTRIBUTE status {0,1,2} 10 | 11 | @DATA 12 | 0,13,2 13 | 0,1,1 14 | 0,72,0 15 | 0,7,2 16 | 0,8,2 17 | 1,67,0 18 | 0,9,2 19 | 0,5,2 20 | 1,70,0 21 | 1,4,0 22 | 1,7,0 23 | 1,68,0 24 | 0,1,2 25 | 1,10,2 26 | 1,7,2 27 | 1,3,1 28 | 1,4,1 29 | 1,4,1 30 | 1,3,1 31 | 1,3,1 32 | 0,22,2 33 | 1,8,1 34 | 1,2,2 35 | 0,0,2 36 | 0,0,1 37 | 0,35,0 38 | 1,35,0 39 | 0,4,2 40 | 0,14,2 41 | 0,26,2 42 | 0,3,2 43 | 1,2,0 44 | 1,8,0 45 | 1,32,0 46 | 0,12,1 47 | -------------------------------------------------------------------------------- /sksurv/ensemble/__init__.py: -------------------------------------------------------------------------------- 1 | from .boosting import ComponentwiseGradientBoostingSurvivalAnalysis, GradientBoostingSurvivalAnalysis # noqa: F401 2 | from .forest import ExtraSurvivalTrees, RandomSurvivalForest # noqa: F401 3 | -------------------------------------------------------------------------------- /sksurv/ensemble/_coxph_loss.pyx: -------------------------------------------------------------------------------- 1 | # This program is free software: you can redistribute it and/or modify 2 | # it under the terms of the GNU General Public License as published by 3 | # the Free Software Foundation, either version 3 of the License, or 4 | # (at your option) any later version. 5 | # 6 | # This program is distributed in the hope that it will be useful, 7 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 8 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 9 | # GNU General Public License for more details. 10 | # 11 | # You should have received a copy of the GNU General Public License 12 | # along with this program. If not, see . 13 | cimport cython 14 | from libc cimport math 15 | 16 | import numpy as np 17 | cimport numpy as cnp 18 | 19 | cnp.import_array() 20 | 21 | 22 | @cython.wraparound(False) 23 | @cython.cdivision(True) 24 | @cython.boundscheck(False) 25 | def coxph_negative_gradient(cnp.npy_uint8[:] event, 26 | cnp.npy_double[:] time, 27 | cnp.npy_double[:] y_pred): 28 | cdef cnp.npy_double s 29 | cdef int i 30 | cdef int j 31 | cdef cnp.npy_intp n_samples = event.shape[0] 32 | 33 | cdef cnp.ndarray[cnp.npy_double, ndim=1] gradient = cnp.PyArray_EMPTY(1, &n_samples, cnp.NPY_DOUBLE, 0) 34 | cdef cnp.npy_double[:] exp_tsj = cnp.PyArray_ZEROS(1, &n_samples, cnp.NPY_DOUBLE, 0) 35 | 36 | cdef cnp.npy_double[:] exp_pred = np.exp(y_pred) 37 | with nogil: 38 | for i in range(n_samples): 39 | for j in range(n_samples): 40 | if time[j] >= time[i]: 41 | exp_tsj[i] += exp_pred[j] 42 | 43 | for i in range(n_samples): 44 | s = 0 45 | for j in range(n_samples): 46 | if event[j] and time[i] >= time[j]: 47 | s += exp_pred[i] / exp_tsj[j] 48 | gradient[i] = event[i] - s 49 | 50 | return gradient 51 | 52 | 53 | @cython.wraparound(False) 54 | @cython.cdivision(True) 55 | @cython.boundscheck(False) 56 | def coxph_loss(cnp.npy_uint8[:] event, 57 | cnp.npy_double[:] time, 58 | cnp.npy_double[:] y_pred): 59 | cdef cnp.npy_intp n_samples = event.shape[0] 60 | cdef cnp.npy_double at_risk 61 | cdef cnp.npy_double loss = 0 62 | 63 | with nogil: 64 | for i in range(n_samples): 65 | at_risk = 0 66 | for j in range(n_samples): 67 | if time[j] >= time[i]: 68 | at_risk += math.exp(y_pred[j]) 69 | loss += event[i] * (y_pred[i] - math.log(at_risk)) 70 | 71 | return - loss 72 | -------------------------------------------------------------------------------- /sksurv/exceptions.py: -------------------------------------------------------------------------------- 1 | # This program is free software: you can redistribute it and/or modify 2 | # it under the terms of the GNU General Public License as published by 3 | # the Free Software Foundation, either version 3 of the License, or 4 | # (at your option) any later version. 5 | # 6 | # This program is distributed in the hope that it will be useful, 7 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 8 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 9 | # GNU General Public License for more details. 10 | # 11 | # You should have received a copy of the GNU General Public License 12 | # along with this program. If not, see . 13 | 14 | 15 | class NoComparablePairException(ValueError): 16 | """An error indicating that data of censored event times 17 | does not contain one or more comparable pairs. 18 | """ 19 | -------------------------------------------------------------------------------- /sksurv/functions.py: -------------------------------------------------------------------------------- 1 | # This program is free software: you can redistribute it and/or modify 2 | # it under the terms of the GNU General Public License as published by 3 | # the Free Software Foundation, either version 3 of the License, or 4 | # (at your option) any later version. 5 | # 6 | # This program is distributed in the hope that it will be useful, 7 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 8 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 9 | # GNU General Public License for more details. 10 | # 11 | # You should have received a copy of the GNU General Public License 12 | # along with this program. If not, see . 13 | 14 | import numpy as np 15 | from sklearn.utils.validation import check_consistent_length 16 | 17 | __all__ = ["StepFunction"] 18 | 19 | 20 | class StepFunction: 21 | """Callable step function. 22 | 23 | .. math:: 24 | 25 | f(z) = a * y_i + b, 26 | x_i \\leq z < x_{i + 1} 27 | 28 | Parameters 29 | ---------- 30 | x : ndarray, shape = (n_points,) 31 | Values on the x axis in ascending order. 32 | 33 | y : ndarray, shape = (n_points,) 34 | Corresponding values on the y axis. 35 | 36 | a : float, optional, default: 1.0 37 | Constant to multiply by. 38 | 39 | b : float, optional, default: 0.0 40 | Constant offset term. 41 | 42 | domain : tuple, optional 43 | A tuple with two entries that sets the limits of the 44 | domain of the step function. 45 | If entry is `None`, use the first/last value of `x` as limit. 46 | """ 47 | 48 | def __init__(self, x, y, *, a=1.0, b=0.0, domain=(0, None)): 49 | check_consistent_length(x, y) 50 | self.x = x 51 | self.y = y 52 | self.a = a 53 | self.b = b 54 | domain_lower = self.x[0] if domain[0] is None else domain[0] 55 | domain_upper = self.x[-1] if domain[1] is None else domain[1] 56 | self._domain = (float(domain_lower), float(domain_upper)) 57 | 58 | @property 59 | def domain(self): 60 | """Returns the domain of the function, that means 61 | the range of values that the function accepts. 62 | 63 | Returns 64 | ------- 65 | lower_limit : float 66 | Lower limit of domain. 67 | 68 | upper_limit : float 69 | Upper limit of domain. 70 | """ 71 | return self._domain 72 | 73 | def __call__(self, x): 74 | """Evaluate step function. 75 | 76 | Values outside the interval specified by `self.domain` 77 | will raise an exception. 78 | Values in `x` that are in the interval `[self.domain[0]; self.x[0]]` 79 | get mapped to `self.y[0]`. 80 | 81 | Parameters 82 | ---------- 83 | x : float|array-like, shape=(n_values,) 84 | Values to evaluate step function at. 85 | 86 | Returns 87 | ------- 88 | y : float|array-like, shape=(n_values,) 89 | Values of step function at `x`. 90 | """ 91 | x = np.atleast_1d(x) 92 | if not np.isfinite(x).all(): 93 | raise ValueError("x must be finite") 94 | if np.min(x) < self._domain[0] or np.max(x) > self.domain[1]: 95 | raise ValueError(f"x must be within [{self.domain[0]:f}; {self.domain[1]:f}]") 96 | 97 | # x is within the domain, but we need to account for self.domain[0] <= x < self.x[0] 98 | x = np.clip(x, a_min=self.x[0], a_max=None) 99 | 100 | i = np.searchsorted(self.x, x, side="left") 101 | not_exact = self.x[i] != x 102 | i[not_exact] -= 1 103 | value = self.a * self.y[i] + self.b 104 | if value.shape[0] == 1: 105 | return value[0] 106 | return value 107 | 108 | def __repr__(self): 109 | return f"StepFunction(x={self.x!r}, y={self.y!r}, a={self.a!r}, b={self.b!r})" 110 | 111 | def __eq__(self, other): 112 | if isinstance(other, type(self)): 113 | return all(self.x == other.x) and all(self.y == other.y) and self.a == other.a and self.b == other.b 114 | return False 115 | -------------------------------------------------------------------------------- /sksurv/io/__init__.py: -------------------------------------------------------------------------------- 1 | from .arffread import loadarff # noqa: F401 2 | from .arffwrite import writearff # noqa: F401 3 | -------------------------------------------------------------------------------- /sksurv/io/arffread.py: -------------------------------------------------------------------------------- 1 | # This program is free software: you can redistribute it and/or modify 2 | # it under the terms of the GNU General Public License as published by 3 | # the Free Software Foundation, either version 3 of the License, or 4 | # (at your option) any later version. 5 | # 6 | # This program is distributed in the hope that it will be useful, 7 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 8 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 9 | # GNU General Public License for more details. 10 | # 11 | # You should have received a copy of the GNU General Public License 12 | # along with this program. If not, see . 13 | import numpy as np 14 | import pandas as pd 15 | from scipy.io.arff import loadarff as scipy_loadarff 16 | 17 | __all__ = ["loadarff"] 18 | 19 | 20 | def _to_pandas(data, meta): 21 | data_dict = {} 22 | attrnames = sorted(meta.names()) 23 | for name in attrnames: 24 | tp, attr_format = meta[name] 25 | if tp == "nominal": 26 | raw = [] 27 | for b in data[name]: 28 | # replace missing values with NaN 29 | if b == b"?": 30 | raw.append(np.nan) 31 | else: 32 | raw.append(b.decode()) 33 | 34 | data_dict[name] = pd.Categorical(raw, categories=attr_format, ordered=False) 35 | else: 36 | arr = data[name] 37 | p = pd.Series(arr, dtype=arr.dtype) 38 | data_dict[name] = p 39 | 40 | # currently, this step converts all pandas.Categorial columns back to pandas.Series 41 | return pd.DataFrame.from_dict(data_dict) 42 | 43 | 44 | def loadarff(filename): 45 | """Load ARFF file 46 | 47 | Parameters 48 | ---------- 49 | filename : string 50 | Path to ARFF file 51 | 52 | Returns 53 | ------- 54 | data_frame : :class:`pandas.DataFrame` 55 | DataFrame containing data of ARFF file 56 | """ 57 | data, meta = scipy_loadarff(filename) 58 | return _to_pandas(data, meta) 59 | -------------------------------------------------------------------------------- /sksurv/io/arffwrite.py: -------------------------------------------------------------------------------- 1 | # This program is free software: you can redistribute it and/or modify 2 | # it under the terms of the GNU General Public License as published by 3 | # the Free Software Foundation, either version 3 of the License, or 4 | # (at your option) any later version. 5 | # 6 | # This program is distributed in the hope that it will be useful, 7 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 8 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 9 | # GNU General Public License for more details. 10 | # 11 | # You should have received a copy of the GNU General Public License 12 | # along with this program. If not, see . 13 | import os.path 14 | import re 15 | 16 | import numpy as np 17 | import pandas as pd 18 | from pandas.api.types import CategoricalDtype, is_object_dtype 19 | 20 | _ILLEGAL_CHARACTER_PAT = re.compile(r"[^-_=\w\d\(\)<>\.]") 21 | 22 | 23 | def writearff(data, filename, relation_name=None, index=True): 24 | """Write ARFF file 25 | 26 | Parameters 27 | ---------- 28 | data : :class:`pandas.DataFrame` 29 | DataFrame containing data 30 | 31 | filename : string or file-like object 32 | Path to ARFF file or file-like object. In the latter case, 33 | the handle is closed by calling this function. 34 | 35 | relation_name : string, optional, default: "pandas" 36 | Name of relation in ARFF file. 37 | 38 | index : boolean, optional, default: True 39 | Write row names (index) 40 | """ 41 | if isinstance(filename, str): 42 | fp = open(filename, "w") 43 | 44 | if relation_name is None: 45 | relation_name = os.path.basename(filename) 46 | else: 47 | fp = filename 48 | 49 | if relation_name is None: 50 | relation_name = "pandas" 51 | 52 | try: 53 | data = _write_header(data, fp, relation_name, index) 54 | fp.write("\n") 55 | _write_data(data, fp) 56 | finally: 57 | fp.close() 58 | 59 | 60 | def _write_header(data, fp, relation_name, index): 61 | """Write header containing attribute names and types""" 62 | fp.write(f"@relation {relation_name}\n\n") 63 | 64 | if index: 65 | data = data.reset_index() 66 | 67 | attribute_names = _sanitize_column_names(data) 68 | 69 | for column, series in data.items(): 70 | name = attribute_names[column] 71 | fp.write(f"@attribute {name}\t") 72 | 73 | if isinstance(series.dtype, CategoricalDtype) or is_object_dtype(series): 74 | _write_attribute_categorical(series, fp) 75 | elif np.issubdtype(series.dtype, np.floating): 76 | fp.write("real") 77 | elif np.issubdtype(series.dtype, np.integer): 78 | fp.write("integer") 79 | elif np.issubdtype(series.dtype, np.datetime64): 80 | fp.write("date 'yyyy-MM-dd HH:mm:ss'") 81 | else: 82 | raise TypeError(f"unsupported type {series.dtype}") 83 | 84 | fp.write("\n") 85 | return data 86 | 87 | 88 | def _sanitize_column_names(data): 89 | """Replace illegal characters with underscore""" 90 | new_names = {} 91 | for name in data.columns: 92 | new_names[name] = _ILLEGAL_CHARACTER_PAT.sub("_", name) 93 | return new_names 94 | 95 | 96 | def _check_str_value(x): 97 | """If string has a space, wrap it in double quotes and remove/escape illegal characters""" 98 | if isinstance(x, str): 99 | # remove commas, and single quotation marks since loadarff cannot deal with it 100 | x = x.replace(",", ".").replace(chr(0x2018), "'").replace(chr(0x2019), "'") 101 | 102 | # put string in double quotes 103 | if " " in x: 104 | if x[0] in ('"', "'"): 105 | x = x[1:] 106 | if x[-1] in ('"', "'"): 107 | x = x[: len(x) - 1] 108 | x = '"' + x.replace('"', '\\"') + '"' 109 | return str(x) 110 | 111 | 112 | _check_str_array = np.frompyfunc(_check_str_value, 1, 1) 113 | 114 | 115 | def _write_attribute_categorical(series, fp): 116 | """Write categories of a categorical/nominal attribute""" 117 | if isinstance(series.dtype, CategoricalDtype): 118 | categories = series.cat.categories 119 | string_values = _check_str_array(categories) 120 | else: 121 | categories = series.dropna().unique() 122 | string_values = sorted(_check_str_array(categories), key=lambda x: x.strip('"')) 123 | 124 | values = ",".join(string_values) 125 | fp.write("{") 126 | fp.write(values) 127 | fp.write("}") 128 | 129 | 130 | def _write_data(data, fp): 131 | """Write the data section""" 132 | fp.write("@data\n") 133 | 134 | def to_str(x): 135 | if pd.isnull(x): 136 | return "?" 137 | return str(x) 138 | 139 | data = data.applymap(to_str) 140 | n_rows = data.shape[0] 141 | for i in range(n_rows): 142 | str_values = list(data.iloc[i, :].apply(_check_str_array)) 143 | line = ",".join(str_values) 144 | fp.write(line) 145 | fp.write("\n") 146 | -------------------------------------------------------------------------------- /sksurv/kernels/__init__.py: -------------------------------------------------------------------------------- 1 | from .clinical import ClinicalKernelTransform, clinical_kernel # noqa: F401 2 | -------------------------------------------------------------------------------- /sksurv/kernels/_clinical_kernel.pyx: -------------------------------------------------------------------------------- 1 | # This program is free software: you can redistribute it and/or modify 2 | # it under the terms of the GNU General Public License as published by 3 | # the Free Software Foundation, either version 3 of the License, or 4 | # (at your option) any later version. 5 | # 6 | # This program is distributed in the hope that it will be useful, 7 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 8 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 9 | # GNU General Public License for more details. 10 | # 11 | # You should have received a copy of the GNU General Public License 12 | # along with this program. If not, see . 13 | cimport cython 14 | cimport numpy as cnp 15 | from libc cimport math 16 | 17 | cnp.import_array() 18 | 19 | 20 | @cython.wraparound(False) 21 | @cython.cdivision(True) 22 | @cython.boundscheck(False) 23 | cdef void _get_min_and_max(cnp.npy_double[:] x, cnp.npy_double * min_out, cnp.npy_double * max_out) noexcept nogil: 24 | cdef cnp.npy_double amin = x[0] 25 | cdef cnp.npy_double amax = x[0] 26 | cdef int i 27 | 28 | for i in range(x.shape[0]): 29 | if x[i] < amin: 30 | amin = x[i] 31 | if x[i] > amax: 32 | amax = x[i] 33 | 34 | min_out[0] = amin 35 | max_out[0] = amax 36 | 37 | 38 | @cython.wraparound(False) 39 | @cython.cdivision(True) 40 | @cython.boundscheck(False) 41 | def continuous_ordinal_kernel_with_ranges(cnp.npy_double[:, :] x, 42 | cnp.npy_double[:, :] y, 43 | cnp.npy_double[:] ranges, 44 | cnp.npy_double[:, :] out): 45 | cdef cnp.npy_intp n_samples_x = x.shape[0] 46 | cdef cnp.npy_intp n_samples_y = y.shape[0] 47 | cdef cnp.npy_intp n_features = x.shape[1] 48 | cdef int i, j, k 49 | 50 | if out.shape[0] != n_samples_x or out.shape[1] != n_samples_y: 51 | raise ValueError("out matrix must be of shape (%d, %d)" % out.shape) 52 | 53 | with nogil: 54 | for i in range(n_samples_x): 55 | for j in range(n_samples_y): 56 | for k in range(n_features): 57 | out[i, j] += (ranges[k] - math.fabs(x[i, k] - y[j, k])) / ranges[k] 58 | 59 | return out 60 | 61 | 62 | @cython.wraparound(False) 63 | @cython.cdivision(True) 64 | @cython.boundscheck(False) 65 | def continuous_ordinal_kernel(cnp.npy_double[:, :] x, 66 | cnp.npy_double[:, :] y, 67 | cnp.npy_double[:, :] out): 68 | cdef cnp.npy_intp n_samples_x = x.shape[0] 69 | cdef cnp.npy_intp n_features = x.shape[1] 70 | cdef cnp.npy_double min_x, max_x, min_y, max_y 71 | 72 | cdef cnp.npy_double[:] ranges = cnp.PyArray_EMPTY(1, &n_samples_x, cnp.NPY_DOUBLE, 0) 73 | with nogil: 74 | for i in range(n_features): 75 | _get_min_and_max(x[:, i], &min_x, &max_x) 76 | _get_min_and_max(y[:, i], &min_y, &max_y) 77 | ranges[i] = max(max_x, max_y) - min(min_x, min_y) 78 | 79 | return continuous_ordinal_kernel_with_ranges(x, y, ranges, out) 80 | 81 | 82 | @cython.wraparound(False) 83 | @cython.cdivision(True) 84 | @cython.boundscheck(False) 85 | def pairwise_continuous_ordinal_kernel(cnp.npy_double[:] x, 86 | cnp.npy_double[:] y, 87 | cnp.npy_double[:] ranges): 88 | cdef cnp.npy_double out = 0 89 | cdef int k 90 | 91 | with nogil: 92 | for k in range(x.shape[0]): 93 | out += (ranges[k] - math.fabs(x[k] - y[k])) / ranges[k] 94 | 95 | return out 96 | 97 | 98 | @cython.wraparound(False) 99 | @cython.cdivision(True) 100 | @cython.boundscheck(False) 101 | def pairwise_nominal_kernel(cnp.npy_int8[:] x, 102 | cnp.npy_int8[:] y): 103 | cdef cnp.npy_double out = 0 104 | cdef int k 105 | 106 | with nogil: 107 | for k in range(x.shape[0]): 108 | if x[k] == y[k]: 109 | out += 1. 110 | 111 | return out 112 | -------------------------------------------------------------------------------- /sksurv/linear_model/__init__.py: -------------------------------------------------------------------------------- 1 | from .aft import IPCRidge # noqa: F401 2 | from .coxnet import CoxnetSurvivalAnalysis # noqa: F401 3 | from .coxph import CoxPHSurvivalAnalysis # noqa: F401 4 | -------------------------------------------------------------------------------- /sksurv/linear_model/_coxnet.pyx: -------------------------------------------------------------------------------- 1 | # This program is free software: you can redistribute it and/or modify 2 | # it under the terms of the GNU General Public License as published by 3 | # the Free Software Foundation, either version 3 of the License, or 4 | # (at your option) any later version. 5 | # 6 | # This program is distributed in the hope that it will be useful, 7 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 8 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 9 | # GNU General Public License for more details. 10 | # 11 | # You should have received a copy of the GNU General Public License 12 | # along with this program. If not, see . 13 | cimport numpy as cnp 14 | from libcpp cimport bool 15 | 16 | cnp.import_array() 17 | 18 | 19 | cdef extern from "coxnet_wrapper.h": 20 | cdef object fit_coxnet[T, S, U] (cnp.ndarray, cnp.ndarray, cnp.ndarray, cnp.ndarray, 21 | cnp.ndarray, bool, cnp.npy_float64, cnp.npy_float64, int, double, bool) except + 22 | 23 | 24 | cdef extern from "coxnet_wrapper.h" namespace "Eigen": 25 | 26 | cdef cppclass Dynamic: 27 | pass 28 | 29 | cdef cppclass RowMajor: 30 | pass 31 | 32 | cdef cppclass ColMajor: 33 | pass 34 | 35 | cdef cppclass Aligned: 36 | pass 37 | 38 | cdef cppclass Unaligned: 39 | pass 40 | 41 | cdef cppclass PlainObjectBase: 42 | pass 43 | 44 | cdef cppclass Matrix(PlainObjectBase): 45 | pass 46 | 47 | cdef cppclass MatrixXd(PlainObjectBase): 48 | pass 49 | 50 | cdef cppclass VectorXd(PlainObjectBase): 51 | pass 52 | 53 | cdef cppclass VectorXuint8(PlainObjectBase): 54 | pass 55 | 56 | 57 | def call_fit_coxnet(cnp.ndarray[cnp.npy_float64, ndim=2, mode='fortran'] X, 58 | cnp.ndarray[cnp.npy_float64, ndim=1] time, 59 | cnp.ndarray[cnp.npy_uint8, ndim=1] event, 60 | cnp.ndarray[cnp.npy_float64, ndim=1] penalty_factor, 61 | cnp.ndarray[cnp.npy_float64, ndim=1] alphas, 62 | bool create_path, 63 | cnp.npy_float64 alpha_min_ratio, 64 | cnp.npy_float64 l1_ratio, 65 | int max_iter, 66 | cnp.npy_float64 eps, 67 | bool verbose): 68 | cdef object result = fit_coxnet[MatrixXd, VectorXd, VectorXuint8] ( 69 | X, time, event, penalty_factor, alphas, create_path, 70 | alpha_min_ratio, l1_ratio, max_iter, eps, verbose) 71 | return result 72 | -------------------------------------------------------------------------------- /sksurv/linear_model/src/coxnet/constants.h: -------------------------------------------------------------------------------- 1 | /** 2 | * This program is free software: you can redistribute it and/or modify 3 | * it under the terms of the GNU General Public License as published by 4 | * the Free Software Foundation, either version 3 of the License, or 5 | * (at your option) any later version. 6 | * 7 | * This program is distributed in the hope that it will be useful, 8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 10 | * GNU General Public License for more details. 11 | * 12 | * You should have received a copy of the GNU General Public License 13 | * along with this program. If not, see . 14 | */ 15 | #ifndef GLMNET_CONSTANTS_H 16 | #define GLMNET_CONSTANTS_H 17 | 18 | #include 19 | 20 | #define LOG_99999 11.512915464920228086754123920088321036576923514461324197617865972730311703692501039693619281678632872501544746697 21 | 22 | #if defined(_MSC_VER) 23 | #define COXNET_CONSTEXPR 24 | #else 25 | #define COXNET_CONSTEXPR constexpr 26 | #endif 27 | 28 | 29 | namespace coxnet { 30 | 31 | template 32 | struct Constants { 33 | static COXNET_CONSTEXPR Scalar BIG() { return Scalar(1e35); } 34 | static COXNET_CONSTEXPR Scalar WEIGHTS_SUM_MIN() { return Scalar((1.0+1.0E-5)*1.0E-5*(1.0-1.0E-5)); } 35 | static COXNET_CONSTEXPR Scalar PMAX() { return Scalar(LOG_99999); } 36 | static COXNET_CONSTEXPR Scalar PMIN() { return Scalar(-LOG_99999); } 37 | static COXNET_CONSTEXPR int MIN_ALPHAS() { return 5; } 38 | }; 39 | 40 | }; 41 | 42 | #endif //GLMNET_CONSTANTS_H 43 | -------------------------------------------------------------------------------- /sksurv/linear_model/src/coxnet/data.h: -------------------------------------------------------------------------------- 1 | /** 2 | * This program is free software: you can redistribute it and/or modify 3 | * it under the terms of the GNU General Public License as published by 4 | * the Free Software Foundation, either version 3 of the License, or 5 | * (at your option) any later version. 6 | * 7 | * This program is distributed in the hope that it will be useful, 8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 10 | * GNU General Public License for more details. 11 | * 12 | * You should have received a copy of the GNU General Public License 13 | * along with this program. If not, see . 14 | */ 15 | #ifndef GLMNET_DATA_H 16 | #define GLMNET_DATA_H 17 | 18 | #include 19 | #include 20 | 21 | namespace coxnet { 22 | 23 | enum { 24 | FLOATING_POINT_ARGUMENT_PASSED__INTEGER_WAS_EXPECTED=1 25 | }; 26 | 27 | template < 28 | typename DerivedMatrix, 29 | typename DerivedFloatVector, 30 | typename DerivedIntVector > 31 | class Data 32 | { 33 | public: 34 | typedef typename DerivedMatrix::Index Index; 35 | typedef Eigen::MatrixBase Matrix; 36 | typedef Eigen::MatrixBase FloatVector; 37 | typedef Eigen::MatrixBase IntVector; 38 | 39 | EIGEN_STATIC_ASSERT_VECTOR_ONLY(DerivedFloatVector); 40 | EIGEN_STATIC_ASSERT_VECTOR_ONLY(DerivedIntVector); 41 | EIGEN_STATIC_ASSERT(Eigen::NumTraits::IsInteger, 42 | FLOATING_POINT_ARGUMENT_PASSED__INTEGER_WAS_EXPECTED); 43 | 44 | Data(const Matrix &x, 45 | const FloatVector &time, 46 | const IntVector &event, 47 | const FloatVector &penalty_factor) : m_x(x), m_time(time), m_event(event), 48 | m_penalty_factor(penalty_factor), 49 | m_samples(x.rows()), 50 | m_features(x.cols()) 51 | { 52 | eigen_assert (time.size() == x.rows()); 53 | eigen_assert (event.size() == x.rows()); 54 | eigen_assert (penalty_factor.size() == x.cols()); 55 | eigen_assert ((event.array() >= 0).all()); 56 | eigen_assert ((event.array() <= 1).all()); 57 | } 58 | 59 | const DerivedMatrix& x() const { return m_x.derived(); } 60 | const DerivedFloatVector& time() const { return m_time.derived(); } 61 | const DerivedIntVector& event() const { return m_event.derived(); } 62 | const DerivedFloatVector& penalty_factor() const { return m_penalty_factor.derived(); } 63 | const Index& n_samples() const { return m_samples; } 64 | const Index& n_features() const { return m_features; } 65 | 66 | template 67 | friend std::ostream& operator<< (std::ostream& os, const Data<_M, _V, _I> &obj); 68 | 69 | private: 70 | const Matrix &m_x; 71 | const FloatVector &m_time; 72 | const IntVector &m_event; 73 | const FloatVector &m_penalty_factor; 74 | const Index m_samples; 75 | const Index m_features; 76 | }; 77 | 78 | template 79 | std::ostream& operator<< (std::ostream& os, const Data<_M, _V, _I> &obj) { 80 | os << "Data(x=" << obj.m_x.size() << ", " 81 | << "time=" << obj.m_time.size() << ", " 82 | << "event=" << obj.m_event.size() << ", " 83 | << "penalty_factor=" << obj.m_penalty_factor.size() 84 | << ")"; 85 | return os; 86 | } 87 | 88 | } 89 | 90 | #endif 91 | -------------------------------------------------------------------------------- /sksurv/linear_model/src/coxnet/error.h: -------------------------------------------------------------------------------- 1 | /** 2 | * This program is free software: you can redistribute it and/or modify 3 | * it under the terms of the GNU General Public License as published by 4 | * the Free Software Foundation, either version 3 of the License, or 5 | * (at your option) any later version. 6 | * 7 | * This program is distributed in the hope that it will be useful, 8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 10 | * GNU General Public License for more details. 11 | * 12 | * You should have received a copy of the GNU General Public License 13 | * along with this program. If not, see . 14 | */ 15 | #ifndef GLMNET_ERROR_H 16 | #define GLMNET_ERROR_H 17 | 18 | enum ErrorType { 19 | NONE, 20 | WEIGHT_TOO_LARGE, // exponential of weight is infinite 21 | }; 22 | 23 | #endif //GLMNET_ERROR_H 24 | -------------------------------------------------------------------------------- /sksurv/linear_model/src/coxnet/fit_params.h: -------------------------------------------------------------------------------- 1 | /** 2 | * This program is free software: you can redistribute it and/or modify 3 | * it under the terms of the GNU General Public License as published by 4 | * the Free Software Foundation, either version 3 of the License, or 5 | * (at your option) any later version. 6 | * 7 | * This program is distributed in the hope that it will be useful, 8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 10 | * GNU General Public License for more details. 11 | * 12 | * You should have received a copy of the GNU General Public License 13 | * along with this program. If not, see . 14 | */ 15 | #ifndef GLMNET_FIT_PARAMS_H 16 | #define GLMNET_FIT_PARAMS_H 17 | 18 | #include 19 | #include 20 | 21 | #include "error.h" 22 | #include "ordered_dict.h" 23 | 24 | 25 | namespace coxnet { 26 | 27 | template 28 | struct FitParams { 29 | typedef typename VectorType::Index Index; 30 | typedef typename VectorType::Scalar Scalar; 31 | 32 | FitParams(Index n_samples, 33 | Index n_features, 34 | double _eps) : coef_x(n_features), 35 | residuals(n_samples), 36 | weights(n_samples), 37 | risk_set(n_samples), 38 | xw(n_samples), 39 | eps(_eps), 40 | maybe_active_set(n_features), 41 | n_iterations(0), 42 | error_type(NONE) 43 | { 44 | } 45 | 46 | void init(); 47 | bool has_error() const { return error_type != NONE; } 48 | 49 | VectorType coef_x; 50 | VectorType residuals; 51 | VectorType weights; 52 | VectorType risk_set; 53 | VectorType xw; 54 | double eps; 55 | 56 | Eigen::Array maybe_active_set; 57 | ordered_dict active_set; 58 | std::size_t n_iterations; 59 | ErrorType error_type; 60 | }; 61 | 62 | 63 | template 64 | void FitParams::init() { 65 | maybe_active_set.setZero(); 66 | coef_x.setZero(); 67 | xw.setZero(); 68 | } 69 | 70 | }; 71 | 72 | #endif //GLMNET_FIT_PARAMS_H 73 | -------------------------------------------------------------------------------- /sksurv/linear_model/src/coxnet/fit_result.h: -------------------------------------------------------------------------------- 1 | /** 2 | * This program is free software: you can redistribute it and/or modify 3 | * it under the terms of the GNU General Public License as published by 4 | * the Free Software Foundation, either version 3 of the License, or 5 | * (at your option) any later version. 6 | * 7 | * This program is distributed in the hope that it will be useful, 8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 10 | * GNU General Public License for more details. 11 | * 12 | * You should have received a copy of the GNU General Public License 13 | * along with this program. If not, see . 14 | */ 15 | #ifndef GLMNET_FIT_RESULT_H 16 | #define GLMNET_FIT_RESULT_H 17 | 18 | #include 19 | #include "error.h" 20 | 21 | 22 | namespace coxnet { 23 | 24 | template 25 | class FitResult { 26 | public: 27 | typedef T MatrixType; 28 | typedef S VectorType; 29 | 30 | FitResult(MatrixType &coef, 31 | VectorType &alphas, 32 | VectorType &deviance_ratio) : m_coef_path(coef), 33 | m_alphas(alphas), 34 | m_deviance_ratio(deviance_ratio), 35 | m_iterations(0), 36 | m_n_alphas(0), 37 | m_error(NONE) 38 | {} 39 | 40 | const MatrixType& getCoefficientPath() const { 41 | return m_coef_path; 42 | } 43 | MatrixType& getCoefficientPath() { 44 | return m_coef_path; 45 | } 46 | 47 | const VectorType& getAlphas() const { 48 | return m_alphas; 49 | } 50 | VectorType& getAlphas() { 51 | return m_alphas; 52 | } 53 | 54 | const VectorType& getDevianceRatio() const { 55 | return m_deviance_ratio; 56 | } 57 | VectorType& getDevianceRatio() { 58 | return m_deviance_ratio; 59 | } 60 | 61 | std::size_t getNumberOfIterations() const { 62 | return m_iterations; 63 | } 64 | void setNumberOfIterations(const std::size_t value) { 65 | m_iterations = value; 66 | } 67 | 68 | typename VectorType::Index getNumberOfAlphas() const { 69 | return m_n_alphas; 70 | } 71 | void setNumberOfAlphas(const typename VectorType::Index value) { 72 | m_n_alphas = value; 73 | } 74 | 75 | ErrorType getError() const { 76 | return m_error; 77 | } 78 | void setError(const ErrorType error_type) { 79 | m_error = error_type; 80 | } 81 | 82 | private: 83 | MatrixType &m_coef_path; 84 | VectorType &m_alphas; 85 | VectorType &m_deviance_ratio; 86 | std::size_t m_iterations; 87 | typename VectorType::Index m_n_alphas; 88 | ErrorType m_error; 89 | 90 | // intentionally not implemented 91 | FitResult (const FitResult&); 92 | FitResult& operator=(const FitResult&); 93 | }; 94 | 95 | } 96 | 97 | #endif //GLMNET_FIT_RESULT_H 98 | -------------------------------------------------------------------------------- /sksurv/linear_model/src/coxnet/ordered_dict.h: -------------------------------------------------------------------------------- 1 | /** 2 | * This program is free software: you can redistribute it and/or modify 3 | * it under the terms of the GNU General Public License as published by 4 | * the Free Software Foundation, either version 3 of the License, or 5 | * (at your option) any later version. 6 | * 7 | * This program is distributed in the hope that it will be useful, 8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 10 | * GNU General Public License for more details. 11 | * 12 | * You should have received a copy of the GNU General Public License 13 | * along with this program. If not, see . 14 | */ 15 | #ifndef GLMNET_ORDERED_DICT_H 16 | #define GLMNET_ORDERED_DICT_H 17 | 18 | #include 19 | #include 20 | #include 21 | 22 | 23 | namespace coxnet { 24 | 25 | template 26 | struct __link { 27 | typedef Key key_type; 28 | typedef __link link_type; 29 | typedef std::shared_ptr pointer; 30 | 31 | key_type key; 32 | pointer next; 33 | std::weak_ptr prev; 34 | 35 | explicit __link() {} 36 | __link (const Key &_key) : key(_key) {} 37 | __link (const Key &_key, 38 | pointer &_next, 39 | pointer &_prev) : key(_key), next(_next), prev(_prev) {} 40 | }; 41 | 42 | template 43 | class ordered_dict_iterator { 44 | public: 45 | typedef T link_type; 46 | typedef std::shared_ptr link_pointer; 47 | typedef typename link_type::key_type key_type; 48 | typedef ordered_dict_iterator iterator; 49 | 50 | ordered_dict_iterator(const link_pointer &__root) : m_root(__root) {} 51 | 52 | iterator& operator++() { 53 | link_pointer curr = m_root->next; 54 | m_root = curr; 55 | return *this; 56 | } 57 | 58 | const key_type& operator*() const { 59 | return m_root->key; 60 | } 61 | 62 | friend 63 | bool operator==(const iterator& __x, const iterator& __y) { 64 | return __x.m_root == __y.m_root; 65 | } 66 | friend 67 | bool operator!=(const iterator& __x, const iterator& __y) { 68 | return !(__x.m_root == __y.m_root); 69 | } 70 | 71 | private: 72 | link_pointer m_root; 73 | }; 74 | 75 | template 76 | class ordered_dict : public std::set { 77 | public: 78 | typedef std::set base; 79 | typedef typename base::key_type key_type; 80 | typedef __link link_type; 81 | typedef std::shared_ptr link_pointer; 82 | typedef ordered_dict_iterator const_iterator; 83 | 84 | explicit ordered_dict() { 85 | m_root = std::make_shared(-1); 86 | m_root->next = m_root; 87 | m_root->prev = m_root; 88 | } 89 | 90 | void insert_ordered( const key_type &key ); 91 | 92 | const_iterator cbegin_ordered() const { 93 | return const_iterator(m_root->next); 94 | } 95 | const_iterator cend_ordered() const { 96 | return const_iterator(m_root); 97 | } 98 | 99 | private: 100 | std::map m_map; 101 | link_pointer m_root; 102 | }; 103 | 104 | 105 | template 106 | void ordered_dict::insert_ordered( const key_type &key ) 107 | { 108 | auto search = this->find(key); 109 | if (search == this->end()) { 110 | link_pointer last(m_root->prev.lock()); 111 | link_pointer new_link = std::make_shared(key, m_root, last); 112 | last->next = new_link; 113 | m_root->prev = new_link; 114 | 115 | m_map.emplace(std::make_pair(key, new_link)); 116 | } 117 | this->insert(key); 118 | }; 119 | 120 | }; 121 | 122 | #endif //GLMNET_ORDERED_DICT_H 123 | -------------------------------------------------------------------------------- /sksurv/linear_model/src/coxnet/parameters.h: -------------------------------------------------------------------------------- 1 | /** 2 | * This program is free software: you can redistribute it and/or modify 3 | * it under the terms of the GNU General Public License as published by 4 | * the Free Software Foundation, either version 3 of the License, or 5 | * (at your option) any later version. 6 | * 7 | * This program is distributed in the hope that it will be useful, 8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 10 | * GNU General Public License for more details. 11 | * 12 | * You should have received a copy of the GNU General Public License 13 | * along with this program. If not, see . 14 | */ 15 | #ifndef GLMNET_PARAMETERS_H 16 | #define GLMNET_PARAMETERS_H 17 | 18 | #include 19 | 20 | 21 | namespace coxnet { 22 | 23 | class Parameters { 24 | public: 25 | Parameters() : m_alpha_min_ratio(0.01), m_l1_ratio(0.5), m_max_iter(10000), m_eps(1e-7), m_verbose(false) {} 26 | Parameters(const double alpha_min_ratio, 27 | const double l1_ratio, 28 | const std::size_t max_iter, 29 | const double eps, 30 | const bool verbose) : m_alpha_min_ratio(alpha_min_ratio), 31 | m_l1_ratio(l1_ratio), 32 | m_max_iter(max_iter), 33 | m_eps(eps), 34 | m_verbose(verbose) {} 35 | 36 | void set_alpha_min_ratio(const double value) { 37 | m_alpha_min_ratio = value; 38 | } 39 | double get_alpha_min_ratio() const { 40 | return m_alpha_min_ratio; 41 | } 42 | 43 | void set_l1_ratio(const double value) { 44 | m_l1_ratio = value; 45 | } 46 | double get_l1_ratio() const { 47 | return m_l1_ratio; 48 | } 49 | 50 | void set_max_iter(const std::size_t value) { 51 | m_max_iter = value; 52 | } 53 | std::size_t get_max_iter() const { 54 | return m_max_iter; 55 | } 56 | 57 | void set_tolerance(const double value) { 58 | m_eps = value; 59 | } 60 | double get_tolerance() const { 61 | return m_eps; 62 | } 63 | 64 | void set_verbose(const bool value) { 65 | m_verbose = value; 66 | } 67 | bool is_verbose() const { 68 | return m_verbose; 69 | } 70 | 71 | private: 72 | double m_alpha_min_ratio; 73 | double m_l1_ratio; 74 | std::size_t m_max_iter; 75 | double m_eps; 76 | bool m_verbose; 77 | }; 78 | 79 | } 80 | 81 | #endif // GLMNET_PARAMETERS_H 82 | -------------------------------------------------------------------------------- /sksurv/linear_model/src/coxnet/soft_threshold.h: -------------------------------------------------------------------------------- 1 | /** 2 | * This program is free software: you can redistribute it and/or modify 3 | * it under the terms of the GNU General Public License as published by 4 | * the Free Software Foundation, either version 3 of the License, or 5 | * (at your option) any later version. 6 | * 7 | * This program is distributed in the hope that it will be useful, 8 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 9 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 10 | * GNU General Public License for more details. 11 | * 12 | * You should have received a copy of the GNU General Public License 13 | * along with this program. If not, see . 14 | */ 15 | #ifndef GLMNET_SOFT_THRESHOLD_H 16 | #define GLMNET_SOFT_THRESHOLD_H 17 | 18 | #include 19 | #include 20 | 21 | 22 | template 23 | bool is_zero (const _T value, 24 | const _T &prec = Eigen::NumTraits<_T>::dummy_precision()) 25 | { 26 | return (std::fabs(value) <= prec); 27 | } 28 | 29 | inline 30 | double fsign(double f) { 31 | double val; 32 | if (is_zero(f)) 33 | val = 0.; 34 | else if (f > 0) 35 | val = 1.; 36 | else //if (f < 0) 37 | val = -1.; 38 | return val; 39 | } 40 | 41 | inline 42 | float fsign(float f) { 43 | float val; 44 | if (is_zero(f)) 45 | val = 0.f; 46 | else if (f > 0) 47 | val = 1.f; 48 | else //if (f < 0) 49 | val = -1.f; 50 | return val; 51 | } 52 | 53 | inline 54 | double soft_threshold(double z, double t) { 55 | double v = std::fabs(z) - t; 56 | if (!is_zero(v) && v > 0) 57 | return fsign(z) * v; 58 | return 0.0; 59 | } 60 | 61 | inline 62 | float soft_threshold(float z, float t) { 63 | float v = std::fabs(z) - t; 64 | if (!is_zero(v) && v > 0) 65 | return fsign(z) * v; 66 | return 0.0f; 67 | } 68 | 69 | #endif //GLMNET_SOFT_THRESHOLD_H 70 | -------------------------------------------------------------------------------- /sksurv/meta/__init__.py: -------------------------------------------------------------------------------- 1 | from .ensemble_selection import EnsembleSelection, EnsembleSelectionRegressor, MeanEstimator 2 | from .stacking import Stacking 3 | 4 | __all__ = ["EnsembleSelection", "EnsembleSelectionRegressor", "MeanEstimator", "Stacking"] 5 | -------------------------------------------------------------------------------- /sksurv/meta/base.py: -------------------------------------------------------------------------------- 1 | # This program is free software: you can redistribute it and/or modify 2 | # it under the terms of the GNU General Public License as published by 3 | # the Free Software Foundation, either version 3 of the License, or 4 | # (at your option) any later version. 5 | # 6 | # This program is distributed in the hope that it will be useful, 7 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 8 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 9 | # GNU General Public License for more details. 10 | # 11 | # You should have received a copy of the GNU General Public License 12 | # along with this program. If not, see . 13 | import numbers 14 | 15 | from sklearn.utils.metaestimators import _safe_split 16 | 17 | 18 | def _fit_and_score(est, x, y, scorer, train_index, test_index, parameters, fit_params, predict_params): 19 | """Train survival model on given data and return its score on test data""" 20 | X_train, y_train = _safe_split(est, x, y, train_index) 21 | train_params = fit_params.copy() 22 | 23 | # Training 24 | est.set_params(**parameters) 25 | est.fit(X_train, y_train, **train_params) 26 | 27 | # Testing 28 | test_predict_params = predict_params.copy() 29 | X_test, y_test = _safe_split(est, x, y, test_index, train_index) 30 | 31 | score = scorer(est, X_test, y_test, **test_predict_params) 32 | if not isinstance(score, numbers.Number): 33 | raise ValueError(f"scoring must return a number, got {score!s} ({type(score)}) instead.") 34 | 35 | return score 36 | -------------------------------------------------------------------------------- /sksurv/svm/__init__.py: -------------------------------------------------------------------------------- 1 | from .minlip import HingeLossSurvivalSVM, MinlipSurvivalAnalysis 2 | from .naive_survival_svm import NaiveSurvivalSVM 3 | from .survival_svm import FastKernelSurvivalSVM, FastSurvivalSVM 4 | 5 | __all__ = [ 6 | "FastKernelSurvivalSVM", 7 | "FastSurvivalSVM", 8 | "HingeLossSurvivalSVM", 9 | "MinlipSurvivalAnalysis", 10 | "NaiveSurvivalSVM", 11 | ] 12 | -------------------------------------------------------------------------------- /sksurv/svm/_prsvm.pyx: -------------------------------------------------------------------------------- 1 | # This program is free software: you can redistribute it and/or modify 2 | # it under the terms of the GNU General Public License as published by 3 | # the Free Software Foundation, either version 3 of the License, or 4 | # (at your option) any later version. 5 | # 6 | # This program is distributed in the hope that it will be useful, 7 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 8 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 9 | # GNU General Public License for more details. 10 | # 11 | # You should have received a copy of the GNU General Public License 12 | # along with this program. If not, see . 13 | cimport cython 14 | 15 | from cython.operator import preincrement 16 | import numpy as np 17 | 18 | cimport numpy as cnp 19 | 20 | from scipy.sparse import csr_matrix 21 | 22 | cnp.import_array() 23 | 24 | 25 | @cython.wraparound(False) 26 | @cython.cdivision(True) 27 | @cython.boundscheck(False) 28 | def survival_constraints_simple(cnp.npy_uint8[:] y): 29 | cdef int i 30 | cdef int j 31 | cdef int k = 0 32 | cdef cnp.npy_intp n_samples = y.shape[0] 33 | cdef cnp.npy_intp n = n_samples * (n_samples - 1) 34 | 35 | cdef cnp.ndarray[cnp.npy_int8, ndim=1] data = cnp.PyArray_EMPTY(1, &n, cnp.NPY_INT8, 0) 36 | cdef cnp.ndarray[cnp.npy_intp, ndim=1] indices = cnp.PyArray_EMPTY(1, &n, cnp.NPY_INTP, 0) 37 | 38 | with nogil: 39 | for i in range(n_samples - 1): 40 | if y[i] == 0: 41 | continue 42 | 43 | for j in range(i + 1, n_samples): 44 | data[k] = -1 45 | data[k + 1] = 1 46 | indices[k] = i 47 | indices[k + 1] = j 48 | k += 2 49 | 50 | data.resize(k, refcheck=False) 51 | indices.resize(k, refcheck=False) 52 | 53 | cdef object indptr = cnp.PyArray_Arange(0, k + 1, 2, cnp.NPY_INTP) 54 | A = csr_matrix((data, indices, indptr), shape=(k // 2, n_samples), dtype=np.int8) 55 | 56 | return A 57 | 58 | 59 | @cython.wraparound(False) 60 | @cython.cdivision(True) 61 | @cython.boundscheck(False) 62 | def survival_constraints_with_support_vectors(cnp.npy_uint8[:] y, 63 | cnp.npy_double[:] xw): 64 | cdef int i 65 | cdef int j 66 | cdef cnp.npy_double vi 67 | cdef int k = 0 68 | cdef cnp.npy_intp n_samples = y.shape[0] 69 | cdef cnp.npy_intp n = n_samples * (n_samples - 1) 70 | 71 | cdef cnp.ndarray[cnp.npy_int8, ndim=1] data = cnp.PyArray_EMPTY(1, &n, cnp.NPY_INT8, 0) 72 | cdef cnp.ndarray[cnp.npy_intp, ndim=1] indices = cnp.PyArray_EMPTY(1, &n, cnp.NPY_INTP, 0) 73 | 74 | with nogil: 75 | for i in range(n_samples - 1): 76 | if y[i] == 0: 77 | continue 78 | vi = xw[i] + 1. 79 | 80 | for j in range(i + 1, n_samples): 81 | if vi > xw[j]: 82 | data[k] = -1 83 | data[k + 1] = 1 84 | indices[k] = i 85 | indices[k + 1] = j 86 | k += 2 87 | 88 | data.resize(k, refcheck=False) 89 | indices.resize(k, refcheck=False) 90 | 91 | cdef object indptr = cnp.PyArray_Arange(0, k + 1, 2, cnp.NPY_INTP) 92 | A = csr_matrix((data, indices, indptr), shape=(k // 2, n_samples), dtype=np.int8) 93 | 94 | return A 95 | -------------------------------------------------------------------------------- /sksurv/testing.py: -------------------------------------------------------------------------------- 1 | # This program is free software: you can redistribute it and/or modify 2 | # it under the terms of the GNU General Public License as published by 3 | # the Free Software Foundation, either version 3 of the License, or 4 | # (at your option) any later version. 5 | # 6 | # This program is distributed in the hope that it will be useful, 7 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 8 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 9 | # GNU General Public License for more details. 10 | # 11 | # You should have received a copy of the GNU General Public License 12 | # along with this program. If not, see . 13 | from importlib import import_module 14 | import inspect 15 | from pathlib import Path 16 | import pkgutil 17 | 18 | import numpy as np 19 | from numpy.testing import assert_almost_equal, assert_array_equal 20 | import pytest 21 | from sklearn.base import BaseEstimator, TransformerMixin 22 | 23 | import sksurv 24 | from sksurv.metrics import concordance_index_censored 25 | 26 | 27 | def assert_cindex_almost_equal(event_indicator, event_time, estimate, expected): 28 | result = concordance_index_censored(event_indicator, event_time, estimate) 29 | assert_array_equal(result[1:], expected[1:]) 30 | concordant, discordant, tied_risk = result[1:4] 31 | cc = (concordant + 0.5 * tied_risk) / (concordant + discordant + tied_risk) 32 | assert_almost_equal(result[0], cc) 33 | assert_almost_equal(result[0], expected[0]) 34 | 35 | 36 | def assert_survival_function_properties(surv_fns): 37 | if not np.isfinite(surv_fns).all(): 38 | raise AssertionError("survival function contains values that are not finite") 39 | if np.any(surv_fns < 0.0): 40 | raise AssertionError("survival function contains negative values") 41 | if np.any(surv_fns > 1.0): 42 | raise AssertionError("survival function contains values larger 1") 43 | 44 | d = np.apply_along_axis(np.diff, 1, surv_fns) 45 | if np.any(d > 0): 46 | raise AssertionError("survival functions are not monotonically decreasing") 47 | 48 | # survival function at first time point 49 | num_closer_to_zero = np.sum(1.0 - surv_fns[:, 0] >= surv_fns[:, 0]) 50 | if num_closer_to_zero / surv_fns.shape[0] > 0.5: 51 | raise AssertionError(f"most ({num_closer_to_zero}) probabilities at first time point are closer to 0 than 1") 52 | 53 | # survival function at last time point 54 | num_closer_to_one = np.sum(1.0 - surv_fns[:, -1] < surv_fns[:, -1]) 55 | if num_closer_to_one / surv_fns.shape[0] > 0.5: 56 | raise AssertionError(f"most ({num_closer_to_one}) probabilities at last time point are closer to 1 than 0") 57 | 58 | 59 | def assert_chf_properties(chf): 60 | if not np.isfinite(chf).all(): 61 | raise AssertionError("chf contains values that are not finite") 62 | if np.any(chf < 0.0): 63 | raise AssertionError("chf contains negative values") 64 | 65 | d = np.apply_along_axis(np.diff, 1, chf) 66 | if np.any(d < 0): 67 | raise AssertionError("chf are not monotonically increasing") 68 | 69 | # chf at first time point 70 | num_closer_to_one = np.sum(1.0 - chf[:, 0] < chf[:, 0]) 71 | if num_closer_to_one / chf.shape[0] > 0.5: 72 | raise AssertionError(f"most ({num_closer_to_one}) hazard rates at first time point are closer to 1 than 0") 73 | 74 | 75 | def _is_survival_estimator(x): 76 | return ( 77 | inspect.isclass(x) 78 | and issubclass(x, BaseEstimator) 79 | and not issubclass(x, TransformerMixin) 80 | and x.__module__.startswith("sksurv.") 81 | and not x.__name__.startswith("_") 82 | and x.__module__.split(".", 2)[1] not in {"metrics", "nonparametric"} 83 | ) 84 | 85 | 86 | def all_survival_estimators(): 87 | root = str(Path(sksurv.__file__).parent) 88 | all_classes = [] 89 | for _importer, modname, _ispkg in pkgutil.walk_packages(path=[root], prefix="sksurv."): 90 | # meta-estimators require base estimators 91 | if modname.startswith("sksurv.meta"): 92 | continue 93 | module = import_module(modname) 94 | for _name, cls in inspect.getmembers(module, _is_survival_estimator): 95 | if inspect.isabstract(cls): 96 | continue 97 | all_classes.append(cls) 98 | return set(all_classes) 99 | 100 | 101 | class FixtureParameterFactory: 102 | def get_cases(self): 103 | cases = [] 104 | for name, func in inspect.getmembers(self): 105 | if name.startswith("data_"): 106 | values = func() 107 | cases.append(pytest.param(*values, id=name)) 108 | return cases 109 | -------------------------------------------------------------------------------- /sksurv/tree/__init__.py: -------------------------------------------------------------------------------- 1 | from .tree import ExtraSurvivalTree, SurvivalTree # noqa: F401 2 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from pathlib import Path 3 | import tempfile 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import pytest 8 | from scipy.sparse import coo_matrix 9 | 10 | from sksurv.column import categorical_to_numeric, encode_categorical, standardize 11 | from sksurv.datasets import load_breast_cancer, load_whas500 12 | from sksurv.util import Surv 13 | 14 | DataSet = namedtuple("DataSet", ["x", "y"]) 15 | DataSetWithNames = namedtuple("DataSetWithNames", ["x", "y", "names", "x_data_frame"]) 16 | SparseDataSet = namedtuple("SparseDataSet", ["x_dense", "x_sparse", "y"]) 17 | 18 | 19 | def pytest_configure(config): 20 | config.addinivalue_line("markers", "slow: marks test as slow (deselect with '-m \"not slow\"')") 21 | 22 | 23 | @pytest.fixture() 24 | def fake_data(): 25 | x = np.random.randn(100, 11) 26 | y = Surv.from_arrays(np.ones(100, dtype=bool), np.arange(1, 101, dtype=float)) 27 | return x, y 28 | 29 | 30 | @pytest.fixture() 31 | def breast_cancer(): 32 | X_str, y = load_breast_cancer() 33 | X_num = encode_categorical(X_str) 34 | return X_num, y 35 | 36 | 37 | @pytest.fixture() 38 | def make_whas500(): 39 | """Load and standardize WHAS500 data.""" 40 | 41 | def _make_whas500(with_mean=True, with_std=True, to_numeric=False): 42 | x, y = load_whas500() 43 | if with_mean: 44 | x = standardize(x, with_std=with_std) 45 | if to_numeric: 46 | x = categorical_to_numeric(x) 47 | names = ["(Intercept)"] + x.columns.tolist() 48 | return DataSetWithNames(x=x.values, y=y, names=names, x_data_frame=x) 49 | 50 | return _make_whas500 51 | 52 | 53 | @pytest.fixture() 54 | def whas500_sparse_data(): 55 | x, y = load_whas500() 56 | x_dense = categorical_to_numeric(x.select_dtypes(exclude=[float])) 57 | 58 | data = [] 59 | index_i = [] 60 | index_j = [] 61 | for j, (_, col) in enumerate(x_dense.items()): 62 | idx = np.flatnonzero(col.values) 63 | data.extend([1] * len(idx)) 64 | index_i.extend(idx) 65 | index_j.extend([j] * len(idx)) 66 | 67 | x_sparse = coo_matrix((data, (index_i, index_j))) 68 | return SparseDataSet(x_dense=x_dense, x_sparse=x_sparse, y=y) 69 | 70 | 71 | @pytest.fixture() 72 | def whas500_uncomparable(make_whas500): 73 | whas500 = make_whas500(to_numeric=True) 74 | i = np.argmax(whas500.y["lenfol"]) 75 | whas500.y["fstat"][:] = False 76 | whas500.y["fstat"][i] = True 77 | return whas500 78 | 79 | 80 | @pytest.fixture() 81 | def rossi(): 82 | """Load rossi.csv""" 83 | p = Path(__file__) 84 | f = p.parent / "data" / "rossi.csv" 85 | data = pd.read_csv(f) 86 | y = Surv.from_dataframe("arrest", "week", data) 87 | x = data.drop(["arrest", "week"], axis=1) 88 | return DataSet(x=x, y=y) 89 | 90 | 91 | @pytest.fixture(params=[np.inf, -np.inf, np.nan]) 92 | def non_finite_value(request): 93 | """Inf/-Inf/NaN value.""" 94 | return request.param 95 | 96 | 97 | @pytest.fixture() 98 | def temp_file(): 99 | f = tempfile.NamedTemporaryFile(mode="w", delete=False) 100 | fp = Path(f.name) 101 | yield f 102 | fp.unlink() 103 | -------------------------------------------------------------------------------- /tests/data/Lagakos_AIDS_adults.csv: -------------------------------------------------------------------------------- 1 | # Lagakos, S. W., Barraj, L. M., De Gruttola, V. (1988). 2 | # Nonparametric analysis of truncated survival data, with application to AIDS. 3 | # Biometrika, 75(3), 515–523. doi:10.1093/biomet/75.3.515 4 | INF,DIAG 5 | 0.0,5.0 6 | 0.25,6.75 7 | 0.75,5.0 8 | 0.75,5.0 9 | 0.75,7.25 10 | 1.0,4.25 11 | 1.0,5.75 12 | 1.0,6.25 13 | 1.0,6.5 14 | 1.25,4.0 15 | 1.25,4.25 16 | 1.25,4.75 17 | 1.25,5.75 18 | 1.5,2.75 19 | 1.5,3.75 20 | 1.5,5.0 21 | 1.5,5.5 22 | 1.5,6.5 23 | 1.75,2.75 24 | 1.75,3.0 25 | 1.75,5.25 26 | 1.75,5.25 27 | 2.0,2.25 28 | 2.0,3.0 29 | 2.0,4.0 30 | 2.0,4.5 31 | 2.0,4.75 32 | 2.0,5.0 33 | 2.0,5.25 34 | 2.0,5.25 35 | 2.0,5.5 36 | 2.0,5.5 37 | 2.0,6.0 38 | 2.25,3.0 39 | 2.25,5.5 40 | 2.5,2.25 41 | 2.5,2.25 42 | 2.5,2.25 43 | 2.5,2.25 44 | 2.5,2.5 45 | 2.5,2.75 46 | 2.5,3.0 47 | 2.5,3.25 48 | 2.5,3.25 49 | 2.5,4.0 50 | 2.5,4.0 51 | 2.5,4.0 52 | 2.75,1.25 53 | 2.75,1.5 54 | 2.75,2.5 55 | 2.75,3.0 56 | 2.75,3.0 57 | 2.75,3.25 58 | 2.75,3.75 59 | 2.75,4.5 60 | 2.75,4.5 61 | 2.75,5.0 62 | 2.75,5.0 63 | 2.75,5.25 64 | 2.75,5.25 65 | 2.75,5.25 66 | 2.75,5.25 67 | 2.75,5.25 68 | 3.0,2.0 69 | 3.0,3.25 70 | 3.0,3.5 71 | 3.0,3.75 72 | 3.0,4.0 73 | 3.0,4.0 74 | 3.0,4.25 75 | 3.0,4.25 76 | 3.0,4.25 77 | 3.0,4.75 78 | 3.0,4.75 79 | 3.0,4.75 80 | 3.0,5.0 81 | 3.25,1.25 82 | 3.25,1.75 83 | 3.25,2.0 84 | 3.25,2.0 85 | 3.25,2.75 86 | 3.25,3.0 87 | 3.25,3.0 88 | 3.25,3.5 89 | 3.25,3.5 90 | 3.25,4.25 91 | 3.25,4.5 92 | 3.5,1.25 93 | 3.5,2.25 94 | 3.5,2.25 95 | 3.5,2.5 96 | 3.5,2.75 97 | 3.5,2.75 98 | 3.5,3.0 99 | 3.5,3.25 100 | 3.5,3.5 101 | 3.5,3.5 102 | 3.5,4.0 103 | 3.5,4.0 104 | 3.5,4.25 105 | 3.5,4.5 106 | 3.5,4.5 107 | 3.75,1.25 108 | 3.75,1.75 109 | 3.75,1.75 110 | 3.75,2.0 111 | 3.75,2.75 112 | 3.75,3.0 113 | 3.75,3.0 114 | 3.75,3.0 115 | 3.75,4.0 116 | 3.75,4.25 117 | 3.75,4.25 118 | 4.0,1.0 119 | 4.0,1.5 120 | 4.0,1.5 121 | 4.0,2.0 122 | 4.0,2.25 123 | 4.0,2.75 124 | 4.0,3.5 125 | 4.0,3.75 126 | 4.0,3.75 127 | 4.0,4.0 128 | 4.25,1.25 129 | 4.25,1.5 130 | 4.25,1.5 131 | 4.25,2.0 132 | 4.25,2.0 133 | 4.25,2.0 134 | 4.25,2.25 135 | 4.25,2.5 136 | 4.25,2.5 137 | 4.25,2.5 138 | 4.25,3.0 139 | 4.25,3.5 140 | 4.25,3.5 141 | 4.5,1.0 142 | 4.5,1.5 143 | 4.5,1.5 144 | 4.5,1.5 145 | 4.5,1.75 146 | 4.5,2.25 147 | 4.5,2.25 148 | 4.5,2.5 149 | 4.5,2.5 150 | 4.5,2.5 151 | 4.5,2.5 152 | 4.5,2.75 153 | 4.5,2.75 154 | 4.5,2.75 155 | 4.5,2.75 156 | 4.5,3.0 157 | 4.5,3.0 158 | 4.5,3.0 159 | 4.5,3.25 160 | 4.5,3.25 161 | 4.75,1.0 162 | 4.75,1.5 163 | 4.75,1.5 164 | 4.75,1.5 165 | 4.75,1.75 166 | 4.75,1.75 167 | 4.75,2.0 168 | 4.75,2.25 169 | 4.75,2.75 170 | 4.75,3.0 171 | 4.75,3.0 172 | 4.75,3.25 173 | 4.75,3.25 174 | 4.75,3.25 175 | 4.75,3.25 176 | 4.75,3.25 177 | 4.75,3.25 178 | 5.0,0.5 179 | 5.0,1.5 180 | 5.0,1.5 181 | 5.0,1.75 182 | 5.0,2.0 183 | 5.0,2.25 184 | 5.0,2.25 185 | 5.0,2.25 186 | 5.0,2.5 187 | 5.0,2.5 188 | 5.0,3.0 189 | 5.0,3.0 190 | 5.0,3.0 191 | 5.25,0.25 192 | 5.25,0.25 193 | 5.25,0.75 194 | 5.25,0.75 195 | 5.25,0.75 196 | 5.25,1.0 197 | 5.25,1.0 198 | 5.25,1.25 199 | 5.25,1.25 200 | 5.25,1.5 201 | 5.25,1.5 202 | 5.25,1.5 203 | 5.25,1.5 204 | 5.25,2.25 205 | 5.25,2.25 206 | 5.25,2.5 207 | 5.25,2.5 208 | 5.25,2.75 209 | 5.5,1.0 210 | 5.5,1.0 211 | 5.5,1.0 212 | 5.5,1.25 213 | 5.5,1.25 214 | 5.5,1.75 215 | 5.5,2.0 216 | 5.5,2.25 217 | 5.5,2.25 218 | 5.5,2.5 219 | 5.75,0.25 220 | 5.75,0.75 221 | 5.75,1.0 222 | 5.75,1.5 223 | 5.75,1.5 224 | 5.75,1.5 225 | 5.75,2.0 226 | 5.75,2.0 227 | 5.75,2.25 228 | 6.0,0.5 229 | 6.0,0.75 230 | 6.0,0.75 231 | 6.0,0.75 232 | 6.0,1.0 233 | 6.0,1.0 234 | 6.0,1.0 235 | 6.0,1.25 236 | 6.0,1.25 237 | 6.0,1.5 238 | 6.0,1.5 239 | 6.0,1.75 240 | 6.0,1.75 241 | 6.0,1.75 242 | 6.0,2.0 243 | 6.25,0.75 244 | 6.25,1.0 245 | 6.25,1.25 246 | 6.25,1.75 247 | 6.25,1.75 248 | 6.5,0.25 249 | 6.5,0.25 250 | 6.5,0.75 251 | 6.5,1.0 252 | 6.5,1.25 253 | 6.5,1.5 254 | 6.75,0.75 255 | 6.75,0.75 256 | 6.75,0.75 257 | 6.75,1.0 258 | 6.75,1.25 259 | 6.75,1.25 260 | 6.75,1.25 261 | 7.0,0.75 262 | 7.25,0.25 263 | -------------------------------------------------------------------------------- /tests/data/Lagakos_AIDS_children.csv: -------------------------------------------------------------------------------- 1 | # Lagakos, S. W., Barraj, L. M., De Gruttola, V. (1988). 2 | # Nonparametric analysis of truncated survival data, with application to AIDS. 3 | # Biometrika, 75(3), 515–523. doi:10.1093/biomet/75.3.515 4 | INF,DIAG 5 | 1,5.5 6 | 1.5,2.25 7 | 2.25,3 8 | 2.75,1 9 | 3,1.75 10 | 3.5,0.75 11 | 3.75,0.75 12 | 3.75,1 13 | 3.75,2.75 14 | 3.75,3 15 | 3.75,3.5 16 | 3.75,4.25 17 | 4,1 18 | 4.25,1.75 19 | 4.5,3.25 20 | 4.75,1 21 | 4.75,2.25 22 | 5,0.5 23 | 5,0.75 24 | 5,1.5 25 | 5,2.5 26 | 5.25,0.25 27 | 5.25,1 28 | 5.25,1.5 29 | 5.5,0.5 30 | 5.5,1.5 31 | 5.5,2.5 32 | 5.75,1.75 33 | 6,0.5 34 | 6,1.25 35 | 6.25,0.5 36 | 6.25,1.25 37 | 6.5,0.75 38 | 6.75,0.5 39 | 6.75,0.75 40 | 7,0.75 41 | 7.25,0.25 -------------------------------------------------------------------------------- /tests/data/cgvhd_aalen.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sebp/scikit-survival/05030894296c07eda84a36e619e814843852aa0a/tests/data/cgvhd_aalen.npy -------------------------------------------------------------------------------- /tests/data/cgvhd_delta.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sebp/scikit-survival/05030894296c07eda84a36e619e814843852aa0a/tests/data/cgvhd_delta.npy -------------------------------------------------------------------------------- /tests/data/cgvhd_dinse.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sebp/scikit-survival/05030894296c07eda84a36e619e814843852aa0a/tests/data/cgvhd_dinse.npy -------------------------------------------------------------------------------- /tests/data/cox-example-coef-2-alpha.csv: -------------------------------------------------------------------------------- 1 | "s0","s1","s2","s3","s4","s5","s6" 2 | 0,0,0,0,0,0,0.352790212410126 3 | 0,0,0,0,0,0,-0.108348034453021 4 | 0,0,0,0,0,0,-0.14417019166909 5 | 0,0,0,0,0,0,0.112497593457371 6 | 0,0,0,0,0,0,-0.127578315262323 7 | 0,0,0,0,0,0,-0.364521710050781 8 | 0,0,0,0,0,0,0.23806306221791 9 | 0,0,0,0,0,0,0.0554724028754736 10 | 0,0,0,0,0,0,0.328160690091434 11 | 0,0,0,0,0,0,0.0619991987139122 12 | 0,0,0,0,0,0,0 13 | 0,0,0,0,0,0,0 14 | 0,0,0,0,0,0,0.00273691909908345 15 | 0,0,0,0,0,0,0 16 | 0,0,0,0,0,0,0 17 | 0,0,0,0,0,0,0 18 | 0,0,0,0,0,0,-7.8658095715925e-05 19 | 0,0,0,0,0,0,0 20 | 0,0,0,0,0,0,0 21 | 0,0,0,0,0,0,0 22 | 0,0,0,0,0,0,0 23 | 0,0,0,0,0,0,0 24 | 0,0,0,0,0,0,0 25 | 0,0,0,0,0,0,0 26 | 0,0,0,0,0,0,0 27 | 0,0,0,0,0,0,0 28 | 0,0,0,0,0,0,0 29 | 0,0,0,0,0,0,0 30 | 0,0,0,0,0,0,0 31 | 0,0,0,0,0,0,0 32 | -------------------------------------------------------------------------------- /tests/data/cox-example-coef-2-nalpha-norm.csv: -------------------------------------------------------------------------------- 1 | "s0","s1","s2","s3","s4","s5","s6","s7","s8","s9","s10" 2 | 0,0.101157742631475,0.230332221918114,0.345949902096499,0.431282962517806,0.486881372962124,0.518780153198447,0.537043970996675,0.546426729038658,0.551216980844768,0.553572104291376 3 | 0,0,-0.039687900093849,-0.104594307663587,-0.151507198359495,-0.179631800770021,-0.195393608456192,-0.204004450886213,-0.208413952116315,-0.210657574374462,-0.21175918728253 4 | 0,0,-0.0698322261390677,-0.140099094858407,-0.191466465414609,-0.223502234509198,-0.241193018670412,-0.25090992174618,-0.256084848803537,-0.258773162035965,-0.260107522812531 5 | 0,0,0.0457240779762313,0.108839346842319,0.152330706582482,0.180377903037592,0.195978378479931,0.20496164468832,0.209677349160301,0.212082319860934,0.213247083526382 6 | 0,0,-0.0640598314919326,-0.124085279770885,-0.166091995527082,-0.192835380945313,-0.207749802431937,-0.216116938348396,-0.220125406480167,-0.22215364636594,-0.223163886665636 7 | 0,-0.122830038903105,-0.248352955194318,-0.358216027448526,-0.434799044280752,-0.484035467115441,-0.512787402941905,-0.528650156432608,-0.536860637451282,-0.541041694827307,-0.543094275488652 8 | 0,0.0447260350592355,0.145029461442723,0.232960325258719,0.295925105447328,0.336603391477607,0.359416430227044,0.371998058254122,0.37822470656736,0.381331984674934,0.382855422815459 9 | 0,0,0.00811294822723447,0.0531471338168118,0.0817777684602242,0.0972765813190051,0.105021385381085,0.108903542854024,0.110990118125211,0.112073584922062,0.112609917985845 10 | 0,0.0952741362791348,0.216849837937998,0.321910602363338,0.397879771780052,0.445537848142615,0.472699260394528,0.487847803734235,0.495433606424665,0.49929173787191,0.501173664878597 11 | 0,0,0,0.058678380254328,0.0986680125740883,0.123112795383886,0.136570957591533,0.143972143033712,0.147685465994468,0.149568477669935,0.150512315168427 12 | 0,0,0,0,0,-0.0071250733666086,-0.0147190015491433,-0.0193483320416341,-0.0218079498619399,-0.0230717977406199,-0.0236955739934211 13 | 0,0,0,0,0,0,0,0.00126913372086278,0.00253288106296641,0.00315196292533745,0.00346544677521315 14 | 0,0,0,0.00117451733921073,0.0191945062320497,0.0294845316520058,0.0352219836023814,0.0385602395399207,0.040301481746046,0.0411815727762595,0.0416135017826611 15 | 0,0,0,0,0,0,0,0,0,0,0 16 | 0,0,0,0,0,0,0,0,-0.00126145987229018,-0.00240823866678263,-0.00300197700544213 17 | 0,0,0,0,0,0,0,0,0,0.000470074816884513,0.000984125915858206 18 | 0,0,0,0,-0.0167391493492525,-0.027819184671025,-0.0341425700655545,-0.0372924585526771,-0.038949895650505,-0.0398831542558697,-0.0403877672927384 19 | 0,0,0,0,0,-0.000126359817622088,-0.00568834211857645,-0.00832971826701309,-0.00964494918977269,-0.0102877817689168,-0.0106029273367256 20 | 0,0,0,0,0,0,0.00264672463158286,0.00646811965357053,0.00829630466967871,0.00920697358179043,0.00966327571210263 21 | 0,0,0,0,0,0,0,0,0,0,0.000174616895329835 22 | 0,0,0,0,-0.00292096214866786,-0.0131558795118032,-0.0184718869987966,-0.0209654494391727,-0.0222431200225692,-0.0228643414611026,-0.0231550527248918 23 | 0,0,0,0,0,-0.011844696857389,-0.0196697403658202,-0.0240426630655554,-0.0262756240672492,-0.0273835549880396,-0.0279439813127631 24 | 0,0,0,0,0,0,-0.000390529423260579,-0.00502022828892123,-0.0075391654176011,-0.00882578975775135,-0.00946261836193899 25 | 0,0,0,0,0,0,-7.44768120743759e-05,-0.00443331060764391,-0.00682431594951285,-0.00808381254575262,-0.00873639057481113 26 | 0,0,0,0,-0.0205434769026525,-0.0348706213579541,-0.0430944235166272,-0.0479787802513129,-0.0504835881472502,-0.0517320635000316,-0.0523547826448766 27 | 0,0,0,0,0,0,0,-0.000265226907815689,-0.00213030360299081,-0.00311176288273769,-0.00359999588634632 28 | 0,0,0,0,0.000325067810701004,0.0152505140449114,0.0232164085333583,0.0274211998779581,0.0296051527122931,0.0307136437389022,0.0312861492837056 29 | 0,0,0,0,0,-0.00458268167249203,-0.00880734267131486,-0.010901371690565,-0.011972187025799,-0.0125213190965137,-0.012789487771725 30 | 0,0,0,0,0,0,0.00360381453736373,0.00533964944900609,0.00635218338890861,0.00691351111478025,0.00722979305580961 31 | 0,0,0,0,-0.00902154044384413,-0.0209981120197067,-0.0273985599849231,-0.0310613197371897,-0.0331070666035073,-0.0341378592869992,-0.0346518803373749 32 | -------------------------------------------------------------------------------- /tests/data/cox-example-coef-2-nalpha.csv: -------------------------------------------------------------------------------- 1 | "s0","s1","s2","s3","s4","s5","s6","s7","s8","s9","s10" 2 | 0,0.188379897144616,0.344250015266454,0.440139094191937,0.493722148611339,0.52379719163387,0.539436393242641,0.547680075380291,0.5518817115364,0.55394616900512,0.554639126278747 3 | 0,0,-0.0761716660983314,-0.141337457041594,-0.177053945088993,-0.194614101657467,-0.203585345742739,-0.208222639362887,-0.210576285044589,-0.211738766850343,-0.212240688688108 4 | 0,-0.00218588754370381,-0.115447043313753,-0.183390061223191,-0.221330472978123,-0.241124340681711,-0.250899617572233,-0.255894467761448,-0.258654184140974,-0.260065700354794,-0.260555177073583 5 | 0,0,0.0784179110985707,0.144256649191689,0.177413317832577,0.195625351590061,0.204630848315522,0.209456447338459,0.212001586758049,0.213247637705685,0.213756565992394 6 | 0,0,-0.0935986329578498,-0.155042061881949,-0.188316387722855,-0.206247645953485,-0.215330471904231,-0.220061895095285,-0.22213671897946,-0.223163992700592,-0.223640530292392 7 | 0,-0.213001581159395,-0.357360146084311,-0.445994600409445,-0.491633333823296,-0.517426846427584,-0.531102435072959,-0.538059051798062,-0.541687297281017,-0.543454490164664,-0.544118534741004 8 | 0,0.0888275540574932,0.216340405177518,0.294491943687599,0.337188050549421,0.361122364934831,0.372949959971717,0.378956627141877,0.381750147424197,0.383091631512976,0.383644975084633 9 | 0,0,0.021144463872504,0.0689799288013892,0.0928505181478764,0.103541831034763,0.108339989770335,0.110588306755649,0.111872237027616,0.112513041962667,0.112842870970619 10 | 0,0.168790477292464,0.314606910575026,0.402929819044781,0.451340833584817,0.477048294128823,0.490136460018706,0.496832231288758,0.500041708335753,0.50159173682166,0.50224174986692 11 | 0,0,0.0162212099125674,0.0845175341827935,0.11759019004849,0.134566660278935,0.143051575386876,0.147346247286474,0.1494007345186,0.150435499804778,0.150964018729624 12 | 0,0,0,0,0,-0.0104863008376034,-0.0169127687655104,-0.0205405870828035,-0.0224380122391535,-0.0233930482336063,-0.0237848256456186 13 | 0,0,0,0,0,0,0,0.00110342886097246,0.00244742856713572,0.00310866128299545,0.00339066869133957 14 | 0,0,0,0.00129634782575739,0.0201682498626346,0.0301749655214231,0.0357859187834796,0.0388715312819611,0.0404705512555458,0.0412660965697554,0.041552367578563 15 | 0,0,0,0,0,0,0,0,0,0,0 16 | 0,0,0,0,0,0,0,0,-0.00156323079119563,-0.00257031728905216,-0.00300897786373177 17 | 0,0,0,0,0,0,0,0,0,0.000578754497236348,0.000954142607637004 18 | 0,0,0,-0.000340084590573705,-0.0201397548142474,-0.030498220258115,-0.0357283078546612,-0.0381521213947762,-0.0394149607450197,-0.0401339036942181,-0.0405323686771122 19 | 0,0,0,0,0,-0.000873228311047719,-0.00617683989882797,-0.00861373652821605,-0.00978328525973577,-0.010358308417395,-0.0106640918505683 20 | 0,0,0,0,0,0,0.00353045649672897,0.00695928343299943,0.00853542683639042,0.00933210841436594,0.00969402810203858 21 | 0,0,0,0,0,0,0,0,0,0,0.000118912808452622 22 | 0,0,0,0,-0.00481185025383026,-0.0145086666503297,-0.0192702520909482,-0.0213967718699201,-0.0224744692107516,-0.0229717669805133,-0.0233052640760358 23 | 0,0,0,0,-0.00348154423535438,-0.0155332133429143,-0.0219408674979574,-0.0252500576088072,-0.0268968805811487,-0.0276948283110578,-0.0280679030079186 24 | 0,0,0,0,0,0,-0.00254256536593351,-0.006223244290189,-0.00818749377834513,-0.0091538683789294,-0.00958440030391387 25 | 0,0,0,0,0,0,-0.000779555142952622,-0.00486816554142604,-0.00708632500813779,-0.00822291422313495,-0.00871449878900857 26 | 0,0,0,-0.00219620992995105,-0.0252466675573078,-0.0378454233200919,-0.0449402009405363,-0.0489595534388782,-0.050986898304077,-0.0519827419048036,-0.0523948243005398 27 | 0,0,0,0,0,0,0,-0.000276895036918721,-0.00217445894522609,-0.00313437325823785,-0.00350328368106385 28 | 0,0,0,0,0.0030577201778878,0.0174684830506897,0.0244991553328014,0.0281111262530663,0.0299501593567974,0.0308968646811434,0.0313291438365989 29 | 0,0,0,0,0,-0.00406117066744012,-0.00859616685690651,-0.0107902896829827,-0.0119194092614571,-0.0124970689885372,-0.0128019857784268 30 | 0,0,0,0,0,0,0.00297435159930023,0.00498834161286277,0.00618393963863962,0.00683767629460158,0.00717166265112352 31 | 0,0,0,0,-0.010481617862575,-0.0221871537978954,-0.0281673939528422,-0.0314870245814881,-0.0333261845756012,-0.0342524737442581,-0.0346767973194823 32 | -------------------------------------------------------------------------------- /tests/data/cox-simple-coef.csv: -------------------------------------------------------------------------------- 1 | "s0","s1","s2","s3","s4","s5","s6","s7","s8","s9","s10","s11","s12","s13","s14","s15","s16","s17","s18","s19","s20","s21","s22","s23","s24","s25","s26","s27","s28","s29","s30","s31","s32","s33","s34","s35","s36","s37","s38","s39","s40","s41","s42","s43","s44","s45","s46","s47","s48","s49","s50" 2 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 3 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0.000356117667614688,0.00103109513672771 4 | 0,0.0109695831467285,0.0202917520453709,0.0285253087429664,0.0360038536145453,0.0429421878402113,0.0494878726120129,0.0557454829942189,0.0617936496367534,0.0676928247821,0.0734936650636488,0.0792321479220874,0.0849419878353832,0.0906505283936411,0.0963809719171504,0.102153052263327,0.107983466405607,0.113886139347814,0.119883317429973,0.125965016092286,0.132145712236168,0.138429249889062,0.144817157837473,0.151308813754268,0.157901663293676,0.16459149556609,0.171372762350935,0.178238921564117,0.185182782352613,0.192196829954851,0.199273512479321,0.206405477749859,0.213585754904693,0.22080788130077,0.2280659797038,0.235354793471163,0.242657634617521,0.250013892284216,0.257363877664968,0.264729630325065,0.27210885723057,0.279506946081561,0.286889904908067,0.294302778950165,0.30173389272566,0.309172818123136,0.316558067827113,0.32402719211167,0.331337625544557,0.338432782708286,0.345240249041518 5 | 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 6 | -------------------------------------------------------------------------------- /tests/test_aft.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.testing import assert_array_almost_equal 3 | import pytest 4 | from sklearn.pipeline import make_pipeline 5 | 6 | from sksurv.base import SurvivalAnalysisMixin 7 | from sksurv.linear_model import IPCRidge 8 | from sksurv.testing import assert_cindex_almost_equal 9 | 10 | 11 | class TestIPCRidge: 12 | @staticmethod 13 | def test_fit(make_whas500): 14 | whas500 = make_whas500() 15 | model = IPCRidge() 16 | model.fit(whas500.x, whas500.y) 17 | 18 | assert model.intercept_ == pytest.approx(5.867520370855396, 1e-7) 19 | expected = np.array( 20 | [ 21 | 0.168481, 22 | -0.24962, 23 | 2.185086, 24 | 0.53682, 25 | -0.514611, 26 | 0.09124, 27 | 0.613114, 28 | 0.480357, 29 | -0.055972, 30 | 0.238472, 31 | -0.127209, 32 | -0.144063, 33 | -1.625081, 34 | -0.217591, 35 | ] 36 | ) 37 | assert_array_almost_equal(model.coef_, expected) 38 | 39 | @staticmethod 40 | def test_predict(make_whas500): 41 | whas500 = make_whas500() 42 | model = IPCRidge() 43 | model.fit(whas500.x[:400], whas500.y[:400]) 44 | 45 | x_test = whas500.x[400:] 46 | y_test = whas500.y[400:] 47 | p = model.predict(x_test) 48 | assert_cindex_almost_equal( 49 | y_test["fstat"], 50 | y_test["lenfol"], 51 | -p, 52 | (0.66925817946226107, 2066, 1021, 0, 1), 53 | ) 54 | 55 | assert model.score(x_test, y_test) == 0.66925817946226107 56 | 57 | @staticmethod 58 | def test_pipeline_score(make_whas500): 59 | whas500 = make_whas500() 60 | pipe = make_pipeline(IPCRidge()) 61 | pipe.fit(whas500.x[:400], whas500.y[:400]) 62 | 63 | x_test = whas500.x[400:] 64 | y_test = whas500.y[400:] 65 | p = pipe.predict(x_test) 66 | assert_cindex_almost_equal( 67 | y_test["fstat"], 68 | y_test["lenfol"], 69 | -p, 70 | (0.66925817946226107, 2066, 1021, 0, 1), 71 | ) 72 | 73 | assert SurvivalAnalysisMixin.score(pipe, x_test, y_test) == 0.66925817946226107 74 | -------------------------------------------------------------------------------- /tests/test_binarytrees.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from sksurv.bintrees import AVLTree, RBTree 4 | 5 | 6 | @pytest.fixture(params=[RBTree, AVLTree]) 7 | def tree(request): 8 | return request.param(10) 9 | 10 | 11 | class TestBinaryTree: 12 | @staticmethod 13 | def test_insert(tree): 14 | for k in (12, 34, 45, 16, 35, 57): 15 | tree.insert(k, k) 16 | assert 6 == len(tree) 17 | 18 | @staticmethod 19 | def test_count_smaller(tree): 20 | for k in (12, 34, 45, 16, 35, 57): 21 | tree.insert(k, k) 22 | 23 | c, _ = tree.count_smaller(12) 24 | assert 0 == c 25 | 26 | c, _ = tree.count_smaller(16) 27 | assert 1 == c 28 | 29 | c, _ = tree.count_smaller(34) 30 | assert 2 == c 31 | 32 | c, _ = tree.count_smaller(35) 33 | assert 3 == c 34 | 35 | c, _ = tree.count_smaller(45) 36 | assert 4 == c 37 | 38 | c, _ = tree.count_smaller(57) 39 | assert 5 == c 40 | 41 | @staticmethod 42 | def test_count_larger(tree): 43 | for k in (12, 34, 45, 16, 35, 57): 44 | tree.insert(k, k) 45 | 46 | c, _ = tree.count_larger(12) 47 | assert 5 == c 48 | 49 | c, _ = tree.count_larger(16) 50 | assert 4 == c 51 | 52 | c, _ = tree.count_larger(34) 53 | assert 3 == c 54 | 55 | c, _ = tree.count_larger(35) 56 | assert 2 == c 57 | 58 | c, _ = tree.count_larger(45) 59 | assert 1 == c 60 | 61 | c, _ = tree.count_larger(57) 62 | assert 0 == c 63 | -------------------------------------------------------------------------------- /tests/test_common.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from sksurv.base import SurvivalAnalysisMixin 4 | from sksurv.testing import all_survival_estimators 5 | 6 | 7 | @pytest.mark.parametrize("estimator_cls", all_survival_estimators()) 8 | def test_survival_analysis_base_clas(estimator_cls): 9 | assert hasattr(estimator_cls, "fit") 10 | assert hasattr(estimator_cls, "predict") 11 | assert hasattr(estimator_cls, "score") 12 | assert issubclass(estimator_cls, SurvivalAnalysisMixin) 13 | -------------------------------------------------------------------------------- /tests/test_io.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from io import StringIO 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import pandas.testing as tm 7 | import pytest 8 | 9 | from sksurv.io import loadarff, writearff 10 | from sksurv.testing import FixtureParameterFactory 11 | 12 | EXPECTED_1 = [ 13 | "@relation test_nominal\n", 14 | "\n", 15 | "@attribute attr_nominal\t{beer,water,wine}\n", 16 | '@attribute attr_nominal_spaces\t{"hard liquor",mate,"red wine"}\n', 17 | "\n", 18 | "@data\n", 19 | 'water,"red wine"\n', 20 | 'wine,"hard liquor"\n', 21 | "beer,?\n", 22 | "?,mate\n", 23 | 'wine,"hard liquor"\n', 24 | "water,mate\n", 25 | ] 26 | 27 | 28 | EXPECTED_NO_QUOTES = [ 29 | "@relation test_nominal\n", 30 | "\n", 31 | "@attribute attr_nominal\t{beer,water,wine}\n", 32 | "@attribute attr_nominal_spaces\t{hard liquor,mate,red wine}\n", 33 | "\n", 34 | "@data\n", 35 | "water,red wine\n", 36 | "wine,hard liquor\n", 37 | "beer,?\n", 38 | "?,mate\n", 39 | "wine,hard liquor\n", 40 | "water,mate\n", 41 | ] 42 | 43 | 44 | EXPECTED_DATETIME = [ 45 | "@relation test_datetime\n", 46 | "\n", 47 | "@attribute attr_datetime\tdate 'yyyy-MM-dd HH:mm:ss'\n", 48 | "\n", 49 | "@data\n", 50 | '"2014-10-31 14:13:01"\n', 51 | '"2004-03-13 19:49:31"\n', 52 | '"1998-12-06 09:10:11"\n', 53 | ] 54 | 55 | 56 | class DataFrameCases(FixtureParameterFactory): 57 | def data_nominal(self): 58 | data = pd.DataFrame( 59 | { 60 | "attr_nominal": ["water", "wine", "beer", None, "wine", "water"], 61 | "attr_nominal_spaces": ["red wine", "hard liquor", None, "mate", "hard liquor", "mate"], 62 | } 63 | ) 64 | return data, "test_nominal", EXPECTED_1.copy() 65 | 66 | def data_nominal_with_quotes(self): 67 | data, rel_name, expected = self.data_nominal() 68 | data["attr_nominal_spaces"] = ["'red wine'", "'hard liquor'", None, "mate", "'hard liquor'", "mate"] 69 | return data, rel_name, expected 70 | 71 | def data_nominal_as_category(self): 72 | data, rel_name, expected = self.data_nominal_with_quotes() 73 | for name, series in data.items(): 74 | data[name] = pd.Categorical(series, ordered=False) 75 | 76 | expected[3] = '@attribute attr_nominal_spaces\t{"hard liquor","red wine",mate}\n' 77 | return data, rel_name, expected 78 | 79 | def data_nominal_as_category_extra(self): 80 | data, rel_name, expected = self.data_nominal_as_category() 81 | data["attr_nominal"] = pd.Categorical( 82 | ["water", "wine", "beer", None, "wine", "water"], 83 | categories=["beer", "coke", "water", "wine"], 84 | ordered=False, 85 | ) 86 | 87 | expected[2] = "@attribute attr_nominal\t{beer,coke,water,wine}\n" 88 | return data, rel_name, expected 89 | 90 | def data_nominal_with_category_ordering(self): 91 | data, rel_name, expected = self.data_nominal_with_quotes() 92 | data["attr_nominal"] = pd.Categorical( 93 | ["water", "wine", "beer", None, "wine", "water"], 94 | categories=["water", "coke", "beer", "wine"], 95 | ordered=False, 96 | ) 97 | 98 | expected[2] = "@attribute attr_nominal\t{water,coke,beer,wine}\n" 99 | return data, rel_name, expected 100 | 101 | def data_datetime(self): 102 | data = pd.DataFrame( 103 | { 104 | "attr_datetime": np.array( 105 | ["2014-10-31 14:13:01", "2004-03-13 19:49:31", "1998-12-06 09:10:11"], dtype="datetime64" 106 | ) 107 | } 108 | ) 109 | return data, "test_datetime", EXPECTED_DATETIME.copy() 110 | 111 | 112 | def test_loadarff_dataframe(): 113 | contents = "".join(EXPECTED_NO_QUOTES) 114 | with StringIO(contents) as fp: 115 | actual_df = loadarff(fp) 116 | 117 | expected_df = pd.DataFrame.from_dict( 118 | OrderedDict( 119 | [ 120 | ("attr_nominal", pd.Series(pd.Categorical.from_codes([1, 2, 0, -1, 2, 1], ["beer", "water", "wine"]))), 121 | ( 122 | "attr_nominal_spaces", 123 | pd.Series(pd.Categorical.from_codes([2, 0, -1, 1, 0, 1], ["hard liquor", "mate", "red wine"])), 124 | ), 125 | ] 126 | ) 127 | ) 128 | 129 | tm.assert_frame_equal(expected_df, actual_df, check_exact=True) 130 | 131 | 132 | @pytest.mark.parametrize("data_frame,relation_name,expectation", DataFrameCases().get_cases()) 133 | def test_writearff(data_frame, relation_name, expectation, temp_file): 134 | writearff(data_frame, temp_file, relation_name=relation_name, index=False) 135 | 136 | with open(temp_file.name) as fp: 137 | read_date = fp.readlines() 138 | 139 | assert expectation == read_date 140 | 141 | 142 | def test_writearff_unsupported_column_type(temp_file): 143 | data = pd.DataFrame( 144 | { 145 | "attr_datetime": np.array([2 + 3j, 45.1 - 1j, 0 - 1j, 7 + 0j, 132 - 3j, 1 - 0.41j], dtype="complex128"), 146 | } 147 | ) 148 | 149 | with pytest.raises(TypeError, match="unsupported type complex128"): 150 | writearff(data, temp_file, relation_name="test_delta", index=False) 151 | -------------------------------------------------------------------------------- /tests/test_pandas_inputs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.testing import assert_array_equal 3 | import pandas as pd 4 | import pytest 5 | 6 | from sksurv.datasets import load_whas500 7 | from sksurv.testing import all_survival_estimators 8 | 9 | 10 | @pytest.mark.parametrize("estimator_cls", all_survival_estimators()) 11 | def test_pandas_inputs(estimator_cls): 12 | X, y = load_whas500() 13 | X = X.iloc[:50] 14 | y = y[:50] 15 | X_df = X.loc[:, ["age", "bmi", "chf", "gender"]].astype(float) 16 | X_np = X_df.values 17 | 18 | estimator = estimator_cls() 19 | if "kernel" in estimator.get_params(): 20 | estimator.set_params(kernel="rbf") 21 | estimator.fit(X_df, y) 22 | assert hasattr(estimator, "feature_names_in_") 23 | assert_array_equal(estimator.feature_names_in_, np.asarray(X_df.columns, dtype=object)) 24 | estimator.predict(X_df) 25 | 26 | msg = "Feature names must be in the same order as they were in fit" 27 | X_bad = pd.DataFrame(X_np, columns=X_df.columns.tolist()[::-1]) 28 | with pytest.raises(ValueError, match=msg): 29 | estimator.predict(X_bad) 30 | 31 | # warns when fitted on dataframe and transforming a ndarray 32 | msg = f"X does not have valid feature names, but {estimator_cls.__name__} was fitted with feature names" 33 | with pytest.warns(UserWarning, match=msg): 34 | estimator.predict(X_np) 35 | 36 | # warns when fitted on a ndarray and transforming dataframe 37 | msg = f"X has feature names, but {estimator_cls.__name__} was fitted without feature names" 38 | estimator.fit(X_np, y) 39 | with pytest.warns(UserWarning, match=msg): 40 | estimator.predict(X_df) 41 | -------------------------------------------------------------------------------- /tests/test_preprocessing.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import numpy as np 4 | from numpy.testing import assert_array_equal 5 | import pandas as pd 6 | import pandas.testing as tm 7 | import pytest 8 | 9 | from sksurv.preprocessing import OneHotEncoder 10 | 11 | 12 | def _encoded_data(data): 13 | expected = [] 14 | for nam, col in data.items(): 15 | if hasattr(col, "cat"): 16 | for cat in col.cat.categories[1:]: 17 | name = f"{nam}={cat}" 18 | s = pd.Series(col == cat, dtype=np.float64) 19 | expected.append((name, s)) 20 | else: 21 | expected.append((nam, col)) 22 | 23 | expected_data = pd.DataFrame.from_dict(OrderedDict(expected)) 24 | return expected_data 25 | 26 | 27 | @pytest.fixture() 28 | def create_data(): 29 | def _create_data(n_samples=117): 30 | rnd = np.random.RandomState(51365192) 31 | data_num = pd.DataFrame(rnd.rand(n_samples, 5), columns=[f"N{i}" for i in range(5)]) 32 | 33 | dat_cat = pd.DataFrame( 34 | OrderedDict( 35 | [ 36 | ("binary_1", pd.Categorical.from_codes(rnd.binomial(1, 0.6, n_samples), ["Yes", "No"])), 37 | ("binary_2", pd.Categorical.from_codes(rnd.binomial(1, 0.376, n_samples), ["East", "West"])), 38 | ("trinary", pd.Categorical.from_codes(rnd.binomial(2, 0.76, n_samples), ["Green", "Blue", "Red"])), 39 | ( 40 | "many", 41 | pd.Categorical.from_codes( 42 | rnd.binomial(5, 0.47, n_samples), ["One", "Two", "Three", "Four", "Five", "Six"] 43 | ), 44 | ), 45 | ] 46 | ) 47 | ) 48 | data = pd.concat((data_num, dat_cat), axis=1) 49 | return data, _encoded_data(data) 50 | 51 | return _create_data 52 | 53 | 54 | class TestOneHotEncoder: 55 | @staticmethod 56 | def test_fit(create_data): 57 | data, expected_data = create_data() 58 | 59 | t = OneHotEncoder().fit(data) 60 | 61 | assert t.feature_names_.tolist() == ["binary_1", "binary_2", "trinary", "many"] 62 | assert set(t.encoded_columns_) == set(expected_data.columns) 63 | 64 | assert t.categories_ == {k: data[k].cat.categories for k in ["binary_1", "binary_2", "trinary", "many"]} 65 | 66 | @staticmethod 67 | def test_fit_transform(create_data): 68 | data, expected_data = create_data() 69 | 70 | actual_data = OneHotEncoder().fit_transform(data) 71 | tm.assert_frame_equal(actual_data, expected_data) 72 | 73 | @staticmethod 74 | def test_transform(create_data): 75 | data, _ = create_data() 76 | 77 | t = OneHotEncoder().fit(data) 78 | data, expected_data = create_data(165) 79 | actual_data = t.transform(data) 80 | tm.assert_frame_equal(actual_data, expected_data) 81 | 82 | data = pd.concat((data.iloc[:, :2], data.iloc[:, 5:], data.iloc[:, 2:5]), axis=1) 83 | actual_data = t.transform(data) 84 | tm.assert_frame_equal(actual_data, expected_data) 85 | 86 | @staticmethod 87 | def test_get_feature_names_out(create_data): 88 | data, expected_data = create_data() 89 | 90 | t = OneHotEncoder() 91 | t.fit(data) 92 | 93 | out_names = t.get_feature_names_out() 94 | assert_array_equal(out_names, expected_data.columns.values) 95 | 96 | @staticmethod 97 | def test_get_feature_names_out_shuffled(create_data): 98 | data, _ = create_data() 99 | order = np.array(["binary_1", "N0", "N3", "trinary", "binary_2", "N1", "N2", "many"]) 100 | expected_columns = np.array( 101 | [ 102 | "binary_1=No", 103 | "N0", 104 | "N3", 105 | "trinary=Blue", 106 | "trinary=Red", 107 | "binary_2=West", 108 | "N1", 109 | "N2", 110 | "many=Two", 111 | "many=Three", 112 | "many=Four", 113 | "many=Five", 114 | "many=Six", 115 | ] 116 | ) 117 | 118 | t = OneHotEncoder() 119 | t.fit(data.loc[:, order]) 120 | 121 | out_names = t.get_feature_names_out() 122 | assert_array_equal(out_names, expected_columns) 123 | 124 | with pytest.raises(ValueError, match="input_features is not equal to feature_names_in_"): 125 | t.get_feature_names_out(data.columns.tolist()) 126 | 127 | @staticmethod 128 | def test_transform_other_columns(create_data): 129 | data, _ = create_data() 130 | 131 | t = OneHotEncoder().fit(data) 132 | data, _ = create_data(125) 133 | 134 | data_renamed = data.rename(columns={"binary_1": "renamed_1"}) 135 | with pytest.raises(ValueError, match=r"1 features are missing from data: \['binary_1'\]"): 136 | t.transform(data_renamed) 137 | 138 | data_dropped = data.drop("trinary", axis=1) 139 | error_msg = "X has 8 features, but OneHotEncoder is expecting 9 features as input" 140 | with pytest.raises(ValueError, match=error_msg): 141 | t.transform(data_dropped) 142 | 143 | data_renamed = data.rename(columns={"binary_1": "renamed_1", "many": "too_many"}) 144 | with pytest.raises(ValueError, match=r"2 features are missing from data: \['binary_1', 'many'\]"): 145 | t.transform(data_renamed) 146 | -------------------------------------------------------------------------------- /tests/test_show_versions.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import sksurv 4 | 5 | 6 | def test_show_versions(capsys): 7 | sksurv.show_versions() 8 | captured = capsys.readouterr() 9 | 10 | assert "SYSTEM" in captured.out 11 | assert "DEPENDENCIES" in captured.out 12 | 13 | # check required dependency 14 | assert re.search(r"numpy\s*:\s([0-9\.\+a-f]|dev)+\n", captured.out) 15 | assert re.search(r"pandas\s*:\s([0-9\.\+a-f]|dev)+\n", captured.out) 16 | assert re.search(r"scikit-learn\s*:\s([0-9\.\+a-f]|dev|post\d)+\n", captured.out) 17 | 18 | # check optional dependency 19 | assert re.search(r"matplotlib\s*:\s([0-9\.]+|None)\n", captured.out) 20 | -------------------------------------------------------------------------------- /tests/test_survival_function.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.testing import assert_array_almost_equal 3 | import pytest 4 | 5 | from sksurv.linear_model import CoxnetSurvivalAnalysis 6 | from sksurv.testing import all_survival_estimators 7 | from sksurv.util import Surv 8 | 9 | 10 | def all_survival_function_estimators(): 11 | estimators = set() 12 | for cls in all_survival_estimators(): 13 | if hasattr(cls, "predict_survival_function"): 14 | if issubclass(cls, CoxnetSurvivalAnalysis): 15 | est = cls(fit_baseline_model=True) 16 | else: 17 | est = cls() 18 | estimators.add(est) 19 | return estimators 20 | 21 | 22 | @pytest.mark.parametrize("estimator", all_survival_function_estimators()) 23 | def test_survival_functions(estimator, make_whas500): 24 | data = make_whas500(to_numeric=True) 25 | 26 | estimator.fit(data.x[150:], data.y[150:]) 27 | fns_cls = estimator.predict_survival_function(data.x[:150]) 28 | fns_arr = estimator.predict_survival_function(data.x[:150], return_array=True) 29 | 30 | times = estimator.unique_times_ 31 | arr = np.vstack([fn(times) for fn in fns_cls]) 32 | 33 | assert_array_almost_equal(arr, fns_arr) 34 | 35 | 36 | @pytest.mark.parametrize("estimator", all_survival_function_estimators()) 37 | @pytest.mark.parametrize("y_time", [-1e-8, -1, np.finfo(float).min]) 38 | def test_fit_negative_survial_time_raises(estimator, y_time): 39 | X = np.random.randn(7, 3) 40 | y = Surv.from_arrays(event=np.ones(7, dtype=bool), time=[1, 9, 3, y_time, 1, 8, 1e10]) 41 | 42 | with pytest.raises(ValueError, match="observed time contains values smaller zero"): 43 | estimator.fit(X, y) 44 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | requires = tox>=4 3 | env_list = lint,docs 4 | 5 | [testenv] 6 | deps = 7 | cython 8 | numpy 9 | 10 | [testenv:lint] 11 | description = Run linters 12 | skip_install = true 13 | deps = 14 | ruff~=0.8.4 15 | commands = ruff check sksurv/ tests/ setup.py 16 | pass_env = RUFF_* 17 | 18 | # Documentation 19 | [testenv:docs] 20 | description = Build documentation 21 | deps = 22 | {[testenv]deps} 23 | extras = 24 | docs 25 | change_dir = doc 26 | commands = 27 | sphinx-build -j 1 -d _build{/}doctrees -E -W -b html . _build{/}html 28 | 29 | [testenv:spelling] 30 | description = Spellcheck documentation 31 | deps = 32 | {[testenv:docs]deps} 33 | extras = 34 | {[testenv:docs]extras} 35 | change_dir = doc 36 | commands = 37 | sphinx-build -j auto -d _build{/}doctrees -E -W -b spelling . _build{/}spelling 38 | --------------------------------------------------------------------------------