To apply softmax and use a cross entropy loss, you have to keep intact the final output of your network of size batch_size x 256 x 256 x 33. Therefore you cannot use mean averaging or argmax because it would destroy the output probabilities of your network.
You have to loop through all the batch_size x 256 x 256 pixels and apply a cross entropy loss to your prediction for this pixel. This is easy with the built-in function tf.nn.sparse_softmax_cross_entropy_with_logits(logits, labels)
.
Some warnings from the doc before applying the code below:
- WARNING: This op expects unscaled logits, since it performs a softmax on logits internally for efficiency. Do not call this op with the output of softmax, as it will produce incorrect results.
- logits and must have the shape [batch_size, num_classes] and the dtype (either float32 or float64).
- labels must have the shape [batch_size] and the dtype int64.
The trick is to use batch_size * 256 * 256
as the batch size required by the function. We will reshape logits
and labels
to this format.
Here is the code I use:
inputs = tf.placeholder(tf.float32, [batch_size, 256, 256, 3]) # input images
logits = inference(inputs) # your outputs of shape [batch_size, 256, 256, 33] (no final softmax !!)
labels = tf.placeholder(tf.float32, [batch_size, 256, 256]) # your labels of shape [batch_size, 256, 256] and type int64
reshaped_logits = tf.reshape(logits, [-1, 33]) # shape [batch_size*256*256, 33]
reshaped_labels = tf.reshape(labels, [-1]) # shape [batch_size*256*256]
loss = sparse_softmax_cross_entropy_with_logits(reshaped_logits, reshaped_labels)
You can then apply your optimizer on that loss.
Update: v0.10
The documentation of tf.sparse_softmax_cross_entropy_with_logits
shows that it now accepts any shape for logits
, so there is no need to reshape the tensors (thanks @chillinger):
inputs = tf.placeholder(tf.float32, [batch_size, 256, 256, 3]) # input images
logits = inference(inputs) # your outputs of shape [batch_size, 256, 256, 33] (no final softmax !!)
labels = tf.placeholder(tf.float32, [batch_size, 256, 256]) # your labels of shape [batch_size, 256, 256] and type int64
loss = sparse_softmax_cross_entropy_with_logits(logits, labels)