Source code for caikit.core.modules.decorator

# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module holds the @module decorator which is used to decorate a
caikit.module
"""

# Standard
from typing import Dict, List, Optional, Type, Union
import collections

# Third Party
import semver

# First Party
import alog

# Local
from ..data_model import ProducerId
from ..exceptions import error_handler
from ..registries import module_backend_registry, module_backend_types, module_registry
from ..signature_parsing import CaikitMethodSignature
from ..task import TaskBase
from .base import ModuleBase
import caikit.core

log = alog.use_channel("MODULE_DEC")
error = error_handler.get(log)


# This is used by individual caikit.module implementations,
# to define what backend type models they can support at load time.
SUPPORTED_LOAD_BACKENDS_VAR_NAME = "SUPPORTED_LOAD_BACKENDS"


[docs] def module( id=None, # pylint: disable=redefined-builtin name=None, version=None, task: Type[TaskBase] = None, tasks: Optional[List[Type[TaskBase]]] = None, backend_type="LOCAL", base_module: Union[str, Type[ModuleBase]] = None, backend_config_override: Optional[Dict] = None, ): """Apply this decorator to any class that should be treated as a caikit module (i.e., extends`{caikit.core.ModuleBase}) and registered with caikit.core so that the library "knows" the class is a caikit module and is capable of loading instances of the module. Args: id: str A UUID to use when registering this module with caikit.core Not required if based on another caikit module using `base_module` name: str A human-readable name for the module Not required if based on another caikit module using `base_module` version: str A SemVer for the module Not required if based on another caikit module using `base_module` task: Type[TaskBase] An ML task class that this module is an implementation for Not required if based on another caikit module using `base_module`, or if multiple tasks are specified using `tasks`. tasks: Optional[List[Type[TaskBase]] List of ML task classes that this module implements. backend_type: backend_type Associated backend type for the module. Default: `LOCAL` base_module: str | ModuleBase If this module is based on a different caikit module, provide name of the base module. Default: None backend_config_override: Dict Dictionary containing configuration required for the specific backend. Default: None Returns: A decorated version of the class to which it was applied, after registering the class as a valid module with caikit.core """ base_module_class = None # Flag to store if the current module is a backend implementation # of an existing module or not backend_module_impl = False # No mutable default backend_config_override = backend_config_override or {} if task and tasks: error( "<COR34125316E>", ValueError("Specify either task or tasks parameter, not both."), ) if tasks: error.type_check( "<COR34125317E>", list, allow_none=True, tasks=tasks, ) if any([id is None, version is None or name is None]): error.type_check( "<COR87944440E>", str, type(ModuleBase), allow_none=False, base_module=base_module, ) error.type_check( "<COR60584425E>", dict, allow_none=True, backend_config_override=backend_config_override, ) # If the base_module is a string, assume that it is the module_id of the # base module if isinstance(base_module, str): module_id = base_module error.value_check( "<COR09479833E>", module_id in module_registry(), "Unknown base module id: {}", module_id, ) base_module_class = module_registry()[module_id] # If base_module is a type, validate that it derives from ModuleBase and # use its MODULE_ID elif isinstance(base_module, type): if not issubclass(base_module, ModuleBase): error( "<COR20161747E>", f"base_module [{base_module}] does not derive from ModuleBase", ) base_module_class = base_module # TODO: Add support for inheritance of backend implementation # i.e if a module inherits from base_module id = base_module_class.MODULE_ID version = base_module_class.MODULE_VERSION name = base_module_class.MODULE_NAME tasks = base_module_class._TASK_CLASSES backend_module_impl = True if task is not None: tasks = [task] if tasks is None: tasks = [] error.type_check("<COR54118928E>", str, id=id, name=name, version=version) for t in tasks: error.subclass_check("<COR90789722E>", t, TaskBase, allow_none=True) semver.VersionInfo.parse(version) # Make sure this is a valid SemVer def decorator(cls_): # Verify this is a valid module type (inherits from the wrapped base class) if not issubclass(cls_, caikit.core.ModuleBase): error( "<COR32401861E>", TypeError( f"`{cls_.__name__}` does not extend ModuleBase", ), ) # Add attributes to the implementation class cls_.MODULE_ID = id # Module ID == Module Type ID cls_.MODULE_NAME = name # Module Name == Module Type Name cls_.MODULE_VERSION = version # Module Version == Module Type Version classname = f"{cls_.__module__}.{cls_.__qualname__}" cls_.MODULE_CLASS = classname cls_.PRODUCER_ID = ProducerId(cls_.MODULE_NAME, cls_.MODULE_VERSION) # Parse the `train` and `run` signatures cls_.RUN_SIGNATURE = CaikitMethodSignature(cls_, "run") cls_.TRAIN_SIGNATURE = CaikitMethodSignature(cls_, "train") cls_._TASK_INFERENCE_SIGNATURES = {} # If the module has tasks, validate them: task_classes = tasks for t in task_classes: if not t.has_inference_method_decorators(module_class=cls_): # Hackity hack hack - make sure at least one flavor is supported validated = False validation_errs = [] for input_streaming, output_streaming in [ [False, False], [True, True], [False, True], ]: try: t.validate_run_signature( cls_.RUN_SIGNATURE, input_streaming, output_streaming ) validated = True cls_._TASK_INFERENCE_SIGNATURES.setdefault(t, []).append( (input_streaming, output_streaming, cls_.RUN_SIGNATURE) ) break except (ValueError, TypeError) as e: validation_errs.append(e) if not validated: raise validation_errs[0] t.deferred_method_decoration(cls_) # Check to see if a super-class has any tasks. # These will have been validated by the superclass decorator already. tasks_in_hierarchy = [] for class_ in cls_.mro(): if hasattr(class_, "_TASK_CLASSES") and class_ is not cls_: tasks_in_hierarchy.extend(class_._TASK_CLASSES) if tasks_in_hierarchy: task_classes += tasks_in_hierarchy # Make sure the tasks are unique. Note that the order here is important # so that iterating the list of tasks is deterministic, unique, and the # tasks given in the class' module list are shown before tasks inherited # from parent classes. cls_._TASK_CLASSES = [] for task in task_classes: if task not in cls_._TASK_CLASSES: cls_._TASK_CLASSES.append(task) # If no backend support described in the class, add current backend # as the only backend that can load models trained by this module cls_.SUPPORTED_LOAD_BACKENDS = getattr( cls_, SUPPORTED_LOAD_BACKENDS_VAR_NAME, [backend_type] ) # Set its own backend_type as an attribute cls_.BACKEND_TYPE = backend_type # Verify UUID and add this module to the module registry if not backend_module_impl: if cls_.MODULE_ID in module_registry(): error( "<COR30607646E>", RuntimeError( "MODULE_ID `{}` conflicts for classes `{}` and `{}`".format( cls_.MODULE_ID, cls_.__name__, module_registry()[cls_.MODULE_ID].__name__, ) ), ) module_registry()[cls_.MODULE_ID] = cls_ # Register backend _register_module_implementation( cls_, backend_type, cls_.MODULE_ID, backend_config_override ) return cls_ return decorator
## Implementation Details ######################################################
[docs] def _register_module_implementation( implementation_class: type, backend_type: str, module_id: str, backend_config_override: Dict = None, ): """This function will register the mapping for the given module_id and backend_type to the implementation class Args: implementation_class (type): The class that is used to implement this backend type for the given module_id backend_type (str): Value from MODULE_BACKEND_TYPES that indicates the backend that this class implements module_id (str): The module_id from the caikit.core module registry that this class overloads backend_config_override (Dict): Dictionary containing essential overrides for the backend config. This will get stored with the implementation_class class name and will automatically get picked up and merged with other such configs for a specific backend """ backend_config_override = backend_config_override or {} log.debug( "Registering backend [%s] implementation for module [%s]", backend_type, module_id, ) error.value_check( "<COR86780140E>", backend_type in module_backend_types(), "Cannot override implementation of {} for unknown backend type {}", module_id, backend_type, ) core_class = module_registry().get(module_id) if core_class is None: # TODO! Inject a dummy entry that will raise on usage # Info level log: this is normal behavior if backend module is imported # before the base module log.info( "<COR86369940I>", "No core class found for new backend %s with module ID %s", implementation_class.__name__, module_id, ) # Do the registration! module_type_mapping = module_backend_registry().setdefault(module_id, {}) # Make sure this is not an overwrite of an existing registration existing_type = module_type_mapping.get(backend_type) assert ( existing_type is None or existing_type is implementation_class ), f"Registration conflict! ({module_id}, {backend_type}) already registered as {existing_type}" BackendConfig = collections.namedtuple( "BackendConfig", "impl_class backend_config_override" ) module_type_mapping[backend_type] = BackendConfig( impl_class=implementation_class, backend_config_override=backend_config_override )