rec_srn_loss.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. from paddle import nn
  19. class SRNLoss(nn.Layer):
  20. def __init__(self, **kwargs):
  21. super(SRNLoss, self).__init__()
  22. self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="sum")
  23. def forward(self, predicts, batch):
  24. predict = predicts['predict']
  25. word_predict = predicts['word_out']
  26. gsrm_predict = predicts['gsrm_out']
  27. label = batch[1]
  28. casted_label = paddle.cast(x=label, dtype='int64')
  29. casted_label = paddle.reshape(x=casted_label, shape=[-1, 1])
  30. cost_word = self.loss_func(word_predict, label=casted_label)
  31. cost_gsrm = self.loss_func(gsrm_predict, label=casted_label)
  32. cost_vsfd = self.loss_func(predict, label=casted_label)
  33. cost_word = paddle.reshape(x=paddle.sum(cost_word), shape=[1])
  34. cost_gsrm = paddle.reshape(x=paddle.sum(cost_gsrm), shape=[1])
  35. cost_vsfd = paddle.reshape(x=paddle.sum(cost_vsfd), shape=[1])
  36. sum_cost = cost_word * 3.0 + cost_vsfd + cost_gsrm * 0.15
  37. return {'loss': sum_cost, 'word_loss': cost_word, 'img_loss': cost_vsfd}