ct_metric.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. from scipy import io
  19. import numpy as np
  20. from ppocr.utils.e2e_metric.Deteval import combine_results, get_score_C
  21. class CTMetric(object):
  22. def __init__(self, main_indicator, delimiter='\t', **kwargs):
  23. self.delimiter = delimiter
  24. self.main_indicator = main_indicator
  25. self.reset()
  26. def reset(self):
  27. self.results = [] # clear results
  28. def __call__(self, preds, batch, **kwargs):
  29. # NOTE: only support bs=1 now, as the label length of different sample is Unequal
  30. assert len(
  31. preds) == 1, "CentripetalText test now only suuport batch_size=1."
  32. label = batch[2]
  33. text = batch[3]
  34. pred = preds[0]['points']
  35. result = get_score_C(label, text, pred)
  36. self.results.append(result)
  37. def get_metric(self):
  38. """
  39. Input format: y0,x0, ..... yn,xn. Each detection is separated by the end of line token ('\n')'
  40. """
  41. metrics = combine_results(self.results, rec_flag=False)
  42. self.reset()
  43. return metrics