# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file contains interfaces required to connect to Remote servers
"""
# Standard
from dataclasses import field
from http.client import HTTP_PORT, HTTPS_PORT
from pathlib import Path
from typing import Optional, Union
# First Party
import alog
# Local
from caikit.core.data_model import PACKAGE_COMMON, DataObjectBase, dataobject
from caikit.core.data_model.json_dict import JsonDict
from caikit.core.exceptions import error_handler
log = alog.use_channel("CNNCTDM")
error = error_handler.get(log)
[docs]
@dataobject(PACKAGE_COMMON)
class ConnectionTlsInfo(DataObjectBase):
"""Helper dataclass to store information regarding TLS information."""
# If TLS is enabled
enabled: bool = False
# Whether to verify server CA bundle
insecure_verify: bool = False
# TLS Key information
ca_file: Optional[str]
cert_file: Optional[str]
key_file: Optional[str]
@property
def mtls_enabled(self) -> bool:
"""Helper property to identify if mtls is enabled"""
return self.cert_file and self.key_file
# Don't use cached_property as DataBase does not contain a __dict__ object
# This also required provided private_slots to DataBase
_private_slots = ("_ca_data", "_cert_data", "_key_data")
@property
def ca_data(self) -> Optional[bytes]:
if not self._ca_data and self.ca_file and Path(self.ca_file).exists():
self._ca_data = Path(self.ca_file).read_bytes()
return self._ca_data
@property
def key_data(self) -> Optional[bytes]:
if not self._key_data and self.key_file and Path(self.key_file).exists():
self._key_data = Path(self.key_file).read_bytes()
return self._key_data
@property
def cert_data(self) -> Optional[bytes]:
if not self._cert_data and self.cert_file and Path(self.cert_file).exists():
self._cert_data = Path(self.cert_file).read_bytes()
return self._cert_data
[docs]
def __post_init__(self):
"""Post init function to verify field types and arguments"""
error.type_check(
"<COR734221567E>",
str,
bytes,
allow_none=True,
tls_ca=self.ca_file,
tls_cert=self.cert_file,
key_file=self.key_file,
)
error.type_check(
"COR74322567E",
bool,
tls_enabled=self.enabled,
insecure_verify=self.insecure_verify,
)
# Initialize cached properties
self._ca_data = None
self._cert_data = None
self._key_data = None
# Read file data if it exists
if self.enabled:
self.verify_ssl_data()
[docs]
def verify_ssl_data(self):
"""Helper function to verify all TLS data was read correctly.
Raises:
FileNotFoundError: If any of the tls files were provided but could not be found
"""
if self.ca_file and not self.ca_data:
raise FileNotFoundError(f"Unable to find TLS CA File {self.ca_file}")
if self.key_file and not self.key_data:
raise FileNotFoundError(f"Unable to find TLS Key File {self.key_file}")
if self.cert_file and not self.cert_data:
raise FileNotFoundError(f"Unable to find TLS Cert File {self.cert_file}")
# Logical XOR to ensure if one is provided so is the other
if bool(self.cert_file) != bool(self.key_file):
raise ValueError(
"Invalid TLS values. Both cert and key must be provided:"
f"{self.cert_file=}, {self.key_file=}"
)
[docs]
@dataobject(PACKAGE_COMMON)
class ConnectionInfo(DataObjectBase):
"""DataClass to store information regarding an external connection. This includes the hostname,
port, tls, and timeout settings"""
# Generic Host settings
hostname: str
port: Optional[int] = None
# TLS Settings
tls: Optional[ConnectionTlsInfo] = field(default_factory=ConnectionTlsInfo)
# Connection timeout settings (in seconds)
timeout: Optional[int] = 60
# Any extra options for the connection
options: Optional[JsonDict] = field(default_factory=dict)
# Number of retries to perform
retries: Optional[int] = 1
# Runtime specific retry options
retry_options: Optional[JsonDict] = field(default_factory=dict)
# Maximum age for a client channel. Values less then 0 are infinite while 0 means new
# channel/session for every request
max_session_age: Union[float, int] = -1
[docs]
def __post_init__(self):
"""Post init function to verify field types and set defaults"""
# If tls is attribute dict then manually convert it to tls
if isinstance(self.tls, dict):
self.tls = ConnectionTlsInfo(**self.tls)
# Set default port. Utilize the standard HTTP ports as the majority of protocols
# use http under the hood like grpc and s3
if not self.port:
self.port = HTTPS_PORT if self.tls.enabled else HTTP_PORT
# Type check all arguments
error.type_check(
"<COR734221567E>",
str,
hostname=self.hostname,
)
error.type_check(
"<COR734224567E>",
int,
port=self.port,
timeout=self.timeout,
retries=self.retries,
)
error.type_check(
"<COR730224567E>",
float,
int,
max_session_age=self.max_session_age,
)
if self.options:
error.type_check("<COR734424567E>", str, int, float, **self.options)
if self.retry_options:
error.type_check("<COR734424567E>", str, int, float, **self.retry_options)