| | |
| | |
| | import gradio as gr |
| | import sys |
| | import threading |
| | import queue |
| | import time |
| | import random |
| | from io import TextIOBase |
| | import datetime |
| | import subprocess |
| | import os |
| | from inference import postprocess_inst_names |
| | from inference import inference_patch |
| | from convert import abc2xml, xml2, pdf2img |
| |
|
| |
|
| | title_html = """ |
| | <div class="title-container"> |
| | <h1 class="title-text">NotaGen</h1> |
| | <!-- ArXiv --> |
| | <a href="https://arxiv.org/abs/2502.18008"> |
| | <img src="https://img.shields.io/badge/NotaGen_Paper-ArXiv-%23B31B1B?logo=arxiv&logoColor=white" alt="Paper"> |
| | </a> |
| | |
| | <!-- GitHub --> |
| | <a href="https://github.com/ElectricAlexis/NotaGen"> |
| | <img src="https://img.shields.io/badge/NotaGen_Code-GitHub-%23181717?logo=github&logoColor=white" alt="GitHub"> |
| | </a> |
| | |
| | <!-- HuggingFace --> |
| | <a href="https://huggingface.co/ElectricAlexis/NotaGen"> |
| | <img src="https://img.shields.io/badge/NotaGen_Weights-HuggingFace-%23FFD21F?logo=huggingface&logoColor=white" alt="Weights"> |
| | </a> |
| | |
| | <!-- Web Demo --> |
| | <a href="https://electricalexis.github.io/notagen-demo/"> |
| | <img src="https://img.shields.io/badge/NotaGen_Demo-Web-%23007ACC?logo=google-chrome&logoColor=white" alt="Demo"> |
| | </a> |
| | </div> |
| | <bp> |
| | <p style="font-size: 1.2em;">NotaGen is a model for generating sheet music in ABC notation format. Select a period, composer, and instrumentation to generate classical-style music!</p> |
| | """ |
| |
|
| | |
| | with open('prompts.txt', 'r') as f: |
| | prompts = f.readlines() |
| |
|
| | valid_combinations = set() |
| | for prompt in prompts: |
| | prompt = prompt.strip() |
| | parts = prompt.split('_') |
| | valid_combinations.add((parts[0], parts[1], parts[2])) |
| |
|
| | |
| | periods = sorted({p for p, _, _ in valid_combinations}) |
| | composers = sorted({c for _, c, _ in valid_combinations}) |
| | instruments = sorted({i for _, _, i in valid_combinations}) |
| |
|
| | |
| | def update_components(period, composer): |
| | if not period: |
| | return [ |
| | gr.update(choices=[], value=None, interactive=False), |
| | gr.update(choices=[], value=None, interactive=False) |
| | ] |
| |
|
| | valid_composers = sorted({c for p, c, _ in valid_combinations if p == period}) |
| | valid_instruments = sorted({i for p, c, i in valid_combinations if p == period and c == composer}) if composer else [] |
| |
|
| | return [ |
| | gr.update( |
| | choices=valid_composers, |
| | value=composer if composer in valid_composers else None, |
| | interactive=True |
| | ), |
| | gr.update( |
| | choices=valid_instruments, |
| | value=None, |
| | interactive=bool(valid_instruments) |
| | ) |
| | ] |
| |
|
| | |
| | class RealtimeStream(TextIOBase): |
| | def __init__(self, queue): |
| | self.queue = queue |
| |
|
| | def write(self, text): |
| | self.queue.put(text) |
| | return len(text) |
| |
|
| | def convert_files(abc_content, period, composer, instrumentation): |
| | if not all([period, composer, instrumentation]): |
| | raise gr.Error("Please complete a valid generation first before saving") |
| |
|
| | timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
| | prompt_str = f"{period}_{composer}_{instrumentation}" |
| | filename_base = f"{timestamp}_{prompt_str}" |
| |
|
| | abc_filename = f"{filename_base}.abc" |
| | with open(abc_filename, "w", encoding="utf-8") as f: |
| | f.write(abc_content) |
| |
|
| | |
| | postprocessed_inst_abc = postprocess_inst_names(abc_content) |
| | filename_base_postinst = f"{filename_base}_postinst" |
| | with open(filename_base_postinst + ".abc", "w", encoding="utf-8") as f: |
| | f.write(postprocessed_inst_abc) |
| |
|
| | |
| | file_paths = {'abc': abc_filename} |
| | try: |
| | |
| | abc2xml(filename_base) |
| | abc2xml(filename_base_postinst) |
| |
|
| | |
| | xml2(filename_base, 'pdf') |
| |
|
| | |
| | xml2(filename_base, 'mid') |
| | xml2(filename_base_postinst, 'mid') |
| |
|
| | |
| | xml2(filename_base, 'mp3') |
| | xml2(filename_base_postinst, 'mp3') |
| |
|
| | |
| | images = pdf2img(filename_base) |
| | for i, image in enumerate(images): |
| | image.save(f"{filename_base}_page_{i+1}.png", "PNG") |
| |
|
| | file_paths.update({ |
| | 'xml': f"{filename_base_postinst}.xml", |
| | 'pdf': f"{filename_base}.pdf", |
| | 'mid': f"{filename_base_postinst}.mid", |
| | 'mp3': f"{filename_base_postinst}.mp3", |
| | 'pages': len(images), |
| | 'current_page': 0, |
| | 'base': filename_base |
| | }) |
| |
|
| | except Exception as e: |
| | raise gr.Error(f"File processing failed: {str(e)}") |
| |
|
| | return file_paths |
| |
|
| |
|
| | |
| | def update_page(direction, data): |
| | """ |
| | data contains three key pieces of information: 'pages', 'current_page', and 'base' |
| | """ |
| | if not data: |
| | return None, gr.update(interactive=False), gr.update(interactive=False), data |
| |
|
| | if direction == "prev" and data['current_page'] > 0: |
| | data['current_page'] -= 1 |
| | elif direction == "next" and data['current_page'] < data['pages'] - 1: |
| | data['current_page'] += 1 |
| |
|
| | current_page_index = data['current_page'] |
| | |
| | new_image = f"{data['base']}_page_{current_page_index+1}.png" |
| | |
| | prev_btn_state = gr.update(interactive=(current_page_index > 0)) |
| | next_btn_state = gr.update(interactive=(current_page_index < data['pages'] - 1)) |
| |
|
| | return new_image, prev_btn_state, next_btn_state, data |
| |
|
| |
|
| | |
| | def generate_music(period, composer, instrumentation): |
| | """ |
| | Must ensure each yield returns the same number of values. |
| | We're preparing to return 5 values, corresponding to: |
| | 1) process_output (intermediate inference information) |
| | 2) final_output (final ABC) |
| | 3) pdf_image (path to the PNG of the first page of the PDF) |
| | 4) audio_player (mp3 path) |
| | 5) pdf_state (state for page navigation) |
| | """ |
| | |
| | random_seed = int(time.time()) % 10000 |
| | random.seed(random_seed) |
| | |
| | |
| | try: |
| | import numpy as np |
| | np.random.seed(random_seed) |
| | except ImportError: |
| | pass |
| | |
| | |
| | try: |
| | import torch |
| | torch.manual_seed(random_seed) |
| | if torch.cuda.is_available(): |
| | torch.cuda.manual_seed_all(random_seed) |
| | except ImportError: |
| | pass |
| |
|
| | if (period, composer, instrumentation) not in valid_combinations: |
| | |
| | raise gr.Error("Invalid prompt combination! Please re-select from the period options") |
| |
|
| | output_queue = queue.Queue() |
| | original_stdout = sys.stdout |
| | sys.stdout = RealtimeStream(output_queue) |
| |
|
| | result_container = [] |
| |
|
| | def run_inference(): |
| | try: |
| | |
| | result = inference_patch(period, composer, instrumentation) |
| | result_container.append(result) |
| | finally: |
| | sys.stdout = original_stdout |
| |
|
| | thread = threading.Thread(target=run_inference) |
| | thread.start() |
| |
|
| | process_output = "" |
| | final_output_abc = "" |
| | pdf_image = None |
| | audio_file = None |
| | pdf_state = None |
| |
|
| | |
| | while thread.is_alive(): |
| | try: |
| | text = output_queue.get(timeout=0.1) |
| | process_output += text |
| | |
| | yield process_output, final_output_abc, pdf_image, audio_file, pdf_state, gr.update(value=None, visible=False) |
| | except queue.Empty: |
| | continue |
| |
|
| | |
| | while not output_queue.empty(): |
| | text = output_queue.get() |
| | process_output += text |
| |
|
| | |
| | final_result = result_container[0] if result_container else "" |
| | |
| | |
| | final_output_abc = "Converting files..." |
| | yield process_output, final_output_abc, pdf_image, audio_file, pdf_state, gr.update(value=None, visible=False) |
| |
|
| |
|
| | |
| | try: |
| | file_paths = convert_files(final_result, period, composer, instrumentation) |
| | final_output_abc = final_result |
| | |
| | if file_paths['pages'] > 0: |
| | pdf_image = f"{file_paths['base']}_page_1.png" |
| | audio_file = file_paths['mp3'] |
| | pdf_state = file_paths |
| | |
| | |
| | download_list = [] |
| | if 'abc' in file_paths and os.path.exists(file_paths['abc']): |
| | download_list.append(file_paths['abc']) |
| | if 'xml' in file_paths and os.path.exists(file_paths['xml']): |
| | download_list.append(file_paths['xml']) |
| | if 'pdf' in file_paths and os.path.exists(file_paths['pdf']): |
| | download_list.append(file_paths['pdf']) |
| | if 'mid' in file_paths and os.path.exists(file_paths['mid']): |
| | download_list.append(file_paths['mid']) |
| | if 'mp3' in file_paths and os.path.exists(file_paths['mp3']): |
| | download_list.append(file_paths['mp3']) |
| | except Exception as e: |
| | |
| | yield process_output, f"Error converting files: {str(e)}", None, None, None, gr.update(value=None, visible=False) |
| | return |
| |
|
| | |
| | yield process_output, final_output_abc, pdf_image, audio_file, pdf_state, gr.update(value=download_list, visible=True) |
| |
|
| |
|
| | def get_file(file_type, period, composer, instrumentation): |
| | """ |
| | Returns the local file of specified type for Gradio download |
| | """ |
| | |
| | |
| | |
| | possible_files = [f for f in os.listdir('.') if f.endswith(f'.{file_type}')] |
| | if not possible_files: |
| | return None |
| | |
| | possible_files.sort(key=os.path.getmtime) |
| | return possible_files[-1] |
| |
|
| |
|
| | css = """ |
| | /* Compact button style */ |
| | button[size="sm"] { |
| | padding: 4px 8px !important; |
| | margin: 2px !important; |
| | min-width: 60px; |
| | } |
| | |
| | /* PDF preview area */ |
| | #pdf-preview { |
| | border-radius: 8px; /* Rounded corners */ |
| | box-shadow: 0 2px 8px rgba(0,0,0,0.1); /* Shadow */ |
| | } |
| | |
| | .page-btn { |
| | padding: 12px !important; /* Increase clickable area */ |
| | margin: auto !important; /* Vertical center */ |
| | } |
| | |
| | /* Button hover effect */ |
| | .page-btn:hover { |
| | background: #f0f0f0 !important; |
| | transform: scale(1.05); |
| | } |
| | |
| | /* Layout adjustment */ |
| | .gr-row { |
| | gap: 10px !important; /* Element spacing */ |
| | } |
| | |
| | /* Audio player */ |
| | .audio-panel { |
| | margin-top: 15px !important; |
| | max-width: 400px; |
| | } |
| | |
| | #audio-preview audio { |
| | height: 200px !important; |
| | } |
| | |
| | /* Save functionality area */ |
| | .save-as-row { |
| | margin-top: 15px; |
| | padding: 10px; |
| | border-top: 1px solid #eee; |
| | } |
| | |
| | /* Download files styling */ |
| | .download-files { |
| | margin-top: 15px; |
| | border-radius: 8px; |
| | box-shadow: 0 2px 8px rgba(0,0,0,0.1); |
| | } |
| | |
| | /* Social icons styling */ |
| | .title-container { |
| | display: flex; |
| | align-items: center; |
| | gap: 15px; |
| | margin-bottom: 10px; |
| | } |
| | |
| | .title-text { |
| | margin: 0; |
| | font-size: 1.8em; |
| | } |
| | |
| | .social-icons { |
| | display: flex; |
| | gap: 10px; |
| | } |
| | |
| | .social-icon { |
| | display: inline-flex; |
| | align-items: center; |
| | justify-content: center; |
| | width: 32px; |
| | height: 32px; |
| | border-radius: 50%; |
| | background-color: #f5f5f5; |
| | text-decoration: none; |
| | transition: transform 0.2s, background-color 0.2s; |
| | } |
| | |
| | .social-icon:hover { |
| | transform: scale(1.1); |
| | background-color: #e0e0e0; |
| | } |
| | |
| | .social-icon img { |
| | width: 20px; |
| | height: 20px; |
| | } |
| | |
| | """ |
| |
|
| | with gr.Blocks(css=css) as demo: |
| | gr.HTML(title_html) |
| |
|
| | |
| | pdf_state = gr.State() |
| |
|
| | with gr.Column(): |
| | with gr.Row(): |
| | |
| | with gr.Column(): |
| | with gr.Row(): |
| | period_dd = gr.Dropdown( |
| | choices=periods, |
| | value=None, |
| | label="Period", |
| | interactive=True |
| | ) |
| | composer_dd = gr.Dropdown( |
| | choices=[], |
| | value=None, |
| | label="Composer", |
| | interactive=False |
| | ) |
| | instrument_dd = gr.Dropdown( |
| | choices=[], |
| | value=None, |
| | label="Instrumentation", |
| | interactive=False |
| | ) |
| |
|
| | generate_btn = gr.Button("Generate!", variant="primary") |
| |
|
| | process_output = gr.Textbox( |
| | label="Generation process", |
| | interactive=False, |
| | lines=2, |
| | max_lines=2, |
| | placeholder="Generation progress will be shown here..." |
| | ) |
| |
|
| | final_output = gr.Textbox( |
| | label="Post-processed ABC notation scores", |
| | interactive=True, |
| | lines=8, |
| | max_lines=8, |
| | placeholder="Post-processed ABC scores will be shown here..." |
| | ) |
| |
|
| | |
| | audio_player = gr.Audio( |
| | label="Audio Preview", |
| | format="mp3", |
| | interactive=False, |
| | ) |
| |
|
| | |
| | with gr.Column(): |
| | |
| | pdf_image = gr.Image( |
| | label="Sheet Music Preview", |
| | show_label=False, |
| | height=650, |
| | type="filepath", |
| | elem_id="pdf-preview", |
| | interactive=False, |
| | show_download_button=False |
| | ) |
| |
|
| | |
| | with gr.Row(): |
| | prev_btn = gr.Button( |
| | "⬅️ Last Page", |
| | variant="secondary", |
| | size="sm", |
| | elem_classes="page-btn" |
| | ) |
| | next_btn = gr.Button( |
| | "Next Page ➡️", |
| | variant="secondary", |
| | size="sm", |
| | elem_classes="page-btn" |
| | ) |
| |
|
| | with gr.Column(): |
| | gr.Markdown("**Download Files:**") |
| | download_files = gr.Files( |
| | label="Generated Files", |
| | visible=False, |
| | elem_classes="download-files", |
| | type="filepath" |
| | ) |
| |
|
| | |
| | period_dd.change( |
| | update_components, |
| | inputs=[period_dd, composer_dd], |
| | outputs=[composer_dd, instrument_dd] |
| | ) |
| | composer_dd.change( |
| | update_components, |
| | inputs=[period_dd, composer_dd], |
| | outputs=[composer_dd, instrument_dd] |
| | ) |
| |
|
| | |
| | generate_btn.click( |
| | generate_music, |
| | inputs=[period_dd, composer_dd, instrument_dd], |
| | outputs=[process_output, final_output, pdf_image, audio_player, pdf_state, download_files] |
| | ) |
| |
|
| | |
| | prev_signal = gr.Textbox(value="prev", visible=False) |
| | next_signal = gr.Textbox(value="next", visible=False) |
| |
|
| | prev_btn.click( |
| | update_page, |
| | inputs=[prev_signal, pdf_state], |
| | outputs=[pdf_image, prev_btn, next_btn, pdf_state] |
| | ) |
| |
|
| | next_btn.click( |
| | update_page, |
| | inputs=[next_signal, pdf_state], |
| | outputs=[pdf_image, prev_btn, next_btn, pdf_state] |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | demo.launch( |
| | server_name="0.0.0.0", |
| | server_port=7860, |
| | |
| | ) |