Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ test = [
"pytest-cov>=7.1.0",
"pytest-asyncio>=1.3.0",
"python-dotenv>=1.2.2",
"vcrpy>=8.1.1",
]
release = ["build>=1.4.3", "jinja2>=3.1.6", "sphinx>=9.1.0", "twine>=6.2.0"]
lint = ["basedpyright>=1.39.0", "ruff>=0.15.10"]
Expand Down
10 changes: 7 additions & 3 deletions splunklib/ai/engines/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@
_testing_force_tool_strategy = False


def _thread_id_new_uuid() -> str:
return str(uuid.uuid4())


def _supports_provider_strategy(model: BaseChatModel) -> bool:
return (
model.profile is not None
Expand Down Expand Up @@ -365,16 +369,16 @@ async def awrap_model_call(
# LLM halucinated a thread_id, start a new conversation instead.
# This should not happen, since we provide an enum above, but just
# in case.
args.thread_id = str(uuid.uuid4())
args.thread_id = _thread_id_new_uuid()

if args.thread_id and args.thread_id in called_thread_ids:
# LLM did not listen not to issue multiple calls to the
# same thread_id, start a new conversation instead.
args.thread_id = str(uuid.uuid4())
args.thread_id = _thread_id_new_uuid()

if not args.thread_id:
# Generate thread_id for a new conversation.
args.thread_id = str(uuid.uuid4())
args.thread_id = _thread_id_new_uuid()

called_thread_ids.add(args.thread_id)
call["args"] = asdict(args)
Expand Down
2 changes: 2 additions & 0 deletions tests/ai_test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ async def _buildInternalAIModel(
auth=(client_id, client_secret),
)

response.raise_for_status()

token = _TokenResponse.model_validate_json(response.text).access_token

auth_handler = _InternalAIAuth(token)
Expand Down
160 changes: 159 additions & 1 deletion tests/ai_testlib.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,16 @@
from typing import override
import functools
import inspect
import json
import os
from collections.abc import Callable, Coroutine
from typing import Any, override
from unittest.mock import patch
from urllib import parse

import vcr
from vcr.config import RecordMode
from vcr.request import Request

from splunklib.ai.model import PredefinedModel
from tests.ai_test_model import InternalAIModel, TestLLMSettings, create_model
from tests.testlib import SDKTestCase
Expand Down Expand Up @@ -42,3 +54,149 @@ async def model(self) -> PredefinedModel:
model = await create_model(self.test_llm_settings)
self._model = model
return model


def ai_snapshot_test() -> Callable[
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] the name suggests it can be used only for AI. What do you think about record_snapshot name?

Copy link
Copy Markdown
Member Author

@mateusz834 mateusz834 Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the name suggests it can be used only for AI.

We do filtering here, so i think it only works for AI 😄 .

[Callable[..., Coroutine[Any, Any, None]]], Callable[..., Coroutine[Any, Any, None]]
]:
def decorator(
fn: Callable[..., Coroutine[Any, Any, None]],
) -> Callable[..., Coroutine[Any, Any, None]]:
source_file = inspect.getfile(fn)
test_dir = os.path.dirname(source_file)
test_file = os.path.splitext(os.path.basename(source_file))[0]

snapshot_dir = os.path.join(test_dir, "snapshots", test_file)
snapshot_filename = f"{fn.__qualname__}.json"

@functools.wraps(fn)
async def wrapper(self: AITestCase, *args: Any, **kwargs: Any) -> None:
settings = self.test_llm_settings
assert settings.internal_ai is not None

internal_ai_hostname = parse.urlparse(
settings.internal_ai.base_url
).hostname
assert internal_ai_hostname is not None

REDACTED_APP_KEY = "[[[--APPKEY-REDACTED-]]]"

class _JSONFriendlySerializer:
def deserialize(self, serialized: str) -> Any:
assert settings.internal_ai is not None
serialized = serialized.replace(
REDACTED_APP_KEY, settings.internal_ai.app_key
)

data = json.loads(serialized)
for interaction in data.get("interactions", []):
interaction["request"]["uri"] = interaction["request"][
"uri"
].replace("internal-ai-host", internal_ai_hostname, 1)

interaction["request"]["body"] = json.dumps(
interaction["request"]["body"]
)
body = interaction["response"]["body"]
interaction["response"]["body"] = {}
interaction["response"]["body"]["string"] = json.dumps(body)

return data

def serialize(self, dict: Any) -> str:
for interaction in dict.get("interactions", []):
interaction["request"]["uri"] = interaction["request"][
"uri"
].replace(internal_ai_hostname, "internal-ai-host", 1)

body = interaction["request"]["body"]
interaction["request"]["body"] = json.loads(body)

resp_body = interaction["response"]["body"]["string"]
interaction["response"]["body"] = json.loads(resp_body)

out = json.dumps(dict, indent=4) + "\n"
assert settings.internal_ai is not None
out = out.replace(settings.internal_ai.app_key, REDACTED_APP_KEY)

# Assert that nothing is leaking into the public snapshots.
assert internal_ai_hostname not in out.lower()
assert settings.internal_ai.app_key.lower() not in out.lower()
assert settings.internal_ai.base_url.lower() not in out.lower()
assert settings.internal_ai.token_url.lower() not in out.lower()
assert settings.internal_ai.client_id.lower() not in out.lower()
assert settings.internal_ai.client_secret.lower() not in out.lower()

return out

def _before_record_request(request: Request) -> Request | None:
url = parse.urlparse(request.uri) # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType]
if url.hostname == internal_ai_hostname:
request.headers = {}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we keep some specific headers for like Authorization and keep comparing them?

return request
return None

def _before_record_response(response: Any) -> Any:
response["headers"] = {}
return response

def _json_body_matcher(r1: Any, r2: Any) -> None:
b1 = json.loads(r1.body)
b2 = json.loads(r2.body)
if b1 != b2:
raise AssertionError(f"Body mismatch:\n{b1}\n!=\n{b2}")

my_vcr = vcr.VCR(
cassette_library_dir=snapshot_dir,
serializer="json-friendly",
record_mode=RecordMode.ONCE,
match_on=[
"method",
"scheme",
"host",
"port",
"path",
"query",
"jsonbody",
],
before_record_request=_before_record_request,
before_record_response=_before_record_response,
# record_on_exception=False,
# drop_unused_requests=True,
)
my_vcr.register_serializer("json-friendly", _JSONFriendlySerializer())
my_vcr.register_matcher("jsonbody", _json_body_matcher)

with my_vcr.use_cassette(snapshot_filename): # pyright: ignore[reportGeneralTypeIssues]
await fn(self, *args, **kwargs)

return wrapper

return decorator


def deterministic_thread_ids() -> Callable[
[Callable[..., Coroutine[Any, Any, None]]], Callable[..., Coroutine[Any, Any, None]]
]:
def decorator(
fn: Callable[..., Coroutine[Any, Any, None]],
) -> Callable[..., Coroutine[Any, Any, None]]:
@functools.wraps(fn)
async def wrapper(self: AITestCase, *args: Any, **kwargs: Any) -> None:
counter = 0

def _deterministic_uuid() -> str:
nonlocal counter
result = f"00000000-0000-0000-0000-{counter:012d}"
counter += 1
return result

with patch(
"splunklib.ai.engines.langchain._thread_id_new_uuid",
side_effect=_deterministic_uuid,
):
await fn(self, *args, **kwargs)

return wrapper

return decorator
Loading
Loading