|
- import asyncio
- import collections
- import datetime
- import http
- import json
- import uuid
- from functools import partial
- from pathlib import Path
-
- import websockets
-
-
- rooms = collections.defaultdict(dict)
-
-
- class PicoProtocol(websockets.WebSocketServerProtocol):
-
- def serve_file(self, path):
- home = Path(__file__).parent
- document = home.joinpath(path.lstrip('/'))
- if not document.is_file():
- document = home.joinpath('pico.html').resolve()
-
- if document.suffix == '.html':
- content_type = 'text/html; charset=utf-8'
- elif document.suffix == '.js':
- content_type = 'application/javascript; charset=utf-8'
- elif document.suffix == '.css':
- content_type = 'text/css; charset=utf-8'
- elif document.suffix == '.svg':
- content_type = 'image/svg+xml; charset=utf-8'
- else:
- content_type = 'text/plain; charset=utf-8'
-
- return (
- http.HTTPStatus.OK,
- [('Content-Type', content_type)],
- document.read_bytes(),
- )
-
- def random_redirect(self):
- new_path = str(uuid.uuid4()).split('-')[1]
- return (
- http.HTTPStatus.FOUND,
- [('Location', new_path)],
- b'',
- )
-
- async def process_request(self, path, request_headers):
- if path == '/':
- return self.random_redirect()
- if request_headers.get('Upgrade') != 'websocket':
- return self.serve_file(path)
- return await super().process_request(path, request_headers)
-
-
- async def send_json_many(targets, **data):
- for websocket in list(targets):
- await send_json(websocket, **data)
-
-
- async def send_json(websocket, **data):
- try:
- await websocket.send(json.dumps(data))
- except websockets.exceptions.ConnectionClosed:
- pass
-
-
- async def recv_json(websocket):
- try:
- return json.loads(await websocket.recv())
- except websockets.exceptions.ConnectionClosed:
- return {'kind': 'logout'}
- except json.decoder.JSONDecodeError:
- return {}
-
-
- async def handle(ws, path, server_name):
- room = rooms[path]
- usernames = room.keys()
- sockets = room.values()
- username = None
-
- while True:
- data = await recv_json(ws)
- ts = datetime.datetime.now().isoformat() + 'Z'
-
- emit = partial(send_json_many, kind=data['kind'], ts=ts)
- broadcast = partial(emit, targets=sockets)
- reply = partial(emit, targets=[ws])
- error = partial(reply, kind='error')
-
- if 'kind' not in data:
- await error(value='Message without kind is invalid')
-
- elif data['kind'] == 'login':
- username = data['value']
- if not username:
- await error(value='Username not allowed')
- break
- if username in usernames:
- await error(value='Username taken')
- break
-
- room[username] = ws
- online = list(usernames)
- await reply(kind='state', username=username)
- await broadcast(kind='state', online=online)
-
- elif username not in room:
- await error(value='Login required')
- break
-
- elif data['kind'] == 'logout':
- del room[username]
- online = list(usernames)
- await broadcast(kind='state', online=online)
- break
-
- else:
- value = data.get('value')
- if 'target' in data:
- recipients = {username, data['target']}
- targets = {v for k, v in room.items() if k in recipients}
- await broadcast(source=username, value=value, targets=targets)
- else:
- await broadcast(source=username, value=value)
-
-
- async def start_server(host, port, server_name):
- bound_handle = partial(handle, server_name=server_name)
- bound_serve = partial(websockets.serve, create_protocol=PicoProtocol)
- return await bound_serve(bound_handle, host, port)
-
-
- if __name__ == '__main__':
- host, port = 'localhost', 9753
- loop = asyncio.get_event_loop()
- loop.run_until_complete(start_server(host, port, 'PicoChat'))
- print(f'Running on {host}:{port}')
- loop.run_forever()
|