Source code for caikit.runtime.servicers.model_runtime_servicer

# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Have pylint ignore Class XXXX has no YYYY member so that we can use gRPC enums.
# pylint: disable=E1101

# First Party
import alog

# Local
from caikit import get_config
from caikit.runtime.model_management.model_manager import ModelManager
from caikit.runtime.protobufs import model_runtime_pb2, model_runtime_pb2_grpc
from caikit.runtime.types.aborted_exception import AbortedException
from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException
from caikit.runtime.work_management.abortable_context import (
    AbortableContext,
    ThreadInterrupter,
)
from caikit.runtime.work_management.rpc_aborter import RpcAborter

log = alog.use_channel("MR-SERVICR-I")


[docs] class ModelRuntimeServicerImpl(model_runtime_pb2_grpc.ModelRuntimeServicer): """This class contains the implementation of all of the RPCs that are required to run a service in Model Mesh as a Model-Runtime.""" def __init__(self, interrupter: ThreadInterrupter = None): self.model_manager = ModelManager.get_instance() self.interrupter = interrupter
[docs] def loadModel(self, request, context): """Model loading . Args: request(model_runtime_pb2.LoadModelRequest): gRPC request, gen from model-runtime.proto context(grpc.ServicerContext): Context object (contains request metadata, etc) Returns: model_runtime_pb2.LoadModelResponse: Gen from model-runtime.proto """ try: log.info( { "log_code": "<RUN10000106I>", "message": f"Loading model '{request.modelId}'", "model_id": request.modelId, } ) aborter = RpcAborter(context) if self.interrupter else None with AbortableContext(aborter=aborter, interrupter=self.interrupter): loaded_model = self.model_manager.load_model( request.modelId, request.modelPath, request.modelType ) model_size = loaded_model.size() log.info( { "log_code": "<RUN10000107I>", "message": f"Model '{request.modelId}' loaded! Model size [{model_size}]", "model_id": request.modelId, } ) except AbortedException as e: log.warning( { "log_code": "<RUN82590507W>", "message": f"Model '{request.modelId}' was not loaded due to the rpc aborting", "model_id": request.modelId, "error_id": e.id, } ) # Unload the model in case it had actually finished loading self.model_manager.unload_model(request.modelId) raise e except CaikitRuntimeException as e: log.warning( { "log_code": "<RUN84720101W>", "message": ( f"Model '{request.modelId}' could not be loaded!" f"Reason: [{e.message}]" ), "model_id": request.modelId, "error_id": e.id, } ) raise e # get concurrency model_mesh_config = get_config().inference_plugin.model_mesh if request.modelType in model_mesh_config.max_model_concurrency_per_type: max_concurrency = model_mesh_config.max_model_concurrency_per_type[ request.modelType ] else: max_concurrency = model_mesh_config.max_model_concurrency return model_runtime_pb2.LoadModelResponse( sizeInBytes=model_size, maxConcurrency=max_concurrency )
[docs] def unloadModel(self, request, context): """Model unloading. Args: request(model_runtime_pb2.UnloadModelRequest): gRPC request, gen from model-runtime.proto context(grpc.ServicerContext): Context object (contains request metadata, etc) Returns: model_runtime_pb2.UnloadModelResponse: Gen from model-runtime.proto """ try: log.info( { "log_code": "<RUN10000110I>", "message": f"Unloading model '{request.modelId}'", "model_id": request.modelId, } ) model_size = self.model_manager.unload_model(request.modelId) log.info( { "log_code": "<RUN10000111I>", "message": f"Unloaded model '{request.modelId}' (Reclaimed size: {model_size})", "model_id": request.modelId, } ) except CaikitRuntimeException as e: log.warning( { "log_code": "<RUN18471838W>", "message": ( f"Model '{request.modelId}' could not be unloaded!" f"Reason: [{e.message}]" ), "model_id": request.modelId, "error_id": e.id, } ) raise e return model_runtime_pb2.UnloadModelResponse()
[docs] def predictModelSize(self, request, context): """Predict size of not-yet-loaded model Args: request(model_runtime_pb2.PredictModelSizeRequest): gRPC request, gen from model-runtime.proto context(grpc._server._Context): Context object (contains request metadata, etc) Returns: model_runtime_pb2.PredictModelSizeResponse: Gen from model-runtime.proto """ try: log.info( { "log_code": "<RUN10000120I>", "message": f"Predicting size of model '{request.modelId}'", "model_id": request.modelId, } ) predicted_size = self.model_manager.estimate_model_size( request.modelId, request.modelPath, request.modelType ) log.info( { "log_code": "<RUN10000123I>", "message": f"Predicted model '{request.modelId}' size: [{predicted_size}]", "model_id": request.modelId, } ) except CaikitRuntimeException as e: log.warning( { "log_code": "<RUN14920102W>", "message": ( f"Model '{request.modelId}' size could not be predicted!" f"Reason: [e.message]" ), "model_id": request.modelId, "error_id": e.id, } ) raise e return model_runtime_pb2.PredictModelSizeResponse(sizeInBytes=predicted_size)
[docs] def modelSize(self, request, context): """Compute size (memory consumption) of currently-loaded model Args: request(model_runtime_pb2.ModelSizeRequest): gRPC request, gen from model-runtime.proto context(grpc._server._Context): Context object (contains request metadata, etc) Returns: model_runtime_pb2.ModelSizeResponse: Gen from model-runtime.proto """ try: log.info( { "log_code": "<RUN10000121I>", "message": f"Computing size of model '{request.modelId}'", "model_id": request.modelId, } ) model_size = self.model_manager.get_model_size(request.modelId) log.info( { "log_code": "<RUN10000122I>", "message": f"Computed model '{request.modelId}' size: [{model_size}]", "model_id": request.modelId, } ) except CaikitRuntimeException as e: log.warning( { "log_code": "<RUN14440122W>", "message": ( f"Failed to calculate model '{request.modelId}' size!" f"Reason: [{e.message}]" ), "model_id": request.modelId, "error_id": e.id, } ) raise e return model_runtime_pb2.ModelSizeResponse(sizeInBytes=model_size)
[docs] def runtimeStatus(self, request, context): """Runtime status checking. Args: request(model_runtime_pb2.RuntimeStatusRequest): GRPC request, gen from model-runtime.proto context(grpc.ServicerContext): Context object (contains request metadata, etc) Returns: model_runtime_pb2.RuntimeStatusResponse: Gen from model-runtime.proto """ model_mesh_config = get_config().inference_plugin.model_mesh log.info( "<RUN25209721I>", "Starting Model Runtime version: %s", model_mesh_config.runtime_version, ) return model_runtime_pb2.RuntimeStatusResponse( status=model_runtime_pb2.RuntimeStatusResponse.READY, capacityInBytes=model_mesh_config.capacity, maxLoadingConcurrency=model_mesh_config.max_loading_concurrency, modelLoadingTimeoutMs=model_mesh_config.model_loading_timeout_ms, defaultModelSizeInBytes=model_mesh_config.default_model_size, runtimeVersion=model_mesh_config.runtime_version, numericRuntimeVersion=model_mesh_config.numeric_runtime_version, limitModelConcurrency=model_mesh_config.latency_based_autoscaling_enabled, )