# Copyright (C) 2017-2023  Cleanlab Inc.
# This file is part of cleanlab.
#
# cleanlab is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# cleanlab is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with cleanlab.  If not, see <https://www.gnu.org/licenses/>.
"""Classes and methods for datasets that are loaded into Datalab."""
import os
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union, cast, TYPE_CHECKING
try:
    import datasets
except ImportError as error:
    raise ImportError(
        "Cannot import datasets package. "
        "Please install it and try again, or just install cleanlab with "
        "all optional dependencies via: `pip install 'cleanlab[all]'`"
    ) from error
import numpy as np
import pandas as pd
from datasets.arrow_dataset import Dataset
from datasets import ClassLabel
from cleanlab.internal.validation import labels_to_array
if TYPE_CHECKING:  # pragma: no cover
    DatasetLike = Union[Dataset, pd.DataFrame, Dict[str, Any], List[Dict[str, Any]], str]
[docs]class DatasetDictError(ValueError):
    """Exception raised when a DatasetDict is passed to Datalab.
    Usually, this means that a dataset identifier was passed to Datalab, but
    the dataset is a DatasetDict, which contains multiple splits of the dataset.
    """
    def __init__(self):
        message = (
            "Please pass a single dataset, not a DatasetDict. "
            "Try specifying a split, e.g. `dataset = load_dataset('dataset', split='train')` "
            "then pass `dataset` to Datalab."
        )
        super().__init__(message) 
[docs]class DatasetLoadError(ValueError):
    """Exception raised when a dataset cannot be loaded.
    Parameters
    ----------
    dataset_type: type
        The type of dataset that failed to load.
    """
    def __init__(self, dataset_type: type):
        message = f"Failed to load dataset from {dataset_type}.\n"
        super().__init__(message) 
[docs]class Data:
    """
    Class that holds and validates datasets for Datalab.
    Internally, the data is stored as a datasets.Dataset object and the labels
    are integers (ranging from 0 to K-1, where K is the number of classes) stored
    in a numpy array.
    Parameters
    ----------
    data :
        Dataset to be audited by Datalab.
        Several formats are supported, which will internally be converted to a Dataset object.
        Supported formats:
            - datasets.Dataset
            - pandas.DataFrame
            - dict
                - keys are strings
                - values are arrays or lists of equal length
            - list
                - list of dictionaries with the same keys
            - str
                - path to a local file
                    - Text (.txt)
                    - CSV (.csv)
                    - JSON (.json)
                - or a dataset identifier on the Hugging Face Hub
            It checks if the string is a path to a file that exists locally, and if not,
            it assumes it is a dataset identifier on the Hugging Face Hub.
    label_name : Union[str, List[str]]
        Name of the label column in the dataset.
    Warnings
    --------
    Optional dependencies:
    - datasets :
        Dataset, DatasetDict and load_dataset are imported from datasets.
        This is an optional dependency of cleanlab, but is required for
        :py:class:`Datalab <cleanlab.datalab.datalab.Datalab>` to work.
    """
    def __init__(self, data: "DatasetLike", label_name: Optional[str] = None) -> None:
        self._validate_data(data)
        self._data = self._load_data(data)
        self._data_hash = hash(self._data)
        self.labels = Label(data=self._data, label_name=label_name)
    def _load_data(self, data: "DatasetLike") -> Dataset:
        """Checks the type of dataset and uses the correct loader method and
        assigns the result to the data attribute."""
        dataset_factory_map: Dict[type, Callable[..., Dataset]] = {
            Dataset: lambda x: x,
            pd.DataFrame: Dataset.from_pandas,
            dict: self._load_dataset_from_dict,
            list: self._load_dataset_from_list,
            str: self._load_dataset_from_string,
        }
        if not isinstance(data, tuple(dataset_factory_map.keys())):
            raise DataFormatError(data)
        return dataset_factory_map[type(data)](data)
    def __len__(self) -> int:
        return len(self._data)
    def __eq__(self, other) -> bool:
        if isinstance(other, Data):
            # Equality checks
            hashes_are_equal = self._data_hash == other._data_hash
            labels_are_equal = self.labels == other.labels
            return all([hashes_are_equal, labels_are_equal])
        return False
    def __hash__(self) -> int:
        return self._data_hash
    @property
    def class_names(self) -> List[str]:
        return self.labels.class_names
    @property
    def has_labels(self) -> bool:
        """Check if labels are available."""
        return self.labels.is_available
    @staticmethod
    def _validate_data(data) -> None:
        if isinstance(data, datasets.DatasetDict):
            raise DatasetDictError()
        if not isinstance(data, (Dataset, pd.DataFrame, dict, list, str)):
            raise DataFormatError(data)
    @staticmethod
    def _load_dataset_from_dict(data_dict: Dict[str, Any]) -> Dataset:
        try:
            return Dataset.from_dict(data_dict)
        except Exception as error:
            raise DatasetLoadError(dict) from error
    @staticmethod
    def _load_dataset_from_list(data_list: List[Dict[str, Any]]) -> Dataset:
        try:
            return Dataset.from_list(data_list)
        except Exception as error:
            raise DatasetLoadError(list) from error
    @staticmethod
    def _load_dataset_from_string(data_string: str) -> Dataset:
        if not os.path.exists(data_string):
            try:
                dataset = datasets.load_dataset(data_string)
                return cast(Dataset, dataset)
            except Exception as error:
                raise DatasetLoadError(str) from error
        factory: Dict[str, Callable[[str], Any]] = {
            ".txt": Dataset.from_text,
            ".csv": Dataset.from_csv,
            ".json": Dataset.from_json,
        }
        extension = os.path.splitext(data_string)[1]
        if extension not in factory:
            raise DatasetLoadError(type(data_string))
        dataset = factory[extension](data_string)
        dataset_cast = cast(Dataset, dataset)
        return dataset_cast 
