caikit.core.toolkit.quality_evaluation
Evaluate quality of models.
Attributes
Classes
Enum that contains set of all possible evaluation types. |
|
Class that holds all evaluation logic for now. May eventually be broken up into |
Module Contents
- caikit.core.toolkit.quality_evaluation.error
- class caikit.core.toolkit.quality_evaluation.EvalTypes(*args, **kwds)[source]
Bases:
enum.EnumEnum that contains set of all possible evaluation types.
- SINGLELABEL_MULTICLASS = 1
- MULTILABEL_MULTICLASS = 2
- MULTILABEL_MULTICLASS_HIERARCHICAL = 3
- class caikit.core.toolkit.quality_evaluation.F1Metrics[source]
- true_positive: int | None = None
- false_positive: int | None = None
- false_negative: int | None = None
- precision: float | None = None
- recall: float | None = None
- f1: float | None = None
- class caikit.core.toolkit.quality_evaluation.QualityEvaluator(gold, pred)[source]
Class that holds all evaluation logic for now. May eventually be broken up into subclasses.
- gold
- pred
- run(evaluation_type, find_label_func=None, find_label_data_func=None, detailed_metrics=False, labels=None, partial_match_metrics=False, max_hierarchy_levels=3)[source]
Main entry point for evaluation.
- Args:
- evaluation_type (str): Which type of evaluation to run. Only a few
are currently supported.
- find_label_func: function to fetch labels from any one prediction, used in
multiclass multilabel evaluation. eg: if a prediction is of form (token, label), this function should be able to tell us how to extract the class labels from the prediction, in this case return the second element of the tuple.
- find_label_data_func: function to fetch predictions that belongs to a certain label,
used only in multiclass multilabel eval type, e.g., if predictions for a data example looks like [(tok1, labX), (tok2, labY), (tok3, labX)], then the function should be able to return all predictions with a given label - labX return should look like [(tok1, labX), (tok3, labX)]
- detailed_metrics: flag to indicate whether or not you want detailed metrics
(currently only for multiclass multilabel eval type) Detailed metrics give us metrics for every example, and metrics using a custom partial match function
- labels: list (Optional, defaults to None)
Optional list of class labels to evaluate quality on. By default evaluation is done over all class labels. Using this, you can explicitly mention only a subset of labels to include in the quality evaluation.
- partial_match_metrics: flag to indicate whether or not you want partial match
micro avg metrics. (currently only for multiclass multilabel eval type)
- max_hierarchy_levels (int): Used in hierarchical multilabel
multiclass evaluation only. The number of levels in the hierarchy to run model evaluation on, in addition to complete matches.
- Returns:
dict: Full results from evaluation on dataset and model.
- singlelabel_multiclass_evaluation(labels=None) dict[source]
Obtain results of evaluation for a single-label, multi-class model.
- Args:
Note: here class should be initialized with gold and pred in the following format self.gold (list): list of gold set labels for every example, where each example
can have only one label eg: [‘label1’,’label2’, ‘label3’,’label4’]
self.pred (list): Predicted-by-the-model set labels for every example. labels: list (Optional, defaults to None)
Optional list of class labels to evaluate quality on. By default evaluation is done over all class labels. Using this, you can explicitly mention only a subset of labels to include in the quality evaluation.
- Returns:
- dict: Dictionary looks like: { ‘per_class_confusion_matrix’:
{‘entity_type’: {‘true_positive’: int …}} ‘macro_precision’: 0 <= float <= 1, ‘macro_recall’: 0 <= float <= 1, ‘macro_f1’: 0 <= float <= 1, ‘micro_precision’: 0 <= float <= 1,, ‘micro_recall’: 0 <= float <= 1,, ‘micro_f1’: 0 <= float <= 1, ‘overall_tp’: int, ‘overall_fp’: int, ‘overall_fn’: int
}
- multilabel_multiclass_evaluation(find_label_func, find_label_data_func, labels=None, detailed_metrics=False, partial_match_metrics=False, use_labels_for_matching=False) dict[source]
Obtain results of evaluation for a multi-label, multi-class model.
- Args:
Note: here class should be initialized with gold and pred in the following format self.gold (list(list)): list of gold set labels for every example eg:
[[‘label1’,’label2’], [‘label1’, ‘label4’]]
self.pred (list(list)): Predicted-by-the-model set labels for every example. find_label_func: function to fetch labels from any one prediction find_label_data_func: function to fetch data that belongs to a certain class labels: list (Optional, defaults to None)
Optional list of class labels to evaluate quality on. By default evaluation is done over all class labels. Using this, you can explicitly mention only a subset of labels to include in the quality evaluation.
- detailed_metrics: flag to indicate whether or not you want detailed metrics
Detailed metrics give us metrics for every example, and metrics using a custom partial match function
- partial_match_metrics: flag to indicate whether or not you want partial match
micro avg metrics.
- use_labels_for_matching (bool): Indicates whether or not we should
use the output of find_label_func for metric computations, or the raw data tuples.
- Returns:
- dict: Dictionary looks like: { ‘per_class_confusion_matrix’:
{‘entity_type’: {‘true_positive’: int …}} ‘macro_precision’: 0 <= float <= 1, ‘macro_recall’: 0 <= float <= 1, ‘macro_f1’: 0 <= float <= 1, ‘micro_precision’: micro_precision, ‘micro_recall’: micro_recall, ‘micro_f1’: micro_f1, ‘detailed_metrics’ : {‘exact_match_precision’..,’partial_match_precision’} ‘micro_precision_partial_match’: 0 <= float <= 1, ‘micro_recall_partial_match’: 0 <= float <= 1, ‘micro_f1_partial_match’: 0 <= float <= 1 }
- multilabel_multiclass_hierarchical_evaluation(find_label_func_builder, find_label_data_func_builder, max_hierarchy_levels=3) dict[source]
Evaluate multilabel/multiclass over a hierarchy, e.g., for ESA categories. This method Evaluates over a set number of hierarchy levels.
Because each level in the hierarchy needs to be able to compare and extract differently, we use builder funcs that create the appropriate functions for a given level of the hierarchy.
- Args:
- find_label_func_builder (function): A function that takes in a level
number (or None if full hierarchy) and returns a find_label_func for this level that can be passed to the multilabel multiclass evaluator.
- find_label_data_func_builder (function): A function that takes in a
level number (or None if full hierarchy) and returns a find_label_data_func for this level that can be passed to the multilabel multiclass evaluator.
- max_hierarchy_levels (int): The number of levels to run in the
hierarchy, in addition to complete match.
- Returns:
- dict: Dictionary, where each key is a level number, or ‘FULL’, and
maps to the dict returned by multilabel_multiclass_evaluation for that level of the hierarchy.
- static calc_f1_score(gold, pred, match_fun=None)[source]
Calculates F1 score Args:
gold (list): List of gold annotations pred (list): List of predictions match_fun: Function that finds the matches and returns tuple of matched gold, preds
- Returns:
tuple: Precision, Recall, F1 score
- static find_partial_matches(groundtruth, prediction)[source]
- Function to do find partial match between predicted phrases and the ground truth.
partial match means a complete predicted phrase is a part of any ground truth phrase or a complete ground truth phrase is a part of any predicted phrase. Overlaps are not considered.
- Args:
groundtruth (list): Groundtruth data prediction (list): Predictions returned by the model
- Returns:
- tuple: gold_matched: set, pred_matched: set gold annotations that
were matched Predictions that partially or fully matched with groundtruth
- static calc_metrics_from_confusion_matrix(per_class_confusion_matrix: Dict[str, F1Metrics]) F1MetricsContainer[source]
- Function to calculate precision, recall, F1 metrics using a confusion matrix containing
statistics per class label.
- Args:
- per_class_confusion_matrix (Dict[str, F1Metrics]): Dictionary of
statistics per class label. Class labels are keys for the dictionary. For each class label, there should be a F1Metrics class object with values true positive, false_positive , false_negative representating the count of these per class. The dictionary looks like: per_class_confusion_matrix[label] = F1Metrics(true_positive = val 1, false_positive = val 2, false_negative = val 3)
- Returns:
Returns: metrics_summary: F1MetricsContainer An instance of F1MetricsContainer dataclass containing summary of F1 metrics
- convert_F1MetricsContainer_to_dict() dict[source]
- Args:
- metrics_summary (F1MetricsContainer): An object of dataclass
F1MetricsContainer
- Returns:
Returns: dict
- Dictionary looks like: {
‘per_class_confusion_matrix’: {‘entity_type’: {‘true_positive’: int …}} ‘macro_precision’: 0 <= float <= 1, ‘macro_recall’: 0 <= float <= 1, ‘macro_f1’: 0 <= float <= 1, ‘micro_precision’: 0 <= float <= 1,, ‘micro_recall’: 0 <= float <= 1,, ‘micro_f1’: 0 <= float <= 1, ‘overall_tp’: int, ‘overall_fp’: int, ‘overall_fn’: int
}