本文整理汇总了Python中tensorflow.python.framework.ops.get_collection_ref函数的典型用法代码示例。如果您正苦于以下问题:Python get_collection_ref函数的具体用法?Python get_collection_ref怎么用?Python get_collection_ref使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了get_collection_ref函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: test_asset_loading
def test_asset_loading(self):
first_path = self._v1_asset_saved_model()
imported = load.load(first_path)
self.evaluate(lookup_ops.tables_initializer())
fn = imported.signatures["serving_default"]
self.assertAllClose({"output": [2, 0]},
fn(start=constant_op.constant(["gamma", "alpha"])))
second_path = os.path.join(self.get_temp_dir(), "saved_model",
str(ops.uid()))
save.save(imported, second_path, signatures=imported.signatures)
shutil.rmtree(first_path)
del ops.get_collection_ref(ops.GraphKeys.TABLE_INITIALIZERS)[:]
second_import = load.load(second_path)
self.evaluate(lookup_ops.tables_initializer())
fn = second_import.signatures["serving_default"]
self.assertAllClose({"output": [2, 0]},
fn(start=constant_op.constant(["gamma", "alpha"])))
third_path = os.path.join(self.get_temp_dir(), "saved_model",
str(ops.uid()))
save.save(second_import, third_path, signatures=second_import.signatures)
shutil.rmtree(second_path)
del ops.get_collection_ref(ops.GraphKeys.TABLE_INITIALIZERS)[:]
third_import = load.load(third_path)
self.evaluate(lookup_ops.tables_initializer())
fn = third_import.signatures["serving_default"]
self.assertAllClose({"output": [2, 0]},
fn(start=constant_op.constant(["gamma", "alpha"])))
开发者ID:aritratony,项目名称:tensorflow,代码行数:28,代码来源:load_v1_in_v2_test.py
示例2: _call_func
def _call_func(self, args, kwargs):
try:
vars_at_start = len(
ops.get_collection_ref(ops.GraphKeys.GLOBAL_VARIABLES))
trainable_at_start = len(
ops.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES))
if self._variables_created:
result = self._func(*args, **kwargs)
else:
# The first time we run, restore variables if necessary (via
# Checkpointable).
with checkpointable_util.capture_dependencies(template=self):
result = self._func(*args, **kwargs)
if self._variables_created:
# Variables were previously created, implying this is not the first
# time the template has been called. Check to make sure that no new
# trainable variables were created this time around.
trainable_variables = ops.get_collection_ref(
ops.GraphKeys.TRAINABLE_VARIABLES)
# If a variable that we intend to train is created as a side effect
# of creating a template, then that is almost certainly an error.
if trainable_at_start != len(trainable_variables):
raise ValueError("Trainable variable created when calling a template "
"after the first time, perhaps you used tf.Variable "
"when you meant tf.get_variable: %s" %
(trainable_variables[trainable_at_start:],))
# Non-trainable tracking variables are a legitimate reason why a new
# variable would be created, but it is a relatively advanced use-case,
# so log it.
variables = ops.get_collection_ref(ops.GraphKeys.GLOBAL_VARIABLES)
if vars_at_start != len(variables):
logging.info("New variables created when calling a template after "
"the first time, perhaps you used tf.Variable when you "
"meant tf.get_variable: %s",
variables[vars_at_start:])
else:
self._variables_created = True
return result
except Exception as exc:
# Reraise the exception, but append the original definition to the
# trace.
args = exc.args
if not args:
arg0 = ""
else:
arg0 = args[0]
trace = "".join(_skip_common_stack_elements(self._stacktrace,
traceback.format_stack()))
arg0 = "%s\n\noriginally defined at:\n%s" % (arg0, trace)
new_args = [arg0]
new_args.extend(args[1:])
exc.args = tuple(new_args)
raise
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:55,代码来源:template.py
示例3: _clear_saved_model_collections
def _clear_saved_model_collections():
"""Clear collections that are expected empty when exporting a SavedModel.
The SavedModel builder uses these collections to track ops necessary to
restore the graph state. These collections are expected to be empty before
MetaGraphs are added to the builder.
"""
del ops.get_collection_ref(constants.ASSETS_KEY)[:]
del ops.get_collection_ref(constants.LEGACY_INIT_OP_KEY)[:]
del ops.get_collection_ref(constants.MAIN_OP_KEY)[:]
del ops.get_collection_ref(constants.TRAIN_OP_KEY)[:]
开发者ID:AnishShah,项目名称:tensorflow,代码行数:11,代码来源:saved_model_estimator.py
示例4: testBasicMemory
def testBasicMemory(self):
"""Make sure arguments can be passed correctly."""
with test_util.device(use_gpu=False):
a = constant_op.constant(10, name="a")
b = constant_op.constant(20, name="b")
c = math_ops.add_n([a, b], name="c")
d = math_ops.add_n([b, c], name="d")
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(d)
mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
report = cost_analyzer.GenerateMemoryReport(mg)
# Print the report to make it easier to debug
print("{}".format(report))
# Check the report
self.assertTrue(
"Peak usage for device /job:localhost/replica:0/task:0/device:CPU:0: "
"16 bytes"
in report)
self.assertTrue(" a:0 uses 4 bytes" in report)
self.assertTrue(" b:0 uses 4 bytes" in report)
self.assertTrue(" c:0 uses 4 bytes" in report)
self.assertTrue(" d:0 uses 4 bytes" in report)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:25,代码来源:cost_analyzer_test.py
示例5: build
def build(self, inputs_shape):
# Call the build method of the parent class.
super(MaskedBasicLSTMCell, self).build(inputs_shape)
self.built = False
input_depth = inputs_shape[1].value
h_depth = self._num_units
self._mask = self.add_variable(
name="mask",
shape=[input_depth + h_depth, 4 * h_depth],
initializer=init_ops.ones_initializer(),
trainable=False,
dtype=self.dtype)
self._threshold = self.add_variable(
name="threshold",
shape=[],
initializer=init_ops.zeros_initializer(),
trainable=False,
dtype=self.dtype)
# Add masked_weights in the weights namescope so as to make it easier
# for the quantization library to add quant ops.
self._masked_kernel = math_ops.multiply(self._mask, self._kernel,
core_layers.MASKED_WEIGHT_NAME)
if self._mask not in ops.get_collection_ref(core_layers.MASK_COLLECTION):
ops.add_to_collection(core_layers.MASK_COLLECTION, self._mask)
ops.add_to_collection(core_layers.MASKED_WEIGHT_COLLECTION,
self._masked_kernel)
ops.add_to_collection(core_layers.THRESHOLD_COLLECTION, self._threshold)
ops.add_to_collection(core_layers.WEIGHT_COLLECTION, self._kernel)
self.built = True
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:32,代码来源:rnn_cells.py
示例6: record_summaries_every_n_global_steps
def record_summaries_every_n_global_steps(n):
"""Sets the should_record_summaries Tensor to true if global_step % n == 0."""
collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
old = collection_ref[:]
collection_ref[:] = [training_util.get_global_step() % n == 0]
yield
collection_ref[:] = old
开发者ID:benoitsteiner,项目名称:tensorflow-opencl,代码行数:7,代码来源:summary_ops.py
示例7: never_record_summaries
def never_record_summaries():
"""Sets the should_record_summaries Tensor to always false."""
collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
old = collection_ref[:]
collection_ref[:] = [False]
yield
collection_ref[:] = old
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:7,代码来源:summary_ops.py
示例8: testSimpleSwap
def testSimpleSwap(self):
"""Check that the swap annotations are followed."""
a = variables.Variable(10, name='a')
b = variables.Variable(20, name='b')
c = math_ops.add_n([a, b], name='c')
d = math_ops.add_n([b, c], name='d')
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(d)
d.op.node_def.attr['_swap_to_host'].i = 0
mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
graph_size = len(mg.graph_def.node)
rewriter_config = rewriter_config_pb2.RewriterConfig(
disable_model_pruning=True,
memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)
graph = tf_optimizer.OptimizeGraph(rewriter_config, mg)
self.assertEqual(len(graph.node), graph_size + 2)
self.assertTrue(
set([node.name for node in graph.node]) > set(
['a', 'b', 'c', 'd', 'swap_in_d_0', 'swap_out_d_0']))
for node in graph.node:
if node.name == 'swap_in_d_0':
self.assertEqual('swap_out_d_0', node.input[0])
self.assertEqual('^b/read', node.input[1])
elif node.name == 'swap_out_d_0':
self.assertEqual('b/read', node.input[0])
elif node.name == 'd':
self.assertEqual('swap_in_d_0', node.input[0])
self.assertEqual('c', node.input[1])
开发者ID:Dr4KK,项目名称:tensorflow,代码行数:32,代码来源:memory_optimizer_test.py
示例9: apply_mask
def apply_mask(x, scope=''):
"""Apply mask to a given weight tensor.
Args:
x: Input weight tensor
scope: The current variable scope. Defaults to ""
Returns:
Tensor representing masked_weights
"""
mask = _weight_mask_variable(x, scope)
threshold = _weight_threshold_variable(x, scope)
# Add masked_weights in the weights namescope so as to make it easier
# for the quantization library to add quant ops.
masked_weights = math_ops.multiply(mask, x, _MASKED_WEIGHT_NAME)
# Make sure the mask for a given variable are not added multiple times to the
# collection. This is particularly important when applying mask to RNN's
# weight variables
if mask not in ops.get_collection_ref(_MASK_COLLECTION):
ops.add_to_collection(_THRESHOLD_COLLECTION, threshold)
ops.add_to_collection(_MASK_COLLECTION, mask)
ops.add_to_collection(_MASKED_WEIGHT_COLLECTION, masked_weights)
ops.add_to_collection(_WEIGHT_COLLECTION, x)
return masked_weights
开发者ID:SylChan,项目名称:tensorflow,代码行数:25,代码来源:pruning.py
示例10: testUpdates
def testUpdates(self):
with ops.Graph().as_default() as g:
a = constant_op.constant(10)
b = constant_op.constant(20)
c = a + b
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(c)
mg = meta_graph.create_meta_graph_def(graph=g)
grappler_item = item.Item(mg)
initial_tf_item = grappler_item.tf_item
no_change_tf_item = grappler_item.tf_item
self.assertEqual(initial_tf_item, no_change_tf_item)
# Modify the placement.
for node in grappler_item.metagraph.graph_def.node:
node.device = '/cpu:0'
new_tf_item = grappler_item.tf_item
self.assertNotEqual(initial_tf_item, new_tf_item)
# Assign the same placement.
for node in grappler_item.metagraph.graph_def.node:
node.device = '/cpu:0'
newest_tf_item = grappler_item.tf_item
self.assertEqual(new_tf_item, newest_tf_item)
开发者ID:AnddyWang,项目名称:tensorflow,代码行数:25,代码来源:item_test.py
示例11: testFromGenerator
def testFromGenerator(self):
test_cases = [{
'tensor': 0,
'shape': tensor_shape.TensorShape([])
}, {
'tensor': np.array([1, 2, 3]),
'shape': tensor_shape.TensorShape([3])
}, {
'tensor': np.array([[1, 2, 3]]),
'shape': tensor_shape.TensorShape([1, 3])
}]
for test_case in test_cases:
def make_generator(tensor):
def generator():
yield tensor
return generator
with ops.Graph().as_default() as g:
dataset = dataset_ops.Dataset.from_generator(
make_generator(test_case['tensor']),
dtypes.int64,
output_shapes=test_case['shape'])
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(get_next)
mg = meta_graph.create_meta_graph_def(graph=g)
grappler_item = item.Item(mg)
op_properties = grappler_item.GetOpProperties()
self.assertEqual(test_case['shape'],
op_properties['IteratorGetNext'][0].shape)
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:35,代码来源:datasets_test.py
示例12: testPruning
def testPruning(self):
x = constant_op.constant(1)
tensor_list = list_ops.empty_tensor_list(
element_dtype=x.dtype, element_shape=x.shape)
def Cond(x, tl):
del tl # Unused for Cond.
return x < 5
def Body(x, tl):
return x + 1, list_ops.tensor_list_push_back(tl, x)
outputs = while_loop_v1(Cond, Body, [x, tensor_list])
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(outputs[0])
def GetOptimizedGraph():
mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
rewriter_config = rewriter_config_pb2.RewriterConfig(
constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)
return tf_optimizer.OptimizeGraph(rewriter_config, mg)
g = GetOptimizedGraph()
self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 1)
stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype)
train_op.append(stack)
g = GetOptimizedGraph()
self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 2)
开发者ID:ThunderQi,项目名称:tensorflow,代码行数:32,代码来源:while_v2_test.py
示例13: testMap
def testMap(self):
test_cases = [{
'tensor': 0,
'shape': tensor_shape.TensorShape([])
}, {
'tensor': np.array([1, 2, 3]),
'shape': tensor_shape.TensorShape([3])
}, {
'tensor': np.array([[1, 2, 3]]),
'shape': tensor_shape.TensorShape([3, 1])
}, {
'tensor': np.array([[[1, 2, 3], [4, 5, 6]]]),
'shape': tensor_shape.TensorShape([3, 2, 1])
}]
for test_case in test_cases:
with ops.Graph().as_default() as g:
dataset = dataset_ops.Dataset.from_tensors(test_case['tensor'])
dataset = dataset.map(array_ops.transpose)
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(get_next)
mg = meta_graph.create_meta_graph_def(graph=g)
grappler_item = item.Item(mg)
op_properties = grappler_item.GetOpProperties()
self.assertEqual(test_case['shape'],
op_properties['IteratorGetNext'][0].shape)
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:28,代码来源:datasets_test.py
示例14: testFromStringHandle
def testFromStringHandle(self):
test_cases = [{
'shape': tensor_shape.TensorShape([])
}, {
'shape': tensor_shape.TensorShape([3])
}, {
'shape': tensor_shape.TensorShape([1, 2])
}, {
'shape': tensor_shape.TensorShape([1, 2, 3])
}]
for test_case in test_cases:
with ops.Graph().as_default() as g:
iterator = iterator_ops.Iterator.from_structure(dtypes.int64)
handle = iterator.string_handle()
iterator = iterator_ops.Iterator.from_string_handle(
handle, dtypes.int64, output_shapes=test_case['shape'])
get_next = iterator.get_next()
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(get_next)
mg = meta_graph.create_meta_graph_def(graph=g)
grappler_item = item.Item(mg)
op_properties = grappler_item.GetOpProperties()
self.assertEqual(test_case['shape'],
op_properties['IteratorGetNext'][0].shape)
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:25,代码来源:datasets_test.py
示例15: testInterleave
def testInterleave(self):
test_cases = [{
'tensor': 0,
'shape': tensor_shape.TensorShape([])
}, {
'tensor': np.array([1, 2, 3]),
'shape': tensor_shape.TensorShape([3])
}, {
'tensor': np.array([[1, 2, 3]]),
'shape': tensor_shape.TensorShape([1, 3])
}]
for test_case in test_cases:
with ops.Graph().as_default() as g:
dataset = dataset_ops.Dataset.range(42)
def make_dataset(tensor):
def dataset_fn(n):
return dataset_ops.Dataset.from_tensors(tensor).repeat(n)
return dataset_fn
dataset = dataset.interleave(
make_dataset(test_case['tensor']), cycle_length=42)
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(get_next)
mg = meta_graph.create_meta_graph_def(graph=g)
grappler_item = item.Item(mg)
op_properties = grappler_item.GetOpProperties()
self.assertEqual(test_case['shape'],
op_properties['IteratorGetNext'][0].shape)
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:34,代码来源:datasets_test.py
示例16: testPaddedBatch
def testPaddedBatch(self):
test_cases = [{
'tensor': 0,
'shape': tensor_shape.TensorShape([None])
}, {
'tensor': np.array([1, 2, 3]),
'shape': tensor_shape.TensorShape([None, 4])
}, {
'tensor': np.array([[1, 2, 3]]),
'shape': tensor_shape.TensorShape([None, 2, 4])
}]
for test_case in test_cases:
with ops.Graph().as_default() as g:
dataset = dataset_ops.Dataset.from_tensors(test_case['tensor'])
dataset = dataset.padded_batch(42, padded_shapes=test_case['shape'][1:])
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(get_next)
mg = meta_graph.create_meta_graph_def(graph=g)
grappler_item = item.Item(mg)
op_properties = grappler_item.GetOpProperties()
inferred_shape = self.as_tensor_shape(
op_properties['IteratorGetNext'][0].shape)
self.assertTrue(test_case['shape'].dims[0].is_compatible_with(
inferred_shape[0]))
self.assertEqual(test_case['shape'][1:], inferred_shape[1:])
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:28,代码来源:datasets_test.py
示例17: testVirtualCluster
def testVirtualCluster(self):
with ops.Graph().as_default() as g:
a = random_ops.random_uniform(shape=())
b = random_ops.random_uniform(shape=())
c = a + b
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(c)
mg = meta_graph.create_meta_graph_def(graph=g)
grappler_item = item.Item(mg)
device_properties = device_properties_pb2.DeviceProperties(
type='GPU',
frequency=1000,
num_cores=60,
environment={
'architecture': '7'
})
named_device = device_properties_pb2.NamedDevice(
properties=device_properties, name='/GPU:0')
grappler_cluster = cluster.Cluster(devices=[named_device])
op_perfs, run_time, _ = grappler_cluster.MeasureCosts(grappler_item)
self.assertGreater(run_time, 0)
self.assertEqual(len(op_perfs), 15)
estimated_perf = grappler_cluster.EstimatePerformance(named_device)
self.assertEqual(7680.0, estimated_perf)
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:25,代码来源:cluster_test.py
示例18: testSimpleSwap
def testSimpleSwap(self):
"""Check that the swap annotations are followed."""
a = constant_op.constant(10, name='a')
b = constant_op.constant(20, name='b')
c = math_ops.add_n([a, b], name='c')
d = math_ops.add_n([b, c], name='d')
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(d)
d.op.node_def.attr['_swap_to_host'].i = 0
mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
rewriter_config = rewriter_config_pb2.RewriterConfig(
memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)
graph = tf_optimizer.OptimizeGraph(rewriter_config, mg)
self.assertEqual(len(graph.node), 6)
self.assertItemsEqual([node.name for node in graph.node], [
'a',
'b',
'c',
'd',
'swap_in_d_0',
'swap_out_d_0',
])
for node in graph.node:
if node.name == 'swap_in_d_0':
self.assertEqual('swap_out_d_0', node.input[0])
self.assertEqual('^b', node.input[1])
elif node.name == 'swap_out_d_0':
self.assertEqual('b', node.input[0])
elif node.name == 'd':
self.assertEqual('swap_in_d_0', node.input[0])
self.assertEqual('c', node.input[1])
开发者ID:AutumnQYN,项目名称:tensorflow,代码行数:35,代码来源:memory_optimizer_test.py
示例19: record_summaries_every_n_global_steps
def record_summaries_every_n_global_steps(n):
"""Sets the should_record_summaries Tensor to true if global_step % n == 0."""
collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
old = collection_ref[:]
with ops.device("cpu:0"):
collection_ref[:] = [math_ops.equal(training_util.get_global_step() % n, 0)]
yield
collection_ref[:] = old
开发者ID:dyoung418,项目名称:tensorflow,代码行数:8,代码来源:summary_ops.py
示例20: always_record_summaries
def always_record_summaries():
"""Sets the should_record_summaries Tensor to always true."""
collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
old = collection_ref[:]
try:
collection_ref[:] = [True]
yield
finally:
collection_ref[:] = old
开发者ID:Jackiefan,项目名称:tensorflow,代码行数:9,代码来源:summary_ops_v2.py
注:本文中的tensorflow.python.framework.ops.get_collection_ref函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论