├── .clang-format ├── CMakeLists.txt ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── cmdline ├── CMakeLists.txt └── ProxyWifiCmd.cpp ├── doc ├── proxy_wifi_general_architecture.png └── proxy_wifi_service_architecture.png ├── include └── ProxyWifi │ ├── Logs.hpp │ └── ProxyWifiService.hpp ├── lib ├── CMakeLists.txt ├── ClientWlanInterface.cpp ├── ClientWlanInterface.hpp ├── Connection.cpp ├── Connection.hpp ├── Iee80211Utils.hpp ├── Logs.cpp ├── LogsHelpers.hpp ├── Messages.cpp ├── Messages.hpp ├── Networks.cpp ├── Networks.hpp ├── OperationHandler.cpp ├── OperationHandler.hpp ├── OperationHandlerBuilder.hpp ├── Protocol.hpp ├── ProxyWifiServiceImpl.cpp ├── ProxyWifiServiceImpl.hpp ├── RealWlanInterface.cpp ├── RealWlanInterface.hpp ├── SocketHelpers.cpp ├── SocketHelpers.hpp ├── TestWlanInterface.cpp ├── TestWlanInterface.hpp ├── Tracelog.hpp ├── Transport.cpp ├── Transport.hpp ├── WlanInterface.hpp ├── WlanSvcHelpers.cpp ├── WlanSvcHelpers.hpp ├── WlanSvcWrapper.cpp ├── WlanSvcWrapper.hpp └── WlansvcOperationHandler.hpp ├── test ├── CMakeLists.txt ├── TestInit.cpp ├── TestOpHandler.cpp ├── TestUtils.cpp ├── WlansvcMock.hpp └── main.cpp └── util ├── CMakeLists.txt ├── DynamicFunction.hpp ├── GuidUtils.hpp ├── StringUtils.hpp └── WorkQueue.hpp /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | Language: Cpp 3 | # BasedOnStyle: LLVM 4 | AccessModifierOffset: -4 5 | AlignAfterOpenBracket: AlwaysBreak 6 | AlignConsecutiveAssignments: false 7 | AlignEscapedNewlines: DontAlign 8 | AlignOperands: true 9 | AlignTrailingComments: true 10 | AllowAllParametersOfDeclarationOnNextLine: true 11 | AllowShortBlocksOnASingleLine: false 12 | AllowShortCaseLabelsOnASingleLine: false 13 | AllowShortFunctionsOnASingleLine: None 14 | AllowShortIfStatementsOnASingleLine: false 15 | AllowShortLoopsOnASingleLine: false 16 | AlwaysBreakAfterDefinitionReturnType: None 17 | AlwaysBreakAfterReturnType: None 18 | AlwaysBreakBeforeMultilineStrings: true 19 | AlwaysBreakTemplateDeclarations: true 20 | BinPackArguments: false 21 | BinPackParameters: false 22 | BreakBeforeBinaryOperators: None 23 | BreakBeforeBraces: Custom 24 | BraceWrapping: 25 | AfterCaseLabel: true 26 | AfterClass: true 27 | AfterControlStatement: true 28 | AfterEnum: true 29 | AfterFunction: true 30 | AfterNamespace: false 31 | AfterStruct: true 32 | AfterUnion: true 33 | AfterExternBlock: false 34 | BeforeCatch: true 35 | BeforeElse: true 36 | SplitEmptyFunction: true 37 | SplitEmptyRecord: true 38 | SplitEmptyNamespace: true 39 | BreakBeforeTernaryOperators: true 40 | BreakConstructorInitializers: BeforeColon 41 | ColumnLimit: 130 42 | CommentPragmas: '^ IWYU pragma:' 43 | CompactNamespaces: true 44 | ConstructorInitializerAllOnOneLineOrOnePerLine: true 45 | ConstructorInitializerIndentWidth: 4 46 | ContinuationIndentWidth: 4 47 | Cpp11BracedListStyle: true 48 | DerivePointerAlignment: false 49 | DisableFormat: false 50 | ExperimentalAutoDetectBinPacking: false 51 | FixNamespaceComments: true 52 | ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ] 53 | IncludeBlocks: Regroup 54 | IncludeCategories: 55 | - Regex: '^"(stdafx.h|pch.h|precomp.h)"$' 56 | Priority: -1 57 | IndentCaseLabels: false 58 | IndentWidth: 4 59 | IndentWrappedFunctionNames: false 60 | KeepEmptyLinesAtTheStartOfBlocks: true 61 | MacroBlockBegin: '^BEGIN_COM_MAP$|^BEGIN_CONNECTION_POINT_MAP$|^BEGIN_HELPER_NODEMAP$|^BEGIN_MODULE$|^BEGIN_MSG_MAP$|^BEGIN_OBJECT_MAP$|^BEGIN_TEST_CLASS$|^BEGIN_TEST_METHOD$|^BEGIN_TEST_METHOD_PROPERTIES$' 62 | MacroBlockEnd: '^END_COM_MAP$|^END_CONNECTION_POINT_MAP$|^END_HELPER_NODEMAP$|^END_MODULE$|^END_MSG_MAP$|^END_OBJECT_MAP$|^END_TEST_CLASS$|^END_TEST_METHOD$|^END_TEST_METHOD_PROPERTIES$' 63 | MaxEmptyLinesToKeep: 1 64 | NamespaceIndentation: Inner 65 | ObjCBlockIndentWidth: 2 66 | ObjCSpaceAfterProperty: false 67 | ObjCSpaceBeforeProtocolList: true 68 | PenaltyBreakBeforeFirstCallParameter: 19 69 | PenaltyBreakComment: 300 70 | PenaltyBreakFirstLessLess: 120 71 | PenaltyBreakString: 1000 72 | PenaltyExcessCharacter: 1 73 | PenaltyReturnTypeOnItsOwnLine: 1000 74 | PointerAlignment: Left 75 | SortIncludes: false 76 | SpaceAfterCStyleCast: false 77 | SpaceBeforeAssignmentOperators: true 78 | SpaceBeforeParens: ControlStatements 79 | SpaceInEmptyParentheses: false 80 | SpacesBeforeTrailingComments: 1 81 | SpacesInAngles: false 82 | SpacesInContainerLiterals: true 83 | SpacesInCStyleCastParentheses: false 84 | SpacesInParentheses: false 85 | SpacesInSquareBrackets: false 86 | Standard: Cpp11 87 | StatementMacros: [ 88 | _Acquires_exclusive_lock_, 89 | _Acquires_lock_, 90 | _Acquires_nonreentrant_lock_, 91 | _Acquires_shared_lock_, 92 | _Analysis_assume_smart_lock_acquired_, 93 | _Analysis_assume_smart_lock_released_, 94 | _Create_lock_level_, 95 | _Detaches_lock_, 96 | _Function_class_, 97 | _Global_cancel_spin_lock_, 98 | _Global_critical_region_, 99 | _Global_interlock_, 100 | _Global_priority_region_, 101 | _Has_lock_kind_, 102 | _Has_lock_level_, 103 | _IRQL_always_function_max_, 104 | _IRQL_always_function_min_, 105 | _IRQL_raises_, 106 | _IRQL_requires_, 107 | _IRQL_requires_max_, 108 | _IRQL_requires_min_, 109 | _IRQL_requires_same_, 110 | _IRQL_restores_, 111 | _IRQL_restores_global_, 112 | _IRQL_saves_, 113 | _IRQL_saves_global_, 114 | _Lock_level_order_, 115 | _Moves_lock_, 116 | _Must_inspect_result_, 117 | _No_competing_thread_, 118 | _Post_same_lock_, 119 | _Post_writable_byte_size_, 120 | _Pre_satisfies_, 121 | _Releases_exclusive_lock_, 122 | _Releases_lock_, 123 | _Releases_nonreentrant_lock_, 124 | _Releases_shared_lock_, 125 | _Replaces_lock_, 126 | _Requires_exclusive_lock_held_, 127 | _Requires_lock_held_, 128 | _Requires_lock_not_held_, 129 | _Requires_no_locks_held_, 130 | _Requires_shared_lock_held_, 131 | _Ret_maybenull_, 132 | _Ret_range_, 133 | _Success_, 134 | _Swaps_locks_, 135 | _Use_decl_annotations_, 136 | _When_, 137 | RpcEndExcept, 138 | ] 139 | TabWidth: 4 140 | TypenameMacros: [ 141 | IFACEMETHOD, 142 | STDMETHOD, 143 | ] 144 | UseTab: Never 145 | ... 146 | 147 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.21) 2 | 3 | project(proxy-wifi LANGUAGES CXX) 4 | 5 | # Provide path for scripts 6 | list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/cmake") 7 | 8 | # Options 9 | option(MICROSOFT_TELEMETRY "Enable Microsoft telemetry collection" OFF) 10 | 11 | include(FetchContent) 12 | 13 | # Configure WIL dependency 14 | set(WIL_BUILD_TESTS OFF CACHE INTERNAL "Turn off wil tests") 15 | set(WIL_BUILD_PACKAGING OFF CACHE INTERNAL "Turn off wil packaging") 16 | FetchContent_Declare(WIL 17 | GIT_REPOSITORY "https://github.com/microsoft/wil" 18 | GIT_TAG "f9284c19c9873664978b873b8858d7dfacc6af1e" 19 | GIT_SHALLOW OFF 20 | ) 21 | FetchContent_MakeAvailable(WIL) 22 | 23 | # Build parameters 24 | 25 | # Default to debug build if unspecified 26 | if(NOT CMAKE_BUILD_TYPE) 27 | set(CMAKE_BUILD_TYPE "Debug") 28 | endif() 29 | 30 | # Rationalize TARGET_PLATFORM 31 | if("${CMAKE_GENERATOR_PLATFORM}" STREQUAL "arm64" OR "${TARGET_PLATFORM}" STREQUAL "arm64") 32 | set(TARGET_PLATFORM "arm64") 33 | elseif("${CMAKE_GENERATOR_PLATFORM}" MATCHES "x64|amd64|" OR "${TARGET_PLATFORM}" MATCHES "x64|amd64|") 34 | set(TARGET_PLATFORM "x64") 35 | else() 36 | message(FATAL_ERROR "Unsupported platform: ${CMAKE_GENERATOR_PLATFORM}") 37 | endif() 38 | 39 | set(CMAKE_CXX_STANDARD 20) 40 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 41 | set(CMAKE_CXX_EXTENSIONS OFF) 42 | set(CMAKE_EXPORT_COMPILE_COMMANDS ON) 43 | 44 | set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -DDEBUG -DDBG") 45 | set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /Zi") 46 | 47 | add_compile_options(/sdl /W4 /WX) 48 | 49 | add_compile_definitions( 50 | UNICODE 51 | NOMINMAX 52 | WIN32_LEAN_AND_MEAN 53 | ) 54 | 55 | if(MICROSOFT_TELEMETRY) 56 | add_compile_definitions(MICROSOFT_TELEMETRY) 57 | endif() 58 | 59 | add_subdirectory(lib) 60 | add_subdirectory(util) 61 | 62 | if(PROJECT_IS_TOP_LEVEL) 63 | add_subdirectory(cmdline) 64 | 65 | include(CTest) 66 | 67 | if(BUILD_TESTING) 68 | add_subdirectory(test) 69 | endif() 70 | endif() -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Proxy_wifi 3 | 4 | This project enables proxying the Wi-Fi functionnality from a Windows host machine to a VM. 5 | 6 | It allows the VM to: 7 | - schedule scans and access the Wi-Fi networks visible by the host 8 | - request connections and disconnections to WPA2PSK and Open networks that will be honored by the host 9 | - request connections to any Wi-Fi network the host is currently connected to ("connection mirroring") 10 | - be notified in case of disconnection and signal quality changes 11 | 12 | Networks can be also emulated by proxy_wifi, letting a VM connect to a fake Wi-Fi network. 13 | 14 | ## VM driver 15 | 16 | This repository contains only the Windows host component. A driver must run on the VM and communicate with the host component. 17 | Such a driver has currently been implemented from Linux only, under the name of `proxy_wifi`. 18 | 19 | ## Building 20 | 21 | Requirements: 22 | 23 | * cmake >= 3.20 (`choco install -u cmake`) 24 | * [Windows SDK 19041+](https://developer.microsoft.com/en-us/windows/downloads/windows-10-sdk/) 25 | * [Build Tools for Visual Studio 2019+](https://visualstudio.microsoft.com/downloads/#build-tools-for-visual-studio-2019) 26 | 27 | The project generates a MSBuild project and solution for each target. These can be used directly by MSBuild tools, Microsoft Visual Studio, or with CMake directly. There are two targets defined: 28 | 29 | 1) proxy-wifi: This is a static library containing the core logic. 30 | 2) proxy-wifi-cmdline: This is a commandline utility allowing to test the proxy_wifi for basic scenarios. 31 | 32 | ### Step 1: Configure the project 33 | 34 | In the project root: 35 | 36 | ```Shell 37 | md build && cd build 38 | cmake .. 39 | ``` 40 | 41 | Note that the build type is set to `DEBUG` by default. Make sure to set the `CMAKE_BUILD_TYPE` variable appropriately for the desired build flavor. For example, to create a release configuration: 42 | 43 | ```Shell 44 | cmake -DCMAKE_BUILD_TYPE=RelWithDebInfo .. 45 | ``` 46 | 47 | ### Step 2: Build the project 48 | 49 | #### With CMake 50 | 51 | ```Shell 52 | cmake --build .. 53 | ``` 54 | 55 | #### With Visual Studio 56 | 57 | Open `proxy-wifi.sln` which contains a project for each build target: 58 | 59 | * `proxy-wifi` 60 | * `proxy-wifi-cmdline` 61 | * `ALL_BUILD` 62 | * `INSTALL` 63 | 64 | #### With Visual Studio Build Tools 65 | 66 | Point `msbuild.exe` to the project to build: 67 | 68 | ```Shell 69 | msbuild.exe proxy-wifi.vcxproj 70 | ``` 71 | 72 | ### Step 3: Run unit tests 73 | 74 | #### Using the test executable 75 | 76 | ```Shell 77 | .\test\Debug\proxy-wifi-test.exe 78 | ``` 79 | 80 | #### With CTest 81 | 82 | ```Shell 83 | ctest -T test 84 | ``` 85 | 86 | ## Basic usage 87 | 88 | The public interface of Proxy_wifi is present in [`ProxyWifiService.hpp`](include/ProxyWifi/ProxyWifiService.hpp). It lets client build a `ProxyWifiService` object through the function `BuildProxyWifiService`. 89 | This functions allows to configure the transport layer and to register to Proxy_wifi notifications. 90 | 91 | You can additionnaly setup a logger to chose where logs will be output. 92 | 93 | ```cpp 94 | // Tell the ProxyWifiService to output logs to a tracelogging provider 95 | Log::AddLogger(std::make_unique()); 96 | 97 | class MyObserver: public ProxyWifiObserver 98 | { 99 | Authorization AuthorizeGuestConnectionRequest(OperationType type, const ConnectRequestArgs& connectionInfo) noexcept override 100 | { 101 | std::cout << "The guest requested a connection" << std::endl; 102 | return Authorization::Approve; 103 | } 104 | }; 105 | 106 | // Must be kept alive until the proxy is destroyed 107 | auto observer = MyObserver{}; 108 | 109 | // Callback providing a list of networks that will be simulated by the Wi-Fi proxy 110 | auto provideFakeNetworks = [this]() -> std::vector { return {}; /* Return simulated network SSID + BSSID */ }; 111 | 112 | auto wifiProxy = BuildProxyWifiService( 113 | ProxyWifiHyperVSettings{vmId}, 114 | std::move(provideFakeNetworks), 115 | &observer); 116 | wifiProxy->Start(); 117 | ``` 118 | 119 | ## Architecture overview 120 | 121 | Proxy_wifi is constituted of a service component (proxy_wifi service) meant to 122 | run on a Windows host, and a driver component (proxy_wifi driver) meant to run 123 | on a guest virtual machine. This repository contains the implementation of 124 | proxy_wifi service for Windows. Currently, an implementation of proxy_wifi 125 | driver exists only for Linux. 126 | 127 | ![General architecture](./doc/proxy_wifi_general_architecture.png) 128 | 129 | Proxy_wifi driver role is to intercept Wi-Fi control path related operations 130 | from the guest and forward them to proxy_wifi service. Communication between 131 | proxy_wifi driver and proxy_wifi service relies on two Hyper-V sockets (hvsockets): one allows 132 | proxy_wifi driver to make a request to proxy_wifi service and wait for the 133 | answer ; the other one allows proxy_wifi service to send spontaneous 134 | notifications to proxy_wifi driver. 135 | 136 | Proxy_wifi service is a library, its code will be running inside your 137 | application process. It uses the Windows Wlan API to perform any Wi-Fi 138 | related operation in the host. It can also emulate operations on fake 139 | networks if requested by the client application. 140 | 141 | Proxy_wifi service present an observer interface (`ProxyWifiObserver`) that 142 | client application should implement to be notified when an operation is request 143 | by the guest or a network change is detected on the host. It is expected the 144 | client application will monitor these events to synchronize proxy_wifi with 145 | other component virtualizing the data path or L3 layer properties. 146 | 147 | ## Proxy_wifi service architecture 148 | 149 | ![Proxy_wifi service architecture](./doc/proxy_wifi_service_architecture.png) 150 | 151 | Proxy_wifi is composed of two main layers: the transport 152 | ([Transport.hpp](lib/Transport.hpp)) and the operation handler 153 | ([OperationHandler.hpp](lib/OperationHandler.hpp)). 154 | 155 | The concrete implementations of `ProxyWifiService` initialize these two layers 156 | and expose a clean API to client applications. 157 | 158 | The transport is responsible for communications with the proxy_wifi driver on 159 | the guest VM. The operation handler is responsible of honoring these requests 160 | and notifying the client application. 161 | 162 | ### Transport layer 163 | 164 | The transport layer communicate with the guest proxy_wifi driver over two channels: 165 | 166 | - a request channel: it handles commands from the proxy_wifi driver and their 167 | responses, such as connection or scan requests. 168 | - a notification channel: it handles spontaneous notification from the host, such 169 | as network disconnection, signal quality changes and scanned bss updates. 170 | 171 | The transport layers deals with `Message` ([Messages.hpp](lib/Messages.hpp)) and isn't aware of their content or structure beside the header + body layout. 172 | 173 | Messages coming over the request channel are handled as synchronous transactions: once a request is accepted in `Transport::AcceptConnections` ([Transport.hpp](lib/Transport.hpp)), no other request will be handled until the transaction is completed by sending an answer. 174 | 175 | On the other hand, messages sent over the notification channel are handled in a fire and forget fashion: the transport send notification asynchronously through a queue and doesn't expect any answer from proxy_wifi driver. 176 | 177 | Two implementations of the transport layers are available: over HV sockets or over TCP sockets. 178 | Only the HV socket implementation is supported by the current proxy_wifi driver on the guest. The TCP implementation is mostly available for testing. 179 | 180 | ### The operation handler 181 | 182 | The operation handler manage a list of interfaces. It dispatches request 183 | received by the transport and aggregate results produced by each interfaces. It 184 | is also responsible to send notification both to proxy_wifi driver on the guest 185 | and to the client application. 186 | 187 | The operation handler uses typed requests and responses, parsed from the 188 | `Messages` used by the transport ([`Messages.hpp`](lib/Messages.hpp)). 189 | Interface management, requests processing and notifications targetting the 190 | guest are serialized through a work queue (`m_serializedRunner`). 191 | 192 | Another work queue (`m_clientNotificationQueue`) serialize notifications sent 193 | to the client application. 194 | 195 | The operation handler maintain a list of `WlanInterface` 196 | ([WlanInterface.hpp](lib/WlanInterface.hpp)) that represent actual Wi-Fi 197 | interfaces on the host or fake interfaces to simulate networks. 198 | 199 | Requests are dispatched to interfaces, either serialized or running 200 | concurrently depending on the operation type. 201 | 202 | In particular: 203 | 204 | - "Scan" operations are ran on all available interfaces concurrently, and the scan results are aggregated 205 | - "Connect" operations are ran successively on each interface until one succeeds 206 | - "Disconnect" operations target only the previously connected interface 207 | 208 | Each interface handles the command depending on its concrete implementation, 209 | using simulated data or fowarding the operation through the host Wlan API. 210 | 211 | #### Simulated interface 212 | 213 | If the client application provide a `FakeNetworkProvider` callback when instantiating proxy_wifi service, a simulated (or "fake" interface will be created). It allows to simulate Wi-Fi networks in the guest. 214 | 215 | At any point in time, the simulated interface consider that all networks returned by the `FakeNetworkProvider` callback are visible and connected. Connection and disconnection request to these networks will always succeed as long as they are returned by the callback and they will correspond to no-op. 216 | 217 | The simulated interface is always considered first when dispatching operations, effectively giving it priority. 218 | 219 | #### Real interfaces 220 | 221 | Real interfaces correspond to the actual wlan STA interfaces on the host. 222 | Proxy_wifi service add and remove these interfaces dynamically following their 223 | state on the system. 224 | 225 | Operation on real interfaces are implemented through calls to the Win32 Wlan 226 | APIs. However, the operations requested from the guest are not mapped directly 227 | to the equivalent host operation, since those operation can be time-consuming 228 | or impact the host network connectivity. 229 | 230 | In particular: 231 | 232 | - a host interface will be connected at the guest request only if no other 233 | interface was already connected to the requested network. 234 | - a host interface will be disconnected at the guest request only if it was 235 | connected because of the guest. 236 | - a host interface will always schedule a scan on the host at the guest 237 | request, but will return cached results immediately, unless a previous scan 238 | request still hasn't completed. Whenever a scan completes on the host, the 239 | visible network are spontaneously sent to the guest to limit the number of 240 | scan necessary. 241 | 242 | ## Capabilities and limitations 243 | 244 | Proxy_wifi currently fully support connections from the guest to WPA2PSK and Open networks only. 245 | This is due to Linux allowing only these two authentication algorithms to be offloaded to the driver. 246 | 247 | In addition, Proxy_wifi allows the guest VM to "mirror" a connection on the host, for any network type. 248 | Since the host is already connected, no authentication is actually needed, which make the operation straighforward. 249 | However, in order to allow this, the authentication algorithm for any non-supported network is faked in scan results and showed as WPA2-PSK. 250 | 251 | ## Contributing 252 | 253 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 254 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 255 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 256 | 257 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 258 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 259 | provided by the bot. You will only need to do this once across all repos using our CLA. 260 | 261 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 262 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 263 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 264 | 265 | ## Data Collection 266 | 267 | The software may collect information about you and your use of the software and send it to Microsoft. Microsoft may use this information to provide services and improve our products and services. You may turn off the telemetry by setting the CMake cache variable `MICROSOFT_TELEMETRY` to `OFF`: 268 | 269 | ```Shell 270 | cmake -DMICROSOFT_TELEMETRY=OFF .. 271 | ``` 272 | 273 | There are also some features in the software that may enable you and Microsoft to collect data from users of your applications. If you use these features, you must comply with applicable law, including providing appropriate notices to users of your applications together with a copy of Microsoft’s privacy statement. Our privacy statement is located at . You can learn more about data collection and use in the help documentation and our privacy statement. Your use of the software operates as your consent to these practices. 274 | 275 | ## Trademarks 276 | 277 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 278 | trademarks or logos is subject to and must follow 279 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). 280 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. 281 | Any use of third-party trademarks or logos are subject to those third-party's policies. -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | # Security 3 | 4 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 5 | 6 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below. 7 | 8 | ## Reporting Security Issues 9 | 10 | **Please do not report security vulnerabilities through public GitHub issues.** 11 | 12 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). 13 | 14 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). 15 | 16 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 17 | 18 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 19 | 20 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 21 | * Full paths of source file(s) related to the manifestation of the issue 22 | * The location of the affected source code (tag/branch/commit or direct URL) 23 | * Any special configuration required to reproduce the issue 24 | * Step-by-step instructions to reproduce the issue 25 | * Proof-of-concept or exploit code (if possible) 26 | * Impact of the issue, including how an attacker might exploit the issue 27 | 28 | This information will help us triage your report more quickly. 29 | 30 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. 31 | 32 | ## Preferred Languages 33 | 34 | We prefer all communications to be in English. 35 | 36 | ## Policy 37 | 38 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). 39 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | 2 | # Support 3 | 4 | ## How to file issues and get help 5 | 6 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 7 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 8 | feature request as a new Issue. 9 | 10 | ## Microsoft Support Policy 11 | 12 | Support for this project is limited to the resources listed above. 13 | -------------------------------------------------------------------------------- /cmdline/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | add_executable(proxy-wifi-cmdline "") 5 | 6 | target_sources(proxy-wifi-cmdline 7 | PRIVATE 8 | ProxyWifiCmd.cpp 9 | ) 10 | 11 | target_link_libraries(proxy-wifi-cmdline 12 | PRIVATE 13 | rpcrt4.lib 14 | proxy-wifi 15 | proxy-wifi-util 16 | WIL 17 | ) 18 | -------------------------------------------------------------------------------- /cmdline/ProxyWifiCmd.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "ProxyWifi/Logs.hpp" 5 | #include "ProxyWifi/ProxyWifiService.hpp" 6 | 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | #include 18 | 19 | using namespace ProxyWifi; 20 | 21 | static void print_help() 22 | { 23 | std::cout << "\nCommand line arguments:\n\n" 24 | "-vmid \n" 25 | "-p \n" 26 | "-n \n" 27 | "-k (Use hv_sockets - default)\n" 28 | "-l \n" 29 | "-f (Fake wlansvc - for tests)\n" 30 | "-tlog (Log to tracelogging in addition to console output)\n" 31 | "-h (Print this help)"; 32 | } 33 | 34 | enum class TransportType 35 | { 36 | HyperVSocket, 37 | TcpSocket, 38 | }; 39 | 40 | std::wstring ToString(TransportType e) 41 | { 42 | switch (e) 43 | { 44 | case TransportType::HyperVSocket: 45 | return L"Hyper-V"; 46 | case TransportType::TcpSocket: 47 | return L"Tcp"; 48 | default: 49 | throw std::invalid_argument("Invalid enum value"); 50 | } 51 | } 52 | 53 | struct ProxyWifiConfig 54 | { 55 | GUID VmId{}; 56 | std::optional ListenIp; 57 | unsigned short RequestResponsePort = RequestResponsePortDefault; 58 | unsigned short NotificationPort = NotificationPortDefault; 59 | OperationMode Mode = OperationMode::Normal; 60 | TransportType Transport = TransportType::HyperVSocket; 61 | bool UseTracelogging = false; 62 | }; 63 | 64 | std::optional CreateProxyConfigFromArguments(int argc, const char* argv[]) 65 | { 66 | ProxyWifiConfig manager{}; 67 | 68 | auto i = 1; 69 | while (i < argc) 70 | { 71 | const std::string_view option(argv[i]); 72 | if (option == "-h") 73 | { 74 | print_help(); 75 | return std::nullopt; 76 | } 77 | else if (option == "-p") 78 | { 79 | std::string_view value(argv[++i]); 80 | const auto res = std::from_chars(value.data(), value.data() + value.size(), manager.RequestResponsePort); 81 | if (res.ec == std::errc::invalid_argument) 82 | { 83 | throw std::invalid_argument("Invalid port number"); 84 | } 85 | } 86 | else if (option == "-n") 87 | { 88 | std::string_view value(argv[++i]); 89 | const auto res = std::from_chars(value.data(), value.data() + value.size(), manager.NotificationPort); 90 | if (res.ec == std::errc::invalid_argument) 91 | { 92 | throw std::invalid_argument("Invalid port number"); 93 | } 94 | } 95 | else if (option == "-k") 96 | { 97 | manager.Transport = TransportType::HyperVSocket; 98 | } 99 | else if (option == "-f") 100 | { 101 | manager.Mode = OperationMode::Simulated; 102 | } 103 | else if (option == "-l") 104 | { 105 | manager.ListenIp = argv[++i]; 106 | manager.Transport = TransportType::TcpSocket; 107 | } 108 | else if (option == "-vmid") 109 | { 110 | GUID guestVmId; 111 | const char* guestVmIdStr = argv[++i]; 112 | if (UuidFromStringA((unsigned char*)guestVmIdStr, &guestVmId) != RPC_S_OK) 113 | { 114 | throw std::invalid_argument("Invalid VM Guid"); 115 | } 116 | manager.VmId = guestVmId; 117 | } 118 | else if (option == "-tlog") 119 | { 120 | manager.UseTracelogging = true; 121 | } 122 | else 123 | { 124 | throw std::invalid_argument("Invalid argument " + std::string(argv[i])); 125 | } 126 | 127 | i++; 128 | } 129 | 130 | switch (manager.Transport) 131 | { 132 | case TransportType::TcpSocket: 133 | if (!manager.ListenIp) 134 | { 135 | throw std::invalid_argument("An ip address is required when using tcp transport"); 136 | } 137 | break; 138 | case TransportType::HyperVSocket: 139 | if (manager.ListenIp) 140 | { 141 | throw std::invalid_argument("An ip address not not be provided when using HyperV socket transport"); 142 | } 143 | break; 144 | default: 145 | throw std::runtime_error("unsupported proxy transport specified"); 146 | } 147 | 148 | if (manager.VmId == GUID{}) 149 | { 150 | throw std::invalid_argument("A VM GUID is required."); 151 | } 152 | 153 | return manager; 154 | } 155 | 156 | std::unique_ptr BuildProxyWifiService(const ProxyWifiConfig& settings) 157 | { 158 | wil::unique_rpc_wstr vmStr; 159 | THROW_IF_FAILED(UuidToString(&settings.VmId, &vmStr)); 160 | 161 | Log::Info(L"Creating %ws Wi-Fi Proxy for VM id=%ws, request port=%u, notification port=%u", ToString(settings.Transport).c_str(), vmStr.get(), settings.RequestResponsePort, settings.NotificationPort); 162 | 163 | switch (settings.Transport) 164 | { 165 | case TransportType::HyperVSocket: 166 | { 167 | const ProxyWifiHyperVSettings proxySettings{settings.VmId, settings.RequestResponsePort, settings.NotificationPort, settings.Mode}; 168 | return BuildProxyWifiService(proxySettings, {}); 169 | } 170 | case TransportType::TcpSocket: 171 | { 172 | const ProxyWifiTcpSettings proxySettings{settings.ListenIp.value(), settings.RequestResponsePort, settings.NotificationPort, settings.Mode}; 173 | return BuildProxyWifiService(proxySettings, {}); 174 | } 175 | default: 176 | throw std::runtime_error("unsupported proxy protocol transport selected"); 177 | } 178 | } 179 | 180 | int main(int argc, const char* argv[]) 181 | try 182 | { 183 | Log::AddLogger(std::make_unique()); 184 | 185 | // Redirect WIL failures as logs 186 | wil::SetResultLoggingCallback([](const wil::FailureInfo& failure) noexcept { 187 | constexpr std::size_t sizeOfLogMessageWithNul = 2048; 188 | 189 | wchar_t logMessage[sizeOfLogMessageWithNul]{}; 190 | wil::GetFailureLogString(logMessage, sizeOfLogMessageWithNul, failure); 191 | Log::WilFailure(logMessage); 192 | }); 193 | 194 | // Initialize winsock. 195 | { 196 | WSADATA wsaData; 197 | const auto wsError = WSAStartup(MAKEWORD(2, 2), &wsaData); 198 | if (wsError != 0) 199 | { 200 | LOG_WIN32_MSG(wsError, "WSAStartup failed"); 201 | return -1; 202 | } 203 | } 204 | 205 | auto wsaCleanupOnExit = wil::scope_exit([&] { WSACleanup(); }); 206 | 207 | std::optional proxyConfig; 208 | try 209 | { 210 | proxyConfig = CreateProxyConfigFromArguments(argc, argv); 211 | } 212 | catch (const std::invalid_argument& e) 213 | { 214 | std::cerr << "Invalid parameter: " << e.what() << std::endl; 215 | print_help(); 216 | return -1; 217 | } 218 | 219 | if (!proxyConfig) 220 | { 221 | return 0; 222 | } 223 | 224 | if (proxyConfig->UseTracelogging) 225 | { 226 | Log::AddLogger(std::make_unique()); 227 | } 228 | 229 | const auto proxyService = BuildProxyWifiService(*proxyConfig); 230 | proxyService->Start(); 231 | 232 | // Sleep until the program is interrupted with Ctrl-C 233 | Sleep(INFINITE); 234 | 235 | return 0; 236 | } 237 | catch (const wil::ResultException& ex) 238 | { 239 | std::wcerr << "Caught unhandled exception: " << ex.GetErrorCode(); 240 | if (ex.GetFailureInfo().pszMessage) 241 | { 242 | std::wcerr << ", " << ex.GetFailureInfo().pszMessage; 243 | } 244 | std::wcerr << std::endl; 245 | 246 | return -1; 247 | } 248 | catch (const std::exception& ex) 249 | { 250 | std::cerr << "Caught unhandled exception: " << ex.what() << std::endl; 251 | return -1; 252 | } 253 | catch (...) 254 | { 255 | std::cerr << "Caught unhandled exception" << std::endl; 256 | FAIL_FAST_MSG("FATAL: UNHANDLED EXCEPTION"); 257 | } 258 | -------------------------------------------------------------------------------- /doc/proxy_wifi_general_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/proxy_wifi/2c97e26ea41defef8834359345811e25a75e743f/doc/proxy_wifi_general_architecture.png -------------------------------------------------------------------------------- /doc/proxy_wifi_service_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/proxy_wifi/2c97e26ea41defef8834359345811e25a75e743f/doc/proxy_wifi_service_architecture.png -------------------------------------------------------------------------------- /include/ProxyWifi/Logs.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | #pragma once 4 | 5 | #include 6 | #include 7 | 8 | namespace ProxyWifi::Log { 9 | 10 | enum class Level 11 | { 12 | Error, 13 | Info, 14 | Trace, 15 | Debug 16 | }; 17 | 18 | /// @brief Interface to implement for defining a new log output target 19 | class Logger 20 | { 21 | public: 22 | virtual void Log(Level level, const wchar_t* message) noexcept = 0; 23 | virtual ~Logger() = default; 24 | }; 25 | 26 | /// @brief Logger printing to the standard console output 27 | class ConsoleLogger : public Logger 28 | { 29 | public: 30 | void Log(Level level, const wchar_t* message) noexcept override; 31 | }; 32 | 33 | /// @brief Logger generating Tracelogging events 34 | class TraceLoggingLogger : public Logger 35 | { 36 | public: 37 | void Log(Level level, const wchar_t* message) noexcept override; 38 | }; 39 | 40 | namespace Details { 41 | 42 | /// @brief `LogManager` formats log messages and dispatch them to a list of `Logger`s 43 | /// 44 | /// It is built as a singleton and should be accessed through a set of helper functions 45 | class LogManager 46 | { 47 | public: 48 | LogManager(const LogManager&) = delete; 49 | LogManager(LogManager&&) = delete; 50 | LogManager& operator=(const LogManager&) = delete; 51 | LogManager& operator=(LogManager&&) = delete; 52 | 53 | static LogManager& Get() noexcept 54 | { 55 | static LogManager defaultLogger{}; 56 | return defaultLogger; 57 | } 58 | 59 | void AddLogger(std::unique_ptr logger) 60 | { 61 | m_loggers.emplace_back(std::move(logger)); 62 | } 63 | 64 | template 65 | inline void Log(Level level, const wchar_t* format, T&&... args) const noexcept 66 | { 67 | // Ensure only funadamental types or pointer to fundamental types are used 68 | // Because of the indirection, the compiler doesn't check the format, 69 | // this at least check parameters have been converted to a fundamental type or C-string. 70 | static_assert(all(std::is_fundamental_v>>...)); 71 | 72 | wchar_t message[2048]{}; 73 | swprintf_s(message, format, std::forward(args)...); 74 | 75 | for (auto& logger : m_loggers) 76 | { 77 | logger->Log(level, message); 78 | } 79 | } 80 | 81 | private: 82 | std::vector> m_loggers; 83 | LogManager() = default; 84 | 85 | template 86 | static constexpr bool all(Args... args) { return (... && args); } 87 | }; 88 | 89 | } // namespace Details 90 | 91 | inline void AddLogger(std::unique_ptr logger) 92 | { 93 | Details::LogManager::Get().AddLogger(std::move(logger)); 94 | } 95 | 96 | /// @brief Handler for WIL reported failures 97 | /// There is intentionaly no function for "Error" level logs: they should be reported through WIL error handlers 98 | inline void WilFailure(const wchar_t* message) 99 | { 100 | Details::LogManager::Get().Log(Level::Error, L"%ws", message); 101 | } 102 | 103 | template 104 | inline void Info(const wchar_t* format, T&&... args) 105 | { 106 | Details::LogManager::Get().Log(Level::Info, format, std::forward(args)...); 107 | } 108 | 109 | template 110 | inline void Trace(const wchar_t* format, T&&... args) 111 | { 112 | Details::LogManager::Get().Log(Level::Trace, format, std::forward(args)...); 113 | } 114 | 115 | template 116 | inline void Debug(const wchar_t* format, T&&... args) 117 | { 118 | Details::LogManager::Get().Log(Level::Debug, format, std::forward(args)...); 119 | } 120 | } // namespace ProxyWifi::Log -------------------------------------------------------------------------------- /include/ProxyWifi/ProxyWifiService.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | #pragma once 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | 14 | namespace ProxyWifi { 15 | 16 | /// @brief The mode controlling how the proxy operates. 17 | enum class OperationMode 18 | { 19 | /// @brief Normal (proxied) mode. 20 | /// 21 | /// The proxy provides real results using available hardware on the host. If 22 | /// no suitable hardware is available, most operations are likely to fail. 23 | Normal, 24 | 25 | /// @brief Simulated mode. 26 | /// 27 | /// The proxy simulates results, not using hardware on the host, even if it 28 | /// is present and available. 29 | Simulated, 30 | }; 31 | 32 | /// @brief Represents a host Wi-Fi proxy service. 33 | class ProxyWifiService 34 | { 35 | public: 36 | virtual ~ProxyWifiService() = default; 37 | 38 | /// @brief Start the proxy. 39 | /// 40 | /// This creates the transport and begins accepting connections on it. Until 41 | /// Start() is called, the proxy is inactive and does not accept connections. 42 | virtual void Start() = 0; 43 | 44 | /// @brief Stop the proxy. 45 | /// 46 | /// Stop the proxy if it is started. This will sever all existing connections 47 | /// to the proxy and destroy the transport. Following execution of this call, 48 | /// the proxy will no longer accept new connections. It may be restarted 49 | /// using Start(). 50 | virtual void Stop() = 0; 51 | }; 52 | 53 | /// @brief Indicate the status of a guest initiated operation 54 | enum class OperationStatus 55 | { 56 | Succeeded, ///< The operation was completed successfully and success will be indicated to the guest 57 | Failed, ///< The operation failed and failure will be indicated to the guest 58 | Denied ///< The operation was denied by the client and failure will be indicated to the guest 59 | }; 60 | 61 | /// @brief List basic information about a Wi-Fi network 62 | struct WifiNetworkInfo 63 | { 64 | WifiNetworkInfo() = default; 65 | WifiNetworkInfo(const DOT11_SSID& ssid, const DOT11_MAC_ADDRESS& bssid); 66 | 67 | DOT11_SSID ssid = {}; 68 | DOT11_MAC_ADDRESS bssid = {}; 69 | }; 70 | 71 | /// @brief Indicate the impact a guest requested operation will have on the host 72 | enum class OperationType 73 | { 74 | GuestDirected, ///< The guest is directing this operation, and the host state will change to accomodate it 75 | HostMirroring ///< The guest was only replicating the state of the host, the host state won't change as a result of this request 76 | }; 77 | 78 | /// @brief Observer class that get notified on host or guest events 79 | /// Client should inherit from it and override method to handle notifications 80 | class ProxyWifiObserver 81 | { 82 | public: 83 | virtual ~ProxyWifiObserver() = default; 84 | 85 | struct ConnectRequestArgs 86 | { 87 | DOT11_SSID ssid; 88 | }; 89 | 90 | struct ConnectCompleteArgs 91 | { 92 | GUID interfaceGuid; 93 | DOT11_SSID ssid; 94 | DOT11_AUTH_ALGORITHM authAlgo; 95 | }; 96 | 97 | struct DisconnectRequestArgs 98 | { 99 | DOT11_SSID ssid; 100 | }; 101 | 102 | struct DisconnectCompleteArgs 103 | { 104 | GUID interfaceGuid; 105 | DOT11_SSID ssid; 106 | }; 107 | 108 | /// @brief Indicate whether proxy_wifi should proceed with a guest request or answer immediately with a failure 109 | enum class Authorization 110 | { 111 | Approve, 112 | Deny 113 | }; 114 | 115 | /// @brief An host WiFi interface connected to a network 116 | virtual void OnHostConnection(const ConnectCompleteArgs& /* connectionInfo */) noexcept 117 | { 118 | } 119 | 120 | /// @brief An host WiFi interface disconnected from a network 121 | virtual void OnHostDisconnection(const DisconnectCompleteArgs& /* disconnectionInfo */) noexcept 122 | { 123 | } 124 | 125 | /// @brief The guest requested a connection to a network 126 | /// @return Return `Authorization::Approve` to let the connection proceed, return `Authorization::Deny` to answer to the guest 127 | /// with a failure If `type == OperationType::HostMirroring`, an host inteface is already connected to the network, otherwise, 128 | /// one will be connected. The connection won't proceed until the callback returns 129 | virtual Authorization AuthorizeGuestConnectionRequest(OperationType /* type */, const ConnectRequestArgs& /* connectionInfo */) noexcept 130 | { 131 | return Authorization::Approve; 132 | } 133 | 134 | /// @brief A guest connection request was processed 135 | /// If `type == OperationType::HostMirroring`, an host inteface was already connected to the network, otherwise, one has been 136 | /// be connected The response won't be sent to the guest until this callback returns 137 | virtual void OnGuestConnectionCompletion(OperationType /* type */, OperationStatus /* status */, const ConnectCompleteArgs& /* connectionInfo */) noexcept 138 | { 139 | } 140 | 141 | /// @brief The guest requested a disconnection from the connected network 142 | /// If `type == OperationType::HostMirroring`, the host won't be impacted, otherwise, a matching host interface will be 143 | /// disconnected The disconnection won't proceed until the callback returns 144 | virtual void OnGuestDisconnectionRequest(OperationType /* type */, const DisconnectRequestArgs& /* connectionInfo */) noexcept 145 | { 146 | } 147 | 148 | /// @brief A guest disconnection request was processed 149 | /// If `type == OperationType::HostMirroring`, this was a no-op for the host, otherwise, a matching host interface has been 150 | /// disconnected The response won't be sent to the guest until this callback returns 151 | virtual void OnGuestDisconnectionCompletion(OperationType /* type */, OperationStatus /* status */, const DisconnectCompleteArgs& /* disconnectionInfo */) noexcept 152 | { 153 | } 154 | 155 | /// @brief The guest requested a scan 156 | /// The scan won't start on the host until this callback returns 157 | virtual void OnGuestScanRequest() noexcept 158 | { 159 | } 160 | 161 | /// @brief A guest scan request was processed 162 | /// The scan results won't be sent to the guest until this callback returns 163 | virtual void OnGuestScanCompletion(OperationStatus /* status */) noexcept 164 | { 165 | } 166 | }; 167 | 168 | /// @brief Type of the callback providing a list of networks that will be simulated by the Wi-Fi proxy 169 | /// They will be shown as open networks, and are considered as permanently connected for the purpose of notifications 170 | using FakeNetworkProvider = std::function()>; 171 | 172 | /// @brief Guid used in notifications concerning the provided fake networks 173 | /// 1b57e649-a1df-482f-85c2-a16063836418 174 | constexpr GUID FakeInterfaceGuid{0x1b57e649, 0xa1df, 0x482f, {0x85, 0xc2, 0xa, 0x6, 0x6, 0x8, 0x6, 0x18}}; 175 | 176 | /// @brief Default request/response port used for both HyperV and TCP based Wi-Fi 177 | /// proxies if none is explicitly specified. 178 | constexpr unsigned short RequestResponsePortDefault = 12345; 179 | 180 | /// @brief Default notification port used for both HyperV and TCP based Wi-Fi 181 | /// proxies if none is explicitly specified. 182 | constexpr unsigned short NotificationPortDefault = 12346; 183 | 184 | /// @brief Settings controlling a HyperV based Wi-Fi proxy. 185 | struct ProxyWifiHyperVSettings 186 | { 187 | /// @brief Construct a setting object to configure a new Wifi Proxy using an Hyper V transport 188 | /// @param guestVmId The vm id of the HyperV container guest from which to allow connections. 189 | /// @param requestResponsePort The HyperV socket port number for the request/response communication channel. 190 | /// @param notificationPort The HyperV socket port number for the notification communication channel. 191 | /// @param mode The mode of operation used to emulate or virtualize Wifi 192 | ProxyWifiHyperVSettings(const GUID& guestVmId, unsigned short requestResponsePort, unsigned short notificationPort, OperationMode mode); 193 | 194 | /// @brief Construct a setting object to configure a new Wifi Proxy using an Hyper V transport 195 | /// @param guestVmId The vm id of the HyperV container guest from which to allow connections 196 | explicit ProxyWifiHyperVSettings(const GUID& guestVmId); 197 | 198 | /// @brief The HyperV socket port number for the request/response communication channel. 199 | unsigned short RequestResponsePort = RequestResponsePortDefault; 200 | 201 | /// @brief The HyperV socket port number for the notification communication channel. 202 | unsigned short NotificationPort = NotificationPortDefault; 203 | 204 | /// @brief The vm id of the HyperV container guest from which to allow connections. 205 | const GUID GuestVmId{}; 206 | 207 | /// @brief The initial mode for the proxy 208 | const OperationMode ProxyMode = OperationMode::Normal; 209 | }; 210 | 211 | /// @brief Settings controlling a TCP based Wi-Fi proxy. 212 | struct ProxyWifiTcpSettings 213 | { 214 | /// @brief Construct a setting object to configure a new Wifi Proxy using a Tcp transport 215 | /// @param listenIp The TCP/IP address for the proxy to listen for connection. 216 | /// @param requestResponsePort The TCP/IP port number for the request/response communication channel. 217 | /// @param notificationPort The TCP/IP port number for the notification communication channel. 218 | /// @param mode The mode of operation used to emulate or virtualize Wifi 219 | ProxyWifiTcpSettings(std::string listenIp, unsigned short requestResponsePort, unsigned short notificationPort, OperationMode mode); 220 | 221 | /// @brief Construct a setting object to configure a new Wifi Proxy using a Tcp transport 222 | /// @param listenIp The TCP/IP address for the proxy to listen for connection. 223 | explicit ProxyWifiTcpSettings(std::string listenIp); 224 | 225 | /// @brief The TCP/IP port number for the request/response communication channel. 226 | unsigned short RequestResponsePort = RequestResponsePortDefault; 227 | 228 | /// @brief The TCP/IP port number for the notification communication channel. 229 | unsigned short NotificationPort = NotificationPortDefault; 230 | 231 | /// @brief The TCP/IP address for the proxy to listen for connection. 232 | const std::string ListenIp; 233 | 234 | /// @brief The initial mode for the proxy 235 | const OperationMode ProxyMode = OperationMode::Normal; 236 | }; 237 | 238 | std::unique_ptr BuildProxyWifiService( 239 | const ProxyWifiHyperVSettings& settings, FakeNetworkProvider fakeNetworkCallback = {}, ProxyWifiObserver* pObserver = nullptr); 240 | std::unique_ptr BuildProxyWifiService( 241 | const ProxyWifiTcpSettings& settings, FakeNetworkProvider fakeNetworkCallback = {}, ProxyWifiObserver* pObserver = nullptr); 242 | 243 | } // namespace ProxyWifi -------------------------------------------------------------------------------- /lib/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | add_library(proxy-wifi STATIC "") 5 | 6 | target_include_directories(proxy-wifi 7 | PUBLIC 8 | $ 9 | ) 10 | 11 | target_link_libraries(proxy-wifi 12 | PUBLIC 13 | Mswsock.lib 14 | Ws2_32.lib 15 | PRIVATE 16 | WIL 17 | proxy-wifi-util 18 | ) 19 | 20 | set(PROXY_WIFI_PUBLIC_HEADERS 21 | ${CMAKE_CURRENT_SOURCE_DIR}/../include/ProxyWifi/ProxyWifiService.hpp 22 | ${CMAKE_CURRENT_SOURCE_DIR}/../include/ProxyWifi/Logs.hpp 23 | ) 24 | 25 | target_sources(proxy-wifi 26 | PRIVATE 27 | ${PROXY_WIFI_PUBLIC_HEADERS} 28 | Iee80211Utils.hpp 29 | ClientWlanInterface.hpp 30 | ClientWlanInterface.cpp 31 | Connection.hpp 32 | Connection.cpp 33 | Networks.hpp 34 | Networks.cpp 35 | Logs.cpp 36 | LogsHelpers.hpp 37 | Messages.hpp 38 | Messages.cpp 39 | OperationHandler.hpp 40 | OperationHandler.cpp 41 | WlanSvcOperationHandler.hpp 42 | OperationHandlerBuilder.hpp 43 | Protocol.hpp 44 | RealWlanInterface.hpp 45 | RealWlanInterface.cpp 46 | SocketHelpers.hpp 47 | SocketHelpers.cpp 48 | TestWlanInterface.hpp 49 | TestWlanInterface.cpp 50 | Tracelog.hpp 51 | Transport.hpp 52 | Transport.cpp 53 | ProxyWifiServiceImpl.hpp 54 | ProxyWifiServiceImpl.cpp 55 | WlanInterface.hpp 56 | WlanSvcHelpers.hpp 57 | WlanSvcHelpers.cpp 58 | WlanSvcWrapper.hpp 59 | WlanSvcWrapper.cpp 60 | ) 61 | 62 | set_target_properties(proxy-wifi PROPERTIES 63 | PUBLIC_HEADER "${PROXY_WIFI_PUBLIC_HEADERS}" 64 | ) 65 | 66 | install(TARGETS proxy-wifi 67 | EXPORT proxy-wifi-targets 68 | ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}/${TARGET_PLATFORM} 69 | INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} 70 | PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ProxyWifi 71 | ) 72 | 73 | install(FILES 74 | $/$.pdb 75 | DESTINATION ${CMAKE_INSTALL_LIBDIR}/${TARGET_PLATFORM} 76 | ) 77 | 78 | install(EXPORT proxy-wifi-targets 79 | NAMESPACE proxy-wifi:: 80 | DESTINATION ${CMAKE_INSTALL_LIBDIR}/${TARGET_PLATFORM}/cmake/proxy-wifi 81 | ) 82 | -------------------------------------------------------------------------------- /lib/ClientWlanInterface.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "ClientWlanInterface.hpp" 5 | 6 | #include "StringUtils.hpp" 7 | #include "ProxyWifi/Logs.hpp" 8 | 9 | namespace ProxyWifi { 10 | 11 | ClientWlanInterface::ClientWlanInterface(const GUID& interfaceGuid, std::function()> callback) 12 | : m_getClientBssCallback{std::move(callback)}, m_interfaceGuid{interfaceGuid} 13 | { 14 | } 15 | 16 | void ClientWlanInterface::SetNotificationHandler(INotificationHandler*) 17 | { 18 | // ClientWlanInterface never send notification: all its networks are always connected with 100% signal quality 19 | } 20 | 21 | const GUID& ClientWlanInterface::GetGuid() const noexcept 22 | { 23 | return m_interfaceGuid; 24 | } 25 | 26 | std::optional ClientWlanInterface::IsConnectedTo(const Ssid& requestedSsid) noexcept 27 | { 28 | const auto clientNetworks = GetBssFromClient(); 29 | const auto network = 30 | std::ranges::find_if(clientNetworks, [&](const auto& n) { return n.ssid == requestedSsid; }); 31 | 32 | if (network == clientNetworks.cend()) 33 | { 34 | return std::nullopt; 35 | } 36 | // A fake interface is always considered connected to all the networks it handles 37 | Log::Info( 38 | L"Client interface %ws already connected to ssid: %ws", 39 | GuidToString(m_interfaceGuid).c_str(), 40 | SsidToLogString(requestedSsid.value()).c_str()); 41 | // Fake interfaces pretend to see WPA2PSK networks 42 | return ConnectedNetwork{requestedSsid, toBssid(network->bssid), DOT11_AUTH_ALGORITHM_RSNA_PSK}; 43 | } 44 | 45 | std::future> ClientWlanInterface::Connect(const Ssid& requestedSsid, const Bssid&, const WlanSecurity&) 46 | { 47 | const auto clientNetworks = GetBssFromClient(); 48 | const auto network = 49 | std::ranges::find_if(clientNetworks, [&](const auto& n) { return n.ssid == requestedSsid; }); 50 | 51 | std::promise> promise; 52 | if (network == clientNetworks.cend()) 53 | { 54 | Log::Trace( 55 | L"Could not connect client interface %ws to ssid: %ws", 56 | GuidToString(m_interfaceGuid).c_str(), 57 | SsidToLogString(requestedSsid.value()).c_str()); 58 | promise.set_value({WlanStatus::UnspecifiedFailure, {}}); 59 | } 60 | else 61 | { 62 | Log::Trace( 63 | L"Connected client interface %ws to to ssid: %ws", 64 | GuidToString(m_interfaceGuid).c_str(), 65 | SsidToLogString(requestedSsid.value()).c_str()); 66 | promise.set_value({WlanStatus::Success, {requestedSsid, toBssid(network->bssid), DOT11_AUTH_ALGO_RSNA_PSK}}); 67 | } 68 | return promise.get_future(); 69 | } 70 | 71 | std::future ClientWlanInterface::Disconnect() 72 | { 73 | // Disconnect is a no-op for a fake interface 74 | std::promise promise; 75 | promise.set_value(); 76 | return promise.get_future(); 77 | } 78 | 79 | std::future ClientWlanInterface::Scan(std::optional&) 80 | { 81 | std::vector result; 82 | for (auto bss : GetBssFromClient()) 83 | { 84 | if (bss.ssid.uSSIDLength > c_wlan_max_ssid_len) 85 | { 86 | Log::Info(L"Ignoring an invalid client provided SSID (length: %d)", bss.ssid.uSSIDLength); 87 | continue; 88 | } 89 | 90 | // Create a wpa2psk network with the requested SSID and BSSID 91 | FakeBss fakeBss; 92 | fakeBss.capabilities = BssCapability::Ess | BssCapability::Privacy; 93 | fakeBss.ssid = bss.ssid; 94 | fakeBss.bssid = toBssid(bss.bssid); 95 | fakeBss.akmSuites = {AkmSuite::Psk}; 96 | fakeBss.cipherSuites = {CipherSuite::Ccmp}; 97 | fakeBss.groupCipher = CipherSuite::Ccmp; 98 | 99 | Log::Debug( 100 | L"Reporting client BSS, Bssid: %ws, Ssid: %ws, AkmSuites: {%ws}, CipherSuites: {%ws}, GroupCipher: %.8x, " 101 | L"ChannelCenterFreq: %d", 102 | BssidToString(fakeBss.bssid).c_str(), 103 | SsidToLogString(fakeBss.ssid.value()).c_str(), 104 | ListEnumToHexString(std::span{fakeBss.akmSuites}).c_str(), 105 | ListEnumToHexString(std::span{fakeBss.cipherSuites}).c_str(), 106 | fakeBss.groupCipher ? WI_EnumValue(*fakeBss.groupCipher) : 0, 107 | fakeBss.channelCenterFreq); 108 | 109 | result.emplace_back(fakeBss); 110 | } 111 | 112 | Log::Debug(L"%d BSS entries reported on client interface %ws", result.size(), GuidToString(m_interfaceGuid).c_str()); 113 | std::promise promise; 114 | promise.set_value({std::move(result), ScanStatus::Completed}); 115 | return promise.get_future(); 116 | } 117 | 118 | } // namespace ProxyWifi -------------------------------------------------------------------------------- /lib/ClientWlanInterface.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include "WlanInterface.hpp" 7 | 8 | #include 9 | 10 | #include 11 | #include 12 | 13 | #include "Networks.hpp" 14 | #include "Iee80211Utils.hpp" 15 | #include "ProxyWifi/ProxyWifiService.hpp" 16 | 17 | namespace ProxyWifi { 18 | 19 | /// @brief This class represent a fake wlan interface simulating networks provided by the lib client 20 | class ClientWlanInterface : public IWlanInterface 21 | { 22 | public: 23 | ClientWlanInterface(const GUID& interfaceGuid, std::function()> callback); 24 | 25 | void SetNotificationHandler(INotificationHandler* handler) override; 26 | const GUID& GetGuid() const noexcept override; 27 | std::optional IsConnectedTo(const Ssid& requestedSsid) noexcept override; 28 | std::future> Connect(const Ssid& requestedSsid, const Bssid& bssid, const WlanSecurity& securityInfo) override; 29 | std::future Disconnect() override; 30 | std::future Scan(std::optional& ssid) override; 31 | 32 | private: 33 | inline std::vector GetBssFromClient() const 34 | { 35 | if (m_getClientBssCallback) 36 | { 37 | return m_getClientBssCallback(); 38 | } 39 | return {}; 40 | } 41 | 42 | std::function()> m_getClientBssCallback; 43 | const GUID m_interfaceGuid = FakeInterfaceGuid; 44 | }; 45 | 46 | } // namespace ProxyWifi -------------------------------------------------------------------------------- /lib/Connection.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include 5 | 6 | #include "Connection.hpp" 7 | #include "ProxyWifi/Logs.hpp" 8 | #include "SocketHelpers.hpp" 9 | 10 | namespace ProxyWifi { 11 | 12 | Connection::Connection(std::shared_ptr operations) 13 | : m_operations(std::move(operations)) 14 | { 15 | } 16 | 17 | void Connection::Run() 18 | { 19 | std::optional request{}; 20 | 21 | try 22 | { 23 | request = ReceiveMessage(); 24 | } 25 | catch (...) 26 | { 27 | LOG_WIN32_MSG(ERROR_INVALID_DATA, "Failed to received a message"); 28 | return; 29 | } 30 | 31 | if (!request.has_value()) 32 | { 33 | LOG_WIN32_MSG(ERROR_NO_DATA, "No message received"); 34 | return; 35 | } 36 | 37 | Log::Trace(L"Received request <%ws> (%d bytes)", GetProtocolMessageTypeName(request->hdr.operation), request->hdr.size); 38 | 39 | try 40 | { 41 | Message responseMessage{}; 42 | if (request->hdr.version != proxy_wifi_version::VERSION_0_1) 43 | { 44 | throw std::runtime_error( 45 | "Unsuported request version: " + std::to_string(request->hdr.version) + 46 | ", version expected: " + std::to_string(proxy_wifi_version::VERSION_0_1)); 47 | } 48 | 49 | switch (request->hdr.operation) 50 | { 51 | case WIFI_OP_SCAN_REQUEST: 52 | { 53 | const auto command = ScanRequest{std::move(request->body)}; 54 | Log::Info(L"Received: %ws", command.Describe().c_str()); 55 | 56 | auto response = m_operations->HandleScanRequest(command); 57 | Log::Info(L"Answering: %ws", response.Describe().c_str()); 58 | 59 | responseMessage = ScanResponse::ToMessage(std::move(response)); 60 | break; 61 | } 62 | case WIFI_OP_CONNECT_REQUEST: 63 | { 64 | const auto command = ConnectRequest{std::move(request->body)}; 65 | 66 | Log::Info(L"Received: %ws", command.Describe().c_str()); 67 | auto response = m_operations->HandleConnectRequest(command); 68 | 69 | Log::Info(L"Answering: %ws", response.Describe().c_str()); 70 | responseMessage = ConnectResponse::ToMessage(std::move(response)); 71 | 72 | break; 73 | } 74 | case WIFI_OP_DISCONNECT_REQUEST: 75 | { 76 | const auto command = DisconnectRequest{std::move(request->body)}; 77 | Log::Info(L"Received: %ws", command.Describe().c_str()); 78 | 79 | auto response = m_operations->HandleDisconnectRequest(command); 80 | Log::Info(L"Answering: %ws", response.Describe().c_str()); 81 | 82 | responseMessage = DisconnectResponse::ToMessage(std::move(response)); 83 | break; 84 | } 85 | default: 86 | THROW_WIN32_MSG(ERROR_INVALID_DATA, "Ignoring unknown command ID: %d", request->hdr.operation); 87 | } 88 | 89 | Log::Trace( 90 | L"Answering with <%ws> (%d bytes)", GetProtocolMessageTypeName(responseMessage.hdr.operation), responseMessage.hdr.size); 91 | SendMessage(responseMessage); 92 | } 93 | catch (...) 94 | { 95 | // Inform the client there was an issue 96 | LOG_CAUGHT_EXCEPTION_MSG("Failed to process message, answering with an error message."); 97 | SendMessage({WIFI_INVALID, {}}); 98 | } 99 | } 100 | 101 | ConnectionSocket::ConnectionSocket(wil::unique_socket socket, const std::shared_ptr& operations) 102 | : Connection(operations), m_socket(std::move(socket)) 103 | { 104 | } 105 | 106 | void ConnectionSocket::SendMessage(const Message& message) 107 | { 108 | SendProxyWifiMessage(m_socket, message); 109 | } 110 | 111 | std::optional ConnectionSocket::ReceiveMessage() 112 | { 113 | return ReceiveProxyWifiMessage(m_socket); 114 | } 115 | 116 | } // namespace ProxyWifi -------------------------------------------------------------------------------- /lib/Connection.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | #pragma once 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | #include "Messages.hpp" 12 | #include "OperationHandler.hpp" 13 | 14 | namespace ProxyWifi { 15 | 16 | /// @brief Represents a connection with a proxy client. Defines an interface for 17 | /// sending and receiving protocol messages, and receiving protocol 18 | /// notifications. 19 | /// 20 | /// This is a base class meant to be derived for specific transports. 21 | class Connection 22 | { 23 | public: 24 | /// @brief Creates a new wifi proxy connection. 25 | /// 26 | /// @param operations The object used to handle protocol messages. 27 | Connection(std::shared_ptr operations); 28 | 29 | /// @brief Destroy the Connection object. 30 | virtual ~Connection() = default; 31 | 32 | /// @brief Start accepting requests on this connection. 33 | /// 34 | /// This must be invoked to enable communication on the connection and is 35 | /// stopped when the Teardown() function is called. 36 | void Run(); 37 | 38 | protected: 39 | /// @brief Receive a single protocol message on the connection. 40 | /// 41 | /// This call must block if no message is available on the socket. Once 42 | /// available, the message should be returned. 43 | /// 44 | /// @return std::optional The message received, if one was available. 45 | virtual std::optional ReceiveMessage() = 0; 46 | 47 | /// @brief Send a single protocol message on the connection. 48 | /// @param message The protocol message to send. 49 | virtual void SendMessage(const Message& message) = 0; 50 | 51 | private: 52 | const std::shared_ptr m_operations; 53 | }; 54 | 55 | /// @brief A connection that uses a generic socket as the transport. 56 | class ConnectionSocket : public Connection 57 | { 58 | public: 59 | /// @brief Create a new connection using a socket as the transport. 60 | /// @param socket The socket used for the connection 61 | /// @param operations The object used to handle protocol messages. 62 | ConnectionSocket(wil::unique_socket socket, const std::shared_ptr& operations); 63 | 64 | private: 65 | /// @brief Sends a protocol message on the connection. 66 | void SendMessage(const Message& message) override; 67 | 68 | /// @brief Receives a protocol message on the connection. 69 | std::optional ReceiveMessage() override; 70 | 71 | private: 72 | wil::unique_socket m_socket; 73 | }; 74 | 75 | } // namespace ProxyWifi -------------------------------------------------------------------------------- /lib/Iee80211Utils.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | #pragma once 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | /// @brief Defines 802.11 enums and types 17 | 18 | namespace ProxyWifi { 19 | 20 | static constexpr size_t c_wlan_max_ssid_len = DOT11_SSID_MAX_LENGTH; 21 | static constexpr size_t c_wlan_bssid_len = sizeof(DOT11_MAC_ADDRESS); 22 | 23 | enum class AuthAlgo : uint8_t 24 | { 25 | OpenSystem, 26 | SharedKey, 27 | Ft, 28 | NetworkEap, 29 | Sae, 30 | FilsSk, 31 | FilsSkPfs, 32 | FilsPk, 33 | }; 34 | 35 | inline constexpr uint32_t suite(uint32_t oui, uint8_t id) 36 | { 37 | return (oui << 8) | id; 38 | } 39 | 40 | enum class AkmSuite : uint32_t 41 | { 42 | OneX = suite(0x000FAC, 1), 43 | Psk = suite(0x000FAC, 2), 44 | Ft8021x = suite(0x000FAC, 3), 45 | FtPsk = suite(0x000FAC, 4), 46 | OneXSha256 = suite(0x000FAC, 5), 47 | PskSha256 = suite(0x000FAC, 6), 48 | Tdls = suite(0x000FAC, 7), 49 | Sae = suite(0x000FAC, 8), 50 | FtOverSae = suite(0x000FAC, 9), 51 | ApPeerKey = suite(0x000FAC, 10), 52 | OneXSuiteB = suite(0x000FAC, 11), 53 | OneXSuiteB192 = suite(0x000FAC, 12), 54 | Ft8021xSha384 = suite(0x000FAC, 13), 55 | FilsSha256 = suite(0x000FAC, 14), 56 | FilsSha384 = suite(0x000FAC, 15), 57 | FtFilsSha256 = suite(0x000FAC, 16), 58 | FtFilsSha384 = suite(0x000FAC, 17), 59 | Owe = suite(0x000FAC, 18), 60 | FtPskSha384 = suite(0x000FAC, 19), 61 | PskSha384 = suite(0x000FAC, 20) 62 | }; 63 | 64 | enum class CipherSuite : uint32_t 65 | { 66 | Wep40 = suite(0x000FAC, 1), 67 | Tkip = suite(0x000FAC, 2), 68 | // Reserved: 3 69 | Ccmp = suite(0x000FAC, 4), 70 | Wep104 = suite(0x000FAC, 5), 71 | AesCmac = suite(0x000FAC, 6), 72 | Gcmp = suite(0x000FAC, 8), 73 | Gcmp256 = suite(0x000FAC, 9), 74 | Ccmp256 = suite(0x000FAC, 10), 75 | BipGmac128 = suite(0x000FAC, 11), 76 | BipGmac256 = suite(0x000FAC, 12), 77 | BipCmac256 = suite(0x000FAC, 13) 78 | }; 79 | 80 | /// @brief 802.11 IE ids 81 | /// 82 | /// Incomplete, value can be added as needed 83 | enum class ElementId : uint8_t 84 | { 85 | Ssid = 0, 86 | Rsn = 48 87 | }; 88 | 89 | enum class BssCapability : uint16_t 90 | { 91 | Ess = 1 << 0, 92 | Ibss = 1 << 1, 93 | CfPollable = 1 << 2, 94 | CfPollRequest = 1 << 3, 95 | Privacy = 1 << 4, 96 | ShortPreamble = 1 << 5, 97 | Pbcc = 1 << 6, 98 | ChannelAgility = 1 << 7, 99 | SpectrumMgmt = 1 << 8, 100 | Qos = 1 << 9, 101 | ShortSlotTime = 1 << 10, 102 | Apsd = 1 << 11, 103 | RadioMeasure = 1 << 12, 104 | DsssOfdm = 1 << 13, 105 | DelBack = 1 << 14, 106 | ImmBack = 1 << 15, 107 | }; 108 | DEFINE_ENUM_FLAG_OPERATORS(BssCapability); 109 | 110 | /// @brief The 80211 status codes 111 | enum class WlanStatus : uint16_t 112 | { 113 | Success = 0, 114 | UnspecifiedFailure = 1, 115 | CapsUnsupported = 10, 116 | ReassocNoAssoc = 11, 117 | AssocDeniedUnspec = 12, 118 | NotSupportedAuthAlg = 13, 119 | UnknownAuthTransaction = 14, 120 | ChallengeFail = 15, 121 | AuthTimeout = 16, 122 | ApUnableToHandleNewSta = 17, 123 | AssocDeniedRates = 18, 124 | /* 802.11b */ 125 | AssocDeniedNoshortpreamble = 19, 126 | AssocDeniedNopbcc = 20, 127 | AssocDeniedNoagility = 21, 128 | /* 802.11h */ 129 | AssocDeniedNospectrum = 22, 130 | AssocRejectedBadPower = 23, 131 | AssocRejectedBadSuppChan = 24, 132 | /* 802.11g */ 133 | AssocDeniedNoshorttime = 25, 134 | AssocDeniedNodsssofdm = 26, 135 | /* 802.11w */ 136 | AssocRejectedTemporarily = 30, 137 | RobustMgmtFramePolicyViolation = 31, 138 | /* 802.11i */ 139 | InvalidIe = 40, 140 | InvalidGroupCipher = 41, 141 | InvalidPairwiseCipher = 42, 142 | InvalidAkmp = 43, 143 | UnsuppRsnVersion = 44, 144 | InvalidRsnIeCap = 45, 145 | CipherSuiteRejected = 46, 146 | /* 802.11e */ 147 | UnspecifiedQos = 32, 148 | AssocDeniedNobandwidth = 33, 149 | AssocDeniedLowack = 34, 150 | AssocDeniedUnsuppQos = 35, 151 | RequestDeclined = 37, 152 | InvalidQosParam = 38, 153 | ChangeTspec = 39, 154 | WaitTsDelay = 47, 155 | NoDirectLink = 48, 156 | StaNotPresent = 49, 157 | StaNotQsta = 50, 158 | /* 802.11s */ 159 | AntiClogRequired = 76, 160 | FcgNotSupp = 78, 161 | StaNoTbtt = 78, 162 | /* 802.11ad */ 163 | RejectedWithSuggestedChanges = 39, 164 | RejectedForDelayPeriod = 47, 165 | RejectWithSchedule = 83, 166 | PendingAdmittingFstSession = 86, 167 | PerformingFstNow = 87, 168 | PendingGapInBaWindow = 88, 169 | RejectUPidSetting = 89, 170 | RejectDseBand = 96, 171 | DeniedWithSuggestedBandAndChannel = 99, 172 | DeniedDueToSpectrumManagement = 103, 173 | /* 802.11ai */ 174 | FilsAuthenticationFailure = 108, 175 | UnknownAuthenticationServer = 109, 176 | SaeHashToElement = 126, 177 | SaePk = 127, 178 | }; 179 | 180 | class Ssid { 181 | 182 | public: 183 | Ssid() = default; 184 | 185 | Ssid(DOT11_SSID rhs) : 186 | m_ssid{rhs.ucSSID, rhs.ucSSID + rhs.uSSIDLength} 187 | { 188 | } 189 | 190 | Ssid(const std::span rhs) : 191 | m_ssid{rhs.begin(), rhs.end()} 192 | { 193 | if (rhs.size() > c_wlan_max_ssid_len) 194 | { 195 | throw std::invalid_argument("Ssid too long: " + std::to_string(rhs.size())); 196 | } 197 | } 198 | 199 | Ssid(const std::string_view rhs) : 200 | m_ssid{rhs.begin(), rhs.end()} 201 | { 202 | if (rhs.size() > c_wlan_max_ssid_len) 203 | { 204 | throw std::invalid_argument("Ssid too long: " + std::to_string(rhs.size())); 205 | } 206 | } 207 | 208 | operator DOT11_SSID() const noexcept 209 | { 210 | DOT11_SSID r; 211 | r.uSSIDLength = size(); 212 | std::copy_n(m_ssid.data(), m_ssid.size(), r.ucSSID); 213 | return r; 214 | } 215 | 216 | friend bool operator==(const Ssid& lhs, const Ssid& rhs) noexcept 217 | { 218 | return lhs.m_ssid == rhs.m_ssid; 219 | } 220 | 221 | friend bool operator!=(const Ssid& lhs, const Ssid& rhs) noexcept 222 | { 223 | return !(lhs == rhs); 224 | } 225 | 226 | friend bool operator==(const Ssid& lhs, const DOT11_SSID& rhs) noexcept 227 | { 228 | return lhs.m_ssid.size() == rhs.uSSIDLength && std::equal(lhs.m_ssid.begin(), lhs.m_ssid.end(), rhs.ucSSID); 229 | } 230 | 231 | friend bool operator==(const DOT11_SSID& lhs, const Ssid& rhs) noexcept 232 | { 233 | return rhs == lhs; 234 | } 235 | 236 | friend bool operator!=(const Ssid& lhs, const DOT11_SSID& rhs) noexcept 237 | { 238 | return !(lhs == rhs); 239 | } 240 | 241 | friend bool operator!=(const DOT11_SSID& lhs, const Ssid& rhs) noexcept 242 | { 243 | return !(lhs == rhs); 244 | } 245 | 246 | const std::vector& value() const noexcept 247 | { 248 | return m_ssid; 249 | } 250 | 251 | uint8_t size() const noexcept 252 | { 253 | // `m_ssid.size()` <= c_max_ssid_length = 32 as a class invariant 254 | return static_cast(m_ssid.size()); 255 | } 256 | 257 | private: 258 | std::vector m_ssid; 259 | }; 260 | 261 | using Bssid = std::array; 262 | 263 | inline Bssid toBssid(const uint8_t bssid[c_wlan_bssid_len]) 264 | { 265 | Bssid r; 266 | std::copy_n(bssid, c_wlan_bssid_len, r.begin()); 267 | return r; 268 | } 269 | 270 | struct WlanSecurity 271 | { 272 | AuthAlgo auth; 273 | uint8_t wpaVersion; 274 | std::vector akmSuites; 275 | std::vector cipherSuites; 276 | CipherSuite groupCipher; 277 | std::vector key; 278 | }; 279 | 280 | } // namespace ProxyWifi -------------------------------------------------------------------------------- /lib/Logs.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | #include "ProxyWifi/Logs.hpp" 4 | 5 | #include 6 | #include 7 | 8 | #include "Tracelog.hpp" 9 | 10 | namespace ProxyWifi::Log { 11 | 12 | namespace { 13 | constexpr std::array levelNames = {"Error", "Info", "Trace", "Debug"}; 14 | 15 | constexpr const char* LevelToCStr(Level lvl) noexcept 16 | { 17 | return levelNames[WI_EnumValue(lvl)]; 18 | } 19 | } // namespace 20 | 21 | void ConsoleLogger::Log(Level level, const wchar_t* message) noexcept 22 | { 23 | printf_s("%hs: %ws\n", LevelToCStr(level), message); 24 | } 25 | 26 | void TraceLoggingLogger::Log(Level level, const wchar_t* message) noexcept 27 | { 28 | switch (level) 29 | { 30 | case Level::Debug: 31 | TraceProvider::Debug(message); 32 | break; 33 | case Level::Trace: 34 | TraceProvider::Trace(message); 35 | break; 36 | case Level::Info: 37 | TraceProvider::Info(message); 38 | break; 39 | case Level::Error: 40 | TraceProvider::Error(message); 41 | break; 42 | } 43 | } 44 | 45 | } // namespace ProxyWifi::Log -------------------------------------------------------------------------------- /lib/LogsHelpers.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | #pragma once 4 | 5 | #include 6 | #include "ProxyWifi/Logs.hpp" 7 | 8 | namespace ProxyWifi { 9 | 10 | /// @brief Add a WIL failure callback for the current thread. 11 | /// This allows to log messages from WIL macro and exceptions 12 | inline auto SetThreadWilFailureLogger() 13 | { 14 | return wil::ThreadFailureCallback([](const wil::FailureInfo& failure) noexcept { 15 | constexpr std::size_t sizeOfLogMessageWithNul = 2048; 16 | 17 | wchar_t logMessage[sizeOfLogMessageWithNul]{}; 18 | wil::GetFailureLogString(logMessage, sizeOfLogMessageWithNul, failure); 19 | Log::WilFailure(logMessage); 20 | return false; // This doesn't report any telemetry 21 | }); 22 | } 23 | 24 | } // namespace ProxyWifi -------------------------------------------------------------------------------- /lib/Messages.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | #include "Messages.hpp" 4 | 5 | #include 6 | 7 | namespace ProxyWifi { 8 | 9 | bool ScanResponseBuilder::IsBssAlreadyPresent(const Bssid& bssid) 10 | { 11 | return std::ranges::find_if(m_bssList, [&](const auto& bss) { return bss.bssid == bssid; }) != m_bssList.cend(); 12 | } 13 | 14 | void ScanResponseBuilder::AddBss(ScannedBss bss) 15 | { 16 | if (IsBssAlreadyPresent(bss.bssid)) 17 | { 18 | return; 19 | } 20 | m_bssList.push_back(std::move(bss)); 21 | } 22 | 23 | void ScanResponseBuilder::SetScanComplete(bool isComplete) noexcept 24 | { 25 | m_scanComplete = isComplete; 26 | } 27 | 28 | ScanResponse ScanResponseBuilder::Build() const 29 | { 30 | auto allocSize = sizeof(proxy_wifi_scan_response) + m_bssList.size() * sizeof(proxy_wifi_bss); 31 | for (const auto& bss : m_bssList) 32 | { 33 | allocSize += bss.ies.size(); 34 | } 35 | 36 | ScanResponse scanResponse{allocSize, m_bssList.size(), m_scanComplete}; 37 | 38 | auto nextIe = scanResponse.getIes(); 39 | for (auto i = 0u; i < m_bssList.size(); ++i) 40 | { 41 | const auto& bss = m_bssList[i]; 42 | scanResponse->bss[i] = proxy_wifi_bss{ 43 | {}, bss.capabilities, bss.rssi, bss.beaconInterval, bss.channelCenterFreq, wil::safe_cast(bss.ies.size()), {}}; 44 | std::ranges::copy(bss.bssid, scanResponse->bss[i].bssid); 45 | std::ranges::copy(bss.ies, nextIe.data()); 46 | scanResponse->bss[i].ie_offset = 47 | wil::safe_cast(std::distance(reinterpret_cast(&scanResponse->bss[i]), nextIe.data())); 48 | 49 | nextIe = nextIe.subspan(bss.ies.size()); 50 | } 51 | 52 | return scanResponse; 53 | } 54 | 55 | } // namespace ProxyWifi -------------------------------------------------------------------------------- /lib/Messages.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | /// @brief Helper types to build messages 5 | 6 | #pragma once 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | #include 18 | 19 | #include "Iee80211Utils.hpp" 20 | #include "Networks.hpp" 21 | #include "Protocol.hpp" 22 | #include "StringUtils.hpp" 23 | 24 | namespace ProxyWifi { 25 | 26 | static constexpr std::array operation_names = { 27 | /* WIFI_INVALID */ L"Invalid", 28 | /* WIFI_OP_SCAN_REQUEST */ L"ScanRequest", 29 | /* WIFI_OP_SCAN_RESPONSE */ L"ScanResponse", 30 | /* WIFI_OP_CONNECT_REQUEST */ L"ConnectRequest", 31 | /* WIFI_OP_CONNECT_RESPONSE */ L"ConnectResponse", 32 | /* WIFI_OP_DISCONNECT_REQUEST */ L"DisconnectRequest", 33 | /* WIFI_OP_DISCONNECT_RESPONSE */ L"DisconnectResponse", 34 | /* WIFI_NOTIF_DISCONNECTED */ L"EventDisconnected", 35 | /* WIFI_NOTIF_SIGNAL_QUALITY */ L"EventSignalQuality"}; 36 | 37 | static_assert(operation_names.size() == WIFI_OP_MAX); 38 | 39 | constexpr const wchar_t* GetProtocolMessageTypeName(uint8_t operation) noexcept 40 | { 41 | return (operation < WIFI_OP_MAX) ? operation_names[operation] : L"Invalid"; 42 | } 43 | 44 | /// @brief Helper to manipulate full messages (header + body) 45 | struct Message 46 | { 47 | Message() = default; 48 | 49 | Message(proxy_wifi_operation op, std::vector buffer) 50 | : hdr{op, wil::safe_cast(buffer.size()), proxy_wifi_version::VERSION_0_1}, body{std::move(buffer)} 51 | { 52 | } 53 | 54 | struct proxy_wifi_hdr hdr{}; 55 | std::vector body; 56 | }; 57 | 58 | /// @brief Return the size of a message. 59 | /// Should be specialized for messages which size is larger than the size of the 60 | /// type representing it (VLA, data happened to the message itself...). 61 | template 62 | inline uint32_t size(const T& msg) 63 | { 64 | return wil::safe_cast(sizeof(msg)); 65 | } 66 | 67 | template <> 68 | inline uint32_t size(const proxy_wifi_disconnect_response&) 69 | { 70 | // sizeof = 1 for an empty struct 71 | return 0; 72 | } 73 | 74 | template <> 75 | inline uint32_t size(const proxy_wifi_scan_response& msg) 76 | { 77 | return msg.total_size; 78 | } 79 | 80 | template <> 81 | inline uint32_t size(const proxy_wifi_connect_request& msg) 82 | { 83 | return sizeof(proxy_wifi_connect_request) + msg.key_len; 84 | } 85 | 86 | /// @brief Handle a buffer of bytes and allow to view it as a message body of the specified type 87 | template 88 | class StructuredBuffer 89 | { 90 | using MsgType = T; 91 | 92 | public: 93 | StructuredBuffer() 94 | : m_buffer(sizeof(MsgType)) 95 | { 96 | } 97 | 98 | explicit StructuredBuffer(size_t allocSize) 99 | : m_buffer(std::max(allocSize, sizeof(MsgType))) 100 | { 101 | } 102 | 103 | StructuredBuffer(std::vector buffer) 104 | : m_buffer{std::move(buffer)} 105 | { 106 | if (m_buffer.size() < sizeof(MsgType)) 107 | { 108 | throw std::invalid_argument( 109 | "Message too small: " + std::to_string(buffer.size()) + "bytes but " + std::to_string(sizeof(MsgType)) + 110 | "at least bytes expected"); 111 | } 112 | 113 | // `size()` might access the fixed data to get the size of variable length data, so a check with `sizeof` must 114 | // be done first 115 | if (m_buffer.size() != size(*get())) 116 | { 117 | throw std::invalid_argument( 118 | "Unexpected message size:" + std::to_string(buffer.size()) + "bytes but " + std::to_string(size(*get())) + 119 | "bytes expected"); 120 | } 121 | } 122 | 123 | MsgType* get() 124 | { 125 | return reinterpret_cast(m_buffer.data()); 126 | } 127 | 128 | MsgType* operator*() 129 | { 130 | return get(); 131 | } 132 | 133 | MsgType* operator->() 134 | { 135 | return get(); 136 | } 137 | 138 | const MsgType* get() const 139 | { 140 | return reinterpret_cast(m_buffer.data()); 141 | } 142 | 143 | const MsgType* operator*() const 144 | { 145 | return get(); 146 | } 147 | 148 | const MsgType* operator->() const 149 | { 150 | return get(); 151 | } 152 | 153 | std::span AsBytes() 154 | { 155 | return std::span{m_buffer}; 156 | } 157 | 158 | static Message ToMessage(StructuredBuffer&& buffer) 159 | { 160 | // Get the size before moving the buffer, or it may get invalidated 161 | return Message(Operation, std::move(buffer.m_buffer)); 162 | } 163 | 164 | std::wstring Describe() const 165 | { 166 | return GetProtocolMessageTypeName(Operation); 167 | } 168 | 169 | private: 170 | std::vector m_buffer; 171 | }; 172 | 173 | class ConnectRequest: public StructuredBuffer 174 | { 175 | public: 176 | ConnectRequest(std::vector buffer) 177 | : StructuredBuffer{std::move(buffer)} 178 | { 179 | } 180 | 181 | std::wstring Describe() const 182 | { 183 | std::wostringstream stream; 184 | stream << L"Connect request, Ssid: " 185 | << SsidToLogString({get()->ssid, std::min(c_wlan_max_ssid_len, wil::safe_cast(get()->ssid_len))}); 186 | stream << L", Bssid: " << BssidToString(get()->bssid); 187 | stream << L", Auth: " << get()->auth_type; 188 | stream << L", WPA version: " << get()->wpa_versions; 189 | 190 | stream << std::hex << std::setfill(L'0'); 191 | stream << L", AKM Suites: {"; 192 | for (const auto& akm : 193 | wil::make_range(get()->akm_suites, std::min(wil::safe_cast(get()->num_akm_suites), c_wlan_max_akm_suites))) 194 | { 195 | stream << L" 0x" << std::setw(8) << akm; 196 | } 197 | stream << L" }, Pairwise Cipher Suites: {"; 198 | for (const auto& cipher : wil::make_range( 199 | get()->pairwise_cipher_suites, 200 | std::min(wil::safe_cast(get()->num_pairwise_cipher_suites), c_wlan_max_pairwise_cipher_suites))) 201 | { 202 | stream << L" 0x" << std::setw(8) << cipher; 203 | } 204 | stream << L" }, Group Cipher Suite: 0x" << std::setw(8) << get()->group_cipher_suite; 205 | stream << std::dec << std::setfill(L' '); 206 | 207 | stream << L", Key present: " + (get()->key_len > 0 ? std::wstring(L"True") : std::wstring(L"False")); 208 | return stream.str(); 209 | } 210 | }; 211 | 212 | class ConnectResponse: public StructuredBuffer 213 | { 214 | public: 215 | ConnectResponse(WlanStatus resultCode, std::span bssid, uint64_t sessionId) 216 | { 217 | get()->result_code = WI_EnumValue(resultCode); 218 | memcpy_s(get()->bssid, sizeof get()->bssid, bssid.data(), bssid.size()); 219 | get()->session_id = sessionId; 220 | } 221 | 222 | std::wstring Describe() const 223 | { 224 | return L"Connect response, Result code: " + std::to_wstring(get()->result_code) + L", Session id: " + 225 | std::to_wstring(get()->session_id) + L", BssId: " + BssidToString(get()->bssid); 226 | } 227 | }; 228 | 229 | class DisconnectRequest: public StructuredBuffer 230 | { 231 | public: 232 | DisconnectRequest(std::vector buffer) 233 | : StructuredBuffer{std::move(buffer)} 234 | { 235 | } 236 | 237 | std::wstring Describe() const 238 | { 239 | return L"Disconnect request, Session id: " + std::to_wstring(get()->session_id); 240 | } 241 | }; 242 | 243 | class DisconnectResponse: public StructuredBuffer 244 | { 245 | }; 246 | 247 | class ScanRequest : public StructuredBuffer 248 | { 249 | public: 250 | ScanRequest(std::vector buffer) 251 | : StructuredBuffer{std::move(buffer)} 252 | { 253 | } 254 | 255 | std::wstring Describe() const 256 | { 257 | return L"Scan request, Target ssid: " + 258 | (get()->ssid_len == 0 259 | ? L"*" 260 | : SsidToLogString({get()->ssid, std::min(c_wlan_max_ssid_len, wil::safe_cast(get()->ssid_len))})); 261 | } 262 | }; 263 | 264 | class ScanResponse : public StructuredBuffer 265 | { 266 | public: 267 | ScanResponse(size_t totalSize, size_t numBss, bool scanComplete) 268 | : StructuredBuffer{totalSize}, 269 | m_ies{AsBytes().subspan(sizeof(proxy_wifi_scan_response) + numBss * sizeof(proxy_wifi_bss))} 270 | { 271 | get()->total_size = wil::safe_cast(totalSize); 272 | get()->num_bss = wil::safe_cast(numBss); 273 | get()->scan_complete = scanComplete; 274 | } 275 | 276 | std::span getIes() const 277 | { 278 | return m_ies; 279 | } 280 | 281 | std::wstring Describe() const 282 | { 283 | return L"Scan response, Scan complete: " + std::wstring(get()->scan_complete ? L"true" : L"false") + 284 | L", Number of reported Bss: " + std::to_wstring(get()->num_bss) + L", Total size " + 285 | std::to_wstring(get()->total_size) + L" bytes"; 286 | } 287 | 288 | private: 289 | std::span m_ies; 290 | }; 291 | 292 | /// @brief Builder class to create a scan response 293 | /// 294 | /// Since all information elements are appended at the end of the message and are accessed through offsets, 295 | /// it is necessary to collect all the results first to allocate and build the response message. 296 | class ScanResponseBuilder 297 | { 298 | public: 299 | void AddBss(ScannedBss bss); 300 | ScanResponse Build() const; 301 | void SetScanComplete(bool isComplete) noexcept; 302 | 303 | private: 304 | bool IsBssAlreadyPresent(const Bssid& bssid); 305 | bool m_scanComplete = false; 306 | std::vector m_bssList; 307 | }; 308 | 309 | class DisconnectNotif: public StructuredBuffer 310 | { 311 | public: 312 | explicit DisconnectNotif(uint64_t sessionId) 313 | { 314 | get()->session_id = sessionId; 315 | } 316 | 317 | std::wstring Describe() const 318 | { 319 | return L"Disconnect notification, Session id: " + std::to_wstring(get()->session_id); 320 | } 321 | }; 322 | 323 | class SignalQualityNotif: public StructuredBuffer 324 | { 325 | public: 326 | explicit SignalQualityNotif(int8_t signal) 327 | { 328 | get()->signal = signal; 329 | } 330 | 331 | std::wstring Describe() const 332 | { 333 | return L"Signal quality notification, Signal: " + std::to_wstring(get()->signal); 334 | } 335 | }; 336 | 337 | } // namespace ProxyWifi -------------------------------------------------------------------------------- /lib/Networks.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "Networks.hpp" 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | namespace ProxyWifi { 12 | 13 | namespace { 14 | 15 | enum class Endianness 16 | { 17 | Little, 18 | Big 19 | }; 20 | 21 | /// @brief Append an unsigned integer to a vector as a list of bytes with specified endianness 22 | /// @tparam N The number of bytes to append to the vector 23 | /// @tparam E The endianness (Little = CPU order, Big = Network order) 24 | /// @example appendBytes<4, Endianness::Big>(vec, 0x000fac04) -> vec = {..., 0x00, 0x0f, 0xac, 0x04} 25 | template && !std::is_signed_v, int> = 1> 26 | void appendBytes(std::vector& vector, V value) 27 | { 28 | vector.reserve(vector.size() + N); 29 | for (auto i = 0u; i < N; ++i) 30 | { 31 | auto bitShift = 0u; 32 | if constexpr (E == Endianness::Big) 33 | { 34 | bitShift = (N - 1 - i) * 8; 35 | } 36 | else 37 | { 38 | bitShift = i * 8; 39 | } 40 | auto byte = static_cast(value >> bitShift); 41 | vector.push_back(byte); 42 | } 43 | } 44 | 45 | } // namespace 46 | 47 | size_t FakeBss::IeAllocationSize() const 48 | { 49 | constexpr auto ieHeaderSize = 2; // IE id + IE size 50 | constexpr auto rsnIeConstantSize = 12; // Version, Group cipher, Num ciphers, Num akms, capabilities 51 | constexpr auto suiteBlockSize = 4; 52 | 53 | return ieHeaderSize + ssid.size() + ieHeaderSize + rsnIeConstantSize + (akmSuites.size() + cipherSuites.size()) * suiteBlockSize; 54 | } 55 | 56 | std::vector FakeBss::BuildInformationElements() const 57 | { 58 | std::vector ies; 59 | const std::vector& rawSsid = ssid.value(); 60 | if (!rawSsid.empty()) 61 | { 62 | // Build an SSID element 63 | if (rawSsid.size() > c_wlan_max_ssid_len) 64 | { 65 | throw std::invalid_argument("Invalid ssid length: " + std::to_string(rawSsid.size())); 66 | } 67 | 68 | ies.insert(ies.end(), {WI_EnumValue(ElementId::Ssid), static_cast(rawSsid.size())}); 69 | ies.insert(ies.end(), rawSsid.begin(), rawSsid.end()); 70 | } 71 | 72 | if (!akmSuites.empty()) 73 | { 74 | assert(!cipherSuites.empty() && groupCipher.has_value()); 75 | 76 | // Build an RSNIE 77 | // Warning: Assume all provided akm and cipher are compatibles 78 | constexpr auto rsnIeBaseSize = 12; // Version, Group cipher, Num ciphers, Num akms, capabilities 79 | constexpr auto suiteBlockSize = 4; 80 | const auto rsnIeSize = wil::safe_cast(rsnIeBaseSize + (akmSuites.size() + cipherSuites.size()) * suiteBlockSize); 81 | 82 | ies.insert(ies.end(), {WI_EnumValue(ElementId::Rsn), rsnIeSize}); 83 | 84 | // Add the version 85 | appendBytes<2, Endianness::Little>(ies, 0x0001u); 86 | 87 | // Add the group cipher 88 | appendBytes<4, Endianness::Big>(ies, WI_EnumValue(*groupCipher)); 89 | 90 | // Add the pairwise cipher suites 91 | appendBytes<2, Endianness::Little>(ies, cipherSuites.size()); 92 | for (const auto cipher : cipherSuites) 93 | { 94 | appendBytes<4, Endianness::Big>(ies, WI_EnumValue(cipher)); 95 | } 96 | 97 | // Add the akm suites 98 | appendBytes<2, Endianness::Little>(ies, akmSuites.size()); 99 | for (const auto akm : akmSuites) 100 | { 101 | appendBytes<4, Endianness::Big>(ies, WI_EnumValue(akm)); 102 | } 103 | 104 | // Add the RSN capabilities 105 | appendBytes<2, Endianness::Little>(ies, 0x0000u); 106 | } 107 | 108 | return ies; 109 | } 110 | 111 | ScannedBss::ScannedBss(const FakeBss& fakeBss) 112 | : bssid{fakeBss.bssid}, 113 | capabilities{WI_EnumValue(fakeBss.capabilities)}, 114 | rssi{fakeBss.rssi}, 115 | channelCenterFreq{fakeBss.channelCenterFreq}, 116 | beaconInterval{fakeBss.beaconInterval}, 117 | ies{fakeBss.BuildInformationElements()} 118 | { 119 | } 120 | 121 | ScannedBss::ScannedBss(Bssid bssid, Ssid ssid, uint16_t capabilities, int8_t rssi, uint32_t channelCenterFreq, uint16_t beaconInterval, std::vector ies) 122 | : bssid{bssid}, ssid{std::move(ssid)}, capabilities{capabilities}, rssi{rssi}, channelCenterFreq{channelCenterFreq}, beaconInterval{beaconInterval}, ies{std::move(ies)} 123 | { 124 | } 125 | 126 | } // namespace ProxyWifi -------------------------------------------------------------------------------- /lib/Networks.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include "Iee80211Utils.hpp" 11 | 12 | namespace ProxyWifi { 13 | 14 | /// @brief Define fake bss, to emulate a non-existing network 15 | struct FakeBss 16 | { 17 | BssCapability capabilities = BssCapability::Ess; 18 | int32_t rssi = -50; 19 | uint32_t channelCenterFreq = 5240000; // 5GHz by default 20 | uint16_t beaconInterval = 0; 21 | Bssid bssid{}; 22 | Ssid ssid; 23 | std::vector akmSuites; 24 | std::vector cipherSuites; 25 | std::optional groupCipher; 26 | std::vector key; 27 | 28 | std::vector BuildInformationElements() const; 29 | 30 | /// @brief The quantity of memory needed to store contain the IE for this network. 31 | /// The value is an approximation, it may be bigger than the actual size needed. 32 | size_t IeAllocationSize() const; 33 | }; 34 | 35 | /// @brief Information about a BSS reported by a scan 36 | struct ScannedBss 37 | { 38 | ScannedBss() = default; 39 | explicit ScannedBss(const FakeBss& fakeBss); 40 | ScannedBss(Bssid bssid, Ssid ssid, uint16_t capabilities, int8_t rssi, uint32_t channelCenterFreq, uint16_t beaconInterval, std::vector ies); 41 | 42 | Bssid bssid{}; 43 | Ssid ssid; 44 | uint16_t capabilities = 0; 45 | int32_t rssi = -50; 46 | uint32_t channelCenterFreq = 5240000; // 5GHz by default 47 | uint16_t beaconInterval = 0; 48 | std::vector ies; 49 | }; 50 | 51 | struct ConnectedNetwork 52 | { 53 | Ssid ssid; 54 | Bssid bssid{}; 55 | DOT11_AUTH_ALGORITHM auth = DOT11_AUTH_ALGO_80211_OPEN; 56 | }; 57 | 58 | enum class ScanStatus 59 | { 60 | Running, 61 | Completed 62 | }; 63 | 64 | } // namespace ProxyWifi -------------------------------------------------------------------------------- /lib/OperationHandler.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | #pragma once 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include "Messages.hpp" 13 | #include "WlanInterface.hpp" 14 | #include "WorkQueue.hpp" 15 | #include "ProxyWifi/ProxyWifiService.hpp" 16 | 17 | namespace ProxyWifi { 18 | 19 | // Interface for operation handlers 20 | // Classes implementing this interface can handle and build a response to request from the guest 21 | class OperationHandler: private INotificationHandler 22 | { 23 | public: 24 | OperationHandler(ProxyWifiObserver* pObserver, std::vector> wlanInterfaces) 25 | : m_pClientObserver{pObserver}, m_wlanInterfaces{std::move(wlanInterfaces)} 26 | { 27 | for (const auto& wlanIntf: m_wlanInterfaces) 28 | { 29 | wlanIntf->SetNotificationHandler(this); 30 | } 31 | } 32 | 33 | ~OperationHandler() override 34 | { 35 | // First, destroy all interfaces so no notification will be queued on a destroyed workqueue 36 | m_wlanInterfaces.clear(); 37 | 38 | // Then cancel async works before the object destruction to ensure nothing reference `this` 39 | m_serializedRunner.Cancel(); 40 | } 41 | 42 | OperationHandler(const OperationHandler&) = delete; 43 | OperationHandler(OperationHandler&&) = delete; 44 | OperationHandler& operator=(const OperationHandler&) = delete; 45 | OperationHandler& operator=(OperationHandler&&) = delete; 46 | 47 | /// @brief Add an interface to the operation handler 48 | /// Takes a builder function instead of the interface directly to create the interface in the work queue 49 | /// (needed to avoid deadlocks when an interface is created - and subscribe to notifications - from a notification thread) 50 | void AddInterface(const std::function()>& wlanInterfaceBuilder); 51 | void RemoveInterface(const GUID& interfaceGuid); 52 | 53 | ConnectResponse HandleConnectRequest(const ConnectRequest& connectRequest); 54 | DisconnectResponse HandleDisconnectRequest(const DisconnectRequest& disconnectRequest); 55 | ScanResponse HandleScanRequest(const ScanRequest& scanRequest); 56 | 57 | using GuestNotificationTypes = std::variant; 58 | void RegisterGuestNotificationCallback(std::function notificationCallback); 59 | void ClearGuestNotificationCallback(); 60 | 61 | /// @brief Wait all client notifications have been processed and return 62 | /// Unit test helper 63 | void DrainWorkqueues(); 64 | 65 | protected: 66 | 67 | /// @brief Must be called by the interfaces when they connect to a network 68 | void OnHostConnection(const GUID& interfaceGuid, const Ssid& ssid, DOT11_AUTH_ALGORITHM authAlgo) override; 69 | /// @brief Must be called by the interfaces when they disconnect to a network 70 | void OnHostDisconnection(const GUID& interfaceGuid, const Ssid& ssid) override; 71 | /// @brief Must be called by the interfaces when the signal quality changes 72 | void OnHostSignalQualityChange(const GUID& interfaceGuid, unsigned long signalQuality) override; 73 | /// @brief Must be called by the interfaces when scan results are available 74 | void OnHostScanResults(const GUID& interfaceGuid, const std::vector& scannedBss, ScanStatus status) override; 75 | 76 | private: 77 | 78 | // These functions do the actual handling of the request from a seriliazed work queue 79 | ConnectResponse HandleConnectRequestSerialized(const ConnectRequest& connectRequest); 80 | DisconnectResponse HandleDisconnectRequestSerialized(const DisconnectRequest& disconnectRequest); 81 | ScanResponse HandleScanRequestSerialized(const ScanRequest& scanRequest); 82 | 83 | /// @brief Send a notification to the guest 84 | void SendGuestNotification(GuestNotificationTypes notif); 85 | 86 | /// @brief Notify the guest of guest operation request and completion 87 | ProxyWifiObserver::Authorization AuthorizeGuestConnectionRequest(OperationType type, const Ssid& ssid) noexcept; 88 | void OnGuestConnectionCompletion(OperationType type, OperationStatus status, const GUID& interfaceGuid, const Ssid& ssid, DOT11_AUTH_ALGORITHM authAlgo) noexcept; 89 | void OnGuestDisconnectionRequest(OperationType type, const Ssid& ssid) noexcept; 90 | void OnGuestDisconnectionCompletion(OperationType type, OperationStatus status, const GUID& interfaceGuid, const Ssid& ssid) noexcept; 91 | void OnGuestScanRequest() noexcept; 92 | void OnGuestScanCompletion(OperationStatus status) noexcept; 93 | 94 | std::shared_mutex m_notificationLock; 95 | std::function m_notificationCallback; 96 | 97 | /// @brief Client provided object to notify client of various events 98 | ProxyWifiObserver* m_pClientObserver = nullptr; 99 | 100 | enum class ConnectionType 101 | { 102 | Mirrored, 103 | GuestDirected 104 | }; 105 | 106 | struct ConnectionInfo 107 | { 108 | /// @brief Identify how the guest connection was initiated 109 | ConnectionType type; 110 | /// @brief Identify the host interface corresponding to the guest connection 111 | GUID interfaceGuid; 112 | /// @brief Identify the ssid the guest is connected 113 | Ssid ssid; 114 | }; 115 | 116 | std::optional m_guestConnection; 117 | 118 | /// @brief Number identifying the current connection session in the host 119 | /// It allows to keep the host and guest in sync in some race scenarios, e.g: 120 | /// 1) HostInitiated send disconnect notif 121 | /// 2) Guest send connect request before receiving disconnect notif 122 | /// 3) HostInitiated process connect request, connect, answer 123 | /// 4) Guest process disconnect notif (blocked in queue while connect request pending) 124 | /// 5) Session ID is expired, prevent the disconnection in the guest 125 | std::atomic m_sessionId{}; 126 | 127 | /// @brief Set of interfaces currently executing a scan 128 | std::set m_scanningInterfaces; 129 | 130 | std::vector> m_wlanInterfaces; 131 | 132 | /// @brief Serialized workqueue processing guest requests and guest notifications 133 | SerializedWorkRunner m_serializedRunner; 134 | /// @brief Serialized workqueue sending client notifications 135 | SerializedWorkRunner m_clientNotificationQueue; 136 | }; 137 | 138 | } // namespace ProxyWifi -------------------------------------------------------------------------------- /lib/OperationHandlerBuilder.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | #pragma once 4 | 5 | #include 6 | 7 | #include 8 | 9 | #include "ProxyWifi/Logs.hpp" 10 | #include "OperationHandler.hpp" 11 | #include "ClientWlanInterface.hpp" 12 | #include "RealWlanInterface.hpp" 13 | #include "TestWlanInterface.hpp" 14 | #include "WlansvcOperationHandler.hpp" 15 | 16 | namespace ProxyWifi { 17 | 18 | inline std::unique_ptr MakeWlansvcOperationHandler(std::shared_ptr wlansvc, FakeNetworkProvider fakeNetworkCallback, ProxyWifiObserver* pObserver) 19 | { 20 | std::vector> wlanInterfaces; 21 | // Add an interface for user defined networks, if a callback is provided. Must be first to take priority over the other interfaces 22 | if (fakeNetworkCallback) 23 | { 24 | Log::Info(L"Adding client interface %ws", GuidToString(FakeInterfaceGuid).c_str()); 25 | wlanInterfaces.push_back(std::make_unique(FakeInterfaceGuid, std::move(fakeNetworkCallback))); 26 | } 27 | 28 | // Add the real wlan interfaces 29 | if (wlansvc) 30 | { 31 | const auto interfaces = wlansvc->EnumerateInterfaces(); 32 | for (const auto& i : interfaces) 33 | { 34 | Log::Info(L"Adding interface %ws", GuidToString(i).c_str()); 35 | try 36 | { 37 | wlanInterfaces.push_back(std::make_unique(wlansvc, i)); 38 | } 39 | catch (...) 40 | { 41 | LOG_CAUGHT_EXCEPTION_MSG("Failed to initialize a wlansvc interface. Skipping it."); 42 | } 43 | } 44 | 45 | Log::Info(L"Creating a Wlansvc enabled operation handler"); 46 | return std::make_unique(pObserver, std::move(wlanInterfaces), wlansvc); 47 | } 48 | else 49 | { 50 | // Without wlansvc, we can't handle interfaces arrival/departures anyway, so an `OperationHandler` is enough 51 | Log::Info(L"Creating a client network only operation handler"); 52 | return std::make_unique(pObserver, std::move(wlanInterfaces)); 53 | } 54 | } 55 | 56 | inline std::unique_ptr MakeManualTestOperationHandler(FakeNetworkProvider fakeNetworkCallback, ProxyWifiObserver* pObserver) 57 | { 58 | std::vector> wlanInterfaces; 59 | // Add an interface for user defined networks if a callback is provided 60 | if (fakeNetworkCallback) 61 | { 62 | wlanInterfaces.push_back(std::make_unique(FakeInterfaceGuid, std::move(fakeNetworkCallback))); 63 | } 64 | 65 | // Add a test interface simulating networks 66 | wlanInterfaces.push_back( 67 | std::make_unique(GUID{0xc386c570, 0xf576, 0x4f7e, {0xbf, 0x19, 0xd2, 0x32, 0x3a, 0xf8, 0xdd, 0x19}})); 68 | 69 | return std::make_unique(pObserver, std::move(wlanInterfaces)); 70 | } 71 | 72 | } // namespace ProxyWifi -------------------------------------------------------------------------------- /lib/Protocol.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | /// @brief Defines the types used in messages exchanged with the virt_wifi driver 5 | /// Types defined here **must** have the same binary layout as their counterpart in the driver 6 | /// 7 | /// # Protocol 8 | /// 9 | /// ## Message structure 10 | /// Each message is composed of a header `proxy_wifi_hdr` followed by a body. 11 | /// The operation code in the header indicate the type of the body, and its size 12 | /// must match the size indicated in the header. 13 | /// 14 | /// ## Message categories 15 | /// 16 | /// ### Request / Response 17 | /// Requests are messages sent by the guest to the host. Each request must be followed by a 18 | /// response from the host to the guest, whether it has been handled successfuly or not. 19 | /// The guest should wait for this response before sending another request. 20 | /// 21 | /// ### Notification 22 | /// Notifications are messages sent spontaneously by the host to the guest. 23 | /// No response is expected after a notification. 24 | 25 | #pragma once 26 | 27 | #include "Iee80211Utils.hpp" 28 | 29 | namespace ProxyWifi { 30 | 31 | static constexpr size_t c_wlan_max_akm_suites = 2; 32 | static constexpr size_t c_wlan_max_pairwise_cipher_suites = 5; 33 | 34 | #pragma pack(push, 1) 35 | 36 | enum proxy_wifi_operation : uint8_t 37 | { 38 | WIFI_INVALID = 0, 39 | WIFI_OP_SCAN_REQUEST, 40 | WIFI_OP_SCAN_RESPONSE, 41 | WIFI_OP_CONNECT_REQUEST, 42 | WIFI_OP_CONNECT_RESPONSE, 43 | WIFI_OP_DISCONNECT_REQUEST, 44 | WIFI_OP_DISCONNECT_RESPONSE, 45 | WIFI_NOTIF_DISCONNECTED, 46 | WIFI_NOTIF_SIGNAL_QUALITY, 47 | WIFI_OP_MAX 48 | }; 49 | 50 | enum proxy_wifi_version : uint16_t 51 | { 52 | VERSION_0_1 = 0x0001 53 | }; 54 | 55 | struct proxy_wifi_hdr 56 | { 57 | proxy_wifi_operation operation; 58 | uint32_t size; 59 | proxy_wifi_version version; 60 | }; 61 | 62 | struct proxy_wifi_scan_request 63 | { 64 | uint8_t ssid_len; 65 | uint8_t ssid[c_wlan_max_ssid_len]; 66 | }; 67 | 68 | struct proxy_wifi_bss 69 | { 70 | uint8_t bssid[c_wlan_bssid_len]; 71 | uint16_t capabilities; 72 | int32_t rssi; 73 | uint16_t beacon_interval; 74 | uint32_t channel_center_freq; 75 | uint32_t ie_size; 76 | uint32_t ie_offset; 77 | }; 78 | 79 | /// @brief A list of bss information 80 | /// 81 | /// The information elements for each BSS are appended to the structure (allocated in the same memory block) 82 | /// and can be accessed using the `ie_offset` and `ie_size` field of the `proxy_wifi_bss` structure. 83 | /// | num_bss | total_size | bss 1 | ... | bss n | ie bss 1 | ... | ie bss n | 84 | #pragma warning(disable : 4200) 85 | struct proxy_wifi_scan_response 86 | { 87 | uint8_t scan_complete; 88 | uint32_t num_bss; 89 | uint32_t total_size; 90 | proxy_wifi_bss bss[]; 91 | }; 92 | 93 | #pragma warning(disable : 4200) 94 | struct proxy_wifi_connect_request 95 | { 96 | uint8_t ssid_len; 97 | uint8_t ssid[c_wlan_max_ssid_len]; 98 | uint8_t bssid[c_wlan_bssid_len]; 99 | uint8_t auth_type; 100 | uint8_t wpa_versions; 101 | uint8_t num_akm_suites; 102 | uint32_t akm_suites[c_wlan_max_akm_suites]; 103 | uint8_t num_pairwise_cipher_suites; 104 | uint32_t pairwise_cipher_suites[c_wlan_max_pairwise_cipher_suites]; 105 | uint32_t group_cipher_suite; 106 | uint8_t key_len; 107 | uint8_t key[]; 108 | }; 109 | 110 | struct proxy_wifi_connect_response 111 | { 112 | uint16_t result_code; 113 | uint8_t bssid[c_wlan_bssid_len]; 114 | uint64_t session_id; 115 | }; 116 | 117 | struct proxy_wifi_disconnect_request 118 | { 119 | uint64_t session_id; 120 | }; 121 | 122 | struct proxy_wifi_disconnect_response 123 | { 124 | }; 125 | 126 | struct proxy_wifi_disconnect_notif 127 | { 128 | uint64_t session_id; 129 | }; 130 | 131 | struct proxy_wifi_signal_quality_notif 132 | { 133 | int8_t signal; 134 | }; 135 | 136 | #pragma pack(pop) 137 | 138 | } // namespace ProxyWifi -------------------------------------------------------------------------------- /lib/ProxyWifiServiceImpl.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | #include 4 | 5 | #include 6 | 7 | #include "ClientWlanInterface.hpp" 8 | #include "OperationHandlerBuilder.hpp" 9 | #include "StringUtils.hpp" 10 | #include "ProxyWifiServiceImpl.hpp" 11 | 12 | namespace ProxyWifi { 13 | 14 | namespace { 15 | 16 | constexpr const char* GetProxyModeName(OperationMode mode) 17 | { 18 | switch (mode) 19 | { 20 | case OperationMode::Simulated: 21 | return "Simulated"; 22 | case OperationMode::Normal: 23 | return "Normal"; 24 | default: 25 | throw std::runtime_error("Unsupported proxy mode"); 26 | } 27 | } 28 | 29 | std::shared_ptr GetOperationHandler( 30 | const OperationMode proxyMode, FakeNetworkProvider fakeNetworkCallback, ProxyWifiObserver* pObserver) 31 | { 32 | switch (proxyMode) 33 | { 34 | case OperationMode::Simulated: 35 | return MakeManualTestOperationHandler(std::move(fakeNetworkCallback), pObserver); 36 | case OperationMode::Normal: 37 | { 38 | std::shared_ptr wlansvc; 39 | try 40 | { 41 | wlansvc = std::make_shared(); 42 | } 43 | catch(...) 44 | { 45 | LOG_CAUGHT_EXCEPTION_MSG("Failed to get a Wlansvc handle. Will only support fake networks."); 46 | } 47 | return MakeWlansvcOperationHandler(wlansvc, std::move(fakeNetworkCallback), pObserver); 48 | } 49 | default: 50 | throw std::runtime_error("Unsupported proxy mode selected"); 51 | } 52 | } 53 | } // namespace 54 | 55 | WifiNetworkInfo::WifiNetworkInfo(const DOT11_SSID& ssid, const DOT11_MAC_ADDRESS& bssid) 56 | : ssid{ssid} 57 | { 58 | memcpy_s(this->bssid, sizeof this->bssid, bssid, sizeof bssid); 59 | } 60 | 61 | ProxyWifiCommon::ProxyWifiCommon(OperationMode mode, FakeNetworkProvider fakeNetworkCallback, ProxyWifiObserver* pObserver) 62 | : m_mode{mode}, m_operationHandler{GetOperationHandler(mode, std::move(fakeNetworkCallback), pObserver)} 63 | { 64 | } 65 | 66 | ProxyWifiCommon::~ProxyWifiCommon() 67 | { 68 | Stop(); 69 | } 70 | 71 | void ProxyWifiCommon::Start() 72 | { 73 | Log::Info(L"Starting the Wifi proxy"); 74 | m_transport = CreateTransport(); 75 | m_transport->Start(); 76 | } 77 | 78 | void ProxyWifiCommon::Stop() 79 | { 80 | if (!m_transport) 81 | return; 82 | 83 | Log::Info(L"Stopping the Wifi proxy"); 84 | m_transport->Shutdown(); 85 | m_transport = nullptr; 86 | } 87 | 88 | ProxyWifiHyperVSettings::ProxyWifiHyperVSettings(const GUID& guestVmId, unsigned short requestResponsePort, unsigned short notificationPort, OperationMode mode) 89 | : 90 | RequestResponsePort(requestResponsePort), 91 | NotificationPort(notificationPort), 92 | GuestVmId(guestVmId), 93 | ProxyMode(mode) 94 | { 95 | } 96 | 97 | ProxyWifiHyperVSettings::ProxyWifiHyperVSettings(const GUID& guestVmId) 98 | : GuestVmId(guestVmId) 99 | { 100 | } 101 | 102 | ProxyWifiHyperV::ProxyWifiHyperV(const ProxyWifiHyperVSettings& settings, FakeNetworkProvider fakeNetworkCallback, ProxyWifiObserver* pObserver) 103 | : ProxyWifiCommon(settings.ProxyMode, std::move(fakeNetworkCallback), pObserver), m_settings(settings) 104 | { 105 | } 106 | 107 | std::unique_ptr ProxyWifiHyperV::CreateTransport() 108 | { 109 | return std::make_unique( 110 | m_operationHandler, m_settings.RequestResponsePort, m_settings.NotificationPort, m_settings.GuestVmId); 111 | } 112 | 113 | const ProxyWifiHyperVSettings& ProxyWifiHyperV::Settings() const 114 | { 115 | return m_settings; 116 | } 117 | 118 | ProxyWifiTcpSettings::ProxyWifiTcpSettings(std::string listenIp, unsigned short requestResponsePort, unsigned short notificationPort, OperationMode proxyMode) 119 | : RequestResponsePort(requestResponsePort), 120 | NotificationPort(notificationPort), 121 | ListenIp(std::move(listenIp)), 122 | ProxyMode(proxyMode) 123 | { 124 | } 125 | 126 | ProxyWifiTcpSettings::ProxyWifiTcpSettings(std::string listenIp) 127 | : ListenIp(std::move(listenIp)) 128 | { 129 | } 130 | 131 | ProxyWifiTcp::ProxyWifiTcp( 132 | const ProxyWifiTcpSettings& settings, FakeNetworkProvider fakeNetworkCallback, ProxyWifiObserver* pObserver) 133 | : ProxyWifiCommon(settings.ProxyMode, std::move(fakeNetworkCallback), pObserver), m_settings(settings) 134 | { 135 | } 136 | 137 | std::unique_ptr ProxyWifiTcp::CreateTransport() 138 | { 139 | return std::make_unique( 140 | m_operationHandler, m_settings.RequestResponsePort, m_settings.NotificationPort, m_settings.ListenIp); 141 | } 142 | 143 | const ProxyWifiTcpSettings& ProxyWifiTcp::Settings() const 144 | { 145 | return m_settings; 146 | } 147 | 148 | std::unique_ptr BuildProxyWifiService( 149 | const ProxyWifiHyperVSettings& settings, FakeNetworkProvider fakeNetworkCallback, ProxyWifiObserver* pObserver) 150 | { 151 | Log::Info( 152 | L"Building a Wifi proxy. Mode: %hs, Transport: HvSocket, VM Guid: %ws, Request port: %d, Notification port: %d", 153 | GetProxyModeName(settings.ProxyMode), 154 | GuidToString(settings.GuestVmId).c_str(), 155 | settings.RequestResponsePort, 156 | settings.NotificationPort); 157 | return std::make_unique(settings, std::move(fakeNetworkCallback), pObserver); 158 | } 159 | 160 | std::unique_ptr BuildProxyWifiService( 161 | const ProxyWifiTcpSettings& settings, FakeNetworkProvider fakeNetworkCallback, ProxyWifiObserver* pObserver) 162 | { 163 | Log::Info( 164 | L"Building a Wifi proxy. Mode: %hs, Transport: TCP, Listen IP: %hs, Request port: %d, Notification port: %d", 165 | GetProxyModeName(settings.ProxyMode), 166 | settings.ListenIp.c_str(), 167 | settings.RequestResponsePort, 168 | settings.NotificationPort); 169 | return std::make_unique(settings, std::move(fakeNetworkCallback), pObserver); 170 | } 171 | 172 | } // namespace ProxyWifi -------------------------------------------------------------------------------- /lib/ProxyWifiServiceImpl.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | #pragma once 4 | 5 | #include 6 | #include 7 | 8 | #include "OperationHandler.hpp" 9 | #include "Transport.hpp" 10 | #include "ProxyWifi/ProxyWifiService.hpp" 11 | 12 | namespace ProxyWifi { 13 | 14 | /// @brief Represents a host Wi-Fi proxy. 15 | class ProxyWifiCommon: public ProxyWifiService 16 | { 17 | public: 18 | ProxyWifiCommon(OperationMode mode, FakeNetworkProvider fakeNetworkCallback, ProxyWifiObserver* pObserver); 19 | ~ProxyWifiCommon() override; 20 | 21 | ProxyWifiCommon(const ProxyWifiCommon&) = delete; 22 | ProxyWifiCommon& operator=(const ProxyWifiCommon&) = delete; 23 | ProxyWifiCommon(ProxyWifiCommon&&) = delete; 24 | ProxyWifiCommon& operator=(ProxyWifiCommon&&) = delete; 25 | 26 | /// @brief Start the proxy. 27 | /// 28 | /// This creates the transport and begins accepting connections on it. Until 29 | /// Start() is called, the proxy is inactive and does not accept connections. 30 | void Start() override; 31 | 32 | /// @brief Stop the proxy. 33 | /// 34 | /// Stop the proxy if it is started. This will sever all existing connections 35 | /// to the proxy and destroy the transport. Following execution of this call, 36 | /// the proxy will no longer accept new connections. It may be restarted 37 | /// using Start(). 38 | void Stop() override; 39 | 40 | protected: 41 | /// @brief Create a transport for the proxy. 42 | virtual std::unique_ptr CreateTransport() = 0; 43 | 44 | protected: 45 | std::thread m_proxy; 46 | OperationMode m_mode = OperationMode::Normal; 47 | std::shared_ptr m_operationHandler; 48 | std::unique_ptr m_transport; 49 | }; 50 | 51 | /// @brief Represents a Wi-Fi proxy for HyperV container endpoints. 52 | /// 53 | /// This proxy allows connections from HyperV containers. Each proxy instance is 54 | /// bound to exactly one HyperV container and will not allow connections from any 55 | /// other container. 56 | /// 57 | /// The proxy transport uses two (2) HyperV (AF_HYPERV) sockets to facilitate the 58 | /// proxy protocol: 59 | /// 1) Request/Response communication channel. 60 | /// 2) Notification communication channel. 61 | /// 62 | /// The request/response communication channel is driven by the client endpoint 63 | /// which originates requests to which the host responds. 64 | /// 65 | /// The notification communication channel is driven by the host which originates 66 | /// notification messages destined for the client endpoint. This is a one-way 67 | /// communication channel; the host does not listen on it for client responses. 68 | /// 69 | /// Unless the client VM has been expressly configured to allow communication on 70 | /// the ports defined by these sockets, a registry entry must be added denoting 71 | /// registration of the proxy application with the HyperV socket. A registry key 72 | /// must be added under: 73 | /// 74 | /// HKLM\SOFTWARE\Microsoft\Windows NT\CurrentVersion\Virtualization\GuestCommunicationServices 75 | /// 76 | /// with key name equal to a GUID describing the service. This class does not 77 | /// handle such registration; it is the responsibility of the caller. 78 | class ProxyWifiHyperV : public ProxyWifiCommon 79 | { 80 | public: 81 | /// @brief Construct a new Proxy Wifi HyperV object. 82 | /// @param settings The settings controlling operations of the proxy. 83 | /// @param fakeNetworkCallback Function that the proxy will call when it needs a list of fake network to emulate 84 | /// @param pObserver The handler for guest and host notifications 85 | explicit ProxyWifiHyperV(const ProxyWifiHyperVSettings& settings, FakeNetworkProvider fakeNetworkCallback, ProxyWifiObserver* pObserver); 86 | 87 | /// @brief Get HyperV specific proxy settings. 88 | const ProxyWifiHyperVSettings& Settings() const; 89 | 90 | /// @brief Create a Transport object 91 | std::unique_ptr CreateTransport() override; 92 | 93 | private: 94 | const ProxyWifiHyperVSettings m_settings; 95 | }; 96 | 97 | /// @brief Represents a Wi-Fi proxy for TCP endpoints. 98 | /// 99 | /// This proxy allows connections from TCP/IP endpoints. Each proxy instance is 100 | /// bound to exactly one listening IP address and will not allow connections from 101 | /// any other endpoint. 102 | class ProxyWifiTcp : public ProxyWifiCommon 103 | { 104 | public: 105 | /// @brief Construct a new Proxy Wifi Tcp object. 106 | /// 107 | /// @param settings The settings controlling operations of the proxy. 108 | /// @param fakeNetworkCallback Function that the proxy will call when it needs a list of fake network to emulate 109 | /// @param pObserver The handler for guest and host notifications 110 | explicit ProxyWifiTcp(const ProxyWifiTcpSettings& settings, FakeNetworkProvider fakeNetworkCallback, ProxyWifiObserver* pObserver); 111 | 112 | /// @brief TCP specific proxy settings. 113 | const ProxyWifiTcpSettings& Settings() const; 114 | 115 | /// @brief Create a Transport object 116 | std::unique_ptr CreateTransport() override; 117 | 118 | private: 119 | const ProxyWifiTcpSettings m_settings; 120 | }; 121 | 122 | } // namespace ProxyWifi -------------------------------------------------------------------------------- /lib/RealWlanInterface.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "RealWlanInterface.hpp" 5 | 6 | #include "StringUtils.hpp" 7 | #include "ProxyWifi/Logs.hpp" 8 | #include "WlanSvcHelpers.hpp" 9 | 10 | #include 11 | #include 12 | 13 | namespace ProxyWifi { 14 | 15 | namespace { 16 | 17 | bool IsBssSupported(const ScannedBss& bss, const std::vector& networks) 18 | { 19 | const auto matchingNetwork = 20 | std::ranges::find_if(networks, [&](const auto& n) { return bss.ssid == n.dot11Ssid; }); 21 | 22 | if (matchingNetwork == networks.cend()) 23 | { 24 | Log::Debug( 25 | L"BSS without matching network, Bssid: %ws, Ssid: %ws, ChannelCenterFreq: %d, Rssi: %d, Ie dump:\n%ws", 26 | BssidToString(bss.bssid).c_str(), 27 | SsidToLogString(bss.ssid.value()).c_str(), 28 | bss.channelCenterFreq, 29 | bss.rssi, 30 | ByteBufferToHexString(bss.ies).c_str()); 31 | 32 | return true; 33 | } 34 | 35 | if (Wlansvc::IsAuthCipherPairSupported({matchingNetwork->dot11DefaultAuthAlgorithm, matchingNetwork->dot11DefaultCipherAlgorithm})) 36 | { 37 | Log::Debug( 38 | L"Supported BSS, Bssid: %ws, Ssid: %ws, AuthAlgo: %ws, CihperAlgo: %ws, ChannelCenterFreq: %d, Rssi: %d, " 39 | L"Ie dump:\n%ws", 40 | BssidToString(bss.bssid).c_str(), 41 | SsidToLogString(bss.ssid.value()).c_str(), 42 | Wlansvc::AuthAlgoToString(matchingNetwork->dot11DefaultAuthAlgorithm).c_str(), 43 | Wlansvc::CipherAlgoToString(matchingNetwork->dot11DefaultCipherAlgorithm).c_str(), 44 | bss.channelCenterFreq, 45 | bss.rssi, 46 | ByteBufferToHexString(bss.ies).c_str()); 47 | return true; 48 | } 49 | 50 | Log::Debug( 51 | L"Mirroring only BSS, Bssid: %ws, Ssid: %ws, ChannelCenterFreq: %d, Rssi: %d, Original AuthAlgo: %ws, " 52 | L"Original CipherAlgo: %ws", 53 | BssidToString(bss.bssid).c_str(), 54 | SsidToLogString(bss.ssid.value()).c_str(), 55 | bss.channelCenterFreq, 56 | bss.rssi, 57 | Wlansvc::AuthAlgoToString(matchingNetwork->dot11DefaultAuthAlgorithm).c_str(), 58 | Wlansvc::CipherAlgoToString(matchingNetwork->dot11DefaultCipherAlgorithm).c_str()); 59 | 60 | return false; 61 | } 62 | 63 | /// @brief Return the auth algo that will be shown in scan results to the guest, given the real scanned algo 64 | DOT11_AUTH_ALGORITHM AdaptAuthAlgo(std::pair authCipher) 65 | { 66 | return Wlansvc::IsAuthCipherPairSupported(authCipher) ? authCipher.first : DOT11_AUTH_ALGO_RSNA_PSK; 67 | } 68 | 69 | ScannedBss BuildFakeScanResult(const ScannedBss& bss) 70 | { 71 | // Use a fake Bss to generated the scan result IEs 72 | const FakeBss fakeBss{ 73 | BssCapability::Ess | BssCapability::Privacy, 74 | bss.rssi, 75 | bss.channelCenterFreq, 76 | bss.beaconInterval, 77 | bss.bssid, 78 | bss.ssid, 79 | {AkmSuite::Psk}, 80 | {CipherSuite::Ccmp}, 81 | CipherSuite::Ccmp, 82 | {} // Key won't be used 83 | }; 84 | 85 | return ScannedBss{fakeBss}; 86 | } 87 | 88 | std::vector AdaptScanResult(const std::vector& bssList, const std::vector& networkList) 89 | { 90 | std::vector results; 91 | for (const auto& bss : bssList) 92 | { 93 | if (!IsBssSupported(bss, networkList)) 94 | { 95 | // Replace all Bss with an unsupported network type a WPA2-PSK entry 96 | // A guest initiated connection won't work, but a host connection can be mirrored 97 | results.push_back(BuildFakeScanResult(bss)); 98 | } 99 | else 100 | { 101 | results.push_back(bss); 102 | } 103 | } 104 | return results; 105 | } 106 | 107 | } // namespace 108 | 109 | RealWlanInterface::RealWlanInterface(std::shared_ptr wlansvc, const GUID& interfaceGuid) 110 | : m_wlansvc{std::move(wlansvc)}, m_interfaceGuid{interfaceGuid} 111 | { 112 | m_wlansvc->Subscribe(m_interfaceGuid, [this](const auto& n) { WlanNotificationHandler(n); }); 113 | } 114 | 115 | RealWlanInterface::~RealWlanInterface() 116 | { 117 | try 118 | { 119 | m_wlansvc->Unsubscribe(m_interfaceGuid); 120 | } 121 | CATCH_LOG() 122 | } 123 | 124 | void RealWlanInterface::WlanNotificationHandler(const WLAN_NOTIFICATION_DATA& notification) noexcept 125 | try 126 | { 127 | if (notification.NotificationSource == WLAN_NOTIFICATION_SOURCE_ACM) 128 | { 129 | switch (notification.NotificationCode) 130 | { 131 | case wlan_notification_acm_disconnected: 132 | OnDisconnected(*static_cast(notification.pData)); 133 | break; 134 | case wlan_notification_acm_scan_complete: 135 | case wlan_notification_acm_scan_fail: 136 | OnScanComplete(); 137 | break; 138 | case wlan_notification_acm_connection_complete: 139 | OnConnectComplete(*static_cast(notification.pData)); 140 | break; 141 | default: 142 | return; 143 | } 144 | } 145 | else if (notification.NotificationSource == WLAN_NOTIFICATION_SOURCE_MSM) 146 | { 147 | switch (notification.NotificationCode) 148 | { 149 | case wlan_notification_msm_signal_quality_change: 150 | OnSignalQualityChange(*static_cast(notification.pData)); 151 | break; 152 | default: 153 | return; 154 | } 155 | } 156 | } 157 | CATCH_LOG() 158 | 159 | void RealWlanInterface::SetNotificationHandler(INotificationHandler* handler) 160 | { 161 | { 162 | auto lock = std::scoped_lock(m_notifMutex); 163 | m_notifCallback = handler; 164 | } 165 | 166 | // Send an initial notification if this interface is connected 167 | try 168 | { 169 | const auto currentConnection = m_wlansvc->GetCurrentConnection(m_interfaceGuid); 170 | if (currentConnection && currentConnection->isState == wlan_interface_state_connected) 171 | { 172 | NotifyHostConnection( 173 | currentConnection->wlanAssociationAttributes.dot11Ssid, 174 | AdaptAuthAlgo( 175 | {currentConnection->wlanSecurityAttributes.dot11AuthAlgorithm, currentConnection->wlanSecurityAttributes.dot11CipherAlgorithm})); 176 | } 177 | } 178 | CATCH_LOG() 179 | } 180 | 181 | const GUID& RealWlanInterface::GetGuid() const noexcept 182 | { 183 | return m_interfaceGuid; 184 | } 185 | 186 | std::optional RealWlanInterface::IsConnectedTo(const Ssid& requestedSsid) noexcept 187 | { 188 | try 189 | { 190 | const auto currentConnection = m_wlansvc->GetCurrentConnection(m_interfaceGuid); 191 | // Note: This does not handle transient interface states when connection is being setup 192 | if (!currentConnection || currentConnection->isState != wlan_interface_state_connected) 193 | { 194 | return std::nullopt; 195 | } 196 | 197 | ConnectedNetwork network{ 198 | currentConnection->wlanAssociationAttributes.dot11Ssid, 199 | toBssid(currentConnection->wlanAssociationAttributes.dot11Bssid), 200 | currentConnection->wlanSecurityAttributes.dot11AuthAlgorithm 201 | }; 202 | 203 | if (requestedSsid != network.ssid) 204 | { 205 | return std::nullopt; 206 | } 207 | 208 | Log::Info( 209 | L"Host interface %ws already connected to ssid: %ws, bssid: %ws, auth: %ws", 210 | GuidToString(m_interfaceGuid).c_str(), 211 | SsidToLogString(network.ssid.value()).c_str(), 212 | BssidToString(network.bssid).c_str(), 213 | Wlansvc::AuthAlgoToString(network.auth).c_str()); 214 | return network; 215 | } 216 | CATCH_LOG() 217 | 218 | return std::nullopt; 219 | } 220 | 221 | std::future> RealWlanInterface::Connect(const Ssid& ssid, const Bssid& bssid, const WlanSecurity& securityInfo) 222 | { 223 | const auto authCipher = Wlansvc::DetermineAuthCipherPair(securityInfo); 224 | const auto connectionProfile = Wlansvc::MakeConnectionProfile(ssid, authCipher, securityInfo.key); 225 | 226 | // Parse the requested BSSID from the request 227 | const DOT11_MAC_ADDRESS& requestedBssid = *reinterpret_cast(bssid.data()); 228 | 229 | // Ask Wlansvc to connect 230 | std::scoped_lock connectLock(m_promiseMutex); 231 | Log::Trace(L"Connecting to %ws on host interface %ws", SsidToLogString(ssid.value()).c_str(), GuidToString(m_interfaceGuid).c_str()); 232 | m_wlansvc->Connect(m_interfaceGuid, connectionProfile, requestedBssid); 233 | 234 | m_connectPromise.emplace(); 235 | return m_connectPromise->get_future(); 236 | } 237 | 238 | void RealWlanInterface::OnConnectComplete(const WLAN_CONNECTION_NOTIFICATION_DATA& data) 239 | { 240 | if (data.wlanReasonCode == ERROR_SUCCESS) 241 | { 242 | const auto connInfo = m_wlansvc->GetCurrentConnection(m_interfaceGuid); 243 | if (!connInfo) 244 | { 245 | Log::Trace( 246 | L"Could not get the connection information after connecting the interface %ws", GuidToString(m_interfaceGuid).c_str()); 247 | std::scoped_lock connectLock(m_promiseMutex); 248 | if (m_connectPromise) 249 | { 250 | m_connectPromise->set_value({WlanStatus::UnspecifiedFailure, ConnectedNetwork{}}); 251 | m_connectPromise = std::nullopt; 252 | } 253 | return; 254 | } 255 | 256 | // Notify the client for the host connection outside of the lock 257 | NotifyHostConnection( 258 | data.dot11Ssid, 259 | AdaptAuthAlgo({connInfo->wlanSecurityAttributes.dot11AuthAlgorithm, connInfo->wlanSecurityAttributes.dot11CipherAlgorithm})); 260 | 261 | // If there is a promise, this is a successful guest initiated connection 262 | std::scoped_lock connectLock(m_promiseMutex); 263 | if (m_connectPromise) 264 | { 265 | auto connectedNetwork = ConnectedNetwork{ 266 | connInfo->wlanAssociationAttributes.dot11Ssid, 267 | toBssid(connInfo->wlanAssociationAttributes.dot11Bssid), 268 | connInfo->wlanSecurityAttributes.dot11AuthAlgorithm}; 269 | 270 | m_connectPromise->set_value({WlanStatus::Success, connectedNetwork}); 271 | m_connectPromise = std::nullopt; 272 | } 273 | } 274 | else 275 | { 276 | // If there is a promise, the guest initiated connection failed 277 | std::scoped_lock connectLock(m_promiseMutex); 278 | if (m_connectPromise) 279 | { 280 | Log::Trace(L"Could not connect host interface %ws", GuidToString(m_interfaceGuid).c_str()); 281 | m_connectPromise->set_value({WlanStatus::UnspecifiedFailure, ConnectedNetwork{}}); 282 | m_connectPromise = std::nullopt; 283 | } 284 | } 285 | } 286 | 287 | std::future RealWlanInterface::Disconnect() 288 | { 289 | std::unique_lock disconnectLock(m_promiseMutex); 290 | 291 | Log::Trace(L"Requesting disconnection on host interface %ws", GuidToString(m_interfaceGuid).c_str()); 292 | m_wlansvc->Disconnect(m_interfaceGuid); 293 | 294 | m_disconnectPromise.emplace(); 295 | return m_disconnectPromise->get_future(); 296 | } 297 | 298 | void RealWlanInterface::OnDisconnected(const WLAN_CONNECTION_NOTIFICATION_DATA& data) 299 | { 300 | // Let the client and guest know about the host disconnection out of the lock 301 | Log::Trace(L"Host interface %ws disconnected", GuidToString(m_interfaceGuid).c_str()); 302 | NotifyHostDisconnection(data.dot11Ssid); 303 | 304 | { 305 | // If there is a promise, this is a guest initiated disconnection 306 | std::scoped_lock disconnectLock(m_promiseMutex); 307 | if (m_disconnectPromise) 308 | { 309 | Log::Trace(L"Disconnection complete on host interface %ws", GuidToString(m_interfaceGuid).c_str()); 310 | m_disconnectPromise->set_value(); 311 | m_disconnectPromise = std::nullopt; 312 | return; 313 | } 314 | } 315 | } 316 | 317 | std::future RealWlanInterface::Scan(std::optional& ssid) 318 | { 319 | std::unique_lock scanLock(m_promiseMutex); 320 | if (m_scanRunning) 321 | { 322 | // A scan was already scheduled. Wait for its completion to provide results 323 | Log::Trace(L"A scan was already scheduled on interface %ws. Waiting for its completion.", GuidToString(m_interfaceGuid).c_str()); 324 | m_scanPromise.emplace(); 325 | return m_scanPromise->get_future(); 326 | } 327 | 328 | // A scan request to wlansvc always flushes the BSS cache. Cache the existing results, and use them if the scan fails 329 | // and return no results: drivers can fail a scan right after the host connects (media in use), but relevant results 330 | // have already been scanned. 331 | auto cachedScannedBss = m_wlansvc->GetScannedBssList(m_interfaceGuid); 332 | auto cachedScannedNetworks = m_wlansvc->GetScannedNetworkList(m_interfaceGuid); 333 | 334 | auto scannedBss = AdaptScanResult(cachedScannedBss, cachedScannedNetworks); 335 | 336 | try 337 | { 338 | m_scanRunning = true; 339 | 340 | if (ssid) 341 | { 342 | auto requestedSsid = static_cast(*ssid); 343 | 344 | Log::Trace( 345 | L"Requesting targeted scan on host interface %ws, Ssid: %ws", 346 | GuidToString(m_interfaceGuid).c_str(), 347 | SsidToLogString({requestedSsid.ucSSID, requestedSsid.uSSIDLength}).c_str()); 348 | m_wlansvc->Scan(m_interfaceGuid, &requestedSsid); 349 | } 350 | else 351 | { 352 | Log::Trace(L"Requesting scan on host interface %ws", GuidToString(m_interfaceGuid).c_str()); 353 | m_wlansvc->Scan(m_interfaceGuid); 354 | } 355 | } 356 | catch (...) 357 | { 358 | m_scanRunning = false; 359 | } 360 | 361 | // Always mark the scan as completed: if the cached results are not good enough, the next scan request 362 | // will wait for the real scan completion, or the new results will be sent through a notification 363 | Log::Trace(L"Reporting cached scan results on interface %ws", GuidToString(m_interfaceGuid).c_str()); 364 | 365 | std::promise promise; 366 | promise.set_value({std::move(scannedBss), ScanStatus::Completed}); 367 | return promise.get_future(); 368 | } 369 | 370 | void RealWlanInterface::OnScanComplete() 371 | { 372 | auto scanResults = m_wlansvc->GetScannedBssList(m_interfaceGuid); 373 | auto availableNetworks = m_wlansvc->GetScannedNetworkList(m_interfaceGuid); 374 | 375 | auto results = AdaptScanResult(scanResults, availableNetworks); 376 | 377 | { 378 | std::unique_lock scanLock(m_promiseMutex); 379 | 380 | // The scan is not running anymore 381 | m_scanRunning = false; 382 | 383 | // If a scan request is waiting, complete the promise 384 | if (m_scanPromise) 385 | { 386 | m_scanPromise->set_value({std::move(results), ScanStatus::Completed}); 387 | m_scanPromise = std::nullopt; 388 | return; 389 | } 390 | } 391 | 392 | // Nobody is waiting, simply notify the guest of the new results 393 | NotifyScanResults(std::move(results), ScanStatus::Completed); 394 | } 395 | 396 | void RealWlanInterface::OnSignalQualityChange(unsigned long signal) const 397 | { 398 | NotifySignalQualityChange(signal); 399 | } 400 | 401 | } // namespace ProxyWifi -------------------------------------------------------------------------------- /lib/RealWlanInterface.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include "Networks.hpp" 14 | #include "Iee80211Utils.hpp" 15 | #include "WlanInterface.hpp" 16 | #include "WlanSvcWrapper.hpp" 17 | 18 | namespace ProxyWifi { 19 | 20 | /// @brief This class represent an actual wlan station interface managed by wlansvc 21 | class RealWlanInterface: public IWlanInterface 22 | { 23 | public: 24 | // Add notification handler parameter 25 | RealWlanInterface(std::shared_ptr wlansvc, const GUID& interfaceGuid); 26 | ~RealWlanInterface() override; 27 | 28 | RealWlanInterface(const RealWlanInterface&) = delete; 29 | RealWlanInterface(RealWlanInterface&&) = delete; 30 | RealWlanInterface& operator=(const RealWlanInterface&) = delete; 31 | RealWlanInterface& operator=(RealWlanInterface&&) = delete; 32 | 33 | void SetNotificationHandler(INotificationHandler* handler) override; 34 | 35 | const GUID& GetGuid() const noexcept override; 36 | std::optional IsConnectedTo(const Ssid& requestedSsid) noexcept override; 37 | 38 | std::future> Connect(const Ssid& requestedSsid, const Bssid& bssid, const WlanSecurity& securityInfo) override; 39 | std::future Disconnect() override; 40 | std::future Scan(std::optional& ssid) override; 41 | 42 | private: 43 | void WlanNotificationHandler(const WLAN_NOTIFICATION_DATA& notification) noexcept; 44 | void OnConnectComplete(const WLAN_CONNECTION_NOTIFICATION_DATA& data); 45 | void OnDisconnected(const WLAN_CONNECTION_NOTIFICATION_DATA& data); 46 | void OnScanComplete(); 47 | void OnSignalQualityChange(unsigned long signal) const; 48 | 49 | const std::shared_ptr m_wlansvc; 50 | const GUID m_interfaceGuid; 51 | 52 | /// @brief Mutex to protect access to the following promises + m_scanRunning 53 | mutable std::mutex m_promiseMutex; 54 | std::optional>> m_connectPromise; 55 | std::optional> m_disconnectPromise; 56 | std::optional> m_scanPromise; 57 | /// @brief Indicate a scan was requested to wlansvc and no completion notif was received yet 58 | bool m_scanRunning = false; 59 | 60 | mutable std::mutex m_notifMutex; 61 | INotificationHandler* m_notifCallback{}; 62 | 63 | inline void NotifyHostConnection(const Ssid& ssid, DOT11_AUTH_ALGORITHM authAlgo) const 64 | { 65 | auto lock = std::scoped_lock(m_notifMutex); 66 | if (m_notifCallback) 67 | { 68 | m_notifCallback->OnHostConnection(m_interfaceGuid, ssid, authAlgo); 69 | } 70 | } 71 | 72 | inline void NotifyHostDisconnection(const Ssid& ssid) const 73 | { 74 | auto lock = std::scoped_lock(m_notifMutex); 75 | if (m_notifCallback) 76 | { 77 | m_notifCallback->OnHostDisconnection(m_interfaceGuid, ssid); 78 | } 79 | } 80 | 81 | inline void NotifySignalQualityChange(unsigned long signal) const 82 | { 83 | auto lock = std::scoped_lock(m_notifMutex); 84 | if (m_notifCallback) 85 | { 86 | m_notifCallback->OnHostSignalQualityChange(m_interfaceGuid, signal); 87 | } 88 | } 89 | 90 | inline void NotifyScanResults(std::vector scannedBss, ScanStatus status) const 91 | { 92 | auto lock = std::scoped_lock(m_notifMutex); 93 | if (m_notifCallback) 94 | { 95 | m_notifCallback->OnHostScanResults(m_interfaceGuid, scannedBss, status); 96 | } 97 | } 98 | }; 99 | 100 | } // namespace ProxyWifi -------------------------------------------------------------------------------- /lib/SocketHelpers.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "SocketHelpers.hpp" 5 | 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | #include "ProxyWifi/Logs.hpp" 12 | 13 | namespace ProxyWifi { 14 | 15 | std::pair CreateHyperVSocket(const GUID& vmId, unsigned int port) 16 | { 17 | wil::unique_socket sock{WSASocket(AF_HYPERV, SOCK_STREAM, HV_PROTOCOL_RAW, nullptr, 0, WSA_FLAG_OVERLAPPED)}; 18 | if (!sock.is_valid()) 19 | { 20 | THROW_WIN32_MSG(WSAGetLastError(), "Failed to create an hv socket"); 21 | } 22 | 23 | SOCKADDR_HV sockAddrHv{}; 24 | memset(&sockAddrHv, 0, sizeof sockAddrHv); 25 | sockAddrHv.Family = AF_HYPERV; 26 | sockAddrHv.VmId = vmId; 27 | sockAddrHv.ServiceId = HV_GUID_VSOCK_TEMPLATE; 28 | sockAddrHv.ServiceId.Data1 = port; 29 | 30 | // Ensure the socket stays connected when the VM is suspended. 31 | ULONG enable = 1; 32 | THROW_LAST_ERROR_IF( 33 | setsockopt(sock.get(), HV_PROTOCOL_RAW, HVSOCKET_CONNECTED_SUSPEND, reinterpret_cast(&enable), sizeof enable) == SOCKET_ERROR); 34 | 35 | return {std::move(sock), sockAddrHv}; 36 | } 37 | 38 | std::pair CreateTcpSocket(const IN_ADDR& listenIpAddr, unsigned short port) 39 | { 40 | wil::unique_socket sock{WSASocket(AF_INET, SOCK_STREAM, IPPROTO_TCP, nullptr, 0, WSA_FLAG_OVERLAPPED)}; 41 | if (!sock.is_valid()) 42 | { 43 | THROW_WIN32_MSG(WSAGetLastError(), "Failed to create a tcp socket"); 44 | } 45 | 46 | // Bind address to server socket. 47 | sockaddr_in sockaddrIn{}; 48 | sockaddrIn.sin_family = AF_INET; 49 | sockaddrIn.sin_port = htons(port); 50 | sockaddrIn.sin_addr = listenIpAddr; 51 | 52 | return {std::move(sock), sockaddrIn}; 53 | } 54 | 55 | AcceptAsyncContext::AcceptAsyncContext( 56 | wil::unique_socket acceptSocket, wil::unique_event onAcceptEvent, std::unique_ptr buffer, std::unique_ptr overlapped) noexcept 57 | : m_acceptSocket{std::move(acceptSocket)}, m_onAcceptEvent{std::move(onAcceptEvent)}, m_buffer{std::move(buffer)}, m_overlapped{std::move(overlapped)} 58 | { 59 | } 60 | 61 | AcceptAsyncContext::~AcceptAsyncContext() 62 | { 63 | // Make sure the pending IO is completed before destroying the context 64 | if (m_onAcceptEvent.is_valid()) 65 | { 66 | m_onAcceptEvent.wait(); 67 | } 68 | } 69 | 70 | AcceptAsyncContext AcceptAsyncContext::Accept(const wil::unique_socket& listenSocket, const std::function()>& createSocket) 71 | { 72 | auto [acceptSocket, addrSize] = createSocket(); 73 | if (!acceptSocket.is_valid()) 74 | { 75 | THROW_WIN32_MSG(WSAGetLastError(), "Failed to create an hv socket"); 76 | } 77 | 78 | // `AcceptEx` requires an extra 16 bytes for every addresses in the buffer for an internal representation 79 | constexpr auto extraAddressSize = 16; 80 | 81 | auto buffer = std::make_unique(2 * (addrSize + extraAddressSize)); 82 | auto acceptEvent = wil::unique_event(wil::EventOptions::ManualReset); 83 | auto overlapped = std::make_unique(); 84 | overlapped->hEvent = acceptEvent.get(); 85 | 86 | DWORD bytes{}; 87 | const auto addrSizeWithExtra = wil::safe_cast(addrSize + extraAddressSize); 88 | if (!AcceptEx( 89 | listenSocket.get(), acceptSocket.get(), buffer.get(), 0, addrSizeWithExtra, addrSizeWithExtra, &bytes, overlapped.get())) 90 | { 91 | if (WSAGetLastError() != ERROR_IO_PENDING) 92 | { 93 | THROW_WIN32_MSG(WSAGetLastError(), "Failed to accept the connection"); 94 | } 95 | } 96 | 97 | // Create the context only after the call to `AcceptEx` succeeded, to ensure the event will be signaled 98 | return AcceptAsyncContext(std::move(acceptSocket), std::move(acceptEvent), std::move(buffer), std::move(overlapped)); 99 | } 100 | 101 | static bool ReceiveBytes(const wil::unique_socket& socket, std::span buffer) 102 | { 103 | while (buffer.size_bytes() > 0) 104 | { 105 | // Wait until there is something to read (or fail after a timeout) 106 | constexpr timeval timeout{.tv_sec{0}, .tv_usec{500000}}; // 0.5 sec timeout 107 | fd_set read_set{}; 108 | FD_SET(socket.get(), &read_set); 109 | const auto socket_ready = select(0, &read_set, nullptr, nullptr, &timeout); 110 | if (socket_ready == SOCKET_ERROR) 111 | { 112 | THROW_LAST_ERROR_MSG("Error while waiting for a message"); 113 | } 114 | else if (socket_ready == 0) 115 | { 116 | THROW_WIN32_MSG(ERROR_TIMEOUT, "Timeout while waiting for a message"); 117 | } 118 | 119 | // Read as many bytes as possible up to the buffer size 120 | const auto transfer_size = recv(socket.get(), reinterpret_cast(buffer.data()), static_cast(buffer.size_bytes()), 0); 121 | if (transfer_size < 0) 122 | { 123 | const auto err = WSAGetLastError(); 124 | THROW_WIN32_MSG(err, "Received invalid message (transfer_size=%d)", transfer_size); 125 | } 126 | if (transfer_size == 0) 127 | { 128 | return false; 129 | } 130 | 131 | buffer = buffer.subspan(transfer_size); 132 | } 133 | 134 | return true; 135 | } 136 | 137 | std::optional ReceiveProxyWifiMessage(const wil::unique_socket& socket) 138 | { 139 | Message message; 140 | if (!ReceiveBytes(socket, {reinterpret_cast(&message.hdr), sizeof(message.hdr)})) 141 | { 142 | // The guest closed the connection 143 | return std::nullopt; 144 | } 145 | 146 | if (message.hdr.size == 0) 147 | { 148 | return message; 149 | } 150 | 151 | // Allocate body 152 | message.body.resize(message.hdr.size); 153 | 154 | if (!ReceiveBytes(socket, {message.body.data(), message.hdr.size})) 155 | { 156 | LOG_HR_MSG(E_UNEXPECTED, "Connection closed when expecting a message body"); 157 | return std::nullopt; 158 | } 159 | 160 | return message; 161 | } 162 | 163 | static void SendBytes(wil::unique_socket& socket, std::span dataToSend) 164 | { 165 | while (dataToSend.size_bytes() > 0) 166 | { 167 | // Wait until the destination is ready to receive 168 | constexpr timeval timeout{.tv_sec{0}, .tv_usec{500000}}; // 0.5 sec timeout 169 | fd_set write_set{}; 170 | FD_SET(socket.get(), &write_set); 171 | const auto socket_ready = select(0, nullptr, &write_set, nullptr, &timeout); 172 | if (socket_ready == SOCKET_ERROR) 173 | { 174 | THROW_LAST_ERROR_MSG("Error while waiting to send a message"); 175 | } 176 | else if (socket_ready == 0) 177 | { 178 | THROW_WIN32_MSG(ERROR_TIMEOUT, "Timeout while waiting to send a message"); 179 | } 180 | 181 | // Send the message 182 | const auto transfer_size = 183 | send(socket.get(), reinterpret_cast(dataToSend.data()), wil::safe_cast(dataToSend.size_bytes()), 0); 184 | if (transfer_size <= 0) 185 | { 186 | const auto err = WSAGetLastError(); 187 | THROW_WIN32_MSG(err, "Send failed (transfer_size=%d)", transfer_size); 188 | } 189 | 190 | dataToSend = dataToSend.subspan(transfer_size); 191 | } 192 | } 193 | 194 | void SendProxyWifiMessage(wil::unique_socket& socket, const Message& message) 195 | { 196 | SendBytes(socket, {reinterpret_cast(&message.hdr), sizeof(message.hdr)}); 197 | if (!message.body.empty()) 198 | { 199 | SendBytes(socket, {message.body.data(), message.hdr.size}); 200 | } 201 | } 202 | 203 | } // namespace ProxyWifi -------------------------------------------------------------------------------- /lib/SocketHelpers.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | #pragma once 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | 14 | #include "Messages.hpp" 15 | 16 | namespace ProxyWifi { 17 | 18 | enum class SocketUse 19 | { 20 | Bound, ///< The socket should be bound, as by calling bind(). 21 | Connected, ///< The socket should be connected, as by calling connect(). 22 | }; 23 | 24 | /// @brief Create a socket of a specific type. 25 | /// 26 | /// This will create a socket and then perform some action on the socket 27 | /// following its creation, preparing it for use. 28 | /// 29 | /// @tparam SockAddrType The type of the generic transport descriptor. 30 | /// @param type The type of socket to create. This dictates the action taken on the socket after it is created. 31 | /// @param createSocket A function used to create the socket and its transport descriptor. 32 | template 33 | wil::unique_socket CreateSocketWithUse(SocketUse type, std::function()> createSocket) 34 | { 35 | auto [socket, sockAddr] = createSocket(); 36 | 37 | switch (type) 38 | { 39 | case SocketUse::Bound: 40 | if (bind(socket.get(), reinterpret_cast(&sockAddr), sizeof sockAddr) != 0) 41 | { 42 | THROW_WIN32_MSG(WSAGetLastError(), "Socket bind failed"); 43 | } 44 | break; 45 | case SocketUse::Connected: 46 | if (connect(socket.get(), reinterpret_cast(&sockAddr), sizeof sockAddr) != ERROR_SUCCESS) 47 | { 48 | THROW_WIN32_MSG(WSAGetLastError(), "Socket connect failed"); 49 | } 50 | break; 51 | default: 52 | throw std::invalid_argument("Unsupported socket type"); 53 | } 54 | 55 | return std::move(socket); 56 | } 57 | 58 | /// @brief Helper function to create a HyperV socket. 59 | /// @param vmId The vm if for which the socket should be restricted to. 60 | /// @param port The port of the socket to bind. 61 | std::pair CreateHyperVSocket(const GUID& vmId, unsigned int port); 62 | 63 | /// @brief Helper function to create a TCP/IP socket. 64 | /// @param listenIpAddr The TCP/IP address the socket should be bound to. 65 | /// @param port The TCP/IP port number the socket should be bound to. 66 | std::pair CreateTcpSocket(const IN_ADDR& listenIpAddr, unsigned short port); 67 | 68 | class AcceptAsyncContext 69 | { 70 | public: 71 | /// @brief Asynchronously accept a connection 72 | /// @param listenSocket The socket connection are listened on 73 | /// @return The socket the connection will be accepted on 74 | static AcceptAsyncContext Accept(const wil::unique_socket& listenSocket, const std::function()>& createSocket); 75 | 76 | ~AcceptAsyncContext(); 77 | 78 | const wil::unique_event& getOnAcceptEvent() const noexcept 79 | { 80 | return m_onAcceptEvent; 81 | } 82 | 83 | const wil::unique_socket& getSocket() noexcept 84 | { 85 | return m_acceptSocket; 86 | } 87 | 88 | wil::unique_socket releaseSocket() noexcept 89 | { 90 | return std::move(m_acceptSocket); 91 | } 92 | 93 | private: 94 | explicit AcceptAsyncContext( 95 | wil::unique_socket acceptSocket, wil::unique_event onAcceptEvent, std::unique_ptr buffer, std::unique_ptr overlapped) noexcept; 96 | 97 | wil::unique_socket m_acceptSocket; 98 | wil::unique_event m_onAcceptEvent{wil::EventOptions::ManualReset}; 99 | std::unique_ptr m_buffer; 100 | std::unique_ptr m_overlapped; 101 | }; 102 | 103 | /// @brief Helper function which receives a single protocol message on a generic socket. 104 | std::optional ReceiveProxyWifiMessage(const wil::unique_socket& socket); 105 | 106 | /// @brief Helper function which sends a single protocol message on a generic socket. 107 | void SendProxyWifiMessage(wil::unique_socket& socket, const Message& message); 108 | 109 | } // namespace ProxyWifi -------------------------------------------------------------------------------- /lib/TestWlanInterface.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "TestWlanInterface.hpp" 5 | 6 | #include "LogsHelpers.hpp" 7 | #include "StringUtils.hpp" 8 | #include "ProxyWifi/Logs.hpp" 9 | 10 | #include 11 | 12 | namespace ProxyWifi { 13 | 14 | namespace { 15 | 16 | DOT11_AUTH_ALGORITHM GetAuthAlgo(const FakeBss& bss) 17 | { 18 | // The test interface support only open and wpa2psk networks 19 | return bss.akmSuites.empty() ? DOT11_AUTH_ALGO_80211_OPEN : DOT11_AUTH_ALGO_RSNA_PSK; 20 | } 21 | 22 | } // namespace 23 | 24 | /* static */ std::vector TestWlanInterface::BuildFakeNetworkList() 25 | { 26 | return { 27 | { 28 | BssCapability::Ess | BssCapability::Privacy, // capabilities 29 | -50, // rssi 30 | 2432000, // channelCenterFreq 31 | 0, // beaconInterval 32 | {0x0, 0x0, 0x0, 0x0, 0x0, 0x1}, // bssid 33 | Ssid{"FakeWpa2Psk"}, // ssid 34 | {AkmSuite::Psk}, // akmSuites 35 | {CipherSuite::Ccmp}, // cipherSuites 36 | CipherSuite::Ccmp, // groupCipher 37 | {0x45, 0xf6, 0x30, 0x20, 0x80, 0xc4, 0x77, 0x93, 0x58, 0x28, 0x11, 0x59, 0xfa, 0x68, 0xbf, 0x4b, 0xf7, 38 | 0x35, 0xd1, 0x01, 0xde, 0x08, 0x85, 0x4e, 0x88, 0x58, 0xaa, 0xb3, 0xeb, 0x03, 0x6a, 0xad} // key: "secretsecret" 39 | }, 40 | { 41 | BssCapability::Ess, // capabilities 42 | -50, // rssi 43 | 5240000, // channelCenterFreq 44 | 0, // beaconInterval 45 | {0x0, 0x0, 0x0, 0x0, 0x0, 0x2}, // bssid 46 | Ssid{"FakeOpen"}, // ssid 47 | {}, // akmSuites 48 | {}, // cipherSuites 49 | std::nullopt, // groupCipher 50 | {}, // key 51 | }, 52 | { 53 | BssCapability::Ess, // capabilities 54 | -50, // rssi 55 | 6115000, // channelCenterFreq 56 | 0, // beaconInterval 57 | {0x0, 0x0, 0x0, 0x0, 0x0, 0x3}, // bssid 58 | Ssid{"Fake6GHz"}, // ssid 59 | {}, // akmSuites 60 | {}, // cipherSuites 61 | std::nullopt, // groupCipher 62 | {}, // key 63 | }}; 64 | } 65 | 66 | TestWlanInterface::TestWlanInterface(const GUID& interfaceGuid) 67 | : m_interfaceGuid{interfaceGuid} 68 | { 69 | // Start accepting user triggered notification 70 | auto notificationThread = std::thread([this]() { 71 | const auto logger = SetThreadWilFailureLogger(); 72 | NotificationSender(); 73 | }); 74 | notificationThread.detach(); 75 | } 76 | 77 | void TestWlanInterface::SetNotificationHandler(INotificationHandler* handler) 78 | { 79 | m_notifCallback = handler; 80 | } 81 | 82 | const GUID& TestWlanInterface::GetGuid() const noexcept 83 | { 84 | return m_interfaceGuid; 85 | } 86 | 87 | std::optional TestWlanInterface::IsConnectedTo(const Ssid& requestedSsid) noexcept 88 | { 89 | auto lock = std::scoped_lock{m_connectedNetworkMutex}; 90 | if (!m_connectedNetwork || requestedSsid != m_networks[*m_connectedNetwork].ssid) 91 | { 92 | return std::nullopt; 93 | } 94 | 95 | Log::Info( 96 | L"Test interface %ws already connected to ssid: %ws", 97 | GuidToString(m_interfaceGuid).c_str(), 98 | SsidToLogString(requestedSsid.value()).c_str()); 99 | return ConnectedNetwork{requestedSsid, m_networks[*m_connectedNetwork].bssid, GetAuthAlgo(m_networks[*m_connectedNetwork])}; 100 | } 101 | 102 | std::future> TestWlanInterface::Connect(const Ssid& requestedSsid, const Bssid&, const WlanSecurity&) 103 | { 104 | const auto network = std::ranges::find_if(m_networks, [&](const auto& n) { return n.ssid == requestedSsid; }); 105 | 106 | std::promise> promise; 107 | if (network == m_networks.cend()) 108 | { 109 | Log::Trace( 110 | L"Could not connect test interface %ws to ssid: %ws", 111 | GuidToString(m_interfaceGuid).c_str(), 112 | SsidToLogString(requestedSsid.value()).c_str()); 113 | promise.set_value({WlanStatus::UnspecifiedFailure, {}}); 114 | } 115 | else 116 | { 117 | Log::Trace( 118 | L"Connected test interface %ws to to ssid: %ws", 119 | GuidToString(m_interfaceGuid).c_str(), 120 | SsidToLogString(requestedSsid.value()).c_str()); 121 | auto lock = std::scoped_lock{m_connectedNetworkMutex}; 122 | m_connectedNetwork = std::distance(m_networks.begin(), network); 123 | promise.set_value({WlanStatus::Success, {requestedSsid, network->bssid, GetAuthAlgo(*network)}}); 124 | } 125 | return promise.get_future(); 126 | } 127 | 128 | std::future TestWlanInterface::Disconnect() 129 | { 130 | auto lock = std::scoped_lock{m_connectedNetworkMutex}; 131 | m_connectedNetwork.reset(); 132 | 133 | std::promise promise; 134 | promise.set_value(); 135 | return promise.get_future(); 136 | } 137 | 138 | std::future TestWlanInterface::Scan(std::optional&) 139 | { 140 | std::vector result; 141 | for (const auto& fakeBss : m_networks) 142 | { 143 | Log::Debug( 144 | L"Reporting fake BSS, Bssid: %ws, Ssid: %ws, AkmSuites: {%ws}, CipherSuites: {%ws}, GroupCipher: %.8x, " 145 | L"ChannelCenterFreq: %d", 146 | BssidToString(fakeBss.bssid).c_str(), 147 | SsidToLogString(fakeBss.ssid.value()).c_str(), 148 | ListEnumToHexString(std::span{fakeBss.akmSuites}).c_str(), 149 | ListEnumToHexString(std::span{fakeBss.cipherSuites}).c_str(), 150 | fakeBss.groupCipher ? WI_EnumValue(*fakeBss.groupCipher) : 0, 151 | fakeBss.channelCenterFreq); 152 | 153 | result.emplace_back(fakeBss); 154 | 155 | if (m_scanBehavior == ScanBehavior::Async) 156 | { 157 | // Only report the first network in the immediate answer of an async scan 158 | break; 159 | } 160 | } 161 | 162 | Log::Debug(L"%d BSS entries reported on test interface %ws", result.size(), GuidToString(m_interfaceGuid).c_str()); 163 | std::promise promise; 164 | promise.set_value({std::move(result), m_scanBehavior == ScanBehavior::Async ? ScanStatus::Running : ScanStatus::Completed}); 165 | return promise.get_future(); 166 | } 167 | 168 | void TestWlanInterface::NotificationSender() 169 | { 170 | enum class Notification 171 | { 172 | ConnectedOpen, 173 | ConnectedPsk, 174 | Disconnected, 175 | SignalQuality, 176 | ScanResults, 177 | ScanSync, 178 | ScanAsync 179 | }; 180 | 181 | static const std::array, 7> notifications{ 182 | {{Notification::Disconnected, "Host Disconnected"}, 183 | {Notification::ConnectedOpen, "Host Connected Open"}, 184 | {Notification::ConnectedPsk, "Host Connected Psk"}, 185 | {Notification::SignalQuality, "Signal quality"}, 186 | {Notification::ScanResults, "Notify Scan results"}, 187 | {Notification::ScanSync, "Scan Mode: Sync"}, 188 | {Notification::ScanAsync, "Scan Mode: Async"}}}; 189 | 190 | for (;;) 191 | { 192 | std::cout << ">>> Choose what to do: "; 193 | for (auto i = 0u; i < notifications.size(); ++i) 194 | { 195 | std::cout << "<" << i << " -> " << notifications[i].second << "> "; 196 | } 197 | std::cout << std::endl; 198 | 199 | unsigned int userInput = std::numeric_limits::max(); 200 | std::cin >> userInput; 201 | if (userInput >= notifications.size()) 202 | { 203 | std::cout << "Invalid operation code: " << userInput << std::endl; 204 | continue; 205 | } 206 | 207 | std::cout << "Executing: " << notifications[userInput].second << std::endl; 208 | 209 | switch (notifications[userInput].first) 210 | { 211 | case Notification::Disconnected: 212 | { 213 | DOT11_SSID ssid{}; 214 | { 215 | const auto lock = std::scoped_lock{m_connectedNetworkMutex}; 216 | if (m_connectedNetwork) 217 | { 218 | ssid = m_networks[*m_connectedNetwork].ssid; 219 | m_connectedNetwork.reset(); 220 | } 221 | } 222 | 223 | NotifyDisconnection(ssid); 224 | break; 225 | } 226 | case Notification::ConnectedOpen: 227 | { 228 | NotifyConnection(m_networks.front().ssid, DOT11_AUTH_ALGO_80211_OPEN); 229 | break; 230 | } 231 | case Notification::ConnectedPsk: 232 | { 233 | NotifyConnection(m_networks.front().ssid, DOT11_AUTH_ALGO_RSNA_PSK); 234 | break; 235 | } 236 | case Notification::SignalQuality: 237 | { 238 | NotifySignalQualityChange(60); 239 | break; 240 | } 241 | case Notification::ScanResults: 242 | { 243 | std::vector result; 244 | for (const auto& fakeBss : m_networks) 245 | { 246 | Log::Debug( 247 | L"Reporting fake BSS, Bssid: %ws, Ssid: %ws, AkmSuites: {%ws}, CipherSuites: {%ws}, GroupCipher: %.8x, " 248 | L"ChannelCenterFreq: %d", 249 | BssidToString(fakeBss.bssid).c_str(), 250 | SsidToLogString(fakeBss.ssid.value()).c_str(), 251 | ListEnumToHexString(std::span{fakeBss.akmSuites}).c_str(), 252 | ListEnumToHexString(std::span{fakeBss.cipherSuites}).c_str(), 253 | fakeBss.groupCipher ? WI_EnumValue(*fakeBss.groupCipher) : 0, 254 | fakeBss.channelCenterFreq); 255 | 256 | result.emplace_back(fakeBss); 257 | } 258 | NotifyScanResults(result, ScanStatus::Completed); 259 | break; 260 | } 261 | case Notification::ScanSync: 262 | m_scanBehavior = ScanBehavior::Sync; 263 | break; 264 | case Notification::ScanAsync: 265 | m_scanBehavior = ScanBehavior::Async; 266 | break; 267 | default: 268 | throw std::runtime_error("Unsupported notification"); 269 | } 270 | } 271 | } 272 | 273 | } // namespace ProxyWifi -------------------------------------------------------------------------------- /lib/TestWlanInterface.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include "WlanInterface.hpp" 7 | 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | 14 | #include "Networks.hpp" 15 | #include "Iee80211Utils.hpp" 16 | 17 | namespace ProxyWifi { 18 | 19 | /// @brief This class represent a fake wlan interface simulating networks provided by the lib client 20 | class TestWlanInterface : public IWlanInterface 21 | { 22 | public: 23 | TestWlanInterface(const GUID& interfaceGuid); 24 | 25 | void SetNotificationHandler(INotificationHandler* handler) override; 26 | const GUID& GetGuid() const noexcept override; 27 | std::optional IsConnectedTo(const Ssid& requestedSsid) noexcept override; 28 | std::future> Connect(const Ssid& requestedSsid, const Bssid& bssid, const WlanSecurity& securityInfo) override; 29 | std::future Disconnect() override; 30 | std::future Scan(std::optional& ssid) override; 31 | 32 | private: 33 | static std::vector BuildFakeNetworkList(); 34 | void NotificationSender(); 35 | 36 | const GUID m_interfaceGuid{}; 37 | const std::vector m_networks{BuildFakeNetworkList()}; 38 | std::mutex m_connectedNetworkMutex; 39 | std::optional m_connectedNetwork; 40 | 41 | INotificationHandler* m_notifCallback{}; 42 | enum class ScanBehavior 43 | { 44 | Sync, 45 | Async 46 | } m_scanBehavior = ScanBehavior::Sync; 47 | 48 | inline void NotifyConnection(const Ssid& ssid, DOT11_AUTH_ALGORITHM authAlgo) const 49 | { 50 | if (m_notifCallback) 51 | { 52 | m_notifCallback->OnHostConnection(m_interfaceGuid, ssid, authAlgo); 53 | } 54 | } 55 | 56 | inline void NotifyDisconnection(const Ssid& ssid) const 57 | { 58 | if (m_notifCallback) 59 | { 60 | m_notifCallback->OnHostDisconnection(m_interfaceGuid, ssid); 61 | } 62 | } 63 | 64 | inline void NotifySignalQualityChange(unsigned long signal) const 65 | { 66 | if (m_notifCallback) 67 | { 68 | m_notifCallback->OnHostSignalQualityChange(m_interfaceGuid, signal); 69 | } 70 | } 71 | 72 | inline void NotifyScanResults(std::vector result, ScanStatus status) const 73 | { 74 | if (m_notifCallback) 75 | { 76 | m_notifCallback->OnHostScanResults(m_interfaceGuid, result, status); 77 | } 78 | } 79 | }; 80 | 81 | } // namespace ProxyWifi -------------------------------------------------------------------------------- /lib/Tracelog.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | #pragma once 4 | 5 | #include 6 | 7 | namespace ProxyWifi::Log { 8 | 9 | #if MICROSOFT_TELEMETRY 10 | #define IMPLEMENT_TRACELOGGING_CLASS_WRAP IMPLEMENT_TRACELOGGING_CLASS_WITH_MICROSOFT_TELEMETRY 11 | #else 12 | #define IMPLEMENT_TRACELOGGING_CLASS_WRAP IMPLEMENT_TRACELOGGING_CLASS_WITHOUT_TELEMETRY 13 | #endif 14 | 15 | /// @brief Tracelogging provider 16 | /// 17 | /// It should be used directly only for structured logs (telemetry, performances...) 18 | /// For logs that will be read, use the helper functions and WIL error helpers 19 | class TraceProvider : public wil::TraceLoggingProvider 20 | { 21 | IMPLEMENT_TRACELOGGING_CLASS_WRAP( 22 | TraceProvider, 23 | "Microsoft.WslCore.ProxyWifi", 24 | // 872a70db-e765-45e5-9141-4b35732837b6 25 | (0x872a70db, 0xe765, 0x45e5, 0x91, 0x41, 0x4b, 0x35, 0x73, 0x28, 0x37, 0xb6)) 26 | 27 | DEFINE_TRACELOGGING_EVENT_STRING(Debug, Log, TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE)) 28 | DEFINE_TRACELOGGING_EVENT_STRING(Trace, Log, TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE)) 29 | DEFINE_TRACELOGGING_EVENT_STRING(Info, Log, TraceLoggingLevel(WINEVENT_LEVEL_INFO)) 30 | DEFINE_TRACELOGGING_EVENT_STRING(Error, Log, TraceLoggingLevel(WINEVENT_LEVEL_ERROR)) 31 | }; 32 | 33 | } // namespace ProxyWifi::Log -------------------------------------------------------------------------------- /lib/Transport.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "Transport.hpp" 5 | 6 | #include 7 | 8 | #include "LogsHelpers.hpp" 9 | #include "SocketHelpers.hpp" 10 | #include "ProxyWifi/Logs.hpp" 11 | 12 | namespace ProxyWifi { 13 | 14 | Transport::Transport(std::shared_ptr& operationHandler, unsigned short requestResponsePort, unsigned short notificationPort) 15 | : m_operationHandler{operationHandler}, m_requestResponsePort(requestResponsePort), m_notificationPort(notificationPort) 16 | { 17 | m_operationHandler->RegisterGuestNotificationCallback([this](auto notif) noexcept { 18 | try 19 | { 20 | std::visit( 21 | [this](auto&& n) { 22 | using T = std::decay_t; 23 | Log::Info(L"Adding notification to queue: %ws", n.Describe().c_str()); 24 | QueueNotification(T::ToMessage(std::forward(n))); 25 | }, 26 | notif); 27 | } 28 | CATCH_LOG() 29 | }); 30 | } 31 | 32 | Transport::~Transport() 33 | { 34 | Shutdown(); 35 | } 36 | 37 | void Transport::AcceptConnections() 38 | { 39 | auto listenSocket = CreateListenSocket(); 40 | 41 | // Start listening for connections. 42 | if (listen(listenSocket.get(), 1) != 0) 43 | { 44 | THROW_LAST_ERROR_MSG("Failed to listen on socket"); 45 | } 46 | 47 | while (true) 48 | { 49 | auto acceptContext = AcceptAsyncContext::Accept(listenSocket, [this] { return CreateAcceptSocket(); }); 50 | 51 | // Accept connection. 52 | Log::Debug(L"Waiting for connection..."); 53 | const std::array events = {m_shutdownEvent.get(), acceptContext.getOnAcceptEvent().get()}; 54 | const auto waitResults = WSAWaitForMultipleEvents(static_cast(events.size()), events.data(), false, WSA_INFINITE, false); 55 | if (waitResults == WSA_WAIT_EVENT_0) 56 | { 57 | // This is the shutdown event 58 | Log::Debug(L"Completing the transport runner thread"); 59 | // Close the listen socket first, the accept context will then wait for the pending IO before destroying itself 60 | listenSocket.reset(); 61 | return; 62 | } 63 | else if (waitResults == WSA_WAIT_EVENT_0 + 1) 64 | { 65 | Log::Debug(L"Got connection"); 66 | 67 | // The guest is now ready, this will let host notification be sent 68 | m_guestWasPresent = true; 69 | 70 | // Handle the connection synchronously: the guest requests are serialized 71 | auto connection = ConnectionSocket{acceptContext.releaseSocket(), m_operationHandler}; 72 | connection.Run(); 73 | } 74 | else if (waitResults == WSA_WAIT_FAILED) 75 | { 76 | LOG_WIN32_MSG(WSAGetLastError(), "Failed to wait on for an incomming connection"); 77 | 78 | // Try to recover by reseting the listen socket 79 | listenSocket.reset(); 80 | const auto waitRecover = WSAWaitForMultipleEvents(static_cast(events.size()), events.data(), false, WSA_INFINITE, false); 81 | if (waitRecover == WSA_WAIT_EVENT_0) 82 | { 83 | // This is the shutdown event, end the thread 84 | Log::Debug(L"Completing the transport runner thread while trying to recover from bad wait"); 85 | return; 86 | } 87 | else if (waitRecover != WSA_WAIT_EVENT_0 + 1) 88 | { 89 | // Still failing, abort 90 | FAIL_FAST(); 91 | } 92 | 93 | // Start listening for connections again 94 | listenSocket = CreateListenSocket(); 95 | if (listen(listenSocket.get(), 1) != 0) 96 | { 97 | THROW_LAST_ERROR_MSG("Failed to listen on socket"); 98 | } 99 | continue; 100 | } 101 | else 102 | { 103 | FAIL_FAST_IF_WIN32_ERROR_MSG(ERROR_INVALID_STATE, "Received unexpected value from WSAWaitForMultipleEvents: %d", waitResults); 104 | } 105 | } 106 | } 107 | 108 | void Transport::QueueNotification(Message&& msg) 109 | { 110 | if (!m_guestWasPresent) 111 | { 112 | Log::Trace(L"Dropping a notification: no guest request have been received, it might not be ready yet"); 113 | return; 114 | } 115 | 116 | m_notifQueue.Submit([this, n = std::move(msg)]() noexcept { 117 | try 118 | { 119 | SendNotification(n); 120 | } 121 | CATCH_LOG_MSG("Failed to send a notification") 122 | }); 123 | } 124 | 125 | void Transport::SendNotification(const Message& message) 126 | { 127 | Log::Trace(L"Sending notification <%ws> (%d bytes)", GetProtocolMessageTypeName(message.hdr.operation), message.hdr.size); 128 | auto socket = CreateNotificationSocket(); 129 | SendProxyWifiMessage(socket, message); 130 | } 131 | 132 | void Transport::Start() 133 | { 134 | m_shutdownEvent.ResetEvent(); 135 | 136 | m_runnerThread = std::thread([this] { 137 | const auto logger = SetThreadWilFailureLogger(); 138 | AcceptConnections(); 139 | }); 140 | } 141 | 142 | void Transport::Shutdown() 143 | { 144 | // Stop accepting request from the guest 145 | std::thread thread; 146 | std::swap(thread, m_runnerThread); 147 | 148 | m_shutdownEvent.SetEvent(); 149 | 150 | if (thread.joinable()) 151 | { 152 | thread.join(); 153 | } 154 | 155 | // Stop sending notification to the guest 156 | // First, clear the operation handler callback (and wait for any currently executing one), 157 | // to ensure nothing is queued after `m_notifQueue` is stopped 158 | m_operationHandler->ClearGuestNotificationCallback(); 159 | m_notifQueue.Cancel(); 160 | } 161 | 162 | HyperVTransport::HyperVTransport( 163 | std::shared_ptr& operationHandler, unsigned short requestResponsePort, unsigned short notificationPort, const GUID& guestVmId) 164 | : Transport(operationHandler, requestResponsePort, notificationPort), m_guestVmId(guestVmId) 165 | { 166 | } 167 | 168 | wil::unique_socket HyperVTransport::CreateListenSocket() 169 | { 170 | return CreateSocketWithUse( 171 | SocketUse::Bound, [&]() { return CreateHyperVSocket(m_guestVmId, m_requestResponsePort); }); 172 | } 173 | 174 | wil::unique_socket HyperVTransport::CreateNotificationSocket() 175 | { 176 | return CreateSocketWithUse( 177 | SocketUse::Connected, [&]() { return CreateHyperVSocket(m_guestVmId, m_notificationPort); }); 178 | } 179 | 180 | std::pair HyperVTransport::CreateAcceptSocket() 181 | { 182 | wil::unique_socket sock{WSASocket(AF_HYPERV, SOCK_STREAM, HV_PROTOCOL_RAW, nullptr, 0, 0)}; 183 | if (!sock.is_valid()) 184 | { 185 | THROW_WIN32_MSG(WSAGetLastError(), "Failed to create an hv socket"); 186 | } 187 | return std::make_pair(std::move(sock), sizeof(SOCKADDR_HV)); 188 | } 189 | 190 | TcpTransport::TcpTransport( 191 | std::shared_ptr& operationHandler, unsigned short requestResponsePort, unsigned short notificationPort, const std::string& listenIp) 192 | : Transport(operationHandler, requestResponsePort, notificationPort), m_listenIp(listenIp) 193 | { 194 | if (inet_pton(AF_INET, listenIp.c_str(), &m_listenIpAddr) != 1) 195 | { 196 | THROW_WIN32_MSG(WSAGetLastError(), "listenIP: %hs", listenIp.c_str()); 197 | } 198 | } 199 | 200 | wil::unique_socket TcpTransport::CreateListenSocket() 201 | { 202 | return CreateSocketWithUse( 203 | SocketUse::Bound, [&]() { return CreateTcpSocket(m_listenIpAddr, m_requestResponsePort); }); 204 | } 205 | 206 | wil::unique_socket TcpTransport::CreateNotificationSocket() 207 | { 208 | return CreateSocketWithUse( 209 | SocketUse::Connected, [&]() { return CreateTcpSocket(m_listenIpAddr, m_notificationPort); }); 210 | } 211 | 212 | std::pair TcpTransport::CreateAcceptSocket() 213 | { 214 | wil::unique_socket sock{WSASocket(AF_INET, SOCK_STREAM, IPPROTO_TCP, nullptr, 0, 0)}; 215 | if (!sock.is_valid()) 216 | { 217 | THROW_WIN32_MSG(WSAGetLastError(), "Failed to create a tcp socket"); 218 | } 219 | return std::make_pair(std::move(sock), sizeof(sockaddr_in)); 220 | } 221 | 222 | } // namespace ProxyWifi -------------------------------------------------------------------------------- /lib/Transport.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | #pragma once 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | 12 | #include 13 | 14 | #include "Connection.hpp" 15 | #include "OperationHandler.hpp" 16 | #include "WorkQueue.hpp" 17 | 18 | namespace ProxyWifi { 19 | 20 | /// @brief Represents a transport for exchanging data between the host and 21 | /// guest, using generic sockets. 22 | /// 23 | /// The transport has no awareness of a protocol. It is responsible for accepting 24 | /// connections and exchanging data between two endpoints. 25 | class Transport 26 | { 27 | public: 28 | /// @brief Construct a new Transport Socket object. 29 | /// @param requestResponsePort The request/response port to bind to. 30 | /// @param notificationPort The notification port to connect to. 31 | Transport(std::shared_ptr& operationHandler, unsigned short requestResponsePort, unsigned short notificationPort); 32 | 33 | Transport(const Transport&) = delete; 34 | Transport(Transport&&) = delete; 35 | Transport& operator=(const Transport&) = delete; 36 | Transport& operator=(Transport&&) = delete; 37 | 38 | /// @brief Destroy the Wifi Proxy Transport Socket object. 39 | virtual ~Transport(); 40 | 41 | /// @brief Start accepting connection asynchronounsly. 42 | void Start(); 43 | 44 | /// @brief Shutdown the transport and stop accepting connections. 45 | void Shutdown(); 46 | 47 | protected: 48 | /// @brief Start accepting connections. 49 | void AcceptConnections(); 50 | 51 | /// @brief Interface function to create the request/response socket. 52 | /// Implementations must return a socket that is listening on the configured 53 | /// request/response port. 54 | virtual wil::unique_socket CreateListenSocket() = 0; 55 | 56 | /// @brief Interface function to create a bound socket. Implementations must 57 | /// return a socket that is bound to the confitured notification port. 58 | virtual wil::unique_socket CreateNotificationSocket() = 0; 59 | 60 | /// @brief Interface function to create a socket for an accepted connection. 61 | /// Implementation must return a socket unbound and unconnected 62 | virtual std::pair CreateAcceptSocket() = 0; 63 | 64 | private: 65 | /// @brief Queue a notification to send it asynchronously 66 | void QueueNotification(Message&& msg); 67 | 68 | /// @brief Send a single protocol message as a notification. 69 | void SendNotification(const Message& message); 70 | 71 | protected: 72 | std::shared_ptr m_operationHandler; 73 | 74 | std::thread m_runnerThread; 75 | wil::unique_event m_shutdownEvent{wil::EventOptions::ManualReset}; 76 | unsigned short m_requestResponsePort; 77 | 78 | SerializedWorkQueue> m_notifQueue; 79 | unsigned short m_notificationPort; 80 | 81 | std::atomic_bool m_guestWasPresent = false; 82 | }; 83 | 84 | /// @brief Proxy transport which uses HyperV (AF_HYPERV) sockets. 85 | /// 86 | /// This facilitates proxying Wi-Fi operations to Hyper-V container endpoints. 87 | class HyperVTransport : public Transport 88 | { 89 | public: 90 | /// @brief Creates a new Hyper-V transport which listens for connections 91 | /// on the specified ports and is restricted to the specified guest vm id 92 | /// @param operationHandler The operation handler for the transport. 93 | /// @param requestResponsePort The HyperV socket port for the request/response communication channel. 94 | /// @param notificationPort The HyperV socket port for the notification communication channel. 95 | /// @param guestVmId The vm if for which the transport should be restricted to. 96 | HyperVTransport( 97 | std::shared_ptr& operationHandler, unsigned short requestResponsePort, unsigned short notificationPort, const GUID& guestVmId); 98 | 99 | private: 100 | wil::unique_socket CreateListenSocket() override; 101 | wil::unique_socket CreateNotificationSocket() override; 102 | std::pair CreateAcceptSocket() override; 103 | 104 | private: 105 | const GUID m_guestVmId; 106 | }; 107 | 108 | /// @brief Proxy transport using TCP/IP sockets. 109 | /// 110 | /// This facilitates proxying Wi-Fi operations to TCP/IP endpoints. This can 111 | /// include remote systems if the transport is bound to a publicly routable 112 | /// TCP/IP address. 113 | class TcpTransport : public Transport 114 | { 115 | public: 116 | /// @brief Construct a new Tcp Transport object 117 | /// @param operationHandler The operation handler for the transport. 118 | /// @param listenIp The TCP/IP address to listen on for request, in a format compatible with `inet_pton`. 119 | /// @param requestResponsePort The TCP/IP port number for the request/response communication channel. 120 | /// @param notificationPort The TCP/IP port number for the notification communication channel. 121 | TcpTransport(std::shared_ptr& operationHandler, unsigned short requestResponsePort, unsigned short notificationPort, const std::string& listenIp); 122 | 123 | private: 124 | wil::unique_socket CreateListenSocket() override; 125 | wil::unique_socket CreateNotificationSocket() override; 126 | std::pair CreateAcceptSocket() override; 127 | 128 | private: 129 | const std::string m_listenIp; 130 | IN_ADDR m_listenIpAddr{}; 131 | }; 132 | 133 | } // namespace ProxyWifi -------------------------------------------------------------------------------- /lib/WlanInterface.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #pragma once 5 | 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include "Networks.hpp" 14 | #include "Iee80211Utils.hpp" 15 | 16 | namespace ProxyWifi { 17 | 18 | /// @brief Interface for classes that implementents the callbacks called by an `IWlanInterface` 19 | class INotificationHandler 20 | { 21 | public: 22 | virtual ~INotificationHandler() = default; 23 | 24 | /// @brief Must be called by the interfaces when they connect to a network 25 | virtual void OnHostConnection(const GUID& interfaceGuid, const Ssid& ssid, DOT11_AUTH_ALGORITHM authAlgo) = 0; 26 | 27 | /// @brief Must be called by the interfaces when they disconnect to a network 28 | virtual void OnHostDisconnection(const GUID& interfaceGuid, const Ssid& ssid) = 0; 29 | 30 | /// @brief Must be called by the interfaces when the signal quality changes 31 | virtual void OnHostSignalQualityChange(const GUID& interfaceGuid, unsigned long signalQuality) = 0; 32 | 33 | /// @brief Must be called by the interfaces when scan results are available 34 | virtual void OnHostScanResults(const GUID& interfaceGuid, const std::vector& scannedBss, ScanStatus status) = 0; 35 | }; 36 | 37 | /// @brief Interface for classes representing a wlan interface (real or simulated) 38 | /// An `OperationHandler` will use this interface to dispatch requests to the interfaces and collect the results 39 | class IWlanInterface 40 | { 41 | public: 42 | virtual ~IWlanInterface() = default; 43 | 44 | /// @brief Allows to provide the callback the interface will call on specific events 45 | virtual void SetNotificationHandler(INotificationHandler* notificationHandler) = 0; 46 | 47 | /// @brief Access the interface GUID (unique identifier) 48 | virtual const GUID& GetGuid() const noexcept = 0; 49 | 50 | /// @brief Indicate whether the interface is connected to a specific network 51 | /// @param requestedSsid The Ssid of the network 52 | /// @return the BSSID of the connected BSS if it is connected to the requested network, std::nulopt otherwise 53 | virtual std::optional IsConnectedTo(const Ssid& requestedSsid) noexcept = 0; 54 | 55 | /// @brief Request that the interface connect to a specific network 56 | /// @param requestedSsid The Ssid of the network to connect to 57 | /// @param bssid The bss to connect to. Ignored if all zeros. 58 | /// @param securityInfo The authencation, cipher, key... to use for the connection 59 | /// @return A future indicating whether the connection was successful or not and the connected network when it is ready 60 | virtual std::future> Connect(const Ssid& requestedSsid, const Bssid& bssid, const WlanSecurity& securityInfo) = 0; 61 | 62 | /// @brief Request that the interface disconnect 63 | /// @return A future indicating when the disconnection is complete 64 | virtual std::future Disconnect() = 0; 65 | 66 | struct ScanResult 67 | { 68 | std::vector bssList; 69 | ScanStatus status{ScanStatus::Completed}; 70 | }; 71 | /// @brief Request that the interface schedule a scan 72 | /// @param ssid The if present, request a targeted scan on this ssid (needed to scan hidden networks) 73 | /// @return A future containing the current scan results, and whether the scan is still running 74 | virtual std::future Scan(std::optional& ssid) = 0; 75 | }; 76 | 77 | 78 | } // namespace ProxyWifi -------------------------------------------------------------------------------- /lib/WlanSvcHelpers.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | #pragma once 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | #include "Iee80211Utils.hpp" 16 | 17 | /// @brief Helpers for working with WlanSvc on Windows. 18 | namespace ProxyWifi::Wlansvc { 19 | 20 | /// @brief Get a log-friendly string describing the notification code from a WLAN_NOTIFICATION_DATA structure. 21 | std::string GetWlanNotificationCodeString(const WLAN_NOTIFICATION_DATA& data); 22 | 23 | /// @brief Indicate whether a BSSID is {00-00-00-00-00-00} 24 | bool IsNullBssid(const DOT11_MAC_ADDRESS& bssid); 25 | 26 | /// @brief Build a DOT11_BSSID_LIST from a list of BSSIDs 27 | /// @return A pointer to the created DOT11_BSSID_LIST and a smart pointer to the memory allocated for it 28 | std::pair> BuildBssidList(std::span bssids); 29 | 30 | /// @brief Map a 802.11 cipher suite to the corresponding Windows API enumeration 31 | DOT11_CIPHER_ALGORITHM CipherSuiteToWindowsEnum(CipherSuite cipherSuite); 32 | 33 | /// @brief Convert a link quality in percentage to an RSSI in dBm 34 | inline constexpr int8_t LinkQualityToRssi(unsigned long signal) 35 | { 36 | signal = std::clamp(signal, 0ul, 100ul); 37 | return static_cast(signal) / 2 - 100; 38 | } 39 | 40 | /// @brief Convert an authentication algorithm to a string for pretty printing 41 | std::wstring AuthAlgoToString(DOT11_AUTH_ALGORITHM authAlgo) noexcept; 42 | 43 | /// @brief Convert a cipher algorithm to a string for pretty printing 44 | std::wstring CipherAlgoToString(DOT11_CIPHER_ALGORITHM cipher) noexcept; 45 | 46 | /// @brief Convert an authentication algorithm to the string used in a wlan profile 47 | std::wstring AuthAlgoToProfileString(DOT11_AUTH_ALGORITHM authAlgo); 48 | 49 | /// @brief Convert a cipher algorithm to the string used in a wlan profile 50 | std::wstring CipherAlgoToProfileString(DOT11_CIPHER_ALGORITHM cipher); 51 | 52 | /// @brief Attempt to create a valid profile name from an SSID 53 | /// If the SSID cannot be converted to a valid string, return a default name 54 | std::wstring ProfileNameFromSSID(const Ssid& ssid); 55 | 56 | /// @brief Build a valid, basic, wlan profile from the parameters 57 | std::wstring MakeConnectionProfile(const Ssid& ssid, DOT11_AUTH_CIPHER_PAIR authCipher, const std::span& key); 58 | 59 | /// @brief Determine whether an authentication/cipher is supported by the lib for real host connections 60 | bool IsAuthCipherPairSupported(const std::pair& authCipher); 61 | 62 | /// @brief Build a DOT11_AUTH_CIPHER_PAIR from security parameters, throw if the parameters are invalid or unsupported 63 | DOT11_AUTH_CIPHER_PAIR DetermineAuthCipherPair(const WlanSecurity& secInfo); 64 | 65 | } // namespace ProxyWifi::Wlansvc 66 | -------------------------------------------------------------------------------- /lib/WlanSvcWrapper.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "WlanSvcWrapper.hpp" 5 | 6 | #include "LogsHelpers.hpp" 7 | #include "StringUtils.hpp" 8 | #include "WlanSvcHelpers.hpp" 9 | 10 | #include 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | namespace ProxyWifi::Wlansvc { 17 | 18 | WlanApiWrapperImpl::WlanApiWrapperImpl() 19 | { 20 | DWORD negotiatedVersion = 0; 21 | THROW_IF_WIN32_ERROR(m_wlanApi.WlanOpenHandle(WLAN_API_VERSION_2_0, nullptr, &negotiatedVersion, &m_wlanHandle)); 22 | THROW_IF_WIN32_ERROR(m_wlanApi.WlanRegisterNotification( 23 | m_wlanHandle, WLAN_NOTIFICATION_SOURCE_ACM | WLAN_NOTIFICATION_SOURCE_MSM, true, OnWlansvcEventCallback, this, nullptr, nullptr)); 24 | } 25 | 26 | WlanApiWrapperImpl::~WlanApiWrapperImpl() 27 | { 28 | // Best effort to unregister, wlansvc will do it when the handle is closed otherwise 29 | LOG_IF_WIN32_ERROR(m_wlanApi.WlanRegisterNotification(m_wlanHandle, WLAN_NOTIFICATION_SOURCE_NONE, true, nullptr, nullptr, nullptr, nullptr)); 30 | m_wlanApi.WlanCloseHandle(m_wlanHandle, nullptr); 31 | } 32 | 33 | void WlanApiWrapperImpl::OnWlansvcEventCallback(PWLAN_NOTIFICATION_DATA pNotification, void* pContext) noexcept 34 | try 35 | { 36 | SetThreadWilFailureLogger(); 37 | THROW_IF_NULL_ALLOC(pContext); 38 | auto& wlansvc = *static_cast(pContext); 39 | wlansvc.HandleWlansvcNotification(pNotification); 40 | } 41 | CATCH_LOG() 42 | 43 | void WlanApiWrapperImpl::HandleWlansvcNotification(PWLAN_NOTIFICATION_DATA pNotification) 44 | { 45 | Log::Debug( 46 | L"Receving Wlansvc notification %hs on interface %ws", 47 | Wlansvc::GetWlanNotificationCodeString(*pNotification).c_str(), 48 | GuidToString(pNotification->InterfaceGuid).c_str()); 49 | 50 | auto lock = m_callbacksLock.lock_shared(); 51 | const auto callback = m_callbacks.find(pNotification->InterfaceGuid); 52 | if (callback != m_callbacks.end()) 53 | { 54 | callback->second(*pNotification); 55 | } 56 | 57 | // The callback on the null GUID receive all notifications 58 | const auto allIntfCallback = m_callbacks.find(GUID{}); 59 | if (allIntfCallback != m_callbacks.end()) 60 | { 61 | allIntfCallback->second(*pNotification); 62 | } 63 | } 64 | 65 | void WlanApiWrapperImpl::Subscribe(const GUID& interfaceGuid, std::function callback) 66 | { 67 | auto lock = m_callbacksLock.lock_exclusive(); 68 | m_callbacks[interfaceGuid] = callback; 69 | } 70 | 71 | void WlanApiWrapperImpl::Unsubscribe(const GUID& interfaceGuid) 72 | { 73 | auto lock = m_callbacksLock.lock_exclusive(); 74 | m_callbacks.erase(interfaceGuid); 75 | } 76 | 77 | std::vector WlanApiWrapperImpl::EnumerateInterfaces() 78 | { 79 | WLAN_INTERFACE_INFO_LIST* pInterfaces{nullptr}; 80 | auto cleanup = wil::scope_exit([&] { 81 | if (pInterfaces) 82 | { 83 | m_wlanApi.WlanFreeMemory(pInterfaces); 84 | } 85 | }); 86 | 87 | THROW_IF_WIN32_ERROR(m_wlanApi.WlanEnumInterfaces(m_wlanHandle, nullptr, &pInterfaces)); 88 | 89 | if (!pInterfaces || pInterfaces->dwNumberOfItems == 0) 90 | { 91 | return {}; 92 | } 93 | 94 | std::vector result; 95 | std::transform( 96 | pInterfaces->InterfaceInfo, pInterfaces->InterfaceInfo + pInterfaces->dwNumberOfItems, std::back_inserter(result), [](const auto& i) { 97 | return i.InterfaceGuid; 98 | }); 99 | return result; 100 | } 101 | 102 | std::optional WlanApiWrapperImpl::GetCurrentConnection(const GUID& interfaceGuid) 103 | { 104 | DWORD dataSize = 0; 105 | WLAN_CONNECTION_ATTRIBUTES* pCurrentConnection{nullptr}; 106 | auto cleanup = wil::scope_exit([&] { 107 | if (pCurrentConnection) 108 | { 109 | m_wlanApi.WlanFreeMemory(pCurrentConnection); 110 | } 111 | }); 112 | 113 | const auto err = m_wlanApi.WlanQueryInterface( 114 | m_wlanHandle, &interfaceGuid, wlan_intf_opcode_current_connection, nullptr, &dataSize, reinterpret_cast(&pCurrentConnection), nullptr); 115 | 116 | if (err == ERROR_INVALID_STATE) 117 | { 118 | // This means the interface is not connected 119 | return std::nullopt; 120 | } 121 | else if (err != ERROR_SUCCESS) 122 | { 123 | THROW_WIN32(err); 124 | } 125 | 126 | THROW_IF_NULL_ALLOC(pCurrentConnection); 127 | return *pCurrentConnection; 128 | } 129 | 130 | void WlanApiWrapperImpl::Connect(const GUID& interfaceGuid, const std::wstring& profile, const DOT11_MAC_ADDRESS& bssid) 131 | { 132 | WLAN_CONNECTION_PARAMETERS connectionParameters{}; 133 | connectionParameters.wlanConnectionMode = wlan_connection_mode_temporary_profile; 134 | connectionParameters.strProfile = profile.data(); 135 | connectionParameters.pDot11Ssid = nullptr; 136 | connectionParameters.dot11BssType = dot11_BSS_type_infrastructure; 137 | 138 | // Set the requested BSSID if present in the request 139 | std::unique_ptr bssidListBuffer; 140 | if (!Wlansvc::IsNullBssid(bssid)) 141 | { 142 | std::tie(connectionParameters.pDesiredBssidList, bssidListBuffer) = Wlansvc::BuildBssidList({&bssid, 1}); 143 | } 144 | 145 | THROW_IF_WIN32_ERROR(m_wlanApi.WlanConnect(m_wlanHandle, &interfaceGuid, &connectionParameters, nullptr)); 146 | } 147 | 148 | void WlanApiWrapperImpl::Disconnect(const GUID& interfaceGuid) 149 | { 150 | THROW_IF_WIN32_ERROR(m_wlanApi.WlanDisconnect(m_wlanHandle, &interfaceGuid, nullptr)); 151 | } 152 | 153 | void WlanApiWrapperImpl::Scan(const GUID& interfaceGuid, DOT11_SSID* ssid) 154 | { 155 | THROW_IF_WIN32_ERROR(m_wlanApi.WlanScan(m_wlanHandle, &interfaceGuid, ssid, nullptr, nullptr)); 156 | } 157 | 158 | std::vector WlanApiWrapperImpl::GetScannedBssList(const GUID& interfaceGuid) 159 | { 160 | WLAN_BSS_LIST* pBssList{nullptr}; 161 | auto cleanup = wil::scope_exit([&] { 162 | if (pBssList) 163 | { 164 | m_wlanApi.WlanFreeMemory(pBssList); 165 | } 166 | }); 167 | THROW_IF_WIN32_ERROR(m_wlanApi.WlanGetNetworkBssList( 168 | m_wlanHandle, &interfaceGuid, nullptr, dot11_BSS_type_infrastructure, false, nullptr, &pBssList)); 169 | 170 | std::vector scannedBss; 171 | for (const auto& bss : wil::make_range(pBssList->wlanBssEntries, pBssList->dwNumberOfItems)) 172 | { 173 | const auto ieStart = reinterpret_cast(&bss) + bss.ulIeOffset; 174 | scannedBss.emplace_back( 175 | toBssid(bss.dot11Bssid), 176 | bss.dot11Ssid, 177 | bss.usCapabilityInformation, 178 | wil::safe_cast(bss.lRssi), 179 | bss.ulChCenterFrequency, 180 | bss.usBeaconPeriod, 181 | std::vector{ieStart, ieStart + bss.ulIeSize}); 182 | } 183 | return scannedBss; 184 | } 185 | 186 | std::vector WlanApiWrapperImpl::GetScannedNetworkList(const GUID& interfaceGuid) 187 | { 188 | WLAN_AVAILABLE_NETWORK_LIST* pScannedNetworks{nullptr}; 189 | auto cleanup = wil::scope_exit([&] { 190 | if (pScannedNetworks) 191 | { 192 | m_wlanApi.WlanFreeMemory(pScannedNetworks); 193 | } 194 | }); 195 | 196 | THROW_IF_WIN32_ERROR( 197 | m_wlanApi.WlanGetAvailableNetworkList(m_wlanHandle, &interfaceGuid, dot11_BSS_type_infrastructure, nullptr, &pScannedNetworks)); 198 | return {pScannedNetworks->Network, pScannedNetworks->Network + pScannedNetworks->dwNumberOfItems}; 199 | } 200 | 201 | } // namespace ProxyWifi::Wlansvc -------------------------------------------------------------------------------- /lib/WlanSvcWrapper.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | #pragma once 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include "DynamicFunction.hpp" 15 | #include "GuidUtils.hpp" 16 | #include "Networks.hpp" 17 | 18 | namespace ProxyWifi::Wlansvc { 19 | 20 | constexpr const wchar_t* c_wlanApi = L"wlanApi.dll"; 21 | 22 | /// @brief Wrapper around the Wlansvc API 23 | /// Expose the parts of the Wlansvc API used by the lib in a more C++ compatible way and facilitate unit-testing 24 | class WlanApiWrapper 25 | { 26 | public: 27 | virtual ~WlanApiWrapper() = default; 28 | 29 | /// @brief Provide the GUIDS of all station interfaces on the system 30 | virtual std::vector EnumerateInterfaces() = 0; 31 | 32 | /// @brief Allow a client to subscribe to notification for a specific interface 33 | /// Use the null guid to receive all notifications 34 | virtual void Subscribe(const GUID& interfaceGuid, std::function callback) = 0; 35 | 36 | /// @brief Allow a client to unsubscribe to notification for a specific interface 37 | virtual void Unsubscribe(const GUID& interfaceGuid) = 0; 38 | 39 | /// @brief Provide information about the currently connected network on `interfaceGuid` 40 | virtual std::optional GetCurrentConnection(const GUID& interfaceGuid) = 0; 41 | 42 | /// @brief Connect to a wlan network using a temporary profile 43 | virtual void Connect(const GUID& interfaceGuid, const std::wstring& profile, const DOT11_MAC_ADDRESS& bssid) = 0; 44 | 45 | /// @brief Disconnect the interface 46 | virtual void Disconnect(const GUID& interfaceGuid) = 0; 47 | 48 | /// @brief Schedule a scan on the interface 49 | /// @param ssid If non-null, a targeted scan will be done on this ssid (for hiden networks) 50 | virtual void Scan(const GUID& interfaceGuid, DOT11_SSID* ssid = nullptr) = 0; 51 | 52 | /// @brief Provide the list of scanned BSS 53 | /// It is flushed when a scan is scheduled and will be empty until a scan succeeds 54 | virtual std::vector GetScannedBssList(const GUID& interfaceGuid) = 0; 55 | 56 | /// @brief Provide the list of scanned networks 57 | /// It is flushed when a scan is scheduled and will be empty until a scan succeeds 58 | virtual std::vector GetScannedNetworkList(const GUID& interfaceGuid) = 0; 59 | }; 60 | 61 | /// @brief Implementation of WlanApiWrapper targetting the real Windows Wlan API 62 | class WlanApiWrapperImpl : public WlanApiWrapper 63 | { 64 | public: 65 | WlanApiWrapperImpl(); 66 | ~WlanApiWrapperImpl() override; 67 | 68 | WlanApiWrapperImpl(const WlanApiWrapperImpl&) = delete; 69 | WlanApiWrapperImpl(WlanApiWrapperImpl&&) = delete; 70 | WlanApiWrapperImpl& operator=(const WlanApiWrapperImpl&) = delete; 71 | WlanApiWrapperImpl& operator=(WlanApiWrapperImpl&&) = delete; 72 | 73 | std::vector EnumerateInterfaces() override; 74 | void Subscribe(const GUID& interfaceGuid, std::function callback) override; 75 | void Unsubscribe(const GUID& interfaceGuid) override; 76 | std::optional GetCurrentConnection(const GUID& interfaceGuid) override; 77 | void Connect(const GUID& interfaceGuid, const std::wstring& profile, const DOT11_MAC_ADDRESS& bssid) override; 78 | void Disconnect(const GUID& interfaceGuid) override; 79 | void Scan(const GUID& interfaceGuid, DOT11_SSID* ssid = nullptr) override; 80 | std::vector GetScannedBssList(const GUID& interfaceGuid) override; 81 | std::vector GetScannedNetworkList(const GUID& interfaceGuid) override; 82 | 83 | private: 84 | static void OnWlansvcEventCallback(PWLAN_NOTIFICATION_DATA pNotification, void* pContext) noexcept; 85 | void HandleWlansvcNotification(PWLAN_NOTIFICATION_DATA pNotification); 86 | 87 | private: 88 | HANDLE m_wlanHandle{}; 89 | wil::srwlock m_callbacksLock; 90 | std::unordered_map> m_callbacks; 91 | 92 | struct WlanApiDynFunctions 93 | { 94 | DynamicFunction WlanCloseHandle{c_wlanApi, "WlanCloseHandle"}; 95 | DynamicFunction WlanConnect{c_wlanApi, "WlanConnect"}; 96 | DynamicFunction WlanDisconnect{c_wlanApi, "WlanDisconnect"}; 97 | DynamicFunction WlanEnumInterfaces{c_wlanApi, "WlanEnumInterfaces"}; 98 | DynamicFunction WlanFreeMemory{c_wlanApi, "WlanFreeMemory"}; 99 | DynamicFunction WlanGetAvailableNetworkList{ 100 | c_wlanApi, "WlanGetAvailableNetworkList"}; 101 | DynamicFunction WlanGetNetworkBssList{c_wlanApi, "WlanGetNetworkBssList"}; 102 | DynamicFunction WlanOpenHandle{c_wlanApi, "WlanOpenHandle"}; 103 | DynamicFunction WlanQueryInterface{c_wlanApi, "WlanQueryInterface"}; 104 | DynamicFunction WlanRegisterNotification{c_wlanApi, "WlanRegisterNotification"}; 105 | DynamicFunction WlanScan{c_wlanApi, "WlanScan"}; 106 | }; 107 | WlanApiDynFunctions m_wlanApi{}; 108 | }; 109 | 110 | } // namespace ProxyWifi::Wlansvc -------------------------------------------------------------------------------- /lib/WlansvcOperationHandler.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | #pragma once 4 | 5 | #include "OperationHandler.hpp" 6 | #include "RealWlanInterface.hpp" 7 | #include "WlanSvcWrapper.hpp" 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | namespace ProxyWifi { 14 | 15 | class WlansvcOperationHandler : public OperationHandler 16 | { 17 | public: 18 | WlansvcOperationHandler(ProxyWifiObserver* pObserver, std::vector> wlanInterfaces, std::shared_ptr& wlansvc) 19 | : OperationHandler{pObserver, std::move(wlanInterfaces)}, m_wlansvc{wlansvc} 20 | { 21 | if (!m_wlansvc) 22 | { 23 | return; 24 | } 25 | 26 | m_wlansvc->Subscribe(GUID{}, [this](const auto& n) { 27 | if (n.NotificationSource == WLAN_NOTIFICATION_SOURCE_ACM) 28 | { 29 | switch (n.NotificationCode) 30 | { 31 | case wlan_notification_acm_interface_arrival: 32 | { 33 | // Check the notification is for a primary interface 34 | // (secondary interfaces arrival are notified, but not returned by EnumerateInterface) 35 | auto intf = m_wlansvc->EnumerateInterfaces(); 36 | auto foundIt = std::find(intf.begin(), intf.end(), n.InterfaceGuid); 37 | if (foundIt != intf.end()) 38 | { 39 | AddInterface([wlansvc = m_wlansvc, guid = n.InterfaceGuid] { 40 | return std::make_unique(wlansvc, guid); 41 | }); 42 | } 43 | break; 44 | } 45 | case wlan_notification_acm_interface_removal: 46 | RemoveInterface(n.InterfaceGuid); 47 | break; 48 | } 49 | } 50 | }); 51 | } 52 | 53 | ~WlansvcOperationHandler() override 54 | { 55 | if (m_wlansvc) 56 | { 57 | m_wlansvc->Unsubscribe(GUID{}); 58 | } 59 | } 60 | 61 | WlansvcOperationHandler(const WlansvcOperationHandler&) = delete; 62 | WlansvcOperationHandler(WlansvcOperationHandler&&) = delete; 63 | WlansvcOperationHandler& operator=(const WlansvcOperationHandler&) = delete; 64 | WlansvcOperationHandler& operator= (WlansvcOperationHandler&&) = delete; 65 | 66 | private: 67 | std::shared_ptr m_wlansvc; 68 | }; 69 | 70 | } // namespace ProxyWifi -------------------------------------------------------------------------------- /test/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | include(FetchContent) 5 | 6 | FetchContent_Declare( 7 | Catch2 8 | GIT_REPOSITORY https://github.com/catchorg/Catch2.git 9 | GIT_TAG v2.13.10 10 | ) 11 | 12 | FetchContent_MakeAvailable(Catch2) 13 | 14 | add_executable(proxy-wifi-test) 15 | target_link_libraries(proxy-wifi-test 16 | PRIVATE 17 | Catch2::Catch2 18 | proxy-wifi-util 19 | proxy-wifi 20 | Synchronization.lib 21 | ) 22 | 23 | target_include_directories(proxy-wifi-test 24 | PRIVATE 25 | ../lib 26 | ) 27 | 28 | target_sources(proxy-wifi-test 29 | PRIVATE 30 | main.cpp 31 | TestInit.cpp 32 | TestOpHandler.cpp 33 | TestUtils.cpp 34 | WlansvcMock.hpp 35 | ) 36 | 37 | # Allows CTest to discover Catch2 tests automatically 38 | list(APPEND CMAKE_MODULE_PATH "${catch2_SOURCE_DIR}/contrib") 39 | include(Catch) 40 | catch_discover_tests(proxy-wifi-test) 41 | -------------------------------------------------------------------------------- /test/TestInit.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include 5 | 6 | #include "WlanSvcWrapper.hpp" 7 | #include "OperationHandlerBuilder.hpp" 8 | 9 | using namespace ProxyWifi; 10 | 11 | TEST_CASE("Creating a WlanApiWrapper doesn't cause a crash", "[init]") 12 | { 13 | // The operation can succeed or fail depending on whether wlanapi.dll is available on the SKU 14 | // But the executable must load and not crash 15 | try 16 | { 17 | auto _ = std::make_unique(); 18 | } 19 | catch (...) 20 | { 21 | } 22 | } 23 | 24 | TEST_CASE("WlanApiWrapper is optionnal to create an OperationHandler", "[init]") 25 | { 26 | const auto opHandler = MakeWlansvcOperationHandler(std::shared_ptr{}, {}, {}); 27 | CHECK(opHandler); 28 | } 29 | -------------------------------------------------------------------------------- /test/TestUtils.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #include "catch2/catch.hpp" 5 | #include "StringUtils.hpp" 6 | #include "DynamicFunction.hpp" 7 | #include "WorkQueue.hpp" 8 | 9 | #include 10 | 11 | #include 12 | #include 13 | using namespace std::chrono_literals; 14 | 15 | // Tests for StringUtils.hpp 16 | 17 | TEST_CASE("ByteBufferToHexString format correctly", "[stringUtils]") 18 | { 19 | CHECK(ByteBufferToHexString(std::array{8, 10, 20}) == std::wstring(L"080a14")); 20 | CHECK(ByteBufferToHexString(std::array{}) == std::wstring()); 21 | } 22 | 23 | TEST_CASE("HexStringToByteBuffer parse correctly", "[stringUtils]") 24 | { 25 | CHECK(HexStringToByteBuffer(L"080a14") == std::vector{8, 10, 20}); 26 | CHECK(HexStringToByteBuffer(L"") == std::vector{}); 27 | CHECK_THROWS(HexStringToByteBuffer(L"080a1")); 28 | } 29 | 30 | TEST_CASE("GuidToString format correctly", "[stringUtils]") 31 | { 32 | constexpr GUID guid{0xfef2f808, 0xf267, 0x4728, {0xa0, 0xc5, 0x0a, 0x62, 0x40, 0xd0, 0x1b, 0x33}}; 33 | CHECK(GuidToString(guid) == std::wstring(L"{FEF2F808-F267-4728-A0C5-0A6240D01B33}")); 34 | } 35 | 36 | TEST_CASE("BssidToString format correctly", "[stringUtils]") 37 | { 38 | CHECK(BssidToString(std::array{216, 236, 94, 16, 126, 22}) == std::wstring(L"d8:ec:5e:10:7e:16")); 39 | CHECK(BssidToString(std::array{0, 0, 1, 0, 0, 0}) == std::wstring(L"00:00:01:00:00:00")); 40 | } 41 | 42 | TEST_CASE("SsidToLogString format correctly", "[stringUtils]") 43 | { 44 | CHECK(SsidToLogString(std::array{'m', 'y', ' ', 'w', 'i', 'f', 'i'}) == std::wstring(L"'my wifi' [226422226d792077696669]")); 45 | CHECK(SsidToLogString(std::array{}) == std::wstring(L"'' []")); 46 | } 47 | 48 | TEST_CASE("ListEnumToHexString format correctly", "[stringUtils]") 49 | { 50 | enum class Breakfast : uint32_t 51 | { 52 | Croissant = 0xabc11100, 53 | Chocolatine = 0xdef22200, 54 | Coffee = 0xabcdef00 55 | }; 56 | 57 | enum class Pizza 58 | { 59 | Cheese, 60 | Peperoni 61 | }; 62 | 63 | { 64 | auto a = std::array{Breakfast::Croissant, Breakfast::Chocolatine}; 65 | CHECK(ListEnumToHexString(std::span(a)) == std::wstring(L"abc11100 def22200")); 66 | } 67 | { 68 | auto v = std::vector{Breakfast::Coffee}; 69 | CHECK(ListEnumToHexString(std::span(v)) == std::wstring(L"abcdef00")); 70 | } 71 | { 72 | auto a = std::vector{}; 73 | CHECK(ListEnumToHexString(std::span(a)) == std::wstring(L"")); 74 | } 75 | { 76 | auto v = std::vector{Pizza::Cheese}; 77 | CHECK(ListEnumToHexString(std::span(v), L"-", 4) == std::wstring(L"0000")); 78 | } 79 | { 80 | auto v = std::vector{Pizza::Peperoni, Pizza::Cheese}; 81 | CHECK(ListEnumToHexString(std::span(v), L"-", 4) == std::wstring(L"0001-0000")); 82 | } 83 | } 84 | 85 | TEST_CASE("Dynamic function basic behavior works", "[dynamicFunction]") 86 | { 87 | SECTION("Loading an valid function from a valid module works") 88 | { 89 | CHECK_NOTHROW([]() { 90 | DynamicFunction dynFun{L"kernel32.dll", "GetTickCount"}; 91 | dynFun(); 92 | }()); 93 | } 94 | 95 | SECTION("Loading from a non-existing module throws") 96 | { 97 | CHECK_THROWS([]() { DynamicFunction dynFun{L"dummy.dll", "GetNativeSystemInfo"}; }()); 98 | } 99 | 100 | SECTION("Loading a non-existing function throws") 101 | { 102 | CHECK_THROWS([]() { DynamicFunction dynFun{L"kernel32.dll", "dummy"}; }()); 103 | } 104 | } 105 | 106 | TEST_CASE("Work queues execute work items", "[workQueue]") 107 | { 108 | 109 | SECTION("A work item is exectuted") 110 | { 111 | SerializedWorkQueue> wq; 112 | wil::slim_event event; 113 | wq.Submit([&] { event.SetEvent(); }); 114 | 115 | CHECK(event.wait(50 /* ms */)); 116 | } 117 | 118 | SECTION("Any callable type is supported and return value are ignored") 119 | { 120 | struct Work { 121 | int operator()() 122 | { 123 | event.SetEvent(); 124 | return 42; 125 | } 126 | wil::slim_event& event; 127 | }; 128 | 129 | SerializedWorkQueue wq; 130 | wil::slim_event event; 131 | wq.Submit(Work{event}); 132 | CHECK(event.wait(50 /* ms */)); 133 | } 134 | 135 | SECTION("Work items are executed asychronously in a different thread") 136 | { 137 | SerializedWorkQueue> wq; 138 | wil::slim_event event; 139 | wil::slim_event event2; 140 | wq.Submit([&] { 141 | event2.wait(); 142 | event.SetEvent(); 143 | }); 144 | 145 | event2.SetEvent(); 146 | CHECK(event.wait(50 /* ms */)); 147 | } 148 | } 149 | 150 | TEST_CASE("Work item cancellation works", "[workQueue]") 151 | { 152 | // Check that Cancel wait for currently running work items completion and cancel any pending one 153 | const int workScheduled = 10; 154 | std::atomic_int workStarted{0}; 155 | std::atomic_int workCompleted{0}; 156 | auto f = [&] { 157 | ++workStarted; 158 | std::this_thread::sleep_for(10ms); 159 | ++workCompleted; 160 | }; 161 | 162 | WorkQueue, 2, 2> wq; 163 | for (auto i = 0; i < workScheduled; ++i) 164 | { 165 | wq.Submit(f); 166 | } 167 | wq.Cancel(); 168 | 169 | CHECK(workStarted == workCompleted); 170 | CHECK(workStarted <= workScheduled); 171 | } 172 | 173 | TEST_CASE("Work item are serialized in a serialized queue", "[workQueue]") 174 | { 175 | // Work items are serialized in a serialized queue 176 | SerializedWorkQueue> wq; 177 | wil::slim_event event; 178 | std::atomic_bool firstWorkComplete; 179 | bool testPassed = false; 180 | 181 | wq.Submit([&] { 182 | std::this_thread::sleep_for(10ms); 183 | firstWorkComplete = true; 184 | }); 185 | wq.Submit([&] { 186 | testPassed = firstWorkComplete; 187 | event.SetEvent(); 188 | }); 189 | 190 | CHECK(event.wait(50 /* ms */)); 191 | CHECK(testPassed); 192 | } 193 | 194 | TEST_CASE("Light stress", "[workQueue]") 195 | { 196 | SECTION("All work item run in light stress situation") 197 | { 198 | WorkQueue, 5, 5> wq; 199 | wil::slim_event event; 200 | std::atomic count; 201 | constexpr auto numTasks = 1000; 202 | 203 | for (auto i = 0; i < numTasks; ++i) 204 | { 205 | wq.Submit([&] { 206 | auto v = ++count; 207 | if (v == numTasks) 208 | { 209 | event.SetEvent(); 210 | } 211 | }); 212 | } 213 | 214 | CHECK(event.wait(500 /* ms */)); 215 | } 216 | 217 | SECTION("All work item are serialized in light stress situation") 218 | { 219 | SerializedWorkQueue> wq; 220 | wil::slim_event event; 221 | auto count = 0; 222 | constexpr auto numTasks = 1000; 223 | bool testPassed = true; 224 | 225 | for (auto i = 0; i < numTasks; ++i) 226 | { 227 | wq.Submit([&, id = i] { 228 | if (id != count++) 229 | { 230 | testPassed = false; 231 | } 232 | if (count == numTasks) 233 | { 234 | event.SetEvent(); 235 | } 236 | }); 237 | } 238 | 239 | CHECK(event.wait(500 /* ms */)); 240 | CHECK(testPassed); 241 | } 242 | } 243 | 244 | TEST_CASE("Work runner basic tests", "[workQueue]") 245 | { 246 | SECTION("Work items are executed") 247 | { 248 | SerializedWorkRunner wq; 249 | wil::slim_event event; 250 | wil::slim_event event2; 251 | wq.Run([&] { 252 | event2.wait(); 253 | event.SetEvent(); 254 | }); 255 | 256 | event2.SetEvent(); 257 | CHECK(event.wait(50 /* ms */)); 258 | } 259 | 260 | SECTION("Return values are ignored if not waited for") 261 | { 262 | SerializedWorkRunner wq; 263 | // The task can return a value, which is ignored 264 | wil::slim_event event; 265 | wq.Run([&] { 266 | event.SetEvent(); 267 | return "pizza"; 268 | }); 269 | CHECK(event.wait(50 /* ms */)); 270 | } 271 | 272 | SECTION("One can wait for the return value") 273 | { 274 | SerializedWorkRunner wq; 275 | const auto r = wq.RunAndWait([&] { return 42; }); 276 | CHECK(r == 42); 277 | } 278 | 279 | SECTION("One can wait without a return value") 280 | { 281 | SerializedWorkRunner wq; 282 | int a = 0; 283 | wq.RunAndWait([&] { a = 42; }); 284 | CHECK(a == 42); 285 | } 286 | } -------------------------------------------------------------------------------- /test/main.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | 4 | #define CATCH_CONFIG_RUNNER 5 | #include "catch2/catch.hpp" 6 | 7 | #include "ProxyWifi/Logs.hpp" 8 | 9 | int main(int argc, char* argv[]) 10 | { 11 | ProxyWifi::Log::AddLogger(std::make_unique()); 12 | return Catch::Session().run(argc, argv); 13 | } 14 | -------------------------------------------------------------------------------- /util/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | add_library(proxy-wifi-util INTERFACE "") 5 | 6 | target_sources(proxy-wifi-util 7 | PRIVATE 8 | GuidUtils.hpp 9 | StringUtils.hpp 10 | WorkQueue.hpp 11 | DynamicFunction.hpp 12 | ) 13 | 14 | target_include_directories(proxy-wifi-util 15 | INTERFACE 16 | $ 17 | ) 18 | 19 | target_link_libraries(proxy-wifi-util 20 | INTERFACE 21 | rpcrt4.lib 22 | WIL 23 | ) 24 | 25 | install(TARGETS proxy-wifi-util 26 | EXPORT proxy-wifi-targets 27 | PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/proxy-wifi 28 | ) 29 | -------------------------------------------------------------------------------- /util/DynamicFunction.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | #pragma once 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | template 11 | class DynamicFunction; 12 | 13 | /// @brief Wrapper for a runtime dynamically loaded function 14 | template 15 | class DynamicFunction 16 | { 17 | public: 18 | DynamicFunction(const std::wstring& moduleName, const std::string& functionName) 19 | : m_module{LoadModule(moduleName)}, m_function{LoadFunction(m_module, functionName)} 20 | { 21 | } 22 | 23 | decltype(auto) operator()(Args... args) const 24 | { 25 | return m_function(std::forward(args)...); 26 | } 27 | 28 | private: 29 | static wil::unique_hmodule LoadModule(const std::wstring& name) 30 | { 31 | wil::unique_hmodule module{LoadLibraryEx(name.c_str(), nullptr, LOAD_LIBRARY_SEARCH_SYSTEM32)}; 32 | THROW_LAST_ERROR_IF(!module); 33 | return module; 34 | } 35 | 36 | static std::function LoadFunction(const wil::unique_hmodule& module, const std::string& name) 37 | { 38 | std::function function = reinterpret_cast(GetProcAddress(module.get(), name.c_str())); 39 | THROW_LAST_ERROR_IF(!function); 40 | return function; 41 | } 42 | 43 | private: 44 | wil::unique_hmodule m_module; 45 | std::function m_function; 46 | }; -------------------------------------------------------------------------------- /util/GuidUtils.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | #pragma once 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | template <> 10 | struct std::hash 11 | { 12 | std::size_t operator()(const GUID& guid) const noexcept 13 | { 14 | RPC_STATUS status = RPC_S_OK; 15 | return ::UuidHash(&const_cast(guid), &status); 16 | } 17 | }; 18 | 19 | inline bool operator<(const GUID& lhs, const GUID& rhs) 20 | { 21 | return memcmp(&lhs, &rhs, sizeof lhs) < 0; 22 | } -------------------------------------------------------------------------------- /util/StringUtils.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | #pragma once 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include 13 | #include 14 | 15 | /// @brief Allow to convert a buffer of bytes as a string of hexadecimal two-digit values 16 | /// @example [8, 10, 20] -> "080a14" 17 | inline static void AppendByteBufferAsHexString(std::wostringstream& stream, const std::span& byteBuffer) 18 | { 19 | stream << std::hex << std::setfill(L'0'); 20 | for (const auto& byte : byteBuffer) 21 | { 22 | stream << std::setw(2) << byte; 23 | } 24 | stream << std::dec << std::setfill(L' '); 25 | } 26 | 27 | inline static std::wstring ByteBufferToHexString(const std::span& byteBuffer) 28 | { 29 | std::wostringstream stream; 30 | AppendByteBufferAsHexString(stream, byteBuffer); 31 | return stream.str(); 32 | } 33 | 34 | inline static std::vector HexStringToByteBuffer(const std::wstring_view& s) 35 | { 36 | // Need an even string size, since two digits = 1 byte (assume leading zeros) 37 | if (s.size() % 2 != 0) 38 | { 39 | throw std::invalid_argument("String size must be even"); 40 | } 41 | 42 | std::vector byteBuffer; 43 | for (auto i = 0u; i < s.size(); i += 2) 44 | { 45 | std::wstring t{s.substr(i, 2)}; 46 | byteBuffer.push_back(static_cast(std::stoul(t, nullptr, 16))); 47 | } 48 | return byteBuffer; 49 | } 50 | 51 | /// @brief Convert a GUID to a string 52 | /// @example "{FEF2F808-F267-4728-A0C5-0A6240D01B33}" 53 | inline static std::wstring GuidToString(const GUID& guid) noexcept 54 | { 55 | wchar_t guidAsString[wil::guid_string_buffer_length]; 56 | StringFromGUID2(guid, guidAsString, wil::guid_string_buffer_length); 57 | return guidAsString; 58 | } 59 | 60 | /// @brief Convert a bssid to a string, as hexadecimal 61 | /// @example [216, 236, 94, 16, 126, 22] -> "d8:ec:5e:10:7e:16" 62 | inline static std::wstring BssidToString(const std::span bssid) 63 | { 64 | std::wostringstream stream; 65 | stream << std::hex << std::setfill(L'0'); 66 | 67 | stream << std::setw(2) << bssid.front(); 68 | for (const auto& byte : bssid.subspan<1>()) 69 | { 70 | stream << L':' << std::setw(2) << byte; 71 | } 72 | stream << std::dec << std::setfill(L' '); 73 | return stream.str(); 74 | } 75 | 76 | /// @brief Convert a ssid to a string for *logging only*, with ASCII and hexadecimal representations 77 | /// If the ssid isn't an ascii string, invalid character will be replaced by '?' 78 | /// @example ['m', 'y', ' ', 'w', 'i', 'f', 'i'] -> "'my wifi' [226422226d792077696669]" 79 | inline static std::wstring SsidToLogString(const std::span ssid) 80 | { 81 | std::wostringstream stream; 82 | stream << L"'" << std::wstring{ssid.begin(), ssid.end()} << L"' ["; 83 | for (const auto b : ssid) 84 | { 85 | stream << std::isprint(b) ? b : '?'; 86 | } 87 | AppendByteBufferAsHexString(stream, ssid); 88 | stream << L"]"; 89 | 90 | return stream.str(); 91 | } 92 | 93 | template , int> = 1> 94 | inline static std::wstring ListEnumToHexString(const std::span list, std::wstring_view sep = L" ", int width = 8) 95 | { 96 | if (list.empty()) 97 | { 98 | return L""; 99 | } 100 | 101 | std::wostringstream stream; 102 | stream << std::hex << std::setfill(L'0'); 103 | stream << std::setw(width) << WI_EnumValue(list.front()); 104 | for (const auto& e : list.subspan<1>()) 105 | { 106 | stream << sep << std::setw(width) << WI_EnumValue(e); 107 | } 108 | return stream.str(); 109 | } -------------------------------------------------------------------------------- /util/WorkQueue.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | #pragma once 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | // Rq: We need a WorkItem templated type because `std::function` requires the callable to be copyable, but we need it to 14 | // work for move-only work items (since `Message` is move only) 15 | template, int> = 1> 16 | class WorkQueue 17 | { 18 | public: 19 | WorkQueue() 20 | : m_threadPool(minThread, maxThread), m_threadPoolWork{m_threadPool.CreateWork(WorkCallback, this)} 21 | { 22 | } 23 | 24 | ~WorkQueue() noexcept 25 | { 26 | Cancel(); 27 | } 28 | 29 | void Submit(WorkItem&& workItem) 30 | { 31 | FAIL_FAST_IF(!m_threadPoolWork); 32 | { 33 | const auto lock = m_workQueueLock.lock_exclusive(); 34 | m_workQueue.emplace_back(std::forward(workItem)); 35 | } 36 | ::SubmitThreadpoolWork(m_threadPoolWork.get()); 37 | } 38 | 39 | void Cancel() noexcept 40 | { 41 | if (m_threadPoolWork) 42 | { 43 | // Delete all pending works 44 | { 45 | const auto lock = m_workQueueLock.lock_exclusive(); 46 | m_workQueue.clear(); 47 | } 48 | // Wait for the current work completion and close the threadpool 49 | m_threadPoolWork.reset(); 50 | } 51 | } 52 | 53 | WorkQueue(const WorkQueue&) = delete; 54 | WorkQueue& operator=(const WorkQueue&) = delete; 55 | WorkQueue(WorkQueue&&) = delete; 56 | WorkQueue& operator=(WorkQueue&&) = delete; 57 | 58 | private: 59 | 60 | struct ThreadPool 61 | { 62 | using unique_tp_env = wil::unique_struct; 63 | unique_tp_env m_threadpoolEnv; 64 | 65 | using unique_tp_pool = wil::unique_any; 66 | unique_tp_pool m_threadPool; 67 | 68 | ThreadPool(DWORD countMinThread, DWORD countMaxThread) 69 | { 70 | ::InitializeThreadpoolEnvironment(&m_threadpoolEnv); 71 | 72 | m_threadPool.reset(::CreateThreadpool(nullptr)); 73 | THROW_LAST_ERROR_IF_NULL(m_threadPool.get()); 74 | 75 | // Set min and max thread counts for custom thread pool 76 | const auto res = ::SetThreadpoolThreadMinimum(m_threadPool.get(), countMinThread); 77 | THROW_LAST_ERROR_IF(!res); 78 | ::SetThreadpoolThreadMaximum(m_threadPool.get(), countMaxThread); 79 | ::SetThreadpoolCallbackPool(&m_threadpoolEnv, m_threadPool.get()); 80 | } 81 | 82 | wil::unique_threadpool_work CreateWork(PTP_WORK_CALLBACK callback, void* context) 83 | { 84 | wil::unique_threadpool_work work(::CreateThreadpoolWork(callback, context, &m_threadpoolEnv)); 85 | THROW_LAST_ERROR_IF_NULL(work.get()); 86 | return work; 87 | } 88 | 89 | void Reset() noexcept 90 | { 91 | m_threadPool.reset(); 92 | m_threadpoolEnv.reset(); 93 | } 94 | }; 95 | 96 | static void CALLBACK WorkCallback(_Inout_ PTP_CALLBACK_INSTANCE, _In_ void* context, _Inout_ PTP_WORK) noexcept 97 | { 98 | auto* pThis = static_cast(context); 99 | std::optional workItem; 100 | { 101 | const auto lock = pThis->m_workQueueLock.lock_exclusive(); 102 | if (pThis->m_workQueue.empty()) 103 | { 104 | // `Cancel` has been called, the queue was cleared 105 | return; 106 | } 107 | 108 | workItem.emplace(std::move(pThis->m_workQueue.front())); 109 | pThis->m_workQueue.pop_front(); 110 | } 111 | 112 | (*workItem)(); 113 | } 114 | 115 | ThreadPool m_threadPool; 116 | wil::unique_threadpool_work m_threadPoolWork; 117 | mutable wil::srwlock m_workQueueLock; 118 | std::deque m_workQueue; 119 | }; 120 | 121 | /// @brief Helper class to run tasks in a serialized work queue 122 | /// It handles return values and allows to wait for the task completion 123 | template 124 | class WorkRunner 125 | { 126 | 127 | public: 128 | /// @brief Wait for the currently running task and cancel all others 129 | void Cancel() 130 | { 131 | m_workQueue.Cancel(); 132 | } 133 | 134 | /// @brief Helper to execute a task in the serialized workqueue without waiting for its completion 135 | /// (the task is still serialized with other task in the workqueue) 136 | template, int> = 1> 137 | void Run(F fun) 138 | { 139 | auto t = std::packaged_task{[fun = std::move(fun)] { 140 | fun(); 141 | return std::any{}; 142 | }}; 143 | m_workQueue.Submit(std::move(t)); 144 | } 145 | 146 | /// @brief Helper to execute a task in the workqueue and wait its completion 147 | template, int> = 1> 148 | decltype(auto) RunAndWait(F&& fun) 149 | { 150 | using RetType = decltype(fun()); 151 | if constexpr (std::is_void_v) 152 | { 153 | // Special handling for void return type 154 | auto task = std::packaged_task{[&fun] { 155 | fun(); 156 | return std::any{}; 157 | }}; 158 | auto future_answer = task.get_future(); 159 | 160 | m_workQueue.Submit(std::move(task)); 161 | future_answer.wait(); 162 | return; 163 | } 164 | else 165 | { 166 | // Capture by reference is ok since we wait for the result right after 167 | auto task = std::packaged_task{[&fun] { return std::make_any(fun()); }}; 168 | auto future_answer = task.get_future(); 169 | m_workQueue.Submit(std::move(task)); 170 | return std::any_cast(future_answer.get()); 171 | } 172 | } 173 | 174 | private: 175 | WorkQueue, minThread, maxThread> m_workQueue; 176 | }; 177 | 178 | template 179 | using SerializedWorkQueue = WorkQueue; 180 | using SerializedWorkRunner = WorkRunner<1, 1>; 181 | --------------------------------------------------------------------------------