| """ |
| 认证API模块 |
| """ |
|
|
| import asyncio |
| import json |
| import secrets |
| import socket |
| import threading |
| import time |
| import uuid |
| from datetime import timezone |
| from http.server import BaseHTTPRequestHandler, HTTPServer |
| from typing import Any, Dict, List, Optional |
| from urllib.parse import parse_qs, urlparse |
|
|
| from config import get_config_value, get_antigravity_api_url, get_code_assist_endpoint |
| from log import log |
|
|
| from .google_oauth_api import ( |
| Credentials, |
| Flow, |
| enable_required_apis, |
| fetch_project_id, |
| get_user_projects, |
| select_default_project, |
| ) |
| from .storage_adapter import get_storage_adapter |
| from .utils import ( |
| ANTIGRAVITY_CLIENT_ID, |
| ANTIGRAVITY_CLIENT_SECRET, |
| ANTIGRAVITY_SCOPES, |
| ANTIGRAVITY_USER_AGENT, |
| CALLBACK_HOST, |
| CLIENT_ID, |
| CLIENT_SECRET, |
| SCOPES, |
| GEMINICLI_USER_AGENT, |
| TOKEN_URL, |
| ) |
|
|
|
|
| async def get_callback_port(): |
| """获取OAuth回调端口""" |
| return int(await get_config_value("oauth_callback_port", "11451", "OAUTH_CALLBACK_PORT")) |
|
|
|
|
| def _prepare_credentials_data(credentials: Credentials, project_id: str, mode: str = "geminicli") -> Dict[str, Any]: |
| """准备凭证数据字典(统一函数)""" |
| if mode == "antigravity": |
| creds_data = { |
| "client_id": ANTIGRAVITY_CLIENT_ID, |
| "client_secret": ANTIGRAVITY_CLIENT_SECRET, |
| "token": credentials.access_token, |
| "refresh_token": credentials.refresh_token, |
| "scopes": ANTIGRAVITY_SCOPES, |
| "token_uri": TOKEN_URL, |
| "project_id": project_id, |
| } |
| else: |
| creds_data = { |
| "client_id": CLIENT_ID, |
| "client_secret": CLIENT_SECRET, |
| "token": credentials.access_token, |
| "refresh_token": credentials.refresh_token, |
| "scopes": SCOPES, |
| "token_uri": TOKEN_URL, |
| "project_id": project_id, |
| } |
|
|
| if credentials.expires_at: |
| if credentials.expires_at.tzinfo is None: |
| expiry_utc = credentials.expires_at.replace(tzinfo=timezone.utc) |
| else: |
| expiry_utc = credentials.expires_at |
| creds_data["expiry"] = expiry_utc.isoformat() |
|
|
| return creds_data |
|
|
|
|
| def _generate_random_project_id() -> str: |
| """生成随机project_id(antigravity模式使用)""" |
| random_id = uuid.uuid4().hex[:8] |
| return f"projects/random-{random_id}/locations/global" |
|
|
|
|
| def _cleanup_auth_flow_server(state: str): |
| """清理认证流程的服务器资源""" |
| if state in auth_flows: |
| flow_data_to_clean = auth_flows[state] |
| try: |
| if flow_data_to_clean.get("server"): |
| server = flow_data_to_clean["server"] |
| port = flow_data_to_clean.get("callback_port") |
| async_shutdown_server(server, port) |
| except Exception as e: |
| log.debug(f"关闭服务器时出错: {e}") |
| del auth_flows[state] |
|
|
|
|
| class _OAuthLibPatcher: |
| """oauthlib参数验证补丁的上下文管理器""" |
| def __init__(self): |
| import oauthlib.oauth2.rfc6749.parameters |
| self.module = oauthlib.oauth2.rfc6749.parameters |
| self.original_validate = None |
|
|
| def __enter__(self): |
| self.original_validate = self.module.validate_token_parameters |
|
|
| def patched_validate(params): |
| try: |
| return self.original_validate(params) |
| except Warning: |
| pass |
|
|
| self.module.validate_token_parameters = patched_validate |
| return self |
|
|
| def __exit__(self, exc_type, exc_val, exc_tb): |
| if self.original_validate: |
| self.module.validate_token_parameters = self.original_validate |
|
|
|
|
| |
| auth_flows = {} |
| MAX_AUTH_FLOWS = 20 |
|
|
|
|
| def cleanup_auth_flows_for_memory(): |
| """清理认证流程以释放内存""" |
| global auth_flows |
| cleanup_expired_flows() |
| |
| if len(auth_flows) > 10: |
| |
| sorted_flows = sorted( |
| auth_flows.items(), key=lambda x: x[1].get("created_at", 0), reverse=True |
| ) |
| new_auth_flows = dict(sorted_flows[:10]) |
|
|
| |
| for state, flow_data in auth_flows.items(): |
| if state not in new_auth_flows: |
| try: |
| if flow_data.get("server"): |
| server = flow_data["server"] |
| port = flow_data.get("callback_port") |
| async_shutdown_server(server, port) |
| except Exception: |
| pass |
| flow_data.clear() |
|
|
| auth_flows = new_auth_flows |
| log.info(f"强制清理认证流程,保留 {len(auth_flows)} 个最新流程") |
|
|
| return len(auth_flows) |
|
|
|
|
| async def find_available_port(start_port: int = None) -> int: |
| """动态查找可用端口""" |
| if start_port is None: |
| start_port = await get_callback_port() |
|
|
| |
| for port in range(start_port, start_port + 100): |
| try: |
| with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: |
| s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
| s.bind(("0.0.0.0", port)) |
| log.info(f"找到可用端口: {port}") |
| return port |
| except OSError: |
| continue |
|
|
| |
| try: |
| with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: |
| s.bind(("0.0.0.0", 0)) |
| port = s.getsockname()[1] |
| log.info(f"系统分配可用端口: {port}") |
| return port |
| except OSError as e: |
| log.error(f"无法找到可用端口: {e}") |
| raise RuntimeError("无法找到可用端口") |
|
|
|
|
| def create_callback_server(port: int) -> HTTPServer: |
| """创建指定端口的回调服务器,优化快速关闭""" |
| try: |
| |
| server = HTTPServer(("0.0.0.0", port), AuthCallbackHandler) |
|
|
| |
| server.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
| |
| server.timeout = 1.0 |
|
|
| log.info(f"创建OAuth回调服务器,监听端口: {port}") |
| return server |
| except OSError as e: |
| log.error(f"创建端口{port}的服务器失败: {e}") |
| raise |
|
|
|
|
| class AuthCallbackHandler(BaseHTTPRequestHandler): |
| """OAuth回调处理器""" |
|
|
| def do_GET(self): |
| query_components = parse_qs(urlparse(self.path).query) |
| code = query_components.get("code", [None])[0] |
| state = query_components.get("state", [None])[0] |
|
|
| log.info(f"收到OAuth回调: code={'已获取' if code else '未获取'}, state={state}") |
|
|
| if code and state and state in auth_flows: |
| |
| auth_flows[state]["code"] = code |
| auth_flows[state]["completed"] = True |
|
|
| log.info(f"OAuth回调成功处理: state={state}") |
|
|
| self.send_response(200) |
| self.send_header("Content-type", "text/html") |
| self.end_headers() |
| |
| self.wfile.write( |
| b"<h1>OAuth authentication successful!</h1><p>You can close this window. Please return to the original page and click 'Get Credentials' button.</p>" |
| ) |
| else: |
| self.send_response(400) |
| self.send_header("Content-type", "text/html") |
| self.end_headers() |
| self.wfile.write(b"<h1>Authentication failed.</h1><p>Please try again.</p>") |
|
|
| def log_message(self, format, *args): |
| |
| pass |
|
|
|
|
| async def create_auth_url( |
| project_id: Optional[str] = None, user_session: str = None, mode: str = "geminicli" |
| ) -> Dict[str, Any]: |
| """创建认证URL,支持动态端口分配""" |
| try: |
| |
| callback_port = await find_available_port() |
| callback_url = f"http://{CALLBACK_HOST}:{callback_port}" |
|
|
| |
| try: |
| callback_server = create_callback_server(callback_port) |
| |
| server_thread = threading.Thread( |
| target=callback_server.serve_forever, |
| daemon=True, |
| name=f"OAuth-Server-{callback_port}", |
| ) |
| server_thread.start() |
| log.info(f"OAuth回调服务器已启动,端口: {callback_port}") |
| except Exception as e: |
| log.error(f"启动回调服务器失败: {e}") |
| return { |
| "success": False, |
| "error": f"无法启动OAuth回调服务器,端口{callback_port}: {str(e)}", |
| } |
|
|
| |
| |
| if mode == "antigravity": |
| client_id = ANTIGRAVITY_CLIENT_ID |
| client_secret = ANTIGRAVITY_CLIENT_SECRET |
| scopes = ANTIGRAVITY_SCOPES |
| else: |
| client_id = CLIENT_ID |
| client_secret = CLIENT_SECRET |
| scopes = SCOPES |
|
|
| flow = Flow( |
| client_id=client_id, |
| client_secret=client_secret, |
| scopes=scopes, |
| redirect_uri=callback_url, |
| ) |
|
|
| |
| if user_session: |
| state = f"{user_session}_{str(uuid.uuid4())}" |
| else: |
| state = str(uuid.uuid4()) |
|
|
| |
| auth_url = flow.get_auth_url(state=state) |
|
|
| |
| if len(auth_flows) >= MAX_AUTH_FLOWS: |
| |
| oldest_state = min(auth_flows.keys(), key=lambda k: auth_flows[k].get("created_at", 0)) |
| try: |
| |
| old_flow = auth_flows[oldest_state] |
| if old_flow.get("server"): |
| server = old_flow["server"] |
| port = old_flow.get("callback_port") |
| async_shutdown_server(server, port) |
| except Exception as e: |
| log.warning(f"Failed to cleanup old auth flow {oldest_state}: {e}") |
|
|
| del auth_flows[oldest_state] |
| log.debug(f"Removed oldest auth flow: {oldest_state}") |
|
|
| |
| auth_flows[state] = { |
| "flow": flow, |
| "project_id": project_id, |
| "user_session": user_session, |
| "callback_port": callback_port, |
| "callback_url": callback_url, |
| "server": callback_server, |
| "server_thread": server_thread, |
| "code": None, |
| "completed": False, |
| "created_at": time.time(), |
| "auto_project_detection": project_id is None, |
| "mode": mode, |
| } |
|
|
| |
| cleanup_expired_flows() |
|
|
| log.info(f"OAuth流程已创建: state={state}, project_id={project_id}") |
| log.info(f"用户需要访问认证URL,然后OAuth会回调到 {callback_url}") |
| log.info(f"为此认证流程分配的端口: {callback_port}") |
|
|
| return { |
| "auth_url": auth_url, |
| "state": state, |
| "callback_port": callback_port, |
| "success": True, |
| "auto_project_detection": project_id is None, |
| "detected_project_id": project_id, |
| } |
|
|
| except Exception as e: |
| log.error(f"创建认证URL失败: {e}") |
| return {"success": False, "error": str(e)} |
|
|
|
|
| def wait_for_callback_sync(state: str, timeout: int = 300) -> Optional[str]: |
| """同步等待OAuth回调完成,使用对应流程的专用服务器""" |
| if state not in auth_flows: |
| log.error(f"未找到状态为 {state} 的认证流程") |
| return None |
|
|
| flow_data = auth_flows[state] |
| callback_port = flow_data["callback_port"] |
|
|
| |
| log.info(f"等待OAuth回调完成,端口: {callback_port}") |
|
|
| |
| start_time = time.time() |
| while time.time() - start_time < timeout: |
| if flow_data.get("code"): |
| log.info("OAuth回调成功完成") |
| return flow_data["code"] |
| time.sleep(0.5) |
|
|
| |
| if state in auth_flows: |
| flow_data = auth_flows[state] |
|
|
| log.warning(f"等待OAuth回调超时 ({timeout}秒)") |
| return None |
|
|
|
|
| async def complete_auth_flow( |
| project_id: Optional[str] = None, user_session: str = None |
| ) -> Dict[str, Any]: |
| """完成认证流程并保存凭证,支持自动检测项目ID""" |
| try: |
| |
| state = None |
| flow_data = None |
|
|
| |
| if project_id: |
| for s, data in auth_flows.items(): |
| if data["project_id"] == project_id: |
| |
| if user_session and data.get("user_session") == user_session: |
| state = s |
| flow_data = data |
| break |
| |
| elif not state: |
| state = s |
| flow_data = data |
|
|
| |
| if not state: |
| for s, data in auth_flows.items(): |
| if data.get("auto_project_detection", False): |
| |
| if user_session and data.get("user_session") == user_session: |
| state = s |
| flow_data = data |
| break |
| |
| elif not state: |
| state = s |
| flow_data = data |
|
|
| if not state or not flow_data: |
| return {"success": False, "error": "未找到对应的认证流程,请先点击获取认证链接"} |
|
|
| if not project_id: |
| project_id = flow_data.get("project_id") |
| if not project_id: |
| return { |
| "success": False, |
| "error": "缺少项目ID,请指定项目ID", |
| "requires_manual_project_id": True, |
| } |
|
|
| flow = flow_data["flow"] |
|
|
| |
| if not flow_data.get("code"): |
| log.info(f"等待用户完成OAuth授权 (state: {state})") |
| auth_code = wait_for_callback_sync(state) |
|
|
| if not auth_code: |
| return { |
| "success": False, |
| "error": "未接收到授权回调,请确保完成了浏览器中的OAuth认证", |
| } |
|
|
| |
| auth_flows[state]["code"] = auth_code |
| auth_flows[state]["completed"] = True |
| else: |
| auth_code = flow_data["code"] |
|
|
| |
| with _OAuthLibPatcher(): |
| try: |
| credentials = await flow.exchange_code(auth_code) |
| |
|
|
| |
| if flow_data.get("auto_project_detection", False) and not project_id: |
| log.info("尝试通过API获取用户项目列表...") |
| log.info(f"使用的token: {credentials.access_token[:20]}...") |
| log.info(f"Token过期时间: {credentials.expires_at}") |
| user_projects = await get_user_projects(credentials) |
|
|
| if user_projects: |
| |
| if len(user_projects) == 1: |
| |
| project_id = user_projects[0].get("projectId") |
| if project_id: |
| flow_data["project_id"] = project_id |
| log.info(f"自动选择唯一项目: {project_id}") |
| |
| else: |
| project_id = await select_default_project(user_projects) |
| if project_id: |
| flow_data["project_id"] = project_id |
| log.info(f"自动选择默认项目: {project_id}") |
| else: |
| |
| return { |
| "success": False, |
| "error": "请从以下项目中选择一个", |
| "requires_project_selection": True, |
| "available_projects": [ |
| { |
| |
| "project_id": p.get("projectId"), |
| "name": p.get("displayName") or p.get("projectId"), |
| "projectNumber": p.get("projectNumber"), |
| } |
| for p in user_projects |
| ], |
| } |
| else: |
| |
| return { |
| "success": False, |
| "error": "无法获取您的项目列表,请手动指定项目ID", |
| "requires_manual_project_id": True, |
| } |
|
|
| |
| if not project_id: |
| return { |
| "success": False, |
| "error": "缺少项目ID,请指定项目ID", |
| "requires_manual_project_id": True, |
| } |
|
|
| |
| saved_filename = await save_credentials(credentials, project_id) |
|
|
| |
| creds_data = _prepare_credentials_data(credentials, project_id, mode="geminicli") |
|
|
| |
| _cleanup_auth_flow_server(state) |
|
|
| log.info("OAuth认证成功,凭证已保存") |
| return { |
| "success": True, |
| "credentials": creds_data, |
| "file_path": saved_filename, |
| "auto_detected_project": flow_data.get("auto_project_detection", False), |
| } |
|
|
| except Exception as e: |
| log.error(f"获取凭证失败: {e}") |
| return {"success": False, "error": f"获取凭证失败: {str(e)}"} |
|
|
| except Exception as e: |
| log.error(f"完成认证流程失败: {e}") |
| return {"success": False, "error": str(e)} |
|
|
|
|
| async def asyncio_complete_auth_flow( |
| project_id: Optional[str] = None, user_session: str = None, mode: str = "geminicli" |
| ) -> Dict[str, Any]: |
| """异步完成认证流程,支持自动检测项目ID""" |
| try: |
| log.info( |
| f"asyncio_complete_auth_flow开始执行: project_id={project_id}, user_session={user_session}" |
| ) |
|
|
| |
| state = None |
| flow_data = None |
|
|
| log.debug(f"当前所有auth_flows: {list(auth_flows.keys())}") |
|
|
| |
| if project_id: |
| log.info(f"尝试匹配指定的项目ID: {project_id}") |
| for s, data in auth_flows.items(): |
| if data["project_id"] == project_id: |
| |
| if user_session and data.get("user_session") == user_session: |
| state = s |
| flow_data = data |
| log.info(f"找到匹配的用户会话: {s}") |
| break |
| |
| elif not state: |
| state = s |
| flow_data = data |
| log.info(f"找到匹配的项目ID: {s}") |
|
|
| |
| if not state: |
| log.info("没有找到指定项目的流程,查找自动检测流程") |
| |
| completed_flows = [] |
| for s, data in auth_flows.items(): |
| if data.get("auto_project_detection", False): |
| if user_session and data.get("user_session") == user_session: |
| if data.get("code"): |
| completed_flows.append((s, data, data.get("created_at", 0))) |
|
|
| |
| if completed_flows: |
| completed_flows.sort(key=lambda x: x[2], reverse=True) |
| state, flow_data, _ = completed_flows[0] |
| log.info(f"找到已完成的最新认证流程: {state}") |
| else: |
| |
| pending_flows = [] |
| for s, data in auth_flows.items(): |
| if data.get("auto_project_detection", False): |
| if user_session and data.get("user_session") == user_session: |
| pending_flows.append((s, data, data.get("created_at", 0))) |
| elif not user_session: |
| pending_flows.append((s, data, data.get("created_at", 0))) |
|
|
| if pending_flows: |
| pending_flows.sort(key=lambda x: x[2], reverse=True) |
| state, flow_data, _ = pending_flows[0] |
| log.info(f"找到最新的待完成认证流程: {state}") |
|
|
| if not state or not flow_data: |
| log.error(f"未找到认证流程: state={state}, flow_data存在={bool(flow_data)}") |
| log.debug(f"当前所有flow_data: {list(auth_flows.keys())}") |
| return {"success": False, "error": "未找到对应的认证流程,请先点击获取认证链接"} |
|
|
| log.info(f"找到认证流程: state={state}") |
| log.info( |
| f"flow_data内容: project_id={flow_data.get('project_id')}, auto_project_detection={flow_data.get('auto_project_detection')}" |
| ) |
| log.info(f"传入的project_id参数: {project_id}") |
|
|
| |
| log.info( |
| f"检查auto_project_detection条件: auto_project_detection={flow_data.get('auto_project_detection', False)}, not project_id={not project_id}" |
| ) |
| if flow_data.get("auto_project_detection", False) and not project_id: |
| log.info("跳过自动检测项目ID,进入等待阶段") |
| elif not project_id: |
| log.info("进入project_id检查分支") |
| project_id = flow_data.get("project_id") |
| if not project_id: |
| log.error("缺少项目ID,返回错误") |
| return { |
| "success": False, |
| "error": "缺少项目ID,请指定项目ID", |
| "requires_manual_project_id": True, |
| } |
| else: |
| log.info(f"使用提供的项目ID: {project_id}") |
|
|
| |
| log.info("开始检查OAuth授权码...") |
| log.info(f"等待state={state}的授权回调,回调端口: {flow_data.get('callback_port')}") |
| log.info(f"当前flow_data状态: completed={flow_data.get('completed')}, code存在={bool(flow_data.get('code'))}") |
| max_wait_time = 60 |
| wait_interval = 1 |
| waited = 0 |
|
|
| while waited < max_wait_time: |
| if flow_data.get("code"): |
| log.info(f"检测到OAuth授权码,开始处理凭证 (等待时间: {waited}秒)") |
| break |
|
|
| |
| if waited % 5 == 0 and waited > 0: |
| log.info(f"仍在等待OAuth授权... ({waited}/{max_wait_time}秒)") |
| log.debug(f"当前state: {state}, flow_data keys: {list(flow_data.keys())}") |
|
|
| |
| await asyncio.sleep(wait_interval) |
| waited += wait_interval |
|
|
| |
| if state in auth_flows: |
| flow_data = auth_flows[state] |
|
|
| if not flow_data.get("code"): |
| log.error(f"等待OAuth回调超时,等待了{waited}秒") |
| return { |
| "success": False, |
| "error": "等待OAuth回调超时,请确保完成了浏览器中的认证并看到成功页面", |
| } |
|
|
| flow = flow_data["flow"] |
| auth_code = flow_data["code"] |
|
|
| log.info(f"开始使用授权码获取凭证: code={'***' + auth_code[-4:] if auth_code else 'None'}") |
|
|
| |
| with _OAuthLibPatcher(): |
| try: |
| log.info("调用flow.exchange_code...") |
| credentials = await flow.exchange_code(auth_code) |
| log.info( |
| f"成功获取凭证,token前缀: {credentials.access_token[:20] if credentials.access_token else 'None'}..." |
| ) |
|
|
| log.info( |
| f"检查是否需要项目检测: auto_project_detection={flow_data.get('auto_project_detection')}, project_id={project_id}" |
| ) |
|
|
| |
| cred_mode = flow_data.get("mode", "geminicli") if flow_data.get("mode") else mode |
| if cred_mode == "antigravity": |
| log.info("Antigravity模式:从API获取project_id...") |
| |
| antigravity_url = await get_antigravity_api_url() |
| project_id = await fetch_project_id( |
| credentials.access_token, |
| ANTIGRAVITY_USER_AGENT, |
| antigravity_url |
| ) |
| if project_id: |
| log.info(f"成功从API获取project_id: {project_id}") |
| else: |
| log.warning("无法从API获取project_id,回退到随机生成") |
| project_id = _generate_random_project_id() |
| log.info(f"生成的随机project_id: {project_id}") |
|
|
| |
| saved_filename = await save_credentials(credentials, project_id, mode="antigravity") |
|
|
| |
| creds_data = _prepare_credentials_data(credentials, project_id, mode="antigravity") |
|
|
| |
| _cleanup_auth_flow_server(state) |
|
|
| log.info("Antigravity OAuth认证成功,凭证已保存") |
| return { |
| "success": True, |
| "credentials": creds_data, |
| "file_path": saved_filename, |
| "auto_detected_project": False, |
| "mode": "antigravity", |
| } |
|
|
| |
| if flow_data.get("auto_project_detection", False) and not project_id: |
| log.info("标准模式:从API获取project_id...") |
| |
| code_assist_url = await get_code_assist_endpoint() |
| project_id = await fetch_project_id( |
| credentials.access_token, |
| GEMINICLI_USER_AGENT, |
| code_assist_url |
| ) |
| if project_id: |
| flow_data["project_id"] = project_id |
| log.info(f"成功从API获取project_id: {project_id}") |
| |
| log.info("正在自动启用必需的API服务...") |
| await enable_required_apis(credentials, project_id) |
| else: |
| log.warning("无法从API获取project_id,回退到项目列表获取方式") |
| |
| user_projects = await get_user_projects(credentials) |
|
|
| if user_projects: |
| |
| if len(user_projects) == 1: |
| |
| project_id = user_projects[0].get("projectId") |
| if project_id: |
| flow_data["project_id"] = project_id |
| log.info(f"自动选择唯一项目: {project_id}") |
| |
| log.info("正在自动启用必需的API服务...") |
| await enable_required_apis(credentials, project_id) |
| |
| else: |
| project_id = await select_default_project(user_projects) |
| if project_id: |
| flow_data["project_id"] = project_id |
| log.info(f"自动选择默认项目: {project_id}") |
| |
| log.info("正在自动启用必需的API服务...") |
| await enable_required_apis(credentials, project_id) |
| else: |
| |
| return { |
| "success": False, |
| "error": "请从以下项目中选择一个", |
| "requires_project_selection": True, |
| "available_projects": [ |
| { |
| |
| "project_id": p.get("projectId"), |
| "name": p.get("displayName") or p.get("projectId"), |
| "projectNumber": p.get("projectNumber"), |
| } |
| for p in user_projects |
| ], |
| } |
| else: |
| |
| return { |
| "success": False, |
| "error": "无法获取您的项目列表,请手动指定项目ID", |
| "requires_manual_project_id": True, |
| } |
| elif project_id: |
| |
| log.info("正在为已提供的项目ID自动启用必需的API服务...") |
| await enable_required_apis(credentials, project_id) |
|
|
| |
| if not project_id: |
| return { |
| "success": False, |
| "error": "缺少项目ID,请指定项目ID", |
| "requires_manual_project_id": True, |
| } |
|
|
| |
| saved_filename = await save_credentials(credentials, project_id) |
|
|
| |
| creds_data = _prepare_credentials_data(credentials, project_id, mode="geminicli") |
|
|
| |
| _cleanup_auth_flow_server(state) |
|
|
| log.info("OAuth认证成功,凭证已保存") |
| return { |
| "success": True, |
| "credentials": creds_data, |
| "file_path": saved_filename, |
| "auto_detected_project": flow_data.get("auto_project_detection", False), |
| } |
|
|
| except Exception as e: |
| log.error(f"获取凭证失败: {e}") |
| return {"success": False, "error": f"获取凭证失败: {str(e)}"} |
|
|
| except Exception as e: |
| log.error(f"异步完成认证流程失败: {e}") |
| return {"success": False, "error": str(e)} |
|
|
|
|
| async def complete_auth_flow_from_callback_url( |
| callback_url: str, project_id: Optional[str] = None, mode: str = "geminicli" |
| ) -> Dict[str, Any]: |
| """从回调URL直接完成认证流程,无需启动本地服务器""" |
| try: |
| log.info(f"开始从回调URL完成认证: {callback_url}") |
|
|
| |
| parsed_url = urlparse(callback_url) |
| query_params = parse_qs(parsed_url.query) |
|
|
| |
| if "state" not in query_params or "code" not in query_params: |
| return {"success": False, "error": "回调URL缺少必要参数 (state 或 code)"} |
|
|
| state = query_params["state"][0] |
| code = query_params["code"][0] |
|
|
| log.info(f"从URL解析到: state={state}, code=xxx...") |
|
|
| |
| if state not in auth_flows: |
| return { |
| "success": False, |
| "error": f"未找到对应的认证流程,请先启动认证 (state: {state})", |
| } |
|
|
| flow_data = auth_flows[state] |
| flow = flow_data["flow"] |
|
|
| |
| redirect_uri = flow.redirect_uri |
| log.info(f"使用redirect_uri: {redirect_uri}") |
|
|
| try: |
| |
| credentials = await flow.exchange_code(code) |
| log.info("成功获取访问令牌") |
|
|
| |
| cred_mode = flow_data.get("mode", "geminicli") if flow_data.get("mode") else mode |
| if cred_mode == "antigravity": |
| log.info("Antigravity模式(从回调URL):从API获取project_id...") |
| |
| antigravity_url = await get_antigravity_api_url() |
| project_id = await fetch_project_id( |
| credentials.access_token, |
| ANTIGRAVITY_USER_AGENT, |
| antigravity_url |
| ) |
| if project_id: |
| log.info(f"成功从API获取project_id: {project_id}") |
| else: |
| log.warning("无法从API获取project_id,回退到随机生成") |
| project_id = _generate_random_project_id() |
| log.info(f"生成的随机project_id: {project_id}") |
|
|
| |
| saved_filename = await save_credentials(credentials, project_id, mode="antigravity") |
|
|
| |
| creds_data = _prepare_credentials_data(credentials, project_id, mode="antigravity") |
|
|
| |
| _cleanup_auth_flow_server(state) |
|
|
| log.info("从回调URL完成Antigravity OAuth认证成功,凭证已保存") |
| return { |
| "success": True, |
| "credentials": creds_data, |
| "file_path": saved_filename, |
| "auto_detected_project": False, |
| "mode": "antigravity", |
| } |
|
|
| |
| detected_project_id = None |
| auto_detected = False |
|
|
| if not project_id: |
| |
| try: |
| log.info("标准模式:从API获取project_id...") |
| code_assist_url = await get_code_assist_endpoint() |
| detected_project_id = await fetch_project_id( |
| credentials.access_token, |
| GEMINICLI_USER_AGENT, |
| code_assist_url |
| ) |
| if detected_project_id: |
| auto_detected = True |
| log.info(f"成功从API获取project_id: {detected_project_id}") |
| else: |
| log.warning("无法从API获取project_id,回退到项目列表获取方式") |
| |
| projects = await get_user_projects(credentials) |
| if projects: |
| if len(projects) == 1: |
| |
| |
| detected_project_id = projects[0]["projectId"] |
| auto_detected = True |
| log.info(f"自动检测到唯一项目ID: {detected_project_id}") |
| else: |
| |
| |
| detected_project_id = projects[0]["projectId"] |
| auto_detected = True |
| log.info( |
| f"检测到{len(projects)}个项目,自动选择第一个: {detected_project_id}" |
| ) |
| log.debug(f"其他可用项目: {[p['projectId'] for p in projects[1:]]}") |
| else: |
| |
| return { |
| "success": False, |
| "error": "未检测到可访问的项目,请检查权限或手动指定项目ID", |
| "requires_manual_project_id": True, |
| } |
| except Exception as e: |
| log.warning(f"自动检测项目ID失败: {e}") |
| return { |
| "success": False, |
| "error": f"自动检测项目ID失败: {str(e)},请手动指定项目ID", |
| "requires_manual_project_id": True, |
| } |
| else: |
| detected_project_id = project_id |
|
|
| |
| if detected_project_id: |
| try: |
| log.info(f"正在为项目 {detected_project_id} 启用必需的API服务...") |
| await enable_required_apis(credentials, detected_project_id) |
| except Exception as e: |
| log.warning(f"启用API服务失败: {e}") |
|
|
| |
| saved_filename = await save_credentials(credentials, detected_project_id) |
|
|
| |
| creds_data = _prepare_credentials_data(credentials, detected_project_id, mode="geminicli") |
|
|
| |
| _cleanup_auth_flow_server(state) |
|
|
| log.info("从回调URL完成OAuth认证成功,凭证已保存") |
| return { |
| "success": True, |
| "credentials": creds_data, |
| "file_path": saved_filename, |
| "auto_detected_project": auto_detected, |
| } |
|
|
| except Exception as e: |
| log.error(f"从回调URL获取凭证失败: {e}") |
| return {"success": False, "error": f"获取凭证失败: {str(e)}"} |
|
|
| except Exception as e: |
| log.error(f"从回调URL完成认证流程失败: {e}") |
| return {"success": False, "error": str(e)} |
|
|
|
|
| async def save_credentials(creds: Credentials, project_id: str, mode: str = "geminicli") -> str: |
| """通过统一存储系统保存凭证""" |
| |
| timestamp = int(time.time()) |
|
|
| |
| if mode == "antigravity": |
| filename = f"ag_{project_id}-{timestamp}.json" |
| else: |
| filename = f"{project_id}-{timestamp}.json" |
|
|
| |
| creds_data = _prepare_credentials_data(creds, project_id, mode) |
|
|
| |
| storage_adapter = await get_storage_adapter() |
| success = await storage_adapter.store_credential(filename, creds_data, mode=mode) |
|
|
| if success: |
| |
| try: |
| default_state = { |
| "error_codes": [], |
| "disabled": False, |
| "last_success": time.time(), |
| "user_email": None, |
| } |
| await storage_adapter.update_credential_state(filename, default_state, mode=mode) |
| log.info(f"凭证和状态已保存到: {filename} (mode={mode})") |
| except Exception as e: |
| log.warning(f"创建默认状态记录失败 {filename}: {e}") |
|
|
| return filename |
| else: |
| raise Exception(f"保存凭证失败: {filename}") |
|
|
|
|
| def async_shutdown_server(server, port): |
| """异步关闭OAuth回调服务器,避免阻塞主流程""" |
|
|
| def shutdown_server_async(): |
| try: |
| |
| shutdown_completed = threading.Event() |
|
|
| def do_shutdown(): |
| try: |
| server.shutdown() |
| server.server_close() |
| shutdown_completed.set() |
| log.info(f"已关闭端口 {port} 的OAuth回调服务器") |
| except Exception as e: |
| shutdown_completed.set() |
| log.debug(f"关闭服务器时出错: {e}") |
|
|
| |
| shutdown_worker = threading.Thread(target=do_shutdown, daemon=True) |
| shutdown_worker.start() |
|
|
| |
| if shutdown_completed.wait(timeout=5): |
| log.debug(f"端口 {port} 服务器关闭完成") |
| else: |
| log.warning(f"端口 {port} 服务器关闭超时,但不阻塞主流程") |
|
|
| except Exception as e: |
| log.debug(f"异步关闭服务器时出错: {e}") |
|
|
| |
| shutdown_thread = threading.Thread(target=shutdown_server_async, daemon=True) |
| shutdown_thread.start() |
| log.debug(f"开始异步关闭端口 {port} 的OAuth回调服务器") |
|
|
|
|
| def cleanup_expired_flows(): |
| """清理过期的认证流程""" |
| current_time = time.time() |
| EXPIRY_TIME = 600 |
|
|
| |
| states_to_remove = [ |
| state |
| for state, flow_data in auth_flows.items() |
| if current_time - flow_data["created_at"] > EXPIRY_TIME |
| ] |
|
|
| |
| cleaned_count = 0 |
| for state in states_to_remove: |
| flow_data = auth_flows.get(state) |
| if flow_data: |
| |
| try: |
| if flow_data.get("server"): |
| server = flow_data["server"] |
| port = flow_data.get("callback_port") |
| async_shutdown_server(server, port) |
| except Exception as e: |
| log.debug(f"清理过期流程时启动异步关闭服务器失败: {e}") |
|
|
| |
| flow_data.clear() |
| del auth_flows[state] |
| cleaned_count += 1 |
|
|
| if cleaned_count > 0: |
| log.info(f"清理了 {cleaned_count} 个过期的认证流程") |
|
|
| |
| if len(auth_flows) > 20: |
| import gc |
|
|
| gc.collect() |
| log.debug(f"触发垃圾回收,当前活跃认证流程数: {len(auth_flows)}") |
|
|
|
|
| def get_auth_status(project_id: str) -> Dict[str, Any]: |
| """获取认证状态""" |
| for state, flow_data in auth_flows.items(): |
| if flow_data["project_id"] == project_id: |
| return { |
| "status": "completed" if flow_data["completed"] else "pending", |
| "state": state, |
| "created_at": flow_data["created_at"], |
| } |
|
|
| return {"status": "not_found"} |
|
|
|
|
| |
| auth_tokens = {} |
| TOKEN_EXPIRY = 3600 |
|
|
|
|
| async def verify_password(password: str) -> bool: |
| """验证密码(面板登录使用)""" |
| from config import get_panel_password |
|
|
| correct_password = await get_panel_password() |
| return password == correct_password |
|
|
|
|
| def generate_auth_token() -> str: |
| """生成认证令牌""" |
| |
| cleanup_expired_tokens() |
|
|
| token = secrets.token_urlsafe(32) |
| |
| auth_tokens[token] = time.time() |
| return token |
|
|
|
|
| def verify_auth_token(token: str) -> bool: |
| """验证认证令牌""" |
| if not token or token not in auth_tokens: |
| return False |
|
|
| created_at = auth_tokens[token] |
|
|
| |
| if time.time() - created_at > TOKEN_EXPIRY: |
| del auth_tokens[token] |
| return False |
|
|
| return True |
|
|
|
|
| def cleanup_expired_tokens(): |
| """清理过期的认证令牌""" |
| current_time = time.time() |
| expired_tokens = [ |
| token |
| for token, created_at in auth_tokens.items() |
| if current_time - created_at > TOKEN_EXPIRY |
| ] |
|
|
| for token in expired_tokens: |
| del auth_tokens[token] |
|
|
| if expired_tokens: |
| log.debug(f"清理了 {len(expired_tokens)} 个过期的认证令牌") |
|
|
|
|
| def invalidate_auth_token(token: str): |
| """使认证令牌失效""" |
| if token in auth_tokens: |
| del auth_tokens[token] |
|
|
|
|
| |
| def validate_credential_content(content: str) -> Dict[str, Any]: |
| """验证凭证内容格式""" |
| try: |
| creds_data = json.loads(content) |
|
|
| |
| required_fields = ["client_id", "client_secret", "refresh_token", "token_uri"] |
| missing_fields = [field for field in required_fields if field not in creds_data] |
|
|
| if missing_fields: |
| return {"valid": False, "error": f'缺少必要字段: {", ".join(missing_fields)}'} |
|
|
| |
| if "project_id" not in creds_data: |
| log.warning("认证文件缺少project_id字段") |
|
|
| return {"valid": True, "data": creds_data} |
|
|
| except json.JSONDecodeError as e: |
| return {"valid": False, "error": f"JSON格式错误: {str(e)}"} |
| except Exception as e: |
| return {"valid": False, "error": f"文件验证失败: {str(e)}"} |
|
|
|
|
| async def save_uploaded_credential(content: str, original_filename: str) -> Dict[str, Any]: |
| """通过统一存储系统保存上传的凭证""" |
| try: |
| |
| validation = validate_credential_content(content) |
| if not validation["valid"]: |
| return {"success": False, "error": validation["error"]} |
|
|
| creds_data = validation["data"] |
|
|
| |
| project_id = creds_data.get("project_id", "unknown") |
| timestamp = int(time.time()) |
|
|
| |
| import os |
|
|
| base_name = os.path.splitext(original_filename)[0] |
| filename = f"{base_name}-{timestamp}.json" |
|
|
| |
| storage_adapter = await get_storage_adapter() |
| success = await storage_adapter.store_credential(filename, creds_data) |
|
|
| if success: |
| log.info(f"凭证文件已上传保存: {filename}") |
| return {"success": True, "file_path": filename, "project_id": project_id} |
| else: |
| return {"success": False, "error": "保存到存储系统失败"} |
|
|
| except Exception as e: |
| log.error(f"保存上传文件失败: {e}") |
| return {"success": False, "error": str(e)} |
|
|
|
|
| async def batch_upload_credentials(files_data: List[Dict[str, str]]) -> Dict[str, Any]: |
| """批量上传凭证文件到统一存储系统""" |
| results = [] |
| success_count = 0 |
|
|
| for file_data in files_data: |
| filename = file_data.get("filename", "unknown.json") |
| content = file_data.get("content", "") |
|
|
| result = await save_uploaded_credential(content, filename) |
| result["filename"] = filename |
| results.append(result) |
|
|
| if result["success"]: |
| success_count += 1 |
|
|
| return {"uploaded_count": success_count, "total_count": len(files_data), "results": results} |
|
|