From 3c93077f4c7d1064c796853c757a4ee99f9f5573 Mon Sep 17 00:00:00 2001 From: Nikita Gorskikh Date: Mon, 4 May 2026 14:12:30 +0000 Subject: [PATCH] Pull request 682: Add an "easy" Windows service interface for TrustTunnel GUI clients Squashed commit of the following: commit e262ed1da2c2aef6a5c5a45ec5d41066ba44979f Author: Nikita Gorskikh Date: Sun Apr 26 18:13:02 2026 +0300 Return more correct error code commit 4a9b56b2b4e746b59e559ec6ce689f961ddd9e2b Author: Nikita Gorskikh Date: Sat Apr 25 00:18:15 2026 +0300 Cleanup commit 7884c91e0e869b9c5bc4b83548aa18cf487d7af2 Author: Nikita Gorskikh Date: Sat Apr 25 00:06:10 2026 +0300 Cleanup + run clang-format commit 7973e7aa5cabbb2e02ee559ea3e7bf9f073d3430 Author: Nikita Gorskikh Date: Fri Apr 24 23:28:24 2026 +0300 Cleanup commit 53e982d626ebe841a8d27bb6dd96cc2ace05d0c4 Author: Nikita Gorskikh Date: Fri Apr 24 22:42:27 2026 +0300 Vibe test commit d4fab62e2838dbecb7d04f811b62d6c4efad6440 Author: Nikita Gorskikh Date: Fri Apr 24 22:24:00 2026 +0300 Service interface and test commit 5a84e3b0216ab0ce4cedcd65cc620e82d5426cd5 Author: Nikita Gorskikh Date: Fri Apr 24 20:29:07 2026 +0300 Allow authenticated users to start/stop service commit 85238d526d447547d42cfe689952315de76b2869 Author: Nikita Gorskikh Date: Fri Apr 24 17:59:13 2026 +0300 Fix vibe code commit 79df07013043a16ca1fad3a9ea4b01af7af314e9 Author: Nikita Gorskikh Date: Fri Apr 24 14:12:47 2026 +0300 Fix vibe code commit e8ced6e1c76f519983ac756ea091ecc04332e33f Author: Nikita Gorskikh Date: Thu Apr 23 22:10:38 2026 +0300 Vibe code more commit af1d789c08f6bf27bed7510b4644ad5e17d3e709 Author: Nikita Gorskikh Date: Thu Apr 23 19:28:20 2026 +0300 Vibe-update AGENTS.md commit b34813d984ca927f8e427b551ed787f036386557 Author: Nikita Gorskikh Date: Thu Apr 23 19:28:09 2026 +0300 Remove non-vibe-coded commit 496dc4057e5e7db21e12f4b5d9ef0f9c9e22df5f Author: Nikita Gorskikh Date: Thu Apr 23 18:52:33 2026 +0300 Vibe test commit a1877a026496dc5919c72d44c7e8d971cda18d8f Author: Nikita Gorskikh Date: Thu Apr 23 15:31:38 2026 +0300 Vibe code commit 9e706843376b71cbf4160fe9edcff62ee9652b8f Author: Nikita Gorskikh Date: Thu Apr 23 00:13:18 2026 +0300 Vibe code commit a641c864fde194f9cb7c7001487801842b1d56b9 Author: Nikita Gorskikh Date: Wed Apr 22 22:58:12 2026 +0300 Vibe code commit 9edb4183a2173154502b00aa63b6f7af688b9624 Author: Nikita Gorskikh Date: Wed Apr 22 22:37:20 2026 +0300 Vibe code commit f3a7b059e8e2bce823410d7c6efcc521a48027fd Author: Nikita Gorskikh Date: Tue Mar 31 20:54:15 2026 +0300 WIP --- .gitignore | 2 + AGENTS.md | 36 +- platform/windows/CMakeLists.txt | 19 +- platform/windows/README.md | 2 +- platform/windows/include/vpn/vpn_easy.h | 23 +- .../windows/include/vpn/vpn_easy_service.h | 137 ++ platform/windows/src/vpn_easy.cpp | 414 +++++- platform/windows/src/vpn_easy_pipe.cpp | 589 +++++++++ platform/windows/src/vpn_easy_pipe.h | 291 +++++ platform/windows/src/vpn_easy_service.cpp | 159 +++ platform/windows/test/vpn_easy_pipe_test.cpp | 1134 +++++++++++++++++ .../windows/test/vpn_easy_service_test.cpp | 166 +++ platform/windows/test/vpn_easy_test.cpp | 11 +- 13 files changed, 2965 insertions(+), 18 deletions(-) create mode 100644 platform/windows/include/vpn/vpn_easy_service.h create mode 100644 platform/windows/src/vpn_easy_pipe.cpp create mode 100644 platform/windows/src/vpn_easy_pipe.h create mode 100644 platform/windows/src/vpn_easy_service.cpp create mode 100644 platform/windows/test/vpn_easy_pipe_test.cpp create mode 100644 platform/windows/test/vpn_easy_service_test.cpp diff --git a/.gitignore b/.gitignore index d963e27..daa74af 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,8 @@ pyvenv.cfg .clangd/ /.cache/ +target/ + # Conan and conan-cmake create some files in-tree by design conan.lock conanbuildinfo.* diff --git a/AGENTS.md b/AGENTS.md index bf09966..47190fc 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -19,6 +19,12 @@ See [README.md](README.md) for full product details and - **Conan 2.0.5+** — C++ package manager - **Ninja** — build backend - **Clang / LLVM 17+** (LLVM 19 on macOS) — compiler and tooling + (Windows builds use MSVC `cl.exe`, not clang) +- **clang-format 21+** — required by `make lint-cpp` / `make clang-format` +- **Python 3** — `scripts/bootstrap_conan_deps.py`, Conan wrappers + (`requirements.txt`) +- **Ruby / Fastlane** — iOS/Android release automation (`Gemfile`, + `fastlane/`) ## Directory Structure @@ -38,6 +44,10 @@ See [README.md](README.md) for full product details and | `cmake/` | CMake modules: unit test helper, Conan bootstrapping/provider | | `bamboo-specs/` | CI/CD pipeline definitions (Bamboo) | | `integration-tests/` | Docker-based integration test harness | +| `conan/` | Conan user settings and build profiles | +| `fastlane/` | iOS/Android release automation (Fastfile, Matchfile) | +| `.devcontainer/` | Docker-based remote debugging environment (see below) | +| `.github/` | GitHub Actions workflows and PR/issue templates | ### Module Dependency Flow @@ -52,22 +62,31 @@ common ← tcpip ← core ## Build Commands Run `make init` once after cloning to set up git hooks. +Running `make` with no arguments runs the `init` target — use `make all` (or +`make build_libs`) to actually build. | Command | What It Does | | --- | --- | | `make init` | Configure git hooks path to `./scripts/hooks` | +| `make bootstrap_deps` | Export AdGuard Conan recipes to local cache (prerequisite of most build targets) | +| `make setup_cmake` | CMake configure only (accepts `SKIP_BOOTSTRAP=1`) | | `make build_libs` | Bootstrap Conan deps → CMake configure → build `vpnlibs_core` | | `make build_trusttunnel_client` | Build the CLI client binary (depends on `build_libs`) | | `make build_wizard` | Build the setup wizard binary | +| `make build_and_export_bin` | Build binaries and copy to `$(EXPORT_DIR)` (default `bin/`) | | `make all` | Build all binaries (client + wizard) | | `make test` | Run all tests (`test-cpp` + `test-rust`) | | `make test-cpp` | Build libs → build test targets → run `ctest` | | `make test-rust` | `cargo test` on the setup_wizard workspace | | `make lint` | Run all linters (`lint-md` + `lint-rust` + `lint-cpp`) | | `make lint-cpp` | `clang-format` check + `clangd-tidy` | +| `make clang-format` | Explicit `clang-format` check only | +| `make clang-tidy` / `make clangd-tidy` | Run C++ static analysis only | | `make lint-rust` | `cargo clippy` + `cargo fmt --check` | | `make lint-md` | `markdownlint .` | | `make lint-fix` | Auto-fix all fixable linter issues | +| `make lint-fix-cpp` / `lint-fix-rust` / `lint-fix-md` | Granular auto-fix targets | +| `make list-deps-dirs` | List Conan package directories (for finding dep headers) | | `make compile_commands` | Generate `compile_commands.json` for IDE integration | | `make clean` | Clean build artifacts | @@ -92,6 +111,10 @@ Set `BUILD_TYPE=debug` for debug builds (default is `release` → - `UPPER_CASE`: constants, `constexpr` locals, static constants - Private/protected members prefixed with `m_`, globals with `g_` - Use `libc++` (not `libstdc++`) +- Use static storage duration instead of anonymous namespaces for internal linkage where possible + - (e.g. `static const int VALUE = 42;` instead of putting it in an anonymous namespace) +- Function descriptions are written in imperative language + - e.g. "Calculate the sum of two numbers" instead of "Calculates the sum of two numbers" ### Rust @@ -119,6 +142,12 @@ Set `BUILD_TYPE=debug` for debug builds (default is `release` → - Prefer existing patterns over inventing new ones - Keep changes minimal and focused - Tests live in `test/` subdirectories alongside the module they cover +- Logging guidelines: + - Use `DEBUG` level for verbose debug info, `INFO` for high-level events, + `WARN` for recoverable issues, and `ERROR` for critical (unrecoverable) errors + - Very frequent events (e.g. every packet) should be logged at `DEBUG` level, while important state changes (e.g. connection established, error occurred) should be at `INFO` or higher + - Include relevant context in log messages (e.g. connection ID, error code) + - Avoid logging sensitive information (e.g. IP addresses, payload data) ## Docker Debug Environment @@ -141,10 +170,11 @@ Managed via Conan. Key libraries: - **libevent** — async event loop - **nghttp2** — HTTP/2 - **quiche** — HTTP/3 / QUIC (disabled on MIPS) -- **openssl** (BoringSSL) — TLS +- **openssl** (BoringSSL; MIPS falls back to `openssl/3.1.5-quic1@adguard/oss`) — TLS - **nlohmann_json**, **tomlplusplus** — config parsing - **cxxopts** — CLI argument parsing - **magic_enum** — enum reflection +- **gtest** — unit testing Local conan cache is populated by `make bootstrap_deps` which is dependency for many other make commands. @@ -156,10 +186,10 @@ directories, then look in each directory's `include/` subdirectory. You MUST follow the following rules for EVERY task that you perform: -- You MUST verify it with linter, formatter, and TypeScript compiler. +- You MUST verify it with linter, formatter, and C++/Rust compilers. Use the following commands: - - `make` to check if code builds + - `make all` to check if code builds (bare `make` only runs `init`) - `make test` to build and run unit tests - `make lint` to run the linters - `make lint-fix` to fix linting issues that can be fixed automatically diff --git a/platform/windows/CMakeLists.txt b/platform/windows/CMakeLists.txt index d2f7f2f..cebfa2e 100644 --- a/platform/windows/CMakeLists.txt +++ b/platform/windows/CMakeLists.txt @@ -28,12 +28,27 @@ if (NOT TARGET vpnlibs_trusttunnel) add_subdirectory(${ROOT_DIR} ${CMAKE_BINARY_DIR}/trusttunnel) endif () -add_library(vpn_easy_a STATIC EXCLUDE_FROM_ALL src/vpn_easy.cpp) +add_library(vpn_easy_a STATIC src/vpn_easy.cpp src/vpn_easy_pipe.cpp) target_include_directories(vpn_easy_a PUBLIC include) target_link_libraries(vpn_easy_a vpnlibs_trusttunnel) add_library(vpn_easy SHARED src/vpn_easy.cpp) target_link_libraries(vpn_easy vpn_easy_a) -add_executable(vpn_easy_test EXCLUDE_FROM_ALL test/vpn_easy_test.cpp) +add_executable(vpn_easy_test test/vpn_easy_test.cpp) target_link_libraries(vpn_easy_test vpn_easy_a) + +add_executable(vpn_easy_service src/vpn_easy_service.cpp) +target_link_libraries(vpn_easy_service vpn_easy_a) + +add_executable(vpn_easy_service_test test/vpn_easy_service_test.cpp) +target_link_libraries(vpn_easy_service_test vpn_easy_a) +add_dependencies(vpn_easy_service_test vpn_easy_service) + +enable_testing() +include(${ROOT_DIR}/cmake/add_unit_test.cmake) +set(TEST_DIR ${CMAKE_CURRENT_SOURCE_DIR}/test) + +link_libraries(vpn_easy_a) + +add_unit_test(vpn_easy_pipe_test "${TEST_DIR}" "${CMAKE_CURRENT_SOURCE_DIR}/src" TRUE TRUE) diff --git a/platform/windows/README.md b/platform/windows/README.md index 8d9edb8..80f9250 100644 --- a/platform/windows/README.md +++ b/platform/windows/README.md @@ -1,4 +1,4 @@ -# Easy wrapper for AdGuard VPI API +# Easy wrapper for AdGuard VPN API - Basically, the `trusttunnel_client` command line application in the form of a library. - Only two buttons: `start` and `stop`. The first one accepts the configuration in TOML format. diff --git a/platform/windows/include/vpn/vpn_easy.h b/platform/windows/include/vpn/vpn_easy.h index 220db1a..daac392 100644 --- a/platform/windows/include/vpn/vpn_easy.h +++ b/platform/windows/include/vpn/vpn_easy.h @@ -10,7 +10,12 @@ extern "C" { #endif typedef struct vpn_easy_s vpn_easy_t; -typedef void (*on_state_changed_t)(void *arg, int new_state_description); + +/** See `ag::VpnSessionState`. */ +typedef void (*on_state_changed_t)(void *arg, int state); + +/** See `ag::VpnConnectionInfoEvent`. */ +typedef void (*on_connection_info_t)(void *arg, void *connection_info); /** * Start (connect) a VPN client. @@ -29,6 +34,22 @@ WIN_EXPORT void vpn_easy_start( */ WIN_EXPORT void vpn_easy_stop(); +/** + * Start (connect) a VPN client. The callbacks and their arguments passed to this function + * must remain valid throughout the lifetime of the VPN client. + * @param toml_config VPN client parameters in TOML format. + * @param state_changed_cb A function which will be called each time the VPN client's state changes. + * @param state_changed_cb_arg An argument passed to each invocation of the state change function. + * @param connection_info_cb A function called each time a connection is made through the VPN. + * @param connection_info_cb_arg An argument passed to each invocation of the connection info function. + * @return On success, a pointer to the started VPN client instance. On error, a null pointer. + */ +vpn_easy_t *vpn_easy_start_ex(const char *toml_config, on_state_changed_t state_changed_cb, void *state_changed_cb_arg, + on_connection_info_t connection_info_cb, void *connection_info_cb_arg); + +/** Stop (disconnect) a VPN client and free all associated resources. */ +void vpn_easy_stop_ex(vpn_easy_t *vpn); + #ifdef __cplusplus }; // extern "C" #endif diff --git a/platform/windows/include/vpn/vpn_easy_service.h b/platform/windows/include/vpn/vpn_easy_service.h new file mode 100644 index 0000000..ea9e69e --- /dev/null +++ b/platform/windows/include/vpn/vpn_easy_service.h @@ -0,0 +1,137 @@ +#pragma once + +#include + +#include "vpn/platform.h" +#include "vpn/vpn_easy.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * Communication with the service is done by sending messages of the form: + * ``` + * struct Message { + * uint32_t type; + * uint32_t length; // The length of the `data` field. + * uint8_t data[0]; // `length` bytes of data. + * }; + * ``` + * over the named pipe configured at service creation time (see `vpn_easy_service_install()`). + * The format of the `data` field is given by the message type. Integers are in network byte order. + */ +typedef enum { + /** + * A request to start (connect) the VPN client. The data field must contain the VPN client configuration + * in TOML format (encoded in UTF-8 as per TOML specification). + * + * If the client is already connecting or connected, it is requested to stop + * first and then start with the new configuration. + * + * If the client fails to start for whatever reason, the service will send a state change + * message with the state `VPN_SS_DISCONNECTED`. + */ + VPN_EASY_SVC_MSG_START = 0, + + /** + * A request to stop (disconnect) the VPN client. The length field must be zero, the data field empty. + * If the client is already stopped, this message is ignored. + */ + VPN_EASY_SVC_MSG_STOP, + + /** + * Sent by the service when the VPN client state changes. `length` is always `4` in network byte order, `data` + * is an `int32_t` in network byte order, one of the `ag::VpnSessionState` values. + * + * The service client should wait for this message after sending a start/stop request. + */ + VPN_EASY_SVC_MSG_STATE_CHANGED, + + /** + * Sent by the service to notify the client of a new connection through the VPN tunnel. The data field + * is a JSON document describing the connection, as returned by `ag::ConnectionInfo::to_json()`. + */ + VPN_EASY_SVC_MSG_CONNECTION_INFO, +} VpnEasyServiceMessageType; + +typedef enum { + /** Access denied. Check if the calling process is running as administrator. */ + VPN_EASY_SVC_ERR_ACCESS = 1, + + /** Service already exists. Uninstall it with `vpn_easy_service_uninstall()` first. */ + VPN_EASY_SVC_ERR_SERVICE_EXISTS, + + /** No service with the given name exists. */ + VPN_EASY_SVC_ERR_NO_SUCH_SERVICE, + + /** An operation on the service took too long. */ + VPN_EASY_SVC_ERR_TIMED_OUT, + + /** Encountered an unexpected error. Probably as a result of API misusage. The log may contain more details. */ + VPN_EASY_SVC_ERR_OTHER, +} VpnEasyServiceError; + +/** + * Create and start a VPN service. This function requires administrator privileges. The service is configured + * to start automatically at system startup. After startup, the service is listening on a named pipe `pipe_name`, + * and can be controlled by connecting and sending messages on that pipe. The protocol details are given by the + * description of `VpnEasyServiceMessageType` enumeration. Anyone can read/write from/to the pipe. + * @param image_path The absolute path to the `vpn_easy_service` executable. + * @param logfile_path The absolute path to the service's log file. Will be created if doesn't exist. + * @param name The service name. At most 256 characters. + * @param pipe_name The name for the named pipe used to communicate with the service. + * A string of at most 256 characters of the form: "\\.\pipe\", where "" + * can include any character except the backslash. + * @param display_name The display name to be used by user interface programs to identify the service. + * At most 256 characters. + * @param description A comment that explains the purpose of the service. + * @return Zero on success, one of `VpnEasyServiceError` constants on failure. + */ +WIN_EXPORT int32_t vpn_easy_service_install(const wchar_t *image_path, const wchar_t *logfile_path, + const wchar_t *pipe_name, const wchar_t *name, const wchar_t *display_name, const wchar_t *description); + +/** + * Stop and delete the VPN service named `name`. This function requires administrator privileges. The service is + * requested to stop and marked for deletion. It will be deleted when it has stopped and all handles to it have + * been closed. If it doesn't stop for some reason, it will be deleted when the system is restarted. + * @param name The service name that was passed to `vpn_easy_service_install()`. + * @return Zero on success, one of `VpnEasyServiceError` constants on failure. + */ +WIN_EXPORT int32_t vpn_easy_service_uninstall(const wchar_t *name); + +/** + * Start the VPN service named `service_name`. + * + * This will start the Windows service if it's not already running, connect to the running service + * through the named pipe and instruct it to start the VPN client with the provided configuration. + * + * It shall then pass the service state change messages to `state_changed_cb`, which must remain + * valid (along with `state_changed_cb_arg`) until the service is stopped with `vpn_easy_service_stop()`. + * The callback is invoked on an unspecified thread, and may be called concurrently with `vpn_easy_service_start()`. + * + * @param service_name The service name that was passed to `vpn_easy_service_install()`. + * @param pipe_name The name of the pipe that was passed to `vpn_easy_service_install()`. + * @param toml_config The VPN client configuration in TOML format (encoded in UTF-8 as per TOML specification). + * @param state_changed_cb A function which will be called each time the VPN client's state changes. + * @param state_changed_cb_arg An argument passed to each invocation of the state change function. + * @return Zero on success, one of `VpnEasyServiceError` constants on failure. + */ +WIN_EXPORT int32_t vpn_easy_service_start(const wchar_t *service_name, const wchar_t *pipe_name, + const char *toml_config, on_state_changed_t state_changed_cb, void *state_changed_cb_arg); + +/** + * Stop the VPN service named `service_name`. + * + * This will stop both the VPN client and the Windows service. + * After this function returns, the state change callback will not be called anymore. + * + * @param service_name The service name that was passed to `vpn_easy_service_install()`. + * @param pipe_name The name of the pipe that was passed to `vpn_easy_service_install()`. + * @return Zero on success, one of `VpnEasyServiceError` constants on failure. + */ +WIN_EXPORT int32_t vpn_easy_service_stop(const wchar_t *service_name, const wchar_t *pipe_name); + +#ifdef __cplusplus +}; // extern "C" +#endif diff --git a/platform/windows/src/vpn_easy.cpp b/platform/windows/src/vpn_easy.cpp index 33007bd..dd4629a 100644 --- a/platform/windows/src/vpn_easy.cpp +++ b/platform/windows/src/vpn_easy.cpp @@ -1,20 +1,35 @@ #include "vpn/vpn_easy.h" +#include "vpn/vpn_easy_service.h" +#include +#include #include +#include #include +#include +#include #include +#include + +#include #include #include +#include +#include + #include "common/logger.h" #include "common/net_utils.h" +#include "common/utils.h" #include "net/tls.h" #include "vpn/event_loop.h" #include "vpn/platform.h" #include "vpn/trusttunnel/auto_network_monitor.h" #include "vpn/trusttunnel/client.h" #include "vpn/trusttunnel/config.h" +#include "vpn/vpn.h" +#include "vpn_easy_pipe.h" static ag::Logger g_logger{"VPN_SIMPLE"}; @@ -84,8 +99,8 @@ struct vpn_easy_s { std::unique_ptr network_monitor; }; -static vpn_easy_t *vpn_easy_start_internal( - const char *toml_config, on_state_changed_t state_changed_cb, void *state_changed_cb_arg) { +vpn_easy_t *vpn_easy_start_ex(const char *toml_config, on_state_changed_t state_changed_cb, void *state_changed_cb_arg, + on_connection_info_t connection_info_cb, void *connection_info_cb_arg) { toml::parse_result parsed_config = toml::parse(toml_config); if (!parsed_config) { warnlog(g_logger, "Failed to parse the TOML config: {}", parsed_config.error().description()); @@ -119,11 +134,17 @@ static vpn_easy_t *vpn_easy_start_internal( state_changed_cb(state_changed_cb_arg, event->state); } }; + if (connection_info_cb) { + callbacks.connection_info_handler = [connection_info_cb, connection_info_cb_arg]( + ag::VpnConnectionInfoEvent *event) { + connection_info_cb(connection_info_cb_arg, event); + }; + } auto vpn = std::make_unique(); std::string bound_if; - if (const auto *tun = std::get_if(&trusttunnel_config->listener)) { + if (const auto *tun = std::get_if(&trusttunnel_config->listener)) { bound_if = tun->bound_if; } @@ -141,7 +162,7 @@ static vpn_easy_t *vpn_easy_start_internal( return vpn.release(); } -static void vpn_easy_stop_internal(vpn_easy_t *vpn) { +void vpn_easy_stop_ex(vpn_easy_t *vpn) { if (!vpn) { return; } @@ -175,7 +196,7 @@ public: warnlog(g_logger, "VPN has been already started"); return; } - m_vpn = vpn_easy_start_internal(config.data(), callback, arg); // blocking + m_vpn = vpn_easy_start_ex(config.data(), callback, arg, nullptr, nullptr); // blocking if (!m_vpn) { errlog(g_logger, "Failed to start VPN!"); return; @@ -193,7 +214,7 @@ public: return; } auto *vpn = std::exchange(m_vpn, nullptr); - vpn_easy_stop_internal(vpn); + vpn_easy_stop_ex(vpn); }); } @@ -215,4 +236,383 @@ void vpn_easy_start(const char *toml_config, on_state_changed_t state_changed_cb void vpn_easy_stop() { VpnEasyManager::instance().stop_async(); -} \ No newline at end of file +} + +static std::wstring escape(const wchar_t *str, const wchar_t *chars_to_escape, wchar_t escape_char) { + std::wstring ret; + ret.reserve(wcslen(str) * 2); + while (*str != L'\0') { + if (wcschr(chars_to_escape, *str)) { + ret += escape_char; + } + ret += *str; + ++str; + } + return ret; +} + +using AutoScHandle = ag::UniquePtr, &CloseServiceHandle>; + +// Grant SERVICE_START and SERVICE_STOP to authenticated users on the given service handle. +// Return true on success, false on any error (logged at DEBUG level). +static bool grant_authenticated_users_start_stop(SC_HANDLE svc) { + // Query the current security descriptor size + DWORD bytes_needed = 0; + if (!QueryServiceObjectSecurity(svc, DACL_SECURITY_INFORMATION, nullptr, 0, &bytes_needed) + && GetLastError() != ERROR_INSUFFICIENT_BUFFER) { + dbglog(g_logger, "QueryServiceObjectSecurity (size): {} ({})", GetLastError(), + ag::sys::strerror(GetLastError())); + return false; + } + + std::vector sd_buf; + sd_buf.resize(bytes_needed); + auto *sd = reinterpret_cast(sd_buf.data()); + if (!QueryServiceObjectSecurity(svc, DACL_SECURITY_INFORMATION, sd, bytes_needed, &bytes_needed)) { + dbglog(g_logger, "QueryServiceObjectSecurity: {} ({})", GetLastError(), ag::sys::strerror(GetLastError())); + return false; + } + + // Retrieve the existing DACL from the security descriptor + BOOL dacl_present = FALSE; + PACL old_dacl = nullptr; + BOOL dacl_defaulted = FALSE; + if (!GetSecurityDescriptorDacl(sd, &dacl_present, &old_dacl, &dacl_defaulted)) { + dbglog(g_logger, "GetSecurityDescriptorDacl: {} ({})", GetLastError(), ag::sys::strerror(GetLastError())); + return false; + } + + // Build the SID for Authenticated Users (S-1-5-11) + SID_IDENTIFIER_AUTHORITY nt_authority = SECURITY_NT_AUTHORITY; + PSID authenticated_users_sid = nullptr; + if (!AllocateAndInitializeSid( + &nt_authority, 1, SECURITY_AUTHENTICATED_USER_RID, 0, 0, 0, 0, 0, 0, 0, &authenticated_users_sid)) { + dbglog(g_logger, "AllocateAndInitializeSid: {} ({})", GetLastError(), ag::sys::strerror(GetLastError())); + return false; + } + + PACL new_dacl = nullptr; + + ag::utils::ScopeExit cleanup{[&] { + LocalFree(new_dacl); + FreeSid(authenticated_users_sid); + }}; + + // Build an EXPLICIT_ACCESS entry granting SERVICE_START | SERVICE_STOP + EXPLICIT_ACCESS_W ea{}; + ea.grfAccessPermissions = SERVICE_START | SERVICE_STOP; + ea.grfAccessMode = SET_ACCESS; + ea.grfInheritance = NO_INHERITANCE; + ea.Trustee.TrusteeForm = TRUSTEE_IS_SID; + ea.Trustee.TrusteeType = TRUSTEE_IS_WELL_KNOWN_GROUP; + ea.Trustee.ptstrName = reinterpret_cast(authenticated_users_sid); + + DWORD result = SetEntriesInAclW(1, &ea, old_dacl, &new_dacl); + if (result != ERROR_SUCCESS) { + dbglog(g_logger, "SetEntriesInAclW: {} ({})", result, ag::sys::strerror(result)); + return false; + } + + // Build a new security descriptor with the updated DACL + SECURITY_DESCRIPTOR new_sd{}; + if (!InitializeSecurityDescriptor(&new_sd, SECURITY_DESCRIPTOR_REVISION)) { + dbglog(g_logger, "InitializeSecurityDescriptor: {} ({})", GetLastError(), ag::sys::strerror(GetLastError())); + return false; + } + if (!SetSecurityDescriptorDacl(&new_sd, TRUE, new_dacl, FALSE)) { + dbglog(g_logger, "SetSecurityDescriptorDacl: {} ({})", GetLastError(), ag::sys::strerror(GetLastError())); + return false; + } + + // Apply the updated security descriptor to the service + if (!SetServiceObjectSecurity(svc, DACL_SECURITY_INFORMATION, &new_sd)) { + dbglog(g_logger, "SetServiceObjectSecurity: {} ({})", GetLastError(), ag::sys::strerror(GetLastError())); + return false; + } + + return true; +} + +int32_t vpn_easy_service_install(const wchar_t *image_path_, const wchar_t *logfile_path_, const wchar_t *pipe_name_, + const wchar_t *name, const wchar_t *display_name, const wchar_t *description) { + std::wstring image_path = escape(image_path_, L"\"", L'\\'); + std::wstring logfile_path = escape(logfile_path_, L"\"", L'\\'); + std::wstring pipe_name = escape(pipe_name_, L"\"", L'\\'); + + std::wstring cmd = fmt::format(L"\"{}\" \"{}\" \"{}\"", image_path, logfile_path, pipe_name); + + AutoScHandle scm{OpenSCManagerW(nullptr, nullptr, SC_MANAGER_CREATE_SERVICE)}; + if (!scm) { + if (ERROR_ACCESS_DENIED == GetLastError()) { + return VPN_EASY_SVC_ERR_ACCESS; + } + dbglog(g_logger, "OpenSCManagerW: {} ({})", GetLastError(), ag::sys::strerror(GetLastError())); + return VPN_EASY_SVC_ERR_OTHER; + } + + AutoScHandle svc{CreateServiceW(scm.get(), name, display_name, SERVICE_ALL_ACCESS, SERVICE_WIN32_OWN_PROCESS, + SERVICE_AUTO_START, SERVICE_ERROR_NORMAL, cmd.c_str(), nullptr, nullptr, nullptr, nullptr, nullptr)}; + if (!svc) { + if (ERROR_SERVICE_EXISTS == GetLastError()) { + return VPN_EASY_SVC_ERR_SERVICE_EXISTS; + } + if (ERROR_ACCESS_DENIED == GetLastError()) { + return VPN_EASY_SVC_ERR_ACCESS; + } + dbglog(g_logger, "CreateServiceW: {} ({})", GetLastError(), ag::sys::strerror(GetLastError())); + return VPN_EASY_SVC_ERR_OTHER; + } + + SERVICE_DESCRIPTIONW desc{.lpDescription = const_cast(description)}; + ChangeServiceConfig2W(svc.get(), SERVICE_CONFIG_DESCRIPTION, &desc); + + if (!grant_authenticated_users_start_stop(svc.get())) { + dbglog(g_logger, "Failed to grant start/stop permissions to authenticated users"); + return VPN_EASY_SVC_ERR_OTHER; + } + + if (!StartServiceW(svc.get(), 0, nullptr)) { + dbglog(g_logger, "StartServiceW: {} ({})", GetLastError(), ag::sys::strerror(GetLastError())); + return VPN_EASY_SVC_ERR_OTHER; + } + + return 0; +} + +int32_t vpn_easy_service_uninstall(const wchar_t *name) { + AutoScHandle scm{OpenSCManagerW(nullptr, nullptr, SC_MANAGER_CONNECT)}; + if (!scm) { + if (ERROR_ACCESS_DENIED == GetLastError()) { + return VPN_EASY_SVC_ERR_ACCESS; + } + dbglog(g_logger, "OpenSCManagerW: {} ({})", GetLastError(), ag::sys::strerror(GetLastError())); + return VPN_EASY_SVC_ERR_OTHER; + } + + AutoScHandle svc{OpenServiceW(scm.get(), name, STANDARD_RIGHTS_DELETE | SERVICE_STOP)}; + if (!svc) { + if (ERROR_ACCESS_DENIED == GetLastError()) { + return VPN_EASY_SVC_ERR_ACCESS; + } + if (ERROR_SERVICE_DOES_NOT_EXIST == GetLastError()) { + return VPN_EASY_SVC_ERR_NO_SUCH_SERVICE; + } + dbglog(g_logger, "OpenServiceW: {} ({})", GetLastError(), ag::sys::strerror(GetLastError())); + return VPN_EASY_SVC_ERR_OTHER; + } + + SERVICE_STATUS status{}; + if (!ControlService(svc.get(), SERVICE_CONTROL_STOP, &status) && ERROR_SERVICE_NOT_ACTIVE != GetLastError()) { + dbglog(g_logger, "ControlService(STOP): {} ({})", GetLastError(), ag::sys::strerror(GetLastError())); + } + + if (!DeleteService(svc.get())) { + dbglog(g_logger, "DeleteService: {} ({})", GetLastError(), ag::sys::strerror(GetLastError())); + return VPN_EASY_SVC_ERR_OTHER; + } + + return 0; +} + +static constexpr auto SERVICE_OPERATION_TIMEOUT = std::chrono::seconds{30}; +static constexpr auto SERVICE_POLL_INTERVAL = std::chrono::milliseconds{250}; + +static struct ServiceControllerState { + std::mutex mutex; + HANDLE stop_event = nullptr; + std::unique_ptr pipe_client; + std::thread io_thread; + on_state_changed_t state_changed_cb = nullptr; + void *state_changed_cb_arg = nullptr; + + /// Tear down the pipe session and clear all state. Caller must hold `mutex`. + void reset() { + if (stop_event) { + SetEvent(stop_event); + } + if (io_thread.joinable()) { + io_thread.join(); + } + pipe_client.reset(); + if (stop_event) { + CloseHandle(stop_event); + stop_event = nullptr; + } + state_changed_cb = nullptr; + state_changed_cb_arg = nullptr; + } +} g_svc_state; + +/// Poll a service until it reaches the desired state, or timeout. +/// Return true if the desired state was reached, false on timeout. +static bool wait_for_service_state(SC_HANDLE svc, DWORD desired_state, std::chrono::milliseconds timeout) { + auto deadline = std::chrono::steady_clock::now() + timeout; + for (;;) { + SERVICE_STATUS status{}; + if (!QueryServiceStatus(svc, &status)) { + dbglog(g_logger, "QueryServiceStatus: {} ({})", GetLastError(), ag::sys::strerror(GetLastError())); + return false; + } + if (status.dwCurrentState == desired_state) { + return true; + } + if (std::chrono::steady_clock::now() >= deadline) { + return false; + } + std::this_thread::sleep_for(SERVICE_POLL_INTERVAL); + } +} + +/// Map a Windows error code from an SCM operation to a VpnEasyServiceError. +static int32_t map_scm_error(const char *func_name) { + DWORD err = GetLastError(); + if (err == ERROR_ACCESS_DENIED) { + return VPN_EASY_SVC_ERR_ACCESS; + } + if (err == ERROR_SERVICE_DOES_NOT_EXIST) { + return VPN_EASY_SVC_ERR_NO_SUCH_SERVICE; + } + dbglog(g_logger, "{}: {} ({})", func_name, err, ag::sys::strerror(err)); + return VPN_EASY_SVC_ERR_OTHER; +} + +int32_t vpn_easy_service_start(const wchar_t *service_name, const wchar_t *pipe_name, const char *toml_config, + on_state_changed_t state_changed_cb, void *state_changed_cb_arg) { + std::scoped_lock lock{g_svc_state.mutex}; + + if (g_svc_state.pipe_client) { + warnlog(g_logger, "Service client is already active"); + return VPN_EASY_SVC_ERR_OTHER; + } + + // Save callbacks early so the handler lambda can reference them. + g_svc_state.state_changed_cb = state_changed_cb; + g_svc_state.state_changed_cb_arg = state_changed_cb_arg; + + // ScopeExit: on any error return, clean up everything and clear callbacks. + bool success = false; + ag::utils::ScopeExit cleanup{[&] { + if (success) { + return; + } + g_svc_state.reset(); + }}; + + AutoScHandle scm{OpenSCManagerW(nullptr, nullptr, SC_MANAGER_CONNECT)}; + if (!scm) { + return map_scm_error("OpenSCManagerW"); + } + + AutoScHandle svc{OpenServiceW(scm.get(), service_name, SERVICE_START | SERVICE_QUERY_STATUS)}; + if (!svc) { + return map_scm_error("OpenServiceW"); + } + + SERVICE_STATUS status{}; + if (!QueryServiceStatus(svc.get(), &status)) { + dbglog(g_logger, "QueryServiceStatus: {} ({})", GetLastError(), ag::sys::strerror(GetLastError())); + return VPN_EASY_SVC_ERR_OTHER; + } + + if (status.dwCurrentState != SERVICE_RUNNING) { + if (!StartServiceW(svc.get(), 0, nullptr)) { + if (GetLastError() != ERROR_SERVICE_ALREADY_RUNNING) { + dbglog(g_logger, "StartServiceW: {} ({})", GetLastError(), ag::sys::strerror(GetLastError())); + return VPN_EASY_SVC_ERR_OTHER; + } + } + if (!wait_for_service_state(svc.get(), SERVICE_RUNNING, SERVICE_OPERATION_TIMEOUT)) { + errlog(g_logger, "Service did not reach RUNNING state within timeout"); + return VPN_EASY_SVC_ERR_TIMED_OUT; + } + } + + g_svc_state.stop_event = CreateEventW(nullptr, TRUE, FALSE, nullptr); + if (!g_svc_state.stop_event) { + dbglog(g_logger, "CreateEventW: {} ({})", GetLastError(), ag::sys::strerror(GetLastError())); + return VPN_EASY_SVC_ERR_OTHER; + } + + g_svc_state.pipe_client = std::make_unique( + pipe_name, g_svc_state.stop_event, + [](VpnEasyServiceMessageType what, ag::Uint8View data) { + switch (what) { + case VPN_EASY_SVC_MSG_STATE_CHANGED: { + if (data.size() < sizeof(uint32_t)) { + dbglog(g_logger, "STATE_CHANGED message too short: {} bytes", data.size()); + break; + } + uint32_t net_state = 0; + memcpy(&net_state, data.data(), sizeof(net_state)); + auto state = static_cast(ntohl(net_state)); + if (g_svc_state.state_changed_cb) { + g_svc_state.state_changed_cb(g_svc_state.state_changed_cb_arg, state); + } + break; + } + case VPN_EASY_SVC_MSG_CONNECTION_INFO: + dbglog(g_logger, "Received CONNECTION_INFO (ignored)"); + break; + default: + dbglog(g_logger, "Ignoring unexpected message type: {}", static_cast(what)); + break; + } + }, + std::chrono::duration_cast(SERVICE_OPERATION_TIMEOUT)); + + g_svc_state.io_thread = std::thread([] { + if (!g_svc_state.pipe_client->loop()) { + // Deliver VPN_SS_DISCONNECTED on unexpected exit if callback is still set. + if (g_svc_state.state_changed_cb) { + g_svc_state.state_changed_cb(g_svc_state.state_changed_cb_arg, ag::VPN_SS_DISCONNECTED); + } + } + }); + + if (!g_svc_state.pipe_client->wait_connected()) { + errlog(g_logger, "PipeClient failed to connect within timeout"); + return VPN_EASY_SVC_ERR_TIMED_OUT; + } + + size_t config_len = strlen(toml_config); + g_svc_state.pipe_client->send(VPN_EASY_SVC_MSG_START, {reinterpret_cast(toml_config), config_len}); + + success = true; + return 0; +} + +int32_t vpn_easy_service_stop(const wchar_t *service_name, const wchar_t *pipe_name) { + std::scoped_lock lock{g_svc_state.mutex}; + + if (!g_svc_state.pipe_client) { + return 0; + } + + g_svc_state.pipe_client->send(VPN_EASY_SVC_MSG_STOP, {}); + + g_svc_state.reset(); + + AutoScHandle scm{OpenSCManagerW(nullptr, nullptr, SC_MANAGER_CONNECT)}; + if (!scm) { + return map_scm_error("OpenSCManagerW"); + } + + AutoScHandle svc{OpenServiceW(scm.get(), service_name, SERVICE_STOP | SERVICE_QUERY_STATUS)}; + if (!svc) { + return map_scm_error("OpenServiceW"); + } + + SERVICE_STATUS status{}; + if (!ControlService(svc.get(), SERVICE_CONTROL_STOP, &status)) { + if (GetLastError() != ERROR_SERVICE_NOT_ACTIVE) { + dbglog(g_logger, "ControlService(STOP): {} ({})", GetLastError(), ag::sys::strerror(GetLastError())); + } + } + + if (!wait_for_service_state(svc.get(), SERVICE_STOPPED, SERVICE_OPERATION_TIMEOUT)) { + errlog(g_logger, "Service did not stop within timeout"); + return VPN_EASY_SVC_ERR_TIMED_OUT; + } + + return 0; +} diff --git a/platform/windows/src/vpn_easy_pipe.cpp b/platform/windows/src/vpn_easy_pipe.cpp new file mode 100644 index 0000000..384008a --- /dev/null +++ b/platform/windows/src/vpn_easy_pipe.cpp @@ -0,0 +1,589 @@ +#include "vpn_easy_pipe.h" + +#include + +#include +#include +#include +#include +#include + +#include "common/logger.h" +#include "common/system_error.h" +#include "vpn/internal/wire_utils.h" + +namespace ag::vpn_easy { + +static ag::Logger g_server_logger{"PIPE_SERVER"}; +static ag::Logger g_client_logger{"PIPE_CLIENT"}; + +namespace detail { +void free_security_descriptor(SECURITY_DESCRIPTOR *sd) { + LocalFree(sd); +} +} // namespace detail + +// --------------------------------------------------------------------------- +// PipeEndpoint +// --------------------------------------------------------------------------- + +PipeEndpoint::PipeEndpoint(HANDLE stop_event, Handler handler, ag::Logger &logger) + : m_io_event{CreateEventW(nullptr, TRUE, FALSE, nullptr)} + , m_write_event{CreateEventW(nullptr, TRUE, FALSE, nullptr)} + , m_wake_event{CreateEventW(nullptr, FALSE, FALSE, nullptr)} + , m_logger{logger} + , m_stop_event{stop_event} + , m_handler{std::move(handler)} { + assert(m_handler); + m_olr.hEvent = m_io_event; + m_olw.hEvent = m_write_event; + m_input_buf.resize(INPUT_BUF_SIZE); +} + +PipeEndpoint::~PipeEndpoint() { + // Subclass destructor must have already torn down `m_pipe` (so that any virtual + // `teardown_pipe()` would be unreachable from here, where it would not dispatch to the + // derived implementation). + if (m_io_event != nullptr) { + CloseHandle(m_io_event); + } + if (m_write_event != nullptr) { + CloseHandle(m_write_event); + } + if (m_wake_event != nullptr) { + CloseHandle(m_wake_event); + } +} + +void PipeEndpoint::cancel_pending_io() { + CancelIoEx(m_pipe, nullptr); + if (m_write_pending) { + DWORD ignored = 0; + GetOverlappedResult(m_pipe, &m_olw, &ignored, TRUE); + } + if (m_read_pending || !m_connected.load(std::memory_order_relaxed)) { + DWORD ignored = 0; + GetOverlappedResult(m_pipe, &m_olr, &ignored, TRUE); + } +} + +void PipeEndpoint::prepare_for_connect() { + m_connected.store(false, std::memory_order_relaxed); + m_read_pending = false; + m_write_pending = false; + m_input_buf_used = 0; + ResetEvent(m_io_event); + ResetEvent(m_write_event); +} + +std::vector PipeEndpoint::compose_message(VpnEasyServiceMessageType what, ag::Uint8View data) { + assert(data.size() < size_t(UINT32_MAX)); + std::vector ret; + ret.resize(sizeof(uint32_t) + sizeof(uint32_t) + data.size()); + ag::wire_utils::Writer w{{ret.data(), ret.size()}}; + w.put_u32(static_cast(what)); + w.put_u32(static_cast(data.size())); + w.put_data(data); + return ret; +} + +void PipeEndpoint::send(VpnEasyServiceMessageType what, ag::Uint8View data) { + { + std::scoped_lock l{m_pending_writes_lock}; + // disconnect_and_reset() stores `false` BEFORE taking this lock, so any push that + // happens-before disconnect's lock acquisition will be observed (and cleared) by + // disconnect, and any push that happens-after will see `false` here and bail out. + if (!m_connected.load(std::memory_order_relaxed)) { + return; + } + if (m_pending_writes.size() == MAX_PENDING_WRITES) { + static_assert(MAX_PENDING_WRITES > 0); + m_pending_writes.pop_front(); + } + m_pending_writes.push_back(PendingWrite{compose_message(what, data), 0}); + } + SetEvent(m_wake_event); +} + +bool PipeEndpoint::loop() { + if (m_io_event == nullptr || m_write_event == nullptr || m_wake_event == nullptr) { + return false; + } + + if (!start_connect()) { + return false; + } + + HANDLE events[] = {m_stop_event, m_wake_event, m_io_event, m_write_event}; + constexpr DWORD EVENT_COUNT = static_cast(std::size(events)); + for (;;) { + DWORD wait = WaitForMultipleObjects(EVENT_COUNT, events, FALSE, INFINITE); + if (wait >= WAIT_OBJECT_0 + EVENT_COUNT) { + errlog(m_logger, "WaitForMultipleObjects: {:#x}, GetLastError: {} ({})", wait, GetLastError(), + ag::sys::strerror(GetLastError())); + return false; + } + + DWORD idx = wait - WAIT_OBJECT_0; + if (idx == 0) { + // Stop event. + return true; + } + + if (idx == 2) { + // m_io_event: overlapped connect or ReadFile completed. + if (!m_connected.load(std::memory_order_relaxed)) { + if (!finalize_connect()) { + if (auto r = handle_disconnect()) { + return *r; + } + continue; + } + } else if (m_read_pending) { + if (!complete_read()) { + if (auto r = handle_disconnect()) { + return *r; + } + continue; + } + } + } + + if (idx == 3 && m_write_pending) { + // m_write_event: WriteFile completed. + if (!complete_write()) { + if (auto r = handle_disconnect()) { + return *r; + } + continue; + } + } + + // After any wake-up, try to issue a fresh read (if connected and not already pending) and + // pump as many writes as possible. + if (m_connected.load(std::memory_order_relaxed) && !m_read_pending) { + if (!start_read()) { + if (auto r = handle_disconnect()) { + return *r; + } + continue; + } + } + + if (m_connected.load(std::memory_order_relaxed) && !pump_writes()) { + if (auto r = handle_disconnect()) { + return *r; + } + continue; + } + } +} + +std::optional PipeEndpoint::handle_disconnect() { + disconnect_and_reset(); + if (!should_reconnect_on_disconnect()) { + return true; // Graceful peer-initiated close. + } + if (!start_connect()) { + return false; + } + return std::nullopt; +} + +bool PipeEndpoint::start_read() { + if (m_input_buf_used >= m_input_buf.size()) { + // Buffer is full but no complete message could be parsed -- impossible if MAX_MESSAGE_SIZE + // is honored, so this indicates a protocol violation. Drop the connection. + warnlog(m_logger, "input buffer full ({} bytes) with no parsable message; dropping connection", + m_input_buf_used); + return false; + } + DWORD read_size = 0; + BOOL ok = ReadFile(m_pipe, m_input_buf.data() + m_input_buf_used, + static_cast(m_input_buf.size() - m_input_buf_used), &read_size, &m_olr); + if (ok) { + // Synchronous completion. The kernel may also have signaled m_io_event on its own; if so, + // the next WFMO will wake on it but find `!m_read_pending` and just fall through to + // re-entering start_read(). Either way we must wake the loop ourselves so + // that start_read() runs again to drain any further data. + m_input_buf_used += read_size; + if (!handle_input()) { + return false; + } + SetEvent(m_wake_event); + return true; + } + DWORD err = GetLastError(); + if (err == ERROR_IO_PENDING) { + m_read_pending = true; + return true; + } + if (err == ERROR_BROKEN_PIPE || err == ERROR_PIPE_NOT_CONNECTED || err == ERROR_NO_DATA) { + infolog(m_logger, "ReadFile: peer disconnected ({}: {})", err, ag::sys::strerror(err)); + return false; + } + warnlog(m_logger, "ReadFile: {} ({})", err, ag::sys::strerror(err)); + return false; +} + +bool PipeEndpoint::complete_read() { + DWORD read_size = 0; + if (!GetOverlappedResult(m_pipe, &m_olr, &read_size, FALSE)) { + DWORD err = GetLastError(); + if (err == ERROR_BROKEN_PIPE || err == ERROR_PIPE_NOT_CONNECTED || err == ERROR_OPERATION_ABORTED) { + infolog(m_logger, "GetOverlappedResult(read): peer disconnected ({}: {})", err, ag::sys::strerror(err)); + return false; + } + warnlog(m_logger, "GetOverlappedResult(read): {} ({})", err, ag::sys::strerror(err)); + return false; + } + ResetEvent(m_io_event); + m_read_pending = false; + if (read_size == 0) { + infolog(m_logger, "ReadFile: EOF, peer disconnected"); + return false; + } + m_input_buf_used += read_size; + return handle_input(); +} + +bool PipeEndpoint::handle_input() { + for (;;) { + ag::wire_utils::Reader r{{m_input_buf.data(), m_input_buf_used}}; + auto what = r.get_u32(); + auto size = r.get_u32(); + if (!what.has_value() || !size.has_value()) { + return true; // Need more bytes for the header. + } + if (*size > MAX_MESSAGE_SIZE) { + warnlog(m_logger, "incoming message size {} exceeds MAX_MESSAGE_SIZE ({}); dropping connection", *size, + MAX_MESSAGE_SIZE); + return false; + } + auto data = r.get_bytes(*size); + if (!data.has_value()) { + return true; // Need more bytes for the payload. + } + m_handler(static_cast(*what), *data); + ag::Uint8View remaining = r.get_buffer(); + std::memmove(m_input_buf.data(), remaining.data(), remaining.size()); + m_input_buf_used = remaining.size(); + } +} + +bool PipeEndpoint::pump_writes() { + while (!m_write_pending) { + if (!m_inflight_write.has_value()) { + std::scoped_lock l{m_pending_writes_lock}; + if (m_pending_writes.empty()) { + return true; + } + m_inflight_write.emplace(std::move(m_pending_writes.front())); + m_pending_writes.pop_front(); + } + + PendingWrite &w = *m_inflight_write; + DWORD written = 0; + BOOL ok = WriteFile( + m_pipe, w.data.data() + w.written, static_cast(w.data.size() - w.written), &written, &m_olw); + if (ok) { + // Synchronous completion. The kernel may have signaled m_write_event (the docs are + // inconsistent), so reset it here to avoid a spurious wake on the next WFMO. + // For async (ERROR_IO_PENDING), the kernel resets the event itself when queueing. + ResetEvent(m_write_event); + w.written += written; + if (w.written == w.data.size()) { + m_inflight_write.reset(); + } + continue; + } + DWORD err = GetLastError(); + if (err == ERROR_IO_PENDING) { + m_write_pending = true; + return true; + } + if (err == ERROR_BROKEN_PIPE || err == ERROR_PIPE_NOT_CONNECTED || err == ERROR_NO_DATA) { + infolog(m_logger, "WriteFile: peer disconnected ({}: {})", err, ag::sys::strerror(err)); + return false; + } + warnlog(m_logger, "WriteFile: {} ({})", err, ag::sys::strerror(err)); + return false; + } + return true; +} + +bool PipeEndpoint::complete_write() { + DWORD written = 0; + if (!GetOverlappedResult(m_pipe, &m_olw, &written, FALSE)) { + DWORD err = GetLastError(); + if (err == ERROR_BROKEN_PIPE || err == ERROR_PIPE_NOT_CONNECTED || err == ERROR_OPERATION_ABORTED) { + infolog(m_logger, "GetOverlappedResult(write): peer disconnected ({}: {})", err, ag::sys::strerror(err)); + return false; + } + warnlog(m_logger, "GetOverlappedResult(write): {} ({})", err, ag::sys::strerror(err)); + return false; + } + // Consume the kernel-set completion signal so WFMO doesn't keep firing on m_write_event. + ResetEvent(m_write_event); + m_write_pending = false; + if (m_inflight_write.has_value()) { + PendingWrite &w = *m_inflight_write; + w.written += written; + if (w.written == w.data.size()) { + m_inflight_write.reset(); + } + } + return true; +} + +void PipeEndpoint::disconnect_and_reset() { + // Mark disconnected BEFORE taking the lock, so that any send() that subsequently acquires the + // lock observes `false` and bails out. Any send() already holding (or already past) the lock + // happens-before our lock acquisition below, so its push will be cleared by the clear() call. + m_connected.store(false, std::memory_order_relaxed); + { + std::scoped_lock l{m_pending_writes_lock}; + m_pending_writes.clear(); + } + if (m_pipe != INVALID_HANDLE_VALUE) { + CancelIoEx(m_pipe, nullptr); + if (m_read_pending) { + DWORD ignored = 0; + GetOverlappedResult(m_pipe, &m_olr, &ignored, TRUE); + } + if (m_write_pending) { + DWORD ignored = 0; + GetOverlappedResult(m_pipe, &m_olw, &ignored, TRUE); + } + } + // Safe to drop the in-flight buffer now: the kernel has acknowledged the cancel. + m_inflight_write.reset(); + teardown_pipe(); + // Both event handles may have been left signaled by GetOverlappedResult(TRUE) above. + ResetEvent(m_io_event); + ResetEvent(m_write_event); + m_read_pending = false; + m_write_pending = false; + m_input_buf_used = 0; +} + +// --------------------------------------------------------------------------- +// PipeServer +// --------------------------------------------------------------------------- + +SecurityDescriptorPtr PipeServer::for_authenticated_users() { + // SDDL DACL: + // (A;;GA;;;SY) - SYSTEM full + // (A;;GA;;;BA) - BUILTIN\Administrators full + // (A;;GRGW;;;AU) - NT AUTHORITY\Authenticated Users read+write + static constexpr wchar_t SDDL[] = L"D:(A;;GA;;;SY)(A;;GA;;;BA)(A;;GRGW;;;AU)"; + PSECURITY_DESCRIPTOR sd = nullptr; + if (!ConvertStringSecurityDescriptorToSecurityDescriptorW(SDDL, SDDL_REVISION_1, &sd, nullptr)) { + DWORD err = GetLastError(); + errlog(g_server_logger, "ConvertStringSecurityDescriptorToSecurityDescriptorW: {} ({})", err, + ag::sys::strerror(err)); + return {}; + } + return SecurityDescriptorPtr{static_cast(sd)}; +} + +PipeServer::PipeServer( + const wchar_t *pipe_name, HANDLE stop_event, Handler handler, SECURITY_DESCRIPTOR *security_descriptor) + : PipeEndpoint{stop_event, std::move(handler), g_server_logger} { + m_pipe = create_pipe(pipe_name, security_descriptor); +} + +PipeServer::~PipeServer() { + if (m_pipe != INVALID_HANDLE_VALUE) { + cancel_pending_io(); + DisconnectNamedPipe(m_pipe); + CloseHandle(m_pipe); + m_pipe = INVALID_HANDLE_VALUE; + } +} + +HANDLE PipeServer::create_pipe(const wchar_t *pipe_name, SECURITY_DESCRIPTOR *security_descriptor) { + SECURITY_ATTRIBUTES sa{}; + sa.nLength = sizeof(sa); + sa.lpSecurityDescriptor = security_descriptor; + sa.bInheritHandle = FALSE; + + HANDLE h = CreateNamedPipeW(pipe_name, PIPE_ACCESS_DUPLEX | FILE_FLAG_OVERLAPPED, + PIPE_TYPE_BYTE | PIPE_READMODE_BYTE | PIPE_WAIT, + 1, // single instance + PIPE_BUFFER_SIZE, PIPE_BUFFER_SIZE, 0, security_descriptor != nullptr ? &sa : nullptr); + if (h == INVALID_HANDLE_VALUE) { + errlog(g_server_logger, "CreateNamedPipeW: {} ({})", GetLastError(), ag::sys::strerror(GetLastError())); + } + return h; +} + +bool PipeServer::start_connect() { + prepare_for_connect(); + if (m_pipe == INVALID_HANDLE_VALUE) { + // create_pipe() failed in the constructor. + return false; + } + + if (ConnectNamedPipe(m_pipe, &m_olr)) { + // Synchronous success (very rare for overlapped pipes). The OVERLAPPED was not really + // used by the kernel in this case, so do not call finalize_connect (which would call + // GetOverlappedResult on it). Mark connected directly and kick the loop. + ResetEvent(m_io_event); // Defensive: kernel may have signaled on sync completion. + m_connected.store(true, std::memory_order_relaxed); + SetEvent(m_wake_event); + infolog(m_logger, "client connected (sync)"); + return true; + } + DWORD err = GetLastError(); + if (err == ERROR_PIPE_CONNECTED) { + // A client connected between CreateNamedPipe and ConnectNamedPipe. No overlapped op was + // submitted; mark connected directly. + m_connected.store(true, std::memory_order_relaxed); + SetEvent(m_wake_event); + infolog(m_logger, "client connected (already connected)"); + return true; + } + if (err == ERROR_IO_PENDING) { + return true; + } + errlog(m_logger, "ConnectNamedPipe: {} ({})", err, ag::sys::strerror(err)); + return false; +} + +bool PipeServer::finalize_connect() { + DWORD transferred = 0; + if (!GetOverlappedResult(m_pipe, &m_olr, &transferred, FALSE)) { + DWORD err = GetLastError(); + warnlog(m_logger, "GetOverlappedResult(connect): {} ({})", err, ag::sys::strerror(err)); + return false; + } + ResetEvent(m_io_event); + m_connected.store(true, std::memory_order_relaxed); + infolog(m_logger, "client connected"); + return true; +} + +void PipeServer::teardown_pipe() { + // Server reuses the same pipe instance across reconnects: just disconnect the current client. + if (m_pipe != INVALID_HANDLE_VALUE) { + DisconnectNamedPipe(m_pipe); + } +} + +// --------------------------------------------------------------------------- +// PipeClient +// --------------------------------------------------------------------------- + +PipeClient::PipeClient( + const wchar_t *pipe_name, HANDLE stop_event, Handler handler, std::chrono::milliseconds connect_timeout) + : PipeEndpoint{stop_event, std::move(handler), g_client_logger} + , m_pipe_name{pipe_name} + , m_connect_timeout{connect_timeout.count() <= 0 ? DEFAULT_CONNECT_TIMEOUT : connect_timeout} + , m_connected_or_failed_event{CreateEventW(nullptr, TRUE, FALSE, nullptr)} { +} + +PipeClient::~PipeClient() { + if (m_pipe != INVALID_HANDLE_VALUE) { + cancel_pending_io(); + CloseHandle(m_pipe); + m_pipe = INVALID_HANDLE_VALUE; + } + if (m_connected_or_failed_event != nullptr) { + CloseHandle(m_connected_or_failed_event); + } +} + +bool PipeClient::wait_connected() { + if (m_connected_or_failed_event == nullptr) { + return false; + } + HANDLE events[] = {m_connected_or_failed_event, stop_event()}; + intmax_t timeout_ms = INFINITE; + if (m_connect_timeout.count() < 0) { + timeout_ms = 0; + } else if (m_connect_timeout.count() < INFINITE) { + timeout_ms = m_connect_timeout.count(); + } + DWORD r = WaitForMultipleObjects( + static_cast(std::size(events)), events, FALSE, static_cast(timeout_ms)); + if (r != WAIT_OBJECT_0) { + // Stop event won, or timed out, or wait failed. + return false; + } + // The event is signaled both on successful connect and on fatal start failure; the atomic + // disambiguates. + return m_connected.load(std::memory_order_relaxed); +} + +bool PipeClient::start_connect() { + prepare_for_connect(); + ResetEvent(m_connected_or_failed_event); + + // The server is single-instance: when the previous client disconnects there is a brief + // window between the kernel observing the broken pipe and the server having both + // DisconnectNamedPipe()'d the old client AND posted a fresh ConnectNamedPipe() for the new + // one. During that window CreateFileW returns ERROR_PIPE_BUSY (instance still bound to the + // previous client) or ERROR_FILE_NOT_FOUND (no listening instance yet). Retry with a bounded + // total deadline, in short stop-event-interruptible slices. + constexpr auto DEFAULT_SLICE = ag::Millis{10}; + // Clamp the slice so it never overshoots the configured connect timeout. + auto slice = std::max(ag::Millis{1}, std::min(DEFAULT_SLICE, m_connect_timeout)); + DWORD slice_ms = static_cast(slice.count()); + auto deadline = std::chrono::steady_clock::now() + m_connect_timeout; + + for (;;) { + // CreateFileW is synchronous; FILE_FLAG_OVERLAPPED affects only subsequent IO on the handle. + m_pipe = CreateFileW(m_pipe_name.c_str(), GENERIC_READ | GENERIC_WRITE, 0, nullptr, OPEN_EXISTING, + FILE_FLAG_OVERLAPPED, nullptr); + if (m_pipe != INVALID_HANDLE_VALUE) { + break; + } + DWORD err = GetLastError(); + if (err != ERROR_PIPE_BUSY && err != ERROR_FILE_NOT_FOUND) { + errlog(m_logger, "CreateFileW: {} ({})", err, ag::sys::strerror(err)); + SetEvent(m_connected_or_failed_event); + return false; + } + if (std::chrono::steady_clock::now() >= deadline) { + errlog(m_logger, "CreateFileW: timed out waiting for server (last err {}: {})", err, + ag::sys::strerror(err)); + SetEvent(m_connected_or_failed_event); + return false; + } + // Interruptible nap before retry. + if (WaitForSingleObject(stop_event(), slice_ms) == WAIT_OBJECT_0) { + SetEvent(m_connected_or_failed_event); + return false; + } + if (err == ERROR_PIPE_BUSY) { + // Wait for an instance to become available; ignore the result and just retry CreateFileW. + // WaitNamedPipeW is uninterruptible. + WaitNamedPipeW(m_pipe_name.c_str(), slice_ms); + } + } + + // Defensive: ensure byte-mode read semantics regardless of the server's configuration. + DWORD mode = PIPE_READMODE_BYTE; + if (!SetNamedPipeHandleState(m_pipe, &mode, nullptr, nullptr)) { + DWORD err = GetLastError(); + warnlog(m_logger, "SetNamedPipeHandleState: {} ({})", err, ag::sys::strerror(err)); + } + + m_connected.store(true, std::memory_order_relaxed); + SetEvent(m_connected_or_failed_event); + SetEvent(m_wake_event); + infolog(m_logger, "connected to server"); + return true; +} + +void PipeClient::teardown_pipe() { + // Client uses a single-shot handle: close it and let the loop exit + // (should_reconnect_on_disconnect() returns false). + if (m_pipe != INVALID_HANDLE_VALUE) { + CloseHandle(m_pipe); + m_pipe = INVALID_HANDLE_VALUE; + } +} + +} // namespace ag::vpn_easy diff --git a/platform/windows/src/vpn_easy_pipe.h b/platform/windows/src/vpn_easy_pipe.h new file mode 100644 index 0000000..3edbc19 --- /dev/null +++ b/platform/windows/src/vpn_easy_pipe.h @@ -0,0 +1,291 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#define WIN32_LEAN_AND_MEAN +#include + +#include "common/defs.h" +#include "common/logger.h" +#include "vpn/vpn_easy_service.h" + +namespace ag::vpn_easy { + +namespace detail { +/** Free a security descriptor returned by an SDDL helper. Used as the deleter for `SecurityDescriptorPtr`. */ +void free_security_descriptor(SECURITY_DESCRIPTOR *sd); +} // namespace detail + +/** Owning pointer for a security descriptor allocated via `LocalAlloc` (e.g. by SDDL helpers). */ +using SecurityDescriptorPtr = ag::UniquePtr; + +/** + * Asynchronous named-pipe endpoint base class for the VPN easy-service control protocol. + * + * Holds all framing, queueing, overlapped-IO and event-loop machinery shared by both ends of the + * pipe. Subclasses implement only: + * - how a pipe handle is acquired (`start_connect()`), + * - how a posted overlapped connect completion is reaped (`finalize_connect()`), + * - how the pipe handle is torn down on disconnect (`teardown_pipe()`), + * - and the disconnect policy (`should_reconnect_on_disconnect()`). + * + * After construction, the caller drives the IO loop by calling `loop()` on a worker thread; the + * loop returns once the externally-provided `stop_event` becomes signaled, the peer disconnects + * (only for endpoints whose `should_reconnect_on_disconnect()` returns `false`), or a fatal + * error occurs. + * + * `send()` is thread-safe and may be called from any thread, including from inside the receive + * callback. If the endpoint is not currently connected the message is dropped. If the queue + * overflows, the oldest pending messages are dropped. + */ +class PipeEndpoint { +public: + /** + * Callback invoked from `loop()`'s thread for every fully-received message. + * The `data` view is valid only for the duration of the call. + */ + using Handler = std::function; + + virtual ~PipeEndpoint(); + + PipeEndpoint(const PipeEndpoint &) = delete; + PipeEndpoint &operator=(const PipeEndpoint &) = delete; + PipeEndpoint(PipeEndpoint &&) = delete; + PipeEndpoint &operator=(PipeEndpoint &&) = delete; + + /** + * Run the asynchronous IO loop until `stop_event` is signaled, the peer disconnects (for + * non-reconnecting endpoints), or a fatal error occurs. + * @return `true` if stopped via the stop event or a graceful peer disconnect, `false` on + * fatal error. + */ + bool loop(); + + /** + * Enqueue a message to be sent to the currently-connected peer. Thread-safe. + * Drop the message if no peer is connected. If the internal queue is full, drop the oldest + * pending messages. + */ + void send(VpnEasyServiceMessageType what, ag::Uint8View data); + +protected: + /** + * @param stop_event Externally-owned manual-reset event. When signaled, `loop()` returns `true`. + * Ownership is NOT transferred. + * @param handler Message receive callback. Must be non-null. + * @param logger Logger to use for diagnostic messages. + */ + PipeEndpoint(HANDLE stop_event, Handler handler, ag::Logger &logger); + + /** + * Subclass hook: acquire/initiate the pipe connection. On success, either: + * - mark `m_connected` true synchronously and `SetEvent(m_wake_event)`, or + * - leave `m_connected` false and post an overlapped op on `m_olr`/`m_io_event`; the loop + * will then call `finalize_connect()` when `m_io_event` fires. + * Implementations MUST call `prepare_for_connect()` first to clear per-connection state. + * @return `true` on success (sync or pending), `false` on fatal error (`loop()` returns false). + */ + virtual bool start_connect() = 0; + + /** + * Subclass hook: reap the completion of an overlapped connect posted by `start_connect()`, + * called when `m_io_event` fires while not yet connected. Default returns `false` (no + * overlapped connect was posted; an `m_io_event` wake here is unexpected). + */ + virtual bool finalize_connect() { + return false; + } + + /** + * Subclass hook: tear down the pipe handle after pending IO has been cancelled and drained. + * Called from `disconnect_and_reset()`. Implementations typically call `DisconnectNamedPipe` + * (server -- handle is reused) or `CloseHandle` and reset `m_pipe` to `INVALID_HANDLE_VALUE` + * (client -- handle is single-use). + */ + virtual void teardown_pipe() = 0; + + /** + * Subclass hook: report the disconnect policy. Returning `true` (the default) causes the + * loop to invoke `start_connect()` again after each disconnect; returning `false` causes + * `loop()` to return `true` after the first disconnect. `PipeClient` overrides to return + * `false`. + */ + virtual bool should_reconnect_on_disconnect() const { + return true; + } + + /** + * Reset per-connection state and event handles to their initial values. Subclasses MUST call + * this at the top of `start_connect()`. + */ + void prepare_for_connect(); + + /** + * Cancel any pending overlapped IO on `m_pipe` and synchronously wait for the cancellations + * to land. Used by subclass destructors to safely tear down a still-active endpoint. + * Caller must ensure `m_pipe != INVALID_HANDLE_VALUE`. + */ + void cancel_pending_io(); + + // Pipe handle. Owned by the subclass: the server creates it once in its constructor and + // re-uses it across `DisconnectNamedPipe`; the client creates it in `start_connect()` and + // destroys it in `teardown_pipe()`. + HANDLE m_pipe = INVALID_HANDLE_VALUE; + + // Overlapped state, used by both subclass connect logic (`m_olr`) and the shared read/write + // pipeline. Subclasses must not touch these except as documented above. + OVERLAPPED m_olr{}; ///< For overlapped connect (subclass) / `ReadFile` (base). + OVERLAPPED m_olw{}; ///< For `WriteFile` (base only). + HANDLE m_io_event = nullptr; ///< Signaled on overlapped connect or read completion. + HANDLE m_write_event = nullptr; ///< Signaled on write completion. + HANDLE m_wake_event = nullptr; ///< Set by `send()` (and by sync-connect paths) to wake the loop. + + // Connection state. Written by the loop thread; read by `send()` (any thread). + std::atomic m_connected{false}; + + // Logger. Bound to a static ag::Logger owned by the subclass's translation unit. + ag::Logger &m_logger; + + // Externally-owned stop event. Exposed to subclasses so that long-running connect retries + // (e.g. PipeClient::start_connect) can be interrupted promptly when the loop is asked to stop. + HANDLE stop_event() const { + return m_stop_event; + } + +private: + static constexpr size_t MAX_PENDING_WRITES = 100; + // Maximum payload size of a single message. Messages larger than this are rejected and the + // connection is dropped (protocol violation / DoS protection). + static constexpr size_t MAX_MESSAGE_SIZE = 16 * 1024; + // Receive buffer size: large enough to hold one full message plus its 8-byte header. + static constexpr size_t INPUT_BUF_SIZE = MAX_MESSAGE_SIZE + 2 * sizeof(uint32_t); + + struct PendingWrite { + std::vector data; + size_t written; + }; + + HANDLE m_stop_event = nullptr; + Handler m_handler; + bool m_read_pending = false; + bool m_write_pending = false; + + std::vector m_input_buf; + size_t m_input_buf_used = 0; + + std::mutex m_pending_writes_lock; + std::list m_pending_writes; // Guarded by m_pending_writes_lock. + + // Owned exclusively by the loop thread: the message currently being written (possibly with an + // overlapped WriteFile in flight). Moved here from m_pending_writes under the lock and kept + // alive until the write fully completes, so that send() can never free the in-flight buffer. + std::optional m_inflight_write; + + static std::vector compose_message(VpnEasyServiceMessageType what, ag::Uint8View data); + + // Returns nullopt to continue the loop; otherwise the value `loop()` should return. + std::optional handle_disconnect(); + + bool start_read(); + bool complete_read(); + bool handle_input(); + bool pump_writes(); + bool complete_write(); + void disconnect_and_reset(); +}; + +/** + * Server endpoint: owns a single byte-stream named pipe instance, accepts one client at a time, + * and transparently reconnects (waits for a new client) when the current client disconnects. + */ +class PipeServer : public PipeEndpoint { +public: + /** + * Create a security descriptor that grants GENERIC_READ | GENERIC_WRITE to + * NT AUTHORITY\Authenticated Users, and full control to SYSTEM and BUILTIN\Administrators. + * Suitable for a service-side IPC named pipe that must be reachable from any locally + * authenticated user session. Returns null on failure. + */ + static SecurityDescriptorPtr for_authenticated_users(); + + /** + * @param pipe_name Full named-pipe name (e.g. `\\.\pipe\my_pipe`). + * @param stop_event See `PipeEndpoint`. + * @param handler See `PipeEndpoint`. + * @param security_descriptor Optional security descriptor for the pipe. If null (the default), + * the system default DACL is used. The pointer is consumed + * synchronously by the constructor; the caller may destroy the + * descriptor immediately after construction returns. + */ + PipeServer(const wchar_t *pipe_name, HANDLE stop_event, Handler handler, + SECURITY_DESCRIPTOR *security_descriptor = nullptr); + ~PipeServer() override; + +protected: + bool start_connect() override; + bool finalize_connect() override; + void teardown_pipe() override; + // Default `Reconnect` policy is exactly what the server wants. + +private: + static constexpr DWORD PIPE_BUFFER_SIZE = 64 * 1024; + + static HANDLE create_pipe(const wchar_t *pipe_name, SECURITY_DESCRIPTOR *security_descriptor); +}; + +/** + * Client endpoint: opens a connection to an existing named-pipe server. The IO loop exits on + * peer disconnect (returning `true` from `loop()`); the caller should construct a new `PipeClient` to reconnect. + */ +class PipeClient : public PipeEndpoint { +public: + /** Default total timeout used by `start_connect()` when the constructor is passed `0`. */ + static constexpr std::chrono::milliseconds DEFAULT_CONNECT_TIMEOUT{500}; + + /** + * @param pipe_name Full named-pipe name (e.g. `\\.\pipe\my_pipe`). + * @param stop_event See `PipeEndpoint`. + * @param handler See `PipeEndpoint`. + * @param connect_timeout Maximum total time `start_connect()` will spend retrying + * `CreateFileW` while the server is briefly unavailable + * (e.g. mid-reconnect of a previous client). A value of `0` selects + * `DEFAULT_CONNECT_TIMEOUT`. Negative values are treated as `0`. + */ + PipeClient(const wchar_t *pipe_name, HANDLE stop_event, Handler handler, + std::chrono::milliseconds connect_timeout = std::chrono::milliseconds{0}); + ~PipeClient() override; + + /** + * Block until the client has successfully connected to the server (i.e. `start_connect()`, + * driven by `loop()` on another thread, has succeeded), or the externally-supplied stop event + * is signaled -- whichever happens first. Thread-safe; may be called + * from any thread, including before `loop()` has started. + * @return `true` if the connection is established within the timeout, `false` otherwise + * (timeout, stop event signaled, or fatal connect failure). + */ + bool wait_connected(); + +protected: + bool start_connect() override; + void teardown_pipe() override; + bool should_reconnect_on_disconnect() const override { + return false; + } + +private: + std::wstring m_pipe_name; + std::chrono::milliseconds m_connect_timeout; + // Manual-reset event signaled by start_connect() on success and by loop() on fatal start + // failure. Used by wait_connected() so callers can synchronize without polling. Reset at the + // top of every start_connect() attempt so that a fresh PipeClient instance starts clean. + HANDLE m_connected_or_failed_event = nullptr; +}; + +} // namespace ag::vpn_easy diff --git a/platform/windows/src/vpn_easy_service.cpp b/platform/windows/src/vpn_easy_service.cpp new file mode 100644 index 0000000..5153d59 --- /dev/null +++ b/platform/windows/src/vpn_easy_service.cpp @@ -0,0 +1,159 @@ +#include "vpn/vpn_easy_service.h" +#include "vpn/vpn_easy.h" + +#include +#include +#include +#include + +#include "common/defs.h" +#include "common/logger.h" + +#define WIN32_LEAN_AND_MEAN +#include + +#include "common/system_error.h" +#include "vpn/trusttunnel/connection_info.h" +#include "vpn/vpn.h" +#include "vpn_easy_pipe.h" + +using ag::vpn_easy::PipeServer; + +static ag::Logger g_logger{"VPN_EASY_SERVICE"}; + +static std::wstring g_pipe_name; +static SERVICE_STATUS_HANDLE g_status_handle; +static HANDLE g_shutdown_event; +static vpn_easy_t *g_vpn; + +/// Send a `VPN_EASY_SVC_MSG_STATE_CHANGED` message with the given state value. +static void send_state(PipeServer &server, int32_t state) { + uint32_t net_state = htonl(static_cast(state)); + server.send(VPN_EASY_SVC_MSG_STATE_CHANGED, {reinterpret_cast(&net_state), sizeof(net_state)}); +} + +/// Handle an incoming pipe message from a client. +static void pipe_handler(PipeServer &server, VpnEasyServiceMessageType what, ag::Uint8View data) { + switch (what) { + case VPN_EASY_SVC_MSG_START: { + if (g_vpn != nullptr) { + infolog(g_logger, "VPN already running, stopping before restart"); + vpn_easy_stop_ex(g_vpn); + g_vpn = nullptr; + } + std::string toml_config(reinterpret_cast(data.data()), data.size()); + infolog(g_logger, "Starting VPN client"); + g_vpn = vpn_easy_start_ex( + toml_config.c_str(), + [](void *arg, int state) { + send_state(*static_cast(arg), state); + }, + &server, + [](void *arg, void *connection_info) { + std::string json = + ag::ConnectionInfo::to_json(static_cast(connection_info)); + static_cast(arg)->send(VPN_EASY_SVC_MSG_CONNECTION_INFO, + {reinterpret_cast(json.data()), json.size()}); + }, + &server); + if (g_vpn == nullptr) { + warnlog(g_logger, "vpn_easy_start_ex failed"); + send_state(server, ag::VPN_SS_DISCONNECTED); + } + break; + } + case VPN_EASY_SVC_MSG_STOP: { + if (g_vpn == nullptr) { + infolog(g_logger, "VPN already stopped, ignoring STOP"); + return; + } + infolog(g_logger, "Stopping VPN client"); + vpn_easy_stop_ex(g_vpn); + g_vpn = nullptr; + break; + } + case VPN_EASY_SVC_MSG_STATE_CHANGED: + case VPN_EASY_SVC_MSG_CONNECTION_INFO: + warnlog(g_logger, "Ignoring server-to-client message type: {}", static_cast(what)); + break; + default: + warnlog(g_logger, "Unknown message type: {}", static_cast(what)); + break; + } +} + +static void service_set_status(DWORD current_state) { + SERVICE_STATUS status{ + .dwServiceType = SERVICE_WIN32_OWN_PROCESS, + .dwCurrentState = current_state, + .dwControlsAccepted = SERVICE_ACCEPT_STOP | SERVICE_ACCEPT_SHUTDOWN, + }; + SetServiceStatus(g_status_handle, &status); +} + +static void WINAPI service_ctrl_handler(DWORD control) { + switch (control) { + case SERVICE_CONTROL_STOP: + case SERVICE_CONTROL_SHUTDOWN: + SetEvent(g_shutdown_event); + break; + default: + break; + } +} + +static void WINAPI service_main(DWORD /*argc*/, LPWSTR * /*argv*/) { + g_status_handle = RegisterServiceCtrlHandlerW(L"", service_ctrl_handler); + g_shutdown_event = CreateEventW(nullptr, TRUE, FALSE, nullptr); + + service_set_status(SERVICE_START_PENDING); + + PipeServer server{g_pipe_name.c_str(), g_shutdown_event, + [&server](VpnEasyServiceMessageType what, ag::Uint8View data) { + pipe_handler(server, what, data); + }, + PipeServer::for_authenticated_users().get()}; + + service_set_status(SERVICE_RUNNING); + server.loop(); + + if (g_vpn != nullptr) { + infolog(g_logger, "Shutting down: stopping VPN client"); + vpn_easy_stop_ex(g_vpn); + g_vpn = nullptr; + } + + service_set_status(SERVICE_STOPPED); +} + +int wmain(int argc, wchar_t **argv) { + if (argc != 3) { + return 1; + } + + ag::UniquePtr logfile{_wfsopen(argv[1], L"w", _SH_DENYWR)}; + if (logfile) { + setvbuf(logfile.get(), nullptr, _IONBF, 0); + ag::Logger::set_callback(ag::Logger::LogToFile{logfile.get()}); + } + ag::Logger::set_log_level(ag::LOG_LEVEL_INFO); + + g_pipe_name = argv[2]; + + wchar_t svc_name[] = L""; + SERVICE_TABLE_ENTRYW start_table[] = { + {svc_name, service_main}, + {nullptr, nullptr}, + }; + +#ifndef AG_DEBUGGING_VPN_EASY_SERVICE + if (!StartServiceCtrlDispatcherW(start_table)) { + errlog(g_logger, "StartServiceCtrlDispatcherW: {} ({})", GetLastError(), ag::sys::strerror(GetLastError())); + return 3; + } +#else + service_main(0, nullptr); +#endif + + return 0; +} diff --git a/platform/windows/test/vpn_easy_pipe_test.cpp b/platform/windows/test/vpn_easy_pipe_test.cpp new file mode 100644 index 0000000..7531bf6 --- /dev/null +++ b/platform/windows/test/vpn_easy_pipe_test.cpp @@ -0,0 +1,1134 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define WIN32_LEAN_AND_MEAN +#include + +#include "common/logger.h" +#include "vpn/internal/wire_utils.h" + +#include "vpn_easy_pipe.h" + +using namespace ag::vpn_easy; +using namespace std::chrono_literals; + +namespace { + +constexpr auto TEST_TIMEOUT = std::chrono::seconds(5); +constexpr auto JOIN_TIMEOUT = std::chrono::seconds(5); +// Mirrors PipeEndpoint::MAX_MESSAGE_SIZE (private). If that changes, update here. +constexpr size_t MAX_MESSAGE_SIZE = 16 * 1024; +constexpr size_t WIRE_HEADER_SIZE = 8; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +std::wstring unique_pipe_name() { + static std::atomic counter{0}; + auto pid = static_cast(GetCurrentProcessId()); + auto n = counter.fetch_add(1, std::memory_order_relaxed); + return L"\\\\.\\pipe\\agvpn_pipe_test_" + std::to_wstring(pid) + L"_" + std::to_wstring(n); +} + +std::vector make_framed(uint32_t what, std::span payload) { + std::vector ret(WIRE_HEADER_SIZE + payload.size()); + ag::wire_utils::Writer w{{ret.data(), ret.size()}}; + w.put_u32(what); + w.put_u32(static_cast(payload.size())); + w.put_data({payload.data(), payload.size()}); + return ret; +} + +// Build a header whose advertised length differs from the actual payload size; used by the +// oversized-message test. +std::vector make_framed_with_advertised_len(uint32_t what, uint32_t advertised_len) { + std::vector ret(WIRE_HEADER_SIZE); + ag::wire_utils::Writer w{{ret.data(), ret.size()}}; + w.put_u32(what); + w.put_u32(advertised_len); + return ret; +} + +struct ReceivedMessage { + VpnEasyServiceMessageType what; + std::vector payload; +}; + +// Thread-safe receiver of messages delivered via PipeEndpoint::Handler. +class MessageCollector { +public: + PipeEndpoint::Handler make_handler() { + return [this](VpnEasyServiceMessageType what, ag::Uint8View data) { + std::scoped_lock l{m_lock}; + m_messages.push_back({what, std::vector(data.begin(), data.end())}); + m_cv.notify_all(); + }; + } + + bool wait_for_count(size_t n, std::chrono::milliseconds timeout) { + std::unique_lock l{m_lock}; + return m_cv.wait_for(l, timeout, [&] { + return m_messages.size() >= n; + }); + } + + std::vector snapshot() { + std::scoped_lock l{m_lock}; + return m_messages; + } + + size_t count() { + std::scoped_lock l{m_lock}; + return m_messages.size(); + } + +private: + std::mutex m_lock; + std::condition_variable m_cv; + std::vector m_messages; +}; + +// RAII Win32 HANDLE wrapper. +class Handle { +public: + Handle() = default; + explicit Handle(HANDLE h) + : m_h(h) { + } + Handle(const Handle &) = delete; + Handle &operator=(const Handle &) = delete; + Handle(Handle &&o) noexcept + : m_h(std::exchange(o.m_h, INVALID_HANDLE_VALUE)) { + } + Handle &operator=(Handle &&o) noexcept { + reset(); + m_h = std::exchange(o.m_h, INVALID_HANDLE_VALUE); + return *this; + } + ~Handle() { + reset(); + } + + void reset() { + if (m_h != nullptr && m_h != INVALID_HANDLE_VALUE) { + CloseHandle(m_h); + } + m_h = INVALID_HANDLE_VALUE; + } + + HANDLE get() const { + return m_h; + } + explicit operator bool() const { + return m_h != nullptr && m_h != INVALID_HANDLE_VALUE; + } + +private: + HANDLE m_h = INVALID_HANDLE_VALUE; +}; + +// Open a raw (overlapped) client connection to a server pipe, retrying until `timeout` elapses. +Handle open_raw_client(const std::wstring &name, std::chrono::milliseconds timeout = TEST_TIMEOUT) { + auto deadline = std::chrono::steady_clock::now() + timeout; + for (;;) { + HANDLE h = CreateFileW( + name.c_str(), GENERIC_READ | GENERIC_WRITE, 0, nullptr, OPEN_EXISTING, FILE_FLAG_OVERLAPPED, nullptr); + if (h != INVALID_HANDLE_VALUE) { + return Handle{h}; + } + DWORD err = GetLastError(); + if (err != ERROR_FILE_NOT_FOUND && err != ERROR_PIPE_BUSY) { + return Handle{}; + } + if (std::chrono::steady_clock::now() > deadline) { + return Handle{}; + } + std::this_thread::sleep_for(10ms); + } +} + +// Synchronously write `data` to an overlapped handle. +bool write_all(HANDLE h, std::span data) { + OVERLAPPED ol{}; + Handle ev{CreateEventW(nullptr, TRUE, FALSE, nullptr)}; + ol.hEvent = ev.get(); + DWORD written = 0; + BOOL ok = WriteFile(h, data.data(), static_cast(data.size()), &written, &ol); + if (!ok && GetLastError() == ERROR_IO_PENDING) { + ok = GetOverlappedResult(h, &ol, &written, TRUE); + } + return ok && written == data.size(); +} + +// Read exactly `buf.size()` bytes from an overlapped handle, with a deadline. +bool read_exact(HANDLE h, std::span buf, std::chrono::milliseconds timeout) { + Handle ev{CreateEventW(nullptr, TRUE, FALSE, nullptr)}; + OVERLAPPED ol{}; + ol.hEvent = ev.get(); + size_t got = 0; + auto deadline = std::chrono::steady_clock::now() + timeout; + while (got < buf.size()) { + ResetEvent(ol.hEvent); + DWORD n = 0; + BOOL ok = ReadFile(h, buf.data() + got, static_cast(buf.size() - got), &n, &ol); + if (!ok) { + DWORD err = GetLastError(); + if (err != ERROR_IO_PENDING) { + return false; + } + auto now = std::chrono::steady_clock::now(); + if (now >= deadline) { + CancelIoEx(h, &ol); + GetOverlappedResult(h, &ol, &n, TRUE); + return false; + } + auto wait_ms = std::chrono::duration_cast(deadline - now).count(); + DWORD w = WaitForSingleObject(ol.hEvent, static_cast(wait_ms)); + if (w != WAIT_OBJECT_0) { + CancelIoEx(h, &ol); + GetOverlappedResult(h, &ol, &n, TRUE); + return false; + } + if (!GetOverlappedResult(h, &ol, &n, FALSE)) { + return false; + } + } + if (n == 0) { + return false; // EOF + } + got += n; + } + return true; +} + +bool read_framed_message(HANDLE h, ReceivedMessage &out, std::chrono::milliseconds timeout) { + auto t0 = std::chrono::steady_clock::now(); + uint8_t header[WIRE_HEADER_SIZE]; + if (!read_exact(h, {header, WIRE_HEADER_SIZE}, timeout)) { + return false; + } + ag::wire_utils::Reader r{{header, WIRE_HEADER_SIZE}}; + auto what = r.get_u32(); + auto len = r.get_u32(); + if (!what.has_value() || !len.has_value()) { + return false; + } + out.what = static_cast(*what); + out.payload.assign(*len, 0); + if (*len == 0) { + return true; + } + auto remaining = timeout - (std::chrono::steady_clock::now() - t0); + if (remaining < 0ms) { + remaining = 0ms; + } + return read_exact(h, {out.payload.data(), out.payload.size()}, + std::chrono::duration_cast(remaining)); +} + +// Probe whether the server-side has dropped the connection: any IO failure / EOF within the +// timeout is treated as "peer is gone". +bool wait_for_peer_disconnect(HANDLE h, std::chrono::milliseconds timeout) { + uint8_t byte = 0; + return !read_exact(h, {&byte, 1}, timeout); +} + +// RAII wrapper for a `loop()` invocation on a detached worker thread. The destructor signals +// the bound stop event to ask the loop to exit, then returns immediately without joining; the +// detached thread is left to wind down on its own. +// +// Rationale: the test fixture must remain responsive even if `loop()` itself is buggy and never +// observes the stop event. By detaching, the main thread (and the rest of the test suite) can +// always make progress; an actual `loop()` hang manifests as a test failure (because +// `wait_for(JOIN_TIMEOUT)` returns `timeout`) rather than as a deadlocked test process. +// +// CAVEAT: if `loop()` truly fails to exit, the detached thread will outlive the test's local +// `PipeServer`/`PipeClient`, and continued use of those references inside the loop is undefined +// behavior. A hung `loop()` is a real implementation bug that must be diagnosed and fixed; this +// wrapper just keeps the test harness alive long enough to report it. +class LoopRunner { +public: + template + LoopRunner(HANDLE stop_event, F &&f) + : m_stop_event{stop_event} + , m_state{std::make_shared()} { + auto state = m_state; + std::thread([state, fn = std::forward(f)]() mutable { + bool result = fn(); + std::scoped_lock l{state->lock}; + state->result = result; + state->done = true; + state->cv.notify_all(); + }).detach(); + } + + LoopRunner(const LoopRunner &) = delete; + LoopRunner &operator=(const LoopRunner &) = delete; + + ~LoopRunner() { + if (m_stop_event != nullptr) { + SetEvent(m_stop_event); + } + // Intentionally no join: see class-level comment. + } + + // Wait up to `t` for `loop()` to return. Returns the loop's return value if it completed + // within the timeout, or `std::nullopt` if the timeout was reached. + std::optional wait_for(std::chrono::milliseconds t) { + std::unique_lock l{m_state->lock}; + if (!m_state->cv.wait_for(l, t, [&] { + return m_state->done; + })) { + return std::nullopt; + } + return *m_state->result; + } + +private: + struct State { + std::mutex lock; + std::condition_variable cv; + bool done = false; + std::optional result; + }; + + HANDLE m_stop_event = nullptr; + std::shared_ptr m_state; +}; + +class PipeTest : public testing::Test { +protected: + PipeTest() { + ag::Logger::set_log_level(ag::LOG_LEVEL_DEBUG); + m_stop_event = Handle{CreateEventW(nullptr, TRUE, FALSE, nullptr)}; + m_pipe_name = unique_pipe_name(); + } + + void signal_stop() { + SetEvent(m_stop_event.get()); + } + + Handle m_stop_event; + std::wstring m_pipe_name; +}; + +} // namespace + +// --------------------------------------------------------------------------- +// PipeServer tests +// --------------------------------------------------------------------------- + +TEST_F(PipeTest, ServerLoopFailsImmediatelyOnInvalidPipeName) { + // An empty pipe name causes CreateNamedPipeW to fail; loop() must refuse to start. + MessageCollector collector; + PipeServer server{L"", m_stop_event.get(), collector.make_handler()}; + LoopRunner runner{m_stop_event.get(), [&] { + return server.loop(); + }}; + auto loop_result = runner.wait_for(JOIN_TIMEOUT); + ASSERT_TRUE(loop_result); + EXPECT_FALSE(*loop_result); +} + +TEST_F(PipeTest, ServerStopEventCausesGracefulExitWithNoClient) { + MessageCollector collector; + PipeServer server{m_pipe_name.c_str(), m_stop_event.get(), collector.make_handler()}; + LoopRunner runner{m_stop_event.get(), [&] { + return server.loop(); + }}; + + // Give the server a moment to post the overlapped ConnectNamedPipe. + std::this_thread::sleep_for(50ms); + signal_stop(); + + auto loop_result = runner.wait_for(JOIN_TIMEOUT); + ASSERT_TRUE(loop_result); + EXPECT_TRUE(*loop_result); + EXPECT_EQ(collector.count(), 0u); +} + +TEST_F(PipeTest, ServerReceivesSingleFramedMessage) { + MessageCollector collector; + PipeServer server{m_pipe_name.c_str(), m_stop_event.get(), collector.make_handler()}; + LoopRunner runner{m_stop_event.get(), [&] { + return server.loop(); + }}; + + Handle client = open_raw_client(m_pipe_name); + ASSERT_TRUE(client); + + const std::vector payload = {0xDE, 0xAD, 0xBE, 0xEF}; + auto frame = make_framed(VPN_EASY_SVC_MSG_START, payload); + ASSERT_TRUE(write_all(client.get(), frame)); + + ASSERT_TRUE(collector.wait_for_count(1, TEST_TIMEOUT)); + auto msgs = collector.snapshot(); + ASSERT_EQ(msgs.size(), 1u); + EXPECT_EQ(msgs[0].what, VPN_EASY_SVC_MSG_START); + EXPECT_EQ(msgs[0].payload, payload); + + signal_stop(); + auto loop_result = runner.wait_for(JOIN_TIMEOUT); + ASSERT_TRUE(loop_result); + EXPECT_TRUE(*loop_result); +} + +TEST_F(PipeTest, ServerReceivesMultipleConcatenatedMessages) { + MessageCollector collector; + PipeServer server{m_pipe_name.c_str(), m_stop_event.get(), collector.make_handler()}; + LoopRunner runner{m_stop_event.get(), [&] { + return server.loop(); + }}; + + Handle client = open_raw_client(m_pipe_name); + ASSERT_TRUE(client); + + // Three messages back-to-back in a single write. + std::vector combined; + auto append = [&](uint32_t what, const std::vector &p) { + auto f = make_framed(what, p); + combined.insert(combined.end(), f.begin(), f.end()); + }; + append(VPN_EASY_SVC_MSG_START, {}); + append(VPN_EASY_SVC_MSG_STOP, {1, 2, 3}); + append(VPN_EASY_SVC_MSG_STATE_CHANGED, {0xFF, 0xEE}); + ASSERT_TRUE(write_all(client.get(), combined)); + + ASSERT_TRUE(collector.wait_for_count(3, TEST_TIMEOUT)); + auto msgs = collector.snapshot(); + ASSERT_EQ(msgs.size(), 3u); + EXPECT_EQ(msgs[0].what, VPN_EASY_SVC_MSG_START); + EXPECT_TRUE(msgs[0].payload.empty()); + EXPECT_EQ(msgs[1].what, VPN_EASY_SVC_MSG_STOP); + EXPECT_EQ(msgs[1].payload, (std::vector{1, 2, 3})); + EXPECT_EQ(msgs[2].what, VPN_EASY_SVC_MSG_STATE_CHANGED); + EXPECT_EQ(msgs[2].payload, (std::vector{0xFF, 0xEE})); + + signal_stop(); + ASSERT_TRUE(runner.wait_for(JOIN_TIMEOUT)); +} + +TEST_F(PipeTest, ServerReassemblesMessageSplitAcrossWrites) { + MessageCollector collector; + PipeServer server{m_pipe_name.c_str(), m_stop_event.get(), collector.make_handler()}; + LoopRunner runner{m_stop_event.get(), [&] { + return server.loop(); + }}; + + Handle client = open_raw_client(m_pipe_name); + ASSERT_TRUE(client); + + std::vector payload(128); + for (size_t i = 0; i < payload.size(); ++i) { + payload[i] = static_cast(i); + } + auto frame = make_framed(VPN_EASY_SVC_MSG_CONNECTION_INFO, payload); + + // Header alone, then half the payload, then the rest. The brief sleeps make the test + // deterministic about reassembly across multiple ReadFile completions. + ASSERT_TRUE(write_all(client.get(), {frame.data(), WIRE_HEADER_SIZE})); + std::this_thread::sleep_for(50ms); + size_t half = payload.size() / 2; + ASSERT_TRUE(write_all(client.get(), {frame.data() + WIRE_HEADER_SIZE, half})); + std::this_thread::sleep_for(50ms); + ASSERT_TRUE(write_all(client.get(), {frame.data() + WIRE_HEADER_SIZE + half, payload.size() - half})); + + ASSERT_TRUE(collector.wait_for_count(1, TEST_TIMEOUT)); + auto msgs = collector.snapshot(); + ASSERT_EQ(msgs.size(), 1u); + EXPECT_EQ(msgs[0].what, VPN_EASY_SVC_MSG_CONNECTION_INFO); + EXPECT_EQ(msgs[0].payload, payload); + + signal_stop(); + ASSERT_TRUE(runner.wait_for(JOIN_TIMEOUT)); +} + +TEST_F(PipeTest, ServerReceivesAllMessagesWhenReadFileCompletesSynchronously) { + // Regression: previously, when ReadFile completed synchronously, start_read() returned + // without arming the next read or waking the loop, so subsequent incoming bytes were never + // observed. To exercise the sync-completion path we use a synchronous (non-overlapped) raw + // client and write many small messages before the server gets a chance to post its first + // overlapped ReadFile, so that the data is already buffered in the kernel and ReadFile + // returns synchronously. + MessageCollector collector; + PipeServer server{m_pipe_name.c_str(), m_stop_event.get(), collector.make_handler()}; + + // Open the synchronous client BEFORE spawning the loop, so the server's ConnectNamedPipe + // returns ERROR_PIPE_CONNECTED on the very first iteration and the kernel pipe buffer is + // already pre-filled when the first ReadFile is issued. + Handle client{CreateFileW(m_pipe_name.c_str(), GENERIC_READ | GENERIC_WRITE, 0, nullptr, OPEN_EXISTING, + 0 /* not FILE_FLAG_OVERLAPPED: synchronous writes */, nullptr)}; + ASSERT_TRUE(client); + + constexpr int MESSAGE_COUNT = 100; + for (int i = 0; i < MESSAGE_COUNT; ++i) { + uint8_t payload[2] = {static_cast(i >> 8), static_cast(i & 0xFF)}; + auto frame = make_framed(VPN_EASY_SVC_MSG_STATE_CHANGED, payload); + DWORD written = 0; + ASSERT_TRUE(WriteFile(client.get(), frame.data(), static_cast(frame.size()), &written, nullptr)); + ASSERT_EQ(written, frame.size()); + } + + LoopRunner runner{m_stop_event.get(), [&] { + return server.loop(); + }}; + + ASSERT_TRUE(collector.wait_for_count(MESSAGE_COUNT, TEST_TIMEOUT)); + auto msgs = collector.snapshot(); + ASSERT_EQ(msgs.size(), static_cast(MESSAGE_COUNT)); + for (int i = 0; i < MESSAGE_COUNT; ++i) { + ASSERT_EQ(msgs[i].payload.size(), 2u); + EXPECT_EQ((msgs[i].payload[0] << 8) | msgs[i].payload[1], i) << "at index " << i; + } + + signal_stop(); + ASSERT_TRUE(runner.wait_for(JOIN_TIMEOUT)); +} + +TEST_F(PipeTest, ServerAcceptsMaxSizedMessage) { + MessageCollector collector; + PipeServer server{m_pipe_name.c_str(), m_stop_event.get(), collector.make_handler()}; + LoopRunner runner{m_stop_event.get(), [&] { + return server.loop(); + }}; + + Handle client = open_raw_client(m_pipe_name); + ASSERT_TRUE(client); + + std::vector payload(MAX_MESSAGE_SIZE, 0xAB); + auto frame = make_framed(VPN_EASY_SVC_MSG_CONNECTION_INFO, payload); + ASSERT_TRUE(write_all(client.get(), frame)); + + ASSERT_TRUE(collector.wait_for_count(1, TEST_TIMEOUT)); + auto msgs = collector.snapshot(); + ASSERT_EQ(msgs.size(), 1u); + EXPECT_EQ(msgs[0].payload.size(), MAX_MESSAGE_SIZE); + + signal_stop(); + ASSERT_TRUE(runner.wait_for(JOIN_TIMEOUT)); +} + +TEST_F(PipeTest, ServerDropsConnectionAndReconnectsOnOversizedMessage) { + MessageCollector collector; + PipeServer server{m_pipe_name.c_str(), m_stop_event.get(), collector.make_handler()}; + LoopRunner runner{m_stop_event.get(), [&] { + return server.loop(); + }}; + + // First connection: send an oversized header; expect the server to drop us. + { + Handle client = open_raw_client(m_pipe_name); + ASSERT_TRUE(client); + auto frame = + make_framed_with_advertised_len(VPN_EASY_SVC_MSG_START, static_cast(MAX_MESSAGE_SIZE + 1)); + ASSERT_TRUE(write_all(client.get(), frame)); + EXPECT_TRUE(wait_for_peer_disconnect(client.get(), TEST_TIMEOUT)); + } + + // Second connection: the server should have reconnected and accept fresh traffic. + { + Handle client = open_raw_client(m_pipe_name); + ASSERT_TRUE(client); + const std::vector payload = {0x42}; + auto frame = make_framed(VPN_EASY_SVC_MSG_START, payload); + ASSERT_TRUE(write_all(client.get(), frame)); + ASSERT_TRUE(collector.wait_for_count(1, TEST_TIMEOUT)); + auto msgs = collector.snapshot(); + ASSERT_EQ(msgs.size(), 1u); + EXPECT_EQ(msgs[0].payload, payload); + } + + signal_stop(); + ASSERT_TRUE(runner.wait_for(JOIN_TIMEOUT)); +} + +TEST_F(PipeTest, ServerSendDeliversMessageToClient) { + MessageCollector collector; + PipeServer server{m_pipe_name.c_str(), m_stop_event.get(), collector.make_handler()}; + LoopRunner runner{m_stop_event.get(), [&] { + return server.loop(); + }}; + + Handle client = open_raw_client(m_pipe_name); + ASSERT_TRUE(client); + std::this_thread::sleep_for(50ms); // Let the server observe the connection. + + const std::vector payload = {1, 2, 3, 4, 5}; + server.send(VPN_EASY_SVC_MSG_STATE_CHANGED, {payload.data(), payload.size()}); + + ReceivedMessage rx{}; + ASSERT_TRUE(read_framed_message(client.get(), rx, TEST_TIMEOUT)); + EXPECT_EQ(rx.what, VPN_EASY_SVC_MSG_STATE_CHANGED); + EXPECT_EQ(rx.payload, payload); + + signal_stop(); + ASSERT_TRUE(runner.wait_for(JOIN_TIMEOUT)); +} + +TEST_F(PipeTest, ServerSendDropsMessageWhenNoPeerConnected) { + MessageCollector collector; + PipeServer server{m_pipe_name.c_str(), m_stop_event.get(), collector.make_handler()}; + LoopRunner runner{m_stop_event.get(), [&] { + return server.loop(); + }}; + + // Send a few messages with no peer connected; they must be dropped. + const std::vector dropped_payload = {0xAA}; + for (int i = 0; i < 5; ++i) { + server.send(VPN_EASY_SVC_MSG_STATE_CHANGED, {dropped_payload.data(), dropped_payload.size()}); + } + std::this_thread::sleep_for(50ms); + + Handle client = open_raw_client(m_pipe_name); + ASSERT_TRUE(client); + std::this_thread::sleep_for(50ms); + + // Send a sentinel via the live connection. Only the sentinel must arrive. + const std::vector sentinel_payload = {0x55}; + server.send(VPN_EASY_SVC_MSG_CONNECTION_INFO, {sentinel_payload.data(), sentinel_payload.size()}); + + ReceivedMessage rx{}; + ASSERT_TRUE(read_framed_message(client.get(), rx, TEST_TIMEOUT)); + EXPECT_EQ(rx.what, VPN_EASY_SVC_MSG_CONNECTION_INFO); + EXPECT_EQ(rx.payload, sentinel_payload); + + // No further messages should arrive (the dropped ones must not have been queued). + ReceivedMessage extra{}; + EXPECT_FALSE(read_framed_message(client.get(), extra, 200ms)); + + signal_stop(); + ASSERT_TRUE(runner.wait_for(JOIN_TIMEOUT)); +} + +TEST_F(PipeTest, ServerSendIsThreadSafe) { + MessageCollector collector; + PipeServer server{m_pipe_name.c_str(), m_stop_event.get(), collector.make_handler()}; + LoopRunner runner{m_stop_event.get(), [&] { + return server.loop(); + }}; + + Handle client = open_raw_client(m_pipe_name); + ASSERT_TRUE(client); + std::this_thread::sleep_for(50ms); + + constexpr int THREAD_COUNT = 4; + constexpr int PER_THREAD = 25; + std::vector senders; + senders.reserve(THREAD_COUNT); + for (int t = 0; t < THREAD_COUNT; ++t) { + senders.emplace_back([&, t] { + for (int i = 0; i < PER_THREAD; ++i) { + uint8_t payload[2] = {static_cast(t), static_cast(i)}; + server.send(VPN_EASY_SVC_MSG_STATE_CHANGED, {payload, 2}); + } + }); + } + for (auto &th : senders) { + th.join(); + } + + // Read all messages back. Per-thread submission order must be preserved (within each + // thread's stream); messages from different threads may interleave arbitrarily. + std::vector last_seen(THREAD_COUNT, -1); + for (int i = 0; i < THREAD_COUNT * PER_THREAD; ++i) { + ReceivedMessage rx{}; + ASSERT_TRUE(read_framed_message(client.get(), rx, TEST_TIMEOUT)) << "at message " << i; + ASSERT_EQ(rx.payload.size(), 2u); + int t = rx.payload[0]; + int idx = rx.payload[1]; + ASSERT_GE(t, 0); + ASSERT_LT(t, THREAD_COUNT); + EXPECT_GT(idx, last_seen[t]) << "messages from thread " << t << " out of order"; + last_seen[t] = idx; + } + + signal_stop(); + ASSERT_TRUE(runner.wait_for(JOIN_TIMEOUT)); +} + +TEST_F(PipeTest, ServerHandlerCanCallSend) { + // Handler echoes any incoming message back as VPN_EASY_SVC_MSG_STATE_CHANGED, exercising + // send() being called from the loop's own thread. + PipeServer *server_ptr = nullptr; + PipeEndpoint::Handler echo = [&](VpnEasyServiceMessageType, ag::Uint8View data) { + server_ptr->send(VPN_EASY_SVC_MSG_STATE_CHANGED, data); + }; + PipeServer server{m_pipe_name.c_str(), m_stop_event.get(), echo}; + server_ptr = &server; + LoopRunner runner{m_stop_event.get(), [&] { + return server.loop(); + }}; + + Handle client = open_raw_client(m_pipe_name); + ASSERT_TRUE(client); + + const std::vector payload = {0x10, 0x20, 0x30}; + auto frame = make_framed(VPN_EASY_SVC_MSG_START, payload); + ASSERT_TRUE(write_all(client.get(), frame)); + + ReceivedMessage rx{}; + ASSERT_TRUE(read_framed_message(client.get(), rx, TEST_TIMEOUT)); + EXPECT_EQ(rx.what, VPN_EASY_SVC_MSG_STATE_CHANGED); + EXPECT_EQ(rx.payload, payload); + + signal_stop(); + ASSERT_TRUE(runner.wait_for(JOIN_TIMEOUT)); +} + +TEST_F(PipeTest, ServerReconnectsAfterClientDisconnects) { + MessageCollector collector; + PipeServer server{m_pipe_name.c_str(), m_stop_event.get(), collector.make_handler()}; + LoopRunner runner{m_stop_event.get(), [&] { + return server.loop(); + }}; + + { + Handle client = open_raw_client(m_pipe_name); + ASSERT_TRUE(client); + auto frame = make_framed(VPN_EASY_SVC_MSG_START, {}); + ASSERT_TRUE(write_all(client.get(), frame)); + ASSERT_TRUE(collector.wait_for_count(1, TEST_TIMEOUT)); + // Client closes here. + } + + { + Handle client = open_raw_client(m_pipe_name); + ASSERT_TRUE(client); + const std::vector payload = {0x77}; + auto frame = make_framed(VPN_EASY_SVC_MSG_STOP, payload); + ASSERT_TRUE(write_all(client.get(), frame)); + ASSERT_TRUE(collector.wait_for_count(2, TEST_TIMEOUT)); + auto msgs = collector.snapshot(); + EXPECT_EQ(msgs[0].what, VPN_EASY_SVC_MSG_START); + EXPECT_EQ(msgs[0].payload.size(), 0); + EXPECT_EQ(msgs[1].what, VPN_EASY_SVC_MSG_STOP); + EXPECT_EQ(msgs[1].payload, payload); + } + + signal_stop(); + ASSERT_TRUE(runner.wait_for(JOIN_TIMEOUT)); +} + +TEST_F(PipeTest, ServerSendQueueFlushedOnDisconnectNewClientSeesNoStaleMessages) { + // Queue several server-to-client messages while client A is connected, then disconnect + // client A. Connect client B and verify it receives only messages sent after it connected, + // not stale messages from client A's session. + MessageCollector collector; + PipeServer server{m_pipe_name.c_str(), m_stop_event.get(), collector.make_handler()}; + LoopRunner runner{m_stop_event.get(), [&] { + return server.loop(); + }}; + + // Client A: connect, let the server queue messages, then disconnect without reading them. + { + Handle client_a = open_raw_client(m_pipe_name); + ASSERT_TRUE(client_a); + std::this_thread::sleep_for(50ms); // Let the server observe the connection. + + const std::vector stale_payload = {0xAA, 0xBB}; + for (int i = 0; i < 5; ++i) { + server.send(VPN_EASY_SVC_MSG_STATE_CHANGED, {stale_payload.data(), stale_payload.size()}); + } + // Give the loop a moment to pick up the sends (but don't read from client_a). + std::this_thread::sleep_for(50ms); + // Client A disconnects here. + } + + // Client B: connect and send a sentinel so the server queues a fresh message. + { + Handle client_b = open_raw_client(m_pipe_name); + ASSERT_TRUE(client_b); + std::this_thread::sleep_for(50ms); // Let the server observe the reconnection. + + const std::vector fresh_payload = {0xCC, 0xDD}; + server.send(VPN_EASY_SVC_MSG_CONNECTION_INFO, {fresh_payload.data(), fresh_payload.size()}); + + // Read the first message from client B: it must be the fresh one, not stale. + ReceivedMessage rx{}; + ASSERT_TRUE(read_framed_message(client_b.get(), rx, TEST_TIMEOUT)); + EXPECT_EQ(rx.what, VPN_EASY_SVC_MSG_CONNECTION_INFO); + EXPECT_EQ(rx.payload, fresh_payload); + + // Verify no extra (stale) messages follow. + ReceivedMessage extra{}; + EXPECT_FALSE(read_framed_message(client_b.get(), extra, 200ms)); + } + + signal_stop(); + ASSERT_TRUE(runner.wait_for(JOIN_TIMEOUT)); +} + +TEST_F(PipeTest, ServerStopEventDuringActiveConnectionExitsCleanly) { + MessageCollector collector; + PipeServer server{m_pipe_name.c_str(), m_stop_event.get(), collector.make_handler()}; + LoopRunner runner{m_stop_event.get(), [&] { + return server.loop(); + }}; + + Handle client = open_raw_client(m_pipe_name); + ASSERT_TRUE(client); + auto frame = make_framed(VPN_EASY_SVC_MSG_START, {}); + ASSERT_TRUE(write_all(client.get(), frame)); + ASSERT_TRUE(collector.wait_for_count(1, TEST_TIMEOUT)); + + signal_stop(); + auto loop_result = runner.wait_for(JOIN_TIMEOUT); + ASSERT_TRUE(loop_result); + EXPECT_TRUE(*loop_result); +} + +TEST_F(PipeTest, ServerWithAuthenticatedUsersDescriptorAcceptsConnections) { + auto sd = PipeServer::for_authenticated_users(); + ASSERT_TRUE(sd); + + MessageCollector collector; + PipeServer server{m_pipe_name.c_str(), m_stop_event.get(), collector.make_handler(), sd.get()}; + + // The descriptor is documented as consumed synchronously; freeing it now must be safe. + sd.reset(); + + LoopRunner runner{m_stop_event.get(), [&] { + return server.loop(); + }}; + + Handle client = open_raw_client(m_pipe_name); + ASSERT_TRUE(client); + auto frame = make_framed(VPN_EASY_SVC_MSG_START, {}); + ASSERT_TRUE(write_all(client.get(), frame)); + ASSERT_TRUE(collector.wait_for_count(1, TEST_TIMEOUT)); + + signal_stop(); + ASSERT_TRUE(runner.wait_for(JOIN_TIMEOUT)); +} + +// --------------------------------------------------------------------------- +// PipeClient tests +// --------------------------------------------------------------------------- + +TEST_F(PipeTest, ClientLoopFailsImmediatelyWhenServerNotPresent) { + MessageCollector collector; + PipeClient client{m_pipe_name.c_str(), m_stop_event.get(), collector.make_handler(), std::chrono::milliseconds(50)}; + LoopRunner runner{m_stop_event.get(), [&] { + return client.loop(); + }}; + auto loop_result = runner.wait_for(JOIN_TIMEOUT); + ASSERT_TRUE(loop_result); + EXPECT_FALSE(*loop_result); +} + +TEST_F(PipeTest, ClientServerExchangeMessages) { + // Server echoes any incoming message back as VPN_EASY_SVC_MSG_STATE_CHANGED. + PipeServer *server_ptr = nullptr; + PipeEndpoint::Handler server_handler = [&](VpnEasyServiceMessageType, ag::Uint8View data) { + server_ptr->send(VPN_EASY_SVC_MSG_STATE_CHANGED, data); + }; + PipeServer server{m_pipe_name.c_str(), m_stop_event.get(), server_handler}; + server_ptr = &server; + LoopRunner server_runner{m_stop_event.get(), [&] { + return server.loop(); + }}; + + Handle client_stop{CreateEventW(nullptr, TRUE, FALSE, nullptr)}; + MessageCollector client_collector; + PipeClient client{m_pipe_name.c_str(), client_stop.get(), client_collector.make_handler()}; + LoopRunner client_runner{client_stop.get(), [&] { + return client.loop(); + }}; + + ASSERT_TRUE(client.wait_connected()); + const std::vector payload = {0xAB, 0xCD, 0xEF}; + client.send(VPN_EASY_SVC_MSG_START, {payload.data(), payload.size()}); + + ASSERT_TRUE(client_collector.wait_for_count(1, TEST_TIMEOUT)); + auto msgs = client_collector.snapshot(); + ASSERT_EQ(msgs.size(), 1u); + EXPECT_EQ(msgs[0].what, VPN_EASY_SVC_MSG_STATE_CHANGED); + EXPECT_EQ(msgs[0].payload, payload); + + SetEvent(client_stop.get()); + auto client_result = client_runner.wait_for(JOIN_TIMEOUT); + ASSERT_TRUE(client_result); + EXPECT_TRUE(*client_result); + + signal_stop(); + ASSERT_TRUE(server_runner.wait_for(JOIN_TIMEOUT)); +} + +TEST_F(PipeTest, ClientLoopExitsOnPeerDisconnect) { + // Build a raw single-instance overlapped server pipe by hand. + Handle server_pipe{CreateNamedPipeW(m_pipe_name.c_str(), PIPE_ACCESS_DUPLEX | FILE_FLAG_OVERLAPPED, + PIPE_TYPE_BYTE | PIPE_READMODE_BYTE | PIPE_WAIT, 1, 64 * 1024, 64 * 1024, 0, nullptr)}; + ASSERT_TRUE(server_pipe); + + MessageCollector collector; + PipeClient client{m_pipe_name.c_str(), m_stop_event.get(), collector.make_handler()}; + LoopRunner runner{m_stop_event.get(), [&] { + return client.loop(); + }}; + + // Accept the client's connection via overlapped ConnectNamedPipe. + OVERLAPPED ol{}; + Handle ev{CreateEventW(nullptr, TRUE, FALSE, nullptr)}; + ol.hEvent = ev.get(); + BOOL ok = ConnectNamedPipe(server_pipe.get(), &ol); + DWORD err = GetLastError(); + if (!ok && err == ERROR_IO_PENDING) { + ASSERT_EQ(WaitForSingleObject(ol.hEvent, + static_cast( + std::chrono::duration_cast(TEST_TIMEOUT).count())), + WAIT_OBJECT_0); + DWORD t = 0; + ASSERT_TRUE(GetOverlappedResult(server_pipe.get(), &ol, &t, FALSE)); + } else if (!ok) { + ASSERT_EQ(err, static_cast(ERROR_PIPE_CONNECTED)); + } + + // Tear down the server side; the client's loop() must exit gracefully (returning true). + DisconnectNamedPipe(server_pipe.get()); + server_pipe.reset(); + + auto loop_result = runner.wait_for(JOIN_TIMEOUT); + ASSERT_TRUE(loop_result); + EXPECT_TRUE(*loop_result); +} + +TEST_F(PipeTest, ClientStopEventCausesGracefulExit) { + MessageCollector server_collector; + PipeServer server{m_pipe_name.c_str(), m_stop_event.get(), server_collector.make_handler()}; + LoopRunner server_runner{m_stop_event.get(), [&] { + return server.loop(); + }}; + + Handle client_stop{CreateEventW(nullptr, TRUE, FALSE, nullptr)}; + MessageCollector client_collector; + PipeClient client{m_pipe_name.c_str(), client_stop.get(), client_collector.make_handler()}; + LoopRunner client_runner{client_stop.get(), [&] { + return client.loop(); + }}; + + std::this_thread::sleep_for(100ms); + SetEvent(client_stop.get()); + auto client_result = client_runner.wait_for(JOIN_TIMEOUT); + ASSERT_TRUE(client_result); + EXPECT_TRUE(*client_result); + + signal_stop(); + ASSERT_TRUE(server_runner.wait_for(JOIN_TIMEOUT)); +} + +TEST_F(PipeTest, ClientCanReconnectViaFreshInstance) { + // The documented pattern: after the client's loop() exits, the caller may construct a new + // PipeClient to reconnect. Verify two successive client instances both connect successfully. + MessageCollector server_collector; + PipeServer server{m_pipe_name.c_str(), m_stop_event.get(), server_collector.make_handler()}; + LoopRunner server_runner{m_stop_event.get(), [&] { + return server.loop(); + }}; + + auto run_one_client = [&] { + Handle stop{CreateEventW(nullptr, TRUE, FALSE, nullptr)}; + MessageCollector collector; + PipeClient client{m_pipe_name.c_str(), stop.get(), collector.make_handler()}; + LoopRunner runner{stop.get(), [&] { + return client.loop(); + }}; + ASSERT_TRUE(client.wait_connected()); + const std::vector payload = {0x01}; + client.send(VPN_EASY_SVC_MSG_START, {payload.data(), payload.size()}); + server_collector.wait_for_count(server_collector.count() + 1, TEST_TIMEOUT); + SetEvent(stop.get()); + auto loop_result = runner.wait_for(JOIN_TIMEOUT); + ASSERT_TRUE(loop_result); + EXPECT_TRUE(*loop_result); + }; + + run_one_client(); + run_one_client(); + EXPECT_GE(server_collector.count(), 2u); + + signal_stop(); + auto loop_result = server_runner.wait_for(JOIN_TIMEOUT); + ASSERT_TRUE(loop_result); + EXPECT_TRUE(*loop_result); +} + +TEST_F(PipeTest, ClientStartConnectRetriesUntilServerInstanceBecomesAvailable) { + // Regression: PipeClient::start_connect must tolerate the brief race window during which the + // single-instance server is mid-reconnect (CreateFileW returns ERROR_PIPE_BUSY or + // ERROR_FILE_NOT_FOUND). Construct the client BEFORE the server exists; the connect must + // succeed once the server appears within the retry deadline. + Handle client_stop{CreateEventW(nullptr, TRUE, FALSE, nullptr)}; + MessageCollector client_collector; + PipeClient client{m_pipe_name.c_str(), client_stop.get(), client_collector.make_handler(), std::chrono::seconds(2)}; + LoopRunner client_runner{client_stop.get(), [&] { + return client.loop(); + }}; + + // Bring the server up after a short delay -- well within the client's retry budget. + std::this_thread::sleep_for(200ms); + MessageCollector server_collector; + PipeServer server{m_pipe_name.c_str(), m_stop_event.get(), server_collector.make_handler()}; + LoopRunner server_runner{m_stop_event.get(), [&] { + return server.loop(); + }}; + + // Once connected, the client should be able to send and the server should observe it. + ASSERT_TRUE(client.wait_connected()); + const std::vector payload = {0x99}; + client.send(VPN_EASY_SVC_MSG_START, {payload.data(), payload.size()}); + ASSERT_TRUE(server_collector.wait_for_count(1, TEST_TIMEOUT)); + + SetEvent(client_stop.get()); + auto client_result = client_runner.wait_for(JOIN_TIMEOUT); + ASSERT_TRUE(client_result); + EXPECT_TRUE(*client_result); + + signal_stop(); + auto server_result = server_runner.wait_for(JOIN_TIMEOUT); + ASSERT_TRUE(server_result); + ASSERT_TRUE(*server_result); +} + +TEST_F(PipeTest, ClientStartConnectHonorsStopEventDuringRetry) { + // No server exists; the client should keep retrying until the stop event is signaled, and + // then return promptly (well under the configured connect timeout). + Handle client_stop{CreateEventW(nullptr, TRUE, FALSE, nullptr)}; + MessageCollector client_collector; + PipeClient client{m_pipe_name.c_str(), client_stop.get(), client_collector.make_handler(), std::chrono::seconds(5)}; + LoopRunner client_runner{client_stop.get(), [&] { + return client.loop(); + }}; + + std::this_thread::sleep_for(150ms); // Let the client observe at least one retry slice. + SetEvent(client_stop.get()); + auto client_result = client_runner.wait_for(std::chrono::seconds(2)); + ASSERT_TRUE(client_result); + // start_connect() returning false on stop-event during retry is the contract: the loop never + // truly began, so reporting a fatal start failure is the only available signal. + EXPECT_FALSE(*client_result); +} + +TEST_F(PipeTest, ClientWaitConnectedReturnsTrueAfterSuccessfulConnect) { + MessageCollector server_collector; + PipeServer server{m_pipe_name.c_str(), m_stop_event.get(), server_collector.make_handler()}; + LoopRunner server_runner{m_stop_event.get(), [&] { + return server.loop(); + }}; + + Handle client_stop{CreateEventW(nullptr, TRUE, FALSE, nullptr)}; + MessageCollector client_collector; + PipeClient client{m_pipe_name.c_str(), client_stop.get(), client_collector.make_handler()}; + LoopRunner client_runner{client_stop.get(), [&] { + return client.loop(); + }}; + + EXPECT_TRUE(client.wait_connected()); + + SetEvent(client_stop.get()); + ASSERT_TRUE(client_runner.wait_for(JOIN_TIMEOUT)); + signal_stop(); + ASSERT_TRUE(server_runner.wait_for(JOIN_TIMEOUT)); +} + +TEST_F(PipeTest, ClientWaitConnectedReturnsFalseOnConnectFailure) { + // No server: the client's loop() will fail to connect within the (short) connect timeout. + // wait_connected() must wake on that failure and return false rather than wait the full + // user-supplied connect timeout. + Handle client_stop{CreateEventW(nullptr, TRUE, FALSE, nullptr)}; + MessageCollector client_collector; + PipeClient client{ + m_pipe_name.c_str(), client_stop.get(), client_collector.make_handler(), std::chrono::milliseconds(100)}; + LoopRunner client_runner{client_stop.get(), [&] { + return client.loop(); + }}; + + auto t0 = std::chrono::steady_clock::now(); + EXPECT_FALSE(client.wait_connected()); + auto elapsed = std::chrono::steady_clock::now() - t0; + EXPECT_LT(elapsed, std::chrono::seconds(2)) << "wait_connected did not wake on connect failure"; + + auto client_result = client_runner.wait_for(JOIN_TIMEOUT); + ASSERT_TRUE(client_result); + EXPECT_FALSE(*client_result); +} + +TEST_F(PipeTest, ClientWaitConnectedReturnsFalseOnTimeout) { + // No server, no loop running: wait_connected() must respect its own timeout and return false. + Handle client_stop{CreateEventW(nullptr, TRUE, FALSE, nullptr)}; + MessageCollector client_collector; + PipeClient client{m_pipe_name.c_str(), client_stop.get(), client_collector.make_handler()}; + + auto t0 = std::chrono::steady_clock::now(); + EXPECT_FALSE(client.wait_connected()); + auto elapsed = std::chrono::steady_clock::now() - t0; + EXPECT_GE(elapsed, 100ms); + EXPECT_LT(elapsed, std::chrono::seconds(1)); +} + +TEST_F(PipeTest, ClientWaitConnectedReturnsFalseOnStopEvent) { + // No server, loop running: signal stop while wait_connected is blocked. It must return + // false promptly (well under the connect timeout). + Handle client_stop{CreateEventW(nullptr, TRUE, FALSE, nullptr)}; + MessageCollector client_collector; + PipeClient client{m_pipe_name.c_str(), client_stop.get(), client_collector.make_handler(), std::chrono::seconds(5)}; + LoopRunner client_runner{client_stop.get(), [&] { + return client.loop(); + }}; + + std::thread signaler([&] { + std::this_thread::sleep_for(100ms); + SetEvent(client_stop.get()); + }); + + auto t0 = std::chrono::steady_clock::now(); + EXPECT_FALSE(client.wait_connected()); + auto elapsed = std::chrono::steady_clock::now() - t0; + EXPECT_LT(elapsed, std::chrono::seconds(2)); + + signaler.join(); + auto client_result = client_runner.wait_for(JOIN_TIMEOUT); + ASSERT_TRUE(client_result); +} + +TEST_F(PipeTest, ClientConnectTimeoutZeroSelectsDefault) { + // A connect_timeout of 0 must select PipeClient::DEFAULT_CONNECT_TIMEOUT (500 ms). With no + // server, loop() should fail at roughly that wallclock duration. + Handle client_stop{CreateEventW(nullptr, TRUE, FALSE, nullptr)}; + MessageCollector client_collector; + PipeClient client{ + m_pipe_name.c_str(), client_stop.get(), client_collector.make_handler(), std::chrono::milliseconds(0)}; + LoopRunner client_runner{client_stop.get(), [&] { + return client.loop(); + }}; + + auto t0 = std::chrono::steady_clock::now(); + auto client_result = client_runner.wait_for(std::chrono::seconds(3)); + auto elapsed = std::chrono::steady_clock::now() - t0; + ASSERT_TRUE(client_result); + EXPECT_FALSE(*client_result); + EXPECT_GE(elapsed, std::chrono::milliseconds(400)); + EXPECT_LT(elapsed, std::chrono::seconds(2)); +} + +// --------------------------------------------------------------------------- +// SecurityDescriptorPtr / for_authenticated_users +// --------------------------------------------------------------------------- + +TEST(PipeSecurityDescriptor, ForAuthenticatedUsersReturnsValidDescriptor) { + auto sd = PipeServer::for_authenticated_users(); + ASSERT_TRUE(sd); + ASSERT_NE(sd.get(), nullptr); + EXPECT_TRUE(IsValidSecurityDescriptor(sd.get())); +} diff --git a/platform/windows/test/vpn_easy_service_test.cpp b/platform/windows/test/vpn_easy_service_test.cpp new file mode 100644 index 0000000..b5d0280 --- /dev/null +++ b/platform/windows/test/vpn_easy_service_test.cpp @@ -0,0 +1,166 @@ +#include "vpn/vpn.h" +#include "vpn/vpn_easy.h" +#include "vpn/vpn_easy_service.h" + +#include +#include +#include +#include + +#include +#include + +#include "common/logger.h" + +static constexpr const wchar_t *SERVICE_NAME = L"vpn_easy_service"; +static constexpr const wchar_t *PIPE_NAME = L"\\\\.\\pipe\\TestPipeName"; + +static void state_changed_cb(void *, int state) { + fmt::println(stderr, "VPN state changed: ({}) {}", state, + magic_enum::enum_name(static_cast(state))); +} + +/// Read config.toml into a string. Return empty string on failure. +static std::string read_config() { + std::ifstream in("config.toml"); + std::stringstream buf; + buf << in.rdbuf(); + if (in.fail()) { + fmt::println(stderr, "Failed to read config.toml"); + return {}; + } + return buf.str(); +} + +/// Install the service. If it already exists, uninstall first and retry. +static int32_t install_service() { + auto image = absolute(std::filesystem::path(".") / "vpn_easy_service.exe").wstring(); + auto logfile = absolute(std::filesystem::path(".") / "vpn_easy_service.log").wstring(); + + int32_t ret = vpn_easy_service_install( + image.c_str(), logfile.c_str(), PIPE_NAME, SERVICE_NAME, L"VPN easy service", L"Test description"); + if (ret == VPN_EASY_SVC_ERR_SERVICE_EXISTS) { + fmt::println(stderr, "Service already exists, uninstalling first..."); + vpn_easy_service_uninstall(SERVICE_NAME); + ret = vpn_easy_service_install( + image.c_str(), logfile.c_str(), PIPE_NAME, SERVICE_NAME, L"VPN easy service", L"Test description"); + } + return ret; +} + +/// Test install and uninstall only. +static int test_install_uninstall() { + fmt::println(stderr, "=== test_install_uninstall ==="); + + fmt::println(stderr, "Installing service..."); + int32_t ret = install_service(); + if (ret) { + fmt::println(stderr, "vpn_easy_service_install: {}", ret); + return -1; + } + + fmt::println(stderr, "Type 's' to stop"); + while (getchar() != 's') { + } + + ret = vpn_easy_service_uninstall(SERVICE_NAME); + if (ret) { + fmt::println(stderr, "vpn_easy_service_uninstall: {}", ret); + return -1; + } + + return 0; +} + +/// Test start and stop via the pipe client (requires service to be installed already). +static int test_start_stop() { + fmt::println(stderr, "=== test_start_stop ==="); + + std::string config = read_config(); + if (config.empty()) { + return -1; + } + + fmt::println(stderr, "Starting service..."); + int32_t ret = vpn_easy_service_start(SERVICE_NAME, PIPE_NAME, config.c_str(), state_changed_cb, nullptr); + if (ret) { + fmt::println(stderr, "vpn_easy_service_start: {}", ret); + return -1; + } + fmt::println(stderr, "Service started. Type 's' to stop"); + while (getchar() != 's') { + } + + fmt::println(stderr, "Stopping service..."); + ret = vpn_easy_service_stop(SERVICE_NAME, PIPE_NAME); + if (ret) { + fmt::println(stderr, "vpn_easy_service_stop: {}", ret); + return -1; + } + fmt::println(stderr, "Service stopped."); + + return 0; +} + +/// Test full lifecycle: install, start, stop, uninstall. +static int test_full_lifecycle() { + fmt::println(stderr, "=== test_full_lifecycle ==="); + + std::string config = read_config(); + if (config.empty()) { + return -1; + } + + fmt::println(stderr, "Installing service..."); + int32_t ret = install_service(); + if (ret) { + fmt::println(stderr, "vpn_easy_service_install: {}", ret); + return -1; + } + + fmt::println(stderr, "Starting VPN via service..."); + ret = vpn_easy_service_start(SERVICE_NAME, PIPE_NAME, config.c_str(), state_changed_cb, nullptr); + if (ret) { + fmt::println(stderr, "vpn_easy_service_start: {}", ret); + vpn_easy_service_uninstall(SERVICE_NAME); + return -1; + } + fmt::println(stderr, "VPN started. Type 's' to stop"); + while (getchar() != 's') { + } + + fmt::println(stderr, "Stopping VPN via service..."); + ret = vpn_easy_service_stop(SERVICE_NAME, PIPE_NAME); + if (ret) { + fmt::println(stderr, "vpn_easy_service_stop: {}", ret); + } + + fmt::println(stderr, "Uninstalling service..."); + ret = vpn_easy_service_uninstall(SERVICE_NAME); + if (ret) { + fmt::println(stderr, "vpn_easy_service_uninstall: {}", ret); + return -1; + } + + fmt::println(stderr, "Done."); + return 0; +} + +int main(int argc, char **argv) { + ag::Logger::set_log_level(ag::LOG_LEVEL_DEBUG); + + const char *test = (argc > 1) ? argv[1] : "full"; + + if (strcmp(test, "install") == 0) { + return test_install_uninstall(); + } + if (strcmp(test, "startstop") == 0) { + return test_start_stop(); + } + if (strcmp(test, "full") == 0) { + return test_full_lifecycle(); + } + + fmt::println(stderr, "Usage: {} [install|startstop|full]", argv[0]); + return 1; +} diff --git a/platform/windows/test/vpn_easy_test.cpp b/platform/windows/test/vpn_easy_test.cpp index 1420daf..c9be744 100644 --- a/platform/windows/test/vpn_easy_test.cpp +++ b/platform/windows/test/vpn_easy_test.cpp @@ -1,11 +1,14 @@ +#include "vpn/vpn.h" #include "vpn/vpn_easy.h" #include #include #include -static void state_changed_cb(void *, const char *new_state_description) { - fprintf(stderr, "VPN state changed: %s\n", new_state_description); +#include + +static void state_changed_cb(void *, int state) { + fprintf(stderr, "VPN state changed: (%d) %s\n", state, magic_enum::enum_name((ag::VpnSessionState) state).data()); } int main() { @@ -18,12 +21,12 @@ int main() { } in.close(); - vpn_easy_t *vpn = vpn_easy_start(config.str().c_str(), state_changed_cb, nullptr); + vpn_easy_t *vpn = vpn_easy_start_ex(config.str().c_str(), state_changed_cb, nullptr, nullptr, nullptr); fprintf(stderr, "Type 's' to stop"); while (getchar() != 's') { } - vpn_easy_stop(vpn); + vpn_easy_stop_ex(vpn); return 0; }