Source code for caikit.core.data_model.json_dict

# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module holds common utilities for managing arbitrary JSON serializable
dicts as protobuf Struct objects
"""
# Standard
from typing import Dict, List, Optional, Union

# Third Party
from google.protobuf import descriptor, message_factory, struct_pb2

# Type hints for JSON serializable dicts
JsonDictValue = Union[
    int,
    float,
    str,
    bool,
    type(None),
    List["JsonDictValue"],
    "JsonDict",
]
JsonDict = Dict[str, JsonDictValue]

STRUCT_PROTOBUF_NAME = "google.protobuf.Struct"


[docs] def dict_to_struct( dictionary: JsonDict, struct_class: Optional[type] = None, value_class: Optional[type] = None, list_value_class: Optional[type] = None, ) -> struct_pb2.Struct: """Convert a python dict to a protobuf Struct""" if struct_class is None: struct_class = struct_pb2.Struct value_class = struct_pb2.Value list_value_class = struct_pb2.ListValue else: if value_class is None: value_class = _get_message_class( struct_class.DESCRIPTOR.file.pool.FindMessageTypeByName( "google.protobuf.Value" ) ) if list_value_class is None: list_value_class = _get_message_class( struct_class.DESCRIPTOR.file.pool.FindMessageTypeByName( "google.protobuf.ListValue" ) ) return struct_class( fields={ key: _value_to_struct_value( value, struct_class=struct_class, value_class=value_class, list_value_class=list_value_class, ) for key, value in dictionary.items() } )
[docs] def struct_to_dict(struct: struct_pb2.Struct) -> JsonDict: """Convert a struct into the equivalent json dict""" return {key: _struct_value_to_py(val) for key, val in struct.fields.items()}
## Implementation Details ######################################################
[docs] def _value_to_struct_value(value, struct_class, value_class, list_value_class): """Recursive helper to convert python values to struct values""" if value is None: struct_value = value_class(null_value=struct_pb2.NullValue.NULL_VALUE) elif isinstance(value, dict): struct_value = value_class( struct_value=dict_to_struct( value, struct_class=struct_class, value_class=value_class, list_value_class=list_value_class, ) ) elif isinstance(value, list): struct_value = value_class( list_value=list_value_class( values=( _value_to_struct_value( item, struct_class=struct_class, value_class=value_class, list_value_class=list_value_class, ) for item in value ) ) ) elif isinstance(value, bool): struct_value = value_class(bool_value=value) elif isinstance(value, (int, float)): struct_value = value_class(number_value=value) elif isinstance(value, str): struct_value = value_class(string_value=value) else: raise ValueError(f"Unsupported value type: {type(value)}") return struct_value
[docs] def _struct_value_to_py(struct_value: struct_pb2.Value) -> JsonDictValue: """Recursive helper to convert struct values to python values""" which = struct_value.WhichOneof("kind") if which in [None, "null_value"]: return None if which == "number_value": val = struct_value.number_value if int(val) == val: return int(val) return val if which in ["string_value", "bool_value"]: return getattr(struct_value, which) if which == "struct_value": return struct_to_dict(struct_value.struct_value) if which == "list_value": return [_struct_value_to_py(item) for item in struct_value.list_value.values]
[docs] def _get_message_class( desc: descriptor.Descriptor, ) -> message_factory.message.Message: """Helper to get the concrete protobuf class from a descriptor. This supports compatibility between protobuf 3.X and 4.X """ if hasattr(message_factory, "GetMessageClass"): return message_factory.GetMessageClass(desc) # pragma: no cover return message_factory.MessageFactory().GetPrototype(desc) # pragma: no cover