combined_loss.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. from .rec_ctc_loss import CTCLoss
  17. from .center_loss import CenterLoss
  18. from .ace_loss import ACELoss
  19. from .rec_sar_loss import SARLoss
  20. from .distillation_loss import DistillationCTCLoss
  21. from .distillation_loss import DistillationSARLoss
  22. from .distillation_loss import DistillationDMLLoss
  23. from .distillation_loss import DistillationDistanceLoss, DistillationDBLoss, DistillationDilaDBLoss
  24. from .distillation_loss import DistillationVQASerTokenLayoutLMLoss, DistillationSERDMLLoss
  25. from .distillation_loss import DistillationLossFromOutput
  26. from .distillation_loss import DistillationVQADistanceLoss
  27. class CombinedLoss(nn.Layer):
  28. """
  29. CombinedLoss:
  30. a combionation of loss function
  31. """
  32. def __init__(self, loss_config_list=None):
  33. super().__init__()
  34. self.loss_func = []
  35. self.loss_weight = []
  36. assert isinstance(loss_config_list, list), (
  37. 'operator config should be a list')
  38. for config in loss_config_list:
  39. assert isinstance(config,
  40. dict) and len(config) == 1, "yaml format error"
  41. name = list(config)[0]
  42. param = config[name]
  43. assert "weight" in param, "weight must be in param, but param just contains {}".format(
  44. param.keys())
  45. self.loss_weight.append(param.pop("weight"))
  46. self.loss_func.append(eval(name)(**param))
  47. def forward(self, input, batch, **kargs):
  48. loss_dict = {}
  49. loss_all = 0.
  50. for idx, loss_func in enumerate(self.loss_func):
  51. loss = loss_func(input, batch, **kargs)
  52. if isinstance(loss, paddle.Tensor):
  53. loss = {"loss_{}_{}".format(str(loss), idx): loss}
  54. weight = self.loss_weight[idx]
  55. loss = {key: loss[key] * weight for key in loss}
  56. if "loss" in loss:
  57. loss_all += loss["loss"]
  58. else:
  59. loss_all += paddle.add_n(list(loss.values()))
  60. loss_dict.update(loss)
  61. loss_dict["loss"] = loss_all
  62. return loss_dict