Source code for caikit.core.modules.meta
# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains an implementation of a metaclass that hijacks all
${class}.load(...) invocations.
For a good read on metaclasses, see https://realpython.com/python-metaclasses/
Our goal is to have every instance of a `caikit.core.ModuleBase` automatically populated with
metadata when it is constructed, without requiring anything of the module authors.
The source of this metadata will be the `config.yml` file that resides within the module directory
to be loaded. Since we require the path to that file in order to read the model's metadata, we
cannot simply define a base constructor. Instead, we patch over the module's .load() function,
which is guaranteed to be called with a path to a serialized module.
Additionally, the naive solution of manually patching .load functions does not work when
inheritance is involved. For example::
import caikit.core
class ParentModule(caikit.core.ModuleBase):
@classmethod
def load(cls, module_path):
return cls()
class ChildModule(ParentModule):
@classmethod
def load(cls, module_path):
return super().load()
# This is fine!
assert isinstance(ChildModule.load(), ChildModule)
def injector(load_fn):
def injected_load(*args):
module = load_fn(*args)
module.metadata = {"stuff"}
return module
return classmethod(injected_load)
for clz in (ParentModule, ChildModule):
# But this line binds the new load function directly to each class :(
clz.load = injector(clz.load)
# And this will now raise since ParentModule is returned
assert isinstance(ChildModule.load(), ChildModule)
Instead of binding new metadata-injecting load functions directly to a class at import time, we
need to bind the new load function at contruction time, when the class hierarchy with inheritance
is known.
"""
# Standard
from typing import TYPE_CHECKING, List
import abc
import functools
# First Party
import alog
# Local
from ..exceptions import error_handler
from .config import ModuleConfig
if TYPE_CHECKING:
# Local
from caikit.core import TaskBase
log = alog.use_channel("METADATA_INJECT")
error = error_handler.get(log)
[docs]
class _ModuleBaseMeta(abc.ABCMeta):
"""This is the metaclass used by `caikit.core.ModuleBase`.
This metaclass populates the `metadata` property of any module that is created by invoking
a `load` classmethod on a derived class.
"""
# pylint: disable=arguments-differ
def __new__(mcs, name, bases, attrs):
real_load = attrs.get("load")
if real_load is not None:
log.debug3("Wrapping a load function on class %s", mcs)
@alog.logged_function(log.trace)
def metadata_injecting_load(clz, *args, **kwargs):
"""This function is the replacement for the module's original `.load`"""
path = None
module_config = None
# Many load functions rename what the `path` argument is.
# Usually, it's just the first positional argument.
# But, `load` may be called with kwargs-only, like:
# MyModelClass.load(model_path="/some/path")
log.debug3(
"Attempting to find model path from load args: [%s] [%s]",
args,
kwargs,
)
if len(args) > 0:
log.debug3(
"Using first positional argument to load as model path: %s",
args[0],
)
path = args[0]
else:
for kw, arg in kwargs.items():
if "path" in kw:
log.debug3(
"Using named keyword argument to load as model path: [%s: %s]",
kw,
arg,
)
path = arg
break
try:
if path:
module_config = ModuleConfig.load(path)
except FileNotFoundError:
log.error(
"Could not load module metadata while loading with %s %s",
args,
kwargs,
)
# Call the original .load function to load the module
module = real_load.__func__(clz, *args, **kwargs)
# defer any "is this really a module" logic until after the load call
if hasattr(module, "metadata") and module_config:
module.metadata.update(module_config)
return module
# Wrap the load function so that the final method appears the same
# as the original
metadata_injecting_load = functools.wraps(real_load.__func__)(
metadata_injecting_load
)
attrs["load"] = classmethod(metadata_injecting_load)
return super().__new__(mcs, name, bases, attrs)
@property
def tasks(cls) -> List["TaskBase"]:
return [task for task in cls._TASK_CLASSES]
[docs]
def __setattr__(cls, name, val):
"""Overwrite __setattr__ to warn on any dynamic updates to the load function.
We'd rather not lose all the work we did to wrap `.load` with metadata injection
in the constructor!
"""
if name == "load":
# NB: warn instead of throw because some libraries will mock out .load during
# unit testing where it's easier than trying to build a quick-loading dummy model
log.warning("Overwriting load on a module will break metadata persistence!")
return super().__setattr__(name, val)