Skip to content

Commit

Permalink
Chat rate limiting (#1425)
Browse files Browse the repository at this point in the history
* Chat rate limiting

* Insert 3s delay before showing 429 msg

* Test fix
  • Loading branch information
TamiTakamiya authored Nov 25, 2024
1 parent 55a905a commit 104c09b
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 8 deletions.
39 changes: 36 additions & 3 deletions ansible_ai_connect/ai/api/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import requests
from django.apps import apps
from django.conf import settings
from django.contrib.auth import get_user_model
from django.test import modify_settings, override_settings
from django.urls import reverse
from langchain_core.runnables import Runnable, RunnableConfig
Expand Down Expand Up @@ -4046,14 +4047,22 @@ def query_with_no_error(self, payload, mock_post):
"requests.post",
side_effect=mocked_requests_post,
)
@override_settings(CHATBOT_URL="")
def query_without_chat_config(self, payload, mock_post):
return self.client.post(reverse("chat"), payload, format="json")

def assert_test(
self, payload, expected_status_code=200, expected_exception=None, expected_log_message=None
self,
payload,
expected_status_code=200,
expected_exception=None,
expected_log_message=None,
user=None,
):
mocked_client = Mock()
self.client.force_authenticate(user=self.user)
if user is None:
user = self.user
self.client.force_authenticate(user=user)
with (
patch.object(
apps.get_app_config("ai"),
Expand All @@ -4062,7 +4071,7 @@ def assert_test(
),
self.assertLogs(logger="root", level="DEBUG") as log,
):
self.client.force_authenticate(user=self.user)
self.client.force_authenticate(user=user)

if expected_exception == ChatbotNotEnabledException:
r = self.query_without_chat_config(payload)
Expand Down Expand Up @@ -4260,3 +4269,27 @@ def test_operational_telemetry_anonymizer(self):
segment_events[0]["properties"]["chat_prompt"],
"Hello [email protected]",
)

def test_chat_rate_limit(self):
# Call chat API five times using self.user
for i in range(5):
self.assert_test(TestChatView.VALID_PAYLOAD)
try:
username = "u" + "".join(random.choices(string.digits, k=5))
password = "secret"
email = "[email protected]"
self.user2 = get_user_model().objects.create_user(
username=username,
email=email,
password=password,
)
(org, _) = Organization.objects.get_or_create(id=123, telemetry_opt_out=False)
self.user2.organization = org
# Call chart API five times using self.user2
for i in range(5):
self.assert_test(TestChatView.VALID_PAYLOAD, user=self.user2)
# The next chat API call should be the 11th from two users and should receive a 429.
self.assert_test(TestChatView.VALID_PAYLOAD, expected_status_code=429, user=self.user2)
finally:
if self.user2:
self.user2.delete()
5 changes: 5 additions & 0 deletions ansible_ai_connect/ai/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
)
from ansible_ai_connect.users.models import User

