Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add context kwarg for major fns #1014

Merged
merged 1 commit into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 53 additions & 32 deletions src/marvin/fns/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,42 +33,51 @@ async def cast_async(
instructions: str | None = None,
agent: Agent | None = None,
thread: Thread | str | None = None,
context: dict | None = None,
) -> T:
"""Converts input data into a single entity of the specified target type asynchronously.
"""Asynchronously transforms input data into a specific type using a language model.

This function uses a language model to analyze the input data and convert
it into a single entity of the specified type, preserving semantic meaning
where possible.
This function uses a language model to analyze the input data and transform it
into the specified target type, maintaining as much semantic meaning as possible.

Args:
data: The input data to convert. Can be any type.
target: The type to convert the data into. Defaults to str.
instructions: Optional additional instructions to guide the conversion.
Used to provide specific guidance about how to interpret or
transform the data. Required when target is str.
agent: Optional custom agent to use for conversion. If not provided,
data: The input data to transform. Can be any type.
target: The type to transform the data into. Defaults to str.
instructions: Optional additional instructions to guide the transformation.
Used to provide specific guidance about how to interpret or transform
the data.
agent: Optional custom agent to use for transformation. If not provided,
the default agent will be used.
thread: Optional thread for maintaining conversation context. Can be
either a Thread object or a string thread ID.
context: Optional dictionary of additional context to include in the task.

Returns:
A single entity of type T.
The transformed data of type T.

Raises:
ValueError: If target is str and no instructions are provided.
Examples:
>>> # Cast to string
>>> await cast_async(123, str)
'123'

"""
if target is str and instructions is None:
raise ValueError("Instructions are required when target type is str.")
>>> # Cast to float with instructions
>>> await cast_async("three point five", float, instructions="Convert words to numbers")
3.5

>>> # Cast to bool
>>> await cast_async("yes", bool)
True

context = {"Data to transform": data}
"""
task_context = context or {}
task_context["Data to transform"] = data
if instructions:
context["Additional instructions"] = instructions
task_context["Additional instructions"] = instructions

task = marvin.Task[target](
name="Cast Task",
instructions=PROMPT,
context=context,
context=task_context,
result_type=target,
agents=[agent] if agent else None,
)
Expand All @@ -82,29 +91,40 @@ def cast(
instructions: str | None = None,
agent: Agent | None = None,
thread: Thread | str | None = None,
context: dict | None = None,
) -> T:
"""Converts input data into a single entity of the specified target type.
"""Transforms input data into a specific type using a language model.

This function uses a language model to analyze the input data and convert
it into a single entity of the specified type, preserving semantic meaning
where possible.
This function uses a language model to analyze the input data and transform it
into the specified target type, maintaining as much semantic meaning as possible.

Args:
data: The input data to convert. Can be any type.
target: The type to convert the data into. Defaults to str.
instructions: Optional additional instructions to guide the conversion.
Used to provide specific guidance about how to interpret or
transform the data. Required when target is str.
agent: Optional custom agent to use for conversion. If not provided,
data: The input data to transform. Can be any type.
target: The type to transform the data into. Defaults to str.
instructions: Optional additional instructions to guide the transformation.
Used to provide specific guidance about how to interpret or transform
the data.
agent: Optional custom agent to use for transformation. If not provided,
the default agent will be used.
thread: Optional thread for maintaining conversation context. Can be
either a Thread object or a string thread ID.
context: Optional dictionary of additional context to include in the task.

Returns:
A single entity of type T.
The transformed data of type T.

Examples:
>>> # Cast to string
>>> cast(123, str)
'123'

>>> # Cast to float with instructions
>>> cast("three point five", float, instructions="Convert words to numbers")
3.5

Raises:
ValueError: If target is str and no instructions are provided.
>>> # Cast to bool
>>> cast("yes", bool)
True

"""
return run_sync(
Expand All @@ -114,5 +134,6 @@ def cast(
instructions=instructions,
agent=agent,
thread=thread,
context=context,
),
)
16 changes: 13 additions & 3 deletions src/marvin/fns/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ async def classify_async(
instructions: str | None = None,
agent: Agent | None = None,
thread: Thread | str | None = None,
context: dict | None = None,
) -> T: ...


Expand All @@ -40,6 +41,7 @@ async def classify_async(
instructions: str | None = None,
agent: Agent | None = None,
thread: Thread | str | None = None,
context: dict | None = None,
) -> list[T]: ...


Expand All @@ -50,6 +52,7 @@ async def classify_async(
instructions: str | None = None,
agent: Agent | None = None,
thread: Thread | str | None = None,
context: dict | None = None,
) -> T | list[T]:
"""Asynchronously classifies input data into one or more predefined labels using a language model.

Expand All @@ -70,6 +73,7 @@ async def classify_async(
the default agent will be used.
thread: Optional thread for maintaining conversation context. Can be
either a Thread object or a string thread ID.
context: Optional dictionary of additional context to include in the task.

Returns:
- If labels is a Sequence[T]:
Expand Down Expand Up @@ -101,9 +105,10 @@ async def classify_async(
True

"""
context = {"Data to classify": data}
task_context = context or {}
task_context["Data to classify"] = data
if instructions:
context["Additional instructions"] = instructions
task_context["Additional instructions"] = instructions

# Convert Enum class to sequence of values if needed
if labels is bool or issubclass_safe(labels, enum.Enum):
Expand All @@ -117,7 +122,7 @@ async def classify_async(
task = marvin.Task[result_type](
name="Classification Task",
instructions=PROMPT,
context=context,
context=task_context,
result_type=result_type,
agents=[agent] if agent else None,
)
Expand All @@ -133,6 +138,7 @@ def classify(
instructions: str | None = None,
agent: Agent | None = None,
thread: Thread | str | None = None,
context: dict | None = None,
) -> T: ...


Expand All @@ -145,6 +151,7 @@ def classify(
instructions: str | None = None,
agent: Agent | None = None,
thread: Thread | str | None = None,
context: dict | None = None,
) -> list[T]: ...


Expand All @@ -155,6 +162,7 @@ def classify(
instructions: str | None = None,
agent: Agent | None = None,
thread: Thread | str | None = None,
context: dict | None = None,
) -> T | list[T]:
"""Classifies input data into one or more predefined labels using a language model.

Expand All @@ -175,6 +183,7 @@ def classify(
the default agent will be used.
thread: Optional thread for maintaining conversation context. Can be
either a Thread object or a string thread ID.
context: Optional dictionary of additional context to include in the task.

Returns:
- If labels is a Sequence[T]:
Expand Down Expand Up @@ -214,5 +223,6 @@ def classify(
instructions=instructions,
agent=agent,
thread=thread,
context=context,
),
)
14 changes: 10 additions & 4 deletions src/marvin/fns/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ async def extract_async(
instructions: str | None = None,
agent: Agent | None = None,
thread: Thread | str | None = None,
context: dict | None = None,
) -> list[T]:
"""Extracts entities of a specific type from the provided data asynchronously.
"""Extracts entities of a specific type from the provided data.

This function uses a language model to identify and extract entities of the
specified type from the input data. The extracted entities are returned as a
Expand All @@ -45,6 +46,7 @@ async def extract_async(
the default agent will be used.
thread: Optional thread for maintaining conversation context. Can be
either a Thread object or a string thread ID.
context: Optional dictionary of additional context to include in the task.

Returns:
A list of extracted entities of type T.
Expand All @@ -56,14 +58,15 @@ async def extract_async(
if target is str and instructions is None:
raise ValueError("Instructions are required when target type is str.")

context = {"Data to extract": data}
task_context = context or {}
task_context["Data to extract"] = data
if instructions:
context["Additional instructions"] = instructions
task_context["Additional instructions"] = instructions

task = marvin.Task[list[target]](
name="Extraction Task",
instructions=PROMPT,
context=context,
context=task_context,
result_type=list[target],
agents=[agent] if agent else None,
)
Expand All @@ -77,6 +80,7 @@ def extract(
instructions: str | None = None,
agent: Agent | None = None,
thread: Thread | str | None = None,
context: dict | None = None,
) -> list[T]:
"""Extracts entities of a specific type from the provided data.

Expand All @@ -94,6 +98,7 @@ def extract(
the default agent will be used.
thread: Optional thread for maintaining conversation context. Can be
either a Thread object or a string thread ID.
context: Optional dictionary of additional context to include in the task.

Returns:
A list of extracted entities of type T.
Expand All @@ -109,5 +114,6 @@ def extract(
instructions=instructions,
agent=agent,
thread=thread,
context=context,
),
)
14 changes: 10 additions & 4 deletions src/marvin/fns/generate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, TypeVar, cast
from typing import TypeVar, cast

import marvin
from marvin.agents.agent import Agent
Expand Down Expand Up @@ -32,6 +32,7 @@ async def generate_async(
instructions: str | None = None,
agent: Agent | None = None,
thread: Thread | str | None = None,
context: dict | None = None,
) -> list[T]:
"""Generates examples of a specific type or matching a description asynchronously.

Expand All @@ -48,6 +49,7 @@ async def generate_async(
the default agent will be used.
thread: Optional thread for maintaining conversation context. Can be
either a Thread object or a string thread ID.
context: Optional dictionary of additional context to include in the task.

Returns:
A list of n generated entities of type T.
Expand All @@ -56,14 +58,15 @@ async def generate_async(
if target is str and instructions is None:
raise ValueError("Instructions are required when target type is str.")

context: dict[str, Any] = {"Number to generate": n}
task_context = context or {}
task_context["Number to generate"] = n
if instructions:
context["Additional instructions"] = instructions
task_context["Additional instructions"] = instructions

task = marvin.Task[list[target]](
name="Generation Task",
instructions=PROMPT,
context=context,
context=task_context,
result_type=list[target],
agents=[agent] if agent else None,
)
Expand All @@ -77,6 +80,7 @@ def generate(
instructions: str | None = None,
agent: Agent | None = None,
thread: Thread | str | None = None,
context: dict | None = None,
) -> list[T]:
"""Generates examples of a specific type or matching a description.

Expand All @@ -93,6 +97,7 @@ def generate(
the default agent will be used.
thread: Optional thread for maintaining conversation context. Can be
either a Thread object or a string thread ID.
context: Optional dictionary of additional context to include in the task.

Returns:
A list of n generated entities of type T.
Expand All @@ -105,5 +110,6 @@ def generate(
instructions=instructions,
agent=agent,
thread=thread,
context=context,
),
)
Loading
Loading