Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
341 views
in Technique[技术] by (71.8m points)

python - How to save Tensorflow 2 Object Detection Model including all weights?

I am working on Object Detection using Tensorflow 2 API in Python. This works great so far. However, if I want to save the model, I am using exporter_main_v2.py which exports a graph (.pb) and a checkpoint (checkpoint, ckpt-0.data, ckpt-0.index). The graph does not include any weights, I always have to use the checkpoint to work with the saved model. Is there any way to save all weights into the Protobuf (.pb) file?

Here's what I've tried:

  • Save frozen model: TF2 does obviously not support frozen graphs any more. The export_inference_graph.py, which would freeze the graph including all weights, does not work under TF2.
  • Same goes with freeze_graph.py: Only possible using TF1

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

You can still use the freezing technique from TF1 in TF2, using the compat.v1 module:

In the following snippet, I assume that you have a pretrained model with weights saved in the TF2 fashion, with tf.saved_model.save.

graph = tf.Graph()
with graph.as_default():
    sess = tf.compat.v1.Session()
    with sess.as_default():
        # creating the model/loading it from a TF2 pb file
        # (If you have a keras model, you can use 
        #`tf.keras.models.load_model` instead). 
        model = tf.saved_model.load("/path/to/model")

# the default signature might be different.
sign = model.signatures["serving_default"]
# if using keras, just use model.outputs
tensor_out_names = [out.name.split(":")[0] for out in sign.outputs]
    
graphdef = tf.compat.v1.graph_util.convert_variables_to_constants(
    sess, graph.as_graph_def(), tensor_out_names
)
# the following is optional, use only if no more training is required
graphdef = tf.compat.v1.graph_util.remove_training_nodes(graphdef)
tf.python.framework.graph_io.write_graph(graphdef, "./", "/path/to/frozengraph", as_text=False)

However, I would refrain to do it other than for compatibility reason with an old tool. The compat module might be deprecated one day, and as far as I can understand, there is not a big value having only one file containing the graph+the weights rather than splitting them.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

2.1m questions

2.1m answers

60 comments

56.8k users

...