Giter VIP home page Giter VIP logo

tensorflow-cnn-tutorial's Introduction

Tensorflow-CNN-Tutorial

这是一个手把手教你用Tensorflow构建卷机网络(CNN)进行图像分类的教程。完整代码可在Github中下载:https://github.com/hujunxianligong/Tensorflow-CNN-Tutorial。教程并没有使用MNIST数据集,而是使用了真实的图片文件,并且教程代码包含了模型的保存、加载等功能,因此希望在日常项目中使用Tensorflow的朋友可以参考这篇教程。

概述

  • 代码利用卷积网络完成一个图像分类的功能
  • 训练完成后,模型保存在model文件中,可直接使用模型进行线上分类
  • 同一个代码包括了训练和测试阶段,通过修改train参数为True和False控制训练和测试

数据准备

教程的图片从Cifar数据集中获取,download_cifar.py从Keras自带的Cifar数据集中获取了部分Cifar数据集,并将其转换为jpg图片。

默认从Cifar数据集中选取了3类图片,每类50张图,分别是

  • 0 => 飞机
  • 1 => 汽车
  • 2 => 鸟

图片都放在data文件夹中,按照label_id.jpg进行命名,例如2_111.jpg代表图片类别为2(鸟),id为111。

导入相关库

除了Tensorflow,本教程还需要使用pillow(PIL),在Windows下PIL可能需要使用conda安装。

如果使用download_cifar.py自己构建数据集,还需要安装keras。

import os
#图像读取库
from PIL import Image
#矩阵运算库
import numpy as np
import tensorflow as tf

配置信息

设置了一些变量增加程序的灵活性。图片文件存放在data_dir文件夹中,train表示当前执行是训练还是测试,model-path约定了模型存放的路径。

# 数据文件夹
data_dir = "data"
# 训练还是测试
train = True
# 模型文件路径
model_path = "model/image_model"

数据读取

从图片文件夹中将图片读入numpy的array中。这里有几个细节:

  • pillow读取的图像像素值在0-255之间,需要归一化。
  • 在读取图像数据、Label信息的同时,记录图像的路径,方便后期调试。
# 从文件夹读取图片和标签到numpy数组中
# 标签信息在文件名中,例如1_40.jpg表示该图片的标签为1
def read_data(data_dir):
    datas = []
    labels = []
    fpaths = []
    for fname in os.listdir(data_dir):
        fpath = os.path.join(data_dir, fname)
        fpaths.append(fpath)
        image = Image.open(fpath)
        data = np.array(image) / 255.0
        label = int(fname.split("_")[0])
        datas.append(data)
        labels.append(label)

    datas = np.array(datas)
    labels = np.array(labels)

    print("shape of datas: {}\tshape of labels: {}".format(datas.shape, labels.shape))
    return fpaths, datas, labels


fpaths, datas, labels = read_data(data_dir)

# 计算有多少类图片
num_classes = len(set(labels))

定义placeholder(容器)

除了图像数据和Label,Dropout率也要放在placeholder中,因为在训练阶段和测试阶段需要设置不同的Dropout率。

# 定义Placeholder,存放输入和标签
datas_placeholder = tf.placeholder(tf.float32, [None, 32, 32, 3])
labels_placeholder = tf.placeholder(tf.int32, [None])

# 存放DropOut参数的容器,训练时为0.25,测试时为0
dropout_placeholdr = tf.placeholder(tf.float32)

定义卷基网络(卷积和Pooling部分)

# 定义卷积层, 20个卷积核, 卷积核大小为5,用Relu激活
conv0 = tf.layers.conv2d(datas_placeholder, 20, 5, activation=tf.nn.relu)
# 定义max-pooling层,pooling窗口为2x2,步长为2x2
pool0 = tf.layers.max_pooling2d(conv0, [2, 2], [2, 2])

