Priyansh Saxena commited on
Commit
962831e
·
1 Parent(s): 4ef5712

fix: make model loading offline-safe for Spaces runtime

Browse files
Files changed (4) hide show
  1. Dockerfile +3 -0
  2. README.md +1 -0
  3. app.py +3 -3
  4. llm_agent.py +22 -9
Dockerfile CHANGED
@@ -11,6 +11,9 @@ RUN mkdir -p /app/data/uploads /app/static/images
11
 
12
  ENV TRANSFORMERS_CACHE=/app/.cache/huggingface/transformers
13
  ENV HF_HOME=/app/.cache/huggingface
 
 
 
14
 
15
  EXPOSE 7860
16
 
 
11
 
12
  ENV TRANSFORMERS_CACHE=/app/.cache/huggingface/transformers
13
  ENV HF_HOME=/app/.cache/huggingface
14
+ ENV HF_HUB_OFFLINE=1
15
+ ENV TRANSFORMERS_OFFLINE=1
16
+ ENV HF_HUB_DISABLE_TELEMETRY=1
17
 
18
  EXPOSE 7860
19
 
README.md CHANGED
@@ -6,6 +6,7 @@ colorTo: purple
6
  sdk: docker
7
  sdk_version: "1.0"
8
  app_file: app.py
 
9
  pinned: false
10
  ---
11
 
 
6
  sdk: docker
7
  sdk_version: "1.0"
8
  app_file: app.py
9
+ app_port: 7860
10
  pinned: false
11
  ---
12
 
app.py CHANGED
@@ -54,12 +54,12 @@ def index():
54
  def models():
55
  return jsonify({
56
  "models": [
57
- {"id": "qwen", "name": "Qwen2.5-1.5B", "provider": "Local (transformers)", "free": True},
58
  {"id": "bart", "name": "BART (fine-tuned)", "provider": "Local (transformers)", "free": True},
59
  {"id": "gemini", "name": "Gemini 2.0 Flash", "provider": "Google AI (API key)", "free": False},
60
  {"id": "grok", "name": "Grok-3 Mini", "provider": "xAI (API key)", "free": False},
61
  ],
62
- "default": "qwen"
63
  })
64
 
65
 
@@ -70,7 +70,7 @@ def plot():
70
  if not data or not data.get('query'):
71
  return jsonify({'error': 'Missing required field: query'}), 400
72
 
73
- logging.info(f"Plot request: model={data.get('model','qwen')} query={data.get('query')[:80]}")
74
  result = agent.process_request(data)
75
  logging.info(f"Plot completed in {time.time() - t0:.2f}s")
76
  return jsonify(result)
 
54
  def models():
55
  return jsonify({
56
  "models": [
57
+ {"id": "qwen", "name": "Qwen2.5-1.5B", "provider": "Local (optional path)", "free": True},
58
  {"id": "bart", "name": "BART (fine-tuned)", "provider": "Local (transformers)", "free": True},
59
  {"id": "gemini", "name": "Gemini 2.0 Flash", "provider": "Google AI (API key)", "free": False},
60
  {"id": "grok", "name": "Grok-3 Mini", "provider": "xAI (API key)", "free": False},
61
  ],
62
+ "default": "bart"
63
  })
64
 
65
 
 
70
  if not data or not data.get('query'):
71
  return jsonify({'error': 'Missing required field: query'}), 400
72
 
73
+ logging.info(f"Plot request: model={data.get('model','bart')} query={data.get('query')[:80]}")
74
  result = agent.process_request(data)
75
  logging.info(f"Plot completed in {time.time() - t0:.2f}s")
76
  return jsonify(result)
llm_agent.py CHANGED
@@ -13,6 +13,10 @@ load_dotenv()
13
 
14
  logger = logging.getLogger(__name__)
15
 
 
 
 
 
16
  # ---------------------------------------------------------------------------
17
  # Prompt templates
18
  # ---------------------------------------------------------------------------
@@ -103,10 +107,13 @@ class LLM_Agent:
103
  def _run_qwen(self, user_msg: str) -> str:
104
  if self._qwen_model is None:
105
  from transformers import AutoModelForCausalLM, AutoTokenizer
106
- model_id = "Qwen/Qwen2.5-1.5B-Instruct"
 
 
 
107
  logger.info("Loading Qwen model (first request)...")
