# Copyright (C) 2017-2023 Cleanlab Inc.
# This file is part of cleanlab.
#
# cleanlab is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# cleanlab is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with cleanlab. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional
import numpy as np
from cleanlab.classification import CleanLearning
from cleanlab.datalab.issue_manager import IssueManager
from cleanlab.internal.validation import assert_valid_inputs
if TYPE_CHECKING: # pragma: no cover
import pandas as pd
from cleanlab.datalab.datalab import Datalab
[docs]class LabelIssueManager(IssueManager):
"""Manages label issues in a Datalab.
Parameters
----------
datalab :
A Datalab instance.
clean_learning_kwargs :
Keyword arguments to pass to the :py:meth:`CleanLearning <cleanlab.classification.CleanLearning>` constructor.
health_summary_parameters :
Keyword arguments to pass to the :py:meth:`health_summary <cleanlab.dataset.health_summary>` function.
"""
description: ClassVar[
str
] = """Examples whose given label is estimated to be potentially incorrect
(e.g. due to annotation error) are flagged as having label issues.
"""
issue_name: ClassVar[str] = "label"
verbosity_levels = {
0: [],
1: [],
2: [],
3: ["classes_by_label_quality", "overlapping_classes"],
}
def __init__(
self,
datalab: Datalab,
clean_learning_kwargs: Optional[Dict[str, Any]] = None,
health_summary_parameters: Optional[Dict[str, Any]] = None,
**_,
):
super().__init__(datalab)
self.cl = CleanLearning(**(clean_learning_kwargs or {}))
self.health_summary_parameters: Dict[str, Any] = (
health_summary_parameters.copy() if health_summary_parameters else {}
)
self._reset()
@staticmethod
def _process_find_label_issues_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""Searches for keyword arguments that are meant for the
CleanLearning.find_label_issues method call
Examples
--------
>>> from cleanlab.datalab.issue_manager.label import LabelIssueManager
>>> LabelIssueManager._process_clean_learning_kwargs(thresholds=[0.1, 0.9])
{'thresholds': [0.1, 0.9]}
"""
accepted_kwargs = [
"thresholds",
"noise_matrix",
"inverse_noise_matrix",
"save_space",
"clf_kwargs",
"validation_func",
]
return {k: v for k, v in kwargs.items() if k in accepted_kwargs and v is not None}
def _reset(self) -> None:
"""Reset the attributes of this manager based on the available datalab info
and the keyword arguments stored as instance attributes.
This allows the builder to use pre-computed info from the datalab to speed up
some computations in the :py:meth:`find_issues` method.
"""
if not self.health_summary_parameters:
statistics_dict = self.datalab.get_info("statistics")
self.health_summary_parameters = {
"labels": self.datalab._labels,
"class_names": list(self.datalab._label_map.values()),
"num_examples": statistics_dict.get("num_examples"),
"joint": statistics_dict.get("joint", None),
"confident_joint": statistics_dict.get("confident_joint", None),
"multi_label": statistics_dict.get("multi_label", None),
"asymmetric": statistics_dict.get("asymmetric", None),
"verbose": False,
}
self.health_summary_parameters = {
k: v for k, v in self.health_summary_parameters.items() if v is not None
}
[docs] def find_issues(
self,
pred_probs: np.ndarray,
**kwargs,
) -> None:
self.health_summary_parameters.update({"pred_probs": pred_probs})
# Find examples with label issues
self.issues = self.cl.find_label_issues(
labels=self.datalab._labels,
pred_probs=pred_probs,
**self._process_find_label_issues_kwargs(kwargs),
)
self.issues.rename(columns={"label_quality": self.issue_score_key}, inplace=True)
summary_dict = self.get_health_summary(pred_probs=pred_probs)
# Get a summarized dataframe of the label issues
self.summary = self.make_summary(score=summary_dict["overall_label_health_score"])
# Collect info about the label issues
self.info = self.collect_info(issues=self.issues, summary_dict=summary_dict)
# Drop columns from issues that are in the info
self.issues = self.issues.drop(columns=["given_label", "predicted_label"])
[docs] def get_health_summary(self, pred_probs) -> dict:
"""Returns a short summary of the health of this Lab."""
from cleanlab.dataset import health_summary
# Validate input
self._validate_pred_probs(pred_probs)
summary_kwargs = self._get_summary_parameters(pred_probs)
summary = health_summary(**summary_kwargs)
return summary
def _get_summary_parameters(self, pred_probs) -> Dict["str", Any]:
"""Collects a set of input parameters for the health summary function based on
any info available in the datalab.
Parameters
----------
pred_probs :
The predicted probabilities for each example.
kwargs :
Keyword arguments to pass to the health summary function.
Returns
-------
summary_parameters :
A dictionary of parameters to pass to the health summary function.
"""
if "confident_joint" in self.health_summary_parameters:
summary_parameters = {
"confident_joint": self.health_summary_parameters["confident_joint"]
}
elif all([x in self.health_summary_parameters for x in ["joint", "num_examples"]]):
summary_parameters = {
k: self.health_summary_parameters[k] for k in ["joint", "num_examples"]
}
else:
summary_parameters = {
"pred_probs": pred_probs,
"labels": self.datalab._labels,
}
summary_parameters["class_names"] = self.health_summary_parameters["class_names"]
for k in ["asymmetric", "verbose"]:
# Start with the health_summary_parameters, then override with kwargs
if k in self.health_summary_parameters:
summary_parameters[k] = self.health_summary_parameters[k]
return (
summary_parameters # will be called in `dataset.health_summary(**summary_parameters)`
)
[docs] def collect_info(self, issues: pd.DataFrame, summary_dict: dict) -> dict:
issues_info = {
"num_label_issues": sum(issues[f"is_{self.issue_name}_issue"]),
"average_label_quality": issues[self.issue_score_key].mean(),
"given_label": issues["given_label"].tolist(),
"predicted_label": issues["predicted_label"].tolist(),
}
health_summary_info = {
"confident_joint": summary_dict["joint"],
"classes_by_label_quality": summary_dict["classes_by_label_quality"],
"overlapping_classes": summary_dict["overlapping_classes"],
}
cl_info = {}
for k in self.cl.__dict__:
if k not in ["py", "noise_matrix", "inverse_noise_matrix", "confident_joint"]:
continue
cl_info[k] = self.cl.__dict__[k]
info_dict = {
**issues_info,
**health_summary_info,
**cl_info,
}
return info_dict
def _validate_pred_probs(self, pred_probs) -> None:
assert_valid_inputs(X=None, y=self.datalab._labels, pred_probs=pred_probs)