Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
814 views
in Technique[技术] by (71.8m points)

tensorflow - How does the Flatten layer work in Keras?

I am using the TensorFlow backend.

I am applying a convolution, max-pooling, flatten and a dense layer sequentially. The convolution requires a 3D input (height, width, color_channels_depth).

After the convolution, this becomes (height, width, Number_of_filters).

After applying max-pooling height and width changes. But, after applying the flatten layer, what happens exactly? For example, if the input before flatten is (24, 24, 32), then how it flattens it out?

Is it sequential like (24 * 24) for height, weight for each filter number sequentially, or in some other way? An example would be appreciated with actual values.

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

The Flatten() operator unrolls the values beginning at the last dimension (at least for Theano, which is "channels first", not "channels last" like TF. I can't run TensorFlow in my environment). This is equivalent to numpy.reshape with 'C' ordering:

‘C’ means to read / write the elements using C-like index order, with the last axis index changing fastest, back to the first axis index changing slowest.

Here is a standalone example illustrating Flatten operator with the Keras Functional API. You should be able to easily adapt for your environment.

import numpy as np
from keras.layers import Input, Flatten
from keras.models import Model
inputs = Input(shape=(3,2,4))

# Define a model consisting only of the Flatten operation
prediction = Flatten()(inputs)
model = Model(inputs=inputs, outputs=prediction)

X = np.arange(0,24).reshape(1,3,2,4)
print(X)
#[[[[ 0  1  2  3]
#   [ 4  5  6  7]]
#
#  [[ 8  9 10 11]
#   [12 13 14 15]]
#
#  [[16 17 18 19]
#   [20 21 22 23]]]]
model.predict(X)
#array([[  0.,   1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.,
#         11.,  12.,  13.,  14.,  15.,  16.,  17.,  18.,  19.,  20.,  21.,
#         22.,  23.]], dtype=float32)

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...