loss.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import paddle
  2. import paddle.nn.functional as F
  3. class Loss(object):
  4. """
  5. Loss
  6. """
  7. def __init__(self, class_dim=1000, epsilon=None):
  8. assert class_dim > 1, "class_dim=%d is not larger than 1" % (class_dim)
  9. self._class_dim = class_dim
  10. if epsilon is not None and epsilon >= 0.0 and epsilon <= 1.0:
  11. self._epsilon = epsilon
  12. self._label_smoothing = True
  13. else:
  14. self._epsilon = None
  15. self._label_smoothing = False
  16. def _labelsmoothing(self, target):
  17. if target.shape[-1] != self._class_dim:
  18. one_hot_target = F.one_hot(target, self._class_dim)
  19. else:
  20. one_hot_target = target
  21. soft_target = F.label_smooth(one_hot_target, epsilon=self._epsilon)
  22. soft_target = paddle.reshape(soft_target, shape=[-1, self._class_dim])
  23. return soft_target
  24. def _crossentropy(self, input, target, use_pure_fp16=False):
  25. if self._label_smoothing:
  26. target = self._labelsmoothing(target)
  27. input = -F.log_softmax(input, axis=-1)
  28. cost = paddle.sum(target * input, axis=-1)
  29. else:
  30. cost = F.cross_entropy(input=input, label=target)
  31. if use_pure_fp16:
  32. avg_cost = paddle.sum(cost)
  33. else:
  34. avg_cost = paddle.mean(cost)
  35. return avg_cost
  36. def __call__(self, input, target):
  37. return self._crossentropy(input, target)
  38. def build_loss(config, epsilon=None):
  39. class_dim = config['class_dim']
  40. loss_func = Loss(class_dim=class_dim, epsilon=epsilon)
  41. return loss_func
  42. class LossDistill(Loss):
  43. def __init__(self, model_name_list, class_dim=1000, epsilon=None):
  44. assert class_dim > 1, "class_dim=%d is not larger than 1" % (class_dim)
  45. self._class_dim = class_dim
  46. if epsilon is not None and epsilon >= 0.0 and epsilon <= 1.0:
  47. self._epsilon = epsilon
  48. self._label_smoothing = True
  49. else:
  50. self._epsilon = None
  51. self._label_smoothing = False
  52. self.model_name_list = model_name_list
  53. assert len(self.model_name_list) > 1, "error"
  54. def __call__(self, input, target):
  55. losses = {}
  56. for k in self.model_name_list:
  57. inp = input[k]
  58. losses[k] = self._crossentropy(inp, target)
  59. return losses
  60. class KLJSLoss(object):
  61. def __init__(self, mode='kl'):
  62. assert mode in ['kl', 'js', 'KL', 'JS'
  63. ], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
  64. self.mode = mode
  65. def __call__(self, p1, p2, reduction="mean"):
  66. p1 = F.softmax(p1, axis=-1)
  67. p2 = F.softmax(p2, axis=-1)
  68. loss = paddle.multiply(p2, paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
  69. if self.mode.lower() == "js":
  70. loss += paddle.multiply(
  71. p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
  72. loss *= 0.5
  73. if reduction == "mean":
  74. loss = paddle.mean(loss)
  75. elif reduction == "none" or reduction is None:
  76. return loss
  77. else:
  78. loss = paddle.sum(loss)
  79. return loss
  80. class DMLLoss(object):
  81. def __init__(self, model_name_pairs, mode='js'):
  82. self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
  83. self.kljs_loss = KLJSLoss(mode=mode)
  84. def _check_model_name_pairs(self, model_name_pairs):
  85. if not isinstance(model_name_pairs, list):
  86. return []
  87. elif isinstance(model_name_pairs[0], list) and isinstance(
  88. model_name_pairs[0][0], str):
  89. return model_name_pairs
  90. else:
  91. return [model_name_pairs]
  92. def __call__(self, predicts, target=None):
  93. loss_dict = dict()
  94. for pairs in self.model_name_pairs:
  95. p1 = predicts[pairs[0]]
  96. p2 = predicts[pairs[1]]
  97. loss_dict[pairs[0] + "_" + pairs[1]] = self.kljs_loss(p1, p2)
  98. return loss_dict
  99. # def build_distill_loss(config, epsilon=None):
  100. # class_dim = config['class_dim']
  101. # loss = LossDistill(model_name_list=['student', 'student1'], )
  102. # return loss_func