【Keras+计算机视觉+Tensorflow】OCR文字识别实战(附源码和数据集 超详细必看)

showswoller 2024-07-09 14:01:11 阅读 69

需要源码和数据集请点赞关注收藏后评论区留言私信~~~

一、OCR文字识别简介

利用计算机自动识别字符的技术,是模式识别应用的一个重要领域。人们在生产和生活中,要处理大量的文字、报表和文本。为了减轻人们的劳动,提高处理效率,从上世纪50年代起就开始探讨文字识别方法,并研制出光学字符识别器。

OCR(Optical Character Recognition)图像文字识别是人工智能的重要分支,赋予计算机人眼的功能,使其可以看图识字,图像文字识别系统流程一般分为图像采集、文字检测、文字识别以及结果输出四部分。

 二、OCR文字识别项目实战

1:数据集简介

MSRA-TD500该数据集共包含500 张自然场景图像,其分辨率在1296 ´ 864至920 ´ 1280 之间,涵盖了室内商场、标识牌、室外街道、广告牌等大多数场,文本包含中文和英文,有着不同的字体、大小和倾斜方向,部分数据集图像如下图所示。

 数据集项目结构如下 分为训练集和测试集

2:项目结构

整体项目结构如下 上面是一些算法和模型比如CRAFT CRNN的定义,下面是测试代码

 CRAFT算法实现文本行的检测如图下图所示。首先将完整的文字区域输入CRAFT文字检测网络,得到字符级的文字得分结果热图(Text Score)和字符级文本连接得分热图(Link Score),最后根据连通域得到每个文本行的位置

3:效果展示 

开始运行代码

输出运行结果 可以放入不同图片进行测试 

 

 

 

 

三、代码 

部分代码如下 需要全部代码和数据集请点赞关注收藏后评论区留言私信~~~

 

<code>"""This script demonstrates how to train the model

on the SynthText90 using multiple GPUs."""

# pylint: disable=invalid-name

import datetime

import argparse

import math

import random

import string

import functools

import itertools

import os

import tarfile

import urllib.request

import numpy as np

import cv2

import imgaug

import tqdm

import tensorflow as tf

import keras_ocr

# pylint: disable=redefined-outer-name

def get_filepaths(data_path, split):

"""Get the list of filepaths for a given split (train, val, or test)."""

with open(os.path.join(data_path, f'mnt/ramdisk/max/90kDICT32px/annotation_{split}.txt'),

'r') as text_file:

filepaths = [

os.path.join(data_path, 'mnt/ramdisk/max/90kDICT32px',

line.split(' ')[0][2:]) for line in text_file.readlines()

]

return filepaths

# pylint: disable=redefined-outer-name

def download_extract_and_process_dataset(data_path):

"""Download and extract the synthtext90 dataset."""

archive_filepath = os.path.join(data_path, 'mjsynth.tar.gz')

extraction_directory = os.path.join(data_path, 'mnt')

if not os.path.isfile(archive_filepath) and not os.path.isdir(extraction_directory):

print('Downloading the dataset.')

urllib.request.urlretrieve("https://www.robots.ox.ac.uk/~vgg/data/text/mjsynth.tar.gz",

archive_filepath)

if not os.path.isdir(extraction_directory):

print('Extracting files.')

with tarfile.open(os.path.join(data_path, 'mjsynth.tar.gz')) as tfile:

tfile.extractall(data_path)

def get_image_generator(filepaths, augmenter, width, height):

"""Get an image generator for a list of SynthText90 filepaths."""

filepaths = filepaths.copy()

for filepath in itertools.cycle(filepaths):

text = filepath.split(os.sep)[-1].split('_')[1].lower()

image = cv2.imread(filepath)

if image is None:

print(f'An error occurred reading: {filepath}')

image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

image = keras_ocr.tools.fit(image,

width=width,

height=height,

cval=np.random.randint(low=0, high=255, size=3).astype('uint8'))

if augmenter is not None:

image = augmenter.augment_image(image)

if filepath == filepaths[-1]:

random.shuffle(filepaths)

yield image, text

if __name__ == '__main__':

parser = argparse.ArgumentParser(description='Process some integers.')code>

parser.add_argument('--model_id',

default='recognizer',code>

help='The name to use for saving model checkpoints.')code>

parser.add_argument(

'--data_path',

default='.',code>

help='The path to the directory containing the dataset and where we will put our logs.')code>

