collate_fn.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import numbers
  16. import numpy as np
  17. from collections import defaultdict
  18. class DictCollator(object):
  19. """
  20. data batch
  21. """
  22. def __call__(self, batch):
  23. # todo:support batch operators
  24. data_dict = defaultdict(list)
  25. to_tensor_keys = []
  26. for sample in batch:
  27. for k, v in sample.items():
  28. if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
  29. if k not in to_tensor_keys:
  30. to_tensor_keys.append(k)
  31. data_dict[k].append(v)
  32. for k in to_tensor_keys:
  33. data_dict[k] = paddle.to_tensor(data_dict[k])
  34. return data_dict
  35. class ListCollator(object):
  36. """
  37. data batch
  38. """
  39. def __call__(self, batch):
  40. # todo:support batch operators
  41. data_dict = defaultdict(list)
  42. to_tensor_idxs = []
  43. for sample in batch:
  44. for idx, v in enumerate(sample):
  45. if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
  46. if idx not in to_tensor_idxs:
  47. to_tensor_idxs.append(idx)
  48. data_dict[idx].append(v)
  49. for idx in to_tensor_idxs:
  50. data_dict[idx] = paddle.to_tensor(data_dict[idx])
  51. return list(data_dict.values())
  52. class SSLRotateCollate(object):
  53. """
  54. bach: [
  55. [(4*3xH*W), (4,)]
  56. [(4*3xH*W), (4,)]
  57. ...
  58. ]
  59. """
  60. def __call__(self, batch):
  61. output = [np.concatenate(d, axis=0) for d in zip(*batch)]
  62. return output
  63. class DyMaskCollator(object):
  64. """
  65. batch: [
  66. image [batch_size, channel, maxHinbatch, maxWinbatch]
  67. image_mask [batch_size, channel, maxHinbatch, maxWinbatch]
  68. label [batch_size, maxLabelLen]
  69. label_mask [batch_size, maxLabelLen]
  70. ...
  71. ]
  72. """
  73. def __call__(self, batch):
  74. max_width, max_height, max_length = 0, 0, 0
  75. bs, channel = len(batch), batch[0][0].shape[0]
  76. proper_items = []
  77. for item in batch:
  78. if item[0].shape[1] * max_width > 1600 * 320 or item[0].shape[
  79. 2] * max_height > 1600 * 320:
  80. continue
  81. max_height = item[0].shape[1] if item[0].shape[
  82. 1] > max_height else max_height
  83. max_width = item[0].shape[2] if item[0].shape[
  84. 2] > max_width else max_width
  85. max_length = len(item[1]) if len(item[
  86. 1]) > max_length else max_length
  87. proper_items.append(item)
  88. images, image_masks = np.zeros(
  89. (len(proper_items), channel, max_height, max_width),
  90. dtype='float32'), np.zeros(
  91. (len(proper_items), 1, max_height, max_width), dtype='float32')
  92. labels, label_masks = np.zeros(
  93. (len(proper_items), max_length), dtype='int64'), np.zeros(
  94. (len(proper_items), max_length), dtype='int64')
  95. for i in range(len(proper_items)):
  96. _, h, w = proper_items[i][0].shape
  97. images[i][:, :h, :w] = proper_items[i][0]
  98. image_masks[i][:, :h, :w] = 1
  99. l = len(proper_items[i][1])
  100. labels[i][:l] = proper_items[i][1]
  101. label_masks[i][:l] = 1
  102. return images, image_masks, labels, label_masks