一、前文回顾
前面我们说了那么多理论,什么神经网络,激活函数,Softmax函数等等,那么深度学习到底能干什么?这一节我们就来尝试让小D识别手写数字,看看效果怎么样。这里用到的手写数字就是来自于MNIST手写数字数据集。
二、MNIST手写数字数据集
1.什么是MNIST手写数字数据集
・0~9的数字组成
・训练图像:6万张
・测试图像:1万张
・28*28=784像素的灰度图像(单通道),各像素值0~255之间
简而言之,就是有7万用手写的数字图片,每个图片大小为28*28。就像下面这样子↓

据说这都是不同的人手写的数字,说实话有些我都不认识。今天就用这个来验证小D,看它的识别率能达到多少。
2.MNIST手写数字数据集下载
http://yann.lecun.com/exdb/mnist/
上面就可以手动下载。(Python下载略)
三、重构神经网络
1.为什么要重构神经网络
上一节我们给小D设计的神经网络有两个输入,两个输出。但是今天我们要让小D识别0到9的十个数字,那么两个输出肯定不够,需要十个输出,每个输出代表0到9之间的一个数字,这样子小D经过神经网络的识别以后就会输出每个数字的概率,概率最高的就是小D认为正确的数字。同样的道理,因为每幅图像有28*28=784个像素,所以我们的输入神经元也需要784个,每个神经元接受图片里面的一个像素值。
2.重构什么样的神经网络
重构以后的小D大概长这样子。

现在小D有784个输入神经元(x1~x784),中间层只有一层50个神经元(s1~s50),输出层有10个神经元(y0~y9),分别代表0到9十个手写数字。每一层的每一个神经元都和下一层的每个神经元连接(专业术语叫全连接网络)。
※中间层为什么只有一层50个神经元?因为我懒的弄那么多,画图很麻烦。
四、识别数字
Python实现神经网络
import os
import pickle
import sys
import numpy as np
# sigmoid和softmax参考以前的内容
from common.functions import sigmoid, softmax
# 这个模块代码太长,想要的发消息
from dataset.mnist import load_mnist
sys.path.append(os.pardir)
def get_data():
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)
return x_test, t_test
def init_network():
with open("params_two_layer.pkl", 'rb') as f:
network = pickle.load(f)
return network
def predict(network, x):
# 读取权重参数和偏置项参数
W1, W2 = network['W1'], network['W2']
b1, b2 = network['b1'], network['b2']
# 第一层网络 a1=W1*x+b1
a1 = np.dot(x, W1) + b1
# 用Sigmoid激活(非线性变换)
z1 = sigmoid(a1)
# 第二层网络 a2=W2*z1+b2
a2 = np.dot(z1, W2) + b2
# 输出层用·softmax函数分类
y = softmax(a2)
return y
if __name__ == "__main__":
# 读取测试数据和测试数据标签
x, t = get_data()
# 初始化神经网络(预先训练好的)
network = init_network()
# 批处理数,每次处理100张图片
batch_size = 100
accuracy_cnt = 0
for i in range(0, len(x), batch_size):
# 从测试数据里面每次拿出100张图片进行处理
x_batch = x[i:i + batch_size]
# 取出来的100张图片输出神经网络进行识别,会返回每个数字的概率(因为是每次识别100张图片,所以返回的是100*10的数组)
y_batch = predict(network, x_batch)
# 找出概率最大的那个数字
p = np.argmax(y_batch, axis=1)
# 识别出来的数字和测试标签比较,如果正确就把accuracy_cnt+正确个数(100张图片有97张是正确的话,accuracy_cnt累加97)
accuracy_cnt += np.sum(p == t[i:i + batch_size])
# 输出最终的识别率
print("Accuracy:" + str(float(accuracy_cnt) / len(x)))
总结一下识别手写数字的步骤
①加载MNIST手写数字数据集
因为我们只是测试识别率,所以只需要数据集里面的10000张测试图像,以及对应的标签。(标签就相当于正确答案,是为了比对我们的神经网络识别的结果是否正确)
②加载权重参数和偏置项参数
这里用的参数是我提前训练好的参数,我们只需要把它装进小D的大脑里面就OK
③输入图像进行识别
把每张图片都平展开,变成一维数组输入到小D的784个输入神经元,然后从小D十个输出神经元里面选择最大值(也就是概率最大),并和标签(正确答案)进行比对,如果正确我们的正确计数就累积加1,等1万张图片全部识别完以后,计算识别率。
注意:我们这里是每次输出100张图片,因为这样子速度更快。毕竟CPU的计算速度比起加载数据要快多了,如果一张一张加载的话,那么执行时间都用来加载图片了,导致程序运行很慢。
执行一下看看结果:Accuracy:0.947,正确率94.7%,还不错。
五、总结
给小D看了10000张不同的人手写的数字,它可以认出其中的94.7%,是一个很不错的成绩,但是小D是怎么学会认识数字的?重点来了,下一节进入神经网络最难也最有趣的地方:【教小D识数字】
Kommentarer