# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Standard
from functools import cached_property, partial
from glob import glob
from typing import Any, Callable, Dict, List, Optional, Type, Union
import abc
import os
import sys
# Third Party
import grpc
# First Party
from py_to_proto.dataclass_to_proto import Annotated, FieldNumber, OneofField
import aconfig
import alog
# Local
from caikit.config import get_config
from caikit.core.data_model.base import DataBase
from caikit.core.data_model.dataobject import _make_oneof_init, make_dataobject
from caikit.core.data_model.streams.data_stream import DataStream
from caikit.core.exceptions import error_handler
from caikit.core.toolkit.factory import FactoryConstructible, ImportableFactory
from caikit.interfaces.common.data_model.stream_sources import (
Directory,
FileReference,
ListOfFileReferences,
S3Files,
)
from caikit.runtime.names import get_service_package_name
from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException
import caikit
# import common explicitly since this module needs it
import caikit.interfaces.common
# This global holds the mapping of element types to their respective
# DataStreamSource wrappers so that the same message is not recreated
# unnecessarily
_DATA_STREAM_SOURCE_TYPES = {}
log = alog.use_channel("DSTRM-SRC")
error = error_handler.get(log)
## Plugin Bases ################################################################
[docs]
class DataStreamSourcePlugin(FactoryConstructible):
"""A DataStreamSourcePlugin is a pluggable source that defines the shape of
the data object needed as well as the code for accessing the data from some
source type.
"""
def __init__(self, config: aconfig.Config, instance_name: str):
"""Construct with the basic factory constructible interface and store the
args for use by the child
"""
self._config = config
self._instance_name = instance_name
## Abstract Interface ##
[docs]
@abc.abstractmethod
def get_stream_message_type(self, element_type: type) -> Type[DataBase]:
"""Get the type of the dataobject class that will be used as the source
information
"""
[docs]
@abc.abstractmethod
def to_data_stream(
self, source_message: Type[DataBase], element_type: type
) -> DataStream:
"""Convert an instance of the source message type into a DataStream"""
[docs]
@abc.abstractmethod
def get_field_number(self) -> int:
"""Each plugin must define its field number which may be informed by
self._config
"""
## Public Methods ##
[docs]
def get_field_name(self, element_type: type) -> str:
"""The name of the field that this plugin will use in the source oneof"""
return self.get_stream_message_type(element_type).__name__.lower()
## Shared Impl ##
[docs]
@staticmethod
def _to_element_type(element_type: type, raw_element: Any) -> Any:
if issubclass(element_type, DataBase):
# To allow for extra fields (e.g. in training data) that may not
# be needed by the data objects, we ignore unknown fields
return element_type.from_json(raw_element, ignore_unknown_fields=True)
return raw_element
[docs]
@staticmethod
def _to_element_partial(element_type: type) -> Callable:
return partial(DataStreamSourcePlugin._to_element_type, element_type)
[docs]
class FilePluginBase(DataStreamSourcePlugin):
"""Intermediate base class for file-based plugins with helper utilities"""
[docs]
@classmethod
def _create_data_stream_from_file(
cls, fname: str, element_type: type
) -> DataStream:
"""Create a data stream object by deducing file extension
and reading the file accordingly"""
_, extension = os.path.splitext(fname)
if not extension:
return cls._load_from_file_without_extension(fname, element_type)
full_fname = cls._get_resolved_source_path(fname)
log.debug3("Pulling data stream from %s file [%s]", extension, full_fname)
if not fname or not os.path.isfile(full_fname):
raise CaikitRuntimeException(
grpc.StatusCode.INVALID_ARGUMENT,
f"Invalid {extension} data source file: {fname}",
)
to_element_type = cls._to_element_partial(element_type)
if extension == ".json":
stream = DataStream.from_json_array(full_fname).map(to_element_type)
# Iterate once to make sure this is a json array
stream.peek()
return stream
if extension == ".csv":
return DataStream.from_header_csv(full_fname).map(to_element_type)
if extension == ".jsonl":
return DataStream.from_jsonl(full_fname).map(to_element_type)
raise CaikitRuntimeException(
grpc.StatusCode.INVALID_ARGUMENT,
f"Extension not supported! {extension}",
)
[docs]
@classmethod
def _load_from_file_without_extension(cls, fname, element_type: type) -> DataStream:
"""Similar to _create_data_stream_from_file, but we don't have a file extension to work
with. Attempt to create a data stream using one of a few well-known formats.
🌶🌶🌶️ on ordering here:
File formats are loosely arranged in order of least-to-most-sketchy format validation.
1. .json/.jsonl are pretty straightforward
2. multipart files are a little iffy- the content-type header line can be omitted, in
which case we check for a `--` string and roll our own boundary parser. This could
cause problems in the future for multi-yaml files that begin with `---`
3. CSV support simply assumes the first line of the file has the column headers, and may
confidently return a stream even if that's not the case.
"""
full_fname = cls._get_resolved_source_path(fname)
to_element_type = cls._to_element_partial(element_type)
log.debug3("Attempting to guess file type for file: %s", full_fname)
for factory_method in (
DataStream.from_json_array,
DataStream.from_jsonl,
DataStream.from_multipart_file,
DataStream.from_header_csv,
):
try:
stream = factory_method(full_fname).map(to_element_type)
# Iterate once and assume we have the correct file type if this
# works
stream.peek()
return stream
except Exception as e: # pylint: disable=broad-exception-caught
# Catch any exception: it's hard to know which all could be
# thrown by any of the formatters
log.debug3(
"Failed to load file %s using data stream factory method %s: %s",
full_fname,
factory_method,
e,
exc_info=True,
)
raise CaikitRuntimeException(
grpc.StatusCode.INVALID_ARGUMENT,
f"Could not load input file with no extension: {full_fname}",
)
[docs]
@staticmethod
def _get_resolved_source_path(input_path: str) -> str:
"""Get a fully resolved path, including any shared prefix"""
# Get any configured prefix
source_pfx = caikit.get_config().data_streams.file_source_base
# If a prefix is configured, use it, otherwise return the path as is
# NOTE: os.path.join will ignore the prefix if input_path is absolute
return os.path.join(source_pfx, input_path) if source_pfx else input_path
## Source Plugins ##############################################################
[docs]
class FileDataStreamSourcePlugin(FilePluginBase):
"""Plugin for a single file"""
name = "FileData"
[docs]
def get_field_name(self, element_type: type) -> str:
"""Half-Backwards compatibility and half keep FileReference consistent
with ListofFiles/Directory"""
return "file"
[docs]
def get_stream_message_type(self, *_, **__) -> Type[DataBase]:
return FileReference
[docs]
def to_data_stream(
self, source_message: FileReference, element_type: type
) -> DataStream:
return self._create_data_stream_from_file(
fname=source_message.filename, element_type=element_type
)
[docs]
def get_field_number(self) -> int:
return 2
[docs]
class ListOfFilesDataStreamSourcePlugin(FilePluginBase):
"""Plugin for a list of files"""
name = "ListOfFiles"
[docs]
def get_field_name(self, element_type: type) -> str:
"""Half-Backwards compatibility and half keep ListOfFile consistent
with File/Directory"""
return "list_of_files"
[docs]
def get_stream_message_type(self, *_, **__) -> Type[DataBase]:
return ListOfFileReferences
[docs]
def to_data_stream(
self, source_message: ListOfFileReferences, element_type: type
) -> DataStream:
data_stream_list = []
for fname in source_message.files:
data_stream_list.append(
self._create_data_stream_from_file(
fname=fname, element_type=element_type
)
)
return DataStream.chain(data_stream_list).flatten()
[docs]
def get_field_number(self) -> int:
return 3
[docs]
class DirectoryDataStreamSourcePlugin(FilePluginBase):
"""Plugin for a directory holding files"""
name = "Directory"
[docs]
def get_stream_message_type(self, *_, **__) -> Type[DataBase]:
return Directory
[docs]
def to_data_stream(
self, source_message: Directory, element_type: type
) -> DataStream:
dirname = source_message.dirname
full_dirname = self._get_resolved_source_path(dirname)
extension = source_message.extension or "json"
if not dirname or not os.path.isdir(full_dirname):
raise CaikitRuntimeException(
grpc.StatusCode.INVALID_ARGUMENT,
f"Invalid {extension} directory source file: {full_dirname}",
)
files_with_ext = list(glob(os.path.join(full_dirname, "*." + extension)))
to_element_type = self._to_element_partial(element_type)
# make sure at least 1 file with the given extension exists
if len(files_with_ext) == 0:
raise CaikitRuntimeException(
grpc.StatusCode.INVALID_ARGUMENT,
f"directory {dirname} contains no source files with extension {extension}",
)
if extension == "json":
return DataStream.from_json_collection(full_dirname, extension).map(
to_element_type
)
if extension == "csv":
return DataStream.from_csv_collection(full_dirname).map(to_element_type)
if extension == "jsonl":
return DataStream.from_jsonl_collection(full_dirname).map(to_element_type)
raise CaikitRuntimeException(
grpc.StatusCode.INVALID_ARGUMENT,
f"Extension not supported! {extension}",
)
[docs]
def get_field_number(self) -> int:
return 4
[docs]
class JsonDataStreamSourcePlugin(DataStreamSourcePlugin):
"""This plugin is for inline data, elements are provided in a list.
This plugin has instantiation logic: it needs the stream's element type so that it can
generate a data model for List[element_type]"""
name = "JsonData"
# class-level cache required to avoid creating duplicate data model classes
stream_source_type_cache: Dict[Type[DataBase], Type[DataBase]] = {}
[docs]
def get_stream_message_type(self, element_type: type) -> Type[DataBase]:
stream_message_type = self.__class__.stream_source_type_cache.get(element_type)
if stream_message_type:
return stream_message_type
package = get_service_package_name()
cls_name = _make_data_stream_source_type_name(element_type)
JsonData = make_dataobject(
package=package,
proto_name=f"{cls_name}JsonData",
name="JsonData",
attrs={"__qualname__": f"{cls_name}.JsonData"},
annotations={"data": List[element_type]},
)
self.__class__.stream_source_type_cache[element_type] = JsonData
return JsonData
[docs]
def to_data_stream(self, source_message: Type[DataBase], *_, **__) -> DataStream:
"""source_message should be of type self.get_stream_message_type
So it _should_ contain an attribute named `data`, which is a list"""
return DataStream.from_iterable(source_message.data)
[docs]
def get_field_number(self) -> int:
return 1
[docs]
class S3FilesDataStreamSourcePlugin(DataStreamSourcePlugin):
"""Unimplemented!"""
name = "S3Files"
[docs]
def get_stream_message_type(self, *_, **__) -> Type[DataBase]:
return S3Files
[docs]
def to_data_stream(self, *_, **__) -> DataStream:
error(
"<RUN80419785E>",
NotImplementedError(
"S3Files are not implemented as stream sources in this runtime."
),
)
[docs]
def get_field_number(self) -> int:
return 5
## DataStreamPluginFactory #####################################################
[docs]
class DataStreamPluginFactory(ImportableFactory):
"""The DataStreamPluginFactory is responsible for holding a registry of
plugin instances that will be used to create and manage data stream sources
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._plugins = None
[docs]
def get_plugins(
self, plugins_config: Optional[aconfig.Config] = None
) -> List[DataStreamSourcePlugin]:
"""Builds the set of plugins to use for a data stream source of type element_type"""
if self._plugins is None:
self._plugins = []
if plugins_config is None:
plugins_config = get_config().data_streams.source_plugins
for name, cfg in plugins_config.items():
self._plugins.append(self.construct(cfg, name))
# Make sure field numbers are unique
field_numbers = [plugin.get_field_number() for plugin in self._plugins]
duplicate_field_number_names = [
plugin.name
for plugin in self._plugins
if field_numbers.count(plugin.get_field_number()) > 1
]
error.value_check(
"<RUN69189361E>",
not duplicate_field_number_names,
"Duplicate plugin field numbers found for plugins: {}",
duplicate_field_number_names,
)
return self._plugins
# Single default instance
PluginFactory = DataStreamPluginFactory("DataStreamSource")
PluginFactory.register(JsonDataStreamSourcePlugin)
PluginFactory.register(FileDataStreamSourcePlugin)
PluginFactory.register(ListOfFilesDataStreamSourcePlugin)
PluginFactory.register(DirectoryDataStreamSourcePlugin)
PluginFactory.register(S3FilesDataStreamSourcePlugin)
## DataStreamSourceBase ########################################################
[docs]
class DataStreamSourceBase(DataStream):
"""This base class acts as a sentinel so that dynamically generated data
stream source classes can be identified programmatically.
"""
def __init__(self):
super().__init__(self._generator)
[docs]
def _generator(self):
return self._stream.generator_func(
*self._stream.generator_args, **self._stream.generator_kwargs
)
[docs]
def __getstate__(self) -> bytes:
"""A DataStreamSource is pickled by serializing its source
representation. This is particularly useful when sharing data streams
across subprocesses to run training in an isolated process.
"""
return self.to_binary_buffer()
[docs]
def __setstate__(self, pickle_bytes: bytes):
"""Unpickling a DataStreamSource basically involves unpacking the
serialized source representation. The catch is that the oneof is
represented strangely in __dict__, so we need to explicitly set all
oneof members.
"""
new_inst = self.__class__.from_binary_buffer(pickle_bytes)
setattr(self, new_inst.which_oneof("data_stream"), new_inst.data_stream)
self.generator_func = self._generator
self.generator_args = tuple()
self.generator_kwargs = {}
@cached_property
def name_to_plugin_map(self):
return {
plugin.get_field_name(self.ELEMENT_TYPE): plugin for plugin in self.PLUGINS
}
@cached_property
def _stream(self):
"""The internal _stream is cached here so that the result of calling to_data_stream can be
re-read, rather than requiring to_data_stream to be invoked on every read through the
stream"""
return self.to_data_stream()
# pylint: disable=too-many-return-statements
[docs]
def to_data_stream(self) -> DataStream:
"""Convert to the target data stream type based on the source type"""
# Determine which of the value types is set
set_field = None
for field_name in self.get_proto_class().DESCRIPTOR.fields_by_name:
if getattr(self, field_name) is not None:
error.value_check(
"<RUN80421785E>",
set_field is None,
"Found DataStreamSource with multiple sources set: {} and {}",
set_field,
field_name,
)
error.value_check(
"<RUN80420785E>",
field_name in self.name_to_plugin_map,
"no data stream plugin found for field: {}",
field_name,
)
set_field = field_name
# If no field is set, return an empty DataStream
if set_field is None:
log.debug3("Returning empty data stream")
return DataStream.from_iterable([])
# Get the correct plugin, and pass it the source field + the element
# type to serialize to
plugin = self.name_to_plugin_map[set_field]
return plugin.to_data_stream(getattr(self, set_field), self.ELEMENT_TYPE)
## make_data_stream_source #####################################################
[docs]
def make_data_stream_source(
data_element_type: type,
plugin_factory: DataStreamPluginFactory = PluginFactory,
plugins_config: Optional[aconfig.Config] = None,
) -> Type[DataBase]:
"""Dynamically create a data stream source message type that supports
pulling an iterable of the given type from all valid data stream sources
"""
log.debug2("Looking for DataStreamSource[%s]", data_element_type)
if data_element_type not in _DATA_STREAM_SOURCE_TYPES:
cls_name = _make_data_stream_source_type_name(data_element_type)
package = get_service_package_name()
log.debug("Creating DataStreamSource[%s] -> %s", data_element_type, cls_name)
# Get the required plugins
plugins = plugin_factory.get_plugins(plugins_config)
# Make sure there are no field name duplicates
plug_to_name = {
plugin: plugin.get_field_name(data_element_type) for plugin in plugins
}
all_field_names = list(plug_to_name.values())
duplicates = {
plugin.name: field_name
for plugin, field_name in plug_to_name.items()
if all_field_names.count(field_name) > 1
}
error.value_check(
"<RUN66854455E>",
not duplicates,
"Duplicate plugin field names found for type {}: {}",
data_element_type,
duplicates,
)
# Create the outer class that encapsulates the Union (oneof) of the
# various types of input sources
# Determine the type stream message type for each source. This can
# potentially be expensive, so we do it once
stream_message_types = {
plugin.name: plugin.get_stream_message_type(data_element_type)
for plugin in plugins
}
# Build the type annotation for the data model
# This describes a large oneof containing all the info from each data
# stream source plugin
annotation_list = [
Annotated[
stream_message_types[plugin.name],
OneofField(plugin.get_field_name(data_element_type)),
FieldNumber(plugin.get_field_number()),
]
for plugin in plugins
]
data_stream_type_union = Union[tuple(annotation_list)]
# Create an attribute dictionary that will expose each of the source
# types on this datastream class itself. E.g. if I have the `JsonData`
# plugin enabled, this enables:
# >>> make_data_stream_source(some_type).JsonData
# to access the `JsonData` source message directly.
type_attrs = {
msg_type.__name__: msg_type for msg_type in stream_message_types.values()
}
data_object = make_dataobject(
package=package,
name=cls_name,
bases=(DataStreamSourceBase,),
attrs={"ELEMENT_TYPE": data_element_type, "PLUGINS": plugins, **type_attrs},
annotations={"data_stream": data_stream_type_union},
)
# Add this data stream source to the common data model and the module
# where it was declared
setattr(
caikit.interfaces.common.data_model,
cls_name,
data_object,
)
setattr(
sys.modules[data_object.__module__],
cls_name,
data_object,
)
# Add an init that sequences the initialization so that
# DataStreamSourceBase is initialized after DataBase
orig_init = _make_oneof_init(data_object)
def __init__(self, *args, **kwargs):
try:
orig_init(self, *args, **kwargs)
except TypeError as err:
raise CaikitRuntimeException(
grpc.StatusCode.INVALID_ARGUMENT, str(err)
) from err
DataStreamSourceBase.__init__(self)
data_object.__init__ = __init__
_DATA_STREAM_SOURCE_TYPES[data_element_type] = data_object
# Return the global stream source object for this element type
return _DATA_STREAM_SOURCE_TYPES[data_element_type]
[docs]
def _make_data_stream_source_type_name(data_element_type: Type) -> str:
"""Make the name for data stream source class that wraps the given type"""
element_name = data_element_type.__name__
return "DataStreamSource{}".format(element_name[0].upper() + element_name[1:])