From 3ff066f278d616e09670cc4d4debd50e31ccdb8a Mon Sep 17 00:00:00 2001 From: Connor Shorten Date: Tue, 21 Jan 2025 10:45:35 -0500 Subject: [PATCH] agent fixes --- notebooks/bigquery-pydantic-simple.ipynb | 154 +++++++++++++---- src/lm/pydantic_agents/agent.py | 53 +++++- src/lm/pydantic_agents/executors.py | 204 +++++++++++++---------- src/lm/pydantic_agents/nodes.py | 19 ++- src/test_gorilla/test_pydantic_agent.py | 2 + 5 files changed, 298 insertions(+), 134 deletions(-) diff --git a/notebooks/bigquery-pydantic-simple.ipynb b/notebooks/bigquery-pydantic-simple.ipynb index 8b67ee6..6f41fdf 100644 --- a/notebooks/bigquery-pydantic-simple.ipynb +++ b/notebooks/bigquery-pydantic-simple.ipynb @@ -6,7 +6,9 @@ "source": [ "# Pydantic is the new SQL\n", "\n", - "For decades, SQL was our go-to tool for shaping and constraining data interactions—querying relational databases with crystal-clear precision. In the new era of Large Language Models (LLMs), we need a similar level of control and structure over our AI-generated outputs. Enter Pydantic: by defining typed models that validate data at runtime, we can harness LLMs to produce strictly formatted outputs—much like how SQL structures our database queries. In this notebook, we’ll explore how `Pydantic is the new SQL`, demonstrating how to build models that specify the exact shape and rules of the data we need from an LLM. We’ll walk through an example using Google BigQuery to show how this approach seamlessly translates natural language prompts into reliably structured queries—and how it can transform the way we interact with generative AI systems." + "This notebook will illustrate the challenge of text-to-SQL approaches with LLMs,\n", + "\n", + "> and why we believe **structured outputs with Pydantic** is the best way to query databases with LLMs." ] }, { @@ -15,7 +17,7 @@ "source": [ "### DSPy setup\n", "\n", - "DSPy is an open-source framework for programming Large Language Model applications developed by Stanford University. At the time of this writing, DSPy has achieved 21,000 stars on GitHub and 166 academic citations." + "This notebook does not use any particular features of DSPy -- it is just wrapping the structured output API from OpenAI." ] }, { @@ -82,12 +84,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Google Data Marketplace\n", - "\n", - "From Google Cloud, \n", - "```\n", - "\"Enhance your analytics and AI initiatives with pre-built data solutions and valuable datasets powered by BigQuery, Cloud Storage, Earth Engine, and other Google Cloud services.\"\n", - "```" + "### Google Data Marketplace" ] }, { @@ -119,6 +116,34 @@ " print(row)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Execute BigQuery Query" + ] + }, + { + "cell_type": "code", + "execution_count": 93, + "metadata": {}, + "outputs": [], + "source": [ + "def execute_query(query: str, client: bigquery.Client) -> bigquery.table.RowIterator:\n", + " \"\"\"\n", + " Execute a BigQuery SQL query using the provided client\n", + " \n", + " Args:\n", + " query: SQL query string to execute\n", + " client: Authenticated BigQuery client\n", + " \n", + " Returns:\n", + " Iterator of query results\n", + " \"\"\"\n", + " query_job = client.query(query)\n", + " return query_job.result()" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -167,7 +192,90 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### BigQuery Pydantic model" + "### SQL as a Text Output" + ] + }, + { + "cell_type": "code", + "execution_count": 90, + "metadata": {}, + "outputs": [], + "source": [ + "class SQLWriter(dspy.Signature):\n", + " \"\"\"Translate a natural language information need into a BigQuery query. In your rationale please be very clear about why you chose the specific query operators you did and why you do not need the operators you did not choose.\"\"\"\n", + "\n", + " nl_command: str = dspy.InputField(desc=\"A natural language command with an underlying information need your db_query should answer.\")\n", + " db_schema: str = dspy.InputField(desc=\"The database schema you can query.\")\n", + " db_sql_query: str = dspy.OutputField()" + ] + }, + { + "cell_type": "code", + "execution_count": 92, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "```sql\n", + "SELECT name, SUM(number) AS total_count\n", + "FROM `bigquery-public-data.usa_names.usa_1910_2013`\n", + "WHERE state = 'TX'\n", + "GROUP BY name\n", + "ORDER BY total_count DESC\n", + "LIMIT 5\n", + "```\n" + ] + } + ], + "source": [ + "bigquery_writer = dspy.ChainOfThought(SQLWriter)\n", + "texas_names_schema = get_table_schema(\"bigquery-public-data.usa_names.usa_1910_2013\")\n", + "\n", + "generated_query = bigquery_writer(\n", + " nl_command = \"What are the 5 most common names in Texas?\",\n", + " db_schema = texas_names_schema\n", + ")\n", + "\n", + "db_sql_query = generated_query.db_sql_query\n", + "\n", + "print(db_sql_query)" + ] + }, + { + "cell_type": "code", + "execution_count": 94, + "metadata": {}, + "outputs": [ + { + "ename": "BadRequest", + "evalue": "400 Syntax error: Unexpected identifier `` at [1:1]; reason: invalidQuery, location: query, message: Syntax error: Unexpected identifier `` at [1:1]\n\nLocation: US\nJob ID: 5ca32c57-14ea-468c-9948-d8c163bec990\n", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mBadRequest\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[94], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mexecute_query\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdb_sql_query\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbigquery_client\u001b[49m\u001b[43m)\u001b[49m\n", + "Cell \u001b[0;32mIn[93], line 13\u001b[0m, in \u001b[0;36mexecute_query\u001b[0;34m(query, client)\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;124;03mExecute a BigQuery SQL query using the provided client\u001b[39;00m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;124;03m\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;124;03m Iterator of query results\u001b[39;00m\n\u001b[1;32m 11\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 12\u001b[0m query_job \u001b[38;5;241m=\u001b[39m client\u001b[38;5;241m.\u001b[39mquery(query)\n\u001b[0;32m---> 13\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mquery_job\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mresult\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/google/cloud/bigquery/job/query.py:1681\u001b[0m, in \u001b[0;36mQueryJob.result\u001b[0;34m(self, page_size, max_results, retry, timeout, start_index, job_retry)\u001b[0m\n\u001b[1;32m 1676\u001b[0m remaining_timeout \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1678\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m remaining_timeout \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1679\u001b[0m \u001b[38;5;66;03m# Since is_job_done() calls jobs.getQueryResults, which is a\u001b[39;00m\n\u001b[1;32m 1680\u001b[0m \u001b[38;5;66;03m# long-running API, don't delay the next request at all.\u001b[39;00m\n\u001b[0;32m-> 1681\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[43mis_job_done\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[1;32m 1682\u001b[0m \u001b[38;5;28;01mpass\u001b[39;00m\n\u001b[1;32m 1683\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1684\u001b[0m \u001b[38;5;66;03m# Use a monotonic clock since we don't actually care about\u001b[39;00m\n\u001b[1;32m 1685\u001b[0m \u001b[38;5;66;03m# daylight savings or similar, just the elapsed time.\u001b[39;00m\n", + "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/google/api_core/retry/retry_unary.py:293\u001b[0m, in \u001b[0;36mRetry.__call__..retry_wrapped_func\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 289\u001b[0m target \u001b[38;5;241m=\u001b[39m functools\u001b[38;5;241m.\u001b[39mpartial(func, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 290\u001b[0m sleep_generator \u001b[38;5;241m=\u001b[39m exponential_sleep_generator(\n\u001b[1;32m 291\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_initial, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_maximum, multiplier\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_multiplier\n\u001b[1;32m 292\u001b[0m )\n\u001b[0;32m--> 293\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mretry_target\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 294\u001b[0m \u001b[43m \u001b[49m\u001b[43mtarget\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 295\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_predicate\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 296\u001b[0m \u001b[43m \u001b[49m\u001b[43msleep_generator\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 297\u001b[0m \u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_timeout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 298\u001b[0m \u001b[43m \u001b[49m\u001b[43mon_error\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mon_error\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 299\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/google/api_core/retry/retry_unary.py:153\u001b[0m, in \u001b[0;36mretry_target\u001b[0;34m(target, predicate, sleep_generator, timeout, on_error, exception_factory, **kwargs)\u001b[0m\n\u001b[1;32m 149\u001b[0m \u001b[38;5;66;03m# pylint: disable=broad-except\u001b[39;00m\n\u001b[1;32m 150\u001b[0m \u001b[38;5;66;03m# This function explicitly must deal with broad exceptions.\u001b[39;00m\n\u001b[1;32m 151\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m exc:\n\u001b[1;32m 152\u001b[0m \u001b[38;5;66;03m# defer to shared logic for handling errors\u001b[39;00m\n\u001b[0;32m--> 153\u001b[0m \u001b[43m_retry_error_helper\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 154\u001b[0m \u001b[43m \u001b[49m\u001b[43mexc\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 155\u001b[0m \u001b[43m \u001b[49m\u001b[43mdeadline\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 156\u001b[0m \u001b[43m \u001b[49m\u001b[43msleep\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 157\u001b[0m \u001b[43m \u001b[49m\u001b[43merror_list\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 158\u001b[0m \u001b[43m \u001b[49m\u001b[43mpredicate\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 159\u001b[0m \u001b[43m \u001b[49m\u001b[43mon_error\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 160\u001b[0m \u001b[43m \u001b[49m\u001b[43mexception_factory\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 161\u001b[0m \u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 162\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 163\u001b[0m \u001b[38;5;66;03m# if exception not raised, sleep before next attempt\u001b[39;00m\n\u001b[1;32m 164\u001b[0m time\u001b[38;5;241m.\u001b[39msleep(sleep)\n", + "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/google/api_core/retry/retry_base.py:212\u001b[0m, in \u001b[0;36m_retry_error_helper\u001b[0;34m(exc, deadline, next_sleep, error_list, predicate_fn, on_error_fn, exc_factory_fn, original_timeout)\u001b[0m\n\u001b[1;32m 206\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m predicate_fn(exc):\n\u001b[1;32m 207\u001b[0m final_exc, source_exc \u001b[38;5;241m=\u001b[39m exc_factory_fn(\n\u001b[1;32m 208\u001b[0m error_list,\n\u001b[1;32m 209\u001b[0m RetryFailureReason\u001b[38;5;241m.\u001b[39mNON_RETRYABLE_ERROR,\n\u001b[1;32m 210\u001b[0m original_timeout,\n\u001b[1;32m 211\u001b[0m )\n\u001b[0;32m--> 212\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m final_exc \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msource_exc\u001b[39;00m\n\u001b[1;32m 213\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m on_error_fn \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 214\u001b[0m on_error_fn(exc)\n", + "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/google/api_core/retry/retry_unary.py:144\u001b[0m, in \u001b[0;36mretry_target\u001b[0;34m(target, predicate, sleep_generator, timeout, on_error, exception_factory, **kwargs)\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m sleep \u001b[38;5;129;01min\u001b[39;00m sleep_generator:\n\u001b[1;32m 143\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 144\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43mtarget\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 145\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m inspect\u001b[38;5;241m.\u001b[39misawaitable(result):\n\u001b[1;32m 146\u001b[0m warnings\u001b[38;5;241m.\u001b[39mwarn(_ASYNC_RETRY_WARNING)\n", + "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/google/cloud/bigquery/job/query.py:1630\u001b[0m, in \u001b[0;36mQueryJob.result..is_job_done\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1607\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m job_failed_exception \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1608\u001b[0m \u001b[38;5;66;03m# Only try to restart the query job if the job failed for\u001b[39;00m\n\u001b[1;32m 1609\u001b[0m \u001b[38;5;66;03m# a retriable reason. For example, don't restart the query\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1627\u001b[0m \u001b[38;5;66;03m# into an exception that can be processed by the\u001b[39;00m\n\u001b[1;32m 1628\u001b[0m \u001b[38;5;66;03m# `job_retry` predicate.\u001b[39;00m\n\u001b[1;32m 1629\u001b[0m restart_query_job \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m-> 1630\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m job_failed_exception\n\u001b[1;32m 1631\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1632\u001b[0m \u001b[38;5;66;03m# Make sure that the _query_results are cached so we\u001b[39;00m\n\u001b[1;32m 1633\u001b[0m \u001b[38;5;66;03m# can return a complete RowIterator.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1639\u001b[0m \u001b[38;5;66;03m# making any extra API calls if the previous loop\u001b[39;00m\n\u001b[1;32m 1640\u001b[0m \u001b[38;5;66;03m# iteration fetched the finished job.\u001b[39;00m\n\u001b[1;32m 1641\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reload_query_results(\n\u001b[1;32m 1642\u001b[0m retry\u001b[38;5;241m=\u001b[39mretry, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mreload_query_results_kwargs\n\u001b[1;32m 1643\u001b[0m )\n", + "\u001b[0;31mBadRequest\u001b[0m: 400 Syntax error: Unexpected identifier `` at [1:1]; reason: invalidQuery, location: query, message: Syntax error: Unexpected identifier `` at [1:1]\n\nLocation: US\nJob ID: 5ca32c57-14ea-468c-9948-d8c163bec990\n" + ] + } + ], + "source": [ + "execute_query(db_sql_query, bigquery_client)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Pydantic is the new SQL" ] }, { @@ -217,34 +325,6 @@ " return query" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### BigQuery Query Executor" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def execute_query(query: str, client: bigquery.Client) -> bigquery.table.RowIterator:\n", - " \"\"\"\n", - " Execute a BigQuery SQL query using the provided client\n", - " \n", - " Args:\n", - " query: SQL query string to execute\n", - " client: Authenticated BigQuery client\n", - " \n", - " Returns:\n", - " Iterator of query results\n", - " \"\"\"\n", - " query_job = client.query(query)\n", - " return query_job.result()" - ] - }, { "cell_type": "markdown", "metadata": {}, diff --git a/src/lm/pydantic_agents/agent.py b/src/lm/pydantic_agents/agent.py index e81da47..4ba8f98 100644 --- a/src/lm/pydantic_agents/agent.py +++ b/src/lm/pydantic_agents/agent.py @@ -461,11 +461,60 @@ def _stringify_aggregation_results( """Stringify the results of the aggregation executor.""" print(f"{Fore.GREEN}Stringifying {len(results)} aggregation results{Style.RESET_ALL}") stringified_responses = [] + for idx, (result, query) in enumerate(zip(results, queries)): response = f"Aggregation Result {idx+1} (Query: {query}):\n" - for prop in result.properties: - response += f"{prop}:{result.properties[prop]}\n" + + # Handle grouped results (AggregateGroupByReturn) + if hasattr(result, 'groups'): + response += f"Found {len(result.groups)} groups:\n" + for group in result.groups: + response += f"\nGroup: {group.grouped_by.prop} = {group.grouped_by.value}\n" + if hasattr(group, 'total_count'): + response += f"Total Count: {group.total_count}\n" + + if hasattr(group, 'properties'): + for prop_name, metrics in group.properties.items(): + response += f"\n{prop_name} metrics:\n" + + # Handle text metrics with top occurrences + if hasattr(metrics, 'top_occurrences') and metrics.top_occurrences: + response += "Top occurrences:\n" + for occurrence in metrics.top_occurrences: + response += f" {occurrence.value}: {occurrence.count}\n" + + # Handle numeric metrics + if hasattr(metrics, 'count') and metrics.count is not None: + response += f" count: {metrics.count}\n" + if hasattr(metrics, 'minimum') and metrics.minimum is not None: + response += f" minimum: {metrics.minimum}\n" + if hasattr(metrics, 'maximum') and metrics.maximum is not None: + response += f" maximum: {metrics.maximum}\n" + if hasattr(metrics, 'mean') and metrics.mean is not None: + response += f" mean: {metrics.mean:.2f}\n" + if hasattr(metrics, 'sum_') and metrics.sum_ is not None: + response += f" sum: {metrics.sum_}\n" + + # Handle non-grouped results (AggregateReturn) + elif hasattr(result, 'properties'): + for prop_name, metrics in result.properties.items(): + response += f"\n{prop_name} metrics:\n" + if hasattr(metrics, 'count') and metrics.count is not None: + response += f" count: {metrics.count}\n" + if hasattr(metrics, 'mean') and metrics.mean is not None: + response += f" mean: {metrics.mean:.2f}\n" + if hasattr(metrics, 'maximum') and metrics.maximum is not None: + response += f" maximum: {metrics.maximum}\n" + if hasattr(metrics, 'minimum') and metrics.minimum is not None: + response += f" minimum: {metrics.minimum}\n" + if hasattr(metrics, 'sum_') and metrics.sum_ is not None: + response += f" sum: {metrics.sum_}\n" + + if hasattr(result, 'total_count'): + response += f"\nTotal Count: {result.total_count}\n" + stringified_responses.append(response) + return stringified_responses def _update_usage(self, usage: Usage): diff --git a/src/lm/pydantic_agents/executors.py b/src/lm/pydantic_agents/executors.py index 90f7f37..84fc9bf 100644 --- a/src/lm/pydantic_agents/executors.py +++ b/src/lm/pydantic_agents/executors.py @@ -107,71 +107,102 @@ def create_filter(f: PropertyFilter) -> _FilterByProperty | None: async def process_results(aggregation_results: List[AggregateReturn], search_results: List[QueryReturn]): """Process both aggregation and search results.""" + print(f"\nReceived {len(search_results)} search results and {len(aggregation_results)} aggregation results") + # Process search results if search_results: for i, result in enumerate(search_results, 1): print(f"\nSearch Result Set {i}:") + print(f"Number of objects in result set: {len(result.objects)}") for item in result.objects: - print(f"Menu Item: {item.properties['menuItem']}") - print(f"Price: ${item.properties['price']:.2f}") - print(f"Vegetarian: {item.properties['isVegetarian']}") + print(f"\nObject properties: {list(item.properties.keys())}") + for prop, value in item.properties.items(): + print(f"{prop}: {value}") print("---") + else: + print("No search results to process") # Process aggregation results if aggregation_results: - print("\nAggregation Results:") - for agg_result in aggregation_results: + print(f"\nAggregation Results (count: {len(aggregation_results)}):") + for idx, agg_result in enumerate(aggregation_results): + print(f"\nProcessing aggregation result {idx + 1}") + print(f"Result type: {type(agg_result)}") + print(f"Available attributes: {dir(agg_result)}") + try: - # Handle AggregateGroupByReturn object - if hasattr(agg_result, 'groups') and agg_result.groups: # Check if groups exist and is not empty - for group in agg_result.groups: - print(f"\nGroup: {group.grouped_by.prop} = {group.grouped_by.value}") - print(f"Count: {group.total_count}") + # Handle grouped results (AggregateGroupByReturn) + if hasattr(agg_result, 'groups'): + print(f"Found grouped result with {len(agg_result.groups)} groups") + for group_idx, group in enumerate(agg_result.groups): + print(f"\nProcessing group {group_idx + 1}") + print(f"Group attributes: {dir(group)}") + print(f"Group: {group.grouped_by.prop} = {group.grouped_by.value}") + + if hasattr(group, 'total_count'): + print(f"Total Count: {group.total_count}") - for prop_name, metrics in group.properties.items(): - print(f"{prop_name} metrics:") - if isinstance(metrics, (AggregateInteger, AggregateNumber)): - if metrics.mean is not None: - print(f" mean: {metrics.mean:.2f}") - if metrics.maximum is not None: - print(f" maximum: {metrics.maximum}") - if metrics.minimum is not None: - print(f" minimum: {metrics.minimum}") - if metrics.count is not None: - print(f" count: {metrics.count}") - if metrics.sum_ is not None: - print(f" sum: {metrics.sum_}") - else: - # Convert metrics object to dictionary, filtering None values - metrics_dict = {k: v for k, v in vars(metrics).items() if not k.startswith('_') and v is not None} - for metric_name, value in metrics_dict.items(): - print(f" {metric_name}: {value}") + if hasattr(group, 'properties'): + print(f"Properties in group: {list(group.properties.keys())}") + for prop_name, metrics in group.properties.items(): + print(f"\n{prop_name} metrics (type: {type(metrics)}):") + print(f"Metrics attributes: {dir(metrics)}") + + # Handle AggregateText specifically + if hasattr(metrics, 'top_occurrences') and metrics.top_occurrences: + print(" Top occurrences:") + for occurrence in metrics.top_occurrences: + print(f" {occurrence.value}: {occurrence.count}") + + # Handle numeric metrics + if isinstance(metrics, (AggregateInteger, AggregateNumber)): + print(f" Available numeric metrics: count={metrics.count}, " + f"min={metrics.minimum}, max={metrics.maximum}, " + f"mean={metrics.mean}, sum={metrics.sum_}") + if metrics.count is not None: + print(f" count: {metrics.count}") + if metrics.minimum is not None: + print(f" minimum: {metrics.minimum}") + if metrics.maximum is not None: + print(f" maximum: {metrics.maximum}") + if metrics.mean is not None: + print(f" mean: {metrics.mean:.2f}") + if metrics.sum_ is not None: + print(f" sum: {metrics.sum_}") - # Handle AggregateReturn object - if hasattr(agg_result, 'properties'): + # Handle non-grouped results (AggregateReturn) + elif hasattr(agg_result, 'properties'): + print(f"Found non-grouped result with properties: {list(agg_result.properties.keys())}") for prop_name, metrics in agg_result.properties.items(): - print(f"\n{prop_name} metrics:") + print(f"\n{prop_name} metrics (type: {type(metrics)}):") + print(f"Metrics attributes: {dir(metrics)}") if isinstance(metrics, (AggregateInteger, AggregateNumber)): + print(f" Available numeric metrics: count={metrics.count}, " + f"min={metrics.minimum}, max={metrics.maximum}, " + f"mean={metrics.mean}, sum={metrics.sum_}") + if metrics.count is not None: + print(f" count: {metrics.count}") if metrics.mean is not None: print(f" mean: {metrics.mean:.2f}") if metrics.maximum is not None: print(f" maximum: {metrics.maximum}") if metrics.minimum is not None: print(f" minimum: {metrics.minimum}") - if metrics.count is not None: - print(f" count: {metrics.count}") if metrics.sum_ is not None: print(f" sum: {metrics.sum_}") - else: - metrics_dict = {k: v for k, v in vars(metrics).items() if not k.startswith('_') and v is not None} - for metric_name, value in metrics_dict.items(): - print(f" {metric_name}: {value}") - + else: + print(f"Warning: Unexpected aggregation result type: {type(agg_result)}") + print(f"Available attributes: {dir(agg_result)}") + if hasattr(agg_result, 'total_count'): print(f"\nTotal Count: {agg_result.total_count}") - + except Exception as e: print(f"Error processing aggregation result: {str(e)}") + print(f"Result type: {type(agg_result)}") + print(f"Error details: {repr(e)}") + print(f"Available attributes: {dir(agg_result)}") + raise # Re-raise the exception to see the full stack trace async def aggregate( @@ -229,9 +260,12 @@ def _build_return_metrics( ) ) elif isinstance(agg, IntegerPropertyAggregation): + metric_name = agg.metrics.value.lower() + if metric_name == "sum": + metric_name = "sum_" metrics_list.append( wc.query.Metrics(agg.property_name).integer( - **{agg.metrics.value.lower(): True} + **{metric_name: True} ) ) elif isinstance(agg, TextPropertyAggregation): @@ -253,57 +287,47 @@ def _build_return_metrics( async def main(): - """Example usage of the query executor.""" - import os - import weaviate - from weaviate.classes.init import Auth - from coordinator.src.agents.nodes.query import QueryAgentDeps, query_agent - - client = weaviate.use_async_with_weaviate_cloud( - cluster_url=os.getenv("WEAVIATE_URL"), - auth_credentials=Auth.api_key(os.getenv("WEAVIATE_API_KEY")), - headers={"X-Openai-Api-Key": os.getenv("OPENAI_APIKEY")}, - ) - - try: - collection = client.collections.get("Menus") - - # Example schema for menu items - schema = { - "properties": { - "menuItem": "string", - "itemDescription": "string", - "price": "number", - "isVegetarian": "boolean" - } - } - - query_agent_deps = QueryAgentDeps(collection_schema=schema) - - # Example query - query_result = await query_agent.run( - "Find vegetarian menu items under $20", - deps=query_agent_deps, - ) - - # Connect the client before searching - await client.connect() - - # Execute search - search_results = await search(collection, query_result.data, limit=10) + """Test AggregateGroupByReturn object parsing.""" + from dataclasses import dataclass + + @dataclass + class GroupedBy: + prop: str + value: str + + @dataclass + class AggregateInteger: + count: int + maximum: None + mean: None + median: None + minimum: None + mode: None + sum_: None + + @dataclass + class AggregateGroup: + grouped_by: GroupedBy + properties: dict + total_count: int - # Execute aggregation - agg_results = [] - if hasattr(query_result.data, 'aggregations'): - for agg in query_result.data.aggregations: - agg_result = await aggregate(collection, agg) - agg_results.append(agg_result) - - # Process and display results - await process_results(agg_results, search_results) - - finally: - await client.close() + @dataclass + class AggregateGroupByReturn: + groups: list + + # Create test object + test_results = [ + AggregateGroupByReturn(groups=[ + AggregateGroup( + grouped_by=GroupedBy(prop='openNow', value='true'), + properties={'name': AggregateInteger(count=13, maximum=None, mean=None, median=None, minimum=None, mode=None, sum_=None)}, + total_count=13 + ) + ]) + ] + + # Process results + await process_results(test_results, []) if __name__ == "__main__": diff --git a/src/lm/pydantic_agents/nodes.py b/src/lm/pydantic_agents/nodes.py index b7df52d..d1bd50a 100644 --- a/src/lm/pydantic_agents/nodes.py +++ b/src/lm/pydantic_agents/nodes.py @@ -103,7 +103,7 @@ class NumericMetrics(str, Enum): MEDIAN = "MEDIAN" MIN = "MINIMUM" MODE = "MODE" - SUM = "SUM" + SUM_ = "SUM_" TYPE = "TYPE" class TextMetrics(str, Enum): @@ -191,14 +191,22 @@ def build_prompt(self) -> str: You are a search query analyzer that determines which collections need which types of search actions. For each query, analyze which collections need: - 1. semantic search (searches): Used when you need to find or match specific items/documents + 1. aggregations (aggregations): Used when you need to compute statistics, counts, averages, etc. + - Each aggregation should be focused on one specific aspect of the original query + - IMPORTANT: Any questions about "how many", counts, or numbers mean that you should use aggregations + - Common aggregation triggers: + * "how many..." + * "count of..." + * "number of..." + * "total..." + + 2. semantic search (searches): Used when you need to find or match specific items/documents - Break down complex queries into specific, focused search queries - Each search_query should target one specific aspect or question - For example, "what are laptop prices and battery life" should become two queries: * "what are laptop prices" * "what is laptop battery life" - 2. aggregations (aggregations): Used when you need to compute statistics, counts, averages, etc. - - Each aggregation should be focused on one specific aspect of the original query + Key Decision Rules: - A collection can appear in both searches and aggregations if needed @@ -211,6 +219,7 @@ def build_prompt(self) -> str: - Calculate statistics (avg, sum, count) - Group or summarize data - Analyze trends or patterns + - Count number of items or documents Example Response Format: {{ @@ -466,7 +475,7 @@ def build_prompt(self) -> str: - MEAN: Average value - MEDIAN: Middle value - MODE: Most frequent value - - SUM: Sum of all values + - SUM_: Sum of all values - TYPE: Data type information Text properties support: diff --git a/src/test_gorilla/test_pydantic_agent.py b/src/test_gorilla/test_pydantic_agent.py index 7131813..99b7012 100644 --- a/src/test_gorilla/test_pydantic_agent.py +++ b/src/test_gorilla/test_pydantic_agent.py @@ -180,6 +180,8 @@ async def run_queries(client, queries_data: List[Dict], collections: Dict): # Print results print("\nQuery results:") print("Original Query:", result.original_query) + print("Search Queries sent:", result.searches) + print("Aggregation Queries sent:", result.aggregations) print("Final Answer:", result.final_answer) print("Usage Stats:", result.usage)