├── NOTICE.txt ├── src ├── main │ ├── scala │ │ ├── reinvent │ │ │ └── securityanalytics │ │ │ │ ├── utilities │ │ │ │ ├── GeoIPException.scala │ │ │ │ ├── TorExitNodeList.scala │ │ │ │ ├── ObjectMapperSingleton.scala │ │ │ │ ├── S3Singleton.scala │ │ │ │ ├── SQSSingleton.scala │ │ │ │ ├── SNSSingleton.scala │ │ │ │ ├── GeoIPDBReaderSingleton.scala │ │ │ │ ├── Configuration.scala │ │ │ │ └── CloudTrailS3Utilities.scala │ │ │ │ ├── state │ │ │ │ ├── Field.scala │ │ │ │ ├── ActivityProfile.scala │ │ │ │ └── FieldProfileState.scala │ │ │ │ ├── profilers │ │ │ │ ├── GenericCloudTrailProfiler.scala │ │ │ │ ├── AccessKeyIDProfiler.scala │ │ │ │ ├── GeoIPCityProfiler.scala │ │ │ │ ├── GeoIPCountryProfiler.scala │ │ │ │ ├── TorProfiler.scala │ │ │ │ └── ActivityProfiler.scala │ │ │ │ ├── receivers │ │ │ │ ├── ReplayExistingCloudTrailEventsReceiver.scala │ │ │ │ ├── GenericCloudTrailEventsReceiver.scala │ │ │ │ └── NewCloudTrailEventsReceiver.scala │ │ │ │ ├── GeoIPLookup.scala │ │ │ │ ├── TorExitLookup.scala │ │ │ │ └── CloudTrailProfileAnalyzer.scala │ │ ├── S3LogsToSQL.scala │ │ └── CloudTrailToSQL.scala │ └── resources │ │ ├── sparkShell.sh │ │ ├── startStreaming.sh │ │ └── config │ │ └── emptyConfig.properties └── test │ └── scala │ └── reinvent │ └── securityanalytics │ └── receivers │ └── NewCloudTrailEventsReceiverTest.scala ├── LICENSE.txt ├── README.md └── pom.xml /NOTICE.txt: -------------------------------------------------------------------------------- 1 | Timely Security Analytics 2 | Copyright 2015 Amazon.com, Inc. or its affiliates. All Rights Reserved. -------------------------------------------------------------------------------- /src/main/scala/reinvent/securityanalytics/utilities/GeoIPException.scala: -------------------------------------------------------------------------------- 1 | package reinvent.securityanalytics.utilities 2 | 3 | class GeoIPException(message:String, underlying:Throwable) extends Exception(message, underlying) with Serializable -------------------------------------------------------------------------------- /src/main/resources/sparkShell.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | spark-shell \ 3 | --master yarn-client \ 4 | --num-executors 40 \ 5 | --conf spark.executor.cores=2 \ 6 | --jars /home/hadoop/cloudtrailanalysisdemo-1.0-SNAPSHOT-jar-with-dependencies.jar 7 | 8 | -------------------------------------------------------------------------------- /src/main/scala/reinvent/securityanalytics/utilities/TorExitNodeList.scala: -------------------------------------------------------------------------------- 1 | package reinvent.securityanalytics.utilities 2 | 3 | import java.net.InetAddress 4 | 5 | //See TorExitLookup 6 | class TorExitNodeList (val exitNodes:Set[InetAddress], val retrievedTimestamp:Long) extends Serializable 7 | -------------------------------------------------------------------------------- /src/main/scala/reinvent/securityanalytics/state/Field.scala: -------------------------------------------------------------------------------- 1 | package reinvent.securityanalytics.state 2 | 3 | class Field(val name:String) extends Serializable { 4 | def ignore(value:String):Boolean = false 5 | } 6 | 7 | case object SOURCE_IP_ADDRESS extends Field("sourceIPAddress") 8 | case object AWS_ACCESS_KEY_ID extends Field("accessKeyId") 9 | case object PRINCIPAL_ID extends Field("principalId") 10 | case object PRINCIPAL_ARN extends Field("arn") 11 | -------------------------------------------------------------------------------- /src/main/scala/reinvent/securityanalytics/profilers/GenericCloudTrailProfiler.scala: -------------------------------------------------------------------------------- 1 | package reinvent.securityanalytics.profilers 2 | 3 | import reinvent.securityanalytics.state.Field 4 | import reinvent.securityanalytics.utilities.Configuration 5 | 6 | /**A GenericCloudTrailProfiler will track new values for a given field in CloudTrail events. */ 7 | class GenericCloudTrailProfiler[T] (field:Field, config:Configuration) extends ActivityProfiler[T](field.name, config) 8 | -------------------------------------------------------------------------------- /src/main/resources/startStreaming.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | spark-submit \ 4 | --master yarn-client \ 5 | --num-executors 40 \ 6 | --conf spark.executor.cores=2 \ 7 | --conf spark.streaming.receiver.maxRate=3 \ 8 | --conf spark.rdd.compress=true \ 9 | --conf spark.cleaner.ttl=3600 \ 10 | --conf spark.streaming.concurrentJobs=5 \ 11 | --class reinvent.securityanalytics.CloudTrailProfileAnalyzer \ 12 | /home/hadoop/cloudtrailanalysisdemo-1.0-SNAPSHOT-jar-with-dependencies.jar \ 13 | config/reinventConfig.properties \ 14 | 2>&1 > output.txt 15 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright 2015 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with 4 | the License. A copy of the License is located at 5 | 6 | http://aws.amazon.com/apache2.0/ 7 | 8 | or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 9 | CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions 10 | and limitations under the License. -------------------------------------------------------------------------------- /src/main/scala/reinvent/securityanalytics/utilities/ObjectMapperSingleton.scala: -------------------------------------------------------------------------------- 1 | package reinvent.securityanalytics.utilities 2 | 3 | import com.fasterxml.jackson.databind.ObjectMapper 4 | 5 | //Make sure we load the ObjectMapper once and only once per executor 6 | object ObjectMapperSingleton { 7 | private var _mapper:Option[ObjectMapper] = None 8 | def mapper:ObjectMapper = { 9 | _mapper match { 10 | case Some(wrapped) => { wrapped } 11 | case None => { 12 | _mapper = Some(new ObjectMapper()) 13 | mapper 14 | } 15 | } 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /src/main/scala/reinvent/securityanalytics/utilities/S3Singleton.scala: -------------------------------------------------------------------------------- 1 | package reinvent.securityanalytics.utilities 2 | 3 | import com.amazonaws.regions.{Regions,Region} 4 | import com.amazonaws.services.s3.AmazonS3Client 5 | 6 | //Make sure we have only one S3 client per executor 7 | object S3Singleton { 8 | private var _client:Option[AmazonS3Client] = None 9 | def client(region:String):AmazonS3Client = { 10 | _client match { 11 | case Some(wrapped) => { wrapped } 12 | case None => { 13 | val newClient = new AmazonS3Client() 14 | newClient.setRegion(Region.getRegion(Regions.fromName(region))) 15 | _client = Some(newClient) 16 | client(region) 17 | } 18 | } 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /src/main/scala/reinvent/securityanalytics/utilities/SQSSingleton.scala: -------------------------------------------------------------------------------- 1 | package reinvent.securityanalytics.utilities 2 | 3 | import com.amazonaws.regions.{Regions, Region} 4 | import com.amazonaws.services.sqs.AmazonSQSClient 5 | 6 | //Make sure we have only SQS client per executor 7 | object SQSSingleton { 8 | private var _client:Option[AmazonSQSClient] = None 9 | def client(region:String):AmazonSQSClient = { 10 | _client match { 11 | case Some(wrapped) => { wrapped } 12 | case None => { 13 | val newClient = new AmazonSQSClient() 14 | newClient.setRegion(Region.getRegion(Regions.fromName(region))) 15 | _client = Some(newClient) 16 | client(region) 17 | } 18 | } 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /src/main/scala/reinvent/securityanalytics/utilities/SNSSingleton.scala: -------------------------------------------------------------------------------- 1 | package reinvent.securityanalytics.utilities 2 | 3 | import com.amazonaws.regions.{Regions, Region} 4 | import com.amazonaws.services.sns.AmazonSNSClient 5 | 6 | //Make sure we have only one SNS client per executor 7 | object SNSSingleton { 8 | private var _client:Option[AmazonSNSClient] = None 9 | def client(region:String):AmazonSNSClient = { 10 | _client match { 11 | case Some(wrapped) => { wrapped } 12 | case None => { 13 | val newClient = new AmazonSNSClient() 14 | newClient.setRegion(Region.getRegion(Regions.fromName(region))) 15 | _client = Some(newClient) 16 | client(region) 17 | } 18 | } 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /src/main/scala/reinvent/securityanalytics/state/ActivityProfile.scala: -------------------------------------------------------------------------------- 1 | package reinvent.securityanalytics.state 2 | 3 | /**ActivityProfile contains a set which is the profile of the previous activity and a count of how many 4 | * alerts it has sent.*/ 5 | class ActivityProfile[T] (previousActivity:Set[T], val alertCount:Int = 0) extends Serializable { 6 | 7 | //Union the old and new activity and increment the alert count 8 | def addNewActivity(newActivity:Set[T]):ActivityProfile[T] = { 9 | new ActivityProfile(previousActivity ++ newActivity, alertCount + 1) 10 | } 11 | 12 | def activity:Set[T] = previousActivity 13 | 14 | override def toString:String = { 15 | val stringBuffer = new StringBuffer() 16 | stringBuffer.append("Profile contents are " + previousActivity.mkString(",") + " (Size=" + 17 | previousActivity.size + ")\n") 18 | stringBuffer.append("Count of alerts sent: " + alertCount + "\n") 19 | stringBuffer.toString 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /src/main/scala/reinvent/securityanalytics/profilers/AccessKeyIDProfiler.scala: -------------------------------------------------------------------------------- 1 | package reinvent.securityanalytics.profilers 2 | 3 | import reinvent.securityanalytics.state.{ActivityProfile, AWS_ACCESS_KEY_ID} 4 | import reinvent.securityanalytics.utilities.Configuration 5 | 6 | /** AccessKeyIDProfiler will build a profile of the activity of long-term AWS access keys IDs 7 | * Access Key IDs beginning in "ASIA" will be ignored as they are temporary and will result in 8 | * significant false positives. */ 9 | class AccessKeyIDProfiler(config:Configuration) extends ActivityProfiler[String] (AWS_ACCESS_KEY_ID.name, config) { 10 | override def compareNewActivity(newActivity:Set[String], profile:ActivityProfile[String]):ActivityProfile[String] = { 11 | val transformedNewActivity = newActivity.filter((accessKeyId:String) => !accessKeyId.startsWith("ASIA")) 12 | if (transformedNewActivity.nonEmpty) { 13 | compareNewTransformedActivity(transformedNewActivity, profile) 14 | } 15 | else { //Return the current profile 16 | profile 17 | } 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /src/main/resources/config/emptyConfig.properties: -------------------------------------------------------------------------------- 1 | #Store this file in S3 and pass the bucket and key to your streaming application. Also, see utilities.Configuration.scala 2 | #The bucket and prefix where your CloudTrail logs can be found. They will be replayed by the ReplayExistingCloudTrailReceiver 3 | cloudTrailBucket= 4 | cloudTrailPrefix= 5 | #The bucket in which configuration data, e.g. Maxmind GeoIP DB can be found 6 | configDataBucket= 7 | #The path in the config bucket for the geoIP DB 8 | geoIPDatabaseKey= 9 | #The base path for Spark streaming checkpoints. See code and Spark documentation 10 | checkpointPath= 11 | #Where to load the Tor exit node list. Note this changes frequently and we only load it once. 12 | exitNodeURL=https://check.torproject.org/exit-addresses 13 | #The Spark streaming batch interval 14 | batchIntervalSeconds=30 15 | #Our Spark application name 16 | appName=CloudtrailProfiler 17 | #AWS region 18 | regionName=us-west-2 19 | #The SNS topic to which all alerts should be sent. 20 | alertTopic= 21 | #The queue URL which has notificatiosn of all new CloudTrail logs. Wire your CloudTrail SNS topic to this SQS queue. 22 | cloudTrailQueue= -------------------------------------------------------------------------------- /src/main/scala/reinvent/securityanalytics/profilers/GeoIPCityProfiler.scala: -------------------------------------------------------------------------------- 1 | package reinvent.securityanalytics.profilers 2 | 3 | import org.apache.commons.logging.LogFactory 4 | import reinvent.securityanalytics.GeoIPLookup 5 | import reinvent.securityanalytics.state.ActivityProfile 6 | import reinvent.securityanalytics.utilities.{Configuration, GeoIPException} 7 | 8 | /**A GeoIPCityProfiler will alert you when AWS activity involving your AWS resources comes from a new city, as 9 | * determined by a GeoIP database.*/ 10 | class GeoIPCityProfiler(config:Configuration) extends ActivityProfiler[String] ("SourceIP-City", config) { 11 | val geoIPLookup = new GeoIPLookup(config) 12 | private val logger = LogFactory.getLog(this.getClass) 13 | 14 | override def compareNewActivity(newActivity:Set[String], profile:ActivityProfile[String]):ActivityProfile[String] = { 15 | try { 16 | val transformedNewActivity = newActivity.map((sourceIpAddress: String) => { 17 | geoIPLookup.cityLookup(sourceIpAddress) 18 | }) 19 | compareNewTransformedActivity(transformedNewActivity, profile) 20 | } 21 | catch { 22 | case (g:GeoIPException) => { 23 | logger.info("Could not look up city for " + newActivity) 24 | new ActivityProfile(Set.empty) 25 | } 26 | } 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/main/scala/reinvent/securityanalytics/profilers/GeoIPCountryProfiler.scala: -------------------------------------------------------------------------------- 1 | package reinvent.securityanalytics.profilers 2 | 3 | import org.apache.commons.logging.LogFactory 4 | import reinvent.securityanalytics.GeoIPLookup 5 | import reinvent.securityanalytics.state.ActivityProfile 6 | import reinvent.securityanalytics.utilities.{Configuration, GeoIPException} 7 | 8 | /**A GeoIPCountryProfiler will alert you when AWS activity involving your AWS resources comes from a new country, as 9 | * determined by a GeoIP database.*/ 10 | class GeoIPCountryProfiler(config:Configuration) extends ActivityProfiler[String] ("SourceIP-Country", config) { 11 | private val logger = LogFactory.getLog(this.getClass) 12 | val geoIPLookup = new GeoIPLookup(config) 13 | 14 | override def compareNewActivity(newActivity:Set[String], profile:ActivityProfile[String]):ActivityProfile[String] = { 15 | try { 16 | val transformedNewActivity = newActivity.map((sourceIpAddress: String) => { 17 | geoIPLookup.countryLookup(sourceIpAddress) 18 | }) 19 | compareNewTransformedActivity(transformedNewActivity, profile) 20 | } 21 | catch { 22 | case (g:GeoIPException) => { 23 | logger.error("Could not look up country for " + newActivity) 24 | new ActivityProfile(Set.empty) 25 | } 26 | } 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/main/scala/reinvent/securityanalytics/utilities/GeoIPDBReaderSingleton.scala: -------------------------------------------------------------------------------- 1 | package reinvent.securityanalytics.utilities 2 | 3 | import java.io.InputStream 4 | 5 | import com.amazonaws.services.s3.AmazonS3Client 6 | import com.maxmind.geoip2.DatabaseReader 7 | import org.apache.spark.Logging 8 | 9 | import scala.util.{Failure, Success, Try} 10 | 11 | /**Make sure we load the DB once, and only once per executor.*/ 12 | object GeoIPDBReaderSingleton extends Logging { 13 | private var _dbreader:Option[DatabaseReader] = None 14 | 15 | /** @return the DatabaseReader singleton that can be used to access the GeoIP database*/ 16 | def dbreader(config:Configuration):DatabaseReader = { 17 | _dbreader match { 18 | case None => { 19 | loadDBFromS3(config) match { 20 | case Success(db) => { 21 | logInfo("Creating new GeoIP DB reader") 22 | _dbreader = Some(new DatabaseReader.Builder(db).build()) 23 | } 24 | case Failure(e) => { 25 | throw new GeoIPException("Geo IP DB could not be loaded. Make sure you have it specified in your config.", e) 26 | } 27 | } 28 | dbreader(config) 29 | } 30 | case Some(reader) => reader 31 | } 32 | } 33 | 34 | /**Attempt to load the database from a location in S3 35 | * @return a Success(inputstream) in the event of success or a Failure wrapping the exception*/ 36 | def loadDBFromS3(config:Configuration):Try[InputStream] = { 37 | val s3 = new AmazonS3Client() 38 | Success(s3.getObject(config.getString(Configuration.CONFIG_DATA_BUCKET), config.getString(Configuration.GEO_IP_DB_KEY)).getObjectContent) 39 | } 40 | } -------------------------------------------------------------------------------- /src/main/scala/reinvent/securityanalytics/receivers/ReplayExistingCloudTrailEventsReceiver.scala: -------------------------------------------------------------------------------- 1 | package reinvent.securityanalytics.receivers 2 | 3 | import com.amazonaws.services.s3.AmazonS3Client 4 | import org.apache.commons.logging.LogFactory 5 | import org.apache.spark.Logging 6 | import org.apache.spark.storage.StorageLevel 7 | import org.apache.spark.streaming.receiver.Receiver 8 | import reinvent.securityanalytics.utilities.{Configuration, CloudTrailS3Utilities} 9 | 10 | /**One of the best ways to test this code is to replay your existing CloudTrail records as if they arrived fresh. 11 | * This receiver does that by looking in the configured bucket, with the configured prefix, for CloudTrail logs. 12 | * As it finds them, it converts them to the raw events and store()s them. Control the incoming rate with the 13 | * spark.streaming.receiver.maxRate Spark configuration value. */ 14 | class ReplayExistingCloudTrailEventsReceiver(config:Configuration) extends GenericCloudTrailEventsReceiver(config) { 15 | override protected val logger = LogFactory.getLog(this.getClass) 16 | 17 | override def onStart():Unit = { 18 | val bucket = config.getString(Configuration.CLOUDTRAIL_BUCKET) 19 | val prefix = config.getString(Configuration.CLOUDTRAIL_PREFIX, true) 20 | val s3 = new AmazonS3Client() 21 | val cloudTrailS3Objects = CloudTrailS3Utilities.findCloudTrailDataInBucket(bucket, prefix) 22 | 23 | //For each S3 object containing CloudTrail data... 24 | cloudTrailS3Objects.foreach(bucketKeyPair => { 25 | readAndStoreRawCloudTrailEvents(bucketKeyPair._1, bucketKeyPair._2) 26 | }) 27 | 28 | logger.error("Finished reading existing CloudTrail events.") 29 | } 30 | } 31 | 32 | -------------------------------------------------------------------------------- /src/main/scala/reinvent/securityanalytics/receivers/GenericCloudTrailEventsReceiver.scala: -------------------------------------------------------------------------------- 1 | package reinvent.securityanalytics.receivers 2 | 3 | import org.apache.commons.logging.LogFactory 4 | import org.apache.spark.storage.StorageLevel 5 | import org.apache.spark.streaming.receiver.Receiver 6 | import reinvent.securityanalytics.utilities._ 7 | 8 | //Common code between receivers for new and existing CloudTrail events 9 | abstract class GenericCloudTrailEventsReceiver(config:Configuration) extends Receiver[String](StorageLevel.MEMORY_ONLY) { 10 | protected var run = true 11 | private var storeCount = 0 12 | protected val logger = LogFactory.getLog(this.getClass) 13 | 14 | def readAndStoreRawCloudTrailEvents(s3Bucket:String, s3Key:String) = { 15 | require(s3Bucket.nonEmpty, "Cannot read CloudTrail logs from an empty S3 bucket.") 16 | require(s3Key.nonEmpty, "Cannot read CloudTrail logs from an empty S3 key.") 17 | 18 | val s3 = S3Singleton.client(config.getString(Configuration.REGION)) 19 | 20 | //Pull the raw strings from S3 21 | val rawEvents = CloudTrailS3Utilities.readRawCloudTrailEventsFromS3Object(s3Bucket, s3Key, s3) 22 | //Convert these into individual CloudTrail events (JSON strings) 23 | val cloudTrailEvents = CloudTrailS3Utilities.readCloudtrailRecords(rawEvents) 24 | cloudTrailEvents.foreach((event:String) => { 25 | if (run) { 26 | if (event == null || event.equals("")) { 27 | logger.warn("I will not store an empty event.") 28 | } 29 | else { 30 | store(event) //Store each event individually 31 | storeCount += 1 32 | } 33 | } 34 | else { 35 | logger.error("Not storing due to receiver shutdown.") 36 | } 37 | }) 38 | 39 | logger.info(storeCount + " CloudTrail events have been read and stored in total..") 40 | } 41 | 42 | override def onStop():Unit = { run = false} 43 | } 44 | -------------------------------------------------------------------------------- /src/test/scala/reinvent/securityanalytics/receivers/NewCloudTrailEventsReceiverTest.scala: -------------------------------------------------------------------------------- 1 | package reinvent.securityanalytics.receivers 2 | 3 | import com.amazonaws.services.sqs.AmazonSQSClient 4 | import org.junit.runner.RunWith 5 | import org.scalatest.{FlatSpec,Matchers} 6 | import org.mockito.Mockito._ 7 | import reinvent.securityanalytics.utilities.Configuration 8 | import com.amazonaws.services.sqs.model.Message 9 | import org.scalatest.junit.JUnitRunner 10 | 11 | @RunWith(classOf[JUnitRunner]) 12 | class NewCloudTrailEventsReceiverTest extends FlatSpec with Matchers { 13 | it should "properly parse an SNS message" in { 14 | val sqsMessageBody = "{\n \"Type\" : \"Notification\",\n \"MessageId\" : \"id\",\n \"TopicArn\" : \"arn:aws:sns:us-west-2::cloudTrailLogArrivals\",\n \"Message\" : \"{\\\"s3Bucket\\\":\\\"cloudtrail17\\\",\\\"s3ObjectKey\\\":[\\\"objectPath\\\"]}\",\n \"Timestamp\" : \"2015-10-03T01:36:22.964Z\",\n \"SignatureVersion\" : \"1\",\n \"Signature\" : \"AAAAA\",\n \"SigningCertURL\" : \"https://sns.us-west-2.amazonaws.com/SimpleNotificationService-bb750dd426d95ee9390147a5624348ee.pem\",\n \"UnsubscribeURL\" : \"url\"\n}" 15 | val receipt = "RECEIPT" 16 | val queueURL = "URL" 17 | val sqs = mock(classOf[AmazonSQSClient]) 18 | val config = mock(classOf[Configuration]) 19 | when(config.getString(Configuration.CLOUDTRAIL_NEW_LOGS_QUEUE)).thenReturn(queueURL) 20 | val sqsMessage = mock(classOf[Message]) 21 | when(sqsMessage.getBody).thenReturn(sqsMessageBody) 22 | when(sqsMessage.getReceiptHandle).thenReturn(receipt) 23 | 24 | var stored = false 25 | def readAndStoreFunction(bucket:String, key:String):Unit = { 26 | stored = true 27 | } 28 | 29 | val receiver = new NewCloudTrailEventsReceiver(config) 30 | receiver.processSNSMessageInSQSMessage(sqsMessage, sqs, readAndStoreFunction) 31 | 32 | assert(stored) 33 | verify(sqs).deleteMessage(queueURL, receipt) 34 | } 35 | } -------------------------------------------------------------------------------- /src/main/scala/reinvent/securityanalytics/state/FieldProfileState.scala: -------------------------------------------------------------------------------- 1 | package reinvent.securityanalytics.state 2 | 3 | /* We want to use Spark's DStream.updateStateByKey functionality for scalably comparing historical data against 4 | * new data. Doing so requires us to maintain state for each of the CloudTrail fields we want to keep track of. 5 | * This state is the values of the CloudTrail field that we've seen in the past. In some cases (e.g., SourceIPAddress), 6 | * we may transform the CloudTrail values (e.g., look up the GeoIP country of origin) and compare those transformed 7 | * values against any previous values. Therefore, we have a one-to-many mapping of fields to the profiles we want to 8 | * keep for them. FieldProfileState contains this mapping. Specifically, it maps the profiler name to the set of 9 | * previously seen values. Therefore, updateStateByKey will take the new activity, look up the profilers for the field, 10 | * and pass the new data to each profiler. Each profiler will return the new profile (or state) to be saved for that 11 | * profiler and we will store that state in the FieldProfileState map.*/ 12 | 13 | class FieldProfileState(stateMappings:Map[String, ActivityProfile[String]], updateCount:Int = 0) extends Serializable { 14 | def updateProfilerStateMappings(newMappings:Map[String, ActivityProfile[String]]):FieldProfileState = { 15 | new FieldProfileState(newMappings, updateCount + 1) 16 | } 17 | 18 | def mappings = stateMappings 19 | 20 | override def toString:String = { 21 | val stringBuffer = new StringBuffer() 22 | stringBuffer.append("This field's state has been updated " + updateCount + " times.\n") 23 | stringBuffer.append("State mappings size is " + stateMappings.size + ". State mappings are:\n") 24 | stateMappings.foreach(pair => { 25 | val profilerName = pair._1 26 | val profile:ActivityProfile[String] = pair._2 27 | stringBuffer.append("Profile for " + profilerName + " is " + profile.toString + "\n") 28 | }) 29 | stringBuffer.toString 30 | } 31 | } 32 | 33 | -------------------------------------------------------------------------------- /src/main/scala/reinvent/securityanalytics/utilities/Configuration.scala: -------------------------------------------------------------------------------- 1 | package reinvent.securityanalytics.utilities 2 | 3 | import java.util.Properties 4 | 5 | import com.amazonaws.regions.{Region, Regions} 6 | import com.amazonaws.services.s3.AmazonS3Client 7 | import org.apache.commons.logging.LogFactory 8 | 9 | //A serializable wrapper for a Java properties file which contains the application's configuration. 10 | class Configuration(configBucket:String, configKey:String) extends Serializable { 11 | private val logger = LogFactory.getLog(this.getClass.getSimpleName) 12 | var _properties:Option[Properties] = None 13 | 14 | def region:Region = { 15 | Region.getRegion(Regions.fromName(getString(Configuration.REGION))) 16 | } 17 | 18 | private def properties:Properties = { 19 | _properties match { 20 | case None => { 21 | logger.info("Loading configuration...") 22 | val s3 = new AmazonS3Client() 23 | val inputStream = s3.getObject(configBucket, configKey).getObjectContent 24 | val underlying = new Properties() 25 | underlying.load(inputStream) 26 | _properties = Some(underlying) 27 | underlying 28 | } 29 | case Some(underlying) => { 30 | underlying 31 | } 32 | } 33 | } 34 | 35 | def getString(key:String, isEmptyValueOkay:Boolean = false):String = { 36 | val value = properties.getProperty(key) 37 | if (value == null || value.equals("")) { 38 | if (isEmptyValueOkay) { 39 | logger.warn("Empty value found for " + key + ". Returning empty string.") 40 | "" 41 | } 42 | else { 43 | throw new IllegalArgumentException("Could not find config value for " + key) 44 | } 45 | } 46 | else { 47 | value 48 | } 49 | } 50 | 51 | def getInt(key:String):Int = { 52 | getString(key).toInt 53 | } 54 | } 55 | 56 | object Configuration { 57 | val CLOUDTRAIL_BUCKET = "cloudTrailBucket" 58 | val CLOUDTRAIL_PREFIX = "cloudTrailPrefix" 59 | val CONFIG_DATA_BUCKET = "configDataBucket" 60 | val STATE_OUTPUT_BUCKET = "stateOutputBucket" 61 | val STATE_OUTPUT_PREFIX = "stateOutputPrefix" 62 | val STATE_OUTPUT_NAME = "stateOutputName" 63 | val CHECKPOINT_PATH = "checkpointPath" 64 | val GEO_IP_DB_KEY = "geoIPDatabaseKey" 65 | val EXIT_NODE_URL = "exitNodeURL" 66 | val BATCH_INTERVAL_SECONDS = "batchIntervalSeconds" 67 | val APP_NAME = "appName" 68 | val REGION = "regionName" 69 | val ALERT_TOPIC = "alertTopic" 70 | val ALERT_QUEUE = "alertQueue" 71 | val CLOUDTRAIL_NEW_LOGS_QUEUE = "cloudTrailQueue" 72 | } -------------------------------------------------------------------------------- /src/main/scala/reinvent/securityanalytics/profilers/TorProfiler.scala: -------------------------------------------------------------------------------- 1 | package reinvent.securityanalytics.profilers 2 | 3 | import java.net.InetAddress 4 | import org.apache.commons.logging.LogFactory 5 | import org.apache.spark.SparkContext 6 | import reinvent.securityanalytics.TorExitLookup 7 | import reinvent.securityanalytics.state.ActivityProfile 8 | import reinvent.securityanalytics.utilities.{Configuration, SNSSingleton} 9 | 10 | /**A TorProfiler which will alert you when AWS activity involving your AWS resources originates from the Tor anonymizing 11 | * network. While this may be fine for some administrators, this may be unacceptable to others who would see the 12 | * use of an anonymizing proxy as a sign of hiding one's tracks.*/ 13 | class TorProfiler(config:Configuration) extends ActivityProfiler[String] ("SourceIP-TorExitNode", config) { 14 | private val logger = LogFactory.getLog(this.getClass) 15 | 16 | val torExitLookup = new TorExitLookup(config) 17 | 18 | override def initialize(sparkContext:SparkContext) = { 19 | logger.info("Initializing TorProfiler") 20 | torExitLookup.initialize(sparkContext) //Load the exit node list. 21 | } 22 | 23 | override def compareNewActivity(newActivity:Set[String], profile:ActivityProfile[String]):ActivityProfile[String] = { 24 | var updatedProfile = profile 25 | newActivity.foreach((sourceIpAddress:String) => { 26 | try { 27 | val ipAddress = InetAddress.getByName(sourceIpAddress) 28 | if (torExitLookup.isExitNode(ipAddress)) { 29 | val sns = SNSSingleton.client(config.getString(Configuration.REGION)) 30 | val subject = "Anonymizing proxy in use with your AWS account" 31 | val message = "Activity on your AWS resources has been seen from " + sourceIpAddress + " which is a Tor exit node." 32 | sns.publish(config.getString(Configuration.ALERT_TOPIC), message, subject) 33 | /*TODO Currently this will generate an alarm for each remote IP 34 | Instead, we should track either the long-term credential or the principalID to avoid duplicates. In which case, 35 | an expiration scheme is needed. */ 36 | updatedProfile = profile.addNewActivity(Set.empty) //Will increment alerts counter. 37 | } 38 | } 39 | catch { 40 | case (u:Throwable) => { 41 | logger.error("Invalid IP address: " + sourceIpAddress, u) 42 | } 43 | case (e:Exception) => { 44 | logger.error("Unexpected exception while parsing IP address", e) 45 | } 46 | } 47 | }) 48 | updatedProfile 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /src/main/scala/reinvent/securityanalytics/GeoIPLookup.scala: -------------------------------------------------------------------------------- 1 | package reinvent.securityanalytics 2 | 3 | import java.net.{InetAddress, UnknownHostException} 4 | 5 | import com.maxmind.geoip2.model.CityResponse 6 | import org.apache.commons.logging.LogFactory 7 | import org.apache.spark.sql.SQLContext 8 | import reinvent.securityanalytics.utilities.{Configuration, GeoIPDBReaderSingleton} 9 | 10 | class GeoIPLookup(config:Configuration) extends Serializable { 11 | private val logger = LogFactory.getLog(this.getClass) 12 | 13 | /**@param ipAddress the IP address for which to look up the city 14 | * @return the city name for the IP address*/ 15 | def cityLookup(ipAddress:String):String = { 16 | try { 17 | ipGeoLocationLookup(ipAddress).getCity.getName 18 | } 19 | catch { 20 | case (u:UnknownHostException) => { 21 | logger.error("Invalid IP address: " + ipAddress, u) 22 | "Unknown" 23 | } 24 | case (e:Exception) => { 25 | logger.error("Unexpected exception while parsing IP address", e) 26 | "Unknown" 27 | } 28 | } 29 | } 30 | 31 | /**@param ipAddress the IP address for which to look up the country 32 | * @return the country for the IP address*/ 33 | def countryLookup(ipAddress:String):String = { 34 | try { 35 | ipGeoLocationLookup(ipAddress).getCountry.getName 36 | } 37 | catch { 38 | case (u:Throwable) => { 39 | logger.error("Invalid IP address: " + ipAddress, u) 40 | "Unknown" 41 | } 42 | case (e:Exception) => { 43 | logger.error("Unexpected exception while parsing IP address", e) 44 | "Unknown" 45 | } 46 | } 47 | } 48 | 49 | /**@param ipAddress the IP address string for which to look up the city 50 | * @return the CityResponse object for that string*/ 51 | def ipGeoLocationLookup(ipAddress:String):CityResponse = { 52 | ipGeoLocationLookup(InetAddress.getByName(ipAddress)) 53 | } 54 | 55 | /**@param inetAddress the IP address object for which to look up the city 56 | * @return the CityResponse object for that string*/ 57 | def ipGeoLocationLookup(inetAddress:InetAddress):CityResponse = { 58 | GeoIPDBReaderSingleton.dbreader(config).city(inetAddress) 59 | } 60 | 61 | /**Registers user-defined functions in the SQL context for querying GeoIP data*/ 62 | def registerUDFs(sqlContext:SQLContext):Unit = { 63 | sqlContext.udf.register("city", cityLookup(_:String)) 64 | sqlContext.udf.register("country", countryLookup(_:String)) 65 | } 66 | 67 | //Can only be run after registerUDFs is called, otherwise UDFs won't be defined. 68 | def runSampleGeoIPQuery(sqlContext:SQLContext) = { 69 | sqlContext.sql("select distinct sourceIpAddress, city(sourceIpAddress) as city, country(sourceIpAddress) as country from cloudtrail").show(10000) 70 | } 71 | } -------------------------------------------------------------------------------- /src/main/scala/reinvent/securityanalytics/profilers/ActivityProfiler.scala: -------------------------------------------------------------------------------- 1 | package reinvent.securityanalytics.profilers 2 | 3 | import org.apache.commons.logging.LogFactory 4 | import org.apache.spark.{SparkContext, Logging} 5 | import reinvent.securityanalytics.state.ActivityProfile 6 | import reinvent.securityanalytics.utilities.{Configuration, SNSSingleton} 7 | 8 | /**ActivityProfiler is the most generic "profile" builder. It contains most of the comparison and alerting loci. 9 | * @param key a string that uniquely identifies this profiler*/ 10 | class ActivityProfiler[T](val key:String, config:Configuration) extends Serializable { 11 | private val logger = LogFactory.getLog(this.getClass) 12 | private val EOL = "\r\n" 13 | protected var alerts = 0 14 | 15 | protected def updateProfile(newActivity:Set[T], previousProfile:ActivityProfile[T]):ActivityProfile[T] = { 16 | logger.info("Updating profile for " + key) 17 | 18 | if (newActivity != null && newActivity.nonEmpty) { 19 | val oldSize = previousProfile.activity.size 20 | 21 | //Add new activity to the existing profile //TODO Handle expiration in cases where it makes sense 22 | val newProfile:ActivityProfile[T] = previousProfile.addNewActivity(newActivity) 23 | 24 | val newSize = newProfile.activity.size 25 | if (newSize <= oldSize) { 26 | logger.error("Problem: profile did not grow. " + oldSize + " vs. " + newSize) 27 | } 28 | logger.info("Saved " + newProfile + " as profile for " + key) 29 | 30 | newProfile 31 | } 32 | else { 33 | logger.error("Empty profile found for " + key ) 34 | new ActivityProfile(Set.empty) 35 | } 36 | } 37 | 38 | //Override this if you need to transform the input in some way 39 | def compareNewActivity(newActivity:Set[T], profile:ActivityProfile[T]):ActivityProfile[T] = { 40 | compareNewTransformedActivity(newActivity, profile) 41 | } 42 | 43 | /*compareNewTransformedActivity will compare new (transformed) activity against the current profile*/ 44 | protected def compareNewTransformedActivity(newActivity:Set[T], profile:ActivityProfile[T]):ActivityProfile[T] = { 45 | if (profile.activity.isEmpty) { //If no profile exists, we'll create one from this batch of data 46 | logger.info("Empty profile found for " + key + ". We will create a new profile using " + newActivity) 47 | val newProfile = new ActivityProfile(newActivity) 48 | logger.info("New profile" + newProfile) 49 | newProfile 50 | } 51 | else { 52 | if (newActivity.subsetOf(profile.activity)) { 53 | logger.info("The following activity is In compliance with our profile for " + key + ": " 54 | + newActivity.mkString(",")) 55 | profile //Return the existing profile as the new state 56 | } 57 | else { //We have a deviation from the profile. Send an alert and update the profile. 58 | alert(newActivity, profile) 59 | updateProfile(newActivity, profile) 60 | } 61 | } 62 | } 63 | 64 | //Logic for generating an SNS alert 65 | private def alert(newActivity:Set[T], profile:ActivityProfile[T]): Unit = { 66 | val intersection = profile.activity.intersect(newActivity) 67 | val sns = SNSSingleton.client(config.getString(Configuration.REGION)) 68 | val subject = "CloudTrail Profile Mismatch for " + key 69 | val body = new StringBuilder() 70 | body.append("New activity which doesn't fit current profile: " + 71 | newActivity.diff(profile.activity).mkString(",") + EOL) 72 | body.append(EOL + "Current activity (including compliant activity): " + newActivity.mkString(",") + EOL) 73 | body.append("Current profile (which was not matched): " + profile.activity.mkString(",") + EOL) 74 | body.append("The following is the intersection between the existing profile and new activity: " + 75 | intersection.mkString(",") + EOL) 76 | body.append(EOL + "The profile will be updated (to avoid duplicate alarms)") 77 | 78 | logger.info("***** Sending alert: Subject="+subject+" Body="+body.toString()) 79 | sns.publish(config.getString(Configuration.ALERT_TOPIC), body.toString(), subject) 80 | alerts += 1 81 | } 82 | 83 | override def toString():String = { 84 | key + " generated " + alerts + " alerts." 85 | } 86 | 87 | //Child classes should override this 88 | def initialize(sparkContext:SparkContext):Unit = {} 89 | } -------------------------------------------------------------------------------- /src/main/scala/reinvent/securityanalytics/TorExitLookup.scala: -------------------------------------------------------------------------------- 1 | package reinvent.securityanalytics 2 | 3 | import java.net.{InetAddress, UnknownHostException} 4 | 5 | import org.apache.commons.logging.LogFactory 6 | import org.apache.spark.SparkContext 7 | import org.apache.spark.broadcast.Broadcast 8 | import org.apache.spark.rdd.RDD 9 | import reinvent.securityanalytics.utilities.{Configuration, TorExitNodeList} 10 | 11 | import scala.io.Source 12 | 13 | //TODO: Update the exit node list periodically. 14 | class TorExitLookup(config:Configuration) extends Serializable { 15 | private val logger = LogFactory.getLog(this.getClass) 16 | val EXIT_FILE_ADDRESS_LINE = "ExitAddress" 17 | var exitNodesListOption:Option[Broadcast[TorExitNodeList]] = None 18 | var lastUpdateTimestamp:Long = 0 19 | 20 | def getExitNodeRDD(sparkContext:SparkContext):RDD[InetAddress] = { 21 | loadExitNodeList() match { 22 | case None => sparkContext.parallelize(Seq.empty[InetAddress]) 23 | case Some(exitNodesList) => { 24 | sparkContext.parallelize(exitNodesList.exitNodes.toSeq) 25 | } 26 | } 27 | } 28 | 29 | def initialize(sparkContext:SparkContext):Unit = { 30 | exitNodesListOption = Some(getBroadcastExitNodeList(sparkContext)) 31 | } 32 | 33 | def getBroadcastExitNodeList(sparkContext:SparkContext):Broadcast[TorExitNodeList] = { 34 | val nodeList = loadExitNodeList() match { 35 | case None => new TorExitNodeList(Set.empty[InetAddress], 0) 36 | case Some(exitNodesList) => { 37 | exitNodesList 38 | } 39 | } 40 | sparkContext.broadcast(nodeList) 41 | } 42 | 43 | /*Format looks like: 44 | ExitNode 0011BD2485AD45D984EC4159C88FC066E5E3300E 45 | Published 2015-09-27 22:16:46 46 | LastStatus 2015-09-27 23:08:39 47 | ExitAddress 162.247.72.201 2015-09-27 23:17:58 48 | ExitNode 0098C475875ABC4AA864738B1D1079F711C38287 49 | Published 2015-09-28 13:59:30 50 | LastStatus 2015-09-28 15:03:16 51 | ExitAddress 162.248.160.151 2015-09-28 15:12:01 52 | ExitNode 00B70D1F261EBF4576D06CE0DA69E1F700598239 53 | Published 2015-09-28 10:21:07 54 | LastStatus 2015-09-28 11:02:17 55 | ExitAddress 193.34.116.18 2015-09-28 11:10:34 56 | ... 57 | 58 | We just want the IPs. The following function gets them. 59 | */ 60 | private def loadExitNodeList():Option[TorExitNodeList] = { 61 | try { 62 | val exitNodesDump = Source.fromURL(config.getString(Configuration.EXIT_NODE_URL)).getLines() 63 | val exitNodes = exitNodesDump.filter(line => line.contains(EXIT_FILE_ADDRESS_LINE)).map(line => { 64 | val lineArr = line.split(" ") 65 | if (lineArr == null || lineArr.length < 2) { 66 | logger.warn("The following line could not be converted to an IP string: " + line) 67 | None 68 | } 69 | else { 70 | val ipAddressString = lineArr(1) 71 | 72 | try { 73 | Some(InetAddress.getByName(ipAddressString)) 74 | } 75 | catch { 76 | case (u: UnknownHostException) => { 77 | logger.warn("Could not parse " + ipAddressString + " into an IP address") 78 | None 79 | } 80 | } 81 | } 82 | }).filter(_.isDefined).map(_.get).toSet //Filter out the entries we couldn't parse 83 | 84 | if (exitNodes.size < 1) { 85 | logger.error("No exit node information found at " + config.getString(Configuration.EXIT_NODE_URL)) 86 | } 87 | Some(new TorExitNodeList(exitNodes, System.currentTimeMillis())) 88 | } 89 | catch { 90 | case (e:Exception) => { 91 | logger.error("Problem initializing the Tor exit node list. Continuing without initialization.", e) 92 | None 93 | } 94 | } 95 | } 96 | 97 | def isExitNode(ip:InetAddress):Boolean = { 98 | if (exitNodesListOption.isDefined) { 99 | val exitNodeList = exitNodesListOption.get.value.exitNodes 100 | if (exitNodeList.isEmpty) { 101 | logger.warn("Exit node list is empty. Make sure to initialize it first.") 102 | false 103 | } 104 | else { 105 | exitNodeList.contains(ip) 106 | } 107 | } 108 | else { 109 | logger.error("Error: Exit node list not initialized.") 110 | false 111 | } 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /src/main/scala/S3LogsToSQL.scala: -------------------------------------------------------------------------------- 1 | import org.apache.spark.sql.types.{StructField, StructType} 2 | import org.apache.spark.{Logging, SparkContext} 3 | import org.apache.spark.rdd.RDD 4 | import org.apache.spark.sql.{DataFrame, SQLContext, Row} 5 | import org.apache.spark.sql.types.{StructType,StructField,StringType} 6 | 7 | /*Copyright 2015 Amazon.com, Inc. or its affiliates. All Rights Reserved. 8 | 9 | Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with 10 | the License. A copy of the License is located at 11 | 12 | http://aws.amazon.com/apache2.0/ 13 | 14 | or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 15 | CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions 16 | and limitations under the License.*/ 17 | 18 | 19 | /**Like CloudTrailToSQL, S3LogsToSQL is written for you to cut and paste it into your spark-shell and use it from there. 20 | * TODO: We do not auto-discover your S3 logs so you need to pass a bucket and prefix. 21 | * 22 | * Sample use: 23 | * 1. Start the Spark shell as follows: 24 | spark-shell --master yarn-client --num-executors 40 --conf spark.executor.cores=2 25 | * 2. Copy and paste this entire file into the shell. It will create an object called S3LogsToSQL 26 | * 3. Load your S3 logs into a Hive table by running a command like the following: 27 | S3LogsToSQL.createHiveTable("", "S3logs", sqlContext) 28 | * 4. Query them, e.g., 29 | sqlContext.sql("select distinct ip from s3logshive").show(10000) 30 | * */ 31 | object S3LogsToSQL extends Logging { 32 | val TABLE_NAME = "s3logs" 33 | val schemaString = "bucketOwner bucket date ip requester operation key" 34 | val RDD_PARTITION_COUNT = 1000 //TODO: Consider making this dynamic and configurable. 35 | 36 | // Generate the schema from the schema string 37 | val schema = 38 | StructType( 39 | schemaString.split(" ").map(fieldName => StructField(fieldName, StringType, true))) 40 | 41 | def rawS3Logs(sparkContext:SparkContext, logsBucket:String, logsPrefix:String):RDD[String] = { 42 | val path = "s3://" + logsBucket + "/" + logsPrefix + "*" 43 | val rawStrings:RDD[String] = sparkContext.textFile(path).coalesce(RDD_PARTITION_COUNT, true) 44 | rawStrings 45 | } 46 | 47 | def rawS3LogsToTempTable(rawStrings:RDD[String], sqlContext:SQLContext):DataFrame = { 48 | val s3LogsDF = rawS3LogsToDataFrame(rawStrings, sqlContext) 49 | s3LogsDF.registerTempTable(TABLE_NAME) 50 | s3LogsDF 51 | } 52 | 53 | def rawS3LogsToHiveTable(rawStrings:RDD[String], sqlContext:SQLContext):DataFrame = { 54 | val s3LogsDF = rawS3LogsToDataFrame(rawStrings, sqlContext) 55 | s3LogsDF.write.saveAsTable(TABLE_NAME+"hive") 56 | s3LogsDF 57 | } 58 | 59 | /** See http://docs.aws.amazon.com/AmazonS3/latest/dev/LogFormat.html for the format. 60 | * TODO, Right now we only load in part of the logs. It would be better to use regexes or a more thoughtful 61 | * parsing scheme.*/ 62 | def rawS3LogsToDataFrame(rawStrings:RDD[String], sqlContext:SQLContext):DataFrame = { 63 | val rowRDD = rawStrings.map((line:String) => { 64 | val lineArray = line.split(" ") 65 | 66 | if (lineArray.length >= 9) { 67 | val bucketOwner = lineArray(0) 68 | val bucket = lineArray(1) 69 | val date = (lineArray(2) + " " + lineArray(3)) 70 | val ip = lineArray(4) 71 | val requester = lineArray(5) 72 | val operation = lineArray(7) 73 | val key = lineArray(8) 74 | Some(Row(bucketOwner, bucket, date, ip, requester, operation, key)) 75 | } 76 | else { 77 | println("Could not parse the following line: " + line) 78 | None 79 | } 80 | }).filter(_.isDefined).map(_.get) 81 | 82 | val s3LogsDF = sqlContext.createDataFrame(rowRDD, schema) 83 | s3LogsDF.cache() 84 | s3LogsDF 85 | } 86 | 87 | def createTable(logsBucket:String, logsPrefix:String, sqlContext:SQLContext):DataFrame = { 88 | val rawLogsRDD = rawS3Logs(sqlContext.sparkContext, logsBucket, logsPrefix) 89 | rawS3LogsToTempTable(rawLogsRDD, sqlContext) 90 | } 91 | 92 | def createHiveTable(logsBucket:String, logsPrefix:String, sqlContext:SQLContext):DataFrame = { 93 | val rawLogsRDD = rawS3Logs(sqlContext.sparkContext, logsBucket, logsPrefix) 94 | rawS3LogsToHiveTable(rawLogsRDD, sqlContext) 95 | } 96 | } 97 | 98 | -------------------------------------------------------------------------------- /src/main/scala/reinvent/securityanalytics/receivers/NewCloudTrailEventsReceiver.scala: -------------------------------------------------------------------------------- 1 | package reinvent.securityanalytics.receivers 2 | 3 | import com.amazonaws.services.sqs.AmazonSQSClient 4 | import com.amazonaws.services.sqs.model.Message 5 | import com.fasterxml.jackson.databind.JsonNode 6 | import com.fasterxml.jackson.databind.node.TextNode 7 | import org.apache.commons.logging.LogFactory 8 | import reinvent.securityanalytics.utilities._ 9 | import scala.collection.JavaConversions._ 10 | 11 | /**NewCloudTrailEventsReceiver will receive CloudTrail event notifications for new log arrivals, go find those 12 | * logs, parse out the events, and send them along for processing. */ 13 | class NewCloudTrailEventsReceiver(config:Configuration) extends GenericCloudTrailEventsReceiver(config) { 14 | private val noMessageSleepTimeMillis = 15000 //wait 15 seconds between empty message batches 15 | private val SNS_MESSAGE_FIELD_IN_SQS = "Message" 16 | private val S3_BUCKET_FIELD_IN_SNS = "s3Bucket" 17 | private val S3_OBJECT_FIELD_IN_SNS = "s3ObjectKey" 18 | override protected val logger = LogFactory.getLog(this.getClass) 19 | 20 | override def onStart():Unit = { 21 | val sqs = SQSSingleton.client(config.getString(Configuration.REGION)) 22 | 23 | while (run) { 24 | val sqsMessages = sqs.receiveMessage(config.getString(Configuration.CLOUDTRAIL_NEW_LOGS_QUEUE)).getMessages 25 | if (sqsMessages.size() < 1) { 26 | logger.info("No messages are available so we'll take a short nap so we don't hammer SQS.") 27 | Thread.sleep(noMessageSleepTimeMillis) 28 | } 29 | else { 30 | //We get a batch of SQS messages back 31 | sqsMessages.foreach(processSNSMessageInSQSMessage(_, sqs)) 32 | } 33 | } 34 | } 35 | 36 | def processSNSMessageInSQSMessage(sqsMessage:Message, sqs:AmazonSQSClient):Unit = { 37 | processSNSMessageInSQSMessage(sqsMessage, sqs, readAndStoreRawCloudTrailEvents) 38 | } 39 | 40 | /* SNS notifications sent to SQS look like: 41 | { 42 | ... 43 | "Message" : "{\"s3Bucket\":\"cloudtrail17\",\"s3ObjectKey\":[\"pathToCloudTrailObject.json.gz\"]}", 44 | ... 45 | } 46 | See http://docs.aws.amazon.com/awscloudtrail/latest/userguide/configure-cloudtrail-to-send-notifications.html 47 | * */ 48 | def processSNSMessageInSQSMessage(sqsMessage:Message, sqs:AmazonSQSClient, readAndStore:(String,String)=>Unit):Unit = { 49 | val sqsMessageBody = sqsMessage.getBody 50 | //The body is a JSON object, so find the root of the object 51 | val sqsMessageRoot = ObjectMapperSingleton.mapper.readTree(sqsMessageBody) 52 | 53 | //Find the SNS message in the SQS message 54 | val snsMessages = sqsMessageRoot.findValuesAsText(SNS_MESSAGE_FIELD_IN_SQS) 55 | 56 | if (snsMessages.size < 1) { 57 | logger.info("Could not find SNS message in SQS message: " + sqsMessageBody) 58 | } 59 | else { 60 | //There should only be one SNS message, but in case there are more we are doing foreach 61 | snsMessages.foreach(processNewCloudTrailLogsSNSNotification(_, sqsMessage, sqs, readAndStore)) 62 | } 63 | } 64 | 65 | //Find and process the notification in the SNS message 66 | def processNewCloudTrailLogsSNSNotification(snsMessage:String, sqsMessage:Message, sqs:AmazonSQSClient, readAndStore:(String, String)=>Unit) { 67 | val sqsMessageBody = sqsMessage.getBody 68 | //Find the root of the SNS message, which is also a JSON object 69 | val snsMessageRoot = ObjectMapperSingleton.mapper.readTree(snsMessage) 70 | val s3BucketList = snsMessageRoot.findValuesAsText(S3_BUCKET_FIELD_IN_SNS) //Find the bucket name 71 | if (s3BucketList.size() != 1) { 72 | logger.error("Was expecting to find 1 S3 bucket and found " + s3BucketList.size() + " instead in " + sqsMessageBody) 73 | } else { 74 | val s3Bucket = s3BucketList.head //Get the one bucket name 75 | 76 | val s3ObjectListNode = snsMessageRoot.get(S3_OBJECT_FIELD_IN_SNS) 77 | val s3ObjectList = s3ObjectListNode.elements().toSet 78 | if (s3ObjectList.size < 1) { 79 | logger.error("Found empty S3 object list in " + sqsMessageBody) 80 | } 81 | else { 82 | s3ObjectList.foreach((node:JsonNode) => { 83 | if (node.isInstanceOf[TextNode]) { 84 | val s3Key = node.asInstanceOf[TextNode].textValue() 85 | //Get and store individual CloudTrail events 86 | readAndStore(s3Bucket, s3Key) 87 | 88 | //Delete the message since we've successfully processed it. 89 | sqs.deleteMessage(config.getString(Configuration.CLOUDTRAIL_NEW_LOGS_QUEUE), sqsMessage.getReceiptHandle) 90 | } 91 | else { 92 | logger.error("Could not read S3 object key in " + node) 93 | } 94 | }) 95 | } 96 | } 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # timely-security-analytics Overview 2 | This repo contains demo code for the Timely Security Analytics and Analysis presentation at the Amazon Web Services 2015 Re:Invent conference. We are open sourcing this project so that others may learn from it and potentially build on it. It really contains three interesting, independent units: 3 | * CloudTrailToSQL - See the comments at the top of the file which contain instructions to convert your AWS CloudTrail logs to a Spark Data Frame which you can then query with SQL 4 | * reinvent.securityanalytics.CloudTrailProfileAnalyzer is a Spark Streaming application which will read your CloudTrail logs (new and historical), build profiles based on those logs, and alert you when activity deviates from those profiles. 5 | * S3LogsToSQL is alpha code to convert your Amazon S3 logs into tables which can be queried. 6 | 7 | This page will be updated with a link to the video of the talk, when it is available. 8 | 9 | #CloudTrailToSQL 10 | Do you want to run SQL queries over your CloudTrail logs? Well, you're in the right place. This code is written so you can cut-and-paste it into a spark-shell and then start running SQL queries (hence why there is no package in this code and I've used as few libraries as possible). All the libraries it requires are already included in the build of Spark 1.4.1 that's available through Amazon's Elastic MapReduce (EMR). For more information, see https://aws.amazon.com/blogs/aws/new-apache-spark-on-amazon-emr/ 11 | 12 | ##How to use it 13 | 1. Provision an EMR cluster with Spark 1.4.1 and an IAM role that has CloudTrail and S3 describe and read permissions. 14 | 2. SSH to that cluster and run that spark-shell, e.g. spark-shell --master yarn-client --num-executors 40 --conf spark.executor.cores=2 15 | 3. Cut and paste the contents of CloudTrailToSQL.scala (found in this package) into your Spark Shell (once the scala> prompt is available) 16 | 4. Run the following commands: 17 | ``` 18 | var cloudtrail = CloudTrailToSQL.createTable(sc, sqlContext) //creates and registers the Spark SQL table 19 | CloudTrailToSQL.runSampleQuery(sqlContext) //runs a sample query 20 | ``` 21 | Note that these commands will take some time to run as they load your CloudTrail data from S3 and store it in-memory on the Spark cluster. Run the sample query again and you'll see the speed up that the in-memory caching provides. 22 | 5. Run any SQL query you want over the data, e.g. 23 | ``` 24 | sqlContext.sql("select distinct eventSource, eventName, userIdentity.principalId from cloudtrail where userIdentity.principalId = userIdentity.accountId").show(99999) //Find services and APIs called with root credentials 25 | ``` 26 | 6. You can create a Hive table (that will persist after your program exits) by running 27 | ``` 28 | var cloudtrail = CloudTrailToSQL.createHiveTable(sc, sqlContext) 29 | ``` 30 | ##Additional uses 31 | You can configure and invoke geoIP lookup functions using code like that below. To do this, you will need a copy of the Maxmind GeoIP database. See the Dependencies section of this documentation. 32 | ``` 33 | import reinvent.securityanalytics.utilities.Configuration 34 | import reinvent.securityanalytics.GeoIPLookup 35 | var config = new Configuration("", "config/reinventConfig.properties") 36 | var geoIP = new GeoIPLookup(config) 37 | geoIP.registerUDFs(sqlContext) //Registers UDFs that you can use for lookups. 38 | sqlContext.sql("select distinct sourceIpAddress, city(sourceIpAddress), country(sourceIpAddress) from cloudtrail").collect.foreach(println) 39 | ``` 40 | #CloudTrailProfileAnalyzer 41 | ##How to use it 42 | 1. Fill out a config file and load it in S3 43 | 2. (Optional) License Maxmind's GeoIP DB to get use of the GeoIP functionality 44 | 3. Start an EMR cluster with Spark 1.4.1 45 | 4. Compile the code with "mvn package" 46 | 5. Upload the fat jar (e.g., cloudtrailanalysisdemo-1.0-SNAPSHOT-jar-with-dependencies.jar) to your EMR cluster 47 | 6. Submit it using spark-submit. See resources/startStreaming.sh for an example. Make sure to pass the bucket and key that points to your config file. 48 | 7. Look for alerts via the subscriptions set up on your SNS topic. 49 | 50 | ##Future work 51 | * The current code builds profiles based on deduplicated historical data, which doesn't allow for easy frequency analysis. Passing the data, with duplicates, would allow for more meaningful analysis. 52 | * The current code operates on a per-field basis and should be improved to allow profilers to operate on the full activity. For example, its hard prevent duplicate alarms on Tor usage since Tor usage, by definition, comes from different exit nodes. It would make more sense to ignore duplicates based on the actor rather than the source IP address. 53 | 54 | #Dependencies 55 | This code has the key dependencies described below. For a full list, including versions, please see the pom.xml file included in the repo. 56 | * Apache Spark is our core processing engine. 57 | * AWS Java SDK is how we communicate with AWS. 58 | * Scala version 2.10 is the language in which this code is written. 59 | * The Maxmind [GeoIP database](http://dev.maxmind.com/geoip/geoip2/downloadable/). 60 | * The Maxmind [GeoIP2 Java library](https://github.com/maxmind/GeoIP2-java). 61 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 3 | 4.0.0 4 | reinvent.securityanalytics 5 | cloudtrailanalysisdemo 6 | jar 7 | 1.0-SNAPSHOT 8 | 9 | cloudtrailanalysisdemo 10 | http://maven.apache.org 11 | 12 | 13 | 14 | scala-tools.org 15 | Scala-tools Maven2 Repository 16 | http://scala-tools.org/repo-releases 17 | 18 | 19 | 20 | 21 | org.scala-lang 22 | scala-library 23 | 2.10.6 24 | 25 | 26 | com.amazonaws 27 | aws-java-sdk 28 | 1.10.20 29 | 30 | 31 | org.apache.spark 32 | spark-core_2.10 33 | 1.4.1 34 | 35 | 36 | com.maxmind.geoip2 37 | geoip2 38 | 2.3.1 39 | 40 | 41 | org.apache.spark 42 | spark-streaming_2.10 43 | 1.4.1 44 | 45 | 46 | org.apache.spark 47 | spark-sql_2.10 48 | 1.4.1 49 | 50 | 51 | org.apache.spark 52 | spark-hive_2.10 53 | 1.4.1 54 | 55 | 56 | org.scalatest 57 | scalatest_2.10 58 | 2.2.5 59 | test 60 | 61 | 62 | junit 63 | junit 64 | 4.11 65 | test 66 | 67 | 68 | org.mockito 69 | mockito-all 70 | 1.10.19 71 | test 72 | 73 | 74 | 75 | 76 | 77 | org.scala-tools 78 | maven-scala-plugin 79 | 2.15.2 80 | 81 | 82 | 83 | compile 84 | testCompile 85 | 86 | 87 | 88 | 89 | src/main/scala 90 | 91 | -Xms64m 92 | -Xmx1024m 93 | 94 | 95 | 96 | 97 | org.apache.maven.plugins 98 | maven-assembly-plugin 99 | 2.4.1 100 | 101 | 102 | 103 | jar-with-dependencies 104 | 105 | 106 | 107 | 108 | reinvent.securityanalytics.CloudTrailProfileAnalyzer 109 | 110 | 111 | 112 | 113 | 114 | 115 | make-assembly 116 | 117 | package 118 | 119 | single 120 | 121 | 122 | 123 | 124 | 125 | 126 | org.apache.maven.plugins 127 | maven-surefire-plugin 128 | 2.7 129 | 130 | true 131 | 132 | 133 | 134 | 135 | org.scalatest 136 | scalatest-maven-plugin 137 | 1.0 138 | 139 | ${project.build.directory}/surefire-reports 140 | . 141 | WDF TestSuite.txt 142 | 143 | 144 | 145 | test 146 | 147 | test 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | -------------------------------------------------------------------------------- /src/main/scala/reinvent/securityanalytics/utilities/CloudTrailS3Utilities.scala: -------------------------------------------------------------------------------- 1 | package reinvent.securityanalytics.utilities 2 | 3 | import java.util.zip.GZIPInputStream 4 | 5 | import com.amazonaws.regions.{Region, RegionUtils} 6 | import com.amazonaws.services.cloudtrail.AWSCloudTrailClient 7 | import com.amazonaws.services.cloudtrail.model.Trail 8 | import com.amazonaws.services.s3.AmazonS3Client 9 | import com.amazonaws.services.s3.model.S3ObjectSummary 10 | import com.fasterxml.jackson.databind.{JsonNode, ObjectMapper} 11 | import org.apache.commons.logging.LogFactory 12 | import org.apache.spark.{SparkContext, Logging} 13 | import org.apache.spark.rdd.RDD 14 | 15 | import scala.collection.JavaConversions._ 16 | import scala.collection.mutable 17 | import scala.collection.mutable.Buffer 18 | import scala.io.Source 19 | 20 | /**Functions for loading CloudTrail events from S3*/ 21 | object CloudTrailS3Utilities { 22 | private val logger = LogFactory.getLog(this.getClass) 23 | private val mapper = new ObjectMapper() 24 | 25 | /**Query CloudTrail to learn where (i.e., which S3 buckets) CloudTrail data is being stored. If more than one trail 26 | * logs to the same bucket we will de-dupe so that we don't read from the bucket twice. 27 | * @return a set of (bucket,prefix) pairs where CloudTrail logs can be found.*/ 28 | def getCloudTrailBucketsAndPrefixes:mutable.Set[(String,String)] = { 29 | val trailSet:mutable.Set[(String,String)] = new mutable.HashSet[(String,String)] //Sets will remove duplicates 30 | 31 | //For each AWS region, find the CloudTrail trails in that region. 32 | RegionUtils.getRegions.foreach((region:Region) => { 33 | try { 34 | logger.info("Looking for CloudTrail trails in " + region) 35 | val ct = new AWSCloudTrailClient 36 | ct.setRegion(region) 37 | 38 | val trails = ct.describeTrails().getTrailList //Retrieve the list of trails (usually 1 or 0) 39 | if (trails.size() > 0) { 40 | logger.info("Trails found: " + trails) 41 | trails.foreach((trail: Trail) => { 42 | trailSet.add((trail.getS3BucketName, trail.getS3KeyPrefix)) 43 | }) 44 | } 45 | else { 46 | logger.error("No CloudTrail trail configured in " + region + ". Go turn Cloudtrail on!") 47 | } 48 | } 49 | catch { 50 | case e: java.lang.Throwable => { 51 | logger.error("Problem with region: " + region + ". Error: " + e) 52 | Iterator.empty 53 | } 54 | } 55 | }) 56 | 57 | logger.info("Found " + trailSet.size + " S3 buckets and prefixed to explore (after removing duplicates)") 58 | trailSet 59 | } 60 | 61 | /** See getCloudTrailBucketsAndPrefixes except that we will return a list of all S3 objects that have CloudTrail data 62 | * @return a set of (bucket, S3 key) pairs that contain CloudTrail data.*/ 63 | def getCloudTrailS3Objects:mutable.Set[(String,String)] = { 64 | val trailSet = getCloudTrailBucketsAndPrefixes 65 | getCloudTrailS3Objects(trailSet) 66 | } 67 | 68 | /** See getCloudTrailBucketsAndPrefixes except that we will return a list of all S3 objects that have CloudTrail data 69 | * @param trailSet a set of (bucket,prefix) pairs to be explored for CloudTrail data 70 | * @return a set of (bucket, S3 key) pairs that contain CloudTrail data.*/ 71 | def getCloudTrailS3Objects(trailSet:mutable.Set[(String,String)]):mutable.Set[(String,String)] = { 72 | trailSet.flatMap((trailS3Info:(String,String)) => { 73 | findCloudTrailDataInBucket(trailS3Info._1, trailS3Info._2) 74 | }) 75 | } 76 | 77 | /**For a given bucket, go find the keys under the given prefix. S3 returns results in batches so this will page 78 | * through the batches and return the full set of keys 79 | * @param bucket the S3 bucket in which to look for CloudTraildata 80 | * @param prefix the S3 prefix in that bucket to look for CloudTrail data 81 | * @return a buffer of (bucket, S3 key) pairs that contain CloudTrail data.*/ 82 | def findCloudTrailDataInBucket(bucket:String, prefix:String):mutable.Buffer[(String,String)] = { 83 | /*We create a new S3 client each time this method is called. This is necessary because the AmazonS3Client is not 84 | * serializable to Spark executors. An improvement is to use a proxy pattern that returns a singleton instead of 85 | * creating a new client. However, to keep it simple, we're just going to stick with one client per bucket. */ 86 | val s3 = new AmazonS3Client 87 | 88 | logger.info("Looking in " + bucket + "/" + prefix + " (null means there's no prefix) for CloudTrail logs") 89 | var objectList = s3.listObjects(bucket, prefix) 90 | 91 | if (objectList.getObjectSummaries.size() > 0) { 92 | val cloudTrailS3Objects = getCloudTrailS3Keys(objectList.getObjectSummaries) //Get the first batch 93 | 94 | while (objectList.isTruncated) { //If there is more data to be retrieved... 95 | logger.info("Looking for another batch of S3 objects...") 96 | objectList = s3.listNextBatchOfObjects(objectList) //... get the next batch ... 97 | cloudTrailS3Objects ++= getCloudTrailS3Keys(objectList.getObjectSummaries) //... and add it to the original batch 98 | } 99 | logger.info("Found " + cloudTrailS3Objects.size + " S3 objects with CloudTrail data") 100 | cloudTrailS3Objects.map((key:String) => (bucket, key)) //We will need to bucket later, so add it. 101 | } 102 | else { 103 | logger.error("No S3 objects found! bucket=" + bucket + " prefix=" + prefix) 104 | Buffer.empty 105 | } 106 | } 107 | 108 | /** Take an S3 object list, remove irrelevant objects, and extract the key. 109 | * @param objectList an object list returned by S3 110 | * @return a buffer (list) of S3 keys */ 111 | def getCloudTrailS3Keys(objectList:java.util.List[S3ObjectSummary]):Buffer[String] = { 112 | objectList.filter((summary: S3ObjectSummary) => { 113 | isCompressedJson(summary.getKey) 114 | }).map((summary:S3ObjectSummary) => { 115 | summary.getKey 116 | }) 117 | } 118 | 119 | /** Starting from a SparkContext, retrieve the raw CloudTrail data 120 | * @param sparkContext the sparkContext 121 | * @return an RDD of raw CloudTrail strings*/ 122 | def loadFromS3(sparkContext:SparkContext):RDD[String] = { 123 | val list:mutable.Set[(String,String)] = getCloudTrailS3Objects 124 | val rdd:RDD[(String,String)] = sparkContext.parallelize(list.toSeq) 125 | loadFromS3(rdd) 126 | } 127 | 128 | /** Starting from an RDD of (bucket,key) pairs, retrieve the raw CloudTrail data 129 | * @param bucketKeyPairs the S3 (bucket,key) pairs that contain CloudTrail data 130 | * @return an RDD of raw CloudTrail strings*/ 131 | def loadFromS3(bucketKeyPairs:RDD[(String,String)]):RDD[String] = { 132 | bucketKeyPairs.flatMap(bucketKeyPair => { 133 | val s3 = new AmazonS3Client() 134 | readRawCloudTrailEventsFromS3Object(bucketKeyPair._1, bucketKeyPair._2, s3) 135 | }) 136 | } 137 | 138 | /**@param bucket the S3 bucket that contains the CloudTrail data 139 | * @param key the S3 key of the object holding the CloudTrail data 140 | * @return an iterator over the raw CloudTrail data*/ 141 | def readRawCloudTrailEventsFromS3Object(bucket:String, key:String, s3:AmazonS3Client):Iterator[String] = { 142 | try { 143 | Source.fromInputStream(new GZIPInputStream(s3.getObject(bucket, key).getObjectContent)).getLines() 144 | } 145 | catch { 146 | case (e:Exception) => { 147 | logger.error("Could not read CloudTrail log from s3://" + bucket + "/" + key, e) 148 | Iterator.empty 149 | } 150 | } 151 | } 152 | 153 | /**Convert a JSON array of CloudTrail events into an iterator of strings where each one is a CloudTrail event 154 | * @param line the JSON array fo CloudTrail events 155 | * @return an iterator (iterable, really) over strings where each one is a CloudTrail event*/ 156 | def cloudTrailDataToEvents(line:String):Iterable[String] = { 157 | val root = mapper.readTree(line) //Find the root node 158 | val nodeIterator: Iterable[JsonNode] = root.flatMap((n: JsonNode) => n.iterator()) //get an iterator over each Cloudtrail event 159 | val jsonStrings: Iterable[String] = nodeIterator.map((n: JsonNode) => n.toString) //get the string representation for each event 160 | jsonStrings 161 | } 162 | 163 | /** Turn raw CloudTrail data (arrays of events) into one JSON string per event 164 | * @param rawCloudTrailData CloudTrail strings that have been read from S3 in their original format 165 | * @return an RDD of strings where each string is an individual CloudTrail event. This is the format that Spark 166 | * needs them in. */ 167 | def readCloudtrailRecords(rawCloudTrailData:Iterator[String]):Iterator[String] = { 168 | val cloudTrailRecords = rawCloudTrailData.flatMap(cloudTrailDataToEvents) 169 | cloudTrailRecords 170 | } 171 | 172 | private def isCompressedJson(name:String):Boolean = { 173 | name.endsWith(".json.gz") 174 | } 175 | } 176 | -------------------------------------------------------------------------------- /src/main/scala/CloudTrailToSQL.scala: -------------------------------------------------------------------------------- 1 | import java.util.zip.GZIPInputStream 2 | import com.amazonaws.regions.RegionUtils 3 | import com.amazonaws.services.cloudtrail.AWSCloudTrailClient 4 | import com.amazonaws.services.cloudtrail.model.Trail 5 | import com.amazonaws.services.s3.AmazonS3Client 6 | import com.amazonaws.services.s3.model.S3ObjectSummary 7 | import com.fasterxml.jackson.databind.{ObjectMapper, JsonNode} 8 | import org.apache.spark.rdd.RDD 9 | import org.apache.spark.sql.hive.HiveContext 10 | import org.apache.spark.sql.{DataFrame, SQLContext} 11 | import org.apache.spark.{SparkContext, Logging} 12 | import com.amazonaws.regions.Region 13 | import scala.collection.mutable 14 | import scala.collection.mutable.Buffer 15 | import scala.io.Source 16 | import scala.collection.JavaConversions._ 17 | 18 | /*Copyright 2015 Amazon.com, Inc. or its affiliates. All Rights Reserved. 19 | 20 | Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with 21 | the License. A copy of the License is located at 22 | 23 | http://aws.amazon.com/apache2.0/ 24 | 25 | or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 26 | CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions 27 | and limitations under the License.*/ 28 | 29 | /** Do you want to run SQL queries over your CloudTrail logs? Well, you're in the right place. This code is written 30 | * so you can cut-and-paste it into a spark-shell and then start running SQL queries (hence why there is no package 31 | * in this code and I've used as few libraries as possible). All the libraries it requires are already included in 32 | * the build of Spark 1.4.1 that's available through Amazon's Elastic MapReduce (EMR). For more information, see 33 | * https://aws.amazon.com/blogs/aws/new-apache-spark-on-amazon-emr/ 34 | * 35 | * Quick Start: To use this code, do the following: 36 | * 1. Provision an EMR cluster with Spark 1.4.1 and an IAM role that has CloudTrail and S3 describe and read permissions. 37 | * 2. SSH to that cluster and run that spark-shell, e.g.: 38 | spark-shell --master yarn-client --num-executors 40 --conf spark.executor.cores=2 39 | * 3. Cut and paste all of this code into your Spark Shell (once the scala> prompt is available) 40 | * 4. Run the following commands: 41 | var cloudtrail = CloudTrailToSQL.createTable(sc, sqlContext) //creates and registers the Spark SQL table 42 | CloudTrailToSQL.runSampleQuery(sqlContext) //runs a sample query 43 | 44 | Note that these commands will take some time to run as they load your CloudTrail data from S3 and store it in-memory 45 | on the Spark cluster. Run the sample query again and you'll see the speed up that the in-memory caching provides. 46 | * 5. Run any SQL query you want over the data, e.g. 47 | sqlContext.sql("select distinct eventSource, eventName, userIdentity.principalId from cloudtrail where userIdentity.principalId = userIdentity.accountId").show(99999) //Find services and APIs called with root credentials 48 | */ 49 | object CloudTrailToSQL extends Logging { 50 | private val cloudTrailTableName = "cloudtrail" 51 | 52 | /**Query CloudTrail to learn where (i.e., which S3 buckets) CloudTrail data is being stored. If more than one trail 53 | * logs to the same bucket we will de-dupe so that we don't read from the bucket twice. 54 | * @return a set of (bucket,prefix) pairs where CloudTrail logs can be found.*/ 55 | def getCloudTrailBucketsAndPrefixes:mutable.Set[(String,String)] = { 56 | val trailSet:mutable.Set[(String,String)] = new mutable.HashSet[(String,String)] //Sets will remove duplicates 57 | 58 | //For each AWS region, find the CloudTrail trails in that region. 59 | RegionUtils.getRegions.foreach((region:Region) => { 60 | try { 61 | logInfo("Looking for CloudTrail trails in " + region) 62 | val ct = new AWSCloudTrailClient 63 | ct.setRegion(region) 64 | 65 | val trails = ct.describeTrails().getTrailList //Retrieve the list of trails (usually 1 or 0) 66 | if (trails.size() > 0) { 67 | logInfo("Trails found: " + trails) 68 | trails.foreach((trail: Trail) => { 69 | trailSet.add((trail.getS3BucketName, trail.getS3KeyPrefix)) 70 | }) 71 | } 72 | else { 73 | logError("No CloudTrail trail configured in " + region + ". Go turn Cloudtrail on!") 74 | } 75 | } 76 | catch { 77 | case e: java.lang.Throwable => { 78 | logError("Problem with region: " + region + ". Error: " + e) 79 | Iterator.empty 80 | } 81 | } 82 | }) 83 | 84 | logInfo("Found " + trailSet.size + " S3 buckets and prefixed to explore (after removing duplicates)") 85 | trailSet 86 | } 87 | 88 | /** See getCloudTrailBucketsAndPrefixes except that we will return a list of all S3 objects that have CloudTrail data 89 | * @return a set of (bucket, S3 key) pairs that contain CloudTrail data.*/ 90 | def getCloudTrailS3Objects:mutable.Set[(String,String)] = { 91 | val trailSet = getCloudTrailBucketsAndPrefixes 92 | getCloudTrailS3Objects(trailSet) 93 | } 94 | 95 | /** See getCloudTrailBucketsAndPrefixes except that we will return a list of all S3 objects that have CloudTrail data 96 | * @param trailSet a set of (bucket,prefix) pairs to be explored for CloudTrail data 97 | * @return a set of (bucket, S3 key) pairs that contain CloudTrail data.*/ 98 | def getCloudTrailS3Objects(trailSet:mutable.Set[(String,String)]):mutable.Set[(String,String)] = { 99 | trailSet.flatMap((trailS3Info:(String,String)) => { 100 | findCloudTrailDataInBucket(trailS3Info._1, trailS3Info._2) 101 | }) 102 | } 103 | 104 | /**For a given bucket, go find the keys under the given prefix. S3 returns results in batches so this will page 105 | * through the batches and return the full set of keys 106 | * @param bucket the S3 bucket in which to look for CloudTraildata 107 | * @param prefix the S3 prefix in that bucket to look for CloudTrail data 108 | * @return a buffer of (bucket, S3 key) pairs that contain CloudTrail data.*/ 109 | def findCloudTrailDataInBucket(bucket:String, prefix:String):mutable.Buffer[(String,String)] = { 110 | /*We create a new S3 client each time this method is called. This is necessary because the AmazonS3Client is not 111 | * serializable to Spark executors. An improvement is to use a proxy pattern that returns a singleton instead of 112 | * creating a new client. However, to keep it simple, we're just going to stick with one client per bucket. */ 113 | val s3 = new AmazonS3Client 114 | 115 | logInfo("Looking in " + bucket + "/" + prefix + " (null means there's no prefix) for CloudTrail logs") 116 | var objectList = s3.listObjects(bucket, prefix) 117 | 118 | if (objectList.getObjectSummaries.size() > 0) { 119 | val cloudTrailS3Objects = getCloudTrailS3Keys(objectList.getObjectSummaries) //Get the first batch 120 | 121 | while (objectList.isTruncated) { //If there is more data to be retrieved... 122 | logInfo("Looking for another batch of S3 objects...") 123 | objectList = s3.listNextBatchOfObjects(objectList) //... get the next batch ... 124 | cloudTrailS3Objects ++= getCloudTrailS3Keys(objectList.getObjectSummaries) //... and add it to the original batch 125 | } 126 | logInfo("Found " + cloudTrailS3Objects.size + " S3 objects with CloudTrail data") 127 | cloudTrailS3Objects.map((key:String) => (bucket, key)) //We will need to bucket later, so add it. 128 | } 129 | else { 130 | logError("No S3 objects found! bucket=" + bucket + " prefix=" + prefix) 131 | Buffer.empty 132 | } 133 | } 134 | 135 | /** Take an S3 object list, remove irrelevant objects, and extract the key. 136 | * @param objectList an object list returned by S3 137 | * @return a buffer (list) of S3 keys */ 138 | def getCloudTrailS3Keys(objectList:java.util.List[S3ObjectSummary]):Buffer[String] = { 139 | objectList.filter((summary: S3ObjectSummary) => { 140 | isCompressedJson(summary.getKey) 141 | }).map((summary:S3ObjectSummary) => { 142 | summary.getKey 143 | }) 144 | } 145 | 146 | /** Starting from a SparkContext, retrieve the raw CloudTrail data 147 | * @param sparkContext the sparkContext 148 | * @return an RDD of raw CloudTrail strings*/ 149 | def loadFromS3(sparkContext:SparkContext):RDD[String] = { 150 | val list:mutable.Set[(String,String)] = getCloudTrailS3Objects 151 | val rdd:RDD[(String,String)] = sparkContext.parallelize(list.toSeq) 152 | loadFromS3(rdd) 153 | } 154 | 155 | /** Starting from an RDD of (bucket,key) pairs, retrieve the raw CloudTrail data 156 | * @param bucketKeyPairs the S3 (bucket,key) pairs that contain CloudTrail data 157 | * @return an RDD of raw CloudTrail strings*/ 158 | def loadFromS3(bucketKeyPairs:RDD[(String,String)]):RDD[String] = { 159 | bucketKeyPairs.flatMap(bucketKeyPair => { 160 | val s3 = new AmazonS3Client() 161 | Source.fromInputStream(new GZIPInputStream(s3.getObject(bucketKeyPair._1, bucketKeyPair._2).getObjectContent)).getLines() 162 | }) 163 | } 164 | 165 | /** Turn raw CloudTrail data (arrays of events) into one JSON string per event 166 | * @param rawCloudTrailData CloudTrail strings that have been read from S3 in their original format 167 | * @return an RDD of strings where each string is an individual CloudTrail event. This is the format that Spark 168 | * needs them in. */ 169 | def readCloudtrailRecordsFromRDD(rawCloudTrailData:RDD[String]):RDD[String] = { 170 | val cloudTrailRecords = rawCloudTrailData.flatMap((line:String) => { 171 | val mapper = new ObjectMapper() //A singleton pattern might be beneficial here as well. 172 | val root = mapper.readTree(line) //Find the root node 173 | 174 | //Get an iterator over each Cloudtrail event 175 | val nodeIterator: Iterable[JsonNode] = root.flatMap((n: JsonNode) => n.iterator()) 176 | 177 | //Get the string representation for each event 178 | val jsonStrings: Iterable[String] = nodeIterator.map((n: JsonNode) => n.toString) 179 | jsonStrings 180 | }) 181 | cloudTrailRecords.persist() //Store these events in memory so the first query doesn't have to re-fetch from S3 182 | cloudTrailRecords 183 | } 184 | 185 | /**Main entry point: convert all CloudTrail logs into a Spark DataFrame, register it as a table 186 | * @param sc the Spark Context 187 | * @param sqlContext the SQL Context 188 | * @return a DataFrame that represents all your CloudTrail logs*/ 189 | def createTable(sc:SparkContext, sqlContext:SQLContext):DataFrame = { 190 | val rawCloudTrailData:RDD[String] = loadFromS3(sc) 191 | val individualCloudTrailEvents:RDD[String] = readCloudtrailRecordsFromRDD(rawCloudTrailData) 192 | val cloudtrailRecordsDataFrame:DataFrame = sqlContext.read.json(individualCloudTrailEvents) 193 | cloudtrailRecordsDataFrame.cache() //After your first query, all data will be cached in memory 194 | 195 | //Enable querying as the given table name via the SQL context 196 | cloudtrailRecordsDataFrame.registerTempTable(cloudTrailTableName) 197 | cloudtrailRecordsDataFrame 198 | } 199 | 200 | def createHiveTable(sc:SparkContext, hiveContext:SQLContext):DataFrame = { 201 | require(hiveContext.isInstanceOf[HiveContext], "You must pass a SQL context that is a HiveContext. Use the sqlContext val created for you.") 202 | val rawCloudTrailData = loadFromS3(sc) 203 | val individualCloudTrailEvents = readCloudtrailRecordsFromRDD(rawCloudTrailData) 204 | val hiveCloudTrailDataFrame = hiveContext.read.json(individualCloudTrailEvents) 205 | hiveCloudTrailDataFrame.cache() 206 | hiveCloudTrailDataFrame.write.saveAsTable(cloudTrailTableName+"hive") 207 | hiveCloudTrailDataFrame 208 | } 209 | 210 | def runSampleQuery(sqlContext:SQLContext) = { 211 | sqlContext.sql("select distinct userIdentity.principalId, sourceIPAddress, userIdentity.accessKeyId from " + cloudTrailTableName + " order by accessKeyId").show(10000) 212 | } 213 | 214 | private def isCompressedJson(name:String):Boolean = { 215 | name.endsWith(".json.gz") 216 | } 217 | } 218 | -------------------------------------------------------------------------------- /src/main/scala/reinvent/securityanalytics/CloudTrailProfileAnalyzer.scala: -------------------------------------------------------------------------------- 1 | package reinvent.securityanalytics 2 | 3 | import org.apache.commons.logging.LogFactory 4 | import org.apache.spark._ 5 | import org.apache.spark.rdd.RDD 6 | import org.apache.spark.streaming.{Duration, Seconds, StreamingContext} 7 | import org.apache.spark.streaming.dstream.DStream 8 | import reinvent.securityanalytics.profilers._ 9 | import reinvent.securityanalytics.receivers.{NewCloudTrailEventsReceiver, ReplayExistingCloudTrailEventsReceiver} 10 | import reinvent.securityanalytics.state._ 11 | import reinvent.securityanalytics.utilities._ 12 | import scala.collection.JavaConversions._ 13 | import scala.collection.mutable 14 | 15 | /*Copyright 2015 Amazon.com, Inc. or its affiliates. All Rights Reserved. 16 | 17 | Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with 18 | the License. A copy of the License is located at 19 | 20 | http://aws.amazon.com/apache2.0/ 21 | 22 | or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 23 | CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions 24 | and limitations under the License.*/ 25 | 26 | /**CloudTrailProfileAnalyzer uses Spark Streaming to analyze CloudTrail data as soon as it is available in S3. It uses 27 | * this data to build a "profile" of the activity seen to date and then checks new activity against this profile. When 28 | * the new activity doesn't match the historical activity, it sends an alert to the SNS topic specified in the 29 | * configuration. 30 | * 31 | * The stock profiling logic is encapsulated in "profilers," specifically the following: 32 | * AccessKeyIDProfiler will alert you when a new long-term (i.e., not temporary) access key is used with your account. 33 | * This may be interesting to you if you want to learn about keys which were dormant (i.e., unused) that are now in 34 | * use. 35 | * A GenericCloudTrailProfiler will alert you when a new principal, as identified by the ARN, interacts with your 36 | * AWS resources. 37 | * A GenericCloudTrailProfiler will alert you when a new principal, as identified by the Principal ID, interacts with 38 | * your AWS resources. "Why do you have two different principal ID alerts?" The alerts provide different resolution. 39 | * The ARN alert will alert you when a new role name is used with a role principal ID (i.e., ARO...) whereas the 40 | * principal ID alert will not. If you'd like more details, use the principal ARN alert. If you'd like fewer alerts, 41 | * potentially at the expense of missing important details, use the principal ID alert. 42 | * A GenericCloudTrailProfiler will alert you when AWS activity involving your AWS resources comes from a previously- 43 | * unseen Source IP Address. 44 | * A GeoIPCityProfiler will alert you when AWS activity involving your AWS resources comes from a new city, as 45 | * determined by a GeoIP database. 46 | * A GeoIPCountryProfiler will alert you when AWS activity involving your AWS resources comes from a new country, as 47 | * determined by a GeoIP database. 48 | * A TorProfiler which will alert you when AWS activity involving your AWS resources originates from the Tor anonymizing 49 | * network. While this may be fine for some administrators, this may be unacceptable to others who would see the 50 | * use of an anonymizing proxy as a sign of hiding one's tracks. 51 | * 52 | * You can either build simple profiles using profile.GenericCloudTrailProfiler or you can build custom profiles by 53 | * extending either profile.ActivityProfiler or profile.GenericCloudTrailProfiler. Other profiler ideas include: 54 | * Alerts when a root account is used to access your account's resources. */ 55 | object CloudTrailProfileAnalyzer { 56 | private val logger = LogFactory.getLog(this.getClass) 57 | 58 | def main(args:Array[String]) = { 59 | require(args.length == 2, "You must supply the S3 bucket and key to a config file") 60 | val config = new Configuration(args(0), args(1)) 61 | 62 | //Each field we're interested in profiling maps to one or more profilers to be applied to it 63 | val fieldToProfilerMap:Map[Field, Set[ActivityProfiler[String]]] = Map( 64 | AWS_ACCESS_KEY_ID -> Set(new AccessKeyIDProfiler(config)), 65 | PRINCIPAL_ARN -> Set(new GenericCloudTrailProfiler[String](PRINCIPAL_ARN, config)), 66 | SOURCE_IP_ADDRESS -> Set( 67 | new GeoIPCityProfiler(config), 68 | new GeoIPCountryProfiler(config), 69 | new TorProfiler(config)) 70 | ) 71 | 72 | def createStreamingContextFunction = { 73 | createStreamingContext(config, fieldToProfilerMap) 74 | } 75 | 76 | val streamingContext = StreamingContext.getOrCreate(config.getString(Configuration.CHECKPOINT_PATH), 77 | createStreamingContextFunction _, new org.apache.hadoop.conf.Configuration(), 78 | true //Start application even if checkpoint restore fails 79 | ) 80 | 81 | streamingContext.start() 82 | streamingContext.awaitTermination() 83 | } 84 | 85 | def createStreamingContext(config:Configuration, 86 | fieldToProfilerMap:Map[Field, Set[ActivityProfiler[String]]]):StreamingContext = { 87 | val sparkConf = new SparkConf().setAppName(config.getString(Configuration.APP_NAME)) 88 | 89 | val duration = Seconds(config.getInt(Configuration.BATCH_INTERVAL_SECONDS)) 90 | 91 | val streamingContext = new StreamingContext(sparkConf, duration) //Set our batch interval 92 | 93 | //Make sure the streaming application can recover after failure 94 | val checkpointPath = config.getString(Configuration.CHECKPOINT_PATH) + "/" + System.currentTimeMillis() 95 | //TODO right now we are creating a new checkpoint per run. This is to work around an error in Spark. To restore 96 | // from this checkpoint, you will need to modify this code. 97 | streamingContext.checkpoint(checkpointPath) 98 | 99 | val sparkContext = streamingContext.sparkContext 100 | 101 | //Initialize any profilers that require initialization 102 | initializeProfilers(fieldToProfilerMap, sparkContext) 103 | 104 | //Create DStream of CloudTrail events from receivers 105 | val cloudTrailEventsStream:DStream[String] = createDStream(streamingContext, config) 106 | 107 | //Create empty initial state 108 | val initialState = createEmptyInitialState(sparkContext, fieldToProfilerMap) 109 | 110 | //Profile incoming CloudTrail data 111 | val stateDStream = profileCloudTrailEvents(cloudTrailEventsStream, sparkContext, initialState, fieldToProfilerMap) 112 | stateDStream.checkpoint(duration) 113 | 114 | //Output the state after the current batch 115 | outputStatusOfProfilersAndProfiles(sparkContext, config, stateDStream) 116 | 117 | streamingContext 118 | } 119 | 120 | //Perform any profiler initialization that may be necessary 121 | def initializeProfilers(fieldToProfilerMap:Map[Field, Set[ActivityProfiler[String]]], sparkContext:SparkContext) = { 122 | fieldToProfilerMap.foreach(pair => { 123 | val profilerSet = pair._2 124 | profilerSet.foreach((profiler:ActivityProfiler[String]) => { 125 | profiler.initialize(sparkContext) 126 | }) 127 | }) 128 | } 129 | 130 | def createDStream(streamingContext:StreamingContext, config:Configuration):DStream[String] = { 131 | //Configure our receiver 132 | val existingEventsReceiver = new ReplayExistingCloudTrailEventsReceiver(config) 133 | val newEventsReceiver = new NewCloudTrailEventsReceiver(config) 134 | 135 | val existingEventsStream = streamingContext.receiverStream(existingEventsReceiver) 136 | val newEventsStream = streamingContext.receiverStream(newEventsReceiver) 137 | 138 | //Union the two event streams so we have one event stream. 139 | existingEventsStream.union(newEventsStream) 140 | } 141 | 142 | /*Output the number of alerts for each profiler (and the fields they are profiling) that have been issued. This 143 | * serves two purposes. First, the operator can evaluate the effectiveness the frequency of alerts. Next, the 144 | * foreachRDD causes the CloudTrail data to be read and processed. Without this, nothing would happen due to 145 | * Spark's lazy loading. */ 146 | def outputStatusOfProfilersAndProfiles(sparkContext:SparkContext, config:Configuration, 147 | stateDStream:DStream[(Field, FieldProfileState)]):Unit = { 148 | val profileAlertStatusAccumulator = sparkContext.accumulableCollection(new mutable.HashSet[String]) 149 | 150 | stateDStream.foreachRDD(fieldStateRDD => { 151 | fieldStateRDD.foreach((fieldStatePair) => { 152 | val field = fieldStatePair._1 153 | val fieldState = fieldStatePair._2 154 | fieldState.mappings.foreach(profilerProfilePair => { 155 | val profilerName = profilerProfilePair._1 156 | val currentProfile = profilerProfilePair._2 157 | val alertCount = currentProfile.alertCount 158 | 159 | val statusMessage = "The " + profilerName + " profiler for the field " + field.name + " has issued " + alertCount + " alerts." 160 | profileAlertStatusAccumulator.add(statusMessage) 161 | }) 162 | }) 163 | 164 | val set = profileAlertStatusAccumulator.value 165 | logger.info("Batch complete. Current state of profilers, profiles, and alerts follows with " + set.size + " messages") 166 | 167 | set.foreach((statusMessage) => { 168 | logger.info(statusMessage) 169 | }) 170 | 171 | //Remove status messages since they've already been written out. 172 | profileAlertStatusAccumulator.setValue(profileAlertStatusAccumulator.zero) 173 | }) 174 | } 175 | 176 | def createEmptyInitialState(sparkContext:SparkContext, 177 | fieldToProfilerMap:Map[Field, Set[ActivityProfiler[String]]]) 178 | :RDD[(Field, FieldProfileState)] = { 179 | logger.info("Creating empty initial state.") 180 | 181 | val initiateState = fieldToProfilerMap.map(pair => { 182 | val field = pair._1 183 | val profilerSet = pair._2 184 | val profilerToNewProfileMappings:Map[String, ActivityProfile[String]] = profilerSet.map((profiler: ActivityProfiler[String]) => { 185 | val profilerKey = profiler.key 186 | (profilerKey, new ActivityProfile(Set.empty[String])) 187 | }).toMap 188 | 189 | (field, new FieldProfileState(profilerToNewProfileMappings)) //Map the field to the map of profilers and profiles for use on in the next batch 190 | }).toSeq 191 | 192 | sparkContext.parallelize(initiateState) 193 | } 194 | 195 | def profileCloudTrailEvents(arrivingEvents:DStream[String], 196 | sparkContext:SparkContext, 197 | initialState:RDD[(Field, FieldProfileState)], 198 | fieldToProfilerMap:Map[Field, Set[ActivityProfiler[String]]]):DStream[(Field, FieldProfileState)] = { 199 | logger.info("Configuring transforms for streaming CloudTrail data") 200 | val targetFieldsAndValues:DStream[(Field, String)] = getTargetFieldsAndValues(arrivingEvents, fieldToProfilerMap) 201 | 202 | val hashPartitioner:Partitioner = new HashPartitioner(sparkContext.defaultParallelism) 203 | 204 | //Create partially-applied function for updatingStateByKey 205 | def updateFunction:(Iterator[(Field, Seq[String], Option[FieldProfileState])] 206 | => Iterator[(Field, FieldProfileState)]) = profileAndUpdateState(_, fieldToProfilerMap) 207 | 208 | val stateDStream = targetFieldsAndValues.updateStateByKey(updateFunction, hashPartitioner, true, initialState) 209 | stateDStream 210 | } 211 | 212 | /*Map from CloudTrail events to pairs of (target fields, list of values) 213 | * Our input is individual CloudTrail events as JSON strings. We need to convert them to JSON trees and 214 | * then find all the values for only the fields we're interested in. We then want to de-duplicate values across 215 | * different CloudTrail events. We want the results as a map between the fields we're interested in and the 216 | * values for those fields.*/ 217 | def getTargetFieldsAndValues(arrivingEvents:DStream[String], 218 | fieldToProfilerMap:Map[Field, Set[ActivityProfiler[String]]]):DStream[(Field, String)] = { 219 | if (fieldToProfilerMap.isEmpty) { 220 | logger.warn("Field to profile map is empty!") 221 | } 222 | 223 | val targetFieldsAndValues:DStream[(Field, String)] = arrivingEvents.flatMap((eventAsJsonString:String) => { 224 | if (eventAsJsonString == null || eventAsJsonString.equals("")) { 225 | logger.warn("An empty string was received instead of a CloudTrail event.") 226 | } 227 | val fieldValuePairs:Set[(Field, String)] = fieldToProfilerMap.keySet.flatMap((fieldToProfile: Field) => { 228 | val fieldValues = findValuesForCloudTrailField(fieldToProfile.name, eventAsJsonString) 229 | 230 | if (fieldValues.isEmpty) { 231 | logger.warn("No values were found for " + fieldToProfile.name + " in " + eventAsJsonString) 232 | } 233 | else { 234 | logger.info("New values for " + fieldToProfile.name + " are " + fieldValues) 235 | } 236 | 237 | //Pair each value up with its field name (and metadata). This will be collapsed later. 238 | fieldValues.map((fieldValue:String) => { 239 | (fieldToProfile, fieldValue) 240 | }) 241 | }) 242 | fieldValuePairs 243 | }) 244 | 245 | targetFieldsAndValues 246 | } 247 | 248 | def profileAndUpdateState(targetFieldsAndValues:Iterator[(Field, Seq[String], Option[FieldProfileState])], 249 | fieldToProfilerMap:Map[Field, Set[ActivityProfiler[String]]]) 250 | :Iterator[(Field, FieldProfileState)] = { 251 | logger.info("Profiling and updating state for all fields") 252 | 253 | if (targetFieldsAndValues.isEmpty) { 254 | logger.error("No target fields and values found!") 255 | } 256 | else { 257 | logger.debug("Target fields and values found.") 258 | } 259 | 260 | val newState = targetFieldsAndValues.map(triple => { 261 | val field = triple._1 262 | val newValues = triple._2 263 | val previousState = triple._3 264 | profileAndUpdateStateForField(field, newValues, previousState, fieldToProfilerMap) 265 | }) 266 | 267 | newState 268 | } 269 | 270 | def profileAndUpdateStateForField(field:Field, 271 | newValues:Seq[String], 272 | previousStateOption:Option[FieldProfileState], 273 | fieldToProfilerMap:Map[Field, Set[ActivityProfiler[String]]]) 274 | :(Field, FieldProfileState) = { 275 | logger.info("Profiling and saving new state for " + field.name) 276 | logger.info("Previous state is " + previousStateOption) 277 | val newActivity = newValues.toSet //Remove duplicates 278 | 279 | val profilerSetOption = fieldToProfilerMap.get(field) 280 | if (profilerSetOption.isDefined) { 281 | val profilerSet = profilerSetOption.get 282 | 283 | val newState:FieldProfileState = if (previousStateOption.isDefined) { 284 | val previousState = previousStateOption.get 285 | 286 | //Iterate over all the profilers for this field and compare against previous profiles 287 | val newMappings = compareAgainstPreviousProfilesAndUpdateProfiles(profilerSet, previousState, 288 | newActivity, field) 289 | logger.info("New mappings: " + newMappings) 290 | previousState.updateProfilerStateMappings(newMappings) 291 | } 292 | else { 293 | logger.warn("No previous state found for comparisons. We will create state from the values in the current run.") 294 | val newMappings = profilerSet.map((profiler: ActivityProfiler[String]) => { 295 | val profilerKey = profiler.key 296 | (profilerKey, new ActivityProfile(newActivity)) 297 | }).toMap 298 | 299 | new FieldProfileState(newMappings) 300 | } 301 | 302 | logger.info("New state is " + newState) 303 | (field, newState) //Map the field to the map of profilers and profiles for use on in the next batch 304 | } 305 | else { 306 | logger.error("Could not find a profiler for the following field: " + field.name) 307 | (field, new FieldProfileState(Map.empty)) //Return empty state since there's no profiler for which to store state 308 | } 309 | } 310 | 311 | def compareAgainstPreviousProfilesAndUpdateProfiles(profilerSet:Set[ActivityProfiler[String]], 312 | previousState:FieldProfileState, 313 | newActivity:Set[String], 314 | field:Field):Map[String, ActivityProfile[String]] = { 315 | logger.info("Comparing new activity against past state for " + field.name + " and updating profiles") 316 | 317 | if (profilerSet.isEmpty) { 318 | logger.warn("Set of profilers is empty.") 319 | } 320 | 321 | if (previousState.mappings.isEmpty) { 322 | logger.warn("No previous state found.") 323 | } 324 | 325 | if (newActivity.isEmpty) { 326 | logger.warn("No new activity found.") 327 | } 328 | 329 | //Iterate over all the profilers for this field and compare against previous profiles 330 | val profilerToNewProfileMappings = profilerSet.map((profiler: ActivityProfiler[String]) => { 331 | val profilerKey = profiler.key 332 | val previousProfileOption = previousState.mappings.get(profilerKey) 333 | if (previousProfileOption.isDefined) { 334 | val previousProfile:ActivityProfile[String] = previousProfileOption.get 335 | logger.info("Old profile was " + previousProfile) 336 | val newProfile:ActivityProfile[String] = profiler.compareNewActivity(newActivity, previousProfile) 337 | logger.info("New profile is " + newProfile) 338 | (profilerKey, newProfile) //Map the profiler to the profile so we can retrieve it later. 339 | } 340 | else { 341 | logger.error("Could not find a previous profile for the following field:" + field.name) 342 | (profilerKey, new ActivityProfile(newActivity)) //Map the profiler to the profile so we can retrieve it later. 343 | } 344 | }).toMap 345 | 346 | if (profilerToNewProfileMappings.isEmpty) { 347 | logger.warn("Profiler to new profile mappings is empty!") 348 | } 349 | 350 | profilerToNewProfileMappings 351 | } 352 | 353 | def findValuesForCloudTrailField(fieldName:String, cloudTrailEvent:String):Set[String] = { 354 | val tree = ObjectMapperSingleton.mapper.readTree(cloudTrailEvent) 355 | tree.findValuesAsText(fieldName).toSet 356 | } 357 | } 358 | --------------------------------------------------------------------------------