-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add system_prompt to chat endpoint (#1480)
- Loading branch information
1 parent
4608df7
commit c2b631a
Showing
6 changed files
with
70 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3964,6 +3964,11 @@ class TestChatView(WisdomServiceAPITestCaseBase): | |
"provider": "non_default_provider", | ||
} | ||
|
||
PAYLOAD_WITH_SYSTEM_PROMPT_OVERRIDE = { | ||
"query": "Payload with a system prompt override", | ||
"system_prompt": "System prompt override", | ||
} | ||
|
||
JSON_RESPONSE = { | ||
"response": "AAP 2.5 introduces an updated, unified UI.", | ||
"conversation_id": "123e4567-e89b-12d3-a456-426614174000", | ||
|
@@ -4031,7 +4036,10 @@ def json(self): | |
json_response = { | ||
"detail": "Internal server error", | ||
} | ||
elif kwargs["json"]["query"] == TestChatView.PAYLOAD_WITH_MODEL_AND_PROVIDER["query"]: | ||
elif ( | ||
kwargs["json"]["query"] == TestChatView.PAYLOAD_WITH_MODEL_AND_PROVIDER["query"] | ||
or kwargs["json"]["query"] == TestChatView.PAYLOAD_WITH_SYSTEM_PROMPT_OVERRIDE["query"] | ||
): | ||
status_code = 200 | ||
json_response["response"] = input | ||
return MockResponse(json_response, status_code) | ||
|
@@ -4160,6 +4168,10 @@ def test_chat_with_model_and_provider(self): | |
self.assertIn('"model": "non_default_model"', r.data["response"]) | ||
self.assertIn('"provider": "non_default_provider"', r.data["response"]) | ||
|
||
def test_chat_with_system_prompt_override(self): | ||
r = self.assert_test(TestChatView.PAYLOAD_WITH_SYSTEM_PROMPT_OVERRIDE) | ||
self.assertIn(TestChatView.PAYLOAD_WITH_SYSTEM_PROMPT_OVERRIDE["query"], r.data["response"]) | ||
|
||
@override_settings(SEGMENT_WRITE_KEY="DUMMY_KEY_VALUE") | ||
def test_operational_telemetry(self): | ||
self.client.force_authenticate(user=self.user) | ||
|
@@ -4268,6 +4280,39 @@ def test_operational_telemetry_anonymizer(self): | |
"Hello [email protected]", | ||
) | ||
|
||
@override_settings(SEGMENT_WRITE_KEY="DUMMY_KEY_VALUE") | ||
def test_operational_telemetry_with_system_prompt_override(self): | ||
self.client.force_authenticate(user=self.user) | ||
with ( | ||
patch.object( | ||
apps.get_app_config("ai"), | ||
"get_model_pipeline", | ||
Mock(return_value=HttpChatBotPipeline(mock_pipeline_config("http"))), | ||
), | ||
self.assertLogs(logger="root", level="DEBUG") as log, | ||
): | ||
r = self.query_with_no_error(TestChatView.PAYLOAD_WITH_SYSTEM_PROMPT_OVERRIDE) | ||
self.assertEqual(r.status_code, HTTPStatus.OK) | ||
segment_events = self.extractSegmentEventsFromLog(log) | ||
self.assertEqual( | ||
segment_events[0]["properties"]["chat_prompt"], | ||
TestChatView.PAYLOAD_WITH_SYSTEM_PROMPT_OVERRIDE["query"], | ||
) | ||
self.assertEqual(segment_events[0]["properties"]["modelName"], "granite-8b") | ||
self.assertIn( | ||
TestChatView.PAYLOAD_WITH_SYSTEM_PROMPT_OVERRIDE["query"], | ||
segment_events[0]["properties"]["chat_response"], | ||
) | ||
self.assertEqual( | ||
segment_events[0]["properties"]["chat_truncated"], | ||
TestChatView.JSON_RESPONSE["truncated"], | ||
) | ||
self.assertEqual(len(segment_events[0]["properties"]["chat_referenced_documents"]), 0) | ||
self.assertEqual( | ||
segment_events[0]["properties"]["chat_system_prompt"], | ||
TestChatView.PAYLOAD_WITH_SYSTEM_PROMPT_OVERRIDE["system_prompt"], | ||
) | ||
|
||
def test_chat_rate_limit(self): | ||
# Call chat API five times using self.user | ||
for i in range(5): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters