Skip to content

Commit

Permalink
Make Gemini return structured outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
dirkbrnd committed Feb 12, 2025
1 parent 551aadb commit 989f275
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 27 deletions.
10 changes: 7 additions & 3 deletions libs/agno/agno/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,9 +882,11 @@ def run(
import time

time.sleep(delay)

# If we get here, all retries failed
raise Exception(f"Failed after {num_attempts} attempts. Last error: {str(last_exception)}")

if last_exception is not None:
raise Exception(f"Failed after {num_attempts} attempts. Last error using {last_exception.model_name}({last_exception.model_id}): {str(last_exception)}")
else:
raise Exception(f"Failed after {num_attempts} attempts.")

async def _arun(
self,
Expand Down Expand Up @@ -1421,11 +1423,13 @@ def update_model(self) -> None:

# Update the response_format on the Model
if self.response_model is not None:
# This will pass the pydantic model to the model
if self.structured_outputs and self.model.supports_structured_outputs:
logger.debug("Setting Model.response_format to Agent.response_model")
self.model.response_format = self.response_model
self.model.structured_outputs = True
else:
# Otherwise we just want JSON
self.model.response_format = {"type": "json_object"}
else:
self.model.response_format = None
Expand Down
53 changes: 29 additions & 24 deletions libs/agno/agno/models/google/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ class Gemini(Model):
name: str = "Gemini"
provider: str = "Google"

supports_structured_outputs: bool = True

assistant_message_role: str = "model"

# Request parameters
Expand Down Expand Up @@ -608,32 +610,35 @@ def parse_provider_response(self, response: GenerateContentResponse) -> ModelRes
model_response = ModelResponse()

# Get response message
response_message: Content = response.candidates[0].content

# Add role
if response_message.role is not None:
model_response.role = response_message.role

# Add content
if response_message.parts is not None:
for part in response_message.parts:
# Extract text if present
if hasattr(part, "text") and part.text is not None:
model_response.content = part.text
if response.candidates is not None:
response_message: Content = response.candidates[0].content

# Add role
if response_message.role is not None:
model_response.role = response_message.role

# Add content
if response_message.parts is not None:
for part in response_message.parts:
# Extract text if present
if hasattr(part, "text") and part.text is not None:
model_response.content = part.text

# Extract function call if present
if hasattr(part, "function_call") and part.function_call is not None:
tool_call = {
"type": "function",
"function": {
"name": part.function_call.name,
"arguments": json.dumps(part.function_call.args)
if part.function_call.args is not None
else "",
},
}

model_response.tool_calls.append(tool_call)

# Extract function call if present
if hasattr(part, "function_call") and part.function_call is not None:
tool_call = {
"type": "function",
"function": {
"name": part.function_call.name,
"arguments": json.dumps(part.function_call.args)
if part.function_call.args is not None
else "",
},
}

model_response.tool_calls.append(tool_call)

# Extract usage metadata if present
if hasattr(response, "usage_metadata"):
Expand Down
13 changes: 13 additions & 0 deletions libs/agno/agno/models/groq/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
try:
from groq import AsyncGroq as AsyncGroqClient
from groq import Groq as GroqClient
from groq import APIError, APIConnectionError, APITimeoutError, APIStatusError
from groq.types.chat import ChatCompletion
from groq.types.chat.chat_completion_chunk import ChatCompletionChunk, ChoiceDelta, ChoiceDeltaToolCall
except (ModuleNotFoundError, ImportError):
Expand Down Expand Up @@ -228,6 +229,9 @@ def invoke(self, messages: List[Message]) -> ChatCompletion:
messages=[format_message(m) for m in messages], # type: ignore
**self.request_kwargs,
)
except (APIError, APIConnectionError, APITimeoutError, APIStatusError) as e:
logger.error(f"Error calling Groq API: {str(e)}")
raise ModelProviderError(e, self.name, self.id) from e
except Exception as e:
logger.error(f"Unexpected error calling Groq API: {str(e)}")
raise ModelProviderError(e, self.name, self.id) from e
Expand All @@ -248,6 +252,9 @@ async def ainvoke(self, messages: List[Message]) -> ChatCompletion:
messages=[format_message(m) for m in messages], # type: ignore
**self.request_kwargs,
)
except (APIError, APIConnectionError, APITimeoutError, APIStatusError) as e:
logger.error(f"Error calling Groq API: {str(e)}")
raise ModelProviderError(e, self.name, self.id) from e
except Exception as e:
logger.error(f"Unexpected error calling Groq API: {str(e)}")
raise ModelProviderError(e, self.name, self.id) from e
Expand All @@ -269,6 +276,9 @@ def invoke_stream(self, messages: List[Message]) -> Iterator[ChatCompletionChunk
stream=True,
**self.request_kwargs,
)
except (APIError, APIConnectionError, APITimeoutError, APIStatusError) as e:
logger.error(f"Error calling Groq API: {str(e)}")
raise ModelProviderError(e, self.name, self.id) from e
except Exception as e:
logger.error(f"Unexpected error calling Groq API: {str(e)}")
raise ModelProviderError(e, self.name, self.id) from e
Expand All @@ -293,6 +303,9 @@ async def ainvoke_stream(self, messages: List[Message]) -> Any:
)
async for chunk in stream: # type: ignore
yield chunk
except (APIError, APIConnectionError, APITimeoutError, APIStatusError) as e:
logger.error(f"Error calling Groq API: {str(e)}")
raise ModelProviderError(e, self.name, self.id) from e
except Exception as e:
logger.error(f"Unexpected error calling Groq API: {str(e)}")
raise ModelProviderError(e, self.name, self.id) from e
Expand Down

0 comments on commit 989f275

Please sign in to comment.