Source code for caikit.core.toolkit.concurrency.destroyable_thread

# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Standard
from typing import Optional
import ctypes
import sys
import threading
import traceback

# First Party
import alog

# Local
from .destroyable import Destroyable

log = alog.use_channel("DESTROY-THRD")


[docs] class ThreadDestroyedException(RuntimeError): """Exception raised inside a DestroyableThread when it is destroyed by the thread managing its lifecycle.""" def __init__(self): super().__init__( "Work thread intentionally destroyed by its lifecycle manager. " "This exception was not raised by the code running in this thread." )
# pylint: disable=too-many-instance-attributes
[docs] class DestroyableThread(threading.Thread, Destroyable): """A class for Destroyable Threads. When work is delegated to a thread but may need to be canceled while in progress, we use this class which allows us to raise an exception inside the work thread. Exceptions raised this way are asynchronous and they will not interrupt the python instruction that the thread is currently executing. E.g. a time.sleep() will finish sleeping before the exception is raised. This class may be initialized with a threading event, which it will set when the thread finishes executing, whether nominally or by raising an exception. """ # The exception we'll throw to kill the thread __exception = ThreadDestroyedException def __init__( self, runnable_func, *runnable_args, work_done_event: Optional[threading.Event] = None, **runnable_kwargs, ): threading.Thread.__init__(self) self.work_done_event = work_done_event or threading.Event() # These describe the work to be done self.runnable_func = runnable_func self.runnable_args = runnable_args self.runnable_kwargs = runnable_kwargs # These describe what happened with the work self.__runnable_result = None self.__runnable_exception = None self.__threw = False self.__started = False self.__ran = False # In case `destroy` is called before Python has actually started the thread, we need to # know to not do the work self.__destroyed = False @property def destroyed(self) -> bool: return self.__destroyed @property def canceled(self) -> bool: return self.destroyed and ( (self.__started and (self.threw or not self.ran)) or not self.__started ) @property def ran(self) -> bool: return self.__ran @property def threw(self) -> bool: return self.__threw # Run wraps the supplied function with logic to set the event when it finishes, and save any # result or raised error
[docs] def run(self) -> None: """ Overrides Thread.run() *Do not call* Returns: None """ # Raise immediately if the thread was destroyed before if self.__destroyed: log.info( "<COR14653273I>", "Not starting work for %s, thread already cancelled", self.runnable_func, ) self.__raise() self.__started = True try: self.__runnable_result = self.runnable_func( *self.runnable_args, **self.runnable_kwargs ) self.__threw = False except: # noqa: E722 # bare-except # PEP8 complains, but in this case we really do want to re-throw _any_ exception that # occurred. In the interest of transparently wrapping any work in these threads, we # want to keep exception signatures identical. E.g. if I expect this thread to throw a # CaikitRuntimeException, I want to be able to catch a CaikitRuntimeException. # Rethrowing from `sys.exc_info()[1]` should retain all stack trace info later. e = sys.exc_info()[1] self.__runnable_exception = e self.__threw = True # Add a little bit of visibility to know why work failed if self.__destroyed: log.info( { "log_code": "<COR15827563I>", "message": "Work for {} was aborted and threw".format( self.runnable_func ), "stack_trace": traceback.format_exc(), } ) else: log.warning( { "log_code": "<COR16788843W>", "message": "Work for {} threw exception: {}".format( self.runnable_func, e ), "stack_trace": traceback.format_exc(), } ) finally: # Before setting the synchronization event, flag that the work was done self.__ran = True self.work_done_event.set()
[docs] def get_or_throw(self): """ After the thread has completed it's work, call this to get the output. Returns: The resulting value of runnable_func(*runnable_args, **runnable_kwargs) Raises: Any exception raised by runnable_func(*runnable_args, **runnable_kwargs) """ if self.destroyed: log.error( "<COR14653274E>", "get_or_throw called on destroyed thread for %s, no value to return", self.runnable_func, ) if not self.ran: log.error( "<COR14653275E>", "get_or_throw called on thread for %s, but it has not finished running", self.runnable_func, ) if self.threw: raise self.__runnable_exception return self.__runnable_result
[docs] def destroy(self) -> None: """ Cancel any in-progress work and kill the thread if it is alive. Otherwise, prevent the thread from running at all. Returns: None """ # Set the destroyed flag in case the thread has not started yet. # If it has, we should be able to kill it with the async exception below. self.__destroyed = True # The thread has already finished or is not yet alive, so we cannot kill it thread_id = self.__get_id() if thread_id is None or not self.is_alive(): log.debug( "<COR14653276D>", "Destroying thread that is not currently alive: %s", self.runnable_func, ) return # This is the code that raises an async exception in the target thread # (We can't just use raise, because the parent thread is in this control flow) async_exception_result = ctypes.pythonapi.PyThreadState_SetAsyncExc( ctypes.c_long(thread_id), ctypes.py_object(self.__exception) ) if async_exception_result > 1: log.error( "<COR14653277E>", "Could not raise async exception on destroyable thread for %s. Result code: %s", self.runnable_func, async_exception_result, )
@property def error(self) -> Optional[Exception]: if isinstance(self.__runnable_exception, Exception): return self.__runnable_exception def __get_id(self): # Returns the thread if if the thread is running for thread in threading.enumerate(): if thread is self: return thread.ident # Otherwise, the thread has completed or has not started return None def __raise(self): # __exception is just a type, we need to be sure to initialize a value of it raise self.__exception()