Source code for caikit.core.model_management.model_trainer_base

# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A Trainer is responsible for managing execution of a training job for a given
module class

Configuration for ModelTrainers lives under the config as follows:

model_management:
    trainers:
        <trainer name>:
            type: <trainer type name>
            config:
                <config option>: <value>
"""
# Standard
from typing import Optional, Type, Union
import abc

# Local
from ...interfaces.common.data_model.stream_sources import S3Path
from ..modules import ModuleBase
from .job_base import JobBase, JobFutureBase, JobInfo


[docs] class TrainingInfo(JobInfo): pass
[docs] class ModelTrainerFutureBase(JobFutureBase): def __init__(self, *args, **kwargs): if "trainer_name" in kwargs: kwargs["future_name"] = kwargs["trainer_name"] if "training_id" in kwargs: kwargs["future_id"] = kwargs["training_id"] # If save path is provided then add it as a property self._save_path = kwargs.get("save_path") super().__init__(*args, **kwargs) @property def save_path(self) -> Optional[str]: """If created with a save path, the future must expose it, including any injected background id """ return self._save_path
[docs] @abc.abstractmethod def load(self) -> ModuleBase: """A model future must be loadable with no additional arguments. Mainly useful in train results"""
[docs] def result(self) -> ModuleBase: """The result of a model train future is the loaded model""" return self.load()
[docs] class ModelTrainerBase(JobBase): __doc__ = __doc__ ModelFutureBase = ModelTrainerFutureBase
[docs] @abc.abstractmethod def train( self, module_class: Type[ModuleBase], *args, save_path: Optional[Union[str, S3Path]] = None, save_with_id: bool = False, model_name: Optional[str] = None, **kwargs, ) -> ModelFutureBase: """Start training the given module and return a future to the trained model instance """
[docs] @abc.abstractmethod def get_model_future(self, training_id: str) -> ModelFutureBase: """Look up the model future for the given id"""
## Shared Utilities ##
[docs] def get_future(self, job_id: str) -> JobFutureBase: """Look up the model future for the given id""" return self.get_model_future(job_id)
[docs] @classmethod def get_trainer_name(cls, training_id: str) -> str: """Un-hash the trainer's instance name from the given training id""" return cls.get_job_name(training_id)