• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python smart_cond.smart_cond函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.framework.smart_cond.smart_cond函数的典型用法代码示例。如果您正苦于以下问题:Python smart_cond函数的具体用法?Python smart_cond怎么用?Python smart_cond使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了smart_cond函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: categorical_crossentropy

def categorical_crossentropy(y_true,
                             y_pred,
                             from_logits=False,
                             label_smoothing=0):
  """Computes the categorical crossentropy loss.

  Args:
    y_true: tensor of true targets.
    y_pred: tensor of predicted targets.
    from_logits: Whether `y_pred` is expected to be a logits tensor. By default,
      we assume that `y_pred` encodes a probability distribution.
    label_smoothing: Float in [0, 1]. If > `0` then smooth the labels.

  Returns:
    Categorical crossentropy loss value.
  """
  y_pred = ops.convert_to_tensor(y_pred)
  y_true = math_ops.cast(y_true, y_pred.dtype)
  label_smoothing = ops.convert_to_tensor(label_smoothing, dtype=K.floatx())

  def _smooth_labels():
    num_classes = math_ops.cast(array_ops.shape(y_true)[1], y_pred.dtype)
    return y_true * (1.0 - label_smoothing) + (label_smoothing / num_classes)

  y_true = smart_cond.smart_cond(label_smoothing,
                                 _smooth_labels, lambda: y_true)
  return K.categorical_crossentropy(y_true, y_pred, from_logits=from_logits)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:27,代码来源:losses.py


示例2: _contraction

  def _contraction():
    """Performs a contraction."""
    contracted = face_centroid - contraction * (face_centroid -
                                                simplex[worst_index])
    objective_at_contracted = objective_function(contracted)
    is_contracted_acceptable = objective_at_contracted <= worst_objective_value
    def _accept_contraction():
      next_simplex = _replace_at_index(simplex, worst_index, contracted)
      objective_at_next_simplex = _replace_at_index(
          objective_values,
          worst_index,
          objective_at_contracted)
      return (
          False,
          next_simplex,
          objective_at_next_simplex,
          1
      )

    def _reject_contraction():
      return _shrink_towards_best(objective_function, simplex, best_index,
                                  shrinkage, batch_evaluate_objective)

    return smart_cond.smart_cond(is_contracted_acceptable,
                                 _accept_contraction,
                                 _reject_contraction)
开发者ID:lewisKit,项目名称:probability,代码行数:26,代码来源:nelder_mead.py


示例3: write

def write(tag, tensor, step=None, metadata=None, name=None):
  """Writes a generic summary to the default SummaryWriter if one exists.

  This exists primarily to support the definition of type-specific summary ops
  like scalar() and image(), and is not intended for direct use unless defining
  a new type-specific summary op.

  Args:
    tag: string tag used to identify the summary (e.g. in TensorBoard), usually
      generated with `tf.summary.summary_scope`
    tensor: the Tensor holding the summary data to write
    step: Explicit `int64`-castable monotonic step value for this summary. If
      omitted, this defaults to `tf.summary.experimental.get_step()`, which must
      not be None.
    metadata: Optional SummaryMetadata, as a proto or serialized bytes
    name: Optional string name for this op.

  Returns:
    True on success, or false if no summary was written because no default
    summary writer was available.

  Raises:
    ValueError: if a default writer exists, but no step was provided and
      `tf.summary.experimental.get_step()` is None.
  """
  with ops.name_scope(name, "write_summary") as scope:
    if context.context().summary_writer is None:
      return constant_op.constant(False)
    if step is None:
      step = get_step()
      if step is None:
        raise ValueError("No step set via 'step' argument or "
                         "tf.summary.experimental.set_step()")
    if metadata is None:
      serialized_metadata = b""
    elif hasattr(metadata, "SerializeToString"):
      serialized_metadata = metadata.SerializeToString()
    else:
      serialized_metadata = metadata

    def record():
      """Record the actual summary and return True."""
      # Note the identity to move the tensor to the CPU.
      with ops.device("cpu:0"):
        write_summary_op = gen_summary_ops.write_summary(
            context.context().summary_writer._resource,  # pylint: disable=protected-access
            step,
            array_ops.identity(tensor),
            tag,
            serialized_metadata,
            name=scope)
        with ops.control_dependencies([write_summary_op]):
          return constant_op.constant(True)

    with ops.device("cpu:0"):
      op = smart_cond.smart_cond(
          _should_record_summaries_v2(), record, _nothing, name="summary_cond")
      if not context.executing_eagerly():
        ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, op)  # pylint: disable=protected-access
      return op
开发者ID:aritratony,项目名称:tensorflow,代码行数:60,代码来源:summary_ops_v2.py


示例4: summary_writer_function

