├── bhakti-cdk ├── tests │ ├── __init__.py │ └── unit │ │ ├── __init__.py │ │ └── test_bhakti_cdk_stack.py ├── bhakti_cdk │ ├── __init__.py │ ├── bhakti_central_components.py │ ├── bhakti_instance_profiles.py │ └── bhakti_monitoring_stack.py ├── lambda │ ├── requirements.txt │ └── monitoring_lambda.py ├── requirements-dev.txt ├── requirements.txt ├── analysis │ ├── requirements.txt │ ├── monitoring_ec2_check.py │ └── checkModel.py ├── cdk.context.json ├── .gitignore ├── source.bat ├── app.py ├── cdk.json └── README.md ├── media └── Bhakti.png ├── analysis ├── requirements.txt ├── monitoring_ec2_check.py └── checkModel.py ├── yara ├── keras-lambda.yara ├── keras-b64-url.yara └── keras-requests.yara ├── README.md └── LICENSE /bhakti-cdk/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bhakti-cdk/bhakti_cdk/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bhakti-cdk/tests/unit/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bhakti-cdk/lambda/requirements.txt: -------------------------------------------------------------------------------- 1 | requests -------------------------------------------------------------------------------- /bhakti-cdk/requirements-dev.txt: -------------------------------------------------------------------------------- 1 | pytest==6.2.5 2 | -------------------------------------------------------------------------------- /bhakti-cdk/requirements.txt: -------------------------------------------------------------------------------- 1 | aws-cdk-lib==2.130.0 2 | constructs>=10.0.0,<11.0.0 3 | -------------------------------------------------------------------------------- /media/Bhakti.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dropbox/bhakti/main/media/Bhakti.png -------------------------------------------------------------------------------- /analysis/requirements.txt: -------------------------------------------------------------------------------- 1 | h5py==3.10.0 2 | Requests==2.31.0 3 | tensorflow==2.16.1 4 | -------------------------------------------------------------------------------- /bhakti-cdk/analysis/requirements.txt: -------------------------------------------------------------------------------- 1 | h5py==3.10.0 2 | Requests==2.31.0 3 | tensorflow==2.16.1 4 | -------------------------------------------------------------------------------- /bhakti-cdk/cdk.context.json: -------------------------------------------------------------------------------- 1 | { 2 | "deploy_type": "monitoring", 3 | "deploy_region": "us-west-2", 4 | "deploy_account": 123456789, 5 | "sg_id": "sg-example" 6 | } -------------------------------------------------------------------------------- /bhakti-cdk/.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | package-lock.json 3 | __pycache__ 4 | .pytest_cache 5 | .venv 6 | *.egg-info 7 | 8 | # CDK asset staging directory 9 | .cdk.staging 10 | cdk.out 11 | -------------------------------------------------------------------------------- /bhakti-cdk/source.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | rem The sole purpose of this script is to make the command 4 | rem 5 | rem source .venv/bin/activate 6 | rem 7 | rem (which activates a Python virtualenv on Linux or Mac OS X) work on Windows. 8 | rem On Windows, this command just runs this batch file (the argument is ignored). 9 | rem 10 | rem Now we don't need to document a Windows command for activating a virtualenv. 11 | 12 | echo Executing .venv\Scripts\activate.bat for you 13 | .venv\Scripts\activate.bat 14 | -------------------------------------------------------------------------------- /yara/keras-lambda.yara: -------------------------------------------------------------------------------- 1 | rule KerasLambda 2 | { 3 | meta: 4 | author = "Dropbox Threat Intel" 5 | description = "This signature fires on the presence of a lambda layer in a keras Tensorflow model. The simple presence of such a layer is not an indicator of malicious content, but is worth further investigation." 6 | created_date = "2024-04-05" 7 | updated_date = "2024-04-05" 8 | 9 | strings: 10 | $function = "function_type" 11 | $layer = "lambda" 12 | 13 | condition: 14 | $function and $layer 15 | } -------------------------------------------------------------------------------- /bhakti-cdk/tests/unit/test_bhakti_cdk_stack.py: -------------------------------------------------------------------------------- 1 | import aws_cdk as core 2 | import aws_cdk.assertions as assertions 3 | 4 | from bhakti_cdk.bhakti_cdk_stack import BhaktiCdkStack 5 | 6 | # example tests. To run these tests, uncomment this file along with the example 7 | # resource in bhakti_cdk/bhakti_cdk_stack.py 8 | def test_sqs_queue_created(): 9 | app = core.App() 10 | stack = BhaktiCdkStack(app, "bhakti-cdk") 11 | template = assertions.Template.from_stack(stack) 12 | 13 | # template.has_resource_properties("AWS::SQS::Queue", { 14 | # "VisibilityTimeout": 300 15 | # }) 16 | -------------------------------------------------------------------------------- /yara/keras-b64-url.yara: -------------------------------------------------------------------------------- 1 | rule KerasURL 2 | { 3 | meta: 4 | author = "Dropbox Threat Intel" 5 | description = "This signature fires on the presence of Base64 encoded URI prefixes (http:// and https://) within a lambda layer of a keras Tensorflow model. The simple presence of such strings is not inherently an indicator of malicious content, but is worth further investigation." 6 | created_date = "2024-04-05" 7 | updated_date = "2024-04-05" 8 | 9 | strings: 10 | $function = "function_type" 11 | $layer = "lambda" 12 | $url = "http" base64 13 | 14 | condition: 15 | $url and ($function and $layer) 16 | 17 | } -------------------------------------------------------------------------------- /yara/keras-requests.yara: -------------------------------------------------------------------------------- 1 | rule KerasRequests 2 | { 3 | meta: 4 | author = "Dropbox Threat Intel" 5 | description = "This signature fires on the presence of Base64 encoded URI prefixes (http:// and https://) within a lambda layer of a keras Tensorflow model. The simple presence of such strings is not inherently an indicator of malicious content, but is worth further investigation." 6 | created_date = "2024-04-05" 7 | updated_date = "2024-04-05" 8 | strings: 9 | $function = "function_type" 10 | $layer = "lambda" 11 | $req = "requests" base64 12 | 13 | condition: 14 | $req and ($function and $layer) 15 | } -------------------------------------------------------------------------------- /bhakti-cdk/app.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | 4 | import aws_cdk as cdk 5 | 6 | from bhakti_cdk.bhakti_monitoring_stack import MonitoringStack 7 | from bhakti_cdk.bhakti_instance_profiles import InstanceProfiles 8 | from bhakti_cdk.bhakti_central_components import BhaktiShared 9 | 10 | 11 | app = cdk.App() 12 | deploy_type = app.node.try_get_context("deploy_type") 13 | deploy_region = app.node.try_get_context("deploy_region") 14 | deploy_account = app.node.try_get_context("deploy_account") 15 | env = cdk.Environment(account=f'{deploy_account}', region=f'{deploy_region}') 16 | sg_id = app.node.try_get_context("sg_id") 17 | 18 | print(deploy_type) 19 | print(deploy_region) 20 | print(deploy_account) 21 | print(env) 22 | 23 | if not sg_id: 24 | sg_id = None 25 | 26 | shared_resources = BhaktiShared(app, "BhaktiShared", sg_id=sg_id, env=env) 27 | if deploy_type == 'monitoring': 28 | MonitoringStack(app, "MonitoringStack", 29 | hf_token=shared_resources.hf_token, 30 | script_asset=shared_resources.script_asset, 31 | env=env,) 32 | elif deploy_type == 'instance_profile': 33 | InstanceProfiles(app, "InstanceProfileStack", 34 | hf_token=shared_resources.hf_token, 35 | script_asset=shared_resources.script_asset, 36 | sg_id=shared_resources.sg, 37 | env=env, 38 | ) 39 | 40 | app.synth() 41 | -------------------------------------------------------------------------------- /bhakti-cdk/bhakti_cdk/bhakti_central_components.py: -------------------------------------------------------------------------------- 1 | from aws_cdk import ( 2 | # Duration, 3 | Stack, 4 | aws_secretsmanager, 5 | aws_logs as logs, 6 | aws_s3_assets as assets, 7 | aws_ec2 as ec2 8 | ) 9 | from constructs import Construct 10 | from typing import Optional 11 | 12 | class BhaktiShared(Stack): 13 | 14 | def __init__(self, scope: Construct, construct_id: str, sg_id: str, **kwargs) -> None: 15 | super().__init__(scope, construct_id, **kwargs) 16 | huggingface_token = aws_secretsmanager.Secret( 17 | self, 18 | "huggingface_token", 19 | secret_name="huggingface_api_token" 20 | ) 21 | 22 | #bundle analysis script in s3 for use in ec2 user data 23 | s3_script_asset = assets.Asset( 24 | self, "file_asset", 25 | path=("./analysis/") 26 | ) 27 | self._asset = s3_script_asset 28 | self._token = huggingface_token 29 | 30 | if sg_id: 31 | security_group = ec2.SecurityGroup.from_security_group_id(self, 'sg', sg_id) 32 | self._security_group = security_group 33 | else: 34 | self._security_group = None 35 | 36 | @property 37 | def hf_token(self) -> aws_secretsmanager.Secret: 38 | return self._token 39 | 40 | @property 41 | def script_asset(self) -> assets.Asset: 42 | return self._asset 43 | 44 | @property 45 | def sg(self) -> Optional[ec2.SecurityGroup]: 46 | return self._security_group 47 | 48 | -------------------------------------------------------------------------------- /bhakti-cdk/cdk.json: -------------------------------------------------------------------------------- 1 | { 2 | "app": "python3 app.py", 3 | "watch": { 4 | "include": [ 5 | "**" 6 | ], 7 | "exclude": [ 8 | "README.md", 9 | "cdk*.json", 10 | "requirements*.txt", 11 | "source.bat", 12 | "**/__init__.py", 13 | "**/__pycache__", 14 | "tests" 15 | ] 16 | }, 17 | "context": { 18 | "@aws-cdk/aws-lambda:recognizeLayerVersion": true, 19 | "@aws-cdk/core:checkSecretUsage": true, 20 | "@aws-cdk/core:target-partitions": [ 21 | "aws", 22 | "aws-cn" 23 | ], 24 | "@aws-cdk-containers/ecs-service-extensions:enableDefaultLogDriver": true, 25 | "@aws-cdk/aws-ec2:uniqueImdsv2TemplateName": true, 26 | "@aws-cdk/aws-ecs:arnFormatIncludesClusterName": true, 27 | "@aws-cdk/aws-iam:minimizePolicies": true, 28 | "@aws-cdk/core:validateSnapshotRemovalPolicy": true, 29 | "@aws-cdk/aws-codepipeline:crossAccountKeyAliasStackSafeResourceName": true, 30 | "@aws-cdk/aws-s3:createDefaultLoggingPolicy": true, 31 | "@aws-cdk/aws-sns-subscriptions:restrictSqsDescryption": true, 32 | "@aws-cdk/aws-apigateway:disableCloudWatchRole": true, 33 | "@aws-cdk/core:enablePartitionLiterals": true, 34 | "@aws-cdk/aws-events:eventsTargetQueueSameAccount": true, 35 | "@aws-cdk/aws-iam:standardizedServicePrincipals": true, 36 | "@aws-cdk/aws-ecs:disableExplicitDeploymentControllerForCircuitBreaker": true, 37 | "@aws-cdk/aws-iam:importedRoleStackSafeDefaultPolicyName": true, 38 | "@aws-cdk/aws-s3:serverAccessLogsUseBucketPolicy": true, 39 | "@aws-cdk/aws-route53-patters:useCertificate": true, 40 | "@aws-cdk/customresources:installLatestAwsSdkDefault": false, 41 | "@aws-cdk/aws-rds:databaseProxyUniqueResourceName": true, 42 | "@aws-cdk/aws-codedeploy:removeAlarmsFromDeploymentGroup": true, 43 | "@aws-cdk/aws-apigateway:authorizerChangeDeploymentLogicalId": true, 44 | "@aws-cdk/aws-ec2:launchTemplateDefaultUserData": true, 45 | "@aws-cdk/aws-secretsmanager:useAttachedSecretResourcePolicyForSecretTargetAttachments": true, 46 | "@aws-cdk/aws-redshift:columnId": true, 47 | "@aws-cdk/aws-stepfunctions-tasks:enableEmrServicePolicyV2": true, 48 | "@aws-cdk/aws-ec2:restrictDefaultSecurityGroup": true, 49 | "@aws-cdk/aws-apigateway:requestValidatorUniqueId": true, 50 | "@aws-cdk/aws-kms:aliasNameRef": true, 51 | "@aws-cdk/aws-autoscaling:generateLaunchTemplateInsteadOfLaunchConfig": true, 52 | "@aws-cdk/core:includePrefixInUniqueNameGeneration": true, 53 | "@aws-cdk/aws-efs:denyAnonymousAccess": true, 54 | "@aws-cdk/aws-opensearchservice:enableOpensearchMultiAzWithStandby": true, 55 | "@aws-cdk/aws-lambda-nodejs:useLatestRuntimeVersion": true, 56 | "@aws-cdk/aws-efs:mountTargetOrderInsensitiveLogicalId": true, 57 | "@aws-cdk/aws-rds:auroraClusterChangeScopeOfInstanceParameterGroupWithEachParameters": true, 58 | "@aws-cdk/aws-appsync:useArnForSourceApiAssociationIdentifier": true, 59 | "@aws-cdk/aws-rds:preventRenderingDeprecatedCredentials": true, 60 | "@aws-cdk/aws-codepipeline-actions:useNewDefaultBranchForCodeCommitSource": true, 61 | "@aws-cdk/aws-cloudwatch-actions:changeLambdaPermissionLogicalIdForLambdaAction": true, 62 | "@aws-cdk/aws-codepipeline:crossAccountKeysDefaultValueToFalse": true 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Bhakti 🐕 2 | 3 | [Bhakti](https://finalfantasy.fandom.com/wiki/Bhakti) is a conglomeration of analysis tools to look at certain types of machine learning models for the presence of a code execution layer. There are three main components: 4 | 5 | - Analysis script(s) 6 | - Amazon CDK to create an AWS investigation lab or to stand-up automated model monitoring 7 | - Yara rules that can be used to identify risky models 8 | 9 | ## Background 10 | 11 | This little repo is scrappy tooling that resulted from a threat hunt that the Dropbox Threat Intelligence team conducted across huggingface in 2023 and into 2024. This work coalesced after [@5stars217](https://github.com/5stars217) began to publish his work on malicious models in the keras space (see [On Malicious Models](https://5stars217.github.io/2023-03-30-on-malicious-models/)). Currently, all analysis is restricted to Tensorflow models using Keras that leverage a lambda layer as a vehicle for arbitrary code execution on a victim system. 12 | 13 | ## Analysis scripts 14 | 15 | [Analysis Scripts](analysis/) 16 | - `checkModel.py` is designed to assess either a local model or a huggingface repo for a lambda layer. It supports `.h5` and `keras_metadata.pb` formats; it attempts to dump any code found within any identified layers in these kinds of files. 17 | -`monitoring_ec2_check.py` is designed to run as part of huggingface monitoring hosted on AWS; it's deployed with the monitoring cdk stack. It does a bunch of updating of dynamo, pulling work to do from sqs, etc. 18 | 19 | ## YARA rules 20 | [YARA Rules](yara/) 21 | - `keras-lambda.yara` flags on any Tensorflow Keras model containing a lambda layer 22 | - `keras-requests.yara` flags on any Tensorflow Keras model using the requests library in a lambda layer 23 | - `keras-subprocess.yara` flags on any Tensorflow Keras model using the subprocess library in a lambda layer 24 | 25 | ## CDK Stuff 26 | [CDK Things](bhakti-cdk/) 27 | 28 | Parameterizing CDK and making it beautiful and portable is not really my forte, but I've done my best. There's a whole additional [README.md](bhakti-cdk/README.md) file in the cdk sub-folder with more information about standing up this infrastructure in your own account. **AWS isn't free**, please configure your account with appropriate billing alarms so you're not taken aback by anything these stacks might do trying to be a good little robots. 29 | 30 | - [$monitoring_stack](bhakti-cdk/bhakti_cdk/bhakti_monitoring_stack.py) will attempt to deploy a monitoring solution in a bootstrapped AWS account. 31 | - [$launch_template_stack](bhakti-cdk/bhakti_cdk/bhakti_instance_profiles.py) will attempt to stand-up an ec2 launch template in a bootstrapped AWS account to use for ML malware analysis 32 | 33 | ## License 34 | 35 | Unless otherwise noted: 36 | 37 | ``` 38 | Copyright (c) 2023-2024 Dropbox, Inc 39 | 40 | Licensed under the Apache License, Version 2.0 (the "License"); 41 | you may not use this file except in compliance with the License. 42 | You may obtain a copy of the License at 43 | 44 | http://www.apache.org/licenses/LICENSE-2.0 45 | 46 | Unless required by applicable law or agreed to in writing, software 47 | distributed under the License is distributed on an "AS IS" BASIS, 48 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 49 | See the License for the specific language governing permissions and 50 | limitations under the License. 51 | ``` -------------------------------------------------------------------------------- /bhakti-cdk/bhakti_cdk/bhakti_instance_profiles.py: -------------------------------------------------------------------------------- 1 | from aws_cdk import ( 2 | # Duration, 3 | Stack, 4 | aws_secretsmanager, 5 | aws_logs as logs, 6 | aws_ec2 as ec2, 7 | aws_iam as iam, 8 | aws_s3_assets as assets, 9 | ) 10 | from constructs import Construct 11 | from typing import Optional 12 | 13 | class InstanceProfiles(Stack): 14 | def __init__(self, scope: Construct, 15 | construct_id: str, 16 | hf_token: aws_secretsmanager.Secret, 17 | sg_id: ec2.SecurityGroup, 18 | script_asset: assets.Asset, 19 | **kwargs) -> None: 20 | super().__init__(scope, construct_id, **kwargs) 21 | 22 | # Add any AWS access you need on your EC2 instance as PolicyStatements in this PolicyDocument 23 | bhakti_access = iam.PolicyDocument( 24 | statements=[iam.PolicyStatement( 25 | actions=["secretsmanager:GetSecretValue", "secretsmanager:DescribeSecret"], 26 | resources=[hf_token.secret_arn] 27 | )] 28 | ) 29 | 30 | bhakti_ec2_access_policy = iam.Policy( 31 | self, "bhakti_ec2_access_policy", 32 | document=bhakti_access 33 | ) 34 | 35 | bhakti_role = iam.Role( 36 | self, "bhakti_role", 37 | assumed_by=iam.ServicePrincipal("ec2.amazonaws.com"), 38 | description="EC2 instance role for Bhakti analysis instances" 39 | ) 40 | 41 | bhakti_role.attach_inline_policy(bhakti_ec2_access_policy) 42 | 43 | bhakti_instance_profile = iam.CfnInstanceProfile( 44 | self, "bhakti_instance_profile", 45 | roles=[bhakti_role.role_name] 46 | ) 47 | 48 | bhakti_keypair = ec2.KeyPair(self, "bhakti_ssh_key", 49 | key_pair_name='bhakti-ssh-key' 50 | ) 51 | 52 | #bundle analysis script in s3 for use in ec2 user data 53 | #asset = assets.Asset( 54 | # self, "file_asset", 55 | # path=("./analysis/checkModel.py") 56 | #) 57 | 58 | bhakti_user_data = ec2.UserData.for_linux() 59 | 60 | local_path = bhakti_user_data.add_s3_download_command( 61 | bucket=script_asset.bucket, 62 | bucket_key = script_asset.s3_object_key, 63 | ) 64 | 65 | check_model = f'{local_path} -d /home/ec2-user/analysis' 66 | 67 | bhakti_user_data.add_execute_file_command( 68 | file_path='/usr/bin/unzip', 69 | arguments=check_model 70 | ) 71 | 72 | script_asset.grant_read(bhakti_role) 73 | 74 | # This template does not include any default security groups-- this means you won't be able to access it 75 | # until you set-up at least an ssh-allow security group within EC2. 76 | 77 | if sg_id: 78 | bhakti_analysis = ec2.LaunchTemplate( 79 | self, "ec2_template", 80 | launch_template_name="bhakti_model_analysis", 81 | machine_image=ec2.MachineImage.lookup(name='Deep*',filters={'image-id':['ami-0b28c78d9f575dfa1']}, owners=["amazon"]), 82 | instance_type=ec2.InstanceType.of(ec2.InstanceClass.G4DN, ec2.InstanceSize.XLARGE), 83 | key_pair=bhakti_keypair, 84 | user_data=bhakti_user_data, 85 | role=bhakti_role, 86 | security_group=sg_id, 87 | instance_initiated_shutdown_behavior=ec2.InstanceInitiatedShutdownBehavior.TERMINATE, 88 | ) 89 | else: 90 | bhakti_analysis = ec2.LaunchTemplate( 91 | self, "ec2_template", 92 | launch_template_name="bhakti_model_analysis", 93 | machine_image=ec2.MachineImage.lookup(name='Deep*',filters={'image-id':['ami-0b28c78d9f575dfa1']}, owners=["amazon"]), 94 | instance_type=ec2.InstanceType.of(ec2.InstanceClass.G4DN, ec2.InstanceSize.XLARGE), 95 | key_pair=bhakti_keypair, 96 | user_data=bhakti_user_data, 97 | role=bhakti_role, 98 | instance_initiated_shutdown_behavior=ec2.InstanceInitiatedShutdownBehavior.TERMINATE, 99 | ) -------------------------------------------------------------------------------- /analysis/monitoring_ec2_check.py: -------------------------------------------------------------------------------- 1 | #!/opt/tensorflow/bin/python3 2 | 3 | import boto3 4 | from botocore.exceptions import ClientError 5 | import json 6 | from pathlib import Path 7 | import logging 8 | import requests 9 | from tensorflow.python.keras.protobuf.saved_metadata_pb2 import SavedMetadata 10 | import subprocess 11 | from datetime import datetime 12 | import os 13 | 14 | SQS_QUEUE = os.getenv('SQS_QUEUE') 15 | AWS_REGION = os.getenv('AWS_REG') 16 | HUGGINGFACE_TOKEN = os.getenv('HUGGINGFACE_TOKEN') 17 | MODEL_DIRECTORY='/tmp/models' 18 | DYNAMO_STATUS_TABLE = os.getenv('DYNAMO_STATUS_TABLE') 19 | LOGGING_BUCKET = os.getenv('LOGGING_BUCKET') 20 | 21 | logger = logging.getLogger() 22 | logging.basicConfig(filename='/var/log/bhakti.log', encoding='utf-8', level=logging.DEBUG) 23 | logger.setLevel(logging.INFO) 24 | 25 | def get_api_token(): 26 | client = boto3.client('secretsmanager', region_name=AWS_REGION) 27 | get_secret_value_response = client.get_secret_value(SecretId=HUGGINGFACE_TOKEN) 28 | secret = get_secret_value_response['SecretString'] 29 | return secret 30 | 31 | def download_metadata_file(msg_body, token): 32 | model = msg_body['id'] 33 | filename = '' 34 | for file in msg_body['siblings']: 35 | if 'keras_metadata.pb' in file['rfilename']: 36 | filename = file['rfilename'] 37 | logger.info((f'Attempting to download {model}/{filename} from HuggingFace')) 38 | 39 | downloadLoc = Path(f"{MODEL_DIRECTORY}/{model}/{filename}") 40 | downloadLoc.parent.mkdir(parents=True, exist_ok=True) 41 | downloadLink = f"https://huggingface.co/{model}/resolve/main/{filename}" 42 | logger.info((f'TRYING: {downloadLink}')) 43 | 44 | headers = { 45 | 'Authorization': f'Bearer {token}' 46 | } 47 | 48 | try: 49 | response = requests.get(downloadLink, headers=headers) 50 | if response.status_code == 401: 51 | downloadLoc = describe_no_access(downloadLoc) 52 | logger.info((f"Code 401: {downloadLoc}")) 53 | elif response.status_code == 200: 54 | with open(downloadLoc, "wb") as resultFile: 55 | resultFile.write(response.content) 56 | logger.info((f'wrote file to {downloadLoc}')) 57 | 58 | except Exception as e: 59 | with open(f'{downloadLoc}-FAILED', 'w') as failed: 60 | failed.write("COULD NOT DOWNLOAD") 61 | logger.error((e)) 62 | return f'{downloadLoc}-FAILED' 63 | 64 | return downloadLoc 65 | 66 | def describe_no_access(location): 67 | with open(f'{location}-GATED', 'w') as noAccess: 68 | noAccess.write("CAN'T FETCH") 69 | logger.info(("couldn't access model")) 70 | return f'{location}-GATED' 71 | 72 | def check_for_code(local_file): 73 | metadata = {} 74 | saved_metadata = SavedMetadata() 75 | logger.info((f"******* Checking {local_file} for keras Lambda Layer *********")) 76 | try: 77 | with open(local_file, 'rb') as f: 78 | saved_metadata.ParseFromString(f.read()) 79 | lambda_code = [layer["config"]["function"]["items"][0] 80 | for layer in [json.loads(node.metadata) 81 | for node in saved_metadata.nodes 82 | if node.identifier == "_tf_keras_layer"] 83 | if layer["class_name"] == "Lambda"] 84 | for code in lambda_code: 85 | logger.info((f"found code in {local_file}: {code}")) 86 | logger.info((f"CODE: {code}")) 87 | code = lambda_code[0] 88 | metadata['extracted_encoded_code'] = code 89 | metadata['contains_code'] = True 90 | return metadata 91 | # If we don't find a lambda layer, the above check will give an IndexError that we can assume 92 | # that the model does not contain a Lambda layer 93 | except IndexError as ie: 94 | metadata['contains_code'] = False 95 | logger.info(("didn't find code")) 96 | return metadata 97 | except Exception as e: 98 | logger.info((f'We had a non-index error analyzing {local_file} : {e}')) 99 | return metadata 100 | 101 | def update_dynamo(result): 102 | result['model_type'] = 'protobuf' 103 | model = result["repo"] 104 | 105 | dynamodb = boto3.resource('dynamodb', region_name=AWS_REGION) 106 | table = dynamodb.Table(DYNAMO_STATUS_TABLE) 107 | 108 | response = table.query( 109 | Select='COUNT', 110 | KeyConditionExpression='repo = :repo', 111 | ExpressionAttributeValues={':repo': model} 112 | ) 113 | 114 | if response['Count'] > 0: 115 | #get and preserve prior analysis 116 | old_analysis = table.get_item(Key={'repo': model, 'version': 'v0'}) 117 | version = f"v{response['Count']}" 118 | old_analysis['Item']['version'] = version 119 | table.put_item( 120 | Item = old_analysis['Item'] 121 | ) 122 | 123 | #update table with new data 124 | result['version'] = 'v0' 125 | addition_response = table.delete_item(Key={'repo': model, 'version': 'v0'}) 126 | table.put_item( 127 | Item = result 128 | ) 129 | logger.info((f'Added {result["repo"]} got code {addition_response["ResponseMetadata"]["HTTPStatusCode"]}')) 130 | 131 | else: 132 | logger.info((f'New model {result["repo"]} analyzed, adding to metadata store')) 133 | result['version'] = 'v0' 134 | addition_response = table.put_item( 135 | Item = result 136 | ) 137 | logger.info((f'Added {result["repo"]} got code {addition_response["ResponseMetadata"]["HTTPStatusCode"]}')) 138 | 139 | sqs = boto3.resource('sqs', region_name=AWS_REGION) 140 | bhakti_queue = sqs.get_queue_by_name( 141 | QueueName=SQS_QUEUE 142 | ) 143 | api_token = get_api_token() 144 | 145 | scanning = True 146 | while scanning: 147 | sqs_messages = bhakti_queue.receive_messages( 148 | MaxNumberOfMessages=1, 149 | AttributeNames=["All"], 150 | MessageAttributeNames=["All"], 151 | WaitTimeSeconds=20, 152 | ) 153 | if len(sqs_messages) == 0: 154 | scanning = False 155 | for sqs_message in sqs_messages: 156 | body = sqs_message.body 157 | msg_body = json.loads(body) 158 | logger.info((f'SQS GIVING US {msg_body}')) 159 | model = msg_body['id'] 160 | 161 | local_file = download_metadata_file(msg_body, api_token) 162 | logger.info((local_file)) 163 | sqs_message.delete() 164 | 165 | result = {} 166 | if str(local_file).endswith('-GATED'): 167 | logger.info((f'{model} is not publicly available')) 168 | result['private'] = True 169 | else: 170 | result = check_for_code(local_file) 171 | result['repo'] = model 172 | result['modified_date'] = msg_body['lastModified'] 173 | result['keras_filenam'] = msg_body['keras_filename'] 174 | 175 | logger.info((f'RESULTS {result}')) 176 | update_dynamo(result) 177 | 178 | try: 179 | s3 = boto3.client('s3', region_name=AWS_REGION) 180 | s3.put_object(Bucket = LOGGING_BUCKET, Key=f"{int(round(datetime.timestamp(datetime.now())))}-bhakti.log", Body='/var/log/bhakti.log') 181 | except Exception as e: 182 | print('unable to upload log to s3') 183 | 184 | subprocess.call(["shutdown"]) -------------------------------------------------------------------------------- /bhakti-cdk/analysis/monitoring_ec2_check.py: -------------------------------------------------------------------------------- 1 | #!/opt/tensorflow/bin/python3 2 | 3 | import boto3 4 | from botocore.exceptions import ClientError 5 | import json 6 | from pathlib import Path 7 | import logging 8 | import requests 9 | from tensorflow.python.keras.protobuf.saved_metadata_pb2 import SavedMetadata 10 | import subprocess 11 | from datetime import datetime 12 | import os 13 | 14 | SQS_QUEUE = os.getenv('SQS_QUEUE') 15 | AWS_REGION = os.getenv('AWS_REG') 16 | HUGGINGFACE_TOKEN = os.getenv('HUGGINGFACE_TOKEN') 17 | MODEL_DIRECTORY='/tmp/models' 18 | DYNAMO_STATUS_TABLE = os.getenv('DYNAMO_STATUS_TABLE') 19 | LOGGING_BUCKET = os.getenv('LOGGING_BUCKET') 20 | 21 | logger = logging.getLogger() 22 | logging.basicConfig(filename='/var/log/bhakti.log', encoding='utf-8', level=logging.DEBUG) 23 | logger.setLevel(logging.INFO) 24 | 25 | def get_api_token(): 26 | client = boto3.client('secretsmanager', region_name=AWS_REGION) 27 | get_secret_value_response = client.get_secret_value(SecretId=HUGGINGFACE_TOKEN) 28 | secret = get_secret_value_response['SecretString'] 29 | return secret 30 | 31 | def download_metadata_file(msg_body, token): 32 | model = msg_body['id'] 33 | filename = '' 34 | for file in msg_body['siblings']: 35 | if 'keras_metadata.pb' in file['rfilename']: 36 | filename = file['rfilename'] 37 | logger.info((f'Attempting to download {model}/{filename} from HuggingFace')) 38 | 39 | downloadLoc = Path(f"{MODEL_DIRECTORY}/{model}/{filename}") 40 | downloadLoc.parent.mkdir(parents=True, exist_ok=True) 41 | downloadLink = f"https://huggingface.co/{model}/resolve/main/{filename}" 42 | logger.info((f'TRYING: {downloadLink}')) 43 | 44 | headers = { 45 | 'Authorization': f'Bearer {token}' 46 | } 47 | 48 | try: 49 | response = requests.get(downloadLink, headers=headers) 50 | if response.status_code == 401: 51 | downloadLoc = describe_no_access(downloadLoc) 52 | logger.info((f"Code 401: {downloadLoc}")) 53 | elif response.status_code == 200: 54 | with open(downloadLoc, "wb") as resultFile: 55 | resultFile.write(response.content) 56 | logger.info((f'wrote file to {downloadLoc}')) 57 | 58 | except Exception as e: 59 | with open(f'{downloadLoc}-FAILED', 'w') as failed: 60 | failed.write("COULD NOT DOWNLOAD") 61 | logger.error((e)) 62 | return f'{downloadLoc}-FAILED' 63 | 64 | return downloadLoc 65 | 66 | def describe_no_access(location): 67 | with open(f'{location}-GATED', 'w') as noAccess: 68 | noAccess.write("CAN'T FETCH") 69 | logger.info(("couldn't access model")) 70 | return f'{location}-GATED' 71 | 72 | def check_for_code(local_file): 73 | metadata = {} 74 | saved_metadata = SavedMetadata() 75 | logger.info((f"******* Checking {local_file} for keras Lambda Layer *********")) 76 | try: 77 | with open(local_file, 'rb') as f: 78 | saved_metadata.ParseFromString(f.read()) 79 | lambda_code = [layer["config"]["function"]["items"][0] 80 | for layer in [json.loads(node.metadata) 81 | for node in saved_metadata.nodes 82 | if node.identifier == "_tf_keras_layer"] 83 | if layer["class_name"] == "Lambda"] 84 | for code in lambda_code: 85 | logger.info((f"found code in {local_file}: {code}")) 86 | logger.info((f"CODE: {code}")) 87 | code = lambda_code[0] 88 | metadata['extracted_encoded_code'] = code 89 | metadata['contains_code'] = True 90 | return metadata 91 | # If we don't find a lambda layer, the above check will give an IndexError that we can assume 92 | # that the model does not contain a Lambda layer 93 | except IndexError as ie: 94 | metadata['contains_code'] = False 95 | logger.info(("didn't find code")) 96 | return metadata 97 | except Exception as e: 98 | logger.info((f'We had a non-index error analyzing {local_file} : {e}')) 99 | return metadata 100 | 101 | def update_dynamo(result): 102 | result['model_type'] = 'protobuf' 103 | model = result["repo"] 104 | 105 | dynamodb = boto3.resource('dynamodb', region_name=AWS_REGION) 106 | table = dynamodb.Table(DYNAMO_STATUS_TABLE) 107 | 108 | response = table.query( 109 | Select='COUNT', 110 | KeyConditionExpression='repo = :repo', 111 | ExpressionAttributeValues={':repo': model} 112 | ) 113 | 114 | if response['Count'] > 0: 115 | #get and preserve prior analysis 116 | old_analysis = table.get_item(Key={'repo': model, 'version': 'v0'}) 117 | version = f"v{response['Count']}" 118 | old_analysis['Item']['version'] = version 119 | table.put_item( 120 | Item = old_analysis['Item'] 121 | ) 122 | 123 | #update table with new data 124 | result['version'] = 'v0' 125 | addition_response = table.delete_item(Key={'repo': model, 'version': 'v0'}) 126 | table.put_item( 127 | Item = result 128 | ) 129 | logger.info((f'Added {result["repo"]} got code {addition_response["ResponseMetadata"]["HTTPStatusCode"]}')) 130 | 131 | else: 132 | logger.info((f'New model {result["repo"]} analyzed, adding to metadata store')) 133 | result['version'] = 'v0' 134 | addition_response = table.put_item( 135 | Item = result 136 | ) 137 | logger.info((f'Added {result["repo"]} got code {addition_response["ResponseMetadata"]["HTTPStatusCode"]}')) 138 | 139 | sqs = boto3.resource('sqs', region_name=AWS_REGION) 140 | bhakti_queue = sqs.get_queue_by_name( 141 | QueueName=SQS_QUEUE 142 | ) 143 | api_token = get_api_token() 144 | 145 | scanning = True 146 | while scanning: 147 | sqs_messages = bhakti_queue.receive_messages( 148 | MaxNumberOfMessages=1, 149 | AttributeNames=["All"], 150 | MessageAttributeNames=["All"], 151 | WaitTimeSeconds=20, 152 | ) 153 | if len(sqs_messages) == 0: 154 | scanning = False 155 | for sqs_message in sqs_messages: 156 | body = sqs_message.body 157 | msg_body = json.loads(body) 158 | logger.info((f'SQS GIVING US {msg_body}')) 159 | model = msg_body['id'] 160 | 161 | local_file = download_metadata_file(msg_body, api_token) 162 | logger.info((local_file)) 163 | sqs_message.delete() 164 | 165 | result = {} 166 | if str(local_file).endswith('-GATED'): 167 | logger.info((f'{model} is not publicly available')) 168 | result['private'] = True 169 | else: 170 | result = check_for_code(local_file) 171 | result['repo'] = model 172 | result['modified_date'] = msg_body['lastModified'] 173 | result['keras_filenam'] = msg_body['keras_filename'] 174 | 175 | logger.info((f'RESULTS {result}')) 176 | update_dynamo(result) 177 | 178 | try: 179 | s3 = boto3.client('s3', region_name=AWS_REGION) 180 | s3.put_object(Bucket = LOGGING_BUCKET, Key=f"{int(round(datetime.timestamp(datetime.now())))}-bhakti.log", Body='/var/log/bhakti.log') 181 | except Exception as e: 182 | print('unable to upload log to s3') 183 | 184 | subprocess.call(["shutdown"]) -------------------------------------------------------------------------------- /bhakti-cdk/lambda/monitoring_lambda.py: -------------------------------------------------------------------------------- 1 | import json 2 | import boto3 3 | from botocore.exceptions import ClientError 4 | from datetime import datetime 5 | import logging 6 | import requests 7 | import re 8 | import hashlib 9 | import os 10 | 11 | logger = logging.getLogger() 12 | logger.setLevel(logging.INFO) 13 | 14 | 15 | DATE_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ" 16 | EC2_AMI = 'ami-0b28c78d9f575dfa1' 17 | DYNAMO_TABLE = os.getenv('DYNAMO_TABLE') 18 | WORKING_QUEUE = os.getenv('WORKING_QUEUE') 19 | HF_TOKEN = os.getenv('HF_TOKEN') 20 | AWS_REGION = os.getenv('AWS_REG') 21 | INSTANCE_PROFILE_ARN = os.getenv('INSTANCE_PROFILE_ARN') 22 | LOGGING_BUCKET = os.getenv('LOGGING_BUCKET') 23 | ANALYSIS_BUCKET = os.getenv('ANALYSIS_BUCKET') 24 | ANALYSIS_PATH = os.getenv('ANALYSIS_PATH') 25 | 26 | def get_user_data(bucket: str) -> str: 27 | user_data = f"""#!/bin/bash 28 | aws s3 cp s3://{bucket} /tmp/analysis/scripts.zip 29 | unzip /tmp/analysis/scripts.zip -d /tmp/analysis 30 | chmod +x /tmp/analysis/monitoring_ec2_check.py 31 | export SQS_QUEUE={WORKING_QUEUE} 32 | export AWS_REG={AWS_REGION} 33 | export HUGGINGFACE_TOKEN={HF_TOKEN} 34 | export DYNAMO_STATUS_TABLE={DYNAMO_TABLE} 35 | export LOGGING_BUCKET={LOGGING_BUCKET} 36 | ./tmp/analysis/monitoring_ec2_check.py""" 37 | return user_data 38 | 39 | def get_api_token(): 40 | secret_name = HF_TOKEN 41 | region_name = AWS_REGION 42 | 43 | session = boto3.session.Session() 44 | client = session.client( 45 | service_name='secretsmanager', 46 | region_name=region_name 47 | ) 48 | 49 | try: 50 | get_secret_value_response = client.get_secret_value( 51 | SecretId=secret_name 52 | ) 53 | 54 | except ClientError as e: 55 | print(e) 56 | raise e 57 | 58 | secret = get_secret_value_response['SecretString'] 59 | return secret 60 | 61 | def callHuggingFace(urlpointer, token): 62 | url = urlpointer 63 | payload = {} 64 | headers = { 65 | 'Authorization': f'Bearer {token}' 66 | } 67 | 68 | response = requests.request("GET", url, headers=headers, data=payload) 69 | return response 70 | 71 | def findKeras(models, modelType): 72 | with open(f'/tmp/kerasFriends-{modelType}.txt', 'a' ) as kerasFriends: 73 | for model in models: 74 | for file in model['siblings']: 75 | if modelType in file['rfilename']: 76 | model['keras_filename'] = file['rfilename'] 77 | kerasFriends.write(json.dumps(model)) 78 | kerasFriends.write("\n") 79 | 80 | def scanPublicModels(url, api_token, modelType): 81 | response = callHuggingFace(url, api_token) 82 | models = response.json() 83 | findKeras(models, modelType) 84 | 85 | try: 86 | nextPage = response.headers['link'] 87 | url = re.search('<(.+?)>', nextPage).group(1) 88 | return url 89 | 90 | except Exception as sslE: 91 | url = 'DONE' 92 | return url 93 | 94 | 95 | def check_if_model_updated(id, lastModified): 96 | latest_modification_date = datetime.strptime(lastModified, DATE_FORMAT) 97 | try: 98 | dynamodb = boto3.resource('dynamodb', region_name=AWS_REGION) 99 | table = dynamodb.Table(DYNAMO_TABLE) 100 | response = table.get_item( 101 | Key={'repo': id, 'version': 'v0'} 102 | ) 103 | 104 | if 'Item' not in response.keys(): 105 | logger.info(f'New model {id} identified, enqueing for processing') 106 | return True 107 | else: 108 | last_checked = datetime.strptime(response['Item']['modified_date'], DATE_FORMAT) 109 | if last_checked == latest_modification_date: 110 | logger.info(f"We're up to date with analysis for {id}") 111 | return False 112 | elif last_checked < latest_modification_date: 113 | logger.info(f"New version detected for {id}!") 114 | return True 115 | 116 | except Exception as e: 117 | logging.error(f'We had some trouble with dynamoDB: {e}') 118 | 119 | 120 | def send_sqs(message, queue): 121 | sqs = boto3.resource("sqs") 122 | sqs_queue = sqs.get_queue_by_name(QueueName=queue) 123 | groupid = "bhakti_updates" 124 | deduplicationid = hashlib.md5( 125 | ( 126 | groupid + json.dumps(message) + datetime.now().strftime("%d%m%Y%H%M%S") 127 | ).encode("utf-8") 128 | ).hexdigest() 129 | response = sqs_queue.send_message( 130 | MessageBody=message, 131 | MessageGroupId=groupid, 132 | MessageDeduplicationId=deduplicationid, 133 | ) 134 | return response 135 | 136 | def handler(event, context): 137 | logger.info("request: {}".format(json.dumps(event))) 138 | 139 | api_token = get_api_token() 140 | url = "https://huggingface.co/api/models/?full=full" 141 | scanning = True 142 | new_models = False 143 | 144 | try: 145 | next_url = scanPublicModels(url, api_token, 'keras_metadata.pb') 146 | while scanning == True: 147 | if 'huggingface' in next_url: 148 | next_url = scanPublicModels(next_url, api_token, 'keras_metadata.pb') 149 | else: 150 | scanning = False 151 | 152 | with open('/tmp/kerasFriends-keras_metadata.pb.txt', 'r') as results: 153 | current_keras_models = [] 154 | for line in results.readlines(): 155 | current_model = json.loads(line) 156 | current_keras_models.append(current_model) 157 | 158 | for model in current_keras_models: 159 | lastModified = model['lastModified'] 160 | logger.info(f'model: {model["modelId"]} last modified on {lastModified}') 161 | update_needed = check_if_model_updated(model['id'], lastModified) 162 | if update_needed: 163 | new_models = True 164 | model['bhakti_request_date'] = datetime.now().strftime(DATE_FORMAT) 165 | logger.info(f'send_sqs_message with {model}') 166 | if len(model['siblings']) > 100: 167 | model['siblings'] = 'too_many_files' 168 | continue 169 | send_sqs(json.dumps(model), WORKING_QUEUE) 170 | 171 | if new_models: 172 | ec2 = boto3.client('ec2', region_name=AWS_REGION) 173 | instance = ec2.run_instances( 174 | ImageId=EC2_AMI, 175 | InstanceType="g4dn.xlarge", 176 | UserData=get_user_data(f'{ANALYSIS_BUCKET}/{ANALYSIS_PATH}'), 177 | IamInstanceProfile={ 'Arn': INSTANCE_PROFILE_ARN }, 178 | InstanceInitiatedShutdownBehavior='terminate', 179 | KeyName='bhakti-ssh-key', 180 | MinCount=1, 181 | MaxCount=1 182 | ) 183 | 184 | instance_data = { 185 | 'status_code': instance['ResponseMetadata']['HTTPStatusCode'] 186 | } 187 | 188 | if instance['ResponseMetadata']['HTTPStatusCode'] == 200: 189 | logger.info('Started an EC2 instance for analysis...') 190 | instance_data['instance_id'] = instance['Instances'][0]['InstanceId'] 191 | else: 192 | logger.error('EC2 instance failed to launch') 193 | instance_data['instance_id'] = 'N/A, FAILED' 194 | 195 | logger.info(instance_data) 196 | 197 | except Exception as e: 198 | logger.error(e) -------------------------------------------------------------------------------- /bhakti-cdk/bhakti_cdk/bhakti_monitoring_stack.py: -------------------------------------------------------------------------------- 1 | from aws_cdk import ( 2 | aws_s3_assets as assets, 3 | Stack, 4 | aws_secretsmanager, 5 | aws_logs as logs, 6 | aws_dynamodb, 7 | aws_s3 as s3, 8 | aws_sqs as sqs, 9 | aws_iam as iam, 10 | Duration, 11 | aws_lambda, 12 | aws_events, 13 | aws_events_targets, 14 | ) 15 | from constructs import Construct 16 | 17 | class MonitoringStack(Stack): 18 | 19 | def __init__(self, scope: Construct, construct_id: str, hf_token: aws_secretsmanager.Secret, script_asset: assets.Asset, **kwargs) -> None: 20 | super().__init__(scope, construct_id, **kwargs) 21 | huggingface_token = hf_token 22 | bhakti_log_group = logs.LogGroup(self, 'bhakti_logs') 23 | 24 | status_table = aws_dynamodb.TableV2(self, 'status_table', 25 | partition_key=aws_dynamodb.Attribute( 26 | name='repo', 27 | type=aws_dynamodb.AttributeType.STRING 28 | ), 29 | sort_key=aws_dynamodb.Attribute( 30 | name='version', 31 | type=aws_dynamodb.AttributeType.STRING 32 | ) 33 | ) 34 | status_table.add_global_secondary_index( 35 | index_name="models_with_code", 36 | partition_key=aws_dynamodb.Attribute( 37 | name='extracted_encoded_code', 38 | type=aws_dynamodb.AttributeType.STRING)) 39 | 40 | bhakti_analysis_bucket = s3.Bucket( 41 | self, 'bhakti_analysis_bucket', 42 | block_public_access=s3.BlockPublicAccess.BLOCK_ALL, 43 | ) 44 | 45 | monitoring_queue = sqs.Queue( 46 | self, 47 | "monitoring_queue", 48 | queue_name="bhakti_monitoring_queue.fifo", 49 | visibility_timeout=Duration.seconds(660), 50 | fifo=True, 51 | ) 52 | 53 | bhakti_automated_role = iam.Role( 54 | self, "bhakti_automated_role", 55 | assumed_by=iam.ServicePrincipal("ec2.amazonaws.com"), 56 | description="EC2 instance role for automated Bhakti analysis instances" 57 | ) 58 | 59 | monitoring_execution = iam.PolicyDocument( 60 | statements=[ 61 | iam.PolicyStatement( 62 | actions=["dynamodb:GetItem", "dynamodb:Query", "dynamodb:DeleteItem", "dynamodb:PutItem"], 63 | resources=[status_table.table_arn], 64 | ), 65 | iam.PolicyStatement( 66 | actions=["sqs:SendMessage", "sqs:GetQueueUrl"], 67 | resources=[monitoring_queue.queue_arn], 68 | ), 69 | iam.PolicyStatement( 70 | actions=["secretsmanager:GetSecretValue", "secretsmanager:DescribeSecret"], 71 | resources=[huggingface_token.secret_arn], 72 | ), 73 | iam.PolicyStatement( 74 | actions=["s3:putItem"], 75 | resources=[f"{bhakti_analysis_bucket.bucket_arn}/*"], 76 | ), 77 | iam.PolicyStatement( 78 | actions=["iam:PassRole"], 79 | resources=[bhakti_automated_role.role_arn], 80 | ), 81 | iam.PolicyStatement( 82 | actions=["ec2:RunInstances", "ec2:CreateTags"], 83 | resources=[ 84 | f"arn:aws:ec2:{self.region}:{self.account}:instance/*", 85 | f"arn:aws:ec2:{self.region}:{self.account}:image/ami-0b28c78d9f575dfa1", 86 | f"arn:aws:ec2:{self.region}:{self.account}:network-interface/*", 87 | f"arn:aws:ec2:{self.region}:{self.account}:security-group/*", 88 | f"arn:aws:ec2:{self.region}:{self.account}:subnet/subnet-*", 89 | f"arn:aws:ec2:{self.region}:{self.account}:volume/*", 90 | f"arn:aws:ec2:{self.region}::image/ami-0b28c78d9f575dfa1", 91 | ], 92 | ), 93 | ] 94 | ) 95 | monitoring_execution_policy = iam.Policy( 96 | self, "monitoring_execution_policy", document=monitoring_execution 97 | ) 98 | 99 | asset_bucket = s3.Bucket.from_bucket_name(self, 'script_bucket', script_asset.s3_bucket_name) 100 | 101 | bhakti_analysis_policy_statement = iam.PolicyDocument( 102 | statements=[ 103 | iam.PolicyStatement( 104 | actions=["dynamodb:GetItem", "dynamodb:Query", "dynamodb:DeleteItem", "dynamodb:PutItem"], 105 | resources=[status_table.table_arn], 106 | ), 107 | iam.PolicyStatement( 108 | actions=["s3:getItem"], 109 | resources=[f"{asset_bucket.bucket_arn}/*"] 110 | ), 111 | iam.PolicyStatement( 112 | actions=["s3:putItem"], 113 | resources=[f"{bhakti_analysis_bucket.bucket_arn}/*"] 114 | ), 115 | iam.PolicyStatement( 116 | actions=["secretsmanager:GetSecretValue", "secretsmanager:DescribeSecret"], 117 | resources=[hf_token.secret_arn] 118 | ), 119 | iam.PolicyStatement( 120 | actions=["sqs:GetQueueUrl", "sqs:ReceiveMessage", "sqs:DeleteMessage"], 121 | resources=[monitoring_queue.queue_arn] 122 | ) 123 | ] 124 | ) 125 | 126 | bhakti_analysis_policy = iam.Policy(self, "bhakti_analysis_policy", document=bhakti_analysis_policy_statement) 127 | 128 | bhakti_automated_role.attach_inline_policy(bhakti_analysis_policy) 129 | 130 | bhakti_instance_profile = iam.CfnInstanceProfile( 131 | self, "bhakti_automated_instance_profile", 132 | roles=[bhakti_automated_role.role_name] 133 | ) 134 | 135 | monitoring_lambda = aws_lambda.Function( 136 | self, 137 | "monitoring_lambda", 138 | code=aws_lambda.Code.from_asset( 139 | "lambda", 140 | bundling={ 141 | "image":aws_lambda.Runtime.PYTHON_3_12.bundling_image, 142 | "command": [ 143 | 'bash','-c', 144 | 'pip install -r requirements.txt -t /asset-output && cp -au . /asset-output' 145 | ], 146 | }, 147 | ), 148 | handler="monitoring_lambda.handler", 149 | timeout=Duration.seconds(900), 150 | runtime=aws_lambda.Runtime.PYTHON_3_12, 151 | memory_size=3072, 152 | log_group=bhakti_log_group, 153 | environment={ 154 | 'DYNAMO_TABLE' : status_table.table_name, 155 | 'WORKING_QUEUE' : monitoring_queue.queue_name, 156 | 'HF_TOKEN' : huggingface_token.secret_name, 157 | 'AWS_REG' : self.region, 158 | 'INSTANCE_PROFILE_ARN' : bhakti_instance_profile.attr_arn, 159 | 'LOGGING_BUCKET' : bhakti_analysis_bucket.bucket_name, 160 | 'ANALYSIS_BUCKET' : script_asset.s3_bucket_name, 161 | 'ANALYSIS_PATH' : script_asset.s3_object_key, 162 | } 163 | ) 164 | 165 | monitoring_lambda.role.attach_inline_policy(monitoring_execution_policy) 166 | script_asset.grant_read(bhakti_automated_role) 167 | 168 | keras_monitoring_event_rule = aws_events.Rule( 169 | self, 170 | "keras_monitoring_event_rule", 171 | schedule=aws_events.Schedule.cron( 172 | hour = "1", 173 | minute = "0", 174 | ) 175 | ) 176 | 177 | keras_monitoring_event_rule.add_target(aws_events_targets.LambdaFunction(monitoring_lambda)) 178 | 179 | -------------------------------------------------------------------------------- /bhakti-cdk/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Welcome to Bhakti CDK!! 3 | 4 | DBX Threat Intelligence 🤝 Infrastructure as Code 5 | 6 | The CDK contained here is intended to help folks get started using AWS to look at ML models. There are two main ways to deploy bhakti's cdk-- for monitoring or for a launch template. 7 | 8 | ## Pre-Requisites 9 | 10 | Deploying this stack requires that you have an AWS account that's been bootstrapped for CDK (see [Bootstrapping](https://docs.aws.amazon.com/cdk/v2/guide/bootstrapping.html)) and an IAM identity with adequate permissions that you can run the cdk commands from. Whatever you've used to bootstrap your account should be able to deploy your stack. 11 | 12 | Don't have cdk? Check out [Getting Started with CDK](https://docs.aws.amazon.com/cdk/v2/guide/getting_started.html) for more info. 13 | 14 | If you're planning to deploy the monitoring stack, you'll need to have Docker running on your machine (it bundles the requests library into the Lambda using a docker container) 15 | 16 | If you're using a f r e s h AWS account, you will probably have to request that AWS grant you a quota for running G class instances. You can do that using the [service quotas dashboard](https://console.aws.amazon.com/servicequotas/home/services/ec2/quotas#) and looking for "Running On-Demand G and VT instances." One `G4DN.XLARGE` has 4vCPUs (which is the quota restricts). 17 | 18 | ## Launch Template Stack 19 | 20 | The [Launch Template Stack](bhakti_cdk/bhakti_instance_profiles.py) will create an ec2 launch template for you that you can then use to run instances in your account for manual analysis. 21 | - It's got an instance role attached that can get a secret from secrets manager (for your huggingface api key) 22 | - It copies over the contents of the [analysis directory](analysis) to the ec2 instance under the `/home/ec2-user/analysis` directory 23 | - If you launch the stack with a security group id (`sg-{id}`) in the context, it will attach that security group to the ec2 instance by default. For more about security groups, see [this documentation](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-security-groups.html). Basically, you'll need at least a security group allowing SSH access from your IP or a CIDR including your IP to be able to ssh to the ec2 instance. If you don't specify one, you can always attach one either when you choose to launch an instance or even after it's running. 24 | - The keypair for the analysis instances will default to one the cdk creates and stores in ssm. 25 | - This stack does *not* start any instance for you, it simply creates an [EC2 Launch template](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-launch-templates.html) for manual instance usage. 26 | 27 | ### ✨ **Deployment** ✨ 28 | 29 | To stand-up this stack run the following commands: 30 | 31 | ``` 32 | $ cdk synth --context deploy_type=instance_profile --context deploy_region={region} --context deploy_account={aws_account_id} --context sg_id={your_security_group_to_attach} 33 | $ cdk deploy --context deploy_type=instance_profile --context deploy_region={region} --context deploy_account={aws_account_id} --context sg_id={your_security_group_to_attach} --all 34 | ``` 35 | 36 | You can also store these context variables in the [`cdk.context.json`](cdk_context.json) file in json format. 37 | 38 | 39 | ### ✍️ **Post-Deployment** ✍️ 40 | 41 | You'll need to do at least two things post-deployment: 42 | 1. Add your huggingface api key to the blank secret we created in this cdk stack. It'll be under the name "huggingface_api_token" unless you've modified it. 43 | ``` 44 | $ aws secretsmanager put-secret-value --secret-id 'huggingface_api_token' --secret-string "myhfapitoken" 45 | { 46 | "ARN": "arn:aws:secretsmanager:{region}:{account_id}:secret:huggingface_api_token-{version-id}", 47 | "Name": "huggingface_api_token", 48 | "VersionId": "f45a72e9-5e1f-4147-b81f-ac48f2c8e4d6", 49 | "VersionStages": [ 50 | "AWSCURRENT" 51 | ] 52 | } 53 | ``` 54 | 2. Grab the private key for your ssh keypair from the ssm parameter store: 55 | ``` 56 | $ aws ec2 describe-key-pairs | grep -B 2 -A 3 'bhakti-ssh-key' 57 | "KeyPairId": "key-$guid", 58 | "KeyFingerprint": "4b:2d:02:1b:bb:ec:6b:e0:30:32:39:66:68:5d:ca:68:e0:0e:16:d8", 59 | "KeyName": "bhakti-ssh-key", 60 | "KeyType": "rsa", 61 | "Tags": [], 62 | "CreateTime": "2024-04-12T22:30:08.693000+00:00" 63 | 64 | $ aws ssm describe-parameters --parameter-filters Key=Name,Values='/ec2/keypair/key-06792029fba76ba97' 65 | { 66 | "Parameters": [ 67 | { 68 | "Name": "/ec2/keypair/key-$guid", 69 | "ARN": "arn:aws:ssm:us-west-2:{account_id}:parameter/ec2/keypair/key-$guid", 70 | "Type": "SecureString", 71 | "KeyId": "alias/aws/ssm", 72 | "LastModifiedDate": "2024-04-12T15:30:08.779000-07:00", 73 | "LastModifiedUser": "{arn of user identity}", 74 | "Version": 1, 75 | "Tier": "Standard", 76 | "Policies": [], 77 | "DataType": "text" 78 | } 79 | ] 80 | } 81 | 82 | $ aws ssm get-parameters --name='/ec2/keypair/key-$guid' 83 | { 84 | "Parameters": [ 85 | { 86 | "Name": "/ec2/keypair/key-$guid", 87 | "Type": "SecureString", 88 | "Value": "{ base64 encoded private key }", 89 | "Version": 1, 90 | "LastModifiedDate": "2024-04-12T15:30:08.779000-07:00", 91 | "ARN": "{ arn of key }", 92 | "DataType": "text" 93 | } 94 | ], 95 | "InvalidParameters": [] 96 | } 97 | 98 | ``` 99 | 100 | ## Monitoring Stack 101 | 102 | The [Monitoring Stack](bhakti_cdk/bhakti_monitoring_stack.py) attempts to stand up a little automation service that will let you monitor huggingface each day for new models to assess. By default, it looks at `keras_metadata.pb` files, runs my stock analysis script over those files, and stores results in a DynamoDB table. If you want to change what's done, follow your heart and modify the script that runs when a new model is found: [monitoring_ec2_check.py](analysis/monitoring_ec2_check.py). 103 | 104 | Here's the architecture of what the cdk will stand-up in your account: 105 | 106 | ![architecture diagram](../media/Bhakti.png) 107 | 108 | ### 💸 *Small caution regarding billing* 💸 109 | If you're thinking about putting this in a personal account, please be advised that it will want to look at > 3.5k model metadata files to start out. It's not a huge amount (they're so tiny), but the ML instance used to assess them is a `G4DN.XLARGE`, which costs ~$0.52 an hour to run (with no discounts) based on current pricing. To just write model candidates to SQS and look them over more manually (maybe run yara on them and find the ones you care about?!), you could simply change the conditional check on [line 171 of the lambda](lambda/monitoring_lambda.py#L171) to never evaluate to be true. 110 | 111 | ### ✨ **Deployment** ✨ 112 | ``` 113 | $ cdk synth --context deploy_type=monitoring --context deploy_region={region} --context deploy_account={aws_account_id} 114 | $ cdk deploy --context deploy_type=monitoring --context deploy_region={region} --context deploy_account={aws_account_id} --all 115 | ``` 116 | You can also store these context variables in the [`cdk.context.json`](cdk_context.json) file in json format. 117 | 118 | ### ✍️ **Post-Deployment** ✍️ 119 | 1. You'll need to add your huggingface api key to the secret we created when we stood up this stack, if you haven't done so already. It'll be under "huggingface_api_token" unless you've altered it. 120 | ``` 121 | $ aws secretsmanager put-secret-value --secret-id 'huggingface_api_token' --secret-string "myhfapitoken" 122 | { 123 | "ARN": "arn:aws:secretsmanager:{region}:{account_id}:secret:huggingface_api_token-{version-id}", 124 | "Name": "huggingface_api_token", 125 | "VersionId": "f45a72e9-5e1f-4147-b81f-ac48f2c8e4d6", 126 | "VersionStages": [ 127 | "AWSCURRENT" 128 | ] 129 | } 130 | ``` 131 | 132 | ### **AWS Services Used** 133 | - AWS Lambda to identify f r e s h models once per day 134 | - AWS Events to trigger the Lambda to run once per day 135 | - AWS SQS stores json events correlating to models we find with lambda and want to assess 136 | - AWS DynamoDB has a table with a complete record of our models we've assessed (we update it from Lambda and EC2) 137 | - AWS EC2 functions as our compute to analyze updated models. The instance is started by Lambda and shuts itself down when its launch script completes work (we are afraid of EC2 bills). 138 | - AWS S3 Stores analysis scripts to load onto EC2 instances and logs from EC2 work. 139 | - CloudWatch Logging for the Lambda 140 | 141 | ## Useful commands 142 | 143 | * `cdk ls` list all stacks in the app 144 | * `cdk synth` emits the synthesized CloudFormation template 145 | * `cdk deploy` deploy this stack to your default AWS account/region 146 | * `cdk diff` compare deployed stack with current state 147 | * `cdk docs` open CDK documentation 148 | 149 | Enjoy! 150 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /analysis/checkModel.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | import logging 4 | import requests 5 | from tensorflow.python.keras.protobuf.saved_metadata_pb2 import SavedMetadata 6 | from optparse import OptionParser 7 | from datetime import datetime 8 | import os 9 | import dis 10 | import codecs 11 | import marshal 12 | import base64 13 | import string 14 | import sys 15 | import h5py 16 | import shutil 17 | from typing import Union, Dict, Any 18 | from collections.abc import Generator 19 | 20 | # output config 21 | logger = logging.getLogger() 22 | logger.setLevel(logging.INFO) 23 | handler = logging.StreamHandler(sys.stdout) 24 | handler.setLevel(logging.INFO) 25 | formatter = logging.Formatter("%(asctime)s | %(levelname)s | %(message)s") 26 | handler.setFormatter(formatter) 27 | logger.addHandler(handler) 28 | 29 | 30 | def gather_file(remote_model: str, api_token: str, directory: str) -> Union[Path, str]: 31 | """Attempts to assess a repo on huggingface and download any h5 or keras_metadata.pb 32 | files found within it. Returns either an error string or a Path object. 33 | """ 34 | filename = "" 35 | pb_filename = "" 36 | h5_filename = "" 37 | api_token = api_token 38 | url = f"https://huggingface.co/api/models/?id={remote_model}&full=full" 39 | 40 | headers = {"Authorization": f"Bearer {api_token}"} 41 | 42 | response = requests.request("GET", url, headers=headers) 43 | hf_model = response.json() 44 | 45 | for file in hf_model[0]["siblings"]: 46 | if "keras_metadata.pb" in file["rfilename"]: 47 | pb_filename = file["rfilename"] 48 | elif file["rfilename"].endswith(".h5"): 49 | h5_filename = file["rfilename"] 50 | 51 | if pb_filename: 52 | filename = pb_filename 53 | elif h5_filename: 54 | filename = h5_filename 55 | 56 | if filename: 57 | downloadLoc = Path(f"{directory}/{remote_model}/{filename}") 58 | downloadLoc.parent.mkdir(parents=True, exist_ok=True) 59 | downloadLink = f"https://huggingface.co/{remote_model}/resolve/main/{filename}" 60 | logger.info((f"Attempting to download: {downloadLink}")) 61 | try: 62 | with requests.get(downloadLink, headers=headers, stream=True) as r: 63 | if response.status_code == 401: 64 | logger.error( 65 | f"!!! Unfortunately, we're not authorized to retrieve {remote_model}" 66 | ) 67 | downloadLoc = "UNAUTHORIZED" 68 | elif response.status_code == 200: 69 | with open(downloadLoc, "wb") as resultFile: 70 | shutil.copyfileobj(r.raw, resultFile) 71 | logger.info((f"Wrote file to {downloadLoc}")) 72 | 73 | except Exception as e: 74 | logger.error( 75 | "!!! There was an issue downloading the file from huggingface! " 76 | ) 77 | logger.error(e) 78 | 79 | return downloadLoc 80 | 81 | else: 82 | logger.info("Couldn't find a keras metadata file for this repo!") 83 | 84 | 85 | def check_pb_for_code(local_file: Path, id: str) -> Dict[str, Any]: 86 | """Looks for the presence of a lambda layer within a keras_metadata.pb metadata file. 87 | If a layer is found, attempts to pull out the embedded code. Returns a dictionary 88 | describing the model assessed. 89 | """ 90 | metadata = {"id": id, "type": "pb"} 91 | saved_metadata = SavedMetadata() 92 | logger.info((f"Checking {local_file} for keras lambda layer")) 93 | try: 94 | with open(local_file, "rb") as f: 95 | saved_metadata.ParseFromString(f.read()) 96 | lambda_code = [ 97 | layer["config"]["function"]["items"][0] 98 | for layer in [ 99 | json.loads(node.metadata) 100 | for node in saved_metadata.nodes 101 | if node.identifier == "_tf_keras_layer" 102 | ] 103 | if layer["class_name"] == "Lambda" 104 | ] 105 | for code in lambda_code: 106 | logger.info((f"Found code in {local_file}: ")) 107 | logger.info((f"CODE: {code}")) 108 | code = lambda_code[0] 109 | metadata["extracted_encoded_code"] = code 110 | metadata["contains_code"] = True 111 | return metadata 112 | # If we don't find a lambda layer, the above check will give an IndexError that we can assume 113 | # that the model does not contain a Lambda layer 114 | except IndexError as ie: 115 | metadata["contains_code"] = False 116 | logger.info((f"Didn't find code in {local_file}")) 117 | return metadata 118 | except Exception as e: 119 | logger.info((f"We had a non-index error analyzing {local_file} : {e}")) 120 | return metadata 121 | 122 | 123 | def check_h5_for_code(local_file: str, id: str) -> Dict[str, Any]: 124 | """Looks for the presence of a lambda layer within an h5 model file. 125 | If a layer is found, attempts to pull out the embedded code. Definitely 126 | will only work for Keras Tensorflow models saved using .save(). 127 | Returns a dictionary describing the model assessed. 128 | """ 129 | metadata = {"id": id, "type": "h5"} 130 | logger.info((f"********* Checking {local_file} for keras lambda layer *********")) 131 | try: 132 | with h5py.File(local_file, "r") as f: 133 | # models saved with .save will contain a "model_config" attribute. Keras documentation 134 | # encourages this saving method in that this is the most consistent way to embed serialized code 135 | if "model_config" in list(f.attrs.keys()): 136 | try: 137 | lambda_code = [ 138 | layer.get("config", {}).get("function", {}) 139 | for layer in json.loads(f.attrs["model_config"])["config"][ 140 | "layers" 141 | ] 142 | if layer["class_name"] == "Lambda" 143 | ] 144 | code = lambda_code[0][0] 145 | logger.info((f"Found code in {local_file}: ")) 146 | logger.info((f"CODE: {code}")) 147 | metadata["contains_code"] = True 148 | metadata["extracted_encoded_code"] = code 149 | return metadata 150 | except IndexError as ie: 151 | logging.info(f"Didn't find code in {local_file}") 152 | metadata["contains_code"] = False 153 | return metadata 154 | else: 155 | metadata["contains_code"] = False 156 | logging.info( 157 | f"!!! Unfortunately, {local_file} was not saved with an extractable model config" 158 | ) 159 | return metadata 160 | except KeyError as ke: 161 | logging.info( 162 | f"!!! Unfortunately, {local_file} was not saved in a way for easy config extraction {ke}" 163 | ) 164 | return metadata 165 | except Exception as e: 166 | logging.error(f"!!! We had a non-index error analyzing {local_file} : {e}") 167 | return metadata 168 | 169 | 170 | def strings(encoded_code: bytes, min=4) -> Generator[str, None, None]: 171 | """ 172 | Attempts to find printable strings >= 4 characters in length. Approximates 173 | Unix strings capability, but a lot more brittle. 174 | """ 175 | try: 176 | encoded_code = encoded_code.decode("latin1") 177 | except UnicodeDecodeError as e: 178 | logger.error("Unable to decode blob as text!") 179 | result = "" 180 | for c in encoded_code: 181 | if c in string.printable: 182 | result += c 183 | continue 184 | if len(result) >= min: 185 | yield result 186 | result = "" 187 | if len(result) >= min: 188 | yield result 189 | 190 | 191 | def main(): 192 | class BhaktiParser(OptionParser): 193 | def format_epilog(self, formatter): 194 | return self.epilog 195 | 196 | usage = "usage: %prog -m author/model -r '/local/results/file' -a 'hf_api_key'" 197 | epilog = """Information: 198 | - Either a huggingface repo or a local model file is required. 199 | - Local files should be either Tensorflow models using keras saved in .h5 or keras_metadata.pb metadata files. 200 | - Unusual huggingface repo structures might behave oddly. 201 | - Not specifying a results file will result in results being written to std out. 202 | - Requesting a huggingface model without specifying a directory will write the file to the working directory 203 | 204 | Examples: 205 | checkModel.py -m 'author/model' -r '/path/to/local/results/file' -d '/path/to/download/models' -a 'hugging_face_api_key' -c 'True' 206 | checkModel.py -f '/path/to/local/model' -r '/path/to/local/results/file'""" 207 | parser = BhaktiParser(usage=usage, epilog=epilog) 208 | parser.add_option( 209 | "-m", 210 | "--model", 211 | dest="remote_model", 212 | help="huggingface repo to assess", 213 | metavar="author/repo", 214 | ) 215 | parser.add_option( 216 | "-f", 217 | "--file", 218 | dest="local_model", 219 | help="local model file to assess", 220 | metavar="/path/to/model", 221 | ) 222 | parser.add_option( 223 | "-r", 224 | "--results_file", 225 | dest="results_file", 226 | help="flat file to write results, otherwise results are printed to stdout", 227 | metavar="/path/to/file", 228 | ) 229 | parser.add_option( 230 | "-d", 231 | "--dir", 232 | dest="dir", 233 | metavar="/path/to/working/dir", 234 | help="local directory to store models downloaded from huggingface", 235 | ) 236 | parser.add_option( 237 | "-a", 238 | "--api_key", 239 | dest="hf_api_key", 240 | metavar="hf_{...}", 241 | help="api token to use to interact with huggingface", 242 | ) 243 | parser.add_option( 244 | "-c", 245 | "--clean_up", 246 | dest="clean_up", 247 | metavar="False", 248 | help="Set to true if you want to delete models that are downloaded", 249 | default="False", 250 | ) 251 | 252 | (options, args) = parser.parse_args() 253 | 254 | if options.remote_model and options.local_model: 255 | parser.error("specify either a local file or remote repo, but not both :)") 256 | 257 | if not options.remote_model and not options.local_model: 258 | parser.error( 259 | "Please specify at least one model to analyze using either [-m|--model] (remote) or [-f|--file] (local)" 260 | ) 261 | 262 | if options.remote_model and not options.dir: 263 | logger.info( 264 | "No results directory specified, fetching remote model to working directory..." 265 | ) 266 | 267 | if not options.hf_api_key and options.remote_model: 268 | logger.info( 269 | "No api key provided but requesting model, trying to download without authorization" 270 | ) 271 | hf_api_key = "" 272 | elif options.hf_api_key: 273 | hf_api_key = options.hf_api_key 274 | 275 | if options.local_model: 276 | local_model = options.local_model 277 | results = {} 278 | if not local_model.endswith(".h5") and local_model.endswith(".pb"): 279 | results = check_pb_for_code(local_model, local_model) 280 | elif local_model.endswith(".h5"): 281 | results = check_h5_for_code(local_model, local_model) 282 | 283 | elif options.remote_model: 284 | remote_model = options.remote_model 285 | api_token = hf_api_key 286 | if options.dir: 287 | directory = options.dir 288 | else: 289 | directory = "." 290 | downloaded_file = gather_file(remote_model, api_token, directory) 291 | file_path = str(downloaded_file) 292 | if downloaded_file != "UNAUTHORIZED": 293 | if not file_path.endswith(".h5") and file_path.endswith(".pb"): 294 | results = check_pb_for_code(downloaded_file, remote_model) 295 | elif file_path.endswith(".h5"): 296 | results = check_h5_for_code(downloaded_file, remote_model) 297 | 298 | if code := results.get("extracted_encoded_code"): 299 | logger.info( 300 | f"********* Trying to disassemble extracted code layer in {results['id']}: *********" 301 | ) 302 | try: 303 | dis.dis(marshal.loads(codecs.decode(code.encode("ascii"), "base64"))) 304 | except Exception as e: 305 | logger.error(f"!!! Unfortunately, dis struggled with {results['id']}: {e}") 306 | logger.info( 307 | f"********* Attempting to find strings for {results['id']}: *********" 308 | ) 309 | decoded_code = base64.b64decode(code) 310 | sl = list(strings(decoded_code)) 311 | if len(sl) > 0: 312 | results["string_list"] = sl 313 | logger.info(f"Found strings in {results['id']}:") 314 | logger.info(f"STRINGS: {sl}") 315 | else: 316 | logger.info(f"Could not find any printable strings in {results['id']}!") 317 | 318 | if options.results_file: 319 | results_file = options.results_file 320 | results_file = Path(results_file) 321 | results_file.parent.mkdir(parents=True, exist_ok=True) 322 | with open(results_file, "a") as f: 323 | f.write(json.dumps(results)) 324 | f.write("\n") 325 | else: 326 | logger.info( 327 | "********* No result file specified, printing results to std out: *********" 328 | ) 329 | logger.info(results) 330 | 331 | clean_up = options.clean_up 332 | if clean_up.lower() in ["true", "1"] and options.remote_model: 333 | os.remove(downloaded_file) 334 | parent_dir = remote_model.split("/")[0] 335 | if options.dir: 336 | directory = options.dir.rstrip("/") 337 | os.rmdir(f"{directory}/{remote_model}") 338 | os.rmdir(f"{directory}/{parent_dir}") 339 | elif options.remote_model: 340 | os.rmdir(remote_model) 341 | os.rmdir(parent_dir) 342 | 343 | 344 | if __name__ == "__main__": 345 | main() 346 | -------------------------------------------------------------------------------- /bhakti-cdk/analysis/checkModel.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | import logging 4 | import requests 5 | from tensorflow.python.keras.protobuf.saved_metadata_pb2 import SavedMetadata 6 | from optparse import OptionParser 7 | from datetime import datetime 8 | import os 9 | import dis 10 | import codecs 11 | import marshal 12 | import base64 13 | import string 14 | import sys 15 | import h5py 16 | import shutil 17 | from typing import Union, Dict, Any 18 | from collections.abc import Generator 19 | 20 | # output config 21 | logger = logging.getLogger() 22 | logger.setLevel(logging.INFO) 23 | handler = logging.StreamHandler(sys.stdout) 24 | handler.setLevel(logging.INFO) 25 | formatter = logging.Formatter("%(asctime)s | %(levelname)s | %(message)s") 26 | handler.setFormatter(formatter) 27 | logger.addHandler(handler) 28 | 29 | 30 | def gather_file(remote_model: str, api_token: str, directory: str) -> Union[Path, str]: 31 | """Attempts to assess a repo on huggingface and download any h5 or keras_metadata.pb 32 | files found within it. Returns either an error string or a Path object. 33 | """ 34 | filename = "" 35 | pb_filename = "" 36 | h5_filename = "" 37 | api_token = api_token 38 | url = f"https://huggingface.co/api/models/?id={remote_model}&full=full" 39 | 40 | headers = {"Authorization": f"Bearer {api_token}"} 41 | 42 | response = requests.request("GET", url, headers=headers) 43 | hf_model = response.json() 44 | 45 | for file in hf_model[0]["siblings"]: 46 | if "keras_metadata.pb" in file["rfilename"]: 47 | pb_filename = file["rfilename"] 48 | elif file["rfilename"].endswith(".h5"): 49 | h5_filename = file["rfilename"] 50 | 51 | if pb_filename: 52 | filename = pb_filename 53 | elif h5_filename: 54 | filename = h5_filename 55 | 56 | if filename: 57 | downloadLoc = Path(f"{directory}/{remote_model}/{filename}") 58 | downloadLoc.parent.mkdir(parents=True, exist_ok=True) 59 | downloadLink = f"https://huggingface.co/{remote_model}/resolve/main/{filename}" 60 | logger.info((f"Attempting to download: {downloadLink}")) 61 | try: 62 | with requests.get(downloadLink, headers=headers, stream=True) as r: 63 | if response.status_code == 401: 64 | logger.error( 65 | f"!!! Unfortunately, we're not authorized to retrieve {remote_model}" 66 | ) 67 | downloadLoc = "UNAUTHORIZED" 68 | elif response.status_code == 200: 69 | with open(downloadLoc, "wb") as resultFile: 70 | shutil.copyfileobj(r.raw, resultFile) 71 | logger.info((f"Wrote file to {downloadLoc}")) 72 | 73 | except Exception as e: 74 | logger.error( 75 | "!!! There was an issue downloading the file from huggingface! " 76 | ) 77 | logger.error(e) 78 | 79 | return downloadLoc 80 | 81 | else: 82 | logger.info("Couldn't find a keras metadata file for this repo!") 83 | 84 | 85 | def check_pb_for_code(local_file: Path, id: str) -> Dict[str, Any]: 86 | """Looks for the presence of a lambda layer within a keras_metadata.pb metadata file. 87 | If a layer is found, attempts to pull out the embedded code. Returns a dictionary 88 | describing the model assessed. 89 | """ 90 | metadata = {"id": id, "type": "pb"} 91 | saved_metadata = SavedMetadata() 92 | logger.info((f"Checking {local_file} for keras lambda layer")) 93 | try: 94 | with open(local_file, "rb") as f: 95 | saved_metadata.ParseFromString(f.read()) 96 | lambda_code = [ 97 | layer["config"]["function"]["items"][0] 98 | for layer in [ 99 | json.loads(node.metadata) 100 | for node in saved_metadata.nodes 101 | if node.identifier == "_tf_keras_layer" 102 | ] 103 | if layer["class_name"] == "Lambda" 104 | ] 105 | for code in lambda_code: 106 | logger.info((f"Found code in {local_file}: ")) 107 | logger.info((f"CODE: {code}")) 108 | code = lambda_code[0] 109 | metadata["extracted_encoded_code"] = code 110 | metadata["contains_code"] = True 111 | return metadata 112 | # If we don't find a lambda layer, the above check will give an IndexError that we can assume 113 | # that the model does not contain a Lambda layer 114 | except IndexError as ie: 115 | metadata["contains_code"] = False 116 | logger.info((f"Didn't find code in {local_file}")) 117 | return metadata 118 | except Exception as e: 119 | logger.info((f"We had a non-index error analyzing {local_file} : {e}")) 120 | return metadata 121 | 122 | 123 | def check_h5_for_code(local_file: str, id: str) -> Dict[str, Any]: 124 | """Looks for the presence of a lambda layer within an h5 model file. 125 | If a layer is found, attempts to pull out the embedded code. Definitely 126 | will only work for Keras Tensorflow models saved using .save(). 127 | Returns a dictionary describing the model assessed. 128 | """ 129 | metadata = {"id": id, "type": "h5"} 130 | logger.info((f"********* Checking {local_file} for keras lambda layer *********")) 131 | try: 132 | with h5py.File(local_file, "r") as f: 133 | # models saved with .save will contain a "model_config" attribute. Keras documentation 134 | # encourages this saving method in that this is the most consistent way to embed serialized code 135 | if "model_config" in list(f.attrs.keys()): 136 | try: 137 | lambda_code = [ 138 | layer.get("config", {}).get("function", {}) 139 | for layer in json.loads(f.attrs["model_config"])["config"][ 140 | "layers" 141 | ] 142 | if layer["class_name"] == "Lambda" 143 | ] 144 | code = lambda_code[0][0] 145 | logger.info((f"Found code in {local_file}: ")) 146 | logger.info((f"CODE: {code}")) 147 | metadata["contains_code"] = True 148 | metadata["extracted_encoded_code"] = code 149 | return metadata 150 | except IndexError as ie: 151 | logging.info(f"Didn't find code in {local_file}") 152 | metadata["contains_code"] = False 153 | return metadata 154 | else: 155 | metadata["contains_code"] = False 156 | logging.info( 157 | f"!!! Unfortunately, {local_file} was not saved with an extractable model config" 158 | ) 159 | return metadata 160 | except KeyError as ke: 161 | logging.info( 162 | f"!!! Unfortunately, {local_file} was not saved in a way for easy config extraction {ke}" 163 | ) 164 | return metadata 165 | except Exception as e: 166 | logging.error(f"!!! We had a non-index error analyzing {local_file} : {e}") 167 | return metadata 168 | 169 | 170 | def strings(encoded_code: bytes, min=4) -> Generator[str, None, None]: 171 | """ 172 | Attempts to find printable strings >= 4 characters in length. Approximates 173 | Unix strings capability, but a lot more brittle. 174 | """ 175 | try: 176 | encoded_code = encoded_code.decode("latin1") 177 | except UnicodeDecodeError as e: 178 | logger.error("Unable to decode blob as text!") 179 | result = "" 180 | for c in encoded_code: 181 | if c in string.printable: 182 | result += c 183 | continue 184 | if len(result) >= min: 185 | yield result 186 | result = "" 187 | if len(result) >= min: 188 | yield result 189 | 190 | 191 | def main(): 192 | class BhaktiParser(OptionParser): 193 | def format_epilog(self, formatter): 194 | return self.epilog 195 | 196 | usage = "usage: %prog -m author/model -r '/local/results/file' -a 'hf_api_key'" 197 | epilog = """Information: 198 | - Either a huggingface repo or a local model file is required. 199 | - Local files should be either Tensorflow models using keras saved in .h5 or keras_metadata.pb metadata files. 200 | - Unusual huggingface repo structures might behave oddly. 201 | - Not specifying a results file will result in results being written to std out. 202 | - Requesting a huggingface model without specifying a directory will write the file to the working directory 203 | 204 | Examples: 205 | checkModel.py -m 'author/model' -r '/path/to/local/results/file' -d '/path/to/download/models' -a 'hugging_face_api_key' -c 'True' 206 | checkModel.py -f '/path/to/local/model' -r '/path/to/local/results/file'""" 207 | parser = BhaktiParser(usage=usage, epilog=epilog) 208 | parser.add_option( 209 | "-m", 210 | "--model", 211 | dest="remote_model", 212 | help="huggingface repo to assess", 213 | metavar="author/repo", 214 | ) 215 | parser.add_option( 216 | "-f", 217 | "--file", 218 | dest="local_model", 219 | help="local model file to assess", 220 | metavar="/path/to/model", 221 | ) 222 | parser.add_option( 223 | "-r", 224 | "--results_file", 225 | dest="results_file", 226 | help="flat file to write results, otherwise results are printed to stdout", 227 | metavar="/path/to/file", 228 | ) 229 | parser.add_option( 230 | "-d", 231 | "--dir", 232 | dest="dir", 233 | metavar="/path/to/working/dir", 234 | help="local directory to store models downloaded from huggingface", 235 | ) 236 | parser.add_option( 237 | "-a", 238 | "--api_key", 239 | dest="hf_api_key", 240 | metavar="hf_{...}", 241 | help="api token to use to interact with huggingface", 242 | ) 243 | parser.add_option( 244 | "-c", 245 | "--clean_up", 246 | dest="clean_up", 247 | metavar="False", 248 | help="Set to true if you want to delete models that are downloaded", 249 | default="False", 250 | ) 251 | 252 | (options, args) = parser.parse_args() 253 | 254 | if options.remote_model and options.local_model: 255 | parser.error("specify either a local file or remote repo, but not both :)") 256 | 257 | if not options.remote_model and not options.local_model: 258 | parser.error( 259 | "Please specify at least one model to analyze using either [-m|--model] (remote) or [-f|--file] (local)" 260 | ) 261 | 262 | if options.remote_model and not options.dir: 263 | logger.info( 264 | "No results directory specified, fetching remote model to working directory..." 265 | ) 266 | 267 | if not options.hf_api_key and options.remote_model: 268 | logger.info( 269 | "No api key provided but requesting model, trying to download without authorization" 270 | ) 271 | hf_api_key = "" 272 | elif options.hf_api_key: 273 | hf_api_key = options.hf_api_key 274 | 275 | if options.local_model: 276 | local_model = options.local_model 277 | results = {} 278 | if not local_model.endswith(".h5") and local_model.endswith(".pb"): 279 | results = check_pb_for_code(local_model, local_model) 280 | elif local_model.endswith(".h5"): 281 | results = check_h5_for_code(local_model, local_model) 282 | 283 | elif options.remote_model: 284 | remote_model = options.remote_model 285 | api_token = hf_api_key 286 | if options.dir: 287 | directory = options.dir 288 | else: 289 | directory = "." 290 | downloaded_file = gather_file(remote_model, api_token, directory) 291 | file_path = str(downloaded_file) 292 | if downloaded_file != "UNAUTHORIZED": 293 | if not file_path.endswith(".h5") and file_path.endswith(".pb"): 294 | results = check_pb_for_code(downloaded_file, remote_model) 295 | elif file_path.endswith(".h5"): 296 | results = check_h5_for_code(downloaded_file, remote_model) 297 | 298 | if code := results.get("extracted_encoded_code"): 299 | logger.info( 300 | f"********* Trying to disassemble extracted code layer in {results['id']}: *********" 301 | ) 302 | try: 303 | dis.dis(marshal.loads(codecs.decode(code.encode("ascii"), "base64"))) 304 | except Exception as e: 305 | logger.error(f"!!! Unfortunately, dis struggled with {results['id']}: {e}") 306 | logger.info( 307 | f"********* Attempting to find strings for {results['id']}: *********" 308 | ) 309 | decoded_code = base64.b64decode(code) 310 | sl = list(strings(decoded_code)) 311 | if len(sl) > 0: 312 | results["string_list"] = sl 313 | logger.info(f"Found strings in {results['id']}:") 314 | logger.info(f"STRINGS: {sl}") 315 | else: 316 | logger.info(f"Could not find any printable strings in {results['id']}!") 317 | 318 | if options.results_file: 319 | results_file = options.results_file 320 | results_file = Path(results_file) 321 | results_file.parent.mkdir(parents=True, exist_ok=True) 322 | with open(results_file, "a") as f: 323 | f.write(json.dumps(results)) 324 | f.write("\n") 325 | else: 326 | logger.info( 327 | "********* No result file specified, printing results to std out: *********" 328 | ) 329 | logger.info(results) 330 | 331 | clean_up = options.clean_up 332 | if clean_up.lower() in ["true", "1"] and options.remote_model: 333 | os.remove(downloaded_file) 334 | parent_dir = remote_model.split("/")[0] 335 | if options.dir: 336 | directory = options.dir.rstrip("/") 337 | os.rmdir(f"{directory}/{remote_model}") 338 | os.rmdir(f"{directory}/{parent_dir}") 339 | elif options.remote_model: 340 | os.rmdir(remote_model) 341 | os.rmdir(parent_dir) 342 | 343 | 344 | if __name__ == "__main__": 345 | main() 346 | --------------------------------------------------------------------------------