Add Payload and handle connections

This commit is contained in:
Riley Winkler 2025-02-11 23:59:52 -06:00
parent ae1699ee9f
commit d18b1bf00d

View File

@ -1,25 +1,65 @@
import asyncio import asyncio
import os import os
from typing import Iterable, TypedDict, Optional
import aiohttp import aiohttp
import json import json
import websockets import websockets
from websockets import broadcast
from .events import ee from .events import ee
class PayloadType(TypedDict):
type: str
data: dict
token: Optional[str]
class Payload:
def __init__(self, _type: str, data: dict, token: Optional[str] = None):
self.type = _type
self.data = data
self.token = token
def to_dict(self) -> PayloadType:
return {
'type': self.type,
'data': self.data,
'token': self.token,
}
def to_json(self) -> str:
return json.dumps(self.to_dict())
@classmethod
def from_dict(cls, data: PayloadType):
return cls(data['type'], data['data'], data.get('token'))
@classmethod
def from_json(cls, data: str):
return cls.from_dict(json.loads(data))
class WebSocketConnection: class WebSocketConnection:
def __init__(self, websocket: websockets.ServerConnection): def __init__(self, websocket: websockets.ServerConnection):
self.websocket = websocket self.websocket = websocket
self.code = None self.code = None
self.token = None self.token = None
async def send(self, data): async def send(self, _type: str, data: dict):
await self.send_payload(Payload(_type, data))
async def send_payload(self, payload: Payload):
await self.send_raw(payload.to_json())
async def send_raw(self, data):
await self.websocket.send(data) await self.websocket.send(data)
async def recv(self): async def recv(self):
return await self.websocket.recv() return await self.websocket.recv()
async def handler(self, data: Payload):
await ee.emit_async(data.type, self, data)
class WebSocketServer: class WebSocketServer:
def __init__(self, port): def __init__(self, port):
@ -32,8 +72,15 @@ class WebSocketServer:
self.connections.add(websocket) self.connections.add(websocket)
try: try:
connection = WebSocketConnection(websocket) connection = WebSocketConnection(websocket)
if await self.authorization(connection): if not await self.authorization(connection):
return
await ee.emit_async('websocket_authorized', connection) await ee.emit_async('websocket_authorized', connection)
while True:
msg = await connection.recv()
if msg is None:
break
msg = Payload.from_json(msg)
await connection.handler(msg)
finally: finally:
self.connections.remove(websocket) self.connections.remove(websocket)
@ -52,9 +99,14 @@ class WebSocketServer:
return False return False
websocket.token = token websocket.token = token
await websocket.send(json.dumps({'type': 'authorization', 'status': 200, 'token': token})) await websocket.send_raw(json.dumps({'type': 'authorization', 'status': 200, 'token': token}))
return True return True
def broadcast(self, sockets: Iterable[WebSocketConnection], _type: str, data: dict):
connections = [socket.websocket for socket in sockets]
data = Payload(_type, data).to_json()
broadcast(connections, data)
async def start(self): async def start(self):
async with websockets.serve(self.handler, "", self.port): async with websockets.serve(self.handler, "", self.port):
await asyncio.Future() await asyncio.Future()