├── .gitignore ├── AsgiHandler ├── AsgiHandler.def ├── CMakeLists.txt └── AsgiHandler.cpp ├── ProcessPool ├── ProcessPool.def ├── CMakeLists.txt ├── ProcessPool.cpp └── process-pool-iis-schema.xml ├── dependencies ├── CMakeLists.txt ├── msgpack-c │ └── CMakeLists.txt ├── googletest │ └── CMakeLists.txt └── hiredis │ └── CMakeLists.txt ├── cli ├── main.cpp ├── CMakeLists.txt └── cli.vcxproj ├── AsgiHandlerTest ├── main.cpp ├── mock_ChannelLayer.h ├── mock_ResponsePump.h ├── mock_HttpRequestHandler.h ├── test_RedisChannelLayer.cpp ├── CMakeLists.txt ├── test_SendToApplicationStep.cpp ├── test_FlushResponseStep.cpp ├── test_WriteResponseStep.cpp ├── test_ReadBodyStep.cpp ├── test_WaitForResponseStep.cpp └── mock_httpserv.h ├── IntegrationTests ├── conftest.py ├── test_config.py ├── requirements.txt ├── fixtures │ ├── __init__.py │ ├── asgi.py │ ├── django_worker.py │ ├── worker.py │ ├── requests.py │ ├── etw.py │ └── iis.py ├── CMakeLists.txt ├── test_pool.py ├── test_django_http.py ├── test_asgi_ws.py └── test_asgi_http.py ├── ProcessPoolTest ├── main.cpp ├── CMakeLists.txt └── test_ProcessPool.cpp ├── SharedUtilsTest ├── main.cpp ├── CMakeLists.txt └── test_Logger.cpp ├── AsgiHandlerLib ├── AsgiWsSendMsg.h ├── HttpRequestHandler.h ├── HttpModuleFactory.h ├── WsReader.h ├── AsgiMsg.h ├── ResponsePump.h ├── ScopedRedisReply.h ├── HttpModule.h ├── AsgiHttpResponseMsg.h ├── WsWriter.h ├── HttpModuleFactory.cpp ├── ChannelLayer.h ├── RedisChannelLayer.h ├── WsRequestHandlerSteps.h ├── AsgiWsConnectMsg.h ├── WsRequestHandler.h ├── CMakeLists.txt ├── AsgiWsReceiveMsg.h ├── AsgiHttpRequestMsg.h ├── RequestHandler.h ├── HttpModule.cpp ├── HttpRequestHandler.cpp ├── WsRequestHandler.cpp ├── WsRequestHandlerSteps.cpp ├── ResponsePump.cpp ├── HttpRequestHandlerSteps.h ├── WsReader.cpp ├── RedisChannelLayer.cpp ├── WsWriter.cpp ├── RequestHandler.cpp ├── HttpRequestHandlerSteps.cpp └── RedisAsgiHandlerLib.vcxproj ├── ProcessPoolLib ├── JobObject.h ├── CMakeLists.txt ├── JobObject.cpp ├── GlobalModule.h ├── ProcessPool.h ├── GlobalModule.cpp └── ProcessPool.cpp ├── SharedUtils ├── mock_Logger.h ├── CMakeLists.txt ├── Logger.cpp ├── Logger.h └── ScopedResources.h ├── appveyor.yml ├── CMakeLists.txt └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | */__pycache__/ 3 | 4 | /build/ 5 | -------------------------------------------------------------------------------- /AsgiHandler/AsgiHandler.def: -------------------------------------------------------------------------------- 1 | LIBRARY AsgiHandler 2 | EXPORTS 3 | RegisterModule -------------------------------------------------------------------------------- /ProcessPool/ProcessPool.def: -------------------------------------------------------------------------------- 1 | LIBRARY ProcessPool 2 | EXPORTS 3 | RegisterModule -------------------------------------------------------------------------------- /dependencies/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | include(ExternalProject) 2 | 3 | add_subdirectory(googletest) 4 | add_subdirectory(hiredis) 5 | add_subdirectory(msgpack-c) 6 | -------------------------------------------------------------------------------- /cli/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include "ProcessPool.h" 6 | 7 | 8 | void main() 9 | { 10 | std::cin.get(); 11 | } 12 | -------------------------------------------------------------------------------- /AsgiHandlerTest/main.cpp: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | 3 | 4 | int main(int argc, char **argv) 5 | { 6 | ::testing::InitGoogleTest(&argc, argv); 7 | return RUN_ALL_TESTS(); 8 | } 9 | -------------------------------------------------------------------------------- /IntegrationTests/conftest.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | 4 | from __future__ import unicode_literals, print_function, absolute_import 5 | 6 | from fixtures import * 7 | -------------------------------------------------------------------------------- /ProcessPoolTest/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | int main(int argc, char **argv) 5 | { 6 | ::testing::InitGoogleTest(&argc, argv); 7 | return RUN_ALL_TESTS(); 8 | } 9 | -------------------------------------------------------------------------------- /SharedUtilsTest/main.cpp: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | 3 | 4 | int main(int argc, char **argv) 5 | { 6 | ::testing::InitGoogleTest(&argc, argv); 7 | return RUN_ALL_TESTS(); 8 | } 9 | -------------------------------------------------------------------------------- /cli/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # cli.exe -- Just for testing out ideas. 2 | 3 | 4 | add_executable(cli 5 | main.cpp 6 | ) 7 | 8 | target_link_libraries(cli 9 | AsgiHandlerLib 10 | ProcessPoolLib 11 | ) -------------------------------------------------------------------------------- /AsgiHandler/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # AsgiHandler.dll - IIS Request Handler module that allows it to behave as an ASGI interface server. 2 | 3 | 4 | add_library(AsgiHandler SHARED 5 | AsgiHandler.cpp 6 | AsgiHandler.def 7 | ) 8 | 9 | target_link_libraries(AsgiHandler 10 | AsgiHandlerLib 11 | ) 12 | -------------------------------------------------------------------------------- /AsgiHandlerLib/AsgiWsSendMsg.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | 6 | class AsgiWsSendMsg 7 | { 8 | public: 9 | std::vector bytes; 10 | bool close{false}; 11 | // TODO: Custom unpacker so that we can unpack `text` into `bytes`? 12 | 13 | MSGPACK_DEFINE_MAP(bytes, close); 14 | }; 15 | -------------------------------------------------------------------------------- /ProcessPool/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # ProcessPool.dll - IIS module that creates a pool of processes with the same lifetime as the application pool. 2 | 3 | 4 | add_library(ProcessPool SHARED 5 | ProcessPool.cpp 6 | ProcessPool.def 7 | 8 | process-pool-iis-schema.xml 9 | ) 10 | 11 | target_link_libraries(ProcessPool 12 | ProcessPoolLib 13 | ) 14 | -------------------------------------------------------------------------------- /IntegrationTests/test_config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | 4 | from __future__ import unicode_literals, print_function 5 | 6 | import pytest 7 | 8 | 9 | def test_doesnt_interfere_with_other_handlers(site, asgi, session): 10 | future = session.get(site.url + site.static_path) 11 | asgi.assert_empty() 12 | resp = future.result(timeout=2) 13 | assert resp.text == 'Hello, world!' 14 | asgi.assert_empty() 15 | -------------------------------------------------------------------------------- /AsgiHandler/AsgiHandler.cpp: -------------------------------------------------------------------------------- 1 | #define WIN32_LEAN_AND_MEAN 2 | #include 3 | 4 | #include "HttpModuleFactory.h" 5 | 6 | 7 | HRESULT __stdcall RegisterModule( 8 | DWORD iisVersion, IHttpModuleRegistrationInfo* moduleInfo, IHttpServer* httpServer) 9 | { 10 | auto factory = new HttpModuleFactory(moduleInfo->GetId()); 11 | HRESULT hr = moduleInfo->SetRequestNotifications( 12 | factory, RQ_EXECUTE_REQUEST_HANDLER, 0 13 | ); 14 | 15 | return S_OK; 16 | } -------------------------------------------------------------------------------- /ProcessPool/ProcessPool.cpp: -------------------------------------------------------------------------------- 1 | #define WIN32_LEAN_AND_MEAN 2 | #include 3 | 4 | #include "GlobalModule.h" 5 | 6 | 7 | HRESULT __stdcall RegisterModule( 8 | DWORD iis_version, IHttpModuleRegistrationInfo* module_info, IHttpServer* http_server 9 | ) 10 | { 11 | auto module = new GlobalModule(http_server); 12 | auto hr = module_info->SetGlobalNotifications( 13 | module, 14 | GL_APPLICATION_START | GL_APPLICATION_STOP 15 | ); 16 | return hr; 17 | } -------------------------------------------------------------------------------- /IntegrationTests/requirements.txt: -------------------------------------------------------------------------------- 1 | asgi-redis==0.10.0 2 | asgiref==0.11.2 3 | colorama==0.3.7 4 | msgpack-python==0.4.7 5 | py==1.4.31 6 | pytest==2.9.1 7 | redis==2.10.5 8 | requests==2.9.1 9 | six==1.10.0 10 | requests-futures==0.9.7 11 | futures==3.0.5 12 | pypiwin32==219 13 | websocket-client==0.37.0 14 | Django==1.10 15 | channels==0.17.2 16 | daphne==0.14.3 17 | twisted==16.2.0 18 | psutil==4.3.0 19 | 20 | -e git+https://github.com/sebmarchand/pyetw.git@302431b0eecc7698f1b7641ed8cf39f8769beb4b#egg=ETW 21 | -------------------------------------------------------------------------------- /ProcessPoolLib/JobObject.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "ScopedResources.h" 4 | 5 | 6 | class JobObject 7 | { 8 | public: 9 | JobObject(); 10 | JobObject(JobObject&& other) = default; 11 | JobObject(JobObject&) = delete; 12 | 13 | void operator=(JobObject) = delete; 14 | JobObject& operator=(JobObject&&) = default; 15 | 16 | HANDLE GetHandle() const { return m_handle.Get(); } 17 | 18 | protected: 19 | void Create(); 20 | 21 | private: 22 | ScopedHandle m_handle; 23 | }; 24 | -------------------------------------------------------------------------------- /AsgiHandlerTest/mock_ChannelLayer.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "gmock/gmock.h" 4 | 5 | #include "ChannelLayer.h" 6 | 7 | 8 | class AsgiHttpRequestMsg; 9 | 10 | 11 | class MockChannelLayer : public ChannelLayer 12 | { 13 | public: 14 | MOCK_METHOD2(Send, void(const std::string& channel, const msgpack::sbuffer& msg)); 15 | MOCK_METHOD2(ReceiveMany, 16 | std::tuple( 17 | const std::vector& channels, bool blocking 18 | ) 19 | ); 20 | }; 21 | -------------------------------------------------------------------------------- /SharedUtils/mock_Logger.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "gmock/gmock.h" 4 | 5 | #include "Logger.h" 6 | 7 | 8 | class MockLogger : public Logger 9 | { 10 | public: 11 | MOCK_CONST_METHOD1(Log, void(const std::string& msg)); 12 | 13 | // The default GetStream() and level() methods are fine. The LogStream class doesn't do 14 | // anything but call bak to Log(). By not mocking them, we can test them. 15 | // MOCK_CONST_METHOD1(debug, LoggerStream()); 16 | // MOCK_CONST_METHOD1(GetStream, LoggerStream()); 17 | }; 18 | -------------------------------------------------------------------------------- /ProcessPool/process-pool-iis-schema.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /AsgiHandlerTest/mock_ResponsePump.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "gmock/gmock.h" 4 | 5 | #include "mock_Logger.h" 6 | #include "mock_ChannelLayer.h" 7 | #include "ResponsePump.h" 8 | 9 | 10 | class MockResponsePump : public ResponsePump 11 | { 12 | public: 13 | MockResponsePump() 14 | : ResponsePump(MockLogger(), MockChannelLayer()) 15 | { } 16 | 17 | MOCK_METHOD2(AddChannel, void(const std::string& channel, const ResponseChannelCallback& callback)); 18 | MOCK_METHOD1(RemoveChannel, void(const std::string& channel)); 19 | }; 20 | -------------------------------------------------------------------------------- /AsgiHandlerTest/mock_HttpRequestHandler.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "gmock/gmock.h" 4 | 5 | #include "mock_Logger.h" 6 | #include "mock_ChannelLayer.h" 7 | #include "HttpRequestHandler.h" 8 | 9 | 10 | using ::testing::NiceMock; 11 | 12 | 13 | class MockHttpRequestHandler : public HttpRequestHandler 14 | { 15 | public: 16 | MockHttpRequestHandler(MockResponsePump& response_pump, MockIHttpContext* http_context) 17 | : HttpRequestHandler(response_pump, channels, logger, http_context) 18 | { } 19 | 20 | MockChannelLayer channels; 21 | NiceMock logger; 22 | }; 23 | -------------------------------------------------------------------------------- /AsgiHandlerLib/HttpRequestHandler.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "RequestHandler.h" 4 | #include "HttpRequestHandlerSteps.h" 5 | 6 | 7 | class HttpRequestHandler : public RequestHandler 8 | { 9 | public: 10 | using RequestHandler::RequestHandler; 11 | 12 | virtual REQUEST_NOTIFICATION_STATUS OnExecuteRequestHandler(); 13 | virtual REQUEST_NOTIFICATION_STATUS OnAsyncCompletion(IHttpCompletionInfo* completion_info); 14 | 15 | protected: 16 | bool ReturnError(USHORT status = 500, const std::string& reason = ""); 17 | 18 | std::unique_ptr m_current_step; 19 | }; 20 | -------------------------------------------------------------------------------- /SharedUtils/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # SharedUtils.lib - Static library of shared utilities for ProcessPool.lib and AsgiHandler.dll 2 | 3 | 4 | set(SOURCES 5 | Logger.cpp 6 | ) 7 | 8 | 9 | set(HEADERS 10 | Logger.h 11 | ScopedResources.h 12 | 13 | # This is here because it's used by multiple *Test projects. 14 | mock_Logger.h 15 | ) 16 | 17 | 18 | add_library(SharedUtils STATIC 19 | ${SOURCES} 20 | ${HEADERS} 21 | ) 22 | 23 | target_link_libraries(SharedUtils 24 | ) 25 | 26 | target_include_directories(SharedUtils PUBLIC 27 | $ 28 | $ 29 | ) 30 | 31 | -------------------------------------------------------------------------------- /SharedUtilsTest/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # SharedUtilsTest.exe - Unit tests for SharedUtils.lib 2 | 3 | 4 | set(SOURCES 5 | main.cpp 6 | 7 | test_Logger.cpp 8 | ) 9 | 10 | 11 | set(HEADERS 12 | ) 13 | 14 | 15 | add_executable(SharedUtilsTest 16 | ${SOURCES} 17 | ${HEADERS} 18 | ) 19 | 20 | target_link_libraries(SharedUtilsTest 21 | AsgiHandlerLib 22 | googletest 23 | googlemock 24 | ) 25 | 26 | 27 | source_group("Header Files\\Mocks" REGULAR_EXPRESSION /mock_[^/]*.h$) 28 | source_group("Source Files\\Tests" REGULAR_EXPRESSION /test_[^/]*.cpp$) 29 | 30 | 31 | add_test(NAME SharedUtilsTest 32 | COMMAND $ 33 | ) 34 | -------------------------------------------------------------------------------- /ProcessPoolLib/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # ProcessPoolLib.lib - For use in ProcessPool.dll, but compiled as a lib so that it can also be used by tests. 2 | 3 | set(SOURCES 4 | GlobalModule.cpp 5 | 6 | ProcessPool.cpp 7 | JobObject.cpp 8 | ) 9 | 10 | 11 | set(HEADERS 12 | GlobalModule.h 13 | 14 | ProcessPool.h 15 | JobObject.h 16 | ) 17 | 18 | 19 | add_library(ProcessPoolLib STATIC 20 | ${SOURCES} 21 | ${HEADERS} 22 | ) 23 | 24 | target_link_libraries(ProcessPoolLib 25 | SharedUtils 26 | ) 27 | 28 | target_include_directories(ProcessPoolLib PUBLIC 29 | $ 30 | $ 31 | ) 32 | 33 | -------------------------------------------------------------------------------- /ProcessPoolTest/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # ProcessPoolTest.exe - Unit tests for ProcessPoolLib.lib 2 | 3 | 4 | set(SOURCES 5 | main.cpp 6 | 7 | test_ProcessPool.cpp 8 | ) 9 | 10 | 11 | set(HEADERS 12 | ) 13 | 14 | 15 | add_executable(ProcessPoolTest 16 | ${SOURCES} 17 | ${HEADERS} 18 | ) 19 | 20 | target_link_libraries(ProcessPoolTest 21 | ProcessPoolLib 22 | googletest 23 | googlemock 24 | ) 25 | 26 | 27 | source_group("Header Files\\Mocks" REGULAR_EXPRESSION /mock_[^/]*.h$) 28 | source_group("Source Files\\Tests" REGULAR_EXPRESSION /test_[^/]*.cpp$) 29 | 30 | 31 | add_test(NAME ProcessPoolTest 32 | COMMAND $ 33 | ) 34 | -------------------------------------------------------------------------------- /appveyor.yml: -------------------------------------------------------------------------------- 1 | version: 0.1.0.{build} 2 | 3 | platform: 4 | - x64 5 | - Win32 6 | 7 | configuration: 8 | - Release 9 | 10 | services: 11 | - iis 12 | 13 | install: 14 | - choco install redis-64 15 | - redis-server --service-install C:\ProgramData\chocolatey\lib\redis-64\redis.windows.conf 16 | - redis-server --service-start 17 | 18 | build_script: 19 | - set GENERATOR=Visual Studio 14 2015 20 | - if %PLATFORM% == x64 set GENERATOR=%GENERATOR% Win64 21 | - mkdir build 22 | - cd build 23 | - cmake .. -G "%GENERATOR%" 24 | - cmake --build . --config %CONFIGURATION% 25 | - cd .. 26 | 27 | test_script: 28 | - cd build 29 | - ctest -C %CONFIGURATION% --output-on-failure 30 | -------------------------------------------------------------------------------- /IntegrationTests/fixtures/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | 4 | from __future__ import unicode_literals, print_function, absolute_import 5 | 6 | import pytest 7 | 8 | from .asgi import asgi 9 | from .django_worker import django_worker 10 | from .etw import * 11 | from .iis import * 12 | from .requests import * 13 | 14 | 15 | # Taken from pytest docs - makes report available in fixture. 16 | @pytest.hookimpl(tryfirst=True, hookwrapper=True) 17 | def pytest_runtest_makereport(item, call): 18 | # execute all other hooks to obtain the report object 19 | outcome = yield 20 | if call.when == 'call': 21 | rep = outcome.get_result() 22 | setattr(item, "_report", rep) 23 | -------------------------------------------------------------------------------- /SharedUtilsTest/test_Logger.cpp: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "gmock/gmock.h" 3 | 4 | #include "mock_Logger.h" 5 | 6 | using ::testing::DoAll; 7 | using ::testing::Return; 8 | using ::testing::SaveArg; 9 | using ::testing::_; 10 | 11 | TEST(LoggerTest, Basic) 12 | { 13 | // We're using MockLogger, but this just mocks out the calls to ETW (mostly). 14 | // We can still use this to test the streams, which is what we're doing here. 15 | MockLogger logger; 16 | 17 | std::string output; 18 | EXPECT_CALL(logger, Log(_)) 19 | .WillOnce(SaveArg<0>(&output)); 20 | 21 | logger.debug() << "This is a test " << 1 << 2 << 3; 22 | 23 | EXPECT_EQ("This is a test 123", output); 24 | } 25 | -------------------------------------------------------------------------------- /ProcessPoolLib/JobObject.cpp: -------------------------------------------------------------------------------- 1 | #include "JobObject.h" 2 | 3 | #define WIN32_LEAN_AND_MEAN 4 | #include 5 | 6 | 7 | JobObject::JobObject() 8 | { 9 | Create(); 10 | } 11 | 12 | 13 | void JobObject::Create() 14 | { 15 | m_handle = ScopedHandle{ ::CreateJobObject(nullptr, nullptr) }; 16 | 17 | // Tell the Job to terminate all of its processes when the handle 18 | // to it is closed. 19 | auto job_info = JOBOBJECT_EXTENDED_LIMIT_INFORMATION{ 0 }; 20 | job_info.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE; 21 | ::SetInformationJobObject( 22 | m_handle.Get(), JobObjectExtendedLimitInformation, 23 | &job_info, sizeof(job_info) 24 | ); 25 | } 26 | -------------------------------------------------------------------------------- /AsgiHandlerLib/HttpModuleFactory.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #define WIN32_LEAN_AND_MEAN 7 | #include 8 | 9 | #include "Logger.h" 10 | #include "RedisChannelLayer.h" 11 | #include "ResponsePump.h" 12 | 13 | 14 | class HttpModuleFactory : public IHttpModuleFactory 15 | { 16 | public: 17 | HttpModuleFactory(const HTTP_MODULE_ID& module_id); 18 | virtual ~HttpModuleFactory(); 19 | 20 | virtual HRESULT GetHttpModule(OUT CHttpModule** module, IN IModuleAllocator*); 21 | virtual void Terminate(); 22 | 23 | 24 | const HTTP_MODULE_ID& module_id() const { return m_module_id; } 25 | 26 | 27 | private: 28 | HTTP_MODULE_ID m_module_id; 29 | const EtwLogger logger; // must be declared before other members that rely on it. 30 | RedisChannelLayer m_channels; 31 | ResponsePump m_response_pump; 32 | }; 33 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.8) 2 | 3 | project(IisAsgiHandler) 4 | 5 | enable_testing() 6 | set_property(GLOBAL PROPERTY USE_FOLDERS On) 7 | 8 | 9 | set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /MT") 10 | set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /MTd") 11 | 12 | 13 | # Check whether we're building for x64 or x86 14 | if("${CMAKE_SIZEOF_VOID_P}" EQUAL "8") 15 | set(TARGET_PLATFORM "x64") 16 | else() 17 | set(TARGET_PLATFORM "x86") 18 | endif() 19 | 20 | 21 | add_subdirectory(dependencies) 22 | add_subdirectory(cli) 23 | add_subdirectory(SharedUtils) 24 | add_subdirectory(SharedUtilsTest) 25 | add_subdirectory(AsgiHandlerLib) 26 | add_subdirectory(AsgiHandlerTest) 27 | add_subdirectory(AsgiHandler) 28 | add_subdirectory(ProcessPoolLib) 29 | add_subdirectory(ProcessPoolTest) 30 | add_subdirectory(ProcessPool) 31 | add_subdirectory(IntegrationTests) 32 | 33 | -------------------------------------------------------------------------------- /AsgiHandlerLib/WsReader.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #define WIN32_LEAN_AND_MEAN 4 | #include 5 | 6 | #include "AsgiWsReceiveMsg.h" 7 | #include "Logger.h" 8 | 9 | 10 | class WsRequestHandler; 11 | class ChannelLayer; 12 | 13 | 14 | class WsReader 15 | { 16 | public: 17 | WsReader(WsRequestHandler& handler); 18 | 19 | void Start(const std::string& reply_channel, const std::string& request_path); 20 | 21 | private: 22 | void ReadAsync(); 23 | void ReadAsyncComplete(HRESULT hr, DWORD num_bytes, BOOL utf8, BOOL final_fragment, BOOL close); 24 | void SendToApplication(); 25 | 26 | const Logger& logger; 27 | ChannelLayer& m_channels; 28 | IHttpContext* m_http_context; 29 | IWebSocketContext* m_ws_context; 30 | AsgiWsReceiveMsg m_msg; 31 | 32 | static void WINAPI ReadCallback(HRESULT hr, VOID *context, DWORD num_bytes, BOOL utf8, BOOL final_fragment, BOOL close); 33 | }; 34 | -------------------------------------------------------------------------------- /AsgiHandlerLib/AsgiMsg.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | 6 | // Base class for ASGI messages that provides some helper functions for packing. 7 | // We need to use these in custom packer functions (rather than relying on MSGPACK_MAP 8 | // and friends) because we need control over whether std::string/std::vector are 9 | // written as binary or text, so that asgi_redis/channels decodes them correctly. 10 | class AsgiMsg 11 | { 12 | protected: 13 | template 14 | static void pack_string(Packer& packer, const std::string& str) 15 | { 16 | packer.pack_str(str.length()); 17 | packer.pack_str_body(str.c_str(), str.length()); 18 | } 19 | 20 | template 21 | static void pack_bytestring(Packer& packer, const std::string& str) 22 | { 23 | packer.pack_bin(str.length()); 24 | packer.pack_bin_body(str.c_str(), str.length()); 25 | } 26 | }; 27 | -------------------------------------------------------------------------------- /ProcessPoolLib/GlobalModule.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #define WIN32_LEAN_AND_MEAN 7 | #include 8 | 9 | #include "ProcessPool.h" 10 | #include "Logger.h" 11 | 12 | 13 | class GlobalModule : public CGlobalModule 14 | { 15 | public: 16 | GlobalModule(IHttpServer *http_server); 17 | 18 | virtual void Terminate() { delete this; } 19 | 20 | virtual GLOBAL_NOTIFICATION_STATUS OnGlobalApplicationStart( 21 | IHttpApplicationStartProvider* provider 22 | ); 23 | virtual GLOBAL_NOTIFICATION_STATUS OnGlobalApplicationStop( 24 | IHttpApplicationStopProvider* provider 25 | ); 26 | 27 | protected: 28 | void LoadConfiguration(IHttpApplication *application); 29 | std::wstring GetProperty(IAppHostElement *element, const std::wstring& name); 30 | 31 | private: 32 | const EtwLogger logger; 33 | IHttpServer *m_http_server; 34 | std::vector> m_pools; 35 | }; 36 | -------------------------------------------------------------------------------- /AsgiHandlerLib/ResponsePump.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "Logger.h" 8 | 9 | 10 | class ChannelLayer; 11 | 12 | 13 | class ResponsePump 14 | { 15 | public: 16 | ResponsePump(const Logger& logger, ChannelLayer& channels); 17 | ~ResponsePump(); 18 | 19 | void Start(); 20 | 21 | // TODO: Figure out exactly what this type should look like! 22 | using ResponseChannelCallback = std::function; 23 | 24 | virtual void AddChannel(const std::string& channel, const ResponseChannelCallback& callback); 25 | virtual void RemoveChannel(const std::string& channel); 26 | 27 | private: 28 | void ThreadMain(); 29 | 30 | 31 | const Logger& logger; 32 | ChannelLayer& m_channels; 33 | std::thread m_thread; 34 | bool m_thread_stop; 35 | std::unordered_map m_callbacks; 36 | std::mutex m_callbacks_mutex; 37 | }; 38 | 39 | -------------------------------------------------------------------------------- /AsgiHandlerLib/ScopedRedisReply.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | 6 | class ScopedRedisReply 7 | { 8 | public: 9 | explicit ScopedRedisReply(redisReply *reply) 10 | : m_reply(nullptr) 11 | { 12 | Set(reply); 13 | } 14 | 15 | ScopedRedisReply(ScopedRedisReply&& other) 16 | { 17 | m_reply = other.m_reply; 18 | } 19 | 20 | ScopedRedisReply(const ScopedRedisReply&) = delete; 21 | 22 | 23 | ~ScopedRedisReply() 24 | { 25 | Close(); 26 | } 27 | 28 | redisReply* operator->() 29 | { 30 | return m_reply; 31 | } 32 | 33 | void Set(redisReply *reply) 34 | { 35 | if (m_reply != reply) { 36 | Close(); 37 | m_reply = reply; 38 | } 39 | } 40 | 41 | void Close() 42 | { 43 | if (m_reply != nullptr) { 44 | freeReplyObject(m_reply); 45 | m_reply = nullptr; 46 | } 47 | 48 | } 49 | private: 50 | redisReply *m_reply; 51 | }; 52 | -------------------------------------------------------------------------------- /AsgiHandlerLib/HttpModule.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #define WIN32_LEAN_AND_MEAN 4 | #include 5 | 6 | #include "HttpModuleFactory.h" 7 | #include "ResponsePump.h" 8 | #include "Logger.h" 9 | 10 | 11 | class HttpModule : public CHttpModule 12 | { 13 | public: 14 | HttpModule( 15 | const HttpModuleFactory& factory, ResponsePump& response_pump, 16 | const Logger& logger 17 | ); 18 | ~HttpModule(); 19 | 20 | virtual REQUEST_NOTIFICATION_STATUS OnExecuteRequestHandler( 21 | IHttpContext* httpContext, IHttpEventProvider* provider 22 | ); 23 | 24 | virtual REQUEST_NOTIFICATION_STATUS OnAsyncCompletion( 25 | IHttpContext* http_context, 26 | DWORD notification, 27 | BOOL post_notification, 28 | IHttpEventProvider* provider, 29 | IHttpCompletionInfo* completion_info 30 | ); 31 | 32 | private: 33 | const HttpModuleFactory& m_factory; 34 | ResponsePump& m_response_pump; 35 | RedisChannelLayer m_channels; 36 | const Logger& logger; 37 | }; 38 | -------------------------------------------------------------------------------- /AsgiHandlerTest/test_RedisChannelLayer.cpp: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "gmock/gmock.h" 3 | 4 | #include "mock_ChannelLayer.h" 5 | 6 | 7 | // We use MockChannelLayer, which mocks the abstract methods. It leaves the 8 | // NewChannel() function untouched, which is what we're testing here. 9 | 10 | TEST(ChannelLayerTest, NewChannelGivesUniqueName) 11 | { 12 | MockChannelLayer channels1, channels2; 13 | 14 | // Different instances give unique names. 15 | EXPECT_NE(channels1.NewChannel("prefix"), channels2.NewChannel("prefix")); 16 | 17 | // Same instance gives a unique name whenever it is called. 18 | EXPECT_NE(channels1.NewChannel("prefix"), channels1.NewChannel("prefix")); 19 | } 20 | 21 | 22 | TEST(ChannelLayerTest, NewChannelRespectsPrefix) 23 | { 24 | MockChannelLayer channels; 25 | 26 | auto prefix = std::string{"myprefix"}; 27 | auto channel_name = channels.NewChannel(prefix); 28 | auto actual_prefix = channel_name.substr(0, prefix.length()); 29 | 30 | EXPECT_EQ(prefix, actual_prefix); 31 | } 32 | -------------------------------------------------------------------------------- /AsgiHandlerTest/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # AsgiHandlerTest.exe - Unit tests for AsgiHandlerLib.lib 2 | 3 | 4 | set(SOURCES 5 | main.cpp 6 | 7 | test_ReadBodyStep.cpp 8 | test_WaitForResponseStep.cpp 9 | test_FlushResponseStep.cpp 10 | test_RedisChannelLayer.cpp 11 | test_WriteResponseStep.cpp 12 | test_SendToApplicationStep.cpp 13 | ) 14 | 15 | 16 | set(HEADERS 17 | mock_httpserv.h 18 | mock_ResponsePump.h 19 | mock_HttpRequestHandler.h 20 | mock_ChannelLayer.h 21 | ) 22 | 23 | 24 | add_executable(AsgiHandlerTest 25 | ${SOURCES} 26 | ${HEADERS} 27 | ) 28 | 29 | target_link_libraries(AsgiHandlerTest 30 | AsgiHandlerLib 31 | googletest 32 | googlemock 33 | ) 34 | 35 | 36 | source_group("Header Files\\Mocks" REGULAR_EXPRESSION /mock_[^/]*.h$) 37 | source_group("Source Files\\Tests" REGULAR_EXPRESSION /test_[^/]*.cpp$) 38 | source_group("Source Files\\Tests\\Steps" REGULAR_EXPRESSION /test_[^/]*Step.cpp$) 39 | 40 | 41 | add_test(NAME AsgiHandlerTest 42 | COMMAND $ 43 | ) 44 | -------------------------------------------------------------------------------- /AsgiHandlerLib/AsgiHttpResponseMsg.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | 6 | /* 7 | status: Integer HTTP status code. 8 | headers: A list of [name, value] pairs, where name is the byte string header name, 9 | and value is the byte string header value. Order should be preserved in 10 | the HTTP response. Header names must be lowercased. 11 | content: Byte string of HTTP body content. Optional, defaults to empty string. 12 | more_content: Boolean value signifying if there is additional content to come (as 13 | part of a Response Chunk message). If False, response will be taken 14 | as complete and closed off, and any further messages on the channel 15 | will be ignored. Optional, defaults to False. 16 | */ 17 | 18 | class AsgiHttpResponseMsg 19 | { 20 | public: 21 | int status{0}; 22 | std::vector> headers; 23 | std::string content; 24 | bool more_content{false}; 25 | 26 | MSGPACK_DEFINE_MAP(status, headers, content, more_content); 27 | }; 28 | -------------------------------------------------------------------------------- /AsgiHandlerLib/WsWriter.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #define WIN32_LEAN_AND_MEAN 4 | #include 5 | #include 6 | 7 | #include "AsgiWsSendMsg.h" 8 | #include "ResponsePump.h" 9 | #include "Logger.h" 10 | 11 | 12 | class WsRequestHandler; 13 | 14 | class WsWriter 15 | { 16 | public: 17 | WsWriter(WsRequestHandler& handler); 18 | 19 | void Start(const std::string& reply_channel); 20 | 21 | private: 22 | void RegisterChannelsCallback(); 23 | void WriteAsync(); 24 | // Returns whether there's more of the current message to write: 25 | bool WriteAsyncComplete(HRESULT hr, DWORD num_bytes, BOOL utf8, BOOL final_fragment, BOOL close); 26 | 27 | const Logger& logger; 28 | ResponsePump& m_response_pump; 29 | IHttpContext* m_http_context; 30 | IWebSocketContext* m_ws_context; 31 | std::string m_reply_channel; 32 | std::unique_ptr m_asgi_send_msg; 33 | size_t m_bytes_written{0}; 34 | 35 | static void WINAPI WriteCallback(HRESULT hr, VOID *context, DWORD num_bytes, BOOL utf8, BOOL final_fragment, BOOL close); 36 | }; 37 | -------------------------------------------------------------------------------- /dependencies/msgpack-c/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Builds msgpack as an ExternalProject and exposes `msgpack-c` as a static library. 2 | 3 | ExternalProject_Add(dep-msgpack-c 4 | # v2.0 is now available, but contains some breaking changes. Consider upgrading. 5 | GIT_REPOSITORY https://github.com/msgpack/msgpack-c.git 6 | GIT_TAG cpp-1.4.2 7 | 8 | INSTALL_COMMAND "" 9 | ) 10 | 11 | 12 | ExternalProject_Get_Property(dep-msgpack-c source_dir) 13 | ExternalProject_Get_Property(dep-msgpack-c binary_dir) 14 | 15 | 16 | # This directories don't get created until build time, but CMake will complain if 17 | # it is not there at configure time. Create it. 18 | file(MAKE_DIRECTORY "${source_dir}/include") 19 | 20 | 21 | add_library(msgpack-c IMPORTED STATIC GLOBAL) 22 | set_target_properties(msgpack-c PROPERTIES 23 | IMPORTED_LOCATION_DEBUG "${binary_dir}/Debug/msgpackc${CMAKE_STATIC_LIBRARY_SUFFIX}" 24 | IMPORTED_LOCATION_RELEASE "${binary_dir}/Release/msgpackc${CMAKE_STATIC_LIBRARY_SUFFIX}" 25 | INTERFACE_INCLUDE_DIRECTORIES "${source_dir}/include/" 26 | ) 27 | add_dependencies(msgpack-c dep-msgpack-c) 28 | -------------------------------------------------------------------------------- /SharedUtils/Logger.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #define WIN32_LEAN_AND_MEAN 4 | #include 5 | #include 6 | 7 | #include "Logger.h" 8 | 9 | 10 | LoggerStream::~LoggerStream() 11 | { 12 | logger.Log(m_sstream.str()); 13 | } 14 | 15 | LoggerStream& LoggerStream::operator<<(const std::wstring& string) 16 | { 17 | // This is a bit sad. We'll conver the wstring to a utf-8 string 18 | // but then EtwLogger will convert it back to u16 wstring. 19 | std::wstring_convert> utf8_conv; 20 | m_sstream << utf8_conv.to_bytes(string); 21 | return *this; 22 | } 23 | 24 | 25 | EtwLogger::EtwLogger(const GUID& etw_guid) 26 | { 27 | ::EventRegister(&etw_guid, nullptr, nullptr, &m_etw_handle); 28 | } 29 | 30 | 31 | EtwLogger::~EtwLogger() 32 | { 33 | if (m_etw_handle) { 34 | ::EventUnregister(m_etw_handle); 35 | } 36 | } 37 | 38 | 39 | void EtwLogger::Log(const std::string& msg) const 40 | { 41 | std::wstring_convert> utf8_conv; 42 | std::wstring wmsg = utf8_conv.from_bytes(msg); 43 | ::EventWriteString(m_etw_handle, 0, 0, wmsg.c_str()); 44 | } 45 | -------------------------------------------------------------------------------- /AsgiHandlerLib/HttpModuleFactory.cpp: -------------------------------------------------------------------------------- 1 | #include "HttpModuleFactory.h" 2 | 3 | #define WIN32_LEAN_AND_MEAN 4 | #include 5 | 6 | #include "HttpModule.h" 7 | 8 | 9 | namespace { 10 | // {B057F98C-CB95-413D-AFAE-8ED010DB73C5} 11 | static const GUID logger_etw_guid = { 0xb057f98c, 0xcb95, 0x413d,{ 0xaf, 0xae, 0x8e, 0xd0, 0x10, 0xdb, 0x73, 0xc5 } }; 12 | } 13 | 14 | 15 | HttpModuleFactory::HttpModuleFactory(const HTTP_MODULE_ID& module_id) 16 | : logger(logger_etw_guid), m_module_id(module_id), 17 | m_response_pump(logger, m_channels) 18 | { 19 | logger.debug() << "Creating HttpModuleFactory"; 20 | m_response_pump.Start(); 21 | } 22 | 23 | HttpModuleFactory::~HttpModuleFactory() 24 | { 25 | logger.debug() << "Destroying HttpModuleFactory"; 26 | } 27 | 28 | HRESULT HttpModuleFactory::GetHttpModule(OUT CHttpModule ** module, IN IModuleAllocator *) 29 | { 30 | // This is called once for each request. We may want to return the same HttpModule 31 | // to multiple requests, given that the actual per-request state is stored elsewhere. 32 | *module = new HttpModule(*this, m_response_pump, logger); 33 | return S_OK; 34 | } 35 | 36 | void HttpModuleFactory::Terminate() 37 | { 38 | } 39 | -------------------------------------------------------------------------------- /AsgiHandlerLib/ChannelLayer.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | 11 | 12 | class ChannelLayer 13 | { 14 | public: 15 | ChannelLayer() 16 | : m_random_engine(std::random_device{}()) 17 | { } 18 | 19 | std::string NewChannel(std::string prefix) 20 | { 21 | return prefix + GenerateRandomAscii(10); 22 | } 23 | 24 | virtual void Send(const std::string& channel, const msgpack::sbuffer& buffer) = 0; 25 | virtual std::tuple ReceiveMany( 26 | const std::vector& channels, bool blocking = false 27 | ) = 0; 28 | 29 | protected: 30 | std::default_random_engine m_random_engine; 31 | 32 | std::string GenerateRandomAscii(size_t length) 33 | { 34 | static const char charset[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; 35 | static const size_t max_index = strlen(charset) - 1; 36 | std::string random_string(length, '0'); 37 | std::generate_n(random_string.begin(), length, [this]() { return charset[m_random_engine() % max_index]; }); 38 | return random_string; 39 | } 40 | }; 41 | -------------------------------------------------------------------------------- /AsgiHandlerLib/RedisChannelLayer.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #define WIN32_LEAN_AND_MEAN 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | #include "ChannelLayer.h" 12 | 13 | 14 | typedef std::unique_ptr> RedisReply; 15 | 16 | 17 | class RedisChannelLayer : public ChannelLayer 18 | { 19 | public: 20 | RedisChannelLayer(std::string ip = "127.0.0.1", int port = 6379, std::string prefix = "asgi:"); 21 | virtual ~RedisChannelLayer(); 22 | 23 | virtual void Send(const std::string& channel, const msgpack::sbuffer& buffer); 24 | virtual std::tuple ReceiveMany(const std::vector& channels, bool blocking = false); 25 | 26 | protected: 27 | 28 | template 29 | RedisReply ExecuteRedisCommand(std::string format_string, Args... args) 30 | { 31 | auto reply = static_cast(redisCommand(m_redis_ctx, format_string.c_str(), args...)); 32 | RedisReply wrapped_reply(reply, freeReplyObject); 33 | return wrapped_reply; 34 | } 35 | 36 | private: 37 | std::string m_prefix; 38 | int m_expiry; // seconds 39 | redisContext *m_redis_ctx; 40 | }; 41 | -------------------------------------------------------------------------------- /AsgiHandlerLib/WsRequestHandlerSteps.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "RequestHandler.h" 4 | #include "AsgiWsConnectMsg.h" 5 | 6 | 7 | class AcceptWebSocketStep : public RequestHandlerStep 8 | { 9 | public: 10 | AcceptWebSocketStep( 11 | RequestHandler& handler, std::unique_ptr& asgi_connect_msg 12 | ) : RequestHandlerStep(handler), m_asgi_connect_msg(std::move(asgi_connect_msg)) 13 | { } 14 | 15 | virtual StepResult Enter(); 16 | virtual StepResult OnAsyncCompletion(HRESULT hr, DWORD num_bytes); 17 | virtual std::unique_ptr GetNextStep(); 18 | 19 | private: 20 | std::unique_ptr m_asgi_connect_msg; 21 | }; 22 | 23 | class SendConnectToApplicationStep : public RequestHandlerStep 24 | { 25 | public: 26 | SendConnectToApplicationStep( 27 | RequestHandler& handler, std::unique_ptr& asgi_connect_msg 28 | ) : RequestHandlerStep(handler), m_asgi_connect_msg(std::move(asgi_connect_msg)) 29 | { } 30 | 31 | virtual StepResult Enter(); 32 | virtual StepResult OnAsyncCompletion(HRESULT hr, DWORD num_bytes); 33 | virtual std::unique_ptr GetNextStep(); 34 | 35 | private: 36 | std::unique_ptr m_asgi_connect_msg; 37 | }; 38 | -------------------------------------------------------------------------------- /IntegrationTests/fixtures/asgi.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | 4 | from __future__ import unicode_literals, print_function, absolute_import 5 | 6 | import pytest 7 | import asgi_redis 8 | 9 | 10 | class RedisChannelLayer(object): 11 | 12 | def __init__(self): 13 | # TODO: Use a custom prefix and configure IIS module to use it too. 14 | self.channels = asgi_redis.RedisChannelLayer() 15 | self.channels.flush() 16 | 17 | def _receive_one(self, channel_name): 18 | channel, data = self.channels.receive_many([channel_name], block=True) 19 | assert channel == channel_name 20 | return data 21 | 22 | def receive_request(self): 23 | return self._receive_one('http.request') 24 | 25 | def receive_ws_connect(self): 26 | return self._receive_one('websocket.connect') 27 | 28 | def receive_ws_data(self): 29 | return self._receive_one('websocket.receive') 30 | 31 | def send(self, channel, msg): 32 | self.channels.send(channel, msg) 33 | 34 | def assert_empty(self): 35 | channel, asgi_request = self.channels.receive_many(['http.request'], block=False) 36 | assert channel == None 37 | assert asgi_request == None 38 | 39 | 40 | @pytest.fixture 41 | def asgi(): 42 | return RedisChannelLayer() 43 | -------------------------------------------------------------------------------- /ProcessPoolLib/ProcessPool.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "JobObject.h" 8 | 9 | 10 | class Logger; 11 | 12 | 13 | class ProcessPool 14 | { 15 | public: 16 | ProcessPool( 17 | const Logger& logger, const std::wstring& process, 18 | const std::wstring& arguments, size_t num_processes 19 | ); 20 | ~ProcessPool(); 21 | 22 | void Stop() const; 23 | 24 | protected: 25 | void ThreadMain(); 26 | void CreateProcess(); 27 | static std::wstring EscapeArgument(const std::wstring& argument); 28 | 29 | private: 30 | const Logger& logger; 31 | JobObject m_job; 32 | 33 | // Process name, argument list and them both joined as the full command line. 34 | std::wstring m_process; 35 | std::wstring m_arguments; 36 | std::wstring m_command_line; 37 | // Number of processes we should aim to have in the pool. 38 | size_t m_num_processes; 39 | 40 | // The pool has its own thread which it uses to monitor the processes it creates, 41 | // so that it can restart them if they exit. 42 | std::thread m_thread; 43 | ScopedHandle m_thread_exit_event; 44 | 45 | // We store one ScopedHandle for each process that we create. These are owned 46 | // by the thread. 47 | std::vector m_processes; 48 | }; 49 | -------------------------------------------------------------------------------- /SharedUtils/Logger.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #define WIN32_LEAN_AND_MEAN 9 | #include 10 | #include 11 | 12 | 13 | class Logger; 14 | 15 | class LoggerStream 16 | { 17 | public: 18 | LoggerStream(const Logger& logger) 19 | : logger(logger) 20 | { } 21 | 22 | LoggerStream(LoggerStream&) = default; 23 | LoggerStream(LoggerStream&&) = default; 24 | 25 | virtual ~LoggerStream(); 26 | 27 | template 28 | LoggerStream& operator<<(const T& value) 29 | { 30 | m_sstream << value; 31 | return *this; 32 | } 33 | 34 | LoggerStream& operator<<(const std::wstring& string); 35 | 36 | protected: 37 | std::ostringstream m_sstream; 38 | const Logger& logger; 39 | }; 40 | 41 | 42 | class Logger 43 | { 44 | friend class LoggerStream; 45 | public: 46 | virtual LoggerStream debug() const { return GetStream(); } 47 | protected: 48 | virtual void Log(const std::string& msg) const = 0; 49 | virtual LoggerStream GetStream() const { return LoggerStream(*this); } 50 | }; 51 | 52 | 53 | class EtwLogger : public Logger 54 | { 55 | public: 56 | EtwLogger(const GUID& etw_guid); 57 | virtual ~EtwLogger(); 58 | protected: 59 | virtual void Log(const std::string& msg) const; 60 | private: 61 | ::REGHANDLE m_etw_handle{ 0 }; 62 | }; 63 | -------------------------------------------------------------------------------- /AsgiHandlerLib/AsgiWsConnectMsg.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "AsgiMsg.h" 6 | 7 | 8 | class AsgiWsConnectMsg : private AsgiMsg 9 | { 10 | public: 11 | std::string reply_channel; 12 | std::string scheme; 13 | std::string path; 14 | std::string query_string; 15 | std::string root_path; 16 | std::vector> headers; 17 | // `order` forced to 0 in pack function. 18 | 19 | // TODO: implement `client` and `server` fields. 20 | 21 | template 22 | void msgpack_pack(Packer& packer) const 23 | { 24 | packer.pack_map(7); 25 | 26 | pack_string(packer, "reply_channel"); 27 | pack_string(packer, reply_channel); 28 | pack_string(packer, "scheme"); 29 | pack_string(packer, scheme); 30 | pack_string(packer, "path"); 31 | pack_string(packer, path); 32 | pack_string(packer, "query_string"); 33 | pack_bytestring(packer, query_string); 34 | pack_string(packer, "root_path"); 35 | pack_bytestring(packer, root_path); 36 | 37 | pack_string(packer, "headers"); 38 | packer.pack_array(headers.size()); 39 | for (auto header : headers) { 40 | packer.pack_array(2); 41 | pack_bytestring(packer, std::get<0>(header)); 42 | pack_bytestring(packer, std::get<1>(header)); 43 | } 44 | 45 | pack_string(packer, "order"); 46 | packer.pack_int(0); 47 | } 48 | }; -------------------------------------------------------------------------------- /IntegrationTests/fixtures/django_worker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | 4 | import threading 5 | 6 | import pytest 7 | 8 | from django.conf import settings 9 | import django.urls 10 | 11 | import asgi_redis 12 | import channels.worker 13 | import channels.asgi 14 | 15 | 16 | class WorkerWrapper(object): 17 | """Wraper channels.worker.Worker and provides a nice interface for tests.""" 18 | 19 | def __init__(self, worker): 20 | self.worker = worker 21 | 22 | def set_urlconfs(self, urlconfs): 23 | class UrlConf(object): 24 | urlpatterns = urlconfs 25 | settings.ROOT_URLCONF = UrlConf() 26 | django.urls.clear_url_caches() 27 | 28 | def start(self): 29 | self.thread = threading.Thread(target=self.worker.run) 30 | self.thread.start() 31 | 32 | def stop(self): 33 | self.worker.termed =True 34 | self.thread.join() 35 | 36 | 37 | @pytest.yield_fixture 38 | def django_worker(request): 39 | if not settings.configured: 40 | settings.configure(DEBUG=True) 41 | 42 | channel_layer = asgi_redis.RedisChannelLayer() 43 | channel_layer_wrapper = channels.asgi.ChannelLayerWrapper( 44 | channel_layer, 'default', dict() 45 | ) 46 | channel_layer_wrapper.router.check_default() 47 | worker = channels.worker.Worker(channel_layer_wrapper, signal_handlers=False) 48 | 49 | worker_wrapper = WorkerWrapper(worker) 50 | try: 51 | yield worker_wrapper 52 | finally: 53 | worker_wrapper.stop() 54 | -------------------------------------------------------------------------------- /IntegrationTests/fixtures/worker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | 4 | """Simple worker process for use in test_pool.py""" 5 | 6 | import atexit 7 | import ctypes 8 | import sys 9 | 10 | 11 | WAIT_OBJECT_0 = 0x0 12 | WAIT_TIMEOUT = 0x102 13 | INFINITE = 0xFFFFFFFF 14 | SYNCHRONIZE = 0x00100000 15 | EVENT_MODIFY_STATE = 0x2 16 | 17 | 18 | if __name__ == '__main__': 19 | pool_name = sys.argv[1] 20 | object_prefix = u'Global\\ProcessPool_IntegrationTests_Worker_' + pool_name 21 | 22 | # The tests will signal this event when they want a process to exit. 23 | # The event is set to auto-reset, so only one process will exit each time it is signaled. 24 | exit_event = ctypes.windll.kernel32.OpenEventW( 25 | SYNCHRONIZE, False, 26 | object_prefix + '_Exit' 27 | ) 28 | assert exit_event, ctypes.GetLastError() 29 | 30 | # Increment a semaphore so the tests can see how many processes have started running. 31 | # As we want to use it as a counter, we start at 0 and each process ReleaseSemaphores() 32 | # to register itself. 33 | semaphore = ctypes.windll.kernel32.OpenSemaphoreW( 34 | SYNCHRONIZE | EVENT_MODIFY_STATE, False, 35 | object_prefix + '_Counter' 36 | ) 37 | assert semaphore, ctypes.GetLastError() 38 | ctypes.windll.kernel32.ReleaseSemaphore(semaphore, 1, None) 39 | 40 | # Continually wait on the exit_event until it is signaled. 41 | while True: 42 | wait_result = ctypes.windll.kernel32.WaitForSingleObject(exit_event, INFINITE) 43 | if wait_result == WAIT_OBJECT_0: 44 | break 45 | 46 | -------------------------------------------------------------------------------- /ProcessPoolTest/test_ProcessPool.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "ProcessPool.h" 4 | 5 | 6 | // Makes the EscapeArgument() function accessible from the tests. 7 | class ProcessPoolPublic : public ProcessPool 8 | { 9 | public: 10 | using ProcessPool::EscapeArgument; 11 | }; 12 | 13 | 14 | TEST(ProcessPoolEscapeArgument, NoSpecialCharacters) 15 | { 16 | EXPECT_EQ(ProcessPoolPublic::EscapeArgument(L"test"), L"test"); 17 | } 18 | 19 | TEST(ProcessPoolEscapeArgument, Spaces) 20 | { 21 | EXPECT_EQ(ProcessPoolPublic::EscapeArgument(L"a space"), L"\"a space\""); 22 | } 23 | 24 | TEST(ProcessPoolEscapeArgument, Quote) 25 | { 26 | // Quotes should be escaped. 27 | EXPECT_EQ(ProcessPoolPublic::EscapeArgument(L"a \"test\""), L"\"a \\\"test\\\"\""); 28 | } 29 | 30 | TEST(ProcessPoolEscapeArgument, Backslashes) 31 | { 32 | // Backslashes within the string should be escaped if followed by a quote. 33 | // (i.e. 1 backslash and a quote become 3 backslashes and a quote). 34 | EXPECT_EQ(ProcessPoolPublic::EscapeArgument(L"test \\\" str"), L"\"test \\\\\\\" str\""); 35 | 36 | // Backslashes within the string should not be escaped if not followed by 37 | // a quote. 38 | EXPECT_EQ(ProcessPoolPublic::EscapeArgument(L"test\\ str"), L"\"test\\ str\""); 39 | // Which means if there are no spaces, they shouldn't be escaped at all, 40 | // even if they appear at the end. 41 | EXPECT_EQ(ProcessPoolPublic::EscapeArgument(L"test\\"), L"test\\"); 42 | 43 | // If the string ends in a backslash, these must all be escaped, but the 44 | // final " should not be escaped. (i.e. 2 becomes 4). 45 | EXPECT_EQ(ProcessPoolPublic::EscapeArgument(L"test str\\\\"), L"\"test str\\\\\\\\\""); 46 | } 47 | -------------------------------------------------------------------------------- /AsgiHandlerLib/WsRequestHandler.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #define WIN32_LEAN_AND_MEAN 4 | #include 5 | #include 6 | 7 | #include "RequestHandler.h" 8 | #include "WsReader.h" 9 | #include "WsWriter.h" 10 | 11 | 12 | class ChannelLayer; 13 | 14 | 15 | class WsRequestHandler : public RequestHandler 16 | { 17 | // Having these as a friend saves injecting a bunch of parameters into the them 18 | // at construction. Their implementation is quite tightly couple to ours anyway. 19 | // They both have access to the IWebSocketContext* and access it without locking, 20 | // as we assume it is safe to do so. However, it might be safer to expose 21 | // IWebSocketContext* via some private member on this class and take a lock before 22 | // using it? It's difficult to work out what the WebSocket module expects. 23 | friend class WsReader; 24 | friend class WsWriter; 25 | 26 | public: 27 | WsRequestHandler( 28 | ResponsePump& response_pump, ChannelLayer& channels, const Logger& logger, IHttpContext* http_context 29 | ) 30 | : RequestHandler(response_pump, channels, logger, http_context), m_reader(*this), m_writer(*this) 31 | { } 32 | 33 | virtual REQUEST_NOTIFICATION_STATUS OnExecuteRequestHandler(); 34 | virtual REQUEST_NOTIFICATION_STATUS OnAsyncCompletion(IHttpCompletionInfo* completion_info); 35 | 36 | protected: 37 | void StartReaderWriter(); 38 | 39 | // While we're setting up the connection and sending the initial 40 | // `websocket.connect` ASGI message, we run this pipeline. 41 | std::unique_ptr m_current_connect_step; 42 | 43 | WsReader m_reader; 44 | WsWriter m_writer; 45 | std::string m_reply_channel; 46 | std::string m_request_path; 47 | }; 48 | -------------------------------------------------------------------------------- /AsgiHandlerLib/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # AsgiHandlerLib.lib - For use in AsgiHandler.dll, but compiled as a lib so that it can also be used by tests. 2 | 3 | 4 | set(SOURCES 5 | HttpModule.cpp 6 | HttpModuleFactory.cpp 7 | 8 | RedisChannelLayer.cpp 9 | ResponsePump.cpp 10 | 11 | RequestHandler.cpp 12 | HttpRequestHandler.cpp 13 | HttpRequestHandlerSteps.cpp 14 | 15 | WsRequestHandler.cpp 16 | WsRequestHandlerSteps.cpp 17 | WsReader.cpp 18 | WsWriter.cpp 19 | ) 20 | 21 | 22 | set(HEADERS 23 | HttpModule.h 24 | HttpModuleFactory.h 25 | 26 | ChannelLayer.h 27 | ResponsePump.h 28 | 29 | ScopedRedisReply.h 30 | RedisChannelLayer.h 31 | 32 | RequestHandler.h 33 | HttpRequestHandler.h 34 | HttpRequestHandlerSteps.h 35 | 36 | WsRequestHandler.h 37 | WsRequestHandlerSteps.h 38 | WsReader.h 39 | WsWriter.h 40 | 41 | AsgiMsg.h 42 | AsgiHttpRequestMsg.h 43 | AsgiHttpResponseMsg.h 44 | AsgiWsConnectMsg.h 45 | AsgiWsReceiveMsg.h 46 | AsgiWsSendMsg.h 47 | ) 48 | 49 | 50 | add_library(AsgiHandlerLib STATIC 51 | ${SOURCES} 52 | ${HEADERS} 53 | ) 54 | 55 | target_link_libraries(AsgiHandlerLib 56 | SharedUtils 57 | hiredis 58 | hiredis-interop 59 | msgpack-c 60 | ) 61 | 62 | target_include_directories(AsgiHandlerLib PUBLIC 63 | $ 64 | $ 65 | ) 66 | 67 | 68 | source_group("Source Files\\Http" REGULAR_EXPRESSION /Http[^/]*.cpp$) 69 | source_group("Header Files\\Http" REGULAR_EXPRESSION /Http[^/]*.h$) 70 | source_group("Source Files\\Ws" REGULAR_EXPRESSION /Ws[^/]*.cpp$) 71 | source_group("Header Files\\Ws" REGULAR_EXPRESSION /Ws[^/]*.h$) 72 | source_group("Header Files\\Asgi" REGULAR_EXPRESSION /Asgi[^/]*$) 73 | -------------------------------------------------------------------------------- /AsgiHandlerLib/AsgiWsReceiveMsg.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "AsgiMsg.h" 6 | 7 | 8 | class AsgiWsReceiveMsg : private AsgiMsg 9 | { 10 | public: 11 | AsgiWsReceiveMsg() 12 | : data(BUFFER_CHUNK_SIZE) 13 | { } 14 | 15 | // `order` starts at 1 as AsgiWsConnectMsg is the 0th message. 16 | int order{1}; 17 | std::string reply_channel; 18 | std::string path; 19 | 20 | // The buffer for holding the data can be up to BUFFER_CHUNK_SIZE bigger than 21 | // the actual length of the data, as we don't know how big the message is ahead 22 | // of time. `data_size` contains the true size of the message. 23 | std::vector data; 24 | size_t data_size{0}; 25 | bool utf8{true}; 26 | 27 | // The initial size of the buffer and the size that it will be incremented 28 | // by each time the current message gets within BUFFER_CHUNK_INCREASE_THRESHOLD. 29 | static const std::size_t BUFFER_CHUNK_SIZE = 4096; 30 | // The amount of freespace that must be remaining in the buffer before the current 31 | // buffer is expanded. 32 | static const std::size_t BUFFER_CHUNK_INCREASE_THRESHOLD = BUFFER_CHUNK_SIZE / 4; 33 | 34 | template 35 | void msgpack_pack(Packer& packer) const 36 | { 37 | packer.pack_map(4); 38 | 39 | pack_string(packer, "reply_channel"); 40 | pack_string(packer, reply_channel); 41 | 42 | pack_string(packer, "path"); 43 | pack_string(packer, path); 44 | 45 | pack_string(packer, "order"); 46 | packer.pack_int(order); 47 | 48 | // TODO: Convert to utf8. 49 | pack_string(packer, "bytes"); 50 | packer.pack_bin(data_size); 51 | packer.pack_bin_body(data.data(), data_size); 52 | } 53 | }; 54 | -------------------------------------------------------------------------------- /AsgiHandlerTest/test_SendToApplicationStep.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "gtest/gtest.h" 5 | #include "gmock/gmock.h" 6 | 7 | #include "mock_httpserv.h" 8 | #include "mock_ResponsePump.h" 9 | #include "mock_HttpRequestHandler.h" 10 | 11 | 12 | using ::testing::DoAll; 13 | using ::testing::Return; 14 | using ::testing::SetArgPointee; 15 | using ::testing::_; 16 | 17 | 18 | class SendToApplicationStepTest : public ::testing::Test 19 | { 20 | public: 21 | SendToApplicationStepTest() 22 | : handler(response_pump, &http_context), msg(std::make_unique()), 23 | step(handler, msg) 24 | { } 25 | 26 | MockIHttpContext http_context; 27 | MockResponsePump response_pump; 28 | MockHttpRequestHandler handler; 29 | std::unique_ptr msg; 30 | SendToApplicationStep step; 31 | }; 32 | 33 | 34 | ACTION_P(SetConditionVariable, condition) 35 | { 36 | condition->notify_all(); 37 | return S_OK; 38 | } 39 | 40 | 41 | TEST_F(SendToApplicationStepTest, SendToChannelAndReturnsAsyncPending) 42 | { 43 | std::condition_variable condition; 44 | std::mutex mutex; 45 | 46 | EXPECT_CALL(handler.channels, Send("http.request", _)) 47 | .Times(1); 48 | 49 | EXPECT_EQ(kStepAsyncPending, step.Enter()); 50 | 51 | EXPECT_CALL(http_context, PostCompletion(0)) 52 | .WillOnce(SetConditionVariable(&condition)); 53 | 54 | { 55 | std::unique_lock lock(mutex); 56 | condition.wait(lock); 57 | } 58 | } 59 | 60 | 61 | // This test will become more interesting if/when OnAsyncCompletion() has 62 | // to handle errors or cancellation. 63 | TEST_F(SendToApplicationStepTest, OnAsyncCompletionReturnsGotoNext) 64 | { 65 | EXPECT_EQ(kStepGotoNext, step.OnAsyncCompletion(S_OK, 0)); 66 | } 67 | -------------------------------------------------------------------------------- /dependencies/googletest/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Builds googletest as an ExternalProject and exposes `googletest` and `googlemock` as static libraries. 2 | 3 | 4 | ExternalProject_Add(dep-googletest 5 | # Unfortunately, the last released version is quite old and does not have 6 | # support for VS 2015. This commit is from master on 2016/08/07 - not a 7 | # lot of intelligence went into picking it. 8 | GIT_REPOSITORY https://github.com/google/googletest.git 9 | GIT_TAG ec44c6c1675c25b9827aacd08c02433cccde7780 10 | 11 | # Don't install. 12 | INSTALL_COMMAND "" 13 | ) 14 | 15 | ExternalProject_Get_Property(dep-googletest source_dir) 16 | ExternalProject_Get_Property(dep-googletest binary_dir) 17 | 18 | 19 | # deps-googletest doesn't create these directories until it's built, yet CMake will 20 | # complain at configure time if they don't exist. Create them to silence CMake. 21 | file(MAKE_DIRECTORY "${source_dir}/googletest/include") 22 | file(MAKE_DIRECTORY "${source_dir}/googlemock/include") 23 | 24 | 25 | add_library(googletest IMPORTED STATIC GLOBAL) 26 | set_target_properties(googletest PROPERTIES 27 | IMPORTED_LOCATION_DEBUG "${binary_dir}/googlemock/gtest/Debug/gtest${CMAKE_STATIC_LIBRARY_SUFFIX}" 28 | IMPORTED_LOCATION_RELEASE "${binary_dir}/googlemock/gtest/Release/gtest${CMAKE_STATIC_LIBRARY_SUFFIX}" 29 | INTERFACE_INCLUDE_DIRECTORIES "${source_dir}/googletest/include" 30 | ) 31 | add_dependencies(googletest dep-googletest) 32 | 33 | 34 | add_library(googlemock IMPORTED STATIC GLOBAL) 35 | set_target_properties(googlemock PROPERTIES 36 | IMPORTED_LOCATION_DEBUG "${binary_dir}/googlemock/Debug/gmock${CMAKE_STATIC_LIBRARY_SUFFIX}" 37 | IMPORTED_LOCATION_RELEASE "${binary_dir}/googlemock/Release/gmock${CMAKE_STATIC_LIBRARY_SUFFIX}" 38 | INTERFACE_INCLUDE_DIRECTORIES "${source_dir}/googlemock/include" 39 | ) 40 | add_dependencies(googlemock dep-googletest) 41 | -------------------------------------------------------------------------------- /AsgiHandlerLib/AsgiHttpRequestMsg.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "AsgiMsg.h" 6 | 7 | 8 | class AsgiHttpRequestMsg : private AsgiMsg 9 | { 10 | public: 11 | std::string reply_channel; 12 | std::string http_version; 13 | std::string method; 14 | std::string scheme; 15 | std::string path; 16 | std::string query_string; 17 | std::string root_path; 18 | std::vector> headers; 19 | std::vector body; 20 | 21 | // TODO: body_channel and chunking. 22 | 23 | template 24 | void msgpack_pack(Packer& packer) const 25 | { 26 | packer.pack_map(8); 27 | 28 | pack_string(packer, "reply_channel"); 29 | pack_string(packer, reply_channel); 30 | pack_string(packer, "http_version"); 31 | pack_string(packer, http_version); 32 | pack_string(packer, "method"); 33 | pack_string(packer, method); 34 | pack_string(packer, "scheme"); 35 | pack_string(packer, scheme); 36 | 37 | // The values of these entries must be byte strings. 38 | pack_string(packer, "path"); 39 | pack_bytestring(packer, path); 40 | pack_string(packer, "query_string"); 41 | pack_bytestring(packer, query_string); 42 | // Django seems to expect a unicode string, not a byte string. 43 | // pack_string(packer, "root_path"); 44 | // pack_bytestring(packer, root_path); 45 | 46 | pack_string(packer, "headers"); 47 | packer.pack_array(headers.size()); 48 | for (auto header : headers) { 49 | packer.pack_array(2); 50 | pack_bytestring(packer, std::get<0>(header)); 51 | pack_bytestring(packer, std::get<1>(header)); 52 | } 53 | 54 | pack_string(packer, "body"); 55 | packer.pack_bin(body.size()); 56 | packer.pack_bin_body(body.data(), body.size()); 57 | } 58 | }; 59 | -------------------------------------------------------------------------------- /dependencies/hiredis/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Builds hiredis as an ExternalProject and exposes `hiredis` and `hiredis-interop` as static libraries. 2 | # Both must be included by dependent targets. 3 | 4 | 5 | set(solution_dir msvs/vs-solutions/vs2015) 6 | 7 | ExternalProject_Add(dep-hiredis 8 | # We need to use the Microsoft fork of hiredis. It seems to be kept up-to-date 9 | # with redis/hiredis, but does not seem to make releases. Use master as of 2016/08/07. 10 | GIT_REPOSITORY https://github.com/Microsoft/hiredis.git 11 | GIT_TAG 05362104fa8d42a5ad64c042a79cbefc8fcfec33 12 | 13 | # The project is checked out into a intermediate directory, so this isn't truly 14 | # building in _our_ source. It's necessary because hiredis doesn't use CMake. 15 | BUILD_IN_SOURCE 1 16 | 17 | CONFIGURE_COMMAND "" 18 | # Always build both configurations. CMake targets will consume the right one. 19 | BUILD_COMMAND msbuild "/p:Configuration=Debug" "/p:Platform=${TARGET_PLATFORM}" "${solution_dir}/hiredis_win.sln" 20 | COMMAND msbuild "/p:Configuration=Release" "/p:Platform=${TARGET_PLATFORM}" "${solution_dir}/hiredis_win.sln" 21 | INSTALL_COMMAND "" 22 | 23 | LOG_BUILD 1 24 | ) 25 | 26 | ExternalProject_Get_Property(dep-hiredis source_dir) 27 | set(binary_dir "${source_dir}/${solution_dir}/bin/${TARGET_PLATFORM}") 28 | 29 | 30 | add_library(hiredis IMPORTED STATIC GLOBAL) 31 | set_target_properties(hiredis PROPERTIES 32 | IMPORTED_LOCATION_DEBUG "${binary_dir}/Debug/hiredis${CMAKE_STATIC_LIBRARY_SUFFIX}" 33 | IMPORTED_LOCATION_RELEASE "${binary_dir}/Release/hiredis${CMAKE_STATIC_LIBRARY_SUFFIX}" 34 | INTERFACE_INCLUDE_DIRECTORIES "${source_dir}/" 35 | ) 36 | add_dependencies(hiredis dep-hiredis) 37 | 38 | 39 | add_library(hiredis-interop IMPORTED STATIC GLOBAL) 40 | set_target_properties(hiredis-interop PROPERTIES 41 | IMPORTED_LOCATION_DEBUG "${binary_dir}/Debug/win32_interop${CMAKE_STATIC_LIBRARY_SUFFIX}" 42 | IMPORTED_LOCATION_RELEASE "${binary_dir}/Release/win32_interop${CMAKE_STATIC_LIBRARY_SUFFIX}" 43 | INTERFACE_INCLUDE_DIRECTORIES "${source_dir}/" 44 | ) 45 | add_dependencies(hiredis-interop dep-hiredis) 46 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # iis-asgi-handler — [![Build status](https://ci.appveyor.com/api/projects/status/yuaoo10qojr5825j/branch/master?svg=true)](https://ci.appveyor.com/project/mjkillough/iis-asgi-handler/branch/master) 2 | 3 | Module to allow IIS to act as an ASGI Interface Server for [Django Channels](https://github.com/andrewgodwin/channels/). Supports both HTTP and WebSockets. 4 | 5 | 6 | ## Building 7 | 8 | Requires: 9 | - Visual Studio 2015 10 | - CMake 2.8+ 11 | 12 | To build: 13 | 14 | - `mkdir build && cd build` 15 | - Then either: 16 | - x86: `cmake .. -G "Visual Studio 14 2015"` 17 | - x64: `cmake .. -G "Visual Studio 14 2015 Win64"` 18 | - `cmake --build . --configuration Debug` 19 | - Or open the `.sln` file it creates. 20 | 21 | 22 | ## Installing 23 | 24 | It isn't ready for people to install and use. If you're really keen, you can look at the code in `IntegrationTests/fixtures/` for an idea of how to install it. 25 | 26 | The module can not currently be run on IIS 7.5 and it requires the Web Sockets module to be installed. Eventually we will support IIS 7.5 (without Web Socket support). 27 | 28 | 29 | ## Tests 30 | 31 | There are two sets of tests: 32 | 33 | - Unit tests, in `RedisAsgiHandlerTests/`. These use `googletest` and `googlemock`. 34 | - Integration tests, in `IntegrationTests/`. These use `pytest` and install the module into IIS before each test. 35 | 36 | To run the integration tests you will need the following installed: 37 | - [redis](https://github.com/MSOpenTech/Redis) (or [redis-64](https://chocolatey.org/packages/redis-64) chocolatey package). 38 | - IIS 8+ with the Web Sockets module installed. 39 | 40 | To run all the tests: `ctest -C Debug --output-on-failure`. Append `-R AsgiHandler` to run just the unit tests, or `-R Integration` to run just the integration tests. 41 | 42 | 43 | ## Dependencies 44 | 45 | Relies on the following libraries, which are downloaded and built automatically: 46 | 47 | - `msgpack-c` 48 | - Microsoft fork of [`hiredis`](https://github.com/Microsoft/hiredis/) 49 | - `googletest` / `googlemock` for unit tests 50 | 51 | There are a number of Python dependencies recorded in `IntegrationTests/requirements.txt`. These are automatically installed into a venv when the project is built. 52 | -------------------------------------------------------------------------------- /AsgiHandlerTest/test_FlushResponseStep.cpp: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "gmock/gmock.h" 3 | 4 | #include "mock_httpserv.h" 5 | #include "mock_ResponsePump.h" 6 | #include "mock_HttpRequestHandler.h" 7 | 8 | 9 | using ::testing::DoAll; 10 | using ::testing::NiceMock; 11 | using ::testing::Return; 12 | using ::testing::SetArgPointee; 13 | using ::testing::_; 14 | 15 | 16 | class FlushResponseStepTest : public ::testing::Test 17 | { 18 | public: 19 | FlushResponseStepTest() 20 | : handler(response_pump, &http_context), step(handler, "channel") 21 | { 22 | ON_CALL(http_context, GetResponse()) 23 | .WillByDefault(Return(&response)); 24 | } 25 | 26 | NiceMock http_context; 27 | MockResponsePump response_pump; 28 | MockHttpRequestHandler handler; 29 | MockIHttpResponse response; 30 | FlushResponseStep step; 31 | }; 32 | 33 | 34 | TEST_F(FlushResponseStepTest, ReturnedFinishedOnError) 35 | { 36 | EXPECT_CALL(response, Flush(true, true, _, _)) 37 | .WillOnce(Return(E_ACCESSDENIED)); 38 | 39 | EXPECT_EQ(kStepFinishRequest, step.Enter()); 40 | } 41 | 42 | 43 | TEST_F(FlushResponseStepTest, ReturnsAsyncPending) 44 | { 45 | auto msg = std::make_unique(); 46 | EXPECT_CALL(response, Flush(true, true, _, _)) 47 | .WillOnce(DoAll( 48 | SetArgPointee<3>(TRUE), // completion_expected 49 | Return(S_OK) 50 | )); 51 | 52 | 53 | EXPECT_EQ(kStepAsyncPending, step.Enter()); 54 | } 55 | 56 | 57 | // If we repeatedly get completion_expected==FALSE, we might 58 | // finish this step without ever explicitly calling OnAsyncCompletion. 59 | TEST_F(FlushResponseStepTest, HandlesCompletionExpectedFalse) 60 | { 61 | // Right now OnAsyncCompletion() will always return kStepFinishRequest, 62 | // as we ignore the num_bytes. (There's nothing we could do with them, 63 | // as we do not know the full size of the response). 64 | EXPECT_CALL(response, Flush(true, true, _, _)) 65 | .WillOnce(DoAll( 66 | SetArgPointee<2>(12), // num_bytes 67 | SetArgPointee<3>(FALSE), // completion_expected 68 | Return(S_OK) 69 | )); 70 | 71 | EXPECT_EQ(kStepGotoNext, step.Enter()); 72 | } 73 | -------------------------------------------------------------------------------- /AsgiHandlerLib/RequestHandler.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #define WIN32_LEAN_AND_MEAN 4 | #include 5 | 6 | #include 7 | #include 8 | 9 | #include "ResponsePump.h" 10 | #include "Logger.h" 11 | 12 | 13 | enum StepResult { 14 | kStepAsyncPending, 15 | kStepRerun, // Re-Enter()s the same state. 16 | kStepGotoNext, // Causes us to call GetNextStep(). 17 | kStepFinishRequest 18 | }; 19 | 20 | 21 | class RequestHandler; 22 | 23 | 24 | class RequestHandlerStep 25 | { 26 | public: 27 | RequestHandlerStep(RequestHandler& handler); 28 | 29 | virtual StepResult Enter() = 0; 30 | virtual StepResult OnAsyncCompletion(HRESULT hr, DWORD num_bytes) = 0; 31 | virtual std::unique_ptr GetNextStep() = 0; 32 | 33 | protected: 34 | RequestHandler& m_handler; 35 | // Expose some of m_handler's protected members here, so that they are 36 | // accessible from Step subclasses. 37 | const Logger& logger; 38 | IHttpContext* m_http_context; 39 | ResponsePump& m_response_pump; 40 | ChannelLayer& m_channels; 41 | }; 42 | 43 | 44 | class RequestHandler : public IHttpStoredContext 45 | { 46 | friend class RequestHandlerStep; 47 | public: 48 | RequestHandler::RequestHandler( 49 | ResponsePump& response_pump, ChannelLayer& channels, const Logger& logger, IHttpContext* http_context 50 | ) 51 | : m_response_pump(response_pump), m_channels(channels), logger(logger), m_http_context(http_context) 52 | { } 53 | 54 | virtual void CleanupStoredContext() { delete this; } 55 | 56 | virtual REQUEST_NOTIFICATION_STATUS OnExecuteRequestHandler() = 0; 57 | virtual REQUEST_NOTIFICATION_STATUS OnAsyncCompletion(IHttpCompletionInfo* completion_info) = 0; 58 | 59 | protected: 60 | REQUEST_NOTIFICATION_STATUS HandlerStateMachine(std::unique_ptr& step, StepResult result); 61 | 62 | IHttpContext* m_http_context; 63 | ResponsePump& m_response_pump; 64 | const Logger& logger; 65 | ChannelLayer& m_channels; 66 | 67 | static std::string GetRequestHttpVersion(const IHttpRequest* request); 68 | static std::string GetRequestScheme(const HTTP_REQUEST* raw_request); 69 | static std::string GetRequestPath(const HTTP_REQUEST* raw_request); 70 | static std::string GetRequestQueryString(const HTTP_REQUEST* raw_request); 71 | static std::vector> GetRequestHeaders(const HTTP_REQUEST* raw_request); 72 | }; 73 | -------------------------------------------------------------------------------- /IntegrationTests/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # IntegrationTests - Installs compiled module into IIS and runs tests against it. 2 | 3 | # Find appropriate Python and then find its copy of virtualenv. This is 4 | # very Windows specific, but we only need to run on Windows. Unfortunately, 5 | # we can't rely on virtualenv being in the path. 6 | find_package(PythonInterp 2 EXACT REQUIRED) 7 | get_filename_component(PYTHON_DIRECTORY ${PYTHON_EXECUTABLE} DIRECTORY) 8 | set(PYTHON_SCRIPTS_DIRECTORY ${PYTHON_DIRECTORY}/Scripts) 9 | set(VIRTUALENV ${PYTHON_SCRIPTS_DIRECTORY}/virtualenv.exe) 10 | if(NOT EXISTS ${VIRTUALENV}) 11 | message(FATAL_ERROR "Could not find `virtualenv` at: ${VIRTUALENV}") 12 | endif() 13 | 14 | # Generate the virtualenv and ensure it's up to date. 15 | add_custom_command( 16 | OUTPUT venv 17 | COMMAND ${VIRTUALENV} venv 18 | ) 19 | add_custom_command( 20 | OUTPUT venv.stamp 21 | DEPENDS venv requirements.txt 22 | COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/requirements.txt requirements.txt 23 | COMMAND ./venv/Scripts/pip.exe install -r requirements.txt --upgrade 24 | ) 25 | 26 | # Build command line to run py.test. We disable the cachedir as it doesn't seem 27 | # possible to cause to happen out of the source tree. 28 | set(PYTEST 29 | ${CMAKE_CURRENT_BINARY_DIR}/venv/Scripts/python.exe 30 | ${CMAKE_CURRENT_BINARY_DIR}/venv/Scripts/py.test.exe 31 | -p no:cacheprovider 32 | ) 33 | 34 | 35 | set(TESTS 36 | test_config.py 37 | test_asgi_http.py 38 | test_asgi_ws.py 39 | test_django_http.py 40 | test_pool.py 41 | ) 42 | 43 | set(FIXTURES 44 | conftest.py 45 | 46 | fixtures/__init__.py 47 | fixtures/asgi.py 48 | fixtures/django_worker.py 49 | fixtures/etw.py 50 | fixtures/iis.py 51 | fixtures/requests.py 52 | fixtures/worker.py 53 | ) 54 | 55 | add_custom_target(IntegrationTests ALL 56 | SOURCES ${TESTS} ${FIXTURES} requirements.txt 57 | DEPENDS venv.stamp 58 | ) 59 | 60 | source_group(fixtures REGULAR_EXPRESSION fixtures/.*) 61 | 62 | 63 | add_test(NAME IntegrationTests 64 | COMMAND 65 | ${PYTEST} 66 | --dll-bitness ${TARGET_PLATFORM} 67 | --asgi-handler-dll $ 68 | --process-pool-dll $ 69 | --process-pool-schema-xml ${CMAKE_CURRENT_SOURCE_DIR}/../ProcessPool/process-pool-iis-schema.xml 70 | ${TESTS} 71 | WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} 72 | ) 73 | -------------------------------------------------------------------------------- /AsgiHandlerLib/HttpModule.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #define WIN32_LEAN_AND_MEAN 5 | #include 6 | 7 | #include "HttpModule.h" 8 | 9 | #include "HttpModuleFactory.h" 10 | #include "RequestHandler.h" 11 | #include "HttpRequestHandler.h" 12 | #include "WsRequestHandler.h" 13 | 14 | 15 | HttpModule::HttpModule( 16 | const HttpModuleFactory& factory, ResponsePump& response_pump, 17 | const Logger& logger 18 | ) : m_factory(factory), m_response_pump(response_pump), logger(logger) 19 | { 20 | logger.debug() << "Created new HttpModule"; 21 | } 22 | 23 | HttpModule::~HttpModule() 24 | { 25 | logger.debug() << "Destroying HttpModule"; 26 | } 27 | 28 | REQUEST_NOTIFICATION_STATUS HttpModule::OnExecuteRequestHandler( 29 | IHttpContext *http_context, 30 | IHttpEventProvider *provider 31 | ) 32 | { 33 | logger.debug() << "HttpModule::OnExecuteRequestHandler()"; 34 | 35 | // Freed by IIS when the IHttpContext is destroyed, via 36 | // StoredRequestContext::CleanupStoredContext() 37 | RequestHandler *request_handler = nullptr; 38 | 39 | USHORT upgrade_header_length = 0; 40 | http_context->GetRequest()->GetHeader(HttpHeaderUpgrade, &upgrade_header_length); 41 | if (upgrade_header_length > 0) { 42 | logger.debug() << "Creating new WsRequestHandler"; 43 | request_handler = new WsRequestHandler( 44 | m_response_pump, m_channels, logger, http_context 45 | ); 46 | } else { 47 | logger.debug() << "Creating new HttpRequestHandler"; 48 | request_handler = new HttpRequestHandler( 49 | m_response_pump, m_channels, logger, http_context 50 | ); 51 | } 52 | 53 | http_context->GetModuleContextContainer()->SetModuleContext( 54 | request_handler, m_factory.module_id() 55 | ); 56 | 57 | return request_handler->OnExecuteRequestHandler(); 58 | } 59 | 60 | REQUEST_NOTIFICATION_STATUS HttpModule::OnAsyncCompletion( 61 | IHttpContext* http_context, DWORD notification, BOOL post_notification, 62 | IHttpEventProvider* provider, IHttpCompletionInfo* completion_info 63 | ) 64 | { 65 | logger.debug() << "HttpModule::OnAsyncCompletion()"; 66 | 67 | // TODO: Assert we have a HttpRequestHandler in the context container? 68 | auto request_handler = static_cast( 69 | http_context->GetModuleContextContainer()->GetModuleContext(m_factory.module_id()) 70 | ); 71 | 72 | return request_handler->OnAsyncCompletion(completion_info); 73 | } 74 | -------------------------------------------------------------------------------- /IntegrationTests/fixtures/requests.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | 4 | from __future__ import absolute_import 5 | 6 | import concurrent.futures 7 | import logging 8 | 9 | import pytest 10 | 11 | import requests_futures.sessions 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class _ThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor): 17 | """An executor that remembers its futures, so that they can be inspected on test failure.""" 18 | def __init__(self, *args, **kwargs): 19 | super(_ThreadPoolExecutor, self).__init__(*args, **kwargs) 20 | self.futures = [] 21 | 22 | def submit(self, fn, *args, **kwargs): 23 | future = super(_ThreadPoolExecutor, self).submit(fn, *args, **kwargs) 24 | self.futures.append(future) 25 | return future 26 | 27 | @pytest.yield_fixture 28 | def session(request): 29 | # IIS on non-server OS only supports 10 concurrent connections, so there is no point 30 | # starting many more workers. 31 | executor = _ThreadPoolExecutor(max_workers=10) 32 | try: 33 | yield requests_futures.sessions.FuturesSession(executor) 34 | finally: 35 | # If the test failed, append a summary of how many outstanding tasks 36 | # we had. This can be useful to identify hung servers and when IIS has 37 | # returned a response when we weren't expecting one (this usually indicates 38 | # an error like a stopped AppPool). 39 | if hasattr(request.node, '_report'): 40 | if not request.node._report.passed: 41 | request.node._report.longrepr.addsection( 42 | 'Summary of requests-futures futures', 43 | '\r\n'.join(repr(future) for future in executor.futures) 44 | ) 45 | for future in executor.futures: 46 | if future.done(): 47 | response = future.result() 48 | request.node._report.longrepr.addsection( 49 | 'Response for request: %s %s' % (response.request.method, response.request.url), 50 | ( 51 | 'Status: %s\r\n' 52 | 'Headers: %s\r\n' 53 | 'Body:\r\n%s' 54 | '\r\n\r\n---\r\n\r\n' 55 | ) % (response.status_code, response.headers, response.text) 56 | ) 57 | executor.shutdown(wait=False) 58 | -------------------------------------------------------------------------------- /AsgiHandlerLib/HttpRequestHandler.cpp: -------------------------------------------------------------------------------- 1 | #define WIN32_LEAN_AND_MEAN 2 | #include 3 | 4 | #include "HttpRequestHandler.h" 5 | #include "HttpRequestHandlerSteps.h" 6 | #include "ChannelLayer.h" 7 | #include "Logger.h" 8 | 9 | 10 | REQUEST_NOTIFICATION_STATUS HttpRequestHandler::OnExecuteRequestHandler() 11 | { 12 | IHttpRequest *request = m_http_context->GetRequest(); 13 | HTTP_REQUEST *raw_request = request->GetRawHttpRequest(); 14 | 15 | auto asgi_request_msg = std::make_unique(); 16 | asgi_request_msg->reply_channel = m_channels.NewChannel("http.request!"); 17 | asgi_request_msg->http_version = GetRequestHttpVersion(request); 18 | asgi_request_msg->method = std::string(request->GetHttpMethod()); 19 | asgi_request_msg->scheme = GetRequestScheme(raw_request); 20 | asgi_request_msg->path = GetRequestPath(raw_request); 21 | asgi_request_msg->query_string = GetRequestQueryString(raw_request); 22 | asgi_request_msg->root_path = ""; // TODO: Same as SCRIPT_NAME in WSGI. What r that? 23 | asgi_request_msg->headers = GetRequestHeaders(raw_request); 24 | asgi_request_msg->body.resize(request->GetRemainingEntityBytes()); 25 | 26 | m_current_step = std::make_unique(*this, std::move(asgi_request_msg)); 27 | 28 | logger.debug() << typeid(*m_current_step.get()).name() << "->Enter() being called"; 29 | auto result = m_current_step->Enter(); 30 | logger.debug() << typeid(*m_current_step.get()).name() << "->Enter() = " << result; 31 | 32 | return HandlerStateMachine(m_current_step, result); 33 | } 34 | 35 | REQUEST_NOTIFICATION_STATUS HttpRequestHandler::OnAsyncCompletion(IHttpCompletionInfo* completion_info) 36 | { 37 | HRESULT hr = completion_info->GetCompletionStatus(); 38 | DWORD bytes = completion_info->GetCompletionBytes(); 39 | 40 | logger.debug() << typeid(*m_current_step.get()).name() << "->OnAsyncCompletion() being called"; 41 | auto result = m_current_step->OnAsyncCompletion(hr, bytes); 42 | logger.debug() << typeid(*m_current_step.get()).name() << "->OnAsyncCompletion() = " << result; 43 | 44 | return HandlerStateMachine(m_current_step, result); 45 | } 46 | 47 | bool HttpRequestHandler::ReturnError(USHORT status, const std::string& reason) 48 | { 49 | // TODO: Flush? Pass hr to SetStatus() to give better error message? 50 | IHttpResponse* response = m_http_context->GetResponse(); 51 | response->Clear(); 52 | response->SetStatus(status, reason.c_str()); 53 | return false; // No async pending. 54 | } 55 | -------------------------------------------------------------------------------- /AsgiHandlerLib/WsRequestHandler.cpp: -------------------------------------------------------------------------------- 1 | #define WIN32_LEAN_AND_MEAN 2 | #include 3 | 4 | #include "WsRequestHandler.h" 5 | #include "WsRequestHandlerSteps.h" 6 | 7 | #include "ChannelLayer.h" 8 | #include "Logger.h" 9 | 10 | 11 | REQUEST_NOTIFICATION_STATUS WsRequestHandler::OnExecuteRequestHandler() 12 | { 13 | auto request = m_http_context->GetRequest(); 14 | auto raw_request = request->GetRawHttpRequest(); 15 | 16 | // We need to remember these so that we can pass them to WsReader. IT'll need 17 | //to send them with every `websocket.receive` message. 18 | m_reply_channel = m_channels.NewChannel("websocket.send!"); 19 | m_request_path = GetRequestPath(raw_request); 20 | 21 | // The HTTP_COOKED_URL seems to contain the wrong scheme for web sockets. 22 | auto scheme = GetRequestScheme(raw_request) == "https" ? "wss" : "ws"; 23 | 24 | auto asgi_connect_msg = std::make_unique(); 25 | asgi_connect_msg->reply_channel = m_reply_channel; 26 | asgi_connect_msg->scheme = scheme; 27 | asgi_connect_msg->path = m_request_path; 28 | asgi_connect_msg->query_string = GetRequestQueryString(raw_request); 29 | asgi_connect_msg->root_path = ""; // TODO: Same as SCRIPT_NAME in WSGI. What r that? 30 | asgi_connect_msg->headers = GetRequestHeaders(raw_request); 31 | 32 | m_current_connect_step = std::make_unique(*this, std::move(asgi_connect_msg)); 33 | 34 | logger.debug() << typeid(*m_current_connect_step.get()).name() << "->Enter() being called"; 35 | auto result = m_current_connect_step->Enter(); 36 | logger.debug() << typeid(*m_current_connect_step.get()).name() << "->Enter() = " << result; 37 | 38 | return HandlerStateMachine(m_current_connect_step, result); 39 | } 40 | 41 | REQUEST_NOTIFICATION_STATUS WsRequestHandler::OnAsyncCompletion(IHttpCompletionInfo* completion_info) 42 | { 43 | HRESULT hr = completion_info->GetCompletionStatus(); 44 | DWORD bytes = completion_info->GetCompletionBytes(); 45 | 46 | StartReaderWriter(); 47 | 48 | logger.debug() << typeid(*m_current_connect_step.get()).name() << "->OnAsyncCompletion() being called"; 49 | auto result = m_current_connect_step->OnAsyncCompletion(hr, bytes); 50 | logger.debug() << typeid(*m_current_connect_step.get()).name() << "->OnAsyncCompletion() = " << result; 51 | 52 | return HandlerStateMachine(m_current_connect_step, result); 53 | } 54 | 55 | void WsRequestHandler::StartReaderWriter() 56 | { 57 | m_reader.Start(m_reply_channel, m_request_path); 58 | m_writer.Start(m_reply_channel); 59 | } 60 | -------------------------------------------------------------------------------- /SharedUtils/ScopedResources.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #define WIN32_LEAN_AND_MEAN 6 | #include 7 | #include 8 | 9 | 10 | template 11 | class ScopedResourceTraits 12 | { 13 | public: 14 | static constexpr T invalid_value{ nullptr }; 15 | static T Wrap(S v) { return v; } 16 | static void Release(T v) { static_assert(false, "Unknown resource type"); } 17 | }; 18 | 19 | 20 | template 21 | class ScopedResource 22 | { 23 | public: 24 | ScopedResource() = default; 25 | explicit ScopedResource(S value) { Set(value); } 26 | ScopedResource(ScopedResource&) = delete; 27 | ScopedResource(ScopedResource&& other) { Set(other.Transfer()); } 28 | 29 | ~ScopedResource() { Release(); } 30 | 31 | void operator=(ScopedResource&) = delete; 32 | ScopedResource& operator=(ScopedResource&& other) 33 | { 34 | if (this != &other) { 35 | Set(other.Transfer()); 36 | return *this; 37 | } 38 | } 39 | 40 | // Convenience for pointer-like Ts. 41 | T operator->() { return m_value; } 42 | 43 | bool IsValid() const { return m_value != traits::invalid_value; } 44 | T Get() const { return m_value; } 45 | void Set(S value) { m_value = traits::Wrap(value); } 46 | T* Receive() { return &m_value; } 47 | 48 | T Transfer() 49 | { 50 | auto tmp = m_value; 51 | m_value = traits::invalid_value; 52 | return tmp; 53 | } 54 | 55 | void Release() 56 | { 57 | if (IsValid()) { 58 | traits::Release(m_value); 59 | m_value = traits::invalid_value; 60 | } 61 | } 62 | 63 | private: 64 | using traits = ScopedResourceTraits; 65 | 66 | T m_value{ traits::invalid_value }; 67 | }; 68 | 69 | 70 | template<> 71 | class ScopedResourceTraits 72 | { 73 | public: 74 | static constexpr HANDLE invalid_value{ nullptr }; 75 | static HANDLE Wrap(HANDLE h) { return h; } 76 | static void Release(HANDLE h) { ::CloseHandle(h); } 77 | }; 78 | using ScopedHandle = ScopedResource; 79 | 80 | 81 | template<> 82 | class ScopedResourceTraits 83 | { 84 | public: 85 | static constexpr BSTR invalid_value{ nullptr }; 86 | static BSTR Wrap(PCWSTR str) { return ::SysAllocString(str); } 87 | static void Release(BSTR b) { ::SysFreeString(b); } 88 | }; 89 | using ScopedBstr = ScopedResource; 90 | 91 | 92 | template 93 | class ScopedResourceTraits 94 | { 95 | public: 96 | static constexpr T* invalid_value{ nullptr }; 97 | static T* Wrap(T* str) { return v; } 98 | static void Release(T* v) { v->Release(); } 99 | }; 100 | template 101 | using ScopedConfig = ScopedResource; 102 | -------------------------------------------------------------------------------- /AsgiHandlerLib/WsRequestHandlerSteps.cpp: -------------------------------------------------------------------------------- 1 | // This file contains the steps for WsRequestHandler. These deal with setting up 2 | // the WebSocket by accepting the UPGRADE request and sending the websocket.connect 3 | // ASGI message. They do not deal with the general reading/writing from websockets. 4 | 5 | #define WIN32_LEAN_AND_MEAN 6 | #include 7 | #include 8 | 9 | #include "WsRequestHandlerSteps.h" 10 | #include "RequestHandler.h" 11 | #include "ChannelLayer.h" 12 | 13 | 14 | // Connection Pipeline - AcceptWebSocketStep 15 | 16 | StepResult AcceptWebSocketStep::Enter() 17 | { 18 | // Setting the status as 101. 19 | auto resp = m_http_context->GetResponse(); 20 | resp->SetStatus(101, ""); 21 | 22 | // ... and Flushing causes the IIS WebSocket module to kick into action. 23 | // We have to set fMoreData=true in order for it to work. 24 | DWORD num_bytes = 0; 25 | BOOL completion_expected = false; 26 | HRESULT hr = resp->Flush(true, true, &num_bytes, &completion_expected); 27 | if (FAILED(hr)) { 28 | logger.debug() << "Flush() = " << hr; 29 | return kStepFinishRequest; 30 | } 31 | 32 | if (!completion_expected) { 33 | return OnAsyncCompletion(S_OK, num_bytes); 34 | } 35 | 36 | return kStepAsyncPending; 37 | } 38 | 39 | StepResult AcceptWebSocketStep::OnAsyncCompletion(HRESULT hr, DWORD num_bytes) 40 | { 41 | // Nothing to do but to go to the next step. (We don't know how many bytes 42 | // were in the response, so num_bytes is meaningless to us). 43 | // Our caller will handle FAILED(hr). 44 | return kStepGotoNext; 45 | } 46 | 47 | std::unique_ptr AcceptWebSocketStep::GetNextStep() 48 | { 49 | return std::make_unique(m_handler, std::move(m_asgi_connect_msg)); 50 | } 51 | 52 | 53 | // Connection Pipeline - SendConnectToApplicationStep 54 | 55 | StepResult SendConnectToApplicationStep::Enter() 56 | { 57 | auto task = concurrency::create_task([this]() { 58 | msgpack::sbuffer buffer; 59 | msgpack::pack(buffer, *m_asgi_connect_msg); 60 | m_channels.Send("websocket.connect", buffer); 61 | }).then([this]() { 62 | logger.debug() << "SendConnectToApplicationStep calling PostCompletion()"; 63 | m_http_context->PostCompletion(0); 64 | }); 65 | 66 | return kStepAsyncPending; 67 | } 68 | 69 | StepResult SendConnectToApplicationStep::OnAsyncCompletion(HRESULT hr, DWORD num_bytes) 70 | { 71 | if (FAILED(hr)) { 72 | // TODO: Call an Error() or something. 73 | return kStepFinishRequest; 74 | } 75 | 76 | // We return kStepAsyncPending so that the request is kept open. We'll 77 | // get callbacks from IWebSocketContext from here onwards (not directly from IIS' 78 | // main request pipeline). 79 | return kStepAsyncPending; 80 | } 81 | 82 | std::unique_ptr SendConnectToApplicationStep::GetNextStep() 83 | { 84 | // TODO: Go to DummyStep in order to make sure we release any resources held 85 | // by this step? 86 | throw std::runtime_error("SendConnectToApplicationStep::GetNextStep() should never get called."); 87 | } 88 | -------------------------------------------------------------------------------- /IntegrationTests/test_pool.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | 4 | from __future__ import unicode_literals, print_function 5 | 6 | import os 7 | import datetime 8 | import sys 9 | 10 | import pytest 11 | 12 | 13 | def wait_until(condition, timeout=2): 14 | """Waits until the specified condition function returns true, or the timeout 15 | (in seconds) elapses.""" 16 | end = datetime.datetime.utcnow() + datetime.timedelta(seconds=timeout) 17 | while datetime.datetime.utcnow() < end: 18 | if condition(): 19 | return True 20 | return False 21 | 22 | 23 | def start_pools(session, site): 24 | """Make a request to cause the Process Pools to get created.""" 25 | session.get(site.url + site.static_path).result(timeout=2) 26 | 27 | 28 | def test_pool_launches_process(site, session): 29 | # count=None checks that our config schema has a default value of '1' for count 30 | pool = site.add_process_pool(count=None) 31 | assert pool.num_started == 0 32 | start_pools(session, site) 33 | assert wait_until(lambda: pool.num_started == 1) 34 | assert pool.num_running == 1 35 | 36 | 37 | def test_pool_terminated_when_site_stopped(site, session): 38 | pool = site.add_process_pool() 39 | start_pools(session, site) 40 | assert wait_until(lambda: pool.num_started == 1) 41 | assert pool.num_running == 1 42 | site.stop_application_pool() 43 | assert wait_until(lambda: pool.num_running == 0) 44 | 45 | 46 | def test_pool_two_pools(site, session): 47 | pool1 = site.add_process_pool() 48 | pool2 = site.add_process_pool() 49 | start_pools(session, site) 50 | assert wait_until(lambda: pool1.num_started == 1) 51 | assert pool1.num_running == 1 52 | assert wait_until(lambda: pool2.num_started == 1) 53 | assert pool2.num_running == 1 54 | 55 | 56 | def test_pool_add_to_existing_site(site, session): 57 | pool1 = site.add_process_pool() 58 | start_pools(session, site) 59 | assert wait_until(lambda: pool1.num_started == 1) 60 | assert pool1.num_running == 1 61 | 62 | pool2 = site.add_process_pool() 63 | start_pools(session, site) 64 | # This will cause pool1 to restart, so the num_started will be 2. 65 | assert wait_until(lambda: pool1.num_started == 2) 66 | assert pool1.num_running == 1 67 | assert wait_until(lambda: pool2.num_started == 1) 68 | assert pool2.num_running == 1 69 | 70 | 71 | def test_pool_multiple_processes_in_pool(site, session): 72 | pool1 = site.add_process_pool(count=10) 73 | pool2 = site.add_process_pool(count=1) 74 | start_pools(session, site) 75 | assert wait_until(lambda: pool1.num_started == 10) 76 | assert pool1.num_running == 10 77 | assert wait_until(lambda: pool2.num_started == 1) 78 | assert pool2.num_running == 1 79 | 80 | 81 | def test_pool_exiting_process_restarted(site, session): 82 | pool = site.add_process_pool(count=3) 83 | start_pools(session, site) 84 | assert wait_until(lambda: pool.num_started == 3) 85 | assert pool.num_running == 3 86 | pool.kill_one() 87 | assert wait_until(lambda: pool.num_started == 4) 88 | assert pool.num_running == 3 89 | pool.kill_one() 90 | pool.kill_one() 91 | assert wait_until(lambda: pool.num_started == 6) 92 | assert pool.num_running == 3 93 | -------------------------------------------------------------------------------- /IntegrationTests/test_django_http.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | 4 | """Tests to check that we work with Django/Channels. 5 | 6 | Tests in test_asgi_*.py are preferable, as they're much easier to debug. 7 | The tests in here are a bit more complicated, but check we actually work 8 | when run under Django channels. 9 | """ 10 | 11 | from __future__ import unicode_literals, print_function 12 | 13 | import pytest 14 | 15 | from django import http 16 | from django.conf.urls import url 17 | 18 | 19 | def test_django_http_returns_200_with_body(site, django_worker, session): 20 | body = 'a body' 21 | def view(request): 22 | return http.HttpResponse(body) 23 | django_worker.set_urlconfs([url('^$', view)]) 24 | django_worker.start() 25 | 26 | future = session.get(site.url) 27 | resp = future.result(timeout=2) 28 | assert resp.status_code == 200 29 | assert resp.text == body 30 | 31 | 32 | def test_django_http_returns_redirect(site, django_worker, session): 33 | redirect_url = 'a/path/' 34 | def view(request): 35 | return http.HttpResponseRedirect(redirect_url) 36 | django_worker.set_urlconfs([url('^$', view)]) 37 | django_worker.start() 38 | 39 | future = session.get(site.url, allow_redirects=False) 40 | resp = future.result(timeout=2) 41 | assert resp.status_code == 302 42 | assert 'location' in resp.headers 43 | assert resp.headers['location'] == redirect_url 44 | 45 | 46 | def test_django_http_request_headers(site, django_worker, session): 47 | def view(request): 48 | return http.HttpResponse(request.META.get('HTTP_X_MY_HEADER', '')) 49 | django_worker.set_urlconfs([url('^$', view)]) 50 | django_worker.start() 51 | 52 | expected_value = 'a string' 53 | headers = {'X-My-Header': expected_value} 54 | future = session.get(site.url, headers=headers) 55 | resp = future.result(timeout=2) 56 | assert resp.status_code == 200 57 | assert resp.text == expected_value 58 | 59 | 60 | def test_django_http_response_headers(site, django_worker, session): 61 | header_name = 'X-My-Header' 62 | header_value = 'response string' 63 | def view(request): 64 | resp = http.HttpResponse('') 65 | resp['X-My-Header'] = header_value 66 | return resp 67 | django_worker.set_urlconfs([url('^$', view)]) 68 | django_worker.start() 69 | 70 | future = session.get(site.url) 71 | resp = future.result(timeout=2) 72 | assert resp.status_code == 200 73 | assert header_name in resp.headers 74 | assert resp.headers[header_name] == header_value 75 | 76 | 77 | def test_django_http_two_views(site, django_worker, session): 78 | def view1(request): 79 | return http.HttpResponse('view1') 80 | def view2(request): 81 | return http.HttpResponse('view2') 82 | django_worker.set_urlconfs([ 83 | url('^view1$', view1), 84 | url('^view2$', view2), 85 | ]) 86 | django_worker.start() 87 | 88 | future1 = session.get(site.url + '/view1') 89 | future2 = session.get(site.url + '/view2') 90 | resp1 = future1.result(timeout=2) 91 | resp2 = future2.result(timeout=2) 92 | assert resp1.status_code == 200 93 | assert resp2.status_code == 200 94 | assert resp1.text == 'view1' 95 | assert resp2.text == 'view2' 96 | -------------------------------------------------------------------------------- /AsgiHandlerLib/ResponsePump.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #define WIN32_LEAN_AND_MEAN 5 | #include 6 | 7 | #include "ResponsePump.h" 8 | #include "ChannelLayer.h" 9 | #include "Logger.h" 10 | 11 | 12 | ResponsePump::ResponsePump(const Logger& logger, ChannelLayer& channels) 13 | : logger(logger), m_channels(channels), m_thread_stop(false) 14 | { 15 | } 16 | 17 | ResponsePump::~ResponsePump() 18 | { 19 | m_thread_stop = true; 20 | // TODO: Consider m_callbacks.clear() and calling .detach()? 21 | if (m_thread.joinable()) { 22 | m_thread.join(); 23 | } 24 | } 25 | 26 | void ResponsePump::Start() 27 | { 28 | // TODO: Assert we aren't called twice. 29 | m_thread = std::thread(&ResponsePump::ThreadMain, this); 30 | } 31 | 32 | void ResponsePump::AddChannel(const std::string& channel, const ResponseChannelCallback& callback) 33 | { 34 | std::lock_guard lock(m_callbacks_mutex); 35 | logger.debug() << "ResponsePump::AddChannel(" << channel << ", _)"; 36 | m_callbacks[channel] = callback; 37 | } 38 | 39 | void ResponsePump::RemoveChannel(const std::string& channel) 40 | { 41 | std::lock_guard lock(m_callbacks_mutex); 42 | m_callbacks.erase(channel); 43 | } 44 | 45 | void ResponsePump::ThreadMain() 46 | { 47 | using namespace std::literals::chrono_literals; 48 | 49 | logger.debug() << "ResponsePump::ThreadMain() starting"; 50 | while (!m_thread_stop) { 51 | std::vector channel_names; 52 | { 53 | std::lock_guard lock(m_callbacks_mutex); 54 | channel_names.reserve(m_callbacks.size()); 55 | for (const auto& it : m_callbacks) { 56 | channel_names.push_back(it.first); 57 | } 58 | } 59 | 60 | // These delay values have been chosen to match those in asgi_redis. 61 | // TODO: Think about whether these make sense for us. Perhaps we should 62 | // have some way of being woken up whilst we're sleeping? 63 | // A 10-50ms latency seems like a pretty big hit. 64 | auto delay = 50ms; 65 | if (!channel_names.empty()) { 66 | delay = 10ms; 67 | 68 | std::string channel, data; 69 | std::tie(channel, data) = m_channels.ReceiveMany(channel_names, false); 70 | if (!channel.empty()) { 71 | delay = 0ms; 72 | 73 | std::lock_guard lock(m_callbacks_mutex); 74 | const auto it = m_callbacks.find(channel); 75 | // If we don't have a callback, do nothing. Assume the request 76 | // has since been closed. 77 | if (it != m_callbacks.end()) { 78 | const auto callback = it->second; 79 | m_callbacks.erase(it); 80 | logger.debug() << "ResponsePump calling callback for channel: " << channel; 81 | concurrency::create_task([data, callback]() { 82 | callback(std::move(data)); 83 | }); 84 | } else { 85 | logger.debug() << "ResponsePump dropping reply as no callback for channel: " << channel; 86 | } 87 | } 88 | } 89 | 90 | // Don't yield if there's more to dispatch. 91 | if (delay != 0ms) { 92 | std::this_thread::sleep_for(delay); 93 | } 94 | } 95 | logger.debug() << "ResponsePump::ThreadMain() exiting"; 96 | } 97 | -------------------------------------------------------------------------------- /AsgiHandlerLib/HttpRequestHandlerSteps.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #define WIN32_LEAN_AND_MEAN 7 | #include 8 | 9 | #include "AsgiHttpRequestMsg.h" 10 | #include "AsgiHttpResponseMsg.h" 11 | #include "Logger.h" 12 | #include "RedisChannelLayer.h" 13 | #include "ResponsePump.h" 14 | #include "RequestHandler.h" 15 | 16 | 17 | class ReadBodyStep : public RequestHandlerStep 18 | { 19 | public: 20 | ReadBodyStep( 21 | RequestHandler& handler, std::unique_ptr& asgi_request_msg 22 | ) : RequestHandlerStep(handler), m_asgi_request_msg(std::move(asgi_request_msg)), m_body_bytes_read(0) 23 | { } 24 | 25 | virtual StepResult Enter(); 26 | virtual StepResult OnAsyncCompletion(HRESULT hr, DWORD num_bytes); 27 | virtual std::unique_ptr GetNextStep(); 28 | 29 | protected: 30 | std::unique_ptr m_asgi_request_msg; 31 | DWORD m_body_bytes_read; 32 | }; 33 | 34 | 35 | class SendToApplicationStep : public RequestHandlerStep 36 | { 37 | public: 38 | SendToApplicationStep( 39 | RequestHandler& handler, std::unique_ptr& asgi_request_msg 40 | ) : RequestHandlerStep(handler), m_asgi_request_msg(std::move(asgi_request_msg)) 41 | { } 42 | 43 | virtual StepResult Enter(); 44 | virtual StepResult OnAsyncCompletion(HRESULT hr, DWORD num_bytes); 45 | virtual std::unique_ptr GetNextStep(); 46 | 47 | protected: 48 | std::unique_ptr m_asgi_request_msg; 49 | }; 50 | 51 | 52 | class WaitForResponseStep : public RequestHandlerStep 53 | { 54 | public: 55 | WaitForResponseStep( 56 | RequestHandler& handler, const std::string& reply_channel 57 | ) : RequestHandlerStep(handler), m_reply_channel(reply_channel) 58 | { } 59 | 60 | virtual StepResult Enter(); 61 | virtual StepResult OnAsyncCompletion(HRESULT hr, DWORD num_bytes); 62 | virtual std::unique_ptr GetNextStep(); 63 | 64 | protected: 65 | const std::string m_reply_channel; 66 | std::unique_ptr m_asgi_response_msg; 67 | }; 68 | 69 | 70 | class WriteResponseStep : public RequestHandlerStep 71 | { 72 | public: 73 | WriteResponseStep( 74 | RequestHandler& handler, std::unique_ptr& asgi_response_msg, 75 | std::string reply_channel 76 | ) : RequestHandlerStep(handler), m_asgi_response_msg(std::move(asgi_response_msg)), 77 | m_reply_channel(reply_channel), m_resp_bytes_written(0) 78 | { } 79 | 80 | virtual StepResult Enter(); 81 | virtual StepResult OnAsyncCompletion(HRESULT hr, DWORD num_bytes); 82 | 83 | virtual std::unique_ptr GetNextStep(); 84 | 85 | protected: 86 | std::unique_ptr m_asgi_response_msg; 87 | const std::string m_reply_channel; 88 | HTTP_DATA_CHUNK m_resp_chunk; 89 | DWORD m_resp_bytes_written; 90 | }; 91 | 92 | 93 | class FlushResponseStep : public RequestHandlerStep 94 | { 95 | public: 96 | FlushResponseStep( 97 | RequestHandler& handler, std::string reply_channel 98 | ) : RequestHandlerStep(handler), m_reply_channel(reply_channel) 99 | { } 100 | 101 | virtual StepResult Enter(); 102 | virtual StepResult OnAsyncCompletion(HRESULT hr, DWORD num_bytes); 103 | 104 | virtual std::unique_ptr GetNextStep(); 105 | 106 | protected: 107 | const std::string m_reply_channel; 108 | }; 109 | -------------------------------------------------------------------------------- /AsgiHandlerTest/test_WriteResponseStep.cpp: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "gmock/gmock.h" 3 | 4 | #include "mock_httpserv.h" 5 | #include "mock_ResponsePump.h" 6 | #include "mock_HttpRequestHandler.h" 7 | 8 | 9 | using ::testing::DoAll; 10 | using ::testing::NiceMock; 11 | using ::testing::Return; 12 | using ::testing::SaveArg; 13 | using ::testing::SetArgPointee; 14 | using ::testing::_; 15 | 16 | 17 | class WriteResponseStepTest : public ::testing::Test 18 | { 19 | public: 20 | WriteResponseStepTest() 21 | : handler(response_pump, &http_context) 22 | { 23 | ON_CALL(http_context, GetResponse()) 24 | .WillByDefault(Return(&response)); 25 | } 26 | 27 | NiceMock http_context; 28 | MockResponsePump response_pump; 29 | MockHttpRequestHandler handler; 30 | MockIHttpResponse response; 31 | // No step here, as each test will need to pass a msg in during construction 32 | }; 33 | 34 | 35 | TEST_F(WriteResponseStepTest, ReturnedFinishedOnError) 36 | { 37 | auto msg = std::make_unique(); 38 | msg->content = "some content"; 39 | WriteResponseStep step(handler, std::move(msg), ""); 40 | 41 | EXPECT_CALL(response, WriteEntityChunks(_, _, _, _, _, _)) 42 | .WillOnce(Return(E_ACCESSDENIED)); 43 | 44 | EXPECT_EQ(kStepFinishRequest, step.Enter()); 45 | } 46 | 47 | 48 | TEST_F(WriteResponseStepTest, ReturnsAsyncPending) 49 | { 50 | auto msg = std::make_unique(); 51 | msg->content = "some content"; 52 | WriteResponseStep step(handler, std::move(msg), ""); 53 | 54 | EXPECT_CALL(response, WriteEntityChunks(_, _, _, _, _, _)) 55 | .WillOnce(DoAll( 56 | SetArgPointee<5>(TRUE), // completion_expected 57 | Return(S_OK) 58 | )); 59 | 60 | EXPECT_EQ(kStepAsyncPending, step.Enter()); 61 | } 62 | 63 | 64 | // If we repeatedly get completion_expected==FALSE, we might 65 | // finish this step without ever explicitly calling OnAsyncCompletion. 66 | TEST_F(WriteResponseStepTest, HandlesCompletionExpectedFalse) 67 | { 68 | auto msg = std::make_unique(); 69 | msg->content = "some content"; 70 | WriteResponseStep step(handler, std::move(msg), ""); 71 | 72 | // Right now OnAsyncCompletion() will always return kStepFinishRequest, 73 | // as we ignore the num_bytes. (See comments there). 74 | EXPECT_CALL(response, WriteEntityChunks(_, _, _, _, _, _)) 75 | .WillOnce(DoAll( 76 | SetArgPointee<4>(12), // num_bytes 77 | SetArgPointee<5>(FALSE), // completion_expected 78 | Return(S_OK) 79 | )); 80 | 81 | EXPECT_EQ(kStepFinishRequest, step.Enter()); 82 | } 83 | 84 | 85 | /* Skipped: OnAsyncCompletion() currently ignores num_bytes. 86 | TEST_F(WriteResponseStepTest, OnAsyncCompletionMoreToWrite) 87 | { 88 | auto msg = std::make_unique(); 89 | msg->content = "some content"; 90 | WriteResponseStep step(handler, std::move(msg)); 91 | 92 | EXPECT_EQ(kStepRerun, step.OnAsyncCompletion(S_OK, 4)); 93 | } 94 | */ 95 | 96 | 97 | TEST_F(WriteResponseStepTest, OnAsyncCompletionNoMoreToWrite) 98 | { 99 | auto msg = std::make_unique(); 100 | msg->content = "some content"; 101 | WriteResponseStep step(handler, std::move(msg), ""); 102 | 103 | EXPECT_EQ(kStepFinishRequest, step.OnAsyncCompletion(S_OK, 12)); 104 | } 105 | 106 | 107 | TEST_F(WriteResponseStepTest, OnAsyncCompletionChunkedResponse) 108 | { 109 | auto msg = std::make_unique(); 110 | msg->content = "some content"; 111 | msg->more_content = true; 112 | WriteResponseStep step(handler, std::move(msg), "reply_channel"); 113 | 114 | EXPECT_EQ(kStepGotoNext, step.OnAsyncCompletion(S_OK, 12)); 115 | } 116 | -------------------------------------------------------------------------------- /AsgiHandlerLib/WsReader.cpp: -------------------------------------------------------------------------------- 1 | // WsReader reads messages from a WebSocket and sends them to the channels 2 | // layer. WsReader is owned and started by WsRequestHandler, but it does not receive 3 | // it's callbacks from WsRequestHandler or IIS' normal request pipeline. It registers 4 | // callbacks with IIS' WebSocket module via the IWebSocketContext. 5 | 6 | #define WIN32_LEAN_AND_MEAN 7 | #include 8 | #include 9 | 10 | #include "WsReader.h" 11 | #include "WsRequestHandler.h" 12 | #include "RedisChannelLayer.h" 13 | #include "Logger.h" 14 | 15 | 16 | WsReader::WsReader(WsRequestHandler & handler) 17 | : logger(handler.logger), m_channels(handler.m_channels), 18 | m_http_context(handler.m_http_context) 19 | { } 20 | 21 | 22 | void WsReader::Start(const std::string& reply_channel, const std::string& request_path) 23 | { 24 | auto http_context3 = static_cast(m_http_context); 25 | m_ws_context = static_cast( 26 | http_context3->GetNamedContextContainer()->GetNamedContext(L"websockets") 27 | ); 28 | 29 | m_msg.reply_channel = reply_channel; 30 | m_msg.path = request_path; 31 | 32 | ReadAsync(); 33 | } 34 | 35 | 36 | void WsReader::ReadAsync() 37 | { 38 | logger.debug() << "ReadAsync()"; 39 | 40 | BOOL completion_expected = false; 41 | while (!completion_expected) { 42 | DWORD num_bytes = m_msg.data.size(); 43 | BOOL utf8 = false, final_fragment = false, close = false; 44 | 45 | HRESULT hr = m_ws_context->ReadFragment( 46 | m_msg.data.data(), &num_bytes, true, &utf8, &final_fragment, &close, 47 | ReadCallback, this, &completion_expected 48 | ); 49 | if (FAILED(hr)) { 50 | logger.debug() << "ReadFragment() = " << hr; 51 | // TODO: Call an Error() or something. 52 | // TODO: Figure out how to close the request from here. 53 | } 54 | 55 | if (!completion_expected) { 56 | logger.debug() << "ReadFragment() returned completion_expected=false"; 57 | ReadAsyncComplete(S_OK, num_bytes, utf8, final_fragment, close); 58 | } 59 | } 60 | 61 | logger.debug() << "ReadAsync() returning"; 62 | } 63 | 64 | 65 | void WsReader::ReadAsyncComplete(HRESULT hr, DWORD num_bytes, BOOL utf8, BOOL final_fragment, BOOL close) 66 | { 67 | logger.debug() << "ReadAsyncComplete() " << num_bytes << " " << utf8 << " " << final_fragment << " " << close; 68 | 69 | if (FAILED(hr)) { 70 | logger.debug() << "ReadAsyncComplete() hr = " << hr; 71 | // TODO: Figure out how to propogate an error from here. 72 | } 73 | // TODO: Handle close. 74 | 75 | m_msg.data_size += num_bytes; 76 | m_msg.utf8 = utf8; 77 | 78 | if (final_fragment) { 79 | // Don't bother re-sizing the buffer to the correct size before sending. 80 | // The msgpack serializer will take care of it. 81 | SendToApplication(); 82 | } else { 83 | // If we're almost at the end of the buffer, increase the buffer size. 84 | if (m_msg.data_size >= (m_msg.data.size() - AsgiWsReceiveMsg::BUFFER_CHUNK_INCREASE_THRESHOLD)) { 85 | m_msg.data.resize(m_msg.data.size() + AsgiWsReceiveMsg::BUFFER_CHUNK_SIZE); 86 | } 87 | } 88 | } 89 | 90 | 91 | void WsReader::SendToApplication() 92 | { 93 | // Send synchronously for now. 94 | msgpack::sbuffer buffer; 95 | msgpack::pack(buffer, m_msg); 96 | m_channels.Send("websocket.receive", buffer); 97 | 98 | // Reset for the next msg. 99 | m_msg.data.resize(AsgiWsReceiveMsg::BUFFER_CHUNK_SIZE); 100 | m_msg.data_size = 0; 101 | m_msg.order += 1; 102 | } 103 | 104 | 105 | void WINAPI WsReader::ReadCallback(HRESULT hr, VOID* context, DWORD num_bytes, BOOL utf8, BOOL final_fragment, BOOL close) 106 | { 107 | auto reader = static_cast(context); 108 | reader->ReadAsyncComplete(hr, num_bytes, utf8, final_fragment, close); 109 | reader->ReadAsync(); 110 | } 111 | -------------------------------------------------------------------------------- /AsgiHandlerLib/RedisChannelLayer.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #define WIN32_LEAN_AND_MEAN 4 | #include 5 | 6 | #include 7 | #include 8 | 9 | #include "RedisChannelLayer.h" 10 | #include "AsgiHttpRequestMsg.h" 11 | 12 | 13 | RedisChannelLayer::RedisChannelLayer(std::string ip, int port, std::string prefix) 14 | : m_redis_ctx(nullptr), m_prefix(prefix), m_expiry(60) 15 | { 16 | struct timeval timeout = { 1, 500000 }; // 1.5 seconds 17 | m_redis_ctx = redisConnectWithTimeout(ip.c_str(), port, timeout); 18 | if (m_redis_ctx == nullptr || m_redis_ctx->err) { 19 | if (m_redis_ctx == nullptr) { 20 | std::cerr << "Connection error: " << m_redis_ctx->errstr << std::endl; 21 | redisFree(m_redis_ctx); 22 | } else { 23 | std::cerr << "Could not create redis context" << std::endl; 24 | } 25 | // TODO: Some king of exception handling scheme. 26 | ::exit(1); 27 | } 28 | } 29 | 30 | 31 | RedisChannelLayer::~RedisChannelLayer() 32 | { 33 | redisFree(m_redis_ctx); 34 | } 35 | 36 | void RedisChannelLayer::Send(const std::string& channel, const msgpack::sbuffer& buffer) 37 | { 38 | // asgi_redis chooses to store the data in a random key, then add the key to 39 | // the channel. (This allows us to set the data to expire, which we do). 40 | std::string data_key = m_prefix + GenerateRandomAscii(10); 41 | ExecuteRedisCommand("SET %s %b", data_key.c_str(), buffer.data(), buffer.size()); 42 | ExecuteRedisCommand("EXPIRE %s %i", data_key.c_str(), m_expiry); 43 | 44 | // We also set expiry on the channel. (Subsequent Send()s will bump this expiry 45 | // up). We set to +10, because asgi_redis does... presumably to workaround the 46 | // fact that time has passed since we put the data into the data_key? 47 | std::string channel_key = m_prefix + channel; 48 | ExecuteRedisCommand("RPUSH %s %b", channel_key.c_str(), data_key.c_str(), data_key.length()); 49 | ExecuteRedisCommand("EXPIRE %s %i", channel_key.c_str(), m_expiry + 10); 50 | } 51 | 52 | std::tuple RedisChannelLayer::ReceiveMany(const std::vector& channels, bool blocking) 53 | { 54 | std::vector prefixed_channels(channels.size()); 55 | for (const auto& it : channels) { 56 | prefixed_channels.push_back(m_prefix + it); 57 | } 58 | 59 | // Shuffle to avoid one channel starving the others. 60 | std::shuffle(prefixed_channels.begin(), prefixed_channels.end(), m_random_engine); 61 | 62 | // Build up command for BLPOP 63 | // TODO: Allow a non-blocking version. (Possibly only a non-blocking version?) 64 | // It looks like we might need a custom Lua script. Eek. 65 | std::vector buffer; 66 | std::string blpop("BLPOP"); 67 | buffer.push_back(blpop.c_str()); 68 | for (const auto& it : prefixed_channels) { 69 | buffer.push_back(it.c_str()); 70 | } 71 | std::string timeout("1"); 72 | buffer.push_back(timeout.c_str()); 73 | 74 | // Run the BLPOP. Do this manually for now, rather than trying to figure out a sensible API for 75 | // ExecuteRedisCommand to allow it to use redisCommandArgv(). 76 | RedisReply reply( 77 | static_cast(redisCommandArgv(m_redis_ctx, buffer.size(), buffer.data(), nullptr)), 78 | freeReplyObject 79 | ); 80 | if (reply->type == REDIS_REPLY_NIL) { 81 | return std::make_tuple("", ""); 82 | } 83 | // Remove the prefix before sharing the channel name with others. 84 | std::string channel(reply->element[0]->str + m_prefix.size(), reply->element[0]->len - m_prefix.size()); 85 | 86 | // The response data is not actually stored in the channel. The channel contains a key. 87 | reply = ExecuteRedisCommand("GET %s", reply->element[1]->str); 88 | if (reply->type == REDIS_REPLY_NIL) { 89 | // This usually means that the message has expired. When this happens, 90 | // asgi_redis will loop and try to pull the next item from the list. 91 | // TODO: We should loop here too. 92 | return std::make_tuple("", ""); 93 | } 94 | 95 | // TODO: Think of a way to avoid extra copies. Perhaps a msgpack::object 96 | // with pointers into the original redisReply. 97 | return std::make_tuple(channel, std::string(reply->str, reply->len)); 98 | } 99 | -------------------------------------------------------------------------------- /ProcessPoolLib/GlobalModule.cpp: -------------------------------------------------------------------------------- 1 | #include "GlobalModule.h" 2 | 3 | #include 4 | #include 5 | 6 | #define WIN32_LEAN_AND_MEAN 7 | #include 8 | 9 | #include "ScopedResources.h" 10 | 11 | 12 | namespace { 13 | // {8FC896E8-4370-4976-855B-19B70C976414} 14 | static const GUID logger_etw_guid = { 0x8fc896e8, 0x4370, 0x4976,{ 0x85, 0x5b, 0x19, 0xb7, 0xc, 0x97, 0x64, 0x14 } }; 15 | } 16 | 17 | 18 | GlobalModule::GlobalModule(IHttpServer *http_server) 19 | : logger(logger_etw_guid), m_http_server(http_server) 20 | { 21 | logger.debug() << "Constructed GlobalModule"; 22 | } 23 | 24 | 25 | GLOBAL_NOTIFICATION_STATUS GlobalModule::OnGlobalApplicationStart( 26 | IHttpApplicationStartProvider* provider 27 | ) 28 | { 29 | auto app_name = std::wstring{ provider->GetApplication()->GetApplicationId() }; 30 | logger.debug() << "OnGlobalApplicationStart: " << app_name; 31 | 32 | try { 33 | LoadConfiguration(provider->GetApplication()); 34 | } catch (const std::runtime_error& e) { 35 | logger.debug() << "Error loading configuration for " << app_name << ": " 36 | << ""; 37 | } 38 | 39 | return GL_NOTIFICATION_CONTINUE; 40 | } 41 | 42 | 43 | GLOBAL_NOTIFICATION_STATUS GlobalModule::OnGlobalApplicationStop( 44 | IHttpApplicationStopProvider* provider 45 | ) 46 | { 47 | auto app_name = std::wstring{ provider->GetApplication()->GetApplicationId() }; 48 | logger.debug() << "OnGlobalApplicationStop: " << app_name << " " 49 | << "Terminating running pools."; 50 | 51 | // Destroy all our pools, which has the side-effect of terminating 52 | // their processes. 53 | m_pools.clear(); 54 | 55 | logger.debug() << "OnGlobalApplicationStop finished"; 56 | 57 | return GL_NOTIFICATION_CONTINUE; 58 | } 59 | 60 | 61 | namespace { 62 | 63 | void RaiseOnFailure(HRESULT hr, const std::string& msg) 64 | { 65 | if (FAILED(hr)) { 66 | throw std::runtime_error(msg); 67 | } 68 | } 69 | 70 | } // end anonymous namespace 71 | 72 | 73 | void GlobalModule::LoadConfiguration(IHttpApplication *application) 74 | { 75 | auto section_name = ScopedBstr{ L"system.webServer/processPools" }; 76 | auto config_path = ScopedBstr{ application->GetAppConfigPath() }; 77 | 78 | auto section = ScopedConfig{}; 79 | auto hr = m_http_server->GetAdminManager()->GetAdminSection( 80 | section_name.Get(), config_path.Get(), section.Receive() 81 | ); 82 | if (FAILED(hr)) { 83 | // The schema is probably not installed. When we have logging, 84 | // we should log an error, but shouldn't crash the server.. 85 | logger.debug() << "Could not open system.webServer/processPools. " 86 | << "Is the config schema installed?"; 87 | return; 88 | } 89 | 90 | auto collection = ScopedConfig{}; 91 | RaiseOnFailure( 92 | section->get_Collection(collection.Receive()), 93 | "get_Collection()" 94 | ); 95 | 96 | auto element_count = DWORD{ 0 }; 97 | RaiseOnFailure( 98 | collection->get_Count(&element_count), 99 | "get_Count()" 100 | ); 101 | 102 | for (auto element_idx = 0; element_idx < element_count; element_idx++) { 103 | auto element_variant = VARIANT{ 0 }; 104 | element_variant.vt = VT_I4; 105 | element_variant.lVal = element_idx; 106 | 107 | auto element = ScopedConfig{}; 108 | RaiseOnFailure( 109 | collection->get_Item(element_variant, element.Receive()), 110 | "get_Item()" 111 | ); 112 | 113 | // Each element should be a , which will have the following properties: 114 | // TODO: Make arguments and count optional. 115 | auto executable = GetProperty(element.Get(), L"executable"); 116 | auto arguments = GetProperty(element.Get(), L"arguments"); 117 | auto count = std::stoi(GetProperty(element.Get(), L"count")); 118 | 119 | // Create a ProcessPool for each: 120 | m_pools.push_back(std::make_unique(logger, executable, arguments, count)); 121 | } 122 | } 123 | 124 | std::wstring GlobalModule::GetProperty(IAppHostElement *element, const std::wstring& name) 125 | { 126 | auto bstr_name = ScopedBstr{ name.c_str() }; 127 | auto property = ScopedConfig{}; 128 | RaiseOnFailure( 129 | element->GetPropertyByName(bstr_name.Get(), property.Receive()), 130 | "GetPropertyByName()" 131 | ); 132 | 133 | auto bstr_value = BSTR{ nullptr }; 134 | RaiseOnFailure( 135 | property->get_StringValue(&bstr_value), 136 | "get_StringValue()" 137 | ); 138 | 139 | return std::wstring{ bstr_value }; 140 | } 141 | 142 | -------------------------------------------------------------------------------- /AsgiHandlerLib/WsWriter.cpp: -------------------------------------------------------------------------------- 1 | // WsWriter reads AsgiWsSendMsges from the Channels layer and then writes them 2 | // to the web socket. WsWriter is owned and started by WsRequestHandler. Like WsReader, 3 | // it doesn't receive any callbacks from IIS' normal request pipeline, and instead 4 | // receives callbacks directly fromt he WebSocket module via the IWebSocketContext. 5 | 6 | #define WIN32_LEAN_AND_MEAN 7 | #include 8 | #include 9 | #include 10 | 11 | #include 12 | 13 | #include "WsWriter.h" 14 | #include "AsgiWsSendMsg.h" 15 | #include "WsRequestHandler.h" 16 | #include "ResponsePump.h" 17 | #include "Logger.h" 18 | 19 | 20 | WsWriter::WsWriter(WsRequestHandler & handler) 21 | : logger(handler.logger), m_response_pump(handler.m_response_pump), 22 | m_http_context(handler.m_http_context) 23 | { } 24 | 25 | 26 | void WsWriter::Start(const std::string & reply_channel) 27 | { 28 | auto http_context3 = static_cast(m_http_context); 29 | m_ws_context = static_cast( 30 | http_context3->GetNamedContextContainer()->GetNamedContext(L"websockets") 31 | ); 32 | 33 | m_reply_channel = reply_channel; 34 | 35 | RegisterChannelsCallback(); 36 | } 37 | 38 | 39 | void WsWriter::RegisterChannelsCallback() 40 | { 41 | logger.debug() << "RegisterChannelsCallback()"; 42 | 43 | m_response_pump.AddChannel(m_reply_channel, [this](std::string data) { 44 | logger.debug() << "Received AsgiWsSendMsg on channel " << m_reply_channel; 45 | // This makes me a bit nervous... we're calling WriteAsync() which operates 46 | // on IWebSocketContext from whichever thread our ResponsePump callback 47 | // happens to come on. This is unlikely to be on the IIS (or WebSocket module) 48 | // thread pool and it is difficult to know whether it should be. 49 | m_asgi_send_msg = std::make_unique( 50 | msgpack::unpack(data.data(), data.length()).get().as() 51 | ); 52 | WriteAsync(); 53 | }); 54 | } 55 | 56 | void WsWriter::WriteAsync() 57 | { 58 | // TODO: Handle m_asgi_send_msg->close. 59 | logger.debug() << "WriteAsync()"; 60 | 61 | BOOL completion_expected = false; 62 | while (!completion_expected) { 63 | DWORD num_bytes = m_asgi_send_msg->bytes.size() - m_bytes_written; 64 | BOOL utf8 = true; // TODO: utf8/bytes mode 65 | BOOL final_fragment = true; 66 | 67 | HRESULT hr = m_ws_context->WriteFragment( 68 | m_asgi_send_msg->bytes.data(), &num_bytes, true, utf8, final_fragment, 69 | WriteCallback, this, &completion_expected 70 | ); 71 | if (FAILED(hr)) { 72 | logger.debug() << "WriteFragment() = " << hr; 73 | // TODO: Call an Error() or something. 74 | // TODO: Figure out how to close the request from here. 75 | } 76 | 77 | if (!completion_expected) { 78 | logger.debug() << "WriteFragment() returned completion_expected=false"; 79 | auto more_to_write = WriteAsyncComplete(S_OK, num_bytes, utf8, final_fragment, false); 80 | if (!more_to_write) { 81 | return; 82 | } 83 | } 84 | } 85 | 86 | logger.debug() << "WriteAsync() returning"; 87 | } 88 | 89 | 90 | bool WsWriter::WriteAsyncComplete(HRESULT hr, DWORD num_bytes, BOOL utf8, BOOL final_fragment, BOOL close) 91 | { 92 | logger.debug() << "WriteAsyncComplete() " << num_bytes << " " << utf8 << " " << final_fragment << " " << close; 93 | 94 | if (FAILED(hr)) { 95 | logger.debug() << "WriteAsyncComplete() hr = " << hr; 96 | // TODO: Figure out how to propogate an error from here. 97 | } 98 | // TODO: Handle close. Will it be given for the WriteFragment() callback, 99 | // given that it isn't in the WriteFragment() signature? 100 | 101 | m_bytes_written += num_bytes; 102 | 103 | // If we're finished, start waiting for the next message. 104 | if (m_bytes_written >= m_asgi_send_msg->bytes.size()) { 105 | // Reset ready for the next message. We do this here so that an idle 106 | // websocket consumes as little memory as possible. 107 | m_asgi_send_msg.release(); 108 | m_bytes_written = 0; 109 | RegisterChannelsCallback(); 110 | return false; // No more to write. 111 | } else { 112 | return true; // More to write. 113 | } 114 | } 115 | 116 | 117 | void WsWriter::WriteCallback(HRESULT hr, VOID * context, DWORD num_bytes, BOOL utf8, BOOL final_fragment, BOOL close) 118 | { 119 | auto writer = static_cast(context); 120 | auto more_to_write = writer->WriteAsyncComplete(hr, num_bytes, utf8, final_fragment, close); 121 | if (more_to_write) { 122 | writer->WriteAsync(); 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /AsgiHandlerTest/test_ReadBodyStep.cpp: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "gmock/gmock.h" 3 | 4 | #include "mock_httpserv.h" 5 | #include "mock_ResponsePump.h" 6 | #include "mock_HttpRequestHandler.h" 7 | 8 | 9 | using ::testing::DoAll; 10 | using ::testing::NiceMock; 11 | using ::testing::Return; 12 | using ::testing::SetArgPointee; 13 | using ::testing::SaveArg; 14 | using ::testing::_; 15 | 16 | 17 | class ReadBodyStepTest : public ::testing::Test 18 | { 19 | public: 20 | ReadBodyStepTest() 21 | : handler(response_pump, &http_context), msg(std::make_unique()), 22 | step(handler, msg) 23 | { 24 | ON_CALL(http_context, GetRequest()) 25 | .WillByDefault(Return(&request)); 26 | } 27 | 28 | NiceMock http_context; 29 | MockResponsePump response_pump; 30 | MockHttpRequestHandler handler; 31 | MockIHttpRequest request; 32 | std::unique_ptr msg; 33 | ReadBodyStep step; 34 | }; 35 | 36 | 37 | TEST_F(ReadBodyStepTest, EnterDoesNothingForNoBody) 38 | { 39 | EXPECT_CALL(request, GetRemainingEntityBytes()) 40 | .WillRepeatedly(Return(0)); 41 | // We shouldn't get as far as calling ReadEntityBody() 42 | EXPECT_CALL(request, ReadEntityBody(_, _, _, _, _)) 43 | .Times(0); 44 | 45 | EXPECT_EQ(kStepGotoNext, step.Enter()); 46 | } 47 | 48 | 49 | TEST_F(ReadBodyStepTest, EnterFinishesRequestOnError) 50 | { 51 | EXPECT_CALL(request, GetRemainingEntityBytes()) 52 | .WillRepeatedly(Return(1)); 53 | EXPECT_CALL(request, ReadEntityBody(_, _, _, _, _)) 54 | .Times(1) 55 | .WillOnce(Return(E_INVALIDARG)); 56 | 57 | EXPECT_EQ(kStepFinishRequest, step.Enter()); 58 | } 59 | 60 | 61 | // Handle ReadEntityBody() completion synchronously. 62 | TEST_F(ReadBodyStepTest, EnterHandlesSynchronousRead) 63 | { 64 | EXPECT_CALL(request, GetRemainingEntityBytes()) 65 | .WillOnce(Return(1)) 66 | .WillRepeatedly(Return(0)); 67 | 68 | EXPECT_CALL(request, ReadEntityBody(_, _, _, _, _)) 69 | .Times(1) 70 | .WillOnce(DoAll( 71 | SetArgPointee<3>(1), // bytes_read 72 | SetArgPointee<4>(FALSE), // completion_expected 73 | Return(S_OK) 74 | )); 75 | 76 | EXPECT_EQ(kStepGotoNext, step.Enter()); 77 | } 78 | 79 | 80 | // Continues to read body if ReadEntityBody() completes synchronously 81 | // and there's more data to read. Subseqeunt calls to ReadEntityBody() 82 | // give a pointer to a different place in the buffer. 83 | TEST_F(ReadBodyStepTest, EnterContinuesToRead) 84 | { 85 | EXPECT_CALL(request, GetRemainingEntityBytes()) 86 | .WillOnce(Return(2)) 87 | .WillOnce(Return(1)) 88 | .WillRepeatedly(Return(0)); 89 | 90 | VOID* buffer_ptr1 = nullptr; 91 | VOID* buffer_ptr2 = nullptr; 92 | 93 | EXPECT_CALL(request, ReadEntityBody(_, _, _, _, _)) 94 | .Times(2) 95 | .WillOnce(DoAll( 96 | SaveArg<0>(&buffer_ptr1), 97 | SetArgPointee<3>(1), // bytes_read 98 | SetArgPointee<4>(FALSE), // completion_expected 99 | Return(S_OK) 100 | )) 101 | .WillOnce(DoAll( 102 | SaveArg<0>(&buffer_ptr2), 103 | SetArgPointee<3>(1), // bytes_read 104 | SetArgPointee<4>(FALSE), // completion_expected 105 | Return(S_OK) 106 | )); 107 | 108 | EXPECT_EQ(kStepGotoNext, step.Enter()); 109 | EXPECT_EQ(1, (char*)buffer_ptr2 - (char*)buffer_ptr1); 110 | } 111 | 112 | 113 | // Returns kStepAsyncPending when ReadEntityBody() doesn't complete synchronously. 114 | TEST_F(ReadBodyStepTest, EnterReturnsAsyncPending) 115 | { 116 | EXPECT_CALL(request, GetRemainingEntityBytes()) 117 | .WillRepeatedly(Return(1)); 118 | 119 | // Calls OnAsyncCompletion() if read finishes inline, does not call 120 | // ReadBodyEntity() again if OnAsyncCompletion() != kStepRerun 121 | EXPECT_CALL(request, ReadEntityBody(_, _, _, _, _)) 122 | .Times(1) 123 | .WillOnce(DoAll( 124 | SetArgPointee<4>(TRUE), // completion_expected 125 | Return(S_OK) 126 | )); 127 | 128 | EXPECT_EQ(kStepAsyncPending, step.Enter()); 129 | } 130 | 131 | 132 | TEST_F(ReadBodyStepTest, OnAsyncCompletionFinishesRequestOnError) 133 | { 134 | EXPECT_EQ(kStepFinishRequest, step.OnAsyncCompletion(E_ACCESSDENIED, 1)); 135 | } 136 | 137 | 138 | // Should return kStepRerun when there's more data to read. 139 | TEST_F(ReadBodyStepTest, OnAsyncCompletionMoreData) 140 | { 141 | EXPECT_CALL(request, GetRemainingEntityBytes()) 142 | .WillRepeatedly(Return(1)); 143 | 144 | EXPECT_EQ(kStepRerun, step.OnAsyncCompletion(S_OK, 1)); 145 | } 146 | 147 | 148 | // Should goto next step when there's no more data to read. 149 | TEST_F(ReadBodyStepTest, OnAsyncCompletionNoMoreData) 150 | { 151 | EXPECT_CALL(request, GetRemainingEntityBytes()) 152 | .WillRepeatedly(Return(0)); 153 | 154 | EXPECT_EQ(kStepGotoNext, step.OnAsyncCompletion(S_OK, 1)); 155 | } 156 | -------------------------------------------------------------------------------- /IntegrationTests/fixtures/etw.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | 4 | from __future__ import unicode_literals, print_function, absolute_import 5 | 6 | import datetime 7 | import os 8 | import threading 9 | import time 10 | 11 | import pytest 12 | 13 | import etw 14 | import etw.descriptors.event 15 | import etw.descriptors.field 16 | 17 | 18 | def Consumer(guid): 19 | """Defines boiler-plate classes for the default 'string' events.""" 20 | 21 | class EventDefinitions(object): 22 | GUID = guid 23 | StringEvent = (GUID, 0) 24 | 25 | class EventProvider(etw.descriptors.event.EventCategory): 26 | GUID = EventDefinitions.GUID 27 | VERSION = 0 28 | 29 | class StringEventClass(etw.descriptors.event.EventClass): 30 | _event_types_ = [EventDefinitions.StringEvent] 31 | _fields_ = [('data', etw.descriptors.field.WString)] 32 | 33 | class _Consumer(etw.EventConsumer): 34 | def __init__(self, *args, **kwargs): 35 | super(_Consumer, self).__init__(*args, **kwargs) 36 | self.events = [] 37 | 38 | @etw.EventHandler(EventDefinitions.StringEvent) 39 | def on_event_message(self, event): 40 | self.events.append(event) 41 | 42 | def get_report(self): 43 | return '\r\n'.join( 44 | '%s P%04dT%04d %s' % ( 45 | datetime.datetime.fromtimestamp(event.time_stamp).isoformat(), 46 | event.process_id, event.thread_id, event.data 47 | ) 48 | for event in self.events 49 | ) 50 | 51 | return _Consumer() 52 | 53 | 54 | class ConsumerThread(threading.Thread): 55 | def __init__(self, source, *args, **kwargs): 56 | threading.Thread.__init__(self, *args, **kwargs) 57 | self.source = source 58 | 59 | def run(self): 60 | self.source.Consume() 61 | 62 | 63 | class EtwSession(object): 64 | 65 | def __init__(self, guid, session_name): 66 | self.guid = guid 67 | self.session_name = session_name 68 | 69 | def start(self): 70 | # We use REAL_TIME and do not specify a LogFile. 71 | props = etw.TraceProperties() 72 | props_struct = props.get() 73 | props_struct.contents.LogFileMode = etw.evntcons.PROCESS_TRACE_MODE_REAL_TIME 74 | props_struct.contents.LogFileNameOffset = 0 75 | self.controller = etw.TraceController() 76 | self.controller.Start(self.session_name, props) 77 | self.controller.EnableProvider( 78 | etw.evntrace.GUID(self.guid), 79 | etw.evntrace.TRACE_LEVEL_VERBOSE 80 | ) 81 | self.consumer = Consumer(self.guid) 82 | self.event_source = etw.TraceEventSource([self.consumer]) 83 | self.event_source.OpenRealtimeSession(self.session_name) 84 | # We need a separate thread to consume the events. 85 | self.thread = ConsumerThread(self.event_source) 86 | self.thread.start() 87 | 88 | def stop(self): 89 | # Calling .Close() will normally return ERROR_CTX_CLOSE_PENDING as we're 90 | # using REAL_TIME. It may take a little while for the final events to come in. 91 | try: 92 | self.event_source.Close() 93 | except WindowsError as e: 94 | if e.errno != 7007: # ERROR_CTX_CLOSE_PENDING 95 | raise 96 | # Wait for it to finish consuming events. 97 | self.thread.join() 98 | self.controller.DisableProvider(etw.evntrace.GUID(self.guid)) 99 | self.controller.Stop() 100 | 101 | def __enter__(self): 102 | self.start() 103 | return self 104 | def __exit__(self, *args): 105 | self.stop() 106 | 107 | 108 | def _add_report_for_fail(request, report_title, etw_session): 109 | # If the test failed, append a report of all the logs we captured. 110 | # This relies on our hook to add the report to the test object. 111 | if hasattr(request.node, '_report'): 112 | if not request.node._report.passed: 113 | request.node._report.longrepr.addsection( 114 | 'Captured ETW Logs: %s' % report_title, 115 | etw_session.consumer.get_report() 116 | ) 117 | 118 | 119 | asgi_handler_etw_guid = '{b057f98c-cb95-413d-afae-8ed010db73c5}' 120 | process_pool_etw_guid = '{8fc896e8-4370-4976-855b-19b70c976414}' 121 | 122 | @pytest.yield_fixture 123 | def asgi_etw_consumer(request): 124 | session_name = 'AsgiHandlerSession-' + os.urandom(4).encode('hex') 125 | try: 126 | with EtwSession(asgi_handler_etw_guid, session_name) as etw_session: 127 | yield 128 | finally: 129 | _add_report_for_fail(request, 'AsgiHandler', etw_session) 130 | 131 | @pytest.yield_fixture 132 | def pool_etw_consumer(request): 133 | session_name = 'ProcessPoolSession-' + os.urandom(4).encode('hex') 134 | try: 135 | with EtwSession(process_pool_etw_guid, session_name) as etw_session: 136 | yield 137 | finally: 138 | _add_report_for_fail(request, 'ProcessPool', etw_session) 139 | -------------------------------------------------------------------------------- /AsgiHandlerTest/test_WaitForResponseStep.cpp: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | #include "gmock/gmock.h" 3 | 4 | #include "mock_httpserv.h" 5 | #include "mock_ResponsePump.h" 6 | #include "mock_HttpRequestHandler.h" 7 | 8 | 9 | using ::testing::MatcherCast; 10 | using ::testing::NiceMock; 11 | using ::testing::Pointee; 12 | using ::testing::Return; 13 | using ::testing::SaveArg; 14 | using ::testing::StrEq; 15 | using ::testing::_; 16 | 17 | 18 | class WaitForResponseStepTest : public ::testing::Test 19 | { 20 | public: 21 | WaitForResponseStepTest() 22 | : handler(response_pump, &http_context), reply_channel("http.request!blah"), 23 | step(handler, reply_channel) 24 | { 25 | ON_CALL(http_context, GetResponse()) 26 | .WillByDefault(Return(&response)); 27 | 28 | // These functions are called often, but only some tests are interested in 29 | // them. The interested tests will EXPECT_CALL. 30 | ON_CALL(response, SetStatus(_, _, _, _, _, _)) 31 | .WillByDefault(Return(S_OK)); 32 | } 33 | 34 | std::string GetDummyResponse( 35 | int status, std::vector> headers, std::string content, 36 | bool more_content = false 37 | ) const 38 | { 39 | AsgiHttpResponseMsg msg; 40 | msg.status = status; 41 | msg.headers = headers; 42 | msg.content = content; 43 | msg.more_content = more_content; 44 | msgpack::sbuffer buffer; 45 | msgpack::pack(buffer, msg); 46 | return std::string(buffer.data(), buffer.size()); 47 | } 48 | 49 | NiceMock http_context; 50 | MockResponsePump response_pump; 51 | MockHttpRequestHandler handler; 52 | NiceMock response; 53 | std::string reply_channel; 54 | WaitForResponseStep step; 55 | }; 56 | 57 | 58 | TEST_F(WaitForResponseStepTest, EnterAddsCallbackAndReturnsAsyncPending) 59 | { 60 | EXPECT_CALL(response_pump, AddChannel(reply_channel, _)) 61 | .Times(1); 62 | 63 | EXPECT_EQ(kStepAsyncPending, step.Enter()); 64 | } 65 | 66 | 67 | TEST_F(WaitForResponseStepTest, CallbackCallsPostCompletion) 68 | { 69 | ResponsePump::ResponseChannelCallback callback; 70 | EXPECT_CALL(response_pump, AddChannel(reply_channel, _)) 71 | .WillOnce(SaveArg<1>(&callback)); 72 | 73 | EXPECT_EQ(kStepAsyncPending, step.Enter()); 74 | 75 | EXPECT_CALL(http_context, PostCompletion(0)) 76 | .Times(1); 77 | 78 | callback(GetDummyResponse(200, { } , "")); 79 | } 80 | 81 | 82 | TEST_F(WaitForResponseStepTest, OnAsyncCompletionReturnsGotoNextIfBody) 83 | { 84 | ResponsePump::ResponseChannelCallback callback; 85 | EXPECT_CALL(response_pump, AddChannel(reply_channel, _)) 86 | .WillOnce(SaveArg<1>(&callback)); 87 | 88 | step.Enter(); 89 | 90 | callback(GetDummyResponse(200, { }, "some body")); 91 | 92 | EXPECT_EQ(kStepGotoNext, step.OnAsyncCompletion(S_OK, 0)); 93 | } 94 | 95 | 96 | TEST_F(WaitForResponseStepTest, OnAsyncCompletionReturnsFinishIfNoBody) 97 | { 98 | ResponsePump::ResponseChannelCallback callback; 99 | EXPECT_CALL(response_pump, AddChannel(reply_channel, _)) 100 | .WillOnce(SaveArg<1>(&callback)); 101 | 102 | step.Enter(); 103 | 104 | callback(GetDummyResponse(200, { }, "")); 105 | 106 | EXPECT_EQ(kStepFinishRequest, step.OnAsyncCompletion(S_OK, 0)); 107 | } 108 | 109 | 110 | TEST_F(WaitForResponseStepTest, OnAsyncCompletionReturnsGotoNextIfNoBodyButMoreChunks) 111 | { 112 | ResponsePump::ResponseChannelCallback callback; 113 | EXPECT_CALL(response_pump, AddChannel(reply_channel, _)) 114 | .WillOnce(SaveArg<1>(&callback)); 115 | 116 | step.Enter(); 117 | 118 | callback(GetDummyResponse(200, {}, "", true)); 119 | 120 | EXPECT_EQ(kStepGotoNext, step.OnAsyncCompletion(S_OK, 0)); 121 | } 122 | 123 | 124 | TEST_F(WaitForResponseStepTest, OnAsyncCompletionSetsStatus) 125 | { 126 | ResponsePump::ResponseChannelCallback callback; 127 | EXPECT_CALL(response_pump, AddChannel(reply_channel, _)) 128 | .WillOnce(SaveArg<1>(&callback)); 129 | 130 | step.Enter(); 131 | 132 | // SetStatus should be called with the status. 133 | { 134 | callback(GetDummyResponse(404, { }, "")); 135 | 136 | EXPECT_CALL(response, SetStatus(404, _, _, _, _, _)) 137 | .WillOnce(Return(S_OK)); 138 | 139 | step.OnAsyncCompletion(S_OK, 0); 140 | } 141 | 142 | // SetStatus should not be called if the status is 0 (missing in msgpack). 143 | // This happens for chunks in streaming responses. 144 | { 145 | callback(GetDummyResponse(0, { }, "")); 146 | 147 | EXPECT_CALL(response, SetStatus(_, _, _, _, _, _)) 148 | .Times(0); 149 | 150 | step.OnAsyncCompletion(S_OK, 0); 151 | } 152 | } 153 | 154 | 155 | TEST_F(WaitForResponseStepTest, OnAsyncCompletionSetsHeaders) 156 | { 157 | ResponsePump::ResponseChannelCallback callback; 158 | EXPECT_CALL(response_pump, AddChannel(reply_channel, _)) 159 | .WillOnce(SaveArg<1>(&callback)); 160 | 161 | step.Enter(); 162 | 163 | // SetHeader should be called for each header. 164 | { 165 | callback(GetDummyResponse(404, { 166 | std::make_tuple("header1", "value1"), 167 | std::make_tuple("header2", "value-2") 168 | }, "")); 169 | 170 | EXPECT_CALL(response, SetHeader(MatcherCast(StrEq("header1")), StrEq("value1"), 6, TRUE)) 171 | .Times(1); 172 | EXPECT_CALL(response, SetHeader(MatcherCast(StrEq("header2")), StrEq("value-2"), 7, TRUE)) 173 | .Times(1); 174 | 175 | step.OnAsyncCompletion(S_OK, 0); 176 | } 177 | 178 | // If there are no headers, SetHeader should not be called at all. 179 | // This happens for chunks in streaming responses. 180 | { 181 | callback(GetDummyResponse(404, { }, "")); 182 | 183 | EXPECT_CALL(response, SetHeader(MatcherCast(_), _, _, _)) 184 | .Times(0); 185 | 186 | step.OnAsyncCompletion(S_OK, 0); 187 | } 188 | } 189 | -------------------------------------------------------------------------------- /IntegrationTests/test_asgi_ws.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | 4 | from __future__ import unicode_literals, print_function 5 | 6 | import os 7 | import sys 8 | import time 9 | import urllib 10 | 11 | import pytest 12 | 13 | from websocket import create_connection 14 | 15 | 16 | def test_asgi_ws_order_0(site, asgi): 17 | # websocket.connect['order'] should always be 0 18 | ws1 = create_connection(site.ws_url, timeout=2) 19 | asgi_connect1 = asgi.receive_ws_connect() 20 | assert asgi_connect1['order'] == 0 21 | ws2 = create_connection(site.ws_url, timeout=2) 22 | asgi_connect2 = asgi.receive_ws_connect() 23 | assert asgi_connect2['order'] == 0 24 | 25 | 26 | def test_asgi_ws_connect_scheme_ws(site, asgi): 27 | ws = create_connection(site.ws_url, timeout=2) 28 | asgi_connect = asgi.receive_ws_connect() 29 | assert asgi_connect['scheme'] == 'ws' 30 | 31 | 32 | @pytest.mark.skip(reason='Need to provide TLS certificate when setting up IIS site.') 33 | def test_asgi_ws_connect_scheme_wss(site, asgi): 34 | ws = create_connection(site.ws_url, timeout=2) 35 | asgi_connect = asgi.receive_ws_connect() 36 | assert asgi_connect['scheme'] == 'wss' 37 | 38 | 39 | @pytest.mark.parametrize('path', [ 40 | '', 41 | '/onedir/', 42 | '/missing-trailing', 43 | '/multi/level/with/file.txt' 44 | ]) 45 | def test_asgi_ws_connect_path(path, site, asgi): 46 | ws = create_connection(site.ws_url + path, timeout=2) 47 | asgi_connect = asgi.receive_ws_connect() 48 | # IIS normalizes an empty path to '/' 49 | path = path if path else '/' 50 | assert asgi_connect['path'] == path 51 | 52 | 53 | @pytest.mark.parametrize('qs_parts', [ 54 | None, # means 'do not append ? to url' 55 | [], 56 | [('a', 'b')], 57 | [('a', 'b'), ('c', 'd')], 58 | [('a', 'b c')], 59 | ]) 60 | def test_asgi_ws_connect_querystring(qs_parts, site, asgi): 61 | qs = '' 62 | if qs_parts is not None: 63 | qs = '?' + urllib.urlencode(qs_parts) 64 | ws = create_connection(site.ws_url + qs, timeout=2) 65 | asgi_connect = asgi.receive_ws_connect() 66 | # IIS gives a query string of '' when provided with '?' 67 | expected_qs = '' if qs == '?' else qs 68 | assert asgi_connect['query_string'] == expected_qs.encode('utf-8') 69 | 70 | 71 | @pytest.mark.parametrize('headers', [ 72 | {'User-Agent': 'Custom User-Agent'}, # 'Known' header 73 | {'X-Custom-Header': 'Value'}, 74 | ]) 75 | def test_asgi_ws_connect_headers(headers, site, asgi): 76 | ws = create_connection(site.ws_url, header=headers, timeout=2) 77 | asgi_connect = asgi.receive_ws_connect() 78 | for name, value in headers.items(): 79 | encoded_header = [name.encode('utf-8'), value.encode('utf-8')] 80 | assert encoded_header in asgi_connect['headers'] 81 | 82 | 83 | def test_asgi_ws_receive(site, asgi): 84 | ws = create_connection(site.ws_url, timeout=2) 85 | ws.send(b'Hello, world!') 86 | asgi_receive = asgi.receive_ws_data() 87 | assert asgi_receive['bytes'] == b'Hello, world!' 88 | 89 | 90 | def test_asgi_ws_receive_has_correct_channel(site, asgi): 91 | ws = create_connection(site.ws_url, timeout=2) 92 | asgi_connect = asgi.receive_ws_connect() 93 | ws.send('some data') 94 | asgi_receive = asgi.receive_ws_data() 95 | assert asgi_connect['reply_channel'].startswith('websocket.send') 96 | assert asgi_connect['reply_channel'] == asgi_receive['reply_channel'] 97 | 98 | 99 | def test_asgi_ws_receive_has_correct_path(site, asgi): 100 | path = '/a/path/' 101 | ws = create_connection(site.ws_url + path, timeout=2) 102 | asgi_connect = asgi.receive_ws_connect() 103 | ws.send('some data') 104 | asgi_receive = asgi.receive_ws_data() 105 | assert asgi_connect['path'] == path 106 | assert asgi_connect['path'] == asgi_receive['path'] 107 | 108 | 109 | def test_asgi_ws_receive_multiple_times_same_connection(site, asgi): 110 | ws = create_connection(site.ws_url, timeout=2) 111 | # Check we can receive multiple frames from the websocket connection, 112 | # and ensure we can still send more data after we've received some. 113 | for _ in range(2): 114 | for i in range(10): 115 | ws.send(b'%i' % i) 116 | for i in range(10): 117 | asgi_receive = asgi.receive_ws_data() 118 | assert asgi_receive['bytes'] == b'%i' % i 119 | 120 | 121 | def test_asgi_ws_receive_order_increases(site, asgi): 122 | ws = create_connection(site.ws_url, timeout=2) 123 | ws.send(b' ') 124 | asgi_receive1 = asgi.receive_ws_data() 125 | assert asgi_receive1['order'] == 1 126 | ws.send(b' ') 127 | asgi_receive1 = asgi.receive_ws_data() 128 | assert asgi_receive1['order'] == 2 129 | 130 | # TODO: 131 | # - Receive data of various sizes 132 | # - websocket.receive has bytes/text depending on utf8 133 | 134 | 135 | def test_asgi_ws_send(site, asgi): 136 | ws = create_connection(site.ws_url, timeout=2) 137 | asgi_connect = asgi.receive_ws_connect() 138 | asgi.send(asgi_connect['reply_channel'], dict(bytes=b'Hello, world!', close=False)) 139 | data = ws.recv() 140 | assert data == b'Hello, world!' 141 | 142 | 143 | def test_asgi_ws_send_multiple_times_same_connection(site, asgi): 144 | ws = create_connection(site.ws_url, timeout=2) 145 | asgi_connect = asgi.receive_ws_connect() 146 | for i in range(10): 147 | asgi.send(asgi_connect['reply_channel'], dict(bytes=b'%i' % i, close=False)) 148 | for i in range(10): 149 | data = ws.recv() 150 | assert data == b'%i' % i 151 | 152 | 153 | def test_asgi_ws_concurrent_connections(site, asgi): 154 | # TODO: Paramaterize this test and have it create more connections when 155 | # run on Windows Server, where we can have >10 concurrent connections. 156 | wses = [create_connection(site.ws_url, timeout=2) for _ in range(10)] 157 | for i, ws in enumerate(wses): 158 | for j in range(10): 159 | ws.send(b'%i-%i' % (i, j)) 160 | 161 | # Echo each message back. 162 | for i in range(100): 163 | asgi_msg = asgi.receive_ws_data() 164 | asgi.send(asgi_msg['reply_channel'], dict(bytes=asgi_msg['bytes'])) 165 | 166 | # Ensure each ws got the complete set of replies. We can rely on the messages 167 | # being in order, as we have just one worker (us) and the interface server 168 | # handles messages on each WebSocket serially. 169 | for i, ws in enumerate(wses): 170 | for j in range(10): 171 | data = ws.recv() 172 | assert data == b'%i-%i' % (i, j) 173 | 174 | 175 | if __name__ == '__main__': 176 | # Should really sys.exit() this, but it causes Visual Studio 177 | # to eat the output. :( 178 | pytest.main(['--ignore', 'env1/', 'test_ws.py']) 179 | -------------------------------------------------------------------------------- /AsgiHandlerLib/RequestHandler.cpp: -------------------------------------------------------------------------------- 1 | #define WIN32_LEAN_AND_MEAN 2 | #include 3 | 4 | #include "RequestHandler.h" 5 | 6 | 7 | namespace { 8 | 9 | // TODO: Check these and see if http.sys gives us them anywhere - it must 10 | // have them! MJK did these on a train based on the values in http.h 11 | // TODO: Lower case these. The ASGI spec specifies we should give headers 12 | // as lower case, so we may as well store them here that way. 13 | std::unordered_map kKnownHeaderMap({ 14 | { HttpHeaderCacheControl, "Cache-Control" }, 15 | { HttpHeaderConnection, "Connection" }, 16 | { HttpHeaderDate, "Date" }, 17 | { HttpHeaderKeepAlive, "Keep-Alive" }, 18 | { HttpHeaderPragma, "Pragma" }, 19 | { HttpHeaderTrailer, "Trailer" }, 20 | { HttpHeaderTransferEncoding, "Transfer-Encoding" }, 21 | { HttpHeaderUpgrade, "Upgrade" }, 22 | { HttpHeaderVia, "Via" }, 23 | { HttpHeaderWarning, "Warning" }, 24 | 25 | { HttpHeaderAllow, "Allow" }, 26 | { HttpHeaderContentLength, "Content-Length" }, 27 | { HttpHeaderContentType, "ContentType" }, 28 | { HttpHeaderContentEncoding, "Content-Encoding" }, 29 | { HttpHeaderContentLanguage, "Content-Language" }, 30 | { HttpHeaderContentLocation, "Content-Location" }, 31 | { HttpHeaderContentMd5, "Content-Md5" }, 32 | { HttpHeaderContentRange, "Content-Range" }, 33 | { HttpHeaderExpires, "Expires" }, 34 | { HttpHeaderLastModified, "Last-Modified" }, 35 | 36 | { HttpHeaderAccept, "Accept" }, 37 | { HttpHeaderAcceptCharset, "Accept-Charset" }, 38 | { HttpHeaderAcceptEncoding, "Accept-Encoding" }, 39 | { HttpHeaderAcceptLanguage, "Accept-Language" }, 40 | { HttpHeaderAuthorization, "Authorization" }, 41 | { HttpHeaderCookie, "Cookie" }, 42 | { HttpHeaderExpect, "Expect" }, 43 | { HttpHeaderFrom, "From" }, 44 | { HttpHeaderHost, "Host" }, 45 | { HttpHeaderIfMatch, "If-Match" }, 46 | 47 | { HttpHeaderIfModifiedSince, "If-Modified-Since" }, 48 | { HttpHeaderIfNoneMatch, "If-None-Match" }, 49 | { HttpHeaderIfRange, "If-Range" }, 50 | { HttpHeaderIfUnmodifiedSince, "If-Unmodified-Since" }, 51 | { HttpHeaderMaxForwards, "Max-Forwards" }, 52 | { HttpHeaderProxyAuthorization, "Proxy-Authorization" }, 53 | { HttpHeaderReferer, "Referer" }, 54 | { HttpHeaderRange, "Range" }, 55 | { HttpHeaderTe, "Te" }, 56 | { HttpHeaderTranslate, "Translate" }, 57 | 58 | { HttpHeaderUserAgent, "User-Agent" } 59 | }); 60 | 61 | } // end anonymous namespace 62 | 63 | 64 | RequestHandlerStep::RequestHandlerStep(RequestHandler& handler) 65 | : m_handler(handler), m_http_context(handler.m_http_context), 66 | logger(handler.logger), m_response_pump(handler.m_response_pump), 67 | m_channels(handler.m_channels) 68 | { } 69 | 70 | 71 | REQUEST_NOTIFICATION_STATUS RequestHandler::HandlerStateMachine(std::unique_ptr& step, StepResult result) 72 | { 73 | // This won't loop forever. We expect to return AsyncPending fairly often. 74 | while (true) { 75 | switch (result) { 76 | case kStepAsyncPending: 77 | return RQ_NOTIFICATION_PENDING; 78 | case kStepFinishRequest: 79 | return RQ_NOTIFICATION_FINISH_REQUEST; 80 | case kStepRerun: { 81 | 82 | logger.debug() << typeid(*step.get()).name() << "->Enter() being called"; 83 | result = step->Enter(); 84 | logger.debug() << typeid(*step.get()).name() << "->Enter() = " << result; 85 | 86 | break; 87 | } 88 | case kStepGotoNext: { 89 | 90 | logger.debug() << typeid(*step.get()).name() << "->GetNextStep() being called"; 91 | step = step->GetNextStep(); 92 | 93 | logger.debug() << typeid(*step.get()).name() << "->Enter() being called"; 94 | result = step->Enter(); 95 | logger.debug() << typeid(*step.get()).name() << "->Enter() = " << result; 96 | 97 | break; 98 | } 99 | } 100 | } 101 | // Never reached. 102 | } 103 | 104 | 105 | std::string RequestHandler::GetRequestHttpVersion(const IHttpRequest* request) 106 | { 107 | USHORT http_ver_major, http_ver_minor; 108 | request->GetHttpVersion(&http_ver_major, &http_ver_minor); 109 | std::string http_version = std::to_string(http_ver_major); 110 | if (http_ver_major == 1) { 111 | http_version += "." + std::to_string(http_ver_minor); 112 | } 113 | return http_version; 114 | } 115 | 116 | 117 | std::string RequestHandler::GetRequestScheme(const HTTP_REQUEST* raw_request) 118 | { 119 | std::wstring url( 120 | raw_request->CookedUrl.pFullUrl, 121 | raw_request->CookedUrl.FullUrlLength / sizeof(wchar_t) 122 | ); 123 | auto colon_idx = url.find(':'); 124 | std::wstring scheme_w; 125 | // Don't assume the cooked URL has a scheme, although it should. 126 | if (colon_idx != url.npos) { 127 | scheme_w = url.substr(0, colon_idx); 128 | } 129 | std::wstring_convert> utf8_conv; 130 | return utf8_conv.to_bytes(scheme_w); 131 | } 132 | 133 | 134 | std::string RequestHandler::GetRequestPath(const HTTP_REQUEST* raw_request) 135 | { 136 | std::wstring_convert> utf8_conv; 137 | return utf8_conv.to_bytes(std::wstring( 138 | raw_request->CookedUrl.pAbsPath, 139 | raw_request->CookedUrl.AbsPathLength / sizeof(wchar_t) 140 | )); 141 | } 142 | 143 | std::string RequestHandler::GetRequestQueryString(const HTTP_REQUEST* raw_request) 144 | { 145 | std::wstring_convert> utf8_conv; 146 | return utf8_conv.to_bytes(std::wstring( 147 | raw_request->CookedUrl.pQueryString, 148 | raw_request->CookedUrl.QueryStringLength / sizeof(wchar_t) 149 | )); 150 | } 151 | 152 | 153 | std::vector> RequestHandler::GetRequestHeaders(const HTTP_REQUEST* raw_request) 154 | { 155 | std::vector> headers; 156 | auto known_headers = raw_request->Headers.KnownHeaders; 157 | for (USHORT i = 0; i < HttpHeaderRequestMaximum; i++) { 158 | if (known_headers[i].RawValueLength > 0) { 159 | auto value = std::string(known_headers[i].pRawValue, known_headers[i].RawValueLength); 160 | headers.push_back(std::make_tuple(kKnownHeaderMap[i], value)); 161 | } 162 | } 163 | 164 | // TODO: ASGI specifies headers should be lower-cases. Makes sense to do that here? 165 | auto unknown_headers = raw_request->Headers.pUnknownHeaders; 166 | for (USHORT i = 0; i < raw_request->Headers.UnknownHeaderCount; i++) { 167 | auto name = std::string(unknown_headers[i].pName, unknown_headers[i].NameLength); 168 | auto value = std::string(unknown_headers[i].pRawValue, unknown_headers[i].RawValueLength); 169 | headers.push_back(std::make_tuple(name, value)); 170 | } 171 | 172 | return headers; 173 | } 174 | -------------------------------------------------------------------------------- /ProcessPoolLib/ProcessPool.cpp: -------------------------------------------------------------------------------- 1 | #include "ProcessPool.h" 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #define WIN32_LEAN_AND_MEAN 9 | #include 10 | 11 | #include "Logger.h" 12 | 13 | 14 | ProcessPool::ProcessPool( 15 | const Logger& logger, const std::wstring& process, 16 | const std::wstring& arguments, size_t num_processes 17 | ) 18 | : logger{ logger }, m_process{ process }, m_arguments{ arguments }, 19 | m_num_processes{ num_processes } 20 | { 21 | // Make a combined command line, escaping process as necessary. 22 | m_command_line = EscapeArgument(process) + L" " + arguments; 23 | 24 | logger.debug() << "ProcessPool initialized with command line: " 25 | << m_command_line; 26 | 27 | m_thread_exit_event = ScopedHandle{ ::CreateEvent(nullptr, TRUE, FALSE, nullptr) }; 28 | m_thread = std::thread{ &ProcessPool::ThreadMain, this }; 29 | } 30 | 31 | ProcessPool::~ProcessPool() 32 | { 33 | // Stop the monitoring thread and wait for it to join. 34 | // We don't explicitly terminate the running processes - they will terminate 35 | // when our job object is destructed. 36 | Stop(); 37 | m_thread.join(); 38 | } 39 | 40 | void ProcessPool::Stop() const 41 | { 42 | // Signal for the thread to stop. 43 | ::SetEvent(m_thread_exit_event.Get()); 44 | } 45 | 46 | void ProcessPool::ThreadMain() 47 | { 48 | logger.debug() << "ProcessPool::ThreadMain(): Enter"; 49 | 50 | while (true) { 51 | logger.debug() << "ProcessPool::ThreadMain(): Loop"; 52 | 53 | // Start any missing processes. 54 | // TODO: Determine what to do if process creation fails. Should we retry 55 | // a certain number of times and then give up? Should this affect 56 | // the timeout we give to ::WaitForMultipleObjects()? 57 | for (auto i = m_processes.size(); i < m_num_processes; ++i) { 58 | CreateProcess(); 59 | } 60 | 61 | // Get the handles for all processes we've already started and wait until one 62 | // of them is signaled. This will indicate that the process has been terminated. 63 | // Also watch m_thread_exit_event to see if we should terminate. 64 | auto handles = std::vector(m_processes.size() + 1); 65 | std::transform( 66 | std::begin(m_processes), std::end(m_processes), 67 | std::begin(handles), 68 | [](const auto& sh) { return sh.Get(); } 69 | ); 70 | handles[m_processes.size()] = m_thread_exit_event.Get(); 71 | auto wait_result = ::WaitForMultipleObjects( 72 | handles.size(), 73 | handles.data(), 74 | FALSE, 75 | INFINITE 76 | ); 77 | auto signaled_idx = wait_result - WAIT_OBJECT_0; 78 | 79 | if (signaled_idx == m_processes.size()) { 80 | logger.debug() << "ProcessPool::ThreadMain(): Exit signaled"; 81 | break; 82 | } 83 | 84 | // If one of the handles was signaled, remove the associated process from our 85 | // list (as it has terminated). We'll attempt to start one in its place when 86 | // we loop around. 87 | if (signaled_idx >= 0 && signaled_idx < m_processes.size()) { 88 | logger.debug() << "ProcessPool::ThreadMain(): Process terminated"; 89 | m_processes.erase(std::begin(m_processes) + signaled_idx); 90 | } 91 | 92 | // Regardless, go round the loop again. If it returned because of error, 93 | // lot it but continue. 94 | if (wait_result == WAIT_FAILED) { 95 | logger.debug() << "::WaitForMultipleObjects()=" << ::GetLastError(); 96 | } 97 | } 98 | 99 | logger.debug() << "ProcessPool::ThreadMain(): Exit"; 100 | } 101 | 102 | void ProcessPool::CreateProcess() 103 | { 104 | logger.debug() << "Creating process with command line: " << m_command_line; 105 | 106 | // We must make a copy as a char*, as CreateProcessW() can modify it. 107 | auto cmd_buffer = std::vector(m_command_line.length() + 1, '\0'); 108 | std::copy(m_command_line.begin(), m_command_line.end(), cmd_buffer.begin()); 109 | 110 | // We create the process with its main thread suspended, so that we can assign 111 | // it to our job object before it runs. 112 | auto startup_info = STARTUPINFO{ 0 }; 113 | auto proc_info = PROCESS_INFORMATION{ 0 }; 114 | auto created = ::CreateProcess( 115 | nullptr, cmd_buffer.data(), 116 | nullptr, nullptr, FALSE, CREATE_SUSPENDED, nullptr, nullptr, 117 | &startup_info, &proc_info 118 | ); 119 | if (!created) { 120 | logger.debug() << "Could not create process: " << ::GetLastError(); 121 | return; 122 | } 123 | auto process = ScopedHandle{ proc_info.hProcess }; 124 | auto thread = ScopedHandle{ proc_info.hThread }; 125 | 126 | // Assign the process to the job and resume it. If we fail to do either, 127 | // then terminate the process. 128 | auto assigned = ::AssignProcessToJobObject(m_job.GetHandle(), process.Get()); 129 | auto resumed = ::ResumeThread(thread.Get()) != -1; 130 | if (!assigned || !resumed) { 131 | logger.debug() << "AssignProcessToJobObject=" << assigned << " ResumeThread=" 132 | << resumed << " GetLastError()=" << GetLastError(); 133 | ::TerminateProcess(process.Get(), 0); 134 | return; 135 | } 136 | 137 | m_processes.push_back(std::move(process)); 138 | } 139 | 140 | 141 | std::wstring ProcessPool::EscapeArgument(const std::wstring& argument) 142 | { 143 | // See https://blogs.msdn.microsoft.com/twistylittlepassagesallalike/2011/04/23/everyone-quotes-command-line-arguments-the-wrong-way/ 144 | // Note: we do not escape in such a way that it can be passed to cmd.exe. 145 | 146 | // Only surround with quotes (and escape contents) if there is whitespace 147 | // or quotes in the argument. 148 | if (argument.find_first_of(L" \t\n\v\"") == argument.npos) { 149 | return argument; 150 | } 151 | 152 | auto escaped = std::wostringstream{ L"\"", std::ios_base::ate }; 153 | auto escaped_it = static_cast>(escaped); 154 | auto num_backslashes = 0; 155 | 156 | for (auto& character : argument) { 157 | switch (character) { 158 | case L'\\': 159 | ++num_backslashes; 160 | continue; 161 | case L'"': 162 | // Escape all of the backslashes, plus add one more escape for 163 | // this quotation mark. 164 | std::fill_n(escaped_it, num_backslashes * 2 + 1, L'\\'); 165 | escaped << L'"'; 166 | break; 167 | default: 168 | // We don't need to escape the backslashes if they're not followed 169 | // by a quote. 170 | std::fill_n(escaped_it, num_backslashes, L'\\'); 171 | escaped << character; 172 | break; 173 | } 174 | num_backslashes = 0; 175 | } 176 | // As we're going to append a final " to the argument, we need to escape 177 | // all the remaining backslashes. 178 | std::fill_n(escaped_it, num_backslashes * 2, L'\\'); 179 | // We don't escape our final quote: 180 | escaped << L'\"'; 181 | 182 | return escaped.str(); 183 | } 184 | -------------------------------------------------------------------------------- /IntegrationTests/test_asgi_http.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | 4 | from __future__ import unicode_literals, print_function 5 | 6 | import os 7 | import sys 8 | import time 9 | import urllib 10 | 11 | import pytest 12 | 13 | 14 | @pytest.mark.parametrize('method', ['GET', 'GeT', 'POST', 'PUT']) 15 | def test_asgi_http_request_method(method, site, asgi, session): 16 | session.request(method, site.url) 17 | asgi_request = asgi.receive_request() 18 | assert asgi_request['method'] == method.upper() 19 | 20 | 21 | def test_asgi_http_request_scheme_http(site, asgi, session): 22 | session.get(site.url) 23 | asgi_request = asgi.receive_request() 24 | assert asgi_request['scheme'] == 'http' 25 | 26 | 27 | @pytest.mark.skip(reason='Need to provide HTTPS certificate when setting up IIS site.') 28 | def test_asgi_http_request_scheme_https(site, asgi, session): 29 | session.get(site.https_url, verify=False) 30 | asgi_request = asgi.receive_request() 31 | assert asgi_request['scheme'] == 'https' 32 | 33 | 34 | @pytest.mark.parametrize('path', [ 35 | '', 36 | '/onedir/', 37 | '/missing-trailing', 38 | '/multi/level/with/file.txt' 39 | ]) 40 | def test_asgi_http_request_path(path, site, asgi, session): 41 | session.get(site.url + path) 42 | asgi_request = asgi.receive_request() 43 | # IIS normalizes an empty path to '/' 44 | path = path if path else '/' 45 | assert asgi_request['path'] == path 46 | 47 | 48 | @pytest.mark.parametrize('qs_parts', [ 49 | None, # means 'do not append ? to url' 50 | [], 51 | [('a', 'b')], 52 | [('a', 'b'), ('c', 'd')], 53 | [('a', 'b c')], 54 | ]) 55 | def test_asgi_http_request_querystring(qs_parts, site, asgi, session): 56 | qs = '' 57 | if qs_parts is not None: 58 | qs = '?' + urllib.urlencode(qs_parts) 59 | session.get(site.url + qs) 60 | asgi_request = asgi.receive_request() 61 | # IIS gives a query string of '' when provided with '?' 62 | expected_qs = '' if qs == '?' else qs 63 | assert asgi_request['query_string'] == expected_qs.encode('utf-8') 64 | 65 | 66 | @pytest.mark.parametrize('body', [ 67 | '', 68 | '☃', 69 | 'this is a body', 70 | ]) 71 | def test_asgi_http_request_body(body, site, asgi, session): 72 | session.post(site.url, data=body.encode('utf-8')) 73 | asgi_request = asgi.receive_request() 74 | assert asgi_request['body'].decode('utf-8') == body 75 | 76 | 77 | # NOTE: This test checks that requests with large bodies are handled. However, 78 | # they *should* be split into multiple ASGI messages, but they currently are not. 79 | @pytest.mark.parametrize('body_size', [ 80 | 1024, # 1 KB 81 | 1024 * 1024, # 1 MB 82 | 1024 * 1024 * 15, # 15 MB 83 | ]) 84 | def test_asgi_http_request_large_bodies(body_size, site, asgi, session): 85 | body = os.urandom(body_size // 2).encode('hex') 86 | if body_size % 2: 87 | body += 'e' 88 | session.post(site.url, data=body) 89 | asgi_request = asgi.receive_request() 90 | assert asgi_request['body'].decode('utf-8') == body 91 | 92 | 93 | @pytest.mark.parametrize('headers', [ 94 | {'User-Agent': 'Custom User-Agent'}, # 'Known' header 95 | {'X-Custom-Header': 'Value'}, 96 | ]) 97 | def test_asgi_http_request_headers(headers, site, asgi, session): 98 | session.get(site.url, headers=headers) 99 | asgi_request = asgi.receive_request() 100 | for name, value in headers.items(): 101 | encoded_header = [name.encode('utf-8'), value.encode('utf-8')] 102 | assert encoded_header in asgi_request['headers'] 103 | 104 | 105 | @pytest.mark.parametrize('status', [200, 404, 500]) 106 | def test_asgi_http_response_status(status, site, asgi, session): 107 | future = session.get(site.url) 108 | asgi_request = asgi.receive_request() 109 | asgi_response = dict(status=200, headers=[]) 110 | asgi.send(asgi_request['reply_channel'], asgi_response) 111 | resp = future.result(timeout=2) 112 | assert resp.status_code == 200 113 | 114 | 115 | @pytest.mark.parametrize('headers', [ 116 | [('X-Custom-Header', 'Value')], 117 | [('Server', 'Not IIS')], # IIS sets this. We should be able to override it. 118 | ]) 119 | def test_asgi_http_response_status(headers, site, asgi, session): 120 | future = session.get(site.url) 121 | asgi_request = asgi.receive_request() 122 | asgi_response = dict(status=200, headers=headers) 123 | asgi.send(asgi_request['reply_channel'], asgi_response) 124 | resp = future.result(timeout=2) 125 | assert resp.status_code == 200 126 | # IIS will (helpfully!) return other headers too. Assert 127 | # ours are in there with the correct value. 128 | for name, value in headers: 129 | assert name in resp.headers 130 | assert resp.headers[name] == value 131 | 132 | 133 | @pytest.mark.parametrize('body', [ 134 | '', 135 | 'a noddy body', 136 | ]) 137 | def test_asgi_http_response_body(body, site, asgi, session): 138 | future = session.get(site.url) 139 | asgi_request = asgi.receive_request() 140 | asgi_response = dict(status=200, headers=[], content=body) 141 | asgi.send(asgi_request['reply_channel'], asgi_response) 142 | resp = future.result(timeout=2) 143 | assert resp.status_code == 200 144 | assert resp.text == body 145 | assert resp.headers['Content-Length'] == str(len(body)) 146 | 147 | 148 | # We can't do more than 10 connections on non-server editions of Windows, 149 | # as Windows limits the number of concurrent connections. 150 | # It might be worth a separate test to ensure that extra requests are queued. 151 | @pytest.mark.parametrize('number', [2, 4, 10]) 152 | def test_asgi_multiple_concurrent_http_requests(number, site, asgi, session): 153 | futures = [] 154 | for i in range(number): 155 | futures.append(session.get(site.url, data=b'req%i' % i)) 156 | # Collect all requests after they've been issued. (We don't 157 | # necessarily collect them in the order they were issued). 158 | asgi_requests = [asgi.receive_request() for _ in range(number)] 159 | # Echo the request's body back to it. 160 | for i, asgi_request in enumerate(asgi_requests): 161 | asgi_response = dict(status=200, headers=[], content=asgi_request['body']) 162 | asgi.send(asgi_request['reply_channel'], asgi_response) 163 | # Check each request got the appropriate response. 164 | for i, future in enumerate(futures): 165 | result = future.result(timeout=2) 166 | assert result.status_code == 200 167 | assert result.text == b'req%i' % i 168 | 169 | 170 | def test_asgi_streaming_response(site, asgi, session): 171 | future = session.get(site.url, stream=True) 172 | asgi_request = asgi.receive_request() 173 | channel = asgi_request['reply_channel'] 174 | # Send first response with a status code 175 | asgi.send(channel, dict(status=200, more_content=True)) 176 | 177 | # Now send a few chunks. 178 | chunks = [b'chunk%i' % i for i in range(5)] 179 | for chunk in chunks: 180 | asgi.send(channel, dict(content=chunk, more_content=True)) 181 | 182 | # Check we've received the chunks. 183 | resp = future.result(timeout=2) 184 | resp_chunks = resp.iter_content(chunk_size=len(chunks[0])) 185 | for chunk in chunks: 186 | assert resp_chunks.next() == chunk 187 | 188 | # Send another, to check we still can. 189 | final_chunk = b'chunk%i' % len(chunks) 190 | asgi.send(channel, dict(content=final_chunk, more_content=False)) 191 | assert resp_chunks.next() == final_chunk 192 | # ... and fin. Check the connection is closed. 193 | with pytest.raises(StopIteration): 194 | resp_chunks.next() 195 | 196 | 197 | if __name__ == '__main__': 198 | # Should really sys.exit() this, but it causes Visual Studio 199 | # to eat the output. :( 200 | pytest.main(['--ignore', 'env1/', '-x']) 201 | -------------------------------------------------------------------------------- /AsgiHandlerTest/mock_httpserv.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #define WIN32_LEAN_AND_MEAN 4 | #include 5 | 6 | #include "gmock/gmock.h" 7 | 8 | 9 | class MockIHttpRequest : public IHttpRequest { 10 | public: 11 | MOCK_METHOD0(GetRawHttpRequest, HTTP_REQUEST*()); 12 | MOCK_CONST_METHOD0(GetRawHttpRequest, const HTTP_REQUEST*()); 13 | MOCK_CONST_METHOD2(GetHeader, PCSTR(PCSTR, USHORT *)); 14 | MOCK_CONST_METHOD2(GetHeader, PCSTR(HTTP_HEADER_ID, USHORT *)); 15 | MOCK_METHOD4(SetHeader, HRESULT(PCSTR pszHeaderName, PCSTR pszHeaderValue, USHORT cchHeaderValue, BOOL fReplace)); 16 | MOCK_METHOD4(SetHeader, HRESULT(HTTP_HEADER_ID ulHeaderIndex, PCSTR pszHeaderValue, USHORT cchHeaderValue, BOOL fReplace)); 17 | MOCK_METHOD1(DeleteHeader, HRESULT(PCSTR pszHeaderName)); 18 | MOCK_METHOD1(DeleteHeader, HRESULT(HTTP_HEADER_ID ulHeaderIndex)); 19 | MOCK_CONST_METHOD0(GetHttpMethod, PCSTR()); 20 | MOCK_METHOD1(SetHttpMethod, HRESULT(PCSTR pszHttpMethod)); 21 | MOCK_METHOD3(SetUrl, HRESULT(PCWSTR pszUrl, DWORD cchUrl, BOOL fResetQueryString)); 22 | MOCK_METHOD3(SetUrl, HRESULT(PCSTR pszUrl, DWORD cchUrl, BOOL fResetQueryString)); 23 | MOCK_CONST_METHOD0(GetUrlChanged, BOOL()); 24 | MOCK_CONST_METHOD0(GetForwardedUrl, PCWSTR()); 25 | MOCK_CONST_METHOD0(GetLocalAddress, PSOCKADDR()); 26 | MOCK_CONST_METHOD0(GetRemoteAddress, PSOCKADDR()); 27 | MOCK_METHOD5(ReadEntityBody, HRESULT(VOID *, DWORD, BOOL, DWORD *, BOOL *)); 28 | MOCK_METHOD2(InsertEntityBody, HRESULT(VOID * pvBuffer, DWORD cbBuffer)); 29 | MOCK_METHOD0(GetRemainingEntityBytes, DWORD()); 30 | MOCK_CONST_METHOD2(GetHttpVersion, VOID(USHORT * pMajorVersion, USHORT * pMinorVersion)); 31 | MOCK_METHOD2(GetClientCertificate, HRESULT(HTTP_SSL_CLIENT_CERT_INFO ** ppClientCertInfo, BOOL * pfClientCertNegotiated)); 32 | MOCK_METHOD2(NegotiateClientCertificate, HRESULT(BOOL, BOOL *)); 33 | MOCK_CONST_METHOD0(GetSiteId, DWORD()); 34 | MOCK_METHOD9(GetHeaderChanges, HRESULT(DWORD dwOldChangeNumber, DWORD * pdwNewChangeNumber, PCSTR knownHeaderSnapshot[HttpHeaderRequestMaximum], DWORD * pdwUnknownHeaderSnapshot, PCSTR ** ppUnknownHeaderNameSnapshot, PCSTR ** ppUnknownHeaderValueSnapshot, DWORD diffedKnownHeaderIndices[HttpHeaderRequestMaximum + 1], DWORD * pdwDiffedUnknownHeaders, DWORD ** ppDiffedUnknownHeaderIndices)); 35 | }; 36 | 37 | 38 | class MockIHttpResponse : public IHttpResponse { 39 | public: 40 | MOCK_METHOD0(GetRawHttpResponse, HTTP_RESPONSE*()); 41 | MOCK_CONST_METHOD0(GetRawHttpResponse, const HTTP_RESPONSE*()); 42 | MOCK_METHOD0(GetCachePolicy, IHttpCachePolicy*()); 43 | MOCK_METHOD6(SetStatus, HRESULT(USHORT, PCSTR, USHORT, HRESULT, IAppHostConfigException *, BOOL)); 44 | MOCK_METHOD4(SetHeader, HRESULT(PCSTR pszHeaderName, PCSTR pszHeaderValue, USHORT cchHeaderValue, BOOL fReplace)); 45 | MOCK_METHOD4(SetHeader, HRESULT(HTTP_HEADER_ID ulHeaderIndex, PCSTR pszHeaderValue, USHORT cchHeaderValue, BOOL fReplace)); 46 | MOCK_METHOD1(DeleteHeader, HRESULT(PCSTR pszHeaderName)); 47 | MOCK_METHOD1(DeleteHeader, HRESULT(HTTP_HEADER_ID ulHeaderIndex)); 48 | MOCK_CONST_METHOD2(GetHeader, PCSTR(PCSTR, USHORT *)); 49 | MOCK_CONST_METHOD2(GetHeader, PCSTR(HTTP_HEADER_ID, USHORT *)); 50 | MOCK_METHOD0(Clear, VOID()); 51 | MOCK_METHOD0(ClearHeaders, VOID()); 52 | MOCK_METHOD0(SetNeedDisconnect, VOID()); 53 | MOCK_METHOD0(ResetConnection, VOID()); 54 | MOCK_METHOD1(DisableKernelCache, VOID(ULONG)); 55 | MOCK_CONST_METHOD0(GetKernelCacheEnabled, BOOL()); 56 | MOCK_METHOD0(SuppressHeaders, VOID()); 57 | MOCK_CONST_METHOD0(GetHeadersSuppressed, BOOL()); 58 | MOCK_METHOD4(Flush, HRESULT(BOOL, BOOL, DWORD *, BOOL *)); 59 | MOCK_METHOD3(Redirect, HRESULT(PCSTR, BOOL, BOOL)); 60 | MOCK_METHOD2(WriteEntityChunkByReference, HRESULT(HTTP_DATA_CHUNK *, LONG)); 61 | MOCK_METHOD6(WriteEntityChunks, HRESULT(HTTP_DATA_CHUNK *, DWORD, BOOL, BOOL, DWORD *, BOOL *)); 62 | MOCK_METHOD0(DisableBuffering, VOID()); 63 | MOCK_METHOD9(GetStatus, VOID(USHORT *, USHORT *, PCSTR *, USHORT *, HRESULT *, PCWSTR *, DWORD *, IAppHostConfigException**, BOOL *)); 64 | MOCK_METHOD3(SetErrorDescription, HRESULT(PCWSTR, DWORD, BOOL)); 65 | MOCK_METHOD1(GetErrorDescription, PCWSTR(DWORD *)); 66 | MOCK_METHOD9(GetHeaderChanges, HRESULT(DWORD dwOldChangeNumber, DWORD * pdwNewChangeNumber, PCSTR knownHeaderSnapshot[HttpHeaderResponseMaximum], DWORD * pdwUnknownHeaderSnapshot, PCSTR ** ppUnknownHeaderNameSnapshot, PCSTR ** ppUnknownHeaderValueSnapshot, DWORD diffedKnownHeaderIndices[HttpHeaderResponseMaximum + 1], DWORD * pdwDiffedUnknownHeaders, DWORD ** ppDiffedUnknownHeaderIndices)); 67 | MOCK_METHOD0(CloseConnection, VOID()); 68 | }; 69 | 70 | 71 | 72 | class MockIHttpContext : public IHttpContext { 73 | public: 74 | MOCK_METHOD0(GetSite, IHttpSite*()); 75 | MOCK_METHOD0(GetApplication, IHttpApplication*()); 76 | MOCK_METHOD0(GetConnection, IHttpConnection*()); 77 | MOCK_METHOD0(GetRequest, IHttpRequest*()); 78 | MOCK_METHOD0(GetResponse, IHttpResponse*()); 79 | MOCK_CONST_METHOD0(GetResponseHeadersSent, BOOL()); 80 | MOCK_CONST_METHOD0(GetUser, IHttpUser*()); 81 | MOCK_METHOD0(GetModuleContextContainer, IHttpModuleContextContainer*()); 82 | MOCK_METHOD1(IndicateCompletion, VOID(REQUEST_NOTIFICATION_STATUS notificationStatus)); 83 | MOCK_METHOD1(PostCompletion, HRESULT(DWORD cbBytes)); 84 | MOCK_METHOD2(DisableNotifications, VOID(DWORD dwNotifications, DWORD dwPostNotifications)); 85 | MOCK_METHOD5(GetNextNotification, BOOL(REQUEST_NOTIFICATION_STATUS status, DWORD * pdwNotification, BOOL * pfIsPostNotification, CHttpModule ** ppModuleInfo, IHttpEventProvider ** ppRequestOutput)); 86 | 87 | MOCK_METHOD1(GetIsLastNotification, BOOL(REQUEST_NOTIFICATION_STATUS status)); 88 | MOCK_METHOD5(ExecuteRequest, HRESULT(BOOL fAsync, IHttpContext * pHttpContext, DWORD dwExecuteFlags, IHttpUser * pHttpUser, BOOL * pfCompletionExpected)); 89 | MOCK_CONST_METHOD0(GetExecuteFlags, DWORD()); 90 | MOCK_METHOD3(GetServerVariable, HRESULT(PCSTR pszVariableName, PCWSTR * ppszValue, DWORD * pcchValueLength)); 91 | MOCK_METHOD3(GetServerVariable, HRESULT(PCSTR pszVariableName, PCSTR * ppszValue, DWORD * pcchValueLength)); 92 | MOCK_METHOD2(SetServerVariable, HRESULT(PCSTR pszVariableName, PCWSTR pszVariableValue)); 93 | MOCK_METHOD1(AllocateRequestMemory, VOID*(DWORD cbAllocation)); 94 | MOCK_METHOD0(GetUrlInfo, IHttpUrlInfo*()); 95 | MOCK_METHOD0(GetMetadata, IMetadataInfo*()); 96 | MOCK_METHOD1(GetPhysicalPath, PCWSTR(DWORD *)); 97 | MOCK_CONST_METHOD1(GetScriptName, PCWSTR(DWORD *)); 98 | MOCK_METHOD1(GetScriptTranslated, PCWSTR(DWORD *)); 99 | MOCK_CONST_METHOD0(GetScriptMap, IScriptMapInfo*()); 100 | MOCK_METHOD0(SetRequestHandled, VOID()); 101 | MOCK_CONST_METHOD0(GetFileInfo, IHttpFileInfo*()); 102 | MOCK_METHOD3(MapPath, HRESULT(PCWSTR pszUrl, PWSTR pszPhysicalPath, DWORD * pcbPhysicalPath)); 103 | MOCK_METHOD2(NotifyCustomNotification, HRESULT(ICustomNotificationProvider * pCustomOutput, BOOL * pfCompletionExpected)); 104 | MOCK_CONST_METHOD0(GetParentContext, IHttpContext*()); 105 | MOCK_CONST_METHOD0(GetRootContext, IHttpContext*()); 106 | MOCK_METHOD2(CloneContext, HRESULT(DWORD dwCloneFlags, IHttpContext ** ppHttpContext)); 107 | MOCK_METHOD0(ReleaseClonedContext, HRESULT()); 108 | MOCK_CONST_METHOD6(GetCurrentExecutionStats, HRESULT(DWORD *, DWORD *, PCWSTR *, DWORD *, DWORD *, DWORD *)); 109 | MOCK_CONST_METHOD0(GetTraceContext, IHttpTraceContext*()); 110 | MOCK_METHOD7(GetServerVarChanges, HRESULT(DWORD dwOldChangeNumber, DWORD * pdwNewChangeNumber, DWORD * pdwVariableSnapshot, PCSTR ** ppVariableNameSnapshot, PCWSTR ** ppVariableValueSnapshot, DWORD * pdwDiffedVariables, DWORD ** ppDiffedVariableIndices)); 111 | MOCK_METHOD0(CancelIo, HRESULT()); 112 | MOCK_METHOD6(MapHandler, HRESULT(DWORD, PCWSTR, PCWSTR, PCSTR, IScriptMapInfo**, BOOL)); 113 | 114 | #pragma warning(disable : 4996) 115 | MOCK_METHOD2(GetExtendedInterface, HRESULT(HTTP_CONTEXT_INTERFACE_VERSION version, PVOID * ppInterface)); 116 | }; 117 | -------------------------------------------------------------------------------- /cli/cli.vcxproj: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Debug 6 | Win32 7 | 8 | 9 | Release 10 | Win32 11 | 12 | 13 | Debug 14 | x64 15 | 16 | 17 | Release 18 | x64 19 | 20 | 21 | 22 | {C414CB0F-CECA-464F-9597-AFFE67045AC6} 23 | Win32Proj 24 | cli 25 | 8.1 26 | 27 | 28 | 29 | Application 30 | true 31 | v140 32 | Unicode 33 | 34 | 35 | Application 36 | false 37 | v140 38 | true 39 | Unicode 40 | 41 | 42 | Application 43 | true 44 | v140 45 | Unicode 46 | 47 | 48 | Application 49 | false 50 | v140 51 | true 52 | Unicode 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | true 74 | 75 | 76 | true 77 | 78 | 79 | false 80 | 81 | 82 | false 83 | 84 | 85 | 86 | 87 | 88 | Level3 89 | Disabled 90 | WIN32;_DEBUG;_CONSOLE;%(PreprocessorDefinitions) 91 | true 92 | 93 | 94 | Console 95 | true 96 | 97 | 98 | 99 | 100 | 101 | 102 | Level3 103 | Disabled 104 | _DEBUG;_CONSOLE;%(PreprocessorDefinitions) 105 | true 106 | ..\deps\hiredis;..\deps\msgpack-c\include;..\RedisAsgiHandlerLib;%(AdditionalIncludeDirectories) 107 | MultiThreadedDebug 108 | 109 | 110 | Console 111 | true 112 | %(AdditionalDependencies) 113 | 114 | 115 | 116 | 117 | Level3 118 | 119 | 120 | MaxSpeed 121 | true 122 | true 123 | WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) 124 | true 125 | 126 | 127 | Console 128 | true 129 | true 130 | true 131 | 132 | 133 | 134 | 135 | Level3 136 | 137 | 138 | MaxSpeed 139 | true 140 | true 141 | NDEBUG;_CONSOLE;%(PreprocessorDefinitions) 142 | true 143 | 144 | 145 | Console 146 | true 147 | true 148 | true 149 | 150 | 151 | 152 | 153 | {077b1dce-b453-327f-bf3f-deab170280ed} 154 | 155 | 156 | {66223674-ce18-4249-b02d-80377e46f114} 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | -------------------------------------------------------------------------------- /AsgiHandlerLib/HttpRequestHandlerSteps.cpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #define WIN32_LEAN_AND_MEAN 7 | #include 8 | #include 9 | 10 | #include "msgpack.hpp" 11 | 12 | #include "HttpRequestHandlerSteps.h" 13 | #include "HttpRequestHandler.h" 14 | 15 | 16 | StepResult ReadBodyStep::Enter() 17 | { 18 | IHttpRequest* request = m_http_context->GetRequest(); 19 | 20 | DWORD remaining_bytes = request->GetRemainingEntityBytes(); 21 | if (remaining_bytes == 0) { 22 | return kStepGotoNext; 23 | } 24 | 25 | BOOL completion_expected = FALSE; 26 | while (!completion_expected) { 27 | DWORD bytes_read; 28 | HRESULT hr = request->ReadEntityBody( 29 | m_asgi_request_msg->body.data() + m_body_bytes_read, remaining_bytes, 30 | true, &bytes_read, &completion_expected 31 | ); 32 | if (FAILED(hr)) { 33 | logger.debug() << "ReadEntityBody() = " << hr; 34 | // TODO: Call an Error() or something. 35 | return kStepFinishRequest; 36 | } 37 | 38 | if (!completion_expected) { 39 | // Operation completed synchronously. 40 | auto result = OnAsyncCompletion(S_OK, bytes_read); 41 | // If we need to read more, we might as well do that here, rather than 42 | // yielding back to the request loop. 43 | if (result != kStepRerun) { 44 | return result; 45 | } 46 | } 47 | 48 | remaining_bytes = request->GetRemainingEntityBytes(); 49 | } 50 | 51 | return kStepAsyncPending; 52 | } 53 | 54 | StepResult ReadBodyStep::OnAsyncCompletion(HRESULT hr, DWORD num_bytes) 55 | { 56 | if (FAILED(hr)) { 57 | // TODO: Call an Error() or something. 58 | return kStepFinishRequest; 59 | } 60 | 61 | IHttpRequest* request = m_http_context->GetRequest(); 62 | m_body_bytes_read += num_bytes; 63 | if (request->GetRemainingEntityBytes()) { 64 | // There's more to do - ask to be re-run. 65 | return kStepRerun; 66 | } 67 | 68 | return kStepGotoNext; 69 | } 70 | 71 | std::unique_ptr ReadBodyStep::GetNextStep() 72 | { 73 | return std::make_unique( 74 | m_handler, std::move(m_asgi_request_msg) 75 | ); 76 | } 77 | 78 | 79 | StepResult SendToApplicationStep::Enter() 80 | { 81 | // TODO: Split into chunked messages. 82 | auto task = concurrency::create_task([this]() { 83 | msgpack::sbuffer buffer; 84 | msgpack::pack(buffer, *m_asgi_request_msg); 85 | m_channels.Send("http.request", buffer); 86 | }).then([this]() { 87 | logger.debug() << "SendToApplicationStep calling PostCompletion()"; 88 | // The tests rely on this being the last thing that the callback does. 89 | m_http_context->PostCompletion(0); 90 | }); 91 | 92 | return kStepAsyncPending; 93 | } 94 | 95 | StepResult SendToApplicationStep::OnAsyncCompletion(HRESULT hr, DWORD num_bytes) 96 | { 97 | // TODO: Consider whether the channel layer can ever give us an error, and whether 98 | // we should handle it here. (ChannelFull?) 99 | return kStepGotoNext; 100 | } 101 | 102 | std::unique_ptr SendToApplicationStep::GetNextStep() 103 | { 104 | return std::make_unique( 105 | m_handler, m_asgi_request_msg->reply_channel 106 | ); 107 | } 108 | 109 | 110 | StepResult WaitForResponseStep::Enter() 111 | { 112 | m_response_pump.AddChannel(m_reply_channel, [this](std::string data) { 113 | // This is all sorts of wrong. 114 | m_asgi_response_msg = std::make_unique( 115 | msgpack::unpack(data.data(), data.length()).get().as() 116 | ); 117 | m_http_context->PostCompletion(0); 118 | 119 | logger.debug() << "MessagePump gave us a message; PostCompletion() called"; 120 | }); 121 | 122 | return kStepAsyncPending; 123 | } 124 | 125 | StepResult WaitForResponseStep::OnAsyncCompletion(HRESULT hr, DWORD num_bytes) 126 | { 127 | IHttpResponse *response = m_http_context->GetResponse(); 128 | 129 | // Only set the .status / .headers if the ASGI response contained them. We 130 | // don't distinguish between 'Response' and 'Response Chunk' in the ASGI spec. 131 | // We assume a streamed response sets the status and headers in the first message 132 | // but we do not enforce it. 133 | if (m_asgi_response_msg->status > 0) { 134 | response->SetStatus(m_asgi_response_msg->status, ""); 135 | } 136 | for (auto header : m_asgi_response_msg->headers) { 137 | std::string header_name = std::get<0>(header); 138 | std::string header_value = std::get<1>(header); 139 | response->SetHeader(header_name.c_str(), header_value.c_str(), header_value.length(), true); 140 | } 141 | 142 | if (m_asgi_response_msg->content.length() == 0 && !m_asgi_response_msg->more_content) { 143 | return kStepFinishRequest; 144 | } 145 | 146 | return kStepGotoNext; 147 | } 148 | 149 | std::unique_ptr WaitForResponseStep::GetNextStep() 150 | { 151 | // If the first chunk in a streaming response without data, then we want 152 | // to flush the headers to the client before we wait for more data from 153 | // the application. 154 | if (m_asgi_response_msg->content.length() == 0 && m_asgi_response_msg->more_content) { 155 | return std::make_unique(m_handler, m_reply_channel); 156 | } 157 | 158 | // We tell WriteResponseStep the reply_channel, in case this is a 159 | // streaming response and it needs to go back to waiting for a 160 | // response. 161 | return std::make_unique( 162 | m_handler, std::move(m_asgi_response_msg), m_reply_channel 163 | ); 164 | } 165 | 166 | 167 | StepResult WriteResponseStep::Enter() 168 | { 169 | IHttpResponse* response = m_http_context->GetResponse(); 170 | 171 | BOOL completion_expected = FALSE; 172 | while (!completion_expected) { 173 | m_resp_chunk.DataChunkType = HttpDataChunkFromMemory; 174 | m_resp_chunk.FromMemory.pBuffer = (PVOID)(m_asgi_response_msg->content.c_str() + m_resp_bytes_written); 175 | m_resp_chunk.FromMemory.BufferLength = m_asgi_response_msg->content.length() - m_resp_bytes_written; 176 | 177 | DWORD bytes_written; 178 | HRESULT hr = response->WriteEntityChunks( 179 | &m_resp_chunk, 1, 180 | true, false, &bytes_written, &completion_expected 181 | ); 182 | if (FAILED(hr)) { 183 | logger.debug() << "WriteEntityChunks() returned hr=" << hr; 184 | // TODO: Call some kind of Error(); 185 | return kStepFinishRequest; 186 | } 187 | 188 | if (!completion_expected) { 189 | // Operation completed synchronously. 190 | auto result = OnAsyncCompletion(S_OK, bytes_written); 191 | // If we need to write more, we might as well do that here, rather than 192 | // yielding back to the request loop. 193 | if (result != kStepRerun) { 194 | return result; 195 | } 196 | } 197 | } 198 | 199 | return kStepAsyncPending; 200 | } 201 | 202 | StepResult WriteResponseStep::OnAsyncCompletion(HRESULT hr, DWORD num_bytes) 203 | { 204 | // num_bytes is always 0, whether the operation completed sync or async. 205 | // This is because IIS has our data buffered. We are safe to destroy 206 | // m_asgi_response_msg. 207 | 208 | // If there's more data to send, explicitly flush the data (as the client 209 | // may be waiting for the initial chunk of the response) before waiting 210 | // for the application to send us more data. 211 | if (m_asgi_response_msg->more_content) { 212 | return kStepGotoNext; 213 | } 214 | 215 | // If this was the final chunk, finish the request. There's no need to 216 | // explicitly flush the request. IIS will do that for us. 217 | return kStepFinishRequest; 218 | } 219 | 220 | std::unique_ptr WriteResponseStep::GetNextStep() 221 | { 222 | if (m_asgi_response_msg->more_content) { 223 | // We must flush this response data before going to wait for more. 224 | return std::make_unique(m_handler, m_reply_channel); 225 | } 226 | 227 | // We should never reach here. 228 | throw std::runtime_error("WriteResponseStep does not have a next step."); 229 | } 230 | 231 | StepResult FlushResponseStep::Enter() 232 | { 233 | IHttpResponse* response = m_http_context->GetResponse(); 234 | 235 | BOOL completion_expected = FALSE; 236 | while (!completion_expected) { 237 | DWORD bytes_flushed; 238 | HRESULT hr = response->Flush( 239 | true, true, &bytes_flushed, &completion_expected 240 | ); 241 | if (FAILED(hr)) { 242 | logger.debug() << "Flush() returned hr=" << hr; 243 | // TODO: Call some kind of Error(); 244 | return kStepFinishRequest; 245 | } 246 | 247 | if (!completion_expected) { 248 | // Operation completed synchronously. 249 | auto result = OnAsyncCompletion(S_OK, bytes_flushed); 250 | // If we need to flush more, we might as well do that here, rather than 251 | // yielding back to the request loop. 252 | if (result != kStepRerun) { 253 | return result; 254 | } 255 | } 256 | } 257 | 258 | return kStepAsyncPending; 259 | } 260 | 261 | StepResult FlushResponseStep::OnAsyncCompletion(HRESULT hr, DWORD num_bytes) 262 | { 263 | // We don't actually know how big the request is (headers and body), so 264 | // there's nothing we can usefully do with num_bytes. 265 | return kStepGotoNext; 266 | } 267 | 268 | std::unique_ptr FlushResponseStep::GetNextStep() 269 | { 270 | // Go back to waiting for the application to send us more data. This 271 | // step is only called if more_content=True in the ASGI response. 272 | return std::make_unique(m_handler, m_reply_channel); 273 | } 274 | -------------------------------------------------------------------------------- /AsgiHandlerLib/RedisAsgiHandlerLib.vcxproj: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Debug 6 | Win32 7 | 8 | 9 | Release 10 | Win32 11 | 12 | 13 | Debug 14 | x64 15 | 16 | 17 | Release 18 | x64 19 | 20 | 21 | 22 | {66223674-CE18-4249-B02D-80377E46F114} 23 | Win32Proj 24 | RedisAsgiHandlerLib 25 | 8.1 26 | RedisAsgiHandlerLib 27 | 28 | 29 | 30 | DynamicLibrary 31 | true 32 | v140 33 | Unicode 34 | 35 | 36 | DynamicLibrary 37 | false 38 | v140 39 | true 40 | Unicode 41 | 42 | 43 | StaticLibrary 44 | true 45 | v140 46 | Unicode 47 | 48 | 49 | DynamicLibrary 50 | false 51 | v140 52 | true 53 | Unicode 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | true 75 | 76 | 77 | true 78 | 79 | 80 | false 81 | 82 | 83 | false 84 | 85 | 86 | 87 | 88 | 89 | Level3 90 | Disabled 91 | WIN32;_DEBUG;_WINDOWS;_USRDLL;REDISASGIHANDLERLIB_EXPORTS;%(PreprocessorDefinitions) 92 | true 93 | 94 | 95 | Windows 96 | true 97 | 98 | 99 | 100 | 101 | 102 | 103 | Level3 104 | Disabled 105 | _DEBUG;_WINDOWS;_USRDLL;RedisAsgiHandlerLib_EXPORTS;%(PreprocessorDefinitions) 106 | true 107 | ..\deps\hiredis\msvs\deps\;..\deps\hiredis\;..\deps\msgpack-c\include;%(AdditionalIncludeDirectories) 108 | MultiThreadedDebug 109 | 110 | 111 | Windows 112 | true 113 | RedisAsgiHandlerLib.def 114 | rpcrt4.lib;%(AdditionalDependencies) 115 | 116 | 117 | rpcrt4.lib 118 | 119 | 120 | 121 | 122 | Level3 123 | 124 | 125 | MaxSpeed 126 | true 127 | true 128 | WIN32;NDEBUG;_WINDOWS;_USRDLL;RedisAsgiHandlerLib_EXPORTS;%(PreprocessorDefinitions) 129 | true 130 | 131 | 132 | Windows 133 | true 134 | true 135 | true 136 | RedisAsgiHandlerLib.def 137 | 138 | 139 | 140 | 141 | Level3 142 | 143 | 144 | MaxSpeed 145 | true 146 | true 147 | NDEBUG;_WINDOWS;_USRDLL;RedisAsgiHandlerLib_EXPORTS;%(PreprocessorDefinitions) 148 | true 149 | ..\deps\hiredis\msvs\deps\;..\deps\hiredis\;..\deps\msgpack-c\include;%(AdditionalIncludeDirectories) 150 | 151 | 152 | Windows 153 | true 154 | true 155 | true 156 | RedisAsgiHandlerLib.def 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | {622cdfc1-956d-4b99-8362-acf14c959ee1} 196 | 197 | 198 | {077b1dce-b453-327f-bf3f-deab170280ed} 199 | 200 | 201 | 202 | 203 | 204 | -------------------------------------------------------------------------------- /IntegrationTests/fixtures/iis.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | 4 | from __future__ import absolute_import 5 | 6 | import collections 7 | import ctypes 8 | import logging 9 | import os 10 | import re 11 | import shutil 12 | import subprocess 13 | import sys 14 | import tempfile 15 | 16 | import pytest 17 | 18 | import psutil 19 | 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | BITNESS_64BIT = 'x64' 25 | ASGI_MODULE_NAME = 'AsgiHandler' 26 | POOL_MODULE_NAME = 'ProcessPool' 27 | DEFAULT_ASGI_DLL_PATH = os.path.join( 28 | os.path.dirname(__file__), '..', '..', 'build', 'AsgiHandler', 'Debug', 'AsgiHandler.dll' 29 | ) 30 | DEFAULT_POOL_DLL_PATH = os.path.join( 31 | os.path.dirname(__file__), '..', '..', 'build', 'ProcessPool', 'Debug', 'ProcessPool.dll' 32 | ) 33 | DEFAULT_POOL_SCHEMA_PATH = os.path.join( 34 | # NOTE: Source directory. 35 | os.path.dirname(__file__), '..', '..', 'ProcessPool', 'process-pool-iis-schema.xml' 36 | ) 37 | 38 | IIS_SCHEMA_PATH = os.path.expandvars('%WinDir%\\System32\\inetsrv\\config\\schema\\') 39 | POOL_SCHEMA_INSTALL_PATH = os.path.join(IIS_SCHEMA_PATH, 'process-pool.xml') 40 | 41 | 42 | def pytest_addoption(parser): 43 | parser.addoption( 44 | '--asgi-handler-dll', action='store', default=DEFAULT_ASGI_DLL_PATH, 45 | help='Path to the AsgiHandler.dll that is to be tested' 46 | ) 47 | parser.addoption( 48 | '--process-pool-dll', action='store', default=DEFAULT_POOL_DLL_PATH, 49 | help='Path to the ProcessPool.dll that is to be tested' 50 | ) 51 | parser.addoption( 52 | '--process-pool-schema-xml', action='store', default=DEFAULT_POOL_SCHEMA_PATH, 53 | help='Path to the XML schema for ProcessPool.dll configuration' 54 | ) 55 | parser.addoption( 56 | '--dll-bitness', action='store', default=BITNESS_64BIT, 57 | help='Bitness of DLLs - i.e. x64 or x86' 58 | ) 59 | @pytest.fixture 60 | def asgi_handler_dll(request): 61 | return request.config.getoption('--asgi-handler-dll') 62 | @pytest.fixture 63 | def process_pool_dll(request): 64 | return request.config.getoption('--process-pool-dll') 65 | @pytest.fixture 66 | def process_pool_schema_xml(request): 67 | return request.config.getoption('--process-pool-schema-xml') 68 | @pytest.fixture 69 | def dll_bitness(request): 70 | if request.config.getoption('--dll-bitness') == BITNESS_64BIT: 71 | return '64' 72 | else: 73 | return '32' 74 | 75 | 76 | # Install/uninstall our section in applicationHost.config. 77 | def _run_js(js): 78 | handle, path = tempfile.mkstemp(suffix='.js') 79 | os.close(handle) 80 | with open(path, 'w') as f: 81 | f.write(""" 82 | var ahwrite = new ActiveXObject("Microsoft.ApplicationHost.WritableAdminManager"); 83 | var configManager = ahwrite.ConfigManager; 84 | var appHostConfig = configManager.GetConfigFile("MACHINE/WEBROOT/APPHOST"); 85 | var systemWebServer = appHostConfig.RootSectionGroup.Item("system.webServer"); 86 | """ + js + """ 87 | ahwrite.CommitChanges(); 88 | """) 89 | subprocess.check_call(['cscript.exe', path]) 90 | os.remove(path) 91 | def add_section(name): 92 | _run_js("systemWebServer.Sections.AddSection('%s');" % name) 93 | def remove_section(name): 94 | _run_js("systemWebServer.Sections.DeleteSection('%s');" % name) 95 | 96 | 97 | # Install/uninstall schema XML files into IIS' installation directory, for the 98 | # modules that have configuration. 99 | # We need to disable file system redirection, as we're running in 32 bit Python. 100 | class DisableFileSystemRedirection(object): 101 | _disable = ctypes.windll.kernel32.Wow64DisableWow64FsRedirection 102 | _revert = ctypes.windll.kernel32.Wow64RevertWow64FsRedirection 103 | def __enter__(self): 104 | self.old_value = ctypes.c_long() 105 | self.success = self._disable(ctypes.byref(self.old_value)) 106 | def __exit__(self, type, value, traceback): 107 | if self.success: 108 | self._revert(self.old_value) 109 | def install_schema(path, install_path): 110 | with DisableFileSystemRedirection(): 111 | shutil.copy(path, install_path) 112 | def uninstall_schema(install_path): 113 | with DisableFileSystemRedirection(): 114 | os.remove(install_path) 115 | 116 | 117 | def appcmd(*args): 118 | path = os.path.join('C:\\', 'Windows', 'System32', 'inetsrv', 'appcmd.exe') 119 | cmd = [path] + list(args) 120 | logger.debug('Calling: %s', ' '.join(cmd)) 121 | output = subprocess.check_output(cmd) 122 | logger.debug('Returned: %s', output) 123 | return output 124 | 125 | def uninstall_module(name): 126 | appcmd('uninstall', 'module', name) 127 | 128 | def install_module(name, path, bitness): 129 | appcmd( 130 | 'install', 'module', 131 | '/name:%s' % name, 132 | '/image:%s' % path, 133 | '/preCondition:bitness%s' % bitness 134 | ) 135 | 136 | 137 | @pytest.yield_fixture 138 | def asgi_iis_module(asgi_handler_dll, dll_bitness, asgi_etw_consumer): 139 | try: 140 | install_module(ASGI_MODULE_NAME, asgi_handler_dll, dll_bitness) 141 | yield 142 | finally: 143 | # Allow errors to propogate, as they could affect the ability of other tests 144 | # to re-install the module. 145 | uninstall_module(ASGI_MODULE_NAME) 146 | 147 | 148 | @pytest.yield_fixture 149 | def pool_iis_module(process_pool_dll, dll_bitness, pool_etw_consumer, process_pool_schema_xml): 150 | try: 151 | install_module(POOL_MODULE_NAME, process_pool_dll, dll_bitness) 152 | install_schema(process_pool_schema_xml, POOL_SCHEMA_INSTALL_PATH) 153 | add_section('processPools') 154 | yield 155 | finally: 156 | # Allow errors to propogate, as they could affect the ability of other tests 157 | # to re-install the module. 158 | uninstall_module(POOL_MODULE_NAME) 159 | uninstall_schema(POOL_SCHEMA_INSTALL_PATH) 160 | remove_section('processPools') 161 | 162 | 163 | class SECURITY_DESCRIPTOR(ctypes.Structure): 164 | # From jaraco.windows 165 | _fields_ = [ 166 | ('Revision', ctypes.c_ubyte), 167 | ('Sbz1', ctypes.c_ubyte), 168 | ('Control', ctypes.c_ushort), 169 | ('Owner', ctypes.c_void_p), 170 | ('Group', ctypes.c_void_p), 171 | ('Sacl', ctypes.c_void_p), 172 | ('Dacl', ctypes.c_void_p), 173 | ] 174 | 175 | 176 | class SECURITY_ATTRIBUTES(ctypes.Structure): 177 | # From jaraco.windows 178 | _fields_ = [ 179 | ('nLength', ctypes.c_uint32), 180 | ('lpSecurityDescriptor', ctypes.POINTER(SECURITY_DESCRIPTOR)), 181 | ('bInheritHandle', ctypes.c_bool), 182 | ] 183 | 184 | 185 | class _ProcessPool(object): 186 | 187 | _pythonw = sys.executable.replace('python.exe', 'pythonw.exe') 188 | _worker_py = os.path.join(os.path.dirname(__file__), 'worker.py') 189 | 190 | def __init__(self, site): 191 | self.site = site 192 | self.name = os.urandom(4).encode('hex') 193 | self.process = self._pythonw 194 | self.arguments = [self._worker_py, self.name] 195 | self._object_prefix = u'Global\\ProcessPool_IntegrationTests_Worker_' + self.name 196 | 197 | # Create a SECURITY_DESCRIPTOR that grants everyone access: we'll be sharing a sempahore 198 | # and event between user sessions. 199 | sd = SECURITY_DESCRIPTOR() 200 | ctypes.windll.advapi32.InitializeSecurityDescriptor(ctypes.byref(sd), 1) 201 | ctypes.windll.advapi32.SetSecurityDescriptorDacl(ctypes.byref(sd), True, None, False) 202 | sa = SECURITY_ATTRIBUTES() 203 | sa.nLength = ctypes.sizeof(sa) 204 | sa.lpSecurityDescriptor = ctypes.pointer(sd) 205 | sa.bInheritHandle = False 206 | 207 | # Create a shared semaphore which we can use to count how many processes have run and 208 | # an event to stop processes. The event is declared with auto-reset, so that each time 209 | # it is singaled, a single process exits. 210 | self.semaphore = ctypes.windll.kernel32.CreateSemaphoreW( 211 | ctypes.byref(sa), 0x0, 0xFFFF, 212 | self._object_prefix + '_Counter' 213 | ) 214 | assert self.semaphore, ctypes.GetLastError() 215 | self.exit_event = ctypes.windll.kernel32.CreateEventW( 216 | ctypes.byref(sa), False, False, 217 | self._object_prefix + '_Exit' 218 | ) 219 | assert self.exit_event, ctypes.GetLastError() 220 | 221 | @staticmethod 222 | def escape_argument(arg): 223 | # See https://blogs.msdn.microsoft.com/twistylittlepassagesallalike/2011/04/23/everyone-quotes-command-line-arguments-the-wrong-way/ 224 | if not arg or re.search(r'(["\s])', arg): 225 | return '"' + arg.replace('"', r'\"') + '"' 226 | return arg 227 | 228 | @staticmethod 229 | def get_processes_for_user(user): 230 | procs = [] 231 | for proc in psutil.process_iter(): 232 | try: 233 | # We ignore w3wp.exe (IIS app pool) as it is uninteresting. 234 | if proc.username() == user and proc.name() != 'w3wp.exe': 235 | procs.append(proc) 236 | except psutil.AccessDenied: 237 | pass 238 | return [ 239 | (proc.name(), tuple(proc.cmdline())) 240 | for proc in procs 241 | ] 242 | 243 | @property 244 | def escaped_arguments(self): 245 | return ' '.join(map(self.escape_argument, self.arguments)) 246 | 247 | @property 248 | def num_started(self): 249 | """Returns the number of processes that have been started for this pool. 250 | 251 | Each process increments a semaphore when it starts, so we check the current count. 252 | """ 253 | # We must increment/decrement the semaphore in order to get the current value. 254 | current_count = ctypes.c_long() 255 | ctypes.windll.kernel32.ReleaseSemaphore(self.semaphore, 1, ctypes.byref(current_count)) 256 | ctypes.windll.kernel32.WaitForSingleObject(self.semaphore, 0) 257 | return current_count.value 258 | 259 | @property 260 | def num_running(self): 261 | """Returns the number of processes that are currently running in the pool.""" 262 | # It is unlikely that there are other processes running under the same user with 263 | # the same command line. 264 | expected = (os.path.basename(self.process), tuple([self.process] + self.arguments)) 265 | matching = [ 266 | process 267 | for process in self.get_processes_for_user(self.site.user) 268 | if process == expected 269 | ] 270 | return len(matching) 271 | 272 | def kill_one(self): 273 | """Signals the event, causing one process to exit""" 274 | assert ctypes.windll.kernel32.SetEvent(self.exit_event), ctypes.GetLastError() 275 | 276 | 277 | class _Site(object): 278 | 279 | pool_name = 'asgi-test-pool' 280 | user = 'IIS APPPOOL\\' + pool_name 281 | site_name = 'asgi-test-site' 282 | http_port = 90 283 | https_port = 91 284 | http_url = 'http://localhost:%i' % http_port 285 | https_url = 'http://localhost:%i' % https_port 286 | ws_url = 'ws://localhost:%i' % http_port 287 | url = http_url 288 | static_path = '/static.html' 289 | 290 | def __init__(self, directory, dll_bitness): 291 | self.directory = directory 292 | self.dll_bitness = dll_bitness 293 | 294 | def create(self): 295 | # Create a web.config which sends most requests through 296 | # our handler. We create a subdirectory that serves static files 297 | # so that we can test our handler doesn't get in the way 298 | # of others. 299 | webconfig = self.directory.join('web.config') 300 | webconfig.write(""" 301 | 302 | 303 | 304 | 305 | 312 | 318 | 319 | 320 | 321 | """) 322 | staticfile = self.directory.join('static.html') 323 | staticfile.write('Hello, world!') 324 | # Ensure IIS can read the directory. Use icacls to avoid introducing 325 | # pywin32 dependency. 326 | for path in [self.directory, webconfig, staticfile]: 327 | subprocess.check_call(['icacls', str(path), '/grant', 'Users:R']) 328 | # Add the site with its own app pool. Failing tests can cause the 329 | # pool to stop, so we create a new one for each test. 330 | appcmd('add', 'apppool', 331 | '/name:' + self.pool_name, 332 | '/enable32BitAppOnWin64:%s' % ('true' if self.dll_bitness == '32' else 'false'), 333 | ) 334 | appcmd('add', 'site', 335 | '/name:' + self.site_name, 336 | '/bindings:http://*:%i,https://*:%i' % (self.http_port, self.https_port), 337 | '/physicalPath:' + str(self.directory), 338 | ) 339 | appcmd('set', 'app', self.site_name + '/', '/applicationPool:' + self.pool_name) 340 | appcmd('unlock', 'config', '-section:system.webServer/handlers') 341 | 342 | def destroy(self): 343 | appcmd('delete', 'apppool', self.pool_name) 344 | appcmd('delete', 'site', self.site_name) 345 | 346 | def __enter__(self): 347 | self.create() 348 | return self 349 | def __exit__(self, *args): 350 | self.destroy() 351 | 352 | def stop(self): 353 | appcmd('stop', 'site', self.site_name) 354 | def start(self): 355 | appcmd('start', 'site', self.site_name) 356 | def restart(self): 357 | self.stop() 358 | self.start() 359 | 360 | def stop_application_pool(self): 361 | appcmd('stop', 'apppool', self.pool_name) 362 | def start_application_pool(self): 363 | appcmd('start', 'apppool', self.pool_name) 364 | def restart_application_pool(self): 365 | self.stop_application_pool() 366 | self.start_application_pool() 367 | 368 | def add_process_pool(self, count=1): 369 | pool = _ProcessPool(self) 370 | config = dict(executable=pool.process, arguments=pool.escaped_arguments) 371 | # If count=None, then don't specify and let IIS use the default. 372 | if count is not None: 373 | config['count'] = str(count) 374 | config_str = ','.join('%s=\'%s\'' % (k, v) for k, v in config.iteritems()) 375 | appcmd( 376 | 'set', 'config', self.site_name, 377 | '/section:system.webServer/processPools', 378 | '/+[%s]' % (config_str) 379 | ) 380 | self.restart() 381 | return pool 382 | 383 | def clear_process_pools(self): 384 | appcmd( 385 | 'clear', 'config', self.site_name, 386 | '/section:system.webServer/processPools', 387 | ) 388 | self.restart() 389 | 390 | @pytest.yield_fixture 391 | def site(tmpdir, asgi_iis_module, pool_iis_module, dll_bitness): 392 | with _Site(tmpdir, dll_bitness) as site: 393 | yield site 394 | --------------------------------------------------------------------------------