Source code for caikit.core.augmentors.base

# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# Standard
import random

# First Party
import alog

# Local
from ..exceptions import error_handler

log = alog.use_channel("AUGMENT_BASE")
error = error_handler.get(log)


[docs] class AugmentorBase: def __init__(self, random_seed, produces_none=False): """The base class from which all Augmentors inherit. An augmentor is a bit different from a module in that it is subject to the following constraints / design considerations. - An augmentor must have the same input type as output type. Currently, this should only be data model objects / strings only. - An augmentor only takes one argument at run time, which is the object being augmented. Everything else should be contained within the class. This is necessary so that augmentors can be merged. Augmentors can only be combined if they have the same input/output object type. """ error.type_check("<COR73155110E>", int, random_seed=random_seed) error.type_check( "<COR73301070E>", bool, allow_none=True, produces_none=produces_none ) if random_seed is not None: random.seed(random_seed) self._is_random = True self._init_state = random.getstate() self.produces_none = produces_none
[docs] def augment(self, inp_obj): """Take an object in, give an object back. Calls ._augment in the subclass. Args: inp_obj (str | caikit.core.data_model.DataBase): Object to be augmented. Returns: str | caikit.core.data_model.DataBase: Augmented object of same type as input inp_obj. """ if not isinstance(inp_obj, self.augmentor_type): raise TypeError( "Input type [{}] is misaligned with augmentor type [{}]".format( type(inp_obj), self.augmentor_type ) ) out_obj = self._augment(inp_obj) # Only some augmentors are allowed to return None. if out_obj is None and self.produces_none: return out_obj if not isinstance(out_obj, self.augmentor_type): error( "<COR72611702E>", TypeError( "Output type [{}] is misaligned with augmentor type [{}]".format( type(out_obj), self.augmentor_type ) ), ) return out_obj
[docs] def reset(self): """Reset random number generation for the current augmentor. Note that this currently assumes the augmentor is using the builtin random generator leveraged by Python; if you end up using something else, you may want to override this or restructure this base class to allow resetting of random states based on seed type. """ if self._is_random: random.setstate(self._init_state)