Source code for caikit.runtime.service_generation.create_service

# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script auto-generates the `caikit-runtime.proto` RPC definitions for a
collection of caikit.core derived libraries
"""

# Standard
from typing import Dict, List, Type

# First Party
from aconfig import aconfig
import alog

# Local
from .rpcs import (
    CaikitRPCBase,
    ModuleClassTrainRPC,
    TaskPredictionCancelRPC,
    TaskPredictionJobRPC,
    TaskPredictionResultRPC,
    TaskPredictionStatusRPC,
    TaskPredictRPC,
)
from caikit.core import ModuleBase, TaskBase
from caikit.core.exceptions import error_handler
from caikit.core.signature_parsing.module_signature import CaikitMethodSignature

log = alog.use_channel("CREATE-RPCS")
error = error_handler.get(log)

## Globals #####################################################################

TRAIN_FUNCTION_NAME = "train"

## Utilities ###################################################################


[docs] def assert_compatible(modules: List[str], previous_modules: List[str]): """Logic about whether it's okay to include this set of modules in service generation Args: modules: list of module IDs that we are considering in service generation previous_modules: list of module IDs that were supported in the previous service version Raises: If a new service should not be built with this set of modules """ regressed_modules = set(previous_modules) - set(modules) if len(regressed_modules) > 0: log.error( "BREAKING CHANGE FOUND! These modules became unsupported. These models were " "on the supported list in previous version, but now are no longer supported." ) for mod in regressed_modules: log.error("Regressed module: %s", mod) error.value_check( "<SVC68235724E>", len(regressed_modules) == 0, "BREAKING CHANGE! Found unsupported module(s) that were previously supported: {}", regressed_modules, )
[docs] def create_inference_rpcs( modules: List[Type[ModuleBase]], caikit_config: aconfig.Config = None ) -> List[CaikitRPCBase]: """Handles the logic to create all the RPCs for inference""" rpcs = [] included_task_types = ( caikit_config and caikit_config.runtime.service_generation and caikit_config.runtime.service_generation.task_types and caikit_config.runtime.service_generation.task_types.included ) or [] excluded_task_types = ( caikit_config and caikit_config.runtime.service_generation and caikit_config.runtime.service_generation.task_types and caikit_config.runtime.service_generation.task_types.excluded ) or [] task_groups = _group_modules_by_task( modules, included_task_types, excluded_task_types ) # Create the RPC for each task for task, task_methods in task_groups.items(): with alog.ContextLog(log.debug, "Generating task RPC for %s", task): for streaming_type, method_signatures in task_methods.items(): input_streaming, output_streaming = streaming_type try: rpcs.append( TaskPredictRPC( task, method_signatures, input_streaming, output_streaming ) ) log.debug("Successfully generated task RPC for %s", task) except Exception as err: # pylint: disable=broad-exception-caught log.warning( "Cannot generate task rpc for %s: %s", task, err, exc_info=True, ) return sorted(rpcs, key=lambda x: x.name)
[docs] def create_job_inference_rpcs( modules: List[Type[ModuleBase]], caikit_config: aconfig.Config = None ) -> List[CaikitRPCBase]: """Handles the logic to create all the RPCs for inference jobs""" rpcs = [] included_task_types = ( caikit_config and caikit_config.runtime.service_generation and caikit_config.runtime.service_generation.task_types and caikit_config.runtime.service_generation.task_types.included ) or [] excluded_task_types = ( caikit_config and caikit_config.runtime.service_generation and caikit_config.runtime.service_generation.task_types and caikit_config.runtime.service_generation.task_types.excluded ) or [] task_groups = _group_modules_by_task( modules, included_task_types, excluded_task_types ) # Create the RPC for each task for task, task_methods in task_groups.items(): with alog.ContextLog(log.debug, "Generating task RPC for %s", task): for streaming_type, method_signatures in task_methods.items(): input_streaming, output_streaming = streaming_type # For every unary task add the generic prediction job services if not input_streaming and not output_streaming: rpcs.extend( [ TaskPredictionJobRPC(task, method_signatures), TaskPredictionResultRPC(task, method_signatures), TaskPredictionStatusRPC(task, method_signatures), TaskPredictionCancelRPC(task, method_signatures), ] ) return sorted(rpcs, key=lambda x: x.name)
[docs] def create_training_rpcs(modules: List[Type[ModuleBase]]) -> List[CaikitRPCBase]: """Handles the logic to create all the RPCs for training""" rpcs = [] for ck_module in modules: if not ck_module.tasks: log.debug("Skipping module %s with no tasks", ck_module) continue # If this train function has not been changed from the base, skip it as # a module that can't be trained # # HACK alert! I'm struggling to find the right way to identify this # condition, so for now, we'll use the string repr train_fn = getattr(ck_module, TRAIN_FUNCTION_NAME) if str(train_fn).startswith(f"<bound method ModuleBase.{TRAIN_FUNCTION_NAME}"): log.debug( "Skipping train API for %s with no %s function", ck_module, TRAIN_FUNCTION_NAME, ) continue signature = ck_module.TRAIN_SIGNATURE log.debug( "Function signature for %s::%s [%s -> %s]", ck_module, TRAIN_FUNCTION_NAME, signature.parameters, signature.return_type, ) with alog.ContextLog(log.debug, "Generating train RPC for %s", ck_module): try: rpcs.append(ModuleClassTrainRPC(signature)) log.debug("Successfully generated train RPC for %s", ck_module) except Exception as err: # pylint: disable=broad-exception-caught log.warning( "Cannot generate train rpc for %s: %s", ck_module, err, exc_info=True, ) return sorted(rpcs, key=lambda x: x.name)
[docs] def _group_modules_by_task( modules: List[Type[ModuleBase]], included_tasks: List[Type[TaskBase]], excluded_tasks: List[Type[TaskBase]], ) -> Dict[Type[TaskBase], List[CaikitMethodSignature]]: task_groups = {} # Sort modules so the order of modules processed is deterministic modules = sorted(modules, key=lambda x: x.MODULE_ID) for ck_module in modules: for task_class in ck_module.tasks: if ( included_tasks and task_class.__name__ not in included_tasks or excluded_tasks and task_class.__name__ in excluded_tasks ): continue ck_module_task_name = task_class.__name__ if ck_module_task_name is not None: for ( input_streaming, output_streaming, signature, ) in ck_module.get_inference_signatures(task_class): task_groups.setdefault(task_class, {}).setdefault( (input_streaming, output_streaming), [] ).append(signature) return task_groups