Spaces:
Running on Zero
Running on Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -28,18 +28,25 @@ SYSTEM_PROMPT = "You are MathBioAgent, an expert AI assistant specialized in mat
|
|
| 28 |
|
| 29 |
@spaces.GPU(duration=60)
|
| 30 |
def chat(message, history):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
| 32 |
for h in history:
|
| 33 |
if isinstance(h, dict):
|
| 34 |
-
messages.append(h)
|
| 35 |
elif isinstance(h, (list, tuple)) and len(h) == 2:
|
| 36 |
-
messages.append({"role": "user", "content": h[0]})
|
| 37 |
-
messages.append({"role": "assistant", "content": h[1]})
|
| 38 |
-
messages.append({"role": "user", "content": message})
|
| 39 |
-
|
| 40 |
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 41 |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 42 |
-
|
| 43 |
with torch.no_grad():
|
| 44 |
outputs = model.generate(
|
| 45 |
**inputs,
|
|
|
|
| 28 |
|
| 29 |
@spaces.GPU(duration=60)
|
| 30 |
def chat(message, history):
|
| 31 |
+
def extract_text(content):
|
| 32 |
+
if isinstance(content, str):
|
| 33 |
+
return content
|
| 34 |
+
if isinstance(content, list):
|
| 35 |
+
return " ".join(c.get("text", "") if isinstance(c, dict) else str(c) for c in content)
|
| 36 |
+
return str(content)
|
| 37 |
+
|
| 38 |
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
| 39 |
for h in history:
|
| 40 |
if isinstance(h, dict):
|
| 41 |
+
messages.append({"role": h["role"], "content": extract_text(h.get("content", ""))})
|
| 42 |
elif isinstance(h, (list, tuple)) and len(h) == 2:
|
| 43 |
+
messages.append({"role": "user", "content": extract_text(h[0])})
|
| 44 |
+
messages.append({"role": "assistant", "content": extract_text(h[1])})
|
| 45 |
+
messages.append({"role": "user", "content": extract_text(message)})
|
| 46 |
+
|
| 47 |
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 48 |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 49 |
+
|
| 50 |
with torch.no_grad():
|
| 51 |
outputs = model.generate(
|
| 52 |
**inputs,
|