# Copyright (C) 2017-2023  Cleanlab Inc.
# This file is part of cleanlab.
#
# cleanlab is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# cleanlab is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with cleanlab.  If not, see <https://www.gnu.org/licenses/>.
"""
Wrapper class you can use to make any Keras model compatible with :py:class:`CleanLearning <cleanlab.classification.CleanLearning>` and sklearn.
Use :py:class:`KerasWrapperModel<cleanlab.experimental.keras.KerasWrapperModel>` to wrap existing functional API code for ``keras.Model`` objects,
and :py:class:`KerasWrapperSequential<cleanlab.experimental.keras.KerasWrapperSequential>` to wrap existing ``tf.keras.models.Sequential`` objects.
Most of the instance methods of this class work the same as the ones for the wrapped Keras model,
see the `Keras documentation <https://keras.io/>`_ for details.
This is a good example of making any bespoke neural network compatible with cleanlab.
You must have `Tensorflow 2 installed <https://www.tensorflow.org/install>`_ (only compatible with Python versions >= 3.7).
This wrapper class is only fully compatible with ``tensorflow<2.11``, if using ``tensorflow>=2.11``, 
please replace your Optimizer class with the legacy Optimizer `here <https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/legacy/Optimizer>`_.
Tips:
* If this class lacks certain functionality, you can alternatively try `scikeras <https://github.com/adriangb/scikeras>`_.
* Unlike scikeras, our `KerasWrapper` classes can operate directly on ``tensorflow.data.Dataset`` objects (like regular Keras models).
* To call ``fit()`` on a tensorflow ``Dataset`` object with a Keras model, the ``Dataset`` should already be batched.
* Check out our example using this class: `huggingface_keras_imdb <https://github.com/cleanlab/examples/blob/master/huggingface_keras_imdb/huggingface_keras_imdb.ipynb>`_
* Our `unit tests <https://github.com/cleanlab/cleanlab/blob/master/tests/test_frameworks.py>`_ also provide basic usage examples.
"""
import tensorflow as tf
import keras  # type: ignore
import numpy as np
from typing import Callable, Optional
[docs]class KerasWrapperModel:
    """Takes in a callable function to instantiate a Keras Model (using Keras functional API)
    that is compatible with :py:class:`CleanLearning <cleanlab.classification.CleanLearning>` and sklearn.
    The instance methods of this class work in the same way as those of any ``keras.Model`` object, see the `Keras documentation <https://keras.io/>`_ for details.
    For using Keras sequential instead of functional API, see the :py:class:`KerasWrapperSequential<cleanlab.experimental.keras.KerasWrapperSequential>` class.
    Parameters
    ----------
    model: Callable
        A callable function to construct the Keras Model (using functional API). Pass in the function here, not the constructed model!
        For example::
            def model(num_features, num_classes):
                inputs = tf.keras.Input(shape=(num_features,))
                outputs = tf.keras.layers.Dense(num_classes)(inputs)
                return tf.keras.Model(inputs=inputs, outputs=outputs, name="my_keras_model")
    model_kwargs: dict, default = {}
        Dict of optional keyword arguments to pass into ``model()`` when instantiating the ``keras.Model``.
    compile_kwargs: dict, default = {"loss": tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)}
        Dict of optional keyword arguments to pass into ``model.compile()`` for declaring loss, metrics, optimizer, etc.
    """
    def __init__(
        self,
        model: Callable,
        model_kwargs: dict = {},
        compile_kwargs: dict = {
            "loss": tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        },
        params: Optional[dict] = None,
    ):
        if params is None:
            params = {}
        self.model = model
        self.model_kwargs = model_kwargs
        self.compile_kwargs = compile_kwargs
        self.params = params
        self.net = None
[docs]    def get_params(self, deep=True):
        """Returns the parameters of the Keras model."""
        return {
            "model": self.model,
            "model_kwargs": self.model_kwargs,
            "compile_kwargs": self.compile_kwargs,
            "params": self.params,
        } 
[docs]    def set_params(self, **params):
        """Set the parameters of the Keras model."""
        self.params.update(params)
        return self 
[docs]    def fit(self, X, y=None, **kwargs):
        """Trains a Keras model.
        Parameters
        ----------
        X : tf.Dataset or np.array or pd.DataFrame
            If ``X`` is a tensorflow dataset object, it must already contain the labels as is required for standard Keras fit.
        y : np.array or pd.DataFrame, default = None
            If ``X`` is a tensorflow dataset object, you can optionally provide the labels again here as argument `y` to be compatible with sklearn,
            but they are ignored.
            If ``X`` is a numpy array or pandas dataframe, the labels have to be passed in using this argument.
        """
        if self.net is None:
            self.net = self.model(**self.model_kwargs)
            self.net.compile(**self.compile_kwargs)
        # TODO: check for generators
        if y is not None and not isinstance(X, (tf.data.Dataset, keras.utils.Sequence)):
            kwargs["y"] = y
        self.net.fit(X, **{**self.params, **kwargs}) 
