"""Handle registration and publication of supported protocols."""
import logging
import re
from typing import Mapping, Sequence
from ..config.injection_context import InjectionContext
from ..utils.classloader import ClassLoader
from .error import ProtocolMinorVersionNotSupported, ProtocolDefinitionValidationError
LOGGER = logging.getLogger(__name__)
[docs]class ProtocolRegistry:
"""Protocol registry for indexing message families."""
def __init__(self):
"""Initialize a `ProtocolRegistry` instance."""
self._controllers = {}
self._typemap = {}
self._versionmap = {}
@property
def protocols(self) -> Sequence[str]:
"""Accessor for a list of all message protocols."""
prots = set()
for message_type in self._typemap.keys():
pos = message_type.rfind("/")
if pos > 0:
family = message_type[:pos]
prots.add(family)
return prots
@property
def message_types(self) -> Sequence[str]:
"""Accessor for a list of all message types."""
return tuple(self._typemap.keys())
@property
def controllers(self) -> Mapping[str, str]:
"""Accessor for a list of all protocol controller functions."""
return self._controllers.copy()
[docs] def protocols_matching_query(self, query: str) -> Sequence[str]:
"""Return a list of message protocols matching a query string."""
all_types = self.protocols
result = None
if query == "*" or query is None:
result = all_types
elif query:
if query.endswith("*"):
match = query[:-1]
result = tuple(k for k in all_types if k.startswith(match))
elif query in all_types:
result = (query,)
return result or ()
[docs] def parse_type_string(self, message_type):
"""Parse message type string and return dict with info."""
tokens = message_type.split("/")
protocol_name = tokens[-3]
version_string = tokens[-2]
message_name = tokens[-1]
version_string_tokens = version_string.split(".")
assert len(version_string_tokens) == 2
return {
"protocol_name": protocol_name,
"message_name": message_name,
"major_version": int(version_string_tokens[0]),
"minor_version": int(version_string_tokens[1]),
}
[docs] def create_msg_types_for_minor_version(self, typesets, version_definition):
"""
Return mapping of message type to module path for minor versions.
Args:
typesets: Mappings of message types to register
version_definition: Optional version definition dict
Returns:
Typesets mapping
"""
updated_typeset = {}
curr_minor_version = version_definition["current_minor_version"]
min_minor_version = version_definition["minimum_minor_version"]
major_version = version_definition["major_version"]
if curr_minor_version >= min_minor_version:
for version_index in range(min_minor_version, curr_minor_version + 1):
to_check = f"{str(major_version)}.{str(version_index)}"
updated_typeset.update(
self._get_updated_typeset_dict(typesets, to_check, updated_typeset)
)
else:
raise ProtocolDefinitionValidationError(
"min_minor_version is greater than curr_minor_version for the"
f" following typeset: {str(typesets)}"
)
return (updated_typeset,)
def _get_updated_typeset_dict(self, typesets, to_check, updated_typeset) -> dict:
for typeset in typesets:
for msg_type_string, module_path in typeset.items():
updated_msg_type_string = re.sub(
r"(\d+\.)?(\*|\d+)", to_check, msg_type_string
)
updated_typeset[updated_msg_type_string] = module_path
return updated_typeset
def _message_type_check_for_minor_verssion(self, version_definition) -> bool:
if not version_definition:
return False
curr_minor_version = version_definition["current_minor_version"]
min_minor_version = version_definition["minimum_minor_version"]
return bool(curr_minor_version >= 1 and curr_minor_version >= min_minor_version)
def _create_and_register_updated_typesets(self, typesets, version_definition):
updated_typesets = self.create_msg_types_for_minor_version(
typesets, version_definition
)
update_flag = False
for typeset in updated_typesets:
if typeset:
self._typemap.update(typeset)
update_flag = True
if update_flag:
return updated_typesets
else:
return None
def _update_version_map(self, message_type_string, module_path, version_definition):
parsed_type_string = self.parse_type_string(message_type_string)
if version_definition["major_version"] not in self._versionmap:
self._versionmap[version_definition["major_version"]] = []
self._versionmap[version_definition["major_version"]].append(
{
"parsed_type_string": parsed_type_string,
"version_definition": version_definition,
"message_module": module_path,
}
)
[docs] def register_message_types(self, *typesets, version_definition=None):
"""
Add new supported message types.
Args:
typesets: Mappings of message types to register
version_definition: Optional version definition dict
"""
# Maintain support for versionless protocol modules
updated_typesets = None
minor_versions_supported = self._message_type_check_for_minor_verssion(
version_definition
)
if not minor_versions_supported:
for typeset in typesets:
self._typemap.update(typeset)
# Track versioned modules for version routing
if version_definition:
# create updated typesets for minor versions and register them
if minor_versions_supported:
updated_typesets = self._create_and_register_updated_typesets(
typesets, version_definition
)
if updated_typesets:
typesets = updated_typesets
for typeset in typesets:
for message_type_string, module_path in typeset.items():
self._update_version_map(
message_type_string, module_path, version_definition
)
[docs] def register_controllers(self, *controller_sets, version_definition=None):
"""
Add new controllers.
Args:
controller_sets: Mappings of message families to coroutines
"""
for controlset in controller_sets:
self._controllers.update(controlset)
[docs] def resolve_message_class(self, message_type: str) -> type:
"""
Resolve a message_type to a message class.
Given a message type identifier, this method
returns the corresponding registered message class.
Args:
message_type: Message type to resolve
Returns:
The resolved message class
"""
# Try and retrieve from direct mapping
msg_cls = self._typemap.get(message_type)
if isinstance(msg_cls, str):
return ClassLoader.load_class(msg_cls)
# Support registered modules (not path as string)
elif msg_cls:
return msg_cls
# Try and route via min/maj version matching
if not msg_cls:
parsed_type_string = self.parse_type_string(message_type)
major_version = parsed_type_string["major_version"]
version_supported_protos = self._versionmap.get(major_version)
if not version_supported_protos:
return None
for proto in version_supported_protos:
if (
proto["parsed_type_string"]["protocol_name"]
== parsed_type_string["protocol_name"]
and proto["parsed_type_string"]["message_name"]
== parsed_type_string["message_name"]
):
if (
parsed_type_string["minor_version"]
< proto["version_definition"]["minimum_minor_version"]
):
raise ProtocolMinorVersionNotSupported(
"Minimum supported minor version is "
+ f"{proto['version_definition']['minimum_minor_version']}."
+ f" Received {parsed_type_string['minor_version']}."
)
if isinstance(proto["message_module"], str):
return ClassLoader.load_class(msg_cls)
elif proto["message_module"]:
return proto["message_module"]
return None
[docs] async def prepare_disclosed(
self, context: InjectionContext, protocols: Sequence[str]
):
"""Call controllers and return publicly supported message families and roles."""
published = []
for protocol in protocols:
result = {"pid": protocol}
if protocol in self._controllers:
ctl_cls = self._controllers[protocol]
if isinstance(ctl_cls, str):
ctl_cls = ClassLoader.load_class(ctl_cls)
ctl_instance = ctl_cls(protocol)
if hasattr(ctl_instance, "check_access"):
allowed = await ctl_instance.check_access(context)
if not allowed:
# remove from published
continue
if hasattr(ctl_instance, "determine_roles"):
roles = await ctl_instance.determine_roles(context)
if roles:
result["roles"] = list(roles)
published.append(result)
return published
def __repr__(self) -> str:
"""Return a string representation for this class."""
return "<{}>".format(self.__class__.__name__)