Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
702 views
in Technique[技术] by (71.8m points)

tensorflow - Save keras model as .h5

I want to save my trained keras model as .h5 file. Should be straight forward. Short example:

#%%
import tensorflow as tf
import numpy as np
from tensorflow.keras.callbacks import ModelCheckpoint
import matplotlib.pyplot as plt

print('TF version: ',tf.__version__)

#%%
#########################
# BATCH SIZE
BATCH_SIZE=100
########################

# create training data
X_train_set = np.random.random(size=(10000,10))
y_train_set = np.random.random(size=(10000))

# create validation data
X_val_set = np.random.random(size=(100,10))
y_val_set = np.random.random(size=(100))

# convert np.array to dataset
train_dataset = tf.data.Dataset.from_tensor_slices((X_train_set, y_train_set))
val_dataset = tf.data.Dataset.from_tensor_slices((X_val_set, y_val_set))

# batching
train_dataset=train_dataset.batch(BATCH_SIZE)
val_dataset = val_dataset.batch(BATCH_SIZE)

# set up the model
my_model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(10,)),
    tf.keras.layers.Dense(100, activation='relu'),
    tf.keras.layers.Dense(10, activation='relu'),
    tf.keras.layers.Dense(1)
])

#%%
# custom optimizer with learning rate
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=1e-2,
    decay_steps=10000,
    decay_rate=0.9)
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)


# compile the model
my_model.compile(optimizer=optimizer,loss='mse')

# define a checkpoint
checkpoint = ModelCheckpoint('./tf.keras_test',
                             monitor='val_loss',
                             verbose=1,
                             save_best_only=True,
                             mode='min',
                             save_freq='epoch')

callbacks = [checkpoint]

#%%
# train with datasets
history= my_model.fit(train_dataset,
             validation_data=val_dataset,
             #validation_steps=100,
             #callbacks=callbacks,
             epochs=10)

# save as .h5
my_model.save('my_model.h5',save_format='h5')

However, my_model.save gives me a TypeError:

Traceback (most recent call last):
  File "/home/max/.local/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 3343, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-11-a369340a62e1>", line 1, in <module>
    my_model.save('my_model.h5',save_format='h5')
  File "/home/max/.local/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py", line 975, in save
    signatures, options)
  File "/home/max/.local/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/save.py", line 112, in save_model
    model, filepath, overwrite, include_optimizer)
  File "/home/max/.local/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/hdf5_format.py", line 109, in save_model_to_hdf5
    save_weights_to_hdf5_group(model_weights_group, model_layers)
  File "/home/max/.local/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/hdf5_format.py", line 631, in save_weights_to_hdf5_group
    param_dset = g.create_dataset(name, val.shape, dtype=val.dtype)
  File "/usr/local/lib/python3.6/dist-packages/h5py/_hl/group.py", line 143, in create_dataset
    if '/' in name:
TypeError: a bytes-like object is required, not 'str'

Not sure what's the problem... Is it a TF2 issue? Never had problems saving as .h5 with TF1.X and still can save it as .pb graph. However, I'd like to have it as .h5.

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

So this seems to be a a bug in the h5py library, it should accept a bytes or a unicode str, but fails with a str instance. It should be fixed in the next release.

You could downgrade the h5py version in your local installation and it should work around the problem. The problem was introduced by version 3.0.0, so earlier versions should work.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...