Source code for aries_cloudagent.admin.server

"""Admin server classes."""

import asyncio
import logging
from typing import Coroutine, Sequence, Set
import uuid

from aiohttp import web, ClientSession
from aiohttp_apispec import docs, response_schema, setup_aiohttp_apispec
import aiohttp_cors

from marshmallow import fields, Schema

from ..classloader import ClassLoader
from ..config.base import ConfigError
from ..config.injection_context import InjectionContext
from ..messaging.outbound_message import OutboundMessage
from ..messaging.responder import BaseResponder
from ..stats import Collector
from ..task_processor import TaskProcessor
from ..transport.outbound.queue.base import BaseOutboundMessageQueue

from .base_server import BaseAdminServer
from .error import AdminSetupError
from .routes import register_module_routes

LOGGER = logging.getLogger(__name__)


[docs]class AdminModulesSchema(Schema): """Schema for the modules endpoint.""" result = fields.List(fields.Str())
[docs]class AdminStatusSchema(Schema): """Schema for the status endpoint."""
[docs]class AdminResponder(BaseResponder): """Handle outgoing messages from message handlers.""" def __init__(self, send: Coroutine, webhook: Coroutine, **kwargs): """ Initialize an instance of `AdminResponder`. Args: send: Function to send outbound message """ super().__init__(**kwargs) self._send = send self._webhook = webhook
[docs] async def send_outbound(self, message: OutboundMessage): """ Send outbound message. Args: message: The `OutboundMessage` to be sent """ await self._send(message)
[docs] async def send_webhook(self, topic: str, payload: dict): """ Dispatch a webhook. Args: topic: the webhook topic identifier payload: the webhook payload value """ await self._webhook(topic, payload)
[docs]class WebhookTarget: """Class for managing webhook target information.""" def __init__( self, endpoint: str, topic_filter: Sequence[str] = None, retries: int = None ): """Initialize the webhook target.""" self.endpoint = endpoint self._topic_filter = None self.retries = retries # call setter self.topic_filter = topic_filter @property def topic_filter(self) -> Set[str]: """Accessor for the target's topic filter.""" return self._topic_filter @topic_filter.setter def topic_filter(self, val: Sequence[str]): """Setter for the target's topic filter.""" filter = set(val) if val else None if filter and "*" in filter: filter = None self._topic_filter = filter
[docs]class AdminServer(BaseAdminServer): """Admin HTTP server class.""" def __init__( self, host: str, port: int, context: InjectionContext, outbound_message_router: Coroutine, ): """ Initialize an AdminServer instance. Args: host: Host to listen on port: Port to listen on """ self.app = None self.host = host self.port = port self.loaded_modules = [] self.webhook_queue = None self.webhook_retries = 5 self.webhook_session: ClientSession = None self.webhook_targets = {} self.webhook_task = None self.webhook_processor: TaskProcessor = None self.websocket_queues = {} self.site = None self.context = context.start_scope("admin") self.responder = AdminResponder(outbound_message_router, self.send_webhook) self.context.injector.bind_instance(BaseResponder, self.responder)
[docs] async def make_application(self) -> web.Application: """Get the aiohttp application instance.""" middlewares = [] admin_api_key = self.context.settings.get("admin.admin_api_key") admin_insecure_mode = self.context.settings.get("admin.admin_insecure_mode") # admin-token and admin-token are mutually exclusive and required. # This should be enforced during parameter parsing but to be sure, # we check here. assert admin_insecure_mode or admin_api_key assert not (admin_insecure_mode and admin_api_key) # If admin_api_key is None, then admin_insecure_mode must be set so # we can safely enable the admin server with no security if admin_api_key: @web.middleware async def check_token(request, handler): header_admin_api_key = request.headers.get("x-api-key") if not header_admin_api_key: raise web.HTTPUnauthorized() if admin_api_key == header_admin_api_key: return await handler(request) else: raise web.HTTPUnauthorized() middlewares.append(check_token) stats: Collector = await self.context.inject(Collector, required=False) if stats: @web.middleware async def collect_stats(request, handler): handler = stats.wrap_coro( handler, [handler.__qualname__, "any-admin-request"] ) return await handler(request) middlewares.append(collect_stats) app = web.Application(middlewares=middlewares) app["request_context"] = self.context app["outbound_message_router"] = self.responder.send app.add_routes( [ web.get("/", self.redirect_handler), web.get("/modules", self.modules_handler), web.get("/status", self.status_handler), web.post("/status/reset", self.status_reset_handler), web.get("/ws", self.websocket_handler), ] ) await register_module_routes(app) for protocol_module_path in self.context.settings.get("external_protocols", []): try: routes_module = ClassLoader.load_module( f"{protocol_module_path}.routes" ) await routes_module.register(app) except Exception as e: raise ConfigError( f"Failed to load external protocol module '{protocol_module_path}'." ) from e cors = aiohttp_cors.setup( app, defaults={ "*": aiohttp_cors.ResourceOptions( allow_credentials=True, expose_headers="*", allow_headers="*", allow_methods="*", ) }, ) for route in app.router.routes(): cors.add(route) setup_aiohttp_apispec( app=app, title="Aries Cloud Agent", version="v1", swagger_path="/api/doc" ) app.on_startup.append(self.on_startup) return app
[docs] async def start(self) -> None: """ Start the webserver. Raises: AdminSetupError: If there was an error starting the webserver """ self.app = await self.make_application() runner = web.AppRunner(self.app) await runner.setup() self.site = web.TCPSite(runner, host=self.host, port=self.port) try: await self.site.start() except OSError: raise AdminSetupError( "Unable to start webserver with host " + f"'{self.host}' and port '{self.port}'\n" )
[docs] async def stop(self) -> None: """Stop the webserver.""" for queue in self.websocket_queues.values(): queue.stop() if self.site: await self.site.stop() self.site = None if self.webhook_queue: self.webhook_queue.stop() self.webhook_queue = None if self.webhook_session: await self.webhook_session.close() self.webhook_session = None
[docs] async def on_startup(self, app: web.Application): """Perform webserver startup actions."""
[docs] @docs(tags=["server"], summary="Fetch the list of loaded modules") @response_schema(AdminModulesSchema(), 200) async def modules_handler(self, request: web.BaseRequest): """ Request handler for the loaded modules list. Args: request: aiohttp request object Returns: The module list response """ return web.json_response({"result": self.loaded_modules})
[docs] @docs(tags=["server"], summary="Fetch the server status") @response_schema(AdminStatusSchema(), 200) async def status_handler(self, request: web.BaseRequest): """ Request handler for the server status information. Args: request: aiohttp request object Returns: The web response """ status = {} collector: Collector = await self.context.inject(Collector, required=False) if collector: status["timing"] = collector.results return web.json_response(status)
[docs] @docs(tags=["server"], summary="Reset statistics") @response_schema(AdminStatusSchema(), 200) async def status_reset_handler(self, request: web.BaseRequest): """ Request handler for resetting the timing statistics. Args: request: aiohttp request object Returns: The web response """ collector: Collector = await self.context.inject(Collector, required=False) if collector: collector.reset() return web.json_response({})
[docs] async def redirect_handler(self, request: web.BaseRequest): """Perform redirect to documentation.""" raise web.HTTPFound("/api/doc")
[docs] async def websocket_handler(self, request): """Send notifications to admin client over websocket.""" ws = web.WebSocketResponse() await ws.prepare(request) socket_id = str(uuid.uuid4()) queue = await self.context.inject(BaseOutboundMessageQueue) try: self.websocket_queues[socket_id] = queue await queue.enqueue( { "topic": "settings", "payload": { "label": self.context.settings.get("default_label"), "endpoint": self.context.settings.get("default_endpoint"), "no_receive_invites": self.context.settings.get( "admin.no_receive_invites", False ), "help_link": self.context.settings.get("admin.help_link"), }, } ) closed = False while not closed: try: msg = await queue.dequeue(timeout=5.0) if msg is None: # we send fake pings because the JS client # can't detect real ones msg = {"topic": "ping"} if ws.closed: closed = True if msg and not closed: await ws.send_json(msg) except asyncio.CancelledError: closed = True finally: del self.websocket_queues[socket_id] return ws
[docs] def add_webhook_target( self, target_url: str, topic_filter: Sequence[str] = None, retries: int = None ): """Add a webhook target.""" self.webhook_targets[target_url] = WebhookTarget( target_url, topic_filter, retries )
[docs] def remove_webhook_target(self, target_url: str): """Remove a webhook target.""" if target_url in self.webhook_targets: del self.webhook_targets[target_url]
[docs] async def send_webhook(self, topic: str, payload: dict): """Add a webhook to the queue, to send to all registered targets.""" if not self.webhook_queue: self.webhook_queue = await self.context.inject(BaseOutboundMessageQueue) self.webhook_task = asyncio.ensure_future(self._process_webhooks()) await self.webhook_queue.enqueue((topic, payload))
async def _process_webhooks(self): """Continuously poll webhook queue and dispatch to targets.""" self.webhook_session = ClientSession() self.webhook_processor = TaskProcessor(max_pending=5) async for topic, payload in self.webhook_queue: for queue in self.websocket_queues.values(): await queue.enqueue({"topic": topic, "payload": payload}) if self.webhook_targets: targets = self.webhook_targets.copy() for idx, target in targets.items(): if topic == "connections_activity": # filter connections activity by default (only sent to sockets) continue if not target.topic_filter or topic in target.topic_filter: retries = ( self.webhook_retries if target.retries is None else target.retries ) await self.webhook_processor.run_retry( lambda pending: self._perform_send_webhook( target.endpoint, topic, payload, pending.attempts + 1 ), ident=(target.endpoint, topic), retries=retries, ) self.webhook_queue.task_done() async def _perform_send_webhook( self, target_url: str, topic: str, payload: dict, attempt: int = None ): """Dispatch a webhook to a specific endpoint.""" full_webhook_url = f"{target_url}/topic/{topic}/" attempt_str = f" (attempt {attempt})" if attempt else "" LOGGER.debug("Sending webhook to : %s%s", full_webhook_url, attempt_str) async with self.webhook_session.post( full_webhook_url, json=payload ) as response: if response.status < 200 or response.status > 299: raise Exception("Unexpected response status")
[docs] async def complete_webhooks(self): """Wait for all pending webhooks to be dispatched, used in testing.""" if self.webhook_queue: await self.webhook_queue.join() self.webhook_queue.stop() if self.webhook_processor: await self.webhook_processor.wait_done()