Upload inference_trained_model.py
Browse files
inference_trained_model.py
CHANGED
|
@@ -203,22 +203,22 @@ def generate(model, tokenizer, midi_tokenizer, metadata, prompt, max_midi_tokens
|
|
| 203 |
def save_midi(midi_score, output_path: str):
|
| 204 |
"""Save a MidiTok decoded score to a MIDI file.
|
| 205 |
|
| 206 |
-
MidiTok v3 returns a symusic
|
| 207 |
instead of the old miditoolkit MidiFile.dump() method.
|
| 208 |
"""
|
| 209 |
out = Path(output_path)
|
| 210 |
out.parent.mkdir(parents=True, exist_ok=True)
|
| 211 |
|
| 212 |
-
#
|
| 213 |
-
if hasattr(midi_score, "
|
| 214 |
-
midi_score.
|
| 215 |
# Fallback for older miditoolkit MidiFile objects
|
| 216 |
elif hasattr(midi_score, "dump"):
|
| 217 |
midi_score.dump(str(out))
|
| 218 |
else:
|
| 219 |
raise AttributeError(
|
| 220 |
f"Cannot save MIDI object of type {type(midi_score)}. "
|
| 221 |
-
"
|
| 222 |
)
|
| 223 |
|
| 224 |
size = out.stat().st_size
|
|
|
|
| 203 |
def save_midi(midi_score, output_path: str):
|
| 204 |
"""Save a MidiTok decoded score to a MIDI file.
|
| 205 |
|
| 206 |
+
MidiTok v3 returns a symusic ScoreTick object which uses dump_midi()
|
| 207 |
instead of the old miditoolkit MidiFile.dump() method.
|
| 208 |
"""
|
| 209 |
out = Path(output_path)
|
| 210 |
out.parent.mkdir(parents=True, exist_ok=True)
|
| 211 |
|
| 212 |
+
# symusic ScoreTick (MidiTok v3) uses dump_midi()
|
| 213 |
+
if hasattr(midi_score, "dump_midi"):
|
| 214 |
+
midi_score.dump_midi(str(out))
|
| 215 |
# Fallback for older miditoolkit MidiFile objects
|
| 216 |
elif hasattr(midi_score, "dump"):
|
| 217 |
midi_score.dump(str(out))
|
| 218 |
else:
|
| 219 |
raise AttributeError(
|
| 220 |
f"Cannot save MIDI object of type {type(midi_score)}. "
|
| 221 |
+
f"Available attrs: {[a for a in dir(midi_score) if not a.startswith('_')]}"
|
| 222 |
)
|
| 223 |
|
| 224 |
size = out.stat().st_size
|