Source code for caikit.core.toolkit.quality_evaluation

# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Evaluate quality of models.
"""

# Standard
from dataclasses import dataclass
from typing import Dict, Optional
import enum
import math
import re

# First Party
import alog

# Local
from ..exceptions import error_handler

log = alog.use_channel("TLKIT")
error = error_handler.get(log)


[docs] class EvalTypes(enum.Enum): """Enum that contains set of all possible evaluation types.""" SINGLELABEL_MULTICLASS = 1 MULTILABEL_MULTICLASS = 2 MULTILABEL_MULTICLASS_HIERARCHICAL = 3
[docs] @dataclass class F1Metrics: true_positive: Optional[int] = None false_positive: Optional[int] = None false_negative: Optional[int] = None precision: Optional[float] = None recall: Optional[float] = None f1: Optional[float] = None
[docs] @dataclass class F1MetricsContainer: per_class_confusion_matrix: Dict[str, F1Metrics] macro_metrics: F1Metrics micro_metrics: F1Metrics
[docs] class QualityEvaluator: """Class that holds all evaluation logic for now. May eventually be broken up into subclasses.""" def __init__(self, gold, pred): self.gold = gold self.pred = pred
[docs] def run( self, evaluation_type, find_label_func=None, find_label_data_func=None, detailed_metrics=False, labels=None, partial_match_metrics=False, max_hierarchy_levels=3, ): """Main entry point for evaluation. Args: evaluation_type (str): Which type of evaluation to run. Only a few are currently supported. find_label_func: function to fetch labels from any one prediction, used in multiclass multilabel evaluation. eg: if a prediction is of form (token, label), this function should be able to tell us how to extract the class labels from the prediction, in this case return the second element of the tuple. find_label_data_func: function to fetch predictions that belongs to a certain label, used only in multiclass multilabel eval type, e.g., if predictions for a data example looks like [(tok1, labX), (tok2, labY), (tok3, labX)], then the function should be able to return all predictions with a given label - labX return should look like [(tok1, labX), (tok3, labX)] detailed_metrics: flag to indicate whether or not you want detailed metrics (currently only for multiclass multilabel eval type) Detailed metrics give us metrics for every example, and metrics using a custom partial match function labels: list (Optional, defaults to None) Optional list of class labels to evaluate quality on. By default evaluation is done over all class labels. Using this, you can explicitly mention only a subset of labels to include in the quality evaluation. partial_match_metrics: flag to indicate whether or not you want partial match micro avg metrics. (currently only for multiclass multilabel eval type) max_hierarchy_levels (int): Used in hierarchical multilabel multiclass evaluation only. The number of levels in the hierarchy to run model evaluation on, in addition to complete matches. Returns: dict: Full results from evaluation on dataset and model. """ if evaluation_type == EvalTypes.MULTILABEL_MULTICLASS: return self.multilabel_multiclass_evaluation( find_label_func, find_label_data_func, labels, detailed_metrics, partial_match_metrics, ) if evaluation_type == EvalTypes.SINGLELABEL_MULTICLASS: return self.singlelabel_multiclass_evaluation(labels) if evaluation_type == EvalTypes.MULTILABEL_MULTICLASS_HIERARCHICAL: return self.multilabel_multiclass_hierarchical_evaluation( find_label_func, find_label_data_func, max_hierarchy_levels ) error( "<COR81451123E>", ValueError(f"Unknown evaluation_type: {evaluation_type}"), )
[docs] def singlelabel_multiclass_evaluation(self, labels=None) -> dict: """Obtain results of evaluation for a single-label, multi-class model. Args: Note: here class should be initialized with gold and pred in the following format self.gold (list): list of gold set labels for every example, where each example can have only one label eg: ['label1','label2', 'label3','label4'] self.pred (list): Predicted-by-the-model set labels for every example. labels: list (Optional, defaults to None) Optional list of class labels to evaluate quality on. By default evaluation is done over all class labels. Using this, you can explicitly mention only a subset of labels to include in the quality evaluation. Returns: dict: Dictionary looks like: { 'per_class_confusion_matrix': {'entity_type': {'true_positive': int ...}} 'macro_precision': 0 <= float <= 1, 'macro_recall': 0 <= float <= 1, 'macro_f1': 0 <= float <= 1, 'micro_precision': 0 <= float <= 1,, 'micro_recall': 0 <= float <= 1,, 'micro_f1': 0 <= float <= 1, 'overall_tp': int, 'overall_fp': int, 'overall_fn': int } """ gold, pred = self.gold, self.pred assert len(gold) == len( pred ), "Length of gold and predicted datasets does not match" per_class_confusion_matrix = {} for gold_label, pred_label in zip(gold, pred): if gold_label not in per_class_confusion_matrix and ( labels is None or gold_label in labels ): per_class_confusion_matrix[gold_label] = F1Metrics( true_positive=0, false_positive=0, false_negative=0, precision=0.0, recall=0.0, f1=0.0, ) if pred_label not in per_class_confusion_matrix and ( labels is None or pred_label in labels ): per_class_confusion_matrix[pred_label] = F1Metrics( true_positive=0, false_positive=0, false_negative=0, precision=0.0, recall=0.0, f1=0.0, ) # true positive if gold_label == pred_label and (labels is None or gold_label in labels): per_class_confusion_matrix[gold_label].true_positive += 1 else: if labels is None or pred_label in labels: per_class_confusion_matrix[pred_label].false_positive += 1 if labels is None or gold_label in labels: per_class_confusion_matrix[gold_label].false_negative += 1 calc_metrics = QualityEvaluator.calc_metrics_from_confusion_matrix( per_class_confusion_matrix ) metrics_out = QualityEvaluator.convert_F1MetricsContainer_to_dict(calc_metrics) return metrics_out
[docs] def multilabel_multiclass_evaluation( self, find_label_func, find_label_data_func, labels=None, detailed_metrics=False, partial_match_metrics=False, use_labels_for_matching=False, ) -> dict: """Obtain results of evaluation for a multi-label, multi-class model. Args: Note: here class should be initialized with gold and pred in the following format self.gold (list(list)): list of gold set labels for every example eg: [['label1','label2'], ['label1', 'label4']] self.pred (list(list)): Predicted-by-the-model set labels for every example. find_label_func: function to fetch labels from any one prediction find_label_data_func: function to fetch data that belongs to a certain class labels: list (Optional, defaults to None) Optional list of class labels to evaluate quality on. By default evaluation is done over all class labels. Using this, you can explicitly mention only a subset of labels to include in the quality evaluation. detailed_metrics: flag to indicate whether or not you want detailed metrics Detailed metrics give us metrics for every example, and metrics using a custom partial match function partial_match_metrics: flag to indicate whether or not you want partial match micro avg metrics. use_labels_for_matching (bool): Indicates whether or not we should use the output of find_label_func for metric computations, or the raw data tuples. Returns: dict: Dictionary looks like: { 'per_class_confusion_matrix': {'entity_type': {'true_positive': int ...}} 'macro_precision': 0 <= float <= 1, 'macro_recall': 0 <= float <= 1, 'macro_f1': 0 <= float <= 1, 'micro_precision': micro_precision, 'micro_recall': micro_recall, 'micro_f1': micro_f1, 'detailed_metrics' : {'exact_match_precision'..,'partial_match_precision'} 'micro_precision_partial_match': 0 <= float <= 1, 'micro_recall_partial_match': 0 <= float <= 1, 'micro_f1_partial_match': 0 <= float <= 1 } """ gold, pred = self.gold, self.pred assert len(gold) == len( pred ), "Length of gold and predicted datasets does not match" detailed_output = [] per_class_confusion_matrix = {} all_labels = set() num_preds_partial_matched = 0 num_gold_partial_matched = 0 total_pred = 0 total_gold = 0 overall_tp = 0 overall_fp = 0 overall_fn = 0 micro_precision = 0.0 micro_recall = 0.0 micro_f1 = 0.0 micro_metrics_partial = F1Metrics(precision=0.0, recall=0.0, f1=0.0) if labels: try: gold = [ [label for label in gold_ex if find_label_func(label) in labels] for gold_ex in gold ] pred = [ [label for label in pred_ex if find_label_func(label) in labels] for pred_ex in pred ] except NotImplementedError: error( "<COR19114599E>", NotImplementedError( "find_label_func must be implemented to use [labels]" ), ) for gold_ex, pred_ex in zip(gold, pred): if detailed_metrics: precision, recall, f1 = QualityEvaluator.calc_f1_score(gold_ex, pred_ex) ( partial_precision, partial_recall, partial_f1, ) = QualityEvaluator.calc_f1_score( gold_ex, pred_ex, QualityEvaluator.find_partial_matches ) instances = { "exact_match_precision": precision, "exact_match_recall": recall, "exact_match_f1": f1, "partial_match_precision": partial_precision, "partial_match_recall": partial_recall, "partial_match_f1": partial_f1, } detailed_output.append(instances) try: gold_labels = set(map(find_label_func, gold_ex)) pred_labels = set(map(find_label_func, pred_ex)) # get per-class information if possible all_labels = gold_labels.union(pred_labels) except NotImplementedError: # If find_label_func raises NotImplementedError, we can't do label-based matching. # In this case we need to fall back to set operations on the raw data tuples. log.info( "INFO: find_label_func not implemented for this module type - falling back " "to tuple match!!" ) use_labels_for_matching = False for label in all_labels: # dictionary initizalization if label not in per_class_confusion_matrix: per_class_confusion_matrix[label] = F1Metrics( true_positive=0, false_positive=0, false_negative=0, precision=0.0, recall=0.0, f1=0.0, ) # build confusion matrix pred_label_data = set(find_label_data_func(pred_ex, label)) gold_label_data = set(find_label_data_func(gold_ex, label)) # true positive per_class_confusion_matrix[label].true_positive += len( gold_label_data.intersection(pred_label_data) ) # false positive per_class_confusion_matrix[label].false_positive += len( pred_label_data - gold_label_data ) # false negative per_class_confusion_matrix[label].false_negative += len( gold_label_data - pred_label_data ) if use_labels_for_matching: gold_ex_set = gold_labels pred_ex_set = pred_labels else: # In case the user did not specify how to obtain class labels, # we can still calculate micro avg using sum of true positives, # false positives etc over all examples (over all classes) # We should deprecate this section in next major release gold_ex_set = set(gold_ex) pred_ex_set = set(pred_ex) overall_tp += len(gold_ex_set.intersection(pred_ex_set)) overall_fp += len(pred_ex_set - gold_ex_set) overall_fn += len(gold_ex_set - pred_ex_set) # Calculate micro average metrics # Micro precision = # no. of correct precisions over all classes / no. of total predictions if not math.isclose(overall_tp + overall_fp, 0): micro_precision = overall_tp / (overall_tp + overall_fp) # Micro recall = no. of correct precisions over all classes / no. of true samples if not math.isclose(overall_tp + overall_fn, 0): micro_recall = overall_tp / (overall_tp + overall_fn) # Micro avg F1 = harmonic mean of precision and recall if not math.isclose(micro_precision + micro_recall, 0): micro_f1 = (2.0 * micro_precision * micro_recall) / ( micro_precision + micro_recall ) if partial_match_metrics: gold_matched, preds_matched = QualityEvaluator.find_partial_matches( gold_ex_set, pred_ex_set ) num_preds_partial_matched += len(preds_matched) num_gold_partial_matched += len(gold_matched) total_pred += len(pred_ex_set) total_gold += len(gold_ex_set) calc_metrics = QualityEvaluator.calc_metrics_from_confusion_matrix( per_class_confusion_matrix ) # This section should be deprecated with future refactors if not use_labels_for_matching: log.warning( "WARNING: Only Micro_avg metrics could be calculated based on the information " "available for this module type." ) calc_metrics.micro_metrics.precision = micro_precision calc_metrics.micro_metrics.recall = micro_recall calc_metrics.micro_metrics.f1 = micro_f1 calc_metrics.micro_metrics.true_positive = overall_tp calc_metrics.micro_metrics.false_positive = overall_fp calc_metrics.micro_metrics.false_negative = overall_fn metrics_out = QualityEvaluator.convert_F1MetricsContainer_to_dict(calc_metrics) # This flag only controls calculation of micro average partial match metrics # Detailed metrics flag calculates partial match metrics per data row if partial_match_metrics: # Calculate micro average partial match metrics # Micro precision = Number of matched predictions / Number predicted # Micro precision = Fraction of retrieved instances that are relevant if total_pred > 0: micro_metrics_partial.precision = num_preds_partial_matched / total_pred # Micro recall = Number of matched gold / Number in gold set # Micro recall = Fraction of relevant instances that are retrieved if total_gold > 0: micro_metrics_partial.recall = num_gold_partial_matched / total_gold # Micro avg F1 = harmonic mean of precision and recall if not math.isclose( micro_metrics_partial.precision + micro_metrics_partial.recall, 0 ): micro_metrics_partial.f1 = ( 2.0 * micro_metrics_partial.precision * micro_metrics_partial.recall ) / (micro_metrics_partial.precision + micro_metrics_partial.recall) metrics_out["detailed_metrics"] = detailed_output metrics_out["micro_precision_partial_match"] = micro_metrics_partial.precision metrics_out["micro_recall_partial_match"] = micro_metrics_partial.recall metrics_out["micro_f1_partial_match"] = micro_metrics_partial.f1 return metrics_out
[docs] def multilabel_multiclass_hierarchical_evaluation( self, find_label_func_builder, find_label_data_func_builder, max_hierarchy_levels=3, ) -> dict: """Evaluate multilabel/multiclass over a hierarchy, e.g., for ESA categories. This method Evaluates over a set number of hierarchy levels. Because each level in the hierarchy needs to be able to compare and extract differently, we use builder funcs that create the appropriate functions for a given level of the hierarchy. Args: find_label_func_builder (function): A function that takes in a level number (or None if full hierarchy) and returns a find_label_func for this level that can be passed to the multilabel multiclass evaluator. find_label_data_func_builder (function): A function that takes in a level number (or None if full hierarchy) and returns a find_label_data_func for this level that can be passed to the multilabel multiclass evaluator. max_hierarchy_levels (int): The number of levels to run in the hierarchy, in addition to complete match. Returns: dict: Dictionary, where each key is a level number, or 'FULL', and maps to the dict returned by multilabel_multiclass_evaluation for that level of the hierarchy. """ metrics = {} # Levels are None [FULL], and 1...n, where n is the deepest level in the hierarchy (for now, # needs to be manually determined by the user). # pylint: disable=unnecessary-comprehension levels = [None] + [level for level in range(1, max_hierarchy_levels + 1)] for level in levels: # Get the find_label_func/find_label_data_func for this level in the hierarchy find_label_func = find_label_func_builder(level) find_label_data_func = find_label_data_func_builder(level) # Get the appropriate dictionary key - Fall is a bit more descriptive than None dict_key = "level_{}".format(level) if level is not None else "level_all" # NOTE: We use label matching for computing our metrics here. This means that we # compare the outputs of find_label_func on gold/pred examples to get our metrics # instead of the gold/pred example tuples themselves. The reason that we generally # want to do this is that the examples have the full labels, but we need to slice out # just part of the hierarchical label to consider each level. metrics[dict_key] = self.multilabel_multiclass_evaluation( find_label_func, find_label_data_func, use_labels_for_matching=True ) return metrics
[docs] @staticmethod def calc_f1_score(gold, pred, match_fun=None): """Calculates F1 score Args: gold (list): List of gold annotations pred (list): List of predictions match_fun: Function that finds the matches and returns tuple of matched gold, preds Returns: tuple: Precision, Recall, F1 score """ if match_fun: # In case of partial match, matched predictions need not equal matched gold # Two predictions can match one gold, while another gold may not be matched try: matched_gold, matched_preds = match_fun(gold, pred) except (ValueError, TypeError): error( "<COR19474599E>", ValueError("Match function not returning expected tuple format"), ) else: matched_preds = set(gold).intersection(set(pred)) matched_gold = matched_preds num_correct_preds, num_pred, num_gold, num_correct_gold = ( len(matched_preds), len(pred), len(gold), len(matched_gold), ) # precision == Fraction of relevant instances among retrieved instances # If we could match 3 predictions with gold out of 4 predictions, precision = 3/4 precision = num_correct_preds / num_pred if num_pred > 0 else 0.0 # recall == Fraction of retrieved instances among relevant instances # If we could match/retrieve only one gold out of 3, recall = 1/3 recall = num_correct_gold / num_gold if num_gold > 0 else 0.0 # f1 == harmonic_mean(precision, recall) f1 = ( (2 * precision * recall) / (precision + recall) if precision != 0 and recall != 0 else 0.0 ) return precision, recall, f1
[docs] @staticmethod def find_partial_matches(groundtruth, prediction): """Function to do find partial match between predicted phrases and the ground truth. partial match means a complete predicted phrase is a part of any ground truth phrase or a complete ground truth phrase is a part of any predicted phrase. Overlaps are not considered. Args: groundtruth (list): Groundtruth data prediction (list): Predictions returned by the model Returns: tuple: gold_matched: set, pred_matched: set gold annotations that were matched Predictions that partially or fully matched with groundtruth """ gold_matched = set() pred_matched = set() for ground_truth_phrase in groundtruth: for predicted_phrase in prediction: pd_compiler = re.compile( r"\b{}\b".format(re.escape(predicted_phrase)), re.I ) # Checks if prediction is part of groundtruth if pd_compiler.search(ground_truth_phrase): pred_matched.add(predicted_phrase) gold_matched.add(ground_truth_phrase) else: # Checks if groundtruth is part of prediction gt_compiler = re.compile( r"\b{}\b".format(re.escape(ground_truth_phrase)), re.I ) if gt_compiler.search(predicted_phrase): pred_matched.add(predicted_phrase) gold_matched.add(ground_truth_phrase) return gold_matched, pred_matched
[docs] @staticmethod def calc_metrics_from_confusion_matrix( per_class_confusion_matrix: Dict[str, F1Metrics] ) -> F1MetricsContainer: """Function to calculate precision, recall, F1 metrics using a confusion matrix containing statistics per class label. Args: per_class_confusion_matrix (Dict[str, F1Metrics]): Dictionary of statistics per class label. Class labels are keys for the dictionary. For each class label, there should be a F1Metrics class object with values true positive, false_positive , false_negative representating the count of these per class. The dictionary looks like: per_class_confusion_matrix[label] = F1Metrics(true_positive = val 1, false_positive = val 2, false_negative = val 3) Returns: Returns: metrics_summary: F1MetricsContainer An instance of F1MetricsContainer dataclass containing summary of F1 metrics """ macro_metrics = F1Metrics(precision=0.0, recall=0.0, f1=0.0) micro_metrics = F1Metrics( true_positive=0, false_positive=0, false_negative=0, precision=0.0, recall=0.0, f1=0.0, ) num_classes = len(per_class_confusion_matrix) # Compute metrics per label for label in per_class_confusion_matrix: tp = per_class_confusion_matrix[label].true_positive fp = per_class_confusion_matrix[label].false_positive fn = per_class_confusion_matrix[label].false_negative # Calculate precision of a label X = \ # no. of correct predictions of X / no. of predictions of X if not math.isclose(tp + fp, 0): per_class_confusion_matrix[label].precision = tp / (tp + fp) # Calculate recall of label X = \ # no. of correct predictions of X / no. of true samples of X if not math.isclose(tp + fn, 0): per_class_confusion_matrix[label].recall = tp / (tp + fn) prec = per_class_confusion_matrix[label].precision recall = per_class_confusion_matrix[label].recall # Calculate F1 score of label X = harmonic mean of precision and recall of X if not math.isclose(prec + recall, 0): per_class_confusion_matrix[label].f1 = (2.0 * prec * recall) / ( prec + recall ) micro_metrics.true_positive += tp micro_metrics.false_positive += fp micro_metrics.false_negative += fn macro_metrics.precision += per_class_confusion_matrix[label].precision macro_metrics.recall += per_class_confusion_matrix[label].recall macro_metrics.f1 += per_class_confusion_matrix[label].f1 # Macro average metrics = average of metrics over all classes if num_classes > 0: macro_metrics.precision = macro_metrics.precision / num_classes macro_metrics.recall = macro_metrics.recall / num_classes macro_metrics.f1 = macro_metrics.f1 / num_classes # Calculate micro average metrics # Micro precision = no. of correct precisions over all classes / no. of total predictions if not math.isclose( micro_metrics.true_positive + micro_metrics.false_positive, 0 ): micro_metrics.precision = micro_metrics.true_positive / ( micro_metrics.true_positive + micro_metrics.false_positive ) # Micro recall = no. of correct precisions over all classes / no. of true samples if not math.isclose( micro_metrics.true_positive + micro_metrics.false_negative, 0 ): micro_metrics.recall = micro_metrics.true_positive / ( micro_metrics.true_positive + micro_metrics.false_negative ) # Micro avg F1 = harmonic mean of precision and recall if not math.isclose(micro_metrics.precision + micro_metrics.recall, 0): micro_metrics.f1 = ( 2.0 * micro_metrics.precision * micro_metrics.recall ) / (micro_metrics.precision + micro_metrics.recall) metrics_summary = F1MetricsContainer( per_class_confusion_matrix, macro_metrics, micro_metrics ) return metrics_summary
# pylint: disable=no-self-argument
[docs] def convert_F1MetricsContainer_to_dict(metrics_summary: F1MetricsContainer) -> dict: """ Args: metrics_summary (F1MetricsContainer): An object of dataclass F1MetricsContainer Returns: Returns: dict Dictionary looks like: { 'per_class_confusion_matrix': {'entity_type': {'true_positive': int ...}} 'macro_precision': 0 <= float <= 1, 'macro_recall': 0 <= float <= 1, 'macro_f1': 0 <= float <= 1, 'micro_precision': 0 <= float <= 1,, 'micro_recall': 0 <= float <= 1,, 'micro_f1': 0 <= float <= 1, 'overall_tp': int, 'overall_fp': int, 'overall_fn': int } """ for label, obj in metrics_summary.per_class_confusion_matrix.items(): # Converts the object to dictionary metrics_summary.per_class_confusion_matrix[label] = vars(obj) out = {"per_class_confusion_matrix": metrics_summary.per_class_confusion_matrix} for k, v in vars(metrics_summary.macro_metrics).items(): out[f"macro_{k}"] = v out["micro_precision"] = metrics_summary.micro_metrics.precision out["micro_recall"] = metrics_summary.micro_metrics.recall out["micro_f1"] = metrics_summary.micro_metrics.f1 out["overall_tp"] = metrics_summary.micro_metrics.true_positive out["overall_fp"] = metrics_summary.micro_metrics.false_positive out["overall_fn"] = metrics_summary.micro_metrics.false_negative return out