rec_ce_loss.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import paddle
  2. from paddle import nn
  3. import paddle.nn.functional as F
  4. class CELoss(nn.Layer):
  5. def __init__(self,
  6. smoothing=False,
  7. with_all=False,
  8. ignore_index=-1,
  9. **kwargs):
  10. super(CELoss, self).__init__()
  11. if ignore_index >= 0:
  12. self.loss_func = nn.CrossEntropyLoss(
  13. reduction='mean', ignore_index=ignore_index)
  14. else:
  15. self.loss_func = nn.CrossEntropyLoss(reduction='mean')
  16. self.smoothing = smoothing
  17. self.with_all = with_all
  18. def forward(self, pred, batch):
  19. if isinstance(pred, dict): # for ABINet
  20. loss = {}
  21. loss_sum = []
  22. for name, logits in pred.items():
  23. if isinstance(logits, list):
  24. logit_num = len(logits)
  25. all_tgt = paddle.concat([batch[1]] * logit_num, 0)
  26. all_logits = paddle.concat(logits, 0)
  27. flt_logtis = all_logits.reshape([-1, all_logits.shape[2]])
  28. flt_tgt = all_tgt.reshape([-1])
  29. else:
  30. flt_logtis = logits.reshape([-1, logits.shape[2]])
  31. flt_tgt = batch[1].reshape([-1])
  32. loss[name + '_loss'] = self.loss_func(flt_logtis, flt_tgt)
  33. loss_sum.append(loss[name + '_loss'])
  34. loss['loss'] = sum(loss_sum)
  35. return loss
  36. else:
  37. if self.with_all: # for ViTSTR
  38. tgt = batch[1]
  39. pred = pred.reshape([-1, pred.shape[2]])
  40. tgt = tgt.reshape([-1])
  41. loss = self.loss_func(pred, tgt)
  42. return {'loss': loss}
  43. else: # for NRTR
  44. max_len = batch[2].max()
  45. tgt = batch[1][:, 1:2 + max_len]
  46. pred = pred.reshape([-1, pred.shape[2]])
  47. tgt = tgt.reshape([-1])
  48. if self.smoothing:
  49. eps = 0.1
  50. n_class = pred.shape[1]
  51. one_hot = F.one_hot(tgt, pred.shape[1])
  52. one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (
  53. n_class - 1)
  54. log_prb = F.log_softmax(pred, axis=1)
  55. non_pad_mask = paddle.not_equal(
  56. tgt, paddle.zeros(
  57. tgt.shape, dtype=tgt.dtype))
  58. loss = -(one_hot * log_prb).sum(axis=1)
  59. loss = loss.masked_select(non_pad_mask).mean()
  60. else:
  61. loss = self.loss_func(pred, tgt)
  62. return {'loss': loss}