| | import httpx |
| | import os |
| | import time |
| | import subprocess |
| | import uuid |
| | from loguru import logger |
| | from typing import Any, List, Tuple, Type, Literal, Optional, Union, Dict |
| | import httpx |
| | import os |
| | import time |
| | import subprocess |
| | import uuid |
| | import streamlit as st |
| | from openai import OpenAI |
| | import base64 |
| | from tqdm import tqdm |
| |
|
| | from app.config import app_settings |
| |
|
| | from app.qdrant_db import MyQdrantClient |
| |
|
| | from app.vdr_utils import ( |
| | get_text_embedding, |
| | get_image_embedding, |
| | pdf_folder_to_images, |
| | scale_image, |
| | pil_image_to_base64, |
| | load_images, |
| | ) |
| |
|
| | class VDRSession: |
| | def __init__(self): |
| | self.client = None |
| | self.api_key = None |
| | self.base_url = app_settings.GLOBAL_API_BASE |
| | self.SAVE_DIR = None |
| | self.db_collection = None |
| | self.session_id = str(uuid.uuid4())[:5] |
| | self.indexed_images = [] |
| | self.model_name_list = [] |
| | self.vector_db_client = None |
| |
|
| | def set_api_key(self, api_key: str): |
| | if api_key is not None and len(api_key)>10: |
| | try: |
| | api_key = api_key.strip() |
| | client = OpenAI(api_key=api_key, |
| | base_url=self.base_url) |
| | models = client.models.list() |
| | if models: |
| | self.api_key = api_key |
| | self.client = client |
| | return True |
| | except Exception as e: |
| | logger.debug(f'Incorrect API Key: {e}') |
| |
|
| | self.client = None |
| | return False |
| |
|
| | def set_context(self, embed_model: str): |
| | self.embed_model = embed_model |
| |
|
| | if not self.SAVE_DIR: |
| | self.SAVE_DIR=os.path.join('./temp_data', self.session_id) |
| | os.makedirs(self.SAVE_DIR, exist_ok=True) |
| | self.SAVE_IMAGE_DIR=os.path.join(self.SAVE_DIR, 'images') |
| | logger.debug(f'Created folder: {self.SAVE_DIR} and {self.SAVE_IMAGE_DIR}') |
| |
|
| | if not self.vector_db_client: |
| | self.vector_db_client = MyQdrantClient(path=self.SAVE_DIR) |
| |
|
| | if not self.db_collection: |
| | self.db_collection = f"qd-{embed_model}-{self.session_id}" |
| | try: |
| | if self.embed_model == "tsi-embedding-colqwen2-2b-v1": |
| | self.vector_db_client.create_collection(self.db_collection, vector_dim=128, vector_type="colbert") |
| | elif self.embed_model == "jina-embedding-clip-v1": |
| | self.vector_db_client.create_collection(self.db_collection, vector_dim=768, vector_type="dense") |
| | else: |
| | raise ValueError(f"Embedding model {self.embed_model} not supported") |
| | except Exception as e: |
| | logger.error(f"Error while creating collection: {e}") |
| | |
| | return True |
| |
|
| | def get_available_vlms(self) -> List[str]: |
| | assert self.client != None |
| | |
| | if self.model_name_list: |
| | return self.model_name_list |
| | try: |
| | models = self.client.models.list() |
| | for model in models.data: |
| | model_name = model.id |
| | substrings = ['gemini-2.0','claude','Qwen2.5-VL-72B-Instruct'] |
| | if any(substring in model_name for substring in substrings): |
| | self.model_name_list.append(model.id) |
| | |
| | except Exception as e: |
| | logger.error(f"Error while query all models: {e}") |
| | raise e |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | return self.model_name_list |
| |
|
| | def get_available_image_embeds(self) -> List[str]: |
| | assert self.client != None |
| | model_name_list = [] |
| | try: |
| | models = self.client.models.list() |
| | for model in models.data: |
| | model_name = model.id |
| | substrings = ['tsi-embedding','clip'] |
| | if any(substring in model_name for substring in substrings): |
| | model_name_list.append(model.id) |
| |
|
| | except Exception as e: |
| | logger.error(f"Error while query all models: {e}") |
| | raise e |
| | |
| | return model_name_list |
| |
|
| | def search_images(self, text: str, top_k: int = 5) -> list[str]: |
| | assert self.client != None |
| | assert self.vector_db_client != None |
| | try: |
| | if not self.indexed_images: |
| | raise Exception("No indexed images found. You need to click on 'Add selected context' button to index images.") |
| | text = text.strip() |
| | if len(text) < 2: |
| | return False |
| |
|
| | embeddings = get_text_embedding( |
| | texts=text, |
| | openai_client=self.client, |
| | model=self.embed_model |
| | )[0] |
| |
|
| | index_results = self.vector_db_client.query_multivector( |
| | multivector_input=embeddings, |
| | collection_name=self.db_collection, |
| | top_k=top_k |
| | ) |
| | image_list=[self.indexed_images[i] for i in index_results] |
| | images = [] |
| | for img in image_list: |
| | |
| | |
| | encoded = pil_image_to_base64(img) |
| | images.append(f"data:image/png;base64,{encoded}") |
| | return images |
| | except Exception as e: |
| | logger.error(f"Error while generating image: {e}") |
| | raise e |
| |
|
| | def ask(self, query: str, model: str, prompt_template: str, retrieved_context: Any, modality: str = "image", stream: bool = False) -> str: |
| | assert self.client != None |
| | assert query != None |
| | assert prompt_template != None |
| | assert retrieved_context != None |
| |
|
| | try: |
| | prompt = prompt_template.format(user_question=query) |
| | if modality == "image": |
| | context = [ |
| | { |
| | "type": "image_url", |
| | "image_url": { |
| | "url": base64_image |
| | } |
| | } for base64_image in retrieved_context |
| | ] |
| | |
| | content = [ |
| | { |
| | "type": "text", |
| | "text": prompt |
| | } |
| | ] |
| | content=content+context |
| | |
| | messages=[ |
| | { |
| | "role": "user", |
| | "content": content, |
| | } |
| | ] |
| |
|
| | chat_response = self.client.chat.completions.create( |
| | model=model, |
| | messages=messages, |
| | temperature=0.1, |
| | max_tokens=2048, |
| | stream=stream, |
| | ) |
| | if not stream: |
| | return chat_response.choices[0].message.content |
| | else: |
| | for chunk in chat_response: |
| | if chunk.choices: |
| | if chunk.choices[0].delta.content is not None: |
| | yield chunk.choices[0].delta.content |
| | |
| |
|
| | except Exception as e: |
| | logger.error(f"Error while asking: {e}") |
| | raise e |
| |
|
| | def indexing(self, uploaded_files: list[str], embed_model: str, indexing_bar: Optional[st.progress] = None) -> bool: |
| | self.set_context(embed_model) |
| |
|
| | assert self.client != None |
| | assert self.db_collection != None |
| | assert self.SAVE_DIR != None |
| | assert self.embed_model != None |
| | assert len(uploaded_files) > 0 |
| |
|
| | |
| | for file in uploaded_files : |
| | path = os.path.join(self.SAVE_DIR, file.name) |
| | if os.path.exists(path): |
| | print("File existed, skip") |
| | continue |
| | with open(path, "wb") as f: |
| | f.write(file.getvalue()) |
| | |
| | image_path_list = pdf_folder_to_images(pdf_folder=self.SAVE_DIR, output_folder=self.SAVE_IMAGE_DIR) |
| | logger.debug(f"Extracted {len(image_path_list)} images from {len(uploaded_files)} files.") |
| |
|
| | indexed_images = self.index_from_images(image_path_list, indexing_bar=indexing_bar) |
| | logger.debug(f"Indexed {len(indexed_images)} images.") |
| |
|
| | self.indexed_images.extend(indexed_images) |
| | return True |
| |
|
| | def clear_context(self): |
| | self.indexed_images = [] |
| | self.vector_db_client.delete_collection(self.db_collection) |
| | self.db_collection = None |
| | self.vector_db_client = None |
| |
|
| | if self.SAVE_DIR: |
| | if os.path.exists(self.SAVE_DIR): |
| | subprocess.run(['rm', '-rf', self.SAVE_DIR]) |
| | logger.debug(f'Removed folder: {self.SAVE_DIR}') |
| | self.SAVE_DIR = None |
| | return True |
| |
|
| | def __del__(self): |
| | self.clear_context() |
| | logger.debug('VDR session is cleaned up.') |
| |
|
| | def index_from_images(self, |
| | images_path_list: list, |
| | batch_size: int =5, |
| | indexing_bar: Optional[st.progress] = None |
| | ): |
| | try: |
| | indexed_images = [] |
| | total_len = len(images_path_list) |
| | with tqdm(total=total_len, desc="Indexing Progress") as pbar: |
| | for i in range(0, total_len, batch_size): |
| | try: |
| | batch = images_path_list[i:min(i+batch_size,total_len)] |
| | |
| | batch = [scale_image(x, 768) for x in batch] |
| |
|
| | embeddings = get_image_embedding( |
| | image_list=batch, |
| | openai_client=self.client, |
| | model=self.embed_model |
| | ) |
| | self.vector_db_client.upsert_multivector( |
| | index=i, |
| | multivector_input_list=embeddings, |
| | collection_name=self.db_collection |
| | ) |
| |
|
| | indexed_images.extend(batch) |
| | |
| | pbar.update(batch_size) |
| | indexing_bar.progress(i/total_len, text=f"Indexing {i}/{total_len}") |
| | except Exception as e: |
| | logger.exception(f"Error during indexing: {e}") |
| | continue |
| | |
| | return indexed_images |
| |
|
| | logger.debug("Indexing complete!") |
| | except Exception as e: |
| | raise Exception(f"Error during indexing: {e}") |
| |
|
| | |