您的位置 首页 编程知识

TensorFlow MNIST手写数字分类:训练集准确率极低,问题出在哪儿?

TensorFlow MNIST手写数字分类:低训练集准确率的根本原因及修复方案 在使用TensorFlow进…

TensorFlow MNIST手写数字分类:训练集准确率极低,问题出在哪儿?

TensorFlow MNIST手写数字分类:低训练集准确率的根本原因及修复方案

在使用TensorFlow进行MNIST手写数字分类时,许多开发者会遇到一个难题:即使对训练集和测试集进行了像素归一化,训练集的准确率仍然异常低。本文将深入分析此问题,并结合代码示例提供有效的解决方案。

问题根源在于原始代码中y_p的计算方式。代码中y_pred = tf.nn.softmax(tf.matmul(X, W) + B)这一行,错误地将softmax函数应用于未经softmax处理的预测结果。tf.nn.softmax_cross_entropy_with_los函数期望输入的是未经softmax处理的预测值(logits)。原始代码却将softmax后的结果传入该函数,导致交叉熵损失函数计算错误,最终影响模型训练效果,导致训练集准确率极低。

为了解决这个问题,我们需要调整y_pred的计算方式以及准确率的计算方式。正确的做法是在损失函数计算后应用softmax函数获取最终的预测概率,而损失函数计算则使用未经softmax处理的预测值。

修正后的代码如下:

# 导入必要的库 import numpy as np import tensorflow as tf import matplotlib.pyplot as plt from tensorflow.examples.tutorials.mnist import input_data import os import pickle  # 超参数设置 numClasses = 10 inputSize = 784 batch_size = 64 learning_rate = 0.05  # 下载数据集 mnist = input_data.read_data_sets('original_data/', one_hot=True)  train_img = mnist.train.images train_label = mnist.train.labels test_img = mnist.test.images test_label = mnist.test.labels train_img /= 255.0 test_img /= 255.0   X = tf.compat.v1.placeholder(tf.float32, shape=[None, inputSize]) y = tf.compat.v1.placeholder(tf.float32, shape=[None, numClasses]) W = tf.Variable(tf.random_normal([inputSize, numClasses], stddev=0.1)) B = tf.Variable(tf.constant(0.1), [numClasses]) y_pred = tf.matmul(X, W) + B  # 修正:移除softmax  loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=y_pred)) + 0.01 * tf.nn.l2_loss(W) opt = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)  correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(tf.nn.softmax(y_pred), 1))  # 修正:在计算准确率时应用softmax accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))  saver = tf.train.Saver() multiclass_parameters = {}  # 运行 with tf.Session() as sess:     sess.run(tf.global_variables_initializer())      # 开始训练     for epoch in range(20):         total_batch = int(len(train_img) / batch_size)          for batch in range(total_batch):             batch_input = train_img[batch * batch_size: (batch + 1) * batch_size]             batch_label = train_label[batch * batch_size: (batch + 1) * batch_size]              _, trainingLoss = sess.run([opt, loss], feed_dict={X: batch_input, y: batch_label})          train_acc = sess.run(accuracy, feed_dict={X: train_img, y: train_label})         print("Epoch %d Training Accuracy %g" % (epoch + 1, train_acc))
登录后复制

通过以上修正,tf.nn.softmax_cross_entropy_with_logits函数能够正确计算损失,模型得以有效训练,最终显著提升训练集准确率。 请注意,在计算最终预测概率时,仍然需要使用tf.nn.softmax函数。

以上就是TensorFlow MNIST手写数字分类:训练集准确率极低,问题出在哪儿?的详细内容,更多请关注php中文网其它相关文章!

本文来自网络,不代表四平甲倪网络网站制作专家立场,转载请注明出处:http://www.elephantgpt.cn/7889.html

作者: nijia

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注

联系我们

联系我们

18844404989

在线咨询: QQ交谈

邮箱: 641522856@qq.com

工作时间:周一至周五,9:00-17:30,节假日休息

关注微信
微信扫一扫关注我们

微信扫一扫关注我们

关注微博
返回顶部