# Copyright (C) 2017-2023  Cleanlab Inc.
# This file is part of cleanlab.
#
# cleanlab is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# cleanlab is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with cleanlab.  If not, see <https://www.gnu.org/licenses/>.
"""
A cleanlab-compatible PyTorch ConvNet classifier that can be used to find
label issues in image data.
This is a good example to reference for making your own bespoke model compatible with cleanlab.
You must have PyTorch installed: https://pytorch.org/get-started/locally/
"""
from sklearn.base import BaseEstimator
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np
MNIST_TRAIN_SIZE = 60000
MNIST_TEST_SIZE = 10000
SKLEARN_DIGITS_TRAIN_SIZE = 1247
SKLEARN_DIGITS_TEST_SIZE = 550
[docs]def get_mnist_dataset(loader):  # pragma: no cover
    """Downloads MNIST as PyTorch dataset.
    Parameters
    ----------
    loader : str (values: 'train' or 'test')."""
    dataset = datasets.MNIST(
        root="../data",
        train=(loader == "train"),
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    )
    return dataset 
[docs]def get_sklearn_digits_dataset(loader):
    """Downloads Sklearn handwritten digits dataset.
    Uses the last SKLEARN_DIGITS_TEST_SIZE examples as the test
    This is (hard-coded) -- do not change.
    Parameters
    ----------
    loader : str (values: 'train' or 'test')."""
    from torch.utils.data import Dataset
    from sklearn.datasets import load_digits
    class TorchDataset(Dataset):
        """Abstracts a numpy array as a PyTorch dataset."""
        def __init__(self, data, targets, transform=None):
            self.data = torch.from_numpy(data).float()
            self.targets = torch.from_numpy(targets).long()
            self.transform = transform
        def __getitem__(self, index):
            x = self.data[index]
            y = self.targets[index]
            if self.transform:
                x = self.transform(x)
            return x, y
        def __len__(self):
            return len(self.data)
    transform = transforms.Compose(
        [
            transforms.ToPILImage(),
            transforms.Resize(28),
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
        ]
    )
    # Get sklearn digits dataset
    X_all, y_all = load_digits(return_X_y=True)
    X_all = X_all.reshape((len(X_all), 8, 8))
    y_train = y_all[:-SKLEARN_DIGITS_TEST_SIZE]
    y_test = y_all[-SKLEARN_DIGITS_TEST_SIZE:]
    X_train = X_all[:-SKLEARN_DIGITS_TEST_SIZE]
    X_test = X_all[-SKLEARN_DIGITS_TEST_SIZE:]
    if loader == "train":
        return TorchDataset(X_train, y_train, transform=transform)
    elif loader == "test":
        return TorchDataset(X_test, y_test, transform=transform)
    else:  # prama: no cover
        raise ValueError("loader must be either str 'train' or str 'test'.") 
[docs]class SimpleNet(nn.Module):
    """Basic Pytorch CNN for MNIST-like data."""
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)
[docs]    def forward(self, x, T=1.0):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        x = F.log_softmax(x, dim=1)
        return x  
