OnlyCheeini commited on
Commit
51356d9
·
verified ·
1 Parent(s): 161df1b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -62
app.py CHANGED
@@ -1,109 +1,151 @@
1
  import gradio as gr
2
  import torch
3
  import json
 
4
  from pathlib import Path
5
  from model import (
6
  GreesyGPT,
7
  generate_moderation,
 
 
 
8
  ReasoningMode,
9
  OutputFormat,
10
  DEVICE,
11
- describe_reasoning_modes
 
12
  )
13
 
14
- # 1. Initialize Model
15
  model = GreesyGPT()
16
  weights_path = Path("greesy_gpt.pt")
17
 
18
- if weights_path.exists():
19
- model.load_state_dict(torch.load(weights_path, map_location=DEVICE))
20
- print(f"Loaded weights from {weights_path}")
21
- else:
22
- print("Warning: No trained weights found. Using fresh initialization.")
23
 
 
24
  model.to(DEVICE)
25
- model.eval()
26
 
 
27
  def moderate(text, mode_str, format_str):
28
  if not text.strip():
29
- return "Please enter some text to analyze.", ""
30
 
 
31
  mode = ReasoningMode(mode_str.lower())
32
  fmt = OutputFormat(format_str.lower())
33
 
34
- result = generate_moderation(
35
- model,
36
- prompt=text,
37
- mode=mode,
38
- output_format=fmt
39
- )
40
 
41
  verdict_output = result["verdict_fmt"]
42
  if fmt == OutputFormat.JSON:
43
  verdict_output = f"```json\n{json.dumps(verdict_output, indent=2)}\n```"
44
 
45
  thinking_process = result.get("thinking", "No reasoning generated.")
46
-
47
  return verdict_output, thinking_process
48
 
49
- # 2. Build Gradio UI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  theme = gr.themes.Soft(primary_hue="orange", secondary_hue="gray")
51
 
52
- with gr.Blocks(theme=theme, title="GreesyGPT Content Moderation") as demo:
53
- gr.Markdown("# 🛡️ GreesyGPT Content Moderation")
54
- gr.Markdown("Reasoning-based safety model using chain-of-thought deliberation.")
55
 
56
- with gr.Row():
57
- with gr.Column(scale=2):
58
- input_text = gr.Textbox(
59
- label="Message to Review",
60
- placeholder="Type the message you want to moderate here...",
61
- lines=5
62
- )
63
-
64
  with gr.Row():
65
- mode_dropdown = gr.Dropdown(
66
- choices=[m.value for m in ReasoningMode],
67
- value="low",
68
- label="Reasoning Mode"
69
- )
70
- format_dropdown = gr.Dropdown(
71
- choices=[f.value for f in OutputFormat],
72
- value="markdown",
73
- label="Output Format"
74
- )
75
-
76
- submit_btn = gr.Button("Analyze Content", variant="primary")
77
 
78
- with gr.Column(scale=3):
79
- output_verdict = gr.Markdown(label="Verdict")
80
-
81
- # FIXED: Changed Expander to Accordion
82
- with gr.Accordion("View Internal Reasoning (Thinking Process)", open=False):
83
- output_thinking = gr.Textbox(
84
- label="Chain of Thought",
85
- interactive=False,
86
- lines=10
87
- )
88
 
89
- gr.Examples(
90
- examples=[
91
- ["You're so stupid, nobody likes you.", "medium", "markdown"],
92
- ["How do I fix a bug in my Python code?", "none", "markdown"],
93
- ["CONGRATULATIONS! You won a $1000 gift card! Click here!", "low", "json"],
94
- ],
95
- inputs=[input_text, mode_dropdown, format_dropdown]
96
- )
 
 
97
 
98
- # FIXED: Changed Expander to Accordion
99
- with gr.Accordion("System Information / Reasoning Mode Definitions", open=False):
100
- gr.Code(describe_reasoning_modes(), language="text")
 
 
101
 
 
102
  submit_btn.click(
103
  fn=moderate,
104
  inputs=[input_text, mode_dropdown, format_dropdown],
105
  outputs=[output_verdict, output_thinking]
106
  )
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  if __name__ == "__main__":
109
- demo.launch()
 
 
1
  import gradio as gr
2
  import torch
3
  import json
4
+ import os
5
  from pathlib import Path
6
  from model import (
7
  GreesyGPT,
8
  generate_moderation,
9
+ GreesyTrainer,
10
+ get_dataset,
11
+ get_sample_dataset,
12
  ReasoningMode,
13
  OutputFormat,
14
  DEVICE,
15
+ describe_reasoning_modes,
16
+ DATASET_JSON_PATH
17
  )
18
 
19
+ # 1. Initialize Model Global Instance
20
  model = GreesyGPT()
21
  weights_path = Path("greesy_gpt.pt")
22
 
23
+ def load_weights():
24
+ if weights_path.exists():
25
+ model.load_state_dict(torch.load(weights_path, map_location=DEVICE))
26
+ return f"Loaded weights from {weights_path}"
27
+ return "No weights found. Model initialized with random parameters."
28
 
29
+ load_weights()
30
  model.to(DEVICE)
 
31
 
32
+ # --- Inference Logic ---
33
  def moderate(text, mode_str, format_str):
34
  if not text.strip():
35
+ return "Please enter text.", ""
36
 
