| | from typing import List, Dict |
| | import requests |
| | import numpy as np |
| | from elasticsearch import Elasticsearch |
| | import urllib3 |
| | from dotenv import load_dotenv |
| | import os |
| |
|
| | load_dotenv() |
| |
|
| | urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) |
| |
|
| | class VectorStore: |
| | def __init__(self): |
| | |
| | self.es = Elasticsearch( |
| | "https://elastic.aixiao.xyz", |
| | basic_auth=("elastic", os.getenv("PASSWORD")), |
| | verify_certs=False, |
| | request_timeout=30, |
| | |
| | headers={"accept": "application/vnd.elasticsearch+json; compatible-with=8"}, |
| | ) |
| | self.api_key = os.getenv("API_KEY") |
| | self.api_base = os.getenv("BASE_URL") |
| | |
| | def get_embedding(self, text: str) -> List[float]: |
| | """调用SiliconFlow的embedding API获取向量""" |
| | headers = { |
| | "Authorization": f"Bearer {self.api_key}", |
| | "Content-Type": "application/json" |
| | } |
| | |
| | response = requests.post( |
| | f"{self.api_base}/embeddings", |
| | headers=headers, |
| | json={ |
| | "model": "BAAI/bge-m3", |
| | "input": text |
| | } |
| | ) |
| | |
| | if response.status_code == 200: |
| | return response.json()["data"][0]["embedding"] |
| | else: |
| | raise Exception(f"Error getting embedding: {response.text}") |
| | |
| | def store(self, documents: List[Dict], index_name: str) -> None: |
| | """将文档存储到 Elasticsearch""" |
| | |
| | if not self.es.indices.exists(index=index_name): |
| | self.create_index(index_name) |
| | |
| | |
| | try: |
| | response = self.es.count(index=index_name) |
| | last_id = response['count'] - 1 |
| | if last_id < 0: |
| | last_id = -1 |
| | except Exception as e: |
| | print(f"获取文档数量时出错,假设为-1: {str(e)}") |
| | last_id = -1 |
| | |
| | |
| | bulk_data = [] |
| | for i, doc in enumerate(documents, start=last_id + 1): |
| | |
| | vector = self.get_embedding(doc['content']) |
| | |
| | |
| | bulk_data.append({ |
| | "index": { |
| | "_index": index_name, |
| | "_id": f"doc_{i}" |
| | } |
| | }) |
| | |
| | |
| | doc_data = { |
| | "content": doc['content'], |
| | "vector": vector, |
| | "metadata": { |
| | "file_name": doc['metadata'].get('file_name', '未知文件'), |
| | "source": doc['metadata'].get('source', ''), |
| | "page": doc['metadata'].get('page', ''), |
| | "img_url": doc['metadata'].get('img_url', '') |
| | } |
| | } |
| | bulk_data.append(doc_data) |
| | |
| | |
| | if bulk_data: |
| | response = self.es.bulk(operations=bulk_data, refresh=True) |
| | if response.get('errors'): |
| | print("批量写入时出现错误:", response) |
| | |
| | def get_files_in_index(self, index_name: str) -> List[str]: |
| | """获取索引中的所有文件名""" |
| | try: |
| | response = self.es.search( |
| | index=index_name, |
| | body={ |
| | "size": 0, |
| | "aggs": { |
| | "unique_files": { |
| | "terms": { |
| | "field": "metadata.file_name", |
| | "size": 1000 |
| | } |
| | } |
| | } |
| | } |
| | ) |
| | |
| | files = [bucket['key'] for bucket in response['aggregations']['unique_files']['buckets']] |
| | return sorted(files) |
| | except Exception as e: |
| | print(f"获取文件列表时出错: {str(e)}") |
| | return [] |
| |
|
| | def create_index(self, index_name: str): |
| | """创建 Elasticsearch 索引""" |
| | settings = { |
| | "mappings": { |
| | "properties": { |
| | "content": {"type": "text"}, |
| | "vector": { |
| | "type": "dense_vector", |
| | "dims": 1024 |
| | }, |
| | "metadata": { |
| | "properties": { |
| | "file_name": { |
| | "type": "keyword", |
| | "ignore_above": 256 |
| | }, |
| | "source": { |
| | "type": "keyword" |
| | }, |
| | "page": { |
| | "type": "keyword" |
| | }, |
| | "img_url": { |
| | "type": "keyword", |
| | "ignore_above": 2048 |
| | } |
| | } |
| | } |
| | } |
| | } |
| | } |
| | |
| | |
| | if self.es.indices.exists(index=index_name): |
| | self.es.indices.delete(index=index_name) |
| | |
| | self.es.indices.create(index=index_name, body=settings) |
| | |
| | def delete_index(self, index_id: str) -> bool: |
| | """删除一个索引""" |
| | try: |
| | if self.es.indices.exists(index=index_id): |
| | self.es.indices.delete(index=index_id) |
| | return True |
| | return False |
| | except Exception as e: |
| | print(f"删除索引时出错: {str(e)}") |
| | return False |
| | |
| | def delete_document(self, index_id: str, file_name: str) -> bool: |
| | """根据文件名删除文档""" |
| | try: |
| | response = self.es.delete_by_query( |
| | index=index_id, |
| | body={ |
| | "query": { |
| | "term": { |
| | "metadata.file_name": file_name |
| | } |
| | } |
| | }, |
| | refresh=True |
| | ) |
| | return True |
| | except Exception as e: |
| | print(f"删除文档时出错: {str(e)}") |
| | return False |