本文整理汇总了Python中tensorflow.python.ops.check_ops.assert_equal函数的典型用法代码示例。如果您正苦于以下问题:Python assert_equal函数的具体用法?Python assert_equal怎么用?Python assert_equal使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了assert_equal函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: call
def call(self, labels, predictions, weights=None):
"""Accumulate accuracy statistics.
For example, if labels is [1, 2, 3, 4] and predictions is [0, 2, 3, 4]
then the accuracy is 3/4 or .75. If the weights were specified as
[1, 1, 0, 0] then the accuracy would be 1/2 or .5.
`labels` and `predictions` should have the same shape and type.
Args:
labels: Tensor with the true labels for each example. One example
per element of the Tensor.
predictions: Tensor with the predicted label for each example.
weights: Optional weighting of each example. Defaults to 1.
Returns:
The arguments, for easy chaining.
"""
check_ops.assert_equal(
array_ops.shape(labels), array_ops.shape(predictions),
message="Shapes of labels and predictions are unequal")
matches = math_ops.equal(labels, predictions)
matches = math_ops.cast(matches, self.dtype)
super(Accuracy, self).call(matches, weights=weights)
if weights is None:
return labels, predictions
return labels, predictions, weights
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:27,代码来源:metrics_impl.py
示例2: _check_shapes_dynamic
def _check_shapes_dynamic(self, operator, v, diag):
"""Return (v, diag) with Assert dependencies, which check shape."""
checks = []
with ops.op_scope([operator, v, diag], 'check_shapes'):
s_v = array_ops.shape(v)
r_op = operator.rank()
r_v = array_ops.rank(v)
if diag is not None:
s_d = array_ops.shape(diag)
r_d = array_ops.rank(diag)
# Check tensor rank.
checks.append(check_ops.assert_rank(v, r_op))
if diag is not None:
checks.append(check_ops.assert_rank(diag, r_op - 1))
# Check batch shape
checks.append(check_ops.assert_equal(
operator.batch_shape(), array_ops.slice(s_v, [0], [r_v - 2])))
if diag is not None:
checks.append(check_ops.assert_equal(
operator.batch_shape(), array_ops.slice(s_d, [0], [r_d - 1])))
# Check event shape
checks.append(check_ops.assert_equal(
operator.vector_space_dimension(), array_ops.gather(s_v, r_v - 2)))
if diag is not None:
checks.append(check_ops.assert_equal(
array_ops.gather(s_v, r_v - 1), array_ops.gather(s_d, r_d - 1)))
v = control_flow_ops.with_dependencies(checks, v)
if diag is not None:
diag = control_flow_ops.with_dependencies(checks, diag)
return v, diag
开发者ID:10imaging,项目名称:tensorflow,代码行数:34,代码来源:operator_pd_vdvt_update.py
示例3: _check_mu
def _check_mu(self, mu):
"""Return `mu` after validity checks and possibly with assertations."""
mu = ops.convert_to_tensor(mu)
cov = self._cov
if mu.dtype != cov.dtype:
raise TypeError(
"mu and cov must have the same dtype. Found mu.dtype = %s, "
"cov.dtype = %s"
% (mu.dtype, cov.dtype))
if not self.strict:
return mu
else:
assert_compatible_shapes = control_flow_ops.group(
check_ops.assert_equal(
array_ops.rank(mu) + 1,
cov.rank(),
data=["mu should have rank 1 less than cov. Found: rank(mu) = ",
array_ops.rank(mu), " rank(cov) = ", cov.rank()],
),
check_ops.assert_equal(
array_ops.shape(mu),
cov.vector_shape(),
data=["mu.shape and cov.shape[:-1] should match. "
"Found: shape(mu) = "
, array_ops.shape(mu), " shape(cov) = ", cov.shape()],
),
)
return control_flow_ops.with_dependencies([assert_compatible_shapes], mu)
开发者ID:363158858,项目名称:tensorflow,代码行数:29,代码来源:mvn.py
示例4: _get_sparse_tensors
def _get_sparse_tensors(self, inputs, weight_collections=None,
trainable=None):
sparse_tensors = self.categorical_column._get_sparse_tensors(inputs)
id_tensor = sparse_tensors.id_tensor
weight_tensor = sparse_tensors.weight_tensor
# Expands final dimension, so that embeddings are not combined during
# embedding lookup.
check_id_rank = check_ops.assert_equal(
array_ops.rank(id_tensor), 2,
data=[
'Column {} expected ID tensor of rank 2. '.format(self.name),
'id_tensor shape: ', array_ops.shape(id_tensor)])
with ops.control_dependencies([check_id_rank]):
id_tensor = sparse_ops.sparse_reshape(
id_tensor,
shape=array_ops.concat([id_tensor.dense_shape, [1]], axis=0))
if weight_tensor is not None:
check_weight_rank = check_ops.assert_equal(
array_ops.rank(weight_tensor), 2,
data=[
'Column {} expected weight tensor of rank 2.'.format(self.name),
'weight_tensor shape:', array_ops.shape(weight_tensor)])
with ops.control_dependencies([check_weight_rank]):
weight_tensor = sparse_ops.sparse_reshape(
weight_tensor,
shape=array_ops.concat([weight_tensor.dense_shape, [1]], axis=0))
return fc._CategoricalColumn.IdWeightPair(id_tensor, weight_tensor)
开发者ID:DILASSS,项目名称:tensorflow,代码行数:27,代码来源:sequential_feature_column.py
示例5: _kl_independent
def _kl_independent(a, b, name="kl_independent"):
"""Batched KL divergence `KL(a || b)` for Independent distributions.
We can leverage the fact that
```
KL(Independent(a) || Independent(b)) = sum(KL(a || b))
```
where the sum is over the `reinterpreted_batch_ndims`.
Args:
a: Instance of `Independent`.
b: Instance of `Independent`.
name: (optional) name to use for created ops. Default "kl_independent".
Returns:
Batchwise `KL(a || b)`.
Raises:
ValueError: If the event space for `a` and `b`, or their underlying
distributions don't match.
"""
p = a.distribution
q = b.distribution
# The KL between any two (non)-batched distributions is a scalar.
# Given that the KL between two factored distributions is the sum, i.e.
# KL(p1(x)p2(y) || q1(x)q2(y)) = KL(p1 || q1) + KL(q1 || q2), we compute
# KL(p || q) and do a `reduce_sum` on the reinterpreted batch dimensions.
if a.event_shape.is_fully_defined() and b.event_shape.is_fully_defined():
if a.event_shape == b.event_shape:
if p.event_shape == q.event_shape:
num_reduce_dims = a.event_shape.ndims - p.event_shape.ndims
reduce_dims = [-i - 1 for i in range(0, num_reduce_dims)]
return math_ops.reduce_sum(
kullback_leibler.kl_divergence(p, q, name=name), axis=reduce_dims)
else:
raise NotImplementedError("KL between Independents with different "
"event shapes not supported.")
else:
raise ValueError("Event shapes do not match.")
else:
with ops.control_dependencies([
check_ops.assert_equal(a.event_shape_tensor(), b.event_shape_tensor()),
check_ops.assert_equal(p.event_shape_tensor(), q.event_shape_tensor())
]):
num_reduce_dims = (
array_ops.shape(a.event_shape_tensor()[0]) -
array_ops.shape(p.event_shape_tensor()[0]))
reduce_dims = math_ops.range(-num_reduce_dims - 1, -1, 1)
return math_ops.reduce_sum(
kullback_leibler.kl_divergence(p, q, name=name), axis=reduce_dims)
开发者ID:didukhle,项目名称:tensorflow,代码行数:52,代码来源:independent.py
示例6: test_raises_when_less
def test_raises_when_less(self):
with self.test_session():
# Static check
static_small = constant_op.constant([3, 1], name="small")
static_big = constant_op.constant([4, 2], name="big")
with self.assertRaisesRegexp(ValueError, "fail"):
check_ops.assert_equal(static_big, static_small, message="fail")
# Dynamic check
small = array_ops.placeholder(dtypes.int32, name="small")
big = array_ops.placeholder(dtypes.int32, name="big")
with ops.control_dependencies([check_ops.assert_equal(small, big)]):
out = array_ops.identity(small)
with self.assertRaisesOpError("small.*big"):
out.eval(feed_dict={small: [3, 1], big: [4, 2]})
开发者ID:1000sprites,项目名称:tensorflow,代码行数:14,代码来源:check_ops_test.py
示例7: _check_mu
def _check_mu(self, mu):
"""Return `mu` after validity checks and possibly with assertations."""
mu = ops.convert_to_tensor(mu)
cov = self._cov
if mu.dtype != cov.dtype:
raise TypeError(
"mu and cov must have the same dtype. Found mu.dtype = %s, " "cov.dtype = %s" % (mu.dtype, cov.dtype)
)
# Try to validate with static checks.
mu_shape = mu.get_shape()
cov_shape = cov.get_shape()
if mu_shape.is_fully_defined() and cov_shape.is_fully_defined():
if mu_shape != cov_shape[:-1]:
raise ValueError(
"mu.shape and cov.shape[:-1] should match. Found: mu.shape=%s, "
"cov.shape=%s" % (mu_shape, cov_shape)
)
else:
return mu
# Static checks could not be run, so possibly do dynamic checks.
if not self.validate_args:
return mu
else:
assert_same_rank = check_ops.assert_equal(
array_ops.rank(mu) + 1,
cov.rank(),
data=[
"mu should have rank 1 less than cov. Found: rank(mu) = ",
array_ops.rank(mu),
" rank(cov) = ",
cov.rank(),
],
)
with ops.control_dependencies([assert_same_rank]):
assert_same_shape = check_ops.assert_equal(
array_ops.shape(mu),
cov.vector_shape(),
data=[
"mu.shape and cov.shape[:-1] should match. " "Found: shape(mu) = ",
array_ops.shape(mu),
" shape(cov) = ",
cov.shape(),
],
)
return control_flow_ops.with_dependencies([assert_same_shape], mu)
开发者ID:damienmg,项目名称:tensorflow,代码行数:48,代码来源:mvn.py
示例8: _maybe_check_matching_sizes
def _maybe_check_matching_sizes(self, event_shape_in, event_shape_out,
validate_args=False):
"""Check that prod(event_shape_in)==prod(event_shape_out)."""
def _get_size_from_shape(shape):
"""Computes size from a shape `Tensor`, statically if possible."""
s = tensor_util.constant_value(shape)
if s is not None:
return [np.int32(np.prod(s))]*2
return None, math_ops.reduce_prod(shape, name="size")
# Ensure `event_shape_in` is compatible with `event_shape_out`.
event_size_in_, event_size_in = _get_size_from_shape( # pylint: disable=unbalanced-tuple-unpacking
event_shape_in)
event_size_out_, event_size_out = _get_size_from_shape( # pylint: disable=unbalanced-tuple-unpacking
event_shape_out)
assertions = []
if event_size_in_ is not None and event_size_out_ is not None:
if event_size_in_ != event_size_out_:
raise ValueError(
"Input `event_size` ({}) does not match output `event_size` ({}).".
format(event_size_in, event_size_out_))
elif validate_args:
assertions.append(check_ops.assert_equal(
event_size_in, event_size_out,
message="Input/output `event_size`s do not match."))
return assertions
开发者ID:SylChan,项目名称:tensorflow,代码行数:29,代码来源:reshape_impl.py
示例9: _check_labels
def _check_labels(labels, expected_labels_dimension):
"""Check labels type and shape."""
with ops.name_scope(None, 'labels', (labels,)) as scope:
labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels)
if isinstance(labels, sparse_tensor.SparseTensor):
raise ValueError('SparseTensor labels are not supported.')
labels_shape = array_ops.shape(labels)
err_msg = 'labels shape must be [batch_size, {}]'.format(
expected_labels_dimension)
assert_rank = check_ops.assert_rank(labels, 2, message=err_msg)
with ops.control_dependencies([assert_rank]):
static_shape = labels.shape
if static_shape is not None:
dim1 = static_shape[1]
if (dim1 is not None) and (dim1 != expected_labels_dimension):
raise ValueError(
'Mismatched label shape. '
'Classifier configured with n_classes=%s. Received %s. '
'Suggested Fix: check your n_classes argument to the estimator '
'and/or the shape of your label.' %
(expected_labels_dimension, dim1))
assert_dimension = check_ops.assert_equal(
expected_labels_dimension, labels_shape[1], message=err_msg)
with ops.control_dependencies([assert_dimension]):
return array_ops.identity(labels, name=scope)
开发者ID:cneeruko,项目名称:tensorflow,代码行数:25,代码来源:head.py
示例10: assert_splits_match
def assert_splits_match(nested_splits_lists):
"""Checks that the given splits lists are identical.
Performs static tests to ensure that the given splits lists are identical,
and returns a list of control dependency op tensors that check that they are
fully identical.
Args:
nested_splits_lists: A list of nested_splits_lists, where each split_list is
a list of `splits` tensors from a `RaggedTensor`, ordered from outermost
ragged dimension to innermost ragged dimension.
Returns:
A list of control dependency op tensors.
Raises:
ValueError: If the splits are not identical.
"""
error_msg = "Inputs must have identical ragged splits"
for splits_list in nested_splits_lists:
if len(splits_list) != len(nested_splits_lists[0]):
raise ValueError(error_msg)
return [
check_ops.assert_equal(s1, s2, message=error_msg)
for splits_list in nested_splits_lists[1:]
for (s1, s2) in zip(nested_splits_lists[0], splits_list)
]
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:26,代码来源:ragged_util.py
示例11: assert_close
def assert_close(
x, y, data=None, summarize=None, message=None, name="assert_close"):
"""Assert that that x and y are within machine epsilon of each other.
Args:
x: Numeric `Tensor`
y: Numeric `Tensor`
data: The tensors to print out if the condition is `False`. Defaults to
error message and first few entries of `x` and `y`.
summarize: Print this many entries of each tensor.
message: A string to prefix to the default message.
name: A name for this operation (optional).
Returns:
Op raising `InvalidArgumentError` if |x - y| > machine epsilon.
"""
message = message or ""
x = ops.convert_to_tensor(x, name="x")
y = ops.convert_to_tensor(y, name="y")
if x.dtype.is_integer:
return check_ops.assert_equal(
x, y, data=data, summarize=summarize, message=message, name=name)
with ops.name_scope(name, "assert_close", [x, y, data]):
tol = np.finfo(x.dtype.as_numpy_dtype).resolution
if data is None:
data = [
message,
"Condition x ~= y did not hold element-wise: x = ", x.name, x, "y = ",
y.name, y
]
condition = math_ops.reduce_all(math_ops.less_equal(math_ops.abs(x-y), tol))
return control_flow_ops.Assert(
condition, data, summarize=summarize)
开发者ID:Nishant23,项目名称:tensorflow,代码行数:35,代码来源:distribution_util.py
示例12: __init__
def __init__(self, event_ndims=0, validate_args=False, name="absolute_value"):
"""Instantiates the `AbsoluteValue` bijector.
Args:
event_ndims: Python scalar indicating the number of dimensions associated
with a particular draw from the distribution. Currently only zero is
supported.
validate_args: Python `bool` indicating whether arguments should be
checked for correctness.
name: Python `str` name given to ops managed by this object.
Raises:
ValueError: If `event_ndims` is not zero.
"""
self._graph_parents = []
self._name = name
event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims")
event_ndims_const = tensor_util.constant_value(event_ndims)
if event_ndims_const is not None and event_ndims_const not in (0,):
raise ValueError("event_ndims(%s) was not 0" % event_ndims_const)
else:
if validate_args:
event_ndims = control_flow_ops.with_dependencies(
[check_ops.assert_equal(
event_ndims, 0, message="event_ndims was not 0")],
event_ndims)
with self._name_scope("init"):
super(AbsoluteValue, self).__init__(
event_ndims=event_ndims,
validate_args=validate_args,
name=name)
开发者ID:Mazecreator,项目名称:tensorflow,代码行数:33,代码来源:absolute_value_impl.py
示例13: test_doesnt_raise_when_both_empty
def test_doesnt_raise_when_both_empty(self):
with self.test_session():
larry = constant_op.constant([])
curly = constant_op.constant([])
with ops.control_dependencies([check_ops.assert_equal(larry, curly)]):
out = array_ops.identity(larry)
out.eval()
开发者ID:1000sprites,项目名称:tensorflow,代码行数:7,代码来源:check_ops_test.py
示例14: test_doesnt_raise_when_equal_and_broadcastable_shapes
def test_doesnt_raise_when_equal_and_broadcastable_shapes(self):
with self.test_session():
small = constant_op.constant([1, 2], name="small")
small_2 = constant_op.constant([1, 2], name="small_2")
with ops.control_dependencies([check_ops.assert_equal(small, small_2)]):
out = array_ops.identity(small)
out.eval()
开发者ID:1000sprites,项目名称:tensorflow,代码行数:7,代码来源:check_ops_test.py
示例15: maybe_check_quadrature_param
def maybe_check_quadrature_param(param, name, validate_args):
"""Helper which checks validity of `loc` and `scale` init args."""
with ops.name_scope(name="check_" + name, values=[param]):
assertions = []
if param.shape.ndims is not None:
if param.shape.ndims == 0:
raise ValueError("Mixing params must be a (batch of) vector; "
"{}.rank={} is not at least one.".format(
name, param.shape.ndims))
elif validate_args:
assertions.append(check_ops.assert_rank_at_least(
param, 1,
message=("Mixing params must be a (batch of) vector; "
"{}.rank is not at least one.".format(
name))))
# TODO(jvdillon): Remove once we support k-mixtures.
if param.shape.with_rank_at_least(1)[-1] is not None:
if param.shape[-1].value != 1:
raise NotImplementedError("Currently only bimixtures are supported; "
"{}.shape[-1]={} is not 1.".format(
name, param.shape[-1].value))
elif validate_args:
assertions.append(check_ops.assert_equal(
array_ops.shape(param)[-1], 1,
message=("Currently only bimixtures are supported; "
"{}.shape[-1] is not 1.".format(name))))
if assertions:
return control_flow_ops.with_dependencies(assertions, param)
return param
开发者ID:bikong2,项目名称:tensorflow,代码行数:31,代码来源:vector_diffeomixture.py
示例16: assert_integer_form
def assert_integer_form(
x, data=None, summarize=None, message=None,
int_dtype=None, name="assert_integer_form"):
"""Assert that x has integer components (or floats equal to integers).
Args:
x: Floating-point `Tensor`
data: The tensors to print out if the condition is `False`. Defaults to
error message and first few entries of `x` and `y`.
summarize: Print this many entries of each tensor.
message: A string to prefix to the default message.
int_dtype: A `tf.dtype` used to cast the float to. The default (`None`)
implies the smallest possible signed int will be used for casting.
name: A name for this operation (optional).
Returns:
Op raising `InvalidArgumentError` if `cast(x, int_dtype) != x`.
"""
with ops.name_scope(name, values=[x, data]):
x = ops.convert_to_tensor(x, name="x")
if x.dtype.is_integer:
return control_flow_ops.no_op()
message = message or "{} has non-integer components".format(x.op.name)
if int_dtype is None:
try:
int_dtype = {
dtypes.float16: dtypes.int16,
dtypes.float32: dtypes.int32,
dtypes.float64: dtypes.int64,
}[x.dtype.base_dtype]
except KeyError:
raise TypeError("Unrecognized type {}".format(x.dtype.name))
return check_ops.assert_equal(
x, math_ops.cast(math_ops.cast(x, int_dtype), x.dtype),
data=data, summarize=summarize, message=message, name=name)
开发者ID:1000sprites,项目名称:tensorflow,代码行数:35,代码来源:util.py
示例17: zero_state
def zero_state(self, batch_size, dtype):
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
if self._initial_cell_state is not None:
cell_state = self._initial_cell_state
else:
cell_state = self._cell.zero_state(batch_size, dtype)
error_message = (
"When calling zero_state of AttentionWrapper %s: " % self._base_name +
"Non-matching batch sizes between the memory "
"(encoder output) and the requested batch size. Are you using "
"the BeamSearchDecoder? If so, make sure your encoder output has "
"been tiled to beam_width via tf.contrib.seq2seq.tile_batch, and "
"the batch_size= argument passed to zero_state is "
"batch_size * beam_width.")
with ops.control_dependencies(
[check_ops.assert_equal(batch_size,
self._attention_mechanism.batch_size,
message=error_message)]):
cell_state = nest.map_structure(
lambda s: array_ops.identity(s, name="checked_cell_state"),
cell_state)
if self._alignment_history:
alignment_history = tensor_array_ops.TensorArray(
dtype=dtype, size=0, dynamic_size=True)
else:
alignment_history = ()
return AttentionWrapperState(
cell_state=cell_state,
time=array_ops.zeros([], dtype=dtypes.int32),
attention=_zero_state_tensors(self._attention_layer_size, batch_size,
dtype),
alignments=self._attention_mechanism.initial_alignments(
batch_size, dtype),
alignment_history=alignment_history)
开发者ID:ajaybhat,项目名称:tensorflow,代码行数:34,代码来源:attention_wrapper.py
示例18: _model_fn_ops
def _model_fn_ops(
expected_features, expected_labels, actual_features, actual_labels, mode):
assert_ops = tuple([
check_ops.assert_equal(
expected_features[k], actual_features[k], name='assert_%s' % k)
for k in expected_features
] + [
check_ops.assert_equal(
expected_labels, actual_labels, name='assert_labels')
])
with ops.control_dependencies(assert_ops):
return model_fn.ModelFnOps(
mode=mode,
predictions=constant_op.constant(0.),
loss=constant_op.constant(0.),
train_op=constant_op.constant(0.))
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:16,代码来源:estimator_test.py
示例19: calculate_reshape
def calculate_reshape(original_shape, new_shape, validate=False, name=None):
"""Calculates the reshaped dimensions (replacing up to one -1 in reshape)."""
batch_shape_static = tensor_util.constant_value_as_shape(new_shape)
if batch_shape_static.is_fully_defined():
return np.int32(batch_shape_static.as_list()), batch_shape_static, []
with ops.name_scope(name, "calculate_reshape", [original_shape, new_shape]):
original_size = math_ops.reduce_prod(original_shape)
implicit_dim = math_ops.equal(new_shape, -1)
size_implicit_dim = (
original_size // math_ops.maximum(1, -math_ops.reduce_prod(new_shape)))
new_ndims = array_ops.shape(new_shape)
expanded_new_shape = array_ops.where( # Assumes exactly one `-1`.
implicit_dim, array_ops.fill(new_ndims, size_implicit_dim), new_shape)
validations = [] if not validate else [
check_ops.assert_rank(
original_shape, 1, message="Original shape must be a vector."),
check_ops.assert_rank(
new_shape, 1, message="New shape must be a vector."),
check_ops.assert_less_equal(
math_ops.count_nonzero(implicit_dim, dtype=dtypes.int32),
1,
message="At most one dimension can be unknown."),
check_ops.assert_positive(
expanded_new_shape, message="Shape elements must be >=-1."),
check_ops.assert_equal(
math_ops.reduce_prod(expanded_new_shape),
original_size,
message="Shape sizes do not match."),
]
return expanded_new_shape, batch_shape_static, validations
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:30,代码来源:batch_reshape.py
示例20: _verify_input
def _verify_input(tensor_list, labels, probs_list):
"""Verify that batched inputs are well-formed."""
checked_probs_list = []
for probs in probs_list:
# Since number of classes shouldn't change at runtime, probabilities shape
# should be fully defined.
probs.get_shape().assert_is_fully_defined()
# Probabilities must be 1D.
probs.get_shape().assert_has_rank(1)
# Probabilities must be nonnegative and sum to one.
tol = 1e-6
prob_sum = math_ops.reduce_sum(probs)
checked_probs = control_flow_ops.with_dependencies([
check_ops.assert_non_negative(probs),
check_ops.assert_less(prob_sum, 1.0 + tol),
check_ops.assert_less(1.0 - tol, prob_sum)
], probs)
checked_probs_list.append(checked_probs)
# All probabilities should be the same length.
prob_length = checked_probs_list[0].get_shape().num_elements()
for checked_prob in checked_probs_list:
if checked_prob.get_shape().num_elements() != prob_length:
raise ValueError('Probability parameters must have the same length.')
# Labels tensor should only have batch dimension.
labels.get_shape().assert_has_rank(1)
for tensor in tensor_list:
# Data tensor should have a batch dimension.
shape = tensor.get_shape().with_rank_at_least(1)
# Data and label batch dimensions must be compatible.
tensor_shape.dimension_at_index(shape, 0).assert_is_compatible_with(
labels.get_shape()[0])
# Data and labels must have the same, strictly positive batch size. Since we
# can't assume we know the batch size at graph creation, add runtime checks.
labels_batch_size = array_ops.shape(labels)[0]
lbl_assert = check_ops.assert_positive(labels_batch_size)
# Make each tensor depend on its own checks.
labels = control_flow_ops.with_dependencies([lbl_assert], labels)
tensor_list = [
control_flow_ops.with_dependencies([
lbl_assert,
check_ops.assert_equal(array_ops.shape(x)[0], labels_batch_size)
], x) for x in tensor_list
]
# Label's classes must be integers 0 <= x < num_classes.
labels = control_flow_ops.with_dependencies([
check_ops.assert_integer(labels), check_ops.assert_non_negative(labels),
check_ops.assert_less(labels, math_ops.cast(prob_length, labels.dtype))
], labels)
return tensor_list, labels, checked_probs_list
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:59,代码来源:sampling_ops.py
注:本文中的tensorflow.python.ops.check_ops.assert_equal函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论