[docs]    def predict_proba(self, X, *, apply_softmax=True, **kwargs):
        """Predict class probabilities for all classes using the wrapped Keras model.
        Set extra argument `apply_softmax` to True to indicate your network only outputs logits not probabilities.
        Parameters
        ----------
        X : tf.Dataset or np.array or pd.DataFrame
            Data in the same format as the original ``X`` provided to ``fit()``.
        """
        if self.net is None:
            raise ValueError("must call fit() before predict()")
        pred_probs = self.net.predict(X, **kwargs)
        if apply_softmax:
            pred_probs = tf.nn.softmax(pred_probs, axis=1)
        return pred_probs 
[docs]    def predict(self, X, **kwargs):
        """Predict class labels using the wrapped Keras model.
        Parameters
        ----------
        X : tf.Dataset or np.array or pd.DataFrame
            Data in the same format as the original ``X`` provided to ``fit()``.
        """
        pred_probs = self.predict_proba(X, **kwargs)
        return np.argmax(pred_probs, axis=1) 
[docs]    def summary(self, **kwargs):
        """Returns the summary of the Keras model."""
        if self.net is None:
            self.net = self.model(**self.model_kwargs)
            self.net.compile(**self.compile_kwargs)
        return self.net.summary(**kwargs)  
[docs]class KerasWrapperSequential:
    """Makes any ``tf.keras.models.Sequential`` object compatible with :py:class:`CleanLearning <cleanlab.classification.CleanLearning>` and sklearn.
    `KerasWrapperSequential` is instantiated in the same way as a keras ``Sequential``  object, except for optional extra `compile_kwargs` argument.
    Just instantiate this object in the same way as your ``tf.keras.models.Sequential`` object (rather than passing in an existing ``Sequential`` object).
    The instance methods of this class work in the same way as those of any keras ``Sequential`` object, see the `Keras documentation <https://keras.io/>`_ for details.
    Parameters
    ----------
    layers: list
        A list containing the layers to add to the keras ``Sequential`` model (same as for ``tf.keras.models.Sequential``).
    name: str, default = None
        Name for the Keras model (same as for ``tf.keras.models.Sequential``).
    compile_kwargs: dict, default = {"loss": tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)}
        Dict of optional keyword arguments to pass into ``model.compile()`` for declaring loss, metrics, optimizer, etc.
    """
    def __init__(
        self,
        layers: Optional[list] = None,
        name: Optional[str] = None,
        compile_kwargs: dict = {
            "loss": tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        },
        params: Optional[dict] = None,
    ):
        if params is None:
            params = {}
        self.layers = layers
        self.name = name
        self.compile_kwargs = compile_kwargs
        self.params = params
        self.net = None
[docs]    def get_params(self, deep=True):
        """Returns the parameters of the Keras model."""
        return {
            "layers": self.layers,
            "name": self.name,
            "compile_kwargs": self.compile_kwargs,
            "params": self.params,
        } 
[docs]    def set_params(self, **params):
        """Set the parameters of the Keras model."""
        self.params.update(params)
        return self 
[docs]    def fit(self, X, y=None, **kwargs):
        """Trains a Sequential Keras model.
        Parameters
        ----------
        X : tf.Dataset or np.array or pd.DataFrame
            If ``X`` is a tensorflow dataset object, it must already contain the labels as is required for standard Keras fit.
        y : np.array or pd.DataFrame, default = None
            If ``X`` is a tensorflow dataset object, you can optionally provide the labels again here as argument `y` to be compatible with sklearn,
            but they are ignored.
            If ``X`` is a numpy array or pandas dataframe, the labels have to be passed in using this argument.
        """
        if self.net is None:
            self.net = tf.keras.models.Sequential(self.layers, self.name)
            self.net.compile(**self.compile_kwargs)
        # TODO: check for generators
        if y is not None and not isinstance(X, (tf.data.Dataset, keras.utils.Sequence)):
            kwargs["y"] = y
        self.net.fit(X, **{**self.params, **kwargs}) 
[docs]    def predict_proba(self, X, *, apply_softmax=True, **kwargs):
        """Predict class probabilities for all classes using the wrapped Keras model.
        Set extra argument `apply_softmax` to True to indicate your network only outputs logits not probabilities.
        Parameters
        ----------
        X : tf.Dataset or np.array or pd.DataFrame
            Data in the same format as the original ``X`` provided to ``fit()``.
        """
        if self.net is None:
            raise ValueError("must call fit() before predict()")
        pred_probs = self.net.predict(X, **kwargs)
        if apply_softmax:
            pred_probs = tf.nn.softmax(pred_probs, axis=1)
        return pred_probs 
[docs]    def predict(self, X, **kwargs):
        """Predict class labels using the wrapped Keras model.
        Parameters
        ----------
        X : tf.Dataset or np.array or pd.DataFrame
            Data in the same format as the original ``X`` provided to ``fit()``.
        """
        pred_probs = self.predict_proba(X, **kwargs)
        return np.argmax(pred_probs, axis=1) 
[docs]    def summary(self, **kwargs):
        """Returns the summary of the Keras model."""
        if self.net is None:
            self.net = tf.keras.models.Sequential(self.layers, self.name)
            self.net.compile(**self.compile_kwargs)
        return self.net.summary(**kwargs)