Skip to content
Snippets Groups Projects
Commit aa009c94 authored by E. Madison Bray's avatar E. Madison Bray
Browse files

[refactoring][#1] further split the RenewalRecsystem base class into a new

JSONRPCServerWebsocketClient class that implements the asyncio loops

the long name is to make clear that this implements a websocket client,
which once connected to the websocket server provides a JSON-RPC server

everything in this code is almost completely abstracted from the details
of Renewal, allowing this code to be tested independently of the Renewal
details
parent cd145e02
No related branches found
No related tags found
No related merge requests found
import abc
import asyncio
import logging
import sys
import time
from functools import partial
from urllib.parse import splittype, urljoin
# Third-party modules
import coloredlogs
import objclick as click
import websockets
from jsonrpcserver import async_dispatch as dispatch
from jsonrpcserver.methods import Methods
# Local modules
from .server import JSONRPCServerWebsocketClient
from .utils import FileOrToken, format_rpc_call
class RenewalRecsystem(metaclass=abc.ABCMeta):
class RenewalRecsystem(JSONRPCServerWebsocketClient, metaclass=abc.ABCMeta):
NAME = abc.abstractproperty()
"""
Display name for this recsystem.
......@@ -51,42 +46,28 @@ class RenewalRecsystem(metaclass=abc.ABCMeta):
RECOMMEND_DEFAULT_LIMIT = 30
"""Max number of recommendations to return by default."""
RETRY_RATE = 5 # seconds
"""
Rate in seconds to retry connecting to the API when connection fails or
is dropped.
"""
_initialized = False
EVENT_STREAM_ENDPOINT = '/event_stream'
"""This is the endpoint on the Renewal API to connect to the websocket."""
def __init__(self, api_base_uri=None, token=None, log=None):
if api_base_uri is not None:
if api_base_uri[-1] != '/':
# Add trailing slash to make it easier to join URL fragments
# with urljoin()
api_base_uri += '/'
self.api_base_uri = api_base_uri
else:
# set it to the default
self.api_base_uri = self.RENEWAL_API_BASE_URI
self.token = token
if log is None:
self.log = logging.getLogger(self.NAME)
else:
self.log = log
if self.api_base_uri[-1] != '/':
# Add trailing slash to make it easier to join URL fragments
# with urljoin()
self.api_base_uri += '/'
methods = []
for method_name in self.RPC_METHODS:
method = getattr(self, method_name, None)
if not method:
# TODO: Maybe also validate the signature of each method
self.log.warning(
f'required RPC method {method_name} not implemented')
methods.append(method)
proto, uri = splittype(self.api_base_uri)
ws_proto = 'wss' if proto == 'https' else 'ws'
websocket_uri = urljoin(ws_proto + ':' + uri,
self.EVENT_STREAM_ENDPOINT.lstrip('/'))
self.methods = Methods(*methods)
self.token = token
super().__init__(websocket_uri)
@property
def client_headers(self):
......@@ -103,17 +84,6 @@ class RenewalRecsystem(metaclass=abc.ABCMeta):
else:
return {'Authorization': 'Bearer ' + self.token}
async def initialize_if_needed(self):
if self._initialized:
return
try:
await self.initialize()
except:
raise
else:
self._initialized = True
@abc.abstractmethod
async def initialize(self):
"""
......@@ -188,100 +158,6 @@ class RenewalRecsystem(metaclass=abc.ABCMeta):
See also `assigned_user`.
"""
########## websocket server loops ##########
async def request_loop(self):
"""
Main loop of the recsystem application.
Connects to the event stream websocket and starts a loop to receive and
handle events from the backend.
"""
self.log.info(f'initializing websocket connection to event stream')
uri = urljoin('ws:' + splittype(self.api_base_uri)[1], 'event_stream')
websocket_connect = partial(websockets.connect, uri,
extra_headers=self.client_headers)
async with websocket_connect() as websocket:
self.log.info(f'listening to websocket for events...')
# Incoming RPC requests are added to this queue, and their results are
# popped off the queue and sent; the queue is used as a means of
# serializing responses, otherwise we could have multiple coroutines
# concurrently trying to write to the same websocket
queue = asyncio.Queue()
# Start the incoming and outgoing message handlers; a slight variant of
# this pattern:
# https://websockets.readthedocs.io/en/stable/intro.html#both
await self.multiplex_tasks(self.handle_incoming(websocket, queue),
self.handle_outgoing(websocket, queue))
@staticmethod
async def multiplex_tasks(*tasks):
"""
Run multiple coroutines simultaneously as tasks, exiting as soon as any
one of them raises an exception.
The exception from the coroutine is then re-raised.
"""
done, pending = await asyncio.wait(tasks,
return_when=asyncio.FIRST_EXCEPTION)
try:
for task in done:
# If one of the tasks exited with an exception
# Calling .result() re-raises that exception
task.result()
finally:
for task in pending:
task.cancel()
async def dispatch_incoming(self, queue, request):
"""
Dispatch incoming messages to the JSON-RPC method dispatcher.
When the result is ready it is placed on the outgoing queue.
"""
response = await dispatch(request, methods=self.methods)
self.log.info(format_rpc_call(request, response))
await queue.put(response)
async def handle_incoming(self, websocket, queue):
"""
This coroutine checks the websocket for incoming JSON-RPC requests and
passes them to `dispatch_incoming`.
"""
while True:
request = await websocket.recv()
future = asyncio.ensure_future(
self.dispatch_incoming(queue, request))
def callback(future):
try:
future.result()
except Exception as exc:
self.log.exception(
f'unhandled exception dispatching request: '
f'{request}; this indicates an error in the RPC '
f'method implementation')
future.add_done_callback(callback)
async def handle_outgoing(self, websocket, queue):
"""
This coroutine checks the outgoing response queue for results from
dispatched RPC methods, and sends them on the websocket.
"""
while True:
response = await queue.get()
if response.wanted:
await websocket.send(str(response))
def run(self):
"""
Run the main event loop for the recommendation system.
......@@ -292,35 +168,7 @@ class RenewalRecsystem(metaclass=abc.ABCMeta):
self.log.info(
f'starting up {self.NAME} recsystem on {self.api_base_uri}')
loop = asyncio.get_event_loop()
try:
while True:
try:
loop.run_until_complete(self.initialize_if_needed())
loop.run_until_complete(self.request_loop())
except (websockets.WebSocketException, ConnectionRefusedError):
self.log.warning(
'lost connection to the backend; trying to '
're-establish...')
time.sleep(self.RETRY_RATE)
except KeyboardInterrupt:
return
finally:
# Cancel all pending tasks
for task in asyncio.Task.all_tasks(loop=loop):
task.cancel()
try:
# Give the task a chance to finish up
loop.run_until_complete(task)
except Exception:
# This may result in a CancelledError or other miscellaneous
# exceptions as connections are shut down, but we are exiting
# anyways so ignore them.
pass
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()
super().run()
@click.classcommand()
@click.option('-a', '--api-base-uri', default=RENEWAL_API_BASE_URI,
......
import abc
import asyncio
import logging
import time
from functools import partial
# Third-party modules
import websockets
from jsonrpcserver import async_dispatch as dispatch
from jsonrpcserver.methods import Methods
# Local modules
from .utils import format_rpc_call
class JSONRPCServerWebsocketClient(metaclass=abc.ABCMeta):
"""
Server code for running a websocket server that accepts and responds to
JSON-RPC requests made by a remote server that is connected to via a
websocket.
It can handle multiple concurrent requests and response to them out of
order in which they were received.
To be clear, this is a *client* of the websocket protocol (it connects to
a web server that hosts websocket connections) but is a *server* of the
JSON-RPC protocol (it response to RPC requests sent over the websocket
connection).
This class is intended to be subclassed:
* The names of RPC methods implemented by this server must be provided in
the `RPC_METHODS` list, and the subclass should have methods of the same
names.
* An optional `initialize` method may be implemented to perform any tasks
prior to connecting to the websocket.
* An optional `client_headers` property/attribute may be used to pass
additional HTTP headers (e.g. for authorization) when making the
websocket connection.
Parameters
----------
uri : str
Full URI of the websocket to connect to.
"""
RPC_METHODS = abc.abstractproperty()
"""List of method names that implement RPC methods."""
RETRY_RATE = 5 # seconds
"""
Rate in seconds to retry connecting to the API when connection fails or
is dropped.
"""
_initialized = False
def __init__(self, websocket_uri, log=None):
self.websocket_uri = websocket_uri
if log is None:
self.log = logging.getLogger(self.NAME)
else:
self.log = log
methods = []
for method_name in self.RPC_METHODS:
method = getattr(self, method_name, None)
if not method:
# TODO: Maybe also validate the signature of each method
self.log.warning(
f'required RPC method {method_name} not implemented')
methods.append(method)
self.methods = Methods(*methods)
@property
def client_headers(self):
"""
Returns the HTTP headers to be passed when opening the websocket
connection.
"""
return {}
async def initialize_if_needed(self):
if self._initialized:
return
try:
await self.initialize()
except:
raise
else:
self._initialized = True
async def initialize(self):
"""
Start-up tasks to perform before starting the main websocket client
loop.
"""
pass
async def request_loop(self):
"""
Main loop of the recsystem application.
Connects to the event stream websocket and starts a loop to receive and
handle events from the backend.
"""
self.log.info(
f'initializing websocket connection to {self.websocket_uri}')
websocket_connect = partial(websockets.connect, self.websocket_uri,
extra_headers=self.client_headers)
async with websocket_connect() as websocket:
self.log.info(f'listening to websocket for RPC requests...')
# Incoming RPC requests are added to this queue, and their results are
# popped off the queue and sent; the queue is used as a means of
# serializing responses, otherwise we could have multiple coroutines
# concurrently trying to write to the same websocket
queue = asyncio.Queue()
# Start the incoming and outgoing message handlers; a slight variant of
# this pattern:
# https://websockets.readthedocs.io/en/stable/intro.html#both
await self.multiplex_tasks(self.handle_incoming(websocket, queue),
self.handle_outgoing(websocket, queue))
@staticmethod
async def multiplex_tasks(*tasks):
"""
Run multiple coroutines simultaneously as tasks, exiting as soon as any
one of them raises an exception.
The exception from the coroutine is then re-raised.
"""
done, pending = await asyncio.wait(tasks,
return_when=asyncio.FIRST_EXCEPTION)
try:
for task in done:
# If one of the tasks exited with an exception
# Calling .result() re-raises that exception
task.result()
finally:
for task in pending:
task.cancel()
async def dispatch_incoming(self, queue, request):
"""
Dispatch incoming messages to the JSON-RPC method dispatcher.
When the result is ready it is placed on the outgoing queue.
"""
response = await dispatch(request, methods=self.methods)
self.log.info(format_rpc_call(request, response))
await queue.put(response)
async def handle_incoming(self, websocket, queue):
"""
This coroutine checks the websocket for incoming JSON-RPC requests and
passes them to `dispatch_incoming`.
"""
while True:
request = await websocket.recv()
future = asyncio.ensure_future(
self.dispatch_incoming(queue, request))
def callback(future):
try:
future.result()
except Exception as exc:
self.log.exception(
f'unhandled exception dispatching request: '
f'{request}; this indicates an error in the RPC '
f'method implementation')
future.add_done_callback(callback)
async def handle_outgoing(self, websocket, queue):
"""
This coroutine checks the outgoing response queue for results from
dispatched RPC methods, and sends them on the websocket.
"""
while True:
response = await queue.get()
if response.wanted:
await websocket.send(str(response))
def run(self):
"""
Run the main event loop for the recommendation system.
Starts by calling `initialize`, and once that is complete starts up
the websocket connection and runs forever.
"""
loop = asyncio.get_event_loop()
try:
while True:
try:
loop.run_until_complete(self.initialize_if_needed())
loop.run_until_complete(self.request_loop())
except (websockets.WebSocketException, ConnectionRefusedError):
self.log.warning(
'lost connection to the websocket server; trying to '
're-establish...')
time.sleep(self.RETRY_RATE)
except KeyboardInterrupt:
return
finally:
self.log.info('shutting down cleanly...')
# Cancel all pending tasks
for task in asyncio.Task.all_tasks(loop=loop):
task.cancel()
try:
# Give the task a chance to finish up
loop.run_until_complete(task)
except Exception:
# This may result in a CancelledError or other miscellaneous
# exceptions as connections are shut down, but we are exiting
# anyways so ignore them.
pass
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()
self.log.info('done')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment