diff --git a/src/aws_durable_execution_sdk_python/execution.py b/src/aws_durable_execution_sdk_python/execution.py index e2daa46..403111d 100644 --- a/src/aws_durable_execution_sdk_python/execution.py +++ b/src/aws_durable_execution_sdk_python/execution.py @@ -69,32 +69,6 @@ def from_json_dict(input_dict: MutableMapping[str, Any]) -> InitialExecutionStat next_marker=input_dict.get("NextMarker", ""), ) - def get_execution_operation(self) -> Operation | None: - if not self.operations: - # Due to payload size limitations we may have an empty operations list. - # This will only happen when loading the initial page of results and is - # expected behaviour. We don't fail, but instead return None - # as the execution operation does not exist - msg: str = "No durable operations found in initial execution state." - logger.debug(msg) - return None - - candidate = self.operations[0] - if candidate.operation_type is not OperationType.EXECUTION: - msg = f"First operation in initial execution state is not an execution operation: {candidate.operation_type}" - raise DurableExecutionsError(msg) - - return candidate - - def get_input_payload(self) -> str | None: - # It is possible that backend will not provide an execution operation - # for the initial page of results. - if not (operations := self.get_execution_operation()): - return None - if not (execution_details := operations.execution_details): - return None - return execution_details.input_payload - def to_dict(self) -> MutableMapping[str, Any]: return { "Operations": [op.to_dict() for op in self.operations], @@ -275,23 +249,6 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: else LambdaClient.initialize_client() ) - raw_input_payload: str | None = ( - invocation_input.initial_execution_state.get_input_payload() - ) - - # Python RIC LambdaMarshaller just uses standard json deserialization for event - # https://github.com/aws/aws-lambda-python-runtime-interface-client/blob/main/awslambdaric/lambda_runtime_marshaller.py#L46 - input_event: MutableMapping[str, Any] = {} - if raw_input_payload and raw_input_payload.strip(): - try: - input_event = json.loads(raw_input_payload) - except json.JSONDecodeError: - logger.exception( - "Failed to parse input payload as JSON: payload: %r", - raw_input_payload, - ) - raise - execution_state: ExecutionState = ExecutionState( durable_execution_arn=invocation_input.durable_execution_arn, initial_checkpoint_token=invocation_input.checkpoint_token, @@ -309,6 +266,21 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: invocation_input.initial_execution_state.next_marker, ) + raw_input_payload: str | None = execution_state.get_input_payload() + + # Python RIC LambdaMarshaller just uses standard json deserialization for event + # https://github.com/aws/aws-lambda-python-runtime-interface-client/blob/main/awslambdaric/lambda_runtime_marshaller.py#L46 + input_event: MutableMapping[str, Any] = {} + if raw_input_payload and raw_input_payload.strip(): + try: + input_event = json.loads(raw_input_payload) + except json.JSONDecodeError: + logger.exception( + "Failed to parse input payload as JSON: payload: %r", + raw_input_payload, + ) + raise + durable_context: DurableContext = DurableContext.from_lambda_context( state=execution_state, lambda_context=context ) diff --git a/src/aws_durable_execution_sdk_python/state.py b/src/aws_durable_execution_sdk_python/state.py index 16bd31d..48076c9 100644 --- a/src/aws_durable_execution_sdk_python/state.py +++ b/src/aws_durable_execution_sdk_python/state.py @@ -290,6 +290,33 @@ def fetch_paginated_operations( with self._operations_lock: self.operations.update({op.operation_id: op for op in all_operations}) + def get_input_payload(self) -> str | None: + # It is possible that backend will not provide an execution operation + # for the initial page of results. + if not (operations := self.get_execution_operation()): + return None + if not (execution_details := operations.execution_details): + return None + return execution_details.input_payload + + def get_execution_operation(self) -> Operation | None: + # invocation id is id of execution operation + invocation_id = self.durable_execution_arn.split("/")[-1] + candidate = self.operations.get(invocation_id) + if not candidate: + # Due to payload size limitations we may have an empty operations list. + # This will only happen when loading the initial page of results and is + # expected behaviour. We don't fail, but instead return None + # as the execution operation does not exist + msg: str = "No durable operations found in execution state." + logger.debug(msg) + return None + if candidate.operation_type is not OperationType.EXECUTION: + msg = f"The execution operation in execution state does not have EXECUTION type: {candidate.operation_type}" + raise DurableExecutionsError(msg) + + return candidate + def track_replay(self, operation_id: str) -> None: """Check if operation exists with completed status; if not, transition to NEW status. diff --git a/tests/execution_test.py b/tests/execution_test.py index 485d400..eeaf949 100644 --- a/tests/execution_test.py +++ b/tests/execution_test.py @@ -774,51 +774,6 @@ def test_handler(event: Any, context: DurableContext) -> dict: mock_lambda_client.initialize_client.assert_called_once() -def test_initial_execution_state_get_execution_operation_no_operations(): - """Test get_execution_operation logs debug and returns None when no operations exist.""" - state = InitialExecutionState(operations=[], next_marker="") - - with patch("aws_durable_execution_sdk_python.execution.logger") as mock_logger: - result = state.get_execution_operation() - - assert result is None - mock_logger.debug.assert_called_once_with( - "No durable operations found in initial execution state." - ) - - -def test_initial_execution_state_get_execution_operation_wrong_type(): - """Test get_execution_operation raises error when first operation is not EXECUTION.""" - operation = Operation( - operation_id="step1", - operation_type=OperationType.STEP, - status=OperationStatus.STARTED, - ) - - state = InitialExecutionState(operations=[operation], next_marker="") - - with pytest.raises( - Exception, - match="First operation in initial execution state is not an execution operation", - ): - state.get_execution_operation() - - -def test_initial_execution_state_get_input_payload_none(): - """Test get_input_payload returns None when execution_details is None.""" - operation = Operation( - operation_id="exec1", - operation_type=OperationType.EXECUTION, - status=OperationStatus.STARTED, - execution_details=None, - ) - - state = InitialExecutionState(operations=[operation], next_marker="") - - result = state.get_input_payload() - assert result is None - - def test_durable_handler_empty_input_payload(): """Test durable_handler handles empty input payload correctly.""" mock_client = Mock(spec=DurableServiceClient) @@ -916,7 +871,7 @@ def test_handler(event: Any, context: DurableContext) -> dict: initial_state = InitialExecutionState(operations=[operation], next_marker="") invocation_input = DurableExecutionInvocationInputWithClient( - durable_execution_arn="arn:test:execution", + durable_execution_arn="arn:test:execution/exec1", checkpoint_token="token123", # noqa: S106 initial_execution_state=initial_state, service_client=mock_client, diff --git a/tests/state_test.py b/tests/state_test.py index 1f49838..19d641e 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -9,7 +9,7 @@ import time import unittest.mock from concurrent.futures import ThreadPoolExecutor -from unittest.mock import Mock, call +from unittest.mock import Mock, call, patch import pytest @@ -3562,3 +3562,90 @@ def test_collect_checkpoint_batch_first_empty_counts_toward_limit(): ) # Only the leading empty; trailing deferred to next batch # op_2 and trailing empties remain in the queue assert state._checkpoint_queue.qsize() == 51 + + +def test_execution_state_get_execution_operation_no_operations(): + """Test get_execution_operation logs debug and returns None when no operations exist.""" + mock_lambda_client = Mock(spec=LambdaClient) + config = CheckpointBatcherConfig( + max_batch_size_bytes=10 * 1024 * 1024, + max_batch_time_seconds=10.0, + max_batch_operations=2, + ) + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_lambda_client, + batcher_config=config, + ) + + with patch("aws_durable_execution_sdk_python.state.logger") as mock_logger: + result = state.get_execution_operation() + + assert result is None + mock_logger.debug.assert_called_once_with( + "No durable operations found in execution state." + ) + + +def test_initial_execution_state_get_execution_operation_wrong_type(): + """Test get_execution_operation raises error when first operation is not EXECUTION.""" + operation = Operation( + operation_id="step1", + operation_type=OperationType.STEP, + status=OperationStatus.STARTED, + ) + + mock_lambda_client = Mock(spec=LambdaClient) + config = CheckpointBatcherConfig( + max_batch_size_bytes=10 * 1024 * 1024, + max_batch_time_seconds=10.0, + max_batch_operations=2, + ) + state = ExecutionState( + durable_execution_arn="test_arn/step1", + initial_checkpoint_token="token123", # noqa: S106 + operations={"step1": operation}, + service_client=mock_lambda_client, + batcher_config=config, + ) + + with pytest.raises( + Exception, + match="The execution operation in execution state does not have EXECUTION type: OperationType.STEP", + ): + state.get_execution_operation() + + +def test_initial_execution_state_get_input_payload_none(): + """Test get_input_payload returns None when execution_details is None.""" + operation = Operation( + operation_id="exec1", + operation_type=OperationType.EXECUTION, + status=OperationStatus.STARTED, + execution_details=None, + ) + + operation = Operation( + operation_id="step1", + operation_type=OperationType.STEP, + status=OperationStatus.STARTED, + ) + + mock_lambda_client = Mock(spec=LambdaClient) + config = CheckpointBatcherConfig( + max_batch_size_bytes=10 * 1024 * 1024, + max_batch_time_seconds=10.0, + max_batch_operations=2, + ) + state = ExecutionState( + durable_execution_arn="test_arn/exec1", + initial_checkpoint_token="token123", # noqa: S106 + operations={"step1": operation}, + service_client=mock_lambda_client, + batcher_config=config, + ) + + result = state.get_input_payload() + assert result is None