# 定义卷积层, 40个卷积核, 卷积核大小为4,用Relu激活
conv1 = tf.layers.conv2d(pool0, 40, 4, activation=tf.nn.relu)
# 定义max-pooling层,pooling窗口为2x2,步长为2x2
pool1 = tf.layers.max_pooling2d(conv1, [2, 2], [2, 2])

定义全连接部分

# 将3维特征转换为1维向量
flatten = tf.layers.flatten(pool1)

# 全连接层,转换为长度为100的特征向量
fc = tf.layers.dense(flatten, 400, activation=tf.nn.relu)

# 加上DropOut,防止过拟合
dropout_fc = tf.layers.dropout(fc, dropout_placeholdr)

# 未激活的输出层
logits = tf.layers.dense(dropout_fc, num_classes)

predicted_labels = tf.arg_max(logits, 1)

定义损失函数和优化器

这里有一个技巧,没有必要给Optimizer传递平均的损失,直接将未平均的损失函数传给Optimizer即可。

# 利用交叉熵定义损失
losses = tf.nn.softmax_cross_entropy_with_logits(
    labels=tf.one_hot(labels_placeholder, num_classes),
    logits=logits
)
# 平均损失
mean_loss = tf.reduce_mean(losses)

# 定义优化器,指定要优化的损失函数
optimizer = tf.train.AdamOptimizer(learning_rate=1e-2).minimize(losses)

定义模型保存器/载入器

如果在比较大的数据集上进行长时间训练,建议定期保存模型。

# 用于保存和载入模型
saver = tf.train.Saver()

进入训练/测试执行阶段

with tf.Session() as sess:

在执行阶段有两条分支:

  • 如果trian为True,进行训练。训练需要使用sess.run(tf.global_variables_initializer())初始化参数,训练完成后,需要使用saver.save(sess, model_path)保存模型参数。
  • 如果train为False,进行测试,测试需要使用saver.restore(sess, model_path)读取参数。

训练阶段执行

if train:
       print("训练模式")
       # 如果是训练,初始化参数
       sess.run(tf.global_variables_initializer())
       # 定义输入和Label以填充容器,训练时dropout为0.25
       train_feed_dict = {
           datas_placeholder: datas,
           labels_placeholder: labels,
           dropout_placeholdr: 0.25
       }
       for step in range(150):
           _, mean_loss_val = sess.run([optimizer, mean_loss], feed_dict=train_feed_dict)
           if step % 10 == 0:
               print("step = {}\tmean loss = {}".format(step, mean_loss_val))
       saver.save(sess, model_path)
       print("训练结束,保存模型到{}".format(model_path))

测试阶段执行

else:
    print("测试模式")
    # 如果是测试,载入参数
    saver.restore(sess, model_path)
    print("从{}载入模型".format(model_path))
    # label和名称的对照关系
    label_name_dict = {
        0: "飞机",
        1: "汽车",
        2: "鸟"
    }
    # 定义输入和Label以填充容器,测试时dropout为0
    test_feed_dict = {
        datas_placeholder: datas,
        labels_placeholder: labels,
        dropout_placeholdr: 0
    }
    predicted_labels_val = sess.run(predicted_labels, feed_dict=test_feed_dict)
    # 真实label与模型预测label
    for fpath, real_label, predicted_label in zip(fpaths, labels, predicted_labels_val):
        # 将label id转换为label名
        real_label_name = label_name_dict[real_label]
        predicted_label_name = label_name_dict[predicted_label]
        print("{}\t{} => {}".format(fpath, real_label_name, predicted_label_name))

tensorflow-cnn-tutorial's People

Contributors

hujunxianligong avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

tensorflow-cnn-tutorial's Issues

您好,有个问题想咨询您

我使用您的代码训练自己的数据集
但是报错(如下)

ValueError Traceback (most recent call last)
in
37
38
---> 39 fpaths, datas, labels = read_data(data_dir)
40

