Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 84 additions & 29 deletions bitreq/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

use std::collections::{hash_map, HashMap, VecDeque};
use std::sync::{Arc, Mutex};
use std::time::Instant;

use crate::connection::AsyncConnection;
use crate::request::{OwnedConnectionParams as ConnectionKey, ParsedRequest};
Expand All @@ -16,7 +17,10 @@ use crate::{Error, Request, Response};
/// A client that caches connections for reuse.
///
/// The client maintains a pool of up to `capacity` connections, evicting
/// the least recently used connection when the cache is full.
/// the least recently used connection when the cache is full. Pooled
/// connections are validated on every acquire: an entry whose keep-alive
/// deadline has passed — or whose underlying socket has been poisoned by
/// a previous failure — is dropped and a fresh connection is opened.
///
/// # Example
///
Expand Down Expand Up @@ -62,39 +66,90 @@ impl Client {
pub async fn send_async(&self, request: Request) -> Result<Response, Error> {
let parsed_request = ParsedRequest::new(request)?;
let key = parsed_request.connection_params();
let owned_key = key.into();
let owned_key: ConnectionKey = key.into();

// Try to get cached connection
let conn_opt = {
let state = self.r#async.lock().unwrap();

if let Some(conn) = state.connections.get(&owned_key) {
Some(Arc::clone(conn))
} else {
None
let conn = match self.acquire_pooled(&owned_key) {
Some(conn) => conn,
None => {
// On a miss, pre-insert the fresh `Arc` so concurrent
// callers arriving before this send completes can clone
// it and share the socket for pipelining. A send failure
// or non-keep-alive response will evict in the post-send
// check below; the `reusable_until` probe in
// `acquire_pooled` keeps subsequent callers from using a
// poisoned `Arc` even during that window.
let conn = Arc::new(AsyncConnection::new(key, parsed_request.timeout_at).await?);
self.insert_if_vacant(owned_key.clone(), Arc::clone(&conn));
conn
}
};
let conn = if let Some(conn) = conn_opt {
conn
} else {
let connection = AsyncConnection::new(key, parsed_request.timeout_at).await?;
let connection = Arc::new(connection);

let mut state = self.r#async.lock().unwrap();
if let hash_map::Entry::Vacant(entry) = state.connections.entry(owned_key) {
entry.insert(Arc::clone(&connection));
state.lru_order.push_back(key.into());
if state.connections.len() > state.capacity {
if let Some(oldest_key) = state.lru_order.pop_front() {
state.connections.remove(&oldest_key);
}
}

let result = conn.send(parsed_request).await;

// Evict when the send poisoned the connection — covers write /
// read errors, `Connection: close`, and malformed `Keep-Alive`,
// all of which `AsyncConnection::send` signals by setting
// `next_request_id = usize::MAX`.
if conn.reusable_until().is_none() {
self.evict(&owned_key);
}

result
}

/// Returns a pooled connection for `key` if one is present and still
/// reusable per its own [`AsyncConnection::reusable_until`] — no
/// sidecar expiry needs to be tracked because the connection already
/// refreshes its `socket_new_requests_timeout` from the server's
/// `Keep-Alive: timeout=N` header on every successful response.
/// Otherwise evicts the stale entry and returns `None`.
fn acquire_pooled(&self, key: &ConnectionKey) -> Option<Arc<AsyncConnection>> {
let mut state = self.r#async.lock().unwrap();
let conn = state.connections.get(key)?;
let reusable = conn.reusable_until().is_some_and(|t| t > Instant::now());
if !reusable {
state.connections.remove(key);
if let Some(pos) = state.lru_order.iter().position(|k| k == key) {
state.lru_order.remove(pos);
}
connection
};
return None;
}
let connection = Arc::clone(conn);
// Refresh LRU position so this hit is treated as the most recent use.
if let Some(pos) = state.lru_order.iter().position(|k| k == key) {
state.lru_order.remove(pos);
}
state.lru_order.push_back(key.clone());
Some(connection)
}

/// Inserts `connection` under `key` only if the slot is vacant. On a
/// pool-hit the entry is already there (we cloned the `Arc` during
/// acquire), so this is a no-op on that path. On a pool-miss, a
/// concurrent caller may have raced us and already placed a different
/// `Arc` under this key — "first writer wins," and we drop ours.
fn insert_if_vacant(&self, key: ConnectionKey, connection: Arc<AsyncConnection>) {
let mut state = self.r#async.lock().unwrap();
if let hash_map::Entry::Vacant(entry) = state.connections.entry(key.clone()) {
entry.insert(connection);
state.lru_order.push_back(key);
while state.connections.len() > state.capacity {
let oldest = match state.lru_order.pop_front() {
Some(k) => k,
None => break,
};
state.connections.remove(&oldest);
}
}
}

// Send the request
conn.send(parsed_request).await
/// Removes any pool entry for `key`. No-op if the slot is already empty.
fn evict(&self, key: &ConnectionKey) {
let mut state = self.r#async.lock().unwrap();
state.connections.remove(key);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should evict by identity, not just by ConnectionKey.

There is a race where an older poisoned Arc can finish later and remove a newer healthy entry for the same key:

  1. tasks A and B are both using the old pooled connection for K
  2. A fails and marks the old connection dead
  3. task C opens and inserts a fresh connection for K
  4. B reaches this cleanup path later and remove(key) drops C's fresh entry

That doesn't break in-flight requests, but it does create avoidable reconnect churn under load. I think evict needs to compare the currently pooled Arc against the one that just finished and only remove on Arc::ptr_eq.

Permalinks:

if let Some(pos) = state.lru_order.iter().position(|k| k == key) {
state.lru_order.remove(pos);
}
}
}

Expand Down
19 changes: 19 additions & 0 deletions bitreq/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,25 @@ impl AsyncConnection {
}))))
}

