Source code for caikit.core.model_management.local_model_trainer

# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The LocalModelTrainer uses a local thread to launch and manage each training job

model_management:
    trainers:
        <predictor name>:
            type: LOCAL
            config:
                # ! Inherits config from LocalJobBase
                # Whether to train in a subprocess or different thread
                use_subprocess: <bool: False>
                # Type of subprocess to use if use_subprocess=True
                subprocess_start_method: <string: spawn>
                
"""

# Standard
from datetime import datetime
from typing import Optional, Type
import os
import uuid

# First Party
import aconfig
import alog

# Local
from ..exceptions import error_handler
from ..exceptions.caikit_core_exception import CaikitCoreException, CaikitCoreStatusCode
from ..modules import ModuleBase
from ..toolkit.logging import configure as configure_logging
from .local_job_base import LocalJobBase
from .model_trainer_base import ModelTrainerBase, ModelTrainerFutureBase
import caikit

log = alog.use_channel("LOC-TRNR")
error = error_handler.get(log)


[docs] class LocalModelTrainer(LocalJobBase, ModelTrainerBase): __doc__ = __doc__
[docs] class LocalModelTrainFuture(LocalJobBase.LocalJobFuture, ModelTrainerFutureBase): """LocalModelTrainFuture takes a module class and training request and runs it in a thread or subprocess"""
[docs] def run(self, *args, **kwargs): """Function that will run in the worker thread""" # If running in a spawned subprocess, reconfigure logging if self._use_subprocess and self._subprocess_start_method != "fork": configure_logging() with alog.ContextTimer(log.debug, "Training %s finished in: ", self.id): trained_model = self._module_class.train(*args, **kwargs) if self.save_path is not None: log.debug("Saving training %s to %s", self.id, self.save_path) with alog.ContextTimer(log.debug, "Training %s saved in: ", self.id): trained_model.save(self.save_path) self._completion_time = self._completion_time or datetime.now() log.debug2("Completion time for %s: %s", self.id, self._completion_time) return trained_model
[docs] def load(self) -> ModuleBase: """Wait for the training to complete, then return the resulting model or raise any errors that happened during training. """ self.wait() if self._use_subprocess: log.debug2("Loading model saved in subprocess") error.value_check( "<COR16745216E>", self.save_path, "Unable to load model from training {} " + "trained in subprocess without a save_path", self.id, ) error.value_check( "<COR59551640E>", os.path.exists(self.save_path), "Unable to load model from training {} " + "saved in subprocess, path does not exist: {}", self.id, self.save_path, ) result = caikit.load(self.save_path) else: result = self._worker.get_or_throw() return result
LocalModelFuture = LocalModelTrainFuture name = "LOCAL" ## Interface ## def __init__(self, config: aconfig.Config, instance_name: str): """Initialize with a shared dict of all trainings""" self._use_subprocess = config.get("use_subprocess", False) self._subprocess_start_method = config.get("subprocess_start_method", "spawn") super().__init__(config, instance_name)
[docs] def train( self, module_class: Type[ModuleBase], *args, save_path: Optional[str] = None, save_with_id: bool = False, external_training_id: Optional[str] = None, model_name: Optional[str] = None, **kwargs, ) -> "LocalModelFuture": """Start training the given module and return a future to the trained model instance """ # Always purge old futures self._purge_old_futures() # Wrap any models in the kwargs for safe spawning if needed if self._use_subprocess and self._subprocess_start_method != "fork": wrapped_models = { key: _SpawnProcessModelWrapper(val) for key, val in kwargs.items() if isinstance(val, ModuleBase) } log.debug2("Subprocess wrapped models: %s", wrapped_models.keys()) kwargs.update(wrapped_models) # If there's an external ID, make sure it's not currently running before # launching the job if external_training_id and ( current_future := self._futures.get(external_training_id) ): error.value_check( "<COR79850511E>", current_future.get_info().status.is_terminal, "Cannot restart training {} that is currently running", external_training_id, ) # Update kwargs with required information future_id = external_training_id or str(uuid.uuid4()) # Create the new future model_future = self.LocalModelFuture( self._instance_name, module_class, save_path=save_path, future_id=future_id, save_with_id=save_with_id, use_subprocess=self._use_subprocess, subprocess_start_method=self._subprocess_start_method, model_name=model_name, args=args, kwargs=kwargs, ) # Lock the global futures dict and add it to the dict with self._futures_lock: if current_future := self._futures.get(model_future.id): error.value_check( "<COR35431427E>", current_future.get_info().status.is_terminal, "UUID collision for model future {}", model_future.id, ) self._futures[model_future.id] = model_future # Return the future return model_future
[docs] def get_model_future(self, training_id: str) -> "LocalModelFuture": """Look up the model future for the given id. Use custom implementation to retain error codes""" self._purge_old_futures() if model_future := self._futures.get(training_id): return model_future raise CaikitCoreException( status_code=CaikitCoreStatusCode.NOT_FOUND, message=f"Unknown training_id: {training_id}", )
[docs] class _SpawnProcessModelWrapper(ModuleBase): """This class wraps up a model to make it safe to pass to a spawned subprocess. It will not be efficient, but it will be safe! """ def __init__(self, model: ModuleBase): super().__init__() self._model = model
[docs] def __getattr__(self, name): """Forward attributes that are not found on the base class to the model NOTE: This does _not_ forward base class attributes since those are resolved before __getattr__ is called. """ return getattr(self._model, name)
[docs] def save(self, *args, **kwargs): """Directly forward save to the model so that it is not called by the base class """ return self._model.save(*args, **kwargs)
[docs] def run(self, *args, **kwargs): """Directly forward run to the model so that it is not called by the base class """ return self._model.run(*args, **kwargs)
[docs] def __getstate__(self) -> bytes: """When pickling, only send the serialized model body for non-fork. This is not a general-purpose pickle solution for models, but makes them safe for training jobs that need to move models between processes. """ return self._model.as_bytes()
[docs] def __setstate__(self, pickled: bytes): """When unpickling, deserialize the body if the model is not already loaded in the model manager. This must be used in conjunction with the above __getstate__ across a process boundary and should not be used as a general-purpose deserialization for models. """ retrieved_model = caikit.core.load(pickled) self._model = retrieved_model