Source code for caikit.core.data_model.enums

# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Enumeration data structures map from strings to integers and back.
"""

# Standard
from enum import Enum
from typing import Dict, Optional, Tuple, Type

# Third Party
from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper
import google
import munch

# First Party
import alog

# Local
from ..exceptions import error_handler

log = alog.use_channel("DATAM")
error = error_handler.get(log)


@classmethod
def to_dict(cls) -> Dict[str, int]:
    """Return a dict representation of the keys and values"""
    if not hasattr(cls, "__dict_repr__"):
        cls.__dict_repr__ = {
            entry.name: entry.value for entry in cls  # pylint: disable=not-an-iterable
        }
    return cls.__dict_repr__


@classmethod
def to_munch(cls) -> munch.Munch:
    """Return a munchified version of the enum"""
    if not hasattr(cls, "__munch_repr__"):
        cls.__munch_repr__ = munch.Munch(cls.to_dict())
    return cls.__munch_repr__


__all__ = ["import_enums", "import_enum"]


[docs] def import_enum( proto_enum: EnumTypeWrapper, enum_class: Optional[Type[Enum]] = None ) -> Tuple[str, str]: """Import a single enum into the global enum module by name Args: proto_enum (EnumTypeWrapper): The enum to import enum_class (Optional[Type[Enum]]): A pre-existing enum class that this proto enum binds to Returns: name: str The name of the enum global rev_name: str The name of the reversed enum global """ if not isprotobufenum(proto_enum): error( "<COR71783964E>", AttributeError(f"`{proto_enum}` is not a valid protobuf enumeration"), ) name = proto_enum.DESCRIPTOR.name log.debug2("Importing enum named %s", name) if enum_class is None: log.debug2("Creating Enum class for %s", name) enum_class = Enum._create_(name, proto_enum.items()) # Add extra utility functions enum_class.to_dict = to_dict enum_class.to_munch = to_munch globals()[name] = enum_class rev_name = name + "Rev" globals()[rev_name] = munch.Munch({v: k for k, v in proto_enum.items()}) __all__.append(name) __all__.append(rev_name) return name, rev_name
[docs] def import_enums(current_globals): """Add all enums and their reverse enum mappings a module's global symbol table. Note that we also update __all__. In general, __all__ controls the stuff that comes with a wild (*) import. Examples tend to make stuff like this easier to understand. Let's say the first name we hit is the Entity Mention Type. Then, after the first cycle through the loop below, you'll see something like: '__all__': ['import_enums', 'EntityMentionType', 'EntityMentionTypeRev'] 'EntityMentionType': { "MENTT_UNSET": 0, "MENTT_NAM": 1, ... , "MENTT_NONE": 4} 'EntityMentionTypeRev': { "0": "MENTT_UNSET", "1": "MENTT_NAM", ... , "4": "MENTT_NONE"} since this is called explicitly below, you can thank this function for automagically syncing your enums (as importable from this file) with the data model. Args: current_globals (dict): global dictionary from your data model package __init__ file. """ # Like the proto imports, we'd one day like to do this with introspection using something # like below, but can't because our wheel is compiled. If you can think of a cleaner way # to do this, open a PR! # caller = inspect.stack()[1] # caller_module = inspect.getmodule(caller[0]) # current_globals = caller_module.__dict__ # Add the str->int (EnumBase) and int->str (EnumRevBase) mapping for each enum # to the calling module's symbol table, then update __all__ to include the names # for the added objects. protobufs = current_globals.get("protobufs") all_enum_names = getattr(protobufs, "all_enum_names", []) for name in all_enum_names: proto_enum = getattr(protobufs, name) name, rev_name = import_enum(proto_enum) current_globals[name] = globals()[name] current_globals[rev_name] = globals()[rev_name]
def isprotobufenum(obj): """Returns True if obj is a protobufs enum.""" return isinstance(obj, google.protobuf.internal.enum_type_wrapper.EnumTypeWrapper)