Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

typing progress #1020

Open
wants to merge 3 commits into
base: marvin-3
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion examples/hello_fn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
This DX is supported but not encouraged for those who want type safety. See `cast` or `extract` for more.
"""

from datetime import date
from typing import TypedDict

Expand All @@ -12,7 +16,7 @@ class CulturalReference(TypedDict):
@fn
def pop_culture_related_to_sum(x: int, y: int) -> CulturalReference:
"""Given two numbers, return a cultural reference related to the sum of the two numbers."""
return f"the sum of {x} and {y} is {x + y}" # type: ignore
return f"the sum of {x} and {y} is {x + y}" # type: ignore[reportReturnType]


if __name__ == "__main__":
Expand Down
6 changes: 4 additions & 2 deletions src/marvin/agents/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

if TYPE_CHECKING:
from marvin.agents.team import Team
from marvin.engine.end_turn import EndTurn


@dataclass(kw_only=True)
Expand Down Expand Up @@ -53,7 +54,7 @@ def get_agentlet(
self,
result_types: list[type],
tools: list[Callable[..., Any]] | None = None,
**kwargs,
**kwargs: Any,
) -> pydantic_ai.Agent[Any, Any]:
raise NotImplementedError("Subclass must implement get_agentlet")

Expand All @@ -67,7 +68,7 @@ def get_tools(self) -> list[Callable[..., Any]]:
"""A list of tools that this actor can use during its turn."""
return []

def get_end_turn_tools(self) -> list[type["marvin.engine.end_turn.EndTurn"]]:
def get_end_turn_tools(self) -> list[type["EndTurn"]]:
"""A list of `EndTurn` tools that this actor can use to end its turn."""
return []

Expand Down Expand Up @@ -107,6 +108,7 @@ async def say_async(
thread: Thread | str | None = None,
):
"""Responds to a user message in a conversational way."""

return await marvin.say_async(
message=message,
instructions=instructions,
Expand Down
2 changes: 1 addition & 1 deletion src/marvin/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def get_agentlet(
self,
result_types: list[type],
tools: list[Callable[..., Any]] | None = None,
**kwargs,
**kwargs: Any,
) -> pydantic_ai.Agent[Any, Any]:
if len(result_types) == 1:
result_type = result_types[0]
Expand Down
2 changes: 1 addition & 1 deletion src/marvin/agents/team.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def get_agentlet(
self,
result_types: list[type],
tools: list[Callable[..., Any]] | None = None,
**kwargs,
**kwargs: Any,
) -> pydantic_ai.Agent[Any, Any]:
return self.active_agent.get_agentlet(
tools=self.tools + self.get_end_turn_tools() + (tools or []),
Expand Down
10 changes: 5 additions & 5 deletions src/marvin/engine/end_turn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Generic, Literal, TypeVar
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar

from pydantic_ai import ModelRetry

Expand Down Expand Up @@ -35,7 +35,7 @@ class MarkTaskSuccessful(TaskStateEndTurn, Generic[TaskResult]):
result: TaskResult

async def run(self, orchestrator: "Orchestrator") -> None:
tasks = {t.id: t for t in orchestrator.get_all_tasks()}
tasks: dict[str, "Task[Any]"] = {t.id: t for t in orchestrator.get_all_tasks()}
if self.task_id not in tasks:
raise ModelRetry(f"Task ID {self.task_id} not found in tasks")

Expand All @@ -48,7 +48,7 @@ async def run(self, orchestrator: "Orchestrator") -> None:
task.mark_successful(self.result)

@classmethod
def prepare_for_task(cls, task: "Task") -> None:
def prepare_for_task(cls, task: "Task[Any]") -> type[Any]:
"""
We could let the LLM fill out the task_id itself, but Pydantic AI doesn't support multiple calls
to final tools with the same name, which prevents parallel end turn calls.
Expand Down Expand Up @@ -89,7 +89,7 @@ async def run(self, orchestrator: "Orchestrator") -> None:
task.mark_failed(self.message)

@classmethod
def prepare_for_task(cls, task: "Task") -> None:
def prepare_for_task(cls, task: "Task[Any]") -> None:
"""
Create a custom class for this task to support parallel end turn calls.
"""
Expand Down Expand Up @@ -120,7 +120,7 @@ async def run(self, orchestrator: "Orchestrator") -> None:
task.mark_skipped()

@classmethod
def prepare_for_task(cls, task: "Task") -> None:
def prepare_for_task(cls, task: "Task[Any]") -> None:
"""
Create a custom class for this task to support parallel end turn calls.
"""
Expand Down
53 changes: 27 additions & 26 deletions src/marvin/engine/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ class OrchestratorPrompt(Template):
@dataclass(kw_only=True)
class Orchestrator:
tasks: list[Task[Any]]
agents: list[Actor] = None
thread: Thread
handlers: list[Handler | AsyncHandler] = None
agents: list[Actor] | None = None
thread: Thread | str | None = None
handlers: list[Handler | AsyncHandler] | None = None

