Source code for cleanlab.datalab.internal.adapter.imagelab
"""An internal wrapper around the Imagelab class from the CleanVision package to incorporate it into Datalab.This allows low-quality images to be detected alongside other issues in computer vision datasets.The methods/classes in this module are just intended for internal use."""importwarningsfromtypingimportTYPE_CHECKING,Any,Dict,List,Optional,Type,cast,Unionimportnumpyasnpimportnumpy.typingasnptimportpandasaspdfromscipy.sparseimportcsr_matrixfromcleanlab.datalab.internal.adapter.constantsimport(DEFAULT_CLEANVISION_ISSUES,IMAGELAB_ISSUES_MAX_PREVALENCE,SPURIOUS_CORRELATION_ISSUE,)fromcleanlab.datalab.internal.dataimportDatafromcleanlab.datalab.internal.data_issuesimportDataIssues,_InfoStrategyfromcleanlab.datalab.internal.issue_finderimportIssueFinderfromcleanlab.datalab.internal.reportimportReporterfromcleanlab.datalab.internal.taskimportTaskfromcleanlab.datalab.internal.spurious_correlationimportSpuriousCorrelationsifTYPE_CHECKING:# pragma: no coverfromcleanvisionimportImagelabfromdatasets.arrow_datasetimportDataset
[docs]defcreate_imagelab(dataset:"Dataset",image_key:Optional[str])->Optional["Imagelab"]:"""Creates Imagelab instance for running CleanVision checks. CleanVision checks are only supported for huggingface datasets as of now. Parameters ---------- dataset: datasets.Dataset Huggingface dataset used by Imagelab image_key: str key for image feature in the huggingface dataset Returns ------- Imagelab """imagelab=Noneifnotimage_key:returnimagelabtry:fromcleanvisionimportImagelabfromdatasets.arrow_datasetimportDatasetifisinstance(dataset,Dataset):imagelab=Imagelab(hf_dataset=dataset,image_key=image_key)else:raiseValueError("For now, only huggingface datasets are supported for running cleanvision checks inside cleanlab. You can easily convert most datasets to the huggingface dataset format.")exceptImportError:raiseImportError("Cannot import required image packages. Please install them via: `pip install cleanlab[image]` or just install cleanlab with ""all optional dependencies via: `pip install cleanlab[all]`")returnimagelab
[docs]classImagelabDataIssuesAdapter(DataIssues):""" Class that collects and stores information and statistics on issues found in a dataset. Parameters ---------- data : The data object for which the issues are being collected. strategy : Strategy used for processing info dictionaries. Parameters ---------- issues : pd.DataFrame Stores information about each individual issue found in the data, on a per-example basis. issue_summary : pd.DataFrame Summarizes the overall statistics for each issue type. info : dict A dictionary that contains information and statistics about the data and each issue type. """def__init__(self,data:Data,strategy:Type[_InfoStrategy])->None:super().__init__(data,strategy)def_update_issues_imagelab(self,imagelab:"Imagelab",overlapping_issues:List[str])->None:overwrite_columns=[f"is_{issue_type}_issue"forissue_typeinoverlapping_issues]overwrite_columns.extend([f"{issue_type}_score"forissue_typeinoverlapping_issues])ifoverwrite_columns:warnings.warn(f"Overwriting columns {overwrite_columns} in self.issues with "f"columns from imagelab.")self.issues.drop(columns=overwrite_columns,inplace=True)new_columnns=list(set(imagelab.issues.columns).difference(self.issues.columns))self.issues=self.issues.join(imagelab.issues[new_columnns],how="outer")
[docs]deffilter_based_on_max_prevalence(self,issue_summary:pd.DataFrame,max_num:int):removed_issues=issue_summary[issue_summary["num_images"]>max_num]["issue_type"].tolist()iflen(removed_issues)>0:print(f"Removing {', '.join(removed_issues)} from potential issues in the dataset as it exceeds max_prevalence={IMAGELAB_ISSUES_MAX_PREVALENCE}")returnissue_summary[issue_summary["num_images"]<=max_num].copy()
[docs]defcollect_issues_from_imagelab(self,imagelab:"Imagelab",issue_types:List[str])->None:""" Collect results from Imagelab and update datalab.issues and datalab.issue_summary Parameters ---------- imagelab: Imagelab Imagelab instance that run all the checks for image issue types """overlapping_issues=list(set(self.issue_summary["issue_type"])&set(issue_types))self._update_issues_imagelab(imagelab,overlapping_issues)ifoverlapping_issues:warnings.warn(f"Overwriting {overlapping_issues} rows in self.issue_summary from imagelab.")self.issue_summary=self.issue_summary[~self.issue_summary["issue_type"].isin(overlapping_issues)]imagelab_summary_copy=imagelab.issue_summary.copy()imagelab_summary_copy=self.filter_based_on_max_prevalence(imagelab_summary_copy,int(IMAGELAB_ISSUES_MAX_PREVALENCE*len(self.issues)))imagelab_summary_copy.rename({"num_images":"num_issues"},axis=1,inplace=True)self.issue_summary=pd.concat([self.issue_summary,imagelab_summary_copy],axis=0,ignore_index=True)forissue_typeinissue_types:self._update_issue_info(issue_type,imagelab.info[issue_type])
[docs]defget_info(self,issue_name:Optional[str]=None)->Dict[str,Any]:# Extend method for fetching info about spurious correlationsifissue_name!="spurious_correlations":returnsuper().get_info(issue_name)correlations_info=self.info.get("spurious_correlations",{})ifnotcorrelations_info:raiseValueError("Spurious correlations have not been calculated. Run find_issues() first.")returncorrelations_info
[docs]classCorrelationVisualizer:"""Class to visualize images corresponding to the extreme (minimum and maximum) individual scores for each of the detected correlated properties. """def__init__(self):# Wrapper for VizManager that's from the optional cleanvision dependencytry:fromcleanvision.utils.viz_managerimportVizManagerself.viz_manager=VizManagerexceptImportError:raiseImportError("cleanvision is required for correlation visualization. Please install it to use this feature.")
[docs]classCorrelationReporter:"""Class to report spurious correlations between image features and class labels detected in the data. If no spurious correlations are found, the class will not report anything. """def__init__(self,data_issues:"DataIssues",imagelab:"Imagelab"):self.imagelab:"Imagelab"=imagelabself.data_issues=data_issuesself.threshold=data_issues.get_info("spurious_correlations").get("threshold")ifnotself.threshold:raiseValueError("Spurious correlations have not been calculated. Run find_issues() first.")self.visualizer=CorrelationVisualizer()
[docs]defreport(self)->None:"""Reports spurious correlations between image features and class labels detected in the data, if any are found. """correlated_properties=self._get_correlated_properties()ifnotcorrelated_properties:returnself._print_correlation_summary()correlations_df=cast(pd.DataFrame,self.data_issues.get_info("spurious_correlations").get("correlations_df"))filtered_correlations_df=self._get_filtered_correlated_properties(correlations_df,correlated_properties)print(filtered_correlations_df.to_string(index=False)+"\n")self._visualize_extremes(correlated_properties,self.data_issues)
def_print_correlation_summary(self)->None:print("\n\n")report_correlation_header="Summary of (potentially spurious) correlations between image properties and class labels detected in the data:\n\n"report_correlation_metric="Lower scores below correspond to images properties that are more strongly correlated with the class labels.\n\n"print(report_correlation_header+report_correlation_metric)def_visualize_extremes(self,correlated_properties:List[str],data_issues:"DataIssues")->None:report_extremal_images="Here are the images corresponding to the extreme (minimum and maximum) individual scores for each of the detected correlated properties:\n\n"print(report_extremal_images)issues=data_issues.get_issues()correlated_indices={prop:[issues[prop].idxmin(),issues[prop].idxmax()]forpropincorrelated_properties}self._visualize(correlated_indices,issues)def_visualize(self,correlated_indices:Dict[str,List[Any]],issues:pd.DataFrame)->None:forprop,image_idsincorrelated_indices.items():print(f"{'Images with minimum and maximum individual scores for '+prop.replace('_score','')+' issue:'}\n")title_info={"scores":[f"score: {issues.loc[id,prop]:.4f}"foridinimage_ids]}self.visualizer.visualize(images=[self.imagelab._dataset[id]foridinimage_ids],title_info=title_info,)def_get_correlated_properties(self)->List[str]:correlations_df=self.data_issues.get_info("spurious_correlations").get("correlations_df")ifcorrelations_dfisNoneorcorrelations_df.empty:return[]returncorrelations_df.query("score < @self.threshold")["property"].tolist()def_get_filtered_correlated_properties(self,correlations_df:pd.DataFrame,correlated_properties:List[str])->pd.DataFrame:query_str="property in @correlated_properties"filtered_correlations_df=correlations_df.query(query_str)filtered_correlations_df.loc[:,"property"]=filtered_correlations_df["property"].apply(lambdax:x.replace("_score",""))returnfiltered_correlations_df
[docs]classImagelabReporterAdapter(Reporter):def__init__(self,data_issues:"DataIssues",imagelab:"Imagelab",task:Task,verbosity:int=1,include_description:bool=True,show_summary_score:bool=False,show_all_issues:bool=False,):super().__init__(data_issues=data_issues,task=task,verbosity=verbosity,include_description=include_description,show_summary_score=show_summary_score,show_all_issues=show_all_issues,)self.imagelab=imagelabself.correlation_reporter:Optional[CorrelationReporter]=Nonetry:self.correlation_reporter=CorrelationReporter(data_issues,imagelab)except:# Spurious correlations have not been calculatedself.correlation_reporter=None
[docs]defreport(self,num_examples:int)->None:super().report(num_examples)self._report_imagelab(num_examples)# Only report spurious correlations if they've been calculated & detectedifself.correlation_reporterisnotNone:self.correlation_reporter.report()
[docs]deffind_issues(self,*,pred_probs:Optional[np.ndarray]=None,features:Optional[npt.NDArray]=None,knn_graph:Optional[csr_matrix]=None,issue_types:Optional[Dict[str,Any]]=None,)->None:issue_types_to_ignore_in_datalab=["image_issue_types","spurious_correlations"]datalab_issue_types=({k:vfork,vinissue_types.items()ifknotinissue_types_to_ignore_in_datalab}ifissue_typeselseissue_types)super().find_issues(pred_probs=pred_probs,features=features,knn_graph=knn_graph,issue_types=datalab_issue_types,)issue_types_copy=self._get_imagelab_issue_types(issue_types)ifissue_types_copy:try:ifself.verbosity:print(f'Finding {", ".join(issue_types_copy.keys())} images ...')self.imagelab.find_issues(issue_types=issue_types_copy,verbose=False)self.datalab.data_issues.collect_statistics(self.imagelab)self.datalab.data_issues.collect_issues_from_imagelab(self.imagelab,issue_types_copy.keys())exceptExceptionase:print(f"Error in checking for image issues: {e}")# if issue_types is neither 'None' nor empty dictionary (non-trivial) but# there is no mention of 'spurious_correlations', we return.ifissue_typesand"spurious_correlations"notinissue_types:return# Check if all vision issue scores are computedimagelab_columns=self.imagelab.issues.columns.tolist()ifall(default_cleanvision_issue+"_score"notinimagelab_columnsfordefault_cleanvision_issueinDEFAULT_CLEANVISION_ISSUES.keys()):print("Skipping spurious correlations check: Image property scores not available.")print("To include this check, run find_issues() without parameters to compute all scores.")return# Spurious correlation part must be runprint("Finding spurious correlation issues in the dataset ...")# the else part of the following must contain 'spurious_correlations' keyspurious_correlation_issue_types=(SPURIOUS_CORRELATION_ISSUE["spurious_correlations"]ifnotissue_typeselseissue_types["spurious_correlations"])# If threshold is not expicitly given (e.g. lab.find_issues("issue_types={"spurious_correlations": {}"))# we extract the default value from SPURIOUS_CORRELATION_ISSUEspurious_correlation_issue_threshold=spurious_correlation_issue_types.get("threshold",SPURIOUS_CORRELATION_ISSUE["spurious_correlations"]["threshold"])try:ifself.datalab.has_labels:self.datalab.data_issues.info["spurious_correlations"]=(handle_spurious_correlations(imagelab_issues=self.imagelab.issues,labels=self.datalab.labels,threshold=spurious_correlation_issue_threshold,))exceptExceptionase:print(f"Error in checking for spurious correlations: {e}")