MahjongGameDesigner / reference_retriever.py
zhongchuyi
feature:优化 2026_01_13
e43bcf1
"""
参考玩法检索注入(RAG-lite)
目标:
- 用本地示例玩法库(.md 主真理 + _mGDL_v1.3.txt 辅语法翻译)为当前用户需求挑选少量最相关参考
- 避免把所有示例全量注入导致注意力稀释
"""
import os
import re
from typing import Dict, List, Optional, Tuple
from cache_manager import file_cache
_MGDL_SUFFIX_RE = re.compile(r"^(?P<name>.+?)_mGDL_v1\.3\.txt$")
_CJK_RUN_RE = re.compile(r"[\u4e00-\u9fff]+")
# 非严格停用词:用于降低“通用词”对关键词召回的干扰(可按需要继续补充)
_STOP_TERMS = {
"麻将", "玩法", "规则", "玩家", "游戏", "进行", "阶段", "流程", "说明", "机制",
"可以", "允许", "是否", "如果", "那么", "以及", "但是", "因为", "所以", "同时",
"庄家", "闲家", "手牌", "摸牌", "打牌", "出牌", "胡牌", "自摸", "点炮", "一炮",
"结算", "得分", "倍数", "番型", "番数", "牌墙", "弃牌", "顺序", "回合", "开始", "结束",
"默认", "配置", "支持", "包含", "采用", "需要", "必须", "不得",
# 更偏“功能词/口水词”的补充(避免成为锚点)
"一个", "做一", "做个", "加入", "增加", "带有", "希望", "想要", "想做", "更快", "更刺激",
}
# 锚点词的“领域提示字”:让锚点更偏向机制/规则名词,而不是偶然出现的通用短语
# 说明:这是启发式,但比“硬编码某个玩法/某个术语”更普适。
_ANCHOR_HINT_CHARS = set("鸟马赖鬼杠胡听鸡豆缺换海捞承包庄风中发白炮分番倍封顶")
def _cjk_ngrams(text: str, min_n: int = 2, max_n: int = 4) -> List[str]:
"""
从文本中提取 CJK n-gram(用于关键词召回;无外部依赖,适合小规模本地样例库)。
"""
s = (text or "").strip()
if not s:
return []
grams = []
for run in _CJK_RUN_RE.findall(s):
if not run:
continue
# 对超长段落,限制采样长度,避免构造过多 n-gram
run = run[:2000]
L = len(run)
for i in range(L):
for n in range(min_n, max_n + 1):
j = i + n
if j > L:
continue
g = run[i:j]
if g in _STOP_TERMS:
continue
grams.append(g)
# 去重保持顺序(小规模即可)
seen = set()
uniq = []
for g in grams:
if g not in seen:
seen.add(g)
uniq.append(g)
return uniq
def _base_dir() -> str:
return os.path.dirname(__file__)
def _read_text(path: str) -> str:
cached = file_cache.get(path)
if cached is not None:
return cached
try:
with open(path, "r", encoding="utf-8", errors="ignore") as f:
txt = f.read()
file_cache.set(path, txt)
return txt
except Exception:
return ""
def _build_variant_index() -> Dict[str, Dict[str, str]]:
"""
返回:
{
"疯狂血战": {"md": "...", "mgdl": "..."},
...
}
"""
base = _base_dir()
index: Dict[str, Dict[str, str]] = {}
try:
for filename in os.listdir(base):
if filename.endswith(".md") and filename != "README.md":
name = filename[:-3]
index.setdefault(name, {})["md_path"] = os.path.join(base, filename)
elif filename.endswith("_mGDL_v1.3.txt"):
m = _MGDL_SUFFIX_RE.match(filename)
if not m:
continue
name = m.group("name")
index.setdefault(name, {})["mgdl_path"] = os.path.join(base, filename)
except Exception:
return {}
# 只保留至少有 md 或 mgdl 的条目
cleaned: Dict[str, Dict[str, str]] = {}
for name, entry in index.items():
md_path = entry.get("md_path")
mgdl_path = entry.get("mgdl_path")
if not md_path and not mgdl_path:
continue
cleaned[name] = {}
if md_path:
cleaned[name]["md_path"] = md_path
if mgdl_path:
cleaned[name]["mgdl_path"] = mgdl_path
return cleaned
def list_variant_names() -> List[str]:
"""
返回本地玩法库中可用的玩法名列表(含“麻将机制说明”)。
"""
index = _build_variant_index()
names = sorted(index.keys(), key=lambda x: len(x), reverse=True)
return names
def match_variants_in_text(text: str) -> List[str]:
"""
从用户文本中匹配玩法名(优先更长的名称)。
"""
candidates = list_variant_names()
return _find_mentions(text, candidates)
def load_variant_md(name: str) -> str:
"""
读取指定玩法的 .md 内容(若不存在则返回空字符串)。
"""
index = _build_variant_index()
entry = index.get(name) or {}
md_path = entry.get("md_path")
if not md_path or not os.path.exists(md_path):
return ""
return _read_text(md_path) or ""
def _find_mentions(text: str, candidates: List[str]) -> List[str]:
"""
朴素子串匹配(中文玩法名通常稳定),返回按“更长优先”的去重命中列表。
"""
s = (text or "").strip()
if not s:
return []
hits: List[str] = []
# 长词优先,避免“血战”误命中“疯狂血战”
for name in sorted(candidates, key=lambda x: len(x), reverse=True):
if name and name in s:
hits.append(name)
# 去重保持顺序
seen = set()
uniq: List[str] = []
for h in hits:
if h not in seen:
seen.add(h)
uniq.append(h)
return uniq
_TERM_CACHE: Optional[Dict[str, set]] = None
_TERM_DF: Optional[Dict[str, int]] = None
_TERM_POSTINGS: Optional[Dict[str, List[str]]] = None
_DOMAIN_TERMS: Optional[set] = None
def _build_domain_terms() -> set:
"""
构建“领域词表”用于过滤用户 n-gram(降低把通用词当成锚点的概率)。
策略:
- 不引入外部依赖(如 jieba)
- 直接从“所有玩法 .md 的关键词集合并集”构建 domain_terms(足够泛化)
"""
global _DOMAIN_TERMS
if _DOMAIN_TERMS is not None:
return _DOMAIN_TERMS
index = _build_variant_index()
domain = set()
for name, entry in index.items():
if name == "麻将机制说明":
continue
md_path = entry.get("md_path")
if not md_path or not os.path.exists(md_path):
continue
md_txt = _read_text(md_path)
if not md_txt.strip():
continue
domain.update(_cjk_ngrams(md_txt, min_n=2, max_n=4))
_DOMAIN_TERMS = domain
return domain
def _build_variant_term_cache() -> Dict[str, set]:
"""
为每个玩法构建关键词特征(来自 .md,必要时可扩展到 mgdl)。
返回:{variant_name: {term1, term2, ...}}
"""
global _TERM_CACHE, _TERM_DF, _TERM_POSTINGS
if _TERM_CACHE is not None and _TERM_DF is not None:
return _TERM_CACHE
index = _build_variant_index()
cache: Dict[str, set] = {}
for name, entry in index.items():
# 机制词典单独注入,不把它当“参考玩法”候选
if name == "麻将机制说明":
continue
md_path = entry.get("md_path")
if not md_path or not os.path.exists(md_path):
continue
md_txt = _read_text(md_path)
if not md_txt.strip():
continue
# 只从 md 抽关键词:它是“内容真理”,且比 mGDL 更接近用户描述的术语
cache[name] = set(_cjk_ngrams(md_txt, min_n=2, max_n=4))
# 计算 df(每个 term 在多少个玩法中出现),用于稀有词加权(IDF-lite)
df: Dict[str, int] = {}
postings: Dict[str, List[str]] = {}
for _, terms in cache.items():
for t in terms:
df[t] = df.get(t, 0) + 1
for variant, terms in cache.items():
for t in terms:
postings.setdefault(t, []).append(variant)
_TERM_CACHE = cache
_TERM_DF = df
_TERM_POSTINGS = postings
return cache
def _score_by_terms(message: str, term_cache: Dict[str, set]) -> List[Tuple[str, int]]:
"""
基于关键词重叠给玩法打分(分数越高越相关)。
"""
global _TERM_DF
domain = _build_domain_terms()
user_terms = set(_cjk_ngrams(message, min_n=2, max_n=4))
# 仅保留“在领域词表中出现过”的词(否则很多通用 n-gram 会干扰)
user_terms = {t for t in user_terms if t in domain}
if not user_terms:
return []
scored: List[Tuple[str, int]] = []
for variant, terms in term_cache.items():
if not terms:
continue
inter = user_terms.intersection(terms)
if not inter:
continue
# 稀有词(df低)更有区分度:用 IDF-lite 加权,避免“通用词重叠”淹没关键术语(如“扎鸟”)
score = 0
for t in inter:
df = (_TERM_DF or {}).get(t, 9999)
# df=1 → 权重最高;df越大权重越低
score += int((len(t) * 100) / max(1, df))
scored.append((variant, score))
scored.sort(key=lambda x: x[1], reverse=True)
return scored
def _pick_anchor_terms(message: str, max_terms: int = 3) -> List[str]:
"""
从用户输入中挑选“高区分度锚点词”(通用、可解释、无需硬编码具体玩法/术语)。
规则:
- 仅使用领域词表命中的词
- 优先 df 小(更稀有)+ 词更长(更具体)
"""
global _TERM_DF, _TERM_POSTINGS
_build_variant_term_cache()
domain = _build_domain_terms()
terms = set(_cjk_ngrams(message, min_n=2, max_n=4))
terms = {t for t in terms if t in domain and t not in _STOP_TERMS}
if not terms:
return []
def _is_informative_anchor(term: str) -> bool:
if not term or term in _STOP_TERMS:
return False
return any(ch in _ANCHOR_HINT_CHARS for ch in term)
ranked = []
for t in terms:
df = (_TERM_DF or {}).get(t)
if not df:
continue
# df=1 最佳;长度越长越好
if not _is_informative_anchor(t):
continue
ranked.append((t, df, len(t)))
ranked.sort(key=lambda x: (x[1], -x[2], x[0])) # df asc, len desc
return [t for t, _, _ in ranked[: max(1, max_terms)]]
def pick_reference_variants(
message: str,
max_variants: int = 3,
fallback: Optional[List[str]] = None,
) -> List[str]:
"""
依据用户输入挑选参考玩法名(仅返回玩法名,不读文件)。
"""
index = _build_variant_index()
names = list(index.keys())
mentions = _find_mentions(message, names)
if mentions:
return mentions[: max(1, max_variants)]
# 用户未显式说玩法名:走关键词召回(例如提到“扎鸟/买马/承包”等术语)
term_cache = _build_variant_term_cache()
scored = _score_by_terms(message, term_cache)
anchor_terms = _pick_anchor_terms(message, max_terms=3)
postings = _TERM_POSTINGS or {}
# 将“锚点命中”作为加权 boost,而不是简单强制放前面截断
base_scores = {name: score for name, score in scored}
boost_scores: Dict[str, int] = {}
for t in anchor_terms:
df = (_TERM_DF or {}).get(t, 9999)
w = int((len(t) * 100) / max(1, df))
for v in postings.get(t, []):
boost_scores[v] = boost_scores.get(v, 0) + w
if scored or boost_scores:
candidates = set(base_scores.keys()) | set(boost_scores.keys())
ranked = []
for v in candidates:
base = base_scores.get(v, 0)
boost = boost_scores.get(v, 0)
ranked.append((v, base + boost))
ranked.sort(key=lambda x: x[1], reverse=True)
return [v for v, _ in ranked[: max(1, max_variants)]]
if fallback:
return fallback[: max(1, max_variants)]
# 默认兜底:覆盖不同体系,给模型“底座参考面”
default_pool = [
"疯狂血战",
"疯狂血流",
"广东100张",
"贵州捉鸡麻将",
"妙手七星",
]
return [n for n in default_pool if n in names][: max(1, max_variants)]
def build_reference_pack(
message: str,
max_variants: int = 3,
include_mechanism_library: bool = True,
include_mgdl: bool = False,
) -> Dict[str, str]:
"""
返回可注入到 system message 的多段文本。
"""
index = _build_variant_index()
picked = pick_reference_variants(message, max_variants=max_variants)
parts: Dict[str, str] = {}
if include_mechanism_library:
mech_path = os.path.join(_base_dir(), "麻将机制说明.md")
if os.path.exists(mech_path):
mech_txt = _read_text(mech_path).strip()
if mech_txt:
parts["mechanism_library"] = mech_txt
md_chunks: List[str] = []
mgdl_chunks: List[str] = []
for name in picked:
entry = index.get(name) or {}
md_path = entry.get("md_path")
mgdl_path = entry.get("mgdl_path")
if md_path and os.path.exists(md_path):
txt = _read_text(md_path).strip()
if txt:
md_chunks.append("\n# FILE: {0}\n{1}\n".format(os.path.basename(md_path), txt))
if include_mgdl and mgdl_path and os.path.exists(mgdl_path):
txt = _read_text(mgdl_path).strip()
if txt:
mgdl_chunks.append("\n# FILE: {0}\n{1}\n".format(os.path.basename(mgdl_path), txt))
if md_chunks:
parts["reference_md"] = "\n".join(md_chunks).strip()
if mgdl_chunks:
parts["reference_mgdl"] = "\n".join(mgdl_chunks).strip()
# 透传“本轮选了哪些参考”,便于模型更聚焦
parts["picked_names"] = ", ".join(picked)
return parts