'
13 |
14 | hr_faded: '
'
15 | hr_shaded: '
'
--------------------------------------------------------------------------------
/torchlife/models/error_dist.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: 65_AFT_error_distributions.ipynb (unless otherwise specified).
2 |
3 | __all__ = ['get_distribution']
4 |
5 | # Cell
6 | from functools import partial
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 | # Cell
13 | def get_distribution(dist:str):
14 | """
15 | Get the logpdf and logcdf of a given torch distribution
16 | """
17 | dist = getattr(torch.distributions, dist.title())
18 | if not isinstance(dist.support, torch.distributions.constraints._Real):
19 | raise Exception("Distribution needs support over ALL real values.")
20 |
21 | dist = partial(dist, loc=0.0)
22 |
23 | def dist_logpdf(ξ, σ):
24 | return dist(scale=σ).log_prob(ξ)
25 |
26 | def dist_logicdf(ξ, σ):
27 | """
28 | log of inverse cumulative distribution function
29 | """
30 | return torch.log(1 - dist(scale=σ).cdf(ξ))
31 |
32 | return dist_logpdf, dist_logicdf
--------------------------------------------------------------------------------
/docs/hazard.html:
--------------------------------------------------------------------------------
1 | ---
2 |
3 | title: Hazard Models
4 |
5 |
6 | keywords: fastai
7 | sidebar: home_sidebar
8 |
9 | summary: "Hazard based models."
10 | description: "Hazard based models."
11 | nb_path: "50_hazard.ipynb"
12 | ---
13 |
22 |
23 |
24 |
25 | {% raw %}
26 |
27 |
28 |
29 |
30 | {% endraw %}
31 |
32 |
33 |
34 |
We can choose to model the instantaneous hazard directly. There are two methods:
35 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
--------------------------------------------------------------------------------
/docs/licenses/LICENSE:
--------------------------------------------------------------------------------
1 | /* This license pertains to the docs template, except for the Navgoco jQuery component. */
2 |
3 | The MIT License (MIT)
4 |
5 | Original theme: Copyright (c) 2016 Tom Johnson
6 | Modifications: Copyright (c) 2017 onwards fast.ai, Inc
7 |
8 | Permission is hereby granted, free of charge, to any person obtaining a copy
9 | of this software and associated documentation files (the "Software"), to deal
10 | in the Software without restriction, including without limitation the rights
11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | copies of the Software, and to permit persons to whom the Software is
13 | furnished to do so, subject to the following conditions:
14 |
15 | The above copyright notice and this permission notice shall be included in all
16 | copies or substantial portions of the Software.
17 |
18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 | SOFTWARE.
25 |
--------------------------------------------------------------------------------
/.github/workflows/main.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 | on: [push]
3 | jobs:
4 | build:
5 | runs-on: ubuntu-latest
6 | steps:
7 | - uses: actions/checkout@v1
8 | - uses: actions/setup-python@v1
9 | with:
10 | python-version: '3.6'
11 | architecture: 'x64'
12 | - name: Install the library
13 | run: |
14 | pip install nbdev jupyter
15 | pip install -e .
16 | - name: Read all notebooks
17 | run: |
18 | nbdev_read_nbs
19 | - name: Check if all notebooks are cleaned
20 | run: |
21 | echo "Check we are starting with clean git checkout"
22 | if [ -n "$(git status -uno -s)" ]; then echo "git status is not clean"; false; fi
23 | echo "Trying to strip out notebooks"
24 | nbdev_clean_nbs
25 | echo "Check that strip out was unnecessary"
26 | git status -s # display the status to see which nbs need cleaning up
27 | if [ -n "$(git status -uno -s)" ]; then echo -e "!!! Detected unstripped out notebooks\n!!!Remember to run nbdev_install_git_hooks"; false; fi
28 | - name: Check if there is no diff library/notebooks
29 | run: |
30 | if [ -n "$(nbdev_diff_nbs)" ]; then echo -e "!!! Detected difference between the notebooks and the library"; false; fi
31 | - name: Run tests
32 | run: |
33 | nbdev_test_nbs
34 |
--------------------------------------------------------------------------------
/docs/_config.yml:
--------------------------------------------------------------------------------
1 | repository: sachinruk/torchlife
2 | output: web
3 | topnav_title: torchlife
4 | site_title: torchlife
5 | company_name: Sachinthaka Abeywardana
6 | description: Survival Analysis with fastai and pytorch
7 | # Set to false to disable KaTeX math
8 | use_math: true
9 | # Add Google analytics id if you have one and want to use it here
10 | google_analytics:
11 | # See http://nbdev.fast.ai/search for help with adding Search
12 | google_search:
13 |
14 | host: 127.0.0.1
15 | # the preview server used. Leave as is.
16 | port: 4000
17 | # the port where the preview is rendered.
18 |
19 | exclude:
20 | - .idea/
21 | - .gitignore
22 | - vendor
23 |
24 | exclude: [vendor]
25 |
26 | highlighter: rouge
27 | markdown: kramdown
28 | kramdown:
29 | input: GFM
30 | auto_ids: true
31 | hard_wrap: false
32 | syntax_highlighter: rouge
33 |
34 | collections:
35 | tooltips:
36 | output: false
37 |
38 | defaults:
39 | -
40 | scope:
41 | path: ""
42 | type: "pages"
43 | values:
44 | layout: "page"
45 | comments: true
46 | search: true
47 | sidebar: home_sidebar
48 | topnav: topnav
49 | -
50 | scope:
51 | path: ""
52 | type: "tooltips"
53 | values:
54 | layout: "page"
55 | comments: true
56 | search: true
57 | tooltip: true
58 |
59 | sidebars:
60 | - home_sidebar
61 |
62 | theme: jekyll-theme-cayman
63 | baseurl: /torchlife/
--------------------------------------------------------------------------------
/docs/feed.xml:
--------------------------------------------------------------------------------
1 | ---
2 | search: exclude
3 | layout: none
4 | ---
5 |
6 |
7 |
8 |
9 | {{ site.title | xml_escape }}
10 | {{ site.description | xml_escape }}
11 | {{ site.url }}/
12 |
13 | {{ site.time | date_to_rfc822 }}
14 | {{ site.time | date_to_rfc822 }}
15 | Jekyll v{{ jekyll.version }}
16 | {% for post in site.posts limit:10 %}
17 | -
18 |
{{ post.title | xml_escape }}
19 | {{ post.content | xml_escape }}
20 | {{ post.date | date_to_rfc822 }}
21 | {{ post.url | prepend: site.url }}
22 | {{ post.url | prepend: site.url }}
23 | {% for tag in post.tags %}
24 | {{ tag | xml_escape }}
25 | {% endfor %}
26 | {% for tag in page.tags %}
27 | {{ cat | xml_escape }}
28 | {% endfor %}
29 |
30 | {% endfor %}
31 |
32 |
33 |
--------------------------------------------------------------------------------
/torchlife/models/km.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: 20_KaplanMeier.ipynb (unless otherwise specified).
2 |
3 | __all__ = ['KaplanMeier']
4 |
5 | # Cell
6 | import pandas as pd
7 | import matplotlib.pyplot as plt
8 | from scipy.interpolate import interp1d
9 |
10 | # Cell
11 | class KaplanMeier:
12 | def fit(self, df):
13 | """
14 | Estimages the Kaplan-Meier survival function
15 | parameters:
16 | - t: time steps
17 | - e: whether death occured at time step (1) or not (0)
18 | """
19 | d = df.groupby("t")["e"].sum()
20 | n = df.groupby("t")["e"].count()
21 | n = n[::-1].cumsum()[::-1]
22 | self.survival_function = (1 - d / n).cumprod()
23 |
24 | if 0 not in self.survival_function:
25 | self.survival_function[0] = 1
26 | self.survival_function.sort_index(inplace=True)
27 |
28 | t = self.survival_function.index
29 | p = self.survival_function.values
30 | self.interpolate = interp1d(t, p, bounds_error=False, fill_value="extrapolate")
31 |
32 | def predict(self, t):
33 | if any(t < 0):
34 | raise Exception("Time cannot be less than 0.")
35 | p = self.interpolate(t)
36 | p[p<0] = 0
37 | return p
38 |
39 | def plot_survival_function(self):
40 | fig, ax = plt.subplots()
41 | ax.plot(self.survival_function)
42 | ax.set_xlabel("Duration")
43 | ax.set_ylabel("Survival Probability")
44 |
45 | ax.set_title("Survival Function")
46 | return ax
--------------------------------------------------------------------------------
/docs/_includes/links.html:
--------------------------------------------------------------------------------
1 | {% comment %}Get links from each sidebar, as listed in the _config.yml file under sidebars{% endcomment %}
2 |
3 | {% for sidebar in site.sidebars %}
4 | {% for entry in site.data.sidebars[sidebar].entries %}
5 | {% for folder in entry.folders %}
6 | {% for folderitem in folder.folderitems %}
7 | {% if folderitem.url contains "html#" %}
8 | [{{folderitem.url | remove: "/" }}]: {{folderitem.url | remove: "/"}}
9 | {% else %}
10 | [{{folderitem.url | remove: "/" | remove: ".html"}}]: {{folderitem.url | remove: "/"}}
11 | {% endif %}
12 | {% for subfolders in folderitem.subfolders %}
13 | {% for subfolderitem in subfolders.subfolderitems %}
14 | [{{subfolderitem.url | remove: "/" | remove: ".html"}}]: {{subfolderitem.url | remove: "/"}}
15 | {% endfor %}
16 | {% endfor %}
17 | {% endfor %}
18 | {% endfor %}
19 | {% endfor %}
20 | {% endfor %}
21 |
22 |
23 | {% comment %} Get links from topnav {% endcomment %}
24 |
25 | {% for entry in site.data.topnav.topnav %}
26 | {% for item in entry.items %}
27 | {% if item.external_url == null %}
28 | [{{item.url | remove: "/" | remove: ".html"}}]: {{item.url | remove: "/"}}
29 | {% endif %}
30 | {% endfor %}
31 | {% endfor %}
32 |
33 | {% comment %}Get links from topnav dropdowns {% endcomment %}
34 |
35 | {% for entry in site.data.topnav.topnav_dropdowns %}
36 | {% for folder in entry.folders %}
37 | {% for folderitem in folder.folderitems %}
38 | {% if folderitem.external_url == null %}
39 | [{{folderitem.url | remove: "/" | remove: ".html"}}]: {{folderitem.url | remove: "/"}}
40 | {% endif %}
41 | {% endfor %}
42 | {% endfor %}
43 | {% endfor %}
44 |
45 |
--------------------------------------------------------------------------------
/docs/_data/sidebars/home_sidebar.yml:
--------------------------------------------------------------------------------
1 |
2 | #################################################
3 | ### THIS FILE WAS AUTOGENERATED! DO NOT EDIT! ###
4 | #################################################
5 | # Instead edit ../../sidebar.json
6 | entries:
7 | - folders:
8 | - folderitems:
9 | - output: web,pdf
10 | title: Overview
11 | url: /
12 | - output: web,pdf
13 | title: Survival Analysis Theory
14 | url: /SAT
15 | - output: web,pdf
16 | title: Kaplan Meier Model
17 | url: /KaplanMeier
18 | subfolders:
19 | - output: web
20 | subfolderitems:
21 | - output: web,pdf
22 | title: Overview
23 | url: /hazard
24 | - output: web,pdf
25 | title: Piecewise Hazard
26 | url: /hazard.PiecewiseHazard
27 | - output: web,pdf
28 | title: Cox Proportional Hazard
29 | url: /hazard.Cox
30 | title: Hazard Models
31 | - output: web
32 | subfolderitems:
33 | - output: web,pdf
34 | title: Error Distributions
35 | url: /AFT_error_distributions
36 | - output: web,pdf
37 | title: Accelerated Failure Time Models
38 | url: /AFT_models
39 | title: AFT
40 | - output: web
41 | subfolderitems:
42 | - output: web,pdf
43 | title: Data
44 | url: /data
45 | - output: web,pdf
46 | title: Base Model
47 | url: /model
48 | - output: web,pdf
49 | title: Losses
50 | url: /Losses
51 | title: Model
52 | output: web
53 | title: torchlife
54 | output: web
55 | title: Sidebar
56 |
--------------------------------------------------------------------------------
/docs/css/modern-business.css:
--------------------------------------------------------------------------------
1 | /*!
2 | * Start Bootstrap - Modern Business HTML Template (http://startbootstrap.com)
3 | * Code licensed under the Apache License v2.0.
4 | * For details, see http://www.apache.org/licenses/LICENSE-2.0.
5 | */
6 |
7 | /* Global Styles */
8 |
9 | html,
10 | body {
11 | height: 100%;
12 | }
13 |
14 | .img-portfolio {
15 | margin-bottom: 30px;
16 | }
17 |
18 | .img-hover:hover {
19 | opacity: 0.8;
20 | }
21 |
22 | /* Home Page Carousel */
23 |
24 | header.carousel {
25 | height: 50%;
26 | }
27 |
28 | header.carousel .item,
29 | header.carousel .item.active,
30 | header.carousel .carousel-inner {
31 | height: 100%;
32 | }
33 |
34 | header.carousel .fill {
35 | width: 100%;
36 | height: 100%;
37 | background-position: center;
38 | background-size: cover;
39 | }
40 |
41 | /* 404 Page Styles */
42 |
43 | .error-404 {
44 | font-size: 100px;
45 | }
46 |
47 | /* Pricing Page Styles */
48 |
49 | .price {
50 | display: block;
51 | font-size: 50px;
52 | line-height: 50px;
53 | }
54 |
55 | .price sup {
56 | top: -20px;
57 | left: 2px;
58 | font-size: 20px;
59 | }
60 |
61 | .period {
62 | display: block;
63 | font-style: italic;
64 | }
65 |
66 | /* Footer Styles */
67 |
68 | footer {
69 | margin: 50px 0;
70 | }
71 |
72 | /* Responsive Styles */
73 |
74 | @media(max-width:991px) {
75 | .client-img,
76 | .img-related {
77 | margin-bottom: 30px;
78 | }
79 | }
80 |
81 | @media(max-width:767px) {
82 | .img-portfolio {
83 | margin-bottom: 15px;
84 | }
85 |
86 | header.carousel .carousel {
87 | height: 70%;
88 | }
89 | }
90 |
--------------------------------------------------------------------------------
/torchlife/_nbdev.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED BY NBDEV! DO NOT EDIT!
2 |
3 | __all__ = ["index", "modules", "custom_doc_links", "git_url"]
4 |
5 | index = {"KaplanMeier": "20_KaplanMeier.ipynb",
6 | "torch.Tensor.ndim": "55_hazard.PiecewiseHazard.ipynb",
7 | "PieceWiseHazard": "55_hazard.PiecewiseHazard.ipynb",
8 | "ProportionalHazard": "59_hazard.Cox.ipynb",
9 | "AFTModel": "60_AFT_models.ipynb",
10 | "get_distribution": "65_AFT_error_distributions.ipynb",
11 | "TestData": "80_data.ipynb",
12 | "Data": "80_data.ipynb",
13 | "TestDataFrame": "80_data.ipynb",
14 | "DataFrame": "80_data.ipynb",
15 | "create_dl": "80_data.ipynb",
16 | "create_test_dl": "80_data.ipynb",
17 | "get_breakpoints": "80_data.ipynb",
18 | "GeneralModel": "90_model.ipynb",
19 | "train_model": "90_model.ipynb",
20 | "ModelHazard": "90_model.ipynb",
21 | "ModelAFT": "90_model.ipynb",
22 | "Loss": "95_Losses.ipynb",
23 | "LossType": "95_Losses.ipynb",
24 | "AFTLoss": "95_Losses.ipynb",
25 | "aft_loss": "95_Losses.ipynb",
26 | "HazardLoss": "95_Losses.ipynb",
27 | "hazard_loss": "95_Losses.ipynb"}
28 |
29 | modules = ["models/km.py",
30 | "models/ph.py",
31 | "models/cox.py",
32 | "models/aft.py",
33 | "models/error_dist.py",
34 | "data.py",
35 | "model.py",
36 | "losses.py"]
37 |
38 | doc_url = "https://sachinruk.github.io/torchlife/"
39 |
40 | git_url = "https://github.com/sachinruk/torchlife/tree/master/"
41 |
42 | def custom_doc_links(name): return None
43 |
--------------------------------------------------------------------------------
/docs/_includes/head_print.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
{% if page.homepage == true %} {{site.homepage_title}} {% elsif page.title %}{{ page.title }}{% endif %} | {{ site.site_title }}
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
23 |
24 |
29 |
--------------------------------------------------------------------------------
/docs/licenses/LICENSE-BSD-NAVGOCO.txt:
--------------------------------------------------------------------------------
1 | /* This license pertains to the Navgoco jQuery component used for the sidebar. */
2 |
3 | Copyright (c) 2013, Christodoulos Tsoulloftas, http://www.komposta.net
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without modification,
7 | are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice,
10 | this list of conditions and the following disclaimer.
11 | * Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 | * Neither the name of the
nor the names of its
15 | contributors may be used to endorse or promote products derived from this
16 | software without specific prior written permission.
17 |
18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
19 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
20 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
21 | IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
22 | INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
23 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
24 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
25 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE
26 | OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
27 | OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/docs/_layouts/page.html:
--------------------------------------------------------------------------------
1 | ---
2 | layout: default
3 | ---
4 |
5 |
9 |
10 | {% if page.simple_map == true %}
11 |
12 |
17 |
18 | {% include custom/{{page.map_name}}.html %}
19 |
20 | {% elsif page.complex_map == true %}
21 |
22 |
27 |
28 | {% include custom/{{page.map_name}}.html %}
29 |
30 | {% endif %}
31 |
32 |
33 |
34 | {% if page.summary %}
35 |
{{page.summary}}
36 | {% endif %}
37 |
38 | {% unless page.toc == false %}
39 | {% include toc.html %}
40 | {% endunless %}
41 |
42 |
43 | {% if site.github_editme_path %}
44 |
45 |
Edit me
46 |
47 | {% endif %}
48 |
49 | {{content}}
50 |
51 |
62 |
63 |
64 |
65 | {{site.data.alerts.hr_shaded}}
66 |
67 | {% include footer.html %}
68 |
--------------------------------------------------------------------------------
/torchlife/losses.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: 95_Losses.ipynb (unless otherwise specified).
2 |
3 | __all__ = ['Loss', 'LossType', 'AFTLoss', 'aft_loss', 'HazardLoss', 'hazard_loss']
4 |
5 | # Cell
6 | from abc import ABC, abstractmethod
7 | from typing import Callable, Tuple
8 | import torch
9 |
10 | # Cell
11 | class Loss(ABC):
12 | @abstractmethod
13 | def __call__(event:torch.Tensor, *args):
14 | pass
15 |
16 | # Cell
17 | LossType = Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
18 |
19 | # Cell
20 | class AFTLoss(Loss):
21 | @staticmethod
22 | def __call__(event:torch.Tensor, log_pdf: torch.Tensor, log_icdf: torch.Tensor) -> torch.Tensor:
23 | lik = event * log_pdf + (1 - event) * log_icdf
24 | return -lik.mean()
25 |
26 | # Cell
27 | def _aft_loss(
28 | log_pdf: torch.Tensor, log_cdf: torch.Tensor, e: torch.Tensor
29 | ) -> torch.Tensor:
30 | lik = e * log_pdf + (1 - e) * log_cdf
31 | return -lik.mean()
32 |
33 |
34 | def aft_loss(log_prob, e):
35 | log_pdf, log_cdf = log_prob
36 | return _aft_loss(log_pdf, log_cdf, e)
37 |
38 | # Cell
39 | class HazardLoss(Loss):
40 | @staticmethod
41 | def __call__(event: torch.Tensor, logλ: torch.Tensor, Λ: torch.Tensor) -> torch.Tensor:
42 | log_lik = event * logλ - Λ
43 | return -log_lik.mean()
44 |
45 | # Cell
46 | def _hazard_loss(logλ: torch.Tensor, Λ: torch.Tensor, e: torch.Tensor) -> torch.Tensor:
47 | log_lik = e * logλ - Λ
48 | return -log_lik.mean()
49 |
50 |
51 | def hazard_loss(
52 | hazard: Tuple[torch.Tensor, torch.Tensor], e: torch.Tensor
53 | ) -> torch.Tensor:
54 | """
55 | parameters:
56 | - hazard: log hazard and Cumulative hazard
57 | - e: torch.Tensor of 1 if death event occured and 0 otherwise
58 | """
59 | logλ, Λ = hazard
60 | return _hazard_loss(logλ, Λ, e)
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from packaging.version import parse
2 | from configparser import ConfigParser
3 | import setuptools
4 | assert parse(setuptools.__version__)>=parse('36.2')
5 |
6 | # note: all settings are in settings.ini; edit there, not here
7 | config = ConfigParser(delimiters=['='])
8 | config.read('settings.ini')
9 | cfg = config['DEFAULT']
10 |
11 | cfg_keys = 'version description keywords author author_email'.split()
12 | expected = cfg_keys + "lib_name user branch license status min_python audience language".split()
13 | for o in expected: assert o in cfg, "missing expected setting: {}".format(o)
14 | setup_cfg = {o:cfg[o] for o in cfg_keys}
15 |
16 | licenses = {
17 | 'apache2': ('Apache Software License 2.0','OSI Approved :: Apache Software License'),
18 | }
19 | statuses = [ '1 - Planning', '2 - Pre-Alpha', '3 - Alpha',
20 | '4 - Beta', '5 - Production/Stable', '6 - Mature', '7 - Inactive' ]
21 | py_versions = '2.0 2.1 2.2 2.3 2.4 2.5 2.6 2.7 3.0 3.1 3.2 3.3 3.4 3.5 3.6 3.7 3.8'.split()
22 |
23 | requirements = cfg.get('requirements','').split()
24 | lic = licenses[cfg['license']]
25 | min_python = cfg['min_python']
26 |
27 | setuptools.setup(
28 | name = cfg['lib_name'],
29 | license = lic[0],
30 | classifiers = [
31 | 'Development Status :: ' + statuses[int(cfg['status'])],
32 | 'Intended Audience :: ' + cfg['audience'].title(),
33 | 'License :: ' + lic[1],
34 | 'Natural Language :: ' + cfg['language'].title(),
35 | ] + ['Programming Language :: Python :: '+o for o in py_versions[py_versions.index(min_python):]],
36 | url = 'https://github.com/{}/{}'.format(cfg['user'],cfg['lib_name']),
37 | packages = setuptools.find_packages(),
38 | include_package_data = True,
39 | install_requires = requirements,
40 | python_requires = '>=' + cfg['min_python'],
41 | long_description = open('README.md').read(),
42 | long_description_content_type = 'text/markdown',
43 | zip_safe = False,
44 | entry_points = { 'console_scripts': cfg.get('console_scripts','').split() },
45 | **setup_cfg)
46 |
47 |
--------------------------------------------------------------------------------
/docs/js/customscripts.js:
--------------------------------------------------------------------------------
1 | $('#mysidebar').height($(".nav").height());
2 |
3 |
4 | $( document ).ready(function() {
5 |
6 | //this script says, if the height of the viewport is greater than 800px, then insert affix class, which makes the nav bar float in a fixed
7 | // position as your scroll. if you have a lot of nav items, this height may not work for you.
8 | var h = $(window).height();
9 | //console.log (h);
10 | if (h > 800) {
11 | $( "#mysidebar" ).attr("class", "nav affix");
12 | }
13 | // activate tooltips. although this is a bootstrap js function, it must be activated this way in your theme.
14 | $('[data-toggle="tooltip"]').tooltip({
15 | placement : 'top'
16 | });
17 |
18 | /**
19 | * AnchorJS
20 | */
21 | anchors.add('h2,h3,h4,h5');
22 |
23 | });
24 |
25 | // needed for nav tabs on pages. See Formatting > Nav tabs for more details.
26 | // script from http://stackoverflow.com/questions/10523433/how-do-i-keep-the-current-tab-active-with-twitter-bootstrap-after-a-page-reload
27 | $(function() {
28 | var json, tabsState;
29 | $('a[data-toggle="pill"], a[data-toggle="tab"]').on('shown.bs.tab', function(e) {
30 | var href, json, parentId, tabsState;
31 |
32 | tabsState = localStorage.getItem("tabs-state");
33 | json = JSON.parse(tabsState || "{}");
34 | parentId = $(e.target).parents("ul.nav.nav-pills, ul.nav.nav-tabs").attr("id");
35 | href = $(e.target).attr('href');
36 | json[parentId] = href;
37 |
38 | return localStorage.setItem("tabs-state", JSON.stringify(json));
39 | });
40 |
41 | tabsState = localStorage.getItem("tabs-state");
42 | json = JSON.parse(tabsState || "{}");
43 |
44 | $.each(json, function(containerId, href) {
45 | return $("#" + containerId + " a[href=" + href + "]").tab('show');
46 | });
47 |
48 | $("ul.nav.nav-pills, ul.nav.nav-tabs").each(function() {
49 | var $this = $(this);
50 | if (!json[$this.attr("id")]) {
51 | return $this.find("a[data-toggle=tab]:first, a[data-toggle=pill]:first").tab("show");
52 | }
53 | });
54 | });
55 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.bak
2 | .gitattributes
3 | .last_checked
4 | .gitconfig
5 | *.bak
6 | *.log
7 | *~
8 | ~*
9 | _tmp*
10 | tmp*
11 | tags
12 |
13 | # Byte-compiled / optimized / DLL files
14 | __pycache__/
15 | *.py[cod]
16 | *$py.class
17 |
18 | # C extensions
19 | *.so
20 |
21 | # Distribution / packaging
22 | .Python
23 | env/
24 | build/
25 | develop-eggs/
26 | dist/
27 | downloads/
28 | eggs/
29 | .eggs/
30 | lib/
31 | lib64/
32 | parts/
33 | sdist/
34 | var/
35 | wheels/
36 | *.egg-info/
37 | .installed.cfg
38 | *.egg
39 |
40 | # PyInstaller
41 | # Usually these files are written by a python script from a template
42 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
43 | *.manifest
44 | *.spec
45 |
46 | # Installer logs
47 | pip-log.txt
48 | pip-delete-this-directory.txt
49 |
50 | # Unit test / coverage reports
51 | htmlcov/
52 | .tox/
53 | .coverage
54 | .coverage.*
55 | .cache
56 | nosetests.xml
57 | coverage.xml
58 | *.cover
59 | .hypothesis/
60 |
61 | # Translations
62 | *.mo
63 | *.pot
64 |
65 | # Django stuff:
66 | *.log
67 | local_settings.py
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # pyenv
86 | .python-version
87 |
88 | # celery beat schedule file
89 | celerybeat-schedule
90 |
91 | # SageMath parsed files
92 | *.sage.py
93 |
94 | # dotenv
95 | .env
96 |
97 | # virtualenv
98 | .venv
99 | venv/
100 | ENV/
101 |
102 | # Spyder project settings
103 | .spyderproject
104 | .spyproject
105 |
106 | # Rope project settings
107 | .ropeproject
108 |
109 | # mkdocs documentation
110 | /site
111 |
112 | # mypy
113 | .mypy_cache/
114 |
115 | .vscode
116 | *.swp
117 |
118 | # osx generated files
119 | .DS_Store
120 | .DS_Store?
121 | .Trashes
122 | ehthumbs.db
123 | Thumbs.db
124 | .idea
125 |
126 | # pytest
127 | .pytest_cache
128 |
129 | # tools/trust-doc-nbs
130 | docs_src/.last_checked
131 |
132 | # symlinks to fastai
133 | docs_src/fastai
134 | tools/fastai
135 |
136 | # link checker
137 | checklink/cookies.txt
138 |
139 | # .gitconfig is now autogenerated
140 | .gitconfig
141 |
142 | models/
143 | lightning_logs/
144 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to contribute
2 |
3 | ## How to get started
4 |
5 | Before anything else, please install the git hooks that run automatic scripts during each commit and merge to strip the notebooks of superfluous metadata (and avoid merge conflicts). After cloning the repository, run the following command inside it:
6 | ```
7 | nbdev_install_git_hooks
8 | ```
9 |
10 | ## Did you find a bug?
11 |
12 | * Ensure the bug was not already reported by searching on GitHub under Issues.
13 | * If you're unable to find an open issue addressing the problem, open a new one. Be sure to include a title and clear description, as much relevant information as possible, and a code sample or an executable test case demonstrating the expected behavior that is not occurring.
14 | * Be sure to add the complete error messages.
15 |
16 | #### Did you write a patch that fixes a bug?
17 |
18 | * Open a new GitHub pull request with the patch.
19 | * Ensure that your PR includes a test that fails without your patch, and pass with it.
20 | * Ensure the PR description clearly describes the problem and solution. Include the relevant issue number if applicable.
21 |
22 | ## PR submission guidelines
23 |
24 | * Keep each PR focused. While it's more convenient, do not combine several unrelated fixes together. Create as many branches as needing to keep each PR focused.
25 | * Do not mix style changes/fixes with "functional" changes. It's very difficult to review such PRs and it most likely get rejected.
26 | * Do not add/remove vertical whitespace. Preserve the original style of the file you edit as much as you can.
27 | * Do not turn an already submitted PR into your development playground. If after you submitted PR, you discovered that more work is needed - close the PR, do the required work and then submit a new PR. Otherwise each of your commits requires attention from maintainers of the project.
28 | * If, however, you submitted a PR and received a request for changes, you should proceed with commits inside that PR, so that the maintainer can see the incremental fixes and won't need to review the whole PR again. In the exception case where you realize it'll take many many commits to complete the requests, then it's probably best to close the PR, do the work and then submit it again. Use common sense where you'd choose one way over another.
29 |
30 | ## Do you want to contribute to the documentation?
31 |
32 | * Docs are automatically created from the notebooks in the nbs folder.
33 |
34 |
--------------------------------------------------------------------------------
/torchlife/models/aft.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: 60_AFT_models.ipynb (unless otherwise specified).
2 |
3 | __all__ = ['AFTModel']
4 |
5 | # Cell
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from sklearn.preprocessing import MaxAbsScaler, StandardScaler
12 |
13 | from .error_dist import get_distribution
14 |
15 | # Cell
16 | class AFTModel(nn.Module):
17 | """
18 | Accelerated Failure Time model
19 | parameters:
20 | - Distribution of which the error is assumed to be
21 | - dim (optional): input dimensionality of variables
22 | - h (optional): number of hidden nodes
23 | """
24 | def __init__(self, distribution:str, input_dim:int, h:tuple=()):
25 | super().__init__()
26 | self.logpdf, self.logicdf = get_distribution(distribution)
27 | self.β = nn.Parameter(-torch.rand(1))
28 | self.logσ = nn.Parameter(-torch.rand(1))
29 |
30 | if input_dim > 0:
31 | nodes = (input_dim,) + h + (1,)
32 | self.layers = nn.ModuleList([nn.Linear(a,b, bias=False)
33 | for a,b in zip(nodes[:-1], nodes[1:])])
34 |
35 | self.eps = 1e-7
36 |
37 | def get_mode_time(self, x:torch.Tensor=None):
38 | μ = self.β
39 | if x is not None:
40 | for layer in self.layers[:-1]:
41 | x = F.relu(layer(x))
42 | μ = self.β + self.layers[-1](x)
43 |
44 | σ = torch.exp(self.logσ)
45 | return μ, σ
46 |
47 | def forward(self, t:torch.Tensor, x:torch.Tensor=None):
48 | μ, σ = self.get_mode_time(x)
49 | ξ = torch.log(t + self.eps) - μ
50 | logpdf = self.logpdf(ξ, σ)
51 | logicdf = self.logicdf(ξ, σ)
52 | return logpdf, logicdf
53 |
54 | def survival_function(self, t:np.array, t_scaler, x:np.array=None, x_scaler=None):
55 | if len(t.shape) == 1:
56 | t = t[:,None]
57 | t = t_scaler.transform(t)
58 | t = torch.Tensor(t)
59 | if x is not None:
60 | if len(x.shape) == 1:
61 | x = x[None, :]
62 | if len(x) == 1:
63 | x = np.repeat(x, len(t), axis=0)
64 | x = x_scaler.transform(x)
65 | x = torch.Tensor(x)
66 |
67 | with torch.no_grad():
68 | # calculate cumulative hazard according to above
69 | _, Λ = self(t, x)
70 | return torch.exp(Λ)
71 |
72 | def plot_survival_function(self, t:np.array, t_scaler, x:np.array=None, x_scaler=None):
73 | surv_fun = self.survival_function(t, t_scaler, x, x_scaler)
74 |
75 | # plot
76 | plt.figure(figsize=(12,5))
77 | plt.plot(t, surv_fun)
78 | plt.xlabel('Time')
79 | plt.ylabel('Survival Probability')
80 | plt.show()
--------------------------------------------------------------------------------
/docs/css/theme-green.css:
--------------------------------------------------------------------------------
1 | .summary {
2 | color: #808080;
3 | border-left: 5px solid #E50E51;
4 | font-size:16px;
5 | }
6 |
7 |
8 | h3 {color: #E50E51; }
9 | h4 {color: #808080; }
10 |
11 | .nav-tabs > li.active > a, .nav-tabs > li.active > a:hover, .nav-tabs > li.active > a:focus {
12 | background-color: #248ec2;
13 | color: white;
14 | }
15 |
16 | .nav > li.active > a {
17 | background-color: #72ac4a;
18 | }
19 |
20 | .nav > li > a:hover {
21 | background-color: #72ac4a;
22 | }
23 |
24 | div.navbar-collapse .dropdown-menu > li > a:hover {
25 | background-color: #72ac4a;
26 | }
27 |
28 | .navbar-inverse .navbar-nav>li>a, .navbar-inverse .navbar-brand {
29 | color: white;
30 | }
31 |
32 | .navbar-inverse .navbar-nav>li>a:hover, a.fa.fa-home.fa-lg.navbar-brand:hover {
33 | color: #f0f0f0;
34 | }
35 |
36 | .nav li.thirdlevel > a {
37 | background-color: #FAFAFA !important;
38 | color: #72ac4a;
39 | font-weight: bold;
40 | }
41 |
42 | a[data-toggle="tooltip"] {
43 | color: #649345;
44 | font-style: italic;
45 | cursor: default;
46 | }
47 |
48 | .navbar-inverse {
49 | background-color: #72ac4a;
50 | border-color: #5b893c;
51 | }
52 |
53 | .navbar-inverse .navbar-nav > .open > a, .navbar-inverse .navbar-nav > .open > a:hover, .navbar-inverse .navbar-nav > .open > a:focus {
54 | color: #5b893c;
55 | }
56 |
57 | .navbar-inverse .navbar-nav > .open > a, .navbar-inverse .navbar-nav > .open > a:hover, .navbar-inverse .navbar-nav > .open > a:focus {
58 | background-color: #5b893c;
59 | color: #ffffff;
60 | }
61 |
62 | /* not sure if using this ...*/
63 | .navbar-inverse .navbar-collapse, .navbar-inverse .navbar-form {
64 | border-color: #72ac4a !important;
65 | }
66 |
67 | .btn-primary {
68 | color: #ffffff;
69 | background-color: #5b893c;
70 | border-color: #5b893c;
71 | }
72 |
73 | .btn-primary:hover,
74 | .btn-primary:focus,
75 | .btn-primary:active,
76 | .btn-primary.active,
77 | .open .dropdown-toggle.btn-primary {
78 | background-color: #72ac4a;
79 | border-color: #5b893c;
80 | }
81 |
82 | .printTitle {
83 | color: #5b893c !important;
84 | }
85 |
86 | body.print h1 {color: #5b893c !important; font-size:28px;}
87 | body.print h2 {color: #595959 !important; font-size:24px;}
88 | body.print h3 {color: #E50E51 !important; font-size:14px;}
89 | body.print h4 {color: #679DCE !important; font-size:14px; font-style: italic;}
90 |
91 | .anchorjs-link:hover {
92 | color: #4f7233;
93 | }
94 |
95 | div.sidebarTitle {
96 | color: #E50E51;
97 | }
98 |
99 | li.sidebarTitle {
100 | margin-top:20px;
101 | font-weight:normal;
102 | font-size:130%;
103 | color: #ED1951;
104 | margin-bottom:10px;
105 | margin-left: 5px;
106 | }
107 |
108 | .navbar-inverse .navbar-toggle:focus, .navbar-inverse .navbar-toggle:hover {
109 | background-color: #E50E51;
110 | }
111 |
--------------------------------------------------------------------------------
/torchlife/models/cox.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: 59_hazard.Cox.ipynb (unless otherwise specified).
2 |
3 | __all__ = ['ProportionalHazard']
4 |
5 | # Cell
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from sklearn.preprocessing import MaxAbsScaler, StandardScaler
12 |
13 | from ..losses import hazard_loss
14 | from .ph import PieceWiseHazard
15 |
16 | # torch.Tensor.ndim = property(lambda x: x.dim())
17 |
18 | # Cell
19 | class ProportionalHazard(nn.Module):
20 | """
21 | Hazard proportional to time and feature component as shown above.
22 | parameters:
23 | - breakpoints: time points where hazard would change
24 | - max_t: maximum point of time to plot to.
25 | - dim: number of input dimensions of x
26 | - h: (optional) number of hidden units (for x only).
27 | """
28 | def __init__(self, breakpoints:np.array, t_scaler:MaxAbsScaler, x_scaler:StandardScaler,
29 | dim:int, h:tuple=(), **kwargs):
30 | super().__init__()
31 | self.baseλ = PieceWiseHazard(breakpoints, t_scaler)
32 | self.x_scaler = x_scaler
33 | nodes = (dim,) + h + (1,)
34 | self.layers = nn.ModuleList([nn.Linear(a,b, bias=False)
35 | for a,b in zip(nodes[:-1], nodes[1:])])
36 |
37 | def forward(self, t, t_section, x):
38 | logλ, Λ = self.baseλ(t, t_section)
39 |
40 | for layer in self.layers[:-1]:
41 | x = F.relu(layer(x))
42 | log_hx = self.layers[-1](x)
43 |
44 | logλ += log_hx
45 | Λ = torch.exp(log_hx + torch.log(Λ))
46 | return logλ, Λ
47 |
48 | def survival_function(self, t:np.array, x:np.array) -> torch.Tensor:
49 | if len(t.shape) == 1:
50 | t = t[:,None]
51 | t = self.baseλ.t_scaler.transform(t)
52 | if len(x.shape) == 1:
53 | x = x[None, :]
54 | if len(x) == 1:
55 | x = np.repeat(x, len(t), axis=0)
56 | x = self.x_scaler.transform(x)
57 |
58 |
59 | with torch.no_grad():
60 | x = torch.Tensor(x)
61 | # get the times and time sections for survival function
62 | breakpoints = self.baseλ.breakpoints[1:].cpu().numpy()
63 | t_sec_query = np.searchsorted(breakpoints.squeeze(), t.squeeze())
64 | # convert to pytorch tensors
65 | t_query = torch.Tensor(t)
66 | t_sec_query = torch.LongTensor(t_sec_query)
67 |
68 | # calculate cumulative hazard according to above
69 | _, Λ = self.forward(t_query, t_sec_query, x)
70 | return torch.exp(-Λ)
71 |
72 |
73 | def plot_survival_function(self, t:np.array, x:np.array) -> None:
74 | s = self.survival_function(t, x)
75 |
76 | # plot
77 | plt.figure(figsize=(12,5))
78 | plt.plot(t, s)
79 | plt.xlabel('Time')
80 | plt.ylabel('Survival Probability')
81 | plt.show()
--------------------------------------------------------------------------------
/docs/_includes/sidebar.html:
--------------------------------------------------------------------------------
1 | {% assign sidebar = site.data.sidebars[page.sidebar].entries %}
2 | {% assign pageurl = page.url | remove: ".html" %}
3 |
4 |
57 |
58 |
59 |
60 |
--------------------------------------------------------------------------------
/docs/css/theme-blue.css:
--------------------------------------------------------------------------------
1 | .summary {
2 | color: #808080;
3 | border-left: 5px solid #ED1951;
4 | font-size:16px;
5 | }
6 |
7 |
8 | h3 {color: #000000; }
9 | h4 {color: #000000; }
10 |
11 | .nav-tabs > li.active > a, .nav-tabs > li.active > a:hover, .nav-tabs > li.active > a:focus {
12 | background-color: #248ec2;
13 | color: white;
14 | }
15 |
16 | .nav > li.active > a {
17 | background-color: #347DBE;
18 | }
19 |
20 | .nav > li > a:hover {
21 | background-color: #248ec2;
22 | }
23 |
24 | div.navbar-collapse .dropdown-menu > li > a:hover {
25 | background-color: #347DBE;
26 | }
27 |
28 | .nav li.thirdlevel > a {
29 | background-color: #FAFAFA !important;
30 | color: #248EC2;
31 | font-weight: bold;
32 | }
33 |
34 | a[data-toggle="tooltip"] {
35 | color: #649345;
36 | font-style: italic;
37 | cursor: default;
38 | }
39 |
40 | .navbar-inverse {
41 | background-color: #347DBE;
42 | border-color: #015CAE;
43 | }
44 | .navbar-inverse .navbar-nav>li>a, .navbar-inverse .navbar-brand {
45 | color: white;
46 | }
47 |
48 | .navbar-inverse .navbar-nav>li>a:hover, a.fa.fa-home.fa-lg.navbar-brand:hover {
49 | color: #f0f0f0;
50 | }
51 |
52 | a.navbar-brand:hover {
53 | color: #f0f0f0;
54 | }
55 |
56 | .navbar-inverse .navbar-nav > .open > a, .navbar-inverse .navbar-nav > .open > a:hover, .navbar-inverse .navbar-nav > .open > a:focus {
57 | color: #015CAE;
58 | }
59 |
60 | .navbar-inverse .navbar-nav > .open > a, .navbar-inverse .navbar-nav > .open > a:hover, .navbar-inverse .navbar-nav > .open > a:focus {
61 | background-color: #015CAE;
62 | color: #ffffff;
63 | }
64 |
65 | .navbar-inverse .navbar-collapse, .navbar-inverse .navbar-form {
66 | border-color: #248ec2 !important;
67 | }
68 |
69 | .btn-primary {
70 | color: #ffffff;
71 | background-color: #347DBE;
72 | border-color: #347DBE;
73 | }
74 |
75 | .navbar-inverse .navbar-nav > .active > a, .navbar-inverse .navbar-nav > .active > a:hover, .navbar-inverse .navbar-nav > .active > a:focus {
76 | background-color: #347DBE;
77 | }
78 |
79 | .btn-primary:hover,
80 | .btn-primary:focus,
81 | .btn-primary:active,
82 | .btn-primary.active,
83 | .open .dropdown-toggle.btn-primary {
84 | background-color: #248ec2;
85 | border-color: #347DBE;
86 | }
87 |
88 | .printTitle {
89 | color: #015CAE !important;
90 | }
91 |
92 | body.print h1 {color: #015CAE !important; font-size:28px !important;}
93 | body.print h2 {color: #595959 !important; font-size:20px !important;}
94 | body.print h3 {color: #E50E51 !important; font-size:14px !important;}
95 | body.print h4 {color: #679DCE !important; font-size:14px; font-style: italic !important;}
96 |
97 | .anchorjs-link:hover {
98 | color: #216f9b;
99 | }
100 |
101 | div.sidebarTitle {
102 | color: #015CAE;
103 | }
104 |
105 | li.sidebarTitle {
106 | margin-top:20px;
107 | font-weight:normal;
108 | font-size:130%;
109 | color: #ED1951;
110 | margin-bottom:10px;
111 | margin-left: 5px;
112 |
113 | }
114 |
115 | .navbar-inverse .navbar-toggle:focus, .navbar-inverse .navbar-toggle:hover {
116 | background-color: #015CAE;
117 | }
118 |
119 | .navbar-inverse .navbar-toggle {
120 | border-color: #015CAE;
121 | }
122 |
--------------------------------------------------------------------------------
/docs/_includes/topnav.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
13 |
14 |
15 |
16 | Nav
17 |
18 |
19 | {% assign topnav = site.data[page.topnav] %}
20 | {% assign topnav_dropdowns = site.data[page.topnav].topnav_dropdowns %}
21 |
22 | {% for entry in topnav.topnav %}
23 | {% for item in entry.items %}
24 | {% if item.external_url %}
25 | {{item.title}}
26 | {% elsif page.url contains item.url %}
27 | {{item.title}}
28 | {% else %}
29 | {{item.title}}
30 | {% endif %}
31 | {% endfor %}
32 | {% endfor %}
33 |
34 |
35 | {% for entry in topnav_dropdowns %}
36 | {% for folder in entry.folders %}
37 |
38 | {{ folder.title }}
39 |
50 |
51 | {% endfor %}
52 | {% endfor %}
53 | {% if site.google_search %}
54 |
55 | {% include search_google_custom.html %}
56 |
57 | {% endif %}
58 |
59 |
60 |
61 |
62 |
63 |
--------------------------------------------------------------------------------
/docs/js/jquery.navgoco.min.js:
--------------------------------------------------------------------------------
1 | /*
2 | * jQuery Navgoco Menus Plugin v0.2.1 (2014-04-11)
3 | * https://github.com/tefra/navgoco
4 | *
5 | * Copyright (c) 2014 Chris T (@tefra)
6 | * BSD - https://github.com/tefra/navgoco/blob/master/LICENSE-BSD
7 | */
8 | !function(a){"use strict";var b=function(b,c,d){return this.el=b,this.$el=a(b),this.options=c,this.uuid=this.$el.attr("id")?this.$el.attr("id"):d,this.state={},this.init(),this};b.prototype={init:function(){var b=this;b._load(),b.$el.find("ul").each(function(c){var d=a(this);d.attr("data-index",c),b.options.save&&b.state.hasOwnProperty(c)?(d.parent().addClass(b.options.openClass),d.show()):d.parent().hasClass(b.options.openClass)?(d.show(),b.state[c]=1):d.hide()});var c=a(" ").prepend(b.options.caretHtml),d=b.$el.find("li > a");b._trigger(c,!1),b._trigger(d,!0),b.$el.find("li:has(ul) > a").prepend(c)},_trigger:function(b,c){var d=this;b.on("click",function(b){b.stopPropagation();var e=c?a(this).next():a(this).parent().next(),f=!1;if(c){var g=a(this).attr("href");f=void 0===g||""===g||"#"===g}if(e=e.length>0?e:!1,d.options.onClickBefore.call(this,b,e),!c||e&&f)b.preventDefault(),d._toggle(e,e.is(":hidden")),d._save();else if(d.options.accordion){var h=d.state=d._parents(a(this));d.$el.find("ul").filter(":visible").each(function(){var b=a(this),c=b.attr("data-index");h.hasOwnProperty(c)||d._toggle(b,!1)}),d._save()}d.options.onClickAfter.call(this,b,e)})},_toggle:function(b,c){var d=this,e=b.attr("data-index"),f=b.parent();if(d.options.onToggleBefore.call(this,b,c),c){if(f.addClass(d.options.openClass),b.slideDown(d.options.slide),d.state[e]=1,d.options.accordion){var g=d.state=d._parents(b);g[e]=d.state[e]=1,d.$el.find("ul").filter(":visible").each(function(){var b=a(this),c=b.attr("data-index");g.hasOwnProperty(c)||d._toggle(b,!1)})}}else f.removeClass(d.options.openClass),b.slideUp(d.options.slide),d.state[e]=0;d.options.onToggleAfter.call(this,b,c)},_parents:function(b,c){var d={},e=b.parent(),f=e.parents("ul");return f.each(function(){var b=a(this),e=b.attr("data-index");return e?void(d[e]=c?b:1):!1}),d},_save:function(){if(this.options.save){var b={};for(var d in this.state)1===this.state[d]&&(b[d]=1);c[this.uuid]=this.state=b,a.cookie(this.options.cookie.name,JSON.stringify(c),this.options.cookie)}},_load:function(){if(this.options.save){if(null===c){var b=a.cookie(this.options.cookie.name);c=b?JSON.parse(b):{}}this.state=c.hasOwnProperty(this.uuid)?c[this.uuid]:{}}},toggle:function(b){var c=this,d=arguments.length;if(1>=d)c.$el.find("ul").each(function(){var d=a(this);c._toggle(d,b)});else{var e,f={},g=Array.prototype.slice.call(arguments,1);d--;for(var h=0;d>h;h++){e=g[h];var i=c.$el.find('ul[data-index="'+e+'"]').first();if(i&&(f[e]=i,b)){var j=c._parents(i,!0);for(var k in j)f.hasOwnProperty(k)||(f[k]=j[k])}}for(e in f)c._toggle(f[e],b)}c._save()},destroy:function(){a.removeData(this.$el),this.$el.find("li:has(ul) > a").unbind("click"),this.$el.find("li:has(ul) > a > span").unbind("click")}},a.fn.navgoco=function(c){if("string"==typeof c&&"_"!==c.charAt(0)&&"init"!==c)var d=!0,e=Array.prototype.slice.call(arguments,1);else c=a.extend({},a.fn.navgoco.defaults,c||{}),a.cookie||(c.save=!1);return this.each(function(f){var g=a(this),h=g.data("navgoco");h||(h=new b(this,d?a.fn.navgoco.defaults:c,f),g.data("navgoco",h)),d&&h[c].apply(h,e)})};var c=null;a.fn.navgoco.defaults={caretHtml:"",accordion:!1,openClass:"open",save:!0,cookie:{name:"navgoco",expires:!1,path:"/"},slide:{duration:400,easing:"swing"},onClickBefore:a.noop,onClickAfter:a.noop,onToggleBefore:a.noop,onToggleAfter:a.noop}}(jQuery);
--------------------------------------------------------------------------------
/docs/_includes/initialize_shuffle.html:
--------------------------------------------------------------------------------
1 |
7 |
8 |
100 |
101 |
102 |
103 |
114 |
115 |
129 |
130 |
131 |
--------------------------------------------------------------------------------
/docs/css/printstyles.css:
--------------------------------------------------------------------------------
1 |
2 | /*body.print .container {max-width: 650px;}*/
3 |
4 | body {
5 | font-size:14px;
6 | }
7 | .nav ul li a {border-top:0px; background-color:transparent; color: #808080; }
8 | #navig a[href] {color: #595959 !important;}
9 | table .table {max-width:650px;}
10 |
11 | #navig li.sectionHead {font-weight: bold; font-size: 18px; color: #595959 !important; }
12 | #navig li {font-weight: normal; }
13 |
14 | #navig a[href]::after { content: leader(".") target-counter(attr(href), page); }
15 |
16 | a[href]::after {
17 | content: " (page " target-counter(attr(href), page) ")"
18 | }
19 |
20 | a[href^="http:"]::after, a[href^="https:"]::after {
21 | content: "";
22 | }
23 |
24 | a[href] {
25 | color: blue !important;
26 | }
27 | a[href*="mailto"]::after, a[data-toggle="tooltip"]::after, a[href].noCrossRef::after {
28 | content: "";
29 | }
30 |
31 |
32 | @page {
33 | margin: 60pt 90pt 60pt 90pt;
34 | font-family: sans-serif;
35 | font-style:none;
36 | color: gray;
37 |
38 | }
39 |
40 | .printTitle {
41 | line-height:30pt;
42 | font-size:27pt;
43 | font-weight: bold;
44 | letter-spacing: -.5px;
45 | margin-bottom:25px;
46 | }
47 |
48 | .printSubtitle {
49 | font-size: 19pt;
50 | color: #cccccc !important;
51 | font-family: "Grotesque MT Light";
52 | line-height: 22pt;
53 | letter-spacing: -.5px;
54 | margin-bottom:20px;
55 | }
56 | .printTitleArea hr {
57 | color: #999999 !important;
58 | height: 2px;
59 | width: 100%;
60 | }
61 |
62 | .printTitleImage {
63 | max-width:300px;
64 | margin-bottom:200px;
65 | }
66 |
67 |
68 | .printTitleImage {
69 | max-width: 250px;
70 | }
71 |
72 | #navig {
73 | /*page-break-before: always;*/
74 | }
75 |
76 | .copyrightBoilerplate {
77 | page-break-before:always;
78 | font-size:14px;
79 | }
80 |
81 | .lastGeneratedDate {
82 | font-style: italic;
83 | font-size:14px;
84 | color: gray;
85 | }
86 |
87 | .alert a {
88 | text-decoration: none !important;
89 | }
90 |
91 |
92 | body.title { page: title }
93 |
94 | @page title {
95 | @top-left {
96 | content: " ";
97 | }
98 | @top-right {
99 | content: " "
100 | }
101 | @bottom-right {
102 | content: " ";
103 | }
104 | @bottom-left {
105 | content: " ";
106 | }
107 | }
108 |
109 | body.frontmatter { page: frontmatter }
110 | body.frontmatter {counter-reset: page 1}
111 |
112 |
113 | @page frontmatter {
114 | @top-left {
115 | content: prince-script(guideName);
116 | }
117 | @top-right {
118 | content: prince-script(datestamp);
119 | }
120 | @bottom-right {
121 | content: counter(page, lower-roman);
122 | }
123 | @bottom-left {
124 | content: "youremail@domain.com"; }
125 | }
126 |
127 | body.first_page {counter-reset: page 1}
128 |
129 | h1 { string-set: doctitle content() }
130 |
131 | @page {
132 | @top-left {
133 | content: string(doctitle);
134 | font-size: 11px;
135 | font-style: italic;
136 | }
137 | @top-right {
138 | content: prince-script(datestamp);
139 | font-size: 11px;
140 | }
141 |
142 | @bottom-right {
143 | content: "Page " counter(page);
144 | font-size: 11px;
145 | }
146 | @bottom-left {
147 | content: prince-script(guideName);
148 | font-size: 11px;
149 | }
150 | }
151 | .alert {
152 | background-color: #fafafa !important;
153 | border-color: #dedede !important;
154 | color: black;
155 | }
156 |
157 | pre {
158 | background-color: #fafafa;
159 | }
160 |
--------------------------------------------------------------------------------
/docs/_includes/head.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | {{ page.title }} | {{ site.site_title }}
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 | {% if site.use_math %}
25 |
26 |
27 |
28 |
39 | {% endif %}
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
56 |
57 |
58 |
59 |
--------------------------------------------------------------------------------
/torchlife/models/ph.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: 55_hazard.PiecewiseHazard.ipynb (unless otherwise specified).
2 |
3 | __all__ = ['PieceWiseHazard']
4 |
5 | # Cell
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | from sklearn.preprocessing import MaxAbsScaler
11 |
12 | torch.Tensor.ndim = property(lambda x: x.dim())
13 |
14 | # Cell
15 | class PieceWiseHazard(nn.Module):
16 | """
17 | Piecewise Hazard where the hazard is constant between breakpoints.
18 | parameters:
19 | - breakpoints: time points where hazard would change (must include 0 and max possible time)
20 | """
21 | def __init__(self, breakpoints:np.array, t_scaler:MaxAbsScaler, **kwargs):
22 | super().__init__()
23 | self.t_scaler = t_scaler
24 | if len(breakpoints.shape) == 1:
25 | breakpoints = self.t_scaler.transform(breakpoints[:,None])
26 | else:
27 | breakpoints = self.t_scaler.transform(breakpoints)
28 | self.logλ = nn.Parameter(torch.randn(len(breakpoints)-1, 1))
29 | self.register_buffer('breakpoints', torch.Tensor(breakpoints[:-1]))
30 | self.register_buffer('widths', torch.Tensor(np.diff(breakpoints, axis=0)))
31 | self.prepend_zero = nn.ConstantPad2d((0,0,1,0), 0)
32 |
33 | def cumulative_hazard(self, t, t_section):
34 | """
35 | Integral of hazard wrt time.
36 | """
37 | λ = torch.exp(self.logλ)
38 |
39 | # cumulative hazard
40 | cum_hazard = λ * self.widths
41 | cum_hazard = cum_hazard.cumsum(0)
42 | cum_hazard = self.prepend_zero(cum_hazard)
43 | cum_hazard_sec = cum_hazard[t_section]
44 |
45 | δ_t = t - self.breakpoints[t_section]
46 |
47 | return cum_hazard_sec + λ[t_section] * δ_t
48 |
49 | def forward(self, t, t_section, *args):
50 | return self.logλ[t_section], self.cumulative_hazard(t, t_section)
51 |
52 | def survival_function(self, t:np.array):
53 | """
54 | parameters:
55 | - t: time (do not scale to be between 0 and 1)
56 | """
57 | if len(t.shape) == 1:
58 | t = t[:,None]
59 | t = self.t_scaler.transform(t)
60 |
61 | with torch.no_grad():
62 | # get the times and time sections for survival function
63 | breakpoints = self.breakpoints[1:].cpu().numpy()
64 | t_sec_query = np.searchsorted(breakpoints.squeeze(), t.squeeze())
65 | # convert to pytorch tensors
66 | t_query = torch.Tensor(t)
67 | t_sec_query = torch.LongTensor(t_sec_query)
68 |
69 | # calculate cumulative hazard according to above
70 | Λ = self.cumulative_hazard(t_query, t_sec_query)
71 | return torch.exp(-Λ)
72 |
73 | def hazard(self):
74 | with torch.no_grad():
75 | width = self.widths
76 | breakpoints = self.breakpoints
77 | λ = torch.exp(self.logλ)
78 | return (self.t_scaler.inverse_transform(breakpoints).squeeze(),
79 | self.t_scaler.inverse_transform(width).squeeze(),
80 | λ.squeeze())
81 |
82 | def plot_survival_function(self, t):
83 | s = self.survival_function(t)
84 | # plot
85 | plt.figure(figsize=(12,5))
86 | plt.plot(t, s)
87 | plt.xlabel('Time')
88 | plt.ylabel('Survival Probability')
89 | plt.show()
90 |
91 | def plot_hazard(self):
92 | """
93 | Plot base hazard
94 | """
95 | breakpoints, width, λ = self.hazard()
96 | # plot
97 | plt.figure(figsize=(12,5))
98 | plt.bar(breakpoints, λ, width, align='edge')
99 | plt.ylabel('λ')
100 | plt.xlabel('t')
101 | plt.show()
--------------------------------------------------------------------------------
/docs/utils.html:
--------------------------------------------------------------------------------
1 | ---
2 |
3 | title: Utilities
4 |
5 | keywords: fastai
6 | sidebar: home_sidebar
7 |
8 | summary: "Helper functions used in developing library."
9 | description: "Helper functions used in developing library."
10 | ---
11 |
20 |
21 |
22 |
23 | {% raw %}
24 |
25 |
26 |
27 |
28 | {% endraw %}
29 |
30 |
31 |
32 |
Delegated Inheretance The following functions help tab completion and squashing class structure. See the blog post by fastai for details.
33 |
34 |
35 |
36 |
37 | {% raw %}
38 |
39 |
40 |
41 |
42 | {% endraw %}
43 |
44 | {% raw %}
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
delegates(to =None , keep =False )
56 |
57 |
Decorator: replace **kwargs in signature with params from to
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 | {% endraw %}
68 |
69 | {% raw %}
70 |
71 |
72 |
73 |
74 | {% endraw %}
75 |
76 | {% raw %}
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
custom_dir(c , add )
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 | {% endraw %}
99 |
100 | {% raw %}
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
GetAttr()
112 |
113 |
Base class for attr accesses in self._xtra passed down to self.default
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 | {% endraw %}
124 |
125 |
126 |
127 |
128 |
--------------------------------------------------------------------------------
/docs/css/syntax.css:
--------------------------------------------------------------------------------
1 | .highlight { background: #ffffff; }
2 | .highlight .c { color: #999988; font-style: italic } /* Comment */
3 | .highlight .err { color: #a61717; background-color: #e3d2d2 } /* Error */
4 | .highlight .k { font-weight: bold } /* Keyword */
5 | .highlight .o { font-weight: bold } /* Operator */
6 | .highlight .cm { color: #999988; font-style: italic } /* Comment.Multiline */
7 | .highlight .cp { color: #999999; font-weight: bold } /* Comment.Preproc */
8 | .highlight .c1 { color: #999988; font-style: italic } /* Comment.Single */
9 | .highlight .cs { color: #999999; font-weight: bold; font-style: italic } /* Comment.Special */
10 | .highlight .gd { color: #000000; background-color: #ffdddd } /* Generic.Deleted */
11 | .highlight .gd .x { color: #000000; background-color: #ffaaaa } /* Generic.Deleted.Specific */
12 | .highlight .ge { font-style: italic } /* Generic.Emph */
13 | .highlight .gr { color: #aa0000 } /* Generic.Error */
14 | .highlight .gh { color: #999999 } /* Generic.Heading */
15 | .highlight .gi { color: #000000; background-color: #ddffdd } /* Generic.Inserted */
16 | .highlight .gi .x { color: #000000; background-color: #aaffaa } /* Generic.Inserted.Specific */
17 | .highlight .go { color: #888888 } /* Generic.Output */
18 | .highlight .gp { color: #555555 } /* Generic.Prompt */
19 | .highlight .gs { font-weight: bold } /* Generic.Strong */
20 | .highlight .gu { color: #aaaaaa } /* Generic.Subheading */
21 | .highlight .gt { color: #aa0000 } /* Generic.Traceback */
22 | .highlight .kc { font-weight: bold } /* Keyword.Constant */
23 | .highlight .kd { font-weight: bold } /* Keyword.Declaration */
24 | .highlight .kp { font-weight: bold } /* Keyword.Pseudo */
25 | .highlight .kr { font-weight: bold } /* Keyword.Reserved */
26 | .highlight .kt { color: #445588; font-weight: bold } /* Keyword.Type */
27 | .highlight .m { color: #009999 } /* Literal.Number */
28 | .highlight .s { color: #d14 } /* Literal.String */
29 | .highlight .na { color: #008080 } /* Name.Attribute */
30 | .highlight .nb { color: #0086B3 } /* Name.Builtin */
31 | .highlight .nc { color: #445588; font-weight: bold } /* Name.Class */
32 | .highlight .no { color: #008080 } /* Name.Constant */
33 | .highlight .ni { color: #800080 } /* Name.Entity */
34 | .highlight .ne { color: #990000; font-weight: bold } /* Name.Exception */
35 | .highlight .nf { color: #990000; font-weight: bold } /* Name.Function */
36 | .highlight .nn { color: #555555 } /* Name.Namespace */
37 | .highlight .nt { color: #000080 } /* Name.Tag */
38 | .highlight .nv { color: #008080 } /* Name.Variable */
39 | .highlight .ow { font-weight: bold } /* Operator.Word */
40 | .highlight .w { color: #bbbbbb } /* Text.Whitespace */
41 | .highlight .mf { color: #009999 } /* Literal.Number.Float */
42 | .highlight .mh { color: #009999 } /* Literal.Number.Hex */
43 | .highlight .mi { color: #009999 } /* Literal.Number.Integer */
44 | .highlight .mo { color: #009999 } /* Literal.Number.Oct */
45 | .highlight .sb { color: #d14 } /* Literal.String.Backtick */
46 | .highlight .sc { color: #d14 } /* Literal.String.Char */
47 | .highlight .sd { color: #d14 } /* Literal.String.Doc */
48 | .highlight .s2 { color: #d14 } /* Literal.String.Double */
49 | .highlight .se { color: #d14 } /* Literal.String.Escape */
50 | .highlight .sh { color: #d14 } /* Literal.String.Heredoc */
51 | .highlight .si { color: #d14 } /* Literal.String.Interpol */
52 | .highlight .sx { color: #d14 } /* Literal.String.Other */
53 | .highlight .sr { color: #009926 } /* Literal.String.Regex */
54 | .highlight .s1 { color: #d14 } /* Literal.String.Single */
55 | .highlight .ss { color: #990073 } /* Literal.String.Symbol */
56 | .highlight .bp { color: #999999 } /* Name.Builtin.Pseudo */
57 | .highlight .vc { color: #008080 } /* Name.Variable.Class */
58 | .highlight .vg { color: #008080 } /* Name.Variable.Global */
59 | .highlight .vi { color: #008080 } /* Name.Variable.Instance */
60 | .highlight .il { color: #009999 } /* Literal.Number.Integer.Long */
--------------------------------------------------------------------------------
/docs/js/toc.js:
--------------------------------------------------------------------------------
1 | // https://github.com/ghiculescu/jekyll-table-of-contents
2 | // this library modified by fastai to:
3 | // - update the location.href with the correct anchor when a toc item is clicked on
4 | (function($){
5 | $.fn.toc = function(options) {
6 | var defaults = {
7 | noBackToTopLinks: false,
8 | title: '',
9 | minimumHeaders: 3,
10 | headers: 'h1, h2, h3, h4',
11 | listType: 'ol', // values: [ol|ul]
12 | showEffect: 'show', // values: [show|slideDown|fadeIn|none]
13 | showSpeed: 'slow' // set to 0 to deactivate effect
14 | },
15 | settings = $.extend(defaults, options);
16 |
17 | var headers = $(settings.headers).filter(function() {
18 | // get all headers with an ID
19 | var previousSiblingName = $(this).prev().attr( "name" );
20 | if (!this.id && previousSiblingName) {
21 | this.id = $(this).attr( "id", previousSiblingName.replace(/\./g, "-") );
22 | }
23 | return this.id;
24 | }), output = $(this);
25 | if (!headers.length || headers.length < settings.minimumHeaders || !output.length) {
26 | return;
27 | }
28 |
29 | if (0 === settings.showSpeed) {
30 | settings.showEffect = 'none';
31 | }
32 |
33 | var render = {
34 | show: function() { output.hide().html(html).show(settings.showSpeed); },
35 | slideDown: function() { output.hide().html(html).slideDown(settings.showSpeed); },
36 | fadeIn: function() { output.hide().html(html).fadeIn(settings.showSpeed); },
37 | none: function() { output.html(html); }
38 | };
39 |
40 | var get_level = function(ele) { return parseInt(ele.nodeName.replace("H", ""), 10); }
41 | var highest_level = headers.map(function(_, ele) { return get_level(ele); }).get().sort()[0];
42 | //var return_to_top = ' ';
43 | // other nice icons that can be used instead: glyphicon-upload glyphicon-hand-up glyphicon-chevron-up glyphicon-menu-up glyphicon-triangle-top
44 | var level = get_level(headers[0]),
45 | this_level,
46 | html = settings.title + " <"+settings.listType+">";
47 | headers.on('click', function() {
48 | if (!settings.noBackToTopLinks) {
49 | var pos = $(window).scrollTop();
50 | window.location.hash = this.id;
51 | $(window).scrollTop(pos);
52 | }
53 | })
54 | .addClass('clickable-header')
55 | .each(function(_, header) {
56 | base_url = window.location.href;
57 | base_url = base_url.replace(/#.*$/, "");
58 | this_level = get_level(header);
59 | //if (!settings.noBackToTopLinks && this_level > 1) {
60 | // $(header).addClass('top-level-header').before(return_to_top);
61 | //}
62 | txt = header.textContent.split('¶')[0].split(/\[(test|source)\]/)[0];
63 | if (!txt) {return;}
64 | if (this_level === level) // same level as before; same indenting
65 | html += "" + txt + " ";
66 | else if (this_level <= level){ // higher level than before; end parent ol
67 | for(i = this_level; i < level; i++) {
68 | html += " "+settings.listType+">"
69 | }
70 | html += "" + txt + " ";
71 | }
72 | else if (this_level > level) { // lower level than before; expand the previous to contain a ol
73 | for(i = this_level; i > level; i--) {
74 | html += "<"+settings.listType+">"+((i-level == 2) ? "" : " ")
75 | }
76 | html += "" + txt + " ";
77 | }
78 | level = this_level; // update for the next one
79 | });
80 | html += ""+settings.listType+">";
81 | if (!settings.noBackToTopLinks) {
82 | $(document).on('click', '.back-to-top', function() {
83 | $(window).scrollTop(0);
84 | window.location.hash = '';
85 | });
86 | }
87 |
88 | render[settings.showEffect]();
89 | };
90 | })(jQuery);
91 |
--------------------------------------------------------------------------------
/docs/_layouts/default.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | {% include head.html %}
5 |
41 |
46 |
57 | {% if page.datatable == true %}
58 |
59 |
60 |
61 |
66 |
76 | {% endif %}
77 |
78 |
79 |
80 | {% include topnav.html %}
81 |
82 |
83 |
84 |
85 |
86 | {% assign content_col_size = "col-md-12" %}
87 | {% unless page.hide_sidebar %}
88 |
89 |
92 | {% assign content_col_size = "col-md-9" %}
93 | {% endunless %}
94 |
95 |
96 |
97 | {{content}}
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 | {% if site.google_analytics %}
108 | {% include google_analytics.html %}
109 | {% endif %}
110 |
111 |
--------------------------------------------------------------------------------
/10_SAT.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Survival Analysis Theory\n",
8 | "> The maths behind survival analysis."
9 | ]
10 | },
11 | {
12 | "cell_type": "markdown",
13 | "metadata": {},
14 | "source": [
15 | "Much of the work here is summarised from the notes in [Generalised Linear Models by Germán Rodríguez](https://data.princeton.edu/wws509/notes/c7s1), Chapter 7."
16 | ]
17 | },
18 | {
19 | "cell_type": "markdown",
20 | "metadata": {},
21 | "source": [
22 | "## The Survival Function\n",
23 | "Let us define $S(t)$ to be the probability that an object will survive beyond time $t$. If $f(t)$ is the **instantaneous** probability that a death would be observed at time $t$, the survival function is defined as:\n",
24 | "$$\n",
25 | "S(t) = P(T > t) = 1 - \\int_{-\\infty}^t f(x) dx\n",
26 | "$$"
27 | ]
28 | },
29 | {
30 | "cell_type": "markdown",
31 | "metadata": {},
32 | "source": [
33 | "## The Hazard Function\n",
34 | "Another important concept is the hazard function, $\\lambda(t)$, which is the **instantaneous rate** of occurence, given that the object has survived until time $t$.\n",
35 | "$$\n",
36 | "\\lambda(t) = \\lim\\limits_{dt\\to0} \\frac{P(t t)}{dt}\\\\\n",
37 | "$$\n",
38 | "\n",
39 | "The above can be simplified down by using Bayes Rule, and the definition of $S(t)$ above:\n",
40 | "$$\n",
41 | "\\begin{aligned}\n",
42 | "\\lambda(t) =& \\lim\\limits_{dt\\to0} \\frac{P(t t)}{P(T > t)\\quad dt}\\\\\n",
43 | "=& \\lim\\limits_{dt\\to0} \\frac{P(t t$. Therefore the likelihood $L$ can be defined as:\n",
66 | "$$\n",
67 | "L_i = \\begin{cases}\n",
68 | "f(t_i) = S(t_i)\\lambda(t_i) &\\text{ when }d_i = 1 \\\\\n",
69 | "\\int^{\\infty}_t f(x) dx = S(t_i) &\\text{ when }d_i = 0\n",
70 | "\\end{cases}\n",
71 | "$$\n",
72 | "\n",
73 | "The above can be simplified as:\n",
74 | "$$\n",
75 | "\\begin{aligned}\n",
76 | "L =& \\prod_{i=1}^N L_i = \\prod_{i=1}^N \\lambda(t_i)^{d_i} S(t_i) \\\\\n",
77 | "\\log L =& \\sum_{i=1}^N d_i \\log \\lambda(t_i) - \\Lambda(t_i) \\\\\n",
78 | "-\\log L =& \\sum_{i=1}^N \\Lambda(t_i) - d_i \\log \\lambda(t_i)\n",
79 | "\\end{aligned}\n",
80 | "$$\n",
81 | "where $\\Lambda(t)\\equiv -\\log S(t)$ is cumulative hazard function.\n",
82 | "\n",
83 | "Similarly if we wish to avoid taking into account the hazard function, we can define the likelihood function to be:\n",
84 | "$$\n",
85 | "\\begin{aligned}\n",
86 | "L =& \\prod_{i=1}^N L_i = \\prod_{i=1}^N f(t_i)^{d_i} S(t_i)^{1-d_i} \\\\\n",
87 | "-log L =& -\\sum_{i=1}^N d_i \\log f(t_i) + (1 - d_i) \\log S(t_i)\n",
88 | "\\end{aligned}\n",
89 | "$$\n",
90 | "\n",
91 | "These two formats of the likelihood will be used when modelling the behaviours of censored data (in different modelling settings). We use the negative log likelihood as this is more accommodating for modern deep learning libraries that do gradient descent."
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": null,
97 | "metadata": {},
98 | "outputs": [],
99 | "source": [
100 | " "
101 | ]
102 | }
103 | ],
104 | "metadata": {
105 | "kernelspec": {
106 | "display_name": "Python 3",
107 | "language": "python",
108 | "name": "python3"
109 | }
110 | },
111 | "nbformat": 4,
112 | "nbformat_minor": 2
113 | }
114 |
--------------------------------------------------------------------------------
/docs/SAT.html:
--------------------------------------------------------------------------------
1 | ---
2 |
3 | title: Survival Analysis Theory
4 |
5 |
6 | keywords: fastai
7 | sidebar: home_sidebar
8 |
9 | summary: "The maths behind survival analysis."
10 | description: "The maths behind survival analysis."
11 | nb_path: "10_SAT.ipynb"
12 | ---
13 |
22 |
23 |
24 |
25 | {% raw %}
26 |
27 |
28 |
29 |
30 | {% endraw %}
31 |
32 |
39 |
40 |
41 |
The Survival Function Let us define $S(t)$ to be the probability that an object will survive beyond time $t$. If $f(t)$ is the instantaneous probability that a death would be observed at time $t$, the survival function is defined as:
42 | $$
43 | S(t) = P(T > t) = 1 - \int_{-\infty}^t f(x) dx
44 | $$
45 |
46 |
47 |
48 |
49 |
50 |
51 |
The Hazard Function Another important concept is the hazard function, $\lambda(t)$, which is the instantaneous rate of occurence, given that the object has survived until time $t$.
52 | $$
53 | \lambda(t) = \lim\limits_{dt\to0} \frac{P(t<T<t+dt | T > t)}{dt}\\
54 | $$
55 |
The above can be simplified down by using Bayes Rule, and the definition of $S(t)$ above:
56 | $$
57 | \begin{aligned}
58 | \lambda(t) =& \lim\limits_{dt\to0} \frac{P(t<T<t+dt, T > t)}{P(T > t)\quad dt}\\
59 | =& \lim\limits_{dt\to0} \frac{P(t<T<t+dt)}{dt} \frac{1}{S(t)}\\
60 | =& \frac{f(t)}{S(t)}
61 | \end{aligned}
62 | $$
63 |
Since $S'(t) = -f(t)$ from the first equation, we can also state:
64 | $$
65 | \lambda(t) = - \frac{d}{dt}\log S(t)
66 | $$
67 |
Therefore, the survival function will can be stated as:
68 | $$
69 | S(t) = \exp\left(-\int_{-\infty}^t \lambda(x) dx\right)
70 | $$
71 | and this will come in handy when we are coding up the survival function.
72 |
73 |
74 |
75 |
76 |
77 |
78 |
Likelihood Function In any probabilistic framework we wish to maxmise the likelihood of observed data given the probability functions. However, unlike classification/ regression situations we need to modify the likelihood $f(t)$. Let us define the case where death, $d_i$ hasn't been observed as censored observations. In those case we know that death will occur at a point $T > t$. Therefore the likelihood $L$ can be defined as:
79 | $$
80 | L_i = \begin{cases}
81 | f(t_i) = S(t_i)\lambda(t_i) &\text{ when }d_i = 1 \\
82 | \int^{\infty}_t f(x) dx = S(t_i) &\text{ when }d_i = 0
83 | \end{cases}
84 | $$
85 |
The above can be simplified as:
86 | $$
87 | \begin{aligned}
88 | L =& \prod_{i=1}^N L_i = \prod_{i=1}^N \lambda(t_i)^{d_i} S(t_i) \\
89 | \log L =& \sum_{i=1}^N d_i \log \lambda(t_i) - \Lambda(t_i) \\
90 | -\log L =& \sum_{i=1}^N \Lambda(t_i) - d_i \log \lambda(t_i)
91 | \end{aligned}
92 | $$
93 | where $\Lambda(t)\equiv -\log S(t)$ is cumulative hazard function.
94 |
Similarly if we wish to avoid taking into account the hazard function, we can define the likelihood function to be:
95 | $$
96 | \begin{aligned}
97 | L =& \prod_{i=1}^N L_i = \prod_{i=1}^N f(t_i)^{d_i} S(t_i)^{1-d_i} \\
98 | -log L =& -\sum_{i=1}^N d_i \log f(t_i) + (1 - d_i) \log S(t_i)
99 | \end{aligned}
100 | $$
101 |
These two formats of the likelihood will be used when modelling the behaviours of censored data (in different modelling settings). We use the negative log likelihood as this is more accommodating for modern deep learning libraries that do gradient descent.
102 |
103 |
104 |
105 |
106 | {% raw %}
107 |
108 |
121 | {% endraw %}
122 |
123 |
124 |
125 |
126 |
--------------------------------------------------------------------------------
/docs/js/jekyll-search.js:
--------------------------------------------------------------------------------
1 | !function e(t,n,r){function s(o,u){if(!n[o]){if(!t[o]){var a="function"==typeof require&&require;if(!u&&a)return a(o,!0);if(i)return i(o,!0);throw new Error("Cannot find module '"+o+"'")}var f=n[o]={exports:{}};t[o][0].call(f.exports,function(e){var n=t[o][1][e];return s(n?n:e)},f,f.exports,e,t,n,r)}return n[o].exports}for(var i="function"==typeof require&&require,o=0;o=0}var self=this;self.matches=function(string,crit){return"string"!=typeof string?!1:(string=string.trim(),doMatch(string,crit))}}module.exports=new LiteralSearchStrategy},{}],4:[function(require,module){module.exports=function(){function findMatches(store,crit,strategy){for(var data=store.get(),i=0;i{title} ',noResultsText:"No results found",limit:10,fuzzy:!1};self.init=function(_opt){validateOptions(_opt),assignOptions(_opt),isJSON(opt.dataSource)?initWithJSON(opt.dataSource):initWithURL(opt.dataSource)}}var Searcher=require("./Searcher"),Templater=require("./Templater"),Store=require("./Store"),JSONLoader=require("./JSONLoader"),searcher=new Searcher,templater=new Templater,store=new Store,jsonLoader=new JSONLoader;window.SimpleJekyllSearch=new SimpleJekyllSearch}(window,document)},{"./JSONLoader":1,"./Searcher":4,"./Store":5,"./Templater":6}]},{},[7]);
2 |
--------------------------------------------------------------------------------
/docs/AFT_models.html:
--------------------------------------------------------------------------------
1 | ---
2 |
3 | title: Accelerated Failure Time Models
4 |
5 |
6 | keywords: fastai
7 | sidebar: home_sidebar
8 |
9 | summary: "AFT Model theory."
10 | description: "AFT Model theory."
11 | nb_path: "60_AFT_models.ipynb"
12 | ---
13 |
22 |
23 |
24 |
25 | {% raw %}
26 |
27 |
28 |
29 |
30 | {% endraw %}
31 |
32 |
33 |
34 |
We can model the time to failure as:
35 | $$
36 | \log T_i = \mu + \xi_i
37 | $$
38 | where $\xi_i\sim p(\xi|\theta)$ and $\mu$ is the most likely log time of death (the mode of the distribution of $T_i$). We model log death as that way we do not need to restrict $\mu + \xi_i$ to be positive.
39 |
In the censored case, where $t_i$ is the time where an instance was censored, and $T_i$ is the unobserved time of death, we have:
40 | $$
41 | \begin{aligned}
42 | \log T_i &= \mu(x_i) + \xi_i > \log t_i\\
43 | \therefore \xi_i &> \log t_i - \mu(x_i)
44 | \end{aligned}
45 | $$
46 | Note that $\mu$ is a function of the features $x$. The log likelihood of the data ($\mathcal{D}$) can then shown to be:
47 | $$
48 | \begin{aligned}
49 | \log p(\mathcal{D}) = \sum_{i=1}^N \mathcal{1}(y_i=1)\log p(\xi_i = \log t_i - \mu(x_i)) + \mathcal{1}(y_i=0)\log p(\xi_i &> \log t_i - \mu(x_i))
50 | \end{aligned}
51 | $$
52 |
53 |
54 |
55 |
56 | {% raw %}
57 |
58 |
59 |
60 |
61 | {% endraw %}
62 |
63 | {% raw %}
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
AFTModel(distribution :str, input_dim :int, h :tuple=() ) :: Module
75 |
76 |
Accelerated Failure Time model
77 | parameters:
78 |
79 | Distribution of which the error is assumed to be
80 | dim (optional): input dimensionality of variables
81 | h (optional): number of hidden nodes
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 | {% endraw %}
93 |
94 | {% raw %}
95 |
96 |
97 |
98 |
99 | {% endraw %}
100 |
101 |
102 |
103 |
Modelling based on only time and (death) event variables:
104 |
105 |
106 |
107 |
108 | {% raw %}
109 |
110 |
135 | {% endraw %}
136 |
137 | {% raw %}
138 |
139 |
159 | {% endraw %}
160 |
161 | {% raw %}
162 |
163 |
176 | {% endraw %}
177 |
178 | {% raw %}
179 |
180 |
193 | {% endraw %}
194 |
195 |
196 |
197 |
198 |
--------------------------------------------------------------------------------
/95_Losses.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "# default_exp losses"
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {},
15 | "source": [
16 | "# Losses\n",
17 | "> All the losses used in SA."
18 | ]
19 | },
20 | {
21 | "cell_type": "code",
22 | "execution_count": null,
23 | "metadata": {},
24 | "outputs": [],
25 | "source": [
26 | "# export\n",
27 | "from abc import ABC, abstractmethod\n",
28 | "from typing import Callable, Tuple\n",
29 | "import torch"
30 | ]
31 | },
32 | {
33 | "cell_type": "markdown",
34 | "metadata": {},
35 | "source": [
36 | "Suppose that we have:\n",
37 | "$$\n",
38 | "t_i = \\mu + \\xi_i\n",
39 | "$$\n",
40 | "and $\\xi_i\\sim p(\\xi_i|\\theta)$. Then $\\xi_i|\\mu\\sim p_\\mu(\\xi_i|\\theta)$ where $p_\\mu(\\xi_i|\\theta)$ is simply the distribution $p(\\xi_i|\\theta)$ shifted to the left by $\\mu$.\n",
41 | "\n",
42 | "In the event that the event is censored ($e_i=0$), we know that $t_i < \\mu + \\xi_i$ since the 'death' offset of $\\xi_i$ is not observed. \n",
43 | "\n",
44 | "Therefore we may write the likelihood of \n",
45 | "$$\n",
46 | "\\begin{aligned}\n",
47 | "p(t_i, e_i|\\mu) =& \\left(p(t_i-\\mu)\\right)^{e_i} \\left(\\int_{t_i}^\\infty p(t-\\mu) dt\\right)^{1-e_i}\\\\\n",
48 | "\\log p(t_i, e_i|\\mu) =& e_i \\log p(t-\\mu) + (1 - e_i) \\log \\left(1 - \\int_{-\\infty}^{t_i} p(t-\\mu) dt \\right)\n",
49 | "\\end{aligned}\n",
50 | "$$"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": null,
56 | "metadata": {},
57 | "outputs": [],
58 | "source": [
59 | "# export\n",
60 | "class Loss(ABC):\n",
61 | " @abstractmethod\n",
62 | " def __call__(event:torch.Tensor, *args):\n",
63 | " pass"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": null,
69 | "metadata": {},
70 | "outputs": [],
71 | "source": [
72 | "# export\n",
73 | "LossType = Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": null,
79 | "metadata": {},
80 | "outputs": [],
81 | "source": [
82 | "# export\n",
83 | "class AFTLoss(Loss):\n",
84 | " @staticmethod\n",
85 | " def __call__(event:torch.Tensor, log_pdf: torch.Tensor, log_icdf: torch.Tensor) -> torch.Tensor:\n",
86 | " lik = event * log_pdf + (1 - event) * log_icdf\n",
87 | " return -lik.mean()"
88 | ]
89 | },
90 | {
91 | "cell_type": "code",
92 | "execution_count": null,
93 | "metadata": {},
94 | "outputs": [
95 | {
96 | "data": {
97 | "text/plain": [
98 | "tensor(-0.0400)"
99 | ]
100 | },
101 | "execution_count": null,
102 | "metadata": {},
103 | "output_type": "execute_result"
104 | }
105 | ],
106 | "source": [
107 | "N = 5\n",
108 | "event = torch.randint(0, 2, (N,))\n",
109 | "log_pdf = torch.randn((N,))\n",
110 | "log_cdf = -torch.rand((N,))\n",
111 | "\n",
112 | "aft_loss = AFTLoss()\n",
113 | "aft_loss(event, log_pdf, log_cdf)"
114 | ]
115 | },
116 | {
117 | "cell_type": "code",
118 | "execution_count": null,
119 | "metadata": {},
120 | "outputs": [],
121 | "source": [
122 | "# export\n",
123 | "def _aft_loss(\n",
124 | " log_pdf: torch.Tensor, log_cdf: torch.Tensor, e: torch.Tensor\n",
125 | ") -> torch.Tensor:\n",
126 | " lik = e * log_pdf + (1 - e) * log_cdf\n",
127 | " return -lik.mean()\n",
128 | "\n",
129 | "\n",
130 | "def aft_loss(log_prob, e):\n",
131 | " log_pdf, log_cdf = log_prob\n",
132 | " return _aft_loss(log_pdf, log_cdf, e)"
133 | ]
134 | },
135 | {
136 | "cell_type": "markdown",
137 | "metadata": {},
138 | "source": [
139 | "We use the following loss function to infer our model. See [here](./SAT#Likelihood-Function) for theory.\n",
140 | "$$\n",
141 | "-\\log L = \\sum_{i=1}^N \\Lambda(t_i) - d_i \\log \\lambda(t_i)\n",
142 | "$$"
143 | ]
144 | },
145 | {
146 | "cell_type": "code",
147 | "execution_count": null,
148 | "metadata": {},
149 | "outputs": [],
150 | "source": [
151 | "# export\n",
152 | "class HazardLoss(Loss):\n",
153 | " @staticmethod\n",
154 | " def __call__(event: torch.Tensor, logλ: torch.Tensor, Λ: torch.Tensor) -> torch.Tensor:\n",
155 | " log_lik = event * logλ - Λ\n",
156 | " return -log_lik.mean()"
157 | ]
158 | },
159 | {
160 | "cell_type": "code",
161 | "execution_count": null,
162 | "metadata": {},
163 | "outputs": [],
164 | "source": [
165 | "# export\n",
166 | "def _hazard_loss(logλ: torch.Tensor, Λ: torch.Tensor, e: torch.Tensor) -> torch.Tensor:\n",
167 | " log_lik = e * logλ - Λ\n",
168 | " return -log_lik.mean()\n",
169 | "\n",
170 | "\n",
171 | "def hazard_loss(\n",
172 | " hazard: Tuple[torch.Tensor, torch.Tensor], e: torch.Tensor\n",
173 | ") -> torch.Tensor:\n",
174 | " \"\"\"\n",
175 | " parameters:\n",
176 | " - hazard: log hazard and Cumulative hazard\n",
177 | " - e: torch.Tensor of 1 if death event occured and 0 otherwise\n",
178 | " \"\"\"\n",
179 | " logλ, Λ = hazard\n",
180 | " return _hazard_loss(logλ, Λ, e)"
181 | ]
182 | },
183 | {
184 | "cell_type": "code",
185 | "execution_count": null,
186 | "metadata": {},
187 | "outputs": [
188 | {
189 | "name": "stdout",
190 | "output_type": "stream",
191 | "text": [
192 | "Converted 00_index.ipynb.\n",
193 | "Converted 10_SAT.ipynb.\n",
194 | "Converted 20_KaplanMeier.ipynb.\n",
195 | "Converted 30_overall_model.ipynb.\n",
196 | "Converted 50_hazard.ipynb.\n",
197 | "Converted 55_hazard.PiecewiseHazard.ipynb.\n",
198 | "Converted 59_hazard.Cox.ipynb.\n",
199 | "Converted 60_AFT_models.ipynb.\n",
200 | "Converted 65_AFT_error_distributions.ipynb.\n",
201 | "Converted 80_data.ipynb.\n",
202 | "Converted 90_model.ipynb.\n",
203 | "Converted 95_Losses.ipynb.\n"
204 | ]
205 | }
206 | ],
207 | "source": [
208 | "# hide\n",
209 | "from nbdev.export import *\n",
210 | "notebook2script()"
211 | ]
212 | }
213 | ],
214 | "metadata": {
215 | "kernelspec": {
216 | "display_name": "Python 3",
217 | "language": "python",
218 | "name": "python3"
219 | }
220 | },
221 | "nbformat": 4,
222 | "nbformat_minor": 2
223 | }
224 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TorchLife
2 | > Survival Analysis using pytorch
3 |
4 |
5 | This library takes a deep learning approach to Survival Analysis.
6 |
7 | ## What is Survival Analysis
8 | A lot of classification problems are actually survival analysis problems and haven't been tackled as such. For example, consider a cancer patient and you take X-ray data from that patient. Over time, patients will **eventually** die from cancer (lets ignore the case where people will die from other diseases). The usual approach is to say here is the X-ray (x) and will the patient die in the next 30 days or not (y).
9 |
10 | Survival analysis instead asks the question given the input (x) and a time(t), what is the probability that a patient will survive for a time greater than t. Considering the training dataset if a patient is still alive, in the classification case it would be thought of as y = 0. In survival analysis we say that it is a censored observation since the patient will die at a certain time in the future when the experiment is not being conducted.
11 |
12 | The above analogy can be thought of in other scenarios such as churn prediction as well.
13 |
14 | A proper dive into theory can be seen [here](./SAT).
15 |
16 | ## What's with the name?
17 | Well, if you torch a life... you probability wouldn't survive. 😬
18 |
19 | ## How to use this library
20 | There are 3 models in here that can be used.
21 | 1. [Kaplan Meier Model]
22 | 2. [Proportional Hazard Models]
23 | 3. [Accelerated Failure Time Models]
24 | All 3 models require you to input a pandas dataframe with the columns `"t", "e"` indicating time elapsed and a binary variable, event if a death (1) or live (0) instance is observed respectively. They are all capable of doing `fit` and `plot_survival_function`.
25 |
26 | ### Kaplan Meier Model
27 | This is the most simplistic model.
28 |
29 | ### Proportional Hazard Model
30 | This model attempts to model the instantaneous hazard of an instance given time. It does this by binning time and finding the cumulative hazard upto a given point in time. It's extention the cox model, takes into account other variables that are not time dependant such that the above mentioned hazard can grow or shrink proportional to the risk associated with non-temporal feature based hazard.
31 | ```python
32 | from torchlife.model import ModelHazard
33 |
34 | model = ModelHazard('cox')
35 | model.fit(df)
36 | inst_hazard, surv_probability = model.predict(df)
37 | ```
38 |
39 | ### Accelerated Failure Time Models
40 | This model attempts to model (the mode and not average) time such that it is a function of the non-temporal features, x.
41 |
42 | You are free to choose the distribution of the error, however, `Gumbel` distribution is a popular option in SA literature. The distribution needs to be over the real domain and not just positive since we are modelling log time.
43 | ```python
44 | from torchlife.model import ModelAFT
45 |
46 | model = ModelAFT('Gumbel')
47 | model.fit(df)
48 | surv_prob = model.predict(df)
49 | mode_time = model.predict_time(df)
50 | ```
51 |
52 | ## Kudos
53 | Special thanks to the following libraries and resources.
54 | - [lifelines](https://lifelines.readthedocs.io/en/latest/) and especially Cameron Davidson-Pilon
55 | - [pymc3 survival analysis examples](https://docs.pymc.io/nb_examples/index.html)
56 | - [nbdev](https://nbdev.fast.ai/)
57 | - [pytorch lightning](https://pytorch-lightning.readthedocs.io/)
58 | - [Generalised Linear Models by Germán Rodríguez](https://data.princeton.edu/wws509/notes/c7s1), Chapter 7.
59 |
60 | ## Install
61 |
62 | `pip install torchlife`
63 |
64 | ## How to use
65 | We need a dataframe that has a column named 't' indicating time, and 'e' indicating a death event.
66 |
67 | ```
68 | import pandas as pd
69 | import numpy as np
70 | url = "https://raw.githubusercontent.com/CamDavidsonPilon/lifelines/master/lifelines/datasets/rossi.csv"
71 | df = pd.read_csv(url)
72 | df.rename(columns={'week':'t', 'arrest':'e'}, inplace=True)
73 | ```
74 |
75 | ```
76 | df.head()
77 | ```
78 |
79 |
80 |
81 |
82 |
83 |
96 |
97 |
98 |
99 |
100 | t
101 | e
102 | fin
103 | age
104 | race
105 | wexp
106 | mar
107 | paro
108 | prio
109 |
110 |
111 |
112 |
113 | 0
114 | 20
115 | 1
116 | 0
117 | 27
118 | 1
119 | 0
120 | 0
121 | 1
122 | 3
123 |
124 |
125 | 1
126 | 17
127 | 1
128 | 0
129 | 18
130 | 1
131 | 0
132 | 0
133 | 1
134 | 8
135 |
136 |
137 | 2
138 | 25
139 | 1
140 | 0
141 | 19
142 | 0
143 | 1
144 | 0
145 | 1
146 | 13
147 |
148 |
149 | 3
150 | 52
151 | 0
152 | 1
153 | 23
154 | 1
155 | 1
156 | 1
157 | 1
158 | 1
159 |
160 |
161 | 4
162 | 52
163 | 0
164 | 0
165 | 19
166 | 0
167 | 1
168 | 0
169 | 1
170 | 3
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 | ```
179 | from torchlife.model import ModelHazard
180 |
181 | model = ModelHazard('cox', lr=0.5)
182 | model.fit(df)
183 | λ, S = model.predict(df)
184 | ```
185 |
186 | GPU available: False, used: False
187 | TPU available: None, using: 0 TPU cores
188 |
189 | | Name | Type | Params
190 | --------------------------------------------
191 | 0 | base | ProportionalHazard | 12
192 | --------------------------------------------
193 | 12 Trainable params
194 | 0 Non-trainable params
195 | 12 Total params
196 | Epoch 0: 75%|███████▌ | 3/4 [00:00<00:00, 25.43it/s, loss=nan, v_num=49]
197 | Epoch 0: 100%|██████████| 4/4 [00:00<00:00, 16.39it/s, loss=nan, v_num=49]
198 | Epoch 1: 75%|███████▌ | 3/4 [00:00<00:00, 23.74it/s, loss=nan, v_num=49]
199 | Epoch 1: 100%|██████████| 4/4 [00:00<00:00, 15.25it/s, loss=nan, v_num=49]
200 | Epoch 2: 75%|███████▌ | 3/4 [00:00<00:00, 22.98it/s, loss=nan, v_num=49]
201 | Epoch 2: 100%|██████████| 4/4 [00:00<00:00, 15.53it/s, loss=nan, v_num=49]
202 | Epoch 3: 75%|███████▌ | 3/4 [00:00<00:00, 20.23it/s, loss=nan, v_num=49]
203 | Validating: 0it [00:00, ?it/s][A
204 | Epoch 3: 100%|██████████| 4/4 [00:00<00:00, 13.30it/s, loss=nan, v_num=49]
205 | [ASaving latest checkpoint...
206 | Epoch 3: 100%|██████████| 4/4 [00:00<00:00, 12.92it/s, loss=nan, v_num=49]
207 |
208 |
209 | Let's plot the survival function for the 4th element in the dataframe:
210 |
211 | ```
212 | x = df.drop(['t', 'e'], axis=1).iloc[2]
213 | t = np.arange(df['t'].max())
214 | model.plot_survival_function(t, x)
215 | ```
216 |
217 |
218 | 
219 |
220 |
--------------------------------------------------------------------------------
/torchlife/data.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: 80_data.ipynb (unless otherwise specified).
2 |
3 | __all__ = ['TestData', 'Data', 'TestDataFrame', 'DataFrame', 'create_dl', 'create_test_dl', 'get_breakpoints']
4 |
5 | # Cell
6 | from typing import Optional, Tuple, Union
7 |
8 | import multiprocessing as mp
9 | import numpy as np
10 | import pandas as pd
11 | import torch
12 | from fastai.data_block import DataBunch, DatasetType
13 | from pandas import DataFrame
14 | from sklearn.model_selection import train_test_split
15 | from sklearn.preprocessing import MaxAbsScaler, StandardScaler
16 | from torch.utils.data import DataLoader, Dataset
17 |
18 | # Cell
19 | class TestData(Dataset):
20 | """
21 | Create pyTorch Dataset
22 | parameters:
23 | - t: time elapsed
24 | - b: (optional) breakpoints where the hazard is different to previous segment of time.
25 | **Must include 0 as first element and the maximum time as last element**
26 | - x: (optional) features
27 | """
28 | def __init__(self, t:np.array, b:Optional[np.array]=None, x:Optional[np.array]=None,
29 | t_scaler:MaxAbsScaler=None, x_scaler:StandardScaler=None) -> None:
30 | super().__init__()
31 | self.t, self.b, self.x = t, b, x
32 | self.t_scaler = t_scaler
33 | self.x_scaler = x_scaler
34 | if len(t.shape) == 1:
35 | self.t = t[:,None]
36 |
37 | if t_scaler:
38 | self.t_scaler = t_scaler
39 | self.t = self.t_scaler.transform(self.t)
40 | else:
41 | self.t_scaler = MaxAbsScaler()
42 | self.t = self.t_scaler.fit_transform(self.t)
43 |
44 | if b is not None:
45 | b = b[1:-1]
46 | if len(b.shape) == 1:
47 | b = b[:,None]
48 | if t_scaler:
49 | self.b = t_scaler.transform(b).squeeze()
50 | else:
51 | self.b = self.t_scaler.transform(b).squeeze()
52 |
53 | if x is not None:
54 | if len(x.shape) == 1:
55 | self.x = x[None, :]
56 | if x_scaler:
57 | self.x_scaler = x_scaler
58 | self.x = self.x_scaler.transform(self.x)
59 | else:
60 | self.x_scaler = StandardScaler()
61 | self.x = self.x_scaler.fit_transform(self.x)
62 |
63 | self.only_x = False
64 |
65 | def __len__(self) -> int:
66 | return len(self.t)
67 |
68 | def __getitem__(self, i:int) -> Tuple:
69 | if self.only_x:
70 | return torch.Tensor(self.x[i])
71 |
72 | time = torch.Tensor(self.t[i])
73 |
74 | if self.b is None:
75 | x_ = (time,)
76 | else:
77 | t_section = torch.LongTensor([np.searchsorted(self.b, self.t[i])])
78 | x_ = (time, t_section.squeeze())
79 |
80 | if self.x is not None:
81 | x = torch.Tensor(self.x[i])
82 | x_ = x_ + (x,)
83 |
84 | return x_
85 |
86 | # Cell
87 | class Data(TestData):
88 | """
89 | Create pyTorch Dataset
90 | parameters:
91 | - t: time elapsed
92 | - e: (death) event observed. 1 if observed, 0 otherwise.
93 | - b: (optional) breakpoints where the hazard is different to previous segment of time.
94 | - x: (optional) features
95 | """
96 | def __init__(self, t:np.array, e:np.array, b:Optional[np.array]=None, x:Optional[np.array]=None,
97 | t_scaler:MaxAbsScaler=None, x_scaler:StandardScaler=None) -> None:
98 | super().__init__(t, b, x, t_scaler, x_scaler)
99 | self.e = e
100 | if len(e.shape) == 1:
101 | self.e = e[:,None]
102 |
103 | def __getitem__(self, i) -> Tuple:
104 | x_ = super().__getitem__(i)
105 | e = torch.Tensor(self.e[i])
106 | return x_, e
107 |
108 | # Cell
109 | class TestDataFrame(TestData):
110 | """
111 | Wrapper around Data Class that takes in a dataframe instead
112 | parameters:
113 | - df: dataframe. **Must have t (time) and e (event) columns, other cols optional.
114 | - b: breakpoints of time (optional)
115 | """
116 | def __init__(self, df:DataFrame, b:Optional[np.array]=None,
117 | t_scaler:MaxAbsScaler=None, x_scaler:StandardScaler=None) -> None:
118 | t = df['t'].values
119 | remainder = list(set(df.columns) - set(['t', 'e']))
120 | x = df[remainder].values
121 | if x.shape[1] == 0:
122 | x = None
123 | super().__init__(t, b, x, t_scaler, x_scaler)
124 |
125 | # Cell
126 | class DataFrame(Data):
127 | """
128 | Wrapper around Data Class that takes in a dataframe instead
129 | parameters:
130 | - df: dataframe. **Must have t (time) and e (event) columns, other cols optional.
131 | - b: breakpoints of time (optional)
132 | """
133 | def __init__(self, df:DataFrame, b:Optional[np.array]=None,
134 | t_scaler:MaxAbsScaler=None, x_scaler:StandardScaler=None) -> None:
135 | t = df['t'].values
136 | e = df['e'].values
137 | x = df.drop(['t', 'e'], axis=1).values
138 | if x.shape[1] == 0:
139 | x = None
140 | super().__init__(t, e, b, x, t_scaler, x_scaler)
141 |
142 | # Cell
143 | def create_dl(df:pd.DataFrame, b:Optional[np.array]=None, train_size:float=0.8, random_state=None, bs:int=128)\
144 | -> Tuple[DataBunch, MaxAbsScaler, StandardScaler]:
145 | """
146 | Take dataframe and split into train, test, val (optional)
147 | and convert to Fastai databunch
148 |
149 | parameters:
150 | - df: pandas dataframe
151 | - b(optional): breakpoints of time. **Must include 0 as first element and the maximum time as last element**
152 | - train_p: training percentage
153 | - bs: batch size
154 | """
155 | df.reset_index(drop=True, inplace=True)
156 | train, val = train_test_split(df, train_size=train_size, stratify=df["e"], random_state=random_state)
157 | train.reset_index(drop=True, inplace=True)
158 | val.reset_index(drop=True, inplace=True)
159 |
160 | train_ds = DataFrame(train, b)
161 | val_ds = DataFrame(val, b, train_ds.t_scaler, train_ds.x_scaler)
162 |
163 | train_dl = DataLoader(train_ds, bs, shuffle=True, drop_last=False, num_workers=mp.cpu_count())
164 | val_dl = DataLoader(val_ds, bs, shuffle=False, drop_last=False, num_workers=mp.cpu_count())
165 |
166 | return train_dl, val_dl, train_ds.t_scaler, train_ds.x_scaler
167 |
168 | def create_test_dl(df:pd.DataFrame, b:Optional[np.array]=None,
169 | t_scaler:MaxAbsScaler=None, x_scaler:StandardScaler=None,
170 | bs:int=128, only_x:bool=False) -> DataLoader:
171 | """
172 | Take dataframe and return a pytorch dataloader.
173 | parameters:
174 | - df: pandas dataframe
175 | - b: breakpoints of time (optional)
176 | - bs: batch size
177 | """
178 | if only_x:
179 | df["t"] = 0
180 | df.reset_index(drop=True, inplace=True)
181 | test_ds = TestDataFrame(df, b, t_scaler, x_scaler)
182 | test_ds.only_x = only_x
183 | test_dl = DataLoader(test_ds, bs, shuffle=False, drop_last=False, num_workers=mp.cpu_count())
184 | return test_dl
185 |
186 | # Cell
187 | def get_breakpoints(df:DataFrame, percentiles:list=[20, 40, 60, 80]) -> np.array:
188 | """
189 | Gives the times at which death events occur at given percentile
190 | parameters:
191 | df - must contain columns 't' (time) and 'e' (death event)
192 | percentiles - list of percentages at which breakpoints occur (do not include 0 and 100)
193 | """
194 | event_times = df.loc[df['e']==1, 't'].values
195 | breakpoints = np.percentile(event_times, percentiles)
196 | breakpoints = np.array([0] + breakpoints.tolist() + [df['t'].max()])
197 |
198 | return breakpoints
--------------------------------------------------------------------------------
/torchlife/model.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: 90_model.ipynb (unless otherwise specified).
2 |
3 | __all__ = ['GeneralModel', 'train_model', 'ModelHazard', 'ModelAFT']
4 |
5 | # Cell
6 | from dataclasses import dataclass
7 | from datetime import datetime
8 | from typing import Callable, List, Optional, Tuple
9 |
10 | import pandas as pd
11 | import pytorch_lightning as pl
12 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
13 | import torch
14 | import torch.nn as nn
15 | import torch.nn.functional as F
16 | from torch.optim.lr_scheduler import ReduceLROnPlateau
17 |
18 | from .models.ph import PieceWiseHazard
19 | from .models.cox import ProportionalHazard
20 | from .models.aft import AFTModel
21 | from .data import create_dl, create_test_dl, get_breakpoints
22 | from .losses import aft_loss, hazard_loss, Loss, HazardLoss, AFTLoss
23 |
24 | # Cell
25 | class GeneralModel(pl.LightningModule):
26 | def __init__(
27 | self,
28 | base: nn.Module,
29 | loss_fn: Loss,
30 | lr: float = 1e-3,
31 | ) -> None:
32 | """
33 | n_features: The number of real valued feature.
34 | n_layers: Number of Deep learning layers.
35 | loss_fn: Loss function.
36 | n_cats: List of integers containing the number of unique values per category.
37 | """
38 | super().__init__()
39 | self.base = base
40 | self.loss_fn = loss_fn
41 | self.lr = lr
42 |
43 | def forward(self, x):
44 | return self.base(*x)
45 |
46 | def common_step(
47 | self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
48 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
49 | x, e = batch
50 | density_term, cumulative_term = self(x)
51 | loss = self.loss_fn(e.squeeze(), density_term.squeeze(), cumulative_term.squeeze())
52 |
53 | if torch.isnan(loss):
54 | breakpoint()
55 | return loss
56 |
57 | def training_step(self, batch, *args):
58 | loss = self.common_step(batch)
59 | self.log("training_loss", loss, on_step=True, on_epoch=True)
60 |
61 | def validation_step(self, batch, *args):
62 | loss = self.common_step(batch)
63 | self.log("val_loss", loss, on_step=True, on_epoch=True)
64 |
65 | def configure_optimizers(self):
66 | optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
67 | return {
68 | "optimizer": optimizer,
69 | "lr_scheduler": ReduceLROnPlateau(optimizer, patience=2),
70 | "monitor": "val_loss"
71 | }
72 |
73 | # Cell
74 | def train_model(model, train_dl, valid_dl, epochs):
75 | checkpoint_callback = ModelCheckpoint(
76 | monitor="val_loss",
77 | dirpath="./models/",
78 | filename= "model-{epoch:02d}-{val_loss:.2f}",
79 | save_last=True,
80 | )
81 | early_stopping = EarlyStopping("val_loss")
82 |
83 | trainer = pl.Trainer(
84 | max_epochs=epochs,
85 | gpus=torch.cuda.device_count(),
86 | callbacks=[early_stopping, checkpoint_callback]
87 | )
88 | trainer.fit(model, train_dl, valid_dl)
89 |
90 | # Cell
91 | _text2model_ = {
92 | 'ph': PieceWiseHazard,
93 | 'cox': ProportionalHazard
94 | }
95 |
96 | class ModelHazard:
97 | """
98 | Modelling instantaneous hazard (λ).
99 | parameters:
100 | - model(str): ['ph'|'cox'] which maps to Piecewise Hazard, Cox Proportional Hazard.
101 | - percentiles: list of time percentiles at which time should be broken
102 | - h: list of hidden units (disregarding input units)
103 | - bs: batch size
104 | - epochs: epochs
105 | - lr: learning rate
106 | - beta: l2 penalty on weights
107 | """
108 | def __init__(self, model:str, percentiles=[20, 40, 60, 80], h:tuple=(),
109 | bs:int=128, epochs:int=20, lr:float=1.0, beta:float=0):
110 | self.base_model = _text2model_[model]
111 | self.percentiles = percentiles
112 | self.loss_fn = HazardLoss()
113 | self.h = h
114 | self.bs, self.epochs, self.lr, self.beta = bs, epochs, lr, beta
115 |
116 | def fit(self, df):
117 | breakpoints = get_breakpoints(df, self.percentiles)
118 | train_dl, valid_dl, t_scaler, x_scaler = create_dl(df, breakpoints)
119 | dim = df.shape[1] - 2
120 | assert dim > 0, ValueError("dimensions of x input needs to be >0. Choose ph instead")
121 |
122 | model_args = {
123 | 'breakpoints': breakpoints,
124 | 't_scaler': t_scaler,
125 | 'x_scaler': x_scaler,
126 | 'h': self.h,
127 | 'dim': dim
128 | }
129 | self.model = GeneralModel(
130 | self.base_model(**model_args),
131 | self.loss_fn,
132 | self.lr
133 | )
134 |
135 | self.breakpoints = breakpoints
136 | self.t_scaler = t_scaler
137 | self.x_scaler = x_scaler
138 | train_model(self.model, train_dl, valid_dl, self.epochs)
139 |
140 | def predict(self, df):
141 | test_dl = create_test_dl(df, self.breakpoints, self.t_scaler, self.x_scaler)
142 | with torch.no_grad():
143 | self.model.eval()
144 | λ, S = [], []
145 | for x in test_dl:
146 | preds = self.model(x)
147 | λ.append(torch.exp(preds[0]))
148 | S.append(torch.exp(-preds[1]))
149 | return torch.cat(λ), torch.cat(S)
150 |
151 | def plot_survival_function(self, *args):
152 | self.model.base.plot_survival_function(*args)
153 |
154 | # Cell
155 | from .models.error_dist import *
156 |
157 | class ModelAFT:
158 | """
159 | Modelling error distribution given inputs x.
160 | parameters:
161 | - dist(str): Univariate distribution of error
162 | - h: list of hidden units (disregarding input units)
163 | - bs: batch size
164 | - epochs: epochs
165 | - lr: learning rate
166 | - beta: l2 penalty on weights
167 | """
168 | def __init__(self, dist:str, h:tuple=(),
169 | bs:int=128, epochs:int=20, lr:float=0.1, beta:float=0):
170 | self.dist = dist
171 | self.loss_fn = AFTLoss()
172 | self.h = h
173 | self.bs, self.epochs, self.lr, self.beta = bs, epochs, lr, beta
174 |
175 | def fit(self, df):
176 | train_dl, valid_dl, self.t_scaler, self.x_scaler = create_dl(df)
177 | dim = df.shape[1] - 2
178 | aft_model = AFTModel(self.dist, dim, self.h)
179 | self.model = GeneralModel(
180 | aft_model,
181 | self.loss_fn,
182 | self.lr
183 | )
184 |
185 | train_model(self.model, train_dl, valid_dl, self.epochs)
186 |
187 | def predict(self, df):
188 | """
189 | Predicts the survival probability
190 | """
191 | test_dl = create_test_dl(df)
192 | with torch.no_grad():
193 | self.model.eval()
194 | Λ = []
195 | for x in test_dl:
196 | _, logΛ = self.model(x)
197 | Λ.append(torch.exp(logΛ))
198 | return torch.cat(Λ).cpu().numpy()
199 |
200 | def predict_time(self, df):
201 | """
202 | Predicts the mode (not average) time expected for instance.
203 | """
204 | if "t" not in df.columns:
205 | df["t"] = 0
206 | test_dl = create_test_dl(df)
207 | with torch.no_grad():
208 | self.model.eval()
209 | μ = []
210 | for _, x in test_dl:
211 | logμ, _ = self.model.base.get_mode_time(x)
212 | μ.append(torch.exp(logμ))
213 | return self.t_scaler.inverse_transform(torch.cat(μ).cpu().numpy())
214 |
215 | def plot_survival(self, t, x):
216 | self.model.plot_survival_function(t, self.t_scaler, x, self.x_scaler)
--------------------------------------------------------------------------------
/docs/Gemfile.lock:
--------------------------------------------------------------------------------
1 | GEM
2 | remote: https://rubygems.org/
3 | specs:
4 | activesupport (4.2.11.1)
5 | i18n (~> 0.7)
6 | minitest (~> 5.1)
7 | thread_safe (~> 0.3, >= 0.3.4)
8 | tzinfo (~> 1.1)
9 | addressable (2.7.0)
10 | public_suffix (>= 2.0.2, < 5.0)
11 | coffee-script (2.4.1)
12 | coffee-script-source
13 | execjs
14 | coffee-script-source (1.11.1)
15 | colorator (1.1.0)
16 | commonmarker (0.17.13)
17 | ruby-enum (~> 0.5)
18 | concurrent-ruby (1.1.5)
19 | dnsruby (1.61.3)
20 | addressable (~> 2.5)
21 | em-websocket (0.5.1)
22 | eventmachine (>= 0.12.9)
23 | http_parser.rb (~> 0.6.0)
24 | ethon (0.12.0)
25 | ffi (>= 1.3.0)
26 | eventmachine (1.2.7)
27 | execjs (2.7.0)
28 | faraday (0.17.0)
29 | multipart-post (>= 1.2, < 3)
30 | ffi (1.11.3)
31 | forwardable-extended (2.6.0)
32 | gemoji (3.0.1)
33 | github-pages (202)
34 | activesupport (= 4.2.11.1)
35 | github-pages-health-check (= 1.16.1)
36 | jekyll (= 3.8.5)
37 | jekyll-avatar (= 0.6.0)
38 | jekyll-coffeescript (= 1.1.1)
39 | jekyll-commonmark-ghpages (= 0.1.6)
40 | jekyll-default-layout (= 0.1.4)
41 | jekyll-feed (= 0.11.0)
42 | jekyll-gist (= 1.5.0)
43 | jekyll-github-metadata (= 2.12.1)
44 | jekyll-mentions (= 1.4.1)
45 | jekyll-optional-front-matter (= 0.3.0)
46 | jekyll-paginate (= 1.1.0)
47 | jekyll-readme-index (= 0.2.0)
48 | jekyll-redirect-from (= 0.14.0)
49 | jekyll-relative-links (= 0.6.0)
50 | jekyll-remote-theme (= 0.4.0)
51 | jekyll-sass-converter (= 1.5.2)
52 | jekyll-seo-tag (= 2.5.0)
53 | jekyll-sitemap (= 1.2.0)
54 | jekyll-swiss (= 0.4.0)
55 | jekyll-theme-architect (= 0.1.1)
56 | jekyll-theme-cayman (= 0.1.1)
57 | jekyll-theme-dinky (= 0.1.1)
58 | jekyll-theme-hacker (= 0.1.1)
59 | jekyll-theme-leap-day (= 0.1.1)
60 | jekyll-theme-merlot (= 0.1.1)
61 | jekyll-theme-midnight (= 0.1.1)
62 | jekyll-theme-minimal (= 0.1.1)
63 | jekyll-theme-modernist (= 0.1.1)
64 | jekyll-theme-primer (= 0.5.3)
65 | jekyll-theme-slate (= 0.1.1)
66 | jekyll-theme-tactile (= 0.1.1)
67 | jekyll-theme-time-machine (= 0.1.1)
68 | jekyll-titles-from-headings (= 0.5.1)
69 | jemoji (= 0.10.2)
70 | kramdown (= 1.17.0)
71 | liquid (= 4.0.0)
72 | listen (= 3.1.5)
73 | mercenary (~> 0.3)
74 | minima (= 2.5.0)
75 | nokogiri (>= 1.10.4, < 2.0)
76 | rouge (= 3.11.0)
77 | terminal-table (~> 1.4)
78 | github-pages-health-check (1.16.1)
79 | addressable (~> 2.3)
80 | dnsruby (~> 1.60)
81 | octokit (~> 4.0)
82 | public_suffix (~> 3.0)
83 | typhoeus (~> 1.3)
84 | html-pipeline (2.12.2)
85 | activesupport (>= 2)
86 | nokogiri (>= 1.4)
87 | http_parser.rb (0.6.0)
88 | i18n (0.9.5)
89 | concurrent-ruby (~> 1.0)
90 | jekyll (3.8.5)
91 | addressable (~> 2.4)
92 | colorator (~> 1.0)
93 | em-websocket (~> 0.5)
94 | i18n (~> 0.7)
95 | jekyll-sass-converter (~> 1.0)
96 | jekyll-watch (~> 2.0)
97 | kramdown (~> 1.14)
98 | liquid (~> 4.0)
99 | mercenary (~> 0.3.3)
100 | pathutil (~> 0.9)
101 | rouge (>= 1.7, < 4)
102 | safe_yaml (~> 1.0)
103 | jekyll-avatar (0.6.0)
104 | jekyll (~> 3.0)
105 | jekyll-coffeescript (1.1.1)
106 | coffee-script (~> 2.2)
107 | coffee-script-source (~> 1.11.1)
108 | jekyll-commonmark (1.3.1)
109 | commonmarker (~> 0.14)
110 | jekyll (>= 3.7, < 5.0)
111 | jekyll-commonmark-ghpages (0.1.6)
112 | commonmarker (~> 0.17.6)
113 | jekyll-commonmark (~> 1.2)
114 | rouge (>= 2.0, < 4.0)
115 | jekyll-default-layout (0.1.4)
116 | jekyll (~> 3.0)
117 | jekyll-feed (0.11.0)
118 | jekyll (~> 3.3)
119 | jekyll-gist (1.5.0)
120 | octokit (~> 4.2)
121 | jekyll-github-metadata (2.12.1)
122 | jekyll (~> 3.4)
123 | octokit (~> 4.0, != 4.4.0)
124 | jekyll-mentions (1.4.1)
125 | html-pipeline (~> 2.3)
126 | jekyll (~> 3.0)
127 | jekyll-optional-front-matter (0.3.0)
128 | jekyll (~> 3.0)
129 | jekyll-paginate (1.1.0)
130 | jekyll-readme-index (0.2.0)
131 | jekyll (~> 3.0)
132 | jekyll-redirect-from (0.14.0)
133 | jekyll (~> 3.3)
134 | jekyll-relative-links (0.6.0)
135 | jekyll (~> 3.3)
136 | jekyll-remote-theme (0.4.0)
137 | addressable (~> 2.0)
138 | jekyll (~> 3.5)
139 | rubyzip (>= 1.2.1, < 3.0)
140 | jekyll-sass-converter (1.5.2)
141 | sass (~> 3.4)
142 | jekyll-seo-tag (2.5.0)
143 | jekyll (~> 3.3)
144 | jekyll-sitemap (1.2.0)
145 | jekyll (~> 3.3)
146 | jekyll-swiss (0.4.0)
147 | jekyll-theme-architect (0.1.1)
148 | jekyll (~> 3.5)
149 | jekyll-seo-tag (~> 2.0)
150 | jekyll-theme-cayman (0.1.1)
151 | jekyll (~> 3.5)
152 | jekyll-seo-tag (~> 2.0)
153 | jekyll-theme-dinky (0.1.1)
154 | jekyll (~> 3.5)
155 | jekyll-seo-tag (~> 2.0)
156 | jekyll-theme-hacker (0.1.1)
157 | jekyll (~> 3.5)
158 | jekyll-seo-tag (~> 2.0)
159 | jekyll-theme-leap-day (0.1.1)
160 | jekyll (~> 3.5)
161 | jekyll-seo-tag (~> 2.0)
162 | jekyll-theme-merlot (0.1.1)
163 | jekyll (~> 3.5)
164 | jekyll-seo-tag (~> 2.0)
165 | jekyll-theme-midnight (0.1.1)
166 | jekyll (~> 3.5)
167 | jekyll-seo-tag (~> 2.0)
168 | jekyll-theme-minimal (0.1.1)
169 | jekyll (~> 3.5)
170 | jekyll-seo-tag (~> 2.0)
171 | jekyll-theme-modernist (0.1.1)
172 | jekyll (~> 3.5)
173 | jekyll-seo-tag (~> 2.0)
174 | jekyll-theme-primer (0.5.3)
175 | jekyll (~> 3.5)
176 | jekyll-github-metadata (~> 2.9)
177 | jekyll-seo-tag (~> 2.0)
178 | jekyll-theme-slate (0.1.1)
179 | jekyll (~> 3.5)
180 | jekyll-seo-tag (~> 2.0)
181 | jekyll-theme-tactile (0.1.1)
182 | jekyll (~> 3.5)
183 | jekyll-seo-tag (~> 2.0)
184 | jekyll-theme-time-machine (0.1.1)
185 | jekyll (~> 3.5)
186 | jekyll-seo-tag (~> 2.0)
187 | jekyll-titles-from-headings (0.5.1)
188 | jekyll (~> 3.3)
189 | jekyll-watch (2.2.1)
190 | listen (~> 3.0)
191 | jemoji (0.10.2)
192 | gemoji (~> 3.0)
193 | html-pipeline (~> 2.2)
194 | jekyll (~> 3.0)
195 | kramdown (1.17.0)
196 | liquid (4.0.0)
197 | listen (3.1.5)
198 | rb-fsevent (~> 0.9, >= 0.9.4)
199 | rb-inotify (~> 0.9, >= 0.9.7)
200 | ruby_dep (~> 1.2)
201 | mercenary (0.3.6)
202 | mini_portile2 (2.4.0)
203 | minima (2.5.0)
204 | jekyll (~> 3.5)
205 | jekyll-feed (~> 0.9)
206 | jekyll-seo-tag (~> 2.1)
207 | minitest (5.13.0)
208 | multipart-post (2.1.1)
209 | nokogiri (1.10.5)
210 | mini_portile2 (~> 2.4.0)
211 | octokit (4.14.0)
212 | sawyer (~> 0.8.0, >= 0.5.3)
213 | pathutil (0.16.2)
214 | forwardable-extended (~> 2.6)
215 | public_suffix (3.1.1)
216 | rb-fsevent (0.10.3)
217 | rb-inotify (0.10.0)
218 | ffi (~> 1.0)
219 | rouge (3.11.0)
220 | ruby-enum (0.7.2)
221 | i18n
222 | ruby_dep (1.5.0)
223 | rubyzip (2.0.0)
224 | safe_yaml (1.0.5)
225 | sass (3.7.4)
226 | sass-listen (~> 4.0.0)
227 | sass-listen (4.0.0)
228 | rb-fsevent (~> 0.9, >= 0.9.4)
229 | rb-inotify (~> 0.9, >= 0.9.7)
230 | sawyer (0.8.2)
231 | addressable (>= 2.3.5)
232 | faraday (> 0.8, < 2.0)
233 | terminal-table (1.8.0)
234 | unicode-display_width (~> 1.1, >= 1.1.1)
235 | thread_safe (0.3.6)
236 | typhoeus (1.3.1)
237 | ethon (>= 0.9.0)
238 | tzinfo (1.2.5)
239 | thread_safe (~> 0.1)
240 | unicode-display_width (1.6.0)
241 |
242 | PLATFORMS
243 | ruby
244 |
245 | DEPENDENCIES
246 | github-pages
247 | jekyll (~> 3.7)
248 |
249 | BUNDLED WITH
250 | 2.0.2
251 |
--------------------------------------------------------------------------------
/docs/Losses.html:
--------------------------------------------------------------------------------
1 | ---
2 |
3 | title: Losses
4 |
5 |
6 | keywords: fastai
7 | sidebar: home_sidebar
8 |
9 | summary: "All the losses used in SA."
10 | description: "All the losses used in SA."
11 | nb_path: "95_Losses.ipynb"
12 | ---
13 |
22 |
23 |
24 |
25 | {% raw %}
26 |
27 |
28 |
29 |
30 | {% endraw %}
31 |
32 | {% raw %}
33 |
34 |
35 |
36 |
37 | {% endraw %}
38 |
39 |
40 |
41 |
Suppose that we have:
42 | $$
43 | t_i = \mu + \xi_i
44 | $$
45 | and $\xi_i\sim p(\xi_i|\theta)$. Then $\xi_i|\mu\sim p_\mu(\xi_i|\theta)$ where $p_\mu(\xi_i|\theta)$ is simply the distribution $p(\xi_i|\theta)$ shifted to the left by $\mu$.
46 |
In the event that the event is censored ($e_i=0$), we know that $t_i < \mu + \xi_i$ since the 'death' offset of $\xi_i$ is not observed.
47 |
Therefore we may write the likelihood of
48 | $$
49 | \begin{aligned}
50 | p(t_i, e_i|\mu) =& \left(p(t_i-\mu)\right)^{e_i} \left(\int_{t_i}^\infty p(t-\mu) dt\right)^{1-e_i}\\
51 | \log p(t_i, e_i|\mu) =& e_i \log p(t-\mu) + (1 - e_i) \log \left(1 - \int_{-\infty}^{t_i} p(t-\mu) dt \right)
52 | \end{aligned}
53 | $$
54 |
55 |
56 |
57 |
58 | {% raw %}
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
Loss() :: ABC
70 |
71 |
Helper class that provides a standard way to create an ABC using
72 | inheritance.
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 | {% endraw %}
83 |
84 | {% raw %}
85 |
86 |
87 |
88 |
89 | {% endraw %}
90 |
91 | {% raw %}
92 |
93 |
94 |
95 |
96 | {% endraw %}
97 |
98 | {% raw %}
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
AFTLoss() :: Loss
110 |
111 |
Helper class that provides a standard way to create an ABC using
112 | inheritance.
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 | {% endraw %}
123 |
124 | {% raw %}
125 |
126 |
127 |
128 |
129 | {% endraw %}
130 |
131 | {% raw %}
132 |
133 |
134 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
tensor(-0.0400)
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 | {% endraw %}
169 |
170 | {% raw %}
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
aft_loss(log_prob , e )
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 | {% endraw %}
193 |
194 | {% raw %}
195 |
196 |
197 |
198 |
199 | {% endraw %}
200 |
201 |
202 |
203 |
We use the following loss function to infer our model. See here for theory.
204 | $$
205 | -\log L = \sum_{i=1}^N \Lambda(t_i) - d_i \log \lambda(t_i)
206 | $$
207 |
208 |
209 |
210 |
211 | {% raw %}
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
HazardLoss() :: Loss
223 |
224 |
Helper class that provides a standard way to create an ABC using
225 | inheritance.
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 | {% endraw %}
236 |
237 | {% raw %}
238 |
239 |
240 |
241 |
242 | {% endraw %}
243 |
244 | {% raw %}
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
hazard_loss(hazard :Tuple[Tensor, Tensor], e :Tensor)
256 |
257 |
parameters:
258 |
259 | hazard: log hazard and Cumulative hazard
260 | e: torch.Tensor of 1 if death event occured and 0 otherwise
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 | {% endraw %}
272 |
273 | {% raw %}
274 |
275 |
276 |
277 |
278 | {% endraw %}
279 |
280 |
281 |
282 |
283 |
--------------------------------------------------------------------------------
/60_AFT_models.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "# default_exp models.aft"
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {},
15 | "source": [
16 | "# Accelerated Failure Time Models\n",
17 | "> AFT Model theory."
18 | ]
19 | },
20 | {
21 | "cell_type": "markdown",
22 | "metadata": {},
23 | "source": [
24 | "We can model the time to failure as:\n",
25 | "$$\n",
26 | "\\log T_i = \\mu + \\xi_i\n",
27 | "$$\n",
28 | "where $\\xi_i\\sim p(\\xi|\\theta)$ and $\\mu$ is the most likely log time of death (the mode of the distribution of $T_i$). We model log death as that way we do not need to restrict $\\mu + \\xi_i$ to be positive.\n",
29 | "\n",
30 | "In the censored case, where $t_i$ is the time where an instance was censored, and $T_i$ is the unobserved time of death, we have:\n",
31 | "$$\n",
32 | "\\begin{aligned}\n",
33 | "\\log T_i &= \\mu(x_i) + \\xi_i > \\log t_i\\\\\n",
34 | "\\therefore \\xi_i &> \\log t_i - \\mu(x_i)\n",
35 | "\\end{aligned}\n",
36 | "$$\n",
37 | "Note that $\\mu$ is a function of the features $x$. The log likelihood of the data ($\\mathcal{D}$) can then shown to be:\n",
38 | "$$\n",
39 | "\\begin{aligned}\n",
40 | "\\log p(\\mathcal{D}) = \\sum_{i=1}^N \\mathcal{1}(y_i=1)\\log p(\\xi_i = \\log t_i - \\mu(x_i)) + \\mathcal{1}(y_i=0)\\log p(\\xi_i &> \\log t_i - \\mu(x_i))\n",
41 | "\\end{aligned}\n",
42 | "$$"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": null,
48 | "metadata": {},
49 | "outputs": [],
50 | "source": [
51 | "#export\n",
52 | "import matplotlib.pyplot as plt\n",
53 | "import numpy as np\n",
54 | "import torch\n",
55 | "import torch.nn as nn\n",
56 | "import torch.nn.functional as F\n",
57 | "from sklearn.preprocessing import MaxAbsScaler, StandardScaler\n",
58 | "\n",
59 | "from torchlife.models.error_dist import get_distribution"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": null,
65 | "metadata": {},
66 | "outputs": [],
67 | "source": [
68 | "# hide\n",
69 | "%load_ext autoreload\n",
70 | "%autoreload 2\n",
71 | "\n",
72 | "%matplotlib inline"
73 | ]
74 | },
75 | {
76 | "cell_type": "code",
77 | "execution_count": null,
78 | "metadata": {},
79 | "outputs": [],
80 | "source": [
81 | "# export\n",
82 | "class AFTModel(nn.Module):\n",
83 | " \"\"\"\n",
84 | " Accelerated Failure Time model\n",
85 | " parameters:\n",
86 | " - Distribution of which the error is assumed to be\n",
87 | " - dim (optional): input dimensionality of variables\n",
88 | " - h (optional): number of hidden nodes\n",
89 | " \"\"\"\n",
90 | " def __init__(self, distribution:str, input_dim:int, h:tuple=()):\n",
91 | " super().__init__()\n",
92 | " self.logpdf, self.logicdf = get_distribution(distribution)\n",
93 | " self.β = nn.Parameter(-torch.rand(1))\n",
94 | " self.logσ = nn.Parameter(-torch.rand(1))\n",
95 | " \n",
96 | " if input_dim > 0:\n",
97 | " nodes = (input_dim,) + h + (1,)\n",
98 | " self.layers = nn.ModuleList([nn.Linear(a,b, bias=False) \n",
99 | " for a,b in zip(nodes[:-1], nodes[1:])])\n",
100 | "\n",
101 | " self.eps = 1e-7\n",
102 | "\n",
103 | " def get_mode_time(self, x:torch.Tensor=None):\n",
104 | " μ = self.β\n",
105 | " if x is not None:\n",
106 | " for layer in self.layers[:-1]:\n",
107 | " x = F.relu(layer(x))\n",
108 | " μ = self.β + self.layers[-1](x)\n",
109 | "\n",
110 | " σ = torch.exp(self.logσ)\n",
111 | " return μ, σ\n",
112 | " \n",
113 | " def forward(self, t:torch.Tensor, x:torch.Tensor=None):\n",
114 | " μ, σ = self.get_mode_time(x)\n",
115 | " ξ = torch.log(t + self.eps) - μ\n",
116 | " logpdf = self.logpdf(ξ, σ)\n",
117 | " logicdf = self.logicdf(ξ, σ)\n",
118 | " return logpdf, logicdf\n",
119 | " \n",
120 | " def survival_function(self, t:np.array, t_scaler, x:np.array=None, x_scaler=None):\n",
121 | " if len(t.shape) == 1:\n",
122 | " t = t[:,None]\n",
123 | " t = t_scaler.transform(t)\n",
124 | " t = torch.Tensor(t)\n",
125 | " if x is not None:\n",
126 | " if len(x.shape) == 1:\n",
127 | " x = x[None, :]\n",
128 | " if len(x) == 1:\n",
129 | " x = np.repeat(x, len(t), axis=0)\n",
130 | " x = x_scaler.transform(x)\n",
131 | " x = torch.Tensor(x)\n",
132 | " \n",
133 | " with torch.no_grad():\n",
134 | " # calculate cumulative hazard according to above\n",
135 | " _, Λ = self(t, x)\n",
136 | " return torch.exp(Λ)\n",
137 | " \n",
138 | " def plot_survival_function(self, t:np.array, t_scaler, x:np.array=None, x_scaler=None):\n",
139 | " surv_fun = self.survival_function(t, t_scaler, x, x_scaler)\n",
140 | " \n",
141 | " # plot\n",
142 | " plt.figure(figsize=(12,5))\n",
143 | " plt.plot(t, surv_fun)\n",
144 | " plt.xlabel('Time')\n",
145 | " plt.ylabel('Survival Probability')\n",
146 | " plt.show()"
147 | ]
148 | },
149 | {
150 | "cell_type": "markdown",
151 | "metadata": {},
152 | "source": [
153 | "Modelling based on **only** time and (death) event variables:"
154 | ]
155 | },
156 | {
157 | "cell_type": "code",
158 | "execution_count": null,
159 | "metadata": {},
160 | "outputs": [],
161 | "source": [
162 | "# from torchlife.data import create_dl\n",
163 | "# import pandas as pd\n",
164 | "\n",
165 | "# url = \"https://raw.githubusercontent.com/vincentarelbundock/Rdatasets/master/csv/survival/flchain.csv\"\n",
166 | "# df = pd.read_csv(url).iloc[:,1:]\n",
167 | "# df.rename(columns={'futime':'t', 'death':'e'}, inplace=True)\n",
168 | "\n",
169 | "# cols = [\"age\", \"sample.yr\", \"kappa\"]\n",
170 | "# db, t_scaler, x_scaler = create_dl(df[['t', 'e'] + cols])\n",
171 | "\n",
172 | "# death_rate = 100*df[\"e\"].mean()\n",
173 | "# print(f\"Death occurs in {death_rate:.2f}% of cases\")\n",
174 | "# print(df.shape)\n",
175 | "# df.head()"
176 | ]
177 | },
178 | {
179 | "cell_type": "code",
180 | "execution_count": null,
181 | "metadata": {},
182 | "outputs": [],
183 | "source": [
184 | "# # hide\n",
185 | "# from fastai.basics import Learner\n",
186 | "# from torchlife.losses import aft_loss\n",
187 | "\n",
188 | "# model = AFTModel(\"Gumbel\", t_scaler, x_scaler)\n",
189 | "# learner = Learner(db, model, loss_func=aft_loss)\n",
190 | "# # wd = 1e-4\n",
191 | "# learner.lr_find(start_lr=1, end_lr=10)\n",
192 | "# learner.recorder.plot()"
193 | ]
194 | },
195 | {
196 | "cell_type": "code",
197 | "execution_count": null,
198 | "metadata": {},
199 | "outputs": [],
200 | "source": [
201 | "# learner.fit(epochs=10, lr=2)"
202 | ]
203 | },
204 | {
205 | "cell_type": "code",
206 | "execution_count": null,
207 | "metadata": {},
208 | "outputs": [],
209 | "source": [
210 | "# model.plot_survival_function(np.linspace(0, df[\"t\"].max(), 100), df.loc[0, cols])"
211 | ]
212 | },
213 | {
214 | "cell_type": "code",
215 | "execution_count": null,
216 | "metadata": {},
217 | "outputs": [
218 | {
219 | "name": "stdout",
220 | "output_type": "stream",
221 | "text": [
222 | "Converted 00_index.ipynb.\n",
223 | "Converted 10_SAT.ipynb.\n",
224 | "Converted 20_KaplanMeier.ipynb.\n",
225 | "Converted 30_overall_model.ipynb.\n",
226 | "Converted 50_hazard.ipynb.\n",
227 | "Converted 55_hazard.PiecewiseHazard.ipynb.\n",
228 | "Converted 59_hazard.Cox.ipynb.\n",
229 | "Converted 60_AFT_models.ipynb.\n",
230 | "Converted 65_AFT_error_distributions.ipynb.\n",
231 | "Converted 80_data.ipynb.\n",
232 | "Converted 90_model.ipynb.\n",
233 | "Converted 95_Losses.ipynb.\n"
234 | ]
235 | }
236 | ],
237 | "source": [
238 | "# hide\n",
239 | "from nbdev.export import *\n",
240 | "notebook2script()"
241 | ]
242 | },
243 | {
244 | "cell_type": "code",
245 | "execution_count": null,
246 | "metadata": {},
247 | "outputs": [],
248 | "source": []
249 | }
250 | ],
251 | "metadata": {
252 | "kernelspec": {
253 | "display_name": "Python 3",
254 | "language": "python",
255 | "name": "python3"
256 | }
257 | },
258 | "nbformat": 4,
259 | "nbformat_minor": 2
260 | }
261 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/docs/data.html:
--------------------------------------------------------------------------------
1 | ---
2 |
3 | title: Data
4 |
5 |
6 | keywords: fastai
7 | sidebar: home_sidebar
8 |
9 | summary: "Functions used to create pytorch `DataSet`s and `DataLoader`s."
10 | description: "Functions used to create pytorch `DataSet`s and `DataLoader`s."
11 | nb_path: "80_data.ipynb"
12 | ---
13 |
22 |
23 |
24 |
25 | {% raw %}
26 |
27 |
28 |
29 |
30 | {% endraw %}
31 |
32 | {% raw %}
33 |
34 |
35 |
36 |
37 | {% endraw %}
38 |
39 | {% raw %}
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
TestData(t :array, b :Optional[array]=None , x :Optional[array]=None , t_scaler :MaxAbsScaler=None , x_scaler :StandardScaler=None ) :: Dataset
51 |
52 |
Create pyTorch Dataset
53 | parameters:
54 |
55 | t: time elapsed
56 | b: (optional) breakpoints where the hazard is different to previous segment of time.
57 | Must include 0 as first element and the maximum time as last element
58 | x: (optional) features
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 | {% endraw %}
70 |
71 | {% raw %}
72 |
73 |
74 |
75 |
76 | {% endraw %}
77 |
78 | {% raw %}
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
Data(t :array, e :array, b :Optional[array]=None , x :Optional[array]=None , t_scaler :MaxAbsScaler=None , x_scaler :StandardScaler=None ) :: TestData
90 |
91 |
Create pyTorch Dataset
92 | parameters:
93 |
94 | t: time elapsed
95 | e: (death) event observed. 1 if observed, 0 otherwise.
96 | b: (optional) breakpoints where the hazard is different to previous segment of time.
97 | x: (optional) features
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 | {% endraw %}
109 |
110 | {% raw %}
111 |
112 |
113 |
114 |
115 | {% endraw %}
116 |
117 | {% raw %}
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
TestDataFrame(df :DataFrame , b :Optional[array]=None , t_scaler :MaxAbsScaler=None , x_scaler :StandardScaler=None ) :: TestData
129 |
130 |
Wrapper around Data Class that takes in a dataframe instead
131 | parameters:
132 |
133 | df: dataframe. **Must have t (time) and e (event) columns, other cols optional.
134 | b: breakpoints of time (optional)
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 | {% endraw %}
146 |
147 | {% raw %}
148 |
149 |
150 |
151 |
152 | {% endraw %}
153 |
154 | {% raw %}
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
DataFrame(data =None , index :Optional[Collection[T_co]]=None , columns :Optional[Collection[T_co]]=None , dtype :Union[ForwardRef('ExtensionDtype'), str, dtype, Type[Union[str, float, int, complex, bool]], NoneType]=None , copy :bool=False ) :: NDFrame
166 |
167 |
Two-dimensional, size-mutable, potentially heterogeneous tabular data.
168 |
Data structure also contains labeled axes (rows and columns).
169 | Arithmetic operations align on both row and column labels. Can be
170 | thought of as a dict-like container for Series objects. The primary
171 | pandas data structure.
172 |
Parameters data : ndarray (structured or homogeneous), Iterable, dict, or DataFrame
173 | Dict can contain Series, arrays, constants, or list-like objects.
174 |
175 |
.. versionchanged:: 0.23.0
176 | If data is a dict, column order follows insertion-order for
177 | Python 3.6 and later.
178 |
179 | .. versionchanged:: 0.25.0
180 | If data is a list of dicts, column order follows insertion-order
181 | for Python 3.6 and later.
182 |
183 |
184 |
index : Index or array-like
185 | Index to use for resulting frame. Will default to RangeIndex if
186 | no indexing information part of input data and no index provided.
187 | columns : Index or array-like
188 | Column labels to use for resulting frame. Will default to
189 | RangeIndex (0, 1, 2, ..., n) if no column labels are provided.
190 | dtype : dtype, default None
191 | Data type to force. Only a single dtype is allowed. If None, infer.
192 | copy : bool, default False
193 | Copy data from inputs. Only affects DataFrame / 2d ndarray input.
194 |
See Also DataFrame.from_records : Constructor from tuples, also record arrays.
195 | DataFrame.from_dict : From dicts of Series, arrays, or dicts.
196 | read_csv : Read a comma-separated values (csv) file into DataFrame.
197 | read_table : Read general delimited file into DataFrame.
198 | read_clipboard : Read text from clipboard into DataFrame.
199 |
Examples Constructing DataFrame from a dictionary.
200 |
d = {'col1': [1, 2], 'col2': [3, 4]}
201 | df = pd.DataFrame(data=d)
202 | df
203 | col1 col2
204 | 0 1 3
205 | 1 2 4
206 |
207 |
208 |
209 |
Notice that the inferred dtype is int64.
210 |
df.dtypes
211 | col1 int64
212 | col2 int64
213 | dtype: object
214 |
215 |
216 |
217 |
To enforce a single dtype:
218 |
df = pd.DataFrame(data=d, dtype=np.int8)
219 | df.dtypes
220 | col1 int8
221 | col2 int8
222 | dtype: object
223 |
224 |
225 |
226 |
Constructing DataFrame from numpy ndarray:
227 |
df2 = pd.DataFrame(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
228 | ... columns=['a', 'b', 'c'])
229 | df2
230 | a b c
231 | 0 1 2 3
232 | 1 4 5 6
233 | 2 7 8 9
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 | {% endraw %}
247 |
248 | {% raw %}
249 |
250 |
251 |
252 |
253 | {% endraw %}
254 |
255 |
256 |
257 |
Create iterable data loaders/ fastai databunch using above:
258 |
259 |
260 |
261 |
262 | {% raw %}
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
create_dl(df :DataFrame , b :Optional[array]=None , train_size :float=0.8 , random_state =None , bs :int=128 )
274 |
275 |
Take dataframe and split into train, test, val (optional)
276 | and convert to Fastai databunch
277 |
parameters:
278 |
279 | df: pandas dataframe
280 | b(optional): breakpoints of time. Must include 0 as first element and the maximum time as last element
281 | train_p: training percentage
282 | bs: batch size
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 | {% endraw %}
294 |
295 | {% raw %}
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
create_test_dl(df :DataFrame , b :Optional[array]=None , t_scaler :MaxAbsScaler=None , x_scaler :StandardScaler=None , bs :int=128 , only_x :bool=False )
307 |
308 |
Take dataframe and return a pytorch dataloader.
309 | parameters:
310 |
311 | df: pandas dataframe
312 | b: breakpoints of time (optional)
313 | bs: batch size
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 | {% endraw %}
325 |
326 | {% raw %}
327 |
328 |
329 |
330 |
331 | {% endraw %}
332 |
333 | {% raw %}
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
get_breakpoints(df :DataFrame , percentiles :list=[20, 40, 60, 80] )
345 |
346 |
Gives the times at which death events occur at given percentile
347 | parameters:
348 | df - must contain columns 't' (time) and 'e' (death event)
349 | percentiles - list of percentages at which breakpoints occur (do not include 0 and 100)
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 | {% endraw %}
360 |
361 | {% raw %}
362 |
363 |
364 |
365 |
366 | {% endraw %}
367 |
368 |
369 |
370 |
371 |
--------------------------------------------------------------------------------
/80_data.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "# default_exp data"
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {},
15 | "source": [
16 | "# Data\n",
17 | "> Functions used to create pytorch `DataSet`s and `DataLoader`s."
18 | ]
19 | },
20 | {
21 | "cell_type": "code",
22 | "execution_count": null,
23 | "metadata": {},
24 | "outputs": [],
25 | "source": [
26 | "# export\n",
27 | "from typing import Optional, Tuple, Union\n",
28 | "\n",
29 | "import multiprocessing as mp\n",
30 | "import numpy as np\n",
31 | "import pandas as pd\n",
32 | "import torch\n",
33 | "from fastai.data_block import DataBunch, DatasetType\n",
34 | "from pandas import DataFrame\n",
35 | "from sklearn.model_selection import train_test_split\n",
36 | "from sklearn.preprocessing import MaxAbsScaler, StandardScaler\n",
37 | "from torch.utils.data import DataLoader, Dataset"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": null,
43 | "metadata": {},
44 | "outputs": [],
45 | "source": [
46 | "# hide\n",
47 | "%load_ext autoreload\n",
48 | "%autoreload 2"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": null,
54 | "metadata": {},
55 | "outputs": [],
56 | "source": [
57 | "# hide\n",
58 | "import pandas as pd\n",
59 | "\n",
60 | "url = \"https://raw.githubusercontent.com/CamDavidsonPilon/lifelines/master/lifelines/datasets/rossi.csv\"\n",
61 | "df = pd.read_csv(url)\n",
62 | "df.rename(columns={'week':'t', 'arrest':'e'}, inplace=True)"
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "execution_count": null,
68 | "metadata": {},
69 | "outputs": [],
70 | "source": [
71 | "# export\n",
72 | "class TestData(Dataset):\n",
73 | " \"\"\"\n",
74 | " Create pyTorch Dataset\n",
75 | " parameters:\n",
76 | " - t: time elapsed\n",
77 | " - b: (optional) breakpoints where the hazard is different to previous segment of time. \n",
78 | " **Must include 0 as first element and the maximum time as last element**\n",
79 | " - x: (optional) features\n",
80 | " \"\"\"\n",
81 | " def __init__(self, t:np.array, b:Optional[np.array]=None, x:Optional[np.array]=None, \n",
82 | " t_scaler:MaxAbsScaler=None, x_scaler:StandardScaler=None) -> None:\n",
83 | " super().__init__()\n",
84 | " self.t, self.b, self.x = t, b, x\n",
85 | " self.t_scaler = t_scaler\n",
86 | " self.x_scaler = x_scaler\n",
87 | " if len(t.shape) == 1:\n",
88 | " self.t = t[:,None]\n",
89 | "\n",
90 | " if t_scaler:\n",
91 | " self.t_scaler = t_scaler\n",
92 | " self.t = self.t_scaler.transform(self.t)\n",
93 | " else:\n",
94 | " self.t_scaler = MaxAbsScaler()\n",
95 | " self.t = self.t_scaler.fit_transform(self.t)\n",
96 | " \n",
97 | " if b is not None:\n",
98 | " b = b[1:-1]\n",
99 | " if len(b.shape) == 1:\n",
100 | " b = b[:,None]\n",
101 | " if t_scaler:\n",
102 | " self.b = t_scaler.transform(b).squeeze()\n",
103 | " else:\n",
104 | " self.b = self.t_scaler.transform(b).squeeze()\n",
105 | " \n",
106 | " if x is not None:\n",
107 | " if len(x.shape) == 1:\n",
108 | " self.x = x[None, :]\n",
109 | " if x_scaler:\n",
110 | " self.x_scaler = x_scaler\n",
111 | " self.x = self.x_scaler.transform(self.x)\n",
112 | " else:\n",
113 | " self.x_scaler = StandardScaler()\n",
114 | " self.x = self.x_scaler.fit_transform(self.x)\n",
115 | " \n",
116 | " self.only_x = False\n",
117 | " \n",
118 | " def __len__(self) -> int:\n",
119 | " return len(self.t)\n",
120 | " \n",
121 | " def __getitem__(self, i:int) -> Tuple:\n",
122 | " if self.only_x:\n",
123 | " return torch.Tensor(self.x[i])\n",
124 | " \n",
125 | " time = torch.Tensor(self.t[i])\n",
126 | " \n",
127 | " if self.b is None:\n",
128 | " x_ = (time,)\n",
129 | " else:\n",
130 | " t_section = torch.LongTensor([np.searchsorted(self.b, self.t[i])])\n",
131 | " x_ = (time, t_section.squeeze())\n",
132 | " \n",
133 | " if self.x is not None:\n",
134 | " x = torch.Tensor(self.x[i])\n",
135 | " x_ = x_ + (x,)\n",
136 | " \n",
137 | " return x_"
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": null,
143 | "metadata": {},
144 | "outputs": [],
145 | "source": [
146 | "# export\n",
147 | "class Data(TestData):\n",
148 | " \"\"\"\n",
149 | " Create pyTorch Dataset\n",
150 | " parameters:\n",
151 | " - t: time elapsed\n",
152 | " - e: (death) event observed. 1 if observed, 0 otherwise.\n",
153 | " - b: (optional) breakpoints where the hazard is different to previous segment of time.\n",
154 | " - x: (optional) features\n",
155 | " \"\"\"\n",
156 | " def __init__(self, t:np.array, e:np.array, b:Optional[np.array]=None, x:Optional[np.array]=None,\n",
157 | " t_scaler:MaxAbsScaler=None, x_scaler:StandardScaler=None) -> None:\n",
158 | " super().__init__(t, b, x, t_scaler, x_scaler)\n",
159 | " self.e = e\n",
160 | " if len(e.shape) == 1:\n",
161 | " self.e = e[:,None]\n",
162 | " \n",
163 | " def __getitem__(self, i) -> Tuple:\n",
164 | " x_ = super().__getitem__(i)\n",
165 | " e = torch.Tensor(self.e[i])\n",
166 | " return x_, e"
167 | ]
168 | },
169 | {
170 | "cell_type": "code",
171 | "execution_count": null,
172 | "metadata": {},
173 | "outputs": [
174 | {
175 | "data": {
176 | "text/plain": [
177 | "([torch.Size([64, 1]), torch.Size([64, 3])], torch.Size([64, 1]))"
178 | ]
179 | },
180 | "execution_count": null,
181 | "metadata": {},
182 | "output_type": "execute_result"
183 | }
184 | ],
185 | "source": [
186 | "# hide\n",
187 | "np.random.seed(42)\n",
188 | "N = 100\n",
189 | "D = 3\n",
190 | "p = 0.1\n",
191 | "bs = 64\n",
192 | "\n",
193 | "x = np.random.randn(N, D)\n",
194 | "t = np.arange(N)\n",
195 | "e = np.random.binomial(1, p, N)\n",
196 | "\n",
197 | "data = Data(t, e, x=x)\n",
198 | "batch = next(iter(DataLoader(data, bs)))\n",
199 | "assert len(batch[-1]) == bs, (f\"length of batch {len(batch)} is different\" \n",
200 | " f\"to intended batch size {bs}\")\n",
201 | "[b.shape for b in batch[0]], batch[1].shape"
202 | ]
203 | },
204 | {
205 | "cell_type": "code",
206 | "execution_count": null,
207 | "metadata": {},
208 | "outputs": [
209 | {
210 | "name": "stdout",
211 | "output_type": "stream",
212 | "text": [
213 | "[torch.Size([64, 1]), torch.Size([64]), torch.Size([64, 3])] torch.Size([64, 1])\n"
214 | ]
215 | }
216 | ],
217 | "source": [
218 | "# hide\n",
219 | "breakpoints = np.array([0, 10, 50, N-1])\n",
220 | "\n",
221 | "data = Data(t, e, breakpoints, x)\n",
222 | "batch2 = next(iter(DataLoader(data, bs)))\n",
223 | "assert len(batch2[-1]) == bs, (f\"length of batch {len(batch2)} is different\" \n",
224 | " f\"to intended batch size {bs}\")\n",
225 | "print([b.shape for b in batch2[0]], batch2[1].shape)\n",
226 | "\n",
227 | "assert torch.all(batch[0][0] == batch2[0][0]), (\"Discrepancy between batch \"\n",
228 | " \"with breakpoints and without\")"
229 | ]
230 | },
231 | {
232 | "cell_type": "code",
233 | "execution_count": null,
234 | "metadata": {},
235 | "outputs": [],
236 | "source": [
237 | "# export\n",
238 | "class TestDataFrame(TestData):\n",
239 | " \"\"\"\n",
240 | " Wrapper around Data Class that takes in a dataframe instead\n",
241 | " parameters:\n",
242 | " - df: dataframe. **Must have t (time) and e (event) columns, other cols optional.\n",
243 | " - b: breakpoints of time (optional)\n",
244 | " \"\"\"\n",
245 | " def __init__(self, df:DataFrame, b:Optional[np.array]=None,\n",
246 | " t_scaler:MaxAbsScaler=None, x_scaler:StandardScaler=None) -> None:\n",
247 | " t = df['t'].values\n",
248 | " remainder = list(set(df.columns) - set(['t', 'e']))\n",
249 | " x = df[remainder].values\n",
250 | " if x.shape[1] == 0:\n",
251 | " x = None\n",
252 | " super().__init__(t, b, x, t_scaler, x_scaler)"
253 | ]
254 | },
255 | {
256 | "cell_type": "code",
257 | "execution_count": null,
258 | "metadata": {},
259 | "outputs": [],
260 | "source": [
261 | "# export\n",
262 | "class DataFrame(Data):\n",
263 | " \"\"\"\n",
264 | " Wrapper around Data Class that takes in a dataframe instead\n",
265 | " parameters:\n",
266 | " - df: dataframe. **Must have t (time) and e (event) columns, other cols optional.\n",
267 | " - b: breakpoints of time (optional)\n",
268 | " \"\"\"\n",
269 | " def __init__(self, df:DataFrame, b:Optional[np.array]=None,\n",
270 | " t_scaler:MaxAbsScaler=None, x_scaler:StandardScaler=None) -> None:\n",
271 | " t = df['t'].values\n",
272 | " e = df['e'].values\n",
273 | " x = df.drop(['t', 'e'], axis=1).values\n",
274 | " if x.shape[1] == 0:\n",
275 | " x = None\n",
276 | " super().__init__(t, e, b, x, t_scaler, x_scaler)"
277 | ]
278 | },
279 | {
280 | "cell_type": "code",
281 | "execution_count": null,
282 | "metadata": {},
283 | "outputs": [
284 | {
285 | "data": {
286 | "text/plain": [
287 | "((tensor([0.0101]),), tensor([0.]))"
288 | ]
289 | },
290 | "execution_count": null,
291 | "metadata": {},
292 | "output_type": "execute_result"
293 | }
294 | ],
295 | "source": [
296 | "# hide\n",
297 | "# testing with pandas dataframe\n",
298 | "import pandas as pd\n",
299 | "\n",
300 | "df = pd.DataFrame({'t': t, 'e': e})\n",
301 | "df2 = DataFrame(df)\n",
302 | "df2[1]"
303 | ]
304 | },
305 | {
306 | "cell_type": "code",
307 | "execution_count": null,
308 | "metadata": {},
309 | "outputs": [
310 | {
311 | "data": {
312 | "text/plain": [
313 | "((tensor([0.0101]), tensor([ 1.7440, -0.0523, -0.2790])), tensor([0.]))"
314 | ]
315 | },
316 | "execution_count": null,
317 | "metadata": {},
318 | "output_type": "execute_result"
319 | }
320 | ],
321 | "source": [
322 | "# hide\n",
323 | "# testing with x\n",
324 | "new_df = pd.concat([df, pd.DataFrame(x)], axis=1)\n",
325 | "df3 = DataFrame(new_df)\n",
326 | "df3[1]"
327 | ]
328 | },
329 | {
330 | "cell_type": "code",
331 | "execution_count": null,
332 | "metadata": {},
333 | "outputs": [
334 | {
335 | "data": {
336 | "text/plain": [
337 | "((tensor([0.0101]), tensor(0), tensor([ 1.7440, -0.0523, -0.2790])),\n",
338 | " tensor([0.]))"
339 | ]
340 | },
341 | "execution_count": null,
342 | "metadata": {},
343 | "output_type": "execute_result"
344 | }
345 | ],
346 | "source": [
347 | "# hide\n",
348 | "# testing with breakpoints\n",
349 | "new_df = pd.concat([df, pd.DataFrame(x)], axis=1)\n",
350 | "df3 = DataFrame(new_df, breakpoints)\n",
351 | "df3[1]"
352 | ]
353 | },
354 | {
355 | "cell_type": "markdown",
356 | "metadata": {},
357 | "source": [
358 | "Create iterable data loaders/ fastai databunch using above:"
359 | ]
360 | },
361 | {
362 | "cell_type": "code",
363 | "execution_count": null,
364 | "metadata": {},
365 | "outputs": [],
366 | "source": [
367 | "# export\n",
368 | "def create_dl(df:pd.DataFrame, b:Optional[np.array]=None, train_size:float=0.8, random_state=None, bs:int=128)\\\n",
369 | " -> Tuple[DataBunch, MaxAbsScaler, StandardScaler]:\n",
370 | " \"\"\"\n",
371 | " Take dataframe and split into train, test, val (optional)\n",
372 | " and convert to Fastai databunch\n",
373 | "\n",
374 | " parameters:\n",
375 | " - df: pandas dataframe\n",
376 | " - b(optional): breakpoints of time. **Must include 0 as first element and the maximum time as last element**\n",
377 | " - train_p: training percentage\n",
378 | " - bs: batch size\n",
379 | " \"\"\"\n",
380 | " df.reset_index(drop=True, inplace=True)\n",
381 | " train, val = train_test_split(df, train_size=train_size, stratify=df[\"e\"], random_state=random_state)\n",
382 | " train.reset_index(drop=True, inplace=True)\n",
383 | " val.reset_index(drop=True, inplace=True)\n",
384 | " \n",
385 | " train_ds = DataFrame(train, b)\n",
386 | " val_ds = DataFrame(val, b, train_ds.t_scaler, train_ds.x_scaler)\n",
387 | " \n",
388 | " train_dl = DataLoader(train_ds, bs, shuffle=True, drop_last=False, num_workers=mp.cpu_count())\n",
389 | " val_dl = DataLoader(val_ds, bs, shuffle=False, drop_last=False, num_workers=mp.cpu_count())\n",
390 | " \n",
391 | " return train_dl, val_dl, train_ds.t_scaler, train_ds.x_scaler\n",
392 | "\n",
393 | "def create_test_dl(df:pd.DataFrame, b:Optional[np.array]=None, \n",
394 | " t_scaler:MaxAbsScaler=None, x_scaler:StandardScaler=None, \n",
395 | " bs:int=128, only_x:bool=False) -> DataLoader:\n",
396 | " \"\"\"\n",
397 | " Take dataframe and return a pytorch dataloader.\n",
398 | " parameters:\n",
399 | " - df: pandas dataframe\n",
400 | " - b: breakpoints of time (optional)\n",
401 | " - bs: batch size\n",
402 | " \"\"\"\n",
403 | " if only_x:\n",
404 | " df[\"t\"] = 0\n",
405 | " df.reset_index(drop=True, inplace=True)\n",
406 | " test_ds = TestDataFrame(df, b, t_scaler, x_scaler)\n",
407 | " test_ds.only_x = only_x\n",
408 | " test_dl = DataLoader(test_ds, bs, shuffle=False, drop_last=False, num_workers=mp.cpu_count())\n",
409 | " return test_dl"
410 | ]
411 | },
412 | {
413 | "cell_type": "code",
414 | "execution_count": null,
415 | "metadata": {},
416 | "outputs": [],
417 | "source": [
418 | "# export\n",
419 | "def get_breakpoints(df:DataFrame, percentiles:list=[20, 40, 60, 80]) -> np.array:\n",
420 | " \"\"\"\n",
421 | " Gives the times at which death events occur at given percentile\n",
422 | " parameters:\n",
423 | " df - must contain columns 't' (time) and 'e' (death event)\n",
424 | " percentiles - list of percentages at which breakpoints occur (do not include 0 and 100)\n",
425 | " \"\"\"\n",
426 | " event_times = df.loc[df['e']==1, 't'].values\n",
427 | " breakpoints = np.percentile(event_times, percentiles)\n",
428 | " breakpoints = np.array([0] + breakpoints.tolist() + [df['t'].max()])\n",
429 | " \n",
430 | " return breakpoints"
431 | ]
432 | },
433 | {
434 | "cell_type": "code",
435 | "execution_count": null,
436 | "metadata": {},
437 | "outputs": [
438 | {
439 | "name": "stdout",
440 | "output_type": "stream",
441 | "text": [
442 | "Converted 00_index.ipynb.\n",
443 | "Converted 10_SAT.ipynb.\n",
444 | "Converted 20_KaplanMeier.ipynb.\n",
445 | "Converted 30_overall_model.ipynb.\n",
446 | "Converted 50_hazard.ipynb.\n",
447 | "Converted 55_hazard.PiecewiseHazard.ipynb.\n",
448 | "Converted 59_hazard.Cox.ipynb.\n",
449 | "Converted 60_AFT_models.ipynb.\n",
450 | "Converted 65_AFT_error_distributions.ipynb.\n",
451 | "Converted 80_data.ipynb.\n",
452 | "Converted 90_model.ipynb.\n",
453 | "Converted 95_Losses.ipynb.\n"
454 | ]
455 | }
456 | ],
457 | "source": [
458 | "# hide\n",
459 | "from nbdev.export import *\n",
460 | "notebook2script()"
461 | ]
462 | },
463 | {
464 | "cell_type": "code",
465 | "execution_count": null,
466 | "metadata": {},
467 | "outputs": [],
468 | "source": []
469 | }
470 | ],
471 | "metadata": {
472 | "kernelspec": {
473 | "display_name": "Python 3",
474 | "language": "python",
475 | "name": "python3"
476 | }
477 | },
478 | "nbformat": 4,
479 | "nbformat_minor": 2
480 | }
481 |
--------------------------------------------------------------------------------