Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
374 views
in Technique[技术] by (71.8m points)

python 3.x - How to graph tf.keras model in Tensorflow-2.0?

I upgraded to Tensorflow 2.0 and there is no tf.summary.FileWriter("tf_graphs", sess.graph). I was looking through some other StackOverflow questions on this and they said to use tf.compat.v1.summary etc. Surely there must be a way to graph and visualize a tf.keras model in Tensorflow version 2. What is it? I'm looking for a tensorboard output like the one below. Thank you!

enter image description here

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

You can visualize the graph of any tf.function decorated function, but first, you have to trace its execution.

Visualizing the graph of a Keras model means to visualize it's call method.

By default, this method is not tf.function decorated and therefore you have to wrap the model call in a function correctly decorated and execute it.

import tensorflow as tf

model = tf.keras.Sequential(
    [
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(32, activation="relu"),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(10, activation="softmax"),
    ]
)


@tf.function
def traceme(x):
    return model(x)


logdir = "log"
writer = tf.summary.create_file_writer(logdir)
tf.summary.trace_on(graph=True, profiler=True)
# Forward pass
traceme(tf.zeros((1, 28, 28, 1)))
with writer.as_default():
    tf.summary.trace_export(name="model_trace", step=0, profiler_outdir=logdir)

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...