From 7e1a55bd40004fa9a663e7e15260b6bfa6da399b Mon Sep 17 00:00:00 2001 From: Diptanu Gon Choudhury Date: Sat, 1 Feb 2025 22:55:12 -0800 Subject: [PATCH 1/2] created in-memory state --- server/Cargo.lock | 49 ++++++- server/Cargo.toml | 4 +- server/processor/src/graph_processor.rs | 2 +- server/state_store/Cargo.toml | 3 +- .../src/Dataset Abstraction to create.md | 58 -------- server/state_store/src/in_memory_state.rs | 136 ++++++++++++++++++ server/state_store/src/lib.rs | 13 +- server/state_store/src/requests.rs | 2 +- server/state_store/src/scanner.rs | 16 +++ 9 files changed, 216 insertions(+), 67 deletions(-) delete mode 100644 server/state_store/src/Dataset Abstraction to create.md create mode 100644 server/state_store/src/in_memory_state.rs diff --git a/server/Cargo.lock b/server/Cargo.lock index 126f01987..d94fd79bb 100644 --- a/server/Cargo.lock +++ b/server/Cargo.lock @@ -159,9 +159,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.85" +version = "0.1.86" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f934833b4b7233644e5848f235df3f57ed8c80f1528a26c3dfa13d2147fa056" +checksum = "644dd749086bf3771a2fbc5f256fdb982d53f011c7d5d560304eafeecebce79d" dependencies = [ "proc-macro2", "quote", @@ -393,6 +393,15 @@ version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" +[[package]] +name = "bitmaps" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "031043d04099746d8db04daf1fa424b2bc8bd69d92b25962dcde24da39ab64a2" +dependencies = [ + "typenum", +] + [[package]] name = "blob_store" version = "0.1.0" @@ -1510,6 +1519,21 @@ dependencies = [ "icu_properties", ] +[[package]] +name = "im" +version = "15.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0acd33ff0285af998aaf9b57342af478078f53492322fafc47450e09397e0e9" +dependencies = [ + "bitmaps", + "rand_core", + "rand_xoshiro", + "serde", + "sized-chunks", + "typenum", + "version_check", +] + [[package]] name = "indexify-server" version = "0.2.26" @@ -1528,6 +1552,7 @@ dependencies = [ "futures", "hex", "hyper", + "im", "indexify_ui", "indexify_utils", "metrics", @@ -2443,6 +2468,15 @@ dependencies = [ "rand_core", ] +[[package]] +name = "rand_xoshiro" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f97cdb2a36ed4183de61b2f824cc45c9f1037f28afe0a322e9fff4c108b5aaa" +dependencies = [ + "rand_core", +] + [[package]] name = "rayon" version = "1.10.0" @@ -2925,6 +2959,16 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" +[[package]] +name = "sized-chunks" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16d69225bde7a69b235da73377861095455d298f2b970996eec25ddbb42b3d1e" +dependencies = [ + "bitmaps", + "typenum", +] + [[package]] name = "slab" version = "0.4.9" @@ -3034,6 +3078,7 @@ dependencies = [ "bytes", "data_model", "futures", + "im", "indexify_utils", "metrics", "object_store", diff --git a/server/Cargo.toml b/server/Cargo.toml index 3b37a7c30..9a99847f6 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -49,7 +49,7 @@ object_store = {git = "https://github.com/tensorlakeai/arrow-rs", branch="remove futures = "0.3.31" bytes = "1.9.0" pin-project-lite = "0.2.16" -async-trait = "0.1.85" +async-trait = "0.1.86" tokio-stream = "0.1.17" slatedb = { git = "https://github.com/diptanu/slatedb" } rust-embed = { version = "8.5.0", features = ["mime-guess"] } @@ -64,6 +64,7 @@ tower-http = { version = "0.6.2", default-features = false, features = [ "cors", "trace", ] } +im = {version = "15.1.0", features = ["serde"]} pin-project = "1.1.8" ciborium = "0.2.2" uuid = { version = "1.12.1", features = ["v4"] } @@ -128,6 +129,7 @@ axum-tracing-opentelemetry = { version = "0.25.0", features = [ "tracing_level_info", ] } tower-otel-http-metrics = { workspace = true } +im = { workspace = true } [dev-dependencies] tempfile = { workspace = true } diff --git a/server/processor/src/graph_processor.rs b/server/processor/src/graph_processor.rs index 4899aebc9..59ce99b01 100644 --- a/server/processor/src/graph_processor.rs +++ b/server/processor/src/graph_processor.rs @@ -283,7 +283,7 @@ fn task_creation_result_to_sm_update( state_change: &StateChange, ) -> StateMachineUpdateRequest { StateMachineUpdateRequest { - payload: RequestPayload::NamespaceProcessorUpdate(NamespaceProcessorUpdateRequest { + payload: RequestPayload::TaskCreatorUpdate(NamespaceProcessorUpdateRequest { namespace: ns.to_string(), compute_graph: compute_graph.to_string(), invocation_id: invocation_id.to_string(), diff --git a/server/state_store/Cargo.toml b/server/state_store/Cargo.toml index da625e115..1053c749c 100644 --- a/server/state_store/Cargo.toml +++ b/server/state_store/Cargo.toml @@ -19,7 +19,8 @@ tracing = { workspace = true } tokio = { workspace = true } tokio-stream = { workspace = true } futures.workspace = true -async-stream = "0.3.5" +async-stream = {workspace = true } +im = { workspace = true } tempfile = { workspace = true } object_store.workspace = true blob_store = { version = "0.1.0", path = "../blob_store" } diff --git a/server/state_store/src/Dataset Abstraction to create.md b/server/state_store/src/Dataset Abstraction to create.md deleted file mode 100644 index f5af8ba65..000000000 --- a/server/state_store/src/Dataset Abstraction to create.md +++ /dev/null @@ -1,58 +0,0 @@ -## Dataset -Abstraction to create indexes from a collection of documents by specifying - -1. Parsing Strategy -2. Embedding Models -3. Database URL - -The dataset can be updated continuously by the user, when there are new documents and the indexes get updated by Tensorlake. - -1. Create a Dataset in your project -2. Add Documents to the dataset with some parsing scheme -3. The dataset gets updated -4. The dataset is written to your database - - -Application Flow is - -1. Read from the database during retrieval. - -## What are Datasets -1. A dataset is a collection of related documents. -2. Datasets can be used to automatically parse, chunk, and embed documents for retrieval by LLM applications. -3. Datasets can be continously updated asynchronously by data ingestion pipelines. -4. Datasets can be exported into databases automatically or using APIs. - - -Two different modalities - -1. Out of the box Datasets abstraction for retrieval from documents. -2. Custom workflows for building end to end pipelines for document understanding use cases with complete control over the transformations and extraction process. - -## Approach 2 - Index -1. Crate an Index, specify the parsing mechanism and the embedding models and indexing schema. -2. Add new documents to the index. - - -Sarma - -1. Each dataset is a collection of documents. -2. Each document has a document id. -3. Each document also has a pointer to workflow outputs. This is represented as K/V, document id -> output in workflow. -4. The user can query and check if the dataset has been extracted, transformed and ingested into an index. - -The utility of the dataset abstraction is to check whether a document has made it all the way to an index. - -Datasets needs to provide a `status` API, and provide information about - -1. Total number of documents. -2. Current state/marker of ingestion. - - -## Define the problem of creating indexes for retrieval. -## The general shape of pipelines for building such indexes. -## An out of the box solution as an API -## Levers of the API -### Define the transformations that can be attached to a Dataset -### Talk about the API to ingest into Datasets -### Current Status of the Dataset - - - -## Getting even more flexibility - Write serverless workflows yourself. -1. You can write abstractions such as Datasets yourself on our platform. We have even opensourced the workflow code we use for the Datasets abstrction. You can tweak it to implement new algorithms and deploy this as a service in your account as an API. \ No newline at end of file diff --git a/server/state_store/src/in_memory_state.rs b/server/state_store/src/in_memory_state.rs new file mode 100644 index 000000000..4e7133bac --- /dev/null +++ b/server/state_store/src/in_memory_state.rs @@ -0,0 +1,136 @@ +use anyhow::Result; +use data_model::{ComputeGraph, ExecutorMetadata, Task}; +use im::HashSet; + +use crate::{ + requests::{RequestPayload, StateMachineUpdateRequest}, + scanner::StateReader, + state_machine::IndexifyObjectsColumns, +}; + +pub struct InMemoryState { + namespaces: im::HashMap, + + // Namespace|CG Name -> ComputeGraph + compute_graphs: im::HashMap, + + // ExecutorId -> ExecutorMetadata + executors: im::HashMap, + + // Executor Id -> List of Task IDs + allocated_tasks: im::HashMap>, + // Task ID -> Task + tasks: im::HashMap, + + // Task Keys + unallocated_tasks: im::HashMap, +} + +impl InMemoryState { + pub fn new(reader: &StateReader) -> Result { + let mut namespaces = im::HashMap::new(); + let all_ns = reader.get_all_namespaces()?; + for ns in &all_ns { + namespaces.insert(ns.name.clone(), [0; 0]); + } + let mut compute_graphs = im::HashMap::new(); + for ns in &all_ns { + let cgs = reader.list_compute_graphs(&ns.name, None, None)?.0; + for cg in cgs { + compute_graphs.insert(format!("{}|{}", ns.name, cg.name), cg); + } + } + let all_executors = reader.get_all_executors()?; + let mut executors = im::HashMap::new(); + let mut allocated_tasks = im::HashMap::new(); + for executor in &all_executors { + executors.insert(executor.id.get().to_string(), executor.clone()); + let executor_allocated_tasks = reader.get_allocated_tasks(&executor.id)?; + allocated_tasks.insert( + executor.id.get().to_string(), + executor_allocated_tasks + .into_iter() + .map(|t| t.id.to_string()) + .collect(), + ); + } + + let all_tasks: Vec<(String, Task)> = + reader.get_all_rows_from_cf(IndexifyObjectsColumns::Tasks)?; + let mut tasks = im::HashMap::new(); + for (id, task) in all_tasks { + tasks.insert(id, task); + } + let all_unallocated_tasks: Vec<(String, [u8; 0])> = + reader.get_all_rows_from_cf(IndexifyObjectsColumns::UnallocatedTasks)?; + let mut unallocated_tasks = im::HashMap::new(); + for (id, task) in all_unallocated_tasks { + unallocated_tasks.insert(id, task); + } + Ok(Self { + namespaces, + compute_graphs, + executors, + allocated_tasks, + tasks, + unallocated_tasks, + }) + } + + pub fn update_state( + &mut self, + state_machine_update_request: &StateMachineUpdateRequest, + ) -> Result<()> { + match &state_machine_update_request.payload { + RequestPayload::CreateNameSpace(req) => { + self.namespaces.insert(req.name.clone(), [0; 0]); + } + RequestPayload::CreateOrUpdateComputeGraph(req) => { + self.compute_graphs.insert( + format!("{}|{}", req.namespace, req.compute_graph.name), + req.compute_graph.clone(), + ); + } + RequestPayload::DeleteComputeGraphRequest(req) => { + self.compute_graphs + .remove(&format!("{}|{}", req.namespace, req.name)); + } + RequestPayload::FinalizeTask(req) => { + self.allocated_tasks + .entry(req.executor_id.get().to_string()) + .or_default() + .remove(&req.task_id.to_string()); + } + RequestPayload::TaskAllocationProcessorUpdate(req) => { + for allocation in &req.allocations { + self.allocated_tasks + .entry(allocation.executor.get().to_string()) + .or_default() + .insert(allocation.task.id.to_string()); + } + for unplaced_task in &req.unplaced_task_keys { + self.unallocated_tasks + .insert(unplaced_task.to_string(), [0; 0]); + } + } + RequestPayload::TaskCreatorUpdate(req) => { + for task in &req.task_requests { + self.tasks.insert(task.id.to_string(), task.clone()); + } + } + _ => {} + } + Ok(()) + } + + pub fn get_in_memory_state(&self) -> Self { + InMemoryState { + namespaces: self.namespaces.clone(), + compute_graphs: self.compute_graphs.clone(), + executors: self.executors.clone(), + allocated_tasks: self.allocated_tasks.clone(), + tasks: self.tasks.clone(), + unallocated_tasks: self.unallocated_tasks.clone(), + } + } +} diff --git a/server/state_store/src/lib.rs b/server/state_store/src/lib.rs index b0d90e018..c0ed54c50 100644 --- a/server/state_store/src/lib.rs +++ b/server/state_store/src/lib.rs @@ -23,6 +23,7 @@ use strum::IntoEnumIterator; use tokio::sync::{broadcast, RwLock}; use tracing::{debug, info, span}; +pub mod in_memory_state; pub mod invocation_events; pub mod kv; pub mod requests; @@ -92,6 +93,7 @@ pub struct IndexifyState { pub change_events_tx: tokio::sync::watch::Sender<()>, pub change_events_rx: tokio::sync::watch::Receiver<()>, pub metrics: Arc, + pub in_memory_state: Arc>, // state_metrics: Arc, } @@ -118,8 +120,11 @@ impl IndexifyState { StateMetrics::new(state_store_metrics.clone()); // let state_metrics = Arc::new(StateMetrics::new(state_store_metrics.clone())); let (change_events_tx, change_events_rx) = tokio::sync::watch::channel(()); + let db = Arc::new(db); + let reader = scanner::StateReader::new(db.clone(), state_store_metrics.clone()); + let in_memory_state = Arc::new(RwLock::new(in_memory_state::InMemoryState::new(&reader)?)); let s = Arc::new(Self { - db: Arc::new(db), + db, last_state_change_id: Arc::new(AtomicU64::new(sm_meta.last_change_idx)), executor_states: RwLock::new(HashMap::new()), task_event_tx, @@ -130,6 +135,7 @@ impl IndexifyState { metrics: state_store_metrics, change_events_tx, change_events_rx, + in_memory_state, // state_metrics, }); @@ -301,7 +307,7 @@ impl IndexifyState { self.gc_tx.send(()).unwrap(); vec![] } - RequestPayload::NamespaceProcessorUpdate(request) => { + RequestPayload::TaskCreatorUpdate(request) => { let new_state_changes = state_changes::change_events_for_namespace_processor_update( &self.last_state_change_id, @@ -420,6 +426,7 @@ impl IndexifyState { }, )?; txn.commit()?; + self.in_memory_state.write().await.update_state(&request)?; for executor_id in allocated_tasks_by_executor { self.executor_states .write() @@ -457,7 +464,7 @@ impl IndexifyState { InvocationStateChangeEvent::from_task_finished(task_finished_event.clone()); let _ = self.task_event_tx.send(ev); } - RequestPayload::NamespaceProcessorUpdate(sched_update) => { + RequestPayload::TaskCreatorUpdate(sched_update) => { for task in &sched_update.task_requests { let _ = self .task_event_tx diff --git a/server/state_store/src/requests.rs b/server/state_store/src/requests.rs index 996b136ff..18f83b305 100644 --- a/server/state_store/src/requests.rs +++ b/server/state_store/src/requests.rs @@ -29,7 +29,7 @@ pub enum RequestPayload { CreateOrUpdateComputeGraph(CreateOrUpdateComputeGraphRequest), TombstoneComputeGraph(DeleteComputeGraphRequest), TombstoneInvocation(DeleteInvocationRequest), - NamespaceProcessorUpdate(NamespaceProcessorUpdateRequest), + TaskCreatorUpdate(NamespaceProcessorUpdateRequest), TaskAllocationProcessorUpdate(TaskAllocationUpdateRequest), RegisterExecutor(RegisterExecutorRequest), DeregisterExecutor(DeregisterExecutorRequest), diff --git a/server/state_store/src/scanner.rs b/server/state_store/src/scanner.rs index 75a8081ab..2459d1b33 100644 --- a/server/state_store/src/scanner.rs +++ b/server/state_store/src/scanner.rs @@ -719,6 +719,22 @@ impl StateReader { Ok(res.items) } + pub fn get_allocated_tasks(&self, executor: &ExecutorId) -> Result> { + let kvs = &[KeyValue::new("op", "get_allocated_tasks")]; + let _timer = Timer::start_with_labels(&self.metrics.state_read, kvs); + let prefix = format!("{}|", executor); + let res = self.filter_join_cf( + IndexifyObjectsColumns::TaskAllocations, + IndexifyObjectsColumns::Tasks, + |_| true, + prefix.as_bytes(), + Task::key_from_allocation_key, + None, + None, + )?; + Ok(res.items) + } + pub fn get_all_executors(&self) -> Result> { let kvs = &[KeyValue::new("op", "get_all_executors")]; let _timer = Timer::start_with_labels(&self.metrics.state_read, kvs); From 893f31c9f2f5326c8530ac6aa5ea60c1c324c777 Mon Sep 17 00:00:00 2001 From: Diptanu Gon Choudhury Date: Sun, 2 Feb 2025 13:56:42 -0800 Subject: [PATCH 2/2] making task creator use in memory state --- server/processor/src/graph_processor.rs | 92 +++++++--- server/processor/src/task_creator.rs | 199 ++++++++++------------ server/state_store/src/in_memory_state.rs | 109 ++++++++++-- server/state_store/src/lib.rs | 64 +++---- server/state_store/src/requests.rs | 3 + server/state_store/src/state_machine.rs | 153 ++++++----------- 6 files changed, 350 insertions(+), 270 deletions(-) diff --git a/server/processor/src/graph_processor.rs b/server/processor/src/graph_processor.rs index 59ce99b01..c17a0825a 100644 --- a/server/processor/src/graph_processor.rs +++ b/server/processor/src/graph_processor.rs @@ -1,7 +1,13 @@ use std::{sync::Arc, vec}; use anyhow::Result; -use data_model::{ChangeType, StateChange}; +use data_model::{ + ChangeType, + GraphInvocationCtx, + StateChange, + TaskOutcome, + TaskOutputsIngestedEvent, +}; use state_store::{ requests::{ DeleteComputeGraphRequest, @@ -13,8 +19,7 @@ use state_store::{ RequestPayload, StateMachineUpdateRequest, TaskAllocationUpdateRequest, - }, - IndexifyState, + }, IndexifyState }; use tokio::sync::Notify; use tracing::{error, info}; @@ -173,7 +178,10 @@ impl GraphProcessor { info!("invoking compute graph: {:?}", event); let task_creation_result = self .task_creator - .handle_invoke_compute_graph(event.clone()) + .handle_invoke_compute_graph( + event.clone(), + &self.indexify_state.in_memory_state().await, + ) .await?; Ok(task_creation_result_to_sm_update( &event.namespace, @@ -183,24 +191,33 @@ impl GraphProcessor { &state_change, )) } - ChangeType::TaskOutputsIngested(event) => Ok(StateMachineUpdateRequest { - payload: RequestPayload::FinalizeTask(FinalizeTaskRequest { - namespace: event.namespace.clone(), - compute_graph: event.compute_graph.clone(), - compute_fn: event.compute_fn.clone(), - invocation_id: event.invocation_id.clone(), - task_id: event.task_id.clone(), - task_outcome: event.outcome.clone(), - executor_id: event.executor_id.clone(), - diagnostics: event.diagnostic.clone(), - }), - processed_state_changes: vec![state_change.clone()], - }), + ChangeType::TaskOutputsIngested(event) => { + let graph_ctx = + update_graph_ctx_finalize_task(event, self.indexify_state.clone()).await; + if let None = graph_ctx { + return Ok(StateMachineUpdateRequest { + payload: RequestPayload::Noop, + processed_state_changes: vec![state_change.clone()], + }); + } + Ok(StateMachineUpdateRequest { + payload: RequestPayload::FinalizeTask(FinalizeTaskRequest { + namespace: event.namespace.clone(), + compute_graph: event.compute_graph.clone(), + compute_fn: event.compute_fn.clone(), + invocation_id: event.invocation_id.clone(), + task_id: event.task_id.clone(), + task_outcome: event.outcome.clone(), + executor_id: event.executor_id.clone(), + diagnostics: event.diagnostic.clone(), + invocation_ctx: graph_ctx, + }), + processed_state_changes: vec![state_change.clone()], + }) + } ChangeType::TaskFinalized(event) => { - let task_creation_result = self - .task_creator - .handle_task_finished_inner(self.indexify_state.clone(), event) - .await?; + let task_creation_result = + self.task_creator.handle_task_finished_inner(event).await?; Ok(task_creation_result_to_sm_update( &event.namespace, &event.compute_graph, @@ -292,6 +309,7 @@ fn task_creation_result_to_sm_update( new_reduction_tasks: task_creation_result.new_reduction_tasks, processed_reduction_tasks: task_creation_result.processed_reduction_tasks, }, + invocation_ctx: task_creation_result.invocation_ctx, }), processed_state_changes: vec![state_change.clone()], } @@ -310,3 +328,35 @@ fn task_placement_result_to_sm_update( processed_state_changes: vec![state_change.clone()], } } + +async fn update_graph_ctx_finalize_task( + event: &TaskOutputsIngestedEvent, + indexify_state: Arc, +) -> Option { + let in_memory_state = indexify_state.in_memory_state().await; + let invocation_ctx_map= in_memory_state + .invocation_ctx + .clone(); + + let may_be_invocation_ctx = invocation_ctx_map + .get(&format!( + "{}|{}|{}", + event.namespace, event.compute_graph, event.invocation_id + )); + if may_be_invocation_ctx.is_none() { + error!("namespace: {}, invocation: {}, cg: {}, fn: {}, task: {} no invocation ctx found for task outputs ingested event", + event.namespace, event.invocation_id, event.compute_graph, event.compute_fn, event.task_id); + return None; + } + let mut invocation_ctx = may_be_invocation_ctx.unwrap().clone(); + invocation_ctx.outstanding_tasks -= 1; + invocation_ctx + .fn_task_analytics + .entry(event.compute_fn.clone()) + .and_modify(|e| match event.outcome { + TaskOutcome::Success => e.success(), + TaskOutcome::Failure => e.fail(), + _ => {} + }); + Some(invocation_ctx) +} diff --git a/server/processor/src/task_creator.rs b/server/processor/src/task_creator.rs index 21d612f47..b69271b43 100644 --- a/server/processor/src/task_creator.rs +++ b/server/processor/src/task_creator.rs @@ -3,15 +3,18 @@ use std::{sync::Arc, vec}; use anyhow::{anyhow, Result}; use data_model::{ ComputeGraphVersion, + GraphInvocationCtx, + GraphInvocationCtxBuilder, InvokeComputeGraphEvent, Node, OutputPayload, ReduceTask, Task, + TaskAnalytics, TaskFinalizedEvent, TaskOutcome, }; -use state_store::IndexifyState; +use state_store::{in_memory_state, IndexifyState}; use tracing::{error, info, trace}; #[derive(Debug)] @@ -22,6 +25,7 @@ pub struct TaskCreationResult { pub new_reduction_tasks: Vec, pub processed_reduction_tasks: Vec, pub invocation_id: String, + pub invocation_ctx: Option, } impl TaskCreationResult { @@ -33,6 +37,7 @@ impl TaskCreationResult { tasks: vec![], new_reduction_tasks: vec![], processed_reduction_tasks: vec![], + invocation_ctx: None, } } } @@ -50,16 +55,12 @@ impl TaskCreator { impl TaskCreator { pub async fn handle_task_finished_inner( &self, - indexify_state: Arc, task_finished_event: &TaskFinalizedEvent, ) -> Result { - let task = indexify_state - .reader() - .get_task_from_finished_event(task_finished_event) - .map_err(|e| { - error!("error getting task from finished event: {:?}", e); - e - })?; + let in_memory_state = self.indexify_state.in_memory_state().await; + let task = &in_memory_state + .tasks + .get(&Task::key_from(&task_finished_event.namespace, &task_finished_event.compute_graph, &task_finished_event.invocation_id, &task_finished_event.compute_fn, &task_finished_event.task_id.to_string())); if task.is_none() { error!( "task not found for task finished event: {}", @@ -73,17 +74,11 @@ impl TaskCreator { } let task = task.ok_or(anyhow!("task not found: {}", task_finished_event.task_id))?; - let compute_graph_version = indexify_state - .reader() - .get_compute_graph_version( - &task.namespace, - &task.compute_graph_name, - &task.graph_version, - ) - .map_err(|e| { - error!("error getting compute graph version: {:?}", e); - e - })?; + let compute_graph_version = in_memory_state.compute_graph_versions.get(&format!( + "{}|{}|{}", + task.namespace, task.compute_graph_name, task.graph_version.0, + )); + if compute_graph_version.is_none() { error!( "compute graph version not found: {:?} {:?}", @@ -100,57 +95,22 @@ impl TaskCreator { task.namespace, task.compute_graph_name ))?; - self.handle_task_finished(task, compute_graph_version).await + self.handle_task_finished(task, compute_graph_version, &in_memory_state) + .await } pub async fn handle_invoke_compute_graph( &self, event: InvokeComputeGraphEvent, + in_memory_state: &in_memory_state::InMemoryState, ) -> Result { - let invocation_ctx = self - .indexify_state - .reader() - .invocation_ctx(&event.namespace, &event.compute_graph, &event.invocation_id) - .map_err(|e| { - anyhow!( - "error getting invocation context for invocation {}: {:?}", - event.invocation_id, - e - ) - })?; - if invocation_ctx.is_none() { - return Ok(TaskCreationResult::no_tasks( - &event.namespace, - &event.compute_graph, - &event.invocation_id, - )); - } - let invocation_ctx = invocation_ctx.ok_or(anyhow!( - "invocation context not found for invocation_id {}", - event.invocation_id - ))?; - - let compute_graph_version = self - .indexify_state - .reader() - .get_compute_graph_version( - &event.namespace, - &event.compute_graph, - &invocation_ctx.graph_version, - ) - .map_err(|e| { - anyhow!( - "error getting compute graph version: {:?} {:?} {:?} {:?}", - event.namespace, - event.compute_graph, - invocation_ctx.graph_version, - e - ) - })?; - if compute_graph_version.is_none() { - info!( - "compute graph version not found: {:?} {:?} {:?}", - event.namespace, event.compute_graph, invocation_ctx.graph_version, + let compute_graph = in_memory_state + .compute_graphs + .get(&format!("{}|{}", event.namespace, event.compute_graph)); + if compute_graph.is_none() { + error!( + "compute graph not found: {:?} {:?}", + event.namespace, event.compute_graph ); return Ok(TaskCreationResult::no_tasks( &event.namespace, @@ -158,13 +118,8 @@ impl TaskCreator { &event.invocation_id, )); } - let compute_graph_version = compute_graph_version.ok_or(anyhow!( - "compute graph version not found: {:?} {:?} {:?}", - event.namespace, - event.compute_graph, - invocation_ctx.graph_version, - ))?; - + let compute_graph = compute_graph.unwrap(); + let compute_graph_version = compute_graph.into_version(); // Create a task for the compute graph let task = compute_graph_version.start_fn.create_task( &event.namespace, @@ -174,10 +129,26 @@ impl TaskCreator { None, &compute_graph_version.version, )?; - trace!( + info!( task_key = task.key(), "Creating a standard task to start compute graph" ); + let mut graph_ctx = GraphInvocationCtxBuilder::default() + .graph_version(compute_graph.version.clone()) + .invocation_id(event.invocation_id.clone()) + .namespace(event.namespace.clone()) + .outstanding_tasks(1) + .compute_graph_name(event.compute_graph.clone()) + .build(compute_graph.clone())?; + graph_ctx.fn_task_analytics.insert( + task.compute_fn_name.clone(), + TaskAnalytics { + pending_tasks: 1, + successful_tasks: 0, + failed_tasks: 0, + }, + ); + info!("hereree"); Ok(TaskCreationResult { namespace: event.namespace.clone(), compute_graph: event.compute_graph.clone(), @@ -185,43 +156,34 @@ impl TaskCreator { tasks: vec![task], new_reduction_tasks: vec![], processed_reduction_tasks: vec![], + invocation_ctx: Some(graph_ctx), }) } pub async fn handle_task_finished( &self, - task: Task, - compute_graph_version: ComputeGraphVersion, + task: &Task, + compute_graph_version: &ComputeGraphVersion, + in_memory_state: &in_memory_state::InMemoryState, ) -> Result { - let invocation_ctx = self - .indexify_state - .reader() - .invocation_ctx( - &task.namespace, - &task.compute_graph_name, - &task.invocation_id, - ) - .map_err(|e| { - anyhow!( - "error getting invocation context for invocation {}: {:?}", - task.invocation_id, - e - ) - })?; + let invocation_ctx = in_memory_state.invocation_ctx.get(&format!( + "{}|{}|{}", + task.namespace, task.compute_graph_name, task.invocation_id + )); if invocation_ctx.is_none() { - trace!("no invocation ctx, stopping scheduling of child tasks"); + error!("no invocation ctx, stopping scheduling of child tasks"); return Ok(TaskCreationResult::no_tasks( &task.namespace, &task.compute_graph_name, &task.invocation_id, )); } - let invocation_ctx = invocation_ctx.ok_or(anyhow!( - "invocation context not found for invocation_id {}", - task.invocation_id - ))?; - - trace!("invocation context: {:?}", invocation_ctx); + let mut invocation_ctx = invocation_ctx + .ok_or(anyhow!( + "invocation context not found for invocation_id {}", + task.invocation_id + ))? + .clone(); if task.outcome == TaskOutcome::Failure { trace!("task failed, stopping scheduling of child tasks"); @@ -235,6 +197,7 @@ impl TaskCreator { .indexify_state .reader() .get_task_outputs(&task.namespace, &task.id.to_string())?; + let mut new_tasks = vec![]; // Check if the task has a router output and create new tasks for the router @@ -262,12 +225,18 @@ impl TaskCreator { None, &invocation_ctx.graph_version, )?; + invocation_ctx + .fn_task_analytics + .entry(compute_fn.name().to_string()) + .or_default() + .pending(); new_tasks.push(new_task); } trace!( task_keys = ?new_tasks.iter().map(|t| t.key()).collect::>(), "Creating a router edge task", ); + invocation_ctx.outstanding_tasks += new_tasks.len() as u64; return Ok(TaskCreationResult { namespace: task.namespace.clone(), compute_graph: task.compute_graph_name.clone(), @@ -275,6 +244,7 @@ impl TaskCreator { tasks: new_tasks, new_reduction_tasks: vec![], processed_reduction_tasks: vec![], + invocation_ctx: Some(invocation_ctx.clone()), }); } } @@ -304,16 +274,12 @@ impl TaskCreator { )); } } - let reduction_task = self - .indexify_state - .reader() - .next_reduction_task( - &task.namespace, - &task.compute_graph_name, - &task.invocation_id, - &compute_fn.name, - ) - .map_err(|e| anyhow!("error getting next reduction task: {:?}", e))?; + let reduction_task = in_memory_state.next_queued_task( + &task.namespace, + &task.compute_graph_name, + &task.invocation_id, + &task.compute_fn_name, + ); if let Some(reduction_task) = reduction_task { // Create a new task for the queued reduction_task let output = outputs.first().unwrap(); @@ -330,6 +296,12 @@ impl TaskCreator { compute_fn_name = new_task.compute_fn_name, "Creating a reduction task from queue", ); + invocation_ctx.outstanding_tasks += 1; + invocation_ctx + .fn_task_analytics + .entry(compute_node.name().to_string()) + .or_default() + .pending(); return Ok(TaskCreationResult { namespace: task.namespace.clone(), compute_graph: task.compute_graph_name.clone(), @@ -337,6 +309,7 @@ impl TaskCreator { tasks: vec![new_task], new_reduction_tasks: vec![], processed_reduction_tasks: vec![reduction_task.key()], + invocation_ctx: Some(invocation_ctx.clone()), }); } trace!( @@ -551,6 +524,17 @@ impl TaskCreator { } trace!("tasks: {:?}", new_tasks.len()); + invocation_ctx.outstanding_tasks += new_tasks.len() as u64; + for task in &new_tasks { + invocation_ctx + .fn_task_analytics + .entry(task.compute_fn_name.clone()) + .or_default() + .pending(); + } + if new_tasks.is_empty() && invocation_ctx.outstanding_tasks == 0 { + invocation_ctx.completed = true; + } Ok(TaskCreationResult { namespace: task.namespace.clone(), compute_graph: task.compute_graph_name.clone(), @@ -558,6 +542,7 @@ impl TaskCreator { tasks: new_tasks, new_reduction_tasks, processed_reduction_tasks: vec![], + invocation_ctx: Some(invocation_ctx.clone()), }) } } diff --git a/server/state_store/src/in_memory_state.rs b/server/state_store/src/in_memory_state.rs index 4e7133bac..0e180d30c 100644 --- a/server/state_store/src/in_memory_state.rs +++ b/server/state_store/src/in_memory_state.rs @@ -1,5 +1,12 @@ use anyhow::Result; -use data_model::{ComputeGraph, ExecutorMetadata, Task}; +use data_model::{ + ComputeGraph, + ComputeGraphVersion, + ExecutorMetadata, + GraphInvocationCtx, + ReduceTask, + Task, +}; use im::HashSet; use crate::{ @@ -9,21 +16,31 @@ use crate::{ }; pub struct InMemoryState { - namespaces: im::HashMap, + pub namespaces: im::HashMap, // Namespace|CG Name -> ComputeGraph - compute_graphs: im::HashMap, + pub compute_graphs: im::HashMap, + + // Namespace|CG Name|Version -> ComputeGraph + pub compute_graph_versions: im::HashMap, // ExecutorId -> ExecutorMetadata - executors: im::HashMap, + pub executors: im::HashMap, // Executor Id -> List of Task IDs - allocated_tasks: im::HashMap>, - // Task ID -> Task - tasks: im::HashMap, + pub allocated_tasks: im::HashMap>, + + // Task Key -> Task + pub tasks: im::OrdMap, + + // Queued Reduction Tasks + pub queued_reduction_tasks: im::OrdMap, // Task Keys - unallocated_tasks: im::HashMap, + pub unallocated_tasks: im::HashMap, + + // Invocation Ctx + pub invocation_ctx: im::OrdMap, } impl InMemoryState { @@ -40,6 +57,12 @@ impl InMemoryState { compute_graphs.insert(format!("{}|{}", ns.name, cg.name), cg); } } + let mut compute_graph_versions = im::HashMap::new(); + let all_cg_versions: Vec<(String, ComputeGraphVersion)> = + reader.get_all_rows_from_cf(IndexifyObjectsColumns::ComputeGraphVersions)?; + for (id, cg) in all_cg_versions { + compute_graph_versions.insert(id, cg); + } let all_executors = reader.get_all_executors()?; let mut executors = im::HashMap::new(); let mut allocated_tasks = im::HashMap::new(); @@ -57,9 +80,9 @@ impl InMemoryState { let all_tasks: Vec<(String, Task)> = reader.get_all_rows_from_cf(IndexifyObjectsColumns::Tasks)?; - let mut tasks = im::HashMap::new(); - for (id, task) in all_tasks { - tasks.insert(id, task); + let mut tasks = im::OrdMap::new(); + for (_id, task) in all_tasks { + tasks.insert(task.key(), task); } let all_unallocated_tasks: Vec<(String, [u8; 0])> = reader.get_all_rows_from_cf(IndexifyObjectsColumns::UnallocatedTasks)?; @@ -67,13 +90,28 @@ impl InMemoryState { for (id, task) in all_unallocated_tasks { unallocated_tasks.insert(id, task); } + let mut invocation_ctx = im::OrdMap::new(); + let all_graph_invocation_ctx: Vec<(String, GraphInvocationCtx)> = + reader.get_all_rows_from_cf(IndexifyObjectsColumns::GraphInvocationCtx)?; + for (id, ctx) in all_graph_invocation_ctx { + invocation_ctx.insert(id, ctx); + } + let mut queued_reduction_tasks = im::OrdMap::new(); + let all_reduction_tasks: Vec<(String, ReduceTask)> = + reader.get_all_rows_from_cf(IndexifyObjectsColumns::ReductionTasks)?; + for (_id, task) in all_reduction_tasks { + queued_reduction_tasks.insert(task.key(), task); + } Ok(Self { namespaces, compute_graphs, + compute_graph_versions, executors, allocated_tasks, tasks, unallocated_tasks, + invocation_ctx, + queued_reduction_tasks, }) } @@ -90,16 +128,36 @@ impl InMemoryState { format!("{}|{}", req.namespace, req.compute_graph.name), req.compute_graph.clone(), ); + self.compute_graph_versions.insert( + format!( + "{}|{}|{}", + req.namespace, req.compute_graph.name, &req.compute_graph.version.0, + ), + req.compute_graph.into_version().clone(), + ); } RequestPayload::DeleteComputeGraphRequest(req) => { self.compute_graphs .remove(&format!("{}|{}", req.namespace, req.name)); + self.compute_graph_versions + .remove(&format!("{}|{}", req.namespace, req.name)); + let key = format!("{}|{}", req.namespace, req.name); + let mut graph_ctx_to_remove = vec![]; + for (k, _v) in self.invocation_ctx.range(key.clone()..key.clone()) { + graph_ctx_to_remove.push(k.clone()); + } + for k in graph_ctx_to_remove { + self.invocation_ctx.remove(&k); + } } RequestPayload::FinalizeTask(req) => { self.allocated_tasks .entry(req.executor_id.get().to_string()) .or_default() .remove(&req.task_id.to_string()); + if let Some(updated_invocation_ctx) = &req.invocation_ctx { + self.invocation_ctx = self.invocation_ctx.update(updated_invocation_ctx.key(), updated_invocation_ctx.clone()); + } } RequestPayload::TaskAllocationProcessorUpdate(req) => { for allocation in &req.allocations { @@ -115,7 +173,17 @@ impl InMemoryState { } RequestPayload::TaskCreatorUpdate(req) => { for task in &req.task_requests { - self.tasks.insert(task.id.to_string(), task.clone()); + self.tasks.insert(task.key(), task.clone()); + } + for task in &req.reduction_tasks.new_reduction_tasks { + self.queued_reduction_tasks.insert(task.key(), task.clone()); + } + for task in &req.reduction_tasks.processed_reduction_tasks { + self.queued_reduction_tasks.remove(task); + } + if let Some(updated_invocation_ctx) = &req.invocation_ctx { + self.invocation_ctx = self.invocation_ctx.update(updated_invocation_ctx.key(), updated_invocation_ctx.clone()); + } } _ => {} @@ -123,14 +191,31 @@ impl InMemoryState { Ok(()) } + pub fn next_queued_task( + &self, + ns: &str, + cg: &str, + inv: &str, + c_fn: &str, + ) -> Option { + let key = format!("{}|{}|{}|{}", ns, cg, inv, c_fn); + self.queued_reduction_tasks + .range(key.clone()..key.clone()) + .next() + .map(|(_, v)| v.clone()) + } + pub fn get_in_memory_state(&self) -> Self { InMemoryState { namespaces: self.namespaces.clone(), compute_graphs: self.compute_graphs.clone(), + compute_graph_versions: self.compute_graph_versions.clone(), executors: self.executors.clone(), allocated_tasks: self.allocated_tasks.clone(), tasks: self.tasks.clone(), unallocated_tasks: self.unallocated_tasks.clone(), + invocation_ctx: self.invocation_ctx.clone(), + queued_reduction_tasks: self.queued_reduction_tasks.clone(), } } } diff --git a/server/state_store/src/lib.rs b/server/state_store/src/lib.rs index c0ed54c50..34e206af1 100644 --- a/server/state_store/src/lib.rs +++ b/server/state_store/src/lib.rs @@ -257,17 +257,35 @@ impl IndexifyState { } } RequestPayload::FinalizeTask(finalize_task) => { - let finalized = state_machine::mark_task_finalized( + tasks_finalized + .entry(finalize_task.executor_id.clone()) + .or_default() + .push(finalize_task.task_id.clone()); + let finalized_task_result = state_machine::mark_task_finalized( self.db.clone(), &txn, finalize_task.clone(), self.metrics.clone(), + finalize_task.invocation_ctx.clone(), )?; - if finalized { - tasks_finalized - .entry(finalize_task.executor_id.clone()) - .or_default() - .push(finalize_task.task_id.clone()); + if let Some(invocation_completion) = finalized_task_result.invocation_completion { + match invocation_completion { + InvocationCompletion::System => { + // Notify the system task handler that it can start new tasks since + // a task was completed + let _ = self.system_tasks_tx.send(()); + } + InvocationCompletion::User => { + let _ = self.task_event_tx + .send(InvocationStateChangeEvent::InvocationFinished( + InvocationFinishedEvent { + id: finalize_task.invocation_id.clone(), + }, + )); + } + } + } + if finalized_task_result.should_notify_graph_processor { state_changes::finalized_task(&self.last_state_change_id, &finalize_task)? } else { vec![] @@ -308,33 +326,17 @@ impl IndexifyState { vec![] } RequestPayload::TaskCreatorUpdate(request) => { - let new_state_changes = + let mut new_state_changes = vec![]; + if request.task_requests.len() > 0 { + let state_changes = state_changes::change_events_for_namespace_processor_update( &self.last_state_change_id, &request, )?; - if let Some(completion) = state_machine::create_tasks( - self.db.clone(), - &txn, - &request.task_requests.clone(), - self.metrics.clone().clone(), - &request.namespace, - &request.compute_graph, - &request.invocation_id, - )? { - let _ = - self.task_event_tx - .send(InvocationStateChangeEvent::InvocationFinished( - InvocationFinishedEvent { - id: request.invocation_id.clone(), - }, - )); - if completion == InvocationCompletion::System { - // Notify the system task handler that it can start new tasks since - // a task was completed - let _ = self.system_tasks_tx.send(()); - } - }; + new_state_changes.extend(state_changes); + state_machine::create_tasks(self.db.clone(), &txn, &request.task_requests.clone(), request.invocation_ctx.clone())?; + } + state_machine::processed_reduction_tasks( self.db.clone(), &txn, @@ -509,6 +511,10 @@ impl IndexifyState { scanner::StateReader::new(self.db.clone(), self.metrics.clone()) } + pub async fn in_memory_state(&self) -> in_memory_state::InMemoryState { + self.in_memory_state.read().await.get_in_memory_state() + } + pub fn task_event_stream(&self) -> broadcast::Receiver { self.task_event_tx.subscribe() } diff --git a/server/state_store/src/requests.rs b/server/state_store/src/requests.rs index 18f83b305..6da2888b9 100644 --- a/server/state_store/src/requests.rs +++ b/server/state_store/src/requests.rs @@ -2,6 +2,7 @@ use data_model::{ ComputeGraph, ExecutorId, ExecutorMetadata, + GraphInvocationCtx, GraphVersion, InvocationPayload, NodeOutput, @@ -92,6 +93,7 @@ pub struct FinalizeTaskRequest { pub task_outcome: TaskOutcome, pub diagnostics: Option, pub executor_id: ExecutorId, + pub invocation_ctx: Option, } #[derive(Debug, Clone)] @@ -155,6 +157,7 @@ pub struct NamespaceProcessorUpdateRequest { pub invocation_id: String, pub task_requests: Vec, pub reduction_tasks: ReductionTasks, + pub invocation_ctx: Option, } #[derive(Debug, Clone)] diff --git a/server/state_store/src/state_machine.rs b/server/state_store/src/state_machine.rs index 39f07d25d..18328e4d9 100644 --- a/server/state_store/src/state_machine.rs +++ b/server/state_store/src/state_machine.rs @@ -20,7 +20,6 @@ use data_model::{ StateMachineMetadata, SystemTask, Task, - TaskAnalytics, TaskOutputsIngestionStatus, }; use indexify_utils::{get_epoch_time_in_ms, OptionInspectNone}; @@ -825,39 +824,29 @@ pub(crate) enum InvocationCompletion { System, } +pub(crate) struct FinalizeTaskResult { + pub invocation_completion: Option, + pub should_notify_graph_processor: bool, +} + // returns whether the invocation was completed or not and whether it was a user // or system task invocation. pub(crate) fn create_tasks( db: Arc, txn: &Transaction, tasks: &[Task], - sm_metrics: Arc, - namespace: &str, - compute_graph: &str, - invocation_id: &str, -) -> Result> { - let ctx_key = format!("{}|{}|{}", namespace, compute_graph, invocation_id); - let graph_ctx = txn.get_for_update_cf( + graph_ctx: Option, +) -> Result<()> { + println!("create tasks: {} graph ctx {:?}", tasks.len(), graph_ctx); + if let Some(graph_ctx) = &graph_ctx { + let serialized_graphctx = JsonEncoder::encode(&graph_ctx)?; + txn.put_cf( &IndexifyObjectsColumns::GraphInvocationCtx.cf_db(&db), - &ctx_key, - true, + graph_ctx.key(), + serialized_graphctx, )?; - if graph_ctx.is_none() { - error!( - "Graph context not found for graph {} and invocation {}", - &compute_graph, &invocation_id - ); - return Ok(None); - } - let graph_ctx = &graph_ctx.ok_or(anyhow!( - "Graph context not found for graph {} and invocation {}", - &compute_graph, - &invocation_id - ))?; - let mut graph_ctx: GraphInvocationCtx = JsonEncoder::decode(&graph_ctx)?; - if graph_ctx.completed { - return Ok(None); } + for task in tasks { let serialized_task = JsonEncoder::encode(&task)?; info!( @@ -873,37 +862,8 @@ pub(crate) fn create_tasks( task.key(), &serialized_task, )?; - let analytics = graph_ctx - .fn_task_analytics - .entry(task.compute_fn_name.clone()) - .or_insert_with(|| TaskAnalytics::default()); - analytics.pending(); - } - graph_ctx.outstanding_tasks += tasks.len() as u64; - // Subtract reference for completed state change event - graph_ctx.outstanding_tasks -= 1; - let serialized_graphctx = JsonEncoder::encode(&graph_ctx)?; - txn.put_cf( - &IndexifyObjectsColumns::GraphInvocationCtx.cf_db(&db), - ctx_key, - serialized_graphctx, - )?; - info!( - "invocation ctx for invocation : {}, {:?}", - invocation_id, graph_ctx - ); - sm_metrics.task_unassigned(tasks); - if graph_ctx.outstanding_tasks == 0 { - Ok(Some(mark_invocation_finished( - db, - txn, - &namespace, - &compute_graph, - &invocation_id, - )?)) - } else { - Ok(None) } + Ok(()) } pub fn handle_task_allocation_update( @@ -1042,12 +1002,13 @@ pub fn ingest_task_outputs( } /// Returns true if the task was marked as finalized. -pub fn mark_task_finalized( +pub(crate) fn mark_task_finalized( db: Arc, txn: &Transaction, req: FinalizeTaskRequest, sm_metrics: Arc, -) -> Result { + graph_ctx: Option, +) -> Result { info!( "task finalization begin: ns: {}, compute graph: {}, invocation_id: {}, task: {}, outcome: {:?}", req.namespace, req.compute_graph, req.invocation_id, req.task_id, req.task_outcome @@ -1066,7 +1027,10 @@ pub fn mark_task_finalized( "task finalization end: task: {}, Compute graph not found: {}", &req.task_id, &req.compute_graph ); - return Ok(false); + return Ok(FinalizeTaskResult { + invocation_completion: Some(InvocationCompletion::User), + should_notify_graph_processor: false, + }); } // Check if the invocation was deleted before the task completes @@ -1083,7 +1047,10 @@ pub fn mark_task_finalized( "task finalization end: Invocation not found: {}", &req.invocation_id ); - return Ok(false); + return Ok(FinalizeTaskResult { + invocation_completion: Some(InvocationCompletion::User), + should_notify_graph_processor: false, + }); } let task_key = format!( "{}|{}|{}|{}|{}", @@ -1092,52 +1059,18 @@ pub fn mark_task_finalized( let task = txn.get_for_update_cf(&IndexifyObjectsColumns::Tasks.cf_db(&db), &task_key, true)?; if task.is_none() { error!("task finalization end: Task not found: {}", &task_key); - return Ok(false); + return Ok(FinalizeTaskResult{ + invocation_completion: Some(InvocationCompletion::User), + should_notify_graph_processor: false, + }); } let mut task = JsonEncoder::decode::(&task.unwrap())?; - let graph_ctx_key = format!( - "{}|{}|{}", - req.namespace, req.compute_graph, req.invocation_id - ); - let graph_ctx = txn - .get_for_update_cf( - &IndexifyObjectsColumns::GraphInvocationCtx.cf_db(&db), - &graph_ctx_key, - true, - ) - .map_err(|e| anyhow!("failed to get graph context: {}", e))?; - if graph_ctx.is_none() { - error!( - "task finalization end: Graph context not found, ns: {} compute graph: {} invocation id: {} task: {}", - &req.namespace, &req.compute_graph, &req.invocation_id, &req.task_id - ); - return Ok(false); - } - let mut graph_ctx: GraphInvocationCtx = JsonEncoder::decode(&graph_ctx.ok_or(anyhow!( - "unable to deserialize graph context for task: {}", - &req.task_id - ))?)?; - - let analytics = graph_ctx - .fn_task_analytics - .entry(req.compute_fn.to_string()) - .or_insert_with(|| TaskAnalytics::default()); - match req.task_outcome { - data_model::TaskOutcome::Success => analytics.success(), - data_model::TaskOutcome::Failure => analytics.fail(), - _ => {} - } - info!( - "task finalization graph ctx updated: task: {}, graph ctx: {:?}", - task.key(), - graph_ctx - ); - let graph_ctx = JsonEncoder::encode(&graph_ctx)?; + let serialized_graph_ctx = JsonEncoder::encode(&graph_ctx.clone().unwrap().clone())?; txn.put_cf( &IndexifyObjectsColumns::GraphInvocationCtx.cf_db(&db), - graph_ctx_key, - graph_ctx, + graph_ctx.clone().unwrap().key(), + serialized_graph_ctx, )?; // Delete the task allocation since task is finished. @@ -1169,7 +1102,25 @@ pub fn mark_task_finalized( )?; sm_metrics.update_task_completion(req.task_outcome, task.clone(), req.executor_id.get()); - Ok(true) + if let Some(graph_ctx) = graph_ctx { + if graph_ctx.outstanding_tasks == 0 { + let result = mark_invocation_finished( + db, + txn, + &req.namespace, + &req.compute_graph, + &req.invocation_id, + )?; + return Ok(FinalizeTaskResult { + invocation_completion: Some(result), + should_notify_graph_processor: true, + }); + } + } + Ok(FinalizeTaskResult { + invocation_completion: None, + should_notify_graph_processor: true, + }) } pub(crate) fn save_state_changes(