diff --git a/HeliosBackend/websockets.py b/HeliosBackend/websockets.py index e0a5203..05f30e0 100644 --- a/HeliosBackend/websockets.py +++ b/HeliosBackend/websockets.py @@ -1,25 +1,65 @@ import asyncio import os +from typing import Iterable, TypedDict, Optional import aiohttp import json import websockets +from websockets import broadcast from .events import ee +class PayloadType(TypedDict): + type: str + data: dict + token: Optional[str] + +class Payload: + def __init__(self, _type: str, data: dict, token: Optional[str] = None): + self.type = _type + self.data = data + self.token = token + + def to_dict(self) -> PayloadType: + return { + 'type': self.type, + 'data': self.data, + 'token': self.token, + } + + def to_json(self) -> str: + return json.dumps(self.to_dict()) + + @classmethod + def from_dict(cls, data: PayloadType): + return cls(data['type'], data['data'], data.get('token')) + + @classmethod + def from_json(cls, data: str): + return cls.from_dict(json.loads(data)) + class WebSocketConnection: def __init__(self, websocket: websockets.ServerConnection): self.websocket = websocket self.code = None self.token = None - async def send(self, data): + async def send(self, _type: str, data: dict): + await self.send_payload(Payload(_type, data)) + + async def send_payload(self, payload: Payload): + await self.send_raw(payload.to_json()) + + async def send_raw(self, data): await self.websocket.send(data) async def recv(self): return await self.websocket.recv() + async def handler(self, data: Payload): + await ee.emit_async(data.type, self, data) + class WebSocketServer: def __init__(self, port): @@ -32,8 +72,15 @@ class WebSocketServer: self.connections.add(websocket) try: connection = WebSocketConnection(websocket) - if await self.authorization(connection): - await ee.emit_async('websocket_authorized', connection) + if not await self.authorization(connection): + return + await ee.emit_async('websocket_authorized', connection) + while True: + msg = await connection.recv() + if msg is None: + break + msg = Payload.from_json(msg) + await connection.handler(msg) finally: self.connections.remove(websocket) @@ -52,9 +99,14 @@ class WebSocketServer: return False websocket.token = token - await websocket.send(json.dumps({'type': 'authorization', 'status': 200, 'token': token})) + await websocket.send_raw(json.dumps({'type': 'authorization', 'status': 200, 'token': token})) return True + def broadcast(self, sockets: Iterable[WebSocketConnection], _type: str, data: dict): + connections = [socket.websocket for socket in sockets] + data = Payload(_type, data).to_json() + broadcast(connections, data) + async def start(self): async with websockets.serve(self.handler, "", self.port): await asyncio.Future()