Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
574 views
in Technique[技术] by (71.8m points)

protocol buffers - Is there an example on how to generate protobuf files holding trained TensorFlow graphs

I am looking at Google's example on how to deploy and use a pre-trained Tensorflow graph (model) on Android. This example uses a .pb file at:

https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip

which is a link to a file that downloads automatically.

The example shows how to load the .pb file to a Tensorflow session and use it to perform classification, but it doesn't seem to mention how to generate such a .pb file, after a graph is trained (e.g., in Python).

Are there any examples on how to do that?

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

EDIT: The freeze_graph.py script, which is part of the TensorFlow repository, now serves as a tool that generates a protocol buffer representing a "frozen" trained model, from an existing TensorFlow GraphDef and a saved checkpoint. It uses the same steps as described below, but it much easier to use.


Currently the process isn't very well documented (and subject to refinement), but the approximate steps are as follows:

  1. Build and train your model as a tf.Graph called g_1.
  2. Fetch the final values of each of the variables and store them as numpy arrays (using Session.run()).
  3. In a new tf.Graph called g_2, create tf.constant() tensors for each of the variables, using the value of the corresponding numpy array fetched in step 2.
  4. Use tf.import_graph_def() to copy nodes from g_1 into g_2, and use the input_map argument to replace each variable in g_1 with the corresponding tf.constant() tensors created in step 3. You may also want to use input_map to specify a new input tensor (e.g. replacing an input pipeline with a tf.placeholder()). Use the return_elements argument to specify the name of the predicted output tensor.

  5. Call g_2.as_graph_def() to get a protocol buffer representation of the graph.

(NOTE: The generated graph will have extra nodes in the graph for training. Although it is not part of the public API, you may wish to use the internal graph_util.extract_sub_graph() function to strip these nodes from the graph.)


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...