Skip to content

Commit

Permalink
[CLEANUP]
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed Oct 25, 2024
1 parent 9d23de0 commit b5c32a2
Show file tree
Hide file tree
Showing 9 changed files with 676 additions and 212 deletions.
192 changes: 192 additions & 0 deletions dj_swarm/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import os
from dotenv import load_dotenv
import cv2
import spotipy
from spotipy.oauth2 import SpotifyOAuth
from swarms import Agent
from time import sleep
from typing import Optional
from loguru import logger
from tenacity import retry, wait_fixed, stop_after_attempt
from model import VisionAPIWrapper

# Load environment variables
load_dotenv()

# Set up logging
logger.add("dj_agent_log.log", rotation="1 MB", level="DEBUG")

# Initialize the Spotify API for music control
spotify: spotipy.Spotify = spotipy.Spotify(
auth_manager=SpotifyOAuth(
client_id=os.environ.get("SPOTIFY_CLIENT_ID"),
client_secret=os.environ.get("SPOTIFY_CLIENT_SECRET"),
redirect_uri="http://localhost:8080",
scope="user-read-playback-state user-modify-playback-state",
)
)

# Initialize the vision model for crowd analysis (GPT-4 Vision or OpenCV)
vision_llm = VisionAPIWrapper(
api_key="",
max_tokens=500,
)

# Define the crowd analysis task for acid techno selection
task: str = (
"Analyze this real-time image of a crowd and determine the overall energy level "
"based on body movement, facial expressions, and crowd density. "
"If the crowd appears highly energetic, output a high-intensity acid techno song "
"with a fast tempo and heavy bass. If the energy is lower, output a slightly slower "
"but still engaging acid techno song. Provide the song recommendation to match the "
"crowd’s energy, focusing on high-energy acid techno music."
)


# Retry decorator to handle API call failures
@retry(wait=wait_fixed(5), stop=stop_after_attempt(3))
def get_spotify_track(track_name: str) -> Optional[str]:
"""
Fetches a track from Spotify by its name.
Retries up to 3 times with a 5-second wait in case of failure.
:param track_name: The name of the track to search for.
:return: Spotify track ID if successful, None otherwise.
"""
logger.info(f"Fetching Spotify track for: {track_name}")
try:
results = spotify.search(q=track_name, type="track", limit=1)
track = results["tracks"]["items"][0]["id"]
logger.debug(
f"Found track ID: {track} for track name: {track_name}"
)
return track
except Exception as e:
logger.error(f"Error fetching Spotify track: {e}")
raise


def analyze_crowd(frame: str) -> str:
"""
Analyze the provided image frame of the crowd and return a description
of the energy level and a song recommendation.
:param frame: The path to the image frame to analyze.
:return: A string containing the recommended song based on crowd energy.
"""
logger.info("Analyzing crowd energy from image.")
try:
# Initialize the workflow agent
agent: Agent = Agent(
agent_name="AcidTechnoDJ_CrowdAnalyzer",
# system_prompt=task,
llm=vision_llm,
max_loops=1,
# autosave=True,
# dashboard=True,
# multi_modal=True,
)

# Analyze the frame
response: str = agent.run(task, frame)
logger.debug(f"Crowd analysis result: {response}")
return response # This should be a recommended song title
except Exception as e:
logger.error(f"Error analyzing crowd: {e}")
raise


def play_song(song_name: str) -> None:
"""
Fetches the song by name from Spotify and plays it.
:param song_name: The name of the song to play.
"""
logger.info(f"Attempting to play song: {song_name}")
track_id: Optional[str] = get_spotify_track(song_name)
if track_id:
spotify.start_playback(uris=[f"spotify:track:{track_id}"])
logger.info(f"Now playing track: {track_id}")
else:
logger.error(f"Could not play song: {song_name}")


def save_frame(frame: cv2.Mat, path: str = "temp_frame.jpg") -> str:
"""
Saves a frame from the video feed as an image file.
:param frame: The OpenCV frame (image) to save.
:param path: The file path to save the image to (default: 'temp_frame.jpg').
:return: The path to the saved image file.
"""
cv2.imwrite(path, frame)
logger.info(f"Frame saved to {path}")
return path