parser.add_argument(

'--logs_path',

default='./logs',code>

help=(

'The path to where logs and checkpoints should be stored. '

'If a checkpoint matching "model_id" is found, training will resume from that point.'))

parser.add_argument('--batch_size', default=16, help='The training batch size to use.')code>

parser.add_argument('--no-file-verification', dest='verify_files', action='store_false')code>

parser.set_defaults(verify_files=True)

args = parser.parse_args()

weights_path = os.path.join(args.logs_path, args.model_id + '.h5')

csv_path = os.path.join(args.logs_path, args.model_id + '.csv')

download_extract_and_process_dataset(args.data_path)

with tf.distribute.MirroredStrategy().scope():

recognizer = keras_ocr.recognition.Recognizer(alphabet=string.digits +

string.ascii_lowercase,

height=31,

width=200,

stn=False,

optimizer=tf.keras.optimizers.RMSprop(),

weights=None)

if os.path.isfile(weights_path):

print('Loading saved weights and creating new version.')

dt_string = datetime.datetime.now().isoformat()

weights_path = os.path.join(args.logs_path, args.model_id + '_' + dt_string + '.h5')

csv_path = os.path.join(args.logs_path, args.model_id + '_' + dt_string + '.csv')

recognizer.model.load_weights(weights_path)

augmenter = imgaug.augmenters.Sequential([

imgaug.augmenters.Multiply((0.9, 1.1)),

imgaug.augmenters.GammaContrast(gamma=(0.5, 3.0)),

imgaug.augmenters.Invert(0.25, per_channel=0.5)

])

os.makedirs(args.logs_path, exist_ok=True)

training_filepaths, validation_filepaths = [

get_filepaths(data_path=args.data_path, split=split) for split in ['train', 'val']

]

if args.verify_files:

assert all(

os.path.isfile(filepath) for

filepath in tqdm.tqdm(training_filepaths + validation_filepaths,

desc='Checking filepaths.')), 'Some files appear to be missing.'code>

(training_image_generator, training_steps), (validation_image_generator, validation_steps) = [

(get_image_generator(

filepaths=filepaths,

augmenter=augmenter,

width=recognizer.model.input_shape[2],

height=recognizer.model.input_shape[1],

), math.ceil(len(filepaths) / args.batch_size))

for filepaths, augmenter in [(training_filepaths, augmenter), (validation_filepaths, None)]

]

training_generator, validation_generator = [

tf.data.Dataset.from_generator(

functools.partial(recognizer.get_batch_generator,

image_generator=image_generator,

batch_size=args.batch_size),

output_types=((tf.float32, tf.int64, tf.float64, tf.int64), tf.float64),

output_shapes=((tf.TensorShape([None, 31, 200, 1]), tf.TensorShape([None, recognizer.training_model.input_shape[1][1]]),

tf.TensorShape([None,

1]), tf.TensorShape([None,

1])), tf.TensorShape([None, 1])))

for image_generator in [training_image_generator, validation_image_generator]

]

callbacks = [

tf.keras.callbacks.EarlyStopping(monitor='val_loss',code>

min_delta=0,

patience=10,

restore_best_weights=False),

tf.keras.callbacks.ModelCheckpoint(weights_path, monitor='val_loss', save_best_only=True),code>

tf.keras.callbacks.CSVLogger(csv_path)

]

recognizer.training_model.fit(

x=training_generator,

steps_per_epoch=training_steps,

validation_steps=validation_steps,

validation_data=validation_generator,

callbacks=callbacks,

epochs=1000,

)

"""This script is what was used to generate the

backgrounds.zip and fonts.zip files.

"""

# pylint: disable=invalid-name,redefined-outer-name

import json

import urllib.request

import urllib.parse

import concurrent

import shutil

import zipfile

import glob

import os

import numpy as np

import tqdm

import cv2

import keras_ocr

if __name__ == '__main__':

fonts_commit = 'a0726002eab4639ee96056a38cd35f6188011a81'

fonts_sha256 = 'e447d23d24a5bbe8488200a058cd5b75b2acde525421c2e74dbfb90ceafce7bf'

fonts_source_zip_filepath = keras_ocr.tools.download_and_verify(

url=f'https://github.com/google/fonts/archive/{fonts_commit}.zip',

cache_dir='.',code>

sha256=fonts_sha256)

shutil.rmtree('fonts-raw', ignore_errors=True)

with zipfile.ZipFile(fonts_source_zip_filepath) as zfile:

zfile.extractall(path='fonts-raw')code>

retained_fonts = []