team: marvin.agents.team.Team = field(init=False, repr=False)
_staged_delegate: tuple[marvin.agents.team.Team, Actor] | None = field(
Expand Down Expand Up @@ -97,7 +97,7 @@ async def handle_event(self, event: Event):
if marvin.settings.log_events:
logger.debug(f"Handling event: {event.__class__.__name__}\n{event}")

for handler in self.handlers:
for handler in self.handlers or []:
if isinstance(handler, AsyncHandler):
await handler.handle(event)
else:
Expand All @@ -119,28 +119,28 @@ async def _run_turn(self):
await self.handle_event(AgentStartTurnEvent(agent=self.team))

# --- get tools
tools = set()
tools: set[Callable[..., Any]] = set()
for t in tasks:
tools.update(t.get_tools())
tools = list(tools)
_tools = list(tools)

# --- get end turn tools
end_turn_tools = set()
end_turn_tools: set[EndTurn] = set()

for t in tasks:
end_turn_tools.update(t.get_end_turn_tools())

if self.get_delegates():
end_turn_tools.add(DelegateToAgent)
end_turn_tools.update(self.team.get_end_turn_tools())
end_turn_tools = list(end_turn_tools)
_end_turn_tools = list(end_turn_tools)

# --- prepare messages
orchestrator_prompt = OrchestratorPrompt(
orchestrator=self,
tasks=self.get_all_tasks(),
instructions=get_instructions(),
end_turn_tools=end_turn_tools,
end_turn_tools=_end_turn_tools,
).render()

messages = await self.thread.get_messages_async()
Expand All @@ -149,7 +149,7 @@ async def _run_turn(self):
] + messages

# --- run agent
agentlet = self._get_agentlet(tools=tools, end_turn_tools=end_turn_tools)
agentlet = self._get_agentlet(tools=_tools, end_turn_tools=_end_turn_tools)

result = await agentlet.run("", message_history=all_messages)

Expand Down Expand Up @@ -250,7 +250,7 @@ def _fn(*args, **kwargs):

@agentlet.result_validator
async def validate_end_turn(result: EndTurn):
if isinstance(result, EndTurn):
if isinstance(result, EndTurn): # type: ignore[unnecessaryIsinstance]
try:
await result.run(orchestrator=self)
except pydantic_ai.ModelRetry as e:
Expand All @@ -262,7 +262,7 @@ async def validate_end_turn(result: EndTurn):
# return the original result
return result

