Files changed (3) hide show
  1. README.md +2 -84
  2. modeling_moss_tts.py +103 -0
  3. processing_moss_tts.py +1 -1
README.md CHANGED
@@ -231,27 +231,6 @@ torch.backends.cuda.enable_flash_sdp(True)
231
  torch.backends.cuda.enable_mem_efficient_sdp(True)
232
  torch.backends.cuda.enable_math_sdp(True)
233
 
234
- class DelayGenerationConfig(GenerationConfig):
235
- def __init__(self, **kwargs):
236
- super().__init__(**kwargs)
237
- self.layers = kwargs.get("layers", [{} for _ in range(32)])
238
- self.do_samples = kwargs.get("do_samples", None)
239
- self.n_vq_for_inference = 32
240
-
241
- def initial_config(tokenizer, model_name_or_path):
242
- generation_config = DelayGenerationConfig.from_pretrained(model_name_or_path)
243
- generation_config.pad_token_id = tokenizer.pad_token_id
244
- generation_config.eos_token_id = 151653
245
- generation_config.max_new_tokens = 1000000
246
- generation_config.temperature = 1.0
247
- generation_config.top_p = 0.95
248
- generation_config.top_k = 100
249
- generation_config.repetition_penalty = 1.1
250
- generation_config.use_cache = True
251
- generation_config.do_sample = False
252
- return generation_config
253
-
254
-
255
  pretrained_model_name_or_path = "OpenMOSS-Team/MOSS-TTS-Local-Transformer"
256
  device = "cuda" if torch.cuda.is_available() else "cpu"
257
  dtype = torch.bfloat16 if device == "cuda" else torch.float32
@@ -346,25 +325,6 @@ model = AutoModel.from_pretrained(
346
  ).to(device)
347
  model.eval()
348
 
349
- generation_config = initial_config(processor.tokenizer, pretrained_model_name_or_path)
350
- generation_config.n_vq_for_inference = model.channels - 1
351
- generation_config.do_samples = [True] * model.channels
352
- generation_config.layers = [
353
- {
354
- "repetition_penalty": 1.0,
355
- "temperature": 1.5,
356
- "top_p": 1.0,
357
- "top_k": 50
358
- }
359
- ] + [
360
- {
361
- "repetition_penalty": 1.1,
362
- "temperature": 1.0,
363
- "top_p": 0.95,
364
- "top_k": 50
365
- }
366
- ] * (model.channels - 1)
367
-
368
  batch_size = 1
369
 
370
  save_dir = Path(f"inference_root_moss_tts_local_transformer_generation")
@@ -380,7 +340,7 @@ with torch.no_grad():
380
  outputs = model.generate(
381
  input_ids=input_ids,
382
  attention_mask=attention_mask,
383
- generation_config=generation_config
384
  )
385
 
386
  for message in processor.decode(outputs):
@@ -388,7 +348,6 @@ with torch.no_grad():
388
  out_path = save_dir / f"sample{sample_idx}.wav"
389
  sample_idx += 1
390
  torchaudio.save(out_path, audio.unsqueeze(0), processor.model_config.sampling_rate)
391
-
392
  ```
393
 
394
  ### Continuation + Voice Cloning (Prefix Audio + Text)
@@ -408,27 +367,6 @@ torch.backends.cuda.enable_flash_sdp(True)
408
  torch.backends.cuda.enable_mem_efficient_sdp(True)
409
  torch.backends.cuda.enable_math_sdp(True)
410
 
411
- class DelayGenerationConfig(GenerationConfig):
412
- def __init__(self, **kwargs):
413
- super().__init__(**kwargs)
414
- self.layers = kwargs.get("layers", [{} for _ in range(32)])
415
- self.do_samples = kwargs.get("do_samples", None)
416
- self.n_vq_for_inference = 32
417
-
418
- def initial_config(tokenizer, model_name_or_path):
419
- generation_config = DelayGenerationConfig.from_pretrained(model_name_or_path)
420
- generation_config.pad_token_id = tokenizer.pad_token_id
421
- generation_config.eos_token_id = 151653
422
- generation_config.max_new_tokens = 1000000
423
- generation_config.temperature = 1.0
424
- generation_config.top_p = 0.95
425
- generation_config.top_k = 100
426
- generation_config.repetition_penalty = 1.1
427
- generation_config.use_cache = True
428
- generation_config.do_sample = False
429
- return generation_config
430
-
431
-
432
  pretrained_model_name_or_path = "OpenMOSS-Team/MOSS-TTS-Local-Transformer"
433
  device = "cuda" if torch.cuda.is_available() else "cpu"
434
  dtype = torch.bfloat16 if device == "cuda" else torch.float32
@@ -495,25 +433,6 @@ model = AutoModel.from_pretrained(
495
  ).to(device)
496
  model.eval()
497
 
498
- generation_config = initial_config(processor.tokenizer, pretrained_model_name_or_path)
499
- generation_config.n_vq_for_inference = model.channels - 1
500
- generation_config.do_samples = [True] * model.channels
501
- generation_config.layers = [
502
- {
503
- "repetition_penalty": 1.0,
504
- "temperature": 1.5,
505
- "top_p": 1.0,
506
- "top_k": 50
507
- }
508
- ] + [
509
- {
510
- "repetition_penalty": 1.1,
511
- "temperature": 1.0,
512
- "top_p": 0.95,
513
- "top_k": 50
514
- }
515
- ] * (model.channels - 1)
516
-
517
  batch_size = 1
518
 
519
  save_dir = Path("inference_root_moss_tts_local_transformer_continuation")
@@ -529,7 +448,7 @@ with torch.no_grad():
529
  outputs = model.generate(
530
  input_ids=input_ids,
531
  attention_mask=attention_mask,
532
- generation_config=generation_config
533
  )
534
 
535
  for message in processor.decode(outputs):
@@ -537,7 +456,6 @@ with torch.no_grad():
537
  out_path = save_dir / f"sample{sample_idx}.wav"
538
  sample_idx += 1
539
  torchaudio.save(out_path, audio.unsqueeze(0), processor.model_config.sampling_rate)
540
-
541
  ```
