├── .circleci └── config.yml ├── .codeflow.yml ├── .gitignore ├── CONTRIBUTING.md ├── Dockerfile ├── Gemfile ├── Gemfile.lock ├── LICENSE ├── README.md ├── STATE_SPEC.md ├── assets ├── one_small_step_for_gopher.png └── sm.png ├── aws ├── aws.go ├── dynamodb │ ├── lock.go │ └── lock_test.go ├── mocks │ ├── mock_dynamodb.go │ ├── mock_lambda.go │ ├── mock_s3.go │ ├── mock_sfn.go │ └── mocks.go └── s3 │ ├── lock.go │ ├── lock_test.go │ ├── s3.go │ └── s3_test.go ├── bifrost ├── inmemory_locker.go ├── release.go ├── release_halt_test.go ├── release_lock_test.go └── release_test.go ├── client ├── bootstrap.go ├── client.go ├── client_test.go └── deploy.go ├── deployer ├── README.md ├── fuzz_test.go ├── handlers.go ├── helpers_test.go ├── integration_test.go ├── machine.go ├── release.go ├── release_parsing.go └── release_test.go ├── errors └── errors.go ├── examples ├── all_types.json ├── bad_path.json ├── bad_type.json ├── bad_unknown_state.json ├── basic_choice.json ├── basic_pass.json ├── builder.json ├── deployer.json ├── map.json ├── step_deployer.json └── taskfn.json ├── execution └── execution.go ├── go.mod ├── go.sum ├── handler ├── handler.go └── handler_test.go ├── jsonpath ├── jsonpath.go ├── jsonpath_get_test.go ├── jsonpath_path_test.go └── jsonpath_set_test.go ├── machine ├── README.md ├── choice_state.go ├── choice_state_test.go ├── execution.go ├── fail_state.go ├── machine.go ├── machine_test.go ├── map_state.go ├── map_state_test.go ├── parallel_state.go ├── parser.go ├── parser_test.go ├── pass_state.go ├── pass_state_test.go ├── state.go ├── state_test.go ├── succeed_state.go ├── task_state.go ├── task_state_test.go ├── wait_state.go └── wait_state_test.go ├── resources ├── empty_lambda.zip ├── step-deployer.rb ├── step_assumed_policy.json.erb └── step_lambda_policy.json.erb ├── scripts ├── bootstrap_deployer ├── build_lambda_zip └── deploy_deployer ├── step.go └── utils ├── is ├── is.go └── is_test.go ├── run ├── dot.go └── run.go └── to ├── arn.go ├── arn_test.go ├── json.go ├── json_test.go ├── pointer.go ├── sha256.go ├── to.go └── to_test.go /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | jobs: 3 | build: 4 | docker: 5 | - image: circleci/golang:1.14 6 | working_directory: /go/src/github.com/coinbase/step 7 | steps: 8 | - checkout 9 | - run: export GO111MODULE=on && go mod download 10 | - run: export GO111MODULE=on && go test ./... 11 | 12 | -------------------------------------------------------------------------------- /.codeflow.yml: -------------------------------------------------------------------------------- 1 | deploy: 2 | engine: Step 3 | secure: 4 | required_reviews: 1 5 | upstream_repository: coinbase/step 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | step.zip 3 | step 4 | lambda 5 | lambda.zip 6 | tmp 7 | coverage.out 8 | coverage.html 9 | vendor 10 | .idea 11 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to step 2 | 3 | ## Code of Conduct 4 | 5 | All interactions with this project follow our [Code of Conduct][code-of-conduct]. 6 | By participating, you are expected to honor this code. Violators can be banned 7 | from further participation in this project, or potentially all Coinbase projects. 8 | 9 | [code-of-conduct]: https://github.com/coinbase/code-of-conduct 10 | 11 | ## Bug Reports 12 | 13 | * Ensure your issue [has not already been reported][1]. It may already be fixed! 14 | * Include the steps you carried out to produce the problem. 15 | * Include the behavior you observed along with the behavior you expected, and 16 | why you expected it. 17 | * Include any relevant stack traces or debugging output. 18 | 19 | ## Feature Requests 20 | 21 | We welcome feedback with or without pull requests. If you have an idea for how 22 | to improve the project, great! All we ask is that you take the time to write a 23 | clear and concise explanation of what need you are trying to solve. If you have 24 | thoughts on _how_ it can be solved, include those too! 25 | 26 | The best way to see a feature added, however, is to submit a pull request. 27 | 28 | ## Pull Requests 29 | 30 | * Before creating your pull request, it's usually worth asking if the code 31 | you're planning on writing will actually be considered for merging. You can 32 | do this by [opening an issue][1] and asking. It may also help give the 33 | maintainers context for when the time comes to review your code. 34 | 35 | * Ensure your [commit messages are well-written][2]. This can double as your 36 | pull request message, so it pays to take the time to write a clear message. 37 | 38 | * Add tests for your feature. You should be able to look at other tests for 39 | examples. If you're unsure, don't hesitate to [open an issue][1] and ask! 40 | 41 | * Submit your pull request! 42 | 43 | ## Support Requests 44 | 45 | For security reasons, any communication referencing support tickets for Coinbase 46 | products will be ignored. The request will have its content redacted and will 47 | be locked to prevent further discussion. 48 | 49 | All support requests must be made via [our support team][3]. 50 | 51 | [1]: https://github.com/coinbase/step/issues 52 | [2]: https://medium.com/brigade-engineering/the-secrets-to-great-commit-messages-106fc0a92a25 53 | [3]: https://support.coinbase.com/customer/en/portal/articles/2288496-how-can-i-contact-coinbase-support- 54 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM golang@sha256:ebe7f5d1a2a6b884bc1a45b8c1ff7e26b7b95938a3e8847ea96fc6761fdc2b77 2 | 3 | # Install Zip 4 | RUN apt-get update && apt-get upgrade -y && apt-get install -y zip 5 | 6 | WORKDIR /go/src/github.com/coinbase/step 7 | 8 | ENV GO111MODULE on 9 | ENV GOPATH /go 10 | 11 | COPY go.mod go.sum ./ 12 | RUN go mod download 13 | 14 | COPY . . 15 | 16 | RUN go build && go install 17 | 18 | # builds lambda.zip 19 | RUN ./scripts/build_lambda_zip 20 | RUN shasum -a 256 lambda.zip | awk '{print $1}' > lambda.zip.sha256 21 | 22 | RUN mv lambda.zip.sha256 lambda.zip / 23 | RUN step json > /state_machine.json 24 | 25 | CMD ["step"] 26 | -------------------------------------------------------------------------------- /Gemfile: -------------------------------------------------------------------------------- 1 | source 'https://rubygems.org' 2 | 3 | gem "geoengineer", { 4 | git: 'https://github.com/coinbase/geoengineer.git', 5 | ref: '9970098b88015c7d2157dcaf1276751f266b2ea7' 6 | } 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Step (Beta) 2 | 3 | One Small Step for Go 4 | 5 | Step is a opinionated implementation of the [AWS State Machine language](./STATE_SPEC.md) in [Go](https://golang.org/) used to build and test [AWS Step Functions](https://docs.aws.amazon.com/step-functions/latest/dg/getting-started.html) and [Lambdas](https://docs.aws.amazon.com/lambda/latest/dg/getting-started.html). Step combines the **Structure** of a state machine with the **Code** of a lambda so that the two can be developed, tested and maintained together. 6 | 7 | The three core components of Step are: 8 | 9 | 1. **Library**: tools for building and deploying Step Functions in Go. 10 | 2. **Implementation**: of the AWS State Machine specification to test with the code together ([README](./machine)). 11 | 3. **Deployer**: to deploy Lambda's and Step Functions securely ([README](./deployer)) 12 | 13 | ### Getting Started 14 | 15 | A Step function has two parts: 16 | 17 | 1. A **State Machine** description in JSON, which outlines the flow of execution. 18 | 2. The **Lambda Function** which executes the `TaskFn` states of the step function. 19 | 20 | Create a State Machine like this: 21 | 22 | ```go 23 | func StateMachine(lambdaArn string) (machine.StateMachine, error) { 24 | state_machine, err := machine.FromJSON([]byte(`{ 25 | "Comment": "Hello World", 26 | "StartAt": "HelloFn", 27 | "States": { 28 | "Hello": { 29 | "Type": "TaskFn", 30 | "Comment": "Deploy Step Function", 31 | "End": true 32 | } 33 | } 34 | }`)) 35 | 36 | if err != nil { 37 | return nil, err 38 | } 39 | 40 | // Set the Handlers 41 | state_machine.SetTaskFnHandlers(CreateTaskHandlers()) 42 | 43 | // Set Lambda Arn to call with Task States 44 | state_machine.SetResource(&lambdaArn) 45 | 46 | return state_machine, nil 47 | } 48 | ``` 49 | 50 | `TaskFn` is a custom state type that injects `Parameters` to execute the correct handler. 51 | 52 | Each `TaskFn` must have a handler that implements `func(context.Context, ) (interface{}, error)`. These are defined like: 53 | 54 | ```go 55 | func CreateTaskFunctions() *handler.TaskHandlers { 56 | tm := handler.TaskHandlers{} 57 | // Assign Hello state the HelloHandler 58 | tm["Hello"] = HelloHandler 59 | return &tm 60 | } 61 | 62 | type Hello struct { 63 | Greeting *string 64 | } 65 | 66 | // HelloHandler takes a Hello struct alters its greeting and returns it 67 | func HelloHandler(_ context.Context, hello *Hello) (*Hello, error) { 68 | if hello.Greeting == "" { 69 | hello.Greeting = "Hello World" 70 | } 71 | return hello, nil 72 | } 73 | ``` 74 | 75 | To build a Step Function we then need an executable that can: 76 | 77 | 1. Be executed in a Lambda 78 | 2. Build the State Machine 79 | 80 | ```go 81 | func main() { 82 | var arg, command string 83 | switch len(os.Args) { 84 | case 1: 85 | fmt.Println("Starting Lambda") 86 | run.LambdaTasks(StateMachine("lambda")) 87 | case 2: 88 | command = os.Args[1] 89 | arg = "" 90 | case 3: 91 | command = os.Args[1] 92 | arg = os.Args[2] 93 | default: 94 | printUsage() // Print how to use and exit 95 | } 96 | 97 | switch command { 98 | case "json": 99 | run.JSON(StateMachine(arg)) 100 | case "exec": 101 | run.Exec(StateMachine(""))(&arg) 102 | default: 103 | printUsage() // Print how to use and exit 104 | } 105 | 106 | } 107 | ``` 108 | 109 | 1. `./step-hello-world` will run as a Lambda Function 110 | 2. `./step-hello-world json` will print out the state machine 111 | 112 | ### Testing 113 | 114 | A core benefit when using Step and joining the State Machine and Lambda together is that it makes it possible to test your Step Functions execution. 115 | 116 | For example, a basic test that ensures the correct output and execution path through the Hello World step function looks like: 117 | 118 | ```go 119 | func Test_HelloWorld_StateMachine(t *testing.T) { 120 | state_machine, err := StateMachine("") 121 | assert.NoError(t, err) 122 | 123 | exec, err := state_machine.Execute(&Hello{}) 124 | assert.NoError(t, err) 125 | assert.Equal(t, "Hello World", exec.Output["Greeting"]) 126 | 127 | assert.Equal(t, state_machine.Path(), []string{ 128 | "Hello", 129 | }) 130 | } 131 | ``` 132 | 133 | ### Deploying 134 | 135 | There are two ways to get a State Machine into the cloud: 136 | 137 | 1. **Bootstrap**: Directly upload the Lambda and Step Function to AWS 138 | 2. **Deploy**: Using the Step Deployer which is a Step Function included in this library. 139 | 140 | The Step executable can perform both of these functions. 141 | 142 | *Step does not create the Lambda or Step Function in AWS, it only modifies them. So before either bootstrapping or deploying the resources must already be created.* 143 | 144 | First build and install step with: 145 | 146 | ```bash 147 | go build && go install 148 | ``` 149 | 150 | Bootstrap (directly upload to the Step Function and Lambda): 151 | 152 | ```bash 153 | # Use AWS credentials or assume-role into AWS 154 | # Build linux zip for lambda 155 | GOOS=linux go build -o lambda 156 | zip lambda.zip lambda 157 | 158 | # Tell step to bootstrap this lambda 159 | step bootstrap \ 160 | -lambda "coinbase-step-hello-world" \ 161 | -step "coinbase-step-hello-world" \ 162 | -states "$(./step-hello-world json)" 163 | ``` 164 | 165 | Deploy (via the step-deployer step function): 166 | 167 | ```bash 168 | GOOS=linux go build -o lambda 169 | zip lambda.zip lambda 170 | 171 | # Tell step-deployer to deploy this lambda 172 | step deploy \ 173 | -lambda "coinbase-step-hello-world" \ 174 | -step "coinbase-step-hello-world" \ 175 | -states "$(./step-hello-world json)" 176 | ``` 177 | 178 | ### Development State 179 | 180 | Step is still Beta and its API might change quickly. 181 | 182 | ### More Links 183 | 184 | 1. [AWS Step Functions, State Machines, Bifrost, and Building Deployers](https://blog.coinbase.com/aws-step-functions-state-machines-bifrost-and-building-deployers-5e3745fe645b) 185 | 1. [Open Sourcing Coinbase’s Secure Deployment Pipeline](https://engineering.coinbase.com/open-sourcing-coinbases-secure-deployment-pipeline-ae6c78e25517) 186 | 1. https://docs.aws.amazon.com/step-functions/latest/dg/step-functions-dg.pdf 187 | 1. https://github.com/vkkis93/serverless-step-functions-offline 188 | 1. https://github.com/totherik/step 189 | 190 | *CC Renee French for the logo, borrowed from GopherCon 2017* 191 | -------------------------------------------------------------------------------- /assets/one_small_step_for_gopher.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coinbase/step/301282845bfb07879a39d0c2af36720633a61609/assets/one_small_step_for_gopher.png -------------------------------------------------------------------------------- /assets/sm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coinbase/step/301282845bfb07879a39d0c2af36720633a61609/assets/sm.png -------------------------------------------------------------------------------- /aws/aws.go: -------------------------------------------------------------------------------- 1 | package aws 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/aws/aws-sdk-go/aws" 7 | "github.com/aws/aws-sdk-go/aws/credentials/stscreds" 8 | "github.com/aws/aws-sdk-go/aws/session" 9 | "github.com/aws/aws-sdk-go/service/dynamodb" 10 | "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" 11 | "github.com/aws/aws-sdk-go/service/lambda" 12 | "github.com/aws/aws-sdk-go/service/lambda/lambdaiface" 13 | "github.com/aws/aws-sdk-go/service/s3" 14 | "github.com/aws/aws-sdk-go/service/s3/s3iface" 15 | "github.com/aws/aws-sdk-go/service/sfn" 16 | "github.com/aws/aws-sdk-go/service/sfn/sfniface" 17 | ) 18 | 19 | //////////// 20 | // Interfaces 21 | //////////// 22 | 23 | type S3API s3iface.S3API 24 | type LambdaAPI lambdaiface.LambdaAPI 25 | type SFNAPI sfniface.SFNAPI 26 | type DynamoDBAPI dynamodbiface.DynamoDBAPI 27 | 28 | type AwsClients interface { 29 | S3Client(region *string, account_id *string, role *string) S3API 30 | LambdaClient(region *string, account_id *string, role *string) LambdaAPI 31 | SFNClient(region *string, account_id *string, role *string) SFNAPI 32 | DynamoDBClient(region *string, account_id *string, role *string) DynamoDBAPI 33 | } 34 | 35 | //////////// 36 | // AWS Clients 37 | //////////// 38 | 39 | type Clients struct { 40 | session *session.Session 41 | configs map[string]*aws.Config 42 | } 43 | 44 | func (c Clients) Session() *session.Session { 45 | if c.session != nil { 46 | return c.session 47 | } 48 | // new session 49 | sess := session.Must(session.NewSession()) 50 | c.session = sess 51 | return sess 52 | } 53 | 54 | func (c Clients) Config( 55 | region *string, 56 | account_id *string, 57 | role *string) *aws.Config { 58 | 59 | config := aws.NewConfig().WithMaxRetries(10) 60 | 61 | if region != nil { 62 | config = config.WithRegion(*region) 63 | } 64 | 65 | // return no config for nil inputs 66 | if account_id == nil || role == nil { 67 | return config 68 | } 69 | 70 | // Assume a role 71 | arn := fmt.Sprintf( 72 | "arn:aws:iam::%v:role/%v", 73 | *account_id, 74 | *role, 75 | ) 76 | 77 | // include region in cache key otherwise concurrency errors 78 | key := fmt.Sprintf("%v::%v", *region, arn) 79 | 80 | // check for cached config 81 | if c.configs != nil && c.configs[key] != nil { 82 | return c.configs[key] 83 | } 84 | 85 | // new creds 86 | creds := stscreds.NewCredentials(c.Session(), arn) 87 | 88 | // new config 89 | config = config.WithCredentials(creds) 90 | 91 | if c.configs == nil { 92 | c.configs = map[string]*aws.Config{} 93 | } 94 | 95 | c.configs[key] = config 96 | return config 97 | } 98 | 99 | func (c *Clients) S3Client( 100 | region *string, 101 | account_id *string, 102 | role *string) S3API { 103 | return s3.New(c.Session(), c.Config(region, account_id, role)) 104 | } 105 | 106 | func (c *Clients) LambdaClient(region *string, account_id *string, role *string) LambdaAPI { 107 | return lambda.New(c.Session(), c.Config(region, account_id, role)) 108 | } 109 | 110 | func (c *Clients) SFNClient(region *string, account_id *string, role *string) SFNAPI { 111 | return sfn.New(c.Session(), c.Config(region, account_id, role)) 112 | } 113 | 114 | func (c *Clients) DynamoDBClient(region *string, account_id *string, role *string) DynamoDBAPI { 115 | return dynamodb.New(c.Session(), c.Config(region, account_id, role)) 116 | } 117 | -------------------------------------------------------------------------------- /aws/dynamodb/lock.go: -------------------------------------------------------------------------------- 1 | package dynamodb 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | awssdk "github.com/aws/aws-sdk-go/aws" 8 | "github.com/aws/aws-sdk-go/aws/awserr" 9 | "github.com/aws/aws-sdk-go/service/dynamodb" 10 | "github.com/aws/aws-sdk-go/service/dynamodb/expression" 11 | 12 | stepaws "github.com/coinbase/step/aws" 13 | ) 14 | 15 | var ( 16 | columnKey = "key" 17 | columnId = "id" 18 | columnTime = "time" 19 | ) 20 | 21 | type DynamoDBLocker struct { 22 | client stepaws.DynamoDBAPI 23 | } 24 | 25 | func NewDynamoDBLocker(client stepaws.DynamoDBAPI) *DynamoDBLocker { 26 | return &DynamoDBLocker{client} 27 | } 28 | 29 | func (l *DynamoDBLocker) GrabLock(namespace string, lockPath string, uuid string, reason string) (bool, error) { 30 | // Construct a conditional expression such that we only allow a new lock 31 | // to be created if there is not already one for the same key. 32 | condExp := expression.Name(columnKey).AttributeNotExists() 33 | condExp = condExp.Or(expression.Name(columnId).Equal(expression.Value(uuid))) 34 | 35 | expr, err := expression.NewBuilder().WithCondition(condExp).Build() 36 | if err != nil { 37 | return false, err 38 | } 39 | 40 | // Attempt to create a lock 41 | _, err = l.client.PutItem(&dynamodb.PutItemInput{ 42 | TableName: awssdk.String(namespace), 43 | ConditionExpression: expr.Condition(), 44 | ExpressionAttributeNames: expr.Names(), 45 | ExpressionAttributeValues: expr.Values(), 46 | Item: map[string]*dynamodb.AttributeValue{ 47 | columnKey: { 48 | S: awssdk.String(lockPath), 49 | }, 50 | columnId: { 51 | S: awssdk.String(uuid), 52 | }, 53 | columnTime: { 54 | S: awssdk.String(time.Now().Format(time.RFC3339)), 55 | }, 56 | }, 57 | }) 58 | 59 | if err != nil { 60 | awsErr, ok := err.(awserr.Error) 61 | // A lock already exists for the same key. 62 | if ok && awsErr.Code() == dynamodb.ErrCodeConditionalCheckFailedException { 63 | return false, nil 64 | } 65 | 66 | return false, err 67 | } 68 | 69 | return true, nil 70 | } 71 | 72 | func (l *DynamoDBLocker) ReleaseLock(namespace string, lockPath string, uuid string) error { 73 | // Construct a condition expression such that we only allow a lock 74 | // to be deleted if the key, and the UUID aligns. 75 | condExp := expression.Name(columnId).Equal(expression.Value(uuid)) 76 | expr, err := expression.NewBuilder().WithCondition(condExp).Build() 77 | if err != nil { 78 | return err 79 | } 80 | 81 | // Attempt to delete lock 82 | _, err = l.client.DeleteItem(&dynamodb.DeleteItemInput{ 83 | TableName: awssdk.String(namespace), 84 | ConditionExpression: expr.Condition(), 85 | ExpressionAttributeNames: expr.Names(), 86 | ExpressionAttributeValues: expr.Values(), 87 | Key: map[string]*dynamodb.AttributeValue{ 88 | columnKey: { 89 | S: awssdk.String(lockPath), 90 | }, 91 | }, 92 | }) 93 | 94 | if awsErr, ok := err.(awserr.Error); ok && awsErr.Code() == dynamodb.ErrCodeConditionalCheckFailedException { 95 | // A lock already exists, but with a different UUID. 96 | return fmt.Errorf("Lock was stolen for release with UUID(%v)", uuid) 97 | } 98 | return err 99 | } 100 | -------------------------------------------------------------------------------- /aws/dynamodb/lock_test.go: -------------------------------------------------------------------------------- 1 | package dynamodb 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | 7 | "github.com/aws/aws-sdk-go/aws/awserr" 8 | "github.com/aws/aws-sdk-go/service/dynamodb" 9 | "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | type MockDynamoDBClient struct { 14 | dynamodbiface.DynamoDBAPI 15 | putItemCallback func(input *dynamodb.PutItemInput) (*dynamodb.PutItemOutput, error) 16 | deleteItemCallback func(input *dynamodb.DeleteItemInput) (*dynamodb.DeleteItemOutput, error) 17 | } 18 | 19 | func (c *MockDynamoDBClient) PutItem(input *dynamodb.PutItemInput) (*dynamodb.PutItemOutput, error) { 20 | return c.putItemCallback(input) 21 | } 22 | 23 | func (c *MockDynamoDBClient) DeleteItem(input *dynamodb.DeleteItemInput) (*dynamodb.DeleteItemOutput, error) { 24 | return c.deleteItemCallback(input) 25 | } 26 | 27 | func TestLock(t *testing.T) { 28 | t.Run("lock failure", func(t *testing.T) { 29 | client := &MockDynamoDBClient{} 30 | client.putItemCallback = func(input *dynamodb.PutItemInput) (*dynamodb.PutItemOutput, error) { 31 | return nil, awserr.New(dynamodb.ErrCodeConditionalCheckFailedException, "The conditional request failed.", errors.New("fake error")) 32 | } 33 | 34 | locker := &DynamoDBLocker{client} 35 | 36 | grabbed, err := locker.GrabLock("tableName", "lockPath", "uuid", "testing") 37 | assert.NoError(t, err) 38 | assert.False(t, grabbed) 39 | }) 40 | 41 | t.Run("lock acquired successfully", func(t *testing.T) { 42 | client := &MockDynamoDBClient{} 43 | client.putItemCallback = func(input *dynamodb.PutItemInput) (*dynamodb.PutItemOutput, error) { 44 | assert.Equal(t, "tableName", *input.TableName) 45 | assert.Equal(t, "lockPath", *input.Item[columnKey].S) 46 | assert.Equal(t, "uuid", *input.Item[columnId].S) 47 | assert.Equal(t, "(attribute_not_exists (#0)) OR (#1 = :0)", *input.ConditionExpression) 48 | 49 | assert.Equal(t, "key", *input.ExpressionAttributeNames["#0"]) 50 | assert.Equal(t, "id", *input.ExpressionAttributeNames["#1"]) 51 | assert.Equal(t, "uuid", *input.ExpressionAttributeValues[":0"].S) 52 | 53 | return &dynamodb.PutItemOutput{}, nil 54 | } 55 | 56 | locker := &DynamoDBLocker{client} 57 | 58 | grabbed, err := locker.GrabLock("tableName", "lockPath", "uuid", "testing") 59 | assert.NoError(t, err) 60 | assert.True(t, grabbed) 61 | }) 62 | } 63 | 64 | func TestUnlock(t *testing.T) { 65 | t.Run("unlock failure", func(t *testing.T) { 66 | client := &MockDynamoDBClient{} 67 | client.deleteItemCallback = func(input *dynamodb.DeleteItemInput) (*dynamodb.DeleteItemOutput, error) { 68 | return nil, awserr.New(dynamodb.ErrCodeConditionalCheckFailedException, "The conditional request failed.", errors.New("fake error")) 69 | } 70 | 71 | locker := &DynamoDBLocker{client} 72 | 73 | err := locker.ReleaseLock("tableName", "lockPath", "uuid") 74 | assert.Error(t, err) 75 | }) 76 | 77 | t.Run("unlock released", func(t *testing.T) { 78 | client := &MockDynamoDBClient{} 79 | client.deleteItemCallback = func(input *dynamodb.DeleteItemInput) (*dynamodb.DeleteItemOutput, error) { 80 | assert.Equal(t, "tableName", *input.TableName) 81 | 82 | assert.Equal(t, "id", *input.ExpressionAttributeNames["#0"]) 83 | assert.Equal(t, "uuid", *input.ExpressionAttributeValues[":0"].S) 84 | assert.Equal(t, "lockPath", *input.Key[columnKey].S) 85 | 86 | return &dynamodb.DeleteItemOutput{}, nil 87 | } 88 | 89 | locker := &DynamoDBLocker{client} 90 | 91 | err := locker.ReleaseLock("tableName", "lockPath", "uuid") 92 | assert.NoError(t, err) 93 | }) 94 | } 95 | -------------------------------------------------------------------------------- /aws/mocks/mock_dynamodb.go: -------------------------------------------------------------------------------- 1 | package mocks 2 | 3 | import ( 4 | "github.com/aws/aws-sdk-go/service/dynamodb" 5 | "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" 6 | ) 7 | 8 | type MockDynamoDBClient struct { 9 | dynamodbiface.DynamoDBAPI 10 | 11 | PutItemInputs []*dynamodb.PutItemInput 12 | DeleteItemInputs []*dynamodb.DeleteItemInput 13 | } 14 | 15 | func (m *MockDynamoDBClient) init() { 16 | if m.PutItemInputs == nil { 17 | m.PutItemInputs = []*dynamodb.PutItemInput{} 18 | } 19 | 20 | if m.DeleteItemInputs == nil { 21 | m.DeleteItemInputs = []*dynamodb.DeleteItemInput{} 22 | } 23 | } 24 | 25 | func (m *MockDynamoDBClient) PutItem(input *dynamodb.PutItemInput) (*dynamodb.PutItemOutput, error) { 26 | m.PutItemInputs = append(m.PutItemInputs, input) 27 | return &dynamodb.PutItemOutput{}, nil 28 | } 29 | 30 | func (m *MockDynamoDBClient) DeleteItem(input *dynamodb.DeleteItemInput) (*dynamodb.DeleteItemOutput, error) { 31 | m.DeleteItemInputs = append(m.DeleteItemInputs, input) 32 | return &dynamodb.DeleteItemOutput{}, nil 33 | } 34 | -------------------------------------------------------------------------------- /aws/mocks/mock_lambda.go: -------------------------------------------------------------------------------- 1 | package mocks 2 | 3 | import ( 4 | "github.com/aws/aws-sdk-go/service/lambda" 5 | "github.com/aws/aws-sdk-go/service/lambda/lambdaiface" 6 | ) 7 | 8 | type MockLambdaClient struct { 9 | lambdaiface.LambdaAPI 10 | UpdateFunctionCodeResp *lambda.FunctionConfiguration 11 | UpdateFunctionCodeError error 12 | ListTagsResp *lambda.ListTagsOutput 13 | } 14 | 15 | func (m *MockLambdaClient) init() { 16 | if m.UpdateFunctionCodeResp == nil { 17 | m.UpdateFunctionCodeResp = &lambda.FunctionConfiguration{} 18 | } 19 | } 20 | 21 | func (m *MockLambdaClient) UpdateFunctionCode(in *lambda.UpdateFunctionCodeInput) (*lambda.FunctionConfiguration, error) { 22 | m.init() 23 | return m.UpdateFunctionCodeResp, m.UpdateFunctionCodeError 24 | } 25 | 26 | func (m *MockLambdaClient) ListTags(in *lambda.ListTagsInput) (*lambda.ListTagsOutput, error) { 27 | m.init() 28 | return m.ListTagsResp, nil 29 | } 30 | -------------------------------------------------------------------------------- /aws/mocks/mock_s3.go: -------------------------------------------------------------------------------- 1 | package mocks 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "io/ioutil" 8 | "strings" 9 | "time" 10 | 11 | "github.com/aws/aws-sdk-go/aws/awserr" 12 | "github.com/aws/aws-sdk-go/service/s3" 13 | "github.com/aws/aws-sdk-go/service/s3/s3iface" 14 | "github.com/coinbase/step/utils/to" 15 | ) 16 | 17 | // S3Client 18 | type GetObjectResponse struct { 19 | Resp *s3.GetObjectOutput 20 | Body string 21 | Error error 22 | } 23 | 24 | type PutObjectResponse struct { 25 | Resp *s3.PutObjectOutput 26 | Error error 27 | } 28 | 29 | type DeleteObjectResponse struct { 30 | Resp *s3.DeleteObjectOutput 31 | Error error 32 | } 33 | 34 | type GetBucketTaggingResponse struct { 35 | Resp *s3.GetBucketTaggingOutput 36 | Error error 37 | } 38 | 39 | type MockS3Client struct { 40 | s3iface.S3API 41 | 42 | GetObjectResp map[string]*GetObjectResponse 43 | 44 | PutObjectResp map[string]*PutObjectResponse 45 | 46 | DeleteObjectResp map[string]*DeleteObjectResponse 47 | 48 | GetBucketTaggingResp map[string]*GetBucketTaggingResponse 49 | } 50 | 51 | func (m *MockS3Client) init() { 52 | if m.GetObjectResp == nil { 53 | m.GetObjectResp = map[string]*GetObjectResponse{} 54 | } 55 | 56 | if m.PutObjectResp == nil { 57 | m.PutObjectResp = map[string]*PutObjectResponse{} 58 | } 59 | 60 | if m.DeleteObjectResp == nil { 61 | m.DeleteObjectResp = map[string]*DeleteObjectResponse{} 62 | } 63 | 64 | if m.GetBucketTaggingResp == nil { 65 | m.GetBucketTaggingResp = map[string]*GetBucketTaggingResponse{} 66 | } 67 | } 68 | 69 | func MakeS3Body(ret string) io.ReadCloser { 70 | return ioutil.NopCloser(strings.NewReader(ret)) 71 | } 72 | 73 | func makeS3Resp(ret string, contentType *string, cacheControl *string) *s3.GetObjectOutput { 74 | return &s3.GetObjectOutput{ 75 | Body: MakeS3Body(ret), 76 | ContentType: contentType, 77 | CacheControl: cacheControl, 78 | LastModified: to.Timep(time.Now()), 79 | } 80 | } 81 | 82 | func AWSS3NotFoundError() error { 83 | return awserr.New(s3.ErrCodeNoSuchKey, "not found", nil) 84 | } 85 | 86 | func (m *MockS3Client) addGetObjectWithContentTypeAndCacheControl(key string, body string, contentType *string, cacheControl *string, err error) { 87 | m.init() 88 | m.GetObjectResp[key] = &GetObjectResponse{ 89 | Resp: makeS3Resp(body, contentType, cacheControl), 90 | Body: body, 91 | Error: err, 92 | } 93 | } 94 | 95 | func (m *MockS3Client) AddGetObject(key string, body string, err error) { 96 | m.addGetObjectWithContentTypeAndCacheControl(key, body, nil, nil, err) 97 | } 98 | 99 | func (m *MockS3Client) AddPutObject(key string, err error) { 100 | m.init() 101 | m.PutObjectResp[key] = &PutObjectResponse{ 102 | Resp: &s3.PutObjectOutput{}, 103 | Error: err, 104 | } 105 | } 106 | 107 | func (m *MockS3Client) SetBucketTags(bucket string, tags map[string]string, err error) { 108 | m.init() 109 | tagSet := []*s3.Tag{} 110 | 111 | for tk, tv := range tags { 112 | tagSet = append(tagSet, &s3.Tag{Key: to.Strp(tk), Value: to.Strp(tv)}) 113 | } 114 | 115 | m.GetBucketTaggingResp[bucket] = &GetBucketTaggingResponse{ 116 | Resp: &s3.GetBucketTaggingOutput{ 117 | TagSet: tagSet, 118 | }, 119 | Error: err, 120 | } 121 | } 122 | 123 | func (m *MockS3Client) GetObject(in *s3.GetObjectInput) (*s3.GetObjectOutput, error) { 124 | m.init() 125 | resp := m.GetObjectResp[*in.Key] 126 | 127 | if resp == nil { 128 | return nil, AWSS3NotFoundError() 129 | } 130 | 131 | resp.Resp.Body = MakeS3Body(resp.Body) 132 | return resp.Resp, resp.Error 133 | } 134 | 135 | func (m *MockS3Client) ListObjects(in *s3.ListObjectsInput) (*s3.ListObjectsOutput, error) { 136 | return nil, nil 137 | } 138 | 139 | func (m *MockS3Client) PutObject(in *s3.PutObjectInput) (*s3.PutObjectOutput, error) { 140 | m.init() 141 | 142 | resp := m.PutObjectResp[*in.Key] 143 | // Simulates adding the object 144 | buf := new(bytes.Buffer) 145 | buf.ReadFrom(in.Body) 146 | m.addGetObjectWithContentTypeAndCacheControl(*in.Key, buf.String(), in.ContentType, in.CacheControl, nil) 147 | 148 | if resp == nil { 149 | return &s3.PutObjectOutput{}, nil 150 | } 151 | return resp.Resp, resp.Error 152 | } 153 | 154 | func (m *MockS3Client) GetBucketTagging(in *s3.GetBucketTaggingInput) (*s3.GetBucketTaggingOutput, error) { 155 | m.init() 156 | resp := m.GetBucketTaggingResp[*in.Bucket] 157 | if resp == nil { 158 | return nil, fmt.Errorf("Unkown Bucket, should mock the tags") 159 | } 160 | return resp.Resp, resp.Error 161 | } 162 | 163 | func (m *MockS3Client) DeleteObject(in *s3.DeleteObjectInput) (*s3.DeleteObjectOutput, error) { 164 | m.init() 165 | 166 | resp := m.DeleteObjectResp[*in.Key] 167 | 168 | delete(m.GetObjectResp, *in.Key) 169 | 170 | if resp == nil { 171 | return &s3.DeleteObjectOutput{}, nil 172 | } 173 | return resp.Resp, resp.Error 174 | } 175 | -------------------------------------------------------------------------------- /aws/mocks/mock_sfn.go: -------------------------------------------------------------------------------- 1 | package mocks 2 | 3 | import ( 4 | "github.com/aws/aws-sdk-go/service/sfn" 5 | "github.com/aws/aws-sdk-go/service/sfn/sfniface" 6 | "github.com/coinbase/step/utils/to" 7 | ) 8 | 9 | type MockSFNClient struct { 10 | sfniface.SFNAPI 11 | UpdateStateMachineResp *sfn.UpdateStateMachineOutput 12 | UpdateStateMachineError error 13 | StartExecutionResp *sfn.StartExecutionOutput 14 | DescribeExecutionResp *sfn.DescribeExecutionOutput 15 | GetExecutionHistoryResp *sfn.GetExecutionHistoryOutput 16 | DescribeStateMachineResp *sfn.DescribeStateMachineOutput 17 | ListExecutionsResp *sfn.ListExecutionsOutput 18 | } 19 | 20 | func (m *MockSFNClient) init() { 21 | 22 | if m.UpdateStateMachineResp == nil { 23 | m.UpdateStateMachineResp = &sfn.UpdateStateMachineOutput{} 24 | } 25 | 26 | if m.StartExecutionResp == nil { 27 | m.StartExecutionResp = &sfn.StartExecutionOutput{} 28 | } 29 | 30 | if m.DescribeExecutionResp == nil { 31 | m.DescribeExecutionResp = &sfn.DescribeExecutionOutput{Status: to.Strp("SUCCEEDED")} 32 | } 33 | 34 | if m.GetExecutionHistoryResp == nil { 35 | m.GetExecutionHistoryResp = &sfn.GetExecutionHistoryOutput{Events: []*sfn.HistoryEvent{}} 36 | } 37 | 38 | if m.ListExecutionsResp == nil { 39 | m.ListExecutionsResp = &sfn.ListExecutionsOutput{Executions: []*sfn.ExecutionListItem{}} 40 | } 41 | } 42 | 43 | func (m *MockSFNClient) UpdateStateMachine(in *sfn.UpdateStateMachineInput) (*sfn.UpdateStateMachineOutput, error) { 44 | m.init() 45 | return m.UpdateStateMachineResp, m.UpdateStateMachineError 46 | } 47 | 48 | func (m *MockSFNClient) StartExecution(in *sfn.StartExecutionInput) (*sfn.StartExecutionOutput, error) { 49 | m.init() 50 | return m.StartExecutionResp, nil 51 | } 52 | 53 | func (m *MockSFNClient) DescribeExecution(in *sfn.DescribeExecutionInput) (*sfn.DescribeExecutionOutput, error) { 54 | m.init() 55 | return m.DescribeExecutionResp, nil 56 | } 57 | 58 | func (m *MockSFNClient) GetExecutionHistory(in *sfn.GetExecutionHistoryInput) (*sfn.GetExecutionHistoryOutput, error) { 59 | m.init() 60 | return m.GetExecutionHistoryResp, nil 61 | } 62 | 63 | func (m *MockSFNClient) DescribeStateMachine(in *sfn.DescribeStateMachineInput) (*sfn.DescribeStateMachineOutput, error) { 64 | m.init() 65 | return m.DescribeStateMachineResp, nil 66 | } 67 | 68 | func (m *MockSFNClient) ListExecutions(in *sfn.ListExecutionsInput) (*sfn.ListExecutionsOutput, error) { 69 | m.init() 70 | return m.ListExecutionsResp, nil 71 | } 72 | -------------------------------------------------------------------------------- /aws/mocks/mocks.go: -------------------------------------------------------------------------------- 1 | package mocks 2 | 3 | import "github.com/coinbase/step/aws" 4 | 5 | type MockClients struct { 6 | S3 *MockS3Client 7 | Lambda *MockLambdaClient 8 | SFN *MockSFNClient 9 | DynamoDB *MockDynamoDBClient 10 | } 11 | 12 | func (awsc *MockClients) S3Client(*string, *string, *string) aws.S3API { 13 | return awsc.S3 14 | } 15 | 16 | func (awsc *MockClients) LambdaClient(*string, *string, *string) aws.LambdaAPI { 17 | return awsc.Lambda 18 | } 19 | 20 | func (awsc *MockClients) SFNClient(*string, *string, *string) aws.SFNAPI { 21 | return awsc.SFN 22 | } 23 | 24 | func (awsc *MockClients) DynamoDBClient(*string, *string, *string) aws.DynamoDBAPI { 25 | return awsc.DynamoDB 26 | } 27 | 28 | func MockAwsClients() *MockClients { 29 | return &MockClients{ 30 | &MockS3Client{}, 31 | &MockLambdaClient{}, 32 | &MockSFNClient{}, 33 | &MockDynamoDBClient{}, 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /aws/s3/lock.go: -------------------------------------------------------------------------------- 1 | package s3 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/coinbase/step/aws" 7 | ) 8 | 9 | type Lock struct { 10 | UUID string `json:"uuid,omitempty"` 11 | } 12 | 13 | type UserLock struct { 14 | User string `json:"user,omitempty"` 15 | LockReason string `json:"lock_reason", omitempty"` 16 | } 17 | 18 | func CheckUserLock(s3c aws.S3API, bucket *string, lock_path *string) error { 19 | var userLock UserLock 20 | err := GetStruct(s3c, bucket, lock_path, &userLock) 21 | if err != nil { 22 | switch err.(type) { 23 | case *NotFoundError: 24 | // good we want this 25 | return nil 26 | default: 27 | return err // All other errors return 28 | } 29 | } 30 | if userLock == (UserLock{}) { 31 | return nil 32 | } 33 | return fmt.Errorf("Deploys locked by %v for reason: %v", userLock.User, userLock.LockReason) 34 | } 35 | 36 | // GrabLock creates a lock file in S3 with a UUID 37 | // it returns a grabbed bool, and error 38 | // if the Lock already exists and UUID is equal to the existing lock it will returns true, otherwise false 39 | // if the Lock doesn't exist it will create the file and return true 40 | func GrabLock(s3c aws.S3API, bucket *string, lock_path *string, uuid string) (bool, error) { 41 | lock := &Lock{uuid} 42 | var s3_lock Lock 43 | 44 | err := GetStruct(s3c, bucket, lock_path, &s3_lock) 45 | if err != nil { 46 | switch err.(type) { 47 | case *NotFoundError: 48 | // good we want this 49 | default: 50 | return false, err // All other errors return 51 | } 52 | } 53 | 54 | // If s3_lock unmarshalled and the UUID 55 | if s3_lock.UUID != "" { 56 | // if UUID is the same 57 | if s3_lock.UUID == lock.UUID { 58 | // Already have the lock (caused by a retry ... maybe) 59 | return true, nil 60 | } else { 61 | return false, nil 62 | } 63 | } 64 | 65 | // After this point we might have created the lock so return true 66 | // Create the Lock 67 | err = PutStruct(s3c, bucket, lock_path, lock) 68 | 69 | if err != nil { 70 | return true, err 71 | } 72 | 73 | return true, nil 74 | } 75 | 76 | // ReleaseLock removes the lock file for UUID 77 | // If the lock file exists and is not the same UUID it returns an error 78 | func ReleaseLock(s3c aws.S3API, bucket *string, lock_path *string, uuid string) error { 79 | var s3_lock Lock 80 | 81 | err := GetStruct(s3c, bucket, lock_path, &s3_lock) 82 | if err != nil { 83 | switch err.(type) { 84 | case *NotFoundError: 85 | // No lock to release 86 | return nil 87 | default: 88 | return err // All other errors return 89 | } 90 | } 91 | 92 | // if s3_lock unmarshalled and the UUID is different then error 93 | if s3_lock.UUID != "" && s3_lock.UUID != uuid { 94 | return fmt.Errorf("Release with UUID(%v) is trying to unlock UUID(%v)", uuid, s3_lock.UUID) 95 | } 96 | 97 | return Delete(s3c, bucket, lock_path) 98 | } 99 | -------------------------------------------------------------------------------- /aws/s3/lock_test.go: -------------------------------------------------------------------------------- 1 | package s3 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/coinbase/step/aws/mocks" 8 | "github.com/coinbase/step/utils/to" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func Test_GrabLock_Success(t *testing.T) { 13 | s3c := &mocks.MockS3Client{} 14 | bucket := to.Strp("bucket") 15 | path := to.Strp("path") 16 | 17 | grabbed, err := GrabLock(s3c, bucket, path, "UUID") 18 | 19 | assert.NoError(t, err) 20 | assert.True(t, grabbed) 21 | } 22 | 23 | func Test_CheckUserLock_Success(t *testing.T) { 24 | s3c := &mocks.MockS3Client{} 25 | bucket := to.Strp("bucket") 26 | path := to.Strp("path") 27 | 28 | err := CheckUserLock(s3c, bucket, path) 29 | 30 | assert.NoError(t, err) 31 | } 32 | 33 | func Test_GrabLock_Success_Already_Has_Lock(t *testing.T) { 34 | s3c := &mocks.MockS3Client{} 35 | bucket := to.Strp("bucket") 36 | path := to.Strp("path") 37 | 38 | s3c.AddGetObject(*path, `{"uuid": "UUID"}`, nil) 39 | grabbed, err := GrabLock(s3c, bucket, path, "UUID") 40 | 41 | assert.NoError(t, err) 42 | assert.True(t, grabbed) 43 | } 44 | 45 | func Test_GrabLock_Failure_Already_Locked(t *testing.T) { 46 | s3c := &mocks.MockS3Client{} 47 | bucket := to.Strp("bucket") 48 | path := to.Strp("path") 49 | 50 | s3c.AddGetObject(*path, `{"uuid": "NOT_UUID"}`, nil) 51 | grabbed, err := GrabLock(s3c, bucket, path, "UUID") 52 | 53 | assert.NoError(t, err) 54 | assert.False(t, grabbed) 55 | } 56 | func Test_CheckUserLock_Failure_Already_Locked(t *testing.T) { 57 | s3c := &mocks.MockS3Client{} 58 | bucket := to.Strp("bucket") 59 | path := to.Strp("path") 60 | 61 | s3c.AddGetObject(*path, `{"user": "test", "lock_reason": "testing"}`, nil) 62 | err := CheckUserLock(s3c, bucket, path) 63 | assert.Error(t, err) 64 | } 65 | 66 | func Test_GrabLock_Failure_S3_Get_Error(t *testing.T) { 67 | s3c := &mocks.MockS3Client{} 68 | bucket := to.Strp("bucket") 69 | path := to.Strp("path") 70 | 71 | s3c.AddGetObject(*path, `{"uuid": "NOT_UUID"}`, fmt.Errorf("ERRRR")) 72 | grabbed, err := GrabLock(s3c, bucket, path, "UUID") 73 | 74 | assert.Error(t, err) 75 | assert.False(t, grabbed) 76 | } 77 | 78 | func Test_CheckUserLock_Failure_S3_Get_Error(t *testing.T) { 79 | s3c := &mocks.MockS3Client{} 80 | bucket := to.Strp("bucket") 81 | path := to.Strp("path") 82 | 83 | s3c.AddGetObject(*path, `{"user": "test", "lock_reason": "hello"}`, fmt.Errorf("ERRRR")) 84 | err := CheckUserLock(s3c, bucket, path) 85 | 86 | assert.Error(t, err) 87 | } 88 | 89 | func Test_GrabLock_Failure_S3_Upload_Error(t *testing.T) { 90 | s3c := &mocks.MockS3Client{} 91 | bucket := to.Strp("bucket") 92 | path := to.Strp("path") 93 | 94 | s3c.AddPutObject(*path, fmt.Errorf("ERRRR")) 95 | grabbed, err := GrabLock(s3c, bucket, path, "UUID") 96 | 97 | assert.Error(t, err) 98 | assert.True(t, grabbed) 99 | } 100 | 101 | func Test_ReleaseLock_Success_No_Object(t *testing.T) { 102 | s3c := &mocks.MockS3Client{} 103 | bucket := to.Strp("bucket") 104 | path := to.Strp("path") 105 | 106 | err := ReleaseLock(s3c, bucket, path, "UUID") 107 | 108 | assert.NoError(t, err) 109 | } 110 | 111 | func Test_ReleaseLock_Success_Correct_Lock(t *testing.T) { 112 | s3c := &mocks.MockS3Client{} 113 | bucket := to.Strp("bucket") 114 | path := to.Strp("path") 115 | 116 | s3c.AddGetObject(*path, `{"uuid": "UUID"}`, nil) 117 | err := ReleaseLock(s3c, bucket, path, "UUID") 118 | 119 | assert.NoError(t, err) 120 | } 121 | 122 | func Test_ReleaseLock_Failure_AnotherReleasesLock(t *testing.T) { 123 | s3c := &mocks.MockS3Client{} 124 | bucket := to.Strp("bucket") 125 | path := to.Strp("path") 126 | 127 | s3c.AddGetObject(*path, `{"uuid": "NOT_UUID"}`, nil) 128 | err := ReleaseLock(s3c, bucket, path, "UUID") 129 | 130 | assert.Error(t, err) 131 | } 132 | -------------------------------------------------------------------------------- /aws/s3/s3_test.go: -------------------------------------------------------------------------------- 1 | package s3 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/coinbase/step/aws/mocks" 7 | "github.com/coinbase/step/utils/to" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func Test_Get_Success(t *testing.T) { 12 | s3c := &mocks.MockS3Client{} 13 | _, err := Get(s3c, to.Strp("bucket"), to.Strp("/path")) 14 | assert.Error(t, err) 15 | assert.IsType(t, &NotFoundError{}, err) 16 | 17 | s3c.AddGetObject("/path", "asd", nil) 18 | out, err := Get(s3c, to.Strp("bucket"), to.Strp("/path")) 19 | assert.NoError(t, err) 20 | assert.Equal(t, "asd", string(*out)) 21 | } 22 | 23 | func Test_Put_Success(t *testing.T) { 24 | s3c := &mocks.MockS3Client{} 25 | bucket := to.Strp("bucket") 26 | key := to.Strp("/path") 27 | err := PutStr(s3c, bucket, key, to.Strp("asdji")) 28 | assert.NoError(t, err) 29 | 30 | out, err := Get(s3c, bucket, key) 31 | assert.NoError(t, err) 32 | assert.Equal(t, "asdji", string(*out)) 33 | } 34 | 35 | func Test_Put_With_Type_Success(t *testing.T) { 36 | s3c := &mocks.MockS3Client{} 37 | bucket := to.Strp("bucket") 38 | key := to.Strp("/path") 39 | content := []byte("") 40 | contentType := to.Strp("text/html") 41 | err := PutWithType(s3c, bucket, key, &content, contentType) 42 | assert.NoError(t, err) 43 | 44 | object, out, err := GetObject(s3c, bucket, key) 45 | assert.NoError(t, err) 46 | assert.Equal(t, "", string(*out)) 47 | assert.Equal(t, "text/html", string(*object.ContentType)) 48 | } 49 | 50 | func Test_Put_With_Cache_Control_Success(t *testing.T) { 51 | s3c := &mocks.MockS3Client{} 52 | bucket := to.Strp("bucket") 53 | key := to.Strp("/path") 54 | content := []byte("asdji") 55 | cacheControl := to.Strp("public, max-age=31556926") 56 | err := PutWithCacheControl(s3c, bucket, key, &content, cacheControl) 57 | assert.NoError(t, err) 58 | 59 | object, out, err := GetObject(s3c, bucket, key) 60 | assert.NoError(t, err) 61 | assert.Equal(t, "asdji", string(*out)) 62 | assert.Equal(t, "public, max-age=31556926", string(*object.CacheControl)) 63 | } 64 | 65 | func Test_Put_With_Type_And_Cache_Control_Success(t *testing.T) { 66 | s3c := &mocks.MockS3Client{} 67 | bucket := to.Strp("bucket") 68 | key := to.Strp("/path") 69 | content := []byte("") 70 | contentType := to.Strp("text/html") 71 | cacheControl := to.Strp("public, max-age=31556926") 72 | err := PutWithTypeAndCacheControl(s3c, bucket, key, &content, contentType, cacheControl) 73 | assert.NoError(t, err) 74 | 75 | object, out, err := GetObject(s3c, bucket, key) 76 | assert.NoError(t, err) 77 | assert.Equal(t, "", string(*out)) 78 | assert.Equal(t, "text/html", string(*object.ContentType)) 79 | assert.Equal(t, "public, max-age=31556926", string(*object.CacheControl)) 80 | } 81 | 82 | func Test_Delete_Success(t *testing.T) { 83 | s3c := &mocks.MockS3Client{} 84 | bucket := to.Strp("bucket") 85 | key := to.Strp("/path") 86 | err := PutStr(s3c, bucket, key, to.Strp("asdji")) 87 | assert.NoError(t, err) 88 | 89 | out, err := Get(s3c, bucket, key) 90 | assert.NoError(t, err) 91 | assert.Equal(t, "asdji", string(*out)) 92 | 93 | err = Delete(s3c, bucket, key) 94 | assert.NoError(t, err) 95 | 96 | _, err = Get(s3c, bucket, key) 97 | assert.Error(t, err) 98 | assert.IsType(t, &NotFoundError{}, err) 99 | } 100 | 101 | func Test_GetStruct_Success(t *testing.T) { 102 | s3c := &mocks.MockS3Client{} 103 | s3c.AddGetObject("/path", `{"name": "asd"}`, nil) 104 | str := struct { 105 | Name string 106 | }{} 107 | 108 | err := GetStruct(s3c, to.Strp("bucket"), to.Strp("/path"), &str) 109 | assert.NoError(t, err) 110 | assert.Equal(t, "asd", str.Name) 111 | } 112 | 113 | func Test_PutStruct_Success(t *testing.T) { 114 | s3c := &mocks.MockS3Client{} 115 | bucket := to.Strp("bucket") 116 | key := to.Strp("/path") 117 | err := PutStruct(s3c, bucket, key, struct { 118 | Name string 119 | }{"asd"}) 120 | assert.NoError(t, err) 121 | 122 | str := struct { 123 | Name string 124 | }{} 125 | 126 | err = GetStruct(s3c, to.Strp("bucket"), to.Strp("/path"), &str) 127 | assert.NoError(t, err) 128 | assert.Equal(t, "asd", str.Name) 129 | } 130 | -------------------------------------------------------------------------------- /bifrost/inmemory_locker.go: -------------------------------------------------------------------------------- 1 | package bifrost 2 | 3 | import ( 4 | "fmt" 5 | "sync" 6 | ) 7 | 8 | type Lock struct { 9 | lockPath string 10 | uuid string 11 | reason string 12 | } 13 | 14 | type InMemoryLocker struct { 15 | mu sync.RWMutex 16 | locks map[string][]*Lock 17 | } 18 | 19 | func NewInMemoryLocker() *InMemoryLocker { 20 | return &InMemoryLocker{ 21 | locks: make(map[string][]*Lock), 22 | } 23 | } 24 | 25 | func (l *InMemoryLocker) GrabLock(namespace string, lockPath string, uuid string, reason string) (bool, error) { 26 | existingLock := l.GetLockByPath(namespace, lockPath) 27 | if existingLock != nil { 28 | return existingLock.uuid == uuid, nil 29 | } 30 | 31 | l.mu.Lock() 32 | defer l.mu.Unlock() 33 | 34 | l.locks[namespace] = append(l.locks[namespace], &Lock{ 35 | lockPath: lockPath, 36 | uuid: uuid, 37 | reason: reason, 38 | }) 39 | 40 | return true, nil 41 | } 42 | 43 | func (l *InMemoryLocker) ReleaseLock(namespace string, lockPath string, uuid string) error { 44 | existingLock := l.GetLockByPath(namespace, lockPath) 45 | if existingLock != nil && existingLock.uuid != uuid { 46 | return fmt.Errorf("failed to release lock: %s is currently held by UUID(%v)", lockPath, existingLock.uuid) 47 | } 48 | 49 | l.mu.Lock() 50 | defer l.mu.Unlock() 51 | 52 | var updatedLocks []*Lock 53 | for _, lock := range l.locks[namespace] { 54 | if lock.uuid == uuid { 55 | continue 56 | } 57 | updatedLocks = append(updatedLocks, lock) 58 | } 59 | 60 | l.locks[namespace] = updatedLocks 61 | 62 | return nil 63 | } 64 | 65 | func (l *InMemoryLocker) GetLockByNamespace(namespace string) []*Lock { 66 | l.mu.RLock() 67 | defer l.mu.RUnlock() 68 | 69 | locks, found := l.locks[namespace] 70 | if !found { 71 | return []*Lock{} 72 | } 73 | 74 | return locks 75 | } 76 | 77 | func (l *InMemoryLocker) GetLockByPath(namespace string, lockPath string) *Lock { 78 | l.mu.RLock() 79 | defer l.mu.RUnlock() 80 | 81 | for _, lock := range l.GetLockByNamespace(namespace) { 82 | if lock.lockPath == lockPath { 83 | return lock 84 | } 85 | } 86 | 87 | return nil 88 | } 89 | -------------------------------------------------------------------------------- /bifrost/release_halt_test.go: -------------------------------------------------------------------------------- 1 | package bifrost 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/coinbase/step/utils/to" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func Test_IsHalt_ReleaseTimeout(t *testing.T) { 12 | r := MockRelease() 13 | awsc := MockAwsClients(r) 14 | 15 | assert.NoError(t, r.IsHalt(awsc.S3)) 16 | r.Timeout = to.Intp(0) 17 | assert.Error(t, r.IsHalt(awsc.S3)) 18 | 19 | // 10 second halt 20 | r.StartedAt = to.Timep(time.Now().Add(-1 * (9 * time.Second))) 21 | r.Timeout = to.Intp(10) 22 | assert.NoError(t, r.IsHalt(awsc.S3)) 23 | r.StartedAt = to.Timep(time.Now().Add(-1 * (11 * time.Second))) 24 | assert.Error(t, r.IsHalt(awsc.S3)) 25 | } 26 | 27 | func Test_IsHalt_HaltKey(t *testing.T) { 28 | r := MockRelease() 29 | awsc := MockAwsClients(r) 30 | 31 | assert.NoError(t, r.IsHalt(awsc.S3)) 32 | assert.NoError(t, r.Halt(awsc.S3, to.Strp("error"))) 33 | assert.Error(t, r.IsHalt(awsc.S3)) 34 | 35 | // If the Halt key is older than 5 mins ignore it 36 | awsc.S3.GetObjectResp[*r.HaltPath()].Resp.LastModified = to.Timep(time.Now().Add(-1 * (10 * time.Minute))) 37 | assert.NoError(t, r.IsHalt(awsc.S3)) 38 | } 39 | -------------------------------------------------------------------------------- /bifrost/release_lock_test.go: -------------------------------------------------------------------------------- 1 | package bifrost 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/coinbase/step/utils/to" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func Test_Lock_GrabRootLock(t *testing.T) { 11 | r := MockRelease() 12 | 13 | r2 := MockRelease() 14 | r2.UUID = to.Strp("NOTUUID") 15 | 16 | awsc := MockAwsClients(r) 17 | s3c := awsc.S3Client(r.AwsRegion, nil, nil) 18 | locker := NewInMemoryLocker() 19 | 20 | t.Run("root lock acquired", func(t *testing.T) { 21 | assert.NoError(t, r.GrabRootLock(s3c, locker, "lambdaname")) 22 | }) 23 | 24 | t.Run("same root lock acquired", func(t *testing.T) { 25 | assert.NoError(t, r.GrabRootLock(s3c, locker, "lambdaname")) 26 | 27 | locks := locker.GetLockByNamespace("lambdaname") 28 | // We are re-using an existing lock 29 | assert.Equal(t, len(locks), 1) 30 | assert.Equal(t, locks[0].lockPath, "account/project/config/lock") 31 | }) 32 | 33 | t.Run("conflict when acquiring root lock with different uuid", func(t *testing.T) { 34 | assert.Error(t, r2.GrabRootLock(s3c, locker, "lambdaname")) 35 | // There should be no changes in the existing locks 36 | assert.Equal(t, len(locker.GetLockByNamespace("lambdaname")), 1) 37 | }) 38 | 39 | t.Run("root lock released", func(t *testing.T) { 40 | assert.NoError(t, r.UnlockRoot(s3c, locker, "lambdaname")) 41 | assert.Equal(t, len(locker.GetLockByNamespace("lambdaname")), 0) 42 | }) 43 | 44 | t.Run("same root lock released", func(t *testing.T) { 45 | assert.NoError(t, r.UnlockRoot(s3c, locker, "lambdaname")) 46 | assert.Equal(t, len(locker.GetLockByNamespace("lambdaname")), 0) 47 | }) 48 | 49 | t.Run("root lock with same uuid acquired", func(t *testing.T) { 50 | assert.NoError(t, r2.GrabRootLock(s3c, locker, "lambdaname")) 51 | locks := locker.GetLockByNamespace("lambdaname") 52 | assert.Equal(t, len(locks), 1) 53 | assert.Equal(t, locks[0].lockPath, "account/project/config/lock") 54 | assert.Equal(t, locks[0].uuid, "NOTUUID") 55 | }) 56 | } 57 | -------------------------------------------------------------------------------- /bifrost/release_test.go: -------------------------------------------------------------------------------- 1 | package bifrost 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "testing" 7 | "time" 8 | 9 | "github.com/coinbase/step/aws/mocks" 10 | "github.com/coinbase/step/utils/to" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func MockRelease() *Release { 15 | return &Release{ 16 | AwsRegion: to.Strp("region"), 17 | AwsAccountID: to.Strp("account"), 18 | ReleaseID: to.TimeUUID("release-"), 19 | CreatedAt: to.Timep(time.Now()), 20 | ProjectName: to.Strp("project"), 21 | ConfigName: to.Strp("config"), 22 | Bucket: to.Strp("bucket"), 23 | } 24 | } 25 | 26 | func MockAwsClients(r *Release) *mocks.MockClients { 27 | awsc := mocks.MockAwsClients() 28 | 29 | raw, _ := json.Marshal(r) 30 | 31 | awsc.S3.AddGetObject(fmt.Sprintf("%v/%v/%v/%v/release", *r.AwsAccountID, *r.ProjectName, *r.ConfigName, *r.ReleaseID), string(raw), nil) 32 | r.ReleaseSHA256 = to.SHA256Struct(&r) 33 | 34 | r.SetDefaults(r.AwsRegion, r.AwsAccountID, "") 35 | 36 | return awsc 37 | } 38 | 39 | func TestReleasePaths(t *testing.T) { 40 | release := MockRelease() 41 | release.ReleaseID = to.Strp("id") 42 | 43 | assert.Equal(t, "account/project", *release.ProjectDir()) 44 | assert.Equal(t, "account/project/config", *release.RootDir()) 45 | assert.Equal(t, "account/project/config/id", *release.ReleaseDir()) 46 | assert.Equal(t, "account/project/config/id/release", *release.ReleasePath()) 47 | assert.Equal(t, "account/project/config/id/log", *release.LogPath()) 48 | assert.Equal(t, "account/project/config/lock", *release.RootLockPath()) 49 | assert.Equal(t, "account/project/config/id/lock", *release.ReleaseLockPath()) 50 | assert.Equal(t, "account/project/_shared", *release.SharedProjectDir()) 51 | } 52 | 53 | func Test_Bifrost_Release_Is_Valid(t *testing.T) { 54 | release := MockRelease() 55 | awsc := MockAwsClients(release) 56 | 57 | assert.NoError(t, release.Validate(awsc.S3Client(release.AwsRegion, nil, nil), &Release{})) 58 | } 59 | -------------------------------------------------------------------------------- /client/bootstrap.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | 7 | "github.com/coinbase/step/aws" 8 | "github.com/coinbase/step/deployer" 9 | "github.com/coinbase/step/utils/to" 10 | ) 11 | 12 | // Bootstrap takes release information and uploads directly to Step Function and Lambda 13 | func Bootstrap(release *deployer.Release, zip_file_path *string) error { 14 | awsc := &aws.Clients{} 15 | 16 | fmt.Println("Preparing Release Bundle") 17 | err := PrepareRelease(release, zip_file_path) 18 | if err != nil { 19 | return err 20 | } 21 | 22 | bts, err := ioutil.ReadFile(*zip_file_path) 23 | if err != nil { 24 | return err 25 | } 26 | 27 | fmt.Println("Deploying Step Function") 28 | fmt.Println(to.PrettyJSONStr(release)) 29 | 30 | err = release.DeployStepFunction(awsc.SFNClient(nil, nil, nil)) 31 | if err != nil { 32 | return err 33 | } 34 | 35 | fmt.Println("Deploying Lambda Function") 36 | 37 | err = release.DeployLambdaCode(awsc.LambdaClient(nil, nil, nil), &bts) 38 | if err != nil { 39 | return err 40 | } 41 | 42 | fmt.Println("Success") 43 | return nil 44 | } 45 | -------------------------------------------------------------------------------- /client/client.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/coinbase/step/aws" 7 | "github.com/coinbase/step/aws/s3" 8 | "github.com/coinbase/step/deployer" 9 | "github.com/coinbase/step/utils/to" 10 | ) 11 | 12 | // PrepareRelease returns a release with additional information filled in 13 | func PrepareRelease(release *deployer.Release, zip_file_path *string) error { 14 | region, account_id := to.RegionAccount() 15 | release.SetDefaults(region, account_id, "coinbase-step-deployer-") 16 | 17 | lambda_sha, err := to.SHA256File(*zip_file_path) 18 | if err != nil { 19 | return err 20 | } 21 | release.LambdaSHA256 = &lambda_sha 22 | 23 | // Interpolate variables for resource strings 24 | release.StateMachineJSON = to.InterpolateArnVariables( 25 | release.StateMachineJSON, 26 | release.AwsRegion, 27 | release.AwsAccountID, 28 | release.LambdaName, 29 | ) 30 | 31 | return nil 32 | } 33 | 34 | // PrepareReleaseBundle builds and uploads necessary info for a deploy 35 | func PrepareReleaseBundle(awsc aws.AwsClients, release *deployer.Release, zip_file_path *string) error { 36 | if err := PrepareRelease(release, zip_file_path); err != nil { 37 | return err 38 | } 39 | 40 | err := s3.PutFile( 41 | awsc.S3Client(release.AwsRegion, nil, nil), 42 | zip_file_path, 43 | release.Bucket, 44 | release.LambdaZipPath(), 45 | ) 46 | 47 | if err != nil { 48 | return err 49 | } 50 | 51 | // reset CreateAt because it can take a while to upload the lambda 52 | release.CreatedAt = to.Timep(time.Now()) 53 | 54 | // Uploading the Release to S3 to match SHAs 55 | if err := s3.PutStruct(awsc.S3Client(release.AwsRegion, nil, nil), release.Bucket, release.ReleasePath(), release); err != nil { 56 | return err 57 | } 58 | 59 | return nil 60 | } 61 | -------------------------------------------------------------------------------- /client/client_test.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/coinbase/step/aws/mocks" 8 | "github.com/coinbase/step/bifrost" 9 | "github.com/coinbase/step/deployer" 10 | "github.com/coinbase/step/machine" 11 | "github.com/coinbase/step/utils/to" 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func Test_Client_PrepareReleaseBundle(t *testing.T) { 16 | awsc := mocks.MockAwsClients() 17 | release := &deployer.Release{ 18 | Release: bifrost.Release{ 19 | AwsRegion: to.Strp("project"), 20 | AwsAccountID: to.Strp("project"), 21 | ReleaseID: to.TimeUUID("release-"), 22 | CreatedAt: to.Timep(time.Now()), 23 | ProjectName: to.Strp("project"), 24 | ConfigName: to.Strp("project"), 25 | Bucket: to.Strp("project"), 26 | }, 27 | LambdaName: to.Strp("project"), 28 | StepFnName: to.Strp("project"), 29 | StateMachineJSON: to.Strp(machine.EmptyStateMachine), 30 | } 31 | 32 | err := PrepareReleaseBundle( 33 | awsc, 34 | release, 35 | to.Strp("../resources/empty_lambda.zip"), // Location to empty zip file 36 | ) 37 | 38 | assert.NoError(t, err) 39 | } 40 | -------------------------------------------------------------------------------- /client/deploy.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | 7 | "github.com/coinbase/step/aws" 8 | "github.com/coinbase/step/bifrost" 9 | "github.com/coinbase/step/deployer" 10 | "github.com/coinbase/step/execution" 11 | "github.com/coinbase/step/utils/to" 12 | ) 13 | 14 | // Deploy takes release information and Calls the Step Deployer to deploy the release 15 | func Deploy(release *deployer.Release, zip_file_path *string, deployer_arn *string) error { 16 | awsc := &aws.Clients{} 17 | 18 | fmt.Println("Preparing Release Bundle") 19 | err := PrepareReleaseBundle(awsc, release, zip_file_path) 20 | if err != nil { 21 | return err 22 | } 23 | 24 | fmt.Println("Preparing Deploy") 25 | fmt.Println(to.PrettyJSONStr(release)) 26 | err = sendDeployToDeployer(awsc.SFNClient(nil, nil, nil), release.ReleaseID, release, deployer_arn) 27 | if err != nil { 28 | return err 29 | } 30 | 31 | return nil 32 | } 33 | 34 | // sendDeployToDeployer Calls the Step Deployer Step Function, 35 | // This function will wait for the execution to finish but will timeout after 20 seconds 36 | func sendDeployToDeployer(sfnc aws.SFNAPI, name *string, release *deployer.Release, deployer_arn *string) error { 37 | 38 | exec, err := execution.StartExecution(sfnc, deployer_arn, name, release) 39 | if err != nil { 40 | return err 41 | } 42 | 43 | fmt.Printf("\nStarting Deploy") 44 | 45 | exec.WaitForExecution(sfnc, 1, func(ed *execution.Execution, sd *execution.StateDetails, err error) error { 46 | if err != nil { 47 | return fmt.Errorf("Unexpected Error %v", err.Error()) 48 | } 49 | 50 | var release_error struct { 51 | Error *bifrost.ReleaseError `json:"error,omitempty"` 52 | } 53 | 54 | fmt.Printf("\rExecution: %v", *ed.Status) 55 | 56 | if sd.LastOutput != nil { 57 | json.Unmarshal([]byte(*sd.LastOutput), &release_error) 58 | 59 | if release_error.Error != nil { 60 | fmt.Printf("\nError: %v\nCause: %v\n", to.Strs(release_error.Error.Error), to.Strs(release_error.Error.Cause)) 61 | } 62 | } 63 | 64 | return nil 65 | }) 66 | 67 | return nil 68 | } 69 | -------------------------------------------------------------------------------- /deployer/README.md: -------------------------------------------------------------------------------- 1 | # Step Deployer 2 | 3 | deployer state machine 4 | 5 | The Step Deployer is an [AWS Step Function](https://docs.aws.amazon.com/step-functions/latest/dg/getting-started.html) that can deploy step functions, so it can recursively deploy itself. 6 | 7 | To create the necessary AWS resources you can use GeoEngineer which requires `ruby` and `terraform`: 8 | 9 | ```bash 10 | bundle install 11 | ./scripts/geo apply resources/step_deployer.rb 12 | ``` 13 | 14 | We prefer to use AWS credentials exported via [assume-role](https://github.com/coinbase/assume-role) but any AWS access keys will do: 15 | 16 | ```bash 17 | # Use AWS Creds or assume-role 18 | ./scripts/bootstrap_deployer 19 | ``` 20 | 21 | To update the deployer you can use: 22 | 23 | ```bash 24 | git pull # pull down new code 25 | ./scripts/deploy_deployer # recursive deployer 26 | ``` 27 | 28 | To use the deployer: 29 | 30 | ```bash 31 | step deploy -lambda \ 32 | -step \ 33 | -states 34 | ``` 35 | 36 | This will default the AWS region and account to those in the environment variables, the project and config names to tags on the lambda, the lambda file to `./lambda.zip`. 37 | 38 | ### Implementation 39 | 40 | The tasks of the deployer are: 41 | 42 | 1. **Validate**: Validate the sent release bundle 43 | 2. **Lock**: grab a lock in S3 so others cannot deploy at the same time 44 | 3. **ValiadteResources**: Validate the referenced resources exist and have the correct tags and paths 45 | 4. **Deploy**: Update the State Machine and Lambda, then release the Lock 46 | 5. **ReleaseLockFailure**: If something goes wrong, try release the lock and fail 47 | 48 | The end states are: 49 | 50 | 1. **Success**: deployed correctly 51 | 2. **FailureClean**: something went wrong but it has recovered the previous good state 52 | 3. **FailureDirty**: something went wrong and it is not in a good state. The existing step function, Lambda and/or lock require manual cleanup 53 | 54 | The limitations are: 55 | 56 | 1. **State machine size** must be less than 30Kb as it is sent as part of the step-function input. 57 | 2. **Lambda size** must be less than the RAM available to the `step-deployer` lambda as it validates the lambda SHA256 in memory 58 | 59 | ### Security 60 | 61 | Deployers are critical pieces of infrastructure as they may be used to compromise software they deploy. As such, we take security very seriously around the `step-deployer` and answer the following questions: 62 | 63 | 1. *Authentication*: Who can deploy? 64 | 2. *Authorization*: What can be deployed? 65 | 3. *Replay* and *Man-in-the-middle (MITM)*: Can some unauthorized person edit or reuse a release to change what is deployed? 66 | 4. *Audit*: Who has done what, and when? 67 | 68 | #### Authentication 69 | 70 | The central authentication mechanisms are the AWS IAM permissions for step functions, lambda, and S3. 71 | 72 | By limiting the `lambda:UpdateFunctionCode`, `lambda:UpdateFunctionConfiguration`, `lambda:Invoke*` and `states:UpdateStateMachine` permissions the `step-deployer` function becomes the only way to deploy. Once this is the case, limiting permissions to `states:StartExecution` of the `step-deployer` directly limits who can deploy. 73 | 74 | Ensuring the `step-deployer` Lambdas role can only access a single single S3 bucket with: 75 | 76 | ``` 77 | { 78 | "Effect": "Allow", 79 | "Action": [ 80 | "s3:GetObject*", "s3:PutObject*", 81 | "s3:List*", "s3:DeleteObject*" 82 | ], 83 | "Resource": [ 84 | "arn:aws:s3:::#{s3_bucket_name}/*", 85 | "arn:aws:s3:::#{s3_bucket_name}" 86 | ] 87 | }, 88 | { 89 | "Effect": "Deny", 90 | "Action": ["s3:*"], 91 | "NotResource": [ 92 | "arn:aws:s3:::#{s3_bucket_name}/*", 93 | "arn:aws:s3:::#{s3_bucket_name}" 94 | ] 95 | }, 96 | ``` 97 | 98 | Further restricts who can deploy to those that also can `s3:PutObject` to the bucket. 99 | 100 | Who can execute the step function, and who can upload to S3 are the two permissions that guard deploys. Additionally, if you separate those two permissions, you gain extra security, e.g. by only allowing your CI/CD pipe to upload releases, and developers to execute the step function you can ensure only valid builds are ever deployed. 101 | 102 | #### Authorization 103 | 104 | We use tags and paths to restrict the resources that the `step-deployer` can deploy to. 105 | 106 | The lambda function must have a `ProjectName` and `ConfigName` tag that match the release, and a `DeployWith` tag equal to `"step-deployer"`. 107 | 108 | Step functions don't support tags, so the path on their role must be must be equal to `/step///`. 109 | 110 | Assets uploaded to S3 are in the path `//` so limiting who can `s3:PutObject` to a path can be used to limit what project-configs they can deploy. 111 | 112 | #### Replay and MITM 113 | 114 | Each release the client generates a release `release_id`, a `created_at` date, and a SHA256 of the lambda file, and together also uploads the release to S3. 115 | 116 | The `step-deployer` will reject any request where the `created_at` date is not recent, the lambdas SHA does not match the uploaded zip, or the release sent to the step function and S3 don't match. This means that if a user can invoke the step function, but not upload to S3 (or vice-versa) it is not possible to deploy old or malicious code. 117 | 118 | #### Audit 119 | 120 | Working out what happened when is very useful for debugging and security response. Step functions make it easy to see the history of all executions in the AWS console and via API. S3 can log all access to cloud-trail, so collecting from these two sources will show all information about a deploy. 121 | 122 | ### Continuing Deployment 123 | 124 | Some TODOs for the deployer are: 125 | 126 | 1. Automated rollback on a bad deploy 127 | 1. Assume-role sts into other accounts to deploy there, so only one `step-deployer` is needed for many accounts. 128 | 1. Health checking the Lambdas and Step Functions, if they have a health check. 129 | -------------------------------------------------------------------------------- /deployer/fuzz_test.go: -------------------------------------------------------------------------------- 1 | package deployer 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/coinbase/step/aws/mocks" 7 | "github.com/coinbase/step/machine" 8 | "github.com/coinbase/step/utils/to" 9 | fuzz "github.com/google/gofuzz" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func Test_Release_Basic_Fuzz(t *testing.T) { 14 | for i := 0; i < 20; i++ { 15 | f := fuzz.New() 16 | var release Release 17 | f.Fuzz(&release) 18 | 19 | assertNoPanic(t, &release) 20 | } 21 | } 22 | 23 | func Test_Release_ValidSM_Fuzz(t *testing.T) { 24 | for i := 0; i < 20; i++ { 25 | f := fuzz.New() 26 | var release Release 27 | f.Fuzz(&release) 28 | 29 | release.StateMachineJSON = to.Strp(machine.EmptyStateMachine) 30 | assertNoPanic(t, &release) 31 | } 32 | } 33 | 34 | func assertNoPanic(t *testing.T, release *Release) { 35 | state_machine := createTestStateMachine(t, mocks.MockAwsClients()) 36 | 37 | exec, err := state_machine.Execute(release) 38 | if err != nil { 39 | assert.NotRegexp(t, "Panic", err.Error()) 40 | } 41 | 42 | assert.NotRegexp(t, "Panic", exec.OutputJSON) 43 | } 44 | -------------------------------------------------------------------------------- /deployer/handlers.go: -------------------------------------------------------------------------------- 1 | /* 2 | The deployer package contains the Step Deployer service 3 | that is a Step Function that Deploys Step Functions. 4 | 5 | It also contains a client for messaging and bootstrapping the Step Deployer. 6 | */ 7 | package deployer 8 | 9 | import ( 10 | "context" 11 | "fmt" 12 | 13 | "github.com/coinbase/step/aws" 14 | "github.com/coinbase/step/aws/dynamodb" 15 | "github.com/coinbase/step/errors" 16 | "github.com/coinbase/step/utils/to" 17 | ) 18 | 19 | //////// 20 | // ERRORS 21 | /////// 22 | 23 | type DeploySFNError struct { 24 | err error 25 | } 26 | 27 | type DeployLambdaError struct { 28 | err error 29 | } 30 | 31 | func (e DeploySFNError) Error() string { 32 | return fmt.Sprintf("DeploySFNError: %v", e.err.Error()) 33 | } 34 | 35 | func (e DeployLambdaError) Error() string { 36 | return fmt.Sprintf("DeployLambdaError: %v", e.err.Error()) 37 | } 38 | 39 | //////////// 40 | // HANDLERS 41 | //////////// 42 | 43 | var assumed_role = to.Strp("coinbase-step-deployer-assumed") 44 | 45 | func ValidateHandler(awsc aws.AwsClients) interface{} { 46 | return func(ctx context.Context, release *Release) (*Release, error) { 47 | // Override any attributes set by the client 48 | release.ReleaseSHA256 = to.SHA256Struct(release) 49 | release.WipeControlledValues() 50 | 51 | region, account := to.AwsRegionAccountFromContext(ctx) 52 | release.SetDefaults(region, account, "coinbase-step-deployer-") 53 | 54 | // Validate the attributes for the release 55 | if err := release.Validate(awsc.S3Client(release.AwsRegion, nil, nil)); err != nil { 56 | return nil, errors.BadReleaseError{err.Error()} 57 | } 58 | 59 | return release, nil 60 | } 61 | } 62 | 63 | func LockHandler(awsc aws.AwsClients) interface{} { 64 | return func(ctx context.Context, release *Release) (*Release, error) { 65 | // returns LockExistsError, LockError 66 | locker := dynamodb.NewDynamoDBLocker(awsc.DynamoDBClient(nil, nil, nil)) 67 | return release, release.GrabLocks(awsc.S3Client(release.AwsRegion, nil, nil), locker, getLockTableNameFromContext(ctx, "-locks")) 68 | } 69 | } 70 | 71 | func ValidateResourcesHandler(awsc aws.AwsClients) interface{} { 72 | return func(ctx context.Context, release *Release) (*Release, error) { 73 | // Validate the Resources for the release 74 | if err := release.ValidateResources(awsc.LambdaClient(release.AwsRegion, release.AwsAccountID, assumed_role), awsc.SFNClient(release.AwsRegion, release.AwsAccountID, assumed_role)); err != nil { 75 | return nil, errors.BadReleaseError{err.Error()} 76 | } 77 | 78 | return release, nil 79 | } 80 | } 81 | 82 | func DeployHandler(awsc aws.AwsClients) interface{} { 83 | return func(ctx context.Context, release *Release) (*Release, error) { 84 | 85 | // Update Step Function first because State Machine if it fails we can recover 86 | if err := release.DeployStepFunction(awsc.SFNClient(release.AwsRegion, release.AwsAccountID, assumed_role)); err != nil { 87 | return nil, DeploySFNError{err} 88 | } 89 | 90 | if err := release.DeployLambda(awsc.LambdaClient(release.AwsRegion, release.AwsAccountID, assumed_role), awsc.S3Client(release.AwsRegion, nil, nil)); err != nil { 91 | return nil, DeployLambdaError{err} 92 | } 93 | 94 | release.Success = to.Boolp(true) 95 | locker := dynamodb.NewDynamoDBLocker(awsc.DynamoDBClient(nil, nil, nil)) 96 | release.UnlockRoot(awsc.S3Client(release.AwsRegion, nil, nil), locker, getLockTableNameFromContext(ctx, "-locks")) 97 | 98 | return release, nil 99 | } 100 | } 101 | 102 | func ReleaseLockFailureHandler(awsc aws.AwsClients) interface{} { 103 | return func(ctx context.Context, release *Release) (*Release, error) { 104 | locker := dynamodb.NewDynamoDBLocker(awsc.DynamoDBClient(nil, nil, nil)) 105 | if err := release.UnlockRoot(awsc.S3Client(release.AwsRegion, nil, nil), locker, getLockTableNameFromContext(ctx, "-locks")); err != nil { 106 | return nil, errors.LockError{err.Error()} 107 | } 108 | 109 | return release, nil 110 | } 111 | } 112 | 113 | func getLockTableNameFromContext(ctx context.Context, postfix string) string { 114 | _, _, lambdaName := to.AwsRegionAccountLambdaNameFromContext(ctx) 115 | return fmt.Sprintf("%s%s", lambdaName, postfix) 116 | } 117 | -------------------------------------------------------------------------------- /deployer/helpers_test.go: -------------------------------------------------------------------------------- 1 | package deployer 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "testing" 7 | "time" 8 | 9 | "github.com/aws/aws-sdk-go/service/lambda" 10 | "github.com/aws/aws-sdk-go/service/sfn" 11 | "github.com/coinbase/step/aws" 12 | "github.com/coinbase/step/aws/mocks" 13 | "github.com/coinbase/step/aws/s3" 14 | "github.com/coinbase/step/bifrost" 15 | "github.com/coinbase/step/machine" 16 | "github.com/coinbase/step/utils/to" 17 | "github.com/stretchr/testify/assert" 18 | ) 19 | 20 | //////// 21 | // RELEASE 22 | //////// 23 | 24 | func MockRelease() *Release { 25 | return &Release{ 26 | Release: bifrost.Release{ 27 | AwsAccountID: to.Strp("00000000"), 28 | ReleaseID: to.Strp("release-1"), 29 | ProjectName: to.Strp("project"), 30 | ConfigName: to.Strp("development"), 31 | CreatedAt: to.Timep(time.Now()), 32 | Metadata: map[string]string{"User": "User@user.com"}, 33 | }, 34 | LambdaName: to.Strp("lambdaname"), 35 | StepFnName: to.Strp("stepfnname"), 36 | StateMachineJSON: to.Strp(machine.EmptyStateMachine), 37 | } 38 | } 39 | 40 | func MockAwsClients(r *Release) *mocks.MockClients { 41 | awsc := mocks.MockAwsClients() 42 | 43 | awsc.Lambda.ListTagsResp = &lambda.ListTagsOutput{ 44 | Tags: map[string]*string{"ProjectName": r.ProjectName, "ConfigName": r.ConfigName, "DeployWith": to.Strp("step-deployer")}, 45 | } 46 | 47 | awsc.SFN.DescribeStateMachineResp = &sfn.DescribeStateMachineOutput{ 48 | RoleArn: to.Strp(fmt.Sprintf("arn:aws:iam::000000000000:role/step/%v/%v/role-name", *r.ProjectName, *r.ConfigName)), 49 | } 50 | 51 | lambda_zip_file_contents := "lambda_zip" 52 | awsc.S3.AddGetObject(*r.LambdaZipPath(), lambda_zip_file_contents, nil) 53 | 54 | if r.LambdaSHA256 == nil { 55 | r.LambdaSHA256 = to.Strp(to.SHA256Str(&lambda_zip_file_contents)) 56 | } 57 | 58 | raw, _ := json.Marshal(r) 59 | 60 | account_id := r.AwsAccountID 61 | if account_id == nil { 62 | account_id = to.Strp("000000000000") 63 | } 64 | 65 | awsc.S3.AddGetObject(fmt.Sprintf("%v/%v/%v/%v/release", *account_id, *r.ProjectName, *r.ConfigName, *r.ReleaseID), string(raw), nil) 66 | 67 | return awsc 68 | } 69 | 70 | //////// 71 | // State Machine 72 | //////// 73 | 74 | func createTestStateMachine(t *testing.T, awsc *mocks.MockClients) *machine.StateMachine { 75 | stateMachine, err := StateMachine() 76 | assert.NoError(t, err) 77 | 78 | tfs := CreateTaskFunctions(awsc) 79 | 80 | err = stateMachine.SetTaskFnHandlers(tfs) 81 | assert.NoError(t, err) 82 | 83 | return stateMachine 84 | } 85 | 86 | func assertNoRootLock(t *testing.T, awsc aws.AwsClients, release *Release) { 87 | _, err := s3.Get(awsc.S3Client(release.AwsRegion, nil, nil), release.Bucket, release.RootLockPath()) 88 | assert.Error(t, err) // Not found error 89 | assert.IsType(t, &s3.NotFoundError{}, err) 90 | } 91 | 92 | func assertNoRootLockWithReleseLock(t *testing.T, awsc aws.AwsClients, release *Release) { 93 | assertNoRootLock(t, awsc, release) 94 | 95 | _, err := s3.Get(awsc.S3Client(release.AwsRegion, nil, nil), release.Bucket, release.ReleaseLockPath()) 96 | assert.NoError(t, err) // Not error 97 | } 98 | 99 | func assertNoRootLockNoReleseLock(t *testing.T, awsc aws.AwsClients, release *Release) { 100 | assertNoRootLock(t, awsc, release) 101 | 102 | _, err := s3.Get(awsc.S3Client(release.AwsRegion, nil, nil), release.Bucket, release.ReleaseLockPath()) 103 | assert.Error(t, err) // Not found error 104 | assert.IsType(t, &s3.NotFoundError{}, err) 105 | } 106 | -------------------------------------------------------------------------------- /deployer/machine.go: -------------------------------------------------------------------------------- 1 | package deployer 2 | 3 | import ( 4 | "github.com/coinbase/step/aws" 5 | "github.com/coinbase/step/handler" 6 | "github.com/coinbase/step/machine" 7 | ) 8 | 9 | // StateMachine returns the StateMachine for the deployer 10 | func StateMachine() (*machine.StateMachine, error) { 11 | return machine.FromJSON([]byte(`{ 12 | "Comment": "Step Function Deployer", 13 | "StartAt": "Validate", 14 | "States": { 15 | "Validate": { 16 | "Type": "TaskFn", 17 | "Resource": "arn:aws:lambda:{{aws_region}}:{{aws_account}}:function:{{lambda_name}}", 18 | "Comment": "Validate and Set Defaults", 19 | "Next": "Lock", 20 | "Catch": [ 21 | { 22 | "Comment": "Bad Release or Error GoTo end", 23 | "ErrorEquals": ["States.ALL"], 24 | "ResultPath": "$.error", 25 | "Next": "FailureClean" 26 | } 27 | ] 28 | }, 29 | "Lock": { 30 | "Type": "TaskFn", 31 | "Resource": "arn:aws:lambda:{{aws_region}}:{{aws_account}}:function:{{lambda_name}}", 32 | "Comment": "Grab Lock", 33 | "Next": "ValidateResources", 34 | "Catch": [ 35 | { 36 | "Comment": "Something else is deploying", 37 | "ErrorEquals": ["LockExistsError"], 38 | "ResultPath": "$.error", 39 | "Next": "FailureClean" 40 | }, 41 | { 42 | "Comment": "Try Release Lock Then Fail", 43 | "ErrorEquals": ["States.ALL"], 44 | "ResultPath": "$.error", 45 | "Next": "ReleaseLockFailure" 46 | } 47 | ] 48 | }, 49 | "ValidateResources": { 50 | "Type": "TaskFn", 51 | "Resource": "arn:aws:lambda:{{aws_region}}:{{aws_account}}:function:{{lambda_name}}", 52 | "Comment": "ValidateResources", 53 | "Next": "Deploy", 54 | "Catch": [ 55 | { 56 | "Comment": "Try Release Lock Then Fail", 57 | "ErrorEquals": ["States.ALL"], 58 | "ResultPath": "$.error", 59 | "Next": "ReleaseLockFailure" 60 | } 61 | ] 62 | }, 63 | "Deploy": { 64 | "Type": "TaskFn", 65 | "Resource": "arn:aws:lambda:{{aws_region}}:{{aws_account}}:function:{{lambda_name}}", 66 | "Comment": "Upload Step-Function and Lambda", 67 | "Next": "Success", 68 | "Catch": [ 69 | { 70 | "Comment": "Unsure of State, Leave Lock and Fail", 71 | "ErrorEquals": ["DeploySFNError"], 72 | "ResultPath": "$.error", 73 | "Next": "ReleaseLockFailure" 74 | }, 75 | { 76 | "Comment": "Unsure of State, Leave Lock and Fail", 77 | "ErrorEquals": ["States.ALL"], 78 | "ResultPath": "$.error", 79 | "Next": "FailureDirty" 80 | } 81 | ] 82 | }, 83 | "ReleaseLockFailure": { 84 | "Type": "TaskFn", 85 | "Resource": "arn:aws:lambda:{{aws_region}}:{{aws_account}}:function:{{lambda_name}}", 86 | "Comment": "Release the Lock and Fail", 87 | "Next": "FailureClean", 88 | "Retry": [ { 89 | "Comment": "Keep trying to Release", 90 | "ErrorEquals": ["States.ALL"], 91 | "MaxAttempts": 3, 92 | "IntervalSeconds": 30 93 | }], 94 | "Catch": [{ 95 | "ErrorEquals": ["States.ALL"], 96 | "ResultPath": "$.error", 97 | "Next": "FailureDirty" 98 | }] 99 | }, 100 | "FailureClean": { 101 | "Comment": "Deploy Failed Cleanly", 102 | "Type": "Fail", 103 | "Error": "NotifyError" 104 | }, 105 | "FailureDirty": { 106 | "Comment": "Deploy Failed, Resources left in Bad State, ALERT!", 107 | "Type": "Fail", 108 | "Error": "AlertError" 109 | }, 110 | "Success": { 111 | "Type": "Succeed" 112 | } 113 | } 114 | }`)) 115 | } 116 | 117 | // TaskHandlers returns 118 | func TaskHandlers() *handler.TaskHandlers { 119 | return CreateTaskFunctions(&aws.Clients{}) 120 | } 121 | 122 | // CreateTaskFunctions returns 123 | func CreateTaskFunctions(awsc aws.AwsClients) *handler.TaskHandlers { 124 | tm := handler.TaskHandlers{} 125 | tm["Validate"] = ValidateHandler(awsc) 126 | tm["Lock"] = LockHandler(awsc) 127 | tm["ValidateResources"] = ValidateResourcesHandler(awsc) 128 | tm["Deploy"] = DeployHandler(awsc) 129 | tm["ReleaseLockFailure"] = ReleaseLockFailureHandler(awsc) 130 | return &tm 131 | } 132 | -------------------------------------------------------------------------------- /deployer/release.go: -------------------------------------------------------------------------------- 1 | package deployer 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/aws/aws-sdk-go/service/lambda" 7 | "github.com/aws/aws-sdk-go/service/sfn" 8 | "github.com/coinbase/step/aws" 9 | "github.com/coinbase/step/aws/s3" 10 | "github.com/coinbase/step/bifrost" 11 | "github.com/coinbase/step/machine" 12 | "github.com/coinbase/step/utils/is" 13 | "github.com/coinbase/step/utils/to" 14 | ) 15 | 16 | // Release is the Data Structure passed between Client and Deployer 17 | type Release struct { 18 | bifrost.Release 19 | 20 | // Deploy Releases 21 | LambdaName *string `json:"lambda_name,omitempty"` // Lambda Name 22 | LambdaSHA256 *string `json:"lambda_sha256,omitempty"` // Lambda SHA256 Zip file 23 | StepFnName *string `json:"step_fn_name,omitempty"` // Step Function Name 24 | 25 | StateMachineJSON *string `json:"state_machine_json,omitempty"` 26 | } 27 | 28 | ////////// 29 | // Validations 30 | ////////// 31 | 32 | func (r *Release) Validate(s3c aws.S3API) error { 33 | if err := r.Release.Validate(s3c, &Release{}); err != nil { 34 | return err 35 | } 36 | 37 | if is.EmptyStr(r.LambdaName) { 38 | return fmt.Errorf("LambdaName must be defined") 39 | } 40 | 41 | if is.EmptyStr(r.LambdaSHA256) { 42 | return fmt.Errorf("LambdaSHA256 must be defined") 43 | } 44 | 45 | if is.EmptyStr(r.StepFnName) { 46 | return fmt.Errorf("StepFnName must be defined") 47 | } 48 | 49 | if is.EmptyStr(r.StateMachineJSON) { 50 | return fmt.Errorf("StateMachineJSON must be defined") 51 | } 52 | 53 | // Validate State machine 54 | if err := machine.Validate(r.StateMachineJSON); err != nil { 55 | return fmt.Errorf("StateMachineJSON invalid with '%v'", err.Error()) 56 | } 57 | 58 | if err := r.deployLambdaInput(to.ABytep([]byte{})).Validate(); err != nil { 59 | return err 60 | } 61 | 62 | if err := r.deployStepFunctionInput().Validate(); err != nil { 63 | return err 64 | } 65 | 66 | if err := r.ValidateLambdaSHA(s3c); err != nil { 67 | return err 68 | } 69 | 70 | return nil 71 | } 72 | 73 | // Resource Validations 74 | 75 | func (r *Release) ValidateResources(lambdac aws.LambdaAPI, sfnc aws.SFNAPI) error { 76 | if err := r.ValidateLambdaFunctionTags(lambdac); err != nil { 77 | return err 78 | } 79 | 80 | if err := r.ValidateStepFunctionPath(sfnc); err != nil { 81 | return err 82 | } 83 | 84 | return nil 85 | } 86 | 87 | func (r *Release) ValidateLambdaFunctionTags(lambdac aws.LambdaAPI) error { 88 | project, config, deployer, err := r.LambdaProjectConfigDeployerTags(lambdac) 89 | if err != nil { 90 | return err 91 | } 92 | 93 | if project == nil || config == nil || deployer == nil { 94 | return fmt.Errorf("ProjectName, ConfigName and or DeployWith tag on lambda is nil") 95 | } 96 | 97 | if *r.ProjectName != *project { 98 | return fmt.Errorf("Lambda ProjectName tag incorrect, expecting %v has %v", *r.ProjectName, *project) 99 | } 100 | 101 | if *r.ConfigName != *config { 102 | return fmt.Errorf("Lambda ConfigName tag incorrect, expecting %v has %v", *r.ConfigName, *config) 103 | } 104 | 105 | if "step-deployer" != *deployer { 106 | return fmt.Errorf("Lambda DeployWith tag incorrect, expecting step-deployer has %v", *deployer) 107 | } 108 | 109 | return nil 110 | } 111 | 112 | func (r *Release) ValidateStepFunctionPath(sfnc aws.SFNAPI) error { 113 | out, err := sfnc.DescribeStateMachine(&sfn.DescribeStateMachineInput{StateMachineArn: r.StepArn()}) 114 | 115 | if err != nil { 116 | return err 117 | } 118 | 119 | if out == nil || out.RoleArn == nil { 120 | return fmt.Errorf("Unknown Step Function Error") 121 | } 122 | 123 | path := to.ArnPath(*out.RoleArn) 124 | 125 | expected := fmt.Sprintf("/step/%v/%v/", *r.ProjectName, *r.ConfigName) 126 | if path != expected { 127 | return fmt.Errorf("Incorrect Step Function Role Path, expecting %v, got %v", expected, path) 128 | } 129 | 130 | return nil 131 | } 132 | 133 | func (r *Release) ValidateLambdaSHA(s3c aws.S3API) error { 134 | sha, err := s3.GetSHA256(s3c, r.Bucket, r.LambdaZipPath()) 135 | if err != nil { 136 | return err 137 | } 138 | 139 | if sha != *r.LambdaSHA256 { 140 | return fmt.Errorf("Lambda SHA mismatch, expecting %v, got %v", *r.LambdaSHA256, sha) 141 | } 142 | 143 | return nil 144 | } 145 | 146 | func (r *Release) LambdaProjectConfigDeployerTags(lambdac aws.LambdaAPI) (*string, *string, *string, error) { 147 | out, err := lambdac.ListTags(&lambda.ListTagsInput{ 148 | Resource: r.LambdaArn(), 149 | }) 150 | 151 | if err != nil { 152 | return nil, nil, nil, err 153 | } 154 | 155 | if out == nil { 156 | return nil, nil, nil, fmt.Errorf("Unknown Lambda Tags Error") 157 | } 158 | 159 | return out.Tags["ProjectName"], out.Tags["ConfigName"], out.Tags["DeployWith"], nil 160 | } 161 | 162 | ////////// 163 | // AWS Methods 164 | ////////// 165 | 166 | func (release *Release) deployLambdaInput(zip *[]byte) *lambda.UpdateFunctionCodeInput { 167 | return &lambda.UpdateFunctionCodeInput{ 168 | FunctionName: release.LambdaArn(), 169 | ZipFile: *zip, 170 | } 171 | } 172 | 173 | // DeployLambdaCode 174 | func (release *Release) DeployLambdaCode(lambdaClient aws.LambdaAPI, zip *[]byte) error { 175 | _, err := lambdaClient.UpdateFunctionCode(release.deployLambdaInput(zip)) 176 | return err 177 | } 178 | 179 | // DeployLambda uploads new Code to the Lambda 180 | func (release *Release) DeployLambda(lambdaClient aws.LambdaAPI, s3c aws.S3API) error { 181 | // Download and pass Zip file because lambda might be in another region or account 182 | zip, err := s3.Get(s3c, release.Bucket, release.LambdaZipPath()) 183 | if err != nil { 184 | return err 185 | } 186 | 187 | err = release.DeployLambdaCode(lambdaClient, zip) 188 | if err != nil { 189 | return err 190 | } 191 | 192 | return nil 193 | } 194 | 195 | func (release *Release) deployStepFunctionInput() *sfn.UpdateStateMachineInput { 196 | return &sfn.UpdateStateMachineInput{ 197 | Definition: to.Strp(to.PrettyJSONStr(release.StateMachineJSON)), 198 | StateMachineArn: release.StepArn(), 199 | } 200 | } 201 | 202 | // DeployStepFunction updates the step function State Machine 203 | func (release *Release) DeployStepFunction(sfnClient aws.SFNAPI) error { 204 | _, err := sfnClient.UpdateStateMachine(release.deployStepFunctionInput()) 205 | 206 | if err != nil { 207 | return err 208 | } 209 | 210 | return nil 211 | } 212 | 213 | /////// 214 | // Lambda 215 | /////// 216 | 217 | func (release *Release) LambdaZipPath() *string { 218 | s := fmt.Sprintf("%v/lambda.zip", *release.ReleaseDir()) 219 | return &s 220 | } 221 | 222 | func (release *Release) LambdaArn() *string { 223 | return to.LambdaArn(release.AwsRegion, release.AwsAccountID, release.LambdaName) 224 | } 225 | 226 | /////// 227 | // Step 228 | /////// 229 | 230 | func (release *Release) StepArn() *string { 231 | return to.StepArn(release.AwsRegion, release.AwsAccountID, release.StepFnName) 232 | } 233 | -------------------------------------------------------------------------------- /deployer/release_parsing.go: -------------------------------------------------------------------------------- 1 | package deployer 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | ) 7 | 8 | // The goal here is to raise an error if a key is sent that is not supported. 9 | // This should stop many dangerous problems, like misspelling a parameter. 10 | type releaseAlias Release 11 | 12 | // But the problem is that there are exceptions that we have 13 | type XRelease struct { 14 | releaseAlias 15 | Task *string // Do not include the Task because that can be implemented 16 | } 17 | 18 | // UnmarshalJSON should error if there is something unexpected 19 | func (release *Release) UnmarshalJSON(data []byte) error { 20 | var releaseWithExceptions XRelease 21 | dec := json.NewDecoder(bytes.NewReader(data)) 22 | dec.DisallowUnknownFields() // Force 23 | 24 | if err := dec.Decode(&releaseWithExceptions); err != nil { 25 | return err 26 | } 27 | 28 | *release = Release(releaseWithExceptions.releaseAlias) 29 | return nil 30 | } 31 | -------------------------------------------------------------------------------- /deployer/release_test.go: -------------------------------------------------------------------------------- 1 | package deployer 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | 8 | "github.com/coinbase/step/aws/mocks" 9 | "github.com/coinbase/step/utils/to" 10 | ) 11 | 12 | func Test_Release_DeployStepFunction(t *testing.T) { 13 | sfnClient := &mocks.MockSFNClient{} 14 | r := MockRelease() 15 | 16 | err := r.DeployStepFunction(sfnClient) 17 | assert.NoError(t, err) 18 | } 19 | 20 | func Test_Release_DeployLambda(t *testing.T) { 21 | lambdaClient := &mocks.MockLambdaClient{} 22 | s3c := &mocks.MockS3Client{} 23 | 24 | r := MockRelease() 25 | r.Bucket = to.Strp("bucket") 26 | s3c.AddGetObject(*r.LambdaZipPath(), "", nil) 27 | 28 | err := r.DeployLambda(lambdaClient, s3c) 29 | assert.NoError(t, err) 30 | 31 | } 32 | -------------------------------------------------------------------------------- /errors/errors.go: -------------------------------------------------------------------------------- 1 | // errors has a list of common errors and error functions 2 | package errors 3 | 4 | import ( 5 | "fmt" 6 | ) 7 | 8 | // 9 | // General Errors that represent levels of action to be taken 10 | // 11 | 12 | type AlertError struct { 13 | Cause string 14 | } 15 | 16 | func (e AlertError) Error() string { 17 | return fmt.Sprintf("AlertError: %v", e.Cause) 18 | } 19 | 20 | type NotifyError struct { 21 | Cause string 22 | } 23 | 24 | func (e NotifyError) Error() string { 25 | return fmt.Sprintf("NotifyError: %v", e.Cause) 26 | } 27 | 28 | type LogError struct { 29 | Cause string 30 | } 31 | 32 | func (e LogError) Error() string { 33 | return fmt.Sprintf("LogError: %v", e.Cause) 34 | } 35 | 36 | // 37 | // Low Level Step Errors 38 | // 39 | 40 | type UnmarshalError struct { 41 | Cause string 42 | } 43 | 44 | func (e UnmarshalError) Error() string { 45 | return fmt.Sprintf("UnmarshalError: %v", e.Cause) 46 | } 47 | 48 | type PanicError struct { 49 | Cause string 50 | } 51 | 52 | func (e PanicError) Error() string { 53 | return fmt.Sprintf("PanicError: %v", e.Cause) 54 | } 55 | 56 | // 57 | // Specific Deploy/Release errors 58 | // 59 | 60 | // BadReleaseError error 61 | type BadReleaseError struct { 62 | Cause string 63 | } 64 | 65 | func (e BadReleaseError) Error() string { 66 | return fmt.Sprintf("BadReleaseError: %v", e.Cause) 67 | } 68 | 69 | // LockExistsError error 70 | type LockExistsError struct { 71 | Cause string 72 | } 73 | 74 | func (e LockExistsError) Error() string { 75 | return fmt.Sprintf("LockExistsError: %v", e.Cause) 76 | } 77 | 78 | // LockError error 79 | type LockError struct { 80 | Cause string 81 | } 82 | 83 | func (e LockError) Error() string { 84 | return fmt.Sprintf("LockError: %v", e.Cause) 85 | } 86 | 87 | // DeployError error 88 | type DeployError struct { 89 | Cause string 90 | } 91 | 92 | func (e DeployError) Error() string { 93 | return fmt.Sprintf("DeployError: %v", e.Cause) 94 | } 95 | 96 | // HealthError error 97 | type HealthError struct { 98 | Cause string 99 | } 100 | 101 | func (e HealthError) Error() string { 102 | return fmt.Sprintf("HealthError: %v", e.Cause) 103 | } 104 | 105 | // HaltError error 106 | type HaltError struct { 107 | Cause string 108 | } 109 | 110 | func (e HaltError) Error() string { 111 | return fmt.Sprintf("HaltError: %v", e.Cause) 112 | } 113 | 114 | // CleanUpError error 115 | type CleanUpError struct { 116 | Cause string 117 | } 118 | 119 | func (e CleanUpError) Error() string { 120 | return fmt.Sprintf("CleanUpError: %v", e.Cause) 121 | } 122 | 123 | func throw(err error) error { 124 | fmt.Printf(err.Error()) 125 | return err 126 | } 127 | -------------------------------------------------------------------------------- /examples/all_types.json: -------------------------------------------------------------------------------- 1 | { 2 | "Comment": "Contrived Valid Example that should have all State types", 3 | "StartAt": "Pass", 4 | "States": { 5 | "SimpleTask": { 6 | "Comment": "This is a comment", 7 | "Type": "Task", 8 | "Resource": "asd", 9 | "End": true 10 | }, 11 | "Task": { 12 | "Type": "Task", 13 | "Resource": "asd", 14 | "Catch": [ 15 | { 16 | "ErrorEquals": [ 17 | "CustomError1", 18 | "CustomError2" 19 | ], 20 | "ResultPath": "$.asd", 21 | "Next": "Pass" 22 | } 23 | ], 24 | "Retry": [ 25 | { 26 | "ErrorEquals": [ 27 | "CustomError1", 28 | "CustomError2" 29 | ], 30 | "IntervalSeconds": 3, 31 | "MaxAttempts": 10, 32 | "BackoffRate": 2.5 33 | } 34 | ], 35 | "End": true 36 | }, 37 | "Pass": { 38 | "Type": "Pass", 39 | "Result": { 40 | "x": 0.1337, 41 | "y": 3.14159 42 | }, 43 | "ResultPath": "$.coords", 44 | "End": true 45 | }, 46 | "Choice": { 47 | "Type": "Choice", 48 | "Choices": [ 49 | { 50 | "Not": { 51 | "Variable": "$.type.foo.bar", 52 | "StringEquals": "Private" 53 | }, 54 | "Next": "Public" 55 | }, 56 | { 57 | "Variable": "$.value", 58 | "NumericEquals": 0, 59 | "Next": "ValueIsZero" 60 | }, 61 | { 62 | "And": [ 63 | { 64 | "Variable": "$.value", 65 | "NumericGreaterThanEquals": 20.5 66 | }, 67 | { 68 | "Variable": "$.value", 69 | "NumericLessThan": 30 70 | } 71 | ], 72 | "Next": "ValueInTwenties" 73 | } 74 | ], 75 | "Default": "DefaultState" 76 | }, 77 | "Fail": { 78 | "Type": "Fail", 79 | "Error": "ERROR" 80 | }, 81 | "Succeed": { 82 | "Type": "Succeed" 83 | }, 84 | "Parallel": { 85 | "Type": "Parallel" 86 | }, 87 | "Wait": { 88 | "Type": "Wait", 89 | "End": true, 90 | "Seconds": 10 91 | } 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /examples/bad_path.json: -------------------------------------------------------------------------------- 1 | { 2 | "Comment": "Adds some coordinates to the input", 3 | "StartAt": "Coords", 4 | "States": { 5 | "Coords": { 6 | "Type": "Pass", 7 | "Result": { 8 | "x": 0.1337, 9 | "y": 3.14159 10 | }, 11 | "ResultPath": "$.", 12 | "End": true 13 | } 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /examples/bad_type.json: -------------------------------------------------------------------------------- 1 | { 2 | "StartAt": "Start", 3 | "States": { 4 | "Start": { 5 | "Type": "NOT_A_TYPE" 6 | } 7 | } 8 | } 9 | -------------------------------------------------------------------------------- /examples/bad_unknown_state.json: -------------------------------------------------------------------------------- 1 | { 2 | "StartAt": "Start", 3 | "States": { 4 | "NotStart": { 5 | "Type": "Pass", 6 | "End": true 7 | } 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /examples/basic_choice.json: -------------------------------------------------------------------------------- 1 | { 2 | "Comment": "Adds some coordinates to the input", 3 | "StartAt": "ChoiceStateX", 4 | "States": { 5 | "ChoiceStateX": { 6 | "Type": "Choice", 7 | "Choices": [ 8 | { 9 | "Not": { 10 | "Variable": "$.type", 11 | "StringEquals": "Private" 12 | }, 13 | "Next": "Public" 14 | }, 15 | { 16 | "Variable": "$.value", 17 | "NumericEquals": 0, 18 | "Next": "ValueIsZero" 19 | }, 20 | { 21 | "And": [ 22 | { 23 | "Variable": "$.value", 24 | "NumericGreaterThanEquals": 20 25 | }, 26 | { 27 | "Variable": "$.value", 28 | "NumericLessThan": 30 29 | } 30 | ], 31 | "Next": "ValueInTwenties" 32 | } 33 | ], 34 | "Default": "DefaultState" 35 | }, 36 | "Public": { 37 | "Type": "Pass", 38 | "Next": "NextState" 39 | }, 40 | "ValueIsZero": { 41 | "Type": "Pass", 42 | "Next": "NextState" 43 | }, 44 | "ValueInTwenties": { 45 | "Type": "Pass", 46 | "Next": "NextState" 47 | }, 48 | "DefaultState": { 49 | "Type": "Fail", 50 | "Error": "ERROR", 51 | "Cause": "No Matches!" 52 | } 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /examples/basic_pass.json: -------------------------------------------------------------------------------- 1 | { 2 | "Comment": "Adds some coordinates to the input", 3 | "StartAt": "Coords", 4 | "States": { 5 | "Coords": { 6 | "Type": "Pass", 7 | "Result": { 8 | "x": 0.1337, 9 | "y": 3.14159 10 | }, 11 | "ResultPath": "$.coords", 12 | "End": true 13 | } 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /examples/builder.json: -------------------------------------------------------------------------------- 1 | { 2 | "Comment": "Builder Example", 3 | "StartAt": "FetchValidateCreateFn", 4 | "States": { 5 | "FetchValidateCreateFn": { 6 | "Type": "Pass", 7 | "Result": "FetchValidateCreate", 8 | "ResultPath": "$.Task", 9 | "Next": "FetchValidateCreate" 10 | }, 11 | "FetchValidateCreate": { 12 | "Type": "Task", 13 | "Comment": "Fetch, Validate, Create Resources", 14 | "Resource": "go://localhost/FetchValidateCreate", 15 | "Next": "WaitForBuilt" 16 | }, 17 | "WaitForBuilt": { 18 | "Type": "Wait", 19 | "Seconds" : 20, 20 | "Next": "CheckBuiltFn" 21 | }, 22 | "CheckBuiltFn": { 23 | "Type": "Pass", 24 | "Result": "CheckBuilt", 25 | "ResultPath": "$.Task", 26 | "Next": "CheckBuilt" 27 | }, 28 | "CheckBuilt": { 29 | "Type": "Task", 30 | "Comment": "Is the build finished, has it errored?", 31 | "Resource": "go://localhost/CheckBuilt", 32 | "Next": "Built?" 33 | }, 34 | "Built?": { 35 | "Type": "Choice", 36 | "Choices": [ 37 | { 38 | "Variable": "$.Built", 39 | "BooleanEquals": true, 40 | "Next": "Success" 41 | }, 42 | { 43 | "Variable": "$.Error", 44 | "BooleanEquals": true, 45 | "Next": "Fail" 46 | } 47 | ], 48 | "Default": "CleanUpFailureFn" 49 | }, 50 | "Fail": { 51 | "Type": "Fail" 52 | }, 53 | "Success": { 54 | "Type": "Success" 55 | } 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /examples/deployer.json: -------------------------------------------------------------------------------- 1 | { 2 | "Comment": "Deployer Example", 3 | "StartAt": "FetchValidateCreateFn", 4 | "States": { 5 | "FetchValidateCreateFn": { 6 | "Type": "Pass", 7 | "Result": "FetchValidateCreate", 8 | "ResultPath": "$.Task", 9 | "Next": "FetchValidateCreate" 10 | }, 11 | "FetchValidateCreate": { 12 | "Type": "Task", 13 | "Comment": "Fetch, Validate, Create Resources", 14 | "Resource": "go://localhost/FetchValidateCreate", 15 | "Next": "WaitForHealthy" 16 | }, 17 | "WaitForHealthy": { 18 | "Type": "Wait", 19 | "Seconds" : 20, 20 | "Next": "CheckHealthyFn" 21 | }, 22 | "CheckHealthyFn": { 23 | "Type": "Pass", 24 | "Result": "CheckHealthy", 25 | "ResultPath": "$.Task", 26 | "Next": "CheckHealthy" 27 | }, 28 | "CheckHealthy": { 29 | "Type": "Task", 30 | "Comment": "Is the new deploy healthy? Should we continue checking?", 31 | "Resource": "go://localhost/CheckHealthy", 32 | "Next": "Healthy?" 33 | }, 34 | "Healthy?": { 35 | "Type": "Choice", 36 | "Choices": [ 37 | { 38 | "Variable": "$.Healthy", 39 | "BooleanEquals": true, 40 | "Next": "CleanUpFn" 41 | }, 42 | { 43 | "Variable": "$.CheckAgain", 44 | "BooleanEquals": true, 45 | "Next": "WaitForHealthy" 46 | } 47 | ], 48 | "Default": "CleanUpFailureFn" 49 | }, 50 | "CleanUpFn": { 51 | "Type": "Pass", 52 | "Result": "CleanUp", 53 | "ResultPath": "$.Task", 54 | "Next": "CleanUp" 55 | }, 56 | "CleanUp": { 57 | "Type": "Task", 58 | "Comment": "Delete Old Resources", 59 | "Resource": "go://localhost/CleanUp", 60 | "Next": "Success" 61 | }, 62 | "CleanUpFailureFn": { 63 | "Type": "Pass", 64 | "Result": "CleanUpFailure", 65 | "ResultPath": "$.Task", 66 | "Next": "CleanUpFailure" 67 | }, 68 | "CleanUpFailure": { 69 | "Type": "Task", 70 | "Comment": "Delete Old Resources", 71 | "Resource": "go://localhost/CleanUpFailure", 72 | "Next": "Fail" 73 | }, 74 | "Fail": { 75 | "Type": "Fail" 76 | }, 77 | "Success": { 78 | "Type": "Succeed" 79 | } 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /examples/map.json: -------------------------------------------------------------------------------- 1 | { 2 | "Comment": "Adds some coordinates to the input", 3 | "StartAt": "Start", 4 | "States": { 5 | "Start": { 6 | "Type": "Map", 7 | "InputPath": "$.detail", 8 | "ItemsPath": "$.shipped", 9 | "ResultPath": "$.detail.shipped", 10 | "MaxConcurrency": 0, 11 | "Iterator": { 12 | "StartAt": "Validate", 13 | "States": { 14 | "Validate": { 15 | "Type": "Task", 16 | "Resource": "arn:aws:lambda:us-east-1:123456789012:function:ship-val", 17 | "End": true 18 | } 19 | } 20 | }, 21 | "End": true 22 | } 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /examples/step_deployer.json: -------------------------------------------------------------------------------- 1 | { 2 | "Comment": "Step Deployer Example", 3 | "StartAt": "FetchValidateCreateFn", 4 | "States": { 5 | "FetchValidateCreateFn": { 6 | "Type": "Pass", 7 | "Result": "FetchValidateCreate", 8 | "ResultPath": "$.Task", 9 | "Next": "FetchValidateCreate" 10 | }, 11 | "FetchValidateCreate": { 12 | "Type": "Task", 13 | "Comment": "Fetch, Validate, Create Resources", 14 | "Resource": "", 15 | "Next": "WaitForHealthy" 16 | }, 17 | "Success": { 18 | "Type": "Success" 19 | } 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /examples/taskfn.json: -------------------------------------------------------------------------------- 1 | { 2 | "Comment": "Contrived Valid Example that should have all State types", 3 | "StartAt": "Pass", 4 | "States": { 5 | "TaskFn": { 6 | "Type": "TaskFn", 7 | "Resource": "asd", 8 | "Catch": [ 9 | { 10 | "ErrorEquals": [ 11 | "CustomError1", 12 | "CustomError2" 13 | ], 14 | "ResultPath": "$.asd", 15 | "Next": "Pass" 16 | } 17 | ], 18 | "Retry": [ 19 | { 20 | "ErrorEquals": [ 21 | "CustomError1", 22 | "CustomError2" 23 | ], 24 | "IntervalSeconds": 3, 25 | "MaxAttempts": 10, 26 | "BackoffRate": 2.5 27 | } 28 | ], 29 | "End": true 30 | } 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /execution/execution.go: -------------------------------------------------------------------------------- 1 | package execution 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/aws/aws-sdk-go/service/sfn" 8 | "github.com/aws/aws-sdk-go/service/sfn/sfniface" 9 | "github.com/coinbase/step/aws" 10 | "github.com/coinbase/step/utils/to" 11 | ) 12 | 13 | type Execution struct { 14 | ExecutionArn *string 15 | Input *string 16 | Name *string 17 | Output *string 18 | StartDate *time.Time 19 | StateMachineArn *string 20 | Status *string 21 | StopDate *time.Time 22 | } 23 | 24 | type ExecutionWaiter func(*Execution, *StateDetails, error) error 25 | 26 | func StartExecution(sfnc sfniface.SFNAPI, arn *string, name *string, input interface{}) (*Execution, error) { 27 | input_json, err := to.PrettyJSON(input) 28 | 29 | if err != nil { 30 | return nil, err 31 | } 32 | return StartExecutionRaw(sfnc, arn, name, to.Strp(string(input_json))) 33 | } 34 | 35 | func StartExecutionRaw(sfnc sfniface.SFNAPI, arn *string, name *string, input_json *string) (*Execution, error) { 36 | out, err := sfnc.StartExecution(&sfn.StartExecutionInput{ 37 | Input: input_json, 38 | StateMachineArn: arn, 39 | Name: name, 40 | }) 41 | 42 | if err != nil { 43 | return nil, err 44 | } 45 | 46 | return &Execution{ExecutionArn: out.ExecutionArn, StartDate: out.StartDate}, nil 47 | } 48 | 49 | // executions lists executions with an option to filter 50 | func ExecutionsAfter(sfnc aws.SFNAPI, arn *string, status *string, afterTime time.Time) ([]*Execution, error) { 51 | allExecutions := []*Execution{} 52 | 53 | pagefn := func(page *sfn.ListExecutionsOutput, lastPage bool) bool { 54 | for _, exe := range page.Executions { 55 | if exe.StartDate.Before(afterTime) { 56 | // Break the pagination 57 | return false 58 | } 59 | 60 | allExecutions = append(allExecutions, fromExectionListItem(exe)) 61 | } 62 | 63 | return !lastPage 64 | } 65 | 66 | err := sfnc.ListExecutionsPages(&sfn.ListExecutionsInput{ 67 | MaxResults: to.Int64p(100), 68 | StateMachineArn: arn, 69 | StatusFilter: status, 70 | }, pagefn) 71 | 72 | if err != nil { 73 | return nil, err 74 | } 75 | 76 | return allExecutions, nil 77 | } 78 | 79 | func fromExectionListItem(e *sfn.ExecutionListItem) *Execution { 80 | ed := Execution{} 81 | 82 | ed.ExecutionArn = e.ExecutionArn 83 | ed.Name = e.Name 84 | ed.StartDate = e.StartDate 85 | ed.StateMachineArn = e.StateMachineArn 86 | ed.Status = e.Status 87 | ed.StopDate = e.StopDate 88 | 89 | return &ed 90 | } 91 | 92 | func FindExecution(sfnc sfniface.SFNAPI, arn *string, name_prefix string) (*Execution, error) { 93 | // TODO search through pages for first match 94 | out, err := sfnc.ListExecutions(&sfn.ListExecutionsInput{ 95 | MaxResults: to.Int64p(100), 96 | StatusFilter: to.Strp("RUNNING"), 97 | StateMachineArn: arn, 98 | }) 99 | 100 | if err != nil { 101 | return nil, err 102 | } 103 | 104 | for _, exec := range out.Executions { 105 | name := *exec.Name 106 | if len(name) < len(name_prefix) { 107 | continue 108 | } 109 | 110 | if name[0:len(name_prefix)] == name_prefix { 111 | return &Execution{ExecutionArn: exec.ExecutionArn, StartDate: exec.StartDate}, nil 112 | } 113 | } 114 | 115 | return nil, nil 116 | } 117 | 118 | type StateDetails struct { 119 | LastStateName *string 120 | LastTaskName *string 121 | LastOutput *string 122 | Timestamp *time.Time 123 | } 124 | 125 | func GetDetails(sfnc sfniface.SFNAPI, executionArn *string) (*Execution, *StateDetails, error) { 126 | exec_out, err := sfnc.DescribeExecution(&sfn.DescribeExecutionInput{ 127 | ExecutionArn: executionArn, 128 | }) 129 | 130 | if err != nil { 131 | return nil, nil, err 132 | } 133 | 134 | ed := Execution{ 135 | ExecutionArn: exec_out.ExecutionArn, 136 | Input: exec_out.Input, 137 | Name: exec_out.Name, 138 | Output: exec_out.Output, 139 | StartDate: exec_out.StartDate, 140 | StateMachineArn: exec_out.StateMachineArn, 141 | Status: exec_out.Status, 142 | StopDate: exec_out.StopDate, 143 | } 144 | 145 | sd, err := ed.GetStateDetails(sfnc) 146 | 147 | if err != nil { 148 | return nil, nil, err 149 | } 150 | 151 | return &ed, sd, nil 152 | } 153 | 154 | func (e *Execution) GetStateDetails(sfnc sfniface.SFNAPI) (*StateDetails, error) { 155 | history_out, err := sfnc.GetExecutionHistory(&sfn.GetExecutionHistoryInput{ 156 | ExecutionArn: e.ExecutionArn, 157 | ReverseOrder: to.Boolp(true), 158 | MaxResults: to.Int64p(20), // Enough to Get the Most Recent State Output 159 | }) 160 | 161 | if err != nil { 162 | return nil, err 163 | } 164 | 165 | sd := StateDetails{} 166 | 167 | // We reverse look for last State Existed Event with Output. 168 | // So even on Failure we can see the final details of Failure 169 | for _, he := range history_out.Events { 170 | if he.Timestamp == nil { 171 | sd.Timestamp = he.Timestamp 172 | } 173 | 174 | if he.StateEnteredEventDetails != nil { 175 | if sd.LastStateName == nil { 176 | sd.LastStateName = he.StateEnteredEventDetails.Name 177 | } 178 | } 179 | 180 | if he.StateExitedEventDetails != nil { 181 | if sd.LastStateName == nil { 182 | sd.LastStateName = he.StateExitedEventDetails.Name 183 | } 184 | 185 | if sd.LastOutput == nil { 186 | sd.LastOutput = he.StateExitedEventDetails.Output 187 | } 188 | 189 | if sd.LastTaskName == nil && *he.Type == "TaskStateExited" { 190 | sd.LastTaskName = he.StateExitedEventDetails.Name 191 | } 192 | } 193 | } 194 | 195 | return &sd, nil 196 | } 197 | 198 | // WaitForExecution allows another application to wait for the execution to finish 199 | // and process output as it comes in for usability 200 | func (e *Execution) WaitForExecution(sfnc sfniface.SFNAPI, sleep int, fn ExecutionWaiter) { 201 | for { 202 | exec, state, err := GetDetails(sfnc, e.ExecutionArn) 203 | 204 | // Copy allowed values over 205 | e.Output = exec.Output 206 | e.StartDate = exec.StartDate 207 | e.Status = exec.Status 208 | e.StopDate = exec.StopDate 209 | 210 | err = fn(exec, state, err) 211 | 212 | if err != nil { 213 | fmt.Println(err.Error()) 214 | return 215 | } 216 | 217 | if *exec.Status != "RUNNING" { 218 | return // Exit out of loop if execution has finished 219 | } 220 | 221 | time.Sleep(time.Duration(int64(sleep)) * time.Second) 222 | } 223 | } 224 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/coinbase/step 2 | 3 | go 1.14 4 | 5 | require ( 6 | github.com/DataDog/datadog-lambda-go v0.6.0 // indirect 7 | github.com/aws/aws-lambda-go v1.17.0 8 | github.com/aws/aws-sdk-go v1.31.8 9 | github.com/aws/aws-xray-sdk-go v1.0.1 // indirect 10 | github.com/cenkalti/backoff v2.2.1+incompatible // indirect 11 | github.com/davecgh/go-spew v1.1.1 // indirect 12 | github.com/google/gofuzz v0.0.0-20170612174753-24818f796faf 13 | github.com/stretchr/testify v1.5.1 14 | ) 15 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= 2 | github.com/DATA-DOG/go-sqlmock v1.4.1/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= 3 | github.com/DataDog/datadog-lambda-go v0.6.0 h1://2QePQGtIQAyFbsv/Bew4EX8VVBUaXltPyxp7rHkZo= 4 | github.com/DataDog/datadog-lambda-go v0.6.0/go.mod h1:8IH+3AngDt+on4Fc7qeFAxj2h6oPuIgsXs5lEPFImto= 5 | github.com/aws/aws-lambda-go v1.11.1 h1:wuOnhS5aqzPOWns71FO35PtbtBKHr4MYsPVt5qXLSfI= 6 | github.com/aws/aws-lambda-go v1.11.1/go.mod h1:Rr2SMTLeSMKgD45uep9V/NP8tnbCcySgu04cx0k/6cw= 7 | github.com/aws/aws-lambda-go v1.17.0 h1:Ogihmi8BnpmCNktKAGpNwSiILNNING1MiosnKUfU8m0= 8 | github.com/aws/aws-lambda-go v1.17.0/go.mod h1:FEwgPLE6+8wcGBTe5cJN3JWurd1Ztm9zN4jsXsjzKKw= 9 | github.com/aws/aws-sdk-go v1.17.12/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= 10 | github.com/aws/aws-sdk-go v1.20.2 h1:/BBeW8F4PPmvJ5jpFvgkCK4RJQXErNndVRnNhO2qEkQ= 11 | github.com/aws/aws-sdk-go v1.20.2/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= 12 | github.com/aws/aws-sdk-go v1.31.8 h1:qbA8nsLYcqtGjMGDogqykuO0LyUONkP9YlsKu1SVV5M= 13 | github.com/aws/aws-sdk-go v1.31.8/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0= 14 | github.com/aws/aws-xray-sdk-go v1.0.0-rc.9 h1:MC5zypTWx5YIbWE3pgcPaG8+1ytirvfCVBkcgHbVZ5Q= 15 | github.com/aws/aws-xray-sdk-go v1.0.0-rc.9/go.mod h1:XtMKdBQfpVut+tJEwI7+dJFRxxRdxHDyVNp2tHXRq04= 16 | github.com/aws/aws-xray-sdk-go v1.0.1 h1:En3DuQ3fAIlNPKoMcAY7bv0lINCJPV0lElK8kEEXsKM= 17 | github.com/aws/aws-xray-sdk-go v1.0.1/go.mod h1:tmxq1c+yeEbMh39OmRFuXOrse5ajRlMmDXJ6LrCVsIs= 18 | github.com/cenkalti/backoff v2.1.1+incompatible h1:tKJnvO2kl0zmb/jA5UKAt4VoEVw1qxKWjE/Bpp46npY= 19 | github.com/cenkalti/backoff v2.1.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM= 20 | github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4= 21 | github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM= 22 | github.com/cihub/seelog v0.0.0-20170130134532-f561c5e57575 h1:kHaBemcxl8o/pQ5VM1c8PVE1PubbNx3mjUr09OqWGCs= 23 | github.com/cihub/seelog v0.0.0-20170130134532-f561c5e57575/go.mod h1:9d6lWj8KzO/fd/NrVaLscBKmPigpZpn5YawRPw+e3Yo= 24 | github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= 25 | github.com/davecgh/go-spew v0.0.0-20160907170601-6d212800a42e/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 26 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 27 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 28 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 29 | github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= 30 | github.com/google/gofuzz v0.0.0-20170612174753-24818f796faf h1:+RRA9JqSOZFfKrOeqr2z77+8R2RKyh8PG66dcu1V0ck= 31 | github.com/google/gofuzz v0.0.0-20170612174753-24818f796faf/go.mod h1:HP5RmnzzSNb993RKQDq4+1A4ia9nllfqcQFTQJedwGI= 32 | github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af h1:pmfjZENx5imkbgOkpRUYLnmbU7UEFbjtDA2hxJ1ichM= 33 | github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= 34 | github.com/jmespath/go-jmespath v0.3.0 h1:OS12ieG61fsCg5+qLJ+SsW9NicxNkg3b25OyT2yCeUc= 35 | github.com/jmespath/go-jmespath v0.3.0/go.mod h1:9QtRXoHjLGCJ5IBSaohpXITPlowMeeYCZ7fLUTSywik= 36 | github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= 37 | github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 38 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 39 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 40 | github.com/pmezard/go-difflib v0.0.0-20151028094244-d8ed2627bdf0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 41 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 42 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 43 | github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= 44 | github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= 45 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 46 | github.com/stretchr/testify v1.1.4/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= 47 | github.com/stretchr/testify v1.2.1/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= 48 | github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= 49 | github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= 50 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 51 | github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= 52 | github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= 53 | github.com/urfave/cli/v2 v2.1.1/go.mod h1:SE9GqnLQmjVa0iPEY0f1w3ygNIYcIJ0OKPMoW2caLfQ= 54 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 55 | golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 56 | golang.org/x/net v0.0.0-20190613194153-d28f0bde5980 h1:dfGZHvZk057jK2MCeWus/TowKpJ8y4AmooUzdBSR9GU= 57 | golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 58 | golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 59 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 60 | golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= 61 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 62 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 63 | gopkg.in/urfave/cli.v1 v1.20.0/go.mod h1:vuBzUtMdQeixQj8LVd+/98pzhxNGQoyuPBlsXHOQNO0= 64 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 65 | -------------------------------------------------------------------------------- /handler/handler.go: -------------------------------------------------------------------------------- 1 | // Lambda Handler Data Structures and types 2 | package handler 3 | 4 | import ( 5 | "context" 6 | "encoding/json" 7 | "fmt" 8 | "reflect" 9 | "runtime/debug" 10 | 11 | "github.com/coinbase/step/errors" 12 | ) 13 | 14 | /////////// 15 | // TYPES 16 | /////////// 17 | 18 | // TaskHandlers maps a Task Name String to a function
ahsufasiu
19 | type TaskHandlers map[string]interface{} 20 | 21 | // TaskReflection caches lots of the reflected values from the Task functions in order to speed up calls 22 | type TaskReflection struct { 23 | Handler reflect.Value 24 | Type reflect.Type 25 | EventType reflect.Type 26 | } 27 | 28 | // CreateTaskReflection creates a TaskReflection from a handler function 29 | func CreateTaskReflection(handlerSymbol interface{}) TaskReflection { 30 | handlerType := reflect.TypeOf(handlerSymbol) 31 | 32 | return TaskReflection{ 33 | Handler: reflect.ValueOf(handlerSymbol), 34 | EventType: handlerType.In(1), 35 | } 36 | } 37 | 38 | // Tasks returns all Task names from a TaskHandlers Map 39 | func (t *TaskHandlers) Tasks() []string { 40 | keys := []string{} 41 | for key, _ := range *t { 42 | if key == "" { 43 | continue 44 | } 45 | keys = append(keys, key) 46 | } 47 | return keys 48 | } 49 | 50 | // TaskHandlers Returns a map of TaskReflections from TaskHandlers 51 | func (t *TaskHandlers) Reflect() map[string]TaskReflection { 52 | ref := map[string]TaskReflection{} 53 | for name, handler := range *t { 54 | ref[name] = CreateTaskReflection(handler) 55 | } 56 | return ref 57 | } 58 | 59 | // TaskHandlers validates all handlers in a TaskHandlers map 60 | func (t *TaskHandlers) Validate() error { 61 | // Each 62 | for name, handler := range *t { 63 | if err := ValidateHandler(handler); err != nil { 64 | return &TaskError{err.Error(), &name, t.Tasks()} 65 | } 66 | } 67 | return nil 68 | } 69 | 70 | // ValidateHandler checks a handler is a function with the correct arguments and return values 71 | func ValidateHandler(handlerSymbol interface{}) error { 72 | if handlerSymbol == nil { 73 | return fmt.Errorf("Handler nil") 74 | } 75 | 76 | handlerType := reflect.TypeOf(handlerSymbol) 77 | 78 | if handlerType.Kind() != reflect.Func { 79 | return fmt.Errorf("handler kind %s is not %s", handlerType.Kind(), reflect.Func) 80 | } 81 | 82 | err := validateArguments(handlerType) 83 | if err != nil { 84 | return err 85 | } 86 | 87 | return nil 88 | } 89 | 90 | func validateArguments(handler reflect.Type) error { 91 | if handler.NumIn() != 2 { 92 | return fmt.Errorf("handlers must take two arguments, but handler takes %d", handler.NumIn()) 93 | } 94 | 95 | if handler.NumOut() != 2 { 96 | return fmt.Errorf("handlers must return two arguments, but handler returns %d", handler.NumOut()) 97 | } 98 | 99 | first_in := handler.In(0) 100 | second_out := handler.Out(1) 101 | 102 | // First Argument implements Context 103 | contextType := reflect.TypeOf((*context.Context)(nil)).Elem() 104 | if !first_in.Implements(contextType) { 105 | return fmt.Errorf("handlers first argument must implement context.Context") 106 | } 107 | 108 | // Second Argument must be error 109 | errorInterface := reflect.TypeOf((*error)(nil)).Elem() 110 | if !second_out.Implements(errorInterface) { 111 | return fmt.Errorf("handlers second return value must be error") 112 | } 113 | 114 | return nil 115 | } 116 | 117 | ////// 118 | // RawMessage 119 | ////// 120 | 121 | // RawMessage is the struct passed to the Lambda Handler 122 | // It contains the name of the Task and the Inputs Raw message 123 | type RawMessage struct { 124 | Task *string 125 | Input json.RawMessage 126 | raw []byte 127 | } 128 | 129 | func (message *RawMessage) UnmarshalJSON(data []byte) error { 130 | type xRawMessage RawMessage 131 | var rawMessageX xRawMessage 132 | 133 | if err := json.Unmarshal(data, &rawMessageX); err != nil { 134 | return err 135 | } 136 | 137 | *message = RawMessage{ 138 | Task: rawMessageX.Task, 139 | Input: rawMessageX.Input, 140 | raw: data, 141 | } 142 | return nil 143 | } 144 | 145 | /////////// 146 | // Errors 147 | /////////// 148 | 149 | // TaskError is a error type a task function may throw handling it in the state machine is a good idea 150 | type TaskError struct { 151 | ErrorString string 152 | Task *string 153 | Tasks []string 154 | } 155 | 156 | func (t *TaskError) Error() string { 157 | for_task := "" 158 | with_taskmap := "" 159 | 160 | if t.Task != nil { 161 | for_task = fmt.Sprintf("(%v)", *t.Task) 162 | } 163 | 164 | if t.Tasks != nil { 165 | with_taskmap = fmt.Sprintf(" : %v", t.Tasks) 166 | } 167 | 168 | return fmt.Sprintf("TaskError%v%v: %v", for_task, with_taskmap, t.ErrorString) 169 | } 170 | 171 | /////////// 172 | // FUNCTIONS 173 | /////////// 174 | 175 | // CreateHandler returns the handler passed to the lambda.Start function 176 | func CreateHandler(tm *TaskHandlers) (func(context context.Context, input *RawMessage) (interface{}, error), error) { 177 | if err := tm.Validate(); err != nil { 178 | return nil, err 179 | } 180 | 181 | // This does most reflection before the run handler, 182 | // that way there is less reflection in the main call 183 | reflections := tm.Reflect() 184 | 185 | handler := func(ctx context.Context, input *RawMessage) (interface{}, error) { 186 | // Find Resource Handler 187 | task_name := input.Task 188 | if task_name == nil { 189 | // If task_name cannot be found look for empty string (NoTask) handler 190 | reflection, ok := reflections[""] 191 | if !ok { 192 | return nil, &TaskError{"Nil Task In Message", nil, nil} 193 | } 194 | // call NoTask handler 195 | return CallHandler(reflection, ctx, input.raw) 196 | } 197 | 198 | reflection, ok := reflections[*task_name] 199 | 200 | if !ok { 201 | return nil, &TaskError{"Cannot Find Task", task_name, tm.Tasks()} 202 | } 203 | 204 | return CallHandler(reflection, ctx, input.Input) 205 | } 206 | 207 | return handler, nil 208 | } 209 | 210 | func recoveryError(r interface{}) error { 211 | switch x := r.(type) { 212 | case string: 213 | return errors.PanicError{x} 214 | case error: 215 | return errors.PanicError{x.Error()} 216 | default: 217 | return errors.PanicError{fmt.Sprintf("Unknown %v", x)} 218 | } 219 | 220 | } 221 | 222 | // HANDLERS 223 | 224 | // CallHandler calls a TaskReflections Handler with the correct objects using reflection 225 | // Mostly borrowed from the aws-lambda-go package 226 | func CallHandler(reflection TaskReflection, ctx context.Context, input []byte) (ret interface{}, err error) { 227 | defer func() { 228 | if r := recover(); r != nil { 229 | fmt.Println("Recovering", r, fmt.Sprintf("%s\n", debug.Stack())) 230 | err = recoveryError(r) 231 | ret = nil 232 | } 233 | }() 234 | 235 | event := reflect.New(reflection.EventType) 236 | 237 | if err = json.Unmarshal(input, event.Interface()); err != nil { 238 | return nil, errors.UnmarshalError{err.Error()} 239 | } 240 | 241 | // Get Type of Function Input 242 | var args []reflect.Value 243 | if ctx == nil { 244 | ctx = context.Background() 245 | } 246 | 247 | args = append(args, reflect.ValueOf(ctx)) 248 | args = append(args, event.Elem()) 249 | 250 | response := reflection.Handler.Call(args) 251 | 252 | if errVal, ok := response[1].Interface().(error); ok { 253 | err = errVal 254 | } 255 | ret = response[0].Interface() 256 | 257 | return ret, err 258 | } 259 | 260 | // CallHandlerFunction does reflection inline and should only be used for testing 261 | func CallHandlerFunction(handlerSymbol interface{}, ctx context.Context, input interface{}) (interface{}, error) { 262 | if err := ValidateHandler(handlerSymbol); err != nil { 263 | return nil, err 264 | } 265 | 266 | raw_json, err := json.Marshal(input) 267 | 268 | if err != nil { 269 | return nil, fmt.Errorf("JSON Marshall Error: %v", err) 270 | } 271 | 272 | reflection := CreateTaskReflection(handlerSymbol) 273 | return CallHandler(reflection, ctx, raw_json) 274 | } 275 | -------------------------------------------------------------------------------- /handler/handler_test.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "testing" 7 | 8 | "github.com/coinbase/step/utils/to" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | type TestStruct struct { 13 | Message *string 14 | } 15 | 16 | func Test_Handler_Execution(t *testing.T) { 17 | called := false 18 | testHandler := func(_ context.Context, ts *TestStruct) (interface{}, error) { 19 | assert.Equal(t, ts.Message, to.Strp("mmss")) 20 | called = true 21 | return "asd", nil 22 | } 23 | 24 | tm := TaskHandlers{"Tester": testHandler} 25 | handle, err := CreateHandler(&tm) 26 | assert.NoError(t, err) 27 | 28 | var raw RawMessage 29 | err = json.Unmarshal([]byte(`{"Task": "Tester", "Input": {"Message": "mmss"}}`), &raw) 30 | assert.NoError(t, err) 31 | 32 | out, err := handle(nil, &raw) 33 | 34 | assert.NoError(t, err) 35 | assert.True(t, called) 36 | assert.Equal(t, out, "asd") 37 | } 38 | 39 | func Test_Handler_Execution_with_TaskHandler_and_NoTaskHandler(t *testing.T) { 40 | nthCalled := false 41 | thCalled := false 42 | 43 | noTaskHandler := func(_ context.Context, ts *TestStruct) (interface{}, error) { 44 | assert.Equal(t, ts.Message, to.Strp("mmss")) 45 | nthCalled = true 46 | return "nth", nil 47 | } 48 | 49 | taskHandler := func(_ context.Context, ts *TestStruct) (interface{}, error) { 50 | assert.Equal(t, ts.Message, to.Strp("mmss")) 51 | thCalled = true 52 | return "th", nil 53 | } 54 | 55 | tm := TaskHandlers{"": noTaskHandler, "Tester": taskHandler} 56 | 57 | handle, err := CreateHandler(&tm) 58 | assert.NoError(t, err) 59 | 60 | var rawNth RawMessage 61 | err = json.Unmarshal([]byte(`{"Message": "mmss"}`), &rawNth) 62 | assert.NoError(t, err) 63 | 64 | var rawTh RawMessage 65 | err = json.Unmarshal([]byte(`{"Task": "Tester", "Input": {"Message": "mmss"}}`), &rawTh) 66 | assert.NoError(t, err) 67 | 68 | outNth, err := handle(nil, &rawNth) 69 | assert.NoError(t, err) 70 | 71 | outTh, err := handle(nil, &rawTh) 72 | assert.NoError(t, err) 73 | 74 | assert.True(t, nthCalled) 75 | assert.True(t, thCalled) 76 | 77 | assert.Equal(t, outNth, "nth") 78 | assert.Equal(t, outTh, "th") 79 | } 80 | 81 | func Test_Handler_Failure(t *testing.T) { 82 | tm := TaskHandlers{} 83 | handle, err := CreateHandler(&tm) 84 | assert.NoError(t, err) 85 | 86 | _, err = handle(nil, &RawMessage{}) 87 | assert.Error(t, err) 88 | 89 | _, err = handle(nil, &RawMessage{Task: to.Strp("Tester")}) 90 | assert.Error(t, err) 91 | } 92 | -------------------------------------------------------------------------------- /jsonpath/jsonpath.go: -------------------------------------------------------------------------------- 1 | // Simple Implementation of JSON Path for state machine 2 | package jsonpath 3 | 4 | import ( 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | "reflect" 9 | "strings" 10 | "time" 11 | ) 12 | 13 | /* 14 | The `data` must be from JSON Unmarshal, that way we can guarantee the types: 15 | 16 | bool, for JSON booleans 17 | float64, for JSON numbers 18 | string, for JSON strings 19 | []interface{}, for JSON arrays 20 | map[string]interface{}, for JSON objects 21 | nil for JSON null 22 | 23 | */ 24 | 25 | var NOT_FOUND_ERROR = errors.New("Not Found") 26 | 27 | type Path struct { 28 | path []string 29 | } 30 | 31 | // NewPath takes string returns JSONPath Object 32 | func NewPath(path_string string) (*Path, error) { 33 | path := Path{} 34 | path_array, err := ParsePathString(path_string) 35 | path.path = path_array 36 | return &path, err 37 | } 38 | 39 | // UnmarshalJSON makes a path out of a json string 40 | func (path *Path) UnmarshalJSON(b []byte) error { 41 | var path_string string 42 | err := json.Unmarshal(b, &path_string) 43 | 44 | if err != nil { 45 | return err 46 | } 47 | 48 | path_array, err := ParsePathString(path_string) 49 | 50 | if err != nil { 51 | return err 52 | } 53 | 54 | path.path = path_array 55 | return nil 56 | } 57 | 58 | // MarshalJSON converts path to json string 59 | func (path *Path) MarshalJSON() ([]byte, error) { 60 | if len(path.path) == 0 { 61 | return json.Marshal("$") 62 | } 63 | return json.Marshal(path.String()) 64 | } 65 | 66 | func (path *Path) String() string { 67 | return fmt.Sprintf("$.%v", strings.Join(path.path[:], ".")) 68 | } 69 | 70 | // ParsePathString parses a path string 71 | func ParsePathString(path_string string) ([]string, error) { 72 | // must start with $. otherwise empty path 73 | if path_string == "" || path_string[0:1] != "$" { 74 | return nil, fmt.Errorf("Bad JSON path: must start with $") 75 | } 76 | 77 | if path_string == "$" { 78 | // Default is no path 79 | return []string{}, nil 80 | } 81 | 82 | if len(path_string) < 2 { 83 | // This handles the case for $. or $* which are invalid 84 | return nil, fmt.Errorf("Bad JSON path: cannot not be 2 characters") 85 | } 86 | 87 | head := path_string[2:len(path_string)] 88 | path_array := strings.Split(head, ".") 89 | 90 | // if path contains an "" error 91 | for _, p := range path_array { 92 | if p == "" { 93 | return nil, fmt.Errorf("Bad JSON path: has empty element") 94 | } 95 | } 96 | // Simple Path Builder 97 | return path_array, nil 98 | } 99 | 100 | // PUBLIC METHODS 101 | 102 | // GetTime returns Time from Path 103 | func (path *Path) GetTime(input interface{}) (*time.Time, error) { 104 | output_value, err := path.Get(input) 105 | 106 | if err != nil { 107 | return nil, fmt.Errorf("GetTime Error %q", err) 108 | } 109 | 110 | var output time.Time 111 | switch output_value.(type) { 112 | case string: 113 | output, err = time.Parse(time.RFC3339, output_value.(string)) 114 | if err != nil { 115 | return nil, fmt.Errorf("GetTime Error: time error %q", err) 116 | } 117 | default: 118 | return nil, fmt.Errorf("GetTime Error: time must be string") 119 | } 120 | 121 | return &output, nil 122 | } 123 | 124 | // GetBool returns Bool from Path 125 | func (path *Path) GetBool(input interface{}) (*bool, error) { 126 | output_value, err := path.Get(input) 127 | 128 | if err != nil { 129 | return nil, fmt.Errorf("GetBool Error %q", err) 130 | } 131 | 132 | var output bool 133 | switch output_value.(type) { 134 | case bool: 135 | output = output_value.(bool) 136 | default: 137 | return nil, fmt.Errorf("GetBool Error: must return bool") 138 | } 139 | 140 | return &output, nil 141 | } 142 | 143 | // GetNumber returns Number from Path 144 | func (path *Path) GetNumber(input interface{}) (*float64, error) { 145 | output_value, err := path.Get(input) 146 | 147 | if err != nil { 148 | return nil, fmt.Errorf("GetFloat Error %q", err) 149 | } 150 | 151 | var output float64 152 | switch output_value.(type) { 153 | case float64: 154 | output = output_value.(float64) 155 | case int: 156 | output = float64(output_value.(int)) 157 | default: 158 | return nil, fmt.Errorf("GetFloat Error: must return float") 159 | } 160 | 161 | return &output, nil 162 | } 163 | 164 | // GetString returns String from Path 165 | func (path *Path) GetString(input interface{}) (*string, error) { 166 | output_value, err := path.Get(input) 167 | 168 | if err != nil { 169 | return nil, fmt.Errorf("GetString Error %q", err) 170 | } 171 | 172 | var output string 173 | switch output_value.(type) { 174 | case string: 175 | output = output_value.(string) 176 | default: 177 | return nil, fmt.Errorf("GetString Error: must return string") 178 | } 179 | 180 | return &output, nil 181 | } 182 | 183 | // GetMap returns Map from Path 184 | func (path *Path) GetMap(input interface{}) (output map[string]interface{}, err error) { 185 | output_value, err := path.Get(input) 186 | 187 | if err != nil { 188 | return nil, fmt.Errorf("GetMap Error %q", err) 189 | } 190 | 191 | switch output_value.(type) { 192 | case map[string]interface{}: 193 | output = output_value.(map[string]interface{}) 194 | default: 195 | return nil, fmt.Errorf("GetMap Error: must return map") 196 | } 197 | 198 | return output, nil 199 | } 200 | 201 | // Get returns interface from Path 202 | func (path *Path) Get(input interface{}) (value interface{}, err error) { 203 | if path == nil { 204 | return input, nil // Default is $ 205 | } 206 | return recursiveGet(input, path.path) 207 | } 208 | 209 | // GetSlice returns array from Path 210 | 211 | func (path *Path) GetSlice(input interface{}) (output []interface{}, err error ) { 212 | output_value, err := path.Get(input) 213 | 214 | if err != nil { 215 | return nil, fmt.Errorf("GetSlice Error %q", err) 216 | } 217 | 218 | switch output_value.(type) { 219 | case []interface{}: 220 | output = output_value.([]interface{}) 221 | default: 222 | return nil, fmt.Errorf("GetSlice Error: must be an array") 223 | } 224 | 225 | return output, nil 226 | } 227 | 228 | // Set sets a Value in a map with Path 229 | func (path *Path) Set(input interface{}, value interface{}) (output map[string]interface{}, err error) { 230 | var set_path []string 231 | if path == nil { 232 | set_path = []string{} // default "$" 233 | } else { 234 | set_path = path.path 235 | } 236 | 237 | if len(set_path) == 0 { 238 | // The output is the value 239 | switch value.(type) { 240 | case map[string]interface{}: 241 | output = value.(map[string]interface{}) 242 | return output, nil 243 | default: 244 | return nil, fmt.Errorf("Cannot Set value %q type %q in root JSON path $", value, reflect.TypeOf(value)) 245 | } 246 | } 247 | return recursiveSet(input, value, set_path), nil 248 | } 249 | 250 | // PRIVATE METHODS 251 | 252 | func recursiveSet(data interface{}, value interface{}, path []string) (output map[string]interface{}) { 253 | var data_map map[string]interface{} 254 | 255 | switch data.(type) { 256 | case map[string]interface{}: 257 | data_map = data.(map[string]interface{}) 258 | default: 259 | // Overwrite current data with new map 260 | // this will work for nil as well 261 | data_map = make(map[string]interface{}) 262 | } 263 | 264 | if len(path) == 1 { 265 | data_map[path[0]] = value 266 | } else { 267 | data_map[path[0]] = recursiveSet(data_map[path[0]], value, path[1:]) 268 | } 269 | 270 | return data_map 271 | } 272 | 273 | func recursiveGet(data interface{}, path []string) (interface{}, error) { 274 | if len(path) == 0 { 275 | return data, nil 276 | } 277 | 278 | if data == nil { 279 | return nil, errors.New("Not Found") 280 | } 281 | 282 | switch data.(type) { 283 | case map[string]interface{}: 284 | value, ok := data.(map[string]interface{})[path[0]] 285 | 286 | if !ok { 287 | return data, NOT_FOUND_ERROR 288 | } 289 | 290 | return recursiveGet(value, path[1:]) 291 | 292 | default: 293 | return data, NOT_FOUND_ERROR 294 | } 295 | } 296 | -------------------------------------------------------------------------------- /jsonpath/jsonpath_get_test.go: -------------------------------------------------------------------------------- 1 | package jsonpath 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func Test_JSONPath_NotFound(t *testing.T) { 10 | test := map[string]interface{}{} 11 | 12 | path, err := NewPath("$.a") 13 | assert.NoError(t, err) 14 | 15 | _, err = path.Get(test) 16 | 17 | assert.Error(t, err) 18 | assert.Equal(t, err.Error(), "Not Found") 19 | } 20 | 21 | func Test_JSONPath_Get_Default(t *testing.T) { 22 | test := map[string]interface{}{"a": "b"} 23 | 24 | path, err := NewPath("$") 25 | assert.NoError(t, err) 26 | 27 | out, err := path.Get(test) 28 | assert.NoError(t, err) 29 | assert.Equal(t, out, test) 30 | } 31 | 32 | func Test_JSONPath_Get_Simple(t *testing.T) { 33 | test := map[string]interface{}{"a": "b"} 34 | 35 | path, err := NewPath("$.a") 36 | assert.NoError(t, err) 37 | 38 | out, err := path.Get(test) 39 | assert.NoError(t, err) 40 | assert.Equal(t, out, "b") 41 | } 42 | 43 | func Test_JSONPath_Get_Deep(t *testing.T) { 44 | test := map[string]interface{}{"a": "b"} 45 | outer := map[string]interface{}{"x": test} 46 | 47 | path, err := NewPath("$.x.a") 48 | assert.NoError(t, err) 49 | 50 | out, err := path.Get(outer) 51 | assert.NoError(t, err) 52 | assert.Equal(t, out, "b") 53 | } 54 | 55 | func Test_JSONPath_GetMap(t *testing.T) { 56 | test := map[string]interface{}{"a": "b"} 57 | outer := map[string]interface{}{"x": test} 58 | 59 | path, err := NewPath("$.x") 60 | assert.NoError(t, err) 61 | 62 | out, err := path.GetMap(outer) 63 | assert.NoError(t, err) 64 | assert.Equal(t, out, test) 65 | } 66 | 67 | func Test_JSONPath_GetMap_Error(t *testing.T) { 68 | test := map[string]interface{}{"a": "b"} 69 | outer := map[string]interface{}{"x": test} 70 | 71 | path, err := NewPath("$.x.a") 72 | assert.NoError(t, err) 73 | 74 | _, err = path.GetMap(outer) 75 | assert.Equal(t, err.Error(), "GetMap Error: must return map") 76 | } 77 | 78 | func Test_JSONPath_GetTime(t *testing.T) { 79 | test := "2006-01-02T15:04:05Z" 80 | outer := map[string]interface{}{"x": test} 81 | 82 | path, err := NewPath("$.x") 83 | assert.NoError(t, err) 84 | 85 | out, err := path.GetTime(outer) 86 | assert.NoError(t, err) 87 | assert.Equal(t, out.Year(), 2006) 88 | } 89 | 90 | func Test_JSONPath_GetBool(t *testing.T) { 91 | test := true 92 | outer := map[string]interface{}{"x": test} 93 | 94 | path, err := NewPath("$.x") 95 | assert.NoError(t, err) 96 | 97 | out, err := path.GetBool(outer) 98 | assert.NoError(t, err) 99 | assert.Equal(t, *out, test) 100 | } 101 | 102 | func Test_JSONPath_GetNumber(t *testing.T) { 103 | test := 1.2 104 | outer := map[string]interface{}{"x": test} 105 | 106 | path, err := NewPath("$.x") 107 | assert.NoError(t, err) 108 | 109 | out, err := path.GetNumber(outer) 110 | assert.NoError(t, err) 111 | assert.Equal(t, *out, test) 112 | } 113 | 114 | func Test_JSONPath_GetString(t *testing.T) { 115 | test := "String" 116 | outer := map[string]interface{}{"x": test} 117 | 118 | path, err := NewPath("$.x") 119 | assert.NoError(t, err) 120 | 121 | out, err := path.GetString(outer) 122 | assert.NoError(t, err) 123 | assert.Equal(t, *out, test) 124 | } 125 | 126 | func Test_JSONPath_GetSplice(t *testing.T) { 127 | test := []interface{}{1,2,3} 128 | outer := map[string]interface{}{"x": test} 129 | 130 | path, err := NewPath("$.x") 131 | assert.NoError(t, err) 132 | 133 | out, err := path.GetSlice(outer) 134 | assert.NoError(t, err) 135 | assert.Equal(t, out, test) 136 | 137 | } 138 | -------------------------------------------------------------------------------- /jsonpath/jsonpath_path_test.go: -------------------------------------------------------------------------------- 1 | package jsonpath 2 | 3 | import ( 4 | "encoding/json" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func Test_JSONPath_Parse_Path(t *testing.T) { 11 | out, err := ParsePathString("$") 12 | assert.NoError(t, err) 13 | assert.Equal(t, len(out), 0) 14 | } 15 | 16 | func Test_JSONPath_Parse_PathLong(t *testing.T) { 17 | out, err := ParsePathString("$.a.b.c") 18 | 19 | assert.NoError(t, err) 20 | 21 | assert.Equal(t, len(out), 3) 22 | 23 | assert.Equal(t, out[0], "a") 24 | assert.Equal(t, out[1], "b") 25 | assert.Equal(t, out[2], "c") 26 | } 27 | 28 | func Test_JSONPath_NewPath(t *testing.T) { 29 | path, err := NewPath("$.a.b.c") 30 | 31 | assert.NoError(t, err) 32 | 33 | assert.Equal(t, len(path.path), 3) 34 | 35 | assert.Equal(t, path.path[0], "a") 36 | assert.Equal(t, path.path[1], "b") 37 | assert.Equal(t, path.path[2], "c") 38 | } 39 | 40 | type testPathStruct struct { 41 | Input Path 42 | } 43 | 44 | func Test_JSONPath_Parsing(t *testing.T) { 45 | raw := []byte(`"$.a.b.c"`) 46 | 47 | var pathstr Path 48 | err := json.Unmarshal(raw, &pathstr) 49 | 50 | assert.NoError(t, err) 51 | 52 | assert.Equal(t, len(pathstr.path), 3) 53 | 54 | assert.Equal(t, pathstr.path[0], "a") 55 | assert.Equal(t, pathstr.path[1], "b") 56 | assert.Equal(t, pathstr.path[2], "c") 57 | } 58 | -------------------------------------------------------------------------------- /jsonpath/jsonpath_set_test.go: -------------------------------------------------------------------------------- 1 | package jsonpath 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func Test_JSONPath_Set_Default(t *testing.T) { 10 | test := map[string]interface{}{"a": "b"} 11 | value := map[string]interface{}{"c": "d"} 12 | 13 | path, err := NewPath("$") 14 | assert.NoError(t, err) 15 | 16 | setted, err := path.Set(test, value) 17 | 18 | assert.NoError(t, err) 19 | assert.Equal(t, setted, value) 20 | } 21 | 22 | func Test_JSONPath_Set_Simple(t *testing.T) { 23 | test := map[string]interface{}{"a": "b"} 24 | 25 | path, err := NewPath("$.a") 26 | assert.NoError(t, err) 27 | 28 | setted, err := path.Set(test, "s") 29 | assert.NoError(t, err) 30 | 31 | out, err := path.Get(setted) 32 | assert.NoError(t, err) 33 | assert.Equal(t, "s", out) 34 | } 35 | 36 | func Test_JSONPath_Set_Deep(t *testing.T) { 37 | test := map[string]interface{}{"a": "b"} 38 | outer := map[string]interface{}{"x": test} 39 | 40 | path, err := NewPath("$.x.a") 41 | assert.NoError(t, err) 42 | 43 | setted, err := path.Set(outer, "s") 44 | assert.NoError(t, err) 45 | 46 | out, err := path.Get(setted) 47 | assert.NoError(t, err) 48 | assert.Equal(t, "s", out) 49 | } 50 | 51 | func Test_JSONPath_Set_Create(t *testing.T) { 52 | test := map[string]interface{}{} 53 | 54 | path, err := NewPath("$.a") 55 | assert.NoError(t, err) 56 | 57 | setted, err := path.Set(test, "s") 58 | assert.NoError(t, err) 59 | 60 | out, err := path.Get(setted) 61 | assert.NoError(t, err) 62 | assert.Equal(t, "s", out) 63 | } 64 | 65 | func Test_JSONPath_Set_Overwrite(t *testing.T) { 66 | test := map[string]interface{}{"a": "b"} 67 | 68 | path, err := NewPath("$.a.b") 69 | assert.NoError(t, err) 70 | 71 | setted, err := path.Set(test, "s") 72 | assert.NoError(t, err) 73 | 74 | out, err := path.Get(setted) 75 | assert.NoError(t, err) 76 | assert.Equal(t, "s", out) 77 | } 78 | -------------------------------------------------------------------------------- /machine/README.md: -------------------------------------------------------------------------------- 1 | # Step Machine 2 | 3 | `machine` is an implementation of the AWS State Machine specification. The primary goal of this implementation is to enable testing of state machines and code together. 4 | 5 | ### Continuing Development 6 | 7 | Step at the moment is still very beta, and its API will likely change more before it stabilizes. If you have ideas for improvements please reach out. 8 | 9 | Some of the TODOs left for the library are: 10 | 11 | 1. Support for Parallel States 12 | 1. Better Validations e.g. making sure all states are reachable and executable 13 | 1. Client side visualization of state machine and execution using GraphViz 14 | 15 | -------------------------------------------------------------------------------- /machine/execution.go: -------------------------------------------------------------------------------- 1 | package machine 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/aws/aws-sdk-go/service/sfn" 9 | "github.com/coinbase/step/utils/to" 10 | ) 11 | 12 | type HistoryEvent struct { 13 | sfn.HistoryEvent 14 | } 15 | 16 | type Execution struct { 17 | Output map[string]interface{} 18 | OutputJSON string 19 | Error error 20 | 21 | LastOutput map[string]interface{} // interim output 22 | LastOutputJSON string 23 | LastError error // interim error 24 | 25 | ExecutionHistory []HistoryEvent 26 | } 27 | 28 | func (sm *Execution) SetOutput(output interface{}, err error) { 29 | switch output.(type) { 30 | case map[string]interface{}: 31 | sm.Output = output.(map[string]interface{}) 32 | sm.OutputJSON, _ = to.PrettyJSON(output) 33 | } 34 | 35 | if err != nil { 36 | sm.Error = err 37 | } 38 | } 39 | 40 | func (sm *Execution) SetLastOutput(output interface{}, err error) { 41 | switch output.(type) { 42 | case map[string]interface{}: 43 | sm.LastOutput = output.(map[string]interface{}) 44 | sm.LastOutputJSON, _ = to.PrettyJSON(output) 45 | } 46 | 47 | if err != nil { 48 | sm.LastError = err 49 | } 50 | } 51 | 52 | func (sm *Execution) EnteredEvent(s State, input interface{}) { 53 | sm.ExecutionHistory = append(sm.ExecutionHistory, createEnteredEvent(s, input)) 54 | } 55 | 56 | func (sm *Execution) ExitedEvent(s State, output interface{}) { 57 | sm.ExecutionHistory = append(sm.ExecutionHistory, createExitedEvent(s, output)) 58 | } 59 | 60 | func (sm *Execution) Start() { 61 | sm.ExecutionHistory = []HistoryEvent{createEvent("ExecutionStarted")} 62 | } 63 | 64 | func (sm *Execution) Failed() { 65 | sm.ExecutionHistory = append(sm.ExecutionHistory, createEvent("ExecutionFailed")) 66 | } 67 | 68 | func (sm *Execution) Succeeded() { 69 | sm.ExecutionHistory = append(sm.ExecutionHistory, createEvent("ExecutionSucceeded")) 70 | } 71 | 72 | // Path returns the Path of States, ignoreing TaskFn states 73 | func (sm *Execution) Path() []string { 74 | path := []string{} 75 | for _, er := range sm.ExecutionHistory { 76 | if er.StateEnteredEventDetails != nil { 77 | name := *er.StateEnteredEventDetails.Name 78 | path = append(path, name) 79 | } 80 | } 81 | return path 82 | } 83 | 84 | func createEvent(name string) HistoryEvent { 85 | t := time.Now() 86 | return HistoryEvent{ 87 | sfn.HistoryEvent{ 88 | Type: to.Strp(name), 89 | Timestamp: &t, 90 | }, 91 | } 92 | } 93 | 94 | func createEnteredEvent(state State, input interface{}) HistoryEvent { 95 | event := createEvent(fmt.Sprintf("%vStateEntered", *state.GetType())) 96 | json_raw, err := json.Marshal(input) 97 | 98 | if err != nil { 99 | json_raw = []byte{} 100 | } 101 | 102 | event.StateEnteredEventDetails = &sfn.StateEnteredEventDetails{ 103 | Name: state.Name(), 104 | Input: to.Strp(string(json_raw)), 105 | } 106 | 107 | return event 108 | } 109 | 110 | func createExitedEvent(state State, output interface{}) HistoryEvent { 111 | event := createEvent(fmt.Sprintf("%vStateExited", *state.GetType())) 112 | json_raw, err := json.Marshal(output) 113 | 114 | if err != nil { 115 | json_raw = []byte{} 116 | } 117 | 118 | event.StateExitedEventDetails = &sfn.StateExitedEventDetails{ 119 | Name: state.Name(), 120 | Output: to.Strp(string(json_raw)), 121 | } 122 | 123 | return event 124 | } 125 | -------------------------------------------------------------------------------- /machine/fail_state.go: -------------------------------------------------------------------------------- 1 | package machine 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/coinbase/step/utils/is" 8 | "github.com/coinbase/step/utils/to" 9 | ) 10 | 11 | type FailState struct { 12 | stateStr // Include Defaults 13 | 14 | Type *string 15 | Comment *string `json:",omitempty"` 16 | 17 | Error *string `json:",omitempty"` 18 | Cause *string `json:",omitempty"` 19 | } 20 | 21 | func (s *FailState) Execute(_ context.Context, input interface{}) (output interface{}, next *string, err error) { 22 | return errorOutput(s.Error, s.Cause), nil, fmt.Errorf("Fail") 23 | } 24 | 25 | func (s *FailState) Validate() error { 26 | s.SetType(to.Strp("Fail")) 27 | 28 | if err := ValidateNameAndType(s); err != nil { 29 | return fmt.Errorf("%v %v", errorPrefix(s), err) 30 | } 31 | 32 | if is.EmptyStr(s.Error) { 33 | return fmt.Errorf("%v %v", errorPrefix(s), "must contain Error") 34 | } 35 | 36 | return nil 37 | } 38 | 39 | func (s *FailState) SetType(t *string) { 40 | s.Type = t 41 | } 42 | 43 | func (s *FailState) GetType() *string { 44 | return s.Type 45 | } 46 | -------------------------------------------------------------------------------- /machine/machine.go: -------------------------------------------------------------------------------- 1 | // State Machine implementation 2 | package machine 3 | 4 | import ( 5 | "context" 6 | "encoding/json" 7 | "errors" 8 | "fmt" 9 | 10 | "github.com/aws/aws-lambda-go/lambdacontext" 11 | "github.com/coinbase/step/handler" 12 | "github.com/coinbase/step/utils/is" 13 | "github.com/coinbase/step/utils/to" 14 | ) 15 | 16 | func DefaultHandler(_ context.Context, input interface{}) (interface{}, error) { 17 | return map[string]string{}, nil 18 | } 19 | 20 | // EmptyStateMachine is a small Valid StateMachine 21 | var EmptyStateMachine = `{ 22 | "StartAt": "WIN", 23 | "States": { "WIN": {"Type": "Succeed"}} 24 | }` 25 | 26 | // IMPLEMENTATION 27 | 28 | // States is the collection of states 29 | type States map[string]State 30 | 31 | // StateMachine the core struct for the machine 32 | type StateMachine struct { 33 | Comment *string `json:",omitempty"` 34 | 35 | StartAt *string 36 | 37 | States States 38 | } 39 | 40 | // Global Methods 41 | func Validate(sm_json *string) error { 42 | state_machine, err := FromJSON([]byte(*sm_json)) 43 | if err != nil { 44 | return err 45 | } 46 | 47 | if err := state_machine.Validate(); err != nil { 48 | return err 49 | } 50 | 51 | return nil 52 | } 53 | 54 | func (sm *StateMachine) FindTask(name string) (*TaskState, error) { 55 | task, ok := sm.Tasks()[name] 56 | 57 | if !ok { 58 | return nil, fmt.Errorf("Handler Error: Cannot Find Task %v", name) 59 | } 60 | 61 | return task, nil 62 | } 63 | 64 | func (sm *StateMachine) Tasks() map[string]*TaskState { 65 | tasks := map[string]*TaskState{} 66 | for name, s := range sm.States { 67 | switch s.(type) { 68 | case *TaskState: 69 | tasks[name] = s.(*TaskState) 70 | } 71 | } 72 | return tasks 73 | } 74 | 75 | func (sm *StateMachine) SetResource(lambda_arn *string) { 76 | for _, task := range sm.Tasks() { 77 | if task.Resource == nil { 78 | task.Resource = lambda_arn 79 | } 80 | } 81 | } 82 | 83 | func (sm *StateMachine) SetDefaultHandler() { 84 | for _, task := range sm.Tasks() { 85 | task.SetTaskHandler(DefaultHandler) 86 | } 87 | } 88 | 89 | func (sm *StateMachine) SetTaskFnHandlers(tfs *handler.TaskHandlers) error { 90 | taskHandlers, err := handler.CreateHandler(tfs) 91 | if err != nil { 92 | return err 93 | } 94 | 95 | for name, _ := range *tfs { 96 | if name == "" { 97 | continue // Skip default Handler 98 | } 99 | if err := sm.SetTaskHandler(name, taskHandlers); err != nil { 100 | return err 101 | } 102 | } 103 | 104 | return nil 105 | } 106 | 107 | func (sm *StateMachine) SetTaskHandler(task_name string, resource_fn interface{}) error { 108 | task, err := sm.FindTask(task_name) 109 | if err != nil { 110 | return err 111 | } 112 | 113 | task.SetTaskHandler(resource_fn) 114 | return nil 115 | } 116 | 117 | func (sm *StateMachine) Validate() error { 118 | if is.EmptyStr(sm.StartAt) { 119 | return errors.New("State Machine requires StartAt") 120 | } 121 | 122 | if sm.States == nil { 123 | return errors.New("State Machine must have States") 124 | } 125 | 126 | if len(sm.States) == 0 { 127 | return errors.New("State Machine must have States") 128 | } 129 | 130 | state_errors := []string{} 131 | 132 | for _, state := range sm.States { 133 | err := state.Validate() 134 | if err != nil { 135 | state_errors = append(state_errors, err.Error()) 136 | } 137 | } 138 | 139 | if len(state_errors) != 0 { 140 | return fmt.Errorf("State Errors %q", state_errors) 141 | } 142 | 143 | // TODO: validate all states are reachable 144 | return nil 145 | } 146 | 147 | func (sm *StateMachine) DefaultLambdaContext(lambda_name string) context.Context { 148 | return lambdacontext.NewContext(context.Background(), &lambdacontext.LambdaContext{ 149 | InvokedFunctionArn: fmt.Sprintf("arn:aws:lambda:us-east-1:000000000000:function:%v", lambda_name), 150 | }) 151 | } 152 | 153 | func processInput(input interface{}) (interface{}, error) { 154 | // Make 155 | switch input.(type) { 156 | case string: 157 | var json_input map[string]interface{} 158 | if err := json.Unmarshal([]byte(input.(string)), &json_input); err != nil { 159 | return nil, err 160 | } 161 | return json_input, nil 162 | case *string: 163 | var json_input map[string]interface{} 164 | if err := json.Unmarshal([]byte(*(input.(*string))), &json_input); err != nil { 165 | return nil, err 166 | } 167 | return json_input, nil 168 | } 169 | 170 | // Converts the input interface into map[string]interface{} 171 | return to.FromJSON(input) 172 | } 173 | 174 | func (sm *StateMachine) Execute(input interface{}) (*Execution, error) { 175 | if err := sm.Validate(); err != nil { 176 | return nil, err 177 | } 178 | 179 | input, err := processInput(input) 180 | if err != nil { 181 | return nil, err 182 | } 183 | 184 | // Start Execution (records the history, inputs, outputs...) 185 | exec := &Execution{} 186 | exec.Start() 187 | 188 | // Execute Start State 189 | output, err := sm.stateLoop(exec, sm.StartAt, input) 190 | 191 | // Set Final Output 192 | exec.SetOutput(output, err) 193 | 194 | if err != nil { 195 | exec.Failed() 196 | } else { 197 | exec.Succeeded() 198 | } 199 | 200 | return exec, err 201 | } 202 | 203 | func (sm *StateMachine) stateLoop(exec *Execution, next *string, input interface{}) (output interface{}, err error) { 204 | // Flat loop instead of recursion to better implement timeouts 205 | for { 206 | s, ok := sm.States[*next] 207 | 208 | if !ok { 209 | return nil, fmt.Errorf("Unknown State: %v", *next) 210 | } 211 | 212 | if len(exec.ExecutionHistory) > 250 { 213 | return nil, fmt.Errorf("State Overflow") 214 | } 215 | 216 | exec.EnteredEvent(s, input) 217 | 218 | output, next, err = s.Execute(sm.DefaultLambdaContext(*s.Name()), input) 219 | 220 | if *s.GetType() != "Fail" { 221 | // Failure States Dont exit. 222 | exec.SetLastOutput(output, err) 223 | exec.ExitedEvent(s, output) 224 | } 225 | 226 | // If Error return error 227 | if err != nil { 228 | return output, err 229 | } 230 | 231 | // If next is nil then END 232 | if next == nil { 233 | return output, nil 234 | } 235 | 236 | input = output 237 | } 238 | } 239 | -------------------------------------------------------------------------------- /machine/machine_test.go: -------------------------------------------------------------------------------- 1 | package machine 2 | 3 | import ( 4 | "encoding/json" 5 | "io/ioutil" 6 | "testing" 7 | 8 | "github.com/coinbase/step/utils/to" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func loadFixture(file string, t *testing.T) *StateMachine { 13 | example_machine, err := ParseFile(file) 14 | assert.NoError(t, err) 15 | return example_machine 16 | } 17 | 18 | func execute(json []byte, input interface{}, t *testing.T) (map[string]interface{}, error) { 19 | example_machine, err := FromJSON(json) 20 | assert.NoError(t, err) 21 | example_machine.SetDefaultHandler() 22 | 23 | exec, err := example_machine.Execute(input) 24 | 25 | return exec.Output, err 26 | } 27 | 28 | func executeFixture(file string, input map[string]interface{}, t *testing.T) map[string]interface{} { 29 | example_machine := loadFixture(file, t) 30 | 31 | exec, err := example_machine.Execute(input) 32 | 33 | assert.NoError(t, err) 34 | 35 | return exec.Output 36 | } 37 | 38 | ////// 39 | // TESTS 40 | ////// 41 | 42 | func Test_Machine_EmptyStateMachinePassExample(t *testing.T) { 43 | _, err := execute([]byte(EmptyStateMachine), make(map[string]interface{}), t) 44 | assert.NoError(t, err) 45 | } 46 | 47 | func Test_Machine_SimplePassExample_With_Execute(t *testing.T) { 48 | json := []byte(` 49 | { 50 | "StartAt": "start", 51 | "States": { 52 | "start": { 53 | "Type": "Pass", 54 | "Result": "b", 55 | "ResultPath": "$.a", 56 | "End": true 57 | } 58 | } 59 | }`) 60 | 61 | output, err := execute(json, make(map[string]interface{}), t) 62 | assert.NoError(t, err) 63 | assert.Equal(t, output["a"], "b") 64 | 65 | output, err = execute(json, "{}", t) 66 | assert.NoError(t, err) 67 | assert.Equal(t, output["a"], "b") 68 | 69 | output, err = execute(json, to.Strp("{}"), t) 70 | assert.NoError(t, err) 71 | assert.Equal(t, output["a"], "b") 72 | } 73 | 74 | func Test_Machine_ErrorUnknownState(t *testing.T) { 75 | example_machine := loadFixture("../examples/bad_unknown_state.json", t) 76 | _, err := example_machine.Execute(make(map[string]interface{})) 77 | 78 | assert.Error(t, err) 79 | assert.Regexp(t, "Unknown State", err.Error()) 80 | } 81 | 82 | func Test_Machine_MarshallAllTypes(t *testing.T) { 83 | file := "../examples/all_types.json" 84 | sm, err := ParseFile(file) 85 | assert.NoError(t, err) 86 | 87 | sm.SetDefaultHandler() 88 | assert.NoError(t, sm.Validate()) 89 | 90 | marshalled_json, err := json.Marshal(sm) 91 | assert.NoError(t, err) 92 | 93 | raw_json, err := ioutil.ReadFile(file) 94 | assert.NoError(t, err) 95 | 96 | assert.JSONEq(t, string(raw_json), string(marshalled_json)) 97 | } 98 | -------------------------------------------------------------------------------- /machine/map_state.go: -------------------------------------------------------------------------------- 1 | package machine 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "github.com/coinbase/step/jsonpath" 7 | "github.com/coinbase/step/utils/to" 8 | ) 9 | 10 | type MapState struct { 11 | stateStr // Include Defaults 12 | 13 | Type *string 14 | Comment *string `json:",omitempty"` 15 | 16 | Iterator *StateMachine 17 | ItemsPath *jsonpath.Path `json:",omitempty"` 18 | Parameters interface{} `json:",omitempty"` 19 | 20 | MaxConcurrency *float64 `json:",omitempty"` 21 | 22 | InputPath *jsonpath.Path `json:",omitempty"` 23 | OutputPath *jsonpath.Path `json:",omitempty"` 24 | ResultPath *jsonpath.Path `json:",omitempty"` 25 | 26 | Catch []*Catcher `json:",omitempty"` 27 | Retry []*Retrier `json:",omitempty"` 28 | 29 | Next *string `json:",omitempty"` 30 | End *bool `json:",omitempty"` 31 | } 32 | 33 | func (s *MapState) process(ctx context.Context, input interface{}) (interface{}, *string, error) { 34 | output, err := s.ItemsPath.GetSlice(input) 35 | if err != nil { 36 | return input, nextState(s.Next, s.End), err 37 | } 38 | var res []map[string]interface{} 39 | 40 | for _, item := range output { 41 | execution, err := s.Iterator.Execute(item) 42 | if err != nil { 43 | return input, nextState(s.Next, s.End), err 44 | } 45 | res = append(res, execution.Output) 46 | } 47 | 48 | return res, nextState(s.Next, s.End), nil 49 | } 50 | 51 | func (s *MapState) Execute(ctx context.Context, input interface{}) (output interface{}, next *string, err error) { 52 | return processError(s, 53 | processCatcher(s.Catch, 54 | processRetrier(s.Name(), s.Retry, 55 | inputOutput( 56 | s.InputPath, 57 | s.OutputPath, 58 | withParams( 59 | s.Parameters, 60 | result(s.ResultPath, s.process), 61 | ), 62 | ), 63 | ), 64 | ), 65 | )(ctx, input) 66 | } 67 | 68 | func (s *MapState) Validate() error { 69 | s.SetType(to.Strp("Map")) 70 | 71 | if err := ValidateNameAndType(s); err != nil { 72 | return fmt.Errorf("%v %v", errorPrefix(s), err) 73 | } 74 | 75 | if err := endValid(s.Next, s.End); err != nil { 76 | return fmt.Errorf("%v %v", errorPrefix(s), err) 77 | } 78 | 79 | if s.Iterator == nil { 80 | return fmt.Errorf("%v Requires Iterator", errorPrefix(s)) 81 | } 82 | 83 | if err := s.Iterator.Validate(); err != nil { 84 | return fmt.Errorf("%v %v", errorPrefix(s), err) 85 | } 86 | return nil 87 | } 88 | 89 | func (s *MapState) SetType(t *string) { 90 | s.Type = t 91 | } 92 | 93 | func (s *MapState) GetType() *string { 94 | return s.Type 95 | } 96 | -------------------------------------------------------------------------------- /machine/map_state_test.go: -------------------------------------------------------------------------------- 1 | package machine 2 | 3 | import ( 4 | "github.com/coinbase/step/utils/to" 5 | "github.com/stretchr/testify/assert" 6 | "testing" 7 | ) 8 | 9 | ///////// 10 | // Helpers 11 | ///////// 12 | 13 | func initialize_state_machine(state *StateMachine, t *testing.T) { 14 | state.StartAt = to.Strp("Start") 15 | state.States = States{} 16 | sm := parseTaskState([]byte(`{ 17 | "Resource": "asd", 18 | "Next": "Pass", 19 | "Retry": [{ "ErrorEquals": ["States.ALL"] }] 20 | }`), t) 21 | state.States["start"] = sm 22 | 23 | } 24 | 25 | // Execution 26 | 27 | func Test_MapState_ValidateResource(t *testing.T) { 28 | state := parseMapState([]byte(`{ "Next": "Pass"}`), t) 29 | assert.Error(t, state.Validate()) 30 | state.Iterator = &StateMachine{} 31 | assert.Error(t, state.Validate()) 32 | initialize_state_machine(state.Iterator, t) 33 | assert.NoError(t, state.Validate()) 34 | } 35 | 36 | func Test_MapState_SingleState(t *testing.T) { 37 | state := parseMapState([]byte(`{ 38 | "Type": "Map", 39 | "ItemsPath": "$.shipped", 40 | "ResultPath": "$.output.data", 41 | "OutputPath": "$.output", 42 | "MaxConcurrency": 0, 43 | "Iterator": { 44 | "StartAt": "Validate", 45 | "States": { 46 | "Validate": { 47 | "Type": "Pass", 48 | "Result": {"key": "value"}, 49 | "End": true 50 | } 51 | } 52 | }, 53 | "End": true 54 | }`), t) 55 | // Default 56 | outputResults := map[string]interface{}{} 57 | var res []map[string]interface{} 58 | res = append(res, map[string]interface{}{"key": "value"}) 59 | res = append(res, map[string]interface{}{"key": "value"}) 60 | res = append(res, map[string]interface{}{"key": "value"}) 61 | 62 | outputResults["data"] = res 63 | testState(state, stateTestData{ 64 | Input: map[string]interface{}{"shipped": []interface{}{1, 2, 3}, "output": []interface{}{}}, 65 | Output: outputResults, 66 | }, t) 67 | 68 | testState(state, stateTestData{ 69 | Input: map[string]interface{}{}, 70 | Error: to.Strp("GetSlice Error \"Not Found\""), 71 | }, t) 72 | } 73 | 74 | func Test_MapState_Catch(t *testing.T) { 75 | state := parseMapState([]byte(`{ 76 | "Type": "Map", 77 | "ItemsPath": "$.shipped", 78 | "ResultPath": "$.output.data", 79 | "OutputPath": "$.output", 80 | 81 | "MaxConcurrency": 0, 82 | "Catch": [{ 83 | "ErrorEquals": ["States.ALL"], 84 | "Next": "Fail" 85 | }], 86 | "Iterator": { 87 | "StartAt": "Validate", 88 | "States": { 89 | "Validate": { 90 | "Type": "Pass", 91 | "Result": {"key": "value"}, 92 | "End": true 93 | } 94 | } 95 | }, 96 | "End": true 97 | }`), t) 98 | 99 | // No Input path data. Should be caught 100 | testState(state, stateTestData{ 101 | Input: map[string]interface{}{}, 102 | Output: map[string]interface{}{"Error": "errorString", "Cause": "GetSlice Error \"Not Found\""}, 103 | }, t) 104 | 105 | } 106 | 107 | func Test_MapState_Integration(t *testing.T) { 108 | state := parseMapState([]byte(`{ 109 | "Type": "Map", 110 | "ItemsPath": "$.shipped", 111 | "ResultPath": "$.output.data", 112 | "OutputPath": "$.output", 113 | "MaxConcurrency": 0, 114 | "Iterator": { 115 | "StartAt": "Validate", 116 | "States": { 117 | "Validate": { 118 | "Type": "Pass", 119 | "Next": "Task" 120 | }, 121 | "Task" : { 122 | "Type": "TaskFn", 123 | "Resource": "arn:aws:lambda:{{aws_region}}:{{aws_account}}:function:{{lambda_name}}", 124 | "Result": {"key": "value"}, 125 | "End": true 126 | } 127 | } 128 | }, 129 | "End": true 130 | }`), t) 131 | 132 | // Default 133 | var task = state.Iterator.States["Task"].(*TaskState) 134 | task.SetTaskHandler(ReturnInputHandler) 135 | outputResults := map[string]interface{}{} 136 | var res []map[string]interface{} 137 | res = append(res, map[string]interface{}{"Task": "Task", "Input": float64(11)}) 138 | res = append(res, map[string]interface{}{"Task": "Task", "Input": float64(12)}) 139 | res = append(res, map[string]interface{}{"Task": "Task", "Input": float64(13)}) 140 | 141 | outputResults["data"] = res 142 | testState(state, stateTestData{ 143 | Input: map[string]interface{}{"shipped": []interface{}{11, 12, 13}, "output": []interface{}{}}, 144 | Output: outputResults, 145 | }, t) 146 | } 147 | -------------------------------------------------------------------------------- /machine/parallel_state.go: -------------------------------------------------------------------------------- 1 | package machine 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/coinbase/step/utils/to" 8 | ) 9 | 10 | type ParallelState struct { 11 | stateStr // Include Defaults 12 | 13 | Type *string 14 | Comment *string `json:",omitempty"` 15 | } 16 | 17 | func (s *ParallelState) Execute(_ context.Context, input interface{}) (output interface{}, next *string, err error) { 18 | return input, nil, nil 19 | } 20 | 21 | func (s *ParallelState) Validate() error { 22 | s.SetType(to.Strp("Parallel")) 23 | 24 | if err := ValidateNameAndType(s); err != nil { 25 | return fmt.Errorf("%v %v", errorPrefix(s), err) 26 | } 27 | 28 | return nil 29 | } 30 | 31 | func (s *ParallelState) SetType(t *string) { 32 | s.Type = t 33 | } 34 | 35 | func (s *ParallelState) GetType() *string { 36 | return s.Type 37 | } 38 | -------------------------------------------------------------------------------- /machine/parser.go: -------------------------------------------------------------------------------- 1 | // State Machine Parser 2 | package machine 3 | 4 | import ( 5 | "encoding/json" 6 | "fmt" 7 | "io/ioutil" 8 | 9 | "github.com/coinbase/step/utils/to" 10 | ) 11 | 12 | // Takes a file, and a map of Task Function s 13 | func ParseFile(file string) (*StateMachine, error) { 14 | raw, err := ioutil.ReadFile(file) 15 | if err != nil { 16 | return nil, err 17 | } 18 | 19 | json_sm, err := FromJSON(raw) 20 | return json_sm, err 21 | } 22 | 23 | func FromJSON(raw []byte) (*StateMachine, error) { 24 | var sm StateMachine 25 | err := json.Unmarshal(raw, &sm) 26 | return &sm, err 27 | } 28 | 29 | func (sm *States) UnmarshalJSON(b []byte) error { 30 | // States 31 | var rawStates map[string]*json.RawMessage 32 | err := json.Unmarshal(b, &rawStates) 33 | 34 | if err != nil { 35 | return err 36 | } 37 | 38 | newStates := States{} 39 | for name, raw := range rawStates { 40 | states, err := unmarshallState(name, raw) 41 | if err != nil { 42 | return err 43 | } 44 | 45 | for _, s := range states { 46 | newStates[*s.Name()] = s 47 | } 48 | } 49 | 50 | *sm = newStates 51 | return nil 52 | } 53 | 54 | type stateType struct { 55 | Type string 56 | } 57 | 58 | func unmarshallState(name string, raw_json *json.RawMessage) ([]State, error) { 59 | var err error 60 | 61 | // extract type (safer than regex) 62 | var state_type stateType 63 | if err = json.Unmarshal(*raw_json, &state_type); err != nil { 64 | return nil, err 65 | } 66 | 67 | var newState State 68 | 69 | switch state_type.Type { 70 | case "Pass": 71 | var s PassState 72 | err = json.Unmarshal(*raw_json, &s) 73 | newState = &s 74 | case "Task": 75 | var s TaskState 76 | err = json.Unmarshal(*raw_json, &s) 77 | newState = &s 78 | case "Choice": 79 | var s ChoiceState 80 | err = json.Unmarshal(*raw_json, &s) 81 | newState = &s 82 | case "Wait": 83 | var s WaitState 84 | err = json.Unmarshal(*raw_json, &s) 85 | newState = &s 86 | case "Succeed": 87 | var s SucceedState 88 | err = json.Unmarshal(*raw_json, &s) 89 | newState = &s 90 | case "Fail": 91 | var s FailState 92 | err = json.Unmarshal(*raw_json, &s) 93 | newState = &s 94 | case "Parallel": 95 | var s ParallelState 96 | err = json.Unmarshal(*raw_json, &s) 97 | newState = &s 98 | case "Map": 99 | var s MapState 100 | err = json.Unmarshal(*raw_json, &s) 101 | newState = &s 102 | case "TaskFn": 103 | // This is a custom state that adds values to Task to be handled 104 | var s TaskState 105 | err = json.Unmarshal(*raw_json, &s) 106 | // This will inject the Task name into the input 107 | s.Parameters = map[string]interface{}{"Task": name, "Input.$": "$"} 108 | s.Type = to.Strp("Task") 109 | newState = &s 110 | default: 111 | err = fmt.Errorf("Unknown State %q", state_type.Type) 112 | } 113 | 114 | // End of loop return error 115 | if err != nil { 116 | return nil, err 117 | } 118 | 119 | // Set Name and Defaults 120 | newName := name 121 | newState.SetName(&newName) // Require New Variable Pointer 122 | 123 | return []State{newState}, nil 124 | } 125 | -------------------------------------------------------------------------------- /machine/parser_test.go: -------------------------------------------------------------------------------- 1 | package machine 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func Test_Machine_Parser_FromJSON(t *testing.T) { 10 | json := []byte(` 11 | { 12 | "Comment": "Adds some coordinates to the input", 13 | "StartAt": "Coords", 14 | "States": { 15 | "Coords": { 16 | "Type": "Pass", 17 | "Result": { 18 | "x": 3.14, 19 | "y": 103.14159 20 | }, 21 | "ResultPath": "$.coords", 22 | "End": true 23 | } 24 | } 25 | }`) 26 | 27 | _, err := FromJSON(json) 28 | 29 | assert.Equal(t, err, nil) 30 | } 31 | 32 | func Test_Parser_Expands_TaskFn(t *testing.T) { 33 | json := []byte(` 34 | { 35 | "StartAt": "A", 36 | "States": { 37 | "A": { 38 | "Type": "TaskFn", 39 | "Next": "B" 40 | }, 41 | "B": { 42 | "Type": "TaskFn", 43 | "End": true 44 | } 45 | } 46 | }`) 47 | 48 | sm, err := FromJSON(json) 49 | assert.NoError(t, err) 50 | 51 | // Names and Types 52 | assert.Equal(t, len(sm.States), 2) 53 | assert.Equal(t, *sm.States["A"].GetType(), "Task") 54 | assert.Equal(t, *sm.States["B"].GetType(), "Task") 55 | 56 | ataskState := sm.States["A"].(*TaskState) 57 | btaskState := sm.States["B"].(*TaskState) 58 | 59 | // ORDER 60 | assert.Equal(t, ataskState.Parameters, map[string]interface{}{"Task": "A", "Input.$": "$"}) 61 | assert.Equal(t, btaskState.Parameters, map[string]interface{}{"Task": "B", "Input.$": "$"}) 62 | } 63 | 64 | func Test_Machine_Parser_FileNonexistantFile(t *testing.T) { 65 | _, err := ParseFile("../examples/non_existent_file.json") 66 | assert.Error(t, err) 67 | } 68 | 69 | func Test_Machine_Parser_OfBadStateType(t *testing.T) { 70 | _, err := ParseFile("../examples/bad_type.json") 71 | 72 | assert.Error(t, err) 73 | assert.Regexp(t, "Unknown State", err.Error()) 74 | } 75 | 76 | func Test_Machine_Parser_OfBadPath(t *testing.T) { 77 | _, err := ParseFile("../examples/bad_path.json") 78 | 79 | assert.Error(t, err) 80 | assert.Regexp(t, "Bad JSON path", err.Error()) 81 | } 82 | 83 | // BASIC TYPE TESTS 84 | 85 | func Test_Machine_Parser_AllTypes(t *testing.T) { 86 | sm, err := ParseFile("../examples/all_types.json") 87 | assert.NoError(t, err) 88 | 89 | assert.NoError(t, sm.Validate()) 90 | } 91 | 92 | func Test_Machine_Parser_BasicPass(t *testing.T) { 93 | sm, err := ParseFile("../examples/basic_pass.json") 94 | assert.NoError(t, err) 95 | assert.NoError(t, sm.Validate()) 96 | } 97 | 98 | func Test_Machine_Parser_BasicChoice(t *testing.T) { 99 | sm, err := ParseFile("../examples/basic_choice.json") 100 | 101 | assert.Equal(t, err, nil) 102 | assert.NoError(t, sm.Validate()) 103 | } 104 | 105 | func Test_Machine_Parser_TaskFn(t *testing.T) { 106 | sm, err := ParseFile("../examples/taskfn.json") 107 | 108 | assert.Equal(t, err, nil) 109 | assert.NoError(t, sm.Validate()) 110 | } 111 | 112 | func Test_Machine_Parser_Map(t *testing.T) { 113 | sm, err := ParseFile("../examples/map.json") 114 | var mapState *MapState 115 | mapState = sm.States["Start"].(*MapState) 116 | assert.Equal(t, err, nil) 117 | assert.NoError(t, sm.Validate()) 118 | assert.Equal(t, "$.detail", mapState.InputPath.String(), ) 119 | assert.Equal(t, "$.shipped", mapState.ItemsPath.String(), ) 120 | assert.Equal(t, "$.detail.shipped", mapState.ResultPath.String(), ) 121 | assert.Equal(t, 1, len(mapState.Iterator.States)) 122 | assert.Equal(t, "Task", *mapState.Iterator.States["Validate"].GetType(), ) 123 | 124 | } 125 | -------------------------------------------------------------------------------- /machine/pass_state.go: -------------------------------------------------------------------------------- 1 | package machine 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/coinbase/step/jsonpath" 8 | "github.com/coinbase/step/utils/to" 9 | ) 10 | 11 | type PassState struct { 12 | stateStr // Include Defaults 13 | 14 | Type *string 15 | Comment *string `json:",omitempty"` 16 | 17 | InputPath *jsonpath.Path `json:",omitempty"` 18 | OutputPath *jsonpath.Path `json:",omitempty"` 19 | ResultPath *jsonpath.Path `json:",omitempty"` 20 | 21 | Result interface{} `json:",omitempty"` 22 | 23 | Next *string `json:",omitempty"` 24 | End *bool `json:",omitempty"` 25 | } 26 | 27 | func (s *PassState) Execute(ctx context.Context, input interface{}) (output interface{}, next *string, err error) { 28 | return processError(s, 29 | inputOutput( 30 | s.InputPath, 31 | s.OutputPath, 32 | result(s.ResultPath, s.process), 33 | ), 34 | )(ctx, input) 35 | } 36 | 37 | func (s *PassState) process(ctx context.Context, input interface{}) (output interface{}, next *string, err error) { 38 | return s.Result, nextState(s.Next, s.End), nil 39 | } 40 | 41 | func (s *PassState) Validate() error { 42 | s.SetType(to.Strp("Pass")) 43 | 44 | if err := ValidateNameAndType(s); err != nil { 45 | return fmt.Errorf("%v %v", errorPrefix(s), err) 46 | } 47 | 48 | // Next xor End 49 | if err := endValid(s.Next, s.End); err != nil { 50 | return fmt.Errorf("%v %v", errorPrefix(s), err) 51 | } 52 | 53 | return nil 54 | } 55 | 56 | func (s *PassState) SetType(t *string) { 57 | s.Type = t 58 | } 59 | 60 | func (s *PassState) GetType() *string { 61 | return s.Type 62 | } 63 | -------------------------------------------------------------------------------- /machine/pass_state_test.go: -------------------------------------------------------------------------------- 1 | package machine 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/coinbase/step/utils/to" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func Test_PassState_Defaults(t *testing.T) { 11 | state := parsePassState([]byte(`{ "Next": "Pass", "End": true}`), t) 12 | err := state.Validate() 13 | assert.Error(t, err) 14 | 15 | assert.Equal(t, *state.GetType(), "Pass") 16 | assert.Equal(t, errorPrefix(state), "PassState(TestState) Error:") 17 | 18 | assert.Regexp(t, "End and Next both defined", err.Error()) 19 | } 20 | 21 | // Validations 22 | 23 | func Test_PassState_EndNextBothDefined(t *testing.T) { 24 | state := parsePassState([]byte(`{ "Next": "Pass", "End": true}`), t) 25 | err := state.Validate() 26 | assert.Error(t, err) 27 | 28 | assert.Regexp(t, "End and Next both defined", err.Error()) 29 | } 30 | 31 | func Test_PassState_EndNextBothUnDefined(t *testing.T) { 32 | state := parsePassState([]byte(`{}`), t) 33 | err := state.Validate() 34 | assert.Error(t, err) 35 | 36 | assert.Regexp(t, "End and Next both undefined", err.Error()) 37 | } 38 | 39 | // Execution 40 | 41 | func Test_PassState_ResultPath(t *testing.T) { 42 | state := parsePassState([]byte(`{ "Next": "Pass", "Result": "b", "ResultPath": "$.a"}`), t) 43 | testState(state, stateTestData{Output: map[string]interface{}{"a": "b"}}, t) 44 | } 45 | 46 | func Test_PassState_ResultPathOverrwite(t *testing.T) { 47 | state := parsePassState([]byte(`{ "Next": "Pass", "Result": "b", "ResultPath": "$.a"}`), t) 48 | testState(state, stateTestData{ 49 | Input: map[string]interface{}{"a": "c"}, 50 | Output: map[string]interface{}{"a": "b"}, 51 | }, t) 52 | } 53 | 54 | func Test_PassState_InputPath(t *testing.T) { 55 | state := parsePassState([]byte(`{"Next": "Pass", "InputPath": "$.a"}`), t) 56 | 57 | deep := map[string]interface{}{"a": "b"} 58 | input := map[string]interface{}{"a": deep} 59 | 60 | testState(state, stateTestData{ 61 | Input: input, 62 | Output: deep, 63 | }, t) 64 | } 65 | 66 | func Test_PassState_OutputPath(t *testing.T) { 67 | state := parsePassState([]byte(`{ "Next": "Pass", "OutputPath": "$.a"}`), t) 68 | 69 | deep := map[string]interface{}{"a": "b"} 70 | input := map[string]interface{}{"a": deep} 71 | 72 | testState(state, stateTestData{ 73 | Input: input, 74 | Output: deep, 75 | }, t) 76 | } 77 | 78 | // Bad Execution 79 | 80 | func Test_PassState_BadInputPath(t *testing.T) { 81 | state := parsePassState([]byte(`{"Next": "Pass","InputPath": "$.a.b"}`), t) 82 | 83 | testState(state, stateTestData{ 84 | Input: map[string]interface{}{"a": "b"}, 85 | Error: to.Strp("Input Error"), 86 | }, t) 87 | } 88 | 89 | func Test_PassState_BadOutputPath(t *testing.T) { 90 | state := parsePassState([]byte(`{"Next": "Pass","OutputPath": "$.a.b"}`), t) 91 | 92 | testState(state, stateTestData{ 93 | Input: map[string]interface{}{"a": "b"}, 94 | Error: to.Strp("Output Error"), 95 | }, t) 96 | } 97 | -------------------------------------------------------------------------------- /machine/state_test.go: -------------------------------------------------------------------------------- 1 | package machine 2 | 3 | import ( 4 | "encoding/json" 5 | "testing" 6 | 7 | "github.com/coinbase/step/utils/to" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | type stateTestData struct { 12 | Input map[string]interface{} 13 | Output map[string]interface{} 14 | Error *string 15 | Next *string 16 | } 17 | 18 | func testState(state State, std stateTestData, t *testing.T) { 19 | // Make sure the execution is on Valid State 20 | err := state.Validate() 21 | assert.NoError(t, err) 22 | 23 | // default empty input 24 | if std.Input == nil { 25 | std.Input = map[string]interface{}{} 26 | } 27 | 28 | output, next, err := state.Execute(nil, std.Input) 29 | 30 | // expecting error? 31 | if std.Error != nil { 32 | assert.Error(t, err) 33 | assert.Regexp(t, *std.Error, err.Error()) 34 | } else if err != nil { 35 | assert.NoError(t, err) 36 | } 37 | 38 | if std.Output != nil { 39 | assert.Equal(t, std.Output, output) 40 | } 41 | 42 | if std.Next != nil { 43 | assert.Equal(t, *std.Next, *next) 44 | } 45 | } 46 | 47 | func parseChoiceState(b []byte, t *testing.T) *ChoiceState { 48 | var p ChoiceState 49 | err := json.Unmarshal(b, &p) 50 | assert.NoError(t, err) 51 | p.SetName(to.Strp("TestState")) 52 | p.SetType(to.Strp("Choice")) 53 | return &p 54 | } 55 | 56 | func parsePassState(b []byte, t *testing.T) *PassState { 57 | var p PassState 58 | err := json.Unmarshal(b, &p) 59 | assert.NoError(t, err) 60 | p.SetName(to.Strp("TestState")) 61 | p.SetType(to.Strp("Pass")) 62 | return &p 63 | } 64 | 65 | func parseWaitState(b []byte, t *testing.T) *WaitState { 66 | var p WaitState 67 | err := json.Unmarshal(b, &p) 68 | assert.NoError(t, err) 69 | p.SetName(to.Strp("TestState")) 70 | p.SetType(to.Strp("Wait")) 71 | return &p 72 | } 73 | 74 | func parseTaskState(b []byte, t *testing.T) *TaskState { 75 | var p TaskState 76 | err := json.Unmarshal(b, &p) 77 | assert.NoError(t, err) 78 | p.SetName(to.Strp("TestState")) 79 | p.SetType(to.Strp("Task")) 80 | return &p 81 | } 82 | 83 | func parseValidTaskState(b []byte, handler interface{}, t *testing.T) *TaskState { 84 | state := parseTaskState(b, t) 85 | state.SetTaskHandler(handler) 86 | assert.NoError(t, state.Validate()) 87 | return state 88 | } 89 | 90 | func parseMapState(b []byte, t *testing.T) *MapState { 91 | var p MapState 92 | err := json.Unmarshal(b, &p) 93 | assert.NoError(t, err) 94 | p.SetName(to.Strp("TestState")) 95 | p.SetType(to.Strp("Map")) 96 | return &p 97 | } 98 | -------------------------------------------------------------------------------- /machine/succeed_state.go: -------------------------------------------------------------------------------- 1 | package machine 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/coinbase/step/jsonpath" 8 | "github.com/coinbase/step/utils/to" 9 | ) 10 | 11 | type SucceedState struct { 12 | stateStr // Include Defaults 13 | 14 | Type *string 15 | Comment *string `json:",omitempty"` 16 | 17 | InputPath *jsonpath.Path `json:",omitempty"` 18 | OutputPath *jsonpath.Path `json:",omitempty"` 19 | } 20 | 21 | func (s *SucceedState) process(ctx context.Context, input interface{}) (interface{}, *string, error) { 22 | return input, nil, nil 23 | } 24 | 25 | func (s *SucceedState) Execute(ctx context.Context, input interface{}) (output interface{}, next *string, err error) { 26 | return processError(s, 27 | inputOutput( 28 | s.InputPath, 29 | s.OutputPath, 30 | s.process, 31 | ), 32 | )(ctx, input) 33 | } 34 | 35 | func (s *SucceedState) Validate() error { 36 | s.SetType(to.Strp("Succeed")) 37 | 38 | if err := ValidateNameAndType(s); err != nil { 39 | return fmt.Errorf("%v %v", errorPrefix(s), err) 40 | } 41 | 42 | return nil 43 | } 44 | 45 | func (s *SucceedState) SetType(t *string) { 46 | s.Type = t 47 | } 48 | 49 | func (s *SucceedState) GetType() *string { 50 | return s.Type 51 | } 52 | -------------------------------------------------------------------------------- /machine/task_state.go: -------------------------------------------------------------------------------- 1 | package machine 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/coinbase/step/handler" 8 | "github.com/coinbase/step/jsonpath" 9 | "github.com/coinbase/step/utils/to" 10 | ) 11 | 12 | type TaskState struct { 13 | stateStr // Include Defaults 14 | 15 | Type *string 16 | Comment *string `json:",omitempty"` 17 | 18 | InputPath *jsonpath.Path `json:",omitempty"` 19 | OutputPath *jsonpath.Path `json:",omitempty"` 20 | ResultPath *jsonpath.Path `json:",omitempty"` 21 | Parameters interface{} `json:",omitempty"` 22 | 23 | Resource *string `json:",omitempty"` 24 | 25 | Catch []*Catcher `json:",omitempty"` 26 | Retry []*Retrier `json:",omitempty"` 27 | 28 | // Maps a Lambda Handler Function 29 | TaskHandler interface{} `json:"-"` 30 | 31 | Next *string `json:",omitempty"` 32 | End *bool `json:",omitempty"` 33 | 34 | TimeoutSeconds int `json:",omitempty"` 35 | HeartbeatSeconds int `json:",omitempty"` 36 | } 37 | 38 | func (s *TaskState) SetTaskHandler(resourcefn interface{}) { 39 | s.TaskHandler = resourcefn 40 | } 41 | 42 | func (s *TaskState) process(ctx context.Context, input interface{}) (interface{}, *string, error) { 43 | result, err := handler.CallHandlerFunction(s.TaskHandler, ctx, input) 44 | 45 | if err != nil { 46 | return nil, nil, err 47 | } 48 | 49 | result, err = to.FromJSON(result) 50 | 51 | if err != nil { 52 | return nil, nil, err 53 | } 54 | 55 | return result, nextState(s.Next, s.End), nil 56 | } 57 | 58 | // Input must include the Task name in $.Task 59 | func (s *TaskState) Execute(ctx context.Context, input interface{}) (output interface{}, next *string, err error) { 60 | return processError(s, 61 | processCatcher(s.Catch, 62 | processRetrier(s.Name(), s.Retry, 63 | inputOutput( 64 | s.InputPath, 65 | s.OutputPath, 66 | withParams( 67 | s.Parameters, 68 | result(s.ResultPath, s.process), 69 | ), 70 | ), 71 | ), 72 | ), 73 | )(ctx, input) 74 | } 75 | 76 | func (s *TaskState) Validate() error { 77 | s.SetType(to.Strp("Task")) 78 | 79 | if err := ValidateNameAndType(s); err != nil { 80 | return fmt.Errorf("%v %v", errorPrefix(s), err) 81 | } 82 | 83 | if err := endValid(s.Next, s.End); err != nil { 84 | return fmt.Errorf("%v %v", errorPrefix(s), err) 85 | } 86 | 87 | if s.Resource == nil { 88 | return fmt.Errorf("%v Requires Resource", errorPrefix(s)) 89 | } 90 | 91 | if s.TaskHandler != nil { 92 | if err := handler.ValidateHandler(s.TaskHandler); err != nil { 93 | return err 94 | } 95 | } 96 | 97 | if err := catchValid(s.Catch); err != nil { 98 | return err 99 | } 100 | 101 | if err := retryValid(s.Retry); err != nil { 102 | return err 103 | } 104 | 105 | return nil 106 | } 107 | 108 | func (s *TaskState) SetType(t *string) { 109 | s.Type = t 110 | } 111 | 112 | func (s *TaskState) GetType() *string { 113 | return s.Type 114 | } 115 | -------------------------------------------------------------------------------- /machine/task_state_test.go: -------------------------------------------------------------------------------- 1 | package machine 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/coinbase/step/utils/to" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | ///////// 12 | // TYPES 13 | ///////// 14 | 15 | type TestError struct{} 16 | 17 | func (t *TestError) Error() string { 18 | return "This is a Test Error" 19 | } 20 | 21 | type TestHandler func(context.Context, interface{}) (interface{}, error) 22 | 23 | func countCalls(th TestHandler) (TestHandler, *int) { 24 | calls := 0 25 | return func(ctx context.Context, input interface{}) (interface{}, error) { 26 | calls++ 27 | return th(ctx, input) 28 | }, &calls 29 | } 30 | 31 | func ThrowTestErrorHandler(_ context.Context, input interface{}) (interface{}, error) { 32 | return nil, &TestError{} 33 | } 34 | 35 | func ReturnMapTestHandler(_ context.Context, input interface{}) (interface{}, error) { 36 | return map[string]interface{}{"z": "y"}, nil 37 | } 38 | 39 | func ReturnInputHandler(_ context.Context, input interface{}) (interface{}, error) { 40 | return input, nil 41 | } 42 | 43 | // Execution 44 | 45 | func Test_TaskState_ValidateResource(t *testing.T) { 46 | state := parseTaskState([]byte(`{ "Next": "Pass"}`), t) 47 | assert.Error(t, state.Validate()) 48 | state.Resource = to.Strp("resource") 49 | assert.NoError(t, state.Validate()) 50 | } 51 | 52 | func Test_TaskState_Valid_ErrorEquals_StatesAll(t *testing.T) { 53 | state := parseTaskState([]byte(`{ 54 | "Resource": "asd", 55 | "Next": "Pass", 56 | "Retry": [{ "ErrorEquals": ["States.ALL"] }] 57 | }`), t) 58 | 59 | assert.NoError(t, state.Validate()) 60 | 61 | state = parseTaskState([]byte(`{ 62 | "Resource": "asd", 63 | "Next": "Pass", 64 | "Retry": [{ "ErrorEquals": ["States.ALL", "NoMoreErrors"] }] 65 | }`), t) 66 | assert.Error(t, state.Validate()) 67 | 68 | state = parseTaskState([]byte(`{ 69 | "Resource": "asd", 70 | "Next": "Pass", 71 | "Retry": [{ "ErrorEquals": ["States.ALL"] }, { "ErrorEquals": ["NotLast"] }] 72 | }`), t) 73 | 74 | state = parseTaskState([]byte(`{ 75 | "Resource": "asd", 76 | "Next": "Pass", 77 | "Retry": [{ "ErrorEquals": ["States.NotRealError"] }] 78 | }`), t) 79 | 80 | assert.Error(t, state.Validate()) 81 | } 82 | 83 | func Test_TaskState_TaskHandler(t *testing.T) { 84 | th, calls := countCalls(ReturnMapTestHandler) 85 | 86 | state := parseValidTaskState([]byte(`{ "Next": "Pass", "Resource": "test"}`), th, t) 87 | 88 | testState(state, stateTestData{ 89 | Input: map[string]interface{}{"a": "c"}, 90 | Output: map[string]interface{}{"z": "y"}, 91 | }, t) 92 | 93 | assert.Equal(t, 1, *calls) 94 | } 95 | 96 | func Test_TaskState_Catch_Works(t *testing.T) { 97 | state := parseValidTaskState([]byte(`{ 98 | "Next": "Pass", 99 | "Resource": "test", 100 | "Catch": [{ 101 | "ErrorEquals": ["TestError"], 102 | "Next": "Fail" 103 | }] 104 | }`), ThrowTestErrorHandler, t) 105 | 106 | testState(state, stateTestData{ 107 | Input: map[string]interface{}{"a": "c"}, 108 | Output: map[string]interface{}{"Error": "TestError", "Cause": "This is a Test Error"}, 109 | Next: to.Strp("Fail"), 110 | }, t) 111 | } 112 | 113 | func Test_TaskState_Catch_Doesnt_Catch(t *testing.T) { 114 | state := parseValidTaskState([]byte(`{ 115 | "Next": "Pass", 116 | "Resource": "test", 117 | "Catch": [{ 118 | "ErrorEquals": ["NotTestError"], 119 | "Next": "Fail" 120 | }] 121 | }`), ThrowTestErrorHandler, t) 122 | 123 | testState(state, stateTestData{ 124 | Input: map[string]interface{}{"a": "c"}, 125 | Error: to.Strp("This is a Test Error"), 126 | }, t) 127 | } 128 | 129 | func Test_TaskState_Retry_Works(t *testing.T) { 130 | th, calls := countCalls(ThrowTestErrorHandler) 131 | 132 | state := parseValidTaskState([]byte(`{ 133 | "Next": "Pass", 134 | "Resource": "test", 135 | "Retry": [{ 136 | "ErrorEquals": ["TestError"], 137 | "MaxAttempts": 2 138 | }] 139 | }`), th, t) 140 | 141 | testState(state, stateTestData{ 142 | Input: map[string]interface{}{"a": "c"}, 143 | Next: state.Name(), 144 | }, t) 145 | 146 | testState(state, stateTestData{ 147 | Input: map[string]interface{}{"a": "c"}, 148 | Next: state.Name(), 149 | }, t) 150 | 151 | testState(state, stateTestData{ 152 | Input: map[string]interface{}{"a": "c"}, 153 | Error: to.Strp("This is a Test Error"), 154 | }, t) 155 | 156 | // 1 initial call, + 2 retries 157 | assert.Equal(t, 3, *calls) 158 | } 159 | 160 | func Test_TaskState_Catch_AND_Retry_Works(t *testing.T) { 161 | th, calls := countCalls(ThrowTestErrorHandler) 162 | 163 | state := parseValidTaskState([]byte(`{ 164 | "Next": "Pass", 165 | "Resource": "test", 166 | "Retry": [{ 167 | "ErrorEquals": ["TestError"], 168 | "MaxAttempts": 1 169 | }], 170 | "Catch": [{ 171 | "ErrorEquals": ["TestError"], 172 | "Next": "Fail" 173 | }] 174 | }`), th, t) 175 | 176 | testState(state, stateTestData{ 177 | Input: map[string]interface{}{"a": "c"}, 178 | Next: state.Name(), 179 | }, t) 180 | 181 | testState(state, stateTestData{ 182 | Input: map[string]interface{}{"a": "c"}, 183 | Next: to.Strp("Fail"), 184 | }, t) 185 | 186 | assert.Equal(t, 2, *calls) 187 | } 188 | 189 | func Test_TaskState_Catch_AND_Retry_StateAll(t *testing.T) { 190 | th, calls := countCalls(ThrowTestErrorHandler) 191 | 192 | state := parseValidTaskState([]byte(`{ 193 | "Next": "Pass", 194 | "Resource": "test", 195 | "Retry": [{ 196 | "ErrorEquals": ["States.ALL"], 197 | "MaxAttempts": 1 198 | }], 199 | "Catch": [{ 200 | "ErrorEquals": ["States.ALL"], 201 | "Next": "Fail" 202 | }] 203 | }`), th, t) 204 | 205 | testState(state, stateTestData{ 206 | Input: map[string]interface{}{"a": "c"}, 207 | Next: state.Name(), 208 | }, t) 209 | 210 | testState(state, stateTestData{ 211 | Input: map[string]interface{}{"a": "c"}, 212 | Next: to.Strp("Fail"), 213 | }, t) 214 | 215 | assert.Equal(t, 2, *calls) 216 | } 217 | 218 | func Test_TaskState_Catch_AND_Dont_Retry(t *testing.T) { 219 | th, calls := countCalls(ThrowTestErrorHandler) 220 | 221 | state := parseValidTaskState([]byte(`{ 222 | "Next": "Pass", 223 | "Resource": "test", 224 | "Retry": [{ 225 | "ErrorEquals": ["TestError"], 226 | "MaxAttempts": 1 227 | },{ 228 | "ErrorEquals": ["States.ALL"] 229 | }], 230 | "Catch": [{ 231 | "ErrorEquals": ["States.ALL"], 232 | "Next": "Fail" 233 | }] 234 | }`), th, t) 235 | 236 | testState(state, stateTestData{ 237 | Input: map[string]interface{}{"a": "c"}, 238 | Next: state.Name(), 239 | }, t) 240 | 241 | testState(state, stateTestData{ 242 | Input: map[string]interface{}{"a": "c"}, 243 | Next: to.Strp("Fail"), 244 | }, t) 245 | 246 | assert.Equal(t, 2, *calls) 247 | } 248 | 249 | func Test_TaskState_Parameters(t *testing.T) { 250 | state := parseValidTaskState([]byte(`{ 251 | "Next": "Pass", 252 | "Resource": "test", 253 | "Parameters": {"Task": "Noop", "Input.$": "$.x"} 254 | }`), ReturnInputHandler, t) 255 | 256 | testState(state, stateTestData{ 257 | Input: map[string]interface{}{"x": "AHAH"}, 258 | Output: map[string]interface{}{"Task": "Noop", "Input": "AHAH"}, 259 | }, t) 260 | } 261 | 262 | func Test_TaskState_InputPath_and_Parameters(t *testing.T) { 263 | state := parseValidTaskState([]byte(`{ 264 | "Next": "Pass", 265 | "Resource": "test", 266 | "InputPath": "$.x", 267 | "Parameters": {"Task": "Noop", "Input.$": "$"} 268 | }`), ReturnInputHandler, t) 269 | 270 | testState(state, stateTestData{ 271 | Input: map[string]interface{}{"x": "AHAH"}, 272 | Output: map[string]interface{}{"Task": "Noop", "Input": "AHAH"}, 273 | }, t) 274 | } 275 | -------------------------------------------------------------------------------- /machine/wait_state.go: -------------------------------------------------------------------------------- 1 | package machine 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/coinbase/step/jsonpath" 9 | "github.com/coinbase/step/utils/to" 10 | ) 11 | 12 | type WaitState struct { 13 | stateStr // Include Defaults 14 | 15 | Type *string 16 | Comment *string `json:",omitempty"` 17 | 18 | InputPath *jsonpath.Path `json:",omitempty"` 19 | OutputPath *jsonpath.Path `json:",omitempty"` 20 | 21 | Seconds *float64 `json:",omitempty"` 22 | SecondsPath *jsonpath.Path `json:",omitempty"` 23 | 24 | Timestamp *time.Time `json:",omitempty"` 25 | TimestampPath *jsonpath.Path `json:",omitempty"` 26 | 27 | Next *string `json:",omitempty"` 28 | End *bool `json:",omitempty"` 29 | } 30 | 31 | func (s *WaitState) process(ctx context.Context, input interface{}) (interface{}, *string, error) { 32 | 33 | if s.SecondsPath != nil { 34 | // Validate the path exists 35 | _, err := s.SecondsPath.GetNumber(input) 36 | if err != nil { 37 | return nil, nil, err 38 | } 39 | 40 | } else if s.TimestampPath != nil { 41 | // Validate the path exists 42 | _, err := s.TimestampPath.GetTime(input) 43 | if err != nil { 44 | return nil, nil, err 45 | } 46 | } 47 | 48 | // Always sleep the same amount of time, as this is a simulation 49 | time.Sleep(50 * time.Millisecond) 50 | 51 | return input, nextState(s.Next, s.End), nil 52 | } 53 | 54 | func (s *WaitState) Execute(ctx context.Context, input interface{}) (output interface{}, next *string, err error) { 55 | return processError(s, 56 | inputOutput( 57 | s.InputPath, 58 | s.OutputPath, 59 | s.process, 60 | ), 61 | )(ctx, input) 62 | } 63 | 64 | func (s *WaitState) Validate() error { 65 | s.SetType(to.Strp("Wait")) 66 | 67 | if err := ValidateNameAndType(s); err != nil { 68 | return fmt.Errorf("%v %v", errorPrefix(s), err) 69 | } 70 | 71 | // Next xor End 72 | if err := endValid(s.Next, s.End); err != nil { 73 | return fmt.Errorf("%v %v", errorPrefix(s), err) 74 | } 75 | 76 | exactly_one := []bool{ 77 | s.Seconds != nil, 78 | s.SecondsPath != nil, 79 | s.Timestamp != nil, 80 | s.TimestampPath != nil, 81 | } 82 | 83 | count := 0 84 | for _, c := range exactly_one { 85 | if c { 86 | count += 1 87 | } 88 | } 89 | 90 | if count != 1 { 91 | return fmt.Errorf("%v Exactly One (Seconds,SecondsPath,TimeStamp,TimeStampPath)", errorPrefix(s)) 92 | } 93 | 94 | return nil 95 | } 96 | 97 | func (s *WaitState) SetType(t *string) { 98 | s.Type = t 99 | } 100 | 101 | func (s *WaitState) GetType() *string { 102 | return s.Type 103 | } 104 | -------------------------------------------------------------------------------- /machine/wait_state_test.go: -------------------------------------------------------------------------------- 1 | package machine 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func Test_WaitState_XORofFields(t *testing.T) { 10 | state := parseWaitState([]byte(` 11 | { 12 | "Seconds": 10, 13 | "TimestampPath": "$.a.b", 14 | "Timestamp": "2006-01-02T15:04:05Z", 15 | "Next": "Public" 16 | }`), t) 17 | 18 | err := state.Validate() 19 | assert.Error(t, err) 20 | 21 | assert.Regexp(t, "Exactly One", err.Error()) 22 | } 23 | 24 | func Test_WaitState_SecondsPath(t *testing.T) { 25 | state := parseWaitState([]byte(` 26 | { 27 | "SecondsPath": "$.path", 28 | "Next": "Public" 29 | }`), t) 30 | 31 | _, _, err := state.Execute(nil, map[string]interface{}{"path": 30}) 32 | assert.NoError(t, err) 33 | 34 | _, _, err = state.Execute(nil, map[string]interface{}{}) 35 | assert.Error(t, err) 36 | } 37 | -------------------------------------------------------------------------------- /resources/empty_lambda.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coinbase/step/301282845bfb07879a39d0c2af36720633a61609/resources/empty_lambda.zip -------------------------------------------------------------------------------- /resources/step-deployer.rb: -------------------------------------------------------------------------------- 1 | # GeoEngineer Resources For Step Function Deployer 2 | # GEO_ENV=development bundle exec geo apply resources/step-deployer.rb 3 | 4 | ######################################## 5 | ### ENVIRONMENT ### 6 | ######################################## 7 | 8 | env = environment('development') { 9 | region ENV.fetch('AWS_REGION') 10 | account_id ENV.fetch('AWS_ACCOUNT_ID') 11 | } 12 | 13 | ######################################## 14 | ### PROJECT ### 15 | ######################################## 16 | project = project('coinbase', 'step-deployer') { 17 | environments 'development' 18 | tags { 19 | ProjectName "coinbase/step-deployer" 20 | ConfigName "development" 21 | DeployWith "step-deployer" 22 | self[:org] = "coinbase" 23 | self[:project] = "step-deployer" 24 | } 25 | } 26 | 27 | context = { 28 | assumed_role_name: "coinbase-step-deployer-assumed", 29 | assumable_from: [ ENV['AWS_ACCOUNT_ID'] ], 30 | assumed_policy_file: "#{__dir__}/step_assumed_policy.json.erb" 31 | } 32 | 33 | project.from_template('bifrost_deployer', 'step-deployer', { 34 | lambda_policy_file: "#{__dir__}/step_lambda_policy.json.erb", 35 | lambda_policy_context: context 36 | }) 37 | 38 | # The assumed role exists in all environments 39 | project.from_template('step_assumed', 'coinbase-step-deployer-assumed', context) 40 | -------------------------------------------------------------------------------- /resources/step_assumed_policy.json.erb: -------------------------------------------------------------------------------- 1 | { 2 | "Version": "2012-10-17", 3 | "Statement": [ 4 | { 5 | "Effect": "Allow", 6 | "Action": [ 7 | "states:DescribeStateMachine", 8 | "lambda:ListTags", 9 | "lambda:GetFunction", 10 | "states:UpdateStateMachine", 11 | "lambda:UpdateFunctionCode", 12 | "lambda:UpdateFunctionConfiguration" 13 | ], 14 | "Resource": [ 15 | "*" 16 | ] 17 | } 18 | ] 19 | } 20 | -------------------------------------------------------------------------------- /resources/step_lambda_policy.json.erb: -------------------------------------------------------------------------------- 1 | { 2 | "Version": "2012-10-17", 3 | "Statement": [ 4 | { 5 | "Effect": "Allow", 6 | "Resource": "arn:aws:iam::*:role/<%= assumed_role_name %>", 7 | "Action": "sts:AssumeRole" 8 | }, 9 | { 10 | "Effect": "Allow", 11 | "Action": [ 12 | "s3:GetObject*", 13 | "s3:PutObject*", 14 | "s3:DeleteObject*", 15 | "s3:ListBucket" 16 | ], 17 | "Resource": [ 18 | "arn:aws:s3:::<%= s3_bucket_name %>/*", 19 | "arn:aws:s3:::<%= s3_bucket_name %>" 20 | ] 21 | }, 22 | { 23 | "Effect": "Deny", 24 | "Action": [ 25 | "s3:*" 26 | ], 27 | "NotResource": [ 28 | "arn:aws:s3:::<%= s3_bucket_name %>/*", 29 | "arn:aws:s3:::<%= s3_bucket_name %>" 30 | ] 31 | } 32 | ] 33 | } 34 | -------------------------------------------------------------------------------- /scripts/bootstrap_deployer: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # shortcut script to Bootstrap the step deployer 3 | # Bootstrapping is deploying itself from a local environment 4 | set -e 5 | 6 | ./scripts/build_lambda_zip 7 | 8 | go build && go install 9 | step bootstrap \ 10 | -lambda "coinbase-step-deployer" \ 11 | -step "coinbase-step-deployer" \ 12 | -states "$(step json)" \ 13 | -project "coinbase/step-deployer"\ 14 | -config "development" 15 | 16 | rm lambda.zip 17 | 18 | -------------------------------------------------------------------------------- /scripts/build_lambda_zip: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Build Lambda Zip 3 | set -e 4 | 5 | # Build step (called lambda) for linux lambda 6 | GOOS=linux go build -o lambda 7 | zip lambda.zip lambda 8 | rm lambda 9 | -------------------------------------------------------------------------------- /scripts/deploy_deployer: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # shortcut script to Deploy the step deployer 3 | # So this uses the Step executable to call out to the Step Deployer 4 | # which is a Step Function to deploy itself 5 | set -e 6 | 7 | ./scripts/build_lambda_zip 8 | 9 | go build && go install 10 | step deploy \ 11 | -lambda "coinbase-step-deployer" \ 12 | -step "coinbase-step-deployer" \ 13 | -states "$(step json)" \ 14 | -project "coinbase/step-deployer"\ 15 | -config "development" 16 | 17 | rm lambda.zip 18 | -------------------------------------------------------------------------------- /step.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "os" 7 | "time" 8 | 9 | "github.com/coinbase/step/machine" 10 | 11 | "github.com/coinbase/step/bifrost" 12 | "github.com/coinbase/step/client" 13 | "github.com/coinbase/step/deployer" 14 | "github.com/coinbase/step/utils/run" 15 | "github.com/coinbase/step/utils/to" 16 | ) 17 | 18 | func main() { 19 | default_name := "coinbase-step-deployer" 20 | region, account_id := to.RegionAccount() 21 | def_step_arn := to.Strp("") 22 | if region != nil && account_id != nil { 23 | def_step_arn = to.StepArn(region, account_id, &default_name) 24 | } 25 | 26 | // Step Subcommands 27 | jsonCommand := flag.NewFlagSet("json", flag.ExitOnError) 28 | 29 | dotCommand := flag.NewFlagSet("dot", flag.ExitOnError) 30 | dotStates := dotCommand.String("states", "{}", "State Machine JSON") 31 | 32 | // Other Subcommands 33 | bootstrapCommand := flag.NewFlagSet("bootstrap", flag.ExitOnError) 34 | deployCommand := flag.NewFlagSet("deploy", flag.ExitOnError) 35 | 36 | // bootstrap args 37 | bootstrapStates := bootstrapCommand.String("states", "{}", "State Machine JSON") 38 | bootstrapLambda := bootstrapCommand.String("lambda", "", "lambda name or arn") 39 | bootstrapStep := bootstrapCommand.String("step", "", "step function name or arn") 40 | bootstrapBucket := bootstrapCommand.String("bucket", "", "s3 bucket to upload release to") 41 | bootstrapZip := bootstrapCommand.String("zip", "lambda.zip", "zip of lambda") 42 | bootstrapProject := bootstrapCommand.String("project", "", "project name") 43 | bootstrapConfig := bootstrapCommand.String("config", "", "config name") 44 | bootstrapRegion := bootstrapCommand.String("region", "", "AWS region") 45 | bootstrapAccount := bootstrapCommand.String("account", "", "AWS account id") 46 | 47 | // deploy args 48 | deployStates := deployCommand.String("states", "{}", "State Machine JSON") 49 | deployLambda := deployCommand.String("lambda", "", "lambda name or arn") 50 | deployStep := deployCommand.String("step", "", "step function name or arn") 51 | deployBucket := deployCommand.String("bucket", "", "s3 bucket to upload release to") 52 | deployDeployer := deployCommand.String("deployer", *def_step_arn, "step function deployer name or arn") 53 | deployZip := deployCommand.String("zip", "lambda.zip", "zip of lambda") 54 | deployProject := deployCommand.String("project", "", "project name") 55 | deployConfig := deployCommand.String("config", "", "config name") 56 | deployRegion := deployCommand.String("region", "", "AWS region") 57 | deployAccount := deployCommand.String("account", "", "AWS account id") 58 | 59 | // By Default Run Lambda Function 60 | if len(os.Args) == 1 { 61 | fmt.Println("Starting Lambda") 62 | run.LambdaTasks(deployer.TaskHandlers()) 63 | } 64 | 65 | switch os.Args[1] { 66 | case "json": 67 | jsonCommand.Parse(os.Args[2:]) 68 | case "dot": 69 | dotCommand.Parse(os.Args[2:]) 70 | case "bootstrap": 71 | bootstrapCommand.Parse(os.Args[2:]) 72 | case "deploy": 73 | deployCommand.Parse(os.Args[2:]) 74 | default: 75 | fmt.Println("Usage of step: step (No args starts Lambda)") 76 | fmt.Println("json") 77 | jsonCommand.PrintDefaults() 78 | fmt.Println("dot") 79 | dotCommand.PrintDefaults() 80 | fmt.Println("bootstrap") 81 | bootstrapCommand.PrintDefaults() 82 | fmt.Println("deploy") 83 | deployCommand.PrintDefaults() 84 | os.Exit(1) 85 | } 86 | 87 | // Create the State machine 88 | if jsonCommand.Parsed() { 89 | run.JSON(deployer.StateMachine()) 90 | } else if dotCommand.Parsed() { 91 | run.Dot(machine.FromJSON([]byte(*dotStates))) 92 | } else if bootstrapCommand.Parsed() { 93 | r := newRelease( 94 | bootstrapProject, 95 | bootstrapConfig, 96 | bootstrapLambda, 97 | bootstrapStep, 98 | bootstrapBucket, 99 | bootstrapStates, 100 | bootstrapRegion, 101 | bootstrapAccount, 102 | ) 103 | bootstrapRun(r, bootstrapZip) 104 | 105 | } else if deployCommand.Parsed() { 106 | region, account_id := to.RegionAccountOrExit() 107 | r := newRelease( 108 | deployProject, 109 | deployConfig, 110 | deployLambda, 111 | deployStep, 112 | deployBucket, 113 | deployStates, 114 | deployRegion, 115 | deployAccount, 116 | ) 117 | arn := to.StepArn(region, account_id, deployDeployer) 118 | deployRun(r, deployZip, arn) 119 | } else { 120 | fmt.Println("ERROR: Command Line Not Parsed") 121 | os.Exit(1) 122 | } 123 | } 124 | 125 | func check(err error) { 126 | if err == nil { 127 | return 128 | } 129 | fmt.Println("ERROR", err) 130 | os.Exit(1) 131 | } 132 | 133 | func bootstrapRun(release *deployer.Release, zip *string) { 134 | err := client.Bootstrap(release, zip) 135 | check(err) 136 | } 137 | 138 | func deployRun(release *deployer.Release, zip *string, deployer_arn *string) { 139 | err := client.Deploy(release, zip, deployer_arn) 140 | check(err) 141 | } 142 | 143 | func newRelease(project *string, config *string, lambda *string, step *string, bucket *string, states *string, region *string, account_id *string) *deployer.Release { 144 | return &deployer.Release{ 145 | Release: bifrost.Release{ 146 | AwsRegion: region, 147 | AwsAccountID: account_id, 148 | ReleaseID: to.TimeUUID("release-"), 149 | CreatedAt: to.Timep(time.Now()), 150 | ProjectName: project, 151 | ConfigName: config, 152 | Bucket: bucket, 153 | }, 154 | StateMachineJSON: states, 155 | LambdaName: lambda, 156 | StepFnName: step, 157 | } 158 | } 159 | -------------------------------------------------------------------------------- /utils/is/is.go: -------------------------------------------------------------------------------- 1 | package is 2 | 3 | import "time" 4 | 5 | func EmptyStr(v *string) bool { 6 | return v == nil || *v == "" 7 | } 8 | 9 | // UniqueStrp will check a list of string pointers for unique values 10 | // It will return false if any element is nil 11 | func UniqueStrp(strs []*string) bool { 12 | seen := map[string]bool{} 13 | for _, s := range strs { 14 | if s == nil { 15 | return false 16 | } 17 | if seen[*s] { 18 | return false 19 | } 20 | seen[*s] = true 21 | } 22 | return true 23 | } 24 | 25 | // WithinTimeFrame returns if a time is after and before time from now 26 | func WithinTimeFrame(tt *time.Time, diff_back time.Duration, diff_forward time.Duration) bool { 27 | // -1 make it subtract 28 | ago := time.Now().Add(-1 * diff_back) 29 | 30 | ahead := time.Now().Add(diff_forward) 31 | 32 | return tt.After(ago) && tt.Before(ahead) 33 | } 34 | -------------------------------------------------------------------------------- /utils/is/is_test.go: -------------------------------------------------------------------------------- 1 | package is 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/coinbase/step/utils/to" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func Test_UniqueStrp(t *testing.T) { 12 | assert.True(t, UniqueStrp([]*string{})) 13 | assert.True(t, UniqueStrp([]*string{to.Strp("asd")})) 14 | assert.True(t, UniqueStrp([]*string{to.Strp("asd"), to.Strp("asdas")})) 15 | 16 | assert.False(t, UniqueStrp([]*string{nil})) 17 | assert.False(t, UniqueStrp([]*string{to.Strp("asd"), to.Strp("asd")})) 18 | } 19 | 20 | func Test_WithinTimeFrame(t *testing.T) { 21 | assert.True(t, WithinTimeFrame(to.Timep(time.Now()), 10*time.Second, 10*time.Second)) 22 | 23 | assert.False(t, WithinTimeFrame(to.Timep(time.Now().Add(10*time.Minute)), 10*time.Second, 10*time.Second)) 24 | assert.False(t, WithinTimeFrame(to.Timep(time.Now().Add(-10*time.Minute)), 10*time.Second, 10*time.Second)) 25 | } 26 | -------------------------------------------------------------------------------- /utils/run/dot.go: -------------------------------------------------------------------------------- 1 | package run 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "strings" 7 | 8 | "github.com/coinbase/step/machine" 9 | "github.com/coinbase/step/utils/to" 10 | ) 11 | 12 | // Output Dot Format For State Machine 13 | 14 | // JSON prints a state machine as JSON 15 | func Dot(stateMachine *machine.StateMachine, err error) { 16 | if err != nil { 17 | fmt.Println("ERROR", err) 18 | os.Exit(1) 19 | } 20 | 21 | dotStr := toDot(stateMachine) 22 | fmt.Println(dotStr) 23 | os.Exit(0) 24 | } 25 | 26 | func toDot(stateMachine *machine.StateMachine) string { 27 | return fmt.Sprintf(`digraph StateMachine { 28 | node [style="rounded,filled,bold", shape=box, width=2, fontname="Arial" fontcolor="#183153", color="#183153"]; 29 | edge [style=bold, fontname="Arial", fontcolor="#183153", color="#183153"]; 30 | _Start [fillcolor="#183153", shape=circle, label="", width=0.25]; 31 | _End [fillcolor="#183153", shape=doublecircle, label="", width=0.3]; 32 | 33 | _Start -> "%v" [weight=1000]; 34 | %v 35 | }`, *stateMachine.StartAt, processStates(*stateMachine.StartAt, stateMachine.States)) 36 | } 37 | 38 | func processStates(start string, states map[string]machine.State) string { 39 | orderedStates := orderStates(start, states) 40 | 41 | var stateStrings []string 42 | for _, stateNode := range orderedStates { 43 | stateStrings = append(stateStrings, processState(stateNode)) 44 | } 45 | return strings.Join(stateStrings, "\n\n ") 46 | } 47 | 48 | // Order states from start to end consistently to generate deterministic graphs. 49 | func orderStates(start string, states map[string]machine.State) []machine.State { 50 | var orderedStates []machine.State 51 | startState := states[start] 52 | stateQueue := []machine.State{startState} 53 | seenStates := make(map[string]struct{}) 54 | 55 | for len(stateQueue) > 0 { 56 | var stateNode machine.State 57 | stateNode, stateQueue = stateQueue[0], stateQueue[1:] 58 | 59 | orderedStates = append(orderedStates, stateNode) 60 | 61 | var connectedStates []machine.State 62 | switch stateNode.(type) { 63 | case *machine.PassState: 64 | stateNode := stateNode.(*machine.PassState) 65 | if stateNode.Next != nil { 66 | connectedStates = append(connectedStates, states[*stateNode.Next]) 67 | } 68 | case *machine.TaskState: 69 | stateNode := stateNode.(*machine.TaskState) 70 | 71 | if stateNode.Catch != nil { 72 | for _, catch := range stateNode.Catch { 73 | connectedStates = append(connectedStates, states[*catch.Next]) 74 | } 75 | } 76 | 77 | if stateNode.Next != nil { 78 | connectedStates = append(connectedStates, states[*stateNode.Next]) 79 | } 80 | case *machine.ChoiceState: 81 | stateNode := stateNode.(*machine.ChoiceState) 82 | 83 | if stateNode.Choices != nil { 84 | for _, choice := range stateNode.Choices { 85 | connectedStates = append(connectedStates, states[*choice.Next]) 86 | } 87 | } 88 | case *machine.WaitState: 89 | stateNode := stateNode.(*machine.WaitState) 90 | 91 | if stateNode.Next != nil { 92 | connectedStates = append(connectedStates, states[*stateNode.Next]) 93 | } 94 | } 95 | 96 | for _, connectedState := range connectedStates { 97 | stateName := *connectedState.Name() 98 | if _, seen := seenStates[stateName]; !seen { 99 | stateQueue = append(stateQueue, connectedState) 100 | seenStates[stateName] = struct{}{} 101 | } 102 | } 103 | } 104 | 105 | return orderedStates 106 | } 107 | 108 | func processState(stateNode machine.State) string { 109 | var lines []string 110 | name := *stateNode.Name() 111 | switch stateNode.(type) { 112 | case *machine.PassState: 113 | stateNode := stateNode.(*machine.PassState) 114 | lines = append(lines, fmt.Sprintf(`%q [fillcolor="#FBFBFB"];`, name)) 115 | if stateNode.Next != nil { 116 | lines = append(lines, fmt.Sprintf(`%q -> %q [weight=100];`, name, *stateNode.Next)) 117 | } 118 | if stateNode.End != nil { 119 | lines = append(lines, fmt.Sprintf(`%q -> _End;`, name)) 120 | } 121 | case *machine.TaskState: 122 | stateNode := stateNode.(*machine.TaskState) 123 | lines = append(lines, fmt.Sprintf(`%q [fillcolor="#FBFBFB"];`, name)) 124 | 125 | if stateNode.Catch != nil { 126 | for _, catch := range stateNode.Catch { 127 | catchName := fmt.Sprintf("%q", strings.Join(to.StrSlice(catch.ErrorEquals), ",")) 128 | if len(catch.ErrorEquals) == 1 && *catch.ErrorEquals[0] == "States.ALL" { 129 | catchName = "" 130 | } 131 | lines = append(lines, fmt.Sprintf(`%q -> %q [color="#949494", label=%q, style=solid];`, name, *catch.Next, catchName)) 132 | } 133 | } 134 | 135 | if stateNode.Next != nil { 136 | lines = append(lines, fmt.Sprintf(`%q -> %q [weight=100];`, name, *stateNode.Next)) 137 | } 138 | 139 | if stateNode.End != nil { 140 | lines = append(lines, fmt.Sprintf(`%q -> _End;`, name)) 141 | } 142 | case *machine.ChoiceState: 143 | stateNode := stateNode.(*machine.ChoiceState) 144 | lines = append(lines, fmt.Sprintf(`%q [shape=egg, fillcolor="#FBFBFB"];`, name)) 145 | 146 | if stateNode.Choices != nil { 147 | for _, choice := range stateNode.Choices { 148 | lines = append(lines, fmt.Sprintf(`%q -> %q [weight=100];`, name, *choice.Next)) 149 | } 150 | } 151 | case *machine.WaitState: 152 | stateNode := stateNode.(*machine.WaitState) 153 | 154 | lines = append(lines, fmt.Sprintf(`%q [width=0.5, shape=doublecircle, fillcolor="#FBFBFB", label="Wait"];`, name)) 155 | 156 | if stateNode.Next != nil { 157 | lines = append(lines, fmt.Sprintf(`%q -> %q [weight=100];`, name, *stateNode.Next)) 158 | } 159 | case *machine.FailState: 160 | lines = append(lines, fmt.Sprintf(`%q [fillcolor="#F9E4D1"];`, name)) 161 | lines = append(lines, fmt.Sprintf(`%q -> _End [weight=1000];`, name)) 162 | case *machine.SucceedState: 163 | lines = append(lines, fmt.Sprintf(`%q [fillcolor="#e5eddb"];`, name)) 164 | lines = append(lines, fmt.Sprintf(`%q -> _End [weight=1000];`, name)) 165 | } 166 | 167 | return strings.Join(lines, "\n ") 168 | } 169 | -------------------------------------------------------------------------------- /utils/run/run.go: -------------------------------------------------------------------------------- 1 | // run takes arguments 2 | package run 3 | 4 | import ( 5 | "fmt" 6 | "os" 7 | 8 | "github.com/aws/aws-lambda-go/lambda" 9 | "github.com/coinbase/step/handler" 10 | "github.com/coinbase/step/machine" 11 | "github.com/coinbase/step/utils/is" 12 | "github.com/coinbase/step/utils/to" 13 | 14 | ddlambda "github.com/DataDog/datadog-lambda-go" 15 | ) 16 | 17 | // Exec returns a function that will execute the state machine 18 | func Exec(state_machine *machine.StateMachine, err error) func(*string) { 19 | if err != nil { 20 | return func(input *string) { 21 | fmt.Println("ERROR", err) 22 | os.Exit(1) 23 | } 24 | } 25 | 26 | return func(input *string) { 27 | 28 | if is.EmptyStr(input) { 29 | input = to.Strp("{}") 30 | } 31 | 32 | exec, err := state_machine.Execute(input) 33 | output_json := exec.OutputJSON 34 | 35 | if err != nil { 36 | fmt.Println("ERROR", err) 37 | os.Exit(1) 38 | } 39 | 40 | fmt.Println(output_json) 41 | os.Exit(0) 42 | } 43 | } 44 | 45 | // JSON prints a state machine as JSON 46 | func JSON(state_machine *machine.StateMachine, err error) { 47 | if err != nil { 48 | fmt.Println("ERROR", err) 49 | os.Exit(1) 50 | } 51 | 52 | json, err := to.PrettyJSON(state_machine) 53 | 54 | if err != nil { 55 | fmt.Println("ERROR", err) 56 | os.Exit(1) 57 | } 58 | 59 | fmt.Println(string(json)) 60 | os.Exit(0) 61 | } 62 | 63 | // LambdaTasks takes task functions and and executes as a lambda 64 | func LambdaTasks(task_functions *handler.TaskHandlers) { 65 | handler, err := handler.CreateHandler(task_functions) 66 | 67 | if err != nil { 68 | fmt.Println("ERROR", err) 69 | os.Exit(1) 70 | } 71 | 72 | lambda.Start(ddlambda.WrapHandler(handler, nil)) 73 | 74 | fmt.Println("ERROR: lambda.Start returned, but should have blocked") 75 | os.Exit(1) 76 | } 77 | -------------------------------------------------------------------------------- /utils/to/arn.go: -------------------------------------------------------------------------------- 1 | package to 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "strings" 7 | 8 | "github.com/aws/aws-lambda-go/lambdacontext" 9 | "github.com/aws/aws-sdk-go/aws/arn" 10 | ) 11 | 12 | // LambdaArn takes a name OR arn and returns Arn defaulting to AWS Environment variables 13 | func LambdaArn(region *string, account_id *string, name_or_arn *string) *string { 14 | return createArn("arn:aws:lambda:%v:%v:function:%v", region, account_id, name_or_arn) 15 | } 16 | 17 | // StepArn takes a name OR arn and returns Arn defaulting to AWS Environment variables 18 | func StepArn(region *string, account_id *string, name_or_arn *string) *string { 19 | return createArn("arn:aws:states:%v:%v:stateMachine:%v", region, account_id, name_or_arn) 20 | } 21 | 22 | func RoleArn(account_id *string, name_or_arn *string) *string { 23 | return createArn("arn:aws:iam::%v%v:role/%v", account_id, Strp(""), name_or_arn) 24 | } 25 | 26 | // InterpolateArnVariables replaces any resource parameter templates with the appropriate values 27 | func InterpolateArnVariables(state_machine *string, region *string, account_id *string, name_or_arn *string) *string { 28 | variableTemplate := map[string]*string{ 29 | "{{aws_account}}": account_id, 30 | "{{aws_region}}": region, 31 | "{{lambda_name}}": name_or_arn, 32 | } 33 | for k, v := range variableTemplate { 34 | *state_machine = strings.Replace(*state_machine, k, *v, -1) 35 | } 36 | return state_machine 37 | } 38 | 39 | func ArnPath(arn string) string { 40 | _, _, res := ArnRegionAccountResource(arn) 41 | 42 | path := strings.Split(res, "/") 43 | 44 | switch len(path) { 45 | case 0: 46 | return "/" 47 | case 1: 48 | return "/" 49 | case 2: 50 | return "/" 51 | default: 52 | return fmt.Sprintf("/%v/", strings.Join(path[1:len(path)-1], "/")) 53 | } 54 | } 55 | 56 | func LambdaArnFromContext(ctx context.Context) (string, error) { 57 | lc, ok := lambdacontext.FromContext(ctx) 58 | if !ok || lc == nil { 59 | return "", fmt.Errorf("Incorrect Lambda Context") 60 | } 61 | 62 | return lc.InvokedFunctionArn, nil 63 | } 64 | 65 | func AwsRegionAccountFromContext(ctx context.Context) (*string, *string) { 66 | arn, err := LambdaArnFromContext(ctx) 67 | if err != nil { 68 | return nil, nil 69 | } 70 | 71 | region, account, _ := ArnRegionAccountResource(arn) 72 | return ®ion, &account 73 | } 74 | 75 | func AwsRegionAccountLambdaNameFromContext(ctx context.Context) (region, account, lambdaName string) { 76 | arn, err := LambdaArnFromContext(ctx) 77 | if err != nil { 78 | return "", "", "" 79 | } 80 | 81 | region, account, resource := ArnRegionAccountResource(arn) 82 | // function: 83 | resourceParts := strings.SplitN(strings.ToLower(resource), ":", 2) 84 | if len(resourceParts) < 2 { 85 | return region, account, "" 86 | } 87 | 88 | return region, account, resourceParts[1] 89 | } 90 | 91 | func ArnRegionAccountResource(arnstr string) (string, string, string) { 92 | a, err := arn.Parse(arnstr) 93 | if err != nil { 94 | return "", "", "" 95 | } 96 | return a.Region, a.AccountID, a.Resource 97 | } 98 | 99 | func createArn(arn_str string, region *string, account_id *string, name_or_arn *string) *string { 100 | if len(*name_or_arn) < 5 || (*name_or_arn)[:4] == "arn:" { 101 | return name_or_arn 102 | } 103 | 104 | if region == nil || account_id == nil || name_or_arn == nil { 105 | return name_or_arn 106 | } 107 | 108 | arn := fmt.Sprintf(arn_str, *region, *account_id, *name_or_arn) 109 | return &arn 110 | } 111 | -------------------------------------------------------------------------------- /utils/to/arn_test.go: -------------------------------------------------------------------------------- 1 | package to 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | var InputStateMachine = `{ 10 | "StartAt": "Start", 11 | "States": { 12 | "Start": { 13 | "Type": "Task", 14 | "Resource": "arn:aws:lambda:{{aws_region}}:{{aws_account}}:function:{{lambda_name}}", 15 | "Next": "WIN" 16 | }, 17 | "WIN": {"Type": "Succeed"} 18 | } 19 | }` 20 | 21 | var DesiredStateMachine = `{ 22 | "StartAt": "Start", 23 | "States": { 24 | "Start": { 25 | "Type": "Task", 26 | "Resource": "arn:aws:lambda:test-region:test-account:function:test-lambda", 27 | "Next": "WIN" 28 | }, 29 | "WIN": {"Type": "Succeed"} 30 | } 31 | }` 32 | 33 | func Test_to_InterpolateArnVariables(t *testing.T) { 34 | resultStateMachine := InterpolateArnVariables( 35 | &InputStateMachine, 36 | Strp("test-region"), 37 | Strp("test-account"), 38 | Strp("test-lambda"), 39 | ) 40 | assert.Equal(t, *resultStateMachine, DesiredStateMachine) 41 | } 42 | -------------------------------------------------------------------------------- /utils/to/json.go: -------------------------------------------------------------------------------- 1 | package to 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | ) 7 | 8 | // FromJSON Map Converts a string of JSON or a Struct into a map[string]interface{} 9 | func FromJSON(input interface{}) (interface{}, error) { 10 | str, err := PrettyJSON(input) 11 | if err != nil { 12 | return nil, err 13 | } 14 | 15 | var v interface{} 16 | if err := json.Unmarshal([]byte(str), &v); err != nil { 17 | return nil, err 18 | } 19 | 20 | return v, nil 21 | } 22 | 23 | // Takes a string, *string, or struct and returns []byte (json marshal) 24 | func AByte(input interface{}) ([]byte, error) { 25 | switch input.(type) { 26 | case nil: 27 | return []byte(""), nil 28 | case string: 29 | return []byte(input.(string)), nil 30 | case *string: 31 | str := input.(*string) 32 | if str == nil { 33 | return []byte(""), nil 34 | } 35 | return []byte(*str), nil 36 | case []byte: 37 | return input.([]byte), nil 38 | case *[]byte: 39 | by := input.(*[]byte) 40 | if by == nil { 41 | return []byte(""), nil 42 | } 43 | return *by, nil 44 | default: 45 | return json.Marshal(input) 46 | } 47 | } 48 | 49 | // PrettyJSON takes a string or a struct and returns it as PrettyJSON 50 | func PrettyJSON(input interface{}) (string, error) { 51 | raw, err := AByte(input) 52 | if err != nil { 53 | return "", err 54 | } 55 | 56 | var json_str interface{} 57 | if err := json.Unmarshal(raw, &json_str); err != nil { 58 | return string(raw), nil 59 | } 60 | 61 | by, err := json.MarshalIndent(json_str, "", " ") 62 | return string(by), err 63 | } 64 | 65 | // PrettyJSONStr takes a string or a struct and returns it as PrettyJSON, no error 66 | func PrettyJSONStr(input interface{}) string { 67 | str, _ := PrettyJSON(input) 68 | return str 69 | } 70 | 71 | func CompactJSON(input interface{}) (string, error) { 72 | raw, err := AByte(input) 73 | if err != nil { 74 | return "", err 75 | } 76 | 77 | b := bytes.NewBuffer(nil) 78 | err = json.Compact(b, raw) 79 | if err != nil { 80 | return "", err 81 | } 82 | 83 | return string(b.Bytes()), nil 84 | } 85 | 86 | func CompactJSONStr(input interface{}) string { 87 | str, _ := CompactJSON(input) 88 | return str 89 | } 90 | -------------------------------------------------------------------------------- /utils/to/json_test.go: -------------------------------------------------------------------------------- 1 | package to 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func Test_CompactJSONStr(t *testing.T) { 10 | assert.Equal(t, `{}`, CompactJSONStr(Strp("{}"))) 11 | assert.Equal(t, `{"a":"b"}`, CompactJSONStr(Strp("{\n \"a\": \"b\"\n}"))) 12 | } 13 | 14 | func Test_PrettyJSONStr(t *testing.T) { 15 | assert.Equal(t, `{}`, PrettyJSONStr(Strp("{}"))) 16 | assert.Equal(t, "{\n \"a\": \"b\"\n}", PrettyJSONStr(Strp(`{"a":"b"}`))) 17 | } 18 | 19 | func Test_AByte(t *testing.T) { 20 | raw, err := AByte(nil) 21 | assert.NoError(t, err) 22 | assert.Equal(t, raw, []byte("")) 23 | 24 | var str *string 25 | raw, err = AByte(str) 26 | assert.NoError(t, err) 27 | assert.Equal(t, raw, []byte("")) 28 | 29 | raw, err = AByte(Strp("asd")) 30 | assert.NoError(t, err) 31 | assert.Equal(t, raw, []byte("asd")) 32 | 33 | raw, err = AByte("asd") 34 | assert.NoError(t, err) 35 | assert.Equal(t, raw, []byte("asd")) 36 | 37 | raw, err = AByte(struct{ Name string }{"asd"}) 38 | assert.NoError(t, err) 39 | assert.Equal(t, raw, []byte(`{"Name":"asd"}`)) 40 | } 41 | -------------------------------------------------------------------------------- /utils/to/pointer.go: -------------------------------------------------------------------------------- 1 | package to 2 | 3 | import ( 4 | "encoding/base64" 5 | "time" 6 | ) 7 | 8 | // Strp return a string pointer from string 9 | func Strp(s string) *string { 10 | return &s 11 | } 12 | 13 | func Strs(s *string) string { 14 | if s == nil { 15 | return "" 16 | } 17 | return *s 18 | } 19 | 20 | func Timep(s time.Time) *time.Time { 21 | return &s 22 | } 23 | 24 | func Intp(s int) *int { 25 | return &s 26 | } 27 | 28 | func Int64p(s int64) *int64 { 29 | return &s 30 | } 31 | 32 | func Float64p(s float64) *float64 { 33 | return &s 34 | } 35 | 36 | func Boolp(s bool) *bool { 37 | return &s 38 | } 39 | 40 | func ABytep(s []byte) *[]byte { 41 | return &s 42 | } 43 | 44 | //////// 45 | // Base64 46 | //////// 47 | 48 | func Base64(str *string) string { 49 | if str == nil { 50 | return base64.StdEncoding.EncodeToString([]byte("")) 51 | } 52 | return base64.StdEncoding.EncodeToString([]byte(*str)) 53 | } 54 | 55 | func Base64p(str *string) *string { 56 | return Strp(Base64(str)) 57 | } 58 | -------------------------------------------------------------------------------- /utils/to/sha256.go: -------------------------------------------------------------------------------- 1 | package to 2 | 3 | import ( 4 | "crypto/sha256" 5 | "encoding/hex" 6 | "encoding/json" 7 | "io" 8 | "os" 9 | ) 10 | 11 | func SHA256Struct(str interface{}) string { 12 | raw, err := json.Marshal(str) 13 | if err != nil { 14 | // No deterministic error 15 | return RandomString(10) 16 | } 17 | 18 | return SHA256AByte(&raw) 19 | } 20 | 21 | // SHA256Str returns a hex string of the SHA256 of a string 22 | func SHA256Str(str *string) string { 23 | byt := []byte(*str) 24 | return SHA256AByte(&byt) 25 | } 26 | 27 | // SHA256AByte returns a hex string of the SHA256 of a byte array 28 | func SHA256AByte(b *[]byte) string { 29 | sum := sha256.Sum256(*b) 30 | sha := hex.EncodeToString(sum[:]) 31 | return sha 32 | } 33 | 34 | // SHA256File returns a hex string of the SHA256 of a file 35 | func SHA256File(file_path string) (string, error) { 36 | f, err := os.Open(file_path) 37 | if err != nil { 38 | // No deterministic error 39 | return RandomString(10), err 40 | } 41 | defer f.Close() 42 | 43 | hasher := sha256.New() 44 | if _, err := io.Copy(hasher, f); err != nil { 45 | // No deterministic error 46 | return RandomString(10), err 47 | } 48 | sha := hex.EncodeToString(hasher.Sum(nil)) 49 | return sha, nil 50 | } 51 | -------------------------------------------------------------------------------- /utils/to/to.go: -------------------------------------------------------------------------------- 1 | // to is a list of Functions use to convert things to things 2 | package to 3 | 4 | import ( 5 | "fmt" 6 | "math/rand" 7 | "os" 8 | "reflect" 9 | "strings" 10 | "time" 11 | ) 12 | 13 | func init() { 14 | rand.Seed(time.Now().UTC().UnixNano()) 15 | } 16 | 17 | var letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") 18 | 19 | func RandomString(n int) string { 20 | b := make([]rune, n) 21 | for i := range b { 22 | b[i] = letterRunes[rand.Intn(len(letterRunes))] 23 | } 24 | return string(b) 25 | } 26 | 27 | func StrSlice(strps []*string) []string { 28 | strs := []string{} 29 | for _, sp := range strps { 30 | str := "" 31 | if sp != nil { 32 | str = *sp 33 | } 34 | strs = append(strs, str) 35 | } 36 | return strs 37 | } 38 | 39 | // Take from aws-lambda-go.Function#lambdaErrorResponse 40 | func ErrorType(invokeError error) string { 41 | var errorName string 42 | if errorType := reflect.TypeOf(invokeError); errorType.Kind() == reflect.Ptr { 43 | errorName = errorType.Elem().Name() 44 | } else { 45 | errorName = errorType.Name() 46 | } 47 | return errorName 48 | } 49 | 50 | func RegionAccount() (*string, *string) { 51 | region := os.Getenv("AWS_REGION") 52 | account_id := os.Getenv("AWS_ACCOUNT_ID") 53 | 54 | if region == "" || account_id == "" { 55 | return nil, nil 56 | } 57 | 58 | return ®ion, &account_id 59 | } 60 | 61 | func RegionAccountOrExit() (*string, *string) { 62 | region, account_id := RegionAccount() 63 | 64 | if region == nil || account_id == nil { 65 | fmt.Println("AWS_REGION or AWS_ACCOUNT_ID not defined") 66 | os.Exit(1) 67 | } 68 | 69 | return region, account_id 70 | } 71 | 72 | // TimeUUID returns time base UUID with prefix 73 | func TimeUUID(prefix string) *string { 74 | tf := strings.Replace(time.Now().UTC().Format(time.RFC3339), ":", "-", -1) 75 | rs := RandomString(7) 76 | rid := fmt.Sprintf("%v%v-%v", prefix, tf, rs) 77 | return &rid 78 | } 79 | -------------------------------------------------------------------------------- /utils/to/to_test.go: -------------------------------------------------------------------------------- 1 | package to 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func Test_to_ArnPath(t *testing.T) { 10 | assert.Equal(t, "/", ArnPath("")) 11 | assert.Equal(t, "/bla/foo/", ArnPath("arn:aws:iam::000000:instance-profile/bla/foo/bar")) 12 | assert.Equal(t, "/bla/foo/", ArnPath("arn:aws:iam::000000:role/bla/foo/bar")) 13 | assert.Equal(t, "/", ArnPath("arn:aws:iam::000000:role/bar")) 14 | } 15 | 16 | func Test_to_ArnRegionAccountResource(t *testing.T) { 17 | r, a, res := ArnRegionAccountResource("arn:aws:lambda:::function:") 18 | 19 | assert.Equal(t, "", r) 20 | assert.Equal(t, "", a) 21 | assert.Equal(t, "function:", res) 22 | 23 | r, a, res = ArnRegionAccountResource("arn:aws:iam::000000:instance-profile/bla/foo/bar") 24 | assert.Equal(t, "", r) 25 | assert.Equal(t, "000000", a) 26 | assert.Equal(t, "instance-profile/bla/foo/bar", res) 27 | } 28 | --------------------------------------------------------------------------------