Skip to content
Snippets Groups Projects
Select Git revision
  • ca68c1a90a0897c4def24d70d895367b03a50138
  • master default protected
  • embray/issue-22
  • documentation/init
  • ci/tag-master
  • V1.0
6 results

baseline.py

Blame
  • recsystem.py 7.14 KiB
    import abc
    import logging
    import sys
    from urllib.parse import splittype, urljoin
    
    # Third-party modules
    import coloredlogs
    import objclick as click
    
    # Local modules
    from .server import JSONRPCServerWebsocketClient
    from .utils import FileOrToken, format_rpc_call
    
    
    class RenewalRecsystem(JSONRPCServerWebsocketClient, metaclass=abc.ABCMeta):
        NAME = abc.abstractproperty()
        """
        Display name for this recsystem.
    
        Must be assigned by subclasses.
        """
    
        ENVVAR_PREFIX = 'RENEWAL'
        """
        Prefix for environment variables that can be used to pass arguments to
        the recsystem's CLI.
    
        E.g. If the CLI has an argument named ``token`` this can be set either
        by passing ``--token=<token>`` at the command line, or by setting the
        environment variable ``RENEWAL_TOKEN=<token>``.
        """
    
        RENEWAL_API_BASE_URI = 'https://api.renewal-research.com/v1/'
        """Default URI for the Renewal API."""
    
        RPC_METHODS = [
            'article_interaction',
            'assigned_user',
            'new_article',
            'recommend',
            'ping',
            'unassigned_user'
        ]
        """List of method names that implement RPC methods."""
    
        RECOMMEND_DEFAULT_LIMIT = 30
        """Max number of recommendations to return by default."""
    
        EVENT_STREAM_ENDPOINT = '/event_stream'
        """This is the endpoint on the Renewal API to connect to the websocket."""
    
        def __init__(self, api_base_uri=None, token=None, log=None):
            if api_base_uri is not None:
                self.api_base_uri = api_base_uri
            else:
                # set it to the default
                self.api_base_uri = self.RENEWAL_API_BASE_URI
    
            if self.api_base_uri[-1] != '/':
                # Add trailing slash to make it easier to join URL fragments
                # with urljoin()
                self.api_base_uri += '/'
    
            proto, uri = splittype(self.api_base_uri)
            ws_proto = 'wss' if proto == 'https' else 'ws'
            websocket_uri = urljoin(ws_proto + ':' + uri,
                                    self.EVENT_STREAM_ENDPOINT.lstrip('/'))
    
            self.token = token
            super().__init__(websocket_uri)
    
        @property
        def client_headers(self):
            """
            Returns the headers passed on HTTP(S) requests to the Renewal API.
            """
    
            if self.token is None:
                self.log.warning(
                    f'no authentication token provided; most requests to the '
                    f'backend with be returned unauthorize except when testing '
                    f'against a development server')
                return {}
            else:
                return {'Authorization': 'Bearer ' + self.token}
    
        @abc.abstractmethod
        async def initialize(self):
            """
            Start-up tasks to perform before starting the main client loop.
    
            * Download a initial set of recent articles to work with.
            * Download the current set of users assigned to the recsystem.
    
            Must be implemented by subclasses.
            """
    
        ########## RPC methods ##########
        # WARNING: Don't forget to make these functions async even if they
        # don't use await, otherwise the async_dispatch gets confused.
        async def ping(self):
            """
            Recsystem heartbeat test.
    
            Must just return with the value ``'pong'``.
            """
    
            return 'pong'
    
        @abc.abstractmethod
        async def new_article(self, article):
            """Called when a new article was made available from the backend."""
    
        @abc.abstractmethod
        async def article_interaction(self, interaction):
            """
            Called when a user interacts with an article--this is received for all
            users, including users not currently assigned to the recsystem.
    
            This allows maintaining the recsystem's own up-to-date metrics on
            interactions with each article in its local database of articles.
    
            Interactions currently include:
    
            * Likes
            * Dislikes
            * Bookmarks
            * Clicks (the user clicked on the article)
    
            with more to be added later.
            """
    
        @abc.abstractmethod
        async def recommend(self, user_id, limit=RECOMMEND_DEFAULT_LIMIT,
                            since_id=None, max_id=None):
            """
            Return recommendations for the specified user and article ID range.
            """
    
        @abc.abstractmethod
        async def assigned_user(self, user_id):
            """
            Called when the controller has assigned a new user to the recsystem.
    
            The recsystem will in general only receive recommendation requests for
            users it is actively assigned to (though it may be sent requests for
            unassigned users for test purposes).  However, the recsystem can use
            this to maintain a set of users for whom it should be actively
            processing data.
            """
    
        @abc.abstractmethod
        async def unassigned_user(self, user_id):
            """
            Called when the controller removes a user assignment from the
            recsystem.
    
            See also `assigned_user`.
            """
    
        def run(self):
            """
            Run the main event loop for the recommendation system.
    
            Starts by calling `initialize`, and once that is complete starts up
            the websocket connection and runs forever.
            """
    
            self.log.info(
                    f'starting up {self.NAME} recsystem on {self.api_base_uri}')
            super().run()
    
        @click.classcommand()
        @click.option('-a', '--api-base-uri', default=RENEWAL_API_BASE_URI,
                      help='URI for the Renewal HTTP API')
        @click.option('-t', '--token', required=True, type=FileOrToken(),
                      help='authentication token for the recsystem; if a valid '
                           'filename is given the token is read from a file '
                           'instead')
        @click.option('--log-level', default='INFO',
                      type=click.Choice(['DEBUG', 'INFO', 'WARNING', 'ERROR'],
                                        case_sensitive=False),
                      help='minimum log level to output')
        def main(cls, api_base_uri, token, log_level, **kwargs):
            """
            Recsystem command line interface.
    
            The arguments ``api_base_uri``, ``token``, and ``log_level`` are
            passed by `click` via the command line interface.  Additional
            ``kwargs`` may be passed by subclasses that add additional options to
            the CLI, and their values are passed to the constructor of the
            subclass.
    
            .. note::
    
                When overriding the `main` function in a subclass it is necessary
                to copy all of the decorators above it, including the
                ``@click.classcommand()`` and all ``@click.option()`` decorators
                in addition to any new options.  This may be improved in a future
                version.
            """
    
            # Initialize logging
            logging.basicConfig(level=log_level)
            log = logging.getLogger(cls.NAME)
            log.setLevel(log_level)
            coloredlogs.install(level=log_level, logger=log)
            # Log all uncaught exceptions
            sys.excepthook = lambda *exc_info: log.exception(
                    'an uncaught exception occurred', exc_info=exc_info)
    
            token = token.read().strip()
            recsystem = cls(api_base_uri=api_base_uri, token=token, log=log,
                            **kwargs)
            recsystem.run()