fix: fix the bug in the continuation mode of moss_tts_local; update the readme of moss_tts_local.

#4
Files changed (3) hide show
  1. README.md +84 -2
  2. modeling_moss_tts.py +0 -103
  3. processing_moss_tts.py +1 -1
README.md CHANGED
@@ -231,6 +231,27 @@ torch.backends.cuda.enable_flash_sdp(True)
231
  torch.backends.cuda.enable_mem_efficient_sdp(True)
232
  torch.backends.cuda.enable_math_sdp(True)
233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  pretrained_model_name_or_path = "OpenMOSS-Team/MOSS-TTS-Local-Transformer"
235
  device = "cuda" if torch.cuda.is_available() else "cpu"
236
  dtype = torch.bfloat16 if device == "cuda" else torch.float32
@@ -325,6 +346,25 @@ model = AutoModel.from_pretrained(
325
  ).to(device)
326
  model.eval()
327
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
  batch_size = 1
329
 
330
  save_dir = Path(f"inference_root_moss_tts_local_transformer_generation")
@@ -340,7 +380,7 @@ with torch.no_grad():
340
  outputs = model.generate(
341
  input_ids=input_ids,
342
  attention_mask=attention_mask,
343
- max_new_tokens=4096,
344
  )
345
 
346
  for message in processor.decode(outputs):
@@ -348,6 +388,7 @@ with torch.no_grad():
348
  out_path = save_dir / f"sample{sample_idx}.wav"
349
  sample_idx += 1
350
  torchaudio.save(out_path, audio.unsqueeze(0), processor.model_config.sampling_rate)
 
351
  ```
352
 
353
  ### Continuation + Voice Cloning (Prefix Audio + Text)
@@ -367,6 +408,27 @@ torch.backends.cuda.enable_flash_sdp(True)
367
  torch.backends.cuda.enable_mem_efficient_sdp(True)
368
  torch.backends.cuda.enable_math_sdp(True)
369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  pretrained_model_name_or_path = "OpenMOSS-Team/MOSS-TTS-Local-Transformer"
371
  device = "cuda" if torch.cuda.is_available() else "cpu"
372
  dtype = torch.bfloat16 if device == "cuda" else torch.float32
@@ -433,6 +495,25 @@ model = AutoModel.from_pretrained(
433
  ).to(device)
434
  model.eval()
435
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
  batch_size = 1
437
 
438
  save_dir = Path("inference_root_moss_tts_local_transformer_continuation")
@@ -448,7 +529,7 @@ with torch.no_grad():
448
  outputs = model.generate(
449
  input_ids=input_ids,
450
  attention_mask=attention_mask,
451
- max_new_tokens=4096,
452
  )
453
 
454
  for message in processor.decode(outputs):
@@ -456,6 +537,7 @@ with torch.no_grad():
456
  out_path = save_dir / f"sample{sample_idx}.wav"
457
  sample_idx += 1
458
  torchaudio.save(out_path, audio.unsqueeze(0), processor.model_config.sampling_rate)
 
459
  ```
460
 
461
 
 
231
  torch.backends.cuda.enable_mem_efficient_sdp(True)
232
  torch.backends.cuda.enable_math_sdp(True)
233
 
234
+ class DelayGenerationConfig(GenerationConfig):
235
+ def __init__(self, **kwargs):
236
+ super().__init__(**kwargs)
237
+ self.layers = kwargs.get("layers", [{} for _ in range(32)])
238
+ self.do_samples = kwargs.get("do_samples", None)
239
+ self.n_vq_for_inference = 32
240
+
241
+ def initial_config(tokenizer, model_name_or_path):
242
+ generation_config = DelayGenerationConfig.from_pretrained(model_name_or_path)
243
+ generation_config.pad_token_id = tokenizer.pad_token_id
244
+ generation_config.eos_token_id = 151653
245
+ generation_config.max_new_tokens = 1000000
246
+ generation_config.temperature = 1.0
247
+ generation_config.top_p = 0.95
248
+ generation_config.top_k = 100
249
+ generation_config.repetition_penalty = 1.1
250
+ generation_config.use_cache = True
251
+ generation_config.do_sample = False
252
+ return generation_config
253
+
254
+
255
  pretrained_model_name_or_path = "OpenMOSS-Team/MOSS-TTS-Local-Transformer"
256
  device = "cuda" if torch.cuda.is_available() else "cpu"
257
  dtype = torch.bfloat16 if device == "cuda" else torch.float32
 
346
  ).to(device)
347
  model.eval()
348
 
