"""JSON Web Encryption utilities."""
import binascii
import json
from collections import OrderedDict
from typing import Any, Dict, Iterable, List, Mapping, Optional, Union
from marshmallow import Schema, ValidationError, fields
from ..wallet.util import b64_to_bytes, bytes_to_b64
IDENT_ENC_KEY = "encrypted_key"
IDENT_HEADER = "header"
IDENT_PROTECTED = "protected"
IDENT_RECIPIENTS = "recipients"
[docs]def b64url(value: Union[bytes, str]) -> str:
"""Encode a string or bytes value as unpadded base64-URL."""
if isinstance(value, str):
value = value.encode("utf-8")
return bytes_to_b64(value, urlsafe=True, pad=False)
[docs]def from_b64url(value: str) -> bytes:
"""Decode an unpadded base64-URL value."""
try:
return b64_to_bytes(value, urlsafe=True)
except binascii.Error:
raise ValidationError("Error decoding base64 value")
[docs]class B64Value(fields.Str):
"""A marshmallow-compatible wrapper for base64-URL values."""
def _serialize(self, value, attr, obj, **kwargs) -> Optional[str]:
if value is None:
return None
if not isinstance(value, bytes):
return TypeError("Expected bytes")
return b64url(value)
def _deserialize(self, value, attr, data, **kwargs) -> Any:
value = super()._deserialize(value, attr, data, **kwargs)
return from_b64url(value)
[docs]class JweSchema(Schema):
"""JWE envelope schema."""
protected = fields.Str(required=True)
unprotected = fields.Dict(required=False)
recipients = fields.List(fields.Dict(), required=False)
ciphertext = B64Value(required=True)
iv = B64Value(required=True)
tag = B64Value(required=True)
aad = B64Value(required=False)
# flattened:
header = fields.Dict(required=False)
encrypted_key = B64Value(required=False)
[docs]class JweRecipientSchema(Schema):
"""JWE recipient schema."""
encrypted_key = B64Value(required=True)
header = fields.Dict(required=False, metadata={"many": True})
[docs]class JweRecipient:
"""A single message recipient."""
def __init__(self, *, encrypted_key: bytes, header: dict = None) -> "JweRecipient":
"""Initialize the JWE recipient."""
self.encrypted_key = encrypted_key
self.header = header or {}
[docs] @classmethod
def deserialize(cls, entry: Mapping[str, Any]) -> "JweRecipient":
"""Deserialize a JWE recipient from a mapping."""
vals = JweRecipientSchema().load(entry)
return cls(**vals)
[docs] def serialize(self) -> dict:
"""Serialize the JWE recipient to a mapping."""
ret = OrderedDict([("encrypted_key", b64url(self.encrypted_key))])
if self.header:
ret["header"] = self.header
return ret
[docs]class JweEnvelope:
"""JWE envelope instance."""
def __init__(
self,
*,
protected: dict = None,
protected_b64: bytes = None,
unprotected: dict = None,
ciphertext: bytes = None,
iv: bytes = None,
tag: bytes = None,
aad: bytes = None,
with_protected_recipients: bool = False,
with_flatten_recipients: bool = True,
):
"""Initialize a new JWE envelope instance."""
self.protected = protected
self.protected_b64 = protected_b64
self.unprotected = unprotected or OrderedDict()
self.ciphertext = ciphertext
self.iv = iv
self.tag = tag
self.aad = aad
self.with_protected_recipients = with_protected_recipients
self.with_flatten_recipients = with_flatten_recipients
self._recipients: List[JweRecipient] = []
[docs] @classmethod
def from_json(cls, message: Union[bytes, str]) -> "JweEnvelope":
"""Decode a JWE envelope from a JSON string or bytes value."""
try:
return cls._deserialize(JweSchema().loads(message))
except json.JSONDecodeError:
raise ValidationError("Invalid JWE: not JSON")
[docs] @classmethod
def deserialize(cls, message: Mapping[str, Any]) -> "JweEnvelope":
"""Deserialize a JWE envelope from a mapping."""
return cls._deserialize(JweSchema().load(message))
@classmethod
def _deserialize(cls, parsed: Mapping[str, Any]) -> "JweEnvelope":
protected_b64 = parsed[IDENT_PROTECTED]
try:
protected: dict = json.loads(from_b64url(protected_b64))
except json.JSONDecodeError:
raise ValidationError(
"Invalid JWE: invalid JSON for protected headers"
) from None
unprotected = parsed.get("unprotected") or {}
if protected.keys() & unprotected.keys():
raise ValidationError("Invalid JWE: duplicate header")
encrypted_key = protected.get(IDENT_ENC_KEY) or parsed.get(IDENT_ENC_KEY)
recipients = None
protected_recipients = False
flat_recipients = False
if IDENT_RECIPIENTS in protected:
recipients = protected.pop(IDENT_RECIPIENTS)
if IDENT_RECIPIENTS in parsed:
raise ValidationError("Invalid JWE: duplicate recipients block")
protected_recipients = True
elif IDENT_RECIPIENTS in parsed:
recipients = parsed[IDENT_RECIPIENTS]
if IDENT_ENC_KEY in protected:
encrypted_key = from_b64url(protected.pop(IDENT_ENC_KEY))
header = protected.pop(IDENT_HEADER) if IDENT_HEADER in protected else None
protected_recipients = True
elif IDENT_ENC_KEY in parsed:
encrypted_key = parsed[IDENT_ENC_KEY]
header = parsed.get(IDENT_HEADER)
if recipients:
if encrypted_key:
raise ValidationError("Invalid JWE: flattened form with 'recipients'")
recipients = [JweRecipient.deserialize(recip) for recip in recipients]
elif encrypted_key:
recipients = [
JweRecipient(
encrypted_key=encrypted_key,
header=header,
)
]
flat_recipients = True
else:
raise ValidationError("Invalid JWE: no recipients")
inst = cls(
protected=protected,
protected_b64=protected_b64,
unprotected=unprotected,
ciphertext=parsed["ciphertext"],
iv=parsed.get("iv"),
tag=parsed["tag"],
aad=parsed.get("aad"),
with_protected_recipients=protected_recipients,
with_flatten_recipients=flat_recipients,
)
all_h = protected.keys() | unprotected.keys()
for recip in recipients:
if recip.header and recip.header.keys() & all_h:
raise ValidationError("Invalid JWE: duplicate header")
inst.add_recipient(recip)
return inst
[docs] def serialize(self) -> dict:
"""Serialize the JWE envelope to a mapping."""
if self.protected_b64 is None:
raise ValidationError("Missing protected: use set_protected")
if self.ciphertext is None:
raise ValidationError("Missing ciphertext for JWE")
if self.iv is None:
raise ValidationError("Missing iv (nonce) for JWE")
if self.tag is None:
raise ValidationError("Missing tag for JWE")
env = OrderedDict()
env["protected"] = self.protected_b64
if self.unprotected:
env["unprotected"] = self.unprotected.copy()
if not self.with_protected_recipients:
recipients = self.recipients_json
if self.with_flatten_recipients and len(recipients) == 1:
for k in recipients[0]:
env[k] = recipients[0][k]
elif recipients:
env[IDENT_RECIPIENTS] = recipients
else:
raise ValidationError("Missing message recipients")
env["iv"] = b64url(self.iv)
env["ciphertext"] = b64url(self.ciphertext)
env["tag"] = b64url(self.tag)
if self.aad:
env["aad"] = b64url(self.aad)
return env
[docs] def to_json(self) -> str:
"""Serialize the JWE envelope to a JSON string."""
return json.dumps(self.serialize())
[docs] def add_recipient(self, recip: JweRecipient):
"""Add a recipient to the JWE envelope."""
self._recipients.append(recip)
[docs] def set_protected(
self,
protected: Mapping[str, Any],
):
"""Set the protected headers of the JWE envelope."""
protected = OrderedDict(protected.items())
if self.with_protected_recipients:
recipients = self.recipients_json
if self.with_flatten_recipients and len(recipients) == 1:
protected.update(recipients[0])
elif recipients:
protected[IDENT_RECIPIENTS] = recipients
else:
raise ValidationError("Missing message recipients")
self.protected_b64 = b64url(json.dumps(protected))
@property
def protected_bytes(self) -> bytes:
"""Access the protected data encoded as bytes.
This value is used in the additional authenticated data when encrypting.
"""
return (
self.protected_b64.encode("utf-8")
if self.protected_b64 is not None
else None
)
[docs] def set_payload(self, ciphertext: bytes, iv: bytes, tag: bytes, aad: bytes = None):
"""Set the payload of the JWE envelope."""
self.ciphertext = ciphertext
self.iv = iv
self.tag = tag
self.aad = aad
@property
def recipients(self) -> Iterable[JweRecipient]:
"""Accessor for an iterator over the JWE recipients.
The headers for each recipient include protected and unprotected headers from the
outer envelope.
"""
header = self.protected.copy()
header.update(self.unprotected)
for recip in self._recipients:
if recip.header:
recip_h = header.copy()
recip_h.update(recip.header)
yield JweRecipient(encrypted_key=recip.encrypted_key, header=recip_h)
else:
yield JweRecipient(encrypted_key=recip.encrypted_key, header=header)
@property
def recipients_json(self) -> List[Dict[str, Any]]:
"""Encode the current recipients for JSON."""
return [recip.serialize() for recip in self._recipients]
@property
def recipient_key_ids(self) -> Iterable[JweRecipient]:
"""Accessor for an iterator over the JWE recipient key identifiers."""
for recip in self._recipients:
if recip.header and "kid" in recip.header:
yield recip.header["kid"]
[docs] def get_recipient(self, kid: str) -> JweRecipient:
"""Find a recipient by key ID."""
for recip in self._recipients:
if recip.header and recip.header.get("kid") == kid:
header = self.protected.copy()
header.update(self.unprotected)
header.update(recip.header)
return JweRecipient(encrypted_key=recip.encrypted_key, header=header)
@property
def combined_aad(self) -> bytes:
"""Accessor for the additional authenticated data."""
aad = self.protected_bytes
if self.aad:
aad += b"." + b64url(self.aad).encode("utf-8")
return aad