├── src └── femr │ ├── py.typed │ ├── __init__.py │ ├── post_etl_pipelines │ ├── __init__.py │ └── stanford.py │ ├── models │ ├── tokenizer │ │ ├── __init__.py │ │ └── flat_tokenizer.py │ ├── rmsnorm.py │ ├── xformers.py │ └── config.py │ ├── labelers │ ├── __init__.py │ ├── omop.py │ └── core.py │ ├── featurizers │ ├── __init__.py │ └── utils.py │ ├── pat_utils.py │ ├── splits.py │ ├── stat_utils.py │ ├── transforms │ ├── __init__.py │ └── stanford.py │ └── ontology.py ├── tests ├── __init__.py ├── models │ ├── test_survival_calculator.py │ └── test_batch_creator.py ├── featurizers │ └── test_OnlineStatistics.py ├── femr_test_tools.py ├── test_ontology.py ├── test_transforms.py └── labelers │ └── test_TimeHorizonEventLabeler.py ├── tutorials ├── trash │ └── .gitkeep ├── input │ ├── synthetic_meds │ │ ├── meds_reader.version │ │ ├── meds_reader.properties │ │ ├── code │ │ │ ├── data │ │ │ ├── zdict │ │ │ └── dictionary │ │ ├── time │ │ │ ├── data │ │ │ └── zdict │ │ ├── subject_id │ │ ├── numeric_value │ │ │ ├── data │ │ │ └── zdict │ │ ├── meds_reader.null_map │ │ │ ├── data │ │ │ └── zdict │ │ ├── metadata │ │ │ └── metadata.json │ │ └── meds_reader.length │ ├── ontology.pkl │ ├── labels.parquet │ ├── synthetic │ │ ├── ontology.pkl │ │ ├── data │ │ │ └── subjects.parquet │ │ └── metadata │ │ │ └── metadata.json │ ├── clmbr_model │ │ ├── model.safetensors │ │ ├── dictionary.msgpack │ │ ├── config.json │ │ └── main_split.csv │ └── motor_model │ │ ├── model.safetensors │ │ ├── dictionary.msgpack │ │ └── main_split.csv ├── synthetic_data_generation │ └── generate_patients.py ├── 1_Ontology.ipynb ├── 2_Labeling.ipynb ├── 5_MOTOR Featurization And Modeling.ipynb ├── 3_Count Featurization And Modeling.ipynb └── 4_Train MOTOR.ipynb ├── .gitattributes ├── _config.yml ├── .flake8 ├── .gitignore ├── .github ├── ISSUE_TEMPLATE │ ├── feature_request.md │ └── bug_report.md └── workflows │ ├── tests.yaml │ └── build.yaml ├── .pre-commit-config.yaml ├── .mypy.ini ├── pyproject.toml ├── README.md ├── tools └── stanford │ └── download_bigquery.py └── LICENSE /src/femr/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tutorials/trash/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | tutorials/input/* binary 2 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-cayman 2 | -------------------------------------------------------------------------------- /tutorials/input/synthetic_meds/meds_reader.version: -------------------------------------------------------------------------------- 1 | 2 2 | -------------------------------------------------------------------------------- /src/femr/__init__.py: -------------------------------------------------------------------------------- 1 | from femr._version import __version__ # noqa 2 | -------------------------------------------------------------------------------- /src/femr/post_etl_pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | """A collection of common ETL pipelines.""" 2 | -------------------------------------------------------------------------------- /src/femr/models/tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .hierarchical_tokenizer import HierarchicalTokenizer -------------------------------------------------------------------------------- /tutorials/input/ontology.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/som-shahlab/femr/HEAD/tutorials/input/ontology.pkl -------------------------------------------------------------------------------- /tutorials/input/labels.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/som-shahlab/femr/HEAD/tutorials/input/labels.parquet -------------------------------------------------------------------------------- /tutorials/input/synthetic_meds/meds_reader.properties: -------------------------------------------------------------------------------- 1 | code numeric_valuetime -------------------------------------------------------------------------------- /tutorials/input/synthetic/ontology.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/som-shahlab/femr/HEAD/tutorials/input/synthetic/ontology.pkl -------------------------------------------------------------------------------- /tutorials/input/synthetic_meds/code/data: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/som-shahlab/femr/HEAD/tutorials/input/synthetic_meds/code/data -------------------------------------------------------------------------------- /tutorials/input/synthetic_meds/time/data: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/som-shahlab/femr/HEAD/tutorials/input/synthetic_meds/time/data -------------------------------------------------------------------------------- /tutorials/input/synthetic_meds/code/zdict: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/som-shahlab/femr/HEAD/tutorials/input/synthetic_meds/code/zdict -------------------------------------------------------------------------------- /tutorials/input/synthetic_meds/subject_id: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/som-shahlab/femr/HEAD/tutorials/input/synthetic_meds/subject_id -------------------------------------------------------------------------------- /tutorials/input/synthetic_meds/time/zdict: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/som-shahlab/femr/HEAD/tutorials/input/synthetic_meds/time/zdict -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | extend-ignore = E203, W503, FI10, FI11, FI12, FI13, FI14, FI15, FI16, FI17, FI58, E301, E302 4 | -------------------------------------------------------------------------------- /tutorials/input/clmbr_model/model.safetensors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/som-shahlab/femr/HEAD/tutorials/input/clmbr_model/model.safetensors -------------------------------------------------------------------------------- /tutorials/input/motor_model/model.safetensors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/som-shahlab/femr/HEAD/tutorials/input/motor_model/model.safetensors -------------------------------------------------------------------------------- /tutorials/input/clmbr_model/dictionary.msgpack: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/som-shahlab/femr/HEAD/tutorials/input/clmbr_model/dictionary.msgpack -------------------------------------------------------------------------------- /tutorials/input/motor_model/dictionary.msgpack: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/som-shahlab/femr/HEAD/tutorials/input/motor_model/dictionary.msgpack -------------------------------------------------------------------------------- /tutorials/input/synthetic/data/subjects.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/som-shahlab/femr/HEAD/tutorials/input/synthetic/data/subjects.parquet -------------------------------------------------------------------------------- /tutorials/input/synthetic_meds/code/dictionary: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/som-shahlab/femr/HEAD/tutorials/input/synthetic_meds/code/dictionary -------------------------------------------------------------------------------- /tutorials/input/synthetic_meds/numeric_value/data: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/som-shahlab/femr/HEAD/tutorials/input/synthetic_meds/numeric_value/data -------------------------------------------------------------------------------- /tutorials/input/synthetic_meds/numeric_value/zdict: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/som-shahlab/femr/HEAD/tutorials/input/synthetic_meds/numeric_value/zdict -------------------------------------------------------------------------------- /tutorials/input/synthetic_meds/meds_reader.null_map/data: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/som-shahlab/femr/HEAD/tutorials/input/synthetic_meds/meds_reader.null_map/data -------------------------------------------------------------------------------- /tutorials/input/synthetic_meds/meds_reader.null_map/zdict: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/som-shahlab/femr/HEAD/tutorials/input/synthetic_meds/meds_reader.null_map/zdict -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | __pycache__ 3 | *.so 4 | bazel-* 5 | *.egg-info 6 | ignore/* 7 | *.ipynb_checkpoints* 8 | tutorials/trash 9 | tutorials/tmp_trainer 10 | _version.py 11 | -------------------------------------------------------------------------------- /src/femr/labelers/__init__.py: -------------------------------------------------------------------------------- 1 | """A module for generating labels on subject timelines.""" 2 | 3 | from __future__ import annotations 4 | 5 | from femr.labelers.core import * # noqa 6 | -------------------------------------------------------------------------------- /tutorials/input/synthetic/metadata/metadata.json: -------------------------------------------------------------------------------- 1 | {"dataset_name": "femr synthetic datata", "dataset_version": "1", "etl_name": "synthetic data", "etl_version": "1", "code_metadata": {}} -------------------------------------------------------------------------------- /tutorials/input/synthetic_meds/metadata/metadata.json: -------------------------------------------------------------------------------- 1 | {"dataset_name": "femr synthetic datata", "dataset_version": "1", "etl_name": "synthetic data", "etl_version": "1", "code_metadata": {}} -------------------------------------------------------------------------------- /src/femr/featurizers/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from femr.featurizers.core import * # noqa 4 | from femr.featurizers.featurizers import AgeFeaturizer, CountFeaturizer # noqa 5 | -------------------------------------------------------------------------------- /src/femr/pat_utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | import meds 4 | import meds_reader 5 | 6 | 7 | def get_subject_birthdate(subject: meds_reader.Subject) -> datetime.datetime: 8 | for e in subject.events: 9 | if e.code == meds.birth_code: 10 | return e.time 11 | raise ValueError("Couldn't find subject birthdate -- Subject has no events " + str(subject.events[:5])) 12 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## Describe the bug 11 | A clear and concise description of what the bug is. 12 | 13 | ## Steps to reproduce the bug 14 | ```python 15 | # Sample code to reproduce the bug 16 | ``` 17 | 18 | ## Expected results 19 | A clear and concise description of the expected results. 20 | 21 | ## Actual results 22 | Specify the actual results or traceback. 23 | 24 | ## Environment info 25 | 26 | - `datasets` version: 27 | - Platform: 28 | - Python version: 29 | - PyArrow version: 30 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | exclude: '^tutorials/input' 2 | 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v4.6.0 6 | hooks: 7 | - id: check-yaml 8 | - id: end-of-file-fixer 9 | - id: trailing-whitespace 10 | - repo: https://github.com/psf/black 11 | rev: 24.4.2 12 | hooks: 13 | - id: black 14 | - repo: https://github.com/PyCQA/flake8 15 | rev: 7.0.0 16 | hooks: 17 | - id: flake8 18 | exclude: ^tutorials 19 | - repo: https://github.com/pre-commit/mirrors-mypy 20 | rev: v1.10.0 21 | hooks: 22 | - id: mypy 23 | exclude: ^tutorials 24 | - repo: https://github.com/PyCQA/isort 25 | rev: 5.13.2 26 | hooks: 27 | - id: isort 28 | -------------------------------------------------------------------------------- /.github/workflows/tests.yaml: -------------------------------------------------------------------------------- 1 | name: Run tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | test: 13 | 14 | runs-on: ${{ matrix.os }} 15 | strategy: 16 | fail-fast: false 17 | matrix: 18 | os: ["ubuntu-latest"] 19 | python-version: ["3.9", "3.10", "3.11", "3.12"] 20 | 21 | steps: 22 | - uses: actions/checkout@v3 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v3 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | python -m pip install pytest 31 | python -m pip install -e . 32 | - name: Python tests 33 | run: | 34 | pytest tests 35 | -------------------------------------------------------------------------------- /tutorials/input/synthetic_meds/meds_reader.length: -------------------------------------------------------------------------------- 1 | 2 |        ! 3 |    4 |  5 |     6 | 7 | 8 |  !   9 | 10 | !      11 | 12 |    13 |     14 | 15 |       ! !      16 | -------------------------------------------------------------------------------- /src/femr/models/rmsnorm.py: -------------------------------------------------------------------------------- 1 | # From https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py 2 | # coding=utf-8 3 | 4 | from __future__ import absolute_import, division, print_function 5 | 6 | import torch 7 | import torch.nn as nn 8 | import transformers.pytorch_utils 9 | 10 | 11 | # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral 12 | class RMSNorm(nn.Module): 13 | def __init__(self, hidden_size, eps=1e-6): 14 | """ 15 | MistralRMSNorm is equivalent to T5LayerNorm 16 | """ 17 | super().__init__() 18 | self.weight = nn.Parameter(torch.ones(hidden_size)) 19 | self.variance_epsilon = eps 20 | 21 | def forward(self, hidden_states): 22 | input_dtype = hidden_states.dtype 23 | hidden_states = hidden_states.to(torch.float32) 24 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 25 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 26 | return self.weight * hidden_states.to(input_dtype) 27 | 28 | 29 | transformers.pytorch_utils.ALL_LAYERNORM_LAYERS.append(RMSNorm) 30 | -------------------------------------------------------------------------------- /.mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | python_version = 3.10 3 | show_column_numbers = True 4 | warn_unused_configs = True 5 | 6 | [mypy-femr] 7 | warn_return_any = True 8 | disallow_untyped_defs = True 9 | disallow_incomplete_defs = True 10 | no_implicit_optional = True 11 | warn_unused_ignores = True 12 | warn_unreachable = True 13 | strict_equality = True 14 | follow_imports = silent 15 | disallow_any_generics = True 16 | 17 | [mypy-numpy.*] 18 | ignore_missing_imports = True 19 | 20 | [mypy-scipy.*] 21 | ignore_missing_imports = True 22 | 23 | [mypy-setuptools.*] 24 | ignore_missing_imports = True 25 | 26 | [mypy-pytest.*] 27 | ignore_missing_imports = True 28 | 29 | [mypy-sklearn.*] 30 | ignore_missing_imports = True 31 | 32 | [mypy-toml.*] 33 | ignore_errors = True 34 | ignore_missing_imports = True 35 | 36 | [mypy-datasets.*] 37 | ignore_missing_imports = True 38 | 39 | [mypy-flash_attn.*] 40 | ignore_missing_imports = True 41 | 42 | [mypy-pyarrow.*] 43 | ignore_missing_imports = True 44 | 45 | [mypy-torch.*] 46 | ignore_missing_imports = True 47 | 48 | [mypy-einops.*] 49 | ignore_missing_imports = True 50 | 51 | [mypy-transformers.*] 52 | ignore_missing_imports = True 53 | 54 | [mypy-msgpack.*] 55 | ignore_missing_imports = True 56 | 57 | [mypy-xformers.*] 58 | ignore_missing_imports = True 59 | -------------------------------------------------------------------------------- /tests/models/test_survival_calculator.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from typing import Set 3 | 4 | from femr_test_tools import DummyEvent, DummySubject 5 | 6 | import femr.models.tasks 7 | 8 | 9 | class DummyOntology: 10 | def get_all_parents(self, code: str) -> Set[str]: 11 | if code == "2": 12 | return {"2", "2_parent"} 13 | else: 14 | return {code} 15 | 16 | 17 | def test_calculator(): 18 | subject = DummySubject( 19 | subject_id=100, 20 | events=[ 21 | DummyEvent(time=datetime.datetime(1990, 1, 10), code="1"), 22 | DummyEvent(time=datetime.datetime(1990, 1, 20), code="2"), 23 | DummyEvent(time=datetime.datetime(1990, 1, 25), code="3"), 24 | DummyEvent(time=datetime.datetime(1990, 1, 25), code="1"), 25 | ], 26 | ) 27 | 28 | calculator = femr.models.tasks.SurvivalCalculator(DummyOntology(), subject) 29 | 30 | assert calculator.get_future_events_for_time(datetime.datetime(1990, 1, 1)) == ( 31 | datetime.timedelta(days=24), 32 | { 33 | "1": datetime.timedelta(days=9), 34 | "2": datetime.timedelta(days=19), 35 | "2_parent": datetime.timedelta(days=19), 36 | "3": datetime.timedelta(days=24), 37 | }, 38 | ) 39 | assert calculator.get_future_events_for_time(datetime.datetime(1990, 1, 10)) == ( 40 | datetime.timedelta(days=15), 41 | { 42 | "1": datetime.timedelta(days=15), 43 | "2": datetime.timedelta(days=10), 44 | "2_parent": datetime.timedelta(days=10), 45 | "3": datetime.timedelta(days=15), 46 | }, 47 | ) 48 | assert calculator.get_future_events_for_time(datetime.datetime(1990, 1, 20)) == ( 49 | datetime.timedelta(days=5), 50 | {"1": datetime.timedelta(days=5), "3": datetime.timedelta(days=5)}, 51 | ) 52 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools >= 69.0", "setuptools-scm>=8.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "femr" 7 | description = "Framework for Electronic Medical Records. A python package for building models using EHR data." 8 | readme = "README.md" 9 | dependencies = [ 10 | "numpy >= 1.22", 11 | "scipy >= 1.6", 12 | "scikit-learn >= 0.24", 13 | "tqdm >= 4.60.0", 14 | "zstandard >= 0.18", 15 | "icecream == 2.1.3", 16 | "nptyping == 2.4.1", 17 | "msgpack >= 1.0.5", 18 | "meds == 0.3.3", 19 | "meds_reader >= 0.1.3", 20 | "torch >= 2.1.2", 21 | "transformers >= 4.25", 22 | "datasets >= 2.15", 23 | "polars >= 0.20", 24 | "dill >= 0.3.7", 25 | "pandas >= 2.2", 26 | "pandas-stubs >= 2.2", 27 | "types-tqdm >= 4.60.0", 28 | "xformers >= 0.0.28", 29 | "torch_hawk", 30 | "accelerate >= 0.26.0", 31 | ] 32 | requires-python=">3.9" 33 | dynamic = ["version"] 34 | 35 | [tool.setuptools_scm] 36 | version_file = "src/femr/_version.py" 37 | 38 | [project.scripts] 39 | 40 | femr_stanford_omop_fixer = "femr.post_etl_pipelines.stanford:femr_stanford_omop_fixer_program" 41 | 42 | [project.optional-dependencies] 43 | build = [ 44 | "pytest >= 5.2", 45 | "flake8-future-import >= 0.4.6", 46 | "black >= 19.10b0", 47 | "isort >= 5.3.2", 48 | "mypy >= 0.782", 49 | "flake8 >= 3.8.3", 50 | "sphinx >= 3.2.1", 51 | "sphinx-rtd-theme >= 0.5.0", 52 | "sphinx-autoapi >= 1.5.1", 53 | "torchtyping == 0.1.4", 54 | ] 55 | 56 | [tool.isort] 57 | multi_line_output = 3 58 | include_trailing_comma = true 59 | force_grid_wrap = 0 60 | use_parentheses = true 61 | ensure_newline_before_comments = true 62 | line_length = 120 63 | 64 | [tool.black] 65 | line_length = 120 66 | target_version = ['py310'] 67 | 68 | [tool.pydocstyle] 69 | match = "src/.*\\.py" 70 | 71 | 72 | [tool.pytest.ini_options] 73 | pythonpath = [ 74 | "tests" 75 | ] 76 | -------------------------------------------------------------------------------- /src/femr/splits.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import csv 4 | import dataclasses 5 | import hashlib 6 | import struct 7 | from typing import List 8 | 9 | 10 | @dataclasses.dataclass 11 | class SubjectSplit: 12 | train_subject_ids: List[int] 13 | test_subject_ids: List[int] 14 | 15 | def save_to_csv(self, fname: str): 16 | with open(fname, "w") as f: 17 | writer = csv.DictWriter(f, ("subject_id", "split_name")) 18 | writer.writeheader() 19 | for train in self.train_subject_ids: 20 | writer.writerow({"subject_id": train, "split_name": "train"}) 21 | for test in self.test_subject_ids: 22 | writer.writerow({"subject_id": test, "split_name": "test"}) 23 | 24 | @classmethod 25 | def load_from_csv(cls, fname: str): 26 | train_subject_ids: List[int] = [] 27 | test_subject_ids: List[int] = [] 28 | with open(fname, "r") as f: 29 | for row in csv.DictReader(f): 30 | if row["split_name"] == "train": 31 | train_subject_ids.append(int(row["subject_id"])) 32 | else: 33 | test_subject_ids.append(int(row["subject_id"])) 34 | 35 | return SubjectSplit(train_subject_ids=train_subject_ids, test_subject_ids=test_subject_ids) 36 | 37 | 38 | def generate_hash_split(subject_ids: List[int], seed: int, frac_test: float = 0.15) -> SubjectSplit: 39 | train_subject_ids = [] 40 | test_subject_ids = [] 41 | 42 | for subject_id in subject_ids: 43 | # Convert the integer to bytes 44 | value_bytes = struct.pack(">q", seed) + struct.pack(">q", subject_id) 45 | 46 | # Calculate SHA-256 hash 47 | sha256_hash = hashlib.sha256(value_bytes).hexdigest() 48 | 49 | # Convert the hexadecimal hash to an integer 50 | hash_int = int(sha256_hash, 16) 51 | 52 | # Take the modulus 53 | result = hash_int % (2**16) 54 | if result <= frac_test * (2**16): 55 | test_subject_ids.append(subject_id) 56 | else: 57 | train_subject_ids.append(subject_id) 58 | 59 | return SubjectSplit(train_subject_ids=train_subject_ids, test_subject_ids=test_subject_ids) 60 | -------------------------------------------------------------------------------- /src/femr/models/xformers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import xformers.ops 3 | 4 | # From https://github.com/facebookresearch/xformers/blob/042abc8aa47d1f5bcc2e82df041811de218924ba/tests/test_mem_eff_attention.py#L511 # noqa 5 | 6 | 7 | def ref_attention(q, k, v, attn_bias=None, drop_mask=None, p=0.0, scale=None): 8 | q = q.float() 9 | k = k.float() 10 | v = v.float() 11 | 12 | scale = scale if scale is not None else (1 / q.shape[-1] ** 0.5) 13 | q = q * scale 14 | 15 | attn = q @ k.transpose(-2, -1) 16 | if attn_bias is not None: 17 | if isinstance(attn_bias, xformers.ops.AttentionBias): 18 | # Always create in B,H,Mq,Mk format 19 | attn_bias_tensor = attn_bias.materialize( 20 | (q.shape[0], 1, q.shape[1], k.shape[1]), 21 | device=q.device, 22 | dtype=torch.float32, 23 | ) 24 | else: 25 | attn_bias_tensor = attn_bias 26 | if attn_bias_tensor.ndim == 4: 27 | assert q.shape[0] == attn_bias_tensor.shape[0] * attn_bias_tensor.shape[1] 28 | attn_bias_tensor = attn_bias_tensor.reshape([-1, *attn_bias_tensor.shape[2:]]) 29 | attn = attn + attn_bias_tensor.float() 30 | attn = attn.softmax(-1) 31 | if drop_mask is not None: 32 | attn = attn * (drop_mask / (1 - p)) 33 | return attn @ v 34 | 35 | 36 | def ref_attention_bmhk(q, k, v, attn_bias, scale=None) -> torch.Tensor: 37 | assert q.ndim == 4 38 | 39 | def T(t): 40 | return t.permute((0, 2, 1, 3)).reshape([t.shape[0] * t.shape[2], t.shape[1], t.shape[3]]) 41 | 42 | if isinstance(attn_bias, xformers.ops.AttentionBias): 43 | attn_bias = attn_bias.materialize( 44 | (q.shape[0], q.shape[2], q.shape[1], k.shape[1]), 45 | device=q.device, 46 | dtype=torch.float32, 47 | ).reshape([q.shape[0] * q.shape[2], q.shape[1], k.shape[1]]) 48 | out = ref_attention(T(q), T(k), T(v), attn_bias, scale=scale) 49 | out = out.reshape([q.shape[0], q.shape[2], q.shape[1], v.shape[3]]) 50 | return out.permute((0, 2, 1, 3)) 51 | 52 | 53 | def memory_efficient_attention_wrapper(q, k, v, attn_bias): 54 | if q.device.type == "cpu": 55 | return ref_attention_bmhk(q, k, v, attn_bias) 56 | else: 57 | return xformers.ops.memory_efficient_attention(q, k, v, attn_bias) 58 | -------------------------------------------------------------------------------- /tutorials/input/clmbr_model/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "FEMRModel" 4 | ], 5 | "task_config": { 6 | "_name_or_path": "", 7 | "add_cross_attention": false, 8 | "architectures": null, 9 | "bad_words_ids": null, 10 | "begin_suppress_tokens": null, 11 | "bos_token_id": null, 12 | "chunk_size_feed_forward": 0, 13 | "cross_attention_hidden_size": null, 14 | "decoder_start_token_id": null, 15 | "diversity_penalty": 0.0, 16 | "do_sample": false, 17 | "early_stopping": false, 18 | "encoder_no_repeat_ngram_size": 0, 19 | "eos_token_id": null, 20 | "exponential_decay_length_penalty": null, 21 | "finetuning_task": null, 22 | "forced_bos_token_id": null, 23 | "forced_eos_token_id": null, 24 | "id2label": { 25 | "0": "LABEL_0", 26 | "1": "LABEL_1" 27 | }, 28 | "is_decoder": false, 29 | "is_encoder_decoder": false, 30 | "label2id": { 31 | "LABEL_0": 0, 32 | "LABEL_1": 1 33 | }, 34 | "length_penalty": 1.0, 35 | "max_length": 20, 36 | "min_length": 0, 37 | "model_type": "", 38 | "no_repeat_ngram_size": 0, 39 | "num_beam_groups": 1, 40 | "num_beams": 1, 41 | "num_return_sequences": 1, 42 | "output_attentions": false, 43 | "output_hidden_states": false, 44 | "output_scores": false, 45 | "pad_token_id": null, 46 | "prefix": null, 47 | "problem_type": null, 48 | "pruned_heads": {}, 49 | "remove_invalid_values": false, 50 | "repetition_penalty": 1.0, 51 | "return_dict": true, 52 | "return_dict_in_generate": false, 53 | "sep_token_id": null, 54 | "suppress_tokens": null, 55 | "task_kwargs": { 56 | "clmbr_vocab_size": 64 57 | }, 58 | "task_specific_params": null, 59 | "task_type": "clmbr", 60 | "temperature": 1.0, 61 | "tf_legacy_loss": false, 62 | "tie_encoder_decoder": false, 63 | "tie_word_embeddings": true, 64 | "tokenizer_class": null, 65 | "top_k": 50, 66 | "top_p": 1.0, 67 | "torch_dtype": null, 68 | "torchscript": false, 69 | "typical_p": 1.0, 70 | "use_bfloat16": false 71 | }, 72 | "torch_dtype": "float32", 73 | "transformer_config": { 74 | "hidden_size": 64, 75 | "intermediate_size": 128, 76 | "model_type": "", 77 | "n_heads": 8, 78 | "n_layers": 2, 79 | "vocab_size": 128 80 | }, 81 | "transformers_version": "4.39.0" 82 | } 83 | -------------------------------------------------------------------------------- /.github/workflows/build.yaml: -------------------------------------------------------------------------------- 1 | name: Publish Python 🐍 distribution 📦 to PyPI and TestPyPI 2 | 3 | on: push 4 | 5 | jobs: 6 | build: 7 | name: Build distribution 📦 8 | runs-on: ubuntu-latest 9 | 10 | steps: 11 | - uses: actions/checkout@v4 12 | - name: Set up Python 13 | uses: actions/setup-python@v4 14 | with: 15 | python-version: "3.x" 16 | - name: Install pypa/build 17 | run: >- 18 | python3 -m 19 | pip install 20 | build 21 | --user 22 | - name: Build a binary wheel and a source tarball 23 | run: python3 -m build 24 | - name: Store the distribution packages 25 | uses: actions/upload-artifact@v4 26 | with: 27 | name: python-package-distributions 28 | path: dist/ 29 | 30 | publish-to-pypi: 31 | name: >- 32 | Publish Python 🐍 distribution 📦 to PyPI 33 | if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes 34 | needs: 35 | - build 36 | runs-on: ubuntu-latest 37 | environment: 38 | name: pypi 39 | url: https://pypi.org/p/femr # Replace with your PyPI project name 40 | permissions: 41 | id-token: write # IMPORTANT: mandatory for trusted publishing 42 | 43 | steps: 44 | - name: Download all the dists 45 | uses: actions/download-artifact@v4 46 | with: 47 | name: python-package-distributions 48 | path: dist/ 49 | - name: Publish distribution 📦 to PyPI 50 | uses: pypa/gh-action-pypi-publish@release/v1 51 | 52 | github-release: 53 | name: >- 54 | Sign the Python 🐍 distribution 📦 with Sigstore 55 | and upload them to GitHub Release 56 | needs: 57 | - publish-to-pypi 58 | runs-on: ubuntu-latest 59 | 60 | permissions: 61 | contents: write # IMPORTANT: mandatory for making GitHub Releases 62 | id-token: write # IMPORTANT: mandatory for sigstore 63 | 64 | steps: 65 | - name: Download all the dists 66 | uses: actions/download-artifact@v4 67 | with: 68 | name: python-package-distributions 69 | path: dist/ 70 | - name: Sign the dists with Sigstore 71 | uses: sigstore/gh-action-sigstore-python@v3.0.1 72 | with: 73 | inputs: >- 74 | ./dist/*.tar.gz 75 | ./dist/*.whl 76 | - name: Create GitHub Release 77 | env: 78 | GITHUB_TOKEN: ${{ github.token }} 79 | run: >- 80 | gh release create 81 | '${{ github.ref_name }}' 82 | --repo '${{ github.repository }}' 83 | --notes "" 84 | - name: Upload artifact signatures to GitHub Release 85 | env: 86 | GITHUB_TOKEN: ${{ github.token }} 87 | # Upload to GitHub Release using the `gh` CLI. 88 | # `dist/` contains the built packages, and the 89 | # sigstore-produced signatures and certificates. 90 | run: >- 91 | gh release upload 92 | '${{ github.ref_name }}' dist/** 93 | --repo '${{ github.repository }}' 94 | -------------------------------------------------------------------------------- /src/femr/stat_utils.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import math 3 | import random 4 | 5 | 6 | @dataclasses.dataclass 7 | class OnlineStatistics: 8 | """ 9 | A class for computing online statistics such as mean and variance. 10 | From https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance. 11 | """ 12 | 13 | count: float 14 | current_mean: float 15 | variance: float 16 | 17 | def __init__(self): 18 | """ 19 | Initialize online statistics. Optionally takes the results of self.to_dict() to initialize from old data. 20 | """ 21 | self.count = 0 22 | self.current_mean = 0 23 | self.variance = 0 24 | 25 | def add(self, weight: float, value: float) -> None: 26 | """ 27 | Add an observation to the calculation. 28 | """ 29 | self.count += weight 30 | delta = value - self.current_mean 31 | self.current_mean += delta * (weight / self.count) 32 | delta2 = value - self.current_mean 33 | 34 | self.variance += weight * (delta * delta2) 35 | 36 | def mean(self) -> float: 37 | """ 38 | Return the current mean. 39 | """ 40 | return self.current_mean 41 | 42 | def standard_deviation(self) -> float: 43 | """ 44 | Return the current standard devation. 45 | """ 46 | return math.sqrt(self.variance / self.count) 47 | 48 | def combine(self, other) -> None: 49 | if self.count == 0: 50 | self.count = other.count 51 | self.current_mean = other.current_mean 52 | self.variance = other.variance 53 | elif other.count != 0: 54 | total = self.count + other.count 55 | delta = other.current_mean - self.current_mean 56 | new_mean = self.current_mean + delta * (other.count / total) 57 | new_variance = self.variance + other.variance + (delta * self.count) * (delta * other.count) / total 58 | 59 | self.count = total 60 | self.current_mean = new_mean 61 | self.variance = new_variance 62 | 63 | 64 | class ReservoirSampler: 65 | def __init__(self, size): 66 | self.total_weight = 0 67 | self.size = size 68 | self.samples = [] 69 | 70 | def add(self, sample, weight): 71 | self.total_weight += weight 72 | if len(self.samples) < self.size: 73 | self.samples.append(sample) 74 | if len(self.samples) == self.size: 75 | self.j = random.random() 76 | self.p_none = 1 77 | else: 78 | prob = weight / self.total_weight 79 | self.j -= prob * self.p_none 80 | self.p_none = self.p_none * (1 - prob) 81 | 82 | if self.j <= 0: 83 | self.samples[random.randint(0, self.size - 1)] = sample 84 | self.j = random.random() 85 | self.p_none = 1 86 | 87 | def combine(self, other): 88 | for val in other.samples: 89 | self.add(val, other.total_weight / len(other.samples)) 90 | -------------------------------------------------------------------------------- /tutorials/input/clmbr_model/main_split.csv: -------------------------------------------------------------------------------- 1 | subject_id,split_name 2 | 0,train 3 | 1,train 4 | 2,train 5 | 4,train 6 | 6,train 7 | 7,train 8 | 10,train 9 | 11,train 10 | 12,train 11 | 13,train 12 | 14,train 13 | 15,train 14 | 18,train 15 | 19,train 16 | 20,train 17 | 21,train 18 | 22,train 19 | 23,train 20 | 24,train 21 | 25,train 22 | 26,train 23 | 27,train 24 | 28,train 25 | 29,train 26 | 30,train 27 | 31,train 28 | 33,train 29 | 36,train 30 | 37,train 31 | 38,train 32 | 39,train 33 | 40,train 34 | 42,train 35 | 44,train 36 | 45,train 37 | 46,train 38 | 47,train 39 | 49,train 40 | 50,train 41 | 51,train 42 | 52,train 43 | 53,train 44 | 54,train 45 | 55,train 46 | 56,train 47 | 57,train 48 | 58,train 49 | 59,train 50 | 61,train 51 | 62,train 52 | 63,train 53 | 64,train 54 | 65,train 55 | 66,train 56 | 67,train 57 | 69,train 58 | 70,train 59 | 71,train 60 | 73,train 61 | 74,train 62 | 75,train 63 | 76,train 64 | 77,train 65 | 79,train 66 | 80,train 67 | 82,train 68 | 83,train 69 | 84,train 70 | 85,train 71 | 86,train 72 | 87,train 73 | 88,train 74 | 89,train 75 | 90,train 76 | 91,train 77 | 92,train 78 | 93,train 79 | 94,train 80 | 95,train 81 | 96,train 82 | 97,train 83 | 98,train 84 | 100,train 85 | 101,train 86 | 102,train 87 | 103,train 88 | 104,train 89 | 105,train 90 | 106,train 91 | 107,train 92 | 108,train 93 | 109,train 94 | 110,train 95 | 112,train 96 | 113,train 97 | 114,train 98 | 115,train 99 | 116,train 100 | 117,train 101 | 118,train 102 | 120,train 103 | 121,train 104 | 122,train 105 | 123,train 106 | 124,train 107 | 125,train 108 | 126,train 109 | 127,train 110 | 128,train 111 | 131,train 112 | 132,train 113 | 133,train 114 | 134,train 115 | 135,train 116 | 136,train 117 | 137,train 118 | 138,train 119 | 139,train 120 | 141,train 121 | 142,train 122 | 143,train 123 | 144,train 124 | 146,train 125 | 147,train 126 | 148,train 127 | 149,train 128 | 150,train 129 | 151,train 130 | 152,train 131 | 153,train 132 | 154,train 133 | 155,train 134 | 156,train 135 | 157,train 136 | 158,train 137 | 159,train 138 | 160,train 139 | 161,train 140 | 162,train 141 | 163,train 142 | 165,train 143 | 166,train 144 | 168,train 145 | 169,train 146 | 171,train 147 | 172,train 148 | 173,train 149 | 174,train 150 | 177,train 151 | 178,train 152 | 179,train 153 | 180,train 154 | 181,train 155 | 182,train 156 | 183,train 157 | 184,train 158 | 185,train 159 | 186,train 160 | 187,train 161 | 188,train 162 | 189,train 163 | 190,train 164 | 191,train 165 | 192,train 166 | 193,train 167 | 195,train 168 | 196,train 169 | 197,train 170 | 198,train 171 | 199,train 172 | 3,test 173 | 5,test 174 | 8,test 175 | 9,test 176 | 16,test 177 | 17,test 178 | 32,test 179 | 34,test 180 | 35,test 181 | 41,test 182 | 43,test 183 | 48,test 184 | 60,test 185 | 68,test 186 | 72,test 187 | 78,test 188 | 81,test 189 | 99,test 190 | 111,test 191 | 119,test 192 | 129,test 193 | 130,test 194 | 140,test 195 | 145,test 196 | 164,test 197 | 167,test 198 | 170,test 199 | 175,test 200 | 176,test 201 | 194,test 202 | -------------------------------------------------------------------------------- /tutorials/input/motor_model/main_split.csv: -------------------------------------------------------------------------------- 1 | subject_id,split_name 2 | 0,train 3 | 1,train 4 | 2,train 5 | 4,train 6 | 6,train 7 | 7,train 8 | 10,train 9 | 11,train 10 | 12,train 11 | 13,train 12 | 14,train 13 | 15,train 14 | 18,train 15 | 19,train 16 | 20,train 17 | 21,train 18 | 22,train 19 | 23,train 20 | 24,train 21 | 25,train 22 | 26,train 23 | 27,train 24 | 28,train 25 | 29,train 26 | 30,train 27 | 31,train 28 | 33,train 29 | 36,train 30 | 37,train 31 | 38,train 32 | 39,train 33 | 40,train 34 | 42,train 35 | 44,train 36 | 45,train 37 | 46,train 38 | 47,train 39 | 49,train 40 | 50,train 41 | 51,train 42 | 52,train 43 | 53,train 44 | 54,train 45 | 55,train 46 | 56,train 47 | 57,train 48 | 58,train 49 | 59,train 50 | 61,train 51 | 62,train 52 | 63,train 53 | 64,train 54 | 65,train 55 | 66,train 56 | 67,train 57 | 69,train 58 | 70,train 59 | 71,train 60 | 73,train 61 | 74,train 62 | 75,train 63 | 76,train 64 | 77,train 65 | 79,train 66 | 80,train 67 | 82,train 68 | 83,train 69 | 84,train 70 | 85,train 71 | 86,train 72 | 87,train 73 | 88,train 74 | 89,train 75 | 90,train 76 | 91,train 77 | 92,train 78 | 93,train 79 | 94,train 80 | 95,train 81 | 96,train 82 | 97,train 83 | 98,train 84 | 100,train 85 | 101,train 86 | 102,train 87 | 103,train 88 | 104,train 89 | 105,train 90 | 106,train 91 | 107,train 92 | 108,train 93 | 109,train 94 | 110,train 95 | 112,train 96 | 113,train 97 | 114,train 98 | 115,train 99 | 116,train 100 | 117,train 101 | 118,train 102 | 120,train 103 | 121,train 104 | 122,train 105 | 123,train 106 | 124,train 107 | 125,train 108 | 126,train 109 | 127,train 110 | 128,train 111 | 131,train 112 | 132,train 113 | 133,train 114 | 134,train 115 | 135,train 116 | 136,train 117 | 137,train 118 | 138,train 119 | 139,train 120 | 141,train 121 | 142,train 122 | 143,train 123 | 144,train 124 | 146,train 125 | 147,train 126 | 148,train 127 | 149,train 128 | 150,train 129 | 151,train 130 | 152,train 131 | 153,train 132 | 154,train 133 | 155,train 134 | 156,train 135 | 157,train 136 | 158,train 137 | 159,train 138 | 160,train 139 | 161,train 140 | 162,train 141 | 163,train 142 | 165,train 143 | 166,train 144 | 168,train 145 | 169,train 146 | 171,train 147 | 172,train 148 | 173,train 149 | 174,train 150 | 177,train 151 | 178,train 152 | 179,train 153 | 180,train 154 | 181,train 155 | 182,train 156 | 183,train 157 | 184,train 158 | 185,train 159 | 186,train 160 | 187,train 161 | 188,train 162 | 189,train 163 | 190,train 164 | 191,train 165 | 192,train 166 | 193,train 167 | 195,train 168 | 196,train 169 | 197,train 170 | 198,train 171 | 199,train 172 | 3,test 173 | 5,test 174 | 8,test 175 | 9,test 176 | 16,test 177 | 17,test 178 | 32,test 179 | 34,test 180 | 35,test 181 | 41,test 182 | 43,test 183 | 48,test 184 | 60,test 185 | 68,test 186 | 72,test 187 | 78,test 188 | 81,test 189 | 99,test 190 | 111,test 191 | 119,test 192 | 129,test 193 | 130,test 194 | 140,test 195 | 145,test 196 | 164,test 197 | 167,test 198 | 170,test 199 | 175,test 200 | 176,test 201 | 194,test 202 | -------------------------------------------------------------------------------- /src/femr/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | """A collection of general use transforms.""" 2 | 3 | import datetime 4 | from typing import Any, Callable, Dict, List, Optional, Set, Tuple 5 | 6 | import meds_reader 7 | import meds_reader.transform 8 | 9 | 10 | def remove_nones( 11 | subject: meds_reader.transform.MutableSubject, 12 | do_not_apply_to_filter: Optional[Callable[[meds_reader.Event], bool]] = None, 13 | ) -> meds_reader.transform.MutableSubject: 14 | """Remove duplicate codes w/in same day if duplicate code has None value. 15 | 16 | There is no point having a NONE value in a timeline when we have an actual value within the same day. 17 | 18 | This removes those unnecessary NONE values. 19 | """ 20 | do_not_apply_to_filter = do_not_apply_to_filter or (lambda _: False) 21 | has_value: Set[Tuple[str, datetime.date]] = set() 22 | 23 | for event in subject.events: 24 | value = (event.numeric_value, event.text_value) 25 | if any(v is not None for v in value): 26 | has_value.add((event.code, event.time.date())) 27 | 28 | new_events: List[meds_reader.transform.MutableEvent] = [] 29 | for event in subject.events: 30 | value = (event.numeric_value, event.text_value) 31 | if ( 32 | all(v is None for v in value) 33 | and (event.code, event.time.date()) in has_value 34 | and not do_not_apply_to_filter(event) 35 | ): 36 | # Skip this event as already in there 37 | continue 38 | 39 | new_events.append(event) 40 | 41 | subject.events = new_events 42 | subject.events.sort(key=lambda a: a.time) 43 | 44 | return subject 45 | 46 | 47 | def delta_encode( 48 | subject: meds_reader.transform.MutableSubject, 49 | do_not_apply_to_filter: Optional[Callable[[meds_reader.Event], bool]] = None, 50 | ) -> meds_reader.transform.MutableSubject: 51 | """Delta encodes the subject. 52 | 53 | The idea behind delta encoding is that if we get duplicate values within a short amount of time 54 | (1 day for this code), there is not much point retaining the duplicate. 55 | 56 | This code removes all *sequential* duplicates within the same day. 57 | """ 58 | do_not_apply_to_filter = do_not_apply_to_filter or (lambda _: False) 59 | 60 | last_value: Dict[Tuple[str, datetime.date], Any] = {} 61 | 62 | new_events: List[meds_reader.transform.MutableEvent] = [] 63 | for event in subject.events: 64 | key = (event.code, event.time.date()) 65 | value = (event.numeric_value, event.text_value) 66 | if key in last_value and last_value[key] == value and not do_not_apply_to_filter(event): 67 | continue 68 | last_value[key] = value 69 | new_events.append(event) 70 | 71 | subject.events = new_events 72 | subject.events.sort(key=lambda a: a.time) 73 | 74 | return subject 75 | 76 | 77 | def fix_events(subject: meds_reader.transform.MutableSubject) -> meds_reader.transform.MutableSubject: 78 | """After a series of transformations, sometimes the subject structure gets a bit messed up. 79 | The usual issues are either duplicate event times or missorted events. 80 | 81 | This does a final cleanup pass to meet the MEDS requirements. 82 | """ 83 | subject.events = sorted(subject.events, key=lambda a: a.time) 84 | 85 | return subject 86 | -------------------------------------------------------------------------------- /tutorials/synthetic_data_generation/generate_patients.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import json 4 | import os 5 | import pickle 6 | import random 7 | 8 | import jsonschema 9 | import meds 10 | import meds_reader 11 | import pyarrow 12 | import pyarrow.parquet 13 | 14 | import femr.ontology 15 | import femr.transforms 16 | 17 | if __name__ == "__main__": 18 | 19 | parser = argparse.ArgumentParser(prog="generate_subjects", description="Create synthetic data") 20 | parser.add_argument("athena", type=str) 21 | parser.add_argument("destination", type=str) 22 | args = parser.parse_args() 23 | 24 | random.seed(4533) 25 | 26 | def get_random_subject(subject_id): 27 | epoch = datetime.datetime(1990, 1, 1) 28 | birth = epoch + datetime.timedelta(days=random.randint(100, 1000)) 29 | current_date = birth 30 | 31 | gender = "Gender/" + random.choice(["F", "M"]) 32 | race = "Race/" + random.choice(["White", "Non-White"]) 33 | 34 | rows = [] 35 | 36 | birth_codes = [meds.birth_code, gender, race] 37 | 38 | for birth_code in birth_codes: 39 | rows.append({"subject_id": subject_id, "time": birth, "code": birth_code}) 40 | 41 | code_cats = ["ICD9CM", "RxNorm"] 42 | for code in range(random.randint(1, 10 + (20 if gender == "Gender/F" else 0))): 43 | code_cat = random.choice(code_cats) 44 | if code_cat == "RxNorm": 45 | code = str(random.randint(0, 10000)) 46 | else: 47 | code = str(random.randint(0, 10000)) 48 | if len(code) > 3: 49 | code = code[:3] + "." + code[3:] 50 | current_date = current_date + datetime.timedelta(days=random.randint(1, 100)) 51 | code = code_cat + "/" + code 52 | rows.append({"subject_id": subject_id, "time": current_date, "code": code}) 53 | 54 | return rows 55 | 56 | subjects = [] 57 | for i in range(200): 58 | subjects.extend(get_random_subject(i)) 59 | 60 | subject_schema = meds.schema.data_schema() 61 | 62 | subject_table = pyarrow.Table.from_pylist(subjects, subject_schema) 63 | 64 | os.makedirs(os.path.join(args.destination, "data"), exist_ok=True) 65 | os.makedirs(os.path.join(args.destination, "metadata"), exist_ok=True) 66 | 67 | pyarrow.parquet.write_table(subject_table, os.path.join(args.destination, "data", "subjects.parquet")) 68 | 69 | metadata = { 70 | "dataset_name": "femr synthetic datata", 71 | "dataset_version": "1", 72 | "etl_name": "synthetic data", 73 | "etl_version": "1", 74 | "code_metadata": {}, 75 | } 76 | 77 | jsonschema.validate(instance=metadata, schema=meds.dataset_metadata_schema) 78 | 79 | with open(os.path.join(args.destination, "metadata", "metadata.json"), "w") as f: 80 | json.dump(metadata, f) 81 | 82 | print("Converting") 83 | os.system(f"meds_reader_convert {args.destination} {args.destination}_meds") 84 | 85 | print("Opening database") 86 | 87 | with meds_reader.SubjectDatabase(args.destination + "_meds", num_threads=6) as database: 88 | print("Creating ontology") 89 | ontology = femr.ontology.Ontology(args.athena) 90 | 91 | print("Pruning ontology") 92 | ontology.prune_to_dataset(database, remove_ontologies=("SNOMED")) 93 | 94 | with open(os.path.join(args.destination, "ontology.pkl"), "wb") as f: 95 | pickle.dump(ontology, f) 96 | -------------------------------------------------------------------------------- /src/femr/post_etl_pipelines/stanford.py: -------------------------------------------------------------------------------- 1 | """An ETL script for doing an end to end transform of Stanford data into a SubjectDatabase.""" 2 | 3 | import argparse 4 | import functools 5 | import json 6 | import os 7 | from typing import Callable, Sequence 8 | 9 | import meds_reader 10 | import meds_reader.transform 11 | 12 | from femr.transforms import delta_encode, remove_nones 13 | from femr.transforms.stanford import ( 14 | move_billing_codes, 15 | move_pre_birth, 16 | move_to_day_end, 17 | move_visit_start_to_first_event_start, 18 | switch_to_icd10cm, 19 | ) 20 | 21 | 22 | def _is_visit_measurement(e: meds_reader.Event) -> bool: 23 | return e.table == "visit" 24 | 25 | 26 | def _apply_transformations(subject, *, transforms): 27 | for transform in transforms: 28 | subject = transform(subject) 29 | return subject 30 | 31 | 32 | def _remove_flowsheets(subject: meds_reader.transform.MutableSubject) -> meds_reader.transform.MutableSubject: 33 | """Flowsheets in STARR-OMOP have known timing bugs, making them unsuitable for use as either features or labels. 34 | 35 | TODO: Investigate them so we can add them back as features 36 | """ 37 | new_events = [] 38 | for event in subject.events: 39 | if event.code != "STANFORD_OBS/Flowsheet": 40 | new_events.append(event) 41 | 42 | subject.events = new_events 43 | return subject 44 | 45 | 46 | def _get_stanford_transformations() -> ( 47 | Callable[[meds_reader.transform.MutableSubject], meds_reader.transform.MutableSubject] 48 | ): 49 | """Get the list of current OMOP transformations.""" 50 | # All of these transformations are information preserving 51 | transforms: Sequence[Callable[[meds_reader.transform.MutableSubject], meds_reader.transform.MutableSubject]] = [ 52 | move_pre_birth, 53 | move_visit_start_to_first_event_start, 54 | move_to_day_end, 55 | switch_to_icd10cm, 56 | move_billing_codes, 57 | functools.partial( 58 | remove_nones, # We have to keep visits in order to sync up visit_ids later in the process 59 | # If we ever remove or revisit visit_id, we would want to revisit this 60 | do_not_apply_to_filter=_is_visit_measurement, 61 | ), 62 | functools.partial( 63 | delta_encode, # We have to keep visits in order to sync up visit_ids later in the process 64 | # If we ever remove or revisit visit_id, we would want to revisit this 65 | do_not_apply_to_filter=_is_visit_measurement, 66 | ), 67 | _remove_flowsheets, 68 | ] 69 | 70 | return functools.partial(_apply_transformations, transforms=transforms) 71 | 72 | 73 | def femr_stanford_omop_fixer_program() -> None: 74 | """Extract data from an Stanford STARR-OMOP v5 source to create a femr SubjectDatabase.""" 75 | parser = argparse.ArgumentParser(description="An extraction tool for STARR-OMOP v5 sources") 76 | 77 | parser.add_argument( 78 | "source_dataset", 79 | type=str, 80 | help="Path of the folder to source dataset", 81 | ) 82 | 83 | parser.add_argument( 84 | "target_dataset", 85 | type=str, 86 | help="The place to store the extract", 87 | ) 88 | 89 | parser.add_argument( 90 | "--num_proc", 91 | type=int, 92 | help="The number of threads to use", 93 | default=1, 94 | ) 95 | 96 | args = parser.parse_args() 97 | 98 | meds_reader.transform.transform_meds_dataset( 99 | args.source_dataset, args.target_dataset, _get_stanford_transformations(), num_threads=args.num_proc 100 | ) 101 | 102 | with open(os.path.join(args.target_dataset, "metadata/dataset.json")) as f: 103 | metadata = json.load(f) 104 | 105 | # Let's mark that we modified this dataset 106 | metadata["post_etl_name"] = "femr_stanford_omop_fixer" 107 | metadata["post_etl_version"] = "0.1" 108 | 109 | with open(os.path.join(args.target_dataset, "metadata/dataset.json"), "w") as f: 110 | json.dump(metadata, f) 111 | -------------------------------------------------------------------------------- /src/femr/models/config.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any, Dict, Mapping, Optional 4 | 5 | import transformers 6 | 7 | 8 | class FEMRTransformerConfig(transformers.PretrainedConfig): 9 | def __init__( 10 | self, 11 | vocab_size: int = 32768, 12 | is_hierarchical: bool = False, 13 | hidden_size: int = 768, 14 | intermediate_size: int = 3072, 15 | n_heads: int = 12, 16 | n_layers: int = 6, 17 | attention_width: int = 496, 18 | use_normed_ages: bool = False, 19 | use_bias: bool = True, 20 | hidden_act: str = "gelu", 21 | **kwargs, 22 | ) -> None: 23 | """Defined a configuration for a FEMR Transformer. 24 | 25 | Arguments: 26 | vocab_size: The number of tokens in the vocabulary 27 | is_hierarchical: Whether to use a hierarchical vocabulary. See FEMRTokenizer for more information 28 | hidden_size: The internal representation size 29 | intermediate_size: The size of the FFN in the transformer layers 30 | n_heads: The number of attention heads 31 | n_layers: The number of transformer encoder layers 32 | attention_width: FEMR by default uses a local attention transformer with a width defined here 33 | use_normed_ages: Whether or not to provide normalized ages as a feature to the model 34 | use_bias: Whether or not to use bias terms in the transformer layers 35 | hidden_act: The type of activation function to use in the transformer 36 | """ 37 | super().__init__(**kwargs) 38 | 39 | self.vocab_size = vocab_size 40 | self.is_hierarchical = is_hierarchical 41 | 42 | self.hidden_size = hidden_size 43 | self.intermediate_size = intermediate_size 44 | self.n_heads = n_heads 45 | self.n_layers = n_layers 46 | self.attention_width = attention_width 47 | 48 | self.use_normed_ages = use_normed_ages 49 | 50 | self.use_bias = use_bias 51 | self.hidden_act = hidden_act 52 | 53 | 54 | class FEMRTaskConfig(transformers.PretrainedConfig): 55 | def __init__(self, task_type: str = "", task_kwargs: Mapping[str, Any] = {}, **kwargs): 56 | """A generic FEMR task definition. This holds state used for initalizing a tasks.py class. 57 | 58 | Task.get_task_config returns the task type and kwargs used to initialize this. 59 | 60 | Arguments: 61 | task_type: The name of the task. 62 | task_kwargs: Arbitrary arguments used to store state for that task. 63 | """ 64 | super().__init__(**kwargs) 65 | self.task_type = task_type 66 | self.task_kwargs = task_kwargs 67 | 68 | 69 | class FEMRModelConfig(transformers.PretrainedConfig): 70 | """A model config is defined as the combination of a transformer config and a task config.""" 71 | 72 | def __init__( 73 | self, 74 | transformer_config: Optional[Dict[str, Any]] = None, 75 | task_config: Optional[Dict[str, Any]] = None, 76 | **kwargs, 77 | ): 78 | """A combination of a transformer config and a task config. 79 | 80 | It is possible to initialize this with only a transformer config, in which 81 | case the model will be configured for inference only. 82 | """ 83 | super().__init__(**kwargs) 84 | if transformer_config is None: 85 | transformer_config = {} 86 | self.transformer_config = FEMRTransformerConfig(**transformer_config) 87 | 88 | self.task_config: Optional[FEMRTaskConfig] 89 | 90 | if task_config is not None: 91 | self.task_config = FEMRTaskConfig(**task_config) 92 | else: 93 | self.task_config = None 94 | 95 | @classmethod 96 | def from_transformer_task_configs( 97 | cls, transformer_config: FEMRTransformerConfig, task_config: FEMRTaskConfig 98 | ) -> FEMRModelConfig: 99 | """ 100 | Combine a transformer configuration and task configuration into a model configuration. 101 | """ 102 | if task_config is not None: 103 | task_config_dict = task_config.to_dict() 104 | else: 105 | task_config_dict = None 106 | 107 | return cls(transformer_config=transformer_config.to_dict(), task_config=task_config_dict) 108 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FEMR 2 | ### Framework for Electronic Medical Records 3 | 4 | **FEMR** is a Python package for manipulating longitudinal EHR data for machine learning, with a focus on supporting the creation of foundation models and verifying their [presumed benefits](https://hai.stanford.edu/news/how-foundation-models-can-advance-ai-healthcare) in healthcare. Such a framework is needed given the [current state of large language models in healthcare](https://hai.stanford.edu/news/shaky-foundations-foundation-models-healthcare) and the need for better evaluation frameworks. 5 | 6 | The currently supported foundation models is [MOTOR](https://arxiv.org/abs/2301.03150). 7 | 8 | (Users who want to train auto-regressive CLMBR-style models should use [FEMR 0.1.16](https://github.com/som-shahlab/femr/releases/tag/0.1.16) or https://github.com/som-shahlab/hf_ehr) 9 | 10 | **FEMR** works with data that has been converted to the [MEDS](https://github.com/Medical-Event-Data-Standard/) schema, a simple schema that supports a wide variety of EHR / claims datasets. Please see the MEDS documentation, and in particular its [provided ETLs](https://github.com/Medical-Event-Data-Standard/meds_etl) for help converting your data to MEDS. 11 | 12 | **FEMR** helps users: 13 | 1. [Use ontologies to better understand / featurize medical codes](http://github.com/som-shahlab/femr/blob/main/tutorials/1_Ontology.ipynb) 14 | 2. [Algorithmically label subject records based on structured data](https://github.com/som-shahlab/femr/blob/main/tutorials/2_Labeling.ipynb) 15 | 3. [Generate tabular features from subject timelines for use with traditional gradient boosted tree models](https://github.com/som-shahlab/femr/blob/main/tutorials/3_Count%20Featurization%20And%20Modeling.ipynb) 16 | 4. [Train](https://github.com/som-shahlab/femr/blob/main/tutorials/4_Train%20MOTOR.ipynb) and [finetune](https://github.com/som-shahlab/femr/blob/main/tutorials/5_MOTOR%20Featurization%20And%20Modeling.ipynb) MOTOR-derived models for binary classification and prediction tasks. 17 | 18 | We recommend users start with our [tutorial folder](https://github.com/som-shahlab/femr/tree/main/tutorials) 19 | 20 | # Installation 21 | 22 | ```bash 23 | pip install femr 24 | 25 | # If you are using deep learning, you also need to install xformers 26 | # 27 | # Note that xformers has some known issues with MacOS. 28 | # If you are using MacOS you might also need to install llvm. See https://stackoverflow.com/questions/60005176/how-to-deal-with-clang-error-unsupported-option-fopenmp-on-travis 29 | pip install xformers 30 | 31 | ``` 32 | # Getting Started 33 | 34 | The first step of using **FEMR** is to convert your subject data into [MEDS](https://github.com/Medical-Event-Data-Standard), the standard input format expected by **FEMR** codebase. 35 | 36 | **Note: FEMR currently only supports MEDS v3, so you will need to install MEDS v3 versions of packages. Aka pip install meds-etl==0.3.11** 37 | 38 | The best way to do this is with the [ETLs provided by MEDS](https://github.com/Medical-Event-Data-Standard/meds_etl). 39 | 40 | 41 | ## OMOP Data 42 | 43 | If you have OMOP CDM formated data, follow these instructions: 44 | 45 | 1. Download your OMOP dataset to `[PATH_TO_SOURCE_OMOP]`. 46 | 2. Convert OMOP => MEDS using the following: 47 | ```bash 48 | # Convert OMOP => MEDS data format 49 | meds_etl_omop [PATH_TO_SOURCE_OMOP] [PATH_TO_OUTPUT_MEDS] 50 | ``` 51 | 52 | ## Stanford STARR-OMOP Data 53 | 54 | If you are using the STARR-OMOP dataset from Stanford (which uses the OMOP CDM), we add an initial Stanford-specific preprocessing step. Otherwise this should be identical to the **OMOP Data** section. Follow these instructions: 55 | 56 | 1. Download your STARR-OMOP dataset to `[PATH_TO_SOURCE_OMOP]`. 57 | 2. Convert STARR-OMOP => MEDS using the following: 58 | ```bash 59 | # Convert OMOP => MEDS data format 60 | meds_etl_omop [PATH_TO_SOURCE_OMOP] [PATH_TO_OUTPUT_MEDS]_raw 61 | 62 | # Apply Stanford fixes 63 | femr_stanford_omop_fixer [PATH_TO_OUTPUT_MEDS]_raw [PATH_TO_OUTPUT_MEDS] 64 | ``` 65 | 66 | # Development 67 | 68 | The following guides are for developers who want to contribute to **FEMR**. 69 | 70 | ## Precommit checks 71 | 72 | Before committing, please run the following commands to ensure that your code is formatted correctly and passes all tests. 73 | 74 | ### Installation 75 | ```bash 76 | conda install pre-commit pytest -y 77 | pre-commit install 78 | ``` 79 | 80 | ### Running 81 | 82 | #### Test Functions 83 | 84 | ```bash 85 | pytest tests 86 | ``` 87 | 88 | ### Formatting Checks 89 | 90 | ```bash 91 | pre-commit run --all-files 92 | ``` 93 | -------------------------------------------------------------------------------- /tests/featurizers/test_OnlineStatistics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from femr.featurizers.utils import OnlineStatistics 5 | 6 | 7 | def _assert_correct_stats(stat: OnlineStatistics, values: list): 8 | TOLERANCE = 1e-6 # Allow for some floating point error 9 | true_mean = np.mean(values) 10 | true_sample_variance = np.var(values, ddof=1) 11 | true_m2 = true_sample_variance * (len(values) - 1) 12 | assert stat.current_count == len(values), f"{stat.current_count} != {len(values)}" 13 | assert np.isclose(stat.mean(), true_mean), f"{stat.mean()} != {true_mean}" 14 | assert np.isclose( 15 | stat.variance(), true_sample_variance, atol=TOLERANCE 16 | ), f"{stat.variance()} != {true_sample_variance}" 17 | assert np.isclose(stat.current_M2, true_m2, atol=TOLERANCE), f"{stat.current_M2} != {true_m2}" 18 | 19 | 20 | def test_add(): 21 | # Test adding things to the statistics 22 | def _run_test(values): 23 | stat = OnlineStatistics() 24 | for i in values: 25 | stat.add(i) 26 | _assert_correct_stats(stat, values) 27 | 28 | # Positive integers 29 | _run_test(range(51)) 30 | _run_test(range(10, 10000, 3)) 31 | # Negative integers 32 | _run_test(range(-400, -300)) 33 | # Positive/negative integers 34 | _run_test(list(range(4, 900, 2)) + list(range(-1000, -300, 7))) 35 | _run_test(list(range(-100, 100, 7)) + list(range(-100, 100, 2))) 36 | # Decimals 37 | _run_test(np.linspace(0, 1, 100)) 38 | _run_test(np.logspace(-100, 3, 100)) 39 | # Small lists 40 | _run_test([0, 1]) 41 | _run_test([-1, 1]) 42 | 43 | 44 | def test_constructor(): 45 | # Test default 46 | stat = OnlineStatistics() 47 | assert stat.current_count == 0 48 | assert stat.current_mean == stat.mean() == 0 49 | assert stat.current_M2 == 0 50 | 51 | # Test explicitly setting args 52 | stat = OnlineStatistics(current_count=1, current_mean=2, current_variance=3) 53 | assert stat.current_count == 1 54 | assert stat.current_mean == stat.mean() == 2 55 | assert stat.current_M2 == 0 56 | 57 | # Test M2 58 | stat = OnlineStatistics(current_count=10, current_mean=20, current_variance=30) 59 | assert stat.current_count == 10 60 | assert stat.current_mean == 20 61 | assert stat.current_M2 == 30 * (10 - 1) 62 | 63 | # Test getters/setters 64 | stat = OnlineStatistics(current_count=10, current_mean=20, current_variance=30) 65 | assert stat.mean() == 20 66 | assert stat.variance() == 30 67 | assert stat.standard_deviation() == np.sqrt(30) 68 | 69 | # Test fail cases 70 | with pytest.raises(ValueError) as _: 71 | # Negative count 72 | stat = OnlineStatistics(current_count=-1, current_mean=2, current_variance=3) 73 | with pytest.raises(ValueError) as _: 74 | # Negative variance 75 | stat = OnlineStatistics(current_count=1, current_mean=2, current_variance=-3) 76 | with pytest.raises(ValueError) as _: 77 | # Positive variance with 0 count 78 | stat = OnlineStatistics(current_count=0, current_mean=2, current_variance=1) 79 | with pytest.raises(ValueError) as _: 80 | # Can only compute variance with >1 observation 81 | stat = OnlineStatistics() 82 | stat.add(1) 83 | stat.variance() 84 | 85 | 86 | def test_merge_pair(): 87 | # Simulate two statistics being merged via `merge_pair`` 88 | stat1 = OnlineStatistics() 89 | values1 = list(range(-300, 300, 4)) + list(range(400, 450)) 90 | for i in values1: 91 | stat1.add(i) 92 | stat2 = OnlineStatistics() 93 | values2 = list(range(100, 150)) 94 | for i in values2: 95 | stat2.add(i) 96 | merged_stat = OnlineStatistics.merge_pair(stat1, stat2) 97 | merged_stat_values = values1 + values2 98 | _assert_correct_stats(merged_stat, merged_stat_values) 99 | 100 | 101 | def test_merge(): 102 | # Simulate parallel statistics being merged via `merge` 103 | stats = [] 104 | values = [ 105 | np.linspace(-100, 100, 50), 106 | np.linspace(100, 200, 50), 107 | np.linspace(100, 150, 100), 108 | np.linspace(-10, 100, 100), 109 | np.linspace(10, 200, 3), 110 | ] 111 | for i in range(len(values)): 112 | stat = OnlineStatistics() 113 | for v in values[i]: 114 | stat.add(v) 115 | stats.append(stat) 116 | merged_stat = OnlineStatistics.merge(stats) 117 | _assert_correct_stats(merged_stat, np.concatenate(values)) 118 | -------------------------------------------------------------------------------- /tutorials/1_Ontology.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# FEMR Ontology support\n", 8 | "\n", 9 | "FEMR provides support for querying ontologies using the OMOP Vocabulary. \n", 10 | "\n", 11 | "This enables easier definition of labeling functions as well as better feature generation." 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "# Downloading the OMOP Vocabulary\n", 19 | "\n", 20 | "The OMOP Vocabulary can be downloaded for free from the [OHDSI ATHENA website.](https://athena.ohdsi.org/)" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "metadata": {}, 26 | "source": [ 27 | "# Processing the OMOP Vocabulary\n", 28 | "\n", 29 | "femr.ontology.Ontology allows you to process, and then use the OMOP Vocabulary, optionally combining it with [code metadata from MEDS](https://github.com/Medical-Event-Data-Standard/meds/blob/e93f63a2f9642123c49a31ecffcdb84d877dc54a/src/meds/__init__.py#L94).\n", 30 | "\n", 31 | "```python \n", 32 | "ontology = femr.ontology.Ontology(path_to_athena, code_metadata)\n", 33 | "```" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "# Working with an Ontology object\n", 41 | "\n", 42 | "The following code samples illustrate the main ways to use a vocabulary object" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 1, 48 | "metadata": {}, 49 | "outputs": [ 50 | { 51 | "name": "stdout", 52 | "output_type": "stream", 53 | "text": [ 54 | "Loaded ontology\n" 55 | ] 56 | } 57 | ], 58 | "source": [ 59 | "import pickle\n", 60 | "\n", 61 | "# You can load / save ontology objects with pickle\n", 62 | "\n", 63 | "with open('input/ontology.pkl', 'rb') as f:\n", 64 | " ontology = pickle.load(f)\n", 65 | "\n", 66 | "print(\"Loaded ontology\")" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 2, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "# Ontology datasets downloaded by Athena tend to be very large as they contain many codes, including several that are no longer used.\n", 76 | "# We therefore provide a function to prune ontologies to a particular dataset of interest.\n", 77 | "# This makes it much cheaper to store and use an ontology object, both in terms of disk space and RAM\n", 78 | "\n", 79 | "\n", 80 | "import meds_reader\n", 81 | "\n", 82 | "database = meds_reader.SubjectDatabase(\"input/synthetic_meds\")\n", 83 | "\n", 84 | "ontology.prune_to_dataset(database)" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 3, 90 | "metadata": {}, 91 | "outputs": [ 92 | { 93 | "name": "stdout", 94 | "output_type": "stream", 95 | "text": [ 96 | "Description DRUGS FOR PEPTIC ULCER AND GASTRO-OESOPHAGEAL REFLUX DISEASE (GORD)\n", 97 | "Parents {'ATC/A02'}\n", 98 | "Children {'ATC/A02BX'}\n", 99 | "All children {'ATC/A02BX71', 'RxNorm/2344', 'ATC/A02B', 'RxNorm/8730', 'RxNorm/2403', 'RxNorm/6852', 'ATC/A02BX', 'RxNorm/2620', 'RxNorm/7815', 'RxNorm/4501', 'RxNorm/2018', 'ATC/A02BX77', 'RxNorm/2353', 'RxNorm/7019', 'RxNorm/38574', 'RxNorm/2017', 'RxNorm/8704', 'RxNorm/8705'}\n", 100 | "All parents {'ATC/A02B', 'ATC/A02', 'ATC/A'}\n" 101 | ] 102 | } 103 | ], 104 | "source": [ 105 | "# First, we can query the description for a particular code\n", 106 | "print(\"Description\", ontology.get_description(\"ATC/A02B\"))\n", 107 | "\n", 108 | "# Second, we can search for the parents of a particular code\n", 109 | "print(\"Parents\", ontology.get_parents(\"ATC/A02B\"))\n", 110 | "\n", 111 | "# Finally, we can search for the children of a particular code\n", 112 | "print(\"Children\", ontology.get_children(\"ATC/A02B\"))\n", 113 | "\n", 114 | "# For the sake of convience, we also support the recursive versions of querying parents and children\n", 115 | "print(\"All children\", ontology.get_all_children(\"ATC/A02B\"))\n", 116 | "print(\"All parents\", ontology.get_all_parents(\"ATC/A02B\"))" 117 | ] 118 | } 119 | ], 120 | "metadata": { 121 | "kernelspec": { 122 | "display_name": "Python 3 (ipykernel)", 123 | "language": "python", 124 | "name": "python3" 125 | }, 126 | "language_info": { 127 | "codemirror_mode": { 128 | "name": "ipython", 129 | "version": 3 130 | }, 131 | "file_extension": ".py", 132 | "mimetype": "text/x-python", 133 | "name": "python", 134 | "nbconvert_exporter": "python", 135 | "pygments_lexer": "ipython3", 136 | "version": "3.13.3" 137 | } 138 | }, 139 | "nbformat": 4, 140 | "nbformat_minor": 4 141 | } 142 | -------------------------------------------------------------------------------- /src/femr/featurizers/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import copy 4 | import math 5 | from typing import List 6 | 7 | 8 | class OnlineStatistics: 9 | """ 10 | A class for computing online statistics such as mean and variance. 11 | Uses Welford's online algorithm. 12 | From https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm. 13 | 14 | NOTE: The variance we calculate is the sample variance, not the population variance. 15 | """ 16 | 17 | def __init__( 18 | self, 19 | current_count: int = 0, 20 | current_mean: float = 0, 21 | current_variance: float = 0, 22 | ): 23 | """ 24 | Initialize online statistics. 25 | `mean` accumulates the mean of the entire dataset 26 | `count` aggregates the number of samples seen so far 27 | `current_M2` aggregates the squared distances from the mean 28 | """ 29 | if not (current_count >= 0 and current_variance >= 0): 30 | raise ValueError( 31 | "Must set `current_count` and `current_variance` to be non-negative." 32 | f"You specified `current_count` = {current_count} and `current_variance` = {current_variance}." 33 | ) 34 | self.current_count: int = current_count 35 | self.current_mean: float = current_mean 36 | if current_count == 0 and current_variance == 0: 37 | self.current_M2 = 0.0 38 | elif current_count > 0: 39 | self.current_M2 = current_variance * (current_count - 1) 40 | else: 41 | raise ValueError( 42 | "Cannot specify `current_variance` with a value > 0" 43 | "without specifying `current_count` with a value > 0. " 44 | f"You specified `current_count` = {current_count} and `current_variance` = {current_variance}." 45 | ) 46 | 47 | def add(self, newValue: float) -> None: 48 | """ 49 | Add an observation to the calculation using Welford's online algorithm. 50 | 51 | Taken from: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm 52 | """ 53 | self.current_count += 1 54 | delta: float = newValue - self.current_mean 55 | self.current_mean += delta / self.current_count 56 | delta2: float = newValue - self.current_mean 57 | self.current_M2 += delta * delta2 58 | 59 | def mean(self) -> float: 60 | """ 61 | Return the current mean. 62 | """ 63 | return self.current_mean 64 | 65 | def variance(self) -> float: 66 | """ 67 | Return the current sample variance. 68 | """ 69 | if self.current_count < 2: 70 | raise ValueError(f"Cannot compute variance with only {self.current_count} observations.") 71 | 72 | return self.current_M2 / (self.current_count - 1) 73 | 74 | def standard_deviation(self) -> float: 75 | """ 76 | Return the current standard devation. 77 | """ 78 | return math.sqrt(self.variance()) 79 | 80 | @classmethod 81 | def merge_pair(cls, stats1: OnlineStatistics, stats2: OnlineStatistics) -> OnlineStatistics: 82 | """ 83 | Merge two sets of online statistics using Chan's parallel algorithm. 84 | 85 | Taken from: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm 86 | """ 87 | if stats1.current_count == 0: 88 | return stats2 89 | elif stats2.current_count == 0: 90 | return stats1 91 | 92 | count: int = stats1.current_count + stats2.current_count 93 | delta: float = stats2.current_mean - stats1.current_mean 94 | mean: float = stats1.current_mean + delta * stats2.current_count / count 95 | M2 = stats1.current_M2 + stats2.current_M2 + delta**2 * stats1.current_count * stats2.current_count / count 96 | return OnlineStatistics(count, mean, M2 / (count - 1)) 97 | 98 | @classmethod 99 | def merge(cls, stats_list: List[OnlineStatistics]) -> OnlineStatistics: 100 | """ 101 | Merge a list of online statistics. 102 | """ 103 | if len(stats_list) == 0: 104 | raise ValueError("Cannot merge an empty list of statistics.") 105 | unmerged_stats: List[OnlineStatistics] = copy.deepcopy(stats_list) 106 | # Run tree reduction to merge together all pairs of statistics 107 | # in a numerically stable way 108 | # Example: 1 2 3 4 5 -> 3 7 5 -> 10 5 -> 15 109 | while len(unmerged_stats) > 1: 110 | merged_stats: List[OnlineStatistics] = [] 111 | for i in range(0, len(unmerged_stats), 2): 112 | if i + 1 < len(unmerged_stats): 113 | # If there's another stat after this one, merge them 114 | merged_stats.append(cls.merge_pair(unmerged_stats[i], unmerged_stats[i + 1])) 115 | else: 116 | # We've reached the end of our list, so just add the last stat back 117 | merged_stats.append(unmerged_stats[i]) 118 | unmerged_stats = merged_stats 119 | assert len(unmerged_stats) == 1, f"Should only have one stat left after merging, not ({len(unmerged_stats)})." 120 | return unmerged_stats[0] 121 | -------------------------------------------------------------------------------- /tests/femr_test_tools.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import collections 4 | import dataclasses 5 | import datetime 6 | from typing import Any, List, Optional, Sequence, Tuple, Union, cast 7 | 8 | import meds 9 | import meds_reader 10 | import pandas as pd 11 | 12 | from femr.labelers import Label, Labeler 13 | 14 | # 2nd elem of tuple -- 'skip' means no label, None means censored 15 | EventsWithLabels = List[Tuple[Tuple[Tuple, int, Any], Union[bool, str]]] 16 | 17 | 18 | DUMMY_EVENTS = [ 19 | ((1995, 1, 3), meds.birth_code, None), 20 | ((2010, 1, 1), 1, "test_value"), 21 | ((2010, 1, 1), 1, "test_value"), 22 | ((2010, 1, 5), 2, 1), 23 | ((2010, 6, 5), 3, None), 24 | ((2010, 8, 5), 2, None), 25 | ((2011, 7, 5), 2, None), 26 | ((2012, 10, 5), 3, None), 27 | ((2015, 6, 5, 0), 2, None), 28 | ((2015, 6, 5, 10, 10), 2, None), 29 | ((2015, 6, 15, 11), 3, None), 30 | ((2016, 1, 1), 2, None), 31 | ((2016, 3, 1, 10, 10, 10), 4, None), 32 | ] 33 | 34 | NUM_EVENTS = len(DUMMY_EVENTS) 35 | NUM_PATIENTS = 10 36 | 37 | 38 | @dataclasses.dataclass 39 | class DummyEvent: 40 | time: datetime.datetime 41 | code: str 42 | text_value: Optional[str] = None 43 | numeric_value: Optional[float] = None 44 | visit_id: Optional[int] = None 45 | table: Optional[str] = None 46 | clarity_table: Optional[str] = None 47 | end: Optional[datetime.datetime] = None 48 | 49 | def __getattr__(self, name: str) -> Any: 50 | return None 51 | 52 | 53 | @dataclasses.dataclass 54 | class DummySubject: 55 | subject_id: int 56 | events: Sequence[DummyEvent] 57 | 58 | 59 | class DummyDatabase(dict): 60 | def filter(self, subject_ids): 61 | return DummyDatabase({p: self[p] for p in subject_ids}) 62 | 63 | def map( 64 | self, 65 | map_func, 66 | ) -> Any: 67 | return [map_func(self.values())] 68 | 69 | def map_with_data(self, map_func, data, assume_sorted) -> Any: 70 | entries = collections.defaultdict(list) 71 | 72 | for row in data.itertuples(index=False): 73 | entries[row.subject_id].append(row) 74 | 75 | temp = [] 76 | for k, v in entries.items(): 77 | temp.append((self[k], v)) 78 | 79 | return [map_func(temp)] 80 | 81 | 82 | def create_subjects_dataset( 83 | num_subjects: int, events: List[Tuple[Tuple, Any, Any]] = DUMMY_EVENTS 84 | ) -> meds_reader.SubjectDatabase: 85 | """Creates a list of subjects, each with the same events contained in `events`""" 86 | 87 | converted_events: List[DummyEvent] = [] 88 | 89 | for event in events: 90 | if isinstance(event[1], int): 91 | code = str(event[1]) 92 | else: 93 | code = event[1] 94 | 95 | dummy_event = DummyEvent(time=datetime.datetime(*event[0]), code=code) 96 | 97 | if isinstance(event[2], str): 98 | dummy_event.text_value = event[2] 99 | else: 100 | dummy_event.numeric_value = event[2] 101 | 102 | converted_events.append(dummy_event) 103 | 104 | result = DummyDatabase( 105 | (subject_id, DummySubject(subject_id, converted_events)) for subject_id in range(num_subjects) 106 | ) 107 | return cast(meds_reader.SubjectDatabase, result) 108 | 109 | 110 | def assert_labels_are_accurate( 111 | labeled_subjects: pd.DataFrame, 112 | subject_id: int, 113 | true_labels: List[Tuple[datetime.datetime, Optional[bool]]], 114 | help_text: str = "", 115 | ): 116 | """Passes if the labels in `labeled_subjects` for `subject_id` exactly match the labels in `true_labels`.""" 117 | generated_labels: List[Label] = [a for a in labeled_subjects.itertuples(index=False) if a.subject_id == subject_id] 118 | # Check that length of lists of labels are the same 119 | 120 | assert len(generated_labels) == len( 121 | true_labels 122 | ), f"len(generated): {len(generated_labels)} != len(expected): {len(true_labels)} | {help_text}" 123 | # Check that value of labels are the same 124 | for idx, (label, true_label) in enumerate(zip(generated_labels, true_labels)): 125 | assert label.boolean_value == true_label[1] and label.prediction_time == true_label[0], ( 126 | f"subject_id={subject_id}, label_idx={idx}, label={label} | " 127 | f"{label} (Assigned) != {true_label} (Expected) | " 128 | f"{help_text}" 129 | ) 130 | 131 | 132 | def run_test_for_labeler( 133 | labeler: Labeler, 134 | events_with_labels: EventsWithLabels, 135 | true_outcome_times: Optional[List[datetime.datetime]] = None, 136 | true_prediction_times: Optional[List[datetime.datetime]] = None, 137 | help_text: str = "", 138 | ) -> None: 139 | subjects: meds_reader.SubjectDatabase = create_subjects_dataset(10, [x[0] for x in events_with_labels]) 140 | 141 | true_labels: List[Tuple[datetime.datetime, Optional[bool]]] = [ 142 | (datetime.datetime(*x[0][0]), x[1]) for x in events_with_labels if isinstance(x[1], bool) 143 | ] 144 | if true_prediction_times is not None: 145 | # If manually specified prediction times, adjust labels from occurring at `event.start` 146 | # e.g. we may make predictions at `event.end` or `event.start + 1 day` 147 | true_labels = [(tp, tl[1]) for (tl, tp) in zip(true_labels, true_prediction_times)] 148 | labeled_subjects: List[meds.Label] = labeler.apply(subjects) 149 | 150 | # Check accuracy of Labels 151 | for subject_id in subjects: 152 | assert_labels_are_accurate( 153 | labeled_subjects, 154 | subject_id, 155 | true_labels, 156 | help_text=help_text, 157 | ) 158 | -------------------------------------------------------------------------------- /tests/test_ontology.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa: E501 2 | 3 | import pathlib 4 | 5 | import meds 6 | import pyarrow as pa 7 | import pyarrow.parquet as pq 8 | 9 | import femr.ontology 10 | 11 | 12 | def create_fake_athena(tmp_path: pathlib.Path) -> pathlib.Path: 13 | athena = tmp_path / "athena" 14 | athena.mkdir() 15 | 16 | concept = athena / "CONCEPT.csv" 17 | concept.write_text( 18 | """concept_id concept_name domain_id vocabulary_id concept_class_id standard_concept concept_code valid_start_date valid_end_date invalid_reason 19 | 37200198\tType 2 diabetes mellitus with mild nonproliferative diabetic retinopathy with macular edema, right eye\tCondition\tICD10CM\t7-char billing code\t\tE11.3211\t19700101\t20991231\t 20 | 380097\tMacular edema due to diabetes mellitus\tCondition\tSNOMED\tClinical Finding\tS\t312912001\t20020131\t20991231 21 | 4334884 Disorder of macula due to diabetes mellitus Condition SNOMED Clinical Finding S 232020009 20020131 20991231 22 | 4174977 Retinopathy due to diabetes mellitus Condition SNOMED Clinical Finding S 4855003 20020131 20991231 23 | 4208223 Disorder of macula of retina Condition SNOMED Clinical Finding S 312999006 20020131 20991231 24 | 4290333 Macular retinal edema Condition SNOMED Clinical Finding S 37231002 20020131 20991231 25 | 35626904 Retinal edema due to diabetes mellitus Condition SNOMED Clinical Finding S 770323005 20180731 20991231 26 | 45757435 Mild nonproliferative retinopathy due to type 2 diabetes mellitus Condition SNOMED Clinical Finding S 138911000119106 20150131 20991231 27 | """ 28 | ) 29 | relationship = athena / "CONCEPT_RELATIONSHIP.csv" 30 | relationship.write_text( 31 | """concept_id_1 concept_id_2 relationship_id valid_start_date valid_end_date invalid_reason 32 | 37200198 380097 Maps to 20171001 20991231 33 | 37200198 1567956 Is a 20170428 20991231 34 | 37200198 1567959 Is a 20170428 20991231 35 | 37200198 45757435 Maps to 20171001 20991231 36 | 37200198 1567961 Is a 20170428 20991231 37 | 37200198 45552385 Is a 20170428 20991231 38 | 35977781 35977781 Mapped from 20200913 20991231 39 | 46135811 40642538 Has status 20220128 20991231 40 | 46135811 35631990 Has Module 20220128 20991231""" 41 | ) 42 | 43 | ancestor = athena / "CONCEPT_ANCESTOR.csv" 44 | ancestor.write_text( 45 | """ancestor_concept_id\tdescendant_concept_id\tmin_levels_of_separation\tmax_levels_of_separation 46 | 373499 4334884 4 6 47 | 442793 4334884 3 3 48 | 255919 4334884 6 9 49 | 433128 4334884 4 4 50 | 4180628 4334884 6 8 51 | 4209989 4334884 3 3 52 | 441840 4334884 6 10 53 | 4274025 4334884 5 9 54 | 4082284 4334884 2 2 55 | 4042836 4334884 5 7 56 | 443767 4334884 2 2 57 | 4038502 4334884 5 8 58 | 4174977 4334884 1 1 59 | 4334884 4334884 0 0 60 | 4134440 4334884 5 7 61 | 4247371 4334884 5 8 62 | 375252 4334884 3 5 63 | 4027883 4334884 4 4 64 | 4208223 4334884 1 1 65 | 378416 4334884 2 2 66 | 4080992 4334884 4 6 67 | 4162092 4334884 3 3 68 | 255919 380097 7 10 69 | 433128 380097 5 5 70 | 442793 380097 4 4 71 | 373499 380097 5 7 72 | 4180628 380097 7 9 73 | 4209989 380097 4 4 74 | 37018677 380097 3 3 75 | 441840 380097 5 11 76 | 380097 380097 0 0 77 | 4274025 380097 4 10 78 | 372903 380097 2 2 79 | 433595 380097 4 4 80 | 4042836 380097 6 8 81 | 4082284 380097 3 3 82 | 443767 380097 3 3 83 | 4038502 380097 6 9 84 | 4174977 380097 2 2 85 | 4334884 380097 1 1 86 | 4134440 380097 6 8 87 | 4247371 380097 6 9 88 | 4290333 380097 1 1 89 | 375252 380097 4 6 90 | 4027883 380097 5 5 91 | 4040388 380097 3 3 92 | 4208223 380097 2 2 93 | 35626904 380097 1 1 94 | 378416 380097 3 3 95 | 4080992 380097 5 7 96 | 4162092 380097 4 4 97 | """ 98 | ) 99 | 100 | return athena 101 | 102 | 103 | def test_only_athena(tmp_path: pathlib.Path) -> None: 104 | fake_athena = create_fake_athena(tmp_path) 105 | 106 | ontology = femr.ontology.Ontology(str(fake_athena)) 107 | 108 | assert ( 109 | ontology.get_description("ICD10CM/E11.3211") 110 | == "Type 2 diabetes mellitus with mild nonproliferative diabetic retinopathy with macular edema, right eye" 111 | ) 112 | 113 | assert ontology.get_parents("ICD10CM/E11.3211") == {"SNOMED/312912001", "SNOMED/138911000119106"} 114 | assert ontology.get_parents("SNOMED/312912001") == {"SNOMED/37231002", "SNOMED/232020009", "SNOMED/770323005"} 115 | 116 | assert ontology.get_all_parents("ICD10CM/E11.3211") == { 117 | "SNOMED/37231002", 118 | "SNOMED/138911000119106", 119 | "SNOMED/4855003", 120 | "SNOMED/312999006", 121 | "SNOMED/312912001", 122 | "ICD10CM/E11.3211", 123 | "SNOMED/770323005", 124 | "SNOMED/232020009", 125 | } 126 | 127 | assert ontology.get_children("SNOMED/312912001") == {"ICD10CM/E11.3211"} 128 | assert ontology.get_children("SNOMED/37231002") == {"SNOMED/312912001"} 129 | 130 | assert ontology.get_all_children("SNOMED/37231002") == {"ICD10CM/E11.3211", "SNOMED/312912001", "SNOMED/37231002"} 131 | 132 | 133 | def test_athena_and_custom(tmp_path: pathlib.Path) -> None: 134 | fake_athena = create_fake_athena(tmp_path) 135 | 136 | code_metadata = [ 137 | {"code": "CUSTOM/CustomDiabetes", "description": "A nice diabetes code", "parent_codes": ["ICD10CM/E11.3211"]} 138 | ] 139 | 140 | table = pa.Table.from_pylist(code_metadata, schema=meds.code_metadata_schema()) 141 | pq.write_table(table, tmp_path / "codes.parquet") 142 | 143 | ontology = femr.ontology.Ontology(str(fake_athena), str(tmp_path / "codes.parquet")) 144 | 145 | assert ontology.get_description("CUSTOM/CustomDiabetes") == "A nice diabetes code" 146 | assert ontology.get_all_parents("CUSTOM/CustomDiabetes") == { 147 | "CUSTOM/CustomDiabetes", 148 | "SNOMED/37231002", 149 | "SNOMED/138911000119106", 150 | "SNOMED/4855003", 151 | "SNOMED/312999006", 152 | "SNOMED/312912001", 153 | "ICD10CM/E11.3211", 154 | "SNOMED/770323005", 155 | "SNOMED/232020009", 156 | } 157 | 158 | assert ontology.get_all_children("SNOMED/37231002") == { 159 | "CUSTOM/CustomDiabetes", 160 | "ICD10CM/E11.3211", 161 | "SNOMED/312912001", 162 | "SNOMED/37231002", 163 | } 164 | -------------------------------------------------------------------------------- /tutorials/2_Labeling.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "43f4d50c-4e7b-4652-9701-be9366ff70c4", 6 | "metadata": {}, 7 | "source": [ 8 | "# Labeling\n", 9 | "\n", 10 | "A core component of FEMR is labeling subjects.\n", 11 | "\n", 12 | "Labels within FEMR follow the [label schema within MEDS](https://github.com/Medical-Event-Data-Standard/meds/blob/e93f63a2f9642123c49a31ecffcdb84d877dc54a/src/meds/__init__.py#L70).\n", 13 | "\n", 14 | "Per MEDS, each label consists of three attributes:\n", 15 | "\n", 16 | "* `subject_id` (int64): The identifier for the subject to predict on\n", 17 | "* `prediction_time` (datetime.datetime): The timestamp for when the prediction should be made. This indicates what features are allowed to be used for prediction.\n", 18 | "* `boolean_value` (bool): The target to predict\n", 19 | "\n", 20 | "Additional types of labels will be added to MEDS over time, and then supported here." 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 1, 26 | "id": "c6ac5c41-bc99-4731-ad82-7152274c67e1", 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import shutil\n", 31 | "import os\n", 32 | "\n", 33 | "TARGET_DIR = 'trash/tutorial_2'\n", 34 | "\n", 35 | "if os.path.exists(TARGET_DIR):\n", 36 | " shutil.rmtree(TARGET_DIR)\n", 37 | "\n", 38 | "os.mkdir(TARGET_DIR)" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "id": "7e98dd85", 44 | "metadata": {}, 45 | "source": [ 46 | "# Demonstration of some example labels" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 2, 52 | "id": "8d9e2ccd-71c2-4ae0-897b-7ec022f9fdf4", 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "# We can construct these labels manually\n", 57 | "\n", 58 | "import femr.labelers\n", 59 | "import datetime\n", 60 | "import meds\n", 61 | "\n", 62 | "# Predict False on March 2nd, 1994\n", 63 | "example_label = {'subject_id': 100, 'prediction_time': datetime.datetime(1994, 3, 2), 'boolean_value': False}\n", 64 | "\n", 65 | "# Predict True on March 2nd, 2009\n", 66 | "example_label2 = {'subject_id': 100, 'prediction_time': datetime.datetime(2009, 3, 2), 'boolean_value': True}\n", 67 | "\n", 68 | "\n", 69 | "# Multiple labels are stored using a list\n", 70 | "labels = [example_label, example_label2]" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "id": "e77b1bfc-8d2d-4f79-b855-f90b3a73736e", 76 | "metadata": {}, 77 | "source": [ 78 | "# Generating labels programatically within FEMR\n", 79 | "\n", 80 | "One core feature of FEMR is the ability to algorithmically generate labels through the use of a labeling function class.\n", 81 | "\n", 82 | "The core for FEMR's labeling code is the abstract base class [Labeler](https://github.com/som-shahlab/femr/blob/main/src/femr/labelers/core.py#L40).\n", 83 | "\n", 84 | "Labeler has one abstract methods:\n", 85 | "\n", 86 | "```python\n", 87 | "def label(self, subject: meds_reader.Subject) -> List[meds.Label]:\n", 88 | " Generate a list of labels for a subject\n", 89 | "```\n", 90 | "\n", 91 | "Note that the subject is assumed to be the [MEDS Subject schema](https://github.com/Medical-Event-Data-Standard/meds/blob/e93f63a2f9642123c49a31ecffcdb84d877dc54a/src/meds/__init__.py#L18).\n", 92 | "\n", 93 | "Once this method is implemented, the apply function becomes available for generating labels." 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 12, 99 | "id": "9ac22dbe-ef34-468a-8ab3-673e58e5a920", 100 | "metadata": {}, 101 | "outputs": [ 102 | { 103 | "name": "stdout", 104 | "output_type": "stream", 105 | "text": [ 106 | " subject_id prediction_time boolean_value\n", 107 | "0 0 1993-01-31 False\n", 108 | "1 1 1991-08-31 True\n", 109 | "2 2 1992-08-05 True\n", 110 | "3 3 1991-01-11 True\n", 111 | "4 4 1994-04-05 True\n", 112 | ".. ... ... ...\n", 113 | "195 195 1995-10-07 False\n", 114 | "196 196 1995-08-31 False\n", 115 | "197 197 1992-05-29 True\n", 116 | "198 198 1992-10-06 True\n", 117 | "199 199 1993-05-02 True\n", 118 | "\n", 119 | "[200 rows x 3 columns]\n" 120 | ] 121 | } 122 | ], 123 | "source": [ 124 | "from typing import List\n", 125 | "import femr.pat_utils\n", 126 | "import meds_reader\n", 127 | "import meds\n", 128 | "import femr.labelers\n", 129 | "\n", 130 | "\n", 131 | "class IsMaleLabeler(femr.labelers.Labeler):\n", 132 | " # Dummy labeler to predict gender at birth\n", 133 | " \n", 134 | " def label(self, subject: meds_reader.Subject) -> List[meds.Label]:\n", 135 | " is_male = any('Gender/M' == event.code for event in subject.events)\n", 136 | " return [{\n", 137 | " 'subject_id': subject.subject_id, \n", 138 | " 'prediction_time': subject.events[-1].time,\n", 139 | " 'boolean_value': is_male,\n", 140 | " }]\n", 141 | " \n", 142 | "database = meds_reader.SubjectDatabase(\"input/synthetic_meds\")\n", 143 | "\n", 144 | "labeler = IsMaleLabeler()\n", 145 | "labeled_subjects = labeler.apply(database)\n", 146 | "\n", 147 | "\n", 148 | "print(labeled_subjects)\n", 149 | "\n" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 13, 155 | "id": "20bd7859", 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "# We can save these to a parquet\n", 160 | "\n", 161 | "labeled_subjects.to_parquet(\"trash/tutorial_2/labels.parquet\", index=False)" 162 | ] 163 | } 164 | ], 165 | "metadata": { 166 | "kernelspec": { 167 | "display_name": "Python 3 (ipykernel)", 168 | "language": "python", 169 | "name": "python3" 170 | }, 171 | "language_info": { 172 | "codemirror_mode": { 173 | "name": "ipython", 174 | "version": 3 175 | }, 176 | "file_extension": ".py", 177 | "mimetype": "text/x-python", 178 | "name": "python", 179 | "nbconvert_exporter": "python", 180 | "pygments_lexer": "ipython3", 181 | "version": "3.13.3" 182 | } 183 | }, 184 | "nbformat": 4, 185 | "nbformat_minor": 5 186 | } 187 | -------------------------------------------------------------------------------- /tools/stanford/download_bigquery.py: -------------------------------------------------------------------------------- 1 | """ 2 | A tool for downloading datasets from BigQuery. 3 | 4 | Setup: 5 | ``` 6 | pip install --upgrade google-cloud-bigquery 7 | pip install --upgrade google-cloud-storage 8 | conda install google-cloud-sdk -c conda-forge 9 | ``` 10 | 11 | Note: After installing above packages, run `gcloud auth application-default login` on your terminal. 12 | You will be prompted with a authorization link that you will need to follow and approve using your 13 | email address. Then you will copy-paste authorization code to the terminal. 14 | 15 | How to run: 16 | ``` 17 | python download_bigquery.py \ 18 | 19 | --excluded_tables <(Optional) NAME OF TABLE 1 TO BE IGNORED> <(Optional) NAME OF TABLE 2 TO BE IGNORED> 20 | ``` 21 | 22 | Example: python download_bigquery.py som-nero-nigam-starr \ 23 | som-rit-phi-starr-prod.starr_omop_cdm5_deid_1pcent_lite_2023_02_08 . 24 | """ 25 | 26 | from __future__ import annotations 27 | 28 | import argparse 29 | import hashlib 30 | import os 31 | import random 32 | import threading 33 | from functools import partial 34 | 35 | import google 36 | from google.cloud import bigquery, storage 37 | 38 | if __name__ == "__main__": 39 | parser = argparse.ArgumentParser(description="Download a Google BigQuery dataset") 40 | parser.add_argument( 41 | "gcp_project_name", 42 | type=str, 43 | help=( 44 | "The name of *YOUR* GCP project (e.g. 'som-nero-nigam-starr')." 45 | " Note that this need NOT be the GCP project that contains the dataset." 46 | " It just needs to be a GCP project where you have Bucket creation + BigQuery creation permissions." 47 | ), 48 | ) 49 | parser.add_argument( 50 | "gcp_dataset_id", 51 | type=str, 52 | help=( 53 | "The Dataset ID of the GCP dataset to download" 54 | " (e.g. 'som-rit-phi-starr-prod.starr_omop_cdm5_deid_2022_12_03')." 55 | " Note that this is the full ID of the dataset (project name + dataset name)" 56 | ), 57 | ) 58 | parser.add_argument( 59 | "output_dir", 60 | type=str, 61 | help=( 62 | "Path to output directory. Note: The downloaded files will be saved in a subdirectory of this," 63 | " i.e. `output_dir/gcp_dataset_id/...`" 64 | ), 65 | ) 66 | parser.add_argument( 67 | "--excluded_tables", 68 | type=str, 69 | nargs="*", # 0 or more values expected => creates a list 70 | default=[], 71 | help=( 72 | "Optional. Name(s) of tables to exclude. List tables separated by spaces," 73 | " i.e. `--excluded_tables observation note_nlp`" 74 | ), 75 | ) 76 | parser.add_argument( 77 | "--scratch_bucket_postfix", 78 | type=str, 79 | default="_extract_scratch", 80 | help="The postfix for the GCP bucket used for storing temporary files while downloading.", 81 | ) 82 | args = parser.parse_args() 83 | 84 | target = f"{args.output_dir}/{args.gcp_dataset_id}" 85 | os.mkdir(target) 86 | 87 | print('Make sure to run "gcloud auth application-default login" before running this command') 88 | 89 | # Connect to our BigQuery project 90 | client = bigquery.Client(project=args.gcp_project_name) 91 | storage_client = storage.Client(project=args.gcp_project_name) 92 | 93 | random_dir = hashlib.md5(random.randbytes(16)).hexdigest() 94 | 95 | scratch_bucket_name = args.gcp_project_name.replace("-", "_") + args.scratch_bucket_postfix 96 | 97 | print(f"Storing temporary files in gs://{scratch_bucket_name}/{random_dir}") 98 | 99 | try: 100 | bucket = storage_client.get_bucket(scratch_bucket_name) 101 | except google.api_core.exceptions.NotFound as e: 102 | print(f"Could not find the requested bucket? gs://{scratch_bucket_name} in project {args.gcp_project_name}") 103 | raise e 104 | 105 | # Get list of all tables in this GCP dataset 106 | # NOTE: the `HTTPIterator` can be iterated over like a list, but only once (it's a generator) 107 | tables: google.api_core.page_iterator.HTTPIterator = client.list_tables(args.gcp_dataset_id) 108 | print(f"Downloading dataset {args.gcp_dataset_id} using your project {args.gcp_project_name}") 109 | 110 | # Use GZIP compression and export as CVSs 111 | extract_config = bigquery.job.ExtractJobConfig( 112 | compression=bigquery.job.Compression.GZIP, 113 | destination_format=bigquery.job.DestinationFormat.CSV, 114 | field_delimiter=",", 115 | ) 116 | 117 | sem = threading.Semaphore(value=0) # needed for keeping track of how many tables have been downloaded 118 | 119 | def download(table_id: str, f): 120 | """Download the results (a set of .csv.gz's) of the BigQuery extract job to our local filesystem 121 | Note that a single table will be extracted into possibly dozens of smaller .csv.gz files 122 | 123 | Args: 124 | table_id (str): Name of table (e.g. "attribute_definition") 125 | """ 126 | if f.errors is not None: 127 | print("Could not extract, got errors", f.errors, "for", table_id) 128 | os.abort() 129 | sem.release() 130 | 131 | n_tables: int = 0 132 | for table in tables: 133 | # Get the full name of the table 134 | table_name: str = f"{table.project}.{table.dataset_id}.{table.table_id}" 135 | if table.table_id in args.excluded_tables: 136 | print(f"Skipping extraction | table = {table.table_id}") 137 | continue 138 | print(f"Extracting | table = {table.table_id}") 139 | # Create Google Cloud Storage bucket to extract this table into 140 | bucket_target_path: str = f"gs://{scratch_bucket_name}/{random_dir}/{table.table_id}/*.csv.gz" 141 | extract_job = client.extract_table(table.reference, bucket_target_path, job_config=extract_config) 142 | # Call the `download()` function asynchronously to download the bucket contents to our local filesystem 143 | extract_job.add_done_callback(partial(download, table.table_id)) 144 | n_tables += 1 145 | 146 | print(f"\n** Extracting a total of {n_tables} tables**\n") 147 | for i in range(1, n_tables + 1): 148 | sem.acquire() 149 | print(f"====> Finished extracting {i} out of {n_tables} tables") 150 | 151 | print("Starting to download tables") 152 | 153 | os.system(f"gsutil -m rsync -r gs://{scratch_bucket_name}/{random_dir} {target}") 154 | 155 | print("------\n------") 156 | print("Successfully downloaded all tables!") 157 | print("------\n------") 158 | 159 | # Delete the temporary Google Cloud Storage bucket 160 | print("\nDeleting temporary files...") 161 | os.system(f"gsutil -m rm -r gs://{scratch_bucket_name}/{random_dir}") 162 | print("------\n------") 163 | print("Successfully deleted temporary Google Cloud Storage files!") 164 | print("------\n------") 165 | -------------------------------------------------------------------------------- /tests/models/test_batch_creator.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | import meds 4 | from femr_test_tools import create_subjects_dataset 5 | 6 | import femr.models.processor 7 | import femr.models.tasks 8 | import femr.models.tokenizer 9 | 10 | class DummyTokenizer(femr.models.tokenizer.HierarchicalTokenizer): 11 | def __init__(self, is_hierarchical: bool = True): 12 | self.is_hierarchical = is_hierarchical 13 | self.ontology = None 14 | self.vocab_size = 100 15 | 16 | def start_subject(self): 17 | pass 18 | 19 | def get_feature_codes(self, event): 20 | if event.code == meds.birth_code: 21 | return [1], [1] 22 | else: 23 | return [int(event.code)], [1] 24 | 25 | 26 | def get_time_data(self, age: datetime.timedelta, delta: datetime.timedelta) -> float: 27 | return [1,1,1,1] 28 | 29 | def normalize_age(self, age): 30 | return 0.5 31 | 32 | 33 | def assert_two_batches_equal_third(batch1, batch2, batch3): 34 | """This asserts that batch1 + batch2 = batchs3""" 35 | assert batch3["subject_ids"].tolist() == batch1["subject_ids"].tolist() + batch2["subject_ids"].tolist() 36 | 37 | batch3['transformer']['ages'][len(batch1["transformer"]["ages"])] = 0 38 | batch3['transformer']['timestamps'][len(batch1["transformer"]["ages"])] = batch3['transformer']['timestamps'][0] 39 | 40 | assert ( 41 | batch3["transformer"]["ages"].tolist() 42 | == batch1["transformer"]["ages"].tolist() + batch2["transformer"]["ages"].tolist() 43 | ) 44 | assert ( 45 | batch3["transformer"]["timestamps"].tolist() 46 | == batch1["transformer"]["timestamps"].tolist() + batch2["transformer"]["timestamps"].tolist() 47 | ) 48 | 49 | # Checking the label indices is a bit more involved as we have to map to age/subject id and then check that 50 | target_label_ages = [] 51 | target_label_subject_ids = [] 52 | 53 | for label_index in batch1["transformer"]["label_indices"].tolist(): 54 | target_label_ages.append(batch1["transformer"]["ages"][label_index]) 55 | target_label_subject_ids.append(batch1["subject_ids"][label_index]) 56 | 57 | for label_index in batch2["transformer"]["label_indices"].tolist(): 58 | target_label_ages.append(batch2["transformer"]["ages"][label_index]) 59 | target_label_subject_ids.append(batch2["subject_ids"][label_index]) 60 | 61 | actual_label_ages = [] 62 | actual_label_subject_ids = [] 63 | 64 | for label_index in batch3["transformer"]["label_indices"].tolist(): 65 | actual_label_ages.append(batch3["transformer"]["ages"][label_index]) 66 | actual_label_subject_ids.append(batch3["subject_ids"][label_index]) 67 | 68 | assert target_label_ages == actual_label_ages 69 | assert target_label_subject_ids == actual_label_subject_ids 70 | 71 | batch3['transformer']['hierarchical_tokens'][len(batch1["transformer"]["hierarchical_tokens"])] = batch3['transformer']['hierarchical_tokens'][0] 72 | 73 | assert ( 74 | batch3["transformer"]["hierarchical_tokens"].tolist() 75 | == batch1["transformer"]["hierarchical_tokens"].tolist() + batch2["transformer"]["hierarchical_tokens"].tolist() 76 | ) 77 | 78 | def test_two_subjects_concat_no_task(): 79 | tokenizer = DummyTokenizer() 80 | 81 | fake_subjects = create_subjects_dataset(10) 82 | 83 | fake_subject1 = fake_subjects[1] 84 | fake_subject2 = fake_subjects[5] 85 | 86 | creator = femr.models.processor.BatchCreator(tokenizer) 87 | 88 | creator.start_batch() 89 | creator.add_subject(fake_subject1) 90 | 91 | data_for_subject1 = creator.get_batch_data() 92 | 93 | creator.start_batch() 94 | creator.add_subject(fake_subject2) 95 | 96 | data_for_subject2 = creator.get_batch_data() 97 | 98 | creator.start_batch() 99 | creator.add_subject(fake_subject1) 100 | creator.add_subject(fake_subject2) 101 | 102 | data_for_subjects = creator.get_batch_data() 103 | 104 | assert_two_batches_equal_third(data_for_subject1, data_for_subject2, data_for_subjects) 105 | 106 | 107 | def test_split_subjects_concat_no_task(): 108 | tokenizer = DummyTokenizer() 109 | 110 | fake_subjects = create_subjects_dataset(10) 111 | 112 | fake_subject = fake_subjects[1] 113 | 114 | creator = femr.models.processor.BatchCreator(tokenizer) 115 | 116 | creator.start_batch() 117 | creator.add_subject(fake_subject) 118 | 119 | data_for_subject = creator.get_batch_data() 120 | 121 | length = len(data_for_subject["transformer"]["timestamps"]) 122 | 123 | creator.start_batch() 124 | creator.add_subject(fake_subject, offset=0, max_length=length // 2) 125 | 126 | data_for_part1 = creator.get_batch_data() 127 | 128 | creator.start_batch() 129 | creator.add_subject(fake_subject, offset=length // 2, max_length=None) 130 | 131 | data_for_part2 = creator.get_batch_data() 132 | 133 | assert_two_batches_equal_third(data_for_part1, data_for_part2, data_for_subject) 134 | 135 | 136 | def test_two_subjects_concat_task(): 137 | tokenizer = DummyTokenizer() 138 | 139 | fake_subjects = create_subjects_dataset(10) 140 | 141 | labels = [ 142 | {"subject_id": 1, "prediction_time": datetime.datetime(2011, 7, 6)}, 143 | {"subject_id": 1, "prediction_time": datetime.datetime(2017, 1, 1)}, 144 | {"subject_id": 5, "prediction_time": datetime.datetime(2011, 11, 6)}, 145 | {"subject_id": 5, "prediction_time": datetime.datetime(2017, 2, 1)}, 146 | ] 147 | labels = [meds.Label(**label) for label in labels] 148 | 149 | task = femr.models.tasks.LabeledSubjectTask(labels) 150 | 151 | fake_subject1 = fake_subjects[1] 152 | fake_subject2 = fake_subjects[5] 153 | 154 | creator = femr.models.processor.BatchCreator(tokenizer, task=task) 155 | 156 | creator.start_batch() 157 | creator.add_subject(fake_subject1) 158 | 159 | data_for_subject1 = creator.get_batch_data() 160 | 161 | creator.start_batch() 162 | creator.add_subject(fake_subject2) 163 | 164 | data_for_subject2 = creator.get_batch_data() 165 | 166 | creator.start_batch() 167 | creator.add_subject(fake_subject1) 168 | creator.add_subject(fake_subject2) 169 | 170 | data_for_subjects = creator.get_batch_data() 171 | 172 | assert_two_batches_equal_third(data_for_subject1, data_for_subject2, data_for_subjects) 173 | 174 | 175 | def test_split_subjects_concat_task(): 176 | tokenizer = DummyTokenizer() 177 | 178 | fake_subjects = create_subjects_dataset(10) 179 | 180 | fake_subject = fake_subjects[1] 181 | 182 | task = femr.models.tasks.LabeledSubjectTask( 183 | [ 184 | {"subject_id": 1, "prediction_time": datetime.datetime(2010, 8, 6)}, 185 | {"subject_id": 1, "prediction_time": datetime.datetime(2017, 1, 1)}, 186 | ] 187 | ) 188 | 189 | creator = femr.models.processor.BatchCreator(tokenizer, task=task) 190 | 191 | creator.start_batch() 192 | creator.add_subject(fake_subject) 193 | 194 | data_for_subject = creator.get_batch_data() 195 | 196 | length = len(data_for_subject["transformer"]["timestamps"]) 197 | 198 | creator.start_batch() 199 | creator.add_subject(fake_subject, offset=0, max_length=length // 2) 200 | 201 | data_for_part1 = creator.get_batch_data() 202 | 203 | creator.start_batch() 204 | creator.add_subject(fake_subject, offset=length // 2, max_length=None) 205 | 206 | data_for_part2 = creator.get_batch_data() 207 | 208 | assert_two_batches_equal_third(data_for_part1, data_for_part2, data_for_subject) 209 | -------------------------------------------------------------------------------- /src/femr/ontology.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import collections 4 | import os 5 | from typing import Any, Dict, Iterable, Iterator, Optional, Set 6 | 7 | import meds_reader 8 | import polars as pl 9 | 10 | 11 | def _get_all_codes_map(subjects: Iterator[meds_reader.Subject]) -> Set[str]: 12 | result = set() 13 | for subject in subjects: 14 | for event in subject.events: 15 | result.add(event.code) 16 | return result 17 | 18 | 19 | class Ontology: 20 | def __init__(self, athena_path: str, code_metadata_path: Optional[str] = None): 21 | """Create an Ontology from an Athena download and an optional meds Code Metadata structure. 22 | 23 | NOTE: This is an expensive operation. 24 | It is recommended to create an ontology once and then save/load it as necessary. 25 | """ 26 | # Load from code metadata 27 | self.description_map: Dict[str, str] = {} 28 | self.parents_map: Dict[str, Set[str]] = collections.defaultdict(set) 29 | 30 | # Load from the athena path ... 31 | concept = pl.scan_csv(os.path.join(athena_path, "CONCEPT.csv"), separator="\t", infer_schema_length=0, quote_char=None) 32 | code_col = pl.col("vocabulary_id") + "/" + pl.col("concept_code") 33 | description_col = pl.col("concept_name") 34 | concept_id_col = pl.col("concept_id").cast(pl.Int64) 35 | 36 | processed_concepts = ( 37 | concept.select(code_col, concept_id_col, description_col, pl.col("standard_concept").is_null()) 38 | .collect() 39 | .rows() 40 | ) 41 | 42 | concept_id_to_code_map = {} 43 | 44 | non_standard_concepts = set() 45 | 46 | for code, concept_id, description, is_non_standard in processed_concepts: 47 | concept_id_to_code_map[concept_id] = code 48 | 49 | # We don't want to override code metadata 50 | if code not in self.description_map: 51 | self.description_map[code] = description 52 | 53 | if is_non_standard: 54 | non_standard_concepts.add(concept_id) 55 | 56 | relationship = pl.scan_csv( 57 | os.path.join(athena_path, "CONCEPT_RELATIONSHIP.csv"), separator="\t", infer_schema_length=0 58 | ) 59 | relationship_id = pl.col("relationship_id") 60 | relationship = relationship.filter( 61 | relationship_id == "Maps to", pl.col("concept_id_1") != pl.col("concept_id_2") 62 | ) 63 | for concept_id_1, concept_id_2 in ( 64 | relationship.select(pl.col("concept_id_1").cast(pl.Int64), pl.col("concept_id_2").cast(pl.Int64)) 65 | .collect() 66 | .rows() 67 | ): 68 | if concept_id_1 in non_standard_concepts: 69 | self.parents_map[concept_id_to_code_map[concept_id_1]].add(concept_id_to_code_map[concept_id_2]) 70 | 71 | ancestor = pl.scan_csv(os.path.join(athena_path, "CONCEPT_ANCESTOR.csv"), separator="\t", infer_schema_length=0) 72 | ancestor = ancestor.filter(pl.col("min_levels_of_separation") == "1") 73 | for concept_id, parent_concept_id in ( 74 | ancestor.select( 75 | pl.col("descendant_concept_id").cast(pl.Int64), pl.col("ancestor_concept_id").cast(pl.Int64) 76 | ) 77 | .collect() 78 | .rows() 79 | ): 80 | self.parents_map[concept_id_to_code_map[concept_id]].add(concept_id_to_code_map[parent_concept_id]) 81 | 82 | if code_metadata_path is not None: 83 | code_metadata = pl.scan_parquet(code_metadata_path) 84 | code_metadat_items = ( 85 | code_metadata.select(pl.col("code"), pl.col("description"), pl.col("parent_codes")).collect().to_dicts() 86 | ) 87 | 88 | # Have to add after OMOP to overwrite ... 89 | for code_info in code_metadat_items: 90 | code = code_info.get("code") 91 | if code is not None: 92 | if code_info.get("description") is not None: 93 | self.description_map[code] = code_info["description"] 94 | if code_info.get("parent_codes") is not None: 95 | self.parents_map[code] = set(i for i in code_info["parent_codes"] if i is not None) 96 | 97 | self.children_map = collections.defaultdict(set) 98 | for code, parents in self.parents_map.items(): 99 | for parent in parents: 100 | self.children_map[parent].add(code) 101 | 102 | self.all_parents_map: Dict[str, Set[str]] = {} 103 | self.all_children_map: Dict[str, Set[str]] = {} 104 | 105 | def prune_to_dataset( 106 | self, 107 | data_pool: meds_reader.SubjectDatabase, 108 | prune_all_descriptions: bool = False, 109 | remove_ontologies: Set[str] = set(), 110 | ) -> None: 111 | valid_codes = set() 112 | for chunk_codes in data_pool.map(_get_all_codes_map): 113 | valid_codes |= chunk_codes 114 | 115 | if prune_all_descriptions: 116 | self.description_map = {} 117 | 118 | all_parents = set() 119 | 120 | for code in valid_codes: 121 | all_parents |= self.get_all_parents(code) 122 | 123 | def is_valid(code): 124 | ontology = code.split("/")[0] 125 | return (code in valid_codes) or ((ontology not in remove_ontologies) and (code in all_parents)) 126 | 127 | codes = self.children_map.keys() | self.parents_map.keys() | self.description_map.keys() 128 | for code in codes: 129 | m: Any 130 | if is_valid(code): 131 | for m in (self.children_map, self.parents_map): 132 | m[code] = {a for a in m[code] if is_valid(a)} 133 | else: 134 | for m in (self.children_map, self.parents_map, self.description_map): 135 | if code in m: 136 | del m[code] 137 | 138 | self.all_parents_map = {} 139 | self.all_children_map = {} 140 | 141 | # Prime the pump 142 | for code in self.children_map.keys() | self.parents_map.keys(): 143 | self.get_all_parents(code) 144 | 145 | def get_description(self, code: str) -> Optional[str]: 146 | """Get a description of a code.""" 147 | return self.description_map.get(code) 148 | 149 | def get_children(self, code: str) -> Iterable[str]: 150 | """Get the children for a given code.""" 151 | return self.children_map.get(code, set()) 152 | 153 | def get_parents(self, code: str) -> Iterable[str]: 154 | """Get the parents for a given code.""" 155 | return self.parents_map.get(code, set()) 156 | 157 | def get_all_children(self, code: str) -> Set[str]: 158 | """Get all children, including through the ontology.""" 159 | if code not in self.all_children_map: 160 | result = {code} 161 | for child in self.children_map.get(code, set()): 162 | result |= self.get_all_children(child) 163 | self.all_children_map[code] = result 164 | return self.all_children_map[code] 165 | 166 | def get_all_children_for_codes(self, codes: Set[str]) -> Set[str]: 167 | result = set() 168 | for code in codes: 169 | result |= self.get_all_children(code) 170 | return result 171 | 172 | def get_all_parents(self, code: str) -> Set[str]: 173 | """Get all parents, including through the ontology.""" 174 | if code not in self.all_parents_map: 175 | result = {code} 176 | for parent in self.parents_map.get(code, set()): 177 | result |= self.get_all_parents(parent) 178 | self.all_parents_map[code] = result 179 | 180 | return self.all_parents_map[code] 181 | 182 | def get_all_parents_for_codes(self, codes: Set[str]) -> Set[str]: 183 | result = set() 184 | for code in codes: 185 | result |= self.get_all_parents(code) 186 | return result 187 | -------------------------------------------------------------------------------- /src/femr/transforms/stanford.py: -------------------------------------------------------------------------------- 1 | # mypy: disable-error-code="attr-defined" 2 | 3 | """Transforms that are unique to STARR OMOP.""" 4 | 5 | import datetime 6 | from typing import Dict, Tuple 7 | 8 | import meds 9 | import meds_reader.transform 10 | 11 | 12 | def _move_date_to_end( 13 | d: datetime.datetime, 14 | ) -> datetime.datetime: 15 | if d.time() == datetime.time.min: 16 | return d + datetime.timedelta(days=1) - datetime.timedelta(minutes=1) 17 | else: 18 | return d 19 | 20 | 21 | def move_visit_start_to_first_event_start( 22 | subject: meds_reader.transform.MutableSubject, 23 | ) -> meds_reader.transform.MutableSubject: 24 | """Assign visit start times to equal start time of first event in visit 25 | 26 | This function assigns the start time associated with each visit to be 27 | the start time of the first event that (1) is associated with the visit 28 | (i.e., shares the same visit ID as the visit event), (2) is a non-visit 29 | event, and (3) occurs on the same day as the visit event. If the visit 30 | has no non-visit events or all events associated with the visit have 31 | the same start time as the visit event (e.g., events with a start time 32 | of midnight such as billing codes, assuming visit events also have a 33 | midnight start time) then the visit start time remains unchanged. 34 | Events that occur on days prior to the visit do not affect the visit 35 | start time. 36 | 37 | Note that not all visit start times are set to 12:00 AM in the raw data. 38 | STARR-OMOP currently uses the first available value out of (1) hospital 39 | admission time, (2) effective date datetime, and (3) effective date, in 40 | that order. In the OMOP DEID from 12/20/2022 about 10% of visits have 41 | a time that is not '00:00:00'. 42 | """ 43 | first_event_starts: Dict[int, datetime.datetime] = {} 44 | visit_starts: Dict[int, datetime.datetime] = {} 45 | 46 | # Find the stated start time for each visit 47 | for event in subject.events: 48 | if event.table == "visit": 49 | if event.visit_id in visit_starts and visit_starts[event.visit_id] != event.time: 50 | raise RuntimeError( 51 | f"Multiple visit events with visit ID {event.visit_id} " + f" for subject ID {subject.subject_id}" 52 | ) 53 | visit_starts[event.visit_id] = event.time 54 | 55 | # Find the minimum start time over all non-visit events associated with each visit 56 | for event in subject.events: 57 | if event.visit_id is not None: 58 | # Only trigger for non-visit events with start time after associated visit start 59 | # Note: ignores non-visit events starting same time as visit (i.e., at midnight) 60 | if event.visit_id in visit_starts and event.time > visit_starts[event.visit_id]: 61 | first_event_starts[event.visit_id] = min( 62 | event.time, 63 | first_event_starts.get(event.visit_id, event.time), 64 | ) 65 | 66 | # Assign visit start times to be same as first non-visit event with same visit ID 67 | for event in subject.events: 68 | if event.table == "visit": 69 | # Triggers if there is a non-visit event associated with the visit ID that has 70 | # start time strictly after the recorded visit start 71 | if event.visit_id in first_event_starts: 72 | event.time = first_event_starts[event.visit_id] 73 | 74 | if event.end is not None: 75 | # Reset the visit end to be ≥ the visit start 76 | event.end = max(event.time, event.end) 77 | 78 | subject.events.sort(key=lambda a: a.time) 79 | 80 | return subject 81 | 82 | 83 | def move_to_day_end(subject: meds_reader.transform.MutableSubject) -> meds_reader.transform.MutableSubject: 84 | """We assume that everything coded at midnight should actually be moved to the end of the day.""" 85 | for event in subject.events: 86 | event.time = _move_date_to_end(event.time) 87 | if event.end is not None: 88 | event.end = _move_date_to_end(event.end) 89 | event.end = max(event.end, event.time) 90 | 91 | subject.events.sort(key=lambda a: a.time) 92 | 93 | return subject 94 | 95 | 96 | def switch_to_icd10cm(subject: meds_reader.transform.MutableSubject) -> meds_reader.transform.MutableSubject: 97 | """Switch from ICD10 to ICD10CM.""" 98 | for event in subject.events: 99 | if event.code.startswith("ICD10/"): 100 | event.code = event.code.replace("ICD10/", "ICD10CM/", 1) 101 | 102 | return subject 103 | 104 | 105 | def move_pre_birth(subject: meds_reader.transform.MutableSubject) -> meds_reader.transform.MutableSubject: 106 | """Move all events to after the birth of a subject.""" 107 | birth_date = None 108 | for event in subject.events: 109 | if event.code == meds.birth_code: 110 | birth_date = event.time 111 | 112 | assert birth_date is not None 113 | 114 | new_events = [] 115 | for event in subject.events: 116 | if event.time < birth_date: 117 | delta = birth_date - event.time 118 | if delta > datetime.timedelta(days=30): 119 | continue 120 | 121 | event.time = birth_date 122 | 123 | if event.end is not None and event.end < birth_date: 124 | event.end = birth_date 125 | 126 | new_events.append(event) 127 | 128 | subject.events = new_events 129 | subject.events.sort(key=lambda a: a.time) 130 | 131 | return subject 132 | 133 | 134 | def move_billing_codes(subject: meds_reader.transform.MutableSubject) -> meds_reader.transform.MutableSubject: 135 | """Move billing codes to the end of each visit. 136 | 137 | One issue with our OMOP extract is that billing codes are incorrectly assigned at the start of the visit. 138 | This class fixes that by assigning them to the end of the visit. 139 | """ 140 | end_visits: Dict[int, datetime.datetime] = {} # Map from visit ID to visit end time 141 | lowest_visit: Dict[Tuple[datetime.datetime, str], int] = {} # Map from code/start time pairs to visit ID 142 | 143 | # List of billing code tables based on the original Clarity queries used to form STRIDE 144 | billing_codes = [ 145 | "pat_enc_dx", 146 | "hsp_acct_dx_list", 147 | "arpb_transactions", 148 | ] 149 | 150 | all_billing_codes = {(prefix + "_" + billing_code) for billing_code in billing_codes for prefix in ["shc", "lpch"]} 151 | 152 | for event in subject.events: 153 | # For events that share the same code/start time, we find the lowest visit ID 154 | if event.clarity_table in all_billing_codes and event.visit_id is not None: 155 | key = (event.time, event.code) 156 | if key not in lowest_visit: 157 | lowest_visit[key] = event.visit_id 158 | else: 159 | lowest_visit[key] = min(lowest_visit[key], event.visit_id) 160 | 161 | if event.clarity_table in ("lpch_pat_enc", "shc_pat_enc"): 162 | if event.end is not None: 163 | if event.visit_id is None: 164 | # Every event with an end time should have a visit ID associated with it 165 | raise RuntimeError(f"Expected visit id for visit? {subject.subject_id} {event}") 166 | if end_visits.get(event.visit_id, event.end) != event.end: 167 | # Every event associated with a visit should have an end time that matches the visit end time 168 | # Also the end times of all events associated with a visit should have the same end time 169 | raise RuntimeError(f"Multiple end visits? {end_visits.get(event.visit_id)} {event}") 170 | end_visits[event.visit_id] = event.end 171 | 172 | for event in subject.events: 173 | if event.clarity_table in all_billing_codes: 174 | key = (event.time, event.code) 175 | if event.visit_id != lowest_visit.get(key, None): 176 | # Drop this event as we already have it, just with a different visit_id? 177 | # We only keep the copy of the event associated with the lowest visit id 178 | # (Lowest visit id is arbitrary, no explicit connection to time) 179 | continue 180 | 181 | if event.visit_id is None: 182 | # This is a bad code (it has no associated visit_id), but 183 | # we would rather keep it than get rid of it 184 | continue 185 | 186 | end_visit = end_visits.get(event.visit_id) 187 | if end_visit is None: 188 | raise RuntimeError(f"Expected visit end for code {subject.subject_id} {event} {subject}") 189 | 190 | # The end time for an event should be no later than its associated visit end time 191 | if event.end is not None: 192 | event.end = max(event.end, end_visit) 193 | 194 | # The start time for an event should be no later than its associated visit end time 195 | event.time = max(event.time, end_visit) 196 | 197 | subject.events.sort(key=lambda a: a.time) 198 | 199 | return subject 200 | -------------------------------------------------------------------------------- /tutorials/5_MOTOR Featurization And Modeling.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "55c59464-ef8f-4fb9-a0a6-ebb68475e2a9", 6 | "metadata": {}, 7 | "source": [ 8 | "# Using MOTOR to generate features and training models on those features\n", 9 | "\n", 10 | "We can use a trained MOTOR model to generate features and then use those features in a logistic regression model." 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "id": "fe93d59d-f135-46f6-b0a7-2d75d9b18e6f", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import shutil\n", 21 | "import os\n", 22 | "\n", 23 | "TARGET_DIR = 'trash/tutorial_5'\n", 24 | "\n", 25 | "if os.path.exists(TARGET_DIR):\n", 26 | " shutil.rmtree(TARGET_DIR)\n", 27 | "\n", 28 | "os.mkdir(TARGET_DIR)" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "id": "25d741a7-46a2-4760-a369-3efb01afd804", 35 | "metadata": {}, 36 | "outputs": [ 37 | { 38 | "name": "stderr", 39 | "output_type": "stream", 40 | "text": [ 41 | "/home/ethanid/envs/motor_nlp/lib/python3.13/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 42 | " from .autonotebook import tqdm as notebook_tqdm\n", 43 | "Some weights of the model checkpoint at input/motor_model were not used when initializing FEMRModel: ['task_model.final_layer.bias', 'task_model.final_layer.weight', 'task_model.norm.weight', 'task_model.task_layer.bias', 'task_model.task_layer.weight', 'task_model.task_time_bias']\n", 44 | "- This IS expected if you are initializing FEMRModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", 45 | "- This IS NOT expected if you are initializing FEMRModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" 46 | ] 47 | }, 48 | { 49 | "name": "stdout", 50 | "output_type": "stream", 51 | "text": [ 52 | "Got batches 18\n" 53 | ] 54 | }, 55 | { 56 | "name": "stderr", 57 | "output_type": "stream", 58 | "text": [ 59 | "Generating train split: 18 examples [00:00, 923.66 examples/s]\n", 60 | "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 18/18 [00:00<00:00, 138.96it/s]" 61 | ] 62 | }, 63 | { 64 | "name": "stdout", 65 | "output_type": "stream", 66 | "text": [ 67 | "subject_ids (200,)\n", 68 | "feature_times (200,)\n", 69 | "features (200, 64)\n" 70 | ] 71 | }, 72 | { 73 | "name": "stderr", 74 | "output_type": "stream", 75 | "text": [ 76 | "\n" 77 | ] 78 | } 79 | ], 80 | "source": [ 81 | "import femr.models.transformer\n", 82 | "import pandas as pd\n", 83 | "import meds_reader\n", 84 | "import pickle\n", 85 | "\n", 86 | "# First, we compute our features\n", 87 | "\n", 88 | "# Load some labels\n", 89 | "labels = pd.read_parquet('input/labels.parquet')\n", 90 | "\n", 91 | "# Load our data\n", 92 | "database = meds_reader.SubjectDatabase('input/synthetic_meds')\n", 93 | "\n", 94 | "# We need an ontology for MOTOR\n", 95 | "with open('input/ontology.pkl', 'rb') as f:\n", 96 | " ontology = pickle.load(f)\n", 97 | "\n", 98 | "features = femr.models.transformer.compute_features(database, 'input/motor_model', labels=list(labels.itertuples()), num_proc=4, tokens_per_batch=128, ontology=ontology)\n", 99 | "\n", 100 | "# We have our features\n", 101 | "for k, v in features.items():\n", 102 | " print(k, v.shape)" 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "id": "4c5c75a9", 108 | "metadata": {}, 109 | "source": [ 110 | "# Joining features and labels\n", 111 | "\n", 112 | "Given a feature set, it's important to be able to join a set of labels to those features.\n", 113 | "\n", 114 | "This can be done with femr.featurizers.join_labels" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 6, 120 | "id": "9ad882f7", 121 | "metadata": {}, 122 | "outputs": [ 123 | { 124 | "name": "stdout", 125 | "output_type": "stream", 126 | "text": [ 127 | "boolean_values (200,)\n", 128 | "subject_ids (200,)\n", 129 | "times (200,)\n", 130 | "features (200, 64)\n" 131 | ] 132 | } 133 | ], 134 | "source": [ 135 | "import femr.featurizers\n", 136 | "\n", 137 | "features_and_labels = femr.featurizers.join_labels(features, labels)\n", 138 | "\n", 139 | "for k, v in features_and_labels.items():\n", 140 | " print(k, v.shape)" 141 | ] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "id": "7192ccc8", 146 | "metadata": {}, 147 | "source": [ 148 | "# Data Splitting\n", 149 | "\n", 150 | "When using a pretrained CLMBR model, we have to be very careful to use the splits used for the original model" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 7, 156 | "id": "b5c49417", 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "import femr.splits\n", 161 | "import numpy as np\n", 162 | "\n", 163 | "# We split into a global training and test set\n", 164 | "split = femr.splits.SubjectSplit.load_from_csv('input/motor_model/main_split.csv')\n", 165 | "\n", 166 | "train_mask = np.isin(features_and_labels['subject_ids'], split.train_subject_ids)\n", 167 | "test_mask = np.isin(features_and_labels['subject_ids'], split.test_subject_ids)\n", 168 | "\n", 169 | "percent_train = .70\n", 170 | "X_train, y_train = (\n", 171 | " features_and_labels['features'][train_mask],\n", 172 | " features_and_labels['boolean_values'][train_mask],\n", 173 | ")\n", 174 | "X_test, y_test = (\n", 175 | " features_and_labels['features'][test_mask],\n", 176 | " features_and_labels['boolean_values'][test_mask],\n", 177 | ")" 178 | ] 179 | }, 180 | { 181 | "cell_type": "markdown", 182 | "id": "8deca785", 183 | "metadata": {}, 184 | "source": [ 185 | "# Building Models\n", 186 | "\n", 187 | "The generated features can then be used to build your standard models. In this case we construct both logistic regression and XGBoost models and evaluate them.\n", 188 | "\n", 189 | "Performance is perfect since our task (predicting gender) is 100% determined by the features" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": 8, 195 | "id": "bad5ad4f", 196 | "metadata": {}, 197 | "outputs": [ 198 | { 199 | "name": "stdout", 200 | "output_type": "stream", 201 | "text": [ 202 | "---- Logistic Regression ----\n", 203 | "Train:\n", 204 | "\tAUROC: 0.9182372505543237\n", 205 | "\tAPS: 0.9169251055618138\n", 206 | "\tAccuracy: 0.8235294117647058\n", 207 | "\tF1 Score: 0.8148148148148148\n", 208 | "Test:\n", 209 | "\tAUROC: 0.55\n", 210 | "\tAPS: 0.757244869423662\n", 211 | "\tAccuracy: 0.4666666666666667\n", 212 | "\tF1 Score: 0.5294117647058824\n" 213 | ] 214 | } 215 | ], 216 | "source": [ 217 | "import sklearn.linear_model\n", 218 | "import sklearn.metrics\n", 219 | "import sklearn.preprocessing\n", 220 | "\n", 221 | "def run_analysis(title: str, y_train, y_train_proba, y_test, y_test_proba):\n", 222 | " print(f\"---- {title} ----\")\n", 223 | " print(\"Train:\")\n", 224 | " print_metrics(y_train, y_train_proba)\n", 225 | " print(\"Test:\")\n", 226 | " print_metrics(y_test, y_test_proba)\n", 227 | "\n", 228 | "def print_metrics(y_true, y_proba):\n", 229 | " y_pred = y_proba > 0.5\n", 230 | " auroc = sklearn.metrics.roc_auc_score(y_true, y_proba)\n", 231 | " aps = sklearn.metrics.average_precision_score(y_true, y_proba)\n", 232 | " accuracy = sklearn.metrics.accuracy_score(y_true, y_pred)\n", 233 | " f1 = sklearn.metrics.f1_score(y_true, y_pred)\n", 234 | " print(\"\\tAUROC:\", auroc)\n", 235 | " print(\"\\tAPS:\", aps)\n", 236 | " print(\"\\tAccuracy:\", accuracy)\n", 237 | " print(\"\\tF1 Score:\", f1)\n", 238 | "\n", 239 | "\n", 240 | "model = sklearn.linear_model.LogisticRegressionCV(penalty=\"l2\", solver=\"liblinear\").fit(X_train, y_train)\n", 241 | "y_train_proba = model.predict_proba(X_train)[::, 1]\n", 242 | "y_test_proba = model.predict_proba(X_test)[::, 1]\n", 243 | "run_analysis(\"Logistic Regression\", y_train, y_train_proba, y_test, y_test_proba)" 244 | ] 245 | } 246 | ], 247 | "metadata": { 248 | "kernelspec": { 249 | "display_name": "Python 3 (ipykernel)", 250 | "language": "python", 251 | "name": "python3" 252 | }, 253 | "language_info": { 254 | "codemirror_mode": { 255 | "name": "ipython", 256 | "version": 3 257 | }, 258 | "file_extension": ".py", 259 | "mimetype": "text/x-python", 260 | "name": "python", 261 | "nbconvert_exporter": "python", 262 | "pygments_lexer": "ipython3", 263 | "version": "3.13.3" 264 | } 265 | }, 266 | "nbformat": 4, 267 | "nbformat_minor": 5 268 | } 269 | -------------------------------------------------------------------------------- /src/femr/labelers/omop.py: -------------------------------------------------------------------------------- 1 | """meds.Labeling functions for OMOP data.""" 2 | 3 | from __future__ import annotations 4 | 5 | import datetime 6 | from typing import Any, Callable, List, Optional 7 | 8 | import meds 9 | import meds_reader 10 | 11 | import femr.ontology 12 | 13 | from .core import TimeHorizon, TimeHorizonEventLabeler 14 | 15 | 16 | def identity(x: Any) -> Any: 17 | return x 18 | 19 | 20 | def get_death_concepts() -> List[str]: 21 | return [ 22 | meds.death_code, 23 | ] 24 | 25 | 26 | def move_datetime_to_end_of_day(date: datetime.datetime) -> datetime.datetime: 27 | return date.replace(hour=23, minute=59, second=0) 28 | 29 | 30 | ########################################################## 31 | ########################################################## 32 | # Abstract classes derived from TimeHorizonEventLabeler 33 | ########################################################## 34 | ########################################################## 35 | 36 | 37 | class CodeLabeler(TimeHorizonEventLabeler): 38 | """Apply a label based on 1+ outcome_codes' occurrence(s) over a fixed time horizon.""" 39 | 40 | def __init__( 41 | self, 42 | outcome_codes: List[str], 43 | time_horizon: TimeHorizon, 44 | prediction_codes: Optional[List[str]] = None, 45 | prediction_time_adjustment_func: Callable = identity, 46 | ): 47 | """Create a CodeLabeler, which labels events whose index in your Ontology is in `self.outcome_codes` 48 | 49 | Args: 50 | prediction_codes (List[int]): Events that count as an occurrence of the outcome. 51 | time_horizon (TimeHorizon): An interval of time. If the event occurs during this time horizon, then 52 | the label is TRUE. Otherwise, FALSE. 53 | prediction_codes (Optional[List[int]]): If not None, limit events at which you make predictions to 54 | only events with an `event.code` in these codes. 55 | prediction_time_adjustment_func (Optional[Callable]). A function that takes in a `datetime.datetime` 56 | and returns a different `datetime.datetime`. Defaults to the identity function. 57 | """ 58 | self.outcome_codes: List[str] = outcome_codes 59 | self.time_horizon: TimeHorizon = time_horizon 60 | self.prediction_codes: Optional[List[str]] = prediction_codes 61 | self.prediction_time_adjustment_func: Callable = prediction_time_adjustment_func 62 | 63 | def get_prediction_times(self, subject: meds_reader.Subject) -> List[datetime.datetime]: 64 | """Return each event's start time (possibly modified by prediction_time_adjustment_func) 65 | as the time to make a prediction. Default to all events whose `code` is in `self.prediction_codes`.""" 66 | times: List[datetime.datetime] = [] 67 | last_time = None 68 | for e in subject.events: 69 | prediction_time: datetime.datetime = self.prediction_time_adjustment_func(e.time) 70 | 71 | if ((self.prediction_codes is None) or (e.code in self.prediction_codes)) and ( 72 | last_time != prediction_time 73 | ): 74 | times.append(prediction_time) 75 | last_time = prediction_time 76 | return times 77 | 78 | def get_time_horizon(self) -> TimeHorizon: 79 | return self.time_horizon 80 | 81 | def get_outcome_times(self, subject: meds_reader.Subject) -> List[datetime.datetime]: 82 | """Return the start times of this subject's events whose `code` is in `self.outcome_codes`.""" 83 | times: List[datetime.datetime] = [] 84 | for event in subject.events: 85 | if event.code in self.outcome_codes: 86 | times.append(event.time) 87 | return times 88 | 89 | def allow_same_time_labels(self) -> bool: 90 | # We cannot allow labels at the same time as the codes since they will generally be available as features ... 91 | return False 92 | 93 | 94 | class OMOPConceptCodeLabeler(CodeLabeler): 95 | """Same as CodeLabeler, but add the extra step of mapping OMOP concept IDs 96 | (stored in `omop_concept_ids`) to femr codes (stored in `codes`).""" 97 | 98 | # parent OMOP concept codes, from which all the outcome 99 | # are derived (as children from our ontology) 100 | original_omop_concept_codes: List[str] = [] 101 | 102 | def __init__( 103 | self, 104 | ontology: femr.ontology.Ontology, 105 | time_horizon: TimeHorizon, 106 | prediction_codes: Optional[List[str]] = None, 107 | prediction_time_adjustment_func: Callable = identity, 108 | ): 109 | outcome_codes: List[str] = [] 110 | for code in self.original_omop_concept_codes: 111 | outcome_codes.extend(ontology.get_all_children(code)) 112 | super().__init__( 113 | outcome_codes=outcome_codes, 114 | time_horizon=time_horizon, 115 | prediction_codes=prediction_codes, 116 | prediction_time_adjustment_func=prediction_time_adjustment_func, 117 | ) 118 | 119 | 120 | ########################################################## 121 | ########################################################## 122 | # meds.Labeling functions derived from CodeLabeler 123 | ########################################################## 124 | ########################################################## 125 | 126 | 127 | class MortalityCodeLabeler(CodeLabeler): 128 | """Apply a label for whether or not a subject dies within the `time_horizon`. 129 | Make prediction at admission time. 130 | """ 131 | 132 | def __init__( 133 | self, 134 | ontology: femr.ontology.Ontology, 135 | time_horizon: TimeHorizon, 136 | prediction_codes: Optional[List[str]] = None, 137 | prediction_time_adjustment_func: Callable = identity, 138 | ): 139 | """Create a Mortality labeler.""" 140 | outcome_codes: List[str] = [] 141 | for code in get_death_concepts(): 142 | outcome_codes.extend(ontology.get_all_children(code)) 143 | 144 | super().__init__( 145 | outcome_codes=outcome_codes, 146 | time_horizon=time_horizon, 147 | prediction_codes=prediction_codes, 148 | prediction_time_adjustment_func=prediction_time_adjustment_func, 149 | ) 150 | 151 | 152 | class LupusCodeLabeler(OMOPConceptCodeLabeler): 153 | """ 154 | meds.Label if subject is diagnosed with Lupus. 155 | """ 156 | 157 | original_omop_concept_codes = ["SNOMED/55464009", "SNOMED/201436003"] 158 | 159 | 160 | ########################################################## 161 | ########################################################## 162 | # Labeling functions derived from OMOPConceptCodeLabeler 163 | ########################################################## 164 | ########################################################## 165 | 166 | 167 | class HypoglycemiaCodeLabeler(OMOPConceptCodeLabeler): 168 | """Apply a label for whether a subject has at 1+ explicitly 169 | coded occurrence(s) of Hypoglycemia in `time_horizon`.""" 170 | 171 | # fmt: off 172 | original_omop_concept_codes = [ 173 | 'SNOMED/267384006', 'SNOMED/421725003', 'SNOMED/719216001', 174 | 'SNOMED/302866003', 'SNOMED/237633009', 'SNOMED/120731000119103', 175 | 'SNOMED/190448007', 'SNOMED/230796005', 'SNOMED/421437000', 176 | 'SNOMED/52767006', 'SNOMED/237637005', 'SNOMED/84371000119108' 177 | ] 178 | # fmt: on 179 | 180 | 181 | class AKICodeLabeler(OMOPConceptCodeLabeler): 182 | """Apply a label for whether a subject has at 1+ explicitly 183 | coded occurrence(s) of AKI in `time_horizon`.""" 184 | 185 | # fmt: off 186 | original_omop_concept_codes = [ 187 | 'SNOMED/14669001', 'SNOMED/298015003', 'SNOMED/35455006', 188 | ] 189 | # fmt: on 190 | 191 | 192 | class AnemiaCodeLabeler(OMOPConceptCodeLabeler): 193 | """Apply a label for whether a subject has at 1+ explicitly 194 | coded occurrence(s) of Anemia in `time_horizon`.""" 195 | 196 | # fmt: off 197 | original_omop_concept_codes = [ 198 | 'SNOMED/271737000', 'SNOMED/713496008', 'SNOMED/713349004', 'SNOMED/767657005', 199 | 'SNOMED/111570005', 'SNOMED/691401000119104', 'SNOMED/691411000119101', 200 | ] 201 | # fmt: on 202 | 203 | 204 | class HyperkalemiaCodeLabeler(OMOPConceptCodeLabeler): 205 | """Apply a label for whether a subject has at 1+ explicitly 206 | coded occurrence(s) of Hyperkalemia in `time_horizon`.""" 207 | 208 | # fmt: off 209 | original_omop_concept_codes = [ 210 | 'SNOMED/14140009', 211 | ] 212 | # fmt: on 213 | 214 | 215 | class HyponatremiaCodeLabeler(OMOPConceptCodeLabeler): 216 | """Apply a label for whether a subject has at 1+ explicitly 217 | coded occurrence(s) of Hyponatremia in `time_horizon`.""" 218 | 219 | # fmt: off 220 | original_omop_concept_codes = [ 221 | 'SNOMED/267447008', 'SNOMED/89627008' 222 | ] 223 | # fmt: on 224 | 225 | 226 | class ThrombocytopeniaCodeLabeler(OMOPConceptCodeLabeler): 227 | """Apply a label for whether a subject has at 1+ explicitly 228 | coded occurrence(s) of Thrombocytopenia in `time_horizon`.""" 229 | 230 | # fmt: off 231 | original_omop_concept_codes = [ 232 | 'SNOMED/267447008', 'SNOMED/89627008', 233 | ] 234 | # fmt: on 235 | 236 | 237 | class NeutropeniaCodeLabeler(OMOPConceptCodeLabeler): 238 | """Apply a label for whether a subject has at 1+ explicitly 239 | coded occurrence(s) of Neutkropenia in `time_horizon`.""" 240 | 241 | # fmt: off 242 | original_omop_concept_codes = [ 243 | 'SNOMED/165517008', 244 | ] 245 | # fmt: on 246 | -------------------------------------------------------------------------------- /tutorials/3_Count Featurization And Modeling.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "cf343c1c-ad8d-4fdb-a142-c501e579e288", 6 | "metadata": {}, 7 | "source": [ 8 | "# Count Featurization And Models\n", 9 | "\n", 10 | "FEMR contains several utilities to implement common tabular featurization strategies.\n", 11 | "\n", 12 | "[CountFeaturizer](https://github.com/som-shahlab/femr/blob/main/src/femr/featurizers/featurizers.py#L180) is the main class and it documents the various supported options.\n", 13 | "\n", 14 | "In order to use the featurizer, you must construct a featurizer list, prepare the featurizers, and then featurize." 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "id": "892ab2d5-0c5a-43c9-a210-9201f775e4fb", 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "import pickle\n", 25 | "import femr.featurizers\n", 26 | "import femr.labelers\n", 27 | "import meds\n", 28 | "import pandas as pd\n", 29 | "import meds_reader\n", 30 | "\n", 31 | "# Load some labels\n", 32 | "labels = pd.read_parquet('input/labels.parquet')\n", 33 | "\n", 34 | "# Load our data\n", 35 | "database = meds_reader.SubjectDatabase(\"input/synthetic_meds\")\n", 36 | " \n", 37 | "# Define our featurizer\n", 38 | "\n", 39 | "# Note that we are using both ages and counts here\n", 40 | "age = femr.featurizers.AgeFeaturizer(is_normalize=False)\n", 41 | "count = femr.featurizers.CountFeaturizer(string_value_combination=True)\n", 42 | "featurizer_age_count = femr.featurizers.FeaturizerList([age, count])\n", 43 | "\n", 44 | "# Preprocessing the featurizers, which includes processes such as normalizing age.\n", 45 | "featurizer_age_count.preprocess_featurizers(database, labels)\n", 46 | "\n", 47 | "# Actually do the featurization\n", 48 | "features = featurizer_age_count.featurize(database, labels)" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 2, 54 | "id": "112fe99d", 55 | "metadata": {}, 56 | "outputs": [ 57 | { 58 | "name": "stdout", 59 | "output_type": "stream", 60 | "text": [ 61 | "subject_ids (200,)\n", 62 | "feature_times (200,)\n", 63 | "features (200, 1884)\n" 64 | ] 65 | } 66 | ], 67 | "source": [ 68 | "# Results consist of three components, the subject ids, feature times, and the features themselves\n", 69 | "\n", 70 | "for k, v in features.items():\n", 71 | " print(k, v.shape)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "id": "fafa8ea8", 77 | "metadata": {}, 78 | "source": [ 79 | "# Joining features and labels\n", 80 | "\n", 81 | "Given a feature set, it's important to be able to join a set of labels to those features.\n", 82 | "\n", 83 | "This can be done with femr.featurizers.join_labels" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 3, 89 | "id": "cd0f43fd", 90 | "metadata": {}, 91 | "outputs": [ 92 | { 93 | "name": "stdout", 94 | "output_type": "stream", 95 | "text": [ 96 | "boolean_values (200,)\n", 97 | "subject_ids (200,)\n", 98 | "times (200,)\n", 99 | "features (200, 1884)\n" 100 | ] 101 | } 102 | ], 103 | "source": [ 104 | "features_and_labels = femr.featurizers.join_labels(features, labels)\n", 105 | "\n", 106 | "for k, v in features_and_labels.items():\n", 107 | " print(k, v.shape)" 108 | ] 109 | }, 110 | { 111 | "cell_type": "markdown", 112 | "id": "66934476-c40a-467c-8702-b0d7021d92bf", 113 | "metadata": {}, 114 | "source": [ 115 | "# Data Splitting\n", 116 | "\n", 117 | "FEMR contains utilities for doing hash based subject splitting, where splits are determined based on a hash value of the subject id.\n", 118 | "\n", 119 | "This is a deterministic approximate approach for splitting that is both stable and scalable." 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 4, 125 | "id": "01acd922-668b-481b-8dbb-54ab6ae433af", 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "import femr.splits\n", 130 | "import numpy as np\n", 131 | "\n", 132 | "# We split into a global training and test set\n", 133 | "split = femr.splits.generate_hash_split(set(features_and_labels['subject_ids']), seed=87, frac_test=0.3)\n", 134 | "\n", 135 | "train_mask = np.isin(features_and_labels['subject_ids'], split.train_subject_ids)\n", 136 | "test_mask = np.isin(features_and_labels['subject_ids'], split.test_subject_ids)\n", 137 | "\n", 138 | "percent_train = .70\n", 139 | "X_train, y_train = (\n", 140 | " features_and_labels['features'][train_mask],\n", 141 | " features_and_labels['boolean_values'][train_mask],\n", 142 | ")\n", 143 | "X_test, y_test = (\n", 144 | " features_and_labels['features'][test_mask],\n", 145 | " features_and_labels['boolean_values'][test_mask],\n", 146 | ")" 147 | ] 148 | }, 149 | { 150 | "cell_type": "markdown", 151 | "id": "9aaeb7e5-eb48-46f5-ae59-9abfbc0dcef5", 152 | "metadata": {}, 153 | "source": [ 154 | "# Building Models\n", 155 | "\n", 156 | "The generated features can then be used to build your standard models. In this case we construct both logistic regression and XGBoost models and evaluate them.\n", 157 | "\n", 158 | "Performance is perfect since our task (predicting gender) is 100% determined by the features" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 5, 164 | "id": "caae3126-1437-408e-b25f-04568e15c96a", 165 | "metadata": {}, 166 | "outputs": [ 167 | { 168 | "ename": "ModuleNotFoundError", 169 | "evalue": "No module named 'xgboost'", 170 | "output_type": "error", 171 | "traceback": [ 172 | "\u001b[31m---------------------------------------------------------------------------\u001b[39m", 173 | "\u001b[31mModuleNotFoundError\u001b[39m Traceback (most recent call last)", 174 | "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[5]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mxgboost\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mxgb\u001b[39;00m\n\u001b[32m 2\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01msklearn\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mlinear_model\u001b[39;00m\n\u001b[32m 3\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01msklearn\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mmetrics\u001b[39;00m\n", 175 | "\u001b[31mModuleNotFoundError\u001b[39m: No module named 'xgboost'" 176 | ] 177 | } 178 | ], 179 | "source": [ 180 | "import xgboost as xgb\n", 181 | "import sklearn.linear_model\n", 182 | "import sklearn.metrics\n", 183 | "import sklearn.preprocessing\n", 184 | "\n", 185 | "def run_analysis(title: str, y_train, y_train_proba, y_test, y_test_proba):\n", 186 | " print(f\"---- {title} ----\")\n", 187 | " print(\"Train:\")\n", 188 | " print_metrics(y_train, y_train_proba)\n", 189 | " print(\"Test:\")\n", 190 | " print_metrics(y_test, y_test_proba)\n", 191 | "\n", 192 | "def print_metrics(y_true, y_proba):\n", 193 | " y_pred = y_proba > 0.5\n", 194 | " auroc = sklearn.metrics.roc_auc_score(y_true, y_proba)\n", 195 | " aps = sklearn.metrics.average_precision_score(y_true, y_proba)\n", 196 | " accuracy = sklearn.metrics.accuracy_score(y_true, y_pred)\n", 197 | " f1 = sklearn.metrics.f1_score(y_true, y_pred)\n", 198 | " print(\"\\tAUROC:\", auroc)\n", 199 | " print(\"\\tAPS:\", aps)\n", 200 | " print(\"\\tAccuracy:\", accuracy)\n", 201 | " print(\"\\tF1 Score:\", f1)\n", 202 | "\n", 203 | "\n", 204 | "scaler = sklearn.preprocessing.MaxAbsScaler().fit(\n", 205 | " X_train\n", 206 | ") # best for sparse data: see https://scikit-learn.org/stable/modules/preprocessing.html#scaling-sparse-data\n", 207 | "X_train_scaled = scaler.fit_transform(X_train)\n", 208 | "X_test_scaled = scaler.transform(X_test)\n", 209 | "model = sklearn.linear_model.LogisticRegressionCV(penalty=\"l2\", solver=\"liblinear\").fit(X_train_scaled, y_train)\n", 210 | "y_train_proba = model.predict_proba(X_train_scaled)[::, 1]\n", 211 | "y_test_proba = model.predict_proba(X_test_scaled)[::, 1]\n", 212 | "run_analysis(\"Logistic Regression\", y_train, y_train_proba, y_test, y_test_proba)\n", 213 | "\n", 214 | "\n", 215 | "# XGBoost\n", 216 | "model = xgb.XGBClassifier()\n", 217 | "model.fit(X_train, y_train)\n", 218 | "y_train_proba = model.predict_proba(X_train)[::, 1]\n", 219 | "y_test_proba = model.predict_proba(X_test)[::, 1]\n", 220 | "run_analysis(\"XGBoost\", y_train, y_train_proba, y_test, y_test_proba)" 221 | ] 222 | } 223 | ], 224 | "metadata": { 225 | "kernelspec": { 226 | "display_name": "Python 3 (ipykernel)", 227 | "language": "python", 228 | "name": "python3" 229 | }, 230 | "language_info": { 231 | "codemirror_mode": { 232 | "name": "ipython", 233 | "version": 3 234 | }, 235 | "file_extension": ".py", 236 | "mimetype": "text/x-python", 237 | "name": "python", 238 | "nbconvert_exporter": "python", 239 | "pygments_lexer": "ipython3", 240 | "version": "3.13.3" 241 | } 242 | }, 243 | "nbformat": 4, 244 | "nbformat_minor": 5 245 | } 246 | -------------------------------------------------------------------------------- /tests/test_transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import datetime 4 | 5 | import meds 6 | import meds_reader.transform 7 | from femr_test_tools import DummyEvent, DummySubject 8 | 9 | from femr.transforms import delta_encode, remove_nones 10 | from femr.transforms.stanford import ( 11 | move_billing_codes, 12 | move_pre_birth, 13 | move_to_day_end, 14 | move_visit_start_to_first_event_start, 15 | ) 16 | 17 | 18 | def test_pre_birth() -> None: 19 | subject = DummySubject( 20 | subject_id=123, 21 | events=[ 22 | DummyEvent(time=datetime.datetime(1999, 7, 2), code="1234"), 23 | DummyEvent(time=datetime.datetime(1999, 7, 9), code=meds.birth_code), 24 | DummyEvent(time=datetime.datetime(1999, 7, 11), code="12345"), 25 | ], 26 | ) 27 | 28 | expected = DummySubject( 29 | subject_id=123, 30 | events=[ 31 | DummyEvent(time=datetime.datetime(1999, 7, 9), code="1234"), 32 | DummyEvent(time=datetime.datetime(1999, 7, 9), code=meds.birth_code), 33 | DummyEvent(time=datetime.datetime(1999, 7, 11), code="12345"), 34 | ], 35 | ) 36 | 37 | assert move_pre_birth(subject) == expected 38 | 39 | 40 | def test_move_visit_start_ignores_other_visits() -> None: 41 | subject = DummySubject( 42 | subject_id=123, 43 | events=[ 44 | # A non-visit event with no explicit start time 45 | DummyEvent(time=datetime.datetime(1999, 7, 2), code="1234", visit_id=9999), 46 | # A visit event with just date specified 47 | DummyEvent( 48 | time=datetime.datetime(1999, 7, 2), 49 | code="4567", 50 | visit_id=9999, 51 | table="visit", 52 | ), 53 | # A non-visit event from a separate visit ID 54 | DummyEvent( 55 | time=datetime.datetime(1999, 7, 2, 11), 56 | code="2345", 57 | visit_id=8888, 58 | ), 59 | # First recorded non-visit event for visit ID 9999 60 | DummyEvent( 61 | time=datetime.datetime(1999, 7, 2, 12), 62 | code="3456", 63 | visit_id=9999, 64 | ), 65 | ], 66 | ) 67 | 68 | # Note that events are implicitly sorted first by start time, then by code: 69 | # https://github.com/som-shahlab/femr/blob/main/src/femr/__init__.py#L69 70 | expected = DummySubject( 71 | subject_id=123, 72 | events=[ 73 | # A non-visit event with no explicit start time 74 | DummyEvent(time=datetime.datetime(1999, 7, 2), code="1234", visit_id=9999), 75 | # A non-visit event from a separate visit ID 76 | DummyEvent( 77 | time=datetime.datetime(1999, 7, 2, 11), 78 | code="2345", 79 | visit_id=8888, 80 | ), 81 | # A visit event with just date specified 82 | DummyEvent( 83 | time=datetime.datetime(1999, 7, 2, 12), 84 | code="4567", 85 | visit_id=9999, 86 | table="visit", 87 | ), 88 | # First recorded non-visit event for visit ID 9999 89 | DummyEvent( 90 | time=datetime.datetime(1999, 7, 2, 12), 91 | code="3456", 92 | visit_id=9999, 93 | ), 94 | ], 95 | ) 96 | 97 | assert move_visit_start_to_first_event_start(subject) == expected 98 | 99 | 100 | def test_move_visit_start_minute_after_midnight() -> None: 101 | subject = DummySubject( 102 | subject_id=123, 103 | events=[ 104 | DummyEvent(time=datetime.datetime(1999, 7, 2), code="3456", visit_id=9999, table="visit"), 105 | DummyEvent(time=datetime.datetime(1999, 7, 2), code="1234", visit_id=9999), 106 | DummyEvent(time=datetime.datetime(1999, 7, 2, 0, 1), code="2345", visit_id=9999), 107 | DummyEvent(time=datetime.datetime(1999, 7, 2, 2, 12), code="4567", visit_id=9999), 108 | ], 109 | ) 110 | 111 | expected = DummySubject( 112 | subject_id=123, 113 | events=[ 114 | DummyEvent(time=datetime.datetime(1999, 7, 2), code="1234", visit_id=9999), 115 | DummyEvent(time=datetime.datetime(1999, 7, 2, 0, 1), code="3456", visit_id=9999, table="visit"), 116 | DummyEvent(time=datetime.datetime(1999, 7, 2, 0, 1), code="2345", visit_id=9999), 117 | DummyEvent(time=datetime.datetime(1999, 7, 2, 2, 12), code="4567", visit_id=9999), 118 | ], 119 | ) 120 | 121 | assert move_visit_start_to_first_event_start(subject) == expected 122 | 123 | 124 | def test_move_visit_start_doesnt_move_without_event() -> None: 125 | subject = DummySubject( 126 | subject_id=123, 127 | events=[ 128 | DummyEvent(time=datetime.datetime(1999, 7, 2), code="1234", visit_id=9999), 129 | DummyEvent(time=datetime.datetime(1999, 7, 2), code="3456", visit_id=9999, table="visit"), 130 | DummyEvent(time=datetime.datetime(1999, 7, 2), code="2345", visit_id=9999), 131 | ], 132 | ) 133 | 134 | # None of the non-visit events have start time > '00:00:00' so visit event 135 | # start time is unchanged, though order changes based on code under resort. 136 | 137 | assert move_visit_start_to_first_event_start(subject) == subject 138 | 139 | 140 | def test_move_to_day_end() -> None: 141 | subject = meds_reader.transform.MutableSubject( 142 | subject_id=123, 143 | events=[ 144 | meds_reader.transform.MutableEvent(time=datetime.datetime(1999, 7, 2), code="1234"), 145 | meds_reader.transform.MutableEvent(time=datetime.datetime(1999, 7, 2, 12), code="4321"), 146 | meds_reader.transform.MutableEvent(time=datetime.datetime(1999, 7, 9), code=meds.birth_code), 147 | ], 148 | ) 149 | 150 | expected = meds_reader.transform.MutableSubject( 151 | subject_id=123, 152 | events=[ 153 | meds_reader.transform.MutableEvent(time=datetime.datetime(1999, 7, 2, 12), code="4321"), 154 | meds_reader.transform.MutableEvent(time=datetime.datetime(1999, 7, 2, 23, 59), code="1234"), 155 | meds_reader.transform.MutableEvent(time=datetime.datetime(1999, 7, 9, 23, 59), code=meds.birth_code), 156 | ], 157 | ) 158 | 159 | assert move_to_day_end(subject) == expected 160 | 161 | 162 | def test_remove_nones() -> None: 163 | subject = DummySubject( 164 | subject_id=123, 165 | events=[ 166 | DummyEvent(time=datetime.datetime(1999, 7, 2), code="1234"), # No value, to be removed 167 | DummyEvent(time=datetime.datetime(1999, 7, 2, 12), code="1234", numeric_value=3), 168 | DummyEvent(time=datetime.datetime(1999, 7, 9), code=meds.birth_code), 169 | ], 170 | ) 171 | 172 | expected = DummySubject( 173 | subject_id=123, 174 | events=[ 175 | DummyEvent(time=datetime.datetime(1999, 7, 2, 12), code="1234", numeric_value=3), 176 | DummyEvent(time=datetime.datetime(1999, 7, 9), code=meds.birth_code), 177 | ], 178 | ) 179 | 180 | assert remove_nones(subject) == expected 181 | 182 | 183 | def test_delta_encode() -> None: 184 | subject = DummySubject( 185 | subject_id=123, 186 | events=[ 187 | DummyEvent(time=datetime.datetime(1999, 7, 2), code="1234"), 188 | DummyEvent(time=datetime.datetime(1999, 7, 2), code="1234"), 189 | DummyEvent(time=datetime.datetime(1999, 7, 2, 12), code="1234", numeric_value=3), 190 | DummyEvent(time=datetime.datetime(1999, 7, 2, 14), code="1234", numeric_value=3), 191 | DummyEvent(time=datetime.datetime(1999, 7, 2, 19), code="1234", numeric_value=5), 192 | DummyEvent(time=datetime.datetime(1999, 7, 2, 20), code="1234", numeric_value=3), 193 | ], 194 | ) 195 | 196 | expected = DummySubject( 197 | subject_id=123, 198 | events=[ 199 | DummyEvent(time=datetime.datetime(1999, 7, 2), code="1234"), 200 | DummyEvent(time=datetime.datetime(1999, 7, 2, 12), code="1234", numeric_value=3), 201 | DummyEvent(time=datetime.datetime(1999, 7, 2, 19), code="1234", numeric_value=5), 202 | DummyEvent(time=datetime.datetime(1999, 7, 2, 20), code="1234", numeric_value=3), 203 | ], 204 | ) 205 | 206 | assert delta_encode(subject) == expected 207 | 208 | 209 | def test_move_billing_codes() -> None: 210 | subject = DummySubject( 211 | subject_id=123, 212 | events=[ 213 | DummyEvent( 214 | time=datetime.datetime(1999, 7, 2, 0, 0), 215 | code=1234, 216 | visit_id=10, 217 | clarity_table="lpch_pat_enc", 218 | end=datetime.datetime(1999, 7, 20), 219 | ), 220 | DummyEvent( 221 | time=datetime.datetime(1999, 7, 9, 0, 0), 222 | code="SNOMED/184099003", 223 | visit_id=10, 224 | clarity_table="lpch_pat_enc_dx", 225 | ), 226 | DummyEvent( 227 | time=datetime.datetime(1999, 7, 10, 0, 0), code=42165, visit_id=10, clarity_table="shc_pat_enc_dx" 228 | ), 229 | DummyEvent(time=datetime.datetime(1999, 7, 11, 0, 0), code=12345, visit_id=10, clarity_table=None), 230 | DummyEvent(time=datetime.datetime(1999, 7, 13, 0, 0), code=123, visit_id=11, clarity_table=None), 231 | ], 232 | ) 233 | 234 | expected = DummySubject( 235 | subject_id=123, 236 | events=[ 237 | DummyEvent( 238 | time=datetime.datetime(1999, 7, 2, 0, 0), 239 | code=1234, 240 | visit_id=10, 241 | clarity_table="lpch_pat_enc", 242 | end=datetime.datetime(1999, 7, 20), 243 | ), 244 | DummyEvent(time=datetime.datetime(1999, 7, 11, 0, 0), code=12345, visit_id=10, clarity_table=None), 245 | DummyEvent(time=datetime.datetime(1999, 7, 13, 0, 0), code=123, visit_id=11, clarity_table=None), 246 | DummyEvent( 247 | time=datetime.datetime(1999, 7, 20, 0, 0), 248 | code="SNOMED/184099003", 249 | visit_id=10, 250 | clarity_table="lpch_pat_enc_dx", 251 | ), 252 | DummyEvent( 253 | time=datetime.datetime(1999, 7, 20, 0, 0), code=42165, visit_id=10, clarity_table="shc_pat_enc_dx" 254 | ), 255 | ], 256 | ) 257 | 258 | assert move_billing_codes(subject) == expected 259 | -------------------------------------------------------------------------------- /tests/labelers/test_TimeHorizonEventLabeler.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa: E402 2 | 3 | import datetime 4 | import os 5 | import pathlib 6 | import sys 7 | import warnings 8 | from typing import List 9 | 10 | import meds 11 | import meds_reader 12 | from femr_test_tools import EventsWithLabels, run_test_for_labeler 13 | 14 | from femr.labelers import TimeHorizon, TimeHorizonEventLabeler 15 | 16 | 17 | class DummyLabeler(TimeHorizonEventLabeler): 18 | """Dummy labeler that returns True if the event's `code` is in `self.outcome_codes`.""" 19 | 20 | def __init__(self, outcome_codes: List[int], time_horizon: TimeHorizon, allow_same_time: bool = True): 21 | self.outcome_codes: List[str] = [str(a) for a in outcome_codes] 22 | self.time_horizon: TimeHorizon = time_horizon 23 | self.allow_same_time = allow_same_time 24 | 25 | def allow_same_time_labels(self) -> bool: 26 | return self.allow_same_time 27 | 28 | def get_prediction_times(self, subject: meds_reader.Subject) -> List[datetime.datetime]: 29 | return sorted(list({e.time for e in subject.events})) 30 | 31 | def get_time_horizon(self) -> TimeHorizon: 32 | return self.time_horizon 33 | 34 | def get_outcome_times(self, subject: meds_reader.Subject) -> List[datetime.datetime]: 35 | times: List[datetime.datetime] = [] 36 | for e in subject.events: 37 | if e.code in self.outcome_codes: 38 | times.append(e.time) 39 | return times 40 | 41 | 42 | def test_no_outcomes(tmp_path: pathlib.Path): 43 | # No outcomes occur in this subject's timeline 44 | time_horizon = TimeHorizon(datetime.timedelta(days=0), datetime.timedelta(days=180)) 45 | labeler = DummyLabeler([100], time_horizon) 46 | events_with_labels: EventsWithLabels = [ 47 | # fmt: off 48 | (((2015, 1, 3), 2, None), "duplicate"), 49 | (((2015, 1, 3), 1, None), "duplicate"), 50 | (((2015, 1, 3), 3, None), False), 51 | (((2015, 10, 5), 1, None), False), 52 | (((2018, 1, 3), 2, None), False), 53 | (((2018, 3, 3), 1, None), False), 54 | (((2018, 5, 3), 2, None), False), 55 | (((2018, 5, 3, 11), 1, None), False), 56 | (((2018, 5, 4), 1, None), False), 57 | (((2018, 12, 4), 1, None), "out of range"), 58 | # fmt: on 59 | ] 60 | run_test_for_labeler(labeler, events_with_labels, help_text="test_no_outcomes") 61 | 62 | 63 | def test_horizon_0_180_days(tmp_path: pathlib.Path): 64 | # (0, 180) days 65 | time_horizon = TimeHorizon(datetime.timedelta(days=0), datetime.timedelta(days=180)) 66 | labeler = DummyLabeler([2], time_horizon) 67 | events_with_labels: EventsWithLabels = [ 68 | # fmt: off 69 | (((2015, 1, 3), 2, None), "duplicate"), 70 | (((2015, 1, 3), 1, None), "duplicate"), 71 | (((2015, 1, 3), 3, None), True), 72 | (((2015, 10, 5), 1, None), False), 73 | (((2018, 1, 3), 2, None), True), 74 | (((2018, 3, 3), 1, None), True), 75 | (((2018, 5, 3), 2, None), True), 76 | (((2018, 5, 3, 11), 1, None), False), 77 | (((2018, 5, 4), 1, None), False), 78 | (((2018, 12, 4), 1, None), "out of range"), 79 | # fmt: on 80 | ] 81 | run_test_for_labeler(labeler, events_with_labels, help_text="test_horizon_0_180_days") 82 | 83 | 84 | def test_horizon_0_180_days_no_same(tmp_path: pathlib.Path): 85 | # (0, 180) days 86 | time_horizon = TimeHorizon(datetime.timedelta(days=0), datetime.timedelta(days=180)) 87 | labeler = DummyLabeler([2], time_horizon, allow_same_time=False) 88 | events_with_labels: EventsWithLabels = [ 89 | # fmt: off 90 | (((2015, 1, 3), 2, None), "duplicate"), 91 | (((2015, 1, 3), 1, None), "duplicate"), 92 | (((2015, 1, 3), 3, None), "same"), 93 | (((2015, 10, 5), 1, None), False), 94 | (((2018, 1, 3), 2, None), "same"), 95 | (((2018, 3, 3), 1, None), True), 96 | (((2018, 5, 3), 2, None), "same"), 97 | (((2018, 5, 3, 11), 1, None), False), 98 | (((2018, 5, 4), 1, None), False), 99 | (((2018, 12, 4), 1, None), "out of range"), 100 | # fmt: on 101 | ] 102 | run_test_for_labeler(labeler, events_with_labels, help_text="test_horizon_0_180_days") 103 | 104 | 105 | def test_horizon_1_180_days(tmp_path: pathlib.Path): 106 | # (1, 180) days 107 | time_horizon = TimeHorizon(datetime.timedelta(days=1), datetime.timedelta(days=180)) 108 | labeler = DummyLabeler([2], time_horizon) 109 | events_with_labels: EventsWithLabels = [ 110 | # fmt: off 111 | (((2015, 1, 3), 2, None), "duplicate"), 112 | (((2015, 1, 3), 1, None), "duplicate"), 113 | (((2015, 1, 3), 3, None), False), 114 | (((2015, 10, 5), 1, None), False), 115 | (((2018, 1, 3), 2, None), True), 116 | (((2018, 3, 3), 1, None), True), 117 | (((2018, 5, 3), 2, None), False), 118 | (((2018, 5, 3, 11), 1, None), False), 119 | (((2018, 5, 4), 1, None), False), 120 | (((2018, 12, 4), 1, None), "out of range"), 121 | # fmt: on 122 | ] 123 | run_test_for_labeler(labeler, events_with_labels, help_text="test_horizon_1_180_days") 124 | 125 | 126 | def test_horizon_180_365_days(tmp_path: pathlib.Path): 127 | # (180, 365) days 128 | time_horizon = TimeHorizon(datetime.timedelta(days=180), datetime.timedelta(days=365)) 129 | labeler = DummyLabeler([2], time_horizon) 130 | events_with_labels: EventsWithLabels = [ 131 | # fmt: off 132 | (((2000, 1, 3), 2, None), True), 133 | (((2000, 10, 5), 2, None), False), 134 | (((2002, 1, 5), 2, None), True), 135 | (((2002, 3, 1), 1, None), True), 136 | (((2002, 4, 5), 3, None), True), 137 | (((2002, 4, 12), 1, None), True), 138 | (((2002, 12, 5), 2, None), False), 139 | (((2002, 12, 10), 1, None), False), 140 | (((2004, 1, 10), 2, None), False), 141 | (((2008, 1, 10), 2, None), "out of range"), 142 | # fmt: on 143 | ] 144 | run_test_for_labeler(labeler, events_with_labels, help_text="test_horizon_180_365_days") 145 | 146 | 147 | def test_horizon_0_0_days(tmp_path: pathlib.Path): 148 | # (0, 0) days 149 | time_horizon = TimeHorizon(datetime.timedelta(days=0), datetime.timedelta(days=0)) 150 | labeler = DummyLabeler([2], time_horizon) 151 | events_with_labels: EventsWithLabels = [ 152 | # fmt: off 153 | (((2015, 1, 3), 2, None), "duplicate"), 154 | (((2015, 1, 3), 1, None), True), 155 | (((2015, 1, 4), 1, None), False), 156 | (((2015, 1, 5), 2, None), True), 157 | (((2015, 1, 5, 10), 1, None), False), 158 | (((2015, 1, 6), 2, None), True), 159 | # fmt: on 160 | ] 161 | run_test_for_labeler(labeler, events_with_labels, help_text="test_horizon_0_0_days") 162 | 163 | 164 | def test_horizon_10_10_days(tmp_path: pathlib.Path): 165 | # (10, 10) days 166 | time_horizon = TimeHorizon(datetime.timedelta(days=10), datetime.timedelta(days=10)) 167 | labeler = DummyLabeler([2], time_horizon) 168 | events_with_labels: EventsWithLabels = [ 169 | # fmt: off 170 | (((2015, 1, 3), 2, None), False), 171 | (((2015, 1, 13), 1, None), True), 172 | (((2015, 1, 23), 2, None), True), 173 | (((2015, 2, 2), 2, None), False), 174 | (((2015, 3, 10), 1, None), True), 175 | (((2015, 3, 20), 2, None), False), 176 | (((2015, 3, 29), 2, None), "out of range"), 177 | (((2015, 3, 30), 1, None), "out of range"), 178 | # fmt: on 179 | ] 180 | run_test_for_labeler(labeler, events_with_labels, help_text="test_horizon_10_10_days") 181 | 182 | 183 | def test_horizon_0_1000000_days(tmp_path: pathlib.Path): 184 | # (0, 1000000) days 185 | time_horizon = TimeHorizon(datetime.timedelta(days=0), datetime.timedelta(days=1000000)) 186 | labeler = DummyLabeler([2], time_horizon) 187 | events_with_labels: EventsWithLabels = [ 188 | # fmt: off 189 | (((2000, 1, 3), 2, None), True), 190 | (((2001, 10, 5), 1, None), True), 191 | (((2020, 10, 5), 2, None), True), 192 | (((2021, 10, 5), 1, None), True), 193 | (((2050, 1, 10), 2, None), True), 194 | (((2051, 1, 10), 1, None), False), 195 | (((5000, 1, 10), 1, None), "out of range"), 196 | # fmt: on 197 | ] 198 | run_test_for_labeler(labeler, events_with_labels, help_text="test_horizon_0_1000000_days") 199 | 200 | 201 | def test_horizon_5_10_hours(tmp_path: pathlib.Path): 202 | # (5 hours, 10.5 hours) 203 | time_horizon = TimeHorizon(datetime.timedelta(hours=5), datetime.timedelta(hours=10, minutes=30)) 204 | labeler = DummyLabeler([2], time_horizon) 205 | events_with_labels: EventsWithLabels = [ 206 | # fmt: off 207 | (((2015, 1, 1, 0, 0), 1, None), True), 208 | (((2015, 1, 1, 10, 29), 2, None), False), 209 | (((2015, 1, 1, 10, 30), 1, None), False), 210 | (((2015, 1, 1, 10, 31), 1, None), False), 211 | # 212 | (((2016, 1, 1, 0, 0), 1, None), True), 213 | (((2016, 1, 1, 10, 29), 1, None), False), 214 | (((2016, 1, 1, 10, 30), 2, None), False), 215 | (((2016, 1, 1, 10, 31), 1, None), False), 216 | # 217 | (((2017, 1, 1, 0, 0), 1, None), False), 218 | (((2017, 1, 1, 10, 29), 1, None), False), 219 | (((2017, 1, 1, 10, 30), 1, None), False), 220 | (((2017, 1, 1, 10, 31), 2, None), False), 221 | # 222 | (((2018, 1, 1, 0, 0), 1, None), False), 223 | (((2018, 1, 1, 4, 59, 59), 2, None), False), 224 | (((2018, 1, 1, 5), 1, None), False), 225 | # 226 | (((2019, 1, 1, 0, 0), 1, None), True), 227 | (((2019, 1, 1, 4, 59, 59), 1, None), "out of range"), 228 | (((2019, 1, 1, 5), 2, None), "out of range"), 229 | # fmt: on 230 | ] 231 | run_test_for_labeler(labeler, events_with_labels, help_text="test_horizon_5_10_hours") 232 | 233 | 234 | def test_horizon_infinite(tmp_path: pathlib.Path): 235 | # Infinite horizon 236 | time_horizon = TimeHorizon( 237 | datetime.timedelta(days=10), 238 | None, 239 | ) 240 | labeler = DummyLabeler([2], time_horizon) 241 | events_with_labels: EventsWithLabels = [ 242 | # fmt: off 243 | (((1950, 1, 3), 1, None), True), 244 | (((2000, 1, 3), 1, None), True), 245 | (((2001, 10, 5), 1, None), True), 246 | (((2020, 10, 5), 1, None), True), 247 | (((2021, 10, 5), 1, None), True), 248 | (((2050, 1, 10), 2, None), True), 249 | (((2050, 1, 20), 2, None), False), 250 | (((2051, 1, 10), 1, None), False), 251 | (((5000, 1, 10), 1, None), False), 252 | # fmt: on 253 | ] 254 | run_test_for_labeler(labeler, events_with_labels, help_text="test_horizon_infinite") 255 | -------------------------------------------------------------------------------- /src/femr/models/tokenizer/flat_tokenizer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import bisect 4 | import collections 5 | import datetime 6 | import functools 7 | import math 8 | import os 9 | from typing import Any, Dict, Iterator, List, Mapping, Optional, Set, Tuple, Union 10 | 11 | import meds_reader 12 | import msgpack 13 | import numpy as np 14 | import transformers 15 | 16 | import femr.ontology 17 | import femr.stat_utils 18 | import femr.pat_utils 19 | import pyarrow as pa 20 | 21 | 22 | def train_tokenizer( 23 | db: meds_reader.SubjectDatabase, 24 | vocab_size: int, 25 | num_numeric: int = 1000, 26 | ) -> FlatTokenizer: 27 | """Train a FEMR tokenizer from the given dataset""" 28 | 29 | statistics = functools.reduce( 30 | agg_statistics, 31 | db.map( 32 | functools.partial( 33 | map_statistics, 34 | num_subjects=len(db), 35 | ) 36 | ), 37 | ) 38 | 39 | return FlatTokenizer( 40 | convert_statistics_to_msgpack(statistics, vocab_size, num_numeric) 41 | ) 42 | 43 | 44 | def agg_statistics(stats1, stats2): 45 | stats1["age_stats"].combine(stats2["age_stats"]) 46 | 47 | for n in ("code_counts", "text_counts"): 48 | for k, v in stats2[n].items(): 49 | stats1[n][k] += v 50 | 51 | if stats1.get("numeric_samples_by_lab"): 52 | for k, v in stats2["numeric_samples_by_lab"].items(): 53 | stats1["numeric_samples_by_lab"][k].combine(v) 54 | 55 | return stats1 56 | 57 | 58 | def normalize_unit(unit): 59 | if unit: 60 | return unit.lower().replace(" ", "") 61 | else: 62 | return None 63 | 64 | def is_close_float(t, f): 65 | if f is None: 66 | return False 67 | try: 68 | v = float(t) 69 | return math.abs(f - v) < 0.01 * f 70 | except: 71 | return False 72 | 73 | def map_statistics( 74 | subjects: Iterator[meds_reader.Subject], 75 | *, 76 | num_subjects: int, 77 | ) -> Mapping[str, Any]: 78 | age_stats = femr.stat_utils.OnlineStatistics() 79 | code_counts: Dict[str, float] = collections.defaultdict(float) 80 | 81 | numeric_samples_by_lab = collections.defaultdict(functools.partial(femr.stat_utils.ReservoirSampler, 1_000)) 82 | 83 | text_counts: Dict[Any, float] = collections.defaultdict(float) 84 | 85 | for subject in subjects: 86 | total_events = len(subject.events) 87 | 88 | if total_events == 0: 89 | continue 90 | 91 | weight = 1.0 / (num_subjects * total_events) 92 | birth_date = femr.pat_utils.get_subject_birthdate(subject) 93 | for event in subject.events: 94 | if event.time is not None and event.time != birth_date: 95 | age_stats.add(weight, (event.time - birth_date).total_seconds()) 96 | 97 | assert numeric_samples_by_lab is not None 98 | if event.numeric_value is not None: 99 | numeric_samples_by_lab[event.code].add(event.numeric_value, weight) 100 | elif event.text_value is not None: 101 | text_counts[(event.code, event.text_value)] += weight 102 | else: 103 | code_counts[event.code] += weight 104 | 105 | return { 106 | "age_stats": age_stats, 107 | "code_counts": code_counts, 108 | "text_counts": text_counts, 109 | "numeric_samples_by_lab": numeric_samples_by_lab, 110 | } 111 | 112 | 113 | def convert_statistics_to_msgpack( 114 | statistics, vocab_size: int, num_numeric: int, 115 | ): 116 | vocab = [] 117 | 118 | for code, weight in statistics["code_counts"].items(): 119 | entry = { 120 | "type": "code", 121 | "code_string": code, 122 | "weight": weight * math.log(weight) + (1 - weight) * math.log(1 - weight), 123 | } 124 | vocab.append(entry) 125 | 126 | for (code, text), weight in statistics["text_counts"].items(): 127 | entry = { 128 | "type": "text", 129 | "code_string": code, 130 | "text_string": text, 131 | "weight": weight * math.log(weight) + (1 - weight) * math.log(1 - weight), 132 | } 133 | vocab.append(entry) 134 | 135 | for code, reservoir in statistics["numeric_samples_by_lab"].items(): 136 | weight = reservoir.total_weight / 10 137 | samples = reservoir.samples 138 | samples.sort() 139 | 140 | samples_per_bin = (len(samples) + 9) // 10 141 | 142 | for bin_index in range(0, 10): 143 | if bin_index == 0: 144 | start_val = float("-inf") 145 | else: 146 | if bin_index * samples_per_bin >= len(samples): 147 | continue 148 | start_val = samples[bin_index * samples_per_bin] 149 | 150 | if bin_index == 9 or (bin_index + 1) * samples_per_bin >= len(samples): 151 | end_val = float("inf") 152 | else: 153 | end_val = samples[(bin_index + 1) * samples_per_bin] 154 | 155 | if start_val == end_val: 156 | continue 157 | 158 | entry = { 159 | "type": "numeric", 160 | "code_string": code, 161 | "val_start": start_val, 162 | "val_end": end_val, 163 | "weight": weight * math.log(weight) + (1 - weight) * math.log(1 - weight), 164 | } 165 | vocab.append(entry) 166 | 167 | 168 | vocab.sort(key=lambda a: a["weight"]) 169 | vocab = vocab[:vocab_size] 170 | 171 | result = { 172 | "vocab": vocab, 173 | "age_stats": { 174 | "mean": statistics["age_stats"].mean(), 175 | "std": statistics["age_stats"].standard_deviation(), 176 | }, 177 | } 178 | 179 | return result 180 | 181 | 182 | class FlatTokenizer(transformers.utils.PushToHubMixin): 183 | def __init__(self, dictionary: Mapping[str, Any], ontology: Optional[femr.ontology.Ontology] = None): 184 | self.dictionary = dictionary 185 | 186 | self.is_hierarchical = dictionary["is_hierarchical"] 187 | 188 | if self.is_hierarchical: 189 | assert ontology is not None 190 | 191 | self.ontology = ontology 192 | 193 | self.dictionary = dictionary 194 | vocab = dictionary["vocab"] 195 | 196 | self.string_lookup = {} 197 | self.code_lookup = {} 198 | 199 | self.vocab_size = len(vocab) 200 | 201 | self.numeric_lookup = collections.defaultdict(list) 202 | for i, dict_entry in enumerate(vocab): 203 | if dict_entry["type"] == "code": 204 | self.code_lookup[dict_entry["code_string"]] = i 205 | elif dict_entry["type"] == "numeric": 206 | self.numeric_lookup[dict_entry["code_string"]].append( 207 | (dict_entry["val_start"], dict_entry["val_end"], i) 208 | ) 209 | elif dict_entry["type"] == "text": 210 | self.string_lookup[(dict_entry["code_string"], dict_entry["text_string"])] = i 211 | else: 212 | pass 213 | 214 | @classmethod 215 | def from_pretrained( 216 | self, 217 | pretrained_model_name_or_path: Union[str, os.PathLike], 218 | **kwargs, 219 | ): 220 | """ 221 | Load the FEMR tokenizer. 222 | 223 | Parameters: 224 | pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): 225 | Can be either: 226 | - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. 227 | Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a 228 | user or organization name, like `dbmdz/bert-base-german-cased`. 229 | - A path to a *directory* containing tokenization data saved using 230 | [`save_pretrained`], e.g., `./my_data_directory/`. 231 | ontology: An ontology object for hierarchical tokenizers 232 | kwargs: Arguments for loading to pass to transformers.utils.hub.cached_file 233 | 234 | Returns: 235 | A FEMR Tokenizer 236 | """ 237 | 238 | dictionary_file = transformers.utils.hub.cached_file( 239 | pretrained_model_name_or_path, "dictionary.msgpack", **kwargs 240 | ) 241 | 242 | with open(dictionary_file, "rb") as f: 243 | dictionary = msgpack.load(f) 244 | 245 | return FlatTokenizer(dictionary) 246 | 247 | def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): 248 | """ 249 | Save the FEMR tokenizer. 250 | 251 | 252 | This method make sure the batch processor can then be re-loaded using the 253 | .from_pretrained class method. 254 | 255 | Args: 256 | save_directory (`str` or `os.PathLike`): The path to a directory where the tokenizer will be saved. 257 | push_to_hub (`bool`, *optional*, defaults to `False`): 258 | Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the 259 | repository you want to push to with `repo_id` (will default to the name of `save_directory` in your 260 | namespace). 261 | kwargs (`Dict[str, Any]`, *optional*): 262 | Additional key word arguments passed along to the [`PushToHubMixin.push_to_hub`] method. 263 | """ 264 | assert not os.path.isfile(save_directory), f"Provided path ({save_directory}) should be a directory, not a file" 265 | 266 | os.makedirs(save_directory, exist_ok=True) 267 | 268 | if push_to_hub: 269 | commit_message = kwargs.pop("commit_message", None) 270 | repo_id = kwargs.pop("repo_id", str(save_directory).split(os.path.sep)[-1]) 271 | repo_id = self._create_repo(repo_id, **kwargs) 272 | files_timestamps = self._get_files_timestamps(save_directory) 273 | 274 | with open(os.path.join(save_directory, "dictionary.msgpack"), "wb") as f: 275 | msgpack.dump(self.dictionary, f) 276 | 277 | if push_to_hub: 278 | self._upload_modified_files( 279 | save_directory, 280 | repo_id, 281 | files_timestamps, 282 | commit_message=commit_message, 283 | token=kwargs.get("token"), 284 | ) 285 | 286 | def start_subject(self): 287 | """Compute per-subject statistics that are required to generate features.""" 288 | 289 | # This is currently a null-op, but is required for cost featurization 290 | pass 291 | 292 | def get_feature_codes(self, event: meds_reader.Event) -> Tuple[List[int], Optional[List[float]]]: 293 | """Get codes for the provided measurement and time""" 294 | 295 | # Note that time is currently not used in this code, but it is required for cost featurization 296 | if event.numeric_value is not None: 297 | for start, end, i in self.numeric_lookup.get(event.code, []): 298 | if start <= event.numeric_value < end: 299 | return [i], None 300 | else: 301 | return [], None 302 | elif event.text_value is not None: 303 | value = self.string_lookup.get((event.code, event.text_value)) 304 | if value is not None: 305 | return [value], None 306 | else: 307 | return [], None 308 | else: 309 | value = self.code_lookup.get(event.code) 310 | if value is not None: 311 | return [value], None 312 | else: 313 | return [], None 314 | 315 | def normalize_age(self, age: datetime.timedelta) -> float: 316 | return (age.total_seconds() - self.dictionary["age_stats"]["mean"]) / (self.dictionary["age_stats"]["std"]) 317 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2018 Ethan Steinberg and other contributors 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /tutorials/4_Train MOTOR.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "1c81279e-a568-4e36-9906-06317accb622", 6 | "metadata": {}, 7 | "source": [ 8 | "# Train MOTOR\n", 9 | "\n", 10 | "This tutorial walks through the various steps to train a MOTOR model.\n", 11 | "\n", 12 | "Training MOTOR is a four step process:\n", 13 | "\n", 14 | "- Training a tokenizer\n", 15 | "- Prefitting MOTOR\n", 16 | "- Preparing batches\n", 17 | "- Training the model" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 1, 23 | "id": "7dcdfd70-58a1-4460-80a8-db737a8c5cd6", 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "import shutil\n", 28 | "import os\n", 29 | "\n", 30 | "TARGET_DIR = 'trash/tutorial_4'\n", 31 | "\n", 32 | "if os.path.exists(TARGET_DIR):\n", 33 | " shutil.rmtree(TARGET_DIR)\n", 34 | "\n", 35 | "os.mkdir(TARGET_DIR)" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 2, 41 | "id": "646f7590", 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "import meds_reader\n", 46 | "import femr.splits\n", 47 | "\n", 48 | "# First, we want to split our dataset into train, valid, and test\n", 49 | "# We do this by calling our split functionality twice\n", 50 | "\n", 51 | "database = meds_reader.SubjectDatabase('input/synthetic_meds')\n", 52 | "\n", 53 | "main_split = femr.splits.generate_hash_split(list(database), 97, frac_test=0.15)\n", 54 | "\n", 55 | "os.mkdir(os.path.join(TARGET_DIR, 'motor_model'))\n", 56 | "# Note that we want to save this to the target directory since this is important information\n", 57 | "\n", 58 | "main_split.save_to_csv(os.path.join(TARGET_DIR, \"motor_model\", \"main_split.csv\"))\n", 59 | "\n", 60 | "train_split = femr.splits.generate_hash_split(main_split.train_subject_ids, 87, frac_test=0.15)\n", 61 | "\n", 62 | "main_database = database.filter(main_split.train_subject_ids)\n", 63 | "train_database = main_database.filter(train_split.train_subject_ids)\n", 64 | "val_database = main_database.filter(train_split.test_subject_ids)\n" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 3, 70 | "id": "f60ab7df-e851-44a5-ab70-7bee292be00c", 71 | "metadata": {}, 72 | "outputs": [ 73 | { 74 | "name": "stderr", 75 | "output_type": "stream", 76 | "text": [ 77 | "/home/ethanid/envs/motor_nlp/lib/python3.13/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 78 | " from .autonotebook import tqdm as notebook_tqdm\n" 79 | ] 80 | } 81 | ], 82 | "source": [ 83 | "import femr.models.tokenizer\n", 84 | "import pickle\n", 85 | "\n", 86 | "# First, we need to train a tokenizer\n", 87 | "# Note, we need to use a hierarchical tokenizer for MOTOR\n", 88 | "\n", 89 | "with open('input/ontology.pkl', 'rb') as f:\n", 90 | " ontology = pickle.load(f)\n", 91 | "\n", 92 | "# NOTE: A vocab size of 128 is probably too low for a real model. 128 was chosen to make this tutorial quick to run\n", 93 | "# NOTE: Normally you would train the tokenizer on only the train database, but for such a tiny dataset that's not enough\n", 94 | "tokenizer = femr.models.tokenizer.HierarchicalTokenizer.train(\n", 95 | " database, vocab_size=1024 * 16, ontology=ontology, min_fraction=1e-9) # Normally min_fraction should be set higher, to 1e-4, but need a small min fraction to get enough codes\n", 96 | "\n", 97 | "# Save the tokenizer to the same directory as the model\n", 98 | "tokenizer.save_pretrained(os.path.join(TARGET_DIR, \"motor_model\"))" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 4, 104 | "id": "69b60daa", 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "\n", 109 | "import femr.models.tasks\n", 110 | "\n", 111 | "# Second, we need to prefit the MOTOR model. This is necessary because piecewise exponential models are unstable without an initial fit\n", 112 | "\n", 113 | "motor_task = femr.models.tasks.MOTORTask.fit_pretraining_task_info(\n", 114 | " train_database, tokenizer, num_tasks=2048, num_bins=4, final_layer_size=32, min_fraction=1e-9) # Normally min_fraction should be set higher, to 1e-4, but need a small min fraction to get enough codes\n", 115 | "\n", 116 | "# It's recommended to save this with pickle to avoid recomputing since it's an expensive operation" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 5, 122 | "id": "89611ba9-a242-4b87-9b8f-25670d838fc6", 123 | "metadata": {}, 124 | "outputs": [ 125 | { 126 | "name": "stdout", 127 | "output_type": "stream", 128 | "text": [ 129 | "Convert a single subject\n", 130 | "Convert batches\n", 131 | "Got batches 46\n" 132 | ] 133 | }, 134 | { 135 | "name": "stderr", 136 | "output_type": "stream", 137 | "text": [ 138 | "Generating train split: 46 examples [00:00, 1850.48 examples/s]\n" 139 | ] 140 | }, 141 | { 142 | "name": "stdout", 143 | "output_type": "stream", 144 | "text": [ 145 | "Convert batches to pytorch\n", 146 | "Done\n", 147 | "Got batches 9\n" 148 | ] 149 | }, 150 | { 151 | "name": "stderr", 152 | "output_type": "stream", 153 | "text": [ 154 | "Generating train split: 9 examples [00:00, 1676.98 examples/s]\n" 155 | ] 156 | } 157 | ], 158 | "source": [ 159 | "import femr.models.processor\n", 160 | "import femr.models.tasks\n", 161 | "\n", 162 | "# Third, we need to create batches. \n", 163 | "\n", 164 | "processor = femr.models.processor.FEMRBatchProcessor(tokenizer, motor_task)\n", 165 | "\n", 166 | "example_subject_id = list(train_database)[0]\n", 167 | "example_subject = train_database[example_subject_id]\n", 168 | "\n", 169 | "# We can do this one subject at a time\n", 170 | "print(\"Convert a single subject\")\n", 171 | "example_batch = processor.collate([processor.convert_subject(example_subject, tensor_type='pt')])\n", 172 | "\n", 173 | "print(\"Convert batches\")\n", 174 | "# But generally we want to convert entire datasets\n", 175 | "train_batches = processor.convert_dataset(train_database, tokens_per_batch=32, num_proc=4)\n", 176 | "\n", 177 | "print(\"Convert batches to pytorch\")\n", 178 | "# Convert our batches to pytorch tensors\n", 179 | "train_batches.set_format(\"pt\")\n", 180 | "print(\"Done\")\n", 181 | "\n", 182 | "val_batches = processor.convert_dataset(val_database, tokens_per_batch=32, num_proc=4)\n", 183 | "# Convert our batches to pytorch tensors\n", 184 | "val_batches.set_format(\"pt\")" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 6, 190 | "id": "f654a46c-5aa7-465c-b6c5-73d8ba26ed67", 191 | "metadata": {}, 192 | "outputs": [ 193 | { 194 | "name": "stderr", 195 | "output_type": "stream", 196 | "text": [ 197 | "Could not estimate the number of tokens of the input, floating-point operations will not be computed\n" 198 | ] 199 | }, 200 | { 201 | "data": { 202 | "text/html": [ 203 | "\n", 204 | "
\n", 205 | " \n", 206 | " \n", 207 | " [184/184 00:07, Epoch 4/4]\n", 208 | "
\n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | "
StepTraining LossValidation Loss
200.0391000.031973
400.0445000.031971
600.0475000.031969
800.0376000.031968
1000.0364000.031967
1200.0343000.031967
1400.0475000.031966
1600.0496000.031966
1800.0337000.031966

" 265 | ], 266 | "text/plain": [ 267 | "" 268 | ] 269 | }, 270 | "metadata": {}, 271 | "output_type": "display_data" 272 | } 273 | ], 274 | "source": [ 275 | "import transformers\n", 276 | "\n", 277 | "import femr.models.transformer\n", 278 | "\n", 279 | "# Finally, given the batches, we can train CLMBR.\n", 280 | "# We can use huggingface's trainer to do this.\n", 281 | "\n", 282 | "transformer_config = femr.models.config.FEMRTransformerConfig(\n", 283 | " vocab_size=tokenizer.vocab_size, \n", 284 | " is_hierarchical=True, \n", 285 | " use_normed_ages=True,\n", 286 | " use_bias=False,\n", 287 | " hidden_act='swiglu',\n", 288 | " n_layers=2,\n", 289 | " hidden_size=64, \n", 290 | " intermediate_size=64*2,\n", 291 | " n_heads=8,\n", 292 | ")\n", 293 | "\n", 294 | "config = femr.models.config.FEMRModelConfig.from_transformer_task_configs(transformer_config, motor_task.get_task_config())\n", 295 | "\n", 296 | "model = femr.models.transformer.FEMRModel(config)\n", 297 | "\n", 298 | "collator = processor.collate\n", 299 | "\n", 300 | "trainer_config = transformers.TrainingArguments(\n", 301 | " per_device_train_batch_size=1,\n", 302 | " per_device_eval_batch_size=1,\n", 303 | "\n", 304 | " output_dir='tmp_trainer',\n", 305 | " remove_unused_columns=False,\n", 306 | " num_train_epochs=4,\n", 307 | "\n", 308 | " eval_steps=20,\n", 309 | " eval_strategy=\"steps\",\n", 310 | "\n", 311 | " logging_steps=20,\n", 312 | " logging_strategy='steps',\n", 313 | "\n", 314 | " prediction_loss_only=True,\n", 315 | ")\n", 316 | "\n", 317 | "trainer = transformers.Trainer(\n", 318 | " model=model,\n", 319 | " data_collator=processor.collate,\n", 320 | " train_dataset=train_batches,\n", 321 | " eval_dataset=val_batches,\n", 322 | " args=trainer_config,\n", 323 | ")\n", 324 | "\n", 325 | "trainer.train()\n", 326 | "\n", 327 | "model.save_pretrained(os.path.join(TARGET_DIR, 'motor_model'))" 328 | ] 329 | } 330 | ], 331 | "metadata": { 332 | "kernelspec": { 333 | "display_name": "Python 3 (ipykernel)", 334 | "language": "python", 335 | "name": "python3" 336 | }, 337 | "language_info": { 338 | "codemirror_mode": { 339 | "name": "ipython", 340 | "version": 3 341 | }, 342 | "file_extension": ".py", 343 | "mimetype": "text/x-python", 344 | "name": "python", 345 | "nbconvert_exporter": "python", 346 | "pygments_lexer": "ipython3", 347 | "version": "3.13.3" 348 | } 349 | }, 350 | "nbformat": 4, 351 | "nbformat_minor": 5 352 | } 353 | -------------------------------------------------------------------------------- /src/femr/labelers/core.py: -------------------------------------------------------------------------------- 1 | """Core labeling functionality/schemas, shared across all labeling functions.""" 2 | 3 | from __future__ import annotations 4 | 5 | import datetime 6 | import functools 7 | import hashlib 8 | import itertools 9 | import struct 10 | import warnings 11 | from abc import ABC, abstractmethod 12 | from dataclasses import dataclass 13 | from typing import Iterator, List, NamedTuple, Optional, Tuple 14 | 15 | import meds_reader 16 | import pandas as pd 17 | 18 | 19 | # A more efficient copy of the MEDS label type definition 20 | class Label(NamedTuple): 21 | subject_id: int 22 | prediction_time: datetime.datetime 23 | 24 | boolean_value: bool 25 | 26 | 27 | @dataclass(frozen=True) 28 | class TimeHorizon: 29 | """An interval of time. Mandatory `start`, optional `end`.""" 30 | 31 | start: datetime.timedelta 32 | end: datetime.timedelta | None # If NONE, then infinite time horizon 33 | 34 | 35 | def _label_map_func(subjects: Iterator[meds_reader.Subject], *, labeler: Labeler) -> pd.DataFrame: 36 | data = itertools.chain.from_iterable(labeler.label(subject) for subject in subjects) 37 | final = pd.DataFrame.from_records(data, columns=Label._fields) 38 | final["prediction_time"] = final["prediction_time"].astype("datetime64[us]") 39 | return final 40 | 41 | 42 | class Labeler(ABC): 43 | """An interface for labeling functions. 44 | 45 | A labeling function applies a label to a specific datetime in a given subject's timeline. 46 | It can be thought of as generating the following list given a specific subject: 47 | [(subject ID, datetime_1, label_1), (subject ID, datetime_2, label_2), ... ] 48 | Usage: 49 | ``` 50 | labeling_function: Labeler = Labeler(...) 51 | subjects: Sequence[Subject] = ... 52 | labels: LabeledSubject = labeling_function.apply(subjects) 53 | ``` 54 | """ 55 | 56 | @abstractmethod 57 | def label(self, subject: meds_reader.Subject) -> List[Label]: 58 | """Apply every label that is applicable to the provided subject. 59 | 60 | This is only called once per subject. 61 | 62 | Args: 63 | subject (Subject): A subject object 64 | 65 | Returns: 66 | List[Label]: A list of :class:`Label` containing every label for the given subject 67 | """ 68 | pass 69 | 70 | def apply( 71 | self, 72 | db: meds_reader.SubjectDatabase, 73 | ) -> pd.DataFrame: 74 | """Apply the `label()` function one-by-one to each Subject in a sequence of Subjects. 75 | 76 | Args: 77 | dataset (datasets.Dataset): A HuggingFace Dataset with meds_reader.Subject objects to be labeled. 78 | num_proc (int, optional): Number of CPU threads to parallelize across. Defaults to 1. 79 | 80 | Returns: 81 | A list of labels 82 | """ 83 | 84 | # TODO: Cast the schema properly 85 | result = pd.concat(db.map(functools.partial(_label_map_func, labeler=self)), ignore_index=True) 86 | result.sort_values(by=["subject_id", "prediction_time"], inplace=True) 87 | 88 | return result 89 | 90 | 91 | ########################################################## 92 | # Specific Labeler Superclasses 93 | ########################################################## 94 | 95 | 96 | class TimeHorizonEventLabeler(Labeler): 97 | """Label events that occur within a particular time horizon. 98 | This support both "finite" and "infinite" time horizons. 99 | 100 | The time horizon can be "fixed" (i.e. has both a start and end date), or "infinite" (i.e. only a start date) 101 | 102 | A TimeHorizonEventLabeler enables you to label events that occur within a particular 103 | time horizon (i.e. `TimeHorizon`). It is a boolean event that is TRUE if the event of interest 104 | occurs within that time horizon, and FALSE if it doesn't occur by the end of the time horizon. 105 | 106 | No labels are generated if the subject record is "censored" before the end of the horizon. 107 | 108 | You are required to implement three methods: 109 | get_outcome_times() for defining the datetimes of the event of interset 110 | get_prediction_times() for defining the datetimes at which we make our predictions 111 | get_time_horizon() for defining the length of time (i.e. `TimeHorizon`) to use for the time horizon 112 | """ 113 | 114 | def __init__(self): 115 | pass 116 | 117 | @abstractmethod 118 | def get_outcome_times(self, subject: meds_reader.Subject) -> List[datetime.datetime]: 119 | """Return a sorted list containing the datetimes that the event of interest "occurs". 120 | 121 | IMPORTANT: Must be sorted ascending (i.e. start -> end of timeline) 122 | 123 | Args: 124 | subject (Subject): A subject object 125 | 126 | Returns: 127 | List[datetime.datetime]: A list of datetimes, one corresponding to an occurrence of the outcome 128 | """ 129 | pass 130 | 131 | @abstractmethod 132 | def get_time_horizon(self) -> TimeHorizon: 133 | """Return time horizon for making predictions with this labeling function. 134 | 135 | Return the (start offset, end offset) of the time horizon (from the prediction time) 136 | used for labeling whether an outcome occurred or not. These can be arbitrary timedeltas. 137 | 138 | If end offset is None, then the time horizon is infinite (i.e. only has a start offset). 139 | If end offset is not None, then the time horizon is finite (i.e. has both a start and end offset), 140 | and it must be true that end offset >= start offset. 141 | 142 | Example: 143 | X is the time that you're making a prediction (given by `get_prediction_times()`) 144 | (A,B) is your time horizon (given by `get_time_horizon()`) 145 | O is an outcome (given by `get_outcome_times()`) 146 | 147 | Then given a subject timeline: 148 | X-----(X+A)------(X+B)------ 149 | 150 | 151 | This has a label of TRUE: 152 | X-----(X+A)--O---(X+B)------ 153 | 154 | This has a label of TRUE: 155 | X-----(X+A)--O---(X+B)----O- 156 | 157 | This has a label of FALSE: 158 | X---O-(X+A)------(X+B)------ 159 | 160 | This has a label of FALSE: 161 | X-----(X+A)------(X+B)--O--- 162 | """ 163 | pass 164 | 165 | @abstractmethod 166 | def get_prediction_times(self, subject: meds_reader.Subject) -> List[datetime.datetime]: 167 | """Return a sorted list containing the datetimes at which we'll make a prediction. 168 | 169 | IMPORTANT: Must be sorted ascending (i.e. start -> end of timeline) 170 | """ 171 | pass 172 | 173 | def get_subject_start_end_times(self, subject: meds_reader.Subject) -> Tuple[datetime.datetime, datetime.datetime]: 174 | """Return the datetimes that we consider the (start, end) of this subject.""" 175 | return (subject.events[0].time, subject.events[-1].time) 176 | 177 | def allow_same_time_labels(self) -> bool: 178 | """Whether or not to allow labels with events at the same time as prediction""" 179 | return True 180 | 181 | def label(self, subject: meds_reader.Subject) -> List[Label]: 182 | """Return a list of Labels for an individual subject. 183 | 184 | Assumes that events in `subject['events']` are already sorted in chronologically 185 | ascending order (i.e. start -> end). 186 | 187 | Args: 188 | subject (Subject): A subject object 189 | 190 | Returns: 191 | List[Label]: A list containing a label for each datetime returned by `get_prediction_times()` 192 | """ 193 | if len(subject.events) == 0: 194 | return [] 195 | 196 | __, end_time = self.get_subject_start_end_times(subject) 197 | outcome_times: List[datetime.datetime] = self.get_outcome_times(subject) 198 | prediction_times: List[datetime.datetime] = self.get_prediction_times(subject) 199 | time_horizon: TimeHorizon = self.get_time_horizon() 200 | 201 | # Get (start, end) of time horizon. If end is None, then it's infinite (set timedelta to max) 202 | time_horizon_start: datetime.timedelta = time_horizon.start 203 | time_horizon_end: Optional[datetime.timedelta] = time_horizon.end # `None` if infinite time horizon 204 | 205 | # For each prediction time, check if there is an outcome which occurs within the (start, end) 206 | # of the time horizon 207 | results: List[Label] = [] 208 | curr_outcome_idx: int = 0 209 | last_time = None 210 | 211 | for time in prediction_times: 212 | if last_time is not None: 213 | assert time > last_time, f"Must be ascending prediction times, instead got {last_time} <= {time}" 214 | 215 | last_time = time 216 | while curr_outcome_idx < len(outcome_times) and outcome_times[curr_outcome_idx] < time + time_horizon_start: 217 | # `curr_outcome_idx` is the idx in `outcome_times` that corresponds to the first 218 | # outcome EQUAL or AFTER the time horizon for this prediction time starts (if one exists) 219 | curr_outcome_idx += 1 220 | 221 | if curr_outcome_idx < len(outcome_times) and outcome_times[curr_outcome_idx] == time: 222 | if not self.allow_same_time_labels(): 223 | continue 224 | warnings.warn( 225 | "You are making predictions at the same time as the target outcome." 226 | "This frequently leads to label leakage." 227 | ) 228 | 229 | # TRUE if an event occurs within the time horizon 230 | is_outcome_occurs_in_time_horizon: bool = ( 231 | ( 232 | # ensure there is an outcome 233 | # (needed in case there are 0 outcomes) 234 | curr_outcome_idx 235 | < len(outcome_times) 236 | ) 237 | and ( 238 | # outcome occurs after time horizon starts 239 | time + time_horizon_start 240 | <= outcome_times[curr_outcome_idx] 241 | ) 242 | and ( 243 | # outcome occurs before time horizon ends (if there is an end) 244 | (time_horizon_end is None) 245 | or outcome_times[curr_outcome_idx] <= time + time_horizon_end 246 | ) 247 | ) 248 | # TRUE if subject is censored (i.e. timeline ends BEFORE this time horizon ends, 249 | # so we don't know if the outcome happened after the subject timeline ends) 250 | # If infinite time horizon labeler, then assume no censoring 251 | is_censored: bool = end_time < time + time_horizon_end if (time_horizon_end is not None) else False 252 | 253 | if is_outcome_occurs_in_time_horizon: 254 | results.append(Label(subject_id=subject.subject_id, prediction_time=time, boolean_value=True)) 255 | elif not is_censored: 256 | # Not censored + no outcome => FALSE 257 | results.append(Label(subject_id=subject.subject_id, prediction_time=time, boolean_value=False)) 258 | elif is_censored: 259 | # Censored => None 260 | pass 261 | 262 | return results 263 | 264 | 265 | class NLabelsPerSubjectLabeler(Labeler): 266 | """Restricts `self.labeler` to returning a max of `self.k` labels per subject.""" 267 | 268 | def __init__(self, labeler: Labeler, num_labels: int = 1, seed: int = 1): 269 | self.labeler: Labeler = labeler 270 | self.num_labels: int = num_labels # number of labels per subject 271 | self.seed: int = seed 272 | 273 | def label(self, subject: meds_reader.Subject) -> List[Label]: 274 | labels: List[Label] = self.labeler.label(subject) 275 | if len(labels) <= self.num_labels: 276 | return labels 277 | elif self.num_labels == -1: 278 | return labels 279 | hash_to_label_list: List[Tuple[int, int, Label]] = [ 280 | (i, compute_random_num(self.seed, subject.subject_id, i), labels[i]) for i in range(len(labels)) 281 | ] 282 | hash_to_label_list.sort(key=lambda a: a[1]) 283 | n_hash_to_label_list: List[Tuple[int, int, Label]] = hash_to_label_list[: self.num_labels] 284 | n_hash_to_label_list.sort(key=lambda a: a[0]) 285 | n_labels: List[Label] = [hash_to_label[2] for hash_to_label in n_hash_to_label_list] 286 | return n_labels 287 | 288 | 289 | def compute_random_num(seed: int, num_1: int, num_2: int, modulus: int = 100): 290 | network_num_1 = struct.pack("!q", num_1) 291 | network_num_2 = struct.pack("!q", num_2) 292 | network_seed = struct.pack("!q", seed) 293 | 294 | to_hash = network_seed + network_num_1 + network_num_2 295 | 296 | hash_object = hashlib.sha256() 297 | hash_object.update(to_hash) 298 | hash_value = hash_object.digest() 299 | 300 | result = 0 301 | for i in range(len(hash_value)): 302 | result = (result * 256 + hash_value[i]) % modulus 303 | 304 | return result 305 | --------------------------------------------------------------------------------