├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── build.gradle ├── src ├── main │ ├── java │ │ ├── com │ │ │ └── amazonaws │ │ │ │ ├── dagger │ │ │ │ ├── AWSClientModule.java │ │ │ │ └── LambdaFunctionsComponent.java │ │ │ │ └── lambda │ │ │ │ ├── demandpublishing │ │ │ │ ├── DemandRecord.java │ │ │ │ └── PublishDemandHandler.java │ │ │ │ ├── predictiongeneration │ │ │ │ ├── AbstractPredictionGenerationLambdaHandler.java │ │ │ │ ├── CreateDatasetGroupHandler.java │ │ │ │ ├── CreateDatasetHandler.java │ │ │ │ ├── CreateDatasetImportJobHandler.java │ │ │ │ ├── CreateForecastExportJobHandler.java │ │ │ │ ├── CreateForecastHandler.java │ │ │ │ ├── CreatePredictorHandler.java │ │ │ │ ├── DeleteOutdatedDatasetGroupsHandler.java │ │ │ │ ├── DeleteOutdatedDatasetImportJobsHandler.java │ │ │ │ ├── DeleteOutdatedDatasetsHandler.java │ │ │ │ ├── DeleteOutdatedForecastExportJobsHandler.java │ │ │ │ ├── DeleteOutdatedForecastsHandler.java │ │ │ │ ├── DeleteOutdatedPredictorsHandler.java │ │ │ │ ├── GenerateForecastResourcesIdsCronHandler.java │ │ │ │ ├── GenerateForecastResourcesIdsHandler.java │ │ │ │ ├── PredictionGenerationUtils.java │ │ │ │ └── exception │ │ │ │ │ ├── ResourceCleanupInProgressException.java │ │ │ │ │ ├── ResourceSetupFailureException.java │ │ │ │ │ └── ResourceSetupInProgressException.java │ │ │ │ └── queryingpredictionresult │ │ │ │ ├── LoadDataFromS3ToDynamoDBHandler.java │ │ │ │ └── PredictionResultItem.java │ │ └── log4j2.xml │ └── resources │ │ └── raw_demand_requests.csv └── test │ ├── java │ └── com │ │ └── amazonaws │ │ └── lambda │ │ ├── demandpublishing │ │ └── PublishDemandHandlerTest.java │ │ ├── predictiongeneration │ │ ├── BaseTest.java │ │ ├── CreateDatasetGroupHandlerTest.java │ │ ├── CreateDatasetHandlerTest.java │ │ ├── CreateDatasetImportJobHandlerTest.java │ │ ├── CreateForecastExportJobHandlerTest.java │ │ ├── CreateForecastHandlerTest.java │ │ ├── CreatePredictorHandlerTest.java │ │ ├── DeleteOutdatedDatasetGroupsHandlerTest.java │ │ ├── DeleteOutdatedDatasetImportJobsHandlerTest.java │ │ ├── DeleteOutdatedDatasetsHandlerTest.java │ │ ├── DeleteOutdatedPredictorsHandlerTest.java │ │ ├── GenerateForecastResourcesIdsCronHandlerTest.java │ │ └── GenerateForecastResourcesIdsHandlerTest.java │ │ └── queryingpredictionresult │ │ └── LoadDataFromS3ToDynamoDBHandlerTest.java │ └── resources │ ├── test_raw_demand_requests.csv │ └── tgt │ ├── empty_forecast_export_job_2019-10-16T21-40-00Z_part0.csv │ ├── forecast_export_job1_2019-10-16T21-40-00Z_part0.csv │ └── forecast_export_job_with_one_record_2019-10-16T21-40-00Z_part0.csv └── template.yaml /.gitignore: -------------------------------------------------------------------------------- 1 | AwsCredentials.properties 2 | .idea 3 | target/ 4 | *.iml 5 | .gradle 6 | build/ 7 | gradle/ 8 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *master* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | 61 | We may ask you to sign a [Contributor License Agreement (CLA)](http://en.wikipedia.org/wiki/Contributor_License_Agreement) for larger changes. 62 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | this software and associated documentation files (the "Software"), to deal in 5 | the Software without restriction, including without limitation the rights to 6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | the Software, and to permit persons to whom the Software is furnished to do so. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 10 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 11 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 12 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 13 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 14 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 15 | 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SAM Application for Automated Forecast 2 | 3 | This is a sample application to demonstrate how to build a system around 4 | [the time-series forecasting service Amazon Forecast](https://aws.amazon.com/forecast/), 5 | which can automatically: 6 | * publish the historical demand to S3 bucket as the training data, 7 | * create the machine learning model and generate the prediction result, 8 | * load the latest prediction result to DynamoDB for querying. 9 | 10 | Details can be found in the [blog post](https://aws.amazon.com/blogs/machine-learning/automating-your-amazon-forecast-workflow-with-lambda-step-functions-and-cloudwatch-events-rule/) 11 | 12 | ```bash 13 | . 14 | ├── README.md <-- This instructions file 15 | ├── LICENSE.txt <-- MIT No Attribution License (MIT-0) 16 | ├── NOTICE.txt <-- Copyright notices 17 | ├── build.gradle <-- Java dependencies 18 | ├── src 19 | │ ├── main 20 | │ │ └── resources <-- Contains a dummy demand records csv file used for simulating the database 21 | │ │ └── java 22 | │ │ ├── com.amazonaws.dagger <-- Classes to manage Dagger 2 dependency injection 23 | │ │ │ ├── AWSClientModule.java <-- Provides dependencies like the Forecast client for injection 24 | │ │ │ └── LambdaFunctionsComponent.java <-- Contains inject methods for handler entrypoints 25 | │ │ └── com.amazonaws.lambda <-- Source code for lambda functions 26 | │ │ ├── demandpublishing <-- Lambda functions for demand publishing component 27 | | | | ├── DemandRecord.java <-- POJO shape for parsing the demand record from CSV file 28 | | | | ├── PublishDemandHandler.java <-- Lambda functions for querying the historical demand and publish it to S3 29 | │ │ ├── predictiongeneration <-- Lambda functions for prediction generation component 30 | | | | ├── exception <-- Source code for custom exceptions 31 | | | | | ├── ResourceCleanupInProgressException.java <-- Can be thrown when the resource cannot be immediately deleted 32 | | | | | ├── ResourceSetupFailureException.java <-- Can be thrown when the resource failed to create 33 | | | | | └── ResourceSetupInProgressException.java <-- Can be thrown when the resource cannot be immediately created 34 | | | | ├── PredictionGenerationUtils.java <-- Contains common util methods 35 | | | | ├── GenerateForecastResourcesIdsHandler.java <-- Generate required forecast resource ids for model generation 36 | | | | ├── GenerateForecastResourcesIdsCronHandler.java <-- Generate required forecast resource ids for forecast generation 37 | | | | ├── AbstractPredictionGenerationLambdaHandler.java <-- Abstract hanlder contains methods can be shared by inherited handlers 38 | | | | ├── CreateDatasetHandler.java <-- Function implementation for creating forecast dataset resource 39 | | | | ├── CreateDatasetGroupHandler.java <-- Function implementation for creating forecast dataset group resource 40 | | | | ├── CreateDatasetImportJobHandler.java <-- Function implementation for creating forecast dataset import job resource 41 | | | | ├── CreatePredictorHandler.java <-- Function implementation for creating forecast predictor (ML model) resource 42 | | | | ├── CreateForecastHandler.java <-- Function implementation for creating forecast resource 43 | | | | ├── CreateForecastExportJobHandler.java <-- Function implementation for creating forecast export job resource 44 | | | | ├── DeleteOutdatedForecastExportJobsHandler.java <-- Function implementation for deleting expired export job resources 45 | | | | ├── DeleteOutdatedForecastsHandler.java <-- Function implementation for deleting expired forecast resources 46 | | | | ├── DeleteOutdatedPredictorsHandler.java <-- Function implementation for deleting expired predictor resources 47 | | | | ├── DeleteOutdatedDatasetImportJobsHandler.java <-- Function implementation for deleting expired dataset import job resources 48 | | | | ├── DeleteOutdatedDatasetsHandler.java <-- Function implementation for deleting expired dataset resources 49 | | | | └── DeleteOutdatedDatasetGroupsHandler.java <-- Function implementation for deleting expired dataset group resources 50 | │ │ └── queryingpredictionresult <-- Lambda functions for querying prediction result component 51 | | | ├── LoadDataFromS3ToDynamoDBHandler.java <-- Function implementation for loading data from S3 to DynamoDB table 52 | | | └── PredictionResultItem.java <-- POJO shape for a prediction result record 53 | │ └── test <-- Unit tests 54 | │ └── resources <-- Contains dummy prediction result csv file used for testing LoadDataFromS3ToDynamoDBHandler.java 55 | │ └── java 56 | │ └── com.amazonaws.lambda <-- Unit tests for handlers 57 | │ ├── demandpublishing <-- Unit tests for demand publishing related handlers 58 | │ | ├── PublishDemandHandlerTest.java <-- Unit tests for PublishDemandHandler.java 59 | │ ├── predictiongeneration <-- Unit tests for prediction generation related handlers 60 | │ | ├── GenerateForecastResourcesIdsHandlerTest.java <-- Unit tests for GenerateForecastResourcesIdsHandler.java 61 | │ | ├── GenerateForecastResourcesIdsCronHandlerTest.java <-- Unit tests for GenerateForecastResourcesIdsCronHandler.java 62 | │ | ├── CreateDatasetHandlerTest.java <-- Unit tests for CreateDatasetHandler.java 63 | │ | ├── CreateDatasetGroupHandlerTest.java <-- Unit tests for CreateDatasetGroupHandler.java 64 | │ | ├── CreatePredictorHandlerTest.java <-- Unit tests for CreatePredictorHandler.java 65 | │ | ├── CreateForecastHandlerTest.java <-- Unit tests for CreateForecastHandler.java 66 | │ | ├── CreateForecastExportJobHandlerTest.java <-- Unit tests for CreateForecastExportJobHandler.java 67 | │ | ├── DeleteOutdatedForecastExportJobsHandlerTest.java <-- Unit tests for DeleteOutdatedForecastExportJobsHandler.java 68 | │ | ├── DeleteOutdatedForecastsHandlerTest.java <-- Unit tests for DeleteOutdatedForecastsHandler.java 69 | │ | ├── DeleteOutdatedPredictorsHandlerTest.java <-- Unit tests for DeleteOutdatedPredictorsHandler.java 70 | │ | ├── DeleteOutdatedDatasetImportJobsHandlerTest.java <-- Unit tests for DeleteOutdatedDatasetImportJobsHandler.java 71 | │ | ├── DeleteOutdatedDatasetImportJobsHandlerTest.java <-- Unit tests for DeleteOutdatedDatasetImportJobsHandler.java 72 | │ | ├── DeleteOutdatedDatasetsHandlerTest.java <-- Unit tests for DeleteOutdatedDatasetsHandler.java 73 | │ | └── DeleteOutdatedDatasetGroupsHandlerTest.java <-- Unit tests for DeleteOutdatedDatasetGroupsHandler.java 74 | │ └── queryingpredictionresult <-- Unit tests for querying prediction result related handlers 75 | │ └── LoadDataFromS3ToDynamoDBHandlerTest.java <-- Unit tests for LoadDataFromS3ToDynamoDBHandler.java 76 | └── template.yaml <-- Contains cloudformation resources for lambda, S3, step function, cloudwatch event, dynamodb, iam role, etc. 77 | ``` 78 | 79 | ## Requirements 80 | 81 | * AWS CLI already configured with at least PowerUser permission 82 | * [Gradle Build Tool](https://gradle.org/) 83 | * [Java SE Development Kit 8 installed](https://www.oracle.com/java/technologies/javase-jdk8-downloads.html) 84 | * [SAM CLI](https://github.com/awslabs/aws-sam-cli) 85 | 86 | ## Setup process 87 | 88 | ### Installing dependencies and Building code 89 | 90 | We use `sam` to trigger default `gradle` build tool for 91 | installing our dependencies and building our application into a JAR file: 92 | 93 | ```bash 94 | sam build 95 | ``` 96 | 97 | ## Packaging and deployment 98 | 99 | AWS Lambda Java runtime accepts either a zip file or a standalone JAR file - We use the latter in 100 | this example. SAM will use `CodeUri` property to know where to look up for both application and 101 | dependencies. As all functions use the same jar, we declare it in the Globals: 102 | 103 | ```yaml 104 | ... 105 | Globals: 106 | Function: 107 | AutoPublishAlias: live 108 | DeploymentPreference: 109 | Type: AllAtOnce 110 | MemorySize: 1024 111 | Runtime: java8 112 | Timeout: 180 113 | ReservedConcurrentExecutions: 2 # There can be two state machines executing the same function at the same time 114 | CodeUri: . 115 | ``` 116 | 117 | Firstly, we need a `S3 bucket` where we can upload our Lambda functions packaged as ZIP before we 118 | deploy anything - If you don't have a S3 bucket to store code artifacts then this is a good time to 119 | create one: 120 | 121 | ```bash 122 | export BUCKET_NAME= 123 | aws s3 mb s3://$BUCKET_NAME 124 | ``` 125 | 126 | Next, the following command will create a Cloudformation Stack and deploy your SAM resources. 127 | Note, since s3 bucket name has to been unique across all accounts all regions, 128 | we make it as a cloudformation input parameter. 129 | So please specify your own bucket name for automated forecast project. 130 | 131 | ```bash 132 | sam deploy \ 133 | --stack-name sam-automated-forecast \ 134 | --capabilities CAPABILITY_NAMED_IAM \ 135 | --s3-bucket $BUCKET_NAME 136 | --parameter-overrides PredictionS3BucketName= 137 | ``` 138 | 139 | > **See [Serverless Application Model (SAM) HOWTO Guide](https://github.com/awslabs/serverless-application-model/blob/master/HOWTO.md) for more details in how to get started.** 140 | 141 | ## Testing 142 | 143 | ### Running unit tests 144 | We use `JUnit` for testing our code. You can run unit tests with the following command: 145 | 146 | ```bash 147 | gradle test 148 | ``` 149 | -------------------------------------------------------------------------------- /build.gradle: -------------------------------------------------------------------------------- 1 | plugins { 2 | id 'java' 3 | } 4 | 5 | sourceCompatibility = 1.8 6 | 7 | repositories { 8 | mavenCentral() 9 | maven { 10 | url "https://s3-us-west-2.amazonaws.com/dynamodb-local/release" 11 | } 12 | } 13 | 14 | configurations { 15 | dynamodb 16 | } 17 | 18 | dependencies { 19 | compile group: 'commons-io', name: 'commons-io', version: '2.6' 20 | compile group: 'com.amazonaws', name: 'aws-java-sdk-dynamodb', version: '1.11.715' 21 | compile group: 'com.amazonaws', name: 'aws-java-sdk-forecast', version: '1.11.715' 22 | compile group: 'com.amazonaws', name: 'aws-java-sdk-s3', version: '1.11.715' 23 | compile group: 'com.amazonaws', name: 'aws-lambda-java-events', version: '2.2.7' 24 | compile group: 'com.amazonaws', name: 'aws-lambda-java-log4j2', version: '1.1.0' 25 | compile group: 'com.google.collections', name: 'google-collections', version: '1.0' 26 | compile group: 'com.google.dagger', name: 'dagger', version: '2.26' 27 | annotationProcessor group: 'com.google.dagger', name: 'dagger-compiler', version: '2.26' 28 | compile group: 'com.opencsv', name: 'opencsv', version: '5.1' 29 | compile group: 'org.apache.commons', name: 'commons-collections4', version: '4.0' 30 | compile group: 'org.apache.commons', name: 'commons-lang3', version: '3.9' 31 | compile group: 'org.projectlombok', name: 'lombok', version: '1.18.10' 32 | compile group: 'org.slf4j', name: 'slf4j-simple', version: '1.7.30' 33 | annotationProcessor group: 'org.projectlombok', name: 'lombok', version: '1.18.10' 34 | 35 | testCompile group: 'com.amazonaws', name: 'DynamoDBLocal', version: '1.12.0' 36 | dynamodb fileTree (dir: 'lib', include: ["*.dylib", "*.so", "*.dll"]) 37 | dynamodb group: 'com.amazonaws', name: 'DynamoDBLocal', version: '1.12.0' 38 | testCompile group: 'com.github.stefanbirkner', name: 'system-rules', version: '1.17.2' 39 | testCompile group: 'junit', name: 'junit', version: '4.13' 40 | testCompile group: 'org.junit.jupiter', name: 'junit-jupiter-api', version: '5.6.0' 41 | testCompile group: 'org.junit.jupiter', name: 'junit-jupiter-engine', version: '5.6.0' 42 | testCompile group: 'org.mockito', name: 'mockito-core', version: '2.10.0' 43 | } 44 | 45 | task copyNativeDeps(type: Copy) { 46 | from configurations.dynamodb 47 | into "$project.buildDir/libs/" 48 | } 49 | 50 | test.dependsOn copyNativeDeps 51 | test.doFirst { 52 | systemProperty "java.library.path", 'build/libs' 53 | } 54 | 55 | test { 56 | useJUnitPlatform() 57 | } 58 | -------------------------------------------------------------------------------- /src/main/java/com/amazonaws/dagger/AWSClientModule.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.dagger; 2 | 3 | import com.amazonaws.ClientConfiguration; 4 | import com.amazonaws.regions.Regions; 5 | import com.amazonaws.retry.PredefinedRetryPolicies; 6 | import com.amazonaws.retry.RetryPolicy; 7 | import com.amazonaws.services.dynamodbv2.AmazonDynamoDB; 8 | import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClientBuilder; 9 | import com.amazonaws.services.forecast.AmazonForecast; 10 | import com.amazonaws.services.forecast.AmazonForecastClientBuilder; 11 | import com.amazonaws.services.s3.AmazonS3; 12 | import com.amazonaws.services.s3.AmazonS3ClientBuilder; 13 | import dagger.Module; 14 | import dagger.Provides; 15 | 16 | import javax.inject.Singleton; 17 | 18 | @Module 19 | public class AWSClientModule { 20 | 21 | private static final int NUMBER_OF_RETRIES = 10; 22 | private static final RetryPolicy RETRY_POLICY = new RetryPolicy(PredefinedRetryPolicies.DEFAULT_RETRY_CONDITION, 23 | PredefinedRetryPolicies.DEFAULT_BACKOFF_STRATEGY, NUMBER_OF_RETRIES, false); 24 | private static final ClientConfiguration CLIENT_CONFIG = new ClientConfiguration().withRetryPolicy(RETRY_POLICY); 25 | 26 | @Provides 27 | @Singleton 28 | static AmazonDynamoDB provideDDBClient() { 29 | return AmazonDynamoDBClientBuilder.standard() 30 | .withClientConfiguration(CLIENT_CONFIG) 31 | .withRegion(Regions.fromName(System.getenv("AWS_REGION"))) 32 | .build(); 33 | } 34 | 35 | @Provides 36 | @Singleton 37 | static AmazonForecast provideForecastClient() { 38 | return AmazonForecastClientBuilder.standard() 39 | .withClientConfiguration(CLIENT_CONFIG) 40 | .withRegion(Regions.fromName(System.getenv("AWS_REGION"))) 41 | .build(); 42 | } 43 | 44 | @Provides 45 | @Singleton 46 | static AmazonS3 provideS3Client() { 47 | return AmazonS3ClientBuilder.standard() 48 | .withClientConfiguration(CLIENT_CONFIG) 49 | .withRegion(Regions.fromName(System.getenv("AWS_REGION"))) 50 | .build(); 51 | } 52 | } -------------------------------------------------------------------------------- /src/main/java/com/amazonaws/dagger/LambdaFunctionsComponent.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.dagger; 2 | 3 | import com.amazonaws.lambda.demandpublishing.PublishDemandHandler; 4 | import com.amazonaws.lambda.predictiongeneration.AbstractPredictionGenerationLambdaHandler; 5 | import com.amazonaws.lambda.predictiongeneration.GenerateForecastResourcesIdsCronHandler; 6 | import com.amazonaws.lambda.predictiongeneration.GenerateForecastResourcesIdsHandler; 7 | import com.amazonaws.lambda.queryingpredictionresult.LoadDataFromS3ToDynamoDBHandler; 8 | import dagger.Component; 9 | 10 | import javax.inject.Singleton; 11 | 12 | @Singleton 13 | @Component(modules = {AWSClientModule.class}) 14 | public interface LambdaFunctionsComponent { 15 | 16 | void inject(PublishDemandHandler handler); 17 | 18 | void inject(AbstractPredictionGenerationLambdaHandler handler); 19 | 20 | void inject(GenerateForecastResourcesIdsHandler handler); 21 | 22 | void inject(GenerateForecastResourcesIdsCronHandler handler); 23 | 24 | void inject(LoadDataFromS3ToDynamoDBHandler handler); 25 | } 26 | -------------------------------------------------------------------------------- /src/main/java/com/amazonaws/lambda/demandpublishing/DemandRecord.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.lambda.demandpublishing; 2 | 3 | import com.opencsv.bean.AbstractBeanField; 4 | import com.opencsv.bean.CsvBindByName; 5 | import com.opencsv.bean.CsvCustomBindByName; 6 | import lombok.AllArgsConstructor; 7 | import lombok.Builder; 8 | import lombok.Data; 9 | import lombok.NoArgsConstructor; 10 | 11 | import java.time.LocalDateTime; 12 | import java.time.format.DateTimeFormatter; 13 | import java.util.StringJoiner; 14 | 15 | import static com.amazonaws.lambda.demandpublishing.DemandRecord.LocalDateTimeConverter.FORECAST_DATE_TIME_FORMATTER; 16 | 17 | @Data 18 | @AllArgsConstructor 19 | @NoArgsConstructor 20 | @Builder 21 | public class DemandRecord { 22 | 23 | public static class Attribute { 24 | public static final String ITEM_ID = "item_id"; 25 | public static final String TIMESTAMP = "timestamp"; 26 | public static final String TARGET_VALUE = "target_value"; 27 | } 28 | 29 | public static class LocalDateTimeConverter extends AbstractBeanField { 30 | 31 | /* 32 | * Refer to: https://docs.oracle.com/javase/8/docs/api/java/time/format/DateTimeFormatter.html#patterns, 33 | * uuuu: year 34 | * More details about using 'u' instead of 'y' are in: 35 | * https://stackoverflow.com/questions/41177442/uuuu-versus-yyyy-in-datetimeformatter-formatting-pattern-codes-in-java 36 | * MM: month-of-year 37 | * dd: day-of-year; ('DD' is the day-of-year) 38 | * HH: hour-of-day (0-23) ('hh' is the clock-hour-of-am-pm (1-12)) 39 | * mm: minute-of-hour 40 | * ss: second-of-minute 41 | */ 42 | public static final DateTimeFormatter FORECAST_DATE_TIME_FORMATTER = DateTimeFormatter.ofPattern("uuuu-MM-dd HH:mm:ss"); 43 | 44 | @Override 45 | protected LocalDateTime convert(String s) { 46 | /* 47 | * Refer to: https://docs.oracle.com/javase/8/docs/api/java/time/format/DateTimeFormatter.html#patterns, 48 | * uuuu: year 49 | * More details about using 'u' instead of 'y' are in: 50 | * https://stackoverflow.com/questions/41177442/uuuu-versus-yyyy-in-datetimeformatter-formatting-pattern-codes-in-java 51 | * MM: month-of-year 52 | * dd: day-of-year; ('DD' is the day-of-year) 53 | * HH: hour-of-day (0-23) ('hh' is the clock-hour-of-am-pm (1-12)) 54 | * mm: minute-of-hour 55 | * ss: second-of-minute 56 | */ 57 | return LocalDateTime.parse(s, FORECAST_DATE_TIME_FORMATTER); 58 | } 59 | } 60 | 61 | @CsvBindByName(column = Attribute.ITEM_ID, required = true) 62 | private String itemId; 63 | 64 | @CsvCustomBindByName(column = Attribute.TIMESTAMP, converter = LocalDateTimeConverter.class) 65 | private LocalDateTime timestamp; 66 | 67 | @CsvBindByName(column = Attribute.TARGET_VALUE, required = true) 68 | private String targetValue; 69 | 70 | public String toCsvRowString() { 71 | StringJoiner sj = new StringJoiner(","); 72 | sj.add(itemId).add(FORECAST_DATE_TIME_FORMATTER.format(timestamp)).add(targetValue); 73 | return sj.toString(); 74 | } 75 | } 76 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /src/main/java/com/amazonaws/lambda/demandpublishing/PublishDemandHandler.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.lambda.demandpublishing; 2 | 3 | import com.amazonaws.dagger.DaggerLambdaFunctionsComponent; 4 | import com.amazonaws.services.lambda.runtime.Context; 5 | import com.amazonaws.services.lambda.runtime.RequestHandler; 6 | import com.amazonaws.services.s3.AmazonS3; 7 | import com.amazonaws.services.s3.model.ObjectMetadata; 8 | import com.amazonaws.services.s3.transfer.TransferManager; 9 | import com.amazonaws.services.s3.transfer.TransferManagerBuilder; 10 | import com.google.common.annotations.VisibleForTesting; 11 | import com.google.common.collect.Lists; 12 | import com.opencsv.bean.CsvToBean; 13 | import com.opencsv.bean.CsvToBeanBuilder; 14 | import lombok.NonNull; 15 | import lombok.extern.slf4j.Slf4j; 16 | import org.apache.commons.io.IOUtils; 17 | 18 | import javax.inject.Inject; 19 | import java.io.BufferedReader; 20 | import java.io.InputStreamReader; 21 | import java.nio.charset.StandardCharsets; 22 | import java.time.Clock; 23 | import java.time.LocalDateTime; 24 | import java.util.List; 25 | import java.util.StringJoiner; 26 | import java.util.stream.Collectors; 27 | 28 | @Slf4j 29 | public class PublishDemandHandler implements RequestHandler { 30 | 31 | private static final int YEAR_IN_DEMONSTRATION_FILE = 2020; 32 | private static final int LOOK_BACK_DURATION_IN_DAYS = 60; 33 | private static final String HISTORICAL_DEMAND_FILE_HEADER = "item_id,timestamp,target_value"; 34 | private static final String PREDICTION_S3_BUCKET_NAME = System.getenv("PREDICTION_S3_BUCKET_NAME"); 35 | private static final String PREDICTION_S3_HISTORICAL_DEMAND_FILE_KEY = 36 | String.format("%s/%s", System.getenv("SRC_S3_FOLDER"), System.getenv("S3_TRAINING_DATA_FILE_NAME")); 37 | 38 | private final Clock clock; 39 | 40 | @Inject 41 | @NonNull 42 | AmazonS3 s3Client; 43 | private final String rawDemandRequestsFilePath; 44 | private final TransferManager s3TransferManager; 45 | 46 | public PublishDemandHandler() { 47 | this(Clock.systemUTC()); 48 | } 49 | 50 | public PublishDemandHandler(final Clock clock) { 51 | this.clock = clock; 52 | this.rawDemandRequestsFilePath = "/raw_demand_requests.csv"; 53 | DaggerLambdaFunctionsComponent.create().inject(this); 54 | s3TransferManager = TransferManagerBuilder.standard().withS3Client(s3Client).build(); 55 | } 56 | 57 | @VisibleForTesting 58 | PublishDemandHandler(final Clock clock, 59 | final String rawDemandRequestsFilePath, 60 | final TransferManager transferManager) { 61 | this.clock = clock; 62 | this.rawDemandRequestsFilePath = rawDemandRequestsFilePath; 63 | this.s3TransferManager = transferManager; 64 | } 65 | 66 | @Override 67 | public Void handleRequest(final Void input, Context context) { 68 | 69 | List historicalDemandRecords = getHistoricalDemandRecords(); 70 | log.info(String.format("Fetched [%d] historical demand records", historicalDemandRecords.size())); 71 | 72 | uploadHistoricalDemandToS3(historicalDemandRecords); 73 | 74 | return null; 75 | } 76 | 77 | /** 78 | * Query data source to get the historical demand records. 79 | * For demonstration purpose, I use a local CSV file to mimic the data source, 80 | * but in real production environment, you need to query your database like RDS for such info. 81 | * 82 | * @return a list of historical demand record {@link DemandRecord} 83 | */ 84 | private List getHistoricalDemandRecords() { 85 | BufferedReader rawRequestsReader = new BufferedReader( 86 | new InputStreamReader(getClass().getResourceAsStream(rawDemandRequestsFilePath), 87 | StandardCharsets.UTF_8)); 88 | 89 | CsvToBean csvToBean = new CsvToBeanBuilder(rawRequestsReader) 90 | .withType(DemandRecord.class) 91 | .withIgnoreLeadingWhiteSpace(true) 92 | .build(); 93 | List demandRecords = Lists.newArrayList(csvToBean.iterator()); 94 | 95 | LocalDateTime currentTime = LocalDateTime.now(clock); 96 | 97 | final LocalDateTime predictionWindowEndTime; 98 | /* 99 | * As the demonstration csv file only contains data for year 2020, 100 | * If someone runs this sample code in the future, we need to normalize the timestamp to a time in 2020. 101 | */ 102 | if (YEAR_IN_DEMONSTRATION_FILE < currentTime.getYear()) { 103 | log.info(String.format("currentTime [%s] is after year 2020, normalizing it", currentTime)); 104 | predictionWindowEndTime = LocalDateTime.of(YEAR_IN_DEMONSTRATION_FILE, 105 | currentTime.getMonth(), 106 | currentTime.getDayOfMonth(), 107 | currentTime.getHour(), 108 | currentTime.getMinute(), 109 | currentTime.getSecond()); 110 | log.info(String.format("predictionWindowEndTime [%s] after normalization", predictionWindowEndTime)); 111 | } else { 112 | predictionWindowEndTime = currentTime; 113 | } 114 | 115 | final LocalDateTime predictionWindowStartTime = predictionWindowEndTime.minusDays(LOOK_BACK_DURATION_IN_DAYS); 116 | log.info(String.format("Use lookback period [%s - %s] for fetching the historical demand records", 117 | predictionWindowStartTime, predictionWindowEndTime)); 118 | 119 | return demandRecords.stream() 120 | .filter(record -> 121 | record.getTimestamp().isAfter(predictionWindowStartTime) 122 | && record.getTimestamp().isBefore(predictionWindowEndTime)).collect(Collectors.toList()); 123 | } 124 | 125 | private void uploadHistoricalDemandToS3(final List demandRecords) { 126 | String demandRecordsListCsvStr = convertListOfDemandRecordToString(demandRecords); 127 | int demandCsvFileSize = demandRecordsListCsvStr.length(); 128 | 129 | ObjectMetadata metadata = new ObjectMetadata(); 130 | metadata.setContentLength(demandCsvFileSize); 131 | try { 132 | s3TransferManager.upload(PREDICTION_S3_BUCKET_NAME, PREDICTION_S3_HISTORICAL_DEMAND_FILE_KEY, 133 | IOUtils.toInputStream(demandRecordsListCsvStr, StandardCharsets.UTF_8), metadata) 134 | .waitForCompletion(); 135 | } catch (InterruptedException e) { 136 | log.warn("Got InterruptedException while uploading the data to S3"); 137 | } 138 | log.info("Finished uploading the historical demand data to S3"); 139 | } 140 | 141 | private String convertListOfDemandRecordToString(final List demandRecords) { 142 | StringJoiner sj = new StringJoiner("\n"); 143 | sj.add(HISTORICAL_DEMAND_FILE_HEADER); 144 | for (DemandRecord demandRecord : demandRecords) { 145 | sj.add(demandRecord.toCsvRowString()); 146 | } 147 | return sj.toString(); 148 | } 149 | } 150 | -------------------------------------------------------------------------------- /src/main/java/com/amazonaws/lambda/predictiongeneration/AbstractPredictionGenerationLambdaHandler.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.lambda.predictiongeneration; 2 | 3 | import com.amazonaws.dagger.DaggerLambdaFunctionsComponent; 4 | import com.amazonaws.lambda.predictiongeneration.exception.ResourceSetupFailureException; 5 | import com.amazonaws.lambda.predictiongeneration.exception.ResourceSetupInProgressException; 6 | import com.amazonaws.services.forecast.AmazonForecast; 7 | import com.amazonaws.services.forecast.model.DatasetSummary; 8 | import com.amazonaws.services.forecast.model.PredictorSummary; 9 | import com.amazonaws.services.lambda.runtime.Context; 10 | import com.amazonaws.services.lambda.runtime.RequestHandler; 11 | import com.fasterxml.jackson.core.type.TypeReference; 12 | import com.fasterxml.jackson.databind.ObjectMapper; 13 | import lombok.NonNull; 14 | import lombok.extern.slf4j.Slf4j; 15 | 16 | import javax.inject.Inject; 17 | import java.io.IOException; 18 | import java.util.List; 19 | import java.util.Map; 20 | import java.util.stream.Collectors; 21 | 22 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.RESOURCE_ACTIVE_STATUS; 23 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.RESOURCE_FAILED_STATUS; 24 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.listDatasets; 25 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.listPredictors; 26 | 27 | @Slf4j 28 | public abstract class AbstractPredictionGenerationLambdaHandler implements RequestHandler { 29 | 30 | // Will be used for CreateDataset and CreateDatasetGroup APIs 31 | protected static final String DOMAIN = "CUSTOM"; 32 | 33 | @Inject 34 | @NonNull 35 | protected AmazonForecast forecastClient; 36 | 37 | AbstractPredictionGenerationLambdaHandler() { 38 | DaggerLambdaFunctionsComponent.create().inject(this); 39 | } 40 | 41 | AbstractPredictionGenerationLambdaHandler(final AmazonForecast forecastClient) { 42 | this.forecastClient = forecastClient; 43 | } 44 | 45 | @Override 46 | public String handleRequest(final String input, Context context) { 47 | Map resourceIdMap; 48 | try { 49 | resourceIdMap = new ObjectMapper().readValue(input, new TypeReference>() { 50 | }); 51 | } catch(IOException e) { 52 | String errorMsg = e.getMessage(); 53 | log.error(errorMsg); 54 | throw new RuntimeException(errorMsg); 55 | } 56 | process(resourceIdMap); 57 | return input; 58 | } 59 | 60 | abstract void process(Map resourceIdMap); 61 | 62 | /** 63 | * @return true if status is ACTIVE, and it indicates the resource setup is successfully finished. 64 | * @throws ResourceSetupInProgressException if resourceStatus is FAILED 65 | * @throws ResourceSetupFailureException otherwise 66 | */ 67 | protected boolean takeActionByResourceStatus(final String resourceStatus, 68 | final String resourceType, 69 | final String resourceName) 70 | throws ResourceSetupFailureException, ResourceSetupInProgressException { 71 | 72 | switch (resourceStatus) { 73 | case RESOURCE_ACTIVE_STATUS: 74 | log.info(String.format("Successfully created %s %s: [%s]", RESOURCE_ACTIVE_STATUS, resourceType, resourceName)); 75 | return true; 76 | 77 | case RESOURCE_FAILED_STATUS: 78 | throw new ResourceSetupFailureException(String.format("%s: [%s] setup failed.", resourceType, resourceName)); 79 | 80 | default: 81 | throw new ResourceSetupInProgressException( 82 | String.format("%s: [%s] setup is in progress with current status [%s]", 83 | resourceType, resourceName, resourceStatus)); 84 | } 85 | } 86 | 87 | protected List listOutdatedDatasetArns(final String currentDatasetArn) { 88 | List existingDatasetArns = listDatasetArns(); 89 | existingDatasetArns.remove(currentDatasetArn); 90 | return existingDatasetArns; 91 | } 92 | 93 | private List listDatasetArns() { 94 | return listDatasets(forecastClient).stream().map(DatasetSummary::getDatasetArn).collect(Collectors.toList()); 95 | } 96 | 97 | protected List listOutdatedPredictorArns(final String currentPredictorArn) { 98 | List existingPredictorArns = listPredictorArns(); 99 | existingPredictorArns.remove(currentPredictorArn); 100 | return existingPredictorArns; 101 | } 102 | 103 | private List listPredictorArns() { 104 | List existingPredictors = listPredictors(forecastClient); 105 | return existingPredictors.stream().map(PredictorSummary::getPredictorArn).collect(Collectors.toList()); 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /src/main/java/com/amazonaws/lambda/predictiongeneration/CreateDatasetGroupHandler.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.lambda.predictiongeneration; 2 | 3 | import com.amazonaws.services.forecast.AmazonForecast; 4 | import com.amazonaws.services.forecast.model.CreateDatasetGroupRequest; 5 | import com.amazonaws.services.forecast.model.ResourceAlreadyExistsException; 6 | import lombok.extern.slf4j.Slf4j; 7 | 8 | import java.util.Collections; 9 | import java.util.Map; 10 | 11 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.DATASET_ARN_KEY; 12 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.DATASET_GROUP_NAME_KEY; 13 | 14 | @Slf4j 15 | public class CreateDatasetGroupHandler extends AbstractPredictionGenerationLambdaHandler { 16 | 17 | private static final String DATASET_GROUP_RESOURCE_TYPE = "datasetGroup"; 18 | 19 | public CreateDatasetGroupHandler() { 20 | super(); 21 | } 22 | 23 | public CreateDatasetGroupHandler(final AmazonForecast forecastClient) { 24 | super(forecastClient); 25 | } 26 | 27 | @Override 28 | public void process(final Map resourceIdMap) { 29 | 30 | String datasetArn = resourceIdMap.get(DATASET_ARN_KEY); 31 | String datasetGroupName = resourceIdMap.get(DATASET_GROUP_NAME_KEY); 32 | log.info(String.format("The datasetArn and %s getting from resourceIdMap are [%s] and [%s]", 33 | DATASET_GROUP_RESOURCE_TYPE, datasetArn, datasetGroupName)); 34 | 35 | /* 36 | * Create the datasetGroup, since this API call is synchronized, 37 | * once it returns 200, we know the corresponding datasetGroup is created successfully. 38 | * So we don't need to describe for the datasetGroup status. 39 | * Also, refer to: https://docs.aws.amazon.com/forecast/latest/dg/API_DescribeDatasetGroup.html 40 | * datasetGroup doesn't even have the 'Status' attribute. 41 | */ 42 | try { 43 | createDatasetGroup(datasetArn, datasetGroupName, DOMAIN); 44 | } catch (ResourceAlreadyExistsException e) { 45 | log.info(String.format("The %s [%s] already exists.", DATASET_GROUP_RESOURCE_TYPE, datasetGroupName)); 46 | } 47 | 48 | log.info(String.format("Successfully setup the %s %s", DATASET_GROUP_RESOURCE_TYPE, datasetGroupName)); 49 | } 50 | 51 | private void createDatasetGroup(final String datasetArn, 52 | final String datasetGroupName, 53 | final String domain) { 54 | CreateDatasetGroupRequest createDatasetGroupRequest = new CreateDatasetGroupRequest(); 55 | createDatasetGroupRequest.setDatasetArns(Collections.singletonList(datasetArn)); 56 | createDatasetGroupRequest.setDatasetGroupName(datasetGroupName); 57 | createDatasetGroupRequest.setDomain(domain); 58 | forecastClient.createDatasetGroup(createDatasetGroupRequest); 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /src/main/java/com/amazonaws/lambda/predictiongeneration/CreateDatasetHandler.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.lambda.predictiongeneration; 2 | 3 | import com.amazonaws.services.forecast.AmazonForecast; 4 | import com.amazonaws.services.forecast.model.CreateDatasetRequest; 5 | import com.amazonaws.services.forecast.model.DescribeDatasetRequest; 6 | import com.amazonaws.services.forecast.model.ResourceNotFoundException; 7 | import com.amazonaws.services.forecast.model.Schema; 8 | import com.amazonaws.services.forecast.model.SchemaAttribute; 9 | import com.google.common.collect.Lists; 10 | import lombok.extern.slf4j.Slf4j; 11 | 12 | import java.util.List; 13 | import java.util.Map; 14 | 15 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.DATASET_ARN_KEY; 16 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.DATASET_NAME_KEY; 17 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.DATA_FREQUENCY_KEY; 18 | 19 | @Slf4j 20 | public class CreateDatasetHandler extends AbstractPredictionGenerationLambdaHandler { 21 | 22 | private static final String DATASET_TYPE = "TARGET_TIME_SERIES"; 23 | private static final String DATASET_RESOURCE_TYPE = "dataset"; 24 | 25 | public CreateDatasetHandler() { 26 | super(); 27 | } 28 | 29 | public CreateDatasetHandler(AmazonForecast forecastClient) { 30 | super(forecastClient); 31 | } 32 | 33 | @Override 34 | public void process(final Map resourceIdMap) { 35 | 36 | final String datasetName = resourceIdMap.get(DATASET_NAME_KEY); 37 | final String datasetArn = resourceIdMap.get(DATASET_ARN_KEY); 38 | final String dataFrequency = resourceIdMap.get(DATA_FREQUENCY_KEY); 39 | log.info(String.format("The %s and dataFrequency getting from resourceIdMap are [%s] and [%s]", 40 | DATASET_RESOURCE_TYPE, datasetArn, dataFrequency)); 41 | 42 | // Check if dataset exists 43 | try { 44 | String currentStatus = describeDatasetStatus(datasetArn); 45 | if (takeActionByResourceStatus(currentStatus, DATASET_RESOURCE_TYPE, datasetArn)) { 46 | return; 47 | } 48 | } catch (ResourceNotFoundException e) { 49 | log.info(String.format("Cannot find %s with arn [%s]. Proceed to create a new one", DATASET_RESOURCE_TYPE, datasetArn)); 50 | } 51 | 52 | // Create the dataset if found no matching dataset name 53 | createDataset(DOMAIN, DATASET_TYPE, datasetName, dataFrequency); 54 | log.info("finish triggering CreateDatasetCall."); 55 | 56 | String newStatus = describeDatasetStatus(datasetArn); 57 | takeActionByResourceStatus(newStatus, DATASET_RESOURCE_TYPE, datasetArn); 58 | } 59 | 60 | private void createDataset(final String domain, 61 | final String datasetType, 62 | final String datasetName, 63 | final String dataFrequency) { 64 | CreateDatasetRequest createDatasetRequest = new CreateDatasetRequest(); 65 | createDatasetRequest.setDomain(domain); 66 | createDatasetRequest.setDatasetType(datasetType); 67 | createDatasetRequest.setDatasetName(datasetName); 68 | createDatasetRequest.setDataFrequency(dataFrequency); 69 | 70 | // schema configuration 71 | List schemaAttributes = Lists.newArrayList( 72 | 73 | // Refer to https://docs.aws.amazon.com/forecast/latest/dg/API_CreateDataset.html#forecast-CreateDataset-request-Schema 74 | // The schema attributes and their order must match the fields in your training data file. 75 | new SchemaAttribute().withAttributeName("item_id").withAttributeType("string"), 76 | new SchemaAttribute().withAttributeName("timestamp").withAttributeType("timestamp"), 77 | new SchemaAttribute().withAttributeName("target_value").withAttributeType("integer") 78 | ); 79 | Schema schema = new Schema().withAttributes(schemaAttributes); 80 | createDatasetRequest.setSchema(schema); 81 | forecastClient.createDataset(createDatasetRequest); 82 | } 83 | 84 | private String describeDatasetStatus(final String datasetArn) { 85 | DescribeDatasetRequest describeDatasetRequest = new DescribeDatasetRequest(); 86 | describeDatasetRequest.setDatasetArn(datasetArn); 87 | return forecastClient.describeDataset(describeDatasetRequest).getStatus(); 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /src/main/java/com/amazonaws/lambda/predictiongeneration/CreateDatasetImportJobHandler.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.lambda.predictiongeneration; 2 | 3 | import com.amazonaws.services.forecast.AmazonForecast; 4 | import com.amazonaws.services.forecast.model.CreateDatasetImportJobRequest; 5 | import com.amazonaws.services.forecast.model.DataSource; 6 | import com.amazonaws.services.forecast.model.DescribeDatasetImportJobRequest; 7 | import com.amazonaws.services.forecast.model.ResourceNotFoundException; 8 | import com.amazonaws.services.forecast.model.S3Config; 9 | import lombok.extern.slf4j.Slf4j; 10 | 11 | import java.util.Map; 12 | 13 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.DATASET_ARN_KEY; 14 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.DATASET_IMPORT_JOB_ARN_KEY; 15 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.DATASET_IMPORT_JOB_NAME_KEY; 16 | 17 | @Slf4j 18 | public class CreateDatasetImportJobHandler extends AbstractPredictionGenerationLambdaHandler { 19 | 20 | private static final String FORECAST_IMPORT_TRAINING_DATA_ROLE_ARN; 21 | static { 22 | String forecastImportTrainingDataRoleArn = System.getenv("FORECAST_IMPORT_TRAINING_DATA_ROLE_ARN"); 23 | log.info(String.format("forecastImportTrainingDataRoleArn getting from environment variable is [%s]", 24 | forecastImportTrainingDataRoleArn)); 25 | FORECAST_IMPORT_TRAINING_DATA_ROLE_ARN = forecastImportTrainingDataRoleArn; 26 | } 27 | private static final String TIMESTAMP_FORMAT = "yyyy-MM-dd HH:mm:ss"; 28 | private static final String FORECAST_TRAINING_DATA_S3_URI; 29 | static { 30 | String s3TrainingDataBucket = System.getenv("PREDICTION_S3_BUCKET_NAME"); 31 | String s3TrainingDataFolder = System.getenv("SRC_S3_FOLDER"); 32 | String s3TrainingDataFileName = System.getenv("S3_TRAINING_DATA_FILE_NAME"); 33 | String forecastTrainingDataS3Uri = String.format("s3://%s/%s/%s", 34 | s3TrainingDataBucket, 35 | s3TrainingDataFolder, 36 | s3TrainingDataFileName); 37 | log.info(String.format("The forecastTrainingDataS3Uri getting from env variables is %s", 38 | forecastTrainingDataS3Uri)); 39 | FORECAST_TRAINING_DATA_S3_URI = forecastTrainingDataS3Uri; 40 | } 41 | private static final String DATASET_IMPORT_JOB_RESOURCE_TYPE = "datasetImportJob"; 42 | 43 | public CreateDatasetImportJobHandler() { 44 | super(); 45 | } 46 | 47 | public CreateDatasetImportJobHandler(final AmazonForecast forecastClient) { 48 | super(forecastClient); 49 | } 50 | 51 | @Override 52 | public void process(final Map resourceIdMap) { 53 | String datasetArn = resourceIdMap.get(DATASET_ARN_KEY); 54 | String datasetImportJobName = resourceIdMap.get(DATASET_IMPORT_JOB_NAME_KEY); 55 | String datasetImportJobArn = resourceIdMap.get(DATASET_IMPORT_JOB_ARN_KEY); 56 | log.info(String.format( 57 | "The datasetArn, datasetImportJobName, and datasetImportJobArn getting from resourceIdMap are [%s], [%s], and [%s]", 58 | datasetArn, datasetImportJobName, datasetImportJobArn)); 59 | 60 | // Check if dataset import job exists 61 | try { 62 | String currentStatus = describeDatasetImportJobStatus(datasetImportJobArn); 63 | if (takeActionByResourceStatus(currentStatus, DATASET_IMPORT_JOB_RESOURCE_TYPE, datasetImportJobArn)) { 64 | return; 65 | } 66 | } catch (ResourceNotFoundException e) { 67 | log.info(String.format("Cannot find %s, %s. Proceed to create a new one", 68 | DATASET_IMPORT_JOB_RESOURCE_TYPE, datasetImportJobArn)); 69 | } 70 | 71 | // Create the dataset import job if found no import job for given dataset name 72 | createDatasetImportJob(datasetImportJobName, 73 | datasetArn, 74 | FORECAST_TRAINING_DATA_S3_URI, 75 | FORECAST_IMPORT_TRAINING_DATA_ROLE_ARN, 76 | TIMESTAMP_FORMAT); 77 | log.info("finish triggering CreateDatasetImportJobCall."); 78 | 79 | String newStatus = describeDatasetImportJobStatus(datasetImportJobArn); 80 | takeActionByResourceStatus(newStatus, DATASET_IMPORT_JOB_RESOURCE_TYPE, datasetImportJobArn); 81 | } 82 | 83 | private void createDatasetImportJob(final String datasetImportJobName, 84 | final String datasetArn, 85 | final String s3Uri, 86 | final String roleArn, 87 | final String timestampFormat) { 88 | CreateDatasetImportJobRequest createDatasetImportJobRequest = new CreateDatasetImportJobRequest(); 89 | createDatasetImportJobRequest.setDatasetImportJobName(datasetImportJobName); 90 | createDatasetImportJobRequest.setDatasetArn(datasetArn); 91 | createDatasetImportJobRequest.setDataSource( 92 | new DataSource().withS3Config( 93 | new S3Config().withPath(s3Uri).withRoleArn(roleArn)) 94 | ); 95 | createDatasetImportJobRequest.setTimestampFormat(timestampFormat); 96 | forecastClient.createDatasetImportJob(createDatasetImportJobRequest); 97 | } 98 | 99 | private String describeDatasetImportJobStatus(final String dataseImportJobArn) { 100 | DescribeDatasetImportJobRequest describeDatasetImportJobRequest = new DescribeDatasetImportJobRequest(); 101 | describeDatasetImportJobRequest.setDatasetImportJobArn(dataseImportJobArn); 102 | return forecastClient.describeDatasetImportJob(describeDatasetImportJobRequest).getStatus(); 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /src/main/java/com/amazonaws/lambda/predictiongeneration/CreateForecastExportJobHandler.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.lambda.predictiongeneration; 2 | 3 | import com.amazonaws.services.forecast.AmazonForecast; 4 | import com.amazonaws.services.forecast.model.CreateForecastExportJobRequest; 5 | import com.amazonaws.services.forecast.model.DataDestination; 6 | import com.amazonaws.services.forecast.model.DescribeForecastExportJobRequest; 7 | import com.amazonaws.services.forecast.model.ResourceNotFoundException; 8 | import com.amazonaws.services.forecast.model.S3Config; 9 | import lombok.extern.slf4j.Slf4j; 10 | 11 | import java.util.Map; 12 | 13 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.FORECAST_ARN_KEY; 14 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.FORECAST_EXPORT_JOB_ARN_KEY; 15 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.FORECAST_EXPORT_JOB_NAME_KEY; 16 | 17 | @Slf4j 18 | public class CreateForecastExportJobHandler extends AbstractPredictionGenerationLambdaHandler { 19 | 20 | private static final String FORECAST_EXPORT_JOB_RESOURCE_TYPE = "forecastExportJob"; 21 | private static final String FORECAST_EXPORT_RESULT_ROLE_ARN; 22 | static { 23 | String forecastExportResultRoleArn = System.getenv("FORECAST_EXPORT_RESULT_ROLE_ARN"); 24 | log.info(String.format("forecastExportResultRoleArn getting from environment variable is [%s]", forecastExportResultRoleArn)); 25 | FORECAST_EXPORT_RESULT_ROLE_ARN = forecastExportResultRoleArn; 26 | } 27 | private static final String FORECAST_EXPORT_RESULT_S3_URI; 28 | static { 29 | String s3ExportResultBucket = System.getenv("PREDICTION_S3_BUCKET_NAME"); 30 | String s3ExportResultFolder = System.getenv("TGT_S3_FOLDER"); 31 | String forecastExportResultS3Uri = String.format("s3://%s/%s", s3ExportResultBucket, s3ExportResultFolder); 32 | log.info(String.format("The forecastExportResultS3Uri getting from env variables is %s", 33 | forecastExportResultS3Uri)); 34 | FORECAST_EXPORT_RESULT_S3_URI = forecastExportResultS3Uri; 35 | } 36 | 37 | public CreateForecastExportJobHandler() { 38 | super(); 39 | } 40 | 41 | public CreateForecastExportJobHandler(AmazonForecast forecastClient) { 42 | super(forecastClient); 43 | } 44 | 45 | @Override 46 | public void process(final Map resourceIdMap) { 47 | 48 | String forecastExportJobName = resourceIdMap.get(FORECAST_EXPORT_JOB_NAME_KEY); 49 | String forecastExportJobArn = resourceIdMap.get(FORECAST_EXPORT_JOB_ARN_KEY); 50 | String forecastArn = resourceIdMap.get(FORECAST_ARN_KEY); 51 | log.info(String.format( 52 | "The forecastExportJobName, forecastExportJobArn, and forecastArn getting from resourceIdMap are [%s], [%s], and [%s]", 53 | forecastExportJobName, forecastExportJobArn, forecastArn)); 54 | 55 | // Check if forecastExportJob exists 56 | try { 57 | String currentStatus = describeForecastExportJobStatus(forecastExportJobArn); 58 | if (takeActionByResourceStatus(currentStatus, FORECAST_EXPORT_JOB_RESOURCE_TYPE, forecastExportJobArn)) { 59 | return; 60 | } 61 | } catch (ResourceNotFoundException e) { 62 | log.info(String.format("Cannot find %s with arn [%s]. Proceed to create a new one", 63 | FORECAST_EXPORT_JOB_RESOURCE_TYPE, forecastExportJobArn)); 64 | } 65 | 66 | // create a new forecastExportJob 67 | createForecastExportJob(forecastExportJobName, forecastArn, FORECAST_EXPORT_RESULT_ROLE_ARN, FORECAST_EXPORT_RESULT_S3_URI); 68 | log.info("finish triggering CreateForecastExportJobCall."); 69 | 70 | String newStatus = describeForecastExportJobStatus(forecastExportJobArn); 71 | takeActionByResourceStatus(newStatus, FORECAST_EXPORT_JOB_RESOURCE_TYPE, forecastExportJobArn); 72 | } 73 | 74 | private String describeForecastExportJobStatus(final String forecastExportJobArn) { 75 | DescribeForecastExportJobRequest describeForecastExportJobRequest = new DescribeForecastExportJobRequest(); 76 | describeForecastExportJobRequest.setForecastExportJobArn(forecastExportJobArn); 77 | return forecastClient.describeForecastExportJob(describeForecastExportJobRequest).getStatus(); 78 | } 79 | 80 | private void createForecastExportJob(final String forecastExportJobName, 81 | final String forecastArn, 82 | final String roleArn, 83 | final String s3Uri) { 84 | CreateForecastExportJobRequest createForecastExportJobRequest = new CreateForecastExportJobRequest(); 85 | createForecastExportJobRequest.setForecastExportJobName(forecastExportJobName); 86 | createForecastExportJobRequest.setDestination( 87 | new DataDestination().withS3Config( 88 | new S3Config().withPath(s3Uri).withRoleArn(roleArn) 89 | ) 90 | ); 91 | createForecastExportJobRequest.setForecastArn(forecastArn); 92 | 93 | forecastClient.createForecastExportJob(createForecastExportJobRequest); 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /src/main/java/com/amazonaws/lambda/predictiongeneration/CreateForecastHandler.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.lambda.predictiongeneration; 2 | 3 | import com.amazonaws.services.forecast.AmazonForecast; 4 | import com.amazonaws.services.forecast.model.CreateForecastRequest; 5 | import com.amazonaws.services.forecast.model.DescribeForecastRequest; 6 | import com.amazonaws.services.forecast.model.ResourceNotFoundException; 7 | import lombok.extern.slf4j.Slf4j; 8 | 9 | import java.util.Map; 10 | 11 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.FORECAST_ARN_KEY; 12 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.FORECAST_NAME_KEY; 13 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.PREDICTOR_ARN_KEY; 14 | 15 | @Slf4j 16 | public class CreateForecastHandler extends AbstractPredictionGenerationLambdaHandler { 17 | 18 | private static final String FORECAST_RESOURCE_TYPE = "forecast"; 19 | 20 | public CreateForecastHandler() { 21 | super(); 22 | } 23 | 24 | public CreateForecastHandler(final AmazonForecast forecastClient) { 25 | super(forecastClient); 26 | } 27 | 28 | @Override 29 | public void process(final Map resourceIdMap) { 30 | 31 | String forecastName = resourceIdMap.get(FORECAST_NAME_KEY); 32 | String forecastArn = resourceIdMap.get(FORECAST_ARN_KEY); 33 | String predictorArn = resourceIdMap.get(PREDICTOR_ARN_KEY); 34 | log.info(String.format( 35 | "The forecastName, forecastArn, and predictorArn getting from resourceIdMap are [%s], [%s], and [%s]", 36 | forecastName, forecastArn, predictorArn)); 37 | 38 | // Check if forecast exists 39 | try { 40 | String currentStatus = describeForecastStatus(forecastArn); 41 | if (takeActionByResourceStatus(currentStatus, FORECAST_RESOURCE_TYPE, forecastArn)) { 42 | return; 43 | } 44 | } catch (ResourceNotFoundException e) { 45 | log.info(String.format("Cannot find %s with arn [%s]. Proceed to create a new one", 46 | FORECAST_RESOURCE_TYPE, forecastArn)); 47 | } 48 | 49 | // create a new forecast 50 | createForecast(forecastName, predictorArn); 51 | log.info("finish triggering CreateForecastCall."); 52 | 53 | String newStatus = describeForecastStatus(forecastArn); 54 | takeActionByResourceStatus(newStatus, FORECAST_RESOURCE_TYPE, forecastArn); 55 | } 56 | 57 | private void createForecast(final String forecastName, 58 | final String predictorArn) { 59 | CreateForecastRequest createForecastRequest = new CreateForecastRequest(); 60 | createForecastRequest.setForecastName(forecastName); 61 | createForecastRequest.setPredictorArn(predictorArn); 62 | 63 | forecastClient.createForecast(createForecastRequest); 64 | } 65 | 66 | private String describeForecastStatus(final String forecastArn) { 67 | DescribeForecastRequest describeForecastRequest = new DescribeForecastRequest(); 68 | describeForecastRequest.setForecastArn(forecastArn); 69 | 70 | return forecastClient.describeForecast(describeForecastRequest).getStatus(); 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /src/main/java/com/amazonaws/lambda/predictiongeneration/CreatePredictorHandler.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.lambda.predictiongeneration; 2 | 3 | import com.amazonaws.services.forecast.AmazonForecast; 4 | import com.amazonaws.services.forecast.model.CreatePredictorRequest; 5 | import com.amazonaws.services.forecast.model.DescribePredictorRequest; 6 | import com.amazonaws.services.forecast.model.FeaturizationConfig; 7 | import com.amazonaws.services.forecast.model.InputDataConfig; 8 | import com.amazonaws.services.forecast.model.ResourceNotFoundException; 9 | import com.google.common.annotations.VisibleForTesting; 10 | import lombok.extern.slf4j.Slf4j; 11 | import org.apache.commons.lang3.StringUtils; 12 | 13 | import java.util.Map; 14 | 15 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.DATASET_GROUP_ARN_KEY; 16 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.DATA_FREQUENCY_KEY; 17 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.DATA_FREQUENCY_SECONDS_MAPPING; 18 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.PREDICTOR_ARN_KEY; 19 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.PREDICTOR_NAME_KEY; 20 | 21 | @Slf4j 22 | public class CreatePredictorHandler extends AbstractPredictionGenerationLambdaHandler { 23 | 24 | private static final String PREDICTOR_RESOURCE_TYPE = "predictor"; 25 | private static final String FORECAST_PREDICTOR_ALGORITHM_ARN; 26 | static { 27 | String forecastPredictorAlgorithmArn = System.getenv("FORECAST_PREDICTOR_ALGORITHM_ARN"); 28 | log.info(String.format("forecastPredictorAlgorithmArn getting from environment variable is [%s]", forecastPredictorAlgorithmArn)); 29 | FORECAST_PREDICTOR_ALGORITHM_ARN = forecastPredictorAlgorithmArn; 30 | } 31 | 32 | @VisibleForTesting 33 | static final int SECONDS_IN_A_DAY = 86400; 34 | 35 | public CreatePredictorHandler() { 36 | super(); 37 | } 38 | 39 | public CreatePredictorHandler(final AmazonForecast forecastClient) { 40 | super(forecastClient); 41 | } 42 | 43 | @Override 44 | public void process(final Map resourceIdMap) { 45 | 46 | String datasetGroupArn = resourceIdMap.get(DATASET_GROUP_ARN_KEY); 47 | String predictorName = resourceIdMap.get(PREDICTOR_NAME_KEY); 48 | String predictorArn = resourceIdMap.get(PREDICTOR_ARN_KEY); 49 | String dataFrequency = resourceIdMap.get(DATA_FREQUENCY_KEY); 50 | log.info(String.format( 51 | "The datasetGroupArn, %s, and forecastFrequency getting from resourceIdMap are [%s], [%s], and [%s]", 52 | PREDICTOR_RESOURCE_TYPE, datasetGroupArn, predictorName, dataFrequency)); 53 | 54 | 55 | // Check if predictor exists 56 | try { 57 | String currentStatus = describePredictorStatus(predictorArn); 58 | if (takeActionByResourceStatus(currentStatus, PREDICTOR_RESOURCE_TYPE, predictorArn)) { 59 | return; 60 | } 61 | } catch (ResourceNotFoundException e) { 62 | log.info(String.format("Cannot find %s with arn [%s]. Proceed to create a new one", 63 | PREDICTOR_RESOURCE_TYPE, predictorArn)); 64 | } 65 | 66 | // Create the new predictor 67 | int forecastHorizonInDays = Integer.parseInt(System.getenv("FORECAST_HORIZON_IN_DAYS")); 68 | int forecastHorizon = forecastHorizonInDays * SECONDS_IN_A_DAY / DATA_FREQUENCY_SECONDS_MAPPING.get(dataFrequency); 69 | log.info(String.format("[forecastHorizonInDay:%d]*[SECONDS_IN_A_DAY:%d]/[DATA_FREQUENCY_SECONDS:%d]=[forecastHorizon:%d]", 70 | forecastHorizonInDays, SECONDS_IN_A_DAY, DATA_FREQUENCY_SECONDS_MAPPING.get(dataFrequency), forecastHorizon)); 71 | 72 | createPredictor(forecastHorizon, dataFrequency, datasetGroupArn, predictorName, FORECAST_PREDICTOR_ALGORITHM_ARN); 73 | log.info("finish triggering CreatePredictorCall."); 74 | 75 | String newStatus = describePredictorStatus(predictorArn); 76 | takeActionByResourceStatus(newStatus, PREDICTOR_RESOURCE_TYPE, predictorName); 77 | } 78 | 79 | private void createPredictor(final int forecastHorizon, 80 | final String forecastFrequency, 81 | final String datasetGroupArn, 82 | final String predictorName, 83 | final String predictorAlgorithmArn) { 84 | 85 | CreatePredictorRequest createPredictorRequest = new CreatePredictorRequest() 86 | .withForecastHorizon(forecastHorizon) 87 | .withFeaturizationConfig(new FeaturizationConfig().withForecastFrequency(forecastFrequency)) 88 | .withInputDataConfig(new InputDataConfig().withDatasetGroupArn(datasetGroupArn)) 89 | .withPredictorName(predictorName); 90 | if (StringUtils.isBlank(predictorAlgorithmArn)) { 91 | createPredictorRequest.setPerformAutoML(true); 92 | } else { 93 | createPredictorRequest.setAlgorithmArn(predictorAlgorithmArn); 94 | } 95 | 96 | forecastClient.createPredictor(createPredictorRequest); 97 | } 98 | 99 | private String describePredictorStatus(final String predictorArn) { 100 | DescribePredictorRequest describePredictorRequest = new DescribePredictorRequest(); 101 | describePredictorRequest.setPredictorArn(predictorArn); 102 | return forecastClient.describePredictor(describePredictorRequest).getStatus(); 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /src/main/java/com/amazonaws/lambda/predictiongeneration/DeleteOutdatedDatasetGroupsHandler.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.lambda.predictiongeneration; 2 | 3 | import com.amazonaws.lambda.predictiongeneration.exception.ResourceCleanupInProgressException; 4 | import com.amazonaws.services.forecast.AmazonForecast; 5 | import com.amazonaws.services.forecast.model.DatasetGroupSummary; 6 | import com.amazonaws.services.forecast.model.DeleteDatasetGroupRequest; 7 | import com.amazonaws.services.forecast.model.ListDatasetGroupsRequest; 8 | import com.amazonaws.services.forecast.model.ListDatasetGroupsResult; 9 | import lombok.extern.slf4j.Slf4j; 10 | 11 | import java.util.ArrayList; 12 | import java.util.Collections; 13 | import java.util.List; 14 | import java.util.Map; 15 | import java.util.stream.Collectors; 16 | 17 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.DATASET_GROUP_ARN_KEY; 18 | 19 | @Slf4j 20 | public class DeleteOutdatedDatasetGroupsHandler extends AbstractPredictionGenerationLambdaHandler { 21 | 22 | public DeleteOutdatedDatasetGroupsHandler() { 23 | super(); 24 | } 25 | 26 | public DeleteOutdatedDatasetGroupsHandler(final AmazonForecast forecastClient) { 27 | super(forecastClient); 28 | } 29 | 30 | @Override 31 | public void process(final Map resourceIdMap) { 32 | 33 | String preservedDatasetGroupArn = resourceIdMap.get(DATASET_GROUP_ARN_KEY); 34 | log.info(String.format("The preserved datasetGroupArn getting from resourceIdMap is %s", preservedDatasetGroupArn)); 35 | 36 | // Get all existing datasetGroups and exclude the preserved one 37 | List outdatedDatasetGroups = listDatasetGroupArns(); 38 | 39 | if (outdatedDatasetGroups.isEmpty()) { 40 | throw new IllegalStateException("There is no existing datasetGroup."); 41 | } 42 | 43 | outdatedDatasetGroups.remove(preservedDatasetGroupArn); 44 | if (outdatedDatasetGroups.isEmpty()) { 45 | log.info("Don't find any outdated datasetGroup, returning"); 46 | return; 47 | } 48 | 49 | // Delete all outdated datasetGroups 50 | for (String outdatedDatasetGroupArn : outdatedDatasetGroups) { 51 | deleteDatasetGroup(outdatedDatasetGroupArn); 52 | } 53 | 54 | // Verify there is no outdated datasetGroups 55 | List existingDatasetGroups = listDatasetGroupArns(); 56 | if (!Collections.singletonList(preservedDatasetGroupArn).equals(existingDatasetGroups)) { 57 | throw new ResourceCleanupInProgressException( 58 | String.format("Outdated datasetGroups cleanup is in progress with existing datasetGroups %s", 59 | existingDatasetGroups.toString())); 60 | } 61 | 62 | log.info("Successfully clean up outdated datasetGroups."); 63 | } 64 | 65 | private void deleteDatasetGroup(final String datasetGroupArn) { 66 | DeleteDatasetGroupRequest deleteDatasetGroupRequest = new DeleteDatasetGroupRequest(); 67 | deleteDatasetGroupRequest.setDatasetGroupArn(datasetGroupArn); 68 | forecastClient.deleteDatasetGroup(deleteDatasetGroupRequest); 69 | } 70 | 71 | private List listDatasetGroupArns() { 72 | List existingDatasetGroups = new ArrayList<>(); 73 | String nextToken = null; 74 | do { 75 | ListDatasetGroupsRequest listDatasetGroupsRequest = new ListDatasetGroupsRequest(); 76 | if (nextToken != null) { 77 | listDatasetGroupsRequest.setNextToken(nextToken); 78 | } 79 | ListDatasetGroupsResult listDatasetGroupsResult = forecastClient.listDatasetGroups(listDatasetGroupsRequest); 80 | 81 | existingDatasetGroups.addAll( 82 | listDatasetGroupsResult.getDatasetGroups().stream() 83 | .map(DatasetGroupSummary::getDatasetGroupArn).collect(Collectors.toList())); 84 | nextToken = listDatasetGroupsResult.getNextToken(); 85 | } while (nextToken != null); 86 | 87 | return existingDatasetGroups; 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /src/main/java/com/amazonaws/lambda/predictiongeneration/DeleteOutdatedDatasetImportJobsHandler.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.lambda.predictiongeneration; 2 | 3 | import com.amazonaws.lambda.predictiongeneration.exception.ResourceCleanupInProgressException; 4 | import com.amazonaws.services.forecast.AmazonForecast; 5 | import com.amazonaws.services.forecast.model.DatasetImportJobSummary; 6 | import com.amazonaws.services.forecast.model.DeleteDatasetImportJobRequest; 7 | import com.amazonaws.services.forecast.model.Filter; 8 | import com.amazonaws.services.forecast.model.FilterConditionString; 9 | import com.amazonaws.services.forecast.model.ListDatasetImportJobsRequest; 10 | import com.amazonaws.services.forecast.model.ListDatasetImportJobsResult; 11 | import com.amazonaws.services.forecast.model.ResourceNotFoundException; 12 | import lombok.extern.slf4j.Slf4j; 13 | import org.apache.commons.collections4.CollectionUtils; 14 | 15 | import java.util.ArrayList; 16 | import java.util.HashMap; 17 | import java.util.List; 18 | import java.util.Map; 19 | import java.util.stream.Collectors; 20 | 21 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.DATASET_ARN_KEY; 22 | 23 | @Slf4j 24 | public class DeleteOutdatedDatasetImportJobsHandler extends AbstractPredictionGenerationLambdaHandler { 25 | 26 | public DeleteOutdatedDatasetImportJobsHandler() { 27 | super(); 28 | } 29 | 30 | public DeleteOutdatedDatasetImportJobsHandler(final AmazonForecast forecastClient) { 31 | super(forecastClient); 32 | } 33 | 34 | @Override 35 | public void process(final Map resourceIdMap) { 36 | 37 | String preservedDatasetArn = resourceIdMap.get(DATASET_ARN_KEY); 38 | log.info(String.format("The preserved datasetArn getting from resourceIdMap is [%s]", preservedDatasetArn)); 39 | 40 | // Get all existing datasetImportJobs and exclude the ones associated with the preserved dataset name 41 | Map> outdatedDatasetImportJobsMap = listOutdatedDatasetImportJobArns(preservedDatasetArn); 42 | 43 | if (CollectionUtils.isEmpty(outdatedDatasetImportJobsMap.keySet())) { 44 | log.info("Don't find any outdated dataset import job, returning"); 45 | return; 46 | } 47 | 48 | // Delete all datasetImportJobs associated with outdated datasets 49 | outdatedDatasetImportJobsMap.values().stream().flatMap(List::stream).forEach(this::deleteDatasetImportJob); 50 | 51 | // Verify there is no outdated datasetImportJobs 52 | Map> outdatedDatasetImportJobsMapAfterCleanup = listOutdatedDatasetImportJobArns(preservedDatasetArn); 53 | if (CollectionUtils.isNotEmpty(outdatedDatasetImportJobsMapAfterCleanup.keySet())) { 54 | throw new ResourceCleanupInProgressException( 55 | String.format("Outdated datasetImportJobs cleanup is in progress with outdated datasetImportJobs [%s]", 56 | outdatedDatasetImportJobsMapAfterCleanup.keySet().toString())); 57 | } 58 | 59 | log.info("Successfully clean up outdated datasetImportJobs."); 60 | } 61 | 62 | private void deleteDatasetImportJob(final String datasetImportJobArn) { 63 | DeleteDatasetImportJobRequest deleteDatasetImportJobRequest = 64 | new DeleteDatasetImportJobRequest().withDatasetImportJobArn(datasetImportJobArn); 65 | 66 | log.info(String.format("About to delete datasetImportJob: %s", datasetImportJobArn)); 67 | 68 | try { 69 | forecastClient.deleteDatasetImportJob(deleteDatasetImportJobRequest); 70 | } catch (ResourceNotFoundException ex) { 71 | log.info(String.format("DatasetImportJob [%s] has already been deleted", datasetImportJobArn)); 72 | } 73 | } 74 | 75 | private Map> listOutdatedDatasetImportJobArns(final String preservedDatasetArn) { 76 | List outdatedDatasetArns = listOutdatedDatasetArns(preservedDatasetArn); 77 | 78 | Map> outdatedDatasetImportJobsMap = new HashMap<>(); 79 | 80 | for (String outdatedDatasetArn : outdatedDatasetArns) { 81 | List outdatedDatasetImportJobArns = new ArrayList<>(); 82 | String nextToken = null; 83 | ListDatasetImportJobsRequest listDatasetImportJobsRequest = 84 | new ListDatasetImportJobsRequest().withFilters( 85 | new Filter() 86 | .withKey("DatasetArn") 87 | .withValue(outdatedDatasetArn) 88 | .withCondition(FilterConditionString.IS)); 89 | 90 | do { 91 | if (nextToken != null) { 92 | listDatasetImportJobsRequest.setNextToken(nextToken); 93 | } 94 | ListDatasetImportJobsResult listDatasetImportJobsResult = forecastClient 95 | .listDatasetImportJobs(listDatasetImportJobsRequest); 96 | 97 | outdatedDatasetImportJobArns.addAll(listDatasetImportJobsResult.getDatasetImportJobs() 98 | .stream().map(DatasetImportJobSummary::getDatasetImportJobArn).collect(Collectors.toList())); 99 | nextToken = listDatasetImportJobsResult.getNextToken(); 100 | } while (nextToken != null); 101 | 102 | if (!outdatedDatasetImportJobArns.isEmpty()) { 103 | outdatedDatasetImportJobsMap.put(outdatedDatasetArn, outdatedDatasetImportJobArns); 104 | } 105 | } 106 | 107 | return outdatedDatasetImportJobsMap; 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /src/main/java/com/amazonaws/lambda/predictiongeneration/DeleteOutdatedDatasetsHandler.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.lambda.predictiongeneration; 2 | 3 | import com.amazonaws.lambda.predictiongeneration.exception.ResourceCleanupInProgressException; 4 | import com.amazonaws.services.forecast.AmazonForecast; 5 | import com.amazonaws.services.forecast.model.DeleteDatasetRequest; 6 | import lombok.extern.slf4j.Slf4j; 7 | import org.apache.commons.collections4.CollectionUtils; 8 | 9 | import java.util.List; 10 | import java.util.Map; 11 | 12 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.DATASET_ARN_KEY; 13 | 14 | @Slf4j 15 | public class DeleteOutdatedDatasetsHandler extends AbstractPredictionGenerationLambdaHandler { 16 | 17 | public DeleteOutdatedDatasetsHandler() { 18 | super(); 19 | } 20 | 21 | public DeleteOutdatedDatasetsHandler(final AmazonForecast forecastClient) { 22 | super(forecastClient); 23 | } 24 | 25 | @Override 26 | public void process(final Map resourceIdMap) { 27 | 28 | String preservedDatasetArn = resourceIdMap.get(DATASET_ARN_KEY); 29 | log.info(String.format("The preserved datasetArn getting from resourceIdMap is [%s]", preservedDatasetArn)); 30 | 31 | // Get all existing datasets and exclude the preserved one 32 | List outdatedDatasetArns = listOutdatedDatasetArns(preservedDatasetArn); 33 | 34 | if (outdatedDatasetArns.isEmpty()) { 35 | log.info("Don't find any outdated dataset, returning"); 36 | return; 37 | } 38 | 39 | // Delete all outdated datasets 40 | for (String outdatedDatasetArn : outdatedDatasetArns) { 41 | deleteDataset(outdatedDatasetArn); 42 | } 43 | 44 | // Verify there is no outdated datasets 45 | List outdatedDatasetsAfterCleanup = listOutdatedDatasetArns(preservedDatasetArn); 46 | if (CollectionUtils.isNotEmpty(outdatedDatasetsAfterCleanup)) { 47 | throw new ResourceCleanupInProgressException( 48 | String.format("Outdated datasets cleanup is in progress with outdated datasets [%s]", 49 | outdatedDatasetsAfterCleanup.toString())); 50 | } 51 | 52 | log.info("Successfully clean up outdated datasets."); 53 | } 54 | 55 | private void deleteDataset(final String datasetArn) { 56 | DeleteDatasetRequest deleteDatasetRequest = new DeleteDatasetRequest(); 57 | deleteDatasetRequest.setDatasetArn(datasetArn); 58 | forecastClient.deleteDataset(deleteDatasetRequest); 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /src/main/java/com/amazonaws/lambda/predictiongeneration/DeleteOutdatedForecastExportJobsHandler.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.lambda.predictiongeneration; 2 | 3 | import com.amazonaws.lambda.predictiongeneration.exception.ResourceCleanupInProgressException; 4 | import com.amazonaws.services.forecast.AmazonForecast; 5 | import com.amazonaws.services.forecast.model.DeleteForecastExportJobRequest; 6 | import com.amazonaws.services.forecast.model.Filter; 7 | import com.amazonaws.services.forecast.model.FilterConditionString; 8 | import com.amazonaws.services.forecast.model.ForecastExportJobSummary; 9 | import com.amazonaws.services.forecast.model.ListForecastExportJobsRequest; 10 | import com.amazonaws.services.forecast.model.ListForecastExportJobsResult; 11 | import lombok.extern.slf4j.Slf4j; 12 | import org.apache.commons.collections4.CollectionUtils; 13 | 14 | import java.util.ArrayList; 15 | import java.util.List; 16 | import java.util.Map; 17 | import java.util.stream.Collectors; 18 | 19 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.FORECAST_ARN_KEY; 20 | 21 | @Slf4j 22 | public class DeleteOutdatedForecastExportJobsHandler extends AbstractPredictionGenerationLambdaHandler { 23 | 24 | public DeleteOutdatedForecastExportJobsHandler() { 25 | super(); 26 | } 27 | 28 | public DeleteOutdatedForecastExportJobsHandler(AmazonForecast forecastClient) { 29 | super(forecastClient); 30 | } 31 | 32 | @Override 33 | public void process(final Map resourceIdMap) { 34 | String preservedForecastArn = resourceIdMap.get(FORECAST_ARN_KEY); 35 | log.info(String.format("The preservedForecastArn getting from resourceIdMap is [%s]", preservedForecastArn)); 36 | 37 | // Get all existing datasetImportJobs and exclude the ones associated with the preserved dataset name 38 | List outdatedForecastExportJobArns = listOutdatedForecastExportJobArns(preservedForecastArn); 39 | 40 | if (CollectionUtils.isEmpty(outdatedForecastExportJobArns)) { 41 | log.info("Don't find any outdated forecast export job, returning"); 42 | return; 43 | } 44 | 45 | // Delete all forecastExportJobs associated with outdated forecasts 46 | outdatedForecastExportJobArns.forEach(this::deleteForecastExportJob); 47 | 48 | // Verify there is no outdated forecastExportJobs 49 | List outdatedForecastExportJobArnsAfterCleanup = listOutdatedForecastExportJobArns(preservedForecastArn); 50 | if (CollectionUtils.isNotEmpty(outdatedForecastExportJobArnsAfterCleanup)) { 51 | throw new ResourceCleanupInProgressException( 52 | String.format("Outdated forecastExportJobs cleanup is in progress with outdated forecastExportJobs [%s]", 53 | outdatedForecastExportJobArnsAfterCleanup)); 54 | } 55 | 56 | log.info("Successfully clean up outdated forecastExportJobs."); 57 | } 58 | 59 | private void deleteForecastExportJob(final String forecastExportJobArn) { 60 | forecastClient.deleteForecastExportJob(new DeleteForecastExportJobRequest().withForecastExportJobArn(forecastExportJobArn)); 61 | } 62 | 63 | private List listOutdatedForecastExportJobArns(final String preservedForecastArn) { 64 | 65 | List outdatedForecastExportJobArns = new ArrayList<>(); 66 | String nextToken = null; 67 | ListForecastExportJobsRequest listForecastExportJobsRequest = 68 | new ListForecastExportJobsRequest().withFilters( 69 | new Filter() 70 | .withKey("ForecastArn") 71 | .withValue(preservedForecastArn) 72 | .withCondition(FilterConditionString.IS_NOT)); 73 | 74 | do { 75 | if (nextToken != null) { 76 | listForecastExportJobsRequest.setNextToken(nextToken); 77 | } 78 | ListForecastExportJobsResult listForecastExportJobsResult = forecastClient 79 | .listForecastExportJobs(listForecastExportJobsRequest); 80 | 81 | outdatedForecastExportJobArns.addAll(listForecastExportJobsResult.getForecastExportJobs() 82 | .stream().map(ForecastExportJobSummary::getForecastExportJobArn).collect(Collectors.toList())); 83 | nextToken = listForecastExportJobsResult.getNextToken(); 84 | } while (nextToken != null); 85 | 86 | return outdatedForecastExportJobArns; 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /src/main/java/com/amazonaws/lambda/predictiongeneration/DeleteOutdatedForecastsHandler.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.lambda.predictiongeneration; 2 | 3 | import com.amazonaws.services.forecast.AmazonForecast; 4 | import com.amazonaws.services.forecast.model.DeleteForecastRequest; 5 | import com.amazonaws.services.forecast.model.Filter; 6 | import com.amazonaws.services.forecast.model.ForecastSummary; 7 | import com.amazonaws.services.forecast.model.ListForecastsRequest; 8 | import com.amazonaws.services.forecast.model.ListForecastsResult; 9 | import com.google.common.collect.Lists; 10 | import lombok.extern.slf4j.Slf4j; 11 | import org.apache.commons.lang3.StringUtils; 12 | 13 | import java.util.ArrayList; 14 | import java.util.Comparator; 15 | import java.util.List; 16 | import java.util.Map; 17 | 18 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.FORECAST_ARN_KEY; 19 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.PREDICTOR_ARN_KEY; 20 | 21 | @Slf4j 22 | public class DeleteOutdatedForecastsHandler extends AbstractPredictionGenerationLambdaHandler { 23 | 24 | public DeleteOutdatedForecastsHandler() { 25 | super(); 26 | } 27 | 28 | public DeleteOutdatedForecastsHandler(AmazonForecast forecastClient) { 29 | super(forecastClient); 30 | } 31 | 32 | @Override 33 | public void process(final Map resourceIdMap) { 34 | String currentForecastArn = resourceIdMap.get(FORECAST_ARN_KEY); 35 | String preservedPredictorArn = resourceIdMap.get(PREDICTOR_ARN_KEY); 36 | log.info(String.format("The currentForecastArn and preservedPredictorArn getting from resourceIdMap are [%s], [%s]", 37 | currentForecastArn, preservedPredictorArn)); 38 | 39 | // Get all existing predictors and exclude the preserved one 40 | List outdatedPredictors = listOutdatedPredictorArns(preservedPredictorArn); 41 | outdatedPredictors.remove(preservedPredictorArn); 42 | 43 | // Delete all forecasts for all outdated predictors 44 | if (!outdatedPredictors.isEmpty()) { 45 | outdatedPredictors.forEach( 46 | outdatedPredictorArn -> { 47 | log.info(String.format("About to delete forecasts for outdated predictorArn [%s]", outdatedPredictorArn)); 48 | List outdatedForecasts = listActiveForeacasts(outdatedPredictorArn, null); 49 | outdatedForecasts.forEach(outdatedForecast -> { 50 | deleteForecast(outdatedForecast.getForecastArn()); 51 | }); 52 | } 53 | ); 54 | } 55 | 56 | // Get all existing forecasts associated with given predictorArn 57 | List outdatedForecasts = listActiveForeacasts(preservedPredictorArn, "ACTIVE"); 58 | 59 | // Remove the current processing forecast from the list 60 | outdatedForecasts.removeIf(forecast -> currentForecastArn.equals(forecast.getForecastArn())); 61 | 62 | int numberOfOutdatedForecasts = outdatedForecasts.size(); 63 | if (numberOfOutdatedForecasts > 5) { 64 | outdatedForecasts 65 | .stream() 66 | .sorted(Comparator.comparing(ForecastSummary::getCreationTime)) 67 | .limit(numberOfOutdatedForecasts - 5) 68 | .forEach(forecast -> deleteForecast(forecast.getForecastArn())); 69 | } else { 70 | log.info(String.format("We only have %s outdated forecasts, no need to delete", numberOfOutdatedForecasts)); 71 | } 72 | } 73 | 74 | private void deleteForecast(final String forecastArn) { 75 | log.info(String.format("About to delete forecastArn [%s].", forecastArn)); 76 | 77 | forecastClient.deleteForecast(new DeleteForecastRequest().withForecastArn(forecastArn)); 78 | } 79 | 80 | /** 81 | * @param predictorArn the predictor arn associated with forecasts 82 | * @return existing forecasts with Active status 83 | */ 84 | private List listActiveForeacasts(final String predictorArn, 85 | final String status) { 86 | List existingForecasts = new ArrayList<>(); 87 | String nextToken = null; 88 | do { 89 | ListForecastsRequest listForecastsRequest = new ListForecastsRequest(); 90 | List filters = Lists.newArrayList(new Filter().withCondition("IS").withKey("PredictorArn").withValue(predictorArn)); 91 | if (StringUtils.isNotBlank(status)) { 92 | filters.add(new Filter().withCondition("IS").withKey("Status").withValue(status)); 93 | } 94 | listForecastsRequest.setFilters(filters); 95 | 96 | if (nextToken != null) { 97 | listForecastsRequest.setNextToken(nextToken); 98 | } 99 | ListForecastsResult listForecastsResult = forecastClient.listForecasts(listForecastsRequest); 100 | 101 | existingForecasts.addAll(listForecastsResult.getForecasts()); 102 | nextToken = listForecastsResult.getNextToken(); 103 | } while (nextToken != null); 104 | 105 | return existingForecasts; 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /src/main/java/com/amazonaws/lambda/predictiongeneration/DeleteOutdatedPredictorsHandler.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.lambda.predictiongeneration; 2 | 3 | import com.amazonaws.lambda.predictiongeneration.exception.ResourceCleanupInProgressException; 4 | import com.amazonaws.services.forecast.AmazonForecast; 5 | import com.amazonaws.services.forecast.model.DeletePredictorRequest; 6 | import lombok.extern.slf4j.Slf4j; 7 | import org.apache.commons.collections4.CollectionUtils; 8 | 9 | import java.util.List; 10 | import java.util.Map; 11 | 12 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.PREDICTOR_ARN_KEY; 13 | 14 | @Slf4j 15 | public class DeleteOutdatedPredictorsHandler extends AbstractPredictionGenerationLambdaHandler{ 16 | 17 | public DeleteOutdatedPredictorsHandler() { 18 | super(); 19 | } 20 | 21 | public DeleteOutdatedPredictorsHandler(final AmazonForecast forecastClient) { 22 | super(forecastClient); 23 | } 24 | 25 | @Override 26 | public void process(final Map resourceIdMap) { 27 | 28 | String preservedPredictorArn = resourceIdMap.get(PREDICTOR_ARN_KEY); 29 | log.info(String.format("The preserved predictorArn getting from resourceIdMap is [%s]", preservedPredictorArn)); 30 | 31 | // Get all existing predictors and exclude the preserved one 32 | List outdatedPredictors = listOutdatedPredictorArns(preservedPredictorArn); 33 | 34 | if (outdatedPredictors.isEmpty()) { 35 | log.info("Don't find any outdated predictor, returning"); 36 | return; 37 | } 38 | 39 | // Delete all outdated predictors 40 | for (String outdatedPredictorArn : outdatedPredictors) { 41 | deletePredictor(outdatedPredictorArn); 42 | } 43 | 44 | // Verify there is no outdated predictors 45 | List outdatedPredictorArnsAfterCleanup = listOutdatedPredictorArns(preservedPredictorArn); 46 | if (CollectionUtils.isNotEmpty(outdatedPredictorArnsAfterCleanup)) { 47 | throw new ResourceCleanupInProgressException( 48 | String.format("Outdated predictors cleanup is in progress with outdated predictors [%s]", 49 | outdatedPredictorArnsAfterCleanup.toString())); 50 | } 51 | 52 | log.info("Successfully clean up outdated predictors."); 53 | } 54 | 55 | private void deletePredictor(final String predictorArn) { 56 | DeletePredictorRequest deletePredictorRequest = new DeletePredictorRequest(); 57 | deletePredictorRequest.setPredictorArn(predictorArn); 58 | forecastClient.deletePredictor(deletePredictorRequest); 59 | } 60 | 61 | } 62 | -------------------------------------------------------------------------------- /src/main/java/com/amazonaws/lambda/predictiongeneration/GenerateForecastResourcesIdsCronHandler.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.lambda.predictiongeneration; 2 | 3 | import com.amazonaws.dagger.DaggerLambdaFunctionsComponent; 4 | import com.amazonaws.services.forecast.AmazonForecast; 5 | import com.amazonaws.services.forecast.model.DatasetSummary; 6 | import com.amazonaws.services.forecast.model.PredictorSummary; 7 | import com.amazonaws.services.lambda.runtime.Context; 8 | import com.amazonaws.services.lambda.runtime.RequestHandler; 9 | import com.fasterxml.jackson.core.JsonProcessingException; 10 | import com.fasterxml.jackson.databind.ObjectMapper; 11 | import com.google.common.annotations.VisibleForTesting; 12 | import com.google.common.collect.ImmutableMap; 13 | import lombok.NonNull; 14 | import lombok.extern.slf4j.Slf4j; 15 | 16 | import javax.inject.Inject; 17 | import java.time.Clock; 18 | import java.util.HashMap; 19 | import java.util.Map; 20 | 21 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.DATASET_ARN_KEY; 22 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.DATASET_IMPORT_JOB_ARN_KEY; 23 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.DATASET_IMPORT_JOB_NAME_KEY; 24 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.DATASET_IMPORT_JOB_NAME_PREFIX; 25 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.FORECAST_ARN_KEY; 26 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.FORECAST_EXPORT_JOB_ARN_KEY; 27 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.FORECAST_EXPORT_JOB_NAME_KEY; 28 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.FORECAST_EXPORT_JOB_NAME_PREFIX; 29 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.FORECAST_NAME_KEY; 30 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.FORECAST_NAME_PREFIX; 31 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.PREDICTOR_ARN_KEY; 32 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.deriveForecastResourceArnPrefixFromLambdaFunctionArn; 33 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.getLatestDataset; 34 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.getLatestPredictor; 35 | 36 | @Slf4j 37 | public class GenerateForecastResourcesIdsCronHandler implements RequestHandler { 38 | 39 | private final Clock clock; 40 | 41 | @Inject 42 | @NonNull 43 | AmazonForecast forecastClient; 44 | 45 | public GenerateForecastResourcesIdsCronHandler() { 46 | this(Clock.systemUTC()); 47 | } 48 | 49 | public GenerateForecastResourcesIdsCronHandler(final Clock clock) { 50 | this.clock = clock; 51 | DaggerLambdaFunctionsComponent.create().inject(this); 52 | } 53 | 54 | public GenerateForecastResourcesIdsCronHandler(final Clock clock, 55 | final AmazonForecast forecastClient) { 56 | this.clock = clock; 57 | this.forecastClient = forecastClient; 58 | } 59 | 60 | public String handleRequest(Void input, Context context) { 61 | 62 | final DatasetSummary latestDataset = getLatestDataset(forecastClient); 63 | if (latestDataset == null) { 64 | throw new IllegalStateException("cannot find any dataset"); 65 | } 66 | 67 | final PredictorSummary latestPredictor = getLatestPredictor(forecastClient); 68 | if (latestPredictor == null) { 69 | throw new IllegalStateException("cannot find any predictor"); 70 | } 71 | 72 | String functionArn = context.getInvokedFunctionArn(); 73 | String forecastResourceArnPrefix = deriveForecastResourceArnPrefixFromLambdaFunctionArn(functionArn); 74 | 75 | long currentTime = clock.millis(); 76 | Map cronResourceIdMap = buildCronResourceIdMap(currentTime, 77 | forecastResourceArnPrefix, latestDataset.getDatasetName(), latestPredictor.getPredictorArn()); 78 | 79 | String cronResourceIdMapAsJson; 80 | try { 81 | cronResourceIdMapAsJson = new ObjectMapper().writeValueAsString(cronResourceIdMap); 82 | } catch (JsonProcessingException e) { 83 | String errorMsg = e.getMessage(); 84 | log.error(errorMsg); 85 | throw new RuntimeException(errorMsg); 86 | } 87 | 88 | log.info("Returning cronResourceIdMapAsJson value is " + cronResourceIdMapAsJson); 89 | return cronResourceIdMapAsJson; 90 | } 91 | 92 | @VisibleForTesting 93 | static Map buildCronResourceIdMap(final long timestamp, 94 | final String forecastResourceArnPrefix, 95 | final String datasetName, 96 | final String predictorArn) { 97 | 98 | String datasetImportJobName = DATASET_IMPORT_JOB_NAME_PREFIX + timestamp; 99 | String forecastName = FORECAST_NAME_PREFIX + timestamp; 100 | String forecastExportJobName = FORECAST_EXPORT_JOB_NAME_PREFIX + timestamp; 101 | 102 | return ImmutableMap.builder() 103 | .put(DATASET_ARN_KEY, forecastResourceArnPrefix + "dataset/" + datasetName) 104 | .put(DATASET_IMPORT_JOB_NAME_KEY, datasetImportJobName) 105 | .put(DATASET_IMPORT_JOB_ARN_KEY, forecastResourceArnPrefix 106 | + "dataset-import-job/" + datasetName + "/" + datasetImportJobName) 107 | .put(PREDICTOR_ARN_KEY, predictorArn) 108 | .put(FORECAST_NAME_KEY, forecastName) 109 | .put(FORECAST_ARN_KEY, forecastResourceArnPrefix + "forecast/" + forecastName) 110 | .put(FORECAST_EXPORT_JOB_NAME_KEY, forecastExportJobName) 111 | .put(FORECAST_EXPORT_JOB_ARN_KEY, forecastResourceArnPrefix 112 | + "forecast-export-job/" + forecastName + "/" + forecastExportJobName) 113 | .build(); 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /src/main/java/com/amazonaws/lambda/predictiongeneration/GenerateForecastResourcesIdsHandler.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.lambda.predictiongeneration; 2 | 3 | import com.amazonaws.dagger.DaggerLambdaFunctionsComponent; 4 | import com.amazonaws.lambda.predictiongeneration.exception.ResourceSetupInProgressException; 5 | import com.amazonaws.services.lambda.runtime.Context; 6 | import com.amazonaws.services.lambda.runtime.RequestHandler; 7 | import com.amazonaws.services.s3.AmazonS3; 8 | import com.amazonaws.services.s3.model.AmazonS3Exception; 9 | import com.amazonaws.services.s3.model.GetObjectMetadataRequest; 10 | import com.amazonaws.services.s3.model.ObjectMetadata; 11 | import com.fasterxml.jackson.core.JsonProcessingException; 12 | import com.fasterxml.jackson.databind.ObjectMapper; 13 | import com.google.common.annotations.VisibleForTesting; 14 | import com.google.common.collect.ImmutableMap; 15 | import lombok.NonNull; 16 | import lombok.extern.slf4j.Slf4j; 17 | 18 | import javax.inject.Inject; 19 | import java.time.Clock; 20 | import java.time.Duration; 21 | import java.util.Map; 22 | 23 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.DATASET_ARN_KEY; 24 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.DATASET_GROUP_ARN_KEY; 25 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.DATASET_GROUP_NAME_KEY; 26 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.DATASET_GROUP_NAME_PREFIX; 27 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.DATASET_IMPORT_JOB_ARN_KEY; 28 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.DATASET_IMPORT_JOB_NAME_KEY; 29 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.DATASET_IMPORT_JOB_NAME_PREFIX; 30 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.DATASET_NAME_KEY; 31 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.DATASET_NAME_PREFIX; 32 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.DATA_FREQUENCY_KEY; 33 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.FORECAST_ARN_KEY; 34 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.FORECAST_NAME_KEY; 35 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.FORECAST_NAME_PREFIX; 36 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.FORECAST_RESOURCE_ARN_PREFIX_KEY; 37 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.PREDICTOR_ARN_KEY; 38 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.PREDICTOR_NAME_KEY; 39 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.PREDICTOR_NAME_PREFIX; 40 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.deriveForecastResourceArnPrefixFromLambdaFunctionArn; 41 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.getForecastDataFrequencyStr; 42 | 43 | @Slf4j 44 | public class GenerateForecastResourcesIdsHandler implements RequestHandler { 45 | 46 | private static final Duration PREDICTION_WINDOW_SIZE_DURATION = Duration.ofSeconds(2000); 47 | private static final String PREDICTION_S3_BUCKET_NAME = System.getenv("PREDICTION_S3_BUCKET_NAME"); 48 | private static final String PREDICTION_S3_HISTORICAL_DEMAND_FOLDER = System.getenv("SRC_S3_FOLDER"); 49 | private static final String PREDICTION_S3_HISTORICAL_DEMAND_FILE_NAME = System.getenv("S3_TRAINING_DATA_FILE_NAME"); 50 | 51 | private final Clock clock; 52 | 53 | @Inject 54 | @NonNull 55 | AmazonS3 s3Client; 56 | 57 | public GenerateForecastResourcesIdsHandler() { 58 | this(Clock.systemUTC()); 59 | } 60 | 61 | public GenerateForecastResourcesIdsHandler(final Clock clock) { 62 | this.clock = clock; 63 | DaggerLambdaFunctionsComponent.create().inject(this); 64 | } 65 | 66 | @VisibleForTesting 67 | GenerateForecastResourcesIdsHandler(final Clock clock, final AmazonS3 s3Client) { 68 | this.clock = clock; 69 | this.s3Client = s3Client; 70 | } 71 | 72 | public String handleRequest(Void input, Context context) { 73 | 74 | sanityCheck(); 75 | 76 | long currentTime = clock.millis(); 77 | 78 | String functionArn = context.getInvokedFunctionArn(); 79 | String forecastResourceArnPrefix = deriveForecastResourceArnPrefixFromLambdaFunctionArn(functionArn); 80 | 81 | String dataFrequencyValue = getForecastDataFrequencyStr(PREDICTION_WINDOW_SIZE_DURATION); 82 | Map resourceIdMap = buildResourceIdMap(currentTime, forecastResourceArnPrefix, dataFrequencyValue); 83 | 84 | String resourceIdMapAsJson; 85 | try { 86 | resourceIdMapAsJson = new ObjectMapper().writeValueAsString(resourceIdMap); 87 | } catch (JsonProcessingException e) { 88 | String errorMsg = e.getMessage(); 89 | log.error(errorMsg); 90 | throw new RuntimeException(errorMsg); 91 | } 92 | 93 | log.info("Returning resourceIdMapAsJson value is " + resourceIdMapAsJson); 94 | return resourceIdMapAsJson; 95 | } 96 | 97 | private void sanityCheck() { 98 | ObjectMetadata s3ObjectMetadata; 99 | try { 100 | GetObjectMetadataRequest getObjectMetadataRequest = 101 | new GetObjectMetadataRequest(PREDICTION_S3_BUCKET_NAME, 102 | String.format("%s/%s", PREDICTION_S3_HISTORICAL_DEMAND_FOLDER, 103 | PREDICTION_S3_HISTORICAL_DEMAND_FILE_NAME)); 104 | s3ObjectMetadata = s3Client.getObjectMetadata(getObjectMetadataRequest); 105 | } catch (AmazonS3Exception e) { 106 | throw new ResourceSetupInProgressException(String.format("Got exception while getting info of the demand source file: %s", 107 | e.getMessage())); 108 | } 109 | 110 | if (s3ObjectMetadata.getContentLength() == 0) { 111 | throw new ResourceSetupInProgressException("The demand source file is empty"); 112 | } 113 | } 114 | 115 | @VisibleForTesting 116 | static Map buildResourceIdMap(final long timestamp, 117 | final String forecastResourceArnPrefix, 118 | final String dataFrequencyValue) { 119 | 120 | String datasetName = DATASET_NAME_PREFIX + timestamp; 121 | String datasetGroupName = DATASET_GROUP_NAME_PREFIX + timestamp; 122 | String datasetImportJobName = DATASET_IMPORT_JOB_NAME_PREFIX + timestamp; 123 | String predictorName = PREDICTOR_NAME_PREFIX + timestamp; 124 | String forecastName = FORECAST_NAME_PREFIX + timestamp; 125 | 126 | return ImmutableMap.builder() 127 | .put(FORECAST_RESOURCE_ARN_PREFIX_KEY, forecastResourceArnPrefix) 128 | .put(DATASET_NAME_KEY, datasetName) 129 | .put(DATASET_ARN_KEY, forecastResourceArnPrefix + "dataset/" + datasetName) 130 | .put(DATASET_GROUP_NAME_KEY, datasetGroupName) 131 | .put(DATASET_GROUP_ARN_KEY, forecastResourceArnPrefix + "dataset-group/" + datasetGroupName) 132 | .put(DATASET_IMPORT_JOB_NAME_KEY, datasetImportJobName) 133 | .put(DATASET_IMPORT_JOB_ARN_KEY, forecastResourceArnPrefix 134 | + "dataset-import-job/" + datasetName + "/" + datasetImportJobName) 135 | .put(PREDICTOR_NAME_KEY, predictorName) 136 | .put(PREDICTOR_ARN_KEY, forecastResourceArnPrefix + "predictor/" + predictorName) 137 | .put(FORECAST_NAME_KEY, forecastName) 138 | .put(FORECAST_ARN_KEY, forecastResourceArnPrefix + "forecast/" + forecastName) 139 | .put(DATA_FREQUENCY_KEY, dataFrequencyValue) 140 | .build(); 141 | } 142 | } 143 | -------------------------------------------------------------------------------- /src/main/java/com/amazonaws/lambda/predictiongeneration/PredictionGenerationUtils.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.lambda.predictiongeneration; 2 | 3 | import com.amazonaws.services.forecast.AmazonForecast; 4 | import com.amazonaws.services.forecast.model.DatasetSummary; 5 | import com.amazonaws.services.forecast.model.ForecastSummary; 6 | import com.amazonaws.services.forecast.model.ListDatasetsRequest; 7 | import com.amazonaws.services.forecast.model.ListDatasetsResult; 8 | import com.amazonaws.services.forecast.model.ListForecastsRequest; 9 | import com.amazonaws.services.forecast.model.ListForecastsResult; 10 | import com.amazonaws.services.forecast.model.ListPredictorsRequest; 11 | import com.amazonaws.services.forecast.model.ListPredictorsResult; 12 | import com.amazonaws.services.forecast.model.PredictorSummary; 13 | import com.google.common.collect.ImmutableMap; 14 | import lombok.NonNull; 15 | 16 | import java.time.Duration; 17 | import java.util.ArrayList; 18 | import java.util.Comparator; 19 | import java.util.List; 20 | import java.util.Map; 21 | import java.util.StringJoiner; 22 | 23 | 24 | public final class PredictionGenerationUtils { 25 | 26 | // Private Constructor will prevent the instantiation of this class directly 27 | private PredictionGenerationUtils() {} 28 | 29 | private static final Duration ONE_MIN_DURATION = Duration.ofMinutes(1); 30 | private static final Duration FIVE_MIN_DURATION = Duration.ofMinutes(5); 31 | private static final Duration TEN_MIN_DURATION = Duration.ofMinutes(10); 32 | private static final Duration FIFTEEN_MIN_DURATION = Duration.ofMinutes(15); 33 | private static final Duration THIRTY_MIN_DURATION = Duration.ofMinutes(30); 34 | private static final Duration ONE_HOUR_DURATION = Duration.ofHours(1); 35 | private static final Duration ONE_DAY_DURATION = Duration.ofDays(1); 36 | private static final Duration ONE_WEEK_DURATION = Duration.ofDays(7); 37 | private static final Duration ONE_MONTH_DURATION = Duration.ofDays(30); 38 | 39 | // Refer to: https://docs.aws.amazon.com/forecast/latest/dg/API_CreateDataset.html#forecast-CreateDataset-request-DataFrequency 40 | static final String ONE_MIN_DATA_FREQUENCY_STRING = "1min"; 41 | static final String FIVE_MIN_DATA_FREQUENCY_STRING = "5min"; 42 | static final String TEN_MIN_DATA_FREQUENCY_STRING = "10min"; 43 | static final String FIFTEEN_MIN_DATA_FREQUENCY_STRING = "15min"; 44 | static final String THIRTY_MIN_DATA_FREQUENCY_STRING = "30min"; 45 | static final String ONE_HOUR_DATA_FREQUENCY_STRING = "H"; 46 | static final String ONE_DAY_DATA_FREQUENCY_STRING = "D"; 47 | static final String ONE_WEEK_DATA_FREQUENCY_STRING = "W"; 48 | static final String ONE_MONTH_DATA_FREQUENCY_STRING = "M"; 49 | static final String ONE_YEAR_DATA_FREQUENCY_STRING = "Y"; 50 | static final Map DATA_FREQUENCY_SECONDS_MAPPING = 51 | ImmutableMap.builder() 52 | .put(FIVE_MIN_DATA_FREQUENCY_STRING, 5*60) 53 | .put(TEN_MIN_DATA_FREQUENCY_STRING, 10*60) 54 | .put(FIFTEEN_MIN_DATA_FREQUENCY_STRING, 15*60) 55 | .put(THIRTY_MIN_DATA_FREQUENCY_STRING, 30*60) 56 | .put(ONE_HOUR_DATA_FREQUENCY_STRING, 60*60) 57 | .put(ONE_DAY_DATA_FREQUENCY_STRING, 60*60*24) 58 | .put(ONE_WEEK_DATA_FREQUENCY_STRING, 60*60*24*7) 59 | .put(ONE_MONTH_DATA_FREQUENCY_STRING, 60*60*24*30) 60 | .put(ONE_YEAR_DATA_FREQUENCY_STRING, 60*60*24*365) 61 | .build(); 62 | 63 | static final String ARN_COMPONENT_SPLITTER = ":"; 64 | static final String FORECAST_SERVICE_NAME = "forecast"; 65 | static final String DATASET_NAME_PREFIX = "ds_"; 66 | static final String DATASET_GROUP_NAME_PREFIX = "dsg_"; 67 | static final String DATASET_IMPORT_JOB_NAME_PREFIX = "dsij_"; 68 | static final String PREDICTOR_NAME_PREFIX = "p_"; 69 | static final String FORECAST_NAME_PREFIX = "f_"; 70 | static final String FORECAST_EXPORT_JOB_NAME_PREFIX = "fej_"; 71 | 72 | static final String FORECAST_RESOURCE_ARN_PREFIX_KEY = "ForecastResourceArnPrefixKey"; 73 | static final String DATASET_NAME_KEY = "DatasetName"; 74 | static final String DATASET_ARN_KEY = "DatasetArn"; 75 | static final String DATASET_GROUP_NAME_KEY = "DatasetGroupName"; 76 | static final String DATASET_GROUP_ARN_KEY = "DatasetGroupArn"; 77 | static final String DATASET_IMPORT_JOB_NAME_KEY = "DatasetImportJobName"; 78 | static final String DATASET_IMPORT_JOB_ARN_KEY = "DatasetImportJobArn"; 79 | static final String PREDICTOR_NAME_KEY = "PredictorName"; 80 | static final String PREDICTOR_ARN_KEY = "PredictorArn"; 81 | static final String FORECAST_NAME_KEY = "ForecastName"; 82 | static final String FORECAST_ARN_KEY = "ForecastArn"; 83 | static final String FORECAST_EXPORT_JOB_NAME_KEY = "ForecastExportJobName"; 84 | static final String FORECAST_EXPORT_JOB_ARN_KEY = "ForecastExportJobArn"; 85 | static final String DATA_FREQUENCY_KEY = "DataFrequency"; 86 | 87 | static final String RESOURCE_ACTIVE_STATUS = "ACTIVE"; 88 | static final String RESOURCE_FAILED_STATUS = "FAILED"; 89 | 90 | static DatasetSummary getLatestDataset(final AmazonForecast forecastClient) { 91 | List existingDatasets = listDatasets(forecastClient); 92 | return existingDatasets 93 | .stream() 94 | .max(Comparator.comparing(DatasetSummary::getCreationTime)).orElse(null); 95 | } 96 | 97 | static List listDatasets(final AmazonForecast forecastClient) { 98 | List existingDatasets = new ArrayList<>(); 99 | String nextToken = null; 100 | do { 101 | ListDatasetsRequest listDatasetsRequest = new ListDatasetsRequest(); 102 | if (nextToken != null) { 103 | listDatasetsRequest.setNextToken(nextToken); 104 | } 105 | ListDatasetsResult listDatasetsResult = forecastClient.listDatasets(listDatasetsRequest); 106 | existingDatasets.addAll(listDatasetsResult.getDatasets()); 107 | nextToken = listDatasetsResult.getNextToken(); 108 | } while (nextToken != null); 109 | 110 | return existingDatasets; 111 | } 112 | 113 | static PredictorSummary getLatestPredictor(final AmazonForecast forecastClient) { 114 | List existingPredictors = listPredictors(forecastClient); 115 | return existingPredictors.stream().max(Comparator.comparing(PredictorSummary::getCreationTime)).orElse(null); 116 | } 117 | 118 | /** 119 | * TODO: following two methods have very similar strcuture, which is keeping call list call until the next token is null. 120 | * We should parameter them into one method. 121 | */ 122 | static List listPredictors(AmazonForecast forecastClient) { 123 | List existingPredictors = new ArrayList<>(); 124 | String nextToken = null; 125 | do { 126 | ListPredictorsRequest listPredictorsRequest = new ListPredictorsRequest(); 127 | if (nextToken != null) { 128 | listPredictorsRequest.setNextToken(nextToken); 129 | } 130 | ListPredictorsResult listPredictorsResult = forecastClient.listPredictors(listPredictorsRequest); 131 | 132 | existingPredictors.addAll(listPredictorsResult.getPredictors()); 133 | nextToken = listPredictorsResult.getNextToken(); 134 | } while (nextToken != null); 135 | 136 | return existingPredictors; 137 | } 138 | 139 | static List listForecasts(AmazonForecast forecastClient) { 140 | List existingForecasts = new ArrayList<>(); 141 | String nextToken = null; 142 | do { 143 | ListForecastsRequest listForecastsRequest = new ListForecastsRequest(); 144 | if (nextToken != null) { 145 | listForecastsRequest.setNextToken(nextToken); 146 | } 147 | ListForecastsResult listForecastsResult = forecastClient.listForecasts(listForecastsRequest); 148 | 149 | existingForecasts.addAll(listForecastsResult.getForecasts()); 150 | nextToken = listForecastsResult.getNextToken(); 151 | } while (nextToken != null); 152 | 153 | return existingForecasts; 154 | } 155 | 156 | /** 157 | * Convert lambda function arn, e.g.: 158 | * arn:aws:lambda:us-east-1:443299619838:function:CreateDataset 159 | * to forecast resource arn prefix, e.g.: 160 | * arn:aws:forecast:us-west-2:443299619838 161 | */ 162 | static String deriveForecastResourceArnPrefixFromLambdaFunctionArn(@NonNull final String functionArn) { 163 | String[] functionArnComponents = functionArn.split(ARN_COMPONENT_SPLITTER); 164 | StringJoiner forecastResourceArnPrefix = new StringJoiner(ARN_COMPONENT_SPLITTER, "", ARN_COMPONENT_SPLITTER); 165 | forecastResourceArnPrefix 166 | .add(functionArnComponents[0]) // arn 167 | .add(functionArnComponents[1]) // partition: "aws", or "aws-cn" 168 | .add(FORECAST_SERVICE_NAME) // forecast 169 | .add(functionArnComponents[3]) // region: "us-west-2" 170 | .add(functionArnComponents[4]); // accountId: "0123456789" 171 | return forecastResourceArnPrefix.toString(); 172 | } 173 | 174 | static String getForecastDataFrequencyStr(final Duration dataFrequencyDuration) { 175 | 176 | /* 177 | * Refer to: https://docs.aws.amazon.com/forecast/latest/dg/API_CreateDataset.html#forecast-CreateDataset-request-DataFrequency, 178 | * Valid intervals are Y (Year), M (Month), W (Week), D (Day), H (Hour), 30min (30 minutes), 179 | * 15min (15 minutes), 10min (10 minutes), 5min (5 minutes), and 1min (1 minute). 180 | */ 181 | if (dataFrequencyDuration.compareTo(ONE_MIN_DURATION) <= 0) { // below or equal 1 min 182 | return ONE_MIN_DATA_FREQUENCY_STRING; 183 | } 184 | if (dataFrequencyDuration.compareTo(FIVE_MIN_DURATION) <= 0) { // below or equal 5 mins 185 | return FIVE_MIN_DATA_FREQUENCY_STRING; 186 | } 187 | if (dataFrequencyDuration.compareTo(TEN_MIN_DURATION) <= 0) { // below or equal 10 mins 188 | return TEN_MIN_DATA_FREQUENCY_STRING; 189 | } 190 | if (dataFrequencyDuration.compareTo(FIFTEEN_MIN_DURATION) <= 0) { // below or equal 15 mins 191 | return FIFTEEN_MIN_DATA_FREQUENCY_STRING; 192 | } 193 | if (dataFrequencyDuration.compareTo(THIRTY_MIN_DURATION) <= 0) { // below or equal 30 mins 194 | return THIRTY_MIN_DATA_FREQUENCY_STRING; 195 | } 196 | if (dataFrequencyDuration.compareTo(ONE_HOUR_DURATION) <= 0) { // below or equal 1 hour 197 | return ONE_HOUR_DATA_FREQUENCY_STRING; 198 | } 199 | if (dataFrequencyDuration.compareTo(ONE_DAY_DURATION) <= 0) { // below or equal 1 day 200 | return ONE_DAY_DATA_FREQUENCY_STRING; 201 | } 202 | if (dataFrequencyDuration.compareTo(ONE_WEEK_DURATION) <= 0) { // below or equal 1 week 203 | return ONE_WEEK_DATA_FREQUENCY_STRING; 204 | } 205 | if (dataFrequencyDuration.compareTo(ONE_MONTH_DURATION) <= 0) { // below or equal 1 month 206 | return ONE_MONTH_DATA_FREQUENCY_STRING; 207 | } 208 | return ONE_YEAR_DATA_FREQUENCY_STRING; 209 | } 210 | } 211 | -------------------------------------------------------------------------------- /src/main/java/com/amazonaws/lambda/predictiongeneration/exception/ResourceCleanupInProgressException.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.lambda.predictiongeneration.exception; 2 | 3 | public class ResourceCleanupInProgressException extends RuntimeException { 4 | 5 | public static final long serialVersionUID = 3032649540135073597L; 6 | 7 | public ResourceCleanupInProgressException(String message) { 8 | super(message); 9 | } 10 | } 11 | -------------------------------------------------------------------------------- /src/main/java/com/amazonaws/lambda/predictiongeneration/exception/ResourceSetupFailureException.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.lambda.predictiongeneration.exception; 2 | 3 | public class ResourceSetupFailureException extends RuntimeException { 4 | 5 | public static final long serialVersionUID = 457081345023074308L; 6 | 7 | public ResourceSetupFailureException(String message) { 8 | super(message); 9 | } 10 | } 11 | -------------------------------------------------------------------------------- /src/main/java/com/amazonaws/lambda/predictiongeneration/exception/ResourceSetupInProgressException.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.lambda.predictiongeneration.exception; 2 | 3 | public class ResourceSetupInProgressException extends RuntimeException { 4 | 5 | public static final long serialVersionUID = -7480718769026135873L; 6 | 7 | public ResourceSetupInProgressException(String message) { 8 | super(message); 9 | } 10 | } 11 | -------------------------------------------------------------------------------- /src/main/java/com/amazonaws/lambda/queryingpredictionresult/LoadDataFromS3ToDynamoDBHandler.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.lambda.queryingpredictionresult; 2 | 3 | import com.amazonaws.dagger.DaggerLambdaFunctionsComponent; 4 | import com.amazonaws.services.dynamodbv2.AmazonDynamoDB; 5 | import com.amazonaws.services.dynamodbv2.datamodeling.DynamoDBMapper; 6 | import com.amazonaws.services.dynamodbv2.datamodeling.DynamoDBMapperConfig; 7 | import com.amazonaws.services.dynamodbv2.model.AttributeValue; 8 | import com.amazonaws.services.dynamodbv2.model.ComparisonOperator; 9 | import com.amazonaws.services.dynamodbv2.model.Condition; 10 | import com.amazonaws.services.dynamodbv2.model.Put; 11 | import com.amazonaws.services.dynamodbv2.model.QueryRequest; 12 | import com.amazonaws.services.dynamodbv2.model.QueryResult; 13 | import com.amazonaws.services.dynamodbv2.model.TransactWriteItem; 14 | import com.amazonaws.services.dynamodbv2.model.TransactWriteItemsRequest; 15 | import com.amazonaws.services.lambda.runtime.Context; 16 | import com.amazonaws.services.lambda.runtime.RequestHandler; 17 | import com.amazonaws.services.lambda.runtime.events.S3Event; 18 | import com.amazonaws.services.s3.AmazonS3; 19 | import com.amazonaws.services.s3.event.S3EventNotification.S3EventNotificationRecord; 20 | import com.amazonaws.services.s3.model.GetObjectRequest; 21 | import com.amazonaws.services.s3.model.S3Object; 22 | import com.google.common.annotations.VisibleForTesting; 23 | import com.google.common.collect.Lists; 24 | import com.opencsv.bean.CsvToBean; 25 | import com.opencsv.bean.CsvToBeanBuilder; 26 | import lombok.NonNull; 27 | import lombok.RequiredArgsConstructor; 28 | import lombok.extern.slf4j.Slf4j; 29 | import org.joda.time.Duration; 30 | import org.joda.time.format.DateTimeFormat; 31 | import org.joda.time.format.DateTimeFormatter; 32 | 33 | import javax.inject.Inject; 34 | import java.io.BufferedReader; 35 | import java.io.InputStreamReader; 36 | import java.nio.charset.StandardCharsets; 37 | import java.time.Instant; 38 | import java.time.temporal.ChronoUnit; 39 | import java.util.Arrays; 40 | import java.util.Collection; 41 | import java.util.HashMap; 42 | import java.util.List; 43 | import java.util.Map; 44 | import java.util.regex.Matcher; 45 | import java.util.regex.Pattern; 46 | 47 | @RequiredArgsConstructor 48 | @Slf4j 49 | public class LoadDataFromS3ToDynamoDBHandler implements RequestHandler { 50 | 51 | private static final String DYNAMODB_PREDICTION_TABLE_NAME = System.getenv("PREDICTION_TABLE_NAME"); 52 | private static final String DYNAMODB_PREDICTION_TABLE_HASH_KEY_NAME = System.getenv("PREDICTION_TABLE_HASH_KEY"); 53 | private static final String DYNAMODB_PREDICTION_TABLE_RANGE_KEY_NAME = System.getenv("PREDICTION_TABLE_RANGE_KEY"); 54 | 55 | // The item lifespan should be aligned with the forecast horizon 56 | private static final String DYNAMODB_PREDICTION_TABLE_ITEM_LIFESPAN_IN_DAY_STR = System.getenv("FORECAST_HORIZON_IN_DAYS"); 57 | private static final long DYNAMODB_PREDICTION_TABLE_ITEM_EXPIRATION_TIME = Instant.now() 58 | .plus(Long.parseLong(DYNAMODB_PREDICTION_TABLE_ITEM_LIFESPAN_IN_DAY_STR), ChronoUnit.DAYS).getEpochSecond(); 59 | private static final DateTimeFormatter PREDICTION_TIMESTAMP_FORMATTER = DateTimeFormat.forPattern("YYYY-MM-dd'T'HH:mm:ss'Z'"); 60 | private static final int REQUIRED_NUMBER_OF_PREDICTION_RESULT_ITEMS_FOR_DERIVING_DATA_FREQUENCY = 2; 61 | 62 | private static final String DYNAMODB_PREDICTION_METADATA_TABLE_NAME = System.getenv("PREDICTION_METADATA_TABLE_NAME"); 63 | private static final String DYNAMODB_PREDICTION_METADATA_HASH_KEY_NAME = System.getenv("PREDICTION_METADATA_TABLE_HASH_KEY"); 64 | private static final String DYNAMODB_PREDICTION_METADATA_ATTRIBUTE_NAME = System.getenv("PREDICTION_METADATA_TABLE_ATTRIBUTE_NAME"); 65 | 66 | private static final String PREDICTION_TABLE_CSV_VALUE_SPLITTER = "$"; 67 | @VisibleForTesting 68 | static final String DYNAMODB_PREDICTION_METADATA_LATEST_PRED_UUID_ATTR_NAME = "LatestPredictionUUID"; 69 | @VisibleForTesting 70 | static final String DYNAMODB_PREDICTION_METADATA_LATEST_PRED_DATA_FREQ_IN_SEC_ATTR_NAME = "LatestPredictionDataFrequencyInSeconds"; 71 | 72 | // An example of prediction file name: target/fej_1571260106456_2019-10-16T21-40-00Z_part0.csv 73 | private static final String PREDICTION_RESULT_FILE_NAME_REGEX = 74 | "^([a-zA-Z0-9_-]+)/" + // for matching string like, "target/ 75 | "([a-zA-Z0-9_-]+)" + // for matching forecastExportJob file name string like "fej_1571260106456" 76 | "_(\\d{4}-\\d{2}-\\d{2}T\\d{2}-\\d{2}-\\d{2}Z)" + // for matching timestamp like "_2019-10-16T21-40-00Z" 77 | "_(part\\d{1}.csv)$"; // for matching the suffix like "_part0.csv"; 78 | private static final Pattern PREDICTION_RESULT_FILE_NAME_PATTERN = Pattern.compile(PREDICTION_RESULT_FILE_NAME_REGEX); 79 | 80 | /* 81 | * Refer to: https://docs.oracle.com/javase/7/docs/api/java/util/regex/Matcher.html#group%28int%29 82 | * group(0) will match the entire group, group(1) is the first group within the parentheses 83 | */ 84 | private static final int FORECAST_EXPORT_JOB_NAME_INDEX = 2; 85 | 86 | @Inject 87 | @NonNull 88 | AmazonS3 s3Client; 89 | 90 | @Inject 91 | @NonNull 92 | AmazonDynamoDB ddbClient; 93 | 94 | public LoadDataFromS3ToDynamoDBHandler() { 95 | DaggerLambdaFunctionsComponent.create().inject(this); 96 | } 97 | 98 | @Override 99 | public Void handleRequest(S3Event s3Event, Context context) { 100 | /* 101 | * Based on https://forums.aws.amazon.com/thread.jspa?messageID=592264#592264 102 | * all S3 event notifications have a single event(record) per notification message, 103 | * which means there will be only 1 record in records list 104 | */ 105 | S3EventNotificationRecord record = s3Event.getRecords().get(0); 106 | String srcBucket = record.getS3().getBucket().getName(); 107 | String srcKey = record.getS3().getObject().getKey(); 108 | 109 | Matcher predictionResultUuidMatcher = PREDICTION_RESULT_FILE_NAME_PATTERN.matcher(srcKey); 110 | String forecastExportJobName; 111 | if (predictionResultUuidMatcher.matches()) { 112 | forecastExportJobName = predictionResultUuidMatcher.group(FORECAST_EXPORT_JOB_NAME_INDEX); 113 | } else { 114 | String errorMsg = String.format("Cannot parse prediction result object key: %s", srcKey); 115 | throw new RuntimeException(errorMsg); 116 | } 117 | 118 | S3Object s3Object = s3Client.getObject(new GetObjectRequest(srcBucket, srcKey)); 119 | log.info(String.format("Start processing s3 object: %s, with forecast export job name: %s", 120 | s3Object.toString(), forecastExportJobName)); 121 | 122 | // Read file directly from S3 and converts the records into PredictionResultItem model 123 | BufferedReader s3ObjectReader = new BufferedReader(new InputStreamReader(s3Object.getObjectContent(), StandardCharsets.UTF_8)); 124 | CsvToBean csvToBean = new CsvToBeanBuilder(s3ObjectReader) 125 | .withType(PredictionResultItem.class) 126 | .withIgnoreLeadingWhiteSpace(true) 127 | .build(); 128 | List predictionResultItems = Lists.newArrayList(csvToBean.iterator()); 129 | 130 | int numberOfNewItems = predictionResultItems.size(); 131 | if (numberOfNewItems == 0) { 132 | throw new RuntimeException(String.format("Prediction result file %s contains no record.", srcKey)); 133 | } 134 | predictionResultItems.forEach(item -> 135 | { 136 | item.setHashKey(String.format("%s%s%s", 137 | item.getHashKey(), PREDICTION_TABLE_CSV_VALUE_SPLITTER, forecastExportJobName)); 138 | item.setExpirationTime(DYNAMODB_PREDICTION_TABLE_ITEM_EXPIRATION_TIME); 139 | }); 140 | log.info(String.format("Finish loading and parsing %d new items from S3.", numberOfNewItems)); 141 | 142 | DynamoDBMapper mapper = new DynamoDBMapper(ddbClient, 143 | DynamoDBMapperConfig.builder() 144 | .withTableNameOverride(DynamoDBMapperConfig 145 | .TableNameOverride.withTableNameReplacement(DYNAMODB_PREDICTION_TABLE_NAME)) 146 | .build()); 147 | mapper.batchSave(predictionResultItems); 148 | log.info("Finish writing to DynamoDB Table."); 149 | 150 | // After populating the PredictionResultItem table, we get the first 2 items for any hashKey 151 | // and calculate the data frequency by comparing the rangeKey(sortKey) 152 | Condition hashKeyCondition = new Condition() 153 | .withComparisonOperator(ComparisonOperator.EQ) 154 | .withAttributeValueList(new AttributeValue(predictionResultItems.get(0).getHashKey())); 155 | Map keyConditions = new HashMap<>(); 156 | keyConditions.put(DYNAMODB_PREDICTION_TABLE_HASH_KEY_NAME, hashKeyCondition); 157 | QueryRequest queryRequest = new QueryRequest() 158 | .withTableName(DYNAMODB_PREDICTION_TABLE_NAME) 159 | .withKeyConditions(keyConditions) 160 | .withConsistentRead(true) 161 | .withScanIndexForward(true) /* ascending order for the range key*/ 162 | .withLimit(REQUIRED_NUMBER_OF_PREDICTION_RESULT_ITEMS_FOR_DERIVING_DATA_FREQUENCY); /* get the first 2 items */ 163 | QueryResult queryResult = ddbClient.query(queryRequest); 164 | long predictionDataFreqInSecs = derivePredDataFreqFromConsecutiveItems(queryResult.getItems()); 165 | 166 | // Write latestPredictionUUID and latestPredictionDataFrequency to PredictionMetadata table IN A SINGLE TRANSACTION 167 | Map latestPredictionUUIDItem = new HashMap<>(); 168 | latestPredictionUUIDItem.put(DYNAMODB_PREDICTION_METADATA_HASH_KEY_NAME, 169 | new AttributeValue(DYNAMODB_PREDICTION_METADATA_LATEST_PRED_UUID_ATTR_NAME)); 170 | latestPredictionUUIDItem.put(DYNAMODB_PREDICTION_METADATA_ATTRIBUTE_NAME, new AttributeValue(forecastExportJobName)); 171 | Put latestPredictionUUIDItemWrite = new Put() 172 | .withTableName(DYNAMODB_PREDICTION_METADATA_TABLE_NAME) 173 | .withItem(latestPredictionUUIDItem); 174 | 175 | Map latestPredictionDataFrequencyItem = new HashMap<>(); 176 | latestPredictionDataFrequencyItem.put(DYNAMODB_PREDICTION_METADATA_HASH_KEY_NAME, 177 | new AttributeValue(DYNAMODB_PREDICTION_METADATA_LATEST_PRED_DATA_FREQ_IN_SEC_ATTR_NAME)); 178 | latestPredictionDataFrequencyItem.put(DYNAMODB_PREDICTION_METADATA_ATTRIBUTE_NAME, 179 | new AttributeValue(String.valueOf(predictionDataFreqInSecs))); 180 | Put latestPredictionDataFrequencyWrite = new Put() 181 | .withTableName(DYNAMODB_PREDICTION_METADATA_TABLE_NAME) 182 | .withItem(latestPredictionDataFrequencyItem); 183 | 184 | Collection transactWrites = Arrays.asList( 185 | new TransactWriteItem().withPut(latestPredictionUUIDItemWrite), 186 | new TransactWriteItem().withPut(latestPredictionDataFrequencyWrite) 187 | ); 188 | TransactWriteItemsRequest writeItemsRequest = new TransactWriteItemsRequest() 189 | .withTransactItems(transactWrites); 190 | 191 | ddbClient.transactWriteItems(writeItemsRequest); 192 | log.info("Finish updating new metadata items for the latest prediction"); 193 | 194 | // Not bother to close all the file descriptors as lambda function will cleanup them after termination 195 | 196 | return null; 197 | } 198 | 199 | /** 200 | * Derive the data frequency by calculating the diff on the rangeKey(timestamp) for the top two items. 201 | * We can make the assumption that one prediction result file can only have one data frequency. 202 | * @return The data frequency(window size) in seconds 203 | */ 204 | private long derivePredDataFreqFromConsecutiveItems(List> items) { 205 | if (items == null || items.size() < REQUIRED_NUMBER_OF_PREDICTION_RESULT_ITEMS_FOR_DERIVING_DATA_FREQUENCY) { 206 | throw new RuntimeException(String.format("Passed in items contains %d item, which is less than 2.", 207 | items == null ? 0 : items.size())); 208 | } 209 | 210 | Map firstItem = items.get(0); 211 | Map secondItem = items.get(1); 212 | String firstTsStr = firstItem.get(DYNAMODB_PREDICTION_TABLE_RANGE_KEY_NAME).getS(); 213 | String secondTsStr = secondItem.get(DYNAMODB_PREDICTION_TABLE_RANGE_KEY_NAME).getS(); 214 | long dataFreqInSeconds = Math.abs(new Duration(PREDICTION_TIMESTAMP_FORMATTER.parseDateTime(firstTsStr), 215 | PREDICTION_TIMESTAMP_FORMATTER.parseDateTime(secondTsStr)).getStandardSeconds()); 216 | 217 | if (dataFreqInSeconds == 0) { 218 | throw new RuntimeException(String.format("dataFreqInSeconds [%d] derived from firstItem [%s] and secondItem [%s] is 0", 219 | dataFreqInSeconds, firstItem.toString(), secondItem.toString())); 220 | } 221 | return dataFreqInSeconds; 222 | } 223 | } 224 | -------------------------------------------------------------------------------- /src/main/java/com/amazonaws/lambda/queryingpredictionresult/PredictionResultItem.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.lambda.queryingpredictionresult; 2 | 3 | import com.amazonaws.services.dynamodbv2.datamodeling.DynamoDBAttribute; 4 | import com.amazonaws.services.dynamodbv2.datamodeling.DynamoDBHashKey; 5 | import com.amazonaws.services.dynamodbv2.datamodeling.DynamoDBRangeKey; 6 | import com.amazonaws.services.dynamodbv2.datamodeling.DynamoDBTable; 7 | import com.opencsv.bean.CsvBindByName; 8 | import lombok.AllArgsConstructor; 9 | import lombok.Builder; 10 | import lombok.Data; 11 | import lombok.NoArgsConstructor; 12 | 13 | @Data 14 | @AllArgsConstructor 15 | @NoArgsConstructor 16 | @Builder 17 | @DynamoDBTable(tableName = PredictionResultItem.TABLE_NAME) 18 | public class PredictionResultItem { 19 | 20 | public static final String TABLE_NAME = "PredictionResultItem"; 21 | 22 | public static class Attribute { 23 | public static final String ITEM_ID = "item_id"; 24 | public static final String DATE = "date"; 25 | public static final String P10 = "p10"; 26 | public static final String P50 = "p50"; 27 | public static final String P90 = "p90"; 28 | public static final String EXPIRATION_TIME = "expirationTime"; 29 | } 30 | 31 | @DynamoDBHashKey(attributeName = Attribute.ITEM_ID) 32 | @CsvBindByName(column = Attribute.ITEM_ID, required = true) 33 | private String hashKey; 34 | 35 | @DynamoDBRangeKey(attributeName = Attribute.DATE) 36 | @CsvBindByName(column = Attribute.DATE, required = true) 37 | private String sortKey; 38 | 39 | @DynamoDBAttribute(attributeName = Attribute.P10) 40 | @CsvBindByName(column = Attribute.P10, required = true) 41 | private double p10; 42 | 43 | @DynamoDBAttribute(attributeName = Attribute.P50) 44 | @CsvBindByName(column = Attribute.P50, required = true) 45 | private double p50; 46 | 47 | @DynamoDBAttribute(attributeName = Attribute.P90) 48 | @CsvBindByName(column = Attribute.P90, required = true) 49 | private double p90; 50 | 51 | @DynamoDBAttribute(attributeName = Attribute.EXPIRATION_TIME) 52 | @CsvBindByName 53 | private long expirationTime; 54 | } 55 | 56 | -------------------------------------------------------------------------------- /src/main/java/log4j2.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | %d{yyyy-MM-dd HH:mm:ss,SSS} %X{AWSRequestId} %-5p %c{1}:%L - %m%n 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /src/test/java/com/amazonaws/lambda/demandpublishing/PublishDemandHandlerTest.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.lambda.demandpublishing; 2 | 3 | import com.amazonaws.services.lambda.runtime.Context; 4 | import com.amazonaws.services.s3.model.ObjectMetadata; 5 | import com.amazonaws.services.s3.transfer.TransferManager; 6 | import com.amazonaws.services.s3.transfer.Upload; 7 | import com.amazonaws.util.IOUtils; 8 | import org.junit.Rule; 9 | import org.junit.contrib.java.lang.system.EnvironmentVariables; 10 | import org.junit.jupiter.api.BeforeEach; 11 | import org.junit.jupiter.api.Test; 12 | import org.mockito.ArgumentCaptor; 13 | import org.mockito.Mock; 14 | 15 | import java.io.InputStream; 16 | import java.time.Clock; 17 | import java.time.LocalDateTime; 18 | import java.time.ZoneId; 19 | import java.time.ZoneOffset; 20 | 21 | import static org.junit.jupiter.api.Assertions.assertEquals; 22 | import static org.mockito.ArgumentMatchers.any; 23 | import static org.mockito.ArgumentMatchers.eq; 24 | import static org.mockito.Mockito.mock; 25 | import static org.mockito.Mockito.times; 26 | import static org.mockito.Mockito.verify; 27 | import static org.mockito.Mockito.when; 28 | 29 | public class PublishDemandHandlerTest { 30 | 31 | private static final String PREDICTION_S3_BUCKET_NAME = "testBucket"; 32 | private static final String SRC_S3_FOLDER = "testSrc"; 33 | private static final String S3_TRAINING_DATA_FILE_NAME = "testDemandFile"; 34 | 35 | @Rule 36 | public final EnvironmentVariables environmentVariables = new EnvironmentVariables(); 37 | 38 | @Mock 39 | private Context context; 40 | 41 | private Clock fixedClock; 42 | private String testRawDemandRequestsFilePath; 43 | private TransferManager mockTransferManager; 44 | private PublishDemandHandler handler; 45 | 46 | @BeforeEach 47 | void setup() { 48 | environmentVariables.set("PREDICTION_S3_BUCKET_NAME", PREDICTION_S3_BUCKET_NAME); 49 | environmentVariables.set("SRC_S3_FOLDER", SRC_S3_FOLDER); 50 | environmentVariables.set("S3_TRAINING_DATA_FILE_NAME", S3_TRAINING_DATA_FILE_NAME); 51 | 52 | fixedClock = Clock.fixed(LocalDateTime.of(2023, 3, 1, 1, 1) 53 | .toInstant(ZoneOffset.UTC), ZoneId.of("UTC")); 54 | testRawDemandRequestsFilePath = "/test_raw_demand_requests.csv"; 55 | mockTransferManager = mock(TransferManager.class); 56 | handler = new PublishDemandHandler(fixedClock, testRawDemandRequestsFilePath, mockTransferManager); 57 | } 58 | 59 | @Test 60 | public void testPublishDemand() throws Exception { 61 | // Check the first two records in src/test/resources/test_raw_demand_requests.csv for such info 62 | String expectedDemandRecordsStr = "item_id,timestamp,target_value\n5,2020-01-01 03:50:33,14\n5,2020-02-01 03:53:14,14"; 63 | ObjectMetadata expectedMetadata = new ObjectMetadata(); 64 | expectedMetadata.setContentLength(expectedDemandRecordsStr.length()); 65 | when(mockTransferManager.upload(any(String.class), any(String.class), any(InputStream.class), any(ObjectMetadata.class))) 66 | .thenReturn(mock(Upload.class)); 67 | 68 | handler.handleRequest(null, context); 69 | 70 | ArgumentCaptor streamCaptor = ArgumentCaptor.forClass(InputStream.class); 71 | ArgumentCaptor objectMetadataCaptor = ArgumentCaptor.forClass(ObjectMetadata.class); 72 | verify(mockTransferManager, times(1)).upload(eq(PREDICTION_S3_BUCKET_NAME), 73 | eq(String.format("%s/%s", SRC_S3_FOLDER, S3_TRAINING_DATA_FILE_NAME)), 74 | streamCaptor.capture(), objectMetadataCaptor.capture()); 75 | assertEquals(expectedDemandRecordsStr, IOUtils.toString(streamCaptor.getValue())); 76 | assertEquals(expectedMetadata.getContentLength(), objectMetadataCaptor.getValue().getContentLength()); 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /src/test/java/com/amazonaws/lambda/predictiongeneration/BaseTest.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.lambda.predictiongeneration; 2 | 3 | import com.amazonaws.services.forecast.AmazonForecast; 4 | import org.junit.jupiter.api.BeforeEach; 5 | 6 | import java.util.Map; 7 | 8 | import static com.amazonaws.lambda.predictiongeneration.GenerateForecastResourcesIdsHandler.buildResourceIdMap; 9 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.deriveForecastResourceArnPrefixFromLambdaFunctionArn; 10 | import static org.mockito.Mockito.mock; 11 | 12 | public class BaseTest { 13 | 14 | protected static final String TEST_FUNCTION_ARN = "arn:aws:lambda:us-east-1:012345678901:function:Dummy"; 15 | protected static final String TEST_FORECAST_RESOURCE_ARN 16 | = deriveForecastResourceArnPrefixFromLambdaFunctionArn(TEST_FUNCTION_ARN); 17 | protected static final String TEST_RESOURCE_CREATING_STATUS = "CREATING"; 18 | 19 | static final String DEFAULT_DATA_FREQUENCY_VALUE = "30min"; 20 | 21 | protected AmazonForecast mockForecastClient; 22 | protected Map testResourceIdMap; 23 | 24 | @BeforeEach 25 | public void baseSetup() { 26 | mockForecastClient = mock(AmazonForecast.class); 27 | testResourceIdMap = buildResourceIdMap(System.currentTimeMillis(), TEST_FORECAST_RESOURCE_ARN, DEFAULT_DATA_FREQUENCY_VALUE); 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/test/java/com/amazonaws/lambda/predictiongeneration/CreateDatasetGroupHandlerTest.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.lambda.predictiongeneration; 2 | 3 | import com.amazonaws.services.forecast.model.CreateDatasetGroupRequest; 4 | import org.junit.jupiter.api.BeforeEach; 5 | import org.junit.jupiter.api.Test; 6 | 7 | import static com.amazonaws.lambda.predictiongeneration.GenerateForecastResourcesIdsHandler.buildResourceIdMap; 8 | import static org.mockito.ArgumentMatchers.any; 9 | import static org.mockito.Mockito.times; 10 | import static org.mockito.Mockito.verify; 11 | 12 | public class CreateDatasetGroupHandlerTest extends BaseTest { 13 | 14 | CreateDatasetGroupHandler handler; 15 | 16 | @BeforeEach 17 | public void setup() { 18 | handler = new CreateDatasetGroupHandler(mockForecastClient); 19 | } 20 | 21 | @Test 22 | public void testProcess() { 23 | handler.process(buildResourceIdMap(System.currentTimeMillis(), TEST_FORECAST_RESOURCE_ARN, DEFAULT_DATA_FREQUENCY_VALUE)); 24 | verify(mockForecastClient, times(1)).createDatasetGroup(any(CreateDatasetGroupRequest.class)); 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/test/java/com/amazonaws/lambda/predictiongeneration/CreateDatasetHandlerTest.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.lambda.predictiongeneration; 2 | 3 | import com.amazonaws.lambda.predictiongeneration.exception.ResourceSetupFailureException; 4 | import com.amazonaws.lambda.predictiongeneration.exception.ResourceSetupInProgressException; 5 | import com.amazonaws.services.forecast.model.CreateDatasetRequest; 6 | import com.amazonaws.services.forecast.model.DescribeDatasetRequest; 7 | import com.amazonaws.services.forecast.model.DescribeDatasetResult; 8 | import com.amazonaws.services.forecast.model.ResourceNotFoundException; 9 | import org.junit.jupiter.api.BeforeEach; 10 | import org.junit.jupiter.api.Test; 11 | import org.mockito.invocation.InvocationOnMock; 12 | import org.mockito.stubbing.Answer; 13 | 14 | import static com.amazonaws.lambda.predictiongeneration.GenerateForecastResourcesIdsHandler.buildResourceIdMap; 15 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.RESOURCE_ACTIVE_STATUS; 16 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.RESOURCE_FAILED_STATUS; 17 | import static org.junit.jupiter.api.Assertions.assertThrows; 18 | import static org.mockito.ArgumentMatchers.any; 19 | import static org.mockito.Mockito.never; 20 | import static org.mockito.Mockito.times; 21 | import static org.mockito.Mockito.verify; 22 | import static org.mockito.Mockito.when; 23 | 24 | public class CreateDatasetHandlerTest extends BaseTest { 25 | 26 | CreateDatasetHandler handler; 27 | 28 | @BeforeEach 29 | public void setup() { 30 | handler = new CreateDatasetHandler(mockForecastClient); 31 | } 32 | 33 | @Test 34 | public void testProcess_withActiveStatus() { 35 | DescribeDatasetResult dummyDescribeDatasetResult = new DescribeDatasetResult().withStatus(RESOURCE_ACTIVE_STATUS); 36 | when(mockForecastClient.describeDataset(any(DescribeDatasetRequest.class))).thenReturn(dummyDescribeDatasetResult); 37 | 38 | handler.process(buildResourceIdMap(System.currentTimeMillis(), TEST_FORECAST_RESOURCE_ARN, DEFAULT_DATA_FREQUENCY_VALUE)); 39 | 40 | verify(mockForecastClient, times(1)).describeDataset(any(DescribeDatasetRequest.class)); 41 | verify(mockForecastClient, never()).createDataset(any(CreateDatasetRequest.class)); 42 | } 43 | 44 | @Test 45 | public void testProcess_withFailedStatus() { 46 | DescribeDatasetResult dummyDescribeDatasetResult = new DescribeDatasetResult().withStatus(RESOURCE_FAILED_STATUS); 47 | when(mockForecastClient.describeDataset(any(DescribeDatasetRequest.class))).thenReturn(dummyDescribeDatasetResult); 48 | 49 | assertThrows(ResourceSetupFailureException.class, 50 | () -> handler.process(buildResourceIdMap(System.currentTimeMillis(), TEST_FORECAST_RESOURCE_ARN, DEFAULT_DATA_FREQUENCY_VALUE))); 51 | 52 | verify(mockForecastClient, times(1)).describeDataset(any(DescribeDatasetRequest.class)); 53 | verify(mockForecastClient, never()).createDataset(any(CreateDatasetRequest.class)); 54 | } 55 | 56 | @Test 57 | public void testProcess_withCreateNewDatasetThenInCreating() { 58 | DescribeDatasetResult dummyDescribeDatasetResult = new DescribeDatasetResult().withStatus(TEST_RESOURCE_CREATING_STATUS); 59 | when(mockForecastClient.describeDataset(any(DescribeDatasetRequest.class))) 60 | .thenAnswer(new Answer() { 61 | private int count = 0; 62 | 63 | @Override 64 | public DescribeDatasetResult answer(InvocationOnMock invocation) { 65 | switch (count++) { 66 | case 0: 67 | throw new ResourceNotFoundException("cannot find given dataset"); 68 | case 1: 69 | return dummyDescribeDatasetResult; 70 | default: 71 | throw new IllegalArgumentException(); 72 | } 73 | } 74 | }); 75 | 76 | assertThrows(ResourceSetupInProgressException.class, 77 | () -> handler.process(buildResourceIdMap(System.currentTimeMillis(), TEST_FORECAST_RESOURCE_ARN, DEFAULT_DATA_FREQUENCY_VALUE))); 78 | 79 | verify(mockForecastClient, times(2)).describeDataset(any(DescribeDatasetRequest.class)); 80 | verify(mockForecastClient, times(1)).createDataset(any(CreateDatasetRequest.class)); 81 | } 82 | 83 | @Test 84 | public void testProcess_withCreateNewDatasetThenInActive() { 85 | DescribeDatasetResult dummyDescribeDatasetResult = new DescribeDatasetResult().withStatus(RESOURCE_ACTIVE_STATUS); 86 | when(mockForecastClient.describeDataset(any(DescribeDatasetRequest.class))) 87 | .thenAnswer(new Answer() { 88 | private int count = 0; 89 | 90 | @Override 91 | public DescribeDatasetResult answer(InvocationOnMock invocation) { 92 | switch (count++) { 93 | case 0: 94 | throw new ResourceNotFoundException("cannot find given dataset"); 95 | case 1: 96 | return dummyDescribeDatasetResult; 97 | default: 98 | throw new IllegalArgumentException(); 99 | } 100 | } 101 | }); 102 | 103 | handler.process(buildResourceIdMap(System.currentTimeMillis(), TEST_FORECAST_RESOURCE_ARN, DEFAULT_DATA_FREQUENCY_VALUE)); 104 | 105 | verify(mockForecastClient, times(2)).describeDataset(any(DescribeDatasetRequest.class)); 106 | verify(mockForecastClient, times(1)).createDataset(any(CreateDatasetRequest.class)); 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /src/test/java/com/amazonaws/lambda/predictiongeneration/CreateDatasetImportJobHandlerTest.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.lambda.predictiongeneration; 2 | 3 | import com.amazonaws.lambda.predictiongeneration.exception.ResourceSetupFailureException; 4 | import com.amazonaws.lambda.predictiongeneration.exception.ResourceSetupInProgressException; 5 | import com.amazonaws.services.forecast.model.CreateDatasetImportJobRequest; 6 | import com.amazonaws.services.forecast.model.DescribeDatasetImportJobRequest; 7 | import com.amazonaws.services.forecast.model.DescribeDatasetImportJobResult; 8 | import com.amazonaws.services.forecast.model.ResourceNotFoundException; 9 | import org.junit.jupiter.api.BeforeEach; 10 | import org.junit.jupiter.api.Test; 11 | import org.mockito.invocation.InvocationOnMock; 12 | import org.mockito.stubbing.Answer; 13 | 14 | import static com.amazonaws.lambda.predictiongeneration.GenerateForecastResourcesIdsHandler.buildResourceIdMap; 15 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.RESOURCE_ACTIVE_STATUS; 16 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.RESOURCE_FAILED_STATUS; 17 | import static org.junit.jupiter.api.Assertions.assertThrows; 18 | import static org.mockito.ArgumentMatchers.any; 19 | import static org.mockito.Mockito.never; 20 | import static org.mockito.Mockito.times; 21 | import static org.mockito.Mockito.verify; 22 | import static org.mockito.Mockito.when; 23 | 24 | public class CreateDatasetImportJobHandlerTest extends BaseTest { 25 | 26 | CreateDatasetImportJobHandler handler; 27 | 28 | @BeforeEach 29 | public void setup() { 30 | handler = new CreateDatasetImportJobHandler(mockForecastClient); 31 | } 32 | 33 | @Test 34 | public void testProcess_withActiveStatus() { 35 | DescribeDatasetImportJobResult dummyDescribeDatasetImportJobResult = new DescribeDatasetImportJobResult() .withStatus(RESOURCE_ACTIVE_STATUS); 36 | when(mockForecastClient.describeDatasetImportJob(any(DescribeDatasetImportJobRequest.class))).thenReturn(dummyDescribeDatasetImportJobResult); 37 | 38 | handler.process(buildResourceIdMap(System.currentTimeMillis(), TEST_FORECAST_RESOURCE_ARN, DEFAULT_DATA_FREQUENCY_VALUE)); 39 | 40 | verify(mockForecastClient, times(1)).describeDatasetImportJob(any(DescribeDatasetImportJobRequest.class)); 41 | verify(mockForecastClient, never()).createDatasetImportJob(any(CreateDatasetImportJobRequest.class)); 42 | } 43 | 44 | @Test 45 | public void testProcess_withFailedStatus() { 46 | DescribeDatasetImportJobResult dummyDescribeDatasetImportJobResult = new DescribeDatasetImportJobResult().withStatus(RESOURCE_FAILED_STATUS); 47 | when(mockForecastClient.describeDatasetImportJob(any(DescribeDatasetImportJobRequest.class))).thenReturn(dummyDescribeDatasetImportJobResult); 48 | 49 | assertThrows(ResourceSetupFailureException.class, 50 | () -> handler.process(buildResourceIdMap(System.currentTimeMillis(), TEST_FORECAST_RESOURCE_ARN, DEFAULT_DATA_FREQUENCY_VALUE))); 51 | 52 | verify(mockForecastClient, times(1)).describeDatasetImportJob(any(DescribeDatasetImportJobRequest.class)); 53 | verify(mockForecastClient, never()).createDatasetImportJob(any(CreateDatasetImportJobRequest.class)); 54 | } 55 | 56 | @Test 57 | public void testProcess_withCreateNewDatasetImportJobThenInCreating() { 58 | DescribeDatasetImportJobResult dummyDescribeDatasetImportJobResult = new DescribeDatasetImportJobResult().withStatus(TEST_RESOURCE_CREATING_STATUS); 59 | when(mockForecastClient.describeDatasetImportJob(any(DescribeDatasetImportJobRequest.class))) 60 | .thenAnswer(new Answer() { 61 | private int count = 0; 62 | 63 | @Override 64 | public DescribeDatasetImportJobResult answer(InvocationOnMock invocation) { 65 | switch (count++) { 66 | case 0: 67 | throw new ResourceNotFoundException("cannot find given dataset import job"); 68 | case 1: 69 | return dummyDescribeDatasetImportJobResult; 70 | default: 71 | throw new IllegalArgumentException(); 72 | } 73 | } 74 | }); 75 | 76 | assertThrows(ResourceSetupInProgressException.class, 77 | () -> handler.process(buildResourceIdMap(System.currentTimeMillis(), TEST_FORECAST_RESOURCE_ARN, DEFAULT_DATA_FREQUENCY_VALUE))); 78 | 79 | verify(mockForecastClient, times(2)).describeDatasetImportJob(any(DescribeDatasetImportJobRequest.class)); 80 | verify(mockForecastClient, times(1)).createDatasetImportJob(any(CreateDatasetImportJobRequest.class)); 81 | } 82 | 83 | @Test 84 | public void testProcess_withCreateNewDatasetImportJobThenInActive() { 85 | DescribeDatasetImportJobResult dummyDescribeDatasetImportJobResult = new DescribeDatasetImportJobResult().withStatus(RESOURCE_ACTIVE_STATUS); 86 | when(mockForecastClient.describeDatasetImportJob(any(DescribeDatasetImportJobRequest.class))) 87 | .thenAnswer(new Answer() { 88 | private int count = 0; 89 | 90 | @Override 91 | public DescribeDatasetImportJobResult answer(InvocationOnMock invocation) { 92 | switch (count++) { 93 | case 0: 94 | throw new ResourceNotFoundException("cannot find given dataset import job"); 95 | case 1: 96 | return dummyDescribeDatasetImportJobResult; 97 | default: 98 | throw new IllegalArgumentException(); 99 | } 100 | } 101 | }); 102 | 103 | handler.process(buildResourceIdMap(System.currentTimeMillis(), TEST_FORECAST_RESOURCE_ARN, DEFAULT_DATA_FREQUENCY_VALUE)); 104 | 105 | verify(mockForecastClient, times(2)).describeDatasetImportJob(any(DescribeDatasetImportJobRequest.class)); 106 | verify(mockForecastClient, times(1)).createDatasetImportJob(any(CreateDatasetImportJobRequest.class)); 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /src/test/java/com/amazonaws/lambda/predictiongeneration/CreateForecastExportJobHandlerTest.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.lambda.predictiongeneration; 2 | 3 | import com.amazonaws.lambda.predictiongeneration.exception.ResourceSetupFailureException; 4 | import com.amazonaws.lambda.predictiongeneration.exception.ResourceSetupInProgressException; 5 | import com.amazonaws.services.forecast.model.CreateForecastExportJobRequest; 6 | import com.amazonaws.services.forecast.model.DescribeForecastExportJobRequest; 7 | import com.amazonaws.services.forecast.model.DescribeForecastExportJobResult; 8 | import com.amazonaws.services.forecast.model.ResourceNotFoundException; 9 | import org.junit.Rule; 10 | import org.junit.contrib.java.lang.system.EnvironmentVariables; 11 | import org.junit.jupiter.api.BeforeEach; 12 | import org.junit.jupiter.api.Test; 13 | import org.mockito.invocation.InvocationOnMock; 14 | import org.mockito.stubbing.Answer; 15 | 16 | import static com.amazonaws.lambda.predictiongeneration.GenerateForecastResourcesIdsHandler.buildResourceIdMap; 17 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.RESOURCE_ACTIVE_STATUS; 18 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.RESOURCE_FAILED_STATUS; 19 | import static org.junit.jupiter.api.Assertions.assertThrows; 20 | import static org.mockito.ArgumentMatchers.any; 21 | import static org.mockito.Mockito.never; 22 | import static org.mockito.Mockito.times; 23 | import static org.mockito.Mockito.verify; 24 | import static org.mockito.Mockito.when; 25 | 26 | public class CreateForecastExportJobHandlerTest extends BaseTest { 27 | 28 | private static final String TEST_FORECAST_EXPORT_RESULT_ROLE_ARN = "testExportResultRoleArn"; 29 | private static final String TEST_PREDICTION_S3_BUCKET_NAMEE = "predictionBucket"; 30 | private static final String TEST_TGT_S3_FOLDER = "resources/tgt"; 31 | @Rule 32 | public final EnvironmentVariables environmentVariables = new EnvironmentVariables(); 33 | 34 | CreateForecastExportJobHandler handler; 35 | 36 | @BeforeEach 37 | public void setup() { 38 | // Setup Env variables 39 | environmentVariables.set("FORECAST_EXPORT_RESULT_ROLE_ARN", TEST_FORECAST_EXPORT_RESULT_ROLE_ARN); 40 | environmentVariables.set("PREDICTION_S3_BUCKET_NAME", TEST_PREDICTION_S3_BUCKET_NAMEE); 41 | environmentVariables.set("TGT_S3_FOLDER", TEST_TGT_S3_FOLDER); 42 | 43 | handler = new CreateForecastExportJobHandler(mockForecastClient); 44 | } 45 | 46 | @Test 47 | public void testProcess_withActiveStatus() { 48 | DescribeForecastExportJobResult dummyDescribeForecastExportJobResult = new DescribeForecastExportJobResult().withStatus(RESOURCE_ACTIVE_STATUS); 49 | when(mockForecastClient.describeForecastExportJob(any(DescribeForecastExportJobRequest.class))).thenReturn(dummyDescribeForecastExportJobResult); 50 | 51 | handler.process(buildResourceIdMap(System.currentTimeMillis(), TEST_FORECAST_RESOURCE_ARN, DEFAULT_DATA_FREQUENCY_VALUE)); 52 | 53 | verify(mockForecastClient, times(1)).describeForecastExportJob(any(DescribeForecastExportJobRequest.class)); 54 | verify(mockForecastClient, never()).createForecastExportJob(any(CreateForecastExportJobRequest.class)); 55 | } 56 | 57 | @Test 58 | public void testProcess_withFailedStatus() { 59 | DescribeForecastExportJobResult dummyDescribeForecastExportJobResult = new DescribeForecastExportJobResult().withStatus(RESOURCE_FAILED_STATUS); 60 | when(mockForecastClient.describeForecastExportJob(any(DescribeForecastExportJobRequest.class))).thenReturn(dummyDescribeForecastExportJobResult); 61 | 62 | assertThrows(ResourceSetupFailureException.class, 63 | () -> handler.process(buildResourceIdMap(System.currentTimeMillis(), TEST_FORECAST_RESOURCE_ARN, DEFAULT_DATA_FREQUENCY_VALUE))); 64 | 65 | verify(mockForecastClient, times(1)).describeForecastExportJob(any(DescribeForecastExportJobRequest.class)); 66 | verify(mockForecastClient, never()).createForecastExportJob(any(CreateForecastExportJobRequest.class)); 67 | } 68 | 69 | @Test 70 | public void testProcess_withCreateNewForecastExportJobThenInCreating() { 71 | DescribeForecastExportJobResult dummyDescribeForecastExportJobResult = new DescribeForecastExportJobResult().withStatus(TEST_RESOURCE_CREATING_STATUS); 72 | when(mockForecastClient.describeForecastExportJob(any(DescribeForecastExportJobRequest.class))) 73 | .thenAnswer(new Answer() { 74 | private int count = 0; 75 | 76 | @Override 77 | public DescribeForecastExportJobResult answer(InvocationOnMock invocation) { 78 | switch (count++) { 79 | case 0: 80 | throw new ResourceNotFoundException("cannot find given forecast export job"); 81 | case 1: 82 | return dummyDescribeForecastExportJobResult; 83 | default: 84 | throw new IllegalArgumentException(); 85 | } 86 | } 87 | }); 88 | 89 | assertThrows(ResourceSetupInProgressException.class, 90 | () -> handler.process(buildResourceIdMap(System.currentTimeMillis(), TEST_FORECAST_RESOURCE_ARN, DEFAULT_DATA_FREQUENCY_VALUE))); 91 | 92 | verify(mockForecastClient, times(2)).describeForecastExportJob(any(DescribeForecastExportJobRequest.class)); 93 | verify(mockForecastClient, times(1)).createForecastExportJob(any(CreateForecastExportJobRequest.class)); 94 | } 95 | 96 | @Test 97 | public void testProcess_withCreateNewForecastExportJobThenInActive() { 98 | DescribeForecastExportJobResult dummyDescribeForecastExportJobResult = new DescribeForecastExportJobResult().withStatus(RESOURCE_ACTIVE_STATUS); 99 | when(mockForecastClient.describeForecastExportJob(any(DescribeForecastExportJobRequest.class))) 100 | .thenAnswer(new Answer() { 101 | private int count = 0; 102 | 103 | @Override 104 | public DescribeForecastExportJobResult answer(InvocationOnMock invocation) { 105 | switch (count++) { 106 | case 0: 107 | throw new ResourceNotFoundException("cannot find given forecast export job"); 108 | case 1: 109 | return dummyDescribeForecastExportJobResult; 110 | default: 111 | throw new IllegalArgumentException(); 112 | } 113 | } 114 | }); 115 | 116 | handler.process(buildResourceIdMap(System.currentTimeMillis(), TEST_FORECAST_RESOURCE_ARN, DEFAULT_DATA_FREQUENCY_VALUE)); 117 | 118 | verify(mockForecastClient, times(2)).describeForecastExportJob(any(DescribeForecastExportJobRequest.class)); 119 | verify(mockForecastClient, times(1)).createForecastExportJob(any(CreateForecastExportJobRequest.class)); 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /src/test/java/com/amazonaws/lambda/predictiongeneration/CreateForecastHandlerTest.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.lambda.predictiongeneration; 2 | 3 | import com.amazonaws.lambda.predictiongeneration.exception.ResourceSetupFailureException; 4 | import com.amazonaws.lambda.predictiongeneration.exception.ResourceSetupInProgressException; 5 | import com.amazonaws.services.forecast.model.CreateForecastRequest; 6 | import com.amazonaws.services.forecast.model.DescribeForecastRequest; 7 | import com.amazonaws.services.forecast.model.DescribeForecastResult; 8 | import com.amazonaws.services.forecast.model.ResourceNotFoundException; 9 | import org.junit.jupiter.api.BeforeEach; 10 | import org.junit.jupiter.api.Test; 11 | import org.mockito.invocation.InvocationOnMock; 12 | import org.mockito.stubbing.Answer; 13 | 14 | import static com.amazonaws.lambda.predictiongeneration.GenerateForecastResourcesIdsHandler.buildResourceIdMap; 15 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.RESOURCE_ACTIVE_STATUS; 16 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.RESOURCE_FAILED_STATUS; 17 | import static org.junit.jupiter.api.Assertions.assertThrows; 18 | import static org.mockito.ArgumentMatchers.any; 19 | import static org.mockito.Mockito.never; 20 | import static org.mockito.Mockito.times; 21 | import static org.mockito.Mockito.verify; 22 | import static org.mockito.Mockito.when; 23 | 24 | public class CreateForecastHandlerTest extends BaseTest { 25 | 26 | CreateForecastHandler handler; 27 | 28 | @BeforeEach 29 | public void setup() { 30 | handler = new CreateForecastHandler(mockForecastClient); 31 | } 32 | 33 | @Test 34 | public void testProcess_withActiveStatus() { 35 | DescribeForecastResult dummyDescribeForecastResult = new DescribeForecastResult().withStatus(RESOURCE_ACTIVE_STATUS); 36 | when(mockForecastClient.describeForecast(any(DescribeForecastRequest.class))).thenReturn(dummyDescribeForecastResult); 37 | 38 | handler.process(buildResourceIdMap(System.currentTimeMillis(), TEST_FORECAST_RESOURCE_ARN, DEFAULT_DATA_FREQUENCY_VALUE)); 39 | 40 | verify(mockForecastClient, times(1)).describeForecast(any(DescribeForecastRequest.class)); 41 | verify(mockForecastClient, never()).createForecast(any(CreateForecastRequest.class)); 42 | } 43 | 44 | @Test 45 | public void testProcess_withFailedStatus() { 46 | DescribeForecastResult dummyDescribeForecastResult = new DescribeForecastResult().withStatus(RESOURCE_FAILED_STATUS); 47 | when(mockForecastClient.describeForecast(any(DescribeForecastRequest.class))).thenReturn(dummyDescribeForecastResult); 48 | 49 | assertThrows(ResourceSetupFailureException.class, 50 | () -> handler.process(buildResourceIdMap(System.currentTimeMillis(), TEST_FORECAST_RESOURCE_ARN, DEFAULT_DATA_FREQUENCY_VALUE))); 51 | 52 | verify(mockForecastClient, times(1)).describeForecast(any(DescribeForecastRequest.class)); 53 | verify(mockForecastClient, never()).createForecast(any(CreateForecastRequest.class)); 54 | } 55 | 56 | @Test 57 | public void testProcess_withCreateNewForecastThenInCreating() { 58 | DescribeForecastResult dummyDescribeForecastResult = new DescribeForecastResult().withStatus(TEST_RESOURCE_CREATING_STATUS); 59 | when(mockForecastClient.describeForecast(any(DescribeForecastRequest.class))) 60 | .thenAnswer(new Answer() { 61 | private int count = 0; 62 | 63 | @Override 64 | public DescribeForecastResult answer(InvocationOnMock invocation) { 65 | switch (count++) { 66 | case 0: 67 | throw new ResourceNotFoundException("cannot find given forecast"); 68 | case 1: 69 | return dummyDescribeForecastResult; 70 | default: 71 | throw new IllegalArgumentException(); 72 | } 73 | } 74 | }); 75 | 76 | assertThrows(ResourceSetupInProgressException.class, 77 | () -> handler.process(buildResourceIdMap(System.currentTimeMillis(), TEST_FORECAST_RESOURCE_ARN, DEFAULT_DATA_FREQUENCY_VALUE))); 78 | 79 | verify(mockForecastClient, times(2)).describeForecast(any(DescribeForecastRequest.class)); 80 | verify(mockForecastClient, times(1)).createForecast(any(CreateForecastRequest.class)); 81 | } 82 | 83 | @Test 84 | public void testProcess_withCreateNewForecastThenInActive() { 85 | DescribeForecastResult dummyDescribeForecastResult = new DescribeForecastResult().withStatus(RESOURCE_ACTIVE_STATUS); 86 | when(mockForecastClient.describeForecast(any(DescribeForecastRequest.class))) 87 | .thenAnswer(new Answer() { 88 | private int count = 0; 89 | 90 | @Override 91 | public DescribeForecastResult answer(InvocationOnMock invocation) { 92 | switch (count++) { 93 | case 0: 94 | throw new ResourceNotFoundException("cannot find given forecast"); 95 | case 1: 96 | return dummyDescribeForecastResult; 97 | default: 98 | throw new IllegalArgumentException(); 99 | } 100 | } 101 | }); 102 | 103 | handler.process(buildResourceIdMap(System.currentTimeMillis(), TEST_FORECAST_RESOURCE_ARN, DEFAULT_DATA_FREQUENCY_VALUE)); 104 | 105 | verify(mockForecastClient, times(2)).describeForecast(any(DescribeForecastRequest.class)); 106 | verify(mockForecastClient, times(1)).createForecast(any(CreateForecastRequest.class)); 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /src/test/java/com/amazonaws/lambda/predictiongeneration/CreatePredictorHandlerTest.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.lambda.predictiongeneration; 2 | 3 | import com.amazonaws.lambda.predictiongeneration.exception.ResourceSetupFailureException; 4 | import com.amazonaws.lambda.predictiongeneration.exception.ResourceSetupInProgressException; 5 | import com.amazonaws.services.forecast.model.CreatePredictorRequest; 6 | import com.amazonaws.services.forecast.model.DescribePredictorRequest; 7 | import com.amazonaws.services.forecast.model.DescribePredictorResult; 8 | import com.amazonaws.services.forecast.model.FeaturizationConfig; 9 | import com.amazonaws.services.forecast.model.InputDataConfig; 10 | import com.amazonaws.services.forecast.model.ResourceNotFoundException; 11 | import org.junit.Rule; 12 | import org.junit.contrib.java.lang.system.EnvironmentVariables; 13 | import org.junit.jupiter.api.BeforeEach; 14 | import org.junit.jupiter.api.Test; 15 | import org.mockito.invocation.InvocationOnMock; 16 | import org.mockito.stubbing.Answer; 17 | 18 | import static com.amazonaws.lambda.predictiongeneration.CreatePredictorHandler.SECONDS_IN_A_DAY; 19 | import static com.amazonaws.lambda.predictiongeneration.GenerateForecastResourcesIdsHandler.buildResourceIdMap; 20 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.DATASET_GROUP_NAME_PREFIX; 21 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.DATA_FREQUENCY_SECONDS_MAPPING; 22 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.PREDICTOR_NAME_PREFIX; 23 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.RESOURCE_ACTIVE_STATUS; 24 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.RESOURCE_FAILED_STATUS; 25 | import static org.junit.jupiter.api.Assertions.assertThrows; 26 | import static org.mockito.ArgumentMatchers.any; 27 | import static org.mockito.ArgumentMatchers.eq; 28 | import static org.mockito.Mockito.never; 29 | import static org.mockito.Mockito.times; 30 | import static org.mockito.Mockito.verify; 31 | import static org.mockito.Mockito.when; 32 | 33 | public class CreatePredictorHandlerTest extends BaseTest { 34 | 35 | private static final int TEST_FORECAST_HORIZON_IN_DAYS = 3; 36 | @Rule 37 | public final EnvironmentVariables environmentVariables = new EnvironmentVariables(); 38 | 39 | CreatePredictorHandler handler; 40 | 41 | @BeforeEach 42 | public void setup() { 43 | environmentVariables.set("FORECAST_HORIZON_IN_DAYS", String.valueOf(TEST_FORECAST_HORIZON_IN_DAYS)); 44 | handler = new CreatePredictorHandler(mockForecastClient); 45 | } 46 | 47 | @Test 48 | public void testProcess_withActiveStatus() { 49 | DescribePredictorResult dummyDescribePredictorResult = new DescribePredictorResult().withStatus(RESOURCE_ACTIVE_STATUS); 50 | when(mockForecastClient.describePredictor(any(DescribePredictorRequest.class))).thenReturn(dummyDescribePredictorResult); 51 | 52 | handler.process(buildResourceIdMap(System.currentTimeMillis(), TEST_FORECAST_RESOURCE_ARN, DEFAULT_DATA_FREQUENCY_VALUE)); 53 | 54 | verify(mockForecastClient, times(1)).describePredictor(any(DescribePredictorRequest.class)); 55 | verify(mockForecastClient, never()).createPredictor(any(CreatePredictorRequest.class)); 56 | } 57 | 58 | @Test 59 | public void testProcess_withFailedStatus() { 60 | DescribePredictorResult dummyDescribePredictorResult = new DescribePredictorResult().withStatus(RESOURCE_FAILED_STATUS); 61 | when(mockForecastClient.describePredictor(any(DescribePredictorRequest.class))).thenReturn(dummyDescribePredictorResult); 62 | 63 | assertThrows(ResourceSetupFailureException.class, 64 | () -> handler.process(buildResourceIdMap(System.currentTimeMillis(), TEST_FORECAST_RESOURCE_ARN, DEFAULT_DATA_FREQUENCY_VALUE))); 65 | 66 | verify(mockForecastClient, times(1)).describePredictor(any(DescribePredictorRequest.class)); 67 | verify(mockForecastClient, never()).createPredictor(any(CreatePredictorRequest.class)); 68 | } 69 | 70 | @Test 71 | public void testProcess_withCreateNewPredictorThenInCreating() { 72 | DescribePredictorResult dummyDescribePredictorResult = new DescribePredictorResult().withStatus(TEST_RESOURCE_CREATING_STATUS); 73 | when(mockForecastClient.describePredictor(any(DescribePredictorRequest.class))) 74 | .thenAnswer(new Answer() { 75 | private int count = 0; 76 | 77 | @Override 78 | public DescribePredictorResult answer(InvocationOnMock invocation) { 79 | switch (count++) { 80 | case 0: 81 | throw new ResourceNotFoundException("cannot find given predictor"); 82 | case 1: 83 | return dummyDescribePredictorResult; 84 | default: 85 | throw new IllegalArgumentException(); 86 | } 87 | } 88 | }); 89 | long currentTimeMillis = System.currentTimeMillis(); 90 | CreatePredictorRequest expectedCreatePredictorRequest = new CreatePredictorRequest() 91 | .withForecastHorizon(TEST_FORECAST_HORIZON_IN_DAYS * SECONDS_IN_A_DAY / DATA_FREQUENCY_SECONDS_MAPPING.get(DEFAULT_DATA_FREQUENCY_VALUE)) 92 | .withFeaturizationConfig(new FeaturizationConfig().withForecastFrequency(DEFAULT_DATA_FREQUENCY_VALUE)) 93 | .withInputDataConfig(new InputDataConfig().withDatasetGroupArn(TEST_FORECAST_RESOURCE_ARN + "dataset-group/" + DATASET_GROUP_NAME_PREFIX + currentTimeMillis)) 94 | .withPredictorName(PREDICTOR_NAME_PREFIX + currentTimeMillis) 95 | .withPerformAutoML(true); 96 | 97 | assertThrows(ResourceSetupInProgressException.class, 98 | () -> handler.process(buildResourceIdMap(currentTimeMillis, TEST_FORECAST_RESOURCE_ARN, DEFAULT_DATA_FREQUENCY_VALUE))); 99 | 100 | verify(mockForecastClient, times(2)).describePredictor(any(DescribePredictorRequest.class)); 101 | verify(mockForecastClient, times(1)).createPredictor(eq(expectedCreatePredictorRequest)); 102 | } 103 | 104 | @Test 105 | public void testProcess_withCreateNewPredictorThenInActive() { 106 | DescribePredictorResult dummyDescribePredictorResult = new DescribePredictorResult().withStatus(RESOURCE_ACTIVE_STATUS); 107 | when(mockForecastClient.describePredictor(any(DescribePredictorRequest.class))) 108 | .thenAnswer(new Answer() { 109 | private int count = 0; 110 | 111 | @Override 112 | public DescribePredictorResult answer(InvocationOnMock invocation) { 113 | switch (count++) { 114 | case 0: 115 | throw new ResourceNotFoundException("cannot find given predictor"); 116 | case 1: 117 | return dummyDescribePredictorResult; 118 | default: 119 | throw new IllegalArgumentException(); 120 | } 121 | } 122 | }); 123 | 124 | handler.process(buildResourceIdMap(System.currentTimeMillis(), TEST_FORECAST_RESOURCE_ARN, DEFAULT_DATA_FREQUENCY_VALUE)); 125 | 126 | verify(mockForecastClient, times(2)).describePredictor(any(DescribePredictorRequest.class)); 127 | verify(mockForecastClient, times(1)).createPredictor(any(CreatePredictorRequest.class)); 128 | } 129 | } 130 | -------------------------------------------------------------------------------- /src/test/java/com/amazonaws/lambda/predictiongeneration/DeleteOutdatedDatasetGroupsHandlerTest.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.lambda.predictiongeneration; 2 | 3 | import com.amazonaws.lambda.predictiongeneration.exception.ResourceCleanupInProgressException; 4 | import com.amazonaws.services.forecast.model.DatasetGroupSummary; 5 | import com.amazonaws.services.forecast.model.DeleteDatasetGroupRequest; 6 | import com.amazonaws.services.forecast.model.ListDatasetGroupsRequest; 7 | import com.amazonaws.services.forecast.model.ListDatasetGroupsResult; 8 | import org.junit.jupiter.api.BeforeEach; 9 | import org.junit.jupiter.api.Test; 10 | import org.mockito.invocation.InvocationOnMock; 11 | import org.mockito.stubbing.Answer; 12 | 13 | import java.util.ArrayList; 14 | import java.util.Collections; 15 | import java.util.List; 16 | 17 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.DATASET_GROUP_ARN_KEY; 18 | import static org.junit.jupiter.api.Assertions.assertThrows; 19 | import static org.mockito.ArgumentMatchers.any; 20 | import static org.mockito.Mockito.never; 21 | import static org.mockito.Mockito.times; 22 | import static org.mockito.Mockito.verify; 23 | import static org.mockito.Mockito.when; 24 | 25 | public class DeleteOutdatedDatasetGroupsHandlerTest extends BaseTest { 26 | 27 | DeleteOutdatedDatasetGroupsHandler handler; 28 | 29 | @BeforeEach 30 | public void setup() { 31 | handler = new DeleteOutdatedDatasetGroupsHandler(mockForecastClient); 32 | } 33 | 34 | @Test 35 | public void testProcess_withNoDatasetGroup() { 36 | ListDatasetGroupsResult dummyListDatasetGroupsResult = new ListDatasetGroupsResult().withDatasetGroups(Collections.emptyList()); 37 | when(mockForecastClient.listDatasetGroups(any(ListDatasetGroupsRequest.class))).thenReturn(dummyListDatasetGroupsResult); 38 | 39 | assertThrows(IllegalStateException.class, 40 | () -> handler.process(testResourceIdMap)); 41 | 42 | verify(mockForecastClient, times(1)).listDatasetGroups(any(ListDatasetGroupsRequest.class)); 43 | verify(mockForecastClient, never()).deleteDatasetGroup(any(DeleteDatasetGroupRequest.class)); 44 | } 45 | 46 | @Test 47 | public void testProcess_withNoOutdatedDatasetGroup() { 48 | ListDatasetGroupsResult dummyListDatasetGroupsResult = new ListDatasetGroupsResult() 49 | .withDatasetGroups(Collections.singleton(new DatasetGroupSummary().withDatasetGroupArn(testResourceIdMap.get(DATASET_GROUP_ARN_KEY)))); 50 | when(mockForecastClient.listDatasetGroups(any(ListDatasetGroupsRequest.class))).thenReturn(dummyListDatasetGroupsResult); 51 | 52 | handler.process(testResourceIdMap); 53 | 54 | verify(mockForecastClient, times(1)).listDatasetGroups(any(ListDatasetGroupsRequest.class)); 55 | verify(mockForecastClient, never()).deleteDatasetGroup(any(DeleteDatasetGroupRequest.class)); 56 | } 57 | 58 | @Test 59 | public void testProcess_withUnableToDeleteDatasetGroups() { 60 | List dummyOutdatedDatasetGroups = new ArrayList<>(); 61 | dummyOutdatedDatasetGroups.add(new DatasetGroupSummary().withDatasetGroupArn("dummy1")); 62 | dummyOutdatedDatasetGroups.add(new DatasetGroupSummary().withDatasetGroupArn("dummy2")); 63 | List dummyExistingDatasetGroups = new ArrayList<>(); 64 | dummyExistingDatasetGroups.add(new DatasetGroupSummary().withDatasetGroupArn(testResourceIdMap.get(DATASET_GROUP_ARN_KEY))); 65 | dummyExistingDatasetGroups.addAll(dummyOutdatedDatasetGroups); 66 | ListDatasetGroupsResult dummyListDatasetGroupsResult = new ListDatasetGroupsResult().withDatasetGroups(dummyExistingDatasetGroups); 67 | when(mockForecastClient.listDatasetGroups(any(ListDatasetGroupsRequest.class))).thenReturn(dummyListDatasetGroupsResult); 68 | 69 | assertThrows(ResourceCleanupInProgressException.class, 70 | () -> handler.process(testResourceIdMap)); 71 | 72 | verify(mockForecastClient, times(2)).listDatasetGroups(any(ListDatasetGroupsRequest.class)); 73 | verify(mockForecastClient, times(dummyOutdatedDatasetGroups.size())).deleteDatasetGroup(any(DeleteDatasetGroupRequest.class)); 74 | } 75 | 76 | @Test 77 | public void testProcess_withAbleToDeleteDatasetGroups() { 78 | DatasetGroupSummary testPreservedDatasetGroup = new DatasetGroupSummary() 79 | .withDatasetGroupArn(testResourceIdMap.get(DATASET_GROUP_ARN_KEY)); 80 | List dummyOutdatedDatasetGroups= new ArrayList<>(); 81 | dummyOutdatedDatasetGroups.add(new DatasetGroupSummary().withDatasetGroupArn("dummy1")); 82 | dummyOutdatedDatasetGroups.add(new DatasetGroupSummary().withDatasetGroupArn("dummy2")); 83 | when(mockForecastClient.listDatasetGroups(any(ListDatasetGroupsRequest.class))) 84 | .thenAnswer(new Answer() { 85 | private int count = 0; 86 | 87 | @Override 88 | public ListDatasetGroupsResult answer(InvocationOnMock invocation) { 89 | switch (count++) { 90 | case 0: 91 | List withOutdatedDatasets = new ArrayList<>(); 92 | withOutdatedDatasets.add(testPreservedDatasetGroup); 93 | withOutdatedDatasets.addAll(dummyOutdatedDatasetGroups); 94 | ListDatasetGroupsResult withOutdatedResult = new ListDatasetGroupsResult().withDatasetGroups(withOutdatedDatasets); 95 | return withOutdatedResult; 96 | case 1: 97 | ListDatasetGroupsResult withoutOutdatedResult = new ListDatasetGroupsResult() 98 | .withDatasetGroups(Collections.singletonList(testPreservedDatasetGroup)); 99 | return withoutOutdatedResult; 100 | default: 101 | throw new IllegalArgumentException(); 102 | } 103 | } 104 | }); 105 | 106 | handler.process(testResourceIdMap); 107 | 108 | verify(mockForecastClient, times(2)).listDatasetGroups(any(ListDatasetGroupsRequest.class)); 109 | verify(mockForecastClient, times(dummyOutdatedDatasetGroups.size())).deleteDatasetGroup(any(DeleteDatasetGroupRequest.class)); 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /src/test/java/com/amazonaws/lambda/predictiongeneration/DeleteOutdatedDatasetImportJobsHandlerTest.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.lambda.predictiongeneration; 2 | 3 | 4 | import com.amazonaws.services.forecast.model.DatasetImportJobSummary; 5 | import com.amazonaws.services.forecast.model.DatasetSummary; 6 | import com.amazonaws.services.forecast.model.DeleteDatasetImportJobRequest; 7 | import com.amazonaws.services.forecast.model.Filter; 8 | import com.amazonaws.services.forecast.model.FilterConditionString; 9 | import com.amazonaws.services.forecast.model.ListDatasetImportJobsRequest; 10 | import com.amazonaws.services.forecast.model.ListDatasetImportJobsResult; 11 | import com.amazonaws.services.forecast.model.ListDatasetsRequest; 12 | import com.amazonaws.services.forecast.model.ListDatasetsResult; 13 | import com.google.common.collect.Lists; 14 | import org.junit.jupiter.api.BeforeEach; 15 | import org.junit.jupiter.api.Test; 16 | import org.mockito.invocation.InvocationOnMock; 17 | import org.mockito.stubbing.Answer; 18 | 19 | import java.util.Collections; 20 | 21 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.DATASET_ARN_KEY; 22 | import static org.mockito.ArgumentMatchers.any; 23 | import static org.mockito.ArgumentMatchers.eq; 24 | import static org.mockito.Mockito.never; 25 | import static org.mockito.Mockito.times; 26 | import static org.mockito.Mockito.verify; 27 | import static org.mockito.Mockito.when; 28 | 29 | public class DeleteOutdatedDatasetImportJobsHandlerTest extends BaseTest { 30 | 31 | DeleteOutdatedDatasetImportJobsHandler handler; 32 | 33 | @BeforeEach 34 | public void setup() { 35 | handler = new DeleteOutdatedDatasetImportJobsHandler(mockForecastClient); 36 | } 37 | 38 | @Test 39 | public void testProcess_withNoOutdatedDatasets() { 40 | ListDatasetsResult dummyListDatasetsResult = new ListDatasetsResult() 41 | .withDatasets(Collections.singletonList(new DatasetSummary().withDatasetArn(testResourceIdMap.get(DATASET_ARN_KEY)))); 42 | when(mockForecastClient.listDatasets(any(ListDatasetsRequest.class))).thenReturn(dummyListDatasetsResult); 43 | 44 | handler.process(testResourceIdMap); 45 | 46 | verify(mockForecastClient, times(1)).listDatasets(any(ListDatasetsRequest.class)); 47 | verify(mockForecastClient, never()).listDatasetImportJobs(any(ListDatasetImportJobsRequest.class)); 48 | verify(mockForecastClient, never()).deleteDatasetImportJob(any(DeleteDatasetImportJobRequest.class)); 49 | } 50 | 51 | @Test 52 | public void testProcess_withNoOutdatedDatasetImportJobs() { 53 | String dummyOutdatedDatasetArn = "dummyOutdatedDatasetArn"; 54 | ListDatasetsResult dummyListDatasetsResult = new ListDatasetsResult() 55 | .withDatasets(Lists.newArrayList(new DatasetSummary().withDatasetArn(testResourceIdMap.get(DATASET_ARN_KEY)), 56 | new DatasetSummary().withDatasetArn(dummyOutdatedDatasetArn))); 57 | when(mockForecastClient.listDatasets(any(ListDatasetsRequest.class))).thenReturn(dummyListDatasetsResult); 58 | when(mockForecastClient.listDatasetImportJobs(eq(new ListDatasetImportJobsRequest() 59 | .withFilters( 60 | new Filter() 61 | .withKey("DatasetArn") 62 | .withValue(dummyOutdatedDatasetArn) 63 | .withCondition(FilterConditionString.IS))))) 64 | .thenReturn(new ListDatasetImportJobsResult().withDatasetImportJobs()); 65 | 66 | handler.process(testResourceIdMap); 67 | 68 | verify(mockForecastClient, times(1)).listDatasets(any(ListDatasetsRequest.class)); 69 | verify(mockForecastClient, times(1)).listDatasetImportJobs(any(ListDatasetImportJobsRequest.class)); 70 | verify(mockForecastClient, never()).deleteDatasetImportJob(any(DeleteDatasetImportJobRequest.class)); 71 | } 72 | 73 | @Test 74 | public void testProcess_withAbleToDeleteDatasetImportJobs() { 75 | when(mockForecastClient.listDatasets(any(ListDatasetsRequest.class))).thenReturn(new ListDatasetsResult().withDatasets( 76 | Lists.newArrayList(new DatasetSummary().withDatasetArn("dummyOutdatedDatasetArn"), 77 | new DatasetSummary().withDatasetArn(testResourceIdMap.get(DATASET_ARN_KEY))) 78 | )); 79 | 80 | when(mockForecastClient.listDatasetImportJobs(eq(new ListDatasetImportJobsRequest() 81 | .withFilters( 82 | new Filter() 83 | .withKey("DatasetArn") 84 | .withValue("dummyOutdatedDatasetArn") 85 | .withCondition(FilterConditionString.IS) 86 | )))) 87 | .thenAnswer(new Answer() { 88 | 89 | private int count = 0; 90 | 91 | @Override 92 | public ListDatasetImportJobsResult answer(InvocationOnMock invocation) { 93 | switch (count++) { 94 | case 0: 95 | return new ListDatasetImportJobsResult() 96 | .withDatasetImportJobs(new DatasetImportJobSummary().withDatasetImportJobArn("dummyDijArn")); 97 | case 1: 98 | return new ListDatasetImportJobsResult().withDatasetImportJobs(); 99 | default: 100 | throw new IllegalArgumentException(); 101 | } 102 | } 103 | }); 104 | 105 | handler.process(testResourceIdMap); 106 | 107 | verify(mockForecastClient, times(2)).listDatasets(any(ListDatasetsRequest.class)); 108 | verify(mockForecastClient, times(2)).listDatasetImportJobs(any(ListDatasetImportJobsRequest.class)); 109 | verify(mockForecastClient, times(1)).deleteDatasetImportJob(any(DeleteDatasetImportJobRequest.class)); 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /src/test/java/com/amazonaws/lambda/predictiongeneration/DeleteOutdatedDatasetsHandlerTest.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.lambda.predictiongeneration; 2 | 3 | import com.amazonaws.lambda.predictiongeneration.exception.ResourceCleanupInProgressException; 4 | import com.amazonaws.services.forecast.model.DatasetSummary; 5 | import com.amazonaws.services.forecast.model.DeleteDatasetRequest; 6 | import com.amazonaws.services.forecast.model.ListDatasetsRequest; 7 | import com.amazonaws.services.forecast.model.ListDatasetsResult; 8 | import org.junit.jupiter.api.BeforeEach; 9 | import org.junit.jupiter.api.Test; 10 | import org.mockito.invocation.InvocationOnMock; 11 | import org.mockito.stubbing.Answer; 12 | 13 | import java.util.ArrayList; 14 | import java.util.Collections; 15 | import java.util.List; 16 | 17 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.DATASET_ARN_KEY; 18 | import static org.junit.jupiter.api.Assertions.assertThrows; 19 | import static org.mockito.ArgumentMatchers.any; 20 | import static org.mockito.Mockito.never; 21 | import static org.mockito.Mockito.times; 22 | import static org.mockito.Mockito.verify; 23 | import static org.mockito.Mockito.when; 24 | 25 | public class DeleteOutdatedDatasetsHandlerTest extends BaseTest { 26 | 27 | DeleteOutdatedDatasetsHandler handler; 28 | 29 | @BeforeEach 30 | public void setup() { 31 | handler = new DeleteOutdatedDatasetsHandler(mockForecastClient); 32 | } 33 | 34 | @Test 35 | public void testProcess_withNoOutdatedDataset() { 36 | ListDatasetsResult dummyListDatasetsResult = new ListDatasetsResult() 37 | .withDatasets(Collections.singletonList(new DatasetSummary().withDatasetArn(testResourceIdMap.get(DATASET_ARN_KEY)))); 38 | when(mockForecastClient.listDatasets(any(ListDatasetsRequest.class))).thenReturn(dummyListDatasetsResult); 39 | 40 | handler.process(testResourceIdMap); 41 | 42 | verify(mockForecastClient, times(1)).listDatasets(any(ListDatasetsRequest.class)); 43 | verify(mockForecastClient, never()).deleteDataset(any(DeleteDatasetRequest.class)); 44 | } 45 | 46 | @Test 47 | public void testProcess_withUnableToDeleteDatasets() { 48 | List dummyOutdatedDatasets = new ArrayList<>(); 49 | dummyOutdatedDatasets.add(new DatasetSummary().withDatasetArn("dummy1")); 50 | dummyOutdatedDatasets.add(new DatasetSummary().withDatasetArn("dummy2")); 51 | List dummyExistingDatasets = new ArrayList<>(); 52 | dummyExistingDatasets.add(new DatasetSummary().withDatasetArn(testResourceIdMap.get(DATASET_ARN_KEY))); 53 | dummyExistingDatasets.addAll(dummyOutdatedDatasets); 54 | ListDatasetsResult dummyListDatasetsResult = new ListDatasetsResult().withDatasets(dummyExistingDatasets); 55 | when(mockForecastClient.listDatasets(any(ListDatasetsRequest.class))).thenReturn(dummyListDatasetsResult); 56 | 57 | assertThrows(ResourceCleanupInProgressException.class, () -> handler.process(testResourceIdMap)); 58 | 59 | verify(mockForecastClient, times(2)).listDatasets(any(ListDatasetsRequest.class)); 60 | verify(mockForecastClient, times(dummyOutdatedDatasets.size())).deleteDataset(any(DeleteDatasetRequest.class)); 61 | } 62 | 63 | @Test 64 | public void testProcess_withAbleToDeleteDatasets() { 65 | DatasetSummary testPreservedDataset = new DatasetSummary().withDatasetArn(testResourceIdMap.get(DATASET_ARN_KEY)); 66 | List dummyOutdatedDatasets= new ArrayList<>(); 67 | dummyOutdatedDatasets.add(new DatasetSummary().withDatasetArn("dummy1")); 68 | dummyOutdatedDatasets.add(new DatasetSummary().withDatasetArn("dummy2")); 69 | when(mockForecastClient.listDatasets(any(ListDatasetsRequest.class))) 70 | .thenAnswer(new Answer() { 71 | private int count = 0; 72 | 73 | @Override 74 | public ListDatasetsResult answer(InvocationOnMock invocation) { 75 | switch (count++) { 76 | case 0: 77 | List withOutdatedDatasets = new ArrayList<>(); 78 | withOutdatedDatasets.add(testPreservedDataset); 79 | withOutdatedDatasets.addAll(dummyOutdatedDatasets); 80 | ListDatasetsResult withOutdatedResult = new ListDatasetsResult().withDatasets(withOutdatedDatasets); 81 | return withOutdatedResult; 82 | case 1: 83 | ListDatasetsResult withoutOutdatedResponse = new ListDatasetsResult().withDatasets(Collections.singletonList(testPreservedDataset)); 84 | return withoutOutdatedResponse; 85 | default: 86 | throw new IllegalArgumentException(); 87 | } 88 | } 89 | }); 90 | 91 | handler.process(testResourceIdMap); 92 | 93 | verify(mockForecastClient, times(2)).listDatasets(any(ListDatasetsRequest.class)); 94 | verify(mockForecastClient, times(dummyOutdatedDatasets.size())).deleteDataset(any(DeleteDatasetRequest.class)); 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /src/test/java/com/amazonaws/lambda/predictiongeneration/DeleteOutdatedPredictorsHandlerTest.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.lambda.predictiongeneration; 2 | 3 | import com.amazonaws.lambda.predictiongeneration.exception.ResourceCleanupInProgressException; 4 | import com.amazonaws.services.forecast.model.DeletePredictorRequest; 5 | import com.amazonaws.services.forecast.model.ListPredictorsRequest; 6 | import com.amazonaws.services.forecast.model.ListPredictorsResult; 7 | import com.amazonaws.services.forecast.model.PredictorSummary; 8 | import org.junit.jupiter.api.BeforeEach; 9 | import org.junit.jupiter.api.Test; 10 | import org.mockito.invocation.InvocationOnMock; 11 | import org.mockito.stubbing.Answer; 12 | 13 | import java.util.ArrayList; 14 | import java.util.Collections; 15 | import java.util.List; 16 | 17 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.PREDICTOR_ARN_KEY; 18 | import static org.junit.jupiter.api.Assertions.assertThrows; 19 | import static org.mockito.ArgumentMatchers.any; 20 | import static org.mockito.Mockito.never; 21 | import static org.mockito.Mockito.times; 22 | import static org.mockito.Mockito.verify; 23 | import static org.mockito.Mockito.when; 24 | 25 | public class DeleteOutdatedPredictorsHandlerTest extends BaseTest { 26 | 27 | DeleteOutdatedPredictorsHandler handler; 28 | 29 | @BeforeEach 30 | public void setup() { 31 | handler = new DeleteOutdatedPredictorsHandler(mockForecastClient); 32 | } 33 | 34 | @Test 35 | public void testProcess_withNoOutdatedPredictor() { 36 | ListPredictorsResult dummyListPredictorsResult = new ListPredictorsResult() 37 | .withPredictors(Collections.singletonList(new PredictorSummary().withPredictorArn(testResourceIdMap.get(PREDICTOR_ARN_KEY)))); 38 | when(mockForecastClient.listPredictors(any(ListPredictorsRequest.class))).thenReturn(dummyListPredictorsResult); 39 | 40 | handler.process(testResourceIdMap); 41 | 42 | verify(mockForecastClient, times(1)).listPredictors(any(ListPredictorsRequest.class)); 43 | verify(mockForecastClient, never()).deletePredictor(any(DeletePredictorRequest.class)); 44 | } 45 | 46 | @Test 47 | public void testProcess_withUnableToDeletePredictor() { 48 | List dummyOutdatedPredictors = new ArrayList<>(); 49 | dummyOutdatedPredictors.add(new PredictorSummary().withPredictorArn("dummy1")); 50 | dummyOutdatedPredictors.add(new PredictorSummary().withPredictorArn("dummy1")); 51 | List dummyExistingPredictors = new ArrayList<>(); 52 | dummyExistingPredictors.add(new PredictorSummary().withPredictorArn(testResourceIdMap.get(PREDICTOR_ARN_KEY))); 53 | dummyExistingPredictors.addAll(dummyOutdatedPredictors); 54 | ListPredictorsResult dummyListPredictorsResult = new ListPredictorsResult().withPredictors(dummyExistingPredictors); 55 | when(mockForecastClient.listPredictors(any(ListPredictorsRequest.class))).thenReturn(dummyListPredictorsResult); 56 | 57 | assertThrows(ResourceCleanupInProgressException.class, () -> handler.process(testResourceIdMap)); 58 | 59 | verify(mockForecastClient, times(2)).listPredictors(any(ListPredictorsRequest.class)); 60 | verify(mockForecastClient, times(dummyOutdatedPredictors.size())).deletePredictor(any(DeletePredictorRequest.class)); 61 | } 62 | 63 | @Test 64 | public void testProcess_withAbleToDeletePredictors() { 65 | PredictorSummary testPreservedPredictor = new PredictorSummary().withPredictorArn(testResourceIdMap.get(PREDICTOR_ARN_KEY)); 66 | 67 | List dummyOutdatedPredictors = new ArrayList<>(); 68 | dummyOutdatedPredictors.add(new PredictorSummary().withPredictorArn("dummy1")); 69 | dummyOutdatedPredictors.add(new PredictorSummary().withPredictorArn("dummy2")); 70 | 71 | 72 | when(mockForecastClient.listPredictors(any(ListPredictorsRequest.class))) 73 | .thenAnswer(new Answer() { 74 | private int count = 0; 75 | 76 | @Override 77 | public ListPredictorsResult answer(InvocationOnMock invocation) { 78 | switch (count++) { 79 | case 0: 80 | List withOutdatedPredictors = new ArrayList<>(); 81 | withOutdatedPredictors.add(testPreservedPredictor); 82 | withOutdatedPredictors.addAll(dummyOutdatedPredictors); 83 | ListPredictorsResult withOutdatedResult = new ListPredictorsResult().withPredictors(withOutdatedPredictors); 84 | return withOutdatedResult; 85 | case 1: 86 | ListPredictorsResult withoutOutdatedResult = new ListPredictorsResult() 87 | .withPredictors(Collections.singletonList(testPreservedPredictor)); 88 | return withoutOutdatedResult; 89 | default: 90 | throw new IllegalArgumentException(); 91 | } 92 | } 93 | }); 94 | 95 | handler.process(testResourceIdMap); 96 | 97 | verify(mockForecastClient, times(2)).listPredictors(any(ListPredictorsRequest.class)); 98 | verify(mockForecastClient, times(dummyOutdatedPredictors.size())).deletePredictor(any(DeletePredictorRequest.class)); 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /src/test/java/com/amazonaws/lambda/predictiongeneration/GenerateForecastResourcesIdsCronHandlerTest.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.lambda.predictiongeneration; 2 | 3 | import com.amazonaws.services.forecast.model.DatasetSummary; 4 | import com.amazonaws.services.forecast.model.ListDatasetsRequest; 5 | import com.amazonaws.services.forecast.model.ListDatasetsResult; 6 | import com.amazonaws.services.forecast.model.ListPredictorsRequest; 7 | import com.amazonaws.services.forecast.model.ListPredictorsResult; 8 | import com.amazonaws.services.forecast.model.PredictorSummary; 9 | import com.amazonaws.services.lambda.runtime.Context; 10 | import com.fasterxml.jackson.core.type.TypeReference; 11 | import com.fasterxml.jackson.databind.ObjectMapper; 12 | import org.junit.jupiter.api.BeforeEach; 13 | import org.junit.jupiter.api.Test; 14 | 15 | import java.time.Clock; 16 | import java.time.LocalDateTime; 17 | import java.time.ZoneId; 18 | import java.time.ZoneOffset; 19 | import java.util.Date; 20 | import java.util.Map; 21 | 22 | import static com.amazonaws.lambda.predictiongeneration.GenerateForecastResourcesIdsCronHandler.buildCronResourceIdMap; 23 | import static org.junit.jupiter.api.Assertions.assertEquals; 24 | import static org.junit.jupiter.api.Assertions.assertThrows; 25 | import static org.mockito.ArgumentMatchers.any; 26 | import static org.mockito.Mockito.mock; 27 | import static org.mockito.Mockito.when; 28 | 29 | public class GenerateForecastResourcesIdsCronHandlerTest extends BaseTest { 30 | 31 | private Clock fixedClock; 32 | private GenerateForecastResourcesIdsCronHandler handler; 33 | 34 | @BeforeEach 35 | public void setup() { 36 | fixedClock = Clock.fixed(LocalDateTime.of(2019, 1, 1, 1, 1) 37 | .toInstant(ZoneOffset.UTC), ZoneId.of("UTC")); 38 | handler = new GenerateForecastResourcesIdsCronHandler(fixedClock, mockForecastClient); 39 | } 40 | 41 | @Test 42 | public void testHandleRequest_WithNoDataset() { 43 | when(mockForecastClient.listDatasets(any(ListDatasetsRequest.class))).thenReturn(new ListDatasetsResult().withDatasets()); 44 | assertThrows(IllegalStateException.class, () -> handler.handleRequest(null, mock(Context.class))); 45 | } 46 | 47 | @Test 48 | public void testHandleRequest() throws Exception { 49 | long currentTime = fixedClock.millis(); 50 | 51 | Context mockContext = mock(Context.class); 52 | when(mockContext.getInvokedFunctionArn()).thenReturn(TEST_FUNCTION_ARN); 53 | String dummyDatasetName = "dummyDatasetName"; 54 | when(mockForecastClient.listDatasets(any(ListDatasetsRequest.class))) 55 | .thenReturn(new ListDatasetsResult() 56 | .withDatasets(new DatasetSummary().withDatasetName(dummyDatasetName).withCreationTime(new Date()))); 57 | String dummyPredictorArn = "dummyPredictorArn"; 58 | when(mockForecastClient.listPredictors(any(ListPredictorsRequest.class))) 59 | .thenReturn(new ListPredictorsResult() 60 | .withPredictors(new PredictorSummary().withPredictorArn(dummyPredictorArn).withCreationTime(new Date()))); 61 | Map expectedResourceIdMap = buildCronResourceIdMap(currentTime, 62 | TEST_FORECAST_RESOURCE_ARN, dummyDatasetName, dummyPredictorArn); 63 | 64 | String mapString = handler.handleRequest(null, mockContext); 65 | Map actualResourceIdMap = new ObjectMapper().readValue(mapString, new TypeReference>() { 66 | }); 67 | assertEquals(expectedResourceIdMap, actualResourceIdMap); 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /src/test/java/com/amazonaws/lambda/predictiongeneration/GenerateForecastResourcesIdsHandlerTest.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.lambda.predictiongeneration; 2 | 3 | import com.amazonaws.lambda.predictiongeneration.exception.ResourceSetupInProgressException; 4 | import com.amazonaws.services.lambda.runtime.Context; 5 | import com.amazonaws.services.s3.AmazonS3; 6 | import com.amazonaws.services.s3.model.AmazonS3Exception; 7 | import com.amazonaws.services.s3.model.GetObjectMetadataRequest; 8 | import com.amazonaws.services.s3.model.ObjectMetadata; 9 | import com.fasterxml.jackson.core.type.TypeReference; 10 | import com.fasterxml.jackson.databind.ObjectMapper; 11 | import org.junit.jupiter.api.BeforeEach; 12 | import org.junit.jupiter.api.Test; 13 | 14 | import java.time.Clock; 15 | import java.time.LocalDateTime; 16 | import java.time.ZoneId; 17 | import java.time.ZoneOffset; 18 | import java.util.Map; 19 | 20 | import static com.amazonaws.lambda.predictiongeneration.GenerateForecastResourcesIdsHandler.buildResourceIdMap; 21 | import static com.amazonaws.lambda.predictiongeneration.PredictionGenerationUtils.ONE_HOUR_DATA_FREQUENCY_STRING; 22 | import static org.junit.jupiter.api.Assertions.assertEquals; 23 | import static org.junit.jupiter.api.Assertions.assertThrows; 24 | import static org.mockito.ArgumentMatchers.any; 25 | import static org.mockito.Mockito.mock; 26 | import static org.mockito.Mockito.when; 27 | 28 | public class GenerateForecastResourcesIdsHandlerTest extends BaseTest { 29 | 30 | private Clock fixedClock; 31 | private AmazonS3 mockS3Client; 32 | private GenerateForecastResourcesIdsHandler handler; 33 | 34 | @BeforeEach 35 | void setup() { 36 | fixedClock = Clock.fixed(LocalDateTime.of(2019, 1, 1, 1, 1) 37 | .toInstant(ZoneOffset.UTC), ZoneId.of("UTC")); 38 | mockS3Client = mock(AmazonS3.class); 39 | handler = new GenerateForecastResourcesIdsHandler(fixedClock, mockS3Client); 40 | } 41 | 42 | @Test 43 | public void testHandleRequest_WithNoSourceFileExist() { 44 | 45 | when(mockS3Client.getObjectMetadata(any(GetObjectMetadataRequest.class))) 46 | .thenThrow(AmazonS3Exception.class); 47 | 48 | assertThrows(ResourceSetupInProgressException.class, 49 | () -> handler.handleRequest(null, mock(Context.class))); 50 | } 51 | 52 | @Test 53 | public void testHandleRequest() throws Exception { 54 | 55 | long currentTime = fixedClock.millis(); 56 | Map expectedResourceIdMap = buildResourceIdMap(currentTime, TEST_FORECAST_RESOURCE_ARN, 57 | ONE_HOUR_DATA_FREQUENCY_STRING); 58 | 59 | Context mockContext = mock(Context.class); 60 | when(mockContext.getInvokedFunctionArn()).thenReturn(TEST_FUNCTION_ARN); 61 | ObjectMetadata dummyObjectMetadata = new ObjectMetadata(); 62 | dummyObjectMetadata.setContentLength(100); 63 | when(mockS3Client.getObjectMetadata(any(GetObjectMetadataRequest.class))).thenReturn(dummyObjectMetadata); 64 | 65 | String mapString = handler.handleRequest(null, mockContext); 66 | Map actualResourceIdMap = new ObjectMapper().readValue(mapString, new TypeReference>() { 67 | }); 68 | assertEquals(expectedResourceIdMap, actualResourceIdMap); 69 | 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /src/test/java/com/amazonaws/lambda/queryingpredictionresult/LoadDataFromS3ToDynamoDBHandlerTest.java: -------------------------------------------------------------------------------- 1 | package com.amazonaws.lambda.queryingpredictionresult; 2 | 3 | import com.amazonaws.services.dynamodbv2.AmazonDynamoDB; 4 | import com.amazonaws.services.dynamodbv2.local.embedded.DynamoDBEmbedded; 5 | import com.amazonaws.services.dynamodbv2.model.AttributeDefinition; 6 | import com.amazonaws.services.dynamodbv2.model.AttributeValue; 7 | import com.amazonaws.services.dynamodbv2.model.CreateTableRequest; 8 | import com.amazonaws.services.dynamodbv2.model.GetItemRequest; 9 | import com.amazonaws.services.dynamodbv2.model.GetItemResult; 10 | import com.amazonaws.services.dynamodbv2.model.KeySchemaElement; 11 | import com.amazonaws.services.dynamodbv2.model.KeyType; 12 | import com.amazonaws.services.dynamodbv2.model.ProvisionedThroughput; 13 | import com.amazonaws.services.dynamodbv2.model.ScalarAttributeType; 14 | import com.amazonaws.services.dynamodbv2.model.ScanRequest; 15 | import com.amazonaws.services.dynamodbv2.model.ScanResult; 16 | import com.amazonaws.services.lambda.runtime.Context; 17 | import com.amazonaws.services.lambda.runtime.events.S3Event; 18 | import com.amazonaws.services.s3.AmazonS3; 19 | import com.amazonaws.services.s3.event.S3EventNotification; 20 | import com.amazonaws.services.s3.event.S3EventNotification.S3EventNotificationRecord; 21 | import com.amazonaws.services.s3.model.AmazonS3Exception; 22 | import com.amazonaws.services.s3.model.GetObjectRequest; 23 | import com.amazonaws.services.s3.model.S3Object; 24 | import com.amazonaws.services.s3.model.S3ObjectInputStream; 25 | import org.junit.jupiter.api.AfterEach; 26 | import org.junit.jupiter.api.BeforeEach; 27 | import org.junit.jupiter.api.Test; 28 | import org.junit.Rule; 29 | import org.mockito.Mock; 30 | 31 | import java.io.IOException; 32 | import java.io.InputStream; 33 | import java.io.InputStreamReader; 34 | import java.io.LineNumberReader; 35 | import java.util.ArrayList; 36 | import java.util.Collections; 37 | import java.util.HashMap; 38 | import java.util.List; 39 | import java.util.Map; 40 | 41 | import org.junit.contrib.java.lang.system.EnvironmentVariables; 42 | 43 | import static com.amazonaws.lambda.queryingpredictionresult.LoadDataFromS3ToDynamoDBHandler.DYNAMODB_PREDICTION_METADATA_LATEST_PRED_DATA_FREQ_IN_SEC_ATTR_NAME; 44 | import static com.amazonaws.lambda.queryingpredictionresult.LoadDataFromS3ToDynamoDBHandler.DYNAMODB_PREDICTION_METADATA_LATEST_PRED_UUID_ATTR_NAME; 45 | import static org.junit.jupiter.api.Assertions.assertEquals; 46 | import static org.junit.jupiter.api.Assertions.assertThrows; 47 | import static org.mockito.ArgumentMatchers.any; 48 | import static org.mockito.Mockito.mock; 49 | import static org.mockito.Mockito.when; 50 | 51 | public class LoadDataFromS3ToDynamoDBHandlerTest { 52 | 53 | private static final String AWS_REGION = "us-east-1"; 54 | private static final String UNIT_TEST_ROOT_CLASS_PATH = "/"; 55 | private static final String UNIT_TEST_S3_FOLDER_NAME = "tgt"; 56 | private static final String PREDICTION_TABLE_NAME = "LocalTestTable"; 57 | private static final String PREDICTION_TABLE_HASH_KEY = PredictionResultItem.Attribute.ITEM_ID; 58 | private static final String PREDICTION_TABLE_RANGE_KEY = PredictionResultItem.Attribute.DATE; 59 | private static final String FORECAST_HORIZON_IN_DAYS = "3"; 60 | private static final String PREDICTION_METADATA_TABLE_NAME = "PredictionMetadata"; 61 | private static final String PREDICTION_METADATA_TABLE_HASH_KEY = "metadataKey"; 62 | private static final String PREDICTION_METADATA_TABLE_ATTRIBUTE_NAME = "metadataValue"; 63 | 64 | // test csv files are located under folder: resources/tgt/ 65 | private static final String TEST_EMPTY_FORECAST_EXPORT_JOB = "empty_forecast_export_job"; 66 | private static final String TEST_OBJECT_KEY0 = String.format("%s/%s_2019-10-16T21-40-00Z_part0.csv", UNIT_TEST_S3_FOLDER_NAME, TEST_EMPTY_FORECAST_EXPORT_JOB); 67 | 68 | private static final String TEST_FORECAST_EXPORT_JOB1 = "forecast_export_job1"; 69 | private static final String TEST_OBJECT_KEY1 = String.format("%s/%s_2019-10-16T21-40-00Z_part0.csv", UNIT_TEST_S3_FOLDER_NAME, TEST_FORECAST_EXPORT_JOB1); 70 | private static final long TEST_PREDICTION1_DATA_FREQUENCY_IN_SECONDS = 3600L; 71 | 72 | private static final String TEST_FORECAST_EXPORT_JOB_WITH_ONE_RECORD = "forecast_export_job_with_one_record"; 73 | private static final String TEST_OBJECT_KEY2 = String.format("%s/%s_2019-10-16T21-40-00Z_part0.csv", UNIT_TEST_S3_FOLDER_NAME, TEST_FORECAST_EXPORT_JOB_WITH_ONE_RECORD); 74 | 75 | @Rule 76 | public final EnvironmentVariables environmentVariables = new EnvironmentVariables(); 77 | 78 | @Mock 79 | private Context context; 80 | 81 | private AmazonS3 mockS3Client; 82 | private AmazonDynamoDB localDdbClient; 83 | private LoadDataFromS3ToDynamoDBHandler handler; 84 | 85 | @BeforeEach 86 | void setup() { 87 | 88 | // Setup Env variables 89 | environmentVariables.set("AWS_REGION", AWS_REGION); 90 | environmentVariables.set("PREDICTION_TABLE_NAME", PREDICTION_TABLE_NAME); 91 | environmentVariables.set("PREDICTION_TABLE_HASH_KEY", PREDICTION_TABLE_HASH_KEY); 92 | environmentVariables.set("PREDICTION_TABLE_RANGE_KEY", PREDICTION_TABLE_RANGE_KEY); 93 | environmentVariables.set("FORECAST_HORIZON_IN_DAYS", FORECAST_HORIZON_IN_DAYS); 94 | environmentVariables.set("PREDICTION_METADATA_TABLE_NAME", PREDICTION_METADATA_TABLE_NAME); 95 | environmentVariables.set("PREDICTION_METADATA_TABLE_HASH_KEY", PREDICTION_METADATA_TABLE_HASH_KEY); 96 | environmentVariables.set("PREDICTION_METADATA_TABLE_ATTRIBUTE_NAME", PREDICTION_METADATA_TABLE_ATTRIBUTE_NAME); 97 | 98 | mockS3Client = initMockS3Client(); 99 | localDdbClient = initLocalDynamoDB(); 100 | handler = new LoadDataFromS3ToDynamoDBHandler( 101 | mockS3Client, 102 | localDdbClient 103 | ); 104 | } 105 | 106 | @AfterEach 107 | void tearDown() { 108 | localDdbClient.deleteTable(PREDICTION_METADATA_TABLE_NAME); 109 | localDdbClient.deleteTable(PREDICTION_TABLE_NAME); 110 | localDdbClient = null; 111 | mockS3Client = null; 112 | handler = null; 113 | } 114 | 115 | @Test 116 | public void testLoadDataFromS3ToDynamoDB_WithThrowingCannotParseException() { 117 | String dummyS3Key = "dummy"; 118 | RuntimeException thrown = assertThrows(RuntimeException.class, 119 | () -> handler.handleRequest(makeMockS3Event("dummy"), context)); 120 | 121 | assertEquals(String.format("Cannot parse prediction result object key: %s", dummyS3Key), thrown.getMessage()); 122 | } 123 | 124 | @Test 125 | public void testLoadDataFromS3ToDynamoDB_WithEmptyPredictionResultFile() { 126 | RuntimeException thrown = assertThrows(RuntimeException.class, 127 | () -> handler.handleRequest(makeMockS3Event(TEST_OBJECT_KEY0), context)); 128 | 129 | assertEquals(String.format("Prediction result file %s contains no record.", TEST_OBJECT_KEY0), thrown.getMessage()); 130 | 131 | // cleanup 132 | refreshLocalDynamoDB(); 133 | } 134 | 135 | @Test 136 | public void testLoadDataFromS3ToDynamoDB() throws IOException { 137 | handler.handleRequest(makeMockS3Event(TEST_OBJECT_KEY1), context); 138 | verifyDynamoDB(TEST_FORECAST_EXPORT_JOB1); 139 | 140 | // cleanup 141 | refreshLocalDynamoDB(); 142 | } 143 | 144 | @Test 145 | public void testLoadDataFromS3ToDynamoDB_WithOneRecordPredictionResultFile() { 146 | RuntimeException thrown = assertThrows(RuntimeException.class, 147 | () -> handler.handleRequest(makeMockS3Event(TEST_OBJECT_KEY2), context)); 148 | 149 | assertEquals(String.format("Passed in items contains %d item, which is less than 2.", 1), thrown.getMessage()); 150 | 151 | // cleanup 152 | refreshLocalDynamoDB(); 153 | } 154 | 155 | private void verifyDynamoDB(final String forecastExportJobName) throws IOException { 156 | 157 | // Verify the latestPredictionUUID in metadata table 158 | Map latestPredictionUuidHashKey = new HashMap<>(); 159 | latestPredictionUuidHashKey.put(PREDICTION_METADATA_TABLE_HASH_KEY, 160 | new AttributeValue(DYNAMODB_PREDICTION_METADATA_LATEST_PRED_UUID_ATTR_NAME)); 161 | GetItemRequest getUuidItemRequest = new GetItemRequest() 162 | .withTableName(PREDICTION_METADATA_TABLE_NAME) 163 | .withKey(latestPredictionUuidHashKey); 164 | GetItemResult getUuidItemResult = localDdbClient.getItem(getUuidItemRequest); 165 | assertEquals(new AttributeValue(forecastExportJobName), getUuidItemResult.getItem().get(PREDICTION_METADATA_TABLE_ATTRIBUTE_NAME)); 166 | 167 | // Verify the latestPredictionDataFrequency in metadata table 168 | Map latestPredictionDataFreqHashKey = new HashMap<>(); 169 | latestPredictionUuidHashKey.put(PREDICTION_METADATA_TABLE_HASH_KEY, 170 | new AttributeValue(DYNAMODB_PREDICTION_METADATA_LATEST_PRED_DATA_FREQ_IN_SEC_ATTR_NAME)); 171 | GetItemRequest getDataFreqItemRequest = new GetItemRequest() 172 | .withTableName(PREDICTION_METADATA_TABLE_NAME) 173 | .withKey(latestPredictionUuidHashKey); 174 | GetItemResult getDataFreqItemResult = localDdbClient.getItem(getDataFreqItemRequest); 175 | assertEquals(new AttributeValue(String.valueOf(TEST_PREDICTION1_DATA_FREQUENCY_IN_SECONDS)), 176 | getDataFreqItemResult.getItem().get(PREDICTION_METADATA_TABLE_ATTRIBUTE_NAME)); 177 | 178 | // Verify the number of items in PredictionResultItem table 179 | ScanRequest predictionTableScanRequest = new ScanRequest() 180 | .withTableName(PREDICTION_TABLE_NAME); 181 | ScanResult predictionTableScanResult = localDdbClient.scan(predictionTableScanRequest); 182 | long itemCount = getNumberOfLines(TEST_OBJECT_KEY1) - 1; 183 | assertEquals(itemCount, predictionTableScanResult.getItems().size()); 184 | } 185 | 186 | private long getNumberOfLines(final String fileName) throws IOException { 187 | final InputStream inputStream = getClass().getResourceAsStream(UNIT_TEST_ROOT_CLASS_PATH + fileName); 188 | long lines = 0; 189 | try (LineNumberReader lnr = new LineNumberReader(new InputStreamReader(inputStream))) { 190 | while (lnr.readLine() != null) { 191 | lines++; 192 | } 193 | } 194 | 195 | return lines; 196 | } 197 | 198 | private S3Event makeMockS3Event(final String objectKey) { 199 | 200 | S3EventNotification.S3BucketEntity bucket = new S3EventNotification.S3BucketEntity("dummyBucket", 201 | mock(S3EventNotification.UserIdentityEntity.class), "dummyArn"); 202 | S3EventNotification.S3ObjectEntity object = new S3EventNotification.S3ObjectEntity(objectKey, 1024L, 203 | "dummyEtag", "dummyVersionId", null/*no sequencer*/); 204 | S3EventNotification.S3Entity s3 = new S3EventNotification.S3Entity("dummyConfigurationId", 205 | bucket, object, "dummySchemaVer"); 206 | 207 | S3EventNotificationRecord record = new S3EventNotificationRecord("dummyRegion", 208 | "dummyEventName", "dummyEventSrc", "2019-01-01T00:00:00.000", 209 | "dummyEventVer", mock(S3EventNotification.RequestParametersEntity.class), 210 | mock(S3EventNotification.ResponseElementsEntity.class), s3, mock(S3EventNotification.UserIdentityEntity.class), 211 | null/*no glacierEventData*/); 212 | 213 | return new S3Event(Collections.singletonList(record)); 214 | } 215 | 216 | private void refreshLocalDynamoDB() { 217 | // first delete the tables 218 | localDdbClient.deleteTable(PREDICTION_METADATA_TABLE_NAME); 219 | localDdbClient.deleteTable(PREDICTION_TABLE_NAME); 220 | 221 | // then recreate them 222 | localDdbClient = initLocalDynamoDB(); 223 | } 224 | 225 | private AmazonDynamoDB initLocalDynamoDB() { 226 | AmazonDynamoDB localDdbClient = DynamoDBEmbedded.create().amazonDynamoDB(); 227 | ProvisionedThroughput provisionedThroughput = new ProvisionedThroughput() 228 | .withReadCapacityUnits(200L) 229 | .withWriteCapacityUnits(200L); 230 | 231 | // create local prediction result table 232 | List predictionResultTableKeys = new ArrayList<>(); 233 | predictionResultTableKeys.add(new KeySchemaElement().withAttributeName(PREDICTION_TABLE_HASH_KEY).withKeyType(KeyType.HASH)); 234 | predictionResultTableKeys.add(new KeySchemaElement().withAttributeName(PREDICTION_TABLE_RANGE_KEY).withKeyType(KeyType.RANGE)); 235 | List predictionResultTableAttrs = new ArrayList<>(); 236 | predictionResultTableAttrs.add(new AttributeDefinition().withAttributeName(PREDICTION_TABLE_HASH_KEY).withAttributeType(ScalarAttributeType.S)); 237 | predictionResultTableAttrs.add(new AttributeDefinition().withAttributeName(PREDICTION_TABLE_RANGE_KEY).withAttributeType(ScalarAttributeType.S)); 238 | CreateTableRequest predictionResultCreateTableRequest = new CreateTableRequest() 239 | .withTableName(PREDICTION_TABLE_NAME) 240 | .withKeySchema(predictionResultTableKeys) 241 | .withAttributeDefinitions(predictionResultTableAttrs) 242 | .withProvisionedThroughput(provisionedThroughput); 243 | localDdbClient.createTable(predictionResultCreateTableRequest); 244 | 245 | // create local prediction metadata table 246 | KeySchemaElement predictionMetadataTableKey = new KeySchemaElement() 247 | .withAttributeName(PREDICTION_METADATA_TABLE_HASH_KEY) 248 | .withKeyType(KeyType.HASH); 249 | AttributeDefinition predictionMetadataTableAttr = new AttributeDefinition() 250 | .withAttributeName(PREDICTION_METADATA_TABLE_HASH_KEY) 251 | .withAttributeType(ScalarAttributeType.S); 252 | CreateTableRequest predictionMetadataCreateTableRequest = new CreateTableRequest() 253 | .withTableName(PREDICTION_METADATA_TABLE_NAME) 254 | .withKeySchema(Collections.singletonList(predictionMetadataTableKey)) 255 | .withAttributeDefinitions(Collections.singletonList(predictionMetadataTableAttr)) 256 | .withProvisionedThroughput(provisionedThroughput); 257 | localDdbClient.createTable(predictionMetadataCreateTableRequest); 258 | 259 | return localDdbClient; 260 | } 261 | 262 | private AmazonS3 initMockS3Client() { 263 | AmazonS3 mockS3Client = mock(AmazonS3.class); 264 | when(mockS3Client.getObject(any(GetObjectRequest.class))).thenAnswer( 265 | invocationOnMock -> { 266 | GetObjectRequest req = invocationOnMock.getArgument(0); 267 | 268 | try { 269 | return mockS3ObjectFromLocalFile(req.getKey()); 270 | } catch (NullPointerException e) { 271 | // Any request that cannot find match key, we should throw an S3 Exception 272 | throw new AmazonS3Exception("Object not found or not available"); 273 | } 274 | } 275 | ); 276 | return mockS3Client; 277 | } 278 | 279 | private S3Object mockS3ObjectFromLocalFile(String fileName) { 280 | 281 | final InputStream inputStream = getClass().getResourceAsStream(UNIT_TEST_ROOT_CLASS_PATH + fileName); 282 | if (inputStream == null) { 283 | throw new NullPointerException(); 284 | } 285 | S3Object s3Object = mock(S3Object.class); 286 | 287 | // mock an S3ObjectInputStream (stream returned from S3 GET response) 288 | S3ObjectInputStream mockS3ObjectInputStream = new S3ObjectInputStream(inputStream, null); 289 | 290 | when(s3Object.getObjectContent()).thenReturn(mockS3ObjectInputStream); 291 | return s3Object; 292 | } 293 | } 294 | -------------------------------------------------------------------------------- /src/test/resources/test_raw_demand_requests.csv: -------------------------------------------------------------------------------- 1 | item_id,timestamp,target_value 2 | 5,2020-01-01 03:50:33,14 3 | 5,2020-02-01 03:53:14,14 4 | 5,2020-03-01 03:55:18,16 5 | 5,2020-04-01 04:00:30,17 6 | 5,2020-05-01 06:26:10,14 7 | 1,2020-06-01 08:30:42,16 8 | 5,2020-07-01 12:28:58,14 9 | 7,2020-08-01 14:00:28,23 10 | 7,2020-09-01 15:15:18,23 11 | 5,2020-10-01 16:25:23,14 12 | 7,2020-11-01 16:45:05,14 13 | 5,2020-12-01 18:27:56,14 14 | -------------------------------------------------------------------------------- /src/test/resources/tgt/empty_forecast_export_job_2019-10-16T21-40-00Z_part0.csv: -------------------------------------------------------------------------------- 1 | date,item_id,mean,p10,p50,p90 -------------------------------------------------------------------------------- /src/test/resources/tgt/forecast_export_job1_2019-10-16T21-40-00Z_part0.csv: -------------------------------------------------------------------------------- 1 | date,item_id,mean,p10,p50,p90 2 | 2019-01-01T00:00:00Z,wp100,21,-8,21,49 3 | 2019-01-01T00:00:00Z,wp101,21,-8,21,49 4 | 2019-01-01T01:00:00Z,wp100,22,-7,22,52 5 | 2019-01-01T01:00:00Z,wp101,22,-7,22,52 6 | 2019-01-01T02:00:00Z,wp100,19,-8,19,49 7 | 2019-01-01T02:00:00Z,wp101,19,-8,19,49 8 | 2019-01-01T03:00:00Z,wp100,14,-15,14,44 9 | 2019-01-01T03:00:00Z,wp101,14,-15,14,44 10 | 2019-01-01T04:00:00Z,wp100,10,-19,10,39 11 | 2019-01-01T04:00:00Z,wp101,10,-19,10,39 12 | 2019-01-01T05:00:00Z,wp100,9,-18,9,39 13 | 2019-01-01T05:00:00Z,wp101,9,-18,9,39 14 | 2019-01-01T06:00:00Z,wp100,9,-19,9,41 15 | 2019-01-01T06:00:00Z,wp101,9,-19,9,41 16 | 2019-01-01T07:00:00Z,wp100,11,-18,11,41 17 | 2019-01-01T07:00:00Z,wp101,11,-18,11,41 -------------------------------------------------------------------------------- /src/test/resources/tgt/forecast_export_job_with_one_record_2019-10-16T21-40-00Z_part0.csv: -------------------------------------------------------------------------------- 1 | date,item_id,mean,p10,p50,p90 2 | 2018-12-31T23:46:00Z,i2.large,0.9406106747904436,-3.7700757298479304,0.9406106747904436,5.065026283952983 3 | --------------------------------------------------------------------------------