I cannot train a custom loss function when batch_size > 1 due to sizing errors.
class API_Network(object):
def __init__(self):
self.model = self.build_model()
def build_model(self):
# AutoEncoder Model
input_img = keras.Input((len(current_timestep_matrix), len(current_timestep_matrix[0]), len(current_timestep_matrix[0][0])))
encoded = layers.Dense(8, activation='tanh')(input_img)
encoded = layers.AveragePooling2D(pool_size=(2, 2))(encoded)
encoded = layers.Dense(16, activation='tanh')(encoded)
encoded = layers.AveragePooling2D(pool_size=(2, 2))(encoded)
encoded = layers.Dense(1, activation='tanh')(encoded)
encoded = layers.Flatten()(encoded)
decoded = layers.Dense(len(output_train[0]), activation='linear')(encoded)
model = keras.Model(input_img, decoded)
model.compile(loss=my_forward_div_loss_fn(input_img), optimizer=keras.optimizers.Adam())
model.summary()
return model
def my_forward_div_loss_fn(input_img):
def loss(y_true, y_pred):
batch_size, width, height, channels = input_img.get_shape().as_list()
u_field = K.abs(y_true - y_pred)
v_field = input_img[:, :, :, 0]
x_field = input_img[:, :, :, 1]
y_field = input_img[:, :, :, 2]
u_field = K.reshape(u_field, shape=[width, height])
u_field = K.expand_dims(u_field, 0)
forward_dudx = (u_field[batch_size, 1:-1, 2:] - u_field[batch_size, 1:-1, 1:-1]) / (x_field[batch_size, 1:-1, 2:] - x_field[batch_size, 1:-1, 1:-1])
forward_dvdy = (v_field[batch_size, 1:-1, 1:-1] - v_field[batch_size, 2:, 1:-1]) / (y_field[batch_size, :-2, 1:-1] - y_field[batch_size, 2:, 1:-1])
forward_divergence = forward_dudx + forward_dvdy # COMPUTES FORWARD DIVERGENCE
forward_divergence = tf.where(tf.math.is_nan(forward_divergence), tf.zeros_like(forward_divergence), forward_divergence)
return K.square(forward_divergence)
return loss
When changing the batch size parameter when training (to batch_size=2), I get the following error:
InvalidArgumentError: Input to reshape is a tensor with 55778 values, but the requested shape has 27889 [[{{node loss/dense_3_loss/Reshape}}]]
I
The model complies fine, but how is it possible to get past this issue when training the network?