Source code for caikit.core.model_manager

# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Most logic interacting with models.  Can load, etc.
"""

# Standard
from contextlib import contextmanager
from io import BytesIO
from threading import Lock
from typing import Dict, List, Optional, Type, Union
import os
import tempfile
import zipfile

# First Party
import alog

# Local
from ..interfaces.common.data_model.stream_sources import S3Path
from .exceptions import error_handler
from .model_management import (
    JobPredictorBase,
    JobPredictorFutureBase,
    ModelFinderBase,
    ModelInitializerBase,
    ModelTrainerBase,
    ModelTrainerFutureBase,
    job_predictor_factory,
    model_finder_factory,
    model_initializer_factory,
    model_trainer_factory,
)
from .model_management.local_model_initializer import LocalModelInitializer
from .module_backends.base import BackendBase
from .modules.base import ModuleBase
from .registries import module_registry
from .toolkit.factory import Factory, FactoryConstructible
from caikit.config import get_config

log = alog.use_channel("MDLMNG")
error = error_handler.get(log)

# restrict functions that are imported so we don't pollute the base module namespace
__all__ = [
    "get_valid_module_ids",
    "ModelManager",
]


[docs] def get_valid_module_ids(): """Get a dictionary mapping all module IDs to the string names of the implementing classes. """ return { module_id: model_class.__name__ for module_id, model_class in module_registry().items() }
[docs] class ModelManager: """Manage the models or resources for library.""" def __init__(self): """Initialize ModelManager.""" # Map to store module caches, to be used for singleton model lookups self._singleton_module_cache = {} self._trainers = {} self._finders = {} self._job_predictors = {} self._initializers = {} self.__singleton_lock = Lock()
[docs] def initialize_components(self): """Proactively initialize all configured trainer/finder/initializer component instances. This is a separate call to enable explicit config. """ # Initialize all configured components mm_config = get_config().model_management for trainer in mm_config.get("trainers", {}): self.get_trainer(trainer) for finder in mm_config.get("finders", {}): self.get_finder(finder) for initializer in mm_config.get("initializers", {}): self.get_initializer(initializer) for job_predictor in mm_config.get("job_predictors", {}): self.get_predictor(job_predictor)
## Public ##################################################################
[docs] def train( self, module: Union[Type[ModuleBase], str], *args, trainer: Union[str, ModelTrainerBase] = "default", save_path: Optional[Union[str, S3Path]] = None, save_with_id: bool = False, model_name: Optional[str] = None, wait: bool = False, **kwargs, ) -> ModelTrainerFutureBase: """Train an instance of the given module with the given args and kwargs using the given trainer. Each module's train function encapsulates the code needed to perform the training locally. This top-level train function provides the wrapper functionality to delegate the execution of the module's train function to an alternate framework using a ModelTrainerBase. It also allows training to be launched asynchronously. Args: module (Union[Type[ModuleBase], str]): The module class or guid for the module to train *args: Additional positional args to pass through to the module's train function Kwargs: trainer (Union[str, ModelTrainerBase]): The trainer to use. If given as a string, this is a key in the global config at model_management.trainers. save_path (Optional[Union[str, S3Path]]): Base path where the model should be saved (may be relative to a remote trainer's filesystem, or link to S3 storage) save_with_id (bool): Inject the training ID into the save path for the output model model_name (Optional[str]): Name of model that will be appended to the end of the save_path wait (bool): Wait for training to complete before returning **kwargs: Additional keyword arguments to pass through to the modules's train function Returns: model_future (ModelFutureBase): The future handle to the model which holds the status of the in-flight training. """ # Resolve the module class if isinstance(module, str): module_id = module module = module_registry().get(module_id) error.value_check( "<COR00469102E>", module is not None, "Unable to train unknown module {}", module_id, ) error.subclass_check("<COR05418775E>", module, ModuleBase) # Get the trainer to use trainer: ModelTrainerBase = self.get_trainer(trainer) # Start the training with alog.ContextTimer(log.debug, "Started training in: "): model_future = trainer.train( module, *args, save_path=save_path, save_with_id=save_with_id, model_name=model_name, **kwargs, ) log.debug( "Started training %s with save path %s", model_future.id, model_future.save_path, ) # If requested, wait for the future to complete if wait: log.debug("Waiting for training %s to complete", model_future.id) with alog.ContextTimer( log.debug, "Finished training %s in: ", model_future.id ): model_future.wait() # Return a handle to the training return model_future
[docs] def start_prediction_job( self, model: ModuleBase, prediction_func_name: str, *args, predictor: Union[str, JobPredictorBase] = "default", wait: bool = False, **kwargs, ) -> JobPredictorFutureBase: """Start a prediction job using a job_predictor. Args: model (ModuleBase): Loaded model to run prediction on prediction_func_name (str): String reference to name of function to run predictor (Union[str, JobPredictorBase], optional): Which job_predictor to use. Defaults to "default". wait (bool, optional): Weather to wait for job to finish. Defaults to False. Returns: JobPredictorFutureBase: Future to track job result """ error.type_check("<COR02418775E>", ModuleBase, model=model) # Get the predictor to use inferencer: JobPredictorBase = self.get_predictor(predictor) # Start the prediction job with alog.ContextTimer(log.debug, "Started prediction job in: "): prediction_future = inferencer.predict( model, prediction_func_name, *args, **kwargs, ) log.debug( "Started Prediction Job %s", prediction_future.id, ) # If requested, wait for the future to complete if wait: log.debug("Waiting for prediction %s to complete", prediction_future.id) with alog.ContextTimer( log.debug, "Finished prediction %s in: ", prediction_future.id ): prediction_future.wait() # Return a handle to the future return prediction_future
[docs] def get_model_future( self, training_id: str, ) -> ModelTrainerFutureBase: """Get the future handle to an in-progress training Args: training_id (str): The ID string from the original training submission's ModelFuture Returns: model_future (ModelTrainerFutureBase): The future handle to the model which holds the status of the in-flight training. """ try: trainer = self.get_trainer(ModelTrainerBase.get_trainer_name(training_id)) # Fall back to the default trainer to try to find this ID except ValueError: trainer = self.get_trainer("default") return trainer.get_model_future(training_id)
[docs] def get_prediction_future( self, prediction_id: str, ) -> JobPredictorFutureBase: """Get the future handle to an in-progress prediction job Args: prediction_id (str): The ID string from the original prediction submission's ModelFuture Returns: prediction_future (JobPredictorFutureBase): The future handle to the job which holds the status of the in-flight prediction. """ try: predictor = self.get_predictor( JobPredictorBase.get_predictor_name(prediction_id) ) # Fall back to the default trainer to try to find this ID except ValueError: predictor = self.get_predictor("default") return predictor.get_prediction_future(prediction_id)
[docs] def load( self, module_path: Union[str, BytesIO, bytes], *, load_singleton: bool = False, finder: Union[str, ModelFinderBase] = "default", initializer: Union[str, ModelInitializerBase] = "default", **kwargs, ): """Load a model and return an instantiated object on which we can run inference. Args: module_path (str | BytesIO | bytes): A module path (identifier) to one of the following: 1. A directory containing a yaml config file in the top level. 2. A zip archive containing either a yaml config file in the top level when extracted, or a directory containing a yaml config file in the top level. 3. A BytesIO object corresponding to a zip archive containing either a yaml config file in the top level when extracted, or a directory containing a yaml config file in the top level. 4. A bytes object corresponding to a zip archive containing either a yaml config file in the top level when extracted, or a directory containing a yaml config file in the top level. 5. A string that is understood by the configured finder/initializer Kwargs: load_singleton (bool): Load this model as a singleton finder (Union[str, ModelFinderBase]): Finder to use when loading this model. If passed as a string, this names the finder in the global config model_management.finders section. initializer (Union[str, ModelInitializerBase]): Loader to use when initializint this model. If passed as a string, this is the name of the initializer in the global config model_management.initializers section. Returns: model (ModuleBase) Model object that is loaded, configured, and ready for prediction. """ error.type_check("<COR98255724E>", bool, load_singleton=load_singleton) # This allows a user to load their own model (e.g. model saved to disk) load_path = get_config().load_path if ( load_path is not None and isinstance(module_path, str) and not os.path.exists(module_path) ): full_module_path = os.path.join(load_path, module_path) if os.path.exists(full_module_path): module_path = full_module_path # Ensure that we have a loadable directory. error.type_check("<COR98255419E>", str, BytesIO, bytes, module_path=module_path) if isinstance(module_path, str): # Ensure this path is operating system correct if it isn't already. module_path = os.path.normpath(module_path) # If we have bytes, convert to a buffer, since we already handle in memory binary streams. elif isinstance(module_path, bytes): module_path = BytesIO(module_path) # Now that we have a file like object | str we can try to load as an archive. if zipfile.is_zipfile(module_path): return self._load_from_zipfile( module_path, load_singleton, finder, initializer, **kwargs ) try: return self._do_load( module_path, load_singleton, finder, initializer, **kwargs ) except FileNotFoundError: error( "<COR80419785E>", FileNotFoundError( "Module load path `{}` does not contain a `config.yml` file.".format( module_path ) ), )
[docs] def extract( self, zip_path: str, model_path: str, force_overwrite: bool = False ) -> str: """Method to extract a downloaded archive to a specified directory. Args: zip_path (str): Location of .zip file to extract. model_path (str): Model directory where the archive should be unzipped unzipped. force_overwrite: bool (Defaults to false) Force an overwrite to model_path, even if the folder exists Returns: str: Output path where the model archive is unzipped. """ model_path = os.path.abspath(model_path) # skip if force_overwrite disabled and path already exists if not force_overwrite and os.path.exists(model_path): log.info( "INFO: Skipped extraction. Archive already extracted in directory: %s", model_path, ) return model_path with zipfile.ZipFile(zip_path, "r") as zip_f: zip_f.extractall(model_path) # path to model return model_path
[docs] def resolve_and_load( self, path_or_name_or_model_reference: Union[str, ModuleBase], **kwargs ): """Try our best to load a model, given a path or a name. Simply returns any loaded model passed in. This exists to ease the burden on workflow developers who need to accept individual modules in their API, where users may have references to custom models or may only have the ability to give the name of a stock model. Args: path_or_name_or_model_reference (str, ModuleBase): Either a - Path to a model on disk - Name of a model that the catalog knows about - Loaded module **kwargs: Any keyword arguments to pass along to ModelManager.load() or ModelManager.download() e.g. parent_dir Returns: A loaded module Examples: >>> stock_syntax_model = manager.resolve_and_load('syntax_izumo_en_stock') >>> local_categories_model = manager.resolve_and_load('path/to/categories/model') >>> some_custom_model = manager.resolve_and_load(some_custom_model) """ error.type_check( "<COR50266694E>", str, ModuleBase, path_or_name_or_model_reference=path_or_name_or_model_reference, ) # If this is already a module, we're good to go if isinstance(path_or_name_or_model_reference, ModuleBase): log.debug("Returning model %s directly", path_or_name_or_model_reference) return path_or_name_or_model_reference # Otherwise, this could either be a path on disk or some name of a model that our catalog # can resolve and fetch if os.path.isdir(path_or_name_or_model_reference): # Try to load from path log.debug( "Attempting to load model from path %s", path_or_name_or_model_reference ) return self.load(path_or_name_or_model_reference, **kwargs) error( "<COR50207495E>", ValueError( "could not find model with name `{}`".format( path_or_name_or_model_reference ) ), )
[docs] def get_singleton_model_cache_info(self): """Returns information about the singleton cache in {hash: module type} format Returns: Dict[str, type]: A dictionary of model hashes to model types """ return {k: type(v) for k, v in self._singleton_module_cache.items()}
[docs] def clear_singleton_cache(self): """Clears the cache of singleton models. Useful to release references of models, as long as you know that they are no longer held elsewhere and you won't be loading them again. Returns: None """ with self.__singleton_lock: self._singleton_module_cache.clear()
[docs] def get_trainer(self, trainer: Union[str, ModelTrainerBase]) -> ModelTrainerBase: """Get the configured model trainer or the one passed by value""" return self._get_component( component=trainer, component_dict=self._trainers, component_factory=model_trainer_factory, component_name="trainer", component_cfg=get_config().model_management.trainers, component_type=ModelTrainerBase, )
[docs] def get_finder(self, finder: Union[str, ModelFinderBase]) -> ModelFinderBase: """Get the configured model finder or the one passed by value""" return self._get_component( component=finder, component_dict=self._finders, component_factory=model_finder_factory, component_name="finder", component_cfg=get_config().model_management.finders, component_type=ModelFinderBase, )
[docs] def get_initializer( self, initializer: Union[str, ModelInitializerBase] ) -> ModelInitializerBase: """Get the configured model initializer or the one passed by value""" return self._get_component( component=initializer, component_dict=self._initializers, component_factory=model_initializer_factory, component_name="initializer", component_cfg=get_config().model_management.initializers, component_type=ModelInitializerBase, )
[docs] def get_predictor( self, inferencer: Union[str, JobPredictorBase] ) -> JobPredictorBase: """Get the configured job predictor or the one passed by value""" return self._get_component( component=inferencer, component_dict=self._job_predictors, component_factory=job_predictor_factory, component_name="predictor", component_cfg=get_config().model_management.job_predictors, component_type=JobPredictorBase, )
[docs] def get_module_backends( self, initialize: bool = True, ) -> List[BackendBase]: """Convenience method to get access to the configured module backends if any have been configured Args: initialize (bool): Initialize the components from config Returns: backends (List[BackendBase]): The list of backend instances that have been configured """ if initialize: log.debug3("Initializing components to get backends") self.initialize_components() return [ backend for initializer in self._initializers.values() if isinstance(initializer, LocalModelInitializer) for backend in initializer.backends ]
## Implementation Details ##################################################
[docs] def _do_load(self, module_path, load_singleton, finder, initializer, **kwargs): """Load a model from a directory. Args: module_path (str): Path to directory. At the top level of directory is `config.yml` which holds info about the model. load_singleton (bool): Load this model as a singleton finder (Union[str, ModelFinderBase]): Finder to use when loading this model. If passed as a string, this names the finder in the global config model_management.finders section. initializer (Union[str, ModelInitializerBase]): Loader to use when loading this model. If passed as a string, this is the name of the initializer in the global config model_management.initializers section. Returns: subclass of caikit.core.modules.ModuleBase: Model object that is loaded, configured, and ready for prediction. """ with self._singleton_lock(load_singleton): if singleton_entry := ( load_singleton and self._singleton_module_cache.get(module_path) ): log.debug("Found %s in the singleton cache", module_path) return singleton_entry # Use the given finder to try to find the module config for this # module_path # # NOTE: This will lazily construct named finders if needed log.debug("Attempting to find [%s] with finder %s", module_path, finder) finder = self.get_finder(finder) log.debug2("Finder type: %s", finder.name) model_config = finder.find_model(module_path, **kwargs) error.value_check( "<COR92173495E>", model_config is not None, "Unable to find a ModuleConfig for {}", module_path, ) # Use the given initializer to try to load the model # # NOTE: This will lazily construct named initializers if needed log.debug( "Attempting to initialize [%s] with initializer %s", module_path, initializer, ) initializer = self.get_initializer(initializer) log.debug2("Initializer type: %s", initializer.name) loaded_model = initializer.init(model_config, **kwargs) error.value_check( "<COR50207494E>", loaded_model is not None, "Unable to load model from {} with MODULE_ID {}", module_path, model_config.module_id, ) # If loading as a singleton, populate the cache if load_singleton: self._singleton_module_cache[module_path] = loaded_model # Return successfully! return loaded_model
[docs] def _load_from_zipfile( self, module_path, load_singleton, finder, initializer, **kwargs ): """Load a model from a zip archive. Args: module_path (str): Path to directory. At the top level of directory is `config.yml` which holds info about the model. load_singleton (bool): Load this model as a singleton finder (Union[str, ModelFinderBase]): Finder to use when loading this model. If passed as a string, this names the finder in the global config model_management.finders section. initializer (Union[str, ModelInitializerBase]): Loader to use when loading this model. If passed as a string, this is the name of the initializer in the global config model_management.initializers section. Returns: subclass of caikit.core.modules.ModuleBase: Model object that is loaded, configured, and ready for prediction. """ with tempfile.TemporaryDirectory() as extract_path: with zipfile.ZipFile(module_path, "r") as zip_f: zip_f.extractall(extract_path) # Depending on the way the zip archive is packaged, out temp directory may unpack # to files directly, or it may unpack to a (single) directory containing the files. # We expect the former, but fall back to the second if we can't find the config. try: model = self._do_load( extract_path, load_singleton, finder, initializer, **kwargs ) # NOTE: Error handling is a little gross here, the main reason being that we # only want to log to error() if something is fatal, and there are a good amount # of things that can go wrong in this process. except FileNotFoundError: def get_full_path(folder_name): return os.path.join(extract_path, folder_name) # Get the contained directories. Omit anything starting with __ to avoid # accidentally traversing compression artifacts, e.g., __MACOSX. nested_dirs = [ get_full_path(f) for f in os.listdir(extract_path) if os.path.isdir(get_full_path(f)) and not f.startswith("__") ] # If we have multiple dirs, something is probably wrong - this doesn't look # like a simple level of nesting as a result of creating the zip. if len(nested_dirs) != 1: error( "<COR06761097E>", FileNotFoundError( "Unable to locate archive config due to nested dirs" ), ) # Otherwise, try again. If we fail again stop, because the zip creation should only # create one potential extra layer of nesting around the model directory. try: model = self._do_load( nested_dirs[0], load_singleton, finder, initializer, **kwargs ) except FileNotFoundError: error( "<COR84410081E>", FileNotFoundError( "Unable to locate archive config within top two levels of {}".format( module_path ) ), ) return model
[docs] @contextmanager def _singleton_lock(self, load_singleton: bool): """Helper contextmanager that will only lock the singleton cache if this load is a singleton load """ if load_singleton: with self.__singleton_lock: yield else: yield
[docs] @staticmethod def _get_component( component: Union[str, FactoryConstructible], component_dict: Dict[str, FactoryConstructible], component_factory: Factory, component_name: str, component_cfg: dict, component_type: type, ) -> FactoryConstructible: """Common logic for resolving components from config NOTE: This is done lazily to avoid relying on import order and to allow for dynamic config changes """ error.type_check( "<COR45466249E>", str, component_type, **{component_name: component} ) if isinstance(component, component_type): return component if component not in component_dict: cfg = component_cfg.get(component) error.value_check( "<COR55057389E>", isinstance(cfg, dict), "Unknown {}: {}", component_name, component, ) component_dict[component] = component_factory.construct(cfg, component) return component_dict[component]