[docs]class Label:
    """
    Class to represent labels in a dataset.
    Parameters
    ----------
    """
    def __init__(self, *, data: Dataset, label_name: Optional[str] = None) -> None:
        self._data = data
        self.label_name = label_name
        self.labels = labels_to_array([])
        self.label_map: Mapping[str, Any] = {}
        if label_name is not None:
            self.labels, self.label_map = _extract_labels(data, label_name)
            self._validate_labels()
    def __len__(self) -> int:
        if self.labels is None:
            return 0
        return len(self.labels)
    def __eq__(self, __value: object) -> bool:
        if isinstance(__value, Label):
            labels_are_equal = np.array_equal(self.labels, __value.labels)
            names_are_equal = self.label_name == __value.label_name
            maps_are_equal = self.label_map == __value.label_map
            return all([labels_are_equal, names_are_equal, maps_are_equal])
        return False
    def __getitem__(self, __index: Union[int, slice, np.ndarray]) -> np.ndarray:
        return self.labels[__index]
    def __bool__(self) -> bool:
        return self.is_available
    @property
    def class_names(self) -> List[str]:
        """A list of class names that are present in the dataset.
        Without labels, this will return an empty list.
        """
        return list(self.label_map.values())
    @property
    def is_available(self) -> bool:
        """Check if labels are available."""
        empty_labels = self.labels is None or len(self.labels) == 0
        empty_label_map = self.label_map is None or len(self.label_map) == 0
        return not (empty_labels or empty_label_map)
    def _validate_labels(self) -> None:
        if self.label_name not in self._data.column_names:
            raise ValueError(f"Label column '{self.label_name}' not found in dataset.")
        labels = self._data[self.label_name]
        assert isinstance(labels, (np.ndarray, list))
        assert len(labels) == len(self._data) 
def _extract_labels(data: Dataset, label_name: str) -> Tuple[np.ndarray, Mapping]:
    """
    Picks out labels from the dataset and formats them to be [0, 1, ..., K-1]
    where K is the number of classes. Also returns a mapping from the formatted
    labels to the original labels in the dataset.
    Note: This function is not meant to be used directly. It is used by
    ``cleanlab.data.Data`` to extract the formatted labels from the dataset
    and stores them as attributes.
    Parameters
    ----------
    label_name : str
        Name of the column in the dataset that contains the labels.
    Returns
    -------
    formatted_labels : np.ndarray
        Labels in the format [0, 1, ..., K-1] where K is the number of classes.
    inverse_map : dict
        Mapping from the formatted labels to the original labels in the dataset.
    """
    labels = labels_to_array(data[label_name])  # type: ignore[assignment]
    if labels.ndim != 1:
        raise ValueError("labels must be 1D numpy array.")
    label_name_feature = data.features[label_name]
    if isinstance(label_name_feature, ClassLabel):
        label_map = {label: label_name_feature.str2int(label) for label in label_name_feature.names}
        formatted_labels = labels
    else:
        label_map = {label: i for i, label in enumerate(np.unique(labels))}
        formatted_labels = np.vectorize(label_map.get)(labels)
    inverse_map = {i: label for label, i in label_map.items()}
    return formatted_labels, inverse_map