Dua Rajper commited on
Commit
fd320b1
·
verified ·
1 Parent(s): 8bad5cf

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +139 -0
app.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from groq import Groq
4
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, pipeline
5
+ from espnet2.bin.tts_inference import Text2Speech
6
+ import soundfile as sf
7
+ from pydub import AudioSegment
8
+ import io
9
+ from streamlit_webrtc import webrtc_streamer, WebRtcMode, AudioProcessorBase
10
+ import av
11
+ import numpy as np
12
+
13
+ # Load Groq API key from environment secrets
14
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
15
+ if not GROQ_API_KEY:
16
+ st.error("Groq API key not found. Please add it as a secret.")
17
+ st.stop()
18
+
19
+ # Initialize Groq client
20
+ groq_client = Groq(api_key=GROQ_API_KEY)
21
+
22
+ # Load models
23
+ @st.cache_resource # Use st.cache_resource for caching models
24
+ def load_models():
25
+ # Speech-to-Text
26
+ processor = AutoProcessor.from_pretrained("openai/whisper-small")
27
+ stt_model = AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-small")
28
+ stt_pipe = pipeline(
29
+ "automatic-speech-recognition",
30
+ model=stt_model,
31
+ tokenizer=processor.tokenizer,
32
+ feature_extractor=processor.feature_extractor,
33
+ return_timestamps=True # Enable timestamps for long-form audio
34
+ )
35
+
36
+ # Text-to-Speech
37
+ tts_model = Text2Speech.from_pretrained("espnet/espnet_tts_vctk_espnet_spk_voxceleb12_rawnet")
38
+
39
+ return stt_pipe, tts_model
40
+
41
+ stt_pipe, tts_model = load_models()
42
+
43
+ # Audio recorder
44
+ class AudioRecorder(AudioProcessorBase):
45
+ def __init__(self):
46
+ self.audio_frames = []
47
+
48
+ def recv(self, frame: av.AudioFrame) -> av.AudioFrame:
49
+ self.audio_frames.append(frame.to_ndarray())
50
+ return frame
51
+
52
+ # Streamlit app
53
+ st.title("Voice and Text Chatbot")
54
+
55
+ # Sidebar for mode selection
56
+ mode = st.sidebar.radio("Select Mode", ["Text Chatbot", "Voice Chatbot"])
57
+
58
+ if mode == "Text Chatbot":
59
+ # Text Chatbot
60
+ st.header("Text Chatbot")
61
+ user_input = st.text_input("Enter your message:")
62
+
63
+ if user_input:
64
+ try:
65
+ # Generate response using Groq API
66
+ chat_completion = groq_client.chat.completions.create(
67
+ messages=[{"role": "user", "content": user_input}],
68
+ model="mixtral-8x7b-32768",
69
+ temperature=0.5,
70
+ max_tokens=1024
71
+ )
72
+ response = chat_completion.choices[0].message.content
73
+ st.write("Generated Response:", response)
74
+
75
+ # Convert response to speech
76
+ speech, *_ = tts_model(response, spembs=tts_model.spembs[0]) # Use the first speaker embedding
77
+ sf.write("response.wav", speech, 22050)
78
+ st.audio("response.wav")
79
+ except Exception as e:
80
+ st.error(f"Error generating response: {e}")
81
+
82
+ elif mode == "Voice Chatbot":
83
+ # Voice Chatbot
84
+ st.header("Voice Chatbot")
85
+
86
+ # Audio recorder
87
+ st.write("Record your voice:")
88
+ webrtc_ctx = webrtc_streamer(
89
+ key="audio-recorder",
90
+ mode=WebRtcMode.SENDONLY,
91
+ audio_processor_factory=AudioRecorder,
92
+ media_stream_constraints={"audio": True, "video": False},
93
+ )
94
+
95
+ if webrtc_ctx.audio_processor:
96
+ st.write("Recording... Press 'Stop' to finish recording.")
97
+
98
+ # Save recorded audio to a WAV file
99
+ if st.button("Stop and Process Recording"):
100
+ audio_frames = webrtc_ctx.audio_processor.audio_frames
101
+ if audio_frames:
102
+ # Combine audio frames into a single array
103
+ audio_data = np.concatenate(audio_frames)
104
+ # Save as WAV file
105
+ sf.write("recorded_audio.wav", audio_data, samplerate=16000)
106
+ st.success("Recording saved as recorded_audio.wav")
107
+
108
+ # Process the recorded audio
109
+ speech, _ = sf.read("recorded_audio.wav")
110
+ output = stt_pipe(speech) # Transcribe with timestamps
111
+
112
+ # Display the full transcribed text
113
+ st.write("Transcribed Text:", output['text'])
114
+
115
+ # Display the text with timestamps (optional)
116
+ if 'chunks' in output:
117
+ st.write("Transcribed Text with Timestamps:")
118
+ for chunk in output['chunks']:
119
+ st.write(f"{chunk['timestamp'][0]:.2f} - {chunk['timestamp'][1]:.2f}: {chunk['text']}")
120
+
121
+ # Generate response using Groq API
122
+ try:
123
+ chat_completion = groq_client.chat.completions.create(
124
+ messages=[{"role": "user", "content": output['text']}],
125
+ model="mixtral-8x7b-32768",
126
+ temperature=0.5,
127
+ max_tokens=1024
128
+ )
129
+ response = chat_completion.choices[0].message.content
130
+ st.write("Generated Response:", response)
131
+
132
+ # Convert response to speech
133
+ speech, *_ = tts_model(response, spembs=tts_model.spembs[0]) # Use the first speaker embedding
134
+ sf.write("response.wav", speech, 22050)
135
+ st.audio("response.wav")
136
+ except Exception as e:
137
+ st.error(f"Error generating response: {e}")
138
+ else:
139
+ st.error("No audio recorded. Please try again.")