[docs]class CNN(BaseEstimator):  # Inherits sklearn classifier
    """Wraps a PyTorch CNN for the MNIST dataset within an sklearn template
    Defines ``.fit()``, ``.predict()``, and ``.predict_proba()`` functions. This
    template enables the PyTorch CNN to flexibly be used within the sklearn
    architecture -- meaning it can be passed into functions like
    cross_val_predict as if it were an sklearn model. The cleanlab library
    requires that all models adhere to this basic sklearn template and thus,
    this class allows a PyTorch CNN to be used in for learning with noisy
    labels among other things.
    Parameters
    ----------
    batch_size: int
    epochs: int
    log_interval: int
    lr: float
    momentum: float
    no_cuda: bool
    seed: int
    test_batch_size: int, default=None
    dataset: {'mnist', 'sklearn-digits'}
    loader: {'train', 'test'}
      Set to 'test' to force fit() and predict_proba() on test_set
    Note
    ----
    Be careful setting the ``loader`` param, it will override every other loader
    If you set this to 'test', but call .predict(loader = 'train')
    then .predict() will still predict on test!
    Attributes
    ----------
    batch_size: int
    epochs: int
    log_interval: int
    lr: float
    momentum: float
    no_cuda: bool
    seed: int
    test_batch_size: int, default=None
    dataset: {'mnist', 'sklearn-digits'}
    loader: {'train', 'test'}
      Set to 'test' to force fit() and predict_proba() on test_set
    Methods
    -------
    fit
      fits the model to data.
    predict
      get the fitted model's prediction on test data
    predict_proba
      get the fitted model's probability distribution over classes for test data
    """
    def __init__(
        self,
        batch_size=64,
        epochs=6,
        log_interval=50,  # Set to None to not print
        lr=0.01,
        momentum=0.5,
        no_cuda=False,
        seed=1,
        test_batch_size=None,
        dataset="mnist",
        loader=None,
    ):
        self.batch_size = batch_size
        self.epochs = epochs
        self.log_interval = log_interval
        self.lr = lr
        self.momentum = momentum
        self.no_cuda = no_cuda
        self.seed = seed
        self.cuda = not self.no_cuda and torch.cuda.is_available()
        torch.manual_seed(self.seed)
        if self.cuda:  # pragma: no cover
            torch.cuda.manual_seed(self.seed)
        # Instantiate PyTorch model
        self.model = SimpleNet()
        if self.cuda:  # pragma: no cover
            self.model.cuda()
        self.loader_kwargs = {"num_workers": 1, "pin_memory": True} if self.cuda else {}
        self.loader = loader
        self._set_dataset(dataset)
        if test_batch_size is not None:
            self.test_batch_size = test_batch_size
        else:
            self.test_batch_size = self.test_size
    def _set_dataset(self, dataset):
        self.dataset = dataset
        if dataset == "mnist":
            # pragma: no cover
            self.get_dataset = get_mnist_dataset
            self.train_size = MNIST_TRAIN_SIZE
            self.test_size = MNIST_TEST_SIZE
        elif dataset == "sklearn-digits":
            self.get_dataset = get_sklearn_digits_dataset
            self.train_size = SKLEARN_DIGITS_TRAIN_SIZE
            self.test_size = SKLEARN_DIGITS_TEST_SIZE
        else:  # pragma: no cover
            raise ValueError("dataset must be 'mnist' or 'sklearn-digits'.")
    # XXX this is a pretty weird sklearn estimator that does data loading
    # internally in `fit`, and it supports multiple datasets and is aware of
    # which dataset it's using; if we weren't doing this, we wouldn't need to
    # override `get_params` / `set_params`
[docs]    def get_params(self, deep=True):
        return {
            "batch_size": self.batch_size,
            "epochs": self.epochs,
            "log_interval": self.log_interval,
            "lr": self.lr,
            "momentum": self.momentum,
            "no_cuda": self.no_cuda,
            "test_batch_size": self.test_batch_size,
            "dataset": self.dataset,
        } 
[docs]    def set_params(self, **parameters):  # pragma: no cover
        for parameter, value in parameters.items():
            if parameter != "dataset":
                setattr(self, parameter, value)
        if "dataset" in parameters:
            self._set_dataset(parameters["dataset"])
        return self 