108
- self._qwen_tokenizer = AutoTokenizer.from_pretrained(model_id)
109
- self._qwen_model = AutoModelForCausalLM.from_pretrained(model_id)
110
  logger.info("Qwen model loaded.")
111
  messages = [
112
  {"role": "system", "content": _SYSTEM_PROMPT},
@@ -155,10 +162,10 @@ class LLM_Agent:
155
  def _run_bart(self, query: str) -> str:
156
  if self._bart_model is None:
157
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
158
- model_id = "ArchCoder/fine-tuned-bart-large"
159
  logger.info("Loading BART model (first request)...")
160
- self._bart_tokenizer = AutoTokenizer.from_pretrained(model_id)
161
- self._bart_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
162
  logger.info("BART model loaded.")
163
  inputs = self._bart_tokenizer(
164
  query, return_tensors="pt", max_length=512, truncation=True
@@ -172,7 +179,7 @@ class LLM_Agent:
172
  t0 = time.time()
173
  query = data.get("query", "")
174
  data_path = data.get("file_path")
175
- model = data.get("model", "qwen")
176
 
177
  if data_path and os.path.exists(data_path):
178
  self.data_processor = DataProcessor(data_path)
@@ -194,8 +201,14 @@ class LLM_Agent:
194
  user_msg = _user_message(query, columns, dtypes, sample_rows)
195
  if model == "gemini": raw_text = self._run_gemini(user_msg)
196
  elif model == "grok": raw_text = self._run_grok(user_msg)
197
- elif model == "bart": raw_text = self._run_bart(query)
198
- else: raw_text = self._run_qwen(user_msg)
 
 
 
 
 
 
199
 
200
  logger.info(f"LLM [{model}] output: {raw_text}")
201
  parsed = _parse_output(raw_text)
 
13
 
14
  logger = logging.getLogger(__name__)
15
 
16
+
17
+ def _model_dir(dirname: str) -> str:
18
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), dirname)
19
+
20
  # ---------------------------------------------------------------------------
21
  # Prompt templates
22
  # ---------------------------------------------------------------------------
 
107
  def _run_qwen(self, user_msg: str) -> str:
108
  if self._qwen_model is None:
109
  from transformers import AutoModelForCausalLM, AutoTokenizer
110
+ # Prefer a local model path in Spaces to avoid any runtime network dependency.
111
+ model_id = os.getenv("QWEN_LOCAL_PATH", "")
112
+ if not model_id:
113
+ raise ValueError("Qwen local model is not configured in this Space")
114
  logger.info("Loading Qwen model (first request)...")
115
+ self._qwen_tokenizer = AutoTokenizer.from_pretrained(model_id, local_files_only=True)
116
+ self._qwen_model = AutoModelForCausalLM.from_pretrained(model_id, local_files_only=True)
117
  logger.info("Qwen model loaded.")
118
  messages = [
119
  {"role": "system", "content": _SYSTEM_PROMPT},
 
162
  def _run_bart(self, query: str) -> str:
163
  if self._bart_model is None:
164
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
165
+ model_id = os.getenv("BART_LOCAL_PATH", _model_dir("fine-tuned-bart-large"))
166
  logger.info("Loading BART model (first request)...")
167
+ self._bart_tokenizer = AutoTokenizer.from_pretrained(model_id, local_files_only=True)
168
+ self._bart_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, local_files_only=True)
169
  logger.info("BART model loaded.")
170
  inputs = self._bart_tokenizer(
171
  query, return_tensors="pt", max_length=512, truncation=True
 
179
  t0 = time.time()
180
  query = data.get("query", "")
181
  data_path = data.get("file_path")
182
+ model = data.get("model", "bart")
183
 
184
  if data_path and os.path.exists(data_path):
185
  self.data_processor = DataProcessor(data_path)
 
201
  user_msg = _user_message(query, columns, dtypes, sample_rows)
202
  if model == "gemini": raw_text = self._run_gemini(user_msg)
203
  elif model == "grok": raw_text = self._run_grok(user_msg)
204
+ elif model == "qwen":
205
+ try:
206
+ raw_text = self._run_qwen(user_msg)
207
+ except Exception as qwen_exc:
208
+ logger.warning(f"Qwen unavailable, falling back to BART: {qwen_exc}")
209
+ raw_text = self._run_bart(query)
210
+ else:
211
+ raw_text = self._run_bart(query)
212
 
213
  logger.info(f"LLM [{model}] output: {raw_text}")
214
  parsed = _parse_output(raw_text)