from ...users.throttling import EndpointRateThrottle
from ..feature_flags import FeatureFlags
from .data.data_model import ContentMatchPayloadData, ContentMatchResponseDto
from .model_pipelines.exceptions import ModelTimeoutError
Expand Down Expand Up @@ -1117,11 +1118,15 @@ class Chat(APIView):
Send a message to the backend chatbot service and get a reply.
"""

class ChatEndpointThrottle(EndpointRateThrottle):
scope = "chat"

permission_classes = [
permissions.IsAuthenticated,
IsAuthenticatedOrTokenHasScope,
]
required_scopes = ["read", "write"]
throttle_classes = [ChatEndpointThrottle]

def __init__(self):
self.chatbot_enabled = (
Expand Down
2 changes: 2 additions & 0 deletions ansible_ai_connect/main/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ def is_ssl_enabled(value: str) -> bool:
ME_USER_CACHE_TIMEOUT_SEC = int(os.environ.get("ME_USER_CACHE_TIMEOUT_SEC", 30))
ME_USER_RATE_THROTTLE = os.environ.get("ME_USER_RATE_THROTTLE") or "50/minute"
SPECIAL_THROTTLING_GROUPS = ["test"]
CHAT_RATE_THROTTLE = os.environ.get("CHAT_RATE_THROTTLE") or "10/minute"

AMS_ORG_CACHE_TIMEOUT_SEC = int(os.environ.get("AMS_ORG_CACHE_TIMEOUT_SEC", 60 * 60 * 24))
AMS_SUBSCRIPTION_CACHE_TIMEOUT_SEC = int(
Expand All @@ -304,6 +305,7 @@ def is_ssl_enabled(value: str) -> bool:
"user": COMPLETION_USER_RATE_THROTTLE,
"test": "100000/minute",
"me": ME_USER_RATE_THROTTLE,
"chat": CHAT_RATE_THROTTLE,
},
"PAGE_SIZE": 10,
"DEFAULT_AUTHENTICATION_CLASSES": [
Expand Down
23 changes: 23 additions & 0 deletions ansible_ai_connect/users/throttling.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,26 @@ def format_rate(num_requests, duration):
86400: "day",
}[duration]
return f"{num_requests}/{duration_unit}"


class EndpointRateThrottle(GroupSpecificThrottle):
"""
Rate limit on the total number of calls from authenticated users. For test
and unauthenticated users, this works in the same way as its base class,
GroupSpecificThrottle
"""

def get_scope(self, request, view):
scope = super().get_scope(request, view)
return scope if scope != "user" else self.scope

def get_cache_key(self, request, view):
# For test and unauthenticated users, return the same cache key as
# the one GroupSpecificThrottle provides.
scope = super().get_scope(request, view)
if scope != "user" or not request.user.is_authenticated:
return super().get_cache_key(request, view)

# Return the same cache key for all authenticated users.
ident = "user"
return self.cache_format % {"scope": self.scope, "ident": ident}
51 changes: 50 additions & 1 deletion ansible_ai_connect_chatbot/src/App.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { MemoryRouter } from "react-router-dom";
import { App } from "./App";
import { ColorThemeSwitch } from "./ColorThemeSwitch/ColorThemeSwitch";
import userEvent from "@testing-library/user-event";
import axios, { AxiosError } from "axios";
import axios, { AxiosError, AxiosHeaders } from "axios";

describe("App tests", () => {
const renderApp = (debug = false) => {
Expand All @@ -27,13 +27,47 @@ describe("App tests", () => {
},
);
};

const delay = (ms: number) => new Promise((res) => setTimeout(res, ms));

const createError = (message: string, status: number): AxiosError => {
const request = { path: "/chat" };
const headers = new AxiosHeaders({
"Content-Type": "application/json",
});
const config = {
url: "http://localhost:8000",
headers,
};
const code = "SOME_ERR";

const error = new AxiosError(message, code, config, request);
if (status > 0) {
const response = {
data: {},
status,
statusText: "",
config,
headers,
};
error.response = response;
}
return error;
};

const mockAxios = (status: number, reject = false, timeout = false) => {
const spy = vi.spyOn(axios, "post");
if (reject) {
if (timeout) {
spy.mockImplementationOnce(() =>
Promise.reject(new AxiosError("timeout of 28000ms exceeded")),
);
} else if (status === 429) {
spy.mockImplementationOnce(() =>
Promise.reject(
createError("Request failed with status code 429", 429),
),
);
} else {
spy.mockImplementationOnce(() =>
Promise.reject(new Error("mocked error")),
Expand Down Expand Up @@ -170,6 +204,21 @@ describe("App tests", () => {
).toBeInTheDocument();
});

it("Chat service returns 429 Too Many Requests error", async () => {
mockAxios(429, true);
renderApp();
const textArea = screen.getByLabelText("Send a message...");
await act(async () => userEvent.type(textArea, "Hello"));
const sendButton = screen.getByLabelText("Send button");
await act(async () => fireEvent.click(sendButton));
await delay(3100);
expect(
screen.getByText("Chatbot service is busy with too many requests. ", {
exact: false,
}),
).toBeInTheDocument();
});

it("Chat service returns an unexpected error", async () => {
mockAxios(-1, true);
const view = renderApp();
Expand Down
4 changes: 4 additions & 0 deletions ansible_ai_connect_chatbot/src/Constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ export const TIMEOUT_MSG =
"_Chatbot service is taking too long to respond to your query. " +
"Try to submit a different query or try again later._";

/* Too many request message */
export const TOO_MANY_REQUESTS_MSG =
"_Chatbot service is busy with too many requests. Please try again later._";

/* Footnote label */
export const FOOTNOTE_LABEL = "Lightspeed uses AI. Check for mistakes.";

Expand Down
26 changes: 22 additions & 4 deletions ansible_ai_connect_chatbot/src/useChatbot/useChatbot.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
GITHUB_NEW_ISSUE_URL,
Sentiment,
TIMEOUT_MSG,
TOO_MANY_REQUESTS_MSG,
} from "../Constants";

const userName = document.getElementById("user_name")?.innerText ?? "User";
Expand Down Expand Up @@ -52,12 +53,14 @@ export const inDebugMode = () => {
};

const isTimeoutError = (e: any) =>
e?.name === "AxiosError" &&
e?.message === `timeout of ${API_TIMEOUT}ms exceeded`;
axios.isAxiosError(e) && e.message === `timeout of ${API_TIMEOUT}ms exceeded`;

export const timeoutMessage = (): MessageProps => ({
const isTooManyRequestsError = (e: any) =>
axios.isAxiosError(e) && e.response?.status === 429;

export const fixedMessage = (content: string): MessageProps => ({
role: "bot",
content: TIMEOUT_MSG,
content,
name: botName,
avatar: logo,
timestamp: getTimestamp(),
Expand Down Expand Up @@ -90,6 +93,12 @@ export const feedbackMessage = (f: ChatFeedback): MessageProps => ({
],
});

export const timeoutMessage = (): MessageProps => fixedMessage(TIMEOUT_MSG);
export const tooManyRequestsMessage = (): MessageProps =>
fixedMessage(TOO_MANY_REQUESTS_MSG);

const delay = (ms: number) => new Promise((res) => setTimeout(res, ms));

type AlertMessage = {
title: string;
message: string;
Expand Down Expand Up @@ -263,6 +272,15 @@ export const useChatbot = () => {
...timeoutMessage(),
};
addMessage(newBotMessage);
} else if (isTooManyRequestsError(e)) {
// Insert a 3-sec delay before showing the "Too Many Request" message
// for reducing the number of chat requests when the server is busy.
await delay(3000);
const newBotMessage = {
referenced_documents: [],
...tooManyRequestsMessage(),
};
addMessage(newBotMessage);
} else {
setAlertMessage({
title: "Error",
Expand Down

0 comments on commit 104c09b

Please sign in to comment.