havinashpatil commited on
Commit
18261cc
Β·
1 Parent(s): 0c0a8ff

Add Hugging Face Inference API fallback for AI fixer

Browse files
Files changed (2) hide show
  1. requirements.txt +1 -0
  2. server/ai_fixer.py +38 -21
requirements.txt CHANGED
@@ -11,3 +11,4 @@ datasets
11
  trl
12
  accelerate
13
  bitsandbytes
 
 
11
  trl
12
  accelerate
13
  bitsandbytes
14
+ huggingface_hub
server/ai_fixer.py CHANGED
@@ -34,23 +34,39 @@ def check_tgi_availability(tgi_url: str = TGI_BASE_URL) -> bool:
34
  return TGI_AVAILABLE
35
 
36
 
37
- def fix_with_tgi(code: str, tgi_url: str = TGI_BASE_URL) -> Optional[str]:
38
- """Use TGI for advanced code fixing."""
39
- if not TGI_AVAILABLE and not check_tgi_availability(tgi_url):
40
- return None
 
 
 
41
 
42
- prompt = f"""You are an expert competitive programmer.
43
 
44
- Fix the following Python code:
45
- - Remove syntax errors
46
- - Ensure correct logic
47
- - Optimize to O(n) if possible
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- Code:
50
- {code}
 
 
 
51
 
52
- Return ONLY the corrected code without any explanation:
53
- """
54
 
55
  try:
56
  response = httpx.post(
@@ -61,21 +77,22 @@ Return ONLY the corrected code without any explanation:
61
  "max_tokens": 500,
62
  "temperature": 0.3
63
  },
64
- timeout=30.0
65
  )
66
  response.raise_for_status()
67
  result = response.json()
68
  fixed_code = result["choices"][0]["message"]["content"].strip()
69
 
70
- # Clean up the response
71
- if "Return ONLY the corrected code" in fixed_code:
72
- fixed_code = fixed_code.split("Return ONLY the corrected code")[-1].strip()
73
-
74
- return fixed_code if fixed_code else None
75
 
76
  except Exception as e:
77
  print(f"TGI fix error: {e}", file=sys.stderr)
78
- return None
 
79
 
80
 
81
  # ─── Pattern-Based Fixes ─────────────────────────────────────────────────────
@@ -517,7 +534,7 @@ def generate_fix(
517
  Returns: { fixed_code, method, success, explanation }
518
  """
519
  if use_tgi:
520
- fixed_code = fix_with_tgi(code, tgi_url=tgi_url)
521
  if fixed_code and validate_code(fixed_code):
522
  # Log complexity vs reward for research tracking
523
  complexity = detect_complexity(fixed_code)
 
34
  return TGI_AVAILABLE
35
 
36
 
37
+ def fix_with_hf_api(code: str, error_log: str = "") -> Optional[str]:
38
+ """Use Hugging Face Serverless Inference API as a fallback."""
39
+ try:
40
+ from huggingface_hub import InferenceClient
41
+ model = "Qwen/Qwen2.5-Coder-3B-Instruct"
42
+ token = os.environ.get("HF_TOKEN")
43
+ client = InferenceClient(model=model, token=token)
44
 
45
+ prompt = f"You are an expert competitive programmer.\n\nFix the following Python code:\n- Remove syntax errors\n- Ensure correct logic\n- Optimize to O(n) if possible\n\nPrevious Error:\n{error_log}\n\nCode:\n{code}\n\nReturn ONLY the corrected code without any explanation wrapped in ```python ... ```:"
46
 
47
+ response = client.chat_completion(
48
+ messages=[{"role": "user", "content": prompt}],
49
+ max_tokens=500,
50
+ temperature=0.3
51
+ )
52
+ result = response.choices[0].message.content.strip()
53
+
54
+ import re
55
+ code_match = re.search(r'```python\n(.*?)\n```', result, re.DOTALL)
56
+ if not code_match:
57
+ code_match = re.search(r'```(.*?)```', result, re.DOTALL)
58
+ return code_match.group(1).strip() if code_match else result.replace("```", "").strip()
59
+ except Exception as e:
60
+ print(f"HF API fix error: {e}", file=sys.stderr)
61
+ return None
62
 
63
+ def fix_with_tgi(code: str, tgi_url: str = TGI_BASE_URL, error_log: str = "") -> Optional[str]:
64
+ """Use TGI for advanced code fixing. Fallbacks to HF Serverless API if unavailable."""
65
+ if not TGI_AVAILABLE and not check_tgi_availability(tgi_url):
66
+ print("TGI unavailable. Falling back to HF Serverless Inference API...", file=sys.stderr)
67
+ return fix_with_hf_api(code, error_log)
68
 
69
+ prompt = f"You are an expert competitive programmer.\n\nFix the following Python code:\n- Remove syntax errors\n- Ensure correct logic\n- Optimize to O(n) if possible\n\nCode:\n{code}\n\nReturn ONLY the corrected code without any explanation wrapped in ```python ... ```:"
 
70
 
71
  try:
72
  response = httpx.post(
 
77
  "max_tokens": 500,
78
  "temperature": 0.3
79
  },
80
+ timeout=10.0
81
  )
82
  response.raise_for_status()
83
  result = response.json()
84
  fixed_code = result["choices"][0]["message"]["content"].strip()
85
 
86
+ import re
87
+ code_match = re.search(r'```python\n(.*?)\n```', fixed_code, re.DOTALL)
88
+ if not code_match:
89
+ code_match = re.search(r'```(.*?)```', fixed_code, re.DOTALL)
90
+ return code_match.group(1).strip() if code_match else fixed_code.replace("```", "").strip()
91
 
92
  except Exception as e:
93
  print(f"TGI fix error: {e}", file=sys.stderr)
94
+ print("Falling back to HF Serverless Inference API...", file=sys.stderr)
95
+ return fix_with_hf_api(code, error_log)
96
 
97
 
98
  # ─── Pattern-Based Fixes ─────────────────────────────────────────────────────
 
534
  Returns: { fixed_code, method, success, explanation }
535
  """
536
  if use_tgi:
537
+ fixed_code = fix_with_tgi(code, tgi_url=tgi_url, error_log=error_log)
538
  if fixed_code and validate_code(fixed_code):
539
  # Log complexity vs reward for research tracking
540
  complexity = detect_complexity(fixed_code)