Skip to content

Commit

Permalink
Merge conversation and chat/assistant (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
nicovank authored Feb 21, 2024
1 parent 12ec783 commit c7ea352
Show file tree
Hide file tree
Showing 7 changed files with 252 additions and 367 deletions.
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ authors = [
{ name="Emery Berger", email="[email protected]" },
{ name="Kyla Levin", email="[email protected]" },
{ name="Nicolas van Kempen", email="[email protected]" },
{ name="Stephen Freund", email="[email protected]" }
{ name="Stephen Freund", email="[email protected]" },
]
dependencies = [
"llm-utils>=0.2.6",
Expand All @@ -18,7 +18,8 @@ dependencies = [
"ansicolors>=1.1.8",
"traitlets>=5.14.1",
"ipdb>=0.13.13",
"ipython>=8.21.0"
"ipython>=8.21.0",
"litellm>=1.26.6",
]
description = "AI-assisted debugging. Uses AI to answer 'why'."
readme = "README.md"
Expand Down
141 changes: 141 additions & 0 deletions src/chatdbg/assistant/lite_assistant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import json
import sys
import time

import litellm
import llm_utils
import openai


class LiteAssistant:
def __init__(self, instructions, model="gpt-4", timeout=30, debug=False):
if debug:
self._log = open(f"chatdbg.log", "w")
else:
self._log = None

self._functions = {}
self._instructions = instructions
self._model = model

def add_function(self, function):
"""
Add a new function to the list of function tools.
The function should have the necessary json spec as its pydoc string.
"""
schema = json.loads(function.__doc__)
assert "name" in schema, "Bad JSON in pydoc for function tool."
self._functions[schema["name"]] = {
"function": function,
"schema": schema,
}

def _make_call(self, tool_call) -> str:
name = tool_call.function.name
args = json.loads(tool_call.function.arguments)
function = self._functions[name]["function"]
return function(**args)

def _print_message(self, message, indent=4, wrap=120) -> None:
def _print_to_file(file, indent):

tool_calls = None
if "tool_calls" in message:
tool_calls = message["tool_calls"]
elif hasattr(message, "tool_calls"):
tool_calls = message.tool_calls

content = None
if "content" in message:
content = message["content"]
elif hasattr(message, "content"):
content = message.content

assert bool(tool_calls) != bool(content)

# The longest role string is 'assistant'.
max_role_length = 9
# We add 3 characters for the brackets and space.
subindent = indent + max_role_length + 3

role = message["role"].upper()
role_indent = max_role_length - len(role)

if tool_calls:
print(
f"{' ' * indent}[{role}]{' ' * role_indent} Function calls:",
file=file,
)
for tool_call in tool_calls:
arguments = json.loads(tool_call.function.arguments)
print(
f"{' ' * (subindent + 4)}{tool_call.function.name}({', '.join([f'{k}={v}' for k, v in arguments.items()])})",
file=file,
)
else:
content = llm_utils.word_wrap_except_code_blocks(
content, wrap - len(role) - indent - 3
)
first, *rest = content.split("\n")
print(f"{' ' * indent}[{role}]{' ' * role_indent} {first}", file=file)
for line in rest:
print(f"{' ' * subindent}{line}", file=file)
print("\n\n", file=file)

# None is the default file value for print().
_print_to_file(None, indent)
if self._log:
_print_to_file(self._log, 0)

def run(self, prompt: str) -> None:
start = time.time()
cost = 0

try:
conversation = [
{"role": "system", "content": self._instructions},
{"role": "user", "content": prompt},
]

for message in conversation:
self._print_message(message)
while True:
completion = litellm.completion(
model=self._model,
messages=conversation,
tools=[
{"type": "function", "function": f["schema"]}
for f in self._functions.values()
],
)

cost += litellm.completion_cost(completion)

choice = completion.choices[0]
self._print_message(choice.message)

if choice.finish_reason == "tool_calls":
responses = []
for tool_call in choice.message.tool_calls:
function_response = self._make_call(tool_call)
response = {
"tool_call_id": tool_call.id,
"role": "tool",
"name": tool_call.function.name,
"content": function_response,
}
responses.append(response)
self._print_message(response)
conversation.append(choice.message)
conversation.extend(responses)
elif choice.finish_reason == "stop":
break
else:
print(f"Not found: {choice.finish_reason}.")
sys.exit(1)

elapsed = time.time() - start
print(f"Elapsed time: {elapsed:.2f} seconds")
print(f"Total cost: {cost:.2f}$")
except openai.OpenAIError as e:
print(f"*** OpenAI Error: {e}")
Loading

0 comments on commit c7ea352

Please sign in to comment.