• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python train.gan_loss函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.contrib.gan.python.train.gan_loss函数的典型用法代码示例。如果您正苦于以下问题:Python gan_loss函数的具体用法?Python gan_loss怎么用?Python gan_loss使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了gan_loss函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: _test_acgan_helper

  def _test_acgan_helper(self, create_gan_model_fn):
    model = create_gan_model_fn()
    loss = train.gan_loss(model)
    loss_ac_gen = train.gan_loss(model, aux_cond_generator_weight=1.0)
    loss_ac_dis = train.gan_loss(model, aux_cond_discriminator_weight=1.0)
    self.assertTrue(isinstance(loss, namedtuples.GANLoss))
    self.assertTrue(isinstance(loss_ac_gen, namedtuples.GANLoss))
    self.assertTrue(isinstance(loss_ac_dis, namedtuples.GANLoss))

    # Check values.
    with self.test_session(use_gpu=True) as sess:
      variables.global_variables_initializer().run()
      loss_gen_np, loss_ac_gen_gen_np, loss_ac_dis_gen_np = sess.run(
          [loss.generator_loss,
           loss_ac_gen.generator_loss,
           loss_ac_dis.generator_loss])
      loss_dis_np, loss_ac_gen_dis_np, loss_ac_dis_dis_np = sess.run(
          [loss.discriminator_loss,
           loss_ac_gen.discriminator_loss,
           loss_ac_dis.discriminator_loss])

    self.assertTrue(loss_gen_np < loss_dis_np)
    self.assertTrue(np.isscalar(loss_ac_gen_gen_np))
    self.assertTrue(np.isscalar(loss_ac_dis_gen_np))
    self.assertTrue(np.isscalar(loss_ac_gen_dis_np))
    self.assertTrue(np.isscalar(loss_ac_dis_dis_np))
开发者ID:BhaskarNallani,项目名称:tensorflow,代码行数:26,代码来源:train_test.py


示例2: test_doesnt_crash_when_in_nested_scope

  def test_doesnt_crash_when_in_nested_scope(self):
    with variable_scope.variable_scope('outer_scope'):
      gan_model = train.gan_model(
          generator_model,
          discriminator_model,
          real_data=array_ops.zeros([1, 2]),
          generator_inputs=random_ops.random_normal([1, 2]))

      # This should work inside a scope.
      train.gan_loss(gan_model, gradient_penalty_weight=1.0)

    # This should also work outside a scope.
    train.gan_loss(gan_model, gradient_penalty_weight=1.0)
开发者ID:BhaskarNallani,项目名称:tensorflow,代码行数:13,代码来源:train_test.py


示例3: _test_tensor_pool_helper

  def _test_tensor_pool_helper(self, create_gan_model_fn):
    model = create_gan_model_fn()
    if isinstance(model, namedtuples.InfoGANModel):

      def tensor_pool_fn_impl(input_values):
        generated_data, generator_inputs = input_values
        output_values = random_tensor_pool.tensor_pool(
            [generated_data] + generator_inputs, pool_size=5)
        return output_values[0], output_values[1:]

      tensor_pool_fn = tensor_pool_fn_impl
    else:

      def tensor_pool_fn_impl(input_values):
        return random_tensor_pool.tensor_pool(input_values, pool_size=5)

      tensor_pool_fn = tensor_pool_fn_impl
    loss = train.gan_loss(model, tensor_pool_fn=tensor_pool_fn)
    self.assertTrue(isinstance(loss, namedtuples.GANLoss))

    # Check values.
    with self.test_session(use_gpu=True) as sess:
      variables.global_variables_initializer().run()
      for _ in range(10):
        sess.run([loss.generator_loss, loss.discriminator_loss])
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:25,代码来源:train_test.py


