Source code for cleanlab.datalab.issue_manager.label

# Copyright (C) 2017-2023  Cleanlab Inc.
# This file is part of cleanlab.
#
# cleanlab is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# cleanlab is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with cleanlab.  If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations

from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional

import numpy as np

from cleanlab.classification import CleanLearning
from cleanlab.datalab.issue_manager import IssueManager
from cleanlab.internal.validation import assert_valid_inputs

if TYPE_CHECKING:  # pragma: no cover
    import pandas as pd

    from cleanlab.datalab.datalab import Datalab


[docs]class LabelIssueManager(IssueManager): """Manages label issues in a Datalab. Parameters ---------- datalab : A Datalab instance. clean_learning_kwargs : Keyword arguments to pass to the :py:meth:`CleanLearning <cleanlab.classification.CleanLearning>` constructor. health_summary_parameters : Keyword arguments to pass to the :py:meth:`health_summary <cleanlab.dataset.health_summary>` function. """ description: ClassVar[ str ] = """Examples whose given label is estimated to be potentially incorrect (e.g. due to annotation error) are flagged as having label issues. """ issue_name: ClassVar[str] = "label" verbosity_levels = { 0: [], 1: [], 2: [], 3: ["classes_by_label_quality", "overlapping_classes"], } def __init__( self, datalab: Datalab, clean_learning_kwargs: Optional[Dict[str, Any]] = None, health_summary_parameters: Optional[Dict[str, Any]] = None, **_, ): super().__init__(datalab) self.cl = CleanLearning(**(clean_learning_kwargs or {})) self.health_summary_parameters: Dict[str, Any] = ( health_summary_parameters.copy() if health_summary_parameters else {} ) self._reset() @staticmethod def _process_find_label_issues_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]: """Searches for keyword arguments that are meant for the CleanLearning.find_label_issues method call Examples -------- >>> from cleanlab.datalab.issue_manager.label import LabelIssueManager >>> LabelIssueManager._process_clean_learning_kwargs(thresholds=[0.1, 0.9]) {'thresholds': [0.1, 0.9]} """ accepted_kwargs = [ "thresholds", "noise_matrix", "inverse_noise_matrix", "save_space", "clf_kwargs", "validation_func", ] return {k: v for k, v in kwargs.items() if k in accepted_kwargs and v is not None} def _reset(self) -> None: """Reset the attributes of this manager based on the available datalab info and the keyword arguments stored as instance attributes. This allows the builder to use pre-computed info from the datalab to speed up some computations in the :py:meth:`find_issues` method. """ if not self.health_summary_parameters: statistics_dict = self.datalab.get_info("statistics") self.health_summary_parameters = { "labels": self.datalab._labels, "class_names": list(self.datalab._label_map.values()), "num_examples": statistics_dict.get("num_examples"), "joint": statistics_dict.get("joint", None), "confident_joint": statistics_dict.get("confident_joint", None), "multi_label": statistics_dict.get("multi_label", None), "asymmetric": statistics_dict.get("asymmetric", None), "verbose": False, } self.health_summary_parameters = { k: v for k, v in self.health_summary_parameters.items() if v is not None }
[docs] def find_issues( self, pred_probs: np.ndarray, **kwargs, ) -> None: self.health_summary_parameters.update({"pred_probs": pred_probs}) # Find examples with label issues self.issues = self.cl.find_label_issues( labels=self.datalab._labels, pred_probs=pred_probs, **self._process_find_label_issues_kwargs(kwargs), ) self.issues.rename(columns={"label_quality": self.issue_score_key}, inplace=True) summary_dict = self.get_health_summary(pred_probs=pred_probs) # Get a summarized dataframe of the label issues self.summary = self.make_summary(score=summary_dict["overall_label_health_score"]) # Collect info about the label issues self.info = self.collect_info(issues=self.issues, summary_dict=summary_dict) # Drop columns from issues that are in the info self.issues = self.issues.drop(columns=["given_label", "predicted_label"])
[docs] def get_health_summary(self, pred_probs) -> dict: """Returns a short summary of the health of this Lab.""" from cleanlab.dataset import health_summary # Validate input self._validate_pred_probs(pred_probs) summary_kwargs = self._get_summary_parameters(pred_probs) summary = health_summary(**summary_kwargs) return summary
def _get_summary_parameters(self, pred_probs) -> Dict["str", Any]: """Collects a set of input parameters for the health summary function based on any info available in the datalab. Parameters ---------- pred_probs : The predicted probabilities for each example. kwargs : Keyword arguments to pass to the health summary function. Returns ------- summary_parameters : A dictionary of parameters to pass to the health summary function. """ if "confident_joint" in self.health_summary_parameters: summary_parameters = { "confident_joint": self.health_summary_parameters["confident_joint"] } elif all([x in self.health_summary_parameters for x in ["joint", "num_examples"]]): summary_parameters = { k: self.health_summary_parameters[k] for k in ["joint", "num_examples"] } else: summary_parameters = { "pred_probs": pred_probs, "labels": self.datalab._labels, } summary_parameters["class_names"] = self.health_summary_parameters["class_names"] for k in ["asymmetric", "verbose"]: # Start with the health_summary_parameters, then override with kwargs if k in self.health_summary_parameters: summary_parameters[k] = self.health_summary_parameters[k] return ( summary_parameters # will be called in `dataset.health_summary(**summary_parameters)` )
[docs] def collect_info(self, issues: pd.DataFrame, summary_dict: dict) -> dict: issues_info = { "num_label_issues": sum(issues[f"is_{self.issue_name}_issue"]), "average_label_quality": issues[self.issue_score_key].mean(), "given_label": issues["given_label"].tolist(), "predicted_label": issues["predicted_label"].tolist(), } health_summary_info = { "confident_joint": summary_dict["joint"], "classes_by_label_quality": summary_dict["classes_by_label_quality"], "overlapping_classes": summary_dict["overlapping_classes"], } cl_info = {} for k in self.cl.__dict__: if k not in ["py", "noise_matrix", "inverse_noise_matrix", "confident_joint"]: continue cl_info[k] = self.cl.__dict__[k] info_dict = { **issues_info, **health_summary_info, **cl_info, } return info_dict
def _validate_pred_probs(self, pred_probs) -> None: assert_valid_inputs(X=None, y=self.datalab._labels, pred_probs=pred_probs)