├── CONTRIBUTING ├── conda.txt ├── app ├── requirements.txt ├── app_template.yaml ├── openapi.yaml ├── main.py └── recommendations.py ├── requirements.txt ├── wals_ml_engine ├── trainer │ ├── config │ │ ├── config_train.json │ │ ├── config_tune_web.json │ │ └── config_tune.json │ ├── __init__.py │ ├── util.py │ ├── wals.py │ ├── task.py │ └── model.py ├── setup.py ├── README.md └── mltrain.sh ├── .gitignore ├── airflow ├── deploy │ ├── __init__.py │ ├── config │ │ ├── k8s │ │ │ ├── service-redis.yaml │ │ │ ├── service-sql-proxy.yaml │ │ │ ├── service-webserver.yaml │ │ │ ├── redis.yaml │ │ │ ├── deploy-redis.yaml │ │ │ ├── deploy-sql-proxy.yaml │ │ │ ├── job-testing.yaml │ │ │ ├── job-init.yaml │ │ │ ├── deploy-workers.yaml │ │ │ ├── deploy-webserver.yaml │ │ │ └── deploy-scheduler.yaml │ │ ├── settings-template.yaml │ │ └── airflow.cfg │ ├── deploy_airflow.sh │ ├── create_buckets.py │ ├── teardown_local.py │ └── gcp_util.py ├── unittests.cfg ├── dags │ └── training.py ├── plugins │ ├── ml_engine_plugin.py │ └── gae_admin_plugin.py └── airflow.cfg ├── scripts ├── delete_project.sh ├── util.sh ├── prepare_deploy_api.sh ├── generate_traffic.sh ├── query_api_auth.sh ├── query_api.sh └── prepare_deploy_app.sh ├── LICENSE ├── README.md ├── data └── ga_sessions_sample_schema.json └── notebooks └── Part1.ipynb /CONTRIBUTING: -------------------------------------------------------------------------------- 1 | We are not accepting contributions at this time. 2 | -------------------------------------------------------------------------------- /conda.txt: -------------------------------------------------------------------------------- 1 | python==2.7.13 2 | numpy==1.13.1 3 | pandas==0.20.3 4 | scipy==0.19.1 5 | 6 | -------------------------------------------------------------------------------- /app/requirements.txt: -------------------------------------------------------------------------------- 1 | flask 2 | gunicorn 3 | pandas 4 | numpy 5 | google-cloud-storage==1.6.0 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | google-api-python-client 2 | google-compute-engine 3 | google-cloud-storage 4 | sh 5 | retrying 6 | pyyaml 7 | oauth2client==3.0.0 -------------------------------------------------------------------------------- /wals_ml_engine/trainer/config/config_train.json: -------------------------------------------------------------------------------- 1 | { 2 | "trainingInput": { 3 | "scaleTier": "CUSTOM", 4 | "masterType": "complex_model_m_gpu", 5 | }, 6 | } -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | model 3 | *.pyc 4 | __pycache__ 5 | logs 6 | *.db 7 | jobs 8 | wals_ml_engine.egg-info 9 | dist 10 | airflow/deploy/config/settings.yaml 11 | airflow/deploy/deployment-settings.yaml 12 | -------------------------------------------------------------------------------- /airflow/deploy/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /wals_ml_engine/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /scripts/delete_project.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | source util.sh 17 | 18 | main() { 19 | # Get our working project, or exit if it's not set. 20 | local project_id="$(get_project_id)" 21 | if [[ -z "$project_id" ]]; then 22 | exit 1 23 | fi 24 | gcloud -q projects delete "$project_id" 25 | } 26 | 27 | main "$@" 28 | -------------------------------------------------------------------------------- /airflow/deploy/config/k8s/service-redis.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | apiVersion: v1 16 | kind: Service 17 | metadata: 18 | labels: 19 | run: airflow-redis 20 | name: airflow-redis 21 | namespace: default 22 | spec: 23 | ports: 24 | - port: 6379 25 | protocol: TCP 26 | targetPort: 6379 27 | selector: 28 | run: airflow-redis 29 | sessionAffinity: None 30 | type: ClusterIP 31 | -------------------------------------------------------------------------------- /airflow/deploy/config/k8s/service-sql-proxy.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | apiVersion: v1 16 | kind: Service 17 | metadata: 18 | labels: 19 | run: airflow-sqlproxy 20 | name: airflow-sqlproxy 21 | namespace: default 22 | spec: 23 | ports: 24 | - port: 3306 25 | protocol: TCP 26 | targetPort: 3306 27 | selector: 28 | run: airflow-sqlproxy 29 | sessionAffinity: None 30 | type: ClusterIP 31 | -------------------------------------------------------------------------------- /airflow/deploy/config/k8s/service-webserver.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | apiVersion: v1 16 | kind: Service 17 | metadata: 18 | labels: 19 | run: airflow-webserver 20 | name: airflow-webserver 21 | namespace: default 22 | spec: 23 | ports: 24 | - port: 8080 25 | protocol: TCP 26 | targetPort: 8080 27 | selector: 28 | run: airflow-webserver 29 | sessionAffinity: None 30 | type: LoadBalancer 31 | -------------------------------------------------------------------------------- /wals_ml_engine/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from setuptools import find_packages 16 | from setuptools import setup 17 | 18 | REQUIRED_PACKAGES = ['sh'] 19 | 20 | setup( 21 | name='wals_ml_engine', 22 | version='0.1', 23 | install_requires=REQUIRED_PACKAGES, 24 | packages=find_packages(), 25 | include_package_data=True, 26 | description='A trainer application package for WALS on ML Engine.' 27 | ) -------------------------------------------------------------------------------- /app/app_template.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | runtime: python 16 | env: flex 17 | entrypoint: gunicorn -b :$PORT main:app 18 | 19 | runtime_config: 20 | python_version: 3 21 | 22 | resources: 23 | cpu: 2 24 | memory_gb: 12 25 | disk_size_gb: 100 26 | 27 | service: 28 | default 29 | 30 | #readiness_check: 31 | # path: /readiness_check 32 | # app_start_timeout_sec: 300 33 | 34 | manual_scaling: 35 | instances: 2 36 | 37 | endpoints_api_service: 38 | # This is a template. Use the deploy_app.sh script to deploy this app. 39 | name: SERVICE_NAME 40 | config_id: SERVICE_CONFIG_ID 41 | -------------------------------------------------------------------------------- /wals_ml_engine/trainer/config/config_tune_web.json: -------------------------------------------------------------------------------- 1 | { 2 | "trainingInput": { 3 | "scaleTier": "CUSTOM", 4 | "masterType": "standard_gpu", 5 | "hyperparameters": { 6 | "goal": "MINIMIZE", 7 | "params": [ 8 | { 9 | "parameterName": "regularization", 10 | "type": "DOUBLE", 11 | "minValue": "0.001", 12 | "maxValue": "10.0", 13 | "scaleType": "UNIT_LINEAR_SCALE" 14 | }, 15 | { 16 | "parameterName": "latent_factors", 17 | "type": "INTEGER", 18 | "minValue": "5", 19 | "maxValue": "50", 20 | "scaleType": "UNIT_LINEAR_SCALE" 21 | }, 22 | { 23 | "parameterName": "unobs_weight", 24 | "type": "DOUBLE", 25 | "minValue": "0.001", 26 | "maxValue": "5.0", 27 | "scaleType": "UNIT_LOG_SCALE" 28 | }, 29 | { 30 | "parameterName": "feature_wt_exp", 31 | "type": "DOUBLE", 32 | "minValue": ".0001", 33 | "maxValue": "10", 34 | "scaleType": "UNIT_LINEAR_SCALE" 35 | } 36 | ], 37 | "maxTrials": 500, 38 | "maxParallelTrials": 1 39 | } 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /wals_ml_engine/trainer/config/config_tune.json: -------------------------------------------------------------------------------- 1 | { 2 | "trainingInput": { 3 | "scaleTier": "CUSTOM", 4 | "masterType": "standard_gpu", 5 | "hyperparameters": { 6 | "goal": "MINIMIZE", 7 | "params": [ 8 | { 9 | "parameterName": "regularization", 10 | "type": "DOUBLE", 11 | "minValue": "0.001", 12 | "maxValue": "10.0", 13 | "scaleType": "UNIT_REVERSE_LOG_SCALE" 14 | }, 15 | { 16 | "parameterName": "latent_factors", 17 | "type": "INTEGER", 18 | "minValue": "5", 19 | "maxValue": "50", 20 | "scaleType": "UNIT_REVERSE_LOG_SCALE" 21 | }, 22 | { 23 | "parameterName": "unobs_weight", 24 | "type": "DOUBLE", 25 | "minValue": "0.001", 26 | "maxValue": "5.0", 27 | "scaleType": "UNIT_REVERSE_LOG_SCALE" 28 | }, 29 | { 30 | "parameterName": "feature_wt_factor", 31 | "type": "DOUBLE", 32 | "minValue": "1", 33 | "maxValue": "200", 34 | "scaleType": "UNIT_LOG_SCALE" 35 | } 36 | ], 37 | "maxTrials": 500, 38 | "maxParallelTrials": 1 39 | } 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /airflow/deploy/deploy_airflow.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2018 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | set -euo pipefail 17 | 18 | # make a copy of the settings template and replace the project id 19 | project_id=$(gcloud config get-value project 2> /dev/null) 20 | 21 | < config/settings-template.yaml sed -E "s/YOUR-PROJECT-ID/${project_id}/g" > "config/settings.yaml" 22 | 23 | # run deploy script 24 | echo "python deploy_airflow.py --airflow_config config/airflow.cfg --settings config/settings.yaml --output_settings ./deployment-settings.yaml" 25 | python deploy_airflow.py --airflow_config config/airflow.cfg \ 26 | --settings config/settings.yaml \ 27 | --output_settings ./deployment-settings.yaml 28 | 29 | -------------------------------------------------------------------------------- /airflow/deploy/create_buckets.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Create buckets for 'Airflow on GKE' dags, logs and plugins.""" 16 | 17 | import yaml 18 | 19 | from google.cloud import storage 20 | 21 | 22 | def create_folder(gcs_bucket, folder): 23 | blob = gcs_bucket.blob(folder) 24 | blob.upload_from_string('', content_type='application/x-www-form-urlencoded;charset=UTF-8') 25 | 26 | with open('airflow/deploy/deployment-settings.yaml') as f: 27 | settings = yaml.load(f) 28 | bucket_uri = settings['id'] 29 | 30 | gcs_client = storage.Client() 31 | bucket = gcs_client.get_bucket(bucket_uri) 32 | create_folder(bucket, 'dags/') 33 | create_folder(bucket, 'logs/') 34 | create_folder(bucket, 'plugins/') 35 | -------------------------------------------------------------------------------- /airflow/deploy/config/settings-template.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | project: YOUR-PROJECT-ID 16 | web_ui_username: airflow 17 | cloud_sql: 18 | region: us-central1 19 | tier: db-n1-standard-1 20 | 21 | root_user: root 22 | database: airflow 23 | user: airflow 24 | 25 | gcs: 26 | location: us-central1 27 | 28 | kubernetes: 29 | zone: "us-central1-a" 30 | deployment: 31 | webserver: "config/k8s/deploy-webserver.yaml" 32 | redis: "config/k8s/deploy-redis.yaml" 33 | sql_proxy: "config/k8s/deploy-sql-proxy.yaml" 34 | workers: "config/k8s/deploy-workers.yaml" 35 | scheduler: "config/k8s/deploy-scheduler.yaml" 36 | job: 37 | init: "config/k8s/job-init.yaml" 38 | service: 39 | sql_proxy: "config/k8s/service-sql-proxy.yaml" 40 | redis: "config/k8s/service-redis.yaml" 41 | webserver: "config/k8s/service-webserver.yaml" 42 | -------------------------------------------------------------------------------- /wals_ml_engine/README.md: -------------------------------------------------------------------------------- 1 | # Collaborative Filtering in TensorFlow using WALS 2 | 3 | This code implements a collaborative filtering recommendation model using the WALS 4 | (weighted least alternating squares) algorithm provided in TensorFlow, and applies 5 | it to the MovieLens data set. 6 | 7 | ## Install 8 | 9 | This code assumes python 2. 10 | 11 | * Install miniconda2: 12 | 13 | https://conda.io/docs/user-guide/install/index.html 14 | 15 | 16 | * Create environment and install packages: 17 | 18 | Assuming you are in the repo directory: 19 | 20 | ``` 21 | $ conda create -n tfrec 22 | $ conda install -n tfrec --file requirements.txt 23 | ``` 24 | 25 | * Install TensorFlow. 26 | 27 | CPU: 28 | ``` 29 | $ pip install tensorflow 30 | ``` 31 | 32 | Or GPU, if one is available in your environment: 33 | 34 | ``` 35 | $ pip install tensorflow-gpu 36 | ``` 37 | 38 | 39 | ## Download Movielens Data 40 | 41 | ``` 42 | $ curl -O 'http://files.grouplens.org/datasets/movielens/ml-100k.zip' 43 | $ unzip ml-100k.zip 44 | $ mkdir data 45 | $ cp ml-100k/u.data data/ 46 | ``` 47 | 48 | ## Run 49 | 50 | * Train the model locally 51 | ``` 52 | $ ./mltrain.sh local data/u.data 53 | ``` 54 | 55 | * Train the model on ML Engine: 56 | ``` 57 | $ gsutil cp data/u.data gs://mybucket/data/u.data 58 | $ ./mltrain.sh train gs://mybucket/data/u.data 59 | 60 | ``` 61 | 62 | * Hyperparameter tuning on ML Engine: 63 | ``` 64 | $ ./mltrain.sh tune gs://mybucket/data/u.data 65 | 66 | ``` 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /app/openapi.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | --- 15 | swagger: "2.0" 16 | info: 17 | title: "Recommendations" 18 | description: "Get recommended items for a User given their ID." 19 | version: "1.0.0" 20 | # This field will be replaced by the deploy_api.sh script. 21 | host: "YOUR-PROJECT-ID.appspot.com" 22 | schemes: 23 | - "https" 24 | paths: 25 | "/recommendation": 26 | get: 27 | description: "Get recommended items for a User given their ID." 28 | operationId: "recommendation" 29 | parameters: 30 | - 31 | name: userId 32 | in: query 33 | required: true 34 | type: integer 35 | - 36 | name: numRecs 37 | in: query 38 | required: false 39 | type: integer 40 | responses: 41 | 200: 42 | description: "Success." 43 | schema: 44 | type: string 45 | 400: 46 | description: "The User Id is not found or missing." 47 | -------------------------------------------------------------------------------- /airflow/deploy/config/k8s/redis.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | apiVersion: extensions/v1beta1 16 | kind: Deployment 17 | metadata: 18 | labels: 19 | run: airflow-redis 20 | name: airflow-redis 21 | namespace: default 22 | spec: 23 | replicas: 1 24 | selector: 25 | matchLabels: 26 | run: airflow-redis 27 | strategy: 28 | rollingUpdate: 29 | maxSurge: 1 30 | maxUnavailable: 1 31 | type: RollingUpdate 32 | template: 33 | metadata: 34 | creationTimestamp: null 35 | labels: 36 | run: airflow-redis 37 | spec: 38 | containers: 39 | - image: redis:alpine 40 | imagePullPolicy: Always 41 | name: airflow-redis 42 | ports: 43 | - containerPort: 6379 44 | protocol: TCP 45 | resources: {} 46 | terminationMessagePath: /dev/termination-log 47 | dnsPolicy: ClusterFirst 48 | restartPolicy: Always 49 | securityContext: {} 50 | terminationGracePeriodSeconds: 30 -------------------------------------------------------------------------------- /airflow/deploy/config/k8s/deploy-redis.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | apiVersion: extensions/v1beta1 16 | kind: Deployment 17 | metadata: 18 | labels: 19 | run: airflow-redis 20 | name: airflow-redis 21 | namespace: default 22 | spec: 23 | replicas: 1 24 | selector: 25 | matchLabels: 26 | run: airflow-redis 27 | strategy: 28 | rollingUpdate: 29 | maxSurge: 1 30 | maxUnavailable: 1 31 | type: RollingUpdate 32 | template: 33 | metadata: 34 | creationTimestamp: null 35 | labels: 36 | run: airflow-redis 37 | spec: 38 | containers: 39 | - image: redis:alpine 40 | imagePullPolicy: Always 41 | name: airflow-redis 42 | ports: 43 | - containerPort: 6379 44 | protocol: TCP 45 | resources: {} 46 | terminationMessagePath: /dev/termination-log 47 | dnsPolicy: ClusterFirst 48 | restartPolicy: Always 49 | securityContext: {} 50 | terminationGracePeriodSeconds: 30 51 | -------------------------------------------------------------------------------- /scripts/util.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Make Bash a little less error-prone. 17 | set -euo pipefail 18 | 19 | get_latest_config_id() { 20 | # Given a service name, this returns the most recent deployment of that 21 | # API. 22 | service_name="$1" 23 | gcloud endpoints configs list \ 24 | --service="$service_name" \ 25 | --sort-by="~config_id" --limit=1 --format="value(CONFIG_ID)" \ 26 | | tr -d '[:space:]' 27 | } 28 | 29 | get_project_id() { 30 | # Find the project ID first by DEVSHELL_PROJECT_ID (in Cloud Shell) 31 | # and then by querying the gcloud default project. 32 | local project="${DEVSHELL_PROJECT_ID:-}" 33 | if [[ -z "$project" ]]; then 34 | project=$(gcloud config get-value project 2> /dev/null) 35 | fi 36 | if [[ -z "$project" ]]; then 37 | >&2 echo "No default project was found, and DEVSHELL_PROJECT_ID is not set." 38 | >&2 echo "Please use the Cloud Shell or set your default project by typing:" 39 | >&2 echo "gcloud config set project YOUR-PROJECT-NAME" 40 | fi 41 | echo "$project" 42 | } 43 | -------------------------------------------------------------------------------- /airflow/deploy/config/airflow.cfg: -------------------------------------------------------------------------------- 1 | [core] 2 | airflow_home = /etc/airflow 3 | dags_folder = /home/airflow/gcs/dags 4 | base_log_folder = /home/airflow/gcs/logs 5 | plugins_folder = /home/airflow/gcs/plugins 6 | 7 | remote_log_conn_id = google_cloud_default 8 | remote_base_log_folder = 9 | 10 | executor = CeleryExecutor 11 | 12 | sql_alchemy_pool_size = 5 13 | sql_alchemy_pool_recycle = 3600 14 | 15 | parallelism = 32 16 | dag_concurrency = 16 17 | max_active_runs_per_dag = 16 18 | dags_are_paused_at_creation = True 19 | load_examples = False 20 | donot_pickle = False 21 | dagbag_import_timeout = 30 22 | 23 | fernet_key = cryptography_not_found_storing_passwords_in_plain_text 24 | 25 | [webserver] 26 | base_url = http://localhost:8080 27 | web_server_host = 0.0.0.0 28 | web_server_port = 8080 29 | secret_key = temporary_key 30 | workers = 2 31 | worker_class = sync 32 | expose_config = true 33 | 34 | # Web UI Authentication 35 | authenticate = True 36 | auth_backend = airflow.contrib.auth.backends.password_auth 37 | filter_by_owner = False 38 | 39 | [email] 40 | email_backend = airflow.utils.send_email_smtp 41 | 42 | [smtp] 43 | smtp_host = localhost 44 | smtp_starttls = True 45 | smtp_ssl = False 46 | smtp_user = airflow 47 | smtp_port = 25 48 | smtp_password = airflow 49 | smtp_mail_from = airflow@airflow.com 50 | 51 | [celery] 52 | celery_app_name = airflow.executors.celery_executor 53 | celeryd_concurrency = 16 54 | worker_log_server_port = 8793 55 | 56 | broker_url = redis://airflow-redis:6379/0 57 | celery_result_backend = redis://airflow-redis:6379/0 58 | 59 | flower_port = 5555 60 | default_queue = default 61 | 62 | [scheduler] 63 | job_heartbeat_sec = 5 64 | scheduler_heartbeat_sec = 5 65 | -------------------------------------------------------------------------------- /scripts/prepare_deploy_api.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | set -euo pipefail 17 | 18 | source util.sh 19 | 20 | main() { 21 | # Get our working project, or exit if it's not set. 22 | local project_id=$(get_project_id) 23 | if [[ -z "$project_id" ]]; then 24 | exit 1 25 | fi 26 | local temp_file=$(mktemp) 27 | export TEMP_FILE="${temp_file}.yaml" 28 | mv "$temp_file" "$TEMP_FILE" 29 | 30 | # Because the included API is a template, we have to do some string 31 | # substitution before we can deploy it. Sed does this nicely. 32 | < "$API_FILE" sed -E "s/YOUR-PROJECT-ID/${project_id}/g" > "$TEMP_FILE" 33 | echo "Preparing config for deploying service in $API_FILE..." 34 | echo "To deploy: gcloud endpoints services deploy $TEMP_FILE" 35 | } 36 | 37 | # Defaults. 38 | API_FILE="../app/openapi.yaml" 39 | 40 | if [[ "$#" == 0 ]]; then 41 | : # Use defaults. 42 | elif [[ "$#" == 1 ]]; then 43 | API_FILE="$1" 44 | else 45 | echo "Wrong number of arguments specified." 46 | echo "Usage: deploy_api.sh [api-file]" 47 | exit 1 48 | fi 49 | 50 | main "$@" 51 | -------------------------------------------------------------------------------- /scripts/generate_traffic.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | set -euo pipefail 17 | 18 | source util.sh 19 | 20 | main() { 21 | # Get our working project, or exit if it's not set. 22 | local project_id=$(get_project_id) 23 | if [[ -z "$project_id" ]]; then 24 | exit 1 25 | fi 26 | local url="https://${project_id}.appspot.com/recommendation?userId=${USER_ID}&numRecs=${NUM_RECS}" 27 | echo "This command will exit automatically in $TIMEOUT_SECONDS seconds." 28 | echo "Generating traffic to ${url}..." 29 | echo "Press Ctrl-C to stop." 30 | local endtime=$(($(date +%s) + $TIMEOUT_SECONDS)) 31 | # Send queries repeatedly until TIMEOUT_SECONDS seconds have elapsed. 32 | while [[ $(date +%s) -lt $endtime ]]; do 33 | curl "$url" &> /dev/null 34 | done 35 | } 36 | 37 | # Defaults. 38 | USER_ID="5448543647176335931" 39 | NUM_RECS=5 40 | TIMEOUT_SECONDS=$((5 * 60)) # Timeout after 5 minutes. 41 | 42 | if [[ "$#" == 0 ]]; then 43 | : # Use defaults. 44 | elif [[ "$#" == 1 ]]; then 45 | USER_ID="$1" 46 | else 47 | echo "Wrong number of arguments specified." 48 | echo "Usage: generate_traffic.sh [user_id]" 49 | exit 1 50 | fi 51 | 52 | main "$@" 53 | -------------------------------------------------------------------------------- /scripts/query_api_auth.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | set -euo pipefail 17 | 18 | source util.sh 19 | 20 | main() { 21 | # Get our working project, or exit if it's not set. 22 | local project_id=$(get_project_id) 23 | if [[ -z "$project_id" ]]; then 24 | exit 1 25 | fi 26 | # Because our included app uses query string parameters, we can include 27 | # them directly in the URL. We use -H to specify a header with our API key. 28 | QUERY="curl -H 'x-api-key: $API_KEY' \"https://${project_id}.appspot.com/recommendation?userId=${USER_ID}\"" 29 | # First, print the command so the user can see what's being executed. 30 | echo "$QUERY" 31 | # Then actually execute it. 32 | # shellcheck disable=SC2086 33 | eval $QUERY 34 | # Our API doesn't print newlines. So we do it ourselves. 35 | printf '\n' 36 | } 37 | 38 | # Defaults. 39 | USER_ID="5448543647176335931" 40 | 41 | if [[ "$#" == 1 ]]; then 42 | API_KEY="$1" 43 | elif [[ "$#" == 2 ]]; then 44 | # "Quiet mode" won't print the curl command. 45 | API_KEY="$1" 46 | USER_ID="$2" 47 | else 48 | echo "Wrong number of arguments specified." 49 | echo "Usage: query_api_auth.sh api-key [user-id]" 50 | exit 1 51 | fi 52 | 53 | main "$@" 54 | -------------------------------------------------------------------------------- /airflow/unittests.cfg: -------------------------------------------------------------------------------- 1 | [core] 2 | unit_test_mode = True 3 | airflow_home = /Users/lramsey/work/tfrec/rec_serve/airflow 4 | dags_folder = /Users/lramsey/work/tfrec/rec_serve/airflow/dags 5 | plugins_folder = /Users/lramsey/work/tfrec/rec_serve/airflow/plugins 6 | base_log_folder = /Users/lramsey/work/tfrec/rec_serve/airflow/logs 7 | executor = SequentialExecutor 8 | sql_alchemy_conn = sqlite:////Users/lramsey/work/tfrec/rec_serve/airflow/unittests.db 9 | load_examples = True 10 | donot_pickle = False 11 | dag_concurrency = 16 12 | dags_are_paused_at_creation = False 13 | fernet_key = cryptography_not_found_storing_passwords_in_plain_text 14 | non_pooled_task_slot_count = 128 15 | 16 | [cli] 17 | api_client = airflow.api.client.local_client 18 | endpoint_url = http://localhost:8080 19 | 20 | [api] 21 | auth_backend = airflow.api.auth.backend.default 22 | 23 | [operators] 24 | default_owner = airflow 25 | 26 | [webserver] 27 | base_url = http://localhost:8080 28 | web_server_host = 0.0.0.0 29 | web_server_port = 8080 30 | dag_orientation = LR 31 | log_fetch_timeout_sec = 5 32 | hide_paused_dags_by_default = False 33 | 34 | [email] 35 | email_backend = airflow.utils.email.send_email_smtp 36 | 37 | [smtp] 38 | smtp_host = localhost 39 | smtp_user = airflow 40 | smtp_port = 25 41 | smtp_password = airflow 42 | smtp_mail_from = airflow@airflow.com 43 | 44 | [celery] 45 | celery_app_name = airflow.executors.celery_executor 46 | celeryd_concurrency = 16 47 | worker_log_server_port = 8793 48 | broker_url = sqla+mysql://airflow:airflow@localhost:3306/airflow 49 | celery_result_backend = db+mysql://airflow:airflow@localhost:3306/airflow 50 | flower_host = 0.0.0.0 51 | flower_port = 5555 52 | default_queue = default 53 | 54 | [scheduler] 55 | job_heartbeat_sec = 1 56 | scheduler_heartbeat_sec = 5 57 | authenticate = true 58 | max_threads = 2 59 | catchup_by_default = True 60 | scheduler_zombie_task_threshold = 300 61 | dag_dir_list_interval = 0 62 | -------------------------------------------------------------------------------- /app/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Main module for App Engine app.""" 16 | 17 | from flask import Flask, jsonify, request 18 | 19 | from recommendations import Recommendations 20 | 21 | app = Flask(__name__) 22 | 23 | rec_util = Recommendations() 24 | 25 | DEFAULT_RECS = 5 26 | 27 | 28 | @app.route('/recommendation', methods=['GET']) 29 | def recommendation(): 30 | """Given a user id, return a list of recommended item ids.""" 31 | user_id = request.args.get('userId') 32 | num_recs = request.args.get('numRecs') 33 | 34 | # validate args 35 | if user_id is None: 36 | return 'No User Id provided.', 400 37 | if num_recs is None: 38 | num_recs = DEFAULT_RECS 39 | try: 40 | uid_int = int(user_id) 41 | nrecs_int = int(num_recs) 42 | except: 43 | return 'User id and number of recs arguments must be integers.', 400 44 | 45 | # get recommended articles 46 | rec_list = rec_util.get_recommendations(uid_int, nrecs_int) 47 | 48 | if rec_list is None: 49 | return 'User Id not found : %s' % user_id, 400 50 | 51 | json_response = jsonify({'articles': [str(i) for i in rec_list]}) 52 | return json_response, 200 53 | 54 | 55 | @app.route('/readiness_check', methods=['GET']) 56 | def readiness_check(): 57 | return '', 200 58 | 59 | if __name__ == '__main__': 60 | app.run(host='127.0.0.1', port=8080, debug=True) 61 | -------------------------------------------------------------------------------- /wals_ml_engine/trainer/util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import uuid 17 | 18 | import sh 19 | 20 | import tensorflow as tf 21 | from tensorflow.core.framework.summary_pb2 import Summary 22 | 23 | def ensure_local_file(input_file): 24 | """ 25 | Ensure the training ratings file is stored locally. 26 | """ 27 | if input_file.startswith('gs:/'): 28 | input_path = os.path.join('/tmp/', str(uuid.uuid4())) 29 | os.makedirs(input_path) 30 | tmp_input_file = os.path.join(input_path, os.path.basename(input_file)) 31 | sh.gsutil("cp", "-r", input_file, tmp_input_file) 32 | return tmp_input_file 33 | else: 34 | return input_file 35 | 36 | 37 | def write_hptuning_metric(args, metric): 38 | """ 39 | Write a summary containing the tuning loss metric, as required by hyperparam tuning. 40 | """ 41 | summary = Summary(value=[Summary.Value(tag='training/hptuning/metric', simple_value=metric)]) 42 | 43 | # for hyperparam tuning, we write a summary log to a directory 'eval' below the job directory 44 | eval_path = os.path.join(args['output_dir'], 'eval') 45 | summary_writer = tf.summary.FileWriter(eval_path) 46 | 47 | # Note: adding the summary to the writer is enough for hyperparam tuning. 48 | # The ml engine system is looking for any summary added with the hyperparam metric tag. 49 | summary_writer.add_summary(summary) 50 | summary_writer.flush() 51 | -------------------------------------------------------------------------------- /scripts/query_api.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | set -euo pipefail 17 | 18 | source util.sh 19 | 20 | main() { 21 | # Get our working project, or exit if it's not set. 22 | local project_id=$(get_project_id) 23 | if [[ -z "$project_id" ]]; then 24 | exit 1 25 | fi 26 | # Because our included app uses query string parameters, we can include 27 | # them directly in the URL. 28 | QUERY="curl \"https://${project_id}.appspot.com/recommendation?userId=${USER_ID}&numRecs=${NUM_RECS}\"" 29 | # First, (maybe) print the command so the user can see what's being executed. 30 | if [[ "$QUIET" == "false" ]]; then 31 | echo "$QUERY" 32 | fi 33 | # Then actually execute it. 34 | # shellcheck disable=SC2086 35 | eval $QUERY 36 | # Our API doesn't print newlines. So we do it ourselves. 37 | printf '\n' 38 | } 39 | 40 | # Defaults. 41 | USER_ID="5448543647176335931" 42 | NUM_RECS=5 43 | QUIET="false" 44 | 45 | if [[ "$#" == 0 ]]; then 46 | : # Use defaults. 47 | elif [[ "$#" == 1 ]]; then 48 | USER_ID="$1" 49 | elif [[ "$#" == 2 ]]; then 50 | USER_ID="$1" 51 | NUM_RECS="$2" 52 | elif [[ "$#" == 3 ]]; then 53 | # "Quiet mode" won't print the curl command. 54 | USER_ID="$1" 55 | NUM_RECS="$2" 56 | QUIET="true" 57 | else 58 | echo "Wrong number of arguments specified." 59 | echo "Usage: query_api.sh [user-id] [num-recs] [quiet-mode]" 60 | exit 1 61 | fi 62 | 63 | main "$@" 64 | -------------------------------------------------------------------------------- /scripts/prepare_deploy_app.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | set -euo pipefail 17 | 18 | source util.sh 19 | 20 | main() { 21 | # Get our working project, or exit if it's not set. 22 | local project_id="$(get_project_id)" 23 | if [[ -z "$project_id" ]]; then 24 | exit 1 25 | fi 26 | # Try to create an App Engine project in our selected region. 27 | # If it already exists, return a success ("|| true"). 28 | echo "gcloud app create --region=$REGION" 29 | gcloud app create --region="$REGION" || true 30 | 31 | # Prepare the necessary variables for substitution in our app configuration 32 | # template, and create a temporary file to hold the templatized version. 33 | local service_name="${project_id}.appspot.com" 34 | local config_id=$(get_latest_config_id "$service_name") 35 | export TEMP_FILE="${APP}_deploy.yaml" 36 | < "$APP" \ 37 | sed -E "s/SERVICE_NAME/${service_name}/g" \ 38 | | sed -E "s/SERVICE_CONFIG_ID/${config_id}/g" \ 39 | > "$TEMP_FILE" 40 | 41 | echo "To deploy: gcloud -q app deploy $TEMP_FILE" 42 | } 43 | 44 | # Defaults. 45 | APP="../app/app_template.yaml" 46 | REGION="us-east1" 47 | SERVICE_NAME="default" 48 | 49 | if [[ "$#" == 0 ]]; then 50 | : # Use defaults. 51 | elif [[ "$#" == 1 ]]; then 52 | APP="$1" 53 | elif [[ "$#" == 2 ]]; then 54 | APP="$1" 55 | REGION="$2" 56 | else 57 | echo "Wrong number of arguments specified." 58 | echo "Usage: deploy_app.sh [app-template] [region]" 59 | exit 1 60 | fi 61 | 62 | main "$@" 63 | -------------------------------------------------------------------------------- /airflow/deploy/config/k8s/deploy-sql-proxy.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | apiVersion: extensions/v1beta1 16 | kind: Deployment 17 | metadata: 18 | labels: 19 | run: airflow-sqlproxy 20 | name: airflow-sqlproxy 21 | namespace: default 22 | spec: 23 | replicas: 1 24 | selector: 25 | matchLabels: 26 | run: airflow-sqlproxy 27 | strategy: 28 | rollingUpdate: 29 | maxSurge: 1 30 | maxUnavailable: 1 31 | type: RollingUpdate 32 | template: 33 | metadata: 34 | creationTimestamp: null 35 | labels: 36 | run: airflow-sqlproxy 37 | spec: 38 | volumes: 39 | - name: ssl-certs 40 | hostPath: 41 | path: /etc/ssl/certs 42 | containers: 43 | - image: gcr.io/cloud-airflow-public/airflow-sqlproxy:1.8.0 44 | volumeMounts: 45 | - name: ssl-certs 46 | mountPath: /etc/ssl/certs 47 | command: ["/cloud_sql_proxy", 48 | "-instances=$(SQL_PROJECT):$(SQL_REGION):$(SQL_INSTANCE)=tcp:0.0.0.0:3306"] 49 | imagePullPolicy: Always 50 | name: airflow-sqlproxy 51 | ports: 52 | - containerPort: 3306 53 | protocol: TCP 54 | resources: {} 55 | terminationMessagePath: /dev/termination-log 56 | env: 57 | - name: SQL_PROJECT 58 | valueFrom: 59 | configMapKeyRef: 60 | name: deployment-settings 61 | key: sql_project 62 | - name: SQL_REGION 63 | valueFrom: 64 | configMapKeyRef: 65 | name: deployment-settings 66 | key: sql_region 67 | - name: SQL_INSTANCE 68 | valueFrom: 69 | configMapKeyRef: 70 | name: deployment-settings 71 | key: sql_instance 72 | dnsPolicy: ClusterFirst 73 | restartPolicy: Always 74 | securityContext: {} 75 | terminationGracePeriodSeconds: 30 76 | -------------------------------------------------------------------------------- /airflow/deploy/config/k8s/job-testing.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | apiVersion: batch/v1 16 | kind: Job 17 | metadata: 18 | name: airflow-testing-1.8.0 19 | namespace: default 20 | spec: 21 | template: 22 | metadata: 23 | name: airflow-testing 24 | spec: 25 | volumes: 26 | - name: ssl-certs 27 | hostPath: 28 | path: /etc/ssl/certs 29 | - name: airflow-config 30 | configMap: 31 | name: airflow-config 32 | containers: 33 | - image: gcr.io/cloud-airflow-public/airflow-testing:1.8.0 34 | imagePullPolicy: Always 35 | args: 36 | - "managed-airflow-test" 37 | - "us-west1" 38 | name: airflow-testing 39 | env: 40 | - name: GCS_BUCKET 41 | valueFrom: 42 | configMapKeyRef: 43 | name: deployment-settings 44 | key: gcs_bucket 45 | - name: AIRFLOW_HOME 46 | valueFrom: 47 | configMapKeyRef: 48 | name: deployment-settings 49 | key: airflow_home 50 | - name: DAGS_FOLDER 51 | valueFrom: 52 | configMapKeyRef: 53 | name: deployment-settings 54 | key: dags_folder 55 | - name: DAGS_FOLDER 56 | valueFrom: 57 | configMapKeyRef: 58 | name: deployment-settings 59 | key: dags_folder 60 | - name: SQL_DATABASE 61 | valueFrom: 62 | configMapKeyRef: 63 | name: deployment-settings 64 | key: sql_database 65 | - name: SQL_USER 66 | valueFrom: 67 | configMapKeyRef: 68 | name: deployment-settings 69 | key: sql_user 70 | - name: SQL_PASSWORD 71 | valueFrom: 72 | secretKeyRef: 73 | name: deployment-secrets 74 | key: sql_password 75 | - name: AIRFLOW__CORE__SQL_ALCHEMY_CONN 76 | value: mysql+mysqldb://$(SQL_USER):$(SQL_PASSWORD)@airflow-sqlproxy/$(SQL_DATABASE) 77 | volumeMounts: 78 | - name: airflow-config 79 | mountPath: /etc/airflow 80 | - name: ssl-certs 81 | mountPath: /etc/ssl/certs 82 | securityContext: 83 | privileged: true 84 | dnsPolicy: ClusterFirst 85 | restartPolicy: Never 86 | securityContext: {} 87 | terminationGracePeriodSeconds: 30 88 | -------------------------------------------------------------------------------- /airflow/deploy/config/k8s/job-init.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | apiVersion: batch/v1 16 | kind: Job 17 | metadata: 18 | name: airflow-init 19 | namespace: default 20 | spec: 21 | template: 22 | metadata: 23 | name: airflow-init 24 | spec: 25 | volumes: 26 | - name: ssl-certs 27 | hostPath: 28 | path: /etc/ssl/certs 29 | - name: airflow-config 30 | configMap: 31 | name: airflow-config 32 | containers: 33 | - image: gcr.io/cloud-airflow-public/airflow-base-1.8.0 34 | imagePullPolicy: Always 35 | command: ["/var/local/init.py"] 36 | name: airflow-init 37 | env: 38 | - name: C_FORCE_ROOT 39 | value: "TRUE" 40 | - name: AIRFLOW_HOME 41 | value: /etc/airflow 42 | - name: SQL_DATABASE 43 | valueFrom: 44 | configMapKeyRef: 45 | name: deployment-settings 46 | key: sql_database 47 | - name: SQL_USER 48 | valueFrom: 49 | configMapKeyRef: 50 | name: deployment-settings 51 | key: sql_user 52 | - name: SQL_PASSWORD 53 | valueFrom: 54 | secretKeyRef: 55 | name: deployment-secrets 56 | key: sql_password 57 | - name: SQL_ROOT_USER 58 | valueFrom: 59 | configMapKeyRef: 60 | name: deployment-settings 61 | key: sql_root_user 62 | - name: SQL_ROOT_PASSWORD 63 | valueFrom: 64 | secretKeyRef: 65 | name: deployment-secrets 66 | key: sql_root_password 67 | - name: AIRFLOW__CORE__SQL_ALCHEMY_CONN 68 | value: mysql+mysqldb://$(SQL_USER):$(SQL_PASSWORD)@airflow-sqlproxy/$(SQL_DATABASE) 69 | - name: WEB_USER 70 | valueFrom: 71 | configMapKeyRef: 72 | name: deployment-settings 73 | key: web_ui_username 74 | - name: WEB_PASSWORD 75 | valueFrom: 76 | secretKeyRef: 77 | name: deployment-secrets 78 | key: web_ui_password 79 | - name: GCP_PROJECT 80 | valueFrom: 81 | configMapKeyRef: 82 | name: deployment-settings 83 | key: gcp_project 84 | volumeMounts: 85 | - name: airflow-config 86 | mountPath: /etc/airflow 87 | - name: ssl-certs 88 | mountPath: /etc/ssl/certs 89 | dnsPolicy: ClusterFirst 90 | restartPolicy: Never 91 | securityContext: {} 92 | terminationGracePeriodSeconds: 30 93 | -------------------------------------------------------------------------------- /airflow/deploy/config/k8s/deploy-workers.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | apiVersion: extensions/v1beta1 16 | kind: Deployment 17 | metadata: 18 | annotations: 19 | deployment.kubernetes.io/revision: "2" 20 | labels: 21 | run: airflow-worker 22 | name: airflow-worker 23 | namespace: default 24 | spec: 25 | replicas: 1 26 | selector: 27 | matchLabels: 28 | run: airflow-worker 29 | strategy: 30 | rollingUpdate: 31 | maxSurge: 1 32 | maxUnavailable: 1 33 | type: RollingUpdate 34 | template: 35 | metadata: 36 | labels: 37 | run: airflow-worker 38 | spec: 39 | volumes: 40 | - name: ssl-certs 41 | hostPath: 42 | path: /etc/ssl/certs 43 | - name: airflow-config 44 | configMap: 45 | name: airflow-config 46 | containers: 47 | - image: gcr.io/cloud-airflow-public/airflow-worker-1.8.0 48 | imagePullPolicy: Always 49 | name: airflow-worker 50 | env: 51 | - name: GCS_BUCKET 52 | valueFrom: 53 | configMapKeyRef: 54 | name: deployment-settings 55 | key: gcs_bucket 56 | - name: AIRFLOW_HOME 57 | valueFrom: 58 | configMapKeyRef: 59 | name: deployment-settings 60 | key: airflow_home 61 | - name: DAGS_FOLDER 62 | valueFrom: 63 | configMapKeyRef: 64 | name: deployment-settings 65 | key: dags_folder 66 | - name: C_FORCE_ROOT 67 | value: "TRUE" 68 | - name: SQL_DATABASE 69 | valueFrom: 70 | configMapKeyRef: 71 | name: deployment-settings 72 | key: sql_database 73 | - name: SQL_USER 74 | valueFrom: 75 | configMapKeyRef: 76 | name: deployment-settings 77 | key: sql_user 78 | - name: SQL_PASSWORD 79 | valueFrom: 80 | secretKeyRef: 81 | name: deployment-secrets 82 | key: sql_password 83 | - name: AIRFLOW__CORE__SQL_ALCHEMY_CONN 84 | value: mysql+mysqldb://$(SQL_USER):$(SQL_PASSWORD)@airflow-sqlproxy/$(SQL_DATABASE) 85 | volumeMounts: 86 | - name: airflow-config 87 | mountPath: /etc/airflow 88 | args: ["worker"] 89 | ports: 90 | - containerPort: 8793 91 | resources: {} 92 | terminationMessagePath: /dev/termination-log 93 | securityContext: 94 | privileged: true 95 | dnsPolicy: ClusterFirst 96 | restartPolicy: Always 97 | securityContext: {} 98 | terminationGracePeriodSeconds: 30 99 | -------------------------------------------------------------------------------- /airflow/deploy/config/k8s/deploy-webserver.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | apiVersion: extensions/v1beta1 16 | kind: Deployment 17 | metadata: 18 | annotations: 19 | deployment.kubernetes.io/revision: "2" 20 | labels: 21 | run: airflow-webserver 22 | name: airflow-webserver 23 | namespace: default 24 | spec: 25 | replicas: 1 26 | selector: 27 | matchLabels: 28 | run: airflow-webserver 29 | strategy: 30 | rollingUpdate: 31 | maxSurge: 1 32 | maxUnavailable: 1 33 | type: RollingUpdate 34 | template: 35 | metadata: 36 | labels: 37 | run: airflow-webserver 38 | spec: 39 | volumes: 40 | - name: ssl-certs 41 | hostPath: 42 | path: /etc/ssl/certs 43 | - name: airflow-config 44 | configMap: 45 | name: airflow-config 46 | containers: 47 | - image: gcr.io/cloud-airflow-public/airflow-base-1.8.0 48 | imagePullPolicy: Always 49 | name: airflow-webserver 50 | env: 51 | - name: GCS_BUCKET 52 | valueFrom: 53 | configMapKeyRef: 54 | name: deployment-settings 55 | key: gcs_bucket 56 | - name: AIRFLOW_HOME 57 | valueFrom: 58 | configMapKeyRef: 59 | name: deployment-settings 60 | key: airflow_home 61 | - name: DAGS_FOLDER 62 | valueFrom: 63 | configMapKeyRef: 64 | name: deployment-settings 65 | key: dags_folder 66 | - name: SQL_DATABASE 67 | valueFrom: 68 | configMapKeyRef: 69 | name: deployment-settings 70 | key: sql_database 71 | - name: SQL_USER 72 | valueFrom: 73 | configMapKeyRef: 74 | name: deployment-settings 75 | key: sql_user 76 | - name: SQL_PASSWORD 77 | valueFrom: 78 | secretKeyRef: 79 | name: deployment-secrets 80 | key: sql_password 81 | - name: AIRFLOW__CORE__SQL_ALCHEMY_CONN 82 | value: mysql+mysqldb://$(SQL_USER):$(SQL_PASSWORD)@airflow-sqlproxy/$(SQL_DATABASE) 83 | volumeMounts: 84 | - name: airflow-config 85 | mountPath: /etc/airflow 86 | - name: ssl-certs 87 | mountPath: /etc/ssl/certs 88 | args: ["webserver"] 89 | ports: 90 | - containerPort: 8080 91 | protocol: TCP 92 | resources: {} 93 | terminationMessagePath: /dev/termination-log 94 | securityContext: 95 | privileged: true 96 | dnsPolicy: ClusterFirst 97 | restartPolicy: Always 98 | securityContext: {} 99 | terminationGracePeriodSeconds: 30 100 | -------------------------------------------------------------------------------- /airflow/deploy/config/k8s/deploy-scheduler.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | apiVersion: extensions/v1beta1 16 | kind: Deployment 17 | metadata: 18 | annotations: 19 | deployment.kubernetes.io/revision: "2" 20 | labels: 21 | run: airflow-scheduler 22 | name: airflow-scheduler 23 | namespace: default 24 | spec: 25 | replicas: 1 26 | selector: 27 | matchLabels: 28 | run: airflow-scheduler 29 | strategy: 30 | rollingUpdate: 31 | maxSurge: 1 32 | maxUnavailable: 1 33 | type: RollingUpdate 34 | template: 35 | metadata: 36 | labels: 37 | run: airflow-scheduler 38 | spec: 39 | volumes: 40 | - name: ssl-certs 41 | hostPath: 42 | path: /etc/ssl/certs 43 | - name: airflow-config 44 | configMap: 45 | name: airflow-config 46 | containers: 47 | - image: gcr.io/cloud-airflow-public/airflow-base-1.8.0 48 | imagePullPolicy: Always 49 | name: airflow-scheduler 50 | env: 51 | - name: GCS_BUCKET 52 | valueFrom: 53 | configMapKeyRef: 54 | name: deployment-settings 55 | key: gcs_bucket 56 | - name: AIRFLOW_HOME 57 | valueFrom: 58 | configMapKeyRef: 59 | name: deployment-settings 60 | key: airflow_home 61 | - name: DAGS_FOLDER 62 | valueFrom: 63 | configMapKeyRef: 64 | name: deployment-settings 65 | key: dags_folder 66 | - name: SQL_DATABASE 67 | valueFrom: 68 | configMapKeyRef: 69 | name: deployment-settings 70 | key: sql_database 71 | - name: SQL_USER 72 | valueFrom: 73 | configMapKeyRef: 74 | name: deployment-settings 75 | key: sql_user 76 | - name: SQL_PASSWORD 77 | valueFrom: 78 | secretKeyRef: 79 | name: deployment-secrets 80 | key: sql_password 81 | - name: AIRFLOW__CORE__SQL_ALCHEMY_CONN 82 | value: mysql+mysqldb://$(SQL_USER):$(SQL_PASSWORD)@airflow-sqlproxy/$(SQL_DATABASE) 83 | volumeMounts: 84 | - name: airflow-config 85 | mountPath: /etc/airflow 86 | - name: ssl-certs 87 | mountPath: /etc/ssl/certs 88 | args: ["scheduler"] 89 | resources: {} 90 | terminationMessagePath: /dev/termination-log 91 | securityContext: 92 | privileged: true 93 | livenessProbe: 94 | exec: 95 | command: 96 | - /var/local/log_checker.sh 97 | initialDelaySeconds: 120 98 | periodSeconds: 30 99 | dnsPolicy: ClusterFirst 100 | restartPolicy: Always 101 | securityContext: {} 102 | terminationGracePeriodSeconds: 30 103 | -------------------------------------------------------------------------------- /wals_ml_engine/mltrain.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | usage () { 17 | echo "usage: mltrain.sh [local | train | tune] [gs://]job_and_data_dir [path_to/].csv 18 | [--data-type ratings|web_views] 19 | [--delimiter ] 20 | [--use-optimized] 21 | [--headers] 22 | 23 | Use 'local' to train locally with a local data file, and 'train' and 'tune' to 24 | run on ML Engine. For ML Engine cloud jobs the data_dir must be prefixed with 25 | gs:// and point to an existing bucket, and the input file must reside on GCS. 26 | 27 | Optional args: 28 | --data-type: Default to 'ratings', meaning MovieLens ratings from 0-5. 29 | Set to 'web_views' for Google Analytics data. 30 | --delimiter: CSV delimiter, default to '\t'. 31 | --use-optimized: Use optimized hyperparamters, default False. 32 | --headers: Default False for 'ratings', True for 'web_views'. 33 | 34 | Examples: 35 | 36 | # train locally with unoptimized hyperparams on included web views data set 37 | ./mltrain.sh local ../data recommendation_events.csv --data-type web_views 38 | 39 | # train on ML Engine with optimized hyperparams 40 | ./mltrain.sh train gs://rec_serve data/recommendation_events.csv --data-type web_views --use-optimized 41 | 42 | # tune hyperparams on ML Engine: 43 | ./mltrain.sh tune gs://rec_serve data/recommendation_events.csv --data-type web_views 44 | " 45 | 46 | } 47 | 48 | date 49 | 50 | TIME=`date +"%Y%m%d_%H%M%S"` 51 | 52 | # change to your preferred region 53 | REGION=us-central1 54 | 55 | if [[ $# < 3 ]]; then 56 | usage 57 | exit 1 58 | fi 59 | 60 | # set job vars 61 | TRAIN_JOB="$1" 62 | BUCKET="$2" 63 | DATA_FILE="$3" 64 | JOB_NAME=wals_ml_${TRAIN_JOB}_${TIME} 65 | 66 | # add additional args 67 | shift; shift; shift 68 | 69 | if [[ ${TRAIN_JOB} == "local" ]]; then 70 | 71 | ARGS="--train-file $BUCKET/${DATA_FILE} --verbose-logging $@" 72 | 73 | mkdir -p jobs/${JOB_NAME} 74 | 75 | gcloud ai-platform local train \ 76 | --module-name trainer.task \ 77 | --package-path trainer \ 78 | -- \ 79 | --job-dir jobs/${JOB_NAME} \ 80 | ${ARGS} 81 | 82 | elif [[ ${TRAIN_JOB} == "train" ]]; then 83 | 84 | ARGS="--gcs-bucket $BUCKET --train-file ${DATA_FILE} --verbose-logging $@" 85 | 86 | gcloud ai-platform jobs submit training ${JOB_NAME} \ 87 | --region $REGION \ 88 | --scale-tier=CUSTOM \ 89 | --job-dir ${BUCKET}/jobs/${JOB_NAME} \ 90 | --module-name trainer.task \ 91 | --package-path trainer \ 92 | --master-machine-type complex_model_m_gpu \ 93 | --config trainer/config/config_train.json \ 94 | --master-machine-type complex_model_m_gpu \ 95 | --runtime-version 1.15 \ 96 | -- \ 97 | ${ARGS} 98 | 99 | elif [[ $TRAIN_JOB == "tune" ]]; then 100 | 101 | ARGS="--gcs-bucket $BUCKET --train-file ${DATA_FILE} --verbose-logging $@" 102 | 103 | # set configuration for tuning 104 | CONFIG_TUNE="trainer/config/config_tune.json" 105 | for i in $ARGS ; do 106 | if [[ "$i" == "web_views" ]]; then 107 | CONFIG_TUNE="trainer/config/config_tune_web.json" 108 | break 109 | fi 110 | done 111 | 112 | gcloud ai-platform jobs submit training ${JOB_NAME} \ 113 | --region ${REGION} \ 114 | --scale-tier=CUSTOM \ 115 | --job-dir ${BUCKET}/jobs/${JOB_NAME} \ 116 | --module-name trainer.task \ 117 | --package-path trainer \ 118 | --master-machine-type standard_gpu \ 119 | --config ${CONFIG_TUNE} \ 120 | --master-machine-type complex_model_m_gpu \ 121 | --runtime-version 1.15 \ 122 | -- \ 123 | --hypertune \ 124 | ${ARGS} 125 | 126 | else 127 | usage 128 | fi 129 | 130 | date 131 | -------------------------------------------------------------------------------- /airflow/deploy/teardown_local.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Script for tearing down an Airflow deployment on GCP. 16 | 17 | This script essentially just tears down the GCP resources. 18 | It does not interact with Kubernetes or Airflow. 19 | 20 | The cluster to tear down is specified via the deployment settings 21 | 22 | Flags: 23 | --settings: Path to YAML file containing deployment settings. 24 | file emitted by the deployment script. 25 | """ 26 | 27 | import argparse 28 | import logging 29 | import os 30 | import subprocess 31 | import time 32 | 33 | import gcp_util 34 | import oauth2client.client 35 | import yaml 36 | 37 | 38 | def _delete_service(service_file): 39 | """Deletes a Kubernetes service. 40 | 41 | Args: 42 | service_file: a string specifying the path to the file used to 43 | create the service to be deleted 44 | """ 45 | subprocess.call(['kubectl', 'delete', '-f', service_file]) 46 | 47 | 48 | def main(): 49 | args = _parse_args() 50 | logging.basicConfig(level=args.log) 51 | 52 | credentials = ( 53 | oauth2client.client.GoogleCredentials.get_application_default()) 54 | 55 | with open(args.settings, 'r') as fd: 56 | settings = yaml.safe_load(fd) 57 | 58 | # Explicitly deleting the service causes Kubernetes to delete the 59 | # forwarding rule and target pool created for it, which would 60 | # otherwise stick around if we deleted the GKE cluster straightaway. 61 | # The need to do this is a consequence of a known GKE bug. 62 | logging.info('Halting webserver service') 63 | _delete_service(settings['kubernetes']['service']['webserver']) 64 | # Sleep so that the service can clean up networking resources 65 | # before the cluster gets torn down. 66 | # TODO(wwlian): Do something more elegant like polling the GCE API 67 | # for when the target pool and forwarding rule have successfully 68 | # been taken down. 69 | time.sleep(10) 70 | 71 | logging.info('Deleting GKE cluster: %s', settings['id']) 72 | outcome = gcp_util.delete_gke_cluster( 73 | settings['id'], 74 | settings['project'], 75 | settings['kubernetes']['zone'], 76 | credentials) 77 | logging.info('Success' if outcome else 'Failed') 78 | 79 | logging.info('Deleting GCS bucket: %s', settings['id']) 80 | outcome = gcp_util.delete_gcs_bucket(settings['id'], credentials, 81 | force=True) 82 | logging.info('Success' if outcome else 'Failed') 83 | 84 | logging.info('Deleting Cloud SQL instance: %s', settings['id']) 85 | outcome = gcp_util.delete_sql_instance(settings['id'], 86 | settings['project'], 87 | credentials) 88 | logging.info('Success' if outcome else 'Failed') 89 | 90 | 91 | def _existing_filepath(val): 92 | if os.path.isfile(val): 93 | return val 94 | raise argparse.ArgumentError('Path is not an existing file: %s' % val) 95 | 96 | 97 | def _log_level_str(level): 98 | level_num = getattr(logging, level.upper(), None) 99 | if not isinstance(level_num, int): 100 | raise argparse.ArgumentError('Invalid log level: %s' % level) 101 | return level.upper() 102 | 103 | 104 | def _parse_args(): 105 | parser = argparse.ArgumentParser(description='Tear down Airflow on GCP') 106 | parser.add_argument( 107 | '--settings', help='Path to YAML file containing deployment settings', 108 | type=_existing_filepath, required=True) 109 | 110 | # Because argparse renders the helpstring (which we want to 111 | # contain the string-versions of levels) from the 112 | # choices set and runs the type function prior to checking for membership 113 | # in the choices set, the type function must return the same type as 114 | # members of the choices set. Therefore, the type is string rather than 115 | # enum (int). 116 | parser.add_argument('--log', 117 | help='The logging level to use. (default: INFO)', 118 | choices=set(['DEBUG', 'INFO', 'WARNING', 'ERROR', 119 | 'CRITICAL']), 120 | type=_log_level_str, default='INFO') 121 | return parser.parse_args() 122 | 123 | 124 | if __name__ == '__main__': 125 | main() 126 | -------------------------------------------------------------------------------- /airflow/dags/training.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """DAG definition for recserv model training.""" 16 | 17 | import airflow 18 | from airflow import DAG 19 | from airflow.contrib.operators.bigquery_operator import BigQueryOperator 20 | from airflow.contrib.operators.bigquery_to_gcs import BigQueryToCloudStorageOperator 21 | from airflow.hooks.base_hook import BaseHook 22 | from airflow.operators.app_engine_admin_plugin import AppEngineVersionOperator 23 | from airflow.operators.ml_engine_plugin import MLEngineTrainingOperator 24 | 25 | import datetime 26 | 27 | def _get_project_id(): 28 | """Get project ID from default GCP connection.""" 29 | 30 | extras = BaseHook.get_connection('google_cloud_default').extra_dejson 31 | key = 'extra__google_cloud_platform__project' 32 | if key in extras: 33 | project_id = extras[key] 34 | else: 35 | raise ('Must configure project_id in google_cloud_default ' 36 | 'connection from Airflow Console') 37 | return project_id 38 | 39 | PROJECT_ID = _get_project_id() 40 | 41 | # Data set constants, used in BigQuery tasks. You can change these 42 | # to conform to your data. 43 | DATASET = 'GA360_test' 44 | TABLE_NAME = 'ga_sessions_sample' 45 | ARTICLE_CUSTOM_DIMENSION = '10' 46 | 47 | # GCS bucket names and region, can also be changed. 48 | BUCKET = 'gs://recserve_' + PROJECT_ID 49 | REGION = 'us-east1' 50 | 51 | # The code package name comes from the model code in the wals_ml_engine 52 | # directory of the solution code base. 53 | PACKAGE_URI = BUCKET + '/code/wals_ml_engine-0.1.tar.gz' 54 | JOB_DIR = BUCKET + '/jobs' 55 | 56 | default_args = { 57 | 'owner': 'airflow', 58 | 'depends_on_past': False, 59 | 'start_date': airflow.utils.dates.days_ago(2), 60 | 'email': ['airflow@example.com'], 61 | 'email_on_failure': True, 62 | 'email_on_retry': False, 63 | 'retries': 5, 64 | 'retry_delay': datetime.timedelta(minutes=5) 65 | } 66 | 67 | # Default schedule interval using cronjob syntax - can be customized here 68 | # or in the Airflow console. 69 | schedule_interval = '00 21 * * *' 70 | 71 | dag = DAG('recommendations_training_v1', default_args=default_args, 72 | schedule_interval=schedule_interval) 73 | 74 | dag.doc_md = __doc__ 75 | 76 | 77 | # 78 | # 79 | # Task Definition 80 | # 81 | # 82 | 83 | # BigQuery training data query 84 | 85 | bql=''' 86 | #legacySql 87 | SELECT 88 | fullVisitorId as clientId, 89 | ArticleID as contentId, 90 | (nextTime - hits.time) as timeOnPage, 91 | FROM( 92 | SELECT 93 | fullVisitorId, 94 | hits.time, 95 | MAX(IF(hits.customDimensions.index={0}, 96 | hits.customDimensions.value,NULL)) WITHIN hits AS ArticleID, 97 | LEAD(hits.time, 1) OVER (PARTITION BY fullVisitorId, visitNumber 98 | ORDER BY hits.time ASC) as nextTime 99 | FROM [{1}.{2}.{3}] 100 | WHERE hits.type = "PAGE" 101 | ) HAVING timeOnPage is not null and contentId is not null; 102 | ''' 103 | 104 | bql = bql.format(ARTICLE_CUSTOM_DIMENSION, PROJECT_ID, DATASET, TABLE_NAME) 105 | 106 | t1 = BigQueryOperator( 107 | task_id='bq_rec_training_data', 108 | bql=bql, 109 | destination_dataset_table='%s.recommendation_events' % DATASET, 110 | write_disposition='WRITE_TRUNCATE', 111 | dag=dag) 112 | 113 | # BigQuery training data export to GCS 114 | 115 | training_file = BUCKET + '/data/recommendation_events.csv' 116 | t2 = BigQueryToCloudStorageOperator( 117 | task_id='bq_export_op', 118 | source_project_dataset_table='%s.recommendation_events' % DATASET, 119 | destination_cloud_storage_uris=[training_file], 120 | export_format='CSV', 121 | dag=dag 122 | ) 123 | 124 | 125 | # ML Engine training job 126 | 127 | job_id = 'recserve_{0}'.format(datetime.datetime.now().strftime('%Y%m%d%H%M')) 128 | job_dir = BUCKET + '/jobs/' + job_id 129 | output_dir = BUCKET 130 | training_args = ['--job-dir', job_dir, 131 | '--train-file', training_file, 132 | '--output-dir', output_dir, 133 | '--data-type', 'web_views', 134 | '--use-optimized'] 135 | 136 | t3 = MLEngineTrainingOperator( 137 | task_id='ml_engine_training_op', 138 | project_id=PROJECT_ID, 139 | job_id=job_id, 140 | package_uris=[PACKAGE_URI], 141 | training_python_module='trainer.task', 142 | training_args=training_args, 143 | region=REGION, 144 | scale_tier='CUSTOM', 145 | master_type='complex_model_m_gpu', 146 | dag=dag 147 | ) 148 | 149 | # App Engine deploy new version 150 | 151 | t4 = AppEngineVersionOperator( 152 | task_id='app_engine_deploy_version', 153 | project_id=PROJECT_ID, 154 | service_id='default', 155 | region=REGION, 156 | service_spec=None, 157 | dag=dag 158 | ) 159 | 160 | t2.set_upstream(t1) 161 | t3.set_upstream(t2) 162 | t4.set_upstream(t3) 163 | 164 | -------------------------------------------------------------------------------- /wals_ml_engine/trainer/wals.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """WALS model core functions.""" 16 | 17 | import math 18 | 19 | import numpy as np 20 | import tensorflow as tf 21 | from tensorflow.contrib.factorization.python.ops import factorization_ops 22 | 23 | 24 | def get_rmse(output_row, output_col, actual): 25 | """Compute rmse between predicted and actual ratings. 26 | 27 | Args: 28 | output_row: evaluated numpy array of row_factor 29 | output_col: evaluated numpy array of col_factor 30 | actual: coo_matrix of actual (test) values 31 | 32 | Returns: 33 | rmse 34 | """ 35 | mse = 0 36 | for i in xrange(actual.data.shape[0]): 37 | row_pred = output_row[actual.row[i]] 38 | col_pred = output_col[actual.col[i]] 39 | err = actual.data[i] - np.dot(row_pred, col_pred) 40 | mse += err * err 41 | mse /= actual.data.shape[0] 42 | rmse = math.sqrt(mse) 43 | return rmse 44 | 45 | 46 | def simple_train(model, input_tensor, num_iterations): 47 | """Helper function to train model on input for num_iterations. 48 | 49 | Args: 50 | model: WALSModel instance 51 | input_tensor: SparseTensor for input ratings matrix 52 | num_iterations: number of row/column updates to run 53 | 54 | Returns: 55 | tensorflow session, for evaluating results 56 | """ 57 | sess = tf.Session(graph=input_tensor.graph) 58 | 59 | with input_tensor.graph.as_default(): 60 | row_update_op = model.update_row_factors(sp_input=input_tensor)[1] 61 | col_update_op = model.update_col_factors(sp_input=input_tensor)[1] 62 | 63 | sess.run(model.initialize_op) 64 | sess.run(model.worker_init) 65 | for _ in xrange(num_iterations): 66 | sess.run(model.row_update_prep_gramian_op) 67 | sess.run(model.initialize_row_update_op) 68 | sess.run(row_update_op) 69 | sess.run(model.col_update_prep_gramian_op) 70 | sess.run(model.initialize_col_update_op) 71 | sess.run(col_update_op) 72 | 73 | return sess 74 | 75 | LOG_RATINGS = 0 76 | LINEAR_RATINGS = 1 77 | LINEAR_OBS_W = 100.0 78 | 79 | 80 | def make_wts(data, wt_type, obs_wt, feature_wt_exp, axis): 81 | """Generate observed item weights. 82 | 83 | Args: 84 | data: coo_matrix of ratings data 85 | wt_type: weight type, LOG_RATINGS or LINEAR_RATINGS 86 | obs_wt: linear weight factor 87 | feature_wt_exp: logarithmic weight factor 88 | axis: axis to make weights for, 1=rows/users, 0=cols/items 89 | 90 | Returns: 91 | vector of weights for cols (items) or rows (users) 92 | """ 93 | # recipricol of sum of number of items across rows (if axis is 0) 94 | frac = np.array(1.0/(data > 0.0).sum(axis)) 95 | 96 | # filter any invalid entries 97 | frac[np.ma.masked_invalid(frac).mask] = 0.0 98 | 99 | # normalize weights according to assumed distribution of ratings 100 | if wt_type == LOG_RATINGS: 101 | wts = np.array(np.power(frac, feature_wt_exp)).flatten() 102 | else: 103 | wts = np.array(obs_wt * frac).flatten() 104 | 105 | # check again for any numerically unstable entries 106 | assert np.isfinite(wts).sum() == wts.shape[0] 107 | return wts 108 | 109 | 110 | def wals_model(data, dim, reg, unobs, weights=False, 111 | wt_type=LINEAR_RATINGS, feature_wt_exp=None, 112 | obs_wt=LINEAR_OBS_W): 113 | """Create the WALSModel and input, row and col factor tensors. 114 | 115 | Args: 116 | data: scipy coo_matrix of item ratings 117 | dim: number of latent factors 118 | reg: regularization constant 119 | unobs: unobserved item weight 120 | weights: True: set obs weights, False: obs weights = unobs weights 121 | wt_type: feature weight type: linear (0) or log (1) 122 | feature_wt_exp: feature weight exponent constant 123 | obs_wt: feature weight linear factor constant 124 | 125 | Returns: 126 | input_tensor: tensor holding the input ratings matrix 127 | row_factor: tensor for row_factor 128 | col_factor: tensor for col_factor 129 | model: WALSModel instance 130 | """ 131 | row_wts = None 132 | col_wts = None 133 | 134 | num_rows = data.shape[0] 135 | num_cols = data.shape[1] 136 | 137 | if weights: 138 | assert feature_wt_exp is not None 139 | row_wts = np.ones(num_rows) 140 | col_wts = make_wts(data, wt_type, obs_wt, feature_wt_exp, 0) 141 | 142 | row_factor = None 143 | col_factor = None 144 | 145 | with tf.Graph().as_default(): 146 | 147 | input_tensor = tf.SparseTensor(indices=zip(data.row, data.col), 148 | values=(data.data).astype(np.float32), 149 | dense_shape=data.shape) 150 | 151 | model = factorization_ops.WALSModel(num_rows, num_cols, dim, 152 | unobserved_weight=unobs, 153 | regularization=reg, 154 | row_weights=row_wts, 155 | col_weights=col_wts) 156 | 157 | # retrieve the row and column factors 158 | row_factor = model.row_factors[0] 159 | col_factor = model.col_factors[0] 160 | 161 | return input_tensor, row_factor, col_factor, model 162 | -------------------------------------------------------------------------------- /app/recommendations.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Recommendation generation module.""" 16 | 17 | import logging 18 | import numpy as np 19 | import os 20 | import pandas as pd 21 | 22 | import google.auth 23 | import google.cloud.storage as storage 24 | 25 | logging.basicConfig(level=logging.INFO) 26 | 27 | LOCAL_MODEL_PATH = '/tmp' 28 | 29 | ROW_MODEL_FILE = 'model/row.npy' 30 | COL_MODEL_FILE = 'model/col.npy' 31 | USER_MODEL_FILE = 'model/user.npy' 32 | ITEM_MODEL_FILE = 'model/item.npy' 33 | USER_ITEM_DATA_FILE = 'data/recommendation_events.csv' 34 | 35 | 36 | class Recommendations(object): 37 | """Provide recommendations from a pre-trained collaborative filtering model. 38 | 39 | Args: 40 | local_model_path: (string) local path to model files 41 | """ 42 | 43 | def __init__(self, local_model_path=LOCAL_MODEL_PATH): 44 | _, project_id = google.auth.default() 45 | self._bucket = 'recserve_' + project_id 46 | self._load_model(local_model_path) 47 | 48 | def _load_model(self, local_model_path): 49 | """Load recommendation model files from GCS. 50 | 51 | Args: 52 | local_model_path: (string) local path to model files 53 | """ 54 | # download files from GCS to local storage 55 | os.makedirs(os.path.join(local_model_path, 'model'), exist_ok=True) 56 | os.makedirs(os.path.join(local_model_path, 'data'), exist_ok=True) 57 | client = storage.Client() 58 | bucket = client.get_bucket(self._bucket) 59 | 60 | logging.info('Downloading blobs.') 61 | 62 | model_files = [ROW_MODEL_FILE, COL_MODEL_FILE, USER_MODEL_FILE, 63 | ITEM_MODEL_FILE, USER_ITEM_DATA_FILE] 64 | for model_file in model_files: 65 | blob = bucket.blob(model_file) 66 | with open(os.path.join(local_model_path, model_file), 'wb') as file_obj: 67 | blob.download_to_file(file_obj) 68 | 69 | logging.info('Finished downloading blobs.') 70 | 71 | # load npy arrays for user/item factors and user/item maps 72 | self.user_factor = np.load(os.path.join(local_model_path, ROW_MODEL_FILE)) 73 | self.item_factor = np.load(os.path.join(local_model_path, COL_MODEL_FILE)) 74 | self.user_map = np.load(os.path.join(local_model_path, USER_MODEL_FILE)) 75 | self.item_map = np.load(os.path.join(local_model_path, ITEM_MODEL_FILE)) 76 | 77 | logging.info('Finished loading arrays.') 78 | 79 | # load user_item history into pandas dataframe 80 | views_df = pd.read_csv(os.path.join(local_model_path, 81 | USER_ITEM_DATA_FILE), sep=',', header=0) 82 | self.user_items = views_df.groupby('clientId') 83 | 84 | logging.info('Finished loading model.') 85 | 86 | def get_recommendations(self, user_id, num_recs): 87 | """Given a user id, return list of num_recs recommended item ids. 88 | 89 | Args: 90 | user_id: (string) The user id 91 | num_recs: (int) The number of recommended items to return 92 | 93 | Returns: 94 | [item_id_0, item_id_1, ... item_id_k-1]: The list of k recommended items, 95 | if user id is found. 96 | None: The user id was not found. 97 | """ 98 | article_recommendations = None 99 | 100 | # map user id into ratings matrix user index 101 | user_idx = np.searchsorted(self.user_map, user_id) 102 | 103 | if user_idx: 104 | # get already viewed items from views dataframe 105 | already_rated = self.user_items.get_group(user_id).contentId 106 | already_rated_idx = [np.searchsorted(self.item_map, i) 107 | for i in already_rated] 108 | 109 | # generate list of recommended article indexes from model 110 | recommendations = generate_recommendations(user_idx, already_rated_idx, 111 | self.user_factor, 112 | self.item_factor, 113 | num_recs) 114 | 115 | # map article indexes back to article ids 116 | article_recommendations = [self.item_map[i] for i in recommendations] 117 | 118 | return article_recommendations 119 | 120 | 121 | def generate_recommendations(user_idx, user_rated, row_factor, col_factor, k): 122 | """Generate recommendations for a user. 123 | 124 | Args: 125 | user_idx: the row index of the user in the ratings matrix, 126 | 127 | user_rated: the list of item indexes (column indexes in the ratings matrix) 128 | previously rated by that user (which will be excluded from the 129 | recommendations), 130 | 131 | row_factor: the row factors of the recommendation model 132 | 133 | col_factor: the column factors of the recommendation model 134 | 135 | k: number of recommendations requested 136 | 137 | Returns: 138 | list of k item indexes with the predicted highest rating, 139 | excluding those that the user has already rated 140 | """ 141 | 142 | # bounds checking for args 143 | assert (row_factor.shape[0] - len(user_rated)) >= k 144 | 145 | # retrieve user factor 146 | user_f = row_factor[user_idx] 147 | 148 | # dot product of item factors with user factor gives predicted ratings 149 | pred_ratings = col_factor.dot(user_f) 150 | 151 | # find candidate recommended item indexes sorted by predicted rating 152 | k_r = k + len(user_rated) 153 | candidate_items = np.argsort(pred_ratings)[-k_r:] 154 | 155 | # remove previously rated items and take top k 156 | recommended_items = [i for i in candidate_items if i not in user_rated] 157 | recommended_items = recommended_items[-k:] 158 | 159 | # flip to sort highest rated first 160 | recommended_items.reverse() 161 | 162 | return recommended_items 163 | 164 | -------------------------------------------------------------------------------- /wals_ml_engine/trainer/task.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Job entry point for ML Engine.""" 16 | 17 | import argparse 18 | import json 19 | import os 20 | import tensorflow as tf 21 | 22 | import model 23 | import util 24 | import wals 25 | 26 | 27 | def main(args): 28 | # process input file 29 | input_file = util.ensure_local_file(args['train_file']) 30 | user_map, item_map, tr_sparse, test_sparse = model.create_test_and_train_sets( 31 | args, input_file, args['data_type']) 32 | 33 | # train model 34 | output_row, output_col = model.train_model(args, tr_sparse) 35 | 36 | # save trained model to job directory 37 | model.save_model(args, user_map, item_map, output_row, output_col) 38 | 39 | # log results 40 | train_rmse = wals.get_rmse(output_row, output_col, tr_sparse) 41 | test_rmse = wals.get_rmse(output_row, output_col, test_sparse) 42 | 43 | if args['hypertune']: 44 | # write test_rmse metric for hyperparam tuning 45 | util.write_hptuning_metric(args, test_rmse) 46 | 47 | tf.logging.info('train RMSE = %.2f' % train_rmse) 48 | tf.logging.info('test RMSE = %.2f' % test_rmse) 49 | 50 | 51 | def parse_arguments(): 52 | """Parse job arguments.""" 53 | 54 | parser = argparse.ArgumentParser() 55 | # required input arguments 56 | parser.add_argument( 57 | '--train-file', 58 | help='path to training data', 59 | required=True 60 | ) 61 | parser.add_argument( 62 | '--job-dir', 63 | help='GCS location to write checkpoints and export models', 64 | required=True 65 | ) 66 | 67 | # hyper params for model 68 | parser.add_argument( 69 | '--latent_factors', 70 | type=int, 71 | help='Number of latent factors', 72 | ) 73 | parser.add_argument( 74 | '--num_iters', 75 | type=int, 76 | help='Number of iterations for alternating least squares factorization', 77 | ) 78 | parser.add_argument( 79 | '--regularization', 80 | type=float, 81 | help='L2 regularization factor', 82 | ) 83 | parser.add_argument( 84 | '--unobs_weight', 85 | type=float, 86 | help='Weight for unobserved values', 87 | ) 88 | parser.add_argument( 89 | '--wt_type', 90 | type=int, 91 | help='Rating weight type (0=linear, 1=log)', 92 | default=wals.LINEAR_RATINGS 93 | ) 94 | parser.add_argument( 95 | '--feature_wt_factor', 96 | type=float, 97 | help='Feature weight factor (linear ratings)', 98 | ) 99 | parser.add_argument( 100 | '--feature_wt_exp', 101 | type=float, 102 | help='Feature weight exponent (log ratings)', 103 | ) 104 | 105 | # other args 106 | parser.add_argument( 107 | '--gcs-bucket', 108 | help='gcs bucket path for training data', 109 | required=False 110 | ) 111 | parser.add_argument( 112 | '--output-dir', 113 | help='GCS location to write model, overriding job-dir', 114 | ) 115 | parser.add_argument( 116 | '--verbose-logging', 117 | default=False, 118 | action='store_true', 119 | help='Switch to turn on or off verbose logging and warnings' 120 | ) 121 | parser.add_argument( 122 | '--hypertune', 123 | default=False, 124 | action='store_true', 125 | help='Switch to turn on or off hyperparam tuning' 126 | ) 127 | parser.add_argument( 128 | '--data-type', 129 | type=str, 130 | default='ratings', 131 | help='Data type, one of ratings (e.g. MovieLens) or web_views (GA data)' 132 | ) 133 | parser.add_argument( 134 | '--delimiter', 135 | type=str, 136 | default='\t', 137 | help='Delimiter for csv data files' 138 | ) 139 | parser.add_argument( 140 | '--headers', 141 | default=False, 142 | action='store_true', 143 | help='Input file has a header row' 144 | ) 145 | parser.add_argument( 146 | '--use-optimized', 147 | default=False, 148 | action='store_true', 149 | help='Use optimized hyperparameters' 150 | ) 151 | 152 | args = parser.parse_args() 153 | arguments = args.__dict__ 154 | 155 | # add optional gcs path to training data 156 | if args.gcs_bucket: 157 | args.train_file = os.path.join(args.gcs_bucket, args.train_file) 158 | 159 | # set job name as job directory name 160 | job_dir = args.job_dir 161 | job_dir = job_dir[:-1] if job_dir.endswith('/') else job_dir 162 | job_name = os.path.basename(job_dir) 163 | 164 | # set output directory for model 165 | if args.hypertune: 166 | # if tuning, join the trial number to the output path 167 | config = json.loads(os.environ.get('TF_CONFIG', '{}')) 168 | trial = config.get('task', {}).get('trial', '') 169 | output_dir = os.path.join(job_dir, trial) 170 | elif args.output_dir: 171 | output_dir = args.output_dir 172 | else: 173 | output_dir = job_dir 174 | 175 | if args.verbose_logging: 176 | tf.logging.set_verbosity(tf.logging.INFO) 177 | 178 | # Find out if there's a task value on the environment variable. 179 | # If there is none or it is empty define a default one. 180 | env = json.loads(os.environ.get('TF_CONFIG', '{}')) 181 | task_data = env.get('task') or {'type': 'master', 'index': 0} 182 | 183 | # update default params with any args provided to task 184 | params = model.DEFAULT_PARAMS 185 | params.update({k: arg for k, arg in arguments.iteritems() if arg is not None}) 186 | if args.use_optimized: 187 | if args.data_type == 'web_views': 188 | params.update(model.OPTIMIZED_PARAMS_WEB) 189 | else: 190 | params.update(model.OPTIMIZED_PARAMS) 191 | params.update(task_data) 192 | params.update({'output_dir': output_dir}) 193 | params.update({'job_name': job_name}) 194 | 195 | # For web_view data, default to using the exponential weight formula 196 | # with feature weight exp. 197 | # For movie lens data, default to the linear weight formula. 198 | if args.data_type == 'web_views': 199 | params.update({'wt_type': wals.LOG_RATINGS}) 200 | 201 | return params 202 | 203 | 204 | if __name__ == '__main__': 205 | job_args = parse_arguments() 206 | main(job_args) 207 | 208 | 209 | -------------------------------------------------------------------------------- /airflow/plugins/ml_engine_plugin.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Airflow plugin for ML Engine, backported from v1.9.""" 16 | 17 | from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook 18 | from airflow.exceptions import AirflowException 19 | from airflow.models import BaseOperator 20 | from airflow.plugins_manager import AirflowPlugin 21 | from airflow.utils.decorators import apply_defaults 22 | 23 | from apiclient.discovery import build 24 | from googleapiclient import errors 25 | 26 | import logging 27 | from oauth2client.client import GoogleCredentials 28 | import re 29 | import time 30 | 31 | 32 | class MLEngineHook(GoogleCloudBaseHook): 33 | """Hook for ML Engine.""" 34 | 35 | def __init__(self, gcp_conn_id='google_cloud_default', delegate_to=None): 36 | super(MLEngineHook, self).__init__(gcp_conn_id, delegate_to) 37 | self._mlengine = self.get_conn() 38 | 39 | def normalize_mlengine_job_id(self, job_id): 40 | """Replaces invalid MLEngine job_id characters with '_'. 41 | 42 | This also adds a leading 'z' in case job_id starts with an invalid 43 | character. 44 | 45 | Args: 46 | job_id: A job_id str that may have invalid characters. 47 | 48 | Returns: 49 | A valid job_id representation. 50 | """ 51 | match = re.search(r'\d', job_id) 52 | if match and match.start() is 0: 53 | job_id = 'z_{}'.format(job_id) 54 | return re.sub('[^0-9a-zA-Z]+', '_', job_id) 55 | 56 | def get_conn(self): 57 | """Returns a Google MLEngine service object.""" 58 | credentials = GoogleCredentials.get_application_default() 59 | return build('ml', 'v1', credentials=credentials) 60 | 61 | def create_job(self, project_id, job, use_existing_job_fn=None): 62 | """Launches a MLEngine job and wait for it to reach a terminal state. 63 | 64 | Args: 65 | project_id: project id 66 | job: job name 67 | use_existing_job_fn: existing job to use 68 | 69 | Returns: 70 | The MLEngine job object if the job successfully reach a 71 | terminal state (which might be FAILED or CANCELLED state). 72 | """ 73 | request = self._mlengine.projects().jobs().create( 74 | parent='projects/{}'.format(project_id), 75 | body=job) 76 | job_id = job['jobId'] 77 | 78 | try: 79 | request.execute() 80 | except errors.HttpError as e: 81 | # 409 means there is an existing job with the same job ID. 82 | if e.resp.status == 409: 83 | if use_existing_job_fn is not None: 84 | existing_job = self._get_job(project_id, job_id) 85 | if not use_existing_job_fn(existing_job): 86 | logging.error( 87 | 'Job with job_id %s already exist, but it does ' 88 | 'not match our expectation: %s', 89 | job_id, existing_job 90 | ) 91 | raise 92 | logging.info( 93 | 'Job with job_id %s already exist. Will waiting for it to finish', 94 | job_id 95 | ) 96 | else: 97 | logging.error('Failed to create MLEngine job: %s', e) 98 | raise 99 | 100 | return self._wait_for_job_done(project_id, job_id) 101 | 102 | def _get_job(self, project_id, job_id): 103 | """Gets a MLEngine job based on the job name. 104 | 105 | Args: 106 | project_id: project id 107 | job_id: job id 108 | 109 | Returns: 110 | MLEngine job object if succeed. 111 | 112 | Raises: 113 | apiclient.errors.HttpError: if HTTP error is returned from server 114 | """ 115 | job_name = 'projects/{}/jobs/{}'.format(project_id, job_id) 116 | request = self._mlengine.projects().jobs().get(name=job_name) 117 | while True: 118 | try: 119 | return request.execute() 120 | except errors.HttpError as e: 121 | if e.resp.status == 429: 122 | # polling after 30 seconds when quota failure occurs 123 | time.sleep(30) 124 | else: 125 | logging.error('Failed to get MLEngine job: %s', e) 126 | raise 127 | 128 | def _wait_for_job_done(self, project_id, job_id, interval=30): 129 | """Waits for the Job to reach a terminal state. 130 | 131 | This method will periodically check the job state until the job reach 132 | a terminal state. 133 | 134 | Args: 135 | project_id: project id 136 | job_id: job id 137 | interval: check interval in seconds 138 | 139 | Returns: 140 | MLEngine job object if succeed. 141 | 142 | Raises: 143 | apiclient.errors.HttpError: if HTTP error is returned when getting 144 | the job 145 | """ 146 | assert interval > 0 147 | while True: 148 | job = self._get_job(project_id, job_id) 149 | if job['state'] in ['SUCCEEDED', 'FAILED', 'CANCELLED']: 150 | return job 151 | time.sleep(interval) 152 | 153 | 154 | class MLEngineTrainingOperator(BaseOperator): 155 | """Operator for launching a MLEngine training job. 156 | """ 157 | 158 | @apply_defaults 159 | def __init__(self, 160 | project_id, 161 | job_id, 162 | package_uris, 163 | training_python_module, 164 | training_args, 165 | region, 166 | scale_tier=None, 167 | master_type=None, 168 | gcp_conn_id='google_cloud_default', 169 | delegate_to=None, 170 | mode='PRODUCTION', 171 | *args, 172 | **kwargs): 173 | super(MLEngineTrainingOperator, self).__init__(*args, **kwargs) 174 | self._project_id = project_id 175 | self._job_id = job_id 176 | self._package_uris = package_uris 177 | self._training_python_module = training_python_module 178 | self._training_args = training_args 179 | self._region = region 180 | self._scale_tier = scale_tier 181 | self._master_type = master_type 182 | self._gcp_conn_id = gcp_conn_id 183 | self._delegate_to = delegate_to 184 | self._mode = mode 185 | 186 | if not self._project_id: 187 | raise AirflowException('Google Cloud project id is required.') 188 | if not self._job_id: 189 | raise AirflowException( 190 | 'An unique job id is required for Google MLEngine training ' 191 | 'job.') 192 | if not package_uris: 193 | raise AirflowException( 194 | 'At least one python package is required for MLEngine ' 195 | 'Training job.') 196 | if not training_python_module: 197 | raise AirflowException( 198 | 'Python module name to run after installing required ' 199 | 'packages is required.') 200 | if not self._region: 201 | raise AirflowException('Google Compute Engine region is required.') 202 | 203 | def execute(self, context): 204 | hook = MLEngineHook( 205 | gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to) 206 | 207 | job_id = hook.normalize_mlengine_job_id(self._job_id) 208 | training_request = { 209 | 'jobId': job_id, 210 | 'trainingInput': { 211 | 'scaleTier': self._scale_tier, 212 | 'packageUris': self._package_uris, 213 | 'pythonModule': self._training_python_module, 214 | 'region': self._region, 215 | 'args': self._training_args, 216 | 'masterType': self._master_type 217 | } 218 | } 219 | 220 | if self._mode == 'DRY_RUN': 221 | logging.info('In dry_run mode.') 222 | logging.info('MLEngine Training job request is: %s', training_request) 223 | return 224 | 225 | # Helper method to check if the existing job's training input is the 226 | # same as the request we get here. 227 | def check_existing_job(existing_job): 228 | return (existing_job.get('trainingInput', None) 229 | == training_request['trainingInput']) 230 | 231 | try: 232 | finished_training_job = hook.create_job( 233 | self._project_id, training_request, check_existing_job) 234 | except errors.HttpError: 235 | raise 236 | 237 | if finished_training_job['state'] != 'SUCCEEDED': 238 | logging.error('MLEngine training job failed: %s', 239 | str(finished_training_job)) 240 | raise RuntimeError(finished_training_job['errorMessage']) 241 | 242 | 243 | # Plugin class for GoogleMLEngine 244 | class GoogleMLEnginePlugin(AirflowPlugin): 245 | name = 'ml_engine_plugin' 246 | operators = [MLEngineTrainingOperator] 247 | hooks = [MLEngineHook] 248 | executors = [] 249 | macros = [] 250 | admin_views = [] 251 | flask_blueprints = [] 252 | menu_links = [] 253 | 254 | -------------------------------------------------------------------------------- /airflow/plugins/gae_admin_plugin.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Airflow plugin for Google App Engine Admin functions.""" 16 | 17 | from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook 18 | from airflow.exceptions import AirflowException 19 | from airflow.models import BaseOperator 20 | from airflow.plugins_manager import AirflowPlugin 21 | from airflow.utils.decorators import apply_defaults 22 | from apiclient.discovery import build 23 | from datetime import datetime 24 | from googleapiclient import errors 25 | 26 | import logging 27 | from oauth2client.client import GoogleCredentials 28 | import time 29 | 30 | 31 | class AppEngineAdminHook(GoogleCloudBaseHook): 32 | """Hook for App Engine Flex Admin.""" 33 | 34 | def __init__(self, gcp_conn_id='google_cloud_default', delegate_to=None): 35 | super(AppEngineAdminHook, self).__init__(gcp_conn_id, delegate_to) 36 | self._gaeadmin = self.get_ae_conn() 37 | self._svcadmin = self.get_svc_conn() 38 | 39 | def get_ae_conn(self): 40 | """Returns: a App Engine service object.""" 41 | credentials = GoogleCredentials.get_application_default() 42 | return build('appengine', 'v1', credentials=credentials) 43 | 44 | def get_svc_conn(self): 45 | """Returns: a Services Management service object.""" 46 | credentials = GoogleCredentials.get_application_default() 47 | return build('servicemanagement', 'v1', credentials=credentials) 48 | 49 | def create_version(self, project_id, service_id, version_spec): 50 | """Creates new service version on App Engine Engine. 51 | 52 | Args: 53 | project_id: project id 54 | service_id: service id 55 | version_spec: app version spec 56 | 57 | Returns: 58 | The operation if the version was created successfully and 59 | raises an error otherwise. 60 | """ 61 | create_request = self._gaeadmin.apps().services().versions().create( 62 | appsId=project_id, servicesId=service_id, body=version_spec) 63 | response = create_request.execute() 64 | op_name = response['name'].split('/')[-1] 65 | return self._wait_for_operation_done(project_id, op_name) 66 | 67 | def migrate_traffic(self, project_id, service_id, new_version): 68 | """Migrate AE traffic from current version to new version. 69 | 70 | Args: 71 | project_id: project id 72 | service_id: service id 73 | new_version: new version id 74 | 75 | Returns: 76 | the operation if the migration was successful and 77 | raises an error otherwise. 78 | """ 79 | split_config = {'split': {'allocations': {new_version: '1'}}} 80 | migrate_request = self._gaeadmin.apps().services().patch( 81 | appsId=project_id, servicesId=service_id, updateMask='split', 82 | body=split_config) 83 | response = migrate_request.execute() 84 | op_name = response['name'].split('/')[-1] 85 | return self._wait_for_operation_done(project_id, op_name) 86 | 87 | def get_endpoint_config(self, service_id): 88 | """Get latest endpoint config for an endpoint service. 89 | 90 | Args: 91 | service_id: service id 92 | 93 | Returns: 94 | the config version if successful and raises an error otherwise. 95 | """ 96 | resource = self._svcadmin.services().rollouts() 97 | list_request = resource.list(serviceName=service_id) 98 | response = list_request.execute() 99 | config_id = response['rollouts'][0]['rolloutId'] 100 | 101 | return config_id 102 | 103 | def get_version(self, project_id, service_id, version): 104 | """Get spec for a version of a service on App Engine Engine. 105 | 106 | Args: 107 | project_id: project id 108 | service_id: service id 109 | version: version id 110 | 111 | Returns: 112 | the version spec if successful and raises an error otherwise. 113 | """ 114 | resource = self._gaeadmin.apps().services().versions() 115 | get_request = resource.get(appsId=project_id, servicesId=service_id, 116 | versionsId=version, view='FULL') 117 | response = get_request.execute() 118 | 119 | return response 120 | 121 | def get_version_identifiers(self, project_id, service_id): 122 | """Get list of versions of a service on App Engine Engine. 123 | 124 | Args: 125 | project_id: project id 126 | service_id: service id 127 | 128 | Returns: 129 | the list of version identifiers if successful and raises an error otherwise. 130 | """ 131 | request = self._gaeadmin.apps().services().versions().list(appsId=project_id, 132 | servicesId=service_id) 133 | versions = [] 134 | while request is not None: 135 | versions_doc = request.execute() 136 | versions.extend([v['id'] for v in versions_doc['versions']]) 137 | request = self._gaeadmin.apps().services().versions().list_next(request, 138 | versions_doc) 139 | 140 | return versions 141 | 142 | def _get_operation(self, project_id, op_name): 143 | """Gets an AppEngine operation based on the operation name. 144 | 145 | Args: 146 | project_id: project id 147 | op_name: operation name 148 | 149 | Returns: 150 | AppEngine operation object if succeed. 151 | 152 | Raises: 153 | apiclient.errors.HttpError: if HTTP error is returned from server 154 | """ 155 | resource = self._gaeadmin.apps().operations() 156 | request = resource.get(appsId=project_id, operationsId=op_name) 157 | return request.execute() 158 | 159 | def _wait_for_operation_done(self, project_id, op_name, interval=30): 160 | """Waits for the Operation to reach a terminal state. 161 | 162 | This method will periodically check the job state until the operation reaches 163 | a terminal state. 164 | 165 | Args: 166 | project_id: project id 167 | op_name: operation name 168 | interval: check interval in seconds 169 | 170 | Returns: 171 | AppEngine operation object if succeed. 172 | 173 | Raises: 174 | apiclient.errors.HttpError: if HTTP error is returned when getting 175 | the operation 176 | """ 177 | assert interval > 0 178 | while True: 179 | operation = self._get_operation(project_id, op_name) 180 | if 'done' in operation and operation['done']: 181 | return operation 182 | time.sleep(interval) 183 | 184 | 185 | class AppEngineVersionOperator(BaseOperator): 186 | """Operator for creating a new AppEngine flex service version.""" 187 | 188 | @apply_defaults 189 | def __init__(self, 190 | project_id, 191 | service_id, 192 | region, 193 | service_spec=None, 194 | gcp_conn_id='google_cloud_default', 195 | delegate_to=None, 196 | *args, 197 | **kwargs): 198 | super(AppEngineVersionOperator, self).__init__(*args, **kwargs) 199 | self._project_id = project_id 200 | self._service_id = service_id 201 | self._region = region 202 | 203 | self._service_spec = service_spec 204 | self._gcp_conn_id = gcp_conn_id 205 | self._delegate_to = delegate_to 206 | 207 | if not self._project_id: 208 | raise AirflowException('Cloud project id is required.') 209 | if not self._service_id: 210 | raise AirflowException('App Engine service name is required.') 211 | if not self._region: 212 | raise AirflowException('Compute Engine region is required.') 213 | 214 | def execute(self, context): 215 | hook = AppEngineAdminHook( 216 | gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to) 217 | 218 | # if version spec is not provided, use spec for latest version 219 | if self._service_spec is None: 220 | # get version spec for latest version, assuming version ids are sortable 221 | version_list = hook.get_version_identifiers(self._project_id, 222 | self._service_id) 223 | latest_version = max(version_list) 224 | version_spec = hook.get_version(self._project_id, self._service_id, 225 | latest_version) 226 | 227 | # get endpoints config id 228 | endpoint_service = '{0}.appspot.com'.format(self._project_id) 229 | config_id = hook.get_endpoint_config(endpoint_service) 230 | 231 | # clean irrelevant params 232 | version_spec.pop('name', None) 233 | version_spec.pop('threadsafe', None) 234 | version_spec.pop('servingStatus', None) 235 | version_spec.pop('createTime', None) 236 | version_spec.pop('createdBy', None) 237 | version_spec.pop('runtimeApiVersion', None) 238 | version_spec.pop('versionUrl', None) 239 | version_spec.pop('betaSettings', None) 240 | 241 | # fix docker container ref 242 | container_ref = version_spec['deployment']['container']['image'] 243 | container_ref = container_ref.split('@sha')[0] 244 | version_spec['deployment']['container']['image'] = container_ref 245 | 246 | # add endpoint service params 247 | version_spec.update({'endpointsApiService': { 248 | 'name': '{0}.appspot.com'.format(self._project_id), 249 | 'configId': config_id 250 | }}) 251 | else: 252 | version_spec = self._version_spec 253 | 254 | if 'id' not in version_spec: 255 | # generate version id and add to params 256 | now = datetime.now() 257 | version_spec['id'] = '{0}t{1}'.format(now.strftime('%Y%m%d'), 258 | now.strftime('%H%M%S')) 259 | 260 | # deploy new version 261 | try: 262 | finished_version_op = hook.create_version(self._project_id, 263 | self._service_id, version_spec) 264 | except errors.HttpError: 265 | raise 266 | 267 | if 'error' in finished_version_op: 268 | logging.error('AppEngine version deploy failed: %s', 269 | str(finished_version_op)) 270 | raise RuntimeError(finished_version_op['error']['message']) 271 | 272 | # migrate traffic to new version 273 | new_version = version_spec['id'] 274 | try: 275 | finished_migrate_op = hook.migrate_traffic(self._project_id, 276 | self._service_id, new_version) 277 | except errors.HttpError: 278 | raise 279 | 280 | if 'error' in finished_migrate_op: 281 | logging.error('AppEngine version migrate failed: %s', 282 | str(finished_version_op)) 283 | raise RuntimeError(finished_migrate_op['error']['message']) 284 | 285 | 286 | # Plugin class for GAEAdmin 287 | class AppEngineAdminPlugin(AirflowPlugin): 288 | name = 'app_engine_admin_plugin' 289 | operators = [AppEngineVersionOperator] 290 | hooks = [AppEngineAdminHook] 291 | executors = [] 292 | macros = [] 293 | admin_views = [] 294 | flask_blueprints = [] 295 | menu_links = [] 296 | 297 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Recommendations on GCP with TensorFlow 1.x and WALS 2 | 3 | This project deploys a solution for a recommendation service on GCP, using the WALS 4 | algorithm in the contrib.factorization module of TensorFlow 1.x. Components include: 5 | 6 | - Recommendation model code, and scripts to train and tune the model on ML Engine 7 | - A REST endpoint using [Google Cloud Endpoints](https://cloud.google.com/endpoints/) for serving recommendations 8 | - An Airflow server managed by Cloud Composer (or alternatively, running on GKE) for running scheduled model training 9 | 10 | 11 | ## Before you begin 12 | 13 | 1. Create a new [Cloud Platform project](https://console.cloud.google.com/projectcreate). 14 | 15 | 2. [Enable billing](https://support.google.com/cloud/answer/6293499#enable-billing) 16 | for your project. 17 | 18 | 3. [Enable APIs](https://console.cloud.google.com/apis/dashboard) for 19 | * BigQuery API 20 | * Cloud Resource Manager 21 | * AI Platform Training & Prediction API 22 | * App Engine Admin 23 | * Container Engine (if using Airflow on GKE) 24 | * Cloud SQL API (if using Airflow on GKE) 25 | * Cloud Composer API (if using Cloud Composer for Airflow) 26 | 27 | 28 | ## Installation 29 | 30 | ### Option 1: Use Google Cloud Shell 31 | 32 | 1. Open the [Google Cloud Platform 33 | Console](https://console.cloud.google.com/?_ga=1.38191587.1500870598.1489443487). 34 | 35 | 2. Click the Cloud Shell icon at the top of the screen. 36 | ![Cloud Shell](https://cloud.google.com/shell/docs/images/shell_icon.png) 37 | 38 | ### Option 2: Run Locally in Linux or Mac OS X 39 | 40 | *These scripts will not work in Windows. If you have a Windows machine, we 41 | recommend you use Google Cloud Shell.* 42 | 43 | 1. Download and install the [Google Cloud 44 | SDK](https://cloud.google.com/sdk/docs/), which includes the 45 | [gcloud](https://cloud.google.com/sdk/gcloud/) command-line tool. 46 | 47 | 2. Initialize the Cloud SDK. 48 | 49 | gcloud init 50 | 51 | 3. Set your default project (replace YOUR-PROJECT-ID with the name of your 52 | project). 53 | 54 | gcloud config set project YOUR-PROJECT-ID 55 | 56 | 57 | ### Install Miniconda 2 58 | 59 | This project assumes Python 2. 60 | 61 | * Install miniconda 2: 62 | 63 | https://docs.conda.io/en/latest/miniconda.html#installing 64 | 65 | 66 | * Create environment and install packages: 67 | 68 | Install packages in conda.txt: 69 | 70 | cd tensorflow-recommendation-wals 71 | conda create -y -n recserve 72 | source activate recserve 73 | conda install -y -n recserve --file conda.txt 74 | 75 | * Install TensorFlow version 1.x. This code should work with any version of 1.x. We are using the latest as of June 2020. 76 | 77 | 78 | CPU: 79 | 80 | pip install tensorflow==1.15 81 | 82 | Or GPU, if one is available in your environment: 83 | 84 | pip install tensorflow-gpu==1.15 85 | 86 | Install other requirements not available from conda: 87 | 88 | pip install -r requirements.txt 89 | 90 | ### Upload sample data to BigQuery 91 | 92 | This tutorial comes with a sample Google Analytics data set, containing page tracking events from the Austrian news site Kurier.at. The schema file '''ga_sessions_sample_schema.json''' is located in the folder data in the tutorial code, and the data file '''ga_sessions_sample.json.gz''' is located in a public Cloud Storage bucket associated with this tutorial. To upload this data set to BigQuery: 93 | 94 | 1. Make a GCS bucket with the name recserve_[YOUR-PROJECT-ID]: 95 | 96 | export BUCKET=gs://recserve_$(gcloud config get-value project 2> /dev/null) 97 | gsutil mb ${BUCKET} 98 | 99 | 2. Copy the data file ga_sessions_sample.json.gz to the bucket: 100 | 101 | gsutil cp gs://solutions-public-assets/recommendation-tensorflow/data/ga_sessions_sample.json.gz ${BUCKET}/data/ga_sessions_sample.json.gz 102 | 103 | 3. (Option 1) Go to the BigQuery web UI. Create a new dataset named "GA360_test". In the navigation panel, hover on the dataset, click the down arrow icon, and click Create new table. On the Create Table page, in the Source Data section: 104 | - For Location, select Google Cloud Storage, and enter the file path [your_bucket]/data/ga_sessions_sample.json.gz (without the gs:// prefix). 105 | - For File format, select JSON. 106 | - On the Create Table page, in the Destination Table section, for Table name, choose the dataset, and in the table name field, enter the name of the table as 'ga_sessions_sample'. 107 | - Verify that Table type is set to Native table. 108 | - In the Schema section, enter the schema definition. 109 | - Open the file data/ga_sessions_sample_schema.json in a text editor, select all, and copy the complete text of the file to the clipboard. Click Edit as text and paste the table schema into the text field in the web UI. 110 | - Click Create Table. 111 | 112 | 4. (Option 2) Using the command line: 113 | 114 | export PROJECT=$(gcloud config get-value project 2> /dev/null) 115 | 116 | bq --project_id=${PROJECT} mk GA360_test 117 | 118 | bq load --source_format=NEWLINE_DELIMITED_JSON \ 119 | GA360_test.ga_sessions_sample \ 120 | ${BUCKET}/data/ga_sessions_sample.json.gz \ 121 | data/ga_sessions_sample_schema.json 122 | 123 | 124 | ### Install WALS model training package and model data 125 | 126 | 1. Create a distributable package. Copy the package up to the code folder in the bucket you created previously. 127 | 128 | pushd wals_ml_engine 129 | python setup.py sdist 130 | gsutil cp dist/wals_ml_engine-0.1.tar.gz ${BUCKET}/code/ 131 | 132 | 2. Run the wals model on the sample data set: 133 | 134 | ./mltrain.sh local ../data recommendation_events.csv --data-type web_views --use-optimized 135 | 136 | This will take a couple minutes, and create a job directory under wals_ml_engine/jobs like "wals_ml_local_20180102_012345/model", containing the model files saved as numpy arrays. 137 | 138 | 3. Copy the model files from this directory to the model folder in the project bucket: 139 | 140 | export JOB_MODEL=$(find jobs -name "model" | tail -1) 141 | gsutil cp ${JOB_MODEL}/* ${BUCKET}/model/ 142 | 143 | 4. Copy the sample data file up to the project bucket: 144 | 145 | gsutil cp ../data/recommendation_events.csv ${BUCKET}/data/ 146 | popd 147 | 148 | ### Install the recserve endpoint 149 | 150 | This step can take several minutes to complete. You can do this in a separate shell. That way you can deploy the Airflow service in parallel. Remember to 'source activate recserve' in any new shell that you open, to activate the recserve envionment. 151 | 152 | source activate recserve 153 | 154 | 1. Create the App Engine app in your project: 155 | 156 | gcloud app create --region=us-east1 157 | 158 | 2. Prepare the deploy template for the Cloud Endpoint API: 159 | 160 | cd scripts 161 | ./prepare_deploy_api.sh # Prepare config file for the API. 162 | 163 | This will output somthing like: 164 | 165 | ... 166 | To deploy: gcloud endpoints services deploy /var/folders/1m/r3slmhp92074pzdhhfjvnw0m00dhhl/T/tmp.n6QVl5hO.yaml 167 | 168 | 3. Run the endpoints deploy command output above: 169 | 170 | gcloud endpoints services deploy [TEMP_FILE] 171 | 172 | 4. Prepare the deploy template for the App Engine App: 173 | 174 | ./prepare_deploy_app.sh 175 | 176 | You can ignore the script output "ERROR: (gcloud.app.create) The project [...] already contains an App Engine application. You can deploy your application using `gcloud app deploy`." This is expected. 177 | 178 | The script will output something like: 179 | 180 | ... 181 | To deploy: gcloud -q app deploy ../app/app_template.yaml_deploy.yaml 182 | 183 | 5. Run the command above: 184 | 185 | gcloud -q app deploy ../app/app_template.yaml_deploy.yaml 186 | 187 | This will take several minutes. 188 | 189 | cd .. 190 | 191 | ### Deploy the Airflow service 192 | 193 | #### Option 1 (recommended): Use Cloud Composer 194 | Cloud Composer is the GCP managed service for Airflow. It is in beta at the time this code is published. 195 | 196 | 1. Create a new Cloud Composer environment in your project: 197 | 198 | CC_ENV=composer-recserve 199 | 200 | gcloud composer environments create $CC_ENV --location us-central1 201 | 202 | This process takes a few minutes to complete. 203 | 204 | 2. Get the name of the Cloud Storage bucket created for you by Cloud Composer: 205 | 206 | gcloud composer environments describe $CC_ENV \ 207 | --location us-central1 --format="csv[no-heading](config.dagGcsPrefix)" | sed 's/.\{5\}$//' 208 | 209 | In the output, you see the location of the Cloud Storage bucket, like this: 210 | 211 | gs://[region-environment_name-random_id-bucket] 212 | 213 | This bucket contains subfolders for DAGs and plugins. 214 | 215 | 3. Set a shell variable that contains the path to that output: 216 | 217 | export AIRFLOW_BUCKET="gs://[region-environment_name-random_id-bucket]" 218 | 219 | 4. Copy the DAG training.py file to the dags folder in your Cloud Composer bucket: 220 | 221 | gsutil cp airflow/dags/training.py ${AIRFLOW_BUCKET}/dags 222 | 223 | 5. Import the solution plugins to your composer environment: 224 | 225 | gcloud composer environments storage plugins import \ 226 | --location us-central1 --environment ${CC_ENV} --source airflow/plugins/ 227 | 228 | 229 | #### Option 2: Create an Airflow cluster running on GKE 230 | 231 | This can be done in parallel with the app deploy step in a different shell. 232 | 233 | 1. Deploy the Airflow service using the script in airflow/deploy: 234 | 235 | source activate recserve 236 | cd airflow/deploy 237 | ./deploy_airflow.sh 238 | 239 | This will take a few minutes to complete. 240 | 241 | 2. Create "dags," "logs" and "plugins" folders in the GCS bucket created by the deploy script named (managed-airflow-{random hex value}), e.g. gs://managed-airflow-e0c99374808c4d4e8002e481. See https://storage.googleapis.com/solutions-public-assets/recommendation-tensorflow/images/airflow_buckets.png. The name of the bucket is available in the ID field of the airflow/deploy/deployment-settings.yaml file created by the deploy script. You can create the folders in the cloud console, or use the following script: 242 | 243 | cd ../.. 244 | python airflow/deploy/create_buckets.py 245 | 246 | 3. Copy training.py to the dags folder in your airflow bucket: 247 | 248 | export AIRFLOW_BUCKET=`python -c "\ 249 | import yaml;\ 250 | f = open('airflow/deploy/deployment-settings.yaml');\ 251 | settings=yaml.load(f);\ 252 | f.close();\ 253 | print settings['id']"` 254 | 255 | gsutil cp airflow/dags/training.py gs://${AIRFLOW_BUCKET}/dags 256 | 257 | 4. Copy plugins to the plugins folder of your airflow bucket: 258 | 259 | gsutil cp -r airflow/plugins gs://${AIRFLOW_BUCKET} 260 | 261 | 5. Restart the airflow webserver pod 262 | 263 | WS_POD=`kubectl get pod | grep -o airflow-webserver-[0-9a-z]*-[0-9a-z]*` 264 | kubectl get pod ${WS_POD} -o yaml | kubectl replace --force -f - 265 | 266 | 267 | 268 | ## Usage 269 | 270 | ### rec_serve endpoint service 271 | 272 | cd scripts 273 | ./query_api.sh # Query the API. 274 | ./generate_traffic.sh # Send traffic to the API. 275 | 276 | 277 | ### Airflow 278 | 279 | The Airflow web console can be used to update the schedule for the DAG, inspect logs, manually 280 | execute tasks, etc. 281 | 282 | #### Option 1 (Cloud Composer) 283 | 284 | Note that after creating the Cloud Composer environment, it takes approximately 25 285 | minutes for the web interface to finish hosting and become accessible. 286 | 287 | Type this command to print the URL for the Cloud Composer web console: 288 | 289 | gcloud composer environments describe $CC_ENV --location us-central1 \ 290 | --format="csv[no-heading](config.airflow_uri)" 291 | 292 | You see output that looks like the following: 293 | 294 | https://x6c9aa336e72ad0dd-tp.appspot.com 295 | 296 | To access the Airflow console for your Cloud Composer instance, go to the URL displayed in the output. 297 | 298 | #### Option 2 (GKE) 299 | You can find the URL and login credentials for the airflow admin interface in the file 300 | airflow/deploy/deployment-settings.yaml. 301 | 302 | e.g. 303 | 304 | ... 305 | web_ui_password: IiDYrpwJcT... 306 | web_ui_url: http://35.226.101.220:8080 307 | web_ui_username: airflow 308 | 309 | The Airflow service can also be accessed from the airflow-webserver pod in the GKE cluster. 310 | Open your project console, navigate to the "Discovery and load balancing" page in GKE, and 311 | click on the endpoint link for the airflow-webserver to access the Airflow admin app. 312 | -------------------------------------------------------------------------------- /airflow/airflow.cfg: -------------------------------------------------------------------------------- 1 | [core] 2 | # The home folder for airflow, default is ~/airflow 3 | airflow_home = /Users/lramsey/work/tfrec/rec_serve/airflow 4 | 5 | # The folder where your airflow pipelines live, most likely a 6 | # subfolder in a code repository 7 | # This path must be absolute 8 | dags_folder = /Users/lramsey/work/tfrec/rec_serve/airflow/dags 9 | 10 | # The folder where airflow should store its log files 11 | # This path must be absolute 12 | base_log_folder = /Users/lramsey/work/tfrec/rec_serve/airflow/logs 13 | 14 | # Airflow can store logs remotely in AWS S3 or Google Cloud Storage. Users 15 | # must supply a remote location URL (starting with either 's3://...' or 16 | # 'gs://...') and an Airflow connection id that provides access to the storage 17 | # location. 18 | remote_base_log_folder = 19 | remote_log_conn_id = 20 | # Use server-side encryption for logs stored in S3 21 | encrypt_s3_logs = False 22 | # DEPRECATED option for remote log storage, use remote_base_log_folder instead! 23 | s3_log_folder = 24 | 25 | # The executor class that airflow should use. Choices include 26 | # SequentialExecutor, LocalExecutor, CeleryExecutor 27 | executor = SequentialExecutor 28 | 29 | # The SqlAlchemy connection string to the metadata database. 30 | # SqlAlchemy supports many different database engine, more information 31 | # their website 32 | sql_alchemy_conn = sqlite:////Users/lramsey/work/tfrec/rec_serve/airflow/airflow.db 33 | 34 | # The SqlAlchemy pool size is the maximum number of database connections 35 | # in the pool. 36 | sql_alchemy_pool_size = 5 37 | 38 | # The SqlAlchemy pool recycle is the number of seconds a connection 39 | # can be idle in the pool before it is invalidated. This config does 40 | # not apply to sqlite. 41 | sql_alchemy_pool_recycle = 3600 42 | 43 | # The amount of parallelism as a setting to the executor. This defines 44 | # the max number of task instances that should run simultaneously 45 | # on this airflow installation 46 | parallelism = 32 47 | 48 | # The number of task instances allowed to run concurrently by the scheduler 49 | dag_concurrency = 16 50 | 51 | # Are DAGs paused by default at creation 52 | dags_are_paused_at_creation = True 53 | 54 | # When not using pools, tasks are run in the "default pool", 55 | # whose size is guided by this config element 56 | non_pooled_task_slot_count = 128 57 | 58 | # The maximum number of active DAG runs per DAG 59 | max_active_runs_per_dag = 16 60 | 61 | # Whether to load the examples that ship with Airflow. It's good to 62 | # get started, but you probably want to set this to False in a production 63 | # environment 64 | load_examples = True 65 | 66 | # Where your Airflow plugins are stored 67 | plugins_folder = /Users/lramsey/work/tfrec/rec_serve/airflow/plugins 68 | 69 | # Secret key to save connection passwords in the db 70 | fernet_key = Oi7UoO_f2dzqhTM3Eb3Xg9cAbz0VubC_SMc5GHwZP14= 71 | 72 | # Whether to disable pickling dags 73 | donot_pickle = False 74 | 75 | # How long before timing out a python file import while filling the DagBag 76 | dagbag_import_timeout = 30 77 | 78 | # The class to use for running task instances in a subprocess 79 | task_runner = BashTaskRunner 80 | 81 | # If set, tasks without a `run_as_user` argument will be run with this user 82 | # Can be used to de-elevate a sudo user running Airflow when executing tasks 83 | default_impersonation = 84 | 85 | # What security module to use (for example kerberos): 86 | security = 87 | 88 | # Turn unit test mode on (overwrites many configuration options with test 89 | # values at runtime) 90 | unit_test_mode = False 91 | 92 | [cli] 93 | # In what way should the cli access the API. The LocalClient will use the 94 | # database directly, while the json_client will use the api running on the 95 | # webserver 96 | api_client = airflow.api.client.json_client 97 | endpoint_url = http://"-1-9.appspot.com:8080 98 | 99 | [api] 100 | # How to authenticate users of the API 101 | auth_backend = airflow.api.auth.backend.default 102 | 103 | [operators] 104 | # The default owner assigned to each new operator, unless 105 | # provided explicitly or passed via `default_args` 106 | default_owner = Airflow 107 | default_cpus = 1 108 | default_ram = 512 109 | default_disk = 512 110 | default_gpus = 0 111 | 112 | 113 | [webserver] 114 | # The base url of your website as airflow cannot guess what domain or 115 | # cname you are using. This is used in automated emails that 116 | # airflow sends to point links to the right web server 117 | base_url = http://localhost:8080 118 | 119 | # The ip specified when starting the web server 120 | web_server_host = 0.0.0.0 121 | 122 | # The port on which to run the web server 123 | web_server_port = 8080 124 | 125 | # Paths to the SSL certificate and key for the web server. When both are 126 | # provided SSL will be enabled. This does not change the web server port. 127 | web_server_ssl_cert = 128 | web_server_ssl_key = 129 | 130 | # Number of seconds the gunicorn webserver waits before timing out on a worker 131 | web_server_worker_timeout = 120 132 | 133 | # Number of workers to refresh at a time. When set to 0, worker refresh is 134 | # disabled. When nonzero, airflow periodically refreshes webserver workers by 135 | # bringing up new ones and killing old ones. 136 | worker_refresh_batch_size = 1 137 | 138 | # Number of seconds to wait before refreshing a batch of workers. 139 | worker_refresh_interval = 30 140 | 141 | # Secret key used to run your flask app 142 | secret_key = temporary_key 143 | 144 | # Number of workers to run the Gunicorn web server 145 | workers = 4 146 | 147 | # The worker class gunicorn should use. Choices include 148 | # sync (default), eventlet, gevent 149 | worker_class = sync 150 | 151 | # Log files for the gunicorn webserver. '-' means log to stderr. 152 | access_logfile = - 153 | error_logfile = - 154 | 155 | # Expose the configuration file in the web server 156 | expose_config = True 157 | 158 | # Set to true to turn on authentication: 159 | # http://pythonhosted.org/airflow/security.html#web-authentication 160 | authenticate = False 161 | 162 | # Filter the list of dags by owner name (requires authentication to be enabled) 163 | filter_by_owner = False 164 | 165 | # Filtering mode. Choices include user (default) and ldapgroup. 166 | # Ldap group filtering requires using the ldap backend 167 | # 168 | # Note that the ldap server needs the "memberOf" overlay to be set up 169 | # in order to user the ldapgroup mode. 170 | owner_mode = user 171 | 172 | # Default DAG orientation. Valid values are: 173 | # LR (Left->Right), TB (Top->Bottom), RL (Right->Left), BT (Bottom->Top) 174 | dag_orientation = LR 175 | 176 | # Puts the webserver in demonstration mode; blurs the names of Operators for 177 | # privacy. 178 | demo_mode = False 179 | 180 | # The amount of time (in secs) webserver will wait for initial handshake 181 | # while fetching logs from other worker machine 182 | log_fetch_timeout_sec = 5 183 | 184 | # By default, the webserver shows paused DAGs. Flip this to hide paused 185 | # DAGs by default 186 | hide_paused_dags_by_default = False 187 | 188 | [email] 189 | email_backend = airflow.utils.email.send_email_smtp 190 | 191 | 192 | [smtp] 193 | # If you want airflow to send emails on retries, failure, and you want to use 194 | # the airflow.utils.email.send_email_smtp function, you have to configure an 195 | # smtp server here 196 | smtp_host = localhost 197 | smtp_starttls = True 198 | smtp_ssl = False 199 | # Uncomment and set the user/pass settings if you want to use SMTP AUTH 200 | # smtp_user = airflow 201 | # smtp_password = airflow 202 | smtp_port = 25 203 | smtp_mail_from = airflow@airflow.com 204 | 205 | 206 | [celery] 207 | # This section only applies if you are using the CeleryExecutor in 208 | # [core] section above 209 | 210 | # The app name that will be used by celery 211 | celery_app_name = airflow.executors.celery_executor 212 | 213 | # The concurrency that will be used when starting workers with the 214 | # "airflow worker" command. This defines the number of task instances that 215 | # a worker will take, so size up your workers based on the resources on 216 | # your worker box and the nature of your tasks 217 | celeryd_concurrency = 16 218 | 219 | # When you start an airflow worker, airflow starts a tiny web server 220 | # subprocess to serve the workers local log files to the airflow main 221 | # web server, who then builds pages and sends them to users. This defines 222 | # the port on which the logs are served. It needs to be unused, and open 223 | # visible from the main web server to connect into the workers. 224 | worker_log_server_port = 8793 225 | 226 | # The Celery broker URL. Celery supports RabbitMQ, Redis and experimentally 227 | # a sqlalchemy database. Refer to the Celery documentation for more 228 | # information. 229 | broker_url = sqla+mysql://airflow:airflow@localhost:3306/airflow 230 | 231 | # Another key Celery setting 232 | celery_result_backend = db+mysql://airflow:airflow@localhost:3306/airflow 233 | 234 | # Celery Flower is a sweet UI for Celery. Airflow has a shortcut to start 235 | # it `airflow flower`. This defines the IP that Celery Flower runs on 236 | flower_host = 0.0.0.0 237 | 238 | # This defines the port that Celery Flower runs on 239 | flower_port = 5555 240 | 241 | # Default queue that tasks get assigned to and that worker listen on. 242 | default_queue = default 243 | 244 | 245 | [scheduler] 246 | # Task instances listen for external kill signal (when you clear tasks 247 | # from the CLI or the UI), this defines the frequency at which they should 248 | # listen (in seconds). 249 | job_heartbeat_sec = 5 250 | 251 | # The scheduler constantly tries to trigger new tasks (look at the 252 | # scheduler section in the docs for more information). This defines 253 | # how often the scheduler should run (in seconds). 254 | scheduler_heartbeat_sec = 5 255 | 256 | # after how much time should the scheduler terminate in seconds 257 | # -1 indicates to run continuously (see also num_runs) 258 | run_duration = -1 259 | 260 | # after how much time a new DAGs should be picked up from the filesystem 261 | min_file_process_interval = 0 262 | 263 | dag_dir_list_interval = 300 264 | 265 | # How often should stats be printed to the logs 266 | print_stats_interval = 30 267 | 268 | child_process_log_directory = /Users/lramsey/work/tfrec/rec_serve/airflow/logs/scheduler 269 | 270 | # Local task jobs periodically heartbeat to the DB. If the job has 271 | # not heartbeat in this many seconds, the scheduler will mark the 272 | # associated task instance as failed and will re-schedule the task. 273 | scheduler_zombie_task_threshold = 300 274 | 275 | # Turn off scheduler catchup by setting this to False. 276 | # Default behavior is unchanged and 277 | # Command Line Backfills still work, but the scheduler 278 | # will not do scheduler catchup if this is False, 279 | # however it can be set on a per DAG basis in the 280 | # DAG definition (catchup) 281 | catchup_by_default = True 282 | 283 | # Statsd (https://github.com/etsy/statsd) integration settings 284 | statsd_on = False 285 | statsd_host = localhost 286 | statsd_port = 8125 287 | statsd_prefix = airflow 288 | 289 | # The scheduler can run multiple threads in parallel to schedule dags. 290 | # This defines how many threads will run. However airflow will never 291 | # use more threads than the amount of cpu cores available. 292 | max_threads = 2 293 | 294 | authenticate = False 295 | 296 | 297 | [mesos] 298 | # Mesos master address which MesosExecutor will connect to. 299 | master = localhost:5050 300 | 301 | # The framework name which Airflow scheduler will register itself as on mesos 302 | framework_name = Airflow 303 | 304 | # Number of cpu cores required for running one task instance using 305 | # 'airflow run --local -p ' 306 | # command on a mesos slave 307 | task_cpu = 1 308 | 309 | # Memory in MB required for running one task instance using 310 | # 'airflow run --local -p ' 311 | # command on a mesos slave 312 | task_memory = 256 313 | 314 | # Enable framework checkpointing for mesos 315 | # See http://mesos.apache.org/documentation/latest/slave-recovery/ 316 | checkpoint = False 317 | 318 | # Failover timeout in milliseconds. 319 | # When checkpointing is enabled and this option is set, Mesos waits 320 | # until the configured timeout for 321 | # the MesosExecutor framework to re-register after a failover. Mesos 322 | # shuts down running tasks if the 323 | # MesosExecutor framework fails to re-register within this timeframe. 324 | # failover_timeout = 604800 325 | 326 | # Enable framework authentication for mesos 327 | # See http://mesos.apache.org/documentation/latest/configuration/ 328 | authenticate = False 329 | 330 | # Mesos credentials, if authentication is enabled 331 | # default_principal = admin 332 | # default_secret = admin 333 | 334 | 335 | [kerberos] 336 | ccache = /tmp/airflow_krb5_ccache 337 | # gets augmented with fqdn 338 | principal = airflow 339 | reinit_frequency = 3600 340 | kinit_path = kinit 341 | keytab = airflow.keytab 342 | 343 | 344 | [github_enterprise] 345 | api_rev = v3 346 | 347 | 348 | [admin] 349 | # UI to hide sensitive variable fields when set to True 350 | hide_sensitive_variable_fields = False 351 | 352 | -------------------------------------------------------------------------------- /wals_ml_engine/trainer/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """WALS model input data, training and predict functions.""" 16 | 17 | import datetime 18 | import numpy as np 19 | import os 20 | import pandas as pd 21 | from scipy.sparse import coo_matrix 22 | import sh 23 | import tensorflow as tf 24 | 25 | import wals 26 | 27 | # ratio of train set size to test set size 28 | TEST_SET_RATIO = 10 29 | 30 | # default hyperparameters 31 | DEFAULT_PARAMS = { 32 | 'weights': True, 33 | 'latent_factors': 5, 34 | 'num_iters': 20, 35 | 'regularization': 0.07, 36 | 'unobs_weight': 0.01, 37 | 'wt_type': 0, 38 | 'feature_wt_factor': 130.0, 39 | 'feature_wt_exp': 0.08, 40 | 'delimiter': '\t' 41 | } 42 | 43 | # parameters optimized with hypertuning for the MovieLens data set 44 | OPTIMIZED_PARAMS = { 45 | 'latent_factors': 34, 46 | 'regularization': 9.83, 47 | 'unobs_weight': 0.001, 48 | 'feature_wt_factor': 189.8, 49 | } 50 | 51 | # parameters optimized with hypertuning for the included web views data set 52 | OPTIMIZED_PARAMS_WEB = { 53 | 'latent_factors': 30, 54 | 'regularization': 7.27, 55 | 'unobs_weight': 0.01, 56 | 'feature_wt_exp': 5.05, 57 | } 58 | 59 | 60 | def create_test_and_train_sets(args, input_file, data_type='ratings'): 61 | """Create test and train sets, for different input data types. 62 | 63 | Args: 64 | args: input args for job 65 | input_file: path to csv data file 66 | data_type: 'ratings': MovieLens style ratings matrix 67 | 'web_views': Google Analytics time-on-page data 68 | 69 | Returns: 70 | array of user IDs for each row of the ratings matrix 71 | array of item IDs for each column of the rating matrix 72 | sparse coo_matrix for training 73 | sparse coo_matrix for test 74 | 75 | Raises: 76 | ValueError: if invalid data_type is supplied 77 | """ 78 | if data_type == 'ratings': 79 | return _ratings_train_and_test(args['headers'], args['delimiter'], 80 | input_file) 81 | elif data_type == 'web_views': 82 | return _page_views_train_and_test(input_file) 83 | else: 84 | raise ValueError('data_type arg value %s not supported.' % data_type) 85 | 86 | 87 | def _ratings_train_and_test(use_headers, delimiter, input_file): 88 | """Load data set. Assumes Movielens header, format etc. 89 | 90 | MovieLens data starts with user_id=1. The max user id is close to 91 | the number of users, but there may be missing user_id's or item ids 92 | (i.e. movies). For our sparse matrices we need to map the user/item ids 93 | down to a zero-based set of indices, without missing values. 94 | 95 | Args: 96 | use_headers: (boolean) true = headers, false = no headers 97 | delimiter: (string) delimiter to use for csv 98 | input_file: path to csv data file 99 | 100 | Returns: 101 | array of user IDs for each row of the ratings matrix 102 | array of item IDs for each column of the rating matrix 103 | sparse coo_matrix for training 104 | sparse coo_matrix for test 105 | """ 106 | headers = ['user_id', 'item_id', 'rating', 'timestamp'] 107 | header_row = 0 if use_headers else None 108 | ratings_df = pd.read_csv(input_file, 109 | sep=delimiter, 110 | names=headers, 111 | header=header_row, 112 | dtype={ 113 | 'user_id': np.int32, 114 | 'item_id': np.int32, 115 | 'rating': np.float32, 116 | 'timestamp': np.int32, 117 | }) 118 | 119 | np_users = ratings_df.user_id.as_matrix() 120 | np_items = ratings_df.item_id.as_matrix() 121 | unique_users = np.unique(np_users) 122 | unique_items = np.unique(np_items) 123 | 124 | n_users = unique_users.shape[0] 125 | n_items = unique_items.shape[0] 126 | 127 | # make indexes for users and items if necessary 128 | max_user = unique_users[-1] 129 | max_item = unique_items[-1] 130 | if n_users != max_user or n_items != max_item: 131 | # make an array of 0-indexed unique user ids corresponding to the dataset 132 | # stack of user ids 133 | z = np.zeros(max_user+1, dtype=int) 134 | z[unique_users] = np.arange(n_users) 135 | u_r = z[np_users] 136 | 137 | # make an array of 0-indexed unique item ids corresponding to the dataset 138 | # stack of item ids 139 | z = np.zeros(max_item+1, dtype=int) 140 | z[unique_items] = np.arange(n_items) 141 | i_r = z[np_items] 142 | 143 | # construct the ratings set from the three stacks 144 | np_ratings = ratings_df.rating.as_matrix() 145 | ratings = np.zeros((np_ratings.shape[0], 3), dtype=object) 146 | ratings[:, 0] = u_r 147 | ratings[:, 1] = i_r 148 | ratings[:, 2] = np_ratings 149 | else: 150 | ratings = ratings_df.as_matrix(['user_id', 'item_id', 'rating']) 151 | # deal with 1-based user indices 152 | ratings[:, 0] -= 1 153 | ratings[:, 1] -= 1 154 | 155 | tr_sparse, test_sparse = _create_sparse_train_and_test(ratings, 156 | n_users, n_items) 157 | 158 | return ratings[:, 0], ratings[:, 1], tr_sparse, test_sparse 159 | 160 | 161 | def _page_views_train_and_test(input_file): 162 | """Load page views dataset, and create train and set sparse matrices. 163 | 164 | Assumes 'clientId', 'contentId', and 'timeOnPage' columns. 165 | 166 | Args: 167 | input_file: path to csv data file 168 | 169 | Returns: 170 | array of user IDs for each row of the ratings matrix 171 | array of item IDs for each column of the rating matrix 172 | sparse coo_matrix for training 173 | sparse coo_matrix for test 174 | """ 175 | views_df = pd.read_csv(input_file, sep=',', header=0) 176 | 177 | df_items = pd.DataFrame({'contentId': views_df.contentId.unique()}) 178 | df_sorted_items = df_items.sort_values('contentId').reset_index() 179 | pds_items = df_sorted_items.contentId 180 | 181 | # preprocess data. df.groupby.agg sorts clientId and contentId 182 | df_user_items = views_df.groupby(['clientId', 'contentId'] 183 | ).agg({'timeOnPage': 'sum'}) 184 | 185 | # create a list of (userId, itemId, timeOnPage) ratings, where userId and 186 | # clientId are 0-indexed 187 | current_u = -1 188 | ux = -1 189 | pv_ratings = [] 190 | user_ux = [] 191 | for timeonpg in df_user_items.itertuples(): 192 | user = timeonpg[0][0] 193 | item = timeonpg[0][1] 194 | 195 | # as we go, build a (sorted) list of user ids 196 | if user != current_u: 197 | user_ux.append(user) 198 | ux += 1 199 | current_u = user 200 | 201 | # this search makes the preprocessing time O(r * i log(i)), 202 | # r = # ratings, i = # items 203 | ix = pds_items.searchsorted(item)[0] 204 | pv_ratings.append((ux, ix, timeonpg[1])) 205 | 206 | # convert ratings list and user map to np array 207 | pv_ratings = np.asarray(pv_ratings) 208 | user_ux = np.asarray(user_ux) 209 | 210 | # create train and test sets 211 | tr_sparse, test_sparse = _create_sparse_train_and_test(pv_ratings, 212 | ux + 1, 213 | df_items.size) 214 | 215 | return user_ux, pds_items.as_matrix(), tr_sparse, test_sparse 216 | 217 | 218 | def _create_sparse_train_and_test(ratings, n_users, n_items): 219 | """Given ratings, create sparse matrices for train and test sets. 220 | 221 | Args: 222 | ratings: list of ratings tuples (u, i, r) 223 | n_users: number of users 224 | n_items: number of items 225 | 226 | Returns: 227 | train, test sparse matrices in scipy coo_matrix format. 228 | """ 229 | # pick a random test set of entries, sorted ascending 230 | test_set_size = len(ratings) / TEST_SET_RATIO 231 | test_set_idx = np.random.choice(xrange(len(ratings)), 232 | size=test_set_size, replace=False) 233 | test_set_idx = sorted(test_set_idx) 234 | 235 | # sift ratings into train and test sets 236 | ts_ratings = ratings[test_set_idx] 237 | tr_ratings = np.delete(ratings, test_set_idx, axis=0) 238 | 239 | # create training and test matrices as coo_matrix's 240 | u_tr, i_tr, r_tr = zip(*tr_ratings) 241 | tr_sparse = coo_matrix((r_tr, (u_tr, i_tr)), shape=(n_users, n_items)) 242 | 243 | u_ts, i_ts, r_ts = zip(*ts_ratings) 244 | test_sparse = coo_matrix((r_ts, (u_ts, i_ts)), shape=(n_users, n_items)) 245 | 246 | return tr_sparse, test_sparse 247 | 248 | 249 | def train_model(args, tr_sparse): 250 | """Instantiate WALS model and use "simple_train" to factorize the matrix. 251 | 252 | Args: 253 | args: training args containing hyperparams 254 | tr_sparse: sparse training matrix 255 | 256 | Returns: 257 | the row and column factors in numpy format. 258 | """ 259 | dim = args['latent_factors'] 260 | num_iters = args['num_iters'] 261 | reg = args['regularization'] 262 | unobs = args['unobs_weight'] 263 | wt_type = args['wt_type'] 264 | feature_wt_exp = args['feature_wt_exp'] 265 | obs_wt = args['feature_wt_factor'] 266 | 267 | tf.logging.info('Train Start: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now())) 268 | 269 | # generate model 270 | input_tensor, row_factor, col_factor, model = wals.wals_model(tr_sparse, 271 | dim, 272 | reg, 273 | unobs, 274 | args['weights'], 275 | wt_type, 276 | feature_wt_exp, 277 | obs_wt) 278 | 279 | # factorize matrix 280 | session = wals.simple_train(model, input_tensor, num_iters) 281 | 282 | tf.logging.info('Train Finish: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now())) 283 | 284 | # evaluate output factor matrices 285 | output_row = row_factor.eval(session=session) 286 | output_col = col_factor.eval(session=session) 287 | 288 | # close the training session now that we've evaluated the output 289 | session.close() 290 | 291 | return output_row, output_col 292 | 293 | 294 | def save_model(args, user_map, item_map, row_factor, col_factor): 295 | """Save the user map, item map, row factor and column factor matrices in numpy format. 296 | 297 | These matrices together constitute the "recommendation model." 298 | 299 | Args: 300 | args: input args to training job 301 | user_map: user map numpy array 302 | item_map: item map numpy array 303 | row_factor: row_factor numpy array 304 | col_factor: col_factor numpy array 305 | """ 306 | model_dir = os.path.join(args['output_dir'], 'model') 307 | 308 | # if our output directory is a GCS bucket, write model files to /tmp, 309 | # then copy to GCS 310 | gs_model_dir = None 311 | if model_dir.startswith('gs://'): 312 | gs_model_dir = model_dir 313 | model_dir = '/tmp/{0}'.format(args['job_name']) 314 | 315 | os.makedirs(model_dir) 316 | np.save(os.path.join(model_dir, 'user'), user_map) 317 | np.save(os.path.join(model_dir, 'item'), item_map) 318 | np.save(os.path.join(model_dir, 'row'), row_factor) 319 | np.save(os.path.join(model_dir, 'col'), col_factor) 320 | 321 | if gs_model_dir: 322 | sh.gsutil('cp', '-r', os.path.join(model_dir, '*'), gs_model_dir) 323 | 324 | 325 | def generate_recommendations(user_idx, user_rated, row_factor, col_factor, k): 326 | """Generate recommendations for a user. 327 | 328 | Args: 329 | user_idx: the row index of the user in the ratings matrix, 330 | 331 | user_rated: the list of item indexes (column indexes in the ratings matrix) 332 | previously rated by that user (which will be excluded from the 333 | recommendations) 334 | 335 | row_factor: the row factors of the recommendation model 336 | col_factor: the column factors of the recommendation model 337 | 338 | k: number of recommendations requested 339 | 340 | Returns: 341 | list of k item indexes with the predicted highest rating, excluding 342 | those that the user has already rated 343 | """ 344 | 345 | # bounds checking for args 346 | assert (row_factor.shape[0] - len(user_rated)) >= k 347 | 348 | # retrieve user factor 349 | user_f = row_factor[user_idx] 350 | 351 | # dot product of item factors with user factor gives predicted ratings 352 | pred_ratings = col_factor.dot(user_f) 353 | 354 | # find candidate recommended item indexes sorted by predicted rating 355 | k_r = k + len(user_rated) 356 | candidate_items = np.argsort(pred_ratings)[-k_r:] 357 | 358 | # remove previously rated items and take top k 359 | recommended_items = [i for i in candidate_items if i not in user_rated] 360 | recommended_items = recommended_items[-k:] 361 | 362 | # flip to sort highest rated first 363 | recommended_items.reverse() 364 | 365 | return recommended_items 366 | 367 | -------------------------------------------------------------------------------- /data/ga_sessions_sample_schema.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "mode": "NULLABLE", 4 | "name": "visitorId", 5 | "type": "INTEGER" 6 | }, 7 | { 8 | "mode": "NULLABLE", 9 | "name": "visitNumber", 10 | "type": "INTEGER" 11 | }, 12 | { 13 | "mode": "NULLABLE", 14 | "name": "visitStartTime", 15 | "type": "INTEGER" 16 | }, 17 | { 18 | "mode": "NULLABLE", 19 | "name": "date", 20 | "type": "STRING" 21 | }, 22 | { 23 | "fields": [ 24 | { 25 | "mode": "NULLABLE", 26 | "name": "visits", 27 | "type": "INTEGER" 28 | }, 29 | { 30 | "mode": "NULLABLE", 31 | "name": "hits", 32 | "type": "INTEGER" 33 | }, 34 | { 35 | "mode": "NULLABLE", 36 | "name": "pageviews", 37 | "type": "INTEGER" 38 | }, 39 | { 40 | "mode": "NULLABLE", 41 | "name": "timeOnSite", 42 | "type": "INTEGER" 43 | }, 44 | { 45 | "mode": "NULLABLE", 46 | "name": "bounces", 47 | "type": "INTEGER" 48 | }, 49 | { 50 | "mode": "NULLABLE", 51 | "name": "transactions", 52 | "type": "INTEGER" 53 | }, 54 | { 55 | "mode": "NULLABLE", 56 | "name": "transactionRevenue", 57 | "type": "INTEGER" 58 | }, 59 | { 60 | "mode": "NULLABLE", 61 | "name": "newVisits", 62 | "type": "INTEGER" 63 | }, 64 | { 65 | "mode": "NULLABLE", 66 | "name": "screenviews", 67 | "type": "INTEGER" 68 | }, 69 | { 70 | "mode": "NULLABLE", 71 | "name": "uniqueScreenviews", 72 | "type": "INTEGER" 73 | }, 74 | { 75 | "mode": "NULLABLE", 76 | "name": "timeOnScreen", 77 | "type": "INTEGER" 78 | }, 79 | { 80 | "mode": "NULLABLE", 81 | "name": "totalTransactionRevenue", 82 | "type": "INTEGER" 83 | }, 84 | { 85 | "mode": "NULLABLE", 86 | "name": "sessionQualityDim", 87 | "type": "INTEGER" 88 | } 89 | ], 90 | "mode": "NULLABLE", 91 | "name": "totals", 92 | "type": "RECORD" 93 | }, 94 | { 95 | "fields": [ 96 | { 97 | "mode": "NULLABLE", 98 | "name": "browser", 99 | "type": "STRING" 100 | }, 101 | { 102 | "mode": "NULLABLE", 103 | "name": "browserVersion", 104 | "type": "STRING" 105 | }, 106 | { 107 | "mode": "NULLABLE", 108 | "name": "browserSize", 109 | "type": "STRING" 110 | }, 111 | { 112 | "mode": "NULLABLE", 113 | "name": "operatingSystem", 114 | "type": "STRING" 115 | }, 116 | { 117 | "mode": "NULLABLE", 118 | "name": "operatingSystemVersion", 119 | "type": "STRING" 120 | }, 121 | { 122 | "mode": "NULLABLE", 123 | "name": "isMobile", 124 | "type": "BOOLEAN" 125 | }, 126 | { 127 | "mode": "NULLABLE", 128 | "name": "mobileDeviceBranding", 129 | "type": "STRING" 130 | }, 131 | { 132 | "mode": "NULLABLE", 133 | "name": "mobileDeviceModel", 134 | "type": "STRING" 135 | }, 136 | { 137 | "mode": "NULLABLE", 138 | "name": "mobileInputSelector", 139 | "type": "STRING" 140 | }, 141 | { 142 | "mode": "NULLABLE", 143 | "name": "mobileDeviceInfo", 144 | "type": "STRING" 145 | }, 146 | { 147 | "mode": "NULLABLE", 148 | "name": "mobileDeviceMarketingName", 149 | "type": "STRING" 150 | }, 151 | { 152 | "mode": "NULLABLE", 153 | "name": "flashVersion", 154 | "type": "STRING" 155 | }, 156 | { 157 | "mode": "NULLABLE", 158 | "name": "javaEnabled", 159 | "type": "BOOLEAN" 160 | }, 161 | { 162 | "mode": "NULLABLE", 163 | "name": "language", 164 | "type": "STRING" 165 | }, 166 | { 167 | "mode": "NULLABLE", 168 | "name": "screenColors", 169 | "type": "STRING" 170 | }, 171 | { 172 | "mode": "NULLABLE", 173 | "name": "screenResolution", 174 | "type": "STRING" 175 | }, 176 | { 177 | "mode": "NULLABLE", 178 | "name": "deviceCategory", 179 | "type": "STRING" 180 | } 181 | ], 182 | "mode": "NULLABLE", 183 | "name": "device", 184 | "type": "RECORD" 185 | }, 186 | { 187 | "fields": [ 188 | { 189 | "mode": "NULLABLE", 190 | "name": "index", 191 | "type": "INTEGER" 192 | }, 193 | { 194 | "mode": "NULLABLE", 195 | "name": "value", 196 | "type": "STRING" 197 | } 198 | ], 199 | "mode": "REPEATED", 200 | "name": "customDimensions", 201 | "type": "RECORD" 202 | }, 203 | { 204 | "fields": [ 205 | { 206 | "mode": "NULLABLE", 207 | "name": "hitNumber", 208 | "type": "INTEGER" 209 | }, 210 | { 211 | "mode": "NULLABLE", 212 | "name": "time", 213 | "type": "INTEGER" 214 | }, 215 | { 216 | "mode": "NULLABLE", 217 | "name": "hour", 218 | "type": "INTEGER" 219 | }, 220 | { 221 | "mode": "NULLABLE", 222 | "name": "minute", 223 | "type": "INTEGER" 224 | }, 225 | { 226 | "mode": "NULLABLE", 227 | "name": "isSecure", 228 | "type": "BOOLEAN" 229 | }, 230 | { 231 | "mode": "NULLABLE", 232 | "name": "isInteraction", 233 | "type": "BOOLEAN" 234 | }, 235 | { 236 | "mode": "NULLABLE", 237 | "name": "isEntrance", 238 | "type": "BOOLEAN" 239 | }, 240 | { 241 | "mode": "NULLABLE", 242 | "name": "isExit", 243 | "type": "BOOLEAN" 244 | }, 245 | { 246 | "fields": [ 247 | { 248 | "mode": "NULLABLE", 249 | "name": "transactionId", 250 | "type": "STRING" 251 | }, 252 | { 253 | "mode": "NULLABLE", 254 | "name": "transactionRevenue", 255 | "type": "INTEGER" 256 | }, 257 | { 258 | "mode": "NULLABLE", 259 | "name": "transactionTax", 260 | "type": "INTEGER" 261 | }, 262 | { 263 | "mode": "NULLABLE", 264 | "name": "transactionShipping", 265 | "type": "INTEGER" 266 | }, 267 | { 268 | "mode": "NULLABLE", 269 | "name": "affiliation", 270 | "type": "STRING" 271 | }, 272 | { 273 | "mode": "NULLABLE", 274 | "name": "currencyCode", 275 | "type": "STRING" 276 | }, 277 | { 278 | "mode": "NULLABLE", 279 | "name": "localTransactionRevenue", 280 | "type": "INTEGER" 281 | }, 282 | { 283 | "mode": "NULLABLE", 284 | "name": "localTransactionTax", 285 | "type": "INTEGER" 286 | }, 287 | { 288 | "mode": "NULLABLE", 289 | "name": "localTransactionShipping", 290 | "type": "INTEGER" 291 | }, 292 | { 293 | "mode": "NULLABLE", 294 | "name": "transactionCoupon", 295 | "type": "STRING" 296 | } 297 | ], 298 | "mode": "NULLABLE", 299 | "name": "transaction", 300 | "type": "RECORD" 301 | }, 302 | { 303 | "fields": [ 304 | { 305 | "mode": "NULLABLE", 306 | "name": "index", 307 | "type": "INTEGER" 308 | }, 309 | { 310 | "mode": "NULLABLE", 311 | "name": "value", 312 | "type": "STRING" 313 | } 314 | ], 315 | "mode": "REPEATED", 316 | "name": "customDimensions", 317 | "type": "RECORD" 318 | }, 319 | { 320 | "fields": [ 321 | { 322 | "mode": "NULLABLE", 323 | "name": "eventCategory", 324 | "type": "STRING" 325 | }, 326 | { 327 | "mode": "NULLABLE", 328 | "name": "eventAction", 329 | "type": "STRING" 330 | }, 331 | { 332 | "mode": "NULLABLE", 333 | "name": "eventLabel", 334 | "type": "STRING" 335 | }, 336 | { 337 | "mode": "NULLABLE", 338 | "name": "eventValue", 339 | "type": "INTEGER" 340 | } 341 | ], 342 | "mode": "NULLABLE", 343 | "name": "eventInfo", 344 | "type": "RECORD" 345 | }, 346 | { 347 | "fields": [ 348 | { 349 | "mode": "NULLABLE", 350 | "name": "productSKU", 351 | "type": "STRING" 352 | }, 353 | { 354 | "mode": "NULLABLE", 355 | "name": "v2ProductName", 356 | "type": "STRING" 357 | }, 358 | { 359 | "mode": "NULLABLE", 360 | "name": "v2ProductCategory", 361 | "type": "STRING" 362 | }, 363 | { 364 | "mode": "NULLABLE", 365 | "name": "productVariant", 366 | "type": "STRING" 367 | }, 368 | { 369 | "mode": "NULLABLE", 370 | "name": "productBrand", 371 | "type": "STRING" 372 | }, 373 | { 374 | "mode": "NULLABLE", 375 | "name": "productRevenue", 376 | "type": "INTEGER" 377 | }, 378 | { 379 | "mode": "NULLABLE", 380 | "name": "localProductRevenue", 381 | "type": "INTEGER" 382 | }, 383 | { 384 | "mode": "NULLABLE", 385 | "name": "productPrice", 386 | "type": "INTEGER" 387 | }, 388 | { 389 | "mode": "NULLABLE", 390 | "name": "localProductPrice", 391 | "type": "INTEGER" 392 | }, 393 | { 394 | "mode": "NULLABLE", 395 | "name": "productQuantity", 396 | "type": "INTEGER" 397 | }, 398 | { 399 | "mode": "NULLABLE", 400 | "name": "productRefundAmount", 401 | "type": "INTEGER" 402 | }, 403 | { 404 | "mode": "NULLABLE", 405 | "name": "localProductRefundAmount", 406 | "type": "INTEGER" 407 | }, 408 | { 409 | "mode": "NULLABLE", 410 | "name": "isImpression", 411 | "type": "BOOLEAN" 412 | }, 413 | { 414 | "mode": "NULLABLE", 415 | "name": "isClick", 416 | "type": "BOOLEAN" 417 | }, 418 | { 419 | "fields": [ 420 | { 421 | "mode": "NULLABLE", 422 | "name": "index", 423 | "type": "INTEGER" 424 | }, 425 | { 426 | "mode": "NULLABLE", 427 | "name": "value", 428 | "type": "STRING" 429 | } 430 | ], 431 | "mode": "REPEATED", 432 | "name": "customDimensions", 433 | "type": "RECORD" 434 | }, 435 | { 436 | "fields": [ 437 | { 438 | "mode": "NULLABLE", 439 | "name": "index", 440 | "type": "INTEGER" 441 | }, 442 | { 443 | "mode": "NULLABLE", 444 | "name": "value", 445 | "type": "INTEGER" 446 | } 447 | ], 448 | "mode": "REPEATED", 449 | "name": "customMetrics", 450 | "type": "RECORD" 451 | }, 452 | { 453 | "mode": "NULLABLE", 454 | "name": "productListName", 455 | "type": "STRING" 456 | }, 457 | { 458 | "mode": "NULLABLE", 459 | "name": "productListPosition", 460 | "type": "INTEGER" 461 | } 462 | ], 463 | "mode": "REPEATED", 464 | "name": "product", 465 | "type": "RECORD" 466 | }, 467 | { 468 | "fields": [ 469 | { 470 | "mode": "NULLABLE", 471 | "name": "index", 472 | "type": "INTEGER" 473 | }, 474 | { 475 | "mode": "NULLABLE", 476 | "name": "customVarName", 477 | "type": "STRING" 478 | }, 479 | { 480 | "mode": "NULLABLE", 481 | "name": "customVarValue", 482 | "type": "STRING" 483 | } 484 | ], 485 | "mode": "REPEATED", 486 | "name": "customVariables", 487 | "type": "RECORD" 488 | }, 489 | { 490 | "fields": [ 491 | { 492 | "mode": "NULLABLE", 493 | "name": "index", 494 | "type": "INTEGER" 495 | }, 496 | { 497 | "mode": "NULLABLE", 498 | "name": "value", 499 | "type": "INTEGER" 500 | } 501 | ], 502 | "mode": "REPEATED", 503 | "name": "customMetrics", 504 | "type": "RECORD" 505 | }, 506 | { 507 | "mode": "NULLABLE", 508 | "name": "type", 509 | "type": "STRING" 510 | }, 511 | { 512 | "fields": [ 513 | { 514 | "mode": "NULLABLE", 515 | "name": "pageLoadSample", 516 | "type": "INTEGER" 517 | }, 518 | { 519 | "mode": "NULLABLE", 520 | "name": "pageLoadTime", 521 | "type": "INTEGER" 522 | }, 523 | { 524 | "mode": "NULLABLE", 525 | "name": "pageDownloadTime", 526 | "type": "INTEGER" 527 | }, 528 | { 529 | "mode": "NULLABLE", 530 | "name": "redirectionTime", 531 | "type": "INTEGER" 532 | }, 533 | { 534 | "mode": "NULLABLE", 535 | "name": "speedMetricsSample", 536 | "type": "INTEGER" 537 | }, 538 | { 539 | "mode": "NULLABLE", 540 | "name": "domainLookupTime", 541 | "type": "INTEGER" 542 | }, 543 | { 544 | "mode": "NULLABLE", 545 | "name": "serverConnectionTime", 546 | "type": "INTEGER" 547 | }, 548 | { 549 | "mode": "NULLABLE", 550 | "name": "serverResponseTime", 551 | "type": "INTEGER" 552 | }, 553 | { 554 | "mode": "NULLABLE", 555 | "name": "domLatencyMetricsSample", 556 | "type": "INTEGER" 557 | }, 558 | { 559 | "mode": "NULLABLE", 560 | "name": "domInteractiveTime", 561 | "type": "INTEGER" 562 | }, 563 | { 564 | "mode": "NULLABLE", 565 | "name": "domContentLoadedTime", 566 | "type": "INTEGER" 567 | }, 568 | { 569 | "mode": "NULLABLE", 570 | "name": "userTimingValue", 571 | "type": "INTEGER" 572 | }, 573 | { 574 | "mode": "NULLABLE", 575 | "name": "userTimingSample", 576 | "type": "INTEGER" 577 | }, 578 | { 579 | "mode": "NULLABLE", 580 | "name": "userTimingVariable", 581 | "type": "STRING" 582 | }, 583 | { 584 | "mode": "NULLABLE", 585 | "name": "userTimingCategory", 586 | "type": "STRING" 587 | }, 588 | { 589 | "mode": "NULLABLE", 590 | "name": "userTimingLabel", 591 | "type": "STRING" 592 | } 593 | ], 594 | "mode": "NULLABLE", 595 | "name": "latencyTracking", 596 | "type": "RECORD" 597 | }, 598 | { 599 | "fields": [ 600 | { 601 | "mode": "NULLABLE", 602 | "name": "sourcePropertyDisplayName", 603 | "type": "STRING" 604 | }, 605 | { 606 | "mode": "NULLABLE", 607 | "name": "sourcePropertyTrackingId", 608 | "type": "STRING" 609 | } 610 | ], 611 | "mode": "NULLABLE", 612 | "name": "sourcePropertyInfo", 613 | "type": "RECORD" 614 | }, 615 | { 616 | "mode": "NULLABLE", 617 | "name": "dataSource", 618 | "type": "STRING" 619 | } 620 | ], 621 | "mode": "REPEATED", 622 | "name": "hits", 623 | "type": "RECORD" 624 | }, 625 | { 626 | "mode": "NULLABLE", 627 | "name": "fullVisitorId", 628 | "type": "STRING" 629 | }, 630 | { 631 | "mode": "NULLABLE", 632 | "name": "userId", 633 | "type": "STRING" 634 | }, 635 | { 636 | "mode": "NULLABLE", 637 | "name": "channelGrouping", 638 | "type": "STRING" 639 | }, 640 | { 641 | "mode": "NULLABLE", 642 | "name": "socialEngagementType", 643 | "type": "STRING" 644 | } 645 | ] -------------------------------------------------------------------------------- /airflow/deploy/gcp_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Helper functions for interacting with Google Cloud Platform APIs. 16 | """ 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import logging 23 | import time 24 | 25 | from apiclient import discovery 26 | import apiclient.errors 27 | 28 | 29 | # Cap exponential backoff for polling APIs 30 | MAX_POLL_INTERVAL_SECS = 15 31 | 32 | 33 | def empty_gcs_bucket(bucket_name, credentials): 34 | """Attempts to delete all objects in a bucket. 35 | 36 | If concurrent object creations occur while the bucket is being 37 | emptied, those objects may not be deleted and may cause bucket 38 | deletion to fail. 39 | 40 | Args: 41 | bucket_name: a string specifying the bucket to empty 42 | credentials: oauth2client.Credentials to be used for 43 | authentication 44 | """ 45 | logging.info("Emptying GCS bucket: %s", bucket_name) 46 | service = discovery.build('storage', 'v1', credentials=credentials) 47 | response = service.objects().list(bucket=bucket_name).execute() 48 | _delete_resources(bucket_name, response.get('items', []), credentials) 49 | while 'nextPageToken' in response: 50 | response = service.objects().list( 51 | bucket=bucket_name, pageToken=response['nextPageToken']).execute() 52 | _delete_resources(bucket_name, response.get('items', []), credentials) 53 | 54 | 55 | def _delete_resources(bucket_name, resources, credentials): 56 | """Deletes the specified resources from the given bucket. 57 | 58 | Resources are represented as described in 59 | https://cloud.google.com/storage/docs/json_api/v1/objects#resource 60 | 61 | Args: 62 | bucket_name: a string specifying the bucket from which to 63 | delete 64 | resources: a list of resources 65 | credentials: oauth2client.Credentials to be used for 66 | authentication 67 | """ 68 | logging.info("Deleting %s resources.", len(resources)) 69 | service = discovery.build('storage', 'v1', credentials=credentials) 70 | for r in resources: 71 | try: 72 | service.objects().delete( 73 | bucket=bucket_name, 74 | object=r['name']).execute() 75 | except apiclient.errors.HttpError as e: 76 | logging.warning('Error deleting %s: %s', r, e) 77 | 78 | 79 | def create_gcs_bucket(bucket_name, location, project, credentials): 80 | """Attempts to create a Google Cloud Storage bucket. 81 | 82 | Args: 83 | bucket_name: a string specifying the name of the bucket to 84 | create 85 | location: a string specifying the location where the bucket 86 | should be allocated. See 87 | https://cloud.google.com/storage/docs/bucket-locations 88 | for an authoritative list of values. 89 | project: a string specifying the GCP project in which to create 90 | the bucket 91 | credentials: oauth2client.Credentials to be used for 92 | authentication 93 | 94 | Returns: 95 | True if a bucket named bucket_name was successfully created, False 96 | otherwise. Note that False will be returned if there was already a 97 | bucket with the provided bucket_name. 98 | """ 99 | service = discovery.build('storage', 'v1', credentials=credentials) 100 | body = {'name': bucket_name, 'location': location} 101 | try: 102 | service.buckets().insert(project=project, body=body).execute() 103 | logging.info('Created GCS bucket gs://%s', bucket_name) 104 | return True 105 | except apiclient.errors.HttpError as e: 106 | logging.warn('Failed to create GCS bucket gs://%s. %s', bucket_name, e) 107 | return False 108 | 109 | 110 | def delete_gcs_bucket(bucket_name, credentials, force=False): 111 | """Attempts to delete a Google Cloud Storage bucket. 112 | 113 | The REST API doesn't allow for deletion of non-empty buckets; 114 | use force=True to attempt to empty the bucket prior to deletion. 115 | If concurrent object creations occur while the bucket is being 116 | emptied, those objects may not be deleted and may cause bucket 117 | deletion to fail. 118 | 119 | Args: 120 | bucket_name: a string specifying the name of the bucket to 121 | delete 122 | credentials: oauth2client.Credentials to be used for 123 | authentication 124 | force: a boolean specifying whether or not to attempt to empty 125 | the bucket prior to deletion. 126 | 127 | Returns: 128 | True if a bucket named bucket_name was successfully deleted, False 129 | otherwise. 130 | """ 131 | if force: 132 | empty_gcs_bucket(bucket_name, credentials) 133 | 134 | service = discovery.build('storage', 'v1', credentials=credentials) 135 | try: 136 | resp = service.buckets().delete(bucket=bucket_name).execute() 137 | # An empty response indicates a successful deletion. 138 | # https://cloud.google.com/storage/docs/json_api/v1/buckets/delete 139 | return not bool(resp) 140 | except apiclient.errors.HttpError as e: 141 | logging.warn('Error deleting GCS bucket %s: %s', bucket_name, e) 142 | return False 143 | 144 | 145 | def create_gke_cluster(cluster_name, project, zone, credentials): 146 | """Tries to create a GKE cluster. 147 | 148 | TODO(wwlian): Expose more of the node pool's configuration as 149 | needed. 150 | 151 | Args: 152 | cluster_name: string specifying the desired cluster name 153 | project: a string specifying the GCP project in which to create 154 | the cluster 155 | zone: string specifying the GCE zone in which to create the 156 | cluster 157 | credentials: oauth2client.Credentials to be used for 158 | authentication 159 | 160 | Returns: 161 | True if a new cluster with the provided name has been created, 162 | False otherwise. Note that False will be returned if a cluster 163 | with the provided cluster_name already existed. 164 | """ 165 | service = discovery.build('container', 'v1', credentials=credentials) 166 | cluster_body = { 167 | 'name': cluster_name, 168 | 'zone': zone, 169 | 'network': 'default', 170 | 'loggingService': 'logging.googleapis.com', 171 | 'monitoringService': 'none', 172 | 'subnetwork': 'default', 173 | 'nodePools': [{ 174 | 'initialNodeCount': 3, 175 | 'config': { 176 | 'machineType': 'n1-standard-1', 177 | 'imageType': 'GCI', 178 | 'diskSizeGb': 100, 179 | 'oauthScopes': [ 180 | 'https://www.googleapis.com/auth/compute', 181 | 'https://www.googleapis.com/auth/devstorage.read_write', 182 | 'https://www.googleapis.com/auth/sqlservice.admin', 183 | 'https://www.googleapis.com/auth/logging.write', 184 | 'https://www.googleapis.com/auth/servicecontrol', 185 | 'https://www.googleapis.com/auth/service.management.' 186 | 'readonly', 187 | 'https://www.googleapis.com/auth/trace.append', 188 | 'https://www.googleapis.com/auth/source.read_only', 189 | 'https://www.googleapis.com/auth/cloud-platform' 190 | ] 191 | }, 192 | 'autoscaling': { 193 | 'enabled': False 194 | }, 195 | 'management': { 196 | 'autoUpgrade': False, 197 | 'autoRepair': False, 198 | 'upgradeOptions': {} 199 | }, 200 | 'name': 'default-pool' 201 | }], 202 | 'masterAuth': { 203 | 'username': 'admin' 204 | } 205 | } 206 | 207 | request = service.projects().zones().clusters().create( 208 | projectId=project, zone=zone, body={'cluster': cluster_body}) 209 | 210 | logging.info('Waiting for GKE cluster creation: %s', cluster_name) 211 | if not _wait_for_operation(request, 212 | _gke_op_poller_factory(service, project, zone)): 213 | logging.warn('GKE cluster creation failed: %s', cluster_name) 214 | return False 215 | 216 | # Verify creation by tring to retrieve cluster info. 217 | request = service.projects().zones().clusters().get( 218 | projectId=project, zone=zone, clusterId=cluster_name) 219 | try: 220 | request.execute() 221 | logging.info('Created GKE cluster: %s', cluster_name) 222 | return True 223 | except apiclient.errors.HttpError as e: 224 | logging.warn(str(e)) 225 | return False 226 | 227 | 228 | def delete_gke_cluster(cluster_name, project, zone, credentials): 229 | """Attempts to delete a GKE cluster. 230 | 231 | Args: 232 | cluster_name: A string specifying the cluster to delete 233 | project: a string specifying the GCP project in which the 234 | cluster resides 235 | zone: The zone from which to delete the cluster 236 | credentials: oauth2client.Credentials to be used for 237 | authentication 238 | 239 | Returns: 240 | True if the specified cluster was successfully deleted from the 241 | specified zone; False otherwise. 242 | """ 243 | service = discovery.build('container', 'v1', credentials=credentials) 244 | 245 | # If the cluster is in the process of being provisioned, we have to wait 246 | # until it is up and running before we can initiate deletion. 247 | request = service.projects().zones().clusters().get( 248 | projectId=project, zone=zone, clusterId=cluster_name) 249 | while True: 250 | try: 251 | cluster = request.execute() 252 | except apiclient.errors.HttpError: 253 | # No such cluster; this will get caught when we try to delete 254 | # it. 255 | break 256 | if cluster['status'] == 'RUNNING': 257 | break 258 | 259 | request = service.projects().zones().clusters().delete( 260 | projectId=project, zone=zone, clusterId=cluster_name) 261 | 262 | if not _wait_for_operation( 263 | request, _gke_op_poller_factory(service, project, zone)): 264 | return False 265 | 266 | # Verify deletion by tring to retrieve cluster info. 267 | request = service.projects().zones().clusters().get( 268 | projectId=project, zone=zone, clusterId=cluster_name) 269 | try: 270 | request.execute() 271 | return False 272 | except apiclient.errors.HttpError as e: 273 | return e.resp['status'] == '404' 274 | 275 | 276 | def create_sql_instance(instance_name, db_region, db_tier, project, 277 | credentials): 278 | """Creates a Cloud SQL instance and sets its root password. 279 | 280 | If the instance already exists, the creation step is skipped, but 281 | the root password will still be reset. 282 | 283 | Args: 284 | instance_name: A string specifying the name for the new instance 285 | db_region: A string specifying the region in which the instance 286 | should be created 287 | db_tier: A string specifying the database tier to create. For a 288 | list of valid tiers and the regions in which they are 289 | available, use 'gcloud sql tiers list'. 290 | project: a string specifying the GCP project in which to create 291 | the instance credentials: oauth2client.Credentials to be 292 | used for authentication 293 | 294 | Returns: 295 | True if the Cloud SQL instance was successfully created, and its 296 | root password was successfully set; False otherwise. 297 | """ 298 | service = discovery.build('sqladmin', 'v1beta4', credentials=credentials) 299 | request = service.instances().insert( 300 | project=project, 301 | body={ 302 | 'name': instance_name, 303 | 'region': db_region, 304 | 'settings': { 305 | 'tier': db_tier, 306 | 'activationPolicy': 'ALWAYS' 307 | } 308 | } 309 | ) 310 | 311 | logging.info('Waiting for Cloud SQL instance creation: %s', instance_name) 312 | if not _wait_for_operation(request, 313 | _cloud_sql_op_poller_factory(service, project)): 314 | return False 315 | 316 | # Verify creation by tring to retrieve instance info. 317 | request = service.instances().get(project=project, 318 | instance=instance_name) 319 | try: 320 | request.execute() 321 | return True 322 | except apiclient.errors.HttpError: 323 | return False 324 | 325 | 326 | def set_sql_root_password(root_pw, instance_name, project, credentials): 327 | """Attempts to set the root SQL password in a Cloud SQL instance. 328 | 329 | Args: 330 | root_pw: A string specifying the root password to set in the 331 | Cloud SQL instance. 332 | instance_name: A string specifying the name of the Cloud SQL 333 | instance 334 | project: a string specifying the GCP project in which to create 335 | the instance 336 | credentials: oauth2client.Credentials to be used for 337 | authentication 338 | 339 | Returns: 340 | True if the instance's root password was successfully set; False 341 | otherwise. 342 | """ 343 | service = discovery.build('sqladmin', 'v1beta4', credentials=credentials) 344 | request = service.users().update( 345 | project=project, instance=instance_name, host='%', name='root', 346 | body={'password': root_pw}) 347 | 348 | logging.info('Waiting for Cloud SQL root password set: %s', instance_name) 349 | return _wait_for_operation(request, 350 | _cloud_sql_op_poller_factory(service, project)) 351 | 352 | 353 | def delete_sql_instance(instance_name, project, credentials): 354 | """Attempts to delete a Google Cloud SQL instance. 355 | 356 | Args: 357 | instance_name: A string specifying the name for the new instance 358 | project: a string specifying the GCP project in which the 359 | instance resides 360 | credentials: oauth2client.Credentials to be used for 361 | authentication 362 | 363 | Returns: 364 | True if this attempt to delete the instance succeeded, False 365 | otherwise. Note that this means that this function may return 366 | False if the instance did not exist in the first place or was 367 | deleted concurrently 368 | """ 369 | service = discovery.build('sqladmin', 'v1beta4', credentials=credentials) 370 | 371 | # If the instance is in the process of being provisioned, we have to 372 | # wait until it is up and running before we can initiate deletion. 373 | request = service.instances().get(project=project, instance=instance_name) 374 | while True: 375 | try: 376 | instance = request.execute() 377 | except apiclient.errors.HttpError: 378 | # No such instance; this will get caught when we try to delete 379 | # it. 380 | break 381 | if instance['state'] == 'RUNNABLE': 382 | break 383 | 384 | request = service.instances().delete(project=project, 385 | instance=instance_name) 386 | 387 | if not _wait_for_operation( 388 | request, _cloud_sql_op_poller_factory(service, project)): 389 | return False 390 | 391 | # Verify deletion by tring to retrieve instance info. 392 | request = service.instances().get(project=project, 393 | instance=instance_name) 394 | try: 395 | request.execute() 396 | return False 397 | except apiclient.errors.HttpError as e: 398 | return e.resp['status'] == '404' 399 | 400 | 401 | def _wait_for_operation(request, op_poller): 402 | """Executes a request and waits for its operation to finish. 403 | 404 | Args: 405 | request: A apiclient.http.HttpRequest whose response is expected 406 | to be an Operation. 407 | op_poller: A function whose first argument is expected to be an 408 | Operation. When called on an operation, op_poller should 409 | poll the API and return an updated version of the same 410 | Operation. 411 | 412 | Returns: 413 | True if request executed without raising an HttpError, False 414 | otherwise 415 | """ 416 | try: 417 | logging.debug('Executing synchronous request: %s', request.to_json()) 418 | start_time = time.time() 419 | op = request.execute() 420 | except apiclient.errors.HttpError as e: 421 | logging.warn(str(e)) 422 | return False 423 | 424 | poll_interval_secs = 1 425 | while op['status'] != 'DONE': 426 | time.sleep(poll_interval_secs) 427 | logging.debug('Polling Operation: %s', op) 428 | op = op_poller(op) 429 | 430 | # Exponential backoff up to maximum. 431 | poll_interval_secs = min(MAX_POLL_INTERVAL_SECS, 432 | 2 * poll_interval_secs) 433 | duration = time.time() - start_time 434 | logging.debug('Operation completed in %s seconds: %s', 435 | duration, request.to_json()) 436 | 437 | return True 438 | 439 | 440 | def _cloud_sql_op_poller_factory(service, project): 441 | """Creates a function that polls a Cloud SQL operation. 442 | 443 | The value returned by a call to this function can be provided as the 444 | op_poller argument to _wait_for_operation. 445 | 446 | Args: 447 | service: a apiclient.discovery.Resource object for interacting 448 | with the Cloud SQL API. This is usually the same object used 449 | to create the request that spawned the operation that will 450 | be waited on. 451 | project: a string specifying the GCP project in which the 452 | operation will be executing 453 | 454 | Returns: 455 | a function that can be used as the second argument to 456 | _wait_for_operation. 457 | """ 458 | def op_poller(op): 459 | return (service.operations() 460 | .get(project=project, operation=op['name']).execute()) 461 | return op_poller 462 | 463 | 464 | def _gke_op_poller_factory(service, project, zone): 465 | """Creates a function that polls a GKE operation. 466 | 467 | The value returned by a call to this function can be provided as the 468 | op_poller argument to _wait_for_operation. 469 | 470 | Args: 471 | service: a apiclient.discovery.Resource object for interacting 472 | with the GKE API. This is usually the same object used to 473 | create the request that spawned the operation that will be 474 | waited on. 475 | project: a string specifying the GCP project in which the 476 | operation will be executing 477 | zone: a string specifying the GCE zone in which the operation 478 | will be running 479 | 480 | Returns: 481 | a function that can be used as the second argument to 482 | _wait_for_operation. 483 | """ 484 | def op_poller(op): 485 | return (service.projects().zones().operations() 486 | .get(projectId=project, zone=zone, operationId=op['name']) 487 | .execute()) 488 | return op_poller 489 | -------------------------------------------------------------------------------- /notebooks/Part1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Exploring the model code\n", 8 | "\n", 9 | "The model code is contained in the `wals_ml_engine` directory. The code's \n", 10 | "high-level functionality is implemented by the following files:\n", 11 | "\n", 12 | "`mltrain.sh`\n", 13 | " * Launches various types of ML Engine jobs. This shell\n", 14 | " script accepts arguments for the location of the dataset file, the\n", 15 | " delimiter used to separate values in the file, and whether the data\n", 16 | " file has a header line. It's a best practice to create a script that\n", 17 | " automatically configures and executes ML Engine jobs.\n", 18 | "\n", 19 | "`task.py`\n", 20 | " * Parses the arguments for the ML Engine job and executes\n", 21 | " training.\n", 22 | "\n", 23 | "`model.py`\n", 24 | " * Loads the dataset.\n", 25 | " * Creates two sparse matrices from the data, one for training and one for\n", 26 | " testing. Executes WALS on the training sparse matrix of ratings.\n", 27 | "\n", 28 | "`wals.py`\n", 29 | " * Creates the WALS model.\n", 30 | " * Executes the WALS algorithm.\n", 31 | " * Calculates the root-mean-square error [RMSE](https://wikipedia.org/wiki/Root-mean-square_deviation) for a set of row/column factors and a ratings matrix.\n", 32 | "\n", 33 | "## Let's run some code\n", 34 | "\n", 35 | "Now we will run some of the code from `model.py` in this notebook, to help you understand it.\n", 36 | "\n", 37 | "**Execute the following cell to import necessary python modules:**" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "import datetime\n", 47 | "import numpy as np\n", 48 | "import pandas as pd\n", 49 | "from scipy.sparse import coo_matrix\n", 50 | "import tensorflow as tf" 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "metadata": {}, 56 | "source": [ 57 | "## The MovieLens data set\n", 58 | "\n", 59 | "**Execute the following cell to download the MovieLens 100k data set:**" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "!curl -O 'http://files.grouplens.org/datasets/movielens/ml-100k.zip'\n", 69 | "!unzip ml-100k.zip\n", 70 | "!mkdir -p ../data\n", 71 | "!cp ml-100k/u.data ../data/" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": {}, 77 | "source": [ 78 | "**Let's take a look at the MovieLens data.** It consists of rows of individual user ratings of movies. Each row consists of four elements: a unique user id, a unique 'item id' corresponding to a movie, a rating from 1-5, and a timestamp of when the rating was made. Each row of data maps to a single entry in the ratings matrix:" 79 | ] 80 | }, 81 | { 82 | "attachments": { 83 | "recommendation-system-tensorflow-movielens-rating-matrix.png": { 84 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAuAAAAJNCAYAAABwXMA5AAAABmJLR0QA/wD/AP+gvaeTAAAn7ElEQVR42u3dC3RV1Z0/8F8gPAIIqEABDWAABRUEi+igVgtVqxSdTqt20GmplVZQsQ7YGUf9j9KlXRVU7Mj4qM9aK51qrQiKSlhtXbLwMX8dcPqnYhGBohZsgcpLJPnvc0YzxBcPA5wkn89ax3POvjeX7e+e3Hyz7747EQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAANQ45ZRTfpW2m1QCoHEqVQKA3W5OSUnJn7cK5DPS7pVHH330u0oDIIADUMdS0P6RKgA0XiVKALB7ZVNQ0m55CuIXnHzyyXen0fBvbH1706ZN+z7yyCO/Hz58+Pjq6urzUtP+aVucjic/9thjd75/v3T77LR7MbU3S/uvp21zOv7XtL8zPeb1aX9W2v6ativSv3X3Vv/+F9P9rk736ZtON6Xtd2kbme7zmmcHYNdrogQAe04K1KPSbmbabkwBuCTbUvhemIL5lVn4TiH5oiyQZyE6HV+X2r+69den+3wr7dam7XNp+0G6z9QmTZr8Mu1XpPNj0/6WtL9txIgR3bL7H3/88W3S7pdp+/m77757UNofmbZb0nGVZwNg9zAFBaBgTj/99LJ169ZNSOH6yymQP/Fe86tp5PrQFKhHp+MHtgrgL6YQf8V7pwvSfS5MbVmQv/q9tpdS27iqqqosoP+0devWndLtZSmkP5K+btl793lZ1QEEcIBG6+233+6TgnartD2ewvMHb/7D1ifpPv/3A7e/ntpe2Oq8Om1vpNDdKTuZOXPmq2kU/aF0+Hx67CdT++x0/wdSYH9D5QEEcIBGKY1ON0nBOAvX/VJgfumT7pvut/nDTR9uS4/1/pTD6jTy/Xdf+tKXBqdR8ZNT+9fT9oMUxr+QQvgzqg8ggAM0eCkwv5Myd9OtAvj/27Jly4YUkE9Npy/tin9zxowZz6Zdtl2Vwve8tD8zbQI4gAAO0PClEehXUwg/9otf/GKPZs2avd2yZcu/pDng16b2y9JKJ39Ntz2eQnlZCuVDUluzT7OMYXq8Q9PjfSUdziwtLV2RHvPgFPR7pfM7PBMAu4dVUAD2vCxQb0wh+79TIF65cePG3ilkX5navvfeMoQLUkienW7/Sgrgv/80/1B6/LfTYwxOhzPSyidZ8L8tPe6P0rSU2z0NAAAAAAAAAAAAUC/4U/QANBrZGo2qkH/w189/2IN8CBMAAARwAAAQwAEAAAEcAAAEcAAAQAAHAAABHAAABHAAAEAABwAAARwAABDAAQBAAAcAAAEcAAAQwAEAoKEpVQKobd26ddVF7t+yZcvyfXl5eeinfuofXuf0c1tat25d4qovFiPgAAAggAMAgAAOAAAI4AAAIIADAAACOAAACOAAACCAAwAAAjgAAAjgAACAAA4AAAI4AAAI4AAA23DttdfGHXfcoY+NqJ/sGqVKALvOb37zm3jwwQdj5cqV0aVLlxg5cmQMGjSoMP1bsGBBzJgxIxYvXhyrV6+OcePGxbHHHlvIWs6aNSuefvrpWLp0aZSUlERFRUVez169ehWmj3PmzMn7+cYbb0R1dXX+nJ922mlx9NFHF7Km2XN/zz33xFFHHRXjx48vTL+y75lp06bVamvVqlXeV/asQw89NPbaa69Cfx9sTx+L8HqyPf2sT6/RCOBQCC+++GJMnTo1zjrrrBg4cGD89re/jcmTJ8c111yTv9gXwaZNm6JHjx4xbNiwmDRpUqHr+dRTT8VnP/vZOP300/MfmA899FBMnDgxrr/++ujQoUMh+ti8efMYPnx4dOrUKT9/5pln4sYbb4y2bdtGv379ClXPV199Nf/B3rNnz0I+3926dYsLL7yw5rxJE2/YFsEpp5xS+O+D7eljEV5Ptqef9ek1GgEcCmH69Ol58M5GfjJnn312zJ8/Pw892ShGEWSj8UUakf8kV199da3zbKRq1KhReU2HDh1aiD4ec8wxtc779u0b8+bNi4ULFxYqgG/cuDGmTJkS5557bjz55JOFfL5btGiRBw+KJZs2se+++8a3vvWtwn4fbE8fi/B6sj39rE+v0ewYQwqwC2Rvuy5atCj69+9fqz07f/nllxWoDmzZsiWvczaqVkSbN2+OuXPnxtq1aws3+n3XXXflb38X+Qd7NjUgC0SjR4+OH/7wh7Fs2TIXfT1U5O+D+vR6QsNjBBx2gWyEMdvatWtXqz07z+bx8endeeed0bVr1xgwYECh+pUFxQkTJkRVVVX+VvzYsWOjT58+helfFoaykchs9K2oDjjggDx477fffrFmzZp4+OGH4/LLL4/rrruuMNONqN/fB/Xl9QQBHKAw7r333njppZfiqquuitLSYr2Mde7cOZ/rv379+njuuefi5ptvzn/xKsLo36pVq/JVFy677LJ8ikdRHX744bXODz744Hw+eGVlZZx55pn16lrN5j9nH0bMgl337t1r2rNg+sgjj+THI0aMqDXH/bXXXss/Q5JdS0ceeWRh/9+ef/75WvOSs5B93HHHFer74JP6WKTXk+3pJwI4sA0tW7bMt2z0bmvZefv27RXoU7j77rvzUdwrr7wyX12haJo1axbl5eX58UEHHZSPBM6cObMQATxbSSGbCnDppZfWtGVvu2eyYHvTTTdFx44dC1fTsrKyfDT89ddfr3fX6xNPPJHPf84+97F1AH/33Xfjtttuy49PPvnkfJT4fdk7FNltRxxxRKED+CGHHBI33HBDzfnWr21F+T74pD4W6fVkW/1EAAe2Q/ap+t69e+cf6MlWA3hfdn7ggQcq0E7IgmIWSl544YV8pKqI4fujZHNgmzZtWoi+ZJ9ByFZ52Nrtt9+ej/pl86333nvvQtZww4YNsXz58sKsHsT//mKUbUX+Pvi4Phbt9WRHaokADnyCU089NV9yMJu/mr2lnq0Jnn2w7LzzzitUsHnzzTdrzrP1ypcsWRJt2rQp3FzbbEnHZ599Ni644IJ8aa6sn5lspKgoo0W33nprPsKX1S7rYzbyma3je9FFFxWif9m7Mu+PSm7dlo2+frB9Tz/X2ZSNbDQ+e9coW1Eo+0zFCSecUO9eB6644oqPbM9qno0If5STTjop3+qron8f1JfXk/r2Go0ADoWQBYjzzz+/5o+KZCMs2YeSijSK98orr+Rr377v/vvvz7dsGa4xY8YUqp7Z28TZKNoH18LN1vE944wzCtHH7G337Ll+66238jnWWai95JJLYvDgwb4hdkA2H/q+++7LP7CcBY1srfLsl9nsQ3IUX334PqgPryf17TWaHVOiBFDbunXrqovcv/eXYyvSiKV+6md9qWP2FzWJ/MORrvfG08/WrVvLe0UbaFACAAAQwAEAQAAHAAAEcAAAEMABAAABHAAABHAAABDAAQAAARwAAARwAABAAAcAAAEcAAAEcAAAQAAHAAABHAAA2FklSgBAY1GdqEL64Z+oAuw5RsABAEAABwAAARwAABDAAQBAAAcAAARwAAAQwAEAQAAHAAAEcAAAEMABAAABHAAABHAAABDAAQAAARwAAARwAABAAAcAAAEcAAAQwAEAQAAHAAABHAAAEMABAKCeKlWCxmPdunXVRe5f69atS4rQj9fWVxW6TovmzoneQ4YW/nrTT/2kuIr+Ouear1vdWzUpcdUXixFwAAAQwAEAQAAHAAAEcAAAEMABAAABHAAABHAAABDAAQAAARwAAARwAABAAAcAAAEcAAAEcAAAQAAHAAABHAAA2GmlSsCO+s1vfhMPPvhgrFy5Mrp06RIjR46MQYMGKQzUE3Mrn4hfz5oZK998PaqrqqJTl/3ixL/9Sgw6+nOKw6dSOeNX8cDdt8fAo46Ob0+4VEFAAKcuvPjiizF16tQ466yzYuDAgfHb3/42Jk+eHNdcc01UVFQoENQDzZo3j6HDT40On+mcn78wb27cOWVStNmrbfTpP0CB2CnLXv1DVD7yq+jes7diwDaYgsIOmT59eh68TzvttOjWrVucffbZ+X7GjBmKA/XEEcceH0cdPyx69T0k307/5uhov8++8YeFv1McdsqmjRvjjvRL3NdGj4m92rVXEBDAqSvV1dWxaNGi6N+/f6327Pzll19WIKiH3t28Of5z7lPx17VrjH6z0/7jrtvioEP6R/9BRyoGbAdTUNhuG9MIR7a1a9euVnt2vnr1agWCemTFstfi6vEXRlWaA55NSfn6+d+Nnn0OVhh2WPYLXPbuyb9ce6NigAAOwMfp1LlrXHbdv8XG9evjv557Ju799xtjr7btjIKzQ/68amX8/PZb4oLLr4rmLVooCAjg1LWWLVvm25o1a2q1Z+ft25vzB/Xqxb9Zs+ha3j0/rjiob6xY+lrMmTldAGeHLF38Sj596Yf//I81bdl0xcz5Z5wa3596e+zTsZNCgQDOziopKYnevXvH/PnzY/jw4TXt2fmBBx6oQFCPvfvu5mjapKlCsEP69hsQV9wwtVbbtB/fnP+C99VR50a7vfdRJBDA+bROPfXUfMnBhx9+OA4//PB8TfClS5fGeeedpzhQT9x3y01ppPuw2KdDx3z1imwZwoXzX4xzvnuJ4rBDWpSV1byTUtPWsiz/XMEH2wEBnJ00YMCAOP/88/M/xDNt2rT8D/FMmDDBGuBQn1740+jk9Pvvjb+8tSpatGgZXcq7xXe+d1kMGPw3igMggFNExx13XL4B9dOZ3/qOIrDLnP8v/6oIsA3WAQcAAAEcAAAEcAAAQAAHAAABHAAAEMABAEAABwAAARwAABDAAQBAAAcAAARwAAAQwAEAQAAHAAB2ixIlaDwWLlxYXeT+9enTpxDX4+zZs6tdLdAwDRs2TBGSyspKRWgk9u1WEQMPrJD3CqZUCaC23kOGFrp/i+bOKXwf9VM/8TrnmgcBnKS8vFwRAAD2MHPAAQBAAAcAAAEcAAAQwAEAQAAHAAAEcAAAEMABAEAABwAABHAAABDAAQAAARwAAARwAAAQwAEAgN2iVAnYEQsWLIgZM2bE4sWLY/Xq1TFu3Lg49thjFeYDHntgWkyf9tNabWWtWsf1P/m54uyEhfNfjMoZD8fSxa/E2tV/iW+OGx+DP/d5hQFAAKfh27RpU/To0SOGDRsWkyZNUpBPsF+3HjFq3D/WnJc08YbTznrnnU1RfkBFHPOFE+OWa69WEAAEcBqPQYMG5Rvb1rxFi9i/R4VC1IH+g47MNwAQwIGP9celS2L8N74Wpc2aRY9eB8ZpZ309upZ3VxgAaOS8Jw67QHlFz/j70WPjwismxtljLoz1696OyZd9L/68aqXiAEAjZwQcdoFDDz+i1nnvg/vF/7lgdDxd+USMOPMsBQKARswIOOwGLcvKovP+5fGnFX9UDAAQwIFdbeOGDfHGsqXRbu99FAMAGjlTUNghG1KQfPPNN2vOV65cGUuWLIk2bdpEhw4dFOg9P5k6JQ4ecHjs27FTrF2zOp58+JexceOGOPaELyrOTv4Cs+rN12vO31r5p1i+ZHG0arNX7NOhowIBIIDTcL3yyisxceLEmvP7778/34YOHRpjxoxRoPdka34/9NO78z8a0zqFxO49e8c//eC6+Mx++yvOTliy6Pdx48TLa86n339vvg0ZdmL8w5hxCgSAAE7D1a9fv/jFL36hENsgFNatPv0HxM0PzFAIABoEc8ABAEAABwAAARwAABDAAQBAAAcAAARwAAAQwAEAQAAHAAAEcAAAEMABAAABHAAABHAAABDAAQCA3aJUCRqPZcuWKQIAwB5mBBwAAHajEiVoPNatW1dd5P61bt26ENfj7Nmzq10t0DANGzZMEZLKykpFaCT27VYRAw+skPcEcATwYgfw19ZXFbpOi+bOid5Dhhb+etNP/SyibmV+7GWWbqh2LTWifnZv1cSFXzCmoAAAgAAOAAACOAAAIIADAIAADgAACOAAACCAAwCAAA4AAAjgAAAggAMAAAI4AAAI4AAA0JiVKgE7YtasWfH000/H0qVLo6SkJCoqKmLkyJHRq1cvxQEAEMCpa0899VR89rOfjdNPPz0P4A899FBMnDgxrr/++ujQoYMCAcBOmlv5RPx61sxY+ebrUV1VFZ267Bcn/u1XYtDRn1McAZzG7Oqrr651no18jxo1KubPnx9Dhw5VIADYSc2aN4+hw0+NDp/pnJ+/MG9u3DllUrTZq2306T9AgQRw+B9btmyJ6urqaNu2rWIAwKdwxLHH1zrv1feQFMKfjj8s/J0ALoDD/7rzzjuja9euMWCAFwYAqCvvbt4c//XcvPjr2jXCtwAO/+vee++Nl156Ka666qooLXUpAcCntWLZa3H1+AujKs0Bz6akfP3870bPPgcrjAAOEXfffXfMnTs3rrzyyujSpYuCAEAd6NS5a1x23b/FxvXr0wj4M3Hvv98Ye7VtZxRcAKcxy+Z733bbbfHCCy/kI9/CNwDUYTBr1iy6lnfPjysO6hsrlr4Wc2ZOF8AFcBqzqVOnxrPPPhsXXHBBbNq0KZYsWZK3t2/fPt8AgLrz7rubo2mTpgohgNOYZdNONqcPhkyaNKlWe7Yu+BlnnKFAALCT7rvlpjTSfVjs06FjbNq4MV+GcOH8F+Oc716iOAI4jdnPfvYzRQCAXRHK0vST6fffG395a1W0aNEyupR3i+9877IYMPhvFEcABwCgrp35re8oQiPRRAkAAEAABwAAARwAABDAAQBAAAcAAARwAAAQwBuz9AdsmqsCAEDjYR3wrZxyyikvlpSU/HTmzJmTt2qblnZvP/roo+e+d/7F6urqq9P9+qbTTWn7XdpGpttfS/uS4cOHj0+3n5eO90/b4nQ8+bHHHrvz/cdLt89ObQvT1iadfuntt99ekPaf38bjAgDQQBgB3wHHH398Fpp/mbafv/vuuwel/ZFpuyUdV2W3n3zyyVdm4TuF6IuaNm2aBekr0vF1qf2rH3io0U2aNHmqTZs2XTds2DBiW48LAEDDYQR8B7Ru3bpTCthlKTw/kka1l73X/HL2nzSVpGzdunUT0u1fTqPWT7x326tpZPvQFMJHp+MH3n+cdP5UGmW/473Td9KoeMXHPS4AAAJ4o5VC86tpNPuhdPh8CtZPptA8O4XpB1LgfiNNJemTjlul7fF02we/9A9bn6Sve2l7H1fVAQAaFlNQavvQlI8UhJtunZ3TCPXfpZHqz6fjF9JtX0/bohSaj0xtTd67f78UnEs+sPX6QADf+IF/5mMf11MCANCwGAGvHYz/lLYOH2g7MO2e27ptxowZz6Zdtl2VQvK8tD8zhefLt2zZsqGqqurUdP7Szvz7H/W4aXvGMwMAIIA3SClEZyuUjDnxxBOvHzJkyKpnn332gtR88PsBPM3VPjTd/pV0OLO0tHRFCtwHp8CdjW7f8cgjj6xPofnaNHJ9WbrfX9P9Hk+PV5buMyS1NUuj4D/6uH/3kx7XswIAIIA3WOlDlDe1atWqTwrB/53C94bUNC2F51+lcJzfnoLx22l1k8FZSE8rlOydmv6YQvaP0hzu27PbU8i+MoXwlen2sel0cgrRa9Pt/5WOJ33Sv7utx60ry5Yt8yQDAAjgxfHrX/86m5t97nvbh8yaNWtJ2g3/pMdIIXxq2k39uNtTqP7CzjwuAAANQ4kSNB5phL+6yP1LyzwW4nqcPXt2tasFGqZhw4YpQlJZWakIjcS+3Spi4IEV8p4AjgBe7AD+2vqqQtdp0dw50XvI0MJfb/qpn0XUrcyPvczSDdWupUbUz+6tmrjwC8YyhAAAIIADAIAADgAACOAAACCAAwAAAjgAAAjgAAAggAMAAAI4AAAI4AAAgAAOAAACOAAACOAAAMBuUaoE7Ig5c+bErFmz4o033ojq6uro0qVLnHbaaXH00UcrDgCAAE5da968eQwfPjw6deqUnz/zzDNx4403Rtu2baNfv34KBAAggFOXjjnmmFrnffv2jXnz5sXChQsFcAAAAZxdafPmzfHcc8/F2rVrhW8AAAGcXWXZsmUxYcKEqKqqyqekjB07Nvr06aMwAAACOLtC586dY/LkybF+/fp8BPzmm2+Odu3aGQUHABDA2RWaNWsW5eXl+fFBBx2Uj4jPnDlTAAcA2A7WAedTy+aCb9myRSEAALaDEXB2yK233pqPdHfo0CE2bdqUr4CyYMGCuOiiixQHAEAAp65l00+mTZsWb731VrRo0SKfinLJJZfE4MGDFQcAQACnrp1zzjmKAADwKZgDDgAAAjgAAAjgAACAAA4AAAI4AAAggAMAgAAOAAACOAAAIIADAIAADgAACOAAACCAAwCAAA4AAAjgAADQ0JQqQeOxbNkyRQAA2MOMgAMAwG5UogSNx7p166qL3L/WrVsX4nqcPXt2tasFGqZhw4YpQlJZWakIjcS+3Spi4IEV8p4AjgBe7AD+2vqqQtdp0dw50XvI0MJfb/qpn0XUrcyPvczSDdWupUbUz+6tmrjwC8YUFAAAEMABAEAABwAABHAAABDAAQAAARwAAARwAAAQwAEAAAEcAAAEcAAAQAAHAAABHAAABHAAAGC3KFUCdtaMGTPinnvuiaOOOirGjx+vIFBPLJz/YlTOeDiWLn4l1q7+S3xz3PgY/LnPKwyAAE6Rvfrqq3kA79mzp2JAPfPOO5ui/ICKOOYLJ8Yt116tIDR4jz0wLaZP+2mttrJWreP6n/xccRDAqR82btwYU6ZMiXPPPTeefPJJBYF6pv+gI/MNGpP9uvWIUeP+sea8pIlZuAjg1CN33XVXHHrooTFo0CABHIB6oXmLFrF/jwqFQACn/pk7d24sXLgwrr32WsUAoN7449IlMf4bX4vSZs2iR68D47Szvh5dy7srDHuE91/YbqtWrYo77rgjLrroomiRRhIAoD4or+gZfz96bFx4xcQ4e8yFsX7d2zH5su/Fn1etVBz2CCPgbLfFixfH2rVr49JLL61pq66uzvdnnnlm3HTTTdGxY0eFAqBQDj38iFrnvQ/uF//ngtHxdOUTMeLMsxQIAZzi6t+/f1x//fW12m6//fYoLS2NUaNGxd57761IABRey7Ky6Lx/efxpxR8VAwGcgr9gtWwZ5eXlH2pr3rz5h9qB4tq4YUOsevP1mvO3Vv4pli9ZHK3a7BX7dPAuFo3je+CNZUuj2wGW0kUAB2A3WLLo93HjxMtrzqfff2++DRl2YvzDmHEKRIPzk6lT4uABh8e+HTvF2jWr48mHf5mW1N0Qx57wRcVBAKf+2Xo+OFA/9Ok/IG5+YIZC0Ghka34/9NO787/82jq909O9Z+/4px9cF5/Zb3/FQQAHAKhr3tmhaCxDCAAAAjgAAAjgAACAAA4AAAI4AAAggAMAgAAOAAACOAAAIIADAIAADgAACOAAACCAAwCAAA4AAOwWpUrQeCxbtkwRAAD2MCPgAACwG5UoQeOxbt266iL3r3Xr1oW4HmfPnl3taoGGadiwYYqQVFZWKkIjsW+3ihh4YIW8J4AjgBc7gL+2vqrQdVo0d070HjK08NebfupnEXUr82Mvs3RDtWupEfWze6smLvyCMQUFAAAEcAAAEMABAAABHAAABHAAAEAABwAAARwAAARwAABAAAcAAAEcAAAQwAEAQAAHAIDGrFQJ2BEPPvhgTJs2rVZbq1at4p577lEcAAABnF2hW7duceGFF9acN2nijRQAAAGcXaZFixbRo0cPhQAAEMDZHZYuXRqjRo2KZs2aRa9evWLkyJFRXl6uMAAAAjh17YADDojRo0fHfvvtF2vWrImHH344Lr/88rjuuuuiQ4cOCgQAIIBTlw4//PBa5wcffHA+H7yysjLOPPNMBQIA2AafnuNTKSsry0fDX3/9dcUAABDA2dU2bNgQy5cvj7333lsxAAC2gyko7JCpU6fGgAEDomPHjvkc8OnTp8fGjRvjhBNOUBwAAAGcupat+X3ffffF6tWro02bNtGzZ8+45ppromvXrooDACCAU9fGjBmjCAAAn4I54AAAIIADAIAADgAACOAAACCAAwAAAjgAAAjgAAAggAMAAAI4AAAI4AAAgAAOAAACOAAACOAAAIAADgAAAjgAALDTSpSg8ahOVAEA2OngmKjCp2cEHAAABHAAABDAAQAAARwAAARwAABAAAcAAAEcAAAEcAAAQAAHAAABHAAAEMABAEAABwAAARwAABDAAQBAAAcAAARwAAAQwAEAAAEcAAAEcAAAEMABAAABHAAABHAAAEAABwAAARwAAARwAABAAAcAAAEcAADYCaVKwI76/ve/n+9PPPHEOPLIIxUEAEAAZ1eaN29evu/fv79iAADsoBIlaDyqk7p4nMcffzzf9+nTJ7p3766wANBYgmOiCgI4eyCAAwACODvPhzABAEAABwCAhsmHMNlh77zzzv9cPKWl0aSJ3+EAAHaE9MQO+/KXv5xvjzzyiGIAAAjgAABQXD7J2ojU1SoopqAAQCMNjlZBEcDZMwEcABDA2XmGLwEAQAAHAICGyTKE7DB/ih4AQABnN/rRj36U77/97W8L4AAAAji72hFHHJHvO3furBgAADvIJ1kbEaugAACfKjhaBaVO+BAmAAAI4AAAIIADAAACOAAACOAAAIAADgAAAjgAAAjgAACAAA4AAAI4AAAggAMAgAAOAAACOAAAIIADAIAADgAACOAAACCAAwAAAjgAAAjgAAAggAMAAAI4AADUU6VK0HisX7++0P1r3bp1SRH68dr6quoi12nR3DnRe8jQwl9v+qmfRdStrCSIWLqh2rXUiPrZvVUTF37BGAEHAAABHAAABHAAAEAABwAAARwAABDAAQBAAAcAAAEcAAAQwAEAQAAHAAAEcAAAEMABAEAABwAABHAAABDAAQCAnVaqBNSlOXPmxKxZs+KNN96I6urq6NKlS5x22mlx9NFHKw4AgABOXWvevHkMHz48OnXqlJ8/88wzceONN0bbtm2jX79+CgQACOBKwEeZMGFCHHfccTFixIiatilTpkSLFi1izJgxH/t1xxxzTK3zvn37xrx582LhwoUCOACAAM6utHnz5njuuedi7dq1wjcAgADOrrJs2bJ8BL2qqiqfkjJ27Njo06ePwgAACODsrOeffz4mTZpUc56F7GzKSqZz584xefLkWL9+fT4CfvPNN0e7du2MggMACOB8nCZNPrxCZTai/b5DDjkkbrjhhprz9u3b1xw3a9YsysvL8+ODDjooHxGfOXOmAA4AIIDzcbJVS9asWVOrbcWKFdGzZ8/8uKysLN+2RzYXvGnTpooKABD+EA8f47DDDstXL8lCeLae96OPPhrLly/f5tfdeuutMXfu3Hj55ZdjwYIF8eMf/zjfvz89BQCgsTMCzkc66aST8sB98cUX50sPDhkyJAYPHrzNr8umn0ybNi3eeuut/OuyqSiXXHLJdn0tAIAATqOVrV6Srff9SWt+f5RzzjlH8QAAPoEpKAAAIIADAIAADgAACOAAACCAAwAAAjgAAAjgAAAggAMAAAI4AAAI4AAAgAAOAAACOAAACOAAAMBuUaIEjcfChQuri9y/Pn36FOJ6nD17drWrBRqmYcOGKUJSWVmpCI3Evt0qYuCBFfJewZQqAdTWe8jQQvdv0dw5he+jfuonXudc8yCAk5SXlysCAMAeZg44AAAI4AAAIIADAAACOAAACOAAAIAADgAAAjgAAAjgAACAAA4AAAI4AAAggAMAgAAOAAACOAAAsFuUKgF1adasWfH000/H0qVLo6SkJCoqKmLkyJHRq1cvxSmghfNfjMoZD8fSxa/E2tV/iW+OGx+DP/d5haEwKmf8Kh64+/YYeNTR8e0JlyoI0CAYAadOPfXUUzFw4MAYP358vmUmTpwYq1atUpwCeuedTVF+QPol6dtjFYPCWfbqH6LykV9F9569FQNoUIyA85EmTJgQxx13XIwYMaKmbcqUKdGiRYsYM2bMx37d1VdfXes8G/keNWpUzJ8/P4YOHaqwBdN/0JH5BkWzaePGuGPKpPja6DHx1BOzFARoUIyAs0tt2bIlqquro23btooBbLf/uOu2OOiQ/n5BBARw2FF33nlndO3aNQYMGKAYwHb5z7lPxR8W/i6+8o1vKQbQIJmCwk55/vnnY9KkSTXnY8eOzaesbO3ee++Nl156Ka666qooLXWpAdv251Ur4+e33xIXXH5VNE9T3gAEcBqNJk0+/OZIVVVVzfEhhxwSN9xwQ815+/bta9337rvvjrlz58aVV14ZXbp0UVBgu2Qr8vx17Zr44T//Y01bNo0tc/4Zp8b3p94e+3TspFCAAE7Dk83ZXrNmTa22FStWRM+ePfPjsrKyfPug7AflbbfdFi+88EI+8i18Azuib78BccUNU2u1TfvxzVHarFl8ddS50W7vfRQJEMBpmA477LB4/PHH81VQsjD+2GOPxfLly2sC+MeZOnVqPPvss3HBBRfEpk2bYsmSJXl7NkL+wVFy9ryNGzbEqjdfrzl/a+WfYvmSxdGqzV6xT4eOCsRu1yL9Yt+1vHvttpZl0ax58w+1AwjgNCgnnXRSHrgvvvjifOnBIUOGxODBg7f5ddm0k82bN9eaH545/fTT44wzzlDYglmy6Pdx48TLa86n339vvg0ZdmL8w5hxCgQAAji7S/M02pSt9/1Ja35/lJ/97GeKV4/06T8gbn5ghkJQaOf/y78qAtCgWIYQAAAEcAAAEMABAAABHAAABHAAAEAABwAAARwAAARwAABAAAcAAAEcAAAQwAEAQAAHAAABHAAAEMABAEAABwAAdlqJEjQeCxcurC5y//r06VOI63H27NnVrhZomIYNG6YISWVlpSI0Evt2q4iBB1bIewVTqgRQW+8hQwvdv0Vz5xS+j/qpn3idc82DAE5SXl6uCAAAe5g54AAAIIADAIAADgAACOAAACCAAwAAAjgAAAjgAAAggAMAAAI4AAAI4AAAgAAOAACFU6oE1KUFCxbEjBkzYvHixbF69eoYN25cHHvssQoDAPAeI+DUqU2bNkWPHj1i9OjRigEA8BGMgPORJkyYEMcdd1yMGDGipm3KlCnRokWLGDNmzMd+3aBBg/INAICPZgQcAAAEcAAAaJhMQWGnPP/88zFp0qSa87Fjx+ZTVgAAEMDZCU2afPjNkaqqqprjQw45JG644Yaa8/bt2ysaAIAAzs5q27ZtrFmzplbbihUromfPnvlxWVlZvgEAsGPMAecjHXbYYTFv3rw8hFdXV8ejjz4ay5cv3+bXbdiwIZYsWZJvmZUrV+bHq1atUlQAgDACzsc46aST8sB98cUX50sPDhkyJAYPHrzNr3vllVdi4sSJNef3339/vg0dOvQTly8EABDAadSaN2+eB+YdDc39+vWLX/ziFwoIAPAxTEEBAAABHAAABHAAAEAABwAAARwAABDAAQBAAAcAAAEcAAAQwAEAQAAHAAAEcAAAEMABAEAABwAABHAAABDAAQAAARwAAARwAACglhIlAKCxqE5UIf3wT1QB9hwj4AAAIIADAIAADgAACOAAACCAAwAAAjgAAAjgAAAggAMAAAI4AAAI4AAAgAAOAAACOAAACOAAAIAADgAAAjgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAN2v8HRS/dTmuhlL0AAAAASUVORK5CYII=" 85 | } 86 | }, 87 | "cell_type": "markdown", 88 | "metadata": {}, 89 | "source": [ 90 | "![recommendation-system-tensorflow-movielens-rating-matrix.png](attachment:recommendation-system-tensorflow-movielens-rating-matrix.png)" 91 | ] 92 | }, 93 | { 94 | "cell_type": "markdown", 95 | "metadata": {}, 96 | "source": [ 97 | "Let's inspect the contents of the file:" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "!head ../data/u.data" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "metadata": {}, 112 | "source": [ 113 | "This model ignores the timestamp. We will only make use of the **user id**, **item** (or movie) **id**, and **rating**." 114 | ] 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "metadata": {}, 119 | "source": [ 120 | "## How the model preprocesses data\n", 121 | "\n", 122 | "The model code performs data preprocessing to create a sparse ratings matrix and prepare it for matrix factorization. \n", 123 | "\n", 124 | "**This involves the following steps:**\n", 125 | "\n", 126 | " * The model code loads rows of data from a delimited text file, in this case the MovieLens data. \n", 127 | "\n", 128 | "**Execute this cell to read the MovieLens data set into a Pandas dataframe:**" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "input_file = '../data/u.data'\n", 138 | "headers = ['user_id', 'item_id', 'rating', 'timestamp']\n", 139 | "header_row = None\n", 140 | "ratings_df = pd.read_csv(input_file,\n", 141 | " sep='\\t',\n", 142 | " names=headers,\n", 143 | " header=header_row,\n", 144 | " dtype={\n", 145 | " 'user_id': np.int32,\n", 146 | " 'item_id': np.int32,\n", 147 | " 'rating': np.float32,\n", 148 | " 'timestamp': np.int32,\n", 149 | " })" 150 | ] 151 | }, 152 | { 153 | "cell_type": "markdown", 154 | "metadata": {}, 155 | "source": [ 156 | " * The code establishes a 0-indexed set of unique IDs for users and items. This guarantees that a unique ID corresponds to specific row and column indexes of the sparse ratings matrix.\n", 157 | " \n", 158 | "**Let's find the number of unique users and movies:**" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "np_users = ratings_df.user_id.as_matrix()\n", 168 | "np_items = ratings_df.item_id.as_matrix()\n", 169 | "unique_users = np.unique(np_users)\n", 170 | "unique_items = np.unique(np_items)\n", 171 | "\n", 172 | "n_users = unique_users.shape[0]\n", 173 | "n_items = unique_items.shape[0]" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "n_users" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "n_items" 192 | ] 193 | }, 194 | { 195 | "cell_type": "markdown", 196 | "metadata": {}, 197 | "source": [ 198 | "The MovieLens 100k data uses 1-based IDs where the lowest index of the unique set is 1. But the TensorFlow code assumes a starting index of zero. The code handles this **by converting the Pandas dataframe to a numpy array and subtracting one from each index:**" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": null, 204 | "metadata": {}, 205 | "outputs": [], 206 | "source": [ 207 | "ratings = ratings_df.as_matrix(['user_id', 'item_id', 'rating'])\n", 208 | "ratings[:,0] -= 1\n", 209 | "ratings[:,1] -= 1" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": null, 215 | "metadata": {}, 216 | "outputs": [], 217 | "source": [ 218 | "ratings" 219 | ] 220 | }, 221 | { 222 | "cell_type": "markdown", 223 | "metadata": {}, 224 | "source": [ 225 | " * The model code randomly selects a test set of ratings. By default, 10% of the ratings are chosen for the test set. These ratings are removed from the training set and will be used to evaluate the predictive accuracy of the user and item factors." 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": null, 231 | "metadata": {}, 232 | "outputs": [], 233 | "source": [ 234 | "TEST_SET_RATIO = 10\n", 235 | "\n", 236 | "test_set_size = len(ratings) / TEST_SET_RATIO\n", 237 | "test_set_idx = np.random.choice(xrange(len(ratings)),\n", 238 | " size=test_set_size, replace=False)\n", 239 | "test_set_idx = sorted(test_set_idx)\n", 240 | "\n", 241 | "ts_ratings = ratings[test_set_idx]\n", 242 | "tr_ratings = np.delete(ratings, test_set_idx, axis=0)" 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": null, 248 | "metadata": {}, 249 | "outputs": [], 250 | "source": [ 251 | "test_set_idx" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": null, 257 | "metadata": {}, 258 | "outputs": [], 259 | "source": [ 260 | "ts_ratings" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": null, 266 | "metadata": {}, 267 | "outputs": [], 268 | "source": [ 269 | "tr_ratings" 270 | ] 271 | }, 272 | { 273 | "cell_type": "markdown", 274 | "metadata": {}, 275 | "source": [ 276 | " * Finally, the code creates a `scipy` sparse matrix in coordinate form (`coo_matrix`) that includes the user and item indexes and ratings. The `coo_matrix` object acts as a wrapper for a sparse matrix. It also performs validation of the user and ratings indexes, checking for errors in preprocessing:" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": null, 282 | "metadata": {}, 283 | "outputs": [], 284 | "source": [ 285 | "u_tr, i_tr, r_tr = zip(*tr_ratings)\n", 286 | "tr_sparse = coo_matrix((r_tr, (u_tr, i_tr)), shape=(n_users, n_items))\n", 287 | "\n", 288 | "u_ts, i_ts, r_ts = zip(*ts_ratings)\n", 289 | "test_sparse = coo_matrix((r_ts, (u_ts, i_ts)), shape=(n_users, n_items))" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": null, 295 | "metadata": {}, 296 | "outputs": [], 297 | "source": [ 298 | "tr_sparse" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": null, 304 | "metadata": {}, 305 | "outputs": [], 306 | "source": [ 307 | "test_sparse" 308 | ] 309 | }, 310 | { 311 | "cell_type": "markdown", 312 | "metadata": {}, 313 | "source": [ 314 | "## How the WALS Algorithm is implemented in TensorFlow\n", 315 | "\n", 316 | "You can read the [solution online](https://wikipedia.org/wiki/Root-mean-square_deviation) to understand how the WALS algorithm is implemented in TensorFlow." 317 | ] 318 | } 319 | ], 320 | "metadata": { 321 | "kernelspec": { 322 | "display_name": "Python 2", 323 | "language": "python", 324 | "name": "python2" 325 | }, 326 | "language_info": { 327 | "codemirror_mode": { 328 | "name": "ipython", 329 | "version": 2 330 | }, 331 | "file_extension": ".py", 332 | "mimetype": "text/x-python", 333 | "name": "python", 334 | "nbconvert_exporter": "python", 335 | "pygments_lexer": "ipython2", 336 | "version": "2.7.13" 337 | } 338 | }, 339 | "nbformat": 4, 340 | "nbformat_minor": 2 341 | } 342 | --------------------------------------------------------------------------------