sha256s = []

basenames = []

# The blacklist includes fonts that, at least for the English alphabet, were found

# to be illegible (e.g., thin fonts) or render in unexpected ways (e.g., mathematics

# fonts).

blacklist = [

'AlmendraDisplay-Regular.ttf', 'RedactedScript-Bold.ttf', 'RedactedScript-Regular.ttf',

'Sevillana-Regular.ttf', 'Mplus1p-Thin.ttf', 'Stalemate-Regular.ttf', 'jsMath-cmsy10.ttf',

'Codystar-Regular.ttf', 'AdventPro-Thin.ttf', 'RoundedMplus1c-Thin.ttf',

'EncodeSans-Thin.ttf', 'AlegreyaSans-ThinItalic.ttf', 'AlegreyaSans-Thin.ttf',

'FiraSans-Thin.ttf', 'FiraSans-ThinItalic.ttf', 'WorkSans-Thin.ttf',

'Tomorrow-ThinItalic.ttf', 'Tomorrow-Thin.ttf', 'Italianno-Regular.ttf',

'IBMPlexSansCondensed-Thin.ttf', 'IBMPlexSansCondensed-ThinItalic.ttf',

'Lato-ExtraLightItalic.ttf', 'LibreBarcode128Text-Regular.ttf',

'LibreBarcode39-Regular.ttf', 'LibreBarcode39ExtendedText-Regular.ttf',

'EncodeSansExpanded-ExtraLight.ttf', 'Exo-Thin.ttf', 'Exo-ThinItalic.ttf',

'DrSugiyama-Regular.ttf', 'Taviraj-ThinItalic.ttf', 'SixCaps.ttf', 'IBMPlexSans-Thin.ttf',

'IBMPlexSans-ThinItalic.ttf', 'AdobeBlank-Regular.ttf',

'FiraSansExtraCondensed-ThinItalic.ttf', 'HeptaSlab[wght].ttf', 'Karla-Italic[wght].ttf',

'Karla[wght].ttf', 'RalewayDots-Regular.ttf', 'FiraSansCondensed-ThinItalic.ttf',

'jsMath-cmex10.ttf', 'LibreBarcode39Text-Regular.ttf', 'LibreBarcode39Extended-Regular.ttf',

'EricaOne-Regular.ttf', 'ArimaMadurai-Thin.ttf', 'IBMPlexSerif-ExtraLight.ttf',

'IBMPlexSerif-ExtraLightItalic.ttf', 'IBMPlexSerif-ThinItalic.ttf', 'IBMPlexSerif-Thin.ttf',

'Exo2-Thin.ttf', 'Exo2-ThinItalic.ttf', 'BungeeOutline-Regular.ttf', 'Redacted-Regular.ttf',

'JosefinSlab-ThinItalic.ttf', 'GothicA1-Thin.ttf', 'Kanit-ThinItalic.ttf', 'Kanit-Thin.ttf',

'AlegreyaSansSC-ThinItalic.ttf', 'AlegreyaSansSC-Thin.ttf', 'Chathura-Thin.ttf',

'Blinker-Thin.ttf', 'Italiana-Regular.ttf', 'Miama-Regular.ttf', 'Grenze-ThinItalic.ttf',

'LeagueScript-Regular.ttf', 'BigShouldersDisplay-Thin.ttf', 'YanoneKaffeesatz[wght].ttf',

'BungeeHairline-Regular.ttf', 'JosefinSans-Thin.ttf', 'JosefinSans-ThinItalic.ttf',

'Monofett.ttf', 'Raleway-ThinItalic.ttf', 'Raleway-Thin.ttf', 'JosefinSansStd-Light.ttf',

'LibreBarcode128-Regular.ttf'

]

for filepath in tqdm.tqdm(sorted(glob.glob('fonts-raw/**/**/**/*.ttf')),

desc='Filtering fonts.'):code>

sha256 = keras_ocr.tools.sha256sum(filepath)

basename = os.path.basename(filepath)

# We check the sha256 and filenames because some of the fonts

# in the repository are duplicated (see TRIVIA.md).

if sha256 in sha256s or basename in basenames or basename in blacklist:

continue

sha256s.append(sha256)

basenames.append(basename)

retained_fonts.append(filepath)

retained_font_families = set([filepath.split(os.sep)[-2] for filepath in retained_fonts])

added = []

with zipfile.ZipFile(file='fonts.zip', mode='w') as zfile:code>

