| | from PIL import Image |
| | import io |
| | import streamlit as st |
| | import google.generativeai as genai |
| |
|
| | safety_settings = [ |
| | { |
| | "category": "HARM_CATEGORY_HARASSMENT", |
| | "threshold": "BLOCK_NONE" |
| | }, |
| | { |
| | "category": "HARM_CATEGORY_HATE_SPEECH", |
| | "threshold": "BLOCK_NONE" |
| | }, |
| | { |
| | "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", |
| | "threshold": "BLOCK_NONE" |
| | }, |
| | { |
| | "category": "HARM_CATEGORY_DANGEROUS_CONTENT", |
| | "threshold": "BLOCK_NONE" |
| | }, |
| | ] |
| |
|
| |
|
| | password_placeholder = st.empty() |
| | password = password_placeholder.text_input("пасскод", type="password") |
| | if password == st.secrets["real_password"]: |
| | password_placeholder.empty() |
| | |
| |
|
| | with st.sidebar: |
| | st.title("Gemini Pro") |
| | |
| | CONFIG = { |
| | "temperature": 0.5, |
| | "top_p": 1, |
| | "top_k": 32, |
| | "max_output_tokens": 4096, |
| | } |
| | |
| | genai.configure(api_key=st.secrets["api_key"]) |
| |
|
| | uploaded_image = st.file_uploader( |
| | label="загрузи изображение", |
| | label_visibility="visible", |
| | help="если загружено изображение - можно спрашивать по нему что-то, если нет - будет обычный чат", |
| | accept_multiple_files=False, |
| | type=["png", "jpg"], |
| | ) |
| |
|
| | if uploaded_image: |
| | image_bytes = uploaded_image.read() |
| |
|
| |
|
| | def get_response(messages, model="gemini-pro"): |
| | try: |
| | model = genai.GenerativeModel(model, generation_config=genai.GenerationConfig(candidate_count=1, max_output_tokens=4096, temperature=0.6)) |
| | res = model.generate_content(messages, stream=True, safety_settings=safety_settings) |
| | return res |
| | except: |
| | return "Извини, но запрос не прошел цензуру." |
| |
|
| |
|
| | if "messages" not in st.session_state: |
| | st.session_state["messages"] = [] |
| | messages = st.session_state["messages"] |
| |
|
| | if messages: |
| | for item in messages: |
| | role, parts = item.values() |
| | if role == "user": |
| | st.chat_message("user").markdown(parts[0]) |
| | elif role == "model": |
| | st.chat_message("assistant").markdown(parts[0]) |
| |
|
| | chat_message = st.chat_input("Спроси что-нибудь!") |
| |
|
| | if chat_message: |
| | st.chat_message("user").markdown(chat_message) |
| | res_area = st.chat_message("assistant").empty() |
| |
|
| | if "image_bytes" in globals(): |
| | vision_message = [chat_message, Image.open(io.BytesIO(image_bytes))] |
| | res = get_response(vision_message, model="gemini-pro-vision") |
| | else: |
| | vision_message = [{"role": "user", "parts": [chat_message]}] |
| | res = get_response(vision_message) |
| |
|
| | res_text = "" |
| | try: |
| | for chunk in res: |
| | res_text += chunk.text |
| | res_area.markdown(res_text) |
| | except: |
| | res_text += f"запрос не прошел цензуру:\n{str(res.prompt_feedback)}" |
| | res_area.markdown(res_text) |
| |
|
| |
|
| | messages.append({"role": "model", "parts": [res_text]}) |
| | else: |
| | st.warning("неправильный пароль, увы...") |
| |
|