Skip to content

Commit

Permalink
feat: add output schema support to model initialization and improve m…
Browse files Browse the repository at this point in the history
…essage handling
  • Loading branch information
chakravarthik27 committed Jan 21, 2025
1 parent 7bd46d4 commit a011ba0
Showing 1 changed file with 8 additions and 19 deletions.
27 changes: 8 additions & 19 deletions langtest/modelhandler/llm_modelhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,15 @@ def load_model(cls, hub: str, path: str, *args, **kwargs) -> "PretrainedModelFor
if hub == "openai":
from langchain_openai.chat_models import ChatOpenAI

model = ChatOpenAI(
model=path, *args, **filtered_kwargs
).with_structured_output(output_schema)
model = ChatOpenAI(model=path, *args, **filtered_kwargs)
elif hub == "azure-openai":
from langchain_openai.chat_models import AzureChatOpenAI

model = AzureChatOpenAI(
model=path, *args, **filtered_kwargs
).with_structured_output(output_schema)
model = AzureChatOpenAI(model=path, *args, **filtered_kwargs)

# adding output schema to the model if provided
if output_schema:
model = model.with_structured_output(output_schema)

return cls(hub, model, *args, **filtered_kwargs)

Expand Down Expand Up @@ -196,18 +196,9 @@ def predict(self, text: Union[str, dict], prompt: dict, *args, **kwargs):
try:
# loading a prompt manager
from langtest.prompts import PromptManager
from langchain_core.messages import AIMessage
from langchain_core.messages import BaseMessage
from langchain_core.language_models.llms import BaseLLM
from langchain_core.language_models.chat_models import BaseChatModel
from pydantic import BaseModel

# output parsing
output_parser = self.kwargs.get("output_schema", None)
if output_parser and issubclass(output_parser, BaseModel):
output_parser = output_parser
# else:
# from langchain.output_parsers import PydanticOutputParser
# output_parser = PydanticOutputParser(pydantic_object=output_parser)

# prompt configuration
prompt_manager = PromptManager()
Expand All @@ -224,9 +215,7 @@ def predict(self, text: Union[str, dict], prompt: dict, *args, **kwargs):

output = llmchain.invoke(text)

if isinstance(output, dict):
return output.get(llmchain.output_key, "")
elif isinstance(output, AIMessage):
if isinstance(output, BaseMessage):
return output.content

return output
Expand Down

0 comments on commit a011ba0

Please sign in to comment.