bbkdevops's picture
download
raw
7.54 kB
"""
TinyMind Reasoning Engine — Chain-of-Thought + Self-Consistency
ทำให้โมเดลฉลาดขึ้นด้วย inference-time compute:
1. <think>...</think> CoT — โมเดลคิดก่อนตอบ
2. Self-consistency — sample N คำตอบ แล้ว majority vote
3. Best-of-N — เลือกคำตอบที่ score สูงสุด
4. Verification loop — ตรวจคำตอบแล้ว refine ถ้าล้มเหลว
"""
from __future__ import annotations
import re
from dataclasses import dataclass, field
from typing import Callable
THINK_OPEN = "<think>"
THINK_CLOSE = "</think>"
ANSWER_TAG = "<answer>"
ANSWER_END = "</answer>"
# System prompt ที่บังคับให้คิดก่อนตอบ
COT_SYSTEM_PROMPT = """คุณคือ TinyMind ผู้ช่วย AI ที่ฉลาดและคิดอย่างเป็นระบบ
วิธีตอบ:
1. คิดวิเคราะห์ปัญหาอย่างละเอียดใน <think>...</think>
2. ให้คำตอบสุดท้ายที่ชัดเจนใน <answer>...</answer>
ตัวอย่าง:
<think>
ปัญหาถามเกี่ยวกับ X... ต้องพิจารณา A, B, C...
ถ้า A แล้ว... ถ้า B แล้ว... ดังนั้น...
</think>
<answer>คำตอบที่ถูกต้องและครบถ้วน</answer>"""
COT_SYSTEM_EN = """You are TinyMind, an AI assistant that thinks systematically before answering.
How to respond:
1. Analyze the problem thoroughly inside <think>...</think>
2. Give a clear final answer inside <answer>...</answer>
Example:
<think>
The question asks about X... I need to consider A, B, C...
If A then... If B then... Therefore...
</think>
<answer>The correct and complete answer</answer>"""
def build_cot_prompt(question: str, lang: str = "auto") -> str:
is_thai = lang == "th" or (lang == "auto" and _detect_thai(question))
system = COT_SYSTEM_PROMPT if is_thai else COT_SYSTEM_EN
return (
f"<bos><system>{system}</system>\n"
f"<user>{question}</user>\n"
f"<assistant><think>"
)
def _detect_thai(text: str) -> bool:
thai_chars = sum(1 for c in text if "฀" <= c <= "๿")
return thai_chars / max(len(text), 1) > 0.15
def extract_thinking(text: str) -> tuple[str, str]:
"""แยก thinking trace ออกจาก final answer"""
think_match = re.search(r"<think>([\s\S]*?)</think>", text, re.IGNORECASE)
answer_match = re.search(r"<answer>([\s\S]*?)</answer>", text, re.IGNORECASE)
thinking = think_match.group(1).strip() if think_match else ""
if answer_match:
answer = answer_match.group(1).strip()
else:
# fallback: ข้อความหลัง </think>
if think_match:
answer = text[think_match.end():].strip()
answer = re.sub(r"^<answer>", "", answer, flags=re.IGNORECASE).strip()
else:
answer = text.strip()
return thinking, answer
# ─── Self-Consistency ─────────────────────────────────────────────────────────
@dataclass
class SampledAnswer:
raw: str
thinking: str
answer: str
score: float = 0.0
def _normalize(text: str) -> str:
text = text.lower().strip()
text = re.sub(r"\s+", " ", text)
text = re.sub(r"[^\wก-๙ ]", "", text)
return text
def _jaccard(a: str, b: str) -> float:
ta = set(_normalize(a).split())
tb = set(_normalize(b).split())
if not ta or not tb:
return 0.0
return len(ta & tb) / len(ta | tb)
def majority_vote(samples: list[SampledAnswer], min_agreement: float = 0.3) -> SampledAnswer:
"""เลือกคำตอบที่ตรงกันมากสุด (soft majority via Jaccard similarity)"""
if not samples:
raise ValueError("no samples")
if len(samples) == 1:
return samples[0]
scores: list[float] = []
for i, s in enumerate(samples):
agreement = sum(
_jaccard(s.answer, other.answer)
for j, other in enumerate(samples) if j != i
)
scores.append(agreement / (len(samples) - 1))
for s, sc in zip(samples, scores):
s.score = sc
# เลือกตัวที่ agreement สูงสุด
best = max(zip(samples, scores), key=lambda x: x[1])
return best[0]
def best_of_n(
samples: list[SampledAnswer],
score_fn: Callable[[str, str], float],
question: str,
) -> SampledAnswer:
"""เลือกคำตอบที่ score_fn ให้คะแนนสูงสุด"""
for s in samples:
s.score = score_fn(question, s.answer)
return max(samples, key=lambda s: s.score)
# ─── Reasoning Trainer Data Format ───────────────────────────────────────────
def format_cot_training_example(
question: str,
thinking: str,
answer: str,
lang: str = "th",
) -> dict:
"""สร้าง training example ที่มี thinking trace"""
system = COT_SYSTEM_PROMPT if lang == "th" else COT_SYSTEM_EN
full_response = f"<think>\n{thinking}\n</think>\n<answer>{answer}</answer>"
return {
"question": question,
"answer": full_response,
"thinking": thinking,
"final_answer": answer,
"lang": lang,
"source": "cot_reasoning",
"context": f"<system>{system}</system>",
}
def format_dpo_example(
question: str,
chosen_thinking: str,
chosen_answer: str,
rejected_answer: str,
lang: str = "th",
) -> dict:
"""สร้าง DPO pair: chosen (มี reasoning) vs rejected (ตอบตรงๆ ไม่คิด)"""
system = COT_SYSTEM_PROMPT if lang == "th" else COT_SYSTEM_EN
chosen = f"<think>\n{chosen_thinking}\n</think>\n<answer>{chosen_answer}</answer>"
return {
"question": question,
"system": system,
"chosen": chosen,
"rejected": rejected_answer,
"lang": lang,
"source": "dpo_cot",
}
# ─── Verification Loop ────────────────────────────────────────────────────────
@dataclass
class VerificationPass:
attempt: int
answer: str
thinking: str
passed: bool
issues: list[str] = field(default_factory=list)
def build_refinement_prompt(
question: str,
previous_answer: str,
issues: list[str],
lang: str = "th",
) -> str:
if lang == "th":
issue_text = "\n".join(f"- {i}" for i in issues)
return (
f"<user>{question}</user>\n"
f"<assistant>คำตอบก่อนหน้า: {previous_answer}\n\n"
f"ปัญหาที่พบ:\n{issue_text}\n\n"
f"กรุณาคิดใหม่และแก้ไข:</assistant>\n<think>"
)
else:
issue_text = "\n".join(f"- {i}" for i in issues)
return (
f"<user>{question}</user>\n"
f"<assistant>Previous answer: {previous_answer}\n\n"
f"Issues found:\n{issue_text}\n\n"
f"Please reconsider and correct:</assistant>\n<think>"
)

Xet Storage Details

Size:
7.54 kB
·
Xet hash:
7466aba4c13f0f65b65a36a165e9fd353076cfe566405ab9fada483bc25e3d4c

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.