Source code for caikit.core.data_model.base

# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Base classes and functionality for all data structures.
"""
# Standard
from dataclasses import dataclass
from enum import Enum
from io import IOBase
from typing import (
    Any,
    Callable,
    ClassVar,
    Dict,
    List,
    NoReturn,
    Optional,
    Tuple,
    Type,
    Union,
    get_type_hints,
)
import base64
import datetime
import json

# Third Party
from google.protobuf import json_format
from google.protobuf.descriptor import (
    Descriptor,
    EnumDescriptor,
    FieldDescriptor,
    OneofDescriptor,
)
from google.protobuf.internal import type_checkers as proto_type_checkers
from google.protobuf.message import Message as ProtoMessageType

# First Party
from py_to_proto.compat_annotated import Annotated, get_args, get_origin
import alog

# Local
from ..exceptions import error_handler
from . import enums, json_dict, timestamp

# if TYPE_CHECKING: # TODO: uncommenting this breaks `tox -e imports` because of a circular import
#     # Local
#     from caikit.core.data_model.data_backends import DataModelBackendBase
#     from caikit.interfaces.common.data_model.file import File

# metaclass-generated field members cannot be detected by pylint
# pylint: disable=no-member
# pylint: disable=too-many-lines


log = alog.use_channel("DATAM")
error: Callable[..., NoReturn] = error_handler.get(log)


[docs] class _DataBaseMetaClass(type): fields: Tuple full_name: str fields_enum_map: Dict # {} fields_enum_rev: Dict # {} _fields_oneofs_map: Dict # {} _fields_to_oneof: Dict # {} _fields_to_type: Dict # {} _fields_map: Tuple # () _fields_message: Tuple # () _fields_message_repeated: Tuple # () _fields_enum: Tuple # () _fields_enum_repeated: Tuple # () _fields_primitive: Tuple # () _fields_primitive_repeated: Tuple # () _proto_class: ClassVar[Type[ProtoMessageType]] """Meta class for all structures in the data model.""" # store a registry of all classes that use this metaclass, i.e., # all classes that extend DataBase. This is used for constructing new # instances by name without having to introspect all modules in data_model. class_registry = {} # This sentinel value is used to determine whether a given attribute is # present on a class without doing `getattr` twice in the case where the # attribute does exist. _MISSING_ATTRIBUTE = "missing attribute" # Special attribute used to communicate that the proto fields are forward # declared and will be populated after the metaclass has completed # construction. _FWD_DECL_FIELDS = "__fwd_decl_fields__" # Special instance attributes that an instance of a class derived from # DataBase may have. These are added to __slots__. _BACKEND_ATTR = "_backend" _WHICH_ONEOF_ATTR = "_which_oneof" # Special attribute used to indicate which defaults are user provided _USER_DEFINED_DEFAULTS = "__user_defined_defaults__" # When inferring which field in a oneof a given value should be used for # based on the python type, we need to check types in order with bool first, # ints next, then floats values that fit a "more flexible" type don't # accidentally get assigned to the wrong field. These are the lists of int # and bool type values in protobuf. _PROTO_TYPE_ORDER = [FieldDescriptor.TYPE_BOOL] + [ val for name, val in vars(FieldDescriptor).items() if name.startswith("TYPE_") and "INT" in name ] # Add property to track if a class supports exporting and importing via a # file operation supports_file_operations = False def __new__(mcs, name, bases, attrs): """When constructing a new data model class, we set the 'fields' class variable from the protobufs descriptor and then set the '__slots__' magic class attribute to fields. This provides two benefits: (a) performance is improved since the classes only need to know about these attributes (b) it helps to enforce that all member variables in these classes are described in the protobufs. Note: If you want to add a variable for internal use that is not described in the protobufs, it can be named in the tuple class variable _private_slots and will automatically be added to __slots__. """ # Protobufs fields can be divided into these categories, which are used # to automatically determine appropriate behavior in a number of methods attrs["full_name"] = name attrs["fields_enum_map"] = {} attrs["fields_enum_rev"] = {} attrs["_fields_oneofs_map"] = {} attrs["_fields_to_oneof"] = {} attrs["_fields_to_type"] = {} attrs["_fields_map"] = () attrs["_fields_message"] = () attrs["_fields_message_repeated"] = () attrs["_fields_enum"] = () attrs["_fields_enum_repeated"] = () attrs["_fields_primitive"] = () attrs["_fields_primitive_repeated"] = () # Look for the set of fields either from a predefined protobuf class or # from a forward declaration from @dataobject fields = () proto_class = None if name not in ["DataBase", "DataObjectBase"]: # Look for a precompiled proto class and if found, parse its # descriptor proto_class = attrs.get("_proto_class") if proto_class is not None: all_oneof_fields = [ field.name for oneof in proto_class.DESCRIPTOR.oneofs for field in oneof.fields ] fields = tuple( ( field for field in proto_class.DESCRIPTOR.fields_by_name if field not in all_oneof_fields ) ) + tuple(proto_class.DESCRIPTOR.oneofs_by_name) # Otherwise, we need to get the fields from a "special" attribute else: fields = attrs.pop(mcs._FWD_DECL_FIELDS, None) log.debug4( # type: ignore "Using dataclass forward declaration fields %s for %s", fields, name ) error.value_check( "<COR49310991E>", fields is not None, "No proto class found for {}", name, ) attrs["fields"] = fields attrs["_proto_class"] = proto_class # Look if any private slots are declared as class variables private_slots = attrs.setdefault("_private_slots", ()) # Class slots are fields + private slots, this prevents other # member attributes from being set and also improves performance attrs["__slots__"] = tuple( [f"_{field}" for field in fields] + list(private_slots) + [mcs._BACKEND_ATTR, mcs._WHICH_ONEOF_ATTR] ) # Create the instance of the type instance = super().__new__(mcs, name, bases, attrs) # If there's a valid proto class, perform proto descriptor parsing if proto_class is not None: mcs.parse_proto_descriptor(instance) # Return the constructed class instance return instance
[docs] @classmethod def parse_proto_descriptor(mcs, cls): # pyright: ignore[reportSelfClsParameterName] """Encapsulate the logic for parsing the protobuf descriptor here. This allows the parsing to be done as a post-process after metaclass initialization """ # use the fully qualified protobuf name to avoid conflicts with # nested messages that have matching names cls.full_name = cls._proto_class.DESCRIPTOR.full_name # preserve old fields for _make_property_getter later old_fields = cls.fields # overwrite to only have proto-specific fields present cls.fields = tuple(cls._proto_class.DESCRIPTOR.fields_by_name) # map from all enum fields to their enum classes # note: enums are also primitives, these overlap cls.fields_enum_map = { field.name: getattr(enums, field.enum_type.name) for field in cls._proto_class.DESCRIPTOR.fields if field.enum_type is not None } cls.fields_enum_rev = { field.name: getattr(enums, field.enum_type.name + "Rev") for field in cls._proto_class.DESCRIPTOR.fields if field.enum_type is not None } # mapping of all oneofs and the fields that are part of them # NOTE: protobuf makes an interesting use of oneof to wrap types that # should be explicitly optional. We don't want to consider these # oneofs in the general oneof handling. # Sort the names of the fields in this map to ensure that ordering is # correct such that bool < int < float cls._fields_oneofs_map = { oneof_name: mcs._sorted_oneof_field_names(oneof) for oneof_name, oneof in cls._proto_class.DESCRIPTOR.oneofs_by_name.items() if len(oneof.fields) != 1 or oneof.name != f"_{oneof.fields[0].name}" } cls._fields_to_oneof = { field_name: oneof_name for (oneof_name, oneof_fields) in cls._fields_oneofs_map.items() for field_name in oneof_fields } # all repeated fields fields_repeated = tuple( field.name for field in cls._proto_class.DESCRIPTOR.fields if field.label == field.LABEL_REPEATED ) # all messages, repeated or not _fields_message_all = tuple( field.name for field in cls._proto_class.DESCRIPTOR.fields if field.type == field.TYPE_MESSAGE ) # all enums, repeated or not _fields_enum_all = tuple( field.name for field in cls._proto_class.DESCRIPTOR.fields if field.enum_type is not None ) # all fields of type map cls._fields_map = tuple( field.name for field in cls._proto_class.DESCRIPTOR.fields if field.message_type and field.message_type.GetOptions().map_entry ) # all primitives, repeated or not _fields_primitive_all = ( frozenset(cls.fields) .difference(cls._fields_map) .difference(_fields_message_all) .difference(_fields_enum_all) ) # messages that are not repeated cls._fields_message = frozenset(_fields_message_all).difference(fields_repeated) # messages that are repeated cls._fields_message_repeated = frozenset(fields_repeated).intersection( _fields_message_all ) # enums that are not repeated cls._fields_enum = frozenset(_fields_enum_all).difference(fields_repeated) # enums that are repeated cls._fields_enum_repeated = frozenset(_fields_enum_all).intersection( fields_repeated ) # primitives that are not repeated cls._fields_primitive = frozenset(_fields_primitive_all).difference( fields_repeated ) # primitives that are repeated cls._fields_primitive_repeated = frozenset(fields_repeated).intersection( _fields_primitive_all ) # Update the global class and proto registries # NOTE: Explicitly not respecting metaclass inheritance so single # registry shared for all _DataBaseMetaClass.class_registry[cls.full_name] = cls # Add properties that use the underlying backend. Also add fields that # existed in old_fields for supporting oneofs # see https://github.com/caikit/caikit/pull/107 for details for field in set(cls.fields + tuple(old_fields)): # If the field is the name of a field within a oneof and it was not # in the old fields, the data is held under the oneof's name if this # is the set value for the oneof if oneof_name := cls._fields_to_oneof.get(field): setattr(cls, field, mcs._make_property_getter(field, oneof_name)) # If the field is a plain field or the name of a oneof, it will be # accessed directly else: setattr(cls, field, mcs._make_property_getter(field)) # If there is not already an __init__ function defined, make one current_init = cls.__init__ if current_init is None or current_init is DataBase.__init__: cls.__init__ = mcs._make_init(cls.fields) # Check DataBase for file handlers cls.supports_file_operations = ( cls.to_file != DataBase.to_file and cls.from_file != DataBase.from_file )
[docs] @classmethod def _make_property_getter( mcs, field, oneof_name=None # pyright: ignore[reportSelfClsParameterName] ): """This helper creates an @property attribute getter for the given field NOTE: This needs to live as a standalone function in order for the given field name to be properly bound to the closure for the attrs """ private_name = f"_{field}" if oneof_name is None else oneof_name def _property_getter(self): # Check to see if the private name is defined and just return it if # it is current = getattr(self, private_name, mcs._MISSING_ATTRIBUTE) if current is not mcs._MISSING_ATTRIBUTE: return current # If not currently set, delegate to the backend backend = self.backend if backend is None: error( "<COR66616239E>", AttributeError( f"{type(self)} missing attribute {field} and no backend set" ), ) attr_val = backend.get_attribute(self.__class__, field) if isinstance(attr_val, self.__class__.OneofFieldVal): log.debug2("Got a OneofFieldVal from the backend") # type: ignore assert field in self.__class__._fields_oneofs_map self._get_which_oneof_dict()[field] = attr_val.which_oneof attr_val = attr_val.val # If the backend says that this attribute should be cached, set it # as an attribute on the class if backend.cache_attribute(field, attr_val): setattr(self, field, attr_val) # Return the value found by the backend return attr_val # If this is a oneof, add an extra layer of wrapping to check # which_oneof before returning a valid result if oneof_name: def _oneof_property_getter(self): if self.which_oneof(oneof_name) == field: return _property_getter(self) return property(_oneof_property_getter) return property(_property_getter)
[docs] @staticmethod def _make_init(fields): """This helper creates an __init__ function for a class which has the arguments for all the fields and just sets them as instance attributes. """ # Format and preserve docstring docstring = """Construct with arguments for each field on the object Args: {} """.format( "\n ".join(fields) ) def __init__(self, *args, **kwargs): num_args = len(args) num_kwargs = len(kwargs) num_fields = len(fields) used_fields = [] # If the proto has oneofs, set up which_oneof which_oneof = {} cls = self.__class__ if cls._fields_oneofs_map: setattr(self, _DataBaseMetaClass._WHICH_ONEOF_ATTR, which_oneof) if num_args + num_kwargs > num_fields: error( "<COR71444420E>", TypeError(f"Too many arguments given. Args are: {fields}"), ) if num_args > 0: # Do a quick check for performance reason for i, field_val in enumerate(args): field_name = fields[i] setattr(self, field_name, field_val) used_fields.append(field_name) if num_kwargs > 0: # Do a quick check for performance reason for field_name, field_val in kwargs.items(): # If this is a oneof field, alias to the oneof name if oneof_name := cls._fields_to_oneof.get(field_name): which_oneof[oneof_name] = field_name field_name = oneof_name if ( field_name not in fields and field_name not in cls._fields_oneofs_map ): error( "<COR71444421E>", TypeError(f"Unknown field {field_name}") ) elif field_name in used_fields: error( "<COR71444422E>", TypeError(f"Got multiple values for field {field_name}"), ) setattr(self, field_name, field_val) used_fields.append(field_name) # Default all unspecified fields to their User specified defaults or None default_values = self.get_field_defaults() if num_fields > 0: # Do a quick check for performance reason for field_name in fields: if ( field_name not in used_fields and field_name not in cls._fields_to_oneof ): default_value = default_values.get(field_name) if default_value and isinstance(default_value, Callable): default_value = default_value() setattr(self, field_name, default_value) # Add type information for all fields. Do this during init to # allow for forward refs to be imported for field in cls.fields: cls._fields_to_type[field] = cls._get_type_for_field(field) # Set docstring to the method explicitly __init__.___doc__ = docstring return __init__
[docs] @classmethod def _sorted_oneof_field_names( mcs, oneof: OneofDescriptor # pyright: ignore[reportSelfClsParameterName] ) -> List[str]: """Helper to get the list of oneof fields while ensuring field names are sorted such that bool < int < float. This ensures that when iterating fields for which_oneof inference, lower-precedence types take precedence. """ return [ field.name for field in sorted( oneof.fields, key=lambda fld: mcs._PROTO_TYPE_ORDER.index(fld.type) if fld.type in mcs._PROTO_TYPE_ORDER else len(mcs._PROTO_TYPE_ORDER), ) ]
[docs] class DataBase(metaclass=_DataBaseMetaClass): """Base class for all structures in the data model. Notes: All leaves in the hierarchy of derived classes should have a corresponding protobufs class defined in the interface definitions. If not, an exception will be thrown at runtime. """ # Class constant used to identify protobuf types that are handled with # special logic in the to/from proto conversions PROTO_CONVERSION_SPECIAL_TYPES = [ timestamp.TIMESTAMP_PROTOBUF_NAME, json_dict.STRUCT_PROTOBUF_NAME, ]
[docs] @dataclass class OneofFieldVal: """Helper struct that backends can use to return information about values in oneofs along with which of the oneofs is currently valid """ val: Any which_oneof: str
[docs] def __setattr__(self, name, val): """Handle attribute setting for oneofs and named fields with delegation to backends as needed """ # If setting a oneof directly, remove any oneof information cls = self.__class__ if name in cls._fields_oneofs_map: self._get_which_oneof_dict().pop(name, None) # If this is the name of a oneof field, set the oneof itself if oneof_name := cls._fields_to_oneof.get(name): self._get_which_oneof_dict()[oneof_name] = name name = oneof_name # If attempting to set one of the named fields or a oneof, instead set # the private version of the attribute. if name in cls.fields or name in cls._fields_oneofs_map: super().__setattr__(f"_{name}", val) else: super().__setattr__(name, val)
[docs] @classmethod def get_proto_class(cls) -> Type[ProtoMessageType]: return cls._proto_class
[docs] @classmethod def get_field_defaults(cls) -> Type[ProtoMessageType]: """Get mapping of fields to default values. Mapping will not include fields without defaults""" return getattr(cls, _DataBaseMetaClass._USER_DEFINED_DEFAULTS, {})
[docs] @classmethod def get_field_message_type(cls, field_name: str) -> Optional[type]: """Get the python type for the given field. This function relies on the metaclass to fill cls._fields_to_type. This is to avoid costly computation during runtime Args: field_name (str): Field name to check (AttributeError raised if name is invalid) Returns: field_type: type The data model class type for the given field """ # Dataclass look ups are fast so keep them in to retain interface compatibility if field_name not in cls.fields: raise AttributeError(f"Invalid field {field_name}") # If field_name has not been cached then perform lookup and # save result if field_name not in cls._fields_to_type: cls._fields_to_type[field_name] = cls._get_type_for_field(field_name) return cls._fields_to_type.get(field_name)
[docs] @classmethod def from_backend(cls, backend): instance = cls.__new__(cls) setattr(instance, _DataBaseMetaClass._BACKEND_ATTR, backend) return instance
@property def backend(self) -> Optional["DataModelBackendBase"]: # type: ignore # noqa: F821 # see TYPE_CHECKING note at the top return getattr(self, _DataBaseMetaClass._BACKEND_ATTR, None)
[docs] def which_oneof(self, oneof_name: str) -> Optional[str]: """Get the name of the oneof field set for the given oneof or None if no field is set """ # If the internal dict is already set, use that information which_oneof = self._get_which_oneof_dict() if current_val := which_oneof.get(oneof_name): return current_val # Get the current value for the oneof and introspect which field its # type matches oneof_val = getattr(self, oneof_name) # Re-check in case the getattr pulled a OneofFieldVal that populated the # which_oneof dict with knowledge from the backend if current_val := which_oneof.get(oneof_name): return current_val # Try to figure out the field based on the type which_field = self._infer_which_oneof(oneof_name, oneof_val) if which_field is not None: which_oneof[oneof_name] = which_field return which_field
[docs] @classmethod def _infer_which_oneof(cls, oneof_name: str, oneof_val: Any) -> Optional[str]: """Check each candidate field within the oneof to see if it's a type match NOTE: In the case where fields within a oneof have the same type, the first field whose type matches will be used! """ # NOTE: The list of field names are guaranteed to be sorted so that # bool < int < float for field_name in cls._fields_oneofs_map.get(oneof_name, []): if cls._is_valid_type_for_field(field_name, oneof_val): return field_name
[docs] def _get_which_oneof_dict(self) -> Dict[str, str]: which_oneof = getattr(self, _DataBaseMetaClass._WHICH_ONEOF_ATTR, None) if which_oneof is None: super().__setattr__(_DataBaseMetaClass._WHICH_ONEOF_ATTR, {}) which_oneof = getattr(self, _DataBaseMetaClass._WHICH_ONEOF_ATTR) return which_oneof
[docs] @classmethod def _get_type_for_field(cls, field_name: str) -> type: """Helper class method to return the type hint for a particular field""" cls_type_hints = get_type_hints(cls) if type_hint := cls_type_hints.get(field_name): # If type is optional or a list then return internal type type_args = get_args(type_hint) if ( get_origin(type_hint) == Union and type_args == ( type_args[0], type(None), ) or get_origin(type_hint) in [list, List] ): type_hint = type_args[0] # If type is Annotated then get the actual type if get_origin(type_hint) == Annotated: type_hint = get_args(type_hint)[0] return type_hint fd = cls._proto_class.DESCRIPTOR.fields_by_name.get(field_name) if not fd: raise ValueError(f"Unknown field: {field_name}") # Convert the fd type into python if fd.type == fd.TYPE_MESSAGE: return cls.get_class_for_proto(fd.message_type) elif fd.type == fd.TYPE_ENUM: return cls.get_class_for_proto(fd.enum_type) elif fd.type == fd.TYPE_BOOL: return bool elif fd.type == fd.TYPE_BYTES: return bytes elif fd.type == fd.TYPE_STRING: return str elif fd.type in [ fd.TYPE_FIXED32, fd.TYPE_FIXED64, fd.TYPE_INT32, fd.TYPE_INT64, fd.TYPE_SFIXED32, fd.TYPE_SFIXED64, fd.TYPE_SINT32, fd.TYPE_SINT64, fd.TYPE_UINT32, fd.TYPE_UINT64, ]: return int elif fd.type in [fd.TYPE_FLOAT, fd.TYPE_DOUBLE]: return float raise ValueError(f"Unknown proto type: {fd.type}")
[docs] @classmethod def _is_valid_type_for_field(cls, field_name: str, val: Any) -> bool: """Check whether the given value is valid for the given field""" # pylint: disable=too-many-return-statements field_descriptor = cls._proto_class.DESCRIPTOR.fields_by_name[field_name] if val is None: return False # If val is a list, this maybe a union of list field if isinstance(val, list) and field_name in cls._fields_to_oneof: if len(val) == 0: log.info("Assuming the type is valid since list is empty") return True val_list_type = type(val[0]).__name__ return ( field_descriptor.message_type and f"{val_list_type}" in field_descriptor.message_type.full_name.lower() ) # If it's a data object or an enum and the descriptors match, it's a # good type if ( isinstance(val, DataBase) and field_descriptor.message_type == val.get_proto_class().DESCRIPTOR ) or ( isinstance(val, Enum) and field_descriptor.enum_type == val.get_proto_class().DESCRIPTOR # type: ignore ): return True # If it's a data object or an enum and the descriptors don't match, it's # a bad type if field_descriptor.type in [ field_descriptor.TYPE_MESSAGE, field_descriptor.TYPE_ENUM, ]: return False # If the field is a bool field, only accept python bools. Proto is ok to # accept ints, but we are stricter than that. if field_descriptor.type == field_descriptor.TYPE_BOOL: return isinstance(val, bool) # Proto doesn't allow non utf-8 bytes fields; however, python does. if field_descriptor.type == field_descriptor.TYPE_BYTES: return isinstance(val, bytes) # If it's a primitive, use protobuf type checkers checker = proto_type_checkers.GetTypeChecker(field_descriptor) # type: ignore try: checker.CheckValue(val) return True except TypeError: pass return False
[docs] @classmethod def from_binary_buffer(cls, buf): """Builds the data model object out of the binary string Args: buf: The binary buffer containing a serialized protobufs message Returns: A data model object instantiated from the protobufs message deserialized out of `buf` """ proto_message = cls.get_proto_class()() proto_message.ParseFromString(buf) return cls.from_proto(proto_message)
[docs] @classmethod def from_proto(cls, proto): """Build a DataBase from protobufs. Args: proto: A protocol buffer to serialize from. Returns: protobufs: A DataBase object. """ error.type_check("<COR45207671E>", ProtoMessageType, proto=proto) if cls._proto_class.DESCRIPTOR.name != proto.DESCRIPTOR.name: error( "<COR71783894E>", ValueError( "class name `{}` does not match protobufs name `{}`".format( cls._proto_class.DESCRIPTOR.name, proto.DESCRIPTOR.name ) ), ) kwargs = {} for field in cls.fields: try: proto_attr = getattr(proto, field) except AttributeError: error( "<COR71783905E>", AttributeError( "protobufs `{}` does not have field `{}`".format( proto.DESCRIPTOR.name, field ) ), ) if field in cls._fields_primitive or field in cls._fields_enum: # special case for oneofs if field not in cls._fields_to_oneof or proto.HasField(field): kwargs[field] = proto_attr elif ( field in cls._fields_primitive_repeated or field in cls._fields_enum_repeated ): kwargs[field] = list(proto_attr) elif field in cls._fields_map: kwargs[field] = {} for key, value in proto_attr.items(): # Similar to filling; if our value is a non-primitive, i.e., a message, # we need to look up the data model class attached to it. if hasattr(value, "DESCRIPTOR"): contained_class = cls.get_class_for_proto(value) kwargs[field][key] = contained_class.from_proto(value) # If it's not a message, the value can be left alone, i.e., it's a primitive else: kwargs[field][key] = value elif field in cls._fields_message: if proto.HasField(field): if ( proto_attr.DESCRIPTOR.full_name == json_dict.STRUCT_PROTOBUF_NAME ): kwargs[field] = json_dict.struct_to_dict(proto_attr) elif ( proto_attr.DESCRIPTOR.full_name == timestamp.TIMESTAMP_PROTOBUF_NAME ): kwargs[field] = timestamp.proto_to_datetime(proto_attr) elif proto_attr.DESCRIPTOR.full_name.endswith("Sequence"): oneof = cls._fields_to_oneof[field] contained_class = cls.get_class_for_proto(proto_attr) contained_obj = contained_class.from_proto(proto_attr) if hasattr(contained_obj, "values") and ( contained_class.__module__.startswith( "caikit.core.data_model" ) or contained_class.__module__.startswith( "caikit.interfaces.common.data_model" ) ): kwargs[oneof] = contained_obj.values # type: ignore else: kwargs[oneof] = contained_obj else: contained_class = cls.get_class_for_proto(proto_attr) contained_obj = contained_class.from_proto(proto_attr) kwargs[field] = contained_obj elif field in cls._fields_message_repeated: elements = [] contained_class = None for item in proto_attr: if item.DESCRIPTOR.full_name == json_dict.STRUCT_PROTOBUF_NAME: elements.append(json_dict.struct_to_dict(item)) elif item.DESCRIPTOR.full_name == timestamp.TIMESTAMP_PROTOBUF_NAME: elements.append(timestamp.proto_to_datetime(item)) else: if contained_class is None: contained_class = cls.get_class_for_proto(item) elements.append(contained_class.from_proto(item)) kwargs[field] = elements else: error( "<COR71783815E>", AttributeError( "field `{}` is not a protobufs primitive, message, map or " "repeated".format(field) ), ) return cls(**kwargs)
[docs] @classmethod def from_json(cls, json_str, ignore_unknown_fields=False): """Build a DataBase from a given JSON string. Use google's protobufs.json_format for deserialization Args: json_str (str or dict): A stringified JSON specification/dict of the data_model ignore_unknown_fields (bool): If True, ignores unknown JSON fields Returns: caikit.core.data_model.DataBase: A DataBase object. """ # Get protobufs class required for parsing error.type_check("<COR91037250E>", str, dict, json_str=json_str) if isinstance(json_str, dict): # Convert dict object to a JSON string json_str = json.dumps(json_str) try: # Parse given JSON into google.protobufs.pyext.cpp_message.GeneratedProtocolMessageType parsed_proto = json_format.Parse( json_str, cls.get_proto_class()(), ignore_unknown_fields=ignore_unknown_fields, ) # Use from_proto to return the DataBase object from the parsed proto return cls.from_proto(parsed_proto) except json_format.ParseError as ex: error("<COR90619980E>", ValueError(ex))
[docs] @classmethod def from_file(cls, file_obj: IOBase): """Build a DataBase from a given file-like object. Args: file_obj IOBase: A file object that contains some representation of the dataobject Returns: caikit.core.data_model.DataBase: A DataBase object. """ raise NotImplementedError(f"from_file not implemented for {cls}")
[docs] def to_proto(self): """Return a new protobufs populated with the information in this data structure.""" # get the name of the protobufs class proto_class = self.__class__.get_proto_class() if proto_class is None: error( "<COR71783827E>", AttributeError( "protobufs not found for class `{}`".format(self.__class__) ), ) # create the protobufs and call fill_proto to populate it return self.fill_proto(proto_class())
[docs] def to_binary_buffer(self): """Returns a binary buffer with a serialized protobufs message of this data model""" return self.to_proto().SerializeToString()
[docs] def fill_proto(self, proto): """Populate a protobufs with the values from this data model object. Args: proto: A protocol buffer to be populated. Returns: protobufs: The filled protobufs. Notes: The protobufs is filled in place, so the argument and the return value are the same at the end of this call. """ for field in self.fields: try: attr = getattr(self, field) except AttributeError: error( "<COR71783840E>", AttributeError( "class `{}` has no attribute `{}` but it is in the protobufs".format( self.__class__.__name__, field ) ), ) if attr is None: continue if field in self._fields_primitive: setattr(proto, field, attr) elif field in self._fields_enum: if isinstance(attr, Enum): setattr(proto, field, attr.value) else: setattr(proto, field, attr) elif field in self._fields_map: subproto = getattr(proto, field) for key, value in attr.items(): # If our values aren't primitives, the subproto will have a DESCRIPTOR; # in this case we need to fill down recursively, i.e., this is a # protobufs message map container if hasattr(subproto[key], "DESCRIPTOR"): value.fill_proto(subproto[key]) # Otherwise we have a protobufs scalar map container, and we can set the # primitive value like a normal dictionary. else: subproto[key] = value elif ( field in self._fields_primitive_repeated or field in self._fields_enum_repeated ): subproto = getattr(proto, field) subproto.extend(attr) elif field in self._fields_message: subproto = getattr(proto, field) if subproto.DESCRIPTOR.full_name == json_dict.STRUCT_PROTOBUF_NAME: subproto.CopyFrom( json_dict.dict_to_struct(attr, subproto.__class__) ) elif subproto.DESCRIPTOR.full_name == timestamp.TIMESTAMP_PROTOBUF_NAME: timestamp_proto = timestamp.datetime_to_proto(attr) subproto.CopyFrom(timestamp_proto) # check that this is any of the Union of List types elif subproto.DESCRIPTOR.full_name.endswith( "Sequence" ) and not issubclass(attr.__class__, DataBase): seq_dm = subproto.__class__ try: subproto.CopyFrom(seq_dm(values=attr)) log.debug4("Successfully fill proto for %s", field) # type: ignore except TypeError: log.debug4("not the correct union list type") # type: ignore else: attr.fill_proto(subproto) elif field in self._fields_message_repeated: subproto = getattr(proto, field) for item in attr: elem_type = subproto.add() if isinstance(item, dict): elem_type.CopyFrom( json_dict.dict_to_struct(item, elem_type.__class__) ) elif isinstance(item, datetime.datetime): elem_type.CopyFrom(timestamp.datetime_to_proto(item)) else: item.fill_proto(elem_type) else: error( "<COR71783852E>", AttributeError( "field `{}` is not a protobufs primitive, message or repeated".format( field ) ), ) return proto
[docs] def to_dict(self) -> dict: """Convert to a dictionary representation.""" # maintain a list of fields to convert to dict, special handling for oneofs fields_to_dict = [] for field in self.fields: if ( field not in self._fields_to_oneof or self.which_oneof(self._fields_to_oneof[field]) == field ): fields_to_dict.append(field) to_dict = {} for field in fields_to_dict: dict_value = self._field_to_dict_element(field) if ( field in self._fields_to_oneof and not hasattr(dict_value, "values") and isinstance(dict_value, list) ): dict_value = {"values": dict_value} to_dict[field] = dict_value return to_dict
[docs] def to_kwargs(self) -> dict: """Convert to flat dictionary representation. (Like .to_dict, but not recursive) This keeps the attribute names of any fields backed by oneofs, instead of using the internal oneof field name """ fields_to_dict = [] for field in self.fields: if field not in self._fields_to_oneof: fields_to_dict.append(field) else: fields_to_dict.append(self._fields_to_oneof[field]) return {field: getattr(self, field) for field in fields_to_dict}
[docs] def to_json(self, **kwargs) -> str: """Convert to a json representation.""" def _default_serialization_overrides(obj): """Default handler for non-serializable objects; currently this only handles - bytes - datetime.datetime """ if isinstance(obj, bytes): return base64.b64encode(obj).decode("utf-8") if isinstance(obj, datetime.datetime): # Use the timestamp's proto-serialized format to get the proper json serializer return timestamp.datetime_to_proto(obj).ToJsonString() raise TypeError(f"Object of type {type(obj)} is not JSON serializable") if "default" not in kwargs: kwargs["default"] = _default_serialization_overrides return json.dumps(self.to_dict(), **kwargs)
[docs] def to_file( self, file_obj: IOBase ) -> Optional["File"]: # type: ignore # noqa: F821 # see TYPE_CHECKING note at the top """Export a DataBaseObject into a file-like object `file_obj`. If the DataBase object has requirements around file name or file type it can return them via the optional "File" return object Args: file_obj IOBase: a file object to be filled Returns: file_descriptor: Optional[caikit.interfaces.common.data_mode.File] """ raise NotImplementedError(f"to_file not implemented for {self.__class__}")
[docs] def __repr__(self): """Human-friendly representation.""" return self.to_json(indent=2, ensure_ascii=False)
[docs] def _field_to_dict_element(self, field): """Convert field into a representation that can be placed into a dictionary. Recursively calls to_dict on other data model objects. """ try: attr = getattr(self, field) except AttributeError: error( "<COR71783864E>", AttributeError( "class `{}` has no attribute `{}` but it is in the protobufs".format( self.__class__.__name__, field ) ), ) # if field is None, assume it's unset and just return None if attr is None: return None if field in self._fields_enum: # if field is an enum, do the reverse lookup from int -> str enum_rev = self.fields_enum_rev.get(field) if enum_rev is not None: return ( enum_rev[attr.value] if isinstance(attr, Enum) else enum_rev[attr] ) if field in self._fields_enum_repeated: # if field is an enum, do the reverse lookup from int -> str enum_rev = self.fields_enum_rev.get(field) if enum_rev is not None: return [enum_rev[item] for item in attr] # if field is a primitive, just return it to be placed in dict if field in self._fields_primitive or field in self._fields_primitive_repeated: return attr def _recursive_to_dict(_attr): if isinstance(_attr, dict): return {key: _recursive_to_dict(value) for key, value in _attr.items()} if isinstance(_attr, list): return [_recursive_to_dict(listitem) for listitem in _attr] if isinstance(_attr, DataBase): return _attr.to_dict() return _attr # If field is an object in out data model/map/list call to_dict recursively on each element if ( field in self._fields_map or field in self._fields_message or field in self._fields_message_repeated ): return _recursive_to_dict(attr) # fallback to the string representation return str(attr)
[docs] @staticmethod def get_class_for_proto( proto: Union[Descriptor, FieldDescriptor, EnumDescriptor, ProtoMessageType] ) -> Type["DataBase"]: """Look up the data model class corresponding to the given protobuf If no data model is found, this raises an AttributeError Args: proto (Union[Descriptor, ProtoMessageType]) The proto name or descriptor to look up against Returns: dm_class (Type[DataBase]): The data model class corresponding to the given protobuf """ error.type_check( "<COR46446770E>", Descriptor, FieldDescriptor, EnumDescriptor, ProtoMessageType, proto=proto, ) proto_full_name = ( proto.full_name if isinstance(proto, (Descriptor, FieldDescriptor, EnumDescriptor)) else proto.DESCRIPTOR.full_name ) cls = _DataBaseMetaClass.class_registry.get(proto_full_name) if cls is None: error( "<COR71783879E>", AttributeError( "no data model class found in registry for protobufs named `{}`".format( proto_full_name ) ), ) return cls
[docs] @staticmethod def get_class_for_name(class_name: str) -> Type["DataBase"]: """Look up the data model class corresponding to the given name This lookup attempts to encode various naming conventions that might be used, but it can fail in multiple ways: 1. No class with the given name is known 2. Multiple classes with the same name, but different qualified parents are found A ValueError will be raised if either of the above happens Args: class_name (str) The name of the class either as a fully-qualified protobuf name or as the unqualified class name Returns: dm_class (Type[DataBase]): The data model class corresponding to the given protobuf """ dm_class = _DataBaseMetaClass.class_registry.get(class_name) if dm_class is not None: return dm_class matching_classes = [ (full_name, dm_class) for full_name, dm_class in _DataBaseMetaClass.class_registry.items() if full_name.rpartition(".")[-1] == class_name ] if len(matching_classes) == 1: return matching_classes[0][1] if len(matching_classes) > 1: error( "<COR02514290E>", ValueError( "Conflicting data model classes for [{}]: {}".format( class_name, [match[0] for match in matching_classes] ) ), ) error( "<COR99562895E>", ValueError(f"No data model class match for {class_name}"), )