Source code for cleanlab.datalab.internal.adapter.imagelab

"""An internal wrapper around the Imagelab class from the CleanVision package to incorporate it into Datalab.
This allows low-quality images to be detected alongside other issues in computer vision datasets.
The methods/classes in this module are just intended for internal use.
"""

import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, cast, Union

import numpy as np
import numpy.typing as npt
import pandas as pd
from scipy.sparse import csr_matrix

from cleanlab.datalab.internal.adapter.constants import (
    DEFAULT_CLEANVISION_ISSUES,
    IMAGELAB_ISSUES_MAX_PREVALENCE,
    SPURIOUS_CORRELATION_ISSUE,
)
from cleanlab.datalab.internal.data import Data
from cleanlab.datalab.internal.data_issues import DataIssues, _InfoStrategy
from cleanlab.datalab.internal.issue_finder import IssueFinder
from cleanlab.datalab.internal.report import Reporter
from cleanlab.datalab.internal.task import Task
from cleanlab.datalab.internal.spurious_correlation import SpuriousCorrelations

if TYPE_CHECKING:  # pragma: no cover
    from cleanvision import Imagelab
    from datasets.arrow_dataset import Dataset


[docs]def create_imagelab(dataset: "Dataset", image_key: Optional[str]) -> Optional["Imagelab"]: """Creates Imagelab instance for running CleanVision checks. CleanVision checks are only supported for huggingface datasets as of now. Parameters ---------- dataset: datasets.Dataset Huggingface dataset used by Imagelab image_key: str key for image feature in the huggingface dataset Returns ------- Imagelab """ imagelab = None if not image_key: return imagelab try: from cleanvision import Imagelab from datasets.arrow_dataset import Dataset if isinstance(dataset, Dataset): imagelab = Imagelab(hf_dataset=dataset, image_key=image_key) else: raise ValueError( "For now, only huggingface datasets are supported for running cleanvision checks inside cleanlab. You can easily convert most datasets to the huggingface dataset format." ) except ImportError: raise ImportError( "Cannot import required image packages. Please install them via: `pip install cleanlab[image]` or just install cleanlab with " "all optional dependencies via: `pip install cleanlab[all]`" ) return imagelab
[docs]class ImagelabDataIssuesAdapter(DataIssues): """ Class that collects and stores information and statistics on issues found in a dataset. Parameters ---------- data : The data object for which the issues are being collected. strategy : Strategy used for processing info dictionaries. Parameters ---------- issues : pd.DataFrame Stores information about each individual issue found in the data, on a per-example basis. issue_summary : pd.DataFrame Summarizes the overall statistics for each issue type. info : dict A dictionary that contains information and statistics about the data and each issue type. """ def __init__(self, data: Data, strategy: Type[_InfoStrategy]) -> None: super().__init__(data, strategy) def _update_issues_imagelab(self, imagelab: "Imagelab", overlapping_issues: List[str]) -> None: overwrite_columns = [f"is_{issue_type}_issue" for issue_type in overlapping_issues] overwrite_columns.extend([f"{issue_type}_score" for issue_type in overlapping_issues]) if overwrite_columns: warnings.warn( f"Overwriting columns {overwrite_columns} in self.issues with " f"columns from imagelab." ) self.issues.drop(columns=overwrite_columns, inplace=True) new_columnns = list(set(imagelab.issues.columns).difference(self.issues.columns)) self.issues = self.issues.join(imagelab.issues[new_columnns], how="outer")
[docs] def filter_based_on_max_prevalence(self, issue_summary: pd.DataFrame, max_num: int): removed_issues = issue_summary[issue_summary["num_images"] > max_num]["issue_type"].tolist() if len(removed_issues) > 0: print( f"Removing {', '.join(removed_issues)} from potential issues in the dataset as it exceeds max_prevalence={IMAGELAB_ISSUES_MAX_PREVALENCE}" ) return issue_summary[issue_summary["num_images"] <= max_num].copy()
[docs] def collect_issues_from_imagelab(self, imagelab: "Imagelab", issue_types: List[str]) -> None: """ Collect results from Imagelab and update datalab.issues and datalab.issue_summary Parameters ---------- imagelab: Imagelab Imagelab instance that run all the checks for image issue types """ overlapping_issues = list(set(self.issue_summary["issue_type"]) & set(issue_types)) self._update_issues_imagelab(imagelab, overlapping_issues) if overlapping_issues: warnings.warn( f"Overwriting {overlapping_issues} rows in self.issue_summary from imagelab." ) self.issue_summary = self.issue_summary[ ~self.issue_summary["issue_type"].isin(overlapping_issues) ] imagelab_summary_copy = imagelab.issue_summary.copy() imagelab_summary_copy = self.filter_based_on_max_prevalence( imagelab_summary_copy, int(IMAGELAB_ISSUES_MAX_PREVALENCE * len(self.issues)) ) imagelab_summary_copy.rename({"num_images": "num_issues"}, axis=1, inplace=True) self.issue_summary = pd.concat( [self.issue_summary, imagelab_summary_copy], axis=0, ignore_index=True ) for issue_type in issue_types: self._update_issue_info(issue_type, imagelab.info[issue_type])
[docs] def get_info(self, issue_name: Optional[str] = None) -> Dict[str, Any]: # Extend method for fetching info about spurious correlations if issue_name != "spurious_correlations": return super().get_info(issue_name) correlations_info = self.info.get("spurious_correlations", {}) if not correlations_info: raise ValueError( "Spurious correlations have not been calculated. Run find_issues() first." ) return correlations_info
[docs]class CorrelationVisualizer: """Class to visualize images corresponding to the extreme (minimum and maximum) individual scores for each of the detected correlated properties. """ def __init__(self): # Wrapper for VizManager that's from the optional cleanvision dependency try: from cleanvision.utils.viz_manager import VizManager self.viz_manager = VizManager except ImportError: raise ImportError( "cleanvision is required for correlation visualization. Please install it to use this feature." )
[docs] def visualize( self, images: List, title_info: Dict, ncols: int = 2, cell_size: tuple = (2, 2) ) -> None: self.viz_manager.individual_images( images=images, title_info=title_info, ncols=ncols, cell_size=cell_size, )
[docs]class CorrelationReporter: """Class to report spurious correlations between image features and class labels detected in the data. If no spurious correlations are found, the class will not report anything. """ def __init__(self, data_issues: "DataIssues", imagelab: "Imagelab"): self.imagelab: "Imagelab" = imagelab self.data_issues = data_issues self.threshold = data_issues.get_info("spurious_correlations").get("threshold") if not self.threshold: raise ValueError( "Spurious correlations have not been calculated. Run find_issues() first." ) self.visualizer = CorrelationVisualizer()
[docs] def report(self) -> None: """Reports spurious correlations between image features and class labels detected in the data, if any are found. """ correlated_properties = self._get_correlated_properties() if not correlated_properties: return self._print_correlation_summary() correlations_df = cast( pd.DataFrame, self.data_issues.get_info("spurious_correlations").get("correlations_df") ) filtered_correlations_df = self._get_filtered_correlated_properties( correlations_df, correlated_properties ) print(filtered_correlations_df.to_string(index=False) + "\n") self._visualize_extremes(correlated_properties, self.data_issues)
def _print_correlation_summary(self) -> None: print("\n\n") report_correlation_header = "Summary of (potentially spurious) correlations between image properties and class labels detected in the data:\n\n" report_correlation_metric = "Lower scores below correspond to images properties that are more strongly correlated with the class labels.\n\n" print(report_correlation_header + report_correlation_metric) def _visualize_extremes( self, correlated_properties: List[str], data_issues: "DataIssues" ) -> None: report_extremal_images = "Here are the images corresponding to the extreme (minimum and maximum) individual scores for each of the detected correlated properties:\n\n" print(report_extremal_images) issues = data_issues.get_issues() correlated_indices = { prop: [issues[prop].idxmin(), issues[prop].idxmax()] for prop in correlated_properties } self._visualize(correlated_indices, issues) def _visualize(self, correlated_indices: Dict[str, List[Any]], issues: pd.DataFrame) -> None: for prop, image_ids in correlated_indices.items(): print( f"{'Images with minimum and maximum individual scores for ' + prop.replace('_score', '') + ' issue:'}\n" ) title_info = {"scores": [f"score: {issues.loc[id, prop]:.4f}" for id in image_ids]} self.visualizer.visualize( images=[self.imagelab._dataset[id] for id in image_ids], title_info=title_info, ) def _get_correlated_properties(self) -> List[str]: correlations_df = self.data_issues.get_info("spurious_correlations").get("correlations_df") if correlations_df is None or correlations_df.empty: return [] return correlations_df.query("score < @self.threshold")["property"].tolist() def _get_filtered_correlated_properties( self, correlations_df: pd.DataFrame, correlated_properties: List[str] ) -> pd.DataFrame: query_str = "property in @correlated_properties" filtered_correlations_df = correlations_df.query(query_str) filtered_correlations_df.loc[:, "property"] = filtered_correlations_df["property"].apply( lambda x: x.replace("_score", "") ) return filtered_correlations_df
[docs]class ImagelabReporterAdapter(Reporter): def __init__( self, data_issues: "DataIssues", imagelab: "Imagelab", task: Task, verbosity: int = 1, include_description: bool = True, show_summary_score: bool = False, show_all_issues: bool = False, ): super().__init__( data_issues=data_issues, task=task, verbosity=verbosity, include_description=include_description, show_summary_score=show_summary_score, show_all_issues=show_all_issues, ) self.imagelab = imagelab self.correlation_reporter: Optional[CorrelationReporter] = None try: self.correlation_reporter = CorrelationReporter(data_issues, imagelab) except: # Spurious correlations have not been calculated self.correlation_reporter = None
[docs] def report(self, num_examples: int) -> None: super().report(num_examples) self._report_imagelab(num_examples) # Only report spurious correlations if they've been calculated & detected if self.correlation_reporter is not None: self.correlation_reporter.report()
def _report_imagelab(self, num_examples): print("\n\n") self.imagelab.report( num_images=num_examples, max_prevalence=IMAGELAB_ISSUES_MAX_PREVALENCE, print_summary=False, verbosity=0, show_id=True, )
[docs]class ImagelabIssueFinderAdapter(IssueFinder): def __init__(self, datalab, task, verbosity): super().__init__(datalab, task, verbosity) self.imagelab = self.datalab._imagelab def _get_imagelab_issue_types(self, issue_types, **kwargs): if issue_types is None: return DEFAULT_CLEANVISION_ISSUES if "image_issue_types" not in issue_types: return None issue_types_copy = {} for issue_type, params in issue_types["image_issue_types"].items(): if not params: issue_types_copy[issue_type] = DEFAULT_CLEANVISION_ISSUES[issue_type] else: issue_types_copy[issue_type] = params return issue_types_copy
[docs] def find_issues( self, *, pred_probs: Optional[np.ndarray] = None, features: Optional[npt.NDArray] = None, knn_graph: Optional[csr_matrix] = None, issue_types: Optional[Dict[str, Any]] = None, ) -> None: issue_types_to_ignore_in_datalab = ["image_issue_types", "spurious_correlations"] datalab_issue_types = ( {k: v for k, v in issue_types.items() if k not in issue_types_to_ignore_in_datalab} if issue_types else issue_types ) super().find_issues( pred_probs=pred_probs, features=features, knn_graph=knn_graph, issue_types=datalab_issue_types, ) issue_types_copy = self._get_imagelab_issue_types(issue_types) if issue_types_copy: try: if self.verbosity: print(f'Finding {", ".join(issue_types_copy.keys())} images ...') self.imagelab.find_issues(issue_types=issue_types_copy, verbose=False) self.datalab.data_issues.collect_statistics(self.imagelab) self.datalab.data_issues.collect_issues_from_imagelab( self.imagelab, issue_types_copy.keys() ) except Exception as e: print(f"Error in checking for image issues: {e}") # if issue_types is neither 'None' nor empty dictionary (non-trivial) but # there is no mention of 'spurious_correlations', we return. if issue_types and "spurious_correlations" not in issue_types: return # Check if all vision issue scores are computed imagelab_columns = self.imagelab.issues.columns.tolist() if all( default_cleanvision_issue + "_score" not in imagelab_columns for default_cleanvision_issue in DEFAULT_CLEANVISION_ISSUES.keys() ): print("Skipping spurious correlations check: Image property scores not available.") print( "To include this check, run find_issues() without parameters to compute all scores." ) return # Spurious correlation part must be run print("Finding spurious correlation issues in the dataset ...") # the else part of the following must contain 'spurious_correlations' key spurious_correlation_issue_types = ( SPURIOUS_CORRELATION_ISSUE["spurious_correlations"] if not issue_types else issue_types["spurious_correlations"] ) # If threshold is not expicitly given (e.g. lab.find_issues("issue_types={"spurious_correlations": {}")) # we extract the default value from SPURIOUS_CORRELATION_ISSUE spurious_correlation_issue_threshold = spurious_correlation_issue_types.get( "threshold", SPURIOUS_CORRELATION_ISSUE["spurious_correlations"]["threshold"] ) try: if self.datalab.has_labels: self.datalab.data_issues.info["spurious_correlations"] = ( handle_spurious_correlations( imagelab_issues=self.imagelab.issues, labels=self.datalab.labels, threshold=spurious_correlation_issue_threshold, ) ) except Exception as e: print(f"Error in checking for spurious correlations: {e}")
[docs]def handle_spurious_correlations( *, imagelab_issues: pd.DataFrame, labels: Union[np.ndarray, List[List[int]]], threshold: float, **_, ) -> Dict[str, Any]: imagelab_columns = imagelab_issues.columns.tolist() score_columns = [col for col in imagelab_columns if col.endswith("_score")] correlations_df = SpuriousCorrelations( data=imagelab_issues[score_columns], labels=labels ).calculate_correlations() return { "correlations_df": correlations_df, "threshold": threshold, }