Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 124 additions & 27 deletions bitreq/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,51 @@ use crate::{Error, Method, ResponseLazy};

type UnsecuredStream = TcpStream;

#[cfg(feature = "rustls")]
#[cfg(any(
feature = "https-rustls",
feature = "https-rustls-probe",
feature = "async-https-rustls",
feature = "async-https-rustls-probe"
))]
mod rustls_stream;
#[cfg(feature = "rustls")]
type SecuredStream = rustls_stream::SecuredStream;

#[cfg(any(
all(
feature = "https-native-tls",
not(any(feature = "https-rustls", feature = "https-rustls-probe"))
),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we intend to allow selection of a separate sync-tls-provider and async one.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is your suggestion?
The alternatives I came up were either convoluted or do not pass in tests like this one:

cargo --locked build --no-default-features "--features=async-https-native-tls https async-https-rustls native-tls rustls rustls-webpki std tokio webpki-roots"

I was using this and it failed:

    #[cfg[all(
        any(feature = "async-https-native-tls", feature = "https-native-tls"),
        not(any(feature = "https-rustls", feature = "https-rustls-probe"))
    )]

Because it was opening all async gates(async backend present) but rustls priority was taking over.

The simplest solution I can think of without a large rework on all async gates(many) would be just ignore sync on priorities when an async feature is on. But that would change in behavior we intended(rustls priority over native-tls). Otherwise we could simply not compile, but that's again a change in behavior.

But, I bet most of the users won't mix with both configs at the same time. So if they set a async feature it will be set for both sync/async. The gate was necessary to work on edge cases like the one above..

all(
feature = "async-https-native-tls",
not(any(feature = "async-https-rustls", feature = "async-https-rustls-probe"))
)
))]
mod native_tls_stream;

#[cfg(all(
feature = "https-native-tls",
not(any(feature = "https-rustls", feature = "https-rustls-probe"))
))]
use self::native_tls_stream as sync_tls_stream;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its probably simpler to just have a sub-module in rustls_stream (maybe renaming it) and re-export. That way there's a single API to connection.rs and we arent conditional-compiling which module to use here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

really? using a single file to store code for both backends is a missing opportunity IMO. You mean like the enabled/disabled stuff I've initially tough for the tls configs on client.rs? That's doable, but I don't see what's are we winning here.

#[cfg(all(
feature = "async-https-native-tls",
not(any(feature = "async-https-rustls", feature = "async-https-rustls-probe"))
))]
use self::native_tls_stream as async_tls_stream;
#[cfg(any(feature = "https-rustls", feature = "https-rustls-probe"))]
use self::rustls_stream as sync_tls_stream;
#[cfg(any(feature = "async-https-rustls", feature = "async-https-rustls-probe"))]
use self::rustls_stream as async_tls_stream;

#[cfg(any(feature = "https-rustls", feature = "https-rustls-probe", feature = "https-native-tls"))]
type SecuredStream = sync_tls_stream::SecuredStream;

pub(crate) enum HttpStream {
Unsecured(UnsecuredStream, Option<Instant>),
#[cfg(feature = "rustls")]
#[cfg(any(
feature = "https-rustls",
feature = "https-rustls-probe",
feature = "https-native-tls"
))]
Secured(Box<SecuredStream>, Option<Instant>),
#[cfg(feature = "async")]
Buffer(std::io::Cursor<Vec<u8>>),
Expand Down Expand Up @@ -81,7 +118,11 @@ impl Read for HttpStream {
timeout(inner, *timeout_at)?;
inner.read(buf)
}
#[cfg(feature = "rustls")]
#[cfg(any(
feature = "https-rustls",
feature = "https-rustls-probe",
feature = "https-native-tls"
))]
HttpStream::Secured(inner, timeout_at) => {
timeout(inner.get_ref(), *timeout_at)?;
inner.read(buf)
Expand Down Expand Up @@ -111,7 +152,11 @@ impl Write for HttpStream {
set_socket_write_timeout(inner, *timeout_at)?;
inner.write(buf)
}
#[cfg(feature = "rustls")]
#[cfg(any(
feature = "https-rustls",
feature = "https-rustls-probe",
feature = "https-native-tls"
))]
HttpStream::Secured(inner, timeout_at) => {
set_socket_write_timeout(inner.get_ref(), *timeout_at)?;
inner.write(buf)
Expand All @@ -137,7 +182,11 @@ impl Write for HttpStream {
set_socket_write_timeout(inner, *timeout_at)?;
inner.flush()
}
#[cfg(feature = "rustls")]
#[cfg(any(
feature = "https-rustls",
feature = "https-rustls-probe",
feature = "https-native-tls"
))]
HttpStream::Secured(inner, timeout_at) => {
set_socket_write_timeout(inner.get_ref(), *timeout_at)?;
inner.flush()
Expand All @@ -158,13 +207,21 @@ impl Write for HttpStream {
}
}

