Source code for caikit.core.task

# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Standard
from inspect import isclass
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, TypeVar, Union
import collections
import dataclasses
import typing

# First Party
from alog import alog

# Local
from .data_model import DataStream
from .data_model.base import DataBase
from .exceptions import error_handler
from .signature_parsing import CaikitMethodSignature

log = alog.use_channel("TASK_BASE")
error = error_handler.get(log)

ProtoableInputTypes = Type[Union[int, float, str, bytes, bool, DataBase]]
ValidInputTypes = Union[
    ProtoableInputTypes, List[ProtoableInputTypes], DataStream[ProtoableInputTypes]
]

_InferenceMethodBaseT = TypeVar("_InferenceMethodBaseT", bound=Callable)

_STREAM_OUT_ANNOTATION = "__streaming_output_type"
_STREAM_PARAMS_ANNOTATION = "__streaming_params"
_UNARY_OUT_ANNOTATION = "__unary_output_type"
_UNARY_PARAMS_ANNOTATION = "__unary_params"
_VISIBLE_ANNOTATION = "__visible"
_METADATA_ANNOTATION = "__metadata"


[docs] class TaskBase: """The TaskBase defines the interface for an abstract AI task An AI task is a logical function signature which, when implemented, performs a task in some AI domain. The key property of a task is that the set of required input argument types and the output value type are consistent across all implementations of the task. """
[docs] @dataclasses.dataclass class InferenceMethodPtr: """Little container class that holds a method name and its flavor of streaming. i.e. the args to a `@TaskClass.taskmethod` decoration. """ method_name: str # the simple name of a method, like "run" input_streaming: bool output_streaming: bool context_arg: Optional[ str ] # The name of the request context to pass if one is provided
deferred_method_decorators: Dict[ Type["TaskBase"], Dict[str, List["TaskBase.InferenceMethodPtr"]] ] = {}
[docs] @classmethod def taskmethod( cls, input_streaming: bool = False, output_streaming: bool = False, context_arg: Optional[str] = None, ) -> Callable[[_InferenceMethodBaseT], _InferenceMethodBaseT]: """Decorates a module instancemethod and indicates whether the inputs and outputs should be handled as streams. This will trigger validation that the signature of this method is compatible with the task's definition of input and output types. The actual handling of validating the method and registering it is deferred until after the module class is created, which happens outside the context of this decoration. """ def decorator(inference_method: _InferenceMethodBaseT) -> _InferenceMethodBaseT: cls.deferred_method_decorators.setdefault(cls, {}) fq_mod_name = ".".join( [ inference_method.__module__, *inference_method.__qualname__.split(".")[0:-1], ] ) cls.deferred_method_decorators[cls].setdefault(fq_mod_name, []) cls.deferred_method_decorators[cls][fq_mod_name].append( TaskBase.InferenceMethodPtr( method_name=inference_method.__name__, input_streaming=input_streaming, output_streaming=output_streaming, context_arg=context_arg, ) ) return inference_method return decorator
[docs] @classmethod def deferred_method_decoration(cls, module: Type): """Runs the actual decoration logic that `taskmethod` would have run if the module class existed during its lifetime. Validates that all decorated methods match the task's API expectations, and stores the signatures on the module class for access later. """ if cls.has_inference_method_decorators(module): keyname = _make_keyname_for_module(module) deferred_decorations = cls.deferred_method_decorators[cls][keyname] for decoration in deferred_decorations: signature = CaikitMethodSignature( module, decoration.method_name, decoration.context_arg ) cls.validate_run_signature( signature, decoration.input_streaming, decoration.output_streaming ) module._TASK_INFERENCE_SIGNATURES.setdefault(cls, []).append( (decoration.input_streaming, decoration.output_streaming, signature) )
[docs] @classmethod def has_inference_method_decorators(cls, module_class: Type) -> bool: """Utility that returns true iff a module has any `@TaskClass.taskmethod` decorations""" if cls not in cls.deferred_method_decorators: return False return ( _make_keyname_for_module(module_class) in cls.deferred_method_decorators[cls] )
[docs] @classmethod def validate_run_signature( cls, signature: CaikitMethodSignature, input_streaming: bool, output_streaming: bool, ) -> None: """Validates that the provided method signature meets the api constraints defined in this task, for the given streaming flavors. Raises: ValueError if no type annotations were provided on the method TypeError if the type annotations do not meet the task's api constraints """ if not signature.parameters: raise ValueError( "Task could not be validated, no .run parameters were provided" ) if signature.return_type is None: raise ValueError( "Task could not be validated, no .run return type was provided" ) missing_required_params = [ parameter_name for parameter_name in cls.get_required_parameters(input_streaming) if parameter_name not in signature.parameters ] if missing_required_params: raise TypeError( f"Required parameters {missing_required_params} not in signature for module: " f"{signature.module}" ) type_mismatch_errors = [] for parameter_name, parameter_type in cls.get_required_parameters( input_streaming ).items(): signature_type = signature.parameters[parameter_name] if parameter_type != signature_type: if typing.get_origin(signature_type) == typing.Union and ( # Either our parameter type is not a union & is part of the union signature parameter_type in typing.get_args(signature_type) # Or our parameter type is a union that's a subset of the union signature or ( typing.get_origin(parameter_type) == typing.Union and set(typing.get_args(parameter_type)).issubset( set(typing.get_args(signature_type)) ) ) ): continue if input_streaming and cls._is_iterable_type(parameter_type): streaming_type = typing.get_args(parameter_type)[0] for iterable_type in typing.get_args(signature_type): if not cls._subclass_check(iterable_type, streaming_type): raise TypeError( f"Wrong input type for {parameter_name}, expected {parameter_type} \ but got {signature_type}" ) else: type_mismatch_errors.append( f"Parameter {parameter_name} has type {signature_type} but type \ {parameter_type} is required" ) if type_mismatch_errors: raise TypeError( f"Wrong types provided for parameters to {signature.module}: {type_mismatch_errors}" ) cls._raise_on_wrong_output_type( signature.return_type, signature.module, output_streaming )
[docs] @classmethod def get_required_parameters( cls, input_streaming: bool ) -> Dict[str, Union[ValidInputTypes, Type[Iterable[ValidInputTypes]]]]: """Get the set of input types required by this task""" if not input_streaming: if _UNARY_PARAMS_ANNOTATION not in cls.__annotations__: raise ValueError("No unary inputs are specified for this task") return cls.__annotations__[_UNARY_PARAMS_ANNOTATION] if _STREAM_PARAMS_ANNOTATION not in cls.__annotations__: raise ValueError("No streaming inputs are specified for this task") return cls.__annotations__[_STREAM_PARAMS_ANNOTATION]
[docs] @classmethod def get_output_type(cls, output_streaming: bool) -> Type[DataBase]: """Get the output type for this task NOTE: This method is automatically configured by the @task decorator and should not be overwritten by child classes. """ if not output_streaming: if _UNARY_OUT_ANNOTATION not in cls.__annotations__: raise ValueError("No unary outputs are specified for this task") return cls.__annotations__[_UNARY_OUT_ANNOTATION] if _STREAM_OUT_ANNOTATION not in cls.__annotations__: raise ValueError("No streaming outputs are specified for this task") return cls.__annotations__[_STREAM_OUT_ANNOTATION]
[docs] @classmethod def get_visibility(cls) -> bool: """Get the visibility for this task. NOTE: defaults to True even if visibility wasn't provided""" return cls.__annotations__.get(_VISIBLE_ANNOTATION, True)
[docs] @classmethod def get_metadata(cls) -> Dict[str, Any]: """Get any metadata defined for this task NOTE: defaults to an empty dict if one wasn't provided""" return cls.__annotations__.get(_METADATA_ANNOTATION, {})
[docs] @classmethod def _raise_on_wrong_output_type(cls, output_type, module, output_streaming: bool): task_output_type = cls.get_output_type(output_streaming) if cls._subclass_check(output_type, task_output_type): # Basic case, same type or subclass of it return if typing.get_origin(output_type) == Union: for union_type in typing.get_args(output_type): if cls._subclass_check(union_type, task_output_type): # Something in the union has an acceptable type return # Do some streaming checks if output_streaming and cls._is_iterable_type(output_type): # task_output_type is already guaranteed to be Iterable[T] streaming_type = typing.get_args(task_output_type)[0] for iterable_type in typing.get_args(output_type): if cls._subclass_check(iterable_type, streaming_type): return raise TypeError( f"Wrong output type for module {module}: " f"Found {output_type} but expected {task_output_type}" )
[docs] @staticmethod def _subclass_check(this_type, that_type): """Wrapper around issubclass that first checks if both args are classes. Returns True if the types are the same, or they are both classes and this_type is a subclass of that_type """ if this_type == that_type: return True if isclass(this_type) and isclass(that_type): return issubclass(this_type, that_type) return False
[docs] @staticmethod def _is_iterable_type(typ: Type) -> bool: """Returns True if typ is an iterable type. Does not work for types like `list`, `tuple`, but we're interested here in `List[T]` etc. This is implemented this way to support older python versions where isinstance(typ, typing.Iterable) does not work """ try: iter(typ) return True except TypeError: return False
[docs] def task( unary_parameters: Dict[str, ValidInputTypes] = None, streaming_parameters: Dict[str, Type[Iterable[ValidInputTypes]]] = None, unary_output_type: Type[DataBase] = None, streaming_output_type: Type[Iterable[Type[DataBase]]] = None, visible: bool = True, metadata: Optional[Dict[str, Any]] = None, **kwargs, ) -> Callable[[Type[TaskBase]], Type[TaskBase]]: """The decorator for AI Task classes. This defines an output data model type for the task, and a minimal set of required inputs that all public models implementing this task must accept. As an example, the `caikit.interfaces.nlp.SentimentTask` might look like:: @task( unary_parameters={ "raw_document": caikit.interfaces.nlp.RawDocument }, streaming_parameters={ "raw_documents": Iterable[caikit.interfaces.nlp.RawDocument] } unary_output_type=caikit.interfaces.nlp.SentimentPrediction streaming_output_type=Iterable[caikit.interfaces.nlp.SentimentPrediction] ) class SentimentTask(caikit.TaskBase): pass and a module that implements this task might have methods like:: @module(id="b9d98408-84c2-488c-8385-9d698effe60b", task=SentimentTask) class MyModule(ModuleBase): @SentimentTask.taskmethod() def run(raw_document: caikit.interfaces.nlp.RawDocument, inference_mode: str = "fast") -> caikit.interfaces.nlp.SentimentPrediction: # impl @SentimentTask.taskmethod(input_streaming=True, output_streaming=True) def run_bidi_stream(raw_documents: DataStream[caikit.interfaces.nlp.RawDocument]) -> DataStream[caikit.interfaces.nlp.SentimentPrediction]: # impl Note the run function may include other arguments beyond the minimal required inputs for the task. Args: unary_parameters (Dict[str, ValidInputTypes]): The required parameters that all module's unary-input inference methods must contain. A dictionary of parameter name to parameter type, where the types can be in the set of: - Python primitives - Caikit data models - Iterable containers of the above - Caikit model references (maybe?) streaming_parameters: The same as unary_parameters, but for streaming-input inference methods. All types must be in the form `Iterable[T]` unary_output_type (Type[DataBase]): The unary output type of the task, which all modules' unary-output inference methods must return. This must be a caikit data model type. streaming_output_type (Type[Iterable[Type[DataBase]]]): The streaming output type of the task, which all modules' streaming-output inference methods must return. This must be in the form Iterable[T]. visible (bool): If this task should be exposed to the end user in documentation or if it should only be used internally metadata (Optional[Dict[str, Any]]): Any additional metadata that should be included in the documentation for this task Returns: A decorator function for the task class, registering it with caikit's core registry of tasks. """ def decorator(cls: Type[TaskBase]) -> Type[TaskBase]: error.subclass_check("<COR19436440E>", cls, TaskBase) # NB: python <= 3.9 safe way of setting class annotations cls_annotations = cls.__dict__.get("__annotations__", None) if cls_annotations is None: cls.__annotations__ = {} cls_annotations = cls.__dict__.get("__annotations__", None) if unary_parameters: cls_annotations[_UNARY_PARAMS_ANNOTATION] = unary_parameters if streaming_parameters: cls_annotations[_STREAM_PARAMS_ANNOTATION] = streaming_parameters if unary_output_type: cls_annotations[_UNARY_OUT_ANNOTATION] = unary_output_type if streaming_output_type: cls_annotations[_STREAM_OUT_ANNOTATION] = streaming_output_type cls_annotations[_VISIBLE_ANNOTATION] = visible cls_annotations[_METADATA_ANNOTATION] = metadata or {} # Backwards compatibility with old-style @tasks if "required_parameters" in kwargs and not unary_parameters: cls_annotations[_UNARY_PARAMS_ANNOTATION] = kwargs["required_parameters"] if "output_type" in kwargs and not unary_output_type: output_type = kwargs["output_type"] if cls._is_iterable_type(output_type): cls_annotations[_STREAM_OUT_ANNOTATION] = kwargs["output_type"] else: cls_annotations[_UNARY_OUT_ANNOTATION] = kwargs["output_type"] # End Backwards compatibility error.value_check( "<COR12671910E>", _UNARY_PARAMS_ANNOTATION in cls_annotations or _STREAM_PARAMS_ANNOTATION in cls_annotations, "At least one input type must be set on a task", ) error.value_check( "<COR12671910E>", _UNARY_OUT_ANNOTATION in cls_annotations or _STREAM_OUT_ANNOTATION in cls_annotations, "At least one output type must be set on a task", ) if _UNARY_OUT_ANNOTATION in cls_annotations: error.subclass_check( "<COR12766440E>", cls.get_output_type(output_streaming=False), DataBase ) if _STREAM_OUT_ANNOTATION in cls_annotations: if typing.get_origin(cls.get_output_type(output_streaming=True)) is None: raise TypeError( f"subclass check failed: {cls.get_output_type(output_streaming=True)} is \ not a subclass of (<class 'collections.abc.Iterable'>" ) error.subclass_check( "<COR12766440E>", typing.get_origin(cls.get_output_type(output_streaming=True)), collections.abc.Iterable, ) if _UNARY_PARAMS_ANNOTATION in cls_annotations: params_dict = cls.get_required_parameters(input_streaming=False) error.type_check("<COR19906440E>", dict, params_dict=params_dict) error.type_check_all( "<COR00123440E>", str, params_dict_keys=params_dict.keys() ) # TODO: check proto-ability of things if _STREAM_PARAMS_ANNOTATION in cls_annotations: params_dict = cls.get_required_parameters(input_streaming=True) error.type_check("<COR19556230E>", dict, params_dict=params_dict) error.type_check_all( "<COR58796465E>", str, params_dict_keys=params_dict.keys() ) for v in params_dict.values(): error.subclass_check( "<COR52740295E>", typing.get_origin(v), collections.abc.Iterable ) return cls return decorator
[docs] def _make_keyname_for_module(module_class: Type) -> str: return ".".join([module_class.__module__, module_class.__qualname__])