542
 
543
 
 
231
  torch.backends.cuda.enable_mem_efficient_sdp(True)
232
  torch.backends.cuda.enable_math_sdp(True)
233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  pretrained_model_name_or_path = "OpenMOSS-Team/MOSS-TTS-Local-Transformer"
235
  device = "cuda" if torch.cuda.is_available() else "cpu"
236
  dtype = torch.bfloat16 if device == "cuda" else torch.float32
 
325
  ).to(device)
326
  model.eval()
327
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
  batch_size = 1
329
 
330
  save_dir = Path(f"inference_root_moss_tts_local_transformer_generation")
 
340
  outputs = model.generate(
341
  input_ids=input_ids,
342
  attention_mask=attention_mask,
343
+ max_new_tokens=4096,
344
  )
345
 
346
  for message in processor.decode(outputs):
 
348
  out_path = save_dir / f"sample{sample_idx}.wav"
349
  sample_idx += 1
350
  torchaudio.save(out_path, audio.unsqueeze(0), processor.model_config.sampling_rate)
 
351
  ```
352
 
353
  ### Continuation + Voice Cloning (Prefix Audio + Text)
 
367
  torch.backends.cuda.enable_mem_efficient_sdp(True)
368
  torch.backends.cuda.enable_math_sdp(True)
369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  pretrained_model_name_or_path = "OpenMOSS-Team/MOSS-TTS-Local-Transformer"
371
  device = "cuda" if torch.cuda.is_available() else "cpu"
372
  dtype = torch.bfloat16 if device == "cuda" else torch.float32
 
433
  ).to(device)
434
  model.eval()
435
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
  batch_size = 1
437
 
438
  save_dir = Path("inference_root_moss_tts_local_transformer_continuation")
 
448
  outputs = model.generate(
449
  input_ids=input_ids,
450
  attention_mask=attention_mask,
451
+ max_new_tokens=4096,
452
  )
453
 
454
  for message in processor.decode(outputs):
 
456
  out_path = save_dir / f"sample{sample_idx}.wav"
457
  sample_idx += 1
458
  torchaudio.save(out_path, audio.unsqueeze(0), processor.model_config.sampling_rate)
 
459
  ```
460
 
461
 
modeling_moss_tts.py CHANGED
@@ -616,6 +616,109 @@ class MossTTSDelayModel(MosiTTSPretrainedModel, CustomMixin):
616
  def can_generate(self):
617
  return True
618
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
619
  # def tie_weights(self):
620
  # ...
621
  # for i in range(self.config.channels):
 
616
  def can_generate(self):
617
  return True
618
 
