Source code for caikit.core.modules.saver

# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Local saver implementation for saving modules to disk.
Contains recursive functions for loading modules saved inside modules.
"""

# Standard
from importlib import metadata
import datetime
import os
import shutil
import uuid

# First Party
import alog

# Local
from ..exceptions import error_handler
from ..modules.config import ModuleConfig
from ..toolkit import ObjectSerializer
from ..toolkit.wip_decorator import TempDisableWIP
from .base import ModuleBase
from .loader import ModuleLoader
from caikit.config import get_config

log = alog.use_channel("MODULE_SAVE")
error = error_handler.get(log)


[docs] class ModuleSaver: """A module saver that provides common functionality used for saving modules and also a context manager that cleans up in case an error is encountered during the save process for a model_path that did not already exist. """ SAVED_KEY_NAME = "saved" CREATED_KEY_NAME = "created" TRACKING_KEY_NAME = "tracking_id" MODULE_VERSION_KEY_NAME = "version" MODULE_ID_KEY_NAME = "module_id" MODULE_CLASS_KEY_NAME = "module_class" def __init__(self, module: ModuleBase, model_path, exist_ok=True): """Construct a new module saver. Args: module (caikit.core.module.Module): The instance of the module to be saved. model_path (str): The absolute path to the directory where the model will be saved. If this directory does not exist, it will be created. exist_ok (bool): Allow to overwrite existing model_path files. """ self.model_path = os.path.normpath(model_path) self.exist_ok = exist_ok # Get possibly nested caikit library path module_path = module.__module__ lib_name_generator = ( k for k, v in get_config().libraries.items() if module_path.startswith(v.module_path) ) try: self.library_name = next(lib_name_generator) except StopIteration: # This assumes no nested module path by default self.library_name = module_path.split(".")[0] # tests try: self.library_version = metadata.version(self.library_name) except metadata.PackageNotFoundError: log.debug("<COR25991305D>", "No library version found") if ( self.library_name in get_config().libraries and "version" in get_config().libraries[self.library_name] ): self.library_version = get_config().libraries[self.library_name].version else: self.library_version = "0.0.0" self.config = { self.library_name + "_version": self.library_version, self.CREATED_KEY_NAME: str(datetime.datetime.now()), self.SAVED_KEY_NAME: str(datetime.datetime.now()), "name": module.MODULE_NAME, self.TRACKING_KEY_NAME: str(uuid.uuid4()), self.MODULE_ID_KEY_NAME: module.MODULE_ID, self.MODULE_CLASS_KEY_NAME: module.MODULE_CLASS, self.MODULE_VERSION_KEY_NAME: module.MODULE_VERSION, } # Temp disable wip for following invocation to not log warnings for downstream # usage of ModuleSaver with TempDisableWIP(): # Get metadata back about this module and add it to the config stored_config = module.metadata # Sanitize some things off of the config: # Remove the old `saved` timestamp: stored_config.pop(self.SAVED_KEY_NAME, None) # Remove any reserved keys, these will be set by the `ModuleConfig` class for key in ModuleConfig.reserved_keys: if key in stored_config: stored_config.pop(key) self.config.update(stored_config)
[docs] def add_dir(self, relative_path, base_relative_path=""): """Create a directory inside the `model_path` for this saver. Args: relative_path (str): A path relative to this saver's `model_path` denoting the directory to create. base_relative_path (str): A path, relative to this saver's `model_path`, in which `relative_path` will be created. Returns: str, str: A tuple containing both the `relative_path` and `absolute_path` to the directory created. Examples: >>> with ModelSaver('/path/to/model') as saver: >>> rel_path, abs_path = saver.add_dir('word_embeddings', 'model_data') >>> print(rel_path) model_data/word_embeddings >>> print(abs_path) /path/to/model/model_data/word_embeddings """ base_relative_path = os.path.normpath(base_relative_path) relative_path = os.path.normpath(relative_path) relative_path = os.path.join(base_relative_path, relative_path) absolute_path = os.path.join(self.model_path, relative_path) os.makedirs(absolute_path, exist_ok=True) return relative_path, absolute_path
[docs] def copy_file(self, file_path, relative_path=""): """Copy an external file into a subdirectory of the `model_path` for this saver. Args: file_path (str): Absolute path to the external file to copy. relative_path (str): The relative path inside of `model_path` where the file will be copied to. If set to the empty string (default) then the file will be placed directly in the `model_path` directory. Returns: str, str: A tuple containing both the `relative_path` and `absolute_path` to the copied file. """ file_path = os.path.normpath(file_path) if not os.path.isfile(file_path): error( "<COR80954473E>", FileNotFoundError( "Attempted to add `{}` but is not a regular file.".format(file_path) ), ) filename = os.path.basename(os.path.normpath(file_path)) relative_path, absolute_path = self.add_dir(relative_path) relative_file_path = os.path.join(relative_path, filename) absolute_file_path = os.path.join(absolute_path, filename) shutil.copyfile(file_path, absolute_file_path) return relative_file_path, absolute_file_path
[docs] def save_object(self, obj, filename, serializer, relative_path=""): """Save a Python object using the provided ObjectSerializer. Args: obj (any): The Python object to save filename (str): The filename to use for the saved object serializer (ObjectSerializer): An ObjectSerializer instance (e.g., YAMLSerializer) that should be used to serialize the object relative_path (str): The relative path inside of `model_path` where the object will be saved """ if not issubclass(serializer.__class__, ObjectSerializer): error( "<COR85655282E>", TypeError( "`{}` does not extend `ObjectSerializer`".format( serializer.__class__.__name__ ) ), ) relative_path, absolute_path = self.add_dir(relative_path) # Normalize any '././' structure that may come from relative paths relative_file_path = os.path.normpath(os.path.join(relative_path, filename)) absolute_file_path = os.path.normpath(os.path.join(absolute_path, filename)) serializer.serialize(obj, absolute_file_path) return relative_file_path, absolute_file_path
[docs] def update_config(self, additional_config): """Add items to this saver's config dictionary. Args: additional_config (dict): A dictionary of config options to add the this saver's configuration. Notes: The behavior of this method matches `dict.update` and is equivalent to calling `saver.config.update`. The `saver.config` dictionary may be accessed directly for more sophisticated manipulation of the configuration. """ self.config.update(additional_config)
[docs] def save_module(self, module, relative_path, **kwargs): """Save a CaikitCore module within a workflow artifact and add a reference to the config. Args: module (caikit.core.ModuleBase): The CaikitCore module to save as part of this workflow relative_path (str): The relative path inside of `model_path` where the module will be saved **kwargs: dict key-value pair of parameters to be passed to module.save """ if not issubclass(module.__class__, ModuleBase): error( "<COR30664151E>", TypeError( "`{}` does not extend `ModuleBase`".format( module.__class__.__name__ ) ), ) rel_path, abs_path = self.add_dir(relative_path) # Save this module at the specified location module.save(abs_path, **kwargs) self.config.setdefault(ModuleLoader.MODULE_PATHS_KEY, {}).update( {relative_path: rel_path} ) return rel_path, abs_path
[docs] def save_module_list(self, modules, config_key, **kwargs): """Save a list of CaikitCore modules within a workflow artifact and add a reference to the config. Args: modules (dict{str -> caikit.core.ModuleBase}): A dict with module relative path as key and a CaikitCore module as value to save as part of this workflow config_key (str): The config key inside of `model_path` where the modules' relative path with be referenced **kwargs: dict key-value pair of parameters to be passed to module.save Returns: list_of_rel_path: list(str) List of relative paths where the modules are saved list_of_abs_path: list(str) List of absolute paths where the modules are saved """ # validate type of input parameters error.type_check("<COR44644420E>", dict, modules=modules) error.type_check("<COR54316176E>", str, config_key=config_key) list_of_rel_path = [] list_of_abs_path = [] # iterate through the dict and serialize the modules in its corresponding paths for relative_path, module in modules.items(): if not issubclass(module.__class__, ModuleBase): error( "<COR67834055E>", TypeError( "`{}` does not extend `ModuleBase`".format( module.__class__.__name__ ) ), ) error.type_check("<COR48984754E>", str, relative_path=relative_path) rel_path, abs_path = self.add_dir(relative_path) # Save this module at the specified location module.save(abs_path, **kwargs) # append relative and absolute path to a list that will be returned list_of_rel_path.append(rel_path) list_of_abs_path.append(abs_path) # update the config with config key and list of relative path self.config.setdefault(ModuleLoader.MODULE_PATHS_KEY, {}).update( {config_key: list_of_rel_path} ) return list_of_rel_path, list_of_abs_path
[docs] def __enter__(self): """Enter the module saver context. This creates the `model_path` directory. If this context successfully exits, then the model configuration and all files it contains will be written and saved to disk inside the `model_path` directory. If `exist_ok` is False, an exception will be raised before touching existing `model_path` files. If any uncaught exceptions are thrown inside this context, and `exist_ok` is False, then this new `model_path` will be removed. If `exist_ok` is True, the files will be kept and may include incomplete updates. """ os.makedirs(self.model_path, exist_ok=self.exist_ok) return self
[docs] def __exit__(self, exc_type, exc_val, exc_tb): """Exit the module saver context. If this context successfully exits, then the model configuration and all files it contains will be written and saved to disk inside the `model_path` directory. If any uncaught exceptions are thrown inside this context, and `exist_ok` is False, then this new `model_path` will be removed. If `exist_ok` is True, the files will be kept and may include incomplete updates. """ if exc_type is not None: if not self.exist_ok: # Presume it is okay to rmtree shutil.rmtree(self.model_path, ignore_errors=True) return ModuleConfig(self.config).save(self.model_path)