示例4: _test_grad_penalty_helper

  def _test_grad_penalty_helper(self, create_gan_model_fn):
    model = create_gan_model_fn()
    loss = train.gan_loss(model)
    loss_gp = train.gan_loss(model, gradient_penalty_weight=1.0)
    self.assertTrue(isinstance(loss_gp, namedtuples.GANLoss))

    # Check values.
    with self.test_session(use_gpu=True) as sess:
      variables.global_variables_initializer().run()
      loss_gen_np, loss_gen_gp_np = sess.run(
          [loss.generator_loss, loss_gp.generator_loss])
      loss_dis_np, loss_dis_gp_np = sess.run(
          [loss.discriminator_loss, loss_gp.discriminator_loss])

    self.assertEqual(loss_gen_np, loss_gen_gp_np)
    self.assertTrue(loss_dis_np < loss_dis_gp_np)
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:16,代码来源:train_test.py


示例5: test_train_hooks_exist_in_get_hooks_fn

  def test_train_hooks_exist_in_get_hooks_fn(self, create_gan_model_fn):
    model = create_gan_model_fn()
    loss = train.gan_loss(model)

    g_opt = get_sync_optimizer()
    d_opt = get_sync_optimizer()
    train_ops = train.gan_train_ops(
        model,
        loss,
        g_opt,
        d_opt,
        summarize_gradients=True,
        colocate_gradients_with_ops=True)

    sequential_train_hooks = train.get_sequential_train_hooks()(train_ops)
    self.assertLen(sequential_train_hooks, 4)
    sync_opts = [
        hook._sync_optimizer for hook in sequential_train_hooks if
        isinstance(hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)]
    self.assertLen(sync_opts, 2)
    self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))

    joint_train_hooks = train.get_joint_train_hooks()(train_ops)
    self.assertLen(joint_train_hooks, 5)
    sync_opts = [
        hook._sync_optimizer for hook in joint_train_hooks if
        isinstance(hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)]
    self.assertLen(sync_opts, 2)
    self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:29,代码来源:train_test.py


示例6: test_sync_replicas

  def test_sync_replicas(self, create_gan_model_fn, create_global_step):
    model = create_gan_model_fn()
    loss = train.gan_loss(model)
    num_trainable_vars = len(variables_lib.get_trainable_variables())

    if create_global_step:
      gstep = variable_scope.get_variable(
          'custom_gstep', dtype=dtypes.int32, initializer=0, trainable=False)
      ops.add_to_collection(ops.GraphKeys.GLOBAL_STEP, gstep)

    g_opt = get_sync_optimizer()
    d_opt = get_sync_optimizer()
    train_ops = train.gan_train_ops(
        model, loss, generator_optimizer=g_opt, discriminator_optimizer=d_opt)
    self.assertIsInstance(train_ops, namedtuples.GANTrainOps)
    # No new trainable variables should have been added.
    self.assertLen(variables_lib.get_trainable_variables(), num_trainable_vars)

    # Sync hooks should be populated in the GANTrainOps.
    self.assertLen(train_ops.train_hooks, 2)
    for hook in train_ops.train_hooks:
      self.assertIsInstance(
          hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)
    sync_opts = [hook._sync_optimizer for hook in train_ops.train_hooks]
    self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))

    g_sync_init_op = g_opt.get_init_tokens_op(num_tokens=1)
    d_sync_init_op = d_opt.get_init_tokens_op(num_tokens=1)

    # Check that update op is run properly.
    global_step = training_util.get_or_create_global_step()
    with self.test_session(use_gpu=True) as sess:
      variables.global_variables_initializer().run()
      variables.local_variables_initializer().run()

      g_opt.chief_init_op.run()
      d_opt.chief_init_op.run()

      gstep_before = global_step.eval()

      # Start required queue runner for SyncReplicasOptimizer.
      coord = coordinator.Coordinator()
      g_threads = g_opt.get_chief_queue_runner().create_threads(sess, coord)
      d_threads = d_opt.get_chief_queue_runner().create_threads(sess, coord)

      g_sync_init_op.run()
      d_sync_init_op.run()

      train_ops.generator_train_op.eval()
      # Check that global step wasn't incremented.
      self.assertEqual(gstep_before, global_step.eval())

      train_ops.discriminator_train_op.eval()
      # Check that global step wasn't incremented.
      self.assertEqual(gstep_before, global_step.eval())

      coord.request_stop()
      coord.join(g_threads + d_threads)
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:58,代码来源:train_test.py