[docs]    def fit(self, train_idx, train_labels=None, sample_weight=None, loader="train"):
        """This function adheres to sklearn's "fit(X, y)" format for
        compatibility with scikit-learn. ** All inputs should be numpy
        arrays, not pyTorch Tensors train_idx is not X, but instead a list of
        indices for X (and y if train_labels is None). This function is a
        member of the cnn class which will handle creation of X, y from the
        train_idx via the train_loader."""
        if self.loader is not None:
            loader = self.loader
        if train_labels is not None and len(train_idx) != len(train_labels):
            raise ValueError("Check that train_idx and train_labels are the same length.")
        if sample_weight is not None:  # pragma: no cover
            if len(sample_weight) != len(train_labels):
                raise ValueError(
                    "Check that train_labels and sample_weight " "are the same length."
                )
            class_weight = sample_weight[np.unique(train_labels, return_index=True)[1]]
            class_weight = torch.from_numpy(class_weight).float()
            if self.cuda:
                class_weight = class_weight.cuda()
        else:
            class_weight = None
        train_dataset = self.get_dataset(loader)
        # Use provided labels if not None o.w. use MNIST dataset training labels
        if train_labels is not None:
            # Create sparse tensor of train_labels with (-1)s for labels not
            # in train_idx. We avoid train_data[idx] because train_data may
            # very large, i.e. ImageNet
            sparse_labels = (
                np.zeros(self.train_size if loader == "train" else self.test_size, dtype=int) - 1
            )
            sparse_labels[train_idx] = train_labels
            train_dataset.targets = sparse_labels
        train_loader = torch.utils.data.DataLoader(
            dataset=train_dataset,
            # sampler=SubsetRandomSampler(train_idx if train_idx is not None
            # else range(self.train_size)),
            sampler=SubsetRandomSampler(train_idx),
            batch_size=self.batch_size,
            **self.loader_kwargs
        )
        optimizer = optim.SGD(self.model.parameters(), lr=self.lr, momentum=self.momentum)
        # Train for self.epochs epochs
        for epoch in range(1, self.epochs + 1):
            # Enable dropout and batch norm layers
            self.model.train()
            for batch_idx, (data, target) in enumerate(train_loader):
                if self.cuda:  # pragma: no cover
                    data, target = data.cuda(), target.cuda()
                data, target = Variable(data), Variable(target).long()
                optimizer.zero_grad()
                output = self.model(data)
                loss = F.nll_loss(output, target, class_weight)
                loss.backward()
                optimizer.step()
                if self.log_interval is not None and batch_idx % self.log_interval == 0:
                    print(
                        "TrainEpoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                            epoch,
                            batch_idx * len(data),
                            len(train_idx),
                            100.0 * batch_idx / len(train_loader),
                            loss.item(),
                        ),
                    ) 
[docs]    def predict(self, idx=None, loader=None):
        """Get predicted labels from trained model."""
        # get the index of the max probability
        probs = self.predict_proba(idx, loader)
        return probs.argmax(axis=1) 
[docs]    def predict_proba(self, idx=None, loader=None):
        if self.loader is not None:
            loader = self.loader
        if loader is None:
            is_test_idx = (
                idx is not None
                and len(idx) == self.test_size
                and np.all(np.array(idx) == np.arange(self.test_size))
            )
            loader = "test" if is_test_idx else "train"
        dataset = self.get_dataset(loader)
        # Filter by idx
        if idx is not None:
            if (loader == "train" and len(idx) != self.train_size) or (
                loader == "test" and len(idx) != self.test_size
            ):
                dataset.data = dataset.data[idx]
                dataset.targets = dataset.targets[idx]
        loader = torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=self.batch_size if loader == "train" else self.test_batch_size,
            **self.loader_kwargs
        )
        # sets model.train(False) inactivating dropout and batch-norm layers
        self.model.eval()
        # Run forward pass on model to compute outputs
        outputs = []
        for data, _ in loader:
            if self.cuda:  # pragma: no cover
                data = data.cuda()
            with torch.no_grad():
                data = Variable(data)
                output = self.model(data)
            outputs.append(output)
        # Outputs are log_softmax (log probabilities)
        outputs = torch.cat(outputs, dim=0)
        # Convert to probabilities and return the numpy array of shape N x K
        out = outputs.cpu().numpy() if self.cuda else outputs.numpy()
        pred = np.exp(out)
        return pred