Skip to content

Commit

Permalink
change json encoding behavior (#1070)
Browse files Browse the repository at this point in the history
  • Loading branch information
diptanu authored Dec 3, 2024
1 parent ce2db51 commit 5776e0a
Show file tree
Hide file tree
Showing 15 changed files with 274 additions and 323 deletions.
203 changes: 40 additions & 163 deletions python-sdk/indexify/executor/agent.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,21 @@
import asyncio
import json
import traceback
from concurrent.futures.process import BrokenProcessPool
from importlib.metadata import version
from pathlib import Path
from typing import Dict, List, Optional

import structlog
from httpx_sse import aconnect_sse
from pydantic import BaseModel
from rich.console import Console
from rich.panel import Panel
from rich.text import Text
from rich.theme import Theme

from indexify.common_util import get_httpx_client
from indexify.functions_sdk.data_objects import (
FunctionWorkerOutput,
IndexifyData,
)
from indexify.functions_sdk.graph_definition import ComputeGraphMetadata
from indexify.http_client import IndexifyClient

from ..functions_sdk.image import ImageInformation
from . import image_dependency_installer
from .api_objects import ExecutorMetadata, Task
from .downloader import DownloadedInputs, Downloader
from .executor_tasks import DownloadGraphTask, DownloadInputTask, ExtractTask
Expand All @@ -31,16 +24,7 @@
from .task_reporter import TaskReporter
from .task_store import CompletedTask, TaskStore

custom_theme = Theme(
{
"info": "cyan",
"warning": "yellow",
"error": "red",
"success": "green",
}
)

console = Console(theme=custom_theme)
logging = structlog.get_logger(module=__name__)


class FunctionInput(BaseModel):
Expand Down Expand Up @@ -68,21 +52,9 @@ def __init__(
self._config_path = config_path
self._probe = RuntimeProbes()

runtime_probe: ProbeInfo = self._probe.probe()
self._require_image_bootstrap = (
True
if (runtime_probe.is_default_executor and self.name_alias is not None)
else False
)
self._executor_bootstrap_failed = False

console.print(
f"Require Bootstrap? {self._require_image_bootstrap}", style="cyan bold"
)

self.num_workers = num_workers
if config_path:
console.print("Running the extractor with TLS enabled", style="cyan bold")
logging.info("running the extractor with TLS enabled")
self._protocol = "https"
else:
self._protocol = "http"
Expand Down Expand Up @@ -111,10 +83,8 @@ def __init__(
)

async def task_completion_reporter(self):
console.print(Text("Starting task completion reporter", style="bold cyan"))
logging.info("starting task completion reporter")
# We should copy only the keys and not the values
url = f"{self._protocol}://{self._server_addr}/write_content"

while True:
outcomes = await self._task_store.task_outcomes()
for task_outcome in outcomes:
Expand All @@ -129,32 +99,26 @@ async def task_completion_reporter(self):
if "fail" in outcome
else f"[bold green] {outcome} [/]"
)
console.print(
Panel(
f"Reporting outcome of task: {task_outcome.task.id}, function: {task_outcome.task.compute_fn}\n"
f"Outcome: {style_outcome}\n"
f"Num Fn Outputs: {len(task_outcome.outputs or [])}\n"
f"Router Output: {task_outcome.router_output}\n"
f"Retries: {task_outcome.reporting_retries}",
title="Task Completion",
border_style="info",
)
logging.info(
"reporting_task_outcome",
task_id=task_outcome.task.id,
fn_name=task_outcome.task.compute_fn,
num_outputs=len(task_outcome.outputs or []),
router_output=task_outcome.router_output,
outcome=task_outcome.task_outcome,
retries=task_outcome.reporting_retries,
)

try:
# Send task outcome to the server
self._task_reporter.report_task_outcome(completed_task=task_outcome)
except Exception as e:
# The connection was dropped in the middle of the reporting, process, retry
console.print(
Panel(
f"Failed to report task {task_outcome.task.id}\n"
f"Exception: {type(e).__name__}({e})\n"
f"Retries: {task_outcome.reporting_retries}\n"
"Retrying...",
title="Reporting Error",
border_style="error",
)
logging.error(
"failed_to_report_task",
task_id=task_outcome.task.id,
exception=f"exception: {type(e).__name__}({e})",
retries=task_outcome.reporting_retries,
)
task_outcome.reporting_retries += 1
await asyncio.sleep(5)
Expand All @@ -176,44 +140,6 @@ async def task_launcher(self):
fn: FunctionInput
for fn in fn_queue:
task: Task = self._task_store.get_task(fn.task_id)

if self._executor_bootstrap_failed:
completed_task = CompletedTask(
task=task,
outputs=[],
task_outcome="failure",
)
self._task_store.complete(outcome=completed_task)

continue

# Bootstrap this executor. Fail the task if we can't.
if self._require_image_bootstrap:
try:
image_info = await _get_image_info_for_compute_graph(
task, self._protocol, self._server_addr, self._config_path
)
image_dependency_installer.executor_image_builder(
image_info, self.name_alias, self.image_version
)
self._require_image_bootstrap = False
except Exception as e:
console.print(
Text("Failed to bootstrap the executor ", style="red bold")
+ Text(f"Exception: {traceback.format_exc()}", style="red")
)

self._executor_bootstrap_failed = True

completed_task = CompletedTask(
task=task,
outputs=[],
task_outcome="failure",
)
self._task_store.complete(outcome=completed_task)

continue

async_tasks.append(
ExtractTask(
function_worker=self._function_worker,
Expand All @@ -233,12 +159,9 @@ async def task_launcher(self):
for async_task in done:
if async_task.get_name() == "get_runnable_tasks":
if async_task.exception():
console.print(
Text("Task Launcher Error: ", style="red bold")
+ Text(
f"Failed to get runnable tasks: {async_task.exception()}",
style="red",
)
logging.error(
"task_launcher_error, failed to get runnable tasks",
exception=async_task.exception(),
)
continue
result: Dict[str, Task] = await async_task
Expand All @@ -255,12 +178,9 @@ async def task_launcher(self):
)
elif async_task.get_name() == "download_graph":
if async_task.exception():
console.print(
Text(
f"Failed to download graph for task {async_task.task.id}\n",
style="red bold",
)
+ Text(f"Exception: {async_task.exception()}", style="red")
logging.error(
"task_launcher_error, failed to download graph",
exception=async_task.exception(),
)
completed_task = CompletedTask(
task=async_task.task,
Expand All @@ -276,12 +196,9 @@ async def task_launcher(self):
)
elif async_task.get_name() == "download_input":
if async_task.exception():
console.print(
Text(
f"Failed to download input for task {async_task.task.id}\n",
style="red bold",
)
+ Text(f"Exception: {async_task.exception()}", style="red")
logging.error(
"task_launcher_error, failed to download input",
exception=str(async_task.exception()),
)
completed_task = CompletedTask(
task=async_task.task,
Expand Down Expand Up @@ -334,12 +251,10 @@ async def task_launcher(self):
self._task_store.retriable_failure(async_task.task.id)
continue
except Exception as e:
console.print(
Text(
f"Failed to execute task {async_task.task.id}\n",
style="red bold",
)
+ Text(f"Exception: {e}", style="red")
logging.error(
"failed to execute task",
task_id=async_task.task.id,
exception=str(e),
)
completed_task = CompletedTask(
task=async_task.task,
Expand All @@ -360,12 +275,6 @@ async def run(self):
self._should_run = True
while self._should_run:
url = f"{self._protocol}://{self._server_addr}/internal/executors/{self._executor_id}/tasks"
print(f"calling url: {url}")

def to_sentence_case(snake_str):
words = snake_str.split("_")
return words[0].capitalize() + "" + " ".join(words[1:])

runtime_probe: ProbeInfo = self._probe.probe()

executor_version = version("indexify")
Expand All @@ -391,16 +300,7 @@ def to_sentence_case(snake_str):
labels=runtime_probe.labels,
).model_dump()

panel_content = "\n".join(
[f"{to_sentence_case(key)}: {value}" for key, value in data.items()]
)
console.print(
Panel(
panel_content,
title="attempting to Register Executor",
border_style="cyan",
)
)
logging.info("registering_executor", executor_id=self._executor_id)
try:
async with get_httpx_client(self._config_path, True) as client:
async with aconnect_sse(
Expand All @@ -412,11 +312,15 @@ def to_sentence_case(snake_str):
) as event_source:
if not event_source.response.is_success:
resp = await event_source.response.aread().decode("utf-8")
console.print(f"failed to register: {str(resp)}")
logging.error(
f"failed to register",
resp=str(resp),
status_code=event_source.response.status_code,
)
await asyncio.sleep(5)
continue
console.print(
Text("executor registered successfully", style="bold green")
logging.info(
"executor_registered", executor_id=self._executor_id
)
async for sse in event_source.aiter_sse():
data = json.loads(sse.data)
Expand All @@ -427,43 +331,16 @@ def to_sentence_case(snake_str):
)
self._task_store.add_tasks(tasks)
except Exception as e:
console.print(
Text("registration Error: ", style="red bold")
+ Text(f"failed to register: {e}", style="red")
)
logging.error(f"failed to register: {e}")
await asyncio.sleep(5)
continue

async def _shutdown(self, loop):
console.print(Text("shutting down agent...", style="bold yellow"))
logging.info("shutting_down")
self._should_run = False
for task in asyncio.all_tasks(loop):
task.cancel()

def shutdown(self, loop):
self._function_worker.shutdown()
loop.create_task(self._shutdown(loop))


async def _get_image_info_for_compute_graph(
task: Task, protocol, server_addr, config_path: str
) -> ImageInformation:
namespace = task.namespace
graph_name: str = task.compute_graph
compute_fn_name: str = task.compute_fn

http_client = IndexifyClient(
service_url=f"{protocol}://{server_addr}",
namespace=namespace,
config_path=config_path,
)
compute_graph: ComputeGraphMetadata = http_client.graph(graph_name)

console.print(
Text(
f"Compute_fn name {compute_fn_name}, ComputeGraph {compute_graph} \n",
style="red yellow",
)
)

return compute_graph.nodes[compute_fn_name].compute_fn.image_information
Loading

0 comments on commit 5776e0a

Please sign in to comment.