Warning
This documentation is for an old version (v2.7.0
) of cleanlab
. To see the documentation for the latest stable version (v2.7.1
), click here.
Source code for cleanlab.datalab.internal.adapter.imagelab
"""An internal wrapper around the Imagelab class from the CleanVision package to incorporate it into Datalab.
This allows low-quality images to be detected alongside other issues in computer vision datasets.
The methods/classes in this module are just intended for internal use.
"""
import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, cast, Union
import numpy as np
import numpy.typing as npt
import pandas as pd
from scipy.sparse import csr_matrix
from cleanlab.datalab.internal.adapter.constants import (
DEFAULT_CLEANVISION_ISSUES,
IMAGELAB_ISSUES_MAX_PREVALENCE,
SPURIOUS_CORRELATION_ISSUE,
)
from cleanlab.datalab.internal.data import Data
from cleanlab.datalab.internal.data_issues import DataIssues, _InfoStrategy
from cleanlab.datalab.internal.issue_finder import IssueFinder
from cleanlab.datalab.internal.report import Reporter
from cleanlab.datalab.internal.task import Task
from cleanlab.datalab.internal.spurious_correlation import SpuriousCorrelations
if TYPE_CHECKING: # pragma: no cover
from cleanvision import Imagelab
from datasets.arrow_dataset import Dataset
[docs]def create_imagelab(dataset: "Dataset", image_key: Optional[str]) -> Optional["Imagelab"]:
"""Creates Imagelab instance for running CleanVision checks. CleanVision checks are only supported for
huggingface datasets as of now.
Parameters
----------
dataset: datasets.Dataset
Huggingface dataset used by Imagelab
image_key: str
key for image feature in the huggingface dataset
Returns
-------
Imagelab
"""
imagelab = None
if not image_key:
return imagelab
try:
from cleanvision import Imagelab
from datasets.arrow_dataset import Dataset
if isinstance(dataset, Dataset):
imagelab = Imagelab(hf_dataset=dataset, image_key=image_key)
else:
raise ValueError(
"For now, only huggingface datasets are supported for running cleanvision checks inside cleanlab. You can easily convert most datasets to the huggingface dataset format."
)
except ImportError:
raise ImportError(
"Cannot import required image packages. Please install them via: `pip install cleanlab[image]` or just install cleanlab with "
"all optional dependencies via: `pip install cleanlab[all]`"
)
return imagelab
[docs]class ImagelabDataIssuesAdapter(DataIssues):
"""
Class that collects and stores information and statistics on issues found in a dataset.
Parameters
----------
data :
The data object for which the issues are being collected.
strategy :
Strategy used for processing info dictionaries.
Parameters
----------
issues : pd.DataFrame
Stores information about each individual issue found in the data,
on a per-example basis.
issue_summary : pd.DataFrame
Summarizes the overall statistics for each issue type.
info : dict
A dictionary that contains information and statistics about the data and each issue type.
"""
def __init__(self, data: Data, strategy: Type[_InfoStrategy]) -> None:
super().__init__(data, strategy)
def _update_issues_imagelab(self, imagelab: "Imagelab", overlapping_issues: List[str]) -> None:
overwrite_columns = [f"is_{issue_type}_issue" for issue_type in overlapping_issues]
overwrite_columns.extend([f"{issue_type}_score" for issue_type in overlapping_issues])
if overwrite_columns:
warnings.warn(
f"Overwriting columns {overwrite_columns} in self.issues with "
f"columns from imagelab."
)
self.issues.drop(columns=overwrite_columns, inplace=True)
new_columnns = list(set(imagelab.issues.columns).difference(self.issues.columns))
self.issues = self.issues.join(imagelab.issues[new_columnns], how="outer")
[docs] def filter_based_on_max_prevalence(self, issue_summary: pd.DataFrame, max_num: int):
removed_issues = issue_summary[issue_summary["num_images"] > max_num]["issue_type"].tolist()
if len(removed_issues) > 0:
print(
f"Removing {', '.join(removed_issues)} from potential issues in the dataset as it exceeds max_prevalence={IMAGELAB_ISSUES_MAX_PREVALENCE}"
)
return issue_summary[issue_summary["num_images"] <= max_num].copy()
[docs] def collect_issues_from_imagelab(self, imagelab: "Imagelab", issue_types: List[str]) -> None:
"""
Collect results from Imagelab and update datalab.issues and datalab.issue_summary
Parameters
----------
imagelab: Imagelab
Imagelab instance that run all the checks for image issue types
"""
overlapping_issues = list(set(self.issue_summary["issue_type"]) & set(issue_types))
self._update_issues_imagelab(imagelab, overlapping_issues)
if overlapping_issues:
warnings.warn(
f"Overwriting {overlapping_issues} rows in self.issue_summary from imagelab."
)
self.issue_summary = self.issue_summary[
~self.issue_summary["issue_type"].isin(overlapping_issues)
]
imagelab_summary_copy = imagelab.issue_summary.copy()
imagelab_summary_copy = self.filter_based_on_max_prevalence(
imagelab_summary_copy, int(IMAGELAB_ISSUES_MAX_PREVALENCE * len(self.issues))
)
imagelab_summary_copy.rename({"num_images": "num_issues"}, axis=1, inplace=True)
self.issue_summary = pd.concat(
[self.issue_summary, imagelab_summary_copy], axis=0, ignore_index=True
)
for issue_type in issue_types:
self._update_issue_info(issue_type, imagelab.info[issue_type])
[docs] def get_info(self, issue_name: Optional[str] = None) -> Dict[str, Any]:
# Extend method for fetching info about spurious correlations
if issue_name != "spurious_correlations":
return super().get_info(issue_name)
correlations_info = self.info.get("spurious_correlations", {})
if not correlations_info:
raise ValueError(
"Spurious correlations have not been calculated. Run find_issues() first."
)
return correlations_info
[docs]class CorrelationVisualizer:
"""Class to visualize images corresponding to the extreme (minimum and maximum) individual
scores for each of the detected correlated properties.
"""
def __init__(self):
# Wrapper for VizManager that's from the optional cleanvision dependency
try:
from cleanvision.utils.viz_manager import VizManager
self.viz_manager = VizManager
except ImportError:
raise ImportError(
"cleanvision is required for correlation visualization. Please install it to use this feature."
)
[docs] def visualize(
self, images: List, title_info: Dict, ncols: int = 2, cell_size: tuple = (2, 2)
) -> None:
self.viz_manager.individual_images(
images=images,
title_info=title_info,
ncols=ncols,
cell_size=cell_size,
)
[docs]class CorrelationReporter:
"""Class to report spurious correlations between image features and class labels detected in the data.
If no spurious correlations are found, the class will not report anything.
"""
def __init__(self, data_issues: "DataIssues", imagelab: "Imagelab"):
self.imagelab: "Imagelab" = imagelab
self.data_issues = data_issues
self.threshold = data_issues.get_info("spurious_correlations").get("threshold")
if not self.threshold:
raise ValueError(
"Spurious correlations have not been calculated. Run find_issues() first."
)
self.visualizer = CorrelationVisualizer()
[docs] def report(self) -> None:
"""Reports spurious correlations between image features and class labels detected in the data,
if any are found.
"""
correlated_properties = self._get_correlated_properties()
if not correlated_properties:
return
self._print_correlation_summary()
correlations_df = cast(
pd.DataFrame, self.data_issues.get_info("spurious_correlations").get("correlations_df")
)
filtered_correlations_df = self._get_filtered_correlated_properties(
correlations_df, correlated_properties
)
print(filtered_correlations_df.to_string(index=False) + "\n")
self._visualize_extremes(correlated_properties, self.data_issues)
def _print_correlation_summary(self) -> None:
print("\n\n")
report_correlation_header = "Summary of (potentially spurious) correlations between image properties and class labels detected in the data:\n\n"
report_correlation_metric = "Lower scores below correspond to images properties that are more strongly correlated with the class labels.\n\n"
print(report_correlation_header + report_correlation_metric)
def _visualize_extremes(
self, correlated_properties: List[str], data_issues: "DataIssues"
) -> None:
report_extremal_images = "Here are the images corresponding to the extreme (minimum and maximum) individual scores for each of the detected correlated properties:\n\n"
print(report_extremal_images)
issues = data_issues.get_issues()
correlated_indices = {
prop: [issues[prop].idxmin(), issues[prop].idxmax()] for prop in correlated_properties
}
self._visualize(correlated_indices, issues)
def _visualize(self, correlated_indices: Dict[str, List[Any]], issues: pd.DataFrame) -> None:
for prop, image_ids in correlated_indices.items():
print(
f"{'Images with minimum and maximum individual scores for ' + prop.replace('_score', '') + ' issue:'}\n"
)
title_info = {"scores": [f"score: {issues.loc[id, prop]:.4f}" for id in image_ids]}
self.visualizer.visualize(
images=[self.imagelab._dataset[id] for id in image_ids],
title_info=title_info,
)
def _get_correlated_properties(self) -> List[str]:
correlations_df = self.data_issues.get_info("spurious_correlations").get("correlations_df")
if correlations_df is None or correlations_df.empty:
return []
return correlations_df.query("score < @self.threshold")["property"].tolist()
def _get_filtered_correlated_properties(
self, correlations_df: pd.DataFrame, correlated_properties: List[str]
) -> pd.DataFrame:
query_str = "property in @correlated_properties"
filtered_correlations_df = correlations_df.query(query_str)
filtered_correlations_df.loc[:, "property"] = filtered_correlations_df["property"].apply(
lambda x: x.replace("_score", "")
)
return filtered_correlations_df
[docs]class ImagelabReporterAdapter(Reporter):
def __init__(
self,
data_issues: "DataIssues",
imagelab: "Imagelab",
task: Task,
verbosity: int = 1,
include_description: bool = True,
show_summary_score: bool = False,
show_all_issues: bool = False,
):
super().__init__(
data_issues=data_issues,
task=task,
verbosity=verbosity,
include_description=include_description,
show_summary_score=show_summary_score,
show_all_issues=show_all_issues,
)
self.imagelab = imagelab
self.correlation_reporter: Optional[CorrelationReporter] = None
try:
self.correlation_reporter = CorrelationReporter(data_issues, imagelab)
except:
# Spurious correlations have not been calculated
self.correlation_reporter = None
[docs] def report(self, num_examples: int) -> None:
super().report(num_examples)
self._report_imagelab(num_examples)
# Only report spurious correlations if they've been calculated & detected
if self.correlation_reporter is not None:
self.correlation_reporter.report()
def _report_imagelab(self, num_examples):
print("\n\n")
self.imagelab.report(
num_images=num_examples,
max_prevalence=IMAGELAB_ISSUES_MAX_PREVALENCE,
print_summary=False,
verbosity=0,
show_id=True,
)
[docs]class ImagelabIssueFinderAdapter(IssueFinder):
def __init__(self, datalab, task, verbosity):
super().__init__(datalab, task, verbosity)
self.imagelab = self.datalab._imagelab
def _get_imagelab_issue_types(self, issue_types, **kwargs):
if issue_types is None:
return DEFAULT_CLEANVISION_ISSUES
if "image_issue_types" not in issue_types:
return None
issue_types_copy = {}
for issue_type, params in issue_types["image_issue_types"].items():
if not params:
issue_types_copy[issue_type] = DEFAULT_CLEANVISION_ISSUES[issue_type]
else:
issue_types_copy[issue_type] = params
return issue_types_copy
[docs] def find_issues(
self,
*,
pred_probs: Optional[np.ndarray] = None,
features: Optional[npt.NDArray] = None,
knn_graph: Optional[csr_matrix] = None,
issue_types: Optional[Dict[str, Any]] = None,
) -> None:
issue_types_to_ignore_in_datalab = ["image_issue_types", "spurious_correlations"]
datalab_issue_types = (
{k: v for k, v in issue_types.items() if k not in issue_types_to_ignore_in_datalab}
if issue_types
else issue_types
)
super().find_issues(
pred_probs=pred_probs,
features=features,
knn_graph=knn_graph,
issue_types=datalab_issue_types,
)
issue_types_copy = self._get_imagelab_issue_types(issue_types)
if issue_types_copy:
try:
if self.verbosity:
print(f'Finding {", ".join(issue_types_copy.keys())} images ...')
self.imagelab.find_issues(issue_types=issue_types_copy, verbose=False)
self.datalab.data_issues.collect_statistics(self.imagelab)
self.datalab.data_issues.collect_issues_from_imagelab(
self.imagelab, issue_types_copy.keys()
)
except Exception as e:
print(f"Error in checking for image issues: {e}")
# if issue_types is neither 'None' nor empty dictionary (non-trivial) but
# there is no mention of 'spurious_correlations', we return.
if issue_types and "spurious_correlations" not in issue_types:
return
# Check if all vision issue scores are computed
imagelab_columns = self.imagelab.issues.columns.tolist()
if all(
default_cleanvision_issue + "_score" not in imagelab_columns
for default_cleanvision_issue in DEFAULT_CLEANVISION_ISSUES.keys()
):
print("Skipping spurious correlations check: Image property scores not available.")
print(
"To include this check, run find_issues() without parameters to compute all scores."
)
return
# Spurious correlation part must be run
print("Finding spurious correlation issues in the dataset ...")
# the else part of the following must contain 'spurious_correlations' key
spurious_correlation_issue_types = (
SPURIOUS_CORRELATION_ISSUE["spurious_correlations"]
if not issue_types
else issue_types["spurious_correlations"]
)
# If threshold is not expicitly given (e.g. lab.find_issues("issue_types={"spurious_correlations": {}"))
# we extract the default value from SPURIOUS_CORRELATION_ISSUE
spurious_correlation_issue_threshold = spurious_correlation_issue_types.get(
"threshold", SPURIOUS_CORRELATION_ISSUE["spurious_correlations"]["threshold"]
)
try:
if self.datalab.has_labels:
self.datalab.data_issues.info["spurious_correlations"] = (
handle_spurious_correlations(
imagelab_issues=self.imagelab.issues,
labels=self.datalab.labels,
threshold=spurious_correlation_issue_threshold,
)
)
except Exception as e:
print(f"Error in checking for spurious correlations: {e}")
[docs]def handle_spurious_correlations(
*,
imagelab_issues: pd.DataFrame,
labels: Union[np.ndarray, List[List[int]]],
threshold: float,
**_,
) -> Dict[str, Any]:
imagelab_columns = imagelab_issues.columns.tolist()
score_columns = [col for col in imagelab_columns if col.endswith("_score")]
correlations_df = SpuriousCorrelations(
data=imagelab_issues[score_columns], labels=labels
).calculate_correlations()
return {
"correlations_df": correlations_df,
"threshold": threshold,
}