File size: 2,978 Bytes
ae1d809 c5b5c64 ae1d809 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 | import gradio as gr
import cv2
from pathlib import Path
from fastcdm import FastCDM
CHROMEDRIVER_PATH = Path("driver/chromedriver")
def _wrap_latex(s: str) -> str:
s = s or ""
return s if s.strip().startswith("$$") else f"$$ {s} $$"
def preview_latex_gt(gt: str) -> str:
return _wrap_latex(gt)
def preview_latex_pred(pred: str) -> str:
return _wrap_latex(pred)
def compute_fastcdm(gt: str, pred: str):
print("-" * 20)
print(" gt:", gt)
print("pred:", pred)
print("-" * 20)
driver_path = str(CHROMEDRIVER_PATH) if CHROMEDRIVER_PATH.exists() else None
fastcdm = FastCDM(chromedriver=driver_path)
try:
f1, recall, precision, vis_img = fastcdm.compute(gt, pred, visualize=True)
if vis_img is not None:
vis_rgb = cv2.cvtColor(vis_img, cv2.COLOR_BGR2RGB)
else:
vis_rgb = None
metrics_md = f"**CDM得分(F1)**: {f1:.4f} \n**召回率**: {recall:.4f} \n**准确率**: {precision:.4f}"
return metrics_md, vis_rgb
finally:
fastcdm.close()
with gr.Blocks(title="FastCDM 可视化") as demo:
gr.Markdown("# FastCDM 可视化")
with gr.Row():
with gr.Column():
gt_input = gr.Textbox(
label="GT (LaTeX)",
lines=4,
placeholder="输入GT公式,例如: \\frac{1}{2}",
)
gt_md = gr.Markdown(
value="",
latex_delimiters=[
{"left": "$$", "right": "$$", "display": True},
{"left": "$", "right": "$", "display": False},
{"left": "\\(", "right": "\\)", "display": False},
{"left": "\\[", "right": "\\]", "display": True},
],
)
pred_input = gr.Textbox(
label="Pred (LaTeX)",
lines=4,
placeholder="输入Pred公式,例如: \\frac{1}{2}",
)
pred_md = gr.Markdown(
value="",
latex_delimiters=[
{"left": "$$", "right": "$$", "display": True},
{"left": "$", "right": "$", "display": False},
{"left": "\\(", "right": "\\)", "display": False},
{"left": "\\[", "right": "\\]", "display": True},
],
)
submit_btn = gr.Button("提交并评估")
with gr.Column():
metrics_out = gr.Markdown(label="评估指标")
vis_out = gr.Image(type="numpy", label="匹配可视化", format="png")
gt_input.change(fn=preview_latex_gt, inputs=gt_input, outputs=gt_md)
pred_input.change(fn=preview_latex_pred, inputs=pred_input, outputs=pred_md)
submit_btn.click(
fn=compute_fastcdm,
inputs=[gt_input, pred_input],
outputs=[metrics_out, vis_out],
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)
|