Source code for caikit.runtime.client.utils

# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Helper utils for GRPC and HTTP connections
"""
# Standard
from typing import Dict, List, Optional, Tuple
import json

# Third Party
from requests import Session
from requests.adapters import HTTPAdapter, Retry

# Third party
import grpc

# Local
from caikit.interfaces.common.data_model import ConnectionTlsInfo


[docs] def construct_grpc_channel( target: str, options: Optional[List[Tuple[str, str]]] = None, tls: Optional[ConnectionTlsInfo] = None, retries: Optional[int] = None, retry_options: Optional[Dict[str, str]] = None, ) -> grpc.Channel: """Helper function to construct a grpc Channel with the given TLS config Args: target (str): The target hostname options (Optional[List[Tuple[str, str]]], optional): List of tuples representing GRPC options. Defaults to None. tls (Optional[ConnectionTlsInfo], optional): The TLS information for this channel. Defaults to None. retries (Optional[int], optional): The max number of retries to attempt. Defaults to None. retry_options (Optional[Dict[str, str]], optional): Dictionary to override fields in the GRPC retry service config. Defaults to None. Returns: grpc.Channel: The constructed channel """ # Add retry option if one was provided if retries and retries > 1: options.append(("grpc.enable_retries", 1)) # Only add service_config if it wasn't already added to the GRPC option # this stops us from overriding an advanced config options_contain_service_config = False for option_name, _ in options: if option_name == "grpc.service_config": options_contain_service_config = True break if not options_contain_service_config: service_config = { "methodConfig": [ { "name": [{}], "retryPolicy": { "maxAttempts": retries, "initialBackoff": "0.1s", "maxBackoff": "5s", "backoffMultiplier": 2, "retryableStatusCodes": [ "UNAVAILABLE", "UNKNOWN", "INTERNAL", "RESOURCE_EXHAUSTED", ], **retry_options, }, } ] } options.append(("grpc.service_config", json.dumps(service_config))) if tls and tls.enabled: grpc_credentials = grpc.ssl_channel_credentials( root_certificates=tls.ca_data, private_key=tls.key_data, certificate_chain=tls.cert_data, ) return grpc.secure_channel( target, credentials=grpc_credentials, options=options ) return grpc.insecure_channel(target, options=options)
[docs] def construct_requests_session( options: Optional[Dict[str, str]] = None, tls: Optional[ConnectionTlsInfo] = None, timeout: Optional[int] = None, retries: Optional[int] = None, retry_options: Optional[Dict[str, str]] = None, ) -> Session: """Helper function to construct a requests Session object with the given TLS config Args: options (Optional[Dict[str, str]], optional): Dictionary of request kwargs to pass to session creation. Defaults to None. tls (Optional[ConnectionTlsInfo], optional): The TLS information for this session. Defaults to None. retries (Optional[int], optional): The max number of retries to attempt. Defaults to None. retry_options (Optional[Dict[str, str]], optional): Dictionary to override kwargs passed to the Retry object construction Returns: Session: _description_ """ session = Session() session.headers["Content-type"] = "application/json" # Gather request SSL configuration if tls.enabled: # Configure the TLS CA settings if tls.insecure_verify: session.verify = False else: session.verify = tls.ca_file or True # Configure MTLS if its enabled if tls.mtls_enabled: session.cert = ( tls.cert_file, tls.key_file, ) # Update request options and timeout variables if options: session.params.update(options) if timeout: session.params["timeout"] = timeout # Mount retry object if options were provided if retries: default_status_codes = list(Retry.RETRY_AFTER_STATUS_CODES) + [500, 502, 504] requests_retry = Retry( total=retries, allowed_methods=None, status_forcelist=default_status_codes, **(retry_options or {}) ) session.mount("http://", HTTPAdapter(max_retries=requests_retry)) session.mount("https://", HTTPAdapter(max_retries=requests_retry)) return session