| from fastapi import APIRouter, Request, UploadFile, File, Form, HTTPException |
| from typing import List, Dict |
| from datetime import datetime |
| import logging |
| import os |
| import asyncio |
| import hashlib |
| import json |
| import tempfile |
| import shutil |
|
|
| from helloAgents.tools.async_executor import AsyncToolExecutor |
| from helloAgents.tools.builtin.rag_tool import RAGTool |
| from helloAgents.tools.builtin.memory_tool import MemoryTool |
| from helloAgents.tools.registry import global_registry |
|
|
| from redis_config import get_redis, QUEUE_RAG_QDRANT, QUEUE_RAG_NEO4J, QUEUE_MEMORY |
| redis = get_redis() |
|
|
| logger = logging.getLogger(__name__) |
| new_updateFile_router = APIRouter() |
|
|
| |
| |
| |
| MANIFEST_FILE = ".file_manifest.json" |
| CHUNK_SIZE = 8192 |
| MAX_PARALLEL_FILES = 4 |
|
|
| |
| |
| |
|
|
| def load_manifest(ns_path: str) -> Dict[str, str]: |
| manifest_path = os.path.join(ns_path, MANIFEST_FILE) |
| if os.path.exists(manifest_path): |
| try: |
| with open(manifest_path, "r", encoding="utf-8") as f: |
| return json.load(f) |
| except Exception: |
| logger.warning(f"读取 manifest 失败: {manifest_path}") |
| return {} |
| return {} |
|
|
| def save_manifest(ns_path: str, manifest: Dict[str, str]): |
| manifest_path = os.path.join(ns_path, MANIFEST_FILE) |
| temp_path = manifest_path + ".tmp" |
| with open(temp_path, "w", encoding="utf-8") as f: |
| json.dump(manifest, f, indent=2, ensure_ascii=False) |
| os.replace(temp_path, manifest_path) |
|
|
| def calculate_file_hash_stream(file_path: str) -> str: |
| sha256 = hashlib.sha256() |
| with open(file_path, "rb") as f: |
| while chunk := f.read(CHUNK_SIZE): |
| sha256.update(chunk) |
| return sha256.hexdigest() |
|
|
| def calculate_upload_hash_stream(file_obj) -> str: |
| sha256 = hashlib.sha256() |
| file_obj.file.seek(0) |
| while chunk := file_obj.file.read(CHUNK_SIZE): |
| sha256.update(chunk) |
| file_obj.file.seek(0) |
| return sha256.hexdigest() |
|
|
| async def save_file_with_version_and_deduplicate( |
| upload_file: UploadFile, |
| user_id: str, |
| save_dir: str = "./knowledge_base" |
| ) -> dict: |
| original_name = upload_file.filename.strip() |
| name, ext = os.path.splitext(original_name) |
| ns_path = os.path.join(save_dir, user_id) |
|
|
| os.makedirs(ns_path, exist_ok=True) |
| manifest = load_manifest(ns_path) |
|
|
| try: |
| new_file_hash = calculate_upload_hash_stream(upload_file) |
|
|
| for existing_name, existing_hash in manifest.items(): |
| if existing_hash == new_file_hash: |
| return { |
| "success": True, |
| "is_duplicate": True, |
| "filename": original_name, |
| "file_path": os.path.join(ns_path, existing_name), |
| "message": f"内容已存在(同名:{existing_name}),自动去重" |
| } |
|
|
| target_name = original_name |
| version = 1 |
| while target_name in manifest: |
| version += 1 |
| target_name = f"{name}_v{version}{ext}" |
| |
| final_path = os.path.join(ns_path, target_name) |
|
|
| with tempfile.NamedTemporaryFile(dir=ns_path, delete=False) as tmp_file: |
| shutil.copyfileobj(upload_file.file, tmp_file) |
| temp_file_path = tmp_file.name |
| |
| os.replace(temp_file_path, final_path) |
|
|
| manifest[target_name] = new_file_hash |
| save_manifest(ns_path, manifest) |
|
|
| return { |
| "success": True, |
| "is_duplicate": False, |
| "filename": original_name, |
| "file_path": final_path, |
| "hash": new_file_hash, |
| "message": f"已保存(版本 v{version})" if version > 1 else "已保存" |
| } |
|
|
| except Exception as e: |
| logger.error(f"保存失败 {upload_file.filename}: {str(e)}") |
| return { |
| "success": False, |
| "filename": upload_file.filename, |
| "error": str(e) |
| } |
|
|
| |
| |
| |
| async def process_uploaded_files(files: List[UploadFile], user_id: str) -> dict: |
| memory_tool: MemoryTool = global_registry.get_tool("memory") |
| rag_tool: RAGTool = global_registry.get_tool("rag") |
|
|
| if not files: |
| raise HTTPException(status_code=400, detail="请上传至少一个文件") |
| if not user_id.strip(): |
| raise HTTPException(status_code=400, detail="命名空间不能为空") |
|
|
| semaphore = asyncio.Semaphore(MAX_PARALLEL_FILES) |
| |
| async def save_with_limit(f): |
| async with semaphore: |
| return await save_file_with_version_and_deduplicate(f, user_id) |
|
|
| save_tasks = [save_with_limit(f) for f in files] |
| save_results = await asyncio.gather(*save_tasks) |
|
|
| saved_files = [] |
| save_errors = [] |
| duplicate_files = [] |
|
|
| for res in save_results: |
| if not res["success"]: |
| save_errors.append(res) |
| elif res.get("is_duplicate"): |
| duplicate_files.append(res) |
| else: |
| saved_files.append(res) |
|
|
| |
| |
| |
| now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
| total = len(files) |
| success_cnt = len(saved_files) |
| dup_cnt = len(duplicate_files) |
| fail_cnt = len(save_errors) |
|
|
| |
| success_files = "\n".join([f"- {f['filename']} | 路径:{f['file_path']}" for f in saved_files]) |
| dup_files_list = "\n".join([f"- {d['filename']}" for d in duplicate_files]) |
| err_files_list = "\n".join([f"- {e['filename']} | 原因:{e['error']}" for e in save_errors]) |
|
|
| |
| |
| |
| summary_content = f"""【文件上传总结】{now} | 用户:{user_id} |
| 总文件数:{total} |
| ✅ 上传成功:{success_cnt} 个 |
| {success_files if success_files else '无'} |
| |
| ⚠️ 重复文件:{dup_cnt} 个 |
| {dup_files_list if dup_files_list else '无'} |
| |
| ❌ 上传失败:{fail_cnt} 个 |
| {err_files_list if err_files_list else '无'} |
| """ |
| summary_tasks = { |
| "action": "add", |
| "user_id": user_id, |
| "memory_type": "episodic", |
| "content": summary_content.strip(), |
| "importance": 0.8, |
| "session_id": None |
| } |
| redis.rpush(QUEUE_MEMORY, json.dumps(summary_tasks)) |
|
|
| if not saved_files: |
| return { |
| "success": len(save_errors) == 0, |
| "msg": f"去重 {len(duplicate_files)} 个 | 失败 {len(save_errors)} 个", |
| "duplicate_count": len(duplicate_files), |
| "results": [], |
| "errors": save_errors, |
| "duplicates": duplicate_files |
| } |
|
|
| |
| task = { |
| "action": "add_document", |
| "file_path": saved_files, |
| "user_id": user_id |
| } |
| redis.rpush(QUEUE_RAG_QDRANT, json.dumps(task)) |
|
|
| |
| task_neo4j = { |
| "action": "add_neo4j_document", |
| "file_path": saved_files, |
| "user_id": user_id |
| } |
| redis.rpush(QUEUE_RAG_NEO4J, json.dumps(task_neo4j)) |
|
|
|
|
| return { |
| "success": True, |
| "msg": f"成功 {len(saved_files)} | 去重 {len(duplicate_files)}", |
| "data": { |
| "total": len(files), |
| "saved": len(saved_files), |
| "success": saved_files, |
| "duplicate_files": duplicate_files, |
| "namespace": user_id |
| } |
| } |
|
|
| |
| |
| |
| @new_updateFile_router.post("/new_update_file") |
| async def new_update_file( |
| request: Request, |
| files: List[UploadFile] = File(..., description="支持 txt, md, pdf, docx, doc, json"), |
| namespace: str = Form(..., description="命名空间") |
| ) -> dict: |
| return await process_uploaded_files(files, namespace) |