diff --git a/bitreq/src/connection.rs b/bitreq/src/connection.rs index b323c6af3..e3c0a36ae 100644 --- a/bitreq/src/connection.rs +++ b/bitreq/src/connection.rs @@ -29,14 +29,51 @@ use crate::{Error, Method, ResponseLazy}; type UnsecuredStream = TcpStream; -#[cfg(feature = "rustls")] +#[cfg(any( + feature = "https-rustls", + feature = "https-rustls-probe", + feature = "async-https-rustls", + feature = "async-https-rustls-probe" +))] mod rustls_stream; -#[cfg(feature = "rustls")] -type SecuredStream = rustls_stream::SecuredStream; + +#[cfg(any( + all( + feature = "https-native-tls", + not(any(feature = "https-rustls", feature = "https-rustls-probe")) + ), + all( + feature = "async-https-native-tls", + not(any(feature = "async-https-rustls", feature = "async-https-rustls-probe")) + ) +))] +mod native_tls_stream; + +#[cfg(all( + feature = "https-native-tls", + not(any(feature = "https-rustls", feature = "https-rustls-probe")) +))] +use self::native_tls_stream as sync_tls_stream; +#[cfg(all( + feature = "async-https-native-tls", + not(any(feature = "async-https-rustls", feature = "async-https-rustls-probe")) +))] +use self::native_tls_stream as async_tls_stream; +#[cfg(any(feature = "https-rustls", feature = "https-rustls-probe"))] +use self::rustls_stream as sync_tls_stream; +#[cfg(any(feature = "async-https-rustls", feature = "async-https-rustls-probe"))] +use self::rustls_stream as async_tls_stream; + +#[cfg(any(feature = "https-rustls", feature = "https-rustls-probe", feature = "https-native-tls"))] +type SecuredStream = sync_tls_stream::SecuredStream; pub(crate) enum HttpStream { Unsecured(UnsecuredStream, Option), - #[cfg(feature = "rustls")] + #[cfg(any( + feature = "https-rustls", + feature = "https-rustls-probe", + feature = "https-native-tls" + ))] Secured(Box, Option), #[cfg(feature = "async")] Buffer(std::io::Cursor>), @@ -81,7 +118,11 @@ impl Read for HttpStream { timeout(inner, *timeout_at)?; inner.read(buf) } - #[cfg(feature = "rustls")] + #[cfg(any( + feature = "https-rustls", + feature = "https-rustls-probe", + feature = "https-native-tls" + ))] HttpStream::Secured(inner, timeout_at) => { timeout(inner.get_ref(), *timeout_at)?; inner.read(buf) @@ -111,7 +152,11 @@ impl Write for HttpStream { set_socket_write_timeout(inner, *timeout_at)?; inner.write(buf) } - #[cfg(feature = "rustls")] + #[cfg(any( + feature = "https-rustls", + feature = "https-rustls-probe", + feature = "https-native-tls" + ))] HttpStream::Secured(inner, timeout_at) => { set_socket_write_timeout(inner.get_ref(), *timeout_at)?; inner.write(buf) @@ -137,7 +182,11 @@ impl Write for HttpStream { set_socket_write_timeout(inner, *timeout_at)?; inner.flush() } - #[cfg(feature = "rustls")] + #[cfg(any( + feature = "https-rustls", + feature = "https-rustls-probe", + feature = "https-native-tls" + ))] HttpStream::Secured(inner, timeout_at) => { set_socket_write_timeout(inner.get_ref(), *timeout_at)?; inner.flush() @@ -158,13 +207,21 @@ impl Write for HttpStream { } } -#[cfg(any(feature = "async-https-rustls", feature = "async-https-rustls-probe"))] -type AsyncSecuredStream = rustls_stream::AsyncSecuredStream; +#[cfg(any( + feature = "async-https-rustls", + feature = "async-https-rustls-probe", + feature = "async-https-native-tls" +))] +type AsyncSecuredStream = async_tls_stream::AsyncSecuredStream; #[cfg(feature = "async")] pub(crate) enum AsyncHttpStream { Unsecured(AsyncTcpStream), - #[cfg(any(feature = "async-https-rustls", feature = "async-https-rustls-probe"))] + #[cfg(any( + feature = "async-https-rustls", + feature = "async-https-rustls-probe", + feature = "async-https-native-tls" + ))] Secured(Box), } @@ -177,7 +234,11 @@ impl AsyncRead for AsyncHttpStream { ) -> Poll> { match &mut *self { AsyncHttpStream::Unsecured(inner) => Pin::new(inner).poll_read(cx, buf), - #[cfg(any(feature = "async-https-rustls", feature = "async-https-rustls-probe"))] + #[cfg(any( + feature = "async-https-rustls", + feature = "async-https-rustls-probe", + feature = "async-https-native-tls" + ))] AsyncHttpStream::Secured(inner) => Pin::new(inner).poll_read(cx, buf), } } @@ -192,7 +253,11 @@ impl AsyncWrite for AsyncHttpStream { ) -> Poll> { match &mut *self { AsyncHttpStream::Unsecured(inner) => Pin::new(inner).poll_write(cx, buf), - #[cfg(any(feature = "async-https-rustls", feature = "async-https-rustls-probe"))] + #[cfg(any( + feature = "async-https-rustls", + feature = "async-https-rustls-probe", + feature = "async-https-native-tls" + ))] AsyncHttpStream::Secured(inner) => Pin::new(inner).poll_write(cx, buf), } } @@ -200,7 +265,11 @@ impl AsyncWrite for AsyncHttpStream { fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match &mut *self { AsyncHttpStream::Unsecured(inner) => Pin::new(inner).poll_flush(cx), - #[cfg(any(feature = "async-https-rustls", feature = "async-https-rustls-probe"))] + #[cfg(any( + feature = "async-https-rustls", + feature = "async-https-rustls-probe", + feature = "async-https-native-tls" + ))] AsyncHttpStream::Secured(inner) => Pin::new(inner).poll_flush(cx), } } @@ -208,7 +277,11 @@ impl AsyncWrite for AsyncHttpStream { fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match &mut *self { AsyncHttpStream::Unsecured(inner) => Pin::new(inner).poll_shutdown(cx), - #[cfg(any(feature = "async-https-rustls", feature = "async-https-rustls-probe"))] + #[cfg(any( + feature = "async-https-rustls", + feature = "async-https-rustls-probe", + feature = "async-https-native-tls" + ))] AsyncHttpStream::Secured(inner) => Pin::new(inner).poll_shutdown(cx), } } @@ -271,13 +344,7 @@ impl AsyncConnection { let socket = Self::connect(params).await?; if params.https { - #[cfg(not(any( - feature = "async-https-rustls", - feature = "async-https-rustls-probe" - )))] - return Err(Error::HttpsFeatureNotEnabled); - #[cfg(any(feature = "async-https-rustls", feature = "async-https-rustls-probe"))] - rustls_stream::wrap_async_stream(socket, params.host).await + Self::wrap_async_stream(socket, params.host).await } else { Ok(AsyncHttpStream::Unsecured(socket)) } @@ -301,6 +368,31 @@ impl AsyncConnection { })))) } + #[cfg(any( + feature = "async-https-rustls", + feature = "async-https-rustls-probe", + feature = "async-https-native-tls" + ))] + async fn wrap_async_stream( + socket: AsyncTcpStream, + host: &str, + ) -> Result { + async_tls_stream::wrap_async_stream(socket, host).await + } + + /// Error treatment function, should not be called under normal circustances + #[cfg(not(any( + feature = "async-https-rustls", + feature = "async-https-rustls-probe", + feature = "async-https-native-tls" + )))] + async fn wrap_async_stream( + _socket: AsyncTcpStream, + _host: &str, + ) -> Result { + Err(Error::HttpsFeatureNotEnabled) + } + async fn tcp_connect(host: &str, port: u16) -> Result { #[cfg(feature = "log")] log::trace!("Looking up host {host}"); @@ -656,13 +748,18 @@ impl Connection { let socket = Self::connect(params, timeout_at)?; let stream = if params.https { - #[cfg(not(feature = "rustls"))] + #[cfg(not(any( + feature = "https-rustls", + feature = "https-rustls-probe", + feature = "https-native-tls" + )))] return Err(Error::HttpsFeatureNotEnabled); - #[cfg(feature = "rustls")] - { - let tls = rustls_stream::wrap_stream(socket, params.host)?; - HttpStream::Secured(Box::new(tls), timeout_at) - } + #[cfg(any( + feature = "https-rustls", + feature = "https-rustls-probe", + feature = "https-native-tls" + ))] + sync_tls_stream::wrap_stream(socket, params.host)? } else { HttpStream::create_unsecured(socket, timeout_at) }; diff --git a/bitreq/src/connection/native_tls_stream.rs b/bitreq/src/connection/native_tls_stream.rs new file mode 100644 index 000000000..55142f652 --- /dev/null +++ b/bitreq/src/connection/native_tls_stream.rs @@ -0,0 +1,85 @@ +//! Native-TLS connection handling functionality. +//! This module is only compiled when a native-tls HTTPS feature is enabled +//! AND no rustls feature is enabled (mutual exclusion enforced at module level). + +use std::io; +use std::net::TcpStream; +use std::sync::OnceLock; + +use native_tls::{HandshakeError, TlsConnector, TlsStream}; +#[cfg(feature = "async-https-native-tls")] +use tokio_native_tls::TlsConnector as AsyncTlsConnector; + +use super::HttpStream; +#[cfg(feature = "async-https-native-tls")] +use super::{AsyncHttpStream, AsyncTcpStream}; +use crate::Error; + +// === SYNC native-tls === + +pub type SecuredStream = TlsStream; + +static CONNECTOR: OnceLock> = OnceLock::new(); + +fn native_tls_err(e: HandshakeError) -> Error { + match e { + HandshakeError::Failure(err) => Error::NativeTlsCreateConnection(err), + HandshakeError::WouldBlock(_) => { + debug_assert!(false, "We shouldn't hit a blocking error"); + Error::Other("Got a WouldBlock error from native-tls") + } + } +} + +fn build_tls_connector() -> Result { + TlsConnector::builder().build().map_err(Error::from) +} + +pub(super) fn wrap_stream(tcp: TcpStream, host: &str) -> Result { + #[cfg(feature = "log")] + log::trace!("Setting up TLS parameters for {host}."); + + // TODO: Once we can `get_or_try_init`, so that instead + // https://github.com/rust-lang/rust/issues/109737 + let connector = match CONNECTOR.get_or_init(build_tls_connector) { + Ok(c) => c.clone(), + Err(err) => return Err(Error::IoError(io::Error::new(io::ErrorKind::Other, err))), + }; + + #[cfg(feature = "log")] + log::trace!("Establishing TLS session to {host}."); + + let tls = connector.connect(host, tcp).map_err(native_tls_err)?; + + Ok(HttpStream::Secured(Box::new(tls), None)) +} + +// === ASYNC native-tls === + +#[cfg(feature = "async-https-native-tls")] +pub type AsyncSecuredStream = tokio_native_tls::TlsStream; + +#[cfg(feature = "async-https-native-tls")] +pub(super) async fn wrap_async_stream( + tcp: AsyncTcpStream, + host: &str, +) -> Result { + #[cfg(feature = "log")] + log::trace!("Setting up TLS parameters for {host}."); + + // TODO: Once we can `get_or_try_init`, so that instead + // https://github.com/rust-lang/rust/issues/109737 + let sync_connector = match CONNECTOR.get_or_init(build_tls_connector) { + Ok(c) => c.clone(), + Err(err) => return Err(Error::IoError(io::Error::new(io::ErrorKind::Other, err))), + }; + + let async_connector = AsyncTlsConnector::from(sync_connector); + + #[cfg(feature = "log")] + log::trace!("Establishing TLS session to {host}."); + + let tls = async_connector.connect(host, tcp).await?; + + Ok(AsyncHttpStream::Secured(Box::new(tls))) +} diff --git a/bitreq/src/connection/rustls_stream.rs b/bitreq/src/connection/rustls_stream.rs index 62602b0f6..43278de48 100644 --- a/bitreq/src/connection/rustls_stream.rs +++ b/bitreq/src/connection/rustls_stream.rs @@ -1,38 +1,28 @@ -//! TLS connection handling functionality - supports both `rustls` and `native-tls` backends. -//! When both features are enabled, rustls takes precedence. +//! Rustls-based TLS connection handling functionality. -#[cfg(feature = "rustls")] use alloc::sync::Arc; use std::io; use std::net::TcpStream; use std::sync::OnceLock; -#[cfg(all(feature = "native-tls", not(feature = "rustls")))] -use native_tls::{HandshakeError, TlsConnector, TlsStream}; -#[cfg(feature = "rustls")] use rustls::pki_types::ServerName; -#[cfg(feature = "rustls")] use rustls::{self, ClientConfig, ClientConnection, RootCertStore, StreamOwned}; -#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))] -use tokio_native_tls::TlsConnector as AsyncTlsConnector; #[cfg(any(feature = "async-https-rustls", feature = "async-https-rustls-probe"))] use tokio_rustls::{client::TlsStream, TlsConnector}; #[cfg(feature = "rustls-webpki")] use webpki_roots::TLS_SERVER_ROOTS; +use super::HttpStream; #[cfg(any(feature = "async-https-rustls", feature = "async-https-rustls-probe"))] use super::{AsyncHttpStream, AsyncTcpStream}; -#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))] -use super::{AsyncHttpStream, AsyncTcpStream}; use crate::Error; -#[cfg(feature = "rustls")] +// === SYNC rustls === + pub type SecuredStream = StreamOwned; -#[cfg(feature = "rustls")] static CONFIG: OnceLock> = OnceLock::new(); -#[cfg(feature = "rustls")] fn build_client_config() -> Arc { let mut root_certificates = RootCertStore::empty(); @@ -49,8 +39,7 @@ fn build_client_config() -> Arc { Arc::new(config) } -#[cfg(feature = "rustls")] -pub(super) fn wrap_stream(tcp: TcpStream, host: &str) -> Result { +pub(super) fn wrap_stream(tcp: TcpStream, host: &str) -> Result { #[cfg(feature = "log")] log::trace!("Setting up TLS parameters for {host}."); let dns_name = ServerName::try_from(host) @@ -58,13 +47,15 @@ pub(super) fn wrap_stream(tcp: TcpStream, host: &str) -> Result; @@ -89,63 +80,3 @@ pub(super) async fn wrap_async_stream( Ok(AsyncHttpStream::Secured(Box::new(tls))) } - -#[cfg(all(feature = "native-tls", not(feature = "rustls")))] -pub type SecuredStream = TlsStream; - -#[cfg(all(feature = "native-tls", not(feature = "rustls")))] -static CONNECTOR: OnceLock> = OnceLock::new(); - -#[cfg(all(feature = "native-tls", not(feature = "rustls")))] -fn native_tls_err(e: HandshakeError) -> Error { - match e { - HandshakeError::Failure(e) => Error::NativeTlsError(e), - HandshakeError::WouldBlock(_) => { - debug_assert!(false, "We shouldn't hit a blocking error"); - Error::Other("Got a WouldBlock error from native-tls") - } - } -} - -#[cfg(all(feature = "native-tls", not(feature = "rustls")))] -fn build_tls_connector() -> Result { - TlsConnector::builder().build().map_err(Error::NativeTlsError) -} - -#[cfg(all(feature = "native-tls", not(feature = "rustls")))] -pub(super) fn wrap_stream(tcp: TcpStream, host: &str) -> Result { - #[cfg(feature = "log")] - log::trace!("Setting up TLS parameters for {host}."); - - // TODO: Once we can `get_or_try_init`, so that instead - // https://github.com/rust-lang/rust/issues/109737 - let connector = CONNECTOR.get_or_init(build_tls_connector)?; - - #[cfg(feature = "log")] - log::trace!("Establishing TLS session to {host}."); - - connector.connect(host, tcp).map_err(native_tls_err) -} - -#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))] -pub type AsyncSecuredStream = tokio_native_tls::TlsStream; - -#[cfg(all(feature = "native-tls", not(feature = "rustls"), feature = "tokio-native-tls"))] -pub(super) async fn wrap_async_stream( - tcp: AsyncTcpStream, - host: &str, -) -> Result { - #[cfg(feature = "log")] - log::trace!("Setting up TLS parameters for {host}."); - - // TODO: Once we can `get_or_try_init`, so that instead - // https://github.com/rust-lang/rust/issues/109737 - let connector = AsyncTlsConnector::from(CONNECTOR.get_or_init(build_tls_connector)?.clone()); - - #[cfg(feature = "log")] - log::trace!("Establishing TLS session to {host}."); - - let tls = connector.connect(host, tcp).await.map_err(native_tls_err)?; - - Ok(AsyncHttpStream::Secured(Box::new(tls))) -} diff --git a/bitreq/src/error.rs b/bitreq/src/error.rs index ca9d1421d..62bbd0959 100644 --- a/bitreq/src/error.rs +++ b/bitreq/src/error.rs @@ -19,10 +19,10 @@ pub enum Error { /// The response body contains invalid UTF-8, so the `as_str()` /// conversion failed. InvalidUtf8InBody(str::Utf8Error), - #[cfg(feature = "rustls")] + #[cfg(any(feature = "https-rustls", feature = "https-rustls-probe"))] /// Ran into a rustls error while creating the connection. RustlsCreateConnection(rustls::Error), - #[cfg(feature = "native-tls")] + #[cfg(feature = "https-native-tls")] /// Ran into a native-tls error while creating the connection. NativeTlsCreateConnection(native_tls::Error), /// Ran into an IO problem while loading the response. @@ -102,10 +102,10 @@ impl fmt::Display for Error { IoError(err) => write!(f, "{}", err), InvalidUrl(err) => write!(f, "failed to parse given URL: {}", err), InvalidUtf8InBody(err) => write!(f, "{}", err), - #[cfg(feature = "rustls")] + #[cfg(any(feature = "https-rustls", feature = "https-rustls-probe"))] RustlsCreateConnection(err) => write!(f, "error creating rustls connection: {}", err), - #[cfg(feature = "native-tls")] - NativeTlsCreateConnection(err) => write!(f, "error creating native-tls connection: {err}"), + #[cfg(feature = "https-native-tls")] + NativeTlsCreateConnection(err) => write!(f, "error creating native-tls connection: {}", err), MalformedChunkLength => write!(f, "non-usize chunk length with transfer-encoding: chunked"), MalformedChunkEnd => write!(f, "chunk did not end after reading the expected amount of bytes"), MalformedContentLength => write!(f, "non-usize content length"), @@ -145,7 +145,7 @@ impl error::Error for Error { IoError(err) => Some(err), InvalidUrl(err) => Some(err), InvalidUtf8InBody(err) => Some(err), - #[cfg(feature = "rustls")] + #[cfg(any(feature = "https-rustls", feature = "https-rustls-probe"))] RustlsCreateConnection(err) => Some(err), _ => None, } @@ -160,3 +160,8 @@ impl From for Error { impl From for Error { fn from(other: UrlParseError) -> Error { Error::InvalidUrl(other) } } + +#[cfg(feature = "https-native-tls")] +impl From for Error { + fn from(err: native_tls::Error) -> Error { Error::NativeTlsCreateConnection(err) } +} diff --git a/bitreq/tests/main.rs b/bitreq/tests/main.rs index 8d357f354..e7e204e94 100644 --- a/bitreq/tests/main.rs +++ b/bitreq/tests/main.rs @@ -8,7 +8,7 @@ use std::io; use self::setup::*; #[tokio::test] -#[cfg(feature = "rustls")] +#[cfg(any(feature = "https-rustls", feature = "https-rustls-probe"))] async fn test_https() { // TODO: Implement this locally. assert_eq!(get_status_code(bitreq::get("https://example.com")).await, 200); @@ -16,6 +16,39 @@ async fn test_https() { assert_eq!(get_status_code(bitreq::get("https://example.com")).await, 200); } +#[tokio::test] +#[cfg(all( + feature = "async-https-native-tls", + not(any(feature = "https-rustls", feature = "https-rustls-probe")) +))] +async fn test_https() { + // TODO: Implement this locally. + assert_eq!(get_status_code(bitreq::get("https://example.com")).await, 200); + // Test reusing the existing connection in client: + assert_eq!(get_status_code(bitreq::get("https://example.com")).await, 200); +} + +#[tokio::test] +#[cfg(any(feature = "async-https-rustls", feature = "async-https-rustls-probe"))] +async fn test_https_with_client() { + setup(); + let client = bitreq::Client::new(1); + let response = client.send_async(bitreq::get("https://example.com")).await.unwrap(); + assert_eq!(response.status_code, 200); +} + +#[tokio::test] +#[cfg(all( + feature = "async-https-native-tls", + not(any(feature = "https-rustls", feature = "https-rustls-probe")) +))] +async fn test_https_with_client() { + setup(); + let client = bitreq::Client::new(1); + let response = client.send_async(bitreq::get("https://example.com")).await.unwrap(); + assert_eq!(response.status_code, 200); +} + #[tokio::test] #[cfg(feature = "json-using-serde")] async fn test_json_using_serde() {