in read_data(data_dir)
30 labels.append(label)
31
---> 32 datas = np.array(datas)
33 labels = np.array(labels)
34

ValueError: could not broadcast input array from shape (300,300,3) into shape (300,300)

我导入的图片为彩色,像素为300*300,请问报错是什么原因呢?我要怎么修改?
感谢答复

使用独立的测试图片

我发现你用测试的图片是其中一些训练中的图,你有试过将训练中的图独立出一部分,不参加训练,用这些独立出来的图进行测试么,我发现准确率并不高,

Hello,I have a question.

This is a sentence in the code

datas_placeholder = tf.placeholder(tf.float32, [None, 32, 32, 3])

What are the meanings of several parameters of this placeholder‘s shape?

Can you answer me?
Thank you!

where should I decide the batch-size

Hi there, good tutorial. Yet I have a question.
In the code you seem to load all images once in read_data(), and feed it to feed_dict. Do I unserstand it right or batch-size is decided somewhere else?

我有一个小白的问题

我想问下这个分类器,分文 训练和 测试2个部分。我先运行了训练(1次),再运行了测试。发现结果都是
···
data\2_82.jpg 鸟 => 汽车
data\2_86.jpg 鸟 => 汽车
data\2_88.jpg 鸟 => 汽车
data\2_90.jpg 鸟 => 汽车
data\2_92.jpg 鸟 => 汽车
data\2_95.jpg 鸟 => 汽车
···
请问是哪边我操作不对了?

I want to ask about this classifier, divided into two parts: training & testing. I ran the training first (1 time) and then run the test. Found that the results are
···
data\2_82.jpg bird => car
data\2_86.jpg bird => car
data\2_88.jpg bird => car
data\2_90.jpg bird => car
data\2_92.jpg bird => car
data\2_95.jpg bird => car
···
Excuse me, where did I do something wrong?

你好,有些问题想咨询一下您~

i download this code and change de dataset ,but my dataset is gray picture,so i know that maybe my picture is not RGB, but when i change the input channel to 1, it is not still run through this code so can you give me some advice to help me?
this is my question:
ValueError: Cannot feed value of shape (150, 32, 32) for Tensor 'Placeholder:0', which has shape '(?, 32, 32, 3)'

Some questions about the test sets.

Thank you for all your assistance.

Is the format of test sets same to the train set?

Are their naming method same?

Are there some requirements of test sets? Such as the number of pictures in test sets.

Sorry to interrupt you

What are the requirements for the images in the training set?

such as their pixel and image channel number?

Please accept my best thanks.
Enjoy yourself every day.

Sorry to interrupt you again.

Thank you for your kind cooperation.

My data set includes two classes of pictures, both of them contains 660 pictures.
If I use your code to train the model, which places need to be modified?

for step in range(150): Need I change this to for step in range(1320):

And someshere that I need to change.

err, "a Variable name or other graph key that is missing")

Hello,I find an error when I try to test,I changed train to 'False' after training and error occured:

NotFoundError: Restoring from checkpoint failed. This is most likely due to a Variable name or other graph key that is missing from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:

Key beta1_power_1 not found in checkpoint
[[Node: save_2/RestoreV2 = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, ..., DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save_2/Const_0_0, save_2/RestoreV2/tensor_names, save_2/RestoreV2/shape_and_slices)]]

