本文整理汇总了Python中tensorflow.python.framework.graph_util.convert_variables_to_constants函数的典型用法代码示例。如果您正苦于以下问题:Python convert_variables_to_constants函数的具体用法?Python convert_variables_to_constants怎么用?Python convert_variables_to_constants使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了convert_variables_to_constants函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: testConvertVariablesToConsts
def testConvertVariablesToConsts(self):
with ops.Graph().as_default():
variable_node = variables.Variable(1.0, name="variable_node")
_ = variables.Variable(1.0, name="unused_variable_node")
output_node = math_ops_lib.multiply(
variable_node, 2.0, name="output_node")
with session.Session() as sess:
init = variables.initialize_variables([variable_node])
sess.run(init)
output = sess.run(output_node)
self.assertNear(2.0, output, 0.00001)
variable_graph_def = sess.graph.as_graph_def()
# First get the constant_graph_def when variable_names_whitelist is set,
# note that if variable_names_whitelist is not set an error will be
# thrown because unused_variable_node is not initialized.
constant_graph_def = graph_util.convert_variables_to_constants(
sess,
variable_graph_def, ["output_node"],
variable_names_whitelist=set(["variable_node"]))
# Then initialize the unused variable, and get another
# constant_graph_def when variable_names_whitelist is not set.
sess.run(variables.global_variables_initializer())
constant_graph_def_without_variable_whitelist = (
graph_util.convert_variables_to_constants(sess, variable_graph_def,
["output_node"]))
# The unused variable should be cleared so the two graphs should be
# equivalent.
self.assertEqual(
str(constant_graph_def),
str(constant_graph_def_without_variable_whitelist))
# Test variable name black list. This should result in the variable not
# being a const.
sess.run(variables.global_variables_initializer())
constant_graph_def_with_blacklist = (
graph_util.convert_variables_to_constants(
sess,
variable_graph_def, ["output_node"],
variable_names_blacklist=set(["variable_node"])))
variable_node = None
for node in constant_graph_def_with_blacklist.node:
if node.name == "variable_node":
variable_node = node
self.assertIsNotNone(variable_node)
self.assertEqual(variable_node.op, "VariableV2")
# Now we make sure the variable is now a constant, and that the graph still
# produces the expected result.
with ops.Graph().as_default():
_ = importer.import_graph_def(constant_graph_def, name="")
self.assertEqual(4, len(constant_graph_def.node))
for node in constant_graph_def.node:
self.assertNotEqual("Variable", node.op)
self.assertNotEqual("VariableV2", node.op)
with session.Session() as sess:
output_node = sess.graph.get_tensor_by_name("output_node:0")
output = sess.run(output_node)
self.assertNear(2.0, output, 0.00001)
开发者ID:1000sprites,项目名称:tensorflow,代码行数:60,代码来源:graph_util_test.py
示例2: freeze_graph
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,
output_node_names, restore_op_name, filename_tensor_name,
output_graph, clear_devices, initializer_nodes):
"""Converts all variables in a graph and checkpoint into constants."""
if not tf.gfile.Exists(input_graph):
print("Input graph file '" + input_graph + "' does not exist!")
return -1
if input_saver and not tf.gfile.Exists(input_saver):
print("Input saver file '" + input_saver + "' does not exist!")
return -1
# 'input_checkpoint' may be a prefix if we're using Saver V2 format
if not tf.train.checkpoint_exists(input_checkpoint):
print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
return -1
if not output_node_names:
print("You need to supply the name of a node to --output_node_names.")
return -1
input_graph_def = tf.GraphDef()
mode = "rb" if input_binary else "r"
with tf.gfile.FastGFile(input_graph, mode) as f:
if input_binary:
input_graph_def.ParseFromString(f.read())
else:
text_format.Merge(f.read().decode("utf-8"), input_graph_def)
# Remove all the explicit device specifications for this node. This helps to
# make the graph more portable.
if clear_devices:
for node in input_graph_def.node:
node.device = ""
_ = tf.import_graph_def(input_graph_def, name="")
with tf.Session() as sess:
if input_saver:
with tf.gfile.FastGFile(input_saver, mode) as f:
saver_def = tf.train.SaverDef()
if input_binary:
saver_def.ParseFromString(f.read())
else:
text_format.Merge(f.read(), saver_def)
saver = tf.train.Saver(saver_def=saver_def)
saver.restore(sess, input_checkpoint)
else:
sess.run([restore_op_name], {filename_tensor_name: input_checkpoint})
if initializer_nodes:
sess.run(initializer_nodes)
variable_names_blacklist = (FLAGS.variable_names_blacklist.split(",") if
FLAGS.variable_names_blacklist else None)
output_graph_def = graph_util.convert_variables_to_constants(
sess, input_graph_def, output_node_names.split(","),
variable_names_blacklist=variable_names_blacklist)
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node))
开发者ID:curtiszimmerman,项目名称:tensorflow,代码行数:60,代码来源:freeze_graph.py
示例3: graph_def_from_checkpoint
def graph_def_from_checkpoint(checkpoint_dir, output_node_names):
"""Converts checkpoint data to GraphDef.
Reads the latest checkpoint data and produces a GraphDef in which the
variables have been converted to constants.
Args:
checkpoint_dir: Path to the checkpoints.
output_node_names: List of name strings for the result nodes of the graph.
Returns:
A GraphDef from the latest checkpoint
Raises:
ValueError: if no checkpoint is found
"""
checkpoint_path = saver_lib.latest_checkpoint(checkpoint_dir)
if checkpoint_path is None:
raise ValueError('Could not find a checkpoint at: {0}.'
.format(checkpoint_dir))
saver_for_restore = saver_lib.import_meta_graph(
checkpoint_path + '.meta', clear_devices=True)
with session.Session() as sess:
saver_for_restore.restore(sess, checkpoint_path)
graph_def = ops.get_default_graph().as_graph_def()
output_graph_def = graph_util.convert_variables_to_constants(
sess, graph_def, output_node_names)
return output_graph_def
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:30,代码来源:strip_pruning_vars_lib.py
示例4: freeze_graph_with_def_protos
def freeze_graph_with_def_protos(
input_graph_def,
input_saver_def,
input_checkpoint,
output_node_names,
restore_op_name,
filename_tensor_name,
clear_devices,
initializer_nodes,
variable_names_blacklist=''):
"""Converts all variables in a graph and checkpoint into constants."""
del restore_op_name, filename_tensor_name # Unused by updated loading code.
# 'input_checkpoint' may be a prefix if we're using Saver V2 format
if not saver_lib.checkpoint_exists(input_checkpoint):
raise ValueError(
'Input checkpoint "' + input_checkpoint + '" does not exist!')
if not output_node_names:
raise ValueError(
'You must supply the name of a node to --output_node_names.')
# Remove all the explicit device specifications for this node. This helps to
# make the graph more portable.
if clear_devices:
for node in input_graph_def.node:
node.device = ''
_ = importer.import_graph_def(input_graph_def, name='')
with session.Session() as sess:
if input_saver_def:
saver = saver_lib.Saver(saver_def=input_saver_def)
saver.restore(sess, input_checkpoint)
else:
var_list = {}
reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
try:
tensor = sess.graph.get_tensor_by_name(key + ':0')
except KeyError:
# This tensor doesn't exist in the graph (for example it's
# 'global_step' or a similar housekeeping element) so skip it.
continue
var_list[key] = tensor
saver = saver_lib.Saver(var_list=var_list)
saver.restore(sess, input_checkpoint)
if initializer_nodes:
sess.run(initializer_nodes)
variable_names_blacklist = (variable_names_blacklist.split(',') if
variable_names_blacklist else None)
output_graph_def = graph_util.convert_variables_to_constants(
sess,
input_graph_def,
output_node_names.split(','),
variable_names_blacklist=variable_names_blacklist)
return output_graph_def
开发者ID:chenxiang204,项目名称:code,代码行数:60,代码来源:exporter.py
示例5: from_session
def from_session(cls,
sess,
input_tensors,
output_tensors,
freeze_variables=False):
"""Creates a TocoConverter class from a TensorFlow Session.
Args:
sess: TensorFlow Session.
input_tensors: List of input tensors. Type and shape are computed using
`foo.get_shape()` and `foo.dtype`.
output_tensors: List of output tensors (only .name is used from this).
freeze_variables: Boolean indicating whether the variables need to be
converted into constants via the freeze_graph.py script.
(default False)
Returns:
TocoConverter class.
"""
# Get GraphDef.
if freeze_variables:
sess.run(global_variables_initializer())
output_arrays = [tensor_name(tensor) for tensor in output_tensors]
graph_def = tf_graph_util.convert_variables_to_constants(
sess, sess.graph_def, output_arrays)
else:
graph_def = sess.graph_def
# Create TocoConverter class.
return cls(graph_def, input_tensors, output_tensors)
开发者ID:jinxin0924,项目名称:tensorflow,代码行数:31,代码来源:lite.py
示例6: testConvertVariablesToConstsWithEmbeddings
def testConvertVariablesToConstsWithEmbeddings(self):
"""Freezes a graph with embeddings."""
input_data = np.array(np.random.random_sample([1, 1]), dtype=np.int32)
# Make model.
state_input = keras.layers.Input(
shape=(1,), name="state_input", dtype="int32")
output = keras.layers.Embedding(
output_dim=16, input_dim=100, input_length=1, name="state")(
state_input)
model = keras.models.Model(inputs=[state_input], outputs=[output])
model.compile(
loss={"state": "sparse_categorical_crossentropy"}, optimizer="adam")
# Get associated session.
sess = keras.backend.get_session()
variable_graph_def = sess.graph_def
output_tensor = [tensor.name.split(":")[0] for tensor in model.outputs]
constant_graph_def = graph_util.convert_variables_to_constants(
sess, variable_graph_def, output_tensor)
# Ensure graph has no variables.
for node in constant_graph_def.node:
self.assertNotIn(
node.op, ["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
# Compare the value of the graphs.
expected_value = model.predict(input_data)
actual_value = self._evaluate_graph_def(constant_graph_def, model.inputs,
model.outputs, [input_data])
np.testing.assert_almost_equal(np.array([expected_value]), actual_value, 5)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:31,代码来源:graph_util_test.py
示例7: freeze_session
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
"""
Freezes the state of a session into a prunned computation graph.
Creates a new computation graph where variable nodes are replaced by
constants taking their current value in the session. The new graph will be
prunned so subgraphs that are not neccesary to compute the requested
outputs are removed.
@param session The TensorFlow session to be frozen.
@param keep_var_names A list of variable names that should not be frozen,
or None to freeze all the variables in the graph.
@param output_names Names of the relevant graph outputs.
@param clear_devices Remove the device directives from the graph for better portability.
@return The frozen graph definition.
"""
from tensorflow.python.framework.graph_util import convert_variables_to_constants
graph = session.graph
with graph.as_default():
freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
output_names = output_names or []
output_names += [v.op.name for v in tf.global_variables()]
input_graph_def = graph.as_graph_def()
if clear_devices:
for node in input_graph_def.node:
node.device = ""
frozen_graph = convert_variables_to_constants(session, input_graph_def,
output_names, freeze_var_names)
return frozen_graph
开发者ID:DXZ,项目名称:git_test,代码行数:28,代码来源:keras_to_tfsevring_final.py
示例8: freeze_graph
def freeze_graph(sess, input_tensors, output_tensors):
"""Returns a frozen GraphDef.
Runs a Grappler pass and freezes a graph with Variables in it. Otherwise the
existing GraphDef is returned. The Grappler pass is only run on models that
are frozen in order to inline the functions in the graph.
If OpHints is present, it will try to convert the OpHint graph.
Args:
sess: TensorFlow Session.
input_tensors: List of input tensors.
output_tensors: List of output tensors (only .name is used from this).
Returns:
Frozen GraphDef.
"""
# Grappler inline function optimization will break OpHints graph
# transformation, so if OpHints are present, just convert it.
hinted_outputs_nodes = find_all_hinted_output_nodes(sess)
if len(hinted_outputs_nodes) > 0: # pylint: disable=g-explicit-length-test
return _convert_op_hints_if_present(sess, output_tensors)
# Runs a Grappler pass in order to inline any functions in the graph.
config = get_grappler_config(function_only=True)
graph_def = run_graph_optimizations(
sess.graph_def, input_tensors, output_tensors, config, graph=sess.graph)
if not is_frozen_graph(sess):
output_arrays = [get_tensor_name(tensor) for tensor in output_tensors]
return tf_graph_util.convert_variables_to_constants(sess, graph_def,
output_arrays)
else:
return sess.graph_def
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:33,代码来源:util.py
示例9: testConvertVariablesToConstsWithFunctions
def testConvertVariablesToConstsWithFunctions(self):
@function.Defun(dtypes.float32)
def plus_one(x):
return x + 1.0
with ops.Graph().as_default():
variable_node = variables.Variable(1.0, name="variable_node")
_ = variables.Variable(1.0, name="unused_variable_node")
defun_node = plus_one(variable_node)
output_node = math_ops_lib.multiply(
defun_node, 2.0, name="output_node")
with session.Session() as sess:
init = variables.initialize_variables([variable_node])
sess.run(init)
output = sess.run(output_node)
self.assertNear(4.0, output, 0.00001)
variable_graph_def = sess.graph.as_graph_def()
# First get the constant_graph_def when variable_names_whitelist is set,
# note that if variable_names_whitelist is not set an error will be
# thrown because unused_variable_node is not initialized.
constant_graph_def = graph_util.convert_variables_to_constants(
sess,
variable_graph_def, ["output_node"],
variable_names_whitelist=set(["variable_node"]))
self.assertEqual(variable_graph_def.library,
constant_graph_def.library)
开发者ID:1000sprites,项目名称:tensorflow,代码行数:29,代码来源:graph_util_test.py
示例10: _freeze_graph_with_def_protos
def _freeze_graph_with_def_protos(input_graph_def, output_node_names,
initializer_names, shared_init_op_name,
input_saver_def, input_checkpoint):
"""Converts all variables in a graph and checkpoint into constants.
During this process, we need to retain certain initializer nodes (e.g. table
initializer nodes). Instead of determining which dependencies
of the shared initializer node (e.g. group_deps) to keep, we
reconstruct the connections between the individual initializer nodes and
the shared node after freezing the graph.
Args:
input_graph_def: A GraphDef proto to be frozen.
output_node_names: Names of output nodes.
initializer_names: Names of initializer nodes to keep.
shared_init_op_name: The name of the shared initializer node to connect the
nodes in initializer names to.
input_saver_def: A SaverDef proto used for restoring a checkpoint.
input_checkpoint: A path to a checkpoint to restore.
Returns:
A frozen GraphDef.
"""
with _ops.Graph().as_default():
_ = _importer.import_graph_def(input_graph_def, name='')
with _session.Session() as sess:
saver = _saver_lib.Saver(saver_def=input_saver_def)
saver.restore(sess, input_checkpoint)
output_graph_def = _graph_util.convert_variables_to_constants(
sess, input_graph_def, output_node_names + initializer_names)
_connect_to_shared_init_op(output_graph_def, shared_init_op_name,
initializer_names)
return output_graph_def
开发者ID:bikong2,项目名称:tensorflow,代码行数:35,代码来源:meta_graph_transform.py
示例11: freeze_graph_def
def freeze_graph_def(sess, input_graph_def, output_node_names):
for node in input_graph_def.node:
if node.op == 'RefSwitch':
node.op = 'Switch'
for index in xrange(len(node.input)):
if 'moving_' in node.input[index]:
node.input[index] = node.input[index] + '/read'
elif node.op == 'AssignSub':
node.op = 'Sub'
if 'use_locking' in node.attr: del node.attr['use_locking']
elif node.op == 'AssignAdd':
node.op = 'Add'
if 'use_locking' in node.attr: del node.attr['use_locking']
# Get the list of important nodes
whitelist_names = []
for node in input_graph_def.node:
if (node.name.startswith('InceptionResnetV1') or node.name.startswith('embeddings') or
node.name.startswith('phase_train') or node.name.startswith('Bottleneck') or node.name.startswith('Logits')):
whitelist_names.append(node.name)
# Replace all the variables in the graph with constants of the same values
output_graph_def = graph_util.convert_variables_to_constants(
sess, input_graph_def, output_node_names.split(","),
variable_names_whitelist=whitelist_names)
return output_graph_def
开发者ID:citysir,项目名称:facenet,代码行数:26,代码来源:freeze_graph.py
示例12: saveData
def saveData(self,step):
print('{} Saving checkpoint file to: {}'.format(
datetime.datetime.now().strftime('%m-%d %H:%M:%S'),
self.output_dir))
# 保存图的权值
self.saver.save(
self.sess, self.ckpt_file, global_step=self.global_step)
# 保存图的结构
tf.train.write_graph(self.sess.graph_def,
os.path.join(cfg.OUTPUT_DIR, cfg.DATA_VERSION, 'model'),
'train.pbtxt')
# 保存到权值对图,生成可供android使用的.pb文件
graph_def = tf.get_default_graph().as_graph_def()
print("global_variables are")
variables_to_save = []
for variables in self.sess.graph_def.node:
print("{}:{}".format(str(variables.name),type(variables)))
variables_to_save.append(str(variables.name).split(':')[0])
print("--------------")
output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
self.sess,
graph_def,
#['yolo/pad_1/paddings']
variables_to_save
#self.net.logits
# ["predictions"] # 需要保存节点的名字///////////////////////////////////需要再改改
)
with tf.gfile.GFile(
os.path.join(cfg.OUTPUT_DIR, cfg.DATA_VERSION, 'model', 'train.'+step+str(step)+'.pb'),
"wb") as f: # 保存模型
f.write(output_graph_def.SerializeToString()) # 序列化输出
print("%d ops in the final graph." % len(output_graph_def.node))
####################################################################################################
freezetime = datetime.datetime.now().strftime('%m-%d-%H-%M-%S')
zu = ZipUtil()
zipfilename = cfg.DATA_UploadZipFileName +'.'+str(step)+ '.' + freezetime
# 添加啦step参数,可以按照训练对部分进行压缩,,不用全部压缩了
zu.zip_dir(os.path.join(cfg.OUTPUT_DIR, cfg.DATA_VERSION),
step,
zipfilename)
uploader = Uploader()
uploader.setQiniuKEY('mMQxjyif6Uk8nSGIn9ZD3I19MBMEK3IUGngcX8_p',
'J5gFhdpQ-1O1rkCnlqYnzPiH3XTst2Szlv9GlmQM')
#uploader.upload2qiniu(cfg.DATA_UploadZipFileName + '.' + freezetime,zipfilename).start()
sendData = {"state":"prepared",
"filename":str(zipfilename),
"filepath":os.path.join(cfg.OUTPUT_DIR, cfg.DATA_VERSION,zipfilename),
"step":step,
}
uploader.notifyForTrans(sendData)
开发者ID:wanglikang,项目名称:zzuARTensorflow2,代码行数:59,代码来源:train.py
示例13: train_network
def train_network(graph, batch_size, num_epochs, pb_file_path):
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
epoch_delta = 2
for epoch_index in range(num_epochs):
for i in range(12):
sess.run([graph['optimize']], feed_dict={
graph['x']: np.reshape(x_train[i], (1, 224, 224, 3)),
graph['y']: ([[1, 0]] if y_train[i] == 0 else [[0, 1]])
})
if epoch_index % epoch_delta == 0:
total_batches_in_train_set = 0
total_correct_times_in_train_set = 0
total_cost_in_train_set = 0.
for i in range(12):
return_correct_times_in_batch = sess.run(graph['correct_times_in_batch'], feed_dict={
graph['x']: np.reshape(x_train[i], (1, 224, 224, 3)),
graph['y']: ([[1, 0]] if y_train[i] == 0 else [[0, 1]])
})
mean_cost_in_batch = sess.run(graph['cost'], feed_dict={
graph['x']: np.reshape(x_train[i], (1, 224, 224, 3)),
graph['y']: ([[1, 0]] if y_train[i] == 0 else [[0, 1]])
})
total_batches_in_train_set += 1
total_correct_times_in_train_set += return_correct_times_in_batch
total_cost_in_train_set += (mean_cost_in_batch * batch_size)
total_batches_in_test_set = 0
total_correct_times_in_test_set = 0
total_cost_in_test_set = 0.
for i in range(3):
return_correct_times_in_batch = sess.run(graph['correct_times_in_batch'], feed_dict={
graph['x']: np.reshape(x_val[i], (1, 224, 224, 3)),
graph['y']: ([[1, 0]] if y_val[i] == 0 else [[0, 1]])
})
mean_cost_in_batch = sess.run(graph['cost'], feed_dict={
graph['x']: np.reshape(x_val[i], (1, 224, 224, 3)),
graph['y']: ([[1, 0]] if y_val[i] == 0 else [[0, 1]])
})
total_batches_in_test_set += 1
total_correct_times_in_test_set += return_correct_times_in_batch
total_cost_in_test_set += (mean_cost_in_batch * batch_size)
acy_on_test = total_correct_times_in_test_set / float(total_batches_in_test_set * batch_size)
acy_on_train = total_correct_times_in_train_set / float(total_batches_in_train_set * batch_size)
print('Epoch - {:2d}, acy_on_test:{:6.2f}%({}/{}),loss_on_test:{:6.2f}, acy_on_train:{:6.2f}%({}/{}),loss_on_train:{:6.2f}'.format(epoch_index, acy_on_test*100.0,total_correct_times_in_test_set,
total_batches_in_test_set * batch_size,
total_cost_in_test_set,
acy_on_train * 100.0,
total_correct_times_in_train_set,
total_batches_in_train_set * batch_size,
total_cost_in_train_set))
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["output"])
with tf.gfile.FastGFile(pb_file_path, mode='wb') as f:
f.write(constant_graph.SerializeToString())
开发者ID:JGyoung33,项目名称:tensorflow-vgg16-train-and-test,代码行数:57,代码来源:train_vgg.py
示例14: save_graph_to_file
def save_graph_to_file(graph, graph_file_name, model_info, class_count):
sess, _, _, _, _ = build_eval_session(model_info, class_count)
graph = sess.graph
output_graph_def = graph_util.convert_variables_to_constants(
sess, graph.as_graph_def(), [FLAGS.final_tensor_name])
with gfile.FastGFile(graph_file_name, 'wb') as f:
f.write(output_graph_def.SerializeToString())
开发者ID:google,项目名称:makerfaire-2016,代码行数:9,代码来源:retrain.py
示例15: _convert_op_hints_if_present
def _convert_op_hints_if_present(sess, output_tensors):
if is_frozen_graph(sess):
raise ValueError("Try to convert op hints, needs unfrozen graph.")
hinted_outputs_nodes = find_all_hinted_output_nodes(sess)
output_arrays = [get_tensor_name(tensor) for tensor in output_tensors]
graph_def = tf_graph_util.convert_variables_to_constants(
sess, sess.graph_def, output_arrays + hinted_outputs_nodes)
graph_def = convert_op_hints_to_stubs(graph_def=graph_def)
graph_def = tf_graph_util.remove_training_nodes(graph_def)
return graph_def
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:10,代码来源:util.py
示例16: _get_initial_outputs
def _get_initial_outputs(self, output_tensor_names_list):
with self.session(graph=self.initial_graph) as sess1:
self._prune_model(sess1)
reference_outputs = self._get_outputs(sess1, self.initial_graph,
output_tensor_names_list)
self.initial_graph_def = graph_util.convert_variables_to_constants(
sess1, sess1.graph.as_graph_def(),
_get_node_names(output_tensor_names_list))
return reference_outputs
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:10,代码来源:strip_pruning_vars_test.py
示例17: freeze_saved_model
def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
output_arrays, tag_set, signature_key):
"""Converts a SavedModel to a frozen graph.
Args:
saved_model_dir: SavedModel directory to convert.
input_arrays: List of input tensors to freeze graph with. Uses input arrays
from SignatureDef when none are provided.
input_shapes: Dict of strings representing input tensor names to list of
integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
Automatically determined when input shapes is None (e.g., {"foo" : None}).
output_arrays: List of output tensors to freeze graph with. Uses output
arrays from SignatureDef when none are provided.
tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
analyze. All tags in the tag set must be present.
signature_key: Key identifying SignatureDef containing inputs and outputs.
Returns:
frozen_graph_def: Frozen GraphDef.
in_tensors: List of input tensors for the graph.
out_tensors: List of output tensors for the graph.
Raises:
ValueError:
SavedModel doesn't contain a MetaGraphDef identified by tag_set.
signature_key is not in the MetaGraphDef.
assets/ directory is in the MetaGraphDef.
input_shapes does not match the length of input_arrays.
input_arrays or output_arrays are not valid.
"""
# Read SignatureDef.
meta_graph = _get_meta_graph_def(saved_model_dir, tag_set)
signature_def = _get_signature_def(meta_graph, signature_key)
inputs, outputs = _get_inputs_outputs(signature_def)
# Check SavedModel for assets directory.
collection_def = meta_graph.collection_def
if constants.ASSETS_KEY in collection_def:
raise ValueError("SavedModels with assets/ directory are not supported.")
graph = ops.Graph()
with session.Session(graph=graph) as sess:
loader.load(sess, meta_graph.meta_info_def.tags, saved_model_dir)
# Gets input and output tensors.
# TODO(zhixianyan): Use TFLite supported Op list to filter outputs.
in_tensors = _get_tensors(graph, inputs, input_arrays)
out_tensors = _get_tensors(graph, outputs, output_arrays)
set_tensor_shapes(in_tensors, input_shapes)
output_names = [node.split(":")[0] for node in outputs]
frozen_graph_def = tf_graph_util.convert_variables_to_constants(
sess, graph.as_graph_def(), output_names)
return frozen_graph_def, in_tensors, out_tensors
开发者ID:AnishShah,项目名称:tensorflow,代码行数:55,代码来源:convert_saved_model.py
示例18: testConvertVariablesToConsts
def testConvertVariablesToConsts(self):
with tf.Graph().as_default():
variable_node = tf.Variable(1.0, name="variable_node")
_ = tf.Variable(1.0, name="unused_variable_node")
output_node = tf.mul(variable_node, 2.0, name="output_node")
with tf.Session() as sess:
init = tf.initialize_variables([variable_node])
sess.run(init)
output = sess.run(output_node)
self.assertNear(2.0, output, 0.00001)
variable_graph_def = sess.graph.as_graph_def()
# First get the constant_graph_def when variable_names_whitelist is set,
# note that if variable_names_whitelist is not set an error will be
# thrown because unused_variable_node is not initialized.
constant_graph_def = graph_util.convert_variables_to_constants(
sess, variable_graph_def, ["output_node"],
variable_names_whitelist=set(["variable_node"]))
# Then initialize the unused variable, and get another
# constant_graph_def when variable_names_whitelist is not set.
sess.run(tf.initialize_all_variables())
constant_graph_def_without_variable_whitelist = (
graph_util.convert_variables_to_constants(
sess, variable_graph_def, ["output_node"]))
# The unused variable should be cleared so the two graphs should be
# equivalent.
self.assertEqual(str(constant_graph_def),
str(constant_graph_def_without_variable_whitelist))
# Now we make sure the variable is now a constant, and that the graph still
# produces the expected result.
with tf.Graph().as_default():
_ = tf.import_graph_def(constant_graph_def, name="")
self.assertEqual(4, len(constant_graph_def.node))
for node in constant_graph_def.node:
self.assertNotEqual("Variable", node.op)
with tf.Session() as sess:
output_node = sess.graph.get_tensor_by_name("output_node:0")
output = sess.run(output_node)
self.assertNear(2.0, output, 0.00001)
开发者ID:2020zyc,项目名称:tensorflow,代码行数:41,代码来源:graph_util_test.py
示例19: convert_tf_hub_module
def convert_tf_hub_module(module_path, output_dir,
signature='default', quantization_dtype=None,
skip_op_check=False, strip_debug_ops=False):
"""Freeze the TF-Hub module and check compatibility with Tensorflow.js.
Optimize and convert the TF-Hub module to Tensorflow.js format, if it passes
the compatiblity check.
Args:
module_path: string Path to the module.
output_dir: string The name of the output directory. The directory
will consist of
- a file named 'tensorflowjs_model.pb'
- a JSON weights manifest file named 'weights_manifest.json'
- possibly sharded binary weight files.
signature: string Signature to load.
skip_op_check: Bool whether to skip the op check.
strip_debug_ops: Bool whether to strip debug ops.
"""
if not os.path.exists(output_dir):
os.makedirs(output_dir)
graph, sess, inputs, outputs = load_and_initialize_hub_module(
module_path, signature)
input_node_names = []
output_node_names = []
for _, input_tensor in inputs.items():
input_node_names.append(input_tensor.name.split(':')[0])
for _, output_tensor in outputs.items():
output_node_names.append(output_tensor.name.split(':')[0])
print('Creating a model with inputs %s and outputs %s.' % (input_node_names,
output_node_names))
frozen_graph_def = graph_util.convert_variables_to_constants(
sess, graph.as_graph_def(), output_node_names)
output_graph = os.path.join(output_dir, DEFAULT_MODEL_PB_FILENAME)
frozen_file = output_graph + '.frozen'
try:
with tf.gfile.GFile(frozen_file, 'wb') as f:
f.write(frozen_graph_def.SerializeToString())
graph = load_graph(frozen_file, ','.join(output_node_names))
optimize_graph(graph, output_graph, quantization_dtype=quantization_dtype,
skip_op_check=skip_op_check, strip_debug_ops=strip_debug_ops)
finally:
# Clean up the temp files.
if os.path.exists(frozen_file):
os.remove(frozen_file)
开发者ID:oveddan,项目名称:tfjs-converter,代码行数:53,代码来源:tf_saved_model_conversion.py
示例20: freeze_graph
def freeze_graph(session, outputs):
"""Freeze the current graph.
Args:
session: Tensorflow sessions containing the graph
outputs: List of output tensors
Returns:
The frozen graph_def.
"""
return tf_graph_util.convert_variables_to_constants(
session, session.graph.as_graph_def(), [x.op.name for x in outputs])
开发者ID:Kongsea,项目名称:tensorflow,代码行数:12,代码来源:generate_examples.py
注:本文中的tensorflow.python.framework.graph_util.convert_variables_to_constants函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论