├── .gitignore ├── LICENSE ├── README.md ├── bin └── .gitignore ├── build.bat ├── build.ps1 ├── cloud └── aws │ ├── README.md │ ├── cluster │ ├── .gitignore │ └── test-cluster.template │ ├── deploy.bat │ ├── deploy.ps1 │ └── node │ ├── .gitignore │ ├── eks-worker-node.pkr.hcl.template │ ├── generate-setup-script.py │ └── scripts │ ├── cleanup.ps1 │ ├── setup.ps1 │ └── startup.ps1 ├── deployments ├── default-daemonsets.yml ├── multitenancy-configmap.yml └── multitenancy-inline.yml ├── examples ├── cuda-devicequery │ ├── cuda-devicequery-mcdm.yml │ └── cuda-devicequery-wddm.yml ├── cuda-montecarlo │ ├── cuda-montecarlo-mcdm.yml │ └── cuda-montecarlo-wddm.yml ├── device-discovery │ ├── device-discovery-mcdm.yml │ └── device-discovery-wddm.yml ├── directml │ ├── directml-mcdm.yml │ └── directml-wddm.yml ├── ffmpeg-amf │ └── ffmpeg-amf.yml ├── ffmpeg-autodetect │ ├── autodetect-encoder.ps1 │ └── ffmpeg-autodetect.yml ├── ffmpeg-nvenc │ └── ffmpeg-nvenc.yml ├── ffmpeg-quicksync │ └── ffmpeg-quicksync.yml ├── nvidia-smi │ ├── nvidia-smi-mcdm.yml │ └── nvidia-smi-wddm.yml ├── opencl-enum │ ├── opencl-enum-mcdm.yml │ └── opencl-enum-wddm.yml └── vulkaninfo │ └── vulkaninfo.yml ├── external └── .gitignore ├── library ├── CMakeLists.txt ├── include │ ├── DeviceDiscovery.h │ └── DeviceFilter.h ├── src │ ├── Adapter.h │ ├── AdapterEnumeration.cpp │ ├── AdapterEnumeration.h │ ├── D3DHelpers.cpp │ ├── D3DHelpers.h │ ├── Device.h │ ├── DeviceDiscovery.cpp │ ├── DeviceDiscoveryImp.cpp │ ├── DeviceDiscoveryImp.h │ ├── DllMain.cpp │ ├── ErrorHandling.cpp │ ├── ErrorHandling.h │ ├── ObjectHelpers.h │ ├── RegistryQuery.cpp │ ├── RegistryQuery.h │ ├── SafeArray.cpp │ ├── SafeArray.h │ ├── WmiQuery.cpp │ ├── WmiQuery.h │ └── pch.h ├── test │ └── test-device-discovery-cpp.cpp └── vcpkg.json ├── plugins ├── cmd │ ├── device-plugin-mcdm │ │ └── main.go │ ├── device-plugin-wddm │ │ └── main.go │ ├── gen-device-mounts │ │ └── main.go │ ├── query-hcs-capabilities │ │ └── main.go │ └── test-device-discovery-go │ │ └── main.go ├── go.mod ├── go.sum └── internal │ ├── discovery │ ├── device.go │ ├── device_discovery.go │ ├── device_filter.go │ └── runtime_file.go │ ├── mount │ ├── default_mounts.go │ ├── device_mounts.go │ └── vendors.go │ └── plugin │ ├── common_main.go │ ├── deletion_watcher.go │ ├── device_plugin.go │ ├── device_watcher.go │ └── plugin_configuration.go ├── update-version.bat └── update-version.ps1 /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | build 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022-2023 TensorWorks Pty Ltd 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /bin/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /build.bat: -------------------------------------------------------------------------------- 1 | @powershell -ExecutionPolicy Bypass -File "%~dp0.\build.ps1" %* 2 | -------------------------------------------------------------------------------- /cloud/aws/README.md: -------------------------------------------------------------------------------- 1 | # Amazon EKS demo deployment 2 | 3 | This directory contains scripts that can be used to deploy the Kubernetes device plugins for DirectX to an [Amazon EKS](https://aws.amazon.com/eks/) Kubernetes cluster for demonstration purposes. Note that the deployment created by these scripts **is not intended for production use**, and lacks important functionality such as auto-scaling the Windows node group based on requests for DirectX devices. 4 | 5 | The [main deployment script](./deploy.ps1) performs the following steps: 6 | 7 | - Builds a custom [Amazon Machine Image (AMI)](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/AMIs.html) based on Windows Server 2022 for use by Kubernetes worker nodes, with the NVIDIA GPU drivers and containerd v1.7.0 installed. The supporting scripts for building the AMI are located in the [node](./node) subdirectory. 8 | 9 | - Creates a EKS cluster with a Windows node group of `g4dn.xlarge` instances that is configured to use the custom AMI. The supporting configuration files for creating the cluster are located in the [cluster](./cluster) subdirectory. 10 | 11 | - Deploys the Kubernetes device plugins for DirectX to the EKS cluster using the [default HostProcess DaemonSets for the MCDM device plugin and the WDDM device plugin](../../deployments/default-daemonsets.yml). 12 | 13 | 14 | ## Contents 15 | 16 | - [Requirements](#requirements) 17 | - [Running the deployment script](#running-the-deployment-script) 18 | - [Testing the cluster](#testing-the-cluster) 19 | - [Cleaning up](#cleaning-up) 20 | 21 | 22 | ## Requirements 23 | 24 | To use the deployment scripts, the following requirements must be met: 25 | 26 | - The AWS region that you are using needs to have sufficient quota to run at least one `g4dn.xlarge` EC2 instance. To view or change the relevant limit, login to the AWS web console and navigate to the [*Running On-Demand G and VT instances*](https://console.aws.amazon.com/servicequotas/home/services/ec2/quotas/L-DB2E81BA) service quota page. The minimum required value is 4 vCPUs. 27 | 28 | - The AWS region that you are using needs to have a default VPC configured with at least one subnet. If you have deleted the default VPC for the target region then you will need to [create a new one](https://docs.aws.amazon.com/vpc/latest/userguide/default-vpc.html#create-default-vpc). 29 | 30 | - [Microsoft PowerShell](https://github.com/PowerShell/PowerShell) needs to be installed when running the deployment scripts under Linux or macOS systems. (Under Windows, the built-in Windows PowerShell is used instead.) 31 | 32 | - The [AWS CLI](https://docs.aws.amazon.com/cli/) needs to be installed and configured with credentials that permit the creation of AMIs and EKS clusters. For details, see [*Configuring the AWS CLI*](https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-configure.html). 33 | 34 | - [eksctl](https://eksctl.io/) version [0.111.0](https://github.com/weaveworks/eksctl/blob/v0.111.0/docs/release_notes/0.110.0.md) or newer needs to be installed (older versions will refuse to create Windows node groups with GPUs). 35 | 36 | - [HashiCorp Packer](https://www.packer.io/) needs to be installed. 37 | 38 | - [kubectl](https://kubernetes.io/docs/reference/kubectl/) needs to be installed. 39 | 40 | 41 | ## Running the deployment script 42 | 43 | Under Windows, run the main deployment script using the following command: 44 | 45 | ``` 46 | deploy.bat 47 | ``` 48 | 49 | Under Linux and macOS, use this command instead: 50 | 51 | ```bash 52 | pwsh deploy.ps1 53 | ``` 54 | 55 | The following optional flags can be used to control the deployment options: 56 | 57 | - `-Region`: specifies the AWS region into which resources will be deployed. The default region is `us-east-1`. 58 | 59 | - `-AmiName`: specifies the name to use for the custom worker node AMI. The default name is `eks-worker-node`. 60 | 61 | - `-ClusterName`: specifies the name to use for the EKS cluster. The default name is `demo-cluster`. 62 | 63 | An example usage of these flags is shown below: 64 | 65 | ```bash 66 | # Deploys to the Sydney (ap-southeast-2) AWS region and uses custom names for both the AMI and the EKS cluster 67 | pwsh deploy.ps1 -Region "ap-southeast-2" -AmiName "my-custom-ami" -ClusterName "my-test-cluster" 68 | ``` 69 | 70 | 71 | ## Testing the cluster 72 | 73 | Once the EKS cluster has been created, eksctl will configure kubectl to communicate with that cluster by default. This means you can start using kubectl to deploy examples from the top-level [examples](../../examples) directory without the need for any additional configuration steps: 74 | 75 | 1. The first example you should deploy is the [**device-discovery**](../../examples/device-discovery/) test, which acts as a sanity check to verify that GPUs are being exposed to containers correctly: 76 | 77 | ```bash 78 | kubectl apply -f '../../examples/device-discovery/device-discovery-wddm.yml' 79 | ``` 80 | 81 | Once the Job has been created, wait for the Pod to be assigned to a Windows worker node and then run to completion. If the Job finishes with a status of "Succeeded" then you should check the Pod logs to verify that the NVIDIA Tesla T4 GPU is listed in the output. If the Job finishes with a status of "Failed" or if the log output lists zero devices then something has gone wrong. A failure here could indicate an issue with the Kubernetes device plugins for DirectX themselves, or with some aspect of the EKS cluster configuration. 82 | 83 | 2. Since the EKS worker nodes are using NVIDIA GPUs, the second example you should deploy is the [**nvidia-smi**](../../examples/nvidia-smi/) test, which acts as a sanity check to verify that the NVIDIA GPU drivers are able to communicate with the GPU: 84 | 85 | ```bash 86 | kubectl apply -f '../../examples/nvidia-smi/nvidia-smi-wddm.yml' 87 | ``` 88 | 89 | If the Job finishes with a status of "Succeeded" then `nvidia-smi` was able to communicate with the GPU, and you can check the Pod logs to verify that the output is as expected. If the Job finishes with a status of "Failed" then this indicates either an issue with the Kubernetes device plugins for DirectX themselves or with the NVIDIA GPU drivers on the worker node. 90 | 91 | 3. With these basic sanity checks out of the way, you can then try out any of the other examples that work on NVIDIA GPUs. Since the NVIDIA Tesla T4 GPUs provided by `g4dn.xlarge` EC2 instances support both compute and display, you will need to deploy examples that request a `directx.microsoft.com/display` resource. Note that some examples include YAML files for both compute-only (MCDM) and compute+display (WDDM) requests, so be sure to use the version with a filename suffix of `-wddm.yml`. 92 | 93 | Note that if you attempt to run two tests at once, one of them will wait for the other to complete before it can be scheduled. This is because the Windows node group will not automatically scale up in response to requests for DirectX devices, so only one node (and thus one GPU) will be available to allocate to containers at any given time. 94 | 95 | 96 | ## Cleaning up 97 | 98 | To delete all previously deployed AWS resources, run the main deployment script with the `-Clean` flag. Under Windows: 99 | 100 | ``` 101 | deploy.bat -Clean 102 | ``` 103 | 104 | Under Linux and macOS: 105 | 106 | ```bash 107 | pwsh deploy.ps1 -Clean 108 | ``` 109 | 110 | If you specified flags for the AWS region or custom resource names when deploying the resources then be sure to include these flags when deleting them as well. For example: 111 | 112 | ```bash 113 | # Deletes the AWS resources that were deployed in the example from the earlier section 114 | pwsh deploy.ps1 -Region "ap-southeast-2" -AmiName "my-custom-ami" -ClusterName "my-test-cluster" -Clean 115 | ``` 116 | -------------------------------------------------------------------------------- /cloud/aws/cluster/.gitignore: -------------------------------------------------------------------------------- 1 | test-cluster.yml 2 | -------------------------------------------------------------------------------- /cloud/aws/cluster/test-cluster.template: -------------------------------------------------------------------------------- 1 | apiVersion: eksctl.io/v1alpha5 2 | kind: ClusterConfig 3 | 4 | metadata: 5 | name: "__CLUSTER_NAME__" 6 | region: "__AWS_REGION__" 7 | version: "1.24" 8 | 9 | nodeGroups: 10 | - name: windows 11 | ami: "__AMI_ID__" 12 | amiFamily: WindowsServer2022FullContainer 13 | preBootstrapCommands: ["net user Administrator \"Passw0rd!\""] 14 | instanceType: g4dn.xlarge 15 | containerRuntime: containerd 16 | volumeSize: 100 17 | minSize: 1 18 | maxSize: 3 19 | 20 | managedNodeGroups: 21 | - name: linux 22 | instanceType: t2.large 23 | minSize: 2 24 | maxSize: 3 25 | -------------------------------------------------------------------------------- /cloud/aws/deploy.bat: -------------------------------------------------------------------------------- 1 | @powershell -ExecutionPolicy Bypass -File "%~dp0.\deploy.ps1" %* 2 | -------------------------------------------------------------------------------- /cloud/aws/deploy.ps1: -------------------------------------------------------------------------------- 1 | Param ( 2 | [parameter(HelpMessage = "Remove existing resources created by a previous run")] 3 | [switch] $Clean, 4 | 5 | [parameter(HelpMessage = "The AWS region in which to deploy resources")] 6 | $Region = 'us-east-1', 7 | 8 | [parameter(HelpMessage = "The name to use for the custom worker node AMI")] 9 | $AmiName = 'eks-worker-node', 10 | 11 | [parameter(HelpMessage = "The name to use for the EKS cluster")] 12 | $ClusterName = 'demo-cluster' 13 | ) 14 | 15 | 16 | # Halt execution if we encounter an error 17 | $ErrorActionPreference = 'Stop' 18 | 19 | 20 | # Replaces the placeholders in a template file with values and writes the output to a new file 21 | function FillTemplate 22 | { 23 | Param ( 24 | $Template, 25 | $Rendered, 26 | $Values 27 | ) 28 | 29 | $filled = Get-Content -Path $Template -Raw 30 | $Values.GetEnumerator() | ForEach-Object { 31 | $filled = $filled.Replace($_.Key, $_.Value) 32 | } 33 | Set-Content -Path $Rendered -Value $filled -NoNewline 34 | } 35 | 36 | # Represents the output of a native process 37 | class ProcessOutput 38 | { 39 | ProcessOutput([string] $stdout, [string] $stderr) 40 | { 41 | $this.StandardOutput = $stdout 42 | $this.StandardError = $stderr 43 | } 44 | 45 | [string] $StandardOutput 46 | [string] $StandardError 47 | } 48 | 49 | # Helper functions for executing native commands 50 | class ExecutionHelpers 51 | { 52 | # Escapes command-line arguments for passing to a native command 53 | static [string] EscapeArguments([string[]] $arguments) 54 | { 55 | $escaped = @() 56 | 57 | foreach ($arg in $arguments) 58 | { 59 | if ($arg.Contains(' ')) { 60 | $escaped += @("`"$arg`"") 61 | } 62 | else { 63 | $escaped += @($arg) 64 | } 65 | } 66 | 67 | return $escaped -join ' ' 68 | } 69 | 70 | # Executes a command and throws an error if it returns a non-zero exit code 71 | static [ProcessOutput] RunCommand([string] $command, [string[]] $arguments, [bool] $captureStdOut, [bool] $captureStdErr) 72 | { 73 | # Log the command 74 | $escapedArgs = [ExecutionHelpers]::EscapeArguments($arguments) 75 | $formatted = "[$command $escapedArgs]" 76 | Write-Host "$formatted" -ForegroundColor DarkYellow 77 | 78 | # Execute the command and wait for it to complete, retrieving the exit code, stdout and stderr 79 | $info = New-Object System.Diagnostics.ProcessStartInfo 80 | $info.FileName = $command 81 | $info.Arguments = $escapedArgs 82 | $info.RedirectStandardError = $captureStdErr 83 | $info.RedirectStandardOutput = $captureStdOut 84 | $info.UseShellExecute = $false 85 | $info.WorkingDirectory = (Get-Location).ToString() 86 | $process = New-Object System.Diagnostics.Process 87 | $process.StartInfo = $info 88 | $process.Start() 89 | $process.WaitForExit() 90 | $exitCode = $process.ExitCode 91 | $stdout = if ($captureStdOut) { $process.StandardOutput.ReadToEnd() } else { '' } 92 | $stderr = if ($captureStdErr) { $process.StandardError.ReadToEnd() } else { '' } 93 | 94 | # If the command terminated with a non-zero exit code then throw an error 95 | if ($exitCode -ne 0) { 96 | throw "Command $formatted terminated with exit code $exitCode, stdout $stdout and stderr $stderr" 97 | } 98 | 99 | # Return the output 100 | return [ProcessOutput]::new($stdout, $stderr) 101 | } 102 | 103 | # Do not capture stdout and stderr of child processes unless the caller explicitly requests it 104 | static [void] RunCommand([string] $command, [string[]] $arguments) { 105 | [ExecutionHelpers]::RunCommand($command, $arguments, $false, $false) 106 | } 107 | 108 | # Tests whether the specified command exists, by attempting to execute it with the supplied arguments 109 | static [bool] CommandExists([string] $command, [string[]] $testArguments) 110 | { 111 | try 112 | { 113 | [ExecutionHelpers]::RunCommand($command, $testArguments, $true, $true) 114 | return $true 115 | } 116 | catch { 117 | return $false 118 | } 119 | } 120 | } 121 | 122 | # Represents the Packer manifest data for our EKS worker node AMI 123 | class PackerManifest 124 | { 125 | PackerManifest([string] $path) { 126 | $this.ManifestPath = $path 127 | } 128 | 129 | [bool] Exists() { 130 | return (Test-Path -Path $this.ManifestPath) 131 | } 132 | 133 | [void] Parse() 134 | { 135 | # Parse the Packer manifest JSON and validate the AMI details 136 | $manifestDetails = Get-Content -Path $this.ManifestPath -Raw | ConvertFrom-Json 137 | $amiDetails = ($manifestDetails.builds[0].artifact_id -split ':') 138 | if ($amiDetails.Length -lt 2) { 139 | throw "Malformed 'artifact_id' field in Packer build manifest: '$amiDetails'" 140 | } 141 | 142 | # Extract the region and AMI ID 143 | $this.AmiRegion = $amiDetails[0] 144 | $this.AmiID = $amiDetails[1] 145 | 146 | # If the manifest data doesn't contain the snapshot ID for the AMI then populate it 147 | $this.SnapshotID = $manifestDetails.builds[0].custom_data.snapshot_id 148 | if ($this.SnapshotID.Length -lt 1) 149 | { 150 | # Attempt to retrieve the snapshot ID from the AWS API 151 | Write-Host 'Retrieving the snapshot ID for the AMI...' -ForegroundColor Green 152 | $queryOutput = [ExecutionHelpers]::RunCommand('aws', @('ec2', 'describe-images', "--region=$($this.AmiRegion)", "--image-ids=$($this.AmiID)"), $true, $true) 153 | $snapshotDetails = $queryOutput.StandardOutput | ConvertFrom-Json 154 | $this.SnapshotID = $snapshotDetails.Images[0].BlockDeviceMappings[0].Ebs.SnapshotId 155 | if ($amiDetails.Length -lt 1) { 156 | throw "Failed to retrieve snapshot ID for AMI: '$this.AmiID'" 157 | } 158 | 159 | # Inject the snapshot ID into the manifest data 160 | $manifestDetails.builds[0].custom_data.snapshot_id = $this.SnapshotID 161 | 162 | # Write the updated manifest data back to the JSON file 163 | $manifestJson = ConvertTo-Json $manifestDetails -Depth 32 164 | Set-Content -Path $this.ManifestPath -Value $manifestJson -NoNewline 165 | } 166 | } 167 | 168 | [void] Delete() 169 | { 170 | # De-register the AMI 171 | [ExecutionHelpers]::RunCommand('aws', @('ec2', 'deregister-image', "--region=$($this.AmiRegion)", "--image-id=$($this.AmiID)")) 172 | 173 | # Remove the snapshot 174 | [ExecutionHelpers]::RunCommand('aws', @('ec2', 'delete-snapshot', "--region=$($this.AmiRegion)", "--snapshot-id=$($this.SnapshotID)")) 175 | 176 | # Delete the manifest JSON file 177 | Remove-Item -Force $this.ManifestPath 178 | } 179 | 180 | [string] $ManifestPath 181 | [string] $AmiID 182 | [string] $AmiRegion 183 | [string] $SnapshotID 184 | } 185 | 186 | # Represents an EKS cluster managed by eksctl 187 | class EksCluster 188 | { 189 | EksCluster([string] $name) { 190 | $this.Name = $name 191 | } 192 | 193 | [bool] Exists() 194 | { 195 | try 196 | { 197 | [ExecutionHelpers]::RunCommand('eksctl', @('get', 'cluster', "--name=$($this.Name)", "--region=$($global:Region)"), $true, $true) 198 | return $true 199 | } 200 | catch { 201 | return $false 202 | } 203 | } 204 | 205 | [void] Create([string] $yamlFile) { 206 | [ExecutionHelpers]::RunCommand('eksctl', @('create', 'cluster', '-f', $yamlFile.Replace('\', '/'))) 207 | } 208 | 209 | [void] Delete() { 210 | [ExecutionHelpers]::RunCommand('eksctl', @('delete', 'cluster', "--name=$($this.Name)", "--region=$($global:Region)")) 211 | } 212 | 213 | [string] $Name 214 | } 215 | 216 | 217 | # Verify that all of the native commands we require are available 218 | $requiredCommands = @{ 219 | 'the AWS CLI' = [ExecutionHelpers]::CommandExists('aws', @('help')); 220 | 'eksctl' = [ExecutionHelpers]::CommandExists('eksctl', @('version')); 221 | 'kubectl' = [ExecutionHelpers]::CommandExists('kubectl', @('help')); 222 | 'HashiCorp Packer' = [ExecutionHelpers]::CommandExists('packer', @('version')) 223 | } 224 | foreach ($command in $requiredCommands.GetEnumerator()) 225 | { 226 | if ($command.Value -eq $false) { 227 | throw "Error: $($command.Name) must be installed to run this script!" 228 | } 229 | } 230 | 231 | # Resolve the path to the Packer manifest file and create a helper object to represent the manifest data 232 | $packerDir = "$PSScriptRoot\node" 233 | $packerManifest = [PackerManifest]::new("$packerDir\manifest.json") 234 | 235 | # Create a helper object to represent our test EKS cluster 236 | $eksCluster = [EksCluster]::new($global:ClusterName) 237 | 238 | # Determine whether we are removing existing resources created by a previous run 239 | if ($Clean) 240 | { 241 | # Remove the EKS cluster if it exists 242 | if ($eksCluster.Exists()) 243 | { 244 | Write-Host 'Removing existing EKS cluster...' -ForegroundColor Green 245 | $eksCluster.Delete() 246 | } 247 | 248 | # Delete the AMI and its accompanying snapshot if they exist 249 | if ($packerManifest.Exists()) 250 | { 251 | Write-Host 'Removing AMI and its accompanying snapshot...' -ForegroundColor Green 252 | $packerManifest.Parse() 253 | $packerManifest.Delete() 254 | } 255 | 256 | Exit 257 | } 258 | 259 | # Build the custom worker node AMI if it doesn't already exist 260 | if ($packerManifest.Exists() -eq $false) 261 | { 262 | # Populate the Packer template 263 | $packerfile = "$packerDir\eks-worker-node.pkr.hcl" 264 | FillTemplate ` 265 | -Template "$packerDir\eks-worker-node.pkr.hcl.template" ` 266 | -Rendered $packerfile ` 267 | -Values @{ 268 | '__AWS_REGION__' = $global:Region; 269 | '__AMI_NAME__' = $global:AmiName 270 | } 271 | 272 | # Build the AMI 273 | Write-Host 'Building the EKS custom worker node AMI...' -ForegroundColor Green 274 | Push-Location "$packerDir" 275 | [ExecutionHelpers]::RunCommand('packer', @('init', 'eks-worker-node.pkr.hcl')) 276 | [ExecutionHelpers]::RunCommand('packer', @('build', 'eks-worker-node.pkr.hcl')) 277 | Pop-Location 278 | } 279 | 280 | # Parse the Packer manifest JSON and validate the AMI details 281 | $packerManifest.Parse() 282 | 283 | # Populate the cluster template YAML with the values for the AMI 284 | $clusterDir = "$PSScriptRoot\cluster" 285 | $configFile = "$clusterDir\test-cluster.yml" 286 | FillTemplate ` 287 | -Template "$clusterDir\test-cluster.template" ` 288 | -Rendered $configFile ` 289 | -Values @{ 290 | '__CLUSTER_NAME__' = $global:ClusterName; 291 | '__AWS_REGION__' = $packerManifest.AmiRegion; 292 | '__AMI_ID__' = $packerManifest.AmiID 293 | } 294 | 295 | # Deploy the test EKS cluster if it doesn't already exist 296 | if ($eksCluster.Exists() -eq $false) 297 | { 298 | Write-Host 'Deploying a test EKS cluster with a Windows worker node group using the custom AMI...' -ForegroundColor Green 299 | $eksCluster.Create($configFile) 300 | } 301 | 302 | # Deploy the device plugin DaemonSets to the test cluster 303 | Write-Host 'Deploying the DirectX device plugin DaemonSets to the test EKS cluster...' -ForegroundColor Green 304 | $deploymentsYaml = "$PSScriptRoot\..\..\deployments\default-daemonsets.yml" 305 | [ExecutionHelpers]::RunCommand('kubectl', @('apply', '-f', $deploymentsYaml.Replace('\', '/'))) 306 | -------------------------------------------------------------------------------- /cloud/aws/node/.gitignore: -------------------------------------------------------------------------------- 1 | eks-worker-node.pkr.hcl 2 | manifest.json 3 | -------------------------------------------------------------------------------- /cloud/aws/node/eks-worker-node.pkr.hcl.template: -------------------------------------------------------------------------------- 1 | packer { 2 | required_plugins { 3 | amazon = { 4 | version = ">= 1.0.9" 5 | source = "github.com/hashicorp/amazon" 6 | } 7 | } 8 | } 9 | 10 | source "amazon-ebs" "eks-worker-node" { 11 | ami_name = "__AMI_NAME__" 12 | instance_type = "g4dn.xlarge" 13 | region = "__AWS_REGION__" 14 | 15 | # Use the latest version of the official Windows Server 2022 base image 16 | source_ami_filter { 17 | filters = { 18 | name = "Windows_Server-2022-English-Full-Base-*" 19 | root-device-type = "ebs" 20 | virtualization-type = "hvm" 21 | } 22 | 23 | most_recent = true 24 | owners = ["amazon"] 25 | } 26 | 27 | # Expand the boot disk to 100GB 28 | launch_block_device_mappings { 29 | device_name = "/dev/sda1" 30 | volume_size = 100 31 | volume_type = "gp3" 32 | delete_on_termination = true 33 | } 34 | 35 | # Allow S3 access for the VM 36 | temporary_iam_instance_profile_policy_document { 37 | Version = "2012-10-17" 38 | Statement { 39 | Action = ["s3:Get*", "s3:List*"] 40 | Effect = "Allow" 41 | Resource = ["*"] 42 | } 43 | } 44 | 45 | # Use our startup script to enable SSH access 46 | user_data_file = "${path.root}/scripts/startup.ps1" 47 | 48 | # Use SSH for running commands in the VM 49 | communicator = "ssh" 50 | ssh_username = "Administrator" 51 | ssh_timeout = "30m" 52 | 53 | # Don't automatically stop the instance, since sysprep will perform the shutdown 54 | disable_stop_instance = true 55 | } 56 | 57 | build { 58 | name = "eks-worker-node" 59 | sources = ["source.amazon-ebs.eks-worker-node"] 60 | 61 | # Run our EKS worker node setup script 62 | provisioner "powershell" { 63 | script = "${path.root}/scripts/setup.ps1" 64 | } 65 | 66 | # Perform cleanup and shut down the VM 67 | provisioner "powershell" { 68 | script = "${path.root}/scripts/cleanup.ps1" 69 | valid_exit_codes = [0, 2300218] 70 | } 71 | 72 | # Store the AMI ID in a manifest file when the build completes 73 | post-processor "manifest" { 74 | output = "manifest.json" 75 | custom_data = { 76 | snapshot_id = "" 77 | } 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /cloud/aws/node/generate-setup-script.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # This script automates the generation of `setup.ps1` in the `scripts` subdirectory 4 | # 5 | import json, re, subprocess, yaml, sys 6 | from pathlib import Path 7 | 8 | 9 | class Utility: 10 | 11 | @staticmethod 12 | def log(message): 13 | """ 14 | Logs a message to stderr 15 | """ 16 | print('[generate-setup-script.py]: {}'.format(message), flush=True, file=sys.stderr) 17 | 18 | @staticmethod 19 | def capture(command, **kwargs): 20 | """ 21 | Executes the specified command and captures its output 22 | """ 23 | 24 | # Log the command being executed 25 | Utility.log(command) 26 | 27 | # Attempt to execute the specified command 28 | result = subprocess.run( 29 | command, 30 | check = True, 31 | capture_output = True, 32 | universal_newlines = True, 33 | **kwargs 34 | ) 35 | 36 | # Return the contents of stdout 37 | return result.stdout.strip() 38 | 39 | @staticmethod 40 | def writeFile(filename, data): 41 | """ 42 | Writes data to the specified file 43 | """ 44 | return Path(filename).write_bytes(data.encode('utf-8')) 45 | 46 | @staticmethod 47 | def commentForStep(name): 48 | """ 49 | Returns a descriptive comment for the build step with the specified name 50 | """ 51 | return { 52 | 53 | 'ConfigureDirectories': '# Create each of our directories', 54 | 'DownloadKubernetes': '# Download the Kubernetes components', 55 | 'DownloadEKSArtifacts': '# Download the EKS artifacts archive', 56 | 'ExtractEKSArtifacts': '# Extract the EKS artifacts archive', 57 | 'MoveEKSArtifacts': '# Move the EKS files into place', 58 | 'ExecuteBuildScripts': '# Perform EKS worker node setup', 59 | 'RemoveEKSArtifactDownloadDirectory': '# Perform cleanup', 60 | 61 | 'InstallContainers': '\n'.join([ 62 | '# Install the Windows Containers feature', 63 | '# (Note: this is actually a no-op here, since we install the feature beforehand in startup.ps1)' 64 | ]) 65 | 66 | }.get(name, None) 67 | 68 | @staticmethod 69 | def parseConstants(constants): 70 | """ 71 | Parses an EC2 ImageBuilder component's constants list 72 | """ 73 | parsed = {} 74 | for entry in constants: 75 | for key, values in entry.items(): 76 | parsed[key] = values['value'] 77 | return parsed 78 | 79 | @staticmethod 80 | def replaceConstants(string, constants): 81 | """ 82 | Converts EC2 ImageBuilder constant references to PowerShell variable references 83 | """ 84 | 85 | # If the value of a constant is used as a magic value rather than a reference, 86 | # replace it with a reference to the variable representing the constant instead 87 | transformed = string 88 | for key, value in constants.items(): 89 | transformed = transformed.replace(value, '${}'.format(key)) 90 | 91 | # Convert `{{ variable }}` syntax to PowerShell `$variable` syntax 92 | # (Note that we don't bother to wrap the variable names in curly braces, since we know that none 93 | # of the variable names contain special characters, and they're only ever interpolated as either 94 | # part of a filesystem path surrounded by separators, or as a parameter surrounded by whitespace) 95 | return re.sub('{{ (.+?) }}', '$\\1', transformed) 96 | 97 | @staticmethod 98 | def replaceSystemPaths(path): 99 | """ 100 | Replaces hard-coded system paths with the equivalent environment variables 101 | """ 102 | replaced = path 103 | replaced = replaced.replace('C:\\Program Files', '$env:ProgramFiles') 104 | replaced = replaced.replace('C:\\ProgramData', '$env:ProgramData') 105 | return replaced 106 | 107 | @staticmethod 108 | def s3UriToHttpsUrl(s3Uri): 109 | """ 110 | Converts an `s3://` URI to an HTTPS URL 111 | """ 112 | url = s3Uri.replace('s3://', '') 113 | components = url.split('/', 1) 114 | return 'https://{}.s3.amazonaws.com/{}'.format(components[0], components[1]) 115 | 116 | 117 | # Retrieve the contents of the "Amazon EKS Optimized Windows AMI" EC2 ImageBuilder component 118 | componentData = json.loads(Utility.capture([ 119 | 'aws', 120 | 'imagebuilder', 121 | 'get-component', 122 | '--region=us-east-1', 123 | '--component-build-version-arn', 124 | 'arn:aws:imagebuilder:us-east-1:aws:component/eks-optimized-ami-windows/1.24.0' 125 | ])) 126 | 127 | # Parse the pipeline YAML data and extract the list of constants 128 | pipelineData = yaml.load(componentData['component']['data'], Loader=yaml.Loader) 129 | constants = Utility.parseConstants(pipelineData['constants']) 130 | 131 | # Extract the steps for the "build" phase 132 | buildSteps = [p['steps'] for p in pipelineData['phases'] if p['name'] == 'build'][0] 133 | 134 | print('CONSTANTS:') 135 | print(json.dumps(constants, indent=4)) 136 | 137 | print() 138 | print('BUILD STEPS:') 139 | print(json.dumps(buildSteps, indent=4)) 140 | 141 | # Prepend our header to the generated PowerShell code 142 | generated = '''<# 143 | THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT! 144 | 145 | This script is based on the logic from the "Amazon EKS Optimized Windows AMI" 146 | EC2 ImageBuilder component, with modifications to use containerd 1.7.0. 147 | 148 | The original ImageBuilder component logic is Copyright Amazon.com, Inc. or 149 | its affiliates, and is licensed under the MIT License. 150 | #> 151 | 152 | # Halt execution if we encounter an error 153 | $ErrorActionPreference = 'Stop' 154 | 155 | 156 | # Applies in-place patches to a file 157 | function PatchFile 158 | { 159 | Param ( 160 | $File, 161 | $Patches 162 | ) 163 | 164 | $patched = Get-Content -Path $File -Raw 165 | $Patches.GetEnumerator() | ForEach-Object { 166 | $patched = $patched.Replace($_.Key, $_.Value) 167 | } 168 | Set-Content -Path $File -Value $patched -NoNewline 169 | } 170 | 171 | 172 | ''' 173 | 174 | # Inject an additional constant for the parent of the temp directory, immediately before the child directory 175 | tempPath = {k:v for k,v in constants.items() if k == 'TempPath'} 176 | otherConstants = {k:v for k,v in constants.items() if k != 'TempPath'} 177 | constants = {**otherConstants, 'TempRoot': 'C:\\TempEKSArtifactDir', **tempPath} 178 | 179 | # Define variables for each of our constants 180 | generated += '# Constants\n' 181 | existingConstants = {} 182 | for key, value in constants.items(): 183 | transformed = Utility.replaceConstants(value, existingConstants) 184 | transformed = Utility.replaceSystemPaths(transformed) 185 | generated += '${} = "{}"\n'.format(key, transformed) 186 | existingConstants[key] = value 187 | 188 | # Process each build step in turn 189 | for step in buildSteps: 190 | 191 | # Determine whether we have custom preprocessing logic for the step 192 | name = step['name'] 193 | if name == 'ConfigureDirectories': 194 | 195 | # Add the temp directory to the list of directories to be created 196 | step['loop']['forEach'] += [constants['TempRoot']] 197 | 198 | elif name == 'DownloadKubernetes': 199 | 200 | # Inject the driver installation step immediately prior to the Kubernetes download step 201 | generated += '\n'.join([ 202 | '', 203 | '# Install the NVIDIA GPU drivers', 204 | "$driverBucket = 'ec2-windows-nvidia-drivers'", 205 | "$driver = Get-S3Object -BucketName $driverBucket -KeyPrefix 'latest' -Region 'us-east-1' | Where-Object {$_.Key.Contains('server2022')}", 206 | 'Copy-S3Object -BucketName $driverBucket -Key $driver.Key -LocalFile "$TempRoot\driver.exe" -Region \'us-east-1\'', 207 | "Start-Process -FilePath \"$TempRoot\driver.exe\" -ArgumentList @('-s', '-noreboot') -NoNewWindow -Wait", 208 | '' 209 | ]) 210 | 211 | elif name == 'ExtractEKSArtifacts': 212 | 213 | # Remove the redundant directory creation command 214 | step['inputs']['commands'] = [ 215 | c for c in step['inputs']['commands'] 216 | if not c.startswith('New-Item') 217 | ] 218 | 219 | # Use absolute file and directory paths rather than relative paths 220 | step['inputs']['commands'] = [ 221 | c.replace('EKS-Artifacts.zip', '"C:\\EKS-Artifacts.zip"').replace('TempEKSArtifactDir', 'C:\\TempEKSArtifactDir') 222 | for c in step['inputs']['commands'] 223 | ] 224 | 225 | elif name == 'InstallContainerRuntimes': 226 | 227 | # Inject the containerd 1.7.0 download step, along with our configuration patching steps, immediately prior to the containerd installation step 228 | generated += '\n'.join([ 229 | '', 230 | '# -------', 231 | '', 232 | '# TEMPORARY UNTIL EKS ADDS SUPPORT FOR CONTAINERD v1.7.0:', 233 | '# Download and extract the containerd 1.7.0 release build', 234 | '$containerdTarball = "$TempPath\\containerd-1.7.0.tar.gz"', 235 | '$containerdFiles = "$TempPath\\containerd-1.7.0"', 236 | '$webClient.DownloadFile(\'https://github.com/containerd/containerd/releases/download/v1.7.0/containerd-1.7.0-windows-amd64.tar.gz\', $containerdTarball)', 237 | 'New-Item -Path "$containerdFiles" -ItemType Directory -Force | Out-Null', 238 | 'tar.exe -xvzf "$containerdTarball" -C "$containerdFiles"', 239 | '', 240 | '# Move the containerd files into place', 241 | 'Move-Item -Path "$containerdFiles\\bin\\containerd.exe" -Destination "$ContainerdPath\\containerd.exe" -Force', 242 | 'Move-Item -Path "$containerdFiles\\bin\\containerd-shim-runhcs-v1.exe" -Destination "$ContainerdPath\\containerd-shim-runhcs-v1.exe" -Force', 243 | 'Move-Item -Path "$containerdFiles\\bin\\ctr.exe" -Destination "$ContainerdPath\\ctr.exe" -Force', 244 | '', 245 | '# Clean up the containerd intermediate files', 246 | 'Remove-Item -Path "$containerdFiles" -Recurse -Force', 247 | 'Remove-Item -Path "$containerdTarball" -Force', 248 | '', 249 | '# -------', 250 | '', 251 | '# Patch the containerd setup script to configure a log file (rather than just discarding log output) and to use the upstream pause', 252 | '# container image rather than the EKS version, since the latter appears to cause errors when attempting to create Windows Pods', 253 | 'PatchFile -File "$TempPath\Add-ContainerdRuntime.ps1" -Patches @{', 254 | ' "containerd --register-service" = "containerd --register-service --log-file \'C:\\ProgramData\\containerd\\root\\output.log\'";', 255 | ' "amazonaws.com/eks/pause-windows:latest" = "registry.k8s.io/pause:3.9"', 256 | '}', 257 | '', 258 | '# Add the full Windows Server 2022 base image and the pause image to the list of images to pre-pull', 259 | '$baseLayersFile = "$TempPath\eks.baselayers.config"', 260 | '$baseLayers = Get-Content -Path $baseLayersFile -Raw | ConvertFrom-Json', 261 | '$baseLayers.2022 += "mcr.microsoft.com/windows/server:ltsc2022"', 262 | '$baseLayers.2022 += "registry.k8s.io/pause:3.9"', 263 | '$patchedJson = ConvertTo-Json -Depth 100 -InputObject $baseLayers', 264 | 'Set-Content -Path $baseLayersFile -Value $patchedJson -NoNewline', 265 | '', 266 | ]) 267 | 268 | # Simplify the containerd installation command 269 | step['inputs']['commands'] = [ 270 | '', 271 | '# Register containerd as the EKS container runtime', 272 | 'Push-Location $TempPath', 273 | '& .\Add-ContainerdRuntime.ps1 -Path "$ContainerdPath"', 274 | 'Pop-Location' 275 | ] 276 | 277 | elif name == 'ExecuteBuildScripts': 278 | 279 | # Prefix each script invocation with the call operator 280 | step['loop']['forEach'] = [ 281 | '& {}'.format(command) 282 | for command in step['loop']['forEach'] 283 | ] 284 | 285 | # Strip away the boilerplate code surrounding each script invocation 286 | step['inputs']['commands'] = ['Push-Location $TempPath'] + step['loop']['forEach'] + ['Pop-Location'] 287 | 288 | # ------- 289 | 290 | # If we have a descriptive comment for the step then include it above its generated code 291 | comment = Utility.commentForStep(name) 292 | if comment != None: 293 | generated += '\n{}\n'.format(comment) 294 | 295 | # ------- 296 | 297 | # Generate code for the step based on its action type 298 | action = step['action'] 299 | 300 | if action == 'CreateFolder': 301 | directories = [Utility.replaceConstants(d, constants) for d in step['loop']['forEach']] 302 | generated += '\n'.join([ 303 | 'foreach ($dir in @({})) {{'.format(', '.join(directories)), 304 | '\tNew-Item -Path $dir -ItemType Directory -Force | Out-Null', 305 | '}' 306 | ]) 307 | 308 | elif action == 'DeleteFolder': 309 | generated += '\n'.join([ 310 | 'Remove-Item -Path "{}" -Recurse -Force'.format(Utility.replaceConstants(input['path'], constants)) 311 | for input in step['inputs'] 312 | ]) 313 | 314 | elif action == 'MoveFile': 315 | generated += '\n'.join([ 316 | 'Move-Item -Path "{}" -Destination "{}" -Force'.format( 317 | Utility.replaceConstants(input['source'], constants), 318 | Utility.replaceConstants(input['destination'], constants) 319 | ) 320 | for input in step['inputs'] 321 | ]) 322 | 323 | elif action == 'S3Download': 324 | generated += '\n'.join([ 325 | '$webClient.DownloadFile("{}", "{}")'.format( 326 | Utility.s3UriToHttpsUrl(input['source']), 327 | Utility.replaceConstants(input['destination'], constants) 328 | ) 329 | for input in step['inputs'] 330 | ]) 331 | 332 | elif action == 'ExecutePowerShell': 333 | generated += '\n'.join([ 334 | Utility.replaceConstants(c, constants).replace("'", '"') 335 | for c in step['inputs']['commands'] 336 | if not c.startswith('$ErrorActionPreference') 337 | ]) 338 | 339 | elif action == 'Reboot': 340 | Utility.log('Ignoring reboot step.') 341 | continue 342 | 343 | else: 344 | raise RuntimeError('Unknown build step action: {}'.format(action)) 345 | 346 | # ------- 347 | 348 | # Add a trailing newline after each non-ignored step 349 | generated += '\n' 350 | 351 | # Write the generated code to the output script file 352 | outfile = Path(__file__).parent / 'scripts' / 'setup.ps1' 353 | Utility.writeFile(outfile, generated) 354 | Utility.log('Wrote generated code to {}'.format(outfile)) 355 | -------------------------------------------------------------------------------- /cloud/aws/node/scripts/cleanup.ps1: -------------------------------------------------------------------------------- 1 | # Perform cleanup 2 | Set-Service -Name sshd -StartupType 'Manual' 3 | Remove-Item -Path 'C:\ProgramData\ssh\administrators_authorized_keys' -Force 4 | 5 | # Remove the file for this script, since Packer won't have a chance to perform its own cleanup 6 | Remove-Item -Path $PSCommandPath -Force 7 | 8 | # Perform sysprep and shut down the VM 9 | & "$Env:ProgramFiles\Amazon\EC2Launch\EC2Launch.exe" sysprep --shutdown=true 10 | -------------------------------------------------------------------------------- /cloud/aws/node/scripts/setup.ps1: -------------------------------------------------------------------------------- 1 | <# 2 | THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT! 3 | 4 | This script is based on the logic from the "Amazon EKS Optimized Windows AMI" 5 | EC2 ImageBuilder component, with modifications to use containerd 1.7.0. 6 | 7 | The original ImageBuilder component logic is Copyright Amazon.com, Inc. or 8 | its affiliates, and is licensed under the MIT License. 9 | #> 10 | 11 | # Halt execution if we encounter an error 12 | $ErrorActionPreference = 'Stop' 13 | 14 | 15 | # Applies in-place patches to a file 16 | function PatchFile 17 | { 18 | Param ( 19 | $File, 20 | $Patches 21 | ) 22 | 23 | $patched = Get-Content -Path $File -Raw 24 | $Patches.GetEnumerator() | ForEach-Object { 25 | $patched = $patched.Replace($_.Key, $_.Value) 26 | } 27 | Set-Content -Path $File -Value $patched -NoNewline 28 | } 29 | 30 | 31 | # Constants 32 | $KubernetesPath = "$env:ProgramFiles\Kubernetes" 33 | $KubernetesDownload = "https://amazon-eks.s3.amazonaws.com/1.24.7/2022-10-31/bin/windows/amd64" 34 | $ContainerdPath = "$env:ProgramFiles\containerd" 35 | $EKSPath = "$env:ProgramFiles\Amazon\EKS" 36 | $CNIPath = "$EKSPath\cni" 37 | $CSIProxyPath = "$EKSPath\bin" 38 | $EKSLogsPath = "$env:ProgramData\Amazon\EKS\logs" 39 | $TempRoot = "C:\TempEKSArtifactDir" 40 | $TempPath = "$TempRoot\EKS-Artifacts" 41 | 42 | # Create each of our directories 43 | foreach ($dir in @($ContainerdPath, $KubernetesPath, $EKSPath, $CNIPath, $CSIProxyPath, $EKSLogsPath, $TempRoot)) { 44 | New-Item -Path $dir -ItemType Directory -Force | Out-Null 45 | } 46 | 47 | # Install the NVIDIA GPU drivers 48 | $driverBucket = 'ec2-windows-nvidia-drivers' 49 | $driver = Get-S3Object -BucketName $driverBucket -KeyPrefix 'latest' -Region 'us-east-1' | Where-Object {$_.Key.Contains('server2022')} 50 | Copy-S3Object -BucketName $driverBucket -Key $driver.Key -LocalFile "$TempRoot\driver.exe" -Region 'us-east-1' 51 | Start-Process -FilePath "$TempRoot\driver.exe" -ArgumentList @('-s', '-noreboot') -NoNewWindow -Wait 52 | 53 | # Download the Kubernetes components 54 | $webClient = New-Object System.Net.WebClient 55 | $webClient.DownloadFile("$KubernetesDownload/kubelet.exe", "$KubernetesPath\kubelet.exe") 56 | $webClient.DownloadFile("$KubernetesDownload/kube-proxy.exe", "$KubernetesPath\kube-proxy.exe") 57 | $webClient.DownloadFile("$KubernetesDownload/aws-iam-authenticator.exe", "$EKSPath\aws-iam-authenticator.exe") 58 | 59 | # Download the EKS artifacts archive 60 | $webClient.DownloadFile("https://ec2imagebuilder-managed-resources-us-east-1-prod.s3.amazonaws.com/components/eks-optimized-ami-windows/1.24.0/EKS-Artifacts.zip", "C:\EKS-Artifacts.zip") 61 | 62 | # Extract the EKS artifacts archive 63 | Expand-Archive -Path "C:\EKS-Artifacts.zip" -DestinationPath $TempRoot 64 | Remove-Item -Path "C:\EKS-Artifacts.zip" -Force 65 | 66 | # Move the EKS files into place 67 | Move-Item -Path "$TempPath\ctr.exe" -Destination "$ContainerdPath\ctr.exe" -Force 68 | Move-Item -Path "$TempPath\containerd.exe" -Destination "$ContainerdPath\containerd.exe" -Force 69 | Move-Item -Path "$TempPath\containerd-shim-runhcs-v1.exe" -Destination "$ContainerdPath\containerd-shim-runhcs-v1.exe" -Force 70 | Move-Item -Path "$TempPath\Start-EKSBootstrap.ps1" -Destination "$EKSPath\Start-EKSBootstrap.ps1" -Force 71 | Move-Item -Path "$TempPath\EKS-StartupTask.ps1" -Destination "$EKSPath\EKS-StartupTask.ps1" -Force 72 | Move-Item -Path "$TempPath\vpc-shared-eni.exe" -Destination "$CNIPath\vpc-shared-eni.exe" -Force 73 | Move-Item -Path "$TempPath\csi-proxy.exe" -Destination "$CSIProxyPath\csi-proxy.exe" -Force 74 | 75 | # Install the Windows Containers feature 76 | # (Note: this is actually a no-op here, since we install the feature beforehand in startup.ps1) 77 | Install-WindowsFeature -Name Containers 78 | 79 | # ------- 80 | 81 | # TEMPORARY UNTIL EKS ADDS SUPPORT FOR CONTAINERD v1.7.0: 82 | # Download and extract the containerd 1.7.0 release build 83 | $containerdTarball = "$TempPath\containerd-1.7.0.tar.gz" 84 | $containerdFiles = "$TempPath\containerd-1.7.0" 85 | $webClient.DownloadFile('https://github.com/containerd/containerd/releases/download/v1.7.0/containerd-1.7.0-windows-amd64.tar.gz', $containerdTarball) 86 | New-Item -Path "$containerdFiles" -ItemType Directory -Force | Out-Null 87 | tar.exe -xvzf "$containerdTarball" -C "$containerdFiles" 88 | 89 | # Move the containerd files into place 90 | Move-Item -Path "$containerdFiles\bin\containerd.exe" -Destination "$ContainerdPath\containerd.exe" -Force 91 | Move-Item -Path "$containerdFiles\bin\containerd-shim-runhcs-v1.exe" -Destination "$ContainerdPath\containerd-shim-runhcs-v1.exe" -Force 92 | Move-Item -Path "$containerdFiles\bin\ctr.exe" -Destination "$ContainerdPath\ctr.exe" -Force 93 | 94 | # Clean up the containerd intermediate files 95 | Remove-Item -Path "$containerdFiles" -Recurse -Force 96 | Remove-Item -Path "$containerdTarball" -Force 97 | 98 | # ------- 99 | 100 | # Patch the containerd setup script to configure a log file (rather than just discarding log output) and to use the upstream pause 101 | # container image rather than the EKS version, since the latter appears to cause errors when attempting to create Windows Pods 102 | PatchFile -File "$TempPath\Add-ContainerdRuntime.ps1" -Patches @{ 103 | "containerd --register-service" = "containerd --register-service --log-file 'C:\ProgramData\containerd\root\output.log'"; 104 | "amazonaws.com/eks/pause-windows:latest" = "registry.k8s.io/pause:3.9" 105 | } 106 | 107 | # Add the full Windows Server 2022 base image and the pause image to the list of images to pre-pull 108 | $baseLayersFile = "$TempPath\eks.baselayers.config" 109 | $baseLayers = Get-Content -Path $baseLayersFile -Raw | ConvertFrom-Json 110 | $baseLayers.2022 += "mcr.microsoft.com/windows/server:ltsc2022" 111 | $baseLayers.2022 += "registry.k8s.io/pause:3.9" 112 | $patchedJson = ConvertTo-Json -Depth 100 -InputObject $baseLayers 113 | Set-Content -Path $baseLayersFile -Value $patchedJson -NoNewline 114 | 115 | # Register containerd as the EKS container runtime 116 | Push-Location $TempPath 117 | & .\Add-ContainerdRuntime.ps1 -Path "$ContainerdPath" 118 | Pop-Location 119 | 120 | # Perform EKS worker node setup 121 | Push-Location $TempPath 122 | & .\create-windows-pause-image.ps1 -ContainerRuntime containerd 123 | & .\Get-EKSBaseLayers.ps1 -ConfigFile eks.baselayers.config -ContainerRuntime containerd 124 | & .\Add-CSIProxy.ps1 -Path "$CSIProxyPath" -LogPath "$EKSLogsPath" 125 | & .\EKS-WindowsServiceHost.ps1 126 | & .\Install-EKSWorkerNode.ps1 127 | Pop-Location 128 | 129 | # Perform cleanup 130 | Remove-Item -Path "$TempRoot" -Recurse -Force 131 | -------------------------------------------------------------------------------- /cloud/aws/node/scripts/startup.ps1: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Install the OpenSSH server and set the sshd service to start automatically at system startup 4 | Add-WindowsCapability -Online -Name OpenSSH.Server~~~~0.0.1.0 5 | Set-Service -Name sshd -StartupType 'Automatic' 6 | 7 | # Create the OpenSSH configuration directory if it doesn't already exist 8 | $sshDir = 'C:\ProgramData\ssh' 9 | if ((Test-Path -Path $sshDir) -eq $false) { 10 | New-Item -Path $sshDir -ItemType Directory -Force | Out-Null 11 | } 12 | 13 | # Retrieve the SHH public key from the EC2 metadata service 14 | $authorisedKeys = "$sshDir\administrators_authorized_keys" 15 | curl.exe 'http://169.254.169.254/latest/meta-data/public-keys/0/openssh-key' -o "$authorisedKeys" 16 | 17 | # Set the required ACLs for the authorised keys file 18 | icacls.exe "$authorisedKeys" /inheritance:r /grant "Administrators:F" /grant "SYSTEM:F" 19 | 20 | # Install the Windows feature for containers, which will require a reboot 21 | Install-WindowsFeature -Name Containers -IncludeAllSubFeature 22 | 23 | # Restart the VM 24 | Restart-Computer 25 | 26 | -------------------------------------------------------------------------------- /deployments/default-daemonsets.yml: -------------------------------------------------------------------------------- 1 | # HostProcess DaemonSets for the MCDM device plugin and the WDDM device plugin, using default settings 2 | 3 | apiVersion: apps/v1 4 | kind: DaemonSet 5 | metadata: 6 | name: device-plugin-mcdm 7 | spec: 8 | selector: 9 | matchLabels: 10 | app: device-plugin-mcdm 11 | template: 12 | metadata: 13 | labels: 14 | app: device-plugin-mcdm 15 | spec: 16 | nodeSelector: 17 | kubernetes.io/os: 'windows' 18 | kubernetes.io/arch: 'amd64' 19 | node.kubernetes.io/windows-build: '10.0.20348' 20 | securityContext: 21 | windowsOptions: 22 | hostProcess: true 23 | runAsUserName: "NT AUTHORITY\\SYSTEM" 24 | hostNetwork: true 25 | containers: 26 | - name: device-plugin-mcdm 27 | image: "index.docker.io/tensorworks/mcdm-device-plugin:0.0.1" 28 | imagePullPolicy: Always 29 | 30 | --- 31 | 32 | apiVersion: apps/v1 33 | kind: DaemonSet 34 | metadata: 35 | name: device-plugin-wddm 36 | spec: 37 | selector: 38 | matchLabels: 39 | app: device-plugin-wddm 40 | template: 41 | metadata: 42 | labels: 43 | app: device-plugin-wddm 44 | spec: 45 | nodeSelector: 46 | kubernetes.io/os: 'windows' 47 | kubernetes.io/arch: 'amd64' 48 | node.kubernetes.io/windows-build: '10.0.20348' 49 | securityContext: 50 | windowsOptions: 51 | hostProcess: true 52 | runAsUserName: "NT AUTHORITY\\SYSTEM" 53 | hostNetwork: true 54 | containers: 55 | - name: device-plugin-wddm 56 | image: "index.docker.io/tensorworks/wddm-device-plugin:0.0.1" 57 | imagePullPolicy: Always 58 | -------------------------------------------------------------------------------- /deployments/multitenancy-configmap.yml: -------------------------------------------------------------------------------- 1 | # Example HostProcess DaemonSets for the MCDM device plugin and the WDDM device plugin, using settings that enable multitenancy 2 | # 3 | # This version of the DaemonSets uses a ConfigMap to provide configuration values. For a version that sets environment variable 4 | # values directly in the Pod spec, see the file `multitenancy-inline.yml` 5 | 6 | apiVersion: v1 7 | kind: ConfigMap 8 | metadata: 9 | name: device-plugin-config 10 | data: 11 | 12 | # Configure the device plugins to allow 4 containers to mount each device simultaneously 13 | multitenancy: '4' 14 | 15 | --- 16 | 17 | apiVersion: apps/v1 18 | kind: DaemonSet 19 | metadata: 20 | name: device-plugin-mcdm 21 | spec: 22 | selector: 23 | matchLabels: 24 | app: device-plugin-mcdm 25 | template: 26 | metadata: 27 | labels: 28 | app: device-plugin-mcdm 29 | spec: 30 | nodeSelector: 31 | kubernetes.io/os: 'windows' 32 | kubernetes.io/arch: 'amd64' 33 | node.kubernetes.io/windows-build: '10.0.20348' 34 | securityContext: 35 | windowsOptions: 36 | hostProcess: true 37 | runAsUserName: "NT AUTHORITY\\SYSTEM" 38 | hostNetwork: true 39 | containers: 40 | - name: device-plugin-mcdm 41 | image: "index.docker.io/tensorworks/mcdm-device-plugin:0.0.1" 42 | imagePullPolicy: Always 43 | 44 | # Use the configuration values from the ConfigMap 45 | env: 46 | - name: MCDM_DEVICE_PLUGIN_MULTITENANCY 47 | valueFrom: 48 | configMapKeyRef: 49 | name: device-plugin-config 50 | key: multitenancy 51 | 52 | --- 53 | 54 | apiVersion: apps/v1 55 | kind: DaemonSet 56 | metadata: 57 | name: device-plugin-wddm 58 | spec: 59 | selector: 60 | matchLabels: 61 | app: device-plugin-wddm 62 | template: 63 | metadata: 64 | labels: 65 | app: device-plugin-wddm 66 | spec: 67 | nodeSelector: 68 | kubernetes.io/os: 'windows' 69 | kubernetes.io/arch: 'amd64' 70 | node.kubernetes.io/windows-build: '10.0.20348' 71 | securityContext: 72 | windowsOptions: 73 | hostProcess: true 74 | runAsUserName: "NT AUTHORITY\\SYSTEM" 75 | hostNetwork: true 76 | containers: 77 | - name: device-plugin-wddm 78 | image: "index.docker.io/tensorworks/wddm-device-plugin:0.0.1" 79 | imagePullPolicy: Always 80 | 81 | # Use the configuration values from the ConfigMap 82 | env: 83 | - name: WDDM_DEVICE_PLUGIN_MULTITENANCY 84 | valueFrom: 85 | configMapKeyRef: 86 | name: device-plugin-config 87 | key: multitenancy 88 | -------------------------------------------------------------------------------- /deployments/multitenancy-inline.yml: -------------------------------------------------------------------------------- 1 | # Example HostProcess DaemonSets for the MCDM device plugin and the WDDM device plugin, using settings that enable multitenancy 2 | # 3 | # This version of the DaemonSets sets environment variable values directly in the Pod spec. For a version that uses a ConfigMap 4 | # to provide configuration values, see the file `multitenancy-configmap.yml` 5 | 6 | apiVersion: apps/v1 7 | kind: DaemonSet 8 | metadata: 9 | name: device-plugin-mcdm 10 | spec: 11 | selector: 12 | matchLabels: 13 | app: device-plugin-mcdm 14 | template: 15 | metadata: 16 | labels: 17 | app: device-plugin-mcdm 18 | spec: 19 | nodeSelector: 20 | kubernetes.io/os: 'windows' 21 | kubernetes.io/arch: 'amd64' 22 | node.kubernetes.io/windows-build: '10.0.20348' 23 | securityContext: 24 | windowsOptions: 25 | hostProcess: true 26 | runAsUserName: "NT AUTHORITY\\SYSTEM" 27 | hostNetwork: true 28 | containers: 29 | - name: device-plugin-mcdm 30 | image: "index.docker.io/tensorworks/mcdm-device-plugin:0.0.1" 31 | imagePullPolicy: Always 32 | 33 | # Configure the MCDM device plugin to allow 4 containers to mount each compute-only device simultaneously 34 | env: 35 | - name: MCDM_DEVICE_PLUGIN_MULTITENANCY 36 | value: "4" 37 | 38 | --- 39 | 40 | apiVersion: apps/v1 41 | kind: DaemonSet 42 | metadata: 43 | name: device-plugin-wddm 44 | spec: 45 | selector: 46 | matchLabels: 47 | app: device-plugin-wddm 48 | template: 49 | metadata: 50 | labels: 51 | app: device-plugin-wddm 52 | spec: 53 | nodeSelector: 54 | kubernetes.io/os: 'windows' 55 | kubernetes.io/arch: 'amd64' 56 | node.kubernetes.io/windows-build: '10.0.20348' 57 | securityContext: 58 | windowsOptions: 59 | hostProcess: true 60 | runAsUserName: "NT AUTHORITY\\SYSTEM" 61 | hostNetwork: true 62 | containers: 63 | - name: device-plugin-wddm 64 | image: "index.docker.io/tensorworks/wddm-device-plugin:0.0.1" 65 | imagePullPolicy: Always 66 | 67 | # Configure the WDDM device plugin to allow 4 containers to mount each display device simultaneously 68 | env: 69 | - name: WDDM_DEVICE_PLUGIN_MULTITENANCY 70 | value: "4" 71 | -------------------------------------------------------------------------------- /examples/cuda-devicequery/cuda-devicequery-mcdm.yml: -------------------------------------------------------------------------------- 1 | # Example Job for running the CUDA deviceQuery sample program inside a container 2 | # 3 | # This version of the Job requests a compute-only device from the MCDM device plugin. For a version that 4 | # requests a display device from the WDDM device plugin, see the file `cuda-devicequery-wddm.yml` 5 | # 6 | # NOTE: this Job will only work when the device allocated by the MCDM device plugin is an NVIDIA GPU, 7 | # otherwise the DLL files required by `deviceQuery.exe` won't exist and the Pod will fail to start. 8 | 9 | apiVersion: batch/v1 10 | kind: Job 11 | metadata: 12 | name: example-cuda-devicequery-mcdm 13 | spec: 14 | template: 15 | spec: 16 | containers: 17 | - name: example-cuda-devicequery-mcdm 18 | image: "index.docker.io/tensorworks/example-cuda-devicequery:0.0.1" 19 | resources: 20 | limits: 21 | directx.microsoft.com/compute: 1 22 | nodeSelector: 23 | "kubernetes.io/os": windows 24 | restartPolicy: Never 25 | backoffLimit: 0 26 | -------------------------------------------------------------------------------- /examples/cuda-devicequery/cuda-devicequery-wddm.yml: -------------------------------------------------------------------------------- 1 | # Example Job for running the CUDA deviceQuery sample program inside a container 2 | # 3 | # This version of the Job requests a display device from the WDDM device plugin. For a version that 4 | # requests a compute-only device from the MCDM device plugin, see the file `cuda-devicequery-mcdm.yml` 5 | # 6 | # NOTE: this Job will only work when the device allocated by the WDDM device plugin is an NVIDIA GPU, 7 | # otherwise the DLL files required by `deviceQuery.exe` won't exist and the Pod will fail to start. 8 | 9 | apiVersion: batch/v1 10 | kind: Job 11 | metadata: 12 | name: example-cuda-devicequery-wddm 13 | spec: 14 | template: 15 | spec: 16 | containers: 17 | - name: example-cuda-devicequery-wddm 18 | image: "index.docker.io/tensorworks/example-cuda-devicequery:0.0.1" 19 | resources: 20 | limits: 21 | directx.microsoft.com/display: 1 22 | nodeSelector: 23 | "kubernetes.io/os": windows 24 | restartPolicy: Never 25 | backoffLimit: 0 26 | -------------------------------------------------------------------------------- /examples/cuda-montecarlo/cuda-montecarlo-mcdm.yml: -------------------------------------------------------------------------------- 1 | # Example Job for running the CUDA MC_EstimatePiP sample program inside a container 2 | # 3 | # This version of the Job requests a compute-only device from the MCDM device plugin. For a version that 4 | # requests a display device from the WDDM device plugin, see the file `cuda-montecarlo-wddm.yml` 5 | # 6 | # NOTE: this Job will only work when the device allocated by the MCDM device plugin is an NVIDIA GPU, 7 | # otherwise the DLL files required by `MC_EstimatePiP.exe` won't exist and the Pod will fail to start. 8 | 9 | apiVersion: batch/v1 10 | kind: Job 11 | metadata: 12 | name: example-cuda-montecarlo-mcdm 13 | spec: 14 | template: 15 | spec: 16 | containers: 17 | - name: example-cuda-montecarlo-mcdm 18 | image: "index.docker.io/tensorworks/example-cuda-montecarlo:0.0.1" 19 | resources: 20 | limits: 21 | directx.microsoft.com/compute: 1 22 | nodeSelector: 23 | "kubernetes.io/os": windows 24 | restartPolicy: Never 25 | backoffLimit: 0 26 | -------------------------------------------------------------------------------- /examples/cuda-montecarlo/cuda-montecarlo-wddm.yml: -------------------------------------------------------------------------------- 1 | # Example Job for running the CUDA MC_EstimatePiP sample program inside a container 2 | # 3 | # This version of the Job requests a display device from the WDDM device plugin. For a version that 4 | # requests a compute-only device from the MCDM device plugin, see the file `cuda-montecarlo-mcdm.yml` 5 | # 6 | # NOTE: this Job will only work when the device allocated by the WDDM device plugin is an NVIDIA GPU, 7 | # otherwise the DLL files required by `MC_EstimatePiP.exe` won't exist and the Pod will fail to start. 8 | 9 | apiVersion: batch/v1 10 | kind: Job 11 | metadata: 12 | name: example-cuda-montecarlo-wddm 13 | spec: 14 | template: 15 | spec: 16 | containers: 17 | - name: example-cuda-montecarlo-wddm 18 | image: "index.docker.io/tensorworks/example-cuda-montecarlo:0.0.1" 19 | resources: 20 | limits: 21 | directx.microsoft.com/display: 1 22 | nodeSelector: 23 | "kubernetes.io/os": windows 24 | restartPolicy: Never 25 | backoffLimit: 0 26 | -------------------------------------------------------------------------------- /examples/device-discovery/device-discovery-mcdm.yml: -------------------------------------------------------------------------------- 1 | # Example Job for running the device discovery test program inside a container 2 | # 3 | # This version of the Job requests a compute-only device from the MCDM device plugin. For a version that 4 | # requests a display device from the WDDM device plugin, see the file `device-discovery-wddm.yml` 5 | 6 | apiVersion: batch/v1 7 | kind: Job 8 | metadata: 9 | name: example-device-discovery-mcdm 10 | spec: 11 | template: 12 | spec: 13 | containers: 14 | - name: example-device-discovery-mcdm 15 | image: "index.docker.io/tensorworks/example-device-discovery:0.0.1" 16 | resources: 17 | limits: 18 | directx.microsoft.com/compute: 1 19 | nodeSelector: 20 | "kubernetes.io/os": windows 21 | restartPolicy: Never 22 | backoffLimit: 0 23 | -------------------------------------------------------------------------------- /examples/device-discovery/device-discovery-wddm.yml: -------------------------------------------------------------------------------- 1 | # Example Job for running the device discovery test program inside a container 2 | # 3 | # This version of the Job requests a display device from the WDDM device plugin. For a version that 4 | # requests a compute-only device from the MCDM device plugin, see the file `device-discovery-mcdm.yml` 5 | 6 | apiVersion: batch/v1 7 | kind: Job 8 | metadata: 9 | name: example-device-discovery-wddm 10 | spec: 11 | template: 12 | spec: 13 | containers: 14 | - name: example-device-discovery-wddm 15 | image: "index.docker.io/tensorworks/example-device-discovery:0.0.1" 16 | resources: 17 | limits: 18 | directx.microsoft.com/display: 1 19 | nodeSelector: 20 | "kubernetes.io/os": windows 21 | restartPolicy: Never 22 | backoffLimit: 0 23 | -------------------------------------------------------------------------------- /examples/directml/directml-mcdm.yml: -------------------------------------------------------------------------------- 1 | # Example Job for running a DirectML sample inside a container 2 | # 3 | # This version of the Job requests a compute-only device from the MCDM device plugin. For a version that 4 | # requests a display device from the WDDM device plugin, see the file `directml-wddm.yml` 5 | 6 | apiVersion: batch/v1 7 | kind: Job 8 | metadata: 9 | name: example-directml-mcdm 10 | spec: 11 | template: 12 | spec: 13 | containers: 14 | - name: example-directml-mcdm 15 | image: "index.docker.io/tensorworks/example-directml:0.0.1" 16 | resources: 17 | limits: 18 | directx.microsoft.com/compute: 1 19 | nodeSelector: 20 | "kubernetes.io/os": windows 21 | restartPolicy: Never 22 | backoffLimit: 0 23 | -------------------------------------------------------------------------------- /examples/directml/directml-wddm.yml: -------------------------------------------------------------------------------- 1 | # Example Job for running a DirectML sample inside a container 2 | # 3 | # This version of the Job requests a display device from the WDDM device plugin. For a version that 4 | # requests a compute-only device from the MCDM device plugin, see the file `directml-mcdm.yml` 5 | 6 | apiVersion: batch/v1 7 | kind: Job 8 | metadata: 9 | name: example-directml-wddm 10 | spec: 11 | template: 12 | spec: 13 | containers: 14 | - name: example-directml-wddm 15 | image: "index.docker.io/tensorworks/example-directml:0.0.1" 16 | resources: 17 | limits: 18 | directx.microsoft.com/display: 1 19 | nodeSelector: 20 | "kubernetes.io/os": windows 21 | restartPolicy: Never 22 | backoffLimit: 0 23 | -------------------------------------------------------------------------------- /examples/ffmpeg-amf/ffmpeg-amf.yml: -------------------------------------------------------------------------------- 1 | # Example Job for running an AMD AMF transcode operation with FFmpeg inside a container 2 | # 3 | # NOTE: this Job will only work when the device allocated by the WDDM device plugin is an AMD GPU, 4 | # otherwise the DLL files for AMF won't exist and FFmpeg will fail when it tries to load them. 5 | 6 | apiVersion: batch/v1 7 | kind: Job 8 | metadata: 9 | name: example-ffmpeg-amf 10 | spec: 11 | template: 12 | spec: 13 | containers: 14 | - name: example-ffmpeg-amf 15 | image: "index.docker.io/tensorworks/example-ffmpeg:0.0.1" 16 | args: ["-i", "C:\\sample-video.mp4", "-c:v", "h264_amf", "-preset", "default", "C:\\output.mp4"] 17 | resources: 18 | limits: 19 | directx.microsoft.com/display: 1 20 | nodeSelector: 21 | "kubernetes.io/os": windows 22 | restartPolicy: Never 23 | backoffLimit: 0 24 | -------------------------------------------------------------------------------- /examples/ffmpeg-autodetect/autodetect-encoder.ps1: -------------------------------------------------------------------------------- 1 | # Attempt to detect the availability of a hardware video encoder 2 | $encoder = '' 3 | if ((Get-ChildItem "C:\Windows\System32\amfrt64.dll" -ErrorAction SilentlyContinue)) 4 | { 5 | Write-Host 'Detected an AMD GPU, using the AMF video encoder' 6 | $encoder = 'h264_amf' 7 | } 8 | elseif ((Get-ChildItem "C:\Windows\System32\intel_gfx_api-x64.dll" -ErrorAction SilentlyContinue)) 9 | { 10 | Write-Host 'Detected an Intel GPU, using the Quick Sync video encoder' 11 | $encoder = 'h264_qsv' 12 | } 13 | elseif ((Get-ChildItem "C:\Windows\System32\nvEncodeAPI64.dll" -ErrorAction SilentlyContinue)) 14 | { 15 | Write-Host 'Detected an NVIDIA GPU, using the NVENC video encoder' 16 | $encoder = 'h264_nvenc' 17 | } 18 | else { 19 | throw "Failed to detect the availability of a supported hardware video encoder" 20 | } 21 | 22 | # Invoke FFmpeg with the detected hardware video encoder 23 | & C:\ffmpeg.exe -i C:\sample-video.mp4 -c:v "$encoder" -preset default C:\output.mp4 24 | if ($LastExitCode -ne 0) { 25 | throw "FFmpeg terminated with exit code $LastExitCode" 26 | } 27 | -------------------------------------------------------------------------------- /examples/ffmpeg-autodetect/ffmpeg-autodetect.yml: -------------------------------------------------------------------------------- 1 | # Example Job for running a hardware accelerated transcode operation with FFmpeg inside a container 2 | # 3 | # The transcode script will attempt to detect the availability of the following encoders: 4 | # 5 | # - AMD AMF 6 | # - Intel Quick Sync 7 | # - NVIDIA NVENC 8 | # 9 | # If a hardware encoder is detected then it will be used, otherwise the script will fail. 10 | # 11 | # NOTE: this Job will only work when the device allocated by the WDDM device plugin is an AMD, Intel or NVIDIA GPU, 12 | # otherwise the DLL files for the hardware encoders won't exist and the script will fail when no encoder is detected. 13 | 14 | apiVersion: batch/v1 15 | kind: Job 16 | metadata: 17 | name: example-ffmpeg-autodetect 18 | spec: 19 | template: 20 | spec: 21 | containers: 22 | - name: example-ffmpeg-autodetect 23 | image: "index.docker.io/tensorworks/example-ffmpeg:0.0.1" 24 | command: ["powershell"] 25 | args: ["-ExecutionPolicy", "Bypass", "-File", "C:\\autodetect-encoder.ps1"] 26 | resources: 27 | limits: 28 | directx.microsoft.com/display: 1 29 | nodeSelector: 30 | "kubernetes.io/os": windows 31 | restartPolicy: Never 32 | backoffLimit: 0 33 | -------------------------------------------------------------------------------- /examples/ffmpeg-nvenc/ffmpeg-nvenc.yml: -------------------------------------------------------------------------------- 1 | # Example Job for running an NVIDIA NVENC transcode operation with FFmpeg inside a container 2 | # 3 | # NOTE: this Job will only work when the device allocated by the WDDM device plugin is an NVIDIA GPU, 4 | # otherwise the DLL files for CUDA and NVENC won't exist and FFmpeg will fail when it tries to load them. 5 | 6 | apiVersion: batch/v1 7 | kind: Job 8 | metadata: 9 | name: example-ffmpeg-nvenc 10 | spec: 11 | template: 12 | spec: 13 | containers: 14 | - name: example-ffmpeg-nvenc 15 | image: "index.docker.io/tensorworks/example-ffmpeg:0.0.1" 16 | args: ["-i", "C:\\sample-video.mp4", "-c:v", "h264_nvenc", "-preset", "default", "C:\\output.mp4"] 17 | resources: 18 | limits: 19 | directx.microsoft.com/display: 1 20 | nodeSelector: 21 | "kubernetes.io/os": windows 22 | restartPolicy: Never 23 | backoffLimit: 0 24 | -------------------------------------------------------------------------------- /examples/ffmpeg-quicksync/ffmpeg-quicksync.yml: -------------------------------------------------------------------------------- 1 | # Example Job for running an Intel Quick Sync transcode operation with FFmpeg inside a container 2 | # 3 | # NOTE: this Job will only work when the device allocated by the WDDM device plugin is an Intel GPU, 4 | # otherwise the DLL files for Quick Sync won't exist and FFmpeg will fail when it tries to load them. 5 | 6 | apiVersion: batch/v1 7 | kind: Job 8 | metadata: 9 | name: example-ffmpeg-quicksync 10 | spec: 11 | template: 12 | spec: 13 | containers: 14 | - name: example-ffmpeg-quicksync 15 | image: "index.docker.io/tensorworks/example-ffmpeg:0.0.1" 16 | args: ["-i", "C:\\sample-video.mp4", "-c:v", "h264_qsv", "-preset", "default", "C:\\output.mp4"] 17 | resources: 18 | limits: 19 | directx.microsoft.com/display: 1 20 | nodeSelector: 21 | "kubernetes.io/os": windows 22 | restartPolicy: Never 23 | backoffLimit: 0 24 | -------------------------------------------------------------------------------- /examples/nvidia-smi/nvidia-smi-mcdm.yml: -------------------------------------------------------------------------------- 1 | # Example Job for running the NVIDIA SMI tool inside a container 2 | # 3 | # This version of the Job requests a compute-only device from the MCDM device plugin. For a version that 4 | # requests a display device from the WDDM device plugin, see the file `nvidia-smi-wddm.yml` 5 | # 6 | # NOTE: this Job will only work when the device allocated by the MCDM device plugin is an NVIDIA GPU, 7 | # otherwise the executable `nvidia-smi.exe` won't exist and the Pod will fail to start. 8 | 9 | apiVersion: batch/v1 10 | kind: Job 11 | metadata: 12 | name: example-nvidia-smi 13 | spec: 14 | template: 15 | spec: 16 | containers: 17 | - name: example-nvidia-smi 18 | image: "mcr.microsoft.com/windows/servercore:ltsc2022" 19 | command: ["nvidia-smi.exe"] 20 | resources: 21 | limits: 22 | directx.microsoft.com/compute: 1 23 | nodeSelector: 24 | "kubernetes.io/os": windows 25 | restartPolicy: Never 26 | backoffLimit: 0 27 | -------------------------------------------------------------------------------- /examples/nvidia-smi/nvidia-smi-wddm.yml: -------------------------------------------------------------------------------- 1 | # Example Job for running the NVIDIA SMI tool inside a container 2 | # 3 | # This version of the Job requests a display device from the WDDM device plugin. For a version that 4 | # requests a compute-only device from the MCDM device plugin, see the file `nvidia-smi-mcdm.yml` 5 | # 6 | # NOTE: this Job will only work when the device allocated by the WDDM device plugin is an NVIDIA GPU, 7 | # otherwise the executable `nvidia-smi.exe` won't exist and the Pod will fail to start. 8 | 9 | apiVersion: batch/v1 10 | kind: Job 11 | metadata: 12 | name: example-nvidia-smi 13 | spec: 14 | template: 15 | spec: 16 | containers: 17 | - name: example-nvidia-smi 18 | image: "mcr.microsoft.com/windows/servercore:ltsc2022" 19 | command: ["nvidia-smi.exe"] 20 | resources: 21 | limits: 22 | directx.microsoft.com/display: 1 23 | nodeSelector: 24 | "kubernetes.io/os": windows 25 | restartPolicy: Never 26 | backoffLimit: 0 27 | -------------------------------------------------------------------------------- /examples/opencl-enum/opencl-enum-mcdm.yml: -------------------------------------------------------------------------------- 1 | # Example Job for running the OpenCL enumopencl sample program inside a container 2 | # 3 | # This version of the Job requests a compute-only device from the MCDM device plugin. For a version that 4 | # requests a display device from the WDDM device plugin, see the file `opencl-enum-wddm.yml` 5 | # 6 | # NOTE: this Job will only work when the device allocated by the MCDM device plugin is is a GPU that supports 7 | # OpenCL, otherwise the DLL files required by `enumopencl.exe` won't exist and the Pod will fail to start. 8 | 9 | apiVersion: batch/v1 10 | kind: Job 11 | metadata: 12 | name: example-opencl-enum-mcdm 13 | spec: 14 | template: 15 | spec: 16 | containers: 17 | - name: example-opencl-enum-mcdm 18 | image: "index.docker.io/tensorworks/example-opencl-enum:0.0.1" 19 | resources: 20 | limits: 21 | directx.microsoft.com/compute: 1 22 | nodeSelector: 23 | "kubernetes.io/os": windows 24 | restartPolicy: Never 25 | backoffLimit: 0 26 | -------------------------------------------------------------------------------- /examples/opencl-enum/opencl-enum-wddm.yml: -------------------------------------------------------------------------------- 1 | # Example Job for running the OpenCL enumopencl sample program inside a container 2 | # 3 | # This version of the Job requests a display device from the WDDM device plugin. For a version that 4 | # requests a compute-only device from the MCDM device plugin, see the file `opencl-enum-mcdm.yml` 5 | # 6 | # NOTE: this Job will only work when the device allocated by the WDDM device plugin is is a GPU that supports 7 | # OpenCL, otherwise the DLL files required by `enumopencl.exe` won't exist and the Pod will fail to start. 8 | 9 | apiVersion: batch/v1 10 | kind: Job 11 | metadata: 12 | name: example-opencl-enum-wddm 13 | spec: 14 | template: 15 | spec: 16 | containers: 17 | - name: example-opencl-enum-wddm 18 | image: "index.docker.io/tensorworks/example-opencl-enum:0.0.1" 19 | resources: 20 | limits: 21 | directx.microsoft.com/display: 1 22 | nodeSelector: 23 | "kubernetes.io/os": windows 24 | restartPolicy: Never 25 | backoffLimit: 0 26 | -------------------------------------------------------------------------------- /examples/vulkaninfo/vulkaninfo.yml: -------------------------------------------------------------------------------- 1 | # Example Job for running the Vulkan information tool inside a container 2 | # 3 | # NOTE: this Job will only work when the device allocated by the WDDM device plugin is a GPU that supports 4 | # Vulkan, otherwise the executable `vulkaninfo.exe` won't exist and the Pod will fail to start. 5 | 6 | apiVersion: batch/v1 7 | kind: Job 8 | metadata: 9 | name: example-vulkaninfo 10 | spec: 11 | template: 12 | spec: 13 | containers: 14 | - name: example-vulkaninfo 15 | image: "mcr.microsoft.com/windows/server:ltsc2022" 16 | command: ["vulkaninfo.exe"] 17 | resources: 18 | limits: 19 | directx.microsoft.com/display: 1 20 | nodeSelector: 21 | "kubernetes.io/os": windows 22 | restartPolicy: Never 23 | backoffLimit: 0 24 | -------------------------------------------------------------------------------- /external/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /library/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.22) 2 | project(directx-device-discovery) 3 | 4 | # Set the C++ standard to C++17 5 | set(CMAKE_CXX_STANDARD 17) 6 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 7 | set(CMAKE_CXX_EXTENSIONS OFF) 8 | 9 | # Locate our dependencies (these will be provided by vcpkg) 10 | find_package(cppwinrt CONFIG REQUIRED) 11 | find_package(fmt CONFIG REQUIRED) 12 | find_package(spdlog CONFIG REQUIRED) 13 | find_package(wil CONFIG REQUIRED) 14 | 15 | # Build our shared library 16 | add_library(directx-device-discovery SHARED 17 | src/AdapterEnumeration.cpp 18 | src/D3DHelpers.cpp 19 | src/DeviceDiscovery.cpp 20 | src/DeviceDiscoveryImp.cpp 21 | src/DllMain.cpp 22 | src/ErrorHandling.cpp 23 | src/RegistryQuery.cpp 24 | src/SafeArray.cpp 25 | src/WmiQuery.cpp 26 | ) 27 | target_link_libraries(directx-device-discovery PRIVATE 28 | dxcore.lib 29 | dxguid.lib 30 | fmt::fmt-header-only 31 | gdi32.lib 32 | Microsoft::CppWinRT 33 | spdlog::spdlog_header_only 34 | wbemuuid.lib 35 | WIL::WIL 36 | WindowsApp.lib 37 | ) 38 | set_property(TARGET directx-device-discovery PROPERTY MSVC_RUNTIME_LIBRARY "MultiThreaded") 39 | target_include_directories(directx-device-discovery PUBLIC include) 40 | target_precompile_headers(directx-device-discovery PRIVATE src/pch.h) 41 | 42 | # Build our test executable 43 | add_executable(test-device-discovery-cpp test/test-device-discovery-cpp.cpp) 44 | set_property(TARGET test-device-discovery-cpp PROPERTY MSVC_RUNTIME_LIBRARY "MultiThreaded") 45 | target_link_libraries(test-device-discovery-cpp PRIVATE Microsoft::CppWinRT directx-device-discovery) 46 | 47 | # Install the shared library and the test executable to the top-level bin directory 48 | install( 49 | TARGETS directx-device-discovery test-device-discovery-cpp 50 | RUNTIME DESTINATION bin 51 | ) 52 | -------------------------------------------------------------------------------- /library/include/DeviceDiscovery.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "DeviceFilter.h" 3 | 4 | #define DLLEXPORT __declspec(dllexport) 5 | 6 | #ifdef __cplusplus 7 | extern "C" { 8 | #endif 9 | 10 | // Opaque pointer type for DeviceDiscovery instances 11 | typedef void* DeviceDiscoveryInstance; 12 | 13 | // Returns the version string for the device discovery library 14 | DLLEXPORT const wchar_t* GetDiscoveryLibraryVersion(); 15 | 16 | // Disables verbose logging for the device discovery library (this is the default) 17 | DLLEXPORT void DisableDiscoveryLogging(); 18 | 19 | // Enables verbose logging for the device discovery library 20 | DLLEXPORT void EnableDiscoveryLogging(); 21 | 22 | // Creates a new DeviceDiscovery instance 23 | DLLEXPORT DeviceDiscoveryInstance CreateDeviceDiscoveryInstance(); 24 | 25 | // Frees the memory for a DeviceDiscovery instance 26 | DLLEXPORT void DestroyDeviceDiscoveryInstance(DeviceDiscoveryInstance instance); 27 | 28 | // Retrieves the error message for the last operation performed by the DeviceDiscovery instance. 29 | // If the last operation succeeded then an empty string will be returned. 30 | DLLEXPORT const wchar_t* DeviceDiscovery_GetLastErrorMessage(DeviceDiscoveryInstance instance); 31 | 32 | // Determines whether the current device list is stale and needs to be refreshed by performing device discovery again 33 | DLLEXPORT int DeviceDiscovery_IsRefreshRequired(DeviceDiscoveryInstance instance); 34 | 35 | // Performs device discovery. Returns 0 on success and -1 on failure. 36 | // Call GetLastErrorMessage to retrieve the error details for a failure. 37 | DLLEXPORT int DeviceDiscovery_DiscoverDevices(DeviceDiscoveryInstance instance, int filter, int includeIntegrated, int includeDetachable); 38 | 39 | // Returns the number of devices found by the last device discovery, or -1 if device discovery has not been performed 40 | DLLEXPORT int DeviceDiscovery_GetNumDevices(DeviceDiscoveryInstance instance); 41 | 42 | DLLEXPORT long long DeviceDiscovery_GetDeviceAdapterLUID(DeviceDiscoveryInstance instance, unsigned int device); 43 | 44 | // Returns the unique ID of the device with the specified index, or a NULL pointer if the specified device index is invalid 45 | DLLEXPORT const wchar_t* DeviceDiscovery_GetDeviceID(DeviceDiscoveryInstance instance, unsigned int device); 46 | 47 | DLLEXPORT const wchar_t* DeviceDiscovery_GetDeviceDescription(DeviceDiscoveryInstance instance, unsigned int device); 48 | 49 | DLLEXPORT const wchar_t* DeviceDiscovery_GetDeviceDriverRegistryKey(DeviceDiscoveryInstance instance, unsigned int device); 50 | 51 | DLLEXPORT const wchar_t* DeviceDiscovery_GetDeviceDriverStorePath(DeviceDiscoveryInstance instance, unsigned int device); 52 | 53 | DLLEXPORT const wchar_t* DeviceDiscovery_GetDeviceLocationPath(DeviceDiscoveryInstance instance, unsigned int device); 54 | 55 | DLLEXPORT const wchar_t* DeviceDiscovery_GetDeviceVendor(DeviceDiscoveryInstance instance, unsigned int device); 56 | 57 | DLLEXPORT int DeviceDiscovery_GetNumRuntimeFiles(DeviceDiscoveryInstance instance, unsigned int device); 58 | 59 | DLLEXPORT const wchar_t* DeviceDiscovery_GetRuntimeFileSource(DeviceDiscoveryInstance instance, unsigned int device, unsigned int file); 60 | 61 | DLLEXPORT const wchar_t* DeviceDiscovery_GetRuntimeFileDestination(DeviceDiscoveryInstance instance, unsigned int device, unsigned int file); 62 | 63 | DLLEXPORT int DeviceDiscovery_GetNumRuntimeFilesWow64(DeviceDiscoveryInstance instance, unsigned int device); 64 | 65 | DLLEXPORT const wchar_t* DeviceDiscovery_GetRuntimeFileSourceWow64(DeviceDiscoveryInstance instance, unsigned int device, unsigned int file); 66 | 67 | DLLEXPORT const wchar_t* DeviceDiscovery_GetRuntimeFileDestinationWow64(DeviceDiscoveryInstance instance, unsigned int device, unsigned int file); 68 | 69 | DLLEXPORT int DeviceDiscovery_IsDeviceIntegrated(DeviceDiscoveryInstance instance, unsigned int device); 70 | 71 | DLLEXPORT int DeviceDiscovery_IsDeviceDetachable(DeviceDiscoveryInstance instance, unsigned int device); 72 | 73 | DLLEXPORT int DeviceDiscovery_DoesDeviceSupportDisplay(DeviceDiscoveryInstance instance, unsigned int device); 74 | 75 | DLLEXPORT int DeviceDiscovery_DoesDeviceSupportCompute(DeviceDiscoveryInstance instance, unsigned int device); 76 | 77 | #ifdef __cplusplus 78 | } // extern "C" 79 | 80 | 81 | #include 82 | #include 83 | 84 | // API wrapper classes for C++ clients 85 | 86 | class DeviceDiscoveryException 87 | { 88 | public: 89 | DeviceDiscoveryException(const wchar_t* message) { 90 | this->message = message; 91 | } 92 | 93 | DeviceDiscoveryException(const DeviceDiscoveryException& other) = default; 94 | DeviceDiscoveryException(DeviceDiscoveryException&& other) = default; 95 | DeviceDiscoveryException& operator=(const DeviceDiscoveryException& other) = default; 96 | DeviceDiscoveryException& operator=(DeviceDiscoveryException&& other) = default; 97 | 98 | std::wstring what() const { 99 | return this->message; 100 | } 101 | 102 | private: 103 | std::wstring message; 104 | }; 105 | 106 | class DeviceDiscovery 107 | { 108 | private: 109 | DeviceDiscoveryInstance instance; 110 | 111 | public: 112 | 113 | inline DeviceDiscovery() { 114 | this->instance = CreateDeviceDiscoveryInstance(); 115 | } 116 | 117 | inline ~DeviceDiscovery() 118 | { 119 | DestroyDeviceDiscoveryInstance(this->instance); 120 | this->instance = nullptr; 121 | } 122 | 123 | inline const wchar_t* GetLastErrorMessage() { 124 | return DeviceDiscovery_GetLastErrorMessage(this->instance); 125 | } 126 | 127 | inline bool IsRefreshRequired() { 128 | return DeviceDiscovery_IsRefreshRequired(this->instance); 129 | } 130 | 131 | #define THROW_IF_ERROR(sentinel) if (result == sentinel) { throw DeviceDiscoveryException(DeviceDiscovery_GetLastErrorMessage(this->instance)); } 132 | 133 | inline bool DiscoverDevices(DeviceFilter filter, bool includeIntegrated, bool includeDetachable) 134 | { 135 | int result = DeviceDiscovery_DiscoverDevices(this->instance, static_cast(filter), includeIntegrated, includeDetachable); 136 | THROW_IF_ERROR(-1); 137 | return (result == 0); 138 | } 139 | 140 | inline int GetNumDevices() 141 | { 142 | int result = DeviceDiscovery_GetNumDevices(this->instance); 143 | THROW_IF_ERROR(-1); 144 | return result; 145 | } 146 | 147 | inline long long GetDeviceAdapterLUID(unsigned int device) 148 | { 149 | long long result = DeviceDiscovery_GetDeviceAdapterLUID(this->instance, device); 150 | THROW_IF_ERROR(-1); 151 | return result; 152 | } 153 | 154 | inline const wchar_t* GetDeviceID(unsigned int device) 155 | { 156 | const wchar_t* result = DeviceDiscovery_GetDeviceID(this->instance, device); 157 | THROW_IF_ERROR(nullptr); 158 | return result; 159 | } 160 | 161 | inline const wchar_t* GetDeviceDescription(unsigned int device) 162 | { 163 | const wchar_t* result = DeviceDiscovery_GetDeviceDescription(this->instance, device); 164 | THROW_IF_ERROR(nullptr); 165 | return result; 166 | } 167 | 168 | inline const wchar_t* GetDeviceDriverRegistryKey(unsigned int device) 169 | { 170 | const wchar_t* result = DeviceDiscovery_GetDeviceDriverRegistryKey(this->instance, device); 171 | THROW_IF_ERROR(nullptr); 172 | return result; 173 | } 174 | 175 | inline const wchar_t* GetDeviceDriverStorePath(unsigned int device) 176 | { 177 | const wchar_t* result = DeviceDiscovery_GetDeviceDriverStorePath(this->instance, device); 178 | THROW_IF_ERROR(nullptr); 179 | return result; 180 | } 181 | 182 | inline const wchar_t* GetDeviceLocationPath(unsigned int device) 183 | { 184 | const wchar_t* result = DeviceDiscovery_GetDeviceLocationPath(this->instance, device); 185 | THROW_IF_ERROR(nullptr); 186 | return result; 187 | } 188 | 189 | inline const wchar_t* GetDeviceVendor(unsigned int device) 190 | { 191 | const wchar_t* result = DeviceDiscovery_GetDeviceVendor(this->instance, device); 192 | THROW_IF_ERROR(nullptr); 193 | return result; 194 | } 195 | 196 | inline int GetNumRuntimeFiles(unsigned int device) 197 | { 198 | int result = DeviceDiscovery_GetNumRuntimeFiles(this->instance, device); 199 | THROW_IF_ERROR(-1); 200 | return result; 201 | } 202 | 203 | inline const wchar_t* GetRuntimeFileSource(unsigned int device, unsigned int file) 204 | { 205 | const wchar_t* result = DeviceDiscovery_GetRuntimeFileSource(this->instance, device, file); 206 | THROW_IF_ERROR(nullptr); 207 | return result; 208 | } 209 | 210 | inline const wchar_t* GetRuntimeFileDestination(unsigned int device, unsigned int file) 211 | { 212 | const wchar_t* result = DeviceDiscovery_GetRuntimeFileDestination(this->instance, device, file); 213 | THROW_IF_ERROR(nullptr); 214 | return result; 215 | } 216 | 217 | inline int GetNumRuntimeFilesWow64(unsigned int device) 218 | { 219 | int result = DeviceDiscovery_GetNumRuntimeFilesWow64(this->instance, device); 220 | THROW_IF_ERROR(-1); 221 | return result; 222 | } 223 | 224 | inline const wchar_t* GetRuntimeFileSourceWow64(unsigned int device, unsigned int file) 225 | { 226 | const wchar_t* result = DeviceDiscovery_GetRuntimeFileSourceWow64(this->instance, device, file); 227 | THROW_IF_ERROR(nullptr); 228 | return result; 229 | } 230 | 231 | inline const wchar_t* GetRuntimeFileDestinationWow64(unsigned int device, unsigned int file) 232 | { 233 | const wchar_t* result = DeviceDiscovery_GetRuntimeFileDestinationWow64(this->instance, device, file); 234 | THROW_IF_ERROR(nullptr); 235 | return result; 236 | } 237 | 238 | inline bool IsDeviceIntegrated(unsigned int device) 239 | { 240 | int result = DeviceDiscovery_IsDeviceIntegrated(this->instance, device); 241 | THROW_IF_ERROR(-1); 242 | return result; 243 | } 244 | 245 | inline bool IsDeviceDetachable(unsigned int device) 246 | { 247 | int result = DeviceDiscovery_IsDeviceDetachable(this->instance, device); 248 | THROW_IF_ERROR(-1); 249 | return result; 250 | } 251 | 252 | inline bool DoesDeviceSupportDisplay(unsigned int device) 253 | { 254 | int result = DeviceDiscovery_DoesDeviceSupportDisplay(this->instance, device); 255 | THROW_IF_ERROR(-1); 256 | return result; 257 | } 258 | 259 | inline bool DoesDeviceSupportCompute(unsigned int device) 260 | { 261 | int result = DeviceDiscovery_DoesDeviceSupportCompute(this->instance, device); 262 | THROW_IF_ERROR(-1); 263 | return result; 264 | } 265 | 266 | #undef THROW_IF_ERROR 267 | }; 268 | 269 | #endif 270 | -------------------------------------------------------------------------------- /library/include/DeviceFilter.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | // Enumerate all devices 4 | #define DEVICEFILTER_ALL 0 5 | 6 | // Enumerate devices that support display, irrespective of whether they also support compute 7 | #define DEVICEFILTER_DISPLAY_SUPPORTED 1 8 | 9 | // Enumerate devices that support compute, irrespective of whether they also support display 10 | #define DEVICEFILTER_COMPUTE_SUPPORTED 2 11 | 12 | // Enumerate devices that support display and do not support compute (e.g. legacy DirectX 11 devices) 13 | #define DEVICEFILTER_DISPLAY_ONLY 3 14 | 15 | // Enumerate devices that support compute and do not support display (i.e. compute-only DirectX 12 devices) 16 | #define DEVICEFILTER_COMPUTE_ONLY 4 17 | 18 | // Enumerate devices that support both display and compute (i.e. fully-featured DirectX 12 devices) 19 | #define DEVICEFILTER_DISPLAY_AND_COMPUTE 5 20 | 21 | 22 | #ifdef __cplusplus 23 | 24 | #include 25 | 26 | // Device filter enum for C++ clients 27 | enum class DeviceFilter : int 28 | { 29 | AllDevices = DEVICEFILTER_ALL, 30 | DisplaySupported = DEVICEFILTER_DISPLAY_SUPPORTED, 31 | ComputeSupported = DEVICEFILTER_COMPUTE_SUPPORTED, 32 | DisplayOnly = DEVICEFILTER_DISPLAY_ONLY, 33 | ComputeOnly = DEVICEFILTER_COMPUTE_ONLY, 34 | DisplayAndCompute = DEVICEFILTER_DISPLAY_AND_COMPUTE 35 | }; 36 | 37 | // Returns a string representation of a device filter 38 | inline std::wstring DeviceFilterName(DeviceFilter filter) 39 | { 40 | switch (filter) 41 | { 42 | case DeviceFilter::AllDevices: 43 | return L"AllDevices"; 44 | 45 | case DeviceFilter::DisplaySupported: 46 | return L"DisplaySupported"; 47 | 48 | case DeviceFilter::ComputeSupported: 49 | return L"ComputeSupported"; 50 | 51 | case DeviceFilter::DisplayOnly: 52 | return L"DisplayOnly"; 53 | 54 | case DeviceFilter::ComputeOnly: 55 | return L"ComputeOnly"; 56 | 57 | case DeviceFilter::DisplayAndCompute: 58 | return L"DisplayAndCompute"; 59 | 60 | default: 61 | return L""; 62 | } 63 | } 64 | 65 | #endif 66 | -------------------------------------------------------------------------------- /library/src/Adapter.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | // Represents a DirectX adapter as enumerated by DXCore 4 | // (For additional details, see: ) 5 | struct Adapter 6 | { 7 | inline Adapter() : 8 | InstanceLuid(0), 9 | IsHardware(false), 10 | IsIntegrated(false), 11 | IsDetachable(false), 12 | SupportsDisplay(false), 13 | SupportsCompute(false) 14 | {} 15 | 16 | // The locally unique identifier (LUID) for the adapter 17 | int64_t InstanceLuid; 18 | 19 | // The PnP hardware ID information for the adapter 20 | DXCoreHardwareID HardwareID; 21 | 22 | // Specifies whether the adapter is a hardware device (as opposed to a software device) 23 | bool IsHardware; 24 | 25 | // Specifies whether the adapter is an integrated GPU (as opposed to a discrete GPU) 26 | bool IsIntegrated; 27 | 28 | // Specifies whether the adapter is a detachable device (i.e. the device can be removed at runtime) 29 | bool IsDetachable; 30 | 31 | // Specifies whether the adapter supports display 32 | // (i.e. supports either the DXCORE_ADAPTER_ATTRIBUTE_D3D11_GRAPHICS or DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS attributes) 33 | bool SupportsDisplay; 34 | 35 | // Specifies whether the adapter supports compute (i.e. supports the DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE attribute) 36 | bool SupportsCompute; 37 | }; 38 | -------------------------------------------------------------------------------- /library/src/AdapterEnumeration.cpp: -------------------------------------------------------------------------------- 1 | #include "AdapterEnumeration.h" 2 | #include "ErrorHandling.h" 3 | #include "ObjectHelpers.h" 4 | 5 | #include 6 | 7 | AdapterEnumeration::AdapterEnumeration() 8 | { 9 | // Create our DXCore adapter factory 10 | auto error = CheckHresult(DXCoreCreateAdapterFactory(this->adapterFactory.put())); 11 | if (error) { 12 | throw error.Wrap(L"DXCoreCreateAdapterFactory failed"); 13 | } 14 | } 15 | 16 | void AdapterEnumeration::EnumerateAdapters(const DeviceFilter& filter, bool includeIntegrated, bool includeDetachable) 17 | { 18 | // Log our enumeration parameters 19 | LOG( 20 | L"Enumerating DirectX adapters using parameters: {{ filter:{}, includeIntegrated:{}, includeDetachable:{} }}", 21 | DeviceFilterName(filter), 22 | includeIntegrated, 23 | includeDetachable 24 | ); 25 | 26 | // Clear our adapter lists and our set of unique adapters 27 | this->adapterLists.clear(); 28 | this->uniqueAdapters.clear(); 29 | 30 | #define ENUMERATE_ADAPTERS(attribute)\ 31 | {\ 32 | GUID attributes[]{ attribute };\ 33 | this->adapterLists.push_back(nullptr); \ 34 | auto error = CheckHresult(this->adapterFactory->CreateAdapterList(_countof(attributes), attributes, this->adapterLists.back().put()));\ 35 | if (error) { \ 36 | throw error.Wrap(L"IDXCoreAdapterFactory::CreateAdapterList() failed for attribute " + wstring(L#attribute));\ 37 | }\ 38 | } 39 | 40 | // Enumerate adapters that support Direct3D 11 41 | if (filter != DeviceFilter::ComputeOnly && filter != DeviceFilter::DisplayAndCompute) { 42 | ENUMERATE_ADAPTERS(DXCORE_ADAPTER_ATTRIBUTE_D3D11_GRAPHICS); 43 | } 44 | 45 | // Enumerate adapters that support Direct3D 12 46 | if (filter != DeviceFilter::ComputeOnly) { 47 | ENUMERATE_ADAPTERS(DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS); 48 | } 49 | 50 | // Enumerate adapters that support Direct3D 12 Core 51 | if (filter != DeviceFilter::DisplayOnly) { 52 | ENUMERATE_ADAPTERS(DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE); 53 | } 54 | 55 | #undef ENUMERATE_ADAPTERS 56 | 57 | // Process each of the enumerated adapters and apply our filtering criteria 58 | for (auto const& adapters : this->adapterLists) 59 | { 60 | const uint32_t count = adapters->GetAdapterCount(); 61 | for (uint32_t index = 0; index < count; ++index) 62 | { 63 | // Extract the details for the current adapter 64 | com_ptr adapter; 65 | auto error = CheckHresult(adapters->GetAdapter(index, adapter.put())); 66 | if (error) { 67 | throw error.Wrap(L"IDXCoreAdapterList::GetAdapter() failed for index " + std::to_wstring(index)); 68 | } 69 | Adapter details = this->ExtractAdapterDetails(adapter); 70 | 71 | // Ignore software devices 72 | if (!details.IsHardware) { 73 | continue; 74 | } 75 | 76 | // If the adapter does not match our filter mode then ignore it 77 | if ((filter == DeviceFilter::DisplayOnly && details.SupportsCompute) || 78 | (filter == DeviceFilter::ComputeOnly && details.SupportsDisplay) || 79 | (filter == DeviceFilter::DisplayAndCompute && (!details.SupportsDisplay || !details.SupportsCompute))) { 80 | continue; 81 | } 82 | 83 | // If the adapter is integrated and we are not including integrated devices then ignore it 84 | if (details.IsIntegrated && !includeIntegrated) { 85 | continue; 86 | } 87 | 88 | // If the adapter is detachable and we are not including detachable devices then ignore it 89 | if (details.IsDetachable && !includeDetachable) { 90 | continue; 91 | } 92 | 93 | // Add the adapter to our set of unique adapters 94 | this->uniqueAdapters.insert(std::make_pair(details.InstanceLuid, details)); 95 | } 96 | } 97 | 98 | // Log the list of unique adapter LUIDs 99 | LOG(L"Enumerated DirectX adapters with LUIDs: {}", FMT(ObjectHelpers::GetMappingKeys(this->uniqueAdapters))); 100 | } 101 | 102 | const map& AdapterEnumeration::GetUniqueAdapters() const { 103 | return this->uniqueAdapters; 104 | } 105 | 106 | bool AdapterEnumeration::IsStale() const 107 | { 108 | // If we have not yet performed enumeration then report that our data is stale 109 | if (this->adapterLists.empty()) 110 | { 111 | LOG(L"No adapter lists yet, need to perform enumeration"); 112 | return true; 113 | } 114 | 115 | // If any of our adapter lists are stale then our data is stale 116 | for (auto const& list : this->adapterLists) 117 | { 118 | if (list->IsStale()) 119 | { 120 | LOG(L"Found stale adapter list"); 121 | return true; 122 | } 123 | } 124 | 125 | return false; 126 | } 127 | 128 | Adapter AdapterEnumeration::ExtractAdapterDetails(const com_ptr& adapter) const 129 | { 130 | Adapter details; 131 | DeviceDiscoveryError error; 132 | 133 | // Extract the adapter LUID and convert it to an int64_t 134 | LUID instanceLuid; 135 | error = CheckHresult(adapter->GetProperty(DXCoreAdapterProperty::InstanceLuid, &instanceLuid)); 136 | if (error) { 137 | throw error.Wrap(L"IDXCoreAdapter::GetProperty() failed for property InstanceLuid"); 138 | } 139 | details.InstanceLuid = Int64FromLuid(instanceLuid); 140 | 141 | // Extract the PnP hardware ID information 142 | error = CheckHresult(adapter->GetProperty(DXCoreAdapterProperty::HardwareID, &details.HardwareID)); 143 | if (error) { 144 | throw error.Wrap(L"IDXCoreAdapter::GetProperty() failed for property HardwareID"); 145 | } 146 | 147 | // Extract the boolean specifying whether the adapter is a hardware device 148 | error = CheckHresult(adapter->GetProperty(DXCoreAdapterProperty::IsHardware, &details.IsHardware)); 149 | if (error) { 150 | throw error.Wrap(L"IDXCoreAdapter::GetProperty() failed for property IsHardware"); 151 | } 152 | 153 | // Extract the boolean specifying whether the adapter is an integrated GPU 154 | error = CheckHresult(adapter->GetProperty(DXCoreAdapterProperty::IsIntegrated, &details.IsIntegrated)); 155 | if (error) { 156 | throw error.Wrap(L"IDXCoreAdapter::GetProperty() failed for property IsIntegrated"); 157 | } 158 | 159 | // Extract the boolean specifying whether the adapter is detachable 160 | error = CheckHresult(adapter->GetProperty(DXCoreAdapterProperty::IsDetachable, &details.IsDetachable)); 161 | if (error) { 162 | throw error.Wrap(L"IDXCoreAdapter::GetProperty() failed for property IsDetachable"); 163 | } 164 | 165 | // Determine whether the adapter supports display 166 | details.SupportsDisplay = 167 | adapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D11_GRAPHICS) || 168 | adapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS); 169 | 170 | // Determine whether the adapter supports compute 171 | details.SupportsCompute = adapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS); 172 | 173 | return details; 174 | } 175 | -------------------------------------------------------------------------------- /library/src/AdapterEnumeration.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "Adapter.h" 4 | #include "DeviceFilter.h" 5 | 6 | using std::map; 7 | using std::vector; 8 | using winrt::com_ptr; 9 | 10 | class AdapterEnumeration 11 | { 12 | public: 13 | AdapterEnumeration(); 14 | 15 | // Enumerates the DirectX adapters that meet the specified filtering criteria 16 | void EnumerateAdapters(const DeviceFilter& filter, bool includeIntegrated, bool includeDetachable); 17 | 18 | // Retrieves the list of unique adapters retrieved during the last enumeration operation 19 | const map& GetUniqueAdapters() const; 20 | 21 | // Determines whether the list of adapters is stale and needs to be refreshed by performing enumeration again 22 | bool IsStale() const; 23 | 24 | private: 25 | 26 | // Extracts the details from a DXCore adapter object 27 | Adapter ExtractAdapterDetails(const com_ptr& adapter) const; 28 | 29 | // Our DXCore adapter factory 30 | com_ptr adapterFactory; 31 | 32 | // Our collection of DXCore adapter lists, used for enumerating adapters with various capabilities 33 | vector< com_ptr > adapterLists; 34 | 35 | // The list of unique adapters retrieved during the last enumeration operation, keyed by adapter LUID 36 | map uniqueAdapters; 37 | }; 38 | -------------------------------------------------------------------------------- /library/src/D3DHelpers.cpp: -------------------------------------------------------------------------------- 1 | #include "D3DHelpers.h" 2 | #include "ErrorHandling.h" 3 | #include "ObjectHelpers.h" 4 | 5 | QueryD3DRegistryInfo::QueryD3DRegistryInfo() 6 | { 7 | this->Resize(0); 8 | this->RegistryInfo->PhysicalAdapterIndex = 0; 9 | } 10 | 11 | void QueryD3DRegistryInfo::SetFilesystemQuery(D3DDDI_QUERYREGISTRY_TYPE queryType) 12 | { 13 | ZeroMemory(this->RegistryInfo->ValueName, sizeof(wchar_t) * MAX_PATH); 14 | this->RegistryInfo->QueryFlags.TranslatePath = 0; 15 | this->RegistryInfo->QueryType = queryType; 16 | this->RegistryInfo->ValueType = 0; 17 | } 18 | 19 | void QueryD3DRegistryInfo::SetAdapterKeyQuery(wstring_view name, ULONG valueType, bool translatePaths) 20 | { 21 | memcpy(this->RegistryInfo->ValueName, name.data(), sizeof(wchar_t) * name.size()); 22 | this->RegistryInfo->QueryFlags.TranslatePath = (translatePaths ? 1 : 0); 23 | this->RegistryInfo->QueryType = D3DDDI_QUERYREGISTRY_ADAPTERKEY; 24 | this->RegistryInfo->ValueType = valueType; 25 | } 26 | 27 | void QueryD3DRegistryInfo::Resize(size_t trailingBuffer) 28 | { 29 | // Allocate memory for the new struct + buffer 30 | this->PrivateDataSize = sizeof(D3DDDI_QUERYREGISTRY_INFO) + trailingBuffer; 31 | auto newData = std::make_unique(this->PrivateDataSize); 32 | 33 | // If we have existing struct values then copy them over to the new struct 34 | if (this->PrivateData) { 35 | memcpy(newData.get(), this->PrivateData.get(), sizeof(D3DDDI_QUERYREGISTRY_INFO)); 36 | } 37 | 38 | // Release the existing data (if any) and update our struct pointer 39 | this->PrivateData = std::move(newData); 40 | this->RegistryInfo = reinterpret_cast(this->PrivateData.get()); 41 | } 42 | 43 | void QueryD3DRegistryInfo::PerformQuery(unique_adapter_handle& adapter) 44 | { 45 | while (true) 46 | { 47 | // Attempt to perform the query 48 | auto adapterQuery = this->CreateAdapterQuery(adapter); 49 | auto error = CheckNtStatus(D3DKMTQueryAdapterInfo(&adapterQuery)); 50 | if (error) { 51 | throw error.Wrap(L"D3DKMTQueryAdapterInfo failed"); 52 | } 53 | 54 | // Determine whether we need to resize the trailing buffer and try again 55 | if (this->RegistryInfo->Status == D3DDDI_QUERYREGISTRY_STATUS_BUFFER_OVERFLOW) { 56 | this->Resize(this->RegistryInfo->OutputValueSize); 57 | } 58 | else { 59 | return; 60 | } 61 | } 62 | } 63 | 64 | D3DKMT_QUERYADAPTERINFO QueryD3DRegistryInfo::CreateAdapterQuery(unique_adapter_handle& adapter) 65 | { 66 | auto adapterQuery = ObjectHelpers::GetZeroedStruct(); 67 | adapterQuery.hAdapter = adapter.get(); 68 | adapterQuery.Type = KMTQAITYPE_QUERYREGISTRY; 69 | adapterQuery.pPrivateDriverData = this->PrivateData.get(); 70 | adapterQuery.PrivateDriverDataSize = this->PrivateDataSize; 71 | return adapterQuery; 72 | } 73 | -------------------------------------------------------------------------------- /library/src/D3DHelpers.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | using std::wstring_view; 4 | 5 | 6 | // Closes the supplied DirectX adapter handle 7 | inline NTSTATUS CloseAdapter(D3DKMT_HANDLE adapter) 8 | { 9 | D3DKMT_CLOSEADAPTER close; 10 | close.hAdapter = adapter; 11 | return D3DKMTCloseAdapter(&close); 12 | } 13 | 14 | // Auto-releasing resource wrapper type for DirectX adapter handles 15 | typedef wil::unique_any unique_adapter_handle; 16 | 17 | 18 | // Encapsulates a D3DDDI_QUERYREGISTRY_INFO struct, along with its trailing buffer for receiving output data 19 | class QueryD3DRegistryInfo 20 | { 21 | public: 22 | 23 | // Use this to access the struct's member fields 24 | D3DDDI_QUERYREGISTRY_INFO* RegistryInfo; 25 | 26 | QueryD3DRegistryInfo(); 27 | 28 | // Populates the struct fields for querying a filesystem path 29 | void SetFilesystemQuery(D3DDDI_QUERYREGISTRY_TYPE queryType); 30 | 31 | // Populates the struct fields for querying a registry value from the adapter key 32 | void SetAdapterKeyQuery(wstring_view name, ULONG valueType, bool translatePaths); 33 | 34 | // Resizes the trailing buffer 35 | void Resize(size_t trailingBuffer); 36 | 37 | // Performs a registry query against the specified adapter, resizing the trailing buffer to accommodate the output data size as needed 38 | void PerformQuery(unique_adapter_handle& adapter); 39 | 40 | private: 41 | 42 | // The underlying data and size for the struct along with its trailing buffer 43 | std::unique_ptr PrivateData; 44 | size_t PrivateDataSize; 45 | 46 | // Creates a D3DKMT_QUERYADAPTERINFO struct that wraps struct and its trailing buffer 47 | D3DKMT_QUERYADAPTERINFO CreateAdapterQuery(unique_adapter_handle& adapter); 48 | }; 49 | -------------------------------------------------------------------------------- /library/src/Device.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "Adapter.h" 4 | 5 | using std::vector; 6 | using std::wstring; 7 | 8 | 9 | // Represents an additional file that needs to be copied from the driver store to the system directory in order to use a device with non-DirectX runtimes 10 | // (For details, see: ) 11 | struct RuntimeFile 12 | { 13 | RuntimeFile(wstring SourcePath, wstring DestinationFilename) 14 | { 15 | this->SourcePath = SourcePath; 16 | this->DestinationFilename = DestinationFilename; 17 | 18 | // If no destination filename was specified then use the filename from the source path 19 | if (this->DestinationFilename.empty()) { 20 | this->DestinationFilename = std::filesystem::path(this->SourcePath).filename().wstring(); 21 | } 22 | } 23 | 24 | // The relative path to the file in the driver store 25 | wstring SourcePath; 26 | 27 | // The filename that the file should be given when copied to the destination directory 28 | wstring DestinationFilename; 29 | }; 30 | 31 | 32 | // Represents the underlying PnP device associated with a DirectX adapter 33 | struct Device 34 | { 35 | // The DirectX adapter associated with the PnP device 36 | Adapter DeviceAdapter; 37 | 38 | // The unique PNP hardware identifier for the device 39 | wstring ID; 40 | 41 | // A human-readable description of the device (e.g. the model name) 42 | wstring Description; 43 | 44 | // The registry key that contains the driver details for the device 45 | wstring DriverRegistryKey; 46 | 47 | // The absolute path to the directory in the driver store that contains the driver files for the device 48 | wstring DriverStorePath; 49 | 50 | // The path to the physical location of the device in the system 51 | wstring LocationPath; 52 | 53 | // The list of additional files that need to be copied from the driver store to the System32 directory in order to use the device with non-DirectX runtimes 54 | vector RuntimeFiles; 55 | 56 | // The list of additional files that need to be copied from the driver store to the SysWOW64 directory in order to use the device with non-DirectX runtimes 57 | vector RuntimeFilesWow64; 58 | 59 | // The vendor of the device (e.g. AMD, Intel, NVIDIA) 60 | wstring Vendor; 61 | }; 62 | -------------------------------------------------------------------------------- /library/src/DeviceDiscovery.cpp: -------------------------------------------------------------------------------- 1 | #include "DeviceDiscovery.h" 2 | #include "DeviceDiscoveryImp.h" 3 | 4 | #define LIBRARY_VERSION L"0.0.1" 5 | 6 | #define INSTANCE (reinterpret_cast(instance)) 7 | 8 | const wchar_t* GetDiscoveryLibraryVersion() { 9 | return LIBRARY_VERSION; 10 | } 11 | 12 | void DisableDiscoveryLogging() { 13 | spdlog::set_level(spdlog::level::off); 14 | } 15 | 16 | void EnableDiscoveryLogging() 17 | { 18 | spdlog::set_pattern("%^[directx-device-discovery.dll %Y-%m-%dT%T%z]%$ [%s:%# %!] %v", spdlog::pattern_time_type::local); 19 | spdlog::set_level(spdlog::level::info); 20 | spdlog::flush_on(spdlog::level::info); 21 | } 22 | 23 | DeviceDiscoveryInstance CreateDeviceDiscoveryInstance() { 24 | return new DeviceDiscoveryImp(); 25 | } 26 | 27 | void DestroyDeviceDiscoveryInstance(DeviceDiscoveryInstance instance) { 28 | delete INSTANCE; 29 | } 30 | 31 | const wchar_t* DeviceDiscovery_GetLastErrorMessage(DeviceDiscoveryInstance instance) { 32 | return INSTANCE->GetLastErrorMessage(); 33 | } 34 | 35 | int DeviceDiscovery_IsRefreshRequired(DeviceDiscoveryInstance instance) { 36 | return INSTANCE->IsRefreshRequired(); 37 | } 38 | 39 | int DeviceDiscovery_DiscoverDevices(DeviceDiscoveryInstance instance, int filter, int includeIntegrated, int includeDetachable) 40 | { 41 | bool success = INSTANCE->DiscoverDevices(static_cast(filter), includeIntegrated, includeDetachable); 42 | return (success ? 0 : -1); 43 | } 44 | 45 | int DeviceDiscovery_GetNumDevices(DeviceDiscoveryInstance instance) { 46 | return INSTANCE->GetNumDevices(); 47 | } 48 | 49 | long long DeviceDiscovery_GetDeviceAdapterLUID(DeviceDiscoveryInstance instance, unsigned int device) { 50 | return INSTANCE->GetDeviceAdapterLUID(device); 51 | } 52 | 53 | const wchar_t* DeviceDiscovery_GetDeviceID(DeviceDiscoveryInstance instance, unsigned int device) { 54 | return INSTANCE->GetDeviceID(device); 55 | } 56 | 57 | const wchar_t* DeviceDiscovery_GetDeviceDescription(DeviceDiscoveryInstance instance, unsigned int device) { 58 | return INSTANCE->GetDeviceDescription(device); 59 | } 60 | 61 | const wchar_t* DeviceDiscovery_GetDeviceDriverRegistryKey(DeviceDiscoveryInstance instance, unsigned int device) { 62 | return INSTANCE->GetDeviceDriverRegistryKey(device); 63 | } 64 | 65 | const wchar_t* DeviceDiscovery_GetDeviceDriverStorePath(DeviceDiscoveryInstance instance, unsigned int device) { 66 | return INSTANCE->GetDeviceDriverStorePath(device); 67 | } 68 | 69 | const wchar_t* DeviceDiscovery_GetDeviceLocationPath(DeviceDiscoveryInstance instance, unsigned int device) { 70 | return INSTANCE->GetDeviceLocationPath(device); 71 | } 72 | 73 | const wchar_t* DeviceDiscovery_GetDeviceVendor(DeviceDiscoveryInstance instance, unsigned int device) { 74 | return INSTANCE->GetDeviceVendor(device); 75 | } 76 | 77 | int DeviceDiscovery_GetNumRuntimeFiles(DeviceDiscoveryInstance instance, unsigned int device) { 78 | return INSTANCE->GetNumRuntimeFiles(device); 79 | } 80 | 81 | const wchar_t* DeviceDiscovery_GetRuntimeFileSource(DeviceDiscoveryInstance instance, unsigned int device, unsigned int file) { 82 | return INSTANCE->GetRuntimeFileSource(device, file); 83 | } 84 | 85 | const wchar_t* DeviceDiscovery_GetRuntimeFileDestination(DeviceDiscoveryInstance instance, unsigned int device, unsigned int file) { 86 | return INSTANCE->GetRuntimeFileDestination(device, file); 87 | } 88 | 89 | int DeviceDiscovery_GetNumRuntimeFilesWow64(DeviceDiscoveryInstance instance, unsigned int device) { 90 | return INSTANCE->GetNumRuntimeFilesWow64(device); 91 | } 92 | 93 | const wchar_t* DeviceDiscovery_GetRuntimeFileSourceWow64(DeviceDiscoveryInstance instance, unsigned int device, unsigned int file) { 94 | return INSTANCE->GetRuntimeFileSourceWow64(device, file); 95 | } 96 | 97 | const wchar_t* DeviceDiscovery_GetRuntimeFileDestinationWow64(DeviceDiscoveryInstance instance, unsigned int device, unsigned int file) { 98 | return INSTANCE->GetRuntimeFileDestinationWow64(device, file); 99 | } 100 | 101 | int DeviceDiscovery_IsDeviceIntegrated(DeviceDiscoveryInstance instance, unsigned int device) { 102 | return INSTANCE->IsDeviceIntegrated(device); 103 | } 104 | 105 | int DeviceDiscovery_IsDeviceDetachable(DeviceDiscoveryInstance instance, unsigned int device) { 106 | return INSTANCE->IsDeviceDetachable(device); 107 | } 108 | 109 | int DeviceDiscovery_DoesDeviceSupportDisplay(DeviceDiscoveryInstance instance, unsigned int device) { 110 | return INSTANCE->DoesDeviceSupportDisplay(device); 111 | } 112 | 113 | int DeviceDiscovery_DoesDeviceSupportCompute(DeviceDiscoveryInstance instance, unsigned int device) { 114 | return INSTANCE->DoesDeviceSupportCompute(device); 115 | } 116 | -------------------------------------------------------------------------------- /library/src/DeviceDiscoveryImp.cpp: -------------------------------------------------------------------------------- 1 | #include "DeviceDiscoveryImp.h" 2 | #include "ErrorHandling.h" 3 | #include "RegistryQuery.h" 4 | 5 | #include 6 | #include 7 | 8 | #define RETURN_ERROR(sentinel, message) this->SetLastErrorMessage(message); return sentinel 9 | #define RETURN_SUCCESS(value) this->SetLastErrorMessage(L""); return value 10 | 11 | #define VERIFY_DEVICE(sentinel) try { this->ValidateRequestedDevice(device); } catch (const DeviceDiscoveryError& err) { RETURN_ERROR(sentinel, err.message); } 12 | #define VERIFY_FILE() if (file >= files.size()) { RETURN_ERROR(nullptr, L"requested runtime file index is invalid: " + std::to_wstring(file)); } 13 | 14 | const wchar_t* DeviceDiscoveryImp::GetLastErrorMessage() const { 15 | return this->lastError.c_str(); 16 | } 17 | 18 | bool DeviceDiscoveryImp::IsRefreshRequired() 19 | { 20 | // Make sure WinRT is initialised for the calling thread 21 | Windows::Foundation::Initialize(RO_INIT_MULTITHREADED); 22 | 23 | // We require a refresh if we have no data or we have stale data 24 | return (this->HaveDevices()) ? this->enumeration->IsStale() : true; 25 | } 26 | 27 | bool DeviceDiscoveryImp::DiscoverDevices(DeviceFilter filter, bool includeIntegrated, bool includeDetachable) 28 | { 29 | // Make sure WinRT is initialised for the calling thread 30 | Windows::Foundation::Initialize(RO_INIT_MULTITHREADED); 31 | 32 | try 33 | { 34 | // If this is the first time we're performing device discovery then create our helper objects 35 | if (!this->HaveDevices()) 36 | { 37 | this->enumeration = std::make_unique(); 38 | this->wmi = std::make_unique(); 39 | } 40 | 41 | // Enumerate the DirectX adapters that meet the supplied filtering criteria 42 | this->enumeration->EnumerateAdapters(filter, includeIntegrated, includeDetachable); 43 | 44 | // Retrieve the PnP device details from WMI for each of the enumerated adapters 45 | this->devices = this->wmi->GetDevicesForAdapters(this->enumeration->GetUniqueAdapters()); 46 | 47 | // Retrieve the driver details from the registry for each of the devices 48 | for (auto& device : this->devices) { 49 | RegistryQuery::FillDriverDetails(device); 50 | } 51 | 52 | RETURN_SUCCESS(true); 53 | } 54 | catch (const DeviceDiscoveryError& err) { 55 | RETURN_ERROR(false, err.Pretty()); 56 | } 57 | catch (const std::runtime_error& err) { 58 | RETURN_ERROR(false, winrt::to_hstring(err.what())); 59 | } 60 | } 61 | 62 | int DeviceDiscoveryImp::GetNumDevices() 63 | { 64 | // Verify that we have a device list 65 | if (!this->HaveDevices()) { 66 | RETURN_ERROR(-1, L"attempted to retrieve device count before performing device discovery"); 67 | } 68 | 69 | RETURN_SUCCESS(this->devices.size()); 70 | } 71 | 72 | long long DeviceDiscoveryImp::GetDeviceAdapterLUID(unsigned int device) 73 | { 74 | // Verify that the requested device exists 75 | VERIFY_DEVICE(-1); 76 | 77 | // Retrieve the adapter LUID of the specified device 78 | RETURN_SUCCESS(this->devices[device].DeviceAdapter.InstanceLuid); 79 | } 80 | 81 | const wchar_t* DeviceDiscoveryImp::GetDeviceID(unsigned int device) 82 | { 83 | // Verify that the requested device exists 84 | VERIFY_DEVICE(nullptr); 85 | 86 | // Retrieve the ID of the specified device 87 | RETURN_SUCCESS(this->devices[device].ID.c_str()); 88 | } 89 | 90 | const wchar_t* DeviceDiscoveryImp::GetDeviceDescription(unsigned int device) 91 | { 92 | // Verify that the requested device exists 93 | VERIFY_DEVICE(nullptr); 94 | 95 | // Retrieve the human-readable description of the specified device 96 | RETURN_SUCCESS(this->devices[device].Description.c_str()); 97 | } 98 | 99 | const wchar_t* DeviceDiscoveryImp::GetDeviceDriverRegistryKey(unsigned int device) 100 | { 101 | // Verify that the requested device exists 102 | VERIFY_DEVICE(nullptr); 103 | 104 | // Retrieve the path of the registry key with the driver details for the specified device 105 | RETURN_SUCCESS(this->devices[device].DriverRegistryKey.c_str()); 106 | } 107 | 108 | const wchar_t* DeviceDiscoveryImp::GetDeviceDriverStorePath(unsigned int device) 109 | { 110 | // Verify that the requested device exists 111 | VERIFY_DEVICE(nullptr); 112 | 113 | // Retrieve the absolute path to the driver store directory for the specified device 114 | RETURN_SUCCESS(this->devices[device].DriverStorePath.c_str()); 115 | } 116 | 117 | const wchar_t* DeviceDiscoveryImp::GetDeviceLocationPath(unsigned int device) 118 | { 119 | // Verify that the requested device exists 120 | VERIFY_DEVICE(nullptr); 121 | 122 | // Retrieve the physical location path of the specified device 123 | RETURN_SUCCESS(this->devices[device].LocationPath.c_str()); 124 | } 125 | 126 | const wchar_t* DeviceDiscoveryImp::GetDeviceVendor(unsigned int device) 127 | { 128 | // Verify that the requested device exists 129 | VERIFY_DEVICE(nullptr); 130 | 131 | // Retrieve the vendor of the specified device 132 | RETURN_SUCCESS(this->devices[device].Vendor.c_str()); 133 | } 134 | 135 | int DeviceDiscoveryImp::GetNumRuntimeFiles(unsigned int device) 136 | { 137 | // Verify that the requested device exists 138 | VERIFY_DEVICE(-1); 139 | 140 | // Retrieve the number of additional runtime files for the device 141 | RETURN_SUCCESS(this->devices[device].RuntimeFiles.size()); 142 | } 143 | 144 | const wchar_t* DeviceDiscoveryImp::GetRuntimeFileSource(unsigned int device, unsigned int file) 145 | { 146 | // Verify that the requested device exists 147 | VERIFY_DEVICE(nullptr); 148 | 149 | // Verify that the requested file entry exists 150 | const vector& files = this->devices[device].RuntimeFiles; 151 | VERIFY_FILE(); 152 | 153 | // Retrieve the source path for the file 154 | RETURN_SUCCESS(files[file].SourcePath.c_str()); 155 | } 156 | 157 | const wchar_t* DeviceDiscoveryImp::GetRuntimeFileDestination(unsigned int device, unsigned int file) 158 | { 159 | // Verify that the requested device exists 160 | VERIFY_DEVICE(nullptr); 161 | 162 | // Verify that the requested file entry exists 163 | const vector& files = this->devices[device].RuntimeFiles; 164 | VERIFY_FILE(); 165 | 166 | // Retrieve the destination filename for the file 167 | RETURN_SUCCESS(files[file].DestinationFilename.c_str()); 168 | } 169 | 170 | int DeviceDiscoveryImp::GetNumRuntimeFilesWow64(unsigned int device) 171 | { 172 | // Verify that the requested device exists 173 | VERIFY_DEVICE(-1); 174 | 175 | // Retrieve the number of additional SysWOW64 runtime files for the device 176 | RETURN_SUCCESS(this->devices[device].RuntimeFilesWow64.size()); 177 | } 178 | 179 | const wchar_t* DeviceDiscoveryImp::GetRuntimeFileSourceWow64(unsigned int device, unsigned int file) 180 | { 181 | // Verify that the requested device exists 182 | VERIFY_DEVICE(nullptr); 183 | 184 | // Verify that the requested file entry exists 185 | const vector& files = this->devices[device].RuntimeFilesWow64; 186 | VERIFY_FILE(); 187 | 188 | // Retrieve the source path for the file 189 | RETURN_SUCCESS(files[file].SourcePath.c_str()); 190 | } 191 | 192 | const wchar_t* DeviceDiscoveryImp::GetRuntimeFileDestinationWow64(unsigned int device, unsigned int file) 193 | { 194 | // Verify that the requested device exists 195 | VERIFY_DEVICE(nullptr); 196 | 197 | // Verify that the requested file entry exists 198 | const vector& files = this->devices[device].RuntimeFilesWow64; 199 | VERIFY_FILE(); 200 | 201 | // Retrieve the destination filename for the file 202 | RETURN_SUCCESS(files[file].DestinationFilename.c_str()); 203 | } 204 | 205 | int DeviceDiscoveryImp::IsDeviceIntegrated(unsigned int device) 206 | { 207 | // Verify that the requested device exists 208 | VERIFY_DEVICE(-1); 209 | 210 | // Determine whether the specified device is an integrated GPU 211 | RETURN_SUCCESS(this->devices[device].DeviceAdapter.IsIntegrated); 212 | } 213 | 214 | int DeviceDiscoveryImp::IsDeviceDetachable(unsigned int device) 215 | { 216 | // Verify that the requested device exists 217 | VERIFY_DEVICE(-1); 218 | 219 | // Determine whether the specified device is detachable 220 | RETURN_SUCCESS(this->devices[device].DeviceAdapter.IsDetachable); 221 | } 222 | 223 | int DeviceDiscoveryImp::DoesDeviceSupportDisplay(unsigned int device) 224 | { 225 | // Verify that the requested device exists 226 | VERIFY_DEVICE(-1); 227 | 228 | // Determine whether the specified device supports display 229 | RETURN_SUCCESS(this->devices[device].DeviceAdapter.SupportsDisplay); 230 | } 231 | 232 | int DeviceDiscoveryImp::DoesDeviceSupportCompute(unsigned int device) 233 | { 234 | // Verify that the requested device exists 235 | VERIFY_DEVICE(-1); 236 | 237 | // Determine whether the specified device supports compute 238 | RETURN_SUCCESS(this->devices[device].DeviceAdapter.SupportsCompute); 239 | } 240 | 241 | bool DeviceDiscoveryImp::HaveDevices() const { 242 | return (this->enumeration && this->wmi); 243 | } 244 | 245 | void DeviceDiscoveryImp::SetLastErrorMessage(std::wstring_view message) { 246 | this->lastError = message; 247 | } 248 | 249 | void DeviceDiscoveryImp::ValidateRequestedDevice(unsigned int device) 250 | { 251 | // Verify that we have a device list 252 | if (!this->HaveDevices()) { 253 | throw CreateError(L"attempted to retrieve device details before performing device discovery"); 254 | } 255 | 256 | // Verify that the specified device index is valid 257 | if (device >= this->GetNumDevices()) { 258 | throw CreateError(L"requested device index is invalid: " + std::to_wstring(device)); 259 | } 260 | } 261 | -------------------------------------------------------------------------------- /library/src/DeviceDiscoveryImp.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "AdapterEnumeration.h" 4 | #include "Device.h" 5 | #include "DeviceFilter.h" 6 | #include "WmiQuery.h" 7 | 8 | using std::wstring; 9 | using std::wstring_view; 10 | using std::unique_ptr; 11 | using std::vector; 12 | 13 | class DeviceDiscoveryImp 14 | { 15 | public: 16 | 17 | DeviceDiscoveryImp() {} 18 | const wchar_t* GetLastErrorMessage() const; 19 | bool IsRefreshRequired(); 20 | bool DiscoverDevices(DeviceFilter filter, bool includeIntegrated, bool includeDetachable); 21 | int GetNumDevices(); 22 | long long GetDeviceAdapterLUID(unsigned int device); 23 | const wchar_t* GetDeviceID(unsigned int device); 24 | const wchar_t* GetDeviceDescription(unsigned int device); 25 | const wchar_t* GetDeviceDriverRegistryKey(unsigned int device); 26 | const wchar_t* GetDeviceDriverStorePath(unsigned int device); 27 | const wchar_t* GetDeviceLocationPath(unsigned int device); 28 | const wchar_t* GetDeviceVendor(unsigned int device); 29 | int GetNumRuntimeFiles(unsigned int device); 30 | const wchar_t* GetRuntimeFileSource(unsigned int device, unsigned int file); 31 | const wchar_t* GetRuntimeFileDestination(unsigned int device, unsigned int file); 32 | int GetNumRuntimeFilesWow64(unsigned int device); 33 | const wchar_t* GetRuntimeFileSourceWow64(unsigned int device, unsigned int file); 34 | const wchar_t* GetRuntimeFileDestinationWow64(unsigned int device, unsigned int file); 35 | int IsDeviceIntegrated(unsigned int device); 36 | int IsDeviceDetachable(unsigned int device); 37 | int DoesDeviceSupportDisplay(unsigned int device); 38 | int DoesDeviceSupportCompute(unsigned int device); 39 | 40 | private: 41 | 42 | bool HaveDevices() const; 43 | void SetLastErrorMessage(wstring_view message); 44 | void ValidateRequestedDevice(unsigned int device); 45 | 46 | vector devices; 47 | wstring lastError; 48 | 49 | unique_ptr enumeration; 50 | unique_ptr wmi; 51 | }; 52 | -------------------------------------------------------------------------------- /library/src/DllMain.cpp: -------------------------------------------------------------------------------- 1 | BOOL APIENTRY DllMain(HINSTANCE hModule, DWORD dwReason, PVOID lpReserved) 2 | { 3 | if (dwReason == DLL_PROCESS_ATTACH) 4 | { 5 | // Disable logging by default 6 | spdlog::set_level(spdlog::level::off); 7 | } 8 | 9 | return TRUE; 10 | } 11 | -------------------------------------------------------------------------------- /library/src/ErrorHandling.cpp: -------------------------------------------------------------------------------- 1 | #include "ErrorHandling.h" 2 | 3 | DeviceDiscoveryError ErrorHandling::ErrorForNtStatus(NTSTATUS status, wstring_view file, wstring_view function, size_t line) 4 | { 5 | if (status < 0) 6 | { 7 | // Allocate a buffer to hold the error message 8 | size_t bufSize = 1024; 9 | auto buffer = std::make_unique(bufSize); 10 | 11 | // Attempt to retrieve the error message for the status code 12 | DWORD length = FormatMessageW( 13 | FORMAT_MESSAGE_FROM_HMODULE | FORMAT_MESSAGE_IGNORE_INSERTS, 14 | GetModuleHandleW(L"ntdll.dll"), 15 | status, 16 | 0, 17 | buffer.get(), 18 | bufSize, 19 | nullptr 20 | ); 21 | 22 | if (length > 0) 23 | { 24 | // If the message has a trailing newline then remove it 25 | wstring message(buffer.get(), length); 26 | size_t newline = message.find_last_of(L"\r\n"); 27 | if (newline != wstring::npos) { 28 | message = message.substr(0, newline - 1); 29 | } 30 | 31 | // Return an error with the retrieved message 32 | return DeviceDiscoveryError(message, file, function, line); 33 | } 34 | else 35 | { 36 | // Return an error with the hexadecimal representation of the NTSTATUS code 37 | return DeviceDiscoveryError( 38 | fmt::format( 39 | L"Unable to retrieve error message for NTSTATUS code 0x{:0>8X}", 40 | static_cast(status) 41 | ), 42 | file, 43 | function, 44 | line 45 | ); 46 | } 47 | } 48 | 49 | return DeviceDiscoveryError(L"", file, function, line); 50 | } 51 | -------------------------------------------------------------------------------- /library/src/ErrorHandling.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | using std::wstring; 4 | using std::wstring_view; 5 | 6 | 7 | // Define Unicode versions of __FILE__ and __FUNCTION__ 8 | #define __WIDE2(x) L##x 9 | #define __WIDE1(x) __WIDE2(x) 10 | #define __WFILE__ __WIDE1(__FILE__) 11 | #define __WFUNCTION__ __WIDE1(__FUNCTION__) 12 | 13 | 14 | // The exception type used to represent all errors inside the device discovery library 15 | class DeviceDiscoveryError 16 | { 17 | public: 18 | inline DeviceDiscoveryError() : message(L""), file(L""), function(L""), line(0) {} 19 | 20 | inline DeviceDiscoveryError(wstring_view message, wstring_view file, wstring_view function, size_t line) : 21 | message(message), file(file), function(function), line(line) 22 | {} 23 | 24 | inline DeviceDiscoveryError(wstring_view message, const DeviceDiscoveryError& inner) : 25 | file(inner.file), function(inner.function), line(inner.line) 26 | { 27 | this->message = wstring(message) + L": " + inner.message; 28 | } 29 | 30 | DeviceDiscoveryError(const DeviceDiscoveryError& other) = default; 31 | DeviceDiscoveryError(DeviceDiscoveryError&& other) = default; 32 | DeviceDiscoveryError& operator=(const DeviceDiscoveryError& other) = default; 33 | DeviceDiscoveryError& operator=(DeviceDiscoveryError&& other) = default; 34 | 35 | inline operator bool() const { 36 | return (!this->message.empty()); 37 | } 38 | 39 | // Wraps this error in a surrounding error message 40 | inline DeviceDiscoveryError Wrap(wstring_view message) const { 41 | return DeviceDiscoveryError(message, *this); 42 | } 43 | 44 | // Formats the error details as a pretty string 45 | inline wstring Pretty() const 46 | { 47 | // Extract the filename from the file path 48 | wstring filename = std::filesystem::path(this->file).filename().wstring(); 49 | 50 | // Append the filename, line number and function name to the error message 51 | wstring fileAndLine = filename + L":" + std::to_wstring(this->line); 52 | return this->message + L" [" + fileAndLine + L" " + this->function + L"]"; 53 | } 54 | 55 | wstring message; 56 | wstring file; 57 | wstring function; 58 | size_t line; 59 | }; 60 | 61 | 62 | // Provides functionality related to managing errors 63 | namespace ErrorHandling 64 | { 65 | // Returns an error object representing the supplied NTSTATUS code 66 | DeviceDiscoveryError ErrorForNtStatus(NTSTATUS status, wstring_view file, wstring_view function, size_t line); 67 | 68 | // Returns an error object representing the supplied HRESULT code 69 | inline DeviceDiscoveryError ErrorForHresult(const winrt::hresult& result, wstring_view file, wstring_view function, size_t line) 70 | { 71 | try 72 | { 73 | winrt::check_hresult(result); 74 | return DeviceDiscoveryError(L"", file, function, line); 75 | } 76 | catch (const winrt::hresult_error& err) { 77 | return DeviceDiscoveryError(err.message(), file, function, line); 78 | } 79 | } 80 | 81 | // Returns an error object representing the supplied Win32 error code 82 | template 83 | inline DeviceDiscoveryError ErrorForWin32(T error, wstring_view file, wstring_view function, size_t line) 84 | { 85 | try 86 | { 87 | winrt::check_win32(error); 88 | return DeviceDiscoveryError(L"", file, function, line); 89 | } 90 | catch (const winrt::hresult_error& err) { 91 | return DeviceDiscoveryError(err.message(), file, function, line); 92 | } 93 | } 94 | 95 | // Convenience macros for automatically filling out error file, function and line details 96 | #define CreateError(message) DeviceDiscoveryError(message, __WFILE__, __WFUNCTION__, __LINE__) 97 | #define CheckNtStatus(status) ErrorHandling::ErrorForNtStatus(status, __WFILE__, __WFUNCTION__, __LINE__) 98 | #define CheckHresult(status) ErrorHandling::ErrorForHresult(status, __WFILE__, __WFUNCTION__, __LINE__) 99 | #define CheckWin32(status) ErrorHandling::ErrorForWin32(status, __WFILE__, __WFUNCTION__, __LINE__) 100 | 101 | // Catches a winrt::hresult_error object and converts it to a DeviceDiscoveryError object 102 | #define CatchHresult(error, operation) try { operation; error = DeviceDiscoveryError(); } catch (const winrt::hresult_error & err) { error = CreateError(err.message()); } 103 | } 104 | -------------------------------------------------------------------------------- /library/src/ObjectHelpers.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace ObjectHelpers { 4 | 5 | 6 | // Retrieves the list of keys for an STL associative container type (maps, sets, etc.) 7 | // For the full list of supported container types, see: 8 | // - 9 | // - 10 | template 11 | std::vector GetMappingKeys(const MappingType& mapping) 12 | { 13 | std::vector keys; 14 | 15 | for (const auto& pair : mapping) { 16 | keys.push_back(pair.first); 17 | } 18 | 19 | return keys; 20 | } 21 | 22 | // Returns a zeroed-out instance of the specified struct type 23 | template T GetZeroedStruct() 24 | { 25 | T instance; 26 | ZeroMemory(&instance, sizeof(T)); 27 | return instance; 28 | } 29 | 30 | 31 | } // namespace ObjectHelpers 32 | -------------------------------------------------------------------------------- /library/src/RegistryQuery.cpp: -------------------------------------------------------------------------------- 1 | #include "RegistryQuery.h" 2 | #include "D3DHelpers.h" 3 | #include "ErrorHandling.h" 4 | #include "ObjectHelpers.h" 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | map< wstring, vector > RegistryQuery::EnumerateMultiStringValues(unique_hkey& key) 11 | { 12 | map< wstring, vector > values; 13 | 14 | LSTATUS result = ERROR_SUCCESS; 15 | for (int i = 0; ; ++i) 16 | { 17 | // Receives the type of the enumerated value 18 | DWORD valueType = 0; 19 | 20 | // Receives the name of the enumerated value 21 | DWORD nameBufsize = 256; 22 | auto valueName = std::make_unique(nameBufsize); 23 | 24 | // Receives the data of the enumerated value 25 | DWORD dataBufsize = 1024; 26 | auto valueData = std::make_unique(dataBufsize); 27 | 28 | // Retrieve the next value and check to see if we have processed all available values 29 | result = RegEnumValueW(key.get(), i, valueName.get(), &nameBufsize, nullptr, &valueType, valueData.get(), &dataBufsize); 30 | if (result == ERROR_NO_MORE_ITEMS) { 31 | break; 32 | } 33 | 34 | // Report any errors 35 | auto error = CheckWin32(result); 36 | if (error) { 37 | throw error.Wrap(L"RegEnumValueW failed"); 38 | } 39 | 40 | // Verify that the value data is of type REG_MULTI_SZ 41 | wstring name = wstring(valueName.get(), nameBufsize); 42 | if (valueType != REG_MULTI_SZ) { 43 | throw CreateError(L"enumerated value was not of type REG_MULTI_SZ: " + name); 44 | } 45 | 46 | // Parse the value data and add it to our mapping 47 | auto strings = RegistryQuery::ExtractMultiStringValue((wchar_t*)(valueData.get()), dataBufsize); 48 | values.insert(std::make_pair(name, strings)); 49 | } 50 | 51 | return values; 52 | } 53 | 54 | vector RegistryQuery::ExtractMultiStringValue(const wchar_t* data, size_t numBytes) 55 | { 56 | vector strings; 57 | 58 | size_t offset = 0; 59 | size_t upperBound = numBytes / sizeof(wchar_t); 60 | while (offset < upperBound) 61 | { 62 | // Extract the next string and check that it's not empty 63 | wstring nextString(data + offset); 64 | if (nextString.size() == 0) { break; } 65 | 66 | // Add the string to our list and proceed to the next one 67 | strings.push_back(nextString); 68 | offset += strings.back().size() + 1; 69 | } 70 | 71 | return strings; 72 | } 73 | 74 | unique_hkey RegistryQuery::OpenKeyFromString(wstring_view key) 75 | { 76 | // Our list of supported root keys 77 | static map rootKeys = { 78 | { L"HKEY_CLASSES_ROOT", HKEY_CLASSES_ROOT }, 79 | { L"HKEY_CURRENT_CONFIG", HKEY_CURRENT_CONFIG }, 80 | { L"HKEY_CURRENT_USER", HKEY_CURRENT_USER }, 81 | { L"HKEY_LOCAL_MACHINE", HKEY_LOCAL_MACHINE }, 82 | { L"HKEY_PERFORMANCE_DATA", HKEY_PERFORMANCE_DATA }, 83 | { L"HKEY_USERS", HKEY_USERS } 84 | }; 85 | 86 | // Verify that the supplied key path is well-formed 87 | size_t backslash = key.find_first_of(L"\\"); 88 | if (backslash == wstring_view::npos || backslash >= (key.size()-1)) { 89 | throw CreateError(L"invalid registry key path: " + wstring(key)); 90 | } 91 | 92 | // Split the root key name from the rest of the path 93 | wstring rootKeyName = wstring(key.substr(0, backslash)); 94 | wstring keyPath = wstring(key.substr(backslash + 1)); 95 | 96 | // Identify the handle for the specified root key 97 | auto rootKey = rootKeys.find(rootKeyName); 98 | if (rootKey == rootKeys.end()) { 99 | throw CreateError(L"unknown registry root key: " + rootKeyName); 100 | } 101 | 102 | // Attempt to open the key 103 | unique_hkey keyHandle; 104 | auto error = CheckWin32(RegOpenKeyExW(rootKey->second, keyPath.c_str(), 0, KEY_READ, keyHandle.put())); 105 | if (error) { 106 | throw error.Wrap(L"failed to open registry key " + wstring(key)); 107 | } 108 | 109 | return keyHandle; 110 | } 111 | 112 | void RegistryQuery::ProcessRuntimeFiles(Device& device, wstring_view key, bool isWow64) 113 | { 114 | try 115 | { 116 | // Determine whether we are adding runtime files to the device's System32 list or SysWOW64 list 117 | auto& list = (isWow64) ? device.RuntimeFilesWow64 : device.RuntimeFiles; 118 | 119 | // Attempt to open the specified registry key and enumerate its REG_MULTI_SZ values 120 | unique_hkey registryKey = RegistryQuery::OpenKeyFromString(device.DriverRegistryKey + L"\\" + wstring(key)); 121 | auto files = RegistryQuery::EnumerateMultiStringValues(registryKey); 122 | for (const auto& pair : files) 123 | { 124 | if (!pair.second.empty()) 125 | { 126 | // Construct a RuntimeFile from the string values 127 | RuntimeFile newFile(pair.second[0], ((pair.second.size() == 2) ? pair.second[1] : L"")); 128 | 129 | // Check whether the destination filename for the runtime file clashes with an existing file 130 | auto existing = std::find_if(list.begin(), list.end(), [newFile](RuntimeFile f) { 131 | return f.DestinationFilename == newFile.DestinationFilename; 132 | }); 133 | 134 | // Only add the new runtime file to the list if there's no clash 135 | if (existing == list.end()) { 136 | list.push_back(newFile); 137 | } 138 | else { 139 | LOG(L"{}: ignoring runtime file with duplicate destination filename {}", key, newFile.DestinationFilename); 140 | } 141 | } 142 | } 143 | } 144 | catch (const DeviceDiscoveryError& err) { 145 | LOG(L"Could not enumerate runtime files for the {} key: {}", key, err.message); 146 | } 147 | } 148 | 149 | void RegistryQuery::FillDriverDetails(Device& device) 150 | { 151 | // Log the device ID to provide context for any subsequent log messages and errors 152 | LOG(L"Querying device driver registry details for device {}", device.ID); 153 | 154 | // Attempt to open the DirectX adapter for the device 155 | auto adapterDetails = ObjectHelpers::GetZeroedStruct(); 156 | adapterDetails.AdapterLuid = LuidFromInt64(device.DeviceAdapter.InstanceLuid); 157 | auto error = CheckNtStatus(D3DKMTOpenAdapterFromLuid(&adapterDetails)); 158 | if (error) 159 | { 160 | throw error.Wrap( 161 | L"D3DKMTOpenAdapterFromLuid failed to open adapter with LUID " + 162 | std::to_wstring(device.DeviceAdapter.InstanceLuid) 163 | ); 164 | } 165 | 166 | // Ensure we automatically close the adapter handle when we finish 167 | unique_adapter_handle adapter(adapterDetails.hAdapter); 168 | 169 | // Retrieve the path to the driver store directory for the adapter 170 | QueryD3DRegistryInfo queryDriverStore; 171 | queryDriverStore.SetFilesystemQuery(D3DDDI_QUERYREGISTRY_DRIVERSTOREPATH); 172 | queryDriverStore.PerformQuery(adapter); 173 | device.DriverStorePath = wstring(queryDriverStore.RegistryInfo->OutputString); 174 | 175 | // If the driver store path begins with the "\SystemRoot" prefix then expand it 176 | wstring prefix = L"\\SystemRoot"; 177 | wstring systemRoot = wstring(wil::GetEnvironmentVariableW(L"SystemRoot").get()); 178 | if (device.DriverStorePath.find(prefix, 0) == 0) { 179 | device.DriverStorePath = device.DriverStorePath.replace(0, prefix.size(), systemRoot); 180 | } 181 | 182 | // Determine whether we're running on the host or inside a container 183 | // (e.g. when using a client tool to verify that a device has been mounted correctly) 184 | if (device.DriverStorePath.find(L"HostDriverStore", 0) != wstring::npos) 185 | { 186 | // We have no way of enumerating the CopyToVmWhenNewer subkey inside a container, so stop processing here 187 | LOG(L"Running inside a container, skipping runtime file enumeration"); 188 | return; 189 | } 190 | 191 | // Retrieve the list of additional runtime files that need to be copied to the System32 directory 192 | RegistryQuery::ProcessRuntimeFiles(device, L"CopyToVmOverwrite", false); 193 | RegistryQuery::ProcessRuntimeFiles(device, L"CopyToVmWhenNewer", false); 194 | 195 | // Retrieve the list of additional runtime files that need to be copied to the SysWOW64 directory 196 | RegistryQuery::ProcessRuntimeFiles(device, L"CopyToVmOverwriteWow64", true); 197 | RegistryQuery::ProcessRuntimeFiles(device, L"CopyToVmWhenNewerWow64", true); 198 | } 199 | -------------------------------------------------------------------------------- /library/src/RegistryQuery.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "Device.h" 4 | 5 | using std::map; 6 | using std::vector; 7 | using std::wstring; 8 | using std::wstring_view; 9 | using wil::unique_hkey; 10 | 11 | // Provides functionality for querying the Windows registry 12 | namespace RegistryQuery 13 | { 14 | // Enumerates the values of the supplied registry key and parses their data as REG_MULTI_SZ 15 | map< wstring, vector > EnumerateMultiStringValues(unique_hkey& key); 16 | 17 | // Extracts the individual strings of a REG_MULTI_SZ registry value 18 | vector ExtractMultiStringValue(const wchar_t* data, size_t numBytes); 19 | 20 | // Parses a registry key path and opens it using the appropriate root key 21 | unique_hkey OpenKeyFromString(wstring_view key); 22 | 23 | // Enumerates the runtime files for a device as listed under the specified registry key 24 | void ProcessRuntimeFiles(Device& device, wstring_view key, bool isWow64); 25 | 26 | // Queries the registry to retrieve driver-related details for the supplied PnP device 27 | void FillDriverDetails(Device& device); 28 | } 29 | -------------------------------------------------------------------------------- /library/src/SafeArray.cpp: -------------------------------------------------------------------------------- 1 | #include "SafeArray.h" 2 | 3 | unique_variant SafeArrayFactory::CreateStringArray(initializer_list elems) 4 | { 5 | // Create our array bounds descriptor 6 | SAFEARRAYBOUND bounds; 7 | bounds.lLbound = 0; 8 | bounds.cElements = elems.size(); 9 | 10 | // Create a VARIANT to hold our array 11 | unique_variant vtArray; 12 | vtArray.vt = VT_ARRAY | VT_BSTR; 13 | 14 | // Create the SAFEARRAY and lock it for data access 15 | vtArray.parray = SafeArrayCreate(VT_BSTR, 1, &bounds); 16 | auto error = CheckHresult(SafeArrayLock(vtArray.parray)); 17 | if (error) { 18 | throw error.Wrap(L"SafeArrayLock failed"); 19 | } 20 | 21 | // Populate the array with the supplied elements 22 | BSTR* array = reinterpret_cast(vtArray.parray->pvData); 23 | int index = 0; 24 | for (const auto& elem : elems) 25 | { 26 | // Note that a SAFEARRAY owns the memory of its elements, so we transfer ownership of each BSTR 27 | array[index] = wil::make_bstr(elem.c_str()).release(); 28 | index++; 29 | } 30 | 31 | // Unlock the array 32 | SafeArrayUnlock(vtArray.parray); 33 | return vtArray; 34 | } 35 | -------------------------------------------------------------------------------- /library/src/SafeArray.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "ErrorHandling.h" 4 | 5 | using std::initializer_list; 6 | using std::wstring; 7 | using wil::unique_variant; 8 | 9 | 10 | // Provides functionality for iterating over the contents of a one-dimensional SAFEARRAY 11 | template 12 | class SafeArrayIterator 13 | { 14 | public: 15 | SafeArrayIterator(SAFEARRAY* array) 16 | { 17 | // Lock the array for data access 18 | this->array = array; 19 | this->reinterpretedArray = reinterpret_cast(this->array->pvData); 20 | auto error = CheckHresult(SafeArrayLock(this->array)); 21 | if (error) { 22 | throw error.Wrap(L"SafeArrayLock failed"); 23 | } 24 | 25 | // Retrieve the array bounds and compute the number of elements 26 | long lowerBound = 0; 27 | long upperBound = 0; 28 | SafeArrayGetLBound(this->array, 1, &lowerBound); 29 | SafeArrayGetUBound(this->array, 1, &upperBound); 30 | this->numElements = (upperBound - lowerBound) + 1; 31 | } 32 | 33 | ~SafeArrayIterator() 34 | { 35 | // Unlock the array 36 | SafeArrayUnlock(this->array); 37 | this->array = nullptr; 38 | this->reinterpretedArray = nullptr; 39 | } 40 | 41 | T* begin() { 42 | return this->reinterpretedArray; 43 | } 44 | 45 | const T* begin() const { 46 | return this->reinterpretedArray; 47 | } 48 | 49 | T* end() { 50 | return this->reinterpretedArray + this->numElements; 51 | } 52 | 53 | const T* end() const { 54 | return this->reinterpretedArray + this->numElements; 55 | } 56 | 57 | private: 58 | SAFEARRAY* array; 59 | T* reinterpretedArray; 60 | long numElements; 61 | }; 62 | 63 | 64 | // Provides functionality for creating one-dimensional SAFEARRAY instances for specific element types 65 | class SafeArrayFactory 66 | { 67 | public: 68 | 69 | // Creates a SAFEARRAY of BSTR strings and wraps it in a VARIANT 70 | static unique_variant CreateStringArray(initializer_list elems); 71 | }; 72 | -------------------------------------------------------------------------------- /library/src/WmiQuery.cpp: -------------------------------------------------------------------------------- 1 | #include "WmiQuery.h" 2 | #include "ErrorHandling.h" 3 | #include "SafeArray.h" 4 | 5 | #include 6 | #include 7 | 8 | using std::set; 9 | using wil::unique_variant; 10 | using winrt::hstring; 11 | 12 | namespace 13 | { 14 | // Device property key for retrieving the DirectX adapter LUID 15 | const DEVPROPKEY DEVPKEY_Device_AdapterLuid = { 16 | { 0x60b193cb, 0x5276, 0x4d0f, { 0x96, 0xfc, 0xf1, 0x73, 0xab, 0xad, 0x3e, 0xc6 } }, 17 | 2 18 | }; 19 | 20 | // Formats a DEVPROPKEY as a string in the form "{00000000-0000-0000-0000-000000000000} 0" 21 | wstring DevPropKeyToString(const DEVPROPKEY& key) 22 | { 23 | return fmt::format( 24 | L"{{{:0>8X}-{:0>4X}-{:0>4X}-{:0>2X}{:0>2X}-{:0>2X}{:0>2X}{:0>2X}{:0>2X}{:0>2X}{:0>2X}}} {}", 25 | key.fmtid.Data1, 26 | key.fmtid.Data2, 27 | key.fmtid.Data3, 28 | key.fmtid.Data4[0], 29 | key.fmtid.Data4[1], 30 | key.fmtid.Data4[2], 31 | key.fmtid.Data4[3], 32 | key.fmtid.Data4[4], 33 | key.fmtid.Data4[5], 34 | key.fmtid.Data4[6], 35 | key.fmtid.Data4[7], 36 | key.pid 37 | ); 38 | } 39 | 40 | // Formats a PnP hardware ID for use in a WQL query 41 | wstring FormatHardwareID(const DXCoreHardwareID& dxHardwareID) 42 | { 43 | // Build a PCI hardware identifier string as per: 44 | // 45 | // and insert a trailing wildcard for the device instance 46 | return fmt::format( 47 | L"PCI\\\\VEN_{:0>4X}&DEV_{:0>4X}&SUBSYS_{:0>8X}&REV_{:0>2X}%", 48 | dxHardwareID.vendorID, 49 | dxHardwareID.deviceID, 50 | dxHardwareID.subSysID, 51 | dxHardwareID.revision 52 | ); 53 | } 54 | } 55 | 56 | WmiQuery::WmiQuery() 57 | { 58 | // Create a reusable error object 59 | DeviceDiscoveryError error; 60 | 61 | // Generate the string identifier for the DEVPKEY_Device_AdapterLuid device property key 62 | this->devPropKeyLUID = DevPropKeyToString(DEVPKEY_Device_AdapterLuid); 63 | 64 | // Create our IWbemLocator instance 65 | CatchHresult(error, this->wbemLocator = winrt::create_instance(CLSID_WbemLocator)); 66 | if (error) { 67 | throw error.Wrap(L"failed to create an IWbemLocator instance"); 68 | } 69 | 70 | // Connect to the WMI service and retrieve a service proxy object 71 | error = CheckHresult(this->wbemLocator->ConnectServer( 72 | wil::make_bstr(L"ROOT\\CIMV2").get(), 73 | nullptr, 74 | nullptr, 75 | nullptr, 76 | 0, 77 | nullptr, 78 | nullptr, 79 | this->wbemServices.put() 80 | )); 81 | if (error) { 82 | throw error.Wrap(L"failed to connect to the WMI service"); 83 | } 84 | 85 | // Set the security level for the service proxy 86 | error = CheckHresult(CoSetProxyBlanket( 87 | this->wbemServices.get(), 88 | RPC_C_AUTHN_WINNT, 89 | RPC_C_AUTHZ_NONE, 90 | nullptr, 91 | RPC_C_AUTHN_LEVEL_CALL, 92 | RPC_C_IMP_LEVEL_IMPERSONATE, 93 | nullptr, 94 | EOAC_NONE 95 | )); 96 | if (error) { 97 | throw error.Wrap(L"failed to set the security level for the WMI service proxy"); 98 | } 99 | 100 | // Retrieve the CIM class definition for the Win32_PnPEntity class 101 | error = CheckHresult(this->wbemServices->GetObject( 102 | wil::make_bstr(L"Win32_PnPEntity").get(), 103 | 0, 104 | nullptr, 105 | this->pnpEntityClass.put(), 106 | nullptr 107 | )); 108 | if (error) { 109 | throw error.Wrap(L"failed to retrieve the CIM class definition for the Win32_PnPEntity class"); 110 | } 111 | 112 | // Retrieve the input parameters class for the `GetDeviceProperties` method of the CIM class definition 113 | error = CheckHresult(this->pnpEntityClass->GetMethod(L"GetDeviceProperties", 0, this->inputParameters.put(), nullptr)); 114 | if (error) { 115 | throw error.Wrap(L"failed to retrieve the input parameters class for Win32_PnPEntity::GetDeviceProperties"); 116 | } 117 | } 118 | 119 | vector WmiQuery::GetDevicesForAdapters(const map& adapters) 120 | { 121 | // If we don't have any adapters then don't query WMI 122 | if (adapters.empty()) 123 | { 124 | LOG(L"Empty adapter list provided, skipping WMI query"); 125 | return {}; 126 | } 127 | 128 | // Gather the unique PnP hardware IDs from the DirectX adapters for use in our WQL query string 129 | set hardwareIDs; 130 | for (auto const& adapter : adapters) { 131 | hardwareIDs.insert(FormatHardwareID(adapter.second.HardwareID)); 132 | } 133 | 134 | // Build the WQL query string to retrieve the PnP devices associated with the adapters 135 | wstring query = L"SELECT * FROM Win32_PnPEntity WHERE Present = TRUE AND ("; 136 | int index = 0; 137 | int last = hardwareIDs.size() - 1; 138 | for (auto const& id : hardwareIDs) 139 | { 140 | query += L"DeviceID LIKE \"" + id + L"\"" + ((index < last) ? L" OR " : L""); 141 | index++; 142 | } 143 | query += L")"; 144 | 145 | // Log the query string 146 | LOG(L"Executing WQL query: {}", query); 147 | 148 | // Execute the query 149 | com_ptr enumerator; 150 | auto error = CheckHresult(wbemServices->ExecQuery( 151 | wil::make_bstr(L"WQL").get(), 152 | wil::make_bstr(query.c_str()).get(), 153 | 0, 154 | nullptr, 155 | enumerator.put() 156 | )); 157 | if (error) { 158 | throw error.Wrap(L"WQL query execution failed"); 159 | } 160 | 161 | // Iterate over the retrieved PnP devices and match them to their corresponding DirectX adapters 162 | vector devices; 163 | for (int index = 0; ; index++) 164 | { 165 | // Retrieve the device for the current loop iteration 166 | ULONG numReturned = 0; 167 | com_ptr device; 168 | auto error = CheckHresult(enumerator->Next(WBEM_INFINITE, 1, device.put(), &numReturned)); 169 | if (error) { 170 | throw error.Wrap(L"enumerating PnP devices failed"); 171 | } 172 | if (numReturned == 0) { 173 | break; 174 | } 175 | 176 | // Extract the details for the device and determine whether it matches any of our adapters 177 | Device details = this->ExtractDeviceDetails(device); 178 | auto matchingAdapter = adapters.find(details.DeviceAdapter.InstanceLuid); 179 | if (matchingAdapter != adapters.end()) 180 | { 181 | // Log the match 182 | LOG(L"Matched adapter LUID {} to PnP device {}", details.DeviceAdapter.InstanceLuid, details.ID); 183 | 184 | // Replace the device's adapter details with the matching adapter 185 | details.DeviceAdapter = matchingAdapter->second; 186 | 187 | // Include the device in our results 188 | devices.push_back(details); 189 | } 190 | } 191 | 192 | return devices; 193 | } 194 | 195 | Device WmiQuery::ExtractDeviceDetails(const com_ptr& device) const 196 | { 197 | Device details; 198 | DeviceDiscoveryError error; 199 | 200 | // Retrieve the unique PnP device ID of the device 201 | unique_variant vtDeviceID; 202 | error = CheckHresult(device->Get(L"DeviceID", 0, &vtDeviceID, nullptr, nullptr)); 203 | if (error) { 204 | throw error.Wrap(L"failed to retrieve DeviceID property of PnP device"); 205 | } 206 | details.ID = winrt::to_hstring(vtDeviceID.bstrVal); 207 | 208 | // Retrieve the human-readable description of the device 209 | unique_variant vtDescription; 210 | error = CheckHresult(device->Get(L"Description", 0, &vtDescription, nullptr, nullptr)); 211 | if (error) { 212 | throw error.Wrap(L"failed to retrieve Description property of PnP device"); 213 | } 214 | details.Description = winrt::to_hstring(vtDescription.bstrVal); 215 | 216 | // Retrieve the vendor of the device 217 | unique_variant vtVendor; 218 | error = CheckHresult(device->Get(L"Manufacturer", 0, &vtVendor, nullptr, nullptr)); 219 | if (error) { 220 | throw error.Wrap(L"failed to retrieve Manufacturer property of PnP device"); 221 | } 222 | details.Vendor = winrt::to_hstring(vtVendor.bstrVal); 223 | 224 | // Retrieve the object path for the instance so we can call instance methods with it 225 | unique_variant vtPath; 226 | error = CheckHresult(device->Get(L"__Path", 0, &vtPath, nullptr, nullptr)); 227 | if (error) { 228 | throw error.Wrap(L"failed to retrieve __Path property of PnP device"); 229 | } 230 | 231 | // Create an instance of the input parameters type for the `GetDeviceProperties` instance method 232 | com_ptr inputArgs; 233 | error = CheckHresult(this->inputParameters->SpawnInstance(0, inputArgs.put())); 234 | if (error) { 235 | throw error.Wrap(L"failed to spawn input parameters instance for Win32_PnPEntity::GetDeviceProperties"); 236 | } 237 | 238 | // Populate the input parameters with the list of decive property keys we want to retrieve 239 | unique_variant vtPropertyKeys = SafeArrayFactory::CreateStringArray({ 240 | L"DEVPKEY_Device_Driver", 241 | L"DEVPKEY_Device_LocationPaths", 242 | this->devPropKeyLUID 243 | }); 244 | error = CheckHresult(inputArgs->Put(L"devicePropertyKeys", 0, &vtPropertyKeys, CIM_FLAG_ARRAY | CIM_STRING)); 245 | if (error) { 246 | throw error.Wrap(L"failed to assign input parameters array for Win32_PnPEntity::GetDeviceProperties"); 247 | } 248 | 249 | // Call the `GetDeviceProperties` instance method 250 | com_ptr callResult; 251 | error = CheckHresult(this->wbemServices->ExecMethod( 252 | vtPath.bstrVal, 253 | wil::make_bstr(L"GetDeviceProperties").get(), 254 | 0, 255 | nullptr, 256 | inputArgs.get(), 257 | nullptr, 258 | callResult.put() 259 | )); 260 | if (error) { 261 | throw error.Wrap(L"failed to invoke Win32_PnPEntity::GetDeviceProperties()"); 262 | } 263 | 264 | // Retrieve the return value 265 | com_ptr returnValue; 266 | error = CheckHresult(callResult->GetResultObject(WBEM_INFINITE, returnValue.put())); 267 | if (error) { 268 | throw error.Wrap(L"failed to retrieve return value for Win32_PnPEntity::GetDeviceProperties"); 269 | } 270 | 271 | // Extract the device properties array and verify that it matches the expected type 272 | unique_variant vtPropertiesArray; 273 | error = CheckHresult(returnValue->Get(L"deviceProperties", 0, &vtPropertiesArray, nullptr, nullptr)); 274 | if (error) { 275 | throw error.Wrap(L"failed to retrieve deviceProperties property of Win32_PnPEntity::GetDeviceProperties return value"); 276 | } 277 | if (vtPropertiesArray.vt != (VT_ARRAY | VT_UNKNOWN)) { 278 | throw CreateError(L"deviceProperties value was not an array of IUnknown objects"); 279 | } 280 | 281 | // Iterate over the device properties array 282 | SafeArrayIterator propertiesIterator(vtPropertiesArray.parray); 283 | for (auto element : propertiesIterator) 284 | { 285 | // Cast the property object to an IWbemClassObject 286 | IWbemClassObject* object = nullptr; 287 | error = CheckHresult(element->QueryInterface(&object)); 288 | if (error) { 289 | throw error.Wrap(L"IUnknown::QueryInterface() failed for Win32_PnPDeviceProperty object"); 290 | } 291 | 292 | // Retrieve the key name of the property 293 | unique_variant vtKeyName; 294 | error = CheckHresult(object->Get(L"KeyName", 0, &vtKeyName, nullptr, nullptr)); 295 | if (error) { 296 | throw error.Wrap(L"failed to retrieve KeyName property of PnP device property"); 297 | } 298 | wstring keyName(winrt::to_hstring(vtKeyName.bstrVal)); 299 | 300 | // Attempt to retrieve the value of the property 301 | unique_variant data; 302 | HRESULT result = object->Get(L"Data", 0, &data, nullptr, nullptr); 303 | if (FAILED(result)) 304 | { 305 | // The property has no value, so ignore it 306 | continue; 307 | } 308 | 309 | // Determine which device property we are dealing with 310 | if (keyName == L"DEVPKEY_Device_Driver") 311 | { 312 | // Verify that the device driver value is of the expected type 313 | if (data.vt != VT_BSTR) { 314 | throw CreateError(L"DeviceDriver value was not a string"); 315 | } 316 | 317 | // Construct the full path to the registry key for the device's driver 318 | details.DriverRegistryKey = 319 | L"HKEY_LOCAL_MACHINE\\SYSTEM\\CurrentControlSet\\Control\\Class\\" + 320 | winrt::to_hstring(data.bstrVal); 321 | } 322 | else if (keyName == L"DEVPKEY_Device_LocationPaths") 323 | { 324 | // Verify that the LocationPaths array is of the expected type 325 | if (data.vt != (VT_ARRAY | VT_BSTR)) { 326 | throw CreateError(L"LocationPaths value was not an array of strings"); 327 | } 328 | 329 | // Retrieve the first element from the LocationPaths array 330 | SafeArrayIterator locationIterator(data.parray); 331 | details.LocationPath = winrt::to_hstring(*locationIterator.begin()); 332 | } 333 | else if (keyName == this->devPropKeyLUID) 334 | { 335 | // Determine whether the LUID value is represented as a raw 64-bit integer or a string representation 336 | if (data.vt == VT_I8) { 337 | details.DeviceAdapter.InstanceLuid = data.llVal; 338 | } 339 | else if (data.vt == VT_BSTR) 340 | { 341 | // Parse the string back into a 64-bit integer 342 | details.DeviceAdapter.InstanceLuid = std::stoll(winrt::to_string(data.bstrVal)); 343 | } 344 | else { 345 | throw CreateError(L"LUID value was not a 64-bit integer or a string"); 346 | } 347 | } 348 | } 349 | 350 | return details; 351 | } 352 | -------------------------------------------------------------------------------- /library/src/WmiQuery.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "Adapter.h" 4 | #include "Device.h" 5 | 6 | using std::map; 7 | using std::vector; 8 | using std::wstring; 9 | using winrt::com_ptr; 10 | 11 | // Provides functionality for querying Windows Management Instrumentation (WMI) 12 | class WmiQuery 13 | { 14 | public: 15 | 16 | WmiQuery(); 17 | 18 | // Retrieves the device details for the underlying PnP devices associated with the supplied DirectX adapters 19 | vector GetDevicesForAdapters(const map& adapters); 20 | 21 | private: 22 | 23 | // Extracts the details from a PnP device 24 | Device ExtractDeviceDetails(const com_ptr& device) const; 25 | 26 | // Our COM objects for communicating with WMI 27 | com_ptr wbemLocator; 28 | com_ptr wbemServices; 29 | com_ptr pnpEntityClass; 30 | com_ptr inputParameters; 31 | 32 | // The string identifier for the DEVPROPKEY_GPU_LUID device property key 33 | wstring devPropKeyLUID; 34 | }; 35 | -------------------------------------------------------------------------------- /library/src/pch.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | // Prevent the GetObject => GetObjectW macro definition from the Windows headers from interfering with the IDL for IWbemServices 12 | #ifdef GetObject 13 | #undef GetObject 14 | #endif 15 | #include 16 | 17 | // Enable wchar_t support for filenames in spdlog 18 | #define SPDLOG_WCHAR_FILENAMES 19 | #include 20 | #define LOG(...) SPDLOG_INFO(__VA_ARGS__) 21 | 22 | // Include the range formatting support from fmt to facilitate logging container types 23 | #include 24 | #define FMT(x) fmt::format(L"{}", x) 25 | 26 | #include 27 | #include 28 | #include 29 | #include 30 | #include 31 | #include 32 | #include 33 | #include 34 | #include 35 | #include 36 | #include 37 | -------------------------------------------------------------------------------- /library/test/test-device-discovery-cpp.cpp: -------------------------------------------------------------------------------- 1 | #include "DeviceDiscovery.h" 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | using std::endl; 9 | using std::vector; 10 | using std::wstring; 11 | using std::wclog; 12 | using std::wcout; 13 | 14 | wstring FormatBoolean(bool value) { 15 | return (value ? L"true" : L"false"); 16 | } 17 | 18 | int wmain(int argc, wchar_t *argv[], wchar_t *envp[]) 19 | { 20 | // Gather our command-line arguments 21 | vector args; 22 | for (int i = 0; i < argc; ++i) { 23 | args.push_back(argv[i]); 24 | } 25 | 26 | // Enable verbose logging for the device discovery library if it has been requested 27 | if (std::find(args.begin(), args.end(), L"--verbose") != args.end()) { 28 | EnableDiscoveryLogging(); 29 | } 30 | 31 | try 32 | { 33 | // Perform device discovery 34 | DeviceDiscovery discovery; 35 | discovery.DiscoverDevices(DeviceFilter::AllDevices, true, true); 36 | int numDevices = discovery.GetNumDevices(); 37 | wcout << L"DirectX device discovery library version " << GetDiscoveryLibraryVersion() << endl; 38 | wcout << L"Discovered " << numDevices << L" devices.\n" << endl; 39 | 40 | // Print the details for each device 41 | for (int device = 0; device < numDevices; ++device) 42 | { 43 | wcout << L"[Device " << device << L" details]\n\n"; 44 | wcout << L"PnP Hardware ID: " << discovery.GetDeviceID(device) << L"\n"; 45 | wcout << L"DX Adapter LUID: " << discovery.GetDeviceAdapterLUID(device) << L"\n"; 46 | wcout << L"Description: " << discovery.GetDeviceDescription(device) << L"\n"; 47 | wcout << L"Driver Registry Key: " << discovery.GetDeviceDriverRegistryKey(device) << L"\n"; 48 | wcout << L"DriverStore Path: " << discovery.GetDeviceDriverStorePath(device) << L"\n"; 49 | wcout << L"LocationPath: " << discovery.GetDeviceLocationPath(device) << L"\n"; 50 | wcout << L"Vendor: " << discovery.GetDeviceVendor(device) << L"\n"; 51 | wcout << L"Is Integrated: " << FormatBoolean(discovery.IsDeviceIntegrated(device)) << L"\n"; 52 | wcout << L"Is Detachable: " << FormatBoolean(discovery.IsDeviceDetachable(device)) << L"\n"; 53 | wcout << L"Supports Display: " << FormatBoolean(discovery.DoesDeviceSupportDisplay(device)) << L"\n"; 54 | wcout << L"Supports Compute: " << FormatBoolean(discovery.DoesDeviceSupportCompute(device)) << L"\n"; 55 | 56 | int numRuntimeFiles = discovery.GetNumRuntimeFiles(device); 57 | wcout << L"\n" << numRuntimeFiles << L" Additional System32 runtime files:\n"; 58 | for (int file = 0; file < numRuntimeFiles; ++file) 59 | { 60 | wcout << L" " 61 | << discovery.GetRuntimeFileSource(device, file) << " => " 62 | << discovery.GetRuntimeFileDestination(device, file) << "\n"; 63 | } 64 | 65 | int numRuntimeFilesWow64 = discovery.GetNumRuntimeFilesWow64(device); 66 | wcout << L"\n" << numRuntimeFilesWow64 << L" Additional SysWOW64 runtime files:\n"; 67 | for (int file = 0; file < numRuntimeFilesWow64; ++file) 68 | { 69 | wcout << L" " 70 | << discovery.GetRuntimeFileSourceWow64(device, file) << L" => " 71 | << discovery.GetRuntimeFileDestinationWow64(device, file) << L"\n"; 72 | } 73 | 74 | wcout << endl; 75 | } 76 | } 77 | catch (const DeviceDiscoveryException& err) { 78 | wclog << L"Error: " << err.what() << endl; 79 | } 80 | 81 | return 0; 82 | } 83 | -------------------------------------------------------------------------------- /library/vcpkg.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://raw.githubusercontent.com/microsoft/vcpkg/master/scripts/vcpkg.schema.json", 3 | "name": "directx-device-discovery", 4 | "version": "0.0.1", 5 | "dependencies": [ 6 | "cppwinrt", 7 | "fmt", 8 | { 9 | "name": "spdlog", 10 | "features": ["wchar"] 11 | }, 12 | "wil" 13 | ], 14 | "supports": "windows & x64", 15 | "builtin-baseline": "ca8bde3748134247a725c215c4f3a364ee9126fe" 16 | } 17 | -------------------------------------------------------------------------------- /plugins/cmd/device-plugin-mcdm/main.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package main 4 | 5 | import ( 6 | "github.com/tensorworks/directx-device-plugins/plugins/internal/discovery" 7 | "github.com/tensorworks/directx-device-plugins/plugins/internal/plugin" 8 | ) 9 | 10 | func main() { 11 | plugin.CommonMain("mcdm", "directx.microsoft.com/compute", discovery.ComputeOnly) 12 | } 13 | -------------------------------------------------------------------------------- /plugins/cmd/device-plugin-wddm/main.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package main 4 | 5 | import ( 6 | "github.com/tensorworks/directx-device-plugins/plugins/internal/discovery" 7 | "github.com/tensorworks/directx-device-plugins/plugins/internal/plugin" 8 | ) 9 | 10 | func main() { 11 | plugin.CommonMain("wddm", "directx.microsoft.com/display", discovery.DisplayAndCompute) 12 | } 13 | -------------------------------------------------------------------------------- /plugins/cmd/gen-device-mounts/main.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package main 4 | 5 | import ( 6 | "encoding/json" 7 | "fmt" 8 | "log" 9 | "os" 10 | "os/exec" 11 | "strings" 12 | 13 | "github.com/spf13/pflag" 14 | "github.com/tensorworks/directx-device-plugins/plugins/internal/discovery" 15 | "github.com/tensorworks/directx-device-plugins/plugins/internal/mount" 16 | ) 17 | 18 | // Prints output to stderr 19 | func ePrint(a ...any) (n int, err error) { 20 | return fmt.Fprint(os.Stdout, a...) 21 | } 22 | 23 | // Prints output to stderr, with spaces between operands and a trailing newline 24 | func ePrintln(a ...any) (n int, err error) { 25 | return fmt.Fprintln(os.Stdout, a...) 26 | } 27 | 28 | // Determines whether the specified value exists in the supplied list of values 29 | func contains[T comparable](values []T, value T) bool { 30 | for _, existing := range values { 31 | if existing == value { 32 | return true 33 | } 34 | } 35 | 36 | return false 37 | } 38 | 39 | // Formats a list of numeric values 40 | func formatNumbers[T uint | int64](values []T) string { 41 | formatted := []string{} 42 | for _, value := range values { 43 | formatted = append(formatted, fmt.Sprint(value)) 44 | } 45 | return strings.Join(formatted, ", ") 46 | } 47 | 48 | // Formats a list of string values 49 | func formatStrings(values []string, delimiter string) string { 50 | formatted := []string{} 51 | for _, str := range values { 52 | formatted = append(formatted, fmt.Sprint("\"", str, "\"")) 53 | } 54 | return strings.Join(formatted, delimiter) 55 | } 56 | 57 | func main() { 58 | 59 | // Configure our custom help message 60 | pflag.CommandLine.SetOutput(os.Stderr) 61 | pflag.Usage = func() { 62 | ePrintln("gen-device-mounts: generates flags for exposing devices to containers with `ctr run`") 63 | ePrintln("\nUsage syntax:", os.Args[0], "[-h] [--format text|json] [--all] [--index ] [--luid ] [--path ] [--run] [--verbose] []") 64 | ePrintln("\nOptions:") 65 | pflag.PrintDefaults() 66 | ePrintln(strings.Join([]string{ 67 | "", 68 | "The list of available DirectX devices (including their enumeration indices, LUID values, and PCI paths)", 69 | "can be retrieved by running either `test-device-discovery-cpp.exe` or `test-device-discovery-go.exe`.", 70 | "", 71 | "NOTES REGARDING OTHER FRONTENDS", 72 | "-------------------------------", 73 | "", 74 | "Docker:", 75 | "", 76 | "Although Docker version 23.0.0 introduced support for exposing individial devices using their PCI location", 77 | "paths, it still lacks the ability to bind-mount individual files rather than directories, which prevents", 78 | "it from using the flags generated by `gen-device-mounts`. This is due to its continued use of the HCSv1", 79 | "API rather than the newer HCSv2 API used by containerd. When Docker eventually migrates to using HCSv2", 80 | "(or using containerd under Windows the way it does under Linux) then `gen-device-mounts` will be updated", 81 | "to add an option to invoke `docker run` instead of `ctr run` when --run is specified.", 82 | "", 83 | "nerdctl:", 84 | "", 85 | "nerdctl is currently blocked by two outstanding issues that prevent it from using the flags generated by", 86 | "`gen-device-mounts`:", 87 | "", 88 | "- https://github.com/containerd/nerdctl/pull/2079", 89 | "- https://github.com/containerd/nerdctl/issues/759", 90 | "", 91 | "Once these blockers have been resolved then `gen-device-mounts` will be updated to add an option to invoke", 92 | "`nerdctl run` instead of `ctr run` when --run is specified.", 93 | }, "\n")) 94 | os.Exit(1) 95 | } 96 | 97 | // Parse our command-line arguments 98 | allDevices := pflag.Bool("all", false, "Expose all available DirectX devices") 99 | outputFormat := pflag.String("format", "text", "The output format for generated flags (\"text\" or \"json\")") 100 | devicesByIndex := pflag.UintSlice("index", []uint{}, "Expose the DirectX device with the specified enumeration index (can be specified multiple times)") 101 | devicesByLUID := pflag.Int64Slice("luid", []int64{}, "Expose the DirectX device with the specified LUID (can be specified multiple times)") 102 | devicesByPath := pflag.StringSlice("path", []string{}, "Expose the DirectX device with the specified PCI path (can be specified multiple times)") 103 | runContainer := pflag.Bool("run", false, "run") 104 | verbose := pflag.Bool("verbose", false, "Enable verbose output") 105 | pflag.Parse() 106 | 107 | // Verify that a valid output format was specified 108 | if !contains([]string{"text", "json"}, *outputFormat) { 109 | log.Fatalln("Error: unknown output format \"", *outputFormat, "\" (supported formats are \"text\" and \"json\")") 110 | } 111 | 112 | // Print our device selection criteria 113 | if *verbose { 114 | ePrintln("Device selection criteria:") 115 | if *allDevices { 116 | ePrintln("- Include all available DirectX devices") 117 | } else { 118 | if len(*devicesByIndex) > 0 { 119 | ePrintln("- Include DirectX devices with the following enumeration indices:", formatNumbers(*devicesByIndex)) 120 | } 121 | if len(*devicesByLUID) > 0 { 122 | ePrintln("- Include DirectX devices with the following LUID values:", formatNumbers(*devicesByLUID)) 123 | } 124 | if len(*devicesByPath) > 0 { 125 | ePrintln("- Include DirectX devices with the following PCI paths:", formatStrings(*devicesByPath, ", ")) 126 | } 127 | } 128 | ePrintln() 129 | } 130 | 131 | // Attempt to load the DirectX device discovery library 132 | if err := discovery.LoadDiscoveryLibrary(); err != nil { 133 | log.Fatalln("Error:", err) 134 | } 135 | 136 | // Create a new DeviceDiscovery object 137 | deviceDiscovery, err := discovery.NewDeviceDiscovery() 138 | if err != nil { 139 | log.Fatalln("Error:", err) 140 | } 141 | 142 | // Perform device discovery 143 | if err := deviceDiscovery.DiscoverDevices(discovery.AllDevices, true, true); err != nil { 144 | log.Fatalln("Error:", err) 145 | } 146 | 147 | // Filter the list of devices based on the specified selection criteria 148 | filtered := []*discovery.Device{} 149 | for index, device := range deviceDiscovery.Devices { 150 | if *allDevices || contains(*devicesByIndex, uint(index)) || contains(*devicesByLUID, device.AdapterLUID) || contains(*devicesByPath, device.LocationPath) { 151 | filtered = append(filtered, device) 152 | } 153 | } 154 | 155 | // Print the details of the selected devices 156 | if *verbose { 157 | ePrint("Selected ", len(filtered), " device(s) based on selection criteria:\n") 158 | for index, device := range filtered { 159 | ePrint("- Index ", index, ", LUID ", device.AdapterLUID, ", PCI Path ", device.LocationPath, "\n") 160 | } 161 | ePrintln() 162 | } 163 | 164 | // Append our default runtime file mounts to the lists for each device 165 | for _, device := range deviceDiscovery.Devices { 166 | 167 | // Determine whether we have any additional runtime files for the device vendor 168 | files, haveFiles := mount.DefaultMounts[strings.ToLower(device.Vendor)] 169 | filesWow64, haveFilesWow64 := mount.DefaultMountsWow64[strings.ToLower(device.Vendor)] 170 | 171 | // Merge any additions for System32 172 | if haveFiles { 173 | ignored := device.AppendRuntimeFiles(files) 174 | for _, file := range ignored { 175 | ePrintln("Ignoring additional 64-bit runtime file because it clashes with an existing filename: ", file) 176 | } 177 | } 178 | 179 | // Merge any additions for SysWOW64 180 | if haveFilesWow64 { 181 | ignored := device.AppendRuntimeFilesWow64(filesWow64) 182 | for _, file := range ignored { 183 | ePrintln("Ignoring additional 32-bit runtime file because it clashes with an existing filename: ", file) 184 | } 185 | } 186 | } 187 | 188 | // Generate the device specs and runtime file mounts for the selected devices 189 | specs := mount.SpecsForDevices(filtered) 190 | mounts := mount.MountsForDevices(filtered) 191 | 192 | // Generate the flags for mounting the devices 193 | flags := []string{} 194 | for _, spec := range specs { 195 | flags = append(flags, "--device", spec.HostPath) 196 | } 197 | for _, mount := range mounts { 198 | flags = append(flags, "--mount", fmt.Sprint("src=", mount.HostPath, ",dst=", mount.GetContainerPath())) 199 | } 200 | 201 | // Determine whether we are just printing the flags, or running a container with them 202 | if *runContainer { 203 | 204 | // Create a command object to represent our `ctr run` invocation 205 | cmd := exec.Command("ctr", "run", "--rm") 206 | 207 | // Allow the child process to inherit all standard streams 208 | cmd.Stdin = os.Stdin 209 | cmd.Stderr = os.Stderr 210 | cmd.Stdout = os.Stdout 211 | 212 | // Append both our generated flags and any loose command-line arguments to the invocation 213 | cmd.Args = append(cmd.Args, flags...) 214 | cmd.Args = append(cmd.Args, pflag.Args()...) 215 | 216 | // Print the generated command to stderr, wrapping each flag in quotes 217 | ePrintln(formatStrings(cmd.Args, " ")) 218 | 219 | // Attempt to run `ctr run` 220 | if err := cmd.Run(); err != nil { 221 | log.Fatalln("Error:", err) 222 | } 223 | 224 | } else { 225 | 226 | // Determine which format we are using to print the list of flags 227 | if *outputFormat == "json" { 228 | 229 | // Attempt to format the flags as a JSON array 230 | formatted, err := json.Marshal(flags) 231 | if err != nil { 232 | log.Fatalln("Error:", err) 233 | } 234 | 235 | // Print the JSON array to stdout 236 | fmt.Println(formatted) 237 | 238 | } else { 239 | 240 | // Print the list of flags to stdout, wrapping each flag in quotes 241 | fmt.Println(formatStrings(flags, " ")) 242 | 243 | } 244 | } 245 | } 246 | -------------------------------------------------------------------------------- /plugins/cmd/query-hcs-capabilities/main.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package main 4 | 5 | import ( 6 | "encoding/json" 7 | "errors" 8 | "fmt" 9 | "log" 10 | "os/exec" 11 | "strings" 12 | "unsafe" 13 | 14 | "golang.org/x/sys/windows" 15 | "golang.org/x/sys/windows/registry" 16 | ) 17 | 18 | var ( 19 | vmcompute = windows.NewLazyDLL("vmcompute.dll") 20 | hcsGetServiceProperties = vmcompute.NewProc("HcsGetServiceProperties") 21 | ) 22 | 23 | // HCS schema Version structure: 24 | type Version struct { 25 | Major uint32 26 | Minor uint32 27 | } 28 | 29 | // HCS schema BasicInformation structure: 30 | type BasicInformation struct { 31 | SupportedSchemaVersions []*Version 32 | } 33 | 34 | // Modified version of the HCS schema ServiceProperties structure: 35 | // Note that we just treat the array as containing BasicInformation objects, since that's what our specific query returns 36 | type ServiceProperties struct { 37 | Properties []*BasicInformation 38 | } 39 | 40 | func getSupportedSchemas() ([]*Version, error) { 41 | 42 | // Attempt to load vmcompute.dll 43 | if err := vmcompute.Load(); err != nil { 44 | return nil, fmt.Errorf("failed to load %s: %s", vmcompute.Name, err.Error()) 45 | } 46 | 47 | // Convert our query string into a UTF-16 pointer 48 | queryPtr, err := windows.UTF16PtrFromString("{\"PropertyTypes\": [\"Basic\"]}") 49 | if err != nil { 50 | return nil, fmt.Errorf("failed to convert string to UTF-16: %s", err.Error()) 51 | } 52 | 53 | // Call HcsGetServiceProperties() to query the supported schema version 54 | var resultPtr *uint16 = nil 55 | retval, _, _ := hcsGetServiceProperties.Call( 56 | uintptr(unsafe.Pointer(queryPtr)), 57 | uintptr(unsafe.Pointer(&resultPtr)), 58 | ) 59 | 60 | // Verify that the query was successful 61 | if retval != 0 { 62 | return nil, fmt.Errorf("HcsGetServiceProperties() failed: %v", windows.Errno(retval)) 63 | } 64 | 65 | // Convert the result into a JSON string 66 | result := windows.UTF16PtrToString((*uint16)(unsafe.Pointer(resultPtr))) 67 | 68 | // Parse the JSON 69 | serviceProperties := &ServiceProperties{} 70 | if err := json.Unmarshal([]byte(result), &serviceProperties); err != nil { 71 | return nil, err 72 | } 73 | 74 | // Verify that we have at least one supported schema version 75 | if len(serviceProperties.Properties) == 0 || len(serviceProperties.Properties[0].SupportedSchemaVersions) == 0 { 76 | return nil, errors.New("HcsGetServiceProperties() returned zero supported schema versions") 77 | } 78 | 79 | // Return the list of supported schema versions 80 | return serviceProperties.Properties[0].SupportedSchemaVersions, nil 81 | } 82 | 83 | func getWindowsVersion() (string, error) { 84 | 85 | // Use `RtlGetVersion()` to query the Windows version number, so manifest semantics are ignored 86 | versionInfo := windows.RtlGetVersion() 87 | 88 | // Use PowerShell to query WMI for the system caption, since `ProductName` in the registry is no longer reliable 89 | productName, err := exec.Command("powershell", "-Command", "(Get-WmiObject -Class Win32_OperatingSystem).Caption").Output() 90 | if err != nil { 91 | return "", fmt.Errorf("failed to query WMI for the product version string: %s", err.Error()) 92 | } 93 | 94 | // Open the registry key for the Windows version information 95 | key, err := registry.OpenKey(registry.LOCAL_MACHINE, `SOFTWARE\Microsoft\Windows NT\CurrentVersion`, registry.QUERY_VALUE) 96 | if err != nil { 97 | return "", fmt.Errorf("failed to open the registry key \"SOFTWARE\\Microsoft\\Windows NT\\CurrentVersion\": %s", err.Error()) 98 | } 99 | defer key.Close() 100 | 101 | // Attempt to retrieve the display version on newer systems, falling back to the old release ID value on older systems 102 | displayVersion, _, err := key.GetStringValue("DisplayVersion") 103 | if err != nil { 104 | displayVersion, _, err = key.GetStringValue("ReleaseId") 105 | if err != nil { 106 | return "", fmt.Errorf("failed to retrieve either the \"DisplayVersion\" or \"ReleaseId\" registry value: %s", err.Error()) 107 | } 108 | } 109 | 110 | // Retrieve the revision number, since this isn't included in the `RtlGetVersion()` output 111 | revisionNumber, _, err := key.GetIntegerValue("UBR") 112 | if err != nil { 113 | return "", fmt.Errorf("failed to retrieve the \"UBR\" registry value: %s", err.Error()) 114 | } 115 | 116 | // Build an aggregated version string from the retrieved values 117 | return fmt.Sprintf( 118 | "%s, version %s (OS build %d.%d.%d.%d)", 119 | strings.TrimSpace(string(productName)), 120 | displayVersion, 121 | versionInfo.MajorVersion, 122 | versionInfo.MinorVersion, 123 | versionInfo.BuildNumber, 124 | revisionNumber, 125 | ), nil 126 | } 127 | 128 | func main() { 129 | 130 | // Retrieve the Windows version information 131 | windowsVersion, err := getWindowsVersion() 132 | if err != nil { 133 | log.Fatalf("Failed to retrieve Windows version information: %s", err.Error()) 134 | } 135 | 136 | // Query the Host Compute Service (HCS) for the list of supported schema versions 137 | supportedSchemas, err := getSupportedSchemas() 138 | if err != nil { 139 | log.Fatalf("Failed to retrieve the supported HCS schema version: %s", err.Error()) 140 | } 141 | 142 | // Print the Windows version details and supported schema version 143 | fmt.Println("Operating system version:") 144 | fmt.Println(windowsVersion) 145 | fmt.Println() 146 | fmt.Println("Supported HCS schema versions:") 147 | for _, version := range supportedSchemas { 148 | fmt.Printf("- %d.%d", version.Major, version.Minor) 149 | } 150 | } 151 | -------------------------------------------------------------------------------- /plugins/cmd/test-device-discovery-go/main.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package main 4 | 5 | import ( 6 | "flag" 7 | "fmt" 8 | "log" 9 | 10 | "github.com/tensorworks/directx-device-plugins/plugins/internal/discovery" 11 | ) 12 | 13 | func main() { 14 | 15 | // Parse our command-line arguments 16 | verbose := flag.Bool("verbose", false, "enable verbose logging") 17 | flag.Parse() 18 | 19 | // Attempt to load the DirectX device discovery library 20 | if err := discovery.LoadDiscoveryLibrary(); err != nil { 21 | log.Fatalln("Error:", err) 22 | } 23 | 24 | // Enable verbose logging for the device discovery library if it has been requested 25 | if *verbose { 26 | discovery.EnableDiscoveryLogging() 27 | } 28 | 29 | // Create a new DeviceDiscovery object 30 | deviceDiscovery, err := discovery.NewDeviceDiscovery() 31 | if err != nil { 32 | log.Fatalln("Error:", err) 33 | } 34 | 35 | // Perform device discovery 36 | if err := deviceDiscovery.DiscoverDevices(discovery.AllDevices, true, true); err != nil { 37 | log.Fatalln("Error:", err) 38 | } 39 | 40 | // Print the library version string and the number of discovered devices 41 | fmt.Print("DirectX device discovery library version ", discovery.GetDiscoveryLibraryVersion(), "\n") 42 | fmt.Print("Discovered ", len(deviceDiscovery.Devices), " devices.\n\n") 43 | 44 | // Print the details for each device 45 | for index, device := range deviceDiscovery.Devices { 46 | fmt.Print("[Device ", index, " details]\n\n") 47 | fmt.Println("PnP Hardware ID: ", device.ID) 48 | fmt.Println("DX Adapter LUID: ", device.AdapterLUID) 49 | fmt.Println("Description: ", device.Description) 50 | fmt.Println("Driver Registry Key:", device.DriverRegistryKey) 51 | fmt.Println("DriverStore Path: ", device.DriverStorePath) 52 | fmt.Println("LocationPath: ", device.LocationPath) 53 | fmt.Println("Vendor: ", device.Vendor) 54 | fmt.Println("Is Integrated: ", device.IsIntegrated) 55 | fmt.Println("Is Detachable: ", device.IsDetachable) 56 | fmt.Println("Supports Display: ", device.SupportsDisplay) 57 | fmt.Println("Supports Compute: ", device.SupportsCompute) 58 | 59 | fmt.Print("\n", len(device.RuntimeFiles), " Additional System32 runtime files:\n") 60 | for _, file := range device.RuntimeFiles { 61 | fmt.Println(" ", file.SourcePath, "=>", file.DestinationFilename) 62 | } 63 | 64 | fmt.Print("\n", len(device.RuntimeFilesWow64), " Additional SysWOW64 runtime files:\n") 65 | for _, file := range device.RuntimeFilesWow64 { 66 | fmt.Println(" ", file.SourcePath, "=>", file.DestinationFilename) 67 | } 68 | 69 | fmt.Print("\n") 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /plugins/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/tensorworks/directx-device-plugins/plugins 2 | 3 | go 1.18 4 | 5 | require ( 6 | github.com/fsnotify/fsnotify v1.5.4 7 | github.com/spf13/pflag v1.0.5 8 | github.com/spf13/viper v1.12.0 9 | go.uber.org/zap v1.19.0 10 | golang.org/x/exp v0.0.0-20220613132600-b0d781184e0d 11 | golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a 12 | google.golang.org/grpc v1.46.2 13 | k8s.io/kubelet v0.24.1 14 | ) 15 | 16 | require ( 17 | github.com/gogo/protobuf v1.3.2 // indirect 18 | github.com/golang/protobuf v1.5.2 // indirect 19 | github.com/hashicorp/hcl v1.0.0 // indirect 20 | github.com/magiconair/properties v1.8.6 // indirect 21 | github.com/mitchellh/mapstructure v1.5.0 // indirect 22 | github.com/pelletier/go-toml v1.9.5 // indirect 23 | github.com/pelletier/go-toml/v2 v2.0.1 // indirect 24 | github.com/spf13/afero v1.8.2 // indirect 25 | github.com/spf13/cast v1.5.0 // indirect 26 | github.com/spf13/jwalterweatherman v1.1.0 // indirect 27 | github.com/subosito/gotenv v1.3.0 // indirect 28 | go.uber.org/atomic v1.7.0 // indirect 29 | go.uber.org/multierr v1.6.0 // indirect 30 | golang.org/x/net v0.0.0-20220520000938-2e3eb7b945c2 // indirect 31 | golang.org/x/text v0.3.7 // indirect 32 | google.golang.org/genproto v0.0.0-20220519153652-3a47de7e79bd // indirect 33 | google.golang.org/protobuf v1.28.0 // indirect 34 | gopkg.in/ini.v1 v1.66.4 // indirect 35 | gopkg.in/yaml.v2 v2.4.0 // indirect 36 | gopkg.in/yaml.v3 v3.0.0 // indirect 37 | ) 38 | -------------------------------------------------------------------------------- /plugins/internal/discovery/device.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package discovery 4 | 5 | // Represents a DirectX device 6 | type Device struct { 7 | 8 | // The unique PNP hardware identifier for the device 9 | ID string 10 | 11 | // A human-readable description of the device (e.g. the model name) 12 | Description string 13 | 14 | // The registry key that contains the driver details for the device 15 | DriverRegistryKey string 16 | 17 | // The absolute path to the directory in the driver store that contains the driver files for the device 18 | DriverStorePath string 19 | 20 | // The path to the physical location of the device in the system 21 | LocationPath string 22 | 23 | // The list of additional files that need to be copied from the driver store to the System32 directory in order to use the device with non-DirectX runtimes 24 | RuntimeFiles []*RuntimeFile 25 | 26 | // The list of additional files that need to be copied from the driver store to the SysWOW64 directory in order to use the device with non-DirectX runtimes 27 | RuntimeFilesWow64 []*RuntimeFile 28 | 29 | // The vendor of the device (e.g. AMD, Intel, NVIDIA) 30 | Vendor string 31 | 32 | // The DirectX adapter LUID associated with the PnP device 33 | AdapterLUID int64 34 | 35 | // Specifies whether the device is an integrated GPU (as opposed to a discrete GPU) 36 | IsIntegrated bool 37 | 38 | // Specifies whether the device is a detachable device (i.e. the device can be removed at runtime) 39 | IsDetachable bool 40 | 41 | // Specifies whether the device supports display 42 | // (i.e. supports either the DXCORE_ADAPTER_ATTRIBUTE_D3D11_GRAPHICS or DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS attributes) 43 | SupportsDisplay bool 44 | 45 | // Specifies whether the device supports compute (i.e. supports the DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE attribute) 46 | SupportsCompute bool 47 | } 48 | 49 | // Appends the supplied list of 64-bit runtime files, ignoring and returning any files that clash with existing mount destinations 50 | func (d *Device) AppendRuntimeFiles(files []*RuntimeFile) []*RuntimeFile { 51 | merged, ignored := mergeRuntimeFiles(d.RuntimeFiles, files) 52 | d.RuntimeFiles = merged 53 | return ignored 54 | } 55 | 56 | // Appends the supplied list of 32-bit runtime files, ignoring and returning any files that clash with existing mount destinations 57 | func (d *Device) AppendRuntimeFilesWow64(files []*RuntimeFile) []*RuntimeFile { 58 | merged, ignored := mergeRuntimeFiles(d.RuntimeFilesWow64, files) 59 | d.RuntimeFilesWow64 = merged 60 | return ignored 61 | } 62 | 63 | // Merges two lists of runtime files, ignoring any files that clash with existing mount destinations. Returns both the merged list and the list of ignored files. 64 | func mergeRuntimeFiles(files []*RuntimeFile, additions []*RuntimeFile) ([]*RuntimeFile, []*RuntimeFile) { 65 | merged := files 66 | ignored := []*RuntimeFile{} 67 | 68 | // Add each additional file to the list if it doesn't clash with an existing destination filename 69 | outer: 70 | for _, additionalFile := range additions { 71 | 72 | // Determine whether we have an existing file with the same destination as the new file 73 | for _, existingFile := range merged { 74 | if existingFile.DestinationFilename == additionalFile.DestinationFilename { 75 | ignored = append(ignored, additionalFile) 76 | continue outer 77 | } 78 | } 79 | 80 | // Add the file to the list 81 | merged = append(merged, additionalFile) 82 | } 83 | 84 | return merged, ignored 85 | } 86 | -------------------------------------------------------------------------------- /plugins/internal/discovery/device_filter.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package discovery 4 | 5 | type DeviceFilter int32 6 | 7 | const ( 8 | AllDevices DeviceFilter = 0 9 | DisplaySupported DeviceFilter = 1 10 | ComputeSupported DeviceFilter = 2 11 | DisplayOnly DeviceFilter = 3 12 | ComputeOnly DeviceFilter = 4 13 | DisplayAndCompute DeviceFilter = 5 14 | ) 15 | -------------------------------------------------------------------------------- /plugins/internal/discovery/runtime_file.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package discovery 4 | 5 | // Represents an additional file that needs to be copied from the driver store to the system directory in order to use a device with non-DirectX runtimes 6 | // (For details, see: ) 7 | type RuntimeFile struct { 8 | 9 | // The relative path to the file in the driver store 10 | SourcePath string `mapstructure:"source"` 11 | 12 | // The filename that the file should be given when copied to the destination directory 13 | DestinationFilename string `mapstructure:"destination"` 14 | } 15 | -------------------------------------------------------------------------------- /plugins/internal/mount/default_mounts.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package mount 4 | 5 | import ( 6 | "github.com/tensorworks/directx-device-plugins/plugins/internal/discovery" 7 | ) 8 | 9 | // The default 64-bit runtime mounts that we add for each device vendor, to supplement those specified in the registry 10 | var DefaultMounts = map[string][]*discovery.RuntimeFile{ 11 | VendorNvidia: { 12 | { 13 | SourcePath: "nvidia-smi.exe", 14 | DestinationFilename: "nvidia-smi.exe", 15 | }, 16 | { 17 | SourcePath: "vulkaninfo-x64.exe", 18 | DestinationFilename: "vulkaninfo.exe", 19 | }, 20 | }, 21 | } 22 | 23 | // The default 32-bit runtime mounts that we add for each device vendor, to supplement those specified in the registry 24 | var DefaultMountsWow64 = map[string][]*discovery.RuntimeFile{ 25 | VendorNvidia: { 26 | { 27 | SourcePath: "vulkaninfo-x86.exe", 28 | DestinationFilename: "vulkaninfo.exe", 29 | }, 30 | }, 31 | } 32 | -------------------------------------------------------------------------------- /plugins/internal/mount/device_mounts.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package mount 4 | 5 | import ( 6 | "os" 7 | "path/filepath" 8 | 9 | pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1" 10 | 11 | "github.com/tensorworks/directx-device-plugins/plugins/internal/discovery" 12 | ) 13 | 14 | // Generates the device specs for the supplied list of devices 15 | func SpecsForDevices(devices []*discovery.Device) []*pluginapi.DeviceSpec { 16 | 17 | specs := []*pluginapi.DeviceSpec{} 18 | 19 | // Provide the physical location path for each device, avoiding duplicates (duplicate paths can occur when 20 | // multitenancy is enabled and two requested device IDs map to the same underlying physical device) 21 | for _, device := range devices { 22 | specs = appendUniqueSpec(specs, &pluginapi.DeviceSpec{ 23 | HostPath: "vpci-location-path://" + device.LocationPath, 24 | ContainerPath: "", 25 | Permissions: "", 26 | }) 27 | } 28 | 29 | return specs 30 | } 31 | 32 | // Generates the runtime file mounts for the supplied list of devices 33 | func MountsForDevices(devices []*discovery.Device) []*pluginapi.Mount { 34 | 35 | mounts := []*pluginapi.Mount{} 36 | 37 | for _, device := range devices { 38 | 39 | // Generates the mounts for a list of runtime files 40 | generateMounts := func(files []*discovery.RuntimeFile, destinationRoot string) { 41 | for _, file := range files { 42 | 43 | // Resolve the absolute paths to the host source file and the container destination file 44 | source := filepath.Join(device.DriverStorePath, file.SourcePath) 45 | destination := filepath.Join(destinationRoot, file.DestinationFilename) 46 | 47 | // Only mount the file if it exists on the host and can be accessed, and isn't a duplicate 48 | // (Note that duplicate container paths can occur not only when mounting multiple devices 49 | // from a single vendor, but also when device drivers from different vendors mount files 50 | // to the same target path, which means that a container will only see the files from the 51 | // first device's vendor when collisions occur between different device drivers) 52 | _, err := os.Stat(source) 53 | if err == nil { 54 | mounts = appendUniqueMount(mounts, &pluginapi.Mount{ 55 | HostPath: source, 56 | ContainerPath: destination, 57 | ReadOnly: true, 58 | }) 59 | } 60 | } 61 | } 62 | 63 | // Generate the mounts for both the System32 and SysWOW64 runtime files 64 | generateMounts(device.RuntimeFiles, "C:\\Windows\\System32") 65 | generateMounts(device.RuntimeFilesWow64, "C:\\Windows\\SysWOW64") 66 | } 67 | 68 | return mounts 69 | } 70 | 71 | // Appends a device spec to an existing list of device specs if it's not already present in the list 72 | func appendUniqueSpec(specs []*pluginapi.DeviceSpec, newSpec *pluginapi.DeviceSpec) []*pluginapi.DeviceSpec { 73 | for _, existing := range specs { 74 | if existing.HostPath == newSpec.HostPath { 75 | return specs 76 | } 77 | } 78 | 79 | return append(specs, newSpec) 80 | } 81 | 82 | // Appends a mount to an existing list of mounts if it's not already present in the list 83 | func appendUniqueMount(mounts []*pluginapi.Mount, newMount *pluginapi.Mount) []*pluginapi.Mount { 84 | for _, existing := range mounts { 85 | if existing.ContainerPath == newMount.ContainerPath { 86 | return mounts 87 | } 88 | } 89 | 90 | return append(mounts, newMount) 91 | } 92 | -------------------------------------------------------------------------------- /plugins/internal/mount/vendors.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package mount 4 | 5 | // Vendor identifier for NVIDIA devices 6 | const VendorNvidia = "nvidia" 7 | -------------------------------------------------------------------------------- /plugins/internal/plugin/common_main.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package plugin 4 | 5 | import ( 6 | "log" 7 | "os" 8 | "os/signal" 9 | "strings" 10 | "syscall" 11 | 12 | "github.com/tensorworks/directx-device-plugins/plugins/internal/discovery" 13 | "go.uber.org/zap" 14 | ) 15 | 16 | // The version number for the plugin 17 | const version = "0.0.1" 18 | 19 | func CommonMain(pluginName string, resourceName string, filter discovery.DeviceFilter) { 20 | 21 | // Create a logger that prints debug and higher verbosity level messages 22 | logger, err := zap.NewDevelopment() 23 | if err != nil { 24 | log.Fatalln("Error: failed to create the logger:", err) 25 | } 26 | 27 | // Sugar the logger 28 | sugar := logger.Sugar() 29 | defer sugar.Sync() 30 | 31 | // Log the plugin name and version 32 | sugar.Infof("Kubernetes device plugin for %s, version %s", strings.ToUpper(pluginName), version) 33 | 34 | // Load the plugin configuration data 35 | config, err := LoadConfig(pluginName, sugar) 36 | if err != nil { 37 | sugar.Errorf("Error: failed to load the device plugin configuration: %v", err) 38 | return 39 | } 40 | 41 | // Create the device plugin and start the device watcher 42 | server, err := NewDevicePlugin(pluginName, version, resourceName, filter, config, sugar) 43 | if err != nil { 44 | sugar.Errorf("Error: failed to create the device plugin: %v", err) 45 | return 46 | } 47 | 48 | //Ensure the plugin is destroyed and the device watcher stopped when we complete execution 49 | defer server.Destroy() 50 | 51 | // Attempt to start the plugin's gRPC server 52 | if err := server.StartServer(); err != nil { 53 | sugar.Errorf("Error: failed to start the gRPC server: %v", err) 54 | return 55 | } 56 | 57 | // Ensure we perform a graceful shutdown of the gRPC server before we destroy the plugin 58 | defer server.StopServer() 59 | 60 | // Attempt to register the device plugin with the Kubelet 61 | if err := server.RegisterWithKubelet(); err != nil { 62 | sugar.Errorf("Error: failed to register the device plugin with the Kubelet: %v", err) 63 | return 64 | } 65 | 66 | // Wire up a signal handler to receive shutdown requests 67 | signals := make(chan os.Signal, 1) 68 | signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) 69 | 70 | // Serve requests until we receive a shutdown request or an error occurs 71 | sugar.Info("Serving until a shutdown request is received") 72 | for { 73 | select { 74 | case sig := <-signals: 75 | sugar.Infow("Received signal", "signal", sig) 76 | return 77 | 78 | case err := <-server.Errors: 79 | sugar.Errorf("Error: %v", err) 80 | return 81 | } 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /plugins/internal/plugin/deletion_watcher.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package plugin 4 | 5 | import ( 6 | "github.com/fsnotify/fsnotify" 7 | ) 8 | 9 | type DeletionWatcher struct { 10 | watcher *fsnotify.Watcher 11 | Deleted chan struct{} 12 | Errors chan error 13 | } 14 | 15 | func WatchForDeletion(file string) (*DeletionWatcher, error) { 16 | 17 | // Create a new filesystem watcher 18 | fsWatcher, err := fsnotify.NewWatcher() 19 | if err != nil { 20 | return nil, err 21 | } 22 | 23 | // Add a watch for the specified file 24 | err = fsWatcher.Add(file) 25 | if err != nil { 26 | return nil, err 27 | } 28 | 29 | // Wrap the filesystem watcher in a deletion watcher 30 | deletionWatcher := &DeletionWatcher{ 31 | watcher: fsWatcher, 32 | Deleted: make(chan struct{}, 1), 33 | Errors: make(chan error, 1), 34 | } 35 | 36 | // Start the watcher goroutine 37 | go deletionWatcher.watch() 38 | return deletionWatcher, nil 39 | } 40 | 41 | // Cancels the watch 42 | func (d *DeletionWatcher) Cancel() { 43 | d.watcher.Close() 44 | } 45 | 46 | func (d *DeletionWatcher) watch() { 47 | 48 | // Ensure the underlying filesystem watcher is closed when we are done 49 | defer d.watcher.Close() 50 | 51 | // Ensure the channels are closed when we are done 52 | defer close(d.Deleted) 53 | defer close(d.Errors) 54 | 55 | // Process events and errors 56 | for { 57 | select { 58 | case event, ok := <-d.watcher.Events: 59 | if !ok { 60 | return 61 | } 62 | 63 | // Check whether the event is a deletion event 64 | if event.Op&fsnotify.Remove == fsnotify.Remove { 65 | d.Deleted <- struct{}{} 66 | return 67 | } 68 | 69 | case err, ok := <-d.watcher.Errors: 70 | if !ok { 71 | return 72 | } 73 | 74 | d.Errors <- err 75 | } 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /plugins/internal/plugin/device_plugin.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package plugin 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "net" 9 | "path/filepath" 10 | "strings" 11 | "sync" 12 | "time" 13 | 14 | "go.uber.org/zap" 15 | "google.golang.org/grpc" 16 | "google.golang.org/grpc/codes" 17 | "google.golang.org/grpc/credentials/insecure" 18 | "google.golang.org/grpc/status" 19 | pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1" 20 | 21 | "github.com/tensorworks/directx-device-plugins/plugins/internal/discovery" 22 | "github.com/tensorworks/directx-device-plugins/plugins/internal/mount" 23 | ) 24 | 25 | type DevicePlugin struct { 26 | 27 | // The name of the plugin 28 | name string 29 | 30 | // The configuration data for the plugin 31 | config *PluginConfig 32 | 33 | // The Unix socket on which the plugin's gRPC server listens for connections 34 | endpoint string 35 | endpointDeleted *DeletionWatcher 36 | 37 | // The resource name that the plugin advertises to the Kubelet 38 | resourceName string 39 | 40 | // The device watcher that monitors the available DirectX devices 41 | watcher *DeviceWatcher 42 | 43 | // The most recent device list received from the device watcher, and a mutex to protect concurrent access 44 | currentDevices []*discovery.Device 45 | devicesMutex sync.Mutex 46 | 47 | // The logger used to log diagnostic information 48 | logger *zap.SugaredLogger 49 | 50 | // The gRPC server that services requests from the Kubelet 51 | server *grpc.Server 52 | 53 | // The channel used to trigger a restart of the gRPC server in the event of a Kubelet restart 54 | restart chan struct{} 55 | 56 | // The channel used to stop the ListAndWatch streaming RPC during server shutdown 57 | stopListWatch chan struct{} 58 | 59 | // The channel used for reporting errors while the gRPC server is running 60 | Errors chan error 61 | } 62 | 63 | // Creates a new device plugin 64 | func NewDevicePlugin(pluginName string, pluginVersion string, resourceName string, filter discovery.DeviceFilter, config *PluginConfig, logger *zap.SugaredLogger) (*DevicePlugin, error) { 65 | 66 | // Attempt to create a new DeviceWatcher 67 | watcher, err := NewDeviceWatcher( 68 | pluginVersion, 69 | filter, 70 | config.IncludeIntegrated, 71 | config.IncludeDetachable, 72 | config.AdditionalMounts, 73 | config.AdditionalMountsWow64, 74 | logger, 75 | ) 76 | if err != nil { 77 | return nil, err 78 | } 79 | 80 | // Verify that device watcher can successfully list devices 81 | select { 82 | case <-watcher.Updates: 83 | logger.Info("Initial device list retrieved successfully") 84 | 85 | case <-watcher.Errors: 86 | watcher.Destroy() 87 | return nil, fmt.Errorf("failed to perform device discovery: %v", err) 88 | } 89 | 90 | // Create a new device plugin instance with the device watcher 91 | plugin := &DevicePlugin{ 92 | name: pluginName, 93 | config: config, 94 | endpoint: "", 95 | endpointDeleted: nil, 96 | resourceName: resourceName, 97 | watcher: watcher, 98 | currentDevices: []*discovery.Device{}, 99 | devicesMutex: sync.Mutex{}, 100 | logger: logger, 101 | server: nil, 102 | restart: make(chan struct{}, 1), 103 | stopListWatch: nil, 104 | Errors: make(chan error, 1), 105 | } 106 | 107 | // Forward any device watcher errors to the plugin's error channel 108 | go func() { 109 | for err := range plugin.watcher.Errors { 110 | plugin.Errors <- err 111 | } 112 | }() 113 | 114 | // Restart the plugin's gRPC server and perform plugin registration again in the event of a Kubelet restart 115 | go func() { 116 | for range plugin.restart { 117 | 118 | // Restart the gRPC server with a new Unix socket filename since the Kubelet will delete the old one 119 | if err := plugin.RestartServer(); err != nil { 120 | plugin.Errors <- err 121 | } 122 | 123 | // Register the device plugin with the new Kubelet instance 124 | if err := plugin.RegisterWithKubelet(); err != nil { 125 | plugin.Errors <- err 126 | } 127 | } 128 | }() 129 | 130 | return plugin, nil 131 | } 132 | 133 | // Starts the gRPC server for the device plugin 134 | func (p *DevicePlugin) StartServer() error { 135 | 136 | // Create a new gRPC server instance 137 | // (Note that this is necessary to support restarts, since a server instance cannot be reused after it has stopped serving) 138 | p.server = grpc.NewServer() 139 | 140 | // Register our service implementation with the gRPC server 141 | p.logger.Info("Registering the service implementation with the gRPC server") 142 | pluginapi.RegisterDevicePluginServer(p.server, p) 143 | 144 | // Append a timestamp to the filename for the gRPC server's Unix socket to ensure it is unique 145 | p.endpoint = filepath.Join(pluginapi.DevicePluginPathWindows, fmt.Sprintf("%s-%d.sock", p.name, time.Now().UnixMilli())) 146 | 147 | // Attempt to listen for connections on our Unix socket 148 | p.logger.Infow("Listening on endpoint", "endpoint", p.endpoint) 149 | listener, err := net.Listen("unix", p.endpoint) 150 | if err != nil { 151 | return err 152 | } 153 | 154 | // Create the shutdown channel for stopping the ListAndWatch streaming RPC 155 | p.stopListWatch = make(chan struct{}) 156 | 157 | // Create a file deletion watcher for our Unix socket 158 | endpointDeleted, err := WatchForDeletion(p.endpoint) 159 | if err != nil { 160 | return err 161 | } 162 | 163 | // We detect Kubelet restarts by detecting the deletion of our socket 164 | p.endpointDeleted = endpointDeleted 165 | go func() { 166 | for { 167 | select { 168 | 169 | case err, ok := <-p.endpointDeleted.Errors: 170 | if !ok { 171 | p.logger.Info("DeletionWatcher error channel closed") 172 | return 173 | } 174 | p.Errors <- err 175 | 176 | case _, ok := <-p.endpointDeleted.Deleted: 177 | if !ok { 178 | p.logger.Info("DeletionWatcher deletion channel closed") 179 | return 180 | } 181 | 182 | p.logger.Info("Endpoint deletion detected, triggering a restart of the gRPC server") 183 | p.restart <- struct{}{} 184 | } 185 | } 186 | }() 187 | 188 | // Start the gRPC server in a new goroutine and send any errors back through our error channel 189 | go func() { 190 | p.logger.Info("Starting the gRPC server") 191 | if err := p.server.Serve(listener); err != nil { 192 | p.Errors <- err 193 | } 194 | }() 195 | 196 | return nil 197 | } 198 | 199 | // Gracefully stops the gRPC server for the device plugin 200 | func (p *DevicePlugin) StopServer() { 201 | 202 | // If StopServer() is called before StartServer() then do nothing 203 | if p.server == nil { 204 | return 205 | } 206 | 207 | // Stop the ListAndWatch streaming RPC if it is running 208 | close(p.stopListWatch) 209 | 210 | // Stop watching our Unix socket for deletion events 211 | p.endpointDeleted.Cancel() 212 | 213 | // Attempt to perform a graceful shutdown of the server (this will delete the Unix socket) 214 | p.logger.Info("Gracefully stopping the gRPC server") 215 | p.server.GracefulStop() 216 | p.server = nil 217 | } 218 | 219 | // Restarts the gRPC server for the device plugin, generating a new Unix socket filename 220 | func (p *DevicePlugin) RestartServer() error { 221 | p.StopServer() 222 | return p.StartServer() 223 | } 224 | 225 | // Destroys our underlying resources 226 | func (p *DevicePlugin) Destroy() { 227 | p.watcher.Destroy() 228 | close(p.restart) 229 | close(p.Errors) 230 | } 231 | 232 | // Registers the device plugin with the Kubelet 233 | func (p *DevicePlugin) RegisterWithKubelet() error { 234 | 235 | // Set a 60 second timeout when attempting to connect to the Kubelet 236 | ctxConnect, cancelConnect := context.WithTimeout(context.Background(), time.Minute) 237 | defer cancelConnect() 238 | 239 | // Create a dialler that treats the Kubelet's endpoint as a Unix socket rather than a TCP address 240 | dialler := grpc.WithContextDialer(func(ctx context.Context, address string) (net.Conn, error) { 241 | return (&net.Dialer{}).DialContext(ctx, "unix", address) 242 | }) 243 | 244 | // Attempt to connect to the Kubelet's gRPC service using the socket path for Windows 245 | p.logger.Infow("Connecting to the Kubelet", "endpoint", pluginapi.KubeletSocketWindows) 246 | conn, err := grpc.DialContext( 247 | ctxConnect, 248 | pluginapi.KubeletSocketWindows, 249 | grpc.WithBlock(), 250 | grpc.WithTransportCredentials(insecure.NewCredentials()), 251 | dialler, 252 | ) 253 | if err != nil { 254 | return fmt.Errorf("failed to connect to the Kubelet's gRPC service: %v", err) 255 | } 256 | defer conn.Close() 257 | 258 | // Prepare a registration request 259 | request := &pluginapi.RegisterRequest{ 260 | Version: pluginapi.Version, 261 | Endpoint: filepath.Base(p.endpoint), 262 | ResourceName: p.resourceName, 263 | } 264 | 265 | // Set a 60 second timeout when attempting to register with the Kubelet 266 | ctxRegister, cancelRegister := context.WithTimeout(context.Background(), time.Minute) 267 | defer cancelRegister() 268 | 269 | // Create a registration client and attempt to send our registration request 270 | p.logger.Infow("Sending registration request to the Kubelet", "request", request) 271 | client := pluginapi.NewRegistrationClient(conn) 272 | if _, err := client.Register(ctxRegister, request); err != nil { 273 | return fmt.Errorf("failed to register the device plugin with the Kubelet: %v", err) 274 | } 275 | 276 | p.logger.Info("Successfully registered the device plugin with the Kubelet") 277 | return nil 278 | } 279 | 280 | func (p *DevicePlugin) GetDevicePluginOptions(ctx context.Context, request *pluginapi.Empty) (*pluginapi.DevicePluginOptions, error) { 281 | 282 | // Instruct the Kubelet not to call the GetPreferredAllocation or PreStartContainer RPCs, since they aren't necessary 283 | p.logger.Info("GetDevicePluginOptions RPC invoked") 284 | return &pluginapi.DevicePluginOptions{ 285 | GetPreferredAllocationAvailable: false, 286 | PreStartRequired: false, 287 | }, nil 288 | } 289 | 290 | func (p *DevicePlugin) ListAndWatch(request *pluginapi.Empty, stream pluginapi.DevicePlugin_ListAndWatchServer) error { 291 | 292 | // Force a device list refresh to ensure we have an initial list for the Kubelet 293 | p.logger.Info("ListAndWatch streaming RPC started, refreshing the device list") 294 | p.watcher.ForceRefresh() 295 | 296 | // Continue sending updates as our device list changes or until shutdown is requested 297 | for { 298 | select { 299 | 300 | case <-p.stopListWatch: 301 | p.logger.Info("Shutdown requested, stopping ListAndWatch streaming RPC") 302 | return nil 303 | 304 | case <-stream.Context().Done(): 305 | p.logger.Info("Kubelet disconnect detected, stopping ListAndWatch streaming RPC") 306 | return nil 307 | 308 | case devices := <-p.watcher.Updates: 309 | p.logger.Infow("Received new device list", "devices", devices) 310 | 311 | // Store the device list 312 | p.devicesMutex.Lock() 313 | p.currentDevices = devices 314 | p.devicesMutex.Unlock() 315 | 316 | // Convert the device discovery devices to Kubernetes device plugin API devices 317 | kubeletDevices := []*pluginapi.Device{} 318 | for _, device := range devices { 319 | 320 | // Advertise each device multiple times, as per our multitenancy setting 321 | for i := uint32(0); i < p.config.Multitenancy; i += 1 { 322 | kubeletDevices = append(kubeletDevices, &pluginapi.Device{ 323 | ID: fmt.Sprintf("%s\\%d", device.ID, i), 324 | Health: pluginapi.Healthy, 325 | }) 326 | } 327 | } 328 | 329 | // Send the device list to the Kubelet 330 | p.logger.Infow("Sending device list to Kubelet", "devices", kubeletDevices) 331 | stream.Send(&pluginapi.ListAndWatchResponse{ 332 | Devices: kubeletDevices, 333 | }) 334 | } 335 | } 336 | } 337 | 338 | func (p *DevicePlugin) GetPreferredAllocation(context.Context, *pluginapi.PreferredAllocationRequest) (*pluginapi.PreferredAllocationResponse, error) { 339 | 340 | // This RPC should never be called 341 | return nil, status.Error(codes.Unimplemented, "GetPreferredAllocation is not implemented") 342 | } 343 | 344 | // Retrieves the device with the specified ID 345 | func (p *DevicePlugin) GetDeviceForID(deviceID string) (*discovery.Device, error) { 346 | 347 | // Strip the multitenancy suffix from the device ID 348 | backslash := strings.LastIndex(deviceID, "\\") 349 | if backslash == -1 { 350 | return nil, fmt.Errorf("malformed device ID \"%s\"", deviceID) 351 | } 352 | stripped := deviceID[0:backslash] 353 | 354 | // Lock the mutex for the device list 355 | p.devicesMutex.Lock() 356 | defer p.devicesMutex.Unlock() 357 | 358 | // Search for a device with the specified ID 359 | for _, device := range p.currentDevices { 360 | if device.ID == stripped { 361 | return device, nil 362 | } 363 | } 364 | 365 | return nil, fmt.Errorf("could not find device with ID \"%s\"", stripped) 366 | } 367 | 368 | func (p *DevicePlugin) Allocate(ctx context.Context, request *pluginapi.AllocateRequest) (*pluginapi.AllocateResponse, error) { 369 | 370 | p.logger.Infow("Allocate RPC invoked, processing allocation request", "request", request) 371 | response := &pluginapi.AllocateResponse{} 372 | 373 | // Process each of the container requests 374 | for _, containerReq := range request.ContainerRequests { 375 | 376 | // Gather the list of requested devices for the container 377 | devices := []*discovery.Device{} 378 | for _, deviceID := range containerReq.DevicesIDs { 379 | 380 | // Verify that the requested device exists 381 | device, err := p.GetDeviceForID(deviceID) 382 | if err != nil { 383 | return nil, err 384 | } 385 | 386 | // Add the device to the list 387 | devices = append(devices, device) 388 | } 389 | 390 | // Generate the device specs and runtime file mounts for the requested devices, appending the container response to our overall response 391 | response.ContainerResponses = append(response.ContainerResponses, &pluginapi.ContainerAllocateResponse{ 392 | Devices: mount.SpecsForDevices(devices), 393 | Mounts: mount.MountsForDevices(devices), 394 | }) 395 | } 396 | 397 | p.logger.Infow("Sending allocation response", "response", response) 398 | return response, nil 399 | } 400 | 401 | func (p *DevicePlugin) PreStartContainer(context.Context, *pluginapi.PreStartContainerRequest) (*pluginapi.PreStartContainerResponse, error) { 402 | 403 | // This RPC should never be called 404 | return nil, status.Error(codes.Unimplemented, "PreStartContainer is not implemented") 405 | } 406 | -------------------------------------------------------------------------------- /plugins/internal/plugin/device_watcher.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package plugin 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "strings" 9 | "time" 10 | 11 | "github.com/tensorworks/directx-device-plugins/plugins/internal/discovery" 12 | "go.uber.org/zap" 13 | ) 14 | 15 | // Watches for device updates 16 | type DeviceWatcher struct { 17 | 18 | // Our interface to the underlying DeviceDiscovery object from the DirectX device discovery library 19 | deviceDiscovery *discovery.DeviceDiscovery 20 | 21 | // The filter used to control which devices are reported 22 | deviceFilter discovery.DeviceFilter 23 | 24 | // Whether to include integrated GPUs when reporting devices 25 | includeIntegrated bool 26 | 27 | // Whether to include detachable devices when reporting devices 28 | includeDetachable bool 29 | 30 | // The list of additional runtime files for each device vendor that will be added to each device's list for System32 31 | additionalRuntimeFiles map[string][]*discovery.RuntimeFile 32 | 33 | // The list of additional runtime files for each device vendor that will be added to each device's list for SysWOW64 34 | additionalRuntimeFilesWow64 map[string][]*discovery.RuntimeFile 35 | 36 | // The logger used to log diagnostic information 37 | logger *zap.SugaredLogger 38 | 39 | // The channel used to request a forced refresh of the device list 40 | refresh chan struct{} 41 | 42 | // The channel used to stop the device discovery goroutine 43 | shutdown chan struct{} 44 | 45 | // The channel used to report errors 46 | Errors chan error 47 | 48 | // The channel used to report device updates 49 | Updates chan []*discovery.Device 50 | } 51 | 52 | func NewDeviceWatcher( 53 | expectedVersion string, 54 | deviceFilter discovery.DeviceFilter, 55 | includeIntegrated bool, 56 | includeDetachable bool, 57 | additionalRuntimeFiles map[string][]*discovery.RuntimeFile, 58 | additionalRuntimeFilesWow64 map[string][]*discovery.RuntimeFile, 59 | logger *zap.SugaredLogger, 60 | ) (*DeviceWatcher, error) { 61 | 62 | // Attempt to load the DirectX device discovery library 63 | if err := discovery.LoadDiscoveryLibrary(); err != nil { 64 | return nil, err 65 | } 66 | 67 | // Verify that the version of the device discovery library matches our expected version 68 | libraryVersion := discovery.GetDiscoveryLibraryVersion() 69 | if libraryVersion != expectedVersion { 70 | return nil, fmt.Errorf( 71 | "device discovery library version mismatch (found %s, expected %s)", 72 | libraryVersion, 73 | expectedVersion, 74 | ) 75 | } 76 | 77 | // Enable verbose logging for the device discovery library 78 | discovery.EnableDiscoveryLogging() 79 | 80 | // Create a new DeviceDiscovery object 81 | deviceDiscovery, err := discovery.NewDeviceDiscovery() 82 | if err != nil { 83 | return nil, err 84 | } 85 | 86 | // Create the DeviceWatcher 87 | watcher := &DeviceWatcher{ 88 | deviceDiscovery: deviceDiscovery, 89 | deviceFilter: deviceFilter, 90 | includeIntegrated: includeIntegrated, 91 | includeDetachable: includeDetachable, 92 | additionalRuntimeFiles: additionalRuntimeFiles, 93 | additionalRuntimeFilesWow64: additionalRuntimeFilesWow64, 94 | logger: logger, 95 | refresh: make(chan struct{}, 1), 96 | shutdown: make(chan struct{}), 97 | Errors: make(chan error, 1), 98 | Updates: make(chan []*discovery.Device, 1), 99 | } 100 | 101 | // Start the watcher goroutine 102 | go watcher.watchDevices() 103 | 104 | return watcher, nil 105 | } 106 | 107 | // Stops our goroutine and destroys the underlying DeviceDiscovery object 108 | func (d *DeviceWatcher) Destroy() { 109 | close(d.shutdown) 110 | close(d.refresh) 111 | } 112 | 113 | // Forces a refresh of the device list, irrespective of whether the current list is stale 114 | func (d *DeviceWatcher) ForceRefresh() { 115 | d.refresh <- struct{}{} 116 | } 117 | 118 | // Merges any additional runtime files into the list for a device 119 | func (d *DeviceWatcher) mergeRuntimeFiles(device *discovery.Device) { 120 | 121 | // Determine whether we have any additional runtime files for the device vendor 122 | files, haveFiles := d.additionalRuntimeFiles[strings.ToLower(device.Vendor)] 123 | filesWow64, haveFilesWow64 := d.additionalRuntimeFilesWow64[strings.ToLower(device.Vendor)] 124 | 125 | // Merge any additions for System32 126 | if haveFiles { 127 | ignored := device.AppendRuntimeFiles(files) 128 | for _, file := range ignored { 129 | d.logger.Infow("Ignoring additional 64-bit runtime file because it clashes with an existing filename", "file", file) 130 | } 131 | } 132 | 133 | // Merge any additions for SysWOW64 134 | if haveFilesWow64 { 135 | ignored := device.AppendRuntimeFilesWow64(filesWow64) 136 | for _, file := range ignored { 137 | d.logger.Infow("Ignoring additional 32-bit runtime file because it clashes with an existing filename", "file", file) 138 | } 139 | } 140 | } 141 | 142 | // Refreshes the list of devices and reports the new list 143 | func (d *DeviceWatcher) refreshDevices() error { 144 | 145 | // Refresh the list of devices 146 | if err := d.deviceDiscovery.DiscoverDevices(d.deviceFilter, d.includeIntegrated, d.includeDetachable); err != nil { 147 | return err 148 | } 149 | 150 | // Process any additional runtime files for each device 151 | for _, device := range d.deviceDiscovery.Devices { 152 | d.mergeRuntimeFiles(device) 153 | } 154 | 155 | // Report the new device list 156 | d.Updates <- d.deviceDiscovery.Devices 157 | return nil 158 | } 159 | 160 | // The main device watch loop 161 | func (d *DeviceWatcher) watchDevices() { 162 | 163 | // Destroy the underlying DeviceDiscovery object when the loop completes 164 | defer d.deviceDiscovery.Destroy() 165 | 166 | // Use a context for waiting between polling operations rather than sleeping, so we remain responsive to shutdown and refresh events 167 | sleep, cancelSleep := context.WithTimeout(context.Background(), time.Second*0) 168 | defer cancelSleep() 169 | 170 | // Continue sending device updates until shutdown is requested: 171 | forceRefresh := false 172 | for { 173 | select { 174 | 175 | case <-d.shutdown: 176 | return 177 | 178 | case <-d.refresh: 179 | forceRefresh = true 180 | cancelSleep() 181 | 182 | case <-sleep.Done(): 183 | 184 | // Poll for device list changes 185 | refresh, err := d.deviceDiscovery.IsRefreshRequired() 186 | if err != nil { 187 | d.Errors <- err 188 | return 189 | } 190 | 191 | // Retrieve the updated device list if one is available or if a forced refresh has been requested 192 | if refresh || forceRefresh { 193 | if err := d.refreshDevices(); err != nil { 194 | d.Errors <- err 195 | return 196 | } 197 | } 198 | 199 | // Wait 10 seconds before polling again 200 | forceRefresh = false 201 | sleep, cancelSleep = context.WithTimeout(context.Background(), time.Second*10) 202 | defer cancelSleep() 203 | } 204 | } 205 | } 206 | -------------------------------------------------------------------------------- /plugins/internal/plugin/plugin_configuration.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package plugin 4 | 5 | import ( 6 | "errors" 7 | "fmt" 8 | "io/fs" 9 | "os" 10 | "path/filepath" 11 | "strings" 12 | 13 | "github.com/spf13/viper" 14 | "github.com/tensorworks/directx-device-plugins/plugins/internal/discovery" 15 | "github.com/tensorworks/directx-device-plugins/plugins/internal/mount" 16 | "go.uber.org/zap" 17 | "golang.org/x/exp/maps" 18 | ) 19 | 20 | // PluginConfig represents the available configuration options for a device plugin 21 | type PluginConfig struct { 22 | 23 | // The number of containers that can access each device simultaneously (set this to 1 for exclusive access) 24 | Multitenancy uint32 25 | 26 | // Specifies whether we advertise integrated devices (i.e. integrated GPUs) 27 | IncludeIntegrated bool 28 | 29 | // Specifies whether we advertise detachable devices (e.g. external GPUs) 30 | IncludeDetachable bool 31 | 32 | // The list of additional runtime files to be mounted to System32 for each device vendor 33 | AdditionalMounts map[string][]*discovery.RuntimeFile 34 | 35 | // The list of additional runtime files to be mounted to SysWOW64 for each device vendor 36 | AdditionalMountsWow64 map[string][]*discovery.RuntimeFile 37 | } 38 | 39 | // Appends a default set of mounts to the supplied mounts, converting all vendor names to lower case to ensure consistency 40 | func appendMounts(mounts map[string][]*discovery.RuntimeFile, defaults map[string][]*discovery.RuntimeFile) map[string][]*discovery.RuntimeFile { 41 | 42 | // Gather the set of unique vendor names, converting all names to lower case 43 | vendors := make(map[string]bool) 44 | for _, vendor := range append(maps.Keys(mounts), maps.Keys(defaults)...) { 45 | vendorLower := strings.ToLower(vendor) 46 | if !vendors[vendorLower] { 47 | vendors[vendorLower] = true 48 | } 49 | } 50 | 51 | // Process the mounts for each vendor in turn 52 | appended := make(map[string][]*discovery.RuntimeFile) 53 | for vendor := range vendors { 54 | appended[vendor] = []*discovery.RuntimeFile{} 55 | 56 | // Add the mounts for the vendor if we have any 57 | vendorMounts, haveMounts := mounts[vendor] 58 | if haveMounts { 59 | appended[vendor] = append(appended[vendor], vendorMounts...) 60 | } 61 | 62 | // Add the defaults for the vendor if we have any 63 | vendorDefaults, haveDefaults := defaults[vendor] 64 | if haveDefaults { 65 | appended[vendor] = append(appended[vendor], vendorDefaults...) 66 | } 67 | } 68 | 69 | return appended 70 | } 71 | 72 | // Load loads the configuration data from the runtime environment. 73 | func LoadConfig(pluginName string, logger *zap.SugaredLogger) (*PluginConfig, error) { 74 | 75 | // Set our default configuration values 76 | v := viper.New() 77 | v.SetDefault("multitenancy", 0) 78 | v.SetDefault("includeIntegrated", false) 79 | v.SetDefault("includeDetachable", false) 80 | v.SetDefault("additionalMounts", make(map[string][]*discovery.RuntimeFile)) 81 | v.SetDefault("additionalMountsWow64", make(map[string][]*discovery.RuntimeFile)) 82 | 83 | // The names of our environment variables reflect the plugin name 84 | envPrefix := fmt.Sprint(strings.ToUpper(pluginName), "_DEVICE_PLUGIN_") 85 | v.BindEnv("multitenancy", fmt.Sprint(envPrefix, "MULTITENANCY")) 86 | v.BindEnv("includeIntegrated", fmt.Sprint(envPrefix, "INCLUDE_INTEGRATED")) 87 | v.BindEnv("includeDetachable", fmt.Sprint(envPrefix, "INCLUDE_DETACHABLE")) 88 | 89 | // Check if a config file path was explicitly specified through an environment variable 90 | configPath, configPathExists := os.LookupEnv(fmt.Sprint(envPrefix, "CONFIG_FILE")) 91 | if configPathExists { 92 | 93 | // Verify that the specified value is an absolute path 94 | if !filepath.IsAbs(configPath) { 95 | return nil, errors.New("configuration file path must be an absolute path") 96 | } 97 | 98 | // Verify that the specified file exists 99 | if _, err := os.Stat(configPath); errors.Is(err, fs.ErrNotExist) { 100 | return nil, fmt.Errorf("specified configuration file does not exist: %s", configPath) 101 | } 102 | 103 | // Use the specified path 104 | v.SetConfigFile(configPath) 105 | 106 | } else { 107 | 108 | // The default name of our YAML configuration file reflects the plugin name 109 | v.SetConfigName(pluginName) 110 | v.SetConfigType("yaml") 111 | 112 | // We search for the configuration file in both our global config directory and the current working directory 113 | v.AddConfigPath(".") 114 | v.AddConfigPath("\\etc\\directx-device-plugins") 115 | } 116 | 117 | // Attempt to parse our YAML configuration file if it exists 118 | if err := v.ReadInConfig(); err != nil { 119 | if _, ok := err.(viper.ConfigFileNotFoundError); ok { 120 | logger.Infow("Configuration file not found, using configuration values from environment variables") 121 | } else { 122 | return nil, err 123 | } 124 | } 125 | 126 | // Load the parsed configuration values into our struct 127 | c := &PluginConfig{} 128 | if err := v.Unmarshal(c); err != nil { 129 | return nil, err 130 | } 131 | 132 | // Enforce a minimum value of 1 for multitenancy 133 | if c.Multitenancy == 0 { 134 | c.Multitenancy = 1 135 | } 136 | 137 | // Append our default mounts to any user-supplied values 138 | c.AdditionalMounts = appendMounts(c.AdditionalMounts, mount.DefaultMounts) 139 | c.AdditionalMountsWow64 = appendMounts(c.AdditionalMountsWow64, mount.DefaultMountsWow64) 140 | 141 | // Log the parsed configuration values 142 | logger.Infow("Parsed configuration data", "config", c) 143 | 144 | return c, nil 145 | } 146 | -------------------------------------------------------------------------------- /update-version.bat: -------------------------------------------------------------------------------- 1 | @powershell -ExecutionPolicy Bypass -File "%~dp0.\update-version.ps1" %* 2 | -------------------------------------------------------------------------------- /update-version.ps1: -------------------------------------------------------------------------------- 1 | Param ( 2 | [parameter(Mandatory=$true, Position=0, HelpMessage = "The new version string")] 3 | $version 4 | ) 5 | 6 | 7 | # Patches a file using the specified regular expression search string and replacement 8 | function Patch-File { 9 | Param ( 10 | $Path, 11 | $Search, 12 | $Replace 13 | ) 14 | 15 | Write-Host "Updating $Path" -ForegroundColor Cyan 16 | $content = Get-Content -Path $Path -Raw 17 | $content = $content -replace $Search, $Replace 18 | Set-Content -Path $Path -NoNewline -Value $content 19 | } 20 | 21 | # Patches image tags in a YAML file 22 | function Patch-ImageTags { 23 | Param ($Path) 24 | Patch-File -Path $Path ` 25 | -Search '"index.docker.io/tensorworks/(.+):(.+)"' ` 26 | -Replace "`"index.docker.io/tensorworks/`$1:$version`"" 27 | } 28 | 29 | 30 | # Update the version strings in all of our files, ready for a new release 31 | Write-Host "Updating version strings to $version..." -ForegroundColor Green 32 | 33 | # Update the version string for the device discovery library 34 | Patch-File -Path "$PSScriptRoot\library\src\DeviceDiscovery.cpp" ` 35 | -Search '#define LIBRARY_VERSION L"(.+)"' ` 36 | -Replace "#define LIBRARY_VERSION L`"$version`"" 37 | 38 | # Update the version string for the device plugins 39 | Patch-File -Path "$PSScriptRoot\plugins\internal\plugin\common_main.go" ` 40 | -Search 'const version = "(.+)"' ` 41 | -Replace "const version = `"$version`"" 42 | 43 | # Update the deployment YAML files 44 | foreach ($file in Get-ChildItem -Path "$PSScriptRoot\deployments\*.yml") { 45 | Patch-ImageTags -Path $file.FullName 46 | } 47 | 48 | # Update the example YAML files 49 | foreach ($file in Get-ChildItem -Path "$PSScriptRoot\examples\*\*.yml") { 50 | Patch-ImageTags -Path $file.FullName 51 | } 52 | --------------------------------------------------------------------------------