rahuldshetty commited on
Commit
9c1b10f
·
verified ·
1 Parent(s): a8414bc

Upload inference_trained_model.py

Browse files
Files changed (1) hide show
  1. inference_trained_model.py +5 -5
inference_trained_model.py CHANGED
@@ -203,22 +203,22 @@ def generate(model, tokenizer, midi_tokenizer, metadata, prompt, max_midi_tokens
203
  def save_midi(midi_score, output_path: str):
204
  """Save a MidiTok decoded score to a MIDI file.
205
 
206
- MidiTok v3 returns a symusic Score object which uses .write_midi()
207
  instead of the old miditoolkit MidiFile.dump() method.
208
  """
209
  out = Path(output_path)
210
  out.parent.mkdir(parents=True, exist_ok=True)
211
 
212
- # MidiTok v3 (symusic backend) uses write_midi()
213
- if hasattr(midi_score, "write_midi"):
214
- midi_score.write_midi(str(out))
215
  # Fallback for older miditoolkit MidiFile objects
216
  elif hasattr(midi_score, "dump"):
217
  midi_score.dump(str(out))
218
  else:
219
  raise AttributeError(
220
  f"Cannot save MIDI object of type {type(midi_score)}. "
221
- "Expected ScoreTick with .write_midi() or MidiFile with .dump()"
222
  )
223
 
224
  size = out.stat().st_size
 
203
  def save_midi(midi_score, output_path: str):
204
  """Save a MidiTok decoded score to a MIDI file.
205
 
206
+ MidiTok v3 returns a symusic ScoreTick object which uses dump_midi()
207
  instead of the old miditoolkit MidiFile.dump() method.
208
  """
209
  out = Path(output_path)
210
  out.parent.mkdir(parents=True, exist_ok=True)
211
 
212
+ # symusic ScoreTick (MidiTok v3) uses dump_midi()
213
+ if hasattr(midi_score, "dump_midi"):
214
+ midi_score.dump_midi(str(out))
215
  # Fallback for older miditoolkit MidiFile objects
216
  elif hasattr(midi_score, "dump"):
217
  midi_score.dump(str(out))
218
  else:
219
  raise AttributeError(
220
  f"Cannot save MIDI object of type {type(midi_score)}. "
221
+ f"Available attrs: {[a for a in dir(midi_score) if not a.startswith('_')]}"
222
  )
223
 
224
  size = out.stat().st_size