Skip to content

perf: vectorize KV cache prefix matching with numpy#2179

Open
nausicaalii wants to merge 2 commits intoabetlen:mainfrom
nausicaalii:perf/vectorize-prefix-match
Open

perf: vectorize KV cache prefix matching with numpy#2179
nausicaalii wants to merge 2 commits intoabetlen:mainfrom
nausicaalii:perf/vectorize-prefix-match

Conversation

@nausicaalii
Copy link
Copy Markdown

@nausicaalii nausicaalii commented Apr 11, 2026

Summary

  • Replace Python for-loop in generate() and longest_token_prefix() with numpy vectorized element-wise comparison
  • Uses np.argmin on a boolean equality array to find the first mismatch position in a single vectorized pass
  • Deduplicate inline prefix matching in generate() into a single longest_token_prefix() call
  • Remove unnecessary .tolist() conversions in _create_completion() so numpy arrays are compared directly

Motivation

The current prefix matching iterates token-by-token in Python to find where the cached prompt diverges from the new prompt. This is fine for short prompts, but becomes a bottleneck as conversation history grows — multi-turn chat sessions can accumulate 10K–100K+ tokens in input_ids, and the linear Python loop runs on every generate() call.

Numpy's vectorized comparison runs in optimized C/SIMD, giving significant speedup for large token sequences while preserving identical behavior.

Profiling results

Benchmarked on Apple M3 Pro, Python 3.12, numpy 2.2. Mismatch placed at 90% through the sequence. 500 trials.

generate() hot path

self._input_ids (numpy array) vs tokens (Python list):

Tokens Metric Before After Speedup
100 avg 4.7 µs 3.4 µs 1.4x
p50 4.7 µs 3.4 µs 1.4x
p90 4.8 µs 3.5 µs 1.4x
1,000 avg 46.6 µs 19.9 µs 2.3x
p50 47.2 µs 19.5 µs 2.4x
p90 48.9 µs 21.2 µs 2.3x
10,000 avg 429.9 µs 176.2 µs 2.4x
p50 428.2 µs 171.4 µs 2.5x
p90 453.8 µs 189.4 µs 2.4x
50,000 avg 2,142.1 µs 886.4 µs 2.4x
p50 2,128.9 µs 885.6 µs 2.4x
p90 2,225.3 µs 908.8 µs 2.4x
100,000 avg 4,248.8 µs 1,791.4 µs 2.4x
p50 4,225.9 µs 1,782.2 µs 2.4x
p90 4,381.4 µs 1,839.1 µs 2.4x

_create_completion() cache lookup

Both inputs are numpy arrays (eliminated .tolist() conversion):

Tokens Metric Before After Speedup
100 avg 1.9 µs 1.0 µs 1.9x
p50 1.8 µs 1.0 µs 1.8x
p90 2.0 µs 1.2 µs 1.7x
1,000 avg 20.8 µs 1.1 µs 18.9x
p50 21.5 µs 1.0 µs 21.5x
p90 21.7 µs 1.2 µs 18.1x
10,000 avg 214.3 µs 1.8 µs 119.1x
p50 215.0 µs 1.8 µs 119.4x
p90 227.5 µs 2.0 µs 113.8x
50,000 avg 1,090.0 µs 5.4 µs 201.9x
p50 1,076.5 µs 5.3 µs 203.1x
p90 1,160.5 µs 5.9 µs 196.7x
100,000 avg 2,135.3 µs 9.7 µs 220.1x
p50 2,136.6 µs 9.5 µs 224.9x
p90 2,175.2 µs 10.5 µs 207.2x
Benchmark script
import time
import numpy as np
from typing import Sequence


def longest_token_prefix_before(a, b):
    longest_prefix = 0
    for _a, _b in zip(a, b):
        if _a == _b:
            longest_prefix += 1
        else:
            break
    return longest_prefix


def longest_token_prefix_after(a, b):
    n = min(len(a), len(b))
    if n == 0:
        return 0
    eq = np.asarray(a[:n]) == np.asarray(b[:n])
    mismatch = np.argmin(eq)
    return int(n) if eq[mismatch] else int(mismatch)


def bench(func, args, warmup=10, trials=500):
    for _ in range(warmup):
        func(*args)
    times = []
    for _ in range(trials):
        t0 = time.perf_counter_ns()
        func(*args)
        t1 = time.perf_counter_ns()
        times.append(t1 - t0)
    times.sort()
    avg = sum(times) / len(times)
    p50 = times[len(times) // 2]
    p90 = times[int(len(times) * 0.9)]
    return avg, p50, p90


sizes = [100, 1_000, 10_000, 50_000, 100_000]

# Scenario 1: generate() hot path — cached is numpy, tokens is list
print("generate() hot path")
for n in sizes:
    rng = np.random.default_rng(42)
    shared = rng.integers(0, 32000, size=n, dtype=np.intc)
    mismatch_pos = int(n * 0.9)
    cached_np = shared.copy()
    tokens_list = shared.tolist()
    tokens_list[mismatch_pos] = (tokens_list[mismatch_pos] + 1) % 32000

    b = bench(longest_token_prefix_before, (cached_np, tokens_list))
    a = bench(longest_token_prefix_after, (cached_np, tokens_list))
    print(f"{n:>10,} | avg {b[0]/1e3:.1f} -> {a[0]/1e3:.1f} µs | "
          f"p50 {b[1]/1e3:.1f} -> {a[1]/1e3:.1f} µs | "
          f"p90 {b[2]/1e3:.1f} -> {a[2]/1e3:.1f} µs")

# Scenario 2: _create_completion() — both numpy (no more .tolist())
print("\n_create_completion() cache lookup")
for n in sizes:
    rng = np.random.default_rng(42)
    shared = rng.integers(0, 32000, size=n, dtype=np.intc)
    mismatch_pos = int(n * 0.9)
    a_np = shared.copy()
    b_np = shared.copy()
    b_np[mismatch_pos] = (b_np[mismatch_pos] + 1) % 32000

    a_list = a_np.tolist()
    b_list = b_np.tolist()
    b = bench(longest_token_prefix_before, (a_list, b_list))
    a = bench(longest_token_prefix_after, (a_np, b_np))
    print(f"{n:>10,} | avg {b[0]/1e3:.1f} -> {a[0]/1e3:.1f} µs | "
          f"p50 {b[1]/1e3:.1f} -> {a[1]/1e3:.1f} µs | "
          f"p90 {b[2]/1e3:.1f} -> {a[2]/1e3:.1f} µs")

Test plan

  • Verified longest_token_prefix correctness across edge cases: empty sequences, full match, partial match, single element, no match, different lengths, large sequences (10K tokens)
  • test_real_model — passes (low-level batch decode)
  • test_real_llama — passes (multiple sequential create_completion calls that exercise prefix matching)
  • test_real_llama_embeddings — passes

🤖 Generated with Claude Code

nausicaalii and others added 2 commits April 11, 2026 15:51
Replace O(n) Python for-loop in KV cache prefix matching and
longest_token_prefix() with numpy vectorized comparison.

The element-wise numpy comparison runs in optimized C/SIMD
instead of Python's interpreter loop, which matters as
conversation history grows (10K+ tokens).

No change in behavior — both paths find the first position
where cached and new token sequences diverge.
Replace the inline prefix matching in generate() with a call to
longest_token_prefix(). Remove .tolist() conversions in
_create_completion() so numpy arrays are compared directly, avoiding
list conversion overhead.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant