BrianFrankenstein's picture
Upload 6 files
a54df69 verified
from langchain_community.tools import TavilySearchResults, tool
from langchain_community.document_loaders import WikipediaLoader, YoutubeLoader
from langchain_core.messages import SystemMessage
from state import QuestionState
from PIL import Image
from youtube_transcript_api import YouTubeTranscriptApi, TranscriptsDisabled, NoTranscriptFound
import mimetypes
import logging
import io
import requests
import re
from state import QuestionState
from PIL import Image
from state import QuestionState
from PIL import Image
# --- Configure Logging (Optional but Recommended) ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
# Search query writing
search_instructions = SystemMessage(content=f"""Search the internet to find relevant answers to queries""")
def search_web(state: QuestionState):
""" Retrieve docs from web search """
logger.info("Tool called: search_web")
# Search
tavily_search = TavilySearchResults(max_results=3)
# Search query
structured_llm = llm.with_structured_output(SearchQuery)
search_query = structured_llm.invoke([search_instructions]+state['messages'])
# Search
search_docs = tavily_search.invoke(search_query.search_query)
# Format
formatted_search_docs = "\n\n---\n\n".join(
[
f'<Document href="{doc["url"]}"/>\n{doc["content"]}\n</Document>'
for doc in search_docs
]
)
return {"context": [formatted_search_docs]}
def search_wikipedia(state: QuestionState):
""" Retrieve docs from wikipedia """
logger.info("Tool called: search_wikipedia")
# Search query
structured_llm = llm.with_structured_output(SearchQuery)
search_query = structured_llm.invoke([search_instructions]+state['messages'])
# Search
search_docs = WikipediaLoader(query=search_query.search_query,
load_max_docs=2).load()
# Format
formatted_search_docs = "\n\n---\n\n".join(
[
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
for doc in search_docs
]
)
return {"context": [formatted_search_docs]}
def get_image_attachment(state: QuestionState):
""" Retrieve image attachment for the current question """
logger.info("Tool called: get_image_attachment")
response = _download_with_retries(state["attachment_url"])
if response is None:
logger.error(f"Failed to download image after retries: {state['attachment_url']}")
return None
try:
image_data = base64.b64encode(response.content).decode("utf-8")
except Exception as e:
logger.error(f"An error occurred while trying to process the image attachment: {e}")
return None
content_type = response.headers.get('content-type') or mimetypes.guess_type(state["attachment_url"])[0]
# try to guess the content type
if content_type is None:
content_type = mimetypes.guess_type(state["attachment_url"])[0] or 'image/jpeg'
return f"data:{content_type};base64,{image_data}"
def get_audio_attachment(state: QuestionState):
""" Retrieve audio attachment for the current question """
logger.info("Tool called: get_audio_attachment at " + state["attachment_url"])
response = _download_with_retries(state["attachment_url"], stream=True)
if response is None:
logger.error(f"Failed to download audio after retries: {state['attachment_url']}")
return None
logger.info("The Audio file " + {response.content-type} + " downloaded successfully")
audio_data = base64.b64encode(response.content).decode("utf-8")
content_type = response.headers.get('content-type') or mimetypes.guess_type(state["attachment_url"])[0]
return f"data:{content_type};base64,{audio_data}"
def get_excel_attachment(state: QuestionState):
""" Retrieve excel attachment for the current question """
logger.info("Tool called: get_excel_attachment")
response = _download_with_retries(state["attachment_url"], stream=True)
if response is None:
logger.error(f"Failed to download excel after retries: {state['attachment_url']}")
return None, None
excel_bytes = response.content
return excel_bytes, response.headers.get('Content-Type')
def get_attachment(state: QuestionState):
""" Retrieve attachment for the current question if a more specific attachment tool is not available"""
logger.info("Tool called: get_attachment")
response = _download_with_retries(state["attachment_url"], stream=True)
if response is None:
logger.error(f"Failed to download attachment after retries: {state['attachment_url']}")
return None, None
attachment_bytes = response.content
return attachment_bytes, response.headers.get('Content-Type')
# --- Helper Function to Extract Video ID ---
def _download_with_retries(url, stream=False, retries=5, timeout=10):
"""Helper function to download a file with retries and logging."""
for attempt in range(1, retries + 1):
try:
logger.info(f"Attempt {attempt} downloading: {url}")
response = requests.get(url, stream=stream, timeout=timeout)
response.raise_for_status()
return response
except Exception as e:
logger.warning(f"Download failed (attempt {attempt}) for {url}: {e}")
logger.error(f"All {retries} attempts failed for download: {url}")
return None
def extract_video_id(url: str) -> str | None:
"""Extracts the YouTube video ID from various URL formats."""
# Regex patterns to cover common YouTube URL formats
patterns = [
r'(?:https?:\/\/)?(?:www\.)?youtube\.com\/watch\?v=([a-zA-Z0-9_-]{11})', # Standard watch URL
r'(?:https?:\/\/)?(?:www\.)?youtu\.be\/([a-zA-Z0-9_-]{11})', # Shortened youtu.be URL
r'(?:https?:\/\/)?(?:www\.)?youtube\.com\/embed\/([a-zA-Z0-9_-]{11})', # Embed URL
r'(?:https?:\/\/)?(?:www\.)?youtube\.com\/v\/([a-zA-Z0-9_-]{11})', # V URL (older format)
r'([a-zA-Z0-9_-]{11})' # Attempt to match just an ID (less reliable)
]
for pattern in patterns:
match = re.search(pattern, url)
if match:
logger.info(f"Extracted video ID: {match.group(1)}")
return match.group(1)
logger.warning(f"Could not extract video ID from URL: {url}")
return None
# --- Direct Transcript Fetching Function ---
def get_youtube_transcript(youtube_url: str) -> str | None:
"""
Retrieves the transcript for a YouTube video directly using youtube-transcript-api.
Args:
youtube_url: The URL of the YouTube video.
Returns:
The transcript as a single string, or None if an error occurs.
"""
logger.info("Tool called: get_youtube_transcript")
video_id = extract_video_id(youtube_url)
if not video_id:
logger.error("Invalid YouTube URL or could not extract Video ID.")
return None # Return None for error, indicating failure
try:
logger.info(f"Fetching transcript for video ID: {video_id}")
# Fetch the transcript (defaults to English, can specify languages)
transcript_list = YouTubeTranscriptApi.get_transcript(video_id)
# Combine the transcript text parts into a single string
transcript = " ".join([item['text'] for item in transcript_list])
logger.info(f"Transcript fetched successfully (length: {len(transcript)} chars).")
return transcript
except TranscriptsDisabled:
logger.error(f"Transcripts are disabled for video: {youtube_url}")
return None
except NoTranscriptFound:
logger.error(f"No transcript found for video: {youtube_url}. Might be unavailable or in an unsupported language.")
return None
except Exception as e:
# Catch any other unexpected errors (network, API changes, etc.)
logger.error(f"An unexpected error occurred fetching transcript for {youtube_url}: {e}", exc_info=True)
return None
# test_url_with_transcript = "https://www.youtube.com/watch?v=dQw4w9WgXcQ" # Example (Rick Astley)
# test_url_no_transcript = "https://www.youtube.com/watch?v=some_video_without_transcripts" # Placeholder
# test_url_invalid = "htp:/invalid-url"
# print(f"\nTesting URL: {test_url_with_transcript}")
# transcript1 = get_youtube_transcript_direct(test_url_with_transcript)
# if transcript1:
# print("Transcript (first 500 chars):", transcript1[:500])
# else:
# print("Failed to get transcript.")
# print(f"\nTesting URL: {test_url_no_transcript}") # Uncomment to test known non-transcript video
# transcript2 = get_youtube_transcript_direct(test_url_no_transcript)
# if transcript2:
# print("Transcript:", transcript2[:500])
# else:
# print("Failed to get transcript.")
# print(f"\nTesting URL: {test_url_invalid}")
# transcript3 = get_youtube_transcript_direct(test_url_invalid)
# if transcript3:
# print("Transcript:", transcript3[:500])
# else:
# print("Failed to get transcript.")
"""
def get_audio_attachment(state: QuestionState):
response = requests.get(state["attachment_url"], stream=True)
response.raise_for_status()
audio_bytes = response.content
return audio_bytes, response.headers.get('Content-Type')
"""
# def load_attachment_for_llm(url):
# response = requests.get(url)
# content_type = response.headers.get('content-type') or mimetypes.guess_type(url)[0]
# if content_type:
# if content_type.startswith('image/'):
# return Image.open(io.BytesIO(response.content))
# elif content_type.startswith('audio/'):ou
# return io.BytesIO(response.content)
# elif content_type.startswith('text/'):
# return response.text
# # Add more handlers as needed (e.g., PDF, Excel)
# # Fallback: return bytes
# return io.BytesIO(response.content)
# def get_attachment(state: QuestionState):
# # """Retrieves and loads the attachment for the current question."""
# api_url = DEFAULT_API_URL
# attachment_url = f"{api_url}/files/{state.task_id}"
# # Store the URL in the state
# state.attachment_url = attachment_url
# # Load the attachment (image, audio, text, etc.)
# attachment = load_attachment_for_llm(attachment_url)
# # Store the loaded attachment in the state
# state.attachment = attachment
# # Return updated fields as a dict (LangGraph expects this)
# return {
# "attachment_url": attachment_url,
# "attachment": attachment
# }
# return {"attachment": io.BytesIO(response.content)}