def summary_writer_function(name, tensor, function, family=None):
  """Helper function to write summaries.

  Args:
    name: name of the summary
    tensor: main tensor to form the summary
    function: function taking a tag and a scope which writes the summary
    family: optional, the summary's family

  Returns:
    The result of writing the summary.
  """
  name_scope = ops.get_name_scope()
  if name_scope:
    # Add a slash to allow reentering the name scope.
    name_scope += "/"
  def record():
    with ops.name_scope(name_scope), summary_op_util.summary_scope(
        name, family, values=[tensor]) as (tag, scope):
      with ops.control_dependencies([function(tag, scope)]):
        return constant_op.constant(True)

  if context.context().summary_writer_resource is None:
    return control_flow_ops.no_op()
  with ops.device("cpu:0"):
    op = smart_cond.smart_cond(
        should_record_summaries(), record, _nothing, name="")
    if not context.executing_eagerly():
      ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, op)  # pylint: disable=protected-access
  return op
开发者ID:aeverall,项目名称:tensorflow,代码行数:30,代码来源:summary_ops_v2.py


示例5: testSmartCondTrue

 def testSmartCondTrue(self):
   with ops.Graph().as_default():
     with session.Session():
       x = constant_op.constant(2)
       y = constant_op.constant(5)
       z = smart_cond.smart_cond(True, lambda: math_ops.multiply(x, 16),
                                 lambda: math_ops.multiply(y, 5))
       self.assertEqual(z.eval(), 32)
开发者ID:neuroradiology,项目名称:tensorflow,代码行数:8,代码来源:smart_cond_test.py


示例6: testUnknown

 def testUnknown(self):
   with ops.Graph().as_default():
     with session.Session():
       x = array_ops.placeholder(dtype=dtypes.int32)
       y = smart_cond.smart_cond(x > 0, lambda: constant_op.constant(1),
                                 lambda: constant_op.constant(2))
       self.assertEqual(y.eval(feed_dict={x: 1}), 1)
       self.assertEqual(y.eval(feed_dict={x: -1}), 2)
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:8,代码来源:smart_cond_test.py


示例7: testSmartCondFalse

 def testSmartCondFalse(self):
   with ops.Graph().as_default():
     with session.Session():
       x = constant_op.constant(4)
       y = constant_op.constant(3)
       z = smart_cond.smart_cond(False, lambda: math_ops.multiply(x, 16),
                                 lambda: math_ops.multiply(y, 3))
       self.assertEqual(z.eval(), 9)
开发者ID:neuroradiology,项目名称:tensorflow,代码行数:8,代码来源:smart_cond_test.py


示例8: testPlaceholderWithDefault

 def testPlaceholderWithDefault(self):
   with ops.Graph().as_default():
     with session.Session():
       x = array_ops.placeholder_with_default(1, shape=())
       y = smart_cond.smart_cond(x > 0, lambda: constant_op.constant(1),
                                 lambda: constant_op.constant(2))
       self.assertEqual(y.eval(), 1)
       self.assertEqual(y.eval(feed_dict={x: -1}), 2)
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:8,代码来源:smart_cond_test.py


示例9: binary_crossentropy

def binary_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0):

  def _smooth_labels():
    return y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing

  y_true = smart_cond.smart_cond(label_smoothing,
                                 _smooth_labels, lambda: y_true)
  return K.mean(
      K.binary_crossentropy(y_true, y_pred, from_logits=from_logits), axis=-1)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:9,代码来源:losses.py


示例10: testEval

 def testEval(self):
   with ops.Graph().as_default():
     with session.Session():
       x = constant_op.constant(1)
       y = constant_op.constant(2)
       # x * y > 0 can be evaluated at graph construction time, so the false
       # branch shouldn't be evaluated at all.
       z = smart_cond.smart_cond(x * y > 0, lambda: constant_op.constant(1),
                                 raise_exception)
       self.assertEqual(z.eval(feed_dict={x: 1}), 1)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:10,代码来源:smart_cond_test.py


示例11: binary_crossentropy

def binary_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0):  # pylint: disable=missing-docstring
  y_pred = ops.convert_to_tensor(y_pred)
  y_true = math_ops.cast(y_true, y_pred.dtype)
  label_smoothing = ops.convert_to_tensor(label_smoothing, dtype=K.floatx())

  def _smooth_labels():
    return y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing

  y_true = smart_cond.smart_cond(label_smoothing,
                                 _smooth_labels, lambda: y_true)
  return K.mean(
      K.binary_crossentropy(y_true, y_pred, from_logits=from_logits), axis=-1)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:12,代码来源:losses.py


