From 52537f291bf9622a2e439bbb1aabe4a4d32d1ebf Mon Sep 17 00:00:00 2001 From: Robert Blackhart Date: Sun, 16 Feb 2020 20:58:24 -0500 Subject: [PATCH 1/3] implement an asyncio version of the websocket class --- examples/aio_echo_websocket_org.py | 26 +++ uaiowebsockets/client.py | 74 +++++++++ uaiowebsockets/protocol.py | 248 +++++++++++++++++++++++++++++ 3 files changed, 348 insertions(+) create mode 100644 examples/aio_echo_websocket_org.py create mode 100644 uaiowebsockets/client.py create mode 100644 uaiowebsockets/protocol.py diff --git a/examples/aio_echo_websocket_org.py b/examples/aio_echo_websocket_org.py new file mode 100644 index 0000000..d78cba6 --- /dev/null +++ b/examples/aio_echo_websocket_org.py @@ -0,0 +1,26 @@ +from uaiowebsockets import client +import uasyncio + + +async def connect_and_receive(loop): + websocket = await client.connect("ws://echo.websocket.org") + + loop.create_task(writer(1, websocket)) + loop.create_task(writer(2, websocket)) + + while True: + print(await websocket.recv()) + + +async def writer(index, websocket): + message_num = 0 + while True: + message_num += 1 + msg = "[%s]: Hello from writer %s." % (message_num, index) + await websocket.send(msg) + await uasyncio.sleep(5) + + +loop = uasyncio.get_event_loop() +loop.create_task(connect_and_receive(loop)) +loop.run_forever() diff --git a/uaiowebsockets/client.py b/uaiowebsockets/client.py new file mode 100644 index 0000000..d90d4ef --- /dev/null +++ b/uaiowebsockets/client.py @@ -0,0 +1,74 @@ +""" +Websockets client for micropython + +Based very heavily off +https://github.com/aaugustin/websockets/blob/master/websockets/client.py +""" + +import logging +import usocket as socket +import uasyncio as asyncio +import ubinascii as binascii +import urandom as random +import ussl + +from .protocol import Websocket, urlparse + +LOGGER = logging.getLogger(__name__) + + +class WebsocketClient(Websocket): + is_client = True + +async def connect(uri): + """ + Connect a websocket. + """ + + uri = urlparse(uri) + assert uri + + if __debug__: LOGGER.debug("open connection %s:%s", + uri.hostname, uri.port) + + sock = socket.socket() + addr = socket.getaddrinfo(uri.hostname, uri.port) + sock.connect(addr[0][4]) + if uri.protocol == 'wss': + sock = ussl.wrap_socket(sock) + + sreader = asyncio.StreamReader(sock) + swriter = asyncio.StreamWriter(sock, {}) + + async def send_header(header, *args): + if __debug__: LOGGER.debug(str(header), *args) + await swriter.awrite(header % args + '\r\n') + + # Sec-WebSocket-Key is 16 bytes of random base64 encoded + key = binascii.b2a_base64(bytes(random.getrandbits(8) + for _ in range(16)))[:-1] + + await send_header(b'GET %s HTTP/1.1', uri.path or '/') + await send_header(b'Host: %s:%s', uri.hostname, uri.port) + await send_header(b'Connection: Upgrade') + await send_header(b'Upgrade: websocket') + await send_header(b'Sec-WebSocket-Key: %s', key) + await send_header(b'Sec-WebSocket-Version: 13') + await send_header(b'Origin: http://{hostname}:{port}'.format( + hostname=uri.hostname, + port=uri.port) + ) + await send_header(b'') + + response = await sreader.readline() + header = response[:-2] + assert header.startswith(b'HTTP/1.1 101 '), header + + # We don't (currently) need these headers + # FIXME: should we check the return key? + while header: + if __debug__: LOGGER.debug(str(header)) + response = await sreader.readline() + header = response[:-2] + + return WebsocketClient(sock, sreader, swriter) diff --git a/uaiowebsockets/protocol.py b/uaiowebsockets/protocol.py new file mode 100644 index 0000000..89f2366 --- /dev/null +++ b/uaiowebsockets/protocol.py @@ -0,0 +1,248 @@ +""" +Websockets protocol +""" + +import logging +import ure as re +import ustruct as struct +import urandom as random +import usocket as socket +from ucollections import namedtuple + +LOGGER = logging.getLogger(__name__) + +# Opcodes +OP_CONT = const(0x0) +OP_TEXT = const(0x1) +OP_BYTES = const(0x2) +OP_CLOSE = const(0x8) +OP_PING = const(0x9) +OP_PONG = const(0xa) + +# Close codes +CLOSE_OK = const(1000) +CLOSE_GOING_AWAY = const(1001) +CLOSE_PROTOCOL_ERROR = const(1002) +CLOSE_DATA_NOT_SUPPORTED = const(1003) +CLOSE_BAD_DATA = const(1007) +CLOSE_POLICY_VIOLATION = const(1008) +CLOSE_TOO_BIG = const(1009) +CLOSE_MISSING_EXTN = const(1010) +CLOSE_BAD_CONDITION = const(1011) + +URL_RE = re.compile(r'(wss|ws)://([A-Za-z0-9-\.]+)(?:\:([0-9]+))?(/.+)?') +URI = namedtuple('URI', ('protocol', 'hostname', 'port', 'path')) + +class NoDataException(Exception): + pass + +class ConnectionClosed(Exception): + pass + +def urlparse(uri): + """Parse ws:// URLs""" + match = URL_RE.match(uri) + if match: + protocol = match.group(1) + host = match.group(2) + port = match.group(3) + path = match.group(4) + + if protocol == 'wss': + if port is None: + port = 443 + elif protocol == 'ws': + if port is None: + port = 80 + else: + raise ValueError('Scheme {} is invalid'.format(protocol)) + + return URI(protocol, host, int(port), path) + + +class Websocket: + """ + Basis of the Websocket protocol. + + This can probably be replaced with the C-based websocket module, but + this one currently supports more options. + """ + is_client = False + + def __init__(self, sock, sreader, swriter): + self.sock = sock + self.sreader = sreader + self.swriter = swriter + self.open = True + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + self.close() + + def settimeout(self, timeout): + self.sock.settimeout(timeout) + + async def read_frame(self, max_size=None): + """ + Read a frame from the socket. + See https://tools.ietf.org/html/rfc6455#section-5.2 for the details. + """ + + # Frame header + two_bytes = await self.sreader.read(2) + + if not two_bytes: + raise NoDataException + + byte1, byte2 = struct.unpack('!BB', two_bytes) + + # Byte 1: FIN(1) _(1) _(1) _(1) OPCODE(4) + fin = bool(byte1 & 0x80) + opcode = byte1 & 0x0f + + # Byte 2: MASK(1) LENGTH(7) + mask = bool(byte2 & (1 << 7)) + length = byte2 & 0x7f + + if length == 126: # Magic number, length header is 2 bytes + length, = struct.unpack('!H', self.sock.read(2)) + elif length == 127: # Magic number, length header is 8 bytes + length, = struct.unpack('!Q', self.sock.read(8)) + + if mask: # Mask is 4 bytes + mask_bits = await self.sreader.read(4) + + try: + data = await self.sreader.read(length) + except MemoryError: + # We can't receive this many bytes, close the socket + if __debug__: LOGGER.debug("Frame of length %s too big. Closing", + length) + self.close(code=CLOSE_TOO_BIG) + return True, OP_CLOSE, None + + if mask: + data = bytes(b ^ mask_bits[i % 4] + for i, b in enumerate(data)) + + return fin, opcode, data + + async def write_frame(self, opcode, data=b''): + """ + Write a frame to the socket. + See https://tools.ietf.org/html/rfc6455#section-5.2 for the details. + """ + fin = True + mask = self.is_client # messages sent by client are masked + + length = len(data) + + # Frame header + # Byte 1: FIN(1) _(1) _(1) _(1) OPCODE(4) + byte1 = 0x80 if fin else 0 + byte1 |= opcode + + # Byte 2: MASK(1) LENGTH(7) + byte2 = 0x80 if mask else 0 + + if length < 126: # 126 is magic value to use 2-byte length header + byte2 |= length + await self.swriter.awrite(struct.pack('!BB', byte1, byte2)) + + elif length < (1 << 16): # Length fits in 2-bytes + byte2 |= 126 # Magic code + await self.srwriter.awrite(struct.pack('!BBH', byte1, byte2, length)) + + elif length < (1 << 64): + byte2 |= 127 # Magic code + await self.swriter.awrite(struct.pack('!BBQ', byte1, byte2, length)) + + else: + raise ValueError() + + if mask: # Mask is 4 bytes + mask_bits = struct.pack('!I', random.getrandbits(32)) + await self.swriter.awrite(mask_bits) + + data = bytes(b ^ mask_bits[i % 4] + for i, b in enumerate(data)) + + await self.swriter.awrite(data) + + async def recv(self): + """ + Receive data from the websocket. + + This is slightly different from 'websockets' in that it doesn't + fire off a routine to process frames and put the data in a queue. + If you don't call recv() sufficiently often you won't process control + frames. + """ + assert self.open + + while self.open: + try: + fin, opcode, data = await self.read_frame() + except NoDataException: + return '' + except ValueError: + LOGGER.debug("Failed to read frame. Socket dead.") + self._close() + raise ConnectionClosed() + + if not fin: + raise NotImplementedError() + + if opcode == OP_TEXT: + return data.decode('utf-8') + elif opcode == OP_BYTES: + return data + elif opcode == OP_CLOSE: + self._close() + return + elif opcode == OP_PONG: + # Ignore this frame, keep waiting for a data frame + continue + elif opcode == OP_PING: + # We need to send a pong frame + if __debug__: LOGGER.debug("Sending PONG") + await self.write_frame(OP_PONG, data) + # And then wait to receive + continue + elif opcode == OP_CONT: + # This is a continuation of a previous frame + raise NotImplementedError(opcode) + else: + raise ValueError(opcode) + + async def send(self, buf): + """Send data to the websocket.""" + + assert self.open + + if isinstance(buf, str): + opcode = OP_TEXT + buf = buf.encode('utf-8') + elif isinstance(buf, bytes): + opcode = OP_BYTES + else: + raise TypeError() + + await self.write_frame(opcode, buf) + + async def close(self, code=CLOSE_OK, reason=''): + """Close the websocket.""" + if not self.open: + return + + buf = struct.pack('!H', code) + reason.encode('utf-8') + + await self.write_frame(OP_CLOSE, buf) + self._close() + + def _close(self): + if __debug__: LOGGER.debug("Connection closed") + self.open = False + self.sock.close() From 2d73f6a72e03d4f5ec92244aa91b9be042478699 Mon Sep 17 00:00:00 2001 From: Robert Blackhart Date: Sat, 22 Feb 2020 17:21:56 -0500 Subject: [PATCH 2/3] use aenter and aexit on the websocket object --- examples/aio_websocket_context.py | 13 +++++++++++++ uaiowebsockets/protocol.py | 6 +++--- 2 files changed, 16 insertions(+), 3 deletions(-) create mode 100644 examples/aio_websocket_context.py diff --git a/examples/aio_websocket_context.py b/examples/aio_websocket_context.py new file mode 100644 index 0000000..720dff9 --- /dev/null +++ b/examples/aio_websocket_context.py @@ -0,0 +1,13 @@ +from uaiowebsockets import client +import uasyncio +import sys + +async def main(): + async with await client.connect("ws://echo.websocket.org") as websocket: + await websocket.send("hello, websocket") + print(await websocket.recv()) + sys.exit(0) + +loop = uasyncio.get_event_loop() +loop.create_task(main()) +loop.run_forever() diff --git a/uaiowebsockets/protocol.py b/uaiowebsockets/protocol.py index 89f2366..4880fc1 100644 --- a/uaiowebsockets/protocol.py +++ b/uaiowebsockets/protocol.py @@ -75,11 +75,11 @@ def __init__(self, sock, sreader, swriter): self.swriter = swriter self.open = True - def __enter__(self): + async def __aenter__(self): return self - def __exit__(self, exc_type, exc, tb): - self.close() + async def __aexit__(self, exc_type, exc, tb): + await self.close() def settimeout(self, timeout): self.sock.settimeout(timeout) From 864b1c7ce408cacbca5aea03acf843926fc8bc7a Mon Sep 17 00:00:00 2001 From: Robert Blackhart Date: Tue, 3 Jun 2025 15:07:15 -0400 Subject: [PATCH 3/3] Update uaiowebsockets/protocol.py Co-authored-by: Timothy Ellis <3098078+TimAEllis@users.noreply.github.com> --- uaiowebsockets/protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/uaiowebsockets/protocol.py b/uaiowebsockets/protocol.py index 4880fc1..59e1b25 100644 --- a/uaiowebsockets/protocol.py +++ b/uaiowebsockets/protocol.py @@ -153,7 +153,7 @@ async def write_frame(self, opcode, data=b''): elif length < (1 << 16): # Length fits in 2-bytes byte2 |= 126 # Magic code - await self.srwriter.awrite(struct.pack('!BBH', byte1, byte2, length)) + await self.swriter.awrite(struct.pack('!BBH', byte1, byte2, length)) elif length < (1 << 64): byte2 |= 127 # Magic code