Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 15 additions & 43 deletions src/aws_durable_execution_sdk_python/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,32 +69,6 @@ def from_json_dict(input_dict: MutableMapping[str, Any]) -> InitialExecutionStat
next_marker=input_dict.get("NextMarker", ""),
)

def get_execution_operation(self) -> Operation | None:
if not self.operations:
# Due to payload size limitations we may have an empty operations list.
# This will only happen when loading the initial page of results and is
# expected behaviour. We don't fail, but instead return None
# as the execution operation does not exist
msg: str = "No durable operations found in initial execution state."
logger.debug(msg)
return None

candidate = self.operations[0]
if candidate.operation_type is not OperationType.EXECUTION:
msg = f"First operation in initial execution state is not an execution operation: {candidate.operation_type}"
raise DurableExecutionsError(msg)

return candidate

def get_input_payload(self) -> str | None:
# It is possible that backend will not provide an execution operation
# for the initial page of results.
if not (operations := self.get_execution_operation()):
return None
if not (execution_details := operations.execution_details):
return None
return execution_details.input_payload

def to_dict(self) -> MutableMapping[str, Any]:
return {
"Operations": [op.to_dict() for op in self.operations],
Expand Down Expand Up @@ -275,23 +249,6 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]:
else LambdaClient.initialize_client()
)

raw_input_payload: str | None = (
invocation_input.initial_execution_state.get_input_payload()
)

# Python RIC LambdaMarshaller just uses standard json deserialization for event
# https://github.com/aws/aws-lambda-python-runtime-interface-client/blob/main/awslambdaric/lambda_runtime_marshaller.py#L46
input_event: MutableMapping[str, Any] = {}
if raw_input_payload and raw_input_payload.strip():
try:
input_event = json.loads(raw_input_payload)
except json.JSONDecodeError:
logger.exception(
"Failed to parse input payload as JSON: payload: %r",
raw_input_payload,
)
raise

execution_state: ExecutionState = ExecutionState(
durable_execution_arn=invocation_input.durable_execution_arn,
initial_checkpoint_token=invocation_input.checkpoint_token,
Expand All @@ -309,6 +266,21 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]:
invocation_input.initial_execution_state.next_marker,
)

raw_input_payload: str | None = execution_state.get_input_payload()

# Python RIC LambdaMarshaller just uses standard json deserialization for event
# https://github.com/aws/aws-lambda-python-runtime-interface-client/blob/main/awslambdaric/lambda_runtime_marshaller.py#L46
input_event: MutableMapping[str, Any] = {}
if raw_input_payload and raw_input_payload.strip():
try:
input_event = json.loads(raw_input_payload)
except json.JSONDecodeError:
logger.exception(
"Failed to parse input payload as JSON: payload: %r",
raw_input_payload,
)
raise

