Source code for caikit.runtime.client.remote_config

# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The RemoteModuleConfig is a ModuleConfig subclass used to describe a Module's
interface without referencing the source ModuleBase. 
"""
# Standard
from dataclasses import dataclass
from typing import List, Tuple, Type, Union, get_args, get_origin
import inspect

# First Party
import alog

# Local
from caikit.core.exceptions import error_handler
from caikit.core.modules.base import ModuleBase
from caikit.core.modules.config import ModuleConfig
from caikit.core.modules.meta import _ModuleBaseMeta
from caikit.core.registries import module_registry
from caikit.core.signature_parsing.module_signature import CaikitMethodSignature
from caikit.core.task import TaskBase
from caikit.interfaces.common.data_model.remote import ConnectionInfo
from caikit.runtime.names import (
    get_task_predict_request_name,
    get_task_predict_rpc_name,
    get_train_request_name,
    get_train_rpc_name,
)

log = alog.use_channel("REM_MODULE_CFG")
error = error_handler.get(log)


## RemoteRPC Descriptor ########################################################


[docs] @dataclass class RemoteRPCDescriptor: """Helper dataclass to store information about a Remote RPC.""" # Full signature for this RPC signature: CaikitMethodSignature # Request and response objects for this RPC request_dm_name: str response_dm_name: str # The name of the RPC rpc_name: str # Only used for infer RPC types input_streaming: bool = False output_streaming: bool = False
### Remote Module Config
[docs] class RemoteModuleConfig(ModuleConfig): """Helper class to differentiate a local ModuleConfig and a RemoteModuleConfig. The structure should contain the following fields/structure""" ## Connection Info # Remote information for how to access the server. connection: ConnectionInfo protocol: str # The name of the metadata field to use for model information # default is defined in runtime.names and is mm-model-id model_key: str ## Method Information # use list and tuples instead of a dictionary to avoid aconfig.Config error task_methods: List[Tuple[Type[TaskBase], List[RemoteRPCDescriptor]]] train_method: RemoteRPCDescriptor ## Target Module Information # Model_path is repurposed in RemoteConfig to be the name of the # model running on the remote model_path: str # Module id and name are passed directly to the @module() decorator module_id: str module_name: str # Reset reserved_keys, so we can manually add model_path reserved_keys = []
[docs] @classmethod def load_from_module( cls, module_reference: Union[str, Type[ModuleBase], ModuleBase], connection_info: ConnectionInfo, protocol: str, model_key: str, model_path: str, ) -> "RemoteModuleConfig": """Construct a new remote module configuration from an existing local Module Args: module_reference: Union[str, Type[ModuleBase]]: Module_reference should either be the id of the locally loaded module, or a module class model_path (str): The path used to load this module connection_info ConnectionInfo: The connection information of the remote to use protocol: str The protocol to connect with model_key: str The model key to use when sending GRPC requests. An example is mm-model-id Returns: model_config (RemoteModuleConfig): Instantiated RemoteModuleConfig for model given model_path. """ # Validate model path arg error.type_check("<COR71170339E>", str, model_path=model_path) # Get local module reference error.type_check( "<COR71270339E>", str, ModuleBase, _ModuleBaseMeta, module_reference=module_reference, ) if isinstance(module_reference, ModuleBase): local_module_class = module_reference.__class__ elif inspect.isclass(module_reference) and issubclass( module_reference, ModuleBase ): local_module_class = module_reference else: if module_reference not in module_registry(): raise KeyError(f"Unknown module reference {module_reference}") local_module_class = module_registry().get(module_reference) # Construct model config dict remote_config_dict = { # Connection info "connection": connection_info, "protocol": protocol, "model_key": model_key, # Method info "task_methods": [], "train_method": None, # Source module info "model_path": model_path, "module_id": f"{local_module_class.MODULE_ID}-remote", "module_name": f"{local_module_class.MODULE_NAME} Remote", } # Parse inference methods signatures for task_class in local_module_class.tasks: task_methods = [] for ( input_streaming, output_streaming, signature, ) in local_module_class.get_inference_signatures(task_class): # Don't get the actual DataBaseObject as the ServicePackage might not have # been generated request_class_name = get_task_predict_request_name( task_class, input_streaming, output_streaming ) task_request_name = get_task_predict_rpc_name( task_class, input_streaming, output_streaming ) if hasattr(signature.return_type, "__name__"): task_return_type = signature.return_type.__name__ else: task_return_type = get_origin(signature.return_type).__name__ # Get the underlying DataBaseObject for stream types if output_streaming and get_args(signature.return_type): # Use [0] as there will only be one internal type for DataStreams task_return_type = get_args(signature.return_type)[0].__name__ # Generate the rpc name and task type task_methods.append( RemoteRPCDescriptor( signature=signature, request_dm_name=request_class_name, response_dm_name=task_return_type, rpc_name=task_request_name, input_streaming=input_streaming, output_streaming=output_streaming, ) ) remote_config_dict["task_methods"].append((task_class, task_methods)) # parse train method signature if there is one if local_module_class.TRAIN_SIGNATURE and ( local_module_class.TRAIN_SIGNATURE.return_type is not None and local_module_class.TRAIN_SIGNATURE.parameters is not None ): train_request_name = get_train_request_name(local_module_class) train_rpc_name = get_train_rpc_name(local_module_class) remote_config_dict["train_method"] = RemoteRPCDescriptor( signature=local_module_class.TRAIN_SIGNATURE, request_dm_name=train_request_name, response_dm_name=local_module_class.TRAIN_SIGNATURE.return_type.__name__, rpc_name=train_rpc_name, ) return RemoteModuleConfig(remote_config_dict)