Toadoum commited on
Commit
d34b995
Β·
verified Β·
1 Parent(s): be38378

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +448 -0
app.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PlotWeaver Audiobook Generator
3
+ English β†’ Hausa Translation + TTS with Timestamps
4
+
5
+ A POC demonstrating AI-powered audiobook creation for African languages.
6
+ """
7
+
8
+ import gradio as gr
9
+ import torch
10
+ import numpy as np
11
+ import tempfile
12
+ import os
13
+ import re
14
+ import json
15
+ from pathlib import Path
16
+ from datetime import timedelta
17
+ from typing import List, Tuple, Optional
18
+
19
+ # Document processing
20
+ import fitz # PyMuPDF
21
+ from docx import Document
22
+
23
+ # Translation & TTS
24
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, VitsModel
25
+ import scipy.io.wavfile as wavfile
26
+
27
+ # ============================================
28
+ # CONFIGURATION
29
+ # ============================================
30
+ NLLB_MODEL = "facebook/nllb-200-distilled-600M" # Optimized for speed
31
+ TTS_MODEL = "facebook/mms-tts-hau"
32
+ SRC_LANG = "eng_Latn"
33
+ TGT_LANG = "hau_Latn"
34
+ SAMPLE_RATE = 16000
35
+ MAX_CHUNK_LENGTH = 200 # characters per TTS chunk
36
+
37
+ # ============================================
38
+ # MODEL LOADING (Cached)
39
+ # ============================================
40
+ def load_models():
41
+ """Load translation and TTS models."""
42
+ print("πŸ”„ Loading models...")
43
+
44
+ device = "cuda" if torch.cuda.is_available() else "cpu"
45
+ print(f" Device: {device}")
46
+
47
+ # Load NLLB translation model
48
+ print(" Loading NLLB-200...")
49
+ nllb_tokenizer = AutoTokenizer.from_pretrained(NLLB_MODEL, src_lang=SRC_LANG)
50
+ nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
51
+ NLLB_MODEL,
52
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
53
+ )
54
+ if device == "cuda":
55
+ nllb_model = nllb_model.cuda()
56
+ nllb_model.eval()
57
+
58
+ # Load MMS-TTS Hausa
59
+ print(" Loading MMS-TTS Hausa...")
60
+ tts_model = VitsModel.from_pretrained(TTS_MODEL)
61
+ tts_tokenizer = AutoTokenizer.from_pretrained(TTS_MODEL)
62
+
63
+ if device == "cuda":
64
+ tts_model = tts_model.cuda()
65
+ tts_model.eval()
66
+
67
+ print("βœ… Models loaded successfully")
68
+ return nllb_model, nllb_tokenizer, tts_model, tts_tokenizer
69
+
70
+ # Global model loading
71
+ nllb_model, nllb_tokenizer, tts_model, tts_tokenizer = None, None, None, None
72
+
73
+ def initialize_models():
74
+ global nllb_model, nllb_tokenizer, tts_model, tts_tokenizer
75
+ if nllb_model is None:
76
+ nllb_model, nllb_tokenizer, tts_model, tts_tokenizer = load_models()
77
+
78
+ # ============================================
79
+ # DOCUMENT EXTRACTION
80
+ # ============================================
81
+ def extract_text_from_pdf(file_path: str) -> List[dict]:
82
+ """Extract text from PDF with page numbers."""
83
+ doc = fitz.open(file_path)
84
+ chapters = []
85
+
86
+ for page_num, page in enumerate(doc, 1):
87
+ text = page.get_text().strip()
88
+ if text:
89
+ chapters.append({
90
+ "chapter": f"Page {page_num}",
91
+ "text": text
92
+ })
93
+
94
+ doc.close()
95
+ return chapters
96
+
97
+ def extract_text_from_docx(file_path: str) -> List[dict]:
98
+ """Extract text from DOCX with paragraph grouping."""
99
+ doc = Document(file_path)
100
+ chapters = []
101
+ current_chapter = {"chapter": "Chapter 1", "text": ""}
102
+ chapter_num = 1
103
+
104
+ for para in doc.paragraphs:
105
+ text = para.text.strip()
106
+ if not text:
107
+ continue
108
+
109
+ # Detect chapter headings (simple heuristic)
110
+ if para.style.name.startswith('Heading') or (len(text) < 50 and text.isupper()):
111
+ if current_chapter["text"]:
112
+ chapters.append(current_chapter)
113
+ chapter_num += 1
114
+ current_chapter = {"chapter": text or f"Chapter {chapter_num}", "text": ""}
115
+ else:
116
+ current_chapter["text"] += text + "\n\n"
117
+
118
+ if current_chapter["text"]:
119
+ chapters.append(current_chapter)
120
+
121
+ return chapters
122
+
123
+ def extract_text(file_path: str) -> List[dict]:
124
+ """Extract text from uploaded file."""
125
+ ext = Path(file_path).suffix.lower()
126
+
127
+ if ext == ".pdf":
128
+ return extract_text_from_pdf(file_path)
129
+ elif ext in [".docx", ".doc"]:
130
+ return extract_text_from_docx(file_path)
131
+ elif ext == ".txt":
132
+ with open(file_path, "r", encoding="utf-8") as f:
133
+ text = f.read()
134
+ return [{"chapter": "Full Text", "text": text}]
135
+ else:
136
+ raise ValueError(f"Unsupported file format: {ext}")
137
+
138
+ # ============================================
139
+ # TRANSLATION (NLLB-200)
140
+ # ============================================
141
+ def translate_text(text: str) -> str:
142
+ """Translate English text to Hausa using NLLB-200."""
143
+ initialize_models()
144
+
145
+ device = "cuda" if torch.cuda.is_available() else "cpu"
146
+
147
+ # Split into sentences for better translation
148
+ sentences = re.split(r'(?<=[.!?])\s+', text)
149
+ translated_sentences = []
150
+
151
+ # Get target language token
152
+ tgt_lang_id = nllb_tokenizer.convert_tokens_to_ids(TGT_LANG)
153
+
154
+ with torch.no_grad():
155
+ for sentence in sentences:
156
+ if not sentence.strip():
157
+ continue
158
+
159
+ # Tokenize
160
+ inputs = nllb_tokenizer(
161
+ sentence,
162
+ return_tensors="pt",
163
+ truncation=True,
164
+ max_length=512,
165
+ padding=True
166
+ )
167
+
168
+ if device == "cuda":
169
+ inputs = {k: v.cuda() for k, v in inputs.items()}
170
+
171
+ # Translate
172
+ outputs = nllb_model.generate(
173
+ **inputs,
174
+ forced_bos_token_id=tgt_lang_id,
175
+ max_length=256,
176
+ num_beams=5,
177
+ early_stopping=True
178
+ )
179
+
180
+ # Decode
181
+ translated = nllb_tokenizer.decode(outputs[0], skip_special_tokens=True)
182
+ translated_sentences.append(translated)
183
+
184
+ return " ".join(translated_sentences)
185
+
186
+ # ============================================
187
+ # TEXT-TO-SPEECH (MMS-TTS)
188
+ # ============================================
189
+ def split_text_for_tts(text: str, max_length: int = MAX_CHUNK_LENGTH) -> List[str]:
190
+ """Split text into chunks suitable for TTS."""
191
+ # Split by sentences first
192
+ sentences = re.split(r'(?<=[.!?])\s+', text)
193
+ chunks = []
194
+ current_chunk = ""
195
+
196
+ for sentence in sentences:
197
+ if len(current_chunk) + len(sentence) <= max_length:
198
+ current_chunk += sentence + " "
199
+ else:
200
+ if current_chunk:
201
+ chunks.append(current_chunk.strip())
202
+ current_chunk = sentence + " "
203
+
204
+ if current_chunk:
205
+ chunks.append(current_chunk.strip())
206
+
207
+ return chunks
208
+
209
+ def generate_audio(text: str) -> Tuple[np.ndarray, List[dict]]:
210
+ """Generate audio from Hausa text with timestamps."""
211
+ initialize_models()
212
+
213
+ chunks = split_text_for_tts(text)
214
+ audio_segments = []
215
+ timestamps = []
216
+ current_time = 0.0
217
+
218
+ device = "cuda" if torch.cuda.is_available() else "cpu"
219
+
220
+ for chunk in chunks:
221
+ if not chunk.strip():
222
+ continue
223
+
224
+ # Tokenize
225
+ inputs = tts_tokenizer(chunk, return_tensors="pt")
226
+ if device == "cuda":
227
+ inputs = {k: v.cuda() for k, v in inputs.items()}
228
+
229
+ # Generate audio
230
+ with torch.no_grad():
231
+ output = tts_model(**inputs).waveform
232
+
233
+ audio = output.squeeze().cpu().numpy()
234
+ audio_segments.append(audio)
235
+
236
+ # Calculate timestamp
237
+ duration = len(audio) / SAMPLE_RATE
238
+ timestamps.append({
239
+ "start": format_timestamp(current_time),
240
+ "end": format_timestamp(current_time + duration),
241
+ "text": chunk
242
+ })
243
+ current_time += duration
244
+
245
+ # Concatenate all audio
246
+ if audio_segments:
247
+ full_audio = np.concatenate(audio_segments)
248
+ else:
249
+ full_audio = np.zeros(SAMPLE_RATE) # 1 second of silence
250
+
251
+ return full_audio, timestamps
252
+
253
+ def format_timestamp(seconds: float) -> str:
254
+ """Format seconds as HH:MM:SS.mmm"""
255
+ td = timedelta(seconds=seconds)
256
+ hours, remainder = divmod(td.seconds, 3600)
257
+ minutes, secs = divmod(remainder, 60)
258
+ milliseconds = int(td.microseconds / 1000)
259
+ return f"{hours:02d}:{minutes:02d}:{secs:02d}.{milliseconds:03d}"
260
+
261
+ # ============================================
262
+ # MAIN PIPELINE
263
+ # ============================================
264
+ def process_document(file, progress=gr.Progress()) -> Tuple[str, str, str, str]:
265
+ """
266
+ Main pipeline: Document β†’ Translation β†’ TTS β†’ Audiobook
267
+
268
+ Returns: (audio_path, transcript, timestamps_json, status)
269
+ """
270
+ if file is None:
271
+ return None, "", "", "⚠️ Please upload a document"
272
+
273
+ try:
274
+ progress(0.1, desc="πŸ“„ Extracting text...")
275
+ chapters = extract_text(file.name)
276
+
277
+ if not chapters:
278
+ return None, "", "", "⚠️ No text found in document"
279
+
280
+ # Combine all text (for POC, limit to first 2000 chars)
281
+ full_text = "\n\n".join([c["text"] for c in chapters])[:2000]
282
+
283
+ progress(0.3, desc="🌍 Translating to Hausa...")
284
+ translated_text = translate_text(full_text)
285
+
286
+ progress(0.6, desc="πŸŽ™οΈ Generating audio...")
287
+ audio, timestamps = generate_audio(translated_text)
288
+
289
+ progress(0.9, desc="πŸ’Ύ Saving audiobook...")
290
+
291
+ # Save audio
292
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
293
+ wavfile.write(f.name, SAMPLE_RATE, (audio * 32767).astype(np.int16))
294
+ audio_path = f.name
295
+
296
+ # Format timestamps
297
+ timestamps_text = "\n".join([
298
+ f"[{t['start']} β†’ {t['end']}] {t['text']}"
299
+ for t in timestamps
300
+ ])
301
+
302
+ # Create transcript
303
+ transcript = f"""## Original (English)
304
+ {full_text[:500]}{'...' if len(full_text) > 500 else ''}
305
+
306
+ ## Translation (Hausa)
307
+ {translated_text}
308
+ """
309
+
310
+ progress(1.0, desc="βœ… Complete!")
311
+
312
+ return audio_path, transcript, timestamps_text, "βœ… Audiobook generated successfully!"
313
+
314
+ except Exception as e:
315
+ return None, "", "", f"❌ Error: {str(e)}"
316
+
317
+ # ============================================
318
+ # GRADIO INTERFACE
319
+ # ============================================
320
+ def create_interface():
321
+
322
+ with gr.Blocks(
323
+ title="PlotWeaver Audiobook Generator",
324
+ theme=gr.themes.Soft(
325
+ primary_hue="orange",
326
+ secondary_hue="blue",
327
+ ),
328
+ css="""
329
+ .main-title {
330
+ text-align: center;
331
+ margin-bottom: 1rem;
332
+ }
333
+ .subtitle {
334
+ text-align: center;
335
+ color: #666;
336
+ margin-bottom: 2rem;
337
+ }
338
+ .output-panel {
339
+ border: 1px solid #ddd;
340
+ border-radius: 8px;
341
+ padding: 1rem;
342
+ }
343
+ """
344
+ ) as demo:
345
+
346
+ # Header
347
+ gr.HTML("""
348
+ <div class="main-title">
349
+ <h1>🎧 PlotWeaver Audiobook Generator</h1>
350
+ </div>
351
+ <div class="subtitle">
352
+ <p><strong>Transform English documents into Hausa audiobooks with timestamps</strong></p>
353
+ <p>Powered by NLLB-200 Translation + MMS-TTS</p>
354
+ </div>
355
+ """)
356
+
357
+ with gr.Row():
358
+ # Input Column
359
+ with gr.Column(scale=1):
360
+ gr.Markdown("### πŸ“ Upload Document")
361
+
362
+ file_input = gr.File(
363
+ label="Upload PDF, DOCX, or TXT",
364
+ file_types=[".pdf", ".docx", ".doc", ".txt"],
365
+ type="filepath"
366
+ )
367
+
368
+ generate_btn = gr.Button(
369
+ "πŸš€ Generate Audiobook",
370
+ variant="primary",
371
+ size="lg"
372
+ )
373
+
374
+ status_output = gr.Textbox(
375
+ label="Status",
376
+ interactive=False,
377
+ lines=1
378
+ )
379
+
380
+ gr.Markdown("""
381
+ ---
382
+ ### ℹ️ How it works
383
+ 1. **Upload** your English document
384
+ 2. **AI translates** to Hausa using NLLB-200
385
+ 3. **TTS generates** natural Hausa audio
386
+ 4. **Download** your audiobook with timestamps
387
+
388
+ ---
389
+ ### 🌍 Supported Languages
390
+ - πŸ‡¬πŸ‡§ English β†’ πŸ‡³πŸ‡¬ Hausa
391
+ - *More languages coming soon!*
392
+ """)
393
+
394
+ # Output Column
395
+ with gr.Column(scale=2):
396
+ gr.Markdown("### 🎧 Generated Audiobook")
397
+
398
+ audio_output = gr.Audio(
399
+ label="Hausa Audiobook",
400
+ type="filepath",
401
+ interactive=False
402
+ )
403
+
404
+ with gr.Tabs():
405
+ with gr.Tab("πŸ“œ Transcript"):
406
+ transcript_output = gr.Markdown(
407
+ label="Translation",
408
+ value="*Upload a document to see the transcript*"
409
+ )
410
+
411
+ with gr.Tab("⏱️ Timestamps"):
412
+ timestamps_output = gr.Textbox(
413
+ label="Timestamps",
414
+ lines=10,
415
+ interactive=False,
416
+ placeholder="Timestamps will appear here..."
417
+ )
418
+
419
+ # Footer
420
+ gr.HTML("""
421
+ <div style="text-align: center; margin-top: 2rem; padding: 1rem; background: #f8f9fa; border-radius: 8px;">
422
+ <p><strong>PlotWeaver</strong> - AI-Powered African Language Technology</p>
423
+ <p style="color: #666; font-size: 0.9rem;">
424
+ Democratizing content access across Africa through voice technology
425
+ </p>
426
+ </div>
427
+ """)
428
+
429
+ # Event handlers
430
+ generate_btn.click(
431
+ fn=process_document,
432
+ inputs=[file_input],
433
+ outputs=[audio_output, transcript_output, timestamps_output, status_output],
434
+ show_progress=True
435
+ )
436
+
437
+ return demo
438
+
439
+ # ============================================
440
+ # MAIN
441
+ # ============================================
442
+ if __name__ == "__main__":
443
+ demo = create_interface()
444
+ demo.launch(
445
+ share=False,
446
+ server_name="0.0.0.0",
447
+ server_port=7860
448
+ )