Source code for caikit.runtime.dump_services

# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Standard
from typing import Dict, List, Optional, Union
import argparse
import json
import os
import shutil
import sys

# Third Party
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pb2, descriptor_pool

# First Party
from py_to_proto import descriptor_to_file
from py_to_proto.utils import safe_add_fd_to_pool
import alog

# Local
from ..config.config import get_config
from ..core.data_model import render_dataobject_protos
from ..core.data_model.dataobject import get_generated_proto_classes
from ..core.exceptions import error_handler
from .service_factory import ServicePackage, ServicePackageFactory
import caikit

log = alog.use_channel("RUNTIME-DUMP-SVC")
error = error_handler.get(log)

## Public ######################################################################


[docs] def dump_grpc_services( output_dir: str, write_modules_file: bool, consolidate: bool = False, ): """Utility for rendering the all generated interfaces to proto files Args: output_dir (str): The directory where the generated services should be placed write_modules_file (bool): Whether or not to write out the compatibility file for supported modules consolidate (bool): Whether or not to consolidate the generated protos by package """ service_packages = _get_grpc_service_packages(write_modules_file) if not consolidate: log.info( "Dumping raw service and data model protos without package consolidation" ) render_dataobject_protos(output_dir) for svc_pkg in service_packages: svc_pkg.service.write_proto_file(output_dir) else: log.info("Dumping service and data model protos with package consolidation") os.makedirs(output_dir, exist_ok=True) all_descriptors = [ proto_cls.DESCRIPTOR for proto_cls in get_generated_proto_classes() if proto_cls.DESCRIPTOR.file.pool is descriptor_pool.Default() ] + [pkg.descriptor for pkg in service_packages] fd_protos = _get_proto_file_descriptors(all_descriptors) _dump_consolidated_protos(fd_protos, output_dir)
[docs] def dump_http_services(output_dir: str): """Dump out the openapi.json for the HTTP server""" # Import the HTTP components inside the dump function to avoid requiring # them when dumping grpc interfaces without the `runtime-http` optional # dependencies installed. try: # Third Party from fastapi.testclient import ( # pylint: disable=import-outside-toplevel TestClient, ) # Local from .http_server import ( # pylint: disable=import-outside-toplevel RuntimeHTTPServer, ) except ModuleNotFoundError as e: message = ( "Error: {} - unable to dump http services. Perhaps you missed" " installing the http optional dependencies?".format(e) ) log.error("<DMP76165827E>", message) sys.exit(1) server = RuntimeHTTPServer() with TestClient(server.app) as client: response = client.get("/openapi.json") response.raise_for_status() # create output dir if doesn't exist if not os.path.isdir(output_dir): os.makedirs(output_dir) with open( os.path.join(output_dir, "openapi.json"), "w", encoding="utf-8" ) as handle: handle.write(json.dumps(response.json(), indent=2))
## Implementation Details ######################################################
[docs] def _try_find_file_by_name( name: str, pool: descriptor_pool.DescriptorPool, ) -> Optional[_descriptor.FileDescriptor]: """Attempt to find a file descriptor by name and return None if not found""" try: return pool.FindFileByName(name) except KeyError: return None
[docs] def _recursive_safe_add_to_pool( fd_proto: descriptor_pb2.FileDescriptorProto, fd_protos_to_add: Dict[str, descriptor_pb2.FileDescriptorProto], dpool: descriptor_pool.DescriptorPool, ) -> _descriptor.FileDescriptor: """Recursively add the given file descriptor and all of its dependencies to the pool and handle double-add conflicts. """ fds_to_add_by_file_name = {fd.name: fd for fd in fd_protos_to_add.values()} for dep_name in fd_proto.dependency: if not _try_find_file_by_name(dep_name, dpool): # Look in the pile of protos that need to be added if pending_fd_proto := fds_to_add_by_file_name.get(dep_name): _recursive_safe_add_to_pool(pending_fd_proto, fd_protos_to_add, dpool) # Look in the default pool elif dflt_fd := _try_find_file_by_name(dep_name, descriptor_pool.Default()): dep_fd_proto = descriptor_pb2.FileDescriptorProto() dflt_fd.CopyToProto(dep_fd_proto) _recursive_safe_add_to_pool(dep_fd_proto, fd_protos_to_add, dpool) else: error( "<COR25660790E>", ValueError( f"Can't add {fd_proto.name}: dependency {dep_name} not found" ), ) safe_add_fd_to_pool(fd_proto, dpool) return dpool.FindFileByName(fd_proto.name)
[docs] def _descriptor_to_proto( descriptor: Union[ _descriptor.Descriptor, _descriptor.EnumDescriptor, _descriptor.ServiceDescriptor, ], ) -> Union[ descriptor_pb2.DescriptorProto, descriptor_pb2.EnumDescriptorProto, descriptor_pb2.ServiceDescriptorProto, ]: """Convert a given Descriptor type to the corresponding Proto for comparison by content rather than instance id """ error.type_check( "<COR46719006E>", _descriptor.Descriptor, _descriptor.EnumDescriptor, _descriptor.ServiceDescriptor, descriptor=descriptor, ) proto_type = None if isinstance(descriptor, _descriptor.Descriptor): proto_type = descriptor_pb2.DescriptorProto elif isinstance(descriptor, _descriptor.EnumDescriptor): proto_type = descriptor_pb2.EnumDescriptorProto elif isinstance(descriptor, _descriptor.ServiceDescriptor): proto_type = descriptor_pb2.ServiceDescriptorProto assert proto_type proto = proto_type() descriptor.CopyToProto(proto) return proto
[docs] def _get_proto_file_descriptors( object_descriptors: List[ Union[ _descriptor.Descriptor, _descriptor.EnumDescriptor, _descriptor.ServiceDescriptor, ] ], ) -> Dict[str, descriptor_pb2.FileDescriptorProto]: """Get a dict mapping package names to consolidated DescriptorProto objects holding all auto-generated messages and enums in the given package. """ # Deduplicate object descriptors dup_candidates = {} for obj_desc in object_descriptors: dup_candidates.setdefault(f"{type(obj_desc)}/{obj_desc.full_name}", {})[ id(obj_desc) ] = obj_desc dups = { dup_name: obj_descs for dup_name, obj_descs in dup_candidates.items() if len( { _descriptor_to_proto(obj_desc).SerializeToString() for obj_desc in obj_descs.values() } ) > 1 } error.value_check( "<COR01018988E>", not dups, "Found conflicting definitions of protobuf objects: {}", list(dups.keys()), ) object_descriptors = sorted( [list(obj_descs.values())[0] for obj_descs in dup_candidates.values()], key=lambda obj_desc: obj_desc.name, ) # Collect the auto-gen protos by package file_descriptor_protos = {} for obj_desc in object_descriptors: file_descriptor_proto = file_descriptor_protos.setdefault( obj_desc.file.package, descriptor_pb2.FileDescriptorProto() ) obj_desc.file.CopyToProto(file_descriptor_proto) # Update the file names to be package-level for pkg_name, pkg_fd in file_descriptor_protos.items(): file_safe_pkg_name = pkg_name.replace(".", "_") pkg_fd.name = f"{file_safe_pkg_name}.proto" # Update the dependencies for each package-level file descriptor proto for pkg_name, pkg_fd in file_descriptor_protos.items(): # Figure out the remaining set of deps for this file as all external # deps and all generated package-level files that aren't this one pkg_deps = set() for candidate_pkg_name in file_descriptor_protos: if candidate_pkg_name != pkg_name and any( dep.startswith(candidate_pkg_name) for dep in pkg_fd.dependency ): pkg_deps.add(candidate_pkg_name) # Clear out existing object-level file deps for existing_dep in list(pkg_fd.dependency): if any( existing_dep.startswith(candidate_pkg_name) for candidate_pkg_name in file_descriptor_protos ): pkg_fd.dependency.remove(existing_dep) # Add package-level dependency files pkg_fd.dependency.extend( sorted([file_descriptor_protos[pkg].name for pkg in pkg_deps]) ) # Remove duplicate dependencies. This is due to a proto3 bug in CopyToProto which # includes all dependencies even if they already exist pruned_deps = set(pkg_fd.dependency) del pkg_fd.dependency[:] pkg_fd.dependency.extend(list(pruned_deps)) return file_descriptor_protos
[docs] def _dump_consolidated_protos( fd_protos: Dict[str, descriptor_pb2.FileDescriptorProto], interfaces_dir: str, ): """Dump all protobuf interfaces consolidated by package""" temp_dpool = descriptor_pool.DescriptorPool() for fd_proto in fd_protos.values(): fd = _recursive_safe_add_to_pool(fd_proto, fd_protos, temp_dpool) target_file = os.path.join(interfaces_dir, fd.name) with open(target_file, "w") as handle: handle.write(descriptor_to_file(fd))
[docs] def _get_grpc_service_packages( write_modules_file: bool = False, ) -> List[ServicePackage]: """Get all enabled grpc service packages""" inf_enabled = get_config().runtime.service_generation.enable_inference inf_job_enabled = get_config().runtime.service_generation.enable_inference_jobs train_enabled = get_config().runtime.service_generation.enable_training svc_descriptors = [] if inf_enabled: svc_descriptors.append( ServicePackageFactory.get_service_package( ServicePackageFactory.ServiceType.INFERENCE, write_modules_file=write_modules_file, ) ) if inf_enabled and inf_job_enabled: svc_descriptors.append( ServicePackageFactory.get_service_package( ServicePackageFactory.ServiceType.JOB_INFERENCE, write_modules_file=write_modules_file, ) ) if train_enabled: svc_descriptors.append( ServicePackageFactory.get_service_package( ServicePackageFactory.ServiceType.TRAINING, ) ) svc_descriptors.append( ServicePackageFactory.get_service_package( ServicePackageFactory.ServiceType.TRAINING_MANAGEMENT, ) ) svc_descriptors.append( ServicePackageFactory.get_service_package( ServicePackageFactory.ServiceType.INFO, ) ) return svc_descriptors
## Main ########################################################################
[docs] def main(): parser = argparse.ArgumentParser( description="Dump grpc and http services for inference and train" ) parser.add_argument( "output_dir", type=str, help="Path to the output directory for service(s)' proto files", ) parser.add_argument( "-j", "--write-modules-json", default=False, action="store_true", help="Wether the modules.json (of supported modules) should be output?", ) parser.add_argument( "-c", "--clean", default=False, action="store_true", help="Clean out existing content in output dir", ) parser.add_argument( "-p", "--consolidate-packages", default=False, action="store_true", help="Consolidate protobufs by package", ) args = parser.parse_args() # Set up logging so users can set LOG_LEVEL etc caikit.core.toolkit.logging.configure() # Make sure the output dir exists and optionally clean it out out_dir = args.output_dir if args.clean and os.path.exists(out_dir): shutil.rmtree(out_dir) os.makedirs(out_dir, exist_ok=True) if get_config().runtime.grpc.enabled: dump_grpc_services( out_dir, args.write_modules_json, args.consolidate_packages, ) if get_config().runtime.http.enabled: dump_http_services(out_dir)
if __name__ == "__main__": main()