Source code for aries_cloudagent.core.protocol_registry

"""Handle registration and publication of supported protocols."""

import logging

from typing import Mapping, Sequence

from ..config.injection_context import InjectionContext
from ..utils.classloader import ClassLoader
from .error import ProtocolMinorVersionNotSupported

LOGGER = logging.getLogger(__name__)


[docs]class ProtocolRegistry: """Protocol registry for indexing message families.""" def __init__(self): """Initialize a `ProtocolRegistry` instance.""" self._controllers = {} self._typemap = {} self._versionmap = {} @property def protocols(self) -> Sequence[str]: """Accessor for a list of all message protocols.""" prots = set() for message_type in self._typemap.keys(): pos = message_type.rfind("/") if pos > 0: family = message_type[:pos] prots.add(family) return prots @property def message_types(self) -> Sequence[str]: """Accessor for a list of all message types.""" return tuple(self._typemap.keys()) @property def controllers(self) -> Mapping[str, str]: """Accessor for a list of all protocol controller functions.""" return self._controllers.copy()
[docs] def protocols_matching_query(self, query: str) -> Sequence[str]: """Return a list of message protocols matching a query string.""" all_types = self.protocols result = None if query == "*" or query is None: result = all_types elif query: if query.endswith("*"): match = query[:-1] result = tuple(k for k in all_types if k.startswith(match)) elif query in all_types: result = (query,) return result or ()
[docs] def parse_type_string(self, message_type): """Parse message type string and return dict with info.""" tokens = message_type.split("/") protocol_name = tokens[-3] version_string = tokens[-2] message_name = tokens[-1] version_string_tokens = version_string.split(".") assert len(version_string_tokens) == 2 return { "protocol_name": protocol_name, "message_name": message_name, "major_version": int(version_string_tokens[0]), "minor_version": int(version_string_tokens[1]), }
[docs] def register_message_types(self, *typesets, version_definition=None): """ Add new supported message types. Args: typesets: Mappings of message types to register version_definition: Optional version definition dict """ # Maintain support for versionless protocol modules for typeset in typesets: self._typemap.update(typeset) # Track versioned modules for version routing if version_definition: for typeset in typesets: for message_type_string, module_path in typeset.items(): parsed_type_string = self.parse_type_string(message_type_string) if version_definition["major_version"] not in self._versionmap: self._versionmap[version_definition["major_version"]] = [] self._versionmap[version_definition["major_version"]].append( { "parsed_type_string": parsed_type_string, "version_definition": version_definition, "message_module": module_path, } )
[docs] def register_controllers(self, *controller_sets, version_definition=None): """ Add new controllers. Args: controller_sets: Mappings of message families to coroutines """ for controlset in controller_sets: self._controllers.update(controlset)
[docs] def resolve_message_class(self, message_type: str) -> type: """ Resolve a message_type to a message class. Given a message type identifier, this method returns the corresponding registered message class. Args: message_type: Message type to resolve Returns: The resolved message class """ # Try and retrieve from direct mapping msg_cls = self._typemap.get(message_type) if isinstance(msg_cls, str): return ClassLoader.load_class(msg_cls) # Support registered modules (not path as string) elif msg_cls: return msg_cls # Try and route via min/maj version matching if not msg_cls: parsed_type_string = self.parse_type_string(message_type) major_version = parsed_type_string["major_version"] version_supported_protos = self._versionmap.get(major_version) if not version_supported_protos: return None for proto in version_supported_protos: if ( proto["parsed_type_string"]["protocol_name"] == parsed_type_string["protocol_name"] and proto["parsed_type_string"]["message_name"] == parsed_type_string["message_name"] ): if ( parsed_type_string["minor_version"] < proto["version_definition"]["minimum_minor_version"] ): raise ProtocolMinorVersionNotSupported( "Minimum supported minor version is " + f"{proto['version_definition']['minimum_minor_version']}." + f" Received {parsed_type_string['minor_version']}." ) if isinstance(proto["message_module"], str): return ClassLoader.load_class(msg_cls) elif proto["message_module"]: return proto["message_module"] return None
[docs] async def prepare_disclosed( self, context: InjectionContext, protocols: Sequence[str] ): """Call controllers and return publicly supported message families and roles.""" published = [] for protocol in protocols: result = {"pid": protocol} if protocol in self._controllers: ctl_cls = self._controllers[protocol] if isinstance(ctl_cls, str): ctl_cls = ClassLoader.load_class(ctl_cls) ctl_instance = ctl_cls(protocol) if hasattr(ctl_instance, "check_access"): allowed = await ctl_instance.check_access(context) if not allowed: # remove from published continue if hasattr(ctl_instance, "determine_roles"): roles = await ctl_instance.determine_roles(context) if roles: result["roles"] = list(roles) published.append(result) return published
def __repr__(self) -> str: """Return a string representation for this class.""" return "<{}>".format(self.__class__.__name__)