Source code for caikit.runtime.servicers.training_management_servicer

# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Have pylint ignore Class XXXX has no YYYY member so that we can use gRPC enums.
# pylint: disable=E1101

# Third Party
import grpc

# First Party
import alog

# Local
from caikit.core import MODEL_MANAGER
from caikit.core.data_model import TrainingStatus
from caikit.core.exceptions.caikit_core_exception import (
    CaikitCoreException,
    CaikitCoreStatusCode,
)
from caikit.interfaces.runtime.data_model import TrainingStatusResponse
from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException
from caikit.runtime.utils.servicer_util import raise_caikit_runtime_exception

log = alog.use_channel("TM-SERVICR-I")


[docs] class TrainingManagementServicerImpl: """This class contains the implementation of all of the RPCs that are required to run a service in Model Mesh as a Model-Runtime.""" ####################### ## gRPC Service Impl ## #######################
[docs] def GetTrainingStatus(self, request, context): # pylint: disable=unused-argument """Get the status of a training by ID""" return self.get_training_status(request.training_id).to_proto()
[docs] def CancelTraining(self, request, context): # pylint: disable=unused-argument """Cancel a training future.""" return self.cancel_training(request.training_id).to_proto()
#################################### ## Interface-agnostic entrypoints ## ####################################
[docs] def get_training_status(self, training_id: str) -> TrainingStatusResponse: """Get the status of a training by ID""" model_future = self._get_model_future(training_id, operation="get_status") try: reasons = [] training_info = model_future.get_info() if training_info.errors: reasons = [str(error) for error in training_info.errors] return TrainingStatusResponse( training_id=training_id, state=training_info.status, reasons=reasons, submission_timestamp=training_info.submission_time, completion_timestamp=training_info.completion_time, ) except CaikitCoreException as err: raise_caikit_runtime_exception(exception=err) except Exception as err: raise CaikitRuntimeException( grpc.StatusCode.INTERNAL, "Failed to get status for training id {}".format( training_id, ), ) from err
[docs] def cancel_training(self, training_id: str) -> TrainingStatusResponse: """Cancel a training future.""" model_future = self._get_model_future(training_id, operation="cancel") try: model_future.cancel() training_info = model_future.get_info() reasons = [] if training_info.errors: reasons = [str(error) for error in training_info.errors] return TrainingStatusResponse( training_id=model_future.id, state=training_info.status, reasons=reasons, ) except CaikitCoreException as err: # In the case that we get a `NOT_FOUND`, we assume that the training was canceled. # This is to handle stateful trainers that implement `cancel` by fully deleting # the training. NB: Future `GetTrainingStatus` calls for these canceled trainings # would raise a not found error to the user. if err.status_code == CaikitCoreStatusCode.NOT_FOUND: return TrainingStatusResponse( training_id=training_id, state=TrainingStatus.CANCELED, ) raise_caikit_runtime_exception(exception=err) except Exception as err: log.debug2( "Unexpected error trying to cancel training id %s: [%s]", training_id, err, ) raise CaikitRuntimeException( grpc.StatusCode.INTERNAL, "Failed to cancel training id {}".format( training_id, ), ) from err
############################ ## Implementation Details ## ############################
[docs] @staticmethod def _get_model_future(training_id: str, operation: str): """Returns a model future, or raises 404 caikit runtime exception on error. Wrapped here so that we only catch errors directly in the `trainer.get_model_future` call """ try: return MODEL_MANAGER.get_model_future(training_id=training_id) except CaikitCoreException as err: raise_caikit_runtime_exception(exception=err) except Exception as err: log.debug2( "Caught unexpected exception while trying to look up model future for id %s: [%s]", training_id, err, ) raise CaikitRuntimeException( grpc.StatusCode.INTERNAL, "Unexpected error with training id {}. Could not perform {}".format( training_id, operation ), ) from err