durable_context: DurableContext = DurableContext.from_lambda_context(
state=execution_state, lambda_context=context
)
Expand Down
27 changes: 27 additions & 0 deletions src/aws_durable_execution_sdk_python/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,33 @@ def fetch_paginated_operations(
with self._operations_lock:
self.operations.update({op.operation_id: op for op in all_operations})

def get_input_payload(self) -> str | None:
# It is possible that backend will not provide an execution operation
# for the initial page of results.
if not (operations := self.get_execution_operation()):
return None
if not (execution_details := operations.execution_details):
return None
return execution_details.input_payload

def get_execution_operation(self) -> Operation | None:
# invocation id is id of execution operation
invocation_id = self.durable_execution_arn.split("/")[-1]
candidate = self.operations.get(invocation_id)
if not candidate:
# Due to payload size limitations we may have an empty operations list.
# This will only happen when loading the initial page of results and is
# expected behaviour. We don't fail, but instead return None
# as the execution operation does not exist
msg: str = "No durable operations found in execution state."
logger.debug(msg)
return None
if candidate.operation_type is not OperationType.EXECUTION:
msg = f"The execution operation in execution state does not have EXECUTION type: {candidate.operation_type}"
raise DurableExecutionsError(msg)

return candidate

def track_replay(self, operation_id: str) -> None:
"""Check if operation exists with completed status; if not, transition to NEW status.

Expand Down
47 changes: 1 addition & 46 deletions tests/execution_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,51 +774,6 @@ def test_handler(event: Any, context: DurableContext) -> dict:
mock_lambda_client.initialize_client.assert_called_once()


def test_initial_execution_state_get_execution_operation_no_operations():
"""Test get_execution_operation logs debug and returns None when no operations exist."""
state = InitialExecutionState(operations=[], next_marker="")

with patch("aws_durable_execution_sdk_python.execution.logger") as mock_logger:
result = state.get_execution_operation()

assert result is None
mock_logger.debug.assert_called_once_with(
"No durable operations found in initial execution state."
)


def test_initial_execution_state_get_execution_operation_wrong_type():
"""Test get_execution_operation raises error when first operation is not EXECUTION."""
operation = Operation(
operation_id="step1",
operation_type=OperationType.STEP,
status=OperationStatus.STARTED,
)

state = InitialExecutionState(operations=[operation], next_marker="")

with pytest.raises(
Exception,
match="First operation in initial execution state is not an execution operation",
):
state.get_execution_operation()


def test_initial_execution_state_get_input_payload_none():
"""Test get_input_payload returns None when execution_details is None."""
operation = Operation(
operation_id="exec1",
operation_type=OperationType.EXECUTION,
status=OperationStatus.STARTED,
execution_details=None,
)

state = InitialExecutionState(operations=[operation], next_marker="")

result = state.get_input_payload()
assert result is None


def test_durable_handler_empty_input_payload():
"""Test durable_handler handles empty input payload correctly."""
mock_client = Mock(spec=DurableServiceClient)
Expand Down Expand Up @@ -916,7 +871,7 @@ def test_handler(event: Any, context: DurableContext) -> dict:
initial_state = InitialExecutionState(operations=[operation], next_marker="")

invocation_input = DurableExecutionInvocationInputWithClient(
durable_execution_arn="arn:test:execution",
durable_execution_arn="arn:test:execution/exec1",
checkpoint_token="token123", # noqa: S106
initial_execution_state=initial_state,
service_client=mock_client,
Expand Down
89 changes: 88 additions & 1 deletion tests/state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import time
import unittest.mock
from concurrent.futures import ThreadPoolExecutor
from unittest.mock import Mock, call
from unittest.mock import Mock, call, patch

import pytest

Expand Down Expand Up @@ -3562,3 +3562,90 @@ def test_collect_checkpoint_batch_first_empty_counts_toward_limit():
) # Only the leading empty; trailing deferred to next batch
# op_2 and trailing empties remain in the queue
assert state._checkpoint_queue.qsize() == 51


def test_execution_state_get_execution_operation_no_operations():
"""Test get_execution_operation logs debug and returns None when no operations exist."""
mock_lambda_client = Mock(spec=LambdaClient)
config = CheckpointBatcherConfig(
max_batch_size_bytes=10 * 1024 * 1024,
max_batch_time_seconds=10.0,
max_batch_operations=2,
)
state = ExecutionState(
durable_execution_arn="test_arn",
initial_checkpoint_token="token123", # noqa: S106
operations={},
service_client=mock_lambda_client,
batcher_config=config,
)

with patch("aws_durable_execution_sdk_python.state.logger") as mock_logger:
result = state.get_execution_operation()

assert result is None
mock_logger.debug.assert_called_once_with(
"No durable operations found in execution state."
)


def test_initial_execution_state_get_execution_operation_wrong_type():
"""Test get_execution_operation raises error when first operation is not EXECUTION."""
operation = Operation(
operation_id="step1",
operation_type=OperationType.STEP,
status=OperationStatus.STARTED,
)

mock_lambda_client = Mock(spec=LambdaClient)
config = CheckpointBatcherConfig(
max_batch_size_bytes=10 * 1024 * 1024,
max_batch_time_seconds=10.0,
max_batch_operations=2,
)
state = ExecutionState(
durable_execution_arn="test_arn/step1",
initial_checkpoint_token="token123", # noqa: S106
operations={"step1": operation},
service_client=mock_lambda_client,
batcher_config=config,
)

with pytest.raises(
Exception,
match="The execution operation in execution state does not have EXECUTION type: OperationType.STEP",
):
state.get_execution_operation()


def test_initial_execution_state_get_input_payload_none():
"""Test get_input_payload returns None when execution_details is None."""
operation = Operation(
operation_id="exec1",
operation_type=OperationType.EXECUTION,
status=OperationStatus.STARTED,
execution_details=None,
)

operation = Operation(
operation_id="step1",
operation_type=OperationType.STEP,
status=OperationStatus.STARTED,
)

mock_lambda_client = Mock(spec=LambdaClient)
config = CheckpointBatcherConfig(
max_batch_size_bytes=10 * 1024 * 1024,
max_batch_time_seconds=10.0,
max_batch_operations=2,
)
state = ExecutionState(
durable_execution_arn="test_arn/exec1",
initial_checkpoint_token="token123", # noqa: S106
operations={"step1": operation},
service_client=mock_lambda_client,
batcher_config=config,
)

result = state.get_input_payload()
assert result is None
Loading