diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..e14c371 --- /dev/null +++ b/LICENSE @@ -0,0 +1,17 @@ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..ccc934e --- /dev/null +++ b/README.md @@ -0,0 +1,42 @@ +WoW Screenshot OCR +============== + +Deep learning OCR models to read text from WoW screenshots. Based on a detector that spots text frames from screenshots, and a recognizer that reads text from detected frames. + +- Chat +- Combat log +- Nameplates +- UI frames +- Map + +Usage +---- + +Models will use pre trained weights, you don't have to train anything + +``` +import wow_ocr + +# Init pipeline, detector and recognizer models with pre trained weights +pipeline = wow_ocr.pipeline.Pipeline() + + +# Screenshots example +images = [ + wow_ocr.tools.read(url) + for url in [ + "https://image_url.com/1.jpg", + "https://image_url.com/2.jpg", + ] +] + +# Results - Image to Text +prediction_groups = pipeline.recognize(images) +# # Each list of predictions in prediction_groups is a list of +# # (word, box) tuples. + +``` + +![](p1.webp) +![](p2.webp) + diff --git a/p1.webp b/p1.webp new file mode 100644 index 0000000..6cb6d51 Binary files /dev/null and b/p1.webp differ diff --git a/p2.webp b/p2.webp new file mode 100644 index 0000000..3f8131d Binary files /dev/null and b/p2.webp differ diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..2bca2ac --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,36 @@ +[build-system] +requires = ["setuptools>=61.0.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "wow-ocr" +version = "0.0.0" +description = "A packaged OCR model to read texts into WoW screenshots" +readme = "README.md" +authors = [{ name = "Geo", email = "geoffrey.menon38@gmail.com" }] +license = { file = "LICENSE" } +classifiers = [ + "License :: OSI Approved :: MIT License", + "Programming Language :: Python", + "Programming Language :: Python :: 3", +] +keywords = ["ocr", "wow", "screenshot"] +dependencies = [ + "validators", + "essential_generators", + "tqdm", + "imgaug", + "fonttools", + "editdistance", + "pyclipper", + "shapely", + "efficientnet", + "tensorflow", +] +requires-python = ">=3.9" + +[project.optional-dependencies] +dev = ["black", "pip-tools", "pytest", "types-requests", "types-pkg-resources"] + +[project.urls] +Homepage = "https://github.com/geo-tp/wow-ocr" diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..196f3ae --- /dev/null +++ b/setup.cfg @@ -0,0 +1,9 @@ + +# Config for projects that do not yet support pyproject.toml (PEP-518) +# https://www.python.org/dev/peps/pep-0518/ + +# mypy: https://github.com/python/mypy/issues/5205 +[mypy] +ignore_missing_imports = True +check_untyped_defs = True +ha \ No newline at end of file diff --git a/wow_ocr/__init__.py b/wow_ocr/__init__.py new file mode 100644 index 0000000..8d730f0 --- /dev/null +++ b/wow_ocr/__init__.py @@ -0,0 +1,3 @@ +from . import detector, recognizer, tools, pipeline + +__version__ = "0.0.0" diff --git a/wow_ocr/detector.py b/wow_ocr/detector.py new file mode 100644 index 0000000..9b52f6b --- /dev/null +++ b/wow_ocr/detector.py @@ -0,0 +1,820 @@ +import typing +import cv2 +import numpy as np +import tensorflow as tf +import efficientnet.tfkeras as efficientnet +from tensorflow import keras + +from . import tools + + +def compute_input(image): + # should be RGB order + image = image.astype("float32") + mean = np.array([0.485, 0.456, 0.406]) + variance = np.array([0.229, 0.224, 0.225]) + + image -= mean * 255 + image /= variance * 255 + return image + + +def get_gaussian_heatmap(size=512, distanceRatio=3.34): + v = np.abs(np.linspace(-size / 2, size / 2, num=size)) + x, y = np.meshgrid(v, v) + g = np.sqrt(x**2 + y**2) + g *= distanceRatio / (size / 2) + g = np.exp(-(1 / 2) * (g**2)) + g *= 255 + return g.clip(0, 255).astype("uint8") + + +def compute_maps(heatmap, image_height, image_width, lines): + assert image_height % 2 == 0, "Height must be an even number" + assert image_width % 2 == 0, "Width must be an even number" + + textmap = np.zeros((image_height // 2, image_width // 2)).astype("float32") + linkmap = np.zeros((image_height // 2, image_width // 2)).astype("float32") + + src = np.array( + [ + [0, 0], + [heatmap.shape[1], 0], + [heatmap.shape[1], heatmap.shape[0]], + [0, heatmap.shape[0]], + ] + ).astype("float32") + + for line in lines: + line, orientation = tools.fix_line(line) + previous_link_points = None + for [(x1, y1), (x2, y2), (x3, y3), (x4, y4)], c in line: + x1, y1, x2, y2, x3, y3, x4, y4 = map( + lambda v: max(v, 0), [x1, y1, x2, y2, x3, y3, x4, y4] + ) + if c == " ": + previous_link_points = None + continue + yc = (y4 + y1 + y3 + y2) / 4 + xc = (x1 + x2 + x3 + x4) / 4 + if orientation == "horizontal": + current_link_points = ( + np.array( + [ + [(xc + (x1 + x2) / 2) / 2, (yc + (y1 + y2) / 2) / 2], + [(xc + (x3 + x4) / 2) / 2, (yc + (y3 + y4) / 2) / 2], + ] + ) + / 2 + ) + else: + current_link_points = ( + np.array( + [ + [(xc + (x1 + x4) / 2) / 2, (yc + (y1 + y4) / 2) / 2], + [(xc + (x2 + x3) / 2) / 2, (yc + (y2 + y3) / 2) / 2], + ] + ) + / 2 + ) + character_points = ( + np.array([[x1, y1], [x2, y2], [x3, y3], [x4, y4]]).astype("float32") / 2 + ) + # pylint: disable=unsubscriptable-object + if previous_link_points is not None: + if orientation == "horizontal": + link_points = np.array( + [ + previous_link_points[0], + current_link_points[0], + current_link_points[1], + previous_link_points[1], + ] + ) + else: + link_points = np.array( + [ + previous_link_points[0], + previous_link_points[1], + current_link_points[1], + current_link_points[0], + ] + ) + ML = cv2.getPerspectiveTransform( + src=src, + dst=link_points.astype("float32"), + ) + linkmap += cv2.warpPerspective( + heatmap, ML, dsize=(linkmap.shape[1], linkmap.shape[0]) + ).astype("float32") + MA = cv2.getPerspectiveTransform( + src=src, + dst=character_points, + ) + textmap += cv2.warpPerspective( + heatmap, MA, dsize=(textmap.shape[1], textmap.shape[0]) + ).astype("float32") + # pylint: enable=unsubscriptable-object + previous_link_points = current_link_points + return ( + np.concatenate( + [textmap[..., np.newaxis], linkmap[..., np.newaxis]], axis=2 + ).clip(0, 255) + / 255 + ) + + +def getBoxes( + y_pred, + detection_threshold=0.7, + text_threshold=0.4, + link_threshold=0.4, + size_threshold=10, +): + box_groups = [] + for y_pred_cur in y_pred: + # Prepare data + textmap = y_pred_cur[..., 0].copy() + linkmap = y_pred_cur[..., 1].copy() + img_h, img_w = textmap.shape + + _, text_score = cv2.threshold( + textmap, thresh=text_threshold, maxval=1, type=cv2.THRESH_BINARY + ) + _, link_score = cv2.threshold( + linkmap, thresh=link_threshold, maxval=1, type=cv2.THRESH_BINARY + ) + n_components, labels, stats, _ = cv2.connectedComponentsWithStats( + np.clip(text_score + link_score, 0, 1).astype("uint8"), connectivity=4 + ) + boxes = [] + for component_id in range(1, n_components): + # Filter by size + size = stats[component_id, cv2.CC_STAT_AREA] + + if size < size_threshold: + continue + + # If the maximum value within this connected component is less than + # text threshold, we skip it. + if np.max(textmap[labels == component_id]) < detection_threshold: + continue + + # Make segmentation map. It is 255 where we find text, 0 otherwise. + segmap = np.zeros_like(textmap) + segmap[labels == component_id] = 255 + segmap[np.logical_and(link_score, text_score)] = 0 + x, y, w, h = [ + stats[component_id, key] + for key in [ + cv2.CC_STAT_LEFT, + cv2.CC_STAT_TOP, + cv2.CC_STAT_WIDTH, + cv2.CC_STAT_HEIGHT, + ] + ] + + # Expand the elements of the segmentation map + niter = int(np.sqrt(size * min(w, h) / (w * h)) * 2) + sx, sy = max(x - niter, 0), max(y - niter, 0) + ex, ey = min(x + w + niter + 1, img_w), min(y + h + niter + 1, img_h) + segmap[sy:ey, sx:ex] = cv2.dilate( + segmap[sy:ey, sx:ex], + cv2.getStructuringElement(cv2.MORPH_RECT, (1 + niter, 1 + niter)), + ) + + # Make rotated box from contour + contours = cv2.findContours( + segmap.astype("uint8"), + mode=cv2.RETR_TREE, + method=cv2.CHAIN_APPROX_SIMPLE, + )[-2] + contour = contours[0] + box = cv2.boxPoints(cv2.minAreaRect(contour)) + + # Check to see if we have a diamond + w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2]) + box_ratio = max(w, h) / (min(w, h) + 1e-5) + if abs(1 - box_ratio) <= 0.1: + l, r = contour[:, 0, 0].min(), contour[:, 0, 0].max() + t, b = contour[:, 0, 1].min(), contour[:, 0, 1].max() + box = np.array([[l, t], [r, t], [r, b], [l, b]], dtype=np.float32) + else: + # Make clock-wise order + box = np.array(np.roll(box, 4 - box.sum(axis=1).argmin(), 0)) + boxes.append(2 * box) + box_groups.append(np.array(boxes)) + return box_groups + + +def build_keras_model(weights_path: str = None, backbone_name="vgg"): + inputs = keras.layers.Input((None, None, 3)) + + if backbone_name == "vgg": + s1, s2, s3, s4 = build_vgg_backbone(inputs) + elif "efficientnet" in backbone_name.lower(): + s1, s2, s3, s4 = build_efficientnet_backbone( + inputs=inputs, backbone_name=backbone_name, imagenet=weights_path is None + ) + else: + raise NotImplementedError + + s5 = keras.layers.MaxPooling2D( + pool_size=3, strides=1, padding="same", name="basenet.slice5.0" + )(s4) + s5 = keras.layers.Conv2D( + 1024, + kernel_size=(3, 3), + padding="same", + strides=1, + dilation_rate=6, + name="basenet.slice5.1", + )(s5) + s5 = keras.layers.Conv2D( + 1024, kernel_size=1, strides=1, padding="same", name="basenet.slice5.2" + )(s5) + + y = keras.layers.Concatenate()([s5, s4]) + y = upconv(y, n=1, filters=512) + y = UpsampleLike()([y, s3]) + y = keras.layers.Concatenate()([y, s3]) + y = upconv(y, n=2, filters=256) + y = UpsampleLike()([y, s2]) + y = keras.layers.Concatenate()([y, s2]) + y = upconv(y, n=3, filters=128) + y = UpsampleLike()([y, s1]) + y = keras.layers.Concatenate()([y, s1]) + features = upconv(y, n=4, filters=64) + + y = keras.layers.Conv2D( + filters=32, kernel_size=3, strides=1, padding="same", name="conv_cls.0" + )(features) + y = keras.layers.Activation("relu", name="conv_cls.1")(y) + y = keras.layers.Conv2D( + filters=32, kernel_size=3, strides=1, padding="same", name="conv_cls.2" + )(y) + y = keras.layers.Activation("relu", name="conv_cls.3")(y) + y = keras.layers.Conv2D( + filters=16, kernel_size=3, strides=1, padding="same", name="conv_cls.4" + )(y) + y = keras.layers.Activation("relu", name="conv_cls.5")(y) + y = keras.layers.Conv2D( + filters=16, kernel_size=1, strides=1, padding="same", name="conv_cls.6" + )(y) + y = keras.layers.Activation("relu", name="conv_cls.7")(y) + y = keras.layers.Conv2D( + filters=2, kernel_size=1, strides=1, padding="same", name="conv_cls.8" + )(y) + if backbone_name != "vgg": + y = keras.layers.Activation("sigmoid")(y) + model = keras.models.Model(inputs=inputs, outputs=y) + if weights_path is not None: + if weights_path.endswith(".h5"): + model.load_weights(weights_path) + elif weights_path.endswith(".pth"): + assert ( + backbone_name == "vgg" + ), "PyTorch weights only allowed with VGG backbone." + load_torch_weights(model=model, weights_path=weights_path) + else: + raise NotImplementedError(f"Cannot load weights from {weights_path}") + return model + + +class UpsampleLike(keras.layers.Layer): + """Keras layer for upsampling a Tensor to be the same shape as another Tensor.""" + + # pylint:disable=unused-argument + def call(self, inputs, **kwargs): + source, target = inputs + target_shape = keras.backend.shape(target) + if keras.backend.image_data_format() == "channels_first": + raise NotImplementedError + else: + # pylint: disable=no-member + return tf.compat.v1.image.resize_bilinear( + source, size=(target_shape[1], target_shape[2]), half_pixel_centers=True + ) + + def compute_output_shape(self, input_shape): + if keras.backend.image_data_format() == "channels_first": + raise NotImplementedError + else: + return (input_shape[0][0],) + input_shape[1][1:3] + (input_shape[0][-1],) + + +def upconv(x, n, filters): + x = keras.layers.Conv2D( + filters=filters, kernel_size=1, strides=1, name=f"upconv{n}.conv.0" + )(x) + x = keras.layers.BatchNormalization( + epsilon=1e-5, momentum=0.9, name=f"upconv{n}.conv.1" + )(x) + x = keras.layers.Activation("relu", name=f"upconv{n}.conv.2")(x) + x = keras.layers.Conv2D( + filters=filters // 2, + kernel_size=3, + strides=1, + padding="same", + name=f"upconv{n}.conv.3", + )(x) + x = keras.layers.BatchNormalization( + epsilon=1e-5, momentum=0.9, name=f"upconv{n}.conv.4" + )(x) + x = keras.layers.Activation("relu", name=f"upconv{n}.conv.5")(x) + return x + + +def make_vgg_block(x, filters, n, prefix, pooling=True): + x = keras.layers.Conv2D( + filters=filters, + strides=(1, 1), + kernel_size=(3, 3), + padding="same", + name=f"{prefix}.{n}", + )(x) + x = keras.layers.BatchNormalization( + momentum=0.1, epsilon=1e-5, axis=-1, name=f"{prefix}.{n+1}" + )(x) + x = keras.layers.Activation("relu", name=f"{prefix}.{n+2}")(x) + if pooling: + x = keras.layers.MaxPooling2D( + pool_size=(2, 2), padding="valid", strides=(2, 2), name=f"{prefix}.{n+3}" + )(x) + return x + + +def build_vgg_backbone(inputs): + x = make_vgg_block(inputs, filters=64, n=0, pooling=False, prefix="basenet.slice1") + x = make_vgg_block(x, filters=64, n=3, pooling=True, prefix="basenet.slice1") + x = make_vgg_block(x, filters=128, n=7, pooling=False, prefix="basenet.slice1") + x = make_vgg_block(x, filters=128, n=10, pooling=True, prefix="basenet.slice1") + x = make_vgg_block(x, filters=256, n=14, pooling=False, prefix="basenet.slice2") + x = make_vgg_block(x, filters=256, n=17, pooling=False, prefix="basenet.slice2") + x = make_vgg_block(x, filters=256, n=20, pooling=True, prefix="basenet.slice3") + x = make_vgg_block(x, filters=512, n=24, pooling=False, prefix="basenet.slice3") + x = make_vgg_block(x, filters=512, n=27, pooling=False, prefix="basenet.slice3") + x = make_vgg_block(x, filters=512, n=30, pooling=True, prefix="basenet.slice4") + x = make_vgg_block(x, filters=512, n=34, pooling=False, prefix="basenet.slice4") + x = make_vgg_block(x, filters=512, n=37, pooling=False, prefix="basenet.slice4") + x = make_vgg_block(x, filters=512, n=40, pooling=True, prefix="basenet.slice4") + vgg = keras.models.Model(inputs=inputs, outputs=x) + return [ + vgg.get_layer(slice_name).output + for slice_name in [ + "basenet.slice1.12", + "basenet.slice2.19", + "basenet.slice3.29", + "basenet.slice4.38", + ] + ] + + +def build_efficientnet_backbone(inputs, backbone_name, imagenet): + backbone = getattr(efficientnet, backbone_name)( + include_top=False, input_tensor=inputs, weights="imagenet" if imagenet else None + ) + return [ + backbone.get_layer(slice_name).output + for slice_name in [ + "block2a_expand_activation", + "block3a_expand_activation", + "block4a_expand_activation", + "block5a_expand_activation", + ] + ] + + +def build_keras_model(weights_path: str = None, backbone_name="vgg"): + inputs = keras.layers.Input((None, None, 3)) + + if backbone_name == "vgg": + s1, s2, s3, s4 = build_vgg_backbone(inputs) + elif "efficientnet" in backbone_name.lower(): + s1, s2, s3, s4 = build_efficientnet_backbone( + inputs=inputs, backbone_name=backbone_name, imagenet=weights_path is None + ) + else: + raise NotImplementedError + + s5 = keras.layers.MaxPooling2D( + pool_size=3, strides=1, padding="same", name="basenet.slice5.0" + )(s4) + s5 = keras.layers.Conv2D( + 1024, + kernel_size=(3, 3), + padding="same", + strides=1, + dilation_rate=6, + name="basenet.slice5.1", + )(s5) + s5 = keras.layers.Conv2D( + 1024, kernel_size=1, strides=1, padding="same", name="basenet.slice5.2" + )(s5) + + y = keras.layers.Concatenate()([s5, s4]) + y = upconv(y, n=1, filters=512) + y = UpsampleLike()([y, s3]) + y = keras.layers.Concatenate()([y, s3]) + y = upconv(y, n=2, filters=256) + y = UpsampleLike()([y, s2]) + y = keras.layers.Concatenate()([y, s2]) + y = upconv(y, n=3, filters=128) + y = UpsampleLike()([y, s1]) + y = keras.layers.Concatenate()([y, s1]) + features = upconv(y, n=4, filters=64) + + y = keras.layers.Conv2D( + filters=32, kernel_size=3, strides=1, padding="same", name="conv_cls.0" + )(features) + y = keras.layers.Activation("relu", name="conv_cls.1")(y) + y = keras.layers.Conv2D( + filters=32, kernel_size=3, strides=1, padding="same", name="conv_cls.2" + )(y) + y = keras.layers.Activation("relu", name="conv_cls.3")(y) + y = keras.layers.Conv2D( + filters=16, kernel_size=3, strides=1, padding="same", name="conv_cls.4" + )(y) + y = keras.layers.Activation("relu", name="conv_cls.5")(y) + y = keras.layers.Conv2D( + filters=16, kernel_size=1, strides=1, padding="same", name="conv_cls.6" + )(y) + y = keras.layers.Activation("relu", name="conv_cls.7")(y) + y = keras.layers.Conv2D( + filters=2, kernel_size=1, strides=1, padding="same", name="conv_cls.8" + )(y) + if backbone_name != "vgg": + y = keras.layers.Activation("sigmoid")(y) + model = keras.models.Model(inputs=inputs, outputs=y) + if weights_path is not None: + if weights_path.endswith(".h5"): + model.load_weights(weights_path) + elif weights_path.endswith(".pth"): + assert ( + backbone_name == "vgg" + ), "PyTorch weights only allowed with VGG backbone." + load_torch_weights(model=model, weights_path=weights_path) + else: + raise NotImplementedError(f"Cannot load weights from {weights_path}") + return model + + +# pylint: disable=import-error +def load_torch_weights(model, weights_path): + import torch + + pretrained = torch.load(weights_path, map_location=torch.device("cpu")) + layer_names = list( + set( + ".".join(k.split(".")[1:-1]) + for k in pretrained.keys() + if k.split(".")[-1] != "num_batches_tracked" + ) + ) + for layer_name in layer_names: + try: + layer = model.get_layer(layer_name) + except Exception: # pylint: disable=broad-except + print("Skipping", layer.name) + continue + if isinstance(layer, keras.layers.BatchNormalization): + gamma, beta, running_mean, running_std = [ + pretrained[k].numpy() + for k in [ + f"module.{layer_name}.weight", + f"module.{layer_name}.bias", + f"module.{layer_name}.running_mean", + f"module.{layer_name}.running_var", + ] + ] + layer.set_weights([gamma, beta, running_mean, running_std]) + elif isinstance(layer, keras.layers.Conv2D): + weights, bias = [ + pretrained[k].numpy() + for k in [f"module.{layer_name}.weight", f"module.{layer_name}.bias"] + ] + layer.set_weights([weights.transpose(2, 3, 1, 0), bias]) + + else: + raise NotImplementedError + + for layer in model.layers: + if isinstance(layer, (keras.layers.BatchNormalization, keras.layers.Conv2D)): + assert layer.name in layer_names + + +# pylint: disable=import-error,too-few-public-methods +def build_torch_model(weights_path=None): + from collections import namedtuple, OrderedDict + + import torch + import torchvision + + def init_weights(modules): + for m in modules: + if isinstance(m, torch.nn.Conv2d): + torch.nn.init.xavier_uniform_(m.weight.data) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, torch.nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, torch.nn.Linear): + m.weight.data.normal_(0, 0.01) + m.bias.data.zero_() + + class vgg16_bn(torch.nn.Module): + def __init__(self, pretrained=True, freeze=True): + super().__init__() + # We don't bother loading the pretrained VGG + # because we're going to use the weights + # at weights_path. + vgg_pretrained_features = torchvision.models.vgg16_bn( + pretrained=False + ).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + for x in range(12): # conv2_2 + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(12, 19): # conv3_3 + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(19, 29): # conv4_3 + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(29, 39): # conv5_3 + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + + # fc6, fc7 without atrous conv + self.slice5 = torch.nn.Sequential( + torch.nn.MaxPool2d(kernel_size=3, stride=1, padding=1), + torch.nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6), + torch.nn.Conv2d(1024, 1024, kernel_size=1), + ) + + if not pretrained: + init_weights(self.slice1.modules()) + init_weights(self.slice2.modules()) + init_weights(self.slice3.modules()) + init_weights(self.slice4.modules()) + + init_weights(self.slice5.modules()) # no pretrained model for fc6 and fc7 + + if freeze: + for param in self.slice1.parameters(): # only first conv + param.requires_grad = False + + def forward(self, X): # pylint: disable=arguments-differ + h = self.slice1(X) + h_relu2_2 = h + h = self.slice2(h) + h_relu3_2 = h + h = self.slice3(h) + h_relu4_3 = h + h = self.slice4(h) + h_relu5_3 = h + h = self.slice5(h) + h_fc7 = h + vgg_outputs = namedtuple( + "vgg_outputs", ["fc7", "relu5_3", "relu4_3", "relu3_2", "relu2_2"] + ) + out = vgg_outputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2) + return out + + class double_conv(torch.nn.Module): + def __init__(self, in_ch, mid_ch, out_ch): + super().__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1), + torch.nn.BatchNorm2d(mid_ch), + torch.nn.ReLU(inplace=True), + torch.nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1), + torch.nn.BatchNorm2d(out_ch), + torch.nn.ReLU(inplace=True), + ) + + def forward(self, x): # pylint: disable=arguments-differ + x = self.conv(x) + return x + + class CRAFT(torch.nn.Module): + def __init__(self, pretrained=False, freeze=False): + super().__init__() + # Base network + self.basenet = vgg16_bn(pretrained, freeze) + # U network + self.upconv1 = double_conv(1024, 512, 256) + self.upconv2 = double_conv(512, 256, 128) + self.upconv3 = double_conv(256, 128, 64) + self.upconv4 = double_conv(128, 64, 32) + + num_class = 2 + self.conv_cls = torch.nn.Sequential( + torch.nn.Conv2d(32, 32, kernel_size=3, padding=1), + torch.nn.ReLU(inplace=True), + torch.nn.Conv2d(32, 32, kernel_size=3, padding=1), + torch.nn.ReLU(inplace=True), + torch.nn.Conv2d(32, 16, kernel_size=3, padding=1), + torch.nn.ReLU(inplace=True), + torch.nn.Conv2d(16, 16, kernel_size=1), + torch.nn.ReLU(inplace=True), + torch.nn.Conv2d(16, num_class, kernel_size=1), + ) + + init_weights(self.upconv1.modules()) + init_weights(self.upconv2.modules()) + init_weights(self.upconv3.modules()) + init_weights(self.upconv4.modules()) + init_weights(self.conv_cls.modules()) + + def forward(self, x): # pylint: disable=arguments-differ + # Base network + sources = self.basenet(x) + # U network + # pylint: disable=E1101 + y = torch.cat([sources[0], sources[1]], dim=1) + + y = self.upconv1(y) + + y = torch.nn.functional.interpolate( + y, size=sources[2].size()[2:], mode="bilinear", align_corners=False + ) + y = torch.cat([y, sources[2]], dim=1) + y = self.upconv2(y) + + y = torch.nn.functional.interpolate( + y, size=sources[3].size()[2:], mode="bilinear", align_corners=False + ) + y = torch.cat([y, sources[3]], dim=1) + y = self.upconv3(y) + + y = torch.nn.functional.interpolate( + y, size=sources[4].size()[2:], mode="bilinear", align_corners=False + ) + y = torch.cat([y, sources[4]], dim=1) + # pylint: enable=E1101 + feature = self.upconv4(y) + + y = self.conv_cls(feature) + + return y.permute(0, 2, 3, 1), feature + + def copyStateDict(state_dict): + if list(state_dict.keys())[0].startswith("module"): + start_idx = 1 + else: + start_idx = 0 + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = ".".join(k.split(".")[start_idx:]) + new_state_dict[name] = v + return new_state_dict + + model = CRAFT(pretrained=True).eval() + if weights_path is not None: + model.load_state_dict( + copyStateDict(torch.load(weights_path, map_location=torch.device("cpu"))) + ) + return model + + +PRETRAINED_WEIGHTS = { + ("clovaai_general", True): { + "url": "https://github.com/faustomorales/keras-ocr/releases/download/v0.8.4/craft_mlt_25k.pth", + "filename": "craft_mlt_25k.pth", + "sha256": "4a5efbfb48b4081100544e75e1e2b57f8de3d84f213004b14b85fd4b3748db17", + }, + ("clovaai_general", False): { + "url": "https://github.com/faustomorales/keras-ocr/releases/download/v0.8.4/craft_mlt_25k.h5", + "filename": "craft_mlt_25k.h5", + "sha256": "7283ce2ff05a0617e9740c316175ff3bacdd7215dbdf1a726890d5099431f899", + }, +} + + +class Detector: + """A text detector using the CRAFT architecture. + + Args: + weights: The weights to use for the model. Currently, only `clovaai_general` + is supported. + load_from_torch: Whether to load the weights from the original PyTorch weights. + optimizer: The optimizer to use for training the model. + backbone_name: The backbone to use. Currently, only 'vgg' is supported. + """ + + def __init__( + self, + weights="clovaai_general", + load_from_torch=False, + optimizer="adam", + backbone_name="vgg", + ): + if weights is not None: + pretrained_key = (weights, load_from_torch) + assert backbone_name == "vgg", "Pretrained weights available only for VGG." + assert ( + pretrained_key in PRETRAINED_WEIGHTS + ), "Selected weights configuration not found." + weights_config = PRETRAINED_WEIGHTS[pretrained_key] + weights_path = tools.download_and_verify( + url=weights_config["url"], + filename=weights_config["filename"], + sha256=weights_config["sha256"], + ) + else: + weights_path = None + self.model = build_keras_model( + weights_path=weights_path, backbone_name=backbone_name + ) + self.model.compile(loss="mse", optimizer=optimizer) + + def get_batch_generator( + self, + image_generator, + batch_size=8, + heatmap_size=512, + heatmap_distance_ratio=1.5, + ): + """Get a generator of X, y batches to train the detector. + + Args: + image_generator: A generator with the same signature as + keras_ocr.tools.get_image_generator. Optionally, a third + entry in the tuple (beyond image and lines) can be provided + which will be interpreted as the sample weight. + batch_size: The size of batches to generate. + heatmap_size: The size of the heatmap to pass to get_gaussian_heatmap + heatmap_distance_ratio: The distance ratio to pass to + get_gaussian_heatmap. The larger the value, the more tightly + concentrated the heatmap becomes. + """ + heatmap = get_gaussian_heatmap( + size=heatmap_size, distanceRatio=heatmap_distance_ratio + ) + while True: + batch = [next(image_generator) for n in range(batch_size)] + images = np.array([entry[0] for entry in batch]) + line_groups = [entry[1] for entry in batch] + X = compute_input(images) + # pylint: disable=unsubscriptable-object + y = np.array( + [ + compute_maps( + heatmap=heatmap, + image_height=images.shape[1], + image_width=images.shape[2], + lines=lines, + ) + for lines in line_groups + ] + ) + # pylint: enable=unsubscriptable-object + if len(batch[0]) == 3: + sample_weights = np.array([sample[2] for sample in batch]) + yield X, y, sample_weights + else: + yield X, y + + def detect( + self, + images: typing.List[typing.Union[np.ndarray, str]], + detection_threshold=0.7, + text_threshold=0.4, + link_threshold=0.4, + size_threshold=10, + **kwargs, + ): + """Recognize the text in a set of images. + + Args: + images: Can be a list of numpy arrays of shape HxWx3 or a list of + filepaths. + link_threshold: This is the same as `text_threshold`, but is applied to the + link map instead of the text map. + detection_threshold: We want to avoid including boxes that may have + represented large regions of low confidence text predictions. To do this, + we do a final check for each word box to make sure the maximum confidence + value exceeds some detection threshold. This is the threshold used for + this check. + text_threshold: When the text map is processed, it is converted from confidence + (float from zero to one) values to classification (0 for not text, 1 for + text) using binary thresholding. The threshold value determines the + breakpoint at which a value is converted to a 1 or a 0. For example, if + the threshold is 0.4 and a value for particular point on the text map is + 0.5, that value gets converted to a 1. The higher this value is, the less + likely it is that characters will be merged together into a single word. + The lower this value is, the more likely it is that non-text will be detected. + Therein lies the balance. + size_threshold: The minimum area for a word. + """ + images = [compute_input(tools.read(image)) for image in images] + boxes = getBoxes( + self.model.predict(np.array(images), **kwargs), + detection_threshold=detection_threshold, + text_threshold=text_threshold, + link_threshold=link_threshold, + size_threshold=size_threshold, + ) + return boxes diff --git a/wow_ocr/pipeline.py b/wow_ocr/pipeline.py new file mode 100644 index 0000000..6fb1571 --- /dev/null +++ b/wow_ocr/pipeline.py @@ -0,0 +1,76 @@ +# pylint: disable=too-few-public-methods +import numpy as np +from . import detector, recognizer, tools + + +class Pipeline: + """A wrapper for a combination of detector and recognizer. + + Args: + detector: The detector to use + recognizer: The recognizer to use + scale: The scale factor to apply to input images + max_size: The maximum single-side dimension of images for + inference. + """ + + def __init__( + self, detector_model=None, recognizer_model=None, scale=2, max_size=2048 + ): + if detector_model is None: + detector_model = detector.Detector() + if recognizer_model is None: + recognizer_model = recognizer.Recognizer() + self.scale = scale + self.detector = detector_model + self.recognizer = recognizer_model + self.max_size = max_size + + def recognize(self, images, detection_kwargs=None, recognition_kwargs=None): + """Run the pipeline on one or multiples images. + + Args: + images: The images to parse (can be a list of actual images or a list of filepaths) + detection_kwargs: Arguments to pass to the detector call + recognition_kwargs: Arguments to pass to the recognizer call + + Returns: + A list of lists of (text, box) tuples. + """ + + # Make sure we have an image array to start with. + if not isinstance(images, np.ndarray): + images = [tools.read(image) for image in images] + # This turns images into (image, scale) tuples temporarily + images = [ + tools.resize_image(image, max_scale=self.scale, max_size=self.max_size) + for image in images + ] + max_height, max_width = np.array( + [image.shape[:2] for image, scale in images] + ).max(axis=0) + scales = [scale for _, scale in images] + images = np.array( + [ + tools.pad(image, width=max_width, height=max_height) + for image, _ in images + ] + ) + if detection_kwargs is None: + detection_kwargs = {} + if recognition_kwargs is None: + recognition_kwargs = {} + box_groups = self.detector.detect(images=images, **detection_kwargs) + prediction_groups = self.recognizer.recognize_from_boxes( + images=images, box_groups=box_groups, **recognition_kwargs + ) + box_groups = [ + tools.adjust_boxes(boxes=boxes, boxes_format="boxes", scale=1 / scale) + if scale != 1 + else boxes + for boxes, scale in zip(box_groups, scales) + ] + return [ + list(zip(predictions, boxes)) + for predictions, boxes in zip(prediction_groups, box_groups) + ] diff --git a/wow_ocr/recognizer.py b/wow_ocr/recognizer.py new file mode 100644 index 0000000..5f6f521 --- /dev/null +++ b/wow_ocr/recognizer.py @@ -0,0 +1,521 @@ +import string +import typing +import tensorflow as tf +from tensorflow import keras +import numpy as np +import cv2 +from . import tools +import sys + +DEFAULT_BUILD_PARAMS = { + "height": 31, + "width": 200, + "color": False, + "filters": (64, 128, 256, 256, 512, 512, 512), + "rnn_units": (128, 128), + "dropout": 0.25, + "rnn_steps_to_discard": 2, + "pool_size": 2, + "stn": True, +} + +DEFAULT_ALPHABET = string.digits + string.ascii_lowercase + +PRETRAINED_WEIGHTS: typing.Dict[str, typing.Any] = { + "wow_ocr": { + "alphabet": DEFAULT_ALPHABET, + "build_params": DEFAULT_BUILD_PARAMS, + "weights": { + "url": "https://github.com/faustomorales/keras-ocr/releases/download/v0.8.4/crnn_kurapan.h5", + "filename": "crnn_kurapan.h5", + "sha256": "a7d8086ac8f5c3d6a0a828f7d6fbabcaf815415dd125c32533013f85603be46d", + }, + }, +} + + +def _repeat(x, num_repeats): + ones = tf.ones((1, num_repeats), dtype="int32") + x = tf.reshape(x, shape=(-1, 1)) + x = tf.matmul(x, ones) + return tf.reshape(x, [-1]) + + +def _meshgrid(height, width): + x_linspace = tf.linspace(-1.0, 1.0, width) + y_linspace = tf.linspace(-1.0, 1.0, height) + x_coordinates, y_coordinates = tf.meshgrid(x_linspace, y_linspace) + x_coordinates = tf.reshape(x_coordinates, shape=(1, -1)) + y_coordinates = tf.reshape(y_coordinates, shape=(1, -1)) + ones = tf.ones_like(x_coordinates) + indices_grid = tf.concat([x_coordinates, y_coordinates, ones], 0) + return indices_grid + + +# pylint: disable=too-many-statements +def _transform(inputs): + locnet_x, locnet_y = inputs + output_size = locnet_x.shape[1:] + batch_size = tf.shape(locnet_x)[0] + height = tf.shape(locnet_x)[1] + width = tf.shape(locnet_x)[2] + num_channels = tf.shape(locnet_x)[3] + + locnet_y = tf.reshape(locnet_y, shape=(batch_size, 2, 3)) + + locnet_y = tf.reshape(locnet_y, (-1, 2, 3)) + locnet_y = tf.cast(locnet_y, "float32") + + output_height = output_size[0] + output_width = output_size[1] + indices_grid = _meshgrid(output_height, output_width) + indices_grid = tf.expand_dims(indices_grid, 0) + indices_grid = tf.reshape(indices_grid, [-1]) # flatten? + indices_grid = tf.tile(indices_grid, tf.stack([batch_size])) + indices_grid = tf.reshape(indices_grid, tf.stack([batch_size, 3, -1])) + + transformed_grid = tf.matmul(locnet_y, indices_grid) + x_s = tf.slice(transformed_grid, [0, 0, 0], [-1, 1, -1]) + y_s = tf.slice(transformed_grid, [0, 1, 0], [-1, 1, -1]) + x = tf.reshape(x_s, [-1]) + y = tf.reshape(y_s, [-1]) + + # Interpolate + height_float = tf.cast(height, dtype="float32") + width_float = tf.cast(width, dtype="float32") + + output_height = output_size[0] + output_width = output_size[1] + + x = tf.cast(x, dtype="float32") + y = tf.cast(y, dtype="float32") + x = 0.5 * (x + 1.0) * width_float + y = 0.5 * (y + 1.0) * height_float + + x0 = tf.cast(tf.floor(x), "int32") + x1 = x0 + 1 + y0 = tf.cast(tf.floor(y), "int32") + y1 = y0 + 1 + + max_y = tf.cast(height - 1, dtype="int32") + max_x = tf.cast(width - 1, dtype="int32") + zero = tf.zeros([], dtype="int32") + + x0 = tf.clip_by_value(x0, zero, max_x) + x1 = tf.clip_by_value(x1, zero, max_x) + y0 = tf.clip_by_value(y0, zero, max_y) + y1 = tf.clip_by_value(y1, zero, max_y) + + flat_image_dimensions = width * height + pixels_batch = tf.range(batch_size) * flat_image_dimensions + flat_output_dimensions = output_height * output_width + base = _repeat(pixels_batch, flat_output_dimensions) + base_y0 = base + y0 * width + base_y1 = base + y1 * width + indices_a = base_y0 + x0 + indices_b = base_y1 + x0 + indices_c = base_y0 + x1 + indices_d = base_y1 + x1 + + flat_image = tf.reshape(locnet_x, shape=(-1, num_channels)) + flat_image = tf.cast(flat_image, dtype="float32") + pixel_values_a = tf.gather(flat_image, indices_a) + pixel_values_b = tf.gather(flat_image, indices_b) + pixel_values_c = tf.gather(flat_image, indices_c) + pixel_values_d = tf.gather(flat_image, indices_d) + + x0 = tf.cast(x0, "float32") + x1 = tf.cast(x1, "float32") + y0 = tf.cast(y0, "float32") + y1 = tf.cast(y1, "float32") + + area_a = tf.expand_dims(((x1 - x) * (y1 - y)), 1) + area_b = tf.expand_dims(((x1 - x) * (y - y0)), 1) + area_c = tf.expand_dims(((x - x0) * (y1 - y)), 1) + area_d = tf.expand_dims(((x - x0) * (y - y0)), 1) + transformed_image = tf.add_n( + [ + area_a * pixel_values_a, + area_b * pixel_values_b, + area_c * pixel_values_c, + area_d * pixel_values_d, + ] + ) + # Finished interpolation + + transformed_image = tf.reshape( + transformed_image, shape=(batch_size, output_height, output_width, num_channels) + ) + return transformed_image + + +def CTCDecoder(): + def decoder(y_pred): + input_shape = tf.keras.backend.shape(y_pred) + input_length = tf.ones(shape=input_shape[0]) * tf.keras.backend.cast( + input_shape[1], "float32" + ) + unpadded = tf.keras.backend.ctc_decode(y_pred, input_length)[0][0] + unpadded_shape = tf.keras.backend.shape(unpadded) + padded = tf.pad( + unpadded, + paddings=[[0, 0], [0, input_shape[1] - unpadded_shape[1]]], + constant_values=-1, + ) + return padded + + return tf.keras.layers.Lambda(decoder, name="decode") + + +def build_model( + alphabet, + height, + width, + color, + filters, + rnn_units, + dropout, + rnn_steps_to_discard, + pool_size, + stn=True, +): + """Build a Keras CRNN model for character recognition. + + Args: + height: The height of cropped images + width: The width of cropped images + color: Whether the inputs should be in color (RGB) + filters: The number of filters to use for each of the 7 convolutional layers + rnn_units: The number of units for each of the RNN layers + dropout: The dropout to use for the final layer + rnn_steps_to_discard: The number of initial RNN steps to discard + pool_size: The size of the pooling steps + stn: Whether to add a Spatial Transformer layer + """ + assert len(filters) == 7, "7 CNN filters must be provided." + assert len(rnn_units) == 2, "2 RNN filters must be provided." + inputs = keras.layers.Input((height, width, 3 if color else 1)) + x = keras.layers.Permute((2, 1, 3))(inputs) + x = keras.layers.Lambda(lambda x: x[:, :, ::-1])(x) + x = keras.layers.Conv2D( + filters[0], (3, 3), activation="relu", padding="same", name="conv_1" + )(x) + x = keras.layers.Conv2D( + filters[1], (3, 3), activation="relu", padding="same", name="conv_2" + )(x) + x = keras.layers.Conv2D( + filters[2], (3, 3), activation="relu", padding="same", name="conv_3" + )(x) + x = keras.layers.BatchNormalization(name="bn_3")(x) + x = keras.layers.MaxPooling2D(pool_size=(pool_size, pool_size), name="maxpool_3")(x) + x = keras.layers.Conv2D( + filters[3], (3, 3), activation="relu", padding="same", name="conv_4" + )(x) + x = keras.layers.Conv2D( + filters[4], (3, 3), activation="relu", padding="same", name="conv_5" + )(x) + x = keras.layers.BatchNormalization(name="bn_5")(x) + x = keras.layers.MaxPooling2D(pool_size=(pool_size, pool_size), name="maxpool_5")(x) + x = keras.layers.Conv2D( + filters[5], (3, 3), activation="relu", padding="same", name="conv_6" + )(x) + x = keras.layers.Conv2D( + filters[6], (3, 3), activation="relu", padding="same", name="conv_7" + )(x) + x = keras.layers.BatchNormalization(name="bn_7")(x) + if stn: + # pylint: disable=pointless-string-statement + """Spatial Transformer Layer + Implements a spatial transformer layer as described in [1]_. + Borrowed from [2]_: + downsample_fator : float + A value of 1 will keep the orignal size of the image. + Values larger than 1 will down sample the image. Values below 1 will + upsample the image. + example image: height= 100, width = 200 + downsample_factor = 2 + output image will then be 50, 100 + References + ---------- + .. [1] Spatial Transformer Networks + Max Jaderberg, Karen Simonyan, Andrew Zisserman, Koray Kavukcuoglu + Submitted on 5 Jun 2015 + .. [2] https://github.com/skaae/transformer_network/blob/master/transformerlayer.py + .. [3] https://github.com/EderSantana/seya/blob/keras1/seya/layers/attention.py + """ + stn_input_output_shape = ( + width // pool_size**2, + height // pool_size**2, + filters[6], + ) + stn_input_layer = keras.layers.Input(shape=stn_input_output_shape) + locnet_y = keras.layers.Conv2D(16, (5, 5), padding="same", activation="relu")( + stn_input_layer + ) + locnet_y = keras.layers.Conv2D(32, (5, 5), padding="same", activation="relu")( + locnet_y + ) + locnet_y = keras.layers.Flatten()(locnet_y) + locnet_y = keras.layers.Dense(64, activation="relu")(locnet_y) + locnet_y = keras.layers.Dense( + 6, + weights=[ + np.zeros((64, 6), dtype="float32"), + np.array([[1, 0, 0], [0, 1, 0]], dtype="float32").flatten(), + ], + )(locnet_y) + localization_net = keras.models.Model(inputs=stn_input_layer, outputs=locnet_y) + x = keras.layers.Lambda(_transform, output_shape=stn_input_output_shape)( + [x, localization_net(x)] + ) + x = keras.layers.Reshape( + target_shape=( + width // pool_size**2, + (height // pool_size**2) * filters[-1], + ), + name="reshape", + )(x) + + x = keras.layers.Dense(rnn_units[0], activation="relu", name="fc_9")(x) + + rnn_1_forward = keras.layers.LSTM( + rnn_units[0], + kernel_initializer="he_normal", + return_sequences=True, + name="lstm_10", + )(x) + rnn_1_back = keras.layers.LSTM( + rnn_units[0], + kernel_initializer="he_normal", + go_backwards=True, + return_sequences=True, + name="lstm_10_back", + )(x) + rnn_1_add = keras.layers.Add()([rnn_1_forward, rnn_1_back]) + rnn_2_forward = keras.layers.LSTM( + rnn_units[1], + kernel_initializer="he_normal", + return_sequences=True, + name="lstm_11", + )(rnn_1_add) + rnn_2_back = keras.layers.LSTM( + rnn_units[1], + kernel_initializer="he_normal", + go_backwards=True, + return_sequences=True, + name="lstm_11_back", + )(rnn_1_add) + x = keras.layers.Concatenate()([rnn_2_forward, rnn_2_back]) + backbone = keras.models.Model(inputs=inputs, outputs=x) + x = keras.layers.Dropout(dropout, name="dropout")(x) + x = keras.layers.Dense( + len(alphabet) + 1, + kernel_initializer="he_normal", + activation="softmax", + name="fc_12", + )(x) + x = keras.layers.Lambda(lambda x: x[:, rnn_steps_to_discard:])(x) + model = keras.models.Model(inputs=inputs, outputs=x) + + prediction_model = keras.models.Model( + inputs=inputs, outputs=CTCDecoder()(model.output) + ) + labels = keras.layers.Input( + name="labels", shape=[model.output_shape[1]], dtype="float32" + ) + label_length = keras.layers.Input(shape=[1]) + input_length = keras.layers.Input(shape=[1]) + loss = keras.layers.Lambda( + lambda inputs: keras.backend.ctc_batch_cost( + y_true=inputs[0], + y_pred=inputs[1], + input_length=inputs[2], + label_length=inputs[3], + ) + )([labels, model.output, input_length, label_length]) + training_model = keras.models.Model( + inputs=[model.input, labels, input_length, label_length], outputs=loss + ) + return backbone, model, training_model, prediction_model + + +class Recognizer: + """A text detector using the CRNN architecture. + + Args: + alphabet: The alphabet the model should recognize. + build_params: A dictionary of build parameters for the model. + See `keras_ocr.recognition.build_model` for details. + weights: The starting weight configuration for the model. + include_top: Whether to include the final classification layer in the model (set + to False to use a custom alphabet). + """ + + def __init__(self, alphabet=None, weights="wow_ocr", build_params=None): + assert ( + alphabet or weights + ), "At least one of alphabet or weights must be provided." + if weights is not None: + build_params = build_params or PRETRAINED_WEIGHTS[weights]["build_params"] + alphabet = alphabet or PRETRAINED_WEIGHTS[weights]["alphabet"] + build_params = build_params or DEFAULT_BUILD_PARAMS + if alphabet is None: + alphabet = DEFAULT_ALPHABET + self.alphabet = alphabet + self.blank_label_idx = len(alphabet) + ( + self.backbone, + self.model, + self.training_model, + self.prediction_model, + ) = build_model(alphabet=alphabet, **build_params) + if weights is not None: + weights_dict = PRETRAINED_WEIGHTS[weights] + self.model.load_weights( + tools.download_and_verify( + url=weights_dict["weights"]["url"], + filename=weights_dict["weights"]["filename"], + sha256=weights_dict["weights"]["sha256"], + ) + ) + + def get_batch_generator(self, image_generator, batch_size=8, lowercase=False): + """ + Generate batches of training data from an image generator. The generator + should yield tuples of (image, sentence) where image contains a single + line of text and sentence is a string representing the contents of + the image. If a sample weight is desired, it can be provided as a third + entry in the tuple, making each tuple an (image, sentence, weight) tuple. + + Args: + image_generator: An image / sentence tuple generator. The images should + be in color even if the OCR is setup to handle grayscale as they + will be converted here. + batch_size: How many images to generate at a time. + lowercase: Whether to convert all characters to lowercase before + encoding. + """ + y = np.zeros((batch_size, 1)) + if self.training_model is None: + raise Exception("You must first call create_training_model().") + max_string_length = self.training_model.input_shape[1][1] + while True: + batch = [sample for sample, _ in zip(image_generator, range(batch_size))] + images: typing.Union[typing.List[np.ndarray], np.ndarray] + if not self.model.input_shape[-1] == 3: + images = [ + cv2.cvtColor(sample[0], cv2.COLOR_RGB2GRAY)[..., np.newaxis] + for sample in batch + ] + else: + images = [sample[0] for sample in batch] + images = np.array([image.astype("float32") / 255 for image in images]) + sentences = [sample[1].strip() for sample in batch] + if lowercase: + sentences = [sentence.lower() for sentence in sentences] + for c in "".join(sentences): + assert c in self.alphabet, f"Found illegal character: {c}" + assert all(sentences), "Found a zero length sentence." + assert all( + len(sentence) <= max_string_length for sentence in sentences + ), "A sentence is longer than this model can predict." + assert all(" " not in sentence for sentence in sentences), ( + "Strings with multiple sequential spaces are not permitted. " + "See https://github.com/faustomorales/keras-ocr/issues/54" + ) + label_length = np.array([len(sentence) for sentence in sentences])[ + :, np.newaxis + ] + labels = np.array( + [ + [self.alphabet.index(c) for c in sentence] + + [-1] * (max_string_length - len(sentence)) + for sentence in sentences + ] + ) + input_length = np.ones((batch_size, 1)) * max_string_length + if len(batch[0]) == 3: + sample_weights = np.array([sample[2] for sample in batch]) + yield (images, labels, input_length, label_length), y, sample_weights + else: + yield (images, labels, input_length, label_length), y + + def recognize(self, image): + """Recognize text from a single image. + + Args: + image: A pre-cropped image containing characters + """ + image = tools.read_and_fit( + filepath_or_array=image, + width=self.prediction_model.input_shape[2], + height=self.prediction_model.input_shape[1], + cval=0, + ) + if self.prediction_model.input_shape[-1] == 1 and image.shape[-1] == 3: + # Convert color to grayscale + image = cv2.cvtColor(image, code=cv2.COLOR_RGB2GRAY)[..., np.newaxis] + image = image.astype("float32") / 255 + return "".join( + [ + self.alphabet[idx] + for idx in self.prediction_model.predict(image[np.newaxis])[0] + if idx not in [self.blank_label_idx, -1] + ] + ) + + def recognize_from_boxes( + self, images, box_groups, **kwargs + ) -> typing.List[typing.List[str]]: + """Recognize text from images using lists of bounding boxes. + + Args: + images: A list of input images, supplied as numpy arrays with shape + (H, W, 3). + boxes: A list of groups of boxes, one for each image + """ + assert len(box_groups) == len( + images + ), "You must provide the same number of box groups as images." + crops = [] + start_end: typing.List[typing.Tuple[int, int]] = [] + for image, boxes in zip(images, box_groups): + image = tools.read(image) + if self.prediction_model.input_shape[-1] == 1 and image.shape[-1] == 3: + # Convert color to grayscale + image = cv2.cvtColor(image, code=cv2.COLOR_RGB2GRAY) + for box in boxes: + crops.append( + tools.warpBox( + image=image, + box=box, + target_height=self.model.input_shape[1], + target_width=self.model.input_shape[2], + ) + ) + start = 0 if not start_end else start_end[-1][1] + start_end.append((start, start + len(boxes))) + if not crops: + return [[]] * len(images) + X = np.array(crops, dtype="float32") / 255 + if len(X.shape) == 3: + X = X[..., np.newaxis] + predictions = [ + "".join( + [ + self.alphabet[idx] + for idx in row + if idx not in [self.blank_label_idx, -1] + ] + ) + for row in self.prediction_model.predict(X, **kwargs) + ] + return [predictions[start:end] for start, end in start_end] + + def compile(self, *args, **kwargs): + """Compile the training model.""" + if "optimizer" not in kwargs: + kwargs["optimizer"] = "RMSprop" + if "loss" not in kwargs: + kwargs["loss"] = lambda _, y_pred: y_pred + self.training_model.compile(*args, **kwargs) diff --git a/wow_ocr/tools.py b/wow_ocr/tools.py new file mode 100644 index 0000000..3e0b2eb --- /dev/null +++ b/wow_ocr/tools.py @@ -0,0 +1,125 @@ +import os +import hashlib +import numpy as np +import cv2 +import typing +import io +import validators +import urllib.request +import urllib.parse + + +def read(filepath_or_buffer: typing.Union[str, io.BytesIO, np.ndarray]): + """Read a file into an image object + + Args: + filepath_or_buffer: The path to the file, a URL, or any object + with a `read` method (such as `io.BytesIO`) + """ + if isinstance(filepath_or_buffer, np.ndarray): + return filepath_or_buffer + if hasattr(filepath_or_buffer, "read"): + image = np.asarray(bytearray(filepath_or_buffer.read()), dtype=np.uint8) # type: ignore + image = cv2.imdecode(image, cv2.IMREAD_UNCHANGED) + elif isinstance(filepath_or_buffer, str): + if validators.url(filepath_or_buffer): + return read(urllib.request.urlopen(filepath_or_buffer)) + assert os.path.isfile(filepath_or_buffer), ( + "Could not find image at path: " + filepath_or_buffer + ) + image = cv2.imread(filepath_or_buffer) + return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + +def sha256sum(filename): + """Compute the sha256 hash for a file.""" + h = hashlib.sha256() + b = bytearray(128 * 1024) + mv = memoryview(b) + with open(filename, "rb", buffering=0) as f: + for n in iter(lambda: f.readinto(mv), 0): # type: ignore + h.update(mv[:n]) + return h.hexdigest() + + +def get_default_cache_dir(): + return os.environ.get( + "KERAS_OCR_CACHE_DIR", os.path.expanduser(os.path.join("~", ".keras-ocr")) + ) + + +def download_and_verify(url, sha256=None, cache_dir=None, verbose=True, filename=None): + """Download a file to a cache directory and verify it with a sha256 + hash. + + Args: + url: The file to download + sha256: The sha256 hash to check. If the file already exists and the hash + matches, we don't download it again. + cache_dir: The directory in which to cache the file. The default is + `~/.keras-ocr`. + verbose: Whether to log progress + filename: The filename to use for the file. By default, the filename is + derived from the URL. + """ + if cache_dir is None: + cache_dir = get_default_cache_dir() + if filename is None: + filename = os.path.basename(urllib.parse.urlparse(url).path) + filepath = os.path.join(cache_dir, filename) + os.makedirs(os.path.split(filepath)[0], exist_ok=True) + if verbose: + print("Looking for " + filepath) + if not os.path.isfile(filepath) or (sha256 and sha256sum(filepath) != sha256): + if verbose: + print("Downloading " + filepath) + urllib.request.urlretrieve(url, filepath) + assert sha256 is None or sha256 == sha256sum( + filepath + ), "Error occurred verifying sha256." + return filepath + + +def pad(image, width: int, height: int, cval: int = 255): + """Pad an image to a desired size. Raises an exception if image + is larger than desired size. + + Args: + image: The input image + width: The output width + height: The output height + cval: The value to use for filling the image. + """ + output_shape: typing.Union[typing.Tuple[int, int, int], typing.Tuple[int, int]] + if len(image.shape) == 3: + output_shape = (height, width, image.shape[-1]) + else: + output_shape = (height, width) + assert height >= output_shape[0], "Input height must be less than output height." + assert width >= output_shape[1], "Input width must be less than output width." + padded = np.zeros(output_shape, dtype=image.dtype) + cval + padded[: image.shape[0], : image.shape[1]] = image + return padded + + +def resize_image(image, max_scale, max_size): + """Obtain the optimal resized image subject to a maximum scale + and maximum size. + + Args: + image: The input image + max_scale: The maximum scale to apply + max_size: The maximum size to return + """ + if max(image.shape) * max_scale > max_size: + # We are constrained by the maximum size + scale = max_size / max(image.shape) + else: + # We are contrained by scale + scale = max_scale + return ( + cv2.resize( + image, dsize=(int(image.shape[1] * scale), int(image.shape[0] * scale)) + ), + scale, + ) diff --git a/wow_ocr/weights/recognizer/wow_ocr.h5 b/wow_ocr/weights/recognizer/wow_ocr.h5 new file mode 100644 index 0000000..3a001da Binary files /dev/null and b/wow_ocr/weights/recognizer/wow_ocr.h5 differ