import asyncio import collections import datetime import http import json import functools from pathlib import Path import websockets rooms = collections.defaultdict(set) class PicoProtocol(websockets.WebSocketServerProtocol): def serve_file(self, path): document = Path(__file__, '..', Path(path).name).resolve() if not document.is_file(): document = Path(__file__, '..', 'pico.html').resolve() content_type = 'text/html; charset=utf-8' elif path.endswith('.js'): content_type = 'application/javascript; charset=utf-8' elif path.endswith('.css'): content_type = 'text/css; charset=utf-8' else: content_type = 'text/plain; charset=utf-8' return ( http.HTTPStatus.OK, [('Content-Type', content_type)], document.read_bytes(), ) async def process_request(self, path, request_headers): if request_headers.get('Upgrade') != 'websocket': return self.serve_file(path) return await super().process_request(path, request_headers) def get_usernames(sockets): return sorted(ws.username for ws in sockets) async def send_json_many(targets, **data): for websocket in list(targets): await send_json(websocket, **data) async def send_json(websocket, **data): try: await websocket.send(json.dumps(data)) except websockets.exceptions.ConnectionClosed: pass async def recv_json(websocket): try: return json.loads(await websocket.recv()) except websockets.exceptions.ConnectionClosed: return {'action': 'logout'} except json.decoder.JSONDecodeError: return {} async def core(ws, path, server_name): sockets = rooms[path] while True: data = await recv_json(ws) ts = datetime.datetime.now().isoformat() reply = functools.partial(send_json, websocket=ws, ts=ts) error = functools.partial(reply, kind='error') broadcast = functools.partial(send_json_many, targets=set(sockets), ts=ts) if 'action' not in data: await error(info='Message without action is invalid') elif data['action'] == 'login': ws.username = data['username'] if not ws.username: await error(info='Username not allowed') break if ws.username in get_usernames(sockets): await error(info='Username taken') break sockets.add(ws) online = get_usernames(sockets) await reply(kind='update', users=online, info=f'Welcome to {path}', username=ws.username) await broadcast(kind='update', users=online, info=f'{ws.username} joined') elif data['action'] == 'post': text = data['text'] if not text: continue elif 'target' in data: targets = {ss for ss in sockets if ss.username in {ws.username, data['target']}} await send_json_many(ts=ts, targets=targets, kind='post', source=ws.username, text=text) else: await broadcast(kind='post', source=ws.username, text=text) elif data['action'] == 'logout': sockets.discard(ws) await broadcast(kind='update', users=get_usernames(sockets), info=f'{ws.username} left') break async def start_server(host, port, server_name): bound_core = functools.partial(core, server_name=server_name) return await websockets.serve(bound_core, host, port, create_protocol=PicoProtocol) if __name__ == '__main__': host, port = 'localhost', 9753 loop = asyncio.get_event_loop() loop.run_until_complete(start_server(host, port, 'PicoChat')) print(f'Running on {host}:{port}') loop.run_forever()