37
+ model.eval()
38
  mode = ReasoningMode(mode_str.lower())
39
  fmt = OutputFormat(format_str.lower())
40
 
41
+ result = generate_moderation(model, prompt=text, mode=mode, output_format=fmt)
 
 
 
 
 
42
 
43
  verdict_output = result["verdict_fmt"]
44
  if fmt == OutputFormat.JSON:
45
  verdict_output = f"```json\n{json.dumps(verdict_output, indent=2)}\n```"
46
 
47
  thinking_process = result.get("thinking", "No reasoning generated.")
 
48
  return verdict_output, thinking_process
49
 
50
+ # --- Training Logic ---
51
+ def start_training(epochs, batch_size, grad_accum):
52
+ try:
53
+ # Load data
54
+ if DATASET_JSON_PATH.exists():
55
+ dataset = get_dataset()
56
+ data_source = "dataset.json"
57
+ else:
58
+ dataset = get_sample_dataset()
59
+ data_source = "Hardcoded Sample Data"
60
+
61
+ trainer = GreesyTrainer(
62
+ model=model,
63
+ train_dataset=dataset,
64
+ batch_size=int(batch_size),
65
+ grad_accum=int(grad_accum)
66
+ )
67
+
68
+ log_history = [f"Starting training on {DEVICE} using {data_source}..."]
69
+
70
+ for epoch in range(1, int(epochs) + 1):
71
+ avg_loss = trainer.train_epoch(epoch)
72
+ log_history.append(f"Epoch {epoch}: Loss = {avg_loss:.4f}")
73
+ yield "\n".join(log_history)
74
+
75
+ # Save weights
76
+ torch.save(model.state_dict(), weights_path)
77
+ log_history.append(f"Success: Weights saved to {weights_path}")
78
+ yield "\n".join(log_history)
79
+
80
+ except Exception as e:
81
+ yield f"Error during training: {str(e)}"
82
+
83
+ # --- UI Layout ---
84
  theme = gr.themes.Soft(primary_hue="orange", secondary_hue="gray")
85
 
86
+ with gr.Blocks() as demo:
87
+ gr.Markdown("# 🛡️ GreesyGPT Control Center")
 
88
 
89
+ with gr.Tabs():
90
+ # TAB 1: MODERATION (INFERENCE)
91
+ with gr.Tab("Moderation Interface"):
 
 
 
 
 
92
  with gr.Row():
93
+ with gr.Column(scale=2):
94
+ input_text = gr.Textbox(label="Message to Review", lines=5)
95
+ with gr.Row():
96
+ mode_dropdown = gr.Dropdown(
97
+ choices=[m.value for m in ReasoningMode],
98
+ value="low", label="Reasoning Mode"
99
+ )
100
+ format_dropdown = gr.Dropdown(
101
+ choices=[f.value for f in OutputFormat],
102
+ value="markdown", label="Output Format"
103
+ )
104
+ submit_btn = gr.Button("Analyze Content", variant="primary")
105
 
106
+ with gr.Column(scale=3):
107
+ output_verdict = gr.Markdown(label="Verdict")
108
+ with gr.Accordion("Internal Reasoning (Thinking)", open=False):
109
+ output_thinking = gr.Textbox(label="", interactive=False, lines=10)
 
 
 
 
 
 
110
 
111
+ # TAB 2: TRAINING
112
+ with gr.Tab("Model Training"):
113
+ gr.Markdown("### Fine-tune GreesyGPT")
114
+ with gr.Row():
115
+ epoch_slider = gr.Slider(minimum=1, maximum=10, value=1, step=1, label="Epochs")
116
+ batch_slider = gr.Slider(minimum=1, maximum=8, value=2, step=1, label="Batch Size")
117
+ accum_slider = gr.Slider(minimum=1, maximum=16, value=4, step=1, label="Grad Accumulation")
118
+
119
+ train_btn = gr.Button("🚀 Start Training Session", variant="stop")
120
+ train_logs = gr.Textbox(label="Training Logs", interactive=False, lines=10)
121
 
122
+ # TAB 3: SYSTEM INFO
123
+ with gr.Tab("System Info"):
124
+ gr.Markdown("### Reasoning Mode Definitions")
125
+ gr.Code(describe_reasoning_modes(), language="markdown")
126
+ status_msg = gr.Textbox(value=load_weights(), label="Model Status", interactive=False)
127
 
128
+ # --- Event Handlers ---
129
  submit_btn.click(
130
  fn=moderate,
131
  inputs=[input_text, mode_dropdown, format_dropdown],
132
  outputs=[output_verdict, output_thinking]
133
  )
134
 
135
+ train_btn.click(
136
+ fn=start_training,
137
+ inputs=[epoch_slider, batch_slider, accum_slider],
138
+ outputs=[train_logs]
139
+ )
140
+
141
+ gr.Examples(
142
+ examples=[
143
+ ["You're so stupid, nobody likes you.", "medium", "markdown"],
144
+ ["CONGRATULATIONS! You won a $1000 prize!", "low", "json"],
145
+ ],
146
+ inputs=[input_text, mode_dropdown, format_dropdown]
147
+ )
148
+
149
  if __name__ == "__main__":
150
+ # In Gradio 6.0+, theme and title are passed here
151
+ demo.launch(theme=theme, title="GreesyGPT Content Moderation")