Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement Claude 3 tool use support (#506) #671

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 38 additions & 7 deletions src/llm/claude_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Optional, List, Dict, Any
from anthropic import Anthropic

from src.config import Config
from src.llm.tools import AVAILABLE_TOOLS, Tool

class Claude:
def __init__(self):
Expand All @@ -9,18 +11,47 @@ def __init__(self):
self.client = Anthropic(
api_key=api_key,
)
self.tool_schemas = [
{
"name": tool.name,
"description": tool.description,
"parameters": {
"type": "object",
"properties": {
name: {
"type": param.type,
"description": param.description,
**({"enum": param.enum} if param.enum else {})
}
for name, param in tool.parameters.items()
},
"required": tool.required
}
}
for tool in AVAILABLE_TOOLS
]

def inference(self, model_id: str, prompt: str) -> str:
message = self.client.messages.create(
max_tokens=4096,
messages=[
def inference(
self,
model_id: str,
prompt: str,
tools: Optional[List[Dict[str, Any]]] = None
) -> str:
kwargs = {
"max_tokens": 4096,
"messages": [
{
"role": "user",
"content": prompt.strip(),
}
],
model=model_id,
temperature=0
)
"model": model_id,
"temperature": 0
}

# Add tool schemas for Claude 3 models
if "claude-3" in model_id:
kwargs["tools"] = tools or self.tool_schemas

message = self.client.messages.create(**kwargs)
return message.content[0].text
93 changes: 93 additions & 0 deletions src/llm/tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""Tool schemas for Claude 3 function calling.

This module defines the tool schemas for Claude 3's function calling capabilities.
Each tool follows Claude's function calling schema format.
"""

from typing import Dict, List, Optional
from dataclasses import dataclass


@dataclass
class ToolParameter:
"""Parameter definition for a tool."""
type: str
description: str
enum: Optional[List[str]] = None
required: bool = True


@dataclass
class Tool:
"""Tool definition following Claude's schema."""
name: str
description: str
parameters: Dict[str, ToolParameter]
required: List[str]


# Core tool definitions
BROWSE_TOOL = Tool(
name="browse_web",
description="Browse a web page and extract its content",
parameters={
"url": ToolParameter(
type="string",
description="The URL to browse"
)
},
required=["url"]
)

READ_FILE_TOOL = Tool(
name="read_file",
description="Read the contents of a file",
parameters={
"path": ToolParameter(
type="string",
description="The path to the file to read"
)
},
required=["path"]
)

WRITE_FILE_TOOL = Tool(
name="write_file",
description="Write content to a file",
parameters={
"path": ToolParameter(
type="string",
description="The path to write the file to"
),
"content": ToolParameter(
type="string",
description="The content to write to the file"
)
},
required=["path", "content"]
)

RUN_CODE_TOOL = Tool(
name="run_code",
description="Execute code in a sandboxed environment",
parameters={
"code": ToolParameter(
type="string",
description="The code to execute"
),
"language": ToolParameter(
type="string",
description="The programming language",
enum=["python", "javascript", "bash"]
)
},
required=["code", "language"]
)

# List of all available tools
AVAILABLE_TOOLS = [
BROWSE_TOOL,
READ_FILE_TOOL,
WRITE_FILE_TOOL,
RUN_CODE_TOOL
]
143 changes: 143 additions & 0 deletions tests/test_claude_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
"""Tests for Claude 3 tool use functionality."""

import pytest
from typing import Dict, Any

from src.llm.tools import (
Tool,
ToolParameter,
BROWSE_TOOL,
READ_FILE_TOOL,
WRITE_FILE_TOOL,
RUN_CODE_TOOL,
AVAILABLE_TOOLS
)
from src.llm.claude_client import Claude


def test_tool_parameter_creation():
"""Test creating tool parameters with various configurations."""
param = ToolParameter(
type="string",
description="Test parameter",
enum=["a", "b", "c"],
required=True
)
assert param.type == "string"
assert param.description == "Test parameter"
assert param.enum == ["a", "b", "c"]
assert param.required is True

# Test without optional fields
basic_param = ToolParameter(
type="integer",
description="Basic parameter"
)
assert basic_param.type == "integer"
assert basic_param.description == "Basic parameter"
assert basic_param.enum is None
assert basic_param.required is True


def test_tool_creation():
"""Test creating tools with parameters."""
tool = Tool(
name="test_tool",
description="Test tool",
parameters={
"param1": ToolParameter(
type="string",
description="Parameter 1"
)
},
required=["param1"]
)
assert tool.name == "test_tool"
assert tool.description == "Test tool"
assert len(tool.parameters) == 1
assert "param1" in tool.parameters
assert tool.required == ["param1"]


def test_browse_tool_schema():
"""Test browse tool schema structure."""
assert BROWSE_TOOL.name == "browse_web"
assert "url" in BROWSE_TOOL.parameters
assert BROWSE_TOOL.parameters["url"].type == "string"
assert BROWSE_TOOL.required == ["url"]


def test_read_file_tool_schema():
"""Test read file tool schema structure."""
assert READ_FILE_TOOL.name == "read_file"
assert "path" in READ_FILE_TOOL.parameters
assert READ_FILE_TOOL.parameters["path"].type == "string"
assert READ_FILE_TOOL.required == ["path"]


def test_write_file_tool_schema():
"""Test write file tool schema structure."""
assert WRITE_FILE_TOOL.name == "write_file"
assert "path" in WRITE_FILE_TOOL.parameters
assert "content" in WRITE_FILE_TOOL.parameters
assert WRITE_FILE_TOOL.required == ["path", "content"]


def test_run_code_tool_schema():
"""Test run code tool schema structure."""
assert RUN_CODE_TOOL.name == "run_code"
assert "code" in RUN_CODE_TOOL.parameters
assert "language" in RUN_CODE_TOOL.parameters
assert RUN_CODE_TOOL.parameters["language"].enum == ["python", "javascript", "bash"]
assert RUN_CODE_TOOL.required == ["code", "language"]


def test_claude_client_tool_schemas():
"""Test Claude client tool schema generation."""
client = Claude()

# Verify tool schemas are properly formatted for Claude API
assert len(client.tool_schemas) == len(AVAILABLE_TOOLS)

# Check schema structure for first tool
schema = client.tool_schemas[0]
assert isinstance(schema, dict)
assert "name" in schema
assert "description" in schema
assert "parameters" in schema
assert schema["parameters"]["type"] == "object"
assert "properties" in schema["parameters"]
assert "required" in schema["parameters"]


@pytest.mark.parametrize("model_id,should_have_tools", [
("claude-3-opus-20240229", True),
("claude-3-sonnet-20240229", True),
("claude-2.1", False),
("claude-2.0", False),
])
def test_claude_inference_tool_inclusion(model_id: str, should_have_tools: bool):
"""Test tool inclusion in Claude inference based on model."""
client = Claude()
prompt = "Test prompt"

# Mock the create method to capture kwargs
def mock_create(**kwargs) -> Dict[str, Any]:
class MockResponse:
content = [type("Content", (), {"text": "Mock response"})]

# Verify tools presence based on model
if should_have_tools:
assert "tools" in kwargs
assert isinstance(kwargs["tools"], list)
assert len(kwargs["tools"]) == len(AVAILABLE_TOOLS)
else:
assert "tools" not in kwargs

return MockResponse()

# Replace create method with mock
client.client.messages.create = mock_create

# Run inference
client.inference(model_id=model_id, prompt=prompt)