123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214 |
- # Copyright 2020 IBM
- # Author: peter.zhong@au1.ibm.com
- #
- # This is free software; you can redistribute it and/or modify
- # it under the terms of the Apache 2.0 License.
- #
- # This software is distributed in the hope that it will be useful,
- # but WITHOUT ANY WARRANTY; without even the implied warranty of
- # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- # Apache 2.0 License for more details.
- from rapidfuzz.distance import Levenshtein
- from apted import APTED, Config
- from apted.helpers import Tree
- from lxml import etree, html
- from collections import deque
- from .parallel import parallel_process
- from tqdm import tqdm
- class TableTree(Tree):
- def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
- self.tag = tag
- self.colspan = colspan
- self.rowspan = rowspan
- self.content = content
- self.children = list(children)
- def bracket(self):
- """Show tree using brackets notation"""
- if self.tag == 'td':
- result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \
- (self.tag, self.colspan, self.rowspan, self.content)
- else:
- result = '"tag": %s' % self.tag
- for child in self.children:
- result += child.bracket()
- return "{{{}}}".format(result)
- class CustomConfig(Config):
- def rename(self, node1, node2):
- """Compares attributes of trees"""
- #print(node1.tag)
- if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
- return 1.
- if node1.tag == 'td':
- if node1.content or node2.content:
- #print(node1.content, )
- return Levenshtein.normalized_distance(node1.content, node2.content)
- return 0.
- class CustomConfig_del_short(Config):
- def rename(self, node1, node2):
- """Compares attributes of trees"""
- if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
- return 1.
- if node1.tag == 'td':
- if node1.content or node2.content:
- #print('before')
- #print(node1.content, node2.content)
- #print('after')
- node1_content = node1.content
- node2_content = node2.content
- if len(node1_content) < 3:
- node1_content = ['####']
- if len(node2_content) < 3:
- node2_content = ['####']
- return Levenshtein.normalized_distance(node1_content, node2_content)
- return 0.
- class CustomConfig_del_block(Config):
- def rename(self, node1, node2):
- """Compares attributes of trees"""
- if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
- return 1.
- if node1.tag == 'td':
- if node1.content or node2.content:
-
- node1_content = node1.content
- node2_content = node2.content
- while ' ' in node1_content:
- print(node1_content.index(' '))
- node1_content.pop(node1_content.index(' '))
- while ' ' in node2_content:
- print(node2_content.index(' '))
- node2_content.pop(node2_content.index(' '))
- return Levenshtein.normalized_distance(node1_content, node2_content)
- return 0.
- class TEDS(object):
- ''' Tree Edit Distance basead Similarity
- '''
- def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
- assert isinstance(n_jobs, int) and (
- n_jobs >= 1), 'n_jobs must be an integer greather than 1'
- self.structure_only = structure_only
- self.n_jobs = n_jobs
- self.ignore_nodes = ignore_nodes
- self.__tokens__ = []
- def tokenize(self, node):
- ''' Tokenizes table cells
- '''
- self.__tokens__.append('<%s>' % node.tag)
- if node.text is not None:
- self.__tokens__ += list(node.text)
- for n in node.getchildren():
- self.tokenize(n)
- if node.tag != 'unk':
- self.__tokens__.append('</%s>' % node.tag)
- if node.tag != 'td' and node.tail is not None:
- self.__tokens__ += list(node.tail)
- def load_html_tree(self, node, parent=None):
- ''' Converts HTML tree to the format required by apted
- '''
- global __tokens__
- if node.tag == 'td':
- if self.structure_only:
- cell = []
- else:
- self.__tokens__ = []
- self.tokenize(node)
- cell = self.__tokens__[1:-1].copy()
- new_node = TableTree(node.tag,
- int(node.attrib.get('colspan', '1')),
- int(node.attrib.get('rowspan', '1')),
- cell, *deque())
- else:
- new_node = TableTree(node.tag, None, None, None, *deque())
- if parent is not None:
- parent.children.append(new_node)
- if node.tag != 'td':
- for n in node.getchildren():
- self.load_html_tree(n, new_node)
- if parent is None:
- return new_node
- def evaluate(self, pred, true):
- ''' Computes TEDS score between the prediction and the ground truth of a
- given sample
- '''
- if (not pred) or (not true):
- return 0.0
- parser = html.HTMLParser(remove_comments=True, encoding='utf-8')
- pred = html.fromstring(pred, parser=parser)
- true = html.fromstring(true, parser=parser)
- if pred.xpath('body/table') and true.xpath('body/table'):
- pred = pred.xpath('body/table')[0]
- true = true.xpath('body/table')[0]
- if self.ignore_nodes:
- etree.strip_tags(pred, *self.ignore_nodes)
- etree.strip_tags(true, *self.ignore_nodes)
- n_nodes_pred = len(pred.xpath(".//*"))
- n_nodes_true = len(true.xpath(".//*"))
- n_nodes = max(n_nodes_pred, n_nodes_true)
- tree_pred = self.load_html_tree(pred)
- tree_true = self.load_html_tree(true)
- distance = APTED(tree_pred, tree_true,
- CustomConfig()).compute_edit_distance()
- return 1.0 - (float(distance) / n_nodes)
- else:
- return 0.0
- def batch_evaluate(self, pred_json, true_json):
- ''' Computes TEDS score between the prediction and the ground truth of
- a batch of samples
- @params pred_json: {'FILENAME': 'HTML CODE', ...}
- @params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
- @output: {'FILENAME': 'TEDS SCORE', ...}
- '''
- samples = true_json.keys()
- if self.n_jobs == 1:
- scores = [self.evaluate(pred_json.get(
- filename, ''), true_json[filename]['html']) for filename in tqdm(samples)]
- else:
- inputs = [{'pred': pred_json.get(
- filename, ''), 'true': true_json[filename]['html']} for filename in samples]
- scores = parallel_process(
- inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1)
- scores = dict(zip(samples, scores))
- return scores
- def batch_evaluate_html(self, pred_htmls, true_htmls):
- ''' Computes TEDS score between the prediction and the ground truth of
- a batch of samples
- '''
- if self.n_jobs == 1:
- scores = [self.evaluate(pred_html, true_html) for (
- pred_html, true_html) in zip(pred_htmls, true_htmls)]
- else:
- inputs = [{"pred": pred_html, "true": true_html} for(
- pred_html, true_html) in zip(pred_htmls, true_htmls)]
- scores = parallel_process(
- inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1)
- return scores
- if __name__ == '__main__':
- import json
- import pprint
- with open('sample_pred.json') as fp:
- pred_json = json.load(fp)
- with open('sample_gt.json') as fp:
- true_json = json.load(fp)
- teds = TEDS(n_jobs=4)
- scores = teds.batch_evaluate(pred_json, true_json)
- pp = pprint.PrettyPrinter()
- pp.pprint(scores)
|