tust-machine-learning/test.py

130 lines
2.9 KiB
Python
Executable File

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
import cv2, codecs, os, argparse
import numpy as np
from typing import Sequence
def crop_image(img):
height = img.shape[0]
width = img.shape[1]
v_sum = np.sum(img, axis=0)
h_sum = np.sum(img, axis=1)
left = 0
right = width - 1
top = 0
bottom = height - 1
# 从左往右扫描,遇到非零像素点就以此为字体的左边界
for i in range(width):
if v_sum[i] > 0:
left = i
break
# 从右往左扫描
for i in range(width - 1, -1, -1):
if v_sum[i] > 0:
right = i
break
# 从上往下扫描
for i in range(height):
if h_sum[i] > 0:
top = i
break
# 从下往上扫描
for i in range(height - 1, -1, -1):
if h_sum[i] > 0:
bottom = i
break
return img[top: bottom + 1, left: right + 1]
def image_to_sample(img):
arr = np.asarray(img)
graysum = [0 for i in range(36)]
total = 0
for i in range(arr.shape[0]):
for j in range(arr.shape[1]):
if arr[i, j] > 0:
graysum[(i // 5) * 6 + j // 5] += 1
total += 1
for i in range(len(graysum)):
graysum[i] = graysum[i] / total
return np.mat(graysum)
def normalize_image_size(img, width, height):
img = cv2.resize(img, (width, height))
return img
# 线性分类器类
class Classifier:
def __init__(self, w, w0: float, id0, id1):
self.__w = w
self.__w0 = w0
self.id0 = id0
self.id1 = id1
def do(self, x):
assert len(self.__w) == len(x)
return self.id0 if float(x * self.__w.T + self.__w0) > 0 else self.id1
def load_train(filename):
fp = open(filename, "r")
arr = []
for line in fp:
elems = line.split(':')
char_pair = elems[0].split(',')
w0 = float(elems[1])
w_star = np.mat([[float(num) for num in elems[2].split(',')]])
arr.append(Classifier(w_star, w0, char_pair[0], char_pair[1]))
return arr
def get_label_dict(filename):
f = codecs.open(filename, 'r', encoding='utf-8')
i = 0
dic = {}
for char in f:
dic[i] = str(char).rstrip('\n')
i = i + 1
return dic
def recognise(classfier_list , img):
ret, img = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY_INV)
img = crop_image(img)
img = normalize_image_size(img, 30, 30)
sample = image_to_sample(img)
ret, img = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY);
scoreboard = {}
for item in classifiers:
cls = item.do(sample)
if cls not in scoreboard:
scoreboard[cls] = 0
scoreboard[cls] += 1
max_score = 0
max_id = ""
for k, v in scoreboard.items():
if v > max_score:
max_score = v
max_id = k
return max_id, max_score
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-t", required=True)
parser.add_argument("-l", required=True)
parser.add_argument("image_dir")
args = parser.parse_args()
classifiers = load_train(args.t)
label = get_label_dict(args.l)
path = args.image_dir
for file in sorted(os.listdir(path)):
file = os.path.join(path, file)
img = cv2.imread(file, cv2.IMREAD_GRAYSCALE)
id, scord = recognise(classifiers, img)
print("%s --> '%s'" % (file, label[int(id)]))