File size: 2,978 Bytes
ae1d809
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5b5c64
 
 
 
 
 
 
 
 
 
 
 
ae1d809
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import gradio as gr
import cv2
from pathlib import Path
from fastcdm import FastCDM


CHROMEDRIVER_PATH = Path("driver/chromedriver")


def _wrap_latex(s: str) -> str:
    s = s or ""
    return s if s.strip().startswith("$$") else f"$$ {s} $$"


def preview_latex_gt(gt: str) -> str:
    return _wrap_latex(gt)


def preview_latex_pred(pred: str) -> str:
    return _wrap_latex(pred)


def compute_fastcdm(gt: str, pred: str):
    print("-" * 20)
    print("  gt:", gt)
    print("pred:", pred)
    print("-" * 20)

    driver_path = str(CHROMEDRIVER_PATH) if CHROMEDRIVER_PATH.exists() else None
    fastcdm = FastCDM(chromedriver=driver_path)
    try:
        f1, recall, precision, vis_img = fastcdm.compute(gt, pred, visualize=True)

        if vis_img is not None:
            vis_rgb = cv2.cvtColor(vis_img, cv2.COLOR_BGR2RGB)
        else:
            vis_rgb = None

        metrics_md = f"**CDM得分(F1)**: {f1:.4f}  \n**召回率**: {recall:.4f}  \n**准确率**: {precision:.4f}"
        return metrics_md, vis_rgb
    finally:
        fastcdm.close()


with gr.Blocks(title="FastCDM 可视化") as demo:
    gr.Markdown("# FastCDM 可视化")

    with gr.Row():
        with gr.Column():
            gt_input = gr.Textbox(
                label="GT (LaTeX)",
                lines=4,
                placeholder="输入GT公式,例如: \\frac{1}{2}",
            )
            gt_md = gr.Markdown(
                value="",
                latex_delimiters=[
                    {"left": "$$", "right": "$$", "display": True},
                    {"left": "$", "right": "$", "display": False},
                    {"left": "\\(", "right": "\\)", "display": False},
                    {"left": "\\[", "right": "\\]", "display": True},
                ],
            )

            pred_input = gr.Textbox(
                label="Pred (LaTeX)",
                lines=4,
                placeholder="输入Pred公式,例如: \\frac{1}{2}",
            )
            pred_md = gr.Markdown(
                value="",
                latex_delimiters=[
                    {"left": "$$", "right": "$$", "display": True},
                    {"left": "$", "right": "$", "display": False},
                    {"left": "\\(", "right": "\\)", "display": False},
                    {"left": "\\[", "right": "\\]", "display": True},
                ],
            )

            submit_btn = gr.Button("提交并评估")

        with gr.Column():
            metrics_out = gr.Markdown(label="评估指标")
            vis_out = gr.Image(type="numpy", label="匹配可视化", format="png")

    gt_input.change(fn=preview_latex_gt, inputs=gt_input, outputs=gt_md)
    pred_input.change(fn=preview_latex_pred, inputs=pred_input, outputs=pred_md)
    submit_btn.click(
        fn=compute_fastcdm,
        inputs=[gt_input, pred_input],
        outputs=[metrics_out, vis_out],
    )


if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)