/// Returns the deadline until which this connection may accept further requests,
/// or `None` if the inner socket has been poisoned and must not be reused.
///
/// A `None` result means the connection's `next_request_id` has been set to
/// `usize::MAX` — every failure path in [`AsyncConnection::send`] (write error,
/// read error, `Connection: close`, malformed `Keep-Alive`) raises that flag,
/// so callers can treat `None` as "drop from the pool". A `Some(instant)`
/// result is the current value of `socket_new_requests_timeout`, which
/// [`AsyncConnection::send`] refreshes from the server's `Keep-Alive: timeout=N`
/// header.
pub(crate) fn reusable_until(&self) -> Option<Instant> {
let state = Arc::clone(&*self.0.lock().unwrap());
if state.next_request_id.load(Ordering::Acquire) == usize::MAX {
None
} else {
Some(*state.socket_new_requests_timeout.lock().unwrap())
}
}

async fn tcp_connect(host: &str, port: u16) -> Result<AsyncTcpStream, Error> {
#[cfg(feature = "log")]
log::trace!("Looking up host {host}");
Expand Down
107 changes: 107 additions & 0 deletions bitreq/tests/async_pool_lifecycle.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
//! Regression test for the async [`Client`](bitreq::Client) pool's LRU
//! bookkeeping: a cache hit must move the entry to the most-recently-used
//! slot, otherwise capacity-driven eviction drops still-warm keys.

#![cfg(feature = "async")]

use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;

use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};

async fn bind_ephemeral() -> (TcpListener, String) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
let base_url = format!("http://127.0.0.1:{}", port);
(listener, base_url)
}

/// Reads bytes from `stream` until the HTTP header terminator `\r\n\r\n`
/// is seen. Returns the accumulated buffer. Assumes no request body, which
/// is true for the GETs issued by this test.
async fn read_request_headers(stream: &mut TcpStream) -> std::io::Result<Vec<u8>> {
let mut buf = Vec::with_capacity(512);
let mut chunk = [0u8; 256];
loop {
let n = stream.read(&mut chunk).await?;
if n == 0 {
return Err(std::io::ErrorKind::UnexpectedEof.into());
}
buf.extend_from_slice(&chunk[..n]);
if buf.windows(4).any(|w| w == b"\r\n\r\n") {
return Ok(buf);
}
}
}

const KEEP_ALIVE_RESPONSE: &[u8] =
b"HTTP/1.1 200 OK\r\nContent-Length: 3\r\nConnection: keep-alive\r\nKeep-Alive: timeout=60\r\n\r\nok\n";

#[tokio::test]
async fn pool_hit_refreshes_lru_position() {
// Capacity = 2, three distinct hosts (= three distinct `ConnectionKey`s
// because the port differs). Request order: a, b, a, c, a.
//
// A correct LRU refresh-on-hit moves `a` to the most-recent slot at
// step 3, so step 4's capacity-driven eviction drops `b`, and step 5
// is a cache hit on `a` — three TCP accepts total. A pool that does
// not refresh LRU on hit still has `a` as the oldest entry after
// step 3, so step 4 evicts `a` instead, and step 5 is a miss —
// four TCP accepts total.
async fn run_server(listener: TcpListener, accepts: Arc<AtomicUsize>) {
loop {
let (mut stream, _) = match listener.accept().await {
Ok(s) => s,
Err(_) => return,
};
accepts.fetch_add(1, Ordering::SeqCst);
tokio::spawn(async move {
loop {
if read_request_headers(&mut stream).await.is_err() {
return;
}
if stream.write_all(KEEP_ALIVE_RESPONSE).await.is_err() {
return;
}
}
});
}
}

let (listener_a, url_a) = bind_ephemeral().await;
let (listener_b, url_b) = bind_ephemeral().await;
let (listener_c, url_c) = bind_ephemeral().await;

let accepts_a = Arc::new(AtomicUsize::new(0));
let accepts_b = Arc::new(AtomicUsize::new(0));
let accepts_c = Arc::new(AtomicUsize::new(0));

let srv_a = tokio::spawn(run_server(listener_a, Arc::clone(&accepts_a)));
let srv_b = tokio::spawn(run_server(listener_b, Arc::clone(&accepts_b)));
let srv_c = tokio::spawn(run_server(listener_c, Arc::clone(&accepts_c)));

let client = bitreq::Client::new(2);
for url in [&url_a, &url_b, &url_a, &url_c, &url_a] {
let response = client.send_async(bitreq::get(format!("{}/x", url))).await.unwrap();
assert_eq!(response.status_code, 200);
assert_eq!(response.as_bytes(), b"ok\n");
}

srv_a.abort();
srv_b.abort();
srv_c.abort();
let _ = tokio::join!(srv_a, srv_b, srv_c);

let total = accepts_a.load(Ordering::SeqCst)
+ accepts_b.load(Ordering::SeqCst)
+ accepts_c.load(Ordering::SeqCst);
assert_eq!(
total, 3,
"request sequence a,b,a,c,a with capacity=2 must refresh a's LRU \
position on the second hit, keeping it warm past the c-driven \
eviction — expected 3 accepts (miss a, miss b, miss c), got {}",
total,
);
assert_eq!(accepts_a.load(Ordering::SeqCst), 1, "a must be reused, not re-opened");
}
Loading