-
Notifications
You must be signed in to change notification settings - Fork 51
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge conversation and chat/assistant (#35)
- Loading branch information
Showing
7 changed files
with
252 additions
and
367 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,7 @@ authors = [ | |
{ name="Emery Berger", email="[email protected]" }, | ||
{ name="Kyla Levin", email="[email protected]" }, | ||
{ name="Nicolas van Kempen", email="[email protected]" }, | ||
{ name="Stephen Freund", email="[email protected]" } | ||
{ name="Stephen Freund", email="[email protected]" }, | ||
] | ||
dependencies = [ | ||
"llm-utils>=0.2.6", | ||
|
@@ -18,7 +18,8 @@ dependencies = [ | |
"ansicolors>=1.1.8", | ||
"traitlets>=5.14.1", | ||
"ipdb>=0.13.13", | ||
"ipython>=8.21.0" | ||
"ipython>=8.21.0", | ||
"litellm>=1.26.6", | ||
] | ||
description = "AI-assisted debugging. Uses AI to answer 'why'." | ||
readme = "README.md" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
import json | ||
import sys | ||
import time | ||
|
||
import litellm | ||
import llm_utils | ||
import openai | ||
|
||
|
||
class LiteAssistant: | ||
def __init__(self, instructions, model="gpt-4", timeout=30, debug=False): | ||
if debug: | ||
self._log = open(f"chatdbg.log", "w") | ||
else: | ||
self._log = None | ||
|
||
self._functions = {} | ||
self._instructions = instructions | ||
self._model = model | ||
|
||
def add_function(self, function): | ||
""" | ||
Add a new function to the list of function tools. | ||
The function should have the necessary json spec as its pydoc string. | ||
""" | ||
schema = json.loads(function.__doc__) | ||
assert "name" in schema, "Bad JSON in pydoc for function tool." | ||
self._functions[schema["name"]] = { | ||
"function": function, | ||
"schema": schema, | ||
} | ||
|
||
def _make_call(self, tool_call) -> str: | ||
name = tool_call.function.name | ||
args = json.loads(tool_call.function.arguments) | ||
function = self._functions[name]["function"] | ||
return function(**args) | ||
|
||
def _print_message(self, message, indent=4, wrap=120) -> None: | ||
def _print_to_file(file, indent): | ||
|
||
tool_calls = None | ||
if "tool_calls" in message: | ||
tool_calls = message["tool_calls"] | ||
elif hasattr(message, "tool_calls"): | ||
tool_calls = message.tool_calls | ||
|
||
content = None | ||
if "content" in message: | ||
content = message["content"] | ||
elif hasattr(message, "content"): | ||
content = message.content | ||
|
||
assert bool(tool_calls) != bool(content) | ||
|
||
# The longest role string is 'assistant'. | ||
max_role_length = 9 | ||
# We add 3 characters for the brackets and space. | ||
subindent = indent + max_role_length + 3 | ||
|
||
role = message["role"].upper() | ||
role_indent = max_role_length - len(role) | ||
|
||
if tool_calls: | ||
print( | ||
f"{' ' * indent}[{role}]{' ' * role_indent} Function calls:", | ||
file=file, | ||
) | ||
for tool_call in tool_calls: | ||
arguments = json.loads(tool_call.function.arguments) | ||
print( | ||
f"{' ' * (subindent + 4)}{tool_call.function.name}({', '.join([f'{k}={v}' for k, v in arguments.items()])})", | ||
file=file, | ||
) | ||
else: | ||
content = llm_utils.word_wrap_except_code_blocks( | ||
content, wrap - len(role) - indent - 3 | ||
) | ||
first, *rest = content.split("\n") | ||
print(f"{' ' * indent}[{role}]{' ' * role_indent} {first}", file=file) | ||
for line in rest: | ||
print(f"{' ' * subindent}{line}", file=file) | ||
print("\n\n", file=file) | ||
|
||
# None is the default file value for print(). | ||
_print_to_file(None, indent) | ||
if self._log: | ||
_print_to_file(self._log, 0) | ||
|
||
def run(self, prompt: str) -> None: | ||
start = time.time() | ||
cost = 0 | ||
|
||
try: | ||
conversation = [ | ||
{"role": "system", "content": self._instructions}, | ||
{"role": "user", "content": prompt}, | ||
] | ||
|
||
for message in conversation: | ||
self._print_message(message) | ||
while True: | ||
completion = litellm.completion( | ||
model=self._model, | ||
messages=conversation, | ||
tools=[ | ||
{"type": "function", "function": f["schema"]} | ||
for f in self._functions.values() | ||
], | ||
) | ||
|
||
cost += litellm.completion_cost(completion) | ||
|
||
choice = completion.choices[0] | ||
self._print_message(choice.message) | ||
|
||
if choice.finish_reason == "tool_calls": | ||
responses = [] | ||
for tool_call in choice.message.tool_calls: | ||
function_response = self._make_call(tool_call) | ||
response = { | ||
"tool_call_id": tool_call.id, | ||
"role": "tool", | ||
"name": tool_call.function.name, | ||
"content": function_response, | ||
} | ||
responses.append(response) | ||
self._print_message(response) | ||
conversation.append(choice.message) | ||
conversation.extend(responses) | ||
elif choice.finish_reason == "stop": | ||
break | ||
else: | ||
print(f"Not found: {choice.finish_reason}.") | ||
sys.exit(1) | ||
|
||
elapsed = time.time() - start | ||
print(f"Elapsed time: {elapsed:.2f} seconds") | ||
print(f"Total cost: {cost:.2f}$") | ||
except openai.OpenAIError as e: | ||
print(f"*** OpenAI Error: {e}") |
Oops, something went wrong.