示例7: test_discriminator_only_sees_pool

 def test_discriminator_only_sees_pool(self):
   """Checks that discriminator only sees pooled values."""
   def checker_gen_fn(_):
     return constant_op.constant(0.0)
   model = train.gan_model(
       checker_gen_fn,
       discriminator_model,
       real_data=array_ops.zeros([]),
       generator_inputs=random_ops.random_normal([]))
   def tensor_pool_fn(_):
     return (random_ops.random_uniform([]), random_ops.random_uniform([]))
   def checker_dis_fn(inputs, _):
     """Discriminator that checks that it only sees pooled Tensors."""
     self.assertFalse(constant_op.is_constant(inputs))
     return inputs
   model = model._replace(
       discriminator_fn=checker_dis_fn)
   train.gan_loss(model, tensor_pool_fn=tensor_pool_fn)
开发者ID:AnishShah,项目名称:tensorflow,代码行数:18,代码来源:train_test.py


示例8: test_grad_penalty

  def test_grad_penalty(self, create_gan_model_fn, one_sided):
    """Test gradient penalty option."""
    model = create_gan_model_fn()
    loss = train.gan_loss(model)
    loss_gp = train.gan_loss(
        model,
        gradient_penalty_weight=1.0,
        gradient_penalty_one_sided=one_sided)
    self.assertIsInstance(loss_gp, namedtuples.GANLoss)

    # Check values.
    with self.test_session(use_gpu=True) as sess:
      variables.global_variables_initializer().run()
      loss_gen_np, loss_gen_gp_np = sess.run(
          [loss.generator_loss, loss_gp.generator_loss])
      loss_dis_np, loss_dis_gp_np = sess.run(
          [loss.discriminator_loss, loss_gp.discriminator_loss])

    self.assertEqual(loss_gen_np, loss_gen_gp_np)
    self.assertLess(loss_dis_np, loss_dis_gp_np)
开发者ID:AnishShah,项目名称:tensorflow,代码行数:20,代码来源:train_test.py


示例9: test_tensor_pool

  def test_tensor_pool(self, create_gan_model_fn):
    """Test tensor pool option."""
    model = create_gan_model_fn()
    tensor_pool_fn = lambda x: random_tensor_pool.tensor_pool(x, pool_size=5)
    loss = train.gan_loss(model, tensor_pool_fn=tensor_pool_fn)
    self.assertIsInstance(loss, namedtuples.GANLoss)

    # Check values.
    with self.test_session(use_gpu=True) as sess:
      variables.global_variables_initializer().run()
      for _ in range(10):
        sess.run([loss.generator_loss, loss.discriminator_loss])
开发者ID:AnishShah,项目名称:tensorflow,代码行数:12,代码来源:train_test.py


示例10: _test_regularization_helper

  def _test_regularization_helper(self, get_gan_model_fn):
    # Evaluate losses without regularization.
    no_reg_loss = train.gan_loss(get_gan_model_fn())
    with self.test_session(use_gpu=True):
      no_reg_loss_gen_np = no_reg_loss.generator_loss.eval()
      no_reg_loss_dis_np = no_reg_loss.discriminator_loss.eval()

    with ops.name_scope(get_gan_model_fn().generator_scope.name):
      ops.add_to_collection(
          ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(3.0))
    with ops.name_scope(get_gan_model_fn().discriminator_scope.name):
      ops.add_to_collection(
          ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(2.0))

    # Check that losses now include the correct regularization values.
    reg_loss = train.gan_loss(get_gan_model_fn())
    with self.test_session(use_gpu=True):
      reg_loss_gen_np = reg_loss.generator_loss.eval()
      reg_loss_dis_np = reg_loss.discriminator_loss.eval()

    self.assertTrue(3.0, reg_loss_gen_np - no_reg_loss_gen_np)
    self.assertTrue(3.0, reg_loss_dis_np - no_reg_loss_dis_np)
