Source code for caikit.core.modules.loader

# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Local loader implementation for loading modules from disk.
Contains recursive functions for loading modules saved inside modules.
"""

# Standard
from typing import Union
import os

# First Party
import alog

# Local
from ..exceptions import error_handler
from .config import ModuleConfig
from caikit import core

log = alog.use_channel("MODULE_LOAD")
error = error_handler.get(log)


[docs] class ModuleLoader: MODULE_PATHS_KEY = "module_paths" def __init__(self, model_path: Union[str, ModuleConfig]): """Construct a new module loader. Args: model_path (Union[str, ModuleConfig]): The path to the directory where the model is to be loaded from, or a preloaded config. """ self.config = ModuleConfig.load(model_path) self.model_path = self.config.model_path
[docs] def load_arg(self, arg): """Extract arg value from the loaded model's config""" return getattr(self.config, arg)
[docs] def load_args(self, *args): """Extract values from the loaded model's config""" return tuple(getattr(self.config, arg) for arg in args)
[docs] def load_module(self, module_paths_key, load_singleton=False): """Load a CaikitCore module from a module config.module_paths specification. Args: module_paths_key (str): key in `config.module_paths` looked at to load a module load_singleton (bool): singleton load flag to pass to individual module loads """ # Load module from a given relative path if "module_paths" not in self.config: error( "<COR08580509E>", KeyError("Missing `module_paths` in module config!") ) if module_paths_key not in self.config.module_paths: error( "<COR22069088E>", KeyError( "Missing required {} key in config.module_paths!".format( module_paths_key ) ), ) module_path = os.path.join( self.model_path, self.config.module_paths[module_paths_key] ) return core.load(module_path, load_singleton=load_singleton)
[docs] def load_module_list(self, module_paths_key): """Load a list of CaikitCore module from a workflow config.module_paths specification. Args: module_paths_key (str): key in `config.module_paths` looked at to load a list of modules Returns: list: list of loaded modules """ # Load module from a given relative path # Can be updated to load from a module key if self.MODULE_PATHS_KEY not in self.config: error( "<COR52619266E>", KeyError("Missing `module_paths` in workflow config!") ) if module_paths_key not in self.config.module_paths: error( "<COR75976687E>", KeyError( "Missing required {} key in config.module_paths!".format( module_paths_key ) ), ) module_list = self.config.module_paths[module_paths_key] error.type_check("<COR21790391E>", list, module_list=module_list) # Iterate through the list and load module one by one loaded_modules = [] for module in module_list: module_path = os.path.join(self.model_path, module) loaded_modules.append(core.load(module_path)) return loaded_modules