|
- import asyncio
- import collections
- import datetime
- import http
- import json
- import functools
- from pathlib import Path
-
- import websockets
-
-
- rooms = collections.defaultdict(set)
-
-
- class PicoProtocol(websockets.WebSocketServerProtocol):
-
- def serve_file(self, path):
- document = Path(__file__, '..', Path(path).name).resolve()
- if not document.is_file():
- document = Path(__file__, '..', 'pico.html').resolve()
- content_type = 'text/html; charset=utf-8'
- elif path.endswith('.js'):
- content_type = 'application/javascript; charset=utf-8'
- elif path.endswith('.css'):
- content_type = 'text/css; charset=utf-8'
- else:
- content_type = 'text/plain; charset=utf-8'
- return (
- http.HTTPStatus.OK,
- [('Content-Type', content_type)],
- document.read_bytes(),
- )
-
- async def process_request(self, path, request_headers):
- if request_headers.get('Upgrade') != 'websocket':
- return self.serve_file(path)
- return await super().process_request(path, request_headers)
-
-
- def get_usernames(sockets):
- return sorted(ws.username for ws in sockets)
-
-
- async def send_json_many(targets, **data):
- for websocket in list(targets):
- await send_json(websocket, **data)
-
-
- async def send_json(websocket, **data):
- try:
- await websocket.send(json.dumps(data))
- except websockets.exceptions.ConnectionClosed:
- pass
-
-
- async def recv_json(websocket):
- try:
- return json.loads(await websocket.recv())
- except websockets.exceptions.ConnectionClosed:
- return {'action': 'logout'}
- except json.decoder.JSONDecodeError:
- return {}
-
-
- async def core(ws, path, server_name):
- sockets = rooms[path]
-
- while True:
- data = await recv_json(ws)
- ts = datetime.datetime.now().isoformat()
- reply = functools.partial(send_json, websocket=ws, ts=ts)
- error = functools.partial(reply, kind='error')
- broadcast = functools.partial(send_json_many, targets=set(sockets), ts=ts)
-
- if 'action' not in data:
- await error(info='Message without action is invalid')
-
- elif data['action'] == 'login':
- ws.username = data['username']
- if not ws.username:
- await error(info='Username not allowed')
- break
- if ws.username in get_usernames(sockets):
- await error(info='Username taken')
- break
-
- sockets.add(ws)
- online = get_usernames(sockets)
- await reply(kind='update', users=online, info=f'Welcome to {path}', username=ws.username)
- await broadcast(kind='update', users=online, info=f'{ws.username} joined')
-
- elif data['action'] == 'post':
- text = data['text']
- if not text:
- continue
- elif 'target' in data:
- targets = {ss for ss in sockets if ss.username in {ws.username, data['target']}}
- await send_json_many(ts=ts, targets=targets, kind='post', source=ws.username, text=text)
- else:
- await broadcast(kind='post', source=ws.username, text=text)
-
- elif data['action'] == 'logout':
- sockets.discard(ws)
- await broadcast(kind='update', users=get_usernames(sockets), info=f'{ws.username} left')
- break
-
-
- async def start_server(host, port, server_name):
- bound_core = functools.partial(core, server_name=server_name)
- return await websockets.serve(bound_core, host, port, create_protocol=PicoProtocol)
-
-
- if __name__ == '__main__':
- host, port = 'localhost', 9753
- loop = asyncio.get_event_loop()
- loop.run_until_complete(start_server(host, port, 'PicoChat'))
- print(f'Running on {host}:{port}')
- loop.run_forever()
|