开发者ID:BhaskarNallani,项目名称:tensorflow,代码行数:22,代码来源:train_test.py


示例11: _test_tensor_pool_helper

  def _test_tensor_pool_helper(self, create_gan_model_fn):
    model = create_gan_model_fn()
    if isinstance(model, namedtuples.InfoGANModel):
      tensor_pool_fn = get_tensor_pool_fn_for_infogan(pool_size=5)
    else:
      tensor_pool_fn = get_tensor_pool_fn(pool_size=5)
    loss = train.gan_loss(model, tensor_pool_fn=tensor_pool_fn)
    self.assertTrue(isinstance(loss, namedtuples.GANLoss))

    # Check values.
    with self.test_session(use_gpu=True) as sess:
      variables.global_variables_initializer().run()
      for _ in range(10):
        sess.run([loss.generator_loss, loss.discriminator_loss])
开发者ID:BhaskarNallani,项目名称:tensorflow,代码行数:14,代码来源:train_test.py


示例12: _test_run_helper

  def _test_run_helper(self, create_gan_model_fn):
    random_seed.set_random_seed(1234)
    model = create_gan_model_fn()
    loss = train.gan_loss(model)

    g_opt = gradient_descent.GradientDescentOptimizer(1.0)
    d_opt = gradient_descent.GradientDescentOptimizer(1.0)
    train_ops = train.gan_train_ops(model, loss, g_opt, d_opt)

    final_step = train.gan_train(
        train_ops,
        logdir='',
        hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=2)])
    self.assertTrue(np.isscalar(final_step))
    self.assertEqual(2, final_step)
开发者ID:BhaskarNallani,项目名称:tensorflow,代码行数:15,代码来源:train_test.py


示例13: _test_output_type_helper

  def _test_output_type_helper(self, create_gan_model_fn):
    model = create_gan_model_fn()
    loss = train.gan_loss(model)

    g_opt = gradient_descent.GradientDescentOptimizer(1.0)
    d_opt = gradient_descent.GradientDescentOptimizer(1.0)
    train_ops = train.gan_train_ops(
        model,
        loss,
        g_opt,
        d_opt,
        summarize_gradients=True,
        colocate_gradients_with_ops=True)

    self.assertTrue(isinstance(train_ops, namedtuples.GANTrainOps))
开发者ID:BhaskarNallani,项目名称:tensorflow,代码行数:15,代码来源:train_test.py


示例14: test_patchgan

  def test_patchgan(self, create_gan_model_fn):
    """Ensure that patch-based discriminators work end-to-end."""
    random_seed.set_random_seed(1234)
    model = create_gan_model_fn()
    loss = train.gan_loss(model)

    g_opt = gradient_descent.GradientDescentOptimizer(1.0)
    d_opt = gradient_descent.GradientDescentOptimizer(1.0)
    train_ops = train.gan_train_ops(model, loss, g_opt, d_opt)

    final_step = train.gan_train(
        train_ops,
        logdir='',
        hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=2)])
    self.assertTrue(np.isscalar(final_step))
    self.assertEqual(2, final_step)
开发者ID:AnishShah,项目名称:tensorflow,代码行数:16,代码来源:train_test.py


