本文整理汇总了Python中tensorflow.contrib.framework.python.ops.variables.get_trainable_variables函数的典型用法代码示例。如果您正苦于以下问题:Python get_trainable_variables函数的具体用法?Python get_trainable_variables怎么用?Python get_trainable_variables使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了get_trainable_variables函数的9个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: test_sync_replicas
def test_sync_replicas(self, create_gan_model_fn, create_global_step):
model = create_gan_model_fn()
loss = train.gan_loss(model)
num_trainable_vars = len(variables_lib.get_trainable_variables())
if create_global_step:
gstep = variable_scope.get_variable(
'custom_gstep', dtype=dtypes.int32, initializer=0, trainable=False)
ops.add_to_collection(ops.GraphKeys.GLOBAL_STEP, gstep)
g_opt = get_sync_optimizer()
d_opt = get_sync_optimizer()
train_ops = train.gan_train_ops(
model, loss, generator_optimizer=g_opt, discriminator_optimizer=d_opt)
self.assertIsInstance(train_ops, namedtuples.GANTrainOps)
# No new trainable variables should have been added.
self.assertLen(variables_lib.get_trainable_variables(), num_trainable_vars)
# Sync hooks should be populated in the GANTrainOps.
self.assertLen(train_ops.train_hooks, 2)
for hook in train_ops.train_hooks:
self.assertIsInstance(
hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)
sync_opts = [hook._sync_optimizer for hook in train_ops.train_hooks]
self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))
g_sync_init_op = g_opt.get_init_tokens_op(num_tokens=1)
d_sync_init_op = d_opt.get_init_tokens_op(num_tokens=1)
# Check that update op is run properly.
global_step = training_util.get_or_create_global_step()
with self.test_session(use_gpu=True) as sess:
variables.global_variables_initializer().run()
variables.local_variables_initializer().run()
g_opt.chief_init_op.run()
d_opt.chief_init_op.run()
gstep_before = global_step.eval()
# Start required queue runner for SyncReplicasOptimizer.
coord = coordinator.Coordinator()
g_threads = g_opt.get_chief_queue_runner().create_threads(sess, coord)
d_threads = d_opt.get_chief_queue_runner().create_threads(sess, coord)
g_sync_init_op.run()
d_sync_init_op.run()
train_ops.generator_train_op.eval()
# Check that global step wasn't incremented.
self.assertEqual(gstep_before, global_step.eval())
train_ops.discriminator_train_op.eval()
# Check that global step wasn't incremented.
self.assertEqual(gstep_before, global_step.eval())
coord.request_stop()
coord.join(g_threads + d_threads)
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:58,代码来源:train_test.py
示例2: _make_prediction_gan_model
def _make_prediction_gan_model(input_data, input_data_domain_label,
generator_fn, generator_scope):
"""Make a `StarGANModel` from just the generator."""
# If `generator_fn` has an argument `mode`, pass mode to it.
if 'mode' in inspect.getargspec(generator_fn).args:
generator_fn = functools.partial(
generator_fn, mode=model_fn_lib.ModeKeys.PREDICT)
with variable_scope.variable_scope(generator_scope) as gen_scope:
# pylint:disable=protected-access
input_data = tfgan_train._convert_tensor_or_l_or_d(input_data)
input_data_domain_label = tfgan_train._convert_tensor_or_l_or_d(
input_data_domain_label)
# pylint:enable=protected-access
generated_data = generator_fn(input_data, input_data_domain_label)
generator_variables = variable_lib.get_trainable_variables(gen_scope)
return tfgan_tuples.StarGANModel(
input_data=input_data,
input_data_domain_label=None,
generated_data=generated_data,
generated_data_domain_target=input_data_domain_label,
reconstructed_data=None,
discriminator_input_data_source_predication=None,
discriminator_generated_data_source_predication=None,
discriminator_input_data_domain_predication=None,
discriminator_generated_data_domain_predication=None,
generator_variables=generator_variables,
generator_scope=generator_scope,
generator_fn=generator_fn,
discriminator_variables=None,
discriminator_scope=None,
discriminator_fn=None)
开发者ID:ahmedsaiduk,项目名称:tensorflow,代码行数:32,代码来源:stargan_estimator_impl.py
示例3: _make_prediction_gan_model
def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope):
"""Make a `GANModel` from just the generator."""
with variable_scope.variable_scope(generator_scope) as gen_scope:
generator_inputs = tfgan_train._convert_tensor_or_l_or_d(generator_inputs) # pylint:disable=protected-access
generated_data = generator_fn(generator_inputs)
generator_variables = variable_lib.get_trainable_variables(gen_scope)
return tfgan_tuples.GANModel(
generator_inputs,
generated_data,
generator_variables,
gen_scope,
generator_fn,
real_data=None,
discriminator_real_outputs=None,
discriminator_gen_outputs=None,
discriminator_variables=None,
discriminator_scope=None,
discriminator_fn=None)
开发者ID:DjangoPeng,项目名称:tensorflow,代码行数:19,代码来源:gan_estimator_impl.py
示例4: _make_prediction_gan_model
def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope):
"""Make a `GANModel` from just the generator."""
# If `generator_fn` has an argument `mode`, pass mode to it.
if 'mode' in inspect.getargspec(generator_fn).args:
generator_fn = functools.partial(generator_fn,
mode=model_fn_lib.ModeKeys.PREDICT)
with variable_scope.variable_scope(generator_scope) as gen_scope:
generator_inputs = tfgan_train._convert_tensor_or_l_or_d(generator_inputs) # pylint:disable=protected-access
generated_data = generator_fn(generator_inputs)
generator_variables = variable_lib.get_trainable_variables(gen_scope)
return tfgan_tuples.GANModel(
generator_inputs,
generated_data,
generator_variables,
gen_scope,
generator_fn,
real_data=None,
discriminator_real_outputs=None,
discriminator_gen_outputs=None,
discriminator_variables=None,
discriminator_scope=None,
discriminator_fn=None)
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:23,代码来源:gan_estimator_impl.py
示例5: combine_adversarial_loss
def combine_adversarial_loss(main_loss,
adversarial_loss,
weight_factor=None,
gradient_ratio=None,
gradient_ratio_epsilon=1e-6,
variables=None,
scalar_summaries=True,
gradient_summaries=True,
scope=None):
"""Utility to combine main and adversarial losses.
This utility combines the main and adversarial losses in one of two ways.
1) Fixed coefficient on adversarial loss. Use `weight_factor` in this case.
2) Fixed ratio of gradients. Use `gradient_ratio` in this case. This is often
used to make sure both losses affect weights roughly equally, as in
https://arxiv.org/pdf/1705.05823.
One can optionally also visualize the scalar and gradient behavior of the
losses.
Args:
main_loss: A floating scalar Tensor indicating the main loss.
adversarial_loss: A floating scalar Tensor indication the adversarial loss.
weight_factor: If not `None`, the coefficient by which to multiply the
adversarial loss. Exactly one of this and `gradient_ratio` must be
non-None.
gradient_ratio: If not `None`, the ratio of the magnitude of the gradients.
Specifically,
gradient_ratio = grad_mag(main_loss) / grad_mag(adversarial_loss)
Exactly one of this and `weight_factor` must be non-None.
gradient_ratio_epsilon: An epsilon to add to the adversarial loss
coefficient denominator, to avoid division-by-zero.
variables: List of variables to calculate gradients with respect to. If not
present, defaults to all trainable variables.
scalar_summaries: Create scalar summaries of losses.
gradient_summaries: Create gradient summaries of losses.
scope: Optional name scope.
Returns:
A floating scalar Tensor indicating the desired combined loss.
Raises:
ValueError: Malformed input.
"""
_validate_args([main_loss, adversarial_loss], weight_factor, gradient_ratio)
if variables is None:
variables = contrib_variables_lib.get_trainable_variables()
with ops.name_scope(scope, 'adversarial_loss',
values=[main_loss, adversarial_loss]):
# Compute gradients if we will need them.
if gradient_summaries or gradient_ratio is not None:
main_loss_grad_mag = _numerically_stable_global_norm(
gradients_impl.gradients(main_loss, variables))
adv_loss_grad_mag = _numerically_stable_global_norm(
gradients_impl.gradients(adversarial_loss, variables))
# Add summaries, if applicable.
if scalar_summaries:
summary.scalar('main_loss', main_loss)
summary.scalar('adversarial_loss', adversarial_loss)
if gradient_summaries:
summary.scalar('main_loss_gradients', main_loss_grad_mag)
summary.scalar('adversarial_loss_gradients', adv_loss_grad_mag)
# Combine losses in the appropriate way.
# If `weight_factor` is always `0`, avoid computing the adversarial loss
# tensor entirely.
if _used_weight((weight_factor, gradient_ratio)) == 0:
final_loss = main_loss
elif weight_factor is not None:
final_loss = (main_loss +
array_ops.stop_gradient(weight_factor) * adversarial_loss)
elif gradient_ratio is not None:
grad_mag_ratio = main_loss_grad_mag / (
adv_loss_grad_mag + gradient_ratio_epsilon)
adv_coeff = grad_mag_ratio / gradient_ratio
summary.scalar('adversarial_coefficient', adv_coeff)
final_loss = (main_loss +
array_ops.stop_gradient(adv_coeff) * adversarial_loss)
return final_loss
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:82,代码来源:losses_impl.py
示例6: gan_model
def gan_model(
# Lambdas defining models.
generator_fn,
discriminator_fn,
# Real data and conditioning.
real_data,
generator_inputs,
# Optional scopes.
generator_scope='Generator',
discriminator_scope='Discriminator',
# Options.
check_shapes=True):
"""Returns GAN model outputs and variables.
Args:
generator_fn: A python lambda that takes `generator_inputs` as inputs and
returns the outputs of the GAN generator.
discriminator_fn: A python lambda that takes `real_data`/`generated data`
and `generator_inputs`. Outputs a Tensor in the range [-inf, inf].
real_data: A Tensor representing the real data.
generator_inputs: A Tensor or list of Tensors to the generator. In the
vanilla GAN case, this might be a single noise Tensor. In the conditional
GAN case, this might be the generator's conditioning.
generator_scope: Optional generator variable scope. Useful if you want to
reuse a subgraph that has already been created.
discriminator_scope: Optional discriminator variable scope. Useful if you
want to reuse a subgraph that has already been created.
check_shapes: If `True`, check that generator produces Tensors that are the
same shape as real data. Otherwise, skip this check.
Returns:
A GANModel namedtuple.
Raises:
ValueError: If the generator outputs a Tensor that isn't the same shape as
`real_data`.
"""
# Create models
with variable_scope.variable_scope(generator_scope) as gen_scope:
generator_inputs = _convert_tensor_or_l_or_d(generator_inputs)
generated_data = generator_fn(generator_inputs)
with variable_scope.variable_scope(discriminator_scope) as dis_scope:
discriminator_gen_outputs = discriminator_fn(generated_data,
generator_inputs)
with variable_scope.variable_scope(dis_scope, reuse=True):
real_data = ops.convert_to_tensor(real_data)
discriminator_real_outputs = discriminator_fn(real_data, generator_inputs)
if check_shapes:
if not generated_data.shape.is_compatible_with(real_data.shape):
raise ValueError(
'Generator output shape (%s) must be the same shape as real data '
'(%s).' % (generated_data.shape, real_data.shape))
# Get model-specific variables.
generator_variables = variables_lib.get_trainable_variables(gen_scope)
discriminator_variables = variables_lib.get_trainable_variables(dis_scope)
return namedtuples.GANModel(
generator_inputs,
generated_data,
generator_variables,
gen_scope,
generator_fn,
real_data,
discriminator_real_outputs,
discriminator_gen_outputs,
discriminator_variables,
dis_scope,
discriminator_fn)
开发者ID:andrewharp,项目名称:tensorflow,代码行数:70,代码来源:train.py
示例7: acgan_model
def acgan_model(
# Lambdas defining models.
generator_fn,
discriminator_fn,
# Real data and conditioning.
real_data,
generator_inputs,
one_hot_labels,
# Optional scopes.
generator_scope='Generator',
discriminator_scope='Discriminator',
# Options.
check_shapes=True):
"""Returns an ACGANModel contains all the pieces needed for ACGAN training.
The `acgan_model` is the same as the `gan_model` with the only difference
being that the discriminator additionally outputs logits to classify the input
(real or generated).
Therefore, an explicit field holding one_hot_labels is necessary, as well as a
discriminator_fn that outputs a 2-tuple holding the logits for real/fake and
classification.
See https://arxiv.org/abs/1610.09585 for more details.
Args:
generator_fn: A python lambda that takes `generator_inputs` as inputs and
returns the outputs of the GAN generator.
discriminator_fn: A python lambda that takes `real_data`/`generated data`
and `generator_inputs`. Outputs a tuple consisting of two Tensors:
(1) real/fake logits in the range [-inf, inf]
(2) classification logits in the range [-inf, inf]
real_data: A Tensor representing the real data.
generator_inputs: A Tensor or list of Tensors to the generator. In the
vanilla GAN case, this might be a single noise Tensor. In the conditional
GAN case, this might be the generator's conditioning.
one_hot_labels: A Tensor holding one-hot-labels for the batch. Needed by
acgan_loss.
generator_scope: Optional generator variable scope. Useful if you want to
reuse a subgraph that has already been created.
discriminator_scope: Optional discriminator variable scope. Useful if you
want to reuse a subgraph that has already been created.
check_shapes: If `True`, check that generator produces Tensors that are the
same shape as real data. Otherwise, skip this check.
Returns:
A ACGANModel namedtuple.
Raises:
ValueError: If the generator outputs a Tensor that isn't the same shape as
`real_data`.
TypeError: If the discriminator does not output a tuple consisting of
(discrimination logits, classification logits).
"""
# Create models
with variable_scope.variable_scope(generator_scope) as gen_scope:
generator_inputs = _convert_tensor_or_l_or_d(generator_inputs)
generated_data = generator_fn(generator_inputs)
with variable_scope.variable_scope(discriminator_scope) as dis_scope:
(discriminator_gen_outputs, discriminator_gen_classification_logits
) = _validate_acgan_discriminator_outputs(
discriminator_fn(generated_data, generator_inputs))
with variable_scope.variable_scope(dis_scope, reuse=True):
real_data = ops.convert_to_tensor(real_data)
(discriminator_real_outputs, discriminator_real_classification_logits
) = _validate_acgan_discriminator_outputs(
discriminator_fn(real_data, generator_inputs))
if check_shapes:
if not generated_data.shape.is_compatible_with(real_data.shape):
raise ValueError(
'Generator output shape (%s) must be the same shape as real data '
'(%s).' % (generated_data.shape, real_data.shape))
# Get model-specific variables.
generator_variables = variables_lib.get_trainable_variables(gen_scope)
discriminator_variables = variables_lib.get_trainable_variables(
dis_scope)
return namedtuples.ACGANModel(
generator_inputs, generated_data, generator_variables, gen_scope,
generator_fn, real_data, discriminator_real_outputs,
discriminator_gen_outputs, discriminator_variables, dis_scope,
discriminator_fn, one_hot_labels,
discriminator_real_classification_logits,
discriminator_gen_classification_logits)
开发者ID:andrewharp,项目名称:tensorflow,代码行数:84,代码来源:train.py
示例8: infogan_model
def infogan_model(
# Lambdas defining models.
generator_fn,
discriminator_fn,
# Real data and conditioning.
real_data,
unstructured_generator_inputs,
structured_generator_inputs,
# Optional scopes.
generator_scope='Generator',
discriminator_scope='Discriminator'):
"""Returns an InfoGAN model outputs and variables.
See https://arxiv.org/abs/1606.03657 for more details.
Args:
generator_fn: A python lambda that takes a list of Tensors as inputs and
returns the outputs of the GAN generator.
discriminator_fn: A python lambda that takes `real_data`/`generated data`
and `generator_inputs`. Outputs a 2-tuple of (logits, distribution_list).
`logits` are in the range [-inf, inf], and `distribution_list` is a list
of Tensorflow distributions representing the predicted noise distribution
of the ith structure noise.
real_data: A Tensor representing the real data.
unstructured_generator_inputs: A list of Tensors to the generator.
These tensors represent the unstructured noise or conditioning.
structured_generator_inputs: A list of Tensors to the generator.
These tensors must have high mutual information with the recognizer.
generator_scope: Optional generator variable scope. Useful if you want to
reuse a subgraph that has already been created.
discriminator_scope: Optional discriminator variable scope. Useful if you
want to reuse a subgraph that has already been created.
Returns:
An InfoGANModel namedtuple.
Raises:
ValueError: If the generator outputs a Tensor that isn't the same shape as
`real_data`.
ValueError: If the discriminator output is malformed.
"""
# Create models
with variable_scope.variable_scope(generator_scope) as gen_scope:
unstructured_generator_inputs = _convert_tensor_or_l_or_d(
unstructured_generator_inputs)
structured_generator_inputs = _convert_tensor_or_l_or_d(
structured_generator_inputs)
generator_inputs = (
unstructured_generator_inputs + structured_generator_inputs)
generated_data = generator_fn(generator_inputs)
with variable_scope.variable_scope(discriminator_scope) as disc_scope:
dis_gen_outputs, predicted_distributions = discriminator_fn(
generated_data, generator_inputs)
_validate_distributions(predicted_distributions, structured_generator_inputs)
with variable_scope.variable_scope(disc_scope, reuse=True):
real_data = ops.convert_to_tensor(real_data)
dis_real_outputs, _ = discriminator_fn(real_data, generator_inputs)
if not generated_data.get_shape().is_compatible_with(real_data.get_shape()):
raise ValueError(
'Generator output shape (%s) must be the same shape as real data '
'(%s).' % (generated_data.get_shape(), real_data.get_shape()))
# Get model-specific variables.
generator_variables = variables_lib.get_trainable_variables(gen_scope)
discriminator_variables = variables_lib.get_trainable_variables(
disc_scope)
return namedtuples.InfoGANModel(
generator_inputs,
generated_data,
generator_variables,
gen_scope,
generator_fn,
real_data,
dis_real_outputs,
dis_gen_outputs,
discriminator_variables,
disc_scope,
lambda x, y: discriminator_fn(x, y)[0], # conform to non-InfoGAN API
structured_generator_inputs,
predicted_distributions,
discriminator_fn)
开发者ID:andrewharp,项目名称:tensorflow,代码行数:83,代码来源:train.py
示例9: stargan_model
def stargan_model(generator_fn,
discriminator_fn,
input_data,
input_data_domain_label,
generator_scope='Generator',
discriminator_scope='Discriminator'):
"""Returns a StarGAN model outputs and variables.
See https://arxiv.org/abs/1711.09020 for more details.
Args:
generator_fn: A python lambda that takes `inputs` and `targets` as inputs
and returns 'generated_data' as the transformed version of `input` based
on the `target`. `input` has shape (n, h, w, c), `targets` has shape (n,
num_domains), and `generated_data` has the same shape as `input`.
discriminator_fn: A python lambda that takes `inputs` and `num_domains` as
inputs and returns a tuple (`source_prediction`, `domain_prediction`).
`source_prediction` represents the source(real/generated) prediction by
the discriminator, and `domain_prediction` represents the domain
prediction/classification by the discriminator. `source_prediction` has
shape (n) and `domain_prediction` has shape (n, num_domains).
input_data: Tensor or a list of tensor of shape (n, h, w, c) representing
the real input images.
input_data_domain_label: Tensor or a list of tensor of shape (batch_size,
num_domains) representing the domain label associated with the real
images.
generator_scope: Optional generator variable scope. Useful if you want to
reuse a subgraph that has already been created.
discriminator_scope: Optional discriminator variable scope. Useful if you
want to reuse a subgraph that has already been created.
Returns:
StarGANModel nametuple return the tensor that are needed to compute the
loss.
Raises:
ValueError: If the shape of `input_data_domain_label` is not rank 2 or fully
defined in every dimensions.
"""
# Convert to tensor.
input_data = _convert_tensor_or_l_or_d(input_data)
input_data_domain_label = _convert_tensor_or_l_or_d(input_data_domain_label)
# Convert list of tensor to a single tensor if applicable.
if isinstance(input_data, (list, tuple)):
input_data = array_ops.concat(
[ops.convert_to_tensor(x) for x in input_data], 0)
if isinstance(input_data_domain_label, (list, tuple)):
input_data_domain_label = array_ops.concat(
[ops.convert_to_tensor(x) for x in input_data_domain_label], 0)
# Get batch_size, num_domains from the labels.
input_data_domain_label.shape.assert_has_rank(2)
input_data_domain_label.shape.assert_is_fully_defined()
batch_size, num_domains = input_data_domain_label.shape.as_list()
# Transform input_data to random target domains.
with variable_scope.variable_scope(generator_scope) as generator_scope:
generated_data_domain_target = _generate_stargan_random_domain_target(
batch_size, num_domains)
generated_data = generator_fn(input_data, generated_data_domain_target)
# Transform generated_data back to the original input_data domain.
with variable_scope.variable_scope(generator_scope, reuse=True):
reconstructed_data = generator_fn(generated_data, input_data_domain_label)
# Predict source and domain for the generated_data using the discriminator.
with variable_scope.variable_scope(
discriminator_scope) as discriminator_scope:
disc_gen_data_source_pred, disc_gen_data_domain_pred = discriminator_fn(
generated_data, num_domains)
# Predict source and domain for the input_data using the discriminator.
with variable_scope.variable_scope(discriminator_scope, reuse=True):
disc_input_data_source_pred, disc_input_data_domain_pred = discriminator_fn(
input_data, num_domains)
# Collect trainable variables from the neural networks.
generator_variables = variables_lib.get_trainable_variables(generator_scope)
discriminator_variables = variables_lib.get_trainable_variables(
discriminator_scope)
# Create the StarGANModel namedtuple.
return namedtuples.StarGANModel(
input_data=input_data,
input_data_domain_label=input_data_domain_label,
generated_data=generated_data,
generated_data_domain_target=generated_data_domain_target,
reconstructed_data=reconstructed_data,
discriminator_input_data_source_predication=disc_input_data_source_pred,
discriminator_generated_data_source_predication=disc_gen_data_source_pred,
discriminator_input_data_domain_predication=disc_input_data_domain_pred,
discriminator_generated_data_domain_predication=disc_gen_data_domain_pred,
generator_variables=generator_variables,
generator_scope=generator_scope,
generator_fn=generator_fn,
discriminator_variables=discriminator_variables,
discriminator_scope=discriminator_scope,
discriminator_fn=discriminator_fn)
开发者ID:zakizhou,项目名称:tensorflow,代码行数:100,代码来源:train.py
注:本文中的tensorflow.contrib.framework.python.ops.variables.get_trainable_variables函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论