"""Base classes for Models and Schemas."""
import logging
import json
from abc import ABC
from collections import namedtuple
from typing import Mapping, Optional, Type, TypeVar, Union, cast, overload
from typing_extensions import Literal
from marshmallow import Schema, post_dump, pre_load, post_load, ValidationError, EXCLUDE
from ...core.error import BaseError
from ...utils.classloader import ClassLoader
LOGGER = logging.getLogger(__name__)
SerDe = namedtuple("SerDe", "ser de")
[docs]def resolve_class(the_cls, relative_cls: Optional[type] = None) -> type:
"""Resolve a class.
Args:
the_cls: The class to resolve
relative_cls: Relative class to resolve from
Returns:
The resolved class
Raises:
ClassNotFoundError: If the class could not be loaded
"""
resolved = None
if isinstance(the_cls, type):
resolved = the_cls
elif isinstance(the_cls, str):
default_module = relative_cls and relative_cls.__module__
resolved = ClassLoader.load_class(the_cls, default_module)
else:
raise TypeError(
f"Could not resolve class from {the_cls}; incorrect type {type(the_cls)}"
)
return resolved
[docs]class BaseModelError(BaseError):
"""Base exception class for base model errors."""
ModelType = TypeVar("ModelType", bound="BaseModel")
[docs]class BaseModel(ABC):
"""Base model that provides convenience methods."""
def __init__(self):
"""Initialize BaseModel.
Raises:
TypeError: If schema_class is not set on Meta
"""
if not self.Meta.schema_class:
raise TypeError(
"Can't instantiate abstract class {} with no schema_class".format(
self.__class__.__name__
)
)
@classmethod
def _get_schema_class(cls) -> Type["BaseModelSchema"]:
"""Get the schema class.
Returns:
The resolved schema class
"""
resolved = resolve_class(cls.Meta.schema_class, cls)
if issubclass(resolved, BaseModelSchema):
return resolved
raise TypeError(
f"Resolved class is not a subclass of BaseModelSchema: {resolved}"
)
@property
def Schema(self) -> Type["BaseModelSchema"]:
"""Accessor for the model's schema class.
Returns:
The schema class
"""
return self._get_schema_class()
@overload
@classmethod
def deserialize(
cls: Type[ModelType],
obj,
) -> ModelType: ...
@overload
@classmethod
def deserialize(
cls: Type[ModelType],
obj,
*,
unknown: Optional[str] = None,
) -> ModelType: ...
@overload
@classmethod
def deserialize(
cls: Type[ModelType],
obj,
*,
none2none: Literal[False],
unknown: Optional[str] = None,
) -> ModelType: ...
@overload
@classmethod
def deserialize(
cls: Type[ModelType],
obj,
*,
none2none: Literal[True],
unknown: Optional[str] = None,
) -> Optional[ModelType]: ...
[docs] @classmethod
def deserialize(
cls: Type[ModelType],
obj,
*,
unknown: Optional[str] = None,
none2none: bool = False,
) -> Optional[ModelType]:
"""Convert from JSON representation to a model instance.
Args:
obj: The dict to load into a model instance
unknown: Behaviour for unknown attributes
none2none: Deserialize None to None
Returns:
A model instance for this data
"""
if obj is None and none2none:
return None
schema_cls = cls._get_schema_class()
schema = schema_cls(
unknown=unknown or resolve_meta_property(schema_cls, "unknown", EXCLUDE)
)
try:
return cast(
ModelType,
schema.loads(obj) if isinstance(obj, str) else schema.load(obj),
)
except (AttributeError, ValidationError) as err:
LOGGER.exception(f"{cls.__name__} message validation error:")
raise BaseModelError(f"{cls.__name__} schema validation failed") from err
@overload
def serialize(
self,
*,
as_string: Literal[True],
unknown: Optional[str] = None,
) -> str: ...
@overload
def serialize(
self,
*,
unknown: Optional[str] = None,
) -> dict: ...
[docs] def serialize(
self,
*,
as_string: bool = False,
unknown: Optional[str] = None,
) -> Union[str, dict]:
"""Create a JSON-compatible dict representation of the model instance.
Args:
as_string: Return a string of JSON instead of a dict
Returns:
A dict representation of this model, or a JSON string if as_string is True
"""
schema_cls = self._get_schema_class()
schema = schema_cls(
unknown=unknown or resolve_meta_property(schema_cls, "unknown", EXCLUDE)
)
try:
return (
schema.dumps(self, separators=(",", ":"))
if as_string
else schema.dump(self)
)
except (AttributeError, ValidationError) as err:
LOGGER.exception(f"{self.__class__.__name__} message serialization error:")
raise BaseModelError(
f"{self.__class__.__name__} schema validation failed"
) from err
[docs] @classmethod
def serde(cls, obj: Union["BaseModel", Mapping, None]) -> Optional[SerDe]:
"""Return serialized, deserialized representations of input object."""
if obj is None:
return None
if isinstance(obj, BaseModel):
return SerDe(obj.serialize(), obj)
return SerDe(obj, cls.deserialize(obj))
[docs] def validate(self, unknown: Optional[str] = None):
"""Validate a constructed model."""
schema = self.Schema(unknown=unknown)
errors = schema.validate(self.serialize())
if errors:
raise ValidationError(errors)
return self
[docs] @classmethod
def from_json(
cls,
json_repr: Union[str, bytes],
unknown: Optional[str] = None,
):
"""Parse a JSON string into a model instance.
Args:
json_repr: JSON string
Returns:
A model instance representation of this JSON
"""
try:
parsed = json.loads(json_repr)
except ValueError as e:
LOGGER.exception(f"{cls.__name__} message parse error:")
raise BaseModelError(f"{cls.__name__} JSON parsing failed") from e
return cls.deserialize(parsed, unknown=unknown)
[docs] def to_json(self, unknown: str = None) -> str:
"""Create a JSON representation of the model instance.
Returns:
A JSON representation of this message
"""
return json.dumps(self.serialize(unknown=unknown))
def __repr__(self) -> str:
"""Return a human readable representation of this class.
Returns:
A human readable string for this class
"""
exclude = resolve_meta_property(self, "repr_exclude", [])
items = (
"{}={}".format(k, repr(v))
for k, v in self.__dict__.items()
if k not in exclude
)
return "<{}({})>".format(self.__class__.__name__, ", ".join(items))
[docs]class BaseModelSchema(Schema):
"""BaseModel schema."""
def __init__(self, *args, **kwargs):
"""Initialize BaseModelSchema.
Raises:
TypeError: If model_class is not set on Meta
"""
super().__init__(*args, **kwargs)
if not self.Meta.model_class:
raise TypeError(
"Can't instantiate abstract class {} with no model_class".format(
self.__class__.__name__
)
)
@classmethod
def _get_model_class(cls):
"""Get the model class.
Returns:
The model class
"""
return resolve_class(cls.Meta.model_class, cls)
@property
def Model(self) -> type:
"""Accessor for the schema's model class.
Returns:
The model class
"""
return self._get_model_class()
@pre_load
def skip_dump_only(self, data, **kwargs):
"""Skip fields that are only expected during serialization.
Args:
data: The incoming data to clean
Returns:
The modified data
"""
if not data:
return data
to_remove = {
field_obj.data_key or field_name
for field_name, field_obj in self.fields.items()
if field_obj.dump_only
}
for field_name in to_remove:
if field_name in data:
del data[field_name]
return data
@post_load
def make_model(self, data: dict, **kwargs):
"""Return model instance after loading.
Returns:
A model instance
"""
try:
cls_inst = self.Model(**data)
except TypeError as err:
if "_type" in str(err) and "_type" in data:
data["msg_type"] = data["_type"]
del data["_type"]
cls_inst = self.Model(**data)
return cls_inst
@post_dump
def remove_skipped_values(self, data, **kwargs):
"""Remove values that are are marked to skip.
Returns:
Returns this modified data
"""
skip_vals = resolve_meta_property(self, "skip_values", [])
return {key: value for key, value in data.items() if value not in skip_vals}