From 295e3bf83295998d040f9833d23168a63780296d Mon Sep 17 00:00:00 2001 From: skolton Date: Wed, 22 Apr 2026 21:45:51 +0200 Subject: [PATCH] Handle AIMessage.content properly --- splunklib/ai/engines/langchain.py | 109 ++++++++-- splunklib/ai/messages.py | 32 ++- .../unit/ai/engine/test_langchain_backend.py | 190 ++++++++++++++++++ 3 files changed, 317 insertions(+), 14 deletions(-) diff --git a/splunklib/ai/engines/langchain.py b/splunklib/ai/engines/langchain.py index 76fa100b..08e67fa7 100644 --- a/splunklib/ai/engines/langchain.py +++ b/splunklib/ai/engines/langchain.py @@ -77,7 +77,9 @@ AgentResponse, AIMessage, BaseMessage, + ContentBlock, HumanMessage, + OpaqueBlock, OutputT, StructuredOutputCall, StructuredOutputMessage, @@ -87,6 +89,7 @@ SubagentStructuredResult, SubagentTextResult, SystemMessage, + TextBlock, ToolCall, ToolFailureResult, ToolMessage, @@ -951,7 +954,7 @@ async def awrap_tool_call( return LC_ToolMessage( name=_normalize_agent_name(call.name), tool_call_id=call.id, - content=content, + content=_map_content_to_langchain(content), status=status, artifact=sdk_result, ) @@ -1085,7 +1088,10 @@ def _convert_model_response_to_model_result( # This invariant is asserted via ModelResponse.__post_init__ assert len(resp.message.structured_output_calls) <= 1 - lc_message = LC_AIMessage(content=resp.message.content) + lc_message = LC_AIMessage( + content=_map_content_to_langchain(resp.message.content), + additional_kwargs=resp.message.extras or {}, + ) # This field can't be set via __init__() lc_message.tool_calls = [_map_tool_call_to_langchain(c) for c in resp.message.calls] @@ -1160,7 +1166,7 @@ def _convert_tool_message_to_lc( name=name, tool_call_id=message.call_id, status=status, - content=content, + content=_map_content_to_langchain(content), artifact=artifact, ) @@ -1243,9 +1249,10 @@ def _convert_model_result_from_lc(model_response: LC_ModelCallResult) -> ModelRe ai_message = model_response structured_response = None + additional_kwargs = cast(dict[str, Any], ai_message.additional_kwargs) return ModelResponse( message=AIMessage( - content=ai_message.content.__str__(), + content=_map_content_from_langchain(ai_message.content), # pyright: ignore[reportUnknownArgumentType] calls=[ _map_tool_call_from_langchain(tc) for tc in ai_message.tool_calls @@ -1260,6 +1267,7 @@ def _convert_model_result_from_lc(model_response: LC_ModelCallResult) -> ModelRe for tc in ai_message.tool_calls if tc["name"].startswith(TOOL_STRATEGY_TOOL_PREFIX) ], + extras=additional_kwargs, ), structured_output=structured_response, ) @@ -1433,7 +1441,10 @@ def _agent_as_tool(agent: BaseAgent[OutputT]) -> StructuredTool: async def invoke_agent( message: HumanMessage, thread_id: str | None - ) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]: + ) -> tuple[ + OutputT | str | list[str | ContentBlock], + SubagentStructuredResult | SubagentTextResult, + ]: result = await agent.invoke([message], thread_id=thread_id) if agent.output_schema: @@ -1452,13 +1463,19 @@ async def invoke_agent( async def _run( # pyright: ignore[reportRedeclaration] content: str, thread_id: str - ) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]: + ) -> tuple[ + OutputT | str | list[str | ContentBlock], + SubagentStructuredResult | SubagentTextResult, + ]: return await invoke_agent(HumanMessage(content=content), thread_id) else: async def _run( # pyright: ignore[reportRedeclaration] content: str, - ) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]: + ) -> tuple[ + OutputT | str | list[str | ContentBlock], + SubagentStructuredResult | SubagentTextResult, + ]: return await invoke_agent(HumanMessage(content=content), None) return StructuredTool.from_function( @@ -1471,7 +1488,10 @@ async def _run( # pyright: ignore[reportRedeclaration] async def invoke_agent_structured( content: BaseModel, thread_id: str | None - ) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]: + ) -> tuple[ + OutputT | str | list[str | ContentBlock], + SubagentStructuredResult | SubagentTextResult, + ]: result = await agent.invoke_with_data( instructions="Follow the system prompt.", data=content.model_dump(), @@ -1492,7 +1512,10 @@ async def invoke_agent_structured( async def _run( **kwargs: Any, # noqa: ANN401 - ) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]: + ) -> tuple[ + OutputT | str | list[str | ContentBlock], + SubagentStructuredResult | SubagentTextResult, + ]: content: BaseModel = kwargs["content"] thread_id: str = kwargs["thread_id"] return await invoke_agent_structured(content, thread_id) @@ -1512,7 +1535,10 @@ async def _run( async def _run( **kwargs: Any, # noqa: ANN401 - ) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]: + ) -> tuple[ + OutputT | str | list[str | ContentBlock], + SubagentStructuredResult | SubagentTextResult, + ]: content = InputSchema(**kwargs) return await invoke_agent_structured(content, None) @@ -1564,11 +1590,66 @@ def _map_tool_call_to_langchain(call: ToolCall | SubagentCall) -> LC_ToolCall: return LC_ToolCall(id=call.id, name=name, args=args) +def _map_content_from_langchain( + content: str | list[str | dict[str, Any]], +) -> str | list[str | ContentBlock]: + if isinstance(content, str): + return content + + result_content = [_map_content_block_from_langchain(b) for b in content] + + return result_content + + +def _map_content_block_from_langchain( + block: str | dict[str, Any], +) -> str | ContentBlock: + if isinstance(block, str): + return block + + match block.get("type"): + case "text": + return TextBlock( + text=block["text"], + extras=block.get("extras"), + ) + case _: + # NOTE: we return data we're not handling + # as opaque content blocks so they + # are preserved and sent back to the LLM + return OpaqueBlock(data=block) + + +def _map_content_to_langchain( + content: str | list[str | ContentBlock], +) -> str | list[str | dict[str, Any]]: + if isinstance(content, str): + return content + + result_content = [_map_content_block_to_langchain(b) for b in content] + + return result_content + + +def _map_content_block_to_langchain(block: str | ContentBlock) -> str | dict[str, Any]: + if isinstance(block, str): + return block + + match block: + case TextBlock(): + result: dict[str, Any] = {"type": "text", "text": block.text} + if block.extras: + result["extras"] = block.extras + return result + case OpaqueBlock(): + return block.data + + def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage: match message: case LC_AIMessage(): return AIMessage( - content=message.content.__str__(), + content=_map_content_from_langchain(message.content), # pyright: ignore[reportUnknownArgumentType] calls=[ _map_tool_call_from_langchain(tc) for tc in message.tool_calls @@ -1583,6 +1664,7 @@ def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage: for tc in message.tool_calls if tc["name"].startswith(TOOL_STRATEGY_TOOL_PREFIX) ], + extras=cast(dict[str, Any], message.additional_kwargs), ) case LC_HumanMessage(): return HumanMessage(content=message.content.__str__()) @@ -1597,7 +1679,10 @@ def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage: def _map_message_to_langchain(message: BaseMessage) -> LC_AnyMessage: match message: case AIMessage(): - lc_message = LC_AIMessage(content=message.content) + lc_message = LC_AIMessage( + content=_map_content_to_langchain(message.content), + additional_kwargs=message.extras or {}, + ) # This field can't be set via constructor lc_message.tool_calls = [ _map_tool_call_to_langchain(c) for c in message.calls diff --git a/splunklib/ai/messages.py b/splunklib/ai/messages.py index 04db32b6..bdacfb93 100644 --- a/splunklib/ai/messages.py +++ b/splunklib/ai/messages.py @@ -21,6 +21,31 @@ from splunklib.ai.tools import ToolType +@dataclass(frozen=True) +class TextBlock: + """Plain text content block returned by a model.""" + + text: str + # TODO: should we have the id here as well? + # Provider-specific extras (e.g. Gemini thought signature on text blocks). + extras: dict[str, Any] | None = field(default=None) + + +@dataclass(frozen=True) +class OpaqueBlock: + """Content block of an unrecognized or unsupported type. + + The raw provider dict is preserved in `data` so it can be sent back + to the model unchanged on subsequent calls. + """ + + data: dict[str, Any] + + +# Type alias for all content block variants. +ContentBlock = TextBlock | OpaqueBlock + + @dataclass(frozen=True) class ToolCall: name: str @@ -85,12 +110,15 @@ class AIMessage(BaseMessage): """ role: Literal["assistant"] = field(default="assistant", init=False) - content: str + content: str | list[str | ContentBlock] calls: Sequence[ToolCall | SubagentCall] structured_output_calls: Sequence[StructuredOutputCall] = field( default_factory=tuple ) + # Backend-specific metadata (e.g. provider additional_kwargs) not + # representable in the standard fields. Opaque to callers. + extras: dict[str, Any] | None = field(default=None) @dataclass(frozen=True) @@ -120,7 +148,7 @@ class SubagentTextResult: Returned by subagent calls that don't have an output schema. """ - content: str + content: str | list[str | ContentBlock] @dataclass(frozen=True) diff --git a/tests/unit/ai/engine/test_langchain_backend.py b/tests/unit/ai/engine/test_langchain_backend.py index c02426bd..d4319896 100644 --- a/tests/unit/ai/engine/test_langchain_backend.py +++ b/tests/unit/ai/engine/test_langchain_backend.py @@ -30,10 +30,12 @@ from splunklib.ai.messages import ( AIMessage, HumanMessage, + OpaqueBlock, SubagentCall, SubagentFailureResult, SubagentMessage, SystemMessage, + TextBlock, ToolCall, ToolFailureResult, ToolMessage, @@ -56,6 +58,95 @@ def test_map_message_from_langchain_ai_with_tool_calls(self) -> None: ToolCall(name="lookup", args={"q": "test"}, id="tc-1", type=ToolType.REMOTE) ] + def test_map_message_from_langchain_ai_with_text_content_block(self) -> None: + text_block = { + "type": "text", + "text": "test-content-block", + "extras": { + # simulate gemini model returning thought signature in extra field of text content block + "signature": "EjQKMgEMOdbHDmsQ+BTM6duYJ43i5npxkpn28Ir0VjD1p6w4fUqIdYszIcWx+XcqAW1a8E+Q" + }, + } + message = LC_AIMessage(content=[text_block], tool_calls=[]) + + mapped = lc._map_message_from_langchain(message) + + assert isinstance(mapped, AIMessage) + assert isinstance(mapped.content[0], TextBlock) + assert mapped.content[0].text == "test-content-block" + + def test_map_message_from_langchain_ai_with_list_of_str(self) -> None: + message = LC_AIMessage(content=["one", "two"], tool_calls=[]) + + mapped = lc._map_message_from_langchain(message) + + assert isinstance(mapped, AIMessage) + assert mapped.content == ["one", "two"] + + def test_map_message_from_langchain_ai_with_other_content_block(self) -> None: + content_block = { + "type": "image", + } + message = LC_AIMessage(content=[content_block], tool_calls=[]) + + mapped = lc._map_message_from_langchain(message) + + assert isinstance(mapped, AIMessage) + assert isinstance(mapped.content[0], OpaqueBlock) + assert mapped.content[0].data == content_block + + def test_map_message_from_langchain_ai_with_mixed_content(self) -> None: + content_block = { + "type": "image", + } + text_block = { + "type": "text", + "text": "test", + } + message = LC_AIMessage( + content=[content_block, text_block, "test"], tool_calls=[] + ) + + mapped = lc._map_message_from_langchain(message) + + assert isinstance(mapped, AIMessage) + assert isinstance(mapped.content[0], OpaqueBlock) + assert mapped.content[0].data == content_block + assert isinstance(mapped.content[1], TextBlock) + assert mapped.content[1].text == "test" + assert mapped.content[2] == "test" + + def test_map_message_from_langchain_ai_tool_call_with_additional_kwargs( + self, + ) -> None: + tool_call = LC_ToolCall( + name=f"__local-startup_time", + args={"q": "test"}, + id="tc-2", + ) + # simulate gemini models returning thought signature in additional kwargs + # when calling tools. + additional_kwargs = { + "function_call": {"name": "__local-startup_time", "arguments": "{}"}, + "__gemini_function_call_thought_signatures__": { + "28e28045-9846-4c9c-ab46-97f33bff5a9c": "EjQKMgEMOdbHH9gTl8BkX2uMM52753GCboanCcnUp9XB896IdThnG42GB8lRSkqGGxVbv5JY" + }, + } + message = LC_AIMessage( + content="done", tool_calls=[tool_call], additional_kwargs=additional_kwargs + ) + mapped = lc._map_message_from_langchain(message) + assert isinstance(mapped, AIMessage) + assert mapped.calls == [ + ToolCall( + name="startup_time", + args={"q": "test"}, + id="tc-2", + type=ToolType.LOCAL, + ) + ] + assert mapped.extras == additional_kwargs + def test_map_message_from_langchain_ai_with_agent_call(self) -> None: tool_call = LC_ToolCall( name=f"{lc.AGENT_PREFIX}assistant", @@ -159,6 +250,69 @@ def test_map_message_to_langchain_ai(self) -> None: assert mapped.content == "hi" assert mapped.tool_calls == [LC_ToolCall(name="lookup", args={}, id="tc-1")] + def test_map_message_to_langchain_ai_with_text_content_block(self) -> None: + extras = { + "signature": "EjQKMgEMOdbHDmsQ+BTM6duYJ43i5npxkpn28Ir0VjD1p6w4fUqIdYszIcWx+XcqAW1a8E+Q" + } + message = AIMessage( + content=[ + TextBlock( + text="test-content-block", + extras=extras, + ) + ], + calls=[], + ) + mapped = lc._map_message_to_langchain(message) + + assert isinstance(mapped, LC_AIMessage) + assert isinstance(mapped.content[0], dict) + assert mapped.content[0]["type"] == "text" + assert mapped.content[0]["text"] == "test-content-block" + assert mapped.content[0]["extras"] == extras + + def test_map_message_to_langchain_ai_with_list_of_str(self) -> None: + message = AIMessage( + content=["one", "two"], + calls=[], + ) + mapped = lc._map_message_to_langchain(message) + + assert isinstance(mapped, LC_AIMessage) + assert mapped.content == ["one", "two"] + + def test_map_message_to_langchain_ai_with_opaque_content_block(self) -> None: + some_data = {"type": "unsupported"} + message = AIMessage( + content=[OpaqueBlock(data=some_data)], + calls=[], + ) + mapped = lc._map_message_to_langchain(message) + + assert isinstance(mapped, LC_AIMessage) + assert isinstance(mapped.content[0], dict) + assert mapped.content[0]["type"] == "unsupported" + + def test_map_message_to_langchain_ai_with_mixed_content_block(self) -> None: + some_data = {"type": "unsupported"} + message = AIMessage( + content=[ + OpaqueBlock(data=some_data), + TextBlock(text="test-content-block"), + "test", + ], + calls=[], + ) + mapped = lc._map_message_to_langchain(message) + + assert isinstance(mapped, LC_AIMessage) + assert isinstance(mapped.content[0], dict) + assert mapped.content[0]["type"] == "unsupported" + assert isinstance(mapped.content[1], dict) + assert mapped.content[1]["type"] == "text" + assert mapped.content[1]["text"] == "test-content-block" + assert mapped.content[2] == "test" + def test_map_message_to_langchain_ai_with_agent_call(self) -> None: message = AIMessage( content="hi", @@ -182,6 +336,42 @@ def test_map_message_to_langchain_ai_with_agent_call(self) -> None: ) ] + def test_map_message_to_langchain_ai_with_tool_call_with_thought_signature( + self, + ) -> None: + extras = { + "function_call": { + "name": "__local-startup_time", + "arguments": '{"q": "test"}', + }, + "__gemini_function_call_thought_signatures__": { + "28e28045-9846-4c9c-ab46-97f33bff5a9c": "EjQKMgEMOdbHH9gTl8BkX2uMM52753GCboanCcnUp9XB896IdThnG42GB8lRSkqGGxVbv5JY" + }, + } + message = AIMessage( + content="hi", + calls=[ + ToolCall( + name="startup_time", + args={"q": "test"}, + id="tc-2", + type=ToolType.LOCAL, + ) + ], + extras=extras, + ) + mapped = lc._map_message_to_langchain(message) + + assert isinstance(mapped, LC_AIMessage) + assert mapped.tool_calls == [ + LC_ToolCall( + name=f"__local-startup_time", + args={"q": "test"}, + id="tc-2", + ) + ] + assert mapped.additional_kwargs == extras + def test_map_message_to_langchain_human(self) -> None: message = HumanMessage(content="hello") mapped = lc._map_message_to_langchain(message)