-
Notifications
You must be signed in to change notification settings - Fork 15
Added initial commit for getTaskDetails #279
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f2d9c41
3a0aec3
457d645
b95e253
ec37d18
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -60,11 +60,13 @@ | |
| I3dInference, | ||
| I3d, | ||
| IGetResponseRequest, | ||
| IGetTaskDetailsRequest, | ||
| IUploadImageRequest, | ||
| IUploadMediaRequest, | ||
| ITextInference, | ||
| IText, | ||
| ITextInputs, | ||
| ITaskDetails, | ||
| ) | ||
| from .types import IImage, IError, SdkType, ListenerType | ||
| from .utils import ( | ||
|
|
@@ -256,10 +258,16 @@ async def _handle_pending_operation_message(self, item: "Dict[str, Any]") -> boo | |
| task_uuid = item.get("taskUUID") | ||
| if not task_uuid: | ||
| return False | ||
| task_type = item.get("taskType") | ||
|
|
||
| on_partial_callback = None | ||
| async with self._operations_lock: | ||
| op = self._pending_operations.get(task_uuid) | ||
| operation_key = task_uuid | ||
| if task_type: | ||
| typed_key = f"{task_uuid}:{task_type}" | ||
| if typed_key in self._pending_operations: | ||
| operation_key = typed_key | ||
| op = self._pending_operations.get(operation_key) | ||
| if op is None: | ||
| return False | ||
|
|
||
|
|
@@ -314,7 +322,13 @@ async def _handle_pending_operation_error(self, error: "Dict[str, Any]") -> bool | |
| on_partial_callback = None | ||
| error_obj = None | ||
| async with self._operations_lock: | ||
| op = self._pending_operations.get(task_uuid) | ||
| operation_key = task_uuid | ||
| task_type = error.get("taskType") | ||
| if task_type: | ||
| typed_key = f"{task_uuid}:{task_type}" | ||
| if typed_key in self._pending_operations: | ||
| operation_key = typed_key | ||
| op = self._pending_operations.get(operation_key) | ||
| if op is None: | ||
| return False | ||
|
|
||
|
|
@@ -2189,6 +2203,145 @@ async def _getResponse( | |
| number_results=request_model.numberResults, | ||
| ) | ||
|
|
||
| async def getTaskDetails(self, taskUUID: str) -> ITaskDetails: | ||
| async with self._request_semaphore: | ||
| request = IGetTaskDetailsRequest(taskUUID=taskUUID) | ||
| return await self._retry_async_with_reconnect( | ||
| self._requestTaskDetails, | ||
| request, | ||
| task_type=ETaskType.GET_TASK_DETAILS.value, | ||
| ) | ||
|
Comment on lines
+2206
to
+2213
|
||
|
|
||
| async def _requestTaskDetails(self, request_model: IGetTaskDetailsRequest) -> ITaskDetails: | ||
| await self.ensureConnection() | ||
| request_object = { | ||
| "taskType": ETaskType.GET_TASK_DETAILS.value, | ||
| "taskUUID": request_model.taskUUID, | ||
| } | ||
| return await self._handleTaskDetailsResponse( | ||
| request_object=request_object, | ||
| task_uuid=request_model.taskUUID, | ||
| ) | ||
|
|
||
| async def _handleTaskDetailsResponse( | ||
| self, | ||
| request_object: Dict[str, Any], | ||
| task_uuid: str, | ||
| ) -> ITaskDetails: | ||
| operation_key = f"{task_uuid}:{ETaskType.GET_TASK_DETAILS.value}" | ||
| future, should_send = await self._register_pending_operation( | ||
| operation_key, | ||
| expected_results=1, | ||
| complete_predicate=lambda r: ( | ||
| r.get("taskType") == ETaskType.GET_TASK_DETAILS.value | ||
| and "request" in r | ||
| and "response" in r | ||
| ), | ||
| result_filter=lambda r: r.get("taskType") == ETaskType.GET_TASK_DETAILS.value, | ||
| ) | ||
| try: | ||
| if should_send: | ||
| await self.send([request_object]) | ||
| await self._mark_operation_sent(operation_key) | ||
|
|
||
| results = await asyncio.wait_for(future, timeout=self._timeout / 1000) | ||
| if not results: | ||
| raise ValueError(f"No task details found for taskUUID={task_uuid}") | ||
| task_details = instantiateDataclass(ITaskDetails, results[0]) | ||
| task_details.request = self._normalizeTaskDetailsRequest(task_details.request) | ||
| task_details.response = self._normalizeTaskDetailsResponse(task_details.response) | ||
| return task_details | ||
| except asyncio.TimeoutError: | ||
| raise Exception( | ||
| f"Timeout waiting for task details | TaskUUID: {task_uuid} | " | ||
| f"Timeout: {self._timeout}ms" | ||
| ) | ||
| finally: | ||
| await self._unregister_pending_operation(operation_key) | ||
|
|
||
| def _normalizeTaskDetailsRequest(self, request_items: List[Any]) -> List[Any]: | ||
| task_type_map = { | ||
| ETaskType.IMAGE_INFERENCE.value: IImageInference, | ||
| ETaskType.PHOTO_MAKER.value: IPhotoMaker, | ||
| ETaskType.IMAGE_CAPTION.value: IImageCaption, | ||
| ETaskType.IMAGE_BACKGROUND_REMOVAL.value: IImageBackgroundRemoval, | ||
| ETaskType.IMAGE_UPSCALE.value: IImageUpscale, | ||
| ETaskType.PROMPT_ENHANCE.value: IPromptEnhance, | ||
| ETaskType.MODEL_SEARCH.value: IModelSearch, | ||
| ETaskType.VIDEO_INFERENCE.value: IVideoInference, | ||
| ETaskType.VIDEO_CAPTION.value: IVideoCaption, | ||
| ETaskType.VIDEO_BACKGROUND_REMOVAL.value: IVideoBackgroundRemoval, | ||
| ETaskType.VIDEO_UPSCALE.value: IVideoUpscale, | ||
| ETaskType.AUDIO_INFERENCE.value: IAudioInference, | ||
| ETaskType.INFERENCE_3D.value: I3dInference, | ||
| ETaskType.TEXT_INFERENCE.value: ITextInference, | ||
| ETaskType.GET_RESPONSE.value: IGetResponseRequest, | ||
| ETaskType.GET_TASK_DETAILS.value: IGetTaskDetailsRequest, | ||
| ETaskType.IMAGE_VECTORIZE.value: IVectorize, | ||
| } | ||
| return self._normalizeTaskDetailsItems( | ||
| request_items, | ||
| task_type_map, | ||
| lambda cls, item: instantiateDataclass(cls, item), | ||
| ) | ||
|
|
||
| def _normalizeTaskDetailsResponse(self, response_payload: Any) -> Any: | ||
| if not isinstance(response_payload, dict): | ||
| return response_payload if isinstance(response_payload, list) else [response_payload] | ||
|
|
||
| data_items = response_payload.get("data") | ||
| if isinstance(data_items, list): | ||
| response_type_map = { | ||
| ETaskType.AUDIO_INFERENCE.value: IAudio, | ||
| ETaskType.VIDEO_CAPTION.value: IVideoToText, | ||
| ETaskType.IMAGE_CAPTION.value: IImageToText, | ||
| ETaskType.IMAGE_INFERENCE.value: IImage, | ||
| ETaskType.PHOTO_MAKER.value: IImage, | ||
| ETaskType.IMAGE_UPSCALE.value: IImage, | ||
| ETaskType.IMAGE_VECTORIZE.value: IImage, | ||
| ETaskType.IMAGE_BACKGROUND_REMOVAL.value: IImage, | ||
| ETaskType.VIDEO_INFERENCE.value: IVideo, | ||
| ETaskType.VIDEO_BACKGROUND_REMOVAL.value: IVideo, | ||
| ETaskType.VIDEO_UPSCALE.value: IVideo, | ||
| ETaskType.INFERENCE_3D.value: I3d, | ||
| ETaskType.TEXT_INFERENCE.value: IText, | ||
| ETaskType.PROMPT_ENHANCE.value: IEnhancedPrompt, | ||
| ETaskType.GET_TASK_DETAILS.value: ITaskDetails, | ||
| } | ||
| return self._normalizeTaskDetailsItems( | ||
| data_items, | ||
| response_type_map, | ||
| lambda cls, item: instantiateDataclass(cls, item), | ||
| ) | ||
|
|
||
| error_items = response_payload.get("errors") | ||
| if isinstance(error_items, list): | ||
| return error_items | ||
|
|
||
| return [response_payload] | ||
|
|
||
| def _normalizeTaskDetailsItems( | ||
| self, | ||
| items: List[Any], | ||
| task_type_map: Dict[str, Any], | ||
| instantiate_fn: Callable[[Any, Dict[str, Any]], Any], | ||
| ) -> List[Any]: | ||
| normalized: List[Any] = [] | ||
| for item in items: | ||
| if not isinstance(item, dict): | ||
| normalized.append(item) | ||
| continue | ||
| task_type = item.get("taskType") | ||
| target_cls = task_type_map.get(task_type) | ||
| if target_cls is None: | ||
| normalized.append(item) | ||
| continue | ||
| try: | ||
| normalized.append(instantiate_fn(target_cls, item)) | ||
| except Exception: | ||
| normalized.append(item) | ||
| return normalized | ||
|
|
||
| async def _requestVideo(self, requestVideo: "IVideoInference") -> "Union[List[IVideo], IAsyncTaskResponse]": | ||
| if requestVideo.frameImages: | ||
| requestVideo.frameImages = await self._process_media_list( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -921,11 +921,53 @@ def instantiateDataclass(dataclass_type: Type[Any], data: dict) -> Any: | |
| continue | ||
|
|
||
| field_type = hints.get(k) | ||
|
|
||
| # Unwrap Optional[X] -> X | ||
| union_args = [] | ||
| if get_origin(field_type) is Union: | ||
| args = [a for a in get_args(field_type) if a is not type(None)] | ||
| field_type = args[0] if args else field_type | ||
| union_args = [a for a in get_args(field_type) if a is not type(None)] | ||
|
|
||
|
|
||
| matched_union_type = False | ||
| for arg in union_args: | ||
| if isinstance(arg, type) and isinstance(v, arg): | ||
| filtered_data[k] = v | ||
| matched_union_type = True | ||
| break | ||
| if matched_union_type: | ||
| continue | ||
|
|
||
|
|
||
| if isinstance(v, dict): | ||
| dataclass_args = [ | ||
| arg for arg in union_args | ||
| if isinstance(arg, type) and is_dataclass(arg) | ||
| ] | ||
| for dataclass_arg in dataclass_args: | ||
| try: | ||
| filtered_data[k] = instantiateDataclass(dataclass_arg, v) | ||
| matched_union_type = True | ||
| break | ||
| except Exception: | ||
| continue | ||
| if matched_union_type: | ||
| continue | ||
|
|
||
| has_dict_branch = any( | ||
| arg is dict or get_origin(arg) is dict | ||
| for arg in union_args | ||
| ) | ||
| if has_dict_branch: | ||
| filtered_data[k] = v | ||
| continue | ||
|
|
||
|
|
||
| if isinstance(v, list): | ||
| list_arg = next((arg for arg in union_args if get_origin(arg) is list), None) | ||
| if list_arg is not None: | ||
| field_type = list_arg | ||
| else: | ||
| field_type = union_args[0] if union_args else field_type | ||
| else: | ||
| field_type = union_args[0] if union_args else field_type | ||
|
Sirsho1997 marked this conversation as resolved.
Comment on lines
923
to
+970
|
||
|
|
||
| # Nested dataclass | ||
| if is_dataclass(field_type) and isinstance(v, dict): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.