Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alternatives #72

Open
wants to merge 4 commits into
base: 66-fix-confidence-computation-and-filtering
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 62 additions & 16 deletions pero_ocr/core/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def get_full_logprobs(self, zero_logit_value: int = -80):
return log_softmax(dense_logits)

def calculate_confidences(self, default_transcription_confidence=None):
if not self.logits:
if self.logits is None:
logger.warning(f'Error: Unable to calculate confidences for line {self.id} due to missing logits.')
self.character_confidences = None
self.transcription_confidence = None
Expand Down Expand Up @@ -203,7 +203,8 @@ def from_pagexml_parse_custom(self, custom_str):
heights = heights_array
self.heights = heights.tolist()

def to_altoxml(self, text_block, arabic_helper, min_line_confidence, version: ALTOVersion):
def to_altoxml(self, text_block, arabic_helper, min_line_confidence, version: ALTOVersion, next_line=None,
previous_line=None, word_splitters=None):
if self.character_confidences is None or self.transcription_confidence is None:
self.calculate_confidences()

Expand All @@ -222,8 +223,9 @@ def to_altoxml(self, text_block, arabic_helper, min_line_confidence, version: AL
text_line.set("WIDTH", str(int(text_line_width)))

if self.category == 'text':
self.to_altoxml_text(text_line, arabic_helper,
text_line_height, text_line_width, text_line_vpos, text_line_hpos)
self.to_altoxml_text(text_line, arabic_helper, text_line_height, text_line_width, text_line_vpos,
text_line_hpos, next_line=next_line, previous_line=previous_line,
word_splitters=word_splitters)
else:
string = ET.SubElement(text_line, "String")
string.set("CONTENT", self.transcription)
Expand Down Expand Up @@ -253,8 +255,8 @@ def get_labels(self):
labels.append(0)
return np.array(labels)

def to_altoxml_text(self, text_line, arabic_helper,
text_line_height, text_line_width, text_line_vpos, text_line_hpos):
def to_altoxml_text(self, text_line, arabic_helper, text_line_height, text_line_width, text_line_vpos,
text_line_hpos, next_line=None, previous_line=None, word_splitters=None):
arabic_line = False
if arabic_helper.is_arabic_line(self.transcription):
arabic_line = True
Expand All @@ -269,15 +271,34 @@ def to_altoxml_text(self, text_line, arabic_helper,
except (ValueError, IndexError, TypeError) as e:
logger.warning(f'Error: Alto export, unable to align line {self.id} due to exception: {e}.')

average_word_width = (text_line_hpos + text_line_width) / len(self.transcription.split())
for w, word in enumerate(self.transcription.split()):
words = self.transcription.split()
average_word_width = (text_line_hpos + text_line_width) / len(words)
for w, word in enumerate(words):
string = ET.SubElement(text_line, "String")
string.set("CONTENT", word)

string.set("HEIGHT", str(int(text_line_height)))
string.set("WIDTH", str(int(average_word_width)))
string.set("VPOS", str(int(text_line_vpos)))
string.set("HPOS", str(int(text_line_hpos + (w * average_word_width))))

if word_splitters is not None:
if w == 0 and previous_line is not None and previous_line.transcription is not None:
previous_word = previous_line.transcription.split()[-1]
last_char = previous_word[-1]
if last_char in word_splitters:
subs_word = previous_word[:-1] + word
string.set("SUBS_CONTENT", subs_word)
string.set("SUBS_TYPE", "HypPart2")

elif w == len(words) - 1 and next_line is not None and next_line.transcription is not None:
last_char = word[-1]
if last_char in word_splitters:
next_line_first_word = next_line.transcription.split()[0]
subs_word = word[:-1] + next_line_first_word
string.set("SUBS_CONTENT", subs_word)
string.set("SUBS_TYPE", "HypPart1")

else:
crop_engine = EngineLineCropper(poly=2)
line_coords = crop_engine.get_crop_inputs(self.baseline, self.heights, 16)
Expand Down Expand Up @@ -334,6 +355,24 @@ def to_altoxml_text(self, text_line, arabic_helper,
if word_confidence is not None:
string.set("WC", str(round(word_confidence, 2)))

if word_splitters is not None:
current_word = splitted_transcription[w]
if w == 0 and previous_line is not None and previous_line.transcription:
previous_word = previous_line.transcription.split()[-1]
last_char = previous_word[-1]
if last_char in word_splitters:
subs_word = previous_word[:-1] + current_word
string.set("SUBS_CONTENT", subs_word)
string.set("SUBS_TYPE", "HypPart2")

elif w == len(words) - 1 and next_line is not None and next_line.transcription:
last_char = current_word[-1]
if last_char in word_splitters:
next_line_first_word = next_line.transcription.split()[0]
subs_word = current_word[:-1] + next_line_first_word
string.set("SUBS_CONTENT", subs_word)
string.set("SUBS_TYPE", "HypPart1")

if w != (len(self.transcription.split()) - 1):
space = ET.SubElement(text_line, "SP")

Expand Down Expand Up @@ -511,8 +550,8 @@ def from_pagexml(cls, region_element: ET.SubElement, schema):

return layout_region

def to_altoxml(self, print_space, arabic_helper, min_line_confidence,
print_space_coords: Tuple[int, int, int, int], version: ALTOVersion) -> Tuple[int, int, int, int]:
def to_altoxml(self, print_space, arabic_helper, min_line_confidence, print_space_coords: Tuple[int, int, int, int],
version: ALTOVersion, word_splitters=None) -> Tuple[int, int, int, int]:
print_space_height, print_space_width, print_space_vpos, print_space_hpos = print_space_coords

text_block = ET.SubElement(print_space, "TextBlock")
Expand All @@ -531,10 +570,14 @@ def to_altoxml(self, print_space, arabic_helper, min_line_confidence,
print_space_height = print_space_height - print_space_vpos
print_space_width = print_space_width - print_space_hpos

for line in self.lines:
for i, line in enumerate(self.lines):
if not line.transcription or line.transcription.strip() == "":
continue
line.to_altoxml(text_block, arabic_helper, min_line_confidence, version)

previous_line = self.lines[i - 1] if i > 0 else None
next_line = self.lines[i + 1] if i + 1 < len(self.lines) else None
line.to_altoxml(text_block, arabic_helper, min_line_confidence, version, next_line=next_line,
previous_line=previous_line, word_splitters=word_splitters)
return print_space_height, print_space_width, print_space_vpos, print_space_hpos

@classmethod
Expand Down Expand Up @@ -757,7 +800,8 @@ def to_pagexml(self, file_name: str, creator: str = 'Pero OCR',
out_f.write(xml_string)

def to_altoxml_string(self, ocr_processing_element: ET.SubElement = None, page_uuid: str = None,
min_line_confidence: float = 0, version: ALTOVersion = ALTOVersion.ALTO_v2_x):
min_line_confidence: float = 0, version: ALTOVersion = ALTOVersion.ALTO_v2_x,
word_splitters=None):
arabic_helper = ArabicHelper()
NSMAP = {"xlink": 'http://www.w3.org/1999/xlink',
"xsi": 'http://www.w3.org/2001/XMLSchema-instance'}
Expand Down Expand Up @@ -802,7 +846,8 @@ def to_altoxml_string(self, ocr_processing_element: ET.SubElement = None, page_u
print_space_coords = (print_space_height, print_space_width, print_space_vpos, print_space_hpos)

for block in self.regions:
print_space_coords = block.to_altoxml(print_space, arabic_helper, min_line_confidence, print_space_coords, version)
print_space_coords = block.to_altoxml(print_space, arabic_helper, min_line_confidence, print_space_coords,
version, word_splitters=word_splitters)

print_space_height, print_space_width, print_space_vpos, print_space_hpos = print_space_coords

Expand Down Expand Up @@ -834,8 +879,9 @@ def to_altoxml_string(self, ocr_processing_element: ET.SubElement = None, page_u
return ET.tostring(root, pretty_print=True, encoding="utf-8", xml_declaration=True).decode("utf-8")

def to_altoxml(self, file_name: str, ocr_processing_element: ET.SubElement = None, page_uuid: str = None,
version: ALTOVersion = ALTOVersion.ALTO_v2_x):
alto_string = self.to_altoxml_string(ocr_processing_element=ocr_processing_element, page_uuid=page_uuid, version=version)
version: ALTOVersion = ALTOVersion.ALTO_v2_x, word_splitters=None):
alto_string = self.to_altoxml_string(ocr_processing_element=ocr_processing_element, page_uuid=page_uuid,
version=version, word_splitters=word_splitters)
with open(file_name, 'w', encoding='utf-8') as out_f:
out_f.write(alto_string)

Expand Down
13 changes: 10 additions & 3 deletions user_scripts/parse_folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def parse_arguments():
parser.add_argument('--output-transcriptions-file-path', help='')
parser.add_argument('--skipp-missing-xml', action='store_true', help='Skipp images which have missing xml.')

parser.add_argument('--word-splitters', default=None, type=str, help='Word splitters for ALTO XML export.')

parser.add_argument('--device', choices=["gpu", "cpu"], default="gpu")
parser.add_argument('--gpu-id', type=int, default=None, help='If set, the computation runs of the specified GPU, otherwise safe-gpu is used to allocate first unused GPU.')

Expand Down Expand Up @@ -144,7 +146,8 @@ def __call__(self, page_layout: PageLayout, file_id):

class Computator:
def __init__(self, page_parser, input_image_path, input_xml_path, input_logit_path, output_render_path,
output_render_category, output_logit_path, output_alto_path, output_xml_path, output_line_path):
output_render_category, output_logit_path, output_alto_path, output_xml_path, output_line_path,
word_splitters=None):
self.page_parser = page_parser
self.input_image_path = input_image_path
self.input_xml_path = input_xml_path
Expand All @@ -155,6 +158,7 @@ def __init__(self, page_parser, input_image_path, input_xml_path, input_logit_pa
self.output_alto_path = output_alto_path
self.output_xml_path = output_xml_path
self.output_line_path = output_line_path
self.word_splitters = word_splitters

def __call__(self, image_file_name, file_id, index, ids_count):
print(f"Processing {file_id}")
Expand Down Expand Up @@ -191,7 +195,8 @@ def __call__(self, image_file_name, file_id, index, ids_count):
page_layout.save_logits(os.path.join(self.output_logit_path, file_id + '.logits'))

if self.output_alto_path is not None:
page_layout.to_altoxml(os.path.join(self.output_alto_path, file_id + '.xml'))
page_layout.to_altoxml(os.path.join(self.output_alto_path, file_id + '.xml'),
word_splitters=self.word_splitters)

if self.output_line_path is not None and page_layout is not None:
if 'lmdb' in self.output_line_path:
Expand Down Expand Up @@ -343,9 +348,11 @@ def main():
ids_to_process = filtered_ids_to_process
images_to_process = filtered_images_to_process

word_splitters = set(list(args.word_splitters)) if args.word_splitters else None

computator = Computator(page_parser, input_image_path, input_xml_path, input_logit_path, output_render_path,
output_render_category, output_logit_path, output_alto_path, output_xml_path,
output_line_path)
output_line_path, word_splitters=word_splitters)

t_start = time.time()
results = []
Expand Down