From e40afcfec3735e9431c0581c9e1cd4d0086e8838 Mon Sep 17 00:00:00 2001 From: airton Date: Sat, 25 Apr 2026 07:33:29 +0200 Subject: [PATCH 1/4] enable native-tls https --- bitreq/src/connection.rs | 90 +++++++++++++++++++------- bitreq/src/connection/rustls_stream.rs | 59 +++++++++++------ bitreq/src/error.rs | 11 +++- bitreq/tests/main.rs | 27 ++++++++ 4 files changed, 138 insertions(+), 49 deletions(-) diff --git a/bitreq/src/connection.rs b/bitreq/src/connection.rs index b323c6af3..801ca11dd 100644 --- a/bitreq/src/connection.rs +++ b/bitreq/src/connection.rs @@ -29,14 +29,14 @@ use crate::{Error, Method, ResponseLazy}; type UnsecuredStream = TcpStream; -#[cfg(feature = "rustls")] +#[cfg(any(feature = "rustls", feature = "https-native-tls"))] mod rustls_stream; -#[cfg(feature = "rustls")] +#[cfg(any(feature = "rustls", feature = "https-native-tls"))] type SecuredStream = rustls_stream::SecuredStream; pub(crate) enum HttpStream { Unsecured(UnsecuredStream, Option), - #[cfg(feature = "rustls")] + #[cfg(any(feature = "rustls", feature = "https-native-tls"))] Secured(Box, Option), #[cfg(feature = "async")] Buffer(std::io::Cursor>), @@ -81,7 +81,7 @@ impl Read for HttpStream { timeout(inner, *timeout_at)?; inner.read(buf) } - #[cfg(feature = "rustls")] + #[cfg(any(feature = "rustls", feature = "https-native-tls"))] HttpStream::Secured(inner, timeout_at) => { timeout(inner.get_ref(), *timeout_at)?; inner.read(buf) @@ -111,7 +111,7 @@ impl Write for HttpStream { set_socket_write_timeout(inner, *timeout_at)?; inner.write(buf) } - #[cfg(feature = "rustls")] + #[cfg(any(feature = "rustls", feature = "https-native-tls"))] HttpStream::Secured(inner, timeout_at) => { set_socket_write_timeout(inner.get_ref(), *timeout_at)?; inner.write(buf) @@ -137,7 +137,7 @@ impl Write for HttpStream { set_socket_write_timeout(inner, *timeout_at)?; inner.flush() } - #[cfg(feature = "rustls")] + #[cfg(any(feature = "rustls", feature = "https-native-tls"))] HttpStream::Secured(inner, timeout_at) => { set_socket_write_timeout(inner.get_ref(), *timeout_at)?; inner.flush() @@ -158,13 +158,21 @@ impl Write for HttpStream { } } -#[cfg(any(feature = "async-https-rustls", feature = "async-https-rustls-probe"))] +#[cfg(any( + feature = "async-https-rustls", + feature = "async-https-rustls-probe", + feature = "async-https-native-tls" +))] type AsyncSecuredStream = rustls_stream::AsyncSecuredStream; #[cfg(feature = "async")] pub(crate) enum AsyncHttpStream { Unsecured(AsyncTcpStream), - #[cfg(any(feature = "async-https-rustls", feature = "async-https-rustls-probe"))] + #[cfg(any( + feature = "async-https-rustls", + feature = "async-https-rustls-probe", + feature = "async-https-native-tls" + ))] Secured(Box), } @@ -177,7 +185,11 @@ impl AsyncRead for AsyncHttpStream { ) -> Poll> { match &mut *self { AsyncHttpStream::Unsecured(inner) => Pin::new(inner).poll_read(cx, buf), - #[cfg(any(feature = "async-https-rustls", feature = "async-https-rustls-probe"))] + #[cfg(any( + feature = "async-https-rustls", + feature = "async-https-rustls-probe", + feature = "async-https-native-tls" + ))] AsyncHttpStream::Secured(inner) => Pin::new(inner).poll_read(cx, buf), } } @@ -192,7 +204,11 @@ impl AsyncWrite for AsyncHttpStream { ) -> Poll> { match &mut *self { AsyncHttpStream::Unsecured(inner) => Pin::new(inner).poll_write(cx, buf), - #[cfg(any(feature = "async-https-rustls", feature = "async-https-rustls-probe"))] + #[cfg(any( + feature = "async-https-rustls", + feature = "async-https-rustls-probe", + feature = "async-https-native-tls" + ))] AsyncHttpStream::Secured(inner) => Pin::new(inner).poll_write(cx, buf), } } @@ -200,7 +216,11 @@ impl AsyncWrite for AsyncHttpStream { fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match &mut *self { AsyncHttpStream::Unsecured(inner) => Pin::new(inner).poll_flush(cx), - #[cfg(any(feature = "async-https-rustls", feature = "async-https-rustls-probe"))] + #[cfg(any( + feature = "async-https-rustls", + feature = "async-https-rustls-probe", + feature = "async-https-native-tls" + ))] AsyncHttpStream::Secured(inner) => Pin::new(inner).poll_flush(cx), } } @@ -208,7 +228,11 @@ impl AsyncWrite for AsyncHttpStream { fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match &mut *self { AsyncHttpStream::Unsecured(inner) => Pin::new(inner).poll_shutdown(cx), - #[cfg(any(feature = "async-https-rustls", feature = "async-https-rustls-probe"))] + #[cfg(any( + feature = "async-https-rustls", + feature = "async-https-rustls-probe", + feature = "async-https-native-tls" + ))] AsyncHttpStream::Secured(inner) => Pin::new(inner).poll_shutdown(cx), } } @@ -271,13 +295,7 @@ impl AsyncConnection { let socket = Self::connect(params).await?; if params.https { - #[cfg(not(any( - feature = "async-https-rustls", - feature = "async-https-rustls-probe" - )))] - return Err(Error::HttpsFeatureNotEnabled); - #[cfg(any(feature = "async-https-rustls", feature = "async-https-rustls-probe"))] - rustls_stream::wrap_async_stream(socket, params.host).await + Self::wrap_async_stream(socket, params.host).await } else { Ok(AsyncHttpStream::Unsecured(socket)) } @@ -301,6 +319,31 @@ impl AsyncConnection { })))) } + #[cfg(any( + feature = "async-https-rustls", + feature = "async-https-rustls-probe", + feature = "async-https-native-tls" + ))] + async fn wrap_async_stream( + socket: AsyncTcpStream, + host: &str, + ) -> Result { + rustls_stream::wrap_async_stream(socket, host).await + } + + /// Error treatment function, should not be called under normal circustances + #[cfg(not(any( + feature = "async-https-rustls", + feature = "async-https-rustls-probe", + feature = "async-https-native-tls" + )))] + async fn wrap_async_stream( + _socket: AsyncTcpStream, + _host: &str, + ) -> Result { + Err(Error::HttpsFeatureNotEnabled) + } + async fn tcp_connect(host: &str, port: u16) -> Result { #[cfg(feature = "log")] log::trace!("Looking up host {host}"); @@ -656,13 +699,10 @@ impl Connection { let socket = Self::connect(params, timeout_at)?; let stream = if params.https { - #[cfg(not(feature = "rustls"))] + #[cfg(not(any(feature = "rustls", feature = "https-native-tls")))] return Err(Error::HttpsFeatureNotEnabled); - #[cfg(feature = "rustls")] - { - let tls = rustls_stream::wrap_stream(socket, params.host)?; - HttpStream::Secured(Box::new(tls), timeout_at) - } + #[cfg(any(feature = "rustls", feature = "https-native-tls"))] + rustls_stream::wrap_stream(socket, params.host)? } else { HttpStream::create_unsecured(socket, timeout_at) }; diff --git a/bitreq/src/connection/rustls_stream.rs b/bitreq/src/connection/rustls_stream.rs index 62602b0f6..fc71b4d1c 100644 --- a/bitreq/src/connection/rustls_stream.rs +++ b/bitreq/src/connection/rustls_stream.rs @@ -3,26 +3,31 @@ #[cfg(feature = "rustls")] use alloc::sync::Arc; +#[cfg(any(feature = "rustls", feature = "https-native-tls"))] use std::io; use std::net::TcpStream; use std::sync::OnceLock; -#[cfg(all(feature = "native-tls", not(feature = "rustls")))] +#[cfg(all(feature = "https-native-tls", not(feature = "rustls")))] use native_tls::{HandshakeError, TlsConnector, TlsStream}; #[cfg(feature = "rustls")] use rustls::pki_types::ServerName; #[cfg(feature = "rustls")] use rustls::{self, ClientConfig, ClientConnection, RootCertStore, StreamOwned}; -#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))] +#[cfg(all(feature = "async-https-native-tls", not(feature = "rustls")))] use tokio_native_tls::TlsConnector as AsyncTlsConnector; #[cfg(any(feature = "async-https-rustls", feature = "async-https-rustls-probe"))] use tokio_rustls::{client::TlsStream, TlsConnector}; #[cfg(feature = "rustls-webpki")] use webpki_roots::TLS_SERVER_ROOTS; -#[cfg(any(feature = "async-https-rustls", feature = "async-https-rustls-probe"))] -use super::{AsyncHttpStream, AsyncTcpStream}; -#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))] +#[cfg(any(feature = "rustls", feature = "https-native-tls"))] +use super::HttpStream; +#[cfg(any( + feature = "async-https-rustls", + feature = "async-https-rustls-probe", + feature = "async-https-native-tls" +))] use super::{AsyncHttpStream, AsyncTcpStream}; use crate::Error; @@ -50,7 +55,7 @@ fn build_client_config() -> Arc { } #[cfg(feature = "rustls")] -pub(super) fn wrap_stream(tcp: TcpStream, host: &str) -> Result { +pub(super) fn wrap_stream(tcp: TcpStream, host: &str) -> Result { #[cfg(feature = "log")] log::trace!("Setting up TLS parameters for {host}."); let dns_name = ServerName::try_from(host) @@ -58,10 +63,12 @@ pub(super) fn wrap_stream(tcp: TcpStream, host: &str) -> Result; -#[cfg(all(feature = "native-tls", not(feature = "rustls")))] +#[cfg(all(feature = "https-native-tls", not(feature = "rustls")))] static CONNECTOR: OnceLock> = OnceLock::new(); -#[cfg(all(feature = "native-tls", not(feature = "rustls")))] +#[cfg(all(feature = "https-native-tls", not(feature = "rustls")))] fn native_tls_err(e: HandshakeError) -> Error { match e { - HandshakeError::Failure(e) => Error::NativeTlsError(e), + HandshakeError::Failure(err) => Error::NativeTlsCreateConnection(err), HandshakeError::WouldBlock(_) => { debug_assert!(false, "We shouldn't hit a blocking error"); Error::Other("Got a WouldBlock error from native-tls") @@ -107,30 +114,35 @@ fn native_tls_err(e: HandshakeError) -> Error { } } -#[cfg(all(feature = "native-tls", not(feature = "rustls")))] +#[cfg(all(feature = "https-native-tls", not(feature = "rustls")))] fn build_tls_connector() -> Result { - TlsConnector::builder().build().map_err(Error::NativeTlsError) + TlsConnector::builder().build().map_err(Error::from) } -#[cfg(all(feature = "native-tls", not(feature = "rustls")))] -pub(super) fn wrap_stream(tcp: TcpStream, host: &str) -> Result { +#[cfg(all(feature = "https-native-tls", not(feature = "rustls")))] +pub(super) fn wrap_stream(tcp: TcpStream, host: &str) -> Result { #[cfg(feature = "log")] log::trace!("Setting up TLS parameters for {host}."); // TODO: Once we can `get_or_try_init`, so that instead // https://github.com/rust-lang/rust/issues/109737 - let connector = CONNECTOR.get_or_init(build_tls_connector)?; + let connector = match CONNECTOR.get_or_init(build_tls_connector) { + Ok(c) => c.clone(), + Err(err) => return Err(Error::IoError(io::Error::new(io::ErrorKind::Other, err))), + }; #[cfg(feature = "log")] log::trace!("Establishing TLS session to {host}."); - connector.connect(host, tcp).map_err(native_tls_err) + let tls = connector.connect(host, tcp).map_err(native_tls_err)?; + + Ok(HttpStream::Secured(Box::new(tls), None)) } -#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))] +#[cfg(all(feature = "async-https-native-tls", not(feature = "rustls")))] pub type AsyncSecuredStream = tokio_native_tls::TlsStream; -#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))] +#[cfg(all(feature = "async-https-native-tls", not(feature = "rustls")))] pub(super) async fn wrap_async_stream( tcp: AsyncTcpStream, host: &str, @@ -140,12 +152,17 @@ pub(super) async fn wrap_async_stream( // TODO: Once we can `get_or_try_init`, so that instead // https://github.com/rust-lang/rust/issues/109737 - let connector = AsyncTlsConnector::from(CONNECTOR.get_or_init(build_tls_connector)?.clone()); + let sync_connector = match CONNECTOR.get_or_init(build_tls_connector) { + Ok(c) => c.clone(), + Err(err) => return Err(Error::IoError(io::Error::new(io::ErrorKind::Other, err))), + }; + + let async_connector = AsyncTlsConnector::from(sync_connector); #[cfg(feature = "log")] log::trace!("Establishing TLS session to {host}."); - let tls = connector.connect(host, tcp).await.map_err(native_tls_err)?; + let tls = async_connector.connect(host, tcp).await?; Ok(AsyncHttpStream::Secured(Box::new(tls))) } diff --git a/bitreq/src/error.rs b/bitreq/src/error.rs index ca9d1421d..9a2b3c7a6 100644 --- a/bitreq/src/error.rs +++ b/bitreq/src/error.rs @@ -22,7 +22,7 @@ pub enum Error { #[cfg(feature = "rustls")] /// Ran into a rustls error while creating the connection. RustlsCreateConnection(rustls::Error), - #[cfg(feature = "native-tls")] + #[cfg(feature = "https-native-tls")] /// Ran into a native-tls error while creating the connection. NativeTlsCreateConnection(native_tls::Error), /// Ran into an IO problem while loading the response. @@ -104,8 +104,8 @@ impl fmt::Display for Error { InvalidUtf8InBody(err) => write!(f, "{}", err), #[cfg(feature = "rustls")] RustlsCreateConnection(err) => write!(f, "error creating rustls connection: {}", err), - #[cfg(feature = "native-tls")] - NativeTlsCreateConnection(err) => write!(f, "error creating native-tls connection: {err}"), + #[cfg(feature = "https-native-tls")] + NativeTlsCreateConnection(err) => write!(f, "error creating native-tls connection: {}", err), MalformedChunkLength => write!(f, "non-usize chunk length with transfer-encoding: chunked"), MalformedChunkEnd => write!(f, "chunk did not end after reading the expected amount of bytes"), MalformedContentLength => write!(f, "non-usize content length"), @@ -160,3 +160,8 @@ impl From for Error { impl From for Error { fn from(other: UrlParseError) -> Error { Error::InvalidUrl(other) } } + +#[cfg(feature = "https-native-tls")] +impl From for Error { + fn from(err: native_tls::Error) -> Error { Error::NativeTlsCreateConnection(err) } +} diff --git a/bitreq/tests/main.rs b/bitreq/tests/main.rs index 8d357f354..2dae034f6 100644 --- a/bitreq/tests/main.rs +++ b/bitreq/tests/main.rs @@ -16,6 +16,33 @@ async fn test_https() { assert_eq!(get_status_code(bitreq::get("https://example.com")).await, 200); } +#[tokio::test] +#[cfg(all(feature = "async-https-native-tls", not(feature = "rustls")))] +async fn test_https() { + // TODO: Implement this locally. + assert_eq!(get_status_code(bitreq::get("https://example.com")).await, 200); + // Test reusing the existing connection in client: + assert_eq!(get_status_code(bitreq::get("https://example.com")).await, 200); +} + +#[tokio::test] +#[cfg(all(feature = "rustls", feature = "tokio-rustls"))] +async fn test_https_with_client() { + setup(); + let client = bitreq::Client::new(1); + let response = client.send_async(bitreq::get("https://example.com")).await.unwrap(); + assert_eq!(response.status_code, 200); +} + +#[tokio::test] +#[cfg(all(feature = "async-https-native-tls", not(feature = "rustls")))] +async fn test_https_with_client() { + setup(); + let client = bitreq::Client::new(1); + let response = client.send_async(bitreq::get("https://example.com")).await.unwrap(); + assert_eq!(response.status_code, 200); +} + #[tokio::test] #[cfg(feature = "json-using-serde")] async fn test_json_using_serde() { From 1c95cf69bfa3ca95cbfaa599737c174349e0ea60 Mon Sep 17 00:00:00 2001 From: airton Date: Sat, 25 Apr 2026 09:18:21 +0200 Subject: [PATCH 2/4] refactor: add native-tls module --- bitreq/src/connection.rs | 63 ++++++++++++--- bitreq/src/connection/native_tls_stream.rs | 85 ++++++++++++++++++++ bitreq/src/connection/rustls_stream.rs | 91 ++-------------------- 3 files changed, 143 insertions(+), 96 deletions(-) create mode 100644 bitreq/src/connection/native_tls_stream.rs diff --git a/bitreq/src/connection.rs b/bitreq/src/connection.rs index 801ca11dd..ee226295d 100644 --- a/bitreq/src/connection.rs +++ b/bitreq/src/connection.rs @@ -29,14 +29,33 @@ use crate::{Error, Method, ResponseLazy}; type UnsecuredStream = TcpStream; -#[cfg(any(feature = "rustls", feature = "https-native-tls"))] +#[cfg(any(feature = "https-rustls", feature = "https-rustls-probe"))] mod rustls_stream; -#[cfg(any(feature = "rustls", feature = "https-native-tls"))] -type SecuredStream = rustls_stream::SecuredStream; + +#[cfg(all( + feature = "https-native-tls", + not(any(feature = "https-rustls", feature = "https-rustls-probe")) +))] +mod native_tls_stream; + +#[cfg(all( + feature = "https-native-tls", + not(any(feature = "https-rustls", feature = "https-rustls-probe")) +))] +use self::native_tls_stream as tls_stream; +#[cfg(any(feature = "https-rustls", feature = "https-rustls-probe"))] +use self::rustls_stream as tls_stream; + +#[cfg(any(feature = "https-rustls", feature = "https-rustls-probe", feature = "https-native-tls"))] +type SecuredStream = tls_stream::SecuredStream; pub(crate) enum HttpStream { Unsecured(UnsecuredStream, Option), - #[cfg(any(feature = "rustls", feature = "https-native-tls"))] + #[cfg(any( + feature = "https-rustls", + feature = "https-rustls-probe", + feature = "https-native-tls" + ))] Secured(Box, Option), #[cfg(feature = "async")] Buffer(std::io::Cursor>), @@ -81,7 +100,11 @@ impl Read for HttpStream { timeout(inner, *timeout_at)?; inner.read(buf) } - #[cfg(any(feature = "rustls", feature = "https-native-tls"))] + #[cfg(any( + feature = "https-rustls", + feature = "https-rustls-probe", + feature = "https-native-tls" + ))] HttpStream::Secured(inner, timeout_at) => { timeout(inner.get_ref(), *timeout_at)?; inner.read(buf) @@ -111,7 +134,11 @@ impl Write for HttpStream { set_socket_write_timeout(inner, *timeout_at)?; inner.write(buf) } - #[cfg(any(feature = "rustls", feature = "https-native-tls"))] + #[cfg(any( + feature = "https-rustls", + feature = "https-rustls-probe", + feature = "https-native-tls" + ))] HttpStream::Secured(inner, timeout_at) => { set_socket_write_timeout(inner.get_ref(), *timeout_at)?; inner.write(buf) @@ -137,7 +164,11 @@ impl Write for HttpStream { set_socket_write_timeout(inner, *timeout_at)?; inner.flush() } - #[cfg(any(feature = "rustls", feature = "https-native-tls"))] + #[cfg(any( + feature = "https-rustls", + feature = "https-rustls-probe", + feature = "https-native-tls" + ))] HttpStream::Secured(inner, timeout_at) => { set_socket_write_timeout(inner.get_ref(), *timeout_at)?; inner.flush() @@ -163,7 +194,7 @@ impl Write for HttpStream { feature = "async-https-rustls-probe", feature = "async-https-native-tls" ))] -type AsyncSecuredStream = rustls_stream::AsyncSecuredStream; +type AsyncSecuredStream = tls_stream::AsyncSecuredStream; #[cfg(feature = "async")] pub(crate) enum AsyncHttpStream { @@ -328,7 +359,7 @@ impl AsyncConnection { socket: AsyncTcpStream, host: &str, ) -> Result { - rustls_stream::wrap_async_stream(socket, host).await + tls_stream::wrap_async_stream(socket, host).await } /// Error treatment function, should not be called under normal circustances @@ -699,10 +730,18 @@ impl Connection { let socket = Self::connect(params, timeout_at)?; let stream = if params.https { - #[cfg(not(any(feature = "rustls", feature = "https-native-tls")))] + #[cfg(not(any( + feature = "https-rustls", + feature = "https-rustls-probe", + feature = "https-native-tls" + )))] return Err(Error::HttpsFeatureNotEnabled); - #[cfg(any(feature = "rustls", feature = "https-native-tls"))] - rustls_stream::wrap_stream(socket, params.host)? + #[cfg(any( + feature = "https-rustls", + feature = "https-rustls-probe", + feature = "https-native-tls" + ))] + tls_stream::wrap_stream(socket, params.host)? } else { HttpStream::create_unsecured(socket, timeout_at) }; diff --git a/bitreq/src/connection/native_tls_stream.rs b/bitreq/src/connection/native_tls_stream.rs new file mode 100644 index 000000000..55142f652 --- /dev/null +++ b/bitreq/src/connection/native_tls_stream.rs @@ -0,0 +1,85 @@ +//! Native-TLS connection handling functionality. +//! This module is only compiled when a native-tls HTTPS feature is enabled +//! AND no rustls feature is enabled (mutual exclusion enforced at module level). + +use std::io; +use std::net::TcpStream; +use std::sync::OnceLock; + +use native_tls::{HandshakeError, TlsConnector, TlsStream}; +#[cfg(feature = "async-https-native-tls")] +use tokio_native_tls::TlsConnector as AsyncTlsConnector; + +use super::HttpStream; +#[cfg(feature = "async-https-native-tls")] +use super::{AsyncHttpStream, AsyncTcpStream}; +use crate::Error; + +// === SYNC native-tls === + +pub type SecuredStream = TlsStream; + +static CONNECTOR: OnceLock> = OnceLock::new(); + +fn native_tls_err(e: HandshakeError) -> Error { + match e { + HandshakeError::Failure(err) => Error::NativeTlsCreateConnection(err), + HandshakeError::WouldBlock(_) => { + debug_assert!(false, "We shouldn't hit a blocking error"); + Error::Other("Got a WouldBlock error from native-tls") + } + } +} + +fn build_tls_connector() -> Result { + TlsConnector::builder().build().map_err(Error::from) +} + +pub(super) fn wrap_stream(tcp: TcpStream, host: &str) -> Result { + #[cfg(feature = "log")] + log::trace!("Setting up TLS parameters for {host}."); + + // TODO: Once we can `get_or_try_init`, so that instead + // https://github.com/rust-lang/rust/issues/109737 + let connector = match CONNECTOR.get_or_init(build_tls_connector) { + Ok(c) => c.clone(), + Err(err) => return Err(Error::IoError(io::Error::new(io::ErrorKind::Other, err))), + }; + + #[cfg(feature = "log")] + log::trace!("Establishing TLS session to {host}."); + + let tls = connector.connect(host, tcp).map_err(native_tls_err)?; + + Ok(HttpStream::Secured(Box::new(tls), None)) +} + +// === ASYNC native-tls === + +#[cfg(feature = "async-https-native-tls")] +pub type AsyncSecuredStream = tokio_native_tls::TlsStream; + +#[cfg(feature = "async-https-native-tls")] +pub(super) async fn wrap_async_stream( + tcp: AsyncTcpStream, + host: &str, +) -> Result { + #[cfg(feature = "log")] + log::trace!("Setting up TLS parameters for {host}."); + + // TODO: Once we can `get_or_try_init`, so that instead + // https://github.com/rust-lang/rust/issues/109737 + let sync_connector = match CONNECTOR.get_or_init(build_tls_connector) { + Ok(c) => c.clone(), + Err(err) => return Err(Error::IoError(io::Error::new(io::ErrorKind::Other, err))), + }; + + let async_connector = AsyncTlsConnector::from(sync_connector); + + #[cfg(feature = "log")] + log::trace!("Establishing TLS session to {host}."); + + let tls = async_connector.connect(host, tcp).await?; + + Ok(AsyncHttpStream::Secured(Box::new(tls))) +} diff --git a/bitreq/src/connection/rustls_stream.rs b/bitreq/src/connection/rustls_stream.rs index fc71b4d1c..fef124a0e 100644 --- a/bitreq/src/connection/rustls_stream.rs +++ b/bitreq/src/connection/rustls_stream.rs @@ -1,36 +1,29 @@ -//! TLS connection handling functionality - supports both `rustls` and `native-tls` backends. -//! When both features are enabled, rustls takes precedence. +//! Rustls-based TLS connection handling functionality. #[cfg(feature = "rustls")] use alloc::sync::Arc; -#[cfg(any(feature = "rustls", feature = "https-native-tls"))] +#[cfg(feature = "rustls")] use std::io; use std::net::TcpStream; use std::sync::OnceLock; -#[cfg(all(feature = "https-native-tls", not(feature = "rustls")))] -use native_tls::{HandshakeError, TlsConnector, TlsStream}; #[cfg(feature = "rustls")] use rustls::pki_types::ServerName; #[cfg(feature = "rustls")] use rustls::{self, ClientConfig, ClientConnection, RootCertStore, StreamOwned}; -#[cfg(all(feature = "async-https-native-tls", not(feature = "rustls")))] -use tokio_native_tls::TlsConnector as AsyncTlsConnector; #[cfg(any(feature = "async-https-rustls", feature = "async-https-rustls-probe"))] use tokio_rustls::{client::TlsStream, TlsConnector}; #[cfg(feature = "rustls-webpki")] use webpki_roots::TLS_SERVER_ROOTS; -#[cfg(any(feature = "rustls", feature = "https-native-tls"))] +#[cfg(feature = "rustls")] use super::HttpStream; -#[cfg(any( - feature = "async-https-rustls", - feature = "async-https-rustls-probe", - feature = "async-https-native-tls" -))] +#[cfg(any(feature = "async-https-rustls", feature = "async-https-rustls-probe"))] use super::{AsyncHttpStream, AsyncTcpStream}; use crate::Error; +// === SYNC rustls === + #[cfg(feature = "rustls")] pub type SecuredStream = StreamOwned; @@ -71,7 +64,7 @@ pub(super) fn wrap_stream(tcp: TcpStream, host: &str) -> Result; @@ -96,73 +89,3 @@ pub(super) async fn wrap_async_stream( Ok(AsyncHttpStream::Secured(Box::new(tls))) } - -#[cfg(all(feature = "https-native-tls", not(feature = "rustls")))] -pub type SecuredStream = TlsStream; - -#[cfg(all(feature = "https-native-tls", not(feature = "rustls")))] -static CONNECTOR: OnceLock> = OnceLock::new(); - -#[cfg(all(feature = "https-native-tls", not(feature = "rustls")))] -fn native_tls_err(e: HandshakeError) -> Error { - match e { - HandshakeError::Failure(err) => Error::NativeTlsCreateConnection(err), - HandshakeError::WouldBlock(_) => { - debug_assert!(false, "We shouldn't hit a blocking error"); - Error::Other("Got a WouldBlock error from native-tls") - } - } -} - -#[cfg(all(feature = "https-native-tls", not(feature = "rustls")))] -fn build_tls_connector() -> Result { - TlsConnector::builder().build().map_err(Error::from) -} - -#[cfg(all(feature = "https-native-tls", not(feature = "rustls")))] -pub(super) fn wrap_stream(tcp: TcpStream, host: &str) -> Result { - #[cfg(feature = "log")] - log::trace!("Setting up TLS parameters for {host}."); - - // TODO: Once we can `get_or_try_init`, so that instead - // https://github.com/rust-lang/rust/issues/109737 - let connector = match CONNECTOR.get_or_init(build_tls_connector) { - Ok(c) => c.clone(), - Err(err) => return Err(Error::IoError(io::Error::new(io::ErrorKind::Other, err))), - }; - - #[cfg(feature = "log")] - log::trace!("Establishing TLS session to {host}."); - - let tls = connector.connect(host, tcp).map_err(native_tls_err)?; - - Ok(HttpStream::Secured(Box::new(tls), None)) -} - -#[cfg(all(feature = "async-https-native-tls", not(feature = "rustls")))] -pub type AsyncSecuredStream = tokio_native_tls::TlsStream; - -#[cfg(all(feature = "async-https-native-tls", not(feature = "rustls")))] -pub(super) async fn wrap_async_stream( - tcp: AsyncTcpStream, - host: &str, -) -> Result { - #[cfg(feature = "log")] - log::trace!("Setting up TLS parameters for {host}."); - - // TODO: Once we can `get_or_try_init`, so that instead - // https://github.com/rust-lang/rust/issues/109737 - let sync_connector = match CONNECTOR.get_or_init(build_tls_connector) { - Ok(c) => c.clone(), - Err(err) => return Err(Error::IoError(io::Error::new(io::ErrorKind::Other, err))), - }; - - let async_connector = AsyncTlsConnector::from(sync_connector); - - #[cfg(feature = "log")] - log::trace!("Establishing TLS session to {host}."); - - let tls = async_connector.connect(host, tcp).await?; - - Ok(AsyncHttpStream::Secured(Box::new(tls))) -} From 9cb51d8058548adda428da312504763cb15a4184 Mon Sep 17 00:00:00 2001 From: airton Date: Sat, 25 Apr 2026 11:21:24 +0200 Subject: [PATCH 3/4] refactor: adjust rustls flags --- bitreq/src/connection/rustls_stream.rs | 9 --------- bitreq/src/error.rs | 6 +++--- bitreq/tests/main.rs | 14 ++++++++++---- 3 files changed, 13 insertions(+), 16 deletions(-) diff --git a/bitreq/src/connection/rustls_stream.rs b/bitreq/src/connection/rustls_stream.rs index fef124a0e..43278de48 100644 --- a/bitreq/src/connection/rustls_stream.rs +++ b/bitreq/src/connection/rustls_stream.rs @@ -1,22 +1,17 @@ //! Rustls-based TLS connection handling functionality. -#[cfg(feature = "rustls")] use alloc::sync::Arc; -#[cfg(feature = "rustls")] use std::io; use std::net::TcpStream; use std::sync::OnceLock; -#[cfg(feature = "rustls")] use rustls::pki_types::ServerName; -#[cfg(feature = "rustls")] use rustls::{self, ClientConfig, ClientConnection, RootCertStore, StreamOwned}; #[cfg(any(feature = "async-https-rustls", feature = "async-https-rustls-probe"))] use tokio_rustls::{client::TlsStream, TlsConnector}; #[cfg(feature = "rustls-webpki")] use webpki_roots::TLS_SERVER_ROOTS; -#[cfg(feature = "rustls")] use super::HttpStream; #[cfg(any(feature = "async-https-rustls", feature = "async-https-rustls-probe"))] use super::{AsyncHttpStream, AsyncTcpStream}; @@ -24,13 +19,10 @@ use crate::Error; // === SYNC rustls === -#[cfg(feature = "rustls")] pub type SecuredStream = StreamOwned; -#[cfg(feature = "rustls")] static CONFIG: OnceLock> = OnceLock::new(); -#[cfg(feature = "rustls")] fn build_client_config() -> Arc { let mut root_certificates = RootCertStore::empty(); @@ -47,7 +39,6 @@ fn build_client_config() -> Arc { Arc::new(config) } -#[cfg(feature = "rustls")] pub(super) fn wrap_stream(tcp: TcpStream, host: &str) -> Result { #[cfg(feature = "log")] log::trace!("Setting up TLS parameters for {host}."); diff --git a/bitreq/src/error.rs b/bitreq/src/error.rs index 9a2b3c7a6..62bbd0959 100644 --- a/bitreq/src/error.rs +++ b/bitreq/src/error.rs @@ -19,7 +19,7 @@ pub enum Error { /// The response body contains invalid UTF-8, so the `as_str()` /// conversion failed. InvalidUtf8InBody(str::Utf8Error), - #[cfg(feature = "rustls")] + #[cfg(any(feature = "https-rustls", feature = "https-rustls-probe"))] /// Ran into a rustls error while creating the connection. RustlsCreateConnection(rustls::Error), #[cfg(feature = "https-native-tls")] @@ -102,7 +102,7 @@ impl fmt::Display for Error { IoError(err) => write!(f, "{}", err), InvalidUrl(err) => write!(f, "failed to parse given URL: {}", err), InvalidUtf8InBody(err) => write!(f, "{}", err), - #[cfg(feature = "rustls")] + #[cfg(any(feature = "https-rustls", feature = "https-rustls-probe"))] RustlsCreateConnection(err) => write!(f, "error creating rustls connection: {}", err), #[cfg(feature = "https-native-tls")] NativeTlsCreateConnection(err) => write!(f, "error creating native-tls connection: {}", err), @@ -145,7 +145,7 @@ impl error::Error for Error { IoError(err) => Some(err), InvalidUrl(err) => Some(err), InvalidUtf8InBody(err) => Some(err), - #[cfg(feature = "rustls")] + #[cfg(any(feature = "https-rustls", feature = "https-rustls-probe"))] RustlsCreateConnection(err) => Some(err), _ => None, } diff --git a/bitreq/tests/main.rs b/bitreq/tests/main.rs index 2dae034f6..e7e204e94 100644 --- a/bitreq/tests/main.rs +++ b/bitreq/tests/main.rs @@ -8,7 +8,7 @@ use std::io; use self::setup::*; #[tokio::test] -#[cfg(feature = "rustls")] +#[cfg(any(feature = "https-rustls", feature = "https-rustls-probe"))] async fn test_https() { // TODO: Implement this locally. assert_eq!(get_status_code(bitreq::get("https://example.com")).await, 200); @@ -17,7 +17,10 @@ async fn test_https() { } #[tokio::test] -#[cfg(all(feature = "async-https-native-tls", not(feature = "rustls")))] +#[cfg(all( + feature = "async-https-native-tls", + not(any(feature = "https-rustls", feature = "https-rustls-probe")) +))] async fn test_https() { // TODO: Implement this locally. assert_eq!(get_status_code(bitreq::get("https://example.com")).await, 200); @@ -26,7 +29,7 @@ async fn test_https() { } #[tokio::test] -#[cfg(all(feature = "rustls", feature = "tokio-rustls"))] +#[cfg(any(feature = "async-https-rustls", feature = "async-https-rustls-probe"))] async fn test_https_with_client() { setup(); let client = bitreq::Client::new(1); @@ -35,7 +38,10 @@ async fn test_https_with_client() { } #[tokio::test] -#[cfg(all(feature = "async-https-native-tls", not(feature = "rustls")))] +#[cfg(all( + feature = "async-https-native-tls", + not(any(feature = "https-rustls", feature = "https-rustls-probe")) +))] async fn test_https_with_client() { setup(); let client = bitreq::Client::new(1); From 3fe473b17460d4d78f7ec4fea08d06d5519d7726 Mon Sep 17 00:00:00 2001 From: airton Date: Sat, 25 Apr 2026 20:30:13 +0200 Subject: [PATCH 4/4] refactor: adjust sync/async tls streams and cfg gates --- bitreq/src/connection.rs | 38 ++++++++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/bitreq/src/connection.rs b/bitreq/src/connection.rs index ee226295d..e3c0a36ae 100644 --- a/bitreq/src/connection.rs +++ b/bitreq/src/connection.rs @@ -29,12 +29,23 @@ use crate::{Error, Method, ResponseLazy}; type UnsecuredStream = TcpStream; -#[cfg(any(feature = "https-rustls", feature = "https-rustls-probe"))] +#[cfg(any( + feature = "https-rustls", + feature = "https-rustls-probe", + feature = "async-https-rustls", + feature = "async-https-rustls-probe" +))] mod rustls_stream; -#[cfg(all( - feature = "https-native-tls", - not(any(feature = "https-rustls", feature = "https-rustls-probe")) +#[cfg(any( + all( + feature = "https-native-tls", + not(any(feature = "https-rustls", feature = "https-rustls-probe")) + ), + all( + feature = "async-https-native-tls", + not(any(feature = "async-https-rustls", feature = "async-https-rustls-probe")) + ) ))] mod native_tls_stream; @@ -42,12 +53,19 @@ mod native_tls_stream; feature = "https-native-tls", not(any(feature = "https-rustls", feature = "https-rustls-probe")) ))] -use self::native_tls_stream as tls_stream; +use self::native_tls_stream as sync_tls_stream; +#[cfg(all( + feature = "async-https-native-tls", + not(any(feature = "async-https-rustls", feature = "async-https-rustls-probe")) +))] +use self::native_tls_stream as async_tls_stream; #[cfg(any(feature = "https-rustls", feature = "https-rustls-probe"))] -use self::rustls_stream as tls_stream; +use self::rustls_stream as sync_tls_stream; +#[cfg(any(feature = "async-https-rustls", feature = "async-https-rustls-probe"))] +use self::rustls_stream as async_tls_stream; #[cfg(any(feature = "https-rustls", feature = "https-rustls-probe", feature = "https-native-tls"))] -type SecuredStream = tls_stream::SecuredStream; +type SecuredStream = sync_tls_stream::SecuredStream; pub(crate) enum HttpStream { Unsecured(UnsecuredStream, Option), @@ -194,7 +212,7 @@ impl Write for HttpStream { feature = "async-https-rustls-probe", feature = "async-https-native-tls" ))] -type AsyncSecuredStream = tls_stream::AsyncSecuredStream; +type AsyncSecuredStream = async_tls_stream::AsyncSecuredStream; #[cfg(feature = "async")] pub(crate) enum AsyncHttpStream { @@ -359,7 +377,7 @@ impl AsyncConnection { socket: AsyncTcpStream, host: &str, ) -> Result { - tls_stream::wrap_async_stream(socket, host).await + async_tls_stream::wrap_async_stream(socket, host).await } /// Error treatment function, should not be called under normal circustances @@ -741,7 +759,7 @@ impl Connection { feature = "https-rustls-probe", feature = "https-native-tls" ))] - tls_stream::wrap_stream(socket, params.host)? + sync_tls_stream::wrap_stream(socket, params.host)? } else { HttpStream::create_unsecured(socket, timeout_at) };