| |
| """Get exact NeMo streaming inference output for comparison with Swift.""" |
|
|
| import os |
| os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" |
|
|
| import torch |
| import numpy as np |
| import librosa |
| import json |
|
|
| from nemo.collections.asr.models import SortformerEncLabelModel |
|
|
| def main(): |
| print("Loading NeMo model...") |
| model = SortformerEncLabelModel.restore_from( |
| 'diar_streaming_sortformer_4spk-v2.nemo', map_location='cpu' |
| ) |
| model.eval() |
|
|
| |
| if hasattr(model.preprocessor, 'featurizer'): |
| if hasattr(model.preprocessor.featurizer, 'dither'): |
| model.preprocessor.featurizer.dither = 0.0 |
|
|
| |
| modules = model.sortformer_modules |
| modules.chunk_len = 6 |
| modules.chunk_left_context = 1 |
| modules.chunk_right_context = 7 |
| modules.fifo_len = 40 |
| modules.spkcache_len = 188 |
| modules.spkcache_update_period = 31 |
|
|
| print(f"Config: chunk_len={modules.chunk_len}, left_ctx={modules.chunk_left_context}, right_ctx={modules.chunk_right_context}") |
| print(f" fifo_len={modules.fifo_len}, spkcache_len={modules.spkcache_len}") |
|
|
| |
| audio_path = "../audio.wav" |
| audio, sr = librosa.load(audio_path, sr=16000, mono=True) |
| print(f"Loaded audio: {len(audio)} samples ({len(audio)/16000:.2f}s)") |
|
|
| waveform = torch.from_numpy(audio).unsqueeze(0).float() |
|
|
| |
| with torch.no_grad(): |
| audio_len = torch.tensor([waveform.shape[1]]) |
| features, feat_len = model.process_signal( |
| audio_signal=waveform, audio_signal_length=audio_len |
| ) |
|
|
| |
| features = features[:, :, :feat_len.max()] |
| print(f"Features: {features.shape} (batch, mel, time)") |
|
|
| |
| subsampling = modules.subsampling_factor |
| chunk_len = modules.chunk_len |
| left_context = modules.chunk_left_context |
| right_context = modules.chunk_right_context |
| core_frames = chunk_len * subsampling |
|
|
| total_mel_frames = features.shape[2] |
| print(f"Total mel frames: {total_mel_frames}") |
| print(f"Core frames per chunk: {core_frames}") |
|
|
| |
| streaming_state = modules.init_streaming_state(device=features.device) |
|
|
| |
| total_preds = torch.zeros((1, 0, 4), device=features.device) |
|
|
| all_preds = [] |
| chunk_idx = 0 |
|
|
| |
| stt_feat = 0 |
| while stt_feat < total_mel_frames: |
| end_feat = min(stt_feat + core_frames, total_mel_frames) |
|
|
| |
| left_offset = min(left_context * subsampling, stt_feat) |
| right_offset = min(right_context * subsampling, total_mel_frames - end_feat) |
|
|
| chunk_start = stt_feat - left_offset |
| chunk_end = end_feat + right_offset |
|
|
| |
| chunk = features[:, :, chunk_start:chunk_end] |
| chunk_t = chunk.transpose(1, 2) |
| chunk_len_tensor = torch.tensor([chunk_t.shape[1]], dtype=torch.long) |
|
|
| with torch.no_grad(): |
| |
| streaming_state, total_preds = model.forward_streaming_step( |
| processed_signal=chunk_t, |
| processed_signal_length=chunk_len_tensor, |
| streaming_state=streaming_state, |
| total_preds=total_preds, |
| left_offset=left_offset, |
| right_offset=right_offset, |
| ) |
|
|
| chunk_idx += 1 |
| stt_feat = end_feat |
|
|
| |
| all_preds = total_preds[0].numpy() |
| print(f"\nTotal output frames: {all_preds.shape[0]}") |
| print(f"Predictions shape: {all_preds.shape}") |
|
|
| |
| print("\n=== NeMo Streaming Timeline (80ms per frame, threshold=0.55) ===") |
| print("Frame Time Spk0 Spk1 Spk2 Spk3 | Visual") |
| print("-" * 60) |
|
|
| for frame in range(all_preds.shape[0]): |
| time_sec = frame * 0.08 |
| probs = all_preds[frame] |
| visual = ['■' if p > 0.55 else '·' for p in probs] |
| print(f"{frame:5d} {time_sec:5.2f}s {probs[0]:.3f} {probs[1]:.3f} {probs[2]:.3f} {probs[3]:.3f} | [{visual[0]}{visual[1]}{visual[2]}{visual[3]}]") |
|
|
| print("-" * 60) |
|
|
| |
| print("\n=== Speaker Activity Summary ===") |
| threshold = 0.55 |
| for spk in range(4): |
| active_frames = np.sum(all_preds[:, spk] > threshold) |
| active_time = active_frames * 0.08 |
| percent = active_time / (all_preds.shape[0] * 0.08) * 100 |
| print(f"Speaker_{spk}: {active_time:.1f}s active ({percent:.1f}%)") |
|
|
| |
| output = { |
| "total_frames": int(all_preds.shape[0]), |
| "frame_duration_seconds": 0.08, |
| "probabilities": all_preds.flatten().tolist(), |
| "config": { |
| "chunk_len": chunk_len, |
| "chunk_left_context": left_context, |
| "chunk_right_context": right_context, |
| "fifo_len": modules.fifo_len, |
| "spkcache_len": modules.spkcache_len, |
| } |
| } |
|
|
| with open("/tmp/nemo_streaming_reference.json", "w") as f: |
| json.dump(output, f, indent=2) |
| print("\nSaved to /tmp/nemo_streaming_reference.json") |
|
|
| if __name__ == "__main__": |
| main() |
|
|