Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
129 views
in Technique[技术] by (71.8m points)

python - Keras ModelCheckpoint not saving but EarlyStopping is working fine with the same monitor argument

I've built a model and I'm using a custom function for validation. The problem is: My custom validation function is saving the validation accuracy in the logs dict, but Keras ModelCheckpoint, somehow, can't see it. EarlyStopping is working fine.

Here's the code for the validation class:

class ValidateModel(keras.callbacks.Callback):
    
    def __init__(self, validation_data, loss_fnc):
        super().__init__()
        self.validation_data = validation_data
        self.loss_fnc = loss_fnc
    
    def on_epoch_end(self, epoch, logs={}):
        
        th = 0.5
        
        features = self.validation_data[0]
        y_true = self.validation_data[1].reshape((-1,1))     
        
        y_pred = np.asarray(self.model.predict(features)).reshape((-1,1))
        
        #Computing the validation loss.
        y_true_tensor = K.constant(y_true)
        y_pred_tensor = K.constant(y_pred)
        
        val_loss = K.eval(self.loss_fnc(y_true_tensor, y_pred_tensor))
        
        #Rounding the predicted values based on the threshold value.
        #Values lesser than th are rounded to 0, while values greater than th are rounded to 1.
        y_pred_rounded = y_pred / th
        y_pred_rounded = np.clip(np.floor(y_pred_rounded).astype(int),0,1)
        y_pred_rounded_tensor = K.constant(y_pred_rounded)
        
        val_acc = accuracy_score(y_true, y_pred_rounded)
        
        logs['val_loss'] = val_loss
        logs['val_acc'] = val_acc
        
        print(f'
val_loss: {val_loss} - val_acc: {val_acc}')

And here's the function that I use to train the model:

def train_generator_model(model):
    steps = int(train_df.shape[0] / TRAIN_BATCH_SIZE)

    cb_validation = ValidateModel([validation_X, validation_y], iou)
    cb_early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_acc', 
                                                     patience=3, 
                                                     mode='max', 
                                                     verbose = 1)
    cb_model_checkpoint = tf.keras.callbacks.ModelCheckpoint('/kaggle/working/best_generator_model.hdf5',
                                                             monitor='val_acc',
                                                             save_best_only=True,
                                                             mode='max',
                                                             verbose=1)

    history = model.fit(
        x = train_datagen, 
        epochs = 2, ##Setting to 2 to test.
        callbacks = [cb_validation, cb_model_checkpoint, cb_early_stop], 
        verbose = 1,
        steps_per_epoch = steps)
    
    #model = tf.keras.models.load_model('/kaggle/working/best_generator_model.hdf5', custom_objects = {'iou':iou})
    #model.load_weights('/kaggle/working/best_generator_model.hdf5')
    
    return history

If I set the ModelCheckpoint parameter "save_best_model" to False, the model is saved perfectly. When the training is over and I run history.history, I can see that the val_loss is being logged, like as follows:

{'loss': [0.13096405565738678, 0.11926634609699249], 'binary_accuracy': [0.9692355990409851, 0.9716895818710327], 'val_loss': [0.23041087, 0.18325138], 'val_acc': [0.9453247578938803, 0.956172612508138]}

I'm using Tensorflow 2.3.1 and importing keras from tensorflow.

Any help is appreciated. Thank you!

question from:https://stackoverflow.com/questions/65891168/keras-modelcheckpoint-not-saving-but-earlystopping-is-working-fine-with-the-same

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

I've checked the Tensorflow code and found an incompatibility between Tensorflow and Keras. Inside the tensorflow.keras.callbacks file, there's the following code:

from keras.utils import tf_utils

The problem is that there's no tf_utils in keras.utils (atleast not in Keras 2.4.3, which I was using). Strangely, no exception was thrown.

Fix #1: Add the following code to your program:

class ModelCheckpoint_tweaked(tf.keras.callbacks.ModelCheckpoint):
    def __init__(self,
                   filepath,
                   monitor='val_loss',
                   verbose=0,
                   save_best_only=False,
                   save_weights_only=False,
                   mode='auto',
                   save_freq='epoch',
                   options=None,
                   **kwargs):
        
        #Change tf_utils source package.
        from tensorflow.python.keras.utils import tf_utils
        
        super(ModelCheckpoint_tweaked, self).__init__(filepath,
                   monitor,
                   verbose,
                   save_best_only,
                   save_weights_only,
                   mode,
                   save_freq,
                   options,
                   **kwargs)

And then use this new class as the ModelCheckpoint callback:

cb_model_checkpoint = ModelCheckpoint_tweaked(file_name,
                                              monitor='val_acc',
                                              save_best_only=True,
                                              mode='max',
                                              verbose=1)

Fix #2:

Update Tensorflow to version 2.4.0. If you are using a custom callback to compute the monitored parameter, add the following line to the custom callback __init__() function:

self._supports_tf_logs = True

If you don't add this line, the logs ain't gonna be persisted between the callbacks.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...