├── deployment ├── solution-assistant │ ├── requirements.in │ ├── requirements.txt │ ├── solution-assistant.yaml │ └── src │ │ └── lambda_function.py ├── arch.png ├── arch_dark.png ├── sagemaker-notebook-instance-stack.yaml ├── sagemaker-permissions-stack.yaml └── sagemaker-graph-entity-resolution.yaml ├── source └── sagemaker │ ├── sagemaker_graph_entity_resolution │ ├── __init__.py │ ├── dgl_entity_resolution │ │ ├── __init__.py │ │ ├── requirements.txt │ │ ├── setup.py │ │ ├── estimator_fns.py │ │ ├── graph.py │ │ ├── data.py │ │ ├── train_dgl_pytorch_entity_resolution.py │ │ ├── utils.py │ │ └── model.py │ ├── requirements.txt │ ├── setup.py │ └── config.py │ ├── setup.sh │ ├── data-preparation │ ├── data_sampling.py │ └── data_prep.py │ ├── data-preprocessing │ └── data_preprocessing.py │ ├── baseline │ └── train_pytorch_mlp_entity_resolution.py │ └── dgl-entity-resolution.ipynb ├── NOTICE ├── CODE_OF_CONDUCT.md ├── metadata └── metadata.json ├── README.md ├── CONTRIBUTING.md └── LICENSE /deployment/solution-assistant/requirements.in: -------------------------------------------------------------------------------- 1 | crhelper 2 | -------------------------------------------------------------------------------- /source/sagemaker/sagemaker_graph_entity_resolution/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /source/sagemaker/sagemaker_graph_entity_resolution/dgl_entity_resolution/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /source/sagemaker/sagemaker_graph_entity_resolution/dgl_entity_resolution/requirements.txt: -------------------------------------------------------------------------------- 1 | dgl-cu101==0.5.0 -------------------------------------------------------------------------------- /source/sagemaker/sagemaker_graph_entity_resolution/requirements.txt: -------------------------------------------------------------------------------- 1 | sagemaker==1.72.0 2 | awscli>=1.18.140 -------------------------------------------------------------------------------- /deployment/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/sagemaker-graph-entity-resolution/HEAD/deployment/arch.png -------------------------------------------------------------------------------- /deployment/arch_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/sagemaker-graph-entity-resolution/HEAD/deployment/arch_dark.png -------------------------------------------------------------------------------- /source/sagemaker/setup.sh: -------------------------------------------------------------------------------- 1 | export PIP_DISABLE_PIP_VERSION_CHECK=1 2 | 3 | pip install -r ./sagemaker_graph_entity_resolution/requirements.txt -q 4 | pip install -e ./sagemaker_graph_entity_resolution/ -------------------------------------------------------------------------------- /deployment/solution-assistant/requirements.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile 3 | # To update, run: 4 | # 5 | # pip-compile requirements.in 6 | # 7 | crhelper==2.0.6 # via -r requirements.in 8 | -------------------------------------------------------------------------------- /source/sagemaker/sagemaker_graph_entity_resolution/dgl_entity_resolution/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='dgl_entity_resolution', 5 | version='1.0', 6 | description='entity resolution on sagemaker using dgl', 7 | packages=find_packages(exclude=('test',)) 8 | ) -------------------------------------------------------------------------------- /source/sagemaker/sagemaker_graph_entity_resolution/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | 4 | setup( 5 | name='sagemaker_graph_entity_resolution', 6 | version='1.0', 7 | description='A package to organize code in the solution', 8 | packages=find_packages(exclude=('test',)) 9 | ) -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /deployment/solution-assistant/solution-assistant.yaml: -------------------------------------------------------------------------------- 1 | AWSTemplateFormatVersion: 2010-09-09 2 | Description: Stack for Solution Helper resources. 3 | Parameters: 4 | SolutionPrefix: 5 | Description: Used as a prefix to name all stack resources. 6 | Type: String 7 | SolutionsRefBucketName: 8 | Description: Amazon S3 Bucket containing solutions 9 | Type: String 10 | SolutionS3BucketName: 11 | Description: Amazon S3 Bucket used to store trained model and data. 12 | Type: String 13 | RoleArn: 14 | Description: Role to use for lambda resource 15 | Type: String 16 | Mappings: 17 | Function: 18 | SolutionAssistant: 19 | S3Key: "Entity-resolution-for-smart-advertising/build/solution_assistant.zip" 20 | Resources: 21 | SolutionAssistant: 22 | Type: "Custom::SolutionAssistant" 23 | Properties: 24 | ServiceToken: !GetAtt SolutionAssistantLambda.Arn 25 | SolutionS3BucketName: !Ref SolutionS3BucketName 26 | SolutionAssistantLambda: 27 | Type: AWS::Lambda::Function 28 | Properties: 29 | Handler: "lambda_function.handler" 30 | FunctionName: !Sub "${SolutionPrefix}-solution-assistant" 31 | Role: !Ref RoleArn 32 | Runtime: "python3.8" 33 | Code: 34 | S3Bucket: !Ref SolutionsRefBucketName 35 | S3Key: !FindInMap 36 | - Function 37 | - SolutionAssistant 38 | - S3Key 39 | Timeout : 60 40 | Metadata: 41 | cfn_nag: 42 | rules_to_suppress: 43 | - id: W58 44 | reason: Passed in role has cloudwatch write permissions 45 | -------------------------------------------------------------------------------- /deployment/solution-assistant/src/lambda_function.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | import sys 3 | 4 | sys.path.append('./site-packages') 5 | from crhelper import CfnResource 6 | 7 | helper = CfnResource() 8 | 9 | 10 | @helper.create 11 | def on_create(_, __): 12 | pass 13 | 14 | @helper.update 15 | def on_update(_, __): 16 | pass 17 | 18 | 19 | def delete_s3_objects(bucket_name): 20 | s3_resource = boto3.resource("s3") 21 | try: 22 | s3_resource.Bucket(bucket_name).objects.all().delete() 23 | print( 24 | "Successfully deleted objects in bucket " 25 | "called '{}'.".format(bucket_name) 26 | ) 27 | except s3_resource.meta.client.exceptions.NoSuchBucket: 28 | print( 29 | "Could not find bucket called '{}'. " 30 | "Skipping delete.".format(bucket_name) 31 | ) 32 | 33 | def delete_s3_bucket(bucket_name): 34 | s3_resource = boto3.resource("s3") 35 | try: 36 | s3_resource.Bucket(bucket_name).delete() 37 | print( 38 | "Successfully deleted bucket " 39 | "called '{}'.".format(bucket_name) 40 | ) 41 | except s3_resource.meta.client.exceptions.NoSuchBucket: 42 | print( 43 | "Could not find bucket called '{}'. " 44 | "Skipping delete.".format(bucket_name) 45 | ) 46 | 47 | 48 | @helper.delete 49 | def on_delete(event, __): 50 | 51 | # remove files in s3 and delete bucket 52 | solution_bucket = event["ResourceProperties"]["SolutionS3BucketName"] 53 | delete_s3_objects(solution_bucket) 54 | delete_s3_bucket(solution_bucket) 55 | 56 | 57 | def handler(event, context): 58 | helper(event, context) 59 | -------------------------------------------------------------------------------- /source/sagemaker/sagemaker_graph_entity_resolution/config.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import boto3 4 | import sagemaker 5 | from pathlib import Path 6 | 7 | 8 | def get_current_folder(global_variables): 9 | # if calling from a file 10 | if "__file__" in global_variables: 11 | current_file = Path(global_variables["__file__"]) 12 | current_folder = current_file.parent.resolve() 13 | # if calling from a notebook 14 | else: 15 | current_folder = Path(os.getcwd()) 16 | return current_folder 17 | 18 | region = boto3.session.Session().region_name 19 | default_bucket = sagemaker.session.Session(boto3.session.Session()).default_bucket() 20 | default_role = sagemaker.get_execution_role() 21 | 22 | cfn_stack_outputs = {} 23 | current_folder = get_current_folder(globals()) 24 | cfn_stack_outputs_filepath = Path(current_folder, '../stack_outputs.json').resolve() 25 | 26 | if os.path.exists(cfn_stack_outputs_filepath): 27 | with open(cfn_stack_outputs_filepath) as f: 28 | cfn_stack_outputs = json.load(f) 29 | 30 | 31 | solution_name = cfn_stack_outputs.get('SolutionName', 'Entity-resolution-for-smart-advertising') 32 | solution_upstream_bucket = cfn_stack_outputs.get('SolutionUpstreamS3Bucket', 'sagemaker-solutions-{}'.format(region)) 33 | 34 | solution_prefix = cfn_stack_outputs.get('SolutionPrefix', 'sagemaker-soln-entity-res') 35 | solution_bucket = cfn_stack_outputs.get('SolutionS3Bucket', default_bucket) 36 | 37 | s3_data_prefix = cfn_stack_outputs.get('S3InputDataPrefix', 'raw-data') 38 | s3_processing_output = cfn_stack_outputs.get('S3ProcessingJobOutputPrefix', 'processed-data') 39 | s3_train_output = cfn_stack_outputs.get('S3TrainingJobOutputPrefix', 'training-output') 40 | 41 | role = cfn_stack_outputs.get('IamRole', default_role) -------------------------------------------------------------------------------- /source/sagemaker/data-preparation/data_sampling.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | 5 | import pandas as pd 6 | import numpy as np 7 | np.random.seed(0) 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--data-dir', type=str, default='/opt/ml/processing/input') 12 | parser.add_argument('--output-dir', type=str, default='/opt/ml/processing/output') 13 | parser.add_argument('--logs', type=str, default='logs.csv', help='log of transient id web activity') 14 | parser.add_argument('--train', type=str, default='train.csv', help='pairs of transient ids that are the same user') 15 | parser.add_argument('--num-sample-src', type=int, default=20000, help='number of src nodes to sample') 16 | return parser.parse_args() 17 | 18 | def sample_train(data_dir, output_dir, train_file, sample_src): 19 | train = pd.read_csv(os.path.join(data_dir, train_file), header=None) 20 | sampled_train_srcs = pd.DataFrame({0:np.random.choice(train[0].unique(), sample_src, replace=False)}) 21 | sampled_train = train.merge(sampled_train_srcs, how='inner', on=[0]) 22 | 23 | initial_node_count = len(pd.concat([train[0], train[1]]).unique()) 24 | final_nodes = pd.concat([sampled_train[0], sampled_train[1]]).unique() 25 | final_node_count = len(final_nodes) 26 | print("Sampled {} train edges from original train set of size {}".format(sampled_train.shape[0], train.shape[0])) 27 | print("{} unique nodes sampled from a set of size {}".format(final_node_count, initial_node_count)) 28 | 29 | with open(os.path.join(output_dir, train_file), 'w') as f: 30 | sampled_train.to_csv(f, index=False, header=False) 31 | 32 | return final_nodes 33 | 34 | 35 | def reduce_logs(data_dir, output_dir, log_file, ids): 36 | logs = pd.read_csv(os.path.join(data_dir, log_file)) 37 | reduced_logs = logs.merge(pd.DataFrame({'uid':ids}), how='inner', on='uid') 38 | 39 | with open(os.path.join(output_dir, log_file), 'w') as f: 40 | reduced_logs.to_csv(f, index=False, header=True) 41 | 42 | if __name__ == '__main__': 43 | 44 | args = parse_args() 45 | node_ids = sample_train(args.data_dir, args.output_dir, args.train, args.num_sample_src) 46 | reduce_logs(args.data_dir, args.output_dir, args.logs, node_ids) -------------------------------------------------------------------------------- /metadata/metadata.json: -------------------------------------------------------------------------------- 1 | { 2 | "id": "sagemaker-soln-ger", 3 | "name": "Amazon SageMaker and Deep Graph Library for Entity Resolution in Identity Graphs", 4 | "shortName": "Graph Entity Resolution", 5 | "priority": 0, 6 | "desc": "User identity plays the essential role in the success of an online advertising company/platform but as the number and variety of different devices increases, online user activity becomes highly fragmented. In this project, we show how to use SageMaker and Graph Neural Networks to perform cross-device entity linking for online advertising.", 7 | "meta": "entity activity resolution graph gnn network identity", 8 | "tags": ["advertising", "identity resolution", "personalization"], 9 | "parameters": [ 10 | { 11 | "name": "SolutionPrefix", 12 | "type": "text", 13 | "default": "sagemaker-soln-graph-entity" 14 | }, 15 | { 16 | "name": "IamRole", 17 | "type": "text", 18 | "default": "" 19 | }, 20 | { 21 | "name": "S3RawDataPrefix", 22 | "type": "text", 23 | "default": "raw-data" 24 | }, 25 | { 26 | "name": "S3ProcessingJobOutputPrefix", 27 | "type": "text", 28 | "default": "preprocessed-data" 29 | }, 30 | { 31 | "name": "S3TrainingJobOutputPrefix", 32 | "type": "text", 33 | "default": "training-output" 34 | }, 35 | { 36 | "name": "CreateSageMakerNotebookInstance", 37 | "type": "text", 38 | "default": "false" 39 | }, 40 | { 41 | "name": "SageMakerNotebookInstanceType", 42 | "type": "text", 43 | "default": "ml.m4.xlarge" 44 | }, 45 | { 46 | "name": "StackVersion", 47 | "type": "text", 48 | "default": "release" 49 | } 50 | ], 51 | "acknowledgements": ["CAPABILITY_IAM","CAPABILITY_NAMED_IAM"], 52 | "cloudFormationTemplate": "s3-us-east-2.amazonaws.com/sagemaker-solutions-build-us-east-2/Entity-resolution-for-smart-advertising/deployment/sagemaker-graph-entity-resolution.yaml", 53 | "serviceCatalogProduct": "TBD", 54 | "copyS3Source": "sagemaker-solutions-build-us-east-2", 55 | "copyS3SourcePrefix": "Entity-resolution-for-smart-advertising/source/sagemaker", 56 | "notebooksDirectory": "Entity-resolution-for-smart-advertising/source/sagemaker", 57 | "notebookPaths": [ 58 | "Entity-resolution-for-smart-advertising/source/sagemaker/dgl-entity-resolution.ipynb" 59 | ], 60 | "permissions": "TBD" 61 | } -------------------------------------------------------------------------------- /source/sagemaker/sagemaker_graph_entity_resolution/dgl_entity_resolution/estimator_fns.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import logging 4 | 5 | 6 | def parse_args(): 7 | parser = argparse.ArgumentParser() 8 | 9 | parser.add_argument('--training-dir', type=str, default=os.environ['SM_CHANNEL_TRAIN']) 10 | parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR']) 11 | parser.add_argument('--output-dir', type=str, default=os.environ['SM_OUTPUT_DATA_DIR']) 12 | parser.add_argument('--train-edges', type=str, default='user_train_edges.csv') 13 | parser.add_argument('--test-edges', type=str, default='user_test_edges.csv') 14 | parser.add_argument('--transient-nodes', type=str, default='transient_nodes.csv') 15 | parser.add_argument('--transient-edges', type=str, default='transient_edges.csv') 16 | parser.add_argument('--website-nodes', type=str, default='website_nodes.csv') 17 | parser.add_argument('--website-group-edges', type=str, default='website_group_edges.csv') 18 | parser.add_argument('--mini-batch', type=lambda x: (str(x).lower() in ['true', '1', 'yes']), 19 | default=True, help='use mini-batch training and sample graph') 20 | parser.add_argument('--batch-size', type=int, default=5000) 21 | parser.add_argument('--num-gpus', type=int, default=1) 22 | parser.add_argument('--optimizer', type=str, default='adam') 23 | parser.add_argument('--lr', type=float, default=1e-2) 24 | parser.add_argument('--embedding-size', type=int, default=64, help="embedding size for node embedding") 25 | parser.add_argument('--n-epochs', type=int, default=20) 26 | parser.add_argument('--n-neighbors', type=int, default=100, help='number of neighbors to sample') 27 | parser.add_argument('--negative-sampling-rate' ,type=int, default=10, help='rate of negatively sampled edges') 28 | parser.add_argument('--n-hidden', type=int, default=16, help='number of hidden units') 29 | parser.add_argument('--n-layers', type=int, default=2, help='number of hidden layers') 30 | parser.add_argument('--weight-decay', type=float, default=5e-4, help='Weight for L2 loss') 31 | parser.add_argument('--regularization-param', type=float, default=5e-4, help='Weight for regularization of decoder') 32 | parser.add_argument('--grad-norm', type=float, default=1.0, help='norm to clip gradient to') 33 | 34 | return parser.parse_args() 35 | 36 | 37 | def get_logger(name): 38 | logger = logging.getLogger(name) 39 | log_format = '%(asctime)s %(levelname)s %(name)s: %(message)s' 40 | logging.basicConfig(format=log_format, level=logging.INFO) 41 | logger.setLevel(logging.INFO) 42 | return logger -------------------------------------------------------------------------------- /source/sagemaker/data-preparation/data_prep.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import json 5 | 6 | import pandas as pd 7 | 8 | def get_logger(name): 9 | logger = logging.getLogger(name) 10 | log_format = '%(asctime)s %(levelname)s %(name)s: %(message)s' 11 | logging.basicConfig(format=log_format, level=logging.INFO) 12 | logger.setLevel(logging.INFO) 13 | return logger 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--data-dir', type=str, default='/opt/ml/processing/input') 19 | parser.add_argument('--output-dir', type=str, default='/opt/ml/processing/output') 20 | parser.add_argument('--urls', type=str, default='urls.csv', help='map fact id to urls') 21 | parser.add_argument('--titles', type=str, default='titles.csv', help='map fact id to url titles') 22 | parser.add_argument('--facts', type=str, default='facts.json', help='map user to list of facts') 23 | parser.add_argument('--logs', type=str, default='logs.csv', help='file to store output normalized log files') 24 | parser.add_argument('--primary-key', type=str, default='fid', help='id key that corresponds to url') 25 | return parser.parse_args() 26 | 27 | 28 | def load_url_data(data_dir, urls_path, titles_path, primary_key): 29 | logging.info("Loading website urls from file: {}".format(os.path.join(data_dir, urls_path))) 30 | urls_df = pd.read_csv(os.path.join(data_dir, urls_path), header=None, names=[primary_key, 'urls']) 31 | logging.info("Loading website titles from file: {}".format(os.path.join(data_dir, titles_path))) 32 | titles_df = pd.read_csv(os.path.join(data_dir, titles_path), header=None, names=[primary_key, 'titles']) 33 | logging.info("Merging website urls with website titles") 34 | return urls_df.merge(titles_df, how='left', on=primary_key).fillna("").set_index(primary_key) 35 | 36 | 37 | def merge_websites_with_user_visits(data_dir, facts, url_data, primary_key, output_dir, logs): 38 | with open(os.path.join(data_dir, facts)) as f_in: 39 | for i, line in enumerate(f_in): 40 | j = json.loads(line.strip()) 41 | user_visits = pd.json_normalize(j.get("facts")) 42 | fids = user_visits[primary_key].values 43 | user_visits = pd.concat((user_visits.set_index(primary_key), url_data.loc[fids]), axis=1) 44 | user_visits['uid'] = j.get('uid') 45 | mode, header = ('w', True) if i == 0 else ('a', False) 46 | with open(os.path.join(output_dir, logs), mode) as f: 47 | user_visits.to_csv(f, index=False, header=header) 48 | 49 | 50 | if __name__ == '__main__': 51 | logging = get_logger(__name__) 52 | 53 | args = parse_args() 54 | 55 | websites = load_url_data(args.data_dir, args.urls, args.titles, args.primary_key) 56 | logging.info("Obtained website info; merging with user visits") 57 | merge_websites_with_user_visits(args.data_dir, args.facts, websites, args.primary_key, args.output_dir, args.logs) 58 | -------------------------------------------------------------------------------- /source/sagemaker/sagemaker_graph_entity_resolution/dgl_entity_resolution/graph.py: -------------------------------------------------------------------------------- 1 | import os 2 | import dgl 3 | 4 | from data import * 5 | 6 | 7 | def construct_graph(training_dir, training_edges, transient_nodes, transient_edges, website_nodes, website_edges): 8 | 9 | def _full_path(f): 10 | return os.path.join(training_dir, f) 11 | 12 | edgelists, id_to_node = {}, {} 13 | 14 | # parse and add training edges 15 | training_edgelist, id_to_node = parse_edgelist(_full_path(training_edges), id_to_node, 16 | source_type='user', sink_type='user') 17 | print("Read user -> user training edgelist from {}".format(_full_path(training_edges))) 18 | edgelists[('user', 'same_entity', 'user')] = training_edgelist 19 | edgelists[('user', 'same_entity_reversed', 'user')] = [(b, a) for a, b in training_edgelist] 20 | 21 | # parse and add transient edges 22 | transient_edgelist, id_to_node = parse_edgelist(_full_path(transient_edges), id_to_node, 23 | source_type='user', sink_type='website') 24 | print("Read user -> website edgelist from {}".format(_full_path(transient_edges))) 25 | edgelists[('user', 'visits', 'website')] = transient_edgelist 26 | edgelists[('website', 'visited_by', 'user')] = [(b, a) for a, b in transient_edgelist] 27 | 28 | # parse and add website edges 29 | website_edgelist, id_to_node = parse_edgelist(_full_path(website_edges), id_to_node, 30 | source_type='website', sink_type='domain') 31 | print("Read website -> domain edgelist from {}".format(_full_path(website_edges))) 32 | edgelists[('website', 'owned_by', 'domain')] = website_edgelist 33 | edgelists[('domain', 'owns', 'website')] = [(b, a) for a, b in website_edgelist] 34 | 35 | # get user features 36 | user_features, new_nodes = get_features(id_to_node['user'], _full_path(transient_nodes)) 37 | print("Got user features from {}".format(_full_path(transient_nodes))) 38 | 39 | # add self relation to user nodes 40 | edgelists[('user', 'self_relation', 'user')] = [(u, u) for u in id_to_node['user'].values()] 41 | 42 | # get website features 43 | website_features = get_website_features(id_to_node['website'], _full_path(website_nodes)) 44 | print("Got website features from {}".format(_full_path(website_nodes))) 45 | 46 | g = dgl.heterograph(edgelists) 47 | print("Constructed heterograph with the following metagraph structure: Node types {}, Edge types{}".format( 48 | g.ntypes, g.canonical_etypes)) 49 | print("Number of user nodes : {}".format(g.number_of_nodes('user'))) 50 | 51 | reverse_etypes = {'same_entity': 'same_entity_reversed', 52 | 'same_entity_reversed': 'same_entity', 53 | 'visits': 'visited_by', 54 | 'visited_by': 'visits', 55 | 'owned_by': 'owns', 56 | 'owns': 'owned_by' 57 | } 58 | 59 | print(g) 60 | 61 | return g, (user_features, website_features), id_to_node, reverse_etypes 62 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Entity Resolution for Smart Advertising using Amazon SageMaker 2 | 3 | This project shows how to use [Deep Graph Library (DGL)](https://www.dgl.ai/) on [Amazon SageMaker](https://aws.amazon.com/sagemaker/) to train a graph neural network (GNN) model to perform entity resolution on customer identity graphs. See the project detail page to learn more about the techniques used. 4 | 5 | ## Getting Started 6 | 7 | You will need an AWS account to use this solution. Sign up for an account [here](https://aws.amazon.com/). 8 | 9 | To run this JumpStart 1P Solution and have the infrastructure deploy to your AWS account you will need to create an active SageMaker Studio instance (see [Onboard to Amazon SageMaker Studio](https://docs.aws.amazon.com/sagemaker/latest/dg/gs-studio-onboard.html)). When your Studio instance is *Ready*, use the instructions in [SageMaker JumpStart](https://docs.aws.amazon.com/sagemaker/latest/dg/studio-jumpstart.html) to 1-Click Launch the solution. 10 | 11 | The solution artifacts are included in this GitHub repository for reference. 12 | 13 | *Note*: Solutions are available in most regions including us-west-2, and us-east-1. 14 | 15 | **Caution**: Cloning this GitHub repository and running the code manually could lead to unexpected issues! Use the AWS CloudFormation template. You'll get an Amazon SageMaker Notebook instance that's been correctly setup and configured to access the other resources in the solution. 16 | 17 | ## Architecture 18 | The project architecture deployed by the cloud formation template is shown here. 19 | 20 | ![](deployment/arch_dark.png) 21 | 22 | ## Contents 23 | 24 | * `deployment/` 25 | * `sagemaker-graph-entity-resolution.yaml`: Creates AWS CloudFormation Stack for solution 26 | * `source/` 27 | * `lambda/` 28 | * `data-preprocessing/` 29 | * `index.py`: Lambda function script for invoking SageMaker processing 30 | * `graph-modelling/` 31 | * `index.py`: Lambda function script for invoking SageMaker training 32 | * `sagemaker/` 33 | * `data-preprocessing/` 34 | * `data_preparation_script.py`: Custom script used to prepare CIKM cup data to solution input format 35 | * `data_preprocessing_script.py`: Custom script used by SageMaker Processing for data processing/feature engineering 36 | * `dgl-entity-resolution/` 37 | * `model.py`: Implements the various graph neural network models used in the project with the pytorch backend 38 | * `data.py`: Contains functions for reading edges and node features 39 | * `estimator_fns.py`: Contains functions for parsing input from SageMaker estimator objects 40 | * `graph.py`: Contains functions for constructing DGL Graphs with node features and edge lists 41 | * `requirements.txt`: Describes Python package requirements of the Amazon SageMaker training instance 42 | * `sampler.py`: Contains functions for graph sampling for mini-batch training 43 | * `train_dgl_pytorch_entry_point.py`: python entry point used by the notebook for GNN training with DGL pytorch backend 44 | * `utils.py`: python script with utility functions for computing metrics and plots 45 | * `dgl-entity-resolution.ipynb`: Orchestrates the solution. Triggers preprocessing and model training 46 | 47 | ## License 48 | 49 | This project is licensed under the Apache-2.0 License. 50 | 51 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *master* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | 61 | We may ask you to sign a [Contributor License Agreement (CLA)](http://en.wikipedia.org/wiki/Contributor_License_Agreement) for larger changes. 62 | -------------------------------------------------------------------------------- /source/sagemaker/sagemaker_graph_entity_resolution/dgl_entity_resolution/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | 5 | def parse_edgelist(edges, id_to_node, source_type='user', sink_type='user'): 6 | """ 7 | Parse an edgelist path file and return the edges as a list of tuple 8 | :param edges: path to comma separated file containing bipartite edges with header for edgetype 9 | :param id_to_node: dictionary containing mapping for node names(id) to dgl node indices 10 | :param source_type: type of the source node in the edge. defaults to 'user' if no header 11 | :param sink_type: type of the sink node in the edge. defaults to 'user' if no header. 12 | :return: (list, dict) a list containing edges of a single relationship type as tuples and updated id_to_node dict. 13 | """ 14 | edge_list = [] 15 | source_pointer, sink_pointer = 0, 0 16 | with open(edges, "r") as fh: 17 | for i, line in enumerate(fh): 18 | source, sink = line.strip().split(",") 19 | if i == 0: 20 | if source_type in id_to_node: 21 | source_pointer = max(id_to_node[source_type].values()) + 1 22 | if sink_type in id_to_node: 23 | sink_pointer = max(id_to_node[sink_type].values()) + 1 24 | continue 25 | source_node, id_to_node, source_pointer = _get_node_idx(id_to_node, source_type, source, source_pointer) 26 | if source_type == sink_type: 27 | sink_node, id_to_node, source_pointer = _get_node_idx(id_to_node, sink_type, sink, source_pointer) 28 | else: 29 | sink_node, id_to_node, sink_pointer = _get_node_idx(id_to_node, sink_type, sink, sink_pointer) 30 | 31 | edge_list.append((source_node, sink_node)) 32 | 33 | return edge_list, id_to_node 34 | 35 | 36 | def _get_node_idx(id_to_node, node_type, node_id, ptr): 37 | if node_type in id_to_node: 38 | if node_id in id_to_node[node_type]: 39 | node_idx = id_to_node[node_type][node_id] 40 | else: 41 | id_to_node[node_type][node_id] = ptr 42 | node_idx = ptr 43 | ptr += 1 44 | else: 45 | id_to_node[node_type] = {} 46 | id_to_node[node_type][node_id] = ptr 47 | node_idx = ptr 48 | ptr += 1 49 | 50 | return node_idx, id_to_node, ptr 51 | 52 | 53 | def get_features(id_to_node, node_features): 54 | """ 55 | :param id_to_node: dictionary mapping node names(id) to dgl node idx 56 | :param node_features: path to file containing node features 57 | :return: (np.ndarray, list) node feature matrix in order and new nodes not yet in the graph 58 | """ 59 | indices, features, new_nodes = [], [], [] 60 | max_node = max(id_to_node.values()) 61 | with open(node_features, "r") as fh: 62 | for line in fh: 63 | node_feats = line.strip().split(",") 64 | node_id = node_feats[0] 65 | feats = np.array(list(map(float, node_feats[1:]))) 66 | features.append(feats) 67 | if node_id not in id_to_node: 68 | max_node += 1 69 | id_to_node[node_id] = max_node 70 | new_nodes.append(max_node) 71 | 72 | indices.append(id_to_node[node_id]) 73 | 74 | features = np.array(features).astype('float32') 75 | features = features[np.argsort(indices), :] 76 | return features, new_nodes 77 | 78 | 79 | def get_website_features(id_to_node, website_features): 80 | """ 81 | :param id_to_node: dictionary mapping node names(id) to dgl node idx 82 | :param website_features: path to file containing website features 83 | :return: (np.ndarray) website feature matrix in order 84 | """ 85 | features_df = pd.read_csv(website_features, header=None, index_col=0) 86 | features_df.reindex(sorted(features_df.index, key=lambda x: id_to_node[x])) 87 | return features_df.values.astype(np.float32) 88 | -------------------------------------------------------------------------------- /deployment/sagemaker-notebook-instance-stack.yaml: -------------------------------------------------------------------------------- 1 | AWSTemplateFormatVersion: "2010-09-09" 2 | Description: "(SA0007) - sagemaker-graph-entity-resolution SageMaker stack" 3 | Parameters: 4 | SolutionPrefix: 5 | Description: Enter the name of the prefix for the solution used for naming 6 | Type: String 7 | SolutionS3BucketName: 8 | Description: Enter the name of the S3 bucket for the solution 9 | Type: String 10 | S3InputDataPrefix: 11 | Description: S3 prefix where raw data is stored 12 | Type: String 13 | Default: "raw-data" 14 | S3ProcessingJobOutputPrefix: 15 | Description: S3 prefix where preprocessed data is stored after processing 16 | Type: String 17 | Default: "preprocessed-data" 18 | S3TrainingJobOutputPrefix: 19 | Description: S3 prefix where training outputs are stored after training 20 | Type: String 21 | Default: "training-output" 22 | SageMakerNotebookInstanceType: 23 | Description: Instance type of the SageMaker notebook instance 24 | Type: String 25 | Default: "ml.t3.medium" 26 | NotebookInstanceExecutionRoleArn: 27 | Type: String 28 | Description: Execution Role for the SageMaker notebook instance 29 | StackVersion: 30 | Type: String 31 | Description: CloudFormation Stack version. 32 | Default: "release" 33 | 34 | Mappings: 35 | S3: 36 | release: 37 | BucketPrefix: "sagemaker-solutions-prod" 38 | development: 39 | BucketPrefix: "sagemaker-solutions-devo" 40 | SageMaker: 41 | Source: 42 | S3Key: "Entity-resolution-for-smart-advertising/source/sagemaker/" 43 | 44 | Resources: 45 | NotebookInstance: 46 | Type: AWS::SageMaker::NotebookInstance 47 | Properties: 48 | DirectInternetAccess: Enabled 49 | InstanceType: !Ref SageMakerNotebookInstanceType 50 | LifecycleConfigName: !GetAtt LifeCycleConfig.NotebookInstanceLifecycleConfigName 51 | NotebookInstanceName: !Sub "${SolutionPrefix}-notebook-instance" 52 | RoleArn: !Ref NotebookInstanceExecutionRoleArn 53 | VolumeSizeInGB: 120 54 | Metadata: 55 | cfn_nag: 56 | rules_to_suppress: 57 | - id: W1201 58 | reason: Solution does not have KMS encryption enabled by default 59 | LifeCycleConfig: 60 | Type: AWS::SageMaker::NotebookInstanceLifecycleConfig 61 | Properties: 62 | NotebookInstanceLifecycleConfigName: !Sub "${SolutionPrefix}-nb-lifecycle-config" 63 | OnCreate: 64 | - Content: 65 | Fn::Base64: !Sub 66 | - | 67 | cd /home/ec2-user/SageMaker 68 | aws s3 cp --recursive s3://${SolutionsRefBucketBase}-${AWS::Region}/${SolutionsRefSource} . 69 | touch stack_outputs.json 70 | echo '{' >> stack_outputs.json 71 | echo ' "IamRole": "${NotebookInstanceExecutionRoleArn}",' >> stack_outputs.json 72 | echo ' "SolutionPrefix": "${SolutionPrefix}",' >> stack_outputs.json 73 | echo ' "SolutionName": "Entity-resolution-for-smart-advertising",' >> stack_outputs.json 74 | echo ' "SolutionUpstreamS3Bucket": "${SolutionsRefBucketBase}-${AWS::Region}",' >> stack_outputs.json 75 | echo ' "SolutionS3Bucket": "${SolutionS3BucketName}",' >> stack_outputs.json 76 | echo ' "S3InputDataPrefix": "${S3InputDataPrefix}",' >> stack_outputs.json 77 | echo ' "S3ProcessingJobOutputPrefix": "${S3ProcessingJobOutputPrefix}",' >> stack_outputs.json 78 | echo ' "S3TrainingJobOutputPrefix": "${S3TrainingJobOutputPrefix}"' >> stack_outputs.json 79 | echo '}' >> stack_outputs.json 80 | sudo chown -R ec2-user:ec2-user . 81 | - SolutionsRefBucketBase: !FindInMap [S3, !Ref StackVersion, BucketPrefix] 82 | SolutionsRefSource: !FindInMap [SageMaker, Source, S3Key] 83 | Outputs: 84 | SourceCode: 85 | Description: "Open Jupyter IDE. This authenticate you against Jupyter." 86 | Value: !Sub "https://console.aws.amazon.com/sagemaker/home?region=${AWS::Region}#/notebook-instances/openNotebook/${SolutionPrefix}-notebook-instance?view=classic" 87 | NotebookInstance: 88 | Description: "SageMaker Notebook instance to manually orchestrate data preprocessing and model training" 89 | Value: !Sub "https://${SolutionPrefix}-notebook-instance.notebook.${AWS::Region}.sagemaker.aws/notebooks/dgl-entity-resolution.ipynb" -------------------------------------------------------------------------------- /deployment/sagemaker-permissions-stack.yaml: -------------------------------------------------------------------------------- 1 | AWSTemplateFormatVersion: "2010-09-09" 2 | Description: "(SA0007) - sagemaker-graph-entity-resolution SageMaker permissions stack" 3 | Parameters: 4 | SolutionPrefix: 5 | Description: Enter the name of the prefix for the solution used for naming 6 | Type: String 7 | Default: "sagemaker-soln-graph-entity-resolution" 8 | SolutionS3BucketName: 9 | Description: Enter the name of the S3 bucket for the solution 10 | Type: String 11 | Default: "sagemaker-soln-*" 12 | StackVersion: 13 | Description: Enter the name of the template stack version 14 | Type: String 15 | Default: "release" 16 | 17 | Mappings: 18 | S3: 19 | release: 20 | BucketPrefix: "sagemaker-solutions-prod" 21 | development: 22 | BucketPrefix: "sagemaker-solutions-devo" 23 | 24 | Resources: 25 | NotebookInstanceExecutionRole: 26 | Type: AWS::IAM::Role 27 | Properties: 28 | RoleName: !Sub "${SolutionPrefix}-${AWS::Region}-nb-role" 29 | AssumeRolePolicyDocument: 30 | Version: '2012-10-17' 31 | Statement: 32 | - Effect: Allow 33 | Principal: 34 | AWS: 35 | - !Sub "arn:aws:iam::${AWS::AccountId}:root" 36 | Service: 37 | - sagemaker.amazonaws.com 38 | - lambda.amazonaws.com 39 | Action: 40 | - 'sts:AssumeRole' 41 | Metadata: 42 | cfn_nag: 43 | rules_to_suppress: 44 | - id: W28 45 | reason: Needs to be explicitly named to tighten launch permissions policy 46 | 47 | NotebookInstanceIAMPolicy: 48 | Type: AWS::IAM::Policy 49 | Properties: 50 | PolicyName: !Sub "${SolutionPrefix}-nb-instance-policy" 51 | Roles: 52 | - !Ref NotebookInstanceExecutionRole 53 | PolicyDocument: 54 | Version: '2012-10-17' 55 | Statement: 56 | - Effect: Allow 57 | Action: 58 | - sagemaker:CreateTrainingJob 59 | - sagemaker:DescribeTrainingJob 60 | - sagemaker:CreateProcessingJob 61 | - sagemaker:DescribeProcessingJob 62 | Resource: 63 | - !Sub "arn:aws:sagemaker:${AWS::Region}:${AWS::AccountId}:*" 64 | - Effect: Allow 65 | Action: 66 | - ecr:GetAuthorizationToken 67 | - ecr:GetDownloadUrlForLayer 68 | - ecr:BatchGetImage 69 | - ecr:PutImage 70 | - ecr:BatchCheckLayerAvailability 71 | Resource: 72 | - "*" 73 | - !Sub "arn:aws:ecr:${AWS::Region}:${AWS::AccountId}:repository/*" 74 | - Effect: Allow 75 | Action: 76 | - cloudwatch:PutMetricData 77 | - cloudwatch:GetMetricData 78 | - cloudwatch:GetMetricStatistics 79 | - cloudwatch:ListMetrics 80 | Resource: 81 | - !Sub "arn:aws:cloudwatch:${AWS::Region}:${AWS::AccountId}:*" 82 | - Effect: Allow 83 | Action: 84 | - logs:CreateLogGroup 85 | - logs:CreateLogStream 86 | - logs:DescribeLogStreams 87 | - logs:GetLogEvents 88 | - logs:PutLogEvents 89 | Resource: 90 | - !Sub "arn:aws:logs:${AWS::Region}:${AWS::AccountId}:log-group:/aws/sagemaker/*" 91 | - !Sub "arn:aws:logs:${AWS::Region}:${AWS::AccountId}:log-group:/aws/lambda/*" 92 | - Effect: Allow 93 | Action: 94 | - iam:PassRole 95 | Resource: 96 | - !GetAtt NotebookInstanceExecutionRole.Arn 97 | Condition: 98 | StringEquals: 99 | iam:PassedToService: sagemaker.amazonaws.com 100 | - Effect: Allow 101 | Action: 102 | - iam:GetRole 103 | Resource: 104 | - !GetAtt NotebookInstanceExecutionRole.Arn 105 | - Effect: Allow 106 | Action: 107 | - s3:ListBucket 108 | - s3:GetObject 109 | - s3:PutObject 110 | - s3:GetObjectVersion 111 | - s3:DeleteObject 112 | - s3:DeleteBucket 113 | Resource: 114 | - !Sub "arn:aws:s3:::${SolutionS3BucketName}" 115 | - !Sub "arn:aws:s3:::${SolutionS3BucketName}/*" 116 | - !Sub 117 | - "arn:aws:s3:::${SolutionRefBucketBase}-${Region}" 118 | - SolutionRefBucketBase: !FindInMap [S3, !Ref StackVersion, BucketPrefix] 119 | Region: !Ref AWS::Region 120 | - !Sub 121 | - "arn:aws:s3:::${SolutionRefBucketBase}-${Region}/*" 122 | - SolutionRefBucketBase: !FindInMap [S3, !Ref StackVersion, BucketPrefix] 123 | Region: !Ref AWS::Region 124 | - Effect: Allow 125 | Action: 126 | - s3:CreateBucket 127 | - s3:ListBucket 128 | - s3:GetObject 129 | - s3:GetObjectVersion 130 | - s3:PutObject 131 | - s3:DeleteObject 132 | Resource: 133 | - !Sub "arn:aws:s3:::sagemaker-${AWS::Region}-${AWS::AccountId}" 134 | - !Sub "arn:aws:s3:::sagemaker-${AWS::Region}-${AWS::AccountId}/*" 135 | Metadata: 136 | cfn_nag: 137 | rules_to_suppress: 138 | - id: W12 139 | reason: ECR GetAuthorizationToken is non resource-specific action 140 | 141 | Outputs: 142 | SageMakerRoleArn: 143 | Description: "SageMaker Execution Role for the solution" 144 | Value: !GetAtt NotebookInstanceExecutionRole.Arn -------------------------------------------------------------------------------- /source/sagemaker/data-preprocessing/data_preprocessing.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | 5 | from datetime import datetime 6 | 7 | from sklearn.model_selection import train_test_split 8 | from sklearn.feature_extraction.text import TfidfVectorizer 9 | from sklearn.decomposition import TruncatedSVD 10 | from sklearn.pipeline import Pipeline 11 | 12 | import pandas as pd 13 | import numpy as np 14 | np.random.seed(0) 15 | 16 | MIN_TIMESTAMP = 1461340800 #2016, 04, 23 17 | MAX_TIMESTAMP = 1466611200 #2016, 06, 23 18 | 19 | 20 | def get_logger(name): 21 | logger = logging.getLogger(name) 22 | log_format = '%(asctime)s %(levelname)s %(name)s: %(message)s' 23 | logging.basicConfig(format=log_format, level=logging.INFO) 24 | logger.setLevel(logging.INFO) 25 | return logger 26 | 27 | 28 | def parse_args(): 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--data-dir', type=str, default='/opt/ml/processing/input') 31 | parser.add_argument('--output-dir', type=str, default='/opt/ml/processing/output') 32 | parser.add_argument('--logs', type=str, default='logs.csv', help='log of transient id web activity') 33 | parser.add_argument('--train', type=str, default='train.csv', help='pairs of transient ids that are the same user') 34 | parser.add_argument('--test-ratio', type=float, default=0, help='fraction of train data to use for test') 35 | return parser.parse_args() 36 | 37 | 38 | def prepare_ground_truth_links(data_dir, train_file, test_ratio, output_dir): 39 | logging.info("Reading train data edges from: {}".format(os.path.join(data_dir, train_file))) 40 | train = pd.read_csv(os.path.join(data_dir, train_file), header=None) 41 | if test_ratio: 42 | logging.info("Partitioning train data edges into train and test set with test ratio: {}".format(test_ratio)) 43 | train, test = train_test_split(train, test_size=test_ratio) 44 | logging.info("Saving test set to {}".format(os.path.join(output_dir, 'user_test_edges.csv'))) 45 | with open(os.path.join(output_dir, 'user_test_edges.csv'), 'w') as f: 46 | test.to_csv(f, index=False, header=False) 47 | logging.info("Saving train set to {}".format(os.path.join(output_dir, 'user_train_edges.csv'))) 48 | with open(os.path.join(output_dir, 'user_train_edges.csv'), 'w') as f: 49 | train.to_csv(f, index=False, header=False) 50 | 51 | 52 | def process_logs(data_dir, log_file, output_dir): 53 | logs = pd.read_csv(os.path.join(data_dir, log_file)) 54 | logging.info("Read user website visit logs from: {}".format(os.path.join(data_dir, log_file))) 55 | 56 | transient_edges = os.path.join(output_dir, 'transient_edges.csv') 57 | save_file(logs[['uid', 'urls']].drop_duplicates(), transient_edges, 58 | "Saved user -> url transient edges to {}".format(transient_edges)) 59 | 60 | transient_nodes_file = os.path.join(output_dir, 'transient_nodes.csv') 61 | user_features = get_user_features(logs[['uid', 'ts']]) 62 | save_file(user_features, transient_nodes_file, "Saved transient user features to {}".format(transient_nodes_file)) 63 | 64 | website_nodes_file = os.path.join(output_dir, 'website_nodes.csv') 65 | website_features = get_website_features(logs[['urls', 'titles']].drop_duplicates().fillna("")) 66 | save_file(website_features, website_nodes_file, "Saved website_features to {}".format(website_nodes_file)) 67 | 68 | website_group_file = os.path.join(output_dir, 'website_group_edges.csv') 69 | website_groupings = get_website_groupings(logs[['urls']].drop_duplicates()) 70 | save_file(website_groupings, website_group_file, "Saved url -> domain edges to {}".format(website_group_file)) 71 | 72 | 73 | def get_user_features(user_data): 74 | logging.info("Number of unique users is {}".format(len(user_data['uid'].unique()))) 75 | logging.info("User data has shape {}, columns: {} before transformation".format(user_data.shape, user_data.columns)) 76 | user_data['ts'] = user_data['ts'].apply(preprocess_timestamp) 77 | user_data = user_data.drop(user_data[(user_data['ts'] < MIN_TIMESTAMP) | (user_data['ts'] > MAX_TIMESTAMP)].index) 78 | 79 | user_features = np.zeros((user_data.shape[0], 7*24)) 80 | user_features[np.arange(user_features.shape[0]), user_data['ts'].apply(get_activity_index)] = 1 81 | 82 | logging.info("User data has shape {} after transformation".format(user_features.shape)) 83 | user_features_df = pd.DataFrame(user_features) 84 | user_features_df['uid'] = user_data['uid'].values 85 | final_user_feature = user_features_df.groupby('uid').sum().reset_index() 86 | logging.info("Final user features shape {}".format(final_user_feature.shape)) 87 | return final_user_feature 88 | 89 | 90 | def preprocess_timestamp(ts): 91 | if ts > 9999999999: 92 | ts = ts / 1000 93 | if ts > 9999999999: 94 | ts = ts / 1000 95 | return ts 96 | 97 | 98 | def get_activity_index(ts): 99 | dt = datetime.fromtimestamp(ts) 100 | return dt.weekday()*24 + dt.hour 101 | 102 | 103 | def get_website_features(web_data): 104 | logging.info("Web data has shape {}, columns: {} before transformation".format(web_data.shape, web_data.columns)) 105 | split_url = lambda url: " " + " ".join(url.split("/")[:3]) 106 | transform_pipeline = Pipeline([('tf_idf', TfidfVectorizer()), ('dim_reduce', TruncatedSVD(n_components=20))]) 107 | web_features = transform_pipeline.fit_transform(web_data['titles'].values+web_data['urls'].apply(split_url).values) 108 | logging.info("Web data has shape {} after transformation".format(web_features.shape)) 109 | web_features_df = pd.DataFrame(web_features) 110 | web_features_df.insert(0, 'urls', web_data['urls'].values) 111 | return web_features_df 112 | 113 | 114 | def get_website_groupings(urls): 115 | urls['domain'] = urls['urls'].apply(lambda x: x.split("/")[0]) 116 | return urls 117 | 118 | 119 | def save_file(df, file_name, message): 120 | with open(file_name, 'w') as f: 121 | df.to_csv(f, index=False, header=False) 122 | logging.info(message) 123 | 124 | 125 | if __name__ == '__main__': 126 | logging = get_logger(__name__) 127 | 128 | args = parse_args() 129 | 130 | prepare_ground_truth_links(args.data_dir, args.train, args.test_ratio, args.output_dir) 131 | process_logs(args.data_dir, args.logs, args.output_dir) -------------------------------------------------------------------------------- /source/sagemaker/sagemaker_graph_entity_resolution/dgl_entity_resolution/train_dgl_pytorch_entity_resolution.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import dgl 3 | 4 | import numpy as np 5 | 6 | import time 7 | import logging 8 | 9 | from estimator_fns import * 10 | from graph import * 11 | from data import * 12 | from model import * 13 | from utils import * 14 | 15 | def evaluate(model, g, train_triplets, test_triplets, user_features, web_features, batch_size, n_neighbors, 16 | hits=[1, 3, 10], device=None, filtered=False, mean_ap=False): 17 | 18 | logging.info("Performing model inference to get embeddings") 19 | embed = model.inference(g, user_features, web_features, batch_size, n_neighbors, device) 20 | logging.info("Got embeddings, computing metrics") 21 | 22 | w = model.w_relation.detach().clone().cpu() 23 | 24 | if mean_ap: 25 | metric = calc_mAP(embed, w, train_triplets, test_triplets) 26 | else: 27 | if filtered: 28 | metric = calc_filtered_mrr(embed, w, train_triplets, test_triplets, hits=hits) 29 | else: 30 | metric = calc_raw_mrr(embed, w, test_triplets, hits=hits, eval_bz=10000) 31 | return metric 32 | 33 | 34 | def train(g, model, train_dataloader, train_triplets, test_triplets, user_features, web_features, optimizer, batch_size, 35 | n_neighbors, n_epochs, negative_rate, grad_norm, cuda, device=None, run_eval=True): 36 | for epoch in range(n_epochs): 37 | tic = time.time() 38 | duration = [] 39 | loss_val = 0. 40 | mrr = -1. 41 | 42 | model.train() 43 | for n, (input_nodes, pos_pair_graph, neg_pair_graph, blocks) in enumerate(train_dataloader): 44 | user_nodes, website_nodes = input_nodes['user'], input_nodes['website'] 45 | u, w = user_features[input_nodes['user']], web_features[input_nodes['website']] 46 | 47 | true_srcs, true_dsts = pos_pair_graph.all_edges(etype='same_entity') 48 | false_srcs, false_dsts = neg_pair_graph.all_edges(etype='same_entity') 49 | sources, sinks = torch.cat((true_srcs, false_srcs)), torch.cat((true_dsts, false_dsts)) 50 | labels = torch.zeros((negative_rate + 1) * len(true_srcs)) 51 | labels[:len(true_srcs)] = 1 52 | 53 | if cuda: 54 | user_nodes, website_nodes, u, w = user_nodes.cuda(), website_nodes.cuda(), u.cuda(), w.cuda() 55 | blocks = [blk.to(device) for blk in blocks] 56 | sources, sinks, labels = sources.cuda(), sinks.cuda(), labels.cuda() 57 | embeddings = model(blocks,user_nodes, website_nodes, u, w) 58 | 59 | loss = model.get_loss(embeddings, sources, sinks, labels) 60 | 61 | optimizer.zero_grad() 62 | loss.backward() 63 | torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm) 64 | optimizer.step() 65 | 66 | loss_val += loss.item() 67 | 68 | duration.append(time.time() - tic) 69 | do_eval = run_eval and ((epoch % 5 == 0) or (epoch == n_epochs-1)) 70 | if do_eval: 71 | mrr = evaluate(model, g, train_triplets, test_triplets, user_features, web_features, batch_size, n_neighbors, 72 | device=device) 73 | 74 | logging.info("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | MRR {:.4f}".format( 75 | epoch, np.mean(duration), loss_val / (n + 1), mrr)) 76 | return model 77 | 78 | def run(): 79 | logging = get_logger(__name__) 80 | logging.info('numpy version:{} Pytorch version:{} DGL version:{}'.format(np.__version__, 81 | torch.__version__, 82 | dgl.__version__)) 83 | 84 | args = parse_args() 85 | 86 | g, (user_features, website_features), id_to_node, reverse = construct_graph(args.training_dir, args.train_edges, 87 | args.transient_nodes, 88 | args.transient_edges, 89 | args.website_nodes, 90 | args.website_group_edges) 91 | 92 | logging.info("""----Data statistics------ 93 | #Nodes: {} 94 | #Edges: {} 95 | #Same entity train edges: {} 96 | #User Features Shape: {} 97 | #Web Features Shape: {}""".format(sum([g.number_of_nodes(n_type) for n_type in g.ntypes]), 98 | sum([g.number_of_edges(e_type) for e_type in g.etypes]), 99 | g.number_of_edges('same_entity'), 100 | user_features.shape, 101 | website_features.shape)) 102 | 103 | user_features, website_features = torch.tensor(user_features), torch.tensor(website_features) 104 | 105 | model = EntityResolution(g, args.embedding_size, args.n_hidden, user_features.shape[1], website_features.shape[1], 106 | num_hidden_layers=args.n_layers, reg_param=args.regularization_param) 107 | 108 | cuda = args.num_gpus > 0 and torch.cuda.is_available() 109 | device = 'cpu' 110 | if cuda: 111 | torch.cuda.set_device(0) 112 | model.cuda() 113 | device = 'cuda:%d' % torch.cuda.current_device() 114 | 115 | # split into train and test 116 | us, vs, eids = g.all_edges(etype='same_entity', form='all') 117 | train_eids = np.random.choice(len(eids), int(0.7 * len(eids)), replace=False) 118 | test_eids = np.setdiff1d(np.arange(len(eids)), train_eids) 119 | train_triplets = torch.tensor( 120 | np.vstack((us[train_eids], np.zeros(len(train_eids), dtype=int), vs[train_eids])).transpose()) 121 | test_triplets = torch.tensor( 122 | np.vstack((us[test_eids], np.zeros(len(test_eids), dtype=int), vs[test_eids])).transpose()) 123 | logging.info("Split into train and test edges with {} train and {} test".format(len(train_eids), len(test_eids))) 124 | 125 | sampler = dgl.dataloading.MultiLayerNeighborSampler([args.n_neighbors] * args.n_layers) if args.mini_batch \ 126 | else dgl.dataloading.MultiLayerFullNeighborSampler(args.n_layers) 127 | neg_sampler = dgl.dataloading.negative_sampler.Uniform(args.negative_sampling_rate) 128 | collator = dgl.dataloading.EdgeCollator(g, {'same_entity': train_eids}, sampler, exclude='reverse_types', 129 | reverse_etypes=reverse, negative_sampler=neg_sampler) 130 | train_dataloader = torch.utils.data.DataLoader(collator.dataset, collate_fn=collator.collate, shuffle=True, 131 | batch_size=args.batch_size, drop_last=False, num_workers=0) 132 | 133 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 134 | 135 | model = train(g, model, train_dataloader, train_triplets, test_triplets, user_features, website_features, optimizer, 136 | args.batch_size, args.n_neighbors, args.n_epochs, args.negative_sampling_rate, args.grad_norm, cuda, 137 | device=device) 138 | 139 | 140 | if __name__ == '__main__': 141 | model = run() 142 | -------------------------------------------------------------------------------- /source/sagemaker/sagemaker_graph_entity_resolution/dgl_entity_resolution/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from sklearn.metrics import average_precision_score 4 | 5 | # raw mrr 6 | def sort_and_rank(score, target): 7 | _, indices = torch.sort(score, dim=1, descending=True) 8 | indices = torch.nonzero(indices == target.view(-1, 1)) 9 | indices = indices[:, 1].view(-1) 10 | return indices 11 | 12 | def perturb_and_get_raw_rank(embedding, w, a, r, b, test_size, batch_size=100): 13 | """ Perturb one element in the triplets 14 | """ 15 | n_batch = (test_size + batch_size - 1) // batch_size 16 | ranks = [] 17 | for idx in range(n_batch): 18 | print("batch {} / {}".format(idx, n_batch)) 19 | batch_start = idx * batch_size 20 | batch_end = min(test_size, (idx + 1) * batch_size) 21 | batch_a = a[batch_start: batch_end] 22 | batch_r = r[batch_start: batch_end] 23 | emb_ar = embedding[batch_a] * w[batch_r] 24 | emb_ar = emb_ar.transpose(0, 1).unsqueeze(2) # size: D x E x 1 25 | emb_c = embedding.transpose(0, 1).unsqueeze(1) # size: D x 1 x V 26 | # out-prod and reduce sum 27 | out_prod = torch.bmm(emb_ar, emb_c) # size D x E x V 28 | score = torch.sum(out_prod, dim=0) # size E x V 29 | score = torch.sigmoid(score) 30 | target = b[batch_start: batch_end] 31 | ranks.append(sort_and_rank(score, target)) 32 | return torch.cat(ranks) 33 | 34 | # return MRR (raw), and Hits @ (1, 3, 10) 35 | def calc_raw_mrr(embedding, w, test_triplets, hits=[], eval_bz=100): 36 | with torch.no_grad(): 37 | s = test_triplets[:, 0] 38 | r = test_triplets[:, 1] 39 | o = test_triplets[:, 2] 40 | test_size = test_triplets.shape[0] 41 | 42 | # perturb subject 43 | ranks_s = perturb_and_get_raw_rank(embedding, w, o, r, s, test_size, eval_bz) 44 | # perturb object 45 | ranks_o = perturb_and_get_raw_rank(embedding, w, s, r, o, test_size, eval_bz) 46 | 47 | ranks = torch.cat([ranks_s, ranks_o]) 48 | ranks += 1 # change to 1-indexed 49 | 50 | mrr = torch.mean(1.0 / ranks.float()) 51 | print("MRR (raw): {:.6f}".format(mrr.item())) 52 | 53 | for hit in hits: 54 | avg_count = torch.mean((ranks <= hit).float()) 55 | print("Hits (raw) @ {}: {:.6f}".format(hit, avg_count.item())) 56 | return mrr.item() 57 | 58 | # filtered mrr 59 | 60 | def filter_o(triplets_to_filter, target_s, target_r, target_o, num_entities): 61 | target_s, target_r, target_o = int(target_s), int(target_r), int(target_o) 62 | filtered_o = [] 63 | # Do not filter out the test triplet, since we want to predict on it 64 | if (target_s, target_r, target_o) in triplets_to_filter: 65 | triplets_to_filter.remove((target_s, target_r, target_o)) 66 | # Do not consider an object if it is part of a triplet to filter 67 | for o in range(num_entities): 68 | if (target_s, target_r, o) not in triplets_to_filter: 69 | filtered_o.append(o) 70 | return torch.LongTensor(filtered_o) 71 | 72 | def filter_s(triplets_to_filter, target_s, target_r, target_o, num_entities): 73 | target_s, target_r, target_o = int(target_s), int(target_r), int(target_o) 74 | filtered_s = [] 75 | # Do not filter out the test triplet, since we want to predict on it 76 | if (target_s, target_r, target_o) in triplets_to_filter: 77 | triplets_to_filter.remove((target_s, target_r, target_o)) 78 | # Do not consider a subject if it is part of a triplet to filter 79 | for s in range(num_entities): 80 | if (s, target_r, target_o) not in triplets_to_filter: 81 | filtered_s.append(s) 82 | return torch.LongTensor(filtered_s) 83 | 84 | def perturb_o_and_get_filtered_rank(embedding, w, s, r, o, test_size, triplets_to_filter): 85 | """ Perturb object in the triplets 86 | """ 87 | num_entities = embedding.shape[0] 88 | ranks = [] 89 | for idx in range(test_size): 90 | if idx % 100 == 0: 91 | print("test triplet {} / {}".format(idx, test_size)) 92 | target_s = s[idx] 93 | target_r = r[idx] 94 | target_o = o[idx] 95 | filtered_o = filter_o(triplets_to_filter, target_s, target_r, target_o, num_entities) 96 | target_o_idx = int((filtered_o == target_o).nonzero()) 97 | emb_s = embedding[target_s] 98 | emb_r = w[target_r] 99 | emb_o = embedding[filtered_o] 100 | emb_triplet = emb_s * emb_r * emb_o 101 | scores = torch.sigmoid(torch.sum(emb_triplet, dim=1)) 102 | _, indices = torch.sort(scores, descending=True) 103 | rank = int((indices == target_o_idx).nonzero()) 104 | ranks.append(rank) 105 | return torch.LongTensor(ranks) 106 | 107 | def perturb_s_and_get_filtered_rank(embedding, w, s, r, o, test_size, triplets_to_filter): 108 | """ Perturb subject in the triplets 109 | """ 110 | num_entities = embedding.shape[0] 111 | ranks = [] 112 | for idx in range(test_size): 113 | if idx % 100 == 0: 114 | print("test triplet {} / {}".format(idx, test_size)) 115 | target_s = s[idx] 116 | target_r = r[idx] 117 | target_o = o[idx] 118 | filtered_s = filter_s(triplets_to_filter, target_s, target_r, target_o, num_entities) 119 | target_s_idx = int((filtered_s == target_s).nonzero()) 120 | emb_s = embedding[filtered_s] 121 | emb_r = w[target_r] 122 | emb_o = embedding[target_o] 123 | emb_triplet = emb_s * emb_r * emb_o 124 | scores = torch.sigmoid(torch.sum(emb_triplet, dim=1)) 125 | _, indices = torch.sort(scores, descending=True) 126 | rank = int((indices == target_s_idx).nonzero()) 127 | ranks.append(rank) 128 | return torch.LongTensor(ranks) 129 | 130 | def calc_filtered_mrr(embedding, w, train_triplets, test_triplets, hits=[]): 131 | with torch.no_grad(): 132 | s = test_triplets[:, 0] 133 | r = test_triplets[:, 1] 134 | o = test_triplets[:, 2] 135 | test_size = test_triplets.shape[0] 136 | 137 | triplets_to_filter = torch.cat([train_triplets, test_triplets]).tolist() 138 | triplets_to_filter = {tuple(triplet) for triplet in triplets_to_filter} 139 | print('Perturbing subject...') 140 | ranks_s = perturb_s_and_get_filtered_rank(embedding, w, s, r, o, test_size, triplets_to_filter) 141 | print('Perturbing object...') 142 | ranks_o = perturb_o_and_get_filtered_rank(embedding, w, s, r, o, test_size, triplets_to_filter) 143 | 144 | ranks = torch.cat([ranks_s, ranks_o]) 145 | ranks += 1 # change to 1-indexed 146 | 147 | mrr = torch.mean(1.0 / ranks.float()) 148 | print("MRR (filtered): {:.6f}".format(mrr.item())) 149 | 150 | for hit in hits: 151 | avg_count = torch.mean((ranks <= hit).float()) 152 | print("Hits (filtered) @ {}: {:.6f}".format(hit, avg_count.item())) 153 | return mrr.item() 154 | 155 | # maP 156 | def convert_to_adj_list(i, j): 157 | adj_list = {} 158 | for (a, b) in zip(i, j): 159 | if a == b: 160 | continue 161 | if a in adj_list: 162 | adj_list[a].append(b) 163 | else: 164 | adj_list[a] = [b] 165 | if b in adj_list: 166 | adj_list[b].append(a) 167 | else: 168 | adj_list[b] = [a] 169 | return adj_list 170 | 171 | def calc_mAP(embedding, w, train_triplets, test_triplets): 172 | sources = torch.cat((test_triplets[:, 0], train_triplets[:, 0])) 173 | sinks = torch.cat((test_triplets[:, 2], train_triplets[:, 2])) 174 | adj_list = convert_to_adj_list(sources.numpy(), sinks.numpy()) 175 | aps = [] 176 | for node in test_triplets[:, 0]: 177 | embed_i = node.repeat(embedding.shape[0] - 1,) 178 | embed_j = torch.tensor(list(range(0, node)) + list(range(node + 1, embedding.shape[0]))) 179 | score = torch.sum(w * embedding[embed_i] * embedding[embed_j], dim=1) 180 | pred_proba = torch.sigmoid(score).detach().numpy() 181 | labels = np.zeros(embedding.shape[0]) 182 | labels_ones_idx = adj_list.get(node, []) 183 | labels[labels_ones_idx] = 1 184 | labels = np.concatenate((labels[:node], labels[node + 1:])) 185 | ap = average_precision_score(labels, pred_proba) 186 | aps.append(ap) 187 | 188 | return np.mean(aps) 189 | -------------------------------------------------------------------------------- /source/sagemaker/baseline/train_pytorch_mlp_entity_resolution.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import logging 4 | import time 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | class SiamesePairwiseClassification(nn.Module): 12 | def __init__(self, n_layers, input_size, hidden_size=16): 13 | super(SiamesePairwiseClassification, self).__init__() 14 | layers = [nn.Linear(input_size, hidden_size), nn.ReLU()] 15 | if n_layers > 1: 16 | for i in range(n_layers - 1): 17 | layers.extend((nn.Linear(hidden_size, hidden_size), nn.ReLU())) 18 | self.layers = nn.Sequential(*layers) 19 | self.score = nn.CosineSimilarity() 20 | self.w_relation = nn.Parameter(torch.Tensor(1, hidden_size)) 21 | 22 | def forward(self, x): 23 | return self.layers(x) 24 | 25 | def calc_score(self, embed_i, embed_j): 26 | score = torch.sum(self.w_relation * embed_i * embed_j, dim=1) 27 | # score = self.score(embed[i], embed[j]) 28 | return score 29 | 30 | def get_loss(self, embed, i, j, labels): 31 | score = self.calc_score(embed[i] , embed[j]) 32 | loss = F.binary_cross_entropy_with_logits(score, labels) + torch.mean(embed.pow(2)) 33 | return loss 34 | 35 | 36 | def read_data(training_dir, user_features, url_features, transient_edges, train_edges): 37 | user_features_df = pd.read_csv(os.path.join(training_dir, user_features), header=None).set_index(0) 38 | logging.info("Read user features".format(os.path.join(training_dir, user_features))) 39 | 40 | url_features_df = pd.read_csv(os.path.join(training_dir, url_features), header=None).set_index(0) 41 | logging.info("Read url features from {}".format(os.path.join(training_dir, url_features))) 42 | 43 | transient_interactions = pd.read_csv(os.path.join(training_dir, transient_edges), header=None) 44 | logging.info("Read transient_interactions {}".format(os.path.join(training_dir, transient_edges))) 45 | 46 | transient_interactions = transient_interactions.groupby([0])[1].apply(','.join).reset_index().drop_duplicates().set_index(0) 47 | logging.info("Grouped transient_interactions") 48 | 49 | (n_user, d_user), d_url, = user_features_df.shape, url_features_df.shape[1] 50 | features = np.zeros((n_user, d_user + d_url)) 51 | for i, (uid, row) in enumerate(user_features_df.iterrows()): 52 | features[i, :d_user] = row 53 | features[i, d_user:] = url_features_df.loc[transient_interactions.loc[uid].values[0].split(',')].mean(axis=0) 54 | 55 | train_pairs = pd.read_csv(os.path.join(training_dir, train_edges), header=None) 56 | logging.info("Read ground truth training pairs {}".format(os.path.join(training_dir, train_edges))) 57 | uid_to_idx = {uid: i for i, uid in enumerate(user_features_df.index.values)} 58 | map_uid_to_idx = lambda x: uid_to_idx[x] 59 | true_i = train_pairs[0].apply(map_uid_to_idx) 60 | true_j = train_pairs[1].apply(map_uid_to_idx) 61 | return features.astype(np.float32), true_i, true_j, uid_to_idx 62 | 63 | def convert_to_adj_list(i, j): 64 | adj_list = {} 65 | for (a, b) in zip(i, j): 66 | if a in adj_list: 67 | adj_list[a].append(b) 68 | else: 69 | adj_list[a] = [b] 70 | if b in adj_list: 71 | adj_list[b].append(a) 72 | else: 73 | adj_list[b] = [a] 74 | return adj_list 75 | 76 | def train(model, dataloader, features, n_epochs, optimizer, neg_rate, cuda): 77 | for epoch in range(n_epochs): 78 | tic = time.time() 79 | loss_val = 0. 80 | duration = [] 81 | metric = -1 82 | for n, (i, j) in enumerate(dataloader): 83 | labels = torch.zeros((neg_rate + 1) * len(i)) 84 | labels[:len(i)] = 1 85 | i = torch.cat((i, torch.tensor(np.random.choice(features.shape[0], neg_rate*len(i))))) 86 | j = torch.cat((j, torch.tensor(np.random.choice(features.shape[0], neg_rate*len(j))))) 87 | 88 | if cuda: 89 | i, j, labels = i.cuda(), j.cuda(), labels.cuda() 90 | 91 | embed = model(features) 92 | loss = model.get_loss(embed, i, j, labels) 93 | 94 | optimizer.zero_grad() 95 | loss.backward() 96 | optimizer.step() 97 | 98 | loss_val += loss.item() 99 | duration.append(time.time() - tic) 100 | print(loss_val) 101 | logging.info("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | MRR {:.4f}".format( 102 | epoch, np.mean(duration), loss_val / (n + 1), metric)) 103 | 104 | def get_logger(name): 105 | logger = logging.getLogger(name) 106 | log_format = '%(asctime)s %(levelname)s %(name)s: %(message)s' 107 | logging.basicConfig(format=log_format, level=logging.INFO) 108 | logger.setLevel(logging.INFO) 109 | return logger 110 | 111 | 112 | def parse_args(): 113 | parser = argparse.ArgumentParser() 114 | 115 | parser.add_argument('--training-dir', type=str, default=os.environ['SM_CHANNEL_TRAIN']) 116 | parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR']) 117 | parser.add_argument('--output-dir', type=str, default=os.environ['SM_OUTPUT_DATA_DIR']) 118 | parser.add_argument('--train-edges', type=str, default='user_train_edges.csv') 119 | parser.add_argument('--test-edges', type=str, default='user_test_edges.csv') 120 | parser.add_argument('--transient-edges', type=str, default='transient_edges.csv') 121 | parser.add_argument('--user-features', type=str, default='transient_nodes.csv') 122 | parser.add_argument('--url-features', type=str, default='website_nodes.csv') 123 | parser.add_argument('--n-hidden', type=int, default=16, help='number of hidden units') 124 | parser.add_argument('--n-layers', type=int, default=2, help='number of hidden layers') 125 | parser.add_argument('--batch-size', type=int, default=5000) 126 | parser.add_argument('--num-gpus', type=int, default=1) 127 | parser.add_argument('--optimizer', type=str, default='adam') 128 | parser.add_argument('--weight-decay', type=float, default=5e-4, help='Weight for L2 loss') 129 | parser.add_argument('--lr', type=float, default=1e-2) 130 | parser.add_argument('--negative-sampling-rate', type=int, default=10, help='rate of negatively sampled edges') 131 | parser.add_argument('--n-epochs', type=int, default=20) 132 | 133 | return parser.parse_args() 134 | 135 | if __name__ == '__main__': 136 | logging = get_logger(__name__) 137 | logging.info('numpy version:{} Pytorch version:{}'.format(np.__version__, torch.__version__)) 138 | 139 | args = parse_args() 140 | features, true_i, true_j, uid_to_idx = read_data(args.training_dir, 141 | args.user_features, 142 | args.url_features, 143 | args.transient_edges, 144 | args.train_edges) 145 | 146 | train_idxs = np.random.choice(len(true_i), int(0.7 * len(true_i)), replace=False) 147 | test_idxs = np.setdiff1d(np.arange(len(true_i)), train_idxs) 148 | train_i, train_j, test_i, test_j = true_i[train_idxs], true_j[train_idxs], true_i[test_idxs], true_j[test_idxs] 149 | 150 | adj_list = convert_to_adj_list(true_i, true_j) 151 | features = torch.tensor(features) 152 | 153 | model = SiamesePairwiseClassification(args.n_layers, features.shape[1], hidden_size=args.n_hidden) 154 | 155 | cuda = args.num_gpus > 0 and torch.cuda.is_available() 156 | device = 'cpu' 157 | if cuda: 158 | torch.cuda.set_device(0) 159 | model.cuda() 160 | features = features.cuda() 161 | device = 'cuda:%d' % torch.cuda.current_device() 162 | 163 | 164 | train_dataloader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(torch.tensor(train_i.values), 165 | torch.tensor(train_j.values)), 166 | shuffle=True, 167 | batch_size=args.batch_size, 168 | drop_last=False, 169 | num_workers=0) 170 | 171 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 172 | 173 | train(model, train_dataloader, features, args.n_epochs, optimizer, args.negative_sampling_rate, cuda) 174 | 175 | 176 | 177 | 178 | -------------------------------------------------------------------------------- /source/sagemaker/sagemaker_graph_entity_resolution/dgl_entity_resolution/model.py: -------------------------------------------------------------------------------- 1 | """RGCN layer implementation""" 2 | from collections import defaultdict 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import dgl 9 | import dgl.function as fn 10 | 11 | 12 | class HeteroRGCNLayer(nn.Module): 13 | def __init__(self, in_sizes, out_sizes, etypes): 14 | super(HeteroRGCNLayer, self).__init__() 15 | # W_r for each relation 16 | self.weight = nn.ModuleDict({ 17 | name: nn.Linear(in_size, out_size) for name, in_size, out_size in zip(etypes, in_sizes, out_sizes) 18 | }) 19 | 20 | def forward(self, G, feat_dict): 21 | # The input is a dictionary of node features for each type 22 | funcs = {} 23 | for srctype, etype, dsttype in G.canonical_etypes: 24 | # Compute W_r * h 25 | if srctype in feat_dict: 26 | Wh = self.weight[etype](feat_dict[srctype]) 27 | # Save it in graph for message passing 28 | G.nodes[srctype].data['Wh_%s' % etype] = Wh 29 | # Specify per-relation message passing functions: (message_func, reduce_func). 30 | # Note that the results are saved to the same destination feature 'h', which 31 | # hints the type wise reducer for aggregation. 32 | funcs[etype] = (fn.copy_u('Wh_%s' % etype, 'm'), fn.mean('m', 'h')) 33 | # Trigger message passing of multiple types. 34 | # The first argument is the message passing functions for each relation. 35 | # The second one is the type wise reducer, could be "sum", "max", 36 | # "min", "mean", "stack" 37 | G.multi_update_all(funcs, 'sum') 38 | # return the updated node feature dictionary 39 | return {ntype: G.dstnodes[ntype].data['h'] for ntype in G.ntypes if 'h' in G.dstnodes[ntype].data} 40 | 41 | 42 | class HeteroRGCN(nn.Module): 43 | def __init__(self, g, in_size, hidden_size, n_layers): 44 | super(HeteroRGCN, self).__init__() 45 | # Use trainable node embeddings as featureless inputs. 46 | embed_dict = {ntype: nn.Parameter(torch.Tensor(g.number_of_nodes(ntype), in_size['default'])) 47 | for ntype in g.ntypes if ntype != 'user' and ntype != 'website'} 48 | for key, embed in embed_dict.items(): 49 | nn.init.xavier_uniform_(embed) 50 | self.embed = nn.ParameterDict(embed_dict) 51 | 52 | # Prepare R-GCN layer input output size for each relation 53 | in_sizes = [] 54 | for srctype, etype, dsttype in g.canonical_etypes: 55 | if srctype in in_size: 56 | in_sizes.append(in_size[srctype]) 57 | else: 58 | in_sizes.append(in_size['default']) 59 | 60 | hidden_sizes = [hidden_size] * len(g.etypes) 61 | 62 | # create layers 63 | layers = [HeteroRGCNLayer(in_sizes, hidden_sizes, g.etypes)] 64 | if n_layers > 1: 65 | # additional hidden layers 66 | for i in range(n_layers - 1): 67 | layers.append(HeteroRGCNLayer(hidden_sizes, hidden_sizes, g.etypes)) 68 | self.layers = nn.Sequential(*layers) 69 | 70 | def forward(self, g, user_features, website_features): 71 | # get embeddings for all node types. for user node type, use passed in user features 72 | h_dict = {} 73 | h_dict['user'] = nn.Parameter(user_features) 74 | h_dict['website'] = nn.Parameter(website_features) 75 | 76 | for ntype in self.embed: 77 | if g[0].number_of_nodes(ntype) > 0: 78 | h_dict[ntype] = self.embed[ntype][g[0].nodes(ntype).long(), :] 79 | 80 | # pass through all layers 81 | for i, layer in enumerate(self.layers): 82 | if i != 0: 83 | h_dict = {k: F.leaky_relu(h) for k, h in h_dict.items()} 84 | h_dict = layer(g[i], h_dict) 85 | 86 | # get user logits 87 | return h_dict['user'] 88 | 89 | 90 | class EmbeddingLayer(nn.Module): 91 | def __init__(self, input_size, embedding_size): 92 | super(EmbeddingLayer, self).__init__() 93 | self.embed = nn.Embedding(input_size, embedding_size) 94 | 95 | def forward(self, nodes): 96 | features = self.embed(nodes) 97 | return features 98 | 99 | 100 | class EntityResolution(nn.Module): 101 | def __init__(self, g, in_dim, h_dim, user_feature_dim, website_feature_dim, 102 | num_hidden_layers=1, reg_param=0): 103 | super(EntityResolution, self).__init__() 104 | self.user_embedding = EmbeddingLayer(g.number_of_nodes('user'), in_dim) 105 | self.website_embedding = EmbeddingLayer(g.number_of_nodes('website'), in_dim) 106 | in_size = {'default': in_dim, 'user': in_dim + user_feature_dim, 'website': in_dim + website_feature_dim} 107 | 108 | self.rgcn = HeteroRGCN(g, in_size, h_dim, num_hidden_layers) 109 | self.n_hidden = h_dim 110 | self.reg_param = reg_param 111 | self.w_relation = nn.Parameter(torch.Tensor(1, h_dim)) 112 | nn.init.xavier_uniform_(self.w_relation, 113 | gain=nn.init.calculate_gain('relu')) 114 | 115 | def calc_score(self, embedding, sources, sinks): 116 | # DistMult 117 | h = embedding[sources] 118 | t = embedding[sinks] 119 | score = torch.sum(self.w_relation * h * t, dim=1) 120 | return score 121 | 122 | def forward(self, g, user_nodes, website_nodes, user_features, website_features): 123 | user_embed, website_embed = self.user_embedding(user_nodes), self.website_embedding(website_nodes) 124 | u = torch.cat((user_embed, user_features), 1) 125 | w = torch.cat((website_embed, website_features), 1) 126 | 127 | return self.rgcn(g, u, w) 128 | 129 | def regularization_loss(self, embedding): 130 | return torch.mean(embedding.pow(2)) + torch.mean(self.w_relation.pow(2)) 131 | 132 | def get_loss(self, embed, sources, sinks, labels): 133 | # sources and sinks is a list of edge data samples (positive and negative) 134 | score = self.calc_score(embed, sources, sinks) 135 | predict_loss = F.binary_cross_entropy_with_logits(score, labels) 136 | reg_loss = self.regularization_loss(embed) 137 | return predict_loss + self.reg_param * reg_loss 138 | 139 | def inference(self, g, user_features, web_features, batch_size, n_neighbors, device, num_workers=0): 140 | for l, layer in enumerate(self.rgcn.layers): 141 | sampler = dgl.dataloading.MultiLayerNeighborSampler([n_neighbors]) 142 | dataloader = dgl.dataloading.NodeDataLoader( 143 | g, 144 | {ntype: torch.arange(g.number_of_nodes(ntype)) for ntype in g.ntypes}, 145 | sampler, 146 | batch_size= batch_size, 147 | shuffle=True, 148 | drop_last=False, 149 | num_workers=num_workers) 150 | 151 | y_user = torch.zeros(g.number_of_nodes('user'), self.n_hidden) 152 | y_website = torch.zeros(g.number_of_nodes('website'), self.n_hidden) 153 | y_others = {ntype: torch.zeros(g.number_of_nodes(ntype), self.n_hidden) 154 | for ntype in g.ntypes if ntype != 'user' and ntype != 'website'} 155 | 156 | for input_nodes, output_nodes, blocks in dataloader: 157 | block = blocks[0].to(device) 158 | 159 | # get initial features 160 | if l == 0: 161 | u_f, w_f = user_features[input_nodes['user']], web_features[input_nodes['website']] 162 | u_f, w_f = u_f.to(device), w_f.to(device) 163 | user_nodes, website_nodes = input_nodes['user'].to(device), input_nodes['website'].to(device) 164 | 165 | # get embeddings and concat with initial features 166 | user_embed, website_embed = self.user_embedding(user_nodes), self.website_embedding(website_nodes) 167 | u = torch.cat((user_embed, u_f), 1) 168 | w = torch.cat((website_embed, w_f), 1) 169 | 170 | # get intermediate representations 171 | else: 172 | u = y_user[input_nodes['user']].to(device) 173 | w = y_website[input_nodes['website']].to(device) 174 | 175 | h_dict = {} 176 | h_dict['user'] = nn.Parameter(u) 177 | h_dict['website'] = nn.Parameter(w) 178 | 179 | for ntype in self.rgcn.embed: 180 | if block.number_of_nodes(ntype) > 0: 181 | if l == 0: 182 | h_dict[ntype] = self.rgcn.embed[ntype][block.nodes(ntype).long(), :] 183 | else: 184 | h_dict[ntype] = y_others[ntype][input_nodes[ntype]].to(device) 185 | 186 | h_dict = layer(block, h_dict) 187 | if l != len(self.rgcn.layers) - 1: 188 | h_dict = {k: F.leaky_relu(h) for k, h in h_dict.items()} 189 | if len(output_nodes['user']): 190 | y_user[output_nodes['user']] = h_dict['user'].cpu() 191 | if len(output_nodes['website']): 192 | y_website[output_nodes['website']] = h_dict['website'].cpu() 193 | for ntype in self.rgcn.embed: 194 | if len(output_nodes[ntype]): 195 | y_others[ntype][output_nodes[ntype]] = h_dict[ntype].cpu() 196 | 197 | return y_user -------------------------------------------------------------------------------- /deployment/sagemaker-graph-entity-resolution.yaml: -------------------------------------------------------------------------------- 1 | AWSTemplateFormatVersion: "2010-09-09" 2 | Description: "(SA0007) - sagemaker-graph-entity-resolution: Solution for training a graph neural network model for entity resolution using Amazon SageMaker. Version 1" 3 | Parameters: 4 | SolutionPrefix: 5 | Type: String 6 | Default: "sagemaker-soln-entity-res" 7 | Description: | 8 | Used to name resources created as part of this stack (and inside nested stacks too). 9 | Can be the same as the stack name used by AWS CloudFormation, but this field has extra 10 | constraints because it's used to name resources with restrictions (e.g. Amazon S3 bucket 11 | names cannot contain capital letters). 12 | AllowedPattern: '^sagemaker-soln-[a-z0-9\-]{1,20}$' 13 | ConstraintDescription: | 14 | Only allowed to use lowercase letters, hyphens and/or numbers. 15 | Should also start with 'sagemaker-soln-entity-res' for permission management. 16 | IamRole: 17 | Type: String 18 | Default: "" 19 | Description: | 20 | IAM Role that will be attached to the resources created by this cloudformation to grant them permissions to 21 | perform their required functions. This role should allow SageMaker and Lambda perform the required actions like 22 | creating training jobs and processing jobs. If left blank, the template will attempt to create a role for you. 23 | This can cause a stack creation error if you don't have privileges to create new roles. 24 | S3RawDataPrefix: 25 | Description: Enter the S3 prefix where user interaction logs and known resolved entities are stored. 26 | Type: String 27 | Default: "raw-data" 28 | S3ProcessingJobOutputPrefix: 29 | Description: Enter the S3 prefix where preprocessed data should be stored and monitored for changes to start the training job 30 | Type: String 31 | Default: "preprocessed-data" 32 | S3TrainingJobOutputPrefix: 33 | Description: Enter the S3 prefix where model and output artifacts from the training job should be stored 34 | Type: String 35 | Default: "training-output" 36 | CreateSageMakerNotebookInstance: 37 | Description: Whether to launch classic sagemaker notebook instance 38 | Type: String 39 | AllowedValues: 40 | - "true" 41 | - "false" 42 | Default: "false" 43 | SageMakerNotebookInstanceType: 44 | Description: Instance type of the SageMaker notebook instance 45 | Type: String 46 | Default: "ml.m4.xlarge" 47 | StackVersion: 48 | Description: | 49 | CloudFormation Stack version. 50 | Use 'release' version unless you are customizing the 51 | CloudFormation templates and solution artifacts. 52 | Type: String 53 | Default: release 54 | AllowedValues: 55 | - release 56 | - development 57 | 58 | Metadata: 59 | AWS::CloudFormation::Interface: 60 | ParameterGroups: 61 | - 62 | Label: 63 | default: Solution Configuration 64 | Parameters: 65 | - SolutionPrefix 66 | - IamRole 67 | - StackVersion 68 | - 69 | Label: 70 | default: S3 Configuration 71 | Parameters: 72 | - S3RawDataPrefix 73 | - S3ProcessingJobOutputPrefix 74 | - S3TrainingJobOutputPrefix 75 | - 76 | Label: 77 | default: SageMaker Configuration 78 | Parameters: 79 | - CreateSageMakerNotebookInstance 80 | - SageMakerNotebookInstanceType 81 | ParameterLabels: 82 | SolutionPrefix: 83 | default: Solution Resources Name Prefix 84 | IamRole: 85 | default: Solution IAM Role Arn 86 | StackVersion: 87 | default: Solution Stack Version 88 | S3RawDataPrefix: 89 | default: S3 Data Prefix 90 | S3ProcessingJobOutputPrefix: 91 | default: S3 Preprocessed Data Prefix 92 | S3TrainingJobOutputPrefix: 93 | default: S3 Training Results Prefix 94 | CreateSageMakerNotebookInstance: 95 | default: Launch Classic SageMaker Notebook Instance 96 | SageMakerNotebookInstanceType: 97 | default: SageMaker Notebook Instance 98 | 99 | Mappings: 100 | S3: 101 | release: 102 | BucketPrefix: "sagemaker-solutions-prod" 103 | development: 104 | BucketPrefix: "sagemaker-solutions-devo" 105 | 106 | Conditions: 107 | CreateClassicSageMakerResources: !Equals [ !Ref CreateSageMakerNotebookInstance, "true" ] 108 | CreateCustomSolutionRole: !Equals [!Ref IamRole, ""] 109 | 110 | Resources: 111 | S3Bucket: 112 | Type: AWS::S3::Bucket 113 | Properties: 114 | BucketName: !Sub "${SolutionPrefix}-${AWS::AccountId}-${AWS::Region}" 115 | PublicAccessBlockConfiguration: 116 | BlockPublicAcls: true 117 | BlockPublicPolicy: true 118 | IgnorePublicAcls: true 119 | RestrictPublicBuckets: true 120 | BucketEncryption: 121 | ServerSideEncryptionConfiguration: 122 | - 123 | ServerSideEncryptionByDefault: 124 | SSEAlgorithm: AES256 125 | Metadata: 126 | cfn_nag: 127 | rules_to_suppress: 128 | - id: W35 129 | reason: Configuring logging requires supplying an existing customer S3 bucket to store logs 130 | - id: W51 131 | reason: Default access policy suffices 132 | 133 | SolutionAssistantStack: 134 | Type: "AWS::CloudFormation::Stack" 135 | Properties: 136 | TemplateURL: !Sub 137 | - "https://s3.${Region}.amazonaws.com/${SolutionRefBucketBase}-${Region}/Entity-resolution-for-smart-advertising/deployment/solution-assistant/solution-assistant.yaml" 138 | - SolutionRefBucketBase: !FindInMap [S3, !Ref StackVersion, BucketPrefix] 139 | Region: !Ref AWS::Region 140 | Parameters: 141 | SolutionPrefix: !Ref SolutionPrefix 142 | SolutionsRefBucketName: !Sub 143 | - "${SolutionRefBucketBase}-${AWS::Region}" 144 | - SolutionRefBucketBase: !FindInMap [S3, !Ref StackVersion, BucketPrefix] 145 | SolutionS3BucketName: !Sub "${SolutionPrefix}-${AWS::AccountId}-${AWS::Region}" 146 | RoleArn: !If [CreateCustomSolutionRole, !GetAtt SageMakerPermissionsStack.Outputs.SageMakerRoleArn, !Ref IamRole] 147 | 148 | SageMakerPermissionsStack: 149 | Type: "AWS::CloudFormation::Stack" 150 | Condition: CreateCustomSolutionRole 151 | Properties: 152 | TemplateURL: !Sub 153 | - "https://s3.${Region}.amazonaws.com/${SolutionRefBucketBase}-${Region}/Entity-resolution-for-smart-advertising/deployment/sagemaker-permissions-stack.yaml" 154 | - SolutionRefBucketBase: !FindInMap [S3, !Ref StackVersion, BucketPrefix] 155 | Region: !Ref AWS::Region 156 | Parameters: 157 | SolutionPrefix: !Ref SolutionPrefix 158 | SolutionS3BucketName: !Sub "${SolutionPrefix}-${AWS::AccountId}-${AWS::Region}" 159 | StackVersion: !Ref StackVersion 160 | 161 | SageMakerStack: 162 | Type: "AWS::CloudFormation::Stack" 163 | Condition: CreateClassicSageMakerResources 164 | Properties: 165 | TemplateURL: !Sub 166 | - "https://s3.${Region}.amazonaws.com/${SolutionRefBucketBase}-${Region}/Entity-resolution-for-smart-advertising/deployment/sagemaker-notebook-instance-stack.yaml" 167 | - SolutionRefBucketBase: !FindInMap [S3, !Ref StackVersion, BucketPrefix] 168 | Region: !Ref AWS::Region 169 | Parameters: 170 | SolutionPrefix: !Ref SolutionPrefix 171 | SolutionS3BucketName: !Sub "${SolutionPrefix}-${AWS::AccountId}-${AWS::Region}" 172 | S3InputDataPrefix: !Ref S3RawDataPrefix 173 | S3ProcessingJobOutputPrefix: !Ref S3ProcessingJobOutputPrefix 174 | S3TrainingJobOutputPrefix: !Ref S3TrainingJobOutputPrefix 175 | SageMakerNotebookInstanceType: !Ref SageMakerNotebookInstanceType 176 | NotebookInstanceExecutionRoleArn: !If [CreateCustomSolutionRole, !GetAtt SageMakerPermissionsStack.Outputs.SageMakerRoleArn, !Ref IamRole] 177 | StackVersion: !Ref StackVersion 178 | 179 | Outputs: 180 | SourceCode: 181 | Condition: CreateClassicSageMakerResources 182 | Description: "Open Jupyter IDE. This authenticate you against Jupyter." 183 | Value: !GetAtt SageMakerStack.Outputs.SourceCode 184 | 185 | NotebookInstance: 186 | Condition: CreateClassicSageMakerResources 187 | Description: "SageMaker Notebook instance to manually orchestrate data preprocessing and model training" 188 | Value: !GetAtt SageMakerStack.Outputs.NotebookInstance 189 | 190 | IamRole: 191 | Description: "Arn of SageMaker Execution Role" 192 | Value: !If [CreateCustomSolutionRole, !GetAtt SageMakerPermissionsStack.Outputs.SageMakerRoleArn, !Ref IamRole] 193 | 194 | SolutionPrefix: 195 | Description: "Solution Prefix for naming SageMaker transient resources" 196 | Value: !Ref SolutionPrefix 197 | 198 | SolutionName: 199 | Description: "Name of the solution" 200 | Value: "Entity-resolution-for-smart-advertising" 201 | 202 | SolutionUpstreamS3Bucket: 203 | Description: "Upstream solutions bucket" 204 | Value: !Sub 205 | - "${SolutionRefBucketBase}-${AWS::Region}" 206 | - SolutionRefBucketBase: !FindInMap [S3, !Ref StackVersion, BucketPrefix] 207 | 208 | SolutionS3Bucket: 209 | Description: "Solution S3 bucket name" 210 | Value: !Sub "${SolutionPrefix}-${AWS::AccountId}-${AWS::Region}" 211 | 212 | S3InputDataPrefix: 213 | Description: "S3 bucket prefix for raw data" 214 | Value: !Ref S3RawDataPrefix 215 | 216 | S3ProcessingJobOutputPrefix: 217 | Description: "S3 bucket prefix for processed data" 218 | Value: !Ref S3ProcessingJobOutputPrefix 219 | 220 | S3TrainingJobOutputPrefix: 221 | Description: "S3 bucket prefix for trained model and other artifacts" 222 | Value: !Ref S3TrainingJobOutputPrefix 223 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | -------------------------------------------------------------------------------- /source/sagemaker/dgl-entity-resolution.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Graph Entity Resolution with DGL on Amazon SageMaker\n", 8 | "\n", 9 | "This notebook how to train an entity resolution model using graph neural networks. Entity resolution is the task of identifying and linking entites in a graph that belong to the same real world entity. This is useful for use-cases like user profiling where users might access an online service via different temporary session IDs generated by different devices. Entity resolution allows us to consolidate all information about a particular user and deduplicate the user profile.\n", 10 | "\n", 11 | "There are two main parts of this notebook.\n", 12 | "* First, we process the raw dataset to prepare the features and construct the graph.\n", 13 | "* Next, we create a launch a training job using the SageMaker to train a graph neural network model with DGL." 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "!bash setup.sh\n", 23 | "\n", 24 | "import sagemaker\n", 25 | "from sagemaker_graph_entity_resolution import config\n", 26 | "\n", 27 | "role = config.role\n", 28 | "sess = sagemaker.Session()" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "## Data Preprocessing and Feature Engineering" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "metadata": {}, 41 | "source": [ 42 | "### Upload raw data to S3\n", 43 | "\n", 44 | "The dataset we use is the [DCA dataset](https://drive.google.com/drive/folders/0B7XZSACQf0KdNXVIUXEyVGlBZnc?usp=drive_open) released as part of the [2016 CIKM Cup competition](https://competitions.codalab.org/competitions/11171). The dataset contains anonymized browsing logs of user accessing various urls. In order to ensure the demonstration runs quickly, and to match the typical format of user activity data that many companies have, we have done some initial preparation steps. We converted the data from json to a relational table format and sampled just a subset of the overall data. The data preparation scripts can be seen in the `data-prep/` folder\n", 45 | "\n", 46 | "The prepared dataset consists of two files:\n", 47 | "\n", 48 | "* `logs.csv`: Records user browsing activity. Each entry consists of a timestamp, the anonymized urls that the user visited, the anonymized title of the url page, and the anonymized transient user id. The column names for the dataset are `['ts', 'urls', 'titles', 'uid']`\n", 49 | "\n", 50 | "* `train.csv`: Records ground truth links between pairs of transient user ids. Each entry has a pair of uids that are known to be the same real world user. The file has no headers.\n", 51 | "\n", 52 | "\n", 53 | "Now, let's move the raw data to a convenient location in an S3 bucket in your account for this proejct. There it will be picked up by the preprocessing job and training job.\n", 54 | "\n", 55 | "If you would like to use your own dataset for this demonstration. Replace the `raw_data_location` in the cell below with the s3 path or local path of your dataset, and modify the data preprocessing step as needed." 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "# Replace with an S3 location or local path to point to your own dataset\n", 65 | "raw_data_location = 's3://{}/{}/data'.format(config.solution_upstream_bucket, config.solution_name)\n", 66 | "\n", 67 | "session_prefix = 'dgl-entity-resolution'\n", 68 | "input_data = 's3://{}/{}/{}'.format(config.solution_bucket, session_prefix, config.s3_data_prefix)\n", 69 | "\n", 70 | "!aws s3 cp --recursive $raw_data_location $input_data\n", 71 | "\n", 72 | "# Set S3 locations to store processed data for training and post-training results and artifacts respectively\n", 73 | "train_data = 's3://{}/{}/{}'.format(config.solution_bucket, session_prefix, config.s3_processing_output)\n", 74 | "train_output = 's3://{}/{}/{}'.format(config.solution_bucket, session_prefix, config.s3_train_output)" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "### Run Preprocessing job with Amazon SageMaker Processing\n", 82 | "\n", 83 | "The script we have defined at `data-preprocessing/data_preprocessing.py` performs data preprocessing and feature engineering transformations on the raw tabular data.\n", 84 | "\n", 85 | "We convert the relational table to graph edgelists describing the relationships. For example the columns `['uid', 'urls']` are converted to an edgelist for edge type `('user', 'visits', 'url')` and the columns `['urls', 'titles']` are converted into an edgelist for edge type `('url', 'owned_by', 'domain')`.\n", 86 | "\n", 87 | "\n", 88 | "We also perform feature engineering to generate features for each user and each domain.\n", 89 | "\n", 90 | "* User features: We use the timestamps in the `ts` column to generate k-hot feature vectors that encode users' weekly browsing habits. For each user, we generate a 168 dimensional (7 days * 24 hours) vector.\n", 91 | "\n", 92 | "* Url features: We use the anonymized tokens in the full url and title to generate features for the url. For example if a url is `a/b/c?d` and the title is `a`, we have a bag of words of `['a', 'b', 'c', 'd']` for the url. We use a [TfIdfVectorizer](https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.TfidfVectorizer.html) to convert the text features into numerical features and then perform [dimensionality reduction](https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.TruncatedSVD.html) to obtain a 20 dimensional feature vector.\n", 93 | "\n", 94 | "In order to adapt the preprocessing script to work with your data in the same format, you can modify the python script `data-preprocessing/data_preprocessing_script.py` used in the cell below.\n", 95 | "\n", 96 | "The python processing script also splits our ground-truth linked entities into a train and test/validation set. The default test-ratio is 0.3 but this can be modified.\n", 97 | "\n", 98 | "We use the built SKLearnProcessor provided SageMaker since it already contains the python dependencies - (pandas, sklearn) - that we need for preprocessing and feature engineering." 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "from sagemaker.sklearn.processing import SKLearnProcessor\n", 108 | "from sagemaker.processing import ProcessingInput, ProcessingOutput\n", 109 | "\n", 110 | "sklearn_processor = SKLearnProcessor(framework_version='0.20.0',\n", 111 | " role=role,\n", 112 | " instance_count=1,\n", 113 | " instance_type='ml.m5.xlarge')\n", 114 | "\n", 115 | "sklearn_processor.run(code='data-preprocessing/data_preprocessing.py',\n", 116 | " arguments = ['--test-ratio', '0.3'],\n", 117 | " inputs=[ProcessingInput(source=input_data,\n", 118 | " destination='/opt/ml/processing/input')],\n", 119 | " outputs=[ProcessingOutput(destination=train_data,\n", 120 | " source='/opt/ml/processing/output')])" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "metadata": {}, 126 | "source": [ 127 | "### View Results of Data Preprocessing\n", 128 | "Once the preprocessing job is complete, we can take a look at the contents of processing output folder in the S3 bucket to see the transformed data. \n", 129 | "\n", 130 | "You should see the following files:\n", 131 | "\n", 132 | "* `transient_edges.csv`: the set of edges between each transient uid and each url visited by uid.\n", 133 | "* `transient_nodes:csv`: the set of transient uid nodes along with the activity feature vector for the uids. \n", 134 | "* `user_train_edges`: the set of ground-truth same entities for pairs of trainsient uids that will be used during training.\n", 135 | "* `user_test_edges`: the set of ground-truth same entities that will be used to evaluate the trained model. \n", 136 | "* `website_group_nodes`: the set of url nodes along with the feature vectors for the urls.\n", 137 | "* `website_group_edges`: the set of edges between each url and it's parent domain.\n", 138 | "\n", 139 | "We add these files to our `param` dictionary because our downstream training job will use these file names to construct the identity graph during model training. " 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "from os import path\n", 149 | "from sagemaker.s3 import S3Downloader\n", 150 | "processed_files = S3Downloader.list(train_data)\n", 151 | "print(\"===== Processed Files =====\")\n", 152 | "print('\\n'.join(processed_files))\n", 153 | "\n", 154 | "params = {\n", 155 | " 'train-edges': 'user_train_edges.csv',\n", 156 | " 'test-edges': 'user_test_edges.csv',\n", 157 | " 'transient-nodes': 'transient_nodes.csv',\n", 158 | " 'transient-edges': 'transient_edges.csv',\n", 159 | " 'website-nodes': 'website_nodes.csv',\n", 160 | " 'website-group-edges': 'website_group_edges.csv'\n", 161 | "}\n", 162 | "\n", 163 | "print(\"Graph will be constructed using the following data:\\n{}\".format(params))" 164 | ] 165 | }, 166 | { 167 | "cell_type": "markdown", 168 | "metadata": {}, 169 | "source": [ 170 | "## Train Graph Neural Network with DGL\n", 171 | "\n", 172 | "Graph Neural Networks (GNNs) work by learning numeric representations for nodes and edges informed by the graph structure. We can model the entity resolution problem as a link prediction problems i.e we have a few links between users that are the same entity and we would like to use that information to predict new links/edges between users that are not linked in the graph but may correspond to linked entities.\n", 173 | "\n", 174 | "In order to train a model that can achieve this we need to make/specify two modelling assumptions the `GNN Architecture` and the `Self-Supervision Task`.\n", 175 | "\n", 176 | "* *GNN Architecture*: The GNN Architecture is what GNN framework is used to learn the node representations that are consumed by downstream task model. Since we have nodes and edges of different types, we will be using a relational graph convolutional neural network model (R-GCN). This architecture using works well on heterogeneous graphs. This is also what is known as the `Graph Encoder`\n", 177 | "\n", 178 | "* *Self-Supervision Task*: As alluded to, the task that will be used to supervise the encoder is *link prediciton*. Formally, we generate some negative edges by creating links between nodes sampled at random from the graph. The goal of the task is to learn a score function that gives higher scores to the real positive edges and lower scores to the negative edges. This is what is also known as the `Graph Decoder`" 179 | ] 180 | }, 181 | { 182 | "cell_type": "markdown", 183 | "metadata": {}, 184 | "source": [ 185 | "### Hyperparameters\n", 186 | "To train the graph neural network, we need to define a few hyperparameters that determine the GNN architecture, graph sampling parameters, optimizer, and optimization parameters.\n", 187 | "\n", 188 | "Here we're setting only a few of the hyperparameters, to see all the hyperparameters and their default values, see dgl-entity-resolution/estimator_fns.py. The parameters set below are:\n", 189 | "\n", 190 | "\n", 191 | "* `mini-batch`: Whether to perform mini-batch training, which training the model with a batch of nodes at a time and using the sampled graph neighbourhood of the mini-batch nodes.\n", 192 | "* `batch-size`: The number of nodes in a mini-batch that are used to compute a single forward pass of the GNN.\n", 193 | "* `num-gpus`: The number of gpus to use during training. Use only when training with a GPU enabled instance\n", 194 | "* `embedding-size`: In the inductive case, the number of dimensions of the node specific embedding that is concatenated with the node feature vector. For nodes that don't have features, the dimensionality is just `embedding-size`.\n", 195 | "* `n-neighbors`: The number of neighbours to sample for each target node during graph sampling for mini-batch training\n", 196 | "* `n-layers`: The number of GNN layers in the model\n", 197 | "* `negative-sampling-rate`: How many `negative edges` to sample from the graph for each positive edge in the mini-batch. This is used to supervise the loss function so it penalize negative edges and distinguishes those from positive edges. \n", 198 | "* `n-epochs`: The number of training epochs for the model training job. We set this to 3 epochs so that the job terminates quickly. In order to obtain better predictions, train the model for more epochs.\n", 199 | "* `optimizer`: The optimization algorithm used for gradient based parameter updates\n", 200 | "* `lr`: The learning rate for parameter updates" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": null, 206 | "metadata": {}, 207 | "outputs": [], 208 | "source": [ 209 | "hyperparams = {\n", 210 | " 'mini-batch': 'true',\n", 211 | " 'batch-size': 1000,\n", 212 | " 'num-gpus': 1,\n", 213 | " 'embedding-size': 64,\n", 214 | " 'n-neighbors': 100,\n", 215 | " 'n-hidden': 16,\n", 216 | " 'n-layers': 2,\n", 217 | " 'negative-sampling-rate': 10,\n", 218 | " 'n-epochs': 3,\n", 219 | " 'optimizer': 'adam',\n", 220 | " 'lr': 1e-2\n", 221 | "}\n", 222 | "\n", 223 | "params.update(**hyperparams)" 224 | ] 225 | }, 226 | { 227 | "cell_type": "markdown", 228 | "metadata": {}, 229 | "source": [ 230 | "### Create and Fit SageMaker Estimator\n", 231 | "\n", 232 | "With the hyperparameters defined, we can kick off the training job. We will be using the Deep Graph Library (DGL), with PyTorch as the backend deep learning framework, to define and train the graph neural network. Amazon SageMaker makes it do this with the Framework estimators which have the deep learning frameworks already setup. Here, we create a SageMaker PyTorch estimator and pass in our model training script, hyperparameters, as well as the number and type of training instances we want.\n", 233 | "\n", 234 | "We can then fit the estimator on the the training data location in S3." 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": null, 240 | "metadata": {}, 241 | "outputs": [], 242 | "source": [ 243 | "from sagemaker.pytorch import PyTorch\n", 244 | "\n", 245 | "estimator = PyTorch(entry_point='train_dgl_pytorch_entity_resolution.py',\n", 246 | " source_dir='sagemaker_graph_entity_resolution/dgl_entity_resolution',\n", 247 | " role=role, \n", 248 | " train_instance_count=1,\n", 249 | " train_instance_type='ml.g4dn.xlarge',\n", 250 | " framework_version=\"1.5.0\",\n", 251 | " py_version='py3',\n", 252 | " hyperparameters=params,\n", 253 | " output_path=train_output,\n", 254 | " code_location=train_output,\n", 255 | " sagemaker_session=sagemaker.Session())\n", 256 | "\n", 257 | "estimator.fit({'train': train_data})" 258 | ] 259 | }, 260 | { 261 | "cell_type": "markdown", 262 | "metadata": {}, 263 | "source": [ 264 | "Once the training is completed, the training instances are shut off and SageMaker stores the trained model and new predicted links to the output location in S3." 265 | ] 266 | } 267 | ], 268 | "metadata": { 269 | "kernelspec": { 270 | "display_name": "Python 3 (Data Science JumpStart)", 271 | "language": "python", 272 | "name": "HUB_1P_IMAGE" 273 | }, 274 | "language_info": { 275 | "codemirror_mode": { 276 | "name": "ipython", 277 | "version": 3 278 | }, 279 | "file_extension": ".py", 280 | "mimetype": "text/x-python", 281 | "name": "python", 282 | "nbconvert_exporter": "python", 283 | "pygments_lexer": "ipython3", 284 | "version": "3.7.7" 285 | } 286 | }, 287 | "nbformat": 4, 288 | "nbformat_minor": 4 289 | } --------------------------------------------------------------------------------