From 71942d53eb2fbd14981c52c40861530ccf63112f Mon Sep 17 00:00:00 2001 From: Andrew Kaster Date: Thu, 20 Feb 2025 04:18:37 -0700 Subject: [PATCH] LibWebSocket+RequestServer: Add a WebSocketImpl using libcurl This implementation can be better improved in the future by ripping out a lot of the manual logic in LibWebSocket and rely on libcurl to parse our message payloads. But for now, this uses the 'raw mode' of curl websockets in connect-only mode to allow for somewhat seamless integration into our event loop. --- Libraries/LibWebSocket/Impl/WebSocketImpl.h | 2 + Libraries/LibWebSocket/WebSocket.cpp | 11 +- Services/RequestServer/CMakeLists.txt | 1 + Services/RequestServer/ConnectionFromClient.h | 2 + Services/RequestServer/WebSocketImplCurl.cpp | 187 ++++++++++++++++++ Services/RequestServer/WebSocketImplCurl.h | 45 +++++ vcpkg.json | 4 +- 7 files changed, 248 insertions(+), 4 deletions(-) create mode 100644 Services/RequestServer/WebSocketImplCurl.cpp create mode 100644 Services/RequestServer/WebSocketImplCurl.h diff --git a/Libraries/LibWebSocket/Impl/WebSocketImpl.h b/Libraries/LibWebSocket/Impl/WebSocketImpl.h index 7ce8b56ef16..f5c28e822dc 100644 --- a/Libraries/LibWebSocket/Impl/WebSocketImpl.h +++ b/Libraries/LibWebSocket/Impl/WebSocketImpl.h @@ -27,6 +27,8 @@ public: virtual bool eof() = 0; virtual void discard_connection() = 0; + virtual bool handshake_complete_when_connected() const { return false; } + Function on_connected; Function on_connection_error; Function on_ready_to_read; diff --git a/Libraries/LibWebSocket/WebSocket.cpp b/Libraries/LibWebSocket/WebSocket.cpp index 6d59aaf1a5e..9024e5410c7 100644 --- a/Libraries/LibWebSocket/WebSocket.cpp +++ b/Libraries/LibWebSocket/WebSocket.cpp @@ -42,9 +42,14 @@ void WebSocket::start() m_impl->on_connected = [this] { if (m_state != WebSocket::InternalState::EstablishingProtocolConnection) return; - set_state(WebSocket::InternalState::SendingClientHandshake); - send_client_handshake(); - drain_read(); + if (m_impl->handshake_complete_when_connected()) { + set_state(WebSocket::InternalState::Open); + notify_open(); + } else { + set_state(WebSocket::InternalState::SendingClientHandshake); + send_client_handshake(); + drain_read(); + } }; m_impl->on_ready_to_read = [this] { drain_read(); diff --git a/Services/RequestServer/CMakeLists.txt b/Services/RequestServer/CMakeLists.txt index 4d983f422e6..097511207b9 100644 --- a/Services/RequestServer/CMakeLists.txt +++ b/Services/RequestServer/CMakeLists.txt @@ -4,6 +4,7 @@ set(CMAKE_AUTOUIC OFF) set(SOURCES ConnectionFromClient.cpp + WebSocketImplCurl.cpp ) if (ANDROID) diff --git a/Services/RequestServer/ConnectionFromClient.h b/Services/RequestServer/ConnectionFromClient.h index 429133e8616..cde4d05bbeb 100644 --- a/Services/RequestServer/ConnectionFromClient.h +++ b/Services/RequestServer/ConnectionFromClient.h @@ -77,4 +77,6 @@ private: NonnullRefPtr m_resolver; }; +constexpr inline uintptr_t websocket_private_tag = 0x1; + } diff --git a/Services/RequestServer/WebSocketImplCurl.cpp b/Services/RequestServer/WebSocketImplCurl.cpp new file mode 100644 index 00000000000..3d6d7d1e25d --- /dev/null +++ b/Services/RequestServer/WebSocketImplCurl.cpp @@ -0,0 +1,187 @@ +/* + * Copyright (c) 2025, Andrew Kaster + * + * SPDX-License-Identifier: BSD-2-Clause + */ + +#include + +namespace RequestServer { + +NonnullRefPtr WebSocketImplCurl::create(CURLM* multi_handle) +{ + return adopt_ref(*new WebSocketImplCurl(multi_handle)); +} + +WebSocketImplCurl::WebSocketImplCurl(CURLM* multi_handle) + : m_multi_handle(multi_handle) +{ +} + +WebSocketImplCurl::~WebSocketImplCurl() +{ + if (m_read_notifier) + m_read_notifier->close(); + if (m_error_notifier) + m_error_notifier->close(); + + if (m_easy_handle) { + curl_multi_remove_handle(m_multi_handle, m_easy_handle); + curl_easy_cleanup(m_easy_handle); + } + + for (auto* list : m_curl_string_lists) { + curl_slist_free_all(list); + } +} + +void WebSocketImplCurl::connect(WebSocket::ConnectionInfo const& info) +{ + VERIFY(!m_easy_handle); + VERIFY(on_connected); + VERIFY(on_connection_error); + VERIFY(on_ready_to_read); + + m_easy_handle = curl_easy_init(); + VERIFY(m_easy_handle); // FIXME: Allow failure, and return ENOMEM + + auto set_option = [this](auto option, auto value) -> bool { + auto result = curl_easy_setopt(m_easy_handle, option, value); + if (result == CURLE_OK) + return true; + dbgln("WebSocketImplCurl::connect: Failed to set curl option {}={}: {}", to_underlying(option), value, curl_easy_strerror(result)); + return false; + }; + + set_option(CURLOPT_PRIVATE, reinterpret_cast(this) | websocket_private_tag); + set_option(CURLOPT_WS_OPTIONS, CURLWS_RAW_MODE); + set_option(CURLOPT_CONNECT_ONLY, 2); // WebSocket mode + + // FIXME: Add a header function to validate the Sec-WebSocket headers that curl currently doesn't validate + + auto const& url = info.url(); + set_option(CURLOPT_URL, url.to_byte_string().characters()); + set_option(CURLOPT_PORT, url.port_or_default()); + + if (auto root_certs = info.root_certificates_path(); root_certs.has_value()) + set_option(CURLOPT_CAINFO, root_certs->characters()); + + auto const origin_header = ByteString::formatted("Origin: {}", info.origin()); + curl_slist* curl_headers = curl_slist_append(nullptr, origin_header.characters()); + + for (auto const& [name, value] : info.headers().headers()) { + // curl will discard headers with empty values unless we pass the header name followed by a semicolon. + ByteString header_string; + if (value.is_empty()) + header_string = ByteString::formatted("{};", name); + else + header_string = ByteString::formatted("{}: {}", name, value); + curl_headers = curl_slist_append(curl_headers, header_string.characters()); + } + + if (auto const& protocols = info.protocols(); !protocols.is_empty()) { + StringBuilder protocol_builder; + protocol_builder.append("Sec-WebSocket-Protocol: "sv); + protocol_builder.append(ByteString::join(","sv, protocols)); + curl_headers = curl_slist_append(curl_headers, protocol_builder.to_byte_string().characters()); + } + + if (auto const& extensions = info.extensions(); !extensions.is_empty()) { + StringBuilder protocol_builder; + protocol_builder.append("Sec-WebSocket-Extensions: "sv); + protocol_builder.append(ByteString::join(","sv, extensions)); + curl_headers = curl_slist_append(curl_headers, protocol_builder.to_byte_string().characters()); + } + + set_option(CURLOPT_HTTPHEADER, curl_headers); + m_curl_string_lists.append(curl_headers); + + CURLMcode const err = curl_multi_add_handle(m_multi_handle, m_easy_handle); + VERIFY(err == CURLM_OK); +} + +bool WebSocketImplCurl::can_read_line() +{ + VERIFY_NOT_REACHED(); +} + +ErrorOr WebSocketImplCurl::read(int max_size) +{ + auto buffer = TRY(ByteBuffer::create_uninitialized(max_size)); + auto const read_bytes = TRY(m_read_buffer.read_some(buffer)); + return buffer.slice(0, read_bytes.size()); +} + +ErrorOr WebSocketImplCurl::read_line(size_t) +{ + VERIFY_NOT_REACHED(); +} + +bool WebSocketImplCurl::send(ReadonlyBytes bytes) +{ + size_t sent = 0; + CURLcode result = CURLE_OK; + do { + sent = 0; + result = curl_easy_send(m_easy_handle, bytes.data(), bytes.size(), &sent); + bytes = bytes.slice(sent); + } while (bytes.size() > 0 && (result == CURLE_OK || result == CURLE_AGAIN)); + + return result == CURLE_OK; +} + +bool WebSocketImplCurl::eof() +{ + return m_read_buffer.is_eof(); +} + +void WebSocketImplCurl::discard_connection() +{ + if (m_read_notifier) { + m_read_notifier->close(); + m_read_notifier = nullptr; + } + if (m_error_notifier) { + m_error_notifier->close(); + m_error_notifier = nullptr; + } + if (m_easy_handle) { + curl_multi_remove_handle(m_multi_handle, m_easy_handle); + curl_easy_cleanup(m_easy_handle); + m_easy_handle = nullptr; + } +} + +void WebSocketImplCurl::did_connect() +{ + curl_socket_t socket_fd = CURL_SOCKET_BAD; + auto res = curl_easy_getinfo(m_easy_handle, CURLINFO_ACTIVESOCKET, &socket_fd); + VERIFY(res == CURLE_OK && socket_fd != CURL_SOCKET_BAD); + + m_read_notifier = Core::Notifier::construct(socket_fd, Core::Notifier::Type::Read); + m_read_notifier->on_activation = [this] { + u8 buffer[65536]; + size_t nread = 0; + CURLcode const result = curl_easy_recv(m_easy_handle, buffer, sizeof(buffer), &nread); + if (result == CURLE_AGAIN) + return; + + if (result != CURLE_OK) { + dbgln("Failed to read from WebSocket: {}", curl_easy_strerror(result)); + on_connection_error(); + } + + if (auto const err = m_read_buffer.write_until_depleted({ buffer, nread }); err.is_error()) + on_connection_error(); + + on_ready_to_read(); + }; + m_error_notifier = Core::Notifier::construct(socket_fd, Core::Notifier::Type::Error | Core::Notifier::Type::HangUp); + m_error_notifier->on_activation = [this] { + on_connection_error(); + }; + + on_connected(); +} + +} diff --git a/Services/RequestServer/WebSocketImplCurl.h b/Services/RequestServer/WebSocketImplCurl.h new file mode 100644 index 00000000000..1ee959781a9 --- /dev/null +++ b/Services/RequestServer/WebSocketImplCurl.h @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2025, Andrew Kaster + * + * SPDX-License-Identifier: BSD-2-Clause + */ + +#pragma once + +#include +#include +#include +#include + +namespace RequestServer { + +class WebSocketImplCurl final : public WebSocket::WebSocketImpl { +public: + virtual ~WebSocketImplCurl() override; + + static NonnullRefPtr create(CURLM*); + + virtual void connect(WebSocket::ConnectionInfo const&) override; + virtual bool can_read_line() override; + virtual ErrorOr read_line(size_t) override; + virtual ErrorOr read(int max_size) override; + virtual bool send(ReadonlyBytes) override; + virtual bool eof() override; + virtual void discard_connection() override; + + virtual bool handshake_complete_when_connected() const override { return true; } + + void did_connect(); + +private: + explicit WebSocketImplCurl(CURLM*); + + CURLM* m_multi_handle { nullptr }; + CURL* m_easy_handle { nullptr }; + RefPtr m_read_notifier; + RefPtr m_error_notifier; + Vector m_curl_string_lists; + AllocatingMemoryStream m_read_buffer; +}; + +} diff --git a/vcpkg.json b/vcpkg.json index d941a980aa0..19f1f59e047 100644 --- a/vcpkg.json +++ b/vcpkg.json @@ -9,7 +9,9 @@ "name": "curl", "features": [ "brotli", - "http2" + "http2", + "ssl", + "websockets" ] }, {