619
+ def _build_generation_config(
620
+ self,
621
+ generation_config: Optional[GenerationConfig] = None,
622
+ max_new_tokens: Optional[int] = None,
623
+ text_temperature: Optional[float] = None,
624
+ text_top_p: Optional[float] = None,
625
+ text_top_k: Optional[int] = None,
626
+ text_repetition_penalty: Optional[float] = None,
627
+ audio_temperature: Optional[float] = None,
628
+ audio_top_p: Optional[float] = None,
629
+ audio_top_k: Optional[int] = None,
630
+ audio_repetition_penalty: Optional[float] = None,
631
+ n_vq_for_inference: Optional[int] = None,
632
+ ) -> GenerationConfig:
633
+ config = copy.deepcopy(generation_config or self.generation_config)
634
+
635
+ text_temperature = 1.5 if text_temperature is None else float(text_temperature)
636
+ text_top_p = 1.0 if text_top_p is None else float(text_top_p)
637
+ text_top_k = 50 if text_top_k is None else int(text_top_k)
638
+ text_repetition_penalty = 1.0 if text_repetition_penalty is None else float(text_repetition_penalty)
639
+ audio_temperature = 1.0 if audio_temperature is None else float(audio_temperature)
640
+ audio_top_p = 0.95 if audio_top_p is None else float(audio_top_p)
641
+ audio_top_k = 50 if audio_top_k is None else int(audio_top_k)
642
+ audio_repetition_penalty = 1.1 if audio_repetition_penalty is None else float(audio_repetition_penalty)
643
+
644
+ text_do_sample = text_temperature > 0
645
+ if not text_do_sample:
646
+ text_temperature = 1.0
647
+ audio_do_sample = audio_temperature > 0
648
+ if not audio_do_sample:
649
+ audio_temperature = 1.0
650
+
651
+ if max_new_tokens is not None:
652
+ config.max_new_tokens = int(max_new_tokens)
653
+ elif getattr(config, "max_new_tokens", None) is None:
654
+ config.max_new_tokens = 100000 # about 2.2 hours , can be overridden by user input, you can set to a smaller value for faster generation during debugging
655
+
656
+ if getattr(config, "pad_token_id", None) is None:
657
+ config.pad_token_id = self.config.pad_token_id
658
+ config.eos_token_id = self.config.audio_end_token_id
659
+ config.use_cache = True
660
+ config.do_sample = text_do_sample or audio_do_sample
661
+
662
+ resolved_n_vq = self.channels - 1 if n_vq_for_inference is None else int(n_vq_for_inference)
663
+ resolved_n_vq = max(1, min(self.channels - 1, resolved_n_vq))
664
+ config.n_vq_for_inference = resolved_n_vq
665
+ config.do_samples = [text_do_sample] + [audio_do_sample] * (self.channels - 1)
666
+ config.layers = [
667
+ {
668
+ "repetition_penalty": text_repetition_penalty,
669
+ "temperature": text_temperature,
670
+ "top_p": text_top_p,
671
+ "top_k": text_top_k,
672
+ }
673
+ ] + [
674
+ {
675
+ "repetition_penalty": audio_repetition_penalty,
676
+ "temperature": audio_temperature,
677
+ "top_p": audio_top_p,
678
+ "top_k": audio_top_k,
679
+ }
680
+ for _ in range(self.channels - 1)
681
+ ]
682
+ return config
683
+
684
+ @torch.inference_mode()
685
+ def generate(
686
+ self,
687
+ input_ids: torch.LongTensor,
688
+ attention_mask: Optional[torch.Tensor] = None,
689
+ generation_config: Optional[GenerationConfig] = None,
690
+ max_new_tokens: Optional[int] = None,
691
+ text_temperature: Optional[float] = None,
692
+ text_top_p: Optional[float] = None,
693
+ text_top_k: Optional[int] = None,
694
+ text_repetition_penalty: Optional[int] = None,
695
+ audio_temperature: Optional[float] = None,
696
+ audio_top_p: Optional[float] = None,
697
+ audio_top_k: Optional[int] = None,
698
+ audio_repetition_penalty: Optional[float] = None,
699
+ n_vq_for_inference: Optional[int] = None,
700
+ **kwargs,
701
+ ):
702
+ resolved_generation_config = self._build_generation_config(
703
+ generation_config=generation_config,
704
+ max_new_tokens=max_new_tokens,
705
+ text_temperature=text_temperature,
706
+ text_top_p=text_top_p,
707
+ text_top_k=text_top_k,
708
+ text_repetition_penalty=text_repetition_penalty,
709
+ audio_temperature=audio_temperature,
710
+ audio_top_p=audio_top_p,
711
+ audio_top_k=audio_top_k,
712
+ audio_repetition_penalty=audio_repetition_penalty,
713
+ n_vq_for_inference=n_vq_for_inference,
714
+ )
715
+ return super().generate(
716
+ input_ids=input_ids,
717
+ attention_mask=attention_mask,
718
+ generation_config=resolved_generation_config,
719
+ **kwargs,
720
+ )
721
+
722
  # def tie_weights(self):
723
  # ...
724
  # for i in range(self.config.channels):
processing_moss_tts.py CHANGED
@@ -621,7 +621,7 @@ class MossTTSDelayProcessor(ProcessorMixin):
621
  prefix_idx = audio_end_idx
622
 
623
  if truncation:
624
- raise RuntimeError("Truncation generation is not supported at present")
625
  else:
626
  last_audio_end_idx = int(audio_end_indices[-1].item())
627
  pad_codes = torch.full(
 
621
  prefix_idx = audio_end_idx
622
 
623
  if truncation:
624
+ ...
625
  else:
626
  last_audio_end_idx = int(audio_end_indices[-1].item())
627
  pad_codes = torch.full(