選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

pico.py 3.8KB

5年前
5年前
5年前
5年前
5年前
5年前
5年前
5年前
5年前
5年前
5年前
5年前
5年前
5年前
5年前
5年前
5年前
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import asyncio
  2. import collections
  3. import datetime
  4. import http
  5. import json
  6. from functools import partial
  7. from pathlib import Path
  8. import websockets
  9. rooms = collections.defaultdict(dict)
  10. class PicoProtocol(websockets.WebSocketServerProtocol):
  11. def serve_file(self, path):
  12. document = Path(__file__, '..', Path(path).name).resolve()
  13. if not document.is_file():
  14. document = Path(__file__, '..', 'pico.html').resolve()
  15. content_type = 'text/html; charset=utf-8'
  16. elif path.endswith('.js'):
  17. content_type = 'application/javascript; charset=utf-8'
  18. elif path.endswith('.css'):
  19. content_type = 'text/css; charset=utf-8'
  20. elif path.endswith('.svg'):
  21. content_type = 'image/svg+xml; charset=utf-8'
  22. else:
  23. content_type = 'text/plain; charset=utf-8'
  24. return (
  25. http.HTTPStatus.OK,
  26. [('Content-Type', content_type)],
  27. document.read_bytes(),
  28. )
  29. async def process_request(self, path, request_headers):
  30. if request_headers.get('Upgrade') != 'websocket':
  31. return self.serve_file(path)
  32. return await super().process_request(path, request_headers)
  33. async def send_json_many(targets, **data):
  34. for websocket in list(targets):
  35. await send_json(websocket, **data)
  36. async def send_json(websocket, **data):
  37. try:
  38. await websocket.send(json.dumps(data))
  39. except websockets.exceptions.ConnectionClosed:
  40. pass
  41. async def recv_json(websocket):
  42. try:
  43. return json.loads(await websocket.recv())
  44. except websockets.exceptions.ConnectionClosed:
  45. return {'kind': 'logout'}
  46. except json.decoder.JSONDecodeError:
  47. return {}
  48. async def core(ws, path, server_name):
  49. room = rooms[path]
  50. usernames = room.keys()
  51. sockets = room.values()
  52. username = None
  53. while True:
  54. data = await recv_json(ws)
  55. ts = datetime.datetime.now().isoformat() + 'Z'
  56. emit = partial(send_json_many, kind=data['kind'], value=data.get('value'), ts=ts)
  57. broadcast = partial(emit, targets=sockets)
  58. reply = partial(emit, targets=[ws])
  59. error = partial(reply, kind='error')
  60. if 'kind' not in data:
  61. await error(value='Message without kind is invalid')
  62. elif data['kind'] == 'login':
  63. username = data['value']
  64. if not username:
  65. await error(value='Username not allowed')
  66. break
  67. if username in usernames:
  68. await error(value='Username taken')
  69. break
  70. others = list(sockets)
  71. room[username] = ws
  72. online = list(usernames)
  73. await reply(online=online)
  74. await broadcast(kind='post', value=f'{username} joined', online=online, targets=others)
  75. await reply(kind='post', value=f'Welcome to {path}')
  76. elif username not in room:
  77. await error(value='Login required')
  78. break
  79. elif data['kind'] == 'logout':
  80. del room[username]
  81. online = list(usernames)
  82. await broadcast(kind='post', value=f'{username} left', online=online)
  83. break
  84. else:
  85. if 'target' in data:
  86. targets = {v for k, v in room.items() if k in {username, data['target']}}
  87. await broadcast(source=username, targets=targets)
  88. else:
  89. await broadcast(source=username)
  90. async def start_server(host, port, server_name):
  91. bound_core = partial(core, server_name=server_name)
  92. return await websockets.serve(bound_core, host, port, create_protocol=PicoProtocol)
  93. if __name__ == '__main__':
  94. host, port = 'localhost', 9753
  95. loop = asyncio.get_event_loop()
  96. loop.run_until_complete(start_server(host, port, 'PicoChat'))
  97. print(f'Running on {host}:{port}')
  98. loop.run_forever()