Source code for caikit.core.model_management.multi_model_initializer

# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The MultiModelInitializer configures a set of other model initializers that will be used
in sequence to try loading models.

Configuration for MultiModelInitializer lives under the config as follows:

model_management:
    initializers:
        <initializer name>:
            type: MULTI
            config:
                # Sequence of other initializer names to use in priority order
                initializer_priority:
                    - other_initializer1
                    - other_initializer2
"""
# Standard
from typing import Optional

# First Party
import aconfig
import alog

# Local
from ...config import get_config
from ..exceptions import error_handler
from ..modules import ModuleBase, ModuleConfig
from .model_initializer_base import ModelInitializerBase

# NOTE: Top-level import done so that global MODEL_MANAGER can be used at
#   construction time without incurring a circular dependency
import caikit.core

log = alog.use_channel("MINIT")
error = error_handler.get(log)


[docs] class MultiModelInitializer(ModelInitializerBase): __doc__ = __doc__ name = "MULTI" def __init__(self, config: aconfig.Config, instance_name: str): """Initialize with the sequence of initializers to use""" self._instance_name = instance_name initializer_priority = config.initializer_priority error.type_check( "<COR47518221E>", list, initializer_priority=initializer_priority, ) error.type_check_all( "<COR47518222E>", str, initializer_priority=initializer_priority, ) error.value_check( "<COR05042343E>", initializer_priority, "Must provide at least one valid initializer", ) config_initializers = get_config().model_management.initializers invalid_initializers = [ initializer for initializer in initializer_priority if initializer not in config_initializers ] error.value_check( "<COR68252034E>", not invalid_initializers, "Invalid initializers given in initializer_priority: {}", invalid_initializers, ) error.value_check( "<COR54613971E>", self._instance_name not in initializer_priority, "Cannot include self in multi initializer priority", ) model_manager = config.model_manager or caikit.core.MODEL_MANAGER log.debug2( "Setting up %s with initializer priority: %s", self.name, initializer_priority, ) self._initializers = [ model_manager.get_initializer(initializer) for initializer in initializer_priority ]
[docs] def init( self, model_config: ModuleConfig, **kwargs, ) -> Optional[ModuleBase]: """Iterate through the sequence of initializers and return the first one that succeeds """ for idx, initializer in enumerate(self._initializers): log.debug2( "Trying to init %s with initializer %d of type %s", model_config.module_id, idx, initializer.name, ) try: module = initializer.init(model_config, **kwargs) if module: log.debug( "Init model %s with initializer %d of type %s", model_config.module_id, idx, initializer.name, ) return module log.debug2( "Initializer %d of type %s unable to init %s", idx, initializer.name, model_config.module_id, ) except Exception as err: # pylint: disable=broad-exception-caught log.debug2( "Initializer %d of type %s failed to load %s: %s", idx, initializer.name, model_config.module_id, err, ) log.debug4("Initializer error", exc_info=True) # No initializer succeeded log.warning("Unable to init %s with any initializer", model_config.module_id) return None