keras#
Wrapper class you can use to make any Keras model compatible with CleanLearning
and sklearn.
Use KerasWrapperModel
to wrap existing functional API code for keras.Model
objects,
and KerasWrapperSequential
to wrap existing tf.keras.models.Sequential
objects.
Most of the instance methods of this class work the same as the ones for the wrapped Keras model,
see the Keras documentation for details.
This is a good example of making any bespoke neural network compatible with cleanlab.
You must have Tensorflow 2 installed (only compatible with Python versions >= 3.7).
Tips:
If this class lacks certain functionality, you can alternatively try scikeras.
Unlike scikeras, our
KerasWrapper
classes can operate directly ontensorflow.data.Dataset
objects (like regular Keras models).To call
fit()
on a tensorflowDataset
object with a Keras model, theDataset
should already be batched.Check out our example using this class: huggingface_keras_imdb
Our unit tests also provide basic usage examples.
Classes:
|
Takes in a callable function to instantiate a Keras Model (using Keras functional API) that is compatible with |
|
Makes any |
- class cleanlab.models.keras.KerasWrapperModel(model, model_kwargs={}, compile_kwargs={'loss': <keras.losses.SparseCategoricalCrossentropy object>}, params=None)[source]#
Bases:
object
Takes in a callable function to instantiate a Keras Model (using Keras functional API) that is compatible with
CleanLearning
and sklearn.The instance methods of this class work in the same way as those of any
keras.Model
object, see the Keras documentation for details. For using Keras sequential instead of functional API, see theKerasWrapperSequential
class.- Parameters:
model (
Callable
) –A callable function to construct the Keras Model (using functional API). Pass in the function here, not the constructed model!
For example:
def model(num_features, num_classes): inputs = tf.keras.Input(shape=(num_features,)) outputs = tf.keras.layers.Dense(num_classes)(inputs) return tf.keras.Model(inputs=inputs, outputs=outputs, name="my_keras_model")
model_kwargs (
dict
, default ={}
) – Dict of optional keyword arguments to pass intomodel()
when instantiating thekeras.Model
.compile_kwargs (
dict
, default ={"loss": tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)}
) – Dict of optional keyword arguments to pass intomodel.compile()
for declaring loss, metrics, optimizer, etc.
Methods:
get_params
([deep])Returns the parameters of the Keras model.
set_params
(**params)Set the parameters of the Keras model.
fit
(X[, y])Trains a Keras model.
predict_proba
(X, *[, apply_softmax])Predict class probabilities for all classes using the wrapped Keras model.
predict
(X, **kwargs)Predict class labels using the wrapped Keras model.
summary
(**kwargs)Returns the summary of the Keras model.
- fit(X, y=None, **kwargs)[source]#
Trains a Keras model.
- Parameters:
X (
tf.Dataset
ornp.array
orpd.DataFrame
) – IfX
is a tensorflow dataset object, it must already contain the labels as is required for standard Keras fit.y (
np.array
orpd.DataFrame
, default= None
) – IfX
is a tensorflow dataset object, you can optionally provide the labels again here as argumenty
to be compatible with sklearn, but they are ignored. IfX
is a numpy array or pandas dataframe, the labels have to be passed in using this argument.
- predict_proba(X, *, apply_softmax=True, **kwargs)[source]#
Predict class probabilities for all classes using the wrapped Keras model. Set extra argument
apply_softmax
to True to indicate your network only outputs logits not probabilities.- Parameters:
X (
tf.Dataset
ornp.array
orpd.DataFrame
) – Data in the same format as the originalX
provided tofit()
.
- class cleanlab.models.keras.KerasWrapperSequential(layers=None, name=None, compile_kwargs={'loss': <keras.losses.SparseCategoricalCrossentropy object>}, params=None)[source]#
Bases:
object
Makes any
tf.keras.models.Sequential
object compatible withCleanLearning
and sklearn.KerasWrapperSequential
is instantiated in the same way as a kerasSequential
object, except for optional extracompile_kwargs
argument. Just instantiate this object in the same way as yourtf.keras.models.Sequential
object (rather than passing in an existingSequential
object). The instance methods of this class work in the same way as those of any kerasSequential
object, see the Keras documentation for details.- Parameters:
layers (
list
) – A list containing the layers to add to the kerasSequential
model (same as fortf.keras.models.Sequential
).name (
str
, default= None
) – Name for the Keras model (same as fortf.keras.models.Sequential
).compile_kwargs (
dict
, default ={"loss": tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)}
) – Dict of optional keyword arguments to pass intomodel.compile()
for declaring loss, metrics, optimizer, etc.
Methods:
get_params
([deep])Returns the parameters of the Keras model.
set_params
(**params)Set the parameters of the Keras model.
fit
(X[, y])Trains a Sequential Keras model.
predict_proba
(X, *[, apply_softmax])Predict class probabilities for all classes using the wrapped Keras model.
predict
(X, **kwargs)Predict class labels using the wrapped Keras model.
summary
(**kwargs)Returns the summary of the Keras model.
- fit(X, y=None, **kwargs)[source]#
Trains a Sequential Keras model.
- Parameters:
X (
tf.Dataset
ornp.array
orpd.DataFrame
) – IfX
is a tensorflow dataset object, it must already contain the labels as is required for standard Keras fit.y (
np.array
orpd.DataFrame
, default= None
) – IfX
is a tensorflow dataset object, you can optionally provide the labels again here as argumenty
to be compatible with sklearn, but they are ignored. IfX
is a numpy array or pandas dataframe, the labels have to be passed in using this argument.
- predict_proba(X, *, apply_softmax=True, **kwargs)[source]#
Predict class probabilities for all classes using the wrapped Keras model. Set extra argument
apply_softmax
to True to indicate your network only outputs logits not probabilities.- Parameters:
X (
tf.Dataset
ornp.array
orpd.DataFrame
) – Data in the same format as the originalX
provided tofit()
.