mirror of
https://github.com/divkit/divkit.git
synced 2026-05-07 20:02:32 +00:00
2987d93ba7
commit_hash:cf5070a543788fa57136adb1a4a7ea42f4490329
594 lines
21 KiB
Python
594 lines
21 KiB
Python
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import json
|
|
import uuid
|
|
import warnings
|
|
from collections import defaultdict
|
|
from functools import reduce
|
|
from types import MappingProxyType
|
|
from typing import (
|
|
Any,
|
|
Dict,
|
|
FrozenSet,
|
|
Iterator,
|
|
List,
|
|
Mapping,
|
|
Optional,
|
|
Sequence,
|
|
Set,
|
|
Tuple,
|
|
Type,
|
|
Union,
|
|
get_args,
|
|
get_origin,
|
|
get_type_hints,
|
|
)
|
|
|
|
from .compat import classproperty
|
|
from .fields import Expr, _Field
|
|
from .schema import ExcludeFieldsType, SchemaType, _field_to_schema
|
|
from .serialization import (
|
|
_cast_value_type,
|
|
_make_exclude_fields,
|
|
_merge_types,
|
|
_update_related_templates,
|
|
dump,
|
|
)
|
|
|
|
TYPE_FIELD = "type"
|
|
|
|
|
|
class BaseEntity:
|
|
__fields__: Mapping[str, _Field]
|
|
__field_names__: Mapping[uuid.UUID, str]
|
|
__field_types__: Mapping[str, type]
|
|
|
|
__subclasses_registry__: Dict[str, Type[BaseEntity]] = {}
|
|
__related_templates__: FrozenSet[Type[BaseDiv]] = frozenset({})
|
|
__template_name__: Optional[str] = None
|
|
|
|
def __init__(self, **kwargs: Any) -> None:
|
|
if self.__fields__ and not self.__field_types__:
|
|
class_name = self.__class__.__name__
|
|
raise TypeError(
|
|
f"Class {class_name} fields are not prepared yet, you may "
|
|
f"need to call {class_name}.update_forward_refs().",
|
|
)
|
|
self._instance_related_templates: Set[Type[BaseDiv]] = set()
|
|
for field_name, value in kwargs.items():
|
|
if value is not None:
|
|
setattr(self, field_name, value)
|
|
for field_name, field in self.__fields__.items():
|
|
if field_name not in kwargs:
|
|
setattr(self, field_name, field.default)
|
|
|
|
@classproperty
|
|
@classmethod
|
|
def template_name(cls) -> str:
|
|
if cls.__template_name__ is not None:
|
|
return cls.__template_name__
|
|
return f"{cls.__module__}.{cls.__name__}"
|
|
|
|
def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None:
|
|
super().__init_subclass__()
|
|
if cls.template_name in cls.__subclasses_registry__:
|
|
warnings.warn(
|
|
f"Template {cls.template_name!r} already defined in "
|
|
f"{cls.__subclasses_registry__[cls.template_name]!r} "
|
|
f"and will be replaced to {cls!r}",
|
|
RuntimeWarning,
|
|
)
|
|
cls.__subclasses_registry__[cls.template_name] = cls
|
|
cls.__fields__ = cls._extract_fields()
|
|
cls.__field_names__ = cls._extract_field_names(cls.__fields__)
|
|
cls._remove_fields_from_cls_attrs()
|
|
cls.update_forward_refs()
|
|
if cls.__init__ != BaseEntity.__init__:
|
|
setattr(cls, "__init__", BaseEntity.__init__)
|
|
|
|
def __setattr__(self, key: str, value: Any) -> None:
|
|
if not key.startswith("_"):
|
|
field_type = self.__field_types__.get(key)
|
|
if field_type is None:
|
|
raise KeyError(
|
|
f"'{self.__class__.__name__}' object has no attribute '{key}'",
|
|
)
|
|
if key == TYPE_FIELD:
|
|
self._validate_type_value(value)
|
|
if not (isinstance(value, _Field) and value.ref_to):
|
|
value = _cast_value_type(value, field_type)
|
|
_update_related_templates(value, self._instance_related_templates)
|
|
super().__setattr__(key, value)
|
|
|
|
@classmethod
|
|
def _extract_field_names(
|
|
cls,
|
|
fields: Mapping[str, _Field],
|
|
) -> Mapping[uuid.UUID, str]:
|
|
return MappingProxyType(
|
|
{field.uid: field_name for field_name, field in fields.items()},
|
|
)
|
|
|
|
@classmethod
|
|
def _remove_fields_from_cls_attrs(cls) -> None:
|
|
to_remove = []
|
|
for attr_name, attr_value in cls.__dict__.items():
|
|
if isinstance(attr_value, _Field):
|
|
to_remove.append(attr_name)
|
|
for attr_name in to_remove:
|
|
delattr(cls, attr_name)
|
|
|
|
def _validate_type_value(self, value: Any) -> None:
|
|
field = self.__fields__[TYPE_FIELD]
|
|
if field.default and field.default != value:
|
|
raise ValueError(
|
|
f"Property '{TYPE_FIELD}' has wrong value: '{value}'. "
|
|
f"Expected is '{field.default}'.",
|
|
)
|
|
|
|
@classmethod
|
|
def update_forward_refs(cls) -> None:
|
|
try:
|
|
cls.__field_types__ = cls._extract_field_types(cls.__fields__)
|
|
except (NameError, AttributeError):
|
|
cls.__field_types__ = MappingProxyType({})
|
|
|
|
@classmethod
|
|
def _extract_fields(cls) -> Mapping[str, _Field]:
|
|
fields: Dict[str, _Field] = {}
|
|
for base in cls.__bases__:
|
|
base_fields = getattr(base, "__fields__", None)
|
|
if issubclass(base, BaseEntity) and base_fields:
|
|
fields.update(base_fields)
|
|
for key, value in cls.__dict__.items():
|
|
if isinstance(value, _Field):
|
|
field_ref = value.ref_to
|
|
if field_ref:
|
|
value = fields[key]
|
|
if value.is_ref:
|
|
raise ValueError("Ref cannot point to another ref")
|
|
if value.default is not None:
|
|
raise ValueError(
|
|
f"Cannot create a ref for field {key} "
|
|
f"because it has a default value",
|
|
)
|
|
value.ref_to = field_ref
|
|
if not value.name:
|
|
value.name = key
|
|
fields[key] = value
|
|
return MappingProxyType(fields)
|
|
|
|
@classmethod
|
|
def _extract_field_types(
|
|
cls,
|
|
fields: Mapping[str, _Field],
|
|
) -> Mapping[str, type]:
|
|
cls_hints = get_type_hints(cls, localns={cls.__name__: cls})
|
|
cls._validate_field_types_of_bases(cls_hints)
|
|
cls._validate_field_types(fields, cls_hints)
|
|
|
|
field_types = {}
|
|
for key in fields.keys():
|
|
field_types[key] = cls_hints[key]
|
|
for base in cls.__bases__:
|
|
base_field_types = getattr(base, "__field_types__", None)
|
|
if issubclass(base, BaseEntity) and base_field_types:
|
|
field_types.update(base_field_types)
|
|
|
|
for field_name, field_type in field_types.items():
|
|
field = cls.__fields__[field_name]
|
|
if field.ref_to:
|
|
field_types[field_name] = Optional[field_type]
|
|
|
|
return MappingProxyType(field_types)
|
|
|
|
@classmethod
|
|
def _validate_field_types(
|
|
cls,
|
|
fields: Mapping[str, _Field],
|
|
cls_hints: Mapping[str, Any],
|
|
) -> None:
|
|
for key, field in fields.items():
|
|
hint = cls_hints.get(key)
|
|
if not hint:
|
|
raise ValueError(
|
|
f"Type hint is missed for {cls.__name__}.{key}",
|
|
)
|
|
if field.default is not None:
|
|
_cast_value_type(field.default, hint)
|
|
|
|
@classmethod
|
|
def _validate_field_types_of_bases(
|
|
cls,
|
|
cls_hints: Mapping[str, Any],
|
|
) -> None:
|
|
for base in cls.__bases__:
|
|
base_hints = get_type_hints(base)
|
|
for key, cls_type_hint in cls_hints.items():
|
|
if not isinstance(getattr(base, key, None), _Field):
|
|
continue
|
|
base_type_hint = base_hints.get(key)
|
|
if base_type_hint and (cls_type_hint is not base_type_hint):
|
|
raise ValueError(
|
|
f"Type hint mismatch. _Field "
|
|
f"{cls.__name__}.{key} "
|
|
"should match type hint of a parent class "
|
|
f"{base.__name__}.{key}\n"
|
|
f"Expected: {base_type_hint}. "
|
|
f"Actual: {cls_type_hint}",
|
|
)
|
|
|
|
@classmethod
|
|
def _merge_ref_types(
|
|
cls,
|
|
*ref_types_seq: Mapping[uuid.UUID, Any],
|
|
) -> Mapping[uuid.UUID, Any]:
|
|
all_refs: Dict[uuid.UUID, List[Any]] = defaultdict(list)
|
|
for ref_types in ref_types_seq:
|
|
for ref_uid, ref_type in ref_types.items():
|
|
all_refs[ref_uid].append(ref_type)
|
|
return MappingProxyType(
|
|
{
|
|
ref_uid: reduce(_merge_types, ref_types)
|
|
for ref_uid, ref_types in all_refs.items()
|
|
},
|
|
)
|
|
|
|
@classmethod
|
|
def _extract_ref_types_from_obj(
|
|
cls,
|
|
obj: Any,
|
|
) -> Mapping[uuid.UUID, Any]:
|
|
if isinstance(obj, BaseEntity):
|
|
return obj._extract_all_ref_types()
|
|
elif isinstance(obj, list):
|
|
ref_types_seq = []
|
|
for item in obj:
|
|
if isinstance(item, BaseEntity):
|
|
ref_types_seq.append(item._extract_all_ref_types())
|
|
return cls._merge_ref_types(*ref_types_seq)
|
|
return MappingProxyType({})
|
|
|
|
def _extract_ref_types_from_values(
|
|
self,
|
|
) -> Mapping[uuid.UUID, Any]:
|
|
ref_types: Dict[uuid.UUID, Any] = {}
|
|
for field_name, field_type in self.__field_types__.items():
|
|
field_value = getattr(self, field_name, None)
|
|
if isinstance(field_value, _Field) and field_value.ref_to:
|
|
field = self.__fields__[field_name]
|
|
field_value.ref_to.apply_constraints(field.constraints)
|
|
ref_uid = field_value.ref_to.uid
|
|
if ref_uid in ref_types:
|
|
ref_types[ref_uid] = _merge_types(
|
|
ref_types[ref_uid],
|
|
field_type,
|
|
)
|
|
else:
|
|
ref_types[ref_uid] = field_type
|
|
return MappingProxyType(ref_types)
|
|
|
|
def _extract_all_ref_types(
|
|
self,
|
|
) -> Mapping[uuid.UUID, Any]:
|
|
ref_types_seq = [self._extract_ref_types_from_values()]
|
|
for field_name in self.__fields__:
|
|
field_value = getattr(self, field_name, None)
|
|
ref_types_seq.append(self._extract_ref_types_from_obj(field_value))
|
|
return self._merge_ref_types(*ref_types_seq)
|
|
|
|
def related_templates(self) -> Set[Type[BaseDiv]]:
|
|
return {
|
|
*self.__related_templates__,
|
|
*self._instance_related_templates,
|
|
}
|
|
|
|
def dict(self) -> Dict[str, Any]:
|
|
result: Dict[str, Any] = {}
|
|
for field_name, field in self.__fields__.items():
|
|
field_value = getattr(self, field_name, field.default)
|
|
if field_value is not None:
|
|
if isinstance(field_value, _Field) and field_value.ref_to:
|
|
result[f"${field.field_name}"] = field_value.ref_to.field_name
|
|
else:
|
|
result[field.field_name] = dump(field_value)
|
|
return result
|
|
|
|
@classmethod
|
|
def _can_add_field_to_schema(
|
|
cls,
|
|
field_name: str,
|
|
field: _Field,
|
|
) -> bool:
|
|
return not field.ref_to
|
|
|
|
@classmethod
|
|
def _build_schema(
|
|
cls,
|
|
definitions: Dict[str, SchemaType],
|
|
exclude: ExcludeFieldsType,
|
|
) -> SchemaType:
|
|
properties: Dict[str, Any] = {}
|
|
required: List[str] = []
|
|
for field_name, field in cls.__fields__.items():
|
|
if (
|
|
not cls._can_add_field_to_schema(field_name, field)
|
|
or exclude.get(field_name) is True
|
|
):
|
|
continue
|
|
field_type = cls.__field_types__[field_name]
|
|
properties[field_name] = _field_to_schema(
|
|
field=field,
|
|
type_=field_type,
|
|
exclude=exclude.get(field_name, {}),
|
|
definitions=definitions,
|
|
)
|
|
is_optional = get_origin(field_type) is Union and type(None) in get_args(
|
|
field_type
|
|
)
|
|
if not is_optional:
|
|
required.append(field_name)
|
|
schema: SchemaType = {"type": "object"}
|
|
if properties:
|
|
schema["properties"] = properties
|
|
if required:
|
|
schema["required"] = required
|
|
return schema
|
|
|
|
@classmethod
|
|
def schema_as_ref(
|
|
cls,
|
|
definitions: Dict[str, SchemaType],
|
|
exclude: ExcludeFieldsType,
|
|
) -> SchemaType:
|
|
schema_name = cls.__name__
|
|
if exclude:
|
|
serialized_exclude = json.dumps(exclude, sort_keys=True)
|
|
exclude_hash = hashlib.md5(serialized_exclude.encode()).hexdigest()
|
|
schema_name += f"_{exclude_hash}"
|
|
if schema_name not in definitions:
|
|
definitions[schema_name] = {} # to stop infinity recursion
|
|
schema = cls._build_schema(definitions, exclude)
|
|
definitions[schema_name] = schema
|
|
return {"$ref": f"#/definitions/{schema_name}"}
|
|
|
|
|
|
class BaseDiv(BaseEntity):
|
|
__tpl_values__: Mapping[str, Any]
|
|
__refs__: Mapping[str, Set[Optional[str]]]
|
|
__local_referred_fields__: FrozenSet[str]
|
|
|
|
__base_type__: Optional[str] = None
|
|
__template__: Optional[Dict[str, Any]] = None
|
|
|
|
def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None:
|
|
cls.__base_type__ = cls._get_base_type()
|
|
cls.__local_referred_fields__ = cls._extract_local_referred_fields()
|
|
cls.__template__ = None
|
|
cls._inject_type_field()
|
|
super().__init_subclass__(*args, **kwargs)
|
|
|
|
@classmethod
|
|
def _extract_local_referred_fields(cls) -> FrozenSet[str]:
|
|
return frozenset(
|
|
{
|
|
field_name
|
|
for field_name, field in cls.__dict__.items()
|
|
if isinstance(field, _Field) and field.ref_to
|
|
},
|
|
)
|
|
|
|
def __setattr__(self, key: str, value: Any) -> None:
|
|
if (value is None) and (key in self.__tpl_values__):
|
|
return
|
|
super().__setattr__(key, value)
|
|
|
|
@classmethod
|
|
def update_forward_refs(cls) -> None:
|
|
super().update_forward_refs()
|
|
cls.__tpl_values__ = cls._extract_tpl_values(cls.__field_types__)
|
|
cls.__field_types__ = cls._make_fields_optional_from_tpl_values(
|
|
cls.__field_types__,
|
|
cls.__tpl_values__,
|
|
)
|
|
cls.__related_templates__ = cls._extract_related_templates(
|
|
cls.__tpl_values__,
|
|
)
|
|
cls._validate_ref_types()
|
|
|
|
@classmethod
|
|
def _make_fields_optional_from_tpl_values(
|
|
cls,
|
|
field_types: Mapping[str, type],
|
|
tpl_values: Mapping[str, Any],
|
|
) -> Mapping[str, type]:
|
|
new_field_types: Dict[str, Any] = {**field_types}
|
|
for field_name in tpl_values:
|
|
new_field_types[field_name] = Optional[new_field_types[field_name]]
|
|
return MappingProxyType(new_field_types)
|
|
|
|
@classmethod
|
|
def _extract_related_templates(
|
|
cls,
|
|
tpl_values: Mapping[str, Any],
|
|
) -> FrozenSet[Type[BaseDiv]]:
|
|
related_templates: Set[Type[BaseDiv]] = set()
|
|
if cls.__base_type__:
|
|
related_templates.add(cls)
|
|
for base_cls in cls.__bases__:
|
|
if issubclass(base_cls, BaseEntity):
|
|
related_templates.update(base_cls.__related_templates__)
|
|
for tpl_value in tpl_values.values():
|
|
_update_related_templates(tpl_value, related_templates)
|
|
return frozenset(related_templates)
|
|
|
|
@classmethod
|
|
def _extract_tpl_values(
|
|
cls,
|
|
field_types: Mapping[str, Any],
|
|
) -> Mapping[str, Any]:
|
|
tpl_values: Dict[str, Any] = {}
|
|
for name, field_type in field_types.items():
|
|
if not hasattr(cls, name):
|
|
continue
|
|
value = getattr(cls, name)
|
|
if value is not None:
|
|
tpl_values[name] = _cast_value_type(value, field_type)
|
|
delattr(cls, name)
|
|
return MappingProxyType(tpl_values)
|
|
|
|
@classmethod
|
|
def _extract_ref_types_from_tpl_values(
|
|
cls,
|
|
) -> Mapping[uuid.UUID, Any]:
|
|
return cls._merge_ref_types(
|
|
*(
|
|
cls._extract_ref_types_from_obj(tpl_value)
|
|
for tpl_value in cls.__tpl_values__.values()
|
|
),
|
|
)
|
|
|
|
@classmethod
|
|
def _extract_ref_types_from_fields(
|
|
cls,
|
|
) -> Mapping[uuid.UUID, Any]:
|
|
ref_types: Dict[uuid.UUID, Any] = {}
|
|
for field_name, field in cls.__fields__.items():
|
|
if not field.ref_to:
|
|
continue
|
|
field.ref_to.apply_constraints(field.constraints)
|
|
ref_uid = field.ref_to.uid
|
|
field_type = cls.__field_types__[field_name]
|
|
if ref_uid in ref_types:
|
|
ref_types[ref_uid] = _merge_types(
|
|
ref_types[ref_uid],
|
|
field_type,
|
|
)
|
|
else:
|
|
ref_types[ref_uid] = field_type
|
|
return MappingProxyType(ref_types)
|
|
|
|
@staticmethod
|
|
def _make_union(type: Type[Any]) -> Set[Type[Any]]:
|
|
if get_origin(type) is Union:
|
|
return set(get_args(type))
|
|
return {type}
|
|
|
|
@classmethod
|
|
def _validate_subclass(
|
|
cls,
|
|
ref_type: Type[Any],
|
|
expected_field_type: Type[Any],
|
|
) -> bool:
|
|
def _check_origins() -> Iterator[Tuple[Tuple[Any, ...], Tuple[Any, ...]]]:
|
|
for ref in cls._make_union(ref_type):
|
|
for expect in cls._make_union(expected_field_type):
|
|
if get_origin(expect) == get_origin(ref):
|
|
yield get_args(ref), get_args(expect)
|
|
return None
|
|
|
|
for origins in _check_origins():
|
|
if origins is None:
|
|
return False
|
|
|
|
ref_args, expect_args = origins
|
|
|
|
if get_args(ref_args) == get_args(expect_args):
|
|
return True
|
|
|
|
if not any(
|
|
issubclass(exp, get_args(expect_args))
|
|
for exp in get_args(ref_args)
|
|
if get_origin(exp) is not Union
|
|
):
|
|
continue
|
|
|
|
return True
|
|
return False
|
|
|
|
@classmethod
|
|
def _validate_ref_types(cls) -> None:
|
|
ref_types = cls._merge_ref_types(
|
|
cls._extract_ref_types_from_fields(),
|
|
cls._extract_ref_types_from_tpl_values(),
|
|
)
|
|
for ref_uid, ref_type in ref_types.items():
|
|
field_name = cls.__field_names__[ref_uid]
|
|
expected_field_type = cls.__field_types__[field_name]
|
|
if (
|
|
ref_type != expected_field_type
|
|
and ref_type != Optional[expected_field_type]
|
|
and ref_type != Union[expected_field_type, Expr]
|
|
and ref_type != Union[expected_field_type, Expr, None]
|
|
and not cls._validate_subclass(ref_type, expected_field_type)
|
|
):
|
|
raise TypeError(
|
|
f"Type of attribute '{field_name}' does "
|
|
f"not match ref type "
|
|
f"{expected_field_type} != {ref_type}",
|
|
)
|
|
|
|
@classmethod
|
|
def _get_base_type(cls) -> Optional[str]:
|
|
if len(cls.__bases__) > 1:
|
|
raise TypeError(
|
|
"Types conflict: base class cannot be uniquely identified",
|
|
)
|
|
base_cls, *_ = cls.__bases__
|
|
if not issubclass(base_cls, BaseDiv):
|
|
raise TypeError(
|
|
"Types conflict: class must be derived from the BaseDiv class",
|
|
)
|
|
type_field = base_cls.__fields__.get(TYPE_FIELD)
|
|
if (base_cls is BaseDiv) or not (type_field and type_field.default):
|
|
return None
|
|
return type_field.default
|
|
|
|
@classmethod
|
|
def _inject_type_field(cls) -> None:
|
|
if cls.__base_type__:
|
|
type_value = cls.template_name
|
|
setattr(cls, TYPE_FIELD, _Field(default=type_value))
|
|
|
|
@classmethod
|
|
def template(cls) -> Dict[str, Any]:
|
|
if not cls.__base_type__:
|
|
raise TypeError(f"Component {cls.__name__} is not a template")
|
|
if cls.__template__ is None:
|
|
cls.__template__ = cls._build_template()
|
|
return cls.__template__
|
|
|
|
@classmethod
|
|
def _build_template(cls) -> Dict[str, Any]:
|
|
template = {TYPE_FIELD: cls.__base_type__}
|
|
for field_name, tpl_field_value in cls.__tpl_values__.items():
|
|
field = cls.__fields__[field_name]
|
|
template[field.field_name] = dump(tpl_field_value)
|
|
for field_name, field in cls.__fields__.items():
|
|
if field.ref_to and field_name in cls.__local_referred_fields__:
|
|
template[f"${field.field_name}"] = field.ref_to.field_name
|
|
return template
|
|
|
|
@classmethod
|
|
def _can_add_field_to_schema(
|
|
cls,
|
|
field_name: str,
|
|
field: _Field,
|
|
) -> bool:
|
|
return (
|
|
super()._can_add_field_to_schema(field_name, field)
|
|
and field_name not in cls.__tpl_values__
|
|
)
|
|
|
|
@classmethod
|
|
def schema(
|
|
cls,
|
|
exclude_fields: Optional[Sequence[str]] = None,
|
|
) -> SchemaType:
|
|
definitions: Dict[str, SchemaType] = {}
|
|
exclude = _make_exclude_fields(exclude_fields)
|
|
schema = cls._build_schema(definitions, exclude)
|
|
schema["definitions"] = definitions
|
|
return schema
|