Skip to content

Commit

Permalink
[CLEANUP]
Browse files Browse the repository at this point in the history
  • Loading branch information
kyegomez committed Dec 29, 2024
1 parent 26667f7 commit f783f89
Show file tree
Hide file tree
Showing 11 changed files with 404 additions and 252 deletions.
95 changes: 59 additions & 36 deletions api/api.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import json
import os
import secrets
import sqlite3
from datetime import datetime
from typing import List, Optional

from dotenv import load_dotenv
from fastapi import Depends, FastAPI, Header, HTTPException, Request
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from loguru import logger
Expand Down Expand Up @@ -66,47 +65,57 @@


# Add this after the patients table creation
cursor.execute("""
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS rate_limits (
ip_address TEXT PRIMARY KEY,
last_daily_reset TEXT,
last_hourly_reset TEXT,
daily_requests_remaining INTEGER DEFAULT 1000,
hourly_requests_remaining INTEGER DEFAULT 100
)
""")
"""
)


async def check_rate_limit(request: Request):
"""Rate limiting middleware based on IP address."""
client_ip = request.client.host
now = datetime.utcnow()

try:
connection = sqlite3.connect(db_path)
cursor = connection.cursor()

# Get or create rate limit record for IP
cursor.execute(
"""INSERT OR IGNORE INTO rate_limits
(ip_address, last_daily_reset, last_hourly_reset,
daily_requests_remaining, hourly_requests_remaining)
VALUES (?, ?, ?, ?, ?)""",
(client_ip, now.isoformat(), now.isoformat(), 1000, 100)
(client_ip, now.isoformat(), now.isoformat(), 1000, 100),
)

# Get current limits
cursor.execute(
"""SELECT daily_requests_remaining, hourly_requests_remaining,
last_daily_reset, last_hourly_reset
FROM rate_limits WHERE ip_address = ?""",
(client_ip,)
(client_ip,),
)
row = cursor.fetchone()

daily_remaining, hourly_remaining, last_daily_reset, last_hourly_reset = row


(
daily_remaining,
hourly_remaining,
last_daily_reset,
last_hourly_reset,
) = row

# Check and reset daily quota
last_daily_reset_time = datetime.fromisoformat(last_daily_reset)
last_daily_reset_time = datetime.fromisoformat(
last_daily_reset
)
if now.date() > last_daily_reset_time.date():
daily_remaining = 1000
last_daily_reset = now.isoformat()
Expand All @@ -115,11 +124,13 @@ async def check_rate_limit(request: Request):
SET daily_requests_remaining = ?,
last_daily_reset = ?
WHERE ip_address = ?""",
(daily_remaining, last_daily_reset, client_ip)
(daily_remaining, last_daily_reset, client_ip),
)

# Check and reset hourly quota
last_hourly_reset_time = datetime.fromisoformat(last_hourly_reset)
last_hourly_reset_time = datetime.fromisoformat(
last_hourly_reset
)
if (now - last_hourly_reset_time).total_seconds() >= 3600:
hourly_remaining = 100
last_hourly_reset = now.isoformat()
Expand All @@ -128,19 +139,19 @@ async def check_rate_limit(request: Request):
SET hourly_requests_remaining = ?,
last_hourly_reset = ?
WHERE ip_address = ?""",
(hourly_remaining, last_hourly_reset, client_ip)
(hourly_remaining, last_hourly_reset, client_ip),
)

# Check remaining quotas
if daily_remaining <= 0:
raise HTTPException(
status_code=429,
detail="Daily rate limit exceeded. Reset occurs at midnight UTC."
status_code=429,
detail="Daily rate limit exceeded. Reset occurs at midnight UTC.",
)
if hourly_remaining <= 0:
raise HTTPException(
status_code=429,
detail="Hourly rate limit exceeded. Please try again next hour."
status_code=429,
detail="Hourly rate limit exceeded. Please try again next hour.",
)

# Deduct from both quotas
Expand All @@ -149,13 +160,15 @@ async def check_rate_limit(request: Request):
SET daily_requests_remaining = daily_requests_remaining - 1,
hourly_requests_remaining = hourly_requests_remaining - 1
WHERE ip_address = ?""",
(client_ip,)
(client_ip,),
)
connection.commit()

except sqlite3.Error as e:
logger.error(f"Error checking rate limit: {e}")
raise HTTPException(status_code=500, detail="Internal Server Error")
raise HTTPException(
status_code=500, detail="Internal Server Error"
)
finally:
if connection:
connection.close()
Expand Down Expand Up @@ -304,10 +317,11 @@ async def general_exception_handler(request, exc):
logger.error(f"Unexpected error: {exc}")
return JSONResponse(
status_code=500,
content={"detail": "An unexpected error occurred. Please try again later."},
content={
"detail": "An unexpected error occurred. Please try again later."
},
)




@app.middleware("http")
async def rate_limit_middleware(request: Request, call_next):
Expand All @@ -316,6 +330,7 @@ async def rate_limit_middleware(request: Request, call_next):
await check_rate_limit(request)
return await call_next(request)


@app.get("/rate-limits")
async def get_rate_limits(request: Request):
"""Get current rate limit status for an IP address."""
Expand All @@ -327,33 +342,41 @@ async def get_rate_limits(request: Request):
"""SELECT daily_requests_remaining, hourly_requests_remaining,
last_daily_reset, last_hourly_reset
FROM rate_limits WHERE ip_address = ?""",
(client_ip,)
(client_ip,),
)
row = cursor.fetchone()

if not row:
return {
"daily_requests_remaining": 1000,
"hourly_requests_remaining": 100,
"last_daily_reset": None,
"last_hourly_reset": None
"last_hourly_reset": None,
}

daily_remaining, hourly_remaining, last_daily_reset, last_hourly_reset = row


(
daily_remaining,
hourly_remaining,
last_daily_reset,
last_hourly_reset,
) = row

return {
"daily_requests_remaining": daily_remaining,
"hourly_requests_remaining": hourly_remaining,
"last_daily_reset": last_daily_reset,
"last_hourly_reset": last_hourly_reset
"last_hourly_reset": last_hourly_reset,
}
except sqlite3.Error as e:
logger.error(f"Error fetching rate limits: {e}")
raise HTTPException(status_code=500, detail="Internal Server Error")
raise HTTPException(
status_code=500, detail="Internal Server Error"
)
finally:
if connection:
connection.close()


@app.post("/v1/medical-coder/run", response_model=QueryResponse)
def run_medical_coder(
patient_case: PatientCase,
Expand Down Expand Up @@ -401,7 +424,7 @@ def run_medical_coder(
response_model=QueryResponse,
)
def get_patient_data(
patient_id: str,
patient_id: str,
):
"""
Retrieve patient data by patient ID.
Expand Down Expand Up @@ -509,7 +532,7 @@ def health_check():

@app.delete("/v1/medical-coder/patient/{patient_id}")
def delete_patient_data(
patient_id: str,
patient_id: str,
):
"""
Delete a patient's data by patient ID.
Expand Down Expand Up @@ -539,7 +562,7 @@ def delete_patient_data(

# @app.delete("/v1/medical-coder/patients")
# def delete_all_patients(

# ):
# """
# Delete all patient data.
Expand Down
Loading

0 comments on commit f783f89

Please sign in to comment.