File size: 8,114 Bytes
8b383ad | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 | from fastapi import APIRouter, Request, UploadFile, File, Form, HTTPException
from typing import List, Dict
from datetime import datetime
import logging
import os
import asyncio
import hashlib
import json
import tempfile
import shutil
from helloAgents.tools.async_executor import AsyncToolExecutor
from helloAgents.tools.builtin.rag_tool import RAGTool
from helloAgents.tools.builtin.memory_tool import MemoryTool
from helloAgents.tools.registry import global_registry
from redis_config import get_redis, QUEUE_RAG_QDRANT, QUEUE_RAG_NEO4J, QUEUE_MEMORY
redis = get_redis()
logger = logging.getLogger(__name__)
new_updateFile_router = APIRouter()
# ------------------------------
# ⚙️ 企业级配置
# ------------------------------
MANIFEST_FILE = ".file_manifest.json"
CHUNK_SIZE = 8192
MAX_PARALLEL_FILES = 4
# ------------------------------
# 🛠️ 核心工具函数
# ------------------------------
def load_manifest(ns_path: str) -> Dict[str, str]:
manifest_path = os.path.join(ns_path, MANIFEST_FILE)
if os.path.exists(manifest_path):
try:
with open(manifest_path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception:
logger.warning(f"读取 manifest 失败: {manifest_path}")
return {}
return {}
def save_manifest(ns_path: str, manifest: Dict[str, str]):
manifest_path = os.path.join(ns_path, MANIFEST_FILE)
temp_path = manifest_path + ".tmp"
with open(temp_path, "w", encoding="utf-8") as f:
json.dump(manifest, f, indent=2, ensure_ascii=False)
os.replace(temp_path, manifest_path)
def calculate_file_hash_stream(file_path: str) -> str:
sha256 = hashlib.sha256()
with open(file_path, "rb") as f:
while chunk := f.read(CHUNK_SIZE):
sha256.update(chunk)
return sha256.hexdigest()
def calculate_upload_hash_stream(file_obj) -> str:
sha256 = hashlib.sha256()
file_obj.file.seek(0)
while chunk := file_obj.file.read(CHUNK_SIZE):
sha256.update(chunk)
file_obj.file.seek(0)
return sha256.hexdigest()
async def save_file_with_version_and_deduplicate(
upload_file: UploadFile,
user_id: str,
save_dir: str = "./knowledge_base"
) -> dict:
original_name = upload_file.filename.strip()
name, ext = os.path.splitext(original_name)
ns_path = os.path.join(save_dir, user_id)
os.makedirs(ns_path, exist_ok=True)
manifest = load_manifest(ns_path)
try:
new_file_hash = calculate_upload_hash_stream(upload_file)
for existing_name, existing_hash in manifest.items():
if existing_hash == new_file_hash:
return {
"success": True,
"is_duplicate": True,
"filename": original_name,
"file_path": os.path.join(ns_path, existing_name),
"message": f"内容已存在(同名:{existing_name}),自动去重"
}
target_name = original_name
version = 1
while target_name in manifest:
version += 1
target_name = f"{name}_v{version}{ext}"
final_path = os.path.join(ns_path, target_name)
with tempfile.NamedTemporaryFile(dir=ns_path, delete=False) as tmp_file:
shutil.copyfileobj(upload_file.file, tmp_file)
temp_file_path = tmp_file.name
os.replace(temp_file_path, final_path)
manifest[target_name] = new_file_hash
save_manifest(ns_path, manifest)
return {
"success": True,
"is_duplicate": False,
"filename": original_name,
"file_path": final_path,
"hash": new_file_hash,
"message": f"已保存(版本 v{version})" if version > 1 else "已保存"
}
except Exception as e:
logger.error(f"保存失败 {upload_file.filename}: {str(e)}")
return {
"success": False,
"filename": upload_file.filename,
"error": str(e)
}
# ------------------------------
# 主处理逻辑(已补全记忆)
# ------------------------------
async def process_uploaded_files(files: List[UploadFile], user_id: str) -> dict:
memory_tool: MemoryTool = global_registry.get_tool("memory")
rag_tool: RAGTool = global_registry.get_tool("rag")
if not files:
raise HTTPException(status_code=400, detail="请上传至少一个文件")
if not user_id.strip():
raise HTTPException(status_code=400, detail="命名空间不能为空")
semaphore = asyncio.Semaphore(MAX_PARALLEL_FILES)
async def save_with_limit(f):
async with semaphore:
return await save_file_with_version_and_deduplicate(f, user_id)
save_tasks = [save_with_limit(f) for f in files]
save_results = await asyncio.gather(*save_tasks)
saved_files = []
save_errors = []
duplicate_files = []
for res in save_results:
if not res["success"]:
save_errors.append(res)
elif res.get("is_duplicate"):
duplicate_files.append(res)
else:
saved_files.append(res)
# ============================
# ✅ 企业级精简版:只存 1 条总结记忆(包含所有文件路径+状态)
# ============================
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
total = len(files)
success_cnt = len(saved_files)
dup_cnt = len(duplicate_files)
fail_cnt = len(save_errors)
# 构建清晰的文件清单
success_files = "\n".join([f"- {f['filename']} | 路径:{f['file_path']}" for f in saved_files])
dup_files_list = "\n".join([f"- {d['filename']}" for d in duplicate_files])
err_files_list = "\n".join([f"- {e['filename']} | 原因:{e['error']}" for e in save_errors])
# ============================
# ✅ 记忆存储结束
# ============================
summary_content = f"""【文件上传总结】{now} | 用户:{user_id}
总文件数:{total}
✅ 上传成功:{success_cnt} 个
{success_files if success_files else '无'}
⚠️ 重复文件:{dup_cnt} 个
{dup_files_list if dup_files_list else '无'}
❌ 上传失败:{fail_cnt} 个
{err_files_list if err_files_list else '无'}
"""
summary_tasks = {
"action": "add",
"user_id": user_id,
"memory_type": "episodic",
"content": summary_content.strip(),
"importance": 0.8,
"session_id": None
}
redis.rpush(QUEUE_MEMORY, json.dumps(summary_tasks))
if not saved_files:
return {
"success": len(save_errors) == 0,
"msg": f"去重 {len(duplicate_files)} 个 | 失败 {len(save_errors)} 个",
"duplicate_count": len(duplicate_files),
"results": [],
"errors": save_errors,
"duplicates": duplicate_files
}
# RAG qdrant 入库
task = {
"action": "add_document",
"file_path": saved_files,
"user_id": user_id
}
redis.rpush(QUEUE_RAG_QDRANT, json.dumps(task))
# RAG neo4j 入库
task_neo4j = {
"action": "add_neo4j_document",
"file_path": saved_files,
"user_id": user_id
}
redis.rpush(QUEUE_RAG_NEO4J, json.dumps(task_neo4j))
return {
"success": True,
"msg": f"成功 {len(saved_files)} | 去重 {len(duplicate_files)}",
"data": {
"total": len(files),
"saved": len(saved_files),
"success": saved_files,
"duplicate_files": duplicate_files,
"namespace": user_id
}
}
# ------------------------------
# 路由
# ------------------------------
@new_updateFile_router.post("/new_update_file")
async def new_update_file(
request: Request,
files: List[UploadFile] = File(..., description="支持 txt, md, pdf, docx, doc, json"),
namespace: str = Form(..., description="命名空间")
) -> dict:
return await process_uploaded_files(files, namespace) |