本文整理汇总了Python中tensorflow.python.ops.array_ops.squeeze函数的典型用法代码示例。如果您正苦于以下问题:Python squeeze函数的具体用法?Python squeeze怎么用?Python squeeze使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了squeeze函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: GetParams
def GetParams(self):
"""Create a graph containing multiple segment."""
input_name = "input"
input_dims = [2, 32, 32, 3]
g = ops.Graph()
with g.as_default():
inp = array_ops.placeholder(
dtype=dtypes.float32, shape=input_dims, name=input_name)
with g.device("/GPU:0"):
n = inp
c = constant_op.constant(1.0, name="c")
n = math_ops.add(n, c, name="add")
n = math_ops.mul(n, n, name="mul")
n = math_ops.add(n, n, name="add1")
n = self.trt_incompatible_op(n, name="incompatible1")
n = math_ops.add(n, c, name="add2")
n = math_ops.mul(n, n, name="mul1")
n = math_ops.add(n, n, name="add3")
array_ops.squeeze(n, name=self.output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
expected_engines={
"my_trt_op_0": ["add2", "add3", "mul1"],
# Why segment ["add", "add1", "mul"] was assigned segment id 1
# instead of 0: the parent node of this segment is actually const
# node 'c', but it's removed later since it's const output of the
# segment which is not allowed.
"my_trt_op_1": ["add", "add1", "mul"]
},
expected_output_dims=tuple(input_dims),
allclose_atol=1.e-06,
allclose_rtol=1.e-06)
开发者ID:ZhangXinNan,项目名称:tensorflow,代码行数:34,代码来源:base_test.py
示例2: GetParams
def GetParams(self):
"""Create a graph containing multiple segment."""
input_name = "input"
input_dims = [2, 32, 32, 3]
output_name = "output"
g = ops.Graph()
with g.as_default():
inp = array_ops.placeholder(
dtype=dtypes.float32, shape=input_dims, name=input_name)
with g.device("/GPU:0"):
n = inp
c = constant_op.constant(1.0, name="c")
# Adds control dependency from the constant op to a trt incompatible op,
# and adds control dependency from the trt incompatible op to all other
# ops, to make sure the constant op cannot be contracted with any trt
# segment that depends on it.
with g.control_dependencies([c]):
d = self.trt_incompatible_op(n, name="incompatible")
with g.control_dependencies([d]):
n = math_ops.add(n, c, name="add")
n = math_ops.mul(n, n, name="mul")
n = math_ops.add(n, n, name="add1")
n = self.trt_incompatible_op(n, name="incompatible1")
with g.control_dependencies([d]):
n = math_ops.add(n, c, name="add2")
n = math_ops.mul(n, n, name="mul1")
n = math_ops.add(n, n, name="add3")
array_ops.squeeze(n, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
output_names=[output_name],
expected_output_dims=[tuple(input_dims)])
开发者ID:aeverall,项目名称:tensorflow,代码行数:34,代码来源:base_test.py
示例3: _test_squeeze
def _test_squeeze(data, squeeze_dims=None):
""" One iteration of squeeze """
if squeeze_dims is None:
squeeze_dims = []
# see relay/frontend/tflite.py convert_squeeze more detail of channel first rule
if len(data.shape) == 1 or len(data.shape) == 2:
tvm_data = data
elif len(data.shape) == 3:
tvm_data = np.transpose(data, axes=(0, 2, 1))
elif len(data.shape) == 4:
tvm_data = np.transpose(data, axes=(0, 3, 1, 2))
else:
raise NotImplementedError("Not support input shape {} of reshape : ".
format(str(len(data.shape))))
tvm_data = np.transpose(data, axes=(0, 3, 1, 2))
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
if squeeze_dims:
out = array_ops.squeeze(in_data, squeeze_dims)
else:
out = array_ops.squeeze(in_data)
compare_tflite_with_tvm(data, tvm_data, 'Placeholder:0', [in_data], [out])
开发者ID:bddppq,项目名称:tvm,代码行数:28,代码来源:test_forward.py
示例4: call
def call(self, inputs):
# There is no TF op for 1D pooling, hence we make the inputs 4D.
if self.data_format == 'channels_last':
# input is NWC, make it NHWC
inputs = array_ops.expand_dims(inputs, 1)
# pool on the W dim
pool_shape = (1, 1) + self.pool_size + (1,)
strides = (1, 1) + self.strides + (1,)
data_format = 'NHWC'
else:
# input is NCW, make it NCHW
inputs = array_ops.expand_dims(inputs, 2)
# pool on the W dim
pool_shape = (1, 1, 1) + self.pool_size
strides = (1, 1, 1) + self.strides
data_format = 'NCHW'
outputs = self.pool_function(
inputs,
ksize=pool_shape,
strides=strides,
padding=self.padding.upper(),
data_format=data_format)
if self.data_format == 'channels_last':
return array_ops.squeeze(outputs, 1)
else:
return array_ops.squeeze(outputs, 2)
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:28,代码来源:pooling.py
示例5: _statistics
def _statistics(x, axes):
"""Calculate the mean and mean square of `x`.
Modified from the implementation of `tf.nn.moments`.
Args:
x: A `Tensor`.
axes: Array of ints. Axes along which to compute mean and
variance.
Returns:
Two `Tensor` objects: `mean` and `square mean`.
"""
# The dynamic range of fp16 is too limited to support the collection of
# sufficient statistics. As a workaround we simply perform the operations
# on 32-bit floats before converting the mean and variance back to fp16
y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x
# Compute true mean while keeping the dims for proper broadcasting.
shift = array_ops.stop_gradient(math_ops.reduce_mean(y, axes, keepdims=True))
shifted_mean = math_ops.reduce_mean(y - shift, axes, keepdims=True)
mean = shifted_mean + shift
mean_squared = math_ops.reduce_mean(math_ops.square(y), axes, keepdims=True)
mean = array_ops.squeeze(mean, axes)
mean_squared = array_ops.squeeze(mean_squared, axes)
if x.dtype == dtypes.float16:
return (math_ops.cast(mean, dtypes.float16),
math_ops.cast(mean_squared, dtypes.float16))
else:
return (mean, mean_squared)
开发者ID:AnishShah,项目名称:tensorflow,代码行数:32,代码来源:virtual_batchnorm_impl.py
示例6: GetParams
def GetParams(self):
"""Test for rank 2 input in TF-TRT."""
input_names = ["input", "input2"]
# Two paths: first with rank 2 input, second with rank 4 input.
input_dims = [[12, 5], [12, 5, 2, 2]]
output_name = "output"
g = ops.Graph()
with g.as_default():
outputs = []
for i in range(2):
x = array_ops.placeholder(
dtype=dtypes.float32, shape=input_dims[i], name=input_names[i])
c = constant_op.constant(1.0, name="c%d_1" % i)
q = math_ops.add(x, c, name="add%d_1" % i)
q = math_ops.abs(q, name="abs%d_1" % i)
c = constant_op.constant(2.2, name="c%d_2" % i)
q = math_ops.add(q, c, name="add%d_2" % i)
q = math_ops.abs(q, name="abs%d_2" % i)
c = constant_op.constant(3.0, name="c%d_3" % i)
q = math_ops.add(q, c, name="add%d_3" % i)
if i == 0:
for j in range(2):
q = array_ops.expand_dims(q, -1, name="expand%d_%d" % (i, j))
q = gen_math_ops.reciprocal(q, name="reciprocal%d" % i)
outputs.append(q)
# Combine both paths
q = math_ops.add(outputs[0], outputs[1], name="add")
array_ops.squeeze(q, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=input_names,
input_dims=input_dims,
output_names=[output_name],
expected_output_dims=[tuple(input_dims[1])])
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:34,代码来源:rank_two_test.py
示例7: get_simple_graph_def
def get_simple_graph_def(self):
"""Create a simple graph and return its graph_def."""
g = ops.Graph()
with g.as_default():
a = aops.placeholder(
dtype=dtypes.float32, shape=(None, 24, 24, 2), name="input")
e = cop.constant(
[[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]],
name="weights",
dtype=dtypes.float32)
conv = nn.conv2d(
input=a,
filter=e,
strides=[1, 2, 2, 1],
padding="SAME",
name="conv")
b = cop.constant(
[4., 1.5, 2., 3., 5., 7.], name="bias", dtype=dtypes.float32)
t = nn.bias_add(conv, b, name="biasAdd")
relu = nn.relu(t, "relu")
idty = aops.identity(relu, "ID")
v = nn_ops.max_pool(
idty, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool")
aops.squeeze(v, name="output")
return g.as_graph_def()
开发者ID:ebrevdo,项目名称:tensorflow,代码行数:25,代码来源:tf_trt_integration_test.py
示例8: testSqueezeMatrix
def testSqueezeMatrix(self):
matrix = [[1, 2, 3]]
matrix_squeezed = array_ops.squeeze(matrix, [0])
self.assertEqual(matrix_squeezed.get_shape(), (3))
with self.assertRaises(ValueError):
matrix_squeezed = array_ops.squeeze(matrix, [1])
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:7,代码来源:array_ops_test.py
示例9: GetParams
def GetParams(self):
"""Neighboring node wiring tests in TF-TRT conversion."""
dtype = dtypes.float32
input_name = "input"
input_dims = [2, 3, 7, 5]
output_name = "output"
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name)
e = constant_op.constant(
np.random.normal(.3, 0.05, [3, 2, 3, 4]), name="weights", dtype=dtype)
conv = nn.conv2d(
input=x,
filter=e,
data_format="NCHW",
strides=[1, 1, 1, 1],
padding="VALID",
name="conv")
b = constant_op.constant(
np.random.normal(1.0, 1.0, [1, 4, 1, 1]), name="bias", dtype=dtype)
t = math_ops.mul(conv, b, name="mul")
e = self.trt_incompatible_op(conv, name="incompatible")
t = math_ops.sub(t, e, name="sub")
array_ops.squeeze(t, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
output_names=[output_name],
expected_output_dims=[(2, 4, 5, 4)])
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:30,代码来源:neighboring_engine_test.py
示例10: test_virtual_statistics
def test_virtual_statistics(self):
"""Check that `_virtual_statistics` gives same result as `nn.moments`."""
random_seed.set_random_seed(1234)
batch_axis = 0
partial_batch = random_ops.random_normal([4, 5, 7, 3])
single_example = random_ops.random_normal([1, 5, 7, 3])
full_batch = array_ops.concat([partial_batch, single_example], axis=0)
for reduction_axis in range(1, 4):
# Get `nn.moments` on the full batch.
reduction_axes = list(range(4))
del reduction_axes[reduction_axis]
mom_mean, mom_variance = nn.moments(full_batch, reduction_axes)
# Get virtual batch statistics.
vb_reduction_axes = list(range(4))
del vb_reduction_axes[reduction_axis]
del vb_reduction_axes[batch_axis]
vbn = virtual_batchnorm.VBN(partial_batch, reduction_axis)
vb_mean, mean_sq = vbn._virtual_statistics(
single_example, vb_reduction_axes)
vb_variance = mean_sq - math_ops.square(vb_mean)
# Remove singleton batch dim for easy comparisons.
vb_mean = array_ops.squeeze(vb_mean, batch_axis)
vb_variance = array_ops.squeeze(vb_variance, batch_axis)
with self.cached_session(use_gpu=True) as sess:
vb_mean_np, vb_var_np, mom_mean_np, mom_var_np = sess.run([
vb_mean, vb_variance, mom_mean, mom_variance])
self.assertAllClose(mom_mean_np, vb_mean_np)
self.assertAllClose(mom_var_np, vb_var_np)
开发者ID:Albert-Z-Guo,项目名称:tensorflow,代码行数:33,代码来源:virtual_batchnorm_test.py
示例11: GetParams
def GetParams(self):
"""Neighboring node wiring tests in TF-TRT conversion."""
dtype = dtypes.float32
input_name = "input"
input_dims = [2, 3, 7, 5]
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name)
e = constant_op.constant(
np.random.normal(.3, 0.05, [3, 2, 3, 4]), name="weights", dtype=dtype)
conv = nn.conv2d(
input=x,
filter=e,
data_format="NCHW",
strides=[1, 1, 1, 1],
padding="VALID",
name="conv")
b = constant_op.constant(
np.random.normal(1.0, 1.0, [1, 4, 1, 1]), name="bias", dtype=dtype)
t = conv * b
e = gen_math_ops.tan(conv)
t = t - e
array_ops.squeeze(t, name=self.output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
num_expected_engines=2,
expected_output_dims=(2, 4, 5, 4),
allclose_atol=1.e-03,
allclose_rtol=1.e-03)
开发者ID:StephenOman,项目名称:tensorflow,代码行数:31,代码来源:neighboring_engine_test.py
示例12: GetParams
def GetParams(self):
"""Single vgg layer test in TF-TRT conversion."""
dtype = dtypes.float32
input_name = "input"
input_dims = [5, 8, 8, 2]
output_name = "output"
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name)
x, _, _ = nn_impl.fused_batch_norm(
x, [1.0, 1.0], [0.0, 0.0],
mean=[0.5, 0.5],
variance=[1.0, 1.0],
is_training=False)
e = constant_op.constant(
np.random.randn(1, 1, 2, 6), name="weights", dtype=dtype)
conv = nn.conv2d(
input=x, filter=e, strides=[1, 2, 2, 1], padding="SAME", name="conv")
b = constant_op.constant(np.random.randn(6), name="bias", dtype=dtype)
t = nn.bias_add(conv, b, name="biasAdd")
relu = nn.relu(t, "relu")
idty = array_ops.identity(relu, "ID")
v = nn_ops.max_pool(
idty, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool")
array_ops.squeeze(v, name=output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
output_names=[output_name],
expected_output_dims=[(5, 2, 2, 6)])
开发者ID:aeverall,项目名称:tensorflow,代码行数:31,代码来源:vgg_block_test.py
示例13: GetMultiEngineGraphDef
def GetMultiEngineGraphDef(dtype=dtypes.float32):
"""Create a graph containing multiple segment."""
g = ops.Graph()
with g.as_default():
inp = array_ops.placeholder(
dtype=dtype, shape=[None] + INPUT_DIMS[1:], name=INPUT_NAME)
with g.device("/GPU:0"):
conv_filter = constant_op.constant(
[[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]],
name="weights",
dtype=dtype)
conv = nn.conv2d(
input=inp,
filter=conv_filter,
strides=[1, 2, 2, 1],
padding="SAME",
name="conv")
c1 = constant_op.constant(
np.random.randn(INPUT_DIMS[0], 12, 12, 6), dtype=dtype)
p = conv * c1
c2 = constant_op.constant(
np.random.randn(INPUT_DIMS[0], 12, 12, 6), dtype=dtype)
q = conv / c2
edge = math_ops.sin(q)
edge /= edge
r = edge + edge
p -= edge
q *= edge
s = p + q
s -= r
array_ops.squeeze(s, name=OUTPUT_NAME)
return g.as_graph_def()
开发者ID:Eagle732,项目名称:tensorflow,代码行数:34,代码来源:tf_trt_integration_test.py
示例14: GetParams
def GetParams(self):
"""Testing conversion of BatchMatMul in TF-TRT conversion."""
dtype = dtypes.float32
input_name = "input"
input_dims = [2, 15, 15, 3]
g = ops.Graph()
with g.as_default():
inp = array_ops.placeholder(
dtype=dtype, shape=[None] + input_dims[1:], name=input_name)
with g.device("/GPU:0"):
e1 = constant_op.constant(
np.random.randn(1, 1, 3, 5), name="kernel_1", dtype=dtype)
e2 = constant_op.constant(
np.random.randn(1, 1, 5, 10), name="kernel_2", dtype=dtype)
conv = nn.conv2d(
input=inp,
filter=e1,
strides=[1, 1, 1, 1],
padding="VALID",
name="conv")
out = nn.conv2d(
input=conv,
filter=e2,
strides=[1, 1, 1, 1],
padding="VALID",
name="conv_2")
array_ops.squeeze(out, name=self.output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
expected_engines=["my_trt_op_0"],
expected_output_dims=(2, 15, 15, 10),
allclose_atol=1.e-02,
allclose_rtol=1.e-02)
开发者ID:ZhangXinNan,项目名称:tensorflow,代码行数:35,代码来源:memory_alignment_test.py
示例15: GetSingleEngineGraphDef
def GetSingleEngineGraphDef(dtype=dtypes.float32):
"""Create a graph containing single segment."""
g = ops.Graph()
with g.as_default():
inp = array_ops.placeholder(
dtype=dtype, shape=[None] + INPUT_DIMS[1:], name=INPUT_NAME)
with g.device("/GPU:0"):
conv_filter = constant_op.constant(
[[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]],
name="weights",
dtype=dtype)
conv = nn.conv2d(
input=inp,
filter=conv_filter,
strides=[1, 2, 2, 1],
padding="SAME",
name="conv")
bias = constant_op.constant(
[4., 1.5, 2., 3., 5., 7.], name="bias", dtype=dtype)
added = nn.bias_add(conv, bias, name="bias_add")
relu = nn.relu(added, "relu")
identity = array_ops.identity(relu, "identity")
pool = nn_ops.max_pool(
identity, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool")
array_ops.squeeze(pool, name=OUTPUT_NAME)
return g.as_graph_def()
开发者ID:Eagle732,项目名称:tensorflow,代码行数:26,代码来源:tf_trt_integration_test.py
示例16: average_impurity
def average_impurity(self):
"""Constructs a TF graph for evaluating the average leaf impurity of a tree.
If in regression mode, this is the leaf variance. If in classification mode,
this is the gini impurity.
Returns:
The last op in the graph.
"""
children = array_ops.squeeze(array_ops.slice(
self.variables.tree, [0, 0], [-1, 1]), squeeze_dims=[1])
is_leaf = math_ops.equal(constants.LEAF_NODE, children)
leaves = math_ops.to_int32(array_ops.squeeze(array_ops.where(is_leaf),
squeeze_dims=[1]))
counts = array_ops.gather(self.variables.node_sums, leaves)
gini = self._weighted_gini(counts)
# Guard against step 1, when there often are no leaves yet.
def impurity():
return gini
# Since average impurity can be used for loss, when there's no data just
# return a big number so that loss always decreases.
def big():
return array_ops.ones_like(gini, dtype=dtypes.float32) * 10000000.
return control_flow_ops.cond(math_ops.greater(
array_ops.shape(leaves)[0], 0), impurity, big)
开发者ID:AutumnQYN,项目名称:tensorflow,代码行数:25,代码来源:tensor_forest.py
示例17: testSqueezeMatrix
def testSqueezeMatrix(self):
matrix = [[1, 2, 3]]
matrix_squeezed = array_ops.squeeze(matrix, [0])
self.assertEqual(matrix_squeezed.get_shape(), (3))
with self.assertRaisesRegexp(
Exception, "Can not squeeze dim.1., expected a dimension of 1, got 3"):
matrix_squeezed = array_ops.squeeze(matrix, [1])
开发者ID:ZhangXinNan,项目名称:tensorflow,代码行数:8,代码来源:array_ops_test.py
示例18: _recall_at_threshold
def _recall_at_threshold(labels, predictions, weights, threshold, name=None):
with ops.name_scope(
name, 'recall_at_%s' % threshold,
(predictions, labels, weights, threshold)) as scope:
precision_tensor, update_op = metrics_lib.recall_at_thresholds(
labels=labels, predictions=predictions, thresholds=(threshold,),
weights=weights, name=scope)
return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op)
开发者ID:vaccine,项目名称:tensorflow,代码行数:8,代码来源:head.py
示例19: remove_squeezable_dimensions
def remove_squeezable_dimensions(
labels, predictions, expected_rank_diff=0, name=None):
"""Squeeze last dim if ranks differ from expected by exactly 1.
In the common case where we expect shapes to match, `expected_rank_diff`
defaults to 0, and we squeeze the last dimension of the larger rank if they
differ by 1.
But, for example, if `labels` contains class IDs and `predictions` contains 1
probability per class, we expect `predictions` to have 1 more dimension than
`labels`, so `expected_rank_diff` would be 1. In this case, we'd squeeze
`labels` if `rank(predictions) - rank(labels) == 0`, and
`predictions` if `rank(predictions) - rank(labels) == 2`.
This will use static shape if available. Otherwise, it will add graph
operations, which could result in a performance hit.
Args:
labels: Label values, a `Tensor` whose dimensions match `predictions`.
predictions: Predicted values, a `Tensor` of arbitrary dimensions.
expected_rank_diff: Expected result of `rank(predictions) - rank(labels)`.
name: Name of the op.
Returns:
Tuple of `labels` and `predictions`, possibly with last dim squeezed.
"""
with ops.name_scope(name, 'remove_squeezable_dimensions',
[labels, predictions]):
predictions = ops.convert_to_tensor(predictions)
labels = ops.convert_to_tensor(labels)
predictions_shape = predictions.get_shape()
predictions_rank = predictions_shape.ndims
labels_shape = labels.get_shape()
labels_rank = labels_shape.ndims
if (labels_rank is not None) and (predictions_rank is not None):
# Use static rank.
rank_diff = predictions_rank - labels_rank
if rank_diff == expected_rank_diff + 1:
predictions = array_ops.squeeze(predictions, [-1])
elif rank_diff == expected_rank_diff - 1:
labels = array_ops.squeeze(labels, [-1])
return labels, predictions
# Use dynamic rank.
rank_diff = array_ops.rank(predictions) - array_ops.rank(labels)
if (predictions_rank is None) or (
predictions_shape.dims[-1].is_compatible_with(1)):
predictions = control_flow_ops.cond(
math_ops.equal(expected_rank_diff + 1, rank_diff),
lambda: array_ops.squeeze(predictions, [-1]),
lambda: predictions)
if (labels_rank is None) or (
labels_shape.dims[-1].is_compatible_with(1)):
labels = control_flow_ops.cond(
math_ops.equal(expected_rank_diff - 1, rank_diff),
lambda: array_ops.squeeze(labels, [-1]),
lambda: labels)
return labels, predictions
开发者ID:aritratony,项目名称:tensorflow,代码行数:58,代码来源:confusion_matrix.py
示例20: crf_decode
def crf_decode(potentials, transition_params, sequence_length):
"""Decode the highest scoring sequence of tags in TensorFlow.
This is a function for tensor.
Args:
potentials: A [batch_size, max_seq_len, num_tags] tensor of
unary potentials.
transition_params: A [num_tags, num_tags] matrix of
binary potentials.
sequence_length: A [batch_size] vector of true sequence lengths.
Returns:
decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`.
Contains the highest scoring tag indicies.
best_score: A [batch_size] vector, containing the score of `decode_tags`.
"""
# For simplicity, in shape comments, denote:
# 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output).
num_tags = potentials.get_shape()[2].value
# Computes forward decoding. Get last score and backpointers.
crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params)
initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1])
initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O]
inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O]
backpointers, last_score = rnn.dynamic_rnn(
crf_fwd_cell,
inputs=inputs,
sequence_length=sequence_length - 1,
initial_state=initial_state,
time_major=False,
dtype=dtypes.int32) # [B, T - 1, O], [B, O]
backpointers = gen_array_ops.reverse_sequence(
backpointers, sequence_length - 1, seq_dim=1) # [B, T-1, O]
# Computes backward decoding. Extract tag indices from backpointers.
crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags)
initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1),
dtype=dtypes.int32) # [B]
initial_state = array_ops.expand_dims(initial_state, axis=-1) # [B, 1]
decode_tags, _ = rnn.dynamic_rnn(
crf_bwd_cell,
inputs=backpointers,
sequence_length=sequence_length - 1,
initial_state=initial_state,
time_major=False,
dtype=dtypes.int32) # [B, T - 1, 1]
decode_tags = array_ops.squeeze(decode_tags, axis=[2]) # [B, T - 1]
decode_tags = array_ops.concat([initial_state, decode_tags], axis=1) # [B, T]
decode_tags = gen_array_ops.reverse_sequence(
decode_tags, sequence_length, seq_dim=1) # [B, T]
best_score = math_ops.reduce_max(last_score, axis=1) # [B]
return decode_tags, best_score
开发者ID:SylChan,项目名称:tensorflow,代码行数:55,代码来源:crf.py
注:本文中的tensorflow.python.ops.array_ops.squeeze函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论