# Copyright (C) 2017-2023  Cleanlab Inc.
# This file is part of cleanlab.
#
# cleanlab is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# cleanlab is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with cleanlab.  If not, see <https://www.gnu.org/licenses/>.
"""
Checks to ensure valid inputs for various methods.
"""
from cleanlab.typing import LabelLike, DatasetLike
from cleanlab.internal.constants import FLOATING_POINT_COMPARISON
from typing import Any, List, Optional, Union
import warnings
import numpy as np
import pandas as pd
[docs]def assert_valid_class_labels(
    y: np.ndarray,
    allow_missing_classes: bool = True,
    allow_one_class: bool = False,
) -> None:
    """Checks that ``labels`` is properly formatted, i.e. a 1D numpy array where labels are zero-based
    integers (not multi-label).
    """
    if y.ndim != 1:
        raise ValueError("Labels must be 1D numpy array.")
    if any([isinstance(label, str) for label in y]):
        raise ValueError(
            "Labels cannot be strings, they must be zero-indexed integers corresponding to class indices."
        )
    if not np.equal(np.mod(y, 1), 0).all():  # check that labels are integers
        raise ValueError("Labels must be zero-indexed integers corresponding to class indices.")
    if min(y) < 0:
        raise ValueError("Labels must be positive integers corresponding to class indices.")
    unique_classes = np.unique(y)
    if (not allow_one_class) and (len(unique_classes) < 2):
        raise ValueError("Labels must contain at least 2 classes.")
    if not allow_missing_classes:
        if (unique_classes != np.arange(len(unique_classes))).any():
            msg = "cleanlab requires zero-indexed integer labels (0,1,2,..,K-1), but in "
            msg += "your case: np.unique(labels) = {}. ".format(str(unique_classes))
            msg += "Every class in (0,1,2,..,K-1) must be present in labels as well."
            raise TypeError(msg) 
[docs]def assert_indexing_works(
    X: DatasetLike, idx: Optional[List[int]] = None, length_X: Optional[int] = None
) -> None:
    """Ensures we can do list-based indexing into ``X`` and ``y``.
    ``length_X`` is an optional argument since sparse matrix ``X``
    does not support: ``len(X)`` and we want this method to work for sparse ``X``
    (in addition to many other types of ``X``).
    """
    if idx is None:
        if length_X is None:
            length_X = 2  # pragma: no cover
        idx = [0, length_X - 1]
    is_indexed = False
    try:
        if isinstance(X, (pd.DataFrame, pd.Series)):
            _ = X.iloc[idx]  # type: ignore[call-overload]
            is_indexed = True
    except Exception:
        pass
    if not is_indexed:
        try:  # check if X is pytorch Dataset object using lazy import
            import torch
            if isinstance(X, torch.utils.data.Dataset):  # indexing for pytorch Dataset
                _ = torch.utils.data.Subset(X, idx)  # type: ignore[call-overload]
                is_indexed = True
        except Exception:
            pass
    if not is_indexed:
        try:  # check if X is tensorflow Dataset object using lazy import
            import tensorflow as tf
            if isinstance(X, tf.data.Dataset):
                is_indexed = True  # skip check for tensorflow Dataset (too expensive)
        except Exception:
            pass
    if not is_indexed:
        try:
            _ = X[idx]  # type: ignore[call-overload]
        except Exception:
            msg = (
                "Data features X must support list-based indexing; i.e. one of these must work: \n"
            )
            msg += "1)  X[index_list] where say index_list = [0,1,3,10], or \n"
            msg += "2)  X.iloc[index_list] if X is pandas DataFrame."
            raise TypeError(msg) 
[docs]def labels_to_array(y: Union[LabelLike, np.generic]) -> np.ndarray:
    """Converts different types of label objects to 1D numpy array and checks their validity.
    Parameters
    ----------
    y : Union[LabelLike, np.generic]
        Labels to convert to 1D numpy array. Can be a list, numpy array, pandas Series, or pandas DataFrame.
    Returns
    -------
    labels_array : np.ndarray
        1D numpy array of labels.
    """
    if isinstance(y, pd.Series):
        y_series: np.ndarray = y.to_numpy()
        return y_series
    elif isinstance(y, pd.DataFrame):
        y_arr = y.values
        assert isinstance(y_arr, np.ndarray)
        if y_arr.shape[1] != 1:
            raise ValueError("labels must be one dimensional.")
        return y_arr.flatten()
    else:  # y is list, np.ndarray, or some other tuple-like object
        try:
            return np.asarray(y)
        except:
            raise ValueError(
                "List of labels must be convertable to 1D numpy array via: np.ndarray(labels)."
            )