示例12: testEval

  def testEval(self):
    # Constant expression evaluation only works with the C API enabled.
    if not ops._USE_C_API: return

    with ops.Graph().as_default():
      with session.Session():
        x = constant_op.constant(1)
        y = constant_op.constant(2)
        # x * y > 0 can be evaluated at graph construction time, so the false
        # branch shouldn't be evaluated at all.
        z = smart_cond.smart_cond(x * y > 0, lambda: constant_op.constant(1),
                                  raise_exception)
        self.assertEqual(z.eval(feed_dict={x: 1}), 1)
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:13,代码来源:smart_cond_test.py


示例13: _maybe_convert_labels

def _maybe_convert_labels(y_true):
  """Converts binary labels into -1/1."""
  are_zeros = math_ops.equal(y_true, 0)
  are_ones = math_ops.equal(y_true, 1)
  is_binary = math_ops.reduce_all(math_ops.logical_or(are_zeros, are_ones))

  def _convert_binary_labels():
    # Convert the binary labels to -1 or 1.
    return 2. * y_true - 1.

  updated_y_true = smart_cond.smart_cond(is_binary,
                                         _convert_binary_labels, lambda: y_true)
  return updated_y_true
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:13,代码来源:losses.py


示例14: result

  def result(self, write_summary=True):
    """Returns the result of the Metric.

    Args:
      write_summary: bool indicating whether to feed the result to the summary
        before returning.
    Returns:
      aggregated metric as float.
    Raises:
      ValueError: if the optional argument is not bool
    """
    # Convert the boolean to tensor for tf.cond, if it is not.
    if not isinstance(write_summary, ops.Tensor):
      write_summary = ops.convert_to_tensor(write_summary)
    t = self.numer / self.denom
    def write_summary_f():
      summary_ops.scalar(name=self.name, tensor=t)
      return t
    smart_cond.smart_cond(write_summary,
                          write_summary_f,
                          lambda: t,
                          name="")
    return t
开发者ID:ahmedsaiduk,项目名称:tensorflow,代码行数:23,代码来源:metrics_impl.py


示例15: write_raw_pb

def write_raw_pb(tensor, step=None, name=None):
  """Writes a summary using raw `tf.compat.v1.Summary` protocol buffers.

  Experimental: this exists to support the usage of V1-style manual summary
  writing (via the construction of a `tf.compat.v1.Summary` protocol buffer)
  with the V2 summary writing API.

  Args:
    tensor: the string Tensor holding one or more serialized `Summary` protobufs
    step: Explicit `int64`-castable monotonic step value for this summary. If
      omitted, this defaults to `tf.summary.experimental.get_step()`, which must
      not be None.
    name: Optional string name for this op.

  Returns:
    True on success, or false if no summary was written because no default
    summary writer was available.

  Raises:
    ValueError: if a default writer exists, but no step was provided and
      `tf.summary.experimental.get_step()` is None.
  """
  with ops.name_scope(name, "write_raw_pb") as scope:
    if context.context().summary_writer is None:
      return constant_op.constant(False)
    if step is None:
      step = get_step()
      if step is None:
        raise ValueError("No step set via 'step' argument or "
                         "tf.summary.experimental.set_step()")

    def record():
      """Record the actual summary and return True."""
      # Note the identity to move the tensor to the CPU.
      with ops.device("cpu:0"):
        raw_summary_op = gen_summary_ops.write_raw_proto_summary(
            context.context().summary_writer._resource,  # pylint: disable=protected-access
            step,
            array_ops.identity(tensor),
            name=scope)
        with ops.control_dependencies([raw_summary_op]):
          return constant_op.constant(True)

    with ops.device("cpu:0"):
      op = smart_cond.smart_cond(
          _should_record_summaries_v2(), record, _nothing, name="summary_cond")
      if not context.executing_eagerly():
        ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, op)  # pylint: disable=protected-access
      return op
开发者ID:aritratony,项目名称:tensorflow,代码行数:49,代码来源:summary_ops_v2.py


示例16: _expand_and_maybe_replace

 def _expand_and_maybe_replace():
   """Performs the expansion step."""
   expanded = face_centroid + expansion * (reflected - face_centroid)
   expanded_objective_value = objective_function(expanded)
   expanded_is_better = (expanded_objective_value <
                         objective_at_reflected)
   accept_expanded_fn = lambda: (expanded, expanded_objective_value)
   accept_reflected_fn = lambda: (reflected, objective_at_reflected)
   next_pt, next_objective_value = smart_cond.smart_cond(
       expanded_is_better, accept_expanded_fn, accept_reflected_fn)
   next_simplex = _replace_at_index(simplex, worst_index, next_pt)
   next_objective_at_simplex = _replace_at_index(objective_values,
                                                 worst_index,
                                                 next_objective_value)
   return False, next_simplex, next_objective_at_simplex, 1
开发者ID:lewisKit,项目名称:probability,代码行数:15,代码来源:nelder_mead.py


