Files
p-mosein 2987d93ba7 pydivkit. migrate to uv+hatch, decompose entities.py, fix Python 3.14 compat
commit_hash:cf5070a543788fa57136adb1a4a7ea42f4490329
2026-02-10 16:34:59 +03:00

594 lines
21 KiB
Python

from __future__ import annotations
import hashlib
import json
import uuid
import warnings
from collections import defaultdict
from functools import reduce
from types import MappingProxyType
from typing import (
Any,
Dict,
FrozenSet,
Iterator,
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Type,
Union,
get_args,
get_origin,
get_type_hints,
)
from .compat import classproperty
from .fields import Expr, _Field
from .schema import ExcludeFieldsType, SchemaType, _field_to_schema
from .serialization import (
_cast_value_type,
_make_exclude_fields,
_merge_types,
_update_related_templates,
dump,
)
TYPE_FIELD = "type"
class BaseEntity:
__fields__: Mapping[str, _Field]
__field_names__: Mapping[uuid.UUID, str]
__field_types__: Mapping[str, type]
__subclasses_registry__: Dict[str, Type[BaseEntity]] = {}
__related_templates__: FrozenSet[Type[BaseDiv]] = frozenset({})
__template_name__: Optional[str] = None
def __init__(self, **kwargs: Any) -> None:
if self.__fields__ and not self.__field_types__:
class_name = self.__class__.__name__
raise TypeError(
f"Class {class_name} fields are not prepared yet, you may "
f"need to call {class_name}.update_forward_refs().",
)
self._instance_related_templates: Set[Type[BaseDiv]] = set()
for field_name, value in kwargs.items():
if value is not None:
setattr(self, field_name, value)
for field_name, field in self.__fields__.items():
if field_name not in kwargs:
setattr(self, field_name, field.default)
@classproperty
@classmethod
def template_name(cls) -> str:
if cls.__template_name__ is not None:
return cls.__template_name__
return f"{cls.__module__}.{cls.__name__}"
def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None:
super().__init_subclass__()
if cls.template_name in cls.__subclasses_registry__:
warnings.warn(
f"Template {cls.template_name!r} already defined in "
f"{cls.__subclasses_registry__[cls.template_name]!r} "
f"and will be replaced to {cls!r}",
RuntimeWarning,
)
cls.__subclasses_registry__[cls.template_name] = cls
cls.__fields__ = cls._extract_fields()
cls.__field_names__ = cls._extract_field_names(cls.__fields__)
cls._remove_fields_from_cls_attrs()
cls.update_forward_refs()
if cls.__init__ != BaseEntity.__init__:
setattr(cls, "__init__", BaseEntity.__init__)
def __setattr__(self, key: str, value: Any) -> None:
if not key.startswith("_"):
field_type = self.__field_types__.get(key)
if field_type is None:
raise KeyError(
f"'{self.__class__.__name__}' object has no attribute '{key}'",
)
if key == TYPE_FIELD:
self._validate_type_value(value)
if not (isinstance(value, _Field) and value.ref_to):
value = _cast_value_type(value, field_type)
_update_related_templates(value, self._instance_related_templates)
super().__setattr__(key, value)
@classmethod
def _extract_field_names(
cls,
fields: Mapping[str, _Field],
) -> Mapping[uuid.UUID, str]:
return MappingProxyType(
{field.uid: field_name for field_name, field in fields.items()},
)
@classmethod
def _remove_fields_from_cls_attrs(cls) -> None:
to_remove = []
for attr_name, attr_value in cls.__dict__.items():
if isinstance(attr_value, _Field):
to_remove.append(attr_name)
for attr_name in to_remove:
delattr(cls, attr_name)
def _validate_type_value(self, value: Any) -> None:
field = self.__fields__[TYPE_FIELD]
if field.default and field.default != value:
raise ValueError(
f"Property '{TYPE_FIELD}' has wrong value: '{value}'. "
f"Expected is '{field.default}'.",
)
@classmethod
def update_forward_refs(cls) -> None:
try:
cls.__field_types__ = cls._extract_field_types(cls.__fields__)
except (NameError, AttributeError):
cls.__field_types__ = MappingProxyType({})
@classmethod
def _extract_fields(cls) -> Mapping[str, _Field]:
fields: Dict[str, _Field] = {}
for base in cls.__bases__:
base_fields = getattr(base, "__fields__", None)
if issubclass(base, BaseEntity) and base_fields:
fields.update(base_fields)
for key, value in cls.__dict__.items():
if isinstance(value, _Field):
field_ref = value.ref_to
if field_ref:
value = fields[key]
if value.is_ref:
raise ValueError("Ref cannot point to another ref")
if value.default is not None:
raise ValueError(
f"Cannot create a ref for field {key} "
f"because it has a default value",
)
value.ref_to = field_ref
if not value.name:
value.name = key
fields[key] = value
return MappingProxyType(fields)
@classmethod
def _extract_field_types(
cls,
fields: Mapping[str, _Field],
) -> Mapping[str, type]:
cls_hints = get_type_hints(cls, localns={cls.__name__: cls})
cls._validate_field_types_of_bases(cls_hints)
cls._validate_field_types(fields, cls_hints)
field_types = {}
for key in fields.keys():
field_types[key] = cls_hints[key]
for base in cls.__bases__:
base_field_types = getattr(base, "__field_types__", None)
if issubclass(base, BaseEntity) and base_field_types:
field_types.update(base_field_types)
for field_name, field_type in field_types.items():
field = cls.__fields__[field_name]
if field.ref_to:
field_types[field_name] = Optional[field_type]
return MappingProxyType(field_types)
@classmethod
def _validate_field_types(
cls,
fields: Mapping[str, _Field],
cls_hints: Mapping[str, Any],
) -> None:
for key, field in fields.items():
hint = cls_hints.get(key)
if not hint:
raise ValueError(
f"Type hint is missed for {cls.__name__}.{key}",
)
if field.default is not None:
_cast_value_type(field.default, hint)
@classmethod
def _validate_field_types_of_bases(
cls,
cls_hints: Mapping[str, Any],
) -> None:
for base in cls.__bases__:
base_hints = get_type_hints(base)
for key, cls_type_hint in cls_hints.items():
if not isinstance(getattr(base, key, None), _Field):
continue
base_type_hint = base_hints.get(key)
if base_type_hint and (cls_type_hint is not base_type_hint):
raise ValueError(
f"Type hint mismatch. _Field "
f"{cls.__name__}.{key} "
"should match type hint of a parent class "
f"{base.__name__}.{key}\n"
f"Expected: {base_type_hint}. "
f"Actual: {cls_type_hint}",
)
@classmethod
def _merge_ref_types(
cls,
*ref_types_seq: Mapping[uuid.UUID, Any],
) -> Mapping[uuid.UUID, Any]:
all_refs: Dict[uuid.UUID, List[Any]] = defaultdict(list)
for ref_types in ref_types_seq:
for ref_uid, ref_type in ref_types.items():
all_refs[ref_uid].append(ref_type)
return MappingProxyType(
{
ref_uid: reduce(_merge_types, ref_types)
for ref_uid, ref_types in all_refs.items()
},
)
@classmethod
def _extract_ref_types_from_obj(
cls,
obj: Any,
) -> Mapping[uuid.UUID, Any]:
if isinstance(obj, BaseEntity):
return obj._extract_all_ref_types()
elif isinstance(obj, list):
ref_types_seq = []
for item in obj:
if isinstance(item, BaseEntity):
ref_types_seq.append(item._extract_all_ref_types())
return cls._merge_ref_types(*ref_types_seq)
return MappingProxyType({})
def _extract_ref_types_from_values(
self,
) -> Mapping[uuid.UUID, Any]:
ref_types: Dict[uuid.UUID, Any] = {}
for field_name, field_type in self.__field_types__.items():
field_value = getattr(self, field_name, None)
if isinstance(field_value, _Field) and field_value.ref_to:
field = self.__fields__[field_name]
field_value.ref_to.apply_constraints(field.constraints)
ref_uid = field_value.ref_to.uid
if ref_uid in ref_types:
ref_types[ref_uid] = _merge_types(
ref_types[ref_uid],
field_type,
)
else:
ref_types[ref_uid] = field_type
return MappingProxyType(ref_types)
def _extract_all_ref_types(
self,
) -> Mapping[uuid.UUID, Any]:
ref_types_seq = [self._extract_ref_types_from_values()]
for field_name in self.__fields__:
field_value = getattr(self, field_name, None)
ref_types_seq.append(self._extract_ref_types_from_obj(field_value))
return self._merge_ref_types(*ref_types_seq)
def related_templates(self) -> Set[Type[BaseDiv]]:
return {
*self.__related_templates__,
*self._instance_related_templates,
}
def dict(self) -> Dict[str, Any]:
result: Dict[str, Any] = {}
for field_name, field in self.__fields__.items():
field_value = getattr(self, field_name, field.default)
if field_value is not None:
if isinstance(field_value, _Field) and field_value.ref_to:
result[f"${field.field_name}"] = field_value.ref_to.field_name
else:
result[field.field_name] = dump(field_value)
return result
@classmethod
def _can_add_field_to_schema(
cls,
field_name: str,
field: _Field,
) -> bool:
return not field.ref_to
@classmethod
def _build_schema(
cls,
definitions: Dict[str, SchemaType],
exclude: ExcludeFieldsType,
) -> SchemaType:
properties: Dict[str, Any] = {}
required: List[str] = []
for field_name, field in cls.__fields__.items():
if (
not cls._can_add_field_to_schema(field_name, field)
or exclude.get(field_name) is True
):
continue
field_type = cls.__field_types__[field_name]
properties[field_name] = _field_to_schema(
field=field,
type_=field_type,
exclude=exclude.get(field_name, {}),
definitions=definitions,
)
is_optional = get_origin(field_type) is Union and type(None) in get_args(
field_type
)
if not is_optional:
required.append(field_name)
schema: SchemaType = {"type": "object"}
if properties:
schema["properties"] = properties
if required:
schema["required"] = required
return schema
@classmethod
def schema_as_ref(
cls,
definitions: Dict[str, SchemaType],
exclude: ExcludeFieldsType,
) -> SchemaType:
schema_name = cls.__name__
if exclude:
serialized_exclude = json.dumps(exclude, sort_keys=True)
exclude_hash = hashlib.md5(serialized_exclude.encode()).hexdigest()
schema_name += f"_{exclude_hash}"
if schema_name not in definitions:
definitions[schema_name] = {} # to stop infinity recursion
schema = cls._build_schema(definitions, exclude)
definitions[schema_name] = schema
return {"$ref": f"#/definitions/{schema_name}"}
class BaseDiv(BaseEntity):
__tpl_values__: Mapping[str, Any]
__refs__: Mapping[str, Set[Optional[str]]]
__local_referred_fields__: FrozenSet[str]
__base_type__: Optional[str] = None
__template__: Optional[Dict[str, Any]] = None
def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None:
cls.__base_type__ = cls._get_base_type()
cls.__local_referred_fields__ = cls._extract_local_referred_fields()
cls.__template__ = None
cls._inject_type_field()
super().__init_subclass__(*args, **kwargs)
@classmethod
def _extract_local_referred_fields(cls) -> FrozenSet[str]:
return frozenset(
{
field_name
for field_name, field in cls.__dict__.items()
if isinstance(field, _Field) and field.ref_to
},
)
def __setattr__(self, key: str, value: Any) -> None:
if (value is None) and (key in self.__tpl_values__):
return
super().__setattr__(key, value)
@classmethod
def update_forward_refs(cls) -> None:
super().update_forward_refs()
cls.__tpl_values__ = cls._extract_tpl_values(cls.__field_types__)
cls.__field_types__ = cls._make_fields_optional_from_tpl_values(
cls.__field_types__,
cls.__tpl_values__,
)
cls.__related_templates__ = cls._extract_related_templates(
cls.__tpl_values__,
)
cls._validate_ref_types()
@classmethod
def _make_fields_optional_from_tpl_values(
cls,
field_types: Mapping[str, type],
tpl_values: Mapping[str, Any],
) -> Mapping[str, type]:
new_field_types: Dict[str, Any] = {**field_types}
for field_name in tpl_values:
new_field_types[field_name] = Optional[new_field_types[field_name]]
return MappingProxyType(new_field_types)
@classmethod
def _extract_related_templates(
cls,
tpl_values: Mapping[str, Any],
) -> FrozenSet[Type[BaseDiv]]:
related_templates: Set[Type[BaseDiv]] = set()
if cls.__base_type__:
related_templates.add(cls)
for base_cls in cls.__bases__:
if issubclass(base_cls, BaseEntity):
related_templates.update(base_cls.__related_templates__)
for tpl_value in tpl_values.values():
_update_related_templates(tpl_value, related_templates)
return frozenset(related_templates)
@classmethod
def _extract_tpl_values(
cls,
field_types: Mapping[str, Any],
) -> Mapping[str, Any]:
tpl_values: Dict[str, Any] = {}
for name, field_type in field_types.items():
if not hasattr(cls, name):
continue
value = getattr(cls, name)
if value is not None:
tpl_values[name] = _cast_value_type(value, field_type)
delattr(cls, name)
return MappingProxyType(tpl_values)
@classmethod
def _extract_ref_types_from_tpl_values(
cls,
) -> Mapping[uuid.UUID, Any]:
return cls._merge_ref_types(
*(
cls._extract_ref_types_from_obj(tpl_value)
for tpl_value in cls.__tpl_values__.values()
),
)
@classmethod
def _extract_ref_types_from_fields(
cls,
) -> Mapping[uuid.UUID, Any]:
ref_types: Dict[uuid.UUID, Any] = {}
for field_name, field in cls.__fields__.items():
if not field.ref_to:
continue
field.ref_to.apply_constraints(field.constraints)
ref_uid = field.ref_to.uid
field_type = cls.__field_types__[field_name]
if ref_uid in ref_types:
ref_types[ref_uid] = _merge_types(
ref_types[ref_uid],
field_type,
)
else:
ref_types[ref_uid] = field_type
return MappingProxyType(ref_types)
@staticmethod
def _make_union(type: Type[Any]) -> Set[Type[Any]]:
if get_origin(type) is Union:
return set(get_args(type))
return {type}
@classmethod
def _validate_subclass(
cls,
ref_type: Type[Any],
expected_field_type: Type[Any],
) -> bool:
def _check_origins() -> Iterator[Tuple[Tuple[Any, ...], Tuple[Any, ...]]]:
for ref in cls._make_union(ref_type):
for expect in cls._make_union(expected_field_type):
if get_origin(expect) == get_origin(ref):
yield get_args(ref), get_args(expect)
return None
for origins in _check_origins():
if origins is None:
return False
ref_args, expect_args = origins
if get_args(ref_args) == get_args(expect_args):
return True
if not any(
issubclass(exp, get_args(expect_args))
for exp in get_args(ref_args)
if get_origin(exp) is not Union
):
continue
return True
return False
@classmethod
def _validate_ref_types(cls) -> None:
ref_types = cls._merge_ref_types(
cls._extract_ref_types_from_fields(),
cls._extract_ref_types_from_tpl_values(),
)
for ref_uid, ref_type in ref_types.items():
field_name = cls.__field_names__[ref_uid]
expected_field_type = cls.__field_types__[field_name]
if (
ref_type != expected_field_type
and ref_type != Optional[expected_field_type]
and ref_type != Union[expected_field_type, Expr]
and ref_type != Union[expected_field_type, Expr, None]
and not cls._validate_subclass(ref_type, expected_field_type)
):
raise TypeError(
f"Type of attribute '{field_name}' does "
f"not match ref type "
f"{expected_field_type} != {ref_type}",
)
@classmethod
def _get_base_type(cls) -> Optional[str]:
if len(cls.__bases__) > 1:
raise TypeError(
"Types conflict: base class cannot be uniquely identified",
)
base_cls, *_ = cls.__bases__
if not issubclass(base_cls, BaseDiv):
raise TypeError(
"Types conflict: class must be derived from the BaseDiv class",
)
type_field = base_cls.__fields__.get(TYPE_FIELD)
if (base_cls is BaseDiv) or not (type_field and type_field.default):
return None
return type_field.default
@classmethod
def _inject_type_field(cls) -> None:
if cls.__base_type__:
type_value = cls.template_name
setattr(cls, TYPE_FIELD, _Field(default=type_value))
@classmethod
def template(cls) -> Dict[str, Any]:
if not cls.__base_type__:
raise TypeError(f"Component {cls.__name__} is not a template")
if cls.__template__ is None:
cls.__template__ = cls._build_template()
return cls.__template__
@classmethod
def _build_template(cls) -> Dict[str, Any]:
template = {TYPE_FIELD: cls.__base_type__}
for field_name, tpl_field_value in cls.__tpl_values__.items():
field = cls.__fields__[field_name]
template[field.field_name] = dump(tpl_field_value)
for field_name, field in cls.__fields__.items():
if field.ref_to and field_name in cls.__local_referred_fields__:
template[f"${field.field_name}"] = field.ref_to.field_name
return template
@classmethod
def _can_add_field_to_schema(
cls,
field_name: str,
field: _Field,
) -> bool:
return (
super()._can_add_field_to_schema(field_name, field)
and field_name not in cls.__tpl_values__
)
@classmethod
def schema(
cls,
exclude_fields: Optional[Sequence[str]] = None,
) -> SchemaType:
definitions: Dict[str, SchemaType] = {}
exclude = _make_exclude_fields(exclude_fields)
schema = cls._build_schema(definitions, exclude)
schema["definitions"] = definitions
return schema