Skip to content

InMemoryTransport

Bases: Transport

Source code in eggai/transport/inmemory.py
class InMemoryTransport(Transport):
    # One queue per (channel, group_id). Each consumer group gets its own queue.
    _CHANNELS: Dict[str, Dict[str, asyncio.Queue]] = defaultdict(dict)
    # For each channel and group_id, store a list of subscription callbacks.
    _SUBSCRIPTIONS: Dict[str, Dict[str, List[Callable[[Dict[str, Any]], "asyncio.Future"]]]] = defaultdict(
        lambda: defaultdict(list)
    )

    def __init__(self):
        self._connected = False
        # Keep references to consume tasks keyed by (channel, group_id)
        self._consume_tasks: Dict[Tuple[str, str], asyncio.Task] = {}

    async def connect(self):
        """Marks the transport as connected and starts consume loops for existing subscriptions."""
        self._connected = True
        for channel, group_map in InMemoryTransport._SUBSCRIPTIONS.items():
            for group_id in group_map.keys():
                key = (channel, group_id)
                if key not in self._consume_tasks:
                    self._consume_tasks[key] = asyncio.create_task(self._consume_loop(channel, group_id))

    async def disconnect(self):
        """Cancels all consume loops and marks the transport as disconnected."""
        for task in self._consume_tasks.values():
            task.cancel()
        for task in self._consume_tasks.values():
            try:
                await task
            except asyncio.CancelledError:
                pass
        self._consume_tasks.clear()
        self._connected = False

    async def publish(self, channel: str, message: Union[Dict[str, Any], BaseMessage]):
        """
        Publishes a message to the given channel.
        The message is put into all queues for that channel so each consumer group receives it.
        """
        if not self._connected:
            raise RuntimeError("Transport not connected. Call `connect()` first.")
        if channel not in InMemoryTransport._CHANNELS:
            InMemoryTransport._CHANNELS[channel] = {}

        if hasattr(message, "model_dump_json"):
            data = message.model_dump_json()
        else:
            data = json.dumps(message)

        for grp_id, queue in InMemoryTransport._CHANNELS[channel].items():
            await queue.put(data)

    async def subscribe(self, channel: str, callback: Callable[[Dict[str, Any]], "asyncio.Future"], **kwargs):
        """
        Subscribes to a channel with the provided group_id.
        If no group_id is given, a unique one is generated, ensuring that each subscription
        gets its own consumer (broadcast mode).
        """
        group_id = kwargs.get("group_id", uuid.uuid4().hex)
        key = (group_id, channel)
        if "filter_by_message" in kwargs:
            async def filtered_callback(data):
                if kwargs["filter_by_message"](data):
                    await callback(data)
            InMemoryTransport._SUBSCRIPTIONS[channel][group_id].append(filtered_callback)
        else:
            InMemoryTransport._SUBSCRIPTIONS[channel][group_id].append(callback)
        if group_id not in InMemoryTransport._CHANNELS[channel]:
            InMemoryTransport._CHANNELS[channel][group_id] = asyncio.Queue()
        if self._connected and key not in self._consume_tasks:
            self._consume_tasks[key] = asyncio.create_task(self._consume_loop(channel, group_id))

    async def _consume_loop(self, channel: str, group_id: str):
        """
        Continuously pulls messages from the queue for (channel, group_id)
        and dispatches them to all registered callbacks.
        """
        queue = InMemoryTransport._CHANNELS[channel].get(group_id)
        if not queue:
            queue = asyncio.Queue()
            InMemoryTransport._CHANNELS[channel][group_id] = queue

        try:
            while True:
                msg = await queue.get()
                data = json.loads(msg)
                callbacks = InMemoryTransport._SUBSCRIPTIONS[channel].get(group_id, [])
                for cb in callbacks:
                    await cb(data)
        except asyncio.CancelledError:
            pass
        except Exception as e:
            print(f"InMemoryTransport consume loop error on channel={channel}, group={group_id}: {e}")

connect() async

Marks the transport as connected and starts consume loops for existing subscriptions.

Source code in eggai/transport/inmemory.py
async def connect(self):
    """Marks the transport as connected and starts consume loops for existing subscriptions."""
    self._connected = True
    for channel, group_map in InMemoryTransport._SUBSCRIPTIONS.items():
        for group_id in group_map.keys():
            key = (channel, group_id)
            if key not in self._consume_tasks:
                self._consume_tasks[key] = asyncio.create_task(self._consume_loop(channel, group_id))

disconnect() async

Cancels all consume loops and marks the transport as disconnected.

Source code in eggai/transport/inmemory.py
async def disconnect(self):
    """Cancels all consume loops and marks the transport as disconnected."""
    for task in self._consume_tasks.values():
        task.cancel()
    for task in self._consume_tasks.values():
        try:
            await task
        except asyncio.CancelledError:
            pass
    self._consume_tasks.clear()
    self._connected = False

publish(channel, message) async

Publishes a message to the given channel. The message is put into all queues for that channel so each consumer group receives it.

Source code in eggai/transport/inmemory.py
async def publish(self, channel: str, message: Union[Dict[str, Any], BaseMessage]):
    """
    Publishes a message to the given channel.
    The message is put into all queues for that channel so each consumer group receives it.
    """
    if not self._connected:
        raise RuntimeError("Transport not connected. Call `connect()` first.")
    if channel not in InMemoryTransport._CHANNELS:
        InMemoryTransport._CHANNELS[channel] = {}

    if hasattr(message, "model_dump_json"):
        data = message.model_dump_json()
    else:
        data = json.dumps(message)

    for grp_id, queue in InMemoryTransport._CHANNELS[channel].items():
        await queue.put(data)

subscribe(channel, callback, **kwargs) async

Subscribes to a channel with the provided group_id. If no group_id is given, a unique one is generated, ensuring that each subscription gets its own consumer (broadcast mode).

Source code in eggai/transport/inmemory.py
async def subscribe(self, channel: str, callback: Callable[[Dict[str, Any]], "asyncio.Future"], **kwargs):
    """
    Subscribes to a channel with the provided group_id.
    If no group_id is given, a unique one is generated, ensuring that each subscription
    gets its own consumer (broadcast mode).
    """
    group_id = kwargs.get("group_id", uuid.uuid4().hex)
    key = (group_id, channel)
    if "filter_by_message" in kwargs:
        async def filtered_callback(data):
            if kwargs["filter_by_message"](data):
                await callback(data)
        InMemoryTransport._SUBSCRIPTIONS[channel][group_id].append(filtered_callback)
    else:
        InMemoryTransport._SUBSCRIPTIONS[channel][group_id].append(callback)
    if group_id not in InMemoryTransport._CHANNELS[channel]:
        InMemoryTransport._CHANNELS[channel][group_id] = asyncio.Queue()
    if self._connected and key not in self._consume_tasks:
        self._consume_tasks[key] = asyncio.create_task(self._consume_loop(channel, group_id))