import asyncio import collections import datetime import http import json import uuid from functools import partial from pathlib import Path import websockets rooms = collections.defaultdict(dict) class PicoProtocol(websockets.WebSocketServerProtocol): def serve_file(self, path): home = Path(__file__).parent document = home.joinpath(path.lstrip('/')) if not document.is_file(): document = home.joinpath('pico.html').resolve() if document.suffix == '.html': content_type = 'text/html; charset=utf-8' elif document.suffix == '.js': content_type = 'application/javascript; charset=utf-8' elif document.suffix == '.css': content_type = 'text/css; charset=utf-8' elif document.suffix == '.svg': content_type = 'image/svg+xml; charset=utf-8' else: content_type = 'text/plain; charset=utf-8' return ( http.HTTPStatus.OK, [('Content-Type', content_type)], document.read_bytes(), ) def random_redirect(self): new_path = str(uuid.uuid4()).split('-')[1] return ( http.HTTPStatus.FOUND, [('Location', new_path)], b'', ) async def process_request(self, path, request_headers): if path == '/': return self.random_redirect() if request_headers.get('Upgrade') != 'websocket': return self.serve_file(path) return await super().process_request(path, request_headers) async def send_json_many(targets, **data): for websocket in list(targets): await send_json(websocket, **data) async def send_json(websocket, **data): try: await websocket.send(json.dumps(data)) except websockets.exceptions.ConnectionClosed: pass async def recv_json(websocket): try: return json.loads(await websocket.recv()) except websockets.exceptions.ConnectionClosed: return {'kind': 'logout'} except json.decoder.JSONDecodeError: return {} async def handle(ws, path, server_name): room = rooms[path] usernames = room.keys() sockets = room.values() username = None while True: data = await recv_json(ws) ts = datetime.datetime.now().isoformat() + 'Z' emit = partial(send_json_many, kind=data['kind'], ts=ts) broadcast = partial(emit, targets=sockets) reply = partial(emit, targets=[ws]) error = partial(reply, kind='error') if 'kind' not in data: await error(value='Message without kind is invalid') elif data['kind'] == 'login': username = data['value'] if not username: await error(value='Username not allowed') break if username in usernames: await error(value='Username taken') break room[username] = ws online = list(usernames) await reply(kind='state', username=username) await broadcast(kind='state', online=online) elif username not in room: await error(value='Login required') break elif data['kind'] == 'logout': del room[username] online = list(usernames) await broadcast(kind='state', online=online) break else: value = data.get('value') if 'target' in data: recipients = {username, data['target']} targets = {v for k, v in room.items() if k in recipients} await broadcast(source=username, value=value, targets=targets) else: await broadcast(source=username, value=value) async def start_server(host, port, server_name): bound_handle = partial(handle, server_name=server_name) bound_serve = partial(websockets.serve, create_protocol=PicoProtocol) return await bound_serve(bound_handle, host, port) if __name__ == '__main__': host, port = 'localhost', 9753 loop = asyncio.get_event_loop() loop.run_until_complete(start_server(host, port, 'PicoChat')) print(f'Running on {host}:{port}') loop.run_forever()