Source code for asyncwebsockets.websocket

"""
Base class for all websockets.
"""
# pylint: disable=R0913

from io import BytesIO, StringIO
from typing import List, Optional, Union

import anyio
from anyio.abc import SocketAttribute
from wsproto import ConnectionType, WSConnection
from wsproto.events import (
    AcceptConnection,
    RejectConnection,
    BytesMessage,
    CloseConnection,
    Message,
    Request,
    TextMessage,
)


class Websocket:
    _sock = None
    _connection = None

    def __init__(self):
        self._byte_buffer = BytesIO()
        self._string_buffer = StringIO()
        self._send_lock = anyio.Lock()

    async def __ainit__(
        self, addr, path: str, headers: Optional[List] = None, subprotocols=None, **connect_kw
    ):
        connect_kw.pop("autostart_tls", None)
        sock = await anyio.connect_tcp(*addr, **connect_kw)
        await self.start_client(sock, addr, path=path, headers=headers, subprotocols=subprotocols)

    @property
    def raw_socket(self):
        return self._sock.extra(SocketAttribute.raw_socket)

    async def start_client(
        self,
        sock: anyio.abc.SocketStream,
        addr,
        path: str,
        headers: Optional[List] = None,
        subprotocols: Optional[List[str]] = None,
    ):
        """Start a client WS connection on this socket.

        Returns: the AcceptConnection message.
        """
        assert self._sock is None
        self._sock = sock
        self._connection = WSConnection(ConnectionType.CLIENT)
        if headers is None:
            headers = []
        if subprotocols is None:
            subprotocols = []
        data = self._connection.send(
            Request(host=addr[0], target=path, extra_headers=headers, subprotocols=subprotocols)
        )
        try:
            await self._sock.send_all(data)
        except AttributeError:
            await self._sock.send(data)

        event = await self._next_event()
        if not isinstance(event, AcceptConnection):
            raise ConnectionError("Failed to establish a connection", event)
        return event

    async def start_server(
        self, sock: anyio.abc.SocketStream, filter=None
    ):  # pylint: disable=W0622
        """Start a server WS connection on this socket.

        Filter: an async callable that gets passed the initial Request.
        It may return an AcceptConnection message, a bool, or a string (the
        subprotocol to use).
        Returns: the Request message.
        """
        assert self._sock is None
        self._sock = sock
        self._connection = WSConnection(ConnectionType.SERVER)

        event = await self._next_event()
        if not isinstance(event, Request):
            raise ConnectionError("Failed to establish a connection", event)
        msg = None
        if filter is not None:
            msg = await filter(event)
            if not msg:
                msg = RejectConnection()
            elif msg is True:
                msg = None
            elif isinstance(msg, str):
                msg = AcceptConnection(subprotocol=msg)
        if not msg:
            subprotocol = event.subprotocols[0] if event.subprotocols else None
            msg = AcceptConnection(subprotocol=subprotocol)
        data = self._connection.send(msg)
        try:
            await self._sock.send_all(data)
        except AttributeError:
            await self._sock.send(data)
        if not isinstance(msg, AcceptConnection):
            raise ConnectionError("Not accepted", msg)

    async def _next_event(self):
        """
        Gets the next event.
        """
        while True:
            for event in self._connection.events():
                if isinstance(event, Message):
                    # check if we need to buffer
                    if event.message_finished:
                        return self._wrap_data(self._gather_buffers(event))
                    self._buffer(event)
                    break  # exit for loop
                return event

            if self._sock is None:
                return CloseConnection(code=500, reason="Socket closed")
            try:
                data = await self._sock.receive_some(4096)
            except AttributeError:
                data = await self._sock.receive(4096)
            if not data:
                return CloseConnection(code=500, reason="Socket closed")
            self._connection.receive_data(data)

[docs] async def close(self, code: int = 1006, reason: str = "Connection closed"): """ Closes the websocket. """ sock, self._sock = self._sock, None if not sock: return data = self._connection.send(CloseConnection(code=code, reason=reason)) try: await sock.send_all(data) except AttributeError: await sock.send(data) # No, we don't wait for the correct reply try: await sock.close() except AttributeError: await sock.aclose()
[docs] async def send(self, data: Union[bytes, str], final: bool = True): """ Sends some data down the connection. """ MsgType = TextMessage if isinstance(data, str) else BytesMessage data = MsgType(data=data, message_finished=final) async with self._send_lock: data = self._connection.send(event=data) try: await self._sock.send_all(data) except AttributeError: await self._sock.send(data)
def _buffer(self, event: Message): """ Buffers an event, if applicable. """ if isinstance(event, BytesMessage): self._byte_buffer.write(event.data) elif isinstance(event, TextMessage): self._string_buffer.write(event.data) def _gather_buffers(self, event: Message): """ Gathers all the data from a buffer. """ if isinstance(event, BytesMessage): buf = self._byte_buffer else: buf = self._string_buffer # yay for code shortening buf.write(event.data) buf.seek(0) data = buf.read() buf.seek(0) buf.truncate() return data @staticmethod def _wrap_data(data: Union[str, bytes]): """ Wraps data into the right event. """ MsgType = TextMessage if isinstance(data, str) else BytesMessage return MsgType(data=data, frame_finished=True, message_finished=True) async def __aiter__(self): while True: msg = await self._next_event() if isinstance(msg, CloseConnection): return yield msg