Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Abstract Function Calling as a framework #49

Merged
merged 13 commits into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
183 changes: 177 additions & 6 deletions notebooks/pydantic-ai.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 2,
"metadata": {},
"outputs": [
{
Expand All @@ -14,7 +14,7 @@
"| Span Processor: SimpleSpanProcessor\n",
"| Collector Endpoint: https://app.phoenix.arize.com/v1/traces\n",
"| Transport: HTTP\n",
"| Transport Headers: {'api_key': '****'}\n",
"| Transport Headers: {'api_key': '****', 'authorization': '****'}\n",
"| \n",
"| Using a default SpanProcessor. `add_span_processor` will overwrite this default.\n",
"| \n",
Expand All @@ -40,7 +40,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -51,7 +51,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 4,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -83,7 +83,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 5,
"metadata": {},
"outputs": [
{
Expand All @@ -97,7 +97,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"greeting=\"Hello, Alice! It's great to see you!\"\n"
"greeting=\"Hello, Alice! Hope you're having a wonderful day!\"\n"
]
}
],
Expand Down Expand Up @@ -160,6 +160,177 @@
"{\"messages\": [{\"role\": \"system\", \"content\": \"You are a personalized greeter AI. Return a short greeting for the user.\"}, {\"role\": \"system\", \"content\": \"The user's name is 'Alice'.\"}, {\"role\": \"user\", \"content\": \"Hi, can you greet me?\"}], \"model\": \"gpt-4o\", \"n\": 1, \"parallel_tool_calls\": true, \"stream\": false, \"tool_choice\": \"required\", \"tools\": [{\"type\": \"function\", \"function\": {\"name\": \"final_result\", \"description\": \"Structured output from the AI.\", \"parameters\": {\"properties\": {\"greeting\": {\"description\": \"A short greeting to the user\", \"title\": \"Greeting\", \"type\": \"string\"}}, \"required\": [\"greeting\"], \"title\": \"GreetingResult\", \"type\": \"object\"}}}]}\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Step 1] Analyzing the user's problem to find relevant database queries...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Failed to export batch code: 204, reason: \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Analysis result (queries):\n",
"{\n",
" \"queries\": [\n",
" \"SELECT name, rating, distance FROM restaurants WHERE location = 'current location' ORDER BY rating DESC, distance ASC;\"\n",
" ]\n",
"}\n",
"\n",
"[Step 2] Formatting the queries into Pydantic API requests...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Failed to export batch code: 204, reason: \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Formatting result (API requests):\n",
"{\n",
" \"requests\": [\n",
" {\n",
" \"query_text\": \"SELECT name, rating, distance FROM restaurants WHERE location = 'current location' ORDER BY rating DESC, distance ASC;\",\n",
" \"endpoint\": \"/execute-query\"\n",
" }\n",
" ]\n",
"}\n"
]
}
],
"source": [
"import nest_asyncio\n",
"nest_asyncio.apply()\n",
"\n",
"import asyncio\n",
"import json\n",
"from dataclasses import dataclass\n",
"from typing import List\n",
"\n",
"from pydantic import BaseModel, Field\n",
"from pydantic_ai import Agent, RunContext\n",
"\n",
"\n",
"@dataclass\n",
"class ProblemContext:\n",
" user_name: str\n",
"\n",
"\n",
"class DatabaseQueryAnalysis(BaseModel):\n",
" queries: List[str] = Field(default_factory=list)\n",
"\n",
"\n",
"analysis_agent = Agent(\n",
" model=\"openai:gpt-4o\",\n",
" deps_type=ProblemContext,\n",
" result_type=DatabaseQueryAnalysis,\n",
" system_prompt=(\n",
" \"You are an AI that, given a user's problem, identifies what database queries \"\n",
" \"would be needed to retrieve information that solves the user's problem.\"\n",
" ),\n",
")\n",
"\n",
"\n",
"@analysis_agent.system_prompt\n",
"async def analysis_agent_system_prompt(ctx: RunContext[ProblemContext]) -> str:\n",
" return (\n",
" f\"The user's name is {ctx.deps.user_name!r}. \"\n",
" \"Analyze the user's input and suggest relevant database queries.\"\n",
" )\n",
"\n",
"\n",
"class QueryAPIRequest(BaseModel):\n",
" query_text: str = Field(description=\"The raw query to be executed.\")\n",
" endpoint: str = Field(\n",
" default=\"/execute-query\",\n",
" description=\"The endpoint where the query should be sent.\"\n",
" )\n",
"\n",
"\n",
"class QueryAPIRequests(BaseModel):\n",
" requests: List[QueryAPIRequest] = Field(default_factory=list)\n",
"\n",
"\n",
"formatting_agent = Agent(\n",
" model=\"openai:gpt-4o\",\n",
" deps_type=ProblemContext,\n",
" result_type=QueryAPIRequests,\n",
" system_prompt=(\n",
" \"You are an AI that formats database queries into the provided Pydantic BaseModels for API requests.\"\n",
" ),\n",
")\n",
"\n",
"\n",
"@formatting_agent.system_prompt\n",
"async def formatting_agent_prompt(ctx: RunContext[ProblemContext]) -> str:\n",
" return (\n",
" f\"The user's name is {ctx.deps.user_name!r}. \"\n",
" \"Convert the list of database queries into `QueryAPIRequest` objects, \"\n",
" \"wrapped in the `QueryAPIRequests` model.\"\n",
" )\n",
"\n",
"\n",
"async def main():\n",
" deps = ProblemContext(user_name=\"Alice\")\n",
"\n",
" user_problem = (\n",
" \"I need to find the top-rated restaurants near me. \"\n",
" \"Show me some options sorted by rating and distance.\"\n",
" )\n",
"\n",
" print(\"[Step 1] Analyzing the user's problem to find relevant database queries...\")\n",
" analysis_result = await analysis_agent.run(user_problem, deps=deps)\n",
"\n",
" # Convert to dict and JSON-serialize with indentation\n",
" print(\"Analysis result (queries):\")\n",
" print(json.dumps(analysis_result.data.model_dump(), indent=2))\n",
"\n",
" queries_to_format = analysis_result.data.queries\n",
" second_input = (\n",
" \"Here are the queries to format:\\n\" + \"\\n\".join(queries_to_format)\n",
" )\n",
"\n",
" print(\"\\n[Step 2] Formatting the queries into Pydantic API requests...\")\n",
" formatting_result = await formatting_agent.run(second_input, deps=deps)\n",
"\n",
" # Again, convert to dict and JSON-serialize with indentation\n",
" print(\"Formatting result (API requests):\")\n",
" print(json.dumps(formatting_result.data.model_dump(), indent=2))\n",
"\n",
"\n",
"# If you're in an environment that supports top-level await (e.g., Jupyter):\n",
"await main()\n",
"\n",
"# If you're in a standard Python script:\n",
"# if __name__ == \"__main__\":\n",
"# asyncio.run(main())\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
Empty file added src/__init__.py
Empty file.
3 changes: 3 additions & 0 deletions src/lm/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
# together models are accessed through the openai SDK with a different base URL
LMModelProvider = Literal["ollama", "openai", "anthropic", "cohere", "together"]

# need to add models to this...!
# fixes test_lm.py

class LMService():
def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/lm/pydantic_agent/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Tie together the nodes and executors here
1 change: 1 addition & 0 deletions src/lm/pydantic_agent/executors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Weaviate Query Executor
1 change: 1 addition & 0 deletions src/lm/pydantic_agent/nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Nodes for prompts
Loading
Loading