| import torch |
| import numpy as np |
| import pandas as pd |
| import gradio as gr |
| import tempfile |
| import subprocess |
| from matplotlib.animation import FFMpegWriter, PillowWriter |
| import matplotlib.pyplot as plt |
| from matplotlib import animation |
| from config import MAX_TEXT_LEN |
| from data import selective_smoothing, GLOBAL_MEAN_T, GLOBAL_STD_T |
| from model import TextToPoseSeq2Seq |
| from transformers import BertTokenizer |
| |
| from faster_whisper import WhisperModel |
|
|
| |
| tokenizer = BertTokenizer.from_pretrained("indobenchmark/indobert-base-p2") |
|
|
| |
| print("FFmpeg check:", subprocess.run(["ffmpeg", "-version"], capture_output=True).stdout.decode().splitlines()[0]) |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = TextToPoseSeq2Seq(tokenizer.vocab_size).to(device) |
| model.load_state_dict(torch.load("best_seq2seq_model_mask.pth", map_location=device)) |
| model.to(device) |
| print("Loaded pretrained weights.") |
|
|
| |
| annot_df = pd.read_csv("annotated.csv") |
| annotated_words = set(annot_df["text"].str.strip().str.lower().unique()) |
|
|
| |
| video_df = pd.read_csv("annotated_vid_link.csv") |
| video_df["text_clean"] = video_df["text"].str.strip().str.lower() |
| video_lookup = dict(zip(video_df["text_clean"], video_df["Video URL"])) |
|
|
| |
| whisper_model = WhisperModel("small", compute_type="int8") |
|
|
| def transcribe_audio(audio_path): |
| try: |
| segments, _ = whisper_model.transcribe(audio_path, language="ms", beam_size=5) |
| full_text = " ".join([segment.text.strip() for segment in segments]) |
| return full_text.strip() |
| except Exception as e: |
| print("Whisper Error:", e) |
| return "" |
|
|
| def get_youtube_link(input_text): |
| return video_lookup.get(input_text.strip().lower()) |
|
|
|
|
| |
| selected_keypoint_indices = list(np.r_[0:25, 501:522, 522:543]) |
| NUM_KEYPOINTS = len(selected_keypoint_indices) |
| POSE_DIM = NUM_KEYPOINTS * 3 |
|
|
| |
| mediapipe_connections = [ |
| (0, 1), (1, 2), (2, 3), (3, 7), (0, 4), (4, 5), (5, 6), (6, 8), |
| (9, 10), (11, 12), (12, 14), (14, 16), (11, 13), (13, 15), |
| (23, 24), (11, 23), (12, 24) |
| ] |
|
|
| def add_hand_connections(base_index): |
| return [(base_index + i, base_index + j) for i, j in [ |
| (0,1), (1,2), (2,3), (3,4), |
| (0,5), (5,6), (6,7), (7,8), |
| (0,9), (9,10), (10,11), (11,12), |
| (0,13), (13,14), (14,15), (15,16), |
| (0,17), (17,18), (18,19), (19,20) |
| ]] |
|
|
| hand1 = selected_keypoint_indices.index(501) |
| hand2 = selected_keypoint_indices.index(522) |
| mediapipe_connections += add_hand_connections(hand1) |
| mediapipe_connections += add_hand_connections(hand2) |
|
|
| |
| def concatenate_and_smooth_sequences(sentence, tokenizer, model, device, GLOBAL_MEAN_T, GLOBAL_STD_T): |
| sentence = sentence.strip() |
| if sentence.lower() in annotated_words: |
| words = [sentence] |
| else: |
| words = sentence.split() |
|
|
| pose_preds, conf_preds, frame_labels = [], [], [] |
| current_frame = 0 |
|
|
| model.eval() |
| with torch.no_grad(): |
| for word in words: |
| inputs = tokenizer(word, padding="max_length", truncation=True, max_length=MAX_TEXT_LEN, return_tensors="pt") |
| input_ids = inputs.input_ids.to(device) |
| attn_mask = inputs.attention_mask.to(device) |
| pred_pose_norm, pred_conf = model(input_ids, attention_mask=attn_mask) |
|
|
| T = pred_pose_norm.shape[1] |
| pose_preds.append(pred_pose_norm[0]) |
| conf_preds.append(pred_conf[0]) |
| frame_labels.append((word, current_frame, current_frame + T - 1)) |
| current_frame += T |
|
|
| if not pose_preds: |
| return None, None, None |
|
|
| full_pose = torch.cat(pose_preds, dim=0).unsqueeze(0) |
| full_conf = torch.cat(conf_preds, dim=0) |
| smoothed = selective_smoothing(full_pose).squeeze(0) |
| unnormalized = smoothed * GLOBAL_STD_T + GLOBAL_MEAN_T |
| return unnormalized.view(-1, NUM_KEYPOINTS, 3).cpu().numpy(), full_conf.cpu().numpy(), frame_labels |
|
|
| |
| def animate_pose(pred_pose, pred_conf=None, frame_labels=None, interval=150, conf_threshold=0.3): |
| fig, ax = plt.subplots(figsize=(5, 5)) |
| fig.subplots_adjust(top=0.85) |
|
|
| def setup(): |
| all_x, all_y = pred_pose[:,:,0].flatten(), pred_pose[:,:,1].flatten() |
| x_buf = (all_x.max() - all_x.min()) * 0.1 + 0.1 |
| y_buf = (all_y.max() - all_y.min()) * 0.1 + 0.1 |
| ax.set_xlim(all_x.min() - x_buf, all_x.max() + x_buf) |
| ax.set_ylim(-all_y.max() - y_buf, -all_y.min() + y_buf) |
| ax.set_aspect('equal') |
| ax.axis('off') |
|
|
| setup() |
| pred_lines = [ax.plot([], [], color='red', lw=2)[0] for _ in mediapipe_connections] |
| pred_pts = ax.plot([], [], 'ko', markersize=3)[0] |
|
|
| def init(): |
| for line in pred_lines: line.set_data([], []) |
| pred_pts.set_data([], []) |
| return pred_lines + [pred_pts] |
|
|
| def update(frame): |
| px, py = pred_pose[frame,:,0], -pred_pose[frame,:,1] |
| if pred_conf is not None: |
| mask = pred_conf[frame] > conf_threshold |
| px[~mask], py[~mask] = np.nan, np.nan |
|
|
| pred_pts.set_data(px, py) |
| for text in ax.texts: text.remove() |
| for i in range(NUM_KEYPOINTS): |
| ax.text(px[i], py[i], str(i), fontsize=5, color='black') |
|
|
| for i, (start, end) in enumerate(mediapipe_connections): |
| if pred_conf is None or (pred_conf[frame][start] > conf_threshold and pred_conf[frame][end] > conf_threshold): |
| pred_lines[i].set_data([px[start], px[end]], [py[start], py[end]]) |
| else: |
| pred_lines[i].set_data([], []) |
|
|
| if frame_labels: |
| for word, start, end in frame_labels: |
| if start <= frame <= end: |
| ax.set_title(f'Prediction: β{word}β (Frames {start}β{end})', fontsize=12, pad=15) |
| break |
|
|
| return pred_lines + [pred_pts] |
|
|
| ani = animation.FuncAnimation(fig, update, frames=len(pred_pose), init_func=init, blit=True, interval=interval) |
| plt.tight_layout(pad=2.0) |
| return ani |
|
|
| |
| def save_animation(anim, format="mp4"): |
| try: |
| ext = ".mp4" if format == "mp4" else ".gif" |
| writer = FFMpegWriter(fps=10, bitrate=1800) if format == "mp4" else PillowWriter(fps=10) |
| with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as f: |
| path = f.name |
| anim.save(path, writer=writer) |
| return path |
| except Exception as e: |
| print(f"{format.upper()} save failed:", e) |
| return None |
|
|
| def find_all_matches(text, video_lookup): |
| words = text.strip().lower().split() |
| matched = [] |
| i = 0 |
| while i < len(words): |
| found = False |
| |
| if i + 2 < len(words): |
| phrase3 = " ".join(words[i:i+3]) |
| if phrase3 in video_lookup: |
| matched.append((phrase3, video_lookup[phrase3])) |
| i += 3 |
| found = True |
| continue |
| |
| if i + 1 < len(words): |
| phrase2 = " ".join(words[i:i+2]) |
| if phrase2 in video_lookup: |
| matched.append((phrase2, video_lookup[phrase2])) |
| i += 2 |
| found = True |
| continue |
| |
| word = words[i] |
| if word in video_lookup: |
| matched.append((word, video_lookup[word])) |
| else: |
| matched.append((word, None)) |
| i += 1 |
| return matched |
|
|
|
|
| def predict(text, threshold, show_videos=True): |
| if not text.strip(): |
| return None, "β οΈ Please enter valid text.", "" |
|
|
| try: |
| pose, conf, labels = concatenate_and_smooth_sequences(text, tokenizer, model, device, GLOBAL_MEAN_T, GLOBAL_STD_T) |
| if pose is None: return None, "β οΈ No pose predicted.", "" |
| anim = animate_pose(pose, pred_conf=conf, frame_labels=labels, conf_threshold=threshold) |
| path = save_animation(anim, format="mp4") or save_animation(anim, format="gif") |
| if not path: |
| return None, "β Failed to save animation.", "" |
|
|
| |
| cleaned_text = text.strip().lower() |
| result_text, video_html = build_result_with_video_links(cleaned_text, video_lookup, show_videos) |
| return path, result_text, video_html |
|
|
| except Exception as e: |
| print("Error during prediction:", e) |
| return None, f"β Runtime error: {str(e)}", "" |
|
|
| def build_result_with_video_links(cleaned_text, video_lookup, show_videos=True): |
| checks = ["**Match Check (Phrase + Word Level):**"] |
| html_blocks = [] |
|
|
| if cleaned_text in video_lookup: |
| checks.append(f'- β{cleaned_text}β β
in dataset') |
| if show_videos: |
| url = video_lookup[cleaned_text] |
| video_id = url.split("v=")[-1] |
| html_blocks.append( |
| f'<iframe width="480" height="270" src="https://www.youtube.com/embed/{video_id}" frameborder="0" allowfullscreen></iframe>' |
| ) |
| else: |
| checks.append(f'- β{cleaned_text}β β οΈ not in dataset β broken into words') |
| for word in cleaned_text.split(): |
| if word in annotated_words: |
| checks.append(f' - β{word}β β
in dataset') |
| if show_videos and word in video_lookup: |
| url = video_lookup[word] |
| video_id = url.split("v=")[-1] |
| html_blocks.append( |
| f'<iframe width="240" height="135" src="https://www.youtube.com/embed/{video_id}" frameborder="0" allowfullscreen></iframe>' |
| ) |
| else: |
| checks.append(f' - β{word}β β not found β generated by approximation') |
|
|
| return "\n".join(checks), "<br>".join(html_blocks) if show_videos else "" |
|
|
|
|
| |
| with gr.Blocks() as demo: |
| gr.Markdown("# Text-to-Malay Sign Pose Generator") |
| gr.Markdown("Generate Malaysian Sign Language (BIM) pose animation from Malay text. Checks which words were seen in training and shows reference YouTube video (if available).") |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| with gr.Tab("Text Input"): |
| text_input = gr.Textbox(label="Enter Malay Word or Sentence") |
| with gr.Tab("Speech Input"): |
| audio_input = gr.Audio(type="filepath", label="Upload or Record Malay Audio") |
| audio_transcript = gr.Textbox(label="Transcribed Text", interactive=True) |
| transcribe_btn = gr.Button("Transcribe") |
| |
| threshold_slider = gr.Slider(0.0, 1.0, value=0.05, step=0.05, label="Confidence Threshold (for displaying joints)") |
| show_video_toggle = gr.Checkbox(label="Show Video Previews", value=True) |
| submit_btn = gr.Button("Submit") |
| clear_btn = gr.Button("Clear") |
|
|
| with gr.Column(scale=1): |
| video_output = gr.Video(label="Generated Pose Animation", height=270) |
| text_output = gr.Markdown(label="Match Check (Phrase + Word Level):") |
| youtube_output = gr.HTML() |
|
|
| submit_btn.click(fn=predict, |
| inputs=[text_input, threshold_slider, show_video_toggle], |
| outputs=[video_output, text_output, youtube_output]) |
| |
| transcribe_btn.click(fn=transcribe_audio, |
| inputs=audio_input, outputs=audio_transcript) |
| audio_transcript.change(fn=predict, |
| inputs=[audio_transcript, threshold_slider, show_video_toggle], |
| outputs=[video_output, text_output, youtube_output]) |
| |
| clear_btn.click(lambda: ("", "", ""), |
| inputs=[], |
| outputs=[video_output, text_output, youtube_output]) |
|
|
| demo.launch(debug=True) |