Spaces:
Sleeping
Sleeping
| import time | |
| import logging | |
| from typing import List | |
| from google import genai | |
| from src.recommendation_engine.config import ( | |
| GEMINI_API_KEY, | |
| MODEL_CANDIDATES, | |
| IDEA_TEMPERATURE, | |
| FEATURE_TEMPERATURE, | |
| CHAT_TEMPERATURE, | |
| INTENT_TEMPERATURE, | |
| IDEA_MAX_TOKENS, | |
| FEATURE_MAX_TOKENS, | |
| CHAT_MAX_TOKENS, | |
| INTENT_MAX_TOKENS, | |
| FULL_PROJECT_MAX_TOKENS, | |
| TOP_P, | |
| TOP_K, | |
| MAX_RETRIES, | |
| RETRY_DELAY_SECONDS, | |
| ENABLE_LOGGING | |
| ) | |
| from src.recommendation_engine.validator import validate_generated_list | |
| logger = logging.getLogger(__name__) | |
| if ENABLE_LOGGING: | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s | %(levelname)s | %(message)s" | |
| ) | |
| class LLMProviderError(Exception): | |
| def __init__(self, message: str, status_code: int = 503): | |
| super().__init__(message) | |
| self.message = message | |
| self.status_code = status_code | |
| def classify_provider_error(error: Exception): | |
| text = str(error).lower() | |
| if ( | |
| "reported as leaked" in text | |
| or "permission_denied" in text | |
| or "api key" in text | |
| or "403" in text | |
| ): | |
| return LLMProviderError( | |
| "Gemini API key was rejected. Create a new key, update .env, and restart the server.", | |
| status_code=503 | |
| ) | |
| if ( | |
| "resource_exhausted" in text | |
| or "quota" in text | |
| or "rate limit" in text | |
| or "429" in text | |
| ): | |
| return LLMProviderError( | |
| "Gemini quota or rate limit is exhausted. Try again later or use another API key/project.", | |
| status_code=429 | |
| ) | |
| return None | |
| client = genai.Client(api_key=GEMINI_API_KEY) | |
| def extract_text(response) -> str: | |
| if not response: | |
| return "" | |
| text = getattr(response, "text", None) | |
| if text: | |
| return text.strip() | |
| try: | |
| candidates = getattr(response, "candidates", []) | |
| if candidates: | |
| parts = candidates[0].content.parts | |
| return " ".join( | |
| p.text for p in parts if hasattr(p, "text") | |
| ).strip() | |
| except Exception: | |
| pass | |
| return "" | |
| def get_temperature(task: str) -> float: | |
| return { | |
| "idea": IDEA_TEMPERATURE, | |
| "feature": FEATURE_TEMPERATURE, | |
| "intent": INTENT_TEMPERATURE, | |
| }.get(task, CHAT_TEMPERATURE) | |
| def get_max_tokens(task: str) -> int: | |
| return { | |
| "idea": IDEA_MAX_TOKENS, | |
| "feature": FEATURE_MAX_TOKENS, | |
| "intent": INTENT_MAX_TOKENS, | |
| "full_project": FULL_PROJECT_MAX_TOKENS, | |
| }.get(task, CHAT_MAX_TOKENS) | |
| def safe_prompt(prompt: str, max_chars: int = 12000) -> str: | |
| return prompt[-max_chars:] | |
| def is_bad_response(text: str) -> bool: | |
| if not text: | |
| return True | |
| text = text.strip() | |
| if len(text) < 3: | |
| return True | |
| bad_phrases = [ | |
| "as an ai", | |
| "i can help you", | |
| "let me know" | |
| ] | |
| lower = text.lower() | |
| if all(p in lower for p in bad_phrases): | |
| return True | |
| return False | |
| def generate_text( | |
| prompt: str, | |
| task: str = "chat", | |
| temperature=None | |
| ) -> str: | |
| prompt = safe_prompt(prompt) | |
| if temperature is None: | |
| temperature = get_temperature(task) | |
| max_tokens = get_max_tokens(task) | |
| for model_name in MODEL_CANDIDATES: | |
| for attempt in range(MAX_RETRIES): | |
| try: | |
| logger.info( | |
| f"[LLM] model={model_name} | task={task} | attempt={attempt+1}" | |
| ) | |
| response = client.models.generate_content( | |
| model=model_name, | |
| contents=prompt, | |
| config={ | |
| "temperature": temperature, | |
| "top_p": TOP_P, | |
| "top_k": TOP_K, | |
| "max_output_tokens": max_tokens | |
| } | |
| ) | |
| text = extract_text(response) | |
| if is_bad_response(text): | |
| logger.warning("[LLM] Weak response, using anyway") | |
| return text | |
| return text | |
| except Exception as e: | |
| logger.warning(f"[LLM ERROR] {e}") | |
| provider_error = classify_provider_error(e) | |
| if provider_error: | |
| if provider_error.status_code == 429 and attempt < MAX_RETRIES - 1: | |
| sleep_time = (RETRY_DELAY_SECONDS * 5) * (attempt + 1) | |
| logger.info(f"[LLM 429] Rate limited. Retrying in {sleep_time}s...") | |
| time.sleep(sleep_time) | |
| continue | |
| raise provider_error | |
| time.sleep(RETRY_DELAY_SECONDS * (attempt + 1)) | |
| logger.info(f"[LLM] switching model...") | |
| logger.error("All LLM models failed") | |
| return "" | |
| def generate_list(prompt: str, task="chat") -> List[str]: | |
| text = generate_text(prompt, task=task) | |
| return validate_generated_list(text) | |