├── .editorconfig ├── .github ├── CONTRIBUTING.md ├── ISSUE_TEMPLATE.md ├── PULL_REQUEST_TEMPLATE.md ├── dependabot.yml └── workflows │ ├── compile-test.yml │ ├── sbt-dependency-graph.yaml │ └── stale.yml ├── .gitignore ├── .tool-versions ├── LICENSE ├── NOTICE ├── README.md ├── build.sbt ├── generate-executable-prefix ├── generate-executable.sh ├── project ├── build.properties └── plugins.sbt ├── scripts └── ssh-report └── src ├── main ├── resources │ └── logback.xml └── scala │ └── com │ └── gu │ └── ssm │ ├── ArgumentParser.scala │ ├── IO.scala │ ├── Interactive.scala │ ├── Logic.scala │ ├── Main.scala │ ├── SSH.scala │ ├── UI.scala │ ├── aws │ ├── AwsAsyncHandler.scala │ ├── EC2.scala │ ├── RDS.scala │ ├── SSM.scala │ └── STS.scala │ ├── models.scala │ └── utils │ ├── FilePermissions.scala │ ├── KeyMaker.scala │ └── attempt │ ├── Attempt.scala │ └── Failure.scala └── test └── scala └── com └── gu └── ssm ├── LogicTest.scala ├── MainTest.scala ├── SSHTest.scala ├── UITest.scala └── utils └── attempt ├── AttemptTest.scala └── AttemptValues.scala /.editorconfig: -------------------------------------------------------------------------------- 1 | 2 | # editorconfig.org 3 | root = true 4 | 5 | [*] 6 | indent_style = space 7 | indent_size = 2 8 | end_of_line = lf 9 | charset = utf-8 10 | trim_trailing_whitespace = true 11 | insert_final_newline = true 12 | 13 | -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Opening a pull request 2 | 3 | ## General 4 | 5 | Pull requests are welcome! Here are the rules of engagement: 6 | 7 | - The philosophy here is about communication, taking responsibility for your changes, and fast, incremental delivery. 8 | - Speak to the team before you decide to do anything major. We can probably help design the change to maximise the chances of it being accepted. 9 | - Pull requests made to main assume the change is ready to be released. 10 | - Many small requests will be reviewed/merged quicker than a giant lists of changes. 11 | - If you have a proposal, or want feedback on a branch under development, prefix `[WIP]` to the pull request title. 12 | 13 | ### Please be aware that we use the Apache License 2.0, and so: 14 | 15 | Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 16 | 17 | Your contributions should be wholly your own. 18 | 19 | ## Submission 20 | 21 | ### Guardian employees 22 | 23 | This is applicable to [GMG employees](http://www.gmgplc.co.uk/). 24 | 25 | 1. Fork or clone the repo and make your changes. 26 | 27 | 2. Test your branch locally by running tests: 28 | - `./sbt project /test` 29 | 30 | 3. Open a pull request: 31 | - Explain why you are making this change in the pull request 32 | 33 | 4. A member of the team will review the changes. Once they are satisfied they will approve the pull request. 34 | 35 | 36 | ### External contributions 37 | 38 | Firstly, thanks for helping make our project better! Secondly, we'll try and make this as simple as possible. 39 | 40 | - Fork the project on GitHub, patch the code, and submit a pull request. 41 | - We will test, verify and merge your changes and then deploy the code. 42 | - Certain contributions may require a Contributor License Agreement. 43 | 44 | Finally, have you considered [working for us](https://workforus.theguardian.com/index.php/search-jobs-and-apply/?search_paths%5B%5D=&query=developer)? 45 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | # ISSUE 2 | 3 | ## Steps to Reproduce 4 | 5 | 6 | ## Actual Results (include screenshots) 7 | 8 | 9 | ## Expected Results (include screenshots) -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## What does this change? 2 | 3 | 4 | 5 | 6 | 7 | ## What is the value of this? 8 | 9 | 10 | 11 | ## Any additional notes? -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "github-actions" # See documentation for possible values 9 | directory: "/" # Location of package manifests 10 | schedule: 11 | interval: "monthly" 12 | -------------------------------------------------------------------------------- /.github/workflows/compile-test.yml: -------------------------------------------------------------------------------- 1 | name: compile-test 2 | 3 | on: 4 | workflow_dispatch: 5 | push: 6 | 7 | jobs: 8 | build: 9 | runs-on: ubuntu-latest 10 | 11 | steps: 12 | - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 13 | - name: Setup Scala 14 | uses: guardian/setup-scala@v1 15 | - run: sbt test 16 | -------------------------------------------------------------------------------- /.github/workflows/sbt-dependency-graph.yaml: -------------------------------------------------------------------------------- 1 | name: Update Dependency Graph for sbt 2 | on: 3 | push: 4 | branches: 5 | - main 6 | workflow_dispatch: 7 | jobs: 8 | dependency-graph: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - name: Checkout branch 12 | id: checkout 13 | uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 14 | - name: Setup Scala 15 | uses: guardian/setup-scala@v1 16 | - name: Submit dependencies 17 | id: submit 18 | uses: scalacenter/sbt-dependency-submission@64084844d2b0a9b6c3765f33acde2fbe3f5ae7d3 # v3.1.0 19 | - name: Log snapshot for user validation 20 | id: validate 21 | run: cat ${{ steps.submit.outputs.snapshot-json-path }} | jq 22 | permissions: 23 | contents: write 24 | -------------------------------------------------------------------------------- /.github/workflows/stale.yml: -------------------------------------------------------------------------------- 1 | # Automatically mark any pull requests that have been inactive for 30 days as "Stale" 2 | # then close them 3 days later if there is still no activity. 3 | name: "Stale PR Handler" 4 | 5 | on: 6 | schedule: 7 | # Check for Stale PRs every Monday to Thursday morning 8 | # Don't check on Fridays as it wouldn't be very nice to have a bot mark your PR as Stale on Friday and then close it on Monday morning! 9 | - cron: "0 6 * * MON-THU" 10 | 11 | permissions: 12 | pull-requests: write 13 | 14 | jobs: 15 | stale: 16 | runs-on: ubuntu-latest 17 | steps: 18 | - uses: actions/stale@5bef64f19d7facfb25b37b414482c7164d639639 # v9.1.0 19 | id: stale 20 | # Read about options here: https://github.com/actions/stale#all-options 21 | with: 22 | # never automatically mark issues as stale 23 | days-before-issue-stale: -1 24 | 25 | # Wait 30 days before marking a PR as stale 26 | days-before-stale: 30 27 | stale-pr-message: > 28 | This PR is stale because it has been open 30 days with no activity. 29 | Unless a comment is added or the “stale” label removed, this will be closed in 3 days 30 | 31 | # Wait 3 days after a PR has been marked as stale before closing 32 | days-before-close: 3 33 | close-pr-message: This PR was closed because it has been stalled for 3 days with no activity. 34 | 35 | # Ignore PR's raised by Dependabot 36 | exempt-pr-labels: "dependencies" 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | 3 | # Created by https://www.gitignore.io/api/sbt,scala,emacs,intellij 4 | 5 | ### Emacs ### 6 | # -*- mode: gitignore; -*- 7 | *~ 8 | \#*\# 9 | /.emacs.desktop 10 | /.emacs.desktop.lock 11 | *.elc 12 | auto-save-list 13 | tramp 14 | .\#* 15 | 16 | # Org-mode 17 | .org-id-locations 18 | *_archive 19 | 20 | # flymake-mode 21 | *_flymake.* 22 | 23 | # eshell files 24 | /eshell/history 25 | /eshell/lastdir 26 | 27 | # elpa packages 28 | /elpa/ 29 | 30 | # reftex files 31 | *.rel 32 | 33 | # AUCTeX auto folder 34 | /auto/ 35 | 36 | # cask packages 37 | .cask/ 38 | dist/ 39 | 40 | # Flycheck 41 | flycheck_*.el 42 | 43 | # server auth directory 44 | /server/ 45 | 46 | # projectiles files 47 | .projectile 48 | projectile-bookmarks.eld 49 | 50 | # directory configuration 51 | .dir-locals.el 52 | 53 | # saveplace 54 | places 55 | 56 | # url cache 57 | url/cache/ 58 | 59 | # cedet 60 | ede-projects.el 61 | 62 | # smex 63 | smex-items 64 | 65 | # company-statistics 66 | company-statistics-cache.el 67 | 68 | # anaconda-mode 69 | anaconda-mode/ 70 | 71 | ### Intellij ### 72 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 73 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 74 | 75 | # User-specific stuff: 76 | .idea/**/workspace.xml 77 | .idea/**/tasks.xml 78 | .idea/dictionaries 79 | 80 | # Sensitive or high-churn files: 81 | .idea/**/dataSources/ 82 | .idea/**/dataSources.ids 83 | .idea/**/dataSources.xml 84 | .idea/**/dataSources.local.xml 85 | .idea/**/sqlDataSources.xml 86 | .idea/**/dynamic.xml 87 | .idea/**/uiDesigner.xml 88 | 89 | # Gradle: 90 | .idea/**/gradle.xml 91 | .idea/**/libraries 92 | 93 | # CMake 94 | cmake-build-debug/ 95 | 96 | # Mongo Explorer plugin: 97 | .idea/**/mongoSettings.xml 98 | 99 | ## File-based project format: 100 | *.iws 101 | 102 | ## Plugin-specific files: 103 | 104 | # IntelliJ 105 | /out/ 106 | 107 | # mpeltonen/sbt-idea plugin 108 | .idea_modules/ 109 | 110 | # JIRA plugin 111 | atlassian-ide-plugin.xml 112 | 113 | # Cursive Clojure plugin 114 | .idea/replstate.xml 115 | 116 | # Ruby plugin and RubyMine 117 | /.rakeTasks 118 | 119 | # Crashlytics plugin (for Android Studio and IntelliJ) 120 | com_crashlytics_export_strings.xml 121 | crashlytics.properties 122 | crashlytics-build.properties 123 | fabric.properties 124 | 125 | ### Intellij Patch ### 126 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 127 | 128 | # *.iml 129 | # modules.xml 130 | # .idea/misc.xml 131 | # *.ipr 132 | 133 | # Sonarlint plugin 134 | .idea/sonarlint 135 | 136 | ### SBT ### 137 | # Simple Build Tool 138 | # http://www.scala-sbt.org/release/docs/Getting-Started/Directories.html#configuring-version-control 139 | 140 | dist/* 141 | target/ 142 | lib_managed/ 143 | src_managed/ 144 | project/boot/ 145 | project/plugins/project/ 146 | .history 147 | .cache 148 | .lib/ 149 | .bsp 150 | 151 | ### Scala ### 152 | *.class 153 | *.log 154 | 155 | # End of https://www.gitignore.io/api/sbt,scala,emacs,intellij 156 | -------------------------------------------------------------------------------- /.tool-versions: -------------------------------------------------------------------------------- 1 | java corretto-21 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | SSM-SCALA 2 | Copyright 2018 Guardian News & Media Ltd 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | SSM-Scala 2 | ========= 3 | 4 | SSM-Scala is a command-line tool, written in Scala, for executing commands on EC2 servers using AWS's EC2 Run command. It provides the user with: 5 | 6 | 1. standard `ssh` access using short lived RSA keys 7 | 2. an _alternative_ to `ssh` for running commands on the target 8 | 9 | Both modes apply to servers in AWS accounts to which you have [IAM](https://aws.amazon.com/iam/) access. 10 | 11 | Instructions for using SSM Scala in your own project can be found [below](#How-to-use-SSM-Scala-with-your-own-project). 12 | 13 | ## Installation 14 | 15 | If you have Homebrew installed and want to install ssm, do 16 | 17 | ``` 18 | brew install guardian/devtools/ssm 19 | ``` 20 | 21 | and for an upgrade do 22 | 23 | ``` 24 | brew upgrade ssm 25 | ``` 26 | 27 | Otherwise, fetch the most recently released version of the program from the [Github releases page](https://github.com/guardian/ssm-scala/releases/latest) and make sure it is executable (`chmod +x ssm`). You may then want to put it somewhere in your PATH. 28 | 29 | 30 | ## First time here, just show me the SSH thing real quick 31 | 32 | The readme is quite detailed (and shows how to do many more things than what will be shown in this section) but you are probably reading it because you just want to ssh to a box. Here is what you need to do: 33 | 34 | 1. Install ssm. How to do so was explained in the previous section. 35 | 2. Ensure that you have the Janus credentials of the account you want to work with. We are going to assume `frontend` in this section for the examples. 36 | 2. Identify the instance number of the box you want to reach. It can be found in the AWS developer console. Instance numbers look like this `i-00032c76140bc9140`. 37 | 3. At your console type 38 | 39 | ``` 40 | ssm ssh -i i-00032c76140bc9140 -p frontend 41 | ``` 42 | 43 | and more generally 44 | 45 | ``` 46 | ssm ssh -i -p 47 | ssm ssh -i -p 48 | ``` 49 | 50 | 5. And that's it! If all went well you have been ssh'ed to the box. 51 | 52 | 53 | ## Known issues 54 | 55 | If you get an error about Futures timed out after 25 seconds, then the SSM permissions may not be right, or you might need to recycle the instance since adding the permissions. 56 | 57 | If the disk on which the keyfile is stored is full, then ssm-scala cannot add the public key identity prior to logging in to the box. This is often found to be the case, and also can apparently cause the AWS SSM agent to stop. 58 | 59 | One potential workaround for this is rebooting the box using the EC2 console (may clear down logs, for example). 60 | 61 | ## Usage 62 | 63 | The automatically generated help section for `ssm` is 64 | 65 | ``` 66 | Usage: ssm [cmd|repl|ssh|scp] [options] ... 67 | 68 | -p, --profile The AWS profile name to use for authenticating this execution 69 | -i, --instances Specify the instance ID(s) on which the specified command(s) should execute 70 | -t, --tags Search for instances by tag. If you provide less than 3 tags assumed order is app,stage,stack. e.g. '--tags riff-raff,prod' or '--tags grafana' Upper/lowercase variations will be tried. 71 | -r, --region AWS region name (defaults to eu-west-1) 72 | --verbose enable more verbose logging 73 | --use-default-credentials-provider 74 | Use the default AWS credentials provider chain rather than profile credentials. This option is required when running within AWS itself. 75 | Command: cmd [options] 76 | Execute a single (bash) command, or a file containing bash commands 77 | -u, --user Execute command on remote host as this user (default: ubuntu) 78 | -c, --cmd A bash command to execute 79 | -f, --file A file containing bash commands to execute 80 | Command: repl 81 | Run SSM in interactive/repl mode 82 | Command: ssh [options] 83 | Create and upload a temporary ssh key 84 | -u, --user Connect to remote host as this user (default: ubuntu) 85 | --port Connect to remote host on this port 86 | --newest Selects the newest instance if more than one instance was specified 87 | --oldest Selects the oldest instance if more than one instance was specified 88 | --private Use private IP address (must be routable via VPN Gateway) 89 | --raw Unix pipe-able ssh connection string. Note: disables automatic execution. You must use 'eval' to execute this due to nested quoting 90 | -x, --execute [Deprecated - new default behaviour] Makes ssm behave like a single command (eg: `--raw` with automatic piping to the shell) 91 | -d, --dryrun Generate SSH command but do not execute (previous default behaviour) 92 | -A, --agent Use the local ssh agent to register the private key (and do not use -i); only bastion connections 93 | -a, --no-agent Do not use the local ssh agent 94 | -b, --bastion Connect through the given bastion specified by its instance id; implies -A (use agent) unless followed by -a. 95 | -B, --bastion-tags 96 | Connect through the given bastion identified by its tags; implies -a (use agent) unless followed by -A. 97 | --bastion-port Connect through the given bastion at a given port. 98 | --bastion-user Connect to bastion as this user (default: ubuntu). 99 | --host-key-alg-preference 100 | The preferred host key algorithms, can be specified multiple times - last is preferred (default: ecdsa-sha2-nistp256, ssh-rsa) 101 | --ssm-tunnel [deprecated] 102 | --no-ssm-proxy Do not connect to the host proxying via AWS Systems Manager - go direct to port 22. Useful for instances running old versions of systems manager (< 2.3.672.0) 103 | --tunnel Forward traffic from the given local port to the given host and port on the remote side. Accepts the format `localPort:host:remotePort`, e.g. --tunnel 5000:a.remote.host.com:5000 104 | --rds-tunnel Forward traffic from a given local port to a RDS database specified by tags. Accepts the format `localPort:tags`, where `tags` is a comma-separated list of tag values, e.g. --rds-tunnel 5000:app,stack,stage 105 | Command: scp [options] [:]... [:]... 106 | Secure Copy 107 | -u, --user Connect to remote host as this user (default: ubuntu) 108 | --port Connect to remote host on this port 109 | --newest Selects the newest instance if more than one instance was specified 110 | --oldest Selects the oldest instance if more than one instance was specified 111 | --private Use private IP address (must be routable via VPN Gateway) 112 | --raw Unix pipe-able scp connection string 113 | -x, --execute [Deprecated - new default behaviour] Makes ssm behave like a single command (eg: `--raw` with automatic piping to the shell) 114 | -d, --dryrun Generate SCP command but do not execute (previous default behaviour) 115 | --ssm-tunnel [deprecated] 116 | --no-ssm-proxy Do not connect to the host proxying via AWS Systems Manager - go direct to port 22. Useful for instances running old versions of systems manager (< 2.3.672.0) 117 | [:]... Source file for the scp sub command. See README for details 118 | [:]... Target file for the scp sub command. See README for details 119 | ``` 120 | 121 | There are two mandatory configuration items. 122 | 123 | To specify your AWS profile (for more information see [AWS profiles](https://docs.aws.amazon.com/cli/latest/userguide/cli-multiple-profiles.html)), either of: 124 | 125 | - `--profile` 126 | - AWS_PROFILE environment variable 127 | 128 | To target the command, either of: 129 | 130 | - `-i`, where you specify one or more instance ids, or 131 | - `-t`, where you specify the app name, the stack and the stage. 132 | 133 | ### "Tainted" Instances 134 | 135 | When accessing to an instance the user is greeted with a message of the form 136 | 137 | ``` 138 | This instance should be considered tainted. 139 | It was accessed by 1234567890:alice.smith at Fri Apr 27 08:36:58 BST 2018 140 | ``` 141 | 142 | This message highlights the fact that access is being logged and that the next person will see that the current user has been there. The current wording of **considered tainted** highlights the fact that the user has no idea what has happened during previous ssh sessions and raises awareness of the implications of accessing a box. 143 | 144 | ### "Too many authentication failures" 145 | 146 | This is the result of having too many keys in your agent and exceeding the servers configured authentication attempts. This is fixed as of 0.9.7 as we disable the use of the agent using `IdentitiesOnly`. 147 | 148 | If you still see "Too many authentication failures" then please raise an issue. You can work around it by running `ssh-add -D` to remove all keys from your agent. 149 | 150 | ### --raw usage 151 | 152 | If you need to add extra parameters to the SSH command then you can use `--raw`. In it's simplest form the following are equivalent: 153 | ```bash 154 | ssm ssh -i i-0123456789abcdef0 -p composer 155 | ``` 156 | 157 | and 158 | 159 | ```bash 160 | eval $(ssm ssh -i i-0123456789abcdef0 -p composer --raw) 161 | ``` 162 | 163 | This helps to undertake actions such as construct tunnels. For example to access a remote postgres server: 164 | 165 | ```bash 166 | eval $(ssm ssh -i i-0123456789abcdef0 -p composer --raw) -L 5432:my-postgres-server-hostname:5432 167 | ``` 168 | 169 | Note the use of `eval` in these examples - this is required in order to correctly parse the nested quotes that are output as part of the raw command. If you don't use `eval` then you are likely to see an error message such as `ssh: Could not resolve hostname yes": nodename nor servname provided, or not known`. 170 | 171 | ### Execution targets 172 | 173 | `ssm` needs to be told which instances should execute the provided command(s). You can do this by specifying instance IDs, or by specifying App, Stack, and Stage tags. 174 | 175 | ``` 176 | # by instance ids 177 | --instances i-0123456,i-9876543 178 | -i i-0123456,i-9876543 179 | 180 | # by tag 181 | --tags ,, 182 | -t ,, 183 | ``` 184 | 185 | If you provide tags, `ssm` will search for running instances that are have those tags. 186 | 187 | ### Examples 188 | 189 | Examples of using `cmd` are 190 | 191 | ``` 192 | ./ssm cmd -c date --profile security -t security-hq,security,PROD 193 | ``` 194 | or 195 | ``` 196 | export AWS_PROFILE=security 197 | ./ssm cmd -c date -t security-hq,security,PROD 198 | ``` 199 | 200 | where the `date` command will be ran on all matching instances. 201 | 202 | An example of using `repl` is: 203 | 204 | ``` 205 | ./ssm repl --profile -t security-hq,security,PROD 206 | ``` 207 | 208 | The REPL mode causes `ssm` to generate a list of instances and then wait for commands to be specified. Each command will be executed on all instances and the user can select the instance to display. 209 | 210 | An example of using `ssh` command is: 211 | 212 | ``` 213 | ./ssm ssh --profile -t security-hq,security,PROD 214 | ``` 215 | 216 | This causes `ssm` to generate a temporary ssh key, and install the public key on a specific instance. It will then output the command to `ssh` directly to that instance. The instance must already have appropriate security groups. 217 | 218 | The target for the ssh command will be the public IP address if there is one, otherwise the private IP address. The `--private` flag overrides this behavior and defaults to the private IP address. 219 | 220 | Note that if the argument `-t ,,` resolves to more than one instance, the command will stop with an error message. You can circumvent this behaviour and instruct `ssm` to proceed with one single instance using the command line flags `--oldest` and `--newest`, which select either the oldest or newest instances. 221 | 222 | ### --raw 223 | 224 | This flag allows for a pipe-able ssh connection string. For instance 225 | 226 | ``` 227 | ssm ssh --profile security -t security-hq,security,PROD --newest --raw | xargs -0 -o bash -c 228 | ``` 229 | 230 | Will automatically ssh you to the newest instance running security-hq. Note that you still have to manually accept the new ECDSA key fingerprint. 231 | 232 | ### -d, --dryrun 233 | 234 | Generate SSH command but do not execute (previous default behaviour) 235 | 236 | ``` 237 | ssm ssh --profile security -t security-hq,security,PROD --newest --dryrun 238 | ``` 239 | 240 | Example output: 241 | 242 | ``` 243 | ========= i-0566a4df63c0c35bb ========= 244 | # Dryrun mode. The command below will remain valid for 30 seconds: 245 | 246 | ssh -o "IdentitiesOnly yes" -o "UserKnownHostsFile ... 247 | ``` 248 | 249 | ### -x, --execute 250 | 251 | DEPRECATED - flag is now the default behaviour. This flag makes ssm behave like ssh. The raw output is automatically piped to `xargs -0 -o bash -c`. You would then do 252 | 253 | ``` 254 | ssm ssh --profile security -t security-hq,security,PROD --newest --execute 255 | ``` 256 | 257 | instead of the example given in the previous `--raw` section. 258 | 259 | ### --tunnel 260 | 261 | This flag forwards traffic from a local port through the instance to the specified hostname and port. For example, 262 | 263 | ``` 264 | ssm ssh --profile security -t security-hq,security,PROD --newest --tunnel 5000:example.com:6000 265 | ``` 266 | 267 | would forward all traffic on your machine through the remote instance to example.com:6000. 268 | 269 | ### ---rds-tunnel 270 | 271 | Similar to `tunnel`, this flag forwards traffic from a local port to an AWS RDS database specified by the given tags. For example, 272 | 273 | ``` 274 | ssm ssh --profile security -t security-hq,security,PROD --newest --rds-tunnel 5000:example-db,security,CODE 275 | ``` 276 | 277 | would try to find a single RDS instance with the tags `example-db,security,CODE`, and forward traffic from port 5000 to that RDS instance via the remote instance. 278 | 279 | ## Disabling SSM Tunnel 280 | **By default, SSM proxies your connection via AWS systems manager**, which saves you from opening up port 22, connecting to 281 | the VPN, or using bastion hosts. This requires a recent version of systems manager to be runnning on your machine and 282 | the target machine. You can still connect the old way via port 22 using the flag `--no-ssm-proxy` 283 | 284 | ## Enabling SSM Tunnel 285 | 286 | It is strongly encouraged to connect using the default SSM tunnel behaviour. To get this working you'll need to do the following stuff: 287 | 288 | ### In AWS 289 | 290 | Update the permissions of your instances so that they are allowed to do these things: 291 | 292 | ``` 293 | - ec2messages:AcknowledgeMessage 294 | - ec2messages:DeleteMessage 295 | - ec2messages:FailMessage 296 | - ec2messages:GetEndpoint 297 | - ec2messages:GetMessages 298 | - ec2messages:SendReply 299 | - ssm:UpdateInstanceInformation 300 | - ssm:ListInstanceAssociations 301 | - ssm:DescribeInstanceProperties 302 | - ssm:DescribeDocumentParameters 303 | - ssmmessages:CreateControlChannel 304 | - ssmmessages:CreateDataChannel 305 | - ssmmessages:OpenControlChannel 306 | - ssmmessages:OpenDataChannel 307 | ``` 308 | 309 | See [here](https://github.com/guardian/deploy-tools-platform/blob/master/cloudformation/nexus.template.yaml#L118) for an example complete policy. 310 | 311 | You'll also need to ensure you're using a recent AMI that has at least version 2.3.672.0 of systems manager - this is now in our base images so using a recent amigo AMI should do the job. 312 | 313 | Once these permissions are added, the `ssm-agent` service running on the boxes will need to be restarted before connecting. This will happen as boxes are cycled – e.g. by redeploying your app – or you can restart an agent manually with `sudo snap restart amazon-ssm-agent.amazon-ssm-agent`. 314 | 315 | ### On your machine 316 | 317 | Upgrade your local version of ssm and awscli: 318 | 319 | ``` 320 | brew upgrade ssm 321 | brew upgrade awscli 322 | ``` 323 | 324 | You'll also need to install the systems manager plugin on your machine: 325 | 326 | ``` 327 | brew cask install session-manager-plugin 328 | ``` 329 | 330 | You can then SSH using SSM with the default arguments: 331 | 332 | ``` 333 | ssm ssh -i i-0937fe9baa578095b -p deployTools 334 | ``` 335 | 336 | (Useful tip - you can find the instance id using prism, e.g. `prism -f instanceName amigo`) 337 | 338 | ### Post setup 339 | 340 | Once you've confirmed this is working you can remove any security group rules allowing access on port 22. 341 | 342 | ### More info 343 | 344 | Check out the original PR: https://github.com/guardian/ssm-scala/pull/111 for further details on how this works. 345 | 346 | 347 | ## Bastions 348 | 349 | Bastion are proxy servers used as entry point to private networks and ssm scala supports their use. 350 | 351 | **You may not need a bastion server at all! Prefer to use an SSM tunnel (see above) where possible.** 352 | 353 | ### Introduction 354 | 355 | In this example we assume that you have a bastion with a public IP address (even though the bastion Ingress rules may restrict it to some IP ranges), identified by aws instance id `i-bastion12345`, and an application server, on a private network with private IP address, and with instance id `i-application-12345`, you would then use ssm to connect to it using 356 | 357 | ``` 358 | ssm ssh --profile --bastion i-bastion12345 --bastion-port 2022 -i i-application-12345 359 | ``` 360 | 361 | The outcome of this command is a one-liner of the form 362 | 363 | ``` 364 | ssh -A -i /path/to/private/key-file ubuntu@someting.example.com -t -t ssh ubuntu@10.123.123.123; 365 | ``` 366 | 367 | ### Handling Ports 368 | 369 | You can specify a port that the bastion runs ssh on, with the option `--bastion-port `, example 370 | 371 | ``` 372 | ssm ssh --profile --bastion i-bastion12345 --bastion-port 2345 -i i-application-12345 373 | ``` 374 | 375 | 376 | ### Using tags to specify the target instance 377 | 378 | In the current version of bastion support you will need to specify the bastion using its aws instance id, but you can refer to the application instance using the tag system as in 379 | 380 | ``` 381 | ssm ssh --profile --bastion i-bastion12345 --bastion-port 2022 --tags app,stage,stack 382 | ``` 383 | 384 | together, if the tags may resolve to more than one instance, the `--oldest` and `--newest` flags 385 | 386 | ``` 387 | ssm ssh --profile --bastion i-bastion12345 --bastion-port 2022 --tags app,stage,stack --newest 388 | ``` 389 | 390 | ### Using tags to specify the bastion instance 391 | 392 | If you do not know the id of the current bastion, but it is tagged correctly, it is also possible to use: 393 | 394 | ``` 395 | ssm ssh --profile --bastion-tags --bastion-port 2022 -i i-application-12345 396 | ``` 397 | 398 | This will respect any --newest / --oldest switches, although it is anticipated that there will usually only be one bastion. It will always use the public IP address of the bastion. 399 | 400 | ### Bastion users 401 | 402 | It is possible to specify the user used for connecting to the bastion, this is done with the `--bastion-user ` command line argument. 403 | 404 | ### Bastions with private IP addresses 405 | 406 | When using the standard `ssh` command, the `--private` flag can be used to indicate that the private IP of the target instance should be used for the connection. In the case of bastion connection the target instance is assumed to always be reacheable through a private IP and this flag indicates whether the private IP of the bastion should be used. 407 | 408 | ### Bastions with private keys problems 409 | 410 | There's been occurences of bastions connections strings of the form 411 | 412 | ``` 413 | ssh -A -i /path/to/temp/private/key -t -t ubuntu@bastion-hostname \ 414 | -t -t ssh -t -t ubuntu@target-ip-address; 415 | ``` 416 | not working, because the private file was not found for the second ssh connection, leading to a "Permission denied (publickey)" error message. 417 | 418 | When this happens the user can use the `-a`, `--agent` flag that performs a registration of the private key at the local ssh agent. With this flag, ssm command 419 | 420 | ``` 421 | ssm ssh --profile --bastion \ 422 | -i --agent 423 | ``` 424 | 425 | returns 426 | 427 | ``` 428 | ssh-add /path/to/temp/private/key && \ 429 | ssh -A ubuntu@bastion-hostname \ 430 | -t -t ssh ubuntu@target-ip-address; 431 | ``` 432 | 433 | 434 | ## Secure Copy 435 | 436 | **ssm** support the **scp** sub command for the secure transfer of files and directories. 437 | 438 | ### Introduction 439 | 440 | An example of usage is 441 | 442 | ``` 443 | ./ssm scp -p account -t app,stage,stack /path/to/file1 :/path/to/file1 444 | ``` 445 | 446 | Which outputs 447 | 448 | ``` 449 | # simplified version 450 | scp -i /path/to/identity/file.tmp /path/to/file1 ubuntu@34.242.32.40:/path/to/file2; 451 | ``` 452 | 453 | Otherwise 454 | 455 | ``` 456 | ./ssm scp -p account -t app,stage,stack :/path/to/file1 /path/to/file2 457 | ``` 458 | 459 | outputs 460 | 461 | ``` 462 | # simplified version 463 | scp -i /path/to/identity/file.tmp ubuntu@34.242.32.40:/path/to/file1 /path/to/file2 ; 464 | ``` 465 | 466 | The convention is: the first (left hand side) file is always the source and the second (right hand side) is always the target and the colon, indicates which one is on the remote server. 467 | 468 | ## Development 469 | 470 | During development, the program can be run using sbt, either from an sbt shell or from the CLI in that project. 471 | 472 | $ sbt "run cmd -c pwd --instances i-0123456 --profile xxx --region xxx" 473 | 474 | sbt:ssm-scala> run cmd -c pwd --instances i-0123456 --profile xxx --region xxx 475 | 476 | However, `sbt` traps the program exit so in REPL mode you may find it easier to create and run an executable instead, for this just run 477 | 478 | ```bash 479 | ./generate-executable.sh 480 | ``` 481 | 482 | The result of this script is an executable called `ssm` in the target folder. If you are using a non unix operating system, run `sbt assembly` as you would normally do and then run the ssm.jar file using 483 | 484 | ``` 485 | java -jar /ssm.jar [arguments] 486 | ``` 487 | 488 | ## Release a new version 489 | 490 | To release a new version of `ssm` perform the two following tasks: 491 | 492 | 1. Update the version number in `build.sbt` 493 | 494 | 2. Generate a new executable. Run the following at the top of the repository 495 | ```bash 496 | ./generate-executable.sh 497 | ``` 498 | Note that this script generates the **tar.gz** file needed for the github release as well as outputting the sha256 hash of that file needed for the homebrew-devtools' update. 499 | 500 | 3. Create and merge a PR with the new version number (Eg. #459). 501 | 502 | 4. Create a new tag locally and push it: 503 | ``` 504 | git tag v[version-number] 505 | git push origin v[version-number] 506 | ``` 507 | 508 | 5. Go to the GitHub repository at https://github.com/guardian/ssm-scala/releases 509 | 510 | 6. Draft a new release 511 | 512 | 7. Upload the binary assets: 513 | * The raw executable file (target/scala-X.Y.Z/ssm) 514 | * The tarball (ssm.tar.gz) 515 | 516 | 8. Publish the release 517 | 518 | 9. Make a PR to [https://github.com/guardian/homebrew-devtools/blob/master/Formula/ssm.rb](https://github.com/guardian/homebrew-devtools/blob/master/Formula/ssm.rb) to update the new version's details. 519 | 520 | 521 | ## How to use SSM Scala with your own project 522 | 523 | To use ssm-scala against the instances of your project, the following needs to happen: 524 | 525 | 1. Add permissions with a policy like: 526 | 527 | ```yaml 528 | ExampleAppSSMRunCommandPolicy: 529 | Type: AWS::IAM::Policy 530 | Properties: 531 | PolicyName: example-app-ssm-run-command-policy 532 | PolicyDocument: 533 | Statement: 534 | # minimal policy to allow to (only) run commands via ssm 535 | - Effect: Allow 536 | Resource: "*" 537 | Action: 538 | - ec2messages:AcknowledgeMessage 539 | - ec2messages:DeleteMessage 540 | - ec2messages:FailMessage 541 | - ec2messages:GetEndpoint 542 | - ec2messages:GetMessages 543 | - ec2messages:SendReply 544 | - ssm:UpdateInstanceInformation 545 | - ssm:ListInstanceAssociations 546 | - ssm:DescribeInstanceProperties 547 | - ssm:DescribeDocumentParameters 548 | - ssmmessages:CreateControlChannel 549 | - ssmmessages:CreateDataChannel 550 | - ssmmessages:OpenControlChannel 551 | - ssmmessages:OpenDataChannel 552 | Roles: 553 | - !Ref ExampleAppInstanceRole 554 | ``` 555 | 556 | Example stolen from the [Security-HQ cloudformation](https://github.com/guardian/security-hq/blob/master/cloudformation/security-hq.template.yaml) file. 557 | 558 | 2. Download the executable from the [project release page](https://github.com/guardian/ssm-scala/releases). Instructions on usage can be found in the above sections. 559 | 560 | Note: SSM needs the target server to have outbound port 443 (ssm-agent's communication with AWS's SSM and EC2 Messages endpoints). 561 | 562 | 563 | ##License 564 | 565 | Copyright (c) 2018 Guardian News & Media. Available under the Apache License. 566 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | name := "ssm-scala" 2 | organization := "com.gu" 3 | version := "3.7.1" 4 | 5 | // be sure to also update this in the `generate-executable.sh` script 6 | scalaVersion := "3.7.0" 7 | 8 | val awsSdkVersion = "1.12.784" 9 | 10 | libraryDependencies ++= Seq( 11 | "com.amazonaws" % "aws-java-sdk-ssm" % awsSdkVersion, 12 | "com.amazonaws" % "aws-java-sdk-sts" % awsSdkVersion, 13 | "com.amazonaws" % "aws-java-sdk-ec2" % awsSdkVersion, 14 | "com.amazonaws" % "aws-java-sdk-rds" % awsSdkVersion, 15 | "com.github.scopt" %% "scopt" % "4.1.0", 16 | "com.googlecode.lanterna" % "lanterna" % "3.1.3", 17 | "ch.qos.logback" % "logback-classic" % "1.5.18", 18 | "com.typesafe.scala-logging" %% "scala-logging" % "3.9.5", 19 | "com.fasterxml.jackson.core" % "jackson-databind" % "2.19.0", 20 | "org.bouncycastle" % "bcpkix-jdk18on" % "1.80", 21 | "org.scalatest" %% "scalatest" % "3.2.19" % Test 22 | ) 23 | 24 | // Required as jackson causes a merge issue with sbt-assembly 25 | // See: https://github.com/sbt/sbt-assembly/issues/391 26 | assemblyMergeStrategy := { 27 | case PathList("META-INF", _*) => MergeStrategy.discard 28 | case _ => MergeStrategy.first 29 | } 30 | assemblyJarName := "ssm.jar" 31 | 32 | scalacOptions := Seq( 33 | "-unchecked", 34 | "-deprecation", 35 | "-release:11", 36 | ) 37 | -------------------------------------------------------------------------------- /generate-executable-prefix: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | MYSELF=`which "$0" 2>/dev/null` 3 | [ $? -gt 0 -a -f "$0" ] && MYSELF="./$0" 4 | java=java 5 | if test -n "$JAVA_HOME"; then 6 | java="$JAVA_HOME/bin/java" 7 | fi 8 | 9 | EXECUTE_FLAG=1 10 | for param in $@; do 11 | test "$param" = "-d" && EXECUTE_FLAG=0 12 | test "$param" = "--dryrun" && EXECUTE_FLAG=0 13 | # raw mode makes no sense if immediately executing the command 14 | test "$param" = "--raw" && EXECUTE_FLAG=0 15 | test "$param" = "--help" && EXECUTE_FLAG=0 16 | done 17 | 18 | if (test $EXECUTE_FLAG -eq 0); then 19 | exec "$java" -jar $MYSELF "$@" 20 | else 21 | exec "$java" -jar $MYSELF "$@" | xargs -0 -o bash -c 22 | fi 23 | 24 | exit $? 25 | -------------------------------------------------------------------------------- /generate-executable.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" 3 | SCALA_FOLDER="scala-3.7.0" 4 | cd $DIR 5 | sbt assembly 6 | cat "$DIR/generate-executable-prefix" "$DIR/target/$SCALA_FOLDER/ssm.jar" > "$DIR/target/$SCALA_FOLDER/ssm" 7 | chmod +x "$DIR/target/$SCALA_FOLDER/ssm" 8 | echo "ssm executable now available at $DIR/target/$SCALA_FOLDER/ssm" 9 | cd "$DIR/target/$SCALA_FOLDER" 10 | tar -czf ssm.tar.gz ssm 11 | echo "ssm tar.zg file now available at $DIR/target/$SCALA_FOLDER/ssm.tar.gz" 12 | echo "ssm.tar.gz sha256:" 13 | shasum -a 256 ssm.tar.gz 14 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=1.11.1 2 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "2.3.1") 2 | -------------------------------------------------------------------------------- /scripts/ssh-report: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ $# -lt 2 ]]; then 4 | echo 'Usage: ssh-reports []' 5 | echo ' offsets are of the form support by mac date eg -1H, -3H' 6 | exit 1 7 | fi 8 | 9 | which jq > /dev/null 10 | if [[ $? -ne 0 ]]; then 11 | echo 'This script requires jq' 12 | exit 1 13 | fi 14 | 15 | PROFILE=$1; shift 16 | START_OFFSET="-v $1"; shift 17 | [[ $1 ]] && STOP_OFFSET="-v $1" && shift 18 | 19 | ISO8601_FORMAT="%Y-%m-%dT%H:%M:%SZ" 20 | AFTER="$(date $START_OFFSET +"$ISO8601_FORMAT")" 21 | BEFORE="$(date $STOP_OFFSET +"$ISO8601_FORMAT")" 22 | 23 | FILTER="[ { \"key\": \"InvokedAfter\", \"value\": \"$AFTER\" }, { \"key\": \"InvokedBefore\", \"value\": \"$BEFORE\" } ]" 24 | 25 | aws --profile $PROFILE --region eu-west-1 ssm list-commands --filters "$FILTER" \ 26 | | jq '.Commands[]|[select(.Parameters.commands[0]|contains("authorized_keys"))|{Comment, InstanceIds, RequestedDateTime}]' 27 | 28 | -------------------------------------------------------------------------------- /src/main/resources/logback.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | ${user.home}/.ssm/ssm.log 4 | false 5 | 6 | %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /src/main/scala/com/gu/ssm/ArgumentParser.scala: -------------------------------------------------------------------------------- 1 | package com.gu.ssm 2 | 3 | import java.io.File 4 | import com.amazonaws.regions.{Region, RegionUtils, Regions} 5 | import com.gu.ssm.Arguments.{bastionDefaultUser, defaultHostKeyAlgPreference, targetInstanceDefaultUser} 6 | import scopt.OptionParser 7 | 8 | 9 | object ArgumentParser { 10 | 11 | val argParser: OptionParser[Arguments] = new OptionParser[Arguments]("ssm") { 12 | 13 | help("help").text("prints this usage text") 14 | 15 | opt[String]('p', "profile").optional() 16 | .action { (profile, args) => 17 | args.copy(profile = Some(profile)) 18 | } text "The AWS profile name to use for authenticating this execution" 19 | 20 | opt[Seq[String]]('i', "instances") 21 | .action { (instanceIds, args) => 22 | val instances = instanceIds.map(i => InstanceId(i)).toList 23 | args.copy(executionTarget = Some(ExecutionTarget(instances = Some(instances)))) 24 | } text "Specify the instance ID(s) on which the specified command(s) should execute" 25 | 26 | opt[Seq[String]]('t', "tags") 27 | .validate { tagsStr => 28 | Logic.extractSASTags(tagsStr).map(_ => ()) 29 | } 30 | .action { (tagsStr, args) => 31 | Logic.extractSASTags(tagsStr) 32 | .fold( 33 | _ => args, 34 | tagValues => args.copy(executionTarget = Some(ExecutionTarget(tagValues = Some(tagValues)))) 35 | ) 36 | } text "Search for instances by tag. If you provide less than 3 tags assumed order is app,stage,stack." + 37 | " e.g. '--tags riff-raff,prod' or '--tags grafana' Upper/lowercase variations will be tried." 38 | 39 | opt[String]('r', "region").optional() 40 | .validate { region => 41 | try { 42 | RegionUtils.getRegion(region) 43 | success 44 | } catch { 45 | case _: IllegalArgumentException => 46 | failure(s"Invalid AWS region name, $region") 47 | } 48 | } action { (region, args) => 49 | args.copy(region = RegionUtils.getRegion(region)) 50 | } text "AWS region name (defaults to eu-west-1)" 51 | 52 | opt[Unit]("verbose").action( (_, c) => 53 | c.copy(verbose = true) ).text("enable more verbose logging") 54 | 55 | opt[Unit]("use-default-credentials-provider").optional() 56 | .action((value, args) => args.copy(useDefaultCredentialsProvider = true)) 57 | .text("Use the default AWS credentials provider chain rather than profile credentials. " + 58 | "This option is required when running within AWS itself.") 59 | 60 | cmd("cmd") 61 | .action((_, c) => c.copy(mode = Some(SsmCmd))) 62 | .text("Execute a single (bash) command, or a file containing bash commands") 63 | .children( 64 | opt[String]('u', "user").optional() 65 | .action((user, args) => args.copy(targetInstanceUser = Some(user))) 66 | .text(s"Execute command on remote host as this user (default: $targetInstanceDefaultUser)"), 67 | opt[String]('c', "cmd").optional() 68 | .action((cmd, args) => args.copy(toExecute = Some(cmd))) 69 | .text("A bash command to execute"), 70 | opt[File]('f', "file").optional() 71 | .action((file, args) => args.copy(toExecute = Some(Logic.generateScript(Right(file))))) 72 | .text("A file containing bash commands to execute") 73 | ) 74 | 75 | cmd("repl") 76 | .action((_, c) => c.copy(mode = Some(SsmRepl))) 77 | .text("Run SSM in interactive/repl mode") 78 | 79 | cmd("ssh") 80 | .action((_, c) => c.copy(mode = Some(SsmSsh))) 81 | .text("Create and upload a temporary ssh key") 82 | .children( 83 | opt[String]('u', "user").optional() 84 | .action((user, args) => args.copy(targetInstanceUser = Some(user))) 85 | .text(s"Connect to remote host as this user (default: $targetInstanceDefaultUser)"), 86 | opt[Int]("port").optional() 87 | .action((port, args) => args.copy(targetInstancePortNumber = Some(port))) 88 | .text(s"Connect to remote host on this port"), 89 | opt[Unit]("newest").optional() 90 | .action((_, args) => { 91 | args.copy( 92 | singleInstanceSelectionMode = SismNewest, 93 | isSelectionModeNewest = true) 94 | }) 95 | .text("Selects the newest instance if more than one instance was specified"), 96 | opt[Unit]("oldest").optional() 97 | .action((_, args) => { 98 | args.copy( 99 | singleInstanceSelectionMode = SismOldest, 100 | isSelectionModeOldest = true) 101 | }) 102 | .text("Selects the oldest instance if more than one instance was specified"), 103 | opt[Unit]("private").optional() 104 | .action((_, args) => { 105 | args.copy( 106 | usePrivateIpAddress = true) 107 | }) 108 | .text("Use private IP address (must be routable via VPN Gateway)"), 109 | opt[Unit]("raw").optional() 110 | .action((_, args) => { 111 | args.copy( 112 | rawOutput = true) 113 | }) 114 | .text("Unix pipe-able ssh connection string. Note: disables automatic execution. You must use 'eval' to execute this due to nested quoting"), 115 | opt[Unit]('x', "execute").optional() 116 | .action((_, args) => { 117 | args.copy( 118 | rawOutput = true) 119 | }) 120 | .text("[Deprecated - new default behaviour] Makes ssm behave like a single command (eg: `--raw` with automatic piping to the shell)"), 121 | opt[Unit]('d', "dryrun").optional() 122 | .action((_, args) => { 123 | args.copy( 124 | rawOutput = false) 125 | }) 126 | .text("Generate SSH command but do not execute (previous default behaviour)"), 127 | opt[Unit]('A', "agent").optional() 128 | .action((_, args) => { 129 | args.copy( 130 | useAgent = Some(true)) 131 | }) 132 | .text("Use the local ssh agent to register the private key (and do not use -i); only bastion connections"), 133 | opt[Unit]('a', "no-agent").optional() 134 | .action((_, args) => { 135 | args.copy( 136 | useAgent = Some(false)) 137 | }) 138 | .text("Do not use the local ssh agent"), 139 | opt[String]('b', "bastion").optional() 140 | .action((bastion, args) => { 141 | args 142 | .copy(bastionInstance = Some(ExecutionTarget(Some(List(InstanceId(bastion))), None))) 143 | }) 144 | .text(s"Connect through the given bastion specified by its instance id; implies -A (use agent) unless followed by -a."), 145 | opt[Seq[String]]('B', "bastion-tags").optional() 146 | .validate { tagsStr => 147 | Logic.extractSASTags(tagsStr).map(_ => ()) 148 | } 149 | .action { (tagsStr, args) => 150 | Logic.extractSASTags(tagsStr) 151 | .fold( 152 | _ => args, 153 | tagValues => { 154 | args 155 | .copy(bastionInstance = Some(ExecutionTarget(None, Some(tagValues)))) 156 | } 157 | ) 158 | } text(s"Connect through the given bastion identified by its tags; implies -a (use agent) unless followed by -A."), 159 | opt[Int]("bastion-port").optional() 160 | .action((bastionPortNumber, args) => args.copy(bastionPortNumber = Some(bastionPortNumber))) 161 | .text(s"Connect through the given bastion at a given port. "), 162 | opt[String]("bastion-user").optional() 163 | .action((bastionUser, args) => args.copy(bastionUser = Some(bastionUser))) 164 | .text(s"Connect to bastion as this user (default: $bastionDefaultUser). "), 165 | opt[String]("host-key-alg-preference").optional().unbounded() 166 | .action((alg, args) => args.copy(hostKeyAlgPreference = alg :: args.hostKeyAlgPreference)) 167 | .text(s"The preferred host key algorithms, can be specified multiple times - last is preferred (default: ${defaultHostKeyAlgPreference.mkString(", ")})"), 168 | opt[Unit]("ssm-tunnel").optional() 169 | .text("[deprecated]"), 170 | opt[Unit]("no-ssm-proxy").optional() 171 | .action((_, args) => args.copy(tunnelThroughSystemsManager = false)) 172 | .text("Do not connect to the host proxying via AWS Systems Manager - go direct to port 22. Useful for instances running old versions of systems manager (< 2.3.672.0)"), 173 | opt[String]("tunnel").optional() 174 | .validate { tunnelStr => 175 | Logic.extractTunnelConfig(tunnelStr).map(_ => ()) 176 | } 177 | .action((tunnelStr, args) => { 178 | Logic.extractTunnelConfig(tunnelStr) 179 | .fold( 180 | _ => args, 181 | tunnelTarget => args.copy(tunnelTarget = Some(tunnelTarget))) 182 | }) 183 | .text("Forward traffic from the given local port to the given host and port on the remote side. Accepts the format `localPort:host:remotePort`, " + 184 | "e.g. --tunnel 5000:a.remote.host.com:5000"), 185 | opt[String]("rds-tunnel").optional() 186 | .validate { tunnelStr => 187 | Logic.extractRDSTunnelConfig(tunnelStr).map(_ => ()) 188 | } 189 | .action((tunnelStr, args) => { 190 | Logic.extractRDSTunnelConfig(tunnelStr) 191 | .fold( 192 | _ => args, 193 | tunnelTarget => args.copy(rdsTunnelTarget = Some(tunnelTarget))) 194 | }) 195 | .text("Forward traffic from a given local port to a RDS database specified by tags. Accepts the format `localPort:tags`, where `tags` is a comma-separated list of tag values, " + 196 | "e.g. --rds-tunnel 5000:app,stack,stage"), 197 | checkConfig( c => 198 | if (c.isSelectionModeOldest && c.isSelectionModeNewest) failure("You cannot both specify --newest and --oldest") 199 | else if (c.tunnelTarget.isDefined && c.rdsTunnelTarget.isDefined) failure("You cannot specify both --tunnel and --rdsTunnel") 200 | else success ) 201 | ) 202 | 203 | cmd("scp") 204 | .action((_, c) => c.copy(mode = Some(SsmScp))) 205 | .text("Secure Copy") 206 | .children( 207 | opt[String]('u', "user").optional() 208 | .action((user, args) => args.copy(targetInstanceUser = Some(user))) 209 | .text(s"Connect to remote host as this user (default: $targetInstanceDefaultUser)"), 210 | opt[Int]("port").optional() 211 | .action((port, args) => args.copy(targetInstancePortNumber = Some(port))) 212 | .text(s"Connect to remote host on this port"), 213 | opt[Unit]("newest").optional() 214 | .action((_, args) => { 215 | args.copy( 216 | singleInstanceSelectionMode = SismNewest, 217 | isSelectionModeNewest = true) 218 | }) 219 | .text("Selects the newest instance if more than one instance was specified"), 220 | opt[Unit]("oldest").optional() 221 | .action((_, args) => { 222 | args.copy( 223 | singleInstanceSelectionMode = SismOldest, 224 | isSelectionModeOldest = true) 225 | }) 226 | .text("Selects the oldest instance if more than one instance was specified"), 227 | opt[Unit]("private").optional() 228 | .action((_, args) => { 229 | args.copy( 230 | usePrivateIpAddress = true) 231 | }) 232 | .text("Use private IP address (must be routable via VPN Gateway)"), 233 | opt[Unit]("raw").optional() 234 | .action((_, args) => { 235 | args.copy( 236 | rawOutput = true) 237 | }) 238 | .text("Unix pipe-able scp connection string"), 239 | opt[Unit]('x', "execute").optional() 240 | .action((_, args) => { 241 | args.copy( 242 | rawOutput = true) 243 | }) 244 | .text("[Deprecated - new default behaviour] Makes ssm behave like a single command (eg: `--raw` with automatic piping to the shell)"), 245 | opt[Unit]('d', "dryrun").optional() 246 | .action((_, args) => { 247 | args.copy( 248 | rawOutput = false) 249 | }) 250 | .text("Generate SCP command but do not execute (previous default behaviour)"), 251 | opt[Unit]("ssm-tunnel").optional() 252 | .text("[deprecated]"), 253 | opt[Unit]("no-ssm-proxy").optional() 254 | .action((_, args) => args.copy(tunnelThroughSystemsManager = false)) 255 | .text("Do not connect to the host proxying via AWS Systems Manager - go direct to port 22. Useful for instances running old versions of systems manager (< 2.3.672.0)"), 256 | 257 | arg[String]("[:]...").required() 258 | .action( (sourceFile, args) => args.copy(sourceFile = Some(sourceFile)) ) 259 | .text("Source file for the scp sub command. See README for details"), 260 | arg[String]("[:]...").required() 261 | .action( (targetFile, args) => args.copy(targetFile = Some(targetFile)) ) 262 | .text("Target file for the scp sub command. See README for details"), 263 | checkConfig( c => 264 | if (c.isSelectionModeOldest && c.isSelectionModeNewest) failure("You cannot both specify --newest and --oldest") 265 | else success ) 266 | ) 267 | 268 | checkConfig { args => 269 | if (args.mode.isEmpty) Left("You must select a mode to use: cmd, repl or ssh") 270 | else if (args.toExecute.isEmpty && args.mode.contains(SsmCmd)) Left("You must provide commands to execute (src-file or cmd)") 271 | else if (args.executionTarget.isEmpty) Left("You must provide a list of target instances (-i) or instance App/Stage/Stack tags (-t)") 272 | else if (!args.useDefaultCredentialsProvider && args.profile.isEmpty && !System.getenv().containsKey("AWS_PROFILE")) Left("Expected --profile, --use-default-credentials-provider or AWS_PROFILE environment variable") 273 | else Right(()) 274 | } 275 | } 276 | } 277 | -------------------------------------------------------------------------------- /src/main/scala/com/gu/ssm/IO.scala: -------------------------------------------------------------------------------- 1 | package com.gu.ssm 2 | 3 | import com.amazonaws.regions.Region 4 | import com.amazonaws.services.ec2.AmazonEC2Async 5 | import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceAsync 6 | import com.amazonaws.services.simplesystemsmanagement.AWSSimpleSystemsManagementAsync 7 | import com.gu.ssm.aws.{EC2, SSM, STS, RDS} 8 | import com.gu.ssm.utils.attempt.{ArgumentsError, Attempt, Failure} 9 | 10 | import scala.concurrent.ExecutionContext 11 | import com.amazonaws.services.rds.AmazonRDSAsync 12 | 13 | 14 | object IO { 15 | def resolveInstances(executionTarget: ExecutionTarget, ec2Client: AmazonEC2Async)(implicit ec: ExecutionContext): Attempt[List[Instance]] = { 16 | executionTarget.instances.map( instances => 17 | EC2.resolveInstanceIds(instances, ec2Client) 18 | ).orElse { 19 | executionTarget.tagValues.map(EC2.resolveByTags(_, ec2Client)) 20 | }.getOrElse(Attempt.Left(Failure("Unable to resolve execution target", "You must provide an execution target (instance(s) or tags)", ArgumentsError))) 21 | } 22 | 23 | def resolveRDSTunnelTarget(target: TunnelTargetWithRDSTags, rdsClient: AmazonRDSAsync)(implicit ec: ExecutionContext): Attempt[TunnelTargetWithHostName] = { 24 | RDS.resolveByTags(target.remoteTags.toList, rdsClient).flatMap { 25 | case rdsInstance :: Nil => Attempt.Right(TunnelTargetWithHostName(target.localPort, rdsInstance.hostname, rdsInstance.port, target.remoteTags)) 26 | case Nil => Attempt.Left(Failure("Could not find target from tags", s"We could not find an RDS instance with the tags: ${target.remoteTags.mkString(", ")}", ArgumentsError)) 27 | case tooManyInstances => 28 | Attempt.Left(Failure("More than one tunnel target resolved from tags", s"We expected to find a single target, but there was more than one tunnel target resolved from the tags: ${target.remoteTags.mkString(", ")}", ArgumentsError)) 29 | } 30 | } 31 | 32 | def executeOnInstances(instanceIds: List[InstanceId], username: String, cmd: String, client: AWSSimpleSystemsManagementAsync)(implicit ec: ExecutionContext): Attempt[List[(InstanceId, Either[CommandStatus, CommandResult])]] = { 33 | for { 34 | cmdId <- SSM.sendCommand(instanceIds, cmd, username, client) 35 | results <- SSM.getCmdOutputs(instanceIds, cmdId, client) 36 | } yield results 37 | } 38 | 39 | def executeOnInstance(instanceId: InstanceId, username: String, script: String, client: AWSSimpleSystemsManagementAsync)(implicit ec: ExecutionContext): Attempt[Either[CommandStatus, CommandResult]] = { 40 | for { 41 | cmdId <- SSM.sendCommand(List(instanceId), script, username, client) 42 | result <- SSM.getCmdOutput(instanceId, cmdId, client).map{ case (_, result) => result } 43 | } yield result 44 | } 45 | 46 | def executeOnInstanceAsync(instanceId: InstanceId, username: String, script: String, client: AWSSimpleSystemsManagementAsync)(implicit ec: ExecutionContext): Attempt[String] = { 47 | for { 48 | cmdId <- SSM.sendCommand(List(instanceId), script, username, client) 49 | } yield cmdId 50 | } 51 | 52 | def tagAsTainted(instanceId: InstanceId, username: String,ec2Client: AmazonEC2Async)(implicit ec: ExecutionContext): Attempt[Unit] = 53 | EC2.tagInstance(instanceId, "taintedBy", username, ec2Client) 54 | 55 | def getSSMConfig(ec2Client: AmazonEC2Async, stsClient: AWSSecurityTokenServiceAsync, executionTarget: ExecutionTarget)(implicit ec: ExecutionContext): Attempt[SSMConfig] = { 56 | for { 57 | instances <- IO.resolveInstances(executionTarget, ec2Client) 58 | name <- STS.getCallerIdentity(stsClient) 59 | } yield SSMConfig(instances, name) 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /src/main/scala/com/gu/ssm/Interactive.scala: -------------------------------------------------------------------------------- 1 | package com.gu.ssm 2 | 3 | import com.amazonaws.regions.Region 4 | import com.googlecode.lanterna.{TerminalSize, TextColor} 5 | import com.googlecode.lanterna.gui2.Interactable.Result 6 | import com.googlecode.lanterna.gui2._ 7 | import com.googlecode.lanterna.gui2.dialogs.{MessageDialog, WaitingDialog} 8 | import com.googlecode.lanterna.input.{KeyStroke, KeyType} 9 | import com.googlecode.lanterna.terminal.{DefaultTerminalFactory, Terminal, TerminalResizeListener} 10 | import com.gu.ssm.utils.attempt.{Attempt, ErrorCode, FailedAttempt, Failure} 11 | import com.typesafe.scalalogging.LazyLogging 12 | 13 | import scala.concurrent.{ExecutionContext, Future} 14 | 15 | class InteractiveProgram(val awsClients: AWSClients)(implicit ec: ExecutionContext) extends LazyLogging { 16 | val ui = new InteractiveUI(this) 17 | 18 | def main(profile: Option[String], region: Region, executionTarget: ExecutionTarget): Unit = { 19 | // start UI on a new thread (it blocks while it listens for keyboard input) 20 | Future { 21 | ui.start() 22 | } 23 | val configAttempt = for { 24 | config <- IO.getSSMConfig(awsClients.ec2Client, awsClients.stsClient, executionTarget) 25 | _ <- Attempt.fromEither(Logic.checkInstancesList(config)) 26 | } yield config 27 | 28 | configAttempt.onComplete { 29 | case Right(SSMConfig(targets, name)) => { 30 | val incorrectInstancesFromInstancesTag = Logic.computeIncorrectInstances(executionTarget, targets.map(i => i.id)) 31 | ui.ready(targets.map(i => i.id), name, incorrectInstancesFromInstancesTag) 32 | } 33 | case Left(failedAttempt) => 34 | ui.displayError(failedAttempt) 35 | ui.ready(List(), "", Nil) 36 | } 37 | } 38 | 39 | /** 40 | * Kick off execution of a new command and update UI when it returns 41 | */ 42 | def executeCommand(command: String, instances: List[InstanceId], username: String, instancesNotFound: List[InstanceId]): Unit = { 43 | IO.executeOnInstances(instances, username, command, awsClients.ssmClient).onComplete { 44 | case Right(results) => 45 | ui.displayResults(instances, username, ResultsWithInstancesNotFound(results, instancesNotFound)) 46 | case Left(fa) => 47 | ui.displayError(fa) 48 | } 49 | } 50 | 51 | def exit(): Unit = { 52 | System.exit(0) 53 | } 54 | } 55 | 56 | class InteractiveUI(program: InteractiveProgram) extends LazyLogging { 57 | val terminalFactory = new DefaultTerminalFactory() 58 | private val screen = terminalFactory.createScreen() 59 | private val guiThreadFactory = new SeparateTextGUIThread.Factory() 60 | val textGUI = new MultiWindowTextGUI(guiThreadFactory, screen) 61 | screen.startScreen() 62 | 63 | /** 64 | * Create window that displays the main UI along with the results of the previous command 65 | */ 66 | def mainWindow(instances: List[InstanceId], username: String, extendedResults: ResultsWithInstancesNotFound): BasicWindow = { 67 | 68 | val window = new BasicWindow(username) 69 | 70 | val initialSize = screen.getTerminal.getTerminalSize 71 | val contentPanel = new Panel(new LinearLayout()) 72 | .setPreferredSize(fullscreenPanelSize(initialSize)) 73 | val layoutManager = contentPanel.getLayoutManager.asInstanceOf[LinearLayout] 74 | layoutManager.setSpacing(0) 75 | 76 | val resizer = new TerminalResizeListener { 77 | override def onResized(terminal: Terminal, newSize: TerminalSize): Unit = 78 | contentPanel.setPreferredSize(fullscreenPanelSize(newSize)) 79 | } 80 | 81 | if (instances.nonEmpty) { 82 | contentPanel.addComponent(new Label("Command to run")) 83 | val cmdInput = new TextBox(new TerminalSize(40, 1)) { 84 | override def handleKeyStroke(keyStroke: KeyStroke): Result = { 85 | keyStroke.getKeyType match { 86 | case KeyType.Enter => 87 | program.executeCommand(this.getText, instances, username, extendedResults.instancesNotFound) 88 | val loading = WaitingDialog.createDialog("Executing...", "Executing command on instances") 89 | textGUI.addWindow(loading) 90 | Result.HANDLED 91 | case _ => 92 | super.handleKeyStroke(keyStroke) 93 | } 94 | } 95 | } 96 | contentPanel.addComponent(cmdInput) 97 | } 98 | 99 | if (extendedResults.instancesNotFound.nonEmpty) { 100 | contentPanel.addComponent(new EmptySpace()) 101 | contentPanel.addComponent(new Label(s"The following instance(s) could not be found: ${extendedResults.instancesNotFound.map(_.id).mkString(", ")}").setForegroundColor(TextColor.ANSI.RED)) 102 | contentPanel.addComponent(new EmptySpace()) 103 | } 104 | 105 | // show results, if present 106 | if (extendedResults.results.nonEmpty) { 107 | val outputs = extendedResults.results.zipWithIndex.map { case ((_, result), i) => 108 | val outputStreams = result match { 109 | case Right(cmdResult) => 110 | cmdResult 111 | case Left(status) => 112 | CommandResult("", status.toString, commandFailed = true) 113 | } 114 | i -> outputStreams 115 | }.toMap 116 | 117 | val errOutputBox = new Label(outputs(0).stdErr) 118 | errOutputBox.setForegroundColor(TextColor.ANSI.RED) 119 | val stdOutputBox = new Label(outputs(0).stdOut) 120 | 121 | val listener = new ComboBox.Listener { 122 | override def onSelectionChanged(selectedIndex: Int, previousSelection: Int, changedByUserInteraction: Boolean): Unit = { 123 | errOutputBox.setText(outputs(selectedIndex).stdErr) 124 | stdOutputBox.setText(outputs(selectedIndex).stdOut) 125 | } 126 | } 127 | 128 | val instancesComboBox: ComboBox[String] = new ComboBox(instances.map(_.id)*).addListener(listener) 129 | contentPanel.addComponent(instancesComboBox) 130 | 131 | contentPanel.addComponent(new EmptySpace()) 132 | contentPanel.addComponent(new Separator(Direction.HORIZONTAL)) 133 | contentPanel.addComponent(errOutputBox) 134 | contentPanel.addComponent(stdOutputBox) 135 | } 136 | 137 | // close button 138 | contentPanel.addComponent(new EmptySpace()) 139 | contentPanel.addComponent(new Separator(Direction.HORIZONTAL)) 140 | contentPanel.addComponent(new Button("Close", () => { 141 | window.close() 142 | program.exit() 143 | })) 144 | 145 | window.setComponent(contentPanel) 146 | window 147 | } 148 | 149 | /** 150 | * "fullscreen" with space for panel borders 151 | */ 152 | def fullscreenPanelSize(newSize: TerminalSize): TerminalSize = { 153 | new TerminalSize(newSize.getColumns - 4, newSize.getRows - 4) 154 | } 155 | 156 | def start(): Unit = { 157 | logger.debug("Starting interactive UI") 158 | textGUI.getGUIThread.asInstanceOf[AsynchronousTextGUIThread].start() 159 | val window = WaitingDialog.createDialog("Loading...", "Loading instance information") 160 | textGUI.addWindowAndWait(window) 161 | } 162 | 163 | def ready(instances: List[InstanceId], username: String, instancesToReport: List[InstanceId]): Unit = { 164 | logger.trace("resolved instances and username, UI ready") 165 | textGUI.removeWindow(textGUI.getActiveWindow) 166 | textGUI.addWindow(mainWindow(instances, username, ResultsWithInstancesNotFound(Nil, instancesToReport))) 167 | textGUI.updateScreen() 168 | } 169 | 170 | def searching(): Unit = { 171 | logger.trace("waiting to resolve instances and username, UI ready") 172 | textGUI.removeWindow(textGUI.getActiveWindow) 173 | textGUI.addWindow(mainWindow(List(), "", ResultsWithInstancesNotFound(Nil, Nil))) 174 | textGUI.updateScreen() 175 | } 176 | 177 | def displayResults(instances: List[InstanceId], username: String, extendedResults: ResultsWithInstancesNotFound): Unit = { 178 | logger.trace("displaying results") 179 | textGUI.removeWindow(textGUI.getActiveWindow) 180 | textGUI.addWindow(mainWindow(instances, username, extendedResults)) 181 | textGUI.updateScreen() 182 | } 183 | 184 | def displayError(fa: FailedAttempt): Unit = { 185 | logger.trace("displaying error") 186 | textGUI.removeWindow(textGUI.getActiveWindow) 187 | MessageDialog.showMessageDialog(textGUI, "Error", fa.failures.map(_.friendlyMessage).mkString(", ")) 188 | } 189 | } 190 | -------------------------------------------------------------------------------- /src/main/scala/com/gu/ssm/Logic.scala: -------------------------------------------------------------------------------- 1 | package com.gu.ssm 2 | 3 | import com.amazonaws.auth.DefaultAWSCredentialsProviderChain 4 | import com.amazonaws.auth.profile.ProfileCredentialsProvider 5 | 6 | import java.io.File 7 | import com.amazonaws.regions.Region 8 | import com.amazonaws.services.ec2.AmazonEC2Async 9 | import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceAsync 10 | import com.amazonaws.services.simplesystemsmanagement.AWSSimpleSystemsManagementAsync 11 | import com.gu.ssm.aws.{EC2, SSM, STS} 12 | import com.gu.ssm.utils.attempt._ 13 | 14 | import scala.io.Source 15 | import com.amazonaws.services.rds.AmazonRDSClient 16 | import com.amazonaws.services.rds.AmazonRDSAsync 17 | import com.gu.ssm.aws.RDS 18 | 19 | object Logic { 20 | def generateScript(toExecute: Either[String, File]): String = { 21 | toExecute match { 22 | case Right(script) => Source.fromFile(script, "UTF-8").mkString 23 | case Left(cmd) => cmd 24 | } 25 | } 26 | 27 | def extractSASTags(tags: Seq[String]): Either[String, List[String]] = { 28 | if (tags.length > 3 || tags.isEmpty || tags.head.length == 0) Left("Please supply at least one and no more than 3 tags. " + 29 | "If you specify less than 3 tags order assumed is app,stage,stack") 30 | else Right(tags.toList) 31 | } 32 | 33 | val tunnelValidationErrorMsg = "Please specify a tunnel target in the format localPort:host:remotePort." 34 | 35 | def extractTunnelConfig(tunnelStr: String): Either[String, TunnelTargetWithHostName] = { 36 | tunnelStr.split(":").toList match { 37 | case localPortStr :: targetStr :: remotePortStr :: Nil => 38 | (localPortStr.toIntOption, targetStr, remotePortStr.toIntOption) match { 39 | case (Some(localPort), targetStr, Some(remotePort)) => 40 | Right(TunnelTargetWithHostName(localPort, targetStr, remotePort)) 41 | case _ => Left(s"$tunnelValidationErrorMsg Ports must be integers.") 42 | } 43 | case _ => Left(tunnelValidationErrorMsg) 44 | } 45 | } 46 | 47 | val rdsTunnelValidationErrorMsg = "Please specify a tunnel target in the format localPort:tags, where tags is a comma-separated list of tag values." 48 | 49 | def extractRDSTunnelConfig(tunnelStr: String): Either[String, TunnelTargetWithRDSTags] = { 50 | tunnelStr.split(":").toList match { 51 | case localPortStr :: tagsStr :: Nil => 52 | localPortStr.toIntOption match { 53 | case Some(localPort) => 54 | extractSASTags(tagsStr.split(",").toSeq).flatMap { tags => 55 | Right(TunnelTargetWithRDSTags(localPort, tags)) 56 | } 57 | case None => Left(rdsTunnelValidationErrorMsg) 58 | } 59 | case _ => Left(rdsTunnelValidationErrorMsg) 60 | } 61 | } 62 | 63 | def checkInstancesList(config: SSMConfig): Either[FailedAttempt, Unit] = config.targets match { 64 | case List() => Left(FailedAttempt(List(Failure("No instances found", "No instances found", ErrorCode)))) 65 | case _ => Right(()) 66 | } 67 | 68 | def getSSHInstance(instances: List[Instance], sism: SingleInstanceSelectionMode): Either[FailedAttempt, Instance] = { 69 | instances.sortBy(_.launchInstant) match { 70 | case Nil => Left(FailedAttempt(Failure(s"Unable to identify a single instance", s"Could not find any instance", UnhandledError))) 71 | case instance :: Nil => Right(instance) 72 | case _ :: _ :: _ if sism == SismUnspecified => Left(FailedAttempt(Failure(s"Unable to identify a single instance", s"Error choosing single instance, found ${instances.map(_.id.id).mkString(", ")}. Use --oldest or --newest to select single instance", UnhandledError))) 73 | case instances if sism == SismNewest => Right(instances.last) // we know that `instances` is not empty, otherwise first case would have applied, therefore calling `.last` is safe 74 | case instance :: _ if sism == SismOldest => Right(instance) 75 | case _ => Left(FailedAttempt(Failure(s"Unable to identify a single instance", s"Could not find any instance", UnhandledError))) 76 | } 77 | } 78 | 79 | def getClients(profile: Option[String], region: Region, useDefaultCredentialsProvider: Boolean): AWSClients = { 80 | val credentialsProvider = profile match { 81 | case _ if useDefaultCredentialsProvider => DefaultAWSCredentialsProviderChain.getInstance() 82 | case Some(profile) => new ProfileCredentialsProvider(profile) 83 | // In this case it's set using the AWS_PROFILE environment variable 84 | case _ => new ProfileCredentialsProvider() 85 | } 86 | 87 | val ssmClient: AWSSimpleSystemsManagementAsync = SSM.client(credentialsProvider, region) 88 | val stsClient: AWSSecurityTokenServiceAsync = STS.client(credentialsProvider, region) 89 | val ec2Client: AmazonEC2Async = EC2.client(credentialsProvider, region) 90 | val rdsClient: AmazonRDSAsync = RDS.client(credentialsProvider, region) 91 | AWSClients(ssmClient, stsClient, ec2Client, rdsClient) 92 | } 93 | 94 | def computeIncorrectInstances(executionTarget: ExecutionTarget, instanceIds: List[InstanceId]): List[InstanceId] = 95 | executionTarget.instances.getOrElse(List()).filterNot(instanceIds.toSet) 96 | 97 | def getAddress(instance: Instance, onlyUsePrivateIP: Boolean): Either[FailedAttempt, String] = { 98 | if (onlyUsePrivateIP) { 99 | Right(instance.privateIpAddress) 100 | } else { 101 | instance.publicIpAddressOpt match { 102 | case Some(ipAddress) => Right(ipAddress) 103 | case None => Right(instance.privateIpAddress) 104 | } 105 | } 106 | } 107 | 108 | def getHostKeyEntry(ssmResult: Either[CommandStatus, CommandResult], preferredAlgs: List[String]): Either[FailedAttempt, String] = { 109 | ssmResult match { 110 | case Right(result) => 111 | val resultLines = result.stdOut.linesIterator 112 | val preferredKeys = resultLines.filter(hostKey => preferredAlgs.exists(hostKey.startsWith)) 113 | val preferenceOrderedKeys: Seq[String] = preferredKeys.toList.sortBy( 114 | hostKey => preferredAlgs.indexWhere(hostKey.startsWith) 115 | ) 116 | 117 | preferenceOrderedKeys.headOption match { 118 | case Some(hostKey) => Right(hostKey) 119 | case None => Left(Failure( 120 | "host key with preferred algorithm not found", 121 | s"The remote instance did not return a host key with any preferred algorithm (preferred: $preferredAlgs)", 122 | NoHostKey, 123 | s"The result lines returned from the host:\n${resultLines.mkString("\n")}" 124 | ).attempt) 125 | } 126 | case Left(otherStatus) => Left(Failure("host keys not returned", s"The remote instance failed to return the host keys within the timeout window (status: $otherStatus)", AwsError).attempt) 127 | } 128 | } 129 | 130 | } 131 | -------------------------------------------------------------------------------- /src/main/scala/com/gu/ssm/Main.scala: -------------------------------------------------------------------------------- 1 | package com.gu.ssm 2 | 3 | import com.amazonaws.regions.Region 4 | import com.gu.ssm.utils.attempt._ 5 | 6 | import scala.concurrent.duration._ 7 | import scala.concurrent.{Await, ExecutionContext, ExecutionContextExecutor} 8 | import com.gu.ssm.ArgumentParser.argParser 9 | 10 | object Main { 11 | implicit val ec: ExecutionContextExecutor = ExecutionContext.global 12 | 13 | def main(args: Array[String]): Unit = { 14 | val (result, verbose) = argParser.parse(args, Arguments.empty()) match { 15 | case Some(Arguments(verbose, Some(executionTarget), toExecuteOpt, profile, region, Some(mode), Some(user), sism, _, _, onlyUsePrivateIP, rawOutput, bastionInstanceIdOpt, bastionPortNumberOpt, Some(bastionUser), targetInstancePortNumberOpt, useAgent, preferredAlgs, sourceFileOpt, targetFileOpt, tunnelThroughSystemsManager, useDefaultCredentialsProvider, tunnelTarget, rdsTunnelTarget)) => 16 | val awsClients = Logic.getClients(profile, region, useDefaultCredentialsProvider) 17 | val r = mode match { 18 | case SsmRepl => 19 | new InteractiveProgram(awsClients).main(profile, region, executionTarget) 20 | ProgramResult(Nil) 21 | case SsmCmd => 22 | toExecuteOpt match { 23 | case Some(toExecute) => execute(awsClients, executionTarget, user, toExecute) 24 | case _ => fail 25 | } 26 | case SsmSsh => bastionInstanceIdOpt match { 27 | case None => setUpStandardSSH(awsClients, executionTarget, user, sism, onlyUsePrivateIP, rawOutput, targetInstancePortNumberOpt, preferredAlgs, useAgent, profile, region, tunnelThroughSystemsManager, tunnelTarget.orElse(rdsTunnelTarget)) 28 | case Some(bastionInstance) => setUpBastionSSH(awsClients, executionTarget, user, sism, onlyUsePrivateIP, rawOutput, bastionInstance, bastionPortNumberOpt, bastionUser, targetInstancePortNumberOpt, useAgent, preferredAlgs) 29 | } 30 | case SsmScp => (sourceFileOpt, targetFileOpt) match { 31 | case (Some(sourceFile), Some(targetFile)) => setUpStandardScp(awsClients, executionTarget, user, sism, onlyUsePrivateIP, rawOutput, targetInstancePortNumberOpt, preferredAlgs, useAgent, sourceFile, targetFile, profile, region, tunnelThroughSystemsManager) 32 | case _ => fail 33 | } 34 | } 35 | r -> verbose 36 | case Some(_) => fail -> false 37 | case None => ProgramResult(Nil, Some(ArgumentsError)) -> false // parsing cmd line args failed, help message will have been displayed 38 | } 39 | 40 | val ui = new UI(verbose) 41 | ui.printAll(result.output) 42 | System.exit(result.nonZeroExitCode.map(_.code).getOrElse(0)) 43 | } 44 | 45 | private def fail: ProgramResult = { 46 | ProgramResult(Seq(Err("Impossible application state! This should be enforced by the CLI parser. Did not receive valid instructions")), Some(UnhandledError)) 47 | } 48 | 49 | private def setUpStandardSSH( 50 | awsClients: AWSClients, 51 | executionTarget: ExecutionTarget, 52 | user: String, 53 | sism: SingleInstanceSelectionMode, 54 | onlyUsePrivateIP: Boolean, 55 | rawOutput: Boolean, 56 | targetInstancePortNumberOpt: Option[Int], 57 | preferredAlgs: List[String], 58 | useAgent: Option[Boolean], 59 | profile: Option[String], 60 | region: Region, 61 | tunnelThroughSystemsManager: Boolean, 62 | tunnelTarget: Option[TunnelTarget]): ProgramResult = { 63 | val fProgramResult = for { 64 | config <- IO.getSSMConfig(awsClients.ec2Client, awsClients.stsClient, executionTarget) 65 | sshArtifacts <- Attempt.fromEither(SSH.createKey()) 66 | (privateKeyFile, publicKey) = sshArtifacts 67 | addPublicKeyCommand = SSH.addTaintedCommand(config.name) + SSH.addPublicKeyCommand(user, publicKey) + SSH.outputHostKeysCommand() 68 | resolvedTunnelTarget <- Attempt.sequence(tunnelTarget.toList.map { 69 | case t: TunnelTargetWithRDSTags => IO.resolveRDSTunnelTarget(t, awsClients.rdsClient) 70 | case t: TunnelTargetWithHostName => Attempt.Right(t) 71 | }) 72 | removePublicKeyCommand = SSH.removePublicKeyCommand(user, publicKey) 73 | instance <- Attempt.fromEither(Logic.getSSHInstance(config.targets, sism)) 74 | _ <- IO.tagAsTainted(instance.id, config.name, awsClients.ec2Client) 75 | result <- IO.executeOnInstance(instance.id, config.name, addPublicKeyCommand, awsClients.ssmClient) 76 | _ <- IO.executeOnInstanceAsync(instance.id, config.name, removePublicKeyCommand, awsClients.ssmClient) 77 | hostKey <- Attempt.fromEither(Logic.getHostKeyEntry(result, preferredAlgs)) 78 | address <- Attempt.fromEither(Logic.getAddress(instance, onlyUsePrivateIP)) 79 | hostKeyFile <- SSH.writeHostKey((address, hostKey)) 80 | } yield { 81 | SSH.sshCmdStandard(rawOutput)(privateKeyFile, instance, user, address, targetInstancePortNumberOpt, Some(hostKeyFile), useAgent, profile, region, tunnelThroughSystemsManager, resolvedTunnelTarget.headOption) 82 | } 83 | val programResult = Await.result(fProgramResult.asFuture, Duration.Inf) 84 | ProgramResult.convertErrorToResult(programResult.map(UI.sshOutput(rawOutput))) 85 | } 86 | 87 | private def setUpBastionSSH( 88 | awsClients: AWSClients, 89 | executionTarget: ExecutionTarget, 90 | user: String, 91 | sism: SingleInstanceSelectionMode, 92 | onlyUsePrivateIP: Boolean, 93 | rawOutput: Boolean, 94 | bastionInstance: ExecutionTarget, 95 | bastionPortNumberOpt: Option[Int], 96 | bastionUser: String, 97 | targetInstancePortNumberOpt: Option[Int], 98 | useAgent: Option[Boolean], 99 | preferredAlgs: List[String]): ProgramResult = { 100 | val fProgramResult = for { 101 | sshArtifacts <- Attempt.fromEither(SSH.createKey()) 102 | (privateKeyFile, publicKey) = sshArtifacts 103 | bastionConfig <- IO.getSSMConfig(awsClients.ec2Client, awsClients.stsClient, bastionInstance) 104 | bastionInstance <- Attempt.fromEither(Logic.getSSHInstance(bastionConfig.targets, sism)) 105 | bastionAddPublicKeyCommand = SSH.addPublicKeyCommand(user, publicKey) + SSH.outputHostKeysCommand() 106 | bastionRemovePublicKeyCommand = SSH.removePublicKeyCommand(user, publicKey) 107 | bastionAddress <- Attempt.fromEither(Logic.getAddress(bastionInstance, onlyUsePrivateIP)) 108 | targetConfig <- IO.getSSMConfig(awsClients.ec2Client, awsClients.stsClient, executionTarget) 109 | targetInstance <- Attempt.fromEither(Logic.getSSHInstance(targetConfig.targets, sism)) 110 | targetAddress <- Attempt.fromEither(Logic.getAddress(targetInstance, true)) 111 | targetAddPublicKeyCommand = SSH.addTaintedCommand(targetConfig.name) + SSH.addPublicKeyCommand(user, publicKey) + SSH.outputHostKeysCommand() 112 | targetRemovePublicKeyCommand = SSH.removePublicKeyCommand(user, publicKey) 113 | bastionResult <- IO.executeOnInstance(bastionInstance.id, bastionConfig.name, bastionAddPublicKeyCommand, awsClients.ssmClient) 114 | _ <- IO.executeOnInstanceAsync(bastionInstance.id, bastionConfig.name, bastionRemovePublicKeyCommand, awsClients.ssmClient) 115 | _ <- IO.tagAsTainted(targetInstance.id, targetConfig.name, awsClients.ec2Client) 116 | targetResult <- IO.executeOnInstance(targetInstance.id, targetConfig.name, targetAddPublicKeyCommand, awsClients.ssmClient) 117 | _ <- IO.executeOnInstanceAsync(targetInstance.id, targetConfig.name, targetRemovePublicKeyCommand, awsClients.ssmClient) 118 | bastionHostKey <- Attempt.fromEither(Logic.getHostKeyEntry(bastionResult, preferredAlgs)) 119 | targetHostKey <- Attempt.fromEither(Logic.getHostKeyEntry(targetResult, preferredAlgs)) 120 | hostKeyFile <- SSH.writeHostKey((bastionAddress, bastionHostKey), (targetAddress, targetHostKey)) 121 | } yield SSH.sshCmdBastion(rawOutput)(privateKeyFile, bastionInstance, targetInstance, user, bastionAddress, targetAddress, bastionPortNumberOpt, bastionUser, targetInstancePortNumberOpt, useAgent, Some(hostKeyFile)) 122 | val programResult = Await.result(fProgramResult.asFuture, Duration.Inf) 123 | ProgramResult.convertErrorToResult(programResult.map(UI.sshOutput(rawOutput))) 124 | } 125 | 126 | private def setUpStandardScp( 127 | awsClients: AWSClients, 128 | executionTarget: ExecutionTarget, 129 | user: String, 130 | sism: SingleInstanceSelectionMode, 131 | onlyUsePrivateIP: Boolean, 132 | rawOutput: Boolean, 133 | targetInstancePortNumberOpt: Option[Int], 134 | preferredAlgs: List[String], 135 | useAgent: Option[Boolean], 136 | sourceFile: String, 137 | targetFile: String, 138 | profile: Option[String], 139 | region: Region, 140 | tunnelThroughSystemsManager: Boolean): ProgramResult = { 141 | val fProgramResult = for { 142 | config <- IO.getSSMConfig(awsClients.ec2Client, awsClients.stsClient, executionTarget) 143 | sshArtifacts <- Attempt.fromEither(SSH.createKey()) 144 | (privateKeyFile, publicKey) = sshArtifacts 145 | addPublicKeyCommand = SSH.addTaintedCommand(config.name) + SSH.addPublicKeyCommand(user, publicKey) + SSH.outputHostKeysCommand() 146 | removePublicKeyCommand = SSH.removePublicKeyCommand(user, publicKey) 147 | instance <- Attempt.fromEither(Logic.getSSHInstance(config.targets, sism)) 148 | _ <- IO.tagAsTainted(instance.id, config.name, awsClients.ec2Client) 149 | result <- IO.executeOnInstance(instance.id, config.name, addPublicKeyCommand, awsClients.ssmClient) 150 | _ <- IO.executeOnInstanceAsync(instance.id, config.name, removePublicKeyCommand, awsClients.ssmClient) 151 | hostKey <- Attempt.fromEither(Logic.getHostKeyEntry(result, preferredAlgs)) 152 | address <- Attempt.fromEither(Logic.getAddress(instance, onlyUsePrivateIP)) 153 | hostKeyFile <- SSH.writeHostKey((address, hostKey)) 154 | } yield { 155 | SSH.scpCmdStandard(rawOutput)(privateKeyFile, instance, user, address, targetInstancePortNumberOpt, useAgent, Some(hostKeyFile), sourceFile, targetFile, profile, region, tunnelThroughSystemsManager) 156 | } 157 | val programResult = Await.result(fProgramResult.asFuture, Duration.Inf) 158 | ProgramResult.convertErrorToResult(programResult.map(UI.sshOutput(rawOutput))) 159 | } 160 | 161 | private def execute(awsClients: AWSClients, executionTarget: ExecutionTarget, user: String, toExecute: String): ProgramResult = { 162 | val fProgramResult = for { 163 | config <- IO.getSSMConfig(awsClients.ec2Client, awsClients.stsClient, executionTarget) 164 | _ <- Attempt.fromEither(Logic.checkInstancesList(config)) 165 | results <- IO.executeOnInstances(config.targets.map(i => i.id), user, toExecute, awsClients.ssmClient) 166 | incorrectInstancesFromInstancesTag = Logic.computeIncorrectInstances(executionTarget, results.map(_._1)) 167 | } yield ResultsWithInstancesNotFound(results, incorrectInstancesFromInstancesTag) 168 | 169 | val programResult = Await.result(fProgramResult.asFuture, Duration.Inf) 170 | ProgramResult.convertErrorToResult(programResult.map(UI.output)) 171 | } 172 | } 173 | -------------------------------------------------------------------------------- /src/main/scala/com/gu/ssm/SSH.scala: -------------------------------------------------------------------------------- 1 | package com.gu.ssm 2 | 3 | import java.io._ 4 | import java.security.{NoSuchAlgorithmException, NoSuchProviderException} 5 | import java.util.Calendar 6 | 7 | import com.amazonaws.regions.Region 8 | import com.gu.ssm.utils.attempt._ 9 | import com.gu.ssm.utils.{FilePermissions, KeyMaker} 10 | 11 | object SSH { 12 | 13 | val sshCredentialsLifetimeSeconds = 30 14 | 15 | def createKey(): Either[FailedAttempt, (File, String)] = { 16 | 17 | // Write key to file. 18 | val prefix = "security_ssm-scala_temporary-rsa-private-key" 19 | val suffix = ".tmp" 20 | val keyAlgorithm = "RSA" 21 | val keyProvider = "BC" 22 | 23 | try { 24 | val privateKeyFile = File.createTempFile(prefix, suffix, new File(System.getProperty("java.io.tmpdir"))) 25 | FilePermissions(privateKeyFile, "0600") 26 | val publicKey = KeyMaker.makeKey(privateKeyFile, keyAlgorithm, keyProvider) 27 | Right((privateKeyFile, publicKey)) 28 | } catch { 29 | case e:IOException => Left(FailedAttempt( 30 | Failure(s"Unable to create private key file", "Error creating key on disk", UnhandledError, e) 31 | )) 32 | case e:NoSuchAlgorithmException => Left(FailedAttempt( 33 | Failure(s"Unable to create key pair with algorithm $keyAlgorithm", s"Error creating key with algorithm $keyAlgorithm", UnhandledError, e) 34 | )) 35 | case e:NoSuchProviderException => Left(FailedAttempt( 36 | Failure(s"Unable to create key pair with provider $keyProvider", s"Error creating key with provider $keyProvider", UnhandledError, e) 37 | )) 38 | } 39 | } 40 | 41 | def writeHostKey(addressHostKeyTuples: (String, String)*): Attempt[File] = { 42 | // Write key to file. 43 | val prefix = "security_ssm-scala_temporary-host-key" 44 | val suffix = ".tmp" 45 | 46 | try { 47 | val hostKeyFile = File.createTempFile(prefix, suffix) 48 | val writer = new PrintWriter(new FileOutputStream(hostKeyFile)) 49 | try { 50 | addressHostKeyTuples.foreach { case (address, hostKey) => 51 | writer.println(s"$address $hostKey") 52 | } 53 | } finally { 54 | writer.close() 55 | } 56 | Attempt.Right(hostKeyFile) 57 | } catch { 58 | case e:IOException => Attempt.Left( 59 | Failure(s"Unable to create host key file", "Error creating host key on disk", UnhandledError, e) 60 | ) 61 | } 62 | } 63 | 64 | def addTaintedCommand(name: String): String = { 65 | s""" 66 | | /usr/bin/test -d /etc/update-motd.d/ && 67 | | ( /usr/bin/test -f /etc/update-motd.d/99-tainted || /bin/echo -e '#!/bin/bash' | /usr/bin/sudo /usr/bin/tee -a /etc/update-motd.d/99-tainted >> /dev/null; 68 | | /bin/echo -e 'echo -e "\\033[0;31mThis instance should be considered tainted.\\033[0;39m"' | /usr/bin/sudo /usr/bin/tee -a /etc/update-motd.d/99-tainted >> /dev/null; 69 | | /bin/echo -e 'echo -e "\\033[0;31mIt was accessed by $name at ${Calendar.getInstance().getTime}\\033[0;39m"' | /usr/bin/sudo /usr/bin/tee -a /etc/update-motd.d/99-tainted >> /dev/null; 70 | | /usr/bin/sudo /bin/chmod 0755 /etc/update-motd.d/99-tainted; 71 | | /usr/bin/sudo /bin/run-parts /etc/update-motd.d/ | /usr/bin/sudo /usr/bin/tee /run/motd.dynamic >> /dev/null; 72 | | ) """.stripMargin 73 | } 74 | 75 | def addPublicKeyCommand(user: String, publicKey: String): String = 76 | s""" 77 | | /bin/mkdir -p /home/$user/.ssh; 78 | | /bin/echo '$publicKey' >> /home/$user/.ssh/authorized_keys; 79 | | /bin/chown $user /home/$user/.ssh/authorized_keys; 80 | | /bin/chmod 0600 /home/$user/.ssh/authorized_keys; 81 | |""".stripMargin 82 | 83 | def removePublicKeyCommand(user: String, publicKey: String): String = 84 | s""" 85 | | /bin/sleep $sshCredentialsLifetimeSeconds; 86 | | /bin/sed -i '/${publicKey.replaceAll("/", "\\\\/")}/d' /home/$user/.ssh/authorized_keys; 87 | |""".stripMargin 88 | 89 | def outputHostKeysCommand(): String = 90 | """ 91 | | for hostkey in $(sshd -T 2> /dev/null |grep "^hostkey " | cut -d ' ' -f 2); do cat $hostkey.pub; done 92 | """.stripMargin 93 | 94 | def sshCmdStandard(rawOutput: Boolean)(privateKeyFile: File, instance: Instance, user: String, ipAddress: String, targetInstancePortNumberOpt: Option[Int], hostsFile: Option[File], useAgent: Option[Boolean], profile: Option[String], region: Region, tunnelThroughSystemsManager: Boolean, tunnelTarget: Option[TunnelTargetWithHostName]): (InstanceId, Seq[Output]) = { 95 | val targetPortSpecifications = targetInstancePortNumberOpt match { 96 | case Some(portNumber) => s" -p ${portNumber}" 97 | case _ => "" 98 | } 99 | val theTTOptions = if(rawOutput) { " -t -t" }else{ "" } 100 | val useAgentFragment = useAgent match { 101 | case None => "" 102 | case Some(decision) => if(decision) " -A" else " -a" 103 | } 104 | val hostsFileString = hostsFile.map(file => s""" -o "UserKnownHostsFile $file" -o "StrictHostKeyChecking yes"""").getOrElse("") 105 | val proxyFragment = if(tunnelThroughSystemsManager) { s""" -o "ProxyCommand sh -c \\"aws ssm start-session --target ${instance.id.id} --document-name AWS-StartSSHSession --parameters 'portNumber=22' --region $region ${profile.map("--profile " + _).getOrElse("")}\\""""" } else { "" } 106 | 107 | val (tunnelString, tunnelMeta) = tunnelTarget.map(t => ( 108 | s"-L ${t.localPort}:${t.remoteHostName}:${t.remotePort} -N -f", 109 | Seq( 110 | Metadata(s"# If the command succeeded, a tunnel has been established."), 111 | Metadata(s"# Local port: ${t.localPort}"), 112 | Metadata(s"# Remote address: ${t.remoteHostName}:${t.remotePort}") 113 | ) ++ t.remoteTags.map(_.toLowerCase).find(_.contains("prod")).map { _ => 114 | Metadata(s"# The tags indicate that this is a PRODUCTION resource. Please take care! Perhaps bring a pair?") 115 | } 116 | )).getOrElse(("", Seq.empty)) 117 | 118 | val connectionString = s"""ssh -o "IdentitiesOnly yes"$useAgentFragment$hostsFileString$targetPortSpecifications$proxyFragment -i ${privateKeyFile.getCanonicalFile.toString}${theTTOptions} $user@$ipAddress $tunnelString""".trim() 119 | 120 | val cmd = if (rawOutput) { 121 | Seq(Out(s"$connectionString", newline = false)) ++ tunnelMeta.toList 122 | } else { 123 | Seq( 124 | Metadata(s"# Dryrun mode. The command below will remain valid for $sshCredentialsLifetimeSeconds seconds:"), 125 | Out(s"$connectionString;") 126 | ) 127 | } 128 | (instance.id, cmd) 129 | } 130 | 131 | def sshCmdBastion(rawOutput: Boolean)(privateKeyFile: File, bastionInstance: Instance, targetInstance: Instance, targetInstanceUser: String, bastionIpAddress: String, targetIpAddress: String, bastionPortNumberOpt: Option[Int], bastionUser: String, targetInstancePortNumberOpt: Option[Int], useAgent: Option[Boolean], hostsFile: Option[File]): (InstanceId, Seq[Output]) = { 132 | val bastionPort = bastionPortNumberOpt.getOrElse(22) 133 | val targetPort = targetInstancePortNumberOpt.getOrElse(22) 134 | val hostsFileString = hostsFile.map(file => s""" -o "UserKnownHostsFile $file" -o "StrictHostKeyChecking yes"""").getOrElse("") 135 | val identityFragment = s"-i ${privateKeyFile.getCanonicalFile.toString}" 136 | val proxyFragment = s"""-o 'ProxyCommand ssh -o "IdentitiesOnly yes" $identityFragment$hostsFileString -p $bastionPort $bastionUser@$bastionIpAddress nc $targetIpAddress $targetPort'""" 137 | val stringFragmentTTOptions = if(rawOutput) { " -t -t" } else { "" } 138 | val useAgentFragment = useAgent match { 139 | case None => "" 140 | case Some(decision) => if(decision) " -A" else " -a" 141 | } 142 | val connectionString = 143 | s"""ssh$useAgentFragment -o "IdentitiesOnly yes" $identityFragment$hostsFileString $proxyFragment$stringFragmentTTOptions $targetInstanceUser@$targetIpAddress""" 144 | val cmd = if(rawOutput) { 145 | Seq(Out(s"$connectionString", newline = false)) 146 | }else{ 147 | Seq( 148 | Metadata(s"# Dryrun mode. The command below will remain valid for $sshCredentialsLifetimeSeconds seconds:"), 149 | Out(s"$connectionString;") 150 | ) 151 | } 152 | (targetInstance.id, cmd) 153 | } 154 | 155 | // The first file goes to the second file 156 | // The remote file is indicated by a colon 157 | 158 | def scpCmdStandard(rawOutput: Boolean)(privateKeyFile: File, instance: Instance, user: String, ipAddress: String, targetInstancePortNumberOpt: Option[Int], useAgent: Option[Boolean], hostsFile: Option[File], sourceFile: String, targetFile: String, profile: Option[String], region: Region, tunnelThroughSystemsManager: Boolean): (InstanceId, Seq[Output]) = { 159 | 160 | def isRemote(filepath: String): Boolean = { 161 | filepath.startsWith(":") 162 | } 163 | 164 | def exactlyOneArgumentIsRemote(filepath1: String, filepath2: String): Boolean = { 165 | List(filepath1, filepath2).map(isRemote).count(_ == true) == 1 166 | } 167 | 168 | val targetPortSpecifications = targetInstancePortNumberOpt match { 169 | case Some(portNumber) => s" -p ${portNumber}" 170 | case _ => "" 171 | } 172 | val hostsFileString = hostsFile.map(file => s""" -o "UserKnownHostsFile $file" -o "StrictHostKeyChecking yes"""").getOrElse("") 173 | val proxyFragment = if(tunnelThroughSystemsManager) { s""" -o "ProxyCommand sh -c \\"aws ssm start-session --target ${instance.id.id} --document-name AWS-StartSSHSession --parameters 'portNumber=22' --region $region ${profile.map("--profile " + _).getOrElse("")}\\""""" } else { "" } 174 | val useAgentFragment = useAgent match { 175 | case None => "" 176 | case Some(decision) => if(decision) " -A" else " -a" 177 | } 178 | // We are using colon to designate the remote file. 179 | // There should be only one. 180 | if (exactlyOneArgumentIsRemote(sourceFile, targetFile)) { 181 | val connectionString = 182 | if (isRemote(sourceFile)) { 183 | s"""scp -o "IdentitiesOnly yes"$useAgentFragment$hostsFileString$proxyFragment${targetPortSpecifications} -i ${privateKeyFile.getCanonicalFile.toString} $user@$ipAddress:${sourceFile.stripPrefix(":")} ${targetFile}""" 184 | }else { 185 | s"""scp -o "IdentitiesOnly yes"$useAgentFragment$hostsFileString$proxyFragment${targetPortSpecifications} -i ${privateKeyFile.getCanonicalFile.toString} ${sourceFile} $user@$ipAddress:${targetFile.stripPrefix(":")}""" 186 | } 187 | val cmd = if(rawOutput) { 188 | Seq(Out(s"$connectionString", newline = false)) 189 | }else{ 190 | Seq( 191 | Metadata(s"# Dryrun mode. The command below will remain valid for $sshCredentialsLifetimeSeconds seconds:"), 192 | Out(s"$connectionString;") 193 | ) 194 | } 195 | (instance.id, cmd) 196 | }else{ 197 | (instance.id, Seq(Err("Incorrect remote server specifications, only one file should carry the starting colon"))) 198 | } 199 | 200 | } 201 | 202 | } 203 | -------------------------------------------------------------------------------- /src/main/scala/com/gu/ssm/UI.scala: -------------------------------------------------------------------------------- 1 | package com.gu.ssm 2 | 3 | import java.io.{ByteArrayOutputStream, PrintWriter} 4 | import com.gu.ssm.utils.attempt.{ErrorCode, ExitCode, FailedAttempt} 5 | 6 | import scala.collection.mutable 7 | 8 | sealed trait Output { 9 | def text: String 10 | def newline: Boolean = true 11 | } 12 | case class Out(text: String, override val newline: Boolean = true) extends Output 13 | case class Metadata(text: String) extends Output 14 | case class Err(text: String, throwable: Option[Throwable] = None) extends Output 15 | case class Verbose(text: String) extends Output 16 | 17 | case class ProgramResult(output: Seq[Output], nonZeroExitCode: Option[ExitCode] = None) 18 | object ProgramResult { 19 | def convertErrorToResult(programResult: Either[FailedAttempt, ProgramResult]): ProgramResult = { 20 | programResult.fold ( 21 | failedAttempt => ProgramResult(UI.outputFailure(failedAttempt), Some(failedAttempt.exitCode)), 22 | identity 23 | ) 24 | } 25 | } 26 | 27 | object UI { 28 | implicit class RichString(val s: String) extends AnyVal { 29 | def colour(colour: String): String = { 30 | colour + s + Console.RESET 31 | } 32 | } 33 | implicit class RichThrowable(val t: Throwable) extends AnyVal { 34 | def getAsString: String = { 35 | val baos = new ByteArrayOutputStream() 36 | val pw = new PrintWriter(baos) 37 | t.printStackTrace(pw) 38 | pw.close() 39 | baos.toString 40 | } 41 | } 42 | 43 | def output(extendedResults: ResultsWithInstancesNotFound): ProgramResult = { 44 | val buffer = mutable.Buffer.empty[Output] 45 | if(extendedResults.instancesNotFound.nonEmpty){ 46 | buffer += Err(s"The following instance(s) could not be found: ${extendedResults.instancesNotFound.map(_.id).mkString(", ")}\n") 47 | } 48 | extendedResults.results.flatMap { case (instance, result) => 49 | buffer += Metadata(s"========= ${instance.id} =========") 50 | result match { 51 | case Left(commandStatus) => 52 | buffer += Err(commandStatus.toString) 53 | case Right(commandStatus) => 54 | buffer ++= Seq( 55 | Metadata(s"STDOUT:"), 56 | Out(commandStatus.stdOut), 57 | Metadata(s"STDERR:"), 58 | Err(commandStatus.stdErr) 59 | ) 60 | } 61 | } 62 | 63 | val nonZeroExitCode = if (hasAnyCommandFailed(extendedResults.results)) Some(ErrorCode) else None 64 | ProgramResult(buffer.toList, nonZeroExitCode) 65 | } 66 | 67 | def sshOutput(rawOutput: Boolean)(result: (InstanceId, Seq[Output])): ProgramResult = ProgramResult( 68 | if (rawOutput){ 69 | result._2 70 | } else { 71 | Metadata(s"========= ${result._1.id} =========") +: result._2 72 | } 73 | ) 74 | 75 | def outputFailure(failedAttempt: FailedAttempt): Seq[Output] = { 76 | failedAttempt.failures.flatMap { failure => 77 | Seq(Err(failure.friendlyMessage, failure.throwable)) ++ failure.context.map(Verbose.apply) 78 | } 79 | } 80 | 81 | def hasAnyCommandFailed(ssmResults: List[(InstanceId, Either[CommandStatus, CommandResult])]): Boolean = { 82 | ssmResults.exists { case(_, result) => result.exists(_.commandFailed) } 83 | } 84 | } 85 | 86 | class UI(verbose: Boolean) { 87 | import UI._ 88 | 89 | def printAll(output: Seq[Output]): Unit = print(output*) 90 | 91 | def print(output: Output*): Unit = { 92 | output.foreach { 93 | case Out(text, true) => System.out.println(text) 94 | case Out(text, false) => System.out.print(text) 95 | case Metadata(text) => printMetadata(text) 96 | case Err(text, maybeThrowable) => 97 | printErr(text) 98 | maybeThrowable.foreach { t => printVerbose(t.getAsString) } 99 | case Verbose(text) => printVerbose(text) 100 | } 101 | } 102 | 103 | def printVerbose(text: String): Unit = { 104 | if (verbose) System.err.println(text.colour(Console.BLUE)) 105 | } 106 | 107 | def printMetadata(text: String): Unit = { 108 | System.err.println(text.colour(Console.CYAN)) 109 | } 110 | 111 | def printErr(text: String): Unit = { 112 | System.err.println(text.colour(Console.YELLOW)) 113 | } 114 | 115 | } 116 | -------------------------------------------------------------------------------- /src/main/scala/com/gu/ssm/aws/AwsAsyncHandler.scala: -------------------------------------------------------------------------------- 1 | package com.gu.ssm.aws 2 | 3 | import com.amazonaws.AmazonWebServiceRequest 4 | import com.amazonaws.handlers.AsyncHandler 5 | import com.gu.ssm.utils.attempt.{Attempt, AwsError, AwsPermissionsError, Failure} 6 | import com.typesafe.scalalogging.LazyLogging 7 | 8 | import scala.concurrent.{ExecutionContext, Future, Promise} 9 | 10 | 11 | object AwsAsyncHandler { 12 | private val ServiceName = ".*Service: ([^;]+);.*".r 13 | def awsToScala[R <: AmazonWebServiceRequest, T](sdkMethod: ( (R, AsyncHandler[R, T]) => java.util.concurrent.Future[T])): (R => Future[T]) = { req => 14 | val p = Promise[T]() 15 | sdkMethod(req, new AwsAsyncPromiseHandler(p)) 16 | p.future 17 | } 18 | 19 | /** 20 | * Handles expected AWS errors in a nice way 21 | */ 22 | def handleAWSErrs[T](f: Future[T])(implicit ec: ExecutionContext): Attempt[T] = { 23 | Attempt.fromFuture(f) { case e => 24 | val serviceNameOpt = e.getMessage match { 25 | case ServiceName(serviceName) => Some(serviceName) 26 | case _ => None 27 | } 28 | if (e.getMessage.contains("Request has expired")) { 29 | Failure("expired AWS credentials", "Failed to request data from AWS, the temporary credentials have expired", AwsPermissionsError, e).attempt 30 | } else if (e.getMessage.contains("Unable to load AWS credentials from any provider in the chain")) { 31 | Failure("No AWS credentials found", "No AWS credentials found. Did you mean to set --profile?", AwsPermissionsError, e).attempt 32 | } else if (e.getMessage.contains("No AWS profile named")) { 33 | Failure("Invalid AWS profile name (does not exist)", "The specified AWS profile does not exist", AwsPermissionsError, e).attempt 34 | } else if (e.getMessage.contains("is not authorized to perform")) { 35 | val message = serviceNameOpt.fold("You do not have sufficient AWS privileges")(serviceName => s"You do not have sufficient privileges to perform actions on $serviceName") 36 | Failure("insufficient permissions", message, AwsPermissionsError, e).attempt 37 | } else if (e.getMessage.contains("InvalidInstanceId")) { 38 | Failure("InvalidInstanceId from AWS", "The specified instance(s) are not eligible targets (AWS said InvalidInstanceId)", AwsError, e).attempt 39 | } else { 40 | val details = serviceNameOpt.fold(s"AWS unknown error, unknown service (check logs for stacktrace). $e") { serviceName => 41 | s"AWS unknown error, service: $serviceName (check logs for stacktrace), $e" 42 | } 43 | val friendlyMessage = serviceNameOpt.fold(s"Unknown error while making API calls to AWS. $e") { serviceName => 44 | s"Unknown error while making an API call to AWS' $serviceName service, $e" 45 | } 46 | Failure(details, friendlyMessage, AwsError, e).attempt 47 | } 48 | } 49 | } 50 | 51 | class AwsAsyncPromiseHandler[R <: AmazonWebServiceRequest, T](promise: Promise[T]) extends AsyncHandler[R, T] with LazyLogging { 52 | def onError(e: Exception): Unit = { 53 | logger.warn("Failed to execute AWS SDK operation", e) 54 | promise failure e 55 | } 56 | def onSuccess(r: R, t: T): Unit = { 57 | promise success t 58 | } 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /src/main/scala/com/gu/ssm/aws/EC2.scala: -------------------------------------------------------------------------------- 1 | package com.gu.ssm.aws 2 | 3 | import com.amazonaws.auth.AWSCredentialsProvider 4 | import com.amazonaws.regions.Region 5 | import com.amazonaws.services.ec2.model._ 6 | import com.amazonaws.services.ec2.{AmazonEC2Async, AmazonEC2AsyncClientBuilder} 7 | import com.gu.ssm.aws.AwsAsyncHandler.{awsToScala, handleAWSErrs} 8 | import com.gu.ssm.utils.attempt.Attempt 9 | import com.gu.ssm.{Instance, InstanceId} 10 | 11 | import scala.concurrent.ExecutionContext 12 | import scala.jdk.CollectionConverters._ 13 | 14 | 15 | object EC2 { 16 | def client(credentialsProvider: AWSCredentialsProvider, region: Region): AmazonEC2Async = { 17 | AmazonEC2AsyncClientBuilder.standard() 18 | .withCredentials(credentialsProvider) 19 | .withRegion(region.getName) 20 | .build() 21 | } 22 | 23 | def makeFilter(tagName: String, values: List[String]) = new Filter(s"tag:$tagName", values.asJava) 24 | 25 | def resolveByTags(tagValues: List[String], client: AmazonEC2Async)(implicit ec: ExecutionContext): Attempt[List[Instance]] = { 26 | val allTags = tagValues ++ tagValues.map(_.toUpperCase) ++ tagValues.map(_.toLowerCase) 27 | 28 | // if user has provided fewer than 3 tags then assume order app,stage,stack 29 | val tagOrder = List("App", "Stage", "Stack") 30 | val filters = new Filter("instance-state-name", List("running").asJava) :: 31 | tagOrder.take(tagValues.length).map(makeFilter(_, allTags)) 32 | 33 | val request = new DescribeInstancesRequest() 34 | .withFilters( 35 | filters* 36 | ) 37 | handleAWSErrs(awsToScala(client.describeInstancesAsync)(request).map(extractInstances)) 38 | } 39 | 40 | def resolveInstanceIds(ids: List[InstanceId], client: AmazonEC2Async)(implicit ec: ExecutionContext): Attempt[List[Instance]] = { 41 | val request = new DescribeInstancesRequest() 42 | .withFilters( 43 | new Filter("instance-state-name", List("running").asJava), 44 | new Filter("instance-id", ids.map(i => i.id).asJava) 45 | ) 46 | handleAWSErrs(awsToScala(client.describeInstancesAsync)(request).map(extractInstances)) 47 | } 48 | 49 | private def extractInstances(describeInstancesResult: DescribeInstancesResult): List[Instance] = { 50 | (for { 51 | reservation <- describeInstancesResult.getReservations.asScala 52 | awsInstance <- reservation.getInstances.asScala 53 | instanceId = awsInstance.getInstanceId 54 | launchDateTime = awsInstance.getLaunchTime.toInstant 55 | } yield Instance(InstanceId(instanceId), Option(awsInstance.getPublicDnsName), Option(awsInstance.getPublicIpAddress), awsInstance.getPrivateIpAddress, launchDateTime)).toList 56 | } 57 | 58 | def tagInstance(id: InstanceId, key: String, value: String, client: AmazonEC2Async)(implicit ec: ExecutionContext): Attempt[Unit] = { 59 | val request = new CreateTagsRequest() 60 | .withTags(new Tag(key, value)) 61 | .withResources(id.id) 62 | handleAWSErrs(awsToScala(client.createTagsAsync)(request)).map(_ => ()) 63 | } 64 | 65 | } 66 | 67 | 68 | -------------------------------------------------------------------------------- /src/main/scala/com/gu/ssm/aws/RDS.scala: -------------------------------------------------------------------------------- 1 | package com.gu.ssm.aws 2 | 3 | import com.amazonaws.services.rds.model.{DescribeDBInstancesRequest, Filter} 4 | import com.amazonaws.services.rds.AmazonRDSAsync 5 | import com.gu.ssm.utils.attempt.Attempt 6 | import com.gu.ssm.aws.AwsAsyncHandler.{awsToScala, handleAWSErrs} 7 | import com.amazonaws.services.rds.AmazonRDSAsyncClientBuilder 8 | import com.amazonaws.auth.AWSCredentialsProvider 9 | import com.amazonaws.regions.Region 10 | import com.amazonaws.services.rds.model.DescribeDBInstancesResult 11 | import scala.concurrent.ExecutionContext 12 | import scala.jdk.CollectionConverters._ 13 | import com.gu.ssm.RDSInstance 14 | import com.gu.ssm.RDSInstanceId 15 | import com.amazonaws.services.rds.model.DBInstance 16 | 17 | object RDS { 18 | def client(credentialsProvider: AWSCredentialsProvider, region: Region): AmazonRDSAsync = { 19 | AmazonRDSAsyncClientBuilder.standard() 20 | .withCredentials(credentialsProvider) 21 | .withRegion(region.getName) 22 | .build() 23 | } 24 | 25 | def resolveByTags(tagValues: List[String], client: AmazonRDSAsync)(implicit ec: ExecutionContext): Attempt[List[RDSInstance]] = { 26 | val request = new DescribeDBInstancesRequest() 27 | 28 | handleAWSErrs(awsToScala(client.describeDBInstancesAsync)(request).map { result => 29 | result.getDBInstances.asScala.toList 30 | .filter(hasTagList(tagValues)) 31 | .map(toInstance) 32 | }) 33 | } 34 | 35 | private def hasTagList(tagValues: List[String])(awsInstance: DBInstance): Boolean = { 36 | val instanceTags = awsInstance.getTagList().asScala.toList.map(_.getValue()) 37 | tagValues.forall(requiredTag => instanceTags.contains(requiredTag)) 38 | } 39 | 40 | private def toInstance(awsInstance: DBInstance): RDSInstance = { 41 | val endpoint = awsInstance.getEndpoint() 42 | RDSInstance(RDSInstanceId(awsInstance.getDBInstanceIdentifier()), endpoint.getAddress(), endpoint.getPort()) 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /src/main/scala/com/gu/ssm/aws/SSM.scala: -------------------------------------------------------------------------------- 1 | package com.gu.ssm.aws 2 | 3 | import com.amazonaws.auth.AWSCredentialsProvider 4 | import com.amazonaws.regions.Region 5 | import com.amazonaws.services.simplesystemsmanagement.model._ 6 | import com.amazonaws.services.simplesystemsmanagement.{AWSSimpleSystemsManagementAsync, AWSSimpleSystemsManagementAsyncClientBuilder} 7 | import com.gu.ssm.aws.AwsAsyncHandler.{awsToScala, handleAWSErrs} 8 | import com.gu.ssm.utils.attempt.Attempt 9 | import com.gu.ssm.{CommandStatus, _} 10 | 11 | import scala.jdk.CollectionConverters._ 12 | import scala.concurrent.ExecutionContext 13 | import scala.concurrent.duration._ 14 | 15 | 16 | object SSM { 17 | def client(credentialsProvider: AWSCredentialsProvider, region: Region): AWSSimpleSystemsManagementAsync = { 18 | AWSSimpleSystemsManagementAsyncClientBuilder.standard() 19 | .withCredentials(credentialsProvider) 20 | .withRegion(region.getName) 21 | .build() 22 | } 23 | 24 | def sendCommand(instanceIds: List[InstanceId], cmd: String, username: String, client: AWSSimpleSystemsManagementAsync)(implicit ec: ExecutionContext): Attempt[String] = { 25 | val parameters = Map("commands" -> List(cmd).asJava).asJava 26 | val sendCommandRequest = new SendCommandRequest() 27 | .withComment(s"Command submitted by $username") 28 | .withInstanceIds(instanceIds.map(_.id).asJava) 29 | .withDocumentName("AWS-RunShellScript") 30 | .withParameters(parameters) 31 | handleAWSErrs(awsToScala(client.sendCommandAsync)(sendCommandRequest).map(extractCommandId)) 32 | } 33 | 34 | def extractCommandId(sendCommandResult: SendCommandResult): String = { 35 | sendCommandResult.getCommand.getCommandId 36 | } 37 | 38 | def getCommandInvocation(instance: InstanceId, commandId: String, client: AWSSimpleSystemsManagementAsync)(implicit ec: ExecutionContext): Attempt[Either[CommandStatus, CommandResult]] = { 39 | val request = new GetCommandInvocationRequest() 40 | .withCommandId(commandId) 41 | .withInstanceId(instance.id) 42 | handleAWSErrs( 43 | awsToScala(client.getCommandInvocationAsync)(request) 44 | .map(extractCommandResult) 45 | .recover { case _:InvocationDoesNotExistException => Left(InvocationDoesNotExist) } 46 | ) 47 | } 48 | 49 | def extractCommandResult(getCommandInvocationResult: GetCommandInvocationResult): Either[CommandStatus, CommandResult] = { 50 | commandStatus(getCommandInvocationResult.getStatusDetails) match { 51 | case Success => 52 | Right(CommandResult(getCommandInvocationResult.getStandardOutputContent, getCommandInvocationResult.getStandardErrorContent, commandFailed = false)) 53 | case Failed => 54 | Right(CommandResult(getCommandInvocationResult.getStandardOutputContent, getCommandInvocationResult.getStandardErrorContent, commandFailed = true)) 55 | case status => 56 | Left(status) 57 | } 58 | } 59 | 60 | def getCmdOutput(instance: InstanceId, commandId: String, client: AWSSimpleSystemsManagementAsync)(implicit ec: ExecutionContext): Attempt[(InstanceId, Either[CommandStatus, CommandResult])] = { 61 | for { 62 | cmdResult <- Attempt.retryUntil(delayBetweenRetries = 500.millis, () => getCommandInvocation(instance, commandId, client))(_.isRight) 63 | } yield instance -> cmdResult 64 | } 65 | 66 | def getCmdOutputs(instanceIds: List[InstanceId], commandId: String, client: AWSSimpleSystemsManagementAsync)(implicit ec: ExecutionContext): Attempt[List[(InstanceId, Either[CommandStatus, CommandResult])]] = { 67 | Attempt.traverse(instanceIds)(getCmdOutput(_, commandId, client)) 68 | } 69 | 70 | def commandStatus(statusDetail: String): CommandStatus = { 71 | statusDetail match { 72 | case "Pending" => 73 | Pending 74 | case "InProgress" => 75 | InProgress 76 | case "Delayed" => 77 | Delayed 78 | case "Success" => 79 | Success 80 | case "DeliveryTimedOut" => 81 | DeliveryTimedOut 82 | case "ExecutionTimedOut" => 83 | ExecutionTimedOut 84 | case "Failed" => 85 | Failed 86 | case "Canceled" => 87 | Canceled 88 | case "Undeliverable" => 89 | Undeliverable 90 | case "Terminated" => 91 | Terminated 92 | case _ => 93 | throw new RuntimeException(s"Unexpected command status $statusDetail") 94 | } 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /src/main/scala/com/gu/ssm/aws/STS.scala: -------------------------------------------------------------------------------- 1 | package com.gu.ssm.aws 2 | 3 | import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration 4 | import com.amazonaws.auth.AWSCredentialsProvider 5 | import com.amazonaws.regions.Region 6 | import com.amazonaws.services.securitytoken.model.{GetCallerIdentityRequest, GetCallerIdentityResult} 7 | import com.amazonaws.services.securitytoken.{AWSSecurityTokenServiceAsync, AWSSecurityTokenServiceAsyncClientBuilder} 8 | import com.gu.ssm.aws.AwsAsyncHandler.{awsToScala, handleAWSErrs} 9 | import com.gu.ssm.utils.attempt.Attempt 10 | 11 | import scala.concurrent.ExecutionContext 12 | 13 | 14 | object STS { 15 | def client(credentialsProvider: AWSCredentialsProvider, region: Region): AWSSecurityTokenServiceAsync = { 16 | AWSSecurityTokenServiceAsyncClientBuilder.standard() 17 | .withCredentials(credentialsProvider) 18 | // STS is a global service but you need to access the regional endpoint if using it through an endpoint in VPCs 19 | // that have no outbound internet access. https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_enable-regions.html 20 | .withEndpointConfiguration(new EndpointConfiguration(region.getServiceEndpoint("sts"), region.getName)) 21 | .build() 22 | } 23 | 24 | def getCallerIdentity(client: AWSSecurityTokenServiceAsync)(implicit ec: ExecutionContext): Attempt[String] = { 25 | val request = new GetCallerIdentityRequest() 26 | handleAWSErrs(awsToScala(client.getCallerIdentityAsync)(request).map(extractUserId)) 27 | } 28 | 29 | def extractUserId(getCallerIdentityResult: GetCallerIdentityResult): String = { 30 | getCallerIdentityResult.getUserId 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /src/main/scala/com/gu/ssm/models.scala: -------------------------------------------------------------------------------- 1 | package com.gu.ssm 2 | 3 | import com.amazonaws.regions.{Region, Regions} 4 | import com.amazonaws.services.ec2.AmazonEC2Async 5 | import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceAsync 6 | import com.amazonaws.services.simplesystemsmanagement.AWSSimpleSystemsManagementAsync 7 | import java.time.Instant 8 | import com.amazonaws.services.rds.AmazonRDSAsync 9 | 10 | case class InstanceId(id: String) extends AnyVal 11 | case class Instance(id: InstanceId, publicDomainNameOpt: Option[String], publicIpAddressOpt: Option[String], privateIpAddress: String, launchInstant: Instant) 12 | case class AppStackStage(app: String, stack: String, stage: String) 13 | case class ExecutionTarget(instances: Option[List[InstanceId]] = None, tagValues: Option[List[String]] = None) 14 | case class RDSInstanceId(id: String) extends AnyVal 15 | case class RDSInstance(id: RDSInstanceId, hostname: String, port: Int) 16 | 17 | case class Arguments( 18 | verbose: Boolean, 19 | executionTarget: Option[ExecutionTarget], 20 | toExecute: Option[String], 21 | profile: Option[String], 22 | region: Region, 23 | mode: Option[SsmMode], 24 | targetInstanceUser: Option[String], 25 | singleInstanceSelectionMode: SingleInstanceSelectionMode, 26 | isSelectionModeNewest: Boolean, 27 | isSelectionModeOldest: Boolean, 28 | usePrivateIpAddress: Boolean, 29 | rawOutput: Boolean, 30 | bastionInstance: Option[ExecutionTarget], 31 | bastionPortNumber: Option[Int], 32 | bastionUser: Option[String], 33 | targetInstancePortNumber: Option[Int], 34 | useAgent: Option[Boolean], 35 | hostKeyAlgPreference: List[String], 36 | sourceFile: Option[String], 37 | targetFile: Option[String], 38 | tunnelThroughSystemsManager: Boolean, 39 | useDefaultCredentialsProvider: Boolean, 40 | tunnelTarget: Option[TunnelTargetWithHostName], 41 | rdsTunnelTarget: Option[TunnelTargetWithRDSTags] 42 | ) 43 | 44 | object Arguments { 45 | val targetInstanceDefaultUser = "ubuntu" 46 | val bastionDefaultUser = "ubuntu" 47 | val defaultHostKeyAlgPreference: List[String] = List("ecdsa-sha2-nistp256", "ssh-rsa") 48 | 49 | def empty(): Arguments = Arguments( 50 | verbose = false, 51 | executionTarget = None, 52 | toExecute = None, 53 | profile = None, 54 | region = Region.getRegion(Regions.EU_WEST_1), 55 | mode = None, 56 | targetInstanceUser = Some(targetInstanceDefaultUser), 57 | singleInstanceSelectionMode = SismUnspecified, 58 | isSelectionModeNewest = false, 59 | isSelectionModeOldest = false, 60 | usePrivateIpAddress = false, 61 | rawOutput = true, 62 | bastionInstance = None, 63 | bastionPortNumber = None, 64 | bastionUser = Some(bastionDefaultUser), 65 | targetInstancePortNumber = None, 66 | useAgent = None, 67 | hostKeyAlgPreference = defaultHostKeyAlgPreference, 68 | sourceFile = None, 69 | targetFile = None, 70 | tunnelThroughSystemsManager = true, 71 | useDefaultCredentialsProvider = false, 72 | tunnelTarget = None, 73 | rdsTunnelTarget = None 74 | ) 75 | } 76 | 77 | sealed trait CommandStatus 78 | case object Pending extends CommandStatus 79 | case object InProgress extends CommandStatus 80 | case object Delayed extends CommandStatus 81 | case object Success extends CommandStatus 82 | case object DeliveryTimedOut extends CommandStatus 83 | case object ExecutionTimedOut extends CommandStatus 84 | case object Failed extends CommandStatus 85 | case object Canceled extends CommandStatus 86 | case object Undeliverable extends CommandStatus 87 | case object Terminated extends CommandStatus 88 | case object InvocationDoesNotExist extends CommandStatus 89 | 90 | sealed trait SsmMode 91 | case object SsmCmd extends SsmMode 92 | case object SsmRepl extends SsmMode 93 | case object SsmSsh extends SsmMode 94 | case object SsmScp extends SsmMode 95 | 96 | case class CommandResult(stdOut: String, stdErr: String, commandFailed: Boolean) 97 | 98 | case class SSMConfig ( 99 | targets: List[Instance], 100 | name: String 101 | ) 102 | 103 | case class AWSClients ( 104 | ssmClient: AWSSimpleSystemsManagementAsync, 105 | stsClient: AWSSecurityTokenServiceAsync, 106 | ec2Client: AmazonEC2Async, 107 | rdsClient: AmazonRDSAsync 108 | ) 109 | 110 | case class ResultsWithInstancesNotFound( 111 | results: List[(InstanceId, scala.Either[CommandStatus, CommandResult])], 112 | instancesNotFound: List[InstanceId] 113 | ) 114 | 115 | sealed trait SingleInstanceSelectionMode 116 | case object SismNewest extends SingleInstanceSelectionMode 117 | case object SismOldest extends SingleInstanceSelectionMode 118 | case object SismUnspecified extends SingleInstanceSelectionMode 119 | 120 | sealed trait TunnelTarget 121 | case class TunnelTargetWithRDSTags(localPort: Int, remoteTags: Seq[String]) extends TunnelTarget 122 | case class TunnelTargetWithHostName(localPort: Int, remoteHostName: String, remotePort: Int, remoteTags: Seq[String] = Seq.empty) extends TunnelTarget 123 | -------------------------------------------------------------------------------- /src/main/scala/com/gu/ssm/utils/FilePermissions.scala: -------------------------------------------------------------------------------- 1 | package com.gu.ssm.utils 2 | 3 | import java.io.File 4 | import java.nio.file.Files 5 | import java.nio.file.attribute.{PosixFilePermission, PosixFilePermissions} 6 | 7 | import scala.util.Try 8 | 9 | /** 10 | * Setting the file permissions 11 | */ 12 | object FilePermissions { 13 | 14 | /** 15 | * Using java 7 nio API to set the permissions. 16 | * 17 | * @param file to act on 18 | * @param perms in octal format 19 | */ 20 | def apply(file: File, perms: String): Unit = { 21 | val posix = PosixFilePermissions.fromString(convert(perms)) 22 | val result = Try { 23 | Files.setPosixFilePermissions(file.toPath, posix) 24 | } recoverWith { 25 | // in case of windows 26 | case _: UnsupportedOperationException => 27 | Try { 28 | file.setExecutable(perms contains PosixFilePermission.OWNER_EXECUTE) 29 | file.setWritable(perms contains PosixFilePermission.OWNER_WRITE) 30 | } 31 | } 32 | 33 | // propagate error 34 | if (result.isFailure) { 35 | val e = result.failed.get 36 | sys.error("Error setting permissions " + perms + " on " + file.getAbsolutePath + ": " + e.getMessage) 37 | } 38 | } 39 | 40 | /** 41 | * Converts a octal unix permission representation into 42 | * a java `PosixFilePermissions` compatible string. 43 | */ 44 | def convert(perms: String): String = { 45 | require(perms.length == 4 || perms.length == 3, s"Permissions must have 3 or 4 digits, got [$perms]") 46 | // ignore setuid/setguid/sticky bit 47 | val i = if (perms.length == 3) 0 else 1 48 | val user = Character.getNumericValue((perms.charAt(i))) 49 | val group = Character.getNumericValue((perms.charAt(i + 1))) 50 | val other = Character.getNumericValue((perms.charAt(i + 2))) 51 | 52 | permissionAsString(user) + permissionAsString(group) + permissionAsString(other) 53 | } 54 | 55 | def permissionAsString(perm: Int): String = perm match { 56 | case 0 => "---" 57 | case 1 => "--x" 58 | case 2 => "-w-" 59 | case 3 => "-wx" 60 | case 4 => "r--" 61 | case 5 => "r-x" 62 | case 6 => "rw-" 63 | case 7 => "rwx" 64 | } 65 | 66 | /** Enriches string with `oct` interpolator, parsing string as base 8 integer. */ 67 | implicit class OctalString(val sc: StringContext) extends AnyVal { 68 | def oct(args: Any*): Int = Integer.parseInt(sc.s(args*), 8) 69 | } 70 | 71 | } 72 | -------------------------------------------------------------------------------- /src/main/scala/com/gu/ssm/utils/KeyMaker.scala: -------------------------------------------------------------------------------- 1 | package com.gu.ssm.utils 2 | import java.io._ 3 | 4 | import org.bouncycastle.util.io.pem.PemObject 5 | import org.bouncycastle.util.io.pem.PemWriter 6 | import java.security.Key 7 | 8 | import org.bouncycastle.jce.provider.BouncyCastleProvider 9 | import java.security.KeyPairGenerator 10 | import java.security.Security 11 | 12 | import org.bouncycastle.jcajce.provider.asymmetric.rsa.BCRSAPublicKey 13 | import org.apache.commons.codec.binary.Base64 14 | 15 | object KeyMaker { 16 | 17 | def makeKey (privateKeyFile: File, algorithm: String, provider: String): String = { 18 | Security.addProvider(new BouncyCastleProvider) 19 | val keyPair = generateKeyPair(algorithm, provider) 20 | val priv = keyPair.getPrivate 21 | val pub = keyPair.getPublic 22 | writePemFile(priv, "RSA PRIVATE KEY", privateKeyFile) 23 | toAuthorizedKey(pub, "security_ssm-scala") 24 | } 25 | 26 | private def generateKeyPair(algorithm: String, provider: String) = { 27 | val generator = KeyPairGenerator.getInstance(algorithm, provider) 28 | generator.initialize(2048) 29 | generator.generateKeyPair 30 | } 31 | 32 | private def toAuthorizedKey(key: Key, description: String) = { 33 | val rsaPublicKey = key.asInstanceOf[BCRSAPublicKey] 34 | val byteOs: ByteArrayOutputStream = new ByteArrayOutputStream 35 | val dos = new DataOutputStream(byteOs) 36 | dos.writeInt ("ssh-rsa".getBytes.length) 37 | dos.write ("ssh-rsa".getBytes) 38 | dos.writeInt (rsaPublicKey.getPublicExponent.toByteArray.length) 39 | dos.write (rsaPublicKey.getPublicExponent.toByteArray) 40 | dos.writeInt (rsaPublicKey.getModulus.toByteArray.length) 41 | dos.write (rsaPublicKey.getModulus.toByteArray) 42 | val publicKeyEncoded = new String (Base64.encodeBase64 (byteOs.toByteArray) ) 43 | "ssh-rsa " + publicKeyEncoded + " " + description 44 | } 45 | 46 | private def writePemFile(key: Key, description: String, file: File): Unit = { 47 | val pemObject = new PemObject(description, key.getEncoded) 48 | val pemWriter = new PemWriter(new OutputStreamWriter(new FileOutputStream(file))) 49 | pemWriter.writeObject(pemObject) 50 | pemWriter.close() 51 | } 52 | 53 | } 54 | 55 | 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /src/main/scala/com/gu/ssm/utils/attempt/Attempt.scala: -------------------------------------------------------------------------------- 1 | package com.gu.ssm.utils.attempt 2 | 3 | import java.util.{Timer, TimerTask} 4 | 5 | import scala.concurrent.duration.FiniteDuration 6 | import scala.concurrent.{ExecutionContext, Future, Promise} 7 | import scala.util.control.NonFatal 8 | 9 | 10 | /** 11 | * Represents a value that will need to be calculated using an asynchronous 12 | * computation that may fail. 13 | */ 14 | case class Attempt[A] private (underlying: Future[Either[FailedAttempt, A]]) { 15 | /** 16 | * Change the value within an attempt 17 | */ 18 | def map[B](f: A => B)(implicit ec: ExecutionContext): Attempt[B] = 19 | flatMap(a => Attempt.Right(f(a))) 20 | 21 | /** 22 | * Create an Attempt by combining this with the result of a dependant operation 23 | * that returns a new Attempt. 24 | */ 25 | def flatMap[B](f: A => Attempt[B])(implicit ec: ExecutionContext): Attempt[B] = Attempt { 26 | asFuture.flatMap { 27 | case Right(a) => f(a).asFuture 28 | case Left(e) => Future.successful(Left(e)) 29 | } 30 | } 31 | 32 | /** 33 | * Produce a value from an Attempt regardless of whether it failed or succeeded. 34 | * 35 | * Note that Attempts are asynchronous so this must return a Future. 36 | */ 37 | def fold[B](failure: FailedAttempt => B, success: A => B)(implicit ec: ExecutionContext): Future[B] = { 38 | asFuture.map(_.fold(failure, success)) 39 | } 40 | 41 | /** 42 | * Combine this Attempt with another attempt without dependencies (in parallel). 43 | */ 44 | def map2[B, C](bAttempt: Attempt[B])(f: (A, B) => C)(implicit ec: ExecutionContext): Attempt[C] = { 45 | Attempt.map2(this, bAttempt)(f) 46 | } 47 | 48 | /** 49 | * If there is an error in the Future itself (e.g. a timeout) we convert it to a 50 | * Left so we have a consistent error representation. Unfortunately, this means 51 | * the error isn't being handled properly so we're left with just the information 52 | * provided by the exception. 53 | * 54 | * Try to avoid hitting this method's failure case by always handling Future errors 55 | * and creating a suitable failure instance for the problem. 56 | */ 57 | def asFuture(implicit ec: ExecutionContext): Future[Either[FailedAttempt, A]] = { 58 | underlying recover { case err => 59 | val apiErrors = FailedAttempt(Failure(err.getMessage, "Unexpected error", UnhandledError, err)) 60 | scala.Left(apiErrors) 61 | } 62 | } 63 | 64 | def delay(delay: FiniteDuration)(implicit ec: ExecutionContext): Attempt[A] = { 65 | Attempt.delay(delay).flatMap(_ => this) 66 | } 67 | 68 | def onComplete[B](callback: Either[FailedAttempt, A] => B)(implicit ec: ExecutionContext): Unit = { 69 | this.asFuture.onComplete { 70 | case util.Failure(e) => 71 | throw new IllegalStateException("Unexpected error handling was bypassed") 72 | case util.Success(either) => 73 | callback(either) 74 | } 75 | } 76 | } 77 | 78 | object Attempt { 79 | def map2[A, B, C](aAttempt: Attempt[A], bAttempt: Attempt[B])(f: (A, B) => C)(implicit ec: ExecutionContext): Attempt[C] = { 80 | for { 81 | a <- aAttempt 82 | b <- bAttempt 83 | } yield f(a, b) 84 | } 85 | 86 | /** 87 | * Changes generated `List[Attempt[A]]` to `Attempt[List[A]]` via provided 88 | * traversal function (like `Future.traverse`). 89 | * 90 | * This implementation returns the first failure in the resulting list, 91 | * or the successful result. 92 | */ 93 | def traverse[A, B](as: List[A])(f: A => Attempt[B])(implicit ec: ExecutionContext): Attempt[List[B]] = { 94 | as.foldRight[Attempt[List[B]]](Right(Nil))(f(_).map2(_)(_ :: _)) 95 | } 96 | 97 | /** 98 | * Using the provided traversal function, sequence the resulting attempts 99 | * into a list that preserves failures. 100 | * 101 | * This is useful if failure is acceptable in part of the application. 102 | */ 103 | def traverseWithFailures[A, B](as: List[A])(f: A => Attempt[B])(implicit ec: ExecutionContext): Attempt[List[Either[FailedAttempt, B]]] = { 104 | sequenceWithFailures(as.map(f)) 105 | } 106 | 107 | /** 108 | * As with `Future.sequence`, changes `List[Attempt[A]]` to `Attempt[List[A]]`. 109 | * 110 | * This implementation returns the first failure in the list, or the successful result. 111 | */ 112 | def sequence[A](responses: List[Attempt[A]])(implicit ec: ExecutionContext): Attempt[List[A]] = { 113 | traverse(responses)(identity) 114 | } 115 | 116 | /** 117 | * Sequence these attempts into a list that preserves failures. 118 | * 119 | * This is useful if failure is acceptable in part of the application. 120 | */ 121 | def sequenceWithFailures[A](attempts: List[Attempt[A]])(implicit ec: ExecutionContext): Attempt[List[Either[FailedAttempt, A]]] = { 122 | Async.Right(Future.traverse(attempts)(_.asFuture)) 123 | } 124 | 125 | def fromEither[A](e: Either[FailedAttempt, A]): Attempt[A] = 126 | Attempt(Future.successful(e)) 127 | 128 | def fromOption[A](optA: Option[A], ifNone: FailedAttempt): Attempt[A] = 129 | fromEither(optA.toRight(ifNone)) 130 | 131 | /** 132 | * Convert a plain `Future` value to an attempt by providing a recovery handler. 133 | */ 134 | def fromFuture[A](future: Future[A])(recovery: PartialFunction[Throwable, FailedAttempt])(implicit ec: ExecutionContext): Attempt[A] = { 135 | Attempt { 136 | future 137 | .map(scala.Right(_)) 138 | .recover { case t => 139 | scala.Left(recovery(t)) 140 | } 141 | } 142 | } 143 | 144 | /** 145 | * Discard failures from a list of attempts. 146 | * 147 | * **Use with caution**. 148 | */ 149 | def successfulAttempts[A](attempts: List[Attempt[A]])(implicit ec: ExecutionContext): Attempt[List[A]] = { 150 | Attempt.Async.Right { 151 | Future.traverse(attempts)(_.asFuture).map(_.collect { case Right(a) => a }) 152 | } 153 | } 154 | 155 | /** 156 | * Returns a successful attempt after a delay. Can be chained with other Attempts to delay those. 157 | */ 158 | def delay(delay: FiniteDuration)(implicit ctx: ExecutionContext): Attempt[Unit] = { 159 | val timer = new Timer() 160 | val prom = Promise[Unit]() 161 | val unitTask = new TimerTask { 162 | def run(): Unit = { 163 | ctx.execute(() => prom.complete(util.Success(()))) 164 | } 165 | } 166 | timer.schedule(unitTask, delay.toMillis) 167 | Attempt.fromFuture(prom.future) { 168 | case NonFatal(e) => Failure("failed to run delay task", "Internal error while delaying operations", ErrorCode, e).attempt 169 | } 170 | } 171 | 172 | /** 173 | * Retry an attempt until the condition is met. 174 | * 175 | * Note that this will fail immediately with the failure if a FailedAttempt is returned, 176 | * this function is for testing the successful value. 177 | */ 178 | def retryUntil[A](delayBetweenRetries: FiniteDuration, attemptA: () => Attempt[A])(condition: A => Boolean) 179 | (implicit ec: ExecutionContext): Attempt[A] = { 180 | def loop(a: A, attemptCount: Int): Attempt[A] = { 181 | if (condition(a)) { 182 | Attempt.Right(a) 183 | } else { 184 | for { 185 | _ <- delay(delayBetweenRetries) 186 | nextA <- attemptA() 187 | result <- loop(nextA, attemptCount + 1) 188 | } yield result 189 | } 190 | } 191 | 192 | for { 193 | initialA <- attemptA() 194 | result <- loop(initialA, 1) 195 | } yield result 196 | } 197 | 198 | /** 199 | * Create an Attempt instance from a "good" value. 200 | */ 201 | def Right[A](a: A): Attempt[A] = 202 | Attempt(Future.successful(scala.Right(a))) 203 | 204 | /** 205 | * Create an Attempt failure from an Failure instance, representing the possibility of multiple failures. 206 | */ 207 | def Left[A](errs: FailedAttempt): Attempt[A] = 208 | Attempt(Future.successful(scala.Left(errs))) 209 | /** 210 | * Syntax sugar to create an Attempt failure if there's only a single error. 211 | */ 212 | def Left[A](err: Failure): Attempt[A] = 213 | Attempt(Future.successful(scala.Left(FailedAttempt(err)))) 214 | 215 | /** 216 | * Asyncronous versions of the Attempt Right/Left helpers for when you have 217 | * a Future that returns a good/bad value directly. 218 | */ 219 | object Async { 220 | /** 221 | * Create an Attempt from a Future of a good value. 222 | */ 223 | def Right[A](fa: Future[A])(implicit ec: ExecutionContext): Attempt[A] = 224 | Attempt(fa.map(scala.Right(_))) 225 | 226 | /** 227 | * Create an Attempt from a known failure in the future. For example, 228 | * if a piece of logic fails but you need to make a Database/API call to 229 | * get the failure information. 230 | */ 231 | def Left[A](ferr: Future[FailedAttempt])(implicit ec: ExecutionContext): Attempt[A] = 232 | Attempt(ferr.map(scala.Left(_))) 233 | } 234 | } 235 | -------------------------------------------------------------------------------- /src/main/scala/com/gu/ssm/utils/attempt/Failure.scala: -------------------------------------------------------------------------------- 1 | package com.gu.ssm.utils.attempt 2 | 3 | 4 | case class FailedAttempt(failures: List[Failure]) { 5 | def exitCode: ExitCode = failures.map(_.exitCode).maxBy(_.code) 6 | } 7 | 8 | object FailedAttempt { 9 | def apply(error: Failure): FailedAttempt = { 10 | FailedAttempt(List(error)) 11 | } 12 | def apply(errors: Seq[Failure]): FailedAttempt = { 13 | FailedAttempt(errors.toList) 14 | } 15 | } 16 | 17 | case class Failure( 18 | message: String, 19 | friendlyMessage: String, 20 | exitCode: ExitCode, 21 | context: Option[String] = None, 22 | throwable: Option[Throwable] = None 23 | ) { 24 | def attempt: FailedAttempt = FailedAttempt(this) 25 | } 26 | object Failure { 27 | def apply(message: String, 28 | friendlyMessage: String, 29 | exitCode: ExitCode): Failure = apply(message, friendlyMessage, exitCode, None, None) 30 | def apply(message: String, 31 | friendlyMessage: String, 32 | exitCode: ExitCode, 33 | context: String): Failure = apply(message, friendlyMessage, exitCode, Some(context), None) 34 | def apply(message: String, 35 | friendlyMessage: String, 36 | exitCode: ExitCode, 37 | throwable: Throwable): Failure = apply(message, friendlyMessage, exitCode, None, Some(throwable)) 38 | def apply(message: String, 39 | friendlyMessage: String, 40 | exitCode: ExitCode, 41 | context: String, 42 | throwable: Throwable): Failure = apply(message, friendlyMessage, exitCode, Some(context), Some(throwable)) 43 | } 44 | 45 | sealed abstract class ExitCode(val code: Int) 46 | case object ErrorCode extends ExitCode(1) 47 | case object ArgumentsError extends ExitCode(2) 48 | case object AwsPermissionsError extends ExitCode(3) 49 | case object AwsError extends ExitCode(4) 50 | case object NoIpAddress extends ExitCode(5) 51 | case object NoHostKey extends ExitCode(6) 52 | case object UnhandledError extends ExitCode(255) -------------------------------------------------------------------------------- /src/test/scala/com/gu/ssm/LogicTest.scala: -------------------------------------------------------------------------------- 1 | package com.gu.ssm 2 | 3 | import org.scalatest.EitherValues 4 | 5 | import java.time.{Instant, LocalDateTime, ZoneId} 6 | import org.scalatest.freespec.AnyFreeSpec 7 | import org.scalatest.matchers.should.Matchers 8 | 9 | class LogicTest extends AnyFreeSpec with Matchers with EitherValues { 10 | "extractSASTags" - { 11 | import Logic.extractSASTags 12 | 13 | "extracts stack app and stage from valid input" in { 14 | val expected = Right(List("app", "stack", "stage")) 15 | extractSASTags(Seq("app", "stack", "stage")) shouldEqual expected 16 | } 17 | 18 | "provides error if nothing is provided" in { 19 | extractSASTags(Seq("")).isLeft shouldEqual true 20 | } 21 | 22 | "returns error if more than 3 tags are provided" in { 23 | extractSASTags(Seq("a", "b", "c", "d")).isLeft shouldEqual true 24 | } 25 | } 26 | 27 | "extractTunnelConfig" - { 28 | import Logic.extractTunnelConfig 29 | 30 | val hostname = "example-db.rds.amazonaws.com" 31 | 32 | "extracts tunnel config given ports and hostname" in { 33 | val expected = Right(TunnelTargetWithHostName(5432, hostname, 5432)) 34 | extractTunnelConfig(s"5432:$hostname:5432") shouldBe expected 35 | } 36 | 37 | "returns error if ports are not integers" in { 38 | extractTunnelConfig(s"5432i:$hostname:5432").isLeft shouldBe true 39 | extractTunnelConfig(s"5432:$hostname:5432i").isLeft shouldBe true 40 | } 41 | } 42 | 43 | "extractRDSTunnelConfig" - { 44 | import Logic.extractRDSTunnelConfig 45 | 46 | "extracts tunnel config given ports and tags" in { 47 | val expected = Right(TunnelTargetWithRDSTags(5432, List("APP", "STACK", "STAGE"))) 48 | extractRDSTunnelConfig(s"5432:APP,STACK,STAGE") shouldBe expected 49 | } 50 | 51 | "returns error if no tags are given" in { 52 | extractRDSTunnelConfig(s"5432:,").isLeft shouldBe true 53 | } 54 | } 55 | 56 | "generateScript" - { 57 | import Logic.generateScript 58 | 59 | "returns command if it was provided" in { 60 | generateScript(Left("ls")) shouldEqual "ls" 61 | } 62 | 63 | "returns script contents if it was provided" ignore { 64 | // TODO: testing IO is hard, should extract file's content separately 65 | } 66 | } 67 | 68 | "getSSHInstance" - { 69 | import Logic.getSSHInstance 70 | 71 | def makeInstance(id: String, publicIpOpt: Option[String], privateIp: String, launchDateDayShift: Int): Instance = 72 | Instance(InstanceId(id), None, publicIpOpt, privateIp, LocalDateTime.now().plusDays(launchDateDayShift).atZone(ZoneId.systemDefault()).toInstant()) 73 | 74 | "if given no instances, should be Left" in { 75 | getSSHInstance(List(), SismUnspecified).isLeft shouldBe true 76 | } 77 | 78 | "Given one instance" - { 79 | "If single instance selection mode is SismNewest, returns argument" in { 80 | val i = makeInstance("X", Some("127.0.0.1"), "10.1.1.10", 0) 81 | getSSHInstance(List(i), SismNewest).value shouldEqual i 82 | } 83 | 84 | "If single instance selection mode is SismOldest, returns argument" in { 85 | val i = makeInstance("X", Some("127.0.0.1"), "10.1.1.10", 0) 86 | getSSHInstance(List(i), SismOldest).value shouldEqual i 87 | } 88 | 89 | "If single instance selection mode is SismUnspecified, returns argument" in { 90 | val i = makeInstance("X", Some("127.0.0.1"), "10.1.1.10", 0) 91 | getSSHInstance(List(i), SismUnspecified).value shouldEqual i 92 | } 93 | } 94 | 95 | "Given more than one instance" - { 96 | val i1 = makeInstance("X", None, "10.1.1.10", -7) 97 | val i2 = makeInstance("Y", Some("127.0.0.1"), "10.1.1.10", -1) 98 | val i3 = makeInstance("Z", Some("127.0.0.1"), "10.1.1.10", 0) 99 | 100 | "If single instance selection mode is SismNewest, selects the newest instance with public IP" in { 101 | getSSHInstance(List(i1, i2, i3), SismNewest).value shouldEqual i3 102 | } 103 | 104 | "If single instance selection mode is SismOldest, selects the oldest instance with public IP" in { 105 | getSSHInstance(List(i1, i2, i3), SismOldest).value shouldEqual i1 106 | } 107 | 108 | "If single instance selection mode is SismUnspecified, should be Left" in { 109 | getSSHInstance(List(i1, i2, i3), SismUnspecified).isLeft shouldBe true 110 | } 111 | } 112 | 113 | } 114 | 115 | "getIpAddress" - { 116 | import Logic.getAddress 117 | 118 | def makeInstance(id: String, publicDnsOpt: Option[String], publicIpOpt: Option[String], privateIp: String): Instance = 119 | Instance(InstanceId(id), publicDnsOpt, publicIpOpt, privateIp, Instant.now()) 120 | 121 | val instanceWithPrivateIpOnly = makeInstance("id-e32cb1c9d09d", None, None, "10.1.1.10") 122 | val instanceWithPublicIpAndPrivateIp = makeInstance("id-a78414cb9b14", None, Some("34.1.1.10"), "10.1.1.10") 123 | val instanceWithPublicDnsAndPublicIPAndPrivateIp = makeInstance("id-a78414cb9b14", Some("ec2-dnsname"), Some("34.1.1.10"), "10.1.1.10") 124 | 125 | "specifying we want private IP" - { 126 | "return private if only private exists" in { 127 | val result = getAddress(instanceWithPrivateIpOnly, onlyUsePrivateIP = true) 128 | result.value shouldEqual "10.1.1.10" 129 | } 130 | 131 | "return private if public and private exists" in { 132 | val result = getAddress(instanceWithPublicIpAndPrivateIp, onlyUsePrivateIP = true) 133 | result.value shouldEqual "10.1.1.10" 134 | } 135 | } 136 | 137 | "not specifying we want private IP" - { 138 | "return public if it exists" in { 139 | val result = getAddress(instanceWithPublicIpAndPrivateIp, onlyUsePrivateIP = false) 140 | result.value shouldEqual "34.1.1.10" 141 | } 142 | 143 | "return private if no public and no dns" in { 144 | val result = getAddress(instanceWithPrivateIpOnly, onlyUsePrivateIP = false) 145 | result.value shouldEqual "10.1.1.10" 146 | } 147 | 148 | "return public IP if it exists, even if public DNS exists" in { 149 | val result = getAddress(instanceWithPublicDnsAndPublicIPAndPrivateIp, onlyUsePrivateIP = false) 150 | result.value shouldEqual "34.1.1.10" 151 | } 152 | } 153 | } 154 | 155 | "getHostKeyEntry" - { 156 | "when the results are sane" - { 157 | val results = 158 | """ 159 | |mfdsafkdlajskl;fjkadls;jfkl;adjs 160 | |fjdlasjfkld;jskl; 161 | |ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCgfV3YLgQ6PKhz3NHwFOhQA1ZgBBxYq9duNF0RdHezuBDQAdz51UKssvsIBi74/DuHk7RjaPPMZaC6yNkAuRMTyJk82S93GGow36iMTQD4HTpDuUFloT+SiTrjez/mkS2Wk+fm4brhjo9Xb8M3TXpOn65AXC/3mrB8JrZwx5Y9d2IwEQT1/r6aM1mUo2JJrSQJ1zv+3+ZFKfij1UncjG7rXsUegmR0lmt8bfAkpef1I+LK3CERgxRNCcuM80ptTws3vgxyP9cS60IiF7W1lwuwtvDvZ9LuDnHlrMi+t1t5EvwRm1CE9eLw9+qTQQijBFVjZlXT03St/6IJLMvBazI7 root@ip-10-248-50-51 162 | |ssh-dss AAAAB3NzaC1kc3MAAACBANIaavW/LDw5eBfY2Gimz5avEEQFDEIn/16LZ5a76VFBZdVgSDwZEhxtclfrdOf6JSe7kyvJL/6vFK6nb4dtgCG3Te3Tj0DU/df13SNokRo165OAe1SASpRw7JqOEdX0fMj1GHCmWZ3HhBtv4zZ1qS0IpSe6VdOZ96JtqMQc6xBvAAAAFQDoegry2E7y3iRPWQnsSDO91YLjiwAAAIAUBdldLO++SverqAgcMbNdNNnvqKgmiwfJ1UJ41tDPjw09WeMKdZ0ht2E1GdWMZXPaO/lPffP8nJlFURhW6Tihw4RW8csdJUrD63EWgXbxVTczqC3I0YWlcT7bCVOm9h0/rXOizdPl4ZtseRZ41DwSpKlSTalKAHlOTONl1DdbjAAAAIAfPZ/qIZdVvQUYeUD7fkbScm3zCj3lXbkleg4BFfBZYHtsscqxowRkJXxLHTFSvhtaKYzEAC6J1rlJRuBdr/fTTD9rpLpz+21Gc0H/2+D5ZlWrsyEfeX03pucCpdhBdQjvC5mexZyevBh7y+vD1KeimyZJMGO5MiBn4+QQ/joxMw== root@ip-10-248-50-51 163 | |ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBKDHXJ6sXLoKprcNzMDLF6YVroaf5ycshemnS1TJggIA6cf/FW5EmdzUlf+P0QfBdLsqjBVBxQhyWTtHXD4Byds= root@ip-10-248-50-51 164 | |ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIC1H8FzNqOefx3ApIF1DuY8UqFhzcAAvhAgb8+jkNlKy root@ip-10-248-50-51 165 | |fndsljfkdls;ajkfla 166 | """.stripMargin 167 | 168 | "return the host key using the first algorithm when there is a match" in { 169 | val hostKey = Logic.getHostKeyEntry(Right(CommandResult(results, "", true)), List("ecdsa-sha2-nistp256", "ssh-rsa")) 170 | hostKey.value shouldBe "ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBKDHXJ6sXLoKprcNzMDLF6YVroaf5ycshemnS1TJggIA6cf/FW5EmdzUlf+P0QfBdLsqjBVBxQhyWTtHXD4Byds= root@ip-10-248-50-51" 171 | } 172 | 173 | "return the host key using the second algorithm when there is a match for the first" in { 174 | val hostKey = Logic.getHostKeyEntry(Right(CommandResult(results, "", true)), List("ecdsa-idontexist", "ssh-rsa")) 175 | hostKey.value shouldBe "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCgfV3YLgQ6PKhz3NHwFOhQA1ZgBBxYq9duNF0RdHezuBDQAdz51UKssvsIBi74/DuHk7RjaPPMZaC6yNkAuRMTyJk82S93GGow36iMTQD4HTpDuUFloT+SiTrjez/mkS2Wk+fm4brhjo9Xb8M3TXpOn65AXC/3mrB8JrZwx5Y9d2IwEQT1/r6aM1mUo2JJrSQJ1zv+3+ZFKfij1UncjG7rXsUegmR0lmt8bfAkpef1I+LK3CERgxRNCcuM80ptTws3vgxyP9cS60IiF7W1lwuwtvDvZ9LuDnHlrMi+t1t5EvwRm1CE9eLw9+qTQQijBFVjZlXT03St/6IJLMvBazI7 root@ip-10-248-50-51" 176 | } 177 | 178 | "error when there are no suitable host keys" in { 179 | val hostKey = Logic.getHostKeyEntry(Right(CommandResult(results, "", true)), List("ssh-bob")) 180 | hostKey.left.value.failures.head.friendlyMessage shouldBe "The remote instance did not return a host key with any preferred algorithm (preferred: List(ssh-bob))" 181 | } 182 | } 183 | 184 | "when the query goes wrong" - { 185 | "error when there are no suitable host keys" in { 186 | val hostKey = Logic.getHostKeyEntry(Left(ExecutionTimedOut), List("ssh-bob")) 187 | hostKey.left.value.failures.head.friendlyMessage shouldBe "The remote instance failed to return the host keys within the timeout window (status: ExecutionTimedOut)" 188 | } 189 | } 190 | } 191 | } 192 | -------------------------------------------------------------------------------- /src/test/scala/com/gu/ssm/MainTest.scala: -------------------------------------------------------------------------------- 1 | package com.gu.ssm 2 | 3 | import org.scalatest.EitherValues 4 | import org.scalatest.freespec.AnyFreeSpec 5 | import org.scalatest.matchers.should.Matchers 6 | import com.gu.ssm.Logic.computeIncorrectInstances 7 | 8 | 9 | class MainTest extends AnyFreeSpec with Matchers with EitherValues { 10 | 11 | "computeIncorrectInstances" - { 12 | "should return empty list when matching list of Instance Ids" in { 13 | val executionTarget = ExecutionTarget(Some(List(InstanceId("i-096fdd62fd48b5b99")))) 14 | val instanceIds = List(InstanceId("i-096fdd62fd48b5b99")) 15 | computeIncorrectInstances(executionTarget, instanceIds) shouldEqual Nil 16 | } 17 | "should return incorrectly submitted Instance Id" in { 18 | val executionTarget = ExecutionTarget(Some(List(InstanceId("i-096fdd62fd48b5b99"),InstanceId("i-12345")))) 19 | val instanceIds = List(InstanceId("i-096fdd62fd48b5b99")) 20 | computeIncorrectInstances(executionTarget, instanceIds) shouldEqual List(InstanceId("i-12345")) 21 | } 22 | } 23 | 24 | } 25 | -------------------------------------------------------------------------------- /src/test/scala/com/gu/ssm/SSHTest.scala: -------------------------------------------------------------------------------- 1 | package com.gu.ssm 2 | 3 | import org.scalatest.EitherValues 4 | 5 | import java.io.File 6 | import java.time.Instant 7 | import com.amazonaws.regions.{Region, Regions} 8 | import org.scalatest.freespec.AnyFreeSpec 9 | import org.scalatest.matchers.should.Matchers 10 | 11 | class SSHTest extends AnyFreeSpec with Matchers with EitherValues { 12 | private val EU_WEST_1 = Region.getRegion(Regions.EU_WEST_1) 13 | 14 | "create add key command" - { 15 | import SSH.addPublicKeyCommand 16 | 17 | "make ssh directory" in { 18 | addPublicKeyCommand("user1", "XXX") should include ("/bin/mkdir -p /home/user1/.ssh;") 19 | } 20 | 21 | "make authorised keys" in { 22 | addPublicKeyCommand("user2", "XXX") should include ("/bin/echo 'XXX' >> /home/user2/.ssh/authorized_keys;") 23 | } 24 | 25 | "ensure authorised key file ownership is correct" in { 26 | addPublicKeyCommand("user3", "XXX") should include ("/bin/chown user3 /home/user3/.ssh/authorized_keys;") 27 | } 28 | 29 | "ensure authorised key file permissions are correct" in { 30 | addPublicKeyCommand("user4", "XXX") should include ("/bin/chmod 0600 /home/user4/.ssh/authorized_keys;") 31 | } 32 | 33 | } 34 | 35 | "create taintedcommand" - { 36 | 37 | "ensure motd command file is present" in { 38 | import SSH.addTaintedCommand 39 | addTaintedCommand("XXX") should include ("test -f /etc/update-motd.d/99-tainted || /bin/echo -e '#!/bin/bash' | /usr/bin/sudo /usr/bin/tee -a /etc/update-motd.d/99-tainted >> /dev/null;") 40 | } 41 | "ensure motd command file contains tainted message" in { 42 | import SSH.addTaintedCommand 43 | addTaintedCommand("XXX") should include ("This instance should be considered tainted.") // much text removed from this because of color codes 44 | } 45 | "ensure motd command file contains accessed message" in { 46 | import SSH.addTaintedCommand 47 | addTaintedCommand("XXX") should include ("It was accessed by XXX at") // much text removed from this because of color codes 48 | } 49 | "ensure motd command file has correct permissions" in { 50 | import SSH.addTaintedCommand 51 | addTaintedCommand("XXX") should include ("/bin/chmod 0755 /etc/update-motd.d/99-tainted;") 52 | } 53 | "ensure motd update is executed" in { 54 | import SSH.addTaintedCommand 55 | addTaintedCommand("XXX") should include ("/usr/bin/sudo /bin/run-parts /etc/update-motd.d/ | /usr/bin/sudo /usr/bin/tee /run/motd.dynamic >> /dev/null;") 56 | } 57 | } 58 | 59 | "create ssh command" - { 60 | import SSH.sshCmdStandard 61 | import SSH.sshCmdBastion 62 | import SSH.sshCredentialsLifetimeSeconds 63 | 64 | "create standard ssh command" - { 65 | 66 | val file = new File("/banana") 67 | val instance = Instance(InstanceId("raspberry"), None, Some("34.1.1.10"), "10.1.1.10", Instant.now()) 68 | 69 | "instance id is correct" in { 70 | val (instanceId, _) = sshCmdStandard(false)(file, instance, "user4", "34.1.1.10", None, None, Some(false), None, EU_WEST_1, tunnelThroughSystemsManager = false, tunnelTarget = None) 71 | instanceId.id shouldEqual "raspberry" 72 | } 73 | 74 | "user command" - { 75 | "is correctly formed without port specification" in { 76 | val (_, command) = sshCmdStandard(false)(file, instance, "user4", "34.1.1.10", None, None, Some(false), None, EU_WEST_1, tunnelThroughSystemsManager = false, tunnelTarget = None) 77 | command should contain (Out("""ssh -o "IdentitiesOnly yes" -a -i /banana user4@34.1.1.10;""")) 78 | } 79 | 80 | "is correctly formed with port specification" in { 81 | val (_, command) = sshCmdStandard(false)(file, instance, "user4", "34.1.1.10", Some(2345), None, Some(false), None, EU_WEST_1, tunnelThroughSystemsManager = false, tunnelTarget = None) 82 | command should contain (Out("""ssh -o "IdentitiesOnly yes" -a -p 2345 -i /banana user4@34.1.1.10;""")) 83 | } 84 | 85 | "is correctly formed with a hosts file" in { 86 | val (_, command) = sshCmdStandard(false)(file, instance, "user4", "34.1.1.10", Some(2345), Some(new File("/tmp/hostsfile")), Some(false), None, EU_WEST_1, tunnelThroughSystemsManager = false, tunnelTarget = None) 87 | command should contain (Out("""ssh -o "IdentitiesOnly yes" -a -o "UserKnownHostsFile /tmp/hostsfile" -o "StrictHostKeyChecking yes" -p 2345 -i /banana user4@34.1.1.10;""")) 88 | } 89 | 90 | "is correctly formed with agent forwarding file" in { 91 | val (_, command) = sshCmdStandard(false)(file, instance, "user4", "34.1.1.10", Some(2345), None, Some(true), None, EU_WEST_1, tunnelThroughSystemsManager = false, tunnelTarget = None) 92 | command should contain (Out("""ssh -o "IdentitiesOnly yes" -A -p 2345 -i /banana user4@34.1.1.10;""")) 93 | } 94 | } 95 | 96 | "machine command" - { 97 | "is correctly formed without port specification" in { 98 | val (_, command) = sshCmdStandard(true)(file, instance, "user4", "34.1.1.10", None, None, Some(false), None, EU_WEST_1, tunnelThroughSystemsManager = false, tunnelTarget = None) 99 | command.head.text should equal ("""ssh -o "IdentitiesOnly yes" -a -i /banana -t -t user4@34.1.1.10""") 100 | } 101 | 102 | "is correctly formed with port specification" in { 103 | val (_, command) = sshCmdStandard(true)(file, instance, "user4", "34.1.1.10", Some(2345), None, Some(false), None, EU_WEST_1, tunnelThroughSystemsManager = false, tunnelTarget = None) 104 | command.head.text should equal ("""ssh -o "IdentitiesOnly yes" -a -p 2345 -i /banana -t -t user4@34.1.1.10""") 105 | } 106 | } 107 | 108 | "ssm tunnel" - { 109 | "is correctly formed" in { 110 | val (_, command) = sshCmdStandard(true)(file, instance, "user4", "34.1.1.10", None, None, Some(false), None, EU_WEST_1, tunnelThroughSystemsManager = true, tunnelTarget = None) 111 | command.head.text should equal ("""ssh -o "IdentitiesOnly yes" -a -o "ProxyCommand sh -c \"aws ssm start-session --target raspberry --document-name AWS-StartSSHSession --parameters 'portNumber=22' --region eu-west-1 \"" -i /banana -t -t user4@34.1.1.10""") 112 | } 113 | } 114 | 115 | "ssh tunnel to remote host" - { 116 | "is correctly formed" in { 117 | val (_, command) = sshCmdStandard(true)(file, instance, "user4", "34.1.1.10", None, None, Some(false), None, EU_WEST_1, tunnelThroughSystemsManager = true, tunnelTarget = Some(TunnelTargetWithHostName(5000, "example-hostname.com", 5432))) 118 | command.head.text should equal ("""ssh -o "IdentitiesOnly yes" -a -o "ProxyCommand sh -c \"aws ssm start-session --target raspberry --document-name AWS-StartSSHSession --parameters 'portNumber=22' --region eu-west-1 \"" -i /banana -t -t user4@34.1.1.10 -L 5000:example-hostname.com:5432 -N -f""") 119 | } 120 | } 121 | } 122 | 123 | "create bastion ssh command" - { 124 | 125 | val file = new File("/banana") 126 | val bastionInstance = Instance(InstanceId("raspberry"), None, Some("34.1.1.10"), "10.1.1.10", Instant.now()) 127 | val targetInstance = Instance(InstanceId("strawberry"), None, Some("34.1.1.11"), "10.1.1.11", Instant.now()) 128 | 129 | "instance id is correct" in { 130 | val (instanceId, _) = sshCmdBastion(false)(file, bastionInstance, targetInstance, "user5", "34.1.1.10", "10.1.1.11", None, "bastionuser", None, Some(false), None) 131 | instanceId.id shouldEqual "strawberry" 132 | } 133 | 134 | "user command" - { 135 | "contains the user instructions" in { 136 | val (_, command) = sshCmdBastion(false)(file, bastionInstance, targetInstance, "user5", "34.1.1.10", "10.1.1.11", None, "bastionuser", None, Some(false), None) 137 | command should contain (Metadata(s"# Dryrun mode. The command below will remain valid for $sshCredentialsLifetimeSeconds seconds:")) 138 | } 139 | 140 | "contains the ssh command" in { 141 | val (_, command) = sshCmdBastion(false)(file, bastionInstance, targetInstance, "user5", "34.1.1.10", "10.1.1.11", None, "bastionuser", None, Some(false), None) 142 | command.find(_.isInstanceOf[Out]).head.text should include ("ssh") 143 | } 144 | } 145 | 146 | "machine command" - { 147 | "is well formed without any port specification" - { 148 | "agent-agnostic" in { 149 | val (_, command) = sshCmdBastion(true)(file, bastionInstance, targetInstance, "user5", "34.1.1.10", "10.1.1.11", None, "bastionuser", None, None, None) 150 | command.head.text should equal ("""ssh -o "IdentitiesOnly yes" -i /banana -o 'ProxyCommand ssh -o "IdentitiesOnly yes" -i /banana -p 22 bastionuser@34.1.1.10 nc 10.1.1.11 22' -t -t user5@10.1.1.11""") 151 | } 152 | 153 | "no agent" in { 154 | val (_, command) = sshCmdBastion(true)(file, bastionInstance, targetInstance, "user5", "34.1.1.10", "10.1.1.11", None, "bastionuser", None, Some(false), None) 155 | command.head.text should equal ("""ssh -a -o "IdentitiesOnly yes" -i /banana -o 'ProxyCommand ssh -o "IdentitiesOnly yes" -i /banana -p 22 bastionuser@34.1.1.10 nc 10.1.1.11 22' -t -t user5@10.1.1.11""") 156 | } 157 | 158 | "with agent" in { 159 | val (_, command) = sshCmdBastion(true)(file, bastionInstance, targetInstance, "user5", "34.1.1.10", "10.1.1.11", None, "bastionuser", None, Some(true), None) 160 | command.head.text should equal ("""ssh -A -o "IdentitiesOnly yes" -i /banana -o 'ProxyCommand ssh -o "IdentitiesOnly yes" -i /banana -p 22 bastionuser@34.1.1.10 nc 10.1.1.11 22' -t -t user5@10.1.1.11""") 161 | } 162 | } 163 | 164 | "is well formed with target instance port specification" - { 165 | "agent-agnostic" in { 166 | val (_, command) = sshCmdBastion(true)(file, bastionInstance, targetInstance, "user5", "34.1.1.10", "10.1.1.11", None, "bastionuser", Some(2345), None, None) 167 | command.head.text should equal ("""ssh -o "IdentitiesOnly yes" -i /banana -o 'ProxyCommand ssh -o "IdentitiesOnly yes" -i /banana -p 22 bastionuser@34.1.1.10 nc 10.1.1.11 2345' -t -t user5@10.1.1.11""") 168 | } 169 | "no agent" in { 170 | val (_, command) = sshCmdBastion(true)(file, bastionInstance, targetInstance, "user5", "34.1.1.10", "10.1.1.11", None, "bastionuser", Some(2345), Some(false), None) 171 | command.head.text should equal ("""ssh -a -o "IdentitiesOnly yes" -i /banana -o 'ProxyCommand ssh -o "IdentitiesOnly yes" -i /banana -p 22 bastionuser@34.1.1.10 nc 10.1.1.11 2345' -t -t user5@10.1.1.11""") 172 | } 173 | 174 | "with agent" in { 175 | val (_, command) = sshCmdBastion(true)(file, bastionInstance, targetInstance, "user5", "34.1.1.10", "10.1.1.11", None, "bastionuser", Some(2345), Some(true), None) 176 | command.head.text should equal ("""ssh -A -o "IdentitiesOnly yes" -i /banana -o 'ProxyCommand ssh -o "IdentitiesOnly yes" -i /banana -p 22 bastionuser@34.1.1.10 nc 10.1.1.11 2345' -t -t user5@10.1.1.11""") 177 | } 178 | } 179 | 180 | "is well formed with bastion port specification" - { 181 | "agent-agnostic" in { 182 | val (_, command) = sshCmdBastion(true)(file, bastionInstance, targetInstance, "user5", "34.1.1.10", "10.1.1.11", Some(1234), "bastionuser", None, None, None) 183 | command.head.text should equal ("""ssh -o "IdentitiesOnly yes" -i /banana -o 'ProxyCommand ssh -o "IdentitiesOnly yes" -i /banana -p 1234 bastionuser@34.1.1.10 nc 10.1.1.11 22' -t -t user5@10.1.1.11""") 184 | } 185 | "no agent" in { 186 | val (_, command) = sshCmdBastion(true)(file, bastionInstance, targetInstance, "user5", "34.1.1.10", "10.1.1.11", Some(1234), "bastionuser", None, Some(false), None) 187 | command.head.text should equal ("""ssh -a -o "IdentitiesOnly yes" -i /banana -o 'ProxyCommand ssh -o "IdentitiesOnly yes" -i /banana -p 1234 bastionuser@34.1.1.10 nc 10.1.1.11 22' -t -t user5@10.1.1.11""") 188 | } 189 | 190 | "with agent" in { 191 | val (_, command) = sshCmdBastion(true)(file, bastionInstance, targetInstance, "user5", "34.1.1.10", "10.1.1.11", Some(1234), "bastionuser", None, Some(true), None) 192 | command.head.text should equal ("""ssh -A -o "IdentitiesOnly yes" -i /banana -o 'ProxyCommand ssh -o "IdentitiesOnly yes" -i /banana -p 1234 bastionuser@34.1.1.10 nc 10.1.1.11 22' -t -t user5@10.1.1.11""") 193 | } 194 | } 195 | 196 | "is well formed with both bastion port and target instance port specifications" - { 197 | "agent-agnostic" in { 198 | val (_, command) = sshCmdBastion(true)(file, bastionInstance, targetInstance, "user5", "34.1.1.10", "10.1.1.11", Some(1234), "bastionuser", Some(2345), None, None) 199 | command.head.text should equal ("""ssh -o "IdentitiesOnly yes" -i /banana -o 'ProxyCommand ssh -o "IdentitiesOnly yes" -i /banana -p 1234 bastionuser@34.1.1.10 nc 10.1.1.11 2345' -t -t user5@10.1.1.11""") 200 | } 201 | 202 | "no agent" in { 203 | val (_, command) = sshCmdBastion(true)(file, bastionInstance, targetInstance, "user5", "34.1.1.10", "10.1.1.11", Some(1234), "bastionuser", Some(2345), Some(false), None) 204 | command.head.text should equal ("""ssh -a -o "IdentitiesOnly yes" -i /banana -o 'ProxyCommand ssh -o "IdentitiesOnly yes" -i /banana -p 1234 bastionuser@34.1.1.10 nc 10.1.1.11 2345' -t -t user5@10.1.1.11""") 205 | } 206 | 207 | "with agent" in { 208 | val (_, command) = sshCmdBastion(true)(file, bastionInstance, targetInstance, "user5", "34.1.1.10", "10.1.1.11", Some(1234), "bastionuser", Some(2345), Some(true), None) 209 | command.head.text should equal ("""ssh -A -o "IdentitiesOnly yes" -i /banana -o 'ProxyCommand ssh -o "IdentitiesOnly yes" -i /banana -p 1234 bastionuser@34.1.1.10 nc 10.1.1.11 2345' -t -t user5@10.1.1.11""") 210 | } 211 | } 212 | 213 | "is well formed with a host key file" in { 214 | val (_, command) = sshCmdBastion(true)(file, bastionInstance, targetInstance, "user5", "34.1.1.10", "10.1.1.11", Some(1234), "bastionuser", Some(2345), Some(false), Some(new File("/tmp/hostfile"))) 215 | command.head.text should equal ("""ssh -a -o "IdentitiesOnly yes" -i /banana -o "UserKnownHostsFile /tmp/hostfile" -o "StrictHostKeyChecking yes" -o 'ProxyCommand ssh -o "IdentitiesOnly yes" -i /banana -o "UserKnownHostsFile /tmp/hostfile" -o "StrictHostKeyChecking yes" -p 1234 bastionuser@34.1.1.10 nc 10.1.1.11 2345' -t -t user5@10.1.1.11""") 216 | } 217 | } 218 | } 219 | } 220 | 221 | "create scp command" - { 222 | import SSH.scpCmdStandard 223 | 224 | "create standard scp command" - { 225 | 226 | val file = new File("/banana") 227 | val instance = Instance(InstanceId("raspberry"), None, Some("34.1.1.10"), "10.1.1.10", Instant.now()) 228 | 229 | "instance id is correct" in { 230 | val (instanceId, _) = scpCmdStandard(false)(file, instance, "user4", "34.1.1.10", None, Some(false), None, "/path/to/sourceFile", ":/path/to/targetFile", None, EU_WEST_1, tunnelThroughSystemsManager = false) 231 | instanceId.id shouldEqual "raspberry" 232 | } 233 | 234 | "user command" - { 235 | 236 | "process correctly remote server specifications" - { 237 | 238 | "target file is remote" in { 239 | val (_, command) = scpCmdStandard(false)(file, instance, "user4", "34.1.1.10", None, Some(false), None, "/path/to/sourceFile", ":/path/to/targetFile", None, EU_WEST_1, tunnelThroughSystemsManager = false) 240 | command should contain (Out("""scp -o "IdentitiesOnly yes" -a -i /banana /path/to/sourceFile user4@34.1.1.10:/path/to/targetFile;""")) 241 | } 242 | "source file is remote" in { 243 | val (_, command) = scpCmdStandard(false)(file, instance, "user4", "34.1.1.10", None, Some(false), None, ":/path/to/sourceFile", "/path/to/targetFile", None, EU_WEST_1, tunnelThroughSystemsManager = false) 244 | command should contain (Out("""scp -o "IdentitiesOnly yes" -a -i /banana user4@34.1.1.10:/path/to/sourceFile /path/to/targetFile;""")) 245 | } 246 | "incorrect specifications in" in { 247 | val (_, command) = scpCmdStandard(false)(file, instance, "user4", "34.1.1.10", None, Some(false), None, ":/path/to/sourceFile", ":/path/to/targetFile", None, EU_WEST_1, tunnelThroughSystemsManager = false) 248 | command.head.text should include ("Incorrect remote server specifications") 249 | } 250 | } 251 | 252 | "is correctly formed without port specification" in { 253 | val (_, command) = scpCmdStandard(false)(file, instance, "user4", "34.1.1.10", None, Some(false), None, "/path/to/sourceFile", ":/path/to/targetFile", None, EU_WEST_1, tunnelThroughSystemsManager = false) 254 | command should contain (Out("""scp -o "IdentitiesOnly yes" -a -i /banana /path/to/sourceFile user4@34.1.1.10:/path/to/targetFile;""")) 255 | } 256 | 257 | "is correctly formed with port specification" in { 258 | val (_, command) = scpCmdStandard(false)(file, instance, "user4", "34.1.1.10", Some(2345), Some(false), None, "/path/to/sourceFile", ":/path/to/targetFile", None, EU_WEST_1, tunnelThroughSystemsManager = false) 259 | command should contain (Out("""scp -o "IdentitiesOnly yes" -a -p 2345 -i /banana /path/to/sourceFile user4@34.1.1.10:/path/to/targetFile;""")) 260 | } 261 | 262 | "is correctly formed with a hosts file" in { 263 | val (_, command) = scpCmdStandard(false)(file, instance, "user4", "34.1.1.10", Some(2345), Some(false), Some(new File("/tmp/hostsfile")), "/path/to/sourceFile", ":/path/to/targetFile", None, EU_WEST_1, tunnelThroughSystemsManager = false) 264 | command should contain (Out("""scp -o "IdentitiesOnly yes" -a -o "UserKnownHostsFile /tmp/hostsfile" -o "StrictHostKeyChecking yes" -p 2345 -i /banana /path/to/sourceFile user4@34.1.1.10:/path/to/targetFile;""")) 265 | } 266 | 267 | "is correctly formed with agent forwarding file" in { 268 | val (_, command) = scpCmdStandard(false)(file, instance, "user4", "34.1.1.10", Some(2345), Some(true), None, "/path/to/sourceFile", ":/path/to/targetFile", None, EU_WEST_1, tunnelThroughSystemsManager = false) 269 | command should contain (Out("""scp -o "IdentitiesOnly yes" -A -p 2345 -i /banana /path/to/sourceFile user4@34.1.1.10:/path/to/targetFile;""")) 270 | } 271 | } 272 | 273 | "machine command" - { 274 | "process correctly remote server specifications" - { 275 | 276 | "target file is remote" in { 277 | val (_, command) = scpCmdStandard(true)(file, instance, "user4", "34.1.1.10", None, Some(false), None, "/path/to/sourceFile", ":/path/to/targetFile", None, EU_WEST_1, tunnelThroughSystemsManager = false) 278 | command.head.text should equal ("""scp -o "IdentitiesOnly yes" -a -i /banana /path/to/sourceFile user4@34.1.1.10:/path/to/targetFile""") 279 | } 280 | "source file is remote" in { 281 | val (_, command) = scpCmdStandard(true)(file, instance, "user4", "34.1.1.10", None, Some(false), None, ":/path/to/sourceFile", "/path/to/targetFile", None, EU_WEST_1, tunnelThroughSystemsManager = false) 282 | command.head.text should equal ("""scp -o "IdentitiesOnly yes" -a -i /banana user4@34.1.1.10:/path/to/sourceFile /path/to/targetFile""") 283 | } 284 | "incorrect specifications in" in { 285 | val (_, command) = scpCmdStandard(true)(file, instance, "user4", "34.1.1.10", None, Some(false), None, ":/path/to/sourceFile", ":/path/to/targetFile", None, EU_WEST_1, tunnelThroughSystemsManager = false) 286 | command.head.text should include ("Incorrect remote server specifications") 287 | } 288 | } 289 | 290 | "is correctly formed without port specification" in { 291 | val (_, command) = scpCmdStandard(true)(file, instance, "user4", "34.1.1.10", None, Some(false), None, ":/path/to/sourceFile", "/path/to/targetFile", None, EU_WEST_1, tunnelThroughSystemsManager = false) 292 | command.head.text should equal ("""scp -o "IdentitiesOnly yes" -a -i /banana user4@34.1.1.10:/path/to/sourceFile /path/to/targetFile""") 293 | } 294 | 295 | "is correctly formed with port specification" in { 296 | val (_, command) = scpCmdStandard(true)(file, instance, "user4", "34.1.1.10", Some(2345), Some(false), None, "/path/to/sourceFile", ":/path/to/targetFile", None, EU_WEST_1, tunnelThroughSystemsManager = false) 297 | command.head.text should equal ("""scp -o "IdentitiesOnly yes" -a -p 2345 -i /banana /path/to/sourceFile user4@34.1.1.10:/path/to/targetFile""") 298 | } 299 | } 300 | 301 | "ssm tunnel" - { 302 | "is correctly formed" in { 303 | val (_, command) = scpCmdStandard(true)(file, instance, "user4", "34.1.1.10", None, Some(false), None, "/path/to/sourceFile", ":/path/to/targetFile", None, EU_WEST_1, tunnelThroughSystemsManager = true) 304 | command.head.text should equal ("""scp -o "IdentitiesOnly yes" -a -o "ProxyCommand sh -c \"aws ssm start-session --target raspberry --document-name AWS-StartSSHSession --parameters 'portNumber=22' --region eu-west-1 \"" -i /banana /path/to/sourceFile user4@34.1.1.10:/path/to/targetFile""") 305 | } 306 | } 307 | } 308 | } 309 | } 310 | -------------------------------------------------------------------------------- /src/test/scala/com/gu/ssm/UITest.scala: -------------------------------------------------------------------------------- 1 | package com.gu.ssm 2 | 3 | import org.scalatest.freespec.AnyFreeSpec 4 | import org.scalatest.matchers.should.Matchers 5 | 6 | class UITest extends AnyFreeSpec with Matchers { 7 | "hasAnyCommandFailed" - { 8 | "returns false if no commands failed" in { 9 | val command = CommandResult("", "", commandFailed = false) 10 | UI.hasAnyCommandFailed(List(InstanceId("test") -> Right(command))) shouldBe false 11 | } 12 | 13 | "returns true if a single command failed" in { 14 | val command = CommandResult("", "", commandFailed = true) 15 | UI.hasAnyCommandFailed(List(InstanceId("test") -> Right(command))) shouldBe true 16 | } 17 | 18 | "returns true if at least one command failed" in { 19 | val commands = List( 20 | InstanceId("test1") -> Right(CommandResult("", "", commandFailed = true)), 21 | InstanceId("test2") -> Right(CommandResult("", "", commandFailed = false)) 22 | ) 23 | 24 | UI.hasAnyCommandFailed(commands) shouldBe true 25 | } 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /src/test/scala/com/gu/ssm/utils/attempt/AttemptTest.scala: -------------------------------------------------------------------------------- 1 | package com.gu.ssm.utils.attempt 2 | 3 | import java.util.concurrent.TimeoutException 4 | import org.scalatest.EitherValues 5 | import Attempt.{Left, Right} 6 | import org.scalatest.freespec.AnyFreeSpec 7 | import org.scalatest.matchers.should.Matchers 8 | 9 | import scala.concurrent.Await 10 | import scala.concurrent.duration._ 11 | import scala.concurrent.ExecutionContext.Implicits.global 12 | 13 | 14 | class AttemptTest extends AnyFreeSpec with Matchers with EitherValues with AttemptValues { 15 | "traverse" - { 16 | "returns the first failure" in { 17 | def failOnFourAndSix(i: Int): Attempt[Int] = { 18 | i match { 19 | case 4 => expectedFailure("fails on four") 20 | case 6 => expectedFailure("fails on six") 21 | case n => Right(n) 22 | } 23 | } 24 | val errors = Attempt.traverse(List(1, 2, 3, 4, 5, 6))(failOnFourAndSix).leftValue() 25 | checkError(errors, "fails on four") 26 | } 27 | 28 | "returns the successful result if there were no failures" in { 29 | Attempt.traverse(List(1, 2, 3, 4))(Right).value() shouldEqual List(1, 2, 3, 4) 30 | } 31 | } 32 | 33 | "successfulAttempts" - { 34 | "returns the list if all were successful" in { 35 | val attempts = List(Right(1), Right(2)) 36 | 37 | Attempt.successfulAttempts(attempts).value() shouldEqual List(1, 2) 38 | } 39 | 40 | "returns only the successful attempts if there were failures" in { 41 | val attempts: List[Attempt[Int]] = List(Right(1), Right(2), expectedFailure("failed"), Right(4)) 42 | 43 | Attempt.successfulAttempts(attempts).value() shouldEqual List(1, 2, 4) 44 | } 45 | } 46 | 47 | "delay" - { 48 | "will cause timeout in shorter time" ignore { 49 | val future = Attempt.delay(10.millis).asFuture 50 | intercept[TimeoutException] { 51 | Await.result(future, 5.millis) 52 | } 53 | } 54 | 55 | "will not cause timeout with longer delay" in { 56 | val future = Attempt.delay(5.millis).asFuture 57 | noException should be thrownBy Await.result(future, 10.millis) 58 | } 59 | } 60 | 61 | "retry" - { 62 | "returns success if the attempt returns successfully" in { 63 | Attempt.retryUntil(Duration.Zero, () => Attempt.Right(true))(_ == true).isSuccessfulAttempt() shouldEqual true 64 | } 65 | 66 | "returns true if the attempt returns after retrying" in { 67 | var counter = 0 68 | def incr(): Attempt[Int] = { 69 | counter += 1 70 | Attempt.Right(counter) 71 | } 72 | 73 | Attempt.retryUntil(Duration.Zero, () => incr())(_ > 3).value() shouldEqual 4 74 | } 75 | 76 | "returns failure if the attempt fails" in { 77 | val failure = Failure("test failure", "Test failure", ErrorCode).attempt 78 | Attempt.retryUntil(Duration.Zero, () => Attempt.Left[Boolean](failure))(_ => true).leftValue() shouldEqual failure 79 | } 80 | 81 | "delays between retrying" - { 82 | "and thus will time out in this test" in { 83 | val future = Attempt.retryUntil(10.millis, () => Attempt.Right(false))(_ == true).asFuture 84 | intercept[TimeoutException] { 85 | Await.result(future, 25.millis) 86 | } 87 | } 88 | } 89 | } 90 | 91 | /** 92 | * Utilities for checking the failure state of attempts 93 | */ 94 | def checkError(errors: FailedAttempt, expected: String): Unit = { 95 | errors.failures.head.message shouldEqual expected 96 | } 97 | def expectedFailure[A](message: String): Attempt[A] = Left[A](Failure(message, "this will fail", ErrorCode)) 98 | } 99 | -------------------------------------------------------------------------------- /src/test/scala/com/gu/ssm/utils/attempt/AttemptValues.scala: -------------------------------------------------------------------------------- 1 | package com.gu.ssm.utils.attempt 2 | 3 | import java.io.{ByteArrayOutputStream, PrintWriter} 4 | import org.scalatest.exceptions.TestFailedException 5 | import org.scalatest.matchers.should.Matchers 6 | 7 | import scala.concurrent.duration._ 8 | import scala.concurrent.{Await, ExecutionContext} 9 | 10 | 11 | trait AttemptValues extends Matchers { 12 | implicit class RichAttempt[A](attempt: Attempt[A]) { 13 | private def stackTrace(failure: Failure): String = { 14 | failure.throwable.map { t => 15 | val baos = new ByteArrayOutputStream() 16 | val pw = new PrintWriter(baos) 17 | t.printStackTrace(pw) 18 | pw.close() 19 | baos.toString 20 | }.getOrElse("") 21 | } 22 | 23 | def value()(implicit ec: ExecutionContext): A = { 24 | val result = Await.result(attempt.asFuture, 5.seconds) 25 | withClue { 26 | result.fold( 27 | fa => s"${fa.failures.map(_.message).mkString(", ")} - ${fa.failures.map(stackTrace).mkString("\n\n")}", 28 | _ => "" 29 | ) 30 | } { 31 | result.fold[A]( 32 | _ => throw new TestFailedException("Could not extract value from failed Attempt", 10), 33 | identity 34 | ) 35 | } 36 | } 37 | 38 | def leftValue()(implicit ec: ExecutionContext): FailedAttempt = { 39 | val result = Await.result(attempt.asFuture, 5.seconds) 40 | withClue { 41 | result.fold( 42 | _ => "", 43 | a => s"$a" 44 | ) 45 | } { 46 | result.fold[FailedAttempt]( 47 | identity, 48 | failed => throw new TestFailedException("Cannot extract failure from successful Attempt", 10) 49 | ) 50 | } 51 | } 52 | 53 | def isSuccessfulAttempt()(implicit ec: ExecutionContext): Boolean = { 54 | Await.result(attempt.asFuture, 5.seconds).fold ( 55 | _ => false, 56 | _ => true 57 | ) 58 | } 59 | 60 | def isFailedAttempt()(implicit ec: ExecutionContext): Boolean = { 61 | !isSuccessfulAttempt() 62 | } 63 | } 64 | } 65 | --------------------------------------------------------------------------------