table_metric.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. # Copyright 2020 IBM
  2. # Author: peter.zhong@au1.ibm.com
  3. #
  4. # This is free software; you can redistribute it and/or modify
  5. # it under the terms of the Apache 2.0 License.
  6. #
  7. # This software is distributed in the hope that it will be useful,
  8. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. # Apache 2.0 License for more details.
  11. from rapidfuzz.distance import Levenshtein
  12. from apted import APTED, Config
  13. from apted.helpers import Tree
  14. from lxml import etree, html
  15. from collections import deque
  16. from .parallel import parallel_process
  17. from tqdm import tqdm
  18. class TableTree(Tree):
  19. def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
  20. self.tag = tag
  21. self.colspan = colspan
  22. self.rowspan = rowspan
  23. self.content = content
  24. self.children = list(children)
  25. def bracket(self):
  26. """Show tree using brackets notation"""
  27. if self.tag == 'td':
  28. result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \
  29. (self.tag, self.colspan, self.rowspan, self.content)
  30. else:
  31. result = '"tag": %s' % self.tag
  32. for child in self.children:
  33. result += child.bracket()
  34. return "{{{}}}".format(result)
  35. class CustomConfig(Config):
  36. def rename(self, node1, node2):
  37. """Compares attributes of trees"""
  38. #print(node1.tag)
  39. if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
  40. return 1.
  41. if node1.tag == 'td':
  42. if node1.content or node2.content:
  43. #print(node1.content, )
  44. return Levenshtein.normalized_distance(node1.content, node2.content)
  45. return 0.
  46. class CustomConfig_del_short(Config):
  47. def rename(self, node1, node2):
  48. """Compares attributes of trees"""
  49. if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
  50. return 1.
  51. if node1.tag == 'td':
  52. if node1.content or node2.content:
  53. #print('before')
  54. #print(node1.content, node2.content)
  55. #print('after')
  56. node1_content = node1.content
  57. node2_content = node2.content
  58. if len(node1_content) < 3:
  59. node1_content = ['####']
  60. if len(node2_content) < 3:
  61. node2_content = ['####']
  62. return Levenshtein.normalized_distance(node1_content, node2_content)
  63. return 0.
  64. class CustomConfig_del_block(Config):
  65. def rename(self, node1, node2):
  66. """Compares attributes of trees"""
  67. if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
  68. return 1.
  69. if node1.tag == 'td':
  70. if node1.content or node2.content:
  71. node1_content = node1.content
  72. node2_content = node2.content
  73. while ' ' in node1_content:
  74. print(node1_content.index(' '))
  75. node1_content.pop(node1_content.index(' '))
  76. while ' ' in node2_content:
  77. print(node2_content.index(' '))
  78. node2_content.pop(node2_content.index(' '))
  79. return Levenshtein.normalized_distance(node1_content, node2_content)
  80. return 0.
  81. class TEDS(object):
  82. ''' Tree Edit Distance basead Similarity
  83. '''
  84. def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
  85. assert isinstance(n_jobs, int) and (
  86. n_jobs >= 1), 'n_jobs must be an integer greather than 1'
  87. self.structure_only = structure_only
  88. self.n_jobs = n_jobs
  89. self.ignore_nodes = ignore_nodes
  90. self.__tokens__ = []
  91. def tokenize(self, node):
  92. ''' Tokenizes table cells
  93. '''
  94. self.__tokens__.append('<%s>' % node.tag)
  95. if node.text is not None:
  96. self.__tokens__ += list(node.text)
  97. for n in node.getchildren():
  98. self.tokenize(n)
  99. if node.tag != 'unk':
  100. self.__tokens__.append('</%s>' % node.tag)
  101. if node.tag != 'td' and node.tail is not None:
  102. self.__tokens__ += list(node.tail)
  103. def load_html_tree(self, node, parent=None):
  104. ''' Converts HTML tree to the format required by apted
  105. '''
  106. global __tokens__
  107. if node.tag == 'td':
  108. if self.structure_only:
  109. cell = []
  110. else:
  111. self.__tokens__ = []
  112. self.tokenize(node)
  113. cell = self.__tokens__[1:-1].copy()
  114. new_node = TableTree(node.tag,
  115. int(node.attrib.get('colspan', '1')),
  116. int(node.attrib.get('rowspan', '1')),
  117. cell, *deque())
  118. else:
  119. new_node = TableTree(node.tag, None, None, None, *deque())
  120. if parent is not None:
  121. parent.children.append(new_node)
  122. if node.tag != 'td':
  123. for n in node.getchildren():
  124. self.load_html_tree(n, new_node)
  125. if parent is None:
  126. return new_node
  127. def evaluate(self, pred, true):
  128. ''' Computes TEDS score between the prediction and the ground truth of a
  129. given sample
  130. '''
  131. if (not pred) or (not true):
  132. return 0.0
  133. parser = html.HTMLParser(remove_comments=True, encoding='utf-8')
  134. pred = html.fromstring(pred, parser=parser)
  135. true = html.fromstring(true, parser=parser)
  136. if pred.xpath('body/table') and true.xpath('body/table'):
  137. pred = pred.xpath('body/table')[0]
  138. true = true.xpath('body/table')[0]
  139. if self.ignore_nodes:
  140. etree.strip_tags(pred, *self.ignore_nodes)
  141. etree.strip_tags(true, *self.ignore_nodes)
  142. n_nodes_pred = len(pred.xpath(".//*"))
  143. n_nodes_true = len(true.xpath(".//*"))
  144. n_nodes = max(n_nodes_pred, n_nodes_true)
  145. tree_pred = self.load_html_tree(pred)
  146. tree_true = self.load_html_tree(true)
  147. distance = APTED(tree_pred, tree_true,
  148. CustomConfig()).compute_edit_distance()
  149. return 1.0 - (float(distance) / n_nodes)
  150. else:
  151. return 0.0
  152. def batch_evaluate(self, pred_json, true_json):
  153. ''' Computes TEDS score between the prediction and the ground truth of
  154. a batch of samples
  155. @params pred_json: {'FILENAME': 'HTML CODE', ...}
  156. @params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
  157. @output: {'FILENAME': 'TEDS SCORE', ...}
  158. '''
  159. samples = true_json.keys()
  160. if self.n_jobs == 1:
  161. scores = [self.evaluate(pred_json.get(
  162. filename, ''), true_json[filename]['html']) for filename in tqdm(samples)]
  163. else:
  164. inputs = [{'pred': pred_json.get(
  165. filename, ''), 'true': true_json[filename]['html']} for filename in samples]
  166. scores = parallel_process(
  167. inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1)
  168. scores = dict(zip(samples, scores))
  169. return scores
  170. def batch_evaluate_html(self, pred_htmls, true_htmls):
  171. ''' Computes TEDS score between the prediction and the ground truth of
  172. a batch of samples
  173. '''
  174. if self.n_jobs == 1:
  175. scores = [self.evaluate(pred_html, true_html) for (
  176. pred_html, true_html) in zip(pred_htmls, true_htmls)]
  177. else:
  178. inputs = [{"pred": pred_html, "true": true_html} for(
  179. pred_html, true_html) in zip(pred_htmls, true_htmls)]
  180. scores = parallel_process(
  181. inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1)
  182. return scores
  183. if __name__ == '__main__':
  184. import json
  185. import pprint
  186. with open('sample_pred.json') as fp:
  187. pred_json = json.load(fp)
  188. with open('sample_gt.json') as fp:
  189. true_json = json.load(fp)
  190. teds = TEDS(n_jobs=4)
  191. scores = teds.batch_evaluate(pred_json, true_json)
  192. pp = pprint.PrettyPrinter()
  193. pp.pprint(scores)