# Copyright (C) 2017-2022 Cleanlab Inc.
# This file is part of cleanlab.
#
# cleanlab is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# cleanlab is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with cleanlab. If not, see <https://www.gnu.org/licenses/>.
"""
Checks to ensure valid inputs for various methods.
"""
from cleanlab.typing import LabelLike, DatasetLike
from typing import Any, List, Optional, Union
import warnings
import numpy as np
import pandas as pd
[docs]def assert_valid_class_labels(
y: np.ndarray,
allow_missing_classes: bool = True,
allow_one_class: bool = False,
) -> None:
"""Checks that ``labels`` is properly formatted, i.e. a 1D numpy array where labels are zero-based
integers (not multi-label).
"""
if y.ndim != 1:
raise ValueError("Labels must be 1D numpy array.")
if any([isinstance(label, str) for label in y]):
raise ValueError(
"Labels cannot be strings, they must be zero-indexed integers corresponding to class indices."
)
if not np.equal(np.mod(y, 1), 0).all(): # check that labels are integers
raise ValueError("Labels must be zero-indexed integers corresponding to class indices.")
if min(y) < 0:
raise ValueError("Labels must be positive integers corresponding to class indices.")
unique_classes = np.unique(y)
if (not allow_one_class) and (len(unique_classes) < 2):
raise ValueError("Labels must contain at least 2 classes.")
if not allow_missing_classes:
if (unique_classes != np.arange(len(unique_classes))).any():
msg = "cleanlab requires zero-indexed integer labels (0,1,2,..,K-1), but in "
msg += "your case: np.unique(labels) = {}. ".format(str(unique_classes))
msg += "Every class in (0,1,2,..,K-1) must be present in labels as well."
raise TypeError(msg)
[docs]def assert_indexing_works(
X: DatasetLike, idx: Optional[List[int]] = None, length_X: Optional[int] = None
) -> None:
"""Ensures we can do list-based indexing into ``X`` and ``y``.
``length_X`` is an optional argument since sparse matrix ``X``
does not support: ``len(X)`` and we want this method to work for sparse ``X``
(in addition to many other types of ``X``).
"""
if idx is None:
if length_X is None:
length_X = 2 # pragma: no cover
idx = [0, length_X - 1]
is_indexed = False
try:
if isinstance(X, (pd.DataFrame, pd.Series)):
_ = X.iloc[idx] # type: ignore[call-overload]
is_indexed = True
except Exception:
pass
if not is_indexed:
try: # check if X is pytorch Dataset object using lazy import
import torch
if isinstance(X, torch.utils.data.Dataset): # indexing for pytorch Dataset
_ = torch.utils.data.Subset(X, idx) # type: ignore[call-overload]
is_indexed = True
except Exception:
pass
if not is_indexed:
try: # check if X is tensorflow Dataset object using lazy import
import tensorflow as tf
if isinstance(X, tf.data.Dataset):
is_indexed = True # skip check for tensorflow Dataset (too expensive)
except Exception:
pass
if not is_indexed:
try:
_ = X[idx] # type: ignore[call-overload]
except Exception:
msg = (
"Data features X must support list-based indexing; i.e. one of these must work: \n"
)
msg += "1) X[index_list] where say index_list = [0,1,3,10], or \n"
msg += "2) X.iloc[index_list] if X is pandas DataFrame."
raise TypeError(msg)
[docs]def labels_to_array(y: Union[LabelLike, np.generic]) -> np.ndarray:
"""Converts different types of label objects to 1D numpy array and checks their validity.
Parameters
----------
y : Union[LabelLike, np.generic]
Labels to convert to 1D numpy array. Can be a list, numpy array, pandas Series, or pandas DataFrame.
Returns
-------
labels_array : np.ndarray
1D numpy array of labels.
"""
if isinstance(y, pd.Series):
y_series: np.ndarray = y.to_numpy()
return y_series
elif isinstance(y, pd.DataFrame):
y_arr = y.values
assert isinstance(y_arr, np.ndarray)
if y_arr.shape[1] != 1:
raise ValueError("labels must be one dimensional.")
return y_arr.flatten()
else: # y is list, np.ndarray, or some other tuple-like object
try:
return np.asarray(y)
except:
raise ValueError(
"List of labels must be convertable to 1D numpy array via: np.ndarray(labels)."
)