本文整理汇总了Python中tensorflow.python.util.nest.flatten函数的典型用法代码示例。如果您正苦于以下问题:Python flatten函数的具体用法?Python flatten怎么用?Python flatten使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了flatten函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: convert_to_generator_like
def convert_to_generator_like(data,
batch_size=None,
steps_per_epoch=None,
epochs=1,
shuffle=False):
"""Make a generator out of NumPy or EagerTensor inputs.
Arguments:
data: Either a generator or `keras.utils.data_utils.Sequence` object or
`Dataset` or `EagerIterator` or a {1,2,3}-tuple of NumPy arrays or
EagerTensors. If a tuple, the elements represent `(x, y, sample_weights)`
and may be `None` or `[None]`.
batch_size: Used when creating a generator out of tuples of NumPy arrays or
EagerTensors.
steps_per_epoch: Steps of the generator to run each epoch.
epochs: Total number of epochs to run.
shuffle: Whether the data should be shuffled.
Returns:
- Generator or `keras.utils.data_utils.Sequence` or EagerIterator.
Raises:
- ValueError: If `batch_size` is not provided for NumPy or EagerTensor
inputs.
"""
if isinstance(data, tuple):
# Scrub `Nones` that might have been passed for `targets`, `sample_weights`.
data = tuple(
ele for ele in data if not all(e is None for e in nest.flatten(ele)))
if len(data) == 1:
data = data[0]
if data_utils.is_generator_or_sequence(data) or isinstance(
data, iterator_ops.EagerIterator):
if isinstance(data, data_utils.Sequence):
steps_per_epoch = len(data)
return data, steps_per_epoch
if isinstance(data, dataset_ops.DatasetV2):
return dataset_ops.make_one_shot_iterator(data), steps_per_epoch
# Create generator from NumPy or EagerTensor Input.
num_samples = int(nest.flatten(data)[0].shape[0])
if batch_size is None:
raise ValueError('You must specify `batch_size`')
steps_per_epoch = int(math.ceil(num_samples / batch_size))
def _gen(data):
"""Makes a generator out of a structure of NumPy/EagerTensors."""
index_array = np.arange(num_samples)
for _ in range(epochs):
if shuffle:
np.random.shuffle(index_array)
batches = generic_utils.make_batches(num_samples, batch_size)
for (batch_start, batch_end) in batches:
batch_ids = index_array[batch_start:batch_end]
flat_batch_data = training_utils.slice_arrays(
nest.flatten(data), batch_ids, contiguous=(not shuffle))
yield nest.pack_sequence_as(data, flat_batch_data)
return _gen(data), steps_per_epoch
开发者ID:aeverall,项目名称:tensorflow,代码行数:60,代码来源:training_generator.py
示例2: _eager_metrics_fn
def _eager_metrics_fn(model,
outputs,
targets,
sample_weights=None,
masks=None,
return_stateful_result=True):
"""Calculates the metrics for each output of the given model.
Arguments:
model: The model on which metrics are being calculated.
outputs: The outputs of the given model.
targets: The predictions or targets of the given model.
sample_weights: Optional list of sample weights for each output.
masks: Optional list of masks for each output.
return_stateful_result: Boolean, indicates whether the stateful
(aggregated)/stateless metric result should be returned.
Returns:
Returns the metric results for each output of the model.
"""
outputs = nest.flatten(outputs)
targets = nest.flatten(targets)
# TODO(psv): Consider supporting skip target indices in eager mode?
metric_results = model._handle_metrics(
outputs,
targets=targets,
sample_weights=sample_weights,
masks=masks,
return_stateful_result=return_stateful_result)
return [backend.mean(t) for t in metric_results]
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:30,代码来源:training_eager.py
示例3: _create_multi_lstm_cell_ops
def _create_multi_lstm_cell_ops(batch_size, num_units, input_depth,
num_layers, max_time, compiled):
with variable_scope.variable_scope(
"root",
initializer=init_ops.random_uniform_initializer(-0.1, 0.1, seed=2)):
inputs = variable_scope.get_variable(
"inputs", initializer=random_ops.random_uniform(
(max_time, batch_size, input_depth), seed=1))
maybe_xla = lambda c: rnn_cell.CompiledWrapper(c) if compiled else c
cell = core_rnn_cell_impl.MultiRNNCell(
[maybe_xla(core_rnn_cell_impl.LSTMCell(num_units))
for _ in range(num_layers)])
initial_state = cell.zero_state(
batch_size=batch_size, dtype=dtypes.float32)
outputs, final_state = rnn.dynamic_rnn(
cell=cell, inputs=inputs, initial_state=initial_state,
time_major=True)
flat_final_state = nest.flatten(final_state)
trainable_variables = variables.trainable_variables()
outputs_grad = gradients_impl.gradients(
[outputs],
trainable_variables + [inputs] + nest.flatten(initial_state))
final_state_grad = gradients_impl.gradients(
flat_final_state,
trainable_variables + [inputs] + nest.flatten(initial_state))
return {"outputs": outputs,
"final_state": flat_final_state,
"outputs_grad": outputs_grad,
"final_state_grad": final_state_grad}
开发者ID:Jackhuang945,项目名称:tensorflow,代码行数:30,代码来源:rnn_cell_test.py
示例4: call
def call(self, inputs, mask=None):
"""Call the model on new inputs.
In this case `call` just reapplies
all ops in the graph to the new inputs
(e.g. build a new computational graph from the provided inputs).
Arguments:
inputs: A tensor or list of tensors.
mask: A mask or list of masks. A mask can be
either a tensor or None (no mask).
Returns:
A tensor if there is a single output, or
a list of tensors if there are more than one outputs.
"""
inputs = nest.flatten(inputs)
if mask is None:
masks = [None for _ in range(len(inputs))]
else:
masks = nest.flatten(mask)
if context.in_graph_mode():
# Try to retrieve cached outputs if the layer has already been called
# on these exact inputs.
cache_key = (layers_util.object_list_uid(inputs)
+ '_' + layers_util.object_list_uid(masks))
if cache_key in self._output_tensor_cache:
# Cache hit.
return self._output_tensor_cache[cache_key]
# Actually apply the network graph to the new inputs.
outputs, _ = self._run_internal_graph(inputs, masks)
return outputs
开发者ID:AnddyWang,项目名称:tensorflow,代码行数:33,代码来源:network.py
示例5: _eager_metrics_fn
def _eager_metrics_fn(model, outputs, targets, sample_weights=None, masks=None):
"""Calculates the metrics for each output of the given model.
Arguments:
model: The model on which metrics are being calculated.
outputs: The outputs of the given model.
targets: The predictions or targets of the given model.
sample_weights: Optional list of sample weights for each output.
masks: Optional list of masks for each output.
Returns:
Returns the metric results for each output of the model.
"""
outputs = nest.flatten(outputs)
targets = nest.flatten(targets)
# TODO(psv): Consider supporting skip target indices in eager mode?
# Invoke all(weighted and unweighted) metrics.
metric_results = []
if targets:
metric_results = model._handle_metrics(
outputs,
targets=targets,
sample_weights=sample_weights,
masks=masks,
return_weighted_and_unweighted_metrics=True)
# Add metric results from the `add_metric` metrics.
metric_results.extend([
m.result()
for m in model.metrics
if m not in model._compile_metric_functions
])
return metric_results
开发者ID:aritratony,项目名称:tensorflow,代码行数:33,代码来源:training_eager.py
示例6: _get_cached_states
def _get_cached_states(self, times):
"""Retrieve cached states for a batch of times."""
read_chunk_numbers = self._get_chunk_number(times)
looked_up_state = list(self._cached_states.lookup(
math_ops.cast(read_chunk_numbers, dtypes.int64)))
looked_up_state = tuple(looked_up_state)
# We need to special-case the first chunk in a series to explicitly rely on
# the model's starting state so that gradients flow back to it. Otherwise it
# would affect only initialization, and would not be read from or updated
# during training. Not doing this also isolates that part of the graph,
# leading to errors on model reload if there are trainable variables
# affecting a model's start state.
if self._input_statistics is not None:
start_time = self._input_statistics.start_time
else:
start_time = 0
set_to_start_state = math_ops.equal(read_chunk_numbers,
self._get_chunk_number(start_time))
new_states = []
for start_state_value, cache_variable in zip(
nest.flatten(
math_utils.replicate_state(self._start_state,
array_ops.shape(times)[0])),
nest.flatten(looked_up_state)):
new_states.append(
array_ops.where(set_to_start_state, start_state_value,
cache_variable))
looked_up_state = nest.pack_sequence_as(looked_up_state, new_states)
return looked_up_state
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:30,代码来源:state_management.py
示例7: _apply_exogenous_update
def _apply_exogenous_update(
self, current_times, step_number, state, raw_features,
embedded_exogenous_regressors):
"""Performs a conditional state update based on exogenous features."""
if embedded_exogenous_regressors is None:
return state
else:
current_exogenous_regressors = embedded_exogenous_regressors[
:, step_number, :]
exogenous_updated_state = self._exogenous_input_step(
current_times=current_times,
current_exogenous_regressors=current_exogenous_regressors,
state=state)
if self._exogenous_update_condition is not None:
current_raw_exogenous_features = {
key: value[:, step_number] for key, value in raw_features.items()
if key not in [PredictionFeatures.STATE_TUPLE,
TrainEvalFeatures.TIMES,
TrainEvalFeatures.VALUES]}
conditionally_updated_state_flat = []
for updated_state_element, original_state_element in zip(
nest.flatten(exogenous_updated_state),
nest.flatten(state)):
conditionally_updated_state_flat.append(
array_ops.where(
self._exogenous_update_condition(
times=current_times,
features=current_raw_exogenous_features),
updated_state_element,
original_state_element))
return nest.pack_sequence_as(state, conditionally_updated_state_flat)
else:
return exogenous_updated_state
开发者ID:AutumnQYN,项目名称:tensorflow,代码行数:33,代码来源:model.py
示例8: _Update
def _Update(struct_acc, struct_x, t):
"""Updates t-th row in accumulators.
Args:
struct_acc: The accumulators. A structure of tensors.
struct_x: The new values. A structure of tensors congruent to `struct_acc`.
t: A scalar integer. Performance is better if `t` is on the device
memory.
Returns:
A structure of tensors. Say, ret is a returned dictionary. Then, for
each key, we have:
ret[key] = struct_acc[key];
ret[key][t, :] = struct_x[key]
"""
to_skip_update = set()
acc_lst = nest.flatten(struct_acc)
x_lst = nest.flatten(struct_x)
t = math_ops.to_int32([t]) # tf.to_int32 casts on-device tensors.
lst = []
for acc, x in zip(acc_lst, x_lst):
if acc in to_skip_update:
# Until b/62105730 is fixed, we need to avoid inplace update for tensors
# of rank 1. could reshape to handle it, but we don't really need the
# values applied to these, so just skip their modification.
lst += [acc]
else:
lst += [alias_inplace_update(acc, t, array_ops.expand_dims(x, 0))]
return nest.pack_sequence_as(struct_acc, lst)
开发者ID:AnishShah,项目名称:tensorflow,代码行数:29,代码来源:recurrent.py
示例9: testFlattenAndPack
def testFlattenAndPack(self):
structure = ((3, 4), 5, (6, 7, (9, 10), 8))
flat = ["a", "b", "c", "d", "e", "f", "g", "h"]
self.assertEqual(nest.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8])
self.assertEqual(
nest.pack_sequence_as(structure, flat), (("a", "b"), "c",
("d", "e", ("f", "g"), "h")))
point = collections.namedtuple("Point", ["x", "y"])
structure = (point(x=4, y=2), ((point(x=1, y=0),),))
flat = [4, 2, 1, 0]
self.assertEqual(nest.flatten(structure), flat)
restructured_from_flat = nest.pack_sequence_as(structure, flat)
self.assertEqual(restructured_from_flat, structure)
self.assertEqual(restructured_from_flat[0].x, 4)
self.assertEqual(restructured_from_flat[0].y, 2)
self.assertEqual(restructured_from_flat[1][0][0].x, 1)
self.assertEqual(restructured_from_flat[1][0][0].y, 0)
self.assertEqual([5], nest.flatten(5))
self.assertEqual([np.array([5])], nest.flatten(np.array([5])))
self.assertEqual("a", nest.pack_sequence_as(5, ["a"]))
self.assertEqual(
np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])]))
with self.assertRaisesRegexp(ValueError, "Structure is a scalar"):
nest.pack_sequence_as("scalar", [4, 5])
with self.assertRaisesRegexp(TypeError, "flat_sequence"):
nest.pack_sequence_as([4, 5], "bad_sequence")
with self.assertRaises(ValueError):
nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"])
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:33,代码来源:nest_test.py
示例10: _hierarchical_pad
def _hierarchical_pad(input_, output, control):
"""Pad and flatten hierarchical inputs, outputs, and controls."""
# Pad empty segments with end tokens and flatten hierarchy.
input_ = nest.flatten(pad_with_element(
input_, self._max_lengths[:-1],
data.np_onehot([self.end_token], self.input_depth)))
output = nest.flatten(pad_with_element(
output, self._max_lengths[:-1],
data.np_onehot([self.end_token], self.output_depth)))
length = np.squeeze(np.array([len(x) for x in input_], np.int32))
# Pad and concatenate flatten hierarchy.
input_ = np.concatenate(
[pad_with_value(x, self._max_lengths[-1], 0) for x in input_])
output = np.concatenate(
[pad_with_value(x, self._max_lengths[-1], 0) for x in output])
if np.size(control):
control = nest.flatten(pad_with_element(
control, self._max_lengths[:-1],
data.np_onehot(
[self._control_pad_token], self.control_depth)))
control = np.concatenate(
[pad_with_value(x, self._max_lengths[-1], 0) for x in control])
return input_, output, control, length
开发者ID:cghawthorne,项目名称:magenta,代码行数:26,代码来源:data_hierarchical.py
示例11: testNoProjNoShardingNestedTupleStateSaver
def testNoProjNoShardingNestedTupleStateSaver(self):
num_units = 3
input_size = 5
batch_size = 2
max_length = 8
with self.test_session(graph=tf.Graph()) as sess:
initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
state_saver = TestStateSaver(batch_size, {"c0": num_units,
"m0": num_units,
"c1": num_units + 1,
"m1": num_units + 1,
"c2": num_units + 2,
"m2": num_units + 2,
"c3": num_units + 3,
"m3": num_units + 3})
def _cell(i):
return tf.contrib.rnn.LSTMCell(
num_units + i, use_peepholes=False, initializer=initializer,
state_is_tuple=True)
# This creates a state tuple which has 4 sub-tuples of length 2 each.
cell = tf.contrib.rnn.MultiRNNCell(
[_cell(i) for i in range(4)], state_is_tuple=True)
self.assertEqual(len(cell.state_size), 4)
for i in range(4):
self.assertEqual(len(cell.state_size[i]), 2)
inputs = max_length * [
tf.placeholder(tf.float32, shape=(batch_size, input_size))]
state_names = (("c0", "m0"), ("c1", "m1"),
("c2", "m2"), ("c3", "m3"))
with tf.variable_scope("share_scope"):
outputs, state = tf.contrib.rnn.static_state_saving_rnn(
cell, inputs, state_saver=state_saver, state_name=state_names)
self.assertEqual(len(outputs), len(inputs))
# Final output comes from _cell(3) which has state size num_units + 3
for out in outputs:
self.assertEqual(out.get_shape().as_list(), [batch_size, num_units + 3])
tf.global_variables_initializer().run()
input_value = np.random.randn(batch_size, input_size)
last_states = sess.run(
list(nest.flatten(state)), feed_dict={inputs[0]: input_value})
saved_states = sess.run(
list(state_saver.saved_state.values()),
feed_dict={inputs[0]: input_value})
self.assertEqual(8, len(last_states))
self.assertEqual(8, len(saved_states))
flat_state_names = nest.flatten(state_names)
named_saved_states = dict(
zip(state_saver.saved_state.keys(), saved_states))
for i in range(8):
self.assertAllEqual(
last_states[i],
named_saved_states[flat_state_names[i]])
开发者ID:Hwhitetooth,项目名称:tensorflow,代码行数:59,代码来源:core_rnn_test.py
示例12: body
def body(time, elements_finished, current_input, emit_ta, state, loop_state):
"""Internal while loop body for raw_rnn.
Args:
time: time scalar.
elements_finished: batch-size vector.
current_input: possibly nested tuple of input tensors.
emit_ta: possibly nested tuple of output TensorArrays.
state: possibly nested tuple of state tensors.
loop_state: possibly nested tuple of loop state tensors.
Returns:
Tuple having the same size as Args but with updated values.
"""
(next_output, cell_state) = cell(current_input, state)
nest.assert_same_structure(state, cell_state)
nest.assert_same_structure(cell.output_size, next_output)
next_time = time + 1
(next_finished, next_input, next_state, emit_output, next_loop_state) = loop_fn(
next_time, next_output, cell_state, loop_state
)
nest.assert_same_structure(state, next_state)
nest.assert_same_structure(current_input, next_input)
nest.assert_same_structure(emit_ta, emit_output)
# If loop_fn returns None for next_loop_state, just reuse the
# previous one.
loop_state = loop_state if next_loop_state is None else next_loop_state
def _copy_some_through(current, candidate):
"""Copy some tensors through via array_ops.where."""
current_flat = nest.flatten(current)
candidate_flat = nest.flatten(candidate)
# pylint: disable=g-long-lambda,cell-var-from-loop
result_flat = [
_on_device(
lambda: array_ops.where(elements_finished, current_i, candidate_i), device=candidate_i.op.device
)
for (current_i, candidate_i) in zip(current_flat, candidate_flat)
]
# pylint: enable=g-long-lambda,cell-var-from-loop
return nest.pack_sequence_as(structure=current, flat_sequence=result_flat)
emit_output = _copy_some_through(zero_emit, emit_output)
next_state = _copy_some_through(state, next_state)
emit_output_flat = nest.flatten(emit_output)
emit_ta_flat = nest.flatten(emit_ta)
elements_finished = math_ops.logical_or(elements_finished, next_finished)
emit_ta_flat = [ta.write(time, emit) for (ta, emit) in zip(emit_ta_flat, emit_output_flat)]
emit_ta = nest.pack_sequence_as(structure=emit_structure, flat_sequence=emit_ta_flat)
return (next_time, elements_finished, next_input, emit_ta, next_state, loop_state)
开发者ID:ygoverdhan,项目名称:tensorflow,代码行数:59,代码来源:rnn.py
示例13: _run_targets
def _run_targets(self, targets1, targets2=None, run_init=True):
targets1 = nest.flatten(targets1)
targets2 = ([] if targets2 is None else nest.flatten(targets2))
assert len(targets1) == len(targets2) or not targets2
if run_init:
init = variables.global_variables_initializer()
self.evaluate(init)
return self.evaluate(targets1 + targets2)
开发者ID:LongJun123456,项目名称:tensorflow,代码行数:8,代码来源:control_flow_ops_test.py
示例14: testFlattenDictOrder
def testFlattenDictOrder(self):
"""`flatten` orders dicts by key, including OrderedDicts."""
ordered = collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)])
plain = {"d": 3, "b": 1, "a": 0, "c": 2}
ordered_flat = nest.flatten(ordered)
plain_flat = nest.flatten(plain)
self.assertEqual([0, 1, 2, 3], ordered_flat)
self.assertEqual([0, 1, 2, 3], plain_flat)
开发者ID:Huoxubeiyin,项目名称:tensorflow,代码行数:8,代码来源:nest_test.py
示例15: _assert_same_shape
def _assert_same_shape(input1, input2, double=False):
flat_input1 = nest.flatten(input1)
flat_input2 = nest.flatten(input2)
for inp1, inp2 in zip(flat_input1, flat_input2):
input_shape = inp1.get_shape().as_list()
if double:
input_shape[1] *= 2
self.assertEqual(input_shape, inp2.get_shape().as_list())
开发者ID:Hwhitetooth,项目名称:tensorflow,代码行数:8,代码来源:core_rnn_test.py
示例16: state_barrier_context
def state_barrier_context(state):
"""Return a context manager that prevents interior ops from running
unless the whole state has been computed.
This is to prevent assign race conditions.
"""
tensors = [x for x in nest.flatten(state) if type(x) == tf.Tensor]
tarray = [x.flow for x in nest.flatten(state) if hasattr(x, "flow")]
return tf.control_dependencies(tensors + tarray)
开发者ID:ALISCIFP,项目名称:models,代码行数:9,代码来源:utils.py
示例17: wrap_state
def wrap_state(self, state):
dummy = BeamDecoderCellWrapper(None, self.num_classes, self.max_len, self.stop_token, self.beam_size)
if nest.is_sequence(state):
batch_size = tf.shape(nest.flatten(state)[0])[0]
dtype = nest.flatten(state)[0].dtype
else:
batch_size = tf.shape(state)[0]
dtype = state.dtype
return dummy._create_state(batch_size, dtype, cell_state=state)
开发者ID:anair13,项目名称:tensorflow,代码行数:9,代码来源:beam_decoder.py
示例18: _get_grads_lists_curvature_prop
def _get_grads_lists_curvature_prop(self, tensors):
loss_inputs = list(loss.inputs for loss in self._layers.losses)
transformed_random_signs = self._get_transformed_random_signs()
grads_flat = gradients_impl.gradients(
nest.flatten(loss_inputs),
nest.flatten(tensors),
grad_ys=nest.flatten(transformed_random_signs))
grads_all = nest.pack_sequence_as(tensors, grads_flat)
return tuple((grad,) for grad in grads_all)
开发者ID:SylChan,项目名称:tensorflow,代码行数:9,代码来源:estimator.py
示例19: run_steps_on_dataset
def run_steps_on_dataset(self, fn, iterator, iterations):
# Enqueue ops
shapes = nest.flatten(iterator.output_shapes)
if any([not s.is_fully_defined() for s in shapes]):
raise ValueError(
'TPU currently requires fully defined shapes. Either use '
'set_shape() on the input tensors or use '
'dataset.apply(map_and_batch(..., drop_remainder=True)).')
types = nest.flatten(iterator.output_types)
def enqueue_ops_fn():
"""Enqueue ops for one iteration."""
control_deps = []
sharded_inputs = []
with ops.device(self._host):
for _ in range(self._num_cores_per_host):
# Use control dependencies to ensure a deterministic ordering.
with ops.control_dependencies(control_deps):
inputs = nest.flatten(iterator.get_next())
control_deps.extend(inputs)
sharded_inputs.append(inputs)
enqueue_ops = []
for core_id, shard_input in enumerate(sharded_inputs):
enqueue_ops.append(
tpu_ops.infeed_enqueue_tuple(
inputs=shard_input, shapes=shapes, device_ordinal=core_id))
return enqueue_ops
def enqueue_ops_loop_body(i):
with ops.control_dependencies(enqueue_ops_fn()):
return i + 1
with ops.device(self._host):
enqueue_ops = control_flow_ops.while_loop(
lambda i: i < iterations,
enqueue_ops_loop_body,
[constant_op.constant(0)],
parallel_iterations=1)
# Dequeue ops
def dequeue_fn():
dequeued = tpu.infeed_dequeue_tuple(dtypes=types, shapes=shapes)
return nest.pack_sequence_as(iterator.output_shapes, dequeued)
# Wrap `fn` for repeat.
run_fn = lambda: fn(dequeue_fn())
# Repeat
def iterate_on_tpu():
return tpu.repeat(iterations, run_fn, [])
# Re-write and distribute computation.
tpu_result = tpu.batch_parallel(
iterate_on_tpu, [], num_shards=self._num_cores_per_host)
return control_flow_ops.group(tpu_result, enqueue_ops)
开发者ID:BhaskarNallani,项目名称:tensorflow,代码行数:57,代码来源:tpu_strategy.py
示例20: _check_default_value
def _check_default_value(shape, default_value, dtype, key):
"""Returns default value as tuple if it's valid, otherwise raises errors.
This function verifies that `default_value` is compatible with both `shape`
and `dtype`. If it is not compatible, it raises an error. If it is compatible,
it casts default_value to a tuple and returns it. `key` is used only
for error message.
Args:
shape: An iterable of integers specifies the shape of the `Tensor`.
default_value: If a single value is provided, the same value will be applied
as the default value for every item. If an iterable of values is
provided, the shape of the `default_value` should be equal to the given
`shape`.
dtype: defines the type of values. Default value is `tf.float32`. Must be a
non-quantized, real integer or floating point type.
key: A string providing key to look up corresponding `Tensor`.
Returns:
A tuple which will be used as default value.
Raises:
TypeError: if `default_value` is an iterable but not compatible with `shape`
TypeError: if `default_value` is not compatible with `dtype`.
ValueError: if `dtype` is not convertible to `tf.float32`.
"""
if default_value is None:
return None
if isinstance(default_value, int):
return _create_tuple(shape, default_value)
if isinstance(default_value, float) and dtype.is_floating:
return _create_tuple(shape, default_value)
if callable(getattr(default_value, 'tolist', None)): # Handles numpy arrays
default_value = default_value.tolist()
if nest.is_sequence(default_value):
if not _is_shape_and_default_value_compatible(default_value, shape):
raise ValueError(
'The shape of default_value must be equal to given shape. '
'default_value: {}, shape: {}, key: {}'.format(
default_value, shape, key))
# Check if the values in the list are all integers or are convertible to
# floats.
is_list_all_int = all(
isinstance(v, int) for v in nest.flatten(default_value))
is_list_has_float = any(
isinstance(v, float) for v in nest.flatten(default_value))
if is_list_all_int:
return _as_tuple(default_value)
if is_list_has_float and dtype.is_floating:
return _as_tuple(default_value)
raise TypeError('default_value must be compatible with dtype. '
'default_value: {}, dtype: {}, key: {}'.format(
default_value, dtype, key))
开发者ID:finardi,项目名称:tensorflow,代码行数:57,代码来源:feature_column.py
注:本文中的tensorflow.python.util.nest.flatten函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论