示例17: call

 def call(self, x, training=None):
   # We basically want to call this...
   f = functools.partial(self._func, x, **self._arguments)
   # ...but we may also have to pass a Python boolean for `training`.
   if not self._func_wants_training:
     result = f()
   else:
     if training is None:
       training = tf.keras.backend.learning_phase()  # Could be a tensor.
     result = smart_cond.smart_cond(training,
                                    lambda: f(training=True),
                                    lambda: f(training=False))
   # TODO(b/124219898): Polymorphic function should return shaped tensor.
   if hasattr(self, '_output_shape'):
     result.set_shape((x.shape[0],) + self._output_shape)
   return result
开发者ID:Albert-Z-Guo,项目名称:tensorflow,代码行数:16,代码来源:util.py


示例18: write

def write(tag, tensor, step, metadata=None, name=None):
  """Writes a generic summary to the default SummaryWriter if one exists.

  This exists primarily to support the definition of type-specific summary ops
  like scalar() and image(), and is not intended for direct use unless defining
  a new type-specific summary op.

  Args:
    tag: string tag used to identify the summary (e.g. in TensorBoard), usually
      generated with `tf.summary.summary_scope`
    tensor: the Tensor holding the summary data to write
    step: `int64`-castable monotic step value for this summary
    metadata: Optional SummaryMetadata, as a proto or serialized bytes
    name: Optional string name for this op.

  Returns:
    True on success, or false if no summary was written because no default
    summary writer was available.
  """
  with ops.name_scope(name, "write_summary") as scope:
    if context.context().summary_writer_resource is None:
      return constant_op.constant(False)
    if metadata is None:
      serialized_metadata = constant_op.constant(b"")
    elif hasattr(metadata, "SerializeToString"):
      serialized_metadata = constant_op.constant(metadata.SerializeToString())
    else:
      serialized_metadata = metadata

    def record():
      """Record the actual summary and return True."""
      # Note the identity to move the tensor to the CPU.
      with ops.device("cpu:0"):
        write_summary_op = gen_summary_ops.write_summary(
            context.context().summary_writer_resource,
            step,
            array_ops.identity(tensor),
            tag,
            serialized_metadata,
            name=scope)
        with ops.control_dependencies([write_summary_op]):
          return constant_op.constant(True)

    return smart_cond.smart_cond(
        _should_record_summaries_v2(), record, _nothing, name="summary_cond")
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:45,代码来源:summary_ops_v2.py


示例19: _apply_gradients_cross_replica

  def _apply_gradients_cross_replica(self, distribution, grads_and_vars, name):
    grads = [g for g, _ in grads_and_vars]
    loss_scale_update_op, should_apply_grads = self._loss_scale.update(grads)

    def apply_fn():
      # We do not want DistributionStrategy to unwrap any MirroredVariables in
      # grads_and_vars, because even in a replica context, the wrapped optimizer
      # expects mirrored variables. So we wrap grads_and_vars with an
      # _UnwrapPreventer, preventing DistributionStrategy from unwrapping the
      # MirroredVariables.
      wrapped_grads_and_vars = _UnwrapPreventer(grads_and_vars)
      return distribution.extended.call_for_each_replica(
          self._apply_gradients, args=(wrapped_grads_and_vars, name))

    # Note: We must call this cond() in a cross-replica context.
    # DistributionStrategy does not support having a cond in a replica context
    # with a branch that calls `merge_call`, and self._optimizer.apply_gradients
    # calls `merge_call`.
    maybe_apply_op = smart_cond.smart_cond(should_apply_grads,
                                           apply_fn,
                                           control_flow_ops.no_op)
    return control_flow_ops.group(maybe_apply_op, loss_scale_update_op)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:22,代码来源:loss_scale_optimizer.py


示例20: smart_cond

def smart_cond(pred, true_fn=None, false_fn=None, name=None):
  """Return either `true_fn()` if predicate `pred` is true else `false_fn()`.

  If `pred` is a bool or has a constant value, we return either `true_fn()`
  or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both.

  Arguments:
    pred: A scalar determining whether to return the result of `true_fn` or
      `false_fn`.
    true_fn: The callable to be performed if pred is true.
    false_fn: The callable to be performed if pred is false.
    name: Optional name prefix when using `tf.cond`.

  Returns:
    Tensors returned by the call to either `true_fn` or `false_fn`.

  Raises:
    TypeError: If `true_fn` or `false_fn` is not callable.
  """
  if isinstance(pred, variables.Variable):
    return control_flow_ops.cond(
        pred, true_fn=true_fn, false_fn=false_fn, name=name)
  return smart_module.smart_cond(
      pred, true_fn=true_fn, false_fn=false_fn, name=name)
开发者ID:kylin9872,项目名称:tensorflow,代码行数:24,代码来源:tf_utils.py



注:本文中的tensorflow.python.framework.smart_cond.smart_cond函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap