Spaces:
Runtime error
Runtime error
| import os | |
| import google.generativeai as genai | |
| from dotenv import load_dotenv | |
| from langchain_core.messages import AIMessage, HumanMessage, SystemMessage | |
| from IPython.display import Image, display | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langgraph.graph import StateGraph, START, END | |
| from langgraph.prebuilt import ToolNode | |
| from langgraph.prebuilt import tools_condition | |
| from pprint import pprint | |
| from state import QuestionState | |
| from tools import search_web, search_wikipedia, get_image_attachment, get_audio_attachment, get_youtube_transcript, get_excel_attachment, get_attachment | |
| load_dotenv() | |
| # Set up Gemini API key from environment variable | |
| GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") | |
| if not GEMINI_API_KEY: | |
| raise ValueError("Please set the GEMINI_API_KEY environment variable.") | |
| genai.configure(api_key=os.getenv("GEMINI_API_KEY")) | |
| """ | |
| models = genai.list_models() | |
| for m in models: | |
| print(m.name) | |
| """ | |
| # LLM | |
| llm = ChatGoogleGenerativeAI( | |
| model="gemini-2.5-pro-preview-03-25", # or "gemini-1.5-pro", "gemini-2.0-flash-001", etc. | |
| temperature=0.7, | |
| google_api_key=GEMINI_API_KEY # or set the GOOGLE_API_KEY environment variable | |
| ) | |
| # TOOLS | |
| tools = [search_web, search_wikipedia, get_image_attachment, get_audio_attachment, get_youtube_transcript, get_excel_attachment, get_attachment] | |
| llm_with_tools = llm.bind_tools(tools) | |
| sys_msg = SystemMessage(content="You are answering questions from the GAIA benchmark. You have access to tools to search the web and wikipedia to answer these questions. Answer with the shortest most concise response possible.You are a general AI assistant. I will ask you a question. Your answer should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. If the question refers to an attachment or media not included in the question, such as 'what is in this picture' get the attachment and use it to answer the question from this url {attachment_url}") | |
| # GRAPH | |
| # Node | |
| def assistant(state: QuestionState): | |
| systemPrompt = sys_msg.content.format(attachment_url=state["attachment_url"]) | |
| inputs = [systemPrompt] + state["messages"] | |
| return {"messages": [llm_with_tools.invoke(inputs)]} | |
| # Build graph | |
| builder = StateGraph(QuestionState) | |
| builder.add_node("assistant", assistant) | |
| builder.add_node("tools", ToolNode(tools)) | |
| builder.add_edge(START, "assistant") | |
| builder.add_conditional_edges( | |
| "assistant", | |
| # If the latest message (result) from assistant is a tool call -> tools_condition routes to tools | |
| # If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END | |
| tools_condition, | |
| ) | |
| builder.add_edge("tools", "assistant") | |
| graph = builder.compile() | |
| message = { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": "What is in the attached image?", | |
| } | |
| ], | |
| } | |
| testState = QuestionState( | |
| question="What is int the attached image?", | |
| attachment_url="https://img.freepik.com/premium-photo/cyberpunk-anime-robot-girl_1162089-424.jpg", | |
| messages=[message] #HumanMessage(content="What is int the attached image?")] | |
| ) | |
| result = graph.invoke(testState) | |
| print(result['messages'][-1].content) | |
| # View | |
| #display(Image(graph.get_graph().draw_mermaid_png())) | |
| ''' | |
| messages = [HumanMessage(content="Hello, what is 2 multiplied by 2?")] | |
| messages = graph.invoke({"messages": messages}) | |
| for m in messages['messages']: | |
| m.pretty_print() | |
| messages = [AIMessage(content=f"So you said you were researching ocean mammals?", name="Model")] | |
| messages.append(HumanMessage(content=f"Yes, that's right.",name="Lance")) | |
| messages.append(AIMessage(content=f"Great, what would you like to learn about.", name="Model")) | |
| messages.append(HumanMessage(content=f"I want to learn about the best place to see Orcas in the US.", name="Lance")) | |
| for m in messages: | |
| m.pretty_print() | |
| result = llm.invoke(messages) | |
| type(result) | |
| result.pretty_print() | |
| def gemini_llm(state): | |
| prompt = state["question"] | |
| model = genai.GenerativeModel("gemini-pro") | |
| response = model.generate_content(prompt) | |
| return {"question": prompt, "answer": response.text} | |
| # Define the state schema | |
| state_schema = {"question": str, "answer": str} | |
| # Build the LangGraph graph | |
| graph = StateGraph(state_schema) | |
| graph.add_node("gemini", gemini_llm) | |
| graph.set_entry_point("gemini") | |
| graph.add_edge("gemini", END) | |
| graph = graph.compile() | |
| if __name__ == "__main__": | |
| user_prompt = input("Enter your prompt: ") | |
| result = graph.invoke({"question": user_prompt, "answer": ""}) | |
| print("Gemini response:", result["answer"]) | |
| ''' |