generated from kyegomez/Python-Package-Template
-
-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Your Name
committed
Oct 25, 2024
1 parent
9d23de0
commit b5c32a2
Showing
9 changed files
with
676 additions
and
212 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
import os | ||
from dotenv import load_dotenv | ||
import cv2 | ||
import spotipy | ||
from spotipy.oauth2 import SpotifyOAuth | ||
from swarms import Agent | ||
from time import sleep | ||
from typing import Optional | ||
from loguru import logger | ||
from tenacity import retry, wait_fixed, stop_after_attempt | ||
from model import VisionAPIWrapper | ||
|
||
# Load environment variables | ||
load_dotenv() | ||
|
||
# Set up logging | ||
logger.add("dj_agent_log.log", rotation="1 MB", level="DEBUG") | ||
|
||
# Initialize the Spotify API for music control | ||
spotify: spotipy.Spotify = spotipy.Spotify( | ||
auth_manager=SpotifyOAuth( | ||
client_id=os.environ.get("SPOTIFY_CLIENT_ID"), | ||
client_secret=os.environ.get("SPOTIFY_CLIENT_SECRET"), | ||
redirect_uri="http://localhost:8080", | ||
scope="user-read-playback-state user-modify-playback-state", | ||
) | ||
) | ||
|
||
# Initialize the vision model for crowd analysis (GPT-4 Vision or OpenCV) | ||
vision_llm = VisionAPIWrapper( | ||
api_key="", | ||
max_tokens=500, | ||
) | ||
|
||
# Define the crowd analysis task for acid techno selection | ||
task: str = ( | ||
"Analyze this real-time image of a crowd and determine the overall energy level " | ||
"based on body movement, facial expressions, and crowd density. " | ||
"If the crowd appears highly energetic, output a high-intensity acid techno song " | ||
"with a fast tempo and heavy bass. If the energy is lower, output a slightly slower " | ||
"but still engaging acid techno song. Provide the song recommendation to match the " | ||
"crowd’s energy, focusing on high-energy acid techno music." | ||
) | ||
|
||
|
||
# Retry decorator to handle API call failures | ||
@retry(wait=wait_fixed(5), stop=stop_after_attempt(3)) | ||
def get_spotify_track(track_name: str) -> Optional[str]: | ||
""" | ||
Fetches a track from Spotify by its name. | ||
Retries up to 3 times with a 5-second wait in case of failure. | ||
:param track_name: The name of the track to search for. | ||
:return: Spotify track ID if successful, None otherwise. | ||
""" | ||
logger.info(f"Fetching Spotify track for: {track_name}") | ||
try: | ||
results = spotify.search(q=track_name, type="track", limit=1) | ||
track = results["tracks"]["items"][0]["id"] | ||
logger.debug( | ||
f"Found track ID: {track} for track name: {track_name}" | ||
) | ||
return track | ||
except Exception as e: | ||
logger.error(f"Error fetching Spotify track: {e}") | ||
raise | ||
|
||
|
||
def analyze_crowd(frame: str) -> str: | ||
""" | ||
Analyze the provided image frame of the crowd and return a description | ||
of the energy level and a song recommendation. | ||
:param frame: The path to the image frame to analyze. | ||
:return: A string containing the recommended song based on crowd energy. | ||
""" | ||
logger.info("Analyzing crowd energy from image.") | ||
try: | ||
# Initialize the workflow agent | ||
agent: Agent = Agent( | ||
agent_name="AcidTechnoDJ_CrowdAnalyzer", | ||
# system_prompt=task, | ||
llm=vision_llm, | ||
max_loops=1, | ||
# autosave=True, | ||
# dashboard=True, | ||
# multi_modal=True, | ||
) | ||
|
||
# Analyze the frame | ||
response: str = agent.run(task, frame) | ||
logger.debug(f"Crowd analysis result: {response}") | ||
return response # This should be a recommended song title | ||
except Exception as e: | ||
logger.error(f"Error analyzing crowd: {e}") | ||
raise | ||
|
||
|
||
def play_song(song_name: str) -> None: | ||
""" | ||
Fetches the song by name from Spotify and plays it. | ||
:param song_name: The name of the song to play. | ||
""" | ||
logger.info(f"Attempting to play song: {song_name}") | ||
track_id: Optional[str] = get_spotify_track(song_name) | ||
if track_id: | ||
spotify.start_playback(uris=[f"spotify:track:{track_id}"]) | ||
logger.info(f"Now playing track: {track_id}") | ||
else: | ||
logger.error(f"Could not play song: {song_name}") | ||
|
||
|
||
def save_frame(frame: cv2.Mat, path: str = "temp_frame.jpg") -> str: | ||
""" | ||
Saves a frame from the video feed as an image file. | ||
:param frame: The OpenCV frame (image) to save. | ||
:param path: The file path to save the image to (default: 'temp_frame.jpg'). | ||
:return: The path to the saved image file. | ||
""" | ||
cv2.imwrite(path, frame) | ||
logger.info(f"Frame saved to {path}") | ||
return path | ||
|
||
|
||
@retry(wait=wait_fixed(5), stop=stop_after_attempt(3)) | ||
def capture_video_feed() -> Optional[cv2.VideoCapture]: | ||
""" | ||
Initializes and returns the video capture feed. | ||
:return: OpenCV video capture object if successful, None otherwise. | ||
""" | ||
logger.info("Attempting to capture video feed.") | ||
try: | ||
camera: cv2.VideoCapture = cv2.VideoCapture(0) | ||
if not camera.isOpened(): | ||
raise RuntimeError("Failed to open camera.") | ||
logger.debug("Camera feed opened successfully.") | ||
return camera | ||
except Exception as e: | ||
logger.error(f"Error capturing video feed: {e}") | ||
raise | ||
|
||
|
||
def run_dj_agent(): | ||
""" | ||
Runs the DJ agent in a loop, analyzing the crowd's energy level in real-time, | ||
recommending acid techno tracks based on the analysis, and playing the recommended | ||
song if applicable. | ||
""" | ||
logger.info("Starting Acid Techno DJ Agent...") | ||
|
||
# Capture the video feed | ||
camera: Optional[cv2.VideoCapture] = capture_video_feed() | ||
|
||
try: | ||
while True: | ||
# Capture a frame every 5 seconds | ||
ret, frame = camera.read() | ||
if not ret: | ||
logger.warning( | ||
"Failed to capture frame from video feed." | ||
) | ||
break | ||
|
||
# Save the frame as an image file | ||
frame_path: str = save_frame(frame) | ||
|
||
# Analyze the current crowd state and get a song recommendation | ||
recommended_song: str = analyze_crowd(frame_path) | ||
logger.info(f"Recommended song: {recommended_song}") | ||
|
||
# Play the recommended song based on the analysis | ||
play_song(recommended_song) | ||
|
||
# Wait for 5 seconds before capturing the next frame | ||
sleep(5) | ||
|
||
except Exception as e: | ||
logger.error(f"DJ Agent encountered an error: {e}") | ||
finally: | ||
# Release the camera feed when done | ||
if camera: | ||
camera.release() | ||
cv2.destroyAllWindows() | ||
logger.info("DJ Agent has stopped.") | ||
|
||
|
||
# Run the DJ agent | ||
if __name__ == "__main__": | ||
run_dj_agent() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
from openai import OpenAI | ||
import base64 | ||
from loguru import logger | ||
from typing import Any, Optional | ||
|
||
|
||
class VisionAPIWrapper: | ||
def __init__( | ||
self, | ||
api_key: str = None, | ||
system_prompt: str = None, | ||
model: str = "gpt-4o-mini", | ||
max_tokens: int = 300, | ||
temperature: float = 0.7, | ||
): | ||
""" | ||
Initializes the API wrapper with the system prompt and configuration. | ||
Args: | ||
system_prompt (str): The system prompt for the model. | ||
model (str): The OpenAI model to use. | ||
max_tokens (int): Maximum number of tokens to generate. | ||
temperature (float): Sampling temperature for the model. | ||
""" | ||
self.client = OpenAI(api_key=api_key) | ||
self.system_prompt = system_prompt | ||
self.model = model | ||
self.max_tokens = max_tokens | ||
self.temperature = temperature | ||
|
||
@staticmethod | ||
def encode_image(image_path: str) -> str: | ||
""" | ||
Encodes the image to base64 format. | ||
Args: | ||
image_path (str): Path to the image file. | ||
Returns: | ||
str: Base64 encoded image string. | ||
""" | ||
with open(image_path, "rb") as image_file: | ||
return base64.b64encode(image_file.read()).decode("utf-8") | ||
|
||
# @retry(stop=stop_after_attempt(3), wait=wait_fixed(2)) | ||
def run(self, task: str, img: Optional[str] = None) -> Any: | ||
""" | ||
Sends a request to the OpenAI API with a task and an optional image. | ||
Args: | ||
task (str): Task to send to the model. | ||
img (Optional[str]): Path to the image to be analyzed by the model (optional). | ||
Returns: | ||
Any: The response from the model. | ||
""" | ||
messages = [{"role": "system", "content": self.system_prompt}] | ||
|
||
user_message = { | ||
"role": "user", | ||
"content": [{"type": "text", "text": task}], | ||
} | ||
|
||
if img: | ||
base64_image = self.encode_image(img) | ||
image_message = { | ||
"type": "image_url", | ||
"image_url": { | ||
"url": f"data:image/jpeg;base64,{base64_image}" | ||
}, | ||
} | ||
user_message["content"].append(image_message) | ||
|
||
messages.append(user_message) | ||
|
||
logger.info( | ||
f"Sending request to OpenAI with task: {task} and image: {img}" | ||
) | ||
|
||
try: | ||
response = self.client.chat.completions.create( | ||
model=self.model, | ||
messages=messages, | ||
max_tokens=self.max_tokens, | ||
temperature=self.temperature, | ||
) | ||
logger.info("Received response successfully.") | ||
return response.choices[0].message.content | ||
except Exception as e: | ||
logger.error( | ||
f"An error occurred while making the API request: {e}" | ||
) | ||
raise | ||
|
||
def __call__(self, task: str, img: Optional[str] = None) -> Any: | ||
""" | ||
Makes the object callable and returns the result of the run method. | ||
Args: | ||
task (str): Task to send to the model. | ||
img (Optional[str]): Path to the image (optional). | ||
Returns: | ||
Any: The response from the model. | ||
""" | ||
return self.run(task, img) |
Oops, something went wrong.