Source code for cleanlab.segmentation.filter

# Copyright (C) 2017-2023  Cleanlab Inc.
# This file is part of cleanlab.
#
# cleanlab is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# cleanlab is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with cleanlab.  If not, see <https://www.gnu.org/licenses/>.

"""
Methods to find label issues in image semantic segmentation datasets, where each pixel in an image receives its own class label.

"""

from typing import Optional, Tuple

import numpy as np

from cleanlab.experimental.label_issues_batched import LabelInspector
from cleanlab.internal.segmentation_utils import _check_input, _get_valid_optional_params


[docs]def find_label_issues( labels: np.ndarray, pred_probs: np.ndarray, *, batch_size: Optional[int] = None, n_jobs: Optional[int] = None, verbose: bool = True, **kwargs, ) -> np.ndarray: """ Returns a boolean mask for the entire dataset, per pixel where ``True`` represents an example identified with a label issue and ``False`` represents an example of a pixel correctly labeled. * N - Number of images in the dataset * K - Number of classes in the dataset * H - Height of each image * W - Width of each image Tip --- If you encounter the error "pred_probs is not defined", try setting ``n_jobs=1``. Parameters ---------- labels: A discrete array of shape ``(N,H,W,)`` of noisy labels for a semantic segmentation dataset, i.e. some labels may be erroneous. *Format requirements*: For a dataset with K classes, each pixel must be labeled using an integer in 0, 1, ..., K-1. Tip --- If your labels are one hot encoded you can do: ``labels = np.argmax(labels_one_hot, axis=1)`` assuming that `labels_one_hot` is of dimension ``(N,K,H,W)``, in order to get properly formatted `labels`. pred_probs: An array of shape ``(N,K,H,W,)`` of model-predicted class probabilities, ``P(label=k|x)`` for each pixel ``x``. The prediction for each pixel is an array corresponding to the estimated likelihood that this pixel belongs to each of the ``K`` classes. The 2nd dimension of `pred_probs` must be ordered such that these probabilities correspond to class 0, 1, ..., K-1. batch_size: Optional size of image mini-batches used for computing the label issues in a streaming fashion (does not affect results, just the runtime and memory requirements). To maximize efficiency, try to use the largest `batch_size` your memory allows. If not provided, a good default is used. n_jobs: Optional number of processes for multiprocessing (default value = 1). Only used on Linux. If `n_jobs=None`, will use either the number of: physical cores if psutil is installed, or logical cores otherwise. verbose: Set to ``False`` to suppress all print statements. **kwargs: * downsample: int, Optional factor to shrink labels and pred_probs by. Default ``1`` Must be a factor divisible by both the labels and the pred_probs. Larger values of `downsample` produce faster runtimes but potentially less accurate results due to over-compression. Set to 1 to avoid any downsampling. Returns ------- label_issues: np.ndarray Returns a boolean **mask** for the entire dataset of length `(N,H,W)` where ``True`` represents a pixel label issue and ``False`` represents an example that is correctly labeled. """ batch_size, n_jobs = _get_valid_optional_params(batch_size, n_jobs) downsample = kwargs.get("downsample", 1) def downsample_arrays( labels: np.ndarray, pred_probs: np.ndarray, factor: int = 1 ) -> Tuple[np.ndarray, np.ndarray]: if factor == 1: return labels, pred_probs num_image, num_classes, h, w = pred_probs.shape # Check if possible to downsample if h % downsample != 0 or w % downsample != 0: raise ValueError( f"Height {h} and width {w} not divisible by downsample value of {downsample}. Set kwarg downsample to 1 to avoid downsampling." ) small_labels = np.round( labels.reshape((num_image, h // factor, factor, w // factor, factor)).mean((4, 2)) ) small_pred_probs = pred_probs.reshape( (num_image, num_classes, h // factor, factor, w // factor, factor) ).mean((5, 3)) # We want to make sure that pred_probs are renormalized row_sums = small_pred_probs.sum(axis=1) renorm_small_pred_probs = small_pred_probs / np.expand_dims(row_sums, 1) return small_labels, renorm_small_pred_probs def flatten_and_preprocess_masks( labels: np.ndarray, pred_probs: np.ndarray ) -> Tuple[np.ndarray, np.ndarray]: _, num_classes, _, _ = pred_probs.shape labels_flat = labels.flatten().astype(int) pred_probs_flat = np.moveaxis(pred_probs, 0, 1).reshape(num_classes, -1) return labels_flat, pred_probs_flat.T ## _check_input(labels, pred_probs) # Added Downsampling pre_labels, pre_pred_probs = downsample_arrays(labels, pred_probs, downsample) num_image, _, h, w = pre_pred_probs.shape ### This section is a modified version of find_label_issues_batched(), old code is commented out # ranked_label_issues = find_label_issues_batched( # pre_labels, pre_pred_probs, batch_size=batch_size, n_jobs=n_jobs, verbose=verbose # ) lab = LabelInspector( num_class=pre_pred_probs.shape[1], verbose=verbose, n_jobs=n_jobs, quality_score_kwargs=None, num_issue_kwargs=None, ) n = len(pre_labels) if verbose: from tqdm.auto import tqdm pbar = tqdm(desc="number of examples processed for estimating thresholds", total=n) # Precompute the size of each image in the batch image_size = np.prod(pre_pred_probs.shape[1:]) images_per_batch = max(batch_size // image_size, 1) for start_index in range(0, n, images_per_batch): end_index = min(start_index + images_per_batch, n) labels_batch, pred_probs_batch = flatten_and_preprocess_masks( pre_labels[start_index:end_index], pre_pred_probs[start_index:end_index] ) lab.update_confident_thresholds(labels_batch, pred_probs_batch) if verbose: pbar.update(end_index - start_index) if verbose: pbar.close() pbar = tqdm(desc="number of examples processed for checking labels", total=n) for start_index in range(0, n, images_per_batch): end_index = min(start_index + images_per_batch, n) labels_batch, pred_probs_batch = flatten_and_preprocess_masks( pre_labels[start_index:end_index], pre_pred_probs[start_index:end_index] ) _ = lab.score_label_quality(labels_batch, pred_probs_batch) if verbose: pbar.update(end_index - start_index) if verbose: pbar.close() ranked_label_issues = lab.get_label_issues() ### End find_label_issues_batched() section # Upsample carefully maintaining indicies label_issues = np.full((num_image, h, w), False) # only want to call it an error if pred_probs doesnt match the label at those pixels for i in range(0, ranked_label_issues.shape[0], batch_size): issues_batch = ranked_label_issues[i : i + batch_size] # Finding the right indicies image_batch, batch_coor_i, batch_coor_j = _get_indexes_from_ranked_issues( issues_batch, h, w ) label_issues[image_batch, batch_coor_i, batch_coor_j] = True if downsample == 1: # check if pred_probs matches the label at those pixels pred_argmax = np.argmax(pred_probs[image_batch, :, batch_coor_i, batch_coor_j], axis=1) mask = pred_argmax == labels[image_batch, batch_coor_i, batch_coor_j] label_issues[image_batch[mask], batch_coor_i[mask], batch_coor_j[mask]] = False if downsample != 1: label_issues = label_issues.repeat(downsample, axis=1).repeat(downsample, axis=2) for i in range(0, ranked_label_issues.shape[0], batch_size): issues_batch = ranked_label_issues[i : i + batch_size] image_batch, batch_coor_i, batch_coor_j = _get_indexes_from_ranked_issues( issues_batch, h, w ) # Upsample the coordinates upsampled_ii = batch_coor_i * downsample upsampled_jj = batch_coor_j * downsample # Iterate over the upsampled region for i in range(downsample): for j in range(downsample): rows = upsampled_ii + i cols = upsampled_jj + j pred_argmax = np.argmax(pred_probs[image_batch, :, rows, cols], axis=1) # Check if the predicted class (argmax) at the identified issue location matches the true label mask = pred_argmax == labels[image_batch, rows, cols] # If they match, set the corresponding entries in the label_issues array to False label_issues[image_batch[mask], rows[mask], cols[mask]] = False return label_issues
def _get_indexes_from_ranked_issues( ranked_label_issues: np.ndarray, h: int, w: int ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: hw = h * w relative_index = ranked_label_issues % hw pixel_coor_i, pixel_coor_j = np.unravel_index(relative_index, (h, w)) image_batch = ranked_label_issues // hw return image_batch, pixel_coor_i, pixel_coor_j