├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── config.yaml ├── hptuning_config.yaml ├── preprocess ├── 1000_genomes_metadata.jinja ├── 1000_genomes_phase3_b37.jinja ├── 1000_genomes_phase3_b37_limit10.jinja ├── 1000_genomes_phase3_b37_snps_only.jinja ├── pgp_data_b37.jinja ├── pgp_data_b37_1kg_variants_only.jinja ├── pgp_data_b37_limit10.jinja ├── pgp_metadata.jinja ├── platinum_genomes_b37.jinja ├── platinum_genomes_metadata.jinja ├── sgdp_data_b37.jinja ├── sgdp_metadata.jinja └── sgdp_metadata_remap_labels.jinja ├── setup.py └── trainer ├── __init__.py ├── ancestry_metadata_encoder.py ├── feature_encoder.py ├── feature_encoder_test.py ├── preprocess_data.py ├── revise_preprocessed_data.py ├── util.py ├── variant_encoder.py ├── variant_encoder_test.py └── variants_inference.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Want to contribute? Great! First, read this page (including the small print at the end). 2 | 3 | ### Before you contribute 4 | Before we can use your code, you must sign the 5 | [Google Individual Contributor License 6 | Agreement](https://cla.developers.google.com/about/google-individual) 7 | (CLA), which you can do online. The CLA is necessary mainly because you own the 8 | copyright to your changes, even after your contribution becomes part of our 9 | codebase, so we need your permission to use and distribute your code. We also 10 | need to be sure of various other things—for instance that you'll tell us if you 11 | know that your code infringes on other people's patents. You don't have to sign 12 | the CLA until after you've submitted your code for review and a member has 13 | approved it, but you must do it before we can put your code into our codebase. 14 | Before you start working on a larger contribution, you should get in touch with 15 | us first through the issue tracker with your idea so that we can help out and 16 | possibly guide you. Coordinating up front makes it much easier to avoid 17 | frustration later on. 18 | 19 | ### Code reviews 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. 22 | 23 | ### The small print 24 | Contributions made by corporations are covered by a different agreement than 25 | the one above, the 26 | [Software Grant and Corporate Contributor License 27 | Agreement](https://cla.developers.google.com/about/google-corporate). 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### Disclaimer 2 | 3 | This is not an official Google product. 4 | 5 | cloudml-examples 6 | ================ 7 | 8 | This repository contains an example of applying machine learning to genomic data using [Cloud Machine Learning Engine (Cloud ML Engine)](https://cloud.google.com/ml-engine/). The learning problem demonstrated is an ancestry inference. Identification of genetic ancestry is important for adjusting putative genetic associations with traits that may be driven largely through population structure. It is also important for assessing the accuracy of self-reported ancestry. 9 | 10 | The instructions below train a model to predict 1000 Genomes super population labels. The training data are the [1000 Genomes](https://cloud.google.com/genomics/docs/public-datasets/1000-genomes) phase 3 variants stored in [Google BigQuery](https://cloud.google.com/bigquery/). The validation data are the [Simons Genome Diversity Project](https://cloud.google.com/genomics/docs/public-datasets/simons) variants stored in BigQuery. The training data is pre-processed using pipelines written with [Apache Beam](https://beam.apache.org/) and executed on [Google Cloud Dataflow](https://cloud.google.com/dataflow/docs/). 11 | 12 | This approach uses continuous vectors of genomic variants for analysis and inference on Machine Learning pipelines. For related work, see also [Diet Networks: Thin Parameters for Fat Genomics](https://openreview.net/pdf?id=Sk-oDY9ge) Romero et. al. 13 | 14 | This is a non-trivial example in terms of cost (it may consume a large portion 15 | of the [free trial credit](https://cloud.google.com/free/)) and also in terms of 16 | the variety of tools used. We suggest working through the introductory materials 17 | for each tool before working with the code in this repository. 18 | 19 | ## Blog Post 20 | 21 | [Genomic ancestry inference with deep learning](https://cloud.google.com/blog/big-data/2017/09/genomic-ancestry-inference-with-deep-learning) blog post provides a great overview of the end-to-end reference implementation. It also links to pre-processed data and trained model in Google Cloud Storage if you would like to skip some of the steps below. 22 | 23 | ## Getting Started 24 | 25 | 1. [Set up the Dataflow SDK for Python](https://cloud.google.com/dataflow/docs/quickstarts/quickstart-python) 26 | 27 | 2. [Set up Cloud ML Engine](https://cloud.google.com/ml-engine/docs/quickstarts/command-line) 28 | 29 | 3. This code depends on a few additional python packages. If you are 30 | using [virtualenv](https://virtualenv.pypa.io/), the following commands will 31 | create a virtualenv, activate it, and install those dependencies. 32 | 33 | ``` 34 | virtualenv --system-site-packages ~/virtualEnvs/tensorflow 35 | source ~/virtualEnvs/tensorflow/bin/activate 36 | pip2.7 install --upgrade pip jinja2 pyfarmhash apache-beam[gcp] tensorflow 37 | ``` 38 | 39 | 4. Set some environment variables to make copy/pasting commands a bit easier. 40 | 41 | * `PROJECT_ID=` 42 | * `BUCKET=gs://` this should be the **regional** bucket you 43 | created during Cloud ML Engine setup. 44 | 45 | 5. git clone this repository and change into its directory 46 | 47 | ## Pre-processing using Apache Beam 48 | 49 | * See if a query for the data you want to work with is already available in the [`preprocess`](./preprocess) directory. If not: 50 | * See also [Select Genomic Data to work with](https://cloud.google.com/genomics/docs/public-datasets/) for other public data and how to load your own data. 51 | * Write jinja files containing the queries for your desired data. 52 | * Run a [`preprocess_data.py`](./trainer/preprocess_data.py) pipeline to convert 53 | the data from BigQuery to TFRecords in Cloud Storage. For example: 54 | 55 | Preprocess training data: 56 | 57 | ``` 58 | python2.7 -m trainer.preprocess_data \ 59 | --setup_file ./setup.py \ 60 | --output ${BUCKET}/1000-genomes \ 61 | --project ${PROJECT_ID} \ 62 | --metadata ./preprocess/1000_genomes_metadata.jinja \ 63 | --input ./preprocess/1000_genomes_phase3_b37.jinja \ 64 | --runner DataflowRunner \ 65 | --worker_machine_type n1-highmem-8 \ 66 | --no_hethom_words 67 | ``` 68 | 69 | Preprocess validation data: 70 | 71 | ``` 72 | python2.7 -m trainer.preprocess_data \ 73 | --setup_file ./setup.py \ 74 | --output ${BUCKET}/sgdp \ 75 | --project ${PROJECT_ID} \ 76 | --metadata ./preprocess/sgdp_metadata.jinja \ 77 | --input ./preprocess/sgdp_data_b37.jinja \ 78 | --runner DataflowRunner \ 79 | --no_hethom_words 80 | ``` 81 | 82 | ## Training using CloudML 83 | 84 | ``` 85 | EXAMPLES_SUBDIR= 86 | JOB_NAME=super_population_1000_genomes 87 | gcloud ai-platform jobs submit training ${JOB_NAME} \ 88 | --project ${PROJECT_ID} \ 89 | --region us-central1 \ 90 | --config config.yaml \ 91 | --package-path ./trainer \ 92 | --module-name trainer.variants_inference \ 93 | --job-dir ${BUCKET}/models/${JOB_NAME} \ 94 | --runtime-version 1.2 \ 95 | -- \ 96 | --input_dir ${BUCKET}/1000-genomes/${EXAMPLES_SUBDIR}/ \ 97 | --export_dir ${BUCKET}/models/${JOB_NAME} \ 98 | --sparse_features all_not_x_y \ 99 | --num_classes 5 \ 100 | --eval_labels="AFR,AMR,EAS,EUR,SAS" \ 101 | --target_field super_population \ 102 | --hidden_units 20 \ 103 | --num_buckets 50000 \ 104 | --num_train_steps 10000 105 | ``` 106 | 107 | If training results in an out of memory exception, add argument `--num_eval_steps 1` to the command line. 108 | 109 | To inspect the behavior of training, launch TensorBoard and point it at the summary logs produced during training — both during and after execution. 110 | 111 | ``` 112 | tensorboard --port=8080 \ 113 | --logdir ${BUCKET}/models/${JOB_NAME}/ 114 | ``` 115 | 116 | *Tip: When running all of these commands from [Google Cloud Shell](https://cloud.google.com/shell/docs/), the [web preview](https://cloud.google.com/shell/docs/using-web-preview) feature can be used to view the TensorBoard user interface.* 117 | 118 | The model generally converges sooner than 10,000 steps and you'll see this via TensorBoard. Training can be stopped early to avoid overfitting. To obtain the "saved model" needed for prediction, start training again from the exact same output directory (it will pick up where it left off) and have it run for a few more steps than it has already completed. 119 | 120 | For example, if the job was cancelled after completing step 5,632, the following command will trigger a save model operation. 121 | 122 | ``` 123 | gcloud ai-platform jobs submit training ${JOB_NAME}_save_model \ 124 | ... 125 | --num_train_steps 5700 126 | ``` 127 | 128 | ## Hyperparameter tuning 129 | 130 | Cloud ML Engine provides out of the box support for [Hyperparameter 131 | tuning](https://cloud.google.com/ml-engine/docs/concepts/hyperparameter-tuning-overview). Running Hyperparameter tuning job is exactly same as a training job except you need to provide options in [TrainingInput](https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#traininginput). 132 | 133 | ``` 134 | EXAMPLES_SUBDIR= 135 | gcloud ai-platform jobs submit training ${JOB_NAME} \ 136 | --project ${PROJECT_ID} \ 137 | --region us-central1 \ 138 | --package-path ./trainer \ 139 | --module-name trainer.variants_inference \ 140 | --job-dir ${BUCKET}/hptuning/${JOB_NAME} \ 141 | --config hptuning_config.yaml \ 142 | -- \ 143 | --input_dir ${BUCKET}/1000-genomes/${EXAMPLES_SUBDIR}/examples* \ 144 | --export_dir ${BUCKET}/hptuning/${JOB_NAME} \ 145 | --sparse_features all_not_x_y \ 146 | --num_classes 5 \ 147 | --eval_labels="AFR,AMR,EAS,EUR,SAS" \ 148 | --target_field super_population \ 149 | --hidden_units 20 \ 150 | --num_buckets 50000 \ 151 | --num_train_steps 10000 152 | ``` 153 | 154 | ## Batch predict 155 | 156 | ``` 157 | EXAMPLES_SUBDIR= 158 | EXPORT_SUBDIR= 159 | gcloud --project ${PROJECT_ID} ai-platform jobs submit \ 160 | prediction ${JOB_NAME}_predict \ 161 | --model-dir \ 162 | ${BUCKET}/models/${JOB_NAME}/export/Servo/${EXPORT_SUBDIR} \ 163 | --input-paths ${BUCKET}/sgdp/${EXAMPLES_SUBDIR}/examples* \ 164 | --output-path ${BUCKET}/predictions/${JOB_NAME} \ 165 | --region us-central1 \ 166 | --data-format TF_RECORD_GZIP 167 | ``` 168 | 169 | If prediction yields an error regarding the size of the saved model, request more quota for your project. 170 | 171 | ## Examine the prediction results 172 | 173 | For Simons Genome Diversity project data, one might examine the prediction results as follows: 174 | 175 | ``` 176 | bq load --source_format NEWLINE_DELIMITED_JSON --autodetect \ 177 | YOUR-DATASET.sgdp_ancestry_prediction_results \ 178 | ${BUCKET}/predictions/prediction.results* 179 | ``` 180 | 181 | ``` 182 | SELECT 183 | key, 184 | probabilities[ORDINAL(1)] AS AFR, 185 | probabilities[ORDINAL(2)] AS AMR, 186 | probabilities[ORDINAL(3)] AS EAS, 187 | probabilities[ORDINAL(4)] AS EUR, 188 | probabilities[ORDINAL(5)] AS SAS, 189 | info.* 190 | FROM 191 | `YOUR-DATASET.sgdp_ancestry_prediction_results` 192 | JOIN 193 | `bigquery-public-data.human_genome_variants.simons_genome_diversity_project_sample_attributes` AS info 194 | ON 195 | key = id_from_vcf 196 | ORDER BY 197 | region, population 198 | ``` 199 | 200 | If you are using the BigQuery web UI, you can click on `Save to GoogleSheets` and then in GoogleSheets: 201 | 202 | * select the 5 columns of prediction probabilities 203 | * `Format` -> `Conditional Formatting` -> `Color Scale` and use bright yellow for `Max Value` 204 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | trainingInput: 2 | scaleTier: CUSTOM 3 | masterType: large_model 4 | workerType: large_model 5 | parameterServerType: large_model 6 | workerCount: 5 7 | parameterServerCount: 1 8 | -------------------------------------------------------------------------------- /hptuning_config.yaml: -------------------------------------------------------------------------------- 1 | trainingInput: 2 | scaleTier: CUSTOM 3 | masterType: large_model 4 | workerType: large_model 5 | parameterServerType: large_model 6 | workerCount: 5 7 | parameterServerCount: 1 8 | hyperparameters: 9 | goal: MAXIMIZE 10 | hyperparameterMetricTag: precision_at_1 11 | maxTrials: 10 12 | maxParallelTrials: 4 13 | params: 14 | - parameterName: learning_rate 15 | type: DOUBLE 16 | minValue: 0.0001 17 | maxValue: 0.01 18 | scaleType: UNIT_REVERSE_LOG_SCALE 19 | -------------------------------------------------------------------------------- /preprocess/1000_genomes_metadata.jinja: -------------------------------------------------------------------------------- 1 | -- 2 | -- Retrieve metadata for 1000 Genomes individuals. 3 | -- 4 | SELECT 5 | Sample AS {{ KEY_COLUMN }}, 6 | Population AS {{ POPULATION_COLUMN }}, 7 | Super_Population AS {{ SUPER_POPULATION_COLUMN }}, 8 | Gender AS {{ GENDER_COLUMN }} 9 | FROM 10 | `genomics-public-data.1000_genomes.sample_info` 11 | -------------------------------------------------------------------------------- /preprocess/1000_genomes_phase3_b37.jinja: -------------------------------------------------------------------------------- 1 | -- 2 | -- Retrieve all variant calls for all individuals in 1000 Genomes. 3 | -- 4 | SELECT 5 | call.call_set_name AS {{ KEY_COLUMN }}, 6 | reference_name AS {{ CONTIG_COLUMN }}, 7 | start AS {{ START_COLUMN }}, 8 | -- 'end' is needed when alt is symbolic such as https://github.com/googlegenomics/bigquery-examples/tree/master/1000genomes/data-stories/understanding-alternate-alleles 9 | `end` AS {{ END_COLUMN }}, 10 | reference_bases AS {{ REF_COLUMN }}, 11 | alt AS {{ ALT_COLUMN }}, 12 | alt_offset + 1 AS {{ ALT_NUM_COLUMN }}, 13 | call.genotype[SAFE_ORDINAL(1)] AS {{ FIRST_ALLELE_COLUMN }}, 14 | call.genotype[SAFE_ORDINAL(2)] AS {{ SECOND_ALLELE_COLUMN }} 15 | FROM `genomics-public-data.1000_genomes_phase_3.variants_20150220_release` v, 16 | v.call call, 17 | v.alternate_bases alt WITH OFFSET alt_offset 18 | WHERE 19 | EXISTS (SELECT gt FROM UNNEST(call.genotype) gt WHERE gt = (alt_offset + 1)) 20 | -------------------------------------------------------------------------------- /preprocess/1000_genomes_phase3_b37_limit10.jinja: -------------------------------------------------------------------------------- 1 | -- 2 | -- Retrieve just a few variant calls for all individuals in 1000 Genomes. 3 | -- This query is useful for local development of the preprocess pipeline. 4 | -- 5 | SELECT 6 | call.call_set_name AS {{ KEY_COLUMN }}, 7 | reference_name AS {{ CONTIG_COLUMN }}, 8 | start AS {{ START_COLUMN }}, 9 | -- 'end' is needed when alt is symbolic such as https://github.com/googlegenomics/bigquery-examples/tree/master/1000genomes/data-stories/understanding-alternate-alleles 10 | `end` AS {{ END_COLUMN }}, 11 | reference_bases AS {{ REF_COLUMN }}, 12 | alt AS {{ ALT_COLUMN }}, 13 | alt_offset + 1 AS {{ ALT_NUM_COLUMN }}, 14 | call.genotype[SAFE_ORDINAL(1)] AS {{ FIRST_ALLELE_COLUMN }}, 15 | call.genotype[SAFE_ORDINAL(2)] AS {{ SECOND_ALLELE_COLUMN }} 16 | FROM `genomics-public-data.1000_genomes_phase_3.variants_20150220_release` v, 17 | v.call call, 18 | v.alternate_bases alt WITH OFFSET alt_offset 19 | WHERE 20 | EXISTS (SELECT gt FROM UNNEST(call.genotype) gt WHERE gt = (alt_offset + 1)) 21 | LIMIT 10 -- Just grab enough data for a smoke test. 22 | -------------------------------------------------------------------------------- /preprocess/1000_genomes_phase3_b37_snps_only.jinja: -------------------------------------------------------------------------------- 1 | -- 2 | -- Retrieve all SNP variant calls for all individuals in 1000 Genomes. 3 | -- 4 | SELECT 5 | call.call_set_name AS {{ KEY_COLUMN }}, 6 | reference_name AS {{ CONTIG_COLUMN }}, 7 | start AS {{ START_COLUMN }}, 8 | -- 'end' is needed when alt is symbolic such as https://github.com/googlegenomics/bigquery-examples/tree/master/1000genomes/data-stories/understanding-alternate-alleles 9 | reference_bases AS {{ REF_COLUMN }}, 10 | `end` AS {{ END_COLUMN }}, 11 | reference_bases AS {{ REF_COLUMN }}, 12 | alt AS {{ ALT_COLUMN }}, 13 | alt_offset + 1 AS {{ ALT_NUM_COLUMN }}, 14 | call.genotype[SAFE_ORDINAL(1)] AS {{ FIRST_ALLELE_COLUMN }}, 15 | call.genotype[SAFE_ORDINAL(2)] AS {{ SECOND_ALLELE_COLUMN }} 16 | FROM `genomics-public-data.1000_genomes_phase_3.variants_20150220_release` v, 17 | v.call call, 18 | v.alternate_bases alt WITH OFFSET alt_offset 19 | WHERE 20 | EXISTS (SELECT gt FROM UNNEST(call.genotype) gt WHERE gt = (alt_offset + 1)) 21 | -- Include only SNPs. 22 | AND reference_bases IN ('A','C','G','T') 23 | AND alt IN ('A','C','G','T') 24 | -------------------------------------------------------------------------------- /preprocess/pgp_data_b37.jinja: -------------------------------------------------------------------------------- 1 | -- 2 | -- Retrieve all variant calls for Personal Genome Project participants. 3 | -- 4 | SELECT 5 | call.call_set_name AS {{ KEY_COLUMN }}, 6 | reference_name AS {{ CONTIG_COLUMN }}, 7 | start AS {{ START_COLUMN }}, 8 | -- 'end' is needed when alt is symbolic such as https://github.com/googlegenomics/bigquery-examples/tree/master/1000genomes/data-stories/understanding-alternate-alleles 9 | `end` AS {{ END_COLUMN }}, 10 | reference_bases AS {{ REF_COLUMN }}, 11 | alt AS {{ ALT_COLUMN }}, 12 | alt_offset + 1 AS {{ ALT_NUM_COLUMN }}, 13 | call.genotype[SAFE_ORDINAL(1)] AS {{ FIRST_ALLELE_COLUMN }}, 14 | call.genotype[SAFE_ORDINAL(2)] AS {{ SECOND_ALLELE_COLUMN }} 15 | FROM 16 | `google.com:biggene.pgp_20150205.genome_calls` v, 17 | v.call call, 18 | v.alternate_bases alt WITH OFFSET alt_offset 19 | WHERE 20 | EXISTS (SELECT gt FROM UNNEST(call.genotype) gt WHERE gt = (alt_offset + 1)) 21 | -- Skip no-calls and non-passing variants. 22 | AND NOT EXISTS (SELECT gt FROM UNNEST(call.genotype) gt WHERE gt < 0) 23 | AND call.allele1VariantQuality = 'VQHIGH' 24 | AND IFNULL(call.allele2VariantQuality = 'VQHIGH', TRUE) 25 | -------------------------------------------------------------------------------- /preprocess/pgp_data_b37_1kg_variants_only.jinja: -------------------------------------------------------------------------------- 1 | -- 2 | -- Retrieve variant calls for Personal Genome Project participants, limiting 3 | -- to only those present in the training data to attempt to minimize batch 4 | -- effects from the differing sequence platforms and variant calling pipelines. 5 | -- 6 | SELECT 7 | call.call_set_name AS {{ KEY_COLUMN }}, 8 | v.reference_name AS {{ CONTIG_COLUMN }}, 9 | v.start AS {{ START_COLUMN }}, 10 | -- 'end' is needed when alt is symbolic such as https://github.com/googlegenomics/bigquery-examples/tree/master/1000genomes/data-stories/understanding-alternate-alleles 11 | v.`end` AS {{ END_COLUMN }}, 12 | v.reference_bases AS {{ REF_COLUMN }}, 13 | alt AS {{ ALT_COLUMN }}, 14 | alt_offset + 1 AS {{ ALT_NUM_COLUMN }}, 15 | call.genotype[SAFE_ORDINAL(1)] AS {{ FIRST_ALLELE_COLUMN }}, 16 | call.genotype[SAFE_ORDINAL(2)] AS {{ SECOND_ALLELE_COLUMN }} 17 | FROM 18 | `google.com:biggene.pgp_20150205.genome_calls` v, 19 | v.call call, 20 | v.alternate_bases alt WITH OFFSET alt_offset 21 | JOIN ( 22 | SELECT 23 | CONCAT('chr', reference_name) AS chrom, 24 | start, 25 | `end`, 26 | reference_bases, 27 | alt 28 | FROM 29 | `genomics-public-data.1000_genomes_phase_3.variants_20150220_release` kgv, 30 | kgv.alternate_bases alt 31 | ) AS kg 32 | ON 33 | kg.chrom = v.reference_name 34 | AND kg.start = v.start 35 | AND kg.`end` = v.`end` 36 | AND kg.reference_bases = v.reference_bases 37 | AND kg.alt = alt 38 | WHERE 39 | EXISTS (SELECT gt FROM UNNEST(call.genotype) gt WHERE gt = (alt_offset + 1)) 40 | -- Skip no-calls and non-passing variants. 41 | AND NOT EXISTS (SELECT gt FROM UNNEST(call.genotype) gt WHERE gt < 0) 42 | AND call.allele1VariantQuality = 'VQHIGH' 43 | AND IFNULL(call.allele2VariantQuality = 'VQHIGH', TRUE) 44 | -------------------------------------------------------------------------------- /preprocess/pgp_data_b37_limit10.jinja: -------------------------------------------------------------------------------- 1 | -- 2 | -- Retrieve just a few variant calls for Personal Genome Project participants. 3 | -- This query is useful for local development of the preprocess pipeline. 4 | -- 5 | SELECT 6 | call.call_set_name AS {{ KEY_COLUMN }}, 7 | reference_name AS {{ CONTIG_COLUMN }}, 8 | start AS {{ START_COLUMN }}, 9 | -- 'end' is needed when alt is symbolic such as https://github.com/googlegenomics/bigquery-examples/tree/master/1000genomes/data-stories/understanding-alternate-alleles 10 | `end` AS {{ END_COLUMN }}, 11 | reference_bases AS {{ REF_COLUMN }}, 12 | alt AS {{ ALT_COLUMN }}, 13 | alt_offset + 1 AS {{ ALT_NUM_COLUMN }}, 14 | call.genotype[SAFE_ORDINAL(1)] AS {{ FIRST_ALLELE_COLUMN }}, 15 | call.genotype[SAFE_ORDINAL(2)] AS {{ SECOND_ALLELE_COLUMN }} 16 | FROM 17 | `google.com:biggene.pgp_20150205.genome_calls` v, 18 | v.call call, 19 | v.alternate_bases alt WITH OFFSET alt_offset 20 | WHERE 21 | EXISTS (SELECT gt FROM UNNEST(call.genotype) gt WHERE gt = (alt_offset + 1)) 22 | -- Skip no-calls and non-passing variants. 23 | AND NOT EXISTS (SELECT gt FROM UNNEST(call.genotype) gt WHERE gt < 0) 24 | AND call.allele1VariantQuality = 'VQHIGH' 25 | AND IFNULL(call.allele2VariantQuality = 'VQHIGH', TRUE) 26 | LIMIT 10 -- Just grab enough data for a smoke test. 27 | -------------------------------------------------------------------------------- /preprocess/pgp_metadata.jinja: -------------------------------------------------------------------------------- 1 | -- 2 | -- Retrieve metadata for Personal Genome Project participants. 3 | -- 4 | SELECT 5 | Participant AS {{ KEY_COLUMN }}, 6 | -- This data more closely matches super population, but use 7 | -- the same metadata column for both as a best effort. 8 | Race_ethnicity AS {{ POPULATION_COLUMN }}, 9 | Race_ethnicity AS {{ SUPER_POPULATION_COLUMN }}, 10 | Sex_Gender AS {{ GENDER_COLUMN }} 11 | FROM 12 | `google.com:biggene.pgp.phenotypes` 13 | -------------------------------------------------------------------------------- /preprocess/platinum_genomes_b37.jinja: -------------------------------------------------------------------------------- 1 | -- 2 | -- Retrieve all passing variant calls for all individuals in Platinum Genomes. 3 | -- 4 | SELECT 5 | call.call_set_name AS {{ KEY_COLUMN }}, 6 | reference_name AS {{ CONTIG_COLUMN }}, 7 | start AS {{ START_COLUMN }}, 8 | -- 'end' is needed when alt is symbolic such as https://github.com/googlegenomics/bigquery-examples/tree/master/1000genomes/data-stories/understanding-alternate-alleles 9 | `end` AS {{ END_COLUMN }}, 10 | reference_bases AS {{ REF_COLUMN }}, 11 | alt AS {{ ALT_COLUMN }}, 12 | alt_offset + 1 AS {{ ALT_NUM_COLUMN }}, 13 | call.genotype[SAFE_ORDINAL(1)] AS {{ FIRST_ALLELE_COLUMN }}, 14 | call.genotype[SAFE_ORDINAL(2)] AS {{ SECOND_ALLELE_COLUMN }} 15 | FROM 16 | `genomics-public-data.platinum_genomes.variants` v, 17 | v.call call, 18 | v.alternate_bases alt WITH OFFSET alt_offset 19 | WHERE 20 | EXISTS (SELECT gt FROM UNNEST(call.genotype) gt WHERE gt = (alt_offset + 1)) 21 | -- Skip no-calls and non-passing variants. 22 | AND NOT EXISTS (SELECT gt FROM UNNEST(call.genotype) gt WHERE gt < 0) 23 | AND NOT EXISTS (SELECT ft FROM UNNEST(call.FILTER) ft WHERE ft NOT IN ('PASS', '.')) 24 | -------------------------------------------------------------------------------- /preprocess/platinum_genomes_metadata.jinja: -------------------------------------------------------------------------------- 1 | -- 2 | -- Retrieve metadata for Platinum Genomes individuals. 3 | -- 4 | SELECT 5 | catalog_id AS {{ KEY_COLUMN }}, 6 | description AS {{ POPULATION_COLUMN }}, 7 | race AS {{ SUPER_POPULATION_COLUMN }}, 8 | gender AS {{ GENDER_COLUMN }} 9 | FROM 10 | `google.com:biggene.platinum_genomes.sample_info` 11 | -------------------------------------------------------------------------------- /preprocess/sgdp_data_b37.jinja: -------------------------------------------------------------------------------- 1 | -- 2 | -- Retrieve high quality variant calls for the Simons Genome Diversity 3 | -- Project Dataset. 4 | -- 5 | SELECT 6 | call.call_set_name AS {{ KEY_COLUMN }}, 7 | reference_name AS {{ CONTIG_COLUMN }}, 8 | start AS {{ START_COLUMN }}, 9 | -- 'end' is needed when alt is symbolic such as 10 | `end` AS {{ END_COLUMN }}, 11 | reference_bases AS {{ REF_COLUMN }}, 12 | alt AS {{ ALT_COLUMN }}, 13 | alt_offset + 1 AS {{ ALT_NUM_COLUMN }}, 14 | call.genotype[SAFE_ORDINAL(1)] AS {{ FIRST_ALLELE_COLUMN }}, 15 | call.genotype[SAFE_ORDINAL(2)] AS {{ SECOND_ALLELE_COLUMN }} 16 | FROM `genomics-public-data.simons_genome_diversity_project.single_sample_genome_calls` v, 17 | v.call call, 18 | v.alternate_bases alt WITH OFFSET alt_offset 19 | WHERE 20 | -- Ensure that at least one of the genotypes corresponds to this alternate. 21 | EXISTS (SELECT gt FROM UNNEST(call.genotype) gt WHERE gt = (alt_offset + 1)) 22 | -- VCF header says: 23 | -- "filter level in range 0-9 or no value (non-integer: N,?) with zero 24 | -- being least reliable; to threshold at FL=n, use all levels n-9" 25 | -- Note that there are two samples with FL="" for all variants. 26 | AND call.FL IN ("", "1","2","3","4","5","6","7","8","9") 27 | -------------------------------------------------------------------------------- /preprocess/sgdp_metadata.jinja: -------------------------------------------------------------------------------- 1 | -- 2 | -- Retrieve metadata for the Simons Genome Diversity 3 | -- Project Dataset. 4 | -- 5 | SELECT 6 | id_from_vcf AS {{ KEY_COLUMN }}, 7 | population AS {{ POPULATION_COLUMN }}, 8 | region AS {{ SUPER_POPULATION_COLUMN }}, 9 | sex AS {{ GENDER_COLUMN }} 10 | FROM 11 | `genomics-public-data.simons_genome_diversity_project.sample_attributes` 12 | -------------------------------------------------------------------------------- /preprocess/sgdp_metadata_remap_labels.jinja: -------------------------------------------------------------------------------- 1 | -- 2 | -- Retrieve metadata for a subset of individuals from the Simons Genome 3 | -- Diversity Project Dataset, remapping labels to be those from 4 | -- the 1000 Genomes study. 5 | -- 6 | SELECT 7 | id_from_vcf AS {{ KEY_COLUMN }}, 8 | population AS {{ POPULATION_COLUMN }}, 9 | CASE region 10 | WHEN "Africa" THEN "AFR" 11 | WHEN "America" THEN "AMR" 12 | WHEN "EastAsia" THEN "EAS" 13 | WHEN "SouthAsia" THEN "SAS" 14 | WHEN "WestEurasia" THEN "EUR" 15 | END AS {{ SUPER_POPULATION_COLUMN }}, 16 | sex AS {{ GENDER_COLUMN }} 17 | FROM 18 | `genomics-public-data.simons_genome_diversity_project.sample_attributes` 19 | WHERE 20 | region IN ('WestEurasia', 'Africa', 'America', 'EastAsia', 'SouthAsia') 21 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | # Copyright 2016 Google Inc. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from setuptools import find_packages 18 | from setuptools import setup 19 | 20 | REQUIRED_PACKAGES = ['Jinja2>=2.10.1', 'pyfarmhash==0.2.2', 'absl-py>=0.7.1'] 21 | 22 | setup( 23 | name='trainer', 24 | version='0.1', 25 | author='Google', 26 | author_email='google-genomics-contact@googlegroups.com', 27 | install_requires=REQUIRED_PACKAGES, 28 | packages=find_packages(), 29 | include_package_data=True, 30 | description='Google Cloud Machine Learning genomics example', 31 | requires=[]) 32 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Package marker file.""" 15 | -------------------------------------------------------------------------------- /trainer/ancestry_metadata_encoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | # Copyright 2016 Google Inc. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Encode sample metadata as features for an ancestry inference.""" 17 | 18 | from collections import defaultdict 19 | 20 | import trainer.feature_encoder as encoder 21 | import trainer.util as util 22 | 23 | # Decouple source table column names from the dictionary keys used 24 | # in this code. 25 | POPULATION_COLUMN = 'pop' 26 | SUPER_POPULATION_COLUMN = 'sup' 27 | GENDER_COLUMN = 'sex' 28 | 29 | # Normalize over possible sex/gender values. 30 | GENDER_MAP = defaultdict(lambda: encoder.NA_INTEGER) 31 | GENDER_MAP.update({ 32 | 'male': 0, 33 | 'female': 1, 34 | 'Male': 0, 35 | 'Female': 1, 36 | 'm': 0, 37 | 'f': 1, 38 | 'M': 0, 39 | 'F': 1 40 | }) 41 | 42 | # Population and Super population labels used for training and evaluation 43 | # will always be the ones from 1000 Genomes plus a label for unknown 44 | # populations. If we wish to use another dataset for training and 45 | # evaluation, we'll need to provide the mapping from 1000 Genomes 46 | # labels to those used for the dataset. 47 | SUPER_POPULATIONS = ['AFR', 'AMR', 'EAS', 'EUR', 'SAS', 'UNK'] 48 | 49 | SUPER_POPULATION_MAP = defaultdict(lambda: encoder.NA_INTEGER) 50 | for pop in range(len(SUPER_POPULATIONS)): 51 | SUPER_POPULATION_MAP[SUPER_POPULATIONS[pop]] = pop 52 | 53 | POPULATIONS = [ 54 | 'ACB', 'ASW', 'BEB', 'CDX', 'CEU', 'CHB', 'CHS', 'CLM', 'ESN', 'FIN', 'GBR', 55 | 'GIH', 'GWD', 'IBS', 'ITU', 'JPT', 'KHV', 'LWK', 'MSL', 'MXL', 'PEL', 'PJL', 56 | 'PUR', 'STU', 'TSI', 'YRI' 57 | ] 58 | 59 | POPULATION_MAP = defaultdict(lambda: encoder.NA_INTEGER) 60 | for pop in range(len(POPULATIONS)): 61 | POPULATION_MAP[POPULATIONS[pop]] = pop 62 | 63 | # Metadata feature name constants 64 | POPULATION_FEATURE = 'population' 65 | POPULATION_STRING_FEATURE = 'population_string' 66 | SUPER_POPULATION_FEATURE = 'super_population' 67 | SUPER_POPULATION_STRING_FEATURE = 'super_population_string' 68 | GENDER_FEATURE = 'gender' 69 | GENDER_STRING_FEATURE = 'gender_string' 70 | 71 | 72 | def metadata_to_ancestry_features(sample_metadata): 73 | """Create features from sample metadata. 74 | 75 | Args: 76 | sample_metadata: dictionary of metadata for one sample 77 | 78 | Returns: 79 | A dictionary of TensorFlow features. 80 | """ 81 | features = { 82 | # Nomalize population to integer or NA_INTEGER if no match. 83 | POPULATION_FEATURE: 84 | util.int64_feature( 85 | POPULATION_MAP[str(sample_metadata[POPULATION_COLUMN])]), 86 | # Use verbatim value of population. 87 | POPULATION_STRING_FEATURE: 88 | util.bytes_feature(str(sample_metadata[POPULATION_COLUMN])), 89 | # Nomalize super population to integer or NA_INTEGER if no match. 90 | SUPER_POPULATION_FEATURE: 91 | util.int64_feature(SUPER_POPULATION_MAP[str(sample_metadata[ 92 | SUPER_POPULATION_COLUMN])]), 93 | # Use verbatim value of super population. 94 | SUPER_POPULATION_STRING_FEATURE: 95 | util.bytes_feature(str(sample_metadata[SUPER_POPULATION_COLUMN])), 96 | # Nomalize sex/gender to integer or NA_INTEGER if no match. 97 | GENDER_FEATURE: 98 | util.int64_feature(GENDER_MAP[str(sample_metadata[GENDER_COLUMN])]), 99 | # Use verbatim value of sex/gender. 100 | GENDER_STRING_FEATURE: 101 | util.bytes_feature(str(sample_metadata[GENDER_COLUMN])) 102 | } 103 | return features 104 | -------------------------------------------------------------------------------- /trainer/feature_encoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | # Copyright 2016 Google Inc. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Encode sample metadata and variant calls as TensorFlow features. 17 | 18 | Given sample metadata and variants, the sample_to_example method will 19 | call the specified encoding strategies to construct the final TensorFlow 20 | Example protocol buffer. 21 | """ 22 | 23 | from collections import defaultdict 24 | 25 | import tensorflow as tf 26 | 27 | import trainer.util as util 28 | 29 | # Values to use for missing data. 30 | NA_STRING = 'NA' 31 | NA_INTEGER = -1 32 | 33 | # Decouple variant data source table column names from the dictionary 34 | # keys used in the variant encoders. 35 | KEY_COLUMN = 'key' 36 | CONTIG_COLUMN = 'contig' 37 | START_COLUMN = 'start_pos' 38 | END_COLUMN = 'end_pos' 39 | REF_COLUMN = 'ref' 40 | ALT_COLUMN = 'alt' 41 | ALT_NUM_COLUMN = 'alt_num' 42 | FIRST_ALLELE_COLUMN = 'first' 43 | SECOND_ALLELE_COLUMN = 'second' 44 | 45 | # Feature name constants 46 | SAMPLE_NAME_FEATURE = 'sample_name' 47 | 48 | 49 | def build_sample_to_example(metadata_to_features_fn, variants_to_features_fn): 50 | """Builder for the strategy to construct examples from sample data. 51 | 52 | Args: 53 | metadata_to_features_fn: the strategy to encode sample metadata as features 54 | variants_to_features_fn: the strategy to encode sample variants as features 55 | 56 | Returns: 57 | The instantiated strategy. 58 | """ 59 | 60 | def sample_to_example(sample, sample_variants, samples_metadata): 61 | """Convert sample metadata and variants to TensorFlow examples. 62 | 63 | Args: 64 | sample: the identifier for the sample 65 | sample_variants: the sample's variant calls 66 | samples_metadata: dictionary of metadata for all samples 67 | 68 | Returns: 69 | A filled in TensorFlow Example protocol buffer for this sample. 70 | """ 71 | features = {SAMPLE_NAME_FEATURE: util.bytes_feature(str(sample))} 72 | 73 | # Some samples may have no metadata, but we may still want to preprocess 74 | # the data for prediction use cases. 75 | metadata = defaultdict(lambda: NA_STRING) 76 | if sample in samples_metadata: 77 | metadata.update(samples_metadata[sample]) 78 | 79 | # Fill in features from metadata. 80 | features.update(metadata_to_features_fn(metadata)) 81 | 82 | # Fill in features from variants. 83 | features.update(variants_to_features_fn(sample_variants)) 84 | 85 | return tf.train.Example(features=tf.train.Features(feature=features)) 86 | 87 | return sample_to_example 88 | -------------------------------------------------------------------------------- /trainer/feature_encoder_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | # Copyright 2016 Google Inc. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Test encoding of sample metadata and variant calls to TensorFlow features.""" 17 | 18 | import copy 19 | import unittest 20 | import trainer.ancestry_metadata_encoder as metadata_encoder 21 | import trainer.feature_encoder as encoder 22 | import trainer.variant_encoder as variant_encoder 23 | 24 | # Test data. 25 | SAMPLE_ID = 'sample1' 26 | 27 | SAMPLE_METADATA = {} 28 | SAMPLE_METADATA[SAMPLE_ID] = { 29 | encoder.KEY_COLUMN: SAMPLE_ID, 30 | metadata_encoder.GENDER_COLUMN: 'female', 31 | metadata_encoder.SUPER_POPULATION_COLUMN: 'SAS', 32 | metadata_encoder.POPULATION_COLUMN: 'some pop not in the training labels' 33 | } 34 | 35 | HETEROZYGOUS_VARIANT_CALL = { 36 | encoder.KEY_COLUMN: SAMPLE_ID, 37 | encoder.CONTIG_COLUMN: 'chr9', 38 | encoder.START_COLUMN: 3500000, 39 | encoder.END_COLUMN: 3500001, 40 | encoder.REF_COLUMN: 'T', 41 | encoder.ALT_COLUMN: 'G', 42 | encoder.ALT_NUM_COLUMN: 1, 43 | encoder.FIRST_ALLELE_COLUMN: 0, 44 | encoder.SECOND_ALLELE_COLUMN: 1 45 | } 46 | 47 | HOMOZYGOUS_ALT_VARIANT_CALL = copy.copy(HETEROZYGOUS_VARIANT_CALL) 48 | HOMOZYGOUS_ALT_VARIANT_CALL[encoder.FIRST_ALLELE_COLUMN] = 1 49 | HOMOZYGOUS_ALT_VARIANT_CALL[encoder.SECOND_ALLELE_COLUMN] = 1 50 | 51 | HOMOZYGOUS_REF_VARIANT_CALL = copy.copy(HETEROZYGOUS_VARIANT_CALL) 52 | HOMOZYGOUS_REF_VARIANT_CALL[encoder.FIRST_ALLELE_COLUMN] = 0 53 | HOMOZYGOUS_REF_VARIANT_CALL[encoder.SECOND_ALLELE_COLUMN] = 0 54 | 55 | 56 | class FeatureEncoderTest(unittest.TestCase): 57 | 58 | def test_sample_to_example(self): 59 | expected = """features { 60 | feature { 61 | key: "gender" 62 | value { 63 | int64_list { 64 | value: 1 65 | } 66 | } 67 | } 68 | feature { 69 | key: "gender_string" 70 | value { 71 | bytes_list { 72 | value: "female" 73 | } 74 | } 75 | } 76 | feature { 77 | key: "population" 78 | value { 79 | int64_list { 80 | value: -1 81 | } 82 | } 83 | } 84 | feature { 85 | key: "population_string" 86 | value { 87 | bytes_list { 88 | value: "some pop not in the training labels" 89 | } 90 | } 91 | } 92 | feature { 93 | key: "sample_name" 94 | value { 95 | bytes_list { 96 | value: "sample1" 97 | } 98 | } 99 | } 100 | feature { 101 | key: "super_population" 102 | value { 103 | int64_list { 104 | value: 4 105 | } 106 | } 107 | } 108 | feature { 109 | key: "super_population_string" 110 | value { 111 | bytes_list { 112 | value: "SAS" 113 | } 114 | } 115 | } 116 | feature { 117 | key: "variants_9" 118 | value { 119 | int64_list { 120 | value: -5153783975271321865 121 | } 122 | } 123 | } 124 | } 125 | """ 126 | variants_to_features_fn = variant_encoder.build_variants_to_features( 127 | variant_to_feature_name_fn=variant_encoder.variant_to_contig_feature_name, 128 | variant_to_words_fn=variant_encoder.build_variant_to_words( 129 | add_hethom=False)) 130 | 131 | sample_to_example = encoder.build_sample_to_example( 132 | metadata_to_features_fn=metadata_encoder.metadata_to_ancestry_features, 133 | variants_to_features_fn=variants_to_features_fn) 134 | self.assertEqual( 135 | expected, 136 | str( 137 | sample_to_example(SAMPLE_ID, [HETEROZYGOUS_VARIANT_CALL], 138 | SAMPLE_METADATA))) 139 | 140 | def test_sample_to_example_add_hethom(self): 141 | expected = """features { 142 | feature { 143 | key: "gender" 144 | value { 145 | int64_list { 146 | value: 1 147 | } 148 | } 149 | } 150 | feature { 151 | key: "gender_string" 152 | value { 153 | bytes_list { 154 | value: "female" 155 | } 156 | } 157 | } 158 | feature { 159 | key: "population" 160 | value { 161 | int64_list { 162 | value: -1 163 | } 164 | } 165 | } 166 | feature { 167 | key: "population_string" 168 | value { 169 | bytes_list { 170 | value: "some pop not in the training labels" 171 | } 172 | } 173 | } 174 | feature { 175 | key: "sample_name" 176 | value { 177 | bytes_list { 178 | value: "sample1" 179 | } 180 | } 181 | } 182 | feature { 183 | key: "super_population" 184 | value { 185 | int64_list { 186 | value: 4 187 | } 188 | } 189 | } 190 | feature { 191 | key: "super_population_string" 192 | value { 193 | bytes_list { 194 | value: "SAS" 195 | } 196 | } 197 | } 198 | feature { 199 | key: "variants_9" 200 | value { 201 | int64_list { 202 | value: -5153783975271321865 203 | value: 1206215103517908850 204 | } 205 | } 206 | } 207 | } 208 | """ 209 | variants_to_features_fn = variant_encoder.build_variants_to_features( 210 | variant_to_feature_name_fn=variant_encoder.variant_to_contig_feature_name, 211 | variant_to_words_fn=variant_encoder.build_variant_to_words( 212 | add_hethom=True)) 213 | 214 | sample_to_example = encoder.build_sample_to_example( 215 | metadata_to_features_fn=metadata_encoder.metadata_to_ancestry_features, 216 | variants_to_features_fn=variants_to_features_fn) 217 | self.assertEqual( 218 | expected, 219 | str( 220 | sample_to_example(SAMPLE_ID, [HETEROZYGOUS_VARIANT_CALL], 221 | SAMPLE_METADATA))) 222 | 223 | def test_sample_to_example_binned_variants(self): 224 | expected = """features { 225 | feature { 226 | key: "gender" 227 | value { 228 | int64_list { 229 | value: 1 230 | } 231 | } 232 | } 233 | feature { 234 | key: "gender_string" 235 | value { 236 | bytes_list { 237 | value: "female" 238 | } 239 | } 240 | } 241 | feature { 242 | key: "population" 243 | value { 244 | int64_list { 245 | value: -1 246 | } 247 | } 248 | } 249 | feature { 250 | key: "population_string" 251 | value { 252 | bytes_list { 253 | value: "some pop not in the training labels" 254 | } 255 | } 256 | } 257 | feature { 258 | key: "sample_name" 259 | value { 260 | bytes_list { 261 | value: "sample1" 262 | } 263 | } 264 | } 265 | feature { 266 | key: "super_population" 267 | value { 268 | int64_list { 269 | value: 4 270 | } 271 | } 272 | } 273 | feature { 274 | key: "super_population_string" 275 | value { 276 | bytes_list { 277 | value: "SAS" 278 | } 279 | } 280 | } 281 | feature { 282 | key: "variants_9_3" 283 | value { 284 | int64_list { 285 | value: -5153783975271321865 286 | } 287 | } 288 | } 289 | } 290 | """ 291 | variants_to_features_fn = variant_encoder.build_variants_to_features( 292 | variant_to_feature_name_fn=variant_encoder. 293 | build_variant_to_binned_feature_name(), 294 | variant_to_words_fn=variant_encoder.build_variant_to_words( 295 | add_hethom=False)) 296 | 297 | sample_to_example = encoder.build_sample_to_example( 298 | metadata_to_features_fn=metadata_encoder.metadata_to_ancestry_features, 299 | variants_to_features_fn=variants_to_features_fn) 300 | self.assertEqual( 301 | expected, 302 | str( 303 | sample_to_example(SAMPLE_ID, [HETEROZYGOUS_VARIANT_CALL], 304 | SAMPLE_METADATA))) 305 | 306 | def test_sample_to_example_smaller_bins(self): 307 | expected = """features { 308 | feature { 309 | key: "gender" 310 | value { 311 | int64_list { 312 | value: 1 313 | } 314 | } 315 | } 316 | feature { 317 | key: "gender_string" 318 | value { 319 | bytes_list { 320 | value: "female" 321 | } 322 | } 323 | } 324 | feature { 325 | key: "population" 326 | value { 327 | int64_list { 328 | value: -1 329 | } 330 | } 331 | } 332 | feature { 333 | key: "population_string" 334 | value { 335 | bytes_list { 336 | value: "some pop not in the training labels" 337 | } 338 | } 339 | } 340 | feature { 341 | key: "sample_name" 342 | value { 343 | bytes_list { 344 | value: "sample1" 345 | } 346 | } 347 | } 348 | feature { 349 | key: "super_population" 350 | value { 351 | int64_list { 352 | value: 4 353 | } 354 | } 355 | } 356 | feature { 357 | key: "super_population_string" 358 | value { 359 | bytes_list { 360 | value: "SAS" 361 | } 362 | } 363 | } 364 | feature { 365 | key: "variants_9_35" 366 | value { 367 | int64_list { 368 | value: -5153783975271321865 369 | } 370 | } 371 | } 372 | } 373 | """ 374 | variants_to_features_fn = variant_encoder.build_variants_to_features( 375 | variant_to_feature_name_fn=variant_encoder. 376 | build_variant_to_binned_feature_name(bin_size=100000), 377 | variant_to_words_fn=variant_encoder.build_variant_to_words( 378 | add_hethom=False)) 379 | 380 | sample_to_example = encoder.build_sample_to_example( 381 | metadata_to_features_fn=metadata_encoder.metadata_to_ancestry_features, 382 | variants_to_features_fn=variants_to_features_fn) 383 | self.assertEqual( 384 | expected, 385 | str( 386 | sample_to_example(SAMPLE_ID, [HETEROZYGOUS_VARIANT_CALL], 387 | SAMPLE_METADATA))) 388 | 389 | def test_sample_to_example_binned_variants_add_hethom(self): 390 | expected = """features { 391 | feature { 392 | key: "gender" 393 | value { 394 | int64_list { 395 | value: 1 396 | } 397 | } 398 | } 399 | feature { 400 | key: "gender_string" 401 | value { 402 | bytes_list { 403 | value: "female" 404 | } 405 | } 406 | } 407 | feature { 408 | key: "population" 409 | value { 410 | int64_list { 411 | value: -1 412 | } 413 | } 414 | } 415 | feature { 416 | key: "population_string" 417 | value { 418 | bytes_list { 419 | value: "some pop not in the training labels" 420 | } 421 | } 422 | } 423 | feature { 424 | key: "sample_name" 425 | value { 426 | bytes_list { 427 | value: "sample1" 428 | } 429 | } 430 | } 431 | feature { 432 | key: "super_population" 433 | value { 434 | int64_list { 435 | value: 4 436 | } 437 | } 438 | } 439 | feature { 440 | key: "super_population_string" 441 | value { 442 | bytes_list { 443 | value: "SAS" 444 | } 445 | } 446 | } 447 | feature { 448 | key: "variants_9_3" 449 | value { 450 | int64_list { 451 | value: -5153783975271321865 452 | value: 1206215103517908850 453 | } 454 | } 455 | } 456 | } 457 | """ 458 | variants_to_features_fn = variant_encoder.build_variants_to_features( 459 | variant_to_feature_name_fn=variant_encoder. 460 | build_variant_to_binned_feature_name(), 461 | variant_to_words_fn=variant_encoder.build_variant_to_words( 462 | add_hethom=True)) 463 | 464 | sample_to_example = encoder.build_sample_to_example( 465 | metadata_to_features_fn=metadata_encoder.metadata_to_ancestry_features, 466 | variants_to_features_fn=variants_to_features_fn) 467 | self.assertEqual( 468 | expected, 469 | str( 470 | sample_to_example(SAMPLE_ID, [HETEROZYGOUS_VARIANT_CALL], 471 | SAMPLE_METADATA))) 472 | 473 | 474 | if __name__ == '__main__': 475 | unittest.main() 476 | -------------------------------------------------------------------------------- /trainer/preprocess_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | # Copyright 2016 Google Inc. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | r"""Pipeline to convert variant data from BigQuery to TensorFlow Example protos. 18 | 19 | For any samples without corresponding metadata, values indicating 20 | NA will be used instead for the metadata. 21 | 22 | USAGE: 23 | python -m trainer.preprocess_data \ 24 | --setup_file ./setup.py \ 25 | --project ${PROJECT_ID} \ 26 | --metadata preprocess/1000_genomes_metadata.jinja \ 27 | --input preprocess/1000_genomes_phase3_b37_limit10.jinja \ 28 | --output ${BUCKET}/1000-genomes 29 | """ 30 | 31 | import datetime 32 | import logging 33 | import os 34 | 35 | import apache_beam as beam 36 | from apache_beam.io import tfrecordio 37 | from apache_beam.io.filesystem import CompressionTypes 38 | from apache_beam.options.pipeline_options import GoogleCloudOptions 39 | from apache_beam.options.pipeline_options import PipelineOptions 40 | from apache_beam.options.pipeline_options import SetupOptions 41 | from apache_beam.options.pipeline_options import WorkerOptions 42 | from jinja2 import Template 43 | 44 | import trainer.ancestry_metadata_encoder as metadata_encoder 45 | import trainer.feature_encoder as encoder 46 | import trainer.util as util 47 | import trainer.variant_encoder as variant_encoder 48 | 49 | 50 | # Jinja template replacements to decouple column names from the source 51 | # tables from the dictionart keys used in this pipeline. 52 | 53 | METADATA_QUERY_REPLACEMENTS = { 54 | 'KEY_COLUMN': encoder.KEY_COLUMN, 55 | 'POPULATION_COLUMN': metadata_encoder.POPULATION_COLUMN, 56 | 'SUPER_POPULATION_COLUMN': metadata_encoder.SUPER_POPULATION_COLUMN, 57 | 'GENDER_COLUMN': metadata_encoder.GENDER_COLUMN, 58 | } 59 | 60 | DATA_QUERY_REPLACEMENTS = { 61 | 'KEY_COLUMN': encoder.KEY_COLUMN, 62 | 'CONTIG_COLUMN': encoder.CONTIG_COLUMN, 63 | 'START_COLUMN': encoder.START_COLUMN, 64 | 'END_COLUMN': encoder.END_COLUMN, 65 | 'REF_COLUMN': encoder.REF_COLUMN, 66 | 'ALT_COLUMN': encoder.ALT_COLUMN, 67 | 'ALT_NUM_COLUMN': encoder.ALT_NUM_COLUMN, 68 | 'FIRST_ALLELE_COLUMN': encoder.FIRST_ALLELE_COLUMN, 69 | 'SECOND_ALLELE_COLUMN': encoder.SECOND_ALLELE_COLUMN 70 | } 71 | 72 | 73 | def variants_to_examples(input_data, samples_metadata, sample_to_example_fn): 74 | """Converts variants to TensorFlow Example protos. 75 | 76 | Args: 77 | input_data: variant call dictionary objects with keys from 78 | DATA_QUERY_REPLACEMENTS 79 | samples_metadata: metadata dictionary objects with keys from 80 | METADATA_QUERY_REPLACEMENTS 81 | sample_to_example_fn: the feature encoder strategy to use to 82 | convert the source data into TensorFlow Example protos. 83 | 84 | Returns: 85 | TensorFlow Example protos. 86 | """ 87 | variant_kvs = input_data | 'BucketVariants' >> beam.Map( 88 | lambda row: (row[encoder.KEY_COLUMN], row)) 89 | 90 | sample_variant_kvs = variant_kvs | 'GroupBySample' >> beam.GroupByKey() 91 | 92 | examples = ( 93 | sample_variant_kvs 94 | | 'SamplesToExamples' >> beam.Map( 95 | lambda (key, vals), samples_metadata: sample_to_example_fn( 96 | key, vals, samples_metadata), 97 | beam.pvalue.AsSingleton(samples_metadata))) 98 | 99 | return examples 100 | 101 | 102 | class PreprocessOptions(PipelineOptions): 103 | 104 | @classmethod 105 | def _add_argparse_args(cls, parser): 106 | parser.add_argument( 107 | '--output', 108 | required=True, 109 | help='Output directory to which to write results.') 110 | parser.add_argument( 111 | '--input', 112 | required=True, 113 | help='Jinja file holding the query for the sample data.') 114 | parser.add_argument( 115 | '--metadata', 116 | required=True, 117 | help='Jinja file holding the query for the sample metadata.') 118 | parser.add_argument( 119 | '--hethom_words', 120 | dest='add_hethom', 121 | action='store_true', 122 | help='Add variant heterozygous/homozygous "word".') 123 | parser.add_argument( 124 | '--no_hethom_words', 125 | dest='add_hethom', 126 | action='store_false', 127 | help='Do not add variant heterozygous/homozygous "word".') 128 | parser.set_defaults(add_hethom=True) 129 | parser.add_argument( 130 | '--bin_size', 131 | type=int, 132 | help='The number of contiguous base pairs to use for each "bin". ' 133 | 'This parameter enables the placement of variant "words" into ' 134 | 'smaller genomic region features (as opposed to the default ' 135 | 'feature-per-chromosome) ') 136 | 137 | 138 | def run(argv=None): 139 | """Runs the variant preprocess pipeline. 140 | 141 | Args: 142 | argv: Pipeline options as a list of arguments. 143 | """ 144 | pipeline_options = PipelineOptions(flags=argv) 145 | preprocess_options = pipeline_options.view_as(PreprocessOptions) 146 | cloud_options = pipeline_options.view_as(GoogleCloudOptions) 147 | output_dir = os.path.join(preprocess_options.output, 148 | datetime.datetime.now().strftime('%Y%m%d-%H%M%S')) 149 | pipeline_options.view_as(SetupOptions).save_main_session = True 150 | pipeline_options.view_as( 151 | WorkerOptions).autoscaling_algorithm = 'THROUGHPUT_BASED' 152 | cloud_options.staging_location = os.path.join(output_dir, 'tmp', 'staging') 153 | cloud_options.temp_location = os.path.join(output_dir, 'tmp') 154 | cloud_options.job_name = 'preprocess-varianteatures-%s' % ( 155 | datetime.datetime.now().strftime('%y%m%d-%H%M%S')) 156 | 157 | metadata_query = str( 158 | Template(open(preprocess_options.metadata, 'r').read()).render( 159 | METADATA_QUERY_REPLACEMENTS)) 160 | logging.info('metadata query : %s', metadata_query) 161 | 162 | data_query = str( 163 | Template(open(preprocess_options.input, 'r').read()).render( 164 | DATA_QUERY_REPLACEMENTS)) 165 | logging.info('data query : %s', data_query) 166 | 167 | # Assemble the strategies to be used to convert the raw data to features. 168 | variant_to_feature_name_fn = variant_encoder.variant_to_contig_feature_name 169 | if preprocess_options.bin_size is not None: 170 | variant_to_feature_name_fn = variant_encoder.build_variant_to_binned_feature_name( 171 | bin_size=preprocess_options.bin_size) 172 | 173 | variants_to_features_fn = variant_encoder.build_variants_to_features( 174 | variant_to_feature_name_fn=variant_to_feature_name_fn, 175 | variant_to_words_fn=variant_encoder.build_variant_to_words( 176 | add_hethom=preprocess_options.add_hethom)) 177 | 178 | sample_to_example_fn = encoder.build_sample_to_example( 179 | metadata_to_features_fn=metadata_encoder.metadata_to_ancestry_features, 180 | variants_to_features_fn=variants_to_features_fn) 181 | 182 | with beam.Pipeline(options=pipeline_options) as p: 183 | # Gather our sample metadata into a python dictionary. 184 | samples_metadata = ( 185 | p 186 | | 'ReadSampleMetadata' >> beam.io.Read( 187 | beam.io.BigQuerySource(query=metadata_query, use_standard_sql=True)) 188 | | 'TableToDictionary' >> beam.CombineGlobally( 189 | util.TableToDictCombineFn(key_column=encoder.KEY_COLUMN))) 190 | 191 | # Read the table rows into a PCollection. 192 | rows = p | 'ReadVariants' >> beam.io.Read( 193 | beam.io.BigQuerySource(query=data_query, use_standard_sql=True)) 194 | 195 | # Convert the data into TensorFlow Example Protocol Buffers. 196 | examples = variants_to_examples( 197 | rows, samples_metadata, sample_to_example_fn=sample_to_example_fn) 198 | 199 | # Write the serialized compressed protocol buffers to Cloud Storage. 200 | _ = (examples 201 | | 'EncodeExamples' >> beam.Map( 202 | lambda example: example.SerializeToString()) 203 | | 'WriteExamples' >> tfrecordio.WriteToTFRecord( 204 | file_path_prefix=os.path.join(output_dir, 'examples'), 205 | compression_type=CompressionTypes.GZIP, 206 | file_name_suffix='.tfrecord.gz')) 207 | 208 | 209 | if __name__ == '__main__': 210 | run() 211 | -------------------------------------------------------------------------------- /trainer/revise_preprocessed_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | # Copyright 2017 Google Inc. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | r"""Filter and revise a collection of TensorFlow Example protos. 17 | 18 | This pipeline is useful when a full dataset has been preprocessed but for 19 | a subsequent experiment you wish to use only a subset of the examples and/or 20 | modify the labels. The subset to be copied is determined by the list of 21 | sample names returned by the metadata query. 22 | 23 | USAGE: 24 | python -m trainer.revise_preprocessed_data \ 25 | --setup_file ./setup.py \ 26 | --project ${PROJECT_ID} \ 27 | --input ${BUCKET}/sgdp \ 28 | --metadata preprocess/sgdp_metadata_remap_labels.jinja \ 29 | --output ${BUCKET}/sgdp_relabeled_subset 30 | """ 31 | 32 | import datetime 33 | import logging 34 | import os 35 | 36 | import apache_beam as beam 37 | from apache_beam.io import tfrecordio 38 | from apache_beam.io.filesystem import CompressionTypes 39 | from apache_beam.options.pipeline_options import GoogleCloudOptions 40 | from apache_beam.options.pipeline_options import PipelineOptions 41 | from apache_beam.options.pipeline_options import SetupOptions 42 | from apache_beam.options.pipeline_options import WorkerOptions 43 | from jinja2 import Template 44 | import tensorflow as tf 45 | 46 | import trainer.ancestry_metadata_encoder as metadata_encoder 47 | import trainer.feature_encoder as encoder 48 | import trainer.util as util 49 | 50 | METADATA_QUERY_REPLACEMENTS = { 51 | 'KEY_COLUMN': encoder.KEY_COLUMN, 52 | 'POPULATION_COLUMN': metadata_encoder.POPULATION_COLUMN, 53 | 'SUPER_POPULATION_COLUMN': metadata_encoder.SUPER_POPULATION_COLUMN, 54 | 'GENDER_COLUMN': metadata_encoder.GENDER_COLUMN, 55 | } 56 | 57 | 58 | def filter_and_revise_example(serialized_example, samples_metadata): 59 | """Filter and revise a collection of existing TensorFlow examples. 60 | 61 | Args: 62 | serialized_example: the example to be revised and/or filtered 63 | samples_metadata: dictionary of metadata for all samples 64 | 65 | Returns: 66 | A list containing the revised example or the empty list if the 67 | example should be removed from the collection. 68 | """ 69 | example = tf.train.Example.FromString(serialized_example) 70 | sample_name = example.features.feature[ 71 | encoder.SAMPLE_NAME_FEATURE].bytes_list.value[0] 72 | logging.info('Checking ' + sample_name) 73 | if sample_name not in samples_metadata: 74 | logging.info('Omitting ' + sample_name) 75 | return [] 76 | 77 | revised_features = {} 78 | # Initialize with current example features. 79 | revised_features.update(example.features.feature) 80 | # Overwrite metadata features. 81 | revised_features.update( 82 | metadata_encoder.metadata_to_ancestry_features( 83 | samples_metadata[sample_name])) 84 | return [ 85 | tf.train.Example(features=tf.train.Features(feature=revised_features)) 86 | ] 87 | 88 | 89 | class ReviseOptions(PipelineOptions): 90 | 91 | @classmethod 92 | def _add_argparse_args(cls, parser): 93 | parser.add_argument( 94 | '--input', 95 | dest='input', 96 | required=True, 97 | help='Input directory holding the previously preprocessed examples.') 98 | parser.add_argument( 99 | '--metadata', 100 | dest='metadata', 101 | required=True, 102 | help='Jinja file holding the metadata query.') 103 | parser.add_argument( 104 | '--output', 105 | dest='output', 106 | required=True, 107 | help='Output directory to which to write filtered and revised ' 108 | 'examples.') 109 | 110 | 111 | def run(argv=None): 112 | """Runs the revise preprocessed data pipeline. 113 | 114 | Args: 115 | argv: Pipeline options as a list of arguments. 116 | """ 117 | pipeline_options = PipelineOptions(flags=argv) 118 | revise_options = pipeline_options.view_as(ReviseOptions) 119 | cloud_options = pipeline_options.view_as(GoogleCloudOptions) 120 | output_dir = os.path.join(revise_options.output, 121 | datetime.datetime.now().strftime('%Y%m%d-%H%M%S')) 122 | pipeline_options.view_as(SetupOptions).save_main_session = True 123 | pipeline_options.view_as( 124 | WorkerOptions).autoscaling_algorithm = 'THROUGHPUT_BASED' 125 | cloud_options.staging_location = os.path.join(output_dir, 'tmp', 'staging') 126 | cloud_options.temp_location = os.path.join(output_dir, 'tmp') 127 | cloud_options.job_name = 'relabel-examples-%s' % ( 128 | datetime.datetime.now().strftime('%y%m%d-%H%M%S')) 129 | 130 | metadata_query = str( 131 | Template(open(revise_options.metadata, 'r').read()).render( 132 | METADATA_QUERY_REPLACEMENTS)) 133 | logging.info('metadata query : %s', metadata_query) 134 | 135 | with beam.Pipeline(options=pipeline_options) as p: 136 | # Gather our sample metadata into a python dictionary. 137 | samples_metadata = ( 138 | p 139 | | 'ReadSampleMetadata' >> beam.io.Read( 140 | beam.io.BigQuerySource(query=metadata_query, use_standard_sql=True)) 141 | | 'TableToDictionary' >> beam.CombineGlobally( 142 | util.TableToDictCombineFn(key_column=encoder.KEY_COLUMN))) 143 | 144 | # Read the tf.Example protos into a PCollection. 145 | examples = p | 'ReadExamples' >> tfrecordio.ReadFromTFRecord( 146 | file_pattern=revise_options.input, 147 | compression_type=CompressionTypes.GZIP) 148 | 149 | # Filter the TensorFlow Example Protocol Buffers. 150 | filtered_examples = (examples | 'ReviseExamples' >> beam.FlatMap( 151 | lambda example, samples_metadata: 152 | filter_and_revise_example(example, samples_metadata), 153 | beam.pvalue.AsSingleton(samples_metadata))) 154 | 155 | # Write the subset of tf.Example protos to Cloud Storage. 156 | _ = (filtered_examples 157 | | 'SerializeExamples' >> 158 | beam.Map(lambda example: example.SerializeToString()) 159 | | 'WriteExamples' >> tfrecordio.WriteToTFRecord( 160 | file_path_prefix=os.path.join(output_dir, 'examples'), 161 | compression_type=CompressionTypes.GZIP, 162 | file_name_suffix='.tfrecord.gz')) 163 | 164 | 165 | if __name__ == '__main__': 166 | logging.getLogger().setLevel(logging.INFO) 167 | run() 168 | -------------------------------------------------------------------------------- /trainer/util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | # Copyright 2016 Google Inc. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Reusable utility functions. 17 | 18 | This file is generic and can be reused by other models without modification. 19 | """ 20 | 21 | from apache_beam.transforms import core 22 | import tensorflow as tf 23 | 24 | 25 | def int64_feature(value): 26 | """Create a multi-valued int64 feature from a single value.""" 27 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 28 | 29 | 30 | def bytes_feature(value): 31 | """Create a multi-valued bytes feature from a single value.""" 32 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 33 | 34 | 35 | def float_feature(value): 36 | """Create a multi-valued float feature from a single value.""" 37 | return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) 38 | 39 | 40 | class DefaultToKeyDict(dict): 41 | """Custom dictionary to use the key as the value for any missing entries.""" 42 | 43 | def __missing__(self, key): 44 | return str(key) 45 | 46 | 47 | class TableToDictCombineFn(core.CombineFn): 48 | """Beam transform to create a python dictionary from a BigQuery table. 49 | 50 | This CombineFn reshapes rows from a BigQuery table using the specified key 51 | column to a Python dictionary. 52 | """ 53 | 54 | def __init__(self, key_column): 55 | self.key_column = key_column 56 | 57 | def create_accumulator(self): 58 | return dict() 59 | 60 | def add_input(self, accumulator, element): 61 | accumulator[element[self.key_column]] = element 62 | return accumulator 63 | 64 | def add_inputs(self, accumulator, elements): 65 | for element in elements: 66 | self.add_input(accumulator, element) 67 | return accumulator 68 | 69 | def merge_accumulators(self, accumulators): 70 | final_accumulator = {} 71 | for accumulator in accumulators: 72 | final_accumulator.update(accumulator) 73 | return final_accumulator 74 | 75 | def extract_output(self, accumulator): 76 | return accumulator 77 | -------------------------------------------------------------------------------- /trainer/variant_encoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | # Copyright 2016 Google Inc. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Encode variant calls as TensorFlow features.""" 17 | 18 | import collections 19 | import struct 20 | 21 | import farmhash 22 | import tensorflow as tf 23 | import trainer.feature_encoder as encoder 24 | import trainer.util as util 25 | 26 | # Normalize reference names, defaulting to the vertabim reference name 27 | # if value is not present in the dictionary. 28 | CONTIG_MAP = util.DefaultToKeyDict() 29 | CONTIG_MAP.update({('chr' + str(x)): str(x) for x in range(1, 23)}) 30 | CONTIG_MAP['chrX'] = 'X' 31 | CONTIG_MAP['chrY'] = 'Y' 32 | CONTIG_MAP['chrM'] = 'MT' 33 | CONTIG_MAP['chrMT'] = 'MT' 34 | CONTIG_MAP['M'] = 'MT' 35 | 36 | 37 | def normalize_contig_name(variant): 38 | """Normalize reference (contig) names. 39 | 40 | For example chromosome X might be 'X' in one dataset and 'chrX' in 41 | another. 42 | 43 | Args: 44 | variant: a variant call 45 | 46 | Returns: 47 | The canonical name of the reference (contig) specified in the 48 | variant call. 49 | """ 50 | return CONTIG_MAP[variant[encoder.CONTIG_COLUMN]] 51 | 52 | 53 | def variant_to_contig_feature_name(variant): 54 | """Create the feature name for the variant call. 55 | 56 | In this implementation the feature name is merly the name of the 57 | reference (contig) within the variant call. 58 | 59 | Args: 60 | variant: a variant call 61 | 62 | Returns: 63 | The name for the feature in which this variant should be stored. 64 | """ 65 | # Use normalized reference name as feature name. 66 | return normalize_contig_name(variant) 67 | 68 | 69 | def sample_has_variant(variant): 70 | """Check whether the sample has this particular variant. 71 | 72 | Since the input data was FLATTENED on alternate_bases, we do this by 73 | checking whether either allele value corresponds to the alternate 74 | currently under consideration. Note that the values of the first 75 | allele and the second allele are genotypes --> which are essentially 76 | an index into the alternate_bases repeated field. See 77 | http://vcftools.sourceforge.net/VCF-poster.pdf for more detail. 78 | 79 | Args: 80 | variant: a variant call 81 | 82 | Returns: 83 | A count of the alleles for this alternate. This count can also be 84 | interpreted as a boolean to indicate whether or not the sample 85 | has this particular variant allele. 86 | """ 87 | alt_num = int(variant[encoder.ALT_NUM_COLUMN]) 88 | return ((variant[encoder.FIRST_ALLELE_COLUMN] == alt_num) + 89 | (encoder.SECOND_ALLELE_COLUMN in variant and 90 | variant[encoder.SECOND_ALLELE_COLUMN] == alt_num)) 91 | 92 | 93 | def build_variant_to_binned_feature_name(bin_size=1000000): 94 | """Builder for strategy for separate features for contiguous genomic regions. 95 | 96 | Args: 97 | bin_size: the maximum size of each contiguous genomic region. 98 | 99 | Returns: 100 | The instantiated strategy. 101 | """ 102 | 103 | def variant_to_binned_feature_name(variant): 104 | """Create the feature name for the variant call. 105 | 106 | In this implementation the feature name is merly the name of the 107 | reference (contig) within the variant call contenated with the bin 108 | number of the genomic region in which the variant call resides. 109 | 110 | Args: 111 | variant: a variant call 112 | 113 | Returns: 114 | The name for the feature in which this variant call should be stored. 115 | """ 116 | # variant_bin will be the floor result of the division since bin_size 117 | # is an integer. 118 | variant_bin = int(variant[encoder.START_COLUMN]) / bin_size 119 | return '_'.join([normalize_contig_name(variant), str(variant_bin)]) 120 | 121 | return variant_to_binned_feature_name 122 | 123 | 124 | def build_variant_to_words(add_hethom=True): 125 | """Builder for strategy to convert a variant to words. 126 | 127 | This encoder will create separate bag-of-words features for each 128 | reference name (contig) in the source data. 129 | 130 | Args: 131 | add_hethom: whether or not to add additional words representing 132 | the zygosity of the variant call. 133 | 134 | Returns: 135 | The instantiated strategy. 136 | """ 137 | 138 | def variant_to_words(variant): 139 | """Encode a variant call as one or more "words". 140 | 141 | Given a variant call record with certain expected fields, create 142 | "words" that uniquely describe it. The first word would match 143 | both heterozygous or homozygous variant calls. A second word is 144 | created when add_hethom=True and is more specific, matching just 145 | one of heterozygous or homozygous. 146 | 147 | Args: 148 | variant: a variant call 149 | 150 | Returns: 151 | One or more "words" that represent the variant call. 152 | """ 153 | # Only add words only if the sample has a variant at this site. 154 | if not sample_has_variant(variant): 155 | return [] 156 | 157 | # Normalize reference names in the words. 158 | contig = normalize_contig_name(variant) 159 | 160 | variant_word = '_'.join([ 161 | contig, 162 | str(variant[encoder.START_COLUMN]), 163 | str(variant[encoder.END_COLUMN]), 164 | str(variant[encoder.REF_COLUMN]), 165 | str(variant[encoder.ALT_COLUMN]) 166 | ]) 167 | 168 | if not add_hethom: 169 | return [variant_word] 170 | 171 | # Add two words, one for the variant itself and another more specific word 172 | # (a synonym) regarding heterozygosity/homozygosity of the observation. 173 | if ((encoder.SECOND_ALLELE_COLUMN not in variant) or 174 | (variant[encoder.FIRST_ALLELE_COLUMN] != 175 | variant[encoder.SECOND_ALLELE_COLUMN])): 176 | return [variant_word, '_'.join([variant_word, 'het'])] 177 | 178 | return [variant_word, '_'.join([variant_word, 'hom'])] 179 | 180 | return variant_to_words 181 | 182 | 183 | def build_variants_to_features(variant_to_feature_name_fn, variant_to_words_fn): 184 | """Builder for the strategy to convert variants to bag-of-words features. 185 | 186 | Args: 187 | variant_to_feature_name_fn: strategy to determine the feature name (bag 188 | name) from the variant 189 | variant_to_words_fn: strategy to encode the variant as one or more words 190 | 191 | Returns: 192 | The instantiated strategy. 193 | """ 194 | 195 | def variants_to_features(sample_variants): 196 | """Convert variant calls to TensorFlow features. 197 | 198 | See also 199 | https://www.tensorflow.org/versions/r0.10/how_tos/reading_data/index.html 200 | 201 | Args: 202 | sample_variants: the sample's variant calls 203 | 204 | Returns: 205 | A dictionary of TensorFlow features. 206 | """ 207 | variants_by_feature = collections.defaultdict(list) 208 | for variant in sample_variants: 209 | feature_name = variant_to_feature_name_fn(variant) 210 | words = variant_to_words_fn(variant) 211 | variants_by_feature[feature_name].extend( 212 | # fingerprint64 returns an unsigned int64 but int64 features are 213 | # signed. Convert from from unsigned to signed. 214 | [ 215 | struct.unpack('q', struct.pack('Q', farmhash.fingerprint64(w)))[0] 216 | for w in words 217 | ]) 218 | 219 | # Fill in features from variants. 220 | features = {} 221 | for feature, variants in variants_by_feature.iteritems(): 222 | if variants: 223 | features['variants_' + feature] = tf.train.Feature( 224 | int64_list=tf.train.Int64List(value=variants)) 225 | 226 | return features 227 | 228 | return variants_to_features 229 | -------------------------------------------------------------------------------- /trainer/variant_encoder_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | # Copyright 2016 Google Inc. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Test encoding of variant calls to TensorFlow features.""" 17 | 18 | import copy 19 | import unittest 20 | 21 | import trainer.feature_encoder as encoder 22 | import trainer.variant_encoder as variant_encoder 23 | 24 | SAMPLE_ID = 'sample1' 25 | 26 | HETEROZYGOUS_VARIANT_CALL = { 27 | encoder.KEY_COLUMN: SAMPLE_ID, 28 | encoder.CONTIG_COLUMN: 'chr9', 29 | encoder.START_COLUMN: 3500000, 30 | encoder.END_COLUMN: 3500001, 31 | encoder.REF_COLUMN: 'T', 32 | encoder.ALT_COLUMN: 'G', 33 | encoder.ALT_NUM_COLUMN: 1, 34 | encoder.FIRST_ALLELE_COLUMN: 0, 35 | encoder.SECOND_ALLELE_COLUMN: 1 36 | } 37 | 38 | HOMOZYGOUS_ALT_VARIANT_CALL = copy.copy(HETEROZYGOUS_VARIANT_CALL) 39 | HOMOZYGOUS_ALT_VARIANT_CALL[encoder.FIRST_ALLELE_COLUMN] = 1 40 | HOMOZYGOUS_ALT_VARIANT_CALL[encoder.SECOND_ALLELE_COLUMN] = 1 41 | 42 | HOMOZYGOUS_REF_VARIANT_CALL = copy.copy(HETEROZYGOUS_VARIANT_CALL) 43 | HOMOZYGOUS_REF_VARIANT_CALL[encoder.FIRST_ALLELE_COLUMN] = 0 44 | HOMOZYGOUS_REF_VARIANT_CALL[encoder.SECOND_ALLELE_COLUMN] = 0 45 | 46 | 47 | class VariantEncoderTest(unittest.TestCase): 48 | 49 | def test_normalize_autosome_contig_names(self): 50 | self.assertEqual('1', 51 | variant_encoder.normalize_contig_name({ 52 | encoder.CONTIG_COLUMN: '1' 53 | })) 54 | self.assertEqual('1', 55 | variant_encoder.normalize_contig_name({ 56 | encoder.CONTIG_COLUMN: 'chr1' 57 | })) 58 | self.assertEqual('21', 59 | variant_encoder.normalize_contig_name({ 60 | encoder.CONTIG_COLUMN: '21' 61 | })) 62 | self.assertEqual('21', 63 | variant_encoder.normalize_contig_name({ 64 | encoder.CONTIG_COLUMN: 'chr21' 65 | })) 66 | 67 | def test_normalize_sex_contig_names(self): 68 | self.assertEqual('X', 69 | variant_encoder.normalize_contig_name({ 70 | encoder.CONTIG_COLUMN: 'X' 71 | })) 72 | self.assertEqual('X', 73 | variant_encoder.normalize_contig_name({ 74 | encoder.CONTIG_COLUMN: 'chrX' 75 | })) 76 | self.assertEqual('Y', 77 | variant_encoder.normalize_contig_name({ 78 | encoder.CONTIG_COLUMN: 'Y' 79 | })) 80 | self.assertEqual('Y', 81 | variant_encoder.normalize_contig_name({ 82 | encoder.CONTIG_COLUMN: 'chrY' 83 | })) 84 | 85 | def test_normalize_mitochondrial_contig_names(self): 86 | self.assertEqual('MT', 87 | variant_encoder.normalize_contig_name({ 88 | encoder.CONTIG_COLUMN: 'MT' 89 | })) 90 | self.assertEqual('MT', 91 | variant_encoder.normalize_contig_name({ 92 | encoder.CONTIG_COLUMN: 'M' 93 | })) 94 | self.assertEqual('MT', 95 | variant_encoder.normalize_contig_name({ 96 | encoder.CONTIG_COLUMN: 'chrM' 97 | })) 98 | self.assertEqual('MT', 99 | variant_encoder.normalize_contig_name({ 100 | encoder.CONTIG_COLUMN: 'chrMT' 101 | })) 102 | 103 | def test_normalize_other_contig_names(self): 104 | # All others pass through as-is. 105 | self.assertEqual('KI270375.1', 106 | variant_encoder.normalize_contig_name({ 107 | encoder.CONTIG_COLUMN: 'KI270375.1' 108 | })) 109 | 110 | def test_variant_to_feature_name(self): 111 | self.assertEqual('9', 112 | variant_encoder.variant_to_contig_feature_name({ 113 | encoder.CONTIG_COLUMN: 'chr9', 114 | encoder.START_COLUMN: 3500000 115 | })) 116 | 117 | def test_variant_to_binned_feature_name(self): 118 | variant_to_feature_name_fn = ( 119 | variant_encoder.build_variant_to_binned_feature_name()) 120 | self.assertEqual('9_3', 121 | variant_to_feature_name_fn({ 122 | encoder.CONTIG_COLUMN: 'chr9', 123 | encoder.START_COLUMN: 3500000 124 | })) 125 | self.assertEqual('9_4', 126 | variant_to_feature_name_fn({ 127 | encoder.CONTIG_COLUMN: 'chr9', 128 | encoder.START_COLUMN: 4000000 129 | })) 130 | 131 | def test_variant_to_smaller_binned_feature_name(self): 132 | variant_to_feature_name_fn = ( 133 | variant_encoder.build_variant_to_binned_feature_name(bin_size=100000)) 134 | self.assertEqual('9_35', 135 | variant_to_feature_name_fn({ 136 | encoder.CONTIG_COLUMN: 'chr9', 137 | encoder.START_COLUMN: 3500000 138 | })) 139 | self.assertEqual('9_40', 140 | variant_to_feature_name_fn({ 141 | encoder.CONTIG_COLUMN: 'chr9', 142 | encoder.START_COLUMN: 4000000 143 | })) 144 | 145 | def test_variant_to_words(self): 146 | variant_to_words_fn = variant_encoder.build_variant_to_words( 147 | add_hethom=False) 148 | self.assertEqual(['9_3500000_3500001_T_G'], 149 | variant_to_words_fn(HETEROZYGOUS_VARIANT_CALL)) 150 | self.assertEqual(['9_3500000_3500001_T_G'], 151 | variant_to_words_fn(HOMOZYGOUS_ALT_VARIANT_CALL)) 152 | self.assertEqual([], variant_to_words_fn(HOMOZYGOUS_REF_VARIANT_CALL)) 153 | 154 | def test_variant_to_words_add_het_hom(self): 155 | variant_to_words_fn = variant_encoder.build_variant_to_words() 156 | self.assertEqual(['9_3500000_3500001_T_G', '9_3500000_3500001_T_G_het'], 157 | variant_to_words_fn(HETEROZYGOUS_VARIANT_CALL)) 158 | self.assertEqual(['9_3500000_3500001_T_G', '9_3500000_3500001_T_G_hom'], 159 | variant_to_words_fn(HOMOZYGOUS_ALT_VARIANT_CALL)) 160 | self.assertEqual([], variant_to_words_fn(HOMOZYGOUS_REF_VARIANT_CALL)) 161 | 162 | 163 | if __name__ == '__main__': 164 | unittest.main() 165 | -------------------------------------------------------------------------------- /trainer/variants_inference.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tensorflow implementation of variants inference.""" 15 | 16 | import functools 17 | import os 18 | import sys 19 | 20 | 21 | from absl import flags 22 | import tensorflow as tf 23 | 24 | from tensorflow.contrib.learn.python.learn import learn_runner 25 | from tensorflow.contrib.learn.python.learn.estimators import model_fn 26 | from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec 27 | from tensorflow.contrib.learn.python.learn.utils import input_fn_utils 28 | from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils 29 | from tensorflow.python.lib.io.tf_record import TFRecordCompressionType 30 | 31 | 32 | DEFAULT_OUTPUT_ALTERNATIVE = "output_alt" 33 | PREDICTION_KEY = "key" 34 | PREDICTION_EXAMPLES = "examples" 35 | logging = tf.logging 36 | metrics_lib = tf.contrib.metrics 37 | 38 | flags.DEFINE_float( 39 | "learning_rate", 0.001, "Learning rate.") 40 | flags.DEFINE_float( 41 | "momentum", 0.9, "Momentum.") 42 | flags.DEFINE_integer( 43 | "num_classes", None, "The number of classes on the dataset.") 44 | flags.DEFINE_integer( 45 | "hidden_units", None, "The number of hidden units on the hidden layer.") 46 | flags.DEFINE_integer( 47 | "num_buckets", 100000, "The number of buckets to use for hashing.") 48 | flags.DEFINE_integer( 49 | "embedding_dimension", 100, 50 | "The total embedding dimension is obtained by multiplying this number " 51 | "by the number of feature columns.") 52 | flags.DEFINE_integer( 53 | "batch_size", 10, "The size of the train and test batches.") 54 | flags.DEFINE_integer( 55 | "feature_queue_capacity", 5, "The size of the feature queue.") 56 | flags.DEFINE_string( 57 | "input_dir", None, "Path to the input files.") 58 | flags.DEFINE_string( 59 | "eval_dir", None, 60 | "If specified use a separate eval dataset. The dataset labels and label " 61 | "indexes should match exactly those of the training set for the evaluation " 62 | "to work correctly. In other words you only need to use this if you want a " 63 | " different train/eval split than the one provided by default.") 64 | flags.DEFINE_string( 65 | "export_dir", "", 66 | "The directory in which the saved model will be stored.") 67 | flags.DEFINE_string( 68 | "job-dir", "", 69 | "Base output directory. Used by the local and cloud jobs.") 70 | flags.DEFINE_boolean( 71 | "use_integerized_features", True, 72 | "Whether the features are int64 values.") 73 | flags.DEFINE_boolean( 74 | "use_gzip", True, 75 | "Whether the tfrecord files are compressed.") 76 | flags.DEFINE_integer( 77 | "num_train_steps", 10000, 78 | "Number of training iterations. None means continuous training.") 79 | flags.DEFINE_integer( 80 | "num_eval_steps", 10, 81 | "Number of evaluation iterations. When running continuous_eval, this is " 82 | "the number of eval steps run for each evaluation of a checkpoint.") 83 | flags.DEFINE_string( 84 | "target_field", None, 85 | "The name of the field that contains the labels.") 86 | flags.DEFINE_string( 87 | "id_field", "sample_name", 88 | "The name of the field that contains the sample ids.") 89 | flags.DEFINE_string( 90 | "sparse_features", None, 91 | "A list of the sparse features to process. For example " 92 | "variants_2,variants_17. Alternatively specify 'all_not_x_y' " 93 | "to indicate chromosomes 1 through 22.") 94 | flags.DEFINE_string( 95 | "eval_labels", "", 96 | "Optional, comma separated values of the labels used for" 97 | "per-class evaluation, the order should be the same as the one used " 98 | "when extracting the features.") 99 | 100 | FLAGS = flags.FLAGS 101 | FLAGS(sys.argv) 102 | 103 | SPARSE_FEATURE_NAMES = ["variants"] 104 | 105 | 106 | def _get_feature_names(): 107 | sparse_features = FLAGS.sparse_features.split(",") 108 | if sparse_features and sparse_features[0] == "all_not_x_y": 109 | return ["variants_%s" % ref for ref in range(1, 23)] 110 | else: 111 | return sparse_features 112 | 113 | 114 | def _get_eval_labels(): 115 | return enumerate(FLAGS.eval_labels.split(",")) if FLAGS.eval_labels else [] 116 | 117 | 118 | def _get_feature_columns(include_target_column): 119 | """Generates a tuple of `FeatureColumn` objects for our inputs. 120 | 121 | Args: 122 | include_target_column: Whether to include the target columns. 123 | 124 | Returns: 125 | Tuple of `FeatureColumn` objects. 126 | """ 127 | embedding_columns = [] 128 | for column_name in _get_feature_names(): 129 | if FLAGS.use_integerized_features: 130 | sparse_column = tf.contrib.layers.sparse_column_with_integerized_feature( 131 | column_name=column_name, 132 | bucket_size=FLAGS.num_buckets, 133 | combiner="sqrtn", 134 | dtype=tf.int64) 135 | else: 136 | sparse_column = tf.contrib.layers.sparse_column_with_hash_bucket( 137 | column_name=column_name, 138 | hash_bucket_size=FLAGS.num_buckets, 139 | combiner="sqrtn", 140 | dtype=tf.string) 141 | 142 | embedding = tf.contrib.layers.embedding_column( 143 | sparse_id_column=sparse_column, 144 | combiner="sqrtn", 145 | dimension=FLAGS.embedding_dimension) 146 | embedding_columns.append(embedding) 147 | feature_columns = tuple(sorted(embedding_columns)) 148 | if include_target_column: 149 | label_column = tf.contrib.layers.real_valued_column( 150 | FLAGS.target_field, dtype=tf.int64) 151 | feature_columns += (label_column,) 152 | return feature_columns 153 | 154 | 155 | def _build_input_fn(input_file_pattern, batch_size, mode): 156 | """Build input function. 157 | 158 | Args: 159 | input_file_pattern: The file patter for examples 160 | batch_size: Batch size 161 | mode: The execution mode, as defined in tf.contrib.learn.ModeKeys. 162 | 163 | Returns: 164 | Tuple, dictionary of feature column name to tensor and labels. 165 | """ 166 | def _input_fn(): 167 | """Supplies the input to the model. 168 | 169 | Returns: 170 | A tuple consisting of 1) a dictionary of tensors whose keys are 171 | the feature names, and 2) a tensor of target labels if the mode 172 | is not INFER (and None, otherwise). 173 | """ 174 | logging.info("Reading files from %s", input_file_pattern) 175 | input_files = sorted(list(tf.gfile.Glob(input_file_pattern))) 176 | logging.info("Reading files from %s", input_files) 177 | include_target_column = (mode != tf.contrib.learn.ModeKeys.INFER) 178 | features_spec = tf.contrib.layers.create_feature_spec_for_parsing( 179 | feature_columns=_get_feature_columns(include_target_column)) 180 | 181 | if FLAGS.use_gzip: 182 | def gzip_reader(): 183 | return tf.TFRecordReader( 184 | options=tf.python_io.TFRecordOptions( 185 | compression_type=TFRecordCompressionType.GZIP)) 186 | reader_fn = gzip_reader 187 | else: 188 | reader_fn = tf.TFRecordReader 189 | 190 | features = tf.contrib.learn.io.read_batch_features( 191 | file_pattern=input_files, 192 | batch_size=batch_size, 193 | queue_capacity=3*batch_size, 194 | randomize_input=mode == tf.contrib.learn.ModeKeys.TRAIN, 195 | feature_queue_capacity=FLAGS.feature_queue_capacity, 196 | reader=reader_fn, 197 | features=features_spec) 198 | target = None 199 | if include_target_column: 200 | target = features.pop(FLAGS.target_field) 201 | return features, target 202 | 203 | return _input_fn 204 | 205 | 206 | def _predict_input_fn(): 207 | """Supplies the input to the model. 208 | 209 | Returns: 210 | A tuple consisting of 1) a dictionary of tensors whose keys are 211 | the feature names, and 2) a tensor of target labels if the mode 212 | is not INFER (and None, otherwise). 213 | """ 214 | feature_spec = tf.contrib.layers.create_feature_spec_for_parsing( 215 | feature_columns=_get_feature_columns(include_target_column=False)) 216 | 217 | feature_spec[FLAGS.id_field] = tf.FixedLenFeature([], dtype=tf.string) 218 | feature_spec[FLAGS.target_field + "_string"] = tf.FixedLenFeature( 219 | [], dtype=tf.string) 220 | 221 | # Add a placeholder for the serialized tf.Example proto input. 222 | examples = tf.placeholder(tf.string, shape=(None,), name="examples") 223 | 224 | features = tf.parse_example(examples, feature_spec) 225 | features[PREDICTION_KEY] = features[FLAGS.id_field] 226 | 227 | inputs = {PREDICTION_EXAMPLES: examples} 228 | 229 | return input_fn_utils.InputFnOps( 230 | features=features, labels=None, default_inputs=inputs) 231 | 232 | 233 | def _build_model_fn(): 234 | """Build model function. 235 | 236 | Returns: 237 | A model function that can be passed to `Estimator` constructor. 238 | """ 239 | def _model_fn(features, labels, mode): 240 | """Creates the prediction and its loss. 241 | 242 | Args: 243 | features: A dictionary of tensors keyed by the feature name. 244 | labels: A tensor representing the labels. 245 | mode: The execution mode, as defined in tf.contrib.learn.ModeKeys. 246 | 247 | Returns: 248 | A tuple consisting of the prediction, loss, and train_op. 249 | """ 250 | # Generate one embedding per sparse feature column and concatenate them. 251 | concat_embeddings = tf.contrib.layers.input_from_feature_columns( 252 | columns_to_tensors=features, 253 | feature_columns=_get_feature_columns(include_target_column=False)) 254 | 255 | # Add one hidden layer. 256 | hidden_layer_0 = tf.contrib.layers.relu( 257 | concat_embeddings, FLAGS.hidden_units) 258 | 259 | # Output and logistic loss. 260 | logits = tf.contrib.layers.linear(hidden_layer_0, FLAGS.num_classes) 261 | 262 | predictions = tf.contrib.layers.softmax(logits) 263 | if mode == tf.contrib.learn.ModeKeys.INFER: 264 | predictions = { 265 | tf.contrib.learn.PredictionKey.PROBABILITIES: predictions, 266 | PREDICTION_KEY: features[PREDICTION_KEY] 267 | } 268 | output_alternatives = { 269 | DEFAULT_OUTPUT_ALTERNATIVE: (tf.contrib.learn.ProblemType.UNSPECIFIED, 270 | predictions) 271 | } 272 | return model_fn.ModelFnOps( 273 | mode=mode, 274 | predictions=predictions, 275 | output_alternatives=output_alternatives) 276 | 277 | target_one_hot = tf.one_hot(labels, FLAGS.num_classes) 278 | target_one_hot = tf.reduce_sum( 279 | input_tensor=target_one_hot, reduction_indices=[1]) 280 | loss = tf.losses.softmax_cross_entropy(target_one_hot, logits) 281 | if mode == tf.contrib.learn.ModeKeys.EVAL: 282 | return predictions, loss, None 283 | 284 | opt = tf.train.MomentumOptimizer(FLAGS.learning_rate, FLAGS.momentum) 285 | train_op = tf.contrib.layers.optimize_loss( 286 | loss=loss, 287 | global_step=tf.contrib.framework.get_global_step(), 288 | learning_rate=FLAGS.learning_rate, 289 | optimizer=opt) 290 | return model_fn.ModelFnOps( 291 | mode=mode, predictions=predictions, loss=loss, train_op=train_op) 292 | 293 | return _model_fn 294 | 295 | 296 | def _create_evaluation_metrics(): 297 | """Creates the evaluation metrics for the model. 298 | 299 | Returns: 300 | A dictionary with keys that are strings naming the evaluation 301 | metrics and values that are functions taking arguments of 302 | (predictions, targets), returning a tuple of a tensor of the metric's 303 | value together with an op to update the metric's value. 304 | """ 305 | eval_metrics = {} 306 | for k in [1]: 307 | eval_metrics["precision_at_%d" % k] = MetricSpec( 308 | metric_fn=functools.partial( 309 | tf.contrib.metrics.streaming_sparse_precision_at_k, k=k)) 310 | eval_metrics["recall_at_%d" % k] = MetricSpec(metric_fn=functools.partial( 311 | tf.contrib.metrics.streaming_sparse_recall_at_k, k=k)) 312 | 313 | for class_id, class_label in _get_eval_labels(): 314 | k = 1 315 | eval_metrics["precision_at_%d_%s" % (k, class_label)] = MetricSpec( 316 | metric_fn=functools.partial( 317 | tf.contrib.metrics.streaming_sparse_precision_at_k, 318 | k=k, 319 | class_id=class_id)) 320 | eval_metrics["recall_at_%d_%s" % (k, class_label)] = MetricSpec( 321 | metric_fn=functools.partial( 322 | tf.contrib.metrics.streaming_sparse_recall_at_k, 323 | k=k, 324 | class_id=class_id)) 325 | return eval_metrics 326 | 327 | 328 | def _def_experiment( 329 | train_file_pattern, eval_file_pattern, batch_size): 330 | """Creates the function used to configure the experiment runner. 331 | 332 | This function creates a function that is used by the learn_runner 333 | module to create an Experiment. 334 | 335 | Args: 336 | train_file_pattern: The directory the train data can be found in. 337 | eval_file_pattern: The directory the test data can be found in. 338 | batch_size: Batch size 339 | 340 | Returns: 341 | A function that creates an Experiment object for the runner. 342 | """ 343 | 344 | def _experiment_fn(output_dir): 345 | """Experiment function used by learn_runner to run training/eval/etc. 346 | 347 | Args: 348 | output_dir: String path of directory to use for outputs. 349 | 350 | Returns: 351 | tf.learn `Experiment`. 352 | """ 353 | estimator = tf.contrib.learn.Estimator( 354 | model_fn=_build_model_fn(), 355 | model_dir=output_dir) 356 | train_input_fn = _build_input_fn( 357 | input_file_pattern=train_file_pattern, 358 | batch_size=batch_size, 359 | mode=tf.contrib.learn.ModeKeys.TRAIN) 360 | eval_input_fn = _build_input_fn( 361 | input_file_pattern=eval_file_pattern, 362 | batch_size=batch_size, 363 | mode=tf.contrib.learn.ModeKeys.EVAL) 364 | 365 | return tf.contrib.learn.Experiment( 366 | estimator=estimator, 367 | train_input_fn=train_input_fn, 368 | train_steps=FLAGS.num_train_steps, 369 | eval_input_fn=eval_input_fn, 370 | eval_steps=FLAGS.num_eval_steps, 371 | eval_metrics=_create_evaluation_metrics(), 372 | min_eval_frequency=100, 373 | export_strategies=[ 374 | saved_model_export_utils.make_export_strategy( 375 | _predict_input_fn, 376 | exports_to_keep=5, 377 | default_output_alternative_key=DEFAULT_OUTPUT_ALTERNATIVE) 378 | ]) 379 | 380 | return _experiment_fn 381 | 382 | 383 | def main(unused_argv): 384 | if not FLAGS.input_dir: 385 | raise ValueError("Input dir should be specified.") 386 | 387 | if FLAGS.eval_dir: 388 | train_file_pattern = os.path.join(FLAGS.input_dir, "examples*") 389 | eval_file_pattern = os.path.join(FLAGS.eval_dir, "examples*") 390 | else: 391 | train_file_pattern = os.path.join(FLAGS.input_dir, "examples*[0-7]-of-*") 392 | eval_file_pattern = os.path.join(FLAGS.input_dir, "examples*[89]-of-*") 393 | 394 | if not FLAGS.num_classes: 395 | raise ValueError("Number of classes should be specified.") 396 | 397 | if not FLAGS.sparse_features: 398 | raise ValueError("Name of the sparse features should be specified.") 399 | 400 | learn_runner.run( 401 | experiment_fn=_def_experiment( 402 | train_file_pattern, 403 | eval_file_pattern, 404 | FLAGS.batch_size), 405 | output_dir=FLAGS.export_dir) 406 | 407 | if __name__ == "__main__": 408 | tf.logging.set_verbosity(tf.logging.INFO) 409 | tf.app.run() 410 | --------------------------------------------------------------------------------