示例15: test_unused_update_ops

  def test_unused_update_ops(self, create_gan_model_fn, provide_update_ops):
    model = create_gan_model_fn()
    loss = train.gan_loss(model)

    # Add generator and discriminator update ops.
    with variable_scope.variable_scope(model.generator_scope):
      gen_update_count = variable_scope.get_variable('gen_count', initializer=0)
      gen_update_op = gen_update_count.assign_add(1)
      ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, gen_update_op)
    with variable_scope.variable_scope(model.discriminator_scope):
      dis_update_count = variable_scope.get_variable('dis_count', initializer=0)
      dis_update_op = dis_update_count.assign_add(1)
      ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, dis_update_op)

    # Add an update op outside the generator and discriminator scopes.
    if provide_update_ops:
      kwargs = {
          'update_ops': [
              constant_op.constant(1.0), gen_update_op, dis_update_op
          ]
      }
    else:
      ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, constant_op.constant(1.0))
      kwargs = {}

    g_opt = gradient_descent.GradientDescentOptimizer(1.0)
    d_opt = gradient_descent.GradientDescentOptimizer(1.0)

    with self.assertRaisesRegexp(ValueError, 'There are unused update ops:'):
      train.gan_train_ops(
          model, loss, g_opt, d_opt, check_for_unused_update_ops=True, **kwargs)
    train_ops = train.gan_train_ops(
        model, loss, g_opt, d_opt, check_for_unused_update_ops=False, **kwargs)

    with self.test_session(use_gpu=True) as sess:
      sess.run(variables.global_variables_initializer())
      self.assertEqual(0, gen_update_count.eval())
      self.assertEqual(0, dis_update_count.eval())

      train_ops.generator_train_op.eval()
      self.assertEqual(1, gen_update_count.eval())
      self.assertEqual(0, dis_update_count.eval())

      train_ops.discriminator_train_op.eval()
      self.assertEqual(1, gen_update_count.eval())
      self.assertEqual(1, dis_update_count.eval())
开发者ID:AnishShah,项目名称:tensorflow,代码行数:46,代码来源:train_test.py


示例16: test_output_type

  def test_output_type(self, create_gan_model_fn):
    model = create_gan_model_fn()
    loss = train.gan_loss(model)

    g_opt = gradient_descent.GradientDescentOptimizer(1.0)
    d_opt = gradient_descent.GradientDescentOptimizer(1.0)
    train_ops = train.gan_train_ops(
        model,
        loss,
        g_opt,
        d_opt,
        summarize_gradients=True,
        colocate_gradients_with_ops=True)

    self.assertIsInstance(train_ops, namedtuples.GANTrainOps)

    # Make sure there are no training hooks populated accidentally.
    self.assertEmpty(train_ops.train_hooks)
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:18,代码来源:train_test.py


示例17: test_is_chief_in_train_hooks

  def test_is_chief_in_train_hooks(self, is_chief):
    """Make sure is_chief is propagated correctly to sync hooks."""
    model = create_gan_model()
    loss = train.gan_loss(model)
    g_opt = get_sync_optimizer()
    d_opt = get_sync_optimizer()
    train_ops = train.gan_train_ops(
        model,
        loss,
        g_opt,
        d_opt,
        is_chief=is_chief,
        summarize_gradients=True,
        colocate_gradients_with_ops=True)

    self.assertLen(train_ops.train_hooks, 2)
    for hook in train_ops.train_hooks:
      self.assertIsInstance(
          hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)
    is_chief_list = [hook._is_chief for hook in train_ops.train_hooks]
    self.assertListEqual(is_chief_list, [is_chief, is_chief])
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:21,代码来源:train_test.py


示例18: test_output_type

 def test_output_type(self, get_gan_model_fn):
   """Test output type."""
   loss = train.gan_loss(get_gan_model_fn(), add_summaries=True)
   self.assertIsInstance(loss, namedtuples.GANLoss)
   self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0)
开发者ID:AnishShah,项目名称:tensorflow,代码行数:5,代码来源:train_test.py


示例19: test_mutual_info_penalty

 def test_mutual_info_penalty(self, create_gan_model_fn):
   """Test mutual information penalty option."""
   train.gan_loss(
       create_gan_model_fn(),
       mutual_information_penalty_weight=constant_op.constant(1.0))
开发者ID:AnishShah,项目名称:tensorflow,代码行数:5,代码来源:train_test.py


示例20: _test_mutual_info_penalty_helper

 def _test_mutual_info_penalty_helper(self, create_gan_model_fn):
   train.gan_loss(create_gan_model_fn(),
                  mutual_information_penalty_weight=constant_op.constant(1.0))
开发者ID:BhaskarNallani,项目名称:tensorflow,代码行数:3,代码来源:train_test.py



注:本文中的tensorflow.contrib.gan.python.train.gan_loss函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python train.gan_model函数代码示例发布时间:2022-05-27
下一篇:
Python variables.model_variable函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap