Skip to content

Commit

Permalink
Fix all style
Browse files Browse the repository at this point in the history
  • Loading branch information
dirkbrnd committed Feb 11, 2025
1 parent f50a794 commit 88025bb
Show file tree
Hide file tree
Showing 20 changed files with 146 additions and 125 deletions.
8 changes: 7 additions & 1 deletion cookbook/models/google/gemini/async_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
from agno.agent import Agent, RunResponse # noqa
from agno.models.google import Gemini

agent = Agent(model=Gemini(id="gemini-2.0-flash-exp"), markdown=True)
agent = Agent(
model=Gemini(
id="gemini-2.0-flash-exp",
instructions=["You are a basic agent that writes short stories."],
),
markdown=True,
)

# Get the response in a variable
# run: RunResponse = agent.run("Share a 2 sentence horror story")
Expand Down
4 changes: 2 additions & 2 deletions libs/agno/agno/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class Agent:
# A list of tools provided to the Model.
# Tools are functions the model may generate JSON inputs for.
# If you provide a dict, it is not called by the model.
tools: Optional[List[Union[Toolkit, Callable, Function, Dict]]] = None
tools: Optional[List[Union[Toolkit, Callable, Function]]] = None
# Show tool calls in Agent response.
show_tool_calls: bool = False
# Maximum number of tool calls allowed.
Expand Down Expand Up @@ -251,7 +251,7 @@ def __init__(
references_format: Literal["json", "yaml"] = "json",
storage: Optional[AgentStorage] = None,
extra_data: Optional[Dict[str, Any]] = None,
tools: Optional[List[Union[Toolkit, Callable, Function, Dict]]] = None,
tools: Optional[List[Union[Toolkit, Callable, Function]]] = None,
show_tool_calls: bool = False,
tool_call_limit: Optional[int] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
Expand Down
4 changes: 2 additions & 2 deletions libs/agno/agno/embedder/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _response(self, text: str) -> Union[EmbeddingDict, BatchEmbeddingDict]:
def get_embedding(self, text: str) -> List[float]:
response = self._response(text=text)
try:
return response.get("embedding", [])
return response.get("embedding", []) # type: ignore
except Exception as e:
logger.warning(e)
return []
Expand All @@ -67,7 +67,7 @@ def get_embedding_and_usage(self, text: str) -> Tuple[List[float], Optional[Dict
response = self._response(text=text)
usage = None
try:
return response.get("embedding", []), usage
return response.get("embedding", []), usage # type: ignore
except Exception as e:
logger.warning(e)
return [], usage
2 changes: 1 addition & 1 deletion libs/agno/agno/memory/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def add_tools_to_model(self, model: Model) -> None:
try:
function_name = tool.__name__
if function_name not in self._functions_for_model:
func = Function.from_callable(tool)
func = Function.from_callable(tool) # type: ignore
self._functions_for_model[func.name] = func
self._tools_for_model.append({"type": "function", "function": func.to_dict()})
logger.debug(f"Included function {func.name}")
Expand Down
4 changes: 2 additions & 2 deletions libs/agno/agno/models/anthropic/claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,8 +538,8 @@ def parse_provider_response_delta(

elif isinstance(response, ContentBlockStopEvent):
# Handle tool calls
if isinstance(response.content_block, ToolUseBlock):
tool_use = response.content_block
if isinstance(response.content_block, ToolUseBlock): # type: ignore
tool_use = response.content_block # type: ignore
tool_name = tool_use.name
tool_input = tool_use.input

Expand Down
106 changes: 54 additions & 52 deletions libs/agno/agno/models/aws/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
from abc import ABC
from dataclasses import dataclass
from os import getenv
from typing import Any, Dict, Iterator, List, Optional, Tuple
Expand All @@ -19,7 +18,14 @@


@dataclass
class AwsBedrock(Model, ABC):
class AwsBedrockResponseUsage:
input_tokens: int = 0
output_tokens: int = 0
total_tokens: int = 0


@dataclass
class AwsBedrock(Model):
"""
AWS Bedrock model.
Expand All @@ -35,6 +41,7 @@ class AwsBedrock(Model, ABC):
aws_access_key_id (Optional[str]): The AWS access key ID to use.
aws_secret_access_key (Optional[str]): The AWS secret access key to use.
"""

id: str = "mistral.mistral-small-2402-v1:0"
name: str = "AwsBedrock"
provider: str = "AwsBedrock"
Expand Down Expand Up @@ -62,7 +69,9 @@ def get_client(self) -> AwsClient:

if not self.aws_access_key_id or not self.aws_secret_access_key:
raise ModelProviderError(
"AWS credentials not found. Please set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables."
"AWS credentials not found. Please set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables.",
model_name=self.name,
model_id=self.id,
)

self.client = AwsClient(
Expand All @@ -75,34 +84,35 @@ def get_client(self) -> AwsClient:

def _format_tools_for_request(self) -> List[Dict[str, Any]]:
tools = []
for f_name, function in self._functions.items():
properties = {}
required = []

for param_name, param_info in function.parameters.get("properties", {}).items():
param_type = param_info.get("type")
if isinstance(param_type, list):
param_type = [t for t in param_type if t != "null"][0]

properties[param_name] = {
"type": param_type or "string",
"description": param_info.get("description") or "",
}
if self._functions is not None:
for f_name, function in self._functions.items():
properties = {}
required = []

for param_name, param_info in function.parameters.get("properties", {}).items():
param_type = param_info.get("type")
if isinstance(param_type, list):
param_type = [t for t in param_type if t != "null"][0]

properties[param_name] = {
"type": param_type or "string",
"description": param_info.get("description") or "",
}

if "null" not in (
param_info.get("type") if isinstance(param_info.get("type"), list) else [param_info.get("type")]
):
required.append(param_name)

tools.append(
{
"toolSpec": {
"name": f_name,
"description": function.description or "",
"inputSchema": {"json": {"type": "object", "properties": properties, "required": required}},
if "null" not in (
param_info.get("type") if isinstance(param_info.get("type"), list) else [param_info.get("type")]
):
required.append(param_name)

tools.append(
{
"toolSpec": {
"name": f_name,
"description": function.description or "",
"inputSchema": {"json": {"type": "object", "properties": properties, "required": required}},
}
}
}
)
)

return tools

Expand All @@ -118,13 +128,13 @@ def _get_inference_config(self) -> Dict[str, Any]:
return request_kwargs

def _format_messages(self, messages: List[Message]) -> Tuple[List[Dict[str, Any]], Optional[List[Dict[str, Any]]]]:
formatted_messages = []
formatted_messages: List[Dict[str, Any]] = []
system_message = None
for message in messages:
if message.role == "system":
system_message = [{"text": message.content}]
else:
formatted_message = {"role": message.role}
formatted_message: Dict[str, Any] = {"role": message.role}
formatted_message["content"] = []
# Handle tool results
if isinstance(message.content, list):
Expand Down Expand Up @@ -225,11 +235,7 @@ def invoke(self, messages: List[Message]) -> Dict[str, Any]:
}
body = {k: v for k, v in body.items() if v is not None}

return self.get_client().converse(
modelId=self.id,
messages=formatted_messages,
**body
)
return self.get_client().converse(modelId=self.id, messages=formatted_messages, **body)
except ClientError as e:
logger.error(f"Unexpected error calling Bedrock API: {str(e)}")
raise ModelProviderError(e, self.name, self.id) from e
Expand Down Expand Up @@ -262,11 +268,7 @@ def invoke_stream(self, messages: List[Message]) -> Iterator[Dict[str, Any]]:
}
body = {k: v for k, v in body.items() if v is not None}

return self.get_client().converse_stream(
modelId=self.id,
messages=formatted_messages,
**body
)["stream"]
return self.get_client().converse_stream(modelId=self.id, messages=formatted_messages, **body)["stream"]
except ClientError as e:
logger.error(f"Unexpected error calling Bedrock API: {str(e)}")
raise ModelProviderError(e, self.name, self.id) from e
Expand Down Expand Up @@ -335,11 +337,11 @@ def parse_provider_response(self, response: Dict[str, Any]) -> ModelResponse:
model_response.content = content

if "usage" in response:
model_response.usage = {
"input_tokens": response["usage"]["inputTokens"],
"output_tokens": response["usage"]["outputTokens"],
"total_tokens": response["usage"]["totalTokens"],
}
model_response.response_usage = AwsBedrockResponseUsage(
input_tokens=response["usage"]["inputTokens"],
output_tokens=response["usage"]["outputTokens"],
total_tokens=response["usage"]["totalTokens"],
)

return model_response

Expand Down Expand Up @@ -401,11 +403,11 @@ def process_response_stream(
elif "messageStop" in response_delta:
if "usage" in response_delta["messageStop"]:
usage = response_delta["messageStop"]["usage"]
model_response.usage = {
"input_tokens": usage.get("inputTokens", 0),
"output_tokens": usage.get("outputTokens", 0),
"total_tokens": usage.get("totalTokens", 0),
}
model_response.response_usage = AwsBedrockResponseUsage(
input_tokens=usage.get("inputTokens", 0),
output_tokens=usage.get("outputTokens", 0),
total_tokens=usage.get("totalTokens", 0),
)

# Update metrics
assistant_message.metrics.completion_tokens += 1
Expand Down Expand Up @@ -433,7 +435,7 @@ def process_response_stream(
if tool_ids:
stream_data.extra["tool_ids"] = tool_ids

def parse_provider_response_delta(self, response_delta: Dict[str, Any]) -> ModelResponse:
def parse_provider_response_delta(self, response_delta: Dict[str, Any]) -> ModelResponse: # type: ignore
pass

async def ainvoke(self, *args, **kwargs) -> Any:
Expand Down
8 changes: 4 additions & 4 deletions libs/agno/agno/models/aws/claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def to_dict(self) -> Dict[str, Any]:
_dict["stop_sequences"] = self.stop_sequences
return _dict

client: Optional[AnthropicBedrock] = None
async_client: Optional[AsyncAnthropicBedrock] = None
client: Optional[AnthropicBedrock] = None # type: ignore
async_client: Optional[AsyncAnthropicBedrock] = None # type: ignore

def get_client(self):
if self.client is not None:
Expand All @@ -74,7 +74,7 @@ def get_client(self):
client_params.update(self.client_params)

self.client = AnthropicBedrock(
**client_params,
**client_params, # type: ignore
)
return self.client

Expand All @@ -91,7 +91,7 @@ def get_async_client(self):
client_params.update(self.client_params)

self.async_client = AsyncAnthropicBedrock(
**client_params,
**client_params, # type: ignore
)
return self.async_client

Expand Down
9 changes: 4 additions & 5 deletions libs/agno/agno/models/azure/ai_foundry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
from collections.abc import AsyncIterator
from dataclasses import dataclass
from os import getenv
Expand Down Expand Up @@ -127,13 +126,13 @@ def _get_request_kwargs(self) -> Dict[str, Any]:
)
)
)
base_params["tools"] = tools
base_params["tools"] = tools # type: ignore
if self.tool_choice:
base_params["tool_choice"] = self.tool_choice

if self.response_format is not None and self.structured_outputs:
if isinstance(self.response_format, type) and issubclass(self.response_format, BaseModel):
base_params["response_format"] = (
base_params["response_format"] = ( # type: ignore
JsonSchemaFormat(
name=self.response_format.__name__,
schema=self.response_format.model_json_schema(),
Expand Down Expand Up @@ -283,7 +282,7 @@ async def ainvoke_stream(self, messages: List[Message]) -> AsyncIterator[Any]:
stream=True,
**self._get_request_kwargs(),
)
async for chunk in stream:
async for chunk in stream: # type: ignore
yield chunk

except HttpResponseError as e:
Expand Down Expand Up @@ -400,7 +399,7 @@ def parse_provider_response_delta(self, response_delta: StreamingChatCompletions

# Add tool calls if present
if delta.tool_calls and len(delta.tool_calls) > 0:
model_response.tool_calls = delta.tool_calls
model_response.tool_calls = delta.tool_calls # type: ignore
# Add usage metrics if present
if response_delta.usage is not None:
model_response.response_usage = {
Expand Down
15 changes: 9 additions & 6 deletions libs/agno/agno/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from types import GeneratorType
from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Tuple, Union
from typing import Any, AsyncGenerator, AsyncIterator, Dict, Iterator, List, Literal, Optional, Tuple, Union

from agno.exceptions import AgentRunException
from agno.media import AudioOutput
Expand Down Expand Up @@ -108,7 +108,7 @@ def invoke_stream(self, *args, **kwargs) -> Iterator[Any]:
pass

@abstractmethod
async def ainvoke_stream(self, *args, **kwargs) -> Any:
async def ainvoke_stream(self, *args, **kwargs) -> AsyncGenerator[Any, None]:
pass

@abstractmethod
Expand Down Expand Up @@ -505,7 +505,7 @@ async def aprocess_response_stream(
"""
Process a streaming response from the model.
"""
async for response_delta in self.ainvoke_stream(messages=messages):
async for response_delta in self.ainvoke_stream(messages=messages): # type: ignore
model_response_delta = self.parse_provider_response_delta(response_delta)
if model_response_delta:
for model_response in self._populate_stream_data_and_assistant_message(
Expand Down Expand Up @@ -596,7 +596,6 @@ def _populate_stream_data_and_assistant_message(
"""Update the stream data and assistant message with the model response."""

# Update metrics
assistant_message.metrics.output_tokens += 1
if not assistant_message.metrics.time_to_first_token:
assistant_message.metrics.set_time_to_first_token()

Expand Down Expand Up @@ -922,12 +921,16 @@ def _show_tool_calls(self, function_calls_to_run: List[FunctionCall], model_resp
Show tool calls in the model response.
"""
if len(function_calls_to_run) == 1:
if len(model_response.content) > 0 and model_response.content[-1] != "\n":
if model_response.content and len(model_response.content) > 0 and model_response.content[-1] != "\n":
model_response.content += "\n\n"
else:
model_response.content = ""
model_response.content += f" - Running: {function_calls_to_run[0].get_call_str()}\n\n"
elif len(function_calls_to_run) > 1:
if len(model_response.content) > 0 and model_response.content[-1] != "\n":
if model_response.content and len(model_response.content) > 0 and model_response.content[-1] != "\n":
model_response.content += "\n\n"
else:
model_response.content = ""
model_response.content += "Running:"
for _f in function_calls_to_run:
model_response.content += f"\n - {_f.get_call_str()}"
Expand Down
Loading

0 comments on commit 88025bb

Please sign in to comment.