失眠网,内容丰富有趣,生活中的好帮手!
失眠网 > 深度学习框架tensorflow学习与应用6(防止过拟合dropout keep_prob =tf.placeholder(tf.float32))

深度学习框架tensorflow学习与应用6(防止过拟合dropout keep_prob =tf.placeholder(tf.float32))

时间:2021-11-17 15:46:56

相关推荐

深度学习框架tensorflow学习与应用6(防止过拟合dropout keep_prob =tf.placeholder(tf.float32))

import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data# In[3]:# 载入数据集mnist = input_data.read_data_sets("MNIST_data", one_hot=True)# 每个批次的大小batch_size = 100# 计算一共有多少个批次n_batch = mnist.train.num_examples // batch_size# 定义两个placeholderx = tf.placeholder(tf.float32, [None, 784])y = tf.placeholder(tf.float32, [None, 10])keep_prob =tf.placeholder(tf.float32)# 创建一个简单的神经网络W1 = tf.Variable(tf.truncated_normal([784, 2000], stddev=0.1))b1 = tf.Variable(tf.zeros([2000])+0.1)L1 = tf.nn.tanh(tf.matmul(x, W1) + b1)L1_drop = tf.nn.dropout(L1, keep_prob)W2 = tf.Variable(tf.truncated_normal([2000, 2000], stddev=0.1))b2= tf.Variable(tf.zeros([2000])+ 0.1)L2= tf.nn.tanh(tf.matmul(L1_drop, W2) + b2)L2_drop = tf.nn.dropout(L2, keep_prob)W3 = tf.Variable(tf.truncated_normal([2000, 1000], stddev=0.1))b3 = tf.Variable(tf.zeros([1000]) + 0.1)L3 = tf.nn.tanh(tf.matmul(L2_drop, W3) + b3)L3_drop = tf.nn.dropout(L3, keep_prob)W4 = tf.Variable(tf.truncated_normal([1000, 10], stddev=0.1))b4 = tf.Variable(tf.zeros([10]) + 0.1)prediction = tf.nn.softmax(tf.matmul(L3_drop, W4) + b2)# 交叉熵函数loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction))# 使用梯度下降法train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)# 初始化变量init = tf.global_variables_initializer()# 结果存放在一个布尔型列表中correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1)) # argmax返回一维张量中最大的值所在的位置# 求准确率accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))with tf.Session() as sess:sess.run(init)for epoch in range(31):for batch in range(n_batch):batch_xs, batch_ys = mnist.train.next_batch(batch_size)sess.run(train_step, feed_dict={x: batch_xs, y: batch_ys})test_acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels, keep_prob: 1.0})train_acc = sess.run(accuracy, feed_dict={x: mnist.train.images, y: mnist.train.labels, keep_prob: 1.0})print("Iter " + str(epoch) + ",Testing Accuracy " + str(test_acc)+",Train Accuracy" + str(train_acc))

再迭代300次 3000次都一样的,基本上提高不了了

Iter 0,Testing Accuracy 0.9208,Training Accuracy 0.91416365

Iter 1,Testing Accuracy 0.9298,Training Accuracy 0.9265636

Iter 2,Testing Accuracy 0.9368,Training Accuracy 0.9341818

Iter 3,Testing Accuracy 0.9423,Training Accuracy 0.9419091

Iter 4,Testing Accuracy 0.9439,Training Accuracy 0.94298184

Iter 5,Testing Accuracy 0.9482,Training Accuracy 0.94812727

Iter 6,Testing Accuracy 0.9482,Training Accuracy 0.95054543

Iter 7,Testing Accuracy 0.9521,Training Accuracy 0.95414543

Iter 8,Testing Accuracy 0.9535,Training Accuracy 0.95536363

Iter 9,Testing Accuracy 0.9543,Training Accuracy 0.95714545

Iter 10,Testing Accuracy 0.9569,Training Accuracy 0.9596909

Iter 11,Testing Accuracy 0.956,Training Accuracy 0.96067274

Iter 12,Testing Accuracy 0.9588,Training Accuracy 0.96314543

Iter 13,Testing Accuracy 0.9607,Training Accuracy 0.9644182

Iter 14,Testing Accuracy 0.9607,Training Accuracy 0.9657818

Iter 15,Testing Accuracy 0.9613,Training Accuracy 0.9667818

Iter 16,Testing Accuracy 0.9636,Training Accuracy 0.96834546

Iter 17,Testing Accuracy 0.9636,Training Accuracy 0.9676909

