# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module is a central entrypoint for running a single synchronous training
job using caikit.core.train
"""
# Standard
from pathlib import Path
from typing import Type
import argparse
import importlib
import json
import os
import sys
import traceback
# Third Party
from google.protobuf import json_format
# First Party
import alog
# Local
from ..core import ModuleBase, train
from ..core.data_model import TrainingStatus
from ..core.exceptions import error_handler
from ..core.registries import module_registry
from ..core.toolkit.logging import configure as config_logging
from .names import get_service_package_name
from .service_factory import ServicePackageFactory
from .utils.servicer_util import build_caikit_library_request_dict
log = alog.use_channel("TRAIN")
error = error_handler.get(log)
# The USER_ERROR_EXIT_CODE will be thrown when the process must exit
# as result of a user input error. User-related errors should be
# >= 1 and <=127 due to how some kubernetes operators interpret them.
USER_ERROR_EXIT_CODE = 1
# The INTERNAL_ERROR_EXIT_CODE will be thrown when training
# abnormally terminates, and it is not clearly fault of the user.
# System-level errors should be >= 128 and <= 254
INTERNAL_ERROR_EXIT_CODE = 203
[docs]
class ArgumentParserError(Exception):
"""Custom exception class for ArgumentParser errors."""
[docs]
class TrainArgumentParser(argparse.ArgumentParser):
[docs]
def error(self, message):
"""Error handler that raises an exception instead of exiting."""
raise ArgumentParserError(f"{self.prog}: error: {message}")
[docs]
def write_termination_log(text: str, log_file: str, enabled: bool):
if not enabled:
return
try:
with open(log_file, "a") as handle:
handle.write(text)
except Exception as e:
log.warning(
"<COR96300323W>",
"Unable to write termination log due to error %s",
e,
)
# Final tasks before exiting the container
[docs]
def exit_complete(
exit_code: int,
save_path: str,
message: str,
termination_log_file: str,
enable_termination_log: bool,
):
if exit_code != 0:
write_termination_log(message, termination_log_file, enable_termination_log)
if save_path:
try:
complete_path = os.path.join(save_path, ".complete")
log.info("Creating completion file at: %s", complete_path)
Path(complete_path).touch()
except Exception as e:
log.warning("Unable to write completion file due to execption: %s", e)
exit(exit_code)
[docs]
def main() -> int:
"""Main entrypoint for running training jobs"""
parser = TrainArgumentParser(description=__doc__)
# Set default values for termination log incase parsing the arguments fail later on
enable_termination_log = os.environ.get("ENABLE_TERMINATION_LOG", True)
termination_log_file = os.environ.get(
"TERMINATION_LOG_FILE", "/dev/termination-log"
)
# Required Args
parser.add_argument(
"--training-kwargs",
"-k",
required=True,
help="Json string or json file pointer with keyword args for the training job",
)
parser.add_argument(
"--module",
"-m",
required=True,
help="Module name (package.Class) or UID to train",
)
parser.add_argument(
"--model-name",
"-n",
required=True,
help="Name to save the model under",
)
# Optional args
parser.add_argument(
"--save-path",
"-s",
default=".",
help="Path to save the output model to",
)
parser.add_argument(
"--library",
"-l",
nargs="*",
help="Libraries that need to be imported to register the module to train",
)
parser.add_argument(
"--trainer",
"-t",
default=None,
help="Trainer config name to use",
)
parser.add_argument(
"--save-with-id",
"-i",
action="store_true",
default=False,
help="Include the training ID in the save path",
)
parser.add_argument(
"--termination-log-file",
"-f",
default=termination_log_file,
help="Location of where to write a termination error message",
)
parser.add_argument(
"--enable-termination-log",
"-e",
default=enable_termination_log,
help="Whether to enable writing to termination log when training fails",
)
try:
args = parser.parse_args()
config_logging()
# Modify termination log variables if parsed
# Previously we grabbed the values from env variables (if present)
# Here, we allow overriding it with the parser values
# If the parser throws an exception parsing any of the args, the values
# captured in previous sections will be used.
if args.enable_termination_log:
enable_termination_log = args.enable_termination_log
if args.termination_log_file:
termination_log_file = args.termination_log_file
# Initialize top-level kwargs
train_kwargs = {
"save_path": args.save_path,
"save_with_id": args.save_with_id,
"model_name": args.model_name,
}
if args.trainer is not None:
train_kwargs["trainer"] = args.trainer
except Exception as e:
message = f"Exception raised during training. This may be a problem with your input: {e}"
log.warning(
{
"log_code": "<COR39662029E>",
"message": message,
"stack_trace": traceback.format_exc(),
},
exc_info=True,
)
# We couldn't parse args, so cannot not pass save_path in
exit_complete(
USER_ERROR_EXIT_CODE,
None,
message,
termination_log_file,
enable_termination_log,
)
# Import libraries to register modules
try:
for library in args.library or []:
log.info("<COR88091092I>", "Importing library %s", library)
importlib.import_module(library)
except Exception:
message = "Unable to import module {}".format(library)
log.warning(
{
"log_code": "<COR17776539E>",
"message": message,
"stack_trace": traceback.format_exc(),
},
exc_info=True,
)
exit_complete(
USER_ERROR_EXIT_CODE,
args.save_path,
message,
termination_log_file,
enable_termination_log,
)
# Try to import the root library of the provided module. It's ok if this
# fails since the module may be a UID
try:
mod_root_lib = args.module.split(".")[0]
importlib.import_module(mod_root_lib)
except (ImportError, ValueError):
log.debug("Unable to import module root lib: %s", mod_root_lib)
# Figure out the module to train
try:
mod_reg = module_registry()
mod_pkg_to_mod = {
f"{mod.__module__}.{mod.__name__}": mod for mod in mod_reg.values()
}
module: Type[ModuleBase] = mod_reg.get(
args.module, mod_pkg_to_mod.get(args.module)
)
error.value_check(
"<COR03876205E>",
module is not None,
"Unable to find module {} to train",
args.module,
)
except Exception:
message = "Unable to find module {} to train".format(args.module)
log.warning(
{
"log_code": "<COR17476539E>",
"message": message,
"stack_trace": traceback.format_exc,
},
exc_info=True,
)
exit_complete(
USER_ERROR_EXIT_CODE,
args.save_path,
message,
termination_log_file,
enable_termination_log,
)
# Read training kwargs
try:
if os.path.isfile(args.training_kwargs):
with open(args.training_kwargs, encoding="utf-8") as handle:
training_kwargs = json.load(handle)
else:
training_kwargs = json.loads(args.training_kwargs)
# Convert datatypes to match the training API
training_service = ServicePackageFactory.get_service_package(
ServicePackageFactory.ServiceType.TRAINING,
)
train_rpcs = [
rpc
for rpc in training_service.caikit_rpcs.values()
if rpc.module_list == [module]
]
error.value_check(
"<COR11978965E>",
len(train_rpcs) == 1,
"Unable to find a unique train signature",
)
package_name = get_service_package_name(
ServicePackageFactory.ServiceType.TRAINING
)
train_rpc_req = (
train_rpcs[0].create_request_data_model(package_name).get_proto_class()
)
request_proto = json_format.Parse(
json.dumps({"parameters": training_kwargs}),
train_rpc_req(),
)
req_kwargs = build_caikit_library_request_dict(
request_proto.parameters, module.TRAIN_SIGNATURE
)
train_kwargs.update(req_kwargs)
log.debug3("All train kwargs: %s", train_kwargs)
except json.decoder.JSONDecodeError:
message = "training-kwargs must be valid json or point to a valid json file"
log.warning(
{
"log_code": "<COR65834760E>",
"message": message,
"stack_trace": traceback.format_exc(),
},
exc_info=True,
)
exit_complete(
USER_ERROR_EXIT_CODE,
args.save_path,
message,
termination_log_file,
enable_termination_log,
)
except ValueError as e:
message = f"Invalid value for one or more input parameters: {e}"
log.warning(
{
"log_code": "<COR65474760E>",
"message": message,
"stack_trace": traceback.format_exc(),
},
exc_info=True,
)
except Exception:
message = "Exception encountered when attempting to parse input parameters"
log.warning(
{
"log_code": "<COR17776549E>",
"message": message,
"stack_trace": traceback.format_exc(),
},
exc_info=True,
)
exit_complete(
USER_ERROR_EXIT_CODE,
args.save_path,
message,
termination_log_file,
enable_termination_log,
)
try:
# Run the training
with alog.ContextTimer(
log.info,
"Finished training %s in: ",
args.model_name,
):
future = train(module, wait=True, **train_kwargs)
info = future.get_info()
if info.status == TrainingStatus.COMPLETED:
log.info(
{
"log_code": "<COR74526958I>",
"message": "Training finished successfully",
}
)
exit_complete(0, args.save_path, None, None, None)
else:
log.warning(
{
"log_code:": "<COR72523958E>",
"message": "Training finished unsuccessfully",
}
)
for err in info.errors or []:
log.error(err)
exit_complete(
INTERNAL_ERROR_EXIT_CODE,
args.save_path,
"Training finished unsuccessfully",
termination_log_file,
enable_termination_log,
)
except MemoryError:
message = "OOM error during training"
log.warning(
{
"log_code": "<COR04280062E>",
"message": message,
"stack_trace": traceback.format_exc(),
},
exc_info=True,
)
exit_complete(
INTERNAL_ERROR_EXIT_CODE,
args.save_path,
message,
termination_log_file,
enable_termination_log,
)
except Exception:
message = "Unhandled exception during training"
log.warning(
{
"log_code": "<COR04280062E>",
"message": message,
"stack_trace": traceback.format_exc(),
},
exc_info=True,
)
exit_complete(
INTERNAL_ERROR_EXIT_CODE,
args.save_path,
message,
termination_log_file,
enable_termination_log,
)
if __name__ == "__main__":
sys.exit(main()) # pragma: no cover