import asyncio import collections import datetime import http import json from functools import partial from pathlib import Path import websockets rooms = collections.defaultdict(dict) class PicoProtocol(websockets.WebSocketServerProtocol): def serve_file(self, path): document = Path(__file__, '..', Path(path).name).resolve() if not document.is_file(): document = Path(__file__, '..', 'pico.html').resolve() content_type = 'text/html; charset=utf-8' elif path.endswith('.js'): content_type = 'application/javascript; charset=utf-8' elif path.endswith('.css'): content_type = 'text/css; charset=utf-8' else: content_type = 'text/plain; charset=utf-8' return ( http.HTTPStatus.OK, [('Content-Type', content_type)], document.read_bytes(), ) async def process_request(self, path, request_headers): if request_headers.get('Upgrade') != 'websocket': return self.serve_file(path) return await super().process_request(path, request_headers) async def send_json_many(targets, **data): for websocket in list(targets): await send_json(websocket, **data) async def send_json(websocket, **data): try: await websocket.send(json.dumps(data)) except websockets.exceptions.ConnectionClosed: pass async def recv_json(websocket): try: return json.loads(await websocket.recv()) except websockets.exceptions.ConnectionClosed: return {'kind': 'logout'} except json.decoder.JSONDecodeError: return {} async def core(ws, path, server_name): room = rooms[path] usernames = room.keys() sockets = room.values() username = None while True: data = await recv_json(ws) ts = datetime.datetime.now().isoformat() + 'Z' emit = partial(send_json_many, kind=data['kind'], value=data.get('value'), ts=ts) broadcast = partial(emit, targets=sockets) reply = partial(emit, targets=[ws]) error = partial(reply, kind='error') if 'kind' not in data: await error(value='Message without kind is invalid') elif data['kind'] == 'login': username = data['value'] if not username: await error(value='Username not allowed') break if username in usernames: await error(value='Username taken') break others = list(sockets) room[username] = ws online = list(usernames) await reply(online=online) await broadcast(kind='post', value=f'{username} joined', online=online, targets=others) await reply(kind='post', value=f'Welcome to {path}') elif username not in room: await error(value='Login required') break elif data['kind'] == 'logout': del room[username] online = list(usernames) await broadcast(kind='post', value=f'{username} left', online=online) break else: if 'target' in data: targets = {v for k, v in room.items() if k in {username, data['target']}} await broadcast(source=username, targets=targets) else: await broadcast(source=username) async def start_server(host, port, server_name): bound_core = partial(core, server_name=server_name) return await websockets.serve(bound_core, host, port, create_protocol=PicoProtocol) if __name__ == '__main__': host, port = 'localhost', 9753 loop = asyncio.get_event_loop() loop.run_until_complete(start_server(host, port, 'PicoChat')) print(f'Running on {host}:{port}') loop.run_forever()