Spaces:
Sleeping
Sleeping
| """Async queue processor for message trees. | |
| Handles the async processing lifecycle of tree nodes. | |
| """ | |
| import asyncio | |
| from collections.abc import Awaitable, Callable | |
| from loguru import logger | |
| from providers.common import get_user_facing_error_message | |
| from .data import MessageNode, MessageState, MessageTree | |
| class TreeQueueProcessor: | |
| """ | |
| Handles async queue processing for a single tree. | |
| Separates the async processing logic from the data management. | |
| """ | |
| def __init__( | |
| self, | |
| queue_update_callback: Callable[[MessageTree], Awaitable[None]] | None = None, | |
| node_started_callback: Callable[[MessageTree, str], Awaitable[None]] | |
| | None = None, | |
| ): | |
| self._queue_update_callback = queue_update_callback | |
| self._node_started_callback = node_started_callback | |
| def set_queue_update_callback( | |
| self, | |
| queue_update_callback: Callable[[MessageTree], Awaitable[None]] | None, | |
| ) -> None: | |
| """Update the callback used to refresh queue positions.""" | |
| self._queue_update_callback = queue_update_callback | |
| def set_node_started_callback( | |
| self, | |
| node_started_callback: Callable[[MessageTree, str], Awaitable[None]] | None, | |
| ) -> None: | |
| """Update the callback used when a queued node starts processing.""" | |
| self._node_started_callback = node_started_callback | |
| async def _notify_queue_updated(self, tree: MessageTree) -> None: | |
| """Invoke queue update callback if set.""" | |
| if not self._queue_update_callback: | |
| return | |
| try: | |
| await self._queue_update_callback(tree) | |
| except Exception as e: | |
| logger.warning(f"Queue update callback failed: {e}") | |
| async def _notify_node_started(self, tree: MessageTree, node_id: str) -> None: | |
| """Invoke node started callback if set.""" | |
| if not self._node_started_callback: | |
| return | |
| try: | |
| await self._node_started_callback(tree, node_id) | |
| except Exception as e: | |
| logger.warning(f"Node started callback failed: {e}") | |
| async def process_node( | |
| self, | |
| tree: MessageTree, | |
| node: MessageNode, | |
| processor: Callable[[str, MessageNode], Awaitable[None]], | |
| ) -> None: | |
| """Process a single node and then check the queue.""" | |
| # Skip if already in terminal state (e.g. from error propagation) | |
| if node.state == MessageState.ERROR: | |
| logger.info( | |
| f"Skipping node {node.node_id} as it is already in state {node.state}" | |
| ) | |
| # Still need to check for next messages | |
| await self._process_next(tree, processor) | |
| return | |
| try: | |
| await processor(node.node_id, node) | |
| except asyncio.CancelledError: | |
| logger.info(f"Task for node {node.node_id} was cancelled") | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error processing node {node.node_id}: {e}") | |
| await tree.update_state( | |
| node.node_id, | |
| MessageState.ERROR, | |
| error_message=get_user_facing_error_message(e), | |
| ) | |
| finally: | |
| async with tree.with_lock(): | |
| tree.clear_current_node() | |
| # Check if there are more messages in the queue | |
| await self._process_next(tree, processor) | |
| async def _process_next( | |
| self, | |
| tree: MessageTree, | |
| processor: Callable[[str, MessageNode], Awaitable[None]], | |
| ) -> None: | |
| """Process the next message in queue, if any.""" | |
| next_node_id = None | |
| node = None | |
| async with tree.with_lock(): | |
| next_node_id = await tree.dequeue() | |
| if not next_node_id: | |
| tree.set_processing_state(None, False) | |
| logger.debug(f"Tree {tree.root_id} queue empty, marking as free") | |
| return | |
| tree.set_processing_state(next_node_id, True) | |
| logger.info(f"Processing next queued node {next_node_id}") | |
| # Process next node (outside lock) | |
| node = tree.get_node(next_node_id) | |
| if node: | |
| tree.set_current_task( | |
| asyncio.create_task(self.process_node(tree, node, processor)) | |
| ) | |
| # Notify that this node has started processing and refresh queue positions. | |
| if next_node_id: | |
| await self._notify_node_started(tree, next_node_id) | |
| await self._notify_queue_updated(tree) | |
| async def enqueue_and_start( | |
| self, | |
| tree: MessageTree, | |
| node_id: str, | |
| processor: Callable[[str, MessageNode], Awaitable[None]], | |
| ) -> bool: | |
| """ | |
| Enqueue a node or start processing immediately. | |
| Args: | |
| tree: The message tree | |
| node_id: Node to process | |
| processor: Async function to process the node | |
| Returns: | |
| True if queued, False if processing immediately | |
| """ | |
| async with tree.with_lock(): | |
| if tree.is_processing: | |
| tree.put_queue_unlocked(node_id) | |
| queue_size = tree.get_queue_size() | |
| logger.info(f"Queued node {node_id}, position {queue_size}") | |
| return True | |
| else: | |
| tree.set_processing_state(node_id, True) | |
| # Process outside the lock | |
| node = tree.get_node(node_id) | |
| if node: | |
| tree.set_current_task( | |
| asyncio.create_task(self.process_node(tree, node, processor)) | |
| ) | |
| return False | |
| def cancel_current(self, tree: MessageTree) -> bool: | |
| """Cancel the currently running task in a tree.""" | |
| return tree.cancel_current_task() | |