#[cfg(any(feature = "async-https-rustls", feature = "async-https-rustls-probe"))]
type AsyncSecuredStream = rustls_stream::AsyncSecuredStream;
#[cfg(any(
feature = "async-https-rustls",
feature = "async-https-rustls-probe",
feature = "async-https-native-tls"
))]
type AsyncSecuredStream = async_tls_stream::AsyncSecuredStream;

#[cfg(feature = "async")]
pub(crate) enum AsyncHttpStream {
Unsecured(AsyncTcpStream),
#[cfg(any(feature = "async-https-rustls", feature = "async-https-rustls-probe"))]
#[cfg(any(
feature = "async-https-rustls",
feature = "async-https-rustls-probe",
feature = "async-https-native-tls"
))]
Secured(Box<AsyncSecuredStream>),
}

Expand All @@ -177,7 +234,11 @@ impl AsyncRead for AsyncHttpStream {
) -> Poll<io::Result<()>> {
match &mut *self {
AsyncHttpStream::Unsecured(inner) => Pin::new(inner).poll_read(cx, buf),
#[cfg(any(feature = "async-https-rustls", feature = "async-https-rustls-probe"))]
#[cfg(any(
feature = "async-https-rustls",
feature = "async-https-rustls-probe",
feature = "async-https-native-tls"
))]
AsyncHttpStream::Secured(inner) => Pin::new(inner).poll_read(cx, buf),
}
}
Expand All @@ -192,23 +253,35 @@ impl AsyncWrite for AsyncHttpStream {
) -> Poll<io::Result<usize>> {
match &mut *self {
AsyncHttpStream::Unsecured(inner) => Pin::new(inner).poll_write(cx, buf),
#[cfg(any(feature = "async-https-rustls", feature = "async-https-rustls-probe"))]
#[cfg(any(
feature = "async-https-rustls",
feature = "async-https-rustls-probe",
feature = "async-https-native-tls"
))]
AsyncHttpStream::Secured(inner) => Pin::new(inner).poll_write(cx, buf),
}
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &mut *self {
AsyncHttpStream::Unsecured(inner) => Pin::new(inner).poll_flush(cx),
#[cfg(any(feature = "async-https-rustls", feature = "async-https-rustls-probe"))]
#[cfg(any(
feature = "async-https-rustls",
feature = "async-https-rustls-probe",
feature = "async-https-native-tls"
))]
AsyncHttpStream::Secured(inner) => Pin::new(inner).poll_flush(cx),
}
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &mut *self {
AsyncHttpStream::Unsecured(inner) => Pin::new(inner).poll_shutdown(cx),
#[cfg(any(feature = "async-https-rustls", feature = "async-https-rustls-probe"))]
#[cfg(any(
feature = "async-https-rustls",
feature = "async-https-rustls-probe",
feature = "async-https-native-tls"
))]
AsyncHttpStream::Secured(inner) => Pin::new(inner).poll_shutdown(cx),
}
}
Expand Down Expand Up @@ -271,13 +344,7 @@ impl AsyncConnection {
let socket = Self::connect(params).await?;

if params.https {
#[cfg(not(any(
feature = "async-https-rustls",
feature = "async-https-rustls-probe"
)))]
return Err(Error::HttpsFeatureNotEnabled);
#[cfg(any(feature = "async-https-rustls", feature = "async-https-rustls-probe"))]
rustls_stream::wrap_async_stream(socket, params.host).await
Self::wrap_async_stream(socket, params.host).await
} else {
Ok(AsyncHttpStream::Unsecured(socket))
}
Expand All @@ -301,6 +368,31 @@ impl AsyncConnection {
}))))
}

#[cfg(any(
feature = "async-https-rustls",
feature = "async-https-rustls-probe",
feature = "async-https-native-tls"
))]
async fn wrap_async_stream(
socket: AsyncTcpStream,
host: &str,
) -> Result<AsyncHttpStream, Error> {
async_tls_stream::wrap_async_stream(socket, host).await
}

/// Error treatment function, should not be called under normal circustances
#[cfg(not(any(
feature = "async-https-rustls",
feature = "async-https-rustls-probe",
feature = "async-https-native-tls"
)))]
async fn wrap_async_stream(
_socket: AsyncTcpStream,
_host: &str,
) -> Result<AsyncHttpStream, Error> {
Err(Error::HttpsFeatureNotEnabled)
}