Iter 18,Testing Accuracy 0.9635,Training Accuracy 0.96983635

Iter 19,Testing Accuracy 0.9654,Training Accuracy 0.97056365

Iter 20,Testing Accuracy 0.9667,Training Accuracy 0.97096366

Iter 21,Testing Accuracy 0.9666,Training Accuracy 0.9719091

Iter 22,Testing Accuracy 0.9677,Training Accuracy 0.97285455

Iter 23,Testing Accuracy 0.9666,Training Accuracy 0.9730545

Iter 24,Testing Accuracy 0.968,Training Accuracy 0.9742182

Iter 25,Testing Accuracy 0.9684,Training Accuracy 0.97458184

Iter 26,Testing Accuracy 0.9696,Training Accuracy 0.9758545

Iter 27,Testing Accuracy 0.9688,Training Accuracy 0.9760909

Iter 28,Testing Accuracy 0.9699,Training Accuracy 0.9764364

Iter 29,Testing Accuracy 0.971,Training Accuracy 0.9768909

Iter 30,Testing Accuracy 0.9705,Training Accuracy 0.97776365

使用dropout进行防止过拟合

但是用dropout收敛速度会变慢,到第30次才能达到97.

Iter 0,Testing Accuracy 0.915,Training Accuracy 0.91123635

Iter 1,Testing Accuracy 0.9298,Training Accuracy 0.92745453

Iter 2,Testing Accuracy 0.9366,Training Accuracy 0.9358364

Iter 3,Testing Accuracy 0.9415,Training Accuracy 0.9400909

Iter 4,Testing Accuracy 0.9444,Training Accuracy 0.94485456

Iter 5,Testing Accuracy 0.9461,Training Accuracy 0.9476909

Iter 6,Testing Accuracy 0.9498,Training Accuracy 0.9505091

Iter 7,Testing Accuracy 0.9528,Training Accuracy 0.9532727

Iter 8,Testing Accuracy 0.9531,Training Accuracy 0.9565091

Iter 9,Testing Accuracy 0.955,Training Accuracy 0.9577818

Iter 10,Testing Accuracy 0.9555,Training Accuracy 0.95861816

Iter 11,Testing Accuracy 0.9579,Training Accuracy 0.9612909

Iter 12,Testing Accuracy 0.9595,Training Accuracy 0.96385455

Iter 13,Testing Accuracy 0.9604,Training Accuracy 0.96523637

Iter 14,Testing Accuracy 0.9609,Training Accuracy 0.96592724

Iter 15,Testing Accuracy 0.9611,Training Accuracy 0.96647274

Iter 16,Testing Accuracy 0.9618,Training Accuracy 0.9676727

Iter 17,Testing Accuracy 0.9627,Training Accuracy 0.9693091

Iter 18,Testing Accuracy 0.9642,Training Accuracy 0.96983635

Iter 19,Testing Accuracy 0.9656,Training Accuracy 0.97049093

Iter 20,Testing Accuracy 0.9637,Training Accuracy 0.9705273

Iter 21,Testing Accuracy 0.9659,Training Accuracy 0.97261816

Iter 22,Testing Accuracy 0.9672,Training Accuracy 0.9735818

Iter 23,Testing Accuracy 0.9673,Training Accuracy 0.9735091

Iter 24,Testing Accuracy 0.9677,Training Accuracy 0.97409093

Iter 25,Testing Accuracy 0.9704,Training Accuracy 0.9747818

Iter 26,Testing Accuracy 0.9696,Training Accuracy 0.9760727

Iter 27,Testing Accuracy 0.9699,Training Accuracy 0.9764

Iter 28,Testing Accuracy 0.9687,Training Accuracy 0.9764364

Iter 29,Testing Accuracy 0.968,Training Accuracy 0.9772

Iter 30,Testing Accuracy 0.9707,Training Accuracy 0.97778183

如果使用tensorbored就可以看到曲线图。

为什么要用dropout?

因为从第二段代码我们可以看出,其在测试集的准确率和在训练集的准确率差不多

如果是googlenet和alexnet这么大的网络进行自己分类时,如果我们用goolenet训练500张照片分5类,那么就会过导致过拟合的

因为样本很少,但是我们dropout百分50的神经元那么就会得到好的效果。

用复杂的网络去训练小样本时我们就可以看到dropout的重要性了。

如果觉得《深度学习框架tensorflow学习与应用6(防止过拟合dropout keep_prob =tf.placeholder(tf.float32))》对你有帮助,请点赞、收藏,并留下你的观点哦!

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。