for tool in agentlet._function_tools.values():
for tool in agentlet._function_tools.values(): # type: ignore[reportPrivateUsage]
# Wrap the tool run function to emit events for each call / result
async def run(
message: ToolCallPart,
Expand All @@ -285,7 +285,7 @@ async def run(
return agentlet

def get_delegates(self) -> list[Actor]:
delegates = []
delegates: list[Actor] = []
current = self.team

# Follow active_agent chain, collecting delegates at each level
Expand Down Expand Up @@ -315,7 +315,8 @@ def stage_delegate(self, agent_id: str) -> None:

# walk active_agents to find the delegate
current = self.team
while agent_id not in {a.id for a in current.agents}:
active_delegates = {a.id for a in current.agents}
while agent_id not in active_delegates:
if not isinstance(current, marvin.agents.team.Team):
raise ValueError(f"Agent ID {agent_id} not found in delegates")
current = current.active_agent
Expand All @@ -333,18 +334,18 @@ def end_turn(self) -> None:

def get_all_tasks(
self, filter: Literal["incomplete", "ready", "assigned"] | None = None
) -> list[Task]:
) -> list[Task[Any]]:
"""Get all tasks, optionally filtered by status.

Filters:
- incomplete: tasks that are not yet complete
- ready: tasks that are ready to be run
- assigned: tasks that are ready and assigned to the active agents
"""
all_tasks: set[Task] = set()
ordered_tasks: list[Task] = []
all_tasks: set[Task[Any]] = set()
ordered_tasks: list[Task[Any]] = []

def collect_tasks(task: Task) -> list[Task]:
def collect_tasks(task: Task[Any]) -> None:
if task in all_tasks:
return

Expand Down Expand Up @@ -383,15 +384,15 @@ def collect_tasks(task: Task) -> list[Task]:
return ordered_tasks

async def run(
self, raise_on_failure: bool = True, max_turns: int | None = None
self, raise_on_failure: bool = True, max_turns: int | float | None = None
) -> list[RunResult]:
if max_turns is None:
max_turns = marvin.settings.max_agent_turns
if max_turns is None:
max_turns = math.inf

results = []
incomplete_tasks: set[Task] = {t for t in self.tasks if t.is_incomplete()}
results: list[RunResult] = []
incomplete_tasks: set[Task[Any]] = {t for t in self.tasks if t.is_incomplete()}
token = _current_orchestrator.set(self)
try:
with self.thread:
Expand All @@ -407,7 +408,7 @@ async def run(
# incomplete dependencies, they will be evaluated as part of
# the orchestrator logic, but not considered part of the
# termination condition.
while incomplete_tasks and turns < max_turns:
while incomplete_tasks and (max_turns is None or turns < max_turns):
result = await self._run_turn()
results.append(result)
turns += 1
Expand All @@ -420,7 +421,7 @@ async def run(
)
incomplete_tasks = {t for t in self.tasks if t.is_incomplete()}

if turns >= max_turns:
if max_turns and turns >= max_turns:
raise ValueError("Max agent turns reached")

except (Exception, KeyboardInterrupt, CancelledError) as e:
Expand Down Expand Up @@ -451,15 +452,15 @@ def active_actors(self) -> list[Actor]:
orchestrator's team and following the team hierarchy to the active
agent.
"""
actors = []
actors: list[Actor] = []
actor = self.team
while not isinstance(actor, Agent):
actors.append(actor)
actor = actor.active_agent
actors.append(actor)
return actors

def get_agent_tree(self) -> dict:
def get_agent_tree(self) -> dict[str, Any]:
"""Returns a tree structure representing the hierarchy of teams and agents.

Returns:
Expand All @@ -473,7 +474,7 @@ def get_agent_tree(self) -> dict:
"""
active_actors = self.active_actors()

def _build_tree(node: Agent | marvin.agents.team.Team) -> dict:
def _build_tree(node: Agent | marvin.agents.team.Team) -> dict[str, Any]:
if isinstance(node, marvin.agents.team.Team):
return {
"type": "team",
Expand Down
4 changes: 2 additions & 2 deletions src/marvin/fns/cast.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, TypeVar
from typing import Any, TypeVar, get_args

import marvin
from marvin.agents.agent import Agent
Expand Down Expand Up @@ -78,7 +78,7 @@ async def cast_async(
name="Cast Task",
instructions=PROMPT,
context=task_context,
result_type=target,
result_type=t[0] if (t := get_args(target)) else target,
agents=[agent] if agent else None,
)

Expand Down
44 changes: 37 additions & 7 deletions src/marvin/fns/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,24 @@ async def classify_async(
) -> list[T]: ...


@overload
async def classify_async(
data: Any,
labels: Sequence[T] | type[T],
multi_label: bool = False,
*,
instructions: str | None = None,
agent: Agent | None = None,
thread: Thread | str | None = None,
context: dict[str, Any] | None = None,
) -> T | list[T]: ...


async def classify_async(
data: Any,
labels: Sequence[T] | type[T],
multi_label: bool = False,
*,
instructions: str | None = None,
agent: Agent | None = None,
thread: Thread | str | None = None,
Expand Down Expand Up @@ -110,22 +124,24 @@ async def classify_async(
if instructions:
task_context["Additional instructions"] = instructions

# Convert Enum class to sequence of values if needed
# Handle bool/enum types specially for correct typing
if labels is bool or issubclass_safe(labels, enum.Enum):
if multi_label:
result_type = list[labels]
else:
result_type = labels
# For bool/enum, we need list[labels] for multi-label
result_type = list[labels] if multi_label else labels # Runtime type
ReturnType = list[T] if multi_label else T # Generic type
else:
result_type = Labels(labels, many=multi_label)
# For sequences, we use Labels for runtime validation
result_type = Labels(labels, many=multi_label) # Runtime type
ReturnType = list[T] if multi_label else T # Generic type

task = marvin.Task[result_type](
task = marvin.Task[ReturnType](
name="Classification Task",
instructions=PROMPT,
context=task_context,
result_type=result_type,
agents=[agent] if agent else None,
)

return await task.run_async(thread=thread, handlers=[])


Expand Down Expand Up @@ -155,10 +171,24 @@ def classify(
) -> list[T]: ...


@overload
def classify(
data: Any,
labels: Sequence[T] | type[T],
multi_label: bool = False,
*,
instructions: str | None = None,
agent: Agent | None = None,
thread: Thread | str | None = None,
context: dict[str, Any] | None = None,
) -> T | list[T]: ...


def classify(
data: Any,
labels: Sequence[T] | type[T],
multi_label: bool = False,
*,
instructions: str | None = None,
agent: Agent | None = None,
thread: Thread | str | None = None,
Expand Down
Loading
Loading