|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147 |
- import asyncio
- import collections
- import datetime
- import http
- import json
- import uuid
- from functools import partial
- from pathlib import Path
- from urllib.parse import urlsplit, urlunsplit
-
- import websockets
-
-
- rooms = collections.defaultdict(dict)
-
-
- def ignore_querystring(path):
- return urlunsplit(urlsplit(path)[:3] + ('', ''))
-
-
- class PicoProtocol(websockets.WebSocketServerProtocol):
-
- def serve_file(self, path):
- home = Path(__file__).parent
- document = home.joinpath(path.lstrip('/'))
- if not document.is_file():
- document = home.joinpath('pico.html').resolve()
-
- if document.suffix == '.html':
- content_type = 'text/html; charset=utf-8'
- elif document.suffix == '.js':
- content_type = 'application/javascript; charset=utf-8'
- elif document.suffix == '.css':
- content_type = 'text/css; charset=utf-8'
- elif document.suffix == '.svg':
- content_type = 'image/svg+xml; charset=utf-8'
- else:
- content_type = 'text/plain; charset=utf-8'
-
- return (
- http.HTTPStatus.OK,
- [('Content-Type', content_type)],
- document.read_bytes(),
- )
-
- def redirect_to(self, path):
- return (
- http.HTTPStatus.FOUND,
- [('Location', path)],
- b'',
- )
-
- async def process_request(self, path, request_headers):
- if path == '/':
- random_path = str(uuid.uuid4()).split('-')[1]
- return self.redirect_to(random_path)
- elif path.lower() != path:
- return self.redirect_to(path.lower()[1:])
- if request_headers.get('Upgrade') != 'websocket':
- return self.serve_file(path)
- return await super().process_request(path, request_headers)
-
-
- async def send_json_many(targets, **data):
- for websocket in list(targets):
- await send_json(websocket, **data)
-
-
- async def send_json(websocket, **data):
- try:
- await websocket.send(json.dumps(data))
- except websockets.exceptions.ConnectionClosed:
- pass
-
-
- async def recv_json(websocket):
- try:
- return json.loads(await websocket.recv())
- except websockets.exceptions.ConnectionClosed:
- return {'kind': 'logout'}
- except json.decoder.JSONDecodeError:
- return {}
-
-
- async def handle(ws, path, server_name):
- room = rooms[ignore_querystring(path)]
- usernames = room.keys()
- sockets = room.values()
- username = None
-
- while True:
- data = await recv_json(ws)
- ts = datetime.datetime.now().isoformat() + 'Z'
-
- emit = partial(send_json_many, kind=data['kind'], ts=ts)
- broadcast = partial(emit, targets=sockets)
- reply = partial(emit, targets=[ws])
- error = partial(reply, kind='error')
-
- if 'kind' not in data:
- await error(value='Message without kind is invalid')
-
- elif data['kind'] == 'login':
- username = data['value']
- if not username:
- await error(value='Username not allowed')
- break
- if username in usernames:
- await error(value='Username taken')
- break
-
- room[username] = ws
- online = list(usernames)
- await reply(kind='state', username=username)
- await broadcast(kind='state', online=online)
-
- elif username not in room:
- await error(value='Login required')
- break
-
- elif data['kind'] == 'logout':
- del room[username]
- online = list(usernames)
- await broadcast(kind='state', online=online)
- break
-
- else:
- value = data.get('value')
- if 'target' in data:
- targets = {v for k, v in room.items() if k == data['target']}
- await broadcast(source=username, value=value, targets=targets)
- else:
- await broadcast(source=username, value=value)
-
-
- async def start_server(host, port, server_name):
- bound_handle = partial(handle, server_name=server_name)
- bound_serve = partial(websockets.serve, create_protocol=PicoProtocol)
- return await bound_serve(bound_handle, host, port)
-
-
- if __name__ == '__main__':
- host, port = 'localhost', 9753
- loop = asyncio.get_event_loop()
- loop.run_until_complete(start_server(host, port, 'PicoChat'))
- print(f'Running on {host}:{port}')
- loop.run_forever()
|