Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 97 additions & 12 deletions splunklib/ai/engines/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@
AgentResponse,
AIMessage,
BaseMessage,
ContentBlock,
HumanMessage,
OpaqueBlock,
OutputT,
StructuredOutputCall,
StructuredOutputMessage,
Expand All @@ -87,6 +89,7 @@
SubagentStructuredResult,
SubagentTextResult,
SystemMessage,
TextBlock,
ToolCall,
ToolFailureResult,
ToolMessage,
Expand Down Expand Up @@ -951,7 +954,7 @@ async def awrap_tool_call(
return LC_ToolMessage(
name=_normalize_agent_name(call.name),
tool_call_id=call.id,
content=content,
content=_map_content_to_langchain(content),
status=status,
artifact=sdk_result,
)
Expand Down Expand Up @@ -1085,7 +1088,10 @@ def _convert_model_response_to_model_result(
# This invariant is asserted via ModelResponse.__post_init__
assert len(resp.message.structured_output_calls) <= 1

lc_message = LC_AIMessage(content=resp.message.content)
lc_message = LC_AIMessage(
content=_map_content_to_langchain(resp.message.content),
additional_kwargs=resp.message.extras or {},
)
# This field can't be set via __init__()
lc_message.tool_calls = [_map_tool_call_to_langchain(c) for c in resp.message.calls]

Expand Down Expand Up @@ -1160,7 +1166,7 @@ def _convert_tool_message_to_lc(
name=name,
tool_call_id=message.call_id,
status=status,
content=content,
content=_map_content_to_langchain(content),
artifact=artifact,
)

Expand Down Expand Up @@ -1243,9 +1249,10 @@ def _convert_model_result_from_lc(model_response: LC_ModelCallResult) -> ModelRe
ai_message = model_response
structured_response = None

additional_kwargs = cast(dict[str, Any], ai_message.additional_kwargs)
return ModelResponse(
message=AIMessage(
content=ai_message.content.__str__(),
content=_map_content_from_langchain(ai_message.content), # pyright: ignore[reportUnknownArgumentType]
calls=[
_map_tool_call_from_langchain(tc)
for tc in ai_message.tool_calls
Expand All @@ -1260,6 +1267,7 @@ def _convert_model_result_from_lc(model_response: LC_ModelCallResult) -> ModelRe
for tc in ai_message.tool_calls
if tc["name"].startswith(TOOL_STRATEGY_TOOL_PREFIX)
],
extras=additional_kwargs,
),
structured_output=structured_response,
)
Expand Down Expand Up @@ -1433,7 +1441,10 @@ def _agent_as_tool(agent: BaseAgent[OutputT]) -> StructuredTool:

async def invoke_agent(
message: HumanMessage, thread_id: str | None
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
) -> tuple[
OutputT | str | list[str | ContentBlock],
SubagentStructuredResult | SubagentTextResult,
]:
result = await agent.invoke([message], thread_id=thread_id)

if agent.output_schema:
Expand All @@ -1452,13 +1463,19 @@ async def invoke_agent(

async def _run( # pyright: ignore[reportRedeclaration]
content: str, thread_id: str
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
) -> tuple[
OutputT | str | list[str | ContentBlock],
SubagentStructuredResult | SubagentTextResult,
]:
return await invoke_agent(HumanMessage(content=content), thread_id)
else:

async def _run( # pyright: ignore[reportRedeclaration]
content: str,
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
) -> tuple[
OutputT | str | list[str | ContentBlock],
SubagentStructuredResult | SubagentTextResult,
]:
return await invoke_agent(HumanMessage(content=content), None)

return StructuredTool.from_function(
Expand All @@ -1471,7 +1488,10 @@ async def _run( # pyright: ignore[reportRedeclaration]

async def invoke_agent_structured(
content: BaseModel, thread_id: str | None
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
) -> tuple[
OutputT | str | list[str | ContentBlock],
SubagentStructuredResult | SubagentTextResult,
]:
result = await agent.invoke_with_data(
instructions="Follow the system prompt.",
data=content.model_dump(),
Expand All @@ -1492,7 +1512,10 @@ async def invoke_agent_structured(

async def _run(
**kwargs: Any, # noqa: ANN401
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
) -> tuple[
OutputT | str | list[str | ContentBlock],
SubagentStructuredResult | SubagentTextResult,
]:
content: BaseModel = kwargs["content"]
thread_id: str = kwargs["thread_id"]
return await invoke_agent_structured(content, thread_id)
Expand All @@ -1512,7 +1535,10 @@ async def _run(

async def _run(
**kwargs: Any, # noqa: ANN401
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
) -> tuple[
OutputT | str | list[str | ContentBlock],
SubagentStructuredResult | SubagentTextResult,
]:
content = InputSchema(**kwargs)
return await invoke_agent_structured(content, None)

Expand Down Expand Up @@ -1564,11 +1590,66 @@ def _map_tool_call_to_langchain(call: ToolCall | SubagentCall) -> LC_ToolCall:
return LC_ToolCall(id=call.id, name=name, args=args)


def _map_content_from_langchain(
content: str | list[str | dict[str, Any]],
) -> str | list[str | ContentBlock]:
if isinstance(content, str):
return content

result_content = [_map_content_block_from_langchain(b) for b in content]

return result_content


def _map_content_block_from_langchain(
block: str | dict[str, Any],
) -> str | ContentBlock:
if isinstance(block, str):
return block

match block.get("type"):
case "text":
return TextBlock(
text=block["text"],
extras=block.get("extras"),
)
case _:
# NOTE: we return data we're not handling
# as opaque content blocks so they
# are preserved and sent back to the LLM
return OpaqueBlock(data=block)


def _map_content_to_langchain(
content: str | list[str | ContentBlock],
) -> str | list[str | dict[str, Any]]:
if isinstance(content, str):
return content

result_content = [_map_content_block_to_langchain(b) for b in content]

return result_content


def _map_content_block_to_langchain(block: str | ContentBlock) -> str | dict[str, Any]:
if isinstance(block, str):
return block

match block:
case TextBlock():
result: dict[str, Any] = {"type": "text", "text": block.text}
if block.extras:
result["extras"] = block.extras
return result
case OpaqueBlock():
return block.data


def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage:
match message:
case LC_AIMessage():
return AIMessage(
content=message.content.__str__(),
content=_map_content_from_langchain(message.content), # pyright: ignore[reportUnknownArgumentType]
calls=[
_map_tool_call_from_langchain(tc)
for tc in message.tool_calls
Expand All @@ -1583,6 +1664,7 @@ def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage:
for tc in message.tool_calls
if tc["name"].startswith(TOOL_STRATEGY_TOOL_PREFIX)
],
extras=cast(dict[str, Any], message.additional_kwargs),
)
case LC_HumanMessage():
return HumanMessage(content=message.content.__str__())
Expand All @@ -1597,7 +1679,10 @@ def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage:
def _map_message_to_langchain(message: BaseMessage) -> LC_AnyMessage:
match message:
case AIMessage():
lc_message = LC_AIMessage(content=message.content)
lc_message = LC_AIMessage(
content=_map_content_to_langchain(message.content),
additional_kwargs=message.extras or {},
)
# This field can't be set via constructor
lc_message.tool_calls = [
_map_tool_call_to_langchain(c) for c in message.calls
Expand Down
32 changes: 30 additions & 2 deletions splunklib/ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,31 @@
from splunklib.ai.tools import ToolType


@dataclass(frozen=True)
class TextBlock:
"""Plain text content block returned by a model."""

text: str
# TODO: should we have the id here as well?
# Provider-specific extras (e.g. Gemini thought signature on text blocks).
extras: dict[str, Any] | None = field(default=None)


@dataclass(frozen=True)
class OpaqueBlock:
"""Content block of an unrecognized or unsupported type.

The raw provider dict is preserved in `data` so it can be sent back
to the model unchanged on subsequent calls.
"""

data: dict[str, Any]


# Type alias for all content block variants.
ContentBlock = TextBlock | OpaqueBlock


@dataclass(frozen=True)
class ToolCall:
name: str
Expand Down Expand Up @@ -85,12 +110,15 @@ class AIMessage(BaseMessage):
"""

role: Literal["assistant"] = field(default="assistant", init=False)
content: str
content: str | list[str | ContentBlock]

calls: Sequence[ToolCall | SubagentCall]
structured_output_calls: Sequence[StructuredOutputCall] = field(
default_factory=tuple
)
# Backend-specific metadata (e.g. provider additional_kwargs) not
# representable in the standard fields. Opaque to callers.
extras: dict[str, Any] | None = field(default=None)


@dataclass(frozen=True)
Expand Down Expand Up @@ -120,7 +148,7 @@ class SubagentTextResult:
Returned by subagent calls that don't have an output schema.
"""

content: str
content: str | list[str | ContentBlock]


@dataclass(frozen=True)
Expand Down
Loading
Loading