├── ct_dicom ├── gcp │ ├── optional-requirements.txt │ ├── requirements.txt │ ├── README.md │ ├── auth_test.py │ ├── auth.py │ └── dicomweb_beam.py ├── requirements.txt ├── testdata │ ├── img_8c1.npy │ ├── img_8c1.png │ ├── img_16c1.npy │ └── img_16c1.png ├── image_utils.py ├── README.md ├── dicom_utils_test.py ├── example_builder_beam.py ├── example_builder.py ├── pipeline.py ├── image_utils_test.py └── dicom_utils.py ├── fitbit_pregnancy ├── requirements.txt ├── data │ ├── figure_2_a_data.csv │ ├── figure_2_c_data.csv │ ├── figure_2_b_data.csv │ ├── figure_5_b_data.csv │ ├── figure_5_c_data.csv │ ├── figure_5_d_data.csv │ ├── figure_5_a_data.csv │ ├── figure_4_b_data.csv │ ├── figure_4_a_data.csv │ ├── figure_4_c_data.csv │ ├── figure_4_d_data.csv │ ├── figure_4_e_data.csv │ ├── figure_4_f_data.csv │ └── figure_3_data.csv ├── README.md └── figure_generate.py ├── colorectal-survival ├── CONTRIBUTING.md ├── requirements.txt ├── README.md ├── run.sh ├── network_test.py ├── loss_test.py ├── train.py ├── analysis_test.py ├── analysis.py ├── loss.py └── network.py ├── breast_survival_prediction ├── requirements.txt ├── CONTRIBUTING.md ├── README.md ├── run.sh ├── mitotic_features_util_test.py ├── stage2_features.py ├── mitotic_features_util.py ├── stage2_features_test.py └── example.ipynb ├── health_acoustic_representations ├── requirements.txt ├── run.sh ├── api_utils_test.py ├── eval_utils_test.py ├── eval_utils.py ├── hear_demo.ipynb └── api_utils.py ├── analysis ├── requirements.txt ├── README.md └── run.sh ├── colorectal_lymph_node_metastasis_prediction ├── CONTRIBUTING.md ├── requirements.txt ├── README.md └── data_utils.py ├── README ├── fetal_ultrasound_blind_sweeps ├── README.md ├── run.sh ├── networks_test.py └── networks.py ├── LICENSE └── data_splits └── README.md /ct_dicom/gcp/optional-requirements.txt: -------------------------------------------------------------------------------- 1 | pydicom 2 | typing-extensions 3 | -------------------------------------------------------------------------------- /fitbit_pregnancy/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.22.4 2 | pandas==1.4.2 3 | matplotlib==3.5.2 -------------------------------------------------------------------------------- /colorectal-survival/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | We are not accepting contributions for this project. 2 | -------------------------------------------------------------------------------- /breast_survival_prediction/requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python>=4.6.0.66 2 | numpy>=1.19.5 3 | Pillow>=9.3.0 4 | -------------------------------------------------------------------------------- /ct_dicom/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | file-io 3 | numpy 4 | pydicom>=2.3.1 5 | pypng 6 | tensorflow-cpu 7 | -------------------------------------------------------------------------------- /health_acoustic_representations/requirements.txt: -------------------------------------------------------------------------------- 1 | google-auth 2 | google-cloud-aiplatform 3 | protobuf 4 | -------------------------------------------------------------------------------- /analysis/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py>=1.0.0 2 | numpy>=1.19.5 3 | scikit-learn>=0.24.1 4 | scipy>=1.2.1 5 | 6 | -------------------------------------------------------------------------------- /breast_survival_prediction/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | We are currently not accepting contributions for this project. 2 | -------------------------------------------------------------------------------- /colorectal_lymph_node_metastasis_prediction/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | We are not accepting contributions for this project. 2 | -------------------------------------------------------------------------------- /ct_dicom/gcp/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | apache-beam 3 | dicomweb-client[gcp] 4 | google-auth 5 | requests-toolbelt 6 | -------------------------------------------------------------------------------- /ct_dicom/testdata/img_8c1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Google-Health/google-health/master/ct_dicom/testdata/img_8c1.npy -------------------------------------------------------------------------------- /ct_dicom/testdata/img_8c1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Google-Health/google-health/master/ct_dicom/testdata/img_8c1.png -------------------------------------------------------------------------------- /ct_dicom/testdata/img_16c1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Google-Health/google-health/master/ct_dicom/testdata/img_16c1.npy -------------------------------------------------------------------------------- /ct_dicom/testdata/img_16c1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Google-Health/google-health/master/ct_dicom/testdata/img_16c1.png -------------------------------------------------------------------------------- /colorectal_lymph_node_metastasis_prediction/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.19.5 2 | pandas>=1.0.5 3 | scikit-learn>=0.23.2 4 | scipy>=1.8.0 5 | statsmodels>=0.12.2 -------------------------------------------------------------------------------- /colorectal-survival/requirements.txt: -------------------------------------------------------------------------------- 1 | lifelines>=0.25.4 2 | matplotlib>=3.0.3 3 | numpy>=1.19.5 4 | pandas>=1.0.5 5 | patsy>=0.4.1 6 | scikit-learn>=0.23.2 7 | tensorflow>=2.3.0 8 | -------------------------------------------------------------------------------- /README: -------------------------------------------------------------------------------- 1 | NOTE: the content of this research code repository (i) is not intended to be a 2 | medical device; and (ii) is not intended for clinical use of any kind, including 3 | but not limited to diagnosis or prognosis. -------------------------------------------------------------------------------- /ct_dicom/gcp/README.md: -------------------------------------------------------------------------------- 1 | Google Cloud Healthcare API and other Google Cloud Platform related utilities 2 | for healthcare applications. 3 | 4 | The module may be downloaded and used independently of the parent `ct_dicom` 5 | module. 6 | -------------------------------------------------------------------------------- /colorectal-survival/README.md: -------------------------------------------------------------------------------- 1 | # Colorectal cancer survival prediction network and analysis code 2 | 3 | This repo contains the network architecture and statistical analysis code 4 | described in "Interpretable survival prediction for colorectal cancer using deep 5 | learning" (https://doi.org/10.1038/s41746-021-00427-2). 6 | -------------------------------------------------------------------------------- /fitbit_pregnancy/data/figure_2_a_data.csv: -------------------------------------------------------------------------------- 1 | bin,count 2 | 23.1,1.0 3 | 24.1,2.0 4 | 25.1,2.0 5 | 26.1,2.0 6 | 27.1,2.0 7 | 28.1,2.0 8 | 29.1,5.0 9 | 30.1,3.0 10 | 31.1,8.0 11 | 32.1,9.0 12 | 33.1,24.0 13 | 34.1,40.0 14 | 35.1,74.0 15 | 36.1,248.0 16 | 37.1,395.0 17 | 38.1,778.0 18 | 39.1,629.0 19 | 40.1,269.0 20 | 41.1,47.0 21 | 42.1,0.0 22 | 43.1,0.0 23 | 44.1, 24 | -------------------------------------------------------------------------------- /fetal_ultrasound_blind_sweeps/README.md: -------------------------------------------------------------------------------- 1 | # Blind sweep fetal ultrasound model networks and training loss functions. 2 | 3 | This repository contains the network graph definitions and training loss 4 | functions for the models described in "A mobile-optimized artificial 5 | intelligence system for gestational age and fetal malpresentation assessment" 6 | (https://doi.org/10.1038/s43856-022-00194-5). -------------------------------------------------------------------------------- /fetal_ultrasound_blind_sweeps/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Create a virtual environment for installing dependencies. 3 | python3 -m virtualenv . 4 | source ./bin/activate 5 | # Install required dependencies. 6 | pip3 install tensorflow==1.15.5 7 | pip3 install tf_slim==1.1.0 8 | git clone https://github.com/tensorflow/models 9 | mv models/research/slim/nets nets 10 | mv models/research/lstm_object_detection lstm_object_detection 11 | # Run unit tests. 12 | python3 -m unittest networks_test.py -------------------------------------------------------------------------------- /fitbit_pregnancy/README.md: -------------------------------------------------------------------------------- 1 | Longitudinal Analysis of Real-world Wearable Device Data Before, During, and After Pregnancy Aggregated Data and Figure Generating Code 2 | 3 | This repo contains the aggregated, anonymized dataset described in the “Insights into Maternal Sleep: A Large-scale Longitudinal Analysis of Real-world Wearable Device Data Before, During, and After Pregnancy” paper, and the code needed to generate the figures in the paper from the aggregated data. 4 | 5 | 6 | NOTE: the content of this research code repository (i) is not intended to be a medical device; and (ii) is not intended for clinical use of any kind, including but not limited to diagnosis or prognosis. 7 | -------------------------------------------------------------------------------- /colorectal_lymph_node_metastasis_prediction/README.md: -------------------------------------------------------------------------------- 1 | # Lymph node metastasis prediction: machine-learned feature generation/selection and model evaluation 2 | 3 | This repo contains the code needed to generate and select cluster-based 4 | machine-learned features while controlling for baseline features as described 5 | in "Predicting lymph node metastasis from primary tumor histology and clinicopathologic factors in colorectal cancer using deep learning." For an example of how this 6 | code may be used, see demo.ipynb. 7 | 8 | NOTE: the content of this research code repository (i) is not intended to be a 9 | medical device; and (ii) is not intended for clinical use of any kind, including 10 | but not limited to diagnosis or prognosis. -------------------------------------------------------------------------------- /breast_survival_prediction/README.md: -------------------------------------------------------------------------------- 1 | # Breast Cancer Survival Prediction: Stage-2 Featurization Code 2 | 3 | This repo contain code to reproduce the stage2 featurization of all three 4 | components of the automatic Nottingham grading system. This is described in the 5 | paper "Deep learning models for histologic grading of breast cancer and 6 | association with disease prognosis" 7 | [[Link](https://doi.org/10.1038/s41523-022-00478-y)]. 8 | For an example of how the code should be used, please refer to example.ipynb. 9 | 10 | NOTE: the content of this research code repository (i) is not intended to be a 11 | medical device; and (ii) is not intended for clinical use of any kind, including 12 | but not limited to diagnosis or prognosis. 13 | -------------------------------------------------------------------------------- /fitbit_pregnancy/data/figure_2_c_data.csv: -------------------------------------------------------------------------------- 1 | week,mean 2 | -11.0,91.85258227536055 3 | -10.0,91.7847898434611 4 | -9.0,92.17305558979415 5 | -8.0,91.91421175890545 6 | -7.0,91.96967829409589 7 | -6.0,92.03130777764082 8 | -5.0,92.36410698878343 9 | -4.0,92.733883890053 10 | -3.0,92.38875878220139 11 | -2.0,92.34561814371995 12 | -1.0,91.18082090472082 13 | 0.0,91.5259460125724 14 | 1.0,66.42220304860513 15 | 2.0,86.81745346973992 16 | 3.0,89.63392086774313 17 | 4.0,90.08381609762111 18 | 5.0,90.34882287686429 19 | 6.0,91.22396154320226 20 | 7.0,91.34105756193763 21 | 8.0,91.50745716750895 22 | 9.0,91.26093923332922 23 | 10.0,91.55676075434488 24 | 11.0,92.08061136447677 25 | 12.0,92.83865401207937 26 | 13.0,92.07444841612227 27 | 14.0,92.14224084802169 28 | 15.0,92.25933686675705 29 | 16.0,91.79711574017009 30 | 17.0,92.09293726118575 31 | 18.0,91.06988783433995 32 | 19.0,90.91581412547764 33 | 20.0,90.40428941205472 34 | -------------------------------------------------------------------------------- /fitbit_pregnancy/data/figure_2_b_data.csv: -------------------------------------------------------------------------------- 1 | week,data 2 | -11.0,89.87626546681663 3 | -10.0,89.85376827896512 4 | -9.0,90.55118110236221 5 | -8.0,90.28683914510685 6 | -7.0,90.19122609673789 7 | -6.0,90.31496062992126 8 | -5.0,90.67491563554556 9 | -4.0,90.43307086614173 10 | -3.0,90.5286839145107 11 | -2.0,90.17435320584927 12 | -1.0,90.03374578177727 13 | 0.0,90.25871766029246 14 | 1.0,89.83689538807648 15 | 2.0,89.73565804274466 16 | 3.0,89.67379077615298 17 | 4.0,90.11248593925758 18 | 5.0,90.5286839145107 19 | 6.0,89.98312710911135 20 | 7.0,89.49381327334085 21 | 8.0,88.97637795275591 22 | 9.0,89.20697412823397 23 | 10.0,88.99887514060742 24 | 11.0,88.63329583802025 25 | 12.0,88.67829021372329 26 | 13.0,89.19572553430821 27 | 14.0,89.15635545556805 28 | 15.0,89.14510686164229 29 | 16.0,89.48256467941506 30 | 17.0,89.3082114735658 31 | 18.0,89.38695163104613 32 | 19.0,89.60067491563554 33 | 20.0,89.92688413948258 34 | 21.0,90.35433070866142 35 | 22.0,89.4431946006749 36 | 23.0,89.28008998875141 37 | 24.0,89.40944881889764 38 | 25.0,89.80314960629921 39 | 26.0,89.89313835770528 40 | 27.0,89.85939257592801 41 | 28.0,90.0281214848144 42 | 29.0,89.4431946006749 43 | 30.0,89.05511811023622 44 | 31.0,88.57705286839145 45 | 32.0,89.11698537682788 46 | 33.0,88.49831271091112 47 | 34.0,87.12035995500561 48 | 35.0,86.45669291338582 49 | 36.0,84.81439820022497 50 | 37.0,81.20359955005624 51 | 38.0,72.64341957255343 52 | -------------------------------------------------------------------------------- /fitbit_pregnancy/data/figure_5_b_data.csv: -------------------------------------------------------------------------------- 1 | ,week,percentile_10,percentile_25,percentile_50,percentile_75,percentile_90 2 | 0,-11.0,384.0,436.5,486.5,535.5,581.5 3 | 1,-10.0,388.5,439.5,488.0,536.5,583.0 4 | 2,-9.0,387.5,438.625,487.5,536.0,583.0 5 | 3,-8.0,387.0,438.5,488.0,536.5,582.0 6 | 4,-7.0,385.5,438.0,487.0,535.5,581.5 7 | 5,-6.0,384.5,438.0,487.0,535.5,582.0 8 | 6,-5.0,384.5,437.5,487.5,536.0,581.5 9 | 7,-4.0,388.0,438.5,486.5,535.0,581.0 10 | 8,-3.0,386.0,437.0,486.0,535.0,581.5 11 | 9,-2.0,385.0,438.0,486.5,534.5,583.5 12 | 10,-1.0,384.5,437.0,486.0,533.0,579.5 13 | 11,0.0,385.0,436.5,485.5,533.0,578.5 14 | 12,1.0,126.5,212.375,325.0,434.0,517.5 15 | 13,2.0,244.5,339.5,434.5,512.5,574.0 16 | 14,3.0,282.0,370.5,457.0,528.5,589.5 17 | 15,4.0,285.5,376.0,458.5,531.5,589.0 18 | 16,5.0,289.5,376.0,457.0,527.0,583.0 19 | 17,6.0,292.0,379.0,459.0,527.0,584.5 20 | 18,7.0,300.5,388.0,463.0,530.0,584.5 21 | 19,8.0,310.5,394.5,469.0,533.0,589.5 22 | 20,9.0,315.0,398.5,471.5,533.5,586.5 23 | 21,10.0,325.0,404.0,472.5,534.0,586.5 24 | 22,11.0,334.0,412.5,478.0,536.5,587.0 25 | 23,12.0,338.5,413.5,478.5,537.0,588.5 26 | 24,13.0,345.0,414.5,477.5,534.0,584.5 27 | 25,14.0,344.5,414.625,478.0,532.5,584.5 28 | 26,15.0,347.0,415.5,477.5,534.0,585.5 29 | 27,16.0,350.5,418.0,479.0,534.5,584.0 30 | 28,17.0,350.1,416.0,478.5,534.5,584.0 31 | 29,18.0,346.5,417.0,478.5,533.0,582.5 32 | 30,19.0,347.5,416.0,478.5,534.0,584.0 33 | 31,20.0,347.5,416.5,476.0,532.0,582.0 34 | -------------------------------------------------------------------------------- /ct_dicom/image_utils.py: -------------------------------------------------------------------------------- 1 | """Utilities for image encoding.""" 2 | 3 | import io 4 | 5 | import numpy as np 6 | import png 7 | 8 | _NUM_BITS_PER_BYTE = 8 9 | 10 | 11 | def encode_png(array: np.ndarray) -> bytes: 12 | """Converts an unsigned integer 2-D NumPy array to a PNG-encoded string. 13 | 14 | Unsigned 8-bit and 16-bit images are supported. 15 | 16 | Args: 17 | array: Array to be encoded. 18 | 19 | Returns: 20 | PNG-encoded string. 21 | 22 | Raises: 23 | ValueError: If any of the following occurs: 24 | - `array` is empty. 25 | - `array` is not 2-D. 26 | - `array` data type is unsupported. 27 | """ 28 | supported_types = frozenset([np.uint8, np.uint16]) 29 | # Sanity checks. 30 | if not array.size: 31 | raise ValueError(f'Received an empty image with shape {array.shape}.') 32 | if array.ndim != 2: 33 | raise ValueError(f'Array must be 2-D. Actual dimensions: {array.ndim}') 34 | if array.dtype.type not in supported_types: 35 | raise ValueError( 36 | 'Pixels must be either `uint8` or `uint16`. ' 37 | f'Actual type: {array.dtype.name!r}' 38 | ) 39 | 40 | # Actual conversion. 41 | writer = png.Writer( 42 | width=array.shape[1], 43 | height=array.shape[0], 44 | greyscale=True, 45 | bitdepth=_NUM_BITS_PER_BYTE * array.dtype.itemsize, 46 | ) 47 | output_data = io.BytesIO() 48 | writer.write(output_data, array.tolist()) 49 | return output_data.getvalue() 50 | -------------------------------------------------------------------------------- /ct_dicom/README.md: -------------------------------------------------------------------------------- 1 | # Utilities for sorting and annotating Computed Tomography DICOMs. 2 | 3 | 4 | This repo contains utility code for reading DICOMs from a Google Cloud DICOM 5 | store, sorting axial CTs to prepare Tensorflow examples, and for annotating 6 | DICOMs from model results, as described in "Assistive AI in Lung Cancer Screening: A Retrospective Multinational Study in the United States and Japan" (https://doi.org/10.1148/ryai.230079). 7 | 8 | 9 | 10 | If you use this software in your own research, please cite our paper: 11 | 12 | ``` 13 | @article{doi:10.1148/ryai.230079, 14 | author = {Kiraly, Atilla P. and Cunningham, Corbin A. and Najafi, Ryan and Nabulsi, Zaid and Yang, Jie and Lau, Charles and Ledsam, Joseph R. and Ye, Wenxing and Ardila, Diego and McKinney, ScottM. and Pilgrim, Rory and Liu, Yun and Saito, Hiroaki and Shimamura, Yasuteru and Etemadi, Mozziyar and Melnick, David and Jansen, Sunny and Corrado, Greg S. and Peng, Lily and Tse, Daniel and Shetty, Shravya and Prabhakara, Shruthi and Naidich, David P. and Beladia, Neeral and Eswaran, Krish}, 15 | title = {Assistive AI in Lung Cancer Screening: A Retrospective Multinational Study in the United States and Japan}, 16 | journal = {Radiology: Artificial Intelligence}, 17 | volume = {0}, 18 | number = {ja}, 19 | pages = {e230079}, 20 | year = {0}, 21 | doi = {10.1148/ryai.230079}, 22 | note ={PMID: 38477661}, 23 | URL = {https://doi.org/10.1148/ryai.230079 24 | }, 25 | eprint = {https://doi.org/10.1148/ryai.230079 26 | } 27 | } 28 | ``` -------------------------------------------------------------------------------- /fitbit_pregnancy/data/figure_5_c_data.csv: -------------------------------------------------------------------------------- 1 | week,mean,std 2 | -11.0,-0.23106288232157213,6.1183850544665965 3 | -10.0,-0.17952147648406278,6.063223474681437 4 | -9.0,-0.5367778605397747,4.569402421012669 5 | -8.0,-0.2735945993565136,4.898648813182015 6 | -7.0,-0.11018478694670114,4.8646046027069305 7 | -6.0,-0.07909820686012109,4.636457979139631 8 | -5.0,-0.09006213182404603,5.064850179938697 9 | -4.0,-0.1793138243454547,5.108455810236698 10 | -3.0,0.1556535968916827,6.408971132821927 11 | -2.0,0.14622069645878225,5.912917338243098 12 | -1.0,0.4258465227861031,8.231045276385599 13 | 0.0,1.0346252920257577,13.026242599511141 14 | 1.0,143.6366833072027,126.44045245777507 15 | 2.0,166.91572604982366,148.46425590590658 16 | 3.0,158.38363165166857,138.7693624261668 17 | 4.0,157.81107830043115,134.4978748150823 18 | 5.0,156.60862948585378,136.41163554223226 19 | 6.0,151.0426074208455,130.34526853134207 20 | 7.0,145.43803788151445,128.30833998146775 21 | 8.0,137.11795807019615,122.8212700669722 22 | 9.0,131.3311802924184,121.11238678762228 23 | 10.0,119.7922862475275,111.56209408067095 24 | 11.0,117.9903318075699,112.62626850873491 25 | 12.0,110.59071275995085,107.99908705138066 26 | 13.0,104.09879934003743,102.0880714795232 27 | 14.0,99.95184534786648,97.31044745822525 28 | 15.0,98.03046742313785,92.6906309453956 29 | 16.0,98.99267911113324,95.10585521830379 30 | 17.0,96.20435400497792,94.14901109026951 31 | 18.0,98.72163038951872,95.6073830879085 32 | 19.0,99.15414179643004,94.41049890270314 33 | 20.0,94.40154925863122,91.92592874663704 34 | -------------------------------------------------------------------------------- /fitbit_pregnancy/data/figure_5_d_data.csv: -------------------------------------------------------------------------------- 1 | week,mean,std 2 | -11.0,5.460372242080611,0.4007338396896058 3 | -10.0,5.608283002588438,0.4069879694588699 4 | -9.0,5.4788610871440895,0.6121310513026391 5 | -8.0,5.491186983853075,0.4885343413559105 6 | -7.0,5.534327622334525,0.3708709181608694 7 | -6.0,5.897941575249599,0.45704263222008024 8 | -5.0,5.423394551953655,0.7332700021027168 9 | -4.0,5.4788610871440895,0.5122666991819637 10 | -3.0,5.6021200542339455,0.4984119768543795 11 | -2.0,5.6021200542339455,0.5995506721072507 12 | -1.0,5.115247134229016,0.525334865552227 13 | 0.0,5.43572044866264,0.41528700253648687 14 | 1.0,17.76161715764822,4.263851753078051 15 | 2.0,20.516455072106492,0.6761103588324822 16 | 3.0,19.918649081720694,1.2400499406547574 17 | 4.0,16.621471712067056,1.1540192790623385 18 | 5.0,15.314926660914582,0.6683318975336346 19 | 6.0,13.496856896339208,0.5482022527793288 20 | 7.0,11.339824972266731,0.8068677977026794 21 | 8.0,10.187353629976581,0.6694580957114309 22 | 9.0,9.959324540860347,0.3925786814344486 23 | 10.0,9.262911376802663,0.256090094781995 24 | 11.0,8.689757179834832,0.37925924916221954 25 | 12.0,8.233699001602366,0.5795574941586773 26 | 13.0,7.765314926660914,0.6085735970520225 27 | 14.0,7.666707752989029,0.44951516156318155 28 | 15.0,7.013435227412794,0.5109675170263559 29 | 16.0,6.976457537285838,0.6522251475569064 30 | 17.0,6.994946382349316,0.3003660269437698 31 | 18.0,6.9025021570319245,0.5462588413351332 32 | 19.0,6.927153950449896,0.3377273205674901 33 | 20.0,6.514236410698879,0.4906161089799248 34 | -------------------------------------------------------------------------------- /fitbit_pregnancy/data/figure_5_a_data.csv: -------------------------------------------------------------------------------- 1 | week,mean,std 2 | -11.0,-0.6490038664241735,30.15727602976822 3 | -10.0,1.4321191526998682,29.309577393880655 4 | -9.0,1.3097827966518876,28.891984670597243 5 | -8.0,0.8874850745766977,28.79110422984371 6 | -7.0,0.08443808754596983,28.11257339484026 7 | -6.0,0.060942006413204894,27.881624571028812 8 | -5.0,-0.41126561128114464,28.27536589794119 9 | -4.0,0.5250996166501658,27.98241566358741 10 | -3.0,-0.34901180987342817,29.04658409483102 11 | -2.0,-0.00920643519133338,28.504673192111106 12 | -1.0,-1.4309969087578978,29.161917041030893 13 | 0.0,-1.8249864112243446,30.06366256375385 14 | 1.0,-157.97947608300132,94.45087321781716 15 | 2.0,-62.70280983581637,87.16456548890275 16 | 3.0,-38.874707466349896,77.59099689764315 17 | 4.0,-36.354987677004964,76.31527680027426 18 | 5.0,-37.89833177583181,74.80707269820088 19 | 6.0,-36.03716898399703,73.17297512779193 20 | 7.0,-31.854359909734768,70.47894905979048 21 | 8.0,-26.299827124712536,68.74728976362839 22 | 9.0,-23.54442666836568,65.61461731630817 23 | 10.0,-20.541775069610097,63.33737573461618 24 | 11.0,-15.968340529008048,60.820159138800044 25 | 12.0,-14.44868101980615,58.60096164412895 26 | 13.0,-14.738484445951627,57.14436716305214 27 | 14.0,-14.98633310232852,55.25463896117928 28 | 15.0,-14.00100295413355,55.13938607705657 29 | 16.0,-12.526104238443628,53.66998041752298 30 | 17.0,-12.977153555235903,53.27227721060504 31 | 18.0,-13.907316159367104,53.48941959790354 32 | 19.0,-13.4756598737474,53.906161111572224 33 | 20.0,-14.294254518745916,52.64062801797197 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021, Google Inc. 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, 5 | are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of Google Inc. nor the names of its contributors 15 | may be used to endorse or promote products derived from this software without 16 | specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 19 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 20 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 22 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 23 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 24 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 25 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 26 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 27 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /analysis/README.md: -------------------------------------------------------------------------------- 1 | # Utilities for evaluating model and reader performance 2 | 3 | This package contains functions for the statistical analysis of multi-reader multi-case (MRMC) studies, a common paradigm in the medical imaging literature. 4 | For statistical considerations and sample usage, please see our open-access [technical report](https://www.medrxiv.org/content/10.1101/2022.05.06.22274773v1). 5 | 6 | As of May 3, 2022, the following published work has used tools in this library: 7 | 8 | * [Majkowska et al. (2019)](https://pubs.rsna.org/doi/full/10.1148/radiol.2019191293) 9 | * [McKinney et al. (2020)](https://www.nature.com/articles/s41586-019-1799-6) 10 | * [Sayres et al. (2020)](https://iovs.arvojournals.org/article.aspx?articleid=2769549) 11 | * [Steiner et al. (2020)](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7662146/) 12 | * [Kazemzadeh et al. (2021)](https://arxiv.org/abs/2105.07540) 13 | * [Liu et al. (2022)](https://www.sciencedirect.com/science/article/pii/S246865302200001X) 14 | 15 | If you use this software in your own research, please cite our paper: 16 | 17 | ``` 18 | @article {McKinney2022, 19 | author = {McKinney, Scott Mayer}, 20 | title = {Comparing human and AI performance in medical machine learning: An open-source Python library for the statistical analysis of reader study data}, 21 | elocation-id = {2022.05.06.22274773}, 22 | year = {2022}, 23 | doi = {10.1101/2022.05.06.22274773}, 24 | publisher = {Cold Spring Harbor Laboratory Press}, 25 | URL = {https://www.medrxiv.org/content/10.1101/2022.05.06.22274773}, 26 | eprint = {https://www.medrxiv.org/content/10.1101/2022.05.06.22274773.full.pdf}, 27 | journal = {medRxiv} 28 | } 29 | ``` 30 | -------------------------------------------------------------------------------- /health_acoustic_representations/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright (c) 2024, Google Inc. 4 | # All rights reserved. 5 | # 6 | # Redistribution and use in source and binary forms, with or without 7 | # modification, are permitted provided that the following conditions 8 | # are met: 9 | # 10 | # 1. Redistributions of source code must retain the above copyright notice, 11 | # this list of conditions and the following disclaimer. 12 | # 13 | # 2. Redistributions in binary form must reproduce the above copyright 14 | # notice, this list of conditions and the following disclaimer in the 15 | # documentation and/or other materials provided with the distribution. 16 | # 17 | # 3. Neither the name of Google Inc. nor the names of its 18 | # contributors may be used to endorse or promote products derived from this 19 | # software without specific prior written permission. 20 | # 21 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 25 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 31 | # POSSIBILITY OF SUCH DAMAGE. 32 | 33 | set -e 34 | set -x 35 | 36 | 37 | python3 -m virtualenv . 38 | 39 | source ./bin/activate 40 | 41 | pip install -r requirements.txt 42 | -------------------------------------------------------------------------------- /analysis/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright (c) 2022, Google Inc. 4 | # All rights reserved. 5 | # 6 | # Redistribution and use in source and binary forms, with or without 7 | # modification, are permitted provided that the following conditions 8 | # are met: 9 | # 10 | # 1. Redistributions of source code must retain the above copyright notice, 11 | # this list of conditions and the following disclaimer. 12 | # 13 | # 2. Redistributions in binary form must reproduce the above copyright 14 | # notice, this list of conditions and the following disclaimer in the 15 | # documentation and/or other materials provided with the distribution. 16 | # 17 | # 3. Neither the name of Google Inc. nor the names of its 18 | # contributors may be used to endorse or promote products derived from this 19 | # software without specific prior written permission. 20 | # 21 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 25 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 31 | # POSSIBILITY OF SUCH DAMAGE. 32 | 33 | set -e 34 | set -x 35 | 36 | 37 | python3 -m virtualenv . 38 | 39 | source ./bin/activate 40 | 41 | pip install -r requirements.txt 42 | 43 | python -m unittest discover -p "*_test.py" 44 | -------------------------------------------------------------------------------- /colorectal-survival/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright (c) 2020, Google Inc. 4 | # All rights reserved. 5 | # 6 | # Redistribution and use in source and binary forms, with or without 7 | # modification, are permitted provided that the following conditions 8 | # are met: 9 | # 10 | # 1. Redistributions of source code must retain the above copyright notice, 11 | # this list of conditions and the following disclaimer. 12 | # 13 | # 2. Redistributions in binary form must reproduce the above copyright 14 | # notice, this list of conditions and the following disclaimer in the 15 | # documentation and/or other materials provided with the distribution. 16 | # 17 | # 3. Neither the name of Google Inc. nor the names of its 18 | # contributors may be used to endorse or promote products derived from this 19 | # software without specific prior written permission. 20 | # 21 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 25 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 31 | # POSSIBILITY OF SUCH DAMAGE. 32 | 33 | set -e 34 | set -x 35 | 36 | 37 | python3 -m virtualenv . 38 | 39 | source ./bin/activate 40 | 41 | pip install -r requirements.txt 42 | 43 | python -m unittest discover -p "*_test.py" 44 | -------------------------------------------------------------------------------- /breast_survival_prediction/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright (c) 2020, Google Inc. 4 | # All rights reserved. 5 | # 6 | # Redistribution and use in source and binary forms, with or without 7 | # modification, are permitted provided that the following conditions 8 | # are met: 9 | # 10 | # 1. Redistributions of source code must retain the above copyright notice, 11 | # this list of conditions and the following disclaimer. 12 | # 13 | # 2. Redistributions in binary form must reproduce the above copyright 14 | # notice, this list of conditions and the following disclaimer in the 15 | # documentation and/or other materials provided with the distribution. 16 | # 17 | # 3. Neither the name of Google Inc. nor the names of its 18 | # contributors may be used to endorse or promote products derived from this 19 | # software without specific prior written permission. 20 | # 21 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 25 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 31 | # POSSIBILITY OF SUCH DAMAGE. 32 | 33 | set -e 34 | set -x 35 | 36 | 37 | python3 -m virtualenv . 38 | 39 | source ./bin/activate 40 | 41 | pip install -r requirements.txt 42 | 43 | python -m unittest discover -p "*_test.py" 44 | -------------------------------------------------------------------------------- /ct_dicom/dicom_utils_test.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | from absl.testing import absltest 4 | import pydicom 5 | 6 | import dicom_utils 7 | 8 | 9 | def _make_axial_spaced_dicoms(z_pos: float) -> pydicom.Dataset: 10 | dataset = pydicom.Dataset() 11 | dataset.ImagePositionPatient = [0, 0, z_pos] 12 | return dataset 13 | 14 | 15 | class DicomUtilsTest(absltest.TestCase): 16 | 17 | def testDedupe(self): 18 | pydicom_image = pydicom.Dataset() 19 | pydicom_image.SeriesInstanceUID = '1.22.333.4444.55555' 20 | 21 | dicoms_to_dedupe = [] 22 | pydicom_image.InstanceNumber = 1 23 | pydicom_image.AcquisitionNumber = 1 24 | dicoms_to_dedupe.append(copy.deepcopy(pydicom_image)) 25 | pydicom_image.InstanceNumber = 2 26 | pydicom_image.AcquisitionNumber = 2 27 | dicoms_to_dedupe.append(copy.deepcopy(pydicom_image)) 28 | pydicom_image.InstanceNumber = 1 29 | pydicom_image.AcquisitionNumber = 1 30 | dicoms_to_dedupe.append(copy.deepcopy(pydicom_image)) 31 | filtered_dicoms, changed = dicom_utils.dedupe_series(dicoms_to_dedupe) 32 | self.assertTrue(changed) 33 | self.assertLen(filtered_dicoms, 1) 34 | self.assertEqual(filtered_dicoms[0].InstanceNumber, 1) 35 | 36 | def testGetAverageSliceSpacingReturnsAverage(self): 37 | dicoms = [ 38 | _make_axial_spaced_dicoms(0), 39 | _make_axial_spaced_dicoms(1), 40 | _make_axial_spaced_dicoms(2.099), 41 | _make_axial_spaced_dicoms(3.1), 42 | ] 43 | self.assertAlmostEqual( 44 | dicom_utils.try_get_average_slice_spacing(dicoms), 45 | 1.0333, 46 | delta=1e-3, 47 | ) 48 | 49 | def testGetAverageSliceSpacingRaisesOnDuplicateSlicing(self): 50 | dicoms = [ 51 | _make_axial_spaced_dicoms(0), 52 | _make_axial_spaced_dicoms(1), 53 | _make_axial_spaced_dicoms(1.01), 54 | _make_axial_spaced_dicoms(3.1), 55 | _make_axial_spaced_dicoms(5.1), 56 | ] 57 | with self.assertRaisesRegex(ValueError, 'spacing ratio (.*)208.00(.*)'): 58 | dicom_utils.validate_slice_spacing(dicoms) 59 | 60 | 61 | if __name__ == '__main__': 62 | absltest.main() 63 | -------------------------------------------------------------------------------- /fitbit_pregnancy/data/figure_4_b_data.csv: -------------------------------------------------------------------------------- 1 | ,week,percentile_10,percentile_25,percentile_50,percentile_75,percentile_90 2 | 0,-11.0,382.5,436.5,487.0,535.5,580.0 3 | 1,-10.0,386.5,438.5,487.5,535.0,580.0 4 | 2,-9.0,385.0,439.0,488.0,535.5,580.5 5 | 3,-8.0,382.5,437.0,487.5,536.0,579.5 6 | 4,-7.0,384.0,437.0,486.0,534.0,580.5 7 | 5,-6.0,381.5,436.0,486.0,533.5,578.0 8 | 6,-5.0,385.0,439.0,487.0,535.0,578.5 9 | 7,-4.0,386.0,437.5,485.5,533.5,577.5 10 | 8,-3.0,385.0,437.0,486.0,533.5,579.5 11 | 9,-2.0,385.0,438.0,486.5,534.0,581.0 12 | 10,-1.0,383.5,438.0,486.0,532.5,578.5 13 | 11,0.0,383.5,436.0,485.0,532.5,577.5 14 | 12,1.0,383.1,437.5,487.0,534.0,578.5 15 | 13,2.0,383.5,436.0,484.5,532.0,578.0 16 | 14,3.0,384.5,438.0,486.5,532.5,577.0 17 | 15,4.0,384.0,438.5,488.0,534.0,579.5 18 | 16,5.0,383.0,438.5,489.5,535.5,581.0 19 | 17,6.0,396.0,450.0,499.0,545.5,590.5 20 | 18,7.0,402.55000000000007,458.0,506.5,555.5,602.0 21 | 19,8.0,407.0,461.5,511.0,560.0,607.5 22 | 20,9.0,409.5,463.5,513.0,563.0,607.5 23 | 21,10.0,407.5,462.5,514.5,563.5,610.5 24 | 22,11.0,402.5,460.5,513.0,562.0,611.0 25 | 23,12.0,405.0,460.5,513.0,562.5,609.0 26 | 24,13.0,403.0,461.0,512.0,560.5,608.0 27 | 25,14.0,405.5,462.0,512.5,560.0,606.5 28 | 26,15.0,400.5,458.5,509.5,559.0,604.5 29 | 27,16.0,400.95000000000005,459.0,509.0,558.0,604.0 30 | 28,17.0,399.0,455.5,506.5,554.5,600.0 31 | 29,18.0,398.0,455.5,506.0,553.5,598.5 32 | 30,19.0,397.0,454.0,503.0,551.5,596.5 33 | 31,20.0,397.5,452.5,502.0,549.0,594.5 34 | 32,21.0,391.5,449.0,500.5,549.0,595.0 35 | 33,22.0,393.5,451.0,500.0,548.0,592.5 36 | 34,23.0,389.5,448.0,499.0,546.5,592.0 37 | 35,24.0,389.0,446.5,497.5,546.0,591.5 38 | 36,25.0,389.5,445.5,496.5,545.5,591.0 39 | 37,26.0,388.0,447.5,497.5,545.0,589.5 40 | 38,27.0,384.0,445.0,496.0,544.0,589.0 41 | 39,28.0,382.5,443.5,494.5,543.5,587.5 42 | 40,29.0,382.5,443.5,494.0,542.5,588.5 43 | 41,30.0,377.5,441.5,493.0,542.0,588.0 44 | 42,31.0,372.5,441.0,494.0,542.0,586.5 45 | 43,32.0,373.0,439.5,492.0,540.5,584.0 46 | 44,33.0,371.0,437.5,490.5,538.0,584.0 47 | 45,34.0,363.5,435.0,488.0,537.0,582.0 48 | 46,35.0,360.0,433.5,488.5,537.0,583.5 49 | 47,36.0,349.0,429.5,485.0,536.0,583.0 50 | 48,37.0,341.5,427.0,485.0,535.0,581.5 51 | 49,38.0,344.0,426.0,485.5,536.0,583.0 52 | -------------------------------------------------------------------------------- /breast_survival_prediction/mitotic_features_util_test.py: -------------------------------------------------------------------------------- 1 | """Tests for mitotic_features_util.""" 2 | 3 | import unittest 4 | 5 | import numpy as np 6 | 7 | import mitotic_features_util 8 | 9 | # Expect mitosis to be full 3x3 detection, so this heatmap has two mitoses 10 | # at top, left and bottom, right corners. 11 | # pylint: disable=bad-whitespace 12 | _HEATMAP = [[1, 1, 1, 0, 0, 1, 0], 13 | [1, 1, 1, 1, 0, 0, 0], 14 | [1, 1, 1, 0, 1, 1, 1], 15 | [0, 0, 0, 0, 1, 1, 1], 16 | [1, 0, 0, 0.5, 1, 1, 1]] # pyformat: disable 17 | _MASK = [[1, 1, 0, 0], 18 | [1, 1, 0, 0], 19 | [1, 1, 0, 0]] # pyformat: disable 20 | # pylint: enable=bad-whitespace 21 | _EXPECTED_DETECTION = [(1, 1), (3, 5)] 22 | _EXPECTED_DENSITY_MAP = [[1, 0, 0], 23 | [0, 0, 1]] # pyformat: disable 24 | _EXPECTED_MASKED_DETECTION = [(1, 1)] 25 | _EXPECTED_MASKED_DENSITY_MAP = [[1, np.nan, np.nan], 26 | [0, np.nan, np.nan]] # pyformat: disable 27 | _WINDOW_SIZE = 2 28 | _STRIDE = 2 29 | _DETECTION_TH = 0.7 30 | _MORPH_ERODE_SIZE = 3 31 | 32 | 33 | class MitoticFeaturesUtilTest(unittest.TestCase): 34 | 35 | def test_mitosis_detection(self): 36 | heatmap = np.array(_HEATMAP) 37 | expected_detection = np.array(_EXPECTED_DETECTION) 38 | actual_detection = mitotic_features_util.heatmap_to_list( 39 | heatmap, 40 | detection_th=_DETECTION_TH, 41 | mask=None, 42 | morph_erode_size=_MORPH_ERODE_SIZE) 43 | # Sort by row index. 44 | actual_detection = sorted(actual_detection, key=lambda x: x[0]) 45 | # Tolerate up to half pixel error. 46 | np.testing.assert_allclose(expected_detection, actual_detection, atol=0.55) 47 | 48 | def test_density_calculation(self): 49 | # Applying 2x2 windows to the detection of heatmap above. 50 | expected_density_map = np.array(_EXPECTED_DENSITY_MAP) / (_WINDOW_SIZE**2) 51 | actual_density_map = mitotic_features_util.calculate_density( 52 | _EXPECTED_DETECTION, (5, 7), _WINDOW_SIZE, _STRIDE, mask=None) 53 | np.testing.assert_allclose( 54 | expected_density_map, actual_density_map, atol=0.01) 55 | 56 | 57 | if __name__ == '__main__': 58 | unittest.main() 59 | -------------------------------------------------------------------------------- /colorectal-survival/network_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Google Inc. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions 6 | # are met: 7 | # 8 | # 1. Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # 11 | # 2. Redistributions in binary form must reproduce the above copyright 12 | # notice, this list of conditions and the following disclaimer in the 13 | # documentation and/or other materials provided with the distribution. 14 | # 15 | # 3. Neither the name of Google Inc. nor the names of its 16 | # contributors may be used to endorse or promote products derived from this 17 | # software without specific prior written permission. 18 | # 19 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 22 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 23 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 24 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 25 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 26 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 27 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 28 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 29 | # POSSIBILITY OF SUCH DAMAGE. 30 | """Tests network.py.""" 31 | 32 | import unittest 33 | 34 | import tensorflow as tf 35 | 36 | import loss 37 | import network 38 | 39 | SEQUENCE_LENGTH = 2 40 | PATCH_SIZE = 128 41 | INPUT_SHAPE = (SEQUENCE_LENGTH, PATCH_SIZE, PATCH_SIZE, 3) 42 | 43 | 44 | class NetworkTest(unittest.TestCase): 45 | 46 | def test_network(self): 47 | """Test that network compiles.""" 48 | model = network.build_network(INPUT_SHAPE) 49 | model.compile( 50 | optimizer=tf.keras.optimizers.Adam(0.01), 51 | loss=loss.keras_cox_partial_likelihood, 52 | metrics=[]) 53 | 54 | 55 | if __name__ == '__main__': 56 | unittest.main() 57 | -------------------------------------------------------------------------------- /fitbit_pregnancy/data/figure_4_a_data.csv: -------------------------------------------------------------------------------- 1 | week,mean,std 2 | -11.0,-0.6537209892673053,30.854962224243707 3 | -10.0,0.7028511893755772,29.992672243164204 4 | -9.0,0.6594268092784551,29.471308803565204 5 | -8.0,-0.258894843749264,29.573983849212556 6 | -7.0,-0.36746021071533796,28.884725179253543 7 | -6.0,-0.8603991444145338,29.82969264980337 8 | -5.0,0.7501616529061085,29.24428415581838 9 | -4.0,-0.5001535768654576,28.91214680734662 10 | -3.0,-0.05330640431995176,29.419829298716394 11 | -2.0,0.39092185989897577,29.865742179156328 12 | -1.0,-0.7649569385684961,30.72363832116789 13 | 0.0,-1.83264783009431,30.505310689503975 14 | 1.0,0.3032985375359017,32.215502309049974 15 | 2.0,-1.7769392329103384,33.92175266327701 16 | 3.0,-0.1530112029701008,32.98883433024914 17 | 4.0,0.17681661097593832,35.27798026397837 18 | 5.0,0.5508053319615163,37.19759170882378 19 | 6.0,11.656332981151076,37.77957518830591 20 | 7.0,20.415773642708377,39.65274029578804 21 | 8.0,24.20517106884934,41.4722061859448 22 | 9.0,26.65529766595039,42.617451151665456 23 | 10.0,26.059017372234653,42.702828070524774 24 | 11.0,25.180563474166984,43.03720490338458 25 | 12.0,24.77807843324952,42.90283278846893 26 | 13.0,23.718743937737674,43.72062607979433 27 | 14.0,24.555334809707055,40.30503369282092 28 | 15.0,21.424099566277132,42.21162811939891 29 | 16.0,21.077418085157607,42.03513859077033 30 | 17.0,18.04404891989431,39.89166159757949 31 | 18.0,16.6476249317315,41.78114760831215 32 | 19.0,15.455450158615694,41.22008986509072 33 | 20.0,13.701147269131114,40.269900252072965 34 | 21.0,12.50593085412989,41.66046770322587 35 | 22.0,11.19085552694063,40.67698092490573 36 | 23.0,9.901316166946721,42.273284338357044 37 | 24.0,9.383064593934995,41.1147218761935 38 | 25.0,8.063437369418185,42.10574051172368 39 | 26.0,8.32294333541083,43.021348590321246 40 | 27.0,6.285561357374075,43.156768166633796 41 | 28.0,5.327101920246529,42.919608686174726 42 | 29.0,4.64001059525725,43.39867134209981 43 | 30.0,3.6924925954883534,44.47554338018715 44 | 31.0,2.696922509619769,46.259853352618535 45 | 32.0,0.8182993574291535,45.69775042813069 46 | 33.0,-0.7258011912388612,46.59101922134652 47 | 34.0,-3.87603675783018,48.68138590503555 48 | 35.0,-4.504153986671717,49.621089389937126 49 | 36.0,-8.752369388859341,52.87481454458643 50 | 37.0,-10.70880186565304,54.381425749920744 51 | 38.0,-9.89646729198872,52.35847611584832 52 | -------------------------------------------------------------------------------- /fitbit_pregnancy/data/figure_4_c_data.csv: -------------------------------------------------------------------------------- 1 | week,mean,std 2 | -11.0,5.2980877390326215,0.6094343825466831 3 | -10.0,5.303712035995501,0.5749084350211476 4 | -9.0,5.219347581552306,0.43595860345832366 5 | -8.0,5.331833520809899,0.5585637663954041 6 | -7.0,5.258717660292463,0.3123721908700609 7 | -6.0,5.556805399325085,0.418156802664963 8 | -5.0,5.0956130483689535,0.6064598855077248 9 | -4.0,5.281214848143982,0.4924820609070505 10 | -3.0,5.253093363329584,0.28828751278698095 11 | -2.0,5.455568053993251,0.4550933289221012 12 | -1.0,5.061867266591676,0.49165706943919596 13 | 0.0,5.15185601799775,0.5588279876809452 14 | 1.0,5.354330708661417,0.4389955035086015 15 | 2.0,5.2643419572553425,0.42018178691829217 16 | 3.0,5.601799775028121,0.5377604417008558 17 | 4.0,5.635545556805399,0.402415729209152 18 | 5.0,6.749156355455568,0.3476009668126009 19 | 6.0,9.032620922384703,1.1342084302868312 20 | 7.0,11.310461192350957,1.1656279130419693 21 | 8.0,12.530933633295836,0.41336350753038964 22 | 9.0,13.048368953880765,0.36429354815739245 23 | 10.0,12.536557930258718,0.7075147995841313 24 | 11.0,11.749156355455568,0.8003712503152111 25 | 12.0,11.541057367829023,0.7957005476413969 26 | 13.0,10.613048368953882,0.48049625620757425 27 | 14.0,10.056242969628796,0.7147800680143611 28 | 15.0,9.904386951631045,0.5407033461425413 29 | 16.0,10.044994375703038,0.5381034668309098 30 | 17.0,9.302587176602925,0.6964218197034631 31 | 18.0,9.173228346456694,0.48538692936763583 32 | 19.0,8.661417322834646,0.38035109570051007 33 | 20.0,8.869516310461192,0.2724931732308897 34 | 21.0,8.908886389201351,0.5842685430164013 35 | 22.0,8.76265466816648,0.8628118833393229 36 | 23.0,9.145106861642294,0.4561463183650341 37 | 24.0,8.92013498312711,0.20814965080139958 38 | 25.0,9.100112485939258,0.37860055720357705 39 | 26.0,9.049493813273342,0.47376704277016024 40 | 27.0,9.100112485939258,0.6955734297802217 41 | 28.0,9.21259842519685,0.3278214172387608 42 | 29.0,9.105736782902138,0.5524518000317887 43 | 30.0,9.370078740157481,0.46081495711494563 44 | 31.0,9.426321709786277,0.5627108408675253 45 | 32.0,9.482564679415074,0.4291329982826101 46 | 33.0,9.403824521934757,0.9113212512158423 47 | 34.0,9.87064116985377,0.7247781947321843 48 | 35.0,9.955005624296962,0.8754653780232265 49 | 36.0,10.47806524184477,0.5629731158017708 50 | 37.0,11.749156355455568,0.608161379512092 51 | 38.0,12.654668166479189,1.1612187077953764 52 | -------------------------------------------------------------------------------- /fitbit_pregnancy/data/figure_4_d_data.csv: -------------------------------------------------------------------------------- 1 | week,mean,std 2 | -11.0,-1.3081447231134014,29.067215706249122 3 | -10.0,-0.3414028664492109,26.566439521945863 4 | -9.0,-0.6344616661886984,26.96323917171225 5 | -8.0,-1.166615565113025,27.12241867540226 6 | -7.0,-1.1990877424722652,26.017978796545243 7 | -6.0,-1.4998337045473908,27.580095415540356 8 | -5.0,-0.4594673905893731,26.73394779406295 9 | -4.0,-1.5774874928864557,26.469713325196853 10 | -3.0,-0.8648992177378408,25.737397786524184 11 | -2.0,-0.30321238016977986,27.53863528252136 12 | -1.0,-1.545309974865483,27.14615066186711 13 | 0.0,-2.495021895393383,27.180475987207803 14 | 1.0,-0.1349042568408643,28.472530815823816 15 | 2.0,-1.618947911082014,29.860142230067005 16 | 3.0,-0.9159584630531604,29.844326866384034 17 | 4.0,-1.5424724021156357,30.74905741920309 18 | 5.0,-2.2017840606860384,30.70281279840791 19 | 6.0,6.403798345359102,32.288369527005734 20 | 7.0,14.400876805449808,34.032990465492325 21 | 8.0,17.65097160803184,35.51533078440032 22 | 9.0,19.55477887411855,36.44823810056407 23 | 10.0,19.755317434069667,36.857776293636554 24 | 11.0,18.820203829611096,37.899800206111586 25 | 12.0,18.401351828993246,36.791179992159385 26 | 13.0,17.375638233121702,36.87224488615105 27 | 14.0,18.642235718068005,34.39768588377912 28 | 15.0,16.300942818571173,35.64073532344201 29 | 16.0,15.622775826038772,35.91030146211479 30 | 17.0,13.984253679869573,34.466111946937566 31 | 18.0,12.590616536897581,35.343481898871524 32 | 19.0,11.868135597176025,35.84226885248412 33 | 20.0,9.829240387991106,34.66335512173853 34 | 21.0,9.011342212118475,35.46439100040829 35 | 22.0,7.415330599984004,35.30769534717715 36 | 23.0,6.8709638776838124,35.83743788368235 37 | 24.0,6.066333544269589,35.55249638636693 38 | 25.0,4.164695347647124,35.7706913220952 39 | 26.0,4.274512454925458,35.999017672293704 40 | 27.0,2.7403577268725163,37.392231933013086 41 | 28.0,1.316931180817479,37.10824279204288 42 | 29.0,0.9691338788384336,37.41765539866401 43 | 30.0,0.10101281234568518,37.15665226374517 44 | 31.0,-1.2176031378123773,39.250090947201336 45 | 32.0,-3.251549598941118,38.954393105812564 46 | 33.0,-4.928512659281257,40.01288512073331 47 | 34.0,-8.201971844909613,41.82854895738096 48 | 35.0,-9.01989837864877,43.290273646258164 49 | 36.0,-12.520183187527106,44.889161896259296 50 | 37.0,-14.750261375879429,45.90803515830972 51 | 38.0,-14.42225432630731,45.06749657793895 52 | -------------------------------------------------------------------------------- /colorectal_lymph_node_metastasis_prediction/data_utils.py: -------------------------------------------------------------------------------- 1 | """Data-processing utils need to prep data for feature generation and selection. 2 | 3 | For example use, see demo.ipynb. 4 | """ 5 | 6 | import pandas as pd 7 | 8 | 9 | def bin_age(age, start_cutoff=60, stop_cutoff=80, increment=10): 10 | """Categorize age by bins. 11 | 12 | Args: 13 | age: age to bin. 14 | start_cutoff: first cutoff to use forming bins. 15 | stop_cutoff: last cutoff to use forming bins. 16 | increment: difference in age between bins. 17 | 18 | Returns: 19 | string representing binned age. 20 | """ 21 | if pd.isnull([age]): 22 | return age 23 | age = float(age) 24 | 25 | last_cutoff = 0 26 | for age_cutoff in range(start_cutoff, stop_cutoff + increment, increment): 27 | if age < age_cutoff: 28 | return f'{last_cutoff}-{age_cutoff-1}' 29 | last_cutoff = age_cutoff 30 | return f'>={stop_cutoff}' 31 | 32 | 33 | def prep_features(df, feature_cols, label_cols): 34 | """Returns a pd.DataFrame suitable for modeling. 35 | 36 | 1) Select only desired `feature_cols`. 37 | 2) Remove rows with nan values. 38 | 3) Convert categorical features to dummy variables. 39 | 4) Remove constant `feature_cols`. 40 | 41 | Args: 42 | df: pd.DataFrame containing `feature_cols`, `label_cols`. 43 | feature_cols: a list of columns in `df` containing regression features. 44 | label_cols: a list of column names in `df` containing labels. These columns 45 | are not coded as dummies if they are categorical. 46 | 47 | Returns: 48 | tuple of (dataframe, expanded feature cols). 49 | """ 50 | # Select subset of required columns. 51 | df = df.copy()[label_cols + feature_cols] 52 | # Remove rows with missing values. 53 | n_rows = df.shape[0] 54 | df = df.dropna() 55 | delta = n_rows - df.shape[0] 56 | if delta > 0: 57 | print('Dropped %d rows due to missing values.' % delta) 58 | 59 | # Convert categorical cols to dummy vars. 60 | df_labels = df[label_cols] 61 | df = pd.get_dummies(df[feature_cols], drop_first=True) 62 | expanded_feature_cols = list(df.columns) 63 | 64 | # Remove constant feature columns 65 | df = df.loc[:, (df != df.iloc[0]).any()] 66 | delta = set(expanded_feature_cols) - set(df.columns) 67 | if delta: 68 | print('Dropped %s constant feature columns.' % list(delta)) 69 | 70 | # Add labels to regression 71 | df = pd.concat([df_labels, df], axis=1) 72 | return df, expanded_feature_cols 73 | -------------------------------------------------------------------------------- /ct_dicom/gcp/auth_test.py: -------------------------------------------------------------------------------- 1 | from unittest import mock 2 | 3 | from absl import flags 4 | from absl.testing import absltest 5 | from absl.testing import flagsaver 6 | from absl.testing import parameterized 7 | from google import auth as gauth 8 | from google.auth.compute_engine import credentials as gce_credentials 9 | 10 | from gcp import auth 11 | 12 | _TOO_MANY_FLAGS_PARAMS = ( 13 | dict( 14 | testcase_name='gce_access', 15 | use_gce_credentials=True, 16 | access_token='dummy_access_token', 17 | ), 18 | ) 19 | 20 | 21 | class DefineFlagsTest(absltest.TestCase): 22 | 23 | def setUp(self): 24 | super().setUp() 25 | auth.define_flags() 26 | 27 | def test_access_token(self): 28 | self.assertIn('access_token', flags.FLAGS) 29 | 30 | def test_use_gce_credentials(self): 31 | self.assertIn('use_gce_credentials', flags.FLAGS) 32 | 33 | 34 | class GenerateGCPCredentials(parameterized.TestCase): 35 | 36 | def setUp(self): 37 | super().setUp() 38 | auth.define_flags() 39 | 40 | def test_application_default_credentials(self): 41 | """ADC is used when no other credentials flags are set.""" 42 | with mock.patch.object(gauth, 'default', autospec=True) as mock_default_fn: 43 | auth.create_gcp_credentials() 44 | mock_default_fn.assert_called_once() 45 | 46 | @flagsaver.flagsaver(use_gce_credentials=True) 47 | def test_use_gce_credentials(self): 48 | """GCE credentials used when `use_gce_credentials` flag is set.""" 49 | with mock.patch.object( 50 | gce_credentials, 'Credentials', autospec=True 51 | ) as mock_credentials_fn: 52 | auth.create_gcp_credentials() 53 | mock_credentials_fn.assert_called_once_with() 54 | 55 | def test_use_gce_credentials_flag_not_defined(self): 56 | del flags.FLAGS.use_gce_credentials 57 | with mock.patch.object(gauth, 'default', autospec=True): 58 | auth.create_gcp_credentials() 59 | 60 | @flagsaver.flagsaver(access_token='dummy_access_token') 61 | def test_access_token(self): 62 | """Access Token is used when `access_token` flag is set.""" 63 | with mock.patch.object( 64 | auth, '_AccessTokenCredentials', autospec=True 65 | ) as mock_access_token_cls: 66 | auth.create_gcp_credentials() 67 | mock_access_token_cls.assert_called_once_with('dummy_access_token') 68 | 69 | def test_access_token_flag_not_defined(self): 70 | del flags.FLAGS.access_token 71 | with mock.patch.object(gauth, 'default', autospec=True): 72 | auth.create_gcp_credentials() 73 | 74 | @parameterized.named_parameters(*_TOO_MANY_FLAGS_PARAMS) 75 | def test_more_than_one_credential_requested(self, **kwargs): 76 | with flagsaver.flagsaver(**kwargs): 77 | with self.assertRaisesRegex(ValueError, 'one credential'): 78 | auth.create_gcp_credentials() 79 | 80 | 81 | if __name__ == '__main__': 82 | absltest.main() 83 | -------------------------------------------------------------------------------- /ct_dicom/example_builder_beam.py: -------------------------------------------------------------------------------- 1 | """Beam wrappers to create Examples from DICOM files. 2 | 3 | These are meant to be used in conjunction with Beam stages in `dicomweb_beam.py` 4 | to create Examples from DICOMs sourced from a CHC DICOM Store. 5 | """ 6 | 7 | import io 8 | 9 | from absl import logging 10 | import apache_beam as beam 11 | from apache_beam import pvalue 12 | import pydicom 13 | 14 | import example_builder 15 | from gcp import dicomweb_beam 16 | 17 | 18 | class CreateCTExampleFn(beam.DoFn): 19 | """Beam wrapper to create CT Example.""" 20 | 21 | # The tag to identify the `TaggedOutput` containing the error string emitted 22 | # in `process()`. 23 | ERROR_OUTPUT_TAG = 'errors' 24 | 25 | def __init__(self, dataset_name: str = 'adhoc') -> None: 26 | """Creates an instance. 27 | 28 | Args: 29 | dataset_name: The dataset name in all emitted TF Examples, stored under 30 | the key "volume/id". 31 | """ 32 | super().__init__() 33 | self._dataset_name = dataset_name 34 | 35 | def process(self, series_scope_dicoms: dicomweb_beam.SeriesScopeDICOMs): 36 | """Creates a CT model-compatible TF Example from DICOMs in a Series. 37 | 38 | If successful, a key-value pair is emitted: 39 | - Key: `/` string. 40 | - Value: The prepared TF Example. 41 | 42 | In case of failure, a string formatted as a CSV row is emitted; it is 43 | routed to an output tagged 'errors'. The row is comma-separated, with column 44 | values enclosed within double-quotes. The column contents are: 45 | - Column 1: Study Instance UID 46 | - Column 2: Series Instance UID 47 | - Column 3 onwards: Stringified arguments passed to the Exception that was 48 | caught for the error. 49 | 50 | Args: 51 | series_scope_dicoms: DICOMs sharing the same value for the Series and 52 | Study Instance UID Attributes. The metadata includes the Study and 53 | Series Instance UID values referenced in this docstring. 54 | 55 | Yields: 56 | Prepared TF Example or error string. 57 | """ 58 | try: 59 | dicoms = list( 60 | pydicom.filereader.dcmread(io.BytesIO(dicom_bytes)) 61 | for dicom_bytes in series_scope_dicoms.dicoms 62 | ) 63 | yield ( 64 | series_scope_dicoms.key.encode('utf-8'), 65 | example_builder.create_ct_tfexample(dicoms, self._dataset_name), 66 | ) 67 | except Exception as e: # pylint: disable=broad-exception-caught 68 | logging.exception('Example creation failed %r', series_scope_dicoms.key) 69 | yield pvalue.TaggedOutput( 70 | self.ERROR_OUTPUT_TAG, 71 | dicomweb_beam.to_csv_row(( 72 | series_scope_dicoms.metadata.study_instance_uid, 73 | series_scope_dicoms.metadata.series_instance_uid, 74 | *e.args, 75 | )), 76 | ) 77 | -------------------------------------------------------------------------------- /colorectal-survival/loss_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Google Inc. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions 6 | # are met: 7 | # 8 | # 1. Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # 11 | # 2. Redistributions in binary form must reproduce the above copyright 12 | # notice, this list of conditions and the following disclaimer in the 13 | # documentation and/or other materials provided with the distribution. 14 | # 15 | # 3. Neither the name of Google Inc. nor the names of its 16 | # contributors may be used to endorse or promote products derived from this 17 | # software without specific prior written permission. 18 | # 19 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 22 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 23 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 24 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 25 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 26 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 27 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 28 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 29 | # POSSIBILITY OF SUCH DAMAGE. 30 | """Tests for loss.py.""" 31 | 32 | import math 33 | import numpy as np 34 | import tensorflow as tf 35 | import loss 36 | 37 | 38 | class LossTest(tf.test.TestCase): 39 | 40 | def test_cox_partial_likelihood(self): 41 | with self.test_session(): 42 | preds = tf.constant([6, 5, 4, 1], dtype=tf.float32) 43 | event_times = tf.constant([4, 5, 3, 2], dtype=tf.int32) 44 | censored = tf.constant([True, False, False, False], dtype=tf.bool) 45 | 46 | loss_1 = math.exp(5) / math.exp(5) 47 | loss_2 = math.exp(4) / (math.exp(6) + math.exp(5) + math.exp(4)) 48 | loss_3 = math.exp(1) / ( 49 | math.exp(6) + math.exp(5) + math.exp(4) + math.exp(1)) 50 | expected = -(math.log(loss_1) + math.log(loss_2) + math.log(loss_3)) / 3 51 | actual = loss.cox_partial_likelihood(event_times, censored, preds) 52 | self.assertAllClose(expected, actual) 53 | 54 | def test_logsumexp_masked(self): 55 | with self.test_session(): 56 | exp_a = tf.constant([[.3, .5, .2], [.1, .0, .9], [.2, .1, .7]]) 57 | a = tf.math.log(exp_a) 58 | m = tf.constant([[1, 1, 1], [1, 1, 0], [1, 0, 0]]) 59 | expected = np.log([1.0, 0.1, 0.2]) 60 | actual = loss.logsumexp_masked(a, m) 61 | self.assertAllClose(expected, actual) 62 | 63 | 64 | if __name__ == '__main__': 65 | tf.test.main() 66 | -------------------------------------------------------------------------------- /fetal_ultrasound_blind_sweeps/networks_test.py: -------------------------------------------------------------------------------- 1 | """Unit tests for model networks and loss functions.""" 2 | 3 | import tensorflow.compat.v1 as tf 4 | import networks 5 | 6 | 7 | class NetworksTest(tf.test.TestCase): 8 | 9 | def test_gestational_age_regression_model(self): 10 | with tf.Graph().as_default(): 11 | # Create a tensor with dimensions representing batch size, sequence 12 | # length, height, width, image channels. Note: The sizes for unit testing 13 | # are smaller than the real values to reduce resource usage. Actual sizes 14 | # are [8, 24, 432, 576, 1] as discussed in Supplementary Methods of the 15 | # publication. 16 | video_clips = tf.zeros([2, 10, 24, 32, 1]) 17 | ages, variances = networks.gestational_age_regression_model( 18 | video_clips, is_training=False) 19 | with self.test_session() as sess: 20 | sess.run(tf.global_variables_initializer()) 21 | ages_, variances_ = sess.run((ages, variances)) 22 | # Expect two output values, one for each clip in the batch. 23 | self.assertEqual(ages_.shape, (2, 1)) 24 | self.assertEqual(variances_.shape, (2, 1)) 25 | 26 | def test_fetal_malpresentation_classification_model(self): 27 | with tf.Graph().as_default(): 28 | # Actual dimensions are [8, 100, 240, 320, 1] as discussed in 29 | # Supplementary Methods of the publication. 30 | video_clips = tf.zeros([2, 10, 24, 32, 1]) 31 | classification_output = ( 32 | networks.fetal_malpresentation_classification_model( 33 | video_clips, is_training=False)) 34 | with self.test_session() as sess: 35 | sess.run(tf.global_variables_initializer()) 36 | classification_output_ = sess.run(classification_output) 37 | # Expect two output values, one for each clip in the batch. 38 | self.assertEqual(classification_output_.shape, (2, 1)) 39 | 40 | def test_gestational_age_loss_function(self): 41 | with tf.Graph().as_default(): 42 | video_clips = tf.zeros([2, 10, 24, 32, 1]) 43 | labels = tf.zeros([2, 1]) 44 | ages, variances = networks.gestational_age_regression_model( 45 | video_clips, is_training=True) 46 | loss = networks.gestational_age_loss_function(ages, variances, labels) 47 | with self.test_session() as sess: 48 | sess.run(tf.global_variables_initializer()) 49 | loss_ = sess.run(loss) 50 | # Expect a scalar output. 51 | self.assertEqual(loss_.shape, ()) 52 | 53 | def test_malpresentation_model_loss_function(self): 54 | with tf.Graph().as_default(): 55 | video_clips = tf.zeros([2, 10, 24, 32, 1]) 56 | labels = tf.zeros([2, 1]) 57 | logits = networks.fetal_malpresentation_classification_model( 58 | video_clips, is_training=True) 59 | loss = networks.fetal_malpresentation_loss_function(logits, labels) 60 | with self.test_session() as sess: 61 | sess.run(tf.global_variables_initializer()) 62 | loss_ = sess.run(loss) 63 | # Expect a scalar output. 64 | self.assertEqual(loss_.shape, ()) 65 | 66 | if __name__ == '__main__': 67 | tf.test.main() 68 | -------------------------------------------------------------------------------- /breast_survival_prediction/stage2_features.py: -------------------------------------------------------------------------------- 1 | """Features for Nottingham Stage 2 models.""" 2 | 3 | from typing import Mapping, Sequence, Union 4 | import numpy as np 5 | 6 | # Number of possible NF/TF grades 7 | _NUM_NPTF_GRADES = 3 8 | 9 | 10 | def _calc_normalized_area(heatmap: np.ndarray) -> np.ndarray: 11 | """Calculates normalized area for each class in the heatmap. 12 | 13 | Args: 14 | heatmap: 3D array of heatmap (channel order: HWC). 15 | 16 | Returns: 17 | 1D array of normalized area for each channel of the heatmap. 18 | """ 19 | area = np.nansum(heatmap, axis=(0, 1)) 20 | normalized_area = area / (np.nansum(area) + np.finfo(float).eps) 21 | return normalized_area 22 | 23 | 24 | def np_tf_featurizer( 25 | tmap: Mapping[str, Union[Sequence[Sequence[float]], 26 | np.ndarray]]) -> np.ndarray: 27 | """Featurization for Nuclear Pleomorphism (NP) and Tubule Formation (TF). 28 | 29 | Args: 30 | tmap: Dictionary of tensors. Expected to contain two 3D heatmaps of 31 | probabilities (channel order: HWC) representing NP/TF model output (keyed 32 | 'heatmap' and the invasive carcinoma segmentation model output (keyed 33 | 'ic_heatmap'). Channels of first heatmap represent NP/TF1, NP/TF2, NP/TF3. 34 | While the ic_heatmap represents Benign, Invasive Carcinoma, Carcinoma in 35 | situ. The first heatmap can optionally contain NP/TF0 which indicate 36 | redundant non-invasive carcinoma segmentation by the NP/TF model. Both 37 | heatmaps are expected to be of the same size. 38 | 39 | Returns: 40 | Normalized area of NP/TF grade 1, 2, 3 within tumor area and outside tumor 41 | area resulting in 6 numbers in total. 42 | """ 43 | 44 | ic_heatmap = tmap['ic_heatmap'] 45 | ic_positive_mask = np.argmax(ic_heatmap, axis=-1) == 1 46 | ic_positive_mask = np.expand_dims(ic_positive_mask, -1) 47 | 48 | nptf_heatmap = tmap['heatmap'] 49 | # Some heatmap may have extra dim in front, so selecting only the last 50 | # _NUM_NPTF_GRADES channels. 51 | nptf_heatmap = nptf_heatmap[..., -_NUM_NPTF_GRADES:] 52 | # Heatmap within invasive tumor. 53 | heatmap_ic = nptf_heatmap * ic_positive_mask 54 | # heatmap outside invasive tumor. 55 | heatmap_non_ic = nptf_heatmap * (1 - ic_positive_mask) 56 | area_ic = _calc_normalized_area(heatmap_ic) 57 | area_non_ic = _calc_normalized_area(heatmap_non_ic) 58 | return np.concatenate((area_ic, area_non_ic)) 59 | 60 | 61 | def mc_featurizer( 62 | tmap: Mapping[str, Union[Sequence[Sequence[float]], 63 | np.ndarray]]) -> np.ndarray: 64 | """Featurization for Mitotic Count (MC). 65 | 66 | Args: 67 | tmap: Mapping of precomputed mitotic features. Expected to contain density 68 | map (a 2D array) under the key 'density'. 69 | 70 | Returns: 71 | The 5th, 25th, 50th, 75th, and 95th percentiles of the density map 72 | """ 73 | density_map = tmap['density'] 74 | # drop NaN 75 | density_map = density_map[~np.isnan(density_map)] 76 | if not density_map.size: 77 | density_map = np.array([0]) 78 | # An alternative is to include histogram and total count. 79 | return np.percentile(density_map, [5, 25, 50, 75, 95]) 80 | -------------------------------------------------------------------------------- /colorectal-survival/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Google Inc. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions 6 | # are met: 7 | # 8 | # 1. Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # 11 | # 2. Redistributions in binary form must reproduce the above copyright 12 | # notice, this list of conditions and the following disclaimer in the 13 | # documentation and/or other materials provided with the distribution. 14 | # 15 | # 3. Neither the name of Google Inc. nor the names of its 16 | # contributors may be used to endorse or promote products derived from this 17 | # software without specific prior written permission. 18 | # 19 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 22 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 23 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 24 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 25 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 26 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 27 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 28 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 29 | # POSSIBILITY OF SUCH DAMAGE. 30 | """Train network on synthetic data.""" 31 | 32 | import lifelines 33 | import numpy as np 34 | import tensorflow as tf 35 | 36 | import loss 37 | import network 38 | 39 | NUM_EXAMPLES = 64 40 | SEQUENCE_LENGTH = 2 41 | PATCH_SIZE = 128 42 | NUM_EPOCHS = 32 43 | 44 | 45 | def main(): 46 | # Set up synthetic data 47 | rs = np.random.RandomState(0) 48 | shape = (NUM_EXAMPLES, SEQUENCE_LENGTH, PATCH_SIZE, PATCH_SIZE, 3) 49 | images = rs.rand(*shape) 50 | event_times = rs.rand(NUM_EXAMPLES) 51 | censored = rs.rand(NUM_EXAMPLES) > 0.75 52 | y_true = np.stack([event_times, censored], axis=1) 53 | 54 | # Build network 55 | model = network.build_network(images.shape[1:]) 56 | model.compile( 57 | optimizer=tf.keras.optimizers.Adam(0.01), 58 | loss=loss.keras_cox_partial_likelihood, 59 | metrics=[]) 60 | 61 | # Compute baseline c-index 62 | y_pred_baseline = -model.predict(images)[:, 0] 63 | c_index_init = lifelines.utils.concordance_index( 64 | event_times, y_pred_baseline, event_observed=~censored) 65 | print(f'Initial C-index: {c_index_init}') 66 | 67 | # Train model 68 | model.fit(images, y_true, epochs=NUM_EPOCHS) 69 | 70 | # Compute final c-index 71 | y_pred_final = -model.predict(images)[:, 0] 72 | c_index_final = lifelines.utils.concordance_index( 73 | event_times, y_pred_final, event_observed=~censored) 74 | print(f'Initial C-index: {c_index_init}') 75 | print(f'Final C-index: {c_index_final}') 76 | 77 | # Assert that c-index increased by at least 10 points 78 | assert c_index_final > (c_index_init + 0.1) 79 | 80 | 81 | if __name__ == '__main__': 82 | main() 83 | -------------------------------------------------------------------------------- /data_splits/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # New splits for public datasets used in technical report 4 | 5 | This repo contains new splits for the VQA-Rad and PAD-UFES-20 datasets. 6 | 7 | If you use any of these in your own research, please cite our technical report: 8 | 9 | ``` 10 | @misc{yang2024advancing, 11 | title={Advancing Multimodal Medical Capabilities of Gemini}, 12 | author={Lin Yang and Shawn Xu and Andrew Sellergren and Timo Kohlberger and Yuchen Zhou and Ira Ktena and Atilla Kiraly and Faruk Ahmed and Farhad Hormozdiari and Tiam Jaroensri and Eric Wang and Ellery Wulczyn and Fayaz Jamil and Theo Guidroz and Chuck Lau and Siyuan Qiao and Yun Liu and Akshay Goel and Kendall Park and Arnav Agharwal and Nick George and Yang Wang and Ryutaro Tanno and David G. T. Barrett and Wei-Hung Weng and S. Sara Mahdavi and Khaled Saab and Tao Tu and Sreenivasa Raju Kalidindi and Mozziyar Etemadi and Jorge Cuadros and Gregory Sorensen and Yossi Matias and Katherine Chou and Greg Corrado and Joelle Barral and Shravya Shetty and David Fleet and S. M. Ali Eslami and Daniel Tse and Shruthi Prabhakara and Cory McLean and Dave Steiner and Rory Pilgrim and Christopher Kelly and Shekoofeh Azizi and Daniel Golden}, 13 | year={2024}, 14 | eprint={2405.03162}, 15 | archivePrefix={arXiv}, 16 | primaryClass={cs.CV} 17 | } 18 | ``` 19 | 20 | 21 | ##VQA-Rad: Balanced splits and exclusions from human evaluations 22 | 23 | File: **vqa_rad_balanced_split_and_human_eval_inclusions.tsv** 24 | 25 | This file contains new balanced three-way splits (train, validate, test) for the VQA-Rad dataset (column BALANCED_SPLIT). 26 | 27 | In addition, column INCLUDED_IN_HUMAN_EVAL indicates which question and answers were excluded in the human evaluation of the test split in our Technical Report [Advancing Multimodal Medical Capabilities of Gemini](https://arxiv.org/abs/2405.03162). See the Appendix of the report for details on the motivation and criteria for the new splits. 28 | 29 | 30 | The original VQA-RAD dataset was published by: 31 | 32 | Dina Demner-Fushman (ddemner@mail.nih.gov)\ 33 | Lister Hill National Center for Biomedical Communications, National Library of Medicine, Bethesda, MD, USA, 34 | 35 | hosted at https://osf.io/89kps/. 36 | 37 | 38 | ##PAD-UFES-20: Train and test split 39 | 40 | File: **pad_ufes_20_split.tsv** 41 | 42 | This file contains a split for the PAD-UFES-20 dataset into a train and test subset, which was used to train and test the Med-Gemini model described in [Advancing Multimodal Medical Capabilities of Gemini](https://arxiv.org/abs/2405.03162). The split is at the patient level, i.e. patients are disjoint between the two splits. 43 | Column PATIENT_ID contains the patient ID, column SPLIT the subset assignment. 44 | 45 | The original PAD-UFES-20 dataset was published by: 46 | 47 | Pacheco, Andre G. C.; Lima, Gustavo R.; Salomão, Amanda S.; Krohling, Breno; Biral, Igor P.; de Angelo, Gabriel G. ; Alves Jr, Fábio C. R. ; Esgario, José G. M.; Simora, Alana C. ; Castro, Pedro B. C. ; Rodrigues, Felipe B.; Frasson, Patricia H. L. ; Krohling, Renato A.; Knidel, Helder ; Santos, Maria C. S. ; Espírito Santo, Rachel B.; Macedo, Telma L. S. G.; Canuto, Tania R. P. ; de Barros, Luíz F. S. (2020), “PAD-UFES-20: a skin lesion dataset composed of patient data and clinical images collected from smartphones”, Mendeley Data, V1, doi: 10.17632/zr7vgbcyr2.1 48 | 49 | and is hosted at https://data.mendeley.com/datasets/zr7vgbcyr2/1. -------------------------------------------------------------------------------- /colorectal-survival/analysis_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Google Inc. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions 6 | # are met: 7 | # 8 | # 1. Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # 11 | # 2. Redistributions in binary form must reproduce the above copyright 12 | # notice, this list of conditions and the following disclaimer in the 13 | # documentation and/or other materials provided with the distribution. 14 | # 15 | # 3. Neither the name of Google Inc. nor the names of its 16 | # contributors may be used to endorse or promote products derived from this 17 | # software without specific prior written permission. 18 | # 19 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 22 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 23 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 24 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 25 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 26 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 27 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 28 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 29 | # POSSIBILITY OF SUCH DAMAGE. 30 | """Tests for analysis.py.""" 31 | 32 | import unittest 33 | import pandas as pd 34 | 35 | import analysis 36 | 37 | 38 | class AnalysisTest(unittest.TestCase): 39 | 40 | def setUp(self): 41 | super().setUp() 42 | self.example_ids = [0, 1, 2, 3, 4] 43 | self.times = [1, 1, 2, 3, 4] # Two examples with same time 44 | self.observed = [1, 1, 0, 1, 1] # One censored example 45 | self.risk_scores = [7, 8, 7, 9, 10] # One out of order, one tied 46 | self.df = pd.DataFrame( 47 | [self.example_ids, self.times, self.observed, self.risk_scores]).T 48 | self.df.columns = [ 49 | 'id', analysis.TIME, analysis.OBSERVED, analysis.RISK_SCORE 50 | ] 51 | 52 | def test_plot_km_curve(self): 53 | analysis.plot_km_curve(self.df, self.df) 54 | 55 | def test_discretize(self): 56 | risk_scores_tune = range(100) 57 | risk_scores_test = range(25, 125) 58 | expected = ['Medium Risk'] * 50 + ['High Risk'] * 50 59 | actual = analysis.discretize(risk_scores_tune, risk_scores_test) 60 | self.assertListEqual(list(actual), expected) 61 | 62 | def test_c_index(self): 63 | # Comparable pairs: 64 | # (0, 2): Tied 65 | # (0, 3): True 66 | # (0, 4): True 67 | # (1, 2): False 68 | # (1, 3): True 69 | # (1, 4): True 70 | # (3, 4): True 71 | expected = (5 + 0.5) / 7 72 | actual = analysis.c_index(self.df) 73 | self.assertEqual(expected, actual) 74 | 75 | def test_survival_auc(self): 76 | # Threshold: 1.5 Comparable pairs: 77 | # (0, 2): Tied 78 | # (0, 3): True 79 | # (0, 4): True 80 | # (1, 2): False 81 | # (1, 3): True 82 | # (1, 4): True 83 | expected = (4 + 0.5) / 6 84 | actual = analysis.survival_auc(self.df, 1.5) 85 | self.assertEqual(expected, actual) 86 | 87 | def test_get_hazard_ratios(self): 88 | analysis.get_hazard_ratios( 89 | self.df[[analysis.TIME, analysis.OBSERVED, analysis.RISK_SCORE]]) 90 | 91 | 92 | if __name__ == '__main__': 93 | unittest.main() 94 | -------------------------------------------------------------------------------- /ct_dicom/gcp/auth.py: -------------------------------------------------------------------------------- 1 | """GCP authentication utilities for binary runfiles.""" 2 | 3 | from absl import flags 4 | from absl import logging 5 | from google import auth 6 | from google.auth import credentials 7 | from google.auth.compute_engine import credentials as gce_credentials 8 | from typing_extensions import override 9 | 10 | _ACCESS_TOKEN_FLAG_NAME = 'access_token' 11 | _GCE_FLAG_NAME = 'use_gce_credentials' 12 | 13 | 14 | _WARNING_MSG_TEMPLATE = ( 15 | '--%s flag not set. define_flags() has not been called before calling' 16 | ' create_gcp_credentials().' 17 | ) 18 | 19 | 20 | class _AccessTokenCredentials(credentials.Credentials): 21 | 22 | def __init__(self, bearer_token: str) -> None: 23 | super().__init__() 24 | self.token = bearer_token 25 | 26 | @override 27 | def refresh(self, _) -> None: 28 | pass 29 | 30 | 31 | def define_flags() -> None: 32 | """Defines command line flags for users to specify GCP Credentials. 33 | 34 | The method may be safely called multiple times, since a flag is defined only 35 | if it does not already exist. 36 | 37 | The following flags are defined: 38 | - "access_token" 39 | - "use_gce_credentials" 40 | """ 41 | if _ACCESS_TOKEN_FLAG_NAME not in flags.FLAGS: 42 | flags.DEFINE_string( 43 | _ACCESS_TOKEN_FLAG_NAME, 44 | None, 45 | 'The OAuth2 Access Token to access the DICOM Store. Primarily meant for' 46 | ' toy/test applications. Cannot be used in conjunction with other ' 47 | 'credentials.', 48 | ) 49 | if _GCE_FLAG_NAME not in flags.FLAGS: 50 | flags.DEFINE_boolean( 51 | _GCE_FLAG_NAME, 52 | False, 53 | 'If true, use GCE Credentials. Cannot be used in conjunction with other' 54 | ' credentials.', 55 | ) 56 | 57 | 58 | def create_gcp_credentials() -> credentials.Credentials: 59 | """Creates GCP credentials, depending on which command line flag is set. 60 | 61 | To define command line flags to specify all supported credentials, the 62 | `define_flags()` method must be called first. 63 | 64 | If `define_flags()` is not called or no flag is set, it defaults to using 65 | Application Default Credentials. 66 | 67 | The supported credential types are: 68 | - Access Token (flag: "access_token") 69 | - GCE (flag: "use_gce_credentials") 70 | - Application Default (no other flags set) 71 | 72 | Returns: 73 | A GCP Credentials instance. 74 | 75 | Raises: 76 | ValueError: If more than one command line Credentials flags is set. 77 | """ 78 | num_credentials_flags_set = 0 79 | 80 | try: 81 | access_token = flags.FLAGS[_ACCESS_TOKEN_FLAG_NAME].value 82 | except KeyError: 83 | access_token = None 84 | logging.warning(_WARNING_MSG_TEMPLATE, _ACCESS_TOKEN_FLAG_NAME) 85 | num_credentials_flags_set += access_token is not None 86 | 87 | try: 88 | use_gce_credentials = flags.FLAGS[_GCE_FLAG_NAME].value 89 | except KeyError: 90 | use_gce_credentials = False 91 | logging.warning(_WARNING_MSG_TEMPLATE, _GCE_FLAG_NAME) 92 | num_credentials_flags_set += use_gce_credentials 93 | 94 | if num_credentials_flags_set > 1: 95 | raise ValueError('At most one credential type can be set.') 96 | 97 | if access_token is not None: 98 | logging.info('Using Access Token credentials.') 99 | return _AccessTokenCredentials(access_token) 100 | 101 | if use_gce_credentials: 102 | logging.info('Using GCE Credentials.') 103 | return gce_credentials.Credentials() 104 | 105 | logging.info('Using Application Default Credentials.') 106 | return auth.default()[0] 107 | -------------------------------------------------------------------------------- /breast_survival_prediction/mitotic_features_util.py: -------------------------------------------------------------------------------- 1 | """Utilities function for calculating second-stage features for mitotic model.""" 2 | 3 | from typing import List, Optional, Tuple 4 | import cv2 5 | import numpy as np 6 | import PIL.Image 7 | 8 | 9 | def heatmap_to_list( 10 | hm: np.ndarray, 11 | detection_th: float, 12 | mask: Optional[np.ndarray] = None, 13 | morph_erode_size: Optional[int] = None) -> List[Tuple[float]]: 14 | """Detect mitosis on the heatmap. 15 | 16 | Args: 17 | hm: 2D heatmap output from the mitotic model. 18 | detection_th: Probability threshold for detection (float between 0-1). 19 | mask: Area not to consider such as out of tissue or out of tumor. 20 | morph_erode_size: Size of morphological eroding structuring element. This is 21 | used as the clean up step. 22 | 23 | Returns: 24 | List of (row, column) heatmap-coordinate of the detected centroid. 25 | """ 26 | binarized_hm = hm > detection_th 27 | if mask is not None: 28 | binarized_hm = binarized_hm * mask 29 | if morph_erode_size and morph_erode_size > 0: 30 | binarized_hm = cv2.morphologyEx( 31 | binarized_hm.astype('uint8'), cv2.MORPH_ERODE, 32 | np.ones(morph_erode_size)) 33 | # into one connected component, see 34 | # https://docs.opencv.org/3.4/dd/d46/imgproc_8hpp.html 35 | # Index 3 is the centroid among other info. 36 | detected_centroids = cv2.connectedComponentsWithStats( 37 | binarized_hm.astype('uint8'), 8, cv2.CV_32S 38 | )[3] 39 | # The centroids include the background label, so we slice from the second 40 | # element onward. 41 | detected_centroids = detected_centroids[1:] 42 | # convert to (row, column) format so that this is consistent with indexing. 43 | return [(pt[1], pt[0]) for pt in detected_centroids] 44 | 45 | 46 | def calculate_density(detected_centroids: List[Tuple[float]], 47 | heatmap_shape: List[int], 48 | window_size: int, 49 | stride: int, 50 | mask: Optional[np.ndarray] = None) -> np.ndarray: 51 | """Calculate density map. 52 | 53 | Args: 54 | detected_centroids: list of detected centroids in original heatmap's 55 | coordinate. This can be the output of the heatmap_to_list function. 56 | heatmap_shape: original heatmap shape (row, column). 57 | window_size: size of windows to calculate density. 58 | stride: overlap between each density windows. This is what control the size 59 | of density map, similar to prediction_size in the inference pipeline. 60 | mask: binary mask that specify area to compute density. If set, things 61 | outside of mask will get NaN. 62 | 63 | Returns: 64 | numpy array representing the density map. 65 | """ 66 | density_shape = (int(heatmap_shape[0] // stride), 67 | int(heatmap_shape[1] // stride)) 68 | window_area = window_size**2 69 | density_map = np.zeros(density_shape) 70 | density_map[:] = np.nan 71 | 72 | def _is_inside(pt, tl): 73 | y, x = pt[0], pt[1] 74 | yb, xb = tl[0], tl[1] 75 | return (x >= xb) and (x <= xb + window_size) and (y >= yb) and ( 76 | y <= yb + window_size) 77 | 78 | if mask is not None: 79 | # PIL expects (width, height), but the shape is 80 | # (row, column) = (height, width) 81 | m_pil = PIL.Image.fromarray(mask).resize( 82 | (density_shape[1], density_shape[0]), PIL.Image.Resampling.NEAREST 83 | ) 84 | mask = np.array(m_pil) 85 | for i in range(density_shape[0]): 86 | for j in range(density_shape[1]): 87 | topleft = [i * stride, j * stride] 88 | if (mask is not None) and (not mask[i][j]): 89 | continue 90 | inbox = [_is_inside(pt, topleft) for pt in detected_centroids] 91 | density_map[i][j] = float(np.sum(inbox)) 92 | return density_map / window_area 93 | -------------------------------------------------------------------------------- /breast_survival_prediction/stage2_features_test.py: -------------------------------------------------------------------------------- 1 | """Tests for stage2_features.""" 2 | 3 | import unittest 4 | 5 | import numpy as np 6 | 7 | import stage2_features 8 | 9 | # Test tolerance to 2 decimal digits. 10 | _absolute_tolerance = 0.01 11 | 12 | 13 | class Stage2FeaturesTest(unittest.TestCase): 14 | 15 | def test_np_tf_featurizer(self): 16 | # This heatmap corresponds to: 17 | # [[np.nan, G1, G1, G2], 18 | # [np.nan, G1, G2, G3]], 19 | # where G1..3 = Grade 1..3, and 20 | # [[np.nan, B, IC, IC], 21 | # [np.nan, CIS, IC, IC]], 22 | # where B = benign, IC = invasive carcinoma, and CIS = carcinoma in situ. 23 | # pyformat: disable 24 | heatmap = np.array([ 25 | [[np.nan, 1.0, 1.0, 0.0], 26 | [np.nan, 1.0, 0.0, 0.0]], # Grade 1 27 | [[np.nan, 0.0, 0.0, 1.0], 28 | [np.nan, 0.0, 1.0, 0.0]], # Grade 2 29 | [[np.nan, 0.0, 0.0, 0.0], 30 | [np.nan, 0.0, 0.0, 1.0]], # Grade 3 31 | ]) 32 | ic_heatmap = np.array([ 33 | [[np.nan, 1.0, 0.1, 0.2], 34 | [np.nan, 0.1, 0.0, 0.0]], # Benign 35 | [[np.nan, 0.0, 0.8, 0.7], 36 | [np.nan, 0.2, 0.8, 0.9]], # Invasive Carcinoma 37 | [[np.nan, 0.0, 0.1, 0.1], 38 | [np.nan, 0.7, 0.2, 0.1]], # Carcinoma In Situ 39 | ]) 40 | # pyformat: enable 41 | tmap = { 42 | 'heatmap': np.moveaxis(heatmap, 0, -1), 43 | 'ic_heatmap': np.moveaxis(ic_heatmap, 0, -1) 44 | } 45 | # First three are area within IC, latter three are outside of IC. 46 | expected_feature = np.array([0.25, 0.5, 0.25, 1.0, 0.0, 0.0]) 47 | 48 | actual_feature = stage2_features.np_tf_featurizer(tmap) 49 | 50 | np.testing.assert_allclose( 51 | actual_feature, expected_feature, atol=_absolute_tolerance) 52 | 53 | def test_np_tf_7d_heatmap(self): 54 | # This heatmap corresponds to: 55 | # [[np.nan, G0, G1, G2], 56 | # [np.nan, G1, G2, G3]], 57 | # where G1..3 = NP/TF Grade 1..3, and 58 | # [[np.nan, B, IC, IC], 59 | # [np.nan, CIS, IC, IC]], 60 | # where B = benign, IC = invasive carcinoma, and CIS = carcinoma in situ. 61 | # pyformat: disable 62 | heatmap = np.array([ 63 | [[np.nan, 1.0, 0.0, 0.0], 64 | [np.nan, 0.0, 0.0, 0.0]], # Grade 0 65 | [[np.nan, 0.0, 1.0, 0.0], 66 | [np.nan, 1.0, 0.0, 0.0]], # Grade 1 67 | [[np.nan, 0.0, 0.0, 1.0], 68 | [np.nan, 0.0, 1.0, 0.0]], # Grade 2 69 | [[np.nan, 0.0, 0.0, 0.0], 70 | [np.nan, 0.0, 0.0, 1.0]], # Grade 3 71 | ]) 72 | ic_heatmap = np.array([ 73 | [[np.nan, 1.0, 0.0, 0.0], 74 | [np.nan, 0.1, 0.0, 0.0]], # Benign 75 | [[np.nan, 0.0, 0.8, 0.7], 76 | [np.nan, 0.2, 0.8, 0.9]], # Invasive Carcinoma 77 | [[np.nan, 0.0, 0.1, 0.1], 78 | [np.nan, 0.7, 0.2, 0.1]], # Carcinoma In Situ 79 | ]) 80 | # pyformat: enable 81 | tmap = { 82 | 'heatmap': np.moveaxis(heatmap, 0, -1), 83 | 'ic_heatmap': np.moveaxis(ic_heatmap, 0, -1) 84 | } 85 | 86 | # First three are area within IC, latter three are outside of IC. 87 | expected_feature = np.array([0.25, 0.5, 0.25, 1.0, 0.0, 0.0]) 88 | 89 | actual_feature = stage2_features.np_tf_featurizer(tmap) 90 | 91 | np.testing.assert_allclose(actual_feature, expected_feature) 92 | 93 | def test_mc_featurizer(self): 94 | # 101 elements so that n-th percentile is exactly n. 95 | n_density_elem = 101 96 | # Two rows: one for actual density, and one for non-tissue area. 97 | density_map = np.zeros((2, n_density_elem)) 98 | density_map[0, :] = np.arange(n_density_elem) 99 | # Add Nan for non-tissue area. 100 | density_map[1, :] = np.nan 101 | 102 | tmap = {'density': density_map} 103 | # Current feature set is 5th, 25th, 50th, 75th, and 95th percentile. 104 | expected_feature = np.array([5, 25, 50, 75, 95]) 105 | 106 | actual_feature = stage2_features.mc_featurizer(tmap) 107 | 108 | np.testing.assert_allclose( 109 | actual_feature, expected_feature, atol=_absolute_tolerance) 110 | 111 | 112 | if __name__ == '__main__': 113 | unittest.main() 114 | -------------------------------------------------------------------------------- /fitbit_pregnancy/data/figure_4_e_data.csv: -------------------------------------------------------------------------------- 1 | ,week,deep,light,rem,wake 2 | 0,-11.0,-0.6119736186865733,-0.47314300209059534,-0.22302810233623238,-0.18101474909886767 3 | 1,-10.0,-0.09753106467911099,-0.14275273482136422,-0.1011190669487355,0.0916543998063883 4 | 2,-9.0,-0.23008808200750513,-0.0373070802111175,-0.36706650397007545,0.16916055134545266 5 | 3,-8.0,-0.3639784561228255,-0.2604044511414361,-0.5422326578487631,-0.029282930597966807 6 | 4,-7.0,-0.48534015651799894,-0.3233608328400512,-0.39038675311421484,0.02036383575128864 7 | 5,-6.0,-0.10411690988290048,-1.382933271936392,-0.01278352272809795,-0.40450693409798394 8 | 6,-5.0,0.027171440904822156,-0.4196303352850905,-0.0670084962091045,-0.20959675743986744 9 | 7,-4.0,-0.34276002158318536,-0.6532845451939137,-0.5814429261093563,-0.2255705317885147 10 | 8,-3.0,0.07144686761139171,-1.2231080150314058,0.28676192968217357,-0.33976004741451094 11 | 9,-2.0,0.37796248020360756,-0.5470165666901634,-0.13415829368322363,-0.3434702521855477 12 | 10,-1.0,0.2668767176452875,-1.4446241623810923,-0.36756253012967777,-0.4992551439634885 13 | 11,0.0,-0.4722110308554645,-1.5047549916059217,-0.5180558729319961,-0.46657467087065513 14 | 12,1.0,-0.0001497014384316153,-0.5344032712480673,0.39964871584563494,-0.41060980849183865 15 | 13,2.0,0.1389698593342636,-1.7101718628055573,-0.047745907610719875,-0.980539982608797 16 | 14,3.0,0.5958241735690146,-1.2296978880076423,-0.28208474861453237,-0.663009863488775 17 | 15,4.0,-0.589428183817027,0.6012089997404763,-1.554253218039085,0.5272072739845781 18 | 16,5.0,-1.961983459972178,2.1499523091129653,-2.3897529098268255,2.254739679920992 19 | 17,6.0,-1.9921709631134101,9.836990687938835,-1.4410213794663218,4.515463768411997 20 | 18,7.0,-2.1438991852208438,17.874958674535083,-1.3301826838644324,6.1038387434865955 21 | 19,8.0,-3.5335418358676804,23.943962970210304,-2.7594495263107865,7.0690812111562344 22 | 20,9.0,-4.000396306396624,25.915228037495027,-2.3600528569798507,7.22960676442253 23 | 21,10.0,-3.61779035147574,25.547756467739042,-2.1746486821936366,6.893675935331102 24 | 22,11.0,-3.557164275330246,24.42047375632448,-2.0431056513831374,6.734783259911473 25 | 23,12.0,-3.3678939837733326,23.825663286539598,-2.0564174737730188,6.737406551167359 26 | 24,13.0,-3.2268605733378672,21.855828910711637,-1.253330104252067,6.297961267329288 27 | 25,14.0,-2.5838818802016235,22.09511348830345,-0.8689958900338232,5.966758126630202 28 | 26,15.0,-2.377813944565751,19.14805000871577,-0.4692932455788475,5.094415212185438 29 | 27,16.0,-2.790129527474489,18.712010485568165,-0.29910513205490435,5.259198662563764 30 | 28,17.0,-2.9867035227009824,16.747826987331397,0.22313021523915716,4.240082414132534 31 | 29,18.0,-3.4645498865565294,15.772952166757387,0.2822142566967247,4.052611639770624 32 | 30,19.0,-3.674874162598021,14.985139221797725,0.5578705379763211,3.7087623670233056 33 | 31,20.0,-4.529605509036173,14.428342316714392,-0.06949641968711044,3.6467891066268385 34 | 32,21.0,-5.119396411554447,14.34489717425584,-0.2141585505829173,3.5896471811114288 35 | 33,22.0,-6.128536650762111,13.82167345290742,-0.2778062021613056,3.5708822553282613 36 | 34,23.0,-6.850109920773781,14.433572610976004,-0.7124988125184096,3.628446686816532 37 | 35,24.0,-7.62730204574743,14.49955255020265,-0.8059169601856315,3.5056341873527894 38 | 36,25.0,-8.869744236676377,14.565953361856867,-1.5315137775333678,3.7810102045477905 39 | 37,26.0,-9.302224078832264,15.089791341748814,-1.513054807991092,3.905943365982469 40 | 38,27.0,-10.188506761271064,15.466028719197286,-2.537164231053704,3.8375913535308803 41 | 39,28.0,-11.042466799310471,14.969150820837651,-2.609752840709701,3.860991931090473 42 | 40,29.0,-11.799089790332687,15.733905141996464,-2.9656814728253442,4.162668422805882 43 | 41,30.0,-12.79195496551104,16.138969699056172,-3.2460019211994457,4.343326594940516 44 | 42,31.0,-13.726840478005773,16.61050686183551,-4.101269521642114,4.334196147594346 45 | 43,32.0,-14.83809352100832,16.353472601274152,-4.76692867920695,4.662086084146944 46 | 44,33.0,-15.289326906038216,15.708916678834445,-5.348102432077485,4.896711844039379 47 | 45,34.0,-16.434272907990273,14.798187475891663,-6.565886412811004,4.747497932100949 48 | 46,35.0,-17.247978490760836,15.649593113685057,-7.421513001572993,5.394533050466975 49 | 47,36.0,-18.49950553299727,14.26671137261833,-8.287389027148166,5.19256092252507 50 | 48,37.0,-19.157452886864956,13.364667726011207,-8.95747621502568,5.3855002592560055 51 | 49,38.0,-19.880599280268044,14.293788903252699,-8.835443949291964,5.774586499455005 52 | -------------------------------------------------------------------------------- /health_acoustic_representations/api_utils_test.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from unittest import mock 4 | 5 | import google.auth.credentials 6 | from google.cloud.aiplatform.aiplatform import gapic 7 | import numpy as np 8 | 9 | import api_utils 10 | 11 | 12 | class TestMakePrediction(unittest.TestCase): 13 | 14 | def setUp(self): 15 | super().setUp() 16 | os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "/tmp/fake_credentials.json" 17 | with open("/tmp/fake_credentials.json", "w") as f: 18 | # Fake file 19 | d = { 20 | "account": "", 21 | "client_id": "fergfggthyt-grht4thhrtyhy.apps.googleusercontent.com", 22 | "client_secret": "d-grteghrthy", 23 | "refresh_token": "1//freghthhyy-getrhythrythyr-egthhrtyhtrth", 24 | "type": "authorized_user", 25 | "universe_domain": "googleapis.com", 26 | } 27 | f.write(json.dumps(d)) 28 | 29 | @mock.patch.object(gapic.PredictionServiceClient, "predict") 30 | def test_raw_audio_endpoint_success(self, mock_predict): 31 | mock_predict.return_value = mock.MagicMock( 32 | predictions=[[0.1] * 512, [0.9] * 512] 33 | ) 34 | instances = np.random.rand(2, 32000) 35 | result = api_utils.make_prediction( 36 | api_utils.RAW_AUDIO_ENDPOINT_PATH, instances, 37 | ) 38 | self.assertEqual(result.shape, (2, 512)) 39 | mock_predict.assert_called_once_with( 40 | endpoint=api_utils.RAW_AUDIO_ENDPOINT_PATH, instances=instances.tolist() 41 | ) 42 | 43 | @mock.patch.object(gapic.PredictionServiceClient, "predict") 44 | def test_gcs_uri_endpoint_success(self, mock_predict): 45 | mock_predict.return_value = mock.MagicMock( 46 | predictions=[[0.1] * 512, [0.9] * 512] 47 | ) 48 | instances = ["gs://bucket/file1.wav", "gs://bucket/file2.wav"] 49 | gcs_bucket_name = "bucket" 50 | gcs_creds = mock.MagicMock(spec=google.auth.credentials.Credentials) 51 | gcs_creds.token = "mocked_token" 52 | 53 | result = api_utils.make_prediction( 54 | api_utils.GCS_URI_ENDPOINT_PATH, 55 | instances, 56 | gcs_bucket_name, 57 | gcs_creds, 58 | ) 59 | self.assertEqual(result.shape, (2, 512)) 60 | expected_instances = api_utils._get_prediction_instances( 61 | image_uris=instances, 62 | gcs_bucket_name=gcs_bucket_name, 63 | gcs_creds=gcs_creds, 64 | ) 65 | mock_predict.assert_called_once_with( 66 | endpoint=api_utils.GCS_URI_ENDPOINT_PATH, instances=expected_instances 67 | ) 68 | 69 | def test_raw_audio_endpoint_invalid_instances_type(self): 70 | instances = ["invalid", "instances"] 71 | with self.assertRaisesRegex(ValueError, "must be a numpy array"): 72 | api_utils.make_prediction( 73 | api_utils.RAW_AUDIO_ENDPOINT_PATH, 74 | instances, 75 | ) 76 | 77 | def test_raw_audio_endpoint_invalid_instances_shape(self): 78 | 79 | instances = np.random.rand(2, 1000) 80 | with self.assertRaisesRegex(ValueError, "must be a numpy array of shape"): 81 | api_utils.make_prediction( 82 | endpoint_path=api_utils.RAW_AUDIO_ENDPOINT_PATH, 83 | instances=instances, 84 | ) 85 | 86 | def test_gcs_uri_endpoint_invalid_instances_type(self): 87 | instances = np.random.rand(2, 32000) 88 | with self.assertRaisesRegex(ValueError, "must be a list of strings"): 89 | api_utils.make_prediction( 90 | endpoint_path=api_utils.GCS_URI_ENDPOINT_PATH, 91 | instances=instances, 92 | gcs_bucket_name="bucket", 93 | gcs_creds=mock.MagicMock(), 94 | ) 95 | 96 | def test_gcs_uri_endpoint_missing_bucket_name(self): 97 | instances = ["gs://bucket/file.wav"] 98 | with self.assertRaisesRegex( 99 | ValueError, "`gcs_bucket_name` must be specified" 100 | ): 101 | api_utils.make_prediction( 102 | endpoint_path=api_utils.GCS_URI_ENDPOINT_PATH, 103 | instances=instances, 104 | gcs_creds=mock.MagicMock(), 105 | ) 106 | 107 | def test_gcs_uri_endpoint_missing_credentials(self): 108 | instances = ["gs://bucket/file.wav"] 109 | with self.assertRaisesRegex(ValueError, "`gcs_creds` must be specified"): 110 | api_utils.make_prediction( 111 | endpoint_path=api_utils.GCS_URI_ENDPOINT_PATH, 112 | instances=instances, 113 | gcs_bucket_name="bucket", 114 | ) 115 | 116 | def test_invalid_endpoint_path(self): 117 | instances = np.random.rand(2, 32000) 118 | with self.assertRaisesRegex( 119 | ValueError, "Endpoint invalid_endpoint is not recognized." 120 | ): 121 | api_utils.make_prediction( 122 | endpoint_path="invalid_endpoint", 123 | instances=instances, 124 | ) 125 | 126 | 127 | if __name__ == "__main__": 128 | unittest.main() 129 | -------------------------------------------------------------------------------- /fitbit_pregnancy/data/figure_4_f_data.csv: -------------------------------------------------------------------------------- 1 | ,week,percentile_10,percentile_25,percentile_50,percentile_75,percentile_90 2 | 0,-11.0,85.86462294072591,87.02093103114689,88.23953210087937,89.4229370456689,90.37313732151198 3 | 1,-10.0,85.7598071693701,86.97932089248307,88.28239174179572,89.39990019569987,90.4303859051916 4 | 2,-9.0,85.79190910377362,86.94487789341204,88.25119101148849,89.39811780024749,90.37694095715804 5 | 3,-8.0,85.82649713808158,87.02364127821438,88.24112655682305,89.38869780683265,90.44470839700281 6 | 4,-7.0,85.82090100276135,87.0174915491607,88.22833296408118,89.35306871217045,90.35913581368673 7 | 5,-6.0,85.94009562069539,87.13365504764444,88.2664042909397,89.42558514148807,90.388900766523 8 | 6,-5.0,85.954196128751,87.08007293942963,88.28541574219267,89.45485886756518,90.406480155753 9 | 7,-4.0,85.9396152259889,87.05579188053429,88.23594540220017,89.38411065768891,90.42246682318569 10 | 8,-3.0,85.89534865424373,87.07443159206069,88.31904398271647,89.47200031910273,90.45292666381407 11 | 9,-2.0,85.9658093025438,87.02104015409748,88.33178533505162,89.4851828058098,90.40762790636005 12 | 10,-1.0,86.00972643141398,87.11212016000708,88.27139726473503,89.41723533322192,90.36166193005097 13 | 11,0.0,85.90562772804844,87.0765844441069,88.3042038430088,89.45222895918737,90.40644183742987 14 | 12,1.0,85.97603558223615,87.0827699682741,88.29092894738473,89.45664505583517,90.429777629565 15 | 13,2.0,86.03032783643359,87.16763175620093,88.36375401085388,89.54266846983543,90.51155103082735 16 | 14,3.0,86.02224182197698,87.19552869926012,88.34952963753037,89.52881094839267,90.44439526225808 17 | 15,4.0,85.70335359188176,86.90971733381637,88.08091903468974,89.30392347834398,90.34044289980363 18 | 16,5.0,85.38004001754469,86.55102091236904,87.80550644443973,89.00954911754924,89.9492151865888 19 | 17,6.0,85.17969436705579,86.39088577825022,87.64089998583188,88.82445135019398,89.86779770154824 20 | 18,7.0,85.07776892776134,86.34793029704072,87.53891150552711,88.71023094515392,89.76363542602861 21 | 19,8.0,85.08557080050764,86.19570518564953,87.43930574423817,88.625501535571,89.6480777964973 22 | 20,9.0,85.07605627850752,86.18663636697138,87.43692291703971,88.69063404439277,89.69613646732105 23 | 21,10.0,85.16359380154029,86.30093421324474,87.52633706234536,88.68960713302643,89.76385521509906 24 | 22,11.0,85.19080227471966,86.3624740094169,87.546549849134,88.68879445679629,89.71097238354697 25 | 23,12.0,85.1053262814159,86.28186245348022,87.50268914360157,88.70014584977031,89.73978177534141 26 | 24,13.0,85.21390312248398,86.32254343519973,87.57056026133996,88.7568214027562,89.76452049906827 27 | 25,14.0,85.2757872277831,86.47229896949975,87.64517524179156,88.80024490641685,89.78103715288725 28 | 26,15.0,85.34887635011961,86.53365867972595,87.7014100928566,88.89807090471433,90.02994035517558 29 | 27,16.0,85.39260008016673,86.47627542549992,87.71228372282538,88.90594366278788,89.85551091093264 30 | 28,17.0,85.54054342199537,86.62207140571763,87.81902700451592,88.99222244084548,90.01760518906664 31 | 29,18.0,85.49066689287336,86.64849655544643,87.83521965343702,88.98441448279227,89.99118201557978 32 | 30,19.0,85.5642853791832,86.6503977163893,87.87201798105417,89.02761803638924,90.00429331140703 33 | 31,20.0,85.38269134810277,86.58326112592873,87.85973374893032,89.02548553066413,90.03070972981332 34 | 32,21.0,85.49791143308538,86.51901165651839,87.78902163203642,89.0331746787825,89.99376854758657 35 | 33,22.0,85.38504996487367,86.53713474127856,87.74221031100318,89.00252651965644,90.04373167129337 36 | 34,23.0,85.3715099197865,86.61574155186993,87.79061517733352,88.92521408166681,89.9441892404094 37 | 35,24.0,85.30911367680677,86.54661755887548,87.78307385312873,88.96815930644225,90.03309903475461 38 | 36,25.0,85.19733698612423,86.44677411308587,87.69407673653438,88.8722846946282,89.93216256128733 39 | 37,26.0,85.36436478952368,86.4742567977213,87.61735595657802,88.79545745409062,89.82779027489751 40 | 38,27.0,85.20566298062359,86.38189879362156,87.61603733409002,88.82150802452203,89.83073088461487 41 | 39,28.0,85.16680463056164,86.27271259275872,87.59811455756123,88.75770579130204,89.79687194108274 42 | 40,29.0,85.0973560100317,86.32535754136919,87.51813958690882,88.6995684389303,89.78006017153196 43 | 41,30.0,84.99747795046164,86.22069192189424,87.4923236785173,88.68083677487593,89.73081554881921 44 | 42,31.0,84.93009525687465,86.17678380067417,87.43936641258657,88.64399829605088,89.65379778678928 45 | 43,32.0,84.88035800906245,86.06317166927329,87.34076079937094,88.53266020162694,89.58462373057246 46 | 44,33.0,84.75184864078483,85.88234443857476,87.26600695315936,88.44704012608366,89.54659937645116 47 | 45,34.0,84.7350211986714,85.90806495468973,87.16052224812398,88.3775677088378,89.46199223847911 48 | 46,35.0,84.50255130891017,85.7729188138898,87.04883146268041,88.28717070251142,89.34228215997835 49 | 47,36.0,84.42201447872321,85.65235607434553,87.02002411929512,88.26154257533219,89.35198863050708 50 | 48,37.0,84.32607110192639,85.63226373713577,86.90217783895545,88.09559729888124,89.21587974107281 51 | 49,38.0,84.2226582606012,85.51809610464485,86.83149698558069,88.12600987526244,89.14817226008694 52 | -------------------------------------------------------------------------------- /fitbit_pregnancy/data/figure_3_data.csv: -------------------------------------------------------------------------------- 1 | day,mean,std 2 | -83.0,-0.16769814857886697,2.945680856842309 3 | -82.0,-0.19747108710441513,2.902543855295249 4 | -81.0,-0.2487690781382201,2.9272746375600955 5 | -80.0,-0.2576812914249982,2.937955506157665 6 | -79.0,-0.32756284141830605,2.8917429284240894 7 | -78.0,-0.3184432301961573,2.878865459427959 8 | -77.0,-0.3431468986089724,2.796863379936013 9 | -76.0,-0.2830637817954389,2.747434004935879 10 | -75.0,-0.24984857476265987,2.7989743033831087 11 | -74.0,-0.21879158716922553,2.738513880549523 12 | -73.0,-0.16281890267666993,2.7098325828839505 13 | -72.0,-0.12733111063830205,2.7494831741885064 14 | -71.0,-0.0504820135105832,2.7288543770979268 15 | -70.0,0.017508572757704094,2.660707838213464 16 | -69.0,0.06791609865875295,2.620676893378685 17 | -68.0,0.17099110199018802,2.662388772428467 18 | -67.0,0.19360176815422356,2.607286831503329 19 | -66.0,0.20562611612585657,2.6244537045309464 20 | -65.0,0.2655974466146493,2.55358068271235 21 | -64.0,0.2686135887446077,2.5443611190895257 22 | -63.0,0.3298601003099881,2.57064316427218 23 | -62.0,0.2766881459530212,2.5437401612887953 24 | -61.0,0.18890014077775918,2.506946738973213 25 | -60.0,0.1266504354751014,2.54860811524235 26 | -59.0,0.009379430520757197,2.6125234503206913 27 | -58.0,-0.10567351181832688,2.59250241099159 28 | -57.0,-0.17304072675095467,2.5926958027799767 29 | -56.0,-0.2513337178865767,2.622674281964326 30 | -55.0,-0.3383866479613834,2.6220118872991325 31 | -54.0,-0.38943428573228017,2.6291867963125384 32 | -53.0,-0.44406133726656144,2.647626944139557 33 | -52.0,-0.5145050789528759,2.596210720259021 34 | -51.0,-0.5152574287222611,2.6053966188866 35 | -50.0,-0.5147823961233489,2.563975989596137 36 | -49.0,-0.5183606127818892,2.573099540618965 37 | -48.0,-0.4193385947013469,2.5571598971802194 38 | -47.0,-0.3028656668096662,2.569726432017165 39 | -46.0,-0.2386841768792865,2.5601845008986066 40 | -45.0,-0.11197215394751855,2.558576379977996 41 | -44.0,0.004955264059418171,2.5370923499793436 42 | -43.0,0.06942233589788026,2.5155849071366587 43 | -42.0,0.12323378717607746,2.4559997817323627 44 | -41.0,0.1910124641789299,2.4373295322908897 45 | -40.0,0.27335660713583504,2.508006064309953 46 | -39.0,0.3825029833509234,2.4898583704461177 47 | -38.0,0.43148408144262,2.5037843158190505 48 | -37.0,0.48544623315965557,2.561038863757535 49 | -36.0,0.5127208807828619,2.590418968096094 50 | -35.0,0.456853331782868,2.557694333724 51 | -34.0,0.42773066020420514,2.580775218815133 52 | -33.0,0.4184667200548852,2.6108073172366666 53 | -32.0,0.33079562476149205,2.6407177137835935 54 | -31.0,0.20996830850539078,2.647825114801626 55 | -30.0,0.014730082252183922,2.660669047220425 56 | -29.0,-0.13649808468380595,2.665501267810027 57 | -28.0,-0.24839142172879294,2.6900934022965806 58 | -27.0,-0.38867122042117697,2.694805703820524 59 | -26.0,-0.512982126799971,2.6632742559146605 60 | -25.0,-0.6691103457034276,2.6260432708047916 61 | -24.0,-0.7961777547294447,2.569454362421822 62 | -23.0,-0.8674974490937355,2.5695615119863615 63 | -22.0,-0.8404837844150919,2.5300708571310437 64 | -21.0,-0.7883830639513417,2.5484949771484553 65 | -20.0,-0.6786396629970394,2.5855099850127474 66 | -19.0,-0.5125993068566362,2.549118903553633 67 | -18.0,-0.3222802615830299,2.5757575860489084 68 | -17.0,-0.20889436927077223,2.6480935111616035 69 | -16.0,-0.052349931609791116,2.6682125156364123 70 | -15.0,0.08978083160679635,2.644385309451993 71 | -14.0,0.28142731083366046,2.650745306161568 72 | -13.0,0.4326753454895309,2.6820969977877733 73 | -12.0,0.5809549433672659,2.6415811721500018 74 | -11.0,0.7110608674091989,2.675745144855068 75 | -10.0,0.7951592829490894,2.7153284340333146 76 | -9.0,0.8287167177426756,2.7106247292338628 77 | -8.0,0.8976483616934918,2.738624782807716 78 | -7.0,0.8885969471758192,2.809580081153524 79 | -6.0,0.8818318080640944,2.9050597043211233 80 | -5.0,0.8298551295202303,2.944665232327594 81 | -4.0,0.6174321887911762,2.974465569704749 82 | -3.0,0.4223725533468706,2.9964980307338225 83 | -2.0,0.19459891765634582,3.0038471010198813 84 | -1.0,-0.10517404681554349,3.018816611604175 85 | 0.0,-0.3831039379956871,3.0452831834140883 86 | 1.0,-0.6050434202233175,3.079964507774967 87 | 2.0,-0.8845117346852202,3.0581668080120252 88 | 3.0,-1.0905690258225837,3.013324807481988 89 | 4.0,-1.2367076160314323,2.989340030303619 90 | 5.0,-1.2987994805838832,3.01213283284721 91 | 6.0,-1.2625267398585265,3.0304840369338586 92 | 7.0,-1.1829487426629721,3.043485696394318 93 | 8.0,-1.0589937365611568,3.098800599507673 94 | 9.0,-0.8232614717045557,3.104427353899736 95 | 10.0,-0.5993594495623801,3.1176048382736536 96 | 11.0,-0.42687369973793093,3.1337365264035526 97 | 12.0,-0.18301142290681893,3.1199244521284872 98 | 13.0,0.04035125570756414,3.1322328327704914 99 | 14.0,0.29203650429184685,3.1051939235202273 100 | 15.0,0.4949567859704329,3.1879712188685776 101 | 16.0,0.7367156746341431,3.21017588545266 102 | 17.0,0.905751239200959,3.2253366402603914 103 | 18.0,1.1107181439467213,3.2238832658567627 104 | 19.0,1.248670367635808,3.254520311972774 105 | 20.0,1.4222561830570528,3.1994524548169836 106 | -------------------------------------------------------------------------------- /colorectal-survival/analysis.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Google Inc. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions 6 | # are met: 7 | # 8 | # 1. Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # 11 | # 2. Redistributions in binary form must reproduce the above copyright 12 | # notice, this list of conditions and the following disclaimer in the 13 | # documentation and/or other materials provided with the distribution. 14 | # 15 | # 3. Neither the name of Google Inc. nor the names of its 16 | # contributors may be used to endorse or promote products derived from this 17 | # software without specific prior written permission. 18 | # 19 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 22 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 23 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 24 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 25 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 26 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 27 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 28 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 29 | # POSSIBILITY OF SUCH DAMAGE. 30 | """Statistical analysis.""" 31 | 32 | import lifelines 33 | import matplotlib.pyplot as plt 34 | import numpy as np 35 | import pandas as pd 36 | import sklearn.metrics 37 | 38 | TIME = 'time' 39 | OBSERVED = 'observed' 40 | RISK_SCORE = 'risk_score' 41 | 42 | 43 | def plot_km_curve(df_tune, df_test): 44 | """Returns KM curves for each risk group for `df_test`. 45 | 46 | Risk groups are defined via thresholds computed on `df_tune`. 47 | 48 | Args: 49 | df_tune: a pd.DataFrame of tune set data. 50 | df_test: a pd.DataFrame of test set data. 51 | """ 52 | # Compute risk groups 53 | df_test['risk_group'] = discretize(df_tune[RISK_SCORE], df_test[RISK_SCORE]) 54 | 55 | # Plot KM curves per risk group 56 | fig, ax = plt.subplots() 57 | groups = ['Low Risk', 'Medium Risk', 'High Risk'] 58 | kmfs = [] 59 | for group in groups: 60 | kmf = lifelines.KaplanMeierFitter() 61 | df_group = df_test.query(f"risk_group=='{group}'") 62 | if df_group.empty: 63 | continue 64 | kmf.fit(df_group[TIME], event_observed=df_group[OBSERVED], label=group) 65 | kmf.plot(ax=ax) 66 | kmfs.append(kmf) 67 | lifelines.plotting.add_at_risk_counts(*kmfs, ax=ax) 68 | return fig 69 | 70 | 71 | def discretize(risk_scores_tune, risk_scores_test): 72 | """Discretize `risk_scores_test` based on thresholds from `risk_scores_tune`. 73 | 74 | Args: 75 | risk_scores_tune: np.ndarray of continuous risk scores. 76 | risk_scores_test: np.ndarray of continuous risk scores. 77 | 78 | Returns: 79 | an np.ndarray of disretized test set risk scores. 80 | """ 81 | thresholds_valid = np.percentile(risk_scores_tune, [25, 75]) 82 | risk_groups_test = np.digitize(risk_scores_test, bins=thresholds_valid) 83 | risk_groups_test = pd.Series(risk_groups_test) 84 | risk_group_map = {0: 'Low Risk', 1: 'Medium Risk', 2: 'High Risk'} 85 | risk_groups_test = risk_groups_test.apply(lambda x: risk_group_map[x]) 86 | return risk_groups_test 87 | 88 | 89 | def c_index(df): 90 | return lifelines.utils.concordance_index(df[TIME], df[RISK_SCORE], 91 | df[OBSERVED]) 92 | 93 | 94 | def survival_auc(df, threshold): 95 | """Survival AUC.""" 96 | df_binarized = binarize_time(df, threshold) 97 | return sklearn.metrics.roc_auc_score(df_binarized[TIME], 98 | df_binarized[RISK_SCORE]) 99 | 100 | 101 | def binarize_time(df, threshold): 102 | """Binarize time based on threshold. 103 | 104 | If time > threshold: `observed` and `time` columns are set to 1. 105 | If time <= threshold: unobserved examples are dropped and `time` is set to 0. 106 | 107 | Args: 108 | df: pd.DataFrame containing `time` and `observed` columns. 109 | threshold: the time threshold on which to binarize. 110 | 111 | Returns: 112 | a pd.Dataframe where time has been discretized. 113 | """ 114 | 115 | def update_observed(row): 116 | if row[TIME] > threshold: 117 | return 1 118 | return row[OBSERVED] 119 | 120 | df[OBSERVED] = df.apply(update_observed, axis=1) 121 | df[TIME] = (df[TIME] > threshold).astype(int) 122 | 123 | # Remove censored examples below threshold. These examples cannot be 124 | # compared to any others. 125 | df = df.query('time != 0 or observed != 0') 126 | return df 127 | 128 | 129 | def get_hazard_ratios(df_test): 130 | cph = lifelines.CoxPHFitter() 131 | cph.fit(df_test, duration_col=TIME, event_col=OBSERVED) 132 | return cph.summary 133 | -------------------------------------------------------------------------------- /health_acoustic_representations/eval_utils_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn import linear_model 3 | from sklearn import metrics 4 | 5 | import unittest 6 | import eval_utils 7 | 8 | 9 | class TestEvalUtils(unittest.TestCase): 10 | 11 | def setUp(self): 12 | super().setUp() 13 | np.random.seed(42) 14 | self.x = np.array([[1, 2], [3, 4]]) 15 | self.y_reg = np.array([2, 4]) 16 | self.y_cls = np.array([0, 1]) 17 | 18 | def test_create_linear_probe(self): 19 | probe = eval_utils.create_linear_probe(0.1, is_regression=True) 20 | with self.subTest(name='check_model_type'): 21 | self.assertIsInstance(probe, linear_model.Ridge) 22 | with self.subTest(name='check_regularization_parameter'): 23 | self.assertEqual(probe.alpha, 0.1) 24 | 25 | probe = eval_utils.create_linear_probe( 26 | 0.1, is_regression=False, use_sgd_classifier=True 27 | ) 28 | with self.subTest(name='check_model_type'): 29 | self.assertIsInstance(probe, linear_model.SGDClassifier) 30 | with self.subTest(name='check_regularization_parameter'): 31 | self.assertEqual(probe.alpha, 0.1) 32 | with self.subTest(name='check_loss_type'): 33 | self.assertEqual(probe.loss, 'log') 34 | with self.subTest(name='check_class_weight'): 35 | self.assertEqual(probe.class_weight, 'balanced') 36 | 37 | probe = eval_utils.create_linear_probe( 38 | 0.1, is_regression=False, use_sgd_classifier=False 39 | ) 40 | with self.subTest(name='check_model_type'): 41 | self.assertIsInstance(probe, linear_model.LogisticRegression) 42 | with self.subTest(name='check_regularization_parameter'): 43 | self.assertEqual(probe.C, 0.1) 44 | with self.subTest(name='check_class_weight'): 45 | self.assertEqual(probe.class_weight, 'balanced') 46 | 47 | def test_predict_with_probe(self): 48 | probe = linear_model.Ridge().fit(self.x, self.y_reg) 49 | predictions = eval_utils.predict_with_probe( 50 | probe, self.x, is_regression=True 51 | ) 52 | self.assertEqual(predictions.shape, (2,)) 53 | 54 | probe = linear_model.LogisticRegression().fit(self.x, self.y_cls) 55 | predictions = eval_utils.predict_with_probe( 56 | probe, self.x, is_regression=False 57 | ) 58 | with self.subTest(name='check_predictions_shape'): 59 | self.assertEqual(predictions.shape, (2,)) 60 | with self.subTest(name='check_value_in_0_1_range'): 61 | self.assertTrue(all(0 <= p <= 1 for p in predictions)) 62 | 63 | def test_find_reg_coef_with_best_metric(self): 64 | cv_scores = {0.1: 0.8, 0.01: 0.9, 1.0: 0.7} 65 | 66 | best_alpha = eval_utils.find_reg_coef_with_best_metric( 67 | cv_scores, lower_is_better=True 68 | ) 69 | with self.subTest(name='check_alpha_is_highest'): 70 | self.assertEqual(best_alpha, 1.0) 71 | 72 | best_alpha = eval_utils.find_reg_coef_with_best_metric( 73 | cv_scores, lower_is_better=False 74 | ) 75 | with self.subTest(name='check_alpha_is_lowest'): 76 | self.assertEqual(best_alpha, 0.01) 77 | 78 | def test_compute_metrics_for_probe(self): 79 | y_true = np.array([1.0, 2.0, 3.0]) 80 | y_pred = np.array([1.1, 2.1, 2.9]) 81 | mae = eval_utils.compute_metrics_for_probe( 82 | y_true, y_pred, is_regression=True 83 | ) 84 | self.assertIsInstance(mae, float) 85 | self.assertAlmostEqual(mae, 0.1) 86 | 87 | y_true = np.array([0, 1, 1, 0]) 88 | y_score = np.array([0.1, 0.8, 0.7, 0.2]) 89 | auc = eval_utils.compute_metrics_for_probe( 90 | y_true, y_score, is_regression=False 91 | ) 92 | with self.subTest(name='check_auc_type'): 93 | self.assertIsInstance(auc, float) 94 | with self.subTest(name='check_auc_value'): 95 | self.assertAlmostEqual(auc, metrics.roc_auc_score(y_true, y_score)) 96 | 97 | def test_train_linear_probe_with_participant_level_crossval(self): 98 | n_samples = 100 99 | n_features = 10 100 | n_participants = 20 101 | 102 | features = np.random.randn(n_samples, n_features) 103 | participant_ids = np.repeat( 104 | range(n_participants), n_samples // n_participants 105 | ) 106 | 107 | labels_reg = np.random.randn(n_samples) 108 | probe_reg = eval_utils.train_linear_probe_with_participant_level_crossval( 109 | features, labels_reg, participant_ids, is_regression=True 110 | ) 111 | with self.subTest(name='check_model_type_ridge'): 112 | self.assertIsInstance(probe_reg, linear_model.Ridge) 113 | 114 | labels_cls = np.random.randint(0, 2, n_samples) 115 | probe_cls = eval_utils.train_linear_probe_with_participant_level_crossval( 116 | features, labels_cls, participant_ids, is_regression=False 117 | ) 118 | with self.subTest(name='check_model_type_logistic_regression'): 119 | self.assertIsInstance( 120 | probe_cls, 121 | (linear_model.SGDClassifier, linear_model.LogisticRegression), 122 | ) 123 | 124 | def test_input_validation(self): 125 | with self.assertRaises(AssertionError): 126 | features = np.random.randn(10, 5) 127 | labels = np.random.randn(9) 128 | participant_ids = np.arange(10) 129 | eval_utils.train_linear_probe_with_participant_level_crossval( 130 | features, labels, participant_ids, is_regression=True 131 | ) 132 | 133 | 134 | if __name__ == '__main__': 135 | unittest.main() 136 | -------------------------------------------------------------------------------- /colorectal-survival/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Google Inc. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions 6 | # are met: 7 | # 8 | # 1. Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # 11 | # 2. Redistributions in binary form must reproduce the above copyright 12 | # notice, this list of conditions and the following disclaimer in the 13 | # documentation and/or other materials provided with the distribution. 14 | # 15 | # 3. Neither the name of Google Inc. nor the names of its 16 | # contributors may be used to endorse or promote products derived from this 17 | # software without specific prior written permission. 18 | # 19 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 22 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 23 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 24 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 25 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 26 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 27 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 28 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 29 | # POSSIBILITY OF SUCH DAMAGE. 30 | """Survival Network.""" 31 | 32 | import tensorflow as tf 33 | 34 | 35 | def cox_partial_likelihood(event_times, censored, preds): 36 | """Returns negative log of cox partial liklihood. 37 | 38 | This implementation uses Breslow's approximation for handling ties. For 39 | details on Breslow's method, see page 144 in 40 | https://www4.stat.ncsu.edu/~dzhang2/st745/chap7.pdf. Note that we calculate 41 | the loss with respect to the negative of `preds` such that preds are 42 | positively correlated with event times. 43 | 44 | Args: 45 | event_times: ground-truth event times. Tensor of shape [batch_size]. 46 | censored: mask indicating whether the example is censored. Tensor of shape 47 | [batch_size]. 48 | preds: predicted event times. Tensor of shape [batch_size]. 49 | 50 | """ 51 | mask = get_risk_set(event_times, ignore_ties=False) 52 | preds = shift_preds(preds) 53 | loss = preds - logsumexp_masked(tile_rows(preds), mask) 54 | observed = tf.cast(tf.logical_not(censored), loss.dtype) 55 | loss = tf.reduce_sum(loss * observed) 56 | loss = tf.math.divide_no_nan(loss, tf.reduce_sum(observed)) 57 | loss = -loss # we minimize the negative liklihood 58 | return loss 59 | 60 | 61 | def get_risk_set(event_times, ignore_ties=False): 62 | """Returns a matrix where row i indicates the risk set for example i. 63 | 64 | If ignore_ties=True: 65 | m[i, j] == 1 iff j > i 66 | If ignore_ties=False: 67 | m[i, j] == 1 iff j >= i 68 | 69 | Args: 70 | event_times: 1D tenseor of event times 71 | ignore_ties: if False, comparable pairs can have tied event times. 72 | """ 73 | m1 = tile_rows(event_times) 74 | m2 = tile_columns(event_times) 75 | if ignore_ties: 76 | return tf.greater(m1, m2) 77 | else: 78 | return tf.greater_equal(m1, m2) 79 | 80 | 81 | def tile_rows(a): 82 | """Returns a matrix where each row is equal to `a`. 83 | 84 | Example: 85 | a = [2, 1, 3] 86 | 87 | m = [[2, 1, 3], 88 | [2, 1, 3], 89 | [2, 1, 3]] 90 | 91 | Args: 92 | a: 1D tensor. 93 | """ 94 | n = tf.shape(a)[0] 95 | return tf.tile(tf.expand_dims(a, axis=0), (n, 1)) 96 | 97 | 98 | def tile_columns(a): 99 | """Returns a matrix where each column is equal to `a`. 100 | 101 | Example: 102 | a = [2, 1, 3] 103 | 104 | m = [[2, 2, 2], 105 | [1, 1, 1], 106 | [3, 3, 3]] 107 | 108 | Args: 109 | a: 1D tensor. 110 | """ 111 | return tf.transpose(tile_rows(a)) 112 | 113 | 114 | def logsumexp_masked(a, mask): 115 | """Returns row-wise masked log sum exp of a. 116 | 117 | Uses the following trick for numeric stability: 118 | log(sum(exp(x))) == log(sum(exp(x - max(x)))) + max(x) 119 | 120 | Args: 121 | a: 2D tensor. 122 | mask: 2D tensor. 123 | """ 124 | mask = tf.cast(mask, a.dtype) 125 | a_max = tf.math.reduce_max(a * mask, axis=1, keepdims=True) 126 | a = a - a_max 127 | a_exp = tf.math.exp(a) 128 | a_sum_exp = tf.math.reduce_sum(a_exp * mask, axis=1, keepdims=True) 129 | return tf.squeeze(tf.math.log(a_sum_exp) + a_max) 130 | 131 | 132 | def shift_preds(preds): 133 | """Returns uniformly shift preds so minimum is at 0 to avoid underflow. 134 | 135 | Args: 136 | preds: Tensor of shape [batch_size]. 137 | """ 138 | preds_min = tf.reduce_min(preds) 139 | shift = tf.where(preds_min < 0, -preds_min, 0) 140 | return preds + shift 141 | 142 | 143 | def keras_cox_partial_likelihood(y_true, y_pred): 144 | """Keras friendly wrapper for cox_partial_likelihood.""" 145 | event_times = tf.squeeze(tf.gather(y_true, [0], axis=1)) 146 | censored = tf.cast(tf.squeeze(tf.gather(y_true, [1], axis=1)), tf.bool) 147 | preds = tf.squeeze(y_pred, axis=1) 148 | return cox_partial_likelihood(event_times, censored, preds) 149 | -------------------------------------------------------------------------------- /colorectal-survival/network.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Google Inc. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions 6 | # are met: 7 | # 8 | # 1. Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # 11 | # 2. Redistributions in binary form must reproduce the above copyright 12 | # notice, this list of conditions and the following disclaimer in the 13 | # documentation and/or other materials provided with the distribution. 14 | # 15 | # 3. Neither the name of Google Inc. nor the names of its 16 | # contributors may be used to endorse or promote products derived from this 17 | # software without specific prior written permission. 18 | # 19 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 22 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 23 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 24 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 25 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 26 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 27 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 28 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 29 | # POSSIBILITY OF SUCH DAMAGE. 30 | """Network.""" 31 | 32 | import tensorflow as tf 33 | from tensorflow.keras.layers import BatchNormalization 34 | from tensorflow.keras.layers import Conv2D 35 | from tensorflow.keras.layers import Dense 36 | from tensorflow.keras.layers import GlobalAveragePooling1D 37 | from tensorflow.keras.layers import GlobalAveragePooling2D 38 | from tensorflow.keras.layers import SeparableConv2D 39 | from tensorflow.keras.layers import TimeDistributed 40 | 41 | 42 | def build_network(input_shape, 43 | base_depth=16, 44 | depth_growth=1.25, 45 | stride_2_layers=4, 46 | stride_1_layers=1, 47 | kernel_size=3): 48 | """Returns deep learning network for survival prediction. 49 | 50 | The network takes a set of image patches as input. A feature vector 51 | is extracted from each image patch using a CNN module. Feature vectors are 52 | averaged before being fed into a dense layer with a single output. 53 | 54 | Args: 55 | input_shape: 4D input shape [sequence_length, height, width, depth]. 56 | base_depth: the number of filter in the base Conv2D layer. 57 | depth_growth: the rate at which the number of channels in the feature map 58 | grows after each layer with stride 2. 59 | stride_2_layers: the number of SeparableConv2D layers with stride 2 layer. 60 | stride_1_layers: the number of SeparableConv2D layers with stride 1 layers 61 | between each SeparableConv2D with stride 2 layer. 62 | kernel_size: integer specifying the height and width of the 2D convolution 63 | window 64 | """ 65 | 66 | if len(input_shape) != 4: 67 | raise ValueError('Expecting 4D input') 68 | 69 | model = tf.keras.Sequential() 70 | 71 | # Create CNN for image patch feature extraction 72 | cnn = build_cnn(input_shape[1:], base_depth, depth_growth, stride_2_layers, 73 | stride_1_layers, kernel_size) 74 | 75 | # Run CNN on each image patch 76 | model.add(TimeDistributed(cnn, input_shape=input_shape)) 77 | 78 | # Copmute average of image patch features per example 79 | model.add(GlobalAveragePooling1D()) 80 | 81 | # Compute risk scores 82 | model.add(Dense(1)) 83 | 84 | return model 85 | 86 | 87 | def build_cnn(input_shape, base_depth, depth_growth, stride_2_layers, 88 | stride_1_layers, kernel_size): 89 | """Returns CNN for extrating image patch features. 90 | 91 | Args: 92 | input_shape: 3D shape of input image patches. 93 | base_depth: the number of filter in the base Conv2D layer. 94 | depth_growth: the rate at which the number of channels in the feature map 95 | grows after each layer with stride 2. 96 | stride_2_layers: the number of SeparableConv2D layers with stride 2 layer. 97 | stride_1_layers: the number of SeparableConv2D layers with stride 1 layers 98 | between each SeparableConv2D with stride 2 layer. 99 | kernel_size: integer specifying the height and width of the 2D convolution 100 | window 101 | """ 102 | model = tf.keras.Sequential() 103 | 104 | # Configure base layer 105 | base_layer = Conv2D( 106 | base_depth, 107 | kernel_size, 108 | strides=kernel_size, 109 | activation='relu', 110 | padding='same', 111 | input_shape=input_shape) 112 | model.add(base_layer) 113 | model.add(BatchNormalization()) 114 | 115 | # Depthwise separable convolution sequence 116 | for i in range(stride_2_layers): 117 | depth = int(base_depth * depth_growth**i) 118 | model.add( 119 | SeparableConv2D( 120 | depth, kernel_size, strides=2, activation='relu', padding='same')) 121 | model.add(BatchNormalization()) 122 | for _ in range(stride_1_layers): 123 | model.add( 124 | SeparableConv2D( 125 | depth, kernel_size, strides=1, activation='relu', padding='same')) 126 | model.add(BatchNormalization()) 127 | 128 | # Spatial Pooling 129 | model.add(GlobalAveragePooling2D()) 130 | 131 | return model 132 | -------------------------------------------------------------------------------- /ct_dicom/example_builder.py: -------------------------------------------------------------------------------- 1 | """Example Preparation Routines based on Pydicom for running CT models.""" 2 | 3 | import datetime 4 | from typing import Sequence 5 | 6 | import numpy as np 7 | import pydicom 8 | import tensorflow as tf 9 | 10 | import dicom_utils 11 | import image_utils 12 | 13 | # The minimum encoded pixel data in Houndsfield units. Anything lower is clipped 14 | # to this value. 15 | MIN_HU = -1024 16 | 17 | 18 | # TODO(b/339471206): Add regression test for `create_ct_tfexample()`. 19 | def create_ct_tfexample( 20 | dicom_series: Sequence[pydicom.Dataset], dataset_name: str = 'adhoc', 21 | strict_check: bool = False 22 | ) -> tf.train.Example: 23 | """Create a CT tf.example for inference based on a single series as input. 24 | 25 | Creates the core precursor tf.example produced upon DICOM export for 26 | volumetric images in the CT pipeline. This allows for loaded pydicom images 27 | to be used to make inference / deployment example creation easier. 28 | 29 | Args: 30 | dicom_series: A list of pydicom series as input to create the example. 31 | dataset_name: The dataset-level name given to the key created for the 32 | example. Stored under 'volume/id'. 33 | strict_check: If True, raise ValueError if the DICOM series as error. 34 | Otherwise, return a tf.train.Example with possibly invalid DICOM series. 35 | 36 | Returns: 37 | example: A tf.example in CT format for inference. 38 | """ 39 | 40 | # Dedupe and sort incoming slices. 41 | dicom_series, _ = dicom_utils.dedupe_series(dicom_series, strict_check) 42 | 43 | dicom_images_dict = dicom_utils.map_by_series_instance_uid( 44 | dicom_series, sort_values=True 45 | ) 46 | dicom_images_series_uid = list(dicom_images_dict.keys())[0] 47 | sorted_dicom_images = dicom_images_dict[dicom_images_series_uid] 48 | 49 | if not sorted_dicom_images: 50 | raise ValueError('No DICOM images found.') 51 | 52 | # Filter out derived images. 53 | filtered_dicom_images = [ 54 | image 55 | for image in sorted_dicom_images 56 | if 'DERIVED' not in image.get('ImageType', []) 57 | ] 58 | 59 | if not filtered_dicom_images: 60 | raise ValueError('Series contains only derived images.') 61 | 62 | # Verify slice locations are consecutive (i.e. DICOM list is complete). 63 | spacing = dicom_utils.try_get_average_slice_spacing(filtered_dicom_images) 64 | depth = len(filtered_dicom_images) 65 | patient_id = 'UNKNOWN' 66 | if 'PatientID' in filtered_dicom_images[0]: 67 | patient_id = filtered_dicom_images[0].PatientID 68 | study_uid = filtered_dicom_images[0].StudyInstanceUID 69 | # Extract Age from the DICOM. 70 | bucketized_age_value = None 71 | try: 72 | if 'PatientAge' in filtered_dicom_images[0]: 73 | patient_age_as = filtered_dicom_images[0].PatientAge 74 | age_value = ''.join(x for x in patient_age_as if x.isdigit()) 75 | age_value = float(age_value) 76 | bucketized_age_value = int(np.floor(age_value / 5.0)) 77 | except ValueError: 78 | bucketized_age_value = None 79 | 80 | # Extract image PNG values for example. 81 | instances_png = [] 82 | widths = set() 83 | heights = set() 84 | pixel_widths = set() 85 | pixel_heights = set() 86 | for a_dicom in filtered_dicom_images: 87 | heights.add(int(a_dicom.Rows)) 88 | widths.add(int(a_dicom.Columns)) 89 | pixel_heights.add(float(a_dicom.PixelSpacing[0])) # Row / Column 90 | pixel_widths.add(float(a_dicom.PixelSpacing[1])) 91 | intercept = float(a_dicom.RescaleIntercept) 92 | slope = float(a_dicom.RescaleSlope) 93 | pixel_data = a_dicom.pixel_array 94 | 95 | hu = pixel_data * slope + intercept # Cast to float. 96 | np.clip(hu, MIN_HU, None, hu) 97 | hu += 1024 98 | hu = hu.astype('uint16') 99 | instances_png.append(image_utils.encode_png(hu)) 100 | 101 | if ( 102 | len(widths) != 1 103 | or len(heights) != 1 104 | or len(pixel_widths) != 1 105 | or len(pixel_heights) != 1 106 | ): 107 | raise ValueError('Images of individual slices are of different dimensions.') 108 | (width,) = widths 109 | (height,) = heights 110 | (pixel_width,) = pixel_widths 111 | (pixel_height,) = pixel_heights 112 | 113 | # Create the tf.example 114 | example = tf.train.Example() 115 | f_dict = example.features.feature 116 | f_dict['volume/encoded'].bytes_list.value[:] = instances_png 117 | f_dict['volume/voxelsize'].float_list.value[:] = [ 118 | pixel_width, 119 | pixel_height, 120 | spacing, 121 | ] 122 | f_dict['volume/width'].int64_list.value.append(width) 123 | f_dict['volume/height'].int64_list.value.append(height) 124 | f_dict['volume/depth'].int64_list.value.append(depth) 125 | if bucketized_age_value is not None: 126 | f_dict['AgeIn5YBuckets'].int64_list.value[:] = [bucketized_age_value] 127 | 128 | key = '%s/%s/%s/%s' % ( 129 | dataset_name, 130 | patient_id, 131 | study_uid, 132 | dicom_images_series_uid, 133 | ) 134 | f_dict['volume/id'].bytes_list.value.append(str.encode(key)) 135 | 136 | # Used for prior classification in older pipeline 137 | # Get study date. Assume it's the same for all slices. 138 | study_date_value = 0 139 | if 'StudyDate' in filtered_dicom_images[0]: 140 | study_date = str(filtered_dicom_images[0].StudyDate) 141 | if len(study_date) == 8: 142 | dt = datetime.datetime( 143 | int(study_date[:4]), 144 | int(study_date[4:6]), 145 | int(study_date[6:]), 146 | 0, 147 | 0, 148 | 0, 149 | ) 150 | study_date_value = ( 151 | dt - datetime.datetime(1970, 1, 1) 152 | ).total_seconds() * 1000000 153 | f_dict['volume/stack/STUDY_DATE/value'].int64_list.value.append( 154 | int(study_date_value) 155 | ) 156 | 157 | return example 158 | -------------------------------------------------------------------------------- /fetal_ultrasound_blind_sweeps/networks.py: -------------------------------------------------------------------------------- 1 | """Model network definitions and loss functions. 2 | 3 | Defines networks and loss functions for gestational age and fetal 4 | malpresentation models featured in the publication. 5 | """ 6 | 7 | import tensorflow.compat.v1 as tf 8 | import tf_slim as slim 9 | 10 | from tensorflow.contrib import rnn 11 | from lstm_object_detection.lstm import lstm_cells 12 | from nets.mobilenet import mobilenet_v2 13 | 14 | N_LSTM_UNITS = 512 15 | LSTM_FILTER_SIZE = (3, 3) 16 | MOBILENET_DEPTH_MULTIPLIER = 1.0 17 | 18 | 19 | def _base_network(video_clips, is_training): 20 | """Builds the base network used by both models. 21 | 22 | Extracts image features independently for each image in each video clip, using 23 | MobileNetV2. Then aggregates the image features for each video clip using LSTM 24 | units. 25 | 26 | Args: 27 | video_clips: Tensor containing a batch of video clips (image sequences). 28 | Dimensions: [batch_size, sequence_length, height, width, image_channels]. 29 | is_training: Boolean value indicating whether the network graph is to be 30 | used for training models. 31 | 32 | Returns: 33 | state_and_output_concat: Tensor containing LSTM state and output values 34 | corresponding to the final image in each video clip. Dimensions: 35 | [batch_size, feature_map_height, feature_map_width, 2 * N_LSTM_UNITS]. 36 | feature_map_height and feature_map_width are determined by the spatial 37 | dimensions of the final feature map layer of the MobileNetV2 image feature 38 | extractor. 39 | """ 40 | video_clips_shape = video_clips.get_shape().as_list() 41 | assert len(video_clips_shape) == 5 42 | n_batch, n_sequence, height, width, n_channels = tuple(video_clips_shape) 43 | # Convert gray scale to RGB for use with MobileNetV2 feature extractor. 44 | if n_channels == 1: 45 | video_clips = tf.image.grayscale_to_rgb(video_clips) 46 | n_channels = 3 47 | # Flatten batch and time dimensions, MobileNetV2 extracts features for each 48 | # image frame independently. 49 | video_clips = tf.reshape( 50 | video_clips, [n_batch * n_sequence, height, width, n_channels]) 51 | 52 | # Weight decay is set to zero in the training scope, but may be overridden 53 | # by training algorithms. 54 | arg_scope = mobilenet_v2.training_scope( 55 | is_training=is_training, weight_decay=0.0) 56 | with slim.arg_scope(arg_scope): 57 | with tf.variable_scope('ImageFeatureExtractor'): 58 | image_feature_maps, _ = mobilenet_v2.mobilenet_base( 59 | video_clips, 60 | depth_multiplier=MOBILENET_DEPTH_MULTIPLIER, 61 | use_explicit_padding=True) 62 | 63 | with tf.variable_scope('LSTM') as lstm_scope: 64 | _, maps_height, maps_width, maps_n_channels = tuple( 65 | image_feature_maps.get_shape().as_list()) 66 | # Reshape the feature maps to recover sequence structure. 67 | maps_unrolled = tf.reshape(image_feature_maps, [ 68 | n_batch, n_sequence, maps_height, maps_width, maps_n_channels 69 | ]) 70 | feature_maps_sequence = tf.unstack(maps_unrolled, axis=1) 71 | lstm_cell = lstm_cells.GroupedConvLSTMCell( 72 | filter_size=LSTM_FILTER_SIZE, 73 | output_size=(maps_height, maps_width), 74 | num_units=N_LSTM_UNITS, 75 | is_training=is_training, 76 | activation=tf.nn.relu6, 77 | clip_state=True, 78 | output_bottleneck=True, 79 | visualize_gates=False) 80 | current_states_list = lstm_cell.init_state( 81 | state_name='lstm_state', batch_size=n_batch, dtype=tf.float32) 82 | init_state = rnn.LSTMStateTuple(*current_states_list) 83 | 84 | # Feed 2-D feature map sequences into recurrent LSTM cell. 85 | _, state_and_output = tf.nn.static_rnn( 86 | cell=lstm_cell, 87 | inputs=feature_maps_sequence, 88 | initial_state=init_state, 89 | scope=lstm_scope) 90 | # The state_and_output contains LSTM state and output for the last 91 | # image in the sequence. 92 | state_and_output_concat = tf.concat(state_and_output, -1) 93 | return state_and_output_concat 94 | 95 | 96 | def _average_pool(feature_map): 97 | feature_map_shape = feature_map.get_shape() 98 | input_rank = feature_map_shape.ndims 99 | n_batch = feature_map_shape.as_list()[0] 100 | return tf.reshape( 101 | tf.reduce_mean(feature_map, axis=list(range(2, input_rank - 1))), 102 | [n_batch, -1]) 103 | 104 | 105 | def gestational_age_regression_model(video_clips, is_training): 106 | """Model network for gestational age regression model.""" 107 | lstm_state_and_output = _base_network(video_clips, is_training) 108 | spatially_averaged_features = _average_pool(lstm_state_and_output) 109 | age_output = tf.layers.dense(spatially_averaged_features, units=1) 110 | variance_output = tf.layers.dense(spatially_averaged_features, units=1) 111 | # Soft plus unit ensures variances are always positive, and a small positive 112 | # value is added to prevent division by zero or small noise values in the loss 113 | # function. 114 | variance_output = 1e-6 + tf.math.softplus(variance_output) 115 | return age_output, variance_output 116 | 117 | 118 | def gestational_age_loss_function(predicted_ages, predicted_variances, labels): 119 | """Training loss function for gestational age regression model.""" 120 | squared_errors = (labels - predicted_ages)**2 121 | scaled_errors = tf.math.divide( 122 | squared_errors, predicted_variances) + tf.math.log(predicted_variances) 123 | return tf.math.reduce_mean(0.5 * scaled_errors) 124 | 125 | 126 | def fetal_malpresentation_classification_model(video_clips, is_training): 127 | """Model network for fetal malpresentation classification model.""" 128 | lstm_state_and_output = _base_network(video_clips, is_training) 129 | spatially_averaged_features = _average_pool(lstm_state_and_output) 130 | # During training, the final sigmoid activation is applied by the loss 131 | # function. 132 | malpresentation_output = tf.layers.dense( 133 | spatially_averaged_features, units=1, 134 | activation=None if is_training else tf.nn.sigmoid) 135 | return malpresentation_output 136 | 137 | 138 | def fetal_malpresentation_loss_function(logits, labels): 139 | """Training loss function for fetal malpresentation classification model.""" 140 | per_instance_loss = tf.nn.sigmoid_cross_entropy_with_logits( 141 | labels=labels, logits=tf.reshape(logits, tf.shape(labels))) 142 | # Instance weights are set to default value of 1.0. 143 | return tf.losses.compute_weighted_loss(per_instance_loss) -------------------------------------------------------------------------------- /ct_dicom/pipeline.py: -------------------------------------------------------------------------------- 1 | """End-to-end Beam pipeline(s) for creating CT Examples. 2 | 3 | These can be called in a main file using Beam Runners suitable for the target 4 | runtime environment. 5 | """ 6 | 7 | import dataclasses 8 | from typing import Optional, Tuple 9 | 10 | import apache_beam as beam 11 | from apache_beam.io import textio 12 | from apache_beam.transforms import util as beam_util 13 | 14 | import example_builder_beam 15 | from gcp import dicomweb_beam 16 | 17 | # Column name in the input CSV file corresponding to Study Instance UIDs. 18 | _STUDY_INSTANCE_UID_COLUMN_NAME = 'study_instance_uid' 19 | 20 | 21 | def _build_dicom_download_from_chc_dicomweb( 22 | root: beam.Pipeline, 23 | chc_dicom_store: dicomweb_beam.ChcDicomStore, 24 | study_instance_uid_filepath: Optional[str], 25 | ) -> Tuple[beam.PCollection, beam.PCollection]: 26 | """Builds pipeline fragment to retrieve DICOMs from DICOM Store.""" 27 | if study_instance_uid_filepath is not None: 28 | study_instance_uids = ( 29 | root 30 | | 'Collect Study Instance UIDs from CSV' 31 | >> textio.ReadFromCsv( 32 | study_instance_uid_filepath, 33 | usecols=[_STUDY_INSTANCE_UID_COLUMN_NAME], 34 | ) 35 | | beam.Map(lambda x: getattr(x, _STUDY_INSTANCE_UID_COLUMN_NAME, None)) 36 | | beam.Filter(lambda x: x is not None).with_output_types(str) 37 | ) 38 | else: 39 | study_instance_uids = ( 40 | root 41 | | 'Collect Study Instance UIDs from DICOMweb' 42 | >> beam.ParDo(dicomweb_beam.QueryStudyInstanceUidsFn(chc_dicom_store)) 43 | ) 44 | 45 | dicoms = ( 46 | study_instance_uids 47 | | beam_util.Reshuffle() 48 | | 'Collect Series Instance UIDs' 49 | >> beam.ParDo(dicomweb_beam.QuerySeriesInstanceUidsFn(chc_dicom_store)) 50 | # No reshuffling here, otherwise the DICOMweb API will be bombarded with 51 | # O(number of Studies) queries in a short duration. No reshuffling couples 52 | # Series Instance UID retrieval with DICOM download (a slower step), 53 | # helping spread the API calls over time. 54 | | 'Retrieve Series DICOMs' 55 | >> beam.ParDo( 56 | dicomweb_beam.DownloadMultipartDicomSeriesFn(chc_dicom_store) 57 | ).with_outputs( 58 | dicomweb_beam.DownloadMultipartDicomSeriesFn.ERROR_OUTPUT_TAG, 59 | main='values', 60 | ) 61 | ) 62 | 63 | return ( 64 | dicoms.values, 65 | dicoms[dicomweb_beam.DownloadMultipartDicomSeriesFn.ERROR_OUTPUT_TAG], 66 | ) 67 | 68 | 69 | def _build_example_creation_from_dicom_bytes( 70 | dicoms: beam.PCollection, 71 | ) -> Tuple[beam.PCollection, beam.PCollection]: 72 | """Builds pipeline fragment to create Examples from downloaded DICOMs.""" 73 | examples = dicoms | 'Create Examples' >> beam.ParDo( 74 | example_builder_beam.CreateCTExampleFn() 75 | ).with_outputs( 76 | example_builder_beam.CreateCTExampleFn.ERROR_OUTPUT_TAG, 77 | main='values', 78 | ) 79 | return ( 80 | examples.values, 81 | examples[example_builder_beam.CreateCTExampleFn.ERROR_OUTPUT_TAG], 82 | ) 83 | 84 | 85 | @dataclasses.dataclass(frozen=True) 86 | class Outputs: 87 | """Container for PCollections returned by `build_for_chc_dicomweb_api()`. 88 | 89 | Attributes: 90 | example_key_values: Key-value pairs where the value is the created TF 91 | Example and the key is a unique string (formatted as "/"; all slices to create the Example have the 93 | same Study and Series Instance UID Attribute values) to identify the 94 | Example. 95 | error_csv_rows: CSV row-formatted error strings. Each "row" corresponds to 96 | an unique Study-Series Instance UID pair for which an error was 97 | encountered either while downloading a DICOM, parsing DICOM bytes, or 98 | creating an Example from parsed DICOMs. 99 | """ 100 | 101 | example_key_values: beam.PCollection 102 | error_csv_rows: beam.PCollection 103 | 104 | 105 | def build_for_chc_dicomweb_api( 106 | root: beam.Pipeline, 107 | chc_dicom_store: dicomweb_beam.ChcDicomStore, 108 | study_instance_uid_filepath: Optional[str] = None, 109 | ) -> Outputs: 110 | """Builds CT Example creation pipeline reading DICOMs from CHC DICOMweb API. 111 | 112 | If `study_instance_uid_filepath` is not set, the pipeline runs on all Study 113 | Instance UIDs within the CHC DICOM Store. Otherwise, it uses Study Instance 114 | UIDs listed in this file. 115 | 116 | The outputs include the created Examples (under the `example_key_values` 117 | attribute) and CSV-formatted error strings (under the `error_csv_rows` 118 | attribute). 119 | 120 | Each CSV-formatted string in the `error_csv_rows` output attribute captures 121 | the first error encountered while downloading, parsing, and creating an 122 | Example for a given Series Instance UID. It is a comma-separated, 123 | double-quoted sequence of entries: 124 | - Column 1: Study Instance UID associated with the Series (Column 2). 125 | - Column 2: The Series Instance UID for which Example creation failed. 126 | - Column 3 onwards: Arguments to the Exception caught. 127 | 128 | By default, this pipeline uses Application Default Credentials to authenticate 129 | with the CHC DICOMweb API. Additional runtime environments are supported via 130 | command-line flags that must be declared by calling `gcp.auth.define_flags()` 131 | before Abseil parses command-line flags at runtime. 132 | 133 | Args: 134 | root: The root (source) of the Beam pipeline to connect. 135 | chc_dicom_store: CHC DICOM Store to download DICOMs from. 136 | study_instance_uid_filepath: A CSV file containing input Study Instance UIDs 137 | to query from the DICOM Store (in lieu of querying all Study Instance UIDs 138 | from the CHC DICOM Store). The UIDs must be present in a column titled 139 | "study_instance_uid". 140 | 141 | Returns: 142 | PCollections containing the created Examples (`example_key_values` 143 | attribute) and any CSV-formatted error strings (`error_csv_rows` attribute). 144 | """ 145 | dicoms, download_errors = _build_dicom_download_from_chc_dicomweb( 146 | root, chc_dicom_store, study_instance_uid_filepath 147 | ) 148 | example_key_values, example_creation_errors = ( 149 | _build_example_creation_from_dicom_bytes(dicoms) 150 | ) 151 | errors = ( 152 | download_errors, 153 | example_creation_errors, 154 | ) | 'Collect Errors' >> beam.Flatten() 155 | return Outputs(example_key_values, errors) 156 | -------------------------------------------------------------------------------- /health_acoustic_representations/eval_utils.py: -------------------------------------------------------------------------------- 1 | """Eval utils. 2 | 3 | Train linear models on top of frozen embeddings. 4 | """ 5 | 6 | import numpy as np 7 | import pandas as pd 8 | from sklearn import linear_model 9 | from sklearn import metrics 10 | from sklearn import model_selection 11 | from sklearn import utils as sk_utils 12 | 13 | 14 | LinearProbe = ( 15 | linear_model.LogisticRegressionCV 16 | | linear_model.RidgeCV 17 | | linear_model.Ridge 18 | | model_selection.GridSearchCV 19 | | linear_model.SGDClassifier 20 | | linear_model.LogisticRegression 21 | | linear_model.ElasticNetCV 22 | | linear_model.ElasticNet 23 | ) 24 | 25 | 26 | def create_linear_probe( 27 | regularization_coef: float, 28 | is_regression: bool, 29 | use_sgd_classifier: bool = True, 30 | ) -> LinearProbe: 31 | """Creates linear probe.""" 32 | if is_regression: 33 | return linear_model.Ridge(alpha=regularization_coef) 34 | else: 35 | if use_sgd_classifier: 36 | return linear_model.SGDClassifier( 37 | loss='log', 38 | penalty='l2', 39 | alpha=regularization_coef, 40 | class_weight='balanced', 41 | max_iter=1_000_000, 42 | tol=1e-3, 43 | random_state=42, 44 | ) 45 | else: 46 | return linear_model.LogisticRegression( 47 | C=regularization_coef, 48 | class_weight='balanced', 49 | penalty='l2', 50 | max_iter=10_000_000, 51 | ) 52 | 53 | 54 | def predict_with_probe( 55 | probe: LinearProbe, 56 | features: np.ndarray, 57 | is_regression: bool, 58 | ) -> np.ndarray: 59 | """Computes trained linear probe's predictions.""" 60 | # pytype: disable=attribute-error 61 | if is_regression: 62 | return probe.predict(features) 63 | else: 64 | return probe.predict_proba(features)[:, 1] 65 | # pytype: enable=attribute-error 66 | 67 | 68 | def find_reg_coef_with_best_metric( 69 | cv_scores: dict[float, float], lower_is_better: bool 70 | ) -> float: 71 | """Finds regularization coef with best metric. 72 | 73 | Args: 74 | cv_scores: A map between regularization coefficients and the cross-validated 75 | performance (ROCAUC for classification or MAE for regression) on the 76 | held-out folds. 77 | lower_is_better: A boolean indicating if the metric is best when lowest 78 | (True) or highest (False). 79 | 80 | Returns: 81 | The key of `cv_scores` corresponding to the best metric. 82 | """ 83 | best_alpha = -1.0 84 | if lower_is_better: 85 | best_metric = 1e50 86 | else: 87 | best_metric = 0 88 | for alpha, metric in cv_scores.items(): 89 | if lower_is_better: 90 | if metric < best_metric: 91 | best_alpha = alpha 92 | best_metric = metric 93 | else: 94 | if metric > best_metric: 95 | best_alpha = alpha 96 | best_metric = metric 97 | return best_alpha 98 | 99 | 100 | def compute_metrics_for_probe( 101 | y_true: np.ndarray, 102 | y_score: np.ndarray, 103 | is_regression: bool, 104 | ) -> float: 105 | if is_regression: 106 | return metrics.mean_absolute_error(y_true=y_true, y_pred=y_score) 107 | else: 108 | return metrics.roc_auc_score(y_true=y_true, y_score=y_score) 109 | 110 | 111 | def train_linear_probe_with_participant_level_crossval( 112 | features: np.ndarray, 113 | labels: np.ndarray, 114 | participant_ids: np.ndarray, 115 | is_regression: bool, 116 | n_folds: int = 5, 117 | use_sgd_classifier: bool = True, 118 | stratify_per_label: bool = True, 119 | ) -> LinearProbe: 120 | """Trains a linear probe using cross-validated l2 penalization parameter.""" 121 | assert features.shape[0] == labels.shape[0] == participant_ids.shape[0] 122 | 123 | if is_regression: 124 | label_by_participant_ids = ( 125 | pd.DataFrame({'participant_id': participant_ids, 'label': labels}) 126 | .groupby('participant_id') 127 | .mean() 128 | ) 129 | label_by_participant_ids = label_by_participant_ids.label.to_dict() 130 | else: 131 | label_by_participant_ids = dict(zip(participant_ids, labels)) 132 | unique_participant_ids = np.array(list(set(participant_ids))) 133 | unique_labels = np.array( 134 | [label_by_participant_ids[k] for k in unique_participant_ids] 135 | ) 136 | if stratify_per_label and not is_regression: 137 | folds = list( 138 | model_selection.StratifiedKFold( 139 | n_folds, shuffle=True, random_state=43 140 | ).split(unique_participant_ids, unique_labels) 141 | ) 142 | else: 143 | folds = list( 144 | model_selection.KFold(n_folds, shuffle=True, random_state=43).split( 145 | unique_participant_ids 146 | ) 147 | ) 148 | 149 | cv_scores = {} 150 | for alpha in np.logspace(-5, 5, num=50): 151 | cross_validated_metrics = 0 152 | 153 | for random_seed, (train_idx, test_idx) in enumerate(folds): 154 | 155 | # `train_idx` and `test_idx` are arrays of integers corresponding to 156 | # indices within `unique_participant_ids`. They take values in 157 | # `range(len(unique_participant_ids))`. 158 | train_unique_participant_idx = set(unique_participant_ids[train_idx]) 159 | test_unique_participant_idx = set(unique_participant_ids[test_idx]) 160 | assert not (test_unique_participant_idx & train_unique_participant_idx) 161 | 162 | keep_train = [ 163 | pid in train_unique_participant_idx for pid in participant_ids 164 | ] 165 | train_fold_features = features[keep_train] * 1 166 | train_fold_labels = labels[keep_train] * 1 167 | train_fold_features, train_fold_labels = sk_utils.shuffle( 168 | train_fold_features, train_fold_labels, random_state=random_seed 169 | ) 170 | 171 | keep_test = [ 172 | pid in test_unique_participant_idx for pid in participant_ids 173 | ] 174 | test_fold_features = features[keep_test] 175 | test_fold_labels = labels[keep_test] 176 | 177 | lr = create_linear_probe( 178 | regularization_coef=alpha, 179 | is_regression=is_regression, 180 | use_sgd_classifier=use_sgd_classifier, 181 | ) 182 | lr.fit(train_fold_features, train_fold_labels) 183 | 184 | predictions = predict_with_probe( 185 | probe=lr, 186 | features=test_fold_features, 187 | is_regression=is_regression, 188 | ) 189 | 190 | cross_validated_metrics += compute_metrics_for_probe( 191 | y_true=test_fold_labels, 192 | y_score=predictions, 193 | is_regression=is_regression, 194 | ) 195 | 196 | cv_scores[alpha] = cross_validated_metrics / n_folds 197 | 198 | best_alpha = find_reg_coef_with_best_metric( 199 | cv_scores=cv_scores, lower_is_better=is_regression 200 | ) 201 | lr = create_linear_probe( 202 | regularization_coef=best_alpha, 203 | is_regression=is_regression, 204 | use_sgd_classifier=use_sgd_classifier, 205 | ) 206 | 207 | lr.fit(features, labels) 208 | return lr 209 | -------------------------------------------------------------------------------- /ct_dicom/image_utils_test.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import io 3 | import os 4 | from typing import Mapping 5 | 6 | from absl import logging 7 | from absl.testing import absltest 8 | from absl.testing import parameterized 9 | import numpy as np 10 | from PIL import Image 11 | 12 | import image_utils 13 | 14 | 15 | class BitWidth(enum.Enum): 16 | B16 = '16-bits' 17 | B8 = '8-bits' 18 | 19 | 20 | _TEST_DATA_DIR = 'google_health/ct_dicom/testdata' 21 | 22 | _NP_IMAGE_FILENAME_BY_BIT_WIDTH = { 23 | BitWidth.B8: os.path.join(_TEST_DATA_DIR, 'img_8c1.npy'), 24 | BitWidth.B16: os.path.join(_TEST_DATA_DIR, 'img_16c1.npy'), 25 | } 26 | _PNG_IMAGE_FILENAME_BY_BIT_WIDTH = { 27 | BitWidth.B8: os.path.join(_TEST_DATA_DIR, 'img_8c1.png'), 28 | BitWidth.B16: os.path.join(_TEST_DATA_DIR, 'img_16c1.png'), 29 | } 30 | _PIXEL_ERROR_MARGIN = 0.00001 31 | # Some encoder tests are parameterized by bit width. 32 | _DTYPE_BY_BIT_WIDTH = { 33 | BitWidth.B8: np.uint8, 34 | BitWidth.B16: np.uint16, 35 | } 36 | _TEST_PARAMS = tuple((bit_width,) for bit_width in BitWidth) 37 | 38 | 39 | def _GetPixelStats(encoded_string: str) -> Mapping[str, np.array]: 40 | """Returns, min, max and average pixel value in the encoded_string.""" 41 | decoded_array = Image.open(io.BytesIO(encoded_string)) 42 | npdecoded_array = np.array(decoded_array) 43 | return { 44 | 'min': np.min(npdecoded_array), 45 | 'max': np.max(npdecoded_array), 46 | 'ave': np.average(npdecoded_array), 47 | } 48 | 49 | 50 | class TestEncodePng(parameterized.TestCase): 51 | """Unit tests for `encode_png()`.""" 52 | 53 | @classmethod 54 | def setUpClass(cls): 55 | """Preloads test resources.""" 56 | super().setUpClass() 57 | cls._NP_IMAGE_BY_BIT_WIDTH = {} 58 | cls._PNG_IMAGE_BY_BIT_WIDTH = {} 59 | 60 | for bit_width in BitWidth: 61 | logging.info('Loading: %s', _NP_IMAGE_FILENAME_BY_BIT_WIDTH[bit_width]) 62 | with gfile.Open(_NP_IMAGE_FILENAME_BY_BIT_WIDTH[bit_width], 'rb') as f: 63 | cls._NP_IMAGE_BY_BIT_WIDTH[bit_width] = np.load(f) 64 | 65 | logging.info('Loading: %s', _PNG_IMAGE_FILENAME_BY_BIT_WIDTH[bit_width]) 66 | with gfile.Open(_PNG_IMAGE_FILENAME_BY_BIT_WIDTH[bit_width], 'rb') as f: 67 | cls._PNG_IMAGE_BY_BIT_WIDTH[bit_width] = f.read() 68 | 69 | @parameterized.parameters(*_TEST_PARAMS) 70 | def testSuccess_Range(self, bit_width): 71 | """Tests image (w, h) = (4, 2) with maximum range of values for uint*.""" 72 | self.assertIn(bit_width, _DTYPE_BY_BIT_WIDTH) 73 | dtype = _DTYPE_BY_BIT_WIDTH[bit_width] 74 | 75 | test_array = np.array( 76 | [[0, 1, 2, 3], [np.iinfo(dtype).max, 12000, 100, 150]] 77 | ).astype(dtype) 78 | png_text = image_utils.encode_png(test_array) 79 | result_array = np.array(Image.open(io.BytesIO(png_text))) 80 | np.testing.assert_array_equal(test_array, result_array) 81 | 82 | @parameterized.parameters(*_TEST_PARAMS) 83 | def testSuccess_Idempotence(self, bit_width): 84 | """Tests that `decode(encode(*))` is an identity op.""" 85 | self.assertIn(bit_width, self._PNG_IMAGE_BY_BIT_WIDTH) 86 | loaded_png_bytes = self._PNG_IMAGE_BY_BIT_WIDTH[bit_width] 87 | 88 | canonical_pixels = self.DecodePng(loaded_png_bytes) 89 | encoded_png_bytes = image_utils.encode_png(canonical_pixels) 90 | actual_pixels = self.DecodePng(encoded_png_bytes) 91 | self.assertEqual(canonical_pixels.dtype, actual_pixels.dtype) 92 | self.assertEqual(canonical_pixels.shape, actual_pixels.shape) 93 | self.assertTrue(np.array_equal(canonical_pixels, actual_pixels)) 94 | 95 | @parameterized.parameters(*_TEST_PARAMS) 96 | def testSuccess_Regression(self, bit_width): 97 | """Captures difference in outputs from OpenCV-based encoder.""" 98 | self.assertIn(bit_width, self._PNG_IMAGE_BY_BIT_WIDTH) 99 | canonical_png_bytes = self._PNG_IMAGE_BY_BIT_WIDTH[bit_width] 100 | self.assertIn(bit_width, self._NP_IMAGE_BY_BIT_WIDTH) 101 | test_png_bytes = image_utils.encode_png( 102 | self._NP_IMAGE_BY_BIT_WIDTH[bit_width] 103 | ) 104 | 105 | canonical_pixel_stats = _GetPixelStats(canonical_png_bytes) 106 | test_pixel_stats = _GetPixelStats(test_png_bytes) 107 | 108 | logging.info('Canonical pixels: %s', str(canonical_pixel_stats)) 109 | logging.info('Instance pixels: %s', str(test_pixel_stats)) 110 | logging.info( 111 | 'Diff in average pixel values: %f', 112 | test_pixel_stats['ave'] - canonical_pixel_stats['ave'], 113 | ) 114 | 115 | self.assertEqual(test_pixel_stats['min'], canonical_pixel_stats['min']) 116 | self.assertEqual(test_pixel_stats['max'], canonical_pixel_stats['max']) 117 | self.assertAlmostEqual( 118 | test_pixel_stats['ave'], 119 | canonical_pixel_stats['ave'], 120 | delta=_PIXEL_ERROR_MARGIN, 121 | ) 122 | 123 | @parameterized.parameters(np.int32, np.uint32, np.int16, np.int8) 124 | def testFailure_Dtype(self, dtype): 125 | """Tests failure to convert to PNG for invalid image dimensions.""" 126 | array = np.array([[0, 1], [2, 4]], dtype=dtype) 127 | with self.assertRaisesRegex( 128 | ValueError, 'Pixels must be either `uint8` or `uint16`.' 129 | ): 130 | image_utils.encode_png(array) 131 | 132 | def testFailure_Dimensions(self): 133 | """Tests failure to convert with wrong input dimensions.""" 134 | test_array_3d = np.ones([2, 2, 2]).astype(np.uint16) 135 | with self.assertRaisesRegex(ValueError, 'Array must be 2-D.'): 136 | image_utils.encode_png(test_array_3d) 137 | 138 | def testEncodeRaisesErrorWithBadInput(self): 139 | with self.assertRaisesRegex(ValueError, 'empty image'): 140 | image_utils.encode_png(np.zeros((50, 0), dtype=np.uint16)) 141 | 142 | def testConversionToPNGImage16bitExtremes(self): 143 | # Test image wXh = 4X2 with maximum range of values for unsigned int 16. 144 | test_array = np.array( 145 | [[0, 1, 2, 3], [np.iinfo(np.uint16).max, 12000, 100, 150]] 146 | ).astype(np.uint16) 147 | 148 | png_bytes = image_utils.encode_png(test_array) 149 | result_array = np.array(Image.open(io.BytesIO(png_bytes))) 150 | np.testing.assert_array_equal(test_array, result_array) 151 | 152 | def testConversionToPNGImage8bitExtremes(self): 153 | # Test image wXh = 4X2 with maximum range of values for unsigned int 8. 154 | test_array = np.array( 155 | [[0, 1, 2, 3], [np.iinfo(np.uint8).max, 128, 64, 32]] 156 | ).astype(np.uint8) 157 | 158 | png_bytes = image_utils.encode_png(test_array) 159 | result_array = np.array(Image.open(io.BytesIO(png_bytes))) 160 | np.testing.assert_array_equal(test_array, result_array) 161 | 162 | def DecodePng(self, png_bytes: bytes) -> np.ndarray: 163 | """Converts an encoded 16-bit grayscale PNG to a 2D uint16 array.""" 164 | # The use of np.uint8 here is for the png bytes, not the pixel values. 165 | byte_array = np.frombuffer(png_bytes, np.uint8) 166 | pixel_array = np.array(Image.open(io.BytesIO(byte_array))).astype(np.uint16) 167 | self.assertEqual(pixel_array.dtype, np.uint16) 168 | self.assertEqual(pixel_array.ndim, 2) 169 | return pixel_array 170 | 171 | 172 | if __name__ == '__main__': 173 | absltest.main() 174 | -------------------------------------------------------------------------------- /ct_dicom/dicom_utils.py: -------------------------------------------------------------------------------- 1 | """DICOM utilities based on pydicom for sorting and examining DICOM data.""" 2 | 3 | import collections 4 | from typing import Any, Callable, Iterable, Mapping, Sequence, Tuple 5 | 6 | import numpy as np 7 | import pydicom 8 | 9 | # Expect that the axial spacing between slices is consistent to a factor of 40%. 10 | _SLICE_SPACING_TOLERANCE_RATIO = 0.4 11 | 12 | # Index of the axial (Z) dimension for slice spacing computations. Used for 13 | # indexing into the Image Position (Patient) (0020,0032) Attribute value. 14 | _IMG_POS_PAT_ZCOORD = 2 15 | 16 | 17 | def validate_slice_spacing(dicoms: Sequence[pydicom.Dataset]) -> None: 18 | """Verifies slice spacing based on sanity checks on the average spacing. 19 | 20 | The following requirements are validated: 21 | - At least 2 DICOMs in `dicoms` to infer slice spacing. 22 | - Slices are sorted in increasing order of axial dimension of Image Position. 23 | - No duplicate slices. 24 | - The max slice spacing is no more than 50% of the min slice spacing. 25 | 26 | Args: 27 | dicoms: Sequence of DICOM images in increasing order of the axial dimension 28 | values for the Image Position (Patient) (0020,0032) Attribute. Must have 29 | at least 2 DICOMs. 30 | 31 | Raises: 32 | ValueError: If any one of the requirements (in the description) fail. 33 | """ 34 | slice_positions = tuple( 35 | dicom.ImagePositionPatient[_IMG_POS_PAT_ZCOORD] for dicom in dicoms 36 | ) 37 | slice_spacings = np.array( 38 | [cur - prev for prev, cur in zip(slice_positions, slice_positions[1:])] 39 | ) 40 | if not slice_spacings.size: 41 | raise ValueError(f'Too few DICOMs ({len(dicoms)}) to infer slice spacing.') 42 | if np.any(np.isclose(slice_spacings, 0.0)): 43 | raise ValueError( 44 | 'DICOM slices are not ordered in increasing value of axial dimension of' 45 | ' the Image Position (Patient) (0020,0032) Attribute.' 46 | ) 47 | 48 | min_slice_spacing = np.min(slice_spacings) 49 | max_slice_spacing = np.max(slice_spacings) 50 | try: 51 | spacing_factor = (max_slice_spacing - min_slice_spacing) / min_slice_spacing 52 | except ZeroDivisionError as e: 53 | raise ValueError( 54 | 'Found a pair of duplicate or non-axially aligned slices.' 55 | ) from e 56 | 57 | if spacing_factor > _SLICE_SPACING_TOLERANCE_RATIO: 58 | raise ValueError( 59 | f'CT Instance spacing ratio {spacing_factor:.2f} exceeds the allowed' 60 | f' {_SLICE_SPACING_TOLERANCE_RATIO:.2f} tolerance. Max spacing:' 61 | f' {max_slice_spacing:.2f}mm, Min spacing: {min_slice_spacing:.2f}mm' 62 | ) 63 | 64 | 65 | def try_get_average_slice_spacing( 66 | dicoms: Sequence[pydicom.Dataset], 67 | ) -> float: 68 | """Returns an average of the slice spacing. 69 | 70 | Exceptions from `validate_slice_spacing()` are passed through to the caller. 71 | 72 | Args: 73 | dicoms: Sequence of DICOM images in increasing order of the axial dimension 74 | values for the Image Position (Patient) (0020,0032) Attribute. Must have 75 | at least 2 DICOMs. 76 | 77 | Returns: 78 | The average slice spacing. 79 | """ 80 | validate_slice_spacing(dicoms) 81 | assert len(dicoms) > 1 82 | return ( 83 | dicoms[-1].ImagePositionPatient[_IMG_POS_PAT_ZCOORD] 84 | - dicoms[0].ImagePositionPatient[_IMG_POS_PAT_ZCOORD] 85 | ) / (len(dicoms) - 1) 86 | 87 | 88 | def dedupe_series( 89 | dicom_datasets: Sequence[pydicom.Dataset], 90 | strict_check: bool = False 91 | ) -> Tuple[Sequence[pydicom.Dataset], bool]: 92 | """Deduplicates slices of a single series by acquisition and instance number. 93 | 94 | In some cases unrelated slices are grouped into the same series UID. 95 | This attempts to remove duplicates by selecting the acquisition with the 96 | greatest number of slices followed by having unique instance numbers. 97 | 98 | Args: 99 | dicom_datasets: List of pydicom datasets to be de-duped. Note, this assumes 100 | that all belong to the same series instance UID. 101 | strict_check: If True, raise ValueError if the DICOM series as error. 102 | Otherwise, return DICOMs with possibly invalid DICOM series (e.g. missing 103 | AcquisitionNumber). 104 | 105 | Returns: 106 | final_dicoms: List of deduped cases. 107 | needed_correction: Set to True iff dicoms were eliminated. 108 | """ 109 | needed_correction = False 110 | 111 | # Get the acquisitions with the most slices and ensure a single series. 112 | series_uid = set() 113 | acquisitions = {} 114 | for a_dicom in dicom_datasets: 115 | series_uid.add(a_dicom.SeriesInstanceUID) 116 | a_acquisition_number = -1 117 | if 'AcquisitionNumber' not in a_dicom and strict_check: 118 | raise ValueError('DICOM does not have AcquisitionNumber metadata.') 119 | elif 'AcquisitionNumber' in a_dicom: 120 | a_acquisition_number = a_dicom.AcquisitionNumber 121 | if a_acquisition_number not in acquisitions: 122 | acquisitions[a_acquisition_number] = [] 123 | acquisitions[a_acquisition_number].append(a_dicom) 124 | most_slices = max(acquisitions, key=lambda k: len(acquisitions[k])) 125 | 126 | if len(series_uid) != 1: 127 | raise ValueError( 128 | 'Got {len(series_uid)} unique series. Function only operates on a' 129 | ' single series.' 130 | ) 131 | 132 | # Dedupe by instance number. 133 | final_dicoms = [] 134 | instance_numbers = set() 135 | for instance in acquisitions[most_slices]: 136 | if instance.InstanceNumber not in instance_numbers: 137 | instance_numbers.add(instance.InstanceNumber) 138 | final_dicoms.append(instance) 139 | 140 | if len(dicom_datasets) != len(final_dicoms): 141 | needed_correction = True 142 | return final_dicoms, needed_correction 143 | 144 | 145 | def _sort_series_by_image_position_patient( 146 | dicoms: Sequence[pydicom.Dataset], 147 | ) -> Tuple[pydicom.Dataset, ...]: 148 | """Sorts a series of pydicom data by Z of ImagePositionPatient (Axial CT).""" 149 | return tuple( 150 | sorted(dicoms, key=lambda d: d.ImagePositionPatient[_IMG_POS_PAT_ZCOORD]) 151 | ) 152 | 153 | 154 | def _map_by_dicom_attribute( 155 | dicoms: Iterable[pydicom.Dataset], 156 | attr: str, 157 | attr_transformer: Callable[[Any], Any] = str, 158 | sort_values=True, 159 | ) -> Mapping[bytes, Sequence[pydicom.Dataset]]: 160 | """Maps DICOMs by Attribute Value (optionally sorted by Series Number.""" 161 | dicoms_by_attribute = collections.defaultdict(list) 162 | for d in dicoms: 163 | dicoms_by_attribute[attr_transformer(getattr(d, attr))].append(d) 164 | 165 | if sort_values: 166 | for attribute_value, unsorted_dicoms in dicoms_by_attribute.items(): 167 | dicoms_by_attribute[attribute_value] = ( 168 | _sort_series_by_image_position_patient(unsorted_dicoms) 169 | ) 170 | return dicoms_by_attribute 171 | 172 | 173 | def map_by_series_instance_uid( 174 | dicoms: Iterable[pydicom.Dataset], sort_values=True 175 | ) -> Mapping[bytes, Sequence[pydicom.Dataset]]: 176 | """Get DICOMs mapped by Series Instance UID (0020, 000E). 177 | 178 | Each UID maps to a list of DICOMs, which may optionally be sorted in 179 | increasing order of the value of the third (Z) dimension of the Image Position 180 | (Patient) Attribute (0020, 0032). 181 | 182 | Args: 183 | dicoms: Input DICOM datasets to map. 184 | sort_values: Sort the mapped DICOM by the Z coordinate of the Image Position 185 | (Patient) Attribute. See method docstring for details. 186 | 187 | Returns: 188 | Mapping of Series Number Attribute values to DICOM Dataset sequences. 189 | """ 190 | return _map_by_dicom_attribute( 191 | dicoms, 'SeriesInstanceUID', sort_values=sort_values 192 | ) 193 | 194 | 195 | def map_by_series_number( 196 | dicoms: Iterable[pydicom.Dataset], sort_values: bool = True 197 | ) -> Mapping[bytes, Sequence[pydicom.Dataset]]: 198 | """Get DICOMs mapped by Series Number (0020, 0011). 199 | 200 | Each UID maps to a list of DICOMs, which may optionally be sorted in 201 | increasing order of the value of the third (Z) dimension of the Image Position 202 | (Patient) Attribute (0020, 0032). 203 | 204 | Args: 205 | dicoms: Input DICOM datasets to map. 206 | sort_values: Sort the mapped DICOM by the Z coordinate of the Image Position 207 | (Patient) Attribute. See method docstring for details. 208 | 209 | Returns: 210 | Mapping of Series Number Attribute values to DICOM Dataset sequences. 211 | """ 212 | return _map_by_dicom_attribute( 213 | dicoms, 'SeriesNumber', attr_transformer=int, sort_values=sort_values 214 | ) 215 | -------------------------------------------------------------------------------- /breast_survival_prediction/example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "kIbu0zjWkPZn" 7 | }, 8 | "source": [ 9 | "This notebook illustrates the usage of stage-2 features as described in \"Deep learning models for histologic grading of breast cancer and association with disease prognosis\"." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": { 16 | "id": "eN2-OtOHSUbp" 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "import numpy as np" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "metadata": { 26 | "id": "tIbsTm9fyMp5" 27 | }, 28 | "source": [ 29 | "# Mitotic Count" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": { 36 | "id": "ymkgekfhSpTA" 37 | }, 38 | "outputs": [], 39 | "source": [ 40 | "def generate_mitotic_heatmap(heatmap_size, list_of_coordinates, mitosis_size=6):\n", 41 | " \"\"\"Generates mitotic heatmap with the given list of coordinates.\n", 42 | " \n", 43 | " Args:\n", 44 | " heatmap_size: size of the heatmap to generate\n", 45 | " list_of_coordinate: coordinates (tuple) of center of the mitoses.\n", 46 | " mitosis_size: size of each mitosis.\n", 47 | " Returns:\n", 48 | " Heatmaps that represent mitosis detection.\n", 49 | " \"\"\"\n", 50 | " half_mitosis_size = int(mitosis_size / 2)\n", 51 | " heatmap = np.zeros(heatmap_size)\n", 52 | " for coord in list_of_coordinates:\n", 53 | " y, x = coord\n", 54 | " y = y - half_mitosis_size\n", 55 | " x = x - half_mitosis_size\n", 56 | " heatmap[y :(y + mitosis_size - 1), x:(x + mitosis_size - 1)] = 1\n", 57 | " return heatmap\n", 58 | "\n", 59 | "\n", 60 | "def detect_and_calc_density(heatmap,\n", 61 | " detection_th=0.5,\n", 62 | " morph_erode_size=4,\n", 63 | " window_size=128,\n", 64 | " stride=64):\n", 65 | " \"\"\"Combined steps of detection and density calculation.\n", 66 | "\n", 67 | " Args:\n", 68 | " heatmap: 2D array of shape (height, width) that represent probability of\n", 69 | " mitotic activity.\n", 70 | " detection_th: detection threshold, see mc_util.heatmap_to_list.\n", 71 | " morph_erode_size: size of structuring element for detection cleanup, see\n", 72 | " mc_util.heatmap_to_list.\n", 73 | " stride: density window stride, see mc_util.calculate_density.\n", 74 | " window_size: density window size, see mc_util.calculate_density.\n", 75 | "\n", 76 | " Returns:\n", 77 | " Dict of detection and density.\n", 78 | " \"\"\"\n", 79 | " # Resize mask so it is in the same size as the heatmap.\n", 80 | "\n", 81 | " detection = mc_util.heatmap_to_list(\n", 82 | " heatmap,\n", 83 | " detection_th,\n", 84 | " morph_erode_size=morph_erode_size)\n", 85 | " heatmap_size = heatmap.shape\n", 86 | "\n", 87 | " density = mc_util.calculate_density(\n", 88 | " detection, heatmap_size, window_size, stride)\n", 89 | " return {'density': density, 'detection': detection}" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": { 96 | "executionInfo": { 97 | "elapsed": 326, 98 | "status": "ok", 99 | "timestamp": 1659999029378, 100 | "user": { 101 | "displayName": "", 102 | "userId": "" 103 | }, 104 | "user_tz": 420 105 | }, 106 | "id": "ggoAd2z3ZOI0", 107 | "outputId": "a164b777-c101-4470-f022-07c6632c098f" 108 | }, 109 | "outputs": [ 110 | { 111 | "name": "stdout", 112 | "output_type": "stream", 113 | "text": [ 114 | "Input mitosis list: [(256, 650), (265, 467), (267, 514), (279, 443), (287, 458), (288, 438), (294, 744), (297, 314), (298, 627), (299, 616)]\n", 115 | "Detected mitosis list: [(255, 649), (264, 466), (266, 513), (278, 442), (286, 457), (287, 437), (293, 743), (296, 313), (297, 626), (298, 615)]\n", 116 | "calculated_features: [0. 0. 0. 0.00012207 0.0004425 ]\n" 117 | ] 118 | } 119 | ], 120 | "source": [ 121 | "np.random.seed(0)\n", 122 | "\n", 123 | "heatmap_size = (1024, 1024)\n", 124 | "n_mitosis = 100\n", 125 | "list_of_mitosis = [(np.random.randint(256, 768), np.random.randint(256, 768)) for _ in range(n_mitosis)]\n", 126 | "\n", 127 | "\n", 128 | "heatmap = generate_mitotic_heatmap(heatmap_size, list_of_mitosis)\n", 129 | "res = detect_and_calc_density(heatmap)\n", 130 | "mc_features = stage2_features.mc_featurizer(res)\n", 131 | "detected_mitosis = [(int(x[0]), int(x[1])) for x in res['detection']]\n", 132 | "print('Input mitosis list: ', sorted(list_of_mitosis, key=lambda x: x[0])[:10])\n", 133 | "print('Detected mitosis list:',sorted(detected_mitosis, key=lambda x: x[0])[:10])\n", 134 | "print('calculated_features:', mc_features)\n" 135 | ] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "metadata": { 140 | "id": "6O-4d1q5yOiX" 141 | }, 142 | "source": [ 143 | "# Nuclear Pleomorphism and Tubule Formation" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "metadata": { 150 | "id": "DHNUaryHzC2i" 151 | }, 152 | "outputs": [], 153 | "source": [ 154 | "ic_heatmap = [\n", 155 | " np.zeros((5, 4)),\n", 156 | "[\n", 157 | " [np.nan, np.nan, np.nan, np.nan,],\n", 158 | " [1, 1, 1, 0],\n", 159 | " [1, 1, 1, 0],\n", 160 | " [1, 1, 1, 0],\n", 161 | " [1, np.nan, np.nan, np.nan,],\n", 162 | "]\n", 163 | "]\n", 164 | "# NP/TF 1 heatmap: 0.4 of IC area, 0.33 of non-IC area\n", 165 | "nptf1_heatmap = [\n", 166 | " [0, 0, 0, 0],\n", 167 | " [1, 1, 1, 1],\n", 168 | " [1, 0, 0, 0],\n", 169 | " [0, 0, 0, 0],\n", 170 | " [0, 0, 0, 0],\n", 171 | "]\n", 172 | "# NP/TF 2 heatmap: 0.2 of IC area, 0.66 of non-IC area\n", 173 | "nptf2_heatmap = [\n", 174 | " [0, 0, 0, 0],\n", 175 | " [0, 0, 0, 0],\n", 176 | " [0, 1, 1, 1],\n", 177 | " [0, 0, 0, 1],\n", 178 | " [0, 0, 0, 0],\n", 179 | "]\n", 180 | "# NP/TF 3 heatmap: 0.4 of IC area, 0.0 of non-IC area\n", 181 | "nptf3_heatmap = [\n", 182 | " [0, 0, 0, 0],\n", 183 | " [0, 0, 0, 0],\n", 184 | " [0, 0, 0, 0],\n", 185 | " [1, 1, 1, 0],\n", 186 | " [1, 0, 0, 0],\n", 187 | "]\n", 188 | "\n", 189 | "# IC Heatmap is expected to be 3D with last channel representing the probability of being invasive carcinoma.\n", 190 | "ic_heatmap = np.dstack(ic_heatmap)\n", 191 | "nptf_heatmap = np.dstack([nptf1_heatmap, nptf2_heatmap, nptf3_heatmap])\n", 192 | "\n", 193 | "tmap = {'ic_heatmap': ic_heatmap, 'heatmap': nptf_heatmap}" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": null, 199 | "metadata": { 200 | "executionInfo": { 201 | "elapsed": 288, 202 | "status": "ok", 203 | "timestamp": 1659999030423, 204 | "user": { 205 | "displayName": "", 206 | "userId": "" 207 | }, 208 | "user_tz": 420 209 | }, 210 | "id": "NKwIr4KSgsxd", 211 | "outputId": "263e1c38-3c1b-4ae4-f3e0-c3d58e676fdb" 212 | }, 213 | "outputs": [ 214 | { 215 | "data": { 216 | "text/plain": [ 217 | "array([0.4 , 0.2 , 0.4 , 0.33333333, 0.66666667,\n", 218 | " 0. ])" 219 | ] 220 | }, 221 | "execution_count": 6, 222 | "metadata": {}, 223 | "output_type": "execute_result" 224 | } 225 | ], 226 | "source": [ 227 | "stage2_features.np_tf_featurizer(tmap)" 228 | ] 229 | } 230 | ], 231 | "metadata": { 232 | "colab": { 233 | "collapsed_sections": [], 234 | "last_runtime": { 235 | "build_target": "", 236 | "kind": "local" 237 | }, 238 | "name": "Example Usages of Stage-2 Featurization", 239 | "provenance": [ 240 | { 241 | "file_id": "1e_Vb20SWXN6aL_IDabLRxLZppERlI5Uj", 242 | "timestamp": 1659388964753 243 | } 244 | ] 245 | }, 246 | "kernelspec": { 247 | "display_name": "Python 3", 248 | "name": "python3" 249 | }, 250 | "language_info": { 251 | "name": "python" 252 | } 253 | }, 254 | "nbformat": 4, 255 | "nbformat_minor": 0 256 | } 257 | -------------------------------------------------------------------------------- /health_acoustic_representations/hear_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "p2hQbI_77OjW" 7 | }, 8 | "source": [ 9 | "```\n", 10 | "Copyright (c) 2024, Google Inc.\n", 11 | "All rights reserved.\n", 12 | "Redistribution and use in source and binary forms, with or without modification,\n", 13 | "are permitted provided that the following conditions are met:\n", 14 | "1. Redistributions of source code must retain the above copyright notice, this\n", 15 | " list of conditions and the following disclaimer.\n", 16 | "2. Redistributions in binary form must reproduce the above copyright notice,\n", 17 | " this list of conditions and the following disclaimer in the documentation\n", 18 | " and/or other materials provided with the distribution.\n", 19 | "3. Neither the name of Google Inc. nor the names of its contributors\n", 20 | " may be used to endorse or promote products derived from this software without\n", 21 | " specific prior written permission.\n", 22 | "THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n", 23 | "ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n", 24 | "WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n", 25 | "DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR\n", 26 | "ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n", 27 | "(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n", 28 | "LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON\n", 29 | "ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n", 30 | "(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n", 31 | "SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n", 32 | "```\n" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": { 38 | "id": "ozCvnTeNgfSd" 39 | }, 40 | "source": [ 41 | "# Imports" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 1, 47 | "metadata": { 48 | "id": "qG8fphmPbwd-" 49 | }, 50 | "outputs": [], 51 | "source": [ 52 | "import concurrent.futures\n", 53 | "import os\n", 54 | "import random\n", 55 | "\n", 56 | "import google.auth\n", 57 | "import google.auth.transport.requests\n", 58 | "import numpy as np\n" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "metadata": { 64 | "id": "A137c1bi7Pi5" 65 | }, 66 | "source": [ 67 | "# Authentication" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": { 73 | "id": "kflxFmKocH4k" 74 | }, 75 | "source": [ 76 | "The JSON file mentioned in the cell below is created by running the following command (for service accounts)\n", 77 | "\n", 78 | "```\n", 79 | "gcloud auth application-default login --impersonate-service-account SERVICE_ACCT\n", 80 | "```\n", 81 | "\n", 82 | "or that command\n", 83 | "\n", 84 | "```\n", 85 | "gcloud auth application-default login\n", 86 | "```\n", 87 | "\n", 88 | "to identify with your own account.\n", 89 | "\n", 90 | "This assumes that you have first [installed](https://cloud.google.com/sdk/docs/install) `gcloud` CLI and created a service account (see [[1]](https://cloud.google.com/iam/docs/service-account-overview), [[2]](https://cloud.google.com/iam/docs/service-accounts-create)) (identified by `SERVICE_ACCT` above)" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": { 97 | "id": "zEqnwMb8b2Yq" 98 | }, 99 | "outputs": [], 100 | "source": [ 101 | "os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = '/path/to/your/credentials/json/file'" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 2, 107 | "metadata": { 108 | "id": "sDD7Ks4svhRV" 109 | }, 110 | "outputs": [], 111 | "source": [ 112 | "# Environment variable `GOOGLE_APPLICATION_CREDENTIALS` must be set for these\n", 113 | "# imports to work.\n", 114 | "import api_utils" 115 | ] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "metadata": { 120 | "id": "muMiY8lS7Q5Z" 121 | }, 122 | "source": [ 123 | "# Online predictions" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "metadata": { 129 | "id": "zScxEKPI7XFU" 130 | }, 131 | "source": [ 132 | "## With raw audio" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": { 139 | "id": "-zx9BkYv61WY" 140 | }, 141 | "outputs": [], 142 | "source": [ 143 | "raw_audio = np.array([[random.random() for _ in range(32000)] for _ in range(4)])\n", 144 | "embeddings = api_utils.make_prediction(\n", 145 | " endpoint_path=api_utils.RAW_AUDIO_ENDPOINT_PATH,\n", 146 | " instances=raw_audio,\n", 147 | ")" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "metadata": { 153 | "id": "Ww2Klcij7YoI" 154 | }, 155 | "source": [ 156 | "## With GCS bucket URIs" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": null, 162 | "metadata": { 163 | "id": "6cVSozwVSJTl" 164 | }, 165 | "outputs": [], 166 | "source": [ 167 | "gcs_creds, project = google.auth.default()\n", 168 | "api_utils.initial_token_refresh(gcs_creds)" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": null, 174 | "metadata": { 175 | "id": "6-WC0o3oSZro" 176 | }, 177 | "outputs": [], 178 | "source": [ 179 | "# copybara:strip_begin(Internal repo)\n", 180 | "gcs_bucket_name = 'hear-demo'\n", 181 | "# copybara:strip_end_and_replace_begin\n", 182 | "# gcs_bucket_name = 'your-bucket-name'\n", 183 | "# copybara:replace_end\n", 184 | "\n", 185 | "predictions = api_utils.make_prediction(\n", 186 | " endpoint_path=api_utils.GCS_URI_ENDPOINT_PATH,\n", 187 | " # copybara:strip_begin(Internal filepaths)\n", 188 | " instances=['data/test.wav', 'data/test.wav'],\n", 189 | " # copybara:strip_end_and_replace_begin\n", 190 | " # instances=['path/to/your/file1.wav', 'path/to/your/file2.wav'],\n", 191 | " # copybara:replace_end\n", 192 | " gcs_bucket_name=gcs_bucket_name,\n", 193 | " gcs_creds=gcs_creds,\n", 194 | ")" 195 | ] 196 | }, 197 | { 198 | "cell_type": "markdown", 199 | "metadata": { 200 | "id": "MY9wituPny0G" 201 | }, 202 | "source": [ 203 | "# If you have a lot of queries to run\n", 204 | "\n", 205 | "Example with the raw-audio endpoint (202) using ThreadPoolExecutor." 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": null, 211 | "metadata": { 212 | "id": "Vx3gaUe0cb_h" 213 | }, 214 | "outputs": [], 215 | "source": [ 216 | "# 1000 batches of 4 clips. This is the format expected for the raw audio endpoint\n", 217 | "instances = np.random.uniform(size=(1000, 4, 32000)) # update with your data\n", 218 | "\n", 219 | "responses = {}\n", 220 | "\n", 221 | "with concurrent.futures.ThreadPoolExecutor(max_workers=50) as executor:\n", 222 | " futures_to_batch_idx = {\n", 223 | " executor.submit(\n", 224 | " api_utils.make_prediction_with_exponential_backoff,\n", 225 | " api_utils.RAW_AUDIO_ENDPOINT_PATH,\n", 226 | " instance\n", 227 | " ): batch_idx\n", 228 | " for batch_idx, instance in enumerate(instances)\n", 229 | " }\n", 230 | "\n", 231 | " for future in concurrent.futures.as_completed(futures_to_batch_idx):\n", 232 | " batch_idx = futures_to_batch_idx[future]\n", 233 | " try:\n", 234 | " responses[batch_idx] = future.result()\n", 235 | " except Exception as e:\n", 236 | " print(\"An error occurred:\", e)" 237 | ] 238 | } 239 | ], 240 | "metadata": { 241 | "colab": { 242 | "last_runtime": { 243 | "build_target": "//medical/discovery/colab:acoustic_notebook", 244 | "kind": "private" 245 | }, 246 | "private_outputs": true, 247 | "provenance": [ 248 | { 249 | "file_id": "google_health/health_acoustic_representations/hear_demo.ipynb", 250 | "timestamp": 1721672412468 251 | } 252 | ] 253 | }, 254 | "kernelspec": { 255 | "display_name": "Python 3", 256 | "name": "python3" 257 | }, 258 | "language_info": { 259 | "name": "python" 260 | } 261 | }, 262 | "nbformat": 4, 263 | "nbformat_minor": 0 264 | } 265 | -------------------------------------------------------------------------------- /health_acoustic_representations/api_utils.py: -------------------------------------------------------------------------------- 1 | """Utils for calling HeAR on Vertex AI.""" 2 | 3 | import datetime 4 | import time 5 | from typing import Literal 6 | 7 | import google.auth 8 | import google.auth.transport.requests 9 | from google.cloud.aiplatform.aiplatform import gapic 10 | from google.protobuf import json_format 11 | import numpy as np 12 | 13 | from google.protobuf import struct_pb2 14 | 15 | LOCATION = 'us-central1' 16 | PROJECT_ID = '132886652110' 17 | RAW_AUDIO_ENDPOINT_ID = '202' 18 | GCS_URI_ENDPOINT_ID = '203' 19 | 20 | CLIENT_OPTIONS = {'api_endpoint': f'{LOCATION}-aiplatform.googleapis.com'} 21 | 22 | try: 23 | CLIENT = gapic.PredictionServiceClient(client_options=CLIENT_OPTIONS) 24 | except google.auth.exceptions.DefaultCredentialsError as exc: 25 | # pylint: disable=line-too-long 26 | raise ValueError( 27 | 'Note: you have not defined environment variable ' 28 | '`GOOGLE_APPLICATION_CREDENTIALS`. That variable should point to the ' 29 | 'path of your service account key file, which you can create by running ' 30 | '`gcloud auth application-default login` for your own identity or ' 31 | '`gcloud auth application-default login --impersonate-service-account SERVICE_ACCT`' 32 | 'for service accounts. This assumes that you have first installed ' 33 | 'https://cloud.google.com/sdk/docs/install) `gcloud` CLI and created a ' 34 | 'service account ' 35 | '(see https://cloud.google.com/iam/docs/service-account-overview, ' 36 | 'https://cloud.google.com/iam/docs/service-accounts-create) ' 37 | 'identified by `SERVICE_ACCT` above.' 38 | ) from exc 39 | # pylint: enable=line-too-long 40 | 41 | 42 | RAW_AUDIO_ENDPOINT_PATH = CLIENT.endpoint_path( 43 | project=PROJECT_ID, 44 | location=LOCATION, 45 | endpoint=RAW_AUDIO_ENDPOINT_ID, 46 | ) 47 | GCS_URI_ENDPOINT_PATH = CLIENT.endpoint_path( 48 | project=PROJECT_ID, 49 | location=LOCATION, 50 | endpoint=GCS_URI_ENDPOINT_ID, 51 | ) 52 | 53 | 54 | def initial_token_refresh( 55 | gcs_creds: google.auth.credentials.Credentials, 56 | ) -> None: 57 | """Obtains short lived credentials for your GCS bucket.""" 58 | auth_req = google.auth.transport.requests.Request() 59 | gcs_creds.refresh(auth_req) 60 | if not gcs_creds.valid: 61 | raise ValueError('Unexpected error: GCS Credentials are invalid') 62 | assert isinstance(gcs_creds.valid, datetime.datetime) # for pytype 63 | time_until_expiry = ( 64 | gcs_creds.expiry - datetime.datetime.now() 65 | ).total_seconds() // 60 66 | print( 67 | 'Token will expire at' 68 | f' {gcs_creds.expiry.strftime("%Y-%m-%d %H:%M:%S")} UTC' 69 | f' ({time_until_expiry} minutes)' 70 | ) 71 | 72 | 73 | def _get_prediction_instances( 74 | image_uris: list[str], 75 | gcs_bucket_name: str, 76 | gcs_creds: google.auth.credentials.Credentials, 77 | ) -> list[struct_pb2.Value]: 78 | """Gets a list of dicts to pass as Vertex PredictionService instances.""" 79 | instances = [] 80 | for image_uri in image_uris: 81 | instance_dict = { 82 | 'bucket_name': gcs_bucket_name, 83 | 'object_uri': image_uri, 84 | 'bearer_token': gcs_creds.token, 85 | } 86 | instance = json_format.ParseDict(instance_dict, struct_pb2.Value()) 87 | instances.append(instance) 88 | return instances 89 | 90 | 91 | def make_prediction( 92 | endpoint_path: Literal[RAW_AUDIO_ENDPOINT_PATH, GCS_URI_ENDPOINT_PATH], 93 | instances: np.ndarray | list[str], 94 | gcs_bucket_name: str | None = None, 95 | gcs_creds: google.auth.credentials.Credentials | None = None, 96 | client: gapic.PredictionServiceClient = CLIENT, 97 | ) -> np.ndarray: 98 | """Makes prediction with HeAR. 99 | 100 | Args: 101 | endpoint_path: The endpoint to use for making the prediction. 102 | instances: The instances to use for making the prediction. When endpoint is 103 | `RAW_AUDIO_ENDPOINT_PATH`, `instances` must be a numpy array of shape 104 | [num_samples, num_timesteps], where num_timesteps = 32000. When endpoint 105 | is `GCS_URI_ENDPOINT_PATH`, `instances` must be a list of strings, each 106 | string corresponding to a path to a wav file in GCS. 107 | gcs_bucket_name: The name of the GCS bucket to use for making the prediction 108 | when endpoint is `GCS_URI_ENDPOINT_PATH`. 109 | gcs_creds: The credentials to use for making the prediction when endpoint is 110 | `GCS_URI_ENDPOINT_PATH`. These must be obtained by calling `gcs_creds, 111 | project = google.auth.default()` and `initial_token_refresh(gcs_creds)`. 112 | client: The client to use for making the prediction. 113 | 114 | Returns: 115 | The predictions from the model. Embeddings of shape [num_samples, 116 | embedding_dim], where embedding_dim is 512. 117 | 118 | Raises: 119 | ValueError: If the instances don't have the right type, if the endpoint is 120 | not recognized, or if the gcs_bucket_name or gcs_creds are not specified 121 | when endpoint is `GCS_URI_ENDPOINT_PATH`. 122 | """ 123 | if endpoint_path == RAW_AUDIO_ENDPOINT_PATH: 124 | if not isinstance(instances, np.ndarray): 125 | raise ValueError( 126 | 'For endpoint `RAW_AUDIO_ENDPOINT_PATH`, `instances` must be a numpy ' 127 | f'array but was of type {type(instances)} with value {instances}' 128 | ) 129 | instances = instances.astype(float) 130 | if instances.ndim != 2 or instances.shape[-1] != 32000: 131 | raise ValueError( 132 | 'For endpoint `RAW_AUDIO_ENDPOINT_PATH`, `instances` must be a numpy ' 133 | 'array of shape [num_samples, num_timesteps], where num_timesteps = ' 134 | f'32000, but got {instances.shape}.' 135 | ) 136 | instances = instances.tolist() 137 | elif endpoint_path == GCS_URI_ENDPOINT_PATH: 138 | if not isinstance(instances, list) and not isinstance(instances[0], str): 139 | raise ValueError( 140 | 'For endpoint `GCS_URI_ENDPOINT_PATH`, `instances` must be a list of ' 141 | 'strings.' 142 | ) 143 | if gcs_bucket_name is None: 144 | raise ValueError( 145 | 'For endpoint `GCS_URI_ENDPOINT_PATH`, `gcs_bucket_name` must be ' 146 | 'specified.' 147 | ) 148 | if gcs_creds is None: 149 | raise ValueError( 150 | 'For endpoint `GCS_URI_ENDPOINT_PATH`, `gcs_creds` must be specified.' 151 | ) 152 | instances = _get_prediction_instances( 153 | image_uris=instances, 154 | gcs_bucket_name=gcs_bucket_name, 155 | gcs_creds=gcs_creds, 156 | ) 157 | else: 158 | raise ValueError(f'Endpoint {endpoint_path} is not recognized.') 159 | response = client.predict(endpoint=endpoint_path, instances=instances) 160 | result = np.array(response.predictions) 161 | return result 162 | 163 | 164 | def make_prediction_with_exponential_backoff( 165 | endpoint_path: Literal[RAW_AUDIO_ENDPOINT_PATH, GCS_URI_ENDPOINT_PATH], 166 | instances: np.ndarray | list[str], 167 | max_retries: int = 10, 168 | base_delay_secs: float = 1, 169 | max_delay_secs: float = 60, 170 | gcs_bucket_name: str | None = None, 171 | gcs_creds: google.auth.credentials.Credentials | None = None, 172 | client: gapic.PredictionServiceClient = CLIENT, 173 | ) -> np.ndarray: 174 | """Makes prediction with exponential backoff. 175 | 176 | Args: 177 | endpoint_path: The endpoint to use for making the prediction. 178 | instances: The instances to use for making the prediction. Array of shape 179 | [num_samples, num_timesteps], where num_timesteps = 32000. 180 | max_retries: The maximum number of retries to make. 181 | base_delay_secs: The base delay in seconds. 182 | max_delay_secs: The maximum delay in seconds. 183 | gcs_bucket_name: The name of the GCS bucket to use for making the prediction 184 | when endpoint is `GCS_URI_ENDPOINT_PATH`. 185 | gcs_creds: The credentials to use for making the prediction when endpoint is 186 | `GCS_URI_ENDPOINT_PATH`. These must be obtained by calling `gcs_creds, 187 | project = google.auth.default()` and `initial_token_refresh(gcs_creds)`. 188 | client: The client to use for making the prediction. 189 | 190 | Returns: 191 | The predictions from the model. Embeddings of shape [num_samples, 192 | embedding_dim], where embedding_dim is 512. 193 | 194 | Raises: 195 | ValueError: If the endpoint is not recognized,or if the query failed too 196 | many times and the maximum of retries is reached. 197 | """ 198 | if endpoint_path not in {RAW_AUDIO_ENDPOINT_PATH, GCS_URI_ENDPOINT_PATH}: 199 | raise ValueError( 200 | f'Endpoint must be one of {RAW_AUDIO_ENDPOINT_PATH} or' 201 | f' {GCS_URI_ENDPOINT_PATH}, but got {endpoint_path}.' 202 | ) 203 | 204 | retries = 0 205 | while retries < max_retries: 206 | try: 207 | result = make_prediction( 208 | endpoint_path=endpoint_path, 209 | instances=instances, 210 | client=client, 211 | gcs_bucket_name=gcs_bucket_name, 212 | gcs_creds=gcs_creds, 213 | ) 214 | return result 215 | except Exception as e: # pylint: disable=broad-except 216 | retries += 1 217 | if retries == max_retries: 218 | raise ValueError(f'Max retries reached. Last error: {e}') from e 219 | 220 | delay = min(max_delay_secs, base_delay_secs * (2 ** (retries - 1))) 221 | 222 | print(f'Attempt {retries} failed. Retrying in {delay} seconds...') 223 | time.sleep(delay) 224 | 225 | raise ValueError( 226 | 'Unexpected error in `make_prediction_with_exponential_backoff`' 227 | ) 228 | -------------------------------------------------------------------------------- /ct_dicom/gcp/dicomweb_beam.py: -------------------------------------------------------------------------------- 1 | """Beam wrappers for DICOMweb API helpers in `dicomweb.py`.""" 2 | 3 | import dataclasses 4 | from typing import Iterable, Optional, Tuple 5 | 6 | from absl import logging 7 | import apache_beam as beam 8 | from apache_beam import pvalue 9 | from google.auth.transport import requests 10 | 11 | from gcp import auth 12 | from gcp import dicomweb 13 | 14 | # CHC DICOMweb API search query limits: 15 | # https://cloud.google.com/healthcare-api/docs/dicom#search_parameters 16 | _STUDY_SEARCH_QUERY_LIMIT = 5000 17 | _SERIES_SEARCH_QUERY_LIMIT = 5000 18 | 19 | 20 | def to_csv_row(values: Iterable[str]) -> str: 21 | """Converts sequence of strings to a comma-separated, quoted CSV row.""" 22 | return ','.join(f'"{value}"' for value in values) 23 | 24 | 25 | @dataclasses.dataclass(frozen=True) 26 | class ChcDicomStore: 27 | """URI components to identify a CHC DICOM Store. 28 | 29 | Attributes: 30 | project_id: The GCP Project name that hosts the DICOM Store. 31 | location: The DICOM Dataset location (region) which contains the DICOM 32 | Store. 33 | dataset_id: The name of the DICOM Dataset which contains the DICOM Store. 34 | dicom_store_id: The name of the CHC DICOM Store. 35 | """ 36 | 37 | project_id: str 38 | location: str 39 | dataset_id: str 40 | dicom_store_id: str 41 | 42 | 43 | @dataclasses.dataclass(frozen=True) 44 | class SeriesScopeMetadata: 45 | """DICOM metadata shared at the scope of a Series Instance UID. 46 | 47 | Attributes: 48 | study_instance_uid: The DICOM Study Instance UID Attribute (0020, 000D) 49 | value. 50 | series_instance_uid: The DICOM Series Instance UID Attribute (0020, 000E) 51 | value. 52 | """ 53 | 54 | study_instance_uid: str 55 | series_instance_uid: str 56 | 57 | @property 58 | def key(self) -> str: 59 | """Generates a key to make key-value pairs from an instance.""" 60 | return f'{self.study_instance_uid}/{self.series_instance_uid}' 61 | 62 | 63 | @dataclasses.dataclass(frozen=True) 64 | class SeriesScopeDICOMs: 65 | """DICOM bytes and select metadata at the Series Instance UID scope. 66 | 67 | Attributes: 68 | metadata: DICOM metadata shared by all contained DICOMs at the Series 69 | Instance UID (0020, 000E) scope. 70 | dicoms: Sequence of DICOM bytes, each sharing the same Study Instance UID 71 | Attribute (0020, 000D) value and Series Instance UID Attribute (0020, 72 | 000E) value. 73 | """ 74 | 75 | metadata: SeriesScopeMetadata 76 | dicoms: Tuple[bytes, ...] 77 | 78 | @property 79 | def key(self) -> str: 80 | """Generates a key to make key-value pairs from an instance.""" 81 | return self.metadata.key 82 | 83 | 84 | class BaseDoFnWithAuthorizedSessionForChc(beam.DoFn): 85 | """Beam DoFn base class that manages Session and Credentials for CHC APIs. 86 | 87 | Before bundle processing begins, it generates a fresh set of Google Auth 88 | Credentials and an Authorized Session with the Cloud Healthcare API Scope. 89 | Since neither Credentials nor Sessions are natively serializable in Beam, they 90 | are initialized in each worker independently. Since the Credentials are 91 | controlled by command line flags, which do get serialized, they are available 92 | to the workers for Credentials initialization. 93 | 94 | Initializing Credentials every time a new bundle is processed may be overkill, 95 | but since some Credentials have timebound validity, managing their freshness 96 | tends to complicate DoFn implementations. For most cases, re-initializing 97 | Credentials before processing a new bundle ought to suffice. 98 | """ 99 | 100 | @property 101 | def session(self) -> Optional[requests.AuthorizedSession]: 102 | """Returns an initialized Session at CHC API Scope. 103 | 104 | Unless the Session is initialized by at least one call to `start_bundle()`, 105 | it returns `None`. 106 | 107 | Returns: 108 | An initialized Session, if available. 109 | """ 110 | try: 111 | return self._session 112 | except AttributeError: 113 | return None 114 | 115 | def start_bundle(self): 116 | """Initializes an Authorized Session for the CHC API Scope.""" 117 | credentials = auth.create_gcp_credentials() 118 | self._session = dicomweb.create_authorized_session(credentials) 119 | 120 | 121 | class QueryStudyInstanceUidsFn(BaseDoFnWithAuthorizedSessionForChc): 122 | """Beam DoFn to query all Study Instance UIDs within a CHC DICOM Store. 123 | 124 | All Study Instance UIDs are generated by one Beam worker. Study Instance 125 | UIDs (0020,000D) are internally retrieved lazily in batches of up to 5000 126 | values, but emitted one-by-one in the `process()` method. 127 | 128 | This also means that if retrieval takes sufficiently long, some Credentials 129 | may expire. Other than explicit Credentials and Session management, a 130 | workaround could be to force the generation of all UIDs before feeding them to 131 | the downstream Beam pipeline components by adding a Shuffle operation to the 132 | output of this class. 133 | """ 134 | 135 | def __init__(self, dicom_store: ChcDicomStore) -> None: 136 | """Initializes an instance. 137 | 138 | Args: 139 | dicom_store: The CHC DICOM Store URI components. 140 | """ 141 | super().__init__() 142 | self._dicom_store = dicom_store 143 | 144 | def process(self, _) -> Iterable[str]: 145 | """Generates all Study Instance UIDs within the CHC DICOM Store.""" 146 | assert self.session is not None 147 | return dicomweb.search_study_instance_uids( 148 | self._dicom_store.project_id, 149 | self._dicom_store.location, 150 | self._dicom_store.dataset_id, 151 | self._dicom_store.dicom_store_id, 152 | self.session, 153 | limit=_STUDY_SEARCH_QUERY_LIMIT, 154 | ) 155 | 156 | 157 | class QuerySeriesInstanceUidsFn(BaseDoFnWithAuthorizedSessionForChc): 158 | """Beam DoFn to query all Series Instance UIDs within a Study Instance UID.""" 159 | 160 | def __init__(self, dicom_store: ChcDicomStore) -> None: 161 | """Initializes an instance. 162 | 163 | Args: 164 | dicom_store: The CHC DICOM Store URI components. 165 | """ 166 | super().__init__() 167 | self._dicom_store = dicom_store 168 | 169 | def process(self, study_instance_uid: str) -> Iterable[SeriesScopeMetadata]: 170 | """Generates all Series Instance UIDs within the input Study Instance UID. 171 | 172 | Args: 173 | study_instance_uid: The DICOM Study Instance UID (0020, 000D) Attribute 174 | value to scope the query within. 175 | 176 | Yields: 177 | All Series Instance UID (0020, 000E) Attribute value within all DICOMs 178 | matching the Study Instance UID Attribute value of `study_instance_uid`. 179 | """ 180 | assert self.session is not None 181 | for series_instance_uid in dicomweb.search_series_instance_uids( 182 | self._dicom_store.project_id, 183 | self._dicom_store.location, 184 | self._dicom_store.dataset_id, 185 | self._dicom_store.dicom_store_id, 186 | self.session, 187 | study_instance_uid, 188 | limit=_SERIES_SEARCH_QUERY_LIMIT, 189 | ): 190 | yield SeriesScopeMetadata(study_instance_uid, series_instance_uid) 191 | 192 | 193 | class DownloadMultipartDicomSeriesFn(BaseDoFnWithAuthorizedSessionForChc): 194 | """Beam DoFn to download all DICOMs with the same Series and Study UIDs. 195 | 196 | The DICOMs are downloaded in one multi-part GET request for efficiency. It 197 | reduces the network and DICOM Store latency involved in otherwise issuing 198 | multiple GET requests, as well as the CHC API Quota usage. 199 | 200 | The downloaded DICOM bytes and information about any errors during download 201 | are routed to separate, tagged PCollections. 202 | """ 203 | 204 | ERROR_OUTPUT_TAG = 'errors' 205 | 206 | def __init__(self, dicom_store: ChcDicomStore) -> None: 207 | """Initializes an instance. 208 | 209 | Args: 210 | dicom_store: The CHC DICOM Store URI components. 211 | """ 212 | super().__init__() 213 | self._dicom_store = dicom_store 214 | 215 | def process(self, series_scope_metadata: SeriesScopeMetadata): 216 | """Emits a DICOM collection sharing the same Study and Series UIDs. 217 | 218 | In case of errors during download, the error string is output to a Beam 219 | TaggedOutput with tag "errors", in the form of a CSV row with entries: 220 | "","","" 221 | 222 | Args: 223 | series_scope_metadata: The Study Instance UID (0020, 000D) and Series 224 | Instance UID (0020, 000E) Attribute values for the DICOMs to download. 225 | 226 | Yields: 227 | A single collection of DICOMs sharing the same Study and Series Instance 228 | UID if download was successful. A CSV row to a TaggedOutput, otherwise. 229 | """ 230 | assert self.session is not None 231 | try: 232 | yield SeriesScopeDICOMs( 233 | series_scope_metadata, 234 | tuple( 235 | dicomweb.download_multipart_dicom_series( 236 | self._dicom_store.project_id, 237 | self._dicom_store.location, 238 | self._dicom_store.dataset_id, 239 | self._dicom_store.dicom_store_id, 240 | self.session, 241 | series_scope_metadata.study_instance_uid, 242 | series_scope_metadata.series_instance_uid, 243 | ) 244 | ), 245 | ) 246 | except Exception as e: # pylint: disable=broad-exception-caught 247 | logging.error('Error downloading %r', series_scope_metadata.key) 248 | yield pvalue.TaggedOutput( 249 | self.ERROR_OUTPUT_TAG, 250 | to_csv_row(( 251 | series_scope_metadata.study_instance_uid, 252 | series_scope_metadata.series_instance_uid, 253 | *e.args, 254 | )), 255 | ) 256 | -------------------------------------------------------------------------------- /fitbit_pregnancy/figure_generate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, Google Inc. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions 6 | # are met: 7 | # 8 | # 1. Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # 11 | # 2. Redistributions in binary form must reproduce the above copyright 12 | # notice, this list of conditions and the following disclaimer in the 13 | # documentation and/or other materials provided with the distribution. 14 | # 15 | # 3. Neither the name of Google Inc. nor the names of its 16 | # contributors may be used to endorse or promote products derived from this 17 | # software without specific prior written permission. 18 | # 19 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 22 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 23 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 24 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 25 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 26 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 27 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 28 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 29 | # POSSIBILITY OF SUCH DAMAGE. 30 | 31 | import os 32 | import matplotlib.pyplot as plt 33 | import numpy as np 34 | import pandas as pd 35 | 36 | data_dir = 'data/' 37 | 38 | figure_dir = 'figures/' 39 | if not os.path.exists(figure_dir): 40 | os.makedirs(figure_dir) 41 | 42 | ### Figure 2A 43 | 44 | data = pd.read_csv(data_dir + 'figure_2_a_data.csv') 45 | 46 | plt.figure() 47 | plt.bar(data['bin'], data['count'], color='k') 48 | plt.grid('on') 49 | xlab = 'Pregnancy length (weeks)' 50 | ylab = 'Participants' 51 | plt.ylabel(ylab) 52 | plt.xlabel(xlab) 53 | plt.title('Figure 2a') 54 | file_path = os.path.join(figure_dir, 'Figure 2a.png') 55 | plt.savefig(file_path, transparent=True) 56 | 57 | 58 | ### Figure 2B 59 | 60 | data = pd.read_csv(data_dir + 'figure_2_b_data.csv') 61 | 62 | plt.figure() 63 | plt.plot(data['week'], data['data'], '-ok') 64 | plt.grid('on') 65 | xlab = 'Time pregnant (weeks)' 66 | ylab = 'Participants (%)' 67 | plt.ylabel(ylab) 68 | plt.xlabel(xlab) 69 | plt.title('Figure 2b') 70 | file_path = os.path.join(figure_dir, 'Figure 2b.png') 71 | plt.savefig(file_path, transparent=True) 72 | 73 | 74 | ### Figure 2C 75 | 76 | data = pd.read_csv(data_dir + 'figure_2_c_data.csv') 77 | 78 | plt.figure() 79 | plt.plot(data['week'], data['mean'], '-ok') 80 | plt.grid('on') 81 | xlab = 'Time (weeks)' 82 | ylab = 'Participants (%)' 83 | plt.ylabel(ylab) 84 | plt.xlabel(xlab) 85 | plt.title('Figure 2c') 86 | file_path = os.path.join(figure_dir, 'Figure 2c.png') 87 | plt.savefig(file_path, transparent=True) 88 | 89 | ### Figure 3 90 | 91 | data = pd.read_csv(data_dir + 'figure_3_data.csv') 92 | 93 | plt.figure() 94 | plt.plot(data['day'], data['mean'], '-k') 95 | plt.fill_between( 96 | data['day'], 97 | list(data['mean'].values - data['std'].values), 98 | list(data['mean'].values + data['std'].values), 99 | alpha=0.2, 100 | color='k', 101 | ) 102 | 103 | plt.grid('on') 104 | xlab = 'Time pregnant (days)' 105 | ylab = 'Heart rate change (bpm)' 106 | plt.ylabel(ylab) 107 | plt.xlabel(xlab) 108 | plt.title('Figure 3') 109 | file_path = os.path.join(figure_dir, 'Figure 3.png') 110 | plt.savefig(file_path, transparent=True) 111 | 112 | 113 | ### Figure 4A 114 | 115 | data = pd.read_csv(data_dir + 'figure_4_a_data.csv') 116 | 117 | plt.figure() 118 | plt.plot(data['week'], data['mean'], '-ok') 119 | plt.fill_between( 120 | data['week'], 121 | list(data['mean'].values - data['std'].values), 122 | list(data['mean'].values + data['std'].values), 123 | alpha=0.2, 124 | color='k', 125 | ) 126 | 127 | plt.grid('on') 128 | xlab = 'Time pregnant (weeks)' 129 | ylab = 'Normalized TIB (min)' 130 | plt.ylabel(ylab) 131 | plt.xlabel(xlab) 132 | plt.title('Figure 4a') 133 | file_path = os.path.join(figure_dir, 'Figure 4a.png') 134 | plt.savefig(file_path, transparent=True) 135 | 136 | 137 | ### Figure 4B 138 | 139 | data = pd.read_csv(data_dir + 'figure_4_b_data.csv') 140 | 141 | plt.figure() 142 | plt.plot(data['week'], data['percentile_10'], '-k', alpha=0.2) 143 | plt.plot(data['week'], data['percentile_25'], '-k', alpha=0.5) 144 | plt.plot(data['week'], data['percentile_50'], '-k', alpha=0.9) 145 | plt.plot(data['week'], data['percentile_75'], '-k', alpha=0.5) 146 | plt.plot(data['week'], data['percentile_90'], '-k', alpha=0.2) 147 | 148 | ##uncomment the line below to add the confidence interval to plot 149 | # plt.fill_between(data['week'],list(data['mean'].values-data['std'].values),list(data['mean'].values+data['std'].values),alpha=0.2,color='k') 150 | 151 | plt.grid('on') 152 | xlab = 'Time pregnant (weeks)' 153 | ylab = 'Total TIB (min)' 154 | plt.ylabel(ylab) 155 | plt.xlabel(xlab) 156 | plt.title('Figure 4b') 157 | file_path = os.path.join(figure_dir, 'Figure 4b.png') 158 | plt.savefig(file_path, transparent=True) 159 | 160 | ### Figure 4C 161 | 162 | data = pd.read_csv(data_dir + 'figure_4_c_data.csv') 163 | 164 | plt.figure() 165 | plt.plot(data['week'], data['mean'], '-ok') 166 | plt.fill_between( 167 | data['week'], 168 | list(data['mean'].values - data['std'].values), 169 | list(data['mean'].values + data['std'].values), 170 | alpha=0.2, 171 | color='k', 172 | ) 173 | 174 | plt.grid('on') 175 | xlab = 'Time pregnant (weeks)' 176 | ylab = 'Participants(%)' 177 | plt.ylabel(ylab) 178 | plt.xlabel(xlab) 179 | plt.title('Figure 4c') 180 | file_path = os.path.join(figure_dir, 'Figure 4c.png') 181 | plt.savefig(file_path, transparent=True) 182 | 183 | 184 | ### Figure 4 D 185 | 186 | data = pd.read_csv(data_dir + 'figure_4_d_data.csv') 187 | 188 | plt.figure() 189 | plt.plot(data['week'], data['mean'], '-k') 190 | plt.fill_between( 191 | data['week'], 192 | list(data['mean'].values - data['std'].values), 193 | list(data['mean'].values + data['std'].values), 194 | alpha=0.2, 195 | color='k', 196 | ) 197 | 198 | plt.grid('on') 199 | xlab = 'Time pregnant (weeks)' 200 | ylab = 'Normalized TST (min)' 201 | plt.ylabel(ylab) 202 | plt.xlabel(xlab) 203 | plt.title('Figure 4d') 204 | file_path = os.path.join(figure_dir, 'Figure 4d.png') 205 | plt.savefig(file_path, transparent=True) 206 | 207 | 208 | ### Figure 4 E 209 | 210 | data = pd.read_csv(data_dir + 'figure_4_e_data.csv') 211 | allcolor = ( 212 | np.array([[0, 144, 181], [32, 133, 78], [255, 135, 39], [188, 60, 41]]) 213 | / 256.0 214 | ) 215 | 216 | plt.figure() 217 | plt.plot(data['week'], data['deep'], color=allcolor[0], label='deep') 218 | plt.plot(data['week'], data['light'], color=allcolor[1], label='light') 219 | plt.plot(data['week'], data['rem'], color=allcolor[2], label='rem') 220 | plt.plot(data['week'], data['wake'], color=allcolor[3], label='wake') 221 | 222 | plt.grid('on') 223 | xlab = 'Time pregnant (weeks)' 224 | ylab = 'Normalized sleep stage (min)' 225 | plt.ylabel(ylab) 226 | plt.xlabel(xlab) 227 | plt.title('Figure 4e') 228 | file_path = os.path.join(figure_dir, 'Figure 4e.png') 229 | plt.savefig(file_path, transparent=True) 230 | 231 | 232 | ### Figure 4 F 233 | 234 | data = pd.read_csv(data_dir + 'figure_4_f_data.csv') 235 | 236 | plt.figure() 237 | plt.plot(data['week'], data['percentile_10'], '-k', alpha=0.2) 238 | plt.plot(data['week'], data['percentile_25'], '-k', alpha=0.5) 239 | plt.plot(data['week'], data['percentile_50'], '-k', alpha=0.9) 240 | plt.plot(data['week'], data['percentile_75'], '-k', alpha=0.5) 241 | plt.plot(data['week'], data['percentile_90'], '-k', alpha=0.2) 242 | 243 | ##uncomment the line below to add the confidence interval to plot 244 | # plt.fill_between(data['week'],list(data['mean'].values-data['std'].values),list(data['mean'].values+data['std'].values),alpha=0.2,color='k') 245 | 246 | plt.grid('on') 247 | xlab = 'Time pregnant (weeks)' 248 | ylab = 'Sleep efficiency (%)' 249 | plt.ylabel(ylab) 250 | plt.xlabel(xlab) 251 | plt.title('Figure 4f') 252 | file_path = os.path.join(figure_dir, 'Figure 4f.png') 253 | plt.savefig(file_path, transparent=True) 254 | 255 | 256 | ### Figure 5 A 257 | 258 | data = pd.read_csv(data_dir + 'figure_5_a_data.csv') 259 | 260 | plt.figure() 261 | plt.plot(data['week'], data['mean'], '-ok') 262 | plt.fill_between( 263 | data['week'], 264 | list(data['mean'].values - data['std'].values), 265 | list(data['mean'].values + data['std'].values), 266 | alpha=0.2, 267 | color='k', 268 | ) 269 | 270 | plt.grid('on') 271 | xlab = 'Time (weeks)' 272 | ylab = 'Participants (%)' 273 | plt.ylabel(ylab) 274 | plt.xlabel(xlab) 275 | plt.title('Figure 5a') 276 | file_path = os.path.join(figure_dir, 'Figure 5a.png') 277 | plt.savefig(file_path, transparent=True) 278 | 279 | 280 | ### Figure 5 B 281 | 282 | data = pd.read_csv(data_dir + 'figure_5_b_data.csv') 283 | 284 | plt.figure() 285 | plt.plot(data['week'], data['percentile_10'], '-k', alpha=0.2) 286 | plt.plot(data['week'], data['percentile_25'], '-k', alpha=0.5) 287 | plt.plot(data['week'], data['percentile_50'], '-k', alpha=0.9) 288 | plt.plot(data['week'], data['percentile_75'], '-k', alpha=0.5) 289 | plt.plot(data['week'], data['percentile_90'], '-k', alpha=0.2) 290 | 291 | plt.grid('on') 292 | xlab = 'Time (weeks)' 293 | ylab = 'Total TIB (min)' 294 | plt.ylabel(ylab) 295 | plt.xlabel(xlab) 296 | plt.title('Figure 5b') 297 | file_path = os.path.join(figure_dir, 'Figure 5b.png') 298 | plt.savefig(file_path, transparent=True) 299 | 300 | 301 | ### Figure 5 c 302 | 303 | data = pd.read_csv(data_dir + 'figure_5_c_data.csv') 304 | 305 | plt.figure() 306 | plt.plot(data['week'], data['mean'], '-ok') 307 | plt.fill_between( 308 | data['week'], 309 | list(data['mean'].values - data['std'].values), 310 | list(data['mean'].values + data['std'].values), 311 | alpha=0.2, 312 | color='k', 313 | ) 314 | 315 | plt.grid('on') 316 | xlab = 'Time (weeks)' 317 | ylab = 'Participants (%)' 318 | plt.ylabel(ylab) 319 | plt.xlabel(xlab) 320 | plt.title('Figure 5c') 321 | file_path = os.path.join(figure_dir, 'Figure 5c.png') 322 | plt.savefig(file_path, transparent=True) 323 | 324 | 325 | ### Figure 5 D 326 | 327 | data = pd.read_csv(data_dir + 'figure_5_d_data.csv') 328 | 329 | plt.figure() 330 | plt.plot(data['week'], data['mean'], '-ok') 331 | plt.fill_between( 332 | data['week'], 333 | list(data['mean'].values - data['std'].values), 334 | list(data['mean'].values + data['std'].values), 335 | alpha=0.2, 336 | color='k', 337 | ) 338 | 339 | plt.grid('on') 340 | xlab = 'Time (weeks)' 341 | ylab = 'Participants (%)' 342 | plt.ylabel(ylab) 343 | plt.xlabel(xlab) 344 | plt.title('Figure 5d') 345 | file_path = os.path.join(figure_dir, 'Figure 5d.png') 346 | plt.savefig(file_path, transparent=True) 347 | --------------------------------------------------------------------------------