@retry(wait=wait_fixed(5), stop=stop_after_attempt(3))
def capture_video_feed() -> Optional[cv2.VideoCapture]:
"""
Initializes and returns the video capture feed.
:return: OpenCV video capture object if successful, None otherwise.
"""
logger.info("Attempting to capture video feed.")
try:
camera: cv2.VideoCapture = cv2.VideoCapture(0)
if not camera.isOpened():
raise RuntimeError("Failed to open camera.")
logger.debug("Camera feed opened successfully.")
return camera
except Exception as e:
logger.error(f"Error capturing video feed: {e}")
raise


def run_dj_agent():
"""
Runs the DJ agent in a loop, analyzing the crowd's energy level in real-time,
recommending acid techno tracks based on the analysis, and playing the recommended
song if applicable.
"""
logger.info("Starting Acid Techno DJ Agent...")

# Capture the video feed
camera: Optional[cv2.VideoCapture] = capture_video_feed()

try:
while True:
# Capture a frame every 5 seconds
ret, frame = camera.read()
if not ret:
logger.warning(
"Failed to capture frame from video feed."
)
break

# Save the frame as an image file
frame_path: str = save_frame(frame)

# Analyze the current crowd state and get a song recommendation
recommended_song: str = analyze_crowd(frame_path)
logger.info(f"Recommended song: {recommended_song}")

# Play the recommended song based on the analysis
play_song(recommended_song)

# Wait for 5 seconds before capturing the next frame
sleep(5)

except Exception as e:
logger.error(f"DJ Agent encountered an error: {e}")
finally:
# Release the camera feed when done
if camera:
camera.release()
cv2.destroyAllWindows()
logger.info("DJ Agent has stopped.")


# Run the DJ agent
if __name__ == "__main__":
run_dj_agent()
106 changes: 106 additions & 0 deletions dj_swarm/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from openai import OpenAI
import base64
from loguru import logger
from typing import Any, Optional


class VisionAPIWrapper:
def __init__(
self,
api_key: str = None,
system_prompt: str = None,
model: str = "gpt-4o-mini",
max_tokens: int = 300,
temperature: float = 0.7,
):
"""
Initializes the API wrapper with the system prompt and configuration.
Args:
system_prompt (str): The system prompt for the model.
model (str): The OpenAI model to use.
max_tokens (int): Maximum number of tokens to generate.
temperature (float): Sampling temperature for the model.
"""
self.client = OpenAI(api_key=api_key)
self.system_prompt = system_prompt
self.model = model
self.max_tokens = max_tokens
self.temperature = temperature

@staticmethod
def encode_image(image_path: str) -> str:
"""
Encodes the image to base64 format.
Args:
image_path (str): Path to the image file.
Returns:
str: Base64 encoded image string.
"""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")

# @retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
def run(self, task: str, img: Optional[str] = None) -> Any:
"""
Sends a request to the OpenAI API with a task and an optional image.
Args:
task (str): Task to send to the model.
img (Optional[str]): Path to the image to be analyzed by the model (optional).
Returns:
Any: The response from the model.
"""
messages = [{"role": "system", "content": self.system_prompt}]

user_message = {
"role": "user",
"content": [{"type": "text", "text": task}],
}

if img:
base64_image = self.encode_image(img)
image_message = {
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
},
}
user_message["content"].append(image_message)

messages.append(user_message)

logger.info(
f"Sending request to OpenAI with task: {task} and image: {img}"
)

try:
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
max_tokens=self.max_tokens,
temperature=self.temperature,
)
logger.info("Received response successfully.")
return response.choices[0].message.content
except Exception as e:
logger.error(
f"An error occurred while making the API request: {e}"
)
raise

def __call__(self, task: str, img: Optional[str] = None) -> Any:
"""
Makes the object callable and returns the result of the run method.
Args:
task (str): Task to send to the model.
img (Optional[str]): Path to the image (optional).
Returns:
Any: The response from the model.
"""
return self.run(task, img)
Loading

0 comments on commit b5c32a2

Please sign in to comment.