distillation_loss.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461
  1. #copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. #Licensed under the Apache License, Version 2.0 (the "License");
  4. #you may not use this file except in compliance with the License.
  5. #You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. #Unless required by applicable law or agreed to in writing, software
  10. #distributed under the License is distributed on an "AS IS" BASIS,
  11. #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. #See the License for the specific language governing permissions and
  13. #limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import numpy as np
  17. import cv2
  18. from .rec_ctc_loss import CTCLoss
  19. from .rec_sar_loss import SARLoss
  20. from .basic_loss import DMLLoss
  21. from .basic_loss import DistanceLoss
  22. from .basic_loss import LossFromOutput
  23. from .det_db_loss import DBLoss
  24. from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
  25. from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
  26. def _sum_loss(loss_dict):
  27. if "loss" in loss_dict.keys():
  28. return loss_dict
  29. else:
  30. loss_dict["loss"] = 0.
  31. for k, value in loss_dict.items():
  32. if k == "loss":
  33. continue
  34. else:
  35. loss_dict["loss"] += value
  36. return loss_dict
  37. class DistillationDMLLoss(DMLLoss):
  38. """
  39. """
  40. def __init__(self,
  41. model_name_pairs=[],
  42. act=None,
  43. use_log=False,
  44. key=None,
  45. multi_head=False,
  46. dis_head='ctc',
  47. maps_name=None,
  48. name="dml"):
  49. super().__init__(act=act, use_log=use_log)
  50. assert isinstance(model_name_pairs, list)
  51. self.key = key
  52. self.multi_head = multi_head
  53. self.dis_head = dis_head
  54. self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
  55. self.name = name
  56. self.maps_name = self._check_maps_name(maps_name)
  57. def _check_model_name_pairs(self, model_name_pairs):
  58. if not isinstance(model_name_pairs, list):
  59. return []
  60. elif isinstance(model_name_pairs[0], list) and isinstance(
  61. model_name_pairs[0][0], str):
  62. return model_name_pairs
  63. else:
  64. return [model_name_pairs]
  65. def _check_maps_name(self, maps_name):
  66. if maps_name is None:
  67. return None
  68. elif type(maps_name) == str:
  69. return [maps_name]
  70. elif type(maps_name) == list:
  71. return [maps_name]
  72. else:
  73. return None
  74. def _slice_out(self, outs):
  75. new_outs = {}
  76. for k in self.maps_name:
  77. if k == "thrink_maps":
  78. new_outs[k] = outs[:, 0, :, :]
  79. elif k == "threshold_maps":
  80. new_outs[k] = outs[:, 1, :, :]
  81. elif k == "binary_maps":
  82. new_outs[k] = outs[:, 2, :, :]
  83. else:
  84. continue
  85. return new_outs
  86. def forward(self, predicts, batch):
  87. loss_dict = dict()
  88. for idx, pair in enumerate(self.model_name_pairs):
  89. out1 = predicts[pair[0]]
  90. out2 = predicts[pair[1]]
  91. if self.key is not None:
  92. out1 = out1[self.key]
  93. out2 = out2[self.key]
  94. if self.maps_name is None:
  95. if self.multi_head:
  96. loss = super().forward(out1[self.dis_head],
  97. out2[self.dis_head])
  98. else:
  99. loss = super().forward(out1, out2)
  100. if isinstance(loss, dict):
  101. for key in loss:
  102. loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],
  103. idx)] = loss[key]
  104. else:
  105. loss_dict["{}_{}".format(self.name, idx)] = loss
  106. else:
  107. outs1 = self._slice_out(out1)
  108. outs2 = self._slice_out(out2)
  109. for _c, k in enumerate(outs1.keys()):
  110. loss = super().forward(outs1[k], outs2[k])
  111. if isinstance(loss, dict):
  112. for key in loss:
  113. loss_dict["{}_{}_{}_{}_{}".format(key, pair[
  114. 0], pair[1], self.maps_name, idx)] = loss[key]
  115. else:
  116. loss_dict["{}_{}_{}".format(self.name, self.maps_name[
  117. _c], idx)] = loss
  118. loss_dict = _sum_loss(loss_dict)
  119. return loss_dict
  120. class DistillationCTCLoss(CTCLoss):
  121. def __init__(self,
  122. model_name_list=[],
  123. key=None,
  124. multi_head=False,
  125. name="loss_ctc"):
  126. super().__init__()
  127. self.model_name_list = model_name_list
  128. self.key = key
  129. self.name = name
  130. self.multi_head = multi_head
  131. def forward(self, predicts, batch):
  132. loss_dict = dict()
  133. for idx, model_name in enumerate(self.model_name_list):
  134. out = predicts[model_name]
  135. if self.key is not None:
  136. out = out[self.key]
  137. if self.multi_head:
  138. assert 'ctc' in out, 'multi head has multi out'
  139. loss = super().forward(out['ctc'], batch[:2] + batch[3:])
  140. else:
  141. loss = super().forward(out, batch)
  142. if isinstance(loss, dict):
  143. for key in loss:
  144. loss_dict["{}_{}_{}".format(self.name, model_name,
  145. idx)] = loss[key]
  146. else:
  147. loss_dict["{}_{}".format(self.name, model_name)] = loss
  148. return loss_dict
  149. class DistillationSARLoss(SARLoss):
  150. def __init__(self,
  151. model_name_list=[],
  152. key=None,
  153. multi_head=False,
  154. name="loss_sar",
  155. **kwargs):
  156. ignore_index = kwargs.get('ignore_index', 92)
  157. super().__init__(ignore_index=ignore_index)
  158. self.model_name_list = model_name_list
  159. self.key = key
  160. self.name = name
  161. self.multi_head = multi_head
  162. def forward(self, predicts, batch):
  163. loss_dict = dict()
  164. for idx, model_name in enumerate(self.model_name_list):
  165. out = predicts[model_name]
  166. if self.key is not None:
  167. out = out[self.key]
  168. if self.multi_head:
  169. assert 'sar' in out, 'multi head has multi out'
  170. loss = super().forward(out['sar'], batch[:1] + batch[2:])
  171. else:
  172. loss = super().forward(out, batch)
  173. if isinstance(loss, dict):
  174. for key in loss:
  175. loss_dict["{}_{}_{}".format(self.name, model_name,
  176. idx)] = loss[key]
  177. else:
  178. loss_dict["{}_{}".format(self.name, model_name)] = loss
  179. return loss_dict
  180. class DistillationDBLoss(DBLoss):
  181. def __init__(self,
  182. model_name_list=[],
  183. balance_loss=True,
  184. main_loss_type='DiceLoss',
  185. alpha=5,
  186. beta=10,
  187. ohem_ratio=3,
  188. eps=1e-6,
  189. name="db",
  190. **kwargs):
  191. super().__init__()
  192. self.model_name_list = model_name_list
  193. self.name = name
  194. self.key = None
  195. def forward(self, predicts, batch):
  196. loss_dict = {}
  197. for idx, model_name in enumerate(self.model_name_list):
  198. out = predicts[model_name]
  199. if self.key is not None:
  200. out = out[self.key]
  201. loss = super().forward(out, batch)
  202. if isinstance(loss, dict):
  203. for key in loss.keys():
  204. if key == "loss":
  205. continue
  206. name = "{}_{}_{}".format(self.name, model_name, key)
  207. loss_dict[name] = loss[key]
  208. else:
  209. loss_dict["{}_{}".format(self.name, model_name)] = loss
  210. loss_dict = _sum_loss(loss_dict)
  211. return loss_dict
  212. class DistillationDilaDBLoss(DBLoss):
  213. def __init__(self,
  214. model_name_pairs=[],
  215. key=None,
  216. balance_loss=True,
  217. main_loss_type='DiceLoss',
  218. alpha=5,
  219. beta=10,
  220. ohem_ratio=3,
  221. eps=1e-6,
  222. name="dila_dbloss"):
  223. super().__init__()
  224. self.model_name_pairs = model_name_pairs
  225. self.name = name
  226. self.key = key
  227. def forward(self, predicts, batch):
  228. loss_dict = dict()
  229. for idx, pair in enumerate(self.model_name_pairs):
  230. stu_outs = predicts[pair[0]]
  231. tch_outs = predicts[pair[1]]
  232. if self.key is not None:
  233. stu_preds = stu_outs[self.key]
  234. tch_preds = tch_outs[self.key]
  235. stu_shrink_maps = stu_preds[:, 0, :, :]
  236. stu_binary_maps = stu_preds[:, 2, :, :]
  237. # dilation to teacher prediction
  238. dilation_w = np.array([[1, 1], [1, 1]])
  239. th_shrink_maps = tch_preds[:, 0, :, :]
  240. th_shrink_maps = th_shrink_maps.numpy() > 0.3 # thresh = 0.3
  241. dilate_maps = np.zeros_like(th_shrink_maps).astype(np.float32)
  242. for i in range(th_shrink_maps.shape[0]):
  243. dilate_maps[i] = cv2.dilate(
  244. th_shrink_maps[i, :, :].astype(np.uint8), dilation_w)
  245. th_shrink_maps = paddle.to_tensor(dilate_maps)
  246. label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = batch[
  247. 1:]
  248. # calculate the shrink map loss
  249. bce_loss = self.alpha * self.bce_loss(
  250. stu_shrink_maps, th_shrink_maps, label_shrink_mask)
  251. loss_binary_maps = self.dice_loss(stu_binary_maps, th_shrink_maps,
  252. label_shrink_mask)
  253. # k = f"{self.name}_{pair[0]}_{pair[1]}"
  254. k = "{}_{}_{}".format(self.name, pair[0], pair[1])
  255. loss_dict[k] = bce_loss + loss_binary_maps
  256. loss_dict = _sum_loss(loss_dict)
  257. return loss_dict
  258. class DistillationDistanceLoss(DistanceLoss):
  259. """
  260. """
  261. def __init__(self,
  262. mode="l2",
  263. model_name_pairs=[],
  264. key=None,
  265. name="loss_distance",
  266. **kargs):
  267. super().__init__(mode=mode, **kargs)
  268. assert isinstance(model_name_pairs, list)
  269. self.key = key
  270. self.model_name_pairs = model_name_pairs
  271. self.name = name + "_l2"
  272. def forward(self, predicts, batch):
  273. loss_dict = dict()
  274. for idx, pair in enumerate(self.model_name_pairs):
  275. out1 = predicts[pair[0]]
  276. out2 = predicts[pair[1]]
  277. if self.key is not None:
  278. out1 = out1[self.key]
  279. out2 = out2[self.key]
  280. loss = super().forward(out1, out2)
  281. if isinstance(loss, dict):
  282. for key in loss:
  283. loss_dict["{}_{}_{}".format(self.name, key, idx)] = loss[
  284. key]
  285. else:
  286. loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1],
  287. idx)] = loss
  288. return loss_dict
  289. class DistillationVQASerTokenLayoutLMLoss(VQASerTokenLayoutLMLoss):
  290. def __init__(self,
  291. num_classes,
  292. model_name_list=[],
  293. key=None,
  294. name="loss_ser"):
  295. super().__init__(num_classes=num_classes)
  296. self.model_name_list = model_name_list
  297. self.key = key
  298. self.name = name
  299. def forward(self, predicts, batch):
  300. loss_dict = dict()
  301. for idx, model_name in enumerate(self.model_name_list):
  302. out = predicts[model_name]
  303. if self.key is not None:
  304. out = out[self.key]
  305. loss = super().forward(out, batch)
  306. loss_dict["{}_{}".format(self.name, model_name)] = loss["loss"]
  307. return loss_dict
  308. class DistillationLossFromOutput(LossFromOutput):
  309. def __init__(self,
  310. reduction="none",
  311. model_name_list=[],
  312. dist_key=None,
  313. key="loss",
  314. name="loss_re"):
  315. super().__init__(key=key, reduction=reduction)
  316. self.model_name_list = model_name_list
  317. self.name = name
  318. self.dist_key = dist_key
  319. def forward(self, predicts, batch):
  320. loss_dict = dict()
  321. for idx, model_name in enumerate(self.model_name_list):
  322. out = predicts[model_name]
  323. if self.dist_key is not None:
  324. out = out[self.dist_key]
  325. loss = super().forward(out, batch)
  326. loss_dict["{}_{}".format(self.name, model_name)] = loss["loss"]
  327. return loss_dict
  328. class DistillationSERDMLLoss(DMLLoss):
  329. """
  330. """
  331. def __init__(self,
  332. act="softmax",
  333. use_log=True,
  334. num_classes=7,
  335. model_name_pairs=[],
  336. key=None,
  337. name="loss_dml_ser"):
  338. super().__init__(act=act, use_log=use_log)
  339. assert isinstance(model_name_pairs, list)
  340. self.key = key
  341. self.name = name
  342. self.num_classes = num_classes
  343. self.model_name_pairs = model_name_pairs
  344. def forward(self, predicts, batch):
  345. loss_dict = dict()
  346. for idx, pair in enumerate(self.model_name_pairs):
  347. out1 = predicts[pair[0]]
  348. out2 = predicts[pair[1]]
  349. if self.key is not None:
  350. out1 = out1[self.key]
  351. out2 = out2[self.key]
  352. out1 = out1.reshape([-1, out1.shape[-1]])
  353. out2 = out2.reshape([-1, out2.shape[-1]])
  354. attention_mask = batch[2]
  355. if attention_mask is not None:
  356. active_output = attention_mask.reshape([-1, ]) == 1
  357. out1 = out1[active_output]
  358. out2 = out2[active_output]
  359. loss_dict["{}_{}".format(self.name, idx)] = super().forward(out1,
  360. out2)
  361. return loss_dict
  362. class DistillationVQADistanceLoss(DistanceLoss):
  363. def __init__(self,
  364. mode="l2",
  365. model_name_pairs=[],
  366. key=None,
  367. index=None,
  368. name="loss_distance",
  369. **kargs):
  370. super().__init__(mode=mode, **kargs)
  371. assert isinstance(model_name_pairs, list)
  372. self.key = key
  373. self.index = index
  374. self.model_name_pairs = model_name_pairs
  375. self.name = name + "_l2"
  376. def forward(self, predicts, batch):
  377. loss_dict = dict()
  378. for idx, pair in enumerate(self.model_name_pairs):
  379. out1 = predicts[pair[0]]
  380. out2 = predicts[pair[1]]
  381. attention_mask = batch[2]
  382. if self.key is not None:
  383. out1 = out1[self.key]
  384. out2 = out2[self.key]
  385. if self.index is not None:
  386. out1 = out1[:, self.index, :, :]
  387. out2 = out2[:, self.index, :, :]
  388. if attention_mask is not None:
  389. max_len = attention_mask.shape[-1]
  390. out1 = out1[:, :max_len]
  391. out2 = out2[:, :max_len]
  392. out1 = out1.reshape([-1, out1.shape[-1]])
  393. out2 = out2.reshape([-1, out2.shape[-1]])
  394. if attention_mask is not None:
  395. active_output = attention_mask.reshape([-1, ]) == 1
  396. out1 = out1[active_output]
  397. out2 = out2[active_output]
  398. loss = super().forward(out1, out2)
  399. if isinstance(loss, dict):
  400. for key in loss:
  401. loss_dict["{}_{}nohu_{}".format(self.name, key,
  402. idx)] = loss[key]
  403. else:
  404. loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1],
  405. idx)] = loss
  406. return loss_dict