From 530af3ea507a3f57df1237c9b1c38c0cb1311471 Mon Sep 17 00:00:00 2001 From: Monalisha Mishra Date: Wed, 12 Feb 2025 22:08:44 +0530 Subject: [PATCH] add song generator agent --- cookbook/playground/multimodal_agents.py | 23 ++++++ libs/agno/agno/models/response.py | 1 + libs/agno/agno/tools/models_labs.py | 95 ++++++++++++++++++++++-- 3 files changed, 113 insertions(+), 6 deletions(-) diff --git a/cookbook/playground/multimodal_agents.py b/cookbook/playground/multimodal_agents.py index 10cf48870a..9697a98cb3 100644 --- a/cookbook/playground/multimodal_agents.py +++ b/cookbook/playground/multimodal_agents.py @@ -57,6 +57,28 @@ ), ) +ml_song_agent = Agent( + name="ModelsLab Song Generator Agent", + agent_id="ml_song_agent", + model=OpenAIChat(id="gpt-4o"), + tools=[ModelsLabTools(wait_for_completion=True,file_type=FileType.MP3)], + description="You are an AI agent that can generate songs using the ModelsLabs API.", + instructions=[ + "When the user asks you to generate audio, use the `generate_audio` tool to generate the audio.", + # "You'll generate the appropriate prompt to send to the tool to generate audio.", + # "You don't need to find the appropriate voice first, I already specified the voice to user." + "Don't return file name or file url in your response or markdown just tell the audio was created successfully.", + "The audio should be long and detailed.", + ], + markdown=True, + debug_mode=True, + add_history_to_messages=True, + add_datetime_to_instructions=True, + storage=SqliteAgentStorage( + table_name="ml_song_agent", db_file=image_agent_storage_file + ), +) + ml_video_agent = Agent( name="ModelsLab Video Agent", agent_id="ml_video_agent", @@ -164,6 +186,7 @@ agents=[ image_agent, ml_gif_agent, + ml_song_agent, ml_video_agent, fal_agent, gif_agent, diff --git a/libs/agno/agno/models/response.py b/libs/agno/agno/models/response.py index 9e40171e1e..01fb4fc847 100644 --- a/libs/agno/agno/models/response.py +++ b/libs/agno/agno/models/response.py @@ -29,3 +29,4 @@ class ModelResponse: class FileType(str, Enum): MP4 = "mp4" GIF = "gif" + MP3 = "mp3" diff --git a/libs/agno/agno/tools/models_labs.py b/libs/agno/agno/tools/models_labs.py index 0fb9b3c1be..8f7f7e1125 100644 --- a/libs/agno/agno/tools/models_labs.py +++ b/libs/agno/agno/tools/models_labs.py @@ -5,7 +5,7 @@ from uuid import uuid4 from agno.agent import Agent -from agno.media import ImageArtifact, VideoArtifact +from agno.media import ImageArtifact, VideoArtifact, AudioArtifact from agno.models.response import FileType from agno.tools import Toolkit from agno.utils.log import logger @@ -15,25 +15,34 @@ except ImportError: raise ImportError("`requests` not installed. Please install using `pip install requests`") +MODELS_LAB_URLS= { + "MP4": "https://modelslab.com/api/v6/video/text2video", + "MP3": "https://modelslab.com/api/v6/voice/music_gen", + "GIF": "https://modelslab.com/api/v6/video/text2video", +} +MODELS_LAB_FETCH_URLS= { + "MP4": "https://modelslab.com/api/v6/video/fetch", + "MP3": "https://modelslab.com/api/v6/voice/fetch", + "GIF": "https://modelslab.com/api/v6/video/fetch", +} + class ModelsLabTools(Toolkit): def __init__( self, api_key: Optional[str] = None, - url: str = "https://modelslab.com/api/v6/video/text2video", - fetch_url: str = "https://modelslab.com/api/v6/video/fetch", # Whether to wait for the video to be ready wait_for_completion: bool = False, # Time to add to the ETA to account for the time it takes to fetch the video add_to_eta: int = 15, # Maximum time to wait for the video to be ready max_wait_time: int = 60, - file_type: FileType = FileType.MP4, + file_type: FileType = FileType.MP4 , ): super().__init__(name="models_labs") - self.url = url - self.fetch_url = fetch_url + self.url = MODELS_LAB_URLS[file_type.value.upper()] + self.fetch_url = MODELS_LAB_URLS[file_type.value.upper()] self.wait_for_completion = wait_for_completion self.add_to_eta = add_to_eta self.max_wait_time = max_wait_time @@ -43,6 +52,7 @@ def __init__( logger.error("MODELS_LAB_API_KEY not set. Please set the MODELS_LAB_API_KEY environment variable.") self.register(self.generate_media) + self.register(self.generate_audio) def generate_media(self, agent: Agent, prompt: str) -> str: """Use this function to generate a video or image given a prompt. @@ -122,3 +132,76 @@ def generate_media(self, agent: Agent, prompt: str) -> str: except Exception as e: logger.error(f"Failed to generate video: {e}") return f"Error: {e}" + + def generate_audio(self, agent: Agent, prompt: str) -> str: + """Use this function to generate a audio given a prompt. + + Args: + prompt (str): A text description of the desired audio. + + Returns: + str: A message indicating if the audio has been generated successfully or an error message. + """ + if not self.api_key: + return "Please set the MODELS_LAB_API_KEY" + + try: + payload = json.dumps( + { + "key": self.api_key, + "prompt": prompt, + "base64":False, + "temp": False, + "webhook": None, + "track_id": None + } + ) + + headers = {"Content-Type": "application/json"} + logger.debug(f"Generating audio for prompt: {prompt}") + response = requests.request("POST", self.url, data=payload, headers=headers) + response.raise_for_status() + + result = response.json() + if "error" in result: + logger.error(f"Failed to generate audio: {result['error']}") + return f"Error: {result['error']}" + + eta = result["eta"] + url_links = result["future_links"] + logger.info(f"Media will be ready in {eta} seconds") + logger.info(f"Media URLs: {url_links}") + + audio_id = str(result["id"]) + + logger.debug(f"Result: {result}") + + for media_url in url_links: + agent.add_audio(AudioArtifact(id=str(audio_id), url=media_url)) + + if self.wait_for_completion and isinstance(eta, int): + audio_ready = False + seconds_waited = 0 + time_to_wait = min(eta + self.add_to_eta, self.max_wait_time) + logger.info(f"Waiting for {time_to_wait} seconds for audio to be ready") + while not audio_ready and seconds_waited < time_to_wait: + time.sleep(1) + seconds_waited += 1 + # Fetch the audio from the ModelsLabs API + fetch_payload = json.dumps({"key": self.api_key}) + fetch_headers = {"Content-Type": "application/json"} + logger.debug(f"Fetching audio from {self.fetch_url}/{audio_id}") + fetch_response = requests.request( + "POST", f"{self.fetch_url}/{audio_id}", data=fetch_payload, headers=fetch_headers + ) + fetch_result = fetch_response.json() + logger.debug(f"Fetch result: {fetch_result}") + if fetch_result.get("status") == "success": + logger.debug(f"Fetch result success: {fetch_result.get("output")}") + audio_ready = True + break + + return f"Audio has been generated successfully and will be ready in {eta} seconds" + except Exception as e: + logger.error(f"Failed to generate audio: {e}") + return f"Error: {e}" \ No newline at end of file