for font_family in tqdm.tqdm(retained_font_families, desc='Saving ZIP file.'):code>

# We want to keep all the metadata files plus

# the retained font files. And we don't want

# to add the same file twice.

files = [

input_filepath for input_filepath in glob.glob(f'fonts-raw/**/**/{font_family}/*')

if input_filepath not in added and

(input_filepath in retained_fonts or os.path.splitext(input_filepath)[1] != '.ttf')

]

added.extend(files)

for input_filepath in files:

zfile.write(filename=input_filepath,

arcname=os.path.join(*input_filepath.split(os.sep)[-2:]))

print('Finished saving fonts file.')

# pylint: disable=line-too-long

url = (

'https://commons.wikimedia.org/w/api.php?action=query&generator=categorymembers&gcmtype=file&format=json'

'&gcmtitle=Category:Featured_pictures_on_Wikimedia_Commons&prop=imageinfo&gcmlimit=50&iiprop=url&iiurlwidth=1024'

)

gcmcontinue = None

max_responses = 300

responses = []

for responseCount in tqdm.tqdm(range(max_responses)):

current_url = url

if gcmcontinue is not None:

current_url += f'&continue=gcmcontinue||&gcmcontinue={gcmcontinue}'

with urllib.request.urlopen(url=current_url) as response:

current = json.loads(response.read())

responses.append(current)

gcmcontinue = None if 'continue' not in current else current['continue']['gcmcontinue']

if gcmcontinue is None:

break

print('Finished getting list of images.')

# We want to avoid animated images as well as icon files.

image_urls = []

for response in responses:

image_urls.extend(

[page['imageinfo'][0]['thumburl'] for page in response['query']['pages'].values()])

image_urls = [url for url in image_urls if url.lower().endswith('.jpg')]

shutil.rmtree('backgrounds', ignore_errors=True)

os.makedirs('backgrounds')

assert len(image_urls) == len(set(image_urls)), 'Duplicates found!'

with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:

futures = [

executor.submit(keras_ocr.tools.download_and_verify,

url=url,

cache_dir='./backgrounds',code>

verbose=False) for url in image_urls

]

for _ in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)):

pass

for filepath in glob.glob('backgrounds/*.JPG'):

os.rename(filepath, filepath.lower())

print('Filtering images by aspect ratio and maximum contiguous contour.')

image_paths = np.array(sorted(glob.glob('backgrounds/*.jpg')))

def compute_metrics(filepath):

image = keras_ocr.tools.read(filepath)

aspect_ratio = image.shape[0] / image.shape[1]

contour, _ = keras_ocr.tools.get_maximum_uniform_contour(image, fontsize=40)

area = cv2.contourArea(contour) if contour is not None else 0

return aspect_ratio, area

metrics = np.array([compute_metrics(filepath) for filepath in tqdm.tqdm(image_paths)])

filtered_paths = image_paths[(metrics[:, 0] < 3 / 2) & (metrics[:, 0] > 2 / 3) &

(metrics[:, 1] > 1e6)]

detector = keras_ocr.detection.Detector()

paths_with_text = [

filepath for filepath in tqdm.tqdm(filtered_paths) if len(

detector.detect(

images=[keras_ocr.tools.read_and_fit(filepath, width=640, height=640)])[0]) > 0

]

filtered_paths = np.array([path for path in filtered_paths if path not in paths_with_text])

filtered_basenames = list(map(os.path.basename, filtered_paths))

basename_to_url = {

os.path.basename(urllib.parse.urlparse(url).path).lower(): url

for url in image_urls

}

filtered_urls = [basename_to_url[basename.lower()] for basename in filtered_basenames]

assert len(filtered_urls) == len(filtered_paths)

removed_paths = [filepath for filepath in image_paths if filepath not in filtered_paths]

for filepath in removed_paths:

os.remove(filepath)

with open('backgrounds/urls.txt', 'w') as f:

f.write('\n'.join(filtered_urls))

with zipfile.ZipFile(file='backgrounds.zip', mode='w') as zfile:code>

for filepath in tqdm.tqdm(filtered_paths.tolist() + ['backgrounds/urls.txt'],

desc='Saving ZIP file.'):code>

zfile.write(filename=filepath, arcname=os.path.basename(filepath.lower()))

创作不易 觉得有帮助请点赞关注收藏~~~



声明

本文内容仅代表作者观点,或转载于其他网站,本站不以此文作为商业用途
如有涉及侵权,请联系本站进行删除
转载本站原创文章,请注明来源及作者。