| | from typing import List, Dict, Callable, Optional |
| | from langchain_text_splitters import RecursiveCharacterTextSplitter |
| | from langchain_community.document_loaders import ( |
| | DirectoryLoader, |
| | UnstructuredMarkdownLoader, |
| | PyPDFLoader, |
| | TextLoader |
| | ) |
| | import os |
| | import requests |
| | import base64 |
| | from PIL import Image |
| | import io |
| |
|
| | class DocumentLoader: |
| | """通用文档加载器""" |
| | def __init__(self, file_path: str, original_filename: str = None): |
| | self.file_path = file_path |
| | |
| | self.original_filename = original_filename or os.path.basename(file_path) |
| | |
| | self.extension = os.path.splitext(self.original_filename)[1].lower() |
| | self.api_key = os.getenv("API_KEY") |
| | self.api_base = os.getenv("BASE_URL") |
| | |
| | def process_image(self, image_path: str) -> str: |
| | """使用 SiliconFlow VLM 模型处理图片""" |
| | try: |
| | |
| | with open(image_path, 'rb') as image_file: |
| | image_data = image_file.read() |
| | base64_image = base64.b64encode(image_data).decode('utf-8') |
| | |
| | |
| | headers = { |
| | "Authorization": f"Bearer {self.api_key}", |
| | "Content-Type": "application/json" |
| | } |
| | |
| | response = requests.post( |
| | f"{self.api_base}/chat/completions", |
| | headers=headers, |
| | json={ |
| | "model": "Qwen/Qwen2.5-VL-72B-Instruct", |
| | "messages": [ |
| | { |
| | "role": "user", |
| | "content": [ |
| | { |
| | "type": "image_url", |
| | "image_url": { |
| | "url": f"data:image/jpeg;base64,{base64_image}", |
| | "detail": "high" |
| | } |
| | }, |
| | { |
| | "type": "text", |
| | "text": "请详细描述这张图片的内容,包括主要对象、场景、活动、颜色、布局等关键信息。" |
| | } |
| | ] |
| | } |
| | ], |
| | "temperature": 0.7, |
| | "max_tokens": 500 |
| | } |
| | ) |
| | |
| | if response.status_code != 200: |
| | raise Exception(f"图片处理API调用失败: {response.text}") |
| | |
| | description = response.json()["choices"][0]["message"]["content"] |
| | return description |
| | |
| | except Exception as e: |
| | print(f"处理图片时出错: {str(e)}") |
| | return "图片处理失败" |
| | |
| | def load(self): |
| | try: |
| | print(f"正在加载文件: {self.file_path}, 原始文件名: {self.original_filename}, 扩展名: {self.extension}") |
| | |
| | if self.extension == '.md': |
| | try: |
| | loader = UnstructuredMarkdownLoader(self.file_path, encoding='utf-8') |
| | return loader.load() |
| | except UnicodeDecodeError: |
| | |
| | loader = UnstructuredMarkdownLoader(self.file_path, encoding='gbk') |
| | return loader.load() |
| | elif self.extension == '.pdf': |
| | loader = PyPDFLoader(self.file_path) |
| | return loader.load() |
| | elif self.extension == '.txt': |
| | try: |
| | loader = TextLoader(self.file_path, encoding='utf-8') |
| | return loader.load() |
| | except UnicodeDecodeError: |
| | |
| | loader = TextLoader(self.file_path, encoding='gbk') |
| | return loader.load() |
| | elif self.extension in ['.png', '.jpg', '.jpeg', '.gif', '.bmp']: |
| | |
| | description = self.process_image(self.file_path) |
| | |
| | from langchain.schema import Document |
| | doc = Document( |
| | page_content=description, |
| | metadata={ |
| | 'source': self.file_path, |
| | 'file_name': self.original_filename, |
| | 'img_url': os.path.abspath(self.file_path) |
| | } |
| | ) |
| | return [doc] |
| | else: |
| | print(f"不支持的文件扩展名: {self.extension}") |
| | raise ValueError(f"不支持的文件格式: {self.extension}") |
| | |
| | except UnicodeDecodeError: |
| | |
| | print(f"文件编码错误,尝试其他编码: {self.file_path}") |
| | if self.extension in ['.md', '.txt']: |
| | try: |
| | loader = TextLoader(self.file_path, encoding='gbk') |
| | return loader.load() |
| | except Exception as e: |
| | print(f"尝试GBK编码也失败: {str(e)}") |
| | raise |
| | except Exception as e: |
| | print(f"加载文件 {self.file_path} 时出错: {str(e)}") |
| | import traceback |
| | traceback.print_exc() |
| | raise |
| |
|
| | class DocumentProcessor: |
| | def __init__(self): |
| | self.text_splitter = RecursiveCharacterTextSplitter( |
| | chunk_size=1000, |
| | chunk_overlap=200, |
| | length_function=len, |
| | ) |
| | |
| | def get_index_name(self, path: str) -> str: |
| | """根据文件路径生成索引名称""" |
| | if os.path.isdir(path): |
| | |
| | return f"rag_{os.path.basename(path).lower()}" |
| | else: |
| | |
| | return f"rag_{os.path.splitext(os.path.basename(path))[0].lower()}" |
| | |
| | def process(self, path: str, progress_callback: Optional[Callable] = None, original_filename: str = None) -> List[Dict]: |
| | """ |
| | 加载并处理文档,支持目录或单个文件 |
| | 参数: |
| | path: 文档路径 |
| | progress_callback: 进度回调函数,用于报告处理进度 |
| | original_filename: 原始文件名(包括中文) |
| | 返回:处理后的文档列表 |
| | """ |
| | if os.path.isdir(path): |
| | documents = [] |
| | total_files = sum([len(files) for _, _, files in os.walk(path)]) |
| | processed_files = 0 |
| | processed_size = 0 |
| | |
| | for root, _, files in os.walk(path): |
| | for file in files: |
| | file_path = os.path.join(root, file) |
| | try: |
| | |
| | if progress_callback: |
| | file_size = os.path.getsize(file_path) |
| | processed_size += file_size |
| | processed_files += 1 |
| | progress_callback(processed_size, f"处理文件 {processed_files}/{total_files}: {file}") |
| | |
| | |
| | loader = DocumentLoader(file_path, original_filename=file) |
| | docs = loader.load() |
| | |
| | for doc in docs: |
| | doc.metadata['file_name'] = file |
| | documents.extend(docs) |
| | except Exception as e: |
| | print(f"警告:加载文件 {file_path} 时出错: {str(e)}") |
| | continue |
| | else: |
| | try: |
| | if progress_callback: |
| | file_size = os.path.getsize(path) |
| | progress_callback(file_size * 0.3, f"加载文件: {original_filename or os.path.basename(path)}") |
| | |
| | |
| | loader = DocumentLoader(path, original_filename=original_filename) |
| | documents = loader.load() |
| | |
| | |
| | if progress_callback: |
| | progress_callback(file_size * 0.6, f"处理文件内容...") |
| | |
| | |
| | file_name = original_filename or os.path.basename(path) |
| | for doc in documents: |
| | doc.metadata['file_name'] = file_name |
| | except Exception as e: |
| | print(f"加载文件时出错: {str(e)}") |
| | raise |
| | |
| | |
| | chunks = self.text_splitter.split_documents(documents) |
| | |
| | |
| | if progress_callback: |
| | if os.path.isdir(path): |
| | progress_callback(processed_size, f"文档分块完成,共{len(chunks)}个文档片段") |
| | else: |
| | file_size = os.path.getsize(path) |
| | progress_callback(file_size * 0.9, f"文档分块完成,共{len(chunks)}个文档片段") |
| | |
| | |
| | processed_docs = [] |
| | for i, chunk in enumerate(chunks): |
| | processed_docs.append({ |
| | 'id': f'doc_{i}', |
| | 'content': chunk.page_content, |
| | 'metadata': chunk.metadata |
| | }) |
| | |
| | return processed_docs |