# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Standard
from typing import Callable, Dict, Optional, Union
import traceback
# Third Party
from grpc._utilities import RpcMethodHandler
from prometheus_client import Gauge
import grpc
# First Party
import alog
# Local
from caikit.core.data_model import DataBase
from caikit.runtime.names import (
ACK_HEADER_STRING,
ServiceType,
get_service_package_name,
)
from caikit.runtime.service_factory import ServicePackage
from caikit.runtime.service_generation.rpcs import CaikitRPCBase
from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException
log = alog.use_channel("SERVER-WRAPR")
IN_PROGRESS_GAUGE = Gauge(
"rpc_in_progress_gauge",
"Total number of in-flight requests to caikit-runtime",
["rpc_name"],
)
[docs]
class CaikitRuntimeServerWrapper(grpc.Server):
"""This class wraps an underlying gRPC server for the purpose of
intercepting the binding of servicers (e.g., the CaikitRuntimeServicer) to
the server so that the RPC handlers that are registered to the server
can optionally be replaced with a generic global predict RPC handler
instead.
"""
def __init__(
self,
server,
rpc_callable: Union[Callable, Dict[CaikitRPCBase, Callable]],
intercepted_svc_package: ServicePackage,
service_type: ServiceType,
):
"""Initialize a new CaikitRuntimeServerWrapper
Args:
server(grpc.Server): The server that is being wrapped
rpc_callable(Union[Callable, Dict[CaikitRPCBase, Callable]]): Either a function that
will accept an arbitrary gRPC request message and a grpc.ServicerContext, and
return a suitable gRPC response message or a mapping of CaikitRPC's to functions
that have the same limitation.
"""
self._server = server
self._service_type = service_type
self._rpc_callable = rpc_callable
self._intercepted_svc_package = intercepted_svc_package
self._intercepted_methods = []
for method in self._intercepted_svc_package.descriptor.methods:
# Take the method short name (e.g., 'SyntaxIzumoPredict') and
# concatenate it with the intercepted service name to produce
# a fully qualified RPC method name that we wish to intercept
# (e.g., '/natural_language_understanding.CaikitRuntime/SyntaxIzumoPredict')
fqm = f"/{self._intercepted_svc_package.descriptor.full_name}/{method.name}"
log.info("<RUN81194024I>", "Intercepting RPC method %s", fqm)
self._intercepted_methods.append((method.name, fqm))
# **************************************************************************
# Custom methods
# **************************************************************************
[docs]
def intercepted_service(self):
"""Get the fully-qualified name of the intercepted service
Returns:
string:
The fully-qualified name of the service whose RPC handlers are
intercepted by this server wrapper
"""
return self._intercepted_svc_package.descriptor.full_name
[docs]
def intercepted_methods(self):
"""Get the list of intercepted predict RPC methods
Returns:
list((string, string)):
A list of two-element tuples containing the short name (e.g.,
'SyntaxIzumoPredict') and fully-qualified name (e.g.,
'/natural_language_understanding.CaikitRuntime/SyntaxIzumoPredict')
of every RPC method intercepted by this server wrapper
"""
return self._intercepted_methods
[docs]
@staticmethod
def safe_rpc_wrapper(rpc: Callable, caikit_rpc: Optional[CaikitRPCBase] = None):
"""This wrapper should be used to safely invoke an RPC. If used, it adds automatic error
handling and conversion to the appropriate response for gRPC, as well as logging indicating
if the the error was intentional (i.e., thrown as CaikitRuntimeException directly) or
unexpected (i.e., thrown as a non GRPC error).
Args:
rpc(Function): Method attached to a servicer instance to be invoked in a safe manner.
Returns:
A function that takes a gRPC request message and ServicerContext, and safely invokes
the provided RPC.
"""
if rpc is None:
message = "Programming error, RPC is None!"
log.error("<RUN33322123E>", message)
raise CaikitRuntimeException(grpc.StatusCode.INTERNAL, message)
if rpc.__name__ == "safe_rpc_call":
return rpc
log.info(
"<RUN33333123I>",
"Wrapping safe rpc for %s",
rpc.__name__,
)
def safe_rpc_call(request, context):
"""This function should be used to safely invoke an RPC. If used, it adds automatic
error handling and conversion to the appropriate response for gRPC, as well as logging
indicating if the the error was intentional (i.e., thrown as CaikitRuntimeException
directly) or unexpected (i.e., thrown as a non GRPC error).
Args:
request(message.Message): gRPC request object normally received by the rpc.
context(grpc.ServicerContext): gRPC context object normally received by the rpc.
Returns:
gRPC response object return by the invoked RPC or None (aborted context).
"""
with alog.ContextLog(log.debug, "[Safe RPC]: %s", rpc.__name__):
try:
IN_PROGRESS_GAUGE.labels(rpc_name=rpc.__name__).inc()
if caikit_rpc:
# Enable sending acknowledgement for bi-directional streaming cases
# Note: we are not enabling it for every rpc, since it may create confusion
# on client side
if (
hasattr(caikit_rpc, "_input_streaming")
and hasattr(caikit_rpc, "_output_streaming")
and caikit_rpc._input_streaming
and caikit_rpc._output_streaming
):
# Send an acknowledgement in metadata
context.send_initial_metadata(((ACK_HEADER_STRING, "ok"),))
# Pass through the CaikitRPCBase rpc description to the global handlers
return rpc(request, context, caikit_rpc=caikit_rpc)
return rpc(request, context)
except CaikitRuntimeException as e:
log_dict = {"log_code": "<RUN89011375W>", "message": e.message}
log.warning({**log_dict, **e.metadata})
context.abort(e.status_code, e.message)
except ValueError as e:
message = repr(e)
log.error("<RUN33333333E>", message)
log.error("<RUN33333334E>", str(traceback.format_exc()))
context.abort(grpc.StatusCode.UNKNOWN, message)
finally:
IN_PROGRESS_GAUGE.labels(rpc_name=rpc.__name__).dec()
return safe_rpc_call
# **************************************************************************
# Overridden grpc.Server methods
# **************************************************************************
[docs]
def add_generic_rpc_handlers(self, generic_rpc_handlers):
"""Registers GenericRpcHandlers with this Server.
This method will intercept the generic_rpc_handlers
Args:
generic_rpc_handlers: An iterable of GenericRpcHandlers that will be
used to service RPCs.
"""
class DummyHandlerCallDetails(grpc.HandlerCallDetails):
"""Dummy class for constructing a grpc.HandlerCallDetails object"""
def __init__(self, method):
super().__init__()
self.method = method
# Iterate over each grpc.ServiceRpcHandler...
for handler in generic_rpc_handlers:
# ...and check if this is the service we wish to intercept
if handler.service_name() == self.intercepted_service():
# This is the service whose RPC handlers we wish to intercept
# and re-route. We now need to iterate over each method that
# we wish to re-route, get the original RPC handler for that
# method (see
# caikit_runtime_pb2_grpc.add_CaikitRuntimeServicer_to_server
# for a dict of the rpc_method_handlers we wish to re-route)
rerouted_rpc_method_handlers = {}
for method, fqm in self.intercepted_methods():
# Get the original grpc.RpcMethodHandler for this RPC method
original_rpc_handler = handler.service(DummyHandlerCallDetails(fqm))
# Find the Caikit RPC that maps to this rpc
caikit_rpc = self._intercepted_svc_package.caikit_rpcs.get(
method, None
)
if not caikit_rpc:
raise ValueError(f"No Caikit RPC Found for method: {method}")
# Now, swap out the original unary-unary callable with our
# generic predict method, and add this newly re-routed RPC
# method handler to the dict of (method, handler) pairs
safe_rpc_handler = self._make_new_handler(
original_rpc_handler, caikit_rpc
)
rerouted_rpc_method_handlers[method] = safe_rpc_handler
log.info(
"<RUN30032825I>",
"Re-routing RPC %s from %s to %s",
fqm,
self._get_handler_fn(original_rpc_handler),
self._get_handler_fn(rerouted_rpc_method_handlers[method]),
)
# Now that we have re-rerouted all the original RPC method
# handlers to the global predict RPC method handler, it is time
# to bind them to the underlying server that we are wrapping
generic_handler = grpc.method_handlers_generic_handler(
self.intercepted_service(), rerouted_rpc_method_handlers
)
self._server.add_generic_rpc_handlers((generic_handler,))
log.info(
"<RUN24924908I>",
"Interception of service %s complete",
self.intercepted_service(),
)
else:
# This is not the service whose RPC handlers we wish to
# intercept, so just pass the (unmodified) RPC handlers
# along to the underlying gRPC server we are wrapping
assert isinstance(handler, grpc._utilities.DictionaryGenericHandler)
for method in handler._method_handlers:
# Wrap the RPC handler for this method in a safe RPC call,
# but do not replace the handler with a global handler
original_rpc_handler = handler._method_handlers[method]
safe_rpc_handler = self._make_new_handler(original_rpc_handler)
handler._method_handlers[method] = safe_rpc_handler
self._server.add_generic_rpc_handlers(generic_rpc_handlers)
[docs]
def _make_new_handler(
self,
original_rpc_handler: RpcMethodHandler,
caikit_rpc: Optional[CaikitRPCBase] = None,
):
request_deserializer = original_rpc_handler.request_deserializer
response_serializer = original_rpc_handler.response_serializer
if caikit_rpc:
# If the rpc callable is a dict then it must be a mapping of CaikitRPCBase's to their
# callables
if isinstance(self._rpc_callable, dict):
# If this rpc type wasn't in the callable map raise an error
if type(caikit_rpc) not in self._rpc_callable:
raise ValueError(
f"Unknown rpc type {type(caikit_rpc)} passed to MultiFuncWrapper"
)
rpc_func = self._rpc_callable[type(caikit_rpc)]
behavior = self.safe_rpc_wrapper(rpc_func, caikit_rpc)
# Fetch the input/output objects to determine the correct serializer
package_name = get_service_package_name(self._service_type)
rpc_json = caikit_rpc.create_rpc_json(package_name)
input_class = DataBase.get_class_for_name(rpc_json["input_type"])
output_class = DataBase.get_class_for_name(rpc_json["output_type"])
request_deserializer = input_class.get_proto_class().FromString
response_serializer = output_class.get_proto_class().SerializeToString
else:
behavior = self.safe_rpc_wrapper(self._rpc_callable, caikit_rpc)
else:
behavior = self.safe_rpc_wrapper(self._get_handler_fn(original_rpc_handler))
if original_rpc_handler.unary_unary:
handler_constructor = grpc.unary_unary_rpc_method_handler
elif original_rpc_handler.unary_stream:
handler_constructor = grpc.unary_stream_rpc_method_handler
elif original_rpc_handler.stream_unary:
handler_constructor = grpc.stream_unary_rpc_method_handler
else:
handler_constructor = grpc.stream_stream_rpc_method_handler
return handler_constructor(
behavior=behavior,
request_deserializer=request_deserializer,
response_serializer=response_serializer,
)
# **************************************************************************
# Pass-through (i.e., unchanged) grpc.Server methods
# **************************************************************************
[docs]
def add_insecure_port(self, address):
"""Opens an insecure port for accepting RPCs.
This method may only be called before starting the server.
Args:
address: The address for which to open a port.
if the port is 0, or not specified in the address, then gRPC runtime
will choose a port.
Returns:
integer:
An integer port on which server will accept RPC requests.
"""
return self._server.add_insecure_port(address)
[docs]
def add_secure_port(self, address, server_credentials):
"""Opens a secure port for accepting RPCs.
This method may only be called before starting the server.
Args:
address: The address for which to open a port.
if the port is 0, or not specified in the address, then gRPC
runtime will choose a port.
server_credentials: A ServerCredentials object.
Returns:
integer:
An integer port on which server will accept RPC requests.
"""
return self._server.add_secure_port(address, server_credentials)
[docs]
def start(self):
"""Starts this Server.
This method may only be called once. (i.e. it is not idempotent).
"""
self._server.start()
[docs]
def stop(self, grace):
"""Stops this Server.
This method immediately stop service of new RPCs in all cases.
If a grace period is specified, this method returns immediately
and all RPCs active at the end of the grace period are aborted.
If a grace period is not specified (by passing None for `grace`),
all existing RPCs are aborted immediately and this method
blocks until the last RPC handler terminates.
This method is idempotent and may be called at any time.
Passing a smaller grace value in a subsequent call will have
the effect of stopping the Server sooner (passing None will
have the effect of stopping the server immediately). Passing
a larger grace value in a subsequent call *will not* have the
effect of stopping the server later (i.e. the most restrictive
grace value is used).
Args:
grace: A duration of time in seconds or None.
Returns:
A threading.Event that will be set when this Server has completely
stopped, i.e. when running RPCs either complete or are aborted and
all handlers have terminated.
"""
return self._server.stop(grace)
[docs]
def wait_for_termination(self, timeout=None):
"""Block current thread until the server stops.
The wait will not consume computational resources during blocking,
and it will block until one of the two following conditions are met:
1. The server is stopped or terminated;
2. A timeout occurs if timeout is not None.
The timeout argument works in the same way as threading.Event.wait().
Args:
timeout: A floating point number specifying a timeout for the operation in seconds.
Returns:
A bool indicates if the operation times out.
"""
return self._server.wait_for_termination(timeout)
[docs]
@staticmethod
def _get_handler_fn(handler: RpcMethodHandler) -> Callable:
if handler.unary_unary:
return handler.unary_unary
if handler.unary_stream:
return handler.unary_stream
if handler.stream_unary:
return handler.stream_unary
return handler.stream_stream