本文整理汇总了Python中tensorflow.python.keras.backend.image_data_format函数的典型用法代码示例。如果您正苦于以下问题:Python image_data_format函数的具体用法?Python image_data_format怎么用?Python image_data_format使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了image_data_format函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: save_img
def save_img(path,
x,
data_format=None,
file_format=None,
scale=True,
**kwargs):
"""Saves an image stored as a Numpy array to a path or file object.
Arguments:
path: Path or file object.
x: Numpy array.
data_format: Image data format,
either "channels_first" or "channels_last".
file_format: Optional file format override. If omitted, the
format to use is determined from the filename extension.
If a file object was used instead of a filename, this
parameter should always be used.
scale: Whether to rescale image values to be within `[0, 255]`.
**kwargs: Additional keyword arguments passed to `PIL.Image.save()`.
"""
if data_format is None:
data_format = backend.image_data_format()
image.save_img(path,
x,
data_format=data_format,
file_format=file_format,
scale=scale, **kwargs)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:27,代码来源:image.py
示例2: preprocess_input
def preprocess_input(x, data_format=None, mode='caffe'):
"""Preprocesses a tensor or Numpy array encoding a batch of images.
Arguments:
x: Input Numpy or symbolic tensor, 3D or 4D.
data_format: Data format of the image tensor/array.
mode: One of "caffe", "tf".
- caffe: will convert the images from RGB to BGR,
then will zero-center each color channel with
respect to the ImageNet dataset,
without scaling.
- tf: will scale pixels between -1 and 1,
sample-wise.
Returns:
Preprocessed tensor or Numpy array.
Raises:
ValueError: In case of unknown `data_format` argument.
"""
if data_format is None:
data_format = K.image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format ' + str(data_format))
if isinstance(x, np.ndarray):
return _preprocess_numpy_input(x, data_format=data_format, mode=mode)
else:
return _preprocess_symbolic_input(x, data_format=data_format, mode=mode)
开发者ID:Huoxubeiyin,项目名称:tensorflow,代码行数:29,代码来源:imagenet_utils.py
示例3: __init__
def __init__(self, x, y, image_data_generator,
batch_size=32,
shuffle=False,
sample_weight=None,
seed=None,
data_format=None,
save_to_dir=None,
save_prefix='',
save_format='png',
subset=None,
dtype=None):
if data_format is None:
data_format = backend.image_data_format()
kwargs = {}
if 'dtype' in tf_inspect.getfullargspec(
image.NumpyArrayIterator.__init__)[0]:
if dtype is None:
dtype = backend.floatx()
kwargs['dtype'] = dtype
super(NumpyArrayIterator, self).__init__(
x, y, image_data_generator,
batch_size=batch_size,
shuffle=shuffle,
sample_weight=sample_weight,
seed=seed,
data_format=data_format,
save_to_dir=save_to_dir,
save_prefix=save_prefix,
save_format=save_format,
subset=subset,
**kwargs)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:31,代码来源:image.py
示例4: get_data_shape
def get_data_shape():
# input image dimensions
img_rows, img_cols = 28, 28
if backend.image_data_format() == 'channels_first':
return 1, img_rows, img_cols
else:
return img_rows, img_cols, 1
开发者ID:aritratony,项目名称:tensorflow,代码行数:7,代码来源:mnist_multi_worker.py
示例5: array_to_img
def array_to_img(x, data_format=None, scale=True, dtype=None):
"""Converts a 3D Numpy array to a PIL Image instance.
Arguments:
x: Input Numpy array.
data_format: Image data format.
either "channels_first" or "channels_last".
scale: Whether to rescale image values
to be within `[0, 255]`.
dtype: Dtype to use.
Returns:
A PIL Image instance.
Raises:
ImportError: if PIL is not available.
ValueError: if invalid `x` or `data_format` is passed.
"""
if data_format is None:
data_format = backend.image_data_format()
kwargs = {}
if 'dtype' in tf_inspect.getfullargspec(image.array_to_img)[0]:
if dtype is None:
dtype = backend.floatx()
kwargs['dtype'] = dtype
return image.array_to_img(x, data_format=data_format, scale=scale, **kwargs)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:27,代码来源:image.py
示例6: load_data
def load_data(label_mode='fine'):
"""Loads CIFAR100 dataset.
Arguments:
label_mode: one of "fine", "coarse".
Returns:
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
Raises:
ValueError: in case of invalid `label_mode`.
"""
if label_mode not in ['fine', 'coarse']:
raise ValueError('`label_mode` must be one of `"fine"`, `"coarse"`.')
dirname = 'cifar-100-python'
origin = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
path = get_file(dirname, origin=origin, untar=True)
fpath = os.path.join(path, 'train')
x_train, y_train = load_batch(fpath, label_key=label_mode + '_labels')
fpath = os.path.join(path, 'test')
x_test, y_test = load_batch(fpath, label_key=label_mode + '_labels')
y_train = np.reshape(y_train, (len(y_train), 1))
y_test = np.reshape(y_test, (len(y_test), 1))
if K.image_data_format() == 'channels_last':
x_train = x_train.transpose(0, 2, 3, 1)
x_test = x_test.transpose(0, 2, 3, 1)
return (x_train, y_train), (x_test, y_test)
开发者ID:AnishShah,项目名称:tensorflow,代码行数:33,代码来源:cifar100.py
示例7: conv_block
def conv_block(x, growth_rate, name):
"""A building block for a dense block.
Arguments:
x: input tensor.
growth_rate: float, growth rate at dense layers.
name: string, block label.
Returns:
output tensor for the block.
"""
bn_axis = 3 if K.image_data_format() == 'channels_last' else 1
x1 = BatchNormalization(
axis=bn_axis, epsilon=1.001e-5, name=name + '_0_bn')(
x)
x1 = Activation('relu', name=name + '_0_relu')(x1)
x1 = Conv2D(4 * growth_rate, 1, use_bias=False, name=name + '_1_conv')(x1)
x1 = BatchNormalization(
axis=bn_axis, epsilon=1.001e-5, name=name + '_1_bn')(
x1)
x1 = Activation('relu', name=name + '_1_relu')(x1)
x1 = Conv2D(
growth_rate, 3, padding='same', use_bias=False, name=name + '_2_conv')(
x1)
x = Concatenate(axis=bn_axis, name=name + '_concat')([x, x1])
return x
开发者ID:Huoxubeiyin,项目名称:tensorflow,代码行数:26,代码来源:densenet.py
示例8: __dense_block
def __dense_block(x, nb_layers, nb_filter, growth_rate, bottleneck=False, dropout_rate=None, weight_decay=1e-4,
grow_nb_filters=True, return_concat_list=False):
''' Build a dense_block where the output of each conv_block is fed to subsequent ones
Args:
x: keras tensor
nb_layers: the number of layers of conv_block to append to the model.
nb_filter: number of filters
growth_rate: growth rate
bottleneck: bottleneck block
dropout_rate: dropout rate
weight_decay: weight decay factor
grow_nb_filters: flag to decide to allow number of filters to grow
return_concat_list: return the list of feature maps along with the actual output
Returns: keras tensor with nb_layers of conv_block appended
'''
concat_axis = 1 if K.image_data_format() == 'channels_first' else -1
x_list = [x]
for i in range(nb_layers):
cb = __conv_block(x, growth_rate, bottleneck, dropout_rate, weight_decay)
x_list.append(cb)
x = concatenate([x, cb], axis=concat_axis)
if grow_nb_filters:
nb_filter += growth_rate
if return_concat_list:
return x, nb_filter, x_list
else:
return x, nb_filter
开发者ID:AlexBlack2202,项目名称:ImageAI,代码行数:32,代码来源:densenet.py
示例9: __conv_block
def __conv_block(ip, nb_filter, bottleneck=False, dropout_rate=None, weight_decay=1e-4):
''' Apply BatchNorm, Relu, 3x3 Conv2D, optional bottleneck block and dropout
Args:
ip: Input keras tensor
nb_filter: number of filters
bottleneck: add bottleneck block
dropout_rate: dropout rate
weight_decay: weight decay factor
Returns: keras tensor with batch_norm, relu and convolution2d added (optional bottleneck)
'''
concat_axis = 1 if K.image_data_format() == 'channels_first' else -1
x = BatchNormalization(axis=concat_axis, epsilon=1.1e-5)(ip)
x = Activation('relu')(x)
if bottleneck:
inter_channel = nb_filter * 4 # Obtained from https://github.com/liuzhuang13/DenseNet/blob/master/densenet.lua
x = Conv2D(inter_channel, (1, 1), kernel_initializer='he_normal', padding='same', use_bias=False,
kernel_regularizer=l2(weight_decay))(x)
x = BatchNormalization(axis=concat_axis, epsilon=1.1e-5)(x)
x = Activation('relu')(x)
x = Conv2D(nb_filter, (3, 3), kernel_initializer='he_normal', padding='same', use_bias=False)(x)
if dropout_rate:
x = Dropout(dropout_rate)(x)
return x
开发者ID:AlexBlack2202,项目名称:ImageAI,代码行数:28,代码来源:densenet.py
示例10: load_data
def load_data():
"""Loads CIFAR10 dataset.
Returns:
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
"""
dirname = 'cifar-10-batches-py'
origin = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
path = get_file(dirname, origin=origin, untar=True)
num_train_samples = 50000
x_train = np.empty((num_train_samples, 3, 32, 32), dtype='uint8')
y_train = np.empty((num_train_samples,), dtype='uint8')
for i in range(1, 6):
fpath = os.path.join(path, 'data_batch_' + str(i))
(x_train[(i - 1) * 10000:i * 10000, :, :, :],
y_train[(i - 1) * 10000:i * 10000]) = load_batch(fpath)
fpath = os.path.join(path, 'test_batch')
x_test, y_test = load_batch(fpath)
y_train = np.reshape(y_train, (len(y_train), 1))
y_test = np.reshape(y_test, (len(y_test), 1))
if K.image_data_format() == 'channels_last':
x_train = x_train.transpose(0, 2, 3, 1)
x_test = x_test.transpose(0, 2, 3, 1)
return (x_train, y_train), (x_test, y_test)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:31,代码来源:cifar10.py
示例11: prepare_simple_model
def prepare_simple_model(input_tensor, loss_name, target):
axis = 1 if K.image_data_format() == 'channels_first' else -1
loss = None
num_channels = None
activation = None
if loss_name == 'sparse_categorical_crossentropy':
loss = lambda y_true, y_pred: K.sparse_categorical_crossentropy( # pylint: disable=g-long-lambda
y_true, y_pred, axis=axis)
num_channels = np.amax(target) + 1
activation = 'softmax'
elif loss_name == 'categorical_crossentropy':
loss = lambda y_true, y_pred: K.categorical_crossentropy( # pylint: disable=g-long-lambda
y_true, y_pred, axis=axis)
num_channels = target.shape[axis]
activation = 'softmax'
elif loss_name == 'binary_crossentropy':
loss = lambda y_true, y_pred: K.binary_crossentropy(y_true, y_pred) # pylint: disable=unnecessary-lambda
num_channels = target.shape[axis]
activation = 'sigmoid'
predictions = Conv2D(num_channels,
1,
activation=activation,
kernel_initializer='ones',
bias_initializer='ones')(input_tensor)
simple_model = keras.models.Model(inputs=input_tensor,
outputs=predictions)
simple_model.compile(optimizer='rmsprop', loss=loss)
return simple_model
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:28,代码来源:training_gpu_test.py
示例12: convert
def convert(in_path, out_path):
"""Convert any Keras model to the frugally-deep model format."""
assert K.backend() == "tensorflow"
assert K.floatx() == "float32"
assert K.image_data_format() == 'channels_last'
print('loading {}'.format(in_path))
model = load_model(in_path)
# Force creation of underlying functional model.
# see: https://github.com/fchollet/keras/issues/8136
# Loss and optimizer type do not matter, since to don't train the model.
model.compile(loss='mse', optimizer='sgd')
model = convert_sequential_to_model(model)
test_data = gen_test_data(model)
json_output = {}
json_output['architecture'] = json.loads(model.to_json())
json_output['image_data_format'] = K.image_data_format()
for depth in range(1, 3, 1):
json_output['conv2d_valid_offset_depth_' + str(depth)] =\
check_operation_offset(depth, offset_conv2d_eval, 'valid')
json_output['conv2d_same_offset_depth_' + str(depth)] =\
check_operation_offset(depth, offset_conv2d_eval, 'same')
json_output['separable_conv2d_valid_offset_depth_' + str(depth)] =\
check_operation_offset(depth, offset_sep_conv2d_eval, 'valid')
json_output['separable_conv2d_same_offset_depth_' + str(depth)] =\
check_operation_offset(depth, offset_sep_conv2d_eval, 'same')
json_output['max_pooling_2d_valid_offset'] =\
check_operation_offset(1, conv2d_offset_max_pool_eval, 'valid')
json_output['max_pooling_2d_same_offset'] =\
check_operation_offset(1, conv2d_offset_max_pool_eval, 'same')
json_output['average_pooling_2d_valid_offset'] =\
check_operation_offset(1, conv2d_offset_average_pool_eval, 'valid')
json_output['average_pooling_2d_same_offset'] =\
check_operation_offset(1, conv2d_offset_average_pool_eval, 'same')
json_output['input_shapes'] = get_shapes(test_data['inputs'])
json_output['output_shapes'] = get_shapes(test_data['outputs'])
json_output['tests'] = [test_data]
json_output['trainable_params'] = get_all_weights(model)
print('writing {}'.format(out_path))
write_text_file(out_path, json.dumps(
json_output, allow_nan=False, indent=2, sort_keys=True))
开发者ID:Telecommunication-Telemedia-Assessment,项目名称:V-BMS360,代码行数:47,代码来源:convert_model.py
示例13: _conv_block
def _conv_block(inputs, filters, alpha, kernel=(3, 3), strides=(1, 1)):
"""Adds an initial convolution layer (with batch normalization and relu6).
Arguments:
inputs: Input tensor of shape `(rows, cols, 3)`
(with `channels_last` data format) or
(3, rows, cols) (with `channels_first` data format).
It should have exactly 3 inputs channels,
and width and height should be no smaller than 32.
E.g. `(224, 224, 3)` would be one valid value.
filters: Integer, the dimensionality of the output space
(i.e. the number of output filters in the convolution).
alpha: controls the width of the network.
- If `alpha` < 1.0, proportionally decreases the number
of filters in each layer.
- If `alpha` > 1.0, proportionally increases the number
of filters in each layer.
- If `alpha` = 1, default number of filters from the paper
are used at each layer.
kernel: An integer or tuple/list of 2 integers, specifying the
width and height of the 2D convolution window.
Can be a single integer to specify the same value for
all spatial dimensions.
strides: An integer or tuple/list of 2 integers,
specifying the strides of the convolution along the width and height.
Can be a single integer to specify the same value for
all spatial dimensions.
Specifying any stride value != 1 is incompatible with specifying
any `dilation_rate` value != 1.
Input shape:
4D tensor with shape:
`(samples, channels, rows, cols)` if data_format='channels_first'
or 4D tensor with shape:
`(samples, rows, cols, channels)` if data_format='channels_last'.
Output shape:
4D tensor with shape:
`(samples, filters, new_rows, new_cols)` if data_format='channels_first'
or 4D tensor with shape:
`(samples, new_rows, new_cols, filters)` if data_format='channels_last'.
`rows` and `cols` values might have changed due to stride.
Returns:
Output tensor of block.
"""
channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
filters = int(filters * alpha)
x = ZeroPadding2D(padding=(1, 1), name='conv1_pad')(inputs)
x = Conv2D(
filters,
kernel,
padding='valid',
use_bias=False,
strides=strides,
name='conv1')(x)
x = BatchNormalization(axis=channel_axis, name='conv1_bn')(x)
return Activation(relu6, name='conv1_relu')(x)
开发者ID:LiuCKind,项目名称:tensorflow,代码行数:58,代码来源:mobilenet.py
示例14: set_model
def set_model(self, model):
"""Sets Keras model and creates summary ops."""
self.model = model
self.sess = K.get_session()
# only make histogram summary op if it hasn't already been made
if self.histogram_freq and self.merged is None:
for layer in self.model.layers:
for weight in layer.weights:
mapped_weight_name = weight.name.replace(':', '_')
tf_summary.histogram(mapped_weight_name, weight)
if self.write_images:
w_img = array_ops.squeeze(weight)
shape = K.int_shape(w_img)
if len(shape) == 2: # dense layer kernel case
if shape[0] > shape[1]:
w_img = array_ops.transpose(w_img)
shape = K.int_shape(w_img)
w_img = array_ops.reshape(w_img, [1, shape[0], shape[1], 1])
elif len(shape) == 3: # convnet case
if K.image_data_format() == 'channels_last':
# switch to channels_first to display
# every kernel as a separate image
w_img = array_ops.transpose(w_img, perm=[2, 0, 1])
shape = K.int_shape(w_img)
w_img = array_ops.reshape(w_img,
[shape[0], shape[1], shape[2], 1])
elif len(shape) == 1: # bias case
w_img = array_ops.reshape(w_img, [1, shape[0], 1, 1])
else:
# not possible to handle 3D convnets etc.
continue
shape = K.int_shape(w_img)
assert len(shape) == 4 and shape[-1] in [1, 3, 4]
tf_summary.image(mapped_weight_name, w_img)
if self.write_grads:
for weight in layer.trainable_weights:
mapped_weight_name = weight.name.replace(':', '_')
grads = model.optimizer.get_gradients(model.total_loss, weight)
def is_indexed_slices(grad):
return type(grad).__name__ == 'IndexedSlices'
grads = [grad.values if is_indexed_slices(grad) else grad
for grad in grads]
tf_summary.histogram('{}_grad'.format(mapped_weight_name), grads)
if hasattr(layer, 'output'):
tf_summary.histogram('{}_out'.format(layer.name), layer.output)
self.merged = tf_summary.merge_all()
if self.write_graph:
self.writer = self._writer_class(self.log_dir, self.sess.graph)
else:
self.writer = self._writer_class(self.log_dir)
开发者ID:LongJun123456,项目名称:tensorflow,代码行数:57,代码来源:callbacks.py
示例15: __init__
def __init__(self, rate, data_format=None, **kwargs):
super(SpatialDropout3D, self).__init__(rate, **kwargs)
if data_format is None:
data_format = K.image_data_format()
if data_format not in {'channels_last', 'channels_first'}:
raise ValueError('data_format must be in '
'{"channels_last", "channels_first"}')
self.data_format = data_format
self.input_spec = InputSpec(ndim=5)
开发者ID:yanchen036,项目名称:tensorflow,代码行数:9,代码来源:core.py
示例16: normalize_data_format
def normalize_data_format(value):
if value is None:
value = backend.image_data_format()
data_format = value.lower()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('The `data_format` argument must be one of '
'"channels_first", "channels_last". Received: ' +
str(value))
return data_format
开发者ID:Huoxubeiyin,项目名称:tensorflow,代码行数:9,代码来源:conv_utils.py
示例17: identity_block
def identity_block(input_tensor, kernel_size, filters, stage, block):
"""The identity block is the block that has no conv layer at shortcut.
# Arguments
input_tensor: input tensor
kernel_size: default 3, the kernel size of
middle conv layer at main path
filters: list of integers, the filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
# Returns
Output tensor for the block.
"""
filters1, filters2, filters3 = filters
if backend.image_data_format() == 'channels_last':
bn_axis = 3
else:
bn_axis = 1
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'
x = layers.Conv2D(filters1, (1, 1), use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2a')(input_tensor)
x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2a')(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters2, kernel_size,
padding='same', use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2b')(x)
x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2b')(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters3, (1, 1), use_bias=False,
kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(L2_WEIGHT_DECAY),
name=conv_name_base + '2c')(x)
x = layers.BatchNormalization(axis=bn_axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
name=bn_name_base + '2c')(x)
x = layers.add([x, input_tensor])
x = layers.Activation('relu')(x)
return x
开发者ID:Exscotticus,项目名称:models,代码行数:55,代码来源:resnet_model.py
示例18: _separable_conv_block
def _separable_conv_block(ip,
filters,
kernel_size=(3, 3),
strides=(1, 1),
block_id=None):
"""Adds 2 blocks of [relu-separable conv-batchnorm].
Arguments:
ip: Input tensor
filters: Number of output filters per layer
kernel_size: Kernel size of separable convolutions
strides: Strided convolution for downsampling
block_id: String block_id
Returns:
A Keras tensor
"""
channel_dim = 1 if K.image_data_format() == 'channels_first' else -1
with K.name_scope('separable_conv_block_%s' % block_id):
x = Activation('relu')(ip)
x = SeparableConv2D(
filters,
kernel_size,
strides=strides,
name='separable_conv_1_%s' % block_id,
padding='same',
use_bias=False,
kernel_initializer='he_normal')(
x)
x = BatchNormalization(
axis=channel_dim,
momentum=0.9997,
epsilon=1e-3,
name='separable_conv_1_bn_%s' % (block_id))(
x)
x = Activation('relu')(x)
x = SeparableConv2D(
filters,
kernel_size,
name='separable_conv_2_%s' % block_id,
padding='same',
use_bias=False,
kernel_initializer='he_normal')(
x)
x = BatchNormalization(
axis=channel_dim,
momentum=0.9997,
epsilon=1e-3,
name='separable_conv_2_bn_%s' % (block_id))(
x)
return x
开发者ID:Huoxubeiyin,项目名称:tensorflow,代码行数:52,代码来源:nasnet.py
示例19: __init__
def __init__(self, pool_function, pool_size, strides,
padding='valid', data_format='channels_last',
name=None, **kwargs):
super(Pooling1D, self).__init__(name=name, **kwargs)
if data_format is None:
data_format = backend.image_data_format()
if strides is None:
strides = pool_size
self.pool_function = pool_function
self.pool_size = conv_utils.normalize_tuple(pool_size, 1, 'pool_size')
self.strides = conv_utils.normalize_tuple(strides, 1, 'strides')
self.padding = conv_utils.normalize_padding(padding)
self.data_format = conv_utils.normalize_data_format(data_format)
self.input_spec = InputSpec(ndim=3)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:14,代码来源:pooling.py
示例20: get_input_datasets
def get_input_datasets(use_bfloat16=False):
"""Downloads the MNIST dataset and creates train and eval dataset objects.
Args:
use_bfloat16: Boolean to determine if input should be cast to bfloat16
Returns:
Train dataset and eval dataset. The dataset doesn't include batch dim.
"""
cast_dtype = dtypes.bfloat16 if use_bfloat16 else dtypes.float32
# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
train_data_shape = (x_train.shape[0],) + get_data_shape()
test_data_shape = (x_test.shape[0],) + get_data_shape()
if backend.image_data_format() == 'channels_first':
x_train = x_train.reshape(train_data_shape)
x_test = x_test.reshape(test_data_shape)
else:
x_train = x_train.reshape(train_data_shape)
x_test = x_test.reshape(test_data_shape)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
# convert class vectors to binary class matrices
y_train = utils.to_categorical(y_train, NUM_CLASSES)
y_test = utils.to_categorical(y_test, NUM_CLASSES)
# train dataset
train_ds = dataset_ops.Dataset.from_tensor_slices((x_train, y_train))
# TODO(rchao): Remove maybe_shard_dataset() once auto-sharding is done.
train_ds = maybe_shard_dataset(train_ds)
train_ds = train_ds.repeat()
train_ds = train_ds.map(lambda x, y: (math_ops.cast(x, cast_dtype), y))
train_ds = train_ds.batch(64, drop_remainder=True)
# eval dataset
eval_ds = dataset_ops.Dataset.from_tensor_slices((x_test, y_test))
# TODO(rchao): Remove maybe_shard_dataset() once auto-sharding is done.
eval_ds = maybe_shard_dataset(eval_ds)
eval_ds = eval_ds.repeat()
eval_ds = eval_ds.map(lambda x, y: (math_ops.cast(x, cast_dtype), y))
eval_ds = eval_ds.batch(64, drop_remainder=True)
return train_ds, eval_ds
开发者ID:aritratony,项目名称:tensorflow,代码行数:50,代码来源:mnist_multi_worker.py
注:本文中的tensorflow.python.keras.backend.image_data_format函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论