"""An internal wrapper around the Imagelab class from the CleanVision package to incorporate it into Datalab.
This allows low-quality images to be detected alongside other issues in computer vision datasets.
The methods/classes in this module are just intended for internal use.
"""
import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, cast, Union
import numpy as np
import numpy.typing as npt
import pandas as pd
from scipy.sparse import csr_matrix
from cleanlab.datalab.internal.adapter.constants import (
DEFAULT_CLEANVISION_ISSUES,
IMAGELAB_ISSUES_MAX_PREVALENCE,
SPURIOUS_CORRELATION_ISSUE,
)
from cleanlab.datalab.internal.data import Data
from cleanlab.datalab.internal.data_issues import DataIssues, _InfoStrategy
from cleanlab.datalab.internal.issue_finder import IssueFinder
from cleanlab.datalab.internal.report import Reporter
from cleanlab.datalab.internal.task import Task
from cleanlab.datalab.internal.spurious_correlation import SpuriousCorrelations
if TYPE_CHECKING: # pragma: no cover
from cleanvision import Imagelab
from datasets.arrow_dataset import Dataset
[docs]def create_imagelab(dataset: "Dataset", image_key: Optional[str]) -> Optional["Imagelab"]:
"""Creates Imagelab instance for running CleanVision checks. CleanVision checks are only supported for
huggingface datasets as of now.
Parameters
----------
dataset: datasets.Dataset
Huggingface dataset used by Imagelab
image_key: str
key for image feature in the huggingface dataset
Returns
-------
Imagelab
"""
imagelab = None
if not image_key:
return imagelab
try:
from cleanvision import Imagelab
from datasets.arrow_dataset import Dataset
if isinstance(dataset, Dataset):
imagelab = Imagelab(hf_dataset=dataset, image_key=image_key)
else:
raise ValueError(
"For now, only huggingface datasets are supported for running cleanvision checks inside cleanlab. You can easily convert most datasets to the huggingface dataset format."
)
except ImportError:
raise ImportError(
"Cannot import required image packages. Please install them via: `pip install cleanlab[image]` or just install cleanlab with "
"all optional dependencies via: `pip install cleanlab[all]`"
)
return imagelab
[docs]class ImagelabDataIssuesAdapter(DataIssues):
"""
Class that collects and stores information and statistics on issues found in a dataset.
Parameters
----------
data :
The data object for which the issues are being collected.
strategy :
Strategy used for processing info dictionaries.
Parameters
----------
issues : pd.DataFrame
Stores information about each individual issue found in the data,
on a per-example basis.
issue_summary : pd.DataFrame
Summarizes the overall statistics for each issue type.
info : dict
A dictionary that contains information and statistics about the data and each issue type.
"""
def __init__(self, data: Data, strategy: Type[_InfoStrategy]) -> None:
super().__init__(data, strategy)
def _update_issues_imagelab(self, imagelab: "Imagelab", overlapping_issues: List[str]) -> None:
overwrite_columns = [f"is_{issue_type}_issue" for issue_type in overlapping_issues]
overwrite_columns.extend([f"{issue_type}_score" for issue_type in overlapping_issues])
if overwrite_columns:
warnings.warn(
f"Overwriting columns {overwrite_columns} in self.issues with "
f"columns from imagelab."
)
self.issues.drop(columns=overwrite_columns, inplace=True)
new_columnns = list(set(imagelab.issues.columns).difference(self.issues.columns))
self.issues = self.issues.join(imagelab.issues[new_columnns], how="outer")
[docs] def filter_based_on_max_prevalence(self, issue_summary: pd.DataFrame, max_num: int):
removed_issues = issue_summary[issue_summary["num_images"] > max_num]["issue_type"].tolist()
if len(removed_issues) > 0:
print(
f"Removing {', '.join(removed_issues)} from potential issues in the dataset as it exceeds max_prevalence={IMAGELAB_ISSUES_MAX_PREVALENCE}"
)
return issue_summary[issue_summary["num_images"] <= max_num].copy()
[docs] def collect_issues_from_imagelab(self, imagelab: "Imagelab", issue_types: List[str]) -> None:
"""
Collect results from Imagelab and update datalab.issues and datalab.issue_summary
Parameters
----------
imagelab: Imagelab
Imagelab instance that run all the checks for image issue types
"""
overlapping_issues = list(set(self.issue_summary["issue_type"]) & set(issue_types))
self._update_issues_imagelab(imagelab, overlapping_issues)
if overlapping_issues:
warnings.warn(
f"Overwriting {overlapping_issues} rows in self.issue_summary from imagelab."
)
self.issue_summary = self.issue_summary[
~self.issue_summary["issue_type"].isin(overlapping_issues)
]
imagelab_summary_copy = imagelab.issue_summary.copy()
imagelab_summary_copy = self.filter_based_on_max_prevalence(
imagelab_summary_copy, int(IMAGELAB_ISSUES_MAX_PREVALENCE * len(self.issues))
)
imagelab_summary_copy.rename({"num_images": "num_issues"}, axis=1, inplace=True)
self.issue_summary = pd.concat(
[self.issue_summary, imagelab_summary_copy], axis=0, ignore_index=True
)
for issue_type in issue_types:
self._update_issue_info(issue_type, imagelab.info[issue_type])
[docs] def get_info(self, issue_name: Optional[str] = None) -> Dict[str, Any]:
# Extend method for fetching info about spurious correlations
if issue_name != "spurious_correlations":
return super().get_info(issue_name)
correlations_info = self.info.get("spurious_correlations", {})
if not correlations_info:
raise ValueError(
"Spurious correlations have not been calculated. Run find_issues() first."
)
return correlations_info
[docs]class CorrelationVisualizer:
"""Class to visualize images corresponding to the extreme (minimum and maximum) individual
scores for each of the detected correlated properties.
"""
def __init__(self):
# Wrapper for VizManager that's from the optional cleanvision dependency
try:
from cleanvision.utils.viz_manager import VizManager
self.viz_manager = VizManager
except ImportError:
raise ImportError(
"cleanvision is required for correlation visualization. Please install it to use this feature."
)
[docs] def visualize(
self, images: List, title_info: Dict, ncols: int = 2, cell_size: tuple = (2, 2)
) -> None:
self.viz_manager.individual_images(
images=images,
title_info=title_info,
ncols=ncols,
cell_size=cell_size,
)
[docs]class CorrelationReporter:
"""Class to report spurious correlations between image features and class labels detected in the data.
If no spurious correlations are found, the class will not report anything.
"""
def __init__(self, data_issues: "DataIssues", imagelab: "Imagelab"):
self.imagelab: "Imagelab" = imagelab
self.data_issues = data_issues
self.threshold = data_issues.get_info("spurious_correlations").get("threshold")
if not self.threshold:
raise ValueError(
"Spurious correlations have not been calculated. Run find_issues() first."
)
self.visualizer = CorrelationVisualizer()
[docs] def report(self) -> None:
"""Reports spurious correlations between image features and class labels detected in the data,
if any are found.
"""
correlated_properties = self._get_correlated_properties()
if not correlated_properties:
return
self._print_correlation_summary()
correlations_df = cast(
pd.DataFrame, self.data_issues.get_info("spurious_correlations").get("correlations_df")
)
filtered_correlations_df = self._get_filtered_correlated_properties(
correlations_df, correlated_properties
)
print(filtered_correlations_df.to_string(index=False) + "\n")
self._visualize_extremes(correlated_properties, self.data_issues)
def _print_correlation_summary(self) -> None:
print("\n\n")
report_correlation_header = "Summary of (potentially spurious) correlations between image properties and class labels detected in the data:\n\n"
report_correlation_metric = "Lower scores below correspond to images properties that are more strongly correlated with the class labels.\n\n"
print(report_correlation_header + report_correlation_metric)
def _visualize_extremes(
self, correlated_properties: List[str], data_issues: "DataIssues"
) -> None:
report_extremal_images = "Here are the images corresponding to the extreme (minimum and maximum) individual scores for each of the detected correlated properties:\n\n"
print(report_extremal_images)
issues = data_issues.get_issues()
correlated_indices = {
prop: [issues[prop].idxmin(), issues[prop].idxmax()] for prop in correlated_properties
}
self._visualize(correlated_indices, issues)
def _visualize(self, correlated_indices: Dict[str, List[Any]], issues: pd.DataFrame) -> None:
for prop, image_ids in correlated_indices.items():
print(
f"{'Images with minimum and maximum individual scores for ' + prop.replace('_score', '') + ' issue:'}\n"
)
title_info = {"scores": [f"score: {issues.loc[id, prop]:.4f}" for id in image_ids]}
self.visualizer.visualize(
images=[self.imagelab._dataset[id] for id in image_ids],
title_info=title_info,
)
def _get_correlated_properties(self) -> List[str]:
correlations_df = self.data_issues.get_info("spurious_correlations").get("correlations_df")
if correlations_df is None or correlations_df.empty:
return []
return correlations_df.query("score < @self.threshold")["property"].tolist()
def _get_filtered_correlated_properties(
self, correlations_df: pd.DataFrame, correlated_properties: List[str]
) -> pd.DataFrame:
query_str = "property in @correlated_properties"
filtered_correlations_df = correlations_df.query(query_str)
filtered_correlations_df.loc[:, "property"] = filtered_correlations_df["property"].apply(
lambda x: x.replace("_score", "")
)
return filtered_correlations_df
[docs]class ImagelabReporterAdapter(Reporter):
def __init__(
self,
data_issues: "DataIssues",
imagelab: "Imagelab",
task: Task,
verbosity: int = 1,
include_description: bool = True,
show_summary_score: bool = False,
show_all_issues: bool = False,
):
super().__init__(
data_issues=data_issues,
task=task,
verbosity=verbosity,
include_description=include_description,
show_summary_score=show_summary_score,
show_all_issues=show_all_issues,
)
self.imagelab = imagelab
self.correlation_reporter: Optional[CorrelationReporter] = None
try:
self.correlation_reporter = CorrelationReporter(data_issues, imagelab)
except:
# Spurious correlations have not been calculated
self.correlation_reporter = None
[docs] def report(self, num_examples: int) -> None:
super().report(num_examples)
self._report_imagelab(num_examples)
# Only report spurious correlations if they've been calculated & detected
if self.correlation_reporter is not None:
self.correlation_reporter.report()
def _report_imagelab(self, num_examples):
print("\n\n")
self.imagelab.report(
num_images=num_examples,
max_prevalence=IMAGELAB_ISSUES_MAX_PREVALENCE,
print_summary=False,
verbosity=0,
show_id=True,
)
[docs]class ImagelabIssueFinderAdapter(IssueFinder):
def __init__(self, datalab, task, verbosity):
super().__init__(datalab, task, verbosity)
self.imagelab = self.datalab._imagelab
def _get_imagelab_issue_types(self, issue_types, **kwargs):
if issue_types is None:
return DEFAULT_CLEANVISION_ISSUES
if "image_issue_types" not in issue_types:
return None
issue_types_copy = {}
for issue_type, params in issue_types["image_issue_types"].items():
if not params:
issue_types_copy[issue_type] = DEFAULT_CLEANVISION_ISSUES[issue_type]
else:
issue_types_copy[issue_type] = params
return issue_types_copy
[docs] def find_issues(
self,
*,
pred_probs: Optional[np.ndarray] = None,
features: Optional[npt.NDArray] = None,
knn_graph: Optional[csr_matrix] = None,
issue_types: Optional[Dict[str, Any]] = None,
) -> None:
issue_types_to_ignore_in_datalab = ["image_issue_types", "spurious_correlations"]
datalab_issue_types = (
{k: v for k, v in issue_types.items() if k not in issue_types_to_ignore_in_datalab}
if issue_types
else issue_types
)
super().find_issues(
pred_probs=pred_probs,
features=features,
knn_graph=knn_graph,
issue_types=datalab_issue_types,
)
issue_types_copy = self._get_imagelab_issue_types(issue_types)
if issue_types_copy:
try:
if self.verbosity:
print(f'Finding {", ".join(issue_types_copy.keys())} images ...')
self.imagelab.find_issues(issue_types=issue_types_copy, verbose=False)
self.datalab.data_issues.collect_statistics(self.imagelab)
self.datalab.data_issues.collect_issues_from_imagelab(
self.imagelab, issue_types_copy.keys()
)
except Exception as e:
print(f"Error in checking for image issues: {e}")
# if issue_types is neither 'None' nor empty dictionary (non-trivial) but
# there is no mention of 'spurious_correlations', we return.
if issue_types and "spurious_correlations" not in issue_types:
return
# Check if all vision issue scores are computed
imagelab_columns = self.imagelab.issues.columns.tolist()
if all(
default_cleanvision_issue + "_score" not in imagelab_columns
for default_cleanvision_issue in DEFAULT_CLEANVISION_ISSUES.keys()
):
print("Skipping spurious correlations check: Image property scores not available.")
print(
"To include this check, run find_issues() without parameters to compute all scores."
)
return
# Spurious correlation part must be run
print("Finding spurious correlation issues in the dataset ...")
# the else part of the following must contain 'spurious_correlations' key
spurious_correlation_issue_types = (
SPURIOUS_CORRELATION_ISSUE["spurious_correlations"]
if not issue_types
else issue_types["spurious_correlations"]
)
# If threshold is not expicitly given (e.g. lab.find_issues("issue_types={"spurious_correlations": {}"))
# we extract the default value from SPURIOUS_CORRELATION_ISSUE
spurious_correlation_issue_threshold = spurious_correlation_issue_types.get(
"threshold", SPURIOUS_CORRELATION_ISSUE["spurious_correlations"]["threshold"]
)
try:
if self.datalab.has_labels:
self.datalab.data_issues.info["spurious_correlations"] = (
handle_spurious_correlations(
imagelab_issues=self.imagelab.issues,
labels=self.datalab.labels,
threshold=spurious_correlation_issue_threshold,
)
)
except Exception as e:
print(f"Error in checking for spurious correlations: {e}")
[docs]def handle_spurious_correlations(
*,
imagelab_issues: pd.DataFrame,
labels: Union[np.ndarray, List[List[int]]],
threshold: float,
**_,
) -> Dict[str, Any]:
imagelab_columns = imagelab_issues.columns.tolist()
score_columns = [col for col in imagelab_columns if col.endswith("_score")]
correlations_df = SpuriousCorrelations(
data=imagelab_issues[score_columns], labels=labels
).calculate_correlations()
return {
"correlations_df": correlations_df,
"threshold": threshold,
}