This is an experimental feature, and the API may change in the future.
Replicate works with any machine learning framework, but it includes a callback that makes it easier to use with Keras.
ReplicateCallback
behaves like Tensorflow's ModelCheckpoint
callback, but in addition to exporting a model at the end of each epoch, it also:
replicate.init()
at the start of training to create an experiment, andexperiment.checkpoint()
after saving the model at the end of the epoch. All metrics are saved, and also any data from the logs
dictionary passed to the callback's on_epoch_end
method.Here is a simple example:
import tensorflow as tffrom tensorflow import kerasfrom replicate.keras_callback import ReplicateCallbackdense_size = 784learning_rate = 0.01# from https://www.tensorflow.org/guide/keras/custom_callbackmodel = keras.Sequential()model.add(keras.layers.Dense(1, input_dim=dense_size))model.compile(optimizer=keras.optimizers.RMSprop(learning_rate=learning_rate),loss="mean_squared_error",metrics=["mean_absolute_error"],)# Load example MNIST data and pre-process it(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()x_train = x_train.reshape(-1, 784).astype("float32") / 255.0x_test = x_test.reshape(-1, 784).astype("float32") / 255.0model.fit(x_train[:1000],y_train[:1000],batch_size=128,epochs=20,validation_split=0.5,callbacks=[ReplicateCallback(params={"dense_size": dense_size, "learning_rate": learning_rate,},primary_metric=("mean_absolute_error", "minimize"),),],)
The ReplicateCallback
class takes the following arguments, all optional:
filepath
: The path where the exported model is saved. This path is also saved by experiment.checkpoint()
at the end of each epoch. If it is None
, the model is not saved, and the callback just gathers code and metrics. Default: model.hdf5
params
: A dictionary of hyperparameters that will be recorded to the experiment at the start of training.primary_metric
: A tuple in the format (metric_name, goal)
, where goal
is either minimize
, or maximize
. For example, ("mean_absolute_error", "minimize")
.save_freq
:"epoch"
or integer. When using "epoch"
, the callback saves the model after each epoch. When using integer, the callback saves the model at end of this many batches. Default: "epoch"
save_weights_only
: if True
, then only the model's weights will be saved (model.save_weights(filepath)
), otherwise the full model is saved (model.save(filepath)
). Default: False