Websocket with Redis implementation get stuck at pubsub.reset()

Hi!

I’m following the implementation of the Websocket system using Redis at https://gist.github.com/ahopkins/5b6d380560d8e9d49e25281ff964ed81, but everytime I restart the server (after changing a line of code), if there was a connection to the websocket, the shutdown process get stuck at the pubsub.reset() call.

Here’s my code :

@app.websocket('')
async def streamable(request, ws: WebsocketImplProtocol):
    # Potentially, to avoid sending the organization at every new connections =>
    # is_new_request = False
    # session_id = request.headers.get('X-Client-Session', None)
    # if not session_id:
    #     session_id = uuid.uuid4().hex
    #     is_new_request = True

    token = request.args.get('token')
    if not token:
        return await ws.close(code=1011, reason=json.dumps({'code': 401, 'error': 'Please authenticate yourself.'}))

    agent = None
    async with db:
        session = await Session.find_by_token(token)
        if not session:
            return await ws.close(code=1011, reason=json.dumps({'code': 401, 'error': 'Please authenticate yourself.'}))

        agent = await session.get_agent()

    if not agent:
        return await ws.close(code=1011, reason=json.dumps({'code': 401, 'error': 'Please authenticate yourself.'}))

    request.app.add_task(ws.keepalive_ping())

    channel_name = redis.get_channel(agent)  # custom code to generate a channel name based on the user

    # First, we retrieve the channel used for that agent
    client = await Channel.get_client(channel_name, request.app, agent, ws)

    try:
        await client.read()
    finally:
        await Channel.leave(channel_name, client)


@app.signal('server.shutdown.before')
async def unregister_clients(app, loop):
    print('Got server.shutdown.before event!')
    await Channel.close()
    print('Event finished')



class WSClient:
    def __init__(self, agent, ws: WebsocketImplProtocol):
        self.agent = agent
        self.ws = ws

    async def write(self, data):
        if isinstance(data, dict):
            data = json.dumps(data, indent=2)

        await self.ws.send(data)

    async def read(self):
        while True:
            message = await self.ws.recv()
            if not message:
                break

            try:
                event = json.loads(message)
                action = event.get('action')
                document = event.get('value', None)

                await EventHandler.dispatch(action, document, self.agent)  # Custom code to dispatch actions (call functions)
            except json.decoder.JSONDecodeError:
                continue

    async def close(self, force=False):
        code = 1006 if force else 1000
        await self.ws.close(code)


class Channel:
    cache = {}

    def __init__(self):
        self.clients: Set[WSClient] = set()
        self.lock = Lock()
        self.pubsub = redis.pubsub()

    async def acquire_lock(self) -> None:
        if not self.lock.locked():
            await self.lock.acquire()

    async def listen(self):
        """
        Receives messages from the PubSub system at Redis
        Responsible to send them back to the client
        """
        while True:
            try:
                raw = await self.pubsub.get_message(ignore_subscribe_messages=True)
                if not raw:
                    continue

                for client in self.clients:
                    await client.write(raw['data'].decode())
            except PubSubError:
                break
            except asyncio.CancelledError:
                print('Cancel error')
                break

    @classmethod
    async def get_client(cls, channel_name: str, app, agent, ws: WebsocketImplProtocol):
        channel = None
        if channel_name in cls.cache:
            channel = cls.cache[channel_name]
            await channel.acquire_lock()
        else:
            # Is not present, we create it
            channel = cls()
            await channel.acquire_lock()

            cls.cache[channel_name] = channel
            await channel.pubsub.subscribe(channel_name)

            # If it's the first, we register a receiver callback
            # That will be triggered everytime a new data from Redis arrive
            # We only need this once to avoid sending the same data multiple times
            app.add_task(channel.listen())

        client = WSClient(agent, ws)
        channel.clients.add(client)
        return client

    @classmethod
    async def leave(cls, channel_name, client):
        if channel_name not in cls.cache:
            return

        channel = cls.cache[channel_name]
        if client in channel.clients:
            await client.close()
            channel.clients.remove(client)

        if not channel.clients:  # No clients anymore
            channel.lock.release()
            if channel.lock.locked():
                print('Lock is locked ?!')
                return

            await channel.pubsub.unsubscribe(channel_name)
            # await channel.pubsub.close()
            del channel.__class__.cache[channel_name]

    @classmethod
    async def close(cls):
        clients = []
        for channel in cls.cache.values():
            for client in channel.clients:
                clients.append(client)

        while clients:
            await clients.pop().close(True)

Do you have any ideao why the restart hangs at the .close() call from the pubsub?
(I have to do a kill -9)

Thanks in advance!

For information, adding a timeout on the get_message call fixes the above issues. Not sure if the problem is related to Sanic or AIORedis, but it might help:

raw = await self.pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)

1 Like

thanks, I’ll add thst to the gist

1 Like