diff --git a/api/discord_api.py b/api/discord_api.py deleted file mode 100644 index b1d1c90..0000000 --- a/api/discord_api.py +++ /dev/null @@ -1,525 +0,0 @@ -import json -from datetime import datetime -from typing import Dict, List - -import discord -from discord.ext import commands -from fastapi import ( - BackgroundTasks, - Depends, - FastAPI, - HTTPException, - Response, - status, -) -from fastapi.middleware.cors import CORSMiddleware -from loguru import logger -from sqlalchemy import ( - Column, - DateTime, - ForeignKey, - Integer, - String, - Text, -) -from sqlalchemy.ext.asyncio import ( - AsyncSession, - async_sessionmaker, - create_async_engine, -) -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.future import select -from sqlalchemy.orm import relationship - -from mcs.main import MedicalCoderSwarm - -# Configure logging -logger.add( - "discord_bot.log", - rotation="500 MB", - retention="10 days", - level="INFO", - backtrace=True, - diagnose=True, -) - -# FastAPI instance with CORS -app = FastAPI(title="Discord Medical Bot API", version="1.0.0") -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - - -# Configuration -class BotConfig: - def __init__(self): - self.DISCORD_TOKEN = "YOUR_DISCORD_BOT_TOKEN" - self.DATABASE_URL = "sqlite+aiosqlite:///./discord_bot.db" - self.MAX_HISTORY_LENGTH = 100 - self.MAX_RETRIES = 3 - self.RATE_LIMIT_WINDOW = 60 - self.MAX_REQUESTS_PER_MINUTE = 30 - self.COMMAND_PREFIX = "!" - - -config = BotConfig() - -# Database setup -Base = declarative_base() - - -class Conversation(Base): - __tablename__ = "conversations" - - id = Column(Integer, primary_key=True) - user_id = Column(String(100), unique=True, index=True) - messages = relationship( - "Message", - back_populates="conversation", - cascade="all, delete-orphan", - ) - - -class Message(Base): - __tablename__ = "messages" - - id = Column(Integer, primary_key=True) - conversation_id = Column(Integer, ForeignKey("conversations.id")) - content = Column(Text) - timestamp = Column(DateTime, default=datetime.utcnow) - role = Column(String(50)) # 'user' or 'assistant' - - conversation = relationship( - "Conversation", back_populates="messages" - ) - - -# Create async engine -engine = create_async_engine(config.DATABASE_URL, echo=True) -AsyncSessionLocal = async_sessionmaker(engine, expire_on_commit=False) - - -async def init_db(): - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) - - -# Dependency for database sessions -async def get_db(): - async with AsyncSessionLocal() as session: - try: - yield session - finally: - await session.close() - - -# Database operations -class DatabaseOps: - @staticmethod - async def store_message( - db: AsyncSession, user_id: str, content: str, role: str - ): - """Store a message in conversation history.""" - stmt = select(Conversation).where( - Conversation.user_id == str(user_id) - ) - result = await db.execute(stmt) - conversation = result.scalar_one_or_none() - - if not conversation: - conversation = Conversation(user_id=str(user_id)) - db.add(conversation) - await db.flush() - - message = Message( - conversation_id=conversation.id, - content=content, - role=role, - ) - db.add(message) - - # Maintain message limit - stmt = ( - select(Message) - .where(Message.conversation_id == conversation.id) - .order_by(Message.timestamp) - ) - result = await db.execute(stmt) - messages = result.scalars().all() - - if len(messages) > config.MAX_HISTORY_LENGTH: - for msg in messages[: -config.MAX_HISTORY_LENGTH]: - await db.delete(msg) - - await db.commit() - - @staticmethod - async def get_conversation_history( - db: AsyncSession, user_id: str - ) -> List[Dict]: - """Get conversation history for a user.""" - stmt = select(Conversation).where( - Conversation.user_id == str(user_id) - ) - result = await db.execute(stmt) - conversation = result.scalar_one_or_none() - - if not conversation: - return [] - - stmt = ( - select(Message) - .where(Message.conversation_id == conversation.id) - .order_by(Message.timestamp) - ) - result = await db.execute(stmt) - messages = result.scalars().all() - - return [ - { - "content": msg.content, - "timestamp": msg.timestamp, - "role": msg.role, - } - for msg in messages - ] - - @staticmethod - async def clear_history(db: AsyncSession, user_id: str): - """Clear conversation history for a user.""" - stmt = select(Conversation).where( - Conversation.user_id == str(user_id) - ) - result = await db.execute(stmt) - conversation = result.scalar_one_or_none() - - if conversation: - await db.delete(conversation) - await db.commit() - - -# Medical Coder Swarm with context -class ContextAwareMedicalSwarm: - def __init__(self, user_id: str): - self.user_id = user_id - self.swarm = MedicalCoderSwarm( - patient_id=user_id, max_loops=1, patient_documentation="" - ) - - async def process_with_context( - self, current_message: str, db: AsyncSession - ) -> str: - """Process message with conversation context.""" - try: - history = await DatabaseOps.get_conversation_history( - db, self.user_id - ) - - context = "\n".join( - [ - f"{msg['role']}: {msg['content']}" - for msg in history - ] - ) - - full_context = f"{context}\nUser: {current_message}" - - response = self.swarm.run( - task=current_message, context=full_context - ) - - return response - - except Exception as e: - logger.error( - f"Swarm processing error for user {self.user_id}: {str(e)}" - ) - return "I apologize, but I couldn't process your request at this time." - - -# Discord bot class -class MedicalBot(commands.Bot): - def __init__(self): - intents = discord.Intents.default() - intents.message_content = True - intents.dm_messages = True - - super().__init__( - command_prefix=commands.when_mentioned_or( - config.COMMAND_PREFIX - ), - intents=intents, - ) - - self.rate_limits: Dict[str, List[float]] = {} - self.db_session = AsyncSessionLocal - - async def setup_hook(self): - await self.tree.sync() - - def check_rate_limit(self, user_id: str) -> bool: - """Check if user has exceeded rate limit.""" - now = datetime.now().timestamp() - if user_id not in self.rate_limits: - self.rate_limits[user_id] = [] - - self.rate_limits[user_id] = [ - ts - for ts in self.rate_limits[user_id] - if now - ts < config.RATE_LIMIT_WINDOW - ] - - return ( - len(self.rate_limits[user_id]) - < config.MAX_REQUESTS_PER_MINUTE - ) - - def add_rate_limit_timestamp(self, user_id: str): - """Add timestamp for rate limiting.""" - if user_id not in self.rate_limits: - self.rate_limits[user_id] = [] - self.rate_limits[user_id].append(datetime.now().timestamp()) - - -# Create bot instance -bot = MedicalBot() - - -# Command handlers -@bot.tree.command( - name="help", description="Show available commands and usage" -) -async def help_command(interaction: discord.Interaction): - """Handle help command.""" - help_embed = discord.Embed( - title="Medical Coding Assistant Help", - description="Here are the available commands:", - color=discord.Color.blue(), - ) - - help_embed.add_field( - name="/help", value="Show this help message", inline=False - ) - - help_embed.add_field( - name="/analyze ", - value="Analyze medical text for coding", - inline=False, - ) - - help_embed.add_field( - name="/clear", - value="Clear your conversation history", - inline=False, - ) - - help_embed.add_field( - name="DM Functionality", - value="You can also DM me directly for a natural conversation with memory!", - inline=False, - ) - - await interaction.response.send_message(embed=help_embed) - - -@bot.tree.command( - name="analyze", description="Analyze medical text for coding" -) -async def analyze_command( - interaction: discord.Interaction, text: str -): - """Handle analyze command.""" - user_id = str(interaction.user.id) - - if not bot.check_rate_limit(user_id): - await interaction.response.send_message( - "You're sending requests too quickly. Please wait a moment.", - ephemeral=True, - ) - return - - async with bot.db_session() as db: - try: - # Store user message - await DatabaseOps.store_message(db, user_id, text, "user") - - # Process with swarm - swarm = ContextAwareMedicalSwarm(user_id) - response = await swarm.process_with_context(text, db) - - # Store bot response - await DatabaseOps.store_message( - db, user_id, response, "assistant" - ) - - # Send response - await interaction.response.send_message(response) - bot.add_rate_limit_timestamp(user_id) - - except Exception as e: - logger.error( - f"Error processing analyze command: {str(e)}" - ) - await interaction.response.send_message( - "I encountered an error processing your request. Please try again later.", - ephemeral=True, - ) - - -@bot.tree.command( - name="clear", description="Clear your conversation history" -) -async def clear_command(interaction: discord.Interaction): - """Handle clear command.""" - async with bot.db_session() as db: - try: - await DatabaseOps.clear_history( - db, str(interaction.user.id) - ) - await interaction.response.send_message( - "Your conversation history has been cleared! 🧹", - ephemeral=True, - ) - except Exception as e: - logger.error(f"Error clearing history: {str(e)}") - await interaction.response.send_message( - "Failed to clear conversation history. Please try again later.", - ephemeral=True, - ) - - -# DM handler -@bot.event -async def on_message(message: discord.Message): - """Handle direct messages.""" - # Ignore bot messages and non-DM messages - if message.author.bot or not isinstance( - message.channel, discord.DMChannel - ): - return - - user_id = str(message.author.id) - - if not bot.check_rate_limit(user_id): - await message.reply( - "You're sending messages too quickly. Please wait a moment." - ) - return - - async with bot.db_session() as db: - try: - # Store user message - await DatabaseOps.store_message( - db, user_id, message.content, "user" - ) - - # Process with swarm - swarm = ContextAwareMedicalSwarm(user_id) - response = await swarm.process_with_context( - message.content, db - ) - - # Store bot response - await DatabaseOps.store_message( - db, user_id, response, "assistant" - ) - - # Send response - await message.reply(response) - bot.add_rate_limit_timestamp(user_id) - - except Exception as e: - logger.error(f"Error processing DM: {str(e)}") - await message.reply( - "I encountered an error processing your message. " - "Please try again later." - ) - - -# FastAPI endpoints -@app.post("/start") -async def start_bot(background_tasks: BackgroundTasks): - """Start the Discord bot.""" - try: - # Initialize database - await init_db() - - # Start bot - background_tasks.add_task(bot.start, config.DISCORD_TOKEN) - return {"status": "Bot started successfully"} - except Exception as e: - logger.error(f"Failed to start bot: {str(e)}") - raise HTTPException( - status_code=500, detail="Failed to start bot" - ) - - -@app.get("/conversations/{user_id}") -async def get_conversation( - user_id: str, db: AsyncSession = Depends(get_db) -): - """Get conversation history for a user.""" - try: - history = await DatabaseOps.get_conversation_history( - db, user_id - ) - return {"user_id": user_id, "messages": history} - except Exception as e: - logger.error(f"Failed to fetch conversation: {str(e)}") - raise HTTPException( - status_code=500, detail="Failed to fetch conversation" - ) - - -@app.delete("/conversations/{user_id}") -async def clear_conversation( - user_id: str, db: AsyncSession = Depends(get_db) -): - """Clear conversation history for a user.""" - try: - await DatabaseOps.clear_history(db, user_id) - return {"status": "Conversation cleared successfully"} - except Exception as e: - logger.error(f"Failed to clear conversation: {str(e)}") - raise HTTPException( - status_code=500, detail="Failed to clear conversation" - ) - - -@app.get("/health") -async def health_check(db: AsyncSession = Depends(get_db)): - """Health check endpoint.""" - try: - # Check database connection - await db.execute("SELECT 1") - - # Check Discord bot connection - if not bot.is_ready(): - raise Exception("Discord bot is not connected") - - return {"status": "healthy"} - except Exception as e: - logger.error(f"Health check failed: {str(e)}") - return Response( - content=json.dumps( - {"status": "unhealthy", "error": str(e)} - ), - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - ) - - -if __name__ == "__main__": - import uvicorn - - logger.info("Starting Discord Bot API server...") - uvicorn.run( - "main:app", - host="0.0.0.0", - port=8000, - reload=True, - workers=1, # Use 1 worker for Discord bot - ) diff --git a/api/requirements.txt b/api/requirements.txt index 646e55c..dc215cd 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -3,4 +3,5 @@ pydantic mcs cryptography uvicorn -loguru \ No newline at end of file +loguru +swarms-tools \ No newline at end of file diff --git a/api/skypilot.yaml b/api/skypilot.yaml deleted file mode 100644 index 55d6743..0000000 --- a/api/skypilot.yaml +++ /dev/null @@ -1,45 +0,0 @@ -name: mcs-api - -service: - readiness_probe: - path: /docs - initial_delay_seconds: 300 - timeout_seconds: 30 - - replica_policy: - min_replicas: 1 - max_replicas: 50 - target_qps_per_replica: 5 - upscale_delay_seconds: 180 - downscale_delay_seconds: 600 - - -envs: - WORKSPACE_DIR: "agent_workspace" - OPENAI_API_KEY: "" - MASTER_KEY: "278327837287384572" - -resources: - # cloud: aws # The cloud to use (optional). - ports: 8000 # FastAPI default port - cpus: 16 - memory: 64 - disk_size: 50 - use_spot: true - -workdir: . - -setup: | - git clone https://github.com/The-Swarm-Corporation/MedicalCoderSwarm.git - cd MedicalCoderSwarm/api - pip install -r requirements.txt - pip install swarms swarm-models loguru pydantic - -run: | - uvicorn main:app --host 0.0.0.0 --port 8000 --workers 4 - -# env: -# PYTHONPATH: /app/swarms -# LOG_LEVEL: "INFO" -# MAX_WORKERS: "4" - diff --git a/cloud-run.yaml b/cloud-run.yaml deleted file mode 100644 index ae3b985..0000000 --- a/cloud-run.yaml +++ /dev/null @@ -1,6 +0,0 @@ -readinessProbe: - httpGet: - path: /health - port: 8000 - initialDelaySeconds: 100 - periodSeconds: 100 \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml deleted file mode 100644 index 45163a5..0000000 --- a/docker-compose.yml +++ /dev/null @@ -1,20 +0,0 @@ -version: "3.9" - -services: - mcs: - build: - context: . - dockerfile: api/Dockerfile - ports: - - "8000:8000" - environment: - - OPENAI_API_KEY=${OPENAI_API_KEY} - - DB_PATH=/app/medical_coder.db - - WORKSPACE_DIR="agent_workspace" - volumes: - - mcs-data:/app - restart: always - -volumes: - mcs-data: - driver: local diff --git a/example.py b/example.py index 99567b5..90deaf9 100644 --- a/example.py +++ b/example.py @@ -1,25 +1,74 @@ from mcs.main import MedicalCoderSwarm if __name__ == "__main__": - # Example patient case + # Extended Example Patient Case patient_case = """ - Patient: 45-year-old White Male - Location: New York, NY + Patient Information: + - Name: John Doe + - Age: 45 + - Gender: Male + - Ethnicity: White + - Location: New York, NY + - BMI: 28.5 (Overweight) + - Occupation: Office Worker + + Presenting Complaints: + - Persistent fatigue for 3 months + - Swelling in lower extremities + - Difficulty concentrating (brain fog) + - Increased frequency of urination + + Medical History: + - Hypertension (diagnosed 5 years ago, poorly controlled) + - Type 2 Diabetes Mellitus (diagnosed 2 years ago, HbA1c: 8.2%) + - Family history of chronic kidney disease (mother) + + Current Medications: + - Lisinopril 20 mg daily + - Metformin 1000 mg twice daily + - Atorvastatin 10 mg daily Lab Results: - - egfr - - 59 ml / min / 1.73 - - non african-american - + - eGFR: 59 ml/min/1.73m² (Non-African American) + - Serum Creatinine: 1.5 mg/dL + - BUN: 22 mg/dL + - Potassium: 4.8 mmol/L + - HbA1c: 8.2% + - Urinalysis: Microalbuminuria detected (300 mg/g creatinine) + + Vital Signs: + - Blood Pressure: 145/90 mmHg + - Heart Rate: 78 bpm + - Respiratory Rate: 16 bpm + - Temperature: 98.6°F + - Oxygen Saturation: 98% + + Differential Diagnoses to Explore: + 1. Chronic Kidney Disease (CKD) Stage 3 + 2. Diabetic Nephropathy + 3. Secondary Hypertension (due to CKD) + 4. Fatigue related to poorly controlled diabetes + + Specialist Consultations Needed: + - Nephrologist + - Endocrinologist + - Dietitian for diabetic and CKD management + + Initial Management Recommendations: + - Optimize blood pressure control (<130/80 mmHg target for CKD) + - Glycemic control improvement (target HbA1c <7%) + - Lifestyle modifications: low-sodium, renal-friendly diet + - Referral to nephrologist for further evaluation """ + # Initialize the MedicalCoderSwarm with the detailed patient case swarm = MedicalCoderSwarm( - patient_id="323u29382938293829382382398", + patient_id="Patient-001", max_loops=1, - output_type="json", - patient_documentation="", ) - swarm.run(task=patient_case) + # Run the swarm on the patient case + output = swarm.run(task=patient_case) - # print(json.dumps(swarm.to_dict())) + # Print the system's state after processing + print(output.model_dump_json(indent=4)) diff --git a/mcs_bounty.pdf b/examples/mcs_bounty.pdf similarity index 100% rename from mcs_bounty.pdf rename to examples/mcs_bounty.pdf diff --git a/mcs_client_test.py b/examples/mcs_client_test.py similarity index 100% rename from mcs_client_test.py rename to examples/mcs_client_test.py diff --git a/gcr-service-policy.yaml b/gcr-service-policy.yaml deleted file mode 100644 index 46d62af..0000000 --- a/gcr-service-policy.yaml +++ /dev/null @@ -1,4 +0,0 @@ -bindings: -- members: - - allUsers - role: roles/run.invoker \ No newline at end of file diff --git a/mcs/main.py b/mcs/main.py index 0980b4f..9286655 100644 --- a/mcs/main.py +++ b/mcs/main.py @@ -1,14 +1,15 @@ import json import os +import time import uuid from datetime import datetime, timedelta from typing import Any, Callable, Dict, List, Optional -from fastapi import requests -from pydantic import BaseModel, Field -from swarm_models import GPT4VisionAPI, OpenAIChat -from swarms import Agent, AgentRearrange +from pydantic import BaseModel +from swarm_models import OpenAIChat +from swarms import Agent from swarms.telemetry.capture_sys_data import log_agent_data +from mcs.rag_api import ChromaQueryClient from mcs.security import ( KeyRotationPolicy, @@ -29,35 +30,6 @@ def patient_id_uu(): return str(uuid.uuid4().hex) -class RAGAPI: - """ - Class to interact with the RAG API. - """ - - def __init__( - self, - base_url: str = None, - ): - """ - Initialize the RAG API with a base URL. - """ - self.base_url = base_url - - def query_rag(self, query: str): - """ - Query the RAG API with a given prompt. - """ - try: - response = requests.post( - f"{self.base_url}/query", - json={"query": query}, - ) - return str(response.json()) - except Exception as e: - print(f"An error occurred during the RAG query: {e}") - return None - - chief_medical_officer = Agent( agent_name="Chief Medical Officer", system_prompt=""" @@ -438,30 +410,19 @@ def query_rag(self, query: str): treatment_agent, ] -# Define diagnostic flow -flow = f"""{medical_coder.agent_name} -> {synthesizer.agent_name}, {treatment_agent.agent_name}""" - -class MedicalCoderSwarmInput(BaseModel): - mcs_id: Optional[str] = uuid.uuid4().hex - patient_id: Optional[str] - task: Optional[str] - img: Optional[str] - patient_docs: Optional[str] - summarization: Optional[bool] +class MCSAgentOutputs(BaseModel): + agent_id: Optional[str] = str(uuid.uuid4().hex) + agent_name: Optional[str] = None + agent_output: Optional[str] = None + timestamp: Optional[str] = time.strftime("%Y-%m-%d %H:%M:%S") -class MedicalCoderSwarmOutput(BaseModel): - input: Optional[MedicalCoderSwarmInput] - run_id: Optional[str] = Field(default=uuid.uuid4().hex) - patient_id: Optional[str] - agent_outputs: Optional[str] - summarization: Optional[str] - - -class ManyMedicalCoderSwarmOutput(BaseModel): - runs_id: Optional[str] = uuid.uuid4().hex - runs: Optional[List[MedicalCoderSwarmOutput]] +class MCSOutput(BaseModel): + run_id: Optional[str] = str(uuid.uuid4().hex) + agent_outputs: Optional[List[MCSAgentOutputs]] = None + summary: Optional[str] + timestamp: Optional[str] = time.strftime("%Y-%m-%d %H:%M:%S") class MedicalCoderSwarm: @@ -474,67 +435,41 @@ def __init__( name: str = "Medical-coding-diagnosis-swarm", description: str = "Comprehensive medical diagnosis and coding system", agents: list = agents, - flow: str = flow, patient_id: str = "001", max_loops: int = 1, - output_type: str = "all", output_folder_path: str = "reports", patient_documentation: str = None, agent_outputs: list = any, - rag_enabled: bool = False, - rag_url: str = None, user_name: str = "User", key_storage_path: str = None, summarization: bool = False, - vision_enabled: bool = False, + rag_on: bool = False, + rag_url: str = None, + rag_api_key: str = None, *args, **kwargs, ): self.name = name self.description = description self.agents = agents - self.flow = flow self.patient_id = patient_id self.max_loops = max_loops - self.output_type = output_type self.output_folder_path = output_folder_path self.patient_documentation = patient_documentation self.agent_outputs = agent_outputs - self.rag_enabled = rag_enabled - self.rag_url = rag_url self.user_name = user_name self.key_storage_path = key_storage_path self.summarization = summarization - self.vision_enabled = vision_enabled + self.rag_on = rag_on + self.rag_url = rag_url + self.rag_api_key = rag_api_key self.agent_outputs = [] self.patient_id = patient_id_uu() - if self.vision_enabled: - self.change_agent_llm() - - self.diagnosis_system = AgentRearrange( - name="Medical-coding-diagnosis-swarm", - description="Comprehensive medical diagnosis and coding system", - agents=agents, - flow=flow, - max_loops=max_loops, - output_type=output_type, - *args, - **kwargs, - ) - - if self.rag_enabled: - self.diagnosis_system.memory_system = RAGAPI( - base_url=rag_url - ) - self.output_file_path = ( f"medical_diagnosis_report_{patient_id}.md", ) - # Change the user name for all agents in the swarm - self.change_agent_user_name(user_name) - # Initialize with production configuration self.secure_handler = SecureDataHandler( master_key=os.environ["MASTER_KEY"], @@ -546,27 +481,15 @@ def __init__( auto_rotate=True, ) - def change_agent_llm(self): - """ - Change the language model for all agents in the swarm. - """ - model = GPT4VisionAPI( - openai_api_key=os.getenv("OPENAI_API_KEY"), - model_name="gpt-4o", - max_tokens=4000, - ) - - for agent in self.agents: - agent.llm = model + # Output schema + self.output_schema = MCSOutput(agent_outputs=[], summary="") - def change_agent_user_name(self, user_name: str): - """ - Change the user name for all agents in the swarm. - """ - for agent in self.agents: - self.user_name = user_name + def rag_query(self, query: str): + client = ChromaQueryClient( + api_key=self.rag_api_key, base_url=self.rag_url + ) - return agents + return client.query(query) def _run( self, task: str = None, img: str = None, *args, **kwargs @@ -577,19 +500,54 @@ def _run( try: log_agent_data(self.to_dict()) - case_info = f"Patient Information: {self.patient_id} \n Timestamp: {datetime.now()} \n Patient Documentation {self.patient_documentation} \n Task: {task}" + if self.rag_on is True: + db_data = self.rag_query(task) - output = self.diagnosis_system.run( - task=case_info, img=img, *args, **kwargs + case_info = f"Patient Information: {self.patient_id} \n Timestamp: {datetime.now()} \n Patient Documentation {self.patient_documentation} \n Task: {task} " + + if self.rag_on: + case_info = f"{db_data}{case_info}" + + medical_coder_output = medical_coder.run(case_info) + + # Append output to schema + self.output_schema.agent_outputs.append( + MCSAgentOutputs( + agent_name=medical_coder.agent_name, + agent_output=medical_coder_output, + ) + ) + + # Next agent + synthesizer_output = synthesizer.run( + f"From {medical_coder.agent_name} {medical_coder_output}" + ) + self.output_schema.agent_outputs.append( + MCSAgentOutputs( + agent_name=synthesizer.agent_name, + agent_output=synthesizer_output, + ) + ) + + # Next agent + treatment_agent_output = treatment_agent.run( + f"From {synthesizer.agent_name} {synthesizer_output}" + ) + self.output_schema.agent_outputs.append( + MCSAgentOutputs( + agent_name=treatment_agent.agent_name, + agent_output=treatment_agent_output, + ) ) if self.summarization is True: - output = summarizer_agent.run(output) + output = summarizer_agent.run(treatment_agent_output) + self.output_schema.summary = output - self.agent_outputs.append(output) log_agent_data(self.to_dict()) - return output + return self.output_schema + except Exception as e: log_agent_data(self.to_dict()) print( @@ -598,13 +556,7 @@ def _run( def run(self, task: str = None, img: str = None, *args, **kwargs): try: - - if self.secure_handler: - return self.secure_run( - task=task, img=img, *args, **kwargs - ) - else: - return self._run(task, img, *args, **kwargs) + return self._run(task, img, *args, **kwargs) except Exception as e: log_agent_data(self.to_dict()) print( diff --git a/mcs/rag_api.py b/mcs/rag_api.py new file mode 100644 index 0000000..0fcb329 --- /dev/null +++ b/mcs/rag_api.py @@ -0,0 +1,61 @@ +import os +import requests +from swarms_tools.utils.formatted_string import ( + format_object_to_string, +) + + +class ChromaQueryClient: + def __init__( + self, + api_key: str = os.getenv("RAG_API_URL"), + base_url: str = os.getenv("RAG_API_URL"), + ): + """ + Initializes the ChromaQueryClient with the API key and base URL. + + :param api_key: The API key for authentication. + :param base_url: The base URL for the Chroma API. + """ + self.api_key = api_key + self.base_url = base_url + + def query(self, query: str, n_results: int, doc_limit: int): + """ + Sends a POST request to the Chroma API to perform a query. + + :param query: The query string to search for. + :param n_results: The number of results to return. + :param doc_limit: The document limit for each result. + :return: The JSON response from the API. + """ + url = f"{self.base_url}/query" + headers = { + "accept": "application/json", + "X-API-Key": self.api_key, + "Content-Type": "application/json", + } + payload = { + "query": query, + "n_results": n_results, + "doc_limit": doc_limit, + } + + response = requests.post(url, headers=headers, json=payload) + + if response.status_code == 200: + result = response.json() + result = format_object_to_string(result) + + return result + + +# # Example usage +# client = ChromaQueryClient() + +# # try: +# # result = client.query(query="back pain", n_results=5, doc_limit=4) +# # result = format_object_to_string(result) +# # print(result) +# # except requests.exceptions.RequestException as e: +# # print("An error occurred:", e) diff --git a/pyproject.toml b/pyproject.toml index cb9e677..4355d02 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "mcs" -version = "0.0.9" +version = "0.1.0" description = "Paper - Pytorch" license = "MIT" authors = ["Kye Gomez "] @@ -27,6 +27,7 @@ swarms = "*" loguru = "*" swarms-models = "*" cryptography = "*" +swarms-tools = "*" [tool.poetry.group.lint.dependencies] ruff = "^0.1.6" diff --git a/requirements.txt b/requirements.txt index b0783ca..ebdf4e4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ swarm-models cryptography python-dotenv pkg_resources -setuptools \ No newline at end of file +setuptools +swarms-tools \ No newline at end of file diff --git a/service.yaml b/service.yaml deleted file mode 100644 index 527dabc..0000000 --- a/service.yaml +++ /dev/null @@ -1,29 +0,0 @@ -apiVersion: serving.knative.dev/v1 -kind: Service -metadata: - name: mcs -spec: - template: - spec: - containers: - - image: mcs - env: - - name: OPENAI_API_KEY - valueFrom: - secretKeyRef: - name: OPENAI_API_KEY - key: latest - - - name: WORKSPACE_DIR - valueFrom: - secretKeyRef: - name: WORKSPACE_DIR - key: latest - - - name: MASTER_KEY - valueFrom: - secretKeyRef: - name: MASTER_KEY - key: latest - ports: - - containerPort: 8000 \ No newline at end of file