# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
! NOTE ! This file should not import any extra dependencies. It is intended for
use by client libraries that do not necessarily use a specific runtime server
type.
"""
# Standard
from enum import Enum
from typing import Optional, Type, Union
import re
# Third Party
from grpc import StatusCode
# First Party
import alog
# Local
from caikit.config import get_config
from caikit.core.exceptions.caikit_core_exception import CaikitCoreStatusCode
from caikit.core.modules import ModuleBase
from caikit.core.task import TaskBase
from caikit.core.toolkit.name_tools import camel_to_snake_case, snake_to_upper_camel
from caikit.interfaces.runtime.data_model import (
DeployModelRequest,
ModelInfo,
ModelInfoRequest,
ModelInfoResponse,
RuntimeInfoRequest,
RuntimeInfoResponse,
TrainingInfoRequest,
TrainingStatusResponse,
UndeployModelRequest,
)
log = alog.use_channel("RNTM-NAMES")
################################# Model Management Names #######################
LOCAL_MODEL_TYPE = "standalone-model"
DEFAULT_LOADER_NAME = "default"
DEFAULT_SIZER_NAME = "default"
################################# Service Names ################################
[docs]
class ServiceType(Enum):
"""Common class for describing service types"""
INFERENCE = 1 # Inference service for the GlobalPredictServicer
TRAINING = 2 # Training service for the GlobalTrainServicer
TRAINING_MANAGEMENT = 3
INFO = 4
MODEL_MANAGEMENT = 5
JOB_INFERENCE = 6 # Inference service for background
############################ Service Name Generation ###########################
## Service Package Descriptors
[docs]
def get_ai_domain() -> str:
"""Get the string name for the AI domain
Returns:
domain(str): The domain for this service
"""
caikit_config = get_config()
lib = caikit_config.runtime.library
default_ai_domain_name = snake_to_upper_camel(lib.replace("caikit_", ""))
ai_domain_name = (
caikit_config.runtime.service_generation.domain or default_ai_domain_name
)
return ai_domain_name
[docs]
def get_service_package_name(service_type: Optional[ServiceType] = None) -> str:
"""This helper will get the name of service package
Args:
service_type Optional[ServiceType]: The Service Type's package name to fetch defaults
to runtime
Returns:
str: The name of the service package
"""
# If specific service_type was provided then return their packages
if service_type == ServiceType.INFO:
return INFO_SERVICE_PACKAGE
elif service_type == ServiceType.TRAINING_MANAGEMENT:
return TRAINING_MANAGEMENT_SERVICE_PACKAGE
elif service_type == ServiceType.MODEL_MANAGEMENT:
return MODEL_MANAGEMENT_SERVICE_PACKAGE
caikit_config = get_config()
ai_domain_name = get_ai_domain()
default_package_name = f"caikit.runtime.{ai_domain_name}"
package_name = (
caikit_config.runtime.service_generation.package or default_package_name
)
return package_name
[docs]
def get_service_name(service_type: ServiceType) -> str:
"""This helper will get the name of the service
Args:
service_type ServiceType: The Service Type whose name to fetch
Returns:
str: The name of the service
"""
if service_type == ServiceType.INFERENCE:
return f"{get_ai_domain()}Service"
if service_type == ServiceType.JOB_INFERENCE:
return f"{get_ai_domain()}JobService"
elif service_type == ServiceType.TRAINING:
return f"{get_ai_domain()}TrainingService"
elif service_type == ServiceType.TRAINING_MANAGEMENT:
return TRAINING_MANAGEMENT_SERVICE_NAME
elif service_type == ServiceType.INFO:
return INFO_SERVICE_NAME
## Service RPC Descriptors
[docs]
def get_train_rpc_name(module_class: Type[ModuleBase]) -> str:
"""Helper function to convert from the name of a module to the name of the
request RPC function
"""
# 🌶️🌶️🌶️ The naming scheme for training RPCs probably needs to change.
# This uses the first task from the `tasks` kwarg in the `@caikit.module` decorator.
# This is both:
# - Flaky, since re-ordering that list would be perfectly reasonable and valid to do except
# for the side effect of breaking the training service api
# - Not very intuitive, since a module supporting multiple tasks will have a training
# endpoint that lists only one of them
rpc_name = snake_to_upper_camel(
f"{next(iter(module_class.tasks)).__name__}_{module_class.__name__}_Train"
)
if len(module_class.tasks) > 1:
log.warning(
"<RUN35134050W>",
"Multiple tasks detected for training rpc. "
"Module: [%s], Tasks: [%s], RPC name: %s ",
module_class,
module_class.tasks,
rpc_name,
)
return rpc_name
[docs]
def get_task_predict_rpc_name(
task_or_module_class: Type[Union[ModuleBase, TaskBase]],
input_streaming: bool = False,
output_streaming: bool = False,
) -> str:
"""Helper function to get the name of a task's RPC"""
task_class = (
next(iter(task_or_module_class.tasks))
if issubclass(task_or_module_class, ModuleBase)
else task_or_module_class
)
if input_streaming and output_streaming:
return snake_to_upper_camel(f"BidiStreaming{task_class.__name__}_Predict")
if output_streaming:
return snake_to_upper_camel(f"ServerStreaming{task_class.__name__}_Predict")
if input_streaming:
return snake_to_upper_camel(f"ClientStreaming{task_class.__name__}_Predict")
return snake_to_upper_camel(f"{task_class.__name__}_Predict")
[docs]
def get_task_predict_job_rpc_name(
task_or_module_class: Type[Union[ModuleBase, TaskBase]],
) -> str:
"""Helper function to get the name of a task's start job RPC"""
task_class = (
next(iter(task_or_module_class.tasks))
if issubclass(task_or_module_class, ModuleBase)
else task_or_module_class
)
return snake_to_upper_camel(f"{task_class.__name__}_StartPredictionJob")
[docs]
def get_task_predict_job_status_rpc_name(
task_or_module_class: Type[Union[ModuleBase, TaskBase]],
) -> str:
"""Helper function to get the name of a task's job status RPC"""
task_class = (
next(iter(task_or_module_class.tasks))
if issubclass(task_or_module_class, ModuleBase)
else task_or_module_class
)
return snake_to_upper_camel(f"{task_class.__name__}_GetPredictionJobStatus")
[docs]
def get_task_predict_job_cancel_rpc_name(
task_or_module_class: Type[Union[ModuleBase, TaskBase]],
) -> str:
"""Helper function to get the name of a task's job cancel RPC"""
task_class = (
next(iter(task_or_module_class.tasks))
if issubclass(task_or_module_class, ModuleBase)
else task_or_module_class
)
return snake_to_upper_camel(f"{task_class.__name__}_CancelPredictionJob")
[docs]
def get_task_predict_job_result_rpc_name(
task_or_module_class: Type[Union[ModuleBase, TaskBase]],
) -> str:
"""Helper function to get the name of a task's job resul RPC"""
task_class = (
next(iter(task_or_module_class.tasks))
if issubclass(task_or_module_class, ModuleBase)
else task_or_module_class
)
return snake_to_upper_camel(f"{task_class.__name__}_GetPredictionJobResult")
## Service DataModel Name Descriptors
[docs]
def get_train_request_name(module_class: Type[ModuleBase]) -> str:
"""Helper function to get the request name of a Train Service"""
return f"{get_train_rpc_name(module_class)}Request"
[docs]
def get_train_parameter_name(module_class: Type[ModuleBase]) -> str:
"""Helper function to get the inner request parameter name of a Train Service"""
return f"{get_train_rpc_name(module_class)}Parameters"
[docs]
def get_task_predict_request_name(
task_or_module_class: Type[Union[ModuleBase, TaskBase]],
input_streaming: bool = False,
output_streaming: bool = False,
) -> str:
"""Helper function to get the name of an RPC's request data type"""
task_class = (
next(iter(task_or_module_class.tasks))
if issubclass(task_or_module_class, ModuleBase)
else task_or_module_class
)
if input_streaming and output_streaming:
return snake_to_upper_camel(f"BidiStreaming{task_class.__name__}_Request")
if output_streaming:
return snake_to_upper_camel(f"ServerStreaming{task_class.__name__}_Request")
if input_streaming:
return snake_to_upper_camel(f"ClientStreaming{task_class.__name__}_Request")
return snake_to_upper_camel(f"{task_class.__name__}_Request")
## Service Definitions
TRAINING_MANAGEMENT_SERVICE_NAME = "TrainingManagement"
TRAINING_MANAGEMENT_SERVICE_PACKAGE = "caikit.runtime.training"
TRAINING_MANAGEMENT_SERVICE_SPEC = {
"service": {
"rpcs": [
{
"name": "GetTrainingStatus",
"input_type": TrainingInfoRequest.get_proto_class().DESCRIPTOR.full_name,
"output_type": TrainingStatusResponse.get_proto_class().DESCRIPTOR.full_name,
},
{
"name": "CancelTraining",
"input_type": TrainingInfoRequest.get_proto_class().DESCRIPTOR.full_name,
"output_type": TrainingStatusResponse.get_proto_class().DESCRIPTOR.full_name,
},
]
}
}
INFO_SERVICE_NAME = "InfoService"
INFO_SERVICE_PACKAGE = "caikit.runtime.info"
INFO_SERVICE_SPEC = {
"service": {
"rpcs": [
{
"name": "GetRuntimeInfo",
"input_type": RuntimeInfoRequest.get_proto_class().DESCRIPTOR.full_name,
"output_type": RuntimeInfoResponse.get_proto_class().DESCRIPTOR.full_name,
},
{
"name": "GetModelsInfo",
"input_type": ModelInfoRequest.get_proto_class().DESCRIPTOR.full_name,
"output_type": ModelInfoResponse.get_proto_class().DESCRIPTOR.full_name,
},
]
}
}
MODEL_MANAGEMENT_SERVICE_NAME = "ModelManagement"
MODEL_MANAGEMENT_SERVICE_PACKAGE = "caikit.runtime.models"
MODEL_MANAGEMENT_SERVICE_SPEC = {
"service": {
"rpcs": [
{
"name": "DeployModel",
"input_type": DeployModelRequest.get_proto_class().DESCRIPTOR.full_name,
"output_type": ModelInfo.get_proto_class().DESCRIPTOR.full_name,
},
{
"name": "UndeployModel",
"input_type": UndeployModelRequest.get_proto_class().DESCRIPTOR.full_name,
"output_type": UndeployModelRequest.get_proto_class().DESCRIPTOR.full_name,
},
]
}
}
################################# Server Names #################################
# Invocation metadata key for the model ID, provided by Model Mesh
MODEL_MESH_MODEL_ID_KEY = "mm-model-id"
## HTTP Server
# Endpoint to use for health checks
HEALTH_ENDPOINT = "/health"
# Endpoint to use for server info
INFO_ENDPOINT = "/info"
RUNTIME_INFO_ENDPOINT = f"{INFO_ENDPOINT}/version"
MODELS_INFO_ENDPOINT = f"{INFO_ENDPOINT}/models"
# Endpoints to use for resource management
MANAGEMENT_ENDPOINT = "/management"
MODEL_MANAGEMENT_ENDPOINT = f"{MANAGEMENT_ENDPOINT}/models"
TRAINING_MANAGEMENT_ENDPOINT = f"{MANAGEMENT_ENDPOINT}/trainings"
# These keys are used to define the logical sections of the request and response
# data structures.
REQUIRED_INPUTS_KEY = "inputs"
OPTIONAL_INPUTS_KEY = "parameters"
MODEL_ID = "model_id"
EXTRA_OPENAPI_KEY = "extra_openapi"
# Key representing the acknowledgement header sent in case of bi-directional streaming
ACK_HEADER_STRING = "acknowledgement"
# Stream event type for HTTP output streaming
[docs]
class StreamEventTypes(Enum):
MESSAGE = "message"
ERROR = "error"
[docs]
def get_http_route_name(rpc_name: str) -> str:
"""Function to get the http route for a given rpc name
Args:
rpc_name (str): The name of the Caikit RPC
Raises:
NotImplementedError: If the RPC is not a Train or Predict RPC
Returns:
str: The name of the http route for RPC
"""
if rpc_name.endswith("Predict"):
task_name = camel_to_snake_case(
re.sub("Task$", "", re.sub("Predict$", "", rpc_name)),
kebab_case=True,
)
route = "/".join([get_config().runtime.http.route_prefix, "task", task_name])
if route[0] != "/":
route = "/" + route
return route
if rpc_name.endswith("StartPredictionJob"):
task_name = camel_to_snake_case(
re.sub("Task$", "", re.sub("StartPredictionJob$", "", rpc_name)),
kebab_case=True,
)
route = "/".join([get_config().runtime.http.route_prefix, "task", task_name])
if route[0] != "/":
route = "/" + route
return route
if rpc_name.endswith("Train"):
route = "/".join([get_config().runtime.http.route_prefix, rpc_name])
if route[0] != "/":
route = "/" + route
return route
raise NotImplementedError(f"Unknown RPC type for rpc name {rpc_name}")
[docs]
def get_http_prediction_job_route_name(rpc_name: str) -> str:
"""Function to get the http route for a prediction job given a rpc name
Args:
rpc_name (str): The name of the Caikit RPC
Raises:
NotImplementedError: If the RPC is not a Train or Predict RPC
Returns:
str: The name of the http route for RPC
"""
traditional_route = get_http_route_name(rpc_name)
return f"{traditional_route}/job"
[docs]
def get_http_prediction_job_result_route_name(rpc_name: str) -> str:
"""Function to get the http route for a prediction job result given a rpc name
Args:
rpc_name (str): The name of the Caikit RPC
Raises:
NotImplementedError: If the RPC is not a Train or Predict RPC
Returns:
str: The name of the http route for RPC
"""
job_route = get_http_prediction_job_route_name(rpc_name)
return f"{job_route}/results"
## GRPC Server
[docs]
def get_grpc_route_name(service_type: ServiceType, rpc_name: str) -> str:
"""Function to get GRPC name for a given service type and rpc name
Args:
rpc_name (str): The name of the Caikit RPC
Returns:
str: The name of the GRPC route for RPC
"""
return f"/{get_service_package_name(service_type)}.{get_service_name(service_type)}/{rpc_name}"
## Status Code Mappings
STATUS_CODE_TO_HTTP = {
# Mapping from GRPC codes to their corresponding HTTP codes
# pylint: disable=line-too-long
# CITE: https://chromium.googlesource.com/external/github.com/grpc/grpc/+/refs/tags/v1.21.4-pre1/doc/statuscodes.md
StatusCode.OK: 200,
StatusCode.INVALID_ARGUMENT: 400,
StatusCode.FAILED_PRECONDITION: 400,
StatusCode.OUT_OF_RANGE: 400,
StatusCode.UNAUTHENTICATED: 401,
StatusCode.PERMISSION_DENIED: 403,
StatusCode.NOT_FOUND: 404,
StatusCode.ALREADY_EXISTS: 409,
StatusCode.ABORTED: 409,
StatusCode.RESOURCE_EXHAUSTED: 429,
StatusCode.CANCELLED: 499,
StatusCode.UNKNOWN: 500,
StatusCode.DATA_LOSS: 500,
StatusCode.UNIMPLEMENTED: 501,
StatusCode.UNAVAILABLE: 501,
StatusCode.DEADLINE_EXCEEDED: 504,
# Mapping from CaikitCore StatusCodes codes to their corresponding HTTP codes
CaikitCoreStatusCode.INVALID_ARGUMENT: 400,
CaikitCoreStatusCode.UNAUTHORIZED: 401,
CaikitCoreStatusCode.FORBIDDEN: 403,
CaikitCoreStatusCode.NOT_FOUND: 404,
CaikitCoreStatusCode.CONNECTION_ERROR: 500,
CaikitCoreStatusCode.UNKNOWN: 500,
CaikitCoreStatusCode.FATAL: 500,
}
# Invert STATUS_CODE_TO_HTTP preferring grpc.StatusCodes over CaikitCoreStatusCode
# this is because CaikitRuntimeExceptions expect StatusCode and not the caikit version
HTTP_TO_STATUS_CODE = {}
for key, val in STATUS_CODE_TO_HTTP.items():
if val not in HTTP_TO_STATUS_CODE or isinstance(key, StatusCode):
HTTP_TO_STATUS_CODE[val] = key
# Mapping from CaikitCore StatusCodes codes to their corresponding GRPC status codes
CAIKIT_STATUS_CODE_TO_GRPC = {
CaikitCoreStatusCode.INVALID_ARGUMENT: StatusCode.INVALID_ARGUMENT,
CaikitCoreStatusCode.UNAUTHORIZED: StatusCode.UNAUTHENTICATED,
CaikitCoreStatusCode.FORBIDDEN: StatusCode.PERMISSION_DENIED,
CaikitCoreStatusCode.NOT_FOUND: StatusCode.NOT_FOUND,
CaikitCoreStatusCode.CONNECTION_ERROR: StatusCode.UNAVAILABLE,
CaikitCoreStatusCode.UNKNOWN: StatusCode.UNKNOWN,
CaikitCoreStatusCode.FATAL: StatusCode.INTERNAL,
}