Source code for caikit.core.toolkit.concurrency.destroyable_process

# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A DestroyableProcess implements a multiprocessing Process that captures errors
and communicates them to its parent with clean semantics for being destroyed.

NOTE: The "fork" start method is used for two reasons:
    1. It's faster
    2. For auto-generated classes (e.g. stream sources), "spawn" can result in
        missing classes since it performs a full re-import, but may not
        regenerate these classes.
"""
# Standard
from functools import partial
from typing import Any, Callable, Optional, Tuple
import multiprocessing
import os
import sys

# First Party
import alog

# Local
from .destroyable import Destroyable
from .pickling_exception import ExceptionPickler
from caikit.core.exceptions import error_handler

log = alog.use_channel("DESTROY-PROC")
error = error_handler.get(log)


OOM_EXIT_CODES = [137, 9, -9]

FORK_CTX = None if sys.platform == "win32" else multiprocessing.get_context("fork")
SPAWN_CTX = multiprocessing.get_context("spawn")
FORKSERVER_CTX = (
    None if sys.platform == "win32" else multiprocessing.get_context("forkserver")
)


[docs] class _DestroyableProcess( multiprocessing.process.BaseProcess, Destroyable, ): # pylint: disable=too-many-instance-attributes """The _DestroyableProcess base class implements a context-agnostic process class that manages the subprocess and allows it to be destroyed """ def __init__( self, target: Optional[Callable] = None, completion_event: Optional[multiprocessing.Event] = None, args: Optional[Tuple] = None, kwargs: Optional[dict] = None, destroy_grace_period: float = 10, return_result: bool = False, **_kwargs, ): """Initialize with an event to use to signal completion""" self._parent_conn, self._child_conn = self.__class__._MP_CTX.Pipe() self._completion_event = completion_event or self.__class__._MP_CTX.Event() self._destroy_grace_period = destroy_grace_period self._return_result = return_result # This will hold the terminal result of the process which will either be # an Exception or some other value self.__result = None # Descriptions of completion state self.__started = False self.__destroyed = False self.__canceled = False # Bind the target to the target wrapper wrapped_target = partial( self._target_wrapper, target=target, args=args, kwargs=kwargs, ) super().__init__(target=wrapped_target, **_kwargs) ## Destroyable Interface ## @property def destroyed(self) -> bool: return self.__destroyed @property def canceled(self) -> bool: return self.__canceled @property def ran(self) -> bool: return self.__started and not self.is_alive() @property def threw(self) -> bool: return self.error is not None
[docs] def get_or_throw(self) -> Any: """Get the result of the execution or raise an error if one occurred""" if self.destroyed: log.error( "<COR24981767E>", "get_or_throw called on destroyed process, no value to return", ) if not self.ran: log.error( "<COR12037430E>", "get_or_throw called on process, but it has not finished running", ) # Update the result and throw if it's an error err = self.error if err is not None: raise err return self.__result
[docs] def destroy(self): """Cancel any in-progress work""" self.__destroyed = True if self.is_alive() or not self.__started: self.__canceled = True if self.__started: self.terminate() self.join(self._destroy_grace_period) self.kill() self.join() self._completion_event.set()
## Process Interface ##
[docs] def start(self): if self.destroyed: err_msg = "Not starting work on pre-canceled process" log.warning("<COR42191929W>", err_msg) self._completion_event.set() self.__result = RuntimeError(err_msg) return self.__started = True return super().start()
[docs] def join(self, *args, **kwargs): if self.destroyed and not self.__started: return return super().join(*args, **kwargs)
# NOTE: This functionality is covered by the unit tests, but pytest-cov is # not correctly collecting the coverage info since it executes inside the # subprocess
[docs] def run(self): # pragma: no cover try: # Run and indicate to the parent that no super().run() # Catch any errors thrown within a subprocess so that they can be # forwarded to the parent # pylint: disable=broad-exception-caught except Exception as err: err_str = repr(err) log.error( "<COR69863806E>", "Caught exception in destroyable process: %s", err_str, exc_info=True, ) # Wrap error for safe pickling pickler = ExceptionPickler(err) self._child_conn.send(pickler) finally: self._completion_event.set()
## Impl ##
[docs] def _update_result(self): if self._parent_conn.poll(): self.__result = self._parent_conn.recv()
@property def error(self) -> Optional[Exception]: self._update_result() if isinstance(self.__result, ExceptionPickler): return self.__result.get() if isinstance(self.__result, Exception): return self.__result if self.exitcode and self.exitcode != os.EX_OK: if self.exitcode in OOM_EXIT_CODES: return MemoryError("Training process died with OOM error!") if not self.canceled: return RuntimeError( f"Training process died with exit code {self.exitcode}" ) @property def completion_event(self) -> multiprocessing.Event: return self._completion_event # NOTE: This functionality is covered by the unit tests, but pytest-cov is # not correctly collecting the coverage info since it executes inside the # subprocess
[docs] def _target_wrapper(self, target, args, kwargs): # pragma: no cover result = target(*(args or []), **(kwargs or {})) log.debug3("Process target result: %s", result) if self._return_result: self._child_conn.send(result) self._completion_event.set()
_PROCESS_TYPES = {} if FORK_CTX is not None:
[docs] class _ForkDestroyableProcess(FORK_CTX.Process, _DestroyableProcess): _MP_CTX = FORK_CTX
_PROCESS_TYPES["fork"] = _ForkDestroyableProcess if SPAWN_CTX is not None:
[docs] class _SpawnDestroyableProcess(SPAWN_CTX.Process, _DestroyableProcess): _MP_CTX = SPAWN_CTX
_PROCESS_TYPES["spawn"] = _SpawnDestroyableProcess if FORKSERVER_CTX is not None:
[docs] class _ForkserverDestroyableProcess(FORKSERVER_CTX.Process, _DestroyableProcess): _MP_CTX = FORKSERVER_CTX
_PROCESS_TYPES["forkserver"] = _ForkserverDestroyableProcess
[docs] def DestroyableProcess(start_method: str, *args, **kwargs): """Class wrapper that returns the appropriate process type based on the requested start method. NOTE: Naming intentionally looks like a class! """ error.value_check( "<COR16699811E>", start_method in _PROCESS_TYPES, "Unsupported start_method on {}: {}. Supported options are {}", sys.platform, start_method, list(_PROCESS_TYPES.keys()), ) return _PROCESS_TYPES[start_method](*args, **kwargs)