├── .github ├── pull_request_template.md └── workflows │ └── movieLens-py37.yml ├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── NOTICE ├── README.md ├── build.gradle ├── figures ├── AliceAnnie.png ├── gdmix-kubeflow-pipeline.png ├── gdmix-operation-models.png ├── gdmix-overview.png ├── gdmix-workflow-jobs.png └── logo.png ├── gdmix-data-all └── build.gradle ├── gdmix-data ├── build.gradle └── src │ ├── main │ ├── resources │ │ └── model │ │ │ └── lr_model.avsc │ └── scala │ │ └── com │ │ └── linkedin │ │ └── gdmix │ │ ├── configs │ │ ├── EffectConfig.scala │ │ └── TensorMetadata.scala │ │ ├── data │ │ ├── BestModelSelector.scala │ │ ├── DataPartitioner.scala │ │ ├── MetadataGenerator.scala │ │ └── OffsetUpdater.scala │ │ ├── evaluation │ │ └── Evaluator.scala │ │ ├── model │ │ └── LrModelSplitter.scala │ │ ├── parsers │ │ ├── BestModelSelectorParser.scala │ │ ├── DataPartitionerParser.scala │ │ ├── EffectConfigParser.scala │ │ ├── EvaluatorParser.scala │ │ ├── LrModelSplitterParser.scala │ │ └── OffsetUpdaterParser.scala │ │ └── utils │ │ ├── Constants.scala │ │ ├── ConversionUtils.scala │ │ ├── IoUtils.scala │ │ ├── JsonUtils.scala │ │ └── PartitionUtils.scala │ └── test │ ├── resources │ ├── configs │ │ ├── ConfigWithTwoFixedEffects.json │ │ ├── EffectConfigs.json │ │ ├── EntityNotInColumns.json │ │ ├── FeatureConvertor.json │ │ └── LabelNotInColumns.json │ ├── data │ │ ├── ExpectedGlobalTrainData.avro │ │ ├── ExpectedGlobalValidData.avro │ │ ├── ExpectedPerItemTrainData.avro │ │ ├── ExpectedPerItemValidData.avro │ │ ├── ExpectedPerMemberTrainData.avro │ │ ├── ExpectedPerMemberValidData.avro │ │ └── TrainData.avro │ └── metadata │ │ ├── ExpectedGlobalFeatureList.txt │ │ ├── ExpectedGlobalMetadata.json │ │ ├── ExpectedPerItemFeatureList.txt │ │ ├── ExpectedPerItemMetadata.json │ │ ├── ExpectedPerMemberFeatureList.txt │ │ └── ExpectedPerMemberMetadata.json │ └── scala │ └── com │ └── linkedin │ └── gdmix │ ├── configs │ └── EffectConfigTest.scala │ ├── data │ ├── BestModelSelectorTest.scala │ ├── DataPartitionerTest.scala │ ├── MetadataGeneratorTest.scala │ └── OffsetUpdaterTest.scala │ ├── evaluation │ └── EvaluatorTest.scala │ ├── model │ └── LrModelSplitterTest.scala │ ├── parsers │ ├── BestModelSelectorParserTest.scala │ ├── DataPartitionerParserTest.scala │ ├── EffectConfigParserTest.scala │ ├── EvaluatorParserTest.scala │ ├── LrModelSplitterParserTest.scala │ └── OffsetUpdaterParserTest.scala │ └── utils │ ├── ConversionUtilsTest.scala │ ├── SharedSparkSession.scala │ └── TestUtils.scala ├── gdmix-trainer ├── README.md ├── setup.cfg ├── setup.py ├── src │ └── gdmix │ │ ├── __init__.py │ │ ├── drivers │ │ ├── __init__.py │ │ ├── driver.py │ │ ├── fixed_effect_driver.py │ │ └── random_effect_driver.py │ │ ├── factory │ │ ├── __init__.py │ │ ├── driver_factory.py │ │ └── model_factory.py │ │ ├── gdmix.py │ │ ├── io │ │ ├── __init__.py │ │ ├── dataset_metadata.py │ │ └── input_data_pipeline.py │ │ ├── models │ │ ├── __init__.py │ │ ├── api.py │ │ ├── custom │ │ │ ├── __init__.py │ │ │ ├── base_lr_params.py │ │ │ ├── binary_logistic_regression.py │ │ │ ├── fixed_effect_lr_lbfgs_model.py │ │ │ ├── random_effect_lr_lbfgs_model.py │ │ │ └── scipy │ │ │ │ ├── __init__.py │ │ │ │ └── job_consumers.py │ │ ├── detext │ │ │ ├── __init__.py │ │ │ └── fixed_effect_detext_model.py │ │ ├── detext_writer.py │ │ └── schemas.py │ │ ├── params.py │ │ └── util │ │ ├── __init__.py │ │ ├── constants.py │ │ ├── distribution_utils.py │ │ ├── io_utils.py │ │ └── model_utils.py └── test │ ├── drivers │ ├── __init__.py │ ├── test_driver.py │ └── test_helper.py │ ├── factory │ ├── __init__.py │ ├── test_driver_factory.py │ └── test_model_factory.py │ ├── io │ ├── test_dataset_metadata.py │ ├── test_entity_grouped_input_fn.py │ └── test_per_record_input_fn.py │ ├── models │ ├── __init__.py │ ├── custom │ │ ├── __init__.py │ │ ├── test_binary_logistic_regression.py │ │ ├── test_fixed_effect_lr_lbfgs_model.py │ │ ├── test_optimizer_helper.py │ │ └── test_random_effect_lr_lbfgs_model.py │ ├── detext │ │ ├── __init__.py │ │ └── test_detext.py │ └── test_model_api.py │ ├── resources │ ├── bert_config.json │ ├── custom │ │ └── sklearn_data.p │ ├── fe_lbfgs │ │ ├── featureList │ │ │ └── global │ │ ├── metadata │ │ │ └── tensor_metadata.json │ │ └── training_data │ │ │ └── test.tfrecord │ ├── grouped_per_member_train │ │ ├── data.json │ │ ├── data.tfrecord │ │ ├── data_intercept_only.json │ │ ├── data_with_string_entity_id.json │ │ ├── dataset_1.json │ │ ├── dataset_1_feature_file.csv │ │ └── fake_feature_file.csv │ ├── member_ids.avro │ ├── metadata │ │ ├── duplicated_names.json │ │ ├── features.txt │ │ ├── invalid_name.json │ │ ├── invalid_shape.json │ │ ├── invalid_type.json │ │ ├── partition_list.txt │ │ ├── tensor_metadata.json │ │ ├── valid.json │ │ └── valid_metadata.json │ ├── train │ │ └── dataset │ │ │ └── tfrecord │ │ │ └── test.tfrecord │ ├── validate │ │ └── data.avro │ └── vocab.txt │ └── util │ ├── test_distribution_utils.py │ ├── test_io_utils.py │ └── test_model_utils.py ├── gdmix-workflow ├── README.md ├── examples │ └── movielens-100k │ │ ├── detext-movieLens.yaml │ │ └── lr-movieLens.yaml ├── gdmix_config.md ├── images │ ├── gdmix_dev │ │ ├── Dockerfile │ │ └── build_image.sh │ └── launcher │ │ ├── common │ │ └── launch_crd.py │ │ ├── sparkapplication │ │ ├── Dockerfile │ │ ├── build_image.sh │ │ └── src │ │ │ └── launch_sparkapplication.py │ │ └── tfjob │ │ ├── Dockerfile │ │ ├── build_image.sh │ │ └── src │ │ └── launch_tfjob.py ├── setup.cfg ├── setup.py ├── src │ ├── conftest.py │ └── gdmixworkflow │ │ ├── __init__.py │ │ ├── common │ │ ├── __init__.py │ │ ├── constants.py │ │ └── utils.py │ │ ├── distributed │ │ ├── __init__.py │ │ ├── container_ops.py │ │ └── resource │ │ │ ├── __init__.py │ │ │ ├── sparkapplication_component.yaml │ │ │ └── tfjob_component.yaml │ │ ├── distributed_workflow.py │ │ ├── fixed_effect_workflow_generator.py │ │ ├── main.py │ │ ├── random_effect_workflow_generator.py │ │ ├── single_node │ │ ├── __init__.py │ │ └── local_ops.py │ │ ├── single_node_workflow.py │ │ └── workflow_generator.py └── test │ ├── common │ └── test_utils.py │ ├── resources │ ├── detext-movieLens.yaml │ └── lr-movieLens.yaml │ ├── single_node │ └── test_local_ops.py │ └── test_workflow_generator.py ├── gdmix.Dockerfile ├── gradle.properties ├── gradle └── wrapper │ ├── gradle-wrapper.jar │ └── gradle-wrapper.properties ├── gradlew ├── gradlew.bat ├── scripts └── download_process_movieLens_data.py └── settings.gradle /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | # Description 2 | 3 | Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change. 4 | 5 | Fixes # (issue) 6 | 7 | ## Type of change 8 | 9 | Please delete options that are not relevant. 10 | 11 | - [ ] Bug fix (non-breaking change which fixes an issue) 12 | - [ ] New feature (non-breaking change which adds functionality) 13 | - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) 14 | 15 | ## List all changes 16 | Please list all changes in the commit. 17 | * change1 18 | * change2 19 | 20 | # Testing 21 | Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce. Please also list any relevant details for your test configuration 22 | 23 | 24 | **Test Configuration**: 25 | * Firmware version: 26 | * Hardware: 27 | * Toolchain: 28 | * SDK: 29 | 30 | # Checklist 31 | 32 | - [ ] My code follows the style guidelines of this project 33 | - [ ] I have performed a self-review of my own code 34 | - [ ] I have commented my code, particularly in hard-to-understand areas 35 | - [ ] I have made corresponding changes to the documentation 36 | - [ ] My changes generate no new warnings 37 | - [ ] I have added tests that prove my fix is effective or that my feature works 38 | - [ ] New and existing unit tests pass locally with my changes 39 | - [ ] Any dependent changes have been merged and published in downstream modules 40 | -------------------------------------------------------------------------------- /.github/workflows/movieLens-py37.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run gdmix end to end on movieLens data 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Gdmix movieLens workflow 5 | 6 | on: 7 | push: 8 | branches-ignore: 9 | - master 10 | jobs: 11 | testbox: 12 | runs-on: ubuntu-latest 13 | strategy: 14 | matrix: 15 | python-version: [3.7] 16 | container: 17 | image: linkedin/gdmix-dev 18 | steps: 19 | - uses: actions/checkout@v2 20 | - name: Set up Python ${{ matrix.python-version }} 21 | uses: actions/setup-python@v2 22 | with: 23 | python-version: ${{ matrix.python-version }} 24 | - name: Build gdmix-data jar 25 | run: | 26 | ./gradlew shadowJar 27 | - name: Install GDMix dependencies and run unit tests 28 | run: | 29 | python -m pip install --upgrade pip 30 | pip install --upgrade setuptools pytest 31 | cd gdmix-trainer && pip install . && pytest && cd .. 32 | cd gdmix-workflow && pip install . && pytest && cd .. 33 | - name: Prepare movieLens data 34 | run: | 35 | pip install pandas numpy 36 | python scripts/download_process_movieLens_data.py 37 | - name: MovieLens logistic regression workflow 38 | run: | 39 | python -m gdmixworkflow.main --config_path gdmix-workflow/examples/movielens-100k/lr-movieLens.yaml --jar_path build/gdmix-data-all_2.11/libs/gdmix-data-all_2.11-*.jar 40 | - name: MovieLens DeText workflow 41 | run: | 42 | python -m gdmixworkflow.main --config_path gdmix-workflow/examples/movielens-100k/detext-movieLens.yaml --jar_path build/gdmix-data-all_2.11/libs/gdmix-data-all_2.11-*.jar 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /out 2 | *.egg 3 | *.egg-info/ 4 | *.iml 5 | *.ipr 6 | *.iws 7 | *.pyc 8 | *.pyo 9 | *.sublime-* 10 | .*.swo 11 | .*.swp 12 | .cache/ 13 | .coverage 14 | .direnv/ 15 | .env 16 | .envrc 17 | .gradle/ 18 | .idea/ 19 | .tox* 20 | .venv* 21 | /*/*pinned.txt 22 | /*/MANIFEST 23 | /*/activate 24 | /*/build/ 25 | /*/config 26 | /*/coverage.xml 27 | /*/dist/ 28 | /*/htmlcov/ 29 | /*/product-spec.json 30 | /build/ 31 | /config/ 32 | /dist/ 33 | /ligradle/ 34 | TEST-*.xml 35 | .DS_Store 36 | venv/ 37 | .vscode/ 38 | .project 39 | .settings/ 40 | __pycache__ 41 | .cache 42 | 43 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Contribution Agreement 2 | ====================== 3 | 4 | As a contributor, you represent that the code you submit is your original work or 5 | that of your employer (in which case you represent you have the right to bind your 6 | employer). By submitting code, you (and, if applicable, your employer) are 7 | licensing the submitted code to LinkedIn and the open source community subject 8 | to the BSD 2-Clause license. 9 | 10 | Responsible Disclosure of Security Vulnerabilities 11 | ================================================== 12 | 13 | **Do not file an issue on Github for security issues.** Please review 14 | the [guidelines for disclosure][disclosure_guidelines]. Reports should 15 | be encrypted using PGP ([public key][pubkey]) and sent to 16 | [security@linkedin.com][disclosure_email] preferably with the title 17 | "Vulnerability in Github LinkedIn/GDMix - <short summary>". 18 | 19 | Tips for Getting Your Pull Request Accepted 20 | =========================================== 21 | 22 | 1. Make sure all new features are tested and the tests pass. 23 | 2. Bug fixes must include a test case demonstrating the error that it fixes. 24 | 3. Open an issue first and seek advice for your change before submitting 25 | a pull request. Large features which have never been discussed are 26 | unlikely to be accepted. **You have been warned.** 27 | 28 | [disclosure_guidelines]: https://www.linkedin.com/help/linkedin/answer/62924 29 | [pubkey]: https://www.linkedin.com/help/linkedin/answer/79676 30 | [disclosure_email]: mailto:security@linkedin.com?subject=Vulnerability%20in%20Github%20LinkedIn/GDMix%20-%20%3Csummary%3E 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2020 LinkedIn Corporation All Rights Reserved. 2 | 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 9 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 10 | -------------------------------------------------------------------------------- /build.gradle: -------------------------------------------------------------------------------- 1 | buildscript { 2 | repositories { 3 | jcenter() 4 | } 5 | } 6 | 7 | plugins { 8 | id "maven-publish" 9 | id "com.jfrog.bintray" version "1.7.3" 10 | id 'com.github.johnrengelman.shadow' version '2.0.1' 11 | } 12 | 13 | allprojects { 14 | apply plugin: 'eclipse' 15 | apply plugin: 'idea' 16 | repositories { 17 | jcenter() 18 | } 19 | } 20 | 21 | // The gradle variables defined here are visible in sub-projects 22 | ext { 23 | sparkVersion = '2.4.7' 24 | scalaSuffix = "_2.11" 25 | gdmixDataVersion = "0.4.0" 26 | pomConfig = { 27 | licenses { 28 | license { 29 | name "The 2-Clause BSD License" 30 | url "https://opensource.org/licenses/BSD-2-Clause" 31 | distribution "repo" 32 | } 33 | } 34 | developers { 35 | developer { 36 | id "jshi" 37 | name "Jun Shi" 38 | email "jshi@linkedin.com" 39 | } 40 | developer { 41 | id "mizhou-in" 42 | name "Mingzhou Zhou" 43 | email "mizhou@linkedin.com" 44 | } 45 | } 46 | scm { 47 | url "https://github.com/linkedin/gdmix" 48 | } 49 | } 50 | } 51 | 52 | subprojects { 53 | // Put the build dir into the rootProject 54 | buildDir = "../build/$name" 55 | 56 | tasks.withType(Jar) { 57 | version "${gdmixDataVersion}" 58 | } 59 | 60 | plugins.withType(JavaPlugin) { 61 | tasks.withType(Test) { 62 | useTestNG() 63 | 64 | // Exclude tests (ex. gradle test -Pexclude=SomeTestClass) 65 | def excludedTests = project.properties['exclude'] 66 | if (excludedTests) { 67 | excludedTests.replaceAll('\\s', '').split('[,]').each { 68 | exclude "**/${it}.class" 69 | } 70 | } 71 | 72 | afterSuite { desc, result -> 73 | if (!desc.parent) { 74 | println ":${project.name} -- Executed ${result.testCount} tests: ${result.successfulTestCount} succeeded, ${result.failedTestCount} failed, ${result.skippedTestCount} skipped" 75 | } 76 | } 77 | 78 | // Forward standard out from child JVMs to the console 79 | testLogging { 80 | showStackTraces = true 81 | showStandardStreams = true 82 | showExceptions = true 83 | showCauses = true 84 | displayGranularity = maxGranularity 85 | exceptionFormat = 'full' 86 | } 87 | 88 | outputs.upToDateWhen { false } 89 | 90 | systemProperty "log4j.configuration", "file:${project.rootDir}/log4j.properties" 91 | 92 | minHeapSize = "2G" 93 | maxHeapSize = "8G" 94 | } 95 | 96 | dependencies { 97 | testCompile 'org.testng:testng:6.10' 98 | } 99 | 100 | sourceCompatibility = 1.8 101 | } 102 | 103 | tasks.withType(ScalaCompile) { 104 | scalaCompileOptions.additionalParameters = ["-feature", "-deprecation", "-verbose", "-optimize", "-unchecked", "-Yinline-warnings", "-g:vars"] 105 | 106 | configure(scalaCompileOptions.forkOptions) { 107 | memoryMaximumSize = '1g' 108 | } 109 | configurations.zinc.transitive = true 110 | 111 | } 112 | 113 | idea { 114 | module { 115 | testSourceDirs += file('src/test/scala') 116 | } 117 | } 118 | 119 | configurations.all { 120 | resolutionStrategy { 121 | force 'com.fasterxml.jackson.core:jackson-databind:2.6.7' 122 | } 123 | } 124 | 125 | // This task allows to get the dependencies from all the sub-project from the main directory 126 | // (otherwise, you have to change to each sub-project directory to get its dependencies) 127 | task allDeps(type: DependencyReportTask) {} 128 | } 129 | -------------------------------------------------------------------------------- /figures/AliceAnnie.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/gdmix/dc24377e808ecc287ece67d88ffb800fb9ffaaa4/figures/AliceAnnie.png -------------------------------------------------------------------------------- /figures/gdmix-kubeflow-pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/gdmix/dc24377e808ecc287ece67d88ffb800fb9ffaaa4/figures/gdmix-kubeflow-pipeline.png -------------------------------------------------------------------------------- /figures/gdmix-operation-models.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/gdmix/dc24377e808ecc287ece67d88ffb800fb9ffaaa4/figures/gdmix-operation-models.png -------------------------------------------------------------------------------- /figures/gdmix-overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/gdmix/dc24377e808ecc287ece67d88ffb800fb9ffaaa4/figures/gdmix-overview.png -------------------------------------------------------------------------------- /figures/gdmix-workflow-jobs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/gdmix/dc24377e808ecc287ece67d88ffb800fb9ffaaa4/figures/gdmix-workflow-jobs.png -------------------------------------------------------------------------------- /figures/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/gdmix/dc24377e808ecc287ece67d88ffb800fb9ffaaa4/figures/logo.png -------------------------------------------------------------------------------- /gdmix-data-all/build.gradle: -------------------------------------------------------------------------------- 1 | apply plugin: 'scala' 2 | apply plugin: 'com.github.johnrengelman.shadow' 3 | apply plugin: 'maven-publish' 4 | apply plugin: 'com.jfrog.bintray' 5 | 6 | configurations { 7 | all*.exclude group: 'org.eclipse.jetty' 8 | } 9 | 10 | dependencies { 11 | compile project(":gdmix-data$scalaSuffix") 12 | } 13 | 14 | jar.enabled = false 15 | 16 | shadowJar { 17 | // Remove the classifier as we want the shadow jar to be the main jar 18 | classifier = '' 19 | configurations = [project.configurations.runtime] 20 | 21 | mergeServiceFiles() 22 | 23 | relocate 'org.json4s', 'gdmix.shaded.org.json4s' 24 | } 25 | 26 | task sourcesJar(type: Jar, dependsOn: classes) { 27 | classifier = 'sources' 28 | from sourceSets.main.allSource 29 | } 30 | 31 | javadoc.failOnError = false 32 | task javadocJar(type: Jar, dependsOn: javadoc) { 33 | classifier = 'javadoc' 34 | from javadoc.destinationDir 35 | } 36 | 37 | artifacts { 38 | archives shadowJar 39 | archives sourcesJar 40 | archives javadocJar 41 | } 42 | 43 | publishing { 44 | publications { 45 | mavenPublication(MavenPublication) { 46 | from components.java 47 | artifact sourcesJar { 48 | classifier "sources" 49 | } 50 | artifact javadocJar { 51 | classifier "javadoc" 52 | } 53 | groupId 'com.linkedin.gdmix' 54 | artifactId 'gdmix-data-all_2.11' 55 | version "${gdmixDataVersion}" 56 | pom.withXml { 57 | def root = asNode() 58 | root.appendNode('description', 'A data processing library of the deep learning personalization framework GDMix') 59 | root.appendNode('name', 'gdmix-data-all') 60 | root.appendNode('url', 'https://github.com/linkedin/gdmix') 61 | root.children().last() + pomConfig 62 | } 63 | } 64 | } 65 | } 66 | 67 | bintray { 68 | user = System.getenv('BINTRAY_USER') 69 | key = System.getenv('BINTRAY_KEY') 70 | publications = ['mavenPublication'] 71 | publish = true 72 | // dryRun = true 73 | 74 | pkg { 75 | repo = 'maven' 76 | user = System.getenv('BINTRAY_USER') 77 | name = 'gdmix-data-all' 78 | userOrg = 'linkedin' 79 | licenses = ['BSD 2-Clause'] 80 | desc = 'A data processing library of the deep learning personalization framework GDMix' 81 | websiteUrl = 'https://github.com/linkedin/gdmix' 82 | vcsUrl = 'https://github.com/linkedin/gdmix' 83 | version { 84 | name = "${gdmixDataVersion}" 85 | desc = 'A data processing library of the deep learning personalization framework GDMix' 86 | } 87 | } 88 | } 89 | 90 | -------------------------------------------------------------------------------- /gdmix-data/build.gradle: -------------------------------------------------------------------------------- 1 | apply plugin: "scala" 2 | apply plugin: 'maven-publish' 3 | apply plugin: 'com.jfrog.bintray' 4 | 5 | configurations { 6 | all*.exclude group: "org.eclipse.jetty" 7 | } 8 | 9 | dependencies { 10 | compile "com.databricks:spark-avro$scalaSuffix:3.2.0" 11 | compile "com.github.scopt:scopt$scalaSuffix:4.0.0-RC2" 12 | compile "com.linkedin.sparktfrecord:spark-tfrecord$scalaSuffix:0.2.1" 13 | compile "org.apache.spark:spark-avro$scalaSuffix:2.4.4" 14 | compile "org.json4s:json4s-core$scalaSuffix:3.3.0" 15 | compile "org.json4s:json4s-jackson$scalaSuffix:3.3.0" 16 | compile "org.json4s:json4s-ext$scalaSuffix:3.3.0" 17 | compile "org.json4s:json4s-ast$scalaSuffix:3.3.0" 18 | 19 | compileOnly "com.fasterxml.jackson.core:jackson-databind:2.6.7.1" 20 | compileOnly "org.apache.spark:spark-core$scalaSuffix:$sparkVersion" 21 | compileOnly "org.apache.spark:spark-sql$scalaSuffix:$sparkVersion" 22 | compileOnly "org.apache.spark:spark-mllib$scalaSuffix:$sparkVersion" 23 | 24 | testCompile "com.fasterxml.jackson.module:jackson-module-paranamer:2.6.7" 25 | testCompile "org.apache.avro:avro-mapred:1.7.7:hadoop2" 26 | testCompile "org.apache.spark:spark-mllib$scalaSuffix:$sparkVersion" 27 | testCompile "org.apache.spark:spark-sql$scalaSuffix:$sparkVersion" 28 | } 29 | 30 | test { 31 | useTestNG() 32 | } 33 | 34 | task sourcesJar(type: Jar, dependsOn: classes) { 35 | classifier = 'sources' 36 | from sourceSets.main.allSource 37 | } 38 | 39 | javadoc.failOnError = false 40 | task javadocJar(type: Jar, dependsOn: javadoc) { 41 | classifier = 'javadoc' 42 | from javadoc.destinationDir 43 | } 44 | 45 | artifacts { 46 | archives sourcesJar 47 | archives javadocJar 48 | archives jar 49 | } 50 | 51 | publishing { 52 | publications { 53 | mavenPublication(MavenPublication) { 54 | from components.java 55 | artifact sourcesJar { 56 | classifier "sources" 57 | } 58 | artifact javadocJar { 59 | classifier "javadoc" 60 | } 61 | groupId 'com.linkedin.gdmix' 62 | artifactId 'gdmix-data_2.11' 63 | version "${gdmixDataVersion}" 64 | pom.withXml { 65 | def root = asNode() 66 | root.appendNode('description', 'A data processing library of the deep learning personalization framework GDMix') 67 | root.appendNode('name', 'gdmix-data') 68 | root.appendNode('url', 'https://github.com/linkedin/gdmix') 69 | root.children().last() + pomConfig 70 | } 71 | } 72 | } 73 | } 74 | 75 | bintray { 76 | user = System.getenv('BINTRAY_USER') 77 | key = System.getenv('BINTRAY_KEY') 78 | publications = ['mavenPublication'] 79 | publish = true 80 | // dryRun = true 81 | 82 | pkg { 83 | repo = 'maven' 84 | user = System.getenv('BINTRAY_USER') 85 | name = 'gdmix-data' 86 | userOrg = 'linkedin' 87 | licenses = ['BSD 2-Clause'] 88 | desc = 'A data processing library of the deep learning personalization framework GDMix' 89 | websiteUrl = 'https://github.com/linkedin/gdmix' 90 | vcsUrl = 'https://github.com/linkedin/gdmix' 91 | version { 92 | name = "${gdmixDataVersion}" 93 | desc = 'A data processing library of the deep learning personalization framework GDMix' 94 | } 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /gdmix-data/src/main/resources/model/lr_model.avsc: -------------------------------------------------------------------------------- 1 | { 2 | "name": "NameTermValueAvro", 3 | "namespace": "com.linkedin.photon.avro.generated", 4 | "type": "record", 5 | "doc": "A tuple of name, term and value. Used as feature or model coefficient", 6 | "fields": [ 7 | { 8 | "name": "name", 9 | "type": "string" 10 | }, 11 | { 12 | "name": "term", 13 | "type": "string" 14 | }, 15 | { 16 | "name": "value", 17 | "type": "double" 18 | } 19 | ] 20 | } 21 | 22 | { 23 | "name": "BayesianLinearModelAvro", 24 | "namespace": "com.linkedin.photon.avro.generated", 25 | "type": "record", 26 | "doc": "a generic schema to describe a Bayesian linear model with means and variances", 27 | "fields": [ 28 | { 29 | "name": "modelId", 30 | "type": "string" 31 | }, 32 | { 33 | "default": null, 34 | "name": "modelClass", 35 | "type": [ 36 | "null", 37 | "string" 38 | ], 39 | "doc": "The fully-qualified class name of enclosing GLM model class. E.g.: com.linkedin.photon.ml.supervised.classification.LogisticRegressionModel" 40 | }, 41 | { 42 | "name": "means", 43 | "type": { 44 | "items": "NameTermValueAvro", 45 | "type": "array" 46 | } 47 | }, 48 | { 49 | "default": null, 50 | "name": "variances", 51 | "type" : [ 52 | "null", 53 | { 54 | "items" : "NameTermValueAvro", 55 | "type" : "array" 56 | } 57 | ] 58 | }, 59 | { 60 | "default": null, 61 | "name": "lossFunction", 62 | "type": [ 63 | "null", 64 | "string" 65 | ], 66 | "doc": "The loss function used for training as the class name. E.g.: com.linkedin.photon.ml.function.LogisticLossFunction" 67 | } 68 | ] 69 | } 70 | -------------------------------------------------------------------------------- /gdmix-data/src/main/scala/com/linkedin/gdmix/configs/EffectConfig.scala: -------------------------------------------------------------------------------- 1 | package com.linkedin.gdmix.configs 2 | 3 | import com.fasterxml.jackson.core.`type`.TypeReference 4 | import com.fasterxml.jackson.module.scala.JsonScalaEnumeration 5 | 6 | /** 7 | * Enumeration to represent different data type of tensors 8 | */ 9 | object DataType extends Enumeration { 10 | type DataType = Value 11 | val string, int, long, double, float, byte = Value 12 | } 13 | 14 | /** 15 | * This class is a workaround to Scala's Enumeration so that we can use 16 | * DataType.DataType as a type in the definition. See the example 17 | * in the ColumnConfig definition below, where the member "dtype" is defined as 18 | * DataType.DataType. 19 | * For details: 20 | * https://github.com/FasterXML/jackson-module-scala/wiki/Enumerations 21 | */ 22 | class DataTypeRef extends TypeReference[DataType.type] 23 | 24 | /** 25 | * Case class for fixed or random effect config 26 | * 27 | * @param isRandomEffect Whether this is a random effect. 28 | * @param coordinateName Coordinate name for a fixed effect or a random effect. 29 | * @param perEntityName Entity name that random effect is based on. null for fixed effect. 30 | * @param labels A sequence of label column names. 31 | * @param columnConfigList A sequence of column configs. 32 | */ 33 | case class EffectConfig( 34 | isRandomEffect: Boolean, 35 | coordinateName: String, 36 | perEntityName: Option[String] = None, 37 | labels: Option[Seq[String]] = None, 38 | columnConfigList: Seq[ColumnConfig] 39 | ) extends Ordered[EffectConfig] { 40 | require( 41 | columnConfigList.nonEmpty, 42 | s"Please specify at least one column" 43 | ) 44 | if (isRandomEffect) { 45 | require(!perEntityName.isEmpty) 46 | } else { 47 | require(perEntityName.isEmpty) 48 | } 49 | 50 | // We want the configs to be sorted such that the fixed effect precedes random effects. 51 | // This property is used in name-term-value to sparse/dense tensor conversion function. 52 | def compare(that: EffectConfig) = this.isRandomEffect.compare(that.isRandomEffect) 53 | } 54 | 55 | /** 56 | * Case class for column config 57 | * 58 | * @param name Column name. 59 | * @param dtype Intended data type after conversion. 60 | * @param shape The data shape of the column. 61 | * @param isInputNTV Whether this is in name-term-value format. 62 | * @param isOutputSparse Whether the output tensor should be a sparse tensor. 63 | * @param sharedFeatureSpace Whether the name-term shared the feature space, only used 64 | * for random effect data conversion. 65 | */ 66 | 67 | case class ColumnConfig( 68 | name: String, 69 | @JsonScalaEnumeration(classOf[DataTypeRef]) dtype: DataType.DataType, 70 | shape: Seq[Int] = Seq(), 71 | isInputNTV: Option[Boolean] = None, 72 | isOutputSparse: Option[Boolean] = None, 73 | sharedFeatureSpace: Option[Boolean] = None) 74 | -------------------------------------------------------------------------------- /gdmix-data/src/main/scala/com/linkedin/gdmix/configs/TensorMetadata.scala: -------------------------------------------------------------------------------- 1 | package com.linkedin.gdmix.configs 2 | 3 | import com.fasterxml.jackson.module.scala.JsonScalaEnumeration 4 | 5 | /** 6 | * Case class for the dataset metadata 7 | * 8 | * @param features Tensor metadata of a sequence of features 9 | * @param labels Tensor metadata of a sequence of labels 10 | */ 11 | case class DatasetMetadata( 12 | features: Seq[TensorMetadata], 13 | labels: Option[Seq[TensorMetadata]] = None) 14 | 15 | /** 16 | * Case class for the tensor metadata 17 | * 18 | * @param name Name of a tensor 19 | * @param dtype Data type of a tensor 20 | * @param shape Shape of a tensor 21 | * @param isSparse If it is a sparse tensor 22 | */ 23 | case class TensorMetadata( 24 | name: String, 25 | @JsonScalaEnumeration(classOf[DataTypeRef]) dtype: DataType.DataType, 26 | shape: Seq[Int], 27 | isSparse: Boolean = false 28 | ) 29 | -------------------------------------------------------------------------------- /gdmix-data/src/main/scala/com/linkedin/gdmix/evaluation/Evaluator.scala: -------------------------------------------------------------------------------- 1 | package com.linkedin.gdmix.evaluation 2 | 3 | import com.databricks.spark.avro._ 4 | import org.apache.hadoop.fs.{FileSystem, Path} 5 | import org.apache.hadoop.mapred.JobConf 6 | import org.apache.spark.mllib.evaluation.{BinaryClassificationMetrics, RegressionMetrics} 7 | import org.apache.spark.sql.{DataFrame, SparkSession} 8 | import org.apache.spark.sql.functions.col 9 | 10 | import com.linkedin.gdmix.parsers.EvaluatorParams 11 | import com.linkedin.gdmix.parsers.EvaluatorParser 12 | import com.linkedin.gdmix.utils.Constants._ 13 | import com.linkedin.gdmix.utils.{IoUtils, JsonUtils} 14 | 15 | 16 | /** 17 | * Metric evaluator. 18 | */ 19 | object Evaluator { 20 | /** 21 | * Compute evaluation metric based on the metric name. 22 | * 23 | * @param df Input data frame 24 | * @param labelName Name of the label in the dataframe 25 | * @param scoreName Name of the score in the dataframe 26 | * @param metricName Name of the evaluation metric 27 | * @return evaluation metric (e.g, area under ROC curve, mean squared error, etc.) 28 | */ 29 | def calculateMetric(df: DataFrame, labelName: String, scoreName: String, metricName: String): Double = { 30 | // Cast the columns. 31 | val scoreLabelDF = df.withColumn(scoreName, col(scoreName).cast("double")) 32 | .withColumn(labelName, col(labelName).cast("double")) 33 | .select(scoreName, labelName) 34 | 35 | // Map to (score, label). 36 | val scoreAndLabels = scoreLabelDF.rdd.map(row => (row.getDouble(0), row.getDouble(1))) 37 | 38 | // Compute evaluation metric. 39 | val metric = metricName match { 40 | case AUC => new BinaryClassificationMetrics(scoreAndLabels).areaUnderROC() 41 | case MSE => new RegressionMetrics(scoreAndLabels).meanSquaredError 42 | case _ => throw new IllegalArgumentException(s"Do not support metric ${metricName}, currently only support 'auc' and 'mse'.") 43 | } 44 | metric 45 | } 46 | 47 | def main(args: Array[String]): Unit = { 48 | 49 | val params = EvaluatorParser.parse(args) 50 | 51 | // Create a Spark session. 52 | val spark: SparkSession = SparkSession 53 | .builder() 54 | .appName(getClass.getName) 55 | .getOrCreate() 56 | 57 | try { 58 | run(spark, params) 59 | } finally { 60 | spark.stop() 61 | } 62 | } 63 | 64 | def run(spark: SparkSession, params: EvaluatorParams): Unit = { 65 | 66 | // Read file and cast the label and score to double. 67 | val df = spark.read.avro(params.metricsInputDir) 68 | 69 | // Compute evaluation metric. 70 | val metric = calculateMetric(df, params.labelColumnName, params.predictionColumnName, params.metricName) 71 | 72 | // Set up Hadoop file system. 73 | val hadoopJobConf = new JobConf() 74 | val fs: FileSystem = FileSystem.get(hadoopJobConf) 75 | 76 | // Convert to json and save to HDFS. 77 | val jsonResult = JsonUtils.toJsonString(Map(params.metricName -> metric)) 78 | IoUtils.writeFile(fs, new Path(params.outputMetricFile, EVAL_SUMMARY_JSON), jsonResult) 79 | } 80 | } -------------------------------------------------------------------------------- /gdmix-data/src/main/scala/com/linkedin/gdmix/model/LrModelSplitter.scala: -------------------------------------------------------------------------------- 1 | package com.linkedin.gdmix.model 2 | 3 | import org.apache.avro.Schema 4 | import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession} 5 | import org.apache.spark.sql.functions._ 6 | import com.databricks.spark.avro._ 7 | 8 | import com.linkedin.gdmix.parsers.LrModelSplitterParser 9 | import com.linkedin.gdmix.parsers.LrModelSplitterParams 10 | import com.linkedin.gdmix.utils.Constants._ 11 | import com.linkedin.gdmix.utils.ConversionUtils.{NameTermValueOptionDouble, splitModelIdUdf} 12 | 13 | import java.io.File 14 | import scala.util.{Success, Try} 15 | 16 | /** 17 | * Split crossed global model to multiple random effect models. 18 | * The input model file contains a single model with crossed feature names, e.g. 19 | * model_1_gdmixcross_feature_1, model_1_gdmixcross_feature_2, 20 | * model_2_gdmixcross_feature_3, model_2_gdmixcross_feature_4, 21 | * model_3_gdmixcross_feature_5, model_3_gdmixcross_feature_6 22 | * 23 | * The result mode files contain the following models: 24 | * model_1: 25 | * feature_1 26 | * feature_2 27 | * model_2: 28 | * feature_3 29 | * feature_4 30 | * model_3: 31 | * feature_5 32 | * feature_6 33 | */ 34 | object LrModelSplitter { 35 | 36 | val LR_MODEL_SCHEMA_FILE = "model/lr_model.avsc" 37 | 38 | def main(args: Array[String]): Unit = { 39 | 40 | val params = LrModelSplitterParser.parse(args) 41 | 42 | // Create a Spark session. 43 | val spark = SparkSession.builder().appName(getClass.getName).getOrCreate() 44 | try { 45 | run(spark, params) 46 | } finally { 47 | spark.stop() 48 | } 49 | } 50 | 51 | def run(spark: SparkSession, params: LrModelSplitterParams): Unit = { 52 | 53 | // Parse the commandline option. 54 | val modelInputDir = params.modelInputDir 55 | val modelOutputDir = params.modelOutputDir 56 | val numOutputFiles = params.numOutputFiles 57 | 58 | val df = spark.read.avro(modelInputDir) 59 | val means = splitModelId(MEANS, df) 60 | val hasVariances = Try(df.first().getAs[Seq[NameTermValueOptionDouble]](VARIANCES)) match { 61 | case Success(value) if value != null => true 62 | case _ => false 63 | } 64 | 65 | // append variances column 66 | val meansAndVariances = if (hasVariances) { 67 | val variances = splitModelId(VARIANCES, df) 68 | means.join(variances, MODEL_ID) 69 | } else { 70 | means.withColumn(VARIANCES, typedLit[Option[NameTermValueOptionDouble]](None)) 71 | } 72 | 73 | // append other columns 74 | val outDf = meansAndVariances 75 | .withColumn("modelClass", typedLit[String](LR_MODEL_CLASS)) 76 | .withColumn("lossFunction", typedLit[String]("")) 77 | 78 | val schema = new Schema.Parser().parse( 79 | getClass.getClassLoader.getResourceAsStream(LR_MODEL_SCHEMA_FILE)) 80 | 81 | outDf.repartition(numOutputFiles).write.option("forceSchema", schema.toString) 82 | .mode(SaveMode.Overwrite).format(AVRO_FORMAT).save(modelOutputDir) 83 | } 84 | 85 | /** 86 | * Separate the model Id from feature names. 87 | * Break a single model into multiple smaller models identified by their model Ids. 88 | * @param colName: the name of the column that has all coefficients of the global model 89 | * @param df: the input dataframe. 90 | * @return: a dataframe where each row is the coefficients of a separated model. 91 | */ 92 | private[model] def splitModelId(colName: String, df: DataFrame): DataFrame = { 93 | df.select(explode(col(colName)).alias("explodeCol")) 94 | .withColumn("splitCol", splitModelIdUdf(col("explodeCol"))) 95 | .select("splitCol.*") 96 | .withColumnRenamed("_1", MODEL_ID) 97 | .withColumnRenamed("_2", colName) 98 | .groupBy(MODEL_ID) 99 | .agg(collect_list(col(colName)).alias(colName)) 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /gdmix-data/src/main/scala/com/linkedin/gdmix/parsers/BestModelSelectorParser.scala: -------------------------------------------------------------------------------- 1 | package com.linkedin.gdmix.parsers 2 | 3 | import com.linkedin.gdmix.utils.Constants._ 4 | import com.linkedin.gdmix.utils.IoUtils 5 | 6 | /** 7 | * Parameters for best model selector job. 8 | */ 9 | case class BestModelSelectorParams( 10 | inputMetricsPaths: Seq[String], 11 | outputBestModelPath: String, 12 | evalMetric: String, 13 | hyperparameters: String, 14 | inputModelPaths: Option[String] = None, 15 | outputBestMetricsPath: Option[String] = None, 16 | copyBestOutput: Boolean = false 17 | ) 18 | 19 | /** 20 | * Parser for best model selector job. 21 | */ 22 | object BestModelSelectorParser { 23 | private val bestModelSelectorParser = new scopt.OptionParser[BestModelSelectorParams]( 24 | "Parsing command line for best model selector job.") { 25 | 26 | opt[String]("inputMetricsPaths").action((x, p) => p.copy( 27 | inputMetricsPaths = x.split(CONFIG_SPLITTER).map(_.trim))) 28 | .required 29 | .text( 30 | """Required. 31 | |Input model metric paths, separated by semicolon.""".stripMargin) 32 | 33 | 34 | opt[String]("outputBestModelPath").action((x, p) => p.copy(outputBestModelPath = x.trim)) 35 | .required 36 | .text( 37 | """Required. 38 | |Output best model path.""".stripMargin) 39 | 40 | opt[String]("evalMetric").action((x, p) => p.copy(evalMetric = x.trim)) 41 | .required 42 | .text( 43 | """Required. 44 | |Evaluation metric.""".stripMargin) 45 | 46 | opt[String]("hyperparameters").action((x, p) => p.copy(hyperparameters = x.trim)) 47 | .required 48 | .text( 49 | """Required. 50 | |Hyper-parameters of each model encoded in base64.""".stripMargin) 51 | 52 | opt[String]("inputModelPaths").action((x, p) => p.copy(inputModelPaths = if (x.trim.isEmpty) None else Some(x.trim))) 53 | .optional 54 | .text( 55 | """Optional. 56 | |Input model paths, separated by semicolons..""".stripMargin) 57 | 58 | opt[String]("outputBestMetricsPath").action((x, p) => p.copy(outputBestMetricsPath = if (x.trim.isEmpty) None else Some(x.trim))) 59 | .optional 60 | .text( 61 | """Optional. 62 | |Path to best model metric.""".stripMargin) 63 | 64 | opt[String]("copyBestOutput").action((x, p) => p.copy(copyBestOutput = x.toLowerCase == "true")) 65 | .optional 66 | .text( 67 | """Optional. 68 | |Boolean whether to copy the best model.""".stripMargin) 69 | 70 | checkConfig(p => 71 | if (p.copyBestOutput) { 72 | if (IoUtils.isEmptyStr(p.inputModelPaths)) { 73 | failure("Option --inputModelPaths is required when --copyBestOutput is true.") 74 | } 75 | 76 | else if (IoUtils.isEmptyStr(p.outputBestMetricsPath)) { 77 | failure("Option --outputBestMetricsPath is required when --copyBestOutput is true.") 78 | } 79 | else success 80 | } 81 | else success) 82 | } 83 | 84 | def parse(args: Seq[String]): BestModelSelectorParams = { 85 | val emptyBestModelSelectorParams = BestModelSelectorParams( 86 | inputMetricsPaths = Seq(""), 87 | outputBestModelPath = "", 88 | evalMetric = "", 89 | hyperparameters = "" 90 | ) 91 | bestModelSelectorParser.parse(args, emptyBestModelSelectorParams) match { 92 | case Some(params) => params 93 | case None => throw new IllegalArgumentException( 94 | s"Parsing the command line arguments failed.\n" + 95 | s"(${args.mkString(", ")}),\n${bestModelSelectorParser.usage}") 96 | } 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /gdmix-data/src/main/scala/com/linkedin/gdmix/parsers/EffectConfigParser.scala: -------------------------------------------------------------------------------- 1 | package com.linkedin.gdmix.parsers 2 | 3 | import com.linkedin.gdmix.configs.{DataType, EffectConfig, ColumnConfig} 4 | import org.json4s.DefaultFormats 5 | import org.json4s.ext.EnumNameSerializer 6 | import org.json4s.jackson.JsonMethods.parse 7 | 8 | /** 9 | * Parser for fixed or random effect configuration [[EffectConfig]] 10 | * 11 | */ 12 | object EffectConfigParser { 13 | 14 | /** 15 | * Get the fixed or random effect config with sanity check 16 | * 17 | * @param jsonString JSON format of a list of EffectConfig. 18 | * @return a list of EffectConfig 19 | */ 20 | def getEffectConfigList(jsonString: String): Seq[EffectConfig] = { 21 | // Define implicit JSON4S default format 22 | implicit val formats = DefaultFormats + new EnumNameSerializer(DataType) 23 | // Use JSON4S to parse and extract a list of EffectConfig. 24 | val configList = parse(jsonString).extract[Seq[EffectConfig]] 25 | sanityCheck(configList) 26 | } 27 | 28 | /** 29 | * Sanity check the list of EffectConfig. 30 | * Throw an exception when there are more than 1 fixed effect. 31 | * 32 | * @param EffectConfig A sequence of EffectConfig to be checked 33 | * @return A sequence of EffectConfig 34 | */ 35 | private def sanityCheck(configList: Seq[EffectConfig]): Seq[EffectConfig] = { 36 | 37 | // A sequence of EffectConfig that represent a dataset should only have one fixed-effect. 38 | val numFixedEffect = configList.foldLeft(0)((accum, config) => accum + (if (config.isRandomEffect) 0 else 1)) 39 | if (numFixedEffect > 1) { 40 | throw new IllegalArgumentException(s"There should be only 1 fixed effect, but $numFixedEffect are present") 41 | } 42 | 43 | // Check indivdual EffectConfig 44 | val checkedConfigList = configList.map(config => checkEffectConfig(config)) 45 | 46 | // Sort the configs such that the fixed effect is at the beginning. 47 | checkedConfigList.sorted 48 | } 49 | 50 | /** 51 | * Check the content of an EffectConfig. 52 | * 53 | * @param config The EffectConfig to be checked 54 | * @return An EffectConfig with missing column info added. 55 | */ 56 | private def checkEffectConfig(config: EffectConfig): EffectConfig = { 57 | 58 | val columnNames = config.columnConfigList.map(column => column.name).toSet 59 | 60 | // Check if the labels are in columnConfig 61 | if (config.labels != None) { 62 | config.labels.get.map { 63 | label => 64 | if (!columnNames.contains(label)) { 65 | throw new IllegalArgumentException(s"Label $label is not in column names") 66 | } 67 | } 68 | } 69 | 70 | // Check if perEntityName is in columnConfig 71 | if (config.perEntityName != None) { 72 | val entityName = config.perEntityName.get 73 | if (!columnNames.contains(entityName)) { 74 | throw new IllegalArgumentException(s"EntityName $entityName is not in column names") 75 | } 76 | } 77 | 78 | config.copy(columnConfigList = config.columnConfigList.map(column => checkColumnConfig(column))) 79 | } 80 | 81 | /** 82 | * Check the content of a ColumnConfig. Fill in the default values if possible. 83 | * Throw an exception when column expression and column configuration both exist in input feature information 84 | * 85 | * @param columnConfig ColumnConfig to be checked 86 | * @return A ColumnConfig with missing values filled with default 87 | */ 88 | private def checkColumnConfig(columnConfig: ColumnConfig): ColumnConfig = { 89 | 90 | val isInputNTV = setDefaultBoolean(columnConfig.isInputNTV, false) 91 | val isOutputSparse = setDefaultBoolean(columnConfig.isOutputSparse, false) 92 | val sharedFeatureSpace = setDefaultBoolean(columnConfig.sharedFeatureSpace, true) 93 | 94 | if ((isInputNTV.get) && (columnConfig.dtype != DataType.float)) { 95 | throw new IllegalArgumentException(s"Name-Term-Value format output datatype must be float") 96 | } 97 | 98 | columnConfig.copy( 99 | isInputNTV = isInputNTV, 100 | isOutputSparse = isOutputSparse, 101 | sharedFeatureSpace = sharedFeatureSpace) 102 | } 103 | 104 | /** 105 | * A utility function that fill a Boolean value if missing from an option 106 | * 107 | * @param some an option value. 108 | * @param value a boolean value used as the default. 109 | * @return A Some with the default value when the input is None. 110 | */ 111 | private def setDefaultBoolean(some: Option[Boolean], value: Boolean): Option[Boolean] = { 112 | if (some == None) Some(value) else some 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /gdmix-data/src/main/scala/com/linkedin/gdmix/parsers/EvaluatorParser.scala: -------------------------------------------------------------------------------- 1 | package com.linkedin.gdmix.parsers 2 | 3 | import com.linkedin.gdmix.utils.Constants._ 4 | 5 | /** 6 | * Parameters for evaluation metric compute job. 7 | */ 8 | case class EvaluatorParams( 9 | metricsInputDir: String, 10 | outputMetricFile: String, 11 | labelColumnName: String, 12 | predictionColumnName: String, 13 | metricName: String 14 | ) 15 | 16 | /** 17 | * Parser for evaluation metric compute job. 18 | */ 19 | object EvaluatorParser { 20 | private val evaluatorParser = new scopt.OptionParser[EvaluatorParams]( 21 | "Parsing command line for evaluation metric compute job.") { 22 | 23 | opt[String]("metricsInputDir").action((x, p) => p.copy(metricsInputDir = x.trim)) 24 | .required 25 | .text( 26 | """Required. 27 | |Input data path containing prediction and label column.""".stripMargin) 28 | 29 | opt[String]("outputMetricFile").action((x, p) => p.copy(outputMetricFile = x.trim)) 30 | .required 31 | .text( 32 | """Required. 33 | |Output file for the computed evaluation metric.""".stripMargin) 34 | 35 | opt[String]("labelColumnName").action((x, p) => p.copy(labelColumnName = x.trim)) 36 | .required 37 | .text( 38 | """Required. 39 | |Label column name.""".stripMargin) 40 | 41 | opt[String]("predictionColumnName").action((x, p) => p.copy(predictionColumnName = x.trim)) 42 | .required 43 | .text( 44 | """Required. 45 | |prediction score column name.""".stripMargin) 46 | 47 | opt[String]("metricName").action((x, p) => p.copy(metricName = x.trim)) 48 | .required 49 | .text( 50 | """Required. 51 | |evaluation metric name (current only support 'auc' and 'mse').""".stripMargin) 52 | 53 | checkConfig(p => 54 | if (!List(AUC, MSE).contains(p.metricName)) { 55 | failure(s"${p.metricName} is not supported, should be in ['auc', 'mse'].") 56 | } 57 | else success) 58 | } 59 | 60 | def parse(args: Seq[String]): EvaluatorParams = { 61 | val emptyEvaluatorParams = EvaluatorParams( 62 | metricsInputDir = "", 63 | outputMetricFile = "", 64 | labelColumnName = "", 65 | predictionColumnName = "", 66 | metricName = "" 67 | ) 68 | evaluatorParser.parse(args, emptyEvaluatorParams) match { 69 | case Some(params) => params 70 | case None => throw new IllegalArgumentException( 71 | s"Parsing the command line arguments failed.\n" + 72 | s"(${args.mkString(", ")}),\n${evaluatorParser.usage}") 73 | } 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /gdmix-data/src/main/scala/com/linkedin/gdmix/parsers/LrModelSplitterParser.scala: -------------------------------------------------------------------------------- 1 | package com.linkedin.gdmix.parsers 2 | 3 | import com.linkedin.gdmix.utils.IoUtils 4 | 5 | /** 6 | * Parameters for LR model splitter job. 7 | */ 8 | case class LrModelSplitterParams( 9 | modelInputDir: String, 10 | modelOutputDir: String, 11 | numOutputFiles: Int = 200 12 | ) 13 | 14 | /** 15 | * Parser for model splitting job. 16 | */ 17 | object LrModelSplitterParser { 18 | private val lrModelSplitterParser = new scopt.OptionParser[LrModelSplitterParams]( 19 | "Parsing command line for model splitting job.") { 20 | 21 | opt[String]("modelInputDir").action((x, p) => p.copy(modelInputDir = x.trim)) 22 | .required 23 | .text( 24 | """Required. 25 | |The path for input models.""".stripMargin) 26 | 27 | opt[String]("modelOutputDir").action((x, p) => p.copy(modelOutputDir = x.trim)) 28 | .required 29 | .text( 30 | """Required. 31 | |The path for output models.""".stripMargin) 32 | 33 | opt[Int]("numOutputFiles").action((x, p) => p.copy(numOutputFiles = x)) 34 | .optional 35 | .validate( 36 | x => if (x > 0) success else failure("Option --numPartitions must be > 0")) 37 | .text( 38 | """Optional. 39 | |Number of output files.""".stripMargin) 40 | 41 | checkConfig(p => 42 | if (p.modelInputDir == "") { 43 | failure("Model input path can not be empty string.") 44 | } 45 | else success) 46 | 47 | checkConfig(p => 48 | if (p.modelOutputDir == "") { 49 | failure("Model output path can not be empty string.") 50 | } 51 | else success) 52 | } 53 | 54 | def parse(args: Seq[String]): LrModelSplitterParams = { 55 | val emptyLrModelSplitterParams = LrModelSplitterParams( 56 | modelInputDir = "", 57 | modelOutputDir = "" 58 | ) 59 | lrModelSplitterParser.parse(args, emptyLrModelSplitterParams) match { 60 | case Some(params) => params 61 | case None => throw new IllegalArgumentException( 62 | s"Parsing the command line arguments failed.\n" + 63 | s"(${args.mkString(", ")}),\n${lrModelSplitterParser.usage}") 64 | } 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /gdmix-data/src/main/scala/com/linkedin/gdmix/utils/Constants.scala: -------------------------------------------------------------------------------- 1 | package com.linkedin.gdmix.utils 2 | 3 | /** 4 | * Constants used by data module. 5 | */ 6 | object Constants { 7 | 8 | val ACTIVE = "active" 9 | val AUC = "auc" 10 | val MSE = "mse" 11 | val AVRO = "avro" 12 | val AVRO_FORMAT = "com.databricks.spark.avro" 13 | val CONFIG_SPLITTER = ";" 14 | val COUNT = "count" 15 | val CROSS = "_gdmixcross_" 16 | val EVAL_SUMMARY_JSON = "evalSummary.json" 17 | val FLOAT = "float" 18 | val GLOBAL = "global" 19 | val GROUP_ID = "groupId" 20 | val INDICES = "indices" 21 | val LONG = "long" 22 | val LR_MODEL_CLASS = "com.linkedin.photon.ml.supervised.classification.LogisticRegressionModel" 23 | val MEANS = "means" 24 | val MODEL_ID = "modelId" 25 | val NAME = "name" 26 | val OUTER = "outer" 27 | val OFFSET = "offset" 28 | val PARTITION_ID = "partitionId" 29 | val PASSIVE = "passive" 30 | val PER_ENTITY_GROUP_COUNT = "perEntityGroupCount" 31 | val PER_ENTITY_TOTAL_SAMPLE_COUNT = "perEntityTotalSampleCount" 32 | val PREDICTION_SCORE = "predictionScore" 33 | val PREDICTION_SCORE_PER_COORDINATE = "predictionScorePerCoordinate" 34 | val PREDICTION_SCORE_TEMP = "predictionScoreTemp" 35 | val PREDICTION_SCORE_SUM = "predictionScoreSum" 36 | val RMSE = "rmse" 37 | val TERM = "term" 38 | val TFRECORD = "tfrecord" 39 | val TF_EXAMPLE = "Example" 40 | val TF_SEQUENCE_EXAMPLE = "SequenceExample" 41 | val UID = "uid" 42 | val UID_TEMP = "uidTemp" 43 | val VALUE = "value" 44 | val VALUES = "values" 45 | val VARIANCES = "variances" 46 | } 47 | -------------------------------------------------------------------------------- /gdmix-data/src/main/scala/com/linkedin/gdmix/utils/ConversionUtils.scala: -------------------------------------------------------------------------------- 1 | package com.linkedin.gdmix.utils 2 | 3 | import org.apache.spark.sql.Row 4 | import org.apache.spark.sql.expressions.UserDefinedFunction 5 | import org.apache.spark.sql.functions.udf 6 | import org.apache.spark.sql.types._ 7 | 8 | import com.linkedin.gdmix.configs.DataType 9 | import com.linkedin.gdmix.utils.Constants.{NAME, TERM, VALUE, CROSS} 10 | 11 | /** 12 | * Helper class to convert NTVs to sparse vectors 13 | */ 14 | object ConversionUtils { 15 | 16 | /** 17 | * Case class to represent the NameTermValue (NTV) 18 | * 19 | * @param name Name of a feature 20 | * @param term Term of a feature 21 | * @param value Value of a feature 22 | */ 23 | case class NameTermValue(name: String, term: String, value: Float) 24 | 25 | /** 26 | * Case class to represent the NameTermValue (NTV) where the value is float (nullable) 27 | * 28 | * @param name Name of a feature 29 | * @param term Term of a feature 30 | * @param value Value of a feature (Option[Float]) 31 | */ 32 | case class NameTermValueOptionFloat(name: String, term: String, value: Option[Float]) 33 | 34 | /** 35 | * Case class to represent the NameTermValue (NTV) where the value is double (nullable) 36 | * Photon-ML generates model values are in double format and nullable. 37 | * 38 | * @param name Name of a feature 39 | * @param term Term of a feature 40 | * @param value Value of a feature (Option[Double]) 41 | */ 42 | case class NameTermValueOptionDouble(name: String, term: String, value: Option[Double]) 43 | 44 | /** 45 | * Case class for SparseVector type 46 | * @param indices The indices of a sparse vector 47 | * @param values The values of a sparse vector 48 | */ 49 | case class SparseVector(indices: Seq[Long], values: Seq[Float]) 50 | 51 | /** 52 | * UDF to get name and term given a row of NTV 53 | * @return A string of "$name,$term" 54 | */ 55 | def getNameTermUdf: UserDefinedFunction = udf { r: Row => (r.getAs[String](NAME), r.getAs[String](TERM)) } 56 | 57 | /** 58 | * Split the full name into (model_id, feature_name) tuple 59 | */ 60 | def splitModelIdUdf: UserDefinedFunction = udf { r: Row => 61 | val Array(modelId, name) = r.getAs[String](NAME).split(CROSS) 62 | val term = r.getAs[String](TERM) 63 | val value = r.getAs[Double](VALUE) 64 | (modelId, NameTermValueOptionDouble(name, term, Some(value))) 65 | } 66 | 67 | /** 68 | * Convert input Config DataType to Spark sql DataType 69 | */ 70 | final val ConfigDataTypeMap = Map[DataType.DataType, org.apache.spark.sql.types.DataType]( 71 | DataType.byte -> ByteType, 72 | DataType.double -> DoubleType, 73 | DataType.float -> FloatType, 74 | DataType.int -> IntegerType, 75 | DataType.long -> LongType, 76 | DataType.string -> StringType 77 | ) 78 | 79 | /** 80 | * Map Spark sql DataType -> Config DataType 81 | */ 82 | def mapSparkToConfigDataType( 83 | sparkType: org.apache.spark.sql.types.DataType 84 | ): DataType.DataType = sparkType match { 85 | case ByteType => DataType.byte 86 | case DoubleType => DataType.double 87 | case FloatType => DataType.float 88 | case IntegerType => DataType.int 89 | case LongType => DataType.long 90 | case StringType => DataType.string 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /gdmix-data/src/main/scala/com/linkedin/gdmix/utils/JsonUtils.scala: -------------------------------------------------------------------------------- 1 | package com.linkedin.gdmix.utils 2 | 3 | import com.fasterxml.jackson.annotation.JsonInclude.Include 4 | import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper, SerializationFeature} 5 | import com.fasterxml.jackson.module.scala.DefaultScalaModule 6 | import com.fasterxml.jackson.module.scala.experimental.ScalaObjectMapper 7 | 8 | /** 9 | * Helper class to serialize an Object into a JSON String and to deserialize a Json String into a Map 10 | */ 11 | object JsonUtils { 12 | 13 | /** 14 | * Define the FasterXML Jackson object mapper 15 | */ 16 | private[this] val mapper = new ObjectMapper() with ScalaObjectMapper 17 | mapper 18 | .registerModule(DefaultScalaModule) 19 | .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) 20 | .configure(DeserializationFeature.ACCEPT_SINGLE_VALUE_AS_ARRAY, true) 21 | .configure(SerializationFeature.INDENT_OUTPUT, true) 22 | .setSerializationInclusion(Include.NON_NULL) 23 | 24 | /** 25 | * Write an object to JSON formatted String 26 | * 27 | * @param value Any object to be written 28 | * @return A JSON pretty formatted String 29 | */ 30 | def toJsonString(value: Any): String = mapper.writeValueAsString(value) 31 | 32 | /** 33 | * Helper function to generate a map from json string 34 | */ 35 | def toMap[V](json: String)(implicit m: Manifest[V]): Map[String, V] = { 36 | if(json.trim.isEmpty) 37 | Map.empty[String, V] 38 | 39 | fromJson[Map[String, V]](json) 40 | } 41 | 42 | private def fromJson[T](json: String)(implicit m: Manifest[T]): T = { 43 | mapper.readValue[T](json) 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /gdmix-data/src/main/scala/com/linkedin/gdmix/utils/PartitionUtils.scala: -------------------------------------------------------------------------------- 1 | package com.linkedin.gdmix.utils 2 | 3 | import org.apache.spark.sql.DataFrame 4 | import org.apache.spark.sql.expressions.UserDefinedFunction 5 | import org.apache.spark.sql.functions.{col, udf} 6 | 7 | import com.linkedin.gdmix.utils.Constants._ 8 | 9 | /** 10 | * Helper class to partition data in Spark data frame 11 | */ 12 | object PartitionUtils { 13 | 14 | /** 15 | * UDF to add offset to each value in a sequence 16 | */ 17 | def addOffsetUDF: UserDefinedFunction = { 18 | udf{ 19 | (indices: Seq[Long], offset: Int) => { 20 | indices.map(index => index + offset) 21 | } 22 | } 23 | } 24 | 25 | /** 26 | * UDF to get partition id by (hash(item id) % number of partitions) 27 | * 28 | * @param numPartitions Number of partitions 29 | * @return A UDF to get the partition id. 30 | */ 31 | def getPartitionIdUDF(numPartitions: Int): UserDefinedFunction = { 32 | udf { 33 | itemId: String => { 34 | Math.abs(itemId.hashCode) % numPartitions 35 | } 36 | } 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /gdmix-data/src/test/resources/configs/ConfigWithTwoFixedEffects.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "isRandomEffect": false, 4 | "coordinateName": "fixed-effect", 5 | "perEntityName": null, 6 | "labels":["response"], 7 | "columnConfigList": [ 8 | { 9 | "name": "weight", 10 | "dtype": "float", 11 | "shape": [], 12 | "isInputNTV": false, 13 | "isOutputSparse": false, 14 | "sharedFeatureSpace": true 15 | }, 16 | { 17 | "name": "response", 18 | "dtype": "float", 19 | "shape": [], 20 | "isInputNTV": false, 21 | "isOutputSparse": false 22 | }, 23 | { 24 | "name": "global", 25 | "dtype": "float", 26 | "shape": [32], 27 | "isInputNTV": false, 28 | "isOutputSparse": false 29 | } 30 | ] 31 | }, 32 | { 33 | "isRandomEffect": false, 34 | "coordinateName": "fixed-effect", 35 | "perEntityName": null, 36 | "labels":["response"], 37 | "columnConfigList": [ 38 | { 39 | "name": "weight", 40 | "dtype": "float", 41 | "shape": [], 42 | "isInputNTV": false, 43 | "isOutputSparse": false, 44 | "sharedFeatureSpace": true 45 | }, 46 | { 47 | "name": "response", 48 | "dtype": "float", 49 | "shape": [], 50 | "isInputNTV": false, 51 | "isOutputSparse": false 52 | }, 53 | { 54 | "name": "fixed-effect", 55 | "dtype": "float", 56 | "shape": [12], 57 | "isInputNTV": false, 58 | "isOutputSparse": false 59 | } 60 | ] 61 | }, 62 | { 63 | "isRandomEffect": true, 64 | "coordinateName": "per-member", 65 | "perEntityName": "memberId", 66 | "labels":["response"], 67 | "columnConfigList": [ 68 | { 69 | "name": "per-member", 70 | "dtype": "float", 71 | "shape": [], 72 | "isInputNTV": true, 73 | "isOutputSparse": false, 74 | "sharedFeatureSpace": true 75 | }, 76 | { 77 | "name": "response", 78 | "dtype": "float", 79 | "shape": [] 80 | } 81 | ] 82 | } 83 | ] -------------------------------------------------------------------------------- /gdmix-data/src/test/resources/configs/EffectConfigs.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "isRandomEffect": true, 4 | "coordinateName": "per-member", 5 | "perEntityName": "memberId", 6 | "labels":["response"], 7 | "columnConfigList": [ 8 | { 9 | "name": "weight", 10 | "dtype": "float", 11 | "shape": [], 12 | "isInputNTV": false, 13 | "isOutputSparse": false, 14 | "sharedFeatureSpace": false 15 | }, 16 | { 17 | "name": "response", 18 | "dtype": "int", 19 | "shape": [] 20 | }, 21 | { 22 | "name": "memberId", 23 | "dtype": "string", 24 | "shape": [] 25 | }, 26 | { 27 | "name": "per_member", 28 | "dtype": "float", 29 | "shape": [2, 3], 30 | "isInputNTV": false, 31 | "isOutputSparse": false 32 | } 33 | ] 34 | }, 35 | { 36 | "isRandomEffect": false, 37 | "coordinateName": "fixed-effect", 38 | "perEntityName": null, 39 | "labels":["response", "label"], 40 | "columnConfigList": [ 41 | { 42 | "name": "response", 43 | "dtype": "int", 44 | "shape": [] 45 | }, 46 | { 47 | "name": "label", 48 | "dtype": "float", 49 | "shape": [] 50 | }, 51 | { 52 | "name": "global", 53 | "dtype": "float", 54 | "shape": [12] 55 | } 56 | ] 57 | } 58 | ] -------------------------------------------------------------------------------- /gdmix-data/src/test/resources/configs/EntityNotInColumns.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "isRandomEffect": true, 4 | "coordinateName": "per-member", 5 | "perEntityName": "memberId", 6 | "labels":["response"], 7 | "columnConfigList": [ 8 | { 9 | "name": "weight", 10 | "dtype": "float", 11 | "shape": [], 12 | "isInputNTV": false, 13 | "isOutputSparse": false, 14 | "sharedFeatureSpace": true 15 | }, 16 | { 17 | "name": "response", 18 | "dtype": "float", 19 | "shape": [], 20 | "isInputNTV": false, 21 | "isOutputSparse": false 22 | }, 23 | { 24 | "name": "per-member", 25 | "dtype": "float", 26 | "shape": [32], 27 | "isInputNTV": false, 28 | "isOutputSparse": false 29 | } 30 | ] 31 | } 32 | ] -------------------------------------------------------------------------------- /gdmix-data/src/test/resources/configs/FeatureConvertor.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "isRandomEffect": false, 4 | "coordinateName": "fixed-effect", 5 | "perEntityName": null, 6 | "labels":["response"], 7 | "columnConfigList": [ 8 | { 9 | "name": "response", 10 | "dtype": "float", 11 | "shape": [] 12 | }, 13 | { 14 | "name": "global", 15 | "dtype": "float", 16 | "shape": [], 17 | "isInputNTV": true, 18 | "isOutputSparse": true, 19 | "sharedFeatureSpace": true 20 | }, 21 | { 22 | "name": "uid", 23 | "dtype": "long", 24 | "shape": [] 25 | } 26 | ] 27 | }, 28 | { 29 | "isRandomEffect": true, 30 | "coordinateName": "per-member", 31 | "perEntityName": "memberId", 32 | "labels":["response"], 33 | "columnConfigList": [ 34 | { 35 | "name": "memberId", 36 | "dtype": "long", 37 | "shape": [] 38 | }, 39 | { 40 | "name": "response", 41 | "dtype": "float", 42 | "shape": [] 43 | }, 44 | { 45 | "name": "perMember", 46 | "dtype": "float", 47 | "shape": [], 48 | "isInputNTV": true, 49 | "isOutputSparse": true, 50 | "sharedFeatureSpace": true 51 | }, 52 | { 53 | "name": "uid", 54 | "dtype": "long", 55 | "shape": [] 56 | } 57 | ] 58 | }, 59 | { 60 | "isRandomEffect": true, 61 | "coordinateName": "per-item", 62 | "perEntityName": "itemId", 63 | "labels":["response"], 64 | "columnConfigList": [ 65 | { 66 | "name": "itemId", 67 | "dtype": "long", 68 | "shape": [] 69 | }, 70 | { 71 | "name": "response", 72 | "dtype": "float", 73 | "shape": [] 74 | }, 75 | { 76 | "name": "perItem", 77 | "dtype": "float", 78 | "shape": [], 79 | "isInputNTV": true, 80 | "isOutputSparse": false, 81 | "sharedFeatureSpace": false 82 | }, 83 | { 84 | "name": "uid", 85 | "dtype": "long", 86 | "shape": [] 87 | } 88 | ] 89 | } 90 | ] -------------------------------------------------------------------------------- /gdmix-data/src/test/resources/configs/LabelNotInColumns.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "isRandomEffect": true, 4 | "coordinateName": "per-member", 5 | "perEntityName": "memberId", 6 | "labels":["lable", "response"], 7 | "columnConfigList": [ 8 | { 9 | "name": "weight", 10 | "dtype": "float", 11 | "shape": [], 12 | "isInputNTV": false, 13 | "isOutputSparse": false, 14 | "sharedFeatureSpace": true 15 | }, 16 | { 17 | "name": "response", 18 | "dtype": "float", 19 | "shape": [], 20 | "isInputNTV": false, 21 | "isOutputSparse": false 22 | }, 23 | { 24 | "name": "per-member", 25 | "dtype": "float", 26 | "shape": [32], 27 | "isInputNTV": false, 28 | "isOutputSparse": false 29 | } 30 | ] 31 | } 32 | ] -------------------------------------------------------------------------------- /gdmix-data/src/test/resources/data/ExpectedGlobalTrainData.avro: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/gdmix/dc24377e808ecc287ece67d88ffb800fb9ffaaa4/gdmix-data/src/test/resources/data/ExpectedGlobalTrainData.avro -------------------------------------------------------------------------------- /gdmix-data/src/test/resources/data/ExpectedGlobalValidData.avro: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/gdmix/dc24377e808ecc287ece67d88ffb800fb9ffaaa4/gdmix-data/src/test/resources/data/ExpectedGlobalValidData.avro -------------------------------------------------------------------------------- /gdmix-data/src/test/resources/data/ExpectedPerItemTrainData.avro: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/gdmix/dc24377e808ecc287ece67d88ffb800fb9ffaaa4/gdmix-data/src/test/resources/data/ExpectedPerItemTrainData.avro -------------------------------------------------------------------------------- /gdmix-data/src/test/resources/data/ExpectedPerItemValidData.avro: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/gdmix/dc24377e808ecc287ece67d88ffb800fb9ffaaa4/gdmix-data/src/test/resources/data/ExpectedPerItemValidData.avro -------------------------------------------------------------------------------- /gdmix-data/src/test/resources/data/ExpectedPerMemberTrainData.avro: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/gdmix/dc24377e808ecc287ece67d88ffb800fb9ffaaa4/gdmix-data/src/test/resources/data/ExpectedPerMemberTrainData.avro -------------------------------------------------------------------------------- /gdmix-data/src/test/resources/data/ExpectedPerMemberValidData.avro: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/gdmix/dc24377e808ecc287ece67d88ffb800fb9ffaaa4/gdmix-data/src/test/resources/data/ExpectedPerMemberValidData.avro -------------------------------------------------------------------------------- /gdmix-data/src/test/resources/data/TrainData.avro: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/gdmix/dc24377e808ecc287ece67d88ffb800fb9ffaaa4/gdmix-data/src/test/resources/data/TrainData.avro -------------------------------------------------------------------------------- /gdmix-data/src/test/resources/metadata/ExpectedGlobalFeatureList.txt: -------------------------------------------------------------------------------- 1 | global,1 2 | global,2 3 | global,3 4 | -------------------------------------------------------------------------------- /gdmix-data/src/test/resources/metadata/ExpectedGlobalMetadata.json: -------------------------------------------------------------------------------- 1 | { 2 | "features" : [ { 3 | "name" : "global", 4 | "dtype" : "float", 5 | "shape" : [ 3 ], 6 | "isSparse" : true 7 | }, { 8 | "name" : "uid", 9 | "dtype" : "long", 10 | "shape" : [ ], 11 | "isSparse" : false 12 | } ], 13 | "labels" : [ { 14 | "name" : "response", 15 | "dtype" : "float", 16 | "shape" : [ ], 17 | "isSparse" : false 18 | } ] 19 | } 20 | -------------------------------------------------------------------------------- /gdmix-data/src/test/resources/metadata/ExpectedPerItemFeatureList.txt: -------------------------------------------------------------------------------- 1 | perItem,1 2 | perItem,1 3 | perItem,2 4 | perItem,2 5 | perItem,3 6 | perItem,3 7 | -------------------------------------------------------------------------------- /gdmix-data/src/test/resources/metadata/ExpectedPerItemMetadata.json: -------------------------------------------------------------------------------- 1 | { 2 | "features" : [ { 3 | "name" : "itemId", 4 | "dtype" : "long", 5 | "shape" : [ ], 6 | "isSparse" : false 7 | }, { 8 | "name" : "perItem", 9 | "dtype" : "float", 10 | "shape" : [ 6 ], 11 | "isSparse" : false 12 | }, { 13 | "name" : "uid", 14 | "dtype" : "long", 15 | "shape" : [ ], 16 | "isSparse" : false 17 | } ], 18 | "labels" : [ { 19 | "name" : "response", 20 | "dtype" : "float", 21 | "shape" : [ ], 22 | "isSparse" : false 23 | } ] 24 | } 25 | -------------------------------------------------------------------------------- /gdmix-data/src/test/resources/metadata/ExpectedPerMemberFeatureList.txt: -------------------------------------------------------------------------------- 1 | perMember,1 2 | perMember,2 3 | perMember,3 4 | -------------------------------------------------------------------------------- /gdmix-data/src/test/resources/metadata/ExpectedPerMemberMetadata.json: -------------------------------------------------------------------------------- 1 | { 2 | "features" : [ { 3 | "name" : "memberId", 4 | "dtype" : "long", 5 | "shape" : [ ], 6 | "isSparse" : false 7 | }, { 8 | "name" : "perMember", 9 | "dtype" : "float", 10 | "shape" : [ 3 ], 11 | "isSparse" : true 12 | }, { 13 | "name" : "uid", 14 | "dtype" : "long", 15 | "shape" : [ ], 16 | "isSparse" : false 17 | } ], 18 | "labels" : [ { 19 | "name" : "response", 20 | "dtype" : "float", 21 | "shape" : [ ], 22 | "isSparse" : false 23 | } ] 24 | } 25 | -------------------------------------------------------------------------------- /gdmix-data/src/test/scala/com/linkedin/gdmix/configs/EffectConfigTest.scala: -------------------------------------------------------------------------------- 1 | package com.linkedin.gdmix.configs 2 | 3 | import org.testng.Assert.assertEquals 4 | import org.testng.annotations.Test 5 | 6 | /** 7 | * Unit tests for [[EffectConfig]]. 8 | */ 9 | @Test 10 | class EffectConfigTest { 11 | def testSort(): Unit = { 12 | val col1 = ColumnConfig("feature", DataType.float, Seq(1, 2), Some(true), Some(true), Some(false)) 13 | val col2 = ColumnConfig("global", DataType.int) 14 | val first = EffectConfig(true, "per-member", Some("memberId"), Some(Seq("label")), Seq(col1, col2)) 15 | val second = EffectConfig(false, "global", None, Some(Seq("label", "response")), Seq(col2)) 16 | val configs = Seq(first, second) 17 | assertEquals(configs.sorted, Seq(second, first)) 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /gdmix-data/src/test/scala/com/linkedin/gdmix/data/BestModelSelectorTest.scala: -------------------------------------------------------------------------------- 1 | package com.linkedin.gdmix.data 2 | 3 | import org.testng.Assert.assertEquals 4 | import org.testng.annotations.Test 5 | import com.linkedin.gdmix.utils.Constants._ 6 | import com.linkedin.gdmix.utils.JsonUtils 7 | 8 | /** 9 | * Unit tests for [[BestModelSelector]]. 10 | */ 11 | class BestModelSelectorTest { 12 | 13 | val hparams = "eyIwIjogWyJmZTpvcHRpbWl6ZXJfdHlwZTpVbmlvblRlbXBsYXRlKHsnY29tLmxpbmtlZGluLmdkbWl4LmNvbmZpZ3MuZXhwZXJpbWVudGFsLm9wdGltaXplci5TR0QnOiB7J2xlYXJuaW5nUmF0ZSc6IDAuMDAxfX0pIiwgInJlOm51bWJlcl9vZl9wYXJ0aXRpb246MTAwIl0sICIxIjogWyJmZTpvcHRpbWl6ZXJfdHlwZTpVbmlvblRlbXBsYXRlKHsnY29tLmxpbmtlZGluLmdkbWl4LmNvbmZpZ3MuZXhwZXJpbWVudGFsLm9wdGltaXplci5TR0QnOiB7J2xlYXJuaW5nUmF0ZSc6IDAuMDAxfX0pIiwgInJlOm51bWJlcl9vZl9wYXJ0aXRpb246MjAwIl0sICIyIjogWyJmZTpvcHRpbWl6ZXJfdHlwZTpVbmlvblRlbXBsYXRlKHsnY29tLmxpbmtlZGluLmdkbWl4LmNvbmZpZ3MuZXhwZXJpbWVudGFsLm9wdGltaXplci5TR0QnOiB7J2xlYXJuaW5nUmF0ZSc6IDAuMDF9fSkiLCAicmU6bnVtYmVyX29mX3BhcnRpdGlvbjoxMDAiXSwgIjMiOiBbImZlOm9wdGltaXplcl90eXBlOlVuaW9uVGVtcGxhdGUoeydjb20ubGlua2VkaW4uZ2RtaXguY29uZmlncy5leHBlcmltZW50YWwub3B0aW1pemVyLlNHRCc6IHsnbGVhcm5pbmdSYXRlJzogMC4wMX19KSIsICJyZTpudW1iZXJfb2ZfcGFydGl0aW9uOjIwMCJdLCAiNCI6IFsiZmU6b3B0aW1pemVyX3R5cGU6VW5pb25UZW1wbGF0ZSh7J2NvbS5saW5rZWRpbi5nZG1peC5jb25maWdzLmV4cGVyaW1lbnRhbC5vcHRpbWl6ZXIuU0dEJzogeydsZWFybmluZ1JhdGUnOiAwLjF9fSkiLCAicmU6bnVtYmVyX29mX3BhcnRpdGlvbjoxMDAiXSwgIjUiOiBbImZlOm9wdGltaXplcl90eXBlOlVuaW9uVGVtcGxhdGUoeydjb20ubGlua2VkaW4uZ2RtaXguY29uZmlncy5leHBlcmltZW50YWwub3B0aW1pemVyLlNHRCc6IHsnbGVhcm5pbmdSYXRlJzogMC4xfX0pIiwgInJlOm51bWJlcl9vZl9wYXJ0aXRpb246MjAwIl19" 14 | 15 | @Test 16 | def testDeserialize(): Unit = { 17 | val hparamMap = BestModelSelector.deserialize(hparams) 18 | assertEquals(hparamMap.size, 6) 19 | assertEquals(JsonUtils.toJsonString(hparamMap("0")), "[ \"fe:optimizer_type:UnionTemplate({'com.linkedin.gdmix.configs.experimental.optimizer.SGD': {'learningRate': 0.001}})\", \"re:number_of_partition:100\" ]") 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /gdmix-data/src/test/scala/com/linkedin/gdmix/data/OffsetUpdaterTest.scala: -------------------------------------------------------------------------------- 1 | package com.linkedin.gdmix.data 2 | 3 | import org.testng.Assert.assertEquals 4 | import org.testng.annotations.Test 5 | 6 | import com.linkedin.gdmix.utils.SharedSparkSession 7 | 8 | /** 9 | * Unit tests for [[OffsetUpdater]]. 10 | */ 11 | class OffsetUpdaterTest extends SharedSparkSession { 12 | 13 | import spark.implicits._ 14 | import OffsetUpdaterTest._ 15 | 16 | /** 17 | * Unit test for [[OffsetUpdater.updateOffset()]]. 18 | */ 19 | @Test 20 | def testUpdateOffset(): Unit = { 21 | val data = Seq((1L, 0.0F), (2L, 0.0F)).toDF(UID, OFFSET) 22 | val lastOffset = Seq((1L, 1.0F), (2L, 2.0F)).toDF(UID, PREDICTION_SCORE) 23 | val perCoordinateScore = Seq((1L, 0.1F), (2L, 0.2F)).toDF(UID, PREDICTION_SCORE_PER_COORDINATE) 24 | 25 | val updatedData1 = OffsetUpdater.updateOffset( 26 | data, 27 | lastOffset, 28 | None, 29 | PREDICTION_SCORE, 30 | PREDICTION_SCORE_PER_COORDINATE, 31 | OFFSET, 32 | UID) 33 | val res1 = updatedData1.map(row => (row.getAs[Long](UID), row.getAs[Float](OFFSET))).collect().toMap 34 | assertEquals(res1(1L), 1.0F) 35 | assertEquals(res1(2L), 2.0F) 36 | 37 | val updatedData2 = OffsetUpdater.updateOffset( 38 | data, 39 | lastOffset, 40 | Some(perCoordinateScore), 41 | PREDICTION_SCORE, 42 | PREDICTION_SCORE_PER_COORDINATE, 43 | OFFSET, 44 | UID) 45 | val res2 = updatedData2.map(row => (row.getAs[Long](UID), row.getAs[Float](OFFSET))).collect().toMap 46 | assertEquals(res2(1L), 0.9F) 47 | assertEquals(res2(2L), 1.8F) 48 | } 49 | } 50 | 51 | object OffsetUpdaterTest { 52 | val OFFSET = "offset" 53 | val UID = "uid" 54 | val PREDICTION_SCORE = "predictionScore" 55 | val PREDICTION_SCORE_PER_COORDINATE = "predictionScorePerCoordinate" 56 | } 57 | -------------------------------------------------------------------------------- /gdmix-data/src/test/scala/com/linkedin/gdmix/evaluation/EvaluatorTest.scala: -------------------------------------------------------------------------------- 1 | package com.linkedin.gdmix.evaluation 2 | 3 | import org.testng.Assert.assertEquals 4 | import org.testng.annotations.{DataProvider, Test} 5 | 6 | import com.linkedin.gdmix.utils.SharedSparkSession 7 | 8 | /** 9 | * Unit tests for [[Evaluator]]. 10 | */ 11 | class EvaluatorTest extends SharedSparkSession { 12 | 13 | import spark.implicits._ 14 | 15 | @DataProvider(name = "AUCScoreAndLabels") 16 | def scoreAndLabels(): Array[Array[Any]] = { 17 | Array( 18 | Array(Array(0.1, 0.4, 0.35, 0.8), 19 | Array(0, 0, 1.0, 1.0), 20 | 0.75), 21 | Array(Array(0.5, 0.7, 0.3, 0.4, 0.45, 0.8), 22 | Array(0, 0, 1.0, 1.0, 0, 1.0), 23 | 0.3333333), 24 | Array(Array(0.5, 0.75, 0.8, 0.2, 0.3, 0.4, 0.45, 0.5), 25 | Array(0, 0, 0, 0, 1.0, 1.0, 0, 1.0), 26 | 0.3) 27 | ) 28 | } 29 | 30 | @DataProvider(name = "MSEScoreAndLabels") 31 | def predictionAndObservations(): Array[Array[Any]] = { 32 | Array( 33 | Array(Array(0.1, 0.4, 0.35, 0.8), 34 | Array(0, 0, 1.0, 2.0), 35 | 0.5081250), 36 | Array(Array(0.5, 0.7, 1.3, 3.4, 5.45, 0.8), 37 | Array(0, 0, 1.0, 2.0, 3.0, 1.0), 38 | 1.4720833), 39 | Array(Array(0.5, 0.75, -0.8, 0.2, -0.3, 0.4, 0.45, 0.5), 40 | Array(0, 0, 0.2, 0, 0.4, -1.1, 0, -1.0), 41 | 0.880625) 42 | ) 43 | } 44 | 45 | /** 46 | * Unit test for [[Evaluator.calculateMetric]] on caculating AUC. 47 | */ 48 | @Test(dataProvider = "AUCScoreAndLabels") 49 | def testCalculateAreaUnderROCCurve(score: Array[Double], label: Array[Double], auc: Double): Unit = { 50 | val metricName = "auc" 51 | val labelName = "label" 52 | val scoreName = "score" 53 | val delta = 1.0e-5 54 | val df = (score zip label).toList.toDF(scoreName, labelName) 55 | val calculatedAUC = Evaluator.calculateMetric(df, labelName, scoreName, metricName) 56 | assertEquals(calculatedAUC, auc, delta) 57 | } 58 | 59 | 60 | /** 61 | * Unit test for [[Evaluator.calculateMetric]] on caculating MSE. 62 | */ 63 | @Test(dataProvider = "MSEScoreAndLabels") 64 | def testCalculateMeanSquaredError(score: Array[Double], label: Array[Double], mse: Double): Unit = { 65 | val metricName = "mse" 66 | val labelName = "label" 67 | val scoreName = "score" 68 | val delta = 1.0e-5 69 | val df = (score zip label).toList.toDF(scoreName, labelName) 70 | val calculatedMSE = Evaluator.calculateMetric(df, labelName, scoreName, metricName) 71 | assertEquals(calculatedMSE, mse, delta) 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /gdmix-data/src/test/scala/com/linkedin/gdmix/parsers/BestModelSelectorParserTest.scala: -------------------------------------------------------------------------------- 1 | package com.linkedin.gdmix.parsers 2 | 3 | import org.testng.annotations.{DataProvider, Test} 4 | import org.testng.Assert.assertEquals 5 | 6 | /** 7 | * Unit tests for BestModelSelectorParser. 8 | */ 9 | class BestModelSelectorParserTest { 10 | 11 | @DataProvider 12 | def dataCompleteArgs(): Array[Array[Any]] = { 13 | Array( 14 | Array( 15 | Seq( 16 | "--inputMetricsPaths", "gdmix/0/per-job/metric/0;gdmix/1/per-job/metric/0", 17 | "--inputModelPaths", "gdmix/0/per-job/model_output;gdmix/1/per-job/model_output", 18 | "--outputBestMetricsPath", "gdmix/best/metric", 19 | "--outputBestModelPath", "gdmix/best/model", 20 | "--hyperparameters", "eyIwIjogWyJnbG9iYWw6YmF0Y2hfc2l6", 21 | "--evalMetric", "auc", 22 | "--copyBestOutput", "True"))) 23 | } 24 | 25 | @DataProvider 26 | def dataIncompleteArgs(): Array[Array[Any]] = { 27 | 28 | Array( 29 | // copyBestOutput = true but no inputModelPaths 30 | Array( 31 | Seq( 32 | "--inputMetricsPaths", "gdmix/0/per-job/metric/0;gdmix/1/per-job/metric/0", 33 | "--outputBestMetricsPath", "gdmix/best/metric", 34 | "--outputBestModelPath", "gdmix/best/model", 35 | "--hyperparameters", "eyIwIjogWyJnbG9iYWw6YmF0Y2hfc2l6", 36 | "--evalMetric", "auc", 37 | "--copyBestOutput", "True")), 38 | // copyBestOutput is true but no outputBestMetricsPath 39 | Array( 40 | Seq( 41 | "--inputMetricsPaths", "gdmix/0/per-job/metric/0;gdmix/1/per-job/metric/0", 42 | "--inputModelPaths", "gdmix/0/per-job/model_output;gdmix/1/per-job/model_output", 43 | "--outputBestModelPath", "gdmix/best/model", 44 | "--hyperparameters", "eyIwIjogWyJnbG9iYWw6YmF0Y2hfc2l6", 45 | "--evalMetric", "auc", 46 | "--copyBestOutput", "True")), 47 | // miss inputMetricsPaths 48 | Array( 49 | Seq( 50 | "--outputBestModelPath", "gdmix/best/model", 51 | "--hyperparameters", "eyIwIjogWyJnbG9iYWw6YmF0Y2hfc2l6", 52 | "--evalMetric", "auc")), 53 | // miss outputBestModelPath 54 | Array( 55 | Seq( 56 | "--inputMetricsPaths", "gdmix/0/per-job/metric/0;gdmix/1/per-job/metric/0", 57 | "--hyperparameters", "eyIwIjogWyJnbG9iYWw6YmF0Y2hfc2l6", 58 | "--evalMetric", "auc")), 59 | // miss hyperparameters 60 | Array( 61 | Seq( 62 | "--inputMetricsPaths", "gdmix/0/per-job/metric/0;gdmix/1/per-job/metric/0", 63 | "--outputBestModelPath", "gdmix/best/model", 64 | "--evalMetric", "auc")), 65 | // miss evalMetric 66 | Array( 67 | Seq( 68 | "--inputMetricsPaths", "gdmix/0/per-job/metric/0;gdmix/1/per-job/metric/0", 69 | "--outputBestModelPath", "gdmix/best/model", 70 | "--hyperparameters", "eyIwIjogWyJnbG9iYWw6YmF0Y2hfc2l6")) 71 | ) 72 | } 73 | 74 | @Test(dataProvider = "dataCompleteArgs") 75 | def testParseCompleteArguments(completeArgs: Seq[String]): Unit = { 76 | 77 | val params = BestModelSelectorParser.parse(completeArgs) 78 | val expectedParams = BestModelSelectorParams( 79 | inputMetricsPaths = Seq("gdmix/0/per-job/metric/0", "gdmix/1/per-job/metric/0"), 80 | outputBestModelPath = "gdmix/best/model", 81 | evalMetric = "auc", 82 | hyperparameters = "eyIwIjogWyJnbG9iYWw6YmF0Y2hfc2l6", 83 | outputBestMetricsPath = Some("gdmix/best/metric"), 84 | inputModelPaths = Some("gdmix/0/per-job/model_output;gdmix/1/per-job/model_output"), 85 | copyBestOutput = true 86 | ) 87 | assertEquals(params, expectedParams) 88 | } 89 | 90 | @Test(dataProvider = "dataIncompleteArgs", expectedExceptions = Array(classOf[IllegalArgumentException])) 91 | def testThrowIllegalArgumentException(inCompleteArgs: Seq[String]): Unit = { 92 | BestModelSelectorParser.parse(inCompleteArgs) 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /gdmix-data/src/test/scala/com/linkedin/gdmix/parsers/EffectConfigParserTest.scala: -------------------------------------------------------------------------------- 1 | package com.linkedin.gdmix.parsers 2 | 3 | import com.linkedin.gdmix.configs.{ColumnConfig, DataType, EffectConfig} 4 | import com.linkedin.gdmix.utils.IoUtils.readFile 5 | import org.testng.Assert.assertEquals 6 | import org.testng.annotations.Test 7 | 8 | /** 9 | * Unit tests for [[EffectConfigParser]]. 10 | */ 11 | class EffectConfigParserTest { 12 | 13 | final val CONFIG_FILE_WITH_TWO_FIXED_EFFECTS = "configs/ConfigWithTwoFixedEffects.json" 14 | final val EFFECT_CONFIG_FILE = "configs/EffectConfigs.json" 15 | final val LABEL_NOT_IN_COLUMNS_FILE = "configs/LabelNotInColumns.json" 16 | final val ENTITY_NOT_IN_COLUMNS_FILE = "configs/EntityNotInColumns.json" 17 | 18 | @Test(expectedExceptions = Array(classOf[IllegalArgumentException])) 19 | def testTwoFixedEffect(): Unit = { 20 | val configJson = readFile(null, CONFIG_FILE_WITH_TWO_FIXED_EFFECTS, true) 21 | EffectConfigParser.getEffectConfigList(configJson) 22 | } 23 | 24 | @Test(expectedExceptions = Array(classOf[IllegalArgumentException])) 25 | def testLabelNotInColumns(): Unit = { 26 | val configJson = readFile(null, CONFIG_FILE_WITH_TWO_FIXED_EFFECTS, true) 27 | EffectConfigParser.getEffectConfigList(configJson) 28 | } 29 | 30 | @Test(expectedExceptions = Array(classOf[IllegalArgumentException])) 31 | def testEntityNotInColumns(): Unit = { 32 | val configJson = readFile(null, CONFIG_FILE_WITH_TWO_FIXED_EFFECTS, true) 33 | EffectConfigParser.getEffectConfigList(configJson) 34 | } 35 | 36 | @Test 37 | def testEffectConfigParser(): Unit = { 38 | val configJson = readFile(null, EFFECT_CONFIG_FILE, true) 39 | val parsedConfigList = EffectConfigParser.getEffectConfigList(configJson) 40 | 41 | // construct expected configs 42 | val fixedEffectCol1 = ColumnConfig("response", DataType.int, Seq(), Some(false), Some(false), Some(true)) 43 | val fixedEffectCol2 = ColumnConfig("label", DataType.float, Seq(), Some(false), Some(false), Some(true)) 44 | val fixedEffectCol3 = ColumnConfig("global", DataType.float, Seq(12), Some(false), Some(false), Some(true)) 45 | 46 | val randomEffectCol1 = ColumnConfig("weight", DataType.float, Seq(), Some(false), Some(false), Some(false)) 47 | val randomEffectCol2 = fixedEffectCol1 48 | val randomEffectCol3 = ColumnConfig("memberId", DataType.string, Seq(), Some(false), Some(false), Some(true)) 49 | val randomEffectCol4 = ColumnConfig("per_member", DataType.float, Seq(2, 3), Some(false), Some(false), Some(true)) 50 | 51 | val fixedEffect = EffectConfig( 52 | false, 53 | "fixed-effect", 54 | None, 55 | Some(Seq("response", "label")), 56 | Seq(fixedEffectCol1, fixedEffectCol2, fixedEffectCol3)) 57 | 58 | val randomEffect = EffectConfig( 59 | true, 60 | "per-member", 61 | Some("memberId"), 62 | Some(Seq("response")), 63 | Seq(randomEffectCol1, randomEffectCol2, randomEffectCol3, randomEffectCol4)) 64 | 65 | assertEquals(parsedConfigList, Seq(fixedEffect, randomEffect)) 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /gdmix-data/src/test/scala/com/linkedin/gdmix/parsers/EvaluatorParserTest.scala: -------------------------------------------------------------------------------- 1 | package com.linkedin.gdmix.parsers 2 | 3 | import org.testng.annotations.{DataProvider, Test} 4 | import org.testng.Assert.assertEquals 5 | 6 | /** 7 | * Unit tests for EvaluatorParser. 8 | */ 9 | class EvaluatorParserTest { 10 | 11 | @DataProvider 12 | def dataCompleteArgs(): Array[Array[Any]] = { 13 | Array( 14 | Array( 15 | Seq( 16 | "--metricsInputDir", "global/validationScore", 17 | "--outputMetricFile", "global/metric/0", 18 | "--labelColumnName", "response", 19 | "--predictionColumnName", "predictionScore", 20 | "--metricName", "auc"))) 21 | } 22 | 23 | @DataProvider 24 | def dataIncompleteArgs(): Array[Array[Any]] = { 25 | 26 | Array( 27 | // miss metricsInputDir 28 | Array( 29 | Seq( 30 | "--outputMetricFile", "global/metric/0", 31 | "--labelColumnName", "response", 32 | "--predictionColumnName", "predictionScore", 33 | "--metricName", "auc")), 34 | // miss outputMetricFile 35 | Array( 36 | Seq( 37 | "--metricsInputDir", "global/validationScore", 38 | "--labelColumnName", "response", 39 | "--predictionColumnName", "predictionScore", 40 | "--metricName", "auc")), 41 | // miss labelColumnName 42 | Array( 43 | Seq( 44 | "--metricsInputDir", "global/validationScore", 45 | "--outputMetricFile", "global/metric/0", 46 | "--predictionColumnName", "predictionScore", 47 | "--metricName", "auc")), 48 | // miss predictionColumnName 49 | Array( 50 | Seq( 51 | "--metricsInputDir", "global/validationScore", 52 | "--outputMetricFile", "global/metric/0", 53 | "--labelColumnName", "response", 54 | "--metricName", "auc")), 55 | // miss metricName 56 | Array( 57 | Seq( 58 | "--metricsInputDir", "global/validationScore", 59 | "--outputMetricFile", "global/metric/0", 60 | "--labelColumnName", "response", 61 | "--predictionColumnName", "predictionScore")), 62 | // metricName not supported 63 | Array( 64 | Seq( 65 | "--metricsInputDir", "global/validationScore", 66 | "--outputMetricFile", "global/metric/0", 67 | "--labelColumnName", "response", 68 | "--predictionColumnName", "predictionScore", 69 | "--metricName", "UnsupportedMetric")) 70 | ) 71 | } 72 | 73 | @Test(dataProvider = "dataCompleteArgs") 74 | def testParseCompleteArguments(completeArgs: Seq[String]): Unit = { 75 | 76 | val params = EvaluatorParser.parse(completeArgs) 77 | val expectedParams = EvaluatorParams( 78 | metricsInputDir = "global/validationScore", 79 | outputMetricFile = "global/metric/0", 80 | labelColumnName = "response", 81 | predictionColumnName = "predictionScore", 82 | metricName = "auc" 83 | ) 84 | assertEquals(params, expectedParams) 85 | } 86 | 87 | @Test(dataProvider = "dataIncompleteArgs", expectedExceptions = Array(classOf[IllegalArgumentException])) 88 | def testThrowIllegalArgumentException(inCompleteArgs: Seq[String]): Unit = { 89 | EvaluatorParser.parse(inCompleteArgs) 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /gdmix-data/src/test/scala/com/linkedin/gdmix/parsers/LrModelSplitterParserTest.scala: -------------------------------------------------------------------------------- 1 | package com.linkedin.gdmix.parsers 2 | 3 | import org.testng.annotations.{DataProvider, Test} 4 | import org.testng.Assert.assertEquals 5 | 6 | /** 7 | * Unit tests for LrModelSplitterParserTest. 8 | */ 9 | class LrModelSplitterParserTest { 10 | 11 | @DataProvider 12 | def dataCompleteArgs(): Array[Array[Any]] = { 13 | Array( 14 | Array( 15 | Seq( 16 | "--modelInputDir", "global/input", 17 | "--modelOutputDir", "global/output", 18 | "--numOutputFiles", "100"))) 19 | } 20 | 21 | @DataProvider 22 | def dataIncompleteArgs(): Array[Array[Any]] = { 23 | 24 | Array( 25 | // missing input path 26 | Array( 27 | Seq( 28 | "--modelOutputDir", "global/output", 29 | "--numOutputFiles", "100")), 30 | // missing output path 31 | Array( 32 | Seq( 33 | "--modelInputDir", "global/input", 34 | "--numOutputFiles", "100")) 35 | ) 36 | } 37 | 38 | @Test(dataProvider = "dataCompleteArgs") 39 | def testParseCompleteArguments(completeArgs: Seq[String]): Unit = { 40 | 41 | val params = LrModelSplitterParser.parse(completeArgs) 42 | val expectedParams = LrModelSplitterParams( 43 | modelInputDir = "global/input", 44 | modelOutputDir = "global/output", 45 | numOutputFiles = 100 46 | ) 47 | assertEquals(params, expectedParams) 48 | } 49 | 50 | @Test(dataProvider = "dataIncompleteArgs", expectedExceptions = Array(classOf[IllegalArgumentException])) 51 | def testThrowIllegalArgumentException(inCompleteArgs: Seq[String]): Unit = { 52 | LrModelSplitterParser.parse(inCompleteArgs) 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /gdmix-data/src/test/scala/com/linkedin/gdmix/parsers/OffsetUpdaterParserTest.scala: -------------------------------------------------------------------------------- 1 | package com.linkedin.gdmix.parsers 2 | 3 | import org.testng.annotations.{DataProvider, Test} 4 | import org.testng.Assert.assertEquals 5 | 6 | /** 7 | * Unit tests for OffsetUpdaterParser. 8 | */ 9 | class OffsetUpdaterParserTest { 10 | 11 | @DataProvider 12 | def dataCompleteArgs(): Array[Array[Any]] = { 13 | Array( 14 | Array( 15 | Seq( 16 | "--trainingDataDir", "fixed-effect/trainingData", 17 | "--trainingScoreDir", "jymbii_lr/per-job/trainingScore", 18 | "--trainingScorePerCoordinateDir", "jymbii_lr/global/trainingScore", 19 | "--outputTrainingDataDir", "jymbii_lr/global/updatedTrainingData", 20 | "--validationDataDir", "fixed-effect/validationData", 21 | "--validationScoreDir", "jymbii_lr/per-job/validationScore", 22 | "--validationScorePerCoordinateDir", "jymbii_lr/global/validationScore", 23 | "--outputValidationDataDir", "jymbii_lr/global/updatedValidationData", 24 | "--numPartitions", "10"))) 25 | } 26 | 27 | @DataProvider 28 | def dataIncompleteArgs(): Array[Array[Any]] = { 29 | 30 | Array( 31 | // miss trainingDataDir 32 | Array( 33 | Seq( 34 | "--trainingScoreDir", "jymbii_lr/per-job/trainingScore", 35 | "--trainingScorePerCoordinateDir", "jymbii_lr/global/trainingScore", 36 | "--outputTrainingDataDir", "jymbii_lr/global/updatedTrainingData")), 37 | // miss trainingScoreDir 38 | Array( 39 | Seq( 40 | "--trainingDataDir", "fixed-effect/trainingData", 41 | "--trainingScorePerCoordinateDir", "jymbii_lr/global/trainingScore", 42 | "--outputTrainingDataDir", "jymbii_lr/global/updatedTrainingData")), 43 | // miss outputTrainingDataDir 44 | Array( 45 | Seq( 46 | "--trainingDataDir", "fixed-effect/trainingData", 47 | "--trainingScoreDir", "jymbii_lr/per-job/trainingScore", 48 | "--trainingScorePerCoordinateDir", "jymbii_lr/global/trainingScore")) 49 | ) 50 | } 51 | 52 | @DataProvider 53 | def dataIncorrectArgs(): Array[Array[Any]] = { 54 | 55 | Array( 56 | // negative numPartitions 57 | Array( 58 | Seq( 59 | "--trainingDataDir", "fixed-effect/trainingData", 60 | "--trainingScoreDir", "jymbii_lr/per-job/trainingScore", 61 | "--trainingScorePerCoordinateDir", "jymbii_lr/global/trainingScore", 62 | "--outputTrainingDataDir", "jymbii_lr/global/updatedTrainingData", 63 | "--numPartitions", "-10")) 64 | ) 65 | } 66 | 67 | @Test(dataProvider = "dataCompleteArgs") 68 | def testParseCompleteArguments(completeArgs: Seq[String]): Unit = { 69 | 70 | val params = OffsetUpdaterParser.parse(completeArgs) 71 | val expectedParams = OffsetUpdaterParams( 72 | trainingDataDir = "fixed-effect/trainingData", 73 | trainingScoreDir = "jymbii_lr/per-job/trainingScore", 74 | trainingScorePerCoordinateDir = Some("jymbii_lr/global/trainingScore"), 75 | outputTrainingDataDir = "jymbii_lr/global/updatedTrainingData", 76 | validationDataDir = Some("fixed-effect/validationData"), 77 | validationScoreDir = Some("jymbii_lr/per-job/validationScore"), 78 | validationScorePerCoordinateDir = Some("jymbii_lr/global/validationScore"), 79 | outputValidationDataDir = Some("jymbii_lr/global/updatedValidationData"), 80 | numPartitions = 10 81 | ) 82 | assertEquals(params, expectedParams) 83 | } 84 | 85 | @Test(dataProvider = "dataIncompleteArgs", expectedExceptions = Array(classOf[IllegalArgumentException])) 86 | def testThrowIllegalArgumentException(inCompleteArgs: Seq[String]): Unit = { 87 | OffsetUpdaterParser.parse(inCompleteArgs) 88 | } 89 | 90 | @Test(dataProvider = "dataIncorrectArgs", expectedExceptions = Array(classOf[IllegalArgumentException])) 91 | def testIncorrectArgs(inCorrectArgs: Seq[String]): Unit = { 92 | OffsetUpdaterParser.parse(inCorrectArgs) 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /gdmix-data/src/test/scala/com/linkedin/gdmix/utils/ConversionUtilsTest.scala: -------------------------------------------------------------------------------- 1 | package com.linkedin.gdmix.utils 2 | 3 | import com.linkedin.gdmix.utils.Constants.{CROSS, MEANS, MODEL_ID, NAME, TERM, VALUE} 4 | import org.apache.spark.sql.functions.col 5 | import org.apache.spark.sql.Row 6 | import org.apache.spark.sql.types._ 7 | import org.testng.Assert.assertTrue 8 | import org.testng.annotations.Test 9 | 10 | /** 11 | * Test functions in ConversionUtils 12 | */ 13 | class ConversionUtilsTest extends SharedSparkSession{ 14 | 15 | import spark.implicits._ 16 | 17 | /** 18 | * Unit test for [[splitModelIdUdf]]. 19 | */ 20 | @Test 21 | def tesSplitModelIdUdf(): Unit = { 22 | val schema = StructType(List(StructField(MEANS, StructType(List(StructField(NAME,StringType, true), 23 | StructField(TERM, StringType, true), StructField(VALUE, DoubleType, true))), true))) 24 | val inputData = Seq( 25 | Row(Row(s"m1${CROSS}f1", "t1", 0.3)), 26 | Row(Row(s"m2${CROSS}f2", "", 0.5))) 27 | val inputDf = spark.createDataFrame(spark.sparkContext.parallelize(inputData), schema) 28 | val splitDf = inputDf.withColumn(MEANS, 29 | ConversionUtils.splitModelIdUdf(col(MEANS))) 30 | val expectedData = Seq( 31 | Row(Row("m1", Row("f1", "t1", 0.3))), 32 | Row(Row("m2", Row("f2", "", 0.5)))) 33 | val expectedSchema = StructType(List(StructField(MEANS, StructType(List(StructField("_1", StringType, true), 34 | StructField("_2", StructType(List(StructField(NAME,StringType, true), 35 | StructField(TERM, StringType, true), StructField(VALUE, DoubleType, true))), true)))))) 36 | val expectedDf = spark.createDataFrame(spark.sparkContext.parallelize(expectedData), expectedSchema) 37 | assertTrue(TestUtils.equalSmallDataFrame(splitDf, expectedDf, MEANS)) 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /gdmix-data/src/test/scala/com/linkedin/gdmix/utils/SharedSparkSession.scala: -------------------------------------------------------------------------------- 1 | package com.linkedin.gdmix.utils 2 | 3 | import org.apache.log4j.Logger 4 | import org.apache.spark.SparkConf 5 | import org.apache.spark.sql.SparkSession 6 | import org.testng.annotations.{AfterSuite, BeforeSuite} 7 | 8 | /** 9 | * We need a common utility to create sparkSession. This is because 10 | * of the way Spark session works. We cannot have separate sparkSession 11 | * in each function/compilation unit level. 12 | */ 13 | trait SharedSparkSession { 14 | private var _spark: SparkSession = _ 15 | lazy val spark: SparkSession = _spark 16 | val logger: Logger = Logger.getLogger(getClass) 17 | 18 | @BeforeSuite 19 | def setupSpark(): Unit = { 20 | val sparkConf = new SparkConf().setMaster("local[*]").setAppName("SharedSparkSession") 21 | _spark = SparkSession.builder().config(sparkConf).getOrCreate() 22 | } 23 | 24 | @AfterSuite 25 | def stopSpark(): Unit = { 26 | _spark.stop() 27 | _spark = null 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /gdmix-data/src/test/scala/com/linkedin/gdmix/utils/TestUtils.scala: -------------------------------------------------------------------------------- 1 | package com.linkedin.gdmix.utils 2 | 3 | import org.apache.spark.sql.{DataFrame, Row} 4 | import org.apache.spark.sql.expressions.UserDefinedFunction 5 | import org.apache.spark.sql.functions.{col, udf} 6 | 7 | import com.linkedin.gdmix.utils.Constants._ 8 | import com.linkedin.gdmix.utils.ConversionUtils.NameTermValue 9 | 10 | /** 11 | * Helper functions for unit test in data module. 12 | */ 13 | object TestUtils { 14 | 15 | /** 16 | * Sort the columns of a DataFrame 17 | * if the original columns are ["c", "b", "a"], 18 | * the sorted columns are ["a", "b", "c"] 19 | * @param df: input DataFrame 20 | * @return: the sorted DataFrame 21 | */ 22 | def sortColumns(df: DataFrame): DataFrame = { 23 | val sortedColumns = df.columns.sorted.map(str => col(str)) 24 | df.select(sortedColumns:_*) 25 | } 26 | 27 | /** 28 | * Check if two small dataframes have equal content 29 | * @param df1 The first dataframe 30 | * @param df2 The second dataframe 31 | * @param sortedBy The column name by which the dataframes to be sorted 32 | * @return true of false 33 | */ 34 | def equalSmallDataFrame(df1: DataFrame, df2: DataFrame, sortedBy: String): Boolean = { 35 | // sort the rows 36 | val sdf1 = df1.sort(sortedBy) 37 | val sdf2 = df2.sort(sortedBy) 38 | 39 | // sort the columns 40 | val ssdf1 = sortColumns(sdf1) 41 | val ssdf2 = sortColumns(sdf2) 42 | 43 | (ssdf1.schema.equals(ssdf2.schema) 44 | && ssdf1.collect().sameElements(ssdf2.collect())) 45 | } 46 | 47 | /** 48 | * Remove the whitespace from a string. 49 | * 50 | * @param s - the string 51 | * @return the string with whitespaces removed 52 | */ 53 | def removeWhiteSpace(s: String): String = s.replaceAll("\\s", "") 54 | } 55 | -------------------------------------------------------------------------------- /gdmix-trainer/README.md: -------------------------------------------------------------------------------- 1 | # GDMix 2 | 3 | ## What is it 4 | Generalized Deep [Mixed Model](https://en.wikipedia.org/wiki/Mixed_model) (GDMix) is a framework to train non-linear fixed effect and random effect models. This kind of models are widely used in personalization of search and recommender systems. This project is an extension of our early effort on generalized linear models [Photon ML](https://github.com/linkedin/photon-ml). It is implemented in Tensorflow, Scipy and Spark. 5 | 6 | The current version of GDMix supports logistic regression and [DeText](https://github.com/linkedin/detext) models for the fixed effect, then logistic regression for the random effects. In the future, we may support deep models for random effects if the increase complexity can be justified by improvement in relevance metrics. 7 | 8 | ## Supported models 9 | ### Logistic regression 10 | As a basic classification model, logistic regression finds wide usage in search and recommender systems due to its model simplicity and training efficiency. Our implementation uses Tensorflow for data reading and gradient computation, and utilizes L-BFGS solver from Scipy. This combination takes advantage of the versatility of Tensorflow and fast convergence of L-BFGS. This mode is functionally equivalent to Photon-ML but with improved efficiency. Our internal tests show about 10% to 40% training speed improvement on various datasets. 11 | 12 | ### DeText models 13 | DeText is a framework for ranking with emphasis on textual features. GDMix supports DeText training natively as a global model. A user can specify a fixed effect model type as DeText then provide the network specifications. GDMix will train and score it automatically and connect the model to the subsequent random effect models. Currently only the pointwise loss function from DeText is allowed to be connected with the logistic regression random effect models. 14 | 15 | ### Other models 16 | GDMix can work with any deep learning fixed effect models. The interface between GDMix and other models is at the file I/O. A user can train a model outside GDMix, then score the training data with the model and save the scores in files, which are the input to the GDMix random effect training. This enables the user to train random effect models based on scores from a custom fixed effect model that is not natively supported by GDMix. 17 | 18 | ## Training efficiency 19 | For logistic regression models, the training efficiency is achieved by parallel training. Since the fixed effect model is usually trained on a large amount of data, synchronous training based on Tensorflow all-reduce operation is utilized. Each worker takes a portion of the training data and compute the local gradient. The gradients are aggregated then fed to the L-BFGS solver. The training dataset for each random effect model is usually small, however the number of models (e.g. individual models for all LinkedIn members) can be on the order of hundred of millions. This requires a partitioning and parallel training strategy, where each worker is responsible for a portion of the population and all the workers train their assigned models independently and simultaneously. 20 | 21 | For DeText models, efficiency is achieved by either Tensorflow based parameter server asynchronous distributed training or Horovod based synchronous distributed training. 22 | -------------------------------------------------------------------------------- /gdmix-trainer/setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 160 3 | 4 | [tool:pytest] 5 | testpaths = test/ 6 | 7 | [coverage:report] 8 | # TODO: Uncomment the line below when you've added tests to enforce minimum required coverage 9 | #fail_under = 100 10 | show_missing = true 11 | 12 | [coverage:run] 13 | branch = true 14 | 15 | [mypy] 16 | mypy_path = src 17 | namespace_packages = true 18 | ignore_missing_imports = true 19 | 20 | [aliases] 21 | test=pytest 22 | -------------------------------------------------------------------------------- /gdmix-trainer/setup.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from setuptools import find_namespace_packages, setup 3 | from sys import platform as _platform 4 | 5 | import sys 6 | 7 | VERSION = "0.5.0" 8 | current_dir = Path(__file__).resolve().parent 9 | with open(current_dir.joinpath('README.md'), encoding='utf-8') as f: 10 | long_description = f.read() 11 | 12 | if _platform not in ["linux", "linux2", "darwin"]: 13 | print("ERROR: platform {} isn't supported".format(_platform)) 14 | sys.exit(1) 15 | 16 | TF_VERSION_QUANTIFIER = '>=2.4,<2.5' 17 | 18 | setup( 19 | name="gdmix-trainer", 20 | python_requires='>=3.7', 21 | long_description=long_description, 22 | long_description_content_type='text/markdown', 23 | classifiers=["Programming Language :: Python :: 3.7", 24 | "Intended Audience :: Science/Research", 25 | "Intended Audience :: Developers", 26 | "License :: OSI Approved"], 27 | license='BSD-2-CLAUSE', 28 | version=VERSION, 29 | package_dir={'': 'src'}, 30 | packages=find_namespace_packages(where='src'), 31 | include_package_data=True, 32 | install_requires=[ 33 | "numpy>=1.19.5", 34 | "absl-py==0.10", 35 | "decorator==4.4.2", 36 | "detext-nodep==3.2.0", 37 | "gin-config==0.3.0", 38 | "fastavro==0.21.22", 39 | "grpcio==1.32.0", 40 | "protobuf==3.19", 41 | "psutil==5.7.0", 42 | "scikit-learn==1.0", 43 | "setuptools>=41.0.0", 44 | "six==1.15.0", 45 | "smart-arg==0.4", 46 | "statsmodels==0.13.1", 47 | "scipy==1.5.4", 48 | f"tensorflow{TF_VERSION_QUANTIFIER}", 49 | "tensorflow-addons==0.12.1", 50 | f"tensorflow-text{TF_VERSION_QUANTIFIER}", 51 | f"tensorflow-serving-api{TF_VERSION_QUANTIFIER}", 52 | "tensorflow_ranking", 53 | f"tf-models-official{TF_VERSION_QUANTIFIER}", 54 | "tomli==1.2.2" 55 | ], 56 | tests_require=['pytest'] 57 | ) 58 | -------------------------------------------------------------------------------- /gdmix-trainer/src/gdmix/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT COPY & PASTE THIS CODE!!!!! 2 | # 3 | # This is a special file only needed for "src/gdmix/__init__.py" 4 | # to declare the "gdmix" package as a "namespace" 5 | # 6 | # All other "__init__.py" files can just be blank, or contain normal Python 7 | # module code. 8 | try: 9 | __import__('pkg_resources').declare_namespace(__name__) 10 | except ImportError: 11 | from pkgutil import extend_path 12 | __path__ = extend_path(__path__, __name__) 13 | -------------------------------------------------------------------------------- /gdmix-trainer/src/gdmix/drivers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/gdmix/dc24377e808ecc287ece67d88ffb800fb9ffaaa4/gdmix-trainer/src/gdmix/drivers/__init__.py -------------------------------------------------------------------------------- /gdmix-trainer/src/gdmix/drivers/fixed_effect_driver.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import logging 4 | import tensorflow as tf 5 | from gdmix.drivers.driver import Driver 6 | from gdmix.util import constants 7 | from gdmix.util.distribution_utils import remove_tf_config 8 | 9 | logger = logging.Logger(__name__) 10 | logger.setLevel(logging.INFO) 11 | 12 | 13 | class FixedEffectDriver(Driver): 14 | """ 15 | Driver class to support fixed-effect training. 16 | """ 17 | 18 | def __init__(self, base_training_params, model): 19 | super().__init__(base_training_params, model, constants.FIXED_EFFECT) 20 | 21 | def _validate_params(self): 22 | pass 23 | 24 | def _setup_cluster(self): 25 | logger.info("Setting up cluster parameters for fixed effect training") 26 | tf_config = os.environ.get(constants.TF_CONFIG) 27 | if not tf_config: 28 | # setup local mode 29 | execution_context = {constants.TASK_TYPE: 'worker', 30 | constants.TASK_INDEX: 0, 31 | constants.CLUSTER_SPEC: {"worker": ["localhost:2222"]}, 32 | constants.NUM_WORKERS: 1, 33 | constants.NUM_SHARDS: 1, 34 | constants.SHARD_INDEX: 0, 35 | constants.IS_CHIEF: True} 36 | return execution_context 37 | tf_config_json = json.loads(tf_config) 38 | cluster = tf_config_json.get('cluster') 39 | if self.base_training_params.action == constants.ACTION_INFERENCE: 40 | # Inference / prediction / validation runs in local mode. 41 | cluster_spec = None 42 | else: 43 | cluster_spec = tf.train.ClusterSpec(cluster) 44 | execution_context = {constants.TASK_TYPE: tf_config_json.get('task', {}).get('type'), 45 | constants.TASK_INDEX: tf_config_json.get('task', {}).get('index'), 46 | constants.CLUSTER_SPEC: cluster_spec, 47 | constants.NUM_WORKERS: tf.train.ClusterSpec(cluster).num_tasks(constants.WORKER), 48 | constants.NUM_SHARDS: tf.train.ClusterSpec(cluster).num_tasks(constants.WORKER), 49 | constants.SHARD_INDEX: tf_config_json.get('task', {}).get('index'), 50 | constants.IS_CHIEF: tf_config_json.get('task', {}).get('index') == 0} 51 | if execution_context[constants.TASK_TYPE] is None or execution_context[constants.TASK_INDEX] is None: 52 | raise Exception('No job name found') 53 | if execution_context[constants.NUM_WORKERS] < 1: 54 | raise Exception('No worker found') 55 | if cluster_spec is None: 56 | # Remove TF_CONFIG if cluster_spec is none. 57 | remove_tf_config() 58 | return execution_context 59 | 60 | def _get_partition_list(self): 61 | # For fixed effect training, partition index is the same as task index 62 | return [self.execution_context[constants.TASK_INDEX]] 63 | 64 | def _anchor_directory(self, directory_path, partition_index): 65 | # For fixed effect, anchoring using partition_index is not required 66 | assert partition_index >= 0 67 | return directory_path 68 | -------------------------------------------------------------------------------- /gdmix-trainer/src/gdmix/drivers/random_effect_driver.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import logging 4 | import tensorflow as tf 5 | from gdmix.drivers.driver import Driver 6 | from gdmix.util import constants 7 | from gdmix.util.distribution_utils import remove_tf_config 8 | 9 | logger = logging.getLogger(__name__) 10 | logger.setLevel(logging.INFO) 11 | 12 | 13 | class RandomEffectDriver(Driver): 14 | """ 15 | Driver class to support random-effect training. 16 | """ 17 | _RANDOM_EFFECT_PARTITION_DIR_PREFIX = "partitionId=" 18 | 19 | def __init__(self, base_training_params, model): 20 | super().__init__(base_training_params, model, constants.RANDOM_EFFECT) 21 | 22 | def _validate_params(self): 23 | assert self.base_training_params.model_type == constants.LOGISTIC_REGRESSION, \ 24 | "Random effect supports logistic_regression only" 25 | assert self.base_training_params.partition_list_file is not None, \ 26 | "Random effect requires partition list file" 27 | 28 | def _setup_cluster(self): 29 | logger.info("Setting up cluster parameters for random effect training") 30 | tf_config = os.environ.get(constants.TF_CONFIG) 31 | if not tf_config: 32 | # setup local mode 33 | execution_context = {constants.TASK_TYPE: 'worker', 34 | constants.TASK_INDEX: 0, 35 | constants.CLUSTER_SPEC: None, 36 | constants.NUM_WORKERS: 1, 37 | constants.NUM_SHARDS: 1, 38 | constants.SHARD_INDEX: 0, 39 | constants.IS_CHIEF: True} 40 | return execution_context 41 | tf_config_json = json.loads(tf_config) 42 | 43 | cluster = tf_config_json.get('cluster') 44 | execution_context = {constants.TASK_TYPE: tf_config_json.get('task', {}).get('type'), 45 | constants.TASK_INDEX: tf_config_json.get('task', {}).get('index'), 46 | # Random effect runs in local mode 47 | constants.CLUSTER_SPEC: None, 48 | constants.NUM_WORKERS: tf.train.ClusterSpec(cluster).num_tasks(constants.WORKER), 49 | constants.NUM_SHARDS: 1, 50 | constants.SHARD_INDEX: 0, 51 | constants.IS_CHIEF: tf_config_json.get('task', {}).get('index') == 0} 52 | if execution_context[constants.TASK_TYPE] is None or execution_context[constants.TASK_INDEX] is None: 53 | raise Exception('No job name found') 54 | if execution_context[constants.NUM_WORKERS] < 1: 55 | raise Exception('No worker found') 56 | # Since random effect runs in local mode, set TF_CONFIG to {} 57 | remove_tf_config() 58 | return execution_context 59 | 60 | def _get_partition_list(self): 61 | with tf.io.gfile.GFile(self.base_training_params.partition_list_file) as f: 62 | line = f.readline() 63 | all_partitions = [int(l) for l in line.split(',')] 64 | num_partitions = len(all_partitions) 65 | indices = list(range(self.execution_context[constants.TASK_INDEX], num_partitions, 66 | self.execution_context[constants.NUM_WORKERS])) 67 | partition_index_list = [all_partitions[i] for i in indices] 68 | return partition_index_list 69 | 70 | def _anchor_directory(self, directory_path, partition_index): 71 | # For random effect, directories should be anchored by attaching partition information 72 | return os.path.join(directory_path, 73 | RandomEffectDriver._RANDOM_EFFECT_PARTITION_DIR_PREFIX + str(partition_index)) 74 | -------------------------------------------------------------------------------- /gdmix-trainer/src/gdmix/factory/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/gdmix/dc24377e808ecc287ece67d88ffb800fb9ffaaa4/gdmix-trainer/src/gdmix/factory/__init__.py -------------------------------------------------------------------------------- /gdmix-trainer/src/gdmix/factory/driver_factory.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from gdmix.drivers.fixed_effect_driver import FixedEffectDriver 4 | from gdmix.drivers.random_effect_driver import RandomEffectDriver 5 | from gdmix.factory.model_factory import ModelFactory 6 | from gdmix.util import constants 7 | 8 | logger = logging.getLogger(__name__) 9 | logger.setLevel(logging.INFO) 10 | 11 | 12 | class DriverFactory: 13 | """ 14 | Provider class for creating driver and dependencies 15 | 16 | NOTE - for now, only linear and DeText models are supported. In the future, the factory will also 17 | accept model type as an input parameter 18 | """ 19 | 20 | @staticmethod 21 | def get_driver(base_training_params, raw_model_params): 22 | """ 23 | Create driver and associated dependencies, based on type. Only linear and DeText models are supported 24 | for now 25 | :param base_training_params: Parsed base training parameters common to all models. This could including 26 | path to training data, validation data, metadata file path, learning rate etc. 27 | :param raw_model_params: Raw model parameters, representing model-specific requirements. For example, a 28 | CNN might expose filter_size as a parameter, a text-based model might expose the size it's word embedding matrix 29 | as a parameter 30 | :return: Fixed or Random effect driver 31 | """ 32 | 33 | driver = DriverFactory.drivers[base_training_params.stage] 34 | model = ModelFactory.get_model(base_training_params, raw_model_params) 35 | logger.info(f"Instantiating model {model} and driver {driver}") 36 | return driver(base_training_params=base_training_params, model=model) 37 | 38 | drivers = {constants.FIXED_EFFECT: FixedEffectDriver, constants.RANDOM_EFFECT: RandomEffectDriver} 39 | -------------------------------------------------------------------------------- /gdmix-trainer/src/gdmix/factory/model_factory.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import tensorflow as tf 4 | 5 | from gdmix.params import Params 6 | 7 | from gdmix.models.detext.fixed_effect_detext_model import FixedEffectDetextModel 8 | from gdmix.models.custom.random_effect_lr_lbfgs_model import RandomEffectLRLBFGSModel 9 | from gdmix.models.custom.fixed_effect_lr_lbfgs_model import FixedEffectLRModelLBFGS 10 | from gdmix.util import constants 11 | 12 | logger = logging.getLogger(__name__) 13 | logger.setLevel(logging.INFO) 14 | 15 | 16 | class ModelFactory: 17 | """ 18 | Provider class for creating model instances and dependencies 19 | 20 | NOTE - for now, only linear and DeText models are supported. In the future, the factory will also 21 | accept model type as an input parameter 22 | """ 23 | 24 | @staticmethod 25 | def get_model(base_training_params: Params, raw_model_params): 26 | """ 27 | Create driver and associated dependencies, based on type. Only linear and DeText models are supported 28 | for now 29 | :param base_training_params: Parsed base training parameters common to all models. This could including 30 | path to training data, validation data, metadata file path, learning rate etc. 31 | :param raw_model_params: Raw model parameters, representing model-specific requirements. For example, a 32 | CNN might expose filter_size as a parameter, a text-based model might expose the size it's word embedding matrix 33 | as a parameter 34 | :return: Model instances 35 | """ 36 | model_type = base_training_params.model_type 37 | driver_type = base_training_params.stage 38 | logger.info(f"Instantiating {model_type} model and driver") 39 | if model_type in [constants.LOGISTIC_REGRESSION, constants.LINEAR_REGRESSION]: 40 | tf.compat.v1.disable_eager_execution() 41 | if driver_type == constants.FIXED_EFFECT: 42 | logger.info("Choosing Scipy-LBFGS FE model") 43 | model = FixedEffectLRModelLBFGS( 44 | raw_model_params=raw_model_params, base_training_params=base_training_params) 45 | else: 46 | if model_type == constants.LINEAR_REGRESSION: 47 | raise Exception(f"Does not support random effect model for plain linear regression") 48 | logger.info("Choosing Scipy RE model") 49 | model = RandomEffectLRLBFGSModel(raw_model_params=raw_model_params) 50 | elif model_type == constants.DETEXT: 51 | model = FixedEffectDetextModel(raw_model_params=raw_model_params) 52 | else: 53 | raise Exception(f"Unknown training models {model_type}") 54 | return model 55 | -------------------------------------------------------------------------------- /gdmix-trainer/src/gdmix/gdmix.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | 4 | from gdmix.factory.driver_factory import DriverFactory 5 | from gdmix.params import Params, SchemaParams 6 | from gdmix.util import constants 7 | 8 | logging.basicConfig(level=logging.INFO) 9 | logger = logging.getLogger(__name__) 10 | logger.setLevel(logging.INFO) 11 | 12 | 13 | def run(args): 14 | """ 15 | Parse CMD line arguments, instantiate Driver and Model object and handover control to Driver 16 | :param args: command line arguments 17 | :return: None 18 | """ 19 | # Parse base training parameters that are required for all models. For other arguments, the 20 | # Driver delegates parsing to the specific model it encapsulates 21 | params = Params.__from_argv__(args, error_on_unknown=False) 22 | schema_params = SchemaParams.__from_argv__(args, error_on_unknown=False) 23 | 24 | # Log parsed base training parameters 25 | logger.info(f"Parsed schema params amd gdmix args (params): {params}") 26 | 27 | # Instantiate appropriate driver, encapsulating a specific model 28 | driver = DriverFactory.get_driver(base_training_params=params, raw_model_params=args) 29 | 30 | # Run driver to either [1] train, [2] run evaluation or [3] export model 31 | if params.action == constants.ACTION_TRAIN: 32 | driver.run_training(schema_params=schema_params, export_model=True) 33 | elif params.action == constants.ACTION_INFERENCE: 34 | driver.run_inference(schema_params=schema_params) 35 | else: 36 | raise Exception(f"Unsupported action {params.action}") 37 | 38 | 39 | if __name__ == '__main__': 40 | run(sys.argv) 41 | -------------------------------------------------------------------------------- /gdmix-trainer/src/gdmix/io/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/gdmix/dc24377e808ecc287ece67d88ffb800fb9ffaaa4/gdmix-trainer/src/gdmix/io/__init__.py -------------------------------------------------------------------------------- /gdmix-trainer/src/gdmix/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/gdmix/dc24377e808ecc287ece67d88ffb800fb9ffaaa4/gdmix-trainer/src/gdmix/models/__init__.py -------------------------------------------------------------------------------- /gdmix-trainer/src/gdmix/models/api.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class Model(abc.ABC): 5 | """ 6 | Abstract class. Must be subclassed to support fixed or random effect training. 7 | 8 | The deriving subclasses can rely on different model frameworks like TF Estimator, session-based training, Keras etc. 9 | 10 | Supports the following functionality: 11 | 12 | 1) Encapsulates model graph specific to fixed or random effect 13 | 3) Interfaces with the underlying training framework. 14 | 3) Wires together model graph, loss functions, optimizers and metrics 15 | 4) Exposes APIs for model compiling, training, prediction, export to be used in the business logic of the driver 16 | """ 17 | def __init__(self, raw_model_params): 18 | self.model_params = self._parse_parameters(raw_model_params) 19 | self.metadata_file = None 20 | self.checkpoint_path = None 21 | self.training_data_dir = None 22 | self.validation_data_dir = None 23 | 24 | @abc.abstractmethod 25 | def train(self, 26 | training_data_dir, 27 | validation_data_dir, 28 | metadata_file, 29 | checkpoint_path, 30 | execution_context, 31 | schema_params): 32 | """ 33 | Fit/train the model 34 | The interface should use internal model parameters `model_params` for model-specifc params. 35 | The data path arguments' values could be different from what users specify in the config 36 | because gdmix internally partitions the data into different chunks for distributed training/inference. 37 | :param training_data_dir the path to training data 38 | :param validation_data_dir the path to validation data 39 | :param metadata_file the path to tensor metadata file 40 | :param checkpoint_path the path to designated savedmodel/checkpoint directory 41 | :param execution_context the tensorflow cluster setup 42 | :param schema_params parameters for schema field keyword definition 43 | :return: None 44 | """ 45 | raise NotImplementedError('Must be implemented in subclasses.') 46 | 47 | @abc.abstractmethod 48 | def predict(self, 49 | output_dir, 50 | input_data_path, 51 | metadata_file, 52 | checkpoint_path, 53 | execution_context, 54 | schema_params): 55 | """ 56 | Run inference on a provided dataset 57 | :param output_dir: the path to which inference output should be written 58 | :param input_data_path the path to validation data 59 | :param metadata_file the path to tensor metadata file 60 | :param checkpoint_path the path to designated savedmodel/checkpoint directory 61 | :param execution_context the tensorflow cluster setup 62 | :param schema_params: parameters for schema field keyword definition 63 | :return: None 64 | """ 65 | raise NotImplementedError('Must be implemented in subclasses.') 66 | 67 | @abc.abstractmethod 68 | def export(self, output_model_dir): 69 | """ 70 | Export TF model into the SavedModel format 71 | :param output_model_dir: model directory where model should be exported 72 | :return: None 73 | """ 74 | raise NotImplementedError('Must be implemented in subclasses.') 75 | 76 | @abc.abstractmethod 77 | def _parse_parameters(self, raw_model_parameters): 78 | """ 79 | Parse model-specific parameters. This excludes generic parameters like path to training set, optimization algo 80 | etc. which are necessary for all models 81 | :param raw_model_parameters: TF Dataset object 82 | :return: Parsed dict of model-specific arguments 83 | """ 84 | raise NotImplementedError('Must be implemented in subclasses.') 85 | -------------------------------------------------------------------------------- /gdmix-trainer/src/gdmix/models/custom/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/gdmix/dc24377e808ecc287ece67d88ffb800fb9ffaaa4/gdmix-trainer/src/gdmix/models/custom/__init__.py -------------------------------------------------------------------------------- /gdmix-trainer/src/gdmix/models/custom/base_lr_params.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | 4 | 5 | @dataclass 6 | class LRParams: 7 | """Base linear model parameters""" 8 | 9 | # Input / output files or directories 10 | metadata_file: str # Path to metadata. 11 | output_model_dir: str # Model output directory. 12 | training_data_dir: Optional[str] = None # Path of directory holding only training data files. 13 | validation_data_dir: Optional[str] = None # "Path of directory holding only data files for in-line validation." 14 | # Column names in the dataset 15 | feature_bag: Optional[str] = None # Feature bag name that is used for training and scoring. 16 | 17 | # Arguments for model export 18 | feature_file: Optional[str] = None # Feature file for model exporting. 19 | 20 | # Optimizer related parameters 21 | regularize_bias: bool = True # Boolean for L2 regularization of bias term. 22 | l2_reg_weight: float = 1.0 # Weight of L2 regularization for each feature bag. 23 | lbfgs_tolerance: float = 1e-12 # LBFGS tolerance. 24 | num_of_lbfgs_curvature_pairs: int = 10 # Number of curvature pairs for LBFGS training. 25 | num_of_lbfgs_iterations: int = 100 # Number of LBFGS iterations. 26 | 27 | # Model related parameters 28 | # Whether to include intercept (the "b" in wx+b) 29 | has_intercept: bool = True 30 | offset_column_name: str = "offset" # Score from previous model. 31 | # The model coefficients are treated as zero if their absolute values are less than or equal to sparsity_threshold. 32 | sparsity_threshold = 1.0e-4 # coefficients less than or equal to the threshold are ignored. 33 | 34 | # Dataset parameters 35 | batch_size: int = 16 36 | data_format: str = "tfrecord" 37 | 38 | def __post_init__(self): 39 | assert self.batch_size > 0, "Batch size must be positive number" 40 | if self.regularize_bias: 41 | assert self.has_intercept, "Intercept must be used when it is regularized" 42 | assert self.feature_bag or self.has_intercept, "Either intercept or feature bag much be used" 43 | -------------------------------------------------------------------------------- /gdmix-trainer/src/gdmix/models/custom/scipy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/gdmix/dc24377e808ecc287ece67d88ffb800fb9ffaaa4/gdmix-trainer/src/gdmix/models/custom/scipy/__init__.py -------------------------------------------------------------------------------- /gdmix-trainer/src/gdmix/models/detext/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/gdmix/dc24377e808ecc287ece67d88ffb800fb9ffaaa4/gdmix-trainer/src/gdmix/models/detext/__init__.py -------------------------------------------------------------------------------- /gdmix-trainer/src/gdmix/models/detext_writer.py: -------------------------------------------------------------------------------- 1 | import fastavro 2 | import logging 3 | import tensorflow as tf 4 | 5 | from gdmix.util.io_utils import try_write_avro_blocks 6 | 7 | logger = logging.getLogger(__name__) 8 | logger.setLevel(logging.INFO) 9 | 10 | 11 | class DetextWriter: 12 | """ 13 | Helper class to run inference and write detext model and/or data 14 | """ 15 | 16 | def __init__(self, schema_params): 17 | self.schema_params = schema_params 18 | 19 | def get_inference_output_avro_schema(self): 20 | schema = { 21 | 'name': 'validation_result', 22 | 'type': 'record', 23 | 'fields': [ 24 | {'name': self.schema_params.uid_column_name, 'type': 'long'}, 25 | {'name': self.schema_params.weight_column_name, 'type': 'float'}, 26 | {'name': self.schema_params.label_column_name, 'type': 'int'}, 27 | {'name': self.schema_params.prediction_score_column_name, 'type': 'float'} 28 | ], 29 | } 30 | return schema 31 | 32 | def append_validation_results(self, records, predicts, ids, labels, weights): 33 | batch_size = predicts.shape[0] 34 | assert predicts.shape[0] == ids.shape[0] 35 | assert predicts.shape[0] == labels.shape[0] 36 | assert predicts.shape[0] == weights.shape[0] 37 | for i in range(batch_size): 38 | # we only support pointwise training for detext 39 | # label is list of one scalar 40 | # score is also scalar 41 | record = {self.schema_params.prediction_score_column_name: predicts[i], 42 | self.schema_params.uid_column_name: ids[i], 43 | self.schema_params.label_column_name: int(labels[i]), 44 | self.schema_params.weight_column_name: weights[i]} 45 | records.append(record) 46 | return batch_size 47 | 48 | def save_batch(self, f, batch_score, output_file, n_records, n_batch): 49 | validation_results = [] 50 | validation_schema = fastavro.parse_schema(self.get_inference_output_avro_schema()) 51 | # save one batch of score 52 | try: 53 | predict_val = batch_score['score'].numpy() 54 | ids = batch_score[self.schema_params.uid_column_name].numpy() 55 | labels = batch_score[self.schema_params.label_column_name].numpy() 56 | weights = batch_score[self.schema_params.weight_column_name].numpy() 57 | n_records += self.append_validation_results(validation_results, 58 | predict_val, 59 | ids, 60 | labels, 61 | weights) 62 | n_batch += 1 63 | except tf.errors.OutOfRangeError: 64 | logger.info( 65 | 'Iterated through one batch. Finished evaluating work at batch {0}.'.format(n_batch)) 66 | else: 67 | try_write_avro_blocks(f, validation_schema, validation_results, None, 68 | self.create_error_message(n_batch, output_file)) 69 | return n_records, n_batch 70 | 71 | def create_error_message(self, n_batch, output_file): 72 | err_msg = 'An error occurred while writing batch #{} to path {}'.format( 73 | n_batch, output_file) 74 | return err_msg 75 | -------------------------------------------------------------------------------- /gdmix-trainer/src/gdmix/models/schemas.py: -------------------------------------------------------------------------------- 1 | # photon-ml output model format 2 | 3 | BAYESIAN_LINEAR_MODEL_SCHEMA = """ 4 | { 5 | "type" : "record", 6 | "name" : "BayesianLinearModelAvro", 7 | "namespace" : "com.linkedin.photon.avro.generated", 8 | "doc" : "a generic schema to describe a Bayesian linear model with means and variances", 9 | "fields" : [ { 10 | "name" : "modelId", 11 | "type" : "string" 12 | }, { 13 | "name" : "modelClass", 14 | "type" : [ "null", "string" ], 15 | "doc" : "The fully-qualified class name of enclosing GLM model class. E.g.: com.linkedin.photon.ml.supervised.classification.LogisticRegressionModel", 16 | "default" : null 17 | }, { 18 | "name" : "means", 19 | "type" : { 20 | "type" : "array", 21 | "items" : { 22 | "type" : "record", 23 | "name" : "NameTermValueAvro", 24 | "doc" : "A tuple of name, term and value. Used as feature or model coefficient", 25 | "fields" : [ { 26 | "name" : "name", 27 | "type" : "string" 28 | }, { 29 | "name" : "term", 30 | "type" : "string" 31 | }, { 32 | "name" : "value", 33 | "type" : "double" 34 | } ] 35 | } 36 | } 37 | }, { 38 | "name" : "variances", 39 | "type" : [ "null", { 40 | "type" : "array", 41 | "items" : "NameTermValueAvro" 42 | } ], 43 | "default" : null 44 | }, { 45 | "name" : "lossFunction", 46 | "type" : [ "null", "string" ], 47 | "doc" : "The loss function used for training as the class name. E.g.: com.linkedin.photon.ml.function.LogisticLossFunction", 48 | "default" : null 49 | } ] 50 | } 51 | """ 52 | -------------------------------------------------------------------------------- /gdmix-trainer/src/gdmix/params.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | 4 | from gdmix.util import constants 5 | from smart_arg import arg_suite, LateInit 6 | 7 | _ACTIONS = (constants.ACTION_INFERENCE, constants.ACTION_TRAIN) 8 | _STAGES = (constants.FIXED_EFFECT, constants.RANDOM_EFFECT) 9 | _MODEL_TYPES = (constants.LOGISTIC_REGRESSION, constants.LINEAR_REGRESSION, constants.DETEXT) 10 | 11 | 12 | @dataclass 13 | class GDMixParams: 14 | action: str = _ACTIONS[1] # Train or inference. 15 | __action = {"choices": _ACTIONS} 16 | stage: str = _STAGES[0] # Fixed or random effect. 17 | __stage = {"choices": _STAGES} 18 | # The model type to train, e.g, logistic regression, linear regression, detext, etc. 19 | model_type: str = _MODEL_TYPES[0] 20 | __model_type = {"choices": _MODEL_TYPES} 21 | 22 | # Input / output directories 23 | training_score_dir: Optional[str] = None # Path to the prediction score directory of the training data.. 24 | validation_score_dir: Optional[str] = None # Path to the prediction score directory of the validation data.. 25 | 26 | # Driver arguments for random effect training 27 | partition_list_file: Optional[str] = None # File containing a list of all the partition ids, for random effect only 28 | 29 | def __post_init__(self): 30 | assert self.action in _ACTIONS, f"Action: {self.action} must be in {_ACTIONS}" 31 | assert self.stage in _STAGES, f"Stage: {self.stage} must be in {_STAGES}" 32 | assert self.model_type in _MODEL_TYPES, f"Model type: {self.model_type} must be in {_MODEL_TYPES}" 33 | 34 | 35 | @arg_suite 36 | @dataclass 37 | class SchemaParams: 38 | # Schema names 39 | uid_column_name: str = LateInit # Unique id column name in the train/validation data. 40 | weight_column_name: Optional[str] = None # weight column name in the train/validation data. 41 | label_column_name: Optional[str] = None # Label column name in the train/validation data. 42 | prediction_score_column_name: Optional[str] = None # Prediction score column name in the generated result file. 43 | prediction_score_per_coordinate_column_name: str = "predictionScorePerCoordinate" # Column name of the prediction score without the offset. 44 | 45 | 46 | @arg_suite 47 | @dataclass 48 | class Params(GDMixParams, SchemaParams): 49 | """GDMix Driver""" 50 | 51 | def __post_init__(self): 52 | super().__post_init__() 53 | assert (self.action == constants.ACTION_TRAIN and self.label_column_name) or \ 54 | (self.action == constants.ACTION_INFERENCE and self.prediction_score_column_name) 55 | -------------------------------------------------------------------------------- /gdmix-trainer/src/gdmix/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/gdmix/dc24377e808ecc287ece67d88ffb800fb9ffaaa4/gdmix-trainer/src/gdmix/util/__init__.py -------------------------------------------------------------------------------- /gdmix-trainer/src/gdmix/util/constants.py: -------------------------------------------------------------------------------- 1 | ACTION = "action" 2 | STAGE = "stage" 3 | LABEL_COLUMN_NAME = "label_column_name" 4 | MODEL_TYPE = "model_type" 5 | LBFGS = "lbfgs" 6 | PREDICTION_SCORE_COLUMN_NAME = "prediction_score_column_name" 7 | PREDICTION_SCORE_PER_COORDINATE_COLUMN_NAME = "prediction_score_per_coordinate_column_name" 8 | OFFSET_COLUMN_NAME = "offset_column_name" 9 | UID_COLUMN_NAME = "uid_column_name" 10 | WEIGHT_COLUMN_NAME = "weight_column_name" 11 | FEATURE_BAG = "feature_bag" 12 | PARTITION_ENTITY = "partition_entity" 13 | TRAINING_DATA_DIR = "training_data_dir" 14 | VALIDATION_DATA_DIR = "validation_data_dir" 15 | OUTPUT_MODEL_DIR = "output_model_dir" 16 | DATA_FORMAT = "data_format" 17 | COPY_TO_LOCAL = "copy_to_local" 18 | 19 | FEATURE_FILE = "feature_file" 20 | 21 | # Training parameters-related constants 22 | BATCH_SIZE = "batch_size" 23 | DELAYED_EXIT_IN_SECONDS = "delayed_exit_in_seconds" 24 | ENABLE_LOCAL_INDEXING = "enable_local_indexing" 25 | L2_REG_WEIGHT = "l2_reg_weight" 26 | LBFGS_TOLERANCE = "lbfgs_tolerance" 27 | MAX_TRAINING_QUEUE_SIZE = "max_training_queue_size" 28 | NUM_OF_CONSUMERS = "num_of_consumers" 29 | NUM_OF_LBFGS_CURVATURE_PAIRS = "num_of_lbfgs_curvature_pairs" 30 | NUM_OF_LBFGS_ITERATIONS = "num_of_lbfgs_iterations" 31 | REGULARIZE_BIAS = "regularize_bias" 32 | TRAINING_QUEUE_TIMEOUT_IN_SECONDS = "training_queue_timeout_in_seconds" 33 | 34 | AUC = "auc", 35 | ACCURACY = "accuracy" 36 | ACTIVE = "active" 37 | PASSIVE = "passive" 38 | 39 | TRAINING_SCORE_DIR = "training_score_dir" 40 | VALIDATION_SCORE_DIR = "validation_score_dir" 41 | ACTIVE_TRAINING_OUTPUT_FILE = "active_training_output_file" 42 | PASSIVE_TRAINING_OUTPUT_FILE = "passive_training_output_file" 43 | TFRECORD_GLOB_PATTERN = "*.tfrecord" 44 | VALIDATION_OUTPUT_FILE = "validation_output_file" 45 | PASSIVE_TRAINING_DATA_DIR = "passive_training_data_dir" 46 | RANDOM_EFFECT = "random_effect" 47 | FIXED_EFFECT = "fixed_effect" 48 | 49 | # Constants for random effect raining 50 | MODEL_IDS_DIR = "model_ids_dir" 51 | PARTITION_INDEX = "partition_index" 52 | PARTITION_LIST_FILE = "partition_list_file" 53 | 54 | # String constants related to execution context 55 | IS_CHIEF = "is_chief" 56 | NUM_SHARDS = "num_shards" 57 | SHARD_INDEX = "shard_index" 58 | NUM_EPOCHS = "num_epochs" 59 | NUM_WORKERS = "num_workers" 60 | WORKER = "worker" 61 | CLUSTER_SPEC = "cluster_spec" 62 | TASK_INDEX = "task_index" 63 | TASK_TYPE = "task_type" 64 | TASK_TYPE_CHIEF = "chief" 65 | TASK_TYPE_WORKER = "worker" 66 | TF_CONFIG = "TF_CONFIG" 67 | 68 | # Dataset constants 69 | DATASET_MODULE = "dataset_module" 70 | DATASET_CREATOR = "dataset_creator" 71 | TFRECORD = "tfrecord" 72 | INPUT_DIR = "input_dir" 73 | METADATA_FILE = "metadata_file" 74 | ACTION_INFERENCE = "inference" 75 | ACTION_TRAIN = "train" 76 | 77 | # Supported models 78 | LINEAR_REGRESSION = "linear_regression" 79 | LOGISTIC_REGRESSION = "logistic_regression" 80 | DETEXT = "detext" 81 | 82 | # Variance computation 83 | SIMPLE = "simple" 84 | FULL = "full" 85 | -------------------------------------------------------------------------------- /gdmix-trainer/src/gdmix/util/distribution_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import tensorflow as tf 4 | 5 | from gdmix.util.io_utils import low_rpc_call_glob 6 | 7 | logger = logging.getLogger(__name__) 8 | logger.setLevel(logging.INFO) 9 | 10 | 11 | def shard_input_files(input_path, num_shards, shard_index): 12 | 13 | """List input files in the input_path, then shard them such that 14 | a worker only takes a portion of the input files. If the number 15 | of files are less than the number of shards, file-level sharding indicator 16 | is set, each shard gets one file if the shard_index is less than 17 | the number of files, otherwise empty list is returned. 18 | 19 | Input path is a directory or a directory + file pattern. 20 | Possible values for input_path: 21 | 1) a directory: hdfs://namespace.com/jobs/bert/trainData 22 | 2) a filename pattern: /user/data/*.tfrecord 23 | :param input_path: the path where the training dataset is located. it can be 24 | a directory or a filename pattern. 25 | :param num_shards: Total number of shards. 26 | :param shard_index: The index of the current worker. 27 | :return: A tuple, (a list of files belonging to the shard and a boolean 28 | suggesting whether sample level sharding is needed. 29 | """ 30 | assert((shard_index >= 0) and (num_shards >= 1) and (num_shards > shard_index)) 31 | if tf.compat.v1.gfile.IsDirectory(input_path): 32 | input_files = low_rpc_call_glob(os.path.join(input_path, '*')) 33 | else: # This is a file or file pattern 34 | input_files = low_rpc_call_glob(input_path) 35 | # sort the file so that all workers see the same order. 36 | input_files = sorted(input_files) 37 | n = len(input_files) 38 | # there should be at least one file 39 | assert(n > 0), "{} is empty".format(input_files) 40 | if n < num_shards: 41 | if shard_index < n: 42 | return [input_files[shard_index]], True 43 | else: 44 | return [], True 45 | else: 46 | return [input_files[i] 47 | for i in range(shard_index, n, num_shards)], False 48 | 49 | 50 | def remove_tf_config(): 51 | tf_config = os.environ.pop('TF_CONFIG', '') 52 | 53 | if tf_config: 54 | logger.info("====== removing the following tf config environmental variable =======") 55 | logger.info(tf_config) 56 | logger.info("======================================================================") 57 | -------------------------------------------------------------------------------- /gdmix-trainer/src/gdmix/util/model_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def threshold_coefficients(coefficients, threshold_value): 5 | """ 6 | Set coefficients whose absolute values less than or equal to the threshold to 0.0. 7 | 8 | :param coefficients: a list of floats, usually coefficients from a trained model. 9 | :param threshold_value: a positive float used as the threshold value. 10 | :return a numpy array, the zeroed coefficients according to the threshold. 11 | """ 12 | return np.array([0.0 if abs(x) <= threshold_value else x for x in coefficients]) 13 | -------------------------------------------------------------------------------- /gdmix-trainer/test/drivers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/gdmix/dc24377e808ecc287ece67d88ffb800fb9ffaaa4/gdmix-trainer/test/drivers/__init__.py -------------------------------------------------------------------------------- /gdmix-trainer/test/drivers/test_helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from gdmix.params import SchemaParams, Params 4 | from gdmix.util import constants 5 | 6 | 7 | def set_fake_tf_config(task_type, worker_index): 8 | """ 9 | Set up fake TF_CONFIG environment variable 10 | :param task_type: worker or evaluator 11 | :param worker_index: index of node 12 | :return: None 13 | """ 14 | os.environ[constants.TF_CONFIG] = str( 15 | {"task": {"type": str(task_type), "index": worker_index}, "cluster": 16 | {"worker": ["node1.example.com:25304", 17 | "node2.example.com:32879", 18 | "node3.example.com:8068", 19 | "node4.example.com:25949", 20 | "node5.example.com:28685"], 21 | "evaluator": ["node6.example.com:21243"]}}).replace("'", '"') 22 | 23 | 24 | def setup_fake_base_training_params(training_stage=constants.FIXED_EFFECT, 25 | model_type=constants.LOGISTIC_REGRESSION): 26 | """ 27 | Set up fake parameter dict for testing 28 | :return: fake parameter dict 29 | """ 30 | params = {constants.ACTION: "train", 31 | constants.STAGE: training_stage, 32 | constants.MODEL_TYPE: model_type, 33 | constants.TRAINING_SCORE_DIR: "dummy_training_output_dir", 34 | constants.VALIDATION_SCORE_DIR: "dummy_validation_output_dir", 35 | 36 | constants.PARTITION_LIST_FILE: os.path.join(os.getcwd(), "test/resources/metadata", 37 | "partition_list.txt"), 38 | 39 | constants.UID_COLUMN_NAME: "uid", 40 | constants.WEIGHT_COLUMN_NAME: "weight", 41 | constants.LABEL_COLUMN_NAME: "response", 42 | constants.PREDICTION_SCORE_COLUMN_NAME: "predictionScore", 43 | constants.PREDICTION_SCORE_PER_COORDINATE_COLUMN_NAME: "predictionScorePerCoordinate" 44 | } 45 | params = Params(**params) 46 | object.__delattr__(params, '__frozen__') # Allow the test code to mutate the params. 47 | return params 48 | 49 | 50 | def setup_fake_raw_model_params(training_stage=constants.FIXED_EFFECT): 51 | raw_model_params = [f"--{constants.UID_COLUMN_NAME}", "uid", f"--{constants.WEIGHT_COLUMN_NAME}", "weight", 52 | f"--{constants.TRAINING_DATA_DIR}", os.path.join(os.getcwd(), "test/resources/train"), 53 | f"--{constants.VALIDATION_DATA_DIR}", 54 | os.path.join(os.getcwd(), "test/resources/validate"), 55 | f"--{constants.OUTPUT_MODEL_DIR}", "dummy_model_output_dir", 56 | f"--{constants.METADATA_FILE}", 57 | os.path.join(os.getcwd(), "test/resources/fe_lbfgs/metadata/tensor_metadata.json") 58 | ] 59 | if training_stage == constants.RANDOM_EFFECT: 60 | raw_model_params.append(f"--{constants.FEATURE_BAG}") 61 | raw_model_params.append("per_member") 62 | raw_model_params.append(f"--{constants.OFFSET_COLUMN_NAME}") 63 | raw_model_params.append("offset") 64 | else: 65 | raw_model_params.append(f"--{constants.FEATURE_BAG}") 66 | raw_model_params.append("global") 67 | raw_model_params.append(f"--{constants.FEATURE_FILE}") 68 | raw_model_params.append("test/resources/fe_lbfgs/featureList/global",) 69 | 70 | return raw_model_params 71 | 72 | 73 | def setup_fake_schema_params(): 74 | return SchemaParams(**{constants.UID_COLUMN_NAME: "uid", 75 | constants.WEIGHT_COLUMN_NAME: "weight", 76 | constants.LABEL_COLUMN_NAME: "response", 77 | constants.PREDICTION_SCORE_COLUMN_NAME: "predictionScore", 78 | constants.PREDICTION_SCORE_PER_COORDINATE_COLUMN_NAME: "predictionScorePerCoordinate" 79 | }) 80 | -------------------------------------------------------------------------------- /gdmix-trainer/test/factory/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/gdmix/dc24377e808ecc287ece67d88ffb800fb9ffaaa4/gdmix-trainer/test/factory/__init__.py -------------------------------------------------------------------------------- /gdmix-trainer/test/factory/test_driver_factory.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from gdmix.factory.driver_factory import DriverFactory 3 | 4 | from gdmix.drivers.fixed_effect_driver import FixedEffectDriver 5 | from gdmix.drivers.random_effect_driver import RandomEffectDriver 6 | from gdmix.util import constants 7 | from drivers.test_helper import set_fake_tf_config, setup_fake_base_training_params, setup_fake_raw_model_params 8 | 9 | 10 | class TestDriverFactory(tf.test.TestCase): 11 | """ 12 | Test DriverFactory 13 | """ 14 | 15 | def setUp(self): 16 | self.task_type = "worker" 17 | self.worker_index = 0 18 | self.num_workers = 5 19 | set_fake_tf_config(task_type=self.task_type, worker_index=self.worker_index) 20 | self.params = setup_fake_base_training_params() 21 | self.model_params = setup_fake_raw_model_params() 22 | 23 | def test_fixed_effect_driver_wiring(self): 24 | fe_driver = DriverFactory.get_driver( 25 | base_training_params=setup_fake_base_training_params(constants.FIXED_EFFECT), 26 | raw_model_params=self.model_params) 27 | # Assert the type of driver 28 | self.assertIsInstance(fe_driver, FixedEffectDriver) 29 | 30 | def test_random_effect_driver_wiring(self): 31 | re_driver = DriverFactory.get_driver( 32 | base_training_params=setup_fake_base_training_params(constants.RANDOM_EFFECT), 33 | raw_model_params=self.model_params) 34 | # Assert the type of driver 35 | self.assertIsInstance(re_driver, RandomEffectDriver) 36 | -------------------------------------------------------------------------------- /gdmix-trainer/test/factory/test_model_factory.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from gdmix.factory.model_factory import ModelFactory 4 | from gdmix.models.custom.fixed_effect_lr_lbfgs_model import FixedEffectLRModelLBFGS 5 | from gdmix.models.custom.random_effect_lr_lbfgs_model import RandomEffectLRLBFGSModel 6 | from gdmix.util import constants 7 | from drivers.test_helper import setup_fake_base_training_params, setup_fake_raw_model_params 8 | 9 | 10 | class TestModelFactory(tf.test.TestCase): 11 | """ 12 | Test ModelFactory 13 | """ 14 | 15 | def setUp(self): 16 | self.model_params = setup_fake_raw_model_params() 17 | 18 | def test_fixed_effect_logistic_regression_lbfgs_model_creation(self): 19 | fe_model = ModelFactory.get_model( 20 | base_training_params=setup_fake_base_training_params(training_stage=constants.FIXED_EFFECT, 21 | model_type=constants.LOGISTIC_REGRESSION), 22 | raw_model_params=self.model_params) 23 | # Assert the type of model 24 | self.assertIsInstance(fe_model, FixedEffectLRModelLBFGS) 25 | 26 | def test_fixed_effect_linear_regression_lbfgs_model_creation(self): 27 | fe_model = ModelFactory.get_model( 28 | base_training_params=setup_fake_base_training_params(training_stage=constants.FIXED_EFFECT, 29 | model_type=constants.LINEAR_REGRESSION), 30 | raw_model_params=self.model_params) 31 | # Assert the type of model 32 | self.assertIsInstance(fe_model, FixedEffectLRModelLBFGS) 33 | 34 | def test_random_effect_custom_logistic_regression_model_creation(self): 35 | re_model = ModelFactory.get_model( 36 | base_training_params=setup_fake_base_training_params(training_stage=constants.RANDOM_EFFECT, 37 | model_type=constants.LOGISTIC_REGRESSION), 38 | raw_model_params=self.model_params) 39 | self.assertIsInstance(re_model, RandomEffectLRLBFGSModel) 40 | -------------------------------------------------------------------------------- /gdmix-trainer/test/io/test_dataset_metadata.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | 4 | from gdmix.io.dataset_metadata import DatasetMetadata 5 | 6 | test_metadata_file = os.path.join(os.getcwd(), "test/resources/metadata") 7 | 8 | 9 | class TestDatasetMetadata(tf.test.TestCase): 10 | """Test DatasetMetadata class.""" 11 | dummy_metadata = DatasetMetadata(os.path.join( 12 | test_metadata_file, "valid_metadata.json")) 13 | feature_names = ["weight", "f1"] 14 | label_names = ["response"] 15 | 16 | def test_feature_names(self): 17 | self.assertEqual(self.dummy_metadata.get_feature_names(), self.feature_names) 18 | 19 | def test_label_names(self): 20 | self.assertEqual(self.dummy_metadata.get_label_names(), self.label_names) 21 | 22 | def test_invalid_type(self): 23 | msg_pattern = r"User provided dtype \'.*\' is not supported. Supported types are \'.*\'." 24 | with self.assertRaises(ValueError, msg=msg_pattern): 25 | DatasetMetadata(os.path.join(test_metadata_file, "invalid_type.json")) 26 | 27 | def test_invalid_name(self): 28 | msg_pattern = r"Feature name can not be None and must be str" 29 | with self.assertRaises(ValueError, msg=msg_pattern): 30 | DatasetMetadata(os.path.join(test_metadata_file, "invalid_name.json")) 31 | 32 | def test_invalid_shape(self): 33 | msg_pattern = r"Feature shape can not be None and must be a list" 34 | with self.assertRaises(ValueError, msg=msg_pattern): 35 | DatasetMetadata(os.path.join(test_metadata_file, "invalid_shape.json")) 36 | 37 | def test_duplicated_names(self): 38 | msg_pattern = r"The following tensor names in your metadata appears more than once:\['weight', 'response'\]" 39 | with self.assertRaises(ValueError, msg=msg_pattern): 40 | DatasetMetadata(os.path.join(test_metadata_file, "duplicated_names.json")) 41 | 42 | def test_map_int(self): 43 | int_dtypes = [tf.int8, tf.uint8, tf.uint16, tf.uint32, tf.uint64, tf.int16, tf.int32, tf.int64] 44 | for id in int_dtypes: 45 | assert tf.int64 == DatasetMetadata.map_int(id) 46 | assert tf.float32 == DatasetMetadata.map_int(tf.float32) 47 | assert tf.float16 == DatasetMetadata.map_int(tf.float16) 48 | assert tf.float64 == DatasetMetadata.map_int(tf.float64) 49 | assert tf.string == DatasetMetadata.map_int(tf.string) 50 | 51 | 52 | if __name__ == '__main__': 53 | tf.test.main() 54 | -------------------------------------------------------------------------------- /gdmix-trainer/test/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/gdmix/dc24377e808ecc287ece67d88ffb800fb9ffaaa4/gdmix-trainer/test/models/__init__.py -------------------------------------------------------------------------------- /gdmix-trainer/test/models/custom/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/gdmix/dc24377e808ecc287ece67d88ffb800fb9ffaaa4/gdmix-trainer/test/models/custom/__init__.py -------------------------------------------------------------------------------- /gdmix-trainer/test/models/custom/test_optimizer_helper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | import statsmodels.api as sm 4 | 5 | from gdmix.util import constants 6 | 7 | 8 | def compute_coefficients_and_variance(X, y, weights=None, offsets=None, 9 | variance_mode=constants.SIMPLE, 10 | lambda_l2=0.0, has_intercept=True): 11 | """ 12 | compute coefficients and variance for logistic regression model 13 | :param X: num_samples x num_features matrix 14 | :param y: num_samples binary labels (0 or 1) 15 | :param weights: num_samples floats, weights of each sample. 16 | :param offsets: num_samples floats, offset of each sample 17 | :param variance_mode: full or simple 18 | :param lambda_l2: L2 regularization coefficient 19 | :param has_intercept: whether to include intercept 20 | :return: (mean, variance) tuple 21 | """ 22 | if scipy.sparse.issparse(X): 23 | X = X.toarray() 24 | X_with_intercept = np.hstack((np.ones((X.shape[0], 1)), X)) if has_intercept else X 25 | lr_model = sm.GLM(y, X_with_intercept, family=sm.families.Binomial(), 26 | offset=offsets, freq_weights=weights) 27 | if lambda_l2 != 0.0: 28 | raise ValueError("This function uses statsmodels to compute LR coefficients and its variance. " 29 | "However, as of version 0.12.2, the coefficients when non-zero L2 regularization" 30 | " is applied are not correct. So we can only check L2=0.") 31 | lr_results = lr_model.fit_regularized(alpha=lambda_l2, maxiterint=500, 32 | cnvrg_tol=1e-12, L1_wt=0.0) 33 | mean = lr_results.params 34 | hessian = lr_model.hessian(mean) 35 | if variance_mode == constants.SIMPLE: 36 | variance = -1.0 / np.diagonal(hessian) 37 | elif variance_mode == constants.FULL: 38 | variance = -np.diagonal(np.linalg.inv(hessian)) 39 | return mean, variance 40 | -------------------------------------------------------------------------------- /gdmix-trainer/test/models/detext/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/gdmix/dc24377e808ecc287ece67d88ffb800fb9ffaaa4/gdmix-trainer/test/models/detext/__init__.py -------------------------------------------------------------------------------- /gdmix-trainer/test/models/test_model_api.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from gdmix.models.api import Model 4 | 5 | 6 | class ConcreteModel(Model): 7 | """Derived model class""" 8 | 9 | def train(self, 10 | training_data_dir, 11 | validation_data_dir, 12 | metadata_file, 13 | checkpoint_path, 14 | execution_context, 15 | schema_params): 16 | super(ConcreteModel, self).train(training_data_dir, 17 | validation_data_dir, 18 | metadata_file, 19 | checkpoint_path, 20 | execution_context, 21 | schema_params) 22 | 23 | def predict(self, 24 | output_dir, 25 | input_data_path, 26 | metadata_file, 27 | checkpoint_path, 28 | execution_context, 29 | schema_params): 30 | super(ConcreteModel, self).predict(output_dir, 31 | input_data_path, 32 | metadata_file, 33 | checkpoint_path, 34 | execution_context, 35 | schema_params) 36 | 37 | def export(self, output_model_dir): 38 | super(ConcreteModel, self).export(output_model_dir) 39 | 40 | def _parse_parameters(self, raw_model_parameters): 41 | pass 42 | 43 | 44 | class TestAbstractModel(tf.test.TestCase): 45 | """Test abstract Model class.""" 46 | 47 | raw_model_parameters = None 48 | concrete_model = ConcreteModel(raw_model_parameters) 49 | 50 | def test_train(self): 51 | self.assertRaises(NotImplementedError, self.concrete_model.train, None, None, None, None, None, None) 52 | 53 | def test_predict(self): 54 | self.assertRaises(NotImplementedError, self.concrete_model.predict, None, None, None, None, None, None) 55 | 56 | def test_export(self): 57 | self.assertRaises(NotImplementedError, self.concrete_model.export, None) 58 | 59 | 60 | if __name__ == '__main__': 61 | tf.test.main() 62 | -------------------------------------------------------------------------------- /gdmix-trainer/test/resources/bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 128, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 3072, 8 | "max_position_embeddings": 512, 9 | "num_attention_heads": 4, 10 | "num_hidden_layers": 4, 11 | "type_vocab_size": 2, 12 | "vocab_size": 30522 13 | } 14 | -------------------------------------------------------------------------------- /gdmix-trainer/test/resources/custom/sklearn_data.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/gdmix/dc24377e808ecc287ece67d88ffb800fb9ffaaa4/gdmix-trainer/test/resources/custom/sklearn_data.p -------------------------------------------------------------------------------- /gdmix-trainer/test/resources/fe_lbfgs/featureList/global: -------------------------------------------------------------------------------- 1 | n1,t1 2 | n2,t2 3 | n3,t3 4 | n4,t4 5 | n5,t5 6 | n6,t6 7 | n7,t7 8 | n8,t8 9 | n9,t9 10 | n10,t10 11 | n11,t11 12 | n12,t12 13 | n13,t13 14 | n14,t14 15 | n15,t15 16 | n16,t16 17 | n17,t17 18 | n18,t18 19 | n19,t19 20 | n20,t20 21 | n21,t21 22 | n22,t22 23 | n23,t23 24 | n24,t24 25 | n25,t25 26 | n26,t26 27 | n27,t27 28 | n28,t28 29 | n29,t29 30 | n30,t30 31 | n31,t31 32 | n32,t32 33 | n33,t33 34 | n34,t34 35 | n35,t35 36 | n36,t36 37 | n37,t37 38 | n38,t38 39 | n39,t39 40 | n40,t40 41 | n41,t41 42 | n42,t42 43 | n43,t43 44 | n44,t44 45 | n45,t45 46 | n46,t46 47 | n47,t47 48 | n48,t48 49 | n49,t49 50 | n50,t50 51 | n51,t51 52 | n52,t52 53 | n53,t53 54 | n54,t54 55 | n55,t55 56 | n56,t56 57 | n57,t57 58 | n58,t58 59 | n59,t59 60 | n60,t60 61 | n61,t61 62 | n62,t62 63 | n63,t63 64 | n64,t64 65 | n65,t65 66 | n66,t66 67 | n67,t67 68 | n68,t68 69 | n69,t69 70 | n70,t70 71 | n71,t71 72 | n72,t72 73 | n73,t73 74 | n74,t74 75 | n75,t75 76 | n76,t76 77 | n77,t77 78 | n78,t78 79 | n79,t79 80 | n80,t80 81 | n81,t81 82 | n82,t82 83 | n83,t83 84 | n84,t84 85 | n85,t85 86 | n86,t86 87 | n87,t87 88 | n88,t88 89 | n89,t89 90 | n90,t90 91 | n91,t91 92 | n92,t92 93 | n93,t93 94 | n94,t94 95 | n95,t95 96 | n96,t96 97 | n97,t97 98 | n98,t98 99 | n99,t99 100 | n100,t100 101 | n101,t101 102 | n102,t102 103 | n103,t103 104 | n104,t104 105 | n105,t105 106 | n106,t106 107 | n107,t107 108 | n108,t108 109 | n109,t109 110 | n110,t110 111 | n111,t111 112 | n112,t112 113 | n113,t113 114 | n114,t114 115 | n115,t115 116 | n116,t116 117 | n117,t117 118 | n118,t118 119 | n119,t119 120 | n120,t120 121 | n121,t121 122 | n122,t122 123 | n123,t123 124 | n124,t124 125 | n125,t125 126 | n126,t126 127 | n127,t127 128 | n128,t128 129 | n129,t129 130 | n130,t130 131 | n131,t131 132 | n132,t132 133 | n133,t133 134 | n134,t134 135 | n135,t135 136 | n136,t136 137 | n137,t137 138 | n138,t138 139 | n139,t139 140 | n140,t140 141 | n141,t141 142 | n142,t142 143 | n143,t143 144 | n144,t144 145 | n145,t145 146 | n146,t146 147 | n147,t147 148 | n148,t148 149 | n149,t149 150 | n150,t150 151 | n151,t151 152 | n152,t152 153 | n153,t153 154 | n154,t154 155 | n155,t155 156 | n156,t156 157 | n157,t157 158 | n158,t158 159 | n159,t159 160 | n160,t160 161 | n161,t161 162 | n162,t162 163 | n163,t163 164 | n164,t164 165 | n165,t165 166 | n166,t166 167 | n167,t167 168 | n168,t168 169 | n169,t169 170 | n170,t170 171 | n171,t171 172 | n172,t172 173 | n173,t173 174 | n174,t174 175 | n175,t175 176 | n176,t176 177 | n177,t177 178 | n178,t178 179 | n179,t179 180 | n180,t180 181 | n181,t181 182 | n182,t182 183 | n183,t183 184 | n184,t184 185 | n185,t185 186 | n186,t186 187 | n187,t187 188 | n188,t188 189 | n189,t189 190 | n190,t190 191 | n191,t191 192 | n192,t192 193 | n193,t193 194 | n194,t194 195 | n195,t195 196 | n196,t196 197 | n197,t197 198 | n198,t198 199 | n199,t199 200 | n200,t200 201 | n201,t201 202 | n202,t202 203 | n203,t203 204 | n204,t204 205 | n205,t205 206 | n206,t206 207 | n207,t207 208 | n208,t208 209 | n209,t209 210 | n210,t210 211 | n211,t211 212 | n212,t212 213 | n213,t213 214 | n214,t214 215 | n215,t215 216 | n216,t216 217 | n217,t217 218 | n218,t218 219 | n219,t219 220 | n220,t220 221 | n221,t221 222 | n222,t222 223 | n223,t223 224 | n224,t224 225 | n225,t225 226 | n226,t226 227 | n227,t227 228 | n228,t228 229 | n229,t229 230 | n230,t230 231 | n231,t231 232 | n232,t232 233 | n233,t233 234 | n234,t234 235 | n235,t235 236 | n236,t236 237 | n237,t237 238 | n238,t238 239 | n239,t239 240 | n240,t240 241 | n241,t241 242 | n242,t242 243 | n243,t243 244 | n244,t244 245 | n245,t245 246 | n246,t246 247 | n247,t247 248 | n248,t248 249 | n249,t249 250 | n250,t250 251 | n251,t251 252 | n252,t252 253 | n253,t253 254 | n254,t254 255 | n255,t255 256 | n256,t256 257 | n257,t257 258 | n258,t258 259 | n259,t259 260 | n260,t260 261 | n261,t261 262 | n262,t262 263 | n263,t263 264 | n264,t264 265 | n265,t265 266 | n266,t266 267 | n267,t267 268 | n268,t268 269 | n269,t269 270 | n270,t270 271 | n271,t271 272 | n272,t272 273 | n273,t273 274 | n274,t274 275 | n275,t275 276 | n276,t276 277 | n277,t277 278 | n278,t278 279 | n279,t279 280 | n280,t280 281 | n281,t281 282 | n282,t282 283 | n283,t283 284 | n284,t284 285 | n285,t285 286 | n286,t286 287 | n287,t287 288 | n288,t288 289 | n289,t289 290 | n290,t290 291 | n291,t291 292 | n292,t292 293 | n293,t293 294 | n294,t294 295 | n295,t295 296 | n296,t296 297 | n297,t297 298 | n298,t298 299 | n299,t299 300 | n300,t300 301 | n301,t301 302 | n302,t302 303 | n303,t303 304 | n304,t304 305 | n305,t305 306 | n306,t306 -------------------------------------------------------------------------------- /gdmix-trainer/test/resources/fe_lbfgs/metadata/tensor_metadata.json: -------------------------------------------------------------------------------- 1 | { 2 | "features" : [ { 3 | "name" : "weight", 4 | "dtype" : "float", 5 | "shape" : [ ], 6 | "isSparse" : false 7 | }, { 8 | "name" : "global", 9 | "dtype" : "float", 10 | "shape" : [ 306 ], 11 | "isSparse" : true 12 | }, { 13 | "name" : "uid", 14 | "dtype" : "long", 15 | "shape" : [ ], 16 | "isSparse" : false 17 | } ], 18 | "labels" : [ { 19 | "name" : "response", 20 | "dtype" : "int", 21 | "shape" : [ ], 22 | "isSparse" : false 23 | } ] 24 | } 25 | -------------------------------------------------------------------------------- /gdmix-trainer/test/resources/fe_lbfgs/training_data/test.tfrecord: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/gdmix/dc24377e808ecc287ece67d88ffb800fb9ffaaa4/gdmix-trainer/test/resources/fe_lbfgs/training_data/test.tfrecord -------------------------------------------------------------------------------- /gdmix-trainer/test/resources/grouped_per_member_train/data.json: -------------------------------------------------------------------------------- 1 | {"features":[{"name": "per_member", "dtype": "float", "shape": [100], "isSparse": true},{"name": "weight", "dtype": "float", "shape": [], "isSparse": false},{"name": "offset", "dtype": "float", "shape": [], "isSparse": false},{"name": "uid", "dtype": "int", "shape": [], "isSparse": false},{"name": "memberId", "dtype": "int", "shape": [], "isSparse": false}],"labels" :[{"name": "response", "dtype": "int", "shape": [], "isSparse": false}]} -------------------------------------------------------------------------------- /gdmix-trainer/test/resources/grouped_per_member_train/data.tfrecord: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/gdmix/dc24377e808ecc287ece67d88ffb800fb9ffaaa4/gdmix-trainer/test/resources/grouped_per_member_train/data.tfrecord -------------------------------------------------------------------------------- /gdmix-trainer/test/resources/grouped_per_member_train/data_intercept_only.json: -------------------------------------------------------------------------------- 1 | {"features":[{"name": "weight", "dtype": "float", "shape": [], "isSparse": false},{"name": "offset", "dtype": "float", "shape": [], "isSparse": false},{"name": "uid", "dtype": "int", "shape": [], "isSparse": false},{"name": "memberId", "dtype": "int", "shape": [], "isSparse": false}],"labels" :[{"name": "response", "dtype": "int", "shape": [], "isSparse": false}]} -------------------------------------------------------------------------------- /gdmix-trainer/test/resources/grouped_per_member_train/data_with_string_entity_id.json: -------------------------------------------------------------------------------- 1 | { 2 | "features": [ 3 | { 4 | "dtype": "float", 5 | "isSparse": true, 6 | "name": "per_member", 7 | "shape": [ 8 | 100 9 | ] 10 | }, 11 | { 12 | "dtype": "float", 13 | "isSparse": false, 14 | "name": "weight", 15 | "shape": [] 16 | }, 17 | { 18 | "dtype": "float", 19 | "isSparse": false, 20 | "name": "offset", 21 | "shape": [] 22 | }, 23 | { 24 | "dtype": "int", 25 | "isSparse": false, 26 | "name": "uid", 27 | "shape": [] 28 | }, 29 | { 30 | "dtype": "string", 31 | "isSparse": false, 32 | "name": "memberId", 33 | "shape": [] 34 | } 35 | ], 36 | "labels": [ 37 | { 38 | "dtype": "int", 39 | "isSparse": false, 40 | "name": "response", 41 | "shape": [] 42 | } 43 | ] 44 | } -------------------------------------------------------------------------------- /gdmix-trainer/test/resources/grouped_per_member_train/dataset_1.json: -------------------------------------------------------------------------------- 1 | { 2 | "features": [ 3 | { 4 | "dtype": "float", 5 | "isSparse": true, 6 | "name": "per_member", 7 | "shape": [ 8 | 3 9 | ] 10 | }, 11 | { 12 | "dtype": "float", 13 | "isSparse": false, 14 | "name": "weight", 15 | "shape": [] 16 | }, 17 | { 18 | "dtype": "float", 19 | "isSparse": false, 20 | "name": "offset", 21 | "shape": [] 22 | }, 23 | { 24 | "dtype": "int", 25 | "isSparse": false, 26 | "name": "uid", 27 | "shape": [] 28 | }, 29 | { 30 | "dtype": "string", 31 | "isSparse": false, 32 | "name": "memberId", 33 | "shape": [] 34 | } 35 | ], 36 | "labels": [ 37 | { 38 | "dtype": "int", 39 | "isSparse": false, 40 | "name": "response", 41 | "shape": [] 42 | } 43 | ] 44 | } 45 | -------------------------------------------------------------------------------- /gdmix-trainer/test/resources/grouped_per_member_train/dataset_1_feature_file.csv: -------------------------------------------------------------------------------- 1 | f1,t1 2 | f2,t2 3 | f3,t3 -------------------------------------------------------------------------------- /gdmix-trainer/test/resources/grouped_per_member_train/fake_feature_file.csv: -------------------------------------------------------------------------------- 1 | f1,t1 2 | f2,t2 3 | f3,t3 4 | f4,t4 5 | f5,t5 6 | f6,t6 7 | f7,t7 8 | f8,t8 9 | f9,t9 10 | f10,t10 11 | f11,t11 12 | f12,t12 13 | f13,t13 14 | f14,t14 15 | f15,t15 16 | f16,t16 17 | f17,t17 18 | f18,t18 19 | f19,t19 20 | f20,t20 21 | f21,t21 22 | f22,t22 23 | f23,t23 24 | f24,t24 25 | f25,t25 26 | f26,t26 27 | f27,t27 28 | f28,t28 29 | f29,t29 30 | f30,t30 31 | f31,t31 32 | f32,t32 33 | f33,t33 34 | f34,t34 35 | f35,t35 36 | f36,t36 37 | f37,t37 38 | f38,t38 39 | f39,t39 40 | f40,t40 41 | f41,t41 42 | f42,t42 43 | f43,t43 44 | f44,t44 45 | f45,t45 46 | f46,t46 47 | f47,t47 48 | f48,t48 49 | f49,t49 50 | f50,t50 51 | f51,t51 52 | f52,t52 53 | f53,t53 54 | f54,t54 55 | f55,t55 56 | f56,t56 57 | f57,t57 58 | f58,t58 59 | f59,t59 60 | f60,t60 61 | f61,t61 62 | f62,t62 63 | f63,t63 64 | f64,t64 65 | f65,t65 66 | f66,t66 67 | f67,t67 68 | f68,t68 69 | f69,t69 70 | f70,t70 71 | f71,t71 72 | f72,t72 73 | f73,t73 74 | f74,t74 75 | f75,t75 76 | f76,t76 77 | f77,t77 78 | f78,t78 79 | f79,t79 80 | f80,t80 81 | f81,t81 82 | f82,t82 83 | f83,t83 84 | f84,t84 85 | f85,t85 86 | f86,t86 87 | f87,t87 88 | f88,t88 89 | f89,t89 90 | f90,t90 91 | f91,t91 92 | f92,t92 93 | f93,t93 94 | f94,t94 95 | f95,t95 96 | f96,t96 97 | f97,t97 98 | f98,t98 99 | f99,t99 100 | f100,t100 -------------------------------------------------------------------------------- /gdmix-trainer/test/resources/member_ids.avro: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/gdmix/dc24377e808ecc287ece67d88ffb800fb9ffaaa4/gdmix-trainer/test/resources/member_ids.avro -------------------------------------------------------------------------------- /gdmix-trainer/test/resources/metadata/duplicated_names.json: -------------------------------------------------------------------------------- 1 | { 2 | "features": [ 3 | { 4 | "name": "weight", 5 | "dtype": "float", 6 | "shape": [], 7 | "isSparse": false 8 | }, 9 | { 10 | "name": "weight", 11 | "dtype": "int", 12 | "shape": [], 13 | "isSparse": false 14 | }, 15 | { 16 | "name": "response", 17 | "dtype": "float", 18 | "shape": [], 19 | "isSparse": false 20 | } 21 | ], 22 | "labels": [ 23 | { 24 | "name": "response", 25 | "dtype": "int", 26 | "shape": [], 27 | "isSparse": false 28 | } 29 | ] 30 | } 31 | -------------------------------------------------------------------------------- /gdmix-trainer/test/resources/metadata/features.txt: -------------------------------------------------------------------------------- 1 | global, 2 | tf_glmix,bias 3 | -------------------------------------------------------------------------------- /gdmix-trainer/test/resources/metadata/invalid_name.json: -------------------------------------------------------------------------------- 1 | { 2 | "features": [ 3 | { 4 | "name": 9, 5 | "dtype": "string", 6 | "shape": [], 7 | "isSparse": false 8 | } 9 | ], 10 | "labels": [ ] 11 | } 12 | -------------------------------------------------------------------------------- /gdmix-trainer/test/resources/metadata/invalid_shape.json: -------------------------------------------------------------------------------- 1 | { 2 | "features": [ 3 | { 4 | "name": "textFeature", 5 | "dtype": "string", 6 | "shape": 7, 7 | "isSparse": false 8 | } 9 | ], 10 | "labels": [ ] 11 | } 12 | -------------------------------------------------------------------------------- /gdmix-trainer/test/resources/metadata/invalid_type.json: -------------------------------------------------------------------------------- 1 | { 2 | "features": [ 3 | { 4 | "name": "textFeature", 5 | "dtype": "text", 6 | "shape": [], 7 | "isSparse": false 8 | } 9 | ], 10 | "labels": [ ] 11 | } 12 | -------------------------------------------------------------------------------- /gdmix-trainer/test/resources/metadata/partition_list.txt: -------------------------------------------------------------------------------- 1 | 0,1,2,3,4,5,6,7,8,9 -------------------------------------------------------------------------------- /gdmix-trainer/test/resources/metadata/tensor_metadata.json: -------------------------------------------------------------------------------- 1 | { 2 | "features" : [ 3 | { 4 | "name" : "global", 5 | "dtype" : "float", 6 | "shape" : [ 7 | 2 8 | ], 9 | "numUniqueValues" : null, 10 | "isSparse" : true 11 | }, 12 | { 13 | "name" : "weight", 14 | "dtype" : "float", 15 | "shape" : [ 16 | ], 17 | "numUniqueValues" : null, 18 | "isSparse" : false 19 | }, 20 | { 21 | "name" : "uid", 22 | "dtype" : "long", 23 | "shape" : [ 24 | ], 25 | "numUniqueValues" : 9, 26 | "isSparse" : false 27 | } 28 | ], 29 | "labels" : [ 30 | { 31 | "name" : "label", 32 | "dtype" : "int", 33 | "shape" : [ 34 | ], 35 | "numUniqueValues" : 1, 36 | "isSparse" : false 37 | } 38 | ] 39 | } -------------------------------------------------------------------------------- /gdmix-trainer/test/resources/metadata/valid.json: -------------------------------------------------------------------------------- 1 | { 2 | "features": [ 3 | { 4 | "name": "weight", 5 | "dtype": "float", 6 | "shape": [], 7 | "isSparse": false 8 | }, 9 | { 10 | "name": "profileImage", 11 | "parserKey": "profileImage", 12 | "dtype": "bytes", 13 | "shape": [], 14 | "isSparse": false 15 | }, 16 | { 17 | "name": "stringFixedLenSeq", 18 | "parserKey": "stringFixedLenSeq[*]", 19 | "dtype": "string", 20 | "shape": [ 21 | 4 22 | ], 23 | "isSparse": false, 24 | "numUniqueValues": 123, 25 | "defaultValue": "xxx" 26 | }, 27 | { 28 | "name": "memberJobCrossFeatures", 29 | "parserKey": "memberJobCrossFeatures[*].indices", 30 | "dtype": "long", 31 | "shape": [ 32 | -1 33 | ], 34 | "numUniqueValues": 1964, 35 | "isSparse": false, 36 | "isDocumentFeature": true 37 | }, 38 | { 39 | "name": "intSparseSeq", 40 | "parserKey": "intSparseSeq[*].id", 41 | "dtype": "int", 42 | "shape": [ 43 | -1 44 | ], 45 | "isSparse": false, 46 | "isDocumentFeature": false 47 | }, 48 | { 49 | "name": "imageVarLenSeq", 50 | "parserKey": "imageVarLenSeq[*]", 51 | "dtype": "bytes", 52 | "shape": [ 53 | -1 54 | ], 55 | "isSparse": false 56 | }, 57 | { 58 | "name": "sparseFeature", 59 | "parserKey": "sparseFeature", 60 | "dtype": "float", 61 | "shape": [ 62 | 1001 63 | ], 64 | "isSparse": true 65 | } 66 | ], 67 | "labels": [ 68 | { 69 | "isSparse": false, 70 | "name": "response", 71 | "parserKey": "response", 72 | "dtype": "int", 73 | "shape": [] 74 | } 75 | ] 76 | } 77 | -------------------------------------------------------------------------------- /gdmix-trainer/test/resources/metadata/valid_metadata.json: -------------------------------------------------------------------------------- 1 | { 2 | "features": [ 3 | { 4 | "name": "weight", 5 | "dtype": "float", 6 | "shape": [234], 7 | "isSparse": true 8 | }, 9 | { 10 | "name": "f1", 11 | "dtype": "float", 12 | "shape": [3], 13 | "isSparse": false 14 | } 15 | ], 16 | "labels": [ 17 | { 18 | "name": "response", 19 | "dtype": "int", 20 | "shape": [], 21 | "isSparse": false 22 | } 23 | ] 24 | } 25 | -------------------------------------------------------------------------------- /gdmix-trainer/test/resources/train/dataset/tfrecord/test.tfrecord: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/gdmix/dc24377e808ecc287ece67d88ffb800fb9ffaaa4/gdmix-trainer/test/resources/train/dataset/tfrecord/test.tfrecord -------------------------------------------------------------------------------- /gdmix-trainer/test/resources/validate/data.avro: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/gdmix/dc24377e808ecc287ece67d88ffb800fb9ffaaa4/gdmix-trainer/test/resources/validate/data.avro -------------------------------------------------------------------------------- /gdmix-trainer/test/util/test_distribution_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import tensorflow as tf 4 | 5 | from gdmix.util.distribution_utils import shard_input_files 6 | 7 | 8 | class TestDistributionUtils(tf.test.TestCase): 9 | """ 10 | Test distribution utils 11 | """ 12 | 13 | def setUp(self): 14 | self._base_dir = tempfile.mkdtemp() 15 | for i in range(10): 16 | with open(os.path.join(self._base_dir, f'{i}.avro'), 'w') as f: 17 | f.write("test") 18 | for i in range(10): 19 | with open(os.path.join(self._base_dir, f'{i}.tfrecord'), 'w') as f: 20 | f.write("test") 21 | 22 | def tearDown(self): 23 | tf.io.gfile.rmtree(self._base_dir) 24 | 25 | def test_shard_input_files_with_wrong_params(self): 26 | with self.assertRaises(AssertionError): 27 | shard_input_files(self._base_dir, 1, 2) 28 | with self.assertRaises(AssertionError): 29 | shard_input_files(self._base_dir, -1, -2) 30 | with self.assertRaises(tf.errors.NotFoundError): 31 | shard_input_files(os.path.join(self._base_dir, "nowhere/nofile"), 3, 2) 32 | 33 | def test_shard_input_files_with_directory(self): 34 | shard_files, _ = shard_input_files(self._base_dir, 2, 0) 35 | expected_files = [os.path.join(self._base_dir, f'{i}.avro') for i in range(10)] 36 | self.assertAllEqual(shard_files, expected_files) 37 | 38 | def test_shard_input_file_with_filename_pattern(self): 39 | input_file_pattern = os.path.join(self._base_dir, "*.tfrecord") 40 | shard_files, indicator = shard_input_files(input_file_pattern, 3, 1) 41 | expected_files = [os.path.join(self._base_dir, f'{i}.tfrecord') for i in range(1, 10, 3)] 42 | self.assertAllEqual(shard_files, expected_files) 43 | self.assertFalse(indicator) 44 | 45 | def test_shard_input_file_with_more_shards(self): 46 | input_file_pattern = os.path.join(self._base_dir, "*.tfrecord") 47 | shard_files, indicator = shard_input_files(input_file_pattern, 20, 1) 48 | expected_files = [os.path.join(self._base_dir, '1.tfrecord')] 49 | self.assertAllEqual(shard_files, expected_files) 50 | self.assertTrue(indicator) 51 | shard_files, indicator = shard_input_files(input_file_pattern, 20, 19) 52 | self.assertEqual(len(shard_files), 0) 53 | self.assertTrue(indicator) 54 | -------------------------------------------------------------------------------- /gdmix-trainer/test/util/test_model_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from gdmix.util.model_utils import threshold_coefficients 5 | 6 | 7 | class TestModelUtils(tf.test.TestCase): 8 | """ 9 | Test Model Utils 10 | """ 11 | def testThresholdCoefficients(self): 12 | coefficients = np.array([1e-5, -1e-4, -0.1, -1e-5, 2.2, 3.3]) 13 | expected_coefficients = np.array([0.0, 0.0, -0.1, 0.0, 2.2, 3.3]) 14 | actual_coefficients = threshold_coefficients(coefficients, 1e-4) 15 | self.assertAllEqual(actual_coefficients, expected_coefficients) 16 | -------------------------------------------------------------------------------- /gdmix-workflow/examples/movielens-100k/detext-movieLens.yaml: -------------------------------------------------------------------------------- 1 | ../../test/resources/detext-movieLens.yaml -------------------------------------------------------------------------------- /gdmix-workflow/examples/movielens-100k/lr-movieLens.yaml: -------------------------------------------------------------------------------- 1 | ../../test/resources/lr-movieLens.yaml -------------------------------------------------------------------------------- /gdmix-workflow/gdmix_config.md: -------------------------------------------------------------------------------- 1 | 2 | # GDMix Configs 3 | To train fixed effect and random effect models using GDMix, users need to provide a GDMix config, which consists of configs for 4 | fixed-effect and random-effect models. For distributed training, computing resource configs for Tensorflow and Spark jobs are needed. 5 | 6 | GDMix config examples for movieLens with a fixed-effect `global` model and two random effect `per-user` and `per-movie` models are available in directory `examples/movielens-100k`: 7 | - [lr-movieLens.yaml](examples/movielens-100k/lr-single-node-movieLens.config): train logistic regression models for the `global`, `per-user` and `per-movie` models; `spark_config` and `tfjob_config` sections are resources used for distributed training only. 8 | - [detext-movieLens.yaml](examples/movielens-100k/detext-movieLens.yaml): train a deep and wide neural network model for the `global` and logistic regression models for the `per-user` and `per-movie`; `spark_config` and `tfjob_config` sections are resources used for distributed training only. 9 | 10 | ## Logistic regression models 11 | ### Fixed-effect config 12 | Required fields: 13 | - **name**: name of the model. String. 14 | - **training_data_dir**: path to training data directory. String. 15 | - **uid_column_name**: unique id column name in the train/validation data. 16 | - **metadata_file**: path to an input data tensor metadata file. String. 17 | - **feature_file**: path to a feature list file for outputing model in name-term-value format. String. 18 | - **model_type**: the model type to train, e.g, logistic regression, linear regression, detext, etc. 19 | - **output_model_dir**: model output directory. 20 | 21 | Optional fields: 22 | - **validation_data_dir**: path to validation data directory. String, default is "". 23 | - **regularize_bias**: whether to regularize the intercept. Ususally we do not put regularization on intercept since it is an important feature. Boolean, default is false. 24 | - **l2_reg_weight**: weight of L2 regularization for each feature bag. Float, default is 0.001. 25 | - **optimizer**: optimizer used in the training, currently support LBFGS only. 26 | - **metric**: metric of the model. String, support "auc" and "mse" Default is "auc". 27 | - **copy_to_local**: whether copy training data to local disk. Boolean, default is true. 28 | - **label_column_name**: label column name in the train/validation data. 29 | - **weight_column_name**: weight column name in the train/validation data. 30 | - **prediction_score_column_name**: prediction score column name in the generated result file. 31 | - **feature_bag**: feature bag name that is used for training and scoring. 32 | 33 | ### Random-effect config 34 | Required fields include all fields from fixed-effect config plus: 35 | - **partition_entity**: the column name used to partition data in order to improve random effect model training parallelism. String. 36 | - **num_partitions**: number of partitions. Integer. 37 | 38 | Optional fields include all fields from fixed-effect config plus: 39 | - **max_training_queue_size**: maximum number of training queue size in the producer/consumer model. The trainer is implemented in a producer/consumer model. The producer reads data from hard drive, then the consumers solve the optimization problem for each entity. The blocking queue synchcronizes both sides. Integer, default is 10. 40 | - **num_of_consumers**: the number of consumers (processes that optimizes the models). This specifies the parallelism inside a trainer. Integer, default is 2. 41 | - **enable_local_indexing**: whether to enable local indexing. Some dataset has large global feature space, but small per entity feature space. For example the total features in a dataset could be on the order of millions, but each member has only hundreds of features. We should re-index the features to save memory footprint and increase the training efficiency. Boolean, default is true. 42 | 43 | ## Neural network model supported by DeText for fixed-effect 44 | ### Fixed-effect config 45 | Please refer to DeText training manual [TRAINING.md](https://github.com/linkedin/detext/blob/master/TRAINING.md) for available parameters for the config, the config is a collection of key/value pair, a DeText config example for movieLens data can be found at [detext-single-node-movieLens.config](examples/movielens-100k/detext-single-node-movieLens.config) 46 | -------------------------------------------------------------------------------- /gdmix-workflow/images/gdmix_dev/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.7 2 | 3 | # Install spark 2.4 4 | ARG spark_version=2.4.8 5 | ARG spark_pkg=spark-${spark_version}-bin-hadoop2.7 6 | 7 | RUN apt-get update 8 | RUN apt-get install software-properties-common -y 9 | RUN apt-add-repository 'deb http://security.debian.org/debian-security stretch/updates main' 10 | RUN apt-get update && apt-get install openjdk-8-jdk git -y 11 | RUN mkdir -p /opt/spark 12 | RUN wget https://downloads.apache.org/spark/spark-${spark_version}/${spark_pkg}.tgz && tar -xf ${spark_pkg}.tgz && \ 13 | mv ${spark_pkg}/jars /opt/spark && \ 14 | mv ${spark_pkg}/bin /opt/spark && \ 15 | mv ${spark_pkg}/sbin /opt/spark && \ 16 | mv ${spark_pkg}/examples /opt/spark && \ 17 | mv ${spark_pkg}/data /opt/spark && \ 18 | mv ${spark_pkg}/kubernetes/tests /opt/spark && \ 19 | mv ${spark_pkg}/kubernetes/dockerfiles/spark/entrypoint.sh /opt/ && \ 20 | mkdir -p /opt/spark/conf && \ 21 | cp ${spark_pkg}/conf/log4j.properties.template /opt/spark/conf/log4j.properties && \ 22 | sed -i 's/INFO/ERROR/g' /opt/spark/conf/log4j.properties && \ 23 | chmod +x /opt/*.sh && \ 24 | rm -rf spark-* 25 | 26 | ENV SPARK_HOME=/opt/spark 27 | ENV PATH=/opt/spark/bin:$PATH 28 | ENV SPARK_CLASSPATH=$SPARK_CLASSPATH:/opt/spark/jars/ 29 | 30 | RUN rm -rf ~/.gradle/caches/* ~/.cache/pip/* 31 | 32 | -------------------------------------------------------------------------------- /gdmix-workflow/images/gdmix_dev/build_image.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | REGISTROY=linkedin 4 | IMAGE_NAME=gdmix-dev 5 | VERSION_TAG=0.4.0 6 | VERSIONED_IMAGE_NAME=${REGISTROY}/${IMAGE_NAME}:${VERSION_TAG} 7 | 8 | echo "Building image ${VERSIONED_IMAGE_NAME}" 9 | docker build -t ${VERSIONED_IMAGE_NAME} . 10 | 11 | # TODO: uncomment to push to docker hub 12 | # docker push ${VERSIONED_IMAGE_NAME} 13 | 14 | rm -rf *.config *.py 15 | -------------------------------------------------------------------------------- /gdmix-workflow/images/launcher/sparkapplication/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM alpine:3 2 | 3 | RUN apk update && \ 4 | apk add ca-certificates python-dev py-setuptools wget && \ 5 | easy_install-2.7 pip && \ 6 | pip install pyyaml==3.12 kubernetes 7 | 8 | ADD launcher /launcher 9 | 10 | ENTRYPOINT ["python", "/launcher/launch_sparkapplication.py"] 11 | -------------------------------------------------------------------------------- /gdmix-workflow/images/launcher/sparkapplication/build_image.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | rm -rf ./launcher && mkdir -p ./launcher 4 | rsync -arvp src/ ./launcher/ 5 | rsync -arvp ../common/ ./launcher/ 6 | 7 | # TODO: change to LinkedIn dockerhub endpoint and push 8 | REGISTROY=linkedin 9 | IMAGE_NAME=sparkapplication-launcher 10 | VERSION_TAG=0.1 11 | VERSIONED_IMAGE_NAME=${IMAGE_NAME}:${VERSION_TAG} 12 | REMOTE_IMAGE_NAME=${REGISTROY}/${VERSIONED_IMAGE_NAME} 13 | 14 | docker build -t ${VERSIONED_IMAGE_NAME} . 15 | 16 | # TODO: tag and push to dockerhub 17 | # docker tag ${VERSIONED_IMAGE_NAME} ${REMOTE_IMAGE_NAME} 18 | # docker push ${REMOTE_IMAGE_NAME} 19 | 20 | rm -rf ./launcher 21 | -------------------------------------------------------------------------------- /gdmix-workflow/images/launcher/tfjob/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM alpine:3 2 | 3 | RUN apk update && \ 4 | apk add ca-certificates python-dev py-setuptools wget && \ 5 | easy_install-2.7 pip && \ 6 | pip install pyyaml==3.12 kubernetes 7 | 8 | ADD launcher /launcher 9 | 10 | ENTRYPOINT ["python", "/launcher/launch_tfjob.py"] 11 | -------------------------------------------------------------------------------- /gdmix-workflow/images/launcher/tfjob/build_image.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | rm -rf ./launcher && mkdir -p ./launcher 4 | rsync -arvp src/ ./launcher/ 5 | rsync -arvp ../common/ ./launcher/ 6 | 7 | REGISTROY=linkedin 8 | IMAGE_NAME=tfjob-launcher 9 | VERSION_TAG=0.1 10 | VERSIONED_IMAGE_NAME=${REGISTROY}/${IMAGE_NAME}:${VERSION_TAG} 11 | 12 | echo "Building image ${VERSIONED_IMAGE_NAME}" 13 | docker build -t ${VERSIONED_IMAGE_NAME} . 14 | 15 | # TODO: uncomment to push to dockerhub 16 | # docker push ${VERSIONED_IMAGE_NAME} 17 | 18 | rm -rf ./launcher 19 | -------------------------------------------------------------------------------- /gdmix-workflow/setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 160 3 | 4 | [tool:pytest] 5 | addopts = --ignore build/ --ignore dist/ --junitxml TEST-pytest.xml 6 | 7 | [aliases] 8 | test=pytest 9 | -------------------------------------------------------------------------------- /gdmix-workflow/setup.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from setuptools import find_namespace_packages, setup 3 | from sys import platform as _platform 4 | 5 | import sys 6 | 7 | VERSION = "0.6.0" 8 | current_dir = Path(__file__).resolve().parent 9 | with open(current_dir.joinpath('README.md'), encoding='utf-8') as f: 10 | long_description = f.read() 11 | 12 | if _platform not in ["linux", "linux2", "darwin"]: 13 | print(f"ERROR: platform {_platform} isn't supported") 14 | sys.exit(1) 15 | 16 | setup( 17 | name="gdmix-workflow", 18 | python_requires='>=3.7', 19 | long_description=long_description, 20 | long_description_content_type='text/markdown', 21 | classifiers=["Programming Language :: Python :: 3.7", 22 | "Intended Audience :: Science/Research", 23 | "Intended Audience :: Developers", 24 | "License :: OSI Approved"], 25 | license='BSD-2-CLAUSE', 26 | version=VERSION, 27 | package_dir={'': 'src'}, 28 | packages=find_namespace_packages(where='src'), 29 | package_data={'': ['*.yaml']}, 30 | include_package_data=True, 31 | install_requires=[ 32 | "setuptools>=41.0.0", 33 | "gdmix-trainer>=0.5.0", 34 | "kfp==0.2.5" 35 | ], 36 | tests_require=['pytest'] 37 | ) 38 | -------------------------------------------------------------------------------- /gdmix-workflow/src/conftest.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/gdmix/dc24377e808ecc287ece67d88ffb800fb9ffaaa4/gdmix-workflow/src/conftest.py -------------------------------------------------------------------------------- /gdmix-workflow/src/gdmixworkflow/__init__.py: -------------------------------------------------------------------------------- 1 | # DO NOT COPY & PASTE THIS CODE!!!!! 2 | # 3 | # This is a special file only needed for "src/gdmixworkflow/__init__.py" 4 | # to declare the "gdmixworkflow" package as a "namespace" 5 | # 6 | # All other "__init__.py" files can just be blank, or contain normal Python 7 | # module code. 8 | try: 9 | __import__('pkg_resources').declare_namespace(__name__) 10 | except ImportError: 11 | from pkgutil import extend_path 12 | __path__ = extend_path(__path__, __name__) 13 | -------------------------------------------------------------------------------- /gdmix-workflow/src/gdmixworkflow/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/gdmix/dc24377e808ecc287ece67d88ffb800fb9ffaaa4/gdmix-workflow/src/gdmixworkflow/common/__init__.py -------------------------------------------------------------------------------- /gdmix-workflow/src/gdmixworkflow/common/constants.py: -------------------------------------------------------------------------------- 1 | from gdmix.util.constants import * 2 | 3 | DISTRIBUTED = "distributed" 4 | FIXED_EFFECT_CONFIG = "fixed_effect_config" 5 | GDMIX_TFJOB = "gdmix_tfjob" 6 | GDMIX_SPARKJOB = "gdmix_sparkjob" 7 | METRIC = "metric" 8 | MODELS = "models" 9 | RANDOM_EFFECT_CONFIG = "random_effect_config" 10 | SINGLE_NODE = "single_node" 11 | TRAINING_SCORES = "training_scores" 12 | VALIDATION_SCORES = "validation_scores" 13 | DETEXT_MODEL_OUTPUT_DIR = "out_dir" 14 | DETEXT_DEV_FILE = "dev_file" 15 | DETEXT_TRAIN_FILE = "train_file" 16 | -------------------------------------------------------------------------------- /gdmix-workflow/src/gdmixworkflow/common/utils.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import yaml 4 | 5 | 6 | def gen_random_string(length=6): 7 | """ Generate fixed-length random string. """ 8 | import random 9 | import string 10 | letters = string.ascii_lowercase 11 | return ''.join(random.choice(letters) for _ in range(length)) 12 | 13 | 14 | def abbr(name): 15 | """ Return abbreviation of a given name. 16 | Example: 17 | fixed-effect -> f10t 18 | per-member -> p8r 19 | """ 20 | return name if len(name) <= 2 else f"{name[0]}{len(name) - 2}{name[-1]}" 21 | 22 | 23 | def rm_backslash(params): 24 | """ A '-' at the beginning of a line is a special charter in YAML, 25 | used backslash to escape, need to remove the added backslash for local run. 26 | """ 27 | return {k.strip('\\'): v for k, v in params.items()} 28 | 29 | 30 | def yaml_config_file_to_obj(config_file): 31 | """ load gdmix config from yaml file to object. """ 32 | def _yaml_object_hook(d): 33 | return namedtuple('GDMIX_CONFIG', d.keys())(*d.values()) 34 | 35 | with open(config_file) as f: 36 | config_obj = _yaml_object_hook(yaml.safe_load(f)) 37 | return config_obj 38 | -------------------------------------------------------------------------------- /gdmix-workflow/src/gdmixworkflow/distributed/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/gdmix/dc24377e808ecc287ece67d88ffb800fb9ffaaa4/gdmix-workflow/src/gdmixworkflow/distributed/__init__.py -------------------------------------------------------------------------------- /gdmix-workflow/src/gdmixworkflow/distributed/resource/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/gdmix/dc24377e808ecc287ece67d88ffb800fb9ffaaa4/gdmix-workflow/src/gdmixworkflow/distributed/resource/__init__.py -------------------------------------------------------------------------------- /gdmix-workflow/src/gdmixworkflow/distributed/resource/sparkapplication_component.yaml: -------------------------------------------------------------------------------- 1 | name: SparkApplication-launcher-name 2 | description: SparkApplication launcher name 3 | inputs: 4 | - {name: Name, type: String, description: 'SparkApplication name.'} 5 | - {name: Namespace, type: String, default: k8s-spark, description: 'SparkApplication namespace.'} 6 | - {name: Version, type: String, default: v1beta2, description: 'SparkApplication version.'} 7 | - {name: Restart Policy, type: String, default: Never, description: 'Defines the policy when the SparkApplication fails.'} 8 | - {name: Image, type: String, default: '', description: 'spark image'} 9 | - {name: Main Class, type: String, default: '', description: 'spark job main class'} 10 | - {name: Arguments, type: String, default: '', description: 'spark job arguments'} 11 | - {name: Main Application File, type: String, default: '', description: 'spark job main file'} 12 | - {name: Spark Version, type: String, default: '2.4.5-SNAPSHOT', description: 'spark version'} 13 | - {name: Driver Spec, type: JSON, default: '{}', description: 'SparkApplication driver spec.'} 14 | - {name: Executor Spec, type: JSON, default: '{}', description: 'SparkApplication executor spec.'} 15 | - {name: SparkApplication Timeout Minutes, type: Integer, default: 1440, description: 'Time in minutes to wait for the spark application to complete.'} 16 | - {name: Delete Finished SparkApplication, type: Bool, default: 'True' , description: 'Whether to delete the spark application after it is finished.'} 17 | implementation: 18 | container: 19 | image: linkedin/sparkapplication-launcher 20 | command: [python, /launcher/launch_sparkapplication.py] 21 | args: [ 22 | --name, {inputValue: Name}, 23 | --namespace, {inputValue: Namespace}, 24 | --version, {inputValue: Version}, 25 | --restartPolicy, {inputValue: Restart Policy}, 26 | --image, {inputValue: Image}, 27 | --mainClass, {inputValue: Main Class}, 28 | --arguments, {inputValue: Arguments}, 29 | --mainApplicationFile, {inputValue: Main Application File}, 30 | --sparkVersion, {inputValue: Spark Version}, 31 | --driverSpec, {inputValue: Driver Spec}, 32 | --executorSpec, {inputValue: Executor Spec}, 33 | --sparkApplicationTimeoutMinutes, {inputValue: SparkApplication Timeout Minutes}, 34 | --deleteAfterDone, {inputValue: Delete Finished SparkApplication}, 35 | ] 36 | -------------------------------------------------------------------------------- /gdmix-workflow/src/gdmixworkflow/distributed/resource/tfjob_component.yaml: -------------------------------------------------------------------------------- 1 | name: TFJob-launcher-name 2 | description: TFJob launcher name 3 | inputs: 4 | - {name: Name, type: String, description: 'TFJob name.'} 5 | - {name: Namespace, type: String, default: kubeflow, description: 'TFJob namespace.'} 6 | - {name: Version, type: String, default: v1, description: 'TFJob version.'} 7 | - {name: ActiveDeadlineSeconds, type: Integer, default: -1, description: 'Specifies the duration (in seconds) since startTime during which the job can remain active before it is terminated. Must be a positive integer. This setting applies only to pods where restartPolicy is OnFailure or Always.'} 8 | - {name: BackoffLimit, type: Integer, default: 5, description: 'Number of retries before marking this job as failed.'} 9 | - {name: ttl Seconds After Finished,type: Integer, default: -1, description: 'Defines the TTL for cleaning up finished TFJobs.'} 10 | - {name: CleanPodPolicy, type: String, default: Running, description: 'Defines the policy for cleaning up pods after the TFJob completes.'} 11 | - {name: PS Spec, type: JSON, default: '{}', description: 'TFJob ps replicaSpecs.'} 12 | - {name: Worker Spec, type: JSON, default: '{}', description: 'TFJob worker replicaSpecs.'} 13 | - {name: Chief Spec, type: JSON, default: '{}', description: 'TFJob chief replicaSpecs.'} 14 | - {name: Evaluator Spec, type: JSON, default: '{}', description: 'TFJob evaluator replicaSpecs.'} 15 | - {name: Tfjob Timeout Minutes, type: Integer, default: 1440, description: 'Time in minutes to wait for the TFJob to complete.'} 16 | - {name: Delete Finished Tfjob, type: Bool, default: 'True' , description: 'Whether to delete the tfjob after it is finished.'} 17 | implementation: 18 | container: 19 | image: linkedin/tfjob-launcher 20 | command: [python, /launcher/launch_tfjob.py] 21 | args: [ 22 | --name, {inputValue: Name}, 23 | --namespace, {inputValue: Namespace}, 24 | --version, {inputValue: Version}, 25 | --activeDeadlineSeconds, {inputValue: ActiveDeadlineSeconds}, 26 | --backoffLimit, {inputValue: BackoffLimit}, 27 | --cleanPodPolicy, {inputValue: CleanPodPolicy}, 28 | --ttlSecondsAfterFinished, {inputValue: ttl Seconds After Finished}, 29 | --psSpec, {inputValue: PS Spec}, 30 | --workerSpec, {inputValue: Worker Spec}, 31 | --chiefSpec, {inputValue: Chief Spec}, 32 | --evaluatorSpec, {inputValue: Evaluator Spec}, 33 | --tfjobTimeoutMinutes, {inputValue: Tfjob Timeout Minutes}, 34 | --deleteAfterDone, {inputValue: Delete Finished Tfjob}, 35 | ] 36 | -------------------------------------------------------------------------------- /gdmix-workflow/src/gdmixworkflow/distributed_workflow.py: -------------------------------------------------------------------------------- 1 | from gdmixworkflow.common.utils import yaml_config_file_to_obj, gen_random_string 2 | from gdmixworkflow.common.constants import * 3 | from gdmixworkflow.distributed.container_ops import no_op 4 | from gdmixworkflow.fixed_effect_workflow_generator \ 5 | import FixedEffectWorkflowGenerator 6 | from gdmixworkflow.random_effect_workflow_generator \ 7 | import RandomEffectWorkflowGenerator 8 | import kfp.dsl as dsl 9 | 10 | 11 | @dsl.pipeline() 12 | def gdmix_distributed_workflow(gdmix_config_file, namespace, secret_name, image, service_account): 13 | """ Generate gdmix kubeflow pipeline using Kubeflow pipeline python DSL( kfp.dsl). 14 | """ 15 | 16 | gdmix_config_obj = yaml_config_file_to_obj(gdmix_config_file) 17 | 18 | current_op = no_op("GDMix-training-start") 19 | suffix = gen_random_string() 20 | 21 | if not hasattr(gdmix_config_obj, FIXED_EFFECT_CONFIG): 22 | raise ValueError(f"Need to define {FIXED_EFFECT_CONFIG}") 23 | fe_tip_op = no_op("fixed-effect-training-start") 24 | fe_tip_op.after(current_op) 25 | fe_workflow = FixedEffectWorkflowGenerator(gdmix_config_obj, 26 | namespace=namespace, 27 | secret_name=secret_name, 28 | image=image, 29 | service_account=service_account, 30 | job_suffix=suffix) 31 | fe_start_op, current_op = fe_workflow.gen_workflow() 32 | fe_start_op.after(fe_tip_op) 33 | 34 | if hasattr(gdmix_config_obj, RANDOM_EFFECT_CONFIG): 35 | re_tip_op = no_op("random-effect-training-start") 36 | re_tip_op.after(current_op) 37 | re_workflow = RandomEffectWorkflowGenerator(gdmix_config_obj, namespace=namespace, secret_name=secret_name, image=image, 38 | service_account=service_account, job_suffix=suffix, prev_model_name=fe_workflow.fixed_effect_name) 39 | re_start_op, _ = re_workflow.gen_workflow() 40 | re_start_op.after(re_tip_op) 41 | -------------------------------------------------------------------------------- /gdmix-workflow/src/gdmixworkflow/fixed_effect_workflow_generator.py: -------------------------------------------------------------------------------- 1 | from dataclasses import replace 2 | from os.path import join as path_join 3 | 4 | from detext.run_detext import DetextArg 5 | from gdmix.models.custom.fixed_effect_lr_lbfgs_model import FixedLRParams 6 | from gdmix.params import Params 7 | 8 | from gdmixworkflow.common.constants import * 9 | from gdmixworkflow.workflow_generator import WorkflowGenerator 10 | 11 | 12 | class FixedEffectWorkflowGenerator(WorkflowGenerator): 13 | """ Generate gdmix fixed effect workflow consisting of 14 | - tfjob: train and inference training and validation data 15 | - sparkjob: compute-metric 16 | """ 17 | 18 | def __init__(self, gdmix_config_obj, jar_path="", namespace="", 19 | secret_name="", image="", service_account="", job_suffix=""): 20 | """ Init to generate gdmix fixed effect workflow. """ 21 | super().__init__(gdmix_config_obj, jar_path, namespace, secret_name, image, service_account, job_suffix) 22 | self.fixed_effect_name, self.fixed_effect_config = tuple(self.gdmix_config_obj.fixed_effect_config.items())[0] 23 | self.output_dir = path_join(gdmix_config_obj.output_dir, self.fixed_effect_name) 24 | self.output_model_dir = path_join(self.output_dir, MODELS) 25 | self.validation_score_dir = path_join(self.output_dir, VALIDATION_SCORES) 26 | 27 | # Validate gdmix params 28 | self.gdmix_params: Params = Params(**self.fixed_effect_config.pop('gdmix_config'), 29 | training_score_dir=path_join(self.output_dir, TRAINING_SCORES), 30 | validation_score_dir=self.validation_score_dir) 31 | 32 | self.model_type = self.gdmix_params.model_type 33 | 34 | def get_train_job(self): 35 | """ Get tfjob training job. 36 | :return (job_type, job_name, "", job_params) where job_params are params in dict 37 | """ 38 | model_param_dict = self.fixed_effect_config 39 | if self.model_type == LOGISTIC_REGRESSION: 40 | model_param_dict["output_model_dir"] = self.output_model_dir 41 | elif self.model_type == DETEXT: 42 | # smart-arg's serialization for parameters' doesn't support NoneType(from default value), so use original params 43 | model_param_dict["out_dir"] = self.output_model_dir 44 | else: 45 | raise ValueError(f'unsupported model_type: {self.model_type}') 46 | return GDMIX_TFJOB, f"{self.fixed_effect_name}-tf-train", "", (self.gdmix_params.__dict__, model_param_dict) 47 | 48 | def get_detext_inference_job(self): 49 | """ Get detext inference job. For LR model the inference job is included in train 50 | job, this job is for DeText model inference. 51 | Return: an inference job inferencing training and validation data 52 | (job_type, job_name, "", job_params) 53 | """ 54 | updated_gdmix_params = replace(self.gdmix_params, action=ACTION_INFERENCE) 55 | model_param_dict = self.fixed_effect_config 56 | model_param_dict["out_dir"] = self.output_model_dir 57 | return GDMIX_TFJOB, f"{self.fixed_effect_name}-tf-inference", "", (updated_gdmix_params.__dict__, model_param_dict) 58 | 59 | def get_compute_metric_job(self): 60 | """ Get sparkjob compute metric job. 61 | Return: (job_type, job_name, class_name, job_params) 62 | """ 63 | params = { 64 | r"\--metricsInputDir": self.validation_score_dir, 65 | "--outputMetricFile": path_join(self.output_dir, METRIC), 66 | "--labelColumnName": self.gdmix_params.label_column_name, 67 | "--metricName": "auc", 68 | "--predictionColumnName": self.gdmix_params.prediction_score_column_name 69 | } 70 | return (GDMIX_SPARKJOB, 71 | f"{self.fixed_effect_name}-compute-metric", 72 | "com.linkedin.gdmix.evaluation.Evaluator", 73 | params) 74 | 75 | def get_job_sequence(self): 76 | """ Get job sequence of fixed effect workflow. """ 77 | if self.model_type == LOGISTIC_REGRESSION: 78 | jobs = [self.get_train_job(), self.get_compute_metric_job()] 79 | elif self.model_type == DETEXT: 80 | jobs = [self.get_train_job(), 81 | self.get_detext_inference_job(), 82 | self.get_compute_metric_job()] 83 | else: 84 | raise ValueError(f'unsupported model_type: {self.model_type}') 85 | return jobs 86 | -------------------------------------------------------------------------------- /gdmix-workflow/src/gdmixworkflow/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from functools import partial, update_wrapper 3 | from typing import NamedTuple 4 | 5 | from smart_arg import arg_suite 6 | 7 | from gdmixworkflow.common.constants import SINGLE_NODE, DISTRIBUTED 8 | from gdmixworkflow.distributed_workflow import gdmix_distributed_workflow 9 | from gdmixworkflow.single_node_workflow import run_gdmix_single_node 10 | 11 | 12 | @arg_suite 13 | class FlowArgs(NamedTuple): 14 | """ Creates gdmix workflow. """ 15 | config_path: str # path to gdmix config 16 | mode: str = SINGLE_NODE # distributed or single_node 17 | jar_path: str = "gdmix-data-all_2.11.jar" # local path to the gdmix-data jar for GDMix processing intermediate data, single_node only 18 | workflow_name: str = "gdmix-workflow" # name for the generated zip file to upload to Kubeflow Pipeline, distributed mode only 19 | namespace: str = "default" # Kubernetes namespace, distributed mode only 20 | secret_name: str = "default" # secret name to access storage, distributed mode only 21 | image: str = "linkedin/gdmix" # image used to launch gdmix jobs on Kubernetes, distributed mode only 22 | service_account: str = "default" # service account to launch spark job, distributed mode only 23 | 24 | 25 | def main(): 26 | args: FlowArgs = FlowArgs.__from_argv__() 27 | 28 | if args.mode == SINGLE_NODE: 29 | try: 30 | output_dir = run_gdmix_single_node(args.config_path, args.jar_path) 31 | except RuntimeError as err: 32 | print(str(err)) 33 | sys.exit(1) 34 | 35 | print(f""" 36 | ------------------------ 37 | GDMix training is finished, results are saved to {output_dir}. 38 | """) 39 | 40 | elif args.mode == DISTRIBUTED: 41 | if not args.namespace: 42 | print("ERROR: --namespace is required for distributed mode") 43 | sys.exit(1) 44 | 45 | wrapper = partial( 46 | gdmix_distributed_workflow, 47 | args.config_path, 48 | args.namespace, 49 | args.secret_name, 50 | args.image, 51 | args.service_account) 52 | update_wrapper(wrapper, gdmix_distributed_workflow) 53 | 54 | output_file_name = args.workflow_name + ".zip" 55 | 56 | import kfp.compiler as compiler 57 | compiler.Compiler().compile(wrapper, output_file_name) 58 | print(f"Workflow file is saved to {output_file_name}") 59 | 60 | else: 61 | print(f"ERROR: --mode={args.mode} isn't supported.") 62 | sys.exit(1) 63 | 64 | 65 | if __name__ == "__main__": 66 | main() 67 | -------------------------------------------------------------------------------- /gdmix-workflow/src/gdmixworkflow/single_node/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/gdmix/dc24377e808ecc287ece67d88ffb800fb9ffaaa4/gdmix-workflow/src/gdmixworkflow/single_node/__init__.py -------------------------------------------------------------------------------- /gdmix-workflow/src/gdmixworkflow/single_node/local_ops.py: -------------------------------------------------------------------------------- 1 | from subprocess import Popen, PIPE 2 | 3 | 4 | def get_param_list(params): 5 | """ transform params from dict to list. 6 | """ 7 | if isinstance(params, dict): 8 | for k, v in params.items(): 9 | yield str(k) 10 | yield str(v) 11 | else: 12 | raise ValueError("job params can only be dict") 13 | 14 | 15 | def get_tfjob_cmd(params): 16 | """ get tfjob command for local execution 17 | """ 18 | cmd = ['python', '-m', 'gdmix.gdmix'] 19 | for param in params: 20 | for k, v in param.items(): 21 | if v != "" and v is not None: 22 | cmd.append(f"--{k}={v}") 23 | return cmd 24 | 25 | 26 | def get_sparkjob_cmd(class_name, params, jar='gdmix-data-all_2.11.jar'): 27 | """ get spark command for local execution 28 | """ 29 | cmd = ['spark-submit', 30 | '--class', class_name, 31 | '--master', 'local[*]', 32 | '--num-executors', '1', 33 | '--driver-memory', '1G', 34 | '--executor-memory', '1G', 35 | '--conf', 'spark.sql.avro.compression.codec=deflate', 36 | '--conf', 'spark.hadoop.mapreduce.fileoutputcommitter.marksuccessfuljobs=false', 37 | jar] 38 | cmd.extend(get_param_list(params)) 39 | return cmd 40 | 41 | 42 | def run_cmd(cmd): 43 | """ run gdmix job locally. 44 | Params: 45 | cmd: shell command, e.g. ['spark-submit', '--class', ...] 46 | """ 47 | process = Popen(cmd, stdout=PIPE, stderr=PIPE) 48 | # wait for the process to terminate 49 | out, err = process.communicate() 50 | print(out.decode("utf-8")) 51 | if process.returncode: 52 | raise RuntimeError(f"ERROR in executing command: {str(' '.join(cmd))}\n\nError message:\n{err.decode('utf-8')}") 53 | else: 54 | print(err.decode("utf-8")) 55 | -------------------------------------------------------------------------------- /gdmix-workflow/src/gdmixworkflow/single_node_workflow.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from os.path import join as path_join 4 | 5 | from gdmixworkflow.common.utils import * 6 | from gdmixworkflow.common.constants import * 7 | from gdmixworkflow.fixed_effect_workflow_generator \ 8 | import FixedEffectWorkflowGenerator 9 | from gdmixworkflow.random_effect_workflow_generator \ 10 | import RandomEffectWorkflowGenerator 11 | 12 | 13 | def create_subdirs(parent_dir): 14 | if os.path.isdir(parent_dir): 15 | shutil.rmtree(parent_dir) 16 | os.makedirs(parent_dir) 17 | for sub_dir_name in (MODELS, METRIC, TRAINING_SCORES, VALIDATION_SCORES): 18 | os.makedirs(path_join(parent_dir, sub_dir_name)) 19 | 20 | 21 | def run_gdmix_single_node(gdmix_config_file, jar_path): 22 | """ Run gdmix jobs locally including: 23 | - fixed-effect jobs 24 | - random-effect jobs 25 | """ 26 | gdmix_config_obj = yaml_config_file_to_obj(gdmix_config_file) 27 | output_dir = gdmix_config_obj.output_dir 28 | 29 | if not hasattr(gdmix_config_obj, FIXED_EFFECT_CONFIG): 30 | raise ValueError(f"Need to define {FIXED_EFFECT_CONFIG}") 31 | fe_workflow = FixedEffectWorkflowGenerator(gdmix_config_obj, jar_path=jar_path) 32 | root_dir = path_join(output_dir, fe_workflow.fixed_effect_name) 33 | create_subdirs(root_dir) 34 | fe_workflow.run() 35 | 36 | if hasattr(gdmix_config_obj, RANDOM_EFFECT_CONFIG): 37 | for name, re_config in gdmix_config_obj.random_effect_config.items(): 38 | root_dir = path_join(output_dir, name) 39 | create_subdirs(root_dir) 40 | num_partitions = re_config['num_partitions'] 41 | for score_output_name in (TRAINING_SCORES, VALIDATION_SCORES): 42 | sub_dir = path_join(root_dir, score_output_name) 43 | for idx in range(num_partitions): 44 | os.makedirs(path_join(sub_dir, f"partitionId={idx}")) 45 | re_workflow = RandomEffectWorkflowGenerator(gdmix_config_obj, jar_path=jar_path, prev_model_name=fe_workflow.fixed_effect_name) 46 | re_workflow.run() 47 | 48 | return output_dir 49 | -------------------------------------------------------------------------------- /gdmix-workflow/test/common/test_utils.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import os 3 | import tempfile 4 | import unittest 5 | import shutil 6 | from gdmixworkflow.common.utils import * 7 | 8 | 9 | class TestUtils(unittest.TestCase): 10 | """ 11 | Test gdmix workflow utils 12 | """ 13 | 14 | def setUp(self): 15 | self.output_dir = tempfile.mkdtemp() 16 | config = {"a": {"b1": "b2"}, 17 | "c": "d"} 18 | self.config_file_name = os.path.join(self.output_dir, "config.yaml") 19 | with open(self.config_file_name, 'w') as f: 20 | yaml.dump(config, f) 21 | 22 | def tearDown(self): 23 | shutil.rmtree(self.output_dir) 24 | 25 | def test_gen_random_string(self): 26 | expectedLen = 8 27 | actualLen = len(gen_random_string(expectedLen)) 28 | 29 | self.assertEqual(actualLen, expectedLen) 30 | 31 | def test_abbr(self): 32 | inputStr = "fixed-effect" 33 | expected = "f10t" 34 | actual = abbr(inputStr) 35 | self.assertEqual(actual, expected) 36 | 37 | def test_yaml_config_file_to_obj(self): 38 | config_obj = yaml_config_file_to_obj(self.config_file_name) 39 | self.assertEqual(config_obj.a['b1'], "b2") 40 | self.assertEqual(config_obj.c, "d") 41 | 42 | 43 | if __name__ == '__main__': 44 | unittest.main() 45 | -------------------------------------------------------------------------------- /gdmix-workflow/test/resources/detext-movieLens.yaml: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "detext-training", 3 | "fixed_effect_config": { 4 | "global": { 5 | "ftr_ext": "cnn", 6 | "ltr_loss_fn": "pointwise", 7 | "learning_rate": 0.002, 8 | "num_classes": 1, 9 | "max_len": 16, 10 | "min_len": 3, 11 | "num_filters": 50, 12 | "num_train_steps": 1000, 13 | "num_units": 64, 14 | "optimizer": "adamw", 15 | "pmetric": "auc", 16 | "steps_per_stats": 10, 17 | "steps_per_eval": 100, 18 | "train_batch_size": 64, 19 | "test_batch_size": 64, 20 | "vocab_file": "movieLens/detext/vocab.txt", 21 | "resume_training": false, 22 | "train_file": "movieLens/detext/trainingData", 23 | "dev_file": "movieLens/detext/validationData", 24 | "test_file": "movieLens/detext/validationData", 25 | "keep_checkpoint_max": 1, 26 | "distribution_strategy": "one_device", 27 | "task_type": "binary_classification", 28 | "sparse_ftrs_column_names": "wide_ftrs_sp", 29 | "doc_text_column_names": "doc_query", 30 | "nums_sparse_ftrs": 100, 31 | "num_gpu": 0, 32 | gdmix_config: &g { 33 | # GDMixParams 34 | "model_type": "detext", 35 | 36 | # SchemaParams 37 | "label_column_name": "response", 38 | "weight_column_name": "weight", 39 | "uid_column_name": "uid", 40 | "prediction_score_column_name": "predictionScore" 41 | }, 42 | }, 43 | }, 44 | "random_effect_config": { 45 | "per-user": { 46 | "partition_entity": "user_id", 47 | "training_data_dir": "movieLens/per_user/trainingData", 48 | "validation_data_dir": "movieLens/per_user/validationData", 49 | "feature_file": "movieLens/per_user/featureList/per_user", 50 | 51 | "feature_bag": "per_user", 52 | "metadata_file": "movieLens/per_user/metadata/tensor_metadata.json", 53 | "l2_reg_weight": 1.0, 54 | "regularize_bias": false, 55 | 56 | "lbfgs_tolerance": 1.0e-12, 57 | "num_of_lbfgs_iterations": 100, 58 | "num_of_lbfgs_curvature_pairs": 10, 59 | "max_training_queue_size": 10, 60 | "num_of_consumers": 1, 61 | "enable_local_indexing": false, 62 | 63 | # extra params 64 | "num_partitions": 1, 65 | 66 | gdmix_config: &g_r { 67 | <<: *g, 68 | "model_type": "logistic_regression", 69 | }, 70 | }, 71 | "per-movie": { 72 | "partition_entity": "movie_id", 73 | "training_data_dir": "movieLens/per_movie/trainingData", 74 | "validation_data_dir": "movieLens/per_movie/validationData", 75 | "feature_file": "movieLens/per_movie/featureList/per_movie", 76 | 77 | "feature_bag": "per_movie", 78 | "metadata_file": "movieLens/per_movie/metadata/tensor_metadata.json", 79 | "l2_reg_weight": 1.0, 80 | "regularize_bias": false, 81 | "num_partitions": 1, 82 | 83 | "lbfgs_tolerance": 1.0e-12, 84 | "num_of_lbfgs_iterations": 100, 85 | "num_of_lbfgs_curvature_pairs": 10, 86 | 87 | "max_training_queue_size": 10, 88 | "num_of_consumers": 1, 89 | "enable_local_indexing": false, 90 | 91 | gdmix_config: *g_r, 92 | } 93 | }, 94 | # configs for dstributed runs, will be ignored by single-node runs 95 | "spark_config":{ 96 | "executorInstances":2, 97 | "executorCores":1, 98 | "driverMemory":"1g", 99 | "executorMemory":"1g" 100 | }, 101 | "tfjob_config":{ 102 | "workerType":"gpu", 103 | "needChief":false, 104 | "psNum":1, 105 | "evaluatorNum":0, 106 | "workerNum":2, 107 | "memorySize":"1g" 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /gdmix-workflow/test/resources/lr-movieLens.yaml: -------------------------------------------------------------------------------- 1 | { 2 | "output_dir": "lr-training", 3 | "fixed_effect_config": { 4 | "global": { 5 | # LRParams 6 | "training_data_dir": "movieLens/global/trainingData", 7 | "validation_data_dir": "movieLens/global/validationData", 8 | "feature_file": "movieLens/global/featureList/global", 9 | "feature_bag": "global", 10 | "metadata_file": "movieLens/global/metadata/tensor_metadata.json", 11 | "l2_reg_weight": 1.0, 12 | "regularize_bias": false, 13 | 14 | "lbfgs_tolerance": 1.0e-12, 15 | "num_of_lbfgs_iterations": 100, 16 | "num_of_lbfgs_curvature_pairs": 10, 17 | 18 | # FixedLRParams(LRParams) 19 | "copy_to_local": false, 20 | 21 | # Params 22 | gdmix_config: &g { 23 | # GDMixParams 24 | "model_type": "logistic_regression", 25 | 26 | # SchemaParams 27 | "label_column_name": "response", 28 | "uid_column_name": "uid", 29 | "prediction_score_column_name": "predictionScore", 30 | "weight_column_name": "weight" 31 | }, 32 | } 33 | }, 34 | "random_effect_config": { 35 | "per-user": { 36 | "partition_entity": "user_id", 37 | "training_data_dir": "movieLens/per_user/trainingData", 38 | "validation_data_dir": "movieLens/per_user/validationData", 39 | "feature_file": "movieLens/per_user/featureList/per_user", 40 | 41 | "feature_bag": "per_user", 42 | "metadata_file": "movieLens/per_user/metadata/tensor_metadata.json", 43 | "l2_reg_weight": 1.0, 44 | "regularize_bias": false, 45 | 46 | "lbfgs_tolerance": 1.0e-12, 47 | "num_of_lbfgs_iterations": 100, 48 | "num_of_lbfgs_curvature_pairs": 10, 49 | "max_training_queue_size": 10, 50 | "num_of_consumers": 1, 51 | "enable_local_indexing": false, 52 | 53 | # extra params 54 | "num_partitions": 1, 55 | "gdmix_config": *g, 56 | }, 57 | "per-movie": { 58 | "partition_entity": "movie_id", 59 | "training_data_dir": "movieLens/per_movie/trainingData", 60 | "validation_data_dir": "movieLens/per_movie/validationData", 61 | "feature_file": "movieLens/per_movie/featureList/per_movie", 62 | 63 | "feature_bag": "per_movie" 64 | , 65 | "metadata_file": "movieLens/per_movie/metadata/tensor_metadata.json", 66 | "l2_reg_weight": 1.0, 67 | "regularize_bias": false, 68 | 69 | "lbfgs_tolerance": 1.0e-12, 70 | "num_of_lbfgs_iterations": 100, 71 | "num_of_lbfgs_curvature_pairs": 10, 72 | 73 | "max_training_queue_size": 10, 74 | "num_of_consumers": 1, 75 | "enable_local_indexing": false, 76 | 77 | "num_partitions": 1, 78 | gdmix_config: *g 79 | } 80 | }, 81 | # configs for dstributed runs, will be ignored by single-node runs 82 | "spark_config":{ 83 | "executorInstances":2, 84 | "executorCores":1, 85 | "driverMemory":"1g", 86 | "executorMemory":"1g" 87 | }, 88 | "tfjob_config":{ 89 | "workerType":"cpu", 90 | "needChief":false, 91 | "psNum":0, 92 | "evaluatorNum":0, 93 | "workerNum":2, 94 | "memorySize":"1g" 95 | } 96 | } -------------------------------------------------------------------------------- /gdmix-workflow/test/single_node/test_local_ops.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from gdmixworkflow.single_node.local_ops import get_tfjob_cmd, get_sparkjob_cmd 3 | 4 | 5 | class TestLocalOps(unittest.TestCase): 6 | """ 7 | Test commands for single node workflow 8 | """ 9 | 10 | def test_get_tfjob_cmd(self): 11 | params = () 12 | expected = ['python', '-m', 'gdmix.gdmix'] 13 | actual = get_tfjob_cmd(params) 14 | self.assertEqual(actual, expected) 15 | 16 | def test_get_sparkjob_cmd(self): 17 | class_name = "Hello" 18 | params = {"-a": "b"} 19 | expected = ['spark-submit', 20 | '--class', "Hello", 21 | '--master', 'local[*]', 22 | '--num-executors','1', 23 | '--driver-memory', '1G', 24 | '--executor-memory', '1G', 25 | '--conf', 'spark.sql.avro.compression.codec=deflate', 26 | '--conf', 'spark.hadoop.mapreduce.fileoutputcommitter.marksuccessfuljobs=false', 27 | 'gdmix-data-all_2.11.jar', 28 | '-a', 'b'] 29 | actual = get_sparkjob_cmd(class_name, params) 30 | self.assertEqual(actual, expected) 31 | 32 | 33 | if __name__ == '__main__': 34 | unittest.main() 35 | -------------------------------------------------------------------------------- /gdmix.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM linkedin/gdmix-dev 2 | 3 | # Install Notebook 4 | RUN pip install notebook jupyter_contrib_nbextensions 5 | RUN jupyter contrib nbextension install 6 | 7 | ARG WORK_DIR="/workspace/notebook" 8 | WORKDIR ${WORK_DIR} 9 | 10 | # Install gdmix components 11 | ARG BUILD_DIR="build_dir" 12 | RUN mkdir ${BUILD_DIR} 13 | COPY gdmix-trainer ${BUILD_DIR}/gdmix-trainer 14 | COPY gdmix-workflow ${BUILD_DIR}/gdmix-workflow 15 | COPY gdmix-data-all ${BUILD_DIR}/gdmix-data-all 16 | COPY gdmix-data ${BUILD_DIR}/gdmix-data 17 | COPY gradle ${BUILD_DIR}/gradle 18 | COPY scripts ${BUILD_DIR}/scripts 19 | COPY build.gradle ${BUILD_DIR}/ 20 | COPY settings.gradle ${BUILD_DIR}/ 21 | COPY gradlew ${BUILD_DIR}/ 22 | 23 | # Install GDMix components 24 | RUN cd ${BUILD_DIR} 25 | RUN python -m pip install --upgrade pip && pip install --upgrade setuptools pytest 26 | RUN cd ${BUILD_DIR}/gdmix-trainer && pip install . && cd ../.. 27 | RUN cd ${BUILD_DIR}/gdmix-workflow && pip install . && cd ../.. 28 | RUN cd ${BUILD_DIR} && sh gradlew shadowJar && cp build/gdmix-data-all_2.11/libs/gdmix-data-all_2.11*.jar ${WORK_DIR} 29 | 30 | # Download and process movieLens data 31 | RUN cp ${WORK_DIR}/${BUILD_DIR}/scripts/download_process_movieLens_data.py . 32 | RUN pip install pandas 33 | RUN python download_process_movieLens_data.py 34 | 35 | # Copy gdmix configs for movieLens exmaple 36 | RUN cp ${WORK_DIR}/${BUILD_DIR}/gdmix-workflow/examples/movielens-100k/*.yaml . 37 | 38 | RUN rm -rf ~/.gradle/caches/* ~/.cache/pip/* ${WORK_DIR}/${BUILD_DIR} 39 | 40 | -------------------------------------------------------------------------------- /gradle.properties: -------------------------------------------------------------------------------- 1 | # Gradle will run tasks from subprojects in parallel. 2 | # Higher CPU usage, faster builds. 3 | org.gradle.parallel=true 4 | 5 | # Starting from Gradle 5, the default memory limit for Gradle daemon is 512MB. 6 | # This is not enough for LI builds so we need to increase the default. 7 | org.gradle.jvmargs=-Xmx1024m "-XX:MaxMetaspaceSize=256m" 8 | 9 | -------------------------------------------------------------------------------- /gradle/wrapper/gradle-wrapper.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/gdmix/dc24377e808ecc287ece67d88ffb800fb9ffaaa4/gradle/wrapper/gradle-wrapper.jar -------------------------------------------------------------------------------- /gradle/wrapper/gradle-wrapper.properties: -------------------------------------------------------------------------------- 1 | distributionBase=GRADLE_USER_HOME 2 | distributionPath=wrapper/dists 3 | zipStoreBase=GRADLE_USER_HOME 4 | zipStorePath=wrapper/dists 5 | distributionUrl=https\://services.gradle.org/distributions/gradle-4.5-bin.zip 6 | -------------------------------------------------------------------------------- /gradlew.bat: -------------------------------------------------------------------------------- 1 | @if "%DEBUG%" == "" @echo off 2 | @rem ########################################################################## 3 | @rem 4 | @rem Gradle startup script for Windows 5 | @rem 6 | @rem ########################################################################## 7 | 8 | @rem Set local scope for the variables with windows NT shell 9 | if "%OS%"=="Windows_NT" setlocal 10 | 11 | set DIRNAME=%~dp0 12 | if "%DIRNAME%" == "" set DIRNAME=. 13 | set APP_BASE_NAME=%~n0 14 | set APP_HOME=%DIRNAME% 15 | 16 | @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. 17 | set DEFAULT_JVM_OPTS= 18 | 19 | @rem Find java.exe 20 | if defined JAVA_HOME goto findJavaFromJavaHome 21 | 22 | set JAVA_EXE=java.exe 23 | %JAVA_EXE% -version >NUL 2>&1 24 | if "%ERRORLEVEL%" == "0" goto init 25 | 26 | echo. 27 | echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 28 | echo. 29 | echo Please set the JAVA_HOME variable in your environment to match the 30 | echo location of your Java installation. 31 | 32 | goto fail 33 | 34 | :findJavaFromJavaHome 35 | set JAVA_HOME=%JAVA_HOME:"=% 36 | set JAVA_EXE=%JAVA_HOME%/bin/java.exe 37 | 38 | if exist "%JAVA_EXE%" goto init 39 | 40 | echo. 41 | echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 42 | echo. 43 | echo Please set the JAVA_HOME variable in your environment to match the 44 | echo location of your Java installation. 45 | 46 | goto fail 47 | 48 | :init 49 | @rem Get command-line arguments, handling Windows variants 50 | 51 | if not "%OS%" == "Windows_NT" goto win9xME_args 52 | 53 | :win9xME_args 54 | @rem Slurp the command line arguments. 55 | set CMD_LINE_ARGS= 56 | set _SKIP=2 57 | 58 | :win9xME_args_slurp 59 | if "x%~1" == "x" goto execute 60 | 61 | set CMD_LINE_ARGS=%* 62 | 63 | :execute 64 | @rem Setup the command line 65 | 66 | set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar 67 | 68 | @rem Execute Gradle 69 | "%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% 70 | 71 | :end 72 | @rem End local scope for the variables with windows NT shell 73 | if "%ERRORLEVEL%"=="0" goto mainEnd 74 | 75 | :fail 76 | rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of 77 | rem the _cmd.exe /c_ return code! 78 | if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 79 | exit /b 1 80 | 81 | :mainEnd 82 | if "%OS%"=="Windows_NT" endlocal 83 | 84 | :omega 85 | -------------------------------------------------------------------------------- /settings.gradle: -------------------------------------------------------------------------------- 1 | buildscript { 2 | repositories { 3 | jcenter() 4 | } 5 | dependencies { 6 | } 7 | } 8 | 9 | /* Modules to use */ 10 | def modules = [ 11 | 'gdmix-data', 12 | 'gdmix-data-all' 13 | ] 14 | 15 | include(*modules) 16 | 17 | /* Scala projects */ 18 | def scalaProjects = ['gdmix-data', 'gdmix-data-all'] 19 | 20 | def scalaSuffix = "_2.11" 21 | 22 | gradle.ext.scalaSuffix = scalaSuffix 23 | 24 | // Make sure the suffix is in sync with the scala version 25 | scalaProjects.forEach { 26 | project(new File(rootProject.projectDir, it)).name += scalaSuffix 27 | } 28 | 29 | --------------------------------------------------------------------------------