async fn tcp_connect(host: &str, port: u16) -> Result<AsyncTcpStream, Error> {
#[cfg(feature = "log")]
log::trace!("Looking up host {host}");
Expand Down Expand Up @@ -656,13 +748,18 @@ impl Connection {
let socket = Self::connect(params, timeout_at)?;

let stream = if params.https {
#[cfg(not(feature = "rustls"))]
#[cfg(not(any(
feature = "https-rustls",
feature = "https-rustls-probe",
feature = "https-native-tls"
)))]
return Err(Error::HttpsFeatureNotEnabled);
#[cfg(feature = "rustls")]
{
let tls = rustls_stream::wrap_stream(socket, params.host)?;
HttpStream::Secured(Box::new(tls), timeout_at)
}
#[cfg(any(
feature = "https-rustls",
feature = "https-rustls-probe",
feature = "https-native-tls"
))]
sync_tls_stream::wrap_stream(socket, params.host)?
} else {
HttpStream::create_unsecured(socket, timeout_at)
};
Expand Down
85 changes: 85 additions & 0 deletions bitreq/src/connection/native_tls_stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
//! Native-TLS connection handling functionality.
//! This module is only compiled when a native-tls HTTPS feature is enabled
//! AND no rustls feature is enabled (mutual exclusion enforced at module level).

use std::io;
use std::net::TcpStream;
use std::sync::OnceLock;

use native_tls::{HandshakeError, TlsConnector, TlsStream};
#[cfg(feature = "async-https-native-tls")]
use tokio_native_tls::TlsConnector as AsyncTlsConnector;

use super::HttpStream;
#[cfg(feature = "async-https-native-tls")]
use super::{AsyncHttpStream, AsyncTcpStream};
use crate::Error;

// === SYNC native-tls ===

pub type SecuredStream = TlsStream<TcpStream>;

static CONNECTOR: OnceLock<Result<TlsConnector, Error>> = OnceLock::new();

fn native_tls_err<S>(e: HandshakeError<S>) -> Error {
match e {
HandshakeError::Failure(err) => Error::NativeTlsCreateConnection(err),
HandshakeError::WouldBlock(_) => {
debug_assert!(false, "We shouldn't hit a blocking error");
Error::Other("Got a WouldBlock error from native-tls")
}
}
}

fn build_tls_connector() -> Result<TlsConnector, Error> {
TlsConnector::builder().build().map_err(Error::from)
}

pub(super) fn wrap_stream(tcp: TcpStream, host: &str) -> Result<HttpStream, Error> {
#[cfg(feature = "log")]
log::trace!("Setting up TLS parameters for {host}.");

// TODO: Once we can `get_or_try_init`, so that instead
// https://github.com/rust-lang/rust/issues/109737
let connector = match CONNECTOR.get_or_init(build_tls_connector) {
Ok(c) => c.clone(),
Err(err) => return Err(Error::IoError(io::Error::new(io::ErrorKind::Other, err))),
};

#[cfg(feature = "log")]
log::trace!("Establishing TLS session to {host}.");

let tls = connector.connect(host, tcp).map_err(native_tls_err)?;

Ok(HttpStream::Secured(Box::new(tls), None))
}

// === ASYNC native-tls ===
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't commit claude's useless code section definition comments for a tiny file.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TheBlueMatt come on, do you even read what you write? I'm trying to help here but it's been really hard to work with you.

You commited the same type o comment in the master. this one is no different:
https://github.com/rust-bitcoin/corepc/blame/46c7bf4284748ee3da549d4463841e67db3bd00e/bitreq/src/connection/rustls_stream.rs#L69
image

I don't even use the fking Claude. I'm spending hours on all those pull requests. Come on. Let's be a bit more professionals here.

I agree 100% about removing all those useless comments, but this is the useless pattern you've already used there..


#[cfg(feature = "async-https-native-tls")]
pub type AsyncSecuredStream = tokio_native_tls::TlsStream<tokio::net::TcpStream>;

#[cfg(feature = "async-https-native-tls")]
pub(super) async fn wrap_async_stream(
tcp: AsyncTcpStream,
host: &str,
) -> Result<AsyncHttpStream, Error> {
#[cfg(feature = "log")]
log::trace!("Setting up TLS parameters for {host}.");

// TODO: Once we can `get_or_try_init`, so that instead
// https://github.com/rust-lang/rust/issues/109737
let sync_connector = match CONNECTOR.get_or_init(build_tls_connector) {
Ok(c) => c.clone(),
Err(err) => return Err(Error::IoError(io::Error::new(io::ErrorKind::Other, err))),
};

let async_connector = AsyncTlsConnector::from(sync_connector);

#[cfg(feature = "log")]
log::trace!("Establishing TLS session to {host}.");

let tls = async_connector.connect(host, tcp).await?;

Ok(AsyncHttpStream::Secured(Box::new(tls)))
}
Loading
Loading