Caused by op 'save_2/RestoreV2', defined at:
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\runpy.py", line 193, in _run_module_as_main
"main", mod_spec)
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\runpy.py", line 85, in run_code
exec(code, run_globals)
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\site-packages\spyder_kernels\console_main
.py", line 11, in
start.main()
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\site-packages\spyder_kernels\console\start.py", line 310, in main
kernel.start()
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\kernelapp.py", line 505, in start
self.io_loop.start()
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\site-packages\tornado\platform\asyncio.py", line 132, in start
self.asyncio_loop.run_forever()
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\asyncio\base_events.py", line 427, in run_forever
self._run_once()
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\asyncio\base_events.py", line 1440, in _run_once
handle._run()
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\asyncio\events.py", line 145, in _run
self._callback(*self._args)
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\site-packages\tornado\ioloop.py", line 758, in _run_callback
ret = callback()
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\site-packages\tornado\stack_context.py", line 300, in null_wrapper
return fn(*args, **kwargs)
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\site-packages\tornado\gen.py", line 1233, in inner
self.run()
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\site-packages\tornado\gen.py", line 1147, in run
yielded = self.gen.send(value)
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\kernelbase.py", line 357, in process_one
yield gen.maybe_future(dispatch(*args))
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\site-packages\tornado\gen.py", line 326, in wrapper
yielded = next(result)
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\kernelbase.py", line 267, in dispatch_shell
yield gen.maybe_future(handler(stream, idents, msg))
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\site-packages\tornado\gen.py", line 326, in wrapper
yielded = next(result)
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\kernelbase.py", line 534, in execute_request
user_expressions, allow_stdin,
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\site-packages\tornado\gen.py", line 326, in wrapper
yielded = next(result)
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\ipkernel.py", line 294, in do_execute
res = shell.run_cell(code, store_history=store_history, silent=silent)
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\zmqshell.py", line 536, in run_cell
return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\site-packages\IPython\core\interactiveshell.py", line 2843, in run_cell
raw_cell, store_history, silent, shell_futures)
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\site-packages\IPython\core\interactiveshell.py", line 2869, in _run_cell
return runner(coro)
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\site-packages\IPython\core\async_helpers.py", line 67, in _pseudo_sync_runner
coro.send(None)
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\site-packages\IPython\core\interactiveshell.py", line 3044, in run_cell_async
interactivity=interactivity, compiler=compiler, result=result)
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\site-packages\IPython\core\interactiveshell.py", line 3215, in run_ast_nodes
if (yield from self.run_code(code, result)):
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\site-packages\IPython\core\interactiveshell.py", line 3291, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "", line 1, in
runfile('D:/Graduate/Tensorflow-CNN-Tutorial-master/cnn.py', wdir='D:/Graduate/Tensorflow-CNN-Tutorial-master')
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\site-packages\spyder_kernels\customize\spydercustomize.py", line 786, in runfile
execfile(filename, namespace)
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\site-packages\spyder_kernels\customize\spydercustomize.py", line 110, in execfile
exec(compile(f.read(), filename, 'exec'), namespace)
File "D:/Graduate/Tensorflow-CNN-Tutorial-master/cnn.py", line 92, in
saver = tf.train.Saver()
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\saver.py", line 1281, in init
self.build()
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\saver.py", line 1293, in build
self._build(self._filename, build_save=True, build_restore=True)
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\saver.py", line 1330, in _build
build_save=build_save, build_restore=build_restore)
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\saver.py", line 778, in _build_internal
restore_sequentially, reshape)
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\saver.py", line 397, in _AddRestoreOps
restore_sequentially)
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\saver.py", line 829, in bulk_restore
return io_ops.restore_v2(filename_tensor, names, slices, dtypes)
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\ops\gen_io_ops.py", line 1546, in restore_v2
shape_and_slices=shape_and_slices, dtypes=dtypes, name=name)
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\util\deprecation.py", line 454, in new_func
return func(*args, **kwargs)
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\ops.py", line 3155, in create_op
op_def=op_def)
File "C:\Users\acer\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\ops.py", line 1717, in init
self._traceback = tf_stack.extract_stack()

NotFoundError (see above for traceback): Restoring from checkpoint failed. This is most likely due to a Variable name or other graph key that is missing from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:

Key beta1_power_1 not found in checkpoint
[[Node: save_2/RestoreV2 = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, ..., DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save_2/Const_0_0, save_2/RestoreV2/tensor_names, save_2/RestoreV2/shape_and_slices)]]

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.