diff --git a/README.md b/README.md index c939db5..02f7555 100644 --- a/README.md +++ b/README.md @@ -241,6 +241,29 @@ async def main() -> None: - Use `getResponse(taskUUID)` to retrieve results at any time - `deliveryMethod="sync"` waits for complete results (may timeout for long-running tasks) +### Retrieving Original Task Request/Response + +To inspect the original request payload and response for a past task, use `getTaskDetails(taskUUID)`. +Known request task types are parsed into SDK request objects when possible; unknown task types remain raw dictionaries. +`details.response` is normalized to a list: +- success: parsed items from `response.data[]` (typed when known) +- failure: items from `response.errors[]` + +```python +from runware import Runware + +async def main() -> None: + runware = Runware(api_key=RUNWARE_API_KEY) + await runware.connect() + + details = await runware.getTaskDetails( + taskUUID="a770f077-f413-47de-9dac-be0b26a35da6" + ) + + print("Original request:", details.request) + print("Original response:", details.response) +``` + ### Enhancing Prompts To enhance prompts using the Runware API, you can use the `promptEnhance` method of the `Runware` class. Here's an example: diff --git a/runware/base.py b/runware/base.py index 7f2bd6f..8cd0ee4 100644 --- a/runware/base.py +++ b/runware/base.py @@ -60,11 +60,13 @@ I3dInference, I3d, IGetResponseRequest, + IGetTaskDetailsRequest, IUploadImageRequest, IUploadMediaRequest, ITextInference, IText, ITextInputs, + ITaskDetails, ) from .types import IImage, IError, SdkType, ListenerType from .utils import ( @@ -256,10 +258,16 @@ async def _handle_pending_operation_message(self, item: "Dict[str, Any]") -> boo task_uuid = item.get("taskUUID") if not task_uuid: return False + task_type = item.get("taskType") on_partial_callback = None async with self._operations_lock: - op = self._pending_operations.get(task_uuid) + operation_key = task_uuid + if task_type: + typed_key = f"{task_uuid}:{task_type}" + if typed_key in self._pending_operations: + operation_key = typed_key + op = self._pending_operations.get(operation_key) if op is None: return False @@ -314,7 +322,13 @@ async def _handle_pending_operation_error(self, error: "Dict[str, Any]") -> bool on_partial_callback = None error_obj = None async with self._operations_lock: - op = self._pending_operations.get(task_uuid) + operation_key = task_uuid + task_type = error.get("taskType") + if task_type: + typed_key = f"{task_uuid}:{task_type}" + if typed_key in self._pending_operations: + operation_key = typed_key + op = self._pending_operations.get(operation_key) if op is None: return False @@ -2189,6 +2203,145 @@ async def _getResponse( number_results=request_model.numberResults, ) + async def getTaskDetails(self, taskUUID: str) -> ITaskDetails: + async with self._request_semaphore: + request = IGetTaskDetailsRequest(taskUUID=taskUUID) + return await self._retry_async_with_reconnect( + self._requestTaskDetails, + request, + task_type=ETaskType.GET_TASK_DETAILS.value, + ) + + async def _requestTaskDetails(self, request_model: IGetTaskDetailsRequest) -> ITaskDetails: + await self.ensureConnection() + request_object = { + "taskType": ETaskType.GET_TASK_DETAILS.value, + "taskUUID": request_model.taskUUID, + } + return await self._handleTaskDetailsResponse( + request_object=request_object, + task_uuid=request_model.taskUUID, + ) + + async def _handleTaskDetailsResponse( + self, + request_object: Dict[str, Any], + task_uuid: str, + ) -> ITaskDetails: + operation_key = f"{task_uuid}:{ETaskType.GET_TASK_DETAILS.value}" + future, should_send = await self._register_pending_operation( + operation_key, + expected_results=1, + complete_predicate=lambda r: ( + r.get("taskType") == ETaskType.GET_TASK_DETAILS.value + and "request" in r + and "response" in r + ), + result_filter=lambda r: r.get("taskType") == ETaskType.GET_TASK_DETAILS.value, + ) + try: + if should_send: + await self.send([request_object]) + await self._mark_operation_sent(operation_key) + + results = await asyncio.wait_for(future, timeout=self._timeout / 1000) + if not results: + raise ValueError(f"No task details found for taskUUID={task_uuid}") + task_details = instantiateDataclass(ITaskDetails, results[0]) + task_details.request = self._normalizeTaskDetailsRequest(task_details.request) + task_details.response = self._normalizeTaskDetailsResponse(task_details.response) + return task_details + except asyncio.TimeoutError: + raise Exception( + f"Timeout waiting for task details | TaskUUID: {task_uuid} | " + f"Timeout: {self._timeout}ms" + ) + finally: + await self._unregister_pending_operation(operation_key) + + def _normalizeTaskDetailsRequest(self, request_items: List[Any]) -> List[Any]: + task_type_map = { + ETaskType.IMAGE_INFERENCE.value: IImageInference, + ETaskType.PHOTO_MAKER.value: IPhotoMaker, + ETaskType.IMAGE_CAPTION.value: IImageCaption, + ETaskType.IMAGE_BACKGROUND_REMOVAL.value: IImageBackgroundRemoval, + ETaskType.IMAGE_UPSCALE.value: IImageUpscale, + ETaskType.PROMPT_ENHANCE.value: IPromptEnhance, + ETaskType.MODEL_SEARCH.value: IModelSearch, + ETaskType.VIDEO_INFERENCE.value: IVideoInference, + ETaskType.VIDEO_CAPTION.value: IVideoCaption, + ETaskType.VIDEO_BACKGROUND_REMOVAL.value: IVideoBackgroundRemoval, + ETaskType.VIDEO_UPSCALE.value: IVideoUpscale, + ETaskType.AUDIO_INFERENCE.value: IAudioInference, + ETaskType.INFERENCE_3D.value: I3dInference, + ETaskType.TEXT_INFERENCE.value: ITextInference, + ETaskType.GET_RESPONSE.value: IGetResponseRequest, + ETaskType.GET_TASK_DETAILS.value: IGetTaskDetailsRequest, + ETaskType.IMAGE_VECTORIZE.value: IVectorize, + } + return self._normalizeTaskDetailsItems( + request_items, + task_type_map, + lambda cls, item: instantiateDataclass(cls, item), + ) + + def _normalizeTaskDetailsResponse(self, response_payload: Any) -> Any: + if not isinstance(response_payload, dict): + return response_payload if isinstance(response_payload, list) else [response_payload] + + data_items = response_payload.get("data") + if isinstance(data_items, list): + response_type_map = { + ETaskType.AUDIO_INFERENCE.value: IAudio, + ETaskType.VIDEO_CAPTION.value: IVideoToText, + ETaskType.IMAGE_CAPTION.value: IImageToText, + ETaskType.IMAGE_INFERENCE.value: IImage, + ETaskType.PHOTO_MAKER.value: IImage, + ETaskType.IMAGE_UPSCALE.value: IImage, + ETaskType.IMAGE_VECTORIZE.value: IImage, + ETaskType.IMAGE_BACKGROUND_REMOVAL.value: IImage, + ETaskType.VIDEO_INFERENCE.value: IVideo, + ETaskType.VIDEO_BACKGROUND_REMOVAL.value: IVideo, + ETaskType.VIDEO_UPSCALE.value: IVideo, + ETaskType.INFERENCE_3D.value: I3d, + ETaskType.TEXT_INFERENCE.value: IText, + ETaskType.PROMPT_ENHANCE.value: IEnhancedPrompt, + ETaskType.GET_TASK_DETAILS.value: ITaskDetails, + } + return self._normalizeTaskDetailsItems( + data_items, + response_type_map, + lambda cls, item: instantiateDataclass(cls, item), + ) + + error_items = response_payload.get("errors") + if isinstance(error_items, list): + return error_items + + return [response_payload] + + def _normalizeTaskDetailsItems( + self, + items: List[Any], + task_type_map: Dict[str, Any], + instantiate_fn: Callable[[Any, Dict[str, Any]], Any], + ) -> List[Any]: + normalized: List[Any] = [] + for item in items: + if not isinstance(item, dict): + normalized.append(item) + continue + task_type = item.get("taskType") + target_cls = task_type_map.get(task_type) + if target_cls is None: + normalized.append(item) + continue + try: + normalized.append(instantiate_fn(target_cls, item)) + except Exception: + normalized.append(item) + return normalized + async def _requestVideo(self, requestVideo: "IVideoInference") -> "Union[List[IVideo], IAsyncTaskResponse]": if requestVideo.frameImages: requestVideo.frameImages = await self._process_media_list( diff --git a/runware/types.py b/runware/types.py index 3f1dc14..2279dc4 100644 --- a/runware/types.py +++ b/runware/types.py @@ -48,6 +48,7 @@ class ETaskType(Enum): VIDEO_CAPTION = "caption" MEDIA_STORAGE = "mediaStorage" GET_RESPONSE = "getResponse" + GET_TASK_DETAILS = "getTaskDetails" IMAGE_VECTORIZE = "vectorize" @@ -138,6 +139,11 @@ class IGetResponseRequest: numberResults: int = 1 +@dataclass +class IGetTaskDetailsRequest: + taskUUID: str + + @dataclass class IUploadImageRequest: file: Union[File, str] @@ -2091,6 +2097,37 @@ class IVideoToText: cost: Optional[float] = None +@dataclass +class ITaskDetails: + taskType: str + taskUUID: str + request: List[ + Union[ + IImageInference, + IPhotoMaker, + IImageCaption, + IImageBackgroundRemoval, + IImageUpscale, + IPromptEnhance, + IModelSearch, + IVideoInference, + IVideoCaption, + IVideoBackgroundRemoval, + IVideoUpscale, + IAudioInference, + I3dInference, + ITextInference, + IGetResponseRequest, + IGetTaskDetailsRequest, + IVectorize, + Dict[str, Any], + ] + ] + response: List[ + Union[IImage, IVideo, IAudio, IVideoToText, IImageToText, I3d, IText, IEnhancedPrompt, Dict[str, Any]] + ] + + # The GetWithPromiseCallBackType is defined using the Callable type from the typing module. It represents a function that takes a dictionary # with specific keys and returns either a boolean or None. # The dictionary should have the following keys: diff --git a/runware/utils.py b/runware/utils.py index 52160e6..b20fe94 100644 --- a/runware/utils.py +++ b/runware/utils.py @@ -921,11 +921,53 @@ def instantiateDataclass(dataclass_type: Type[Any], data: dict) -> Any: continue field_type = hints.get(k) - - # Unwrap Optional[X] -> X + union_args = [] if get_origin(field_type) is Union: - args = [a for a in get_args(field_type) if a is not type(None)] - field_type = args[0] if args else field_type + union_args = [a for a in get_args(field_type) if a is not type(None)] + + + matched_union_type = False + for arg in union_args: + if isinstance(arg, type) and isinstance(v, arg): + filtered_data[k] = v + matched_union_type = True + break + if matched_union_type: + continue + + + if isinstance(v, dict): + dataclass_args = [ + arg for arg in union_args + if isinstance(arg, type) and is_dataclass(arg) + ] + for dataclass_arg in dataclass_args: + try: + filtered_data[k] = instantiateDataclass(dataclass_arg, v) + matched_union_type = True + break + except Exception: + continue + if matched_union_type: + continue + + has_dict_branch = any( + arg is dict or get_origin(arg) is dict + for arg in union_args + ) + if has_dict_branch: + filtered_data[k] = v + continue + + + if isinstance(v, list): + list_arg = next((arg for arg in union_args if get_origin(arg) is list), None) + if list_arg is not None: + field_type = list_arg + else: + field_type = union_args[0] if union_args else field_type + else: + field_type = union_args[0] if union_args else field_type # Nested dataclass if is_dataclass(field_type) and isinstance(v, dict):