det_metric.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. __all__ = ['DetMetric', 'DetFCEMetric']
  18. from .eval_det_iou import DetectionIoUEvaluator
  19. class DetMetric(object):
  20. def __init__(self, main_indicator='hmean', **kwargs):
  21. self.evaluator = DetectionIoUEvaluator()
  22. self.main_indicator = main_indicator
  23. self.reset()
  24. def __call__(self, preds, batch, **kwargs):
  25. '''
  26. batch: a list produced by dataloaders.
  27. image: np.ndarray of shape (N, C, H, W).
  28. ratio_list: np.ndarray of shape(N,2)
  29. polygons: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
  30. ignore_tags: np.ndarray of shape (N, K), indicates whether a region is ignorable or not.
  31. preds: a list of dict produced by post process
  32. points: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
  33. '''
  34. gt_polyons_batch = batch[2]
  35. ignore_tags_batch = batch[3]
  36. for pred, gt_polyons, ignore_tags in zip(preds, gt_polyons_batch,
  37. ignore_tags_batch):
  38. # prepare gt
  39. gt_info_list = [{
  40. 'points': gt_polyon,
  41. 'text': '',
  42. 'ignore': ignore_tag
  43. } for gt_polyon, ignore_tag in zip(gt_polyons, ignore_tags)]
  44. # prepare det
  45. det_info_list = [{
  46. 'points': det_polyon,
  47. 'text': ''
  48. } for det_polyon in pred['points']]
  49. result = self.evaluator.evaluate_image(gt_info_list, det_info_list)
  50. self.results.append(result)
  51. def get_metric(self):
  52. """
  53. return metrics {
  54. 'precision': 0,
  55. 'recall': 0,
  56. 'hmean': 0
  57. }
  58. """
  59. metrics = self.evaluator.combine_results(self.results)
  60. self.reset()
  61. return metrics
  62. def reset(self):
  63. self.results = [] # clear results
  64. class DetFCEMetric(object):
  65. def __init__(self, main_indicator='hmean', **kwargs):
  66. self.evaluator = DetectionIoUEvaluator()
  67. self.main_indicator = main_indicator
  68. self.reset()
  69. def __call__(self, preds, batch, **kwargs):
  70. '''
  71. batch: a list produced by dataloaders.
  72. image: np.ndarray of shape (N, C, H, W).
  73. ratio_list: np.ndarray of shape(N,2)
  74. polygons: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
  75. ignore_tags: np.ndarray of shape (N, K), indicates whether a region is ignorable or not.
  76. preds: a list of dict produced by post process
  77. points: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
  78. '''
  79. gt_polyons_batch = batch[2]
  80. ignore_tags_batch = batch[3]
  81. for pred, gt_polyons, ignore_tags in zip(preds, gt_polyons_batch,
  82. ignore_tags_batch):
  83. # prepare gt
  84. gt_info_list = [{
  85. 'points': gt_polyon,
  86. 'text': '',
  87. 'ignore': ignore_tag
  88. } for gt_polyon, ignore_tag in zip(gt_polyons, ignore_tags)]
  89. # prepare det
  90. det_info_list = [{
  91. 'points': det_polyon,
  92. 'text': '',
  93. 'score': score
  94. } for det_polyon, score in zip(pred['points'], pred['scores'])]
  95. for score_thr in self.results.keys():
  96. det_info_list_thr = [
  97. det_info for det_info in det_info_list
  98. if det_info['score'] >= score_thr
  99. ]
  100. result = self.evaluator.evaluate_image(gt_info_list,
  101. det_info_list_thr)
  102. self.results[score_thr].append(result)
  103. def get_metric(self):
  104. """
  105. return metrics {'heman':0,
  106. 'thr 0.3':'precision: 0 recall: 0 hmean: 0',
  107. 'thr 0.4':'precision: 0 recall: 0 hmean: 0',
  108. 'thr 0.5':'precision: 0 recall: 0 hmean: 0',
  109. 'thr 0.6':'precision: 0 recall: 0 hmean: 0',
  110. 'thr 0.7':'precision: 0 recall: 0 hmean: 0',
  111. 'thr 0.8':'precision: 0 recall: 0 hmean: 0',
  112. 'thr 0.9':'precision: 0 recall: 0 hmean: 0',
  113. }
  114. """
  115. metrics = {}
  116. hmean = 0
  117. for score_thr in self.results.keys():
  118. metric = self.evaluator.combine_results(self.results[score_thr])
  119. # for key, value in metric.items():
  120. # metrics['{}_{}'.format(key, score_thr)] = value
  121. metric_str = 'precision:{:.5f} recall:{:.5f} hmean:{:.5f}'.format(
  122. metric['precision'], metric['recall'], metric['hmean'])
  123. metrics['thr {}'.format(score_thr)] = metric_str
  124. hmean = max(hmean, metric['hmean'])
  125. metrics['hmean'] = hmean
  126. self.reset()
  127. return metrics
  128. def reset(self):
  129. self.results = {
  130. 0.3: [],
  131. 0.4: [],
  132. 0.5: [],
  133. 0.6: [],
  134. 0.7: [],
  135. 0.8: [],
  136. 0.9: []
  137. } # clear results