Source code for caikit.runtime.client.remote_module_base

# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The RemoteModuleBase is a base class that can be mutated to have the same task methods
as a ModuleBase but submit requests to a remote runtime instead of loading locally. By 
design this class/factory does not use any references to the original Module class.
"""
# Standard
from collections import OrderedDict
from datetime import datetime, timedelta
from threading import Lock
from typing import Any, Callable, Dict, Generator, List, Optional, Type, Union
import copy
import inspect
import json
import uuid

# Third Party
from requests import HTTPError, RequestException, Session
import grpc

# First Party
import alog

# Local
from caikit.core.data_model import DataBase, DataStream
from caikit.core.exceptions import error_handler
from caikit.core.modules import ModuleBase, module
from caikit.core.task import TaskBase
from caikit.interfaces.common.data_model import ConnectionInfo, Sequence
from caikit.runtime.client.remote_config import RemoteModuleConfig, RemoteRPCDescriptor
from caikit.runtime.client.utils import (
    construct_grpc_channel,
    construct_requests_session,
)
from caikit.runtime.names import (
    HTTP_TO_STATUS_CODE,
    MODEL_ID,
    OPTIONAL_INPUTS_KEY,
    REQUIRED_INPUTS_KEY,
    ServiceType,
    get_grpc_route_name,
    get_http_route_name,
)
from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException

log = alog.use_channel("RMBASE")
error = error_handler.get(log)


[docs] class RemoteModuleBase(ModuleBase): """Class to act as the base for remote modules. This class will be subclassed and mutated by construct_remote_module_class to make it have the same functions and parameters as the source module.""" def __init__( self, connection_info: ConnectionInfo, protocol: str, model_key: str, model_name: str, ): # Initialize module base super().__init__() self._model_name = model_name # Load connection parameters self._connection = connection_info self._tls = self._connection.tls self._protocol = protocol self._model_key = model_key # Configure GRPC variables and threading lock self._channel_lock = Lock() self._conn_channel: Optional[Union[grpc.Channel, Session]] = None self._current_conn_time = None self._max_conn_delta = timedelta(seconds=self._connection.max_session_age) # Assert parameter values if self._protocol == "grpc" and self._tls.enabled: error.value_check( "<COR74451567E>", not self._tls.insecure_verify, "GRPC does not support insecure TLS connections." "Please provide a valid CA certificate", )
[docs] def __del__(self): """Destructor to ensure channel/session is cleaned up on deletion""" with self._channel_lock: if self._conn_channel: self._conn_channel.close()
### Method Factories
[docs] @classmethod def generate_train_function(cls, method: RemoteRPCDescriptor) -> Callable: """Factory function to construct a train function that will then be set as an attribute""" def train_func(self, *args, **kwargs) -> method.signature.return_type: train_kwargs = {} if "_output_path" in kwargs: train_kwargs["output_path"] = kwargs.pop("_output_path") train_kwargs["model_name"] = kwargs.pop( "_model_name", f"{self._model_name}-{uuid.uuid4()}" ) # 🌶️🌶️🌶️ This code martials the train function arguments/kwargs into the desired # TrainParameters dataobject. Use signature parsing to ensure all args are mapped to # the correct name. Also use string replacement as names.get_train_parameter_name # requires a ref to the Module bound_args = method.signature.method_signature.bind(*args, **kwargs) train_parameter_class = DataBase.get_class_for_name( method.request_dm_name.replace("Request", "Parameters") ) train_kwargs["parameters"] = train_parameter_class(**bound_args.arguments) # Set return type to TrainType method.response_dm_name = "TrainingJob" training_response = self.remote_method_request( method, ServiceType.TRAINING, **train_kwargs ) return cls( self._connection, self._protocol, self._model_key, training_response.model_name, ) # Override infer function name attributes and signature train_func.__name__ = method.signature.method_name train_func.__qualname__ = method.signature.qualified_name train_func.__signature__ = method.signature.method_signature return train_func
[docs] @classmethod def generate_inference_function( cls, task: Type[TaskBase], method: RemoteRPCDescriptor ) -> Callable: """Factory function to construct inference functions that will be set as an attribute.""" def infer_func(self, *args, **kwargs) -> method.signature.return_type: return self.remote_method_request( method, ServiceType.INFERENCE, *args, **kwargs, ) # Override infer function name attributes and signature infer_func.__name__ = method.signature.method_name infer_func.__qualname__ = method.signature.qualified_name infer_func.__signature__ = method.signature.method_signature # Wrap infer function with task method to ensure internal attributes are properly # set task_wrapped_infer_func = task.taskmethod( method.input_streaming, method.output_streaming )(infer_func) return task_wrapped_infer_func
### Remote Interface
[docs] def remote_method_request( self, method: RemoteRPCDescriptor, service_type: ServiceType, *args, **kwargs ) -> Any: """Function to run a remote request based on the data stored in RemoteRPCDescriptor""" if self._protocol == "grpc": return self._request_via_grpc(method, service_type, *args, **kwargs) elif self._protocol == "http": return self._request_via_http(method, service_type, *args, **kwargs) raise NotImplementedError(f"Unknown protocol {self._protocol}")
### HTTP Helper Functions
[docs] def _request_via_http( self, method: RemoteRPCDescriptor, service_type: ServiceType, *args, **kwargs, ) -> Any: # Get request data model request_dm = DataBase.get_class_for_name(method.request_dm_name)( *args, **kwargs ) # ! This is a hack to ensure all fields/types have been json encoded (bytes/datetime/etc). request_dm_dict = json.loads(request_dm.to_json()) # ! This is another hack to ensure all Union types match the oneOf generated by pydantic request_dm_dict = self._rename_union_sequence_types( request_dm_dict, request_dm.__class__ ) # Parse generic Request type into HttpRequest format if service_type == ServiceType.INFERENCE: http_request_dict = { REQUIRED_INPUTS_KEY: {}, OPTIONAL_INPUTS_KEY: {}, MODEL_ID: self._model_name, } for param in method.signature.parameters: value = request_dm_dict.get(param) # If param doesn't have a default then add it to inputs if param not in method.signature.default_parameters: http_request_dict[REQUIRED_INPUTS_KEY][param] = value # If the param is different then the default then add it to parameters elif value != method.signature.default_parameters.get(param): http_request_dict[OPTIONAL_INPUTS_KEY][param] = value # If there is only one input then collapse down the value if len(http_request_dict[REQUIRED_INPUTS_KEY]) == 1: http_request_dict[REQUIRED_INPUTS_KEY] = list( http_request_dict[REQUIRED_INPUTS_KEY].values() )[0] elif service_type == ServiceType.TRAINING: # Strip all null values def _remove_null_values(_attr): if isinstance(_attr, dict): return { key: _remove_null_values(value) for key, value in _attr.items() if value } if isinstance(_attr, list): return [ _remove_null_values(listitem) for listitem in _attr if listitem ] return _attr http_request_dict = _remove_null_values(request_dm_dict) request_url = ( f"{self._get_remote_target()}{get_http_route_name(method.rpc_name)}" ) # Send request while capturing any errors and reporting them as CaikitRuntimeExceptions try: response = self._http_session.post( request_url, json=http_request_dict, stream=method.output_streaming ) except RequestException as err: raise CaikitRuntimeException( grpc.StatusCode.UNKNOWN, "Unknown exception while connecting to runtime" ) from err if response.status_code != 200: # Capture any HTTP errors and return them with the proper Caikit Status mapping try: response.raise_for_status() except HTTPError as err: raise CaikitRuntimeException( HTTP_TO_STATUS_CODE.get( response.status_code, grpc.StatusCode.UNKNOWN ), f"Received status {response.status_code} from remote server: {response.text}", ) from err # Parse response data model either as file or json response_dm_class = DataBase.get_class_for_name(method.response_dm_name) if method.output_streaming: def stream_parser(): """Helper Generator to parse SSE events""" try: for line in response.iter_lines(): # Skip empty or event lines as they're constant if "data:" in line: # Split data lines and remove data: tags before parsing by DM decoded_response = line.decode(response.encoding).replace( "data: ", "" ) yield response_dm_class.from_json(decoded_response) except RequestException as err: raise CaikitRuntimeException( grpc.StatusCode.UNKNOWN, "Received unknown exception from remote server while streaming results", ) from err # Attach reference of this response to the returned DataStream. This ensures # that requests stream won't get closed until after the DataStream has been cleaned up return_stream = DataStream(stream_parser) return_stream._source = response.content return return_stream # If the response_dm_class supports file operations than the HTTP server would've returned # with to_file instead of to_json. Thus for the client we need to return from_file instead # of from_json if response_dm_class.supports_file_operations: return response_dm_class.from_file(response.text) return response_dm_class.from_json(response.text)
### GRPC Helper Functions
[docs] def _request_via_grpc( self, method: RemoteRPCDescriptor, service_type: ServiceType, *args, **kwargs, ) -> Any: """Helper function to send a grpc request""" # Get the request types request_dm_class = DataBase.get_class_for_name(method.request_dm_name) request_protobuf_class = request_dm_class.get_proto_class() # Get the response types response_dm_class = DataBase.get_class_for_name(method.response_dm_name) response_protobuf_class = response_dm_class.get_proto_class() # Get the RPC route grpc_route = get_grpc_route_name(service_type, method.rpc_name) # Construct the service_rpc and serializers if method.input_streaming and method.output_streaming: service_rpc = self._grpc_channel.stream_stream( grpc_route, request_serializer=request_protobuf_class.SerializeToString, response_deserializer=response_protobuf_class.FromString, ) elif method.input_streaming: service_rpc = self._grpc_channel.stream_unary( grpc_route, request_serializer=request_protobuf_class.SerializeToString, response_deserializer=response_protobuf_class.FromString, ) elif method.output_streaming: service_rpc = self._grpc_channel.unary_stream( grpc_route, request_serializer=request_protobuf_class.SerializeToString, response_deserializer=response_protobuf_class.FromString, ) else: service_rpc = self._grpc_channel.unary_unary( grpc_route, request_serializer=request_protobuf_class.SerializeToString, response_deserializer=response_protobuf_class.FromString, ) # Construct request object if method.input_streaming: # Bind the args and kwargs to the signature for parsing. Use None for the self argument bound_args = method.signature.method_signature.bind(None, *args, **kwargs) bound_args.arguments.pop("self") # Gather all iterable parameters as these should be streamed streaming_kwargs = OrderedDict() for name in self._get_streaming_arguments(**bound_args.arguments): streaming_kwargs[name] = bound_args.arguments.pop(name) def input_stream_parser(): """Helper function to iterate over a datastream and stream requests""" for stream_tuple in DataStream.zip(*streaming_kwargs.values()): stream_arguments = copy.deepcopy(bound_args) for streaming_key, sub_value in zip( streaming_kwargs.keys(), stream_tuple ): stream_arguments.arguments[streaming_key] = sub_value yield request_dm_class( *stream_arguments.args, **stream_arguments.kwargs ).to_proto() grpc_request = input_stream_parser() else: # If not streaming then construct a simple request grpc_request = request_dm_class(*args, **kwargs).to_proto() request_kwargs = { "metadata": [(self._model_key, self._model_name)], } if self._connection.timeout: request_kwargs["timeout"] = self._connection.timeout # Send RPC request with or without streaming if method.output_streaming: def output_stream_parser(): """Helper function to stream result objects""" try: for proto in service_rpc(grpc_request, **request_kwargs): yield response_dm_class.from_proto(proto) except grpc.RpcError as err: raise CaikitRuntimeException( err.code() if hasattr(err, "code") else grpc.StatusCode.UNKNOWN, "Error received while streaming GRPC result", ) from err # Attach reference of this RemoteModuleClass to the returned DataStream. This ensures # the GRPC Channel won't get closed until after the DataStream has been cleaned up return_stream = DataStream(output_stream_parser) return_stream._source = self return return_stream else: try: response = service_rpc(grpc_request, **request_kwargs) except grpc.RpcError as err: raise CaikitRuntimeException( err.code() if hasattr(err, "code") else grpc.StatusCode.UNKNOWN, "Error received from GRPC request", ) from err return response_dm_class.from_proto(response)
@property def _grpc_channel(self) -> grpc.Channel: """Helper function to construct a GRPC channel with correct credentials and TLS settings.""" def grpc_channel_construction_fn(): target = self._get_remote_target() options = list(self._connection.options.items()) return construct_grpc_channel( target, options, self._tls, self._connection.retries, self._connection.retry_options, ) return self._get_remote_object(grpc_channel_construction_fn) @property def _http_session(self) -> Session: """Helper function to construct a requests Session with with correct credentials and TLS settings.""" def session_construction_fn(): return construct_requests_session( self._connection.options, self._tls, self._connection.timeout, self._connection.retries, self._connection.retry_options, ) return self._get_remote_object(session_construction_fn) ### Generic Helper Functions
[docs] def _get_remote_object( self, construction_fn: Callable[[None], Union[grpc.Channel, Session]] ) -> Union[grpc.Channel, Session]: """Helper function to control construction of a grpc channel or http session Args: construction_fn (Callable[[None], Union[grpc.Channel, Session]]): _description_ Returns: Union[grpc.Channel, Session]: _description_ """ # If max_session_age is 0 then always return a new session/channel if self._connection.max_session_age == 0: return construction_fn() with self._channel_lock: # If there isn't a channel then construct a new one and set the time if not self._conn_channel: self._current_conn_time = datetime.now() self._conn_channel = construction_fn() # If the max session age is greater then 0 then check if the conn channel time # is older than the delta. If so construct a new channel elif ( self._connection.max_session_age > 0 and datetime.now() - self._current_conn_time > self._max_conn_delta ): log.debug2("Creating new client channel due to max_session_age value") self._current_conn_time = datetime.now() self._conn_channel = construction_fn() return self._conn_channel
[docs] def _get_remote_target(self) -> str: """Get the current remote target""" target_string = f"{self._connection.hostname}:{self._connection.port}" if self._protocol == "grpc": return target_string else: if self._tls.enabled: return f"https://{target_string}" else: return f"http://{target_string}"
[docs] @staticmethod def _get_streaming_arguments(**kwargs: Dict[str, Any]) -> List[str]: """Helper function to detect which kwargs are streaming""" streaming_arguments = [] for name, value in kwargs.items(): if isinstance(value, (DataStream, Generator)): streaming_arguments.append(name) return streaming_arguments
[docs] @staticmethod def _rename_union_sequence_types(obj: Any, dm_type: type): """Helper function that renames all references in a dictionary to match the oneOf value of the DataModel and to collapse all Primitive sequences. This is required to match the format of http requests For example: { "union_str": "test", "ints": { "values":[1,2,3] } } Becomes: { "union": "test", "ints":[1,2,3] } """ if isinstance(obj, list): # If list contains DataObjects then recurse. Else return primitive list if inspect.isclass(dm_type) and issubclass(dm_type, DataBase): return [ RemoteModuleBase._rename_union_sequence_types(sub_obj, dm_type) for sub_obj in obj ] return obj elif isinstance(obj, dict): # Ensure dm_type is a DataObject if not (inspect.isclass(dm_type) and issubclass(dm_type, DataBase)): raise ValueError("Dict object must map to DataBase") # If instance is a sequence then collapse down the values if inspect.isclass(dm_type) and issubclass(dm_type, Sequence): return obj.get("values", []) output_dict = {} for key, val in obj.items(): # If key is apart of a Union then replace the field name with # the union name. E.g. data_str -> data dest_key = key if key in dm_type._fields_to_oneof: dest_key = dm_type._fields_to_oneof[key] val_type = dm_type.get_field_message_type(key) output_dict[dest_key] = RemoteModuleBase._rename_union_sequence_types( val, val_type ) return output_dict # If object is a primitive then return it directly else: return obj
[docs] def construct_remote_module_class( model_config: RemoteModuleConfig, model_class: Type[RemoteModuleBase] = RemoteModuleBase, ) -> Type[ModuleBase]: """Factory function to construct unique Remote Module Class.""" # Construct unique class which will have functions attached to it RemoteModelClass: Type[RemoteModuleBase] = type( "RemoteModelClass", (model_class,), dict(model_class.__dict__) ) # Add the method signatures for train and each task if model_config.train_method: train_func = RemoteModelClass.generate_train_function(model_config.train_method) setattr( RemoteModelClass, model_config.train_method.signature.method_name, train_func, ) task_list = [] for task, method_descriptions in model_config.task_methods: task_list.append(task) for description in method_descriptions: func = RemoteModelClass.generate_inference_function(task, description) setattr(RemoteModelClass, description.signature.method_name, func) # Wrap Module with decorator to ensure attributes are properly set RemoteModelClass = module( id=model_config.module_id, name=model_config.module_name, version="0.0.0", tasks=task_list, # We should make a remote backend that just stores signatures backend_type="LOCAL", )(RemoteModelClass) return RemoteModelClass