| try: |
| from IPython.core.magic import ( |
| needs_local_scope, |
| register_cell_magic, |
| ) |
| from IPython.core.magic_arguments import argument, magic_arguments, parse_argstring |
| except ImportError: |
| pass |
|
|
| import gradio as gr |
| from gradio.routes import App |
| from gradio.utils import BaseReloader |
|
|
|
|
| class CellIdTracker: |
| """Determines the most recently run cell in the notebook. |
| |
| Needed to keep track of which demo the user is updating. |
| """ |
|
|
| def __init__(self, ipython): |
| ipython.events.register("pre_run_cell", self.pre_run_cell) |
| self.shell = ipython |
| self.current_cell: str = "" |
|
|
| def pre_run_cell(self, info): |
| self._current_cell = info.cell_id |
|
|
|
|
| class JupyterReloader(BaseReloader): |
| """Swap a running blocks class in a notebook with the latest cell contents.""" |
|
|
| def __init__(self, ipython) -> None: |
| super().__init__() |
| self._cell_tracker = CellIdTracker(ipython) |
| self._running: dict[str, gr.Blocks] = {} |
|
|
| @property |
| def current_cell(self): |
| return self._cell_tracker.current_cell |
|
|
| @property |
| def running_app(self) -> App: |
| if not self.running_demo.server: |
| raise RuntimeError("Server not running") |
| return self.running_demo.server.running_app |
|
|
| @property |
| def running_demo(self): |
| return self._running[self.current_cell] |
|
|
| def demo_tracked(self) -> bool: |
| return self.current_cell in self._running |
|
|
| def track(self, demo: gr.Blocks): |
| self._running[self.current_cell] = demo |
|
|
|
|
| def load_ipython_extension(ipython): |
| reloader = JupyterReloader(ipython) |
|
|
| @magic_arguments() |
| @argument("--demo-name", default="demo", help="Name of gradio blocks instance.") |
| @argument( |
| "--share", |
| default=False, |
| const=True, |
| nargs="?", |
| help="Whether to launch with sharing. Will slow down reloading.", |
| ) |
| @register_cell_magic |
| @needs_local_scope |
| def blocks(line, cell, local_ns): |
| """Launch a demo defined in a cell in reload mode.""" |
|
|
| args = parse_argstring(blocks, line) |
|
|
| exec(cell, None, local_ns) |
| demo: gr.Blocks = local_ns[args.demo_name] |
| if not reloader.demo_tracked(): |
| demo.launch(share=args.share) |
| reloader.track(demo) |
| else: |
| reloader.swap_blocks(demo) |
| return reloader.running_demo.artifact |
|
|