From 5776e0a4b01c7375b3ae1d5398d7fe0a5d4ae04b Mon Sep 17 00:00:00 2001 From: Diptanu Choudhury Date: Tue, 3 Dec 2024 07:56:44 -0800 Subject: [PATCH] change json encoding behavior (#1070) --- python-sdk/indexify/executor/agent.py | 203 ++++-------------- python-sdk/indexify/executor/downloader.py | 79 +++---- python-sdk/indexify/executor/task_reporter.py | 43 ++-- python-sdk/indexify/functions_sdk/graph.py | 27 +-- .../functions_sdk/graph_definition.py | 12 +- .../functions_sdk/indexify_functions.py | 47 ++-- .../functions_sdk/object_serializer.py | 24 ++- python-sdk/indexify/http_client.py | 14 +- python-sdk/poetry.lock | 73 +++++-- python-sdk/pyproject.toml | 2 +- python-sdk/tests/test_graph_behaviours.py | 22 +- server/data_model/src/lib.rs | 14 +- server/data_model/src/test_objects.rs | 3 +- server/src/http_objects.rs | 32 ++- server/src/routes/invoke.rs | 2 +- 15 files changed, 274 insertions(+), 323 deletions(-) diff --git a/python-sdk/indexify/executor/agent.py b/python-sdk/indexify/executor/agent.py index 8b8083460..0a62b9a82 100644 --- a/python-sdk/indexify/executor/agent.py +++ b/python-sdk/indexify/executor/agent.py @@ -1,28 +1,21 @@ import asyncio import json -import traceback from concurrent.futures.process import BrokenProcessPool from importlib.metadata import version from pathlib import Path from typing import Dict, List, Optional +import structlog from httpx_sse import aconnect_sse from pydantic import BaseModel -from rich.console import Console -from rich.panel import Panel -from rich.text import Text -from rich.theme import Theme from indexify.common_util import get_httpx_client from indexify.functions_sdk.data_objects import ( FunctionWorkerOutput, IndexifyData, ) -from indexify.functions_sdk.graph_definition import ComputeGraphMetadata from indexify.http_client import IndexifyClient -from ..functions_sdk.image import ImageInformation -from . import image_dependency_installer from .api_objects import ExecutorMetadata, Task from .downloader import DownloadedInputs, Downloader from .executor_tasks import DownloadGraphTask, DownloadInputTask, ExtractTask @@ -31,16 +24,7 @@ from .task_reporter import TaskReporter from .task_store import CompletedTask, TaskStore -custom_theme = Theme( - { - "info": "cyan", - "warning": "yellow", - "error": "red", - "success": "green", - } -) - -console = Console(theme=custom_theme) +logging = structlog.get_logger(module=__name__) class FunctionInput(BaseModel): @@ -68,21 +52,9 @@ def __init__( self._config_path = config_path self._probe = RuntimeProbes() - runtime_probe: ProbeInfo = self._probe.probe() - self._require_image_bootstrap = ( - True - if (runtime_probe.is_default_executor and self.name_alias is not None) - else False - ) - self._executor_bootstrap_failed = False - - console.print( - f"Require Bootstrap? {self._require_image_bootstrap}", style="cyan bold" - ) - self.num_workers = num_workers if config_path: - console.print("Running the extractor with TLS enabled", style="cyan bold") + logging.info("running the extractor with TLS enabled") self._protocol = "https" else: self._protocol = "http" @@ -111,10 +83,8 @@ def __init__( ) async def task_completion_reporter(self): - console.print(Text("Starting task completion reporter", style="bold cyan")) + logging.info("starting task completion reporter") # We should copy only the keys and not the values - url = f"{self._protocol}://{self._server_addr}/write_content" - while True: outcomes = await self._task_store.task_outcomes() for task_outcome in outcomes: @@ -129,16 +99,14 @@ async def task_completion_reporter(self): if "fail" in outcome else f"[bold green] {outcome} [/]" ) - console.print( - Panel( - f"Reporting outcome of task: {task_outcome.task.id}, function: {task_outcome.task.compute_fn}\n" - f"Outcome: {style_outcome}\n" - f"Num Fn Outputs: {len(task_outcome.outputs or [])}\n" - f"Router Output: {task_outcome.router_output}\n" - f"Retries: {task_outcome.reporting_retries}", - title="Task Completion", - border_style="info", - ) + logging.info( + "reporting_task_outcome", + task_id=task_outcome.task.id, + fn_name=task_outcome.task.compute_fn, + num_outputs=len(task_outcome.outputs or []), + router_output=task_outcome.router_output, + outcome=task_outcome.task_outcome, + retries=task_outcome.reporting_retries, ) try: @@ -146,15 +114,11 @@ async def task_completion_reporter(self): self._task_reporter.report_task_outcome(completed_task=task_outcome) except Exception as e: # The connection was dropped in the middle of the reporting, process, retry - console.print( - Panel( - f"Failed to report task {task_outcome.task.id}\n" - f"Exception: {type(e).__name__}({e})\n" - f"Retries: {task_outcome.reporting_retries}\n" - "Retrying...", - title="Reporting Error", - border_style="error", - ) + logging.error( + "failed_to_report_task", + task_id=task_outcome.task.id, + exception=f"exception: {type(e).__name__}({e})", + retries=task_outcome.reporting_retries, ) task_outcome.reporting_retries += 1 await asyncio.sleep(5) @@ -176,44 +140,6 @@ async def task_launcher(self): fn: FunctionInput for fn in fn_queue: task: Task = self._task_store.get_task(fn.task_id) - - if self._executor_bootstrap_failed: - completed_task = CompletedTask( - task=task, - outputs=[], - task_outcome="failure", - ) - self._task_store.complete(outcome=completed_task) - - continue - - # Bootstrap this executor. Fail the task if we can't. - if self._require_image_bootstrap: - try: - image_info = await _get_image_info_for_compute_graph( - task, self._protocol, self._server_addr, self._config_path - ) - image_dependency_installer.executor_image_builder( - image_info, self.name_alias, self.image_version - ) - self._require_image_bootstrap = False - except Exception as e: - console.print( - Text("Failed to bootstrap the executor ", style="red bold") - + Text(f"Exception: {traceback.format_exc()}", style="red") - ) - - self._executor_bootstrap_failed = True - - completed_task = CompletedTask( - task=task, - outputs=[], - task_outcome="failure", - ) - self._task_store.complete(outcome=completed_task) - - continue - async_tasks.append( ExtractTask( function_worker=self._function_worker, @@ -233,12 +159,9 @@ async def task_launcher(self): for async_task in done: if async_task.get_name() == "get_runnable_tasks": if async_task.exception(): - console.print( - Text("Task Launcher Error: ", style="red bold") - + Text( - f"Failed to get runnable tasks: {async_task.exception()}", - style="red", - ) + logging.error( + "task_launcher_error, failed to get runnable tasks", + exception=async_task.exception(), ) continue result: Dict[str, Task] = await async_task @@ -255,12 +178,9 @@ async def task_launcher(self): ) elif async_task.get_name() == "download_graph": if async_task.exception(): - console.print( - Text( - f"Failed to download graph for task {async_task.task.id}\n", - style="red bold", - ) - + Text(f"Exception: {async_task.exception()}", style="red") + logging.error( + "task_launcher_error, failed to download graph", + exception=async_task.exception(), ) completed_task = CompletedTask( task=async_task.task, @@ -276,12 +196,9 @@ async def task_launcher(self): ) elif async_task.get_name() == "download_input": if async_task.exception(): - console.print( - Text( - f"Failed to download input for task {async_task.task.id}\n", - style="red bold", - ) - + Text(f"Exception: {async_task.exception()}", style="red") + logging.error( + "task_launcher_error, failed to download input", + exception=str(async_task.exception()), ) completed_task = CompletedTask( task=async_task.task, @@ -334,12 +251,10 @@ async def task_launcher(self): self._task_store.retriable_failure(async_task.task.id) continue except Exception as e: - console.print( - Text( - f"Failed to execute task {async_task.task.id}\n", - style="red bold", - ) - + Text(f"Exception: {e}", style="red") + logging.error( + "failed to execute task", + task_id=async_task.task.id, + exception=str(e), ) completed_task = CompletedTask( task=async_task.task, @@ -360,12 +275,6 @@ async def run(self): self._should_run = True while self._should_run: url = f"{self._protocol}://{self._server_addr}/internal/executors/{self._executor_id}/tasks" - print(f"calling url: {url}") - - def to_sentence_case(snake_str): - words = snake_str.split("_") - return words[0].capitalize() + "" + " ".join(words[1:]) - runtime_probe: ProbeInfo = self._probe.probe() executor_version = version("indexify") @@ -391,16 +300,7 @@ def to_sentence_case(snake_str): labels=runtime_probe.labels, ).model_dump() - panel_content = "\n".join( - [f"{to_sentence_case(key)}: {value}" for key, value in data.items()] - ) - console.print( - Panel( - panel_content, - title="attempting to Register Executor", - border_style="cyan", - ) - ) + logging.info("registering_executor", executor_id=self._executor_id) try: async with get_httpx_client(self._config_path, True) as client: async with aconnect_sse( @@ -412,11 +312,15 @@ def to_sentence_case(snake_str): ) as event_source: if not event_source.response.is_success: resp = await event_source.response.aread().decode("utf-8") - console.print(f"failed to register: {str(resp)}") + logging.error( + f"failed to register", + resp=str(resp), + status_code=event_source.response.status_code, + ) await asyncio.sleep(5) continue - console.print( - Text("executor registered successfully", style="bold green") + logging.info( + "executor_registered", executor_id=self._executor_id ) async for sse in event_source.aiter_sse(): data = json.loads(sse.data) @@ -427,15 +331,12 @@ def to_sentence_case(snake_str): ) self._task_store.add_tasks(tasks) except Exception as e: - console.print( - Text("registration Error: ", style="red bold") - + Text(f"failed to register: {e}", style="red") - ) + logging.error(f"failed to register: {e}") await asyncio.sleep(5) continue async def _shutdown(self, loop): - console.print(Text("shutting down agent...", style="bold yellow")) + logging.info("shutting_down") self._should_run = False for task in asyncio.all_tasks(loop): task.cancel() @@ -443,27 +344,3 @@ async def _shutdown(self, loop): def shutdown(self, loop): self._function_worker.shutdown() loop.create_task(self._shutdown(loop)) - - -async def _get_image_info_for_compute_graph( - task: Task, protocol, server_addr, config_path: str -) -> ImageInformation: - namespace = task.namespace - graph_name: str = task.compute_graph - compute_fn_name: str = task.compute_fn - - http_client = IndexifyClient( - service_url=f"{protocol}://{server_addr}", - namespace=namespace, - config_path=config_path, - ) - compute_graph: ComputeGraphMetadata = http_client.graph(graph_name) - - console.print( - Text( - f"Compute_fn name {compute_fn_name}, ComputeGraph {compute_graph} \n", - style="red yellow", - ) - ) - - return compute_graph.nodes[compute_fn_name].compute_fn.image_information diff --git a/python-sdk/indexify/executor/downloader.py b/python-sdk/indexify/executor/downloader.py index 75ac9bbdc..65d3fd8aa 100644 --- a/python-sdk/indexify/executor/downloader.py +++ b/python-sdk/indexify/executor/downloader.py @@ -2,10 +2,8 @@ from typing import Optional import httpx +import structlog from pydantic import BaseModel -from rich.console import Console -from rich.panel import Panel -from rich.theme import Theme from indexify.functions_sdk.data_objects import IndexifyData @@ -13,15 +11,7 @@ from ..functions_sdk.object_serializer import JsonSerializer, get_serializer from .api_objects import Task -custom_theme = Theme( - { - "info": "cyan", - "warning": "yellow", - "error": "red", - } -) - -console = Console(theme=custom_theme) +logger = structlog.get_logger(module=__name__) class DownloadedInputs(BaseModel): @@ -42,26 +32,21 @@ async def download_graph(self, namespace: str, name: str, version: int) -> str: if os.path.exists(path): return path - console.print( - Panel( - f"Downloading graph: {name}\nPath: {path}", - title="downloader", - border_style="cyan", - ) + logger.info( + "downloading graph", namespace=namespace, name=name, version=version ) - response = self._client.get( f"{self.base_url}/internal/namespaces/{namespace}/compute_graphs/{name}/code" ) try: response.raise_for_status() except httpx.HTTPStatusError as e: - console.print( - Panel( - f"Failed to download graph: {name}\nError: {response.text}", - title="downloader error", - border_style="error", - ) + logger.error( + "failed to download graph", + namespace=namespace, + name=name, + version=version, + error=response.text, ) raise @@ -81,25 +66,17 @@ async def download_input(self, task: Task) -> DownloadedInputs: if task.reducer_output_id: reducer_url = f"{self.base_url}/namespaces/{task.namespace}/compute_graphs/{task.compute_graph}/invocations/{task.invocation_id}/fn/{task.compute_fn}/output/{task.reducer_output_id}" - console.print( - Panel( - f"downloading input\nURL: {url} \n reducer input URL: {reducer_url}", - title="downloader", - border_style="cyan", - ) - ) - + logger.info("downloading input", url=url, reducer_url=reducer_url) response = self._client.get(url) try: response.raise_for_status() except httpx.HTTPStatusError as e: - console.print( - Panel( - f"failed to download input: {task.input_key}\nError: {response.text}", - title="downloader error", - border_style="error", - ) + logger.error( + "failed to download input", + url=url, + reducer_url=reducer_url, + error=response.text, ) raise @@ -108,8 +85,6 @@ async def download_input(self, task: Task) -> DownloadedInputs: if response.headers["content-type"] == JsonSerializer.content_type else "cloudpickle" ) - serializer = get_serializer(encoder) - if task.invocation_id == input_id: return DownloadedInputs( input=IndexifyData( @@ -117,26 +92,24 @@ async def download_input(self, task: Task) -> DownloadedInputs: ), ) - deserialized_content = serializer.deserialize(response.content) + input_payload = response.content if reducer_url: - init_value = self._client.get(reducer_url) + response = self._client.get(reducer_url) try: - init_value.raise_for_status() + response.raise_for_status() + init_value = response.content except httpx.HTTPStatusError as e: - console.print( - Panel( - f"failed to download reducer output: {task.reducer_output_id}\nError: {init_value.text}", - title="downloader error", - border_style="error", - ) + logger.error( + "failed to download reducer output", + url=reducer_url, + error=response.text, ) raise - init_value = serializer.deserialize(init_value.content) return DownloadedInputs( input=IndexifyData( input_id=task.invocation_id, - payload=deserialized_content, + payload=input_payload, encoder=encoder, ), init_value=IndexifyData( @@ -147,7 +120,7 @@ async def download_input(self, task: Task) -> DownloadedInputs: return DownloadedInputs( input=IndexifyData( input_id=task.invocation_id, - payload=deserialized_content, + payload=input_payload, encoder=encoder, ) ) diff --git a/python-sdk/indexify/executor/task_reporter.py b/python-sdk/indexify/executor/task_reporter.py index c905eb94e..6d160c360 100644 --- a/python-sdk/indexify/executor/task_reporter.py +++ b/python-sdk/indexify/executor/task_reporter.py @@ -1,10 +1,9 @@ -import io from typing import Optional import nanoid +import structlog from httpx import Timeout from pydantic import BaseModel -from rich import print from indexify.common_util import get_httpx_client from indexify.executor.api_objects import RouterOutput as ApiRouterOutput @@ -12,6 +11,8 @@ from indexify.executor.task_store import CompletedTask from indexify.functions_sdk.object_serializer import get_serializer +logger = structlog.get_logger(__name__) + # https://github.com/psf/requests/issues/1081#issuecomment-428504128 class ForceMultipartDict(dict): @@ -46,15 +47,14 @@ def report_task_outcome(self, completed_task: CompletedTask): fn_outputs = [] for output in completed_task.outputs or []: serializer = get_serializer(output.encoder) - serialized_output = serializer.serialize(output.payload) fn_outputs.append( ( "node_outputs", - (nanoid.generate(), serialized_output, serializer.content_type), + (nanoid.generate(), output.payload, serializer.content_type), ) ) report.output_count += 1 - report.output_total_bytes += len(serialized_output) + report.output_total_bytes += len(output.payload) if completed_task.stdout: fn_outputs.append( @@ -109,14 +109,17 @@ def report_task_outcome(self, completed_task: CompletedTask): + report.stderr_total_bytes ) - print( - f"[bold]task-reporter[/bold] reporting task outcome " - f"task_id={completed_task.task.id} retries={completed_task.reporting_retries} " - f"total_bytes={total_bytes} total_files={report.output_count + report.stdout_count + report.stderr_count} " - f"output_files={report.output_count} output_bytes={total_bytes} " - f"stdout_bytes={report.stdout_total_bytes} stderr_bytes={report.stderr_total_bytes} " + logger.info( + "reporting task outcome", + task_id=completed_task.task.id, + retries=completed_task.reporting_retries, + total_bytes=total_bytes, + total_files=report.output_count + report.stdout_count + report.stderr_count, + output_files=report.output_count, + output_bytes=total_bytes, + stdout_bytes=report.stdout_total_bytes, + stderr_bytes=report.stderr_total_bytes, ) - # kwargs = { "data": {"task_result": task_result_data}, @@ -137,15 +140,23 @@ def report_task_outcome(self, completed_task: CompletedTask): **kwargs, ) except Exception as e: - print( - f"[bold]task-reporter[/bold] failed to report task outcome retries={completed_task.reporting_retries} {type(e).__name__}({e})" + logger.error( + "failed to report task outcome", + task_id=completed_task.task.id, + retries=completed_task.reporting_retries, + error=type(e).__name__, + message=str(e), ) raise e try: response.raise_for_status() except Exception as e: - print( - f"[bold]task-reporter[/bold] failed to report task outcome retries={completed_task.reporting_retries} {response.text}" + logger.error( + "failed to report task outcome", + task_id=completed_task.task.id, + retries=completed_task.reporting_retries, + status_code=response.status_code, + response_text=response.text, ) raise e diff --git a/python-sdk/indexify/functions_sdk/graph.py b/python-sdk/indexify/functions_sdk/graph.py index f19ea78a1..c2d688ae4 100644 --- a/python-sdk/indexify/functions_sdk/graph.py +++ b/python-sdk/indexify/functions_sdk/graph.py @@ -1,3 +1,4 @@ +import json import sys from collections import defaultdict from queue import deque @@ -101,9 +102,7 @@ def add_node( return self if issubclass(indexify_fn, IndexifyFunction) and indexify_fn.accumulate: - self.accumulator_zero_values[indexify_fn.name] = ( - indexify_fn.accumulate().model_dump() - ) + self.accumulator_zero_values[indexify_fn.name] = indexify_fn.accumulate() self.nodes[indexify_fn.name] = indexify_fn return self @@ -167,7 +166,8 @@ def definition(self) -> ComputeGraphMetadata: reducer=is_reducer, image_name=start_node.image._image_name, image_information=start_node.image.to_image_information(), - encoder=start_node.encoder, + input_encoder=start_node.input_encoder, + output_encoder=start_node.output_encoder, ) metadata_edges = self.edges.copy() metadata_nodes = {} @@ -179,7 +179,8 @@ def definition(self) -> ComputeGraphMetadata: description=node.description or "", source_fn=node_name, target_fns=self.routers[node_name], - encoder=node.encoder, + input_encoder=node.input_encoder, + output_encoder=node.output_encoder, image_name=node.image._image_name, image_information=node.image.to_image_information(), ) @@ -193,7 +194,8 @@ def definition(self) -> ComputeGraphMetadata: reducer=node.accumulate is not None, image_name=node.image._image_name, image_information=node.image.to_image_information(), - encoder=node.encoder, + input_encoder=node.input_encoder, + output_encoder=node.output_encoder, ) ) @@ -212,19 +214,19 @@ def definition(self) -> ComputeGraphMetadata: def run(self, block_until_done: bool = False, **kwargs) -> str: self.validate_graph() start_node = self.nodes[self._start_node] - serializer = get_serializer(start_node.encoder) + serializer = get_serializer(start_node.input_encoder) input = IndexifyData( id=generate(), payload=serializer.serialize(kwargs), - encoder=start_node.encoder, + encoder=start_node.input_encoder, ) print(f"[bold] Invoking {self._start_node}[/bold]") outputs = defaultdict(list) for k, v in self.accumulator_zero_values.items(): node = self.nodes[k] - serializer = get_serializer(node.encoder) + serializer = get_serializer(node.input_encoder) self._accumulator_values[k] = IndexifyData( - payload=serializer.serialize(v), encoder=node.encoder + payload=serializer.serialize(v), encoder=node.input_encoder ) self._results[input.id] = outputs ctx = GraphInvocationContext( @@ -287,7 +289,8 @@ def _run( fn_outputs = function_outputs.ser_outputs print(f"ran {node_name}: num outputs: {len(fn_outputs)}") if self._accumulator_values.get(node_name, None) is not None: - self._accumulator_values[node_name] = fn_outputs[-1].model_copy() + acc_output = fn_outputs[-1].copy() + self._accumulator_values[node_name] = acc_output outputs[node_name] = [] if fn_outputs: outputs[node_name].extend(fn_outputs) @@ -339,7 +342,7 @@ def output( raise ValueError(f"no results found for fn {fn_name} on graph {self.name}") fn = self.nodes[fn_name] fn_model = self.get_function(fn_name).get_output_model() - serializer = get_serializer(fn.encoder) + serializer = get_serializer(fn.output_encoder) outputs = [] for result in results[fn_name]: payload_dict = serializer.deserialize(result.payload) diff --git a/python-sdk/indexify/functions_sdk/graph_definition.py b/python-sdk/indexify/functions_sdk/graph_definition.py index f645a68b6..ecb3ae851 100644 --- a/python-sdk/indexify/functions_sdk/graph_definition.py +++ b/python-sdk/indexify/functions_sdk/graph_definition.py @@ -14,7 +14,8 @@ class FunctionMetadata(BaseModel): reducer: bool = False image_name: str image_information: ImageInformation - encoder: str = "cloudpickle" + input_encoder: str = "cloudpickle" + output_encoder: str = "cloudpickle" class RouterMetadata(BaseModel): @@ -24,7 +25,8 @@ class RouterMetadata(BaseModel): target_fns: List[str] image_name: str image_information: ImageInformation - encoder: str = "cloudpickle" + input_encoder: str = "cloudpickle" + output_encoder: str = "cloudpickle" class NodeMetadata(BaseModel): @@ -49,12 +51,12 @@ class ComputeGraphMetadata(BaseModel): replaying: bool = False def get_input_payload_serializer(self): - return get_serializer(self.start_node.compute_fn.encoder) + return get_serializer(self.start_node.compute_fn.input_encoder) def get_input_encoder(self) -> str: if self.start_node.compute_fn: - return self.start_node.compute_fn.encoder + return self.start_node.compute_fn.input_encoder elif self.start_node.dynamic_router: - return self.start_node.dynamic_router.encoder + return self.start_node.dynamic_router.input_encoder raise ValueError("start node is not set on the graph") diff --git a/python-sdk/indexify/functions_sdk/indexify_functions.py b/python-sdk/indexify/functions_sdk/indexify_functions.py index b57db40ba..ec39d1092 100644 --- a/python-sdk/indexify/functions_sdk/indexify_functions.py +++ b/python-sdk/indexify/functions_sdk/indexify_functions.py @@ -83,7 +83,8 @@ class IndexifyFunction: image: Optional[Image] = DEFAULT_IMAGE_3_10 placement_constraints: List[PlacementConstraints] = [] accumulate: Optional[Type[Any]] = None - encoder: Optional[str] = "cloudpickle" + input_encoder: Optional[str] = "cloudpickle" + output_encoder: Optional[str] = "cloudpickle" def run(self, *args, **kwargs) -> Union[List[Any], Any]: pass @@ -95,7 +96,7 @@ def partial(self, **kwargs) -> Callable: @classmethod def deserialize_output(cls, output: IndexifyData) -> Any: - serializer = get_serializer(cls.encoder) + serializer = get_serializer(cls.output_encoder) return serializer.deserialize(output.payload) @@ -104,7 +105,8 @@ class IndexifyRouter: description: str = "" image: Optional[Image] = DEFAULT_IMAGE_3_10 placement_constraints: List[PlacementConstraints] = [] - encoder: Optional[str] = "cloudpickle" + input_encoder: Optional[str] = "cloudpickle" + output_encoder: Optional[str] = "cloudpickle" def run(self, *args, **kwargs) -> Optional[List[IndexifyFunction]]: pass @@ -120,7 +122,8 @@ def indexify_router( description: Optional[str] = "", image: Optional[Image] = DEFAULT_IMAGE_3_10, placement_constraints: List[PlacementConstraints] = [], - encoder: Optional[str] = "cloudpickle", + input_encoder: Optional[str] = "cloudpickle", + output_encoder: Optional[str] = "cloudpickle", ): def construct(fn): # Get function signature using inspect.signature @@ -144,7 +147,8 @@ def run(self, *args, **kwargs): ), "image": image, "placement_constraints": placement_constraints, - "encoder": encoder, + "input_encoder": input_encoder, + "output_encoder": output_encoder, "run": run, } @@ -158,7 +162,8 @@ def indexify_function( description: Optional[str] = "", image: Optional[Image] = DEFAULT_IMAGE_3_10, accumulate: Optional[Type[BaseModel]] = None, - encoder: Optional[str] = "cloudpickle", + input_encoder: Optional[str] = "cloudpickle", + output_encoder: Optional[str] = "cloudpickle", placement_constraints: List[PlacementConstraints] = [], ): def construct(fn): @@ -184,7 +189,8 @@ def run(self, *args, **kwargs): "image": image, "placement_constraints": placement_constraints, "accumulate": accumulate, - "encoder": encoder, + "input_encoder": input_encoder, + "output_encoder": output_encoder, "run": run, } @@ -231,6 +237,18 @@ def get_output_model(self) -> Any: ) return return_type + def get_input_types(self) -> Dict[str, Any]: + if not isinstance(self.indexify_function, IndexifyFunction): + raise TypeError("Input must be an instance of IndexifyFunction") + + extract_method = self.indexify_function.run + type_hints = get_type_hints(extract_method) + return { + k: v + for k, v in type_hints.items() + if k != "return" and not is_pydantic_model_from_annotation(v) + } + def run_router( self, input: Union[Dict, Type[BaseModel]] ) -> Tuple[List[str], Optional[str]]: @@ -280,20 +298,17 @@ def invoke_fn_ser( self, name: str, input: IndexifyData, acc: Optional[Any] = None ) -> FunctionCallResult: input = self.deserialize_input(name, input) - serializer = get_serializer(self.indexify_function.encoder) + input_serializer = get_serializer(self.indexify_function.input_encoder) + output_serializer = get_serializer(self.indexify_function.output_encoder) if acc is not None: - acc = self.indexify_function.accumulate.model_validate( - serializer.deserialize(acc.payload) - ) + acc = input_serializer.deserialize(acc.payload) if acc is None and self.indexify_function.accumulate is not None: - acc = self.indexify_function.accumulate.model_validate( - self.indexify_function.accumulate() - ) + acc = self.indexify_function.accumulate() outputs, err = self.run_fn(input, acc=acc) ser_outputs = [ IndexifyData( - payload=serializer.serialize(output), - encoder=self.indexify_function.encoder, + payload=output_serializer.serialize(output), + encoder=self.indexify_function.output_encoder, ) for output in outputs ] diff --git a/python-sdk/indexify/functions_sdk/object_serializer.py b/python-sdk/indexify/functions_sdk/object_serializer.py index 708cd147a..45e48e881 100644 --- a/python-sdk/indexify/functions_sdk/object_serializer.py +++ b/python-sdk/indexify/functions_sdk/object_serializer.py @@ -1,7 +1,7 @@ -from typing import Any, List +import json +from typing import Any, List, Type import cloudpickle -import jsonpickle def get_serializer(serializer_type: str) -> Any: @@ -22,19 +22,29 @@ class JsonSerializer: @staticmethod def serialize(data: Any) -> str: - return jsonpickle.encode(data) + try: + return json.dumps(data) + except Exception as e: + raise ValueError(f"failed to serialize data with json: {e}") @staticmethod def deserialize(data: str) -> Any: - return jsonpickle.decode(data) + try: + if isinstance(data, bytes): + data = data.decode("utf-8") + return json.loads(data) + except Exception as e: + raise ValueError(f"failed to deserialize data with json: {e}") @staticmethod def serialize_list(data: List[Any]) -> str: - return jsonpickle.encode(data) + return json.dumps(data) @staticmethod - def deserialize_list(data: str) -> List[Any]: - return jsonpickle.decode(data) + def deserialize_list(data: str, t: Type) -> List[Any]: + if isinstance(data, bytes): + data = data.decode("utf-8") + return json.loads(data) class CloudPickleSerializer: diff --git a/python-sdk/indexify/http_client.py b/python-sdk/indexify/http_client.py index b2a080cb4..7688bb0a9 100644 --- a/python-sdk/indexify/http_client.py +++ b/python-sdk/indexify/http_client.py @@ -274,10 +274,10 @@ def invoke_graph_with_object( self, graph: str, block_until_done: bool = False, - serializer: str = "cloudpickle", + input_encoding: str = "cloudpickle", **kwargs, ) -> str: - serializer = get_serializer(serializer) + serializer = get_serializer(input_encoding) ser_input = serializer.serialize(kwargs) params = {"block_until_finish": block_until_done} kwargs = { @@ -351,11 +351,11 @@ def _download_output( ) response.raise_for_status() content_type = response.headers.get("Content-Type") - serializer = get_serializer(content_type) - decoded_response = serializer.deserialize(response.content) - return IndexifyData( - id=output_id, payload=decoded_response, encoder=serializer.encoding_type - ) + if content_type == "application/octet-stream": + encoding = "cloudpickle" + else: + encoding = "json" + return IndexifyData(id=output_id, payload=response.content, encoder=encoding) def graph_outputs( self, diff --git a/python-sdk/poetry.lock b/python-sdk/poetry.lock index ef9bfbcfa..2bfa2f462 100644 --- a/python-sdk/poetry.lock +++ b/python-sdk/poetry.lock @@ -444,24 +444,6 @@ files = [ [package.extras] colors = ["colorama (>=0.4.6)"] -[[package]] -name = "jsonpickle" -version = "4.0.0" -description = "jsonpickle encodes/decodes any Python object to/from JSON" -optional = false -python-versions = ">=3.8" -files = [ - {file = "jsonpickle-4.0.0-py3-none-any.whl", hash = "sha256:53730b9e094bc41f540bfdd25eaf6e6cf43811590e9e1477abcec44b866ddcd9"}, - {file = "jsonpickle-4.0.0.tar.gz", hash = "sha256:fc670852b204d77601b08f8f9333149ac37ab6d3fe4e6ed3b578427291f63736"}, -] - -[package.extras] -cov = ["pytest-cov"] -dev = ["black", "pyupgrade"] -docs = ["furo", "rst.linker (>=1.9)", "sphinx (>=3.5)"] -packaging = ["build", "setuptools (>=61.2)", "setuptools-scm[toml] (>=6.0)", "twine"] -testing = ["PyYAML", "atheris (>=2.3.0,<2.4.0)", "bson", "ecdsa", "feedparser", "gmpy2", "numpy", "pandas", "pymongo", "pytest (>=6.0,!=8.1.*)", "pytest-benchmark", "pytest-benchmark[histogram]", "pytest-checkdocs (>=1.2.3)", "pytest-enabler (>=1.0.1)", "pytest-ruff (>=0.2.1)", "scikit-learn", "scipy", "scipy (>=1.9.3)", "simplejson", "sqlalchemy", "ujson"] - [[package]] name = "lazy-object-proxy" version = "1.10.0" @@ -954,15 +936,62 @@ files = [ {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, ] +[[package]] +name = "structlog" +version = "24.4.0" +description = "Structured Logging for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "structlog-24.4.0-py3-none-any.whl", hash = "sha256:597f61e80a91cc0749a9fd2a098ed76715a1c8a01f73e336b746504d1aad7610"}, + {file = "structlog-24.4.0.tar.gz", hash = "sha256:b27bfecede327a6d2da5fbc96bd859f114ecc398a6389d664f62085ee7ae6fc4"}, +] + +[package.extras] +dev = ["freezegun (>=0.2.8)", "mypy (>=1.4)", "pretend", "pytest (>=6.0)", "pytest-asyncio (>=0.17)", "rich", "simplejson", "twisted"] +docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-mermaid", "sphinxext-opengraph", "twisted"] +tests = ["freezegun (>=0.2.8)", "pretend", "pytest (>=6.0)", "pytest-asyncio (>=0.17)", "simplejson"] +typing = ["mypy (>=1.4)", "rich", "twisted"] + [[package]] name = "tomli" -version = "2.1.0" +version = "2.2.1" description = "A lil' TOML parser" optional = false python-versions = ">=3.8" files = [ - {file = "tomli-2.1.0-py3-none-any.whl", hash = "sha256:a5c57c3d1c56f5ccdf89f6523458f60ef716e210fc47c4cfb188c5ba473e0391"}, - {file = "tomli-2.1.0.tar.gz", hash = "sha256:3f646cae2aec94e17d04973e4249548320197cfabdf130015d023de4b74d8ab8"}, + {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"}, + {file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"}, + {file = "tomli-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ece47d672db52ac607a3d9599a9d48dcb2f2f735c6c2d1f34130085bb12b112a"}, + {file = "tomli-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6972ca9c9cc9f0acaa56a8ca1ff51e7af152a9f87fb64623e31d5c83700080ee"}, + {file = "tomli-2.2.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c954d2250168d28797dd4e3ac5cf812a406cd5a92674ee4c8f123c889786aa8e"}, + {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8dd28b3e155b80f4d54beb40a441d366adcfe740969820caf156c019fb5c7ec4"}, + {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e59e304978767a54663af13c07b3d1af22ddee3bb2fb0618ca1593e4f593a106"}, + {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:33580bccab0338d00994d7f16f4c4ec25b776af3ffaac1ed74e0b3fc95e885a8"}, + {file = "tomli-2.2.1-cp311-cp311-win32.whl", hash = "sha256:465af0e0875402f1d226519c9904f37254b3045fc5084697cefb9bdde1ff99ff"}, + {file = "tomli-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:2d0f2fdd22b02c6d81637a3c95f8cd77f995846af7414c5c4b8d0545afa1bc4b"}, + {file = "tomli-2.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4a8f6e44de52d5e6c657c9fe83b562f5f4256d8ebbfe4ff922c495620a7f6cea"}, + {file = "tomli-2.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8d57ca8095a641b8237d5b079147646153d22552f1c637fd3ba7f4b0b29167a8"}, + {file = "tomli-2.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e340144ad7ae1533cb897d406382b4b6fede8890a03738ff1683af800d54192"}, + {file = "tomli-2.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db2b95f9de79181805df90bedc5a5ab4c165e6ec3fe99f970d0e302f384ad222"}, + {file = "tomli-2.2.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40741994320b232529c802f8bc86da4e1aa9f413db394617b9a256ae0f9a7f77"}, + {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:400e720fe168c0f8521520190686ef8ef033fb19fc493da09779e592861b78c6"}, + {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:02abe224de6ae62c19f090f68da4e27b10af2b93213d36cf44e6e1c5abd19fdd"}, + {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b82ebccc8c8a36f2094e969560a1b836758481f3dc360ce9a3277c65f374285e"}, + {file = "tomli-2.2.1-cp312-cp312-win32.whl", hash = "sha256:889f80ef92701b9dbb224e49ec87c645ce5df3fa2cc548664eb8a25e03127a98"}, + {file = "tomli-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:7fc04e92e1d624a4a63c76474610238576942d6b8950a2d7f908a340494e67e4"}, + {file = "tomli-2.2.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f4039b9cbc3048b2416cc57ab3bda989a6fcf9b36cf8937f01a6e731b64f80d7"}, + {file = "tomli-2.2.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:286f0ca2ffeeb5b9bd4fcc8d6c330534323ec51b2f52da063b11c502da16f30c"}, + {file = "tomli-2.2.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a92ef1a44547e894e2a17d24e7557a5e85a9e1d0048b0b5e7541f76c5032cb13"}, + {file = "tomli-2.2.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9316dc65bed1684c9a98ee68759ceaed29d229e985297003e494aa825ebb0281"}, + {file = "tomli-2.2.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e85e99945e688e32d5a35c1ff38ed0b3f41f43fad8df0bdf79f72b2ba7bc5272"}, + {file = "tomli-2.2.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ac065718db92ca818f8d6141b5f66369833d4a80a9d74435a268c52bdfa73140"}, + {file = "tomli-2.2.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:d920f33822747519673ee656a4b6ac33e382eca9d331c87770faa3eef562aeb2"}, + {file = "tomli-2.2.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a198f10c4d1b1375d7687bc25294306e551bf1abfa4eace6650070a5c1ae2744"}, + {file = "tomli-2.2.1-cp313-cp313-win32.whl", hash = "sha256:d3f5614314d758649ab2ab3a62d4f2004c825922f9e370b29416484086b264ec"}, + {file = "tomli-2.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:a38aa0308e754b0e3c67e344754dff64999ff9b513e691d0e786265c93583c69"}, + {file = "tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc"}, + {file = "tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff"}, ] [[package]] @@ -1098,4 +1127,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "10e74e0011f33bc047a3a11d4b91005eedc82fda725bc37d20b95a288070ea5d" +content-hash = "31628f9c231567383a7d91a62d61aa2aa51477010443d7e7343a5191539969f7" diff --git a/python-sdk/pyproject.toml b/python-sdk/pyproject.toml index 52331c64d..75655b886 100644 --- a/python-sdk/pyproject.toml +++ b/python-sdk/pyproject.toml @@ -22,7 +22,7 @@ nanoid = "^2.0.0" docker = "^7.1.0" typer = "^0.13.0" httpx-sse = "^0.4.0" -jsonpickle = "^4.0.0" +structlog = "^24.4.0" [tool.poetry.dev-dependencies] black = "^24.10.0" diff --git a/python-sdk/tests/test_graph_behaviours.py b/python-sdk/tests/test_graph_behaviours.py index 73c2b819c..85f999d10 100644 --- a/python-sdk/tests/test_graph_behaviours.py +++ b/python-sdk/tests/test_graph_behaviours.py @@ -4,6 +4,7 @@ from parameterized import parameterized from pydantic import BaseModel +from typing_extensions import TypedDict from indexify import ( Graph, @@ -33,12 +34,12 @@ def simple_function_multiple_inputs(x: MyObject, y: int) -> MyObject: return MyObject(x=x.x + suf) -@indexify_function(encoder="json") +@indexify_function(input_encoder="json", output_encoder="json") def simple_function_with_json_encoder(x: str) -> str: return x + "b" -@indexify_function(encoder="json") +@indexify_function(input_encoder="json", output_encoder="json") def simple_function_multiple_inputs_json(x: str, y: int) -> str: suf = "".join(["b" for _ in range(y)]) return x + suf @@ -49,7 +50,7 @@ def simple_function_with_str_as_input(x: str) -> str: return x + "cc" -@indexify_function(encoder="invalid") +@indexify_function(input_encoder="invalid") def simple_function_with_invalid_encoder(x: MyObject) -> MyObject: return MyObject(x=x.x + "b") @@ -104,7 +105,7 @@ def square(x: int) -> int: return x * x -@indexify_function(encoder="json") +@indexify_function(input_encoder="json", output_encoder="json") def square_with_json_encoder(x: int) -> int: return x * x @@ -119,9 +120,14 @@ def sum_of_squares(init_value: Sum, x: int) -> Sum: return init_value -@indexify_function(accumulate=Sum, encoder="json") -def sum_of_squares_with_json_encoding(init_value: Sum, x: int) -> Sum: - init_value.val += x +class JsonSum(TypedDict): + val: int + + +@indexify_function(accumulate=JsonSum, input_encoder="json") +def sum_of_squares_with_json_encoding(init_value: JsonSum, x: int) -> JsonSum: + val = init_value.get("val", 0) + init_value["val"] = val + x return init_value @@ -331,7 +337,7 @@ def test_map_reduce_operation_with_json_encoding(self, is_remote): output_sum_sq_with_json_encoding = graph.output( invocation_id, "sum_of_squares_with_json_encoding" ) - self.assertEqual(output_sum_sq_with_json_encoding, [Sum(val=9)]) + self.assertEqual(output_sum_sq_with_json_encoding, [{"val": 9}]) @parameterized.expand([(False), (True)]) def test_graph_with_different_encoders(self, is_remote=False): diff --git a/server/data_model/src/lib.rs b/server/data_model/src/lib.rs index b0b378971..e59716e65 100644 --- a/server/data_model/src/lib.rs +++ b/server/data_model/src/lib.rs @@ -119,13 +119,20 @@ impl ImageInformation { } } +fn default_data_encoder() -> String { + "cloudpickle".to_string() +} + #[derive(Default, Debug, Clone, Serialize, Deserialize, Builder, PartialEq, Eq)] pub struct DynamicEdgeRouter { pub name: String, pub description: String, pub source_fn: String, pub target_functions: Vec, - pub encoder: String, + #[serde(default = "default_data_encoder")] + pub input_encoder: String, + #[serde(default = "default_data_encoder")] + pub output_encoder: String, pub image_name: String, pub image_information: ImageInformation, } @@ -137,7 +144,10 @@ pub struct ComputeFn { pub placement_constraints: LabelsFilter, pub fn_name: String, pub reducer: bool, - pub encoder: String, + #[serde(default = "default_data_encoder")] + pub input_encoder: String, + #[serde(default = "default_data_encoder")] + pub output_encoder: String, pub image_name: String, pub image_information: ImageInformation, } diff --git a/server/data_model/src/test_objects.rs b/server/data_model/src/test_objects.rs index 456cb3f0d..d1f60e2d5 100644 --- a/server/data_model/src/test_objects.rs +++ b/server/data_model/src/test_objects.rs @@ -200,7 +200,8 @@ pub mod tests { description: "description router_x".to_string(), source_fn: "fn_a".to_string(), target_functions: vec!["fn_b".to_string(), "fn_c".to_string()], - encoder: "cloudpickle".to_string(), + input_encoder: "cloudpickle".to_string(), + output_encoder: "cloudpickle".to_string(), image_name: TEST_EXECUTOR_IMAGE_NAME.to_string(), image_information: ImageInformation { image_name: "test-image".to_string(), diff --git a/server/src/http_objects.rs b/server/src/http_objects.rs index a6c8e910e..9749c6993 100644 --- a/server/src/http_objects.rs +++ b/server/src/http_objects.rs @@ -126,13 +126,20 @@ impl From for ImageInformation { } } +fn default_encoder() -> String { + "cloudpickle".to_string() +} + #[derive(Debug, Serialize, Deserialize, ToSchema, Clone)] pub struct ComputeFn { pub name: String, pub fn_name: String, pub description: String, pub reducer: bool, - pub encoder: String, + #[serde(default = "default_encoder")] + pub input_encoder: String, + #[serde(default = "default_encoder")] + pub output_encoder: String, pub image_name: String, pub image_information: ImageInformation, } @@ -145,7 +152,8 @@ impl From for data_model::ComputeFn { description: val.description.clone(), placement_constraints: Default::default(), reducer: val.reducer, - encoder: val.encoder.clone(), + input_encoder: val.input_encoder.clone(), + output_encoder: val.output_encoder.clone(), image_name: val.image_name.clone(), image_information: val.image_information.into(), } @@ -159,7 +167,8 @@ impl From for ComputeFn { fn_name: c.fn_name, description: c.description, reducer: c.reducer, - encoder: c.encoder, + input_encoder: c.input_encoder, + output_encoder: c.output_encoder, image_name: c.image_name, image_information: c.image_information.into(), } @@ -172,7 +181,10 @@ pub struct DynamicRouter { pub source_fn: String, pub description: String, pub target_fns: Vec, - pub encoder: String, + #[serde(default = "default_encoder")] + pub input_encoder: String, + #[serde(default = "default_encoder")] + pub output_encoder: String, pub image_name: String, pub image_information: ImageInformation, } @@ -184,7 +196,8 @@ impl From for data_model::DynamicEdgeRouter { source_fn: val.source_fn.clone(), description: val.description.clone(), target_functions: val.target_fns.clone(), - encoder: val.encoder.clone(), + input_encoder: val.input_encoder.clone(), + output_encoder: val.output_encoder.clone(), image_name: val.image_name.clone(), image_information: val.image_information.clone().into(), } @@ -198,7 +211,8 @@ impl From for DynamicRouter { source_fn: d.source_fn, description: d.description, target_fns: d.target_functions, - encoder: d.encoder, + input_encoder: d.input_encoder, + output_encoder: d.output_encoder, image_name: d.image_name, image_information: d.image_information.into(), } @@ -559,7 +573,7 @@ mod tests { // Don't delete this. It makes it easier // to test the deserialization of the ComputeGraph struct // from the python side - let json = r#"{"name":"test","description":"test","start_node":{"compute_fn":{"name":"extractor_a","fn_name":"extractor_a","description":"Random description of extractor_a", "reducer": false, "image_information": {"image_name": "name1", "tag": "tag1", "base_image": "base1", "run_strs": ["tuff", "life", "running", "docker"]}, "encoder":"cloudpickle", "image_name": "default_image"}},"nodes":{"extractor_a":{"compute_fn":{"name":"extractor_a","fn_name":"extractor_a","description":"Random description of extractor_a", "reducer": false, "image_information": {"image_name": "name1", "tag": "tag1", "base_image": "base1", "run_strs": ["tuff", "life", "running", "docker"]}, "encoder":"cloudpickle", "image_name": "default_image"}},"extractor_b":{"compute_fn":{"name":"extractor_b","fn_name":"extractor_b","description":"", "reducer": false, "image_information": {"image_name": "name1", "tag": "tag1", "base_image": "base1", "run_strs": ["tuff", "life", "running", "docker"]}, "encoder":"cloudpickle", "image_name": "default_image"}},"extractor_c":{"compute_fn":{"name":"extractor_c","fn_name":"extractor_c","description":"", "reducer": false, "image_information": {"image_name": "name1", "tag": "tag1", "base_image": "base1", "run_strs": ["tuff", "life", "running", "docker"]}, "encoder":"cloudpickle", "image_name": "default_image"}}},"edges":{"extractor_a":["extractor_b"],"extractor_b":["extractor_c"]},"runtime_information": {"major_version": 3, "minor_version": 10}}"#; + let json = r#"{"name":"test","description":"test","start_node":{"compute_fn":{"name":"extractor_a","fn_name":"extractor_a","description":"Random description of extractor_a", "reducer": false, "image_information": {"image_name": "name1", "tag": "tag1", "base_image": "base1", "run_strs": ["tuff", "life", "running", "docker"]}, "input_encoder":"cloudpickle", "output_encoder":"cloudpickle", "image_name": "default_image"}},"nodes":{"extractor_a":{"compute_fn":{"name":"extractor_a","fn_name":"extractor_a","description":"Random description of extractor_a", "reducer": false, "image_information": {"image_name": "name1", "tag": "tag1", "base_image": "base1", "run_strs": ["tuff", "life", "running", "docker"]}, "input_encoder":"cloudpickle", "output_encoder":"cloudpickle","image_name": "default_image"}},"extractor_b":{"compute_fn":{"name":"extractor_b","fn_name":"extractor_b","description":"", "reducer": false, "image_information": {"image_name": "name1", "tag": "tag1", "base_image": "base1", "run_strs": ["tuff", "life", "running", "docker"]}, "input_encoder":"cloudpickle", "output_encoder":"cloudpickle", "image_name": "default_image"}},"extractor_c":{"compute_fn":{"name":"extractor_c","fn_name":"extractor_c","description":"", "reducer": false, "image_information": {"image_name": "name1", "tag": "tag1", "base_image": "base1", "run_strs": ["tuff", "life", "running", "docker"]}, "input_encoder":"cloudpickle", "output_encoder":"cloudpickle", "image_name": "default_image"}}},"edges":{"extractor_a":["extractor_b"],"extractor_b":["extractor_c"]},"runtime_information": {"major_version": 3, "minor_version": 10}}"#; let mut json_value: serde_json::Value = serde_json::from_str(json).unwrap(); json_value["namespace"] = serde_json::Value::String("test".to_string()); let _: super::ComputeGraph = serde_json::from_value(json_value).unwrap(); @@ -567,7 +581,7 @@ mod tests { #[test] fn test_compute_graph_with_router_deserialization() { - let json = r#"{"name":"graph_a_router","description":"description of graph_a","start_node":{"compute_fn":{"name":"extractor_a","fn_name":"extractor_a","description":"Random description of extractor_a", "reducer": false, "image_information": {"image_name": "name1", "tag": "tag1", "base_image": "base1", "run_strs": ["tuff", "life", "running", "docker"]}, "encoder":"cloudpickle", "image_name": "default_image"}},"nodes":{"extractor_a":{"compute_fn":{"name":"extractor_a","fn_name":"extractor_a","description":"Random description of extractor_a", "reducer": false, "image_information": {"image_name": "name1", "tag": "tag1", "base_image": "base1", "run_strs": ["tuff", "life", "running", "docker"]}, "encoder":"cloudpickle", "image_name": "default_image"}},"router_x":{"dynamic_router":{"name":"router_x","description":"","source_fn":"router_x","target_fns":["extractor_y","extractor_z"], "reducer": false, "image_information": {"image_name": "name1", "tag": "tag1", "base_image": "base1", "run_strs": ["tuff", "life", "running", "docker"]}, "encoder":"cloudpickle", "image_name": "default_image"}},"extractor_y":{"compute_fn":{"name":"extractor_y","fn_name":"extractor_y","description":"", "reducer": false, "image_information": {"image_name": "name1", "tag": "tag1", "base_image": "base1", "run_strs": ["tuff", "life", "running", "docker"]}, "encoder":"cloudpickle", "image_name": "default_image"}},"extractor_z":{"compute_fn":{"name":"extractor_z","fn_name":"extractor_z","description":"", "reducer": false, "image_information": {"image_name": "name1", "tag": "tag1", "base_image": "base1", "run_strs": ["tuff", "life", "running", "docker"]}, "encoder":"cloudpickle", "image_name": "default_image"}},"extractor_c":{"compute_fn":{"name":"extractor_c","fn_name":"extractor_c","description":"", "reducer": false, "image_information": {"image_name": "name1", "tag": "tag1", "base_image": "base1", "run_strs": ["tuff", "life", "running", "docker"]}, "encoder":"cloudpickle", "image_name": "default_image"}}},"edges":{"extractor_a":["router_x"],"extractor_y":["extractor_c"],"extractor_z":["extractor_c"]},"runtime_information": {"major_version": 3, "minor_version": 10}}"#; + let json = r#"{"name":"graph_a_router","description":"description of graph_a","start_node":{"compute_fn":{"name":"extractor_a","fn_name":"extractor_a","description":"Random description of extractor_a", "reducer": false, "image_information": {"image_name": "name1", "tag": "tag1", "base_image": "base1", "run_strs": ["tuff", "life", "running", "docker"]}, "input_encoder":"cloudpickle", "output_encoder":"cloudpickle", "image_name": "default_image"}},"nodes":{"extractor_a":{"compute_fn":{"name":"extractor_a","fn_name":"extractor_a","description":"Random description of extractor_a", "reducer": false, "image_information": {"image_name": "name1", "tag": "tag1", "base_image": "base1", "run_strs": ["tuff", "life", "running", "docker"]}, "input_encoder":"cloudpickle", "output_encoder":"cloudpickle", "image_name": "default_image"}},"router_x":{"dynamic_router":{"name":"router_x","description":"","source_fn":"router_x","target_fns":["extractor_y","extractor_z"], "reducer": false, "image_information": {"image_name": "name1", "tag": "tag1", "base_image": "base1", "run_strs": ["tuff", "life", "running", "docker"]}, "input_encoder":"cloudpickle", "output_encoder":"cloudpickle", "image_name": "default_image"}},"extractor_y":{"compute_fn":{"name":"extractor_y","fn_name":"extractor_y","description":"", "reducer": false, "image_information": {"image_name": "name1", "tag": "tag1", "base_image": "base1", "run_strs": ["tuff", "life", "running", "docker"]}, "input_encoder":"cloudpickle", "output_encoder":"cloudpickle", "image_name": "default_image"}},"extractor_z":{"compute_fn":{"name":"extractor_z","fn_name":"extractor_z","description":"", "reducer": false, "image_information": {"image_name": "name1", "tag": "tag1", "base_image": "base1", "run_strs": ["tuff", "life", "running", "docker"]}, "input_encoder":"cloudpickle", "output_encoder":"cloudpickle", "image_name": "default_image"}},"extractor_c":{"compute_fn":{"name":"extractor_c","fn_name":"extractor_c","description":"", "reducer": false, "image_information": {"image_name": "name1", "tag": "tag1", "base_image": "base1", "run_strs": ["tuff", "life", "running", "docker"]}, "input_encoder":"cloudpickle", "output_encoder":"cloudpickle", "image_name": "default_image"}}},"edges":{"extractor_a":["router_x"],"extractor_y":["extractor_c"],"extractor_z":["extractor_c"]},"runtime_information": {"major_version": 3, "minor_version": 10}}"#; let mut json_value: serde_json::Value = serde_json::from_str(json).unwrap(); json_value["namespace"] = serde_json::Value::String("test".to_string()); let _: super::ComputeGraph = serde_json::from_value(json_value).unwrap(); @@ -575,7 +589,7 @@ mod tests { #[test] fn test_compute_fn_deserialization() { - let json = r#"{"name": "one", "fn_name": "two", "description": "desc", "reducer": true, "image_name": "im1", "image_information": {"image_name": "name1", "tag": "tag1", "base_image": "base1", "run_strs": ["tuff", "life", "running", "docker"]}, "encoder": "clouds"}"#; + let json = r#"{"name": "one", "fn_name": "two", "description": "desc", "reducer": true, "image_name": "im1", "image_information": {"image_name": "name1", "tag": "tag1", "base_image": "base1", "run_strs": ["tuff", "life", "running", "docker"]}, "input_encoder": "cloudpickle", "output_encoder":"cloudpickle"}"#; let compute_fn: ComputeFn = serde_json::from_str(json).unwrap(); println!("{:?}", compute_fn); } diff --git a/server/src/routes/invoke.rs b/server/src/routes/invoke.rs index 4be220476..bc8fb535c 100644 --- a/server/src/routes/invoke.rs +++ b/server/src/routes/invoke.rs @@ -144,7 +144,7 @@ pub async fn invoke_with_file( #[utoipa::path( post, path = "/namespaces/{namespace}/compute_graphs/{compute_graph}/invoke_object", - request_body(content_type = "application/cbor", content = inline(serde_json::Value)), + request_body(content_type = "application/json", content = inline(serde_json::Value)), tag = "ingestion", responses( (status = 200, description = "invocation successful"),