vqa_layoutlm.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. from paddle import nn
  19. from paddlenlp.transformers import LayoutXLMModel, LayoutXLMForTokenClassification, LayoutXLMForRelationExtraction
  20. from paddlenlp.transformers import LayoutLMModel, LayoutLMForTokenClassification
  21. from paddlenlp.transformers import LayoutLMv2Model, LayoutLMv2ForTokenClassification, LayoutLMv2ForRelationExtraction
  22. from paddlenlp.transformers import AutoModel
  23. __all__ = ["LayoutXLMForSer", "LayoutLMForSer"]
  24. pretrained_model_dict = {
  25. LayoutXLMModel: {
  26. "base": "layoutxlm-base-uncased",
  27. "vi": "vi-layoutxlm-base-uncased",
  28. },
  29. LayoutLMModel: {
  30. "base": "layoutlm-base-uncased",
  31. },
  32. LayoutLMv2Model: {
  33. "base": "layoutlmv2-base-uncased",
  34. "vi": "vi-layoutlmv2-base-uncased",
  35. },
  36. }
  37. class NLPBaseModel(nn.Layer):
  38. def __init__(self,
  39. base_model_class,
  40. model_class,
  41. mode="base",
  42. type="ser",
  43. pretrained=True,
  44. checkpoints=None,
  45. **kwargs):
  46. super(NLPBaseModel, self).__init__()
  47. if checkpoints is not None: # load the trained model
  48. self.model = model_class.from_pretrained(checkpoints)
  49. else: # load the pretrained-model
  50. pretrained_model_name = pretrained_model_dict[base_model_class][
  51. mode]
  52. if pretrained is True:
  53. base_model = base_model_class.from_pretrained(
  54. pretrained_model_name)
  55. else:
  56. base_model = base_model_class.from_pretrained(pretrained)
  57. if type == "ser":
  58. self.model = model_class(
  59. base_model, num_classes=kwargs["num_classes"], dropout=None)
  60. else:
  61. self.model = model_class(base_model, dropout=None)
  62. self.out_channels = 1
  63. self.use_visual_backbone = True
  64. class LayoutLMForSer(NLPBaseModel):
  65. def __init__(self,
  66. num_classes,
  67. pretrained=True,
  68. checkpoints=None,
  69. mode="base",
  70. **kwargs):
  71. super(LayoutLMForSer, self).__init__(
  72. LayoutLMModel,
  73. LayoutLMForTokenClassification,
  74. mode,
  75. "ser",
  76. pretrained,
  77. checkpoints,
  78. num_classes=num_classes, )
  79. self.use_visual_backbone = False
  80. def forward(self, x):
  81. x = self.model(
  82. input_ids=x[0],
  83. bbox=x[1],
  84. attention_mask=x[2],
  85. token_type_ids=x[3],
  86. position_ids=None,
  87. output_hidden_states=False)
  88. return x
  89. class LayoutLMv2ForSer(NLPBaseModel):
  90. def __init__(self,
  91. num_classes,
  92. pretrained=True,
  93. checkpoints=None,
  94. mode="base",
  95. **kwargs):
  96. super(LayoutLMv2ForSer, self).__init__(
  97. LayoutLMv2Model,
  98. LayoutLMv2ForTokenClassification,
  99. mode,
  100. "ser",
  101. pretrained,
  102. checkpoints,
  103. num_classes=num_classes)
  104. if hasattr(self.model.layoutlmv2, "use_visual_backbone"
  105. ) and self.model.layoutlmv2.use_visual_backbone is False:
  106. self.use_visual_backbone = False
  107. def forward(self, x):
  108. if self.use_visual_backbone is True:
  109. image = x[4]
  110. else:
  111. image = None
  112. x = self.model(
  113. input_ids=x[0],
  114. bbox=x[1],
  115. attention_mask=x[2],
  116. token_type_ids=x[3],
  117. image=image,
  118. position_ids=None,
  119. head_mask=None,
  120. labels=None)
  121. if self.training:
  122. res = {"backbone_out": x[0]}
  123. res.update(x[1])
  124. return res
  125. else:
  126. return x
  127. class LayoutXLMForSer(NLPBaseModel):
  128. def __init__(self,
  129. num_classes,
  130. pretrained=True,
  131. checkpoints=None,
  132. mode="base",
  133. **kwargs):
  134. super(LayoutXLMForSer, self).__init__(
  135. LayoutXLMModel,
  136. LayoutXLMForTokenClassification,
  137. mode,
  138. "ser",
  139. pretrained,
  140. checkpoints,
  141. num_classes=num_classes)
  142. if hasattr(self.model.layoutxlm, "use_visual_backbone"
  143. ) and self.model.layoutxlm.use_visual_backbone is False:
  144. self.use_visual_backbone = False
  145. def forward(self, x):
  146. if self.use_visual_backbone is True:
  147. image = x[4]
  148. else:
  149. image = None
  150. x = self.model(
  151. input_ids=x[0],
  152. bbox=x[1],
  153. attention_mask=x[2],
  154. token_type_ids=x[3],
  155. image=image,
  156. position_ids=None,
  157. head_mask=None,
  158. labels=None)
  159. if self.training:
  160. res = {"backbone_out": x[0]}
  161. res.update(x[1])
  162. return res
  163. else:
  164. return x
  165. class LayoutLMv2ForRe(NLPBaseModel):
  166. def __init__(self, pretrained=True, checkpoints=None, mode="base",
  167. **kwargs):
  168. super(LayoutLMv2ForRe, self).__init__(
  169. LayoutLMv2Model, LayoutLMv2ForRelationExtraction, mode, "re",
  170. pretrained, checkpoints)
  171. if hasattr(self.model.layoutlmv2, "use_visual_backbone"
  172. ) and self.model.layoutlmv2.use_visual_backbone is False:
  173. self.use_visual_backbone = False
  174. def forward(self, x):
  175. x = self.model(
  176. input_ids=x[0],
  177. bbox=x[1],
  178. attention_mask=x[2],
  179. token_type_ids=x[3],
  180. image=x[4],
  181. position_ids=None,
  182. head_mask=None,
  183. labels=None,
  184. entities=x[5],
  185. relations=x[6])
  186. return x
  187. class LayoutXLMForRe(NLPBaseModel):
  188. def __init__(self, pretrained=True, checkpoints=None, mode="base",
  189. **kwargs):
  190. super(LayoutXLMForRe, self).__init__(
  191. LayoutXLMModel, LayoutXLMForRelationExtraction, mode, "re",
  192. pretrained, checkpoints)
  193. if hasattr(self.model.layoutxlm, "use_visual_backbone"
  194. ) and self.model.layoutxlm.use_visual_backbone is False:
  195. self.use_visual_backbone = False
  196. def forward(self, x):
  197. if self.use_visual_backbone is True:
  198. image = x[4]
  199. entities = x[5]
  200. relations = x[6]
  201. else:
  202. image = None
  203. entities = x[4]
  204. relations = x[5]
  205. x = self.model(
  206. input_ids=x[0],
  207. bbox=x[1],
  208. attention_mask=x[2],
  209. token_type_ids=x[3],
  210. image=image,
  211. position_ids=None,
  212. head_mask=None,
  213. labels=None,
  214. entities=entities,
  215. relations=relations)
  216. return x