Skip to content

Commit

Permalink
Add system_prompt to chat endpoint (#1480)
Browse files Browse the repository at this point in the history
  • Loading branch information
TamiTakamiya authored Jan 7, 2025
1 parent 4608df7 commit c2b631a
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 1 deletion.
3 changes: 3 additions & 0 deletions ansible_ai_connect/ai/api/model_pipelines/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ class ChatBotParameters:
provider: str
model_id: str
conversation_id: Optional[str]
system_prompt: str

@classmethod
def init(
Expand All @@ -200,13 +201,15 @@ def init(
provider: Optional[str] = None,
model_id: Optional[str] = None,
conversation_id: Optional[str] = None,
system_prompt: Optional[str] = None,
):
return cls(
request=request,
query=query,
provider=provider,
model_id=model_id,
conversation_id=conversation_id,
system_prompt=system_prompt,
)


Expand Down
5 changes: 5 additions & 0 deletions ansible_ai_connect/ai/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,11 @@ class ChatRequestSerializer(serializers.Serializer):
label="Provider name",
help_text=("A name that identifies a LLM provider."),
)
system_prompt = serializers.CharField(
required=False,
label="System prompt",
help_text=("An optional non-default system prompt to be used on LLM (debug mode only)."),
)


class ReferencedDocumentsSerializer(serializers.Serializer):
Expand Down
3 changes: 3 additions & 0 deletions ansible_ai_connect/ai/api/telemetry/schema1.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ class ChatBotResponseDocsReferences:
@define
class ChatBotBaseEvent:
chat_prompt: str = field(validator=validators.instance_of(str), converter=str, default="")
chat_system_prompt: str = field(
validator=validators.instance_of(str), converter=str, default=""
)
chat_response: str = field(validator=validators.instance_of(str), converter=str, default="")
chat_truncated: bool = field(
validator=validators.instance_of(bool), converter=bool, default=False
Expand Down
47 changes: 46 additions & 1 deletion ansible_ai_connect/ai/api/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3964,6 +3964,11 @@ class TestChatView(WisdomServiceAPITestCaseBase):
"provider": "non_default_provider",
}

PAYLOAD_WITH_SYSTEM_PROMPT_OVERRIDE = {
"query": "Payload with a system prompt override",
"system_prompt": "System prompt override",
}

JSON_RESPONSE = {
"response": "AAP 2.5 introduces an updated, unified UI.",
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
Expand Down Expand Up @@ -4031,7 +4036,10 @@ def json(self):
json_response = {
"detail": "Internal server error",
}
elif kwargs["json"]["query"] == TestChatView.PAYLOAD_WITH_MODEL_AND_PROVIDER["query"]:
elif (
kwargs["json"]["query"] == TestChatView.PAYLOAD_WITH_MODEL_AND_PROVIDER["query"]
or kwargs["json"]["query"] == TestChatView.PAYLOAD_WITH_SYSTEM_PROMPT_OVERRIDE["query"]
):
status_code = 200
json_response["response"] = input
return MockResponse(json_response, status_code)
Expand Down Expand Up @@ -4160,6 +4168,10 @@ def test_chat_with_model_and_provider(self):
self.assertIn('"model": "non_default_model"', r.data["response"])
self.assertIn('"provider": "non_default_provider"', r.data["response"])

def test_chat_with_system_prompt_override(self):
r = self.assert_test(TestChatView.PAYLOAD_WITH_SYSTEM_PROMPT_OVERRIDE)
self.assertIn(TestChatView.PAYLOAD_WITH_SYSTEM_PROMPT_OVERRIDE["query"], r.data["response"])

@override_settings(SEGMENT_WRITE_KEY="DUMMY_KEY_VALUE")
def test_operational_telemetry(self):
self.client.force_authenticate(user=self.user)
Expand Down Expand Up @@ -4268,6 +4280,39 @@ def test_operational_telemetry_anonymizer(self):
"Hello [email protected]",
)

@override_settings(SEGMENT_WRITE_KEY="DUMMY_KEY_VALUE")
def test_operational_telemetry_with_system_prompt_override(self):
self.client.force_authenticate(user=self.user)
with (
patch.object(
apps.get_app_config("ai"),
"get_model_pipeline",
Mock(return_value=HttpChatBotPipeline(mock_pipeline_config("http"))),
),
self.assertLogs(logger="root", level="DEBUG") as log,
):
r = self.query_with_no_error(TestChatView.PAYLOAD_WITH_SYSTEM_PROMPT_OVERRIDE)
self.assertEqual(r.status_code, HTTPStatus.OK)
segment_events = self.extractSegmentEventsFromLog(log)
self.assertEqual(
segment_events[0]["properties"]["chat_prompt"],
TestChatView.PAYLOAD_WITH_SYSTEM_PROMPT_OVERRIDE["query"],
)
self.assertEqual(segment_events[0]["properties"]["modelName"], "granite-8b")
self.assertIn(
TestChatView.PAYLOAD_WITH_SYSTEM_PROMPT_OVERRIDE["query"],
segment_events[0]["properties"]["chat_response"],
)
self.assertEqual(
segment_events[0]["properties"]["chat_truncated"],
TestChatView.JSON_RESPONSE["truncated"],
)
self.assertEqual(len(segment_events[0]["properties"]["chat_referenced_documents"]), 0)
self.assertEqual(
segment_events[0]["properties"]["chat_system_prompt"],
TestChatView.PAYLOAD_WITH_SYSTEM_PROMPT_OVERRIDE["system_prompt"],
)

def test_chat_rate_limit(self):
# Call chat API five times using self.user
for i in range(5):
Expand Down
9 changes: 9 additions & 0 deletions ansible_ai_connect/ai/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,6 +1156,7 @@ def post(self, request) -> Response:

data = {}
req_query = "<undefined>"
req_system_prompt = "<undefined>"
req_model_id = "<undefined>"
req_provider = "<undefined>"
duration = 0
Expand All @@ -1169,6 +1170,11 @@ def post(self, request) -> Response:
raise ChatbotInvalidRequestException()

req_query = request_serializer.validated_data["query"]
req_system_prompt = (
request_serializer.validated_data["system_prompt"]
if "system_prompt" in request_serializer.validated_data
else None
)
req_model_id = (
request_serializer.validated_data["model"]
if "model" in request_serializer.validated_data
Expand All @@ -1195,6 +1201,7 @@ def post(self, request) -> Response:
ChatBotParameters.init(
request=request,
query=req_query,
system_prompt=req_system_prompt,
model_id=req_model_id,
provider=req_provider,
conversation_id=conversation_id,
Expand All @@ -1210,6 +1217,7 @@ def post(self, request) -> Response:

operational_event = ChatBotOperationalEvent(
chat_prompt=req_query,
chat_system_prompt=req_system_prompt,
chat_response=data["response"],
chat_truncated=bool(data["truncated"]),
chat_referenced_documents=[
Expand All @@ -1234,6 +1242,7 @@ def post(self, request) -> Response:
detail = data.get("detail", "")
operational_event = ChatBotOperationalEvent(
chat_prompt=req_query,
chat_system_prompt=req_system_prompt,
provider_id=req_provider,
modelName=req_model_id,
rh_user_org_id=rh_user_org_id,
Expand Down
4 changes: 4 additions & 0 deletions tools/openapi-schema/ansible-ai-connect-service.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,10 @@ components:
type: string
title: Provider name
description: A name that identifies a LLM provider.
system_prompt:
type: string
description: An optional non-default system prompt to be used on LLM (debug
mode only).
required:
- query
ChatResponse:
Expand Down

0 comments on commit c2b631a

Please sign in to comment.