postprocess_op.h 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. // Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #pragma once
  15. #include "include/clipper.h"
  16. #include "include/utility.h"
  17. namespace PaddleOCR {
  18. class DBPostProcessor {
  19. public:
  20. void GetContourArea(const std::vector<std::vector<float>> &box,
  21. float unclip_ratio, float &distance);
  22. cv::RotatedRect UnClip(std::vector<std::vector<float>> box,
  23. const float &unclip_ratio);
  24. float **Mat2Vec(cv::Mat mat);
  25. std::vector<std::vector<int>>
  26. OrderPointsClockwise(std::vector<std::vector<int>> pts);
  27. std::vector<std::vector<float>> GetMiniBoxes(cv::RotatedRect box,
  28. float &ssid);
  29. float BoxScoreFast(std::vector<std::vector<float>> box_array, cv::Mat pred);
  30. float PolygonScoreAcc(std::vector<cv::Point> contour, cv::Mat pred);
  31. std::vector<std::vector<std::vector<int>>>
  32. BoxesFromBitmap(const cv::Mat pred, const cv::Mat bitmap,
  33. const float &box_thresh, const float &det_db_unclip_ratio,
  34. const std::string &det_db_score_mode);
  35. std::vector<std::vector<std::vector<int>>>
  36. FilterTagDetRes(std::vector<std::vector<std::vector<int>>> boxes,
  37. float ratio_h, float ratio_w, cv::Mat srcimg);
  38. private:
  39. static bool XsortInt(std::vector<int> a, std::vector<int> b);
  40. static bool XsortFp32(std::vector<float> a, std::vector<float> b);
  41. std::vector<std::vector<float>> Mat2Vector(cv::Mat mat);
  42. inline int _max(int a, int b) { return a >= b ? a : b; }
  43. inline int _min(int a, int b) { return a >= b ? b : a; }
  44. template <class T> inline T clamp(T x, T min, T max) {
  45. if (x > max)
  46. return max;
  47. if (x < min)
  48. return min;
  49. return x;
  50. }
  51. inline float clampf(float x, float min, float max) {
  52. if (x > max)
  53. return max;
  54. if (x < min)
  55. return min;
  56. return x;
  57. }
  58. };
  59. class TablePostProcessor {
  60. public:
  61. void init(std::string label_path, bool merge_no_span_structure = true);
  62. void Run(std::vector<float> &loc_preds, std::vector<float> &structure_probs,
  63. std::vector<float> &rec_scores, std::vector<int> &loc_preds_shape,
  64. std::vector<int> &structure_probs_shape,
  65. std::vector<std::vector<std::string>> &rec_html_tag_batch,
  66. std::vector<std::vector<std::vector<int>>> &rec_boxes_batch,
  67. std::vector<int> &width_list, std::vector<int> &height_list);
  68. private:
  69. std::vector<std::string> label_list_;
  70. std::string end = "eos";
  71. std::string beg = "sos";
  72. };
  73. class PicodetPostProcessor {
  74. public:
  75. void init(std::string label_path, const double score_threshold = 0.4,
  76. const double nms_threshold = 0.5,
  77. const std::vector<int> &fpn_stride = {8, 16, 32, 64});
  78. void Run(std::vector<StructurePredictResult> &results,
  79. std::vector<std::vector<float>> outs, std::vector<int> ori_shape,
  80. std::vector<int> resize_shape, int eg_max);
  81. std::vector<int> fpn_stride_ = {8, 16, 32, 64};
  82. private:
  83. StructurePredictResult disPred2Bbox(std::vector<float> bbox_pred, int label,
  84. float score, int x, int y, int stride,
  85. std::vector<int> im_shape, int reg_max);
  86. void nms(std::vector<StructurePredictResult> &input_boxes,
  87. float nms_threshold);
  88. std::vector<std::string> label_list_;
  89. double score_threshold_ = 0.4;
  90. double nms_threshold_ = 0.5;
  91. int num_class_ = 5;
  92. };
  93. } // namespace PaddleOCR