Source code for caikit.runtime.work_management.abortable_context

# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Standard
from queue import SimpleQueue
from typing import Dict, Optional
import abc
import ctypes
import threading
import uuid

# First Party
import alog

# Local
from caikit.runtime.types.aborted_exception import AbortedException

log = alog.use_channel("ABORT-ACTION")


[docs] class ActionAborter(abc.ABC): """Simple interface to wrap up a notification that an action must abort. Children of this class can bind to any notification tool (e.g. grpc context) """
[docs] @abc.abstractmethod def must_abort(self) -> bool: """Indicate whether or not the action must be aborted"""
[docs] @abc.abstractmethod def set_context(self, context: "AbortableContext"): """Set the abortable context that must be notified to abort work"""
[docs] @abc.abstractmethod def unset_context(self): """Unset any abortable context already held. Do not notify it that work should abort"""
[docs] class ThreadInterrupter: """This class implements a listener which will observe all ongoing work in `AbortableContexts` and raise exceptions in the working threads if they need to be aborted. The implementation spawns a single extra thread to wait on any contexts to abort, and interrupt the thread that the context is running in. This keeps the total number of running threads much smaller than using a new thread to monitor each AbortableContext. """ _SHUTDOWN_SIGNAL = -1 def __init__(self): # Using a SimpleQueue because we don't need the Queue's task api self._queue = SimpleQueue() self._thread: Optional[threading.Thread] = None self._context_thread_map: Dict[uuid.UUID, int] = {} self._start_stop_lock = threading.Lock()
[docs] def start(self): """Start the watch loop that will abort any registered contexts passed to .kill()""" with self._start_stop_lock: if self._thread and self._thread.is_alive(): log.debug("ThreadInterrupter already started") return log.debug("Starting ThreadInterrupter") self._thread = threading.Thread(target=self._watch_loop) self._thread.start()
[docs] def stop(self): """Stop the watch loop""" with self._start_stop_lock: if self._thread and not self._thread.is_alive(): log.debug("ThreadInterrupter already shut down") return log.info("Stopping ThreadInterrupter") self._queue.put(self._SHUTDOWN_SIGNAL) self._thread.join(timeout=1)
[docs] def register(self, context_id: uuid, thread: int) -> None: self._context_thread_map[context_id] = thread
[docs] def unregister(self, context_id: uuid) -> None: self._context_thread_map.pop(context_id, None)
[docs] def kill(self, context_id: uuid) -> None: # Put this context onto the queue for abortion and immediately return self._queue.put(context_id, block=False)
[docs] def _watch_loop(self): while True: log.debug("Waiting on any work to abort") context_id = self._queue.get() if context_id == self._SHUTDOWN_SIGNAL: log.debug("Ending abort watch loop") return self._kill_thread(context_id) # Ensure this context/thread pair is unregistered self.unregister(context_id)
[docs] def _kill_thread(self, context_id: uuid.UUID) -> bool: thread_id = self._context_thread_map.get(context_id, None) if thread_id: log.debug("Interrupting thread id: %s", thread_id) # This raises an AbortedException asynchronously in the target thread. (We can't just # use raise, because this thread is the ThreadInterrupter's watch thread). # The exception will only be raised once the target thread regains control of the # python interpreter. This means that statements like `time.sleep(9999999)` cannot be # interrupted in this manner. async_exception_result = ctypes.pythonapi.PyThreadState_SetAsyncExc( ctypes.c_long(thread_id), ctypes.py_object(AbortedException) ) if async_exception_result > 1: log.warning("Failed to abort thread") return False return True else: log.warning("AbortableWork context already unregistered") return False
[docs] class AbortableContext: """Context manager for running work inside a context where it's safe to abort. This is a class instead of a `@contextmanager` function because __exit__ needs to happen on exception. """ def __init__(self, aborter: ActionAborter, interrupter: ThreadInterrupter): """Setup the context. The aborter is responsible for notifying this context if the work needs to be aborted. The interrupter watches all such events, and kills the thread running in this context if the aborter notifies it to abort.""" self.aborter = aborter self.interrupter = interrupter self.id = uuid.uuid4()
[docs] def __enter__(self): if self.aborter and self.interrupter: log.debug4("Entering abortable context %s", self.id) # Set this context on the aborter so that it can notify us when work should be aborted self.aborter.set_context(self) # Register this context with the interrupter so that it knows which thread to kill thread_id = threading.get_ident() self.interrupter.register(self.id, thread_id) else: log.debug4("Aborter or Interrupter was None, no abortable context created.")
[docs] def __exit__(self, exc_type, exc_val, exc_tb): if self.aborter and self.interrupter: # On any exit, whether an exception or not, we unregister with the interrupter # This prevents the interrupter from aborting this thread once this context has ended self.interrupter.unregister(self.id) self.aborter.unset_context()
[docs] def abort(self): """Called by the aborter when this context needs to be aborted""" if self.interrupter: self.interrupter.kill(self.id)