├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── enhancement_or_feature_request.md ├── dependabot.yml ├── pull_request_template.md └── workflows │ ├── ci.yaml │ └── clippy-linting.yml ├── .gitignore ├── .rustfmt.toml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── Cargo.toml ├── LICENSE ├── MAINTAINERS.md ├── Makefile ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── build.rs ├── config └── azure-init.service ├── demo ├── customdata_template.yml └── image_creation.sh ├── doc ├── configuration.md └── libazurekvp.md ├── docs └── E2E_TESTING.md ├── libazureinit ├── Cargo.toml ├── README.md ├── build.rs └── src │ ├── config.rs │ ├── error.rs │ ├── goalstate.rs │ ├── http.rs │ ├── imds.rs │ ├── lib.rs │ ├── media.rs │ ├── provision │ ├── hostname.rs │ ├── mod.rs │ ├── password.rs │ ├── ssh.rs │ └── user.rs │ ├── status.rs │ └── unittest.rs ├── src ├── kvp.rs ├── logging.rs └── main.rs └── tests ├── cli.rs ├── functional_tests.rs └── functional_tests.sh /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a bug report 4 | title: '' 5 | labels: "bug" 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## Description 11 | 12 | 13 | 14 | ## Impact 15 | 16 | 17 | 18 | ## Environment and steps to reproduce 19 | 20 | 21 | 1. **Set-up**: 22 | 23 | 24 | 2. **Action(s)**: 25 | 26 | a. 27 | 28 | 29 | b. 30 | 31 | 32 | 3. **Error**: 33 | 34 | ## Expected behavior 35 | 36 | 37 | 38 | ## Additional information 39 | 40 | 41 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/enhancement_or_feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Enhancement or Feature request 3 | about: Suggest a new feature or enhancement 4 | title: '[RFE]' 5 | labels: "feature" 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## Current situation 11 | 12 | 13 | 14 | ## Impact 15 | 16 | 17 | 18 | ## Ideal future situation 19 | 20 | 21 | 22 | ## **Implementation options 23 | 24 | 25 | 26 | ## Additional information 27 | 28 | 29 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "github-actions" 4 | directory: "/" 5 | schedule: 6 | interval: "weekly" 7 | - package-ecosystem: "cargo" 8 | directory: "/" 9 | schedule: 10 | interval: "weekly" 11 | ignore: 12 | - dependency-name: "idna_adapter" 13 | # stay with <= 1.2.0 for Rust 1.78 14 | - dependency-name: "litemap" 15 | # stay with <= 0.7.4 for Rust 1.76 16 | - dependency-name: "zerofrom" 17 | # stay with <= 0.1.5 for Rust 1.76 18 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## How to use 4 | 5 | 6 | 7 | ## Testing done 8 | 9 | 10 | -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: "Run CI" 2 | permissions: 3 | contents: read 4 | pull-requests: write 5 | on: 6 | push: 7 | branches: 8 | - main 9 | pull_request: 10 | branches: 11 | - main 12 | 13 | jobs: 14 | build-test: 15 | name: Build and test azure-init 16 | runs-on: ubuntu-latest 17 | strategy: 18 | matrix: 19 | rust-toolchain: [stable, stable 12 months ago] 20 | steps: 21 | - name: Install libudev 22 | run: | 23 | sudo apt update 24 | sudo apt install -y libudev-dev 25 | - uses: actions/checkout@v4 26 | - uses: dtolnay/rust-toolchain@master 27 | with: 28 | toolchain: ${{matrix.rust-toolchain}} 29 | components: clippy, rustfmt 30 | - name: Rustfmt Check 31 | run: cargo fmt --all --check 32 | - name: Build azure-init 33 | run: cargo build --verbose 34 | - name: Run unit tests 35 | run: cargo test --verbose --all-features --workspace 36 | - name: Run clippy 37 | run: cargo clippy --verbose -- --deny warnings 38 | -------------------------------------------------------------------------------- /.github/workflows/clippy-linting.yml: -------------------------------------------------------------------------------- 1 | name: "Run Clippy for Linting" 2 | permissions: 3 | contents: read 4 | pull-requests: write 5 | on: 6 | pull_request: 7 | branches: 8 | - main 9 | schedule: 10 | - cron: '0 0 * * 0' 11 | 12 | jobs: 13 | clippy: 14 | name: Run clippy on azure-init 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | rust-toolchain: [stable, stable 12 months ago] 19 | steps: 20 | - name: Install libudev 21 | run: | 22 | sudo apt update 23 | sudo apt install -y libudev-dev 24 | - uses: actions/checkout@v4 25 | - uses: dtolnay/rust-toolchain@master 26 | with: 27 | toolchain: ${{matrix.rust-toolchain}} 28 | components: clippy 29 | - name: Run clippy 30 | run: cargo clippy --all-targets --all-features --verbose -- --deny warnings 31 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by Cargo 2 | # will have compiled files and executables 3 | debug/ 4 | target/ 5 | 6 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 7 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html 8 | Cargo.lock 9 | 10 | # These are backup files generated by rustfmt 11 | **/*.rs.bk 12 | 13 | # MSVC Windows builds of rustc generate these, which store debugging information 14 | *.pdb 15 | 16 | # .tgz Files are used for creating the image and are auto generated by the script 17 | azure-init.tgz 18 | -------------------------------------------------------------------------------- /.rustfmt.toml: -------------------------------------------------------------------------------- 1 | edition="2021" 2 | fn_params_layout="Tall" 3 | force_explicit_abi=true 4 | max_width=80 5 | tab_spaces=4 6 | use_field_init_shorthand=true 7 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribute to Azure-Init 2 | 3 | This project welcomes contributions and suggestions. Most contributions require you to 4 | agree to a Contributor License Agreement (CLA) declaring that you have the right to, 5 | and actually do, grant us the rights to use your contribution. For details, visit 6 | https://cla.microsoft.com. 7 | 8 | When you submit a pull request, a CLA-bot will automatically determine whether you need 9 | to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the 10 | instructions provided by the bot. You will only need to do this once across all repositories using our CLA. 11 | 12 | For each pull request, CI automatically runs unit tests by running `cargo test`, and also checks for coding styles and lints by running `cargo fmt` and `cargo clippy`. So please make sure that the all steps pass with the changes you made, `test`, `fmt`, and `clippy`, to avoid making CI fail with such issues. 13 | 14 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 15 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 16 | or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 17 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "azure-init" 3 | version = "0.1.1" 4 | edition = "2021" 5 | rust-version = "1.74" 6 | repository = "https://github.com/Azure/azure-init/" 7 | homepage = "https://github.com/Azure/azure-init/" 8 | license = "MIT" 9 | readme = "README.md" 10 | description = "A reference implementation for provisioning Linux VMs on Azure." 11 | build = "build.rs" 12 | 13 | [dependencies] 14 | exitcode = "1.1.2" 15 | anyhow = "1.0.81" 16 | tokio = { version = "1", features = ["full"] } 17 | tracing = "0.1.40" 18 | clap = { version = "4.5.21", features = ["derive", "cargo", "env"] } 19 | sysinfo = "0.35" 20 | tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } 21 | opentelemetry = "0.30" 22 | opentelemetry_sdk = "0.30" 23 | tracing-opentelemetry = "0.31" 24 | uuid = { version = "1.2", features = ["v4"] } 25 | chrono = "0.4" 26 | 27 | [dev-dependencies] 28 | assert_cmd = "2.0.16" 29 | predicates = "3.1.2" 30 | predicates-core = "1.0.8" 31 | predicates-tree = "1.0.11" 32 | tempfile = "3.3.0" 33 | 34 | # Pin idna_adapter to <=1.2.0 for MSRV issues with cargo-clippy of Rust 1.78.0. 35 | idna_adapter = "<=1.2.0" 36 | 37 | [dependencies.libazureinit] 38 | path = "libazureinit" 39 | version = "0.1.0" 40 | 41 | [profile.dev] 42 | incremental = true 43 | 44 | [[bin]] 45 | name = "azure-init" 46 | path = "src/main.rs" 47 | 48 | [[bin]] 49 | name = "functional_tests" 50 | path = "tests/functional_tests.rs" 51 | 52 | [workspace] 53 | members = [ 54 | "libazureinit", 55 | ] 56 | 57 | [features] 58 | passwd = [] 59 | hostnamectl = [] 60 | useradd = [] 61 | 62 | systemd_linux = ["passwd", "hostnamectl", "useradd"] 63 | 64 | default = ["systemd_linux"] 65 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /MAINTAINERS.md: -------------------------------------------------------------------------------- 1 | # Maintainers 2 | 3 | * Nell Shamrell-Harrington - [nellshamrell](https://github.com/nellshamrell) 4 | * Cade Jacobson - [@cadejacobson](https://github.com/cadejacobson) 5 | * Dongsu Park - [@dongsupark](https://github.com/dongsupark) 6 | * Thilo Fromm - [@t-lo](https://github.com/t-lo) 7 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | BUILD_MODE ?= debug 2 | BUILD_FLAG := $(if $(filter release,$(BUILD_MODE)),--release,) 3 | 4 | build-all: 5 | @echo "" 6 | @echo "**********************************" 7 | @echo "* Building the source code" 8 | @echo "**********************************" 9 | @echo "" 10 | @cargo build --all $(BUILD_FLAG) 11 | 12 | tests: build-all 13 | @echo "" 14 | @echo "**********************************" 15 | @echo "* Unit testing" 16 | @echo "**********************************" 17 | @echo "" 18 | @cargo test --all --verbose 19 | 20 | e2e-test: build-all 21 | @./tests/functional_tests.sh 22 | 23 | fmt: 24 | @echo "" 25 | @echo "**********************************" 26 | @echo "* Formatting" 27 | @echo "**********************************" 28 | @echo "" 29 | @cargo fmt --all --check 30 | 31 | clippy: 32 | @echo "" 33 | @echo "**********************************" 34 | @echo "* Linting with clippy" 35 | @echo "**********************************" 36 | @echo "" 37 | @cargo clippy --verbose -- --deny warnings 38 | 39 | 40 | install: build-all 41 | @echo "" 42 | @echo "**********************************" 43 | @echo "* Installing binaries" 44 | @echo "**********************************" 45 | @echo "" 46 | @/bin/install -d $(DESTDIR)/usr/bin 47 | @/bin/install -m 0755 target/$(BUILD_MODE)/azure-init $(DESTDIR)/usr/bin/ 48 | 49 | @echo "" 50 | @echo "**********************************" 51 | @echo "* Installing systemd service file" 52 | @echo "**********************************" 53 | @echo "" 54 | @/bin/install -d $(DESTDIR)/usr/lib/systemd/system 55 | @/bin/install -m 0644 config/azure-init.service $(DESTDIR)/usr/lib/systemd/system/azure-init.service 56 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Azure-Init 2 | 3 | [![Github CI](https://github.com/Azure/azure-init/actions/workflows/ci.yaml/badge.svg)](https://github.com/Azure/azure-init/actions) 4 | 5 | A reference implementation for provisioning Linux VMs on Azure. 6 | 7 | Azure-init configures Linux guests from provisioning metadata. 8 | Contrary to complex guest configuration and customisation systems like e.g. cloud-init, azure-init aims to be minimal. 9 | It strictly focuses on basic instance initialisation from Azure metadata. 10 | 11 | Azure-init has very few requirements on its environment, so it may run in a very early stage of the boot process. 12 | 13 | ## Installing Rust 14 | 15 | To install Rust see here: https://www.rust-lang.org/tools/install. 16 | 17 | ## Building the Project 18 | 19 | Building this project can be done by going to the base of the repository in the command line and entering the command 20 | `cargo build --all`. This project contains two binaries, the main provisioning agent and the functional testing binary, 21 | so this command builds both. These binaries are quite small, but you can build only one by entering 22 | `cargo build --bin ` and indicating either `azure-init` or `functional_tests`. 23 | 24 | To run the program, you must enter the command `cargo run --bin ` and indicating the correct binary. 25 | 26 | ## Testing 27 | 28 | Azure-init includes two types of tests: unit tests and end-to-end (e2e) tests. 29 | 30 | ### Running Unit Tests 31 | 32 | From the root directory of the repository, run: 33 | 34 | ``` 35 | cargo test --verbose --all-features --workspace 36 | ``` 37 | 38 | This will run the unit tests for every library in the repository, not just for azure-init. 39 | Doing so ensures your testing will match what is run in the CI pipeline. 40 | 41 | ### Running End-to-End (e2e) Tests 42 | Please refer to [E2E_TESTING.md](docs/E2E_TESTING.md) for end-to-end testing. 43 | 44 | ## Contributing 45 | 46 | Contribution require you to agree to Microsoft's Contributor License Agreement (CLA). 47 | Please refer to [CONTRIBUTING.md](CONTRIBUTING.md) for detailed instructions. 48 | 49 | This project adheres to the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 50 | Check out [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md) for a brief collection of links and references. 51 | 52 | ## Trademarks 53 | 54 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 55 | trademarks or logos is subject to and must follow 56 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). 57 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. 58 | Any use of third-party trademarks or logos are subject to those third-party's policies. 59 | 60 | ## libazureinit 61 | 62 | For common library used by this reference implementation, please refer to [libazureinit](https://github.com/Azure/azure-init/tree/main/libazureinit/). -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). 16 | If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 17 | 18 | You should receive a response within 24 hours. 19 | If for some reason you do not, please follow up via email to ensure we received your original message. 20 | Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 21 | 22 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 23 | 24 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 25 | * Full paths of source file(s) related to the manifestation of the issue 26 | * The location of the affected source code (tag/branch/commit or direct URL) 27 | * Any special configuration required to reproduce the issue 28 | * Step-by-step instructions to reproduce the issue 29 | * Proof-of-concept or exploit code (if possible) 30 | * Impact of the issue, including how an attacker might exploit the issue 31 | 32 | This information will help us triage your report more quickly. 33 | 34 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. 35 | Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 36 | 37 | ## Preferred Languages 38 | 39 | We prefer all communications to be in English. 40 | 41 | ## Policy 42 | 43 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 44 | 45 | 46 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # Support 2 | 3 | This is a community project. 4 | There are no commercial support options available at this time. 5 | 6 | ## How to file issues and get help 7 | 8 | This project uses GitHub Issues to track bugs and feature requests. 9 | Please [search the existing issues](https://github.com/Azure/azure-init/issues) before filing new issues to avoid duplicates. 10 | For new issues, file your bug or feature request as a new issue. 11 | 12 | 13 | For more general help and questions about using this project please [search our discussions](https://github.com/Azure/azure-init/discussions). 14 | Feel free to [start a new discussion](https://github.com/Azure/azure-init/discussions/new/choose) if your topic has not been brought up. 15 | 16 | ## Microsoft Support Policy 17 | 18 | Support for the azure-init is limited to the resources listed above. 19 | -------------------------------------------------------------------------------- /build.rs: -------------------------------------------------------------------------------- 1 | use std::process::Command; 2 | 3 | fn main() { 4 | // Get the short commit hash 5 | let output = Command::new("git") 6 | .args(["rev-parse", "--short", "HEAD"]) 7 | .output() 8 | .expect("Failed to execute git command"); 9 | 10 | let git_hash = 11 | String::from_utf8(output.stdout).expect("Invalid UTF-8 sequence"); 12 | 13 | // Set the environment variable 14 | println!("cargo:rustc-env=GIT_COMMIT_HASH={}", git_hash.trim()); 15 | } 16 | -------------------------------------------------------------------------------- /config/azure-init.service: -------------------------------------------------------------------------------- 1 | [Unit] 2 | Description=Azure-Init 3 | After=hypervkvpd.service hv-kvp-daemon.service 4 | Wants=hypervkvpd.service hv-kvp-daemon.service 5 | After=network-online.target 6 | Wants=network-online.target 7 | 8 | [Service] 9 | Type=oneshot 10 | ExecStart=/usr/bin/azure-init 11 | StandardOutput=journal+console 12 | StandardError=journal+console 13 | 14 | [Install] 15 | WantedBy=multi-user.target 16 | -------------------------------------------------------------------------------- /demo/customdata_template.yml: -------------------------------------------------------------------------------- 1 | #cloud-config 2 | 3 | runcmd: 4 | - nohup bash /var/log/azure/image/setup.sh 2>&1 | 5 | tee -a /dev/console /var/log/azure/image/setup.log & 6 | 7 | write_files: 8 | - path: /var/log/azure/image/setup.sh 9 | content: | 10 | set -eux -o pipefail 11 | echo "SIGTOOL_START" 12 | 13 | # Wait until system is ready. 14 | state="" 15 | while [[ $state != "running" && $state != "degraded" ]]; do 16 | state=$(systemctl is-system-running || true) 17 | sleep 1 18 | done 19 | wget -O /run/azure-init.tgz __SASURL__ 20 | tar -xf /run/azure-init.tgz -C / 21 | systemctl enable /lib/systemd/system/azure-init.service 22 | mkdir --parents /etc/netplan 23 | cat > /etc/netplan/eth0.yaml < Result<(), anyhow::Error> { 52 | event!(Level::INFO, msg = "Starting the provision process..."); 53 | // Other logic... 54 | } 55 | ``` 56 | 57 | 1. **Initialization**: 58 | The `initialize_tracing` function is called at the start of the program to set up the tracing subscriber with the configured layers (`EmitKVPLayer`, `OpenTelemetryLayer`, and `stderr_layer`). 59 | 60 | 2. **Instrumenting the `provision()` Function**: 61 | The `#[instrument]` attribute is used to automatically create a span for the `provision()` function. 62 | - The `name = "root"` part of the `#[instrument]` attribute specifies the name of the span. 63 | - This span will trace the entire execution of the `provision()` function, capturing any relevant metadata (e.g., function parameters, return values). 64 | 65 | 3. **Span Processing**: 66 | As the `provision()` function is called and spans are created, entered, exited, and closed, they are processed by the layers configured in `initialize_tracing`: 67 | - **EmitKVPLayer** processes the span, generates key-value pairs, encodes them, and writes them directly to `/var/lib/hyperv/.kvp_pool_1`. 68 | - **OpenTelemetryLayer** handles context propagation and exports span data to a tracing backend or stdout. 69 | - **stderr_layer** logs span information to stderr or another specified output for immediate visibility. 70 | -------------------------------------------------------------------------------- /docs/E2E_TESTING.md: -------------------------------------------------------------------------------- 1 | # Azure-init End-to-end Testing 2 | 3 | End-to-end tests validate the integration of the entire system. These tests require additional setup, such as setting a subscription ID. 4 | 5 | ## Quickstart 6 | 7 | To run e2e tests, use the following command from the repository root: 8 | 9 | ``` 10 | make e2e-test 11 | ``` 12 | 13 | This command will: 14 | 15 | 1. Create a test user and associated SSH directory. 16 | 2. Place mock SSH keys for testing. 17 | 3. Run the tests and then clean up any test artifacts generated during the process. 18 | 19 | ## Details 20 | 21 | End-to-end testing of azure-init consists of 2 steps: preparation of SIG(Shared Image Gallery) image, and the actual testing. 22 | 23 | ### Preparation of Azure SIG image 24 | 25 | To create an Azure SIG image to be used for end-to-end testing, run `image_creation.sh`. 26 | That will create a resource group, a storage account, a virtual machine, generate a SIG image, and publish the SIG image. 27 | 28 | ``` 29 | demo/image_creation.sh 30 | ``` 31 | 32 | If you want to run the script with custom variables for resource group, VM location, VM size, base image URN, etc., then specify corresponding environment variables. For example: 33 | 34 | ``` 35 | RG="mytest-azinit" LOCATION="westeurope" VM_SIZE="Standard_D2ds_v5" BASE_IMAGE="Canonical:0001-com-ubuntu-server-jammy:22_04-lts-gen2:latest" demo/image_creation.sh 36 | ``` 37 | 38 | The current limitation is, however, that the `BASE_IMAGE` should be one of Debian-derivatives like Ubuntu. When the build host OS is different from the target host OS, the `functional_test` binary might not be able to run due to mismatch of package versions such as glibc. 39 | 40 | ### Running end-to-end testing 41 | 42 | To run end-to-end testing, use `make e2e-test`, which will create a test user, ssh directory, place mock ssh keys, and 43 | then clean up the test artifacts afterwards. 44 | 45 | `VM_IMAGE` should be specified to pick the correct SIG image created in the previous step. 46 | 47 | ``` 48 | VM_IMAGE="$(az sig image-definition list --resource-group testgalleryazinitrg --gallery-name testgalleryazinit | jq -r .[].id)" make e2e-test 49 | ``` 50 | 51 | It is also possible to pass custom environment variables. For example: 52 | 53 | ``` 54 | RG="mytest-azinit" LOCATION="westeurope" VM_SIZE="Standard_D2ds_v5" VM_IMAGE="$(az sig image-definition list --resource-group testgalleryazinitrg --gallery-name testgalleryazinit | jq -r .[].id)" make e2e-test 55 | ``` 56 | 57 | When testing is done, it is recommended to clean up resource group for SIG images. 58 | 59 | ``` 60 | az group delete --resource-group testgalleryazinitrg 61 | ``` 62 | -------------------------------------------------------------------------------- /libazureinit/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "libazureinit" 3 | version = "0.1.1" 4 | edition = "2021" 5 | rust-version = "1.74" 6 | build = "build.rs" 7 | repository = "https://github.com/Azure/azure-init/" 8 | homepage = "https://github.com/Azure/azure-init/" 9 | license = "MIT" 10 | description = "A common library for provisioning Linux VMs on Azure." 11 | 12 | [dependencies] 13 | reqwest = { version = "0.12.0", default-features = false, features = ["blocking", "json"] } 14 | serde = {version = "1.0.163", features = ["derive"]} 15 | thiserror = "2.0.3" 16 | tokio = { version = "1", features = ["full"] } 17 | serde-xml-rs = "0.8.0" 18 | serde_json = "1.0.96" 19 | nix = {version = "0.30.1", features = ["fs", "user"]} 20 | block-utils = "0.11.1" 21 | tracing = "0.1.40" 22 | fstab = "0.4.0" 23 | toml = "0.8" 24 | regex = "1" 25 | lazy_static = "1.4" 26 | figment = { version = "0.10", features = ["toml"] } 27 | # Pinned to 0.1.5 since 0.1.6 bumps its MSRV to 1.81. 28 | # The major difference in 0.1.6 is the switch to core::error 29 | # This should be unpinned on or around 2025-09-05 if we continue with our ~1 year MSRV policy 30 | zerofrom = "=0.1.5" 31 | # Pinned to 0.7.4 since 0.7.5 bumps its MSRV to 1.81. 32 | # The major difference in 0.7.5 is the switch to core::error; there's also a few API additions. 33 | # This should be unpinned on or around 2025-09-05 if we continue with our ~1 year MSRV policy 34 | litemap = "=0.7.4" 35 | uuid = "1.3" 36 | 37 | [dev-dependencies] 38 | tracing-test = { version = "0.2", features = ["no-env-filter"] } 39 | tempfile = "3" 40 | tokio = { version = "1", features = ["full"] } 41 | tokio-util = "0.7.11" 42 | whoami = "1" 43 | anyhow = "1.0.81" 44 | 45 | [lib] 46 | name = "libazureinit" 47 | path = "src/lib.rs" 48 | 49 | -------------------------------------------------------------------------------- /libazureinit/README.md: -------------------------------------------------------------------------------- 1 | # libazureinit 2 | 3 | A common library for provisioning Linux VMs on Azure. 4 | 5 | Features: 6 | 7 | * retrieve provisioning metadata from Azure Instance Metadata Service 8 | * configure the VM according to the provisioning metadata 9 | * report provisioning complete to Azure platform 10 | * basic features for instance initialisation 11 | 12 | [azure-init](https://github.com/Azure/azure-init) is a reference implementation that leverages the APIs provided by libazureinit. 13 | 14 | The goal is to provide APIs for other components to perform VM provisioning on Azure platform. 15 | 16 | For other instructions, like installing Rust, building the project, testing, etc. please refer to [azure-init's README](https://github.com/Azure/azure-init/blob/main/README.md). 17 | -------------------------------------------------------------------------------- /libazureinit/build.rs: -------------------------------------------------------------------------------- 1 | fn main() { 2 | println!("cargo:rerun-if-changed=build.rs"); 3 | 4 | // Pass in build-time environment variables, which could be used in 5 | // crates by `env!` macros. 6 | println!("cargo:rustc-env=PATH_HOSTNAMECTL=hostnamectl"); 7 | println!("cargo:rustc-env=PATH_USERADD=useradd"); 8 | println!("cargo:rustc-env=PATH_PASSWD=passwd"); 9 | } 10 | -------------------------------------------------------------------------------- /libazureinit/src/error.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT License. 3 | 4 | /// Set of error codes that can be used by libazureinit. 5 | /// 6 | /// # Example 7 | /// 8 | /// ```rust 9 | /// # use libazureinit::error::Error; 10 | /// # use std::process::Command; 11 | /// 12 | /// fn run_ls() -> Result<(), Error> { 13 | /// let ls_status = Command::new("ls").arg("/tmp").status().unwrap(); 14 | /// if !ls_status.success() { 15 | /// Err(Error::SubprocessFailed { 16 | /// command: "ls".to_string(), 17 | /// status: ls_status, 18 | /// }) 19 | /// } else { 20 | /// Ok(()) 21 | /// } 22 | /// } 23 | /// 24 | /// ``` 25 | #[derive(thiserror::Error, Debug)] 26 | pub enum Error { 27 | #[error("Unable to deserialize or serialize JSON data: {0}")] 28 | Json(#[from] serde_json::Error), 29 | #[error("Unable to deserialize or serialize XML data: {0}")] 30 | Xml(#[from] serde_xml_rs::Error), 31 | #[error("HTTP client error occurred: {0}")] 32 | Http(#[from] reqwest::Error), 33 | #[error("An I/O error occurred: {0}")] 34 | Io(#[from] std::io::Error), 35 | #[error("HTTP request did not succeed (HTTP {status} from {endpoint})")] 36 | HttpStatus { 37 | endpoint: String, 38 | status: reqwest::StatusCode, 39 | }, 40 | #[error("executing {command} failed: {status}")] 41 | SubprocessFailed { 42 | command: String, 43 | status: std::process::ExitStatus, 44 | }, 45 | #[error("failed to construct a C-style string")] 46 | NulError(#[from] std::ffi::NulError), 47 | #[error("nix call failed: {0}")] 48 | Nix(#[from] nix::Error), 49 | #[error("The user {user} does not exist")] 50 | UserMissing { user: String }, 51 | #[error("failed to get username from IMDS or local OVF files")] 52 | UsernameFailure, 53 | #[error("failed to get instance metadata from IMDS")] 54 | InstanceMetadataFailure, 55 | #[error("Provisioning a user with a non-empty password is not supported")] 56 | NonEmptyPassword, 57 | #[error("Unable to get list of block devices: {0}")] 58 | BlockUtils(#[from] block_utils::BlockUtilsError), 59 | #[error( 60 | "Failed to set the hostname; none of the provided backends succeeded" 61 | )] 62 | NoHostnameProvisioner, 63 | #[error( 64 | "Failed to create a user; none of the provided backends succeeded" 65 | )] 66 | NoUserProvisioner, 67 | #[error( 68 | "Failed to set the user password; none of the provided backends succeeded" 69 | )] 70 | NoPasswordProvisioner, 71 | #[error("A timeout error occurred")] 72 | Timeout, 73 | #[error("Failed to update the sshd configuration")] 74 | UpdateSshdConfig, 75 | } 76 | 77 | impl From for Error { 78 | fn from(_: tokio::time::error::Elapsed) -> Self { 79 | Self::Timeout 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /libazureinit/src/goalstate.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT License. 3 | 4 | use reqwest::header::HeaderMap; 5 | use reqwest::header::HeaderValue; 6 | use reqwest::Client; 7 | use tracing::instrument; 8 | 9 | use std::time::Duration; 10 | 11 | use serde::Deserialize; 12 | use serde_xml_rs::from_str; 13 | 14 | use crate::error::Error; 15 | use crate::http; 16 | 17 | /// Azure goalstate of the virtual machine. Metadata is written in XML format. 18 | /// 19 | /// Required fields are Container, Version, Incarnation. 20 | /// 21 | /// # Example 22 | /// 23 | /// ``` 24 | /// # use libazureinit::goalstate::Goalstate; 25 | /// 26 | /// static GOALSTATE_STR: &str = " 27 | /// 28 | /// 2 29 | /// 30 | /// 31 | /// test_user_instance_id 32 | /// 33 | /// 34 | /// 35 | /// example_version 36 | /// test_goal_incarnation 37 | /// "; 38 | /// 39 | /// let goalstate: Goalstate = serde_xml_rs::from_str(GOALSTATE_STR) 40 | /// .expect("Failed to parse the goalstate XML."); 41 | /// ``` 42 | #[derive(Debug, Deserialize, PartialEq)] 43 | pub struct Goalstate { 44 | #[serde(rename = "Container")] 45 | container: Container, 46 | #[serde(rename = "Version")] 47 | version: String, 48 | #[serde(rename = "Incarnation")] 49 | incarnation: String, 50 | } 51 | 52 | /// Container of [`Goalstate`] of the virtual machine. Metadata is written in XML format. 53 | #[derive(Debug, Deserialize, PartialEq)] 54 | pub struct Container { 55 | #[serde(rename = "ContainerId")] 56 | container_id: String, 57 | #[serde(rename = "RoleInstanceList")] 58 | role_instance_list: RoleInstanceList, 59 | } 60 | 61 | /// List of role instances of goalstate. Metadata is written in XML format. 62 | #[derive(Debug, Deserialize, PartialEq)] 63 | pub struct RoleInstanceList { 64 | #[serde(rename = "RoleInstance")] 65 | role_instance: RoleInstance, 66 | } 67 | 68 | /// Role instance of goalstate. Metadata is written in XML format. 69 | #[derive(Debug, Deserialize, PartialEq)] 70 | pub struct RoleInstance { 71 | #[serde(rename = "InstanceId")] 72 | instance_id: String, 73 | } 74 | 75 | const DEFAULT_GOALSTATE_URL: &str = 76 | "http://168.63.129.16/machine/?comp=goalstate"; 77 | 78 | /// Fetch Azure goalstate of Azure wireserver. 79 | /// 80 | /// Caller needs to pass 3 required parameters, client, retry_interval, 81 | /// total_timeout. It is therefore required to create a reqwest::Client 82 | /// variable with possible options, to pass it as parameter. 83 | /// 84 | /// Parameter url is optional. If None is passed, it defaults to 85 | /// DEFAULT_GOALSTATE_URL, an internal goalstate URL available in the Azure VM. 86 | /// 87 | /// # Example 88 | /// 89 | /// ``` 90 | /// # use std::time::Duration; 91 | /// use libazureinit::reqwest::Client; 92 | /// 93 | /// let client = Client::builder() 94 | /// .timeout(std::time::Duration::from_secs(5)) 95 | /// .build() 96 | /// .unwrap(); 97 | /// 98 | /// let res = libazureinit::goalstate::get_goalstate( 99 | /// &client, Duration::from_secs(1), Duration::from_secs(5), 100 | /// Some("http://127.0.0.1:8000/"), 101 | /// ); 102 | /// ``` 103 | #[instrument(err, skip_all)] 104 | pub async fn get_goalstate( 105 | client: &Client, 106 | retry_interval: Duration, 107 | mut total_timeout: Duration, 108 | url: Option<&str>, 109 | ) -> Result { 110 | let mut headers = HeaderMap::new(); 111 | headers.insert("x-ms-agent-name", HeaderValue::from_static("azure-init")); 112 | headers.insert("x-ms-version", HeaderValue::from_static("2012-11-30")); 113 | let url = url.unwrap_or(DEFAULT_GOALSTATE_URL); 114 | let request_timeout = 115 | Duration::from_secs(http::WIRESERVER_HTTP_TIMEOUT_SEC); 116 | 117 | while !total_timeout.is_zero() { 118 | let (response, remaining_timeout) = http::get( 119 | client, 120 | headers.clone(), 121 | request_timeout, 122 | retry_interval, 123 | total_timeout, 124 | url, 125 | ) 126 | .await?; 127 | match response.text().await { 128 | Ok(body) => { 129 | let goalstate = from_str(&body).map_err(|error| { 130 | tracing::warn!( 131 | ?error, 132 | "The response body was invalid and could not be deserialized" 133 | ); 134 | error.into() 135 | }); 136 | if goalstate.is_ok() { 137 | return goalstate; 138 | } 139 | } 140 | Err(error) => { 141 | tracing::warn!(?error, "Failed to read the full response body") 142 | } 143 | } 144 | 145 | total_timeout = remaining_timeout; 146 | } 147 | 148 | Err(Error::Timeout) 149 | } 150 | 151 | const DEFAULT_HEALTH_URL: &str = "http://168.63.129.16/machine/?comp=health"; 152 | 153 | /// Report health stateus to Azure wireserver. 154 | /// 155 | /// Caller needs to pass 4 required parameters, client, retry_interval, 156 | /// total_timeout, goalstate. It is therefore required to create a reqwest::Client 157 | /// variable with possible options, to pass it as parameter. Also caller must 158 | /// first run get_goalstate to get GoalState variable to pass it as parameter of 159 | /// report_health. 160 | /// 161 | /// Parameter url optional. If None is passed, it defaults to DEFAULT_HEALTH_URL, 162 | /// an internal health report URL available in the Azure VM. 163 | /// 164 | /// # Example 165 | /// 166 | /// ```rust,no_run 167 | /// # use std::time::Duration; 168 | /// use libazureinit::reqwest::Client; 169 | /// 170 | /// #[tokio::main] 171 | /// async fn main() { 172 | /// let client = Client::builder() 173 | /// .timeout(std::time::Duration::from_secs(5)) 174 | /// .build() 175 | /// .unwrap(); 176 | /// 177 | /// let vm_goalstate = libazureinit::goalstate::get_goalstate( 178 | /// &client, Duration::from_secs(1), Duration::from_secs(5), 179 | /// Some("http://127.0.0.1:8000/"), 180 | /// ).await.unwrap(); 181 | /// 182 | /// let res = libazureinit::goalstate::report_health( 183 | /// &client, vm_goalstate, Duration::from_secs(1), Duration::from_secs(5), 184 | /// Some("http://127.0.0.1:8000/"), 185 | /// ); 186 | /// } 187 | /// ``` 188 | #[instrument(err, skip_all)] 189 | pub async fn report_health( 190 | client: &Client, 191 | goalstate: Goalstate, 192 | retry_interval: Duration, 193 | total_timeout: Duration, 194 | url: Option<&str>, 195 | ) -> Result<(), Error> { 196 | let mut headers = HeaderMap::new(); 197 | headers.insert("x-ms-agent-name", HeaderValue::from_static("azure-init")); 198 | headers.insert("x-ms-version", HeaderValue::from_static("2012-11-30")); 199 | headers.insert( 200 | "Content-Type", 201 | HeaderValue::from_static("text/xml;charset=utf-8"), 202 | ); 203 | let request_timeout = 204 | Duration::from_secs(http::WIRESERVER_HTTP_TIMEOUT_SEC); 205 | let url = url.unwrap_or(DEFAULT_HEALTH_URL); 206 | 207 | let post_request = build_report_health_file(goalstate); 208 | 209 | _ = http::post( 210 | client, 211 | headers, 212 | post_request, 213 | request_timeout, 214 | retry_interval, 215 | total_timeout, 216 | url, 217 | ) 218 | .await?; 219 | 220 | Ok(()) 221 | } 222 | 223 | fn build_report_health_file(goalstate: Goalstate) -> String { 224 | let post_request = 225 | "\n\ 226 | \n\ 227 | $GOAL_STATE_INCARNATION\n\ 228 | \n\ 229 | $CONTAINER_ID\n\ 230 | \n\ 231 | \n\ 232 | $INSTANCE_ID\n\ 233 | \n\ 234 | Ready\n\ 235 | \n\ 236 | \n\ 237 | \n\ 238 | \n\ 239 | "; 240 | 241 | let post_request = 242 | post_request.replace("$GOAL_STATE_INCARNATION", &goalstate.incarnation); 243 | let post_request = post_request 244 | .replace("$CONTAINER_ID", &goalstate.container.container_id); 245 | post_request.replace( 246 | "$INSTANCE_ID", 247 | &goalstate 248 | .container 249 | .role_instance_list 250 | .role_instance 251 | .instance_id, 252 | ) 253 | } 254 | 255 | #[cfg(test)] 256 | mod tests { 257 | use super::{ 258 | build_report_health_file, get_goalstate, report_health, Goalstate, 259 | }; 260 | 261 | use reqwest::{header, Client, StatusCode}; 262 | use std::time::Duration; 263 | use tokio::net::TcpListener; 264 | 265 | use crate::{http, unittest}; 266 | 267 | static GOALSTATE_STR: &str = " 268 | 269 | 2 270 | 271 | 272 | test_user_instance_id 273 | 274 | 275 | 276 | example_version 277 | test_goal_incarnation 278 | "; 279 | 280 | static HEALTH_STR: &str = "\n\ 281 | \n\ 282 | test_goal_incarnation\n\ 283 | \n\ 284 | 2\n\ 285 | \n\ 286 | \n\ 287 | test_user_instance_id\n\ 288 | \n\ 289 | Ready\n\ 290 | \n\ 291 | \n\ 292 | \n\ 293 | \n\ 294 | "; 295 | #[test] 296 | fn test_parsing_goalstate() { 297 | let goalstate: Goalstate = serde_xml_rs::from_str(GOALSTATE_STR) 298 | .expect("Failed to parse the goalstate XML."); 299 | assert_eq!(goalstate.container.container_id, "2".to_owned()); 300 | assert_eq!( 301 | goalstate 302 | .container 303 | .role_instance_list 304 | .role_instance 305 | .instance_id, 306 | "test_user_instance_id".to_owned() 307 | ); 308 | assert_eq!(goalstate.version, "example_version".to_owned()); 309 | assert_eq!(goalstate.incarnation, "test_goal_incarnation".to_owned()); 310 | } 311 | 312 | #[tokio::test] 313 | async fn test_build_report_health_file() { 314 | let goalstate: Goalstate = serde_xml_rs::from_str(GOALSTATE_STR) 315 | .expect("Failed to parse the goalstate XML."); 316 | 317 | let actual_output = build_report_health_file(goalstate); 318 | assert_eq!(actual_output, HEALTH_STR); 319 | } 320 | 321 | // Runs a test around sending via get_goalstate() with a given statuscode. 322 | async fn run_goalstate_retry(statuscode: &StatusCode) -> bool { 323 | const HTTP_TOTAL_TIMEOUT_SEC: u64 = 5; 324 | const HTTP_PERCLIENT_TIMEOUT_SEC: u64 = 5; 325 | const HTTP_RETRY_INTERVAL_SEC: u64 = 1; 326 | 327 | let mut default_headers = header::HeaderMap::new(); 328 | let user_agent = 329 | header::HeaderValue::from_str("azure-init test").unwrap(); 330 | 331 | // Run local test servers for goalstate and health that reply with simple test data. 332 | let gs_ok_payload = 333 | unittest::get_http_response_payload(statuscode, GOALSTATE_STR); 334 | let gs_serverlistener = TcpListener::bind("127.0.0.1:0").await.unwrap(); 335 | let gs_addr = gs_serverlistener.local_addr().unwrap(); 336 | 337 | let health_ok_payload = 338 | unittest::get_http_response_payload(statuscode, HEALTH_STR); 339 | let health_serverlistener = 340 | TcpListener::bind("127.0.0.1:0").await.unwrap(); 341 | let health_addr = health_serverlistener.local_addr().unwrap(); 342 | 343 | let cancel_token = tokio_util::sync::CancellationToken::new(); 344 | 345 | let gs_server = tokio::spawn(unittest::serve_requests( 346 | gs_serverlistener, 347 | gs_ok_payload, 348 | cancel_token.clone(), 349 | )); 350 | let health_server = tokio::spawn(unittest::serve_requests( 351 | health_serverlistener, 352 | health_ok_payload, 353 | cancel_token.clone(), 354 | )); 355 | 356 | default_headers.insert(header::USER_AGENT, user_agent); 357 | let client = Client::builder() 358 | .timeout(std::time::Duration::from_secs(HTTP_PERCLIENT_TIMEOUT_SEC)) 359 | .default_headers(default_headers) 360 | .build() 361 | .unwrap(); 362 | 363 | let vm_goalstate = get_goalstate( 364 | &client, 365 | Duration::from_secs(HTTP_RETRY_INTERVAL_SEC), 366 | Duration::from_secs(HTTP_TOTAL_TIMEOUT_SEC), 367 | Some( 368 | format!("http://{:}:{:}/", gs_addr.ip(), gs_addr.port()) 369 | .as_str(), 370 | ), 371 | ) 372 | .await; 373 | 374 | if !vm_goalstate.is_ok() { 375 | cancel_token.cancel(); 376 | 377 | let gs_requests = gs_server.await.unwrap(); 378 | let health_requests = health_server.await.unwrap(); 379 | 380 | if http::HARDFAIL_CODES.contains(statuscode) { 381 | assert_eq!(gs_requests, 1); 382 | assert_eq!(health_requests, 0); 383 | } 384 | 385 | if http::RETRY_CODES.contains(statuscode) { 386 | assert!(gs_requests >= 4); 387 | assert_eq!(health_requests, 0); 388 | } 389 | 390 | return false; 391 | } 392 | 393 | let res_health = report_health( 394 | &client, 395 | vm_goalstate.unwrap(), 396 | Duration::from_secs(HTTP_RETRY_INTERVAL_SEC), 397 | Duration::from_secs(HTTP_TOTAL_TIMEOUT_SEC), 398 | Some( 399 | format!( 400 | "http://{:}:{:}/", 401 | health_addr.ip(), 402 | health_addr.port() 403 | ) 404 | .as_str(), 405 | ), 406 | ) 407 | .await; 408 | 409 | res_health.is_ok() 410 | } 411 | 412 | #[tokio::test] 413 | async fn goalstate_query_retry() { 414 | // status codes that should succeed. 415 | assert!(run_goalstate_retry(&StatusCode::OK).await); 416 | 417 | // status codes that should be retried up to 5 minutes. 418 | for rc in http::RETRY_CODES { 419 | assert!(!run_goalstate_retry(rc).await); 420 | } 421 | 422 | // status codes that should result into immediate failures. 423 | for rc in http::HARDFAIL_CODES { 424 | assert!(!run_goalstate_retry(rc).await); 425 | } 426 | } 427 | 428 | // Assert malformed responses are retried. 429 | // 430 | // In this case the server doesn't return XML at all. 431 | #[tokio::test] 432 | #[tracing_test::traced_test] 433 | async fn malformed_response() { 434 | let body = "You thought this was XML, but you were wrong"; 435 | let payload = format!( 436 | "HTTP/1.1 {} {}\r\nContent-Type: application/xml\r\nContent-Length: {}\r\n\r\n{}", 437 | StatusCode::OK.as_u16(), 438 | StatusCode::OK.to_string(), 439 | body.len(), 440 | body 441 | ); 442 | 443 | let serverlistener = TcpListener::bind("127.0.0.1:0").await.unwrap(); 444 | let addr = serverlistener.local_addr().unwrap(); 445 | let cancel_token = tokio_util::sync::CancellationToken::new(); 446 | let server = tokio::spawn(unittest::serve_requests( 447 | serverlistener, 448 | payload, 449 | cancel_token.clone(), 450 | )); 451 | 452 | let client = Client::builder() 453 | .timeout(std::time::Duration::from_secs(5)) 454 | .build() 455 | .unwrap(); 456 | 457 | let res = get_goalstate( 458 | &client, 459 | Duration::from_millis(10), 460 | Duration::from_millis(50), 461 | Some(format!("http://{:}:{:}/", addr.ip(), addr.port()).as_str()), 462 | ) 463 | .await; 464 | 465 | cancel_token.cancel(); 466 | 467 | let requests = server.await.unwrap(); 468 | assert!(requests >= 2); 469 | assert!(logs_contain( 470 | "The response body was invalid and could not be deserialized" 471 | )); 472 | match res { 473 | Err(crate::error::Error::Timeout) => {} 474 | _ => panic!("Response should have timed out"), 475 | }; 476 | } 477 | } 478 | -------------------------------------------------------------------------------- /libazureinit/src/http.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT License. 3 | 4 | use std::time::Duration; 5 | 6 | use reqwest::{header::HeaderMap, Client, Request, StatusCode}; 7 | use tokio::time::timeout; 8 | use tracing::{instrument, Instrument}; 9 | 10 | use crate::error::Error; 11 | 12 | /// Set of StatusCodes that should be retried, 13 | /// e.g. 400, 404, 410, 429, 500, 503. 14 | /// 15 | /// # Example 16 | /// 17 | /// ```rust,ignore 18 | /// # use libazureinit::http::RETRY_CODES; 19 | /// # use reqwest::StatusCode; 20 | /// 21 | /// assert!(RETRY_CODES.contains(StatusCode::NOT_FOUND)); 22 | /// ``` 23 | pub(crate) const RETRY_CODES: &[StatusCode] = &[ 24 | StatusCode::BAD_REQUEST, 25 | StatusCode::NOT_FOUND, 26 | StatusCode::GONE, 27 | StatusCode::TOO_MANY_REQUESTS, 28 | StatusCode::INTERNAL_SERVER_ERROR, 29 | StatusCode::SERVICE_UNAVAILABLE, 30 | ]; 31 | 32 | /// Set of StatusCodes that should immediately fail, 33 | /// e.g. 401, 403, 405. 34 | /// 35 | /// # Example 36 | /// 37 | /// ```rust,ignore 38 | /// # use libazureinit::http::HARDFAIL_CODES; 39 | /// # use reqwest::StatusCode; 40 | /// 41 | /// assert!(HARDFAIL_CODES.contains(StatusCode::FORBIDDEN)); 42 | /// ``` 43 | #[allow(dead_code)] 44 | pub(crate) const HARDFAIL_CODES: &[StatusCode] = &[ 45 | StatusCode::UNAUTHORIZED, 46 | StatusCode::FORBIDDEN, 47 | StatusCode::METHOD_NOT_ALLOWED, 48 | ]; 49 | 50 | /// Timeout for communicating with IMDS. 51 | pub(crate) const IMDS_HTTP_TIMEOUT_SEC: u64 = 30; 52 | /// Timeout for communicating with wireserver for goalstate, health. 53 | pub(crate) const WIRESERVER_HTTP_TIMEOUT_SEC: u64 = 30; 54 | 55 | /// Send an HTTP GET request to the given URL with an empty body. 56 | #[instrument(err, skip_all)] 57 | pub(crate) async fn get( 58 | client: &Client, 59 | headers: HeaderMap, 60 | request_timeout: Duration, 61 | retry_interval: Duration, 62 | retry_for: Duration, 63 | url: &str, 64 | ) -> Result<(reqwest::Response, Duration), Error> { 65 | let req = client 66 | .get(url) 67 | .headers(headers) 68 | .timeout(request_timeout) 69 | .build()?; 70 | request(client, req, retry_interval, retry_for).await 71 | } 72 | 73 | /// Send an HTTP GET request to the given URL with an empty body. 74 | /// 75 | /// `body` must implement Clone as retries must clone the entire request. 76 | #[instrument(err, skip_all)] 77 | pub(crate) async fn post + Clone>( 78 | client: &Client, 79 | headers: HeaderMap, 80 | body: T, 81 | request_timeout: Duration, 82 | retry_interval: Duration, 83 | retry_for: Duration, 84 | url: &str, 85 | ) -> Result<(reqwest::Response, Duration), Error> { 86 | let req = client 87 | .post(url) 88 | .headers(headers) 89 | .body(body) 90 | .timeout(request_timeout) 91 | .build()?; 92 | request(client, req, retry_interval, retry_for).await 93 | } 94 | 95 | /// Retry an HTTP request until it returns HTTP 200 or the timeout is reached. 96 | /// 97 | /// In the event that the request succeeds, the total remaining timeout is returned with the response. 98 | /// This can be used to resume retrying in the event that the body is malformed. 99 | /// 100 | /// # Panics 101 | /// 102 | /// This function will panic if the request passed cannot be cloned (i.e. the body is a Stream). 103 | /// Functions wrapping this must ensure to include an additional bound on `Body` (see [`post`]). 104 | async fn request( 105 | client: &Client, 106 | request: Request, 107 | retry_interval: Duration, 108 | retry_for: Duration, 109 | ) -> Result<(reqwest::Response, Duration), Error> { 110 | timeout(retry_for, async { 111 | let now = std::time::Instant::now(); 112 | let mut attempt = 0_u32; 113 | loop { 114 | let span = tracing::info_span!("request", attempt, http_status = tracing::field::Empty); 115 | let req = request.try_clone().expect("The request body MUST be clone-able"); 116 | match client 117 | .execute(req) 118 | .instrument(span.clone()) 119 | .await { 120 | Ok(response) => { 121 | let _enter = span.enter(); 122 | let statuscode = response.status(); 123 | span.record("http_status", statuscode.as_u16()); 124 | tracing::info!(target: "libazureinit::http::received", url=response.url().as_str(), "HTTP response received"); 125 | 126 | match response.error_for_status() { 127 | Ok(response) => { 128 | if statuscode == StatusCode::OK { 129 | tracing::info!(target: "libazureinit::http::success", "HTTP response succeeded with status {}", statuscode); 130 | return Ok((response, retry_for.saturating_sub(now.elapsed() + retry_interval))); 131 | } 132 | }, 133 | Err(error) => { 134 | if !RETRY_CODES.contains(&statuscode) { 135 | tracing::error!( 136 | ?error, 137 | "HTTP response status code is fatal and the request will not be retried", 138 | ); 139 | return Err(error.into()); 140 | } 141 | }, 142 | } 143 | 144 | }, 145 | Err(error) => { 146 | let _enter = span.enter(); 147 | tracing::error!(?error, "HTTP request failed to complete"); 148 | }, 149 | } 150 | span.in_scope(||{ 151 | tracing::warn!( 152 | "Failed to get a successful HTTP response, retrying in {} sec, remaining timeout {} sec.", 153 | retry_interval.as_secs(), 154 | retry_for.saturating_sub(now.elapsed()).as_secs() 155 | ); 156 | }); 157 | // Explicitly dropping here to ensure the sleep isn't included in the request timings 158 | drop(span); 159 | 160 | attempt += 1; 161 | tokio::time::sleep(retry_interval).await; 162 | } 163 | }).await? 164 | } 165 | 166 | #[cfg(test)] 167 | pub(crate) mod tests { 168 | use reqwest::{header, Client, StatusCode}; 169 | use std::time::Duration; 170 | use tokio::{io::AsyncWriteExt, net::TcpListener}; 171 | 172 | use crate::unittest::{get_http_response_payload, serve_requests}; 173 | 174 | const BODY_CONTENTS: &str = "hello world"; 175 | 176 | // Helper that returns how many attempts were made on a given HTTP status code. 177 | async fn serve_valid_http_with( 178 | statuscode: &StatusCode, 179 | body: &str, 180 | ) -> bool { 181 | let serverlistener = TcpListener::bind("127.0.0.1:0").await.unwrap(); 182 | let addr = serverlistener.local_addr().unwrap(); 183 | let cancel_token = tokio_util::sync::CancellationToken::new(); 184 | let server = tokio::spawn(serve_requests( 185 | serverlistener, 186 | get_http_response_payload(statuscode, body), 187 | cancel_token.clone(), 188 | )); 189 | 190 | let client = Client::builder() 191 | .timeout(std::time::Duration::from_secs(1)) 192 | .build() 193 | .unwrap(); 194 | 195 | let res = super::get( 196 | &client, 197 | header::HeaderMap::new(), 198 | Duration::from_millis(500), 199 | Duration::from_millis(5), 200 | Duration::from_millis(100), 201 | format!("http://{:}:{:}/", addr.ip(), addr.port()).as_str(), 202 | ) 203 | .await; 204 | 205 | cancel_token.cancel(); 206 | 207 | let requests = server.await.unwrap(); 208 | 209 | if super::HARDFAIL_CODES.contains(statuscode) { 210 | assert_eq!(requests, 1); 211 | } 212 | 213 | if *statuscode == StatusCode::OK { 214 | assert_eq!(requests, 1); 215 | } 216 | 217 | if super::RETRY_CODES.contains(statuscode) { 218 | assert!(requests >= 10); 219 | } 220 | 221 | res.is_ok() 222 | } 223 | 224 | // Assert requests that don't receive data after the connection is accepted retry. 225 | #[tokio::test] 226 | #[tracing_test::traced_test] 227 | async fn get_slow_write() { 228 | let serverlistener = TcpListener::bind("127.0.0.1:0").await.unwrap(); 229 | let addr = serverlistener.local_addr().unwrap(); 230 | let task_cancel = tokio_util::sync::CancellationToken::new(); 231 | let cancel_token = task_cancel.clone(); 232 | let server = tokio::spawn(async move { 233 | let mut requests_accepted = 0; 234 | loop { 235 | tokio::select! { 236 | _ = task_cancel.cancelled() => { 237 | break; 238 | } 239 | _ = async { 240 | let (mut serverstream, _) = serverlistener.accept().await.unwrap(); 241 | requests_accepted += 1; 242 | // Do this asynchronously so we accept the next request in a timely manner; 243 | // there's a separate test for slow accepts. 244 | tokio::spawn(async move { 245 | tokio::time::sleep(Duration::from_millis(200)).await; 246 | let _ = serverstream.write_all( 247 | get_http_response_payload(&StatusCode::FORBIDDEN, "too slow").as_bytes() 248 | ).await; 249 | }); 250 | } => {} 251 | } 252 | } 253 | requests_accepted 254 | }); 255 | 256 | let client = Client::builder().build().unwrap(); 257 | 258 | let res = super::get( 259 | &client, 260 | header::HeaderMap::new(), 261 | Duration::from_millis(100), 262 | Duration::from_millis(200), 263 | Duration::from_millis(500), 264 | format!("http://{:}:{:}/", addr.ip(), addr.port()).as_str(), 265 | ) 266 | .await; 267 | 268 | cancel_token.cancel(); 269 | 270 | let requests = server.await.unwrap(); 271 | assert!(requests >= 2); 272 | match res { 273 | Err(crate::error::Error::Timeout) => {} 274 | _ => panic!("Response should have timed out"), 275 | }; 276 | } 277 | 278 | // Assert requests that never get accepted retry 279 | #[tokio::test] 280 | #[tracing_test::traced_test] 281 | async fn get_slow_accept() { 282 | let serverlistener = TcpListener::bind("127.0.0.1:0").await.unwrap(); 283 | let addr = serverlistener.local_addr().unwrap(); 284 | let task_cancel = tokio_util::sync::CancellationToken::new(); 285 | let cancel_token = task_cancel.clone(); 286 | let server = tokio::spawn(async move { 287 | let mut requests_attempted = 0; 288 | loop { 289 | tokio::select! { 290 | _ = task_cancel.cancelled() => { 291 | break; 292 | } 293 | _ = async { 294 | requests_attempted += 1; 295 | tokio::time::sleep(Duration::from_millis(150)).await; 296 | let _ = serverlistener.accept().await; 297 | } => {} 298 | } 299 | } 300 | requests_attempted 301 | }); 302 | 303 | let client = Client::builder().build().unwrap(); 304 | 305 | let res = super::get( 306 | &client, 307 | header::HeaderMap::new(), 308 | Duration::from_millis(100), 309 | Duration::from_millis(200), 310 | Duration::from_millis(1000), 311 | format!("http://{:}:{:}/", addr.ip(), addr.port()).as_str(), 312 | ) 313 | .await; 314 | 315 | cancel_token.cancel(); 316 | 317 | let requests = server.await.unwrap(); 318 | assert!(requests >= 2); 319 | match res { 320 | Err(crate::error::Error::Timeout) => {} 321 | _ => panic!("Response should have timed out"), 322 | }; 323 | } 324 | 325 | // Assert a response with 200 OK is returned. 326 | #[tokio::test] 327 | #[tracing_test::traced_test] 328 | async fn get_ok() { 329 | assert!(serve_valid_http_with(&StatusCode::OK, BODY_CONTENTS).await); 330 | assert!(logs_contain("HTTP response succeeded with status 200 OK")); 331 | } 332 | 333 | // Assert status codes in the list are retried 334 | #[tokio::test] 335 | #[tracing_test::traced_test] 336 | async fn get_retry_responses() { 337 | for rc in super::RETRY_CODES { 338 | assert!(!serve_valid_http_with(rc, BODY_CONTENTS).await); 339 | } 340 | } 341 | 342 | #[tokio::test] 343 | #[tracing_test::traced_test] 344 | async fn get_fast_fail() { 345 | // status codes that should result into immediate failures. 346 | for rc in super::HARDFAIL_CODES { 347 | assert!(!serve_valid_http_with(rc, BODY_CONTENTS).await); 348 | } 349 | } 350 | } 351 | -------------------------------------------------------------------------------- /libazureinit/src/imds.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT License. 3 | 4 | use reqwest::header::HeaderMap; 5 | use reqwest::header::HeaderValue; 6 | use reqwest::Client; 7 | use tracing::instrument; 8 | 9 | use std::time::Duration; 10 | 11 | use serde::{Deserialize, Deserializer}; 12 | use serde_json; 13 | use serde_json::Value; 14 | 15 | use crate::error::Error; 16 | use crate::http; 17 | 18 | /// Azure instance metadata obtained from IMDS. Written in JSON format. 19 | /// 20 | /// Required fields are osProfile and publicKeys. 21 | /// 22 | /// # Example 23 | /// 24 | /// ``` 25 | /// # use libazureinit::imds; 26 | /// static TESTDATA: &str = r#" 27 | ///{ 28 | /// "compute": { 29 | /// "osProfile": { 30 | /// "adminUsername": "testuser", 31 | /// "computerName": "testcomputer", 32 | /// "disablePasswordAuthentication": "true" 33 | /// }, 34 | /// "publicKeys": [] 35 | /// } 36 | ///}"#; 37 | /// let metadata: imds::InstanceMetadata = 38 | /// serde_json::from_str(&TESTDATA.to_string()).unwrap(); 39 | /// ``` 40 | #[derive(Debug, Deserialize, PartialEq, Clone)] 41 | pub struct InstanceMetadata { 42 | /// Compute metadata 43 | pub compute: Compute, 44 | } 45 | 46 | /// Metadata about the instance's virtual machine. Written in JSON format. 47 | #[derive(Debug, Deserialize, PartialEq, Clone)] 48 | pub struct Compute { 49 | /// Metadata about the operating system. 50 | #[serde(rename = "osProfile")] 51 | pub os_profile: OsProfile, 52 | /// SSH Public keys. 53 | #[serde(rename = "publicKeys")] 54 | pub public_keys: Vec, 55 | } 56 | 57 | /// Azure Metadata about the virtual machine's operating system, obtained from IMDS. 58 | /// Written in JSON format. 59 | /// 60 | /// Required fields are adminUsername, computerName, disablePasswordAuthentication. 61 | /// 62 | /// # Example 63 | /// 64 | /// ``` 65 | /// # use serde_json::json; 66 | /// # use libazureinit::imds::OsProfile; 67 | /// 68 | /// let TESTDATA = json!({ 69 | /// "adminUsername": "testuser", 70 | /// "computerName": "testcomputer", 71 | /// "disablePasswordAuthentication": "true" 72 | /// }); 73 | /// let os_profile: OsProfile = serde_json::from_value(TESTDATA).unwrap(); 74 | /// ``` 75 | #[derive(Debug, Deserialize, PartialEq, Clone)] 76 | pub struct OsProfile { 77 | /// The admin account's username. 78 | #[serde(rename = "adminUsername")] 79 | pub admin_username: String, 80 | /// The name of the virtual machine. 81 | #[serde(rename = "computerName")] 82 | pub computer_name: String, 83 | /// Specifies whether or not password authentication is disabled. 84 | #[serde( 85 | rename = "disablePasswordAuthentication", 86 | deserialize_with = "string_bool" 87 | )] 88 | pub disable_password_authentication: bool, 89 | } 90 | 91 | /// Azure Metadata's SSH public key obtained from IMDS. Written in JSON format. 92 | /// 93 | /// # Example 94 | /// 95 | /// ``` 96 | /// # use serde_json::json; 97 | /// # use libazureinit::imds::PublicKeys; 98 | /// 99 | /// let TESTDATA = json!({ 100 | /// "keyData": "ssh-rsa test_key1", 101 | /// "path": "/path/to/.ssh/authorized_keys" 102 | /// }); 103 | /// let ssh_key: PublicKeys = serde_json::from_value(TESTDATA).unwrap(); 104 | /// ``` 105 | #[derive(Debug, Deserialize, PartialEq, Clone)] 106 | pub struct PublicKeys { 107 | /// The SSH public key certificate used to authenticate with the virtual machine. 108 | #[serde(rename = "keyData")] 109 | pub key_data: String, 110 | /// The full path on the virtual machine where the SSH public key is stored. 111 | #[serde(rename = "path")] 112 | pub path: String, 113 | } 114 | 115 | impl From<&str> for PublicKeys { 116 | fn from(value: &str) -> Self { 117 | Self { 118 | key_data: value.to_string(), 119 | path: String::new(), 120 | } 121 | } 122 | } 123 | 124 | /// Deserializer that handles the string "true" and "false" that the IMDS API returns. 125 | fn string_bool<'de, D>(deserializer: D) -> Result 126 | where 127 | D: Deserializer<'de>, 128 | { 129 | match Deserialize::deserialize(deserializer)? { 130 | Value::String(string) => match string.as_str() { 131 | "true" => Ok(true), 132 | "false" => Ok(false), 133 | unknown => Err(serde::de::Error::unknown_variant( 134 | unknown, 135 | &["true", "false"], 136 | )), 137 | }, 138 | Value::Bool(boolean) => Ok(boolean), 139 | _ => Err(serde::de::Error::custom( 140 | "Unexpected type, expected 'true' or 'false'", 141 | )), 142 | } 143 | } 144 | 145 | const DEFAULT_IMDS_URL: &str = 146 | "http://169.254.169.254/metadata/instance?api-version=2023-11-15&extended=true"; 147 | 148 | /// Send queries to IMDS to fetch Azure instance metadata. 149 | /// 150 | /// Caller needs to pass 3 required parameters, client, retry_interval, 151 | /// total_timeout. It is therefore required to create a reqwest::Client 152 | /// variable with possible options, to pass it as parameter. 153 | /// 154 | /// Parameter url optional. If None is passed, it defaults to 155 | /// DEFAULT_IMDS_URL, an internal IMDS URL available in the Azure VM. 156 | /// 157 | /// # Example 158 | /// 159 | /// ``` 160 | /// # use reqwest::Client; 161 | /// # use std::time::Duration; 162 | /// 163 | /// let client = Client::builder() 164 | /// .timeout(std::time::Duration::from_secs(5)) 165 | /// .build() 166 | /// .unwrap(); 167 | /// 168 | /// let res = libazureinit::imds::query( 169 | /// &client, Duration::from_secs(1), Duration::from_secs(5), 170 | /// Some("http://127.0.0.1:8000/"), 171 | /// ); 172 | /// ``` 173 | #[instrument(err, skip_all)] 174 | pub async fn query( 175 | client: &Client, 176 | retry_interval: Duration, 177 | mut total_timeout: Duration, 178 | url: Option<&str>, 179 | ) -> Result { 180 | let mut headers = HeaderMap::new(); 181 | headers.insert("Metadata", HeaderValue::from_static("true")); 182 | let url = url.unwrap_or(DEFAULT_IMDS_URL); 183 | let request_timeout = Duration::from_secs(http::IMDS_HTTP_TIMEOUT_SEC); 184 | 185 | while !total_timeout.is_zero() { 186 | let (response, remaining_timeout) = http::get( 187 | client, 188 | headers.clone(), 189 | request_timeout, 190 | retry_interval, 191 | total_timeout, 192 | url, 193 | ) 194 | .await?; 195 | match response.text().await { 196 | Ok(text) => { 197 | let metadata = 198 | serde_json::from_str(text.as_str()).map_err(|error| { 199 | tracing::warn!( 200 | ?error, 201 | "The response body was invalid and could not be deserialized" 202 | ); 203 | error.into() 204 | }); 205 | if metadata.is_ok() { 206 | return metadata; 207 | } 208 | } 209 | Err(error) => { 210 | tracing::warn!(?error, "Failed to read the full response body") 211 | } 212 | } 213 | 214 | total_timeout = remaining_timeout; 215 | } 216 | 217 | Err(Error::Timeout) 218 | } 219 | 220 | #[cfg(test)] 221 | mod tests { 222 | use serde_json::json; 223 | 224 | use super::{query, InstanceMetadata, OsProfile}; 225 | 226 | use reqwest::{header, Client, StatusCode}; 227 | use std::time::Duration; 228 | use tokio::net::TcpListener; 229 | 230 | use crate::{http, unittest}; 231 | 232 | static BODY_CONTENTS: &str = r#" 233 | { 234 | "compute": { 235 | "azEnvironment": "cloud_env", 236 | "customData": "", 237 | "evictionPolicy": "", 238 | "isHostCompatibilityLayerVm": "false", 239 | "licenseType": "", 240 | "location": "eastus", 241 | "name": "AzTux-MinProvAgent-Test-0001", 242 | "offer": "0001-com-ubuntu-server-focal", 243 | "osProfile": { 244 | "adminUsername": "MinProvAgentUser", 245 | "computerName": "AzTux-MinProvAgent-Test-0001", 246 | "disablePasswordAuthentication": "true" 247 | }, 248 | "publicKeys": [ 249 | { 250 | "keyData": "ssh-rsa test_key1", 251 | "path": "/path/to/.ssh/authorized_keys" 252 | }, 253 | { 254 | "keyData": "ssh-rsa test_key2", 255 | "path": "/path/to/.ssh/authorized_keys" 256 | } 257 | ] 258 | } 259 | }"#; 260 | 261 | #[test] 262 | fn instance_metadata_deserialization() { 263 | let file_body = BODY_CONTENTS.to_string(); 264 | 265 | let metadata: InstanceMetadata = 266 | serde_json::from_str(&file_body).unwrap(); 267 | 268 | assert!(metadata.compute.os_profile.disable_password_authentication); 269 | assert_eq!( 270 | metadata.compute.public_keys[0].key_data, 271 | "ssh-rsa test_key1".to_string() 272 | ); 273 | assert_eq!( 274 | metadata.compute.public_keys[1].key_data, 275 | "ssh-rsa test_key2".to_string() 276 | ); 277 | assert_eq!( 278 | metadata.compute.os_profile.admin_username, 279 | "MinProvAgentUser".to_string() 280 | ); 281 | assert_eq!( 282 | metadata.compute.os_profile.computer_name, 283 | "AzTux-MinProvAgent-Test-0001".to_string() 284 | ); 285 | assert_eq!( 286 | metadata.compute.os_profile.disable_password_authentication, 287 | true 288 | ); 289 | } 290 | 291 | #[test] 292 | fn deserialization_disable_password_true() { 293 | let os_profile = json!({ 294 | "adminUsername": "MinProvAgentUser", 295 | "computerName": "AzTux-MinProvAgent-Test-0001", 296 | "disablePasswordAuthentication": "true" 297 | }); 298 | let os_profile: OsProfile = serde_json::from_value(os_profile).unwrap(); 299 | assert!(os_profile.disable_password_authentication); 300 | } 301 | 302 | #[test] 303 | fn deserialization_disable_password_false() { 304 | let os_profile = json!({ 305 | "adminUsername": "MinProvAgentUser", 306 | "computerName": "AzTux-MinProvAgent-Test-0001", 307 | "disablePasswordAuthentication": "false" 308 | }); 309 | let os_profile: OsProfile = serde_json::from_value(os_profile).unwrap(); 310 | assert_eq!(os_profile.disable_password_authentication, false); 311 | } 312 | 313 | #[test] 314 | fn deserialization_disable_password_nonsense() { 315 | let os_profile = json!({ 316 | "adminUsername": "MinProvAgentUser", 317 | "computerName": "AzTux-MinProvAgent-Test-0001", 318 | "disablePasswordAuthentication": "nonsense" 319 | }); 320 | let os_profile: Result = 321 | serde_json::from_value(os_profile); 322 | assert!(os_profile.is_err_and(|err| err.is_data())); 323 | } 324 | 325 | // Runs a test around sending via imds::query() with a given statuscode. 326 | async fn run_imds_query_retry(statuscode: &StatusCode) -> bool { 327 | const IMDS_HTTP_TOTAL_TIMEOUT_SEC: u64 = 5; 328 | const IMDS_HTTP_PERCLIENT_TIMEOUT_SEC: u64 = 5; 329 | const IMDS_HTTP_RETRY_INTERVAL_SEC: u64 = 1; 330 | 331 | let mut default_headers = header::HeaderMap::new(); 332 | let user_agent = 333 | header::HeaderValue::from_str("azure-init test").unwrap(); 334 | 335 | let ok_payload = 336 | unittest::get_http_response_payload(statuscode, BODY_CONTENTS); 337 | let serverlistener = TcpListener::bind("127.0.0.1:0").await.unwrap(); 338 | let addr = serverlistener.local_addr().unwrap(); 339 | 340 | let cancel_token = tokio_util::sync::CancellationToken::new(); 341 | 342 | let server = tokio::spawn(unittest::serve_requests( 343 | serverlistener, 344 | ok_payload, 345 | cancel_token.clone(), 346 | )); 347 | 348 | default_headers.insert(header::USER_AGENT, user_agent); 349 | let client = Client::builder() 350 | .timeout(std::time::Duration::from_secs( 351 | IMDS_HTTP_PERCLIENT_TIMEOUT_SEC, 352 | )) 353 | .default_headers(default_headers) 354 | .build() 355 | .unwrap(); 356 | 357 | let res = query( 358 | &client, 359 | Duration::from_secs(IMDS_HTTP_RETRY_INTERVAL_SEC), 360 | Duration::from_secs(IMDS_HTTP_TOTAL_TIMEOUT_SEC), 361 | Some(format!("http://{:}:{:}/", addr.ip(), addr.port()).as_str()), 362 | ) 363 | .await; 364 | 365 | cancel_token.cancel(); 366 | 367 | let requests = server.await.unwrap(); 368 | 369 | if http::HARDFAIL_CODES.contains(statuscode) { 370 | assert_eq!(requests, 1); 371 | } 372 | 373 | if http::RETRY_CODES.contains(statuscode) { 374 | assert!(requests >= 4); 375 | } 376 | 377 | res.is_ok() 378 | } 379 | 380 | #[tokio::test] 381 | async fn imds_query_retry() { 382 | // status codes that should succeed. 383 | assert!(run_imds_query_retry(&StatusCode::OK).await); 384 | 385 | // status codes that should be retried up to 5 minutes. 386 | for rc in http::RETRY_CODES { 387 | assert!(!run_imds_query_retry(rc).await); 388 | } 389 | 390 | // status codes that should result into immediate failures. 391 | for rc in http::HARDFAIL_CODES { 392 | assert!(!run_imds_query_retry(rc).await); 393 | } 394 | } 395 | 396 | // Assert malformed responses are retried. 397 | // 398 | // In this case the server declares a content-type of JSON, but doesn't return JSON. 399 | #[tokio::test] 400 | #[tracing_test::traced_test] 401 | async fn malformed_response() { 402 | let body = "not json, whoops"; 403 | let payload = format!( 404 | "HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", 405 | StatusCode::OK.as_u16(), 406 | StatusCode::OK.to_string(), 407 | body.len(), 408 | body 409 | ); 410 | 411 | let serverlistener = TcpListener::bind("127.0.0.1:0").await.unwrap(); 412 | let addr = serverlistener.local_addr().unwrap(); 413 | let cancel_token = tokio_util::sync::CancellationToken::new(); 414 | let server = tokio::spawn(unittest::serve_requests( 415 | serverlistener, 416 | payload, 417 | cancel_token.clone(), 418 | )); 419 | 420 | let client = Client::builder() 421 | .timeout(std::time::Duration::from_secs(5)) 422 | .build() 423 | .unwrap(); 424 | 425 | let res = query( 426 | &client, 427 | Duration::from_millis(10), 428 | Duration::from_millis(50), 429 | Some(format!("http://{:}:{:}/", addr.ip(), addr.port()).as_str()), 430 | ) 431 | .await; 432 | 433 | cancel_token.cancel(); 434 | 435 | let requests = server.await.unwrap(); 436 | assert!(requests >= 2); 437 | assert!(logs_contain( 438 | "The response body was invalid and could not be deserialized" 439 | )); 440 | match res { 441 | Err(crate::error::Error::Timeout) => {} 442 | _ => panic!("Response should have timed out"), 443 | }; 444 | } 445 | } 446 | -------------------------------------------------------------------------------- /libazureinit/src/lib.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT License. 3 | pub mod config; 4 | pub use config::{HostnameProvisioner, PasswordProvisioner, UserProvisioner}; 5 | pub mod error; 6 | pub mod goalstate; 7 | pub(crate) mod http; 8 | pub mod imds; 9 | pub mod media; 10 | 11 | mod provision; 12 | pub use provision::{user::User, Provision}; 13 | mod status; 14 | pub use status::{ 15 | get_vm_id, is_provisioning_complete, mark_provisioning_complete, 16 | }; 17 | 18 | #[cfg(test)] 19 | mod unittest; 20 | 21 | // Re-export as the Client is used in our API. 22 | pub use reqwest; 23 | 24 | /// Run a command, capturing its output and logging it if it fails. 25 | /// 26 | /// In the event of a failure, the provided `error_message` is logged at 27 | /// error level. 28 | /// 29 | ///
30 | /// 31 | /// This logs the command and its arguments, and as such is not appropriate 32 | /// if the command contains sensitive information. 33 | /// 34 | ///
35 | pub(crate) fn run( 36 | mut command: std::process::Command, 37 | ) -> Result<(), error::Error> { 38 | let program = command.get_program().to_string_lossy().to_string(); 39 | let span = tracing::info_span!("subprocess", program); 40 | let _entered = span.enter(); 41 | 42 | tracing::debug!(?command, "About to execute system program"); 43 | let output = command.output()?; 44 | let status = output.status; 45 | tracing::debug!(?status, "System program completed"); 46 | 47 | if !status.success() { 48 | let stderr = String::from_utf8_lossy(&output.stderr); 49 | let stdout = String::from_utf8_lossy(&output.stdout); 50 | tracing::error!( 51 | ?status, 52 | ?command, 53 | ?stdout, 54 | ?stderr, 55 | "Command '{}' failed", 56 | program 57 | ); 58 | return Err(error::Error::SubprocessFailed { 59 | command: format!("{:?}", command), 60 | status, 61 | }); 62 | } 63 | 64 | Ok(()) 65 | } 66 | -------------------------------------------------------------------------------- /libazureinit/src/media.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT License. 3 | 4 | //! This module provides functionality for handling media devices, including mounting, 5 | //! unmounting, and reading [`OVF`] (Open Virtualization Format) environment data. It defines 6 | //! the [`Media`] struct with state management for [`Mounted`] and [`Unmounted`] states, as well 7 | //! as utility functions for parsing [`OVF`] environment data and retrieving mounted devices 8 | //! with CDROM-type filesystems. 9 | //! 10 | //! # Overview 11 | //! 12 | //! The `media` module is designed to manage media devices in a cloud environment. It 13 | //! includes functionality to mount and unmount media devices, read [`OVF`] environment data, 14 | //! and parse the data into structured formats. This is particularly useful for provisioning 15 | //! virtual machines with specific configurations. 16 | //! 17 | //! # Key Components 18 | //! 19 | //! - [`Media`]: A struct representing a media device, with state management for [`Mounted`] and [`Unmounted`] states. 20 | //! - [`Mounted`] and [`Unmounted`]: Zero-sized structs used to indicate the state of a [`Media`] instance. 21 | //! - [`parse_ovf_env`]: A function to parse [`OVF`] environment data from a string. 22 | //! - [`mount_parse_ovf_env`]: A function to mount a media device, read its [`OVF`] environment data, and return the parsed data. 23 | //! - [`get_mount_device`]: A function to retrieve a list of mounted devices with CDROM-type filesystems. 24 | //! 25 | //! [`Media`]: struct.Media.html 26 | //! [`Mounted`]: struct.Mounted.html 27 | //! [`Unmounted`]: struct.Unmounted.html 28 | //! [`parse_ovf_env`]: fn.parse_ovf_env.html 29 | //! [`mount_parse_ovf_env`]: fn.mount_parse_ovf_env.html 30 | //! [`get_mount_device`]: fn.get_mount_device.html 31 | //! [`OVF`]: https://www.dmtf.org/standards/ovf 32 | 33 | use std::fs; 34 | use std::fs::create_dir_all; 35 | use std::fs::File; 36 | use std::io::Read; 37 | use std::os::unix::fs::PermissionsExt; 38 | use std::path::Path; 39 | use std::path::PathBuf; 40 | use std::process::Command; 41 | 42 | use serde::Deserialize; 43 | use serde_xml_rs::from_str; 44 | 45 | use tracing; 46 | use tracing::instrument; 47 | 48 | use crate::error::Error; 49 | use fstab::FsTab; 50 | 51 | /// Represents a media device. 52 | /// 53 | /// # Type Parameters 54 | /// 55 | /// * `State` - The state of the media, either `Mounted` or `Unmounted`. 56 | #[derive(Debug, Default, Deserialize, PartialEq, Clone)] 57 | pub struct Environment { 58 | #[serde(rename = "wa:ProvisioningSection")] 59 | pub provisioning_section: ProvisioningSection, 60 | #[serde(rename = "wa:PlatformSettingsSection")] 61 | pub platform_settings_section: PlatformSettingsSection, 62 | } 63 | 64 | /// Provisioning section of the environment configuration. 65 | #[derive(Debug, Default, Deserialize, PartialEq, Clone)] 66 | pub struct ProvisioningSection { 67 | #[serde(rename = "wa:Version")] 68 | pub version: String, 69 | #[serde(rename = "LinuxProvisioningConfigurationSet")] 70 | pub linux_prov_conf_set: LinuxProvisioningConfigurationSet, 71 | } 72 | 73 | /// Linux provisioning configuration set. 74 | #[derive(Debug, Default, Deserialize, PartialEq, Clone)] 75 | pub struct LinuxProvisioningConfigurationSet { 76 | #[serde(rename = "UserName")] 77 | pub username: String, 78 | #[serde(default = "default_password", rename = "UserPassword")] 79 | pub password: String, 80 | #[serde(rename = "HostName")] 81 | pub hostname: String, 82 | } 83 | 84 | /// Platform settings section of the environment configuration. 85 | #[derive(Debug, Default, Deserialize, PartialEq, Clone)] 86 | pub struct PlatformSettingsSection { 87 | #[serde(rename = "wa:Version")] 88 | pub version: String, 89 | #[serde(rename = "PlatformSettings")] 90 | pub platform_settings: PlatformSettings, 91 | } 92 | 93 | /// Platform settings details. 94 | #[derive(Debug, Default, Deserialize, PartialEq, Clone)] 95 | pub struct PlatformSettings { 96 | #[serde(default = "default_preprov", rename = "PreprovisionedVm")] 97 | pub preprovisioned_vm: bool, 98 | #[serde(default = "default_preprov_type", rename = "PreprovisionedVmType")] 99 | pub preprovisioned_vm_type: String, 100 | } 101 | 102 | /// Returns an empty string as the default password. 103 | /// 104 | /// # Returns 105 | /// 106 | /// A `String` containing an empty password. 107 | fn default_password() -> String { 108 | "".to_owned() 109 | } 110 | 111 | /// Returns `false` as the default value for preprovisioned VM. 112 | /// 113 | /// # Returns 114 | /// 115 | /// A `bool` indicating that the VM is not preprovisioned. 116 | fn default_preprov() -> bool { 117 | false 118 | } 119 | 120 | /// Returns "None" as the default type for preprovisioned VM. 121 | /// 122 | /// # Returns 123 | /// 124 | /// A `String` containing "None" as the default preprovisioned VM type. 125 | fn default_preprov_type() -> String { 126 | "None".to_owned() 127 | } 128 | 129 | /// Path to the default mount device. 130 | pub const PATH_MOUNT_DEVICE: &str = "/dev/sr0"; 131 | /// Path to the default mount point. 132 | pub const PATH_MOUNT_POINT: &str = "/run/azure-init/media/"; 133 | 134 | /// Valid filesystems for CDROM devices. 135 | const CDROM_VALID_FS: &[&str] = &["iso9660", "udf"]; 136 | /// Path to the mount table file. 137 | const MTAB_PATH: &str = "/etc/mtab"; 138 | 139 | /// Retrieves a list of mounted devices with CDROM-type filesystems. 140 | /// 141 | /// # Arguments 142 | /// 143 | /// * `path` - Optional path to the mount table file. 144 | /// 145 | /// # Returns 146 | /// 147 | /// A `Result` containing a vector of device paths as strings, or an `Error`. 148 | #[instrument] 149 | pub fn get_mount_device(path: Option<&Path>) -> Result, Error> { 150 | let fstab = FsTab::new(path.unwrap_or_else(|| Path::new(MTAB_PATH))); 151 | let entries = fstab.get_entries()?; 152 | 153 | // Retrieve the names of all devices that have cdrom-type filesystem (e.g., udf) 154 | let cdrom_devices = entries 155 | .into_iter() 156 | .filter_map(|entry| { 157 | if CDROM_VALID_FS.contains(&entry.vfs_type.as_str()) { 158 | Some(entry.fs_spec) 159 | } else { 160 | None 161 | } 162 | }) 163 | .collect(); 164 | 165 | Ok(cdrom_devices) 166 | } 167 | 168 | /// Represents the state of a mounted media. 169 | #[derive(Debug)] 170 | pub struct Mounted; 171 | 172 | /// Represents the state of an unmounted media. 173 | #[derive(Debug)] 174 | pub struct Unmounted; 175 | 176 | /// Represents a media device. 177 | /// 178 | /// # Type Parameters 179 | /// 180 | /// * `State` - The state of the media, either `Mounted` or `Unmounted`. 181 | #[derive(Debug)] 182 | pub struct Media { 183 | device_path: PathBuf, 184 | mount_path: PathBuf, 185 | state: std::marker::PhantomData, 186 | } 187 | 188 | impl Media { 189 | /// Creates a new `Media` instance. 190 | /// 191 | /// # Arguments 192 | /// 193 | /// * `device_path` - The path to the media device. 194 | /// * `mount_path` - The path where the media will be mounted. 195 | /// 196 | /// # Returns 197 | /// 198 | /// A new `Media` instance in the `Unmounted` state. 199 | pub fn new(device_path: PathBuf, mount_path: PathBuf) -> Media { 200 | Media { 201 | device_path, 202 | mount_path, 203 | state: std::marker::PhantomData, 204 | } 205 | } 206 | 207 | /// Mounts the media device. 208 | /// 209 | /// # Returns 210 | /// 211 | /// A `Result` containing the `Media` instance in the `Mounted` state, or an `Error`. 212 | #[instrument] 213 | pub fn mount(self) -> Result, Error> { 214 | create_dir_all(&self.mount_path)?; 215 | 216 | let metadata = fs::metadata(&self.mount_path)?; 217 | let permissions = metadata.permissions(); 218 | let mut new_permissions = permissions; 219 | new_permissions.set_mode(0o700); 220 | fs::set_permissions(&self.mount_path, new_permissions)?; 221 | 222 | let mut command = Command::new("mount"); 223 | command 224 | .arg("-o") 225 | .arg("ro") 226 | .arg(&self.device_path) 227 | .arg(&self.mount_path); 228 | crate::run(command)?; 229 | 230 | Ok(Media { 231 | device_path: self.device_path, 232 | mount_path: self.mount_path, 233 | state: std::marker::PhantomData, 234 | }) 235 | } 236 | } 237 | 238 | impl Media { 239 | /// Unmounts the media device. 240 | /// 241 | /// # Returns 242 | /// 243 | /// A `Result` indicating success or failure. 244 | #[instrument] 245 | pub fn unmount(self) -> Result<(), Error> { 246 | let mut command = Command::new("umount"); 247 | command.arg(self.mount_path); 248 | crate::run(command)?; 249 | 250 | let mut command = Command::new("eject"); 251 | command.arg(self.device_path); 252 | crate::run(command) 253 | } 254 | 255 | /// Reads the OVF environment data to a string. 256 | /// 257 | /// # Returns 258 | /// 259 | /// A `Result` containing the OVF environment data as a string, or an `Error`. 260 | #[instrument] 261 | pub fn read_ovf_env_to_string(&self) -> Result { 262 | let mut file_path = self.mount_path.clone(); 263 | file_path.push("ovf-env.xml"); 264 | let mut file = 265 | File::open(file_path.to_str().unwrap_or(PATH_MOUNT_POINT))?; 266 | let mut contents = String::new(); 267 | file.read_to_string(&mut contents)?; 268 | 269 | Ok(contents) 270 | } 271 | } 272 | 273 | /// Parses the OVF environment data. 274 | /// 275 | /// # Arguments 276 | /// 277 | /// * `ovf_body` - A string slice containing the OVF environment data. 278 | /// 279 | /// # Returns 280 | /// 281 | /// A `Result` containing the parsed `Environment` struct, or an `Error`. 282 | /// 283 | /// # Example 284 | /// 285 | /// ``` 286 | /// use libazureinit::media::parse_ovf_env; 287 | /// 288 | /// // Example dummy OVF environment data 289 | /// let ovf_body = r#" 290 | /// 292 | /// 293 | /// 1.0 294 | /// 295 | /// myusername 296 | /// 297 | /// false 298 | /// myhostname 299 | /// 300 | /// 301 | /// 302 | /// 1.0 303 | /// 304 | /// false 305 | /// None 306 | /// 307 | /// 308 | /// 309 | /// "#; 310 | /// 311 | /// let environment = parse_ovf_env(ovf_body).unwrap(); 312 | /// assert_eq!(environment.provisioning_section.linux_prov_conf_set.username, "myusername"); 313 | /// assert_eq!(environment.provisioning_section.linux_prov_conf_set.password, ""); 314 | /// assert_eq!(environment.provisioning_section.linux_prov_conf_set.hostname, "myhostname"); 315 | /// assert_eq!(environment.platform_settings_section.platform_settings.preprovisioned_vm, false); 316 | /// assert_eq!(environment.platform_settings_section.platform_settings.preprovisioned_vm_type, "None"); 317 | /// ``` 318 | #[instrument(skip_all)] 319 | pub fn parse_ovf_env(ovf_body: &str) -> Result { 320 | let environment: Environment = from_str(ovf_body)?; 321 | 322 | if !environment 323 | .provisioning_section 324 | .linux_prov_conf_set 325 | .password 326 | .is_empty() 327 | { 328 | Err(Error::NonEmptyPassword) 329 | } else { 330 | Ok(environment) 331 | } 332 | } 333 | 334 | /// Mounts the given device, gets OVF environment data, and returns it. 335 | /// 336 | /// # Arguments 337 | /// 338 | /// * `dev` - A string containing the device path. 339 | /// 340 | /// # Returns 341 | /// 342 | /// A `Result` containing the parsed `Environment` struct, or an `Error`. 343 | #[instrument(skip_all)] 344 | pub fn mount_parse_ovf_env(dev: String) -> Result { 345 | let mount_media = 346 | Media::new(PathBuf::from(dev), PathBuf::from(PATH_MOUNT_POINT)); 347 | let mounted = mount_media.mount().map_err(|e| { 348 | tracing::error!(error = ?e, "Failed to mount media."); 349 | e 350 | })?; 351 | 352 | let ovf_body = mounted.read_ovf_env_to_string()?; 353 | let environment = parse_ovf_env(ovf_body.as_str())?; 354 | 355 | mounted.unmount().map_err(|e| { 356 | tracing::error!(error = ?e, "Failed to remove media."); 357 | e 358 | })?; 359 | 360 | Ok(environment) 361 | } 362 | 363 | #[cfg(test)] 364 | mod tests { 365 | use super::*; 366 | use crate::error::Error; 367 | use std::io::Write; 368 | use tempfile::NamedTempFile; 369 | 370 | #[test] 371 | fn test_get_ovf_env_none_missing() { 372 | let ovf_body = r#" 373 | 377 | 378 | 1.0 379 | 381 | LinuxProvisioningConfiguration 382 | myusername 383 | 384 | false 385 | myhostname 386 | 387 | 388 | 389 | 1.0 390 | 392 | kms.core.windows.net 393 | true 394 | 395 | true 396 | true 397 | false 398 | None 399 | false 400 | 401 | 402 | "#; 403 | 404 | let environment: Environment = parse_ovf_env(ovf_body).unwrap(); 405 | 406 | assert_eq!( 407 | environment 408 | .provisioning_section 409 | .linux_prov_conf_set 410 | .username, 411 | "myusername" 412 | ); 413 | assert_eq!( 414 | environment 415 | .provisioning_section 416 | .linux_prov_conf_set 417 | .password, 418 | "" 419 | ); 420 | assert_eq!( 421 | environment 422 | .provisioning_section 423 | .linux_prov_conf_set 424 | .hostname, 425 | "myhostname" 426 | ); 427 | assert_eq!( 428 | environment 429 | .platform_settings_section 430 | .platform_settings 431 | .preprovisioned_vm, 432 | false 433 | ); 434 | assert_eq!( 435 | environment 436 | .platform_settings_section 437 | .platform_settings 438 | .preprovisioned_vm_type, 439 | "None" 440 | ); 441 | } 442 | 443 | #[test] 444 | fn test_get_ovf_env_missing_type() { 445 | let ovf_body = r#" 446 | 450 | 451 | 1.0 452 | 455 | LinuxProvisioningConfiguration 456 | myusername 457 | 458 | false 459 | myhostname 460 | 461 | 462 | 463 | 1.0 464 | 466 | kms.core.windows.net 467 | true 468 | 469 | true 470 | true 471 | false 472 | false 473 | 474 | 475 | "#; 476 | 477 | let environment: Environment = parse_ovf_env(ovf_body).unwrap(); 478 | 479 | assert_eq!( 480 | environment 481 | .provisioning_section 482 | .linux_prov_conf_set 483 | .username, 484 | "myusername" 485 | ); 486 | assert_eq!( 487 | environment 488 | .provisioning_section 489 | .linux_prov_conf_set 490 | .password, 491 | "" 492 | ); 493 | assert_eq!( 494 | environment 495 | .provisioning_section 496 | .linux_prov_conf_set 497 | .hostname, 498 | "myhostname" 499 | ); 500 | assert_eq!( 501 | environment 502 | .platform_settings_section 503 | .platform_settings 504 | .preprovisioned_vm, 505 | false 506 | ); 507 | assert_eq!( 508 | environment 509 | .platform_settings_section 510 | .platform_settings 511 | .preprovisioned_vm_type, 512 | "None" 513 | ); 514 | } 515 | 516 | #[test] 517 | fn test_get_ovf_env_password_provided() { 518 | let ovf_body = r#" 519 | 523 | 524 | 1.0 525 | 527 | LinuxProvisioningConfiguration 528 | myusername 529 | mypassword 530 | false 531 | myhostname 532 | 533 | 534 | 535 | 1.0 536 | 538 | kms.core.windows.net 539 | true 540 | 541 | true 542 | true 543 | true 544 | false 545 | 546 | 547 | "#; 548 | match parse_ovf_env(ovf_body) { 549 | Err(Error::NonEmptyPassword) => {} 550 | _ => panic!("Non-empty passwords aren't allowed"), 551 | }; 552 | } 553 | 554 | #[test] 555 | fn test_get_mount_device_with_cdrom_entries() { 556 | let mut temp_file = 557 | NamedTempFile::new().expect("Failed to create temporary file"); 558 | let mount_table = r#" 559 | /dev/sr0 /mnt/cdrom iso9660 ro,user,noauto 0 0 560 | /dev/sr1 /mnt/cdrom2 udf ro,user,noauto 0 0 561 | "#; 562 | temp_file 563 | .write_all(mount_table.as_bytes()) 564 | .expect("Failed to write to temporary file"); 565 | let temp_path = temp_file.into_temp_path(); 566 | let result = get_mount_device(Some(temp_path.as_ref())); 567 | 568 | let list_devices = result.unwrap(); 569 | assert_eq!( 570 | list_devices, 571 | vec!["/dev/sr0".to_string(), "/dev/sr1".to_string()] 572 | ); 573 | } 574 | 575 | #[test] 576 | fn test_get_mount_device_without_cdrom_entries() { 577 | let mut temp_file = 578 | NamedTempFile::new().expect("Failed to create temporary file"); 579 | let mount_table = r#" 580 | /dev/sda1 / ext4 defaults 0 0 581 | /dev/sda2 /home ext4 defaults 0 0 582 | "#; 583 | temp_file 584 | .write_all(mount_table.as_bytes()) 585 | .expect("Failed to write to temporary file"); 586 | let temp_path = temp_file.into_temp_path(); 587 | let result = get_mount_device(Some(temp_path.as_ref())); 588 | 589 | let list_devices = result.unwrap(); 590 | assert!(list_devices.is_empty()); 591 | } 592 | 593 | #[test] 594 | fn test_get_mount_device_with_mixed_entries() { 595 | let mut temp_file = 596 | NamedTempFile::new().expect("Failed to create temporary file"); 597 | let mount_table = r#" 598 | /dev/sr0 /mnt/cdrom iso9660 ro,user,noauto 0 0 599 | /dev/sda1 / ext4 defaults 0 0 600 | /dev/sr1 /mnt/cdrom2 udf ro,user,noauto 0 0 601 | "#; 602 | temp_file 603 | .write_all(mount_table.as_bytes()) 604 | .expect("Failed to write to temporary file"); 605 | let temp_path = temp_file.into_temp_path(); 606 | let result = get_mount_device(Some(temp_path.as_ref())); 607 | 608 | let list_devices = result.unwrap(); 609 | assert_eq!( 610 | list_devices, 611 | vec!["/dev/sr0".to_string(), "/dev/sr1".to_string()] 612 | ); 613 | } 614 | 615 | #[test] 616 | fn test_get_mount_device_with_empty_table() { 617 | let mut temp_file = 618 | NamedTempFile::new().expect("Failed to create temporary file"); 619 | let mount_table = ""; 620 | temp_file 621 | .write_all(mount_table.as_bytes()) 622 | .expect("Failed to write to temporary file"); 623 | let temp_path = temp_file.into_temp_path(); 624 | let result = get_mount_device(Some(temp_path.as_ref())); 625 | 626 | let list_devices = result.unwrap(); 627 | assert!(list_devices.is_empty()); 628 | } 629 | } 630 | -------------------------------------------------------------------------------- /libazureinit/src/provision/hostname.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT License. 3 | 4 | use std::process::Command; 5 | 6 | use tracing::instrument; 7 | 8 | use crate::error::Error; 9 | 10 | use crate::provision::HostnameProvisioner; 11 | 12 | impl HostnameProvisioner { 13 | pub(crate) fn set(&self, hostname: impl AsRef) -> Result<(), Error> { 14 | match self { 15 | Self::Hostnamectl => hostnamectl(hostname.as_ref()), 16 | #[cfg(test)] 17 | Self::FakeHostnamectl => Ok(()), 18 | } 19 | } 20 | } 21 | 22 | #[instrument(skip_all)] 23 | fn hostnamectl(hostname: &str) -> Result<(), Error> { 24 | let path_hostnamectl = env!("PATH_HOSTNAMECTL"); 25 | 26 | let mut command = Command::new(path_hostnamectl); 27 | command.arg("set-hostname").arg(hostname); 28 | crate::run(command) 29 | } 30 | -------------------------------------------------------------------------------- /libazureinit/src/provision/mod.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT License. 3 | pub mod hostname; 4 | pub mod password; 5 | pub(crate) mod ssh; 6 | pub mod user; 7 | 8 | use crate::config::{ 9 | Config, HostnameProvisioner, PasswordProvisioner, UserProvisioner, 10 | }; 11 | use crate::error::Error; 12 | use crate::User; 13 | use tracing::instrument; 14 | 15 | /// The interface for applying the desired configuration to the host. 16 | /// 17 | /// By default, all known tools for provisioning a particular resource are tried 18 | /// until one succeeds. Particular tools can be selected via the 19 | /// `*_provisioners()` methods ([`Provision::hostname_provisioners`], 20 | /// [`Provision::user_provisioners`], etc). 21 | /// 22 | /// To actually apply the configuration, use [`Provision::provision`]. 23 | #[derive(Clone)] 24 | pub struct Provision { 25 | hostname: String, 26 | user: User, 27 | config: Config, 28 | } 29 | 30 | impl Provision { 31 | pub fn new( 32 | hostname: impl Into, 33 | user: User, 34 | config: Config, 35 | ) -> Self { 36 | Self { 37 | hostname: hostname.into(), 38 | user, 39 | config, 40 | } 41 | } 42 | 43 | /// Provisioning can fail if the host lacks the necessary tools. For example, 44 | /// if there is no useradd command on the system's PATH, or if the command 45 | /// returns an error, this will return an error. It does not attempt to undo 46 | /// partial provisioning. 47 | #[instrument(skip_all)] 48 | pub fn provision(self) -> Result<(), Error> { 49 | self.config 50 | .hostname_provisioners 51 | .backends 52 | .iter() 53 | .find_map(|backend| match backend { 54 | HostnameProvisioner::Hostnamectl => { 55 | HostnameProvisioner::Hostnamectl.set(&self.hostname).ok() 56 | } 57 | #[cfg(test)] 58 | HostnameProvisioner::FakeHostnamectl => Some(()), 59 | }) 60 | .ok_or(Error::NoHostnameProvisioner)?; 61 | 62 | self.config 63 | .user_provisioners 64 | .backends 65 | .iter() 66 | .find_map(|backend| match backend { 67 | UserProvisioner::Useradd => { 68 | UserProvisioner::Useradd.create(&self.user).ok() 69 | } 70 | #[cfg(test)] 71 | UserProvisioner::FakeUseradd => Some(()), 72 | }) 73 | .ok_or(Error::NoUserProvisioner)?; 74 | 75 | self.config 76 | .password_provisioners 77 | .backends 78 | .iter() 79 | .find_map(|backend| match backend { 80 | PasswordProvisioner::Passwd => { 81 | PasswordProvisioner::Passwd.set(&self.user).ok() 82 | } 83 | #[cfg(test)] 84 | PasswordProvisioner::FakePasswd => Some(()), 85 | }) 86 | .ok_or(Error::NoPasswordProvisioner)?; 87 | 88 | if !self.user.ssh_keys.is_empty() { 89 | let authorized_keys_path = self.config.ssh.authorized_keys_path; 90 | let query_sshd_config = self.config.ssh.query_sshd_config; 91 | 92 | let user = nix::unistd::User::from_name(&self.user.name)?.ok_or( 93 | Error::UserMissing { 94 | user: self.user.name, 95 | }, 96 | )?; 97 | ssh::provision_ssh( 98 | &user, 99 | &self.user.ssh_keys, 100 | authorized_keys_path, 101 | query_sshd_config, 102 | )?; 103 | } 104 | 105 | Ok(()) 106 | } 107 | } 108 | 109 | #[cfg(test)] 110 | mod tests { 111 | use super::{Config, Provision}; 112 | use crate::config::{ 113 | HostnameProvisioner, PasswordProvisioner, UserProvisioner, 114 | }; 115 | use crate::config::{ 116 | HostnameProvisioners, PasswordProvisioners, UserProvisioners, 117 | }; 118 | use crate::User; 119 | 120 | #[test] 121 | fn test_successful_provision() { 122 | let mock_config = Config { 123 | hostname_provisioners: HostnameProvisioners { 124 | backends: vec![HostnameProvisioner::FakeHostnamectl], 125 | }, 126 | user_provisioners: UserProvisioners { 127 | backends: vec![UserProvisioner::FakeUseradd], 128 | }, 129 | password_provisioners: PasswordProvisioners { 130 | backends: vec![PasswordProvisioner::FakePasswd], 131 | }, 132 | ..Config::default() 133 | }; 134 | 135 | let _p = Provision::new( 136 | "my-hostname".to_string(), 137 | User::new("azureuser", vec![]), 138 | mock_config, 139 | ) 140 | .provision() 141 | .unwrap(); 142 | } 143 | } 144 | -------------------------------------------------------------------------------- /libazureinit/src/provision/password.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT License. 3 | 4 | use std::process::Command; 5 | 6 | use std::path::PathBuf; 7 | 8 | use tracing::instrument; 9 | 10 | use crate::{error::Error, User}; 11 | 12 | use super::ssh::update_sshd_config; 13 | 14 | use crate::provision::PasswordProvisioner; 15 | 16 | impl PasswordProvisioner { 17 | pub(crate) fn set(&self, user: &User) -> Result<(), Error> { 18 | match self { 19 | Self::Passwd => passwd(user), 20 | #[cfg(test)] 21 | Self::FakePasswd => Ok(()), 22 | } 23 | } 24 | } 25 | 26 | // Determines the appropriate SSH configuration file path based on the filesystem. 27 | // If the "/etc/ssh/sshd_config.d" directory exists, it returns the path for a drop-in configuration file. 28 | // Otherwise, it defaults to the main SSH configuration file at "/etc/ssh/sshd_config". 29 | fn get_sshd_config_path() -> &'static str { 30 | if PathBuf::from("/etc/ssh/sshd_config.d").is_dir() { 31 | "/etc/ssh/sshd_config.d/50-azure-init.conf" 32 | } else { 33 | "/etc/ssh/sshd_config" 34 | } 35 | } 36 | 37 | #[instrument(skip_all)] 38 | fn passwd(user: &User) -> Result<(), Error> { 39 | // Update the sshd configuration to allow password authentication. 40 | let sshd_config_path = get_sshd_config_path(); 41 | if let Err(error) = update_sshd_config(sshd_config_path) { 42 | tracing::error!( 43 | ?error, 44 | sshd_config_path, 45 | "Failed to update sshd configuration for password authentication" 46 | ); 47 | return Err(Error::UpdateSshdConfig); 48 | } 49 | let path_passwd = env!("PATH_PASSWD"); 50 | 51 | if user.password.is_none() { 52 | let mut command = Command::new(path_passwd); 53 | command.arg("-d").arg(&user.name); 54 | crate::run(command)?; 55 | } else { 56 | // creating user with a non-empty password is not allowed. 57 | return Err(Error::NonEmptyPassword); 58 | } 59 | 60 | Ok(()) 61 | } 62 | -------------------------------------------------------------------------------- /libazureinit/src/provision/ssh.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT License. 3 | 4 | //! This module provides functionality for provisioning SSH keys for a user. 5 | //! 6 | //! It includes functions to create the necessary `.ssh` directory, set the appropriate 7 | //! permissions, and write the provided public keys to the `authorized_keys` file. 8 | 9 | use crate::error::Error; 10 | use crate::imds::PublicKeys; 11 | use lazy_static::lazy_static; 12 | use nix::unistd::{chown, User}; 13 | use regex::Regex; 14 | use std::{ 15 | fs::{ 16 | OpenOptions, {File, Permissions}, 17 | }, 18 | io::{self, Read, Write}, 19 | os::unix::fs::{DirBuilderExt, PermissionsExt}, 20 | path::PathBuf, 21 | process::{Command, Output}, 22 | }; 23 | use tracing::{error, info, instrument}; 24 | 25 | lazy_static! { 26 | /// A regular expression to match the `PasswordAuthentication` setting in the SSH configuration. 27 | static ref PASSWORD_REGEX: Regex = Regex::new( 28 | r"(?m)^\s*#?\s*PasswordAuthentication\s+(yes|no)\s*$" 29 | ) 30 | .expect( 31 | "The regular expression is invalid or exceeds the default regex size" 32 | ); 33 | } 34 | 35 | /// Provisions SSH keys for the specified user. 36 | /// 37 | /// Creates the `.ssh` directory in the user's home directory, sets the appropriate 38 | /// permissions, and writes the provided public keys to the `authorized_keys` file. 39 | /// 40 | /// # Arguments 41 | /// 42 | /// * `user` - A reference to the user for whom the SSH keys are being provisioned. 43 | /// * `keys` - A slice of `PublicKeys` to be added to the `authorized_keys` file. 44 | /// * `authorized_keys_path_string` - An optional string specifying the path to the `authorized_keys` file. 45 | /// 46 | /// # Returns 47 | /// 48 | /// This function returns `Result<(), Error>` indicating success or failure. 49 | /// 50 | /// # Errors 51 | /// 52 | /// This function will return an error if it fails to create the `.ssh` directory, set permissions, 53 | /// or write to the `authorized_keys` file. 54 | #[instrument(skip_all, name = "ssh")] 55 | pub(crate) fn provision_ssh( 56 | user: &User, 57 | keys: &[PublicKeys], 58 | authorized_keys_path: PathBuf, 59 | query_sshd_config: bool, 60 | ) -> Result<(), Error> { 61 | let authorized_keys_path = if query_sshd_config { 62 | tracing::info!( 63 | "Attempting to get authorized keys path via sshd -G as configured." 64 | ); 65 | 66 | match get_authorized_keys_path_from_sshd(|| { 67 | Command::new("sshd").arg("-G").output() 68 | }) { 69 | Some(path) => user.dir.join(path), 70 | None => { 71 | tracing::warn!("sshd -G failed; using configured authorized_keys_path as fallback."); 72 | user.dir.join(authorized_keys_path) 73 | } 74 | } 75 | } else { 76 | user.dir.join(authorized_keys_path) 77 | }; 78 | 79 | let ssh_dir = user.dir.join(".ssh"); 80 | std::fs::DirBuilder::new() 81 | .recursive(true) 82 | .mode(0o700) 83 | .create(&ssh_dir)?; 84 | std::fs::set_permissions(&ssh_dir, Permissions::from_mode(0o700))?; 85 | 86 | chown(&ssh_dir, Some(user.uid), Some(user.gid))?; 87 | 88 | tracing::info!( 89 | target: "libazureinit::ssh::authorized_keys", 90 | "Using authorized_keys path: {:?}", 91 | authorized_keys_path 92 | ); 93 | 94 | let mut authorized_keys = File::create(&authorized_keys_path)?; 95 | authorized_keys.set_permissions(Permissions::from_mode(0o600))?; 96 | 97 | keys.iter() 98 | .try_for_each(|key| writeln!(authorized_keys, "{}", key.key_data))?; 99 | 100 | chown(&authorized_keys_path, Some(user.uid), Some(user.gid))?; 101 | 102 | Ok(()) 103 | } 104 | 105 | /// Retrieves the path to the `authorized_keys` file from the SSH daemon configuration. 106 | /// 107 | /// Runs the SSH daemon to get the configuration and extracts 108 | /// the `AuthorizedKeysFile` setting. 109 | /// 110 | /// # Arguments 111 | /// 112 | /// * `sshd_config_command_runner` - A function that runs the SSH daemon command and returns its output. 113 | /// 114 | /// # Returns 115 | /// 116 | /// This function returns a path to the `authorized_keys` file if found, 117 | /// or `None` if the setting is not found. 118 | fn get_authorized_keys_path_from_sshd( 119 | sshd_config_command_runner: impl Fn() -> io::Result, 120 | ) -> Option { 121 | let output = run_sshd_command(sshd_config_command_runner)?; 122 | 123 | let path = extract_authorized_keys_file_path(&output.stdout); 124 | if path.is_none() { 125 | error!("No authorizedkeysfile setting found in sshd configuration"); 126 | } 127 | path 128 | } 129 | 130 | /// Runs the SSH daemon command to get its configuration. 131 | /// 132 | /// # Arguments 133 | /// 134 | /// * `sshd_config_command_runner` - A function that runs the SSH daemon command and returns its output. 135 | /// 136 | /// # Returns 137 | /// 138 | /// This function returns an output of the command. 139 | fn run_sshd_command( 140 | sshd_config_command_runner: impl Fn() -> io::Result, 141 | ) -> Option { 142 | match sshd_config_command_runner() { 143 | Ok(output) if output.status.success() => { 144 | info!( 145 | target: "libazureinit::ssh::success", 146 | stdout_length = output.stdout.len(), 147 | "Executed sshd -G successfully", 148 | ); 149 | Some(output) 150 | } 151 | Ok(output) => { 152 | let stdout = String::from_utf8_lossy(&output.stdout); 153 | let stderr = String::from_utf8_lossy(&output.stderr); 154 | error!( 155 | code=output.status.code().unwrap_or(-1), 156 | stdout=%stdout, 157 | stderr=%stderr, 158 | "Failed to execute sshd -G, assuming sshd configuration defaults" 159 | ); 160 | None 161 | } 162 | Err(e) => { 163 | error!( 164 | error=%e, 165 | "Failed to execute sshd -G, assuming sshd configuration defaults", 166 | ); 167 | None 168 | } 169 | } 170 | } 171 | 172 | /// Extracts the `AuthorizedKeysFile` path from the SSH daemon configuration output. 173 | /// 174 | /// Parses the output of the SSH daemon configuration command and extracts the 175 | /// `AuthorizedKeysFile` setting. 176 | /// 177 | /// # Arguments 178 | /// 179 | /// * `sshd_config_output` - A byte slice containing the output of the SSH daemon configuration command. 180 | /// 181 | /// # Returns 182 | /// 183 | /// This function returns an `Option` containing the path to the `authorized_keys` file if found, 184 | /// or `None` if the setting is not found. 185 | fn extract_authorized_keys_file_path(stdout: &[u8]) -> Option { 186 | let output = String::from_utf8_lossy(stdout); 187 | for line in output.lines() { 188 | if line.starts_with("authorizedkeysfile") { 189 | let keypath = line.split_whitespace().nth(1).map(|s| { 190 | info!( 191 | target: "libazureinit::ssh::authorized_keys", 192 | authorizedkeysfile = %s, 193 | "Using sshd's authorizedkeysfile path configuration" 194 | ); 195 | s.to_string() 196 | }); 197 | if keypath.is_some() { 198 | return keypath; 199 | } 200 | } 201 | } 202 | None 203 | } 204 | 205 | /// Updates the SSH daemon configuration to ensure `PasswordAuthentication` is set to `yes`. 206 | /// 207 | /// Checks if the `sshd_config` file exists and updates the `PasswordAuthentication` 208 | /// setting to `yes`. If the file does not exist, it creates a new one with the appropriate setting. 209 | /// 210 | /// # Arguments 211 | /// 212 | /// * `sshd_config_path` - A string slice containing the path to the `sshd_config` file. 213 | /// 214 | /// # Returns 215 | /// 216 | /// This function returns `Result<(), io::Error>` indicating success or failure. 217 | /// 218 | /// # Errors 219 | /// 220 | /// This function will return an error if it fails to read, write, or create the `sshd_config` file. 221 | pub(crate) fn update_sshd_config( 222 | sshd_config_path: &str, 223 | ) -> Result<(), io::Error> { 224 | // Check if the path exists otherwise create it 225 | let sshd_config_path = PathBuf::from(sshd_config_path); 226 | if !sshd_config_path.exists() { 227 | let mut file = std::fs::File::create(&sshd_config_path)?; 228 | file.set_permissions(Permissions::from_mode(0o600))?; 229 | file.write_all(b"PasswordAuthentication yes\n")?; 230 | tracing::info!( 231 | ?sshd_config_path, 232 | "Created new sshd drop-in configuration file" 233 | ); 234 | return Ok(()); 235 | } 236 | 237 | let mut file_content = String::new(); 238 | { 239 | let mut file = OpenOptions::new().read(true).open(&sshd_config_path)?; 240 | file.read_to_string(&mut file_content)?; 241 | } 242 | 243 | let re = &PASSWORD_REGEX; 244 | if re.is_match(&file_content) { 245 | let modified_content = re.replace_all( 246 | &file_content, 247 | "PasswordAuthentication yes # modified by azure-init\n", 248 | ); 249 | 250 | let mut sshd_config = OpenOptions::new() 251 | .write(true) 252 | .truncate(true) 253 | .open(&sshd_config_path)?; 254 | sshd_config.write_all(modified_content.as_bytes())?; 255 | 256 | tracing::info!( 257 | ?sshd_config_path, 258 | "Updated existing sshd setting to allow password authentication" 259 | ); 260 | } else { 261 | let mut file = 262 | OpenOptions::new().append(true).open(&sshd_config_path)?; 263 | file.write_all(b"PasswordAuthentication yes # added by azure-init\n")?; 264 | 265 | tracing::info!( 266 | ?sshd_config_path, 267 | "Added new sshd setting to allow password authentication" 268 | ); 269 | } 270 | 271 | Ok(()) 272 | } 273 | 274 | #[cfg(test)] 275 | mod tests { 276 | use crate::imds::PublicKeys; 277 | use crate::provision::ssh::{ 278 | extract_authorized_keys_file_path, get_authorized_keys_path_from_sshd, 279 | provision_ssh, run_sshd_command, update_sshd_config, 280 | }; 281 | use std::{ 282 | fs::{File, Permissions}, 283 | io::{self, Read, Write}, 284 | os::unix::fs::{DirBuilderExt, PermissionsExt}, 285 | os::unix::process::ExitStatusExt, 286 | process::{ExitStatus, Output}, 287 | }; 288 | use tempfile::TempDir; 289 | 290 | fn create_output(status_code: i32, stdout: &str, stderr: &str) -> Output { 291 | Output { 292 | status: ExitStatus::from_raw(status_code), 293 | stdout: stdout.as_bytes().to_vec(), 294 | stderr: stderr.as_bytes().to_vec(), 295 | } 296 | } 297 | 298 | fn get_test_user_with_home_dir(create_ssh_dir: bool) -> nix::unistd::User { 299 | let home_dir = 300 | tempfile::TempDir::new().expect("Failed to create temp directory"); 301 | 302 | let mut user = 303 | nix::unistd::User::from_name(whoami::username().as_str()) 304 | .expect("Failed to get user") 305 | .expect("User does not exist"); 306 | user.dir = home_dir.path().into(); 307 | 308 | if create_ssh_dir { 309 | std::fs::DirBuilder::new() 310 | .mode(0o700) 311 | .create(user.dir.join(".ssh")) 312 | .expect("Failed to create .ssh directory"); 313 | } 314 | 315 | user 316 | } 317 | 318 | #[test] 319 | fn test_run_sshd_command_success() { 320 | let expected_stdout = "authorizedkeysfile .ssh/test_authorized_keys"; 321 | let mock_runner = 322 | || Ok(create_output(0, expected_stdout, "some stderr")); 323 | 324 | let result = run_sshd_command(mock_runner); 325 | assert!(result.is_some()); 326 | assert_eq!( 327 | String::from_utf8_lossy(&result.unwrap().stdout), 328 | expected_stdout 329 | ); 330 | } 331 | 332 | #[test] 333 | fn test_run_sshd_command_failure() { 334 | let stdout = "authorizedkeysfile .ssh/test_authorized_keys"; 335 | let mock_runner = 336 | || Ok(create_output(1, stdout, "Error running sshd -G")); 337 | 338 | let result = run_sshd_command(mock_runner); 339 | assert!(result.is_none()); 340 | } 341 | 342 | #[test] 343 | fn test_run_sshd_command_error() { 344 | let mock_runner = || { 345 | Err(io::Error::new(io::ErrorKind::NotFound, "command not found")) 346 | }; 347 | 348 | let result = run_sshd_command(mock_runner); 349 | assert!(result.is_none()); 350 | } 351 | 352 | #[test] 353 | fn test_get_authorized_keys_path_from_sshd_success() { 354 | let test_cases = vec![ 355 | ( 356 | "authorizedkeysfile .ssh/authorized_keys", 357 | Some(".ssh/authorized_keys"), 358 | ), 359 | ( 360 | "authorizedkeysfile .ssh/other_authorized_keys", 361 | Some(".ssh/other_authorized_keys"), 362 | ), 363 | ( 364 | "authorizedkeysfile /custom/path/to/keys", 365 | Some("/custom/path/to/keys"), 366 | ), 367 | ("# No authorizedkeysfile line here", None), // Case with no match 368 | ]; 369 | 370 | for (stdout, expected_path) in test_cases { 371 | let mock_runner = || Ok(create_output(0, stdout, "some stderr")); 372 | 373 | let result: Option = run_sshd_command(mock_runner); 374 | assert!(result.is_some(), "Expected a successful command output"); 375 | 376 | let output: Output = result.unwrap(); 377 | let stdout_str = String::from_utf8_lossy(&output.stdout); 378 | assert_eq!(stdout_str, stdout); 379 | 380 | let extracted_path: Option = 381 | extract_authorized_keys_file_path(&output.stdout); 382 | assert_eq!( 383 | extracted_path, 384 | expected_path.map(|s| s.to_string()), 385 | "Expected path extraction to match for stdout: {}", 386 | stdout 387 | ); 388 | } 389 | } 390 | 391 | #[test] 392 | fn test_get_authorized_keys_path_from_sshd_no_authorized_keys() { 393 | let mock_runner = 394 | || Ok(create_output(0, "no authorizedkeysfile here", "")); 395 | 396 | let result = get_authorized_keys_path_from_sshd(mock_runner); 397 | assert!(result.is_none()); 398 | } 399 | 400 | #[test] 401 | fn test_get_authorized_keys_path_from_sshd_command_fails() { 402 | // Mock sshd command runner that simulates a failed command execution 403 | let mock_runner = 404 | || Err(io::Error::new(io::ErrorKind::Other, "command error")); 405 | 406 | let result = get_authorized_keys_path_from_sshd(mock_runner); 407 | assert!(result.is_none()); 408 | } 409 | 410 | #[test] 411 | fn test_extract_authorized_keys_file_path_valid() { 412 | let stdout = b"authorizedkeysfile .ssh/test_authorized_keys\n"; 413 | let result = extract_authorized_keys_file_path(stdout); 414 | assert_eq!(result, Some(".ssh/test_authorized_keys".to_string())); 415 | } 416 | 417 | #[test] 418 | fn test_extract_authorized_keys_file_path_invalid() { 419 | let stdout = b"some irrelevant output\n"; 420 | let result = extract_authorized_keys_file_path(stdout); 421 | assert!(result.is_none()); 422 | } 423 | 424 | // Test that we set the permission bits correctly on the ssh files; sadly it's difficult to test 425 | // chown without elevated permissions. 426 | #[test] 427 | fn test_provision_ssh() { 428 | let user = get_test_user_with_home_dir(false); 429 | let keys = vec![ 430 | PublicKeys { 431 | key_data: "not-a-real-key abc123".to_string(), 432 | path: "unused".to_string(), 433 | }, 434 | PublicKeys { 435 | key_data: "not-a-real-key xyz987".to_string(), 436 | path: "unused".to_string(), 437 | }, 438 | ]; 439 | 440 | let authorized_keys_path = user.dir.join(".ssh/xauthorized_keys"); 441 | 442 | provision_ssh(&user, &keys, authorized_keys_path, false).unwrap(); 443 | 444 | let ssh_path = user.dir.join(".ssh"); 445 | let ssh_dir = std::fs::File::open(&ssh_path).unwrap(); 446 | let mut auth_file = 447 | std::fs::File::open(&ssh_path.join("xauthorized_keys")).unwrap(); 448 | let mut buf = String::new(); 449 | auth_file.read_to_string(&mut buf).unwrap(); 450 | 451 | assert_eq!("not-a-real-key abc123\nnot-a-real-key xyz987\n", buf); 452 | // Refer to man 7 inode for details on the mode - 100000 is a regular file, 040000 is a directory 453 | assert_eq!( 454 | ssh_dir.metadata().unwrap().permissions(), 455 | Permissions::from_mode(0o040700) 456 | ); 457 | assert_eq!( 458 | auth_file.metadata().unwrap().permissions(), 459 | Permissions::from_mode(0o100600) 460 | ); 461 | } 462 | 463 | // Test that if the .ssh directory already exists, we handle it gracefully. This can occur if, for example, 464 | // /etc/skel includes it. This also checks that we fix the permissions if /etc/skel has been mis-configured. 465 | #[test] 466 | fn test_pre_existing_ssh_dir() { 467 | let user = get_test_user_with_home_dir(true); 468 | let keys = vec![ 469 | PublicKeys { 470 | key_data: "not-a-real-key abc123".to_string(), 471 | path: "unused".to_string(), 472 | }, 473 | PublicKeys { 474 | key_data: "not-a-real-key xyz987".to_string(), 475 | path: "unused".to_string(), 476 | }, 477 | ]; 478 | 479 | let authorized_keys_path = user.dir.join(".ssh/xauthorized_keys"); 480 | 481 | provision_ssh(&user, &keys, authorized_keys_path, false).unwrap(); 482 | 483 | let ssh_dir = std::fs::File::open(user.dir.join(".ssh")).unwrap(); 484 | assert_eq!( 485 | ssh_dir.metadata().unwrap().permissions(), 486 | Permissions::from_mode(0o040700) 487 | ); 488 | } 489 | 490 | // Test that any pre-existing authorized_keys are overwritten. 491 | #[test] 492 | fn test_pre_existing_authorized_keys() { 493 | let user = get_test_user_with_home_dir(true); 494 | let keys = vec![ 495 | PublicKeys { 496 | key_data: "not-a-real-key abc123".to_string(), 497 | path: "unused".to_string(), 498 | }, 499 | PublicKeys { 500 | key_data: "not-a-real-key xyz987".to_string(), 501 | path: "unused".to_string(), 502 | }, 503 | ]; 504 | 505 | let authorized_keys_path = user.dir.join(".ssh/xauthorized_keys"); 506 | 507 | provision_ssh(&user, &keys[1..], authorized_keys_path.clone(), false) 508 | .unwrap(); 509 | 510 | provision_ssh(&user, &keys[1..], authorized_keys_path.clone(), false) 511 | .unwrap(); 512 | 513 | let mut auth_file = 514 | std::fs::File::open(user.dir.join(".ssh/xauthorized_keys")) 515 | .unwrap(); 516 | let mut buf = String::new(); 517 | auth_file.read_to_string(&mut buf).unwrap(); 518 | 519 | assert_eq!("not-a-real-key xyz987\n", buf); 520 | } 521 | 522 | #[test] 523 | fn test_update_sshd_config_create_new() -> io::Result<()> { 524 | let temp_dir = TempDir::new().unwrap(); 525 | let sshd_config_path = temp_dir.path().join("sshd_config"); 526 | let ret: Result<(), io::Error> = 527 | update_sshd_config(sshd_config_path.to_str().unwrap()); 528 | assert!(ret.is_ok()); 529 | 530 | let mut updated_content = String::new(); 531 | let mut file = File::open(&sshd_config_path).unwrap(); 532 | file.read_to_string(&mut updated_content).unwrap(); 533 | assert!(updated_content.contains("PasswordAuthentication yes")); 534 | Ok(()) 535 | } 536 | 537 | #[test] 538 | fn test_update_sshd_config_change() -> io::Result<()> { 539 | let temp_dir = TempDir::new()?; 540 | let sshd_config_path = temp_dir.path().join("sshd_config"); 541 | { 542 | let mut file = File::create(&sshd_config_path)?; 543 | writeln!(file, "PasswordAuthentication no")?; 544 | } 545 | 546 | let ret: Result<(), io::Error> = 547 | update_sshd_config(sshd_config_path.to_str().unwrap()); 548 | assert!(ret.is_ok()); 549 | let mut updated_content = String::new(); 550 | { 551 | let mut file = File::open(&sshd_config_path)?; 552 | file.read_to_string(&mut updated_content)?; 553 | } 554 | assert!(updated_content.contains("PasswordAuthentication yes")); 555 | assert!(!updated_content.contains("PasswordAuthentication no")); 556 | 557 | Ok(()) 558 | } 559 | 560 | #[test] 561 | fn test_update_sshd_config_no_change() -> io::Result<()> { 562 | let temp_dir = TempDir::new()?; 563 | let sshd_config_path = temp_dir.path().join("sshd_config"); 564 | { 565 | let mut file = File::create(&sshd_config_path)?; 566 | writeln!(file, "PasswordAuthentication yes")?; 567 | } 568 | let ret: Result<(), io::Error> = 569 | update_sshd_config(sshd_config_path.to_str().unwrap()); 570 | assert!(ret.is_ok()); 571 | let mut updated_content = String::new(); 572 | { 573 | let mut file = File::open(&sshd_config_path)?; 574 | file.read_to_string(&mut updated_content)?; 575 | } 576 | assert!(updated_content.contains("PasswordAuthentication yes")); 577 | assert!(!updated_content.contains("PasswordAuthentication no")); 578 | 579 | Ok(()) 580 | } 581 | } 582 | -------------------------------------------------------------------------------- /libazureinit/src/provision/user.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT License. 3 | 4 | use std::{os::unix::fs::OpenOptionsExt, process::Command}; 5 | 6 | use std::io::Write; 7 | 8 | use tracing::instrument; 9 | 10 | use crate::{error::Error, imds::PublicKeys}; 11 | 12 | use crate::config::UserProvisioner; 13 | 14 | /// The user and its related configuration to create on the host. 15 | /// 16 | /// A bare minimum user includes a name and a set of SSH public keys to allow 17 | /// the user to log into the host. Additional configuration includes a set of 18 | /// supplementary groups to add the user to, and a password to set for the user. 19 | /// 20 | /// By default, the user is not included in any group. To grant administrator 21 | /// privileges via the `sudo` command, additional groups like "wheel" can be 22 | /// added with the [`User::with_groups`] method. 23 | /// 24 | /// # Example 25 | /// 26 | /// ``` 27 | /// # use libazureinit::User; 28 | /// let user = User::new("azure-user", ["ssh-ed25519 NOTAREALKEY".into()]) 29 | /// .with_groups(["wheel".to_string(), "dialout".to_string()]); 30 | /// ``` 31 | /// 32 | /// The [`useradd`] and [`user_exists`] functions handle the creation and 33 | /// management of system users, including group assignments. These functions 34 | /// ensure that the specified user is correctly set up with the appropriate 35 | /// group memberships, whether they are newly created or already exist on the 36 | /// system. 37 | /// 38 | /// - **User Creation:** 39 | /// - If the user does not already exist, it is created with the specified 40 | /// groups. 41 | /// - **Existing User:** 42 | /// - If the user already exists and belongs to the specified groups, no 43 | /// changes are made, and the function exits. 44 | /// - If the user exists but does not belong to one or more of the specified 45 | /// groups, the user will be added to those groups using the `usermod -aG` 46 | /// command. 47 | /// - **Group Management:** 48 | /// - The `usermod -aG` command is used to add the user to the specified 49 | /// groups without removing them from any existing groups. 50 | /// 51 | /// # Examples 52 | /// 53 | /// ``` 54 | /// # use libazureinit::User; 55 | /// let user = User::new("azureuser", vec![]).with_groups(["wheel".to_string()]); 56 | /// let user_with_new_group = User::new("azureuser", vec![]).with_groups(["adm".to_string()]); 57 | /// ``` 58 | #[derive(Clone)] 59 | pub struct User { 60 | pub(crate) name: String, 61 | pub(crate) groups: Vec, 62 | pub(crate) ssh_keys: Vec, 63 | pub(crate) password: Option, 64 | } 65 | 66 | impl core::fmt::Debug for User { 67 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 68 | // This is manually implemented to avoid printing the password if it's set 69 | f.debug_struct("User") 70 | .field("name", &self.name) 71 | .field("groups", &self.groups) 72 | .field("ssh_keys", &self.ssh_keys) 73 | .field("password", &self.password.is_some()) 74 | .finish() 75 | } 76 | } 77 | 78 | impl User { 79 | /// Configure the user being provisioned on the host. 80 | /// 81 | /// What constitutes a valid username depends on the host configuration and 82 | /// no validation will occur prior to provisioning the host. 83 | pub fn new( 84 | name: impl Into, 85 | ssh_keys: impl Into>, 86 | ) -> Self { 87 | Self { 88 | name: name.into(), 89 | groups: vec![], 90 | ssh_keys: ssh_keys.into(), 91 | password: None, 92 | } 93 | } 94 | 95 | /// Set a password for the user; this is optional. 96 | pub fn with_password(mut self, password: impl Into) -> Self { 97 | self.password = Some(password.into()); 98 | self 99 | } 100 | 101 | /// A list of supplemental group names to add the user to. 102 | /// 103 | /// If any of the groups do not exist on the host, provisioning will fail. 104 | pub fn with_groups(mut self, groups: impl Into>) -> Self { 105 | self.groups = groups.into(); 106 | self 107 | } 108 | } 109 | 110 | impl UserProvisioner { 111 | pub(crate) fn create(&self, user: &User) -> Result<(), Error> { 112 | match self { 113 | Self::Useradd => { 114 | useradd(user)?; 115 | let path = "/etc/sudoers.d/azure-init-user"; 116 | add_user_for_passwordless_sudo(user.name.as_str(), path) 117 | } 118 | #[cfg(test)] 119 | Self::FakeUseradd => Ok(()), 120 | } 121 | } 122 | } 123 | 124 | #[instrument(skip_all)] 125 | fn user_exists(username: &str) -> Result { 126 | let output = Command::new("getent") 127 | .arg("passwd") 128 | .arg(username) 129 | .output()?; 130 | 131 | Ok(output.status.success()) 132 | } 133 | 134 | #[instrument(skip_all)] 135 | fn useradd(user: &User) -> Result<(), Error> { 136 | if user_exists(&user.name)? { 137 | tracing::info!( 138 | target: "libazureinit::user::add", 139 | "User '{}' already exists. Skipping user creation.", 140 | user.name 141 | ); 142 | 143 | let group_list = user.groups.join(","); 144 | 145 | tracing::info!( 146 | target: "libazureinit:user::add", 147 | "User '{}' is being added to the following groups: {}", 148 | user.name, 149 | group_list 150 | ); 151 | 152 | let mut command = Command::new("usermod"); 153 | command.arg("-aG").arg(&group_list).arg(&user.name); 154 | return crate::run(command); 155 | } 156 | 157 | let path_useradd = env!("PATH_USERADD"); 158 | 159 | let mut command = Command::new(path_useradd); 160 | command 161 | .arg(&user.name) 162 | .arg("--comment") 163 | .arg("azure-init created this user based on username provided in IMDS") 164 | .arg("--groups") 165 | .arg(user.groups.join(",")) 166 | .arg("-d") 167 | .arg(format!("/home/{}", user.name)) 168 | .arg("-m"); 169 | crate::run(command) 170 | } 171 | 172 | fn add_user_for_passwordless_sudo( 173 | username: &str, 174 | path: &str, 175 | ) -> Result<(), Error> { 176 | // Create a file under /etc/sudoers.d with azure-init-user 177 | let mut sudoers_file = std::fs::OpenOptions::new() 178 | .write(true) 179 | .create(true) 180 | .truncate(true) 181 | .mode(0o600) 182 | .open(path)?; 183 | 184 | writeln!(sudoers_file, "{} ALL=(ALL) NOPASSWD: ALL", username)?; 185 | sudoers_file.flush()?; 186 | Ok(()) 187 | } 188 | 189 | #[cfg(test)] 190 | mod tests { 191 | use std::{fs, os::unix::fs::PermissionsExt}; 192 | use tempfile::tempdir; 193 | 194 | use crate::User; 195 | 196 | use super::add_user_for_passwordless_sudo; 197 | 198 | #[test] 199 | fn password_skipped_in_debug() { 200 | let user_with_password = 201 | User::new("azureuser", []).with_password("hunter2"); 202 | let user_without_password = User::new("azureuser", []); 203 | 204 | assert_eq!( 205 | "User { name: \"azureuser\", groups: [], ssh_keys: [], password: true }", 206 | format!("{:?}", user_with_password) 207 | ); 208 | assert_eq!( 209 | "User { name: \"azureuser\", groups: [], ssh_keys: [], password: false }", 210 | format!("{:?}", user_without_password) 211 | ); 212 | } 213 | 214 | #[test] 215 | fn test_passwordless_sudo_configured_successful() { 216 | let dir = tempdir().unwrap(); 217 | let path = dir.path().join("sudoers_file"); 218 | let path_str = path.to_str().unwrap(); 219 | 220 | let _user_insecure = User::new("azureuser", []); 221 | let ret = 222 | add_user_for_passwordless_sudo(&_user_insecure.name, path_str); 223 | 224 | assert!(ret.is_ok()); 225 | assert!( 226 | fs::metadata(path.clone()).is_ok(), 227 | "{path_str} file not created" 228 | ); 229 | let mode = fs::metadata(path_str) 230 | .expect("Sudoer file not created") 231 | .permissions() 232 | .mode(); 233 | assert_eq!(mode & 0o777, 0o600, "Permissions are not set properly"); 234 | assert_eq!( 235 | fs::read_to_string(path).unwrap(), 236 | "azureuser ALL=(ALL) NOPASSWD: ALL\n", 237 | "Contents of the file are not as expected" 238 | ); 239 | } 240 | } 241 | -------------------------------------------------------------------------------- /libazureinit/src/status.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT License. 3 | 4 | //! Provisioning status management for azure-init. 5 | //! 6 | //! This module ensures that provisioning is performed only when necessary 7 | //! by tracking the VM ID. It stores a provisioning status file named after 8 | //! the VM ID in a persistent location (`/var/lib/azure-init/`). 9 | //! 10 | //! # Logic Overview 11 | //! - Retrieves the VM ID using reading `/sys/class/dmi/id/product_uuid` and byte-swapping if Gen1 VM. 12 | //! - Determines if provisioning is required by checking if a status file exists. 13 | //! - The azure-init data directory is configurable via the Config struct (defaulting to `/var/lib/azure-init/`). 14 | //! - Creates the provisioning status file upon successful provisioning. 15 | //! - Prevents unnecessary re-provisioning on reboot, unless the VM ID changes. 16 | //! 17 | //! # Behavior 18 | //! - On **first boot**, provisioning runs and a file is created: `/var/lib/azure-init/{vm_id}` 19 | //! - On **reboot**, if the same VM ID exists, provisioning is skipped. 20 | //! - If the **VM ID changes** (e.g., due to VM cloning), provisioning runs again. 21 | 22 | use std::fs::{self, OpenOptions}; 23 | use std::os::unix::fs::{OpenOptionsExt, PermissionsExt}; 24 | use std::path::{Path, PathBuf}; 25 | use uuid::Uuid; 26 | 27 | use crate::config::{Config, DEFAULT_AZURE_INIT_DATA_DIR}; 28 | use crate::error::Error; 29 | 30 | /// This function determines the effective provisioning directory. 31 | /// 32 | /// If a [`Config`] is provided, this function returns `config.azure_init_data_dir.path`. 33 | /// Otherwise, it falls back to the default `/var/lib/azure-init/`. 34 | fn get_provisioning_dir(config: Option<&Config>) -> PathBuf { 35 | config 36 | .map(|cfg| cfg.azure_init_data_dir.path.clone()) 37 | .unwrap_or_else(|| PathBuf::from(DEFAULT_AZURE_INIT_DATA_DIR)) 38 | } 39 | 40 | /// This function checks if the azure-init data directory is present, and if not, 41 | /// it creates it. 42 | fn check_provision_dir(config: Option<&Config>) -> Result<(), Error> { 43 | let dir = get_provisioning_dir(config); 44 | if !dir.exists() { 45 | fs::create_dir_all(&dir)?; 46 | tracing::info!("Created provisioning directory: {}", dir.display()); 47 | 48 | if let Err(e) = 49 | fs::set_permissions(&dir, fs::Permissions::from_mode(0o700)) 50 | { 51 | tracing::warn!( 52 | "Failed to set permissions on {}: {}", 53 | dir.display(), 54 | e 55 | ); 56 | } else { 57 | tracing::info!( 58 | "Set secure permissions (700) on provisioning directory: {}", 59 | dir.display() 60 | ); 61 | } 62 | } 63 | 64 | Ok(()) 65 | } 66 | 67 | /// Determines if VM is a gen1 or gen2 based on EFI detection, 68 | /// Returns `true` if it is a Gen1 VM (i.e., not UEFI/Gen2). 69 | /// 70 | /// # Parameters: 71 | /// * `sysfs_efi_path` - An optional override for the default EFI path (`/sys/firmware/efi`). 72 | /// * `dev_efi_path` - An optional override for the default EFI device path (`/dev/efi`). 73 | /// 74 | /// If both parameters are `None`, the function checks the real system paths: 75 | /// `/sys/firmware/efi` and `/dev/efi`. 76 | fn is_vm_gen1( 77 | sysfs_efi_path: Option<&str>, 78 | dev_efi_path: Option<&str>, 79 | ) -> bool { 80 | let sysfs_efi = sysfs_efi_path.unwrap_or("/sys/firmware/efi"); 81 | let dev_efi = dev_efi_path.unwrap_or("/dev/efi"); 82 | 83 | // If *either* efi path exists, this is Gen2; if *neither* exist, Gen1 84 | // (equivalent to `!(exists(sysfs_efi) || exists(dev_efi))`) 85 | !Path::new(sysfs_efi).exists() && !Path::new(dev_efi).exists() 86 | } 87 | 88 | /// Converts the first three fields of a 16-byte array from big-endian to 89 | /// the native endianness, then returns it as a `Uuid`. 90 | /// 91 | /// This partially swaps the UUID so that d1 (4 bytes), d2 (2 bytes), and d3 (2 bytes) 92 | /// are converted from big-endian to the local endianness, leaving the final 8 bytes as-is. 93 | fn swap_uuid_to_little_endian(mut bytes: [u8; 16]) -> Uuid { 94 | let (d1, remainder) = bytes.split_at(std::mem::size_of::()); 95 | let d1 = d1 96 | .try_into() 97 | .map(u32::from_be_bytes) 98 | .unwrap_or(0) 99 | .to_ne_bytes(); 100 | 101 | let (d2, remainder) = remainder.split_at(std::mem::size_of::()); 102 | let d2 = d2 103 | .try_into() 104 | .map(u16::from_be_bytes) 105 | .unwrap_or(0) 106 | .to_ne_bytes(); 107 | 108 | let (d3, _) = remainder.split_at(std::mem::size_of::()); 109 | let d3 = d3 110 | .try_into() 111 | .map(u16::from_be_bytes) 112 | .unwrap_or(0) 113 | .to_ne_bytes(); 114 | 115 | let native_endian = d1.into_iter().chain(d2).chain(d3).collect::>(); 116 | debug_assert_eq!(native_endian.len(), 8); 117 | bytes[..native_endian.len()].copy_from_slice(&native_endian); 118 | uuid::Uuid::from_bytes(bytes) 119 | } 120 | 121 | /// Retrieves the VM ID by reading `/sys/class/dmi/id/product_uuid` and byte-swapping if Gen1. 122 | /// 123 | /// The VM ID is a unique system identifier that persists across reboots but changes 124 | /// when a VM is cloned or redeployed. 125 | /// 126 | /// # Returns 127 | /// - `Some(String)` containing the VM ID if retrieval is successful. 128 | /// - `None` if something fails or the output is empty. 129 | pub fn get_vm_id() -> Option { 130 | private_get_vm_id(None, None, None) 131 | } 132 | 133 | fn private_get_vm_id( 134 | product_uuid_path: Option<&str>, 135 | sysfs_efi_path: Option<&str>, 136 | dev_efi_path: Option<&str>, 137 | ) -> Option { 138 | let path = product_uuid_path.unwrap_or("/sys/class/dmi/id/product_uuid"); 139 | 140 | let system_uuid = match fs::read_to_string(path) { 141 | Ok(s) => s.trim().to_lowercase(), 142 | Err(err) => { 143 | tracing::error!("Failed to read VM ID from {}: {}", path, err); 144 | return None; 145 | } 146 | }; 147 | 148 | if system_uuid.is_empty() { 149 | tracing::info!(target: "libazureinit::status::retrieved_vm_id", "VM ID file is empty at path: {}", path); 150 | return None; 151 | } 152 | 153 | if is_vm_gen1(sysfs_efi_path, dev_efi_path) { 154 | match Uuid::parse_str(&system_uuid) { 155 | Ok(uuid_parsed) => { 156 | let swapped_uuid = 157 | swap_uuid_to_little_endian(*uuid_parsed.as_bytes()); 158 | let final_id = swapped_uuid.to_string(); 159 | tracing::info!( 160 | target: "libazureinit::status::retrieved_vm_id", 161 | "VM ID (Gen1, swapped): {}", 162 | final_id 163 | ); 164 | Some(final_id) 165 | } 166 | Err(e) => { 167 | tracing::error!( 168 | "Failed to parse system UUID '{}': {}", 169 | system_uuid, 170 | e 171 | ); 172 | Some(system_uuid) 173 | } 174 | } 175 | } else { 176 | tracing::info!( 177 | target: "libazureinit::status::retrieved_vm_id", 178 | "VM ID (Gen2, no swap): {}", 179 | system_uuid 180 | ); 181 | Some(system_uuid) 182 | } 183 | } 184 | 185 | /// Checks whether a provisioning status file exists for the current VM ID. 186 | /// 187 | /// If the provisioning status file exists, it indicates that provisioning has already been 188 | /// completed, and the process should be skipped. If the file does not exist or the VM ID has 189 | /// changed, provisioning should proceed. 190 | /// 191 | /// # Parameters 192 | /// - `config`: An optional configuration reference used to determine the provisioning directory. 193 | /// If `None`, the default provisioning directory defined by `DEFAULT_AZURE_INIT_DATA_DIR` is used. 194 | /// 195 | /// # Returns 196 | /// - `true` if provisioning is complete (i.e., the provisioning file exists). 197 | /// - `false` if provisioning has not been completed (i.e., no provisioning file exists). 198 | pub fn is_provisioning_complete(config: Option<&Config>, vm_id: &str) -> bool { 199 | let file_path = 200 | get_provisioning_dir(config).join(format!("{}.provisioned", vm_id)); 201 | 202 | if std::path::Path::new(&file_path).exists() { 203 | tracing::info!("Provisioning already complete. Skipping..."); 204 | return true; 205 | } 206 | tracing::info!("Provisioning required."); 207 | false 208 | } 209 | 210 | /// Marks provisioning as complete by creating a provisioning status file. 211 | /// 212 | /// This function ensures that the provisioning directory exists, retrieves the VM ID, 213 | /// and creates a `{vm_id}.provisioned` file in the provisioning directory. 214 | /// 215 | /// # Parameters 216 | /// - `config`: An optional configuration reference used to determine the provisioning directory. 217 | /// If `None`, the default provisioning directory defined by `DEFAULT_AZURE_INIT_DATA_DIR` is used. 218 | /// 219 | /// # Returns 220 | /// - `Ok(())` if the provisioning status file was successfully created. 221 | /// - `Err(Error)` if an error occurred while creating the provisioning file. 222 | pub fn mark_provisioning_complete( 223 | config: Option<&Config>, 224 | vm_id: &str, 225 | ) -> Result<(), Error> { 226 | check_provision_dir(config)?; 227 | let file_path = 228 | get_provisioning_dir(config).join(format!("{}.provisioned", vm_id)); 229 | 230 | match OpenOptions::new() 231 | .create(true) 232 | .write(true) 233 | .truncate(true) 234 | .mode(0o600) // Ensures correct permissions from the start 235 | .open(&file_path) 236 | { 237 | Ok(_) => { 238 | tracing::info!( 239 | target: "libazureinit::status::success", 240 | "Provisioning complete. File created: {}", 241 | file_path.display() 242 | ); 243 | } 244 | Err(error) => { 245 | tracing::error!( 246 | ?error, 247 | file_path=?file_path, 248 | "Failed to create provisioning status file" 249 | ); 250 | return Err(error.into()); 251 | } 252 | } 253 | 254 | Ok(()) 255 | } 256 | 257 | #[cfg(test)] 258 | mod tests { 259 | use super::*; 260 | use std::fs; 261 | use std::fs::{create_dir, remove_dir}; 262 | use tempfile::TempDir; 263 | 264 | /// Creates a temporary directory and returns a default `Config` 265 | /// whose `azure_init_data_dir` points to that temp directory. 266 | /// Also returns the `TempDir` so it remains in scope for the test. 267 | fn create_test_config() -> (Config, TempDir) { 268 | let test_dir = TempDir::new().unwrap(); 269 | 270 | let mut test_config = Config::default(); 271 | test_config.azure_init_data_dir.path = test_dir.path().to_path_buf(); 272 | 273 | (test_config, test_dir) 274 | } 275 | 276 | #[test] 277 | fn test_gen1_vm() { 278 | assert!(is_vm_gen1( 279 | Some("/nonexistent_sysfs_efi"), 280 | Some("/nonexistent_dev_efi") 281 | )); 282 | } 283 | 284 | #[test] 285 | fn test_gen2_vm_with_sysfs_efi() { 286 | let mock_path = "/tmp/mock_efi"; 287 | create_dir(mock_path).ok(); 288 | assert!(!is_vm_gen1(Some(mock_path), Some("/nonexistent_dev_efi"))); 289 | remove_dir(mock_path).ok(); 290 | } 291 | 292 | #[test] 293 | fn test_gen2_vm_with_dev_efi() { 294 | let mock_path = "/tmp/mock_dev_efi"; 295 | create_dir(mock_path).ok(); 296 | assert!(!is_vm_gen1(Some("/nonexistent_sysfs_efi"), Some(mock_path))); 297 | remove_dir(mock_path).ok(); 298 | } 299 | 300 | #[test] 301 | fn test_mark_provisioning_complete() { 302 | let (test_config, test_dir) = create_test_config(); 303 | 304 | let mock_vm_id_path = test_dir.path().join("mock_product_uuid"); 305 | fs::write(&mock_vm_id_path, "550e8400-e29b-41d4-a716-446655440000") 306 | .unwrap(); 307 | let vm_id = private_get_vm_id( 308 | Some(mock_vm_id_path.to_str().unwrap()), 309 | None, 310 | None, 311 | ) 312 | .unwrap(); 313 | 314 | let file_path = test_dir.path().join(format!("{}.provisioned", vm_id)); 315 | assert!( 316 | !file_path.exists(), 317 | "File should not exist before provisioning" 318 | ); 319 | 320 | mark_provisioning_complete(Some(&test_config), &vm_id).unwrap(); 321 | assert!(file_path.exists(), "Provisioning file should be created"); 322 | } 323 | 324 | #[test] 325 | fn test_is_provisioning_complete() { 326 | let (test_config, test_dir) = create_test_config(); 327 | 328 | let mock_vm_id_path = test_dir.path().join("mock_product_uuid"); 329 | fs::write(&mock_vm_id_path, "550e8400-e29b-41d4-a716-446655440001") 330 | .unwrap(); 331 | 332 | let vm_id = private_get_vm_id( 333 | Some(mock_vm_id_path.to_str().unwrap()), 334 | None, 335 | None, 336 | ) 337 | .unwrap(); 338 | 339 | let file_path = test_dir.path().join(format!("{}.provisioned", vm_id)); 340 | fs::File::create(&file_path).unwrap(); 341 | 342 | assert!( 343 | is_provisioning_complete(Some(&test_config), &vm_id,), 344 | "Provisioning should be complete if file exists" 345 | ); 346 | } 347 | 348 | #[test] 349 | fn test_provisioning_skipped_on_simulated_reboot() { 350 | let (test_config, test_dir) = create_test_config(); 351 | 352 | let mock_vm_id_path = test_dir.path().join("mock_product_uuid"); 353 | fs::write(&mock_vm_id_path, "550e8400-e29b-41d4-a716-446655440002") 354 | .unwrap(); 355 | 356 | let vm_id = private_get_vm_id( 357 | Some(mock_vm_id_path.to_str().unwrap()), 358 | None, 359 | None, 360 | ) 361 | .unwrap(); 362 | 363 | assert!( 364 | !is_provisioning_complete(Some(&test_config), &vm_id), 365 | "Provisioning should NOT be complete initially" 366 | ); 367 | 368 | mark_provisioning_complete(Some(&test_config), &vm_id).unwrap(); 369 | 370 | // Simulate a "reboot" by calling again 371 | assert!( 372 | is_provisioning_complete(Some(&test_config), &vm_id,), 373 | "Provisioning should be skipped on second run (file exists)" 374 | ); 375 | } 376 | 377 | #[test] 378 | fn test_get_vm_id_gen1() { 379 | let tmpdir = TempDir::new().unwrap(); 380 | let vm_uuid_path = tmpdir.path().join("product_uuid"); 381 | fs::write(&vm_uuid_path, "550e8400-e29b-41d4-a716-446655440000") 382 | .unwrap(); 383 | 384 | // No sysfs_efi or dev_efi path created => means neither exists => expect Gen1 385 | let res = private_get_vm_id( 386 | Some(vm_uuid_path.to_str().unwrap()), 387 | Some("/this_does_not_exist"), 388 | Some("/still_nope"), 389 | ); 390 | assert_eq!( 391 | res.unwrap(), 392 | "00840e55-9be2-d441-a716-446655440000", 393 | "Should byte-swap for Gen1" 394 | ); 395 | } 396 | 397 | #[test] 398 | fn test_get_vm_id_gen2() { 399 | let tmpdir = TempDir::new().unwrap(); 400 | let vm_uuid_path = tmpdir.path().join("product_uuid"); 401 | fs::write(&vm_uuid_path, "550e8400-e29b-41d4-a716-446655440000") 402 | .unwrap(); 403 | 404 | // Create a mock EFI directory => at least one path exists => Gen2 405 | let mock_efi_dir = tmpdir.path().join("mock_efi"); 406 | fs::create_dir(&mock_efi_dir).unwrap(); 407 | 408 | let res = private_get_vm_id( 409 | Some(vm_uuid_path.to_str().unwrap()), 410 | Some(mock_efi_dir.to_str().unwrap()), 411 | None, 412 | ); 413 | assert_eq!( 414 | res.unwrap(), 415 | "550e8400-e29b-41d4-a716-446655440000", 416 | "Should not byte-swap for Gen2" 417 | ); 418 | } 419 | } 420 | -------------------------------------------------------------------------------- /libazureinit/src/unittest.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT License. 3 | 4 | use reqwest::StatusCode; 5 | use tokio::io::AsyncWriteExt; 6 | use tokio::net::TcpListener; 7 | use tokio_util::sync::CancellationToken; 8 | 9 | /// Returns expected HTTP response for the given status code and body string. 10 | pub(crate) fn get_http_response_payload( 11 | statuscode: &StatusCode, 12 | body_str: &str, 13 | ) -> String { 14 | // Reply message includes the whole body in case of OK, otherwise empty data. 15 | let res = match statuscode { 16 | &StatusCode::OK => format!("HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", statuscode.as_u16(), statuscode.to_string(), body_str.len(), body_str.to_string()), 17 | _ => { 18 | format!("HTTP/1.1 {} {}\r\n\r\n", statuscode.as_u16(), statuscode.to_string()) 19 | } 20 | }; 21 | 22 | res 23 | } 24 | 25 | /// Accept incoming connections until the cancellation token is used, then return the count 26 | /// of accepted connections. 27 | pub(crate) async fn serve_requests( 28 | listener: TcpListener, 29 | payload: String, 30 | cancel_token: CancellationToken, 31 | ) -> u32 { 32 | let mut request_count = 0; 33 | 34 | loop { 35 | tokio::select! { 36 | _ = cancel_token.cancelled() => { 37 | break; 38 | } 39 | _ = async { 40 | let (mut serverstream, _) = listener.accept().await.unwrap(); 41 | 42 | serverstream.write_all(payload.as_bytes()).await.unwrap(); 43 | } => { 44 | request_count += 1; 45 | } 46 | } 47 | } 48 | 49 | request_count 50 | } 51 | -------------------------------------------------------------------------------- /src/logging.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT License. 3 | 4 | use opentelemetry::{global, trace::TracerProvider}; 5 | use opentelemetry_sdk::trace::{self as sdktrace, Sampler, SdkTracerProvider}; 6 | use std::fs::{OpenOptions, Permissions}; 7 | use std::os::unix::fs::PermissionsExt; 8 | use tracing::{event, Level}; 9 | use tracing_opentelemetry::OpenTelemetryLayer; 10 | use tracing_subscriber::fmt::format::FmtSpan; 11 | use tracing_subscriber::{ 12 | fmt, layer::SubscriberExt, EnvFilter, Layer, Registry, 13 | }; 14 | 15 | use crate::kvp::EmitKVPLayer; 16 | use libazureinit::config::Config; 17 | 18 | pub fn initialize_tracing() -> sdktrace::Tracer { 19 | let provider = SdkTracerProvider::builder() 20 | .with_sampler(Sampler::AlwaysOn) 21 | .build(); 22 | 23 | global::set_tracer_provider(provider.clone()); 24 | provider.tracer("azure-kvp") 25 | } 26 | 27 | /// Builds a `tracing` subscriber that can optionally write azure-init.log 28 | /// to a specific location if `Some(&Config)` is provided. 29 | /// 30 | /// This function follows a two-phase initialization: 31 | /// - Minimal Setup (Pre-Config): When called initially, it sets up basic logging 32 | /// to console (`stderr`), KVP (Hyper-V), and OpenTelemetry without file logging. 33 | /// 34 | /// - Full Setup (Post-Config): After the configuration is loaded, it is called again 35 | /// with `config`, adding file logging to `config.azure_init_log_path.path` or 36 | /// falling back to `DEFAULT_AZURE_INIT_LOG_PATH` if unspecified. 37 | pub fn setup_layers( 38 | tracer: sdktrace::Tracer, 39 | vm_id: &str, 40 | config: &Config, 41 | ) -> Result<(), Box> { 42 | let otel_layer = OpenTelemetryLayer::new(tracer) 43 | .with_filter(EnvFilter::from_env("AZURE_INIT_LOG")); 44 | 45 | let kvp_filter = EnvFilter::builder().parse( 46 | [ 47 | "WARN", 48 | "azure_init=INFO", 49 | "libazureinit::config::success", 50 | "libazureinit::http::received", 51 | "libazureinit::http::success", 52 | "libazureinit::ssh::authorized_keys", 53 | "libazureinit::ssh::success", 54 | "libazureinit::user::add", 55 | "libazureinit::status::success", 56 | "libazureinit::status::retrieved_vm_id", 57 | ] 58 | .join(","), 59 | )?; 60 | 61 | let emit_kvp_layer = if config.telemetry.kvp_diagnostics { 62 | match EmitKVPLayer::new( 63 | std::path::PathBuf::from("/var/lib/hyperv/.kvp_pool_1"), 64 | vm_id, 65 | ) { 66 | Ok(layer) => Some(layer.with_filter(kvp_filter)), 67 | Err(e) => { 68 | event!(Level::ERROR, "Failed to initialize EmitKVPLayer: {}. Continuing without KVP logging.", e); 69 | None 70 | } 71 | } 72 | } else { 73 | event!( 74 | Level::INFO, 75 | "Hyper-V KVP diagnostics are disabled via config. It is recommended to be enabled for support purposes." 76 | ); 77 | None 78 | }; 79 | 80 | let stderr_layer = fmt::layer() 81 | .with_span_events(FmtSpan::NEW | FmtSpan::CLOSE) 82 | .with_writer(std::io::stderr) 83 | .with_filter(EnvFilter::from_env("AZURE_INIT_LOG")); 84 | 85 | let file_layer = match OpenOptions::new() 86 | .create(true) 87 | .append(true) 88 | .open(&config.azure_init_log_path.path) 89 | { 90 | Ok(file) => { 91 | if let Err(e) = file.set_permissions(Permissions::from_mode(0o600)) 92 | { 93 | event!( 94 | Level::WARN, 95 | "Failed to set permissions on {}: {}.", 96 | config.azure_init_log_path.path.display(), 97 | e, 98 | ); 99 | } 100 | 101 | Some( 102 | fmt::layer() 103 | .with_span_events(FmtSpan::NEW | FmtSpan::CLOSE) 104 | .with_writer(file) 105 | .with_filter(EnvFilter::from_env("AZURE_INIT_LOG")), 106 | ) 107 | } 108 | Err(e) => { 109 | event!( 110 | Level::ERROR, 111 | "Could not open configured log file {}: {}. Continuing without file logging.", 112 | config.azure_init_log_path.path.display(), 113 | e 114 | ); 115 | 116 | None 117 | } 118 | }; 119 | 120 | let subscriber = Registry::default() 121 | .with(stderr_layer) 122 | .with(otel_layer) 123 | .with(emit_kvp_layer) 124 | .with(file_layer); 125 | 126 | tracing::subscriber::set_global_default(subscriber)?; 127 | 128 | Ok(()) 129 | } 130 | -------------------------------------------------------------------------------- /src/main.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT License. 3 | use std::path::PathBuf; 4 | mod kvp; 5 | mod logging; 6 | pub use logging::{initialize_tracing, setup_layers}; 7 | 8 | use anyhow::Context; 9 | use clap::{Parser, Subcommand}; 10 | use libazureinit::config::Config; 11 | use libazureinit::imds::InstanceMetadata; 12 | use libazureinit::User; 13 | use libazureinit::{ 14 | error::Error as LibError, 15 | goalstate, imds, media, 16 | media::{get_mount_device, Environment}, 17 | reqwest::{header, Client}, 18 | Provision, 19 | }; 20 | use libazureinit::{ 21 | get_vm_id, is_provisioning_complete, mark_provisioning_complete, 22 | }; 23 | use std::process::ExitCode; 24 | use std::time::Duration; 25 | use sysinfo::System; 26 | use tracing::instrument; 27 | use tracing_subscriber::{prelude::*, Layer}; 28 | 29 | // These should be set during the build process 30 | const VERSION: &str = env!("CARGO_PKG_VERSION"); 31 | const COMMIT_HASH: &str = env!("GIT_COMMIT_HASH"); 32 | 33 | /// Minimal provisioning agent for Azure 34 | /// 35 | /// Create a user, add SSH public keys, and set the hostname. 36 | /// By default, if no subcommand is specified, this will provision the host. 37 | /// 38 | /// Arguments provided via command-line arguments override any arguments provided 39 | /// via environment variables. 40 | #[derive(Parser, Debug)] 41 | struct Cli { 42 | /// List of supplementary groups of the provisioned user account. 43 | /// 44 | /// Values can be comma-separated and the argument can be provided multiple times. 45 | #[arg( 46 | long, 47 | short, 48 | env = "AZURE_INIT_USER_GROUPS", 49 | value_delimiter = ',', 50 | default_value = "" 51 | )] 52 | groups: Vec, 53 | 54 | #[arg( 55 | long, 56 | help = "Path to the configuration file", 57 | env = "AZURE_INIT_CONFIG" 58 | )] 59 | config: Option, 60 | 61 | #[command(subcommand)] 62 | command: Option, 63 | } 64 | 65 | #[derive(Subcommand, Debug)] 66 | enum Command { 67 | /// By default, this removes provisioning state data. Optional flags can be used 68 | /// to clean logs or additional generated files. 69 | Clean { 70 | /// Cleans the log files as defined in the configuration file 71 | #[arg(long)] 72 | logs: bool, 73 | }, 74 | } 75 | 76 | #[instrument] 77 | fn get_environment() -> Result { 78 | let ovf_devices = get_mount_device(None)?; 79 | let mut environment: Option = None; 80 | 81 | // loop until it finds a correct device. 82 | for dev in ovf_devices { 83 | environment = match media::mount_parse_ovf_env(dev) { 84 | Ok(env) => Some(env), 85 | Err(_) => continue, 86 | } 87 | } 88 | 89 | environment.ok_or_else(|| { 90 | tracing::error!("Unable to get list of block devices"); 91 | anyhow::anyhow!("Unable to get list of block devices") 92 | }) 93 | } 94 | 95 | #[instrument(skip_all)] 96 | fn get_username( 97 | instance_metadata: Option<&InstanceMetadata>, 98 | environment: Option<&Environment>, 99 | ) -> Result { 100 | if let Some(metadata) = instance_metadata { 101 | return Ok(metadata.compute.os_profile.admin_username.clone()); 102 | } 103 | 104 | // Read username from OVF environment via mounted local device. 105 | environment 106 | .map(|env| { 107 | env.clone() 108 | .provisioning_section 109 | .linux_prov_conf_set 110 | .username 111 | }) 112 | .ok_or_else(|| { 113 | tracing::error!("Username Failure"); 114 | LibError::UsernameFailure.into() 115 | }) 116 | } 117 | 118 | /// Cleans all provisioning state marker files from the azure-init data directory. 119 | /// 120 | /// This removes all files ending in `.provisioned` from the directory specified 121 | /// by `azure_init_data_dir` (typically `/var/lib/azure-init`). These marker files 122 | /// indicate that provisioning has completed. Removing them allows azure-init to 123 | /// re-run provisioning logic on the next boot. 124 | #[instrument] 125 | fn clean_provisioning_status(config: &Config) -> Result<(), std::io::Error> { 126 | let data_dir = &config.azure_init_data_dir.path; 127 | let mut found = false; 128 | 129 | for entry in std::fs::read_dir(data_dir)? { 130 | let path = match entry { 131 | Ok(e) => e.path(), 132 | Err(e) => { 133 | tracing::error!( 134 | "Failed to read directory entry in {:?}: {:?}", 135 | data_dir, 136 | e 137 | ); 138 | return Err(e); 139 | } 140 | }; 141 | 142 | if path.extension().is_some_and(|ext| ext == "provisioned") { 143 | found = true; 144 | 145 | match std::fs::remove_file(&path) { 146 | Ok(_) => { 147 | tracing::info!( 148 | "Successfully removed provisioning state at: {:?}", 149 | path 150 | ); 151 | } 152 | Err(e) if e.kind() == std::io::ErrorKind::NotFound => { 153 | tracing::info!( 154 | "No provisioning marker found at: {:?}", 155 | path 156 | ); 157 | } 158 | Err(e) => { 159 | tracing::error!( 160 | "Failed to clean provisioning marker {:?}: {:?}", 161 | path, 162 | e 163 | ); 164 | return Err(e); 165 | } 166 | } 167 | } 168 | } 169 | 170 | if !found { 171 | tracing::info!( 172 | "No provisioning marker files (*.provisioned) found in {:?}", 173 | data_dir 174 | ); 175 | } 176 | 177 | Ok(()) 178 | } 179 | 180 | /// Cleans the azure-init log file defined in the configuration. 181 | /// 182 | /// This removes the log file at the path configured by `azure_init_log_path`, 183 | /// which defaults to `/var/log/azure-init.log`. If the file does not exist, 184 | /// a message is logged but no error is returned. 185 | #[instrument] 186 | fn clean_log_file(config: &Config) -> Result<(), std::io::Error> { 187 | let log_path = &config.azure_init_log_path.path; 188 | 189 | match std::fs::remove_file(log_path) { 190 | Ok(_) => { 191 | tracing::info!("Successfully removed log file at: {:?}", log_path); 192 | } 193 | Err(e) if e.kind() == std::io::ErrorKind::NotFound => { 194 | tracing::info!("No log file found at: {:?}", log_path); 195 | } 196 | Err(e) => { 197 | tracing::error!("Failed to clean log file {:?}: {:?}", log_path, e); 198 | return Err(e); 199 | } 200 | } 201 | 202 | Ok(()) 203 | } 204 | 205 | #[tokio::main] 206 | async fn main() -> ExitCode { 207 | let tracer = initialize_tracing(); 208 | let vm_id: String = get_vm_id() 209 | .unwrap_or_else(|| "00000000-0000-0000-0000-000000000000".to_string()); 210 | let opts = Cli::parse(); 211 | 212 | let temp_layer = tracing_subscriber::fmt::layer() 213 | .with_span_events(tracing_subscriber::fmt::format::FmtSpan::NONE) 214 | .with_writer(std::io::stderr) 215 | .with_filter(tracing_subscriber::EnvFilter::new( 216 | "libazureinit::config=info", 217 | )); 218 | 219 | let temp_subscriber = 220 | tracing_subscriber::Registry::default().with(temp_layer); 221 | 222 | let config = 223 | match tracing::subscriber::with_default(temp_subscriber, || { 224 | Config::load(opts.config.clone()) 225 | }) { 226 | Ok(cfg) => cfg, 227 | Err(error) => { 228 | eprintln!("Failed to load configuration: {error:?}"); 229 | eprintln!("Example configuration:\n\n{}", Config::default()); 230 | return ExitCode::FAILURE; 231 | } 232 | }; 233 | 234 | if let Err(e) = setup_layers(tracer, &vm_id, &config) { 235 | tracing::error!("Failed to set final logging subscriber: {e:?}"); 236 | } 237 | 238 | tracing::info!( 239 | target = "libazureinit::config::success", 240 | "Final configuration: {:#?}", 241 | config 242 | ); 243 | 244 | if let Some(Command::Clean { logs }) = opts.command { 245 | if clean_provisioning_status(&config).is_err() { 246 | return ExitCode::FAILURE; 247 | } 248 | 249 | if logs && clean_log_file(&config).is_err() { 250 | return ExitCode::FAILURE; 251 | } 252 | 253 | return ExitCode::SUCCESS; 254 | } 255 | 256 | if is_provisioning_complete(Some(&config), &vm_id) { 257 | tracing::info!( 258 | "Provisioning already completed earlier. Skipping provisioning." 259 | ); 260 | return ExitCode::SUCCESS; 261 | } 262 | 263 | match provision(config, &vm_id, opts).await { 264 | Ok(_) => ExitCode::SUCCESS, 265 | Err(e) => { 266 | tracing::error!("Provisioning failed with error: {:?}", e); 267 | eprintln!("{:?}", e); 268 | let config: u8 = exitcode::CONFIG 269 | .try_into() 270 | .expect("Error code must be less than 256"); 271 | match e.root_cause().downcast_ref::() { 272 | Some(LibError::UserMissing { user: _ }) => { 273 | ExitCode::from(config) 274 | } 275 | Some(LibError::NonEmptyPassword) => ExitCode::from(config), 276 | Some(_) | None => ExitCode::FAILURE, 277 | } 278 | } 279 | } 280 | } 281 | 282 | #[instrument(name = "root", skip_all)] 283 | async fn provision( 284 | config: Config, 285 | vm_id: &str, 286 | opts: Cli, 287 | ) -> Result<(), anyhow::Error> { 288 | let kernel_version = System::kernel_version() 289 | .unwrap_or("Unknown Kernel Version".to_string()); 290 | let os_version = 291 | System::os_version().unwrap_or("Unknown OS Version".to_string()); 292 | 293 | tracing::info!( 294 | "Kernel Version: {}, OS Version: {}, Azure-Init Version: {}", 295 | kernel_version, 296 | os_version, 297 | VERSION 298 | ); 299 | 300 | let clone_config = config.clone(); 301 | 302 | let mut default_headers = header::HeaderMap::new(); 303 | let user_agent = if cfg!(debug_assertions) { 304 | format!("azure-init v{}-{}", VERSION, COMMIT_HASH) 305 | } else { 306 | format!("azure-init v{}", VERSION) 307 | }; 308 | let user_agent = header::HeaderValue::from_str(user_agent.as_str())?; 309 | default_headers.insert(header::USER_AGENT, user_agent); 310 | let client = Client::builder() 311 | .timeout(std::time::Duration::from_secs(30)) 312 | .default_headers(default_headers) 313 | .build()?; 314 | 315 | let imds_http_timeout_sec: u64 = 5 * 60; 316 | let imds_http_retry_interval_sec: u64 = 2; 317 | 318 | // Username can be obtained either via fetching instance metadata from IMDS 319 | // or mounting a local device for OVF environment file. It should not fail 320 | // immediately in a single failure, instead it should fall back to the other 321 | // mechanism. So it is not a good idea to use `?` for query() or 322 | // get_environment(). 323 | let instance_metadata = imds::query( 324 | &client, 325 | Duration::from_secs(imds_http_retry_interval_sec), 326 | Duration::from_secs(imds_http_timeout_sec), 327 | None, // default IMDS URL 328 | ) 329 | .await 330 | .ok(); 331 | 332 | let environment = get_environment().ok(); 333 | 334 | let username = 335 | get_username(instance_metadata.as_ref(), environment.as_ref())?; 336 | 337 | // It is necessary to get the actual instance metadata after getting username, 338 | // as it is not desirable to immediately return error before get_username. 339 | let im = instance_metadata 340 | .clone() 341 | .ok_or::(LibError::InstanceMetadataFailure)?; 342 | 343 | let user = 344 | User::new(username, im.compute.public_keys).with_groups(opts.groups); 345 | 346 | Provision::new(im.compute.os_profile.computer_name, user, config) 347 | .provision()?; 348 | 349 | let vm_goalstate = goalstate::get_goalstate( 350 | &client, 351 | Duration::from_secs(imds_http_retry_interval_sec), 352 | Duration::from_secs(imds_http_timeout_sec), 353 | None, // default wireserver goalstate URL 354 | ) 355 | .await 356 | .with_context(|| { 357 | tracing::error!("Failed to get the desired goalstate."); 358 | "Failed to get desired goalstate." 359 | })?; 360 | 361 | goalstate::report_health( 362 | &client, 363 | vm_goalstate, 364 | Duration::from_secs(imds_http_retry_interval_sec), 365 | Duration::from_secs(imds_http_timeout_sec), 366 | None, // default wireserver health URL 367 | ) 368 | .await 369 | .with_context(|| { 370 | tracing::error!("Failed to report VM health."); 371 | "Failed to report VM health." 372 | })?; 373 | 374 | mark_provisioning_complete(Some(&clone_config), vm_id).with_context( 375 | || { 376 | tracing::error!("Failed to mark provisioning complete."); 377 | "Failed to mark provisioning complete." 378 | }, 379 | )?; 380 | 381 | Ok(()) 382 | } 383 | -------------------------------------------------------------------------------- /tests/cli.rs: -------------------------------------------------------------------------------- 1 | use std::process::Command; 2 | 3 | use assert_cmd::prelude::*; 4 | use predicates::prelude::*; 5 | 6 | use std::fs::{self, File}; 7 | use std::io::Write; 8 | use tempfile::tempdir; 9 | 10 | // Assert help text includes the --groups flag 11 | #[test] 12 | fn help_groups() -> Result<(), Box> { 13 | let mut command = Command::cargo_bin("azure-init")?; 14 | command.arg("--help"); 15 | command 16 | .assert() 17 | .success() 18 | .stdout(predicate::str::contains("-g, --groups ")); 19 | 20 | Ok(()) 21 | } 22 | 23 | #[test] 24 | fn clean_removes_provision_and_log_files( 25 | ) -> Result<(), Box> { 26 | let temp_dir = tempdir()?; 27 | let data_dir = temp_dir.path().join("data"); 28 | let log_file = temp_dir.path().join("azure-init.log"); 29 | fs::create_dir_all(&data_dir)?; 30 | 31 | let provisioned_file = data_dir.join("vm-id.provisioned"); 32 | File::create(&provisioned_file)?; 33 | 34 | let mut log = File::create(&log_file)?; 35 | writeln!(log, "fake log line")?; 36 | 37 | let config_contents = format!( 38 | r#" 39 | [azure_init_data_dir] 40 | path = "{}" 41 | 42 | [azure_init_log_path] 43 | path = "{}" 44 | "#, 45 | data_dir.display(), 46 | log_file.display() 47 | ); 48 | let config_path = temp_dir.path().join("azure-init-config.toml"); 49 | fs::write(&config_path, config_contents)?; 50 | 51 | assert!( 52 | provisioned_file.exists(), 53 | ".provisioned file should exist before cleaning" 54 | ); 55 | assert!(log_file.exists(), "log file should exist before cleaning"); 56 | 57 | let mut cmd = Command::cargo_bin("azure-init")?; 58 | cmd.args(["--config", config_path.to_str().unwrap(), "clean", "--logs"]); 59 | 60 | cmd.assert().success(); 61 | 62 | assert!( 63 | !provisioned_file.exists(), 64 | "Expected .provisioned file to be deleted" 65 | ); 66 | assert!( 67 | !log_file.exists(), 68 | "Expected azure-init.log file to be deleted" 69 | ); 70 | 71 | Ok(()) 72 | } 73 | -------------------------------------------------------------------------------- /tests/functional_tests.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT License. 3 | 4 | use libazureinit::config::Config; 5 | use libazureinit::imds::PublicKeys; 6 | use libazureinit::User; 7 | use libazureinit::{ 8 | goalstate, imds, 9 | reqwest::{header, Client}, 10 | Provision, 11 | }; 12 | use std::env; 13 | use std::time::Duration; 14 | 15 | #[tokio::main] 16 | async fn main() { 17 | let config = Config::default(); 18 | 19 | let cli_args: Vec = env::args().collect(); 20 | let mut default_headers = header::HeaderMap::new(); 21 | let user_agent = header::HeaderValue::from_str("azure-init").unwrap(); 22 | default_headers.insert(header::USER_AGENT, user_agent); 23 | let client = Client::builder() 24 | .timeout(std::time::Duration::from_secs(30)) 25 | .default_headers(default_headers) 26 | .build() 27 | .unwrap(); 28 | 29 | println!(); 30 | println!("**********************************"); 31 | println!("* Beginning functional testing"); 32 | println!("**********************************"); 33 | println!(); 34 | 35 | println!("Querying wireserver for Goalstate"); 36 | 37 | let http_timeout_sec: u64 = 5 * 60; 38 | let http_retry_interval_sec: u64 = 2; 39 | 40 | let get_goalstate_result = goalstate::get_goalstate( 41 | &client, 42 | Duration::from_secs(http_retry_interval_sec), 43 | Duration::from_secs(http_timeout_sec), 44 | None, // default wireserver goalstate URL 45 | ) 46 | .await; 47 | let vm_goalstate = match get_goalstate_result { 48 | Ok(vm_goalstate) => vm_goalstate, 49 | Err(_err) => return, 50 | }; 51 | 52 | println!("Goalstate successfully received"); 53 | println!(); 54 | println!("Reporting VM Health to wireserver"); 55 | 56 | let report_health_result = goalstate::report_health( 57 | &client, 58 | vm_goalstate, 59 | Duration::from_secs(http_retry_interval_sec), 60 | Duration::from_secs(http_timeout_sec), 61 | None, // default wireserver health URL 62 | ) 63 | .await; 64 | match report_health_result { 65 | Ok(report_health) => report_health, 66 | Err(_err) => return, 67 | }; 68 | 69 | println!("VM Health successfully reported"); 70 | 71 | let imds_http_timeout_sec: u64 = 5 * 60; 72 | let imds_http_retry_interval_sec: u64 = 2; 73 | 74 | // Simplified version of calling imds::query. Since username is directly 75 | // given by cli_args below, it is not needed to get instance metadata like 76 | // how it is done in provision() in main. 77 | let _ = imds::query( 78 | &client, 79 | Duration::from_secs(imds_http_retry_interval_sec), 80 | Duration::from_secs(imds_http_timeout_sec), 81 | None, // default IMDS URL 82 | ) 83 | .await 84 | .expect("Failed to query IMDS"); 85 | 86 | let username = &cli_args[1]; 87 | 88 | let keys: Vec = vec![ 89 | PublicKeys { 90 | path: "/path/to/.ssh/keys/".to_owned(), 91 | key_data: "ssh-rsa test_key_1".to_owned(), 92 | }, 93 | PublicKeys { 94 | path: "/path/to/.ssh/keys/".to_owned(), 95 | key_data: "ssh-rsa test_key_2".to_owned(), 96 | }, 97 | PublicKeys { 98 | path: "/path/to/.ssh/keys/".to_owned(), 99 | key_data: "ssh-rsa test_key_3".to_owned(), 100 | }, 101 | ]; 102 | 103 | Provision::new( 104 | "my-hostname".to_string(), 105 | User::new(username, keys), 106 | config, 107 | ) 108 | .provision() 109 | .expect("Failed to provision host"); 110 | 111 | println!("VM successfully provisioned"); 112 | println!(); 113 | 114 | println!("**********************************"); 115 | println!("* Functional testing completed successfully!"); 116 | println!("**********************************"); 117 | println!(); 118 | } 119 | -------------------------------------------------------------------------------- /tests/functional_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Microsoft Corporation. 3 | # Licensed under the MIT License. 4 | 5 | SUBSCRIPTION_ID="${SUBSCRIPTION_ID:-}" 6 | EPOCH=$(date +%s) 7 | RG="${RG:-e2etest-azinit-$EPOCH}" 8 | LOCATION="${LOCATION:-eastus}" 9 | PATH_TO_PUBLIC_SSH_KEY="$HOME/.ssh/id_rsa.pub" 10 | PATH_TO_PRIVATE_SSH_KEY="$HOME/.ssh/id_rsa" 11 | VM_NAME="${VM_NAME:-AzInitFunctionalTest}" 12 | VM_IMAGE="${VM_IMAGE:-Canonical:0001-com-ubuntu-server-jammy:22_04-lts-gen2:latest}" 13 | VM_SIZE="${VM_SIZE:-Standard_D2lds_v5}" 14 | VM_ADMIN_USERNAME="${VM_ADMIN_USERNAME:-azureuser}" 15 | AZURE_SSH_KEY_NAME="${AZURE_SSH_KEY_NAME:-azure-ssh-key}" 16 | VM_NAME_WITH_TIMESTAMP=$VM_NAME-$EPOCH 17 | VM_SECURITY_TYPE="${VM_SECURITY_TYPE:-TrustedLaunch}" 18 | 19 | set -e 20 | 21 | echo "Starting script" 22 | 23 | if [ -z "${SUBSCRIPTION_ID}" ] ; then 24 | echo "SUBSCRIPTION_ID missing. Either set environment variable or edit $0 to set a subscription ID." 25 | exit 1 26 | fi 27 | 28 | if [ ! -f "$PATH_TO_PUBLIC_SSH_KEY" ]; then 29 | ssh-keygen -t rsa -b 4096 -f "$PATH_TO_PRIVATE_SSH_KEY" -N "" 30 | echo "SSH key created." 31 | else 32 | echo "SSH key already exists." 33 | fi 34 | 35 | # Log into Azure (this will open a browser window prompting you to log in) 36 | if az account get-access-token -o none; then 37 | echo "Using existing Azure account" 38 | else 39 | echo "Logging you into Azure" 40 | az login 41 | fi 42 | 43 | # Set the subscription you want to use 44 | az account set --subscription "$SUBSCRIPTION_ID" 45 | 46 | # Create resource group 47 | az group create -g "$RG" -l "$LOCATION" 48 | 49 | echo "Creating VM..." 50 | az vm create -n "$VM_NAME_WITH_TIMESTAMP" \ 51 | -g "$RG" \ 52 | --image "$VM_IMAGE" \ 53 | --size "$VM_SIZE" \ 54 | --admin-username "$VM_ADMIN_USERNAME" \ 55 | --ssh-key-value "$PATH_TO_PUBLIC_SSH_KEY" \ 56 | --public-ip-sku Standard \ 57 | --security-type "$VM_SECURITY_TYPE" 58 | echo "VM successfully created" 59 | 60 | echo "Sleeping to ensure SSH access set up" 61 | sleep 15 62 | 63 | echo "Getting VM Public IP Address..." 64 | PUBLIC_IP=$(az vm show -d -g "$RG" -n "$VM_NAME_WITH_TIMESTAMP" --query publicIps -o tsv) 65 | echo "$PUBLIC_IP" 66 | 67 | scp -o StrictHostKeyChecking=no -i "$PATH_TO_PRIVATE_SSH_KEY" ./target/debug/functional_tests "$VM_ADMIN_USERNAME"@"$PUBLIC_IP":~ 68 | 69 | echo "Logging into VM..." 70 | ssh -o StrictHostKeyChecking=no -i "$PATH_TO_PRIVATE_SSH_KEY" "$VM_ADMIN_USERNAME"@"$PUBLIC_IP" 'sudo ./functional_tests test_user' 71 | 72 | # Delete the resource group 73 | az group delete -g "$RG" --yes --no-wait 74 | --------------------------------------------------------------------------------