__init__.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. import paddle
  16. import paddle.nn as nn
  17. # basic_loss
  18. from .basic_loss import LossFromOutput
  19. # det loss
  20. from .det_db_loss import DBLoss
  21. from .det_east_loss import EASTLoss
  22. from .det_sast_loss import SASTLoss
  23. from .det_pse_loss import PSELoss
  24. from .det_fce_loss import FCELoss
  25. from .det_ct_loss import CTLoss
  26. from .det_drrg_loss import DRRGLoss
  27. # rec loss
  28. from .rec_ctc_loss import CTCLoss
  29. from .rec_att_loss import AttentionLoss
  30. from .rec_srn_loss import SRNLoss
  31. from .rec_ce_loss import CELoss
  32. from .rec_sar_loss import SARLoss
  33. from .rec_aster_loss import AsterLoss
  34. from .rec_pren_loss import PRENLoss
  35. from .rec_multi_loss import MultiLoss
  36. from .rec_vl_loss import VLLoss
  37. from .rec_spin_att_loss import SPINAttentionLoss
  38. from .rec_rfl_loss import RFLLoss
  39. from .rec_can_loss import CANLoss
  40. # cls loss
  41. from .cls_loss import ClsLoss
  42. # e2e loss
  43. from .e2e_pg_loss import PGLoss
  44. from .kie_sdmgr_loss import SDMGRLoss
  45. # basic loss function
  46. from .basic_loss import DistanceLoss
  47. # combined loss function
  48. from .combined_loss import CombinedLoss
  49. # table loss
  50. from .table_att_loss import TableAttentionLoss, SLALoss
  51. from .table_master_loss import TableMasterLoss
  52. # vqa token loss
  53. from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
  54. # sr loss
  55. from .stroke_focus_loss import StrokeFocusLoss
  56. from .text_focus_loss import TelescopeLoss
  57. def build_loss(config):
  58. support_dict = [
  59. 'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'FCELoss', 'CTCLoss',
  60. 'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
  61. 'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
  62. 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
  63. 'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss',
  64. 'SLALoss', 'CTLoss', 'RFLLoss', 'DRRGLoss', 'CANLoss', 'TelescopeLoss'
  65. ]
  66. config = copy.deepcopy(config)
  67. module_name = config.pop('name')
  68. assert module_name in support_dict, Exception('loss only support {}'.format(
  69. support_dict))
  70. module_class = eval(module_name)(**config)
  71. return module_class