Source code for aries_cloudagent.admin.server

"""Admin server classes."""

import asyncio
import logging
from typing import Callable, Coroutine, Sequence, Set
import uuid

from aiohttp import web
from aiohttp_apispec import docs, response_schema, setup_aiohttp_apispec
import aiohttp_cors

from marshmallow import fields, Schema

from ..config.injection_context import InjectionContext
from ..core.plugin_registry import PluginRegistry
from ..messaging.responder import BaseResponder
from ..transport.queue.basic import BasicMessageQueue
from ..transport.outbound.message import OutboundMessage
from ..utils.stats import Collector
from ..utils.task_queue import TaskQueue
from ..version import __version__

from .base_server import BaseAdminServer
from .error import AdminSetupError

LOGGER = logging.getLogger(__name__)


[docs]class AdminModulesSchema(Schema): """Schema for the modules endpoint.""" result = fields.List( fields.Str(description="admin module"), description="List of admin modules" )
[docs]class AdminStatusSchema(Schema): """Schema for the status endpoint."""
[docs]class AdminResponder(BaseResponder): """Handle outgoing messages from message handlers.""" def __init__( self, context: InjectionContext, send: Coroutine, webhook: Coroutine, **kwargs ): """ Initialize an instance of `AdminResponder`. Args: send: Function to send outbound message """ super().__init__(**kwargs) self._context = context self._send = send self._webhook = webhook
[docs] async def send_outbound(self, message: OutboundMessage): """ Send outbound message. Args: message: The `OutboundMessage` to be sent """ await self._send(self._context, message)
[docs] async def send_webhook(self, topic: str, payload: dict): """ Dispatch a webhook. Args: topic: the webhook topic identifier payload: the webhook payload value """ await self._webhook(topic, payload)
[docs]class WebhookTarget: """Class for managing webhook target information.""" def __init__( self, endpoint: str, topic_filter: Sequence[str] = None, max_attempts: int = None, ): """Initialize the webhook target.""" self.endpoint = endpoint self.max_attempts = max_attempts self._topic_filter = None self.topic_filter = topic_filter # call setter @property def topic_filter(self) -> Set[str]: """Accessor for the target's topic filter.""" return self._topic_filter @topic_filter.setter def topic_filter(self, val: Sequence[str]): """Setter for the target's topic filter.""" filter = set(val) if val else None if filter and "*" in filter: filter = None self._topic_filter = filter
[docs]class AdminServer(BaseAdminServer): """Admin HTTP server class.""" def __init__( self, host: str, port: int, context: InjectionContext, outbound_message_router: Coroutine, webhook_router: Callable, task_queue: TaskQueue = None, conductor_stats: Coroutine = None, ): """ Initialize an AdminServer instance. Args: host: Host to listen on port: Port to listen on context: The application context instance outbound_message_router: Coroutine for delivering outbound messages webhook_router: Callable for delivering webhooks task_queue: An optional task queue for handlers """ self.app = None self.host = host self.port = port self.conductor_stats = conductor_stats self.loaded_modules = [] self.task_queue = task_queue self.webhook_router = webhook_router self.webhook_targets = {} self.websocket_queues = {} self.site = None self.context = context.start_scope("admin") self.responder = AdminResponder( self.context, outbound_message_router, self.send_webhook ) self.context.injector.bind_instance(BaseResponder, self.responder)
[docs] async def make_application(self) -> web.Application: """Get the aiohttp application instance.""" middlewares = [] admin_api_key = self.context.settings.get("admin.admin_api_key") admin_insecure_mode = self.context.settings.get("admin.admin_insecure_mode") # admin-token and admin-token are mutually exclusive and required. # This should be enforced during parameter parsing but to be sure, # we check here. assert admin_insecure_mode or admin_api_key assert not (admin_insecure_mode and admin_api_key) # If admin_api_key is None, then admin_insecure_mode must be set so # we can safely enable the admin server with no security if admin_api_key: @web.middleware async def check_token(request, handler): header_admin_api_key = request.headers.get("x-api-key") if not header_admin_api_key: raise web.HTTPUnauthorized() if admin_api_key == header_admin_api_key: return await handler(request) else: raise web.HTTPUnauthorized() middlewares.append(check_token) collector: Collector = await self.context.inject(Collector, required=False) if self.task_queue: @web.middleware async def apply_limiter(request, handler): task = await self.task_queue.put(handler(request)) return await task middlewares.append(apply_limiter) elif collector: @web.middleware async def collect_stats(request, handler): handler = collector.wrap_coro(handler, [handler.__qualname__]) return await handler(request) middlewares.append(collect_stats) app = web.Application(middlewares=middlewares) app["request_context"] = self.context app["outbound_message_router"] = self.responder.send app.add_routes( [ web.get("/", self.redirect_handler), web.get("/plugins", self.plugins_handler), web.get("/status", self.status_handler), web.post("/status/reset", self.status_reset_handler), web.get("/ws", self.websocket_handler), ] ) plugin_registry: PluginRegistry = await self.context.inject( PluginRegistry, required=False ) if plugin_registry: await plugin_registry.register_admin_routes(app) cors = aiohttp_cors.setup( app, defaults={ "*": aiohttp_cors.ResourceOptions( allow_credentials=True, expose_headers="*", allow_headers="*", allow_methods="*", ) }, ) for route in app.router.routes(): cors.add(route) # get agent label agent_label = self.context.settings.get("default_label") version_string = f"v{__version__}" setup_aiohttp_apispec( app=app, title=agent_label, version=version_string, swagger_path="/api/doc" ) app.on_startup.append(self.on_startup) return app
[docs] async def start(self) -> None: """ Start the webserver. Raises: AdminSetupError: If there was an error starting the webserver """ self.app = await self.make_application() runner = web.AppRunner(self.app) await runner.setup() self.site = web.TCPSite(runner, host=self.host, port=self.port) try: await self.site.start() except OSError: raise AdminSetupError( "Unable to start webserver with host " + f"'{self.host}' and port '{self.port}'\n" )
[docs] async def stop(self) -> None: """Stop the webserver.""" for queue in self.websocket_queues.values(): queue.stop() if self.site: await self.site.stop() self.site = None
[docs] async def on_startup(self, app: web.Application): """Perform webserver startup actions."""
[docs] @docs(tags=["server"], summary="Fetch the list of loaded plugins") @response_schema(AdminModulesSchema(), 200) async def plugins_handler(self, request: web.BaseRequest): """ Request handler for the loaded plugins list. Args: request: aiohttp request object Returns: The module list response """ registry: PluginRegistry = await self.context.inject( PluginRegistry, required=False ) plugins = registry and sorted(registry.plugin_names) or [] return web.json_response({"result": plugins})
[docs] @docs(tags=["server"], summary="Fetch the server status") @response_schema(AdminStatusSchema(), 200) async def status_handler(self, request: web.BaseRequest): """ Request handler for the server status information. Args: request: aiohttp request object Returns: The web response """ status = {"version": __version__} collector: Collector = await self.context.inject(Collector, required=False) if collector: status["timing"] = collector.results if self.conductor_stats: status["conductor"] = await self.conductor_stats() return web.json_response(status)
[docs] @docs(tags=["server"], summary="Reset statistics") @response_schema(AdminStatusSchema(), 200) async def status_reset_handler(self, request: web.BaseRequest): """ Request handler for resetting the timing statistics. Args: request: aiohttp request object Returns: The web response """ collector: Collector = await self.context.inject(Collector, required=False) if collector: collector.reset() return web.json_response({})
[docs] async def redirect_handler(self, request: web.BaseRequest): """Perform redirect to documentation.""" raise web.HTTPFound("/api/doc")
[docs] async def websocket_handler(self, request): """Send notifications to admin client over websocket.""" ws = web.WebSocketResponse() await ws.prepare(request) socket_id = str(uuid.uuid4()) queue = BasicMessageQueue() loop = asyncio.get_event_loop() try: self.websocket_queues[socket_id] = queue await queue.enqueue( { "topic": "settings", "payload": { "label": self.context.settings.get("default_label"), "endpoint": self.context.settings.get("default_endpoint"), "no_receive_invites": self.context.settings.get( "admin.no_receive_invites", False ), "help_link": self.context.settings.get("admin.help_link"), }, } ) closed = False receive = loop.create_task(ws.receive()) send = loop.create_task(queue.dequeue(timeout=5.0)) while not closed: try: await asyncio.wait( (receive, send), return_when=asyncio.FIRST_COMPLETED ) if ws.closed: closed = True if receive.done(): # ignored if not closed: receive = loop.create_task(ws.receive()) if send.done(): try: msg = send.result() except asyncio.TimeoutError: msg = None if msg is None: # we send fake pings because the JS client # can't detect real ones msg = {"topic": "ping"} if not closed: if msg: await ws.send_json(msg) send = loop.create_task(queue.dequeue(timeout=5.0)) except asyncio.CancelledError: closed = True if not receive.done(): receive.cancel() if not send.done(): send.cancel() finally: del self.websocket_queues[socket_id] return ws
[docs] def add_webhook_target( self, target_url: str, topic_filter: Sequence[str] = None, max_attempts: int = None, ): """Add a webhook target.""" self.webhook_targets[target_url] = WebhookTarget( target_url, topic_filter, max_attempts )
[docs] def remove_webhook_target(self, target_url: str): """Remove a webhook target.""" if target_url in self.webhook_targets: del self.webhook_targets[target_url]
[docs] async def send_webhook(self, topic: str, payload: dict): """Add a webhook to the queue, to send to all registered targets.""" if self.webhook_router: for idx, target in self.webhook_targets.items(): if not target.topic_filter or topic in target.topic_filter: self.webhook_router( topic, payload, target.endpoint, target.max_attempts ) for queue in self.websocket_queues.values(): await queue.enqueue({"topic": topic, "payload": payload})