349
+ generation_config = initial_config(processor.tokenizer, pretrained_model_name_or_path)
350
+ generation_config.n_vq_for_inference = model.channels - 1
351
+ generation_config.do_samples = [True] * model.channels
352
+ generation_config.layers = [
353
+ {
354
+ "repetition_penalty": 1.0,
355
+ "temperature": 1.5,
356
+ "top_p": 1.0,
357
+ "top_k": 50
358
+ }
359
+ ] + [
360
+ {
361
+ "repetition_penalty": 1.1,
362
+ "temperature": 1.0,
363
+ "top_p": 0.95,
364
+ "top_k": 50
365
+ }
366
+ ] * (model.channels - 1)
367
+
368
  batch_size = 1
369
 
370
  save_dir = Path(f"inference_root_moss_tts_local_transformer_generation")
 
380
  outputs = model.generate(
381
  input_ids=input_ids,
382
  attention_mask=attention_mask,
383
+ generation_config=generation_config
384
  )
385
 
386
  for message in processor.decode(outputs):
 
388
  out_path = save_dir / f"sample{sample_idx}.wav"
389
  sample_idx += 1
390
  torchaudio.save(out_path, audio.unsqueeze(0), processor.model_config.sampling_rate)
391
+
392
  ```
393
 
394
  ### Continuation + Voice Cloning (Prefix Audio + Text)
 
408
  torch.backends.cuda.enable_mem_efficient_sdp(True)
409
  torch.backends.cuda.enable_math_sdp(True)
410
 
411
+ class DelayGenerationConfig(GenerationConfig):
412
+ def __init__(self, **kwargs):
413
+ super().__init__(**kwargs)
414
+ self.layers = kwargs.get("layers", [{} for _ in range(32)])
415
+ self.do_samples = kwargs.get("do_samples", None)
416
+ self.n_vq_for_inference = 32
417
+
418
+ def initial_config(tokenizer, model_name_or_path):
419
+ generation_config = DelayGenerationConfig.from_pretrained(model_name_or_path)
420
+ generation_config.pad_token_id = tokenizer.pad_token_id
421
+ generation_config.eos_token_id = 151653
422
+ generation_config.max_new_tokens = 1000000
423
+ generation_config.temperature = 1.0
424
+ generation_config.top_p = 0.95
425
+ generation_config.top_k = 100
426
+ generation_config.repetition_penalty = 1.1
427
+ generation_config.use_cache = True
428
+ generation_config.do_sample = False
429
+ return generation_config
430
+
431
+
432
  pretrained_model_name_or_path = "OpenMOSS-Team/MOSS-TTS-Local-Transformer"
433
  device = "cuda" if torch.cuda.is_available() else "cpu"
434
  dtype = torch.bfloat16 if device == "cuda" else torch.float32
 
495
  ).to(device)
496
  model.eval()
497
 
498
+ generation_config = initial_config(processor.tokenizer, pretrained_model_name_or_path)
499
+ generation_config.n_vq_for_inference = model.channels - 1
500
+ generation_config.do_samples = [True] * model.channels
501
+ generation_config.layers = [
502
+ {
503
+ "repetition_penalty": 1.0,
504
+ "temperature": 1.5,
505
+ "top_p": 1.0,
506
+ "top_k": 50
507
+ }
508
+ ] + [
509
+ {
510
+ "repetition_penalty": 1.1,
511
+ "temperature": 1.0,
512
+ "top_p": 0.95,
513
+ "top_k": 50
514
+ }
515
+ ] * (model.channels - 1)
516
+
517
  batch_size = 1
518
 
519
  save_dir = Path("inference_root_moss_tts_local_transformer_continuation")
 
529
  outputs = model.generate(
530
  input_ids=input_ids,
531
  attention_mask=attention_mask,
532
+ generation_config=generation_config
533
  )
534
 
535
  for message in processor.decode(outputs):
 
537
  out_path = save_dir / f"sample{sample_idx}.wav"
538
  sample_idx += 1
539
  torchaudio.save(out_path, audio.unsqueeze(0), processor.model_config.sampling_rate)
540
+
541
  ```
542
 
543
 
modeling_moss_tts.py CHANGED
@@ -616,109 +616,6 @@ class MossTTSDelayModel(MosiTTSPretrainedModel, CustomMixin):
616
  def can_generate(self):
617
  return True
618
 
619
- def _build_generation_config(
620
- self,
621
- generation_config: Optional[GenerationConfig] = None,
622
- max_new_tokens: Optional[int] = None,
623
- text_temperature: Optional[float] = None,
624
- text_top_p: Optional[float] = None,
625
- text_top_k: Optional[int] = None,
626
- text_repetition_penalty: Optional[float] = None,
627
- audio_temperature: Optional[float] = None,
628
- audio_top_p: Optional[float] = None,
629
- audio_top_k: Optional[int] = None,
630
- audio_repetition_penalty: Optional[float] = None,
631
- n_vq_for_inference: Optional[int] = None,
632
- ) -> GenerationConfig:
633
- config = copy.deepcopy(generation_config or self.generation_config)
634
-
635
- text_temperature = 1.5 if text_temperature is None else float(text_temperature)
636
- text_top_p = 1.0 if text_top_p is None else float(text_top_p)
637
- text_top_k = 50 if text_top_k is None else int(text_top_k)
638
- text_repetition_penalty = 1.0 if text_repetition_penalty is None else float(text_repetition_penalty)
639
- audio_temperature = 1.0 if audio_temperature is None else float(audio_temperature)
640
- audio_top_p = 0.95 if audio_top_p is None else float(audio_top_p)
641
- audio_top_k = 50 if audio_top_k is None else int(audio_top_k)
642
- audio_repetition_penalty = 1.1 if audio_repetition_penalty is None else float(audio_repetition_penalty)
643
-
644
- text_do_sample = text_temperature > 0
645
- if not text_do_sample:
646
- text_temperature = 1.0
647
- audio_do_sample = audio_temperature > 0
648
- if not audio_do_sample:
649
- audio_temperature = 1.0
650
-
651
- if max_new_tokens is not None:
652
- config.max_new_tokens = int(max_new_tokens)
653
- elif getattr(config, "max_new_tokens", None) is None:
654
- config.max_new_tokens = 100000 # about 2.2 hours , can be overridden by user input, you can set to a smaller value for faster generation during debugging
655
-
656
- if getattr(config, "pad_token_id", None) is None:
657
- config.pad_token_id = self.config.pad_token_id
658
- config.eos_token_id = self.config.audio_end_token_id
659
- config.use_cache = True
660
- config.do_sample = text_do_sample or audio_do_sample
661
-
662
- resolved_n_vq = self.channels - 1 if n_vq_for_inference is None else int(n_vq_for_inference)
663
- resolved_n_vq = max(1, min(self.channels - 1, resolved_n_vq))
664
- config.n_vq_for_inference = resolved_n_vq
665
- config.do_samples = [text_do_sample] + [audio_do_sample] * (self.channels - 1)
666
- config.layers = [
667
- {
668
- "repetition_penalty": text_repetition_penalty,
669
- "temperature": text_temperature,
670
- "top_p": text_top_p,
671
- "top_k": text_top_k,
672
- }
673
- ] + [
674
- {
675
- "repetition_penalty": audio_repetition_penalty,
676
- "temperature": audio_temperature,
677
- "top_p": audio_top_p,
678
- "top_k": audio_top_k,
679
- }
680
- for _ in range(self.channels - 1)
681
- ]
682
- return config
683
-
684
- @torch.inference_mode()
685
- def generate(
686
- self,
687
- input_ids: torch.LongTensor,
688
- attention_mask: Optional[torch.Tensor] = None,
689
- generation_config: Optional[GenerationConfig] = None,
690
- max_new_tokens: Optional[int] = None,
691
- text_temperature: Optional[float] = None,
692
- text_top_p: Optional[float] = None,
693
- text_top_k: Optional[int] = None,
694
- text_repetition_penalty: Optional[int] = None,
695
- audio_temperature: Optional[float] = None,
696
- audio_top_p: Optional[float] = None,
697
- audio_top_k: Optional[int] = None,
698
- audio_repetition_penalty: Optional[float] = None,
699
- n_vq_for_inference: Optional[int] = None,
700
- **kwargs,
701
- ):
702
- resolved_generation_config = self._build_generation_config(
703
- generation_config=generation_config,
704
- max_new_tokens=max_new_tokens,
705
- text_temperature=text_temperature,
706
- text_top_p=text_top_p,
707
- text_top_k=text_top_k,
708
- text_repetition_penalty=text_repetition_penalty,
709
- audio_temperature=audio_temperature,
710
- audio_top_p=audio_top_p,
711
- audio_top_k=audio_top_k,
712
- audio_repetition_penalty=audio_repetition_penalty,
713
- n_vq_for_inference=n_vq_for_inference,
714
- )
715
- return super().generate(
716
- input_ids=input_ids,
717
- attention_mask=attention_mask,
718
- generation_config=resolved_generation_config,
719
- **kwargs,
720
- )
721
-
722
  # def tie_weights(self):
723
  # ...
724
  # for i in range(self.config.channels):
 
616
  def can_generate(self):
617
  return True
618
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
619
  # def tie_weights(self):
620
  # ...
621
  # for i in range(self.config.channels):
processing_moss_tts.py CHANGED
@@ -621,7 +621,7 @@ class MossTTSDelayProcessor(ProcessorMixin):
621
  prefix_idx = audio_end_idx
622
 
623
  if truncation:
624
- ...
625
  else:
626
  last_audio_end_idx = int(audio_end_indices[-1].item())
627
  pad_codes = torch.full(
 
621
  prefix_idx = audio_end_idx
622
 
623
  if truncation:
624
+ raise RuntimeError("Truncation generation is not supported at present")
625
  else:
626
  last_audio_end_idx = int(audio_end_indices[-1].item())
627
  pad_codes = torch.full(