Source code for aries_cloudagent.core.protocol_registry

"""Handle registration and publication of supported protocols."""

from dataclasses import dataclass
import logging

from typing import Any, Dict, Mapping, Optional, Sequence, Union

from ..config.injection_context import InjectionContext
from ..utils.classloader import ClassLoader, DeferLoad
from ..messaging.message_type import MessageType, MessageVersion, ProtocolIdentifier

from .error import ProtocolMinorVersionNotSupported, ProtocolDefinitionValidationError

LOGGER = logging.getLogger(__name__)


[docs]@dataclass class VersionDefinition: """Version definition.""" min: MessageVersion current: MessageVersion
[docs] @classmethod def from_dict(cls, data: dict) -> "VersionDefinition": """Create a version definition from a dict.""" return cls( min=MessageVersion(data["major_version"], data["minimum_minor_version"]), current=MessageVersion( data["major_version"], data["current_minor_version"] ), )
[docs]@dataclass class ProtocolDefinition: """Protocol metadata used to register and resolve message types.""" ident: ProtocolIdentifier min: MessageVersion current: MessageVersion controller: Optional[str] = None @property def minor_versions_supported(self) -> bool: """Accessor for whether minor versions are supported.""" return bool(self.current.minor >= 1 and self.current.minor >= self.min.minor) def __post_init__(self): """Post-init hook.""" if self.min.major != self.current.major: raise ProtocolDefinitionValidationError( f"Major version mismatch: {self.min.major} != {self.current.major}" ) if self.min.minor > self.current.minor: raise ProtocolDefinitionValidationError( f"Minimum minor version greater than current minor version: " f"{self.min.minor} > {self.current.minor}" )
[docs]class ProtocolRegistry: """Protocol registry for indexing message families.""" def __init__(self): """Initialize a `ProtocolRegistry` instance.""" self._definitions: Dict[str, ProtocolDefinition] = {} self._type_to_message_cls: Dict[str, Union[DeferLoad, type]] = {} # Mapping[protocol identifier, controller module path] self._controllers = {} @property def protocols(self) -> Sequence[str]: """Accessor for a list of all message protocols.""" return [ str(definition.ident.with_version((definition.min.major, minor))) for definition in self._definitions.values() for minor in range(definition.min.minor, definition.current.minor + 1) ] @property def message_types(self) -> Sequence[str]: """Accessor for a list of all message types.""" return tuple(self._type_to_message_cls.keys())
[docs] def protocols_matching_query(self, query: str) -> Sequence[str]: """Return a list of message protocols matching a query string.""" all_types = self.protocols result = None if query == "*" or query is None: result = all_types elif query: if query.endswith("*"): match = query[:-1] result = tuple(k for k in all_types if k.startswith(match)) elif query in all_types: result = (query,) return result or ()
[docs] def register_message_types( self, typeset: Mapping[str, Union[str, type]], version_definition: Optional[Union[dict[str, Any], VersionDefinition]] = None, ): """Add new supported message types. Args: typesets: Mappings of message types to register version_definition: Optional version definition dict """ if version_definition is not None and isinstance(version_definition, dict): version_definition = VersionDefinition.from_dict(version_definition) definitions_to_add = {} type_to_message_cls_to_add = {} for message_type, message_cls in typeset.items(): parsed = MessageType.from_str(message_type) protocol = ProtocolIdentifier.from_message_type(parsed) if protocol.stem in definitions_to_add: definition = definitions_to_add[protocol.stem] elif protocol.stem in self._definitions: definition = self._definitions[protocol.stem] else: if version_definition: definition = ProtocolDefinition( ident=protocol, min=version_definition.min, current=version_definition.current, ) else: definition = ProtocolDefinition( ident=protocol, min=protocol.version, current=protocol.version, ) definitions_to_add[protocol.stem] = definition if isinstance(message_cls, str): message_cls = DeferLoad(message_cls) type_to_message_cls_to_add[message_type] = message_cls if definition.minor_versions_supported: for minor_version in range( definition.min.minor, definition.current.minor + 1 ): updated_type = parsed.with_version( (parsed.version.major, minor_version) ) type_to_message_cls_to_add[str(updated_type)] = message_cls self._type_to_message_cls.update(type_to_message_cls_to_add) self._definitions.update(definitions_to_add)
[docs] def register_controllers(self, *controller_sets): """Add new controllers. Args: controller_sets: Mappings of message families to coroutines """ for controlset in controller_sets: self._controllers.update(controlset)
[docs] def resolve_message_class( self, message_type: str ) -> Optional[Union[DeferLoad, type]]: """Resolve a message_type to a message class. Given a message type identifier, this method returns the corresponding registered message class. Args: message_type: Message type to resolve Returns: The resolved message class """ if (message_cls := self._type_to_message_cls.get(message_type)) is not None: return message_cls parsed = MessageType.from_str(message_type) protocol = ProtocolIdentifier.from_message_type(parsed) if definition := self._definitions.get(protocol.stem): if parsed.version.minor < definition.min.minor: raise ProtocolMinorVersionNotSupported( f"Minimum supported minor version is {definition.min.minor}." f" Received {parsed.version.minor}." ) # This code will only be reached if the received minor version is greater # than our current supported version. All directly supported minor # versions would be returned previously. message_type = str(parsed.with_version(definition.current)) if (message_cls := self._type_to_message_cls.get(message_type)) is not None: return message_cls return None
[docs] async def prepare_disclosed( self, context: InjectionContext, protocols: Sequence[str] ): """Call controllers and return publicly supported message families and roles.""" published = [] for protocol in protocols: result: Dict[str, Any] = {"pid": protocol} if protocol in self._controllers: ctl_cls = self._controllers[protocol] if isinstance(ctl_cls, str): ctl_cls = ClassLoader.load_class(ctl_cls) ctl_instance = ctl_cls(protocol) if hasattr(ctl_instance, "check_access"): allowed = await ctl_instance.check_access(context) if not allowed: # remove from published continue if hasattr(ctl_instance, "determine_roles"): roles = await ctl_instance.determine_roles(context) if roles: result["roles"] = list(roles) published.append(result) return published
def __repr__(self) -> str: """Return a string representation for this class.""" return "<{}>".format(self.__class__.__name__)