Skip to content
Snippets Groups Projects
Commit a09220d5 authored by E. Madison Bray's avatar E. Madison Bray
Browse files

[refactoring][#1] split the baseline recsystem in a RenewalRecsystem baseclass

the base class implements the low-level (e.g. websocket client) functionality,
a basic CLI, basic RPC methods (actually just 'ping' for the moment)

the BaselineRecsystem need only implement the initialize() method and the
remaining RPC methods, as well as slightly extend main() to add the --mode
option

contestants may either subclass RenewalRecsystem to start with bare functionality
and build up from scratch, though they may be able to get started quicker by
subclassing BaselineRecsystem and modifying it (especially the recommend() method)
with their algorithms
parent d224db97
Branches
Tags
No related merge requests found
# Python standard library modules # Python standard library modules
import asyncio
import logging
import random import random
import sys
import time
from collections import defaultdict from collections import defaultdict
from functools import partial from functools import partial
from urllib.parse import splittype, urljoin from urllib.parse import urljoin
# Third-party modules # Third-party modules
import aiohttp import aiohttp
import coloredlogs
import objclick as click import objclick as click
import websockets
from jsonrpcserver import async_dispatch as dispatch
from jsonrpcserver.methods import Methods
from jsonrpcserver.response import DictResponse
# Local modules
from .articles import ArticleCollection from .articles import ArticleCollection
from .utils import FileOrToken, format_rpc_call from .recsystem import RenewalRecsystem
from .utils import FileOrToken
class BaselineRecsystem: class BaselineRecsystem(RenewalRecsystem):
NAME = 'baseline' NAME = 'baseline'
ENVVAR_PREFIX = 'RENEWAL'
RENEWAL_API_BASE_URI = 'https://api.renewal-research.com/v1/'
RPC_METHODS = [
'article_interaction',
'assigned_user',
'new_article',
'recommend',
'ping',
'unassigned_user'
]
"""List of method names that implement RPC methods."""
INITIAL_ARTICLES = 1000 INITIAL_ARTICLES = 1000
"""Number of articles to initialize the in-memory article cache with.""" """Number of articles to initialize the in-memory article cache with."""
RECOMMEND_DEFAULT_LIMIT = 30 RECOMMEND_DEFAULT_LIMIT = RenewalRecsystem.RECOMMEND_DEFAULT_LIMIT
"""Max number of recommendations to return by default."""
RETRY_RATE = 5 # seconds
"""
Rate in seconds to retry connecting to the API when connection fails or
is dropped.
"""
articles = None articles = None
"""Articles cache; initialized in `initialize`.""" """Articles cache; initialized in `initialize`."""
...@@ -61,65 +33,11 @@ class BaselineRecsystem: ...@@ -61,65 +33,11 @@ class BaselineRecsystem:
_initialized = False _initialized = False
def __init__(self, api_base_uri=None, token=None, recommendation_mode=None, def __init__(self, api_base_uri=None, token=None, log=None, mode=None):
log=None): super().__init__(api_base_uri=api_base_uri, token=token, log=log)
if api_base_uri is not None:
if api_base_uri[-1] != '/':
# Add trailing slash to make it easier to join URL fragments
# with urljoin()
api_base_uri += '/'
self.api_base_uri = api_base_uri
else:
# set it to the default
self.api_base_uri = self.RENEWAL_API_BASE_URI
self.token = token
if recommendation_mode is not None:
self.recommendation_mode = recommendation_mode
if log is None:
self.log = logging.getLogger(self.NAME)
else:
self.log = log
methods = []
for method_name in self.RPC_METHODS:
method = getattr(self, method_name, None)
if not method:
# TODO: Maybe also validate the signature of each method
self.log.warning(
f'required RPC method {method_name} not implemented')
methods.append(method)
self.methods = Methods(*methods) if mode is not None:
self.recommendation_mode = mode
@property
def client_headers(self):
"""
Returns the headers passed on HTTP(S) requests to the Renewal API.
"""
if self.token is None:
self.log.warning(
f'no authentication token provided; most requests to the '
f'backend with be returned unauthorize except when testing '
f'against a development server')
return {}
else:
return {'Authorization': 'Bearer ' + self.token}
async def initialize_if_needed(self):
if self._initialized:
return
try:
await self.initialize()
except:
raise
else:
self._initialized = True
async def initialize(self): async def initialize(self):
""" """
...@@ -152,15 +70,6 @@ class BaselineRecsystem: ...@@ -152,15 +70,6 @@ class BaselineRecsystem:
########## RPC methods ########## ########## RPC methods ##########
# WARNING: Don't forget to make these functions async even if they # WARNING: Don't forget to make these functions async even if they
# don't use await, otherwise the async_dispatch gets confused. # don't use await, otherwise the async_dispatch gets confused.
async def ping(self):
"""
Recsystem heartbeat test.
Must just return with the value ``'pong'``.
"""
return 'pong'
async def new_article(self, article): async def new_article(self, article):
"""Called when a new article was made available from the backend.""" """Called when a new article was made available from the backend."""
...@@ -310,173 +219,45 @@ class BaselineRecsystem: ...@@ -310,173 +219,45 @@ class BaselineRecsystem:
return 1 return 1
return max(1, m['clicks']) * max(1, (m['likes'] - m['dislikes'])) return max(1, m['clicks']) * max(1, (m['likes'] - m['dislikes']))
########## websocket server loops ##########
async def request_loop(self):
"""
Main loop of the recsystem application.
Connects to the event stream websocket and starts a loop to receive and
handle events from the backend.
"""
self.log.info(f'initializing websocket connection to event stream')
uri = urljoin('ws:' + splittype(self.api_base_uri)[1], 'event_stream')
websocket_connect = partial(websockets.connect, uri,
extra_headers=self.client_headers)
async with websocket_connect() as websocket:
self.log.info(f'listening to websocket for events...')
# Incoming RPC requests are added to this queue, and their results are
# popped off the queue and sent; the queue is used as a means of
# serializing responses, otherwise we could have multiple coroutines
# concurrently trying to write to the same websocket
queue = asyncio.Queue()
# Start the incoming and outgoing message handlers; a slight variant of
# this pattern:
# https://websockets.readthedocs.io/en/stable/intro.html#both
await self.multiplex_tasks(self.handle_incoming(websocket, queue),
self.handle_outgoing(websocket, queue))
@staticmethod
async def multiplex_tasks(*tasks):
"""
Run multiple coroutines simultaneously as tasks, exiting as soon as any
one of them raises an exception.
The exception from the coroutine is then re-raised.
"""
done, pending = await asyncio.wait(tasks,
return_when=asyncio.FIRST_EXCEPTION)
try:
for task in done:
# If one of the tasks exited with an exception
# Calling .result() re-raises that exception
task.result()
finally:
for task in pending:
task.cancel()
async def dispatch_incoming(self, queue, request):
"""
Dispatch incoming messages to the JSON-RPC method dispatcher.
When the result is ready it is placed on the outgoing queue.
"""
response = await dispatch(request, methods=self.methods)
self.log.info(format_rpc_call(request, response))
await queue.put(response)
async def handle_incoming(self, websocket, queue):
"""
This coroutine checks the websocket for incoming JSON-RPC requests and
passes them to `dispatch_incoming`.
"""
while True:
request = await websocket.recv()
future = asyncio.ensure_future(
self.dispatch_incoming(queue, request))
def callback(future):
try:
future.result()
except Exception as exc:
self.log.exception(
f'unhandled exception dispatching request: '
f'{request}; this indicates an error in the RPC '
f'method implementation')
future.add_done_callback(callback)
async def handle_outgoing(self, websocket, queue):
"""
This coroutine checks the outgoing response queue for results from
dispatched RPC methods, and sends them on the websocket.
"""
while True:
response = await queue.get()
if response.wanted:
await websocket.send(str(response))
def run(self):
"""
Run the main event loop for the recommendation system.
Starts by calling `initialize`, and once that is complete starts up
the websocket connection and runs forever.
"""
self.log.info(
f'starting up {self.NAME} recsystem on {self.api_base_uri}')
loop = asyncio.get_event_loop()
try:
while True:
try:
loop.run_until_complete(self.initialize_if_needed())
loop.run_until_complete(self.request_loop())
except (websockets.WebSocketException, ConnectionRefusedError):
self.log.warning(
'lost connection to the backend; trying to '
're-establish...')
time.sleep(5)
except KeyboardInterrupt:
return
finally:
# Cancel all pending tasks
for task in asyncio.Task.all_tasks(loop=loop):
task.cancel()
try:
# Give the task a chance to finish up
loop.run_until_complete(task)
except Exception:
# This may result in a CancelledError or other miscellaneous
# exceptions as connections are shut down, but we are exiting
# anyways so ignore them.
pass
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()
@click.classcommand() @click.classcommand()
@click.option('-a', '--api-base-uri', default=RENEWAL_API_BASE_URI, @click.option('-a', '--api-base-uri',
default=RenewalRecsystem.RENEWAL_API_BASE_URI,
help='URI for the Renewal HTTP API') help='URI for the Renewal HTTP API')
@click.option('-t', '--token', required=True, type=FileOrToken(), @click.option('-t', '--token', required=True, type=FileOrToken(),
help='authentication token for the recsystem; if a valid ' help='authentication token for the recsystem; if a valid '
'filename is given the token is read from a file ' 'filename is given the token is read from a file '
'instead') 'instead')
@click.option('--log-level', default='INFO',
type=click.Choice(['DEBUG', 'INFO', 'WARNING', 'ERROR'],
case_sensitive=False),
help='minimum log level to output')
@click.option('-m', '--mode', type=click.Choice(['random', 'popular']), @click.option('-m', '--mode', type=click.Choice(['random', 'popular']),
default='random', default='random',
help='the recommendation mode: random simply returns a ' help='the recommendation mode: random simply returns a '
'random selection of articles, whereas popular returns ' 'random selection of articles, whereas popular returns '
'the most popular (in terms of rating and clicks) ' 'the most popular (in terms of rating and clicks) '
'articles in the requested range') 'articles in the requested range')
@click.option('--log-level', default='INFO', def main(cls, api_base_uri, token, log_level, mode, **kwargs):
type=click.Choice(['DEBUG', 'INFO', 'WARNING', 'ERROR'], """
case_sensitive=False), Recsystem command line interface.
help='minimum log level to output')
def main(cls, api_base_uri, token, mode, log_level): The arguments ``api_base_uri``, ``token``, and ``log_level`` and
"""Recsystem command line interface.""" ``mode`` are passed by `click` via the command line interface.
Additional ``kwargs`` may be passed by subclasses that add additional
# Initialize logging options to the CLI, and their values are passed to the constructor of
logging.basicConfig(level=log_level) the subclass.
log = logging.getLogger(cls.NAME)
log.setLevel(log_level) .. note::
coloredlogs.install(level=log_level, logger=log)
# Log all uncaught exceptions When overriding the `main` function in a subclass it is necessary
sys.excepthook = lambda *exc_info: log.exception( to copy all of the decorators above it, including the
'an uncaught exception occurred', exc_info=exc_info) ``@click.classcommand()`` and all ``@click.option()`` decorators
in addition to any new options. This may be improved in a future
token = token.read().strip() version.
recsystem = cls(api_base_uri=api_base_uri, token=token, """
recommendation_mode=mode, log=log)
recsystem.run() return super().main.callback(api_base_uri=api_base_uri, token=token,
log_level=log_level, mode=mode, **kwargs)
if __name__ == '__main__': if __name__ == '__main__':
......
import abc
import asyncio
import logging
import sys
import time
from functools import partial
from urllib.parse import splittype, urljoin
# Third-party modules
import coloredlogs
import objclick as click
import websockets
from jsonrpcserver import async_dispatch as dispatch
from jsonrpcserver.methods import Methods
# Local modules
from .utils import FileOrToken, format_rpc_call
class RenewalRecsystem(metaclass=abc.ABCMeta):
NAME = abc.abstractproperty()
"""
Display name for this recsystem.
Must be assigned by subclasses.
"""
ENVVAR_PREFIX = 'RENEWAL'
"""
Prefix for environment variables that can be used to pass arguments to
the recsystem's CLI.
E.g. If the CLI has an argument named ``token`` this can be set either
by passing ``--token=<token>`` at the command line, or by setting the
environment variable ``RENEWAL_TOKEN=<token>``.
"""
RENEWAL_API_BASE_URI = 'https://api.renewal-research.com/v1/'
"""Default URI for the Renewal API."""
RPC_METHODS = [
'article_interaction',
'assigned_user',
'new_article',
'recommend',
'ping',
'unassigned_user'
]
"""List of method names that implement RPC methods."""
RECOMMEND_DEFAULT_LIMIT = 30
"""Max number of recommendations to return by default."""
RETRY_RATE = 5 # seconds
"""
Rate in seconds to retry connecting to the API when connection fails or
is dropped.
"""
_initialized = False
def __init__(self, api_base_uri=None, token=None, log=None):
if api_base_uri is not None:
if api_base_uri[-1] != '/':
# Add trailing slash to make it easier to join URL fragments
# with urljoin()
api_base_uri += '/'
self.api_base_uri = api_base_uri
else:
# set it to the default
self.api_base_uri = self.RENEWAL_API_BASE_URI
self.token = token
if log is None:
self.log = logging.getLogger(self.NAME)
else:
self.log = log
methods = []
for method_name in self.RPC_METHODS:
method = getattr(self, method_name, None)
if not method:
# TODO: Maybe also validate the signature of each method
self.log.warning(
f'required RPC method {method_name} not implemented')
methods.append(method)
self.methods = Methods(*methods)
@property
def client_headers(self):
"""
Returns the headers passed on HTTP(S) requests to the Renewal API.
"""
if self.token is None:
self.log.warning(
f'no authentication token provided; most requests to the '
f'backend with be returned unauthorize except when testing '
f'against a development server')
return {}
else:
return {'Authorization': 'Bearer ' + self.token}
async def initialize_if_needed(self):
if self._initialized:
return
try:
await self.initialize()
except:
raise
else:
self._initialized = True
@abc.abstractmethod
async def initialize(self):
"""
Start-up tasks to perform before starting the main client loop.
* Download a initial set of recent articles to work with.
* Download the current set of users assigned to the recsystem.
Must be implemented by subclasses.
"""
########## RPC methods ##########
# WARNING: Don't forget to make these functions async even if they
# don't use await, otherwise the async_dispatch gets confused.
async def ping(self):
"""
Recsystem heartbeat test.
Must just return with the value ``'pong'``.
"""
return 'pong'
@abc.abstractmethod
async def new_article(self, article):
"""Called when a new article was made available from the backend."""
@abc.abstractmethod
async def article_interaction(self, interaction):
"""
Called when a user interacts with an article--this is received for all
users, including users not currently assigned to the recsystem.
This allows maintaining the recsystem's own up-to-date metrics on
interactions with each article in its local database of articles.
Interactions currently include:
* Likes
* Dislikes
* Bookmarks
* Clicks (the user clicked on the article)
with more to be added later.
"""
@abc.abstractmethod
async def recommend(self, user_id, limit=RECOMMEND_DEFAULT_LIMIT,
since_id=None, max_id=None):
"""
Return recommendations for the specified user and article ID range.
"""
@abc.abstractmethod
async def assigned_user(self, user_id):
"""
Called when the controller has assigned a new user to the recsystem.
The recsystem will in general only receive recommendation requests for
users it is actively assigned to (though it may be sent requests for
unassigned users for test purposes). However, the recsystem can use
this to maintain a set of users for whom it should be actively
processing data.
"""
@abc.abstractmethod
async def unassigned_user(self, user_id):
"""
Called when the controller removes a user assignment from the
recsystem.
See also `assigned_user`.
"""
########## websocket server loops ##########
async def request_loop(self):
"""
Main loop of the recsystem application.
Connects to the event stream websocket and starts a loop to receive and
handle events from the backend.
"""
self.log.info(f'initializing websocket connection to event stream')
uri = urljoin('ws:' + splittype(self.api_base_uri)[1], 'event_stream')
websocket_connect = partial(websockets.connect, uri,
extra_headers=self.client_headers)
async with websocket_connect() as websocket:
self.log.info(f'listening to websocket for events...')
# Incoming RPC requests are added to this queue, and their results are
# popped off the queue and sent; the queue is used as a means of
# serializing responses, otherwise we could have multiple coroutines
# concurrently trying to write to the same websocket
queue = asyncio.Queue()
# Start the incoming and outgoing message handlers; a slight variant of
# this pattern:
# https://websockets.readthedocs.io/en/stable/intro.html#both
await self.multiplex_tasks(self.handle_incoming(websocket, queue),
self.handle_outgoing(websocket, queue))
@staticmethod
async def multiplex_tasks(*tasks):
"""
Run multiple coroutines simultaneously as tasks, exiting as soon as any
one of them raises an exception.
The exception from the coroutine is then re-raised.
"""
done, pending = await asyncio.wait(tasks,
return_when=asyncio.FIRST_EXCEPTION)
try:
for task in done:
# If one of the tasks exited with an exception
# Calling .result() re-raises that exception
task.result()
finally:
for task in pending:
task.cancel()
async def dispatch_incoming(self, queue, request):
"""
Dispatch incoming messages to the JSON-RPC method dispatcher.
When the result is ready it is placed on the outgoing queue.
"""
response = await dispatch(request, methods=self.methods)
self.log.info(format_rpc_call(request, response))
await queue.put(response)
async def handle_incoming(self, websocket, queue):
"""
This coroutine checks the websocket for incoming JSON-RPC requests and
passes them to `dispatch_incoming`.
"""
while True:
request = await websocket.recv()
future = asyncio.ensure_future(
self.dispatch_incoming(queue, request))
def callback(future):
try:
future.result()
except Exception as exc:
self.log.exception(
f'unhandled exception dispatching request: '
f'{request}; this indicates an error in the RPC '
f'method implementation')
future.add_done_callback(callback)
async def handle_outgoing(self, websocket, queue):
"""
This coroutine checks the outgoing response queue for results from
dispatched RPC methods, and sends them on the websocket.
"""
while True:
response = await queue.get()
if response.wanted:
await websocket.send(str(response))
def run(self):
"""
Run the main event loop for the recommendation system.
Starts by calling `initialize`, and once that is complete starts up
the websocket connection and runs forever.
"""
self.log.info(
f'starting up {self.NAME} recsystem on {self.api_base_uri}')
loop = asyncio.get_event_loop()
try:
while True:
try:
loop.run_until_complete(self.initialize_if_needed())
loop.run_until_complete(self.request_loop())
except (websockets.WebSocketException, ConnectionRefusedError):
self.log.warning(
'lost connection to the backend; trying to '
're-establish...')
time.sleep(self.RETRY_RATE)
except KeyboardInterrupt:
return
finally:
# Cancel all pending tasks
for task in asyncio.Task.all_tasks(loop=loop):
task.cancel()
try:
# Give the task a chance to finish up
loop.run_until_complete(task)
except Exception:
# This may result in a CancelledError or other miscellaneous
# exceptions as connections are shut down, but we are exiting
# anyways so ignore them.
pass
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()
@click.classcommand()
@click.option('-a', '--api-base-uri', default=RENEWAL_API_BASE_URI,
help='URI for the Renewal HTTP API')
@click.option('-t', '--token', required=True, type=FileOrToken(),
help='authentication token for the recsystem; if a valid '
'filename is given the token is read from a file '
'instead')
@click.option('--log-level', default='INFO',
type=click.Choice(['DEBUG', 'INFO', 'WARNING', 'ERROR'],
case_sensitive=False),
help='minimum log level to output')
def main(cls, api_base_uri, token, log_level, **kwargs):
"""
Recsystem command line interface.
The arguments ``api_base_uri``, ``token``, and ``log_level`` are
passed by `click` via the command line interface. Additional
``kwargs`` may be passed by subclasses that add additional options to
the CLI, and their values are passed to the constructor of the
subclass.
.. note::
When overriding the `main` function in a subclass it is necessary
to copy all of the decorators above it, including the
``@click.classcommand()`` and all ``@click.option()`` decorators
in addition to any new options. This may be improved in a future
version.
"""
# Initialize logging
logging.basicConfig(level=log_level)
log = logging.getLogger(cls.NAME)
log.setLevel(log_level)
coloredlogs.install(level=log_level, logger=log)
# Log all uncaught exceptions
sys.excepthook = lambda *exc_info: log.exception(
'an uncaught exception occurred', exc_info=exc_info)
token = token.read().strip()
recsystem = cls(api_base_uri=api_base_uri, token=token, log=log,
**kwargs)
recsystem.run()
...@@ -3,6 +3,7 @@ import json ...@@ -3,6 +3,7 @@ import json
import jwt import jwt
import objclick as click import objclick as click
from jsonrpcserver.response import DictResponse
def format_rpc_call(request, response=None): def format_rpc_call(request, response=None):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment