Keras custom loss calculation is incorrect… here is a solution to the problem.
Keras custom loss calculation is incorrect
I’m trying to use a custom loss function in Keras. My implementation looks like this:
class LossFunction:
...
def loss(self, y_true, y_pred):
...
localization_loss = self._localization_loss()
confidence_loss = self._object_confidence_loss()
category_loss = self._category_loss()
self.loc_loss = localization_loss
self.obj_conf_loss = confidence_loss
self.category_loss = category_loss
tot_loss = localization_loss + confidence_loss + category_loss
self.tot_loss = tot_loss
return tot_loss
Then I define custom metrics to see stored tensors, for example:
class MetricContainer:
def __init__(self, loss_obj):
self.loss = loss_obj
def local_loss(self, y_true, y_pred):
return self.loss.loc_loss
def confidence_loss(self, y_true, y_pred):
return self.loss.obj_conf_loss
def category_loss(self, y_true, y_pred):
return self.loss.category_loss
def tot_loss(self, y_true, y_pred):
return self.loss.tot_loss
Then I compile my model with this command :
model.compile('adam',
loss=loss_obj.loss,
metrics=[metric_container.local_loss,
metric_container.confidence_loss,
metric_container.category_loss,
metric_container.tot_loss])
When I train the model (on a very small training set), I get the following output:
Epoch 1/2
1/2 [==============>...............] - ETA: 76s - loss: 482.6910 - category_loss: 28.1100 - confidence_loss: 439.9192 - local_loss: 13.1180 - tot_loss: 481.1472
2/2 [==============================] - 96s - loss: 324.6292 - category_loss: 18.1967 - confidence_loss: 296.0593 - local_loss: 8.8204 - tot_loss: 323.0764 - val_loss: 408.1170 - val_ category_loss: 0.0000e+00 - val_confidence_loss: 400.0000 - val_local_loss: 6.5036 - val_tot_loss: 406.5036
For some reason, tot_loss
and loss
don’t match, even though I should be using the same value for them.
Any idea why this is the case? Will Keras make some modifications after returning?
Solution
Your loss is equal to the sum of the selected loss function and the regularization term. So, if you use any type of regularization – it affects your losses by adding regularization items.