# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module contains the implementation for retrieving information about the
library and services.
"""
# Have pylint ignore Class XXXX has no YYYY member so that we can use gRPC enums.
# pylint: disable=E1101
# Standard
from typing import Any, Dict, List, Optional, Union
# Third Party
from grpc import StatusCode
import importlib_metadata
# First Party
import alog
# Local
from caikit.config import get_config
from caikit.interfaces.runtime.data_model import (
ModelInfo,
ModelInfoRequest,
ModelInfoResponse,
RuntimeInfoResponse,
)
from caikit.runtime.model_management.model_manager import ModelManager
from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException
log = alog.use_channel("RI-SERVICR-I")
[docs]
class InfoServicer:
"""This class contains the implementation for retrieving information about the
library and services."""
[docs]
def GetModelsInfo(
self, request: ModelInfoRequest, context # pylint: disable=unused-argument
) -> ModelInfoResponse:
"""Get information on the loaded models for the GRPC server
Args:
request: ModelInfoRequest
context
Returns:
models_info: ModelInfoResponse
DataObject containing the model info
"""
return self._get_models_info(model_ids=request.model_ids).to_proto()
[docs]
def get_models_info_dict(
self, model_ids: Optional[List[str]]
) -> Dict[str, List[Dict[str, Any]]]:
"""Get information on models for the HTTP server
Returns:
model_info_dict: Dict[str, List[Dict[str, str]]]
Dict representation of ModelInfoResponse
"""
return self._get_models_info(model_ids=model_ids).to_dict()
[docs]
def _get_models_info(
self, model_ids: Optional[List[str]] = None
) -> ModelInfoResponse:
"""Helper function to get the list of models
Returns:
model_info: ModelInfoResponse
DataObject with model information
"""
model_manager = ModelManager.get_instance()
# Get list of models based on input list or all loaded models
loaded_model_list = []
if model_ids:
for model_name in model_ids:
loaded_model = model_manager.loaded_models.get(model_name)
if not loaded_model:
raise CaikitRuntimeException(
StatusCode.NOT_FOUND, f"Model {model_name} is not loaded"
)
loaded_model_list.append((model_name, loaded_model))
else:
loaded_model_list = model_manager.loaded_models.items()
# Get all loaded models
response = ModelInfoResponse(models=[])
for name, loaded_module in loaded_model_list:
# Skip models that haven't been loaded yet or don't have a local instance tied
if loaded_module.loaded():
model_instance = loaded_module.model()
response.models.append(
ModelInfo(
model_path=loaded_module.path(),
name=name,
size=loaded_module.size(),
metadata=model_instance.public_model_info,
loaded=loaded_module.loaded(),
module_id=model_instance.MODULE_ID,
module_metadata=model_instance.module_metadata,
)
)
else:
response.models.append(ModelInfo(loaded=False, name=name))
return response
[docs]
def GetRuntimeInfo(
self, request, context # pylint: disable=unused-argument
) -> RuntimeInfoResponse:
"""Get information on versions of libraries and server for GRPC"""
return self._get_runtime_info().to_proto()
[docs]
def get_version_dict(self) -> Dict[str, Union[str, Dict]]:
"""Get information on versions of libraries and server for HTTP"""
return self._get_runtime_info().to_dict()
[docs]
def _get_runtime_info(self) -> RuntimeInfoResponse:
"""Get information on versions of libraries and server from config"""
config_version_info = get_config().runtime.version_info or {}
python_packages = {
package: version
for package, version in config_version_info.get(
"python_packages", {}
).items()
if package != "all"
}
all_packages = (config_version_info.get("python_packages") or {}).get("all")
for lib, dist_names in importlib_metadata.packages_distributions().items():
if (
all_packages or (len(lib.split(".")) == 1 and lib.startswith("caikit"))
) and (version := self._try_lib_version(dist_names[0])):
python_packages[lib] = version
runtime_image = config_version_info.get("runtime_image")
return RuntimeInfoResponse(
python_packages=python_packages,
runtime_version=runtime_image,
)
[docs]
def _try_lib_version(self, name) -> str:
"""Get version of python modules"""
try:
return importlib_metadata.version(name)
except importlib_metadata.PackageNotFoundError:
return None