# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Standard
from contextlib import contextmanager
from importlib.metadata import version
from typing import Any, Dict, Iterable, Optional, Set, Union
import itertools
import traceback
# Third Party
from google.protobuf.descriptor import FieldDescriptor
from google.protobuf.message import Message as ProtobufMessage
from grpc import RpcError, ServicerContext, StatusCode
from prometheus_client import Counter, Summary
# First Party
import alog
# Local
from caikit import get_config
from caikit.core import MODEL_MANAGER, ModuleBase, TaskBase
from caikit.core.data_model import DataBase, DataStream
from caikit.core.exceptions.caikit_core_exception import CaikitCoreException
from caikit.core.signature_parsing import CaikitMethodSignature
from caikit.interfaces.runtime.data_model import PredictionJob, RuntimeServerContextType
from caikit.runtime import trace
from caikit.runtime.metrics.rpc_meter import RPCMeter
from caikit.runtime.model_management.model_manager import ModelManager
from caikit.runtime.names import MODEL_MESH_MODEL_ID_KEY
from caikit.runtime.service_factory import ServicePackage
from caikit.runtime.service_generation.rpcs import TaskPredictionJobRPC, TaskPredictRPC
from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException
from caikit.runtime.utils.import_util import clean_lib_names
from caikit.runtime.utils.servicer_util import (
build_caikit_library_request_dict,
build_proto_response,
build_proto_stream,
get_metadata,
raise_caikit_runtime_exception,
validate_data_model,
)
from caikit.runtime.work_management.abortable_context import (
AbortableContext,
ThreadInterrupter,
)
from caikit.runtime.work_management.rpc_aborter import RpcAborter
PREDICT_RPC_COUNTER = Counter(
"predict_rpc_count",
"Count of global predict-managed RPC calls",
["grpc_request", "code", "model_id"],
)
JOB_PREDICT_RPC_COUNTER = Counter(
"predict_job_rpc_count",
"Count of global predict-managed RPC jobs started",
["grpc_request", "code", "model_id"],
)
PREDICT_FROM_PROTO_SUMMARY = Summary(
"predict_from_proto_duration_seconds",
"Histogram of predict request unmarshalling duration (in seconds)",
["grpc_request", "model_id"],
)
PREDICT_CAIKIT_LIBRARY_SUMMARY = Summary(
"predict_caikit_library_duration_seconds",
"Histogram of predict Caikit Library run duration (in seconds)",
["grpc_request", "model_id"],
)
PREDICT_TO_PROTO_SUMMARY = Summary(
"predict_to_proto_duration_seconds",
"Histogram of predict response marshalling duration (in seconds)",
["grpc_request", "model_id"],
)
log = alog.use_channel("GP-SERVICR-I")
# Protobuf non primitives
# Ref: https://developers.google.com/protocol-buffers/docs/reference/cpp/google.protobuf.descriptor
NON_PRIMITIVE_TYPES = [FieldDescriptor.TYPE_MESSAGE, FieldDescriptor.TYPE_ENUM]
[docs]
class GlobalPredictServicer:
"""This class contains RPC calls affiliated with the Caikit Runtime that are
not a part of the Model Runtime proto definition. They will be serviced by
mocking the particular RPC (based on the message type), leveraging the
CaikitRuntimeServicerMock to return the appropriate mock response for a
given request
"""
# Input size in code points, provided by orchestrator
INPUT_SIZE_KEY = "input-length"
def __init__(
self,
inference_service: ServicePackage,
interrupter: ThreadInterrupter = None,
):
self._started_metering = False
self._model_manager = ModelManager.get_instance()
self._rpc_meter = None
if get_config().runtime.metering.enabled:
self._started_metering = True
self.rpc_meter = RPCMeter()
log.info(
"<RUN76773775I>",
"Metering is enabled, to disable set `metering.enabled` in config to false",
)
else:
log.info(
"<RUN76773776I>",
"Metering is disabled, to enable set `metering.enabled` in config to true",
)
self._interrupter = interrupter
self._inference_service = inference_service
# Validate that the Caikit Library CDM is compatible with our service descriptor
validate_data_model(self._inference_service.descriptor)
log.info("<RUN76773778I>", "Validated Caikit Library CDM successfully")
# Duplicate code in global_train_servicer
# pylint: disable=duplicate-code
library = clean_lib_names(get_config().runtime.library)[0]
try:
lib_version = version(library)
except Exception: # pylint: disable=broad-exception-caught
lib_version = "unknown"
# Set up shared tracer
self._tracer = trace.get_tracer(__name__)
log.info(
"<RUN76884779I>",
"Constructed inference service for library: %s, version: %s",
library,
lib_version,
)
[docs]
def Predict(
self,
request: Union[ProtobufMessage, Iterable[ProtobufMessage]],
context: ServicerContext,
caikit_rpc: TaskPredictRPC,
*_,
**__,
) -> Union[ProtobufMessage, Iterable[ProtobufMessage]]:
"""Global predict RPC -- Mocks the invocation of a Caikit Library module.run()
method for a loaded Caikit Library model
Args:
request (ProtobufMessage):
A deserialized RPC request message
context (ServicerContext):
Context object (contains request metadata, etc)
Returns:
response (Union[ProtobufMessage, Iterable[ProtobufMessage]]):
A Caikit Library data model response object
"""
# Make sure the request has a model before doing anything
model_id = get_metadata(context, MODEL_MESH_MODEL_ID_KEY)
request_name = caikit_rpc.request.name
with self._handle_predict_exceptions(model_id, request_name), alog.ContextLog(
log.debug, "GlobalPredictServicer.Predict:%s", request_name
):
# Before retrieving the model, which can trigger lazy backend
# initialization, we notify all backends of the context for this
# request which may update how the discovery logic works.
self.notify_backends_with_context(model_id, context)
# Retrieve the model from the model manager
log.debug("<RUN52259129D>", "Retrieving model '%s'", model_id)
model = self._model_manager.retrieve_model(model_id)
model_class = type(model)
# Little hackity hack: Calling _verify_model_task upfront here as well to
# short-circuit requests where the model is _totally_ unsupported
self._verify_model_task(model)
# Unmarshall the request object into the required module run argument(s)
with PREDICT_FROM_PROTO_SUMMARY.labels(
grpc_request=request_name, model_id=model_id
).time():
inference_signature = model_class.get_inference_signature(
input_streaming=caikit_rpc.input_streaming,
output_streaming=caikit_rpc.output_streaming,
task=caikit_rpc.task,
)
if not inference_signature:
raise CaikitRuntimeException(
StatusCode.INVALID_ARGUMENT,
f"Model class {model_class} does not support {caikit_rpc.name}",
)
if caikit_rpc.input_streaming:
caikit_library_request = self._build_caikit_library_request_stream(
request, inference_signature, caikit_rpc
)
else:
caikit_library_request = build_caikit_library_request_dict(
request,
inference_signature,
)
response = self.predict_model(
request_name,
model_id,
input_streaming=caikit_rpc.input_streaming,
output_streaming=caikit_rpc.output_streaming,
task=caikit_rpc.task,
aborter=RpcAborter(context) if self._interrupter else None,
context=context,
context_arg=inference_signature.context_arg,
model=model,
**caikit_library_request,
)
# Marshall the response to the necessary return type
with PREDICT_TO_PROTO_SUMMARY.labels(
grpc_request=request_name, model_id=model_id
).time():
if caikit_rpc.output_streaming:
response_proto = build_proto_stream(response, context)
else:
response_proto = build_proto_response(response)
return response_proto
[docs]
def StartPredictionJob(
self,
request: Union[ProtobufMessage, Iterable[ProtobufMessage]],
context: ServicerContext,
caikit_rpc: TaskPredictionJobRPC,
*_,
**__,
) -> PredictionJob:
"""StartPredictionJob -- Mocks the invocation of a Caikit Core Library
ModelManager.start_prediction_job() method using a loaded model
Args:
request (ProtobufMessage):
A deserialized RPC request message
context (ServicerContext):
Context object (contains request metadata, etc)
caikit_rpc: TaskPredictionJobRPC
The RPC used to
Returns:
response (Union[ProtobufMessage, Iterable[ProtobufMessage]]):
A Caikit Library data model response object
"""
# Make sure the request has a model before doing anything
model_id = get_metadata(context, MODEL_MESH_MODEL_ID_KEY)
request_name = caikit_rpc.request.name
with alog.ContextLog(
log.debug,
"GlobalBackgroundPredictServicer.BackgroundPredict:%s",
request_name,
):
# Before retrieving the model, which can trigger lazy backend
# initialization, we notify all backends of the context for this
# request which may update how the discovery logic works.
self.notify_backends_with_context(model_id, context)
# Retrieve the model from the model manager
log.debug("<RUN52259129D>", "Retrieving model '%s'", model_id)
model = self._model_manager.retrieve_model(model_id)
model_class = type(model)
# Little hackity hack: Calling _verify_model_task upfront here as well to
# short-circuit requests where the model is _totally_ unsupported
self._verify_model_task(model)
# Unmarshall the request object into the required module run argument(s)
with PREDICT_FROM_PROTO_SUMMARY.labels(
grpc_request=request_name, model_id=model_id
).time():
inference_signature = model_class.get_inference_signature(
input_streaming=caikit_rpc.input_streaming,
output_streaming=caikit_rpc.output_streaming,
task=caikit_rpc.task,
)
if not inference_signature:
raise CaikitRuntimeException(
StatusCode.INVALID_ARGUMENT,
f"Model class {model_class} does not support {caikit_rpc.name}",
)
if caikit_rpc.input_streaming:
caikit_library_request = self._build_caikit_library_request_stream(
request, inference_signature, caikit_rpc
)
else:
caikit_library_request = build_caikit_library_request_dict(
request,
inference_signature,
)
response = self.run_prediction_job(
request_name,
model_id,
task=caikit_rpc.task,
prediction_func_name=inference_signature.method_name,
context=context,
context_arg=inference_signature.context_arg,
model=model,
wait=False,
**caikit_library_request,
)
response_proto = build_proto_response(response)
return response_proto
[docs]
def run_prediction_job(
self,
request_name: str,
model_id: str,
prediction_func_name: str = "run",
task: Optional[TaskBase] = None,
context: Optional[RuntimeServerContextType] = None, # noqa: F821
context_arg: Optional[str] = None,
model: Optional[ModuleBase] = None,
**kwargs,
) -> PredictionJob:
"""Start a prediction job against the given model using the raw arguments to
the model's run function.
Args:
request_name (str):
The name of the request message to validate the model's task
model_id (str):
The ID of the loaded model
prediction_func_name (str):
Explicit name of the prediction function to predict
task (Optional[TaskBase])
The task to use for inference (if multitask model)
context (Optional[RuntimeServerContextType]):
The context object from the inbound request
context_arg (Optional[str]):
The arg name to the model inference method where the context
should be passed
model (Optional[ModuleBase]):
Pre-fetched model object
**kwargs: Keyword arguments to pass to the model's run function
Returns:
response (Union[DataBase, Iterable[DataBase]]):
The object (unary) or objects (output stream) produced by the
inference request
PredictionJob: _description_
"""
trace.set_tracer(context, self._tracer)
trace_context = trace.get_trace_context(context)
trace_span_name = f"{__name__}.GlobalPredictServicer.run_prediction_job"
with self._handle_predict_exceptions(
model_id, request_name
), self._tracer.start_as_current_span(
trace_span_name,
context=trace_context,
) as trace_span:
# Set trace attributes available before checking anything
trace_span.set_attribute("calling", trace_span_name)
trace_span.set_attribute("model_id", model_id)
trace_span.set_attribute("request_name", request_name)
trace_span.set_attribute("task", getattr(task, "__name__", str(task)))
model = model or self._model_manager.retrieve_model(model_id)
self._verify_model_task(model)
inference_sig = model.get_inference_signature(
output_streaming=False,
input_streaming=False,
task=task,
)
inference_func_name = inference_sig.method_name
context_arg = inference_sig.context_arg
log.debug2(
"Deduced inference function name: %s and context_arg: %s",
inference_func_name,
context_arg,
)
trace_span.set_attribute("inference_func_name", inference_func_name)
# If a context arg was supplied then add the context
if context_arg:
kwargs[context_arg] = context
model_future = MODEL_MANAGER.start_prediction_job(
model=model, prediction_func_name=prediction_func_name, **kwargs
)
# Update Prometheus metrics
JOB_PREDICT_RPC_COUNTER.labels(
grpc_request=request_name, code=StatusCode.OK.name, model_id=model_id
).inc()
if get_config().runtime.metering.enabled:
self.rpc_meter.update_metrics(str(type(model)))
return PredictionJob(prediction_id=model_future.id)
[docs]
def predict_model(
self,
request_name: str,
model_id: str,
inference_func_name: str = "run",
input_streaming: Optional[bool] = None,
output_streaming: Optional[bool] = None,
task: Optional[TaskBase] = None,
aborter: Optional[RpcAborter] = None,
context: Optional[RuntimeServerContextType] = None, # noqa: F821
context_arg: Optional[str] = None,
model: Optional[ModuleBase] = None,
**kwargs,
) -> Union[DataBase, Iterable[DataBase]]:
"""Run a prediction against the given model using the raw arguments to
the model's run function.
Args:
request_name (str):
The name of the request message to validate the model's task
model_id (str):
The ID of the loaded model
inference_func_name (str):
Explicit name of the inference function to predict (ignored if
input_streaming and output_streaming set)
input_streaming (Optional[bool]):
Use the task function with input streaming
output_streaming (Optional[bool]):
Use the task function with output streaming
task (Optional[TaskBase])
The task to use for inference (if multitask model)
aborter (Optional[RpcAborter]):
If using abortable calls, this is the aborter to use
context (Optional[RuntimeServerContextType]):
The context object from the inbound request
context_arg (Optional[str]):
The arg name to the model inference method where the context
should be passed
model (Optional[ModuleBase]):
Pre-fetched model object
**kwargs: Keyword arguments to pass to the model's run function
Returns:
response (Union[DataBase, Iterable[DataBase]]):
The object (unary) or objects (output stream) produced by the
inference request
"""
trace.set_tracer(context, self._tracer)
trace_context = trace.get_trace_context(context)
trace_span_name = f"{__name__}.GlobalPredictServicer.predict_model"
with self._handle_predict_exceptions(
model_id, request_name
), self._tracer.start_as_current_span(
trace_span_name,
context=trace_context,
) as trace_span:
# Set trace attributes available before checking anything
trace_span.set_attribute("calling", trace_span_name)
trace_span.set_attribute("model_id", model_id)
trace_span.set_attribute("request_name", request_name)
trace_span.set_attribute("task", getattr(task, "__name__", str(task)))
model = model or self._model_manager.retrieve_model(model_id)
self._verify_model_task(model)
if input_streaming is not None and output_streaming is not None:
inference_sig = model.get_inference_signature(
output_streaming=output_streaming,
input_streaming=input_streaming,
task=task,
)
inference_func_name = inference_sig.method_name
context_arg = inference_sig.context_arg
log.debug2(
"Deduced inference function name: %s and context_arg: %s",
inference_func_name,
context_arg,
)
trace_span.set_attribute("inference_func_name", inference_func_name)
# If a context arg was supplied then add the context
if context_arg:
kwargs[context_arg] = context
model_run_fn = getattr(model, inference_func_name)
# NB: we previously recorded the size of the request, and timed this module to
# provide a rudimentary throughput metric of size / time
# 🌶️🌶️🌶️ The `AbortableContext` will only abort if both `self._interrupter` and
# `aborter` are set
with alog.ContextLog(
log.debug,
"GlobalPredictServicer.Predict.caikit_library_run:%s",
request_name,
), PREDICT_CAIKIT_LIBRARY_SUMMARY.labels(
grpc_request=request_name, model_id=model_id
).time(), AbortableContext(
aborter, self._interrupter
):
response = model_run_fn(**kwargs)
# Update Prometheus metrics
PREDICT_RPC_COUNTER.labels(
grpc_request=request_name, code=StatusCode.OK.name, model_id=model_id
).inc()
if get_config().runtime.metering.enabled:
self.rpc_meter.update_metrics(str(type(model)))
return response
[docs]
def stop_metering(self):
if self._started_metering:
self.rpc_meter.flush_metrics()
self.rpc_meter.end_writer_thread()
self._started_metering = False
[docs]
def notify_backends_with_context(
self,
model_id: str,
context: RuntimeServerContextType,
):
"""Utility to notify all configured backends of the request context"""
for backend in MODEL_MANAGER.get_module_backends():
log.debug3(
"Notifying backend type %s of with context of type %s",
type(backend),
type(context),
)
backend.handle_runtime_context(model_id, context)
## Implementation Details ##################################################
[docs]
@contextmanager
def _handle_predict_exceptions(self, model_id: str, request_name: str):
try:
yield
except CaikitRuntimeException as e:
log_dict = {
"log_code": "<RUN50530380W>",
"message": e.message,
"model_id": model_id,
"error_id": e.id,
}
log.warning({**log_dict, **e.metadata})
PREDICT_RPC_COUNTER.labels(
grpc_request=request_name, code=e.status_code.name, model_id=model_id
).inc()
raise e
# Duplicate code in global_train_servicer
# pylint: disable=duplicate-code
except CaikitCoreException as e:
raise_caikit_runtime_exception(exception=e)
except (TypeError, ValueError) as e:
log_dict = {
"log_code": "<RUN490439039W>",
"message": repr(e),
"model_id": model_id,
"stack_trace": traceback.format_exc(),
}
log.warning(log_dict)
PREDICT_RPC_COUNTER.labels(
grpc_request=request_name,
code=StatusCode.INVALID_ARGUMENT.name,
model_id=model_id,
).inc()
raise CaikitRuntimeException(
StatusCode.INVALID_ARGUMENT,
f"{e}",
) from e
# NOTE: Specifically handling RpcError here is to pass through
# grpc client errors, since we expect those clients to be common
except RpcError as e:
log_dict = {
"log_code": "<RUN29029171W>",
"message": repr(e),
"model_id": model_id,
}
log.warning(log_dict)
raise CaikitRuntimeException(
e.code(),
e.details(),
) from e
except Exception as e:
log_dict = {
"log_code": "<RUN49049070W>",
"message": repr(e),
"model_id": model_id,
"stack_trace": traceback.format_exc(),
}
log.warning(log_dict)
PREDICT_RPC_COUNTER.labels(
grpc_request=request_name,
code=StatusCode.INTERNAL.name,
model_id=model_id,
).inc()
raise CaikitRuntimeException(
StatusCode.INTERNAL,
f"{e}",
) from e
[docs]
def _verify_model_task(self, model: ModuleBase):
"""Raise if the model is not supported for the task"""
rpc_set: Set[TaskPredictRPC] = set(self._inference_service.caikit_rpcs.values())
module_rpc: TaskPredictRPC = next(
(rpc for rpc in rpc_set if rpc.task in model.__class__.tasks),
None,
)
if not module_rpc:
raise CaikitRuntimeException(
status_code=StatusCode.INVALID_ARGUMENT,
message=f"Inference for model class {type(model)} not supported by this runtime",
)
[docs]
def _build_caikit_library_request_stream(
self,
request_stream: Iterable[ProtobufMessage],
module_signature: CaikitMethodSignature,
caikit_rpc: TaskPredictRPC,
) -> Dict[str, Any]:
"""Builds the kwargs dict to pass to a caikit module.
Specifically handles the case of constructing input `DataStreams` for some parameters
which are meant to be streamed in.
See caikit.runtime.build_caikit_library_request_dict
"""
def call_build_request_dict(request: ProtobufMessage) -> Dict[str, Any]:
"""This is instead of using a lambda to map each request in the stream"""
return build_caikit_library_request_dict(request, module_signature)
streaming_params = caikit_rpc.task.get_required_parameters(input_streaming=True)
# We need n+1 streams because the first stream is peeked in order to read all the
# non-streaming parameters off of the first message
num_streams = 1 + len(streaming_params)
all_the_streams = itertools.tee(request_stream, num_streams)
# Read the non-streaming parameters off of the first message in the stream
stream_num = 0
kwargs_dict = build_caikit_library_request_dict(
next(all_the_streams[stream_num]), module_signature
)
stream_num += 1
for param in streaming_params:
# For each "streaming" parameter, grab one of the tee'd streams and map it to return
# a `DataStream` of that individual parameter
def build_getter_from_request_dict(param_name: str) -> Any:
# This builder is required to correctly closure the `param_name` of the streaming
# parameter that we're interested in
def get_fn(request_dict):
# Return this parameter out of the request dict
return request_dict.get(param_name)
return get_fn
param_stream = (
DataStream.from_iterable(all_the_streams[stream_num])
.map(call_build_request_dict)
.map(build_getter_from_request_dict(param_name=param))
)
# Add the datastream of this one parameter into the final kwargs dict
kwargs_dict[param] = param_stream
stream_num += 1
return kwargs_dict