# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module holds the Pydantic wrapping required by the REST server,
capable of converting to and from Pydantic models to our DataObjects.
"""
# Standard
from datetime import date, datetime, time, timedelta
from typing import Any, Callable, Dict, List, Type, Union, get_args
import base64
import dataclasses
import enum
import inspect
import json
# Third Party
from fastapi import Request, status
from fastapi.datastructures import FormData
from fastapi.exceptions import HTTPException, RequestValidationError
from pydantic.fields import Field
from pydantic.functional_validators import BeforeValidator
from starlette.datastructures import UploadFile
from typing_extensions import Doc, get_type_hints
import numpy as np
import pydantic
# First Party
from py_to_proto.dataclass_to_proto import ( # Imported here for 3.8 compat
Annotated,
get_origin,
)
import alog
# Local
from caikit.core.data_model.base import DataBase
from caikit.interfaces.common.data_model import File
from caikit.interfaces.common.data_model.primitive_sequences import (
BoolSequence,
FloatSequence,
IntSequence,
StrSequence,
)
from caikit.runtime.http_server.utils import update_dict_at_dot_path
log = alog.use_channel("SERVR-HTTP-PYDNTC")
# PYDANTIC_TO_DM_MAPPING is essentially a 2-way map of DMs <-> Pydantic models, you give it a
# pydantic model, it gives you back a DM class, you give it a
# DM class, you get back a pydantic model.
PYDANTIC_TO_DM_MAPPING = {
# Map primitive sequences to lists
StrSequence: List[str],
IntSequence: List[int],
FloatSequence: List[float],
BoolSequence: List[bool],
}
[docs]
def pydantic_to_dataobject(pydantic_model: pydantic.BaseModel) -> DataBase:
"""Convert pydantic objects to our DM objects"""
dm_class_to_build = PYDANTIC_TO_DM_MAPPING.get(type(pydantic_model))
dm_kwargs = {}
for field_name, field_value in pydantic_model:
# field could be a DM:
# pylint: disable=unidiomatic-typecheck
if type(field_value) in PYDANTIC_TO_DM_MAPPING:
dm_kwargs[field_name] = pydantic_to_dataobject(field_value)
elif isinstance(field_value, list):
if all(type(val) in PYDANTIC_TO_DM_MAPPING for val in field_value):
dm_kwargs[field_name] = [
pydantic_to_dataobject(val) for val in field_value
]
else:
dm_kwargs[field_name] = field_value
else:
dm_kwargs[field_name] = field_value
return dm_class_to_build(**dm_kwargs)
[docs]
def dataobject_to_pydantic(dm_class: Type[DataBase]) -> Type[pydantic.BaseModel]:
"""Make a pydantic model based on the given proto message by using the data
model class annotations to mirror as a pydantic model
"""
# define a local namespace for type hints to get type information from.
# This is needed for pydantic to have a handle on JsonDict and JsonDictValue while
# creating its base model
localns = {"JsonDict": dict, "JsonDictValue": dict}
if dm_class in PYDANTIC_TO_DM_MAPPING:
return PYDANTIC_TO_DM_MAPPING[dm_class]
# Gather Mappings for field lookups
extra_field_type_mapping = get_type_hints(
dm_class, localns=localns, include_extras=True
)
dataclass_fields = dataclasses.fields(dm_class)
dataclass_field_mapping = {field.name: field for field in dataclass_fields}
class_defaults = dm_class.get_field_defaults()
# Construct a mapping of field names to the type and FieldInfo objects.
field_mapping = {}
for field_name, field_type in get_type_hints(dm_class, localns=localns).items():
extra_field_type = extra_field_type_mapping.get(field_name)
pydantic_type = _get_pydantic_type(field_type)
field_info_kwargs = {}
# If the DM field has a default then add it to the kwargs
dm_field_default = class_defaults.get(field_name)
if isinstance(dm_field_default, Callable):
field_info_kwargs[
"default_factory"
] = lambda func=dm_field_default: _conditionally_convert_dataobject(func())
elif dm_field_default is not None:
field_info_kwargs["default"] = _conditionally_convert_dataobject(
dm_field_default
)
# If no default is provided then default the field to None. this ensures
# the parameter isn't required and uses caikits default logic. Use
# default_factory to retain type info in swagger.
else:
field_info_kwargs["default_factory"] = lambda: None
# If the field is a DataBase object then set its title correctly
if inspect.isclass(field_type) and issubclass(field_type, DataBase):
field_info_kwargs["title"] = dm_class.get_proto_class().DESCRIPTOR.full_name
# If the field added dataclass metadata then add it to the Pydantic Field kwargs. This
if dataclass_field := dataclass_field_mapping.get(field_name):
field_info_kwargs.update(dataclass_field.metadata)
# If the field used the Doc type annotation then update the description
if get_origin(extra_field_type) is Annotated:
for annotated_arg in get_args(extra_field_type):
if isinstance(annotated_arg, Doc):
field_info_kwargs["description"] = annotated_arg.documentation
# Construct field info objects
field_info = Field(
**field_info_kwargs,
)
field_mapping[field_name] = (pydantic_type, field_info)
# We want to set the config to forbid extra attributes while instantiating any pydantic models
# This is done to make sure any oneofs can be correctly inferred by pydantic
pydantic_model_config = pydantic.ConfigDict(extra="forbid", protected_namespaces=())
# Construct the pydantic data model using create_model to ensure all internal variables
# are set correctly. This explicitly sets the name of the pydantic class to the
# name of the grpc buffer.
pydantic_model = pydantic.create_model(
dm_class.get_proto_class().DESCRIPTOR.full_name,
__config__=pydantic_model_config,
**field_mapping,
)
# Add the dataobject's doc message to the pydantic class. This has to happen
# after pydantic creation
pydantic_model.__doc__ = getattr(dm_class, "__doc__", "")
# Update DM Mappings
PYDANTIC_TO_DM_MAPPING[dm_class] = pydantic_model
# also store the reverse mapping for easy retrieval
# should be fine since we only check for dm_class in this dict
PYDANTIC_TO_DM_MAPPING[pydantic_model] = dm_class
return pydantic_model
# pylint: disable=too-many-return-statements
[docs]
def _get_pydantic_type(field_type: type) -> type:
"""Recursive helper to get a valid pydantic type for every field type"""
# pylint: disable=too-many-return-statements
# Leaves: we should have primitive types and enums
if np.issubclass_(field_type, np.integer):
return int
if np.issubclass_(field_type, np.floating):
return float
if field_type is bytes:
return Annotated[bytes, BeforeValidator(_from_base64)]
if field_type in (
int,
float,
bool,
str,
dict,
type(None),
date,
datetime,
time,
timedelta,
):
return field_type
if isinstance(field_type, type) and issubclass(field_type, enum.Enum):
return field_type
# These can be nested within other data models
if (
isinstance(field_type, type)
and issubclass(field_type, DataBase)
and not issubclass(field_type, pydantic.BaseModel)
):
# NB: for data models we're calling the data model conversion fn
return dataobject_to_pydantic(field_type)
# And then all of these types can be nested in other type annotations
if get_origin(field_type) is Annotated:
return _get_pydantic_type(get_args(field_type)[0])
if get_origin(field_type) is Union:
return Union[ # type: ignore
tuple((_get_pydantic_type(arg_type) for arg_type in get_args(field_type)))
]
if get_origin(field_type) is list:
return List[_get_pydantic_type(get_args(field_type)[0])]
if get_origin(field_type) is dict:
return Dict[
_get_pydantic_type(get_args(field_type)[0]),
_get_pydantic_type(get_args(field_type)[1]),
]
raise TypeError(f"Cannot get pydantic type for type [{field_type}]")
[docs]
def _conditionally_convert_dataobject(obj: Any) -> Any:
if not isinstance(obj, DataBase):
return obj
if inspect.isclass(obj) and issubclass(obj, DataBase):
return dataobject_to_pydantic(obj)
pydantic_class = dataobject_to_pydantic(obj.__class__)
return pydantic_class.model_validate_json(obj.to_json())
[docs]
def _from_base64(data: Union[bytes, str]) -> bytes:
if isinstance(data, str):
return base64.b64decode(data.encode("utf-8"))
return data
[docs]
async def pydantic_from_request(
pydantic_model: Type[pydantic.BaseModel], request: Request
):
"""Function to convert a fastapi request into a given pydantic model. This
function parses the requests Content-Type and then correctly decodes the data.
The currently supported Content-Types are `application/json`
and `multipart/form-data`"""
content_type = request.headers.get("Content-Type")
log.debug("Detected request using %s type", content_type)
# If content type is json use pydantic to parse
if content_type == "application/json":
raw_content = await request.body()
try:
return pydantic_model.model_validate_json(raw_content)
except pydantic.ValidationError as err:
raise RequestValidationError(errors=err.errors()) from err
# Elif content is form-data then parse the form
elif "multipart/form-data" in content_type:
# Get the raw form data
raw_form = await request.form()
return _parse_form_data_to_pydantic(pydantic_model, raw_form)
else:
raise HTTPException(
status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
f"Unsupported media type: {content_type}.",
)
[docs]
def _get_pydantic_subtypes(
pydantic_model: Type[pydantic.BaseModel], keys: List[str]
) -> List[type]:
"""Recursive helper to get the type_hint for a field"""
if len(keys) == 0:
return [pydantic_model]
# Get the type hints for the current key
current_key = keys.pop(0)
current_type = get_type_hints(pydantic_model).get(current_key)
if not current_type:
return []
if get_origin(current_type) is Union:
# If we're trying to capture a union then return the entire union result
if len(keys) == 0:
return get_args(current_type)
# Get the arg which matches
for arg in get_args(current_type):
if result := _get_pydantic_subtypes(arg, keys):
return result
# If object is a list then recurse on its type
elif get_origin(current_type) is list:
if len(keys) == 0:
return [current_type]
return _get_pydantic_subtypes(get_args(current_type)[0], keys)
else:
return _get_pydantic_subtypes(current_type, keys)