| """ |
| Base API Client - 共用的 API 客户端基础功能 |
| 提供错误处理、自动封禁、重试逻辑等共同功能 |
| """ |
|
|
| import asyncio |
| import json |
| from datetime import datetime, timezone |
| from typing import Any, Dict, Optional |
|
|
| from fastapi import Response |
|
|
| from config import ( |
| get_auto_ban_enabled, |
| get_auto_ban_error_codes, |
| get_retry_429_enabled, |
| get_retry_429_interval, |
| get_retry_429_max_retries, |
| ) |
| from log import log |
| from src.credential_manager import CredentialManager |
|
|
|
|
| |
|
|
| async def check_should_auto_ban(status_code: int) -> bool: |
| """ |
| 检查是否应该触发自动封禁 |
| |
| Args: |
| status_code: HTTP状态码 |
| |
| Returns: |
| bool: 是否应该触发自动封禁 |
| """ |
| return ( |
| await get_auto_ban_enabled() |
| and status_code in await get_auto_ban_error_codes() |
| ) |
|
|
|
|
| async def handle_auto_ban( |
| credential_manager: CredentialManager, |
| status_code: int, |
| credential_name: str, |
| mode: str = "geminicli" |
| ) -> None: |
| """ |
| 处理自动封禁:直接禁用凭证 |
| |
| Args: |
| credential_manager: 凭证管理器实例 |
| status_code: HTTP状态码 |
| credential_name: 凭证名称 |
| mode: 模式(geminicli 或 antigravity) |
| """ |
| if credential_manager and credential_name: |
| log.warning( |
| f"[{mode.upper()} AUTO_BAN] Status {status_code} triggers auto-ban for credential: {credential_name}" |
| ) |
| await credential_manager.set_cred_disabled( |
| credential_name, True, mode=mode |
| ) |
|
|
|
|
| async def handle_error_with_retry( |
| credential_manager: CredentialManager, |
| status_code: int, |
| credential_name: str, |
| retry_enabled: bool, |
| attempt: int, |
| max_retries: int, |
| retry_interval: float, |
| mode: str = "geminicli" |
| ) -> bool: |
| """ |
| 统一处理错误和重试逻辑 |
| |
| 仅在以下情况下进行自动重试: |
| 1. 429错误(速率限制) |
| 2. 导致凭证封禁的错误(AUTO_BAN_ERROR_CODES配置) |
| |
| Args: |
| credential_manager: 凭证管理器实例 |
| status_code: HTTP状态码 |
| credential_name: 凭证名称 |
| retry_enabled: 是否启用重试 |
| attempt: 当前重试次数 |
| max_retries: 最大重试次数 |
| retry_interval: 重试间隔 |
| mode: 模式(geminicli 或 antigravity) |
| |
| Returns: |
| bool: True表示需要继续重试,False表示不需要重试 |
| """ |
| |
| should_auto_ban = await check_should_auto_ban(status_code) |
|
|
| if should_auto_ban: |
| |
| await handle_auto_ban(credential_manager, status_code, credential_name, mode) |
|
|
| |
| if retry_enabled and attempt < max_retries: |
| log.info( |
| f"[{mode.upper()} RETRY] Retrying with next credential after auto-ban " |
| f"(status {status_code}, attempt {attempt + 1}/{max_retries})" |
| ) |
| await asyncio.sleep(retry_interval) |
| return True |
| return False |
|
|
| |
| if status_code == 429 and retry_enabled and attempt < max_retries: |
| log.info( |
| f"[{mode.upper()} RETRY] 429 rate limit encountered, retrying " |
| f"(attempt {attempt + 1}/{max_retries})" |
| ) |
| await asyncio.sleep(retry_interval) |
| return True |
|
|
| |
| return False |
|
|
|
|
| |
|
|
| async def get_retry_config() -> Dict[str, Any]: |
| """ |
| 获取重试配置 |
| |
| Returns: |
| 包含重试配置的字典 |
| """ |
| return { |
| "retry_enabled": await get_retry_429_enabled(), |
| "max_retries": await get_retry_429_max_retries(), |
| "retry_interval": await get_retry_429_interval(), |
| } |
|
|
|
|
| |
|
|
| async def record_api_call_success( |
| credential_manager: CredentialManager, |
| credential_name: str, |
| mode: str = "geminicli", |
| model_key: Optional[str] = None |
| ) -> None: |
| """ |
| 记录API调用成功 |
| |
| Args: |
| credential_manager: 凭证管理器实例 |
| credential_name: 凭证名称 |
| mode: 模式(geminicli 或 antigravity) |
| model_key: 模型键(用于模型级CD) |
| """ |
| if credential_manager and credential_name: |
| await credential_manager.record_api_call_result( |
| credential_name, True, mode=mode, model_key=model_key |
| ) |
|
|
|
|
| async def record_api_call_error( |
| credential_manager: CredentialManager, |
| credential_name: str, |
| status_code: int, |
| cooldown_until: Optional[float] = None, |
| mode: str = "geminicli", |
| model_key: Optional[str] = None |
| ) -> None: |
| """ |
| 记录API调用错误 |
| |
| Args: |
| credential_manager: 凭证管理器实例 |
| credential_name: 凭证名称 |
| status_code: HTTP状态码 |
| cooldown_until: 冷却截止时间(Unix时间戳) |
| mode: 模式(geminicli 或 antigravity) |
| model_key: 模型键(用于模型级CD) |
| """ |
| if credential_manager and credential_name: |
| await credential_manager.record_api_call_result( |
| credential_name, |
| False, |
| status_code, |
| cooldown_until=cooldown_until, |
| mode=mode, |
| model_key=model_key |
| ) |
|
|
|
|
| |
|
|
| async def parse_and_log_cooldown( |
| error_text: str, |
| mode: str = "geminicli" |
| ) -> Optional[float]: |
| """ |
| 解析并记录冷却时间 |
| |
| Args: |
| error_text: 错误响应文本 |
| mode: 模式(geminicli 或 antigravity) |
| |
| Returns: |
| 冷却截止时间(Unix时间戳),如果解析失败则返回None |
| """ |
| try: |
| error_data = json.loads(error_text) |
| cooldown_until = parse_quota_reset_timestamp(error_data) |
| if cooldown_until: |
| log.info( |
| f"[{mode.upper()}] 检测到quota冷却时间: " |
| f"{datetime.fromtimestamp(cooldown_until, timezone.utc).isoformat()}" |
| ) |
| return cooldown_until |
| except Exception as parse_err: |
| log.debug(f"[{mode.upper()}] Failed to parse cooldown time: {parse_err}") |
| return None |
|
|
|
|
| |
|
|
| async def collect_streaming_response(stream_generator) -> Response: |
| """ |
| 将Gemini流式响应收集为一条完整的非流式响应 |
| |
| Args: |
| stream_generator: 流式响应生成器,产生 "data: {json}" 格式的行或Response对象 |
| |
| Returns: |
| Response: 合并后的完整响应对象 |
| |
| Example: |
| >>> async for line in stream_generator: |
| ... # line format: "data: {...}" or Response object |
| >>> response = await collect_streaming_response(stream_generator) |
| """ |
| |
| merged_response = { |
| "response": { |
| "candidates": [{ |
| "content": { |
| "parts": [], |
| "role": "model" |
| }, |
| "finishReason": None, |
| "safetyRatings": [], |
| "citationMetadata": None |
| }], |
| "usageMetadata": { |
| "promptTokenCount": 0, |
| "candidatesTokenCount": 0, |
| "totalTokenCount": 0 |
| } |
| } |
| } |
|
|
| collected_text = [] |
| collected_thought_text = [] |
| collected_other_parts = [] |
| has_data = False |
| line_count = 0 |
|
|
| log.debug("[STREAM COLLECTOR] Starting to collect streaming response") |
|
|
| try: |
| async for line in stream_generator: |
| line_count += 1 |
|
|
| |
| if isinstance(line, Response): |
| log.debug(f"[STREAM COLLECTOR] 收到错误Response,状态码: {line.status_code}") |
| return line |
|
|
| |
| if isinstance(line, bytes): |
| line_str = line.decode('utf-8', errors='ignore') |
| log.debug(f"[STREAM COLLECTOR] Processing bytes line {line_count}: {line_str[:200] if line_str else 'empty'}") |
| elif isinstance(line, str): |
| line_str = line |
| log.debug(f"[STREAM COLLECTOR] Processing line {line_count}: {line_str[:200] if line_str else 'empty'}") |
| else: |
| log.debug(f"[STREAM COLLECTOR] Skipping non-string/bytes line: {type(line)}") |
| continue |
|
|
| |
| if not line_str.startswith("data: "): |
| log.debug(f"[STREAM COLLECTOR] Skipping line without 'data: ' prefix: {line_str[:100]}") |
| continue |
|
|
| raw = line_str[6:].strip() |
| if raw == "[DONE]": |
| log.debug("[STREAM COLLECTOR] Received [DONE] marker") |
| break |
|
|
| try: |
| log.debug(f"[STREAM COLLECTOR] Parsing JSON: {raw[:200]}") |
| chunk = json.loads(raw) |
| has_data = True |
| log.debug(f"[STREAM COLLECTOR] Chunk keys: {chunk.keys() if isinstance(chunk, dict) else type(chunk)}") |
|
|
| |
| response_obj = chunk.get("response", {}) |
| if not response_obj: |
| log.debug("[STREAM COLLECTOR] No 'response' key in chunk, trying direct access") |
| response_obj = chunk |
|
|
| candidates = response_obj.get("candidates", []) |
| log.debug(f"[STREAM COLLECTOR] Found {len(candidates)} candidates") |
| if not candidates: |
| log.debug(f"[STREAM COLLECTOR] No candidates in chunk, chunk structure: {list(chunk.keys()) if isinstance(chunk, dict) else type(chunk)}") |
| continue |
|
|
| candidate = candidates[0] |
|
|
| |
| content = candidate.get("content", {}) |
| parts = content.get("parts", []) |
| log.debug(f"[STREAM COLLECTOR] Processing {len(parts)} parts from candidate") |
|
|
| for part in parts: |
| if not isinstance(part, dict): |
| continue |
|
|
| |
| text = part.get("text", "") |
| if text: |
| |
| if part.get("thought", False): |
| collected_thought_text.append(text) |
| log.debug(f"[STREAM COLLECTOR] Collected thought text: {text[:100]}") |
| else: |
| collected_text.append(text) |
| log.debug(f"[STREAM COLLECTOR] Collected regular text: {text[:100]}") |
| |
| elif "inlineData" in part or "fileData" in part or "executableCode" in part or "codeExecutionResult" in part: |
| collected_other_parts.append(part) |
| log.debug(f"[STREAM COLLECTOR] Collected non-text part: {list(part.keys())}") |
|
|
| |
| if candidate.get("finishReason"): |
| merged_response["response"]["candidates"][0]["finishReason"] = candidate["finishReason"] |
|
|
| if candidate.get("safetyRatings"): |
| merged_response["response"]["candidates"][0]["safetyRatings"] = candidate["safetyRatings"] |
|
|
| if candidate.get("citationMetadata"): |
| merged_response["response"]["candidates"][0]["citationMetadata"] = candidate["citationMetadata"] |
|
|
| |
| usage = response_obj.get("usageMetadata", {}) |
| if usage: |
| merged_response["response"]["usageMetadata"].update(usage) |
|
|
| except json.JSONDecodeError as e: |
| log.debug(f"[STREAM COLLECTOR] Failed to parse JSON chunk: {e}") |
| continue |
| except Exception as e: |
| log.debug(f"[STREAM COLLECTOR] Error processing chunk: {e}") |
| continue |
|
|
| except Exception as e: |
| log.error(f"[STREAM COLLECTOR] Error collecting stream after {line_count} lines: {e}") |
| return Response( |
| content=json.dumps({"error": f"收集流式响应失败: {str(e)}"}), |
| status_code=500, |
| media_type="application/json" |
| ) |
|
|
| log.debug(f"[STREAM COLLECTOR] Finished iteration, has_data={has_data}, line_count={line_count}") |
|
|
| |
| if not has_data: |
| log.error(f"[STREAM COLLECTOR] No data collected from stream after {line_count} lines") |
| return Response( |
| content=json.dumps({"error": "No data collected from stream"}), |
| status_code=500, |
| media_type="application/json" |
| ) |
|
|
| |
| final_parts = [] |
|
|
| |
| if collected_thought_text: |
| final_parts.append({ |
| "text": "".join(collected_thought_text), |
| "thought": True |
| }) |
|
|
| |
| if collected_text: |
| final_parts.append({ |
| "text": "".join(collected_text) |
| }) |
|
|
| |
| final_parts.extend(collected_other_parts) |
|
|
| |
| if not final_parts: |
| final_parts.append({"text": ""}) |
|
|
| merged_response["response"]["candidates"][0]["content"]["parts"] = final_parts |
|
|
| log.info(f"[STREAM COLLECTOR] Collected {len(collected_text)} text chunks, {len(collected_thought_text)} thought chunks, and {len(collected_other_parts)} other parts") |
|
|
| |
| if "response" in merged_response and "candidates" not in merged_response: |
| log.debug(f"[STREAM COLLECTOR] 展开response包装") |
| merged_response = merged_response["response"] |
|
|
| |
| return Response( |
| content=json.dumps(merged_response, ensure_ascii=False).encode('utf-8'), |
| status_code=200, |
| headers={}, |
| media_type="application/json" |
| ) |
|
|
|
|
| def parse_quota_reset_timestamp(error_response: dict) -> Optional[float]: |
| """ |
| 从Google API错误响应中提取quota重置时间戳 |
| |
| Args: |
| error_response: Google API返回的错误响应字典 |
| |
| Returns: |
| Unix时间戳(秒),如果无法解析则返回None |
| |
| 示例错误响应: |
| { |
| "error": { |
| "code": 429, |
| "message": "You have exhausted your capacity...", |
| "status": "RESOURCE_EXHAUSTED", |
| "details": [ |
| { |
| "@type": "type.googleapis.com/google.rpc.ErrorInfo", |
| "reason": "QUOTA_EXHAUSTED", |
| "metadata": { |
| "quotaResetTimeStamp": "2025-11-30T14:57:24Z", |
| "quotaResetDelay": "13h19m1.20964964s" |
| } |
| } |
| ] |
| } |
| } |
| """ |
| try: |
| details = error_response.get("error", {}).get("details", []) |
|
|
| for detail in details: |
| if detail.get("@type") == "type.googleapis.com/google.rpc.ErrorInfo": |
| reset_timestamp_str = detail.get("metadata", {}).get("quotaResetTimeStamp") |
|
|
| if reset_timestamp_str: |
| if reset_timestamp_str.endswith("Z"): |
| reset_timestamp_str = reset_timestamp_str.replace("Z", "+00:00") |
|
|
| reset_dt = datetime.fromisoformat(reset_timestamp_str) |
| if reset_dt.tzinfo is None: |
| reset_dt = reset_dt.replace(tzinfo=timezone.utc) |
|
|
| return reset_dt.astimezone(timezone.utc).timestamp() |
|
|
| return None |
|
|
| except Exception: |
| return None |
|
|
| def get_model_group(model_name: str) -> str: |
| """ |
| 获取模型组,用于 GCLI CD 机制。 |
| |
| Args: |
| model_name: 模型名称 |
| |
| Returns: |
| "pro" 或 "flash" |
| |
| 说明: |
| - pro 组: gemini-2.5-pro, gemini-3-pro-preview 共享额度 |
| - flash 组: gemini-2.5-flash 单独额度 |
| """ |
|
|
| |
| if "flash" in model_name.lower(): |
| return "flash" |
| else: |
| |
| return "pro" |