Giter VIP home page Giter VIP logo

wusaifei / garbage_classify Goto Github PK

View Code? Open in Web Editor NEW
645.0 13.0 177.0 161 KB

本文新增添分类,检测,换脸技术等学习教程,各种调参技巧和tricks,卷积结构详细解析可视化,注意力机制代码等详解!本次垃圾分类挑战杯,目的在于构建基于深度学习技术的图像分类模型,实现垃圾图片类别的精准识别,大赛参考深圳垃圾分类标准,按可回收物、厨余垃圾、有害垃圾和其他垃圾四项分类。本项目包含完整的分类网络、数据增强、SVM等各种分类增强策略,后续还会继续更新新的分类技巧。

Python 100.00%

garbage_classify's Introduction

前言

本文介绍的分类方式可能比较繁琐,因为它是采用华为云比赛的提交模式进行的。简洁的分类版本点击这里:https://github.com/wusaifei/HWCC_image_classification

1.图像分类的更多tricks(注意力机制 keras,TensorFlow和pytorch 版本等):图像分类比赛tricks:“华为云杯”2019人工智能创新应用大赛

2.大家如果对目标检测比赛比较感兴趣的话,可以看一下我这篇对目标检测比赛tricks的详细介绍:目标检测比赛中的tricks(已更新更多代码解析)

3.目标检测比赛笔记:目标检测比赛笔记

4.如果对换脸技术比较感兴趣的同学可以点击这里:deepfakes/faceswap:换脸技术详细教程,手把手教学,简单快速上手!!

5.在日常调参的摸爬滚打中,参考了不少他人的调参经验,也积累了自己的一些有效调参方法,慢慢总结整理如下。希望对新晋算法工程师有所助力呀~:写给新手炼丹师:2021版调参上分手册

6.深度学习中不同类型卷积的综合介绍:2D卷积、3D卷积、转置卷积、扩张卷积、可分离卷积、扁平卷积、分组卷积、随机分组卷积、逐点分组卷积等

7.分类必备知识:Softmax函数和Sigmoid函数的区别与联系深度学习中学习率和batchsize对模型准确率的影响准确率(Precision)、召回率(Recall)、F值(F-Measure)、平均正确率,IoU利用python一层一层可视化卷积神经网络,以ResNet50为例

8.pytorch笔记:Efficientnet微调

9.keras, TensorFlow中加入注意力机制pytorch中加入注意力机制(CBAM),以ResNet为例。解析到底要不要用ImageNet预训练?如何加预训练参数?

增添内容

已修改成本地可以运行。

修改方法:

1.save_model.py|train.py|eval.py|run.py|moxing.framework.file函数全部换成os.pathshutil.copy函数。因为python里面暂时没有moxing框架。

2.注释掉run.py文件里面的下面几行代码:

# FLAGS.tmp = os.path.join(FLAGS.local_data_root, 'tmp/')
# print(FLAGS.tmp)
# if not os.path.exists(FLAGS.tmp):
#     os.mkdir(FLAGS.tmp)

.md后面增添SVM分类器、决策树分类器、随机森林分类器。

运行环境

python3.6

tensorflow 1.13.1

keras 2.24

新版本运行的话可能会运行不成功。

garbage_classify

赛题背景

比赛链接:华为云人工智能大赛·垃圾分类挑战杯

如今,垃圾分类已成为社会热点话题。其实在2019年4月26日,我国住房和城乡建设部等部门就发布了《关于在全国地级及以上城市全面开展生活垃圾分类工作的通知》,决定自2019年起在全国地级及以上城市全面启动生活垃圾分类工作。到2020年底,46个重点城市基本建成生活垃圾分类处理系统。

人工垃圾分类投放是垃圾处理的第一环节,但能够处理海量垃圾的环节是垃圾处理厂。然而,目前国内的垃圾处理厂基本都是采用人工流水线分拣的方式进行垃圾分拣,存在工作环境恶劣、劳动强度大、分拣效率低等缺点。在海量垃圾面前,人工分拣只能分拣出极有限的一部分可回收垃圾和有害垃圾,绝大多数垃圾只能进行填埋,带来了极大的资源浪费和环境污染危险。

随着深度学习技术在视觉领域的应用和发展,让我们看到了利用AI来自动进行垃圾分类的可能,通过摄像头拍摄垃圾图片,检测图片中垃圾的类别,从而可以让机器自动进行垃圾分拣,极大地提高垃圾分拣效率。

因此,华为云面向社会各界精英人士举办了本次垃圾分类竞赛,希望共同探索垃圾分类的AI技术,为垃圾分类这个利国利民的国家大计贡献自己的一份智慧。

赛题说明

本赛题采用深圳市垃圾分类标准,赛题任务是对垃圾图片进行分类,即首先识别出垃圾图片中物品的类别(比如易拉罐、果皮等),然后查询垃圾分类规则,输出该垃圾图片中物品属于可回收物、厨余垃圾、有害垃圾和其他垃圾中的哪一种。 模型输出格式示例:

{

    " result ": "可回收物/易拉罐"

}

垃圾种类40类

{
    "0": "其他垃圾/一次性快餐盒",
    "1": "其他垃圾/污损塑料",
    "2": "其他垃圾/烟蒂",
    "3": "其他垃圾/牙签",
    "4": "其他垃圾/破碎花盆及碟碗",
    "5": "其他垃圾/竹筷",
    "6": "厨余垃圾/剩饭剩菜",
    "7": "厨余垃圾/大骨头",
    "8": "厨余垃圾/水果果皮",
    "9": "厨余垃圾/水果果肉",
    "10": "厨余垃圾/茶叶渣",
    "11": "厨余垃圾/菜叶菜根",
    "12": "厨余垃圾/蛋壳",
    "13": "厨余垃圾/鱼骨",
    "14": "可回收物/充电宝",
    "15": "可回收物/包",
    "16": "可回收物/化妆品瓶",
    "17": "可回收物/塑料玩具",
    "18": "可回收物/塑料碗盆",
    "19": "可回收物/塑料衣架",
    "20": "可回收物/快递纸袋",
    "21": "可回收物/插头电线",
    "22": "可回收物/旧衣服",
    "23": "可回收物/易拉罐",
    "24": "可回收物/枕头",
    "25": "可回收物/毛绒玩具",
    "26": "可回收物/洗发水瓶",
    "27": "可回收物/玻璃杯",
    "28": "可回收物/皮鞋",
    "29": "可回收物/砧板",
    "30": "可回收物/纸板箱",
    "31": "可回收物/调料瓶",
    "32": "可回收物/酒瓶",
    "33": "可回收物/金属食品罐",
    "34": "可回收物/锅",
    "35": "可回收物/食用油桶",
    "36": "可回收物/饮料瓶",
    "37": "有害垃圾/干电池",
    "38": "有害垃圾/软膏",
    "39": "有害垃圾/过期药物"
}

efficientNet默认参数

    (width_coefficient, depth_coefficient, resolution, dropout_rate)
    'efficientnet-b0': (1.0, 1.0, 224, 0.2),
    'efficientnet-b1': (1.0, 1.1, 240, 0.2),
    'efficientnet-b2': (1.1, 1.2, 260, 0.3),
    'efficientnet-b3': (1.2, 1.4, 300, 0.3),
    'efficientnet-b4': (1.4, 1.8, 380, 0.4),
    'efficientnet-b5': (1.6, 2.2, 456, 0.4),
    'efficientnet-b6': (1.8, 2.6, 528, 0.5),
    'efficientnet-b7': (2.0, 3.1, 600, 0.5),

efficientNet的论文地址:https://arxiv.org/pdf/1905.11946.pdf

代码解析

BaseLine改进

1.使用多种模型进行对比实验,ResNet50, SE-ResNet50, Xception, SE-Xception, efficientNetB5

2.使用组归一化(GroupNormalization)代替批量归一化(batch_normalization)-解决当Batch_size过小导致的准确率下降。当batch_size小于16时,BN的error率 逐渐上升,train.py

for i, layer in enumerate(model.layers):
    if "batch_normalization" in layer.name:
        model.layers[i] = GroupNormalization(groups=32, axis=-1, epsilon=0.00001)

3.NAdam优化器

optimizer = Nadam(lr=FLAGS.learning_rate, beta_1=0.9, beta_2=0.999, epsilon=1e-08, schedule_decay=0.004)

4.自定义学习率-SGDR余弦退火学习率

sample_count = len(train_sequence) * FLAGS.batch_size
epochs = FLAGS.max_epochs
warmup_epoch = 5
batch_size = FLAGS.batch_size
learning_rate_base = FLAGS.learning_rate
total_steps = int(epochs * sample_count / batch_size)
warmup_steps = int(warmup_epoch * sample_count / batch_size)

warm_up_lr = WarmUpCosineDecayScheduler(learning_rate_base=learning_rate_base,
                                        total_steps=total_steps,
                                        warmup_learning_rate=0,
                                        warmup_steps=warmup_steps,
                                        hold_base_rate_steps=0,
                                        )

5.数据增强:随机水平翻转、随机垂直翻转、以一定概率随机旋转90°、180°、270°、随机crop(0-10%)等(详细代码请看aug.pydata_gen.py)

def img_aug(self, img):
    data_gen = ImageDataGenerator()
    dic_parameter = {'flip_horizontal': random.choice([True, False]),
                     'flip_vertical': random.choice([True, False]),
                     'theta': random.choice([0, 0, 0, 90, 180, 270])
                    }


    img_aug = data_gen.apply_transform(img, transform_parameters=dic_parameter)
    return img_aug


from imgaug import augmenters as iaa
import imgaug as ia

def augumentor(image):
    sometimes = lambda aug: iaa.Sometimes(0.5, aug)
    seq = iaa.Sequential(
        [
            iaa.Fliplr(0.5),
            iaa.Flipud(0.5),
            iaa.Affine(rotate=(-10, 10)),
            sometimes(iaa.Crop(percent=(0, 0.1), keep_size=True)),
        ],
        random_order=True
    )


    image_aug = seq.augment_image(image)

    return image_aug

6.标签平滑data_gen.py

def smooth_labels(y, smooth_factor=0.1):
    assert len(y.shape) == 2
    if 0 <= smooth_factor <= 1:
        # label smoothing ref: https://www.robots.ox.ac.uk/~vgg/rg/papers/reinception.pdf
        y *= 1 - smooth_factor
        y += smooth_factor / y.shape[1]
    else:
        raise Exception(
            'Invalid label smoothing factor: ' + str(smooth_factor))
    return y

7.数据归一化:得到所有图像的位置信息Save_path.py并计算所有图像的均值和方差mead_std.py

normMean = [0.56719673 0.5293289  0.48351972]
normStd = [0.20874391 0.21455203 0.22451781]


img = np.asarray(img, np.float32) / 255.0
mean = [0.56719673, 0.5293289, 0.48351972]
std = [0.20874391, 0.21455203, 0.22451781]
img[..., 0] -= mean[0]
img[..., 1] -= mean[1]
img[..., 2] -= mean[2]
img[..., 0] /= std[0]
img[..., 1] /= std[1]
img[..., 2] /= std[2]

各部分代码解析

  • deploy_scripts——推理文件,需要修改

    1.self.input_size = 456 
    
    
    2. def _inference(self, data):
    """
    model inference function
    Here are a inference example of resnet, if you use another model, please modify this function
    """
    img = data[self.input_key_1]
    img = img[np.newaxis, :, :, :]  # the input tensor shape of resnet is [?, 224, 224, 3]
    img = np.asarray(img, np.float32) / 255.0
    mean = [0.56719673, 0.5293289, 0.48351972]
    std = [0.20874391, 0.21455203, 0.22451781]
    img[..., 0] -= mean[0]
    img[..., 1] -= mean[1]
    img[..., 2] -= mean[2]
    img[..., 0] /= std[0]
    img[..., 1] /= std[1]
    img[..., 2] /= std[2]
    pred_score = self.sess.run([self.output_score], feed_dict={self.input_images: img})
    if pred_score is not None:
        pred_label = np.argmax(pred_score[0], axis=1)[0]
        result = {'result': self.label_id_name_dict[str(pred_label)]}
    else:
        result = {'result': 'predict score is None'}
    return result
    
  • aug.py——图像增强代码(imgaug函数)

  • data_gen.py——数据预处理代码,包括数据增强、标签平滑以及train和val的划分

  • eval.py——估值函数

  • Groupnormalization.py——组归一化

  • mean_std.py——图像均值和方差

  • Network.py——ResNet50, SE-ResNet50, Xeception, SE-Xeception, efficientNetB5

  • run.py——运行代码

  • save_model.py——保存模型

  • Save_path.py——图像位置信息

  • train.py——训练网络部分,包括网络,loss, optimizer等

  • warmup_cosine_decay_scheduler.py——余弦退火学习率

  • pip-requirements.txt——安装其他所需的库, 安装命令为:pip install -r requirements.txt

使用

前期准备

运行

  • 运行Save_path.py得到图像的位置信息

  • 运行mean_std.py得到图像的均值和方差

  • run.py——训练

    python run.py --data_url='./garbage_classify/train_data' --train_url='./model_snapshots' --deploy_script_path='./deploy_scripts'
    
  • run.py——保存为pd

      python run.py --mode=save_pb --deploy_script_path='./deploy_scripts' --freeze_weights_file_path='./model_snapshots/weights_024_0.9470.h5' --num_classes=40
    
  • run.py——估值

    python run.py --mode=eval --eval_pb_path='./model_snapshots/model' --test_data_url='./garbage_classify/train_data'
    

增添SVM分类器

当模型训练完之后,用训练好的模型预测训练数据,并将它们保存在数组中。然后放到SVC中进行训练,最后将训练好的分类器对抽取的测试数据特征进行分类。

代码如下:

target_pre_con = []
target_con = []
for i, data in tqdm(enumerate(trian_dataloaders_dict['all_data'])):

    input, target = data
    input, target = input.to(device), target.to(device)
    target_pre = model(input)

    target_pre = target_pre.cpu()
    target = target.cpu()

    target_pre = target_pre.detach().numpy()
    target = target.detach().numpy()

    target_pre_con.extend(target_pre)
    target_con.extend(target)

target_pre_con = np.asarray(target_pre_con)
target_con = np.asarray(target_con)

print(target_pre_con.shape)
print(target_con.shape)
# 提取特征用clf:svm
clf = SVC(kernel='rbf', gamma='auto')
clf.fit(target_pre_con, target_con)

for i, (input, filepath) in tqdm(enumerate(test_loader)):
    # print(input.shape[1])
    with torch.no_grad():
        image_var = input.to(device)
        y_pred = model(image_var)
        label = y_pred.cpu().data.numpy()
        # 提取特征用clf分类
        label = clf.predict(label)
        labels.append(label)

决策树分类器和随机森林分类器

只需要将clf换成DecisionTreeClassifier()RandomForestClassifier()即可。

from sklearn.tree import DecisionTreeClassifier
   
from sklearn.ensemble import RandomForestClassifier
   
clf = DecisionTreeClassifier()
   
clf = RandomForestClassifier()

实验结果

  • 网络的改进:ResNet50-0.689704SE-ResNet50-0.83259Xception-0.879003EfficientNetB5-0.924113(无数据增强)

  • 数据增强:由0.924113提升到0.934721

  • 标签平滑和数据归一化处理、学习率策略的调整ReduceLROnPlateau换成WarmUpCosineDecayScheduler,最终准确率在95%左右

大家也可以在分类代码中增加测试时增强,详细代码在tta_wrapper文件夹里面,里面有详细的介绍和测试用例。

后续

  1. 增添模型融合(投票)。

  2. 测试时增强。

  3. Cutout, Mixup, CutMix等数据增强策略。

  4. 标签平滑。

garbage_classify's People

Contributors

saifeiwu avatar wusaifei avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

garbage_classify's Issues

关于准确率的小问题

请问一下大佬,你这个准确率是整个数据集识别正确的除以整个数据集的数量,还是40个小类识别准确率的平均值啊

SVM作用是什么?为什么要加svm?

我是新手小白,不太懂这里边的道理。
是不是说先用efficientnet提取特征,然后用SVM代替原来的Linear进行分类分类?这样子的精度会提高多少?

您好,我在配置的时候总是出错

Traceback (most recent call last):
File "/tmp/pycharm_project_329/run.py", line 166, in
tf.compat.v1.app.run()
File "/home/jx/.local/lib/python3.5/site-packages/tensorflow/python/platform/app.py", line 40, in run
_run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef)
File "/home/jx/.local/lib/python3.5/site-packages/absl/app.py", line 299, in run
_run_main(main, args)
File "/home/jx/.local/lib/python3.5/site-packages/absl/app.py", line 250, in _run_main
sys.exit(main(argv))
File "/tmp/pycharm_project_329/run.py", line 122, in main
check_args(FLAGS)
File "/tmp/pycharm_project_329/run.py", line 61, in check_args
raise Exception('FLAGS.num_classes error, '
Exception: FLAGS.num_classes error, should be a positive number associated with your classification task

您帮我看一下,主要我用的是别人服务器的环境,tensflow是1.14,numpy是1.14.5,python是3.5,我怕新建个环境还得重新安装包,或者会出现缺少各种的麻烦。谢谢

训练一个epoch之后报错 TypeError: must be real number, not NoneType

你好,我把model_fn改成了自己的模型,但是训练了一个epoch之后报这个错误:
Traceback (most recent call last):
File "run.py", line 154, in
tf.app.run()
File "/home/dell/anaconda3/envs/tensorflow-1.14-python-3.6/lib/python3.6/site-packages/tensorflow/python/platform/app.py", line 40, in run
_run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef)
File "/home/dell/anaconda3/envs/tensorflow-1.14-python-3.6/lib/python3.6/site-packages/absl/app.py", line 299, in run
_run_main(main, args)
File "/home/dell/anaconda3/envs/tensorflow-1.14-python-3.6/lib/python3.6/site-packages/absl/app.py", line 250, in _run_main
sys.exit(main(argv))
File "run.py", line 145, in main
train_model(FLAGS)
File "/data2/hlf/garbage_classify/train.py", line 277, in train_model
File "/home/dell/anaconda3/envs/tensorflow-1.14-python-3.6/lib/python3.6/site-packages/keras/legacy/interfaces.py", line 91, in wrapper
return func(*args, **kwargs)
File "/home/dell/anaconda3/envs/tensorflow-1.14-python-3.6/lib/python3.6/site-packages/keras/engine/training.py", line 1732, in fit_generator
initial_epoch=initial_epoch)
File "/home/dell/anaconda3/envs/tensorflow-1.14-python-3.6/lib/python3.6/site-packages/keras/engine/training_generator.py", line 260, in fit_generator
callbacks.on_epoch_end(epoch, epoch_logs)
File "/home/dell/anaconda3/envs/tensorflow-1.14-python-3.6/lib/python3.6/site-packages/keras/callbacks/callbacks.py", line 152, in on_epoch_end
callback.on_epoch_end(epoch, logs)
File "/data2/hlf/garbage_classify/train.py", line 210, in on_epoch_end
TypeError: must be real number, not NoneType
不知道是什么问题呢?想请教一下的!!!

def on_epoch_end(self, epoch, logs={}):
self.losses.append(logs.get('loss'))
self.val_losses.append(logs.get('val_loss'))
save_path = os.path.join(self.FLAGS.train_local, 'weights_%03d_%.4f.h5' % (epoch,logs.get('val_acc'))) # 210行
self.model.save_weights(save_path)
if self.FLAGS.train_url.startswith('s3://'):
save_url = os.path.join(self.FLAGS.train_url, 'weights_%03d_%.4f.h5' % (epoch, logs.get('val_acc')))
shutil.copyfile(save_path, save_url)
print('save weights file', save_path)
if self.FLAGS.keep_weights_file_num > -1:
weights_files = glob(os.path.join(self.FLAGS.train_local, '*.h5'))
if len(weights_files) >= self.FLAGS.keep_weights_file_num:
weights_files.sort(key=lambda file_name: os.stat(file_name).st_ctime, reverse=True)

关于数据增强的问题

你好,数据增强不是通过旋转平移等操作来增加原本的数据集吗,我看到aug.py和data_gen.py里面虽然进行了数据增强,但是是直接替换了原始读取的图片,数据集的总量还是没有变的,对这里有点疑惑?

关于mean_std.py

请问楼主大神,mean_std.py最后执行的结果也没有保存,那它对最后识别的准确率有什么影响,没看懂mean_std.py是干什么的,望您解答,谢谢您!

关于在第一轮回调时的报错信息

在训练了1个epoch后,开始回调时出现了下面这个报错信息:

Traceback (most recent call last):
File "run.py", line 166, in
tf.app.run()
File "e:\Anaconda3\lib\site-packages\tensorflow\python\platform\app.py", line 125, in run
_sys.exit(main(argv))
File "run.py", line 157, in main
train_model(FLAGS)
File "D:\garbage_classification\garbage_classify-master\train.py", line 134, in train_model
shuffle=False
File "e:\Anaconda3\lib\site-packages\keras\legacy\interfaces.py", line 91, in wrapper
return func(*args, **kwargs)
File "e:\Anaconda3\lib\site-packages\keras\engine\training.py", line 1732, in fit_generator
initial_epoch=initial_epoch)
File "e:\Anaconda3\lib\site-packages\keras\engine\training_generator.py", line 242, in fit_generator
workers=0)
File "e:\Anaconda3\lib\site-packages\keras\legacy\interfaces.py", line 91, in wrapper
return func(*args, **kwargs)
File "e:\Anaconda3\lib\site-packages\keras\engine\training.py", line 1791, in evaluate_generator
verbose=verbose)
File "e:\Anaconda3\lib\site-packages\keras\engine\training_generator.py", line 341, in evaluate_generator
callbacks._call_begin_hook('test')
File "e:\Anaconda3\lib\site-packages\keras\callbacks\callbacks.py", line 105, in _call_begin_hook
self.on_test_begin()
File "e:\Anaconda3\lib\site-packages\keras\callbacks\callbacks.py", line 239, in on_test_begin
callback.on_test_begin(logs)
AttributeError: 'WarmUpCosineDecayScheduler' object has no attribute 'on_test_begin'

我的运行环境和作者的是一样的,用笔记本跑的代码,所以把调用GPU的代码注释了
报错信息上说WarmUpCosineDecayScheduler这个类缺少on_test_begin,我看了一下作者重写的这个回调类里确实没有,但是我单独运行这个类检查是没问题的,不知道该怎么改,求解答,非常感谢!

src_v2官方Baseline5个epoch验证集准确度98.5%?

作者您好!华为云官网比赛结束后没有公布测试集。我训练了一下官方给的baseline代码(src_v2.zip的代码),数据集用的是garbage_classify_v2.zip, Baseline只训练了5个epoch,训练集:验证集=0.75:0.25,结果train_acc到了99%,val_acc到了98.5%,我不知道这个是什么原因,想请教一下您!

运行run.py报错

在window下运行的,运行run.py文件,错误为:OSError: Unable to open file (unable to open file: name = '/home/work/user-job-dir/src/efficientnet-b5_notop.h5', errno = 2, error message = 'No such file or directory', flags = 0, o_flags = 0)
找不到这个文件,求解答

关于运行run.py文件的问题

我在运行时出现了File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
File "h5py/h5f.pyx", line 78, in h5py.h5f.open
OSError: Unable to open file (file signature not found)
请问该如何解决这个问题呢

垃圾分类本地运行

作者您好 我在本地运行您的run.py文件,用的efficentnet-b5_notop.h5权重文件,input_size=228(没有用456,GPU原因),batch_size=8,epoch=30 但是loss下降的很慢,训练集acc能到0.95,但是测试集acc基本在70%左右不怎么变化 请问下是什么原因呢

训练结束后进行评估,运作python run.py --mode=eval --eval_pb_path='./model_snapshots/model' --test_data_url='./test_img'报错

2020-03-10 10:37:34.998582: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1003] 1: N N
2020-03-10 10:37:34.998708: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1115] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 10312 MB memory) -> physical GPU (device: 0, name: GeForce RTX 2080 Ti, pci bus id: 0000:01:00.0, compute capability: 7.5)
2020-03-10 10:37:34.998966: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1115] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:1 with 10312 MB memory) -> physical GPU (device: 1, name: GeForce RTX 2080 Ti, pci bus id: 0000:02:00.0, compute capability: 7.5)
Traceback (most recent call last):
File "/opt/shakey/.local/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1092, in _run
subfeed, allow_tensor=True, allow_operation=False)
File "/opt/shakey/.local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3478, in as_graph_element
return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
File "/opt/shakey/.local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3557, in _as_graph_element_locked
raise ValueError("Tensor %s is not an element of this graph." % obj)
ValueError: Tensor Tensor("input_1:0", shape=(?, 456, 456, 3), dtype=float32) is not an element of this graph.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "run.py", line 166, in
tf.app.run()
File "/opt/shakey/.local/lib/python3.6/site-packages/tensorflow/python/platform/app.py", line 125, in run
_sys.exit(main(argv))
File "run.py", line 163, in main
eval_model(FLAGS)
File "/opt/shakey/imageclass/garbage_classify-master/eval.py", line 215, in eval_model
test_single_model(FLAGS)
File "/opt/shakey/imageclass/garbage_classify-master/eval.py", line 159, in test_single_model
pred_score = sess1.run([output_score], feed_dict={input_images: img})
File "/opt/shakey/.local/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 929, in run
run_metadata_ptr)
File "/opt/shakey/.local/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1095, in _run
'Cannot interpret feed_dict key as Tensor: ' + e.args[0])
TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor("input_1:0", shape=(?, 456, 456, 3), dtype=float32) is not an element of this graph.

关于准确度和loss值

总共10类,每类几百张照片,val accuracy 在第三轮就达到了1.0,之后一直不变。test accuracy 在前几轮不断上升,上升到第8轮为0.69,之后不断震荡、下降,50轮时,test accuracy 为0.3。loss下降缓慢,到第10轮之后下降的特别特别慢,几乎保持不变。请问楼主,这是什么原因。

准确率达不到预期

用链接里的数据集训练,30epoch为0.825的准确率,并没有这么高,batch由于内存限制设置的4,其他都没改动,能帮忙找一下原因嘛?

找不到指定模块

你好作者,我按照你的要求安装了所有的环境 但是训练的时候找不到指定模块是怎么回事

Traceback (most recent call last):
File "run.py", line 17, in
import tensorflow as tf
File "C:\ProgramData\Anaconda3\envs\tf_1_13_1\lib\site-packages\tensorflow_init_.py", line 24, in
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
File "C:\ProgramData\Anaconda3\envs\tf_1_13_1\lib\site-packages\tensorflow\python_init_.py", line 49, in
from tensorflow.python import pywrap_tensorflow
File "C:\ProgramData\Anaconda3\envs\tf_1_13_1\lib\site-packages\tensorflow\python\pywrap_tensorflow.py", line 74, in
raise ImportError(msg)
ImportError: Traceback (most recent call last):
File "C:\ProgramData\Anaconda3\envs\tf_1_13_1\lib\site-packages\tensorflow\python\pywrap_tensorflow.py", line 58, in
from tensorflow.python.pywrap_tensorflow_internal import *
File "C:\ProgramData\Anaconda3\envs\tf_1_13_1\lib\site-packages\tensorflow\python\pywrap_tensorflow_internal.py", line 28, in
_pywrap_tensorflow_internal = swig_import_helper()
File "C:\ProgramData\Anaconda3\envs\tf_1_13_1\lib\site-packages\tensorflow\python\pywrap_tensorflow_internal.py", line 24, in swig_import_helper
_mod = imp.load_module('_pywrap_tensorflow_internal', fp, pathname, description)
File "C:\ProgramData\Anaconda3\envs\tf_1_13_1\lib\imp.py", line 243, in load_module
return load_dynamic(name, filename, file)
File "C:\ProgramData\Anaconda3\envs\tf_1_13_1\lib\imp.py", line 343, in load_dynamic
return _load(spec)
ImportError: DLL load failed: 找不到指定的模块。

另外我想问一下 这个训练里面的“--train_url='./model_snapshots‘’”是什么意思 没找到这个文件夹

多gpu并行训练问题

以下是多gpu并行训练的loss:

image

在第一个epoch的时候loss 和对应的 acc是正常的,到第二个epoch有问题,怀疑是合并参数的时候有问题??

关于GPU配置的咨询以及数据增强

楼主你好
我这边发现用阿里云服务器 nvidia的P100 16G显存的话batch size貌似8都不行,只能到6, Ubuntu的系统,你做training的时候用的啥系统,啥配置?

另外关于数据增强的部分是不是只再training的时候才用到,还是做最后模型检测的时候,也像数据归一化一样可以使用。

加载efficientnet-b5_notop.h5

楼主,请问你还有这个.h5权重吗?我下载了几天也没下载下来,需要你的帮助,谢谢!QQ598770323

关于.h5文件

您好,在代码中有很多,h5文件,是所有的.h5文件都要下载吗?谢谢您

还有这个地方怎么处理?需要下载吗。谢谢您

BASE_WEIGHTS_PATH = (
'https://github.com/Callidior/keras-applications/'
'releases/download/efficientnet/')

WEIGHTS_HASHES = {
'efficientnet-b0': ('163292582f1c6eaca8e7dc7b51b01c61'
'5b0dbc0039699b4dcd0b975cc21533dc',
'c1421ad80a9fc67c2cc4000f666aa507'
'89ce39eedb4e06d531b0c593890ccff3'),
'efficientnet-b1': ('d0a71ddf51ef7a0ca425bab32b7fa7f1'
'6043ee598ecee73fc674d9560c8f09b0',
'75de265d03ac52fa74f2f510455ba64f'
'9c7c5fd96dc923cd4bfefa3d680c4b68'),
'efficientnet-b2': ('bb5451507a6418a574534aa76a91b106'
'f6b605f3b5dde0b21055694319853086',
'433b60584fafba1ea3de07443b74cfd3'
'2ce004a012020b07ef69e22ba8669333'),

1

1

ModuleNotFoundError: No module named 'keras'

File "run.py", line 160, in main
from train import train_model
File "G:\garbage_classify-master\train.py", line 6, in
import keras.backend
ModuleNotFoundError: No module named 'keras'

你们有遇到过这种情况吗,keras我确认是安装了的。

TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor("input_1:0", shape=(?, 456, 456, 3), dtype=float32) is not an element of this graph.

`(tfenv) D:\DeepLearnning\Code\RubbishSort\garbage_classify-master>python run.py --mode=eval --eval_pb_path=./model_snapshots/model --test_data_url=./datasets/garbage_classify/tr
ain_data --num_classes=40
Using TensorFlow backend.
2020-03-12 10:09:15.415268: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2
2020-03-12 10:09:16.536963: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1433] Found device 0 with properties:
name: GeForce MX250 major: 6 minor: 1 memoryClockRate(GHz): 1.582
pciBusID: 0000:06:00.0
totalMemory: 2.00GiB freeMemory: 1.62GiB
2020-03-12 10:09:16.571051: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1512] Adding visible gpu devices: 0
2020-03-12 10:09:25.865212: I tensorflow/core/common_runtime/gpu/gpu_device.cc:984] Device interconnect StreamExecutor with strength 1 edge matrix:
2020-03-12 10:09:25.879351: I tensorflow/core/common_runtime/gpu/gpu_device.cc:990] 0
2020-03-12 10:09:25.885838: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1003] 0: N
2020-03-12 10:09:25.956745: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1115] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 1364 MB memory)
-> physical GPU (device: 0, name: GeForce MX250, pci bus id: 0000:06:00.0, compute capability: 6.1)
WARNING:tensorflow:From D:\DeepLearnning\Code\RubbishSort\garbage_classify-master\eval.py:131: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be re
moved in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function
for importing SavedModels in Tensorflow 2.0.
WARNING:tensorflow:From D:\DeepLearnning\anaconda3\envs\tfenv\lib\site-packages\tensorflow\python\training\saver.py:1266: checkpoint_exists (from tensorflow.python.training.chec
kpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to check for files with this prefix.
2020-03-12 10:09:38.781264: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1512] Adding visible gpu devices: 0
2020-03-12 10:09:38.846031: I tensorflow/core/common_runtime/gpu/gpu_device.cc:984] Device interconnect StreamExecutor with strength 1 edge matrix:
2020-03-12 10:09:38.874307: I tensorflow/core/common_runtime/gpu/gpu_device.cc:990] 0
2020-03-12 10:09:38.876455: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1003] 0: N
2020-03-12 10:09:38.889822: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1115] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 1364 MB memory)
-> physical GPU (device: 0, name: GeForce MX250, pci bus id: 0000:06:00.0, compute capability: 6.1)
Traceback (most recent call last):
File "D:\DeepLearnning\anaconda3\envs\tfenv\lib\site-packages\tensorflow\python\client\session.py", line 1092, in _run
subfeed, allow_tensor=True, allow_operation=False)
File "D:\DeepLearnning\anaconda3\envs\tfenv\lib\site-packages\tensorflow\python\framework\ops.py", line 3478, in as_graph_element
return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
File "D:\DeepLearnning\anaconda3\envs\tfenv\lib\site-packages\tensorflow\python\framework\ops.py", line 3557, in _as_graph_element_locked
raise ValueError("Tensor %s is not an element of this graph." % obj)
ValueError: Tensor Tensor("input_1:0", shape=(?, 456, 456, 3), dtype=float32) is not an element of this graph.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "run.py", line 169, in
tf.app.run()
File "D:\DeepLearnning\anaconda3\envs\tfenv\lib\site-packages\tensorflow\python\platform\app.py", line 125, in run
_sys.exit(main(argv))
File "run.py", line 165, in main
eval_model(FLAGS)
File "D:\DeepLearnning\Code\RubbishSort\garbage_classify-master\eval.py", line 215, in eval_model
test_single_model(FLAGS)
File "D:\DeepLearnning\Code\RubbishSort\garbage_classify-master\eval.py", line 159, in test_single_model
pred_score = sess1.run([output_score], feed_dict={input_images: img})
File "D:\DeepLearnning\anaconda3\envs\tfenv\lib\site-packages\tensorflow\python\client\session.py", line 929, in run
run_metadata_ptr)
File "D:\DeepLearnning\anaconda3\envs\tfenv\lib\site-packages\tensorflow\python\client\session.py", line 1095, in _run
'Cannot interpret feed_dict key as Tensor: ' + e.args[0])
TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor("input_1:0", shape=(?, 456, 456, 3), dtype=float32) is not an element of this graph.
`
我在--mode=eval估值的时候遇到了这个问题,请教博主!

作者你好,在save_pb的运行里,h5文件存在也会报错,不存在也会报错,这是为什么

我看了run.py的freeze_weight_path那段代码,没明白这段代码的意思
if not os.path.exists(FLAGS.freeze_weights_file_path):
raise Exception('FLAGS.freeze_weights_file_path: %s is not exist' %
FLAGS.freeze_weights_file_path)
if os.path.isdir(FLAGS.freeze_weights_file_path):
raise Exception('FLAGS.freeze_weights_file_path must be a file path, not a directory, %s ' %
FLAGS.freeze_weights_file_path)
if os.path.exists(FLAGS.freeze_weights_file_path.rsplit('/', 1)[0] + '/model'):
raise Exception('a model directory is already exist in ' +
FLAGS.freeze_weights_file_path.rsplit('/', 1)[0]
+ ', please rename or remove the model directory ')

你好,我按照你的步骤运行python run.py --data_url='../datasets/garbage_classify/train_data' --train_url='./model_snapshots' --deploy_script_path='./deploy_scripts',已经cd到run.py的目录

你好,我按照你的步骤运行python run.py --data_url='../datasets/garbage_classify/train_data' --train_url='./model_snapshots' --deploy_script_path='./deploy_scripts',已经cd到run.py的目录,运行结果提示:
File "run.py", line 68, in check_args
raise Exception('FLAGS.data_url: %s is not exist' % FLAGS.data_url)
Exception: FLAGS.data_url: './garbage_classify/train_data' is not exist
可是这个训练集是在这个目录,请问要怎么解决

当训练完运行eval.py测试,出现TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor("input_1:0", shape=(?, 456, 456, 3), dtype=float32) is not an element of this graph.

Using TensorFlow backend.
2020-03-23 14:30:41.084890: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2020-03-23 14:30:44.398776: I tensorflow/compiler/xla/service/service.cc:150] XLA service 0x5638b17ca080 executing computations on platform CUDA. Devices:
2020-03-23 14:30:44.398829: I tensorflow/compiler/xla/service/service.cc:158] StreamExecutor device (0): GeForce RTX 2080 Ti, Compute Capability 7.5
2020-03-23 14:30:44.398840: I tensorflow/compiler/xla/service/service.cc:158] StreamExecutor device (1): GeForce RTX 2080 Ti, Compute Capability 7.5
2020-03-23 14:30:44.405480: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 2499995000 Hz
2020-03-23 14:30:44.408096: I tensorflow/compiler/xla/service/service.cc:150] XLA service 0x5638b193fdc0 executing computations on platform Host. Devices:
2020-03-23 14:30:44.408128: I tensorflow/compiler/xla/service/service.cc:158] StreamExecutor device (0): ,
2020-03-23 14:30:44.408312: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1433] Found device 0 with properties:
name: GeForce RTX 2080 Ti major: 7 minor: 5 memoryClockRate(GHz): 1.545
pciBusID: 0000:01:00.0
totalMemory: 10.76GiB freeMemory: 10.60GiB
2020-03-23 14:30:44.408403: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1433] Found device 1 with properties:
name: GeForce RTX 2080 Ti major: 7 minor: 5 memoryClockRate(GHz): 1.545
pciBusID: 0000:02:00.0
totalMemory: 10.76GiB freeMemory: 10.60GiB
2020-03-23 14:30:44.408541: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1512] Adding visible gpu devices: 0, 1
2020-03-23 14:30:44.411191: I tensorflow/core/common_runtime/gpu/gpu_device.cc:984] Device interconnect StreamExecutor with strength 1 edge matrix:
2020-03-23 14:30:44.411219: I tensorflow/core/common_runtime/gpu/gpu_device.cc:990] 0 1
2020-03-23 14:30:44.411233: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1003] 0: N N
2020-03-23 14:30:44.411244: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1003] 1: N N
2020-03-23 14:30:44.411375: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1115] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 10312 MB memory) -> physical GPU (device: 0, name: GeForce RTX 2080 Ti, pci bus id: 0000:01:00.0, compute capability: 7.5)
2020-03-23 14:30:44.411927: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1115] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:1 with 10312 MB memory) -> physical GPU (device: 1, name: GeForce RTX 2080 Ti, pci bus id: 0000:02:00.0, compute capability: 7.5)
WARNING:tensorflow:From /opt/shakey/imageclass/garbage_classify-master/eval.py:134: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0.
WARNING:tensorflow:From /opt/shakey/.local/lib/python3.6/site-packages/tensorflow/python/training/saver.py:1266: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to check for files with this prefix.
2020-03-23 14:31:22.029904: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1512] Adding visible gpu devices: 0, 1
2020-03-23 14:31:22.030413: I tensorflow/core/common_runtime/gpu/gpu_device.cc:984] Device interconnect StreamExecutor with strength 1 edge matrix:
2020-03-23 14:31:22.030436: I tensorflow/core/common_runtime/gpu/gpu_device.cc:990] 0 1
2020-03-23 14:31:22.030449: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1003] 0: N N
2020-03-23 14:31:22.030459: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1003] 1: N N
2020-03-23 14:31:22.030628: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1115] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 10312 MB memory) -> physical GPU (device: 0, name: GeForce RTX 2080 Ti, pci bus id: 0000:01:00.0, compute capability: 7.5)
2020-03-23 14:31:22.030908: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1115] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:1 with 10312 MB memory) -> physical GPU (device: 1, name: GeForce RTX 2080 Ti, pci bus id: 0000:02:00.0, compute capability: 7.5)
Traceback (most recent call last):
File "/opt/shakey/.local/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1092, in _run
subfeed, allow_tensor=True, allow_operation=False)
File "/opt/shakey/.local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3478, in as_graph_element
return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
File "/opt/shakey/.local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3557, in _as_graph_element_locked
raise ValueError("Tensor %s is not an element of this graph." % obj)
ValueError: Tensor Tensor("input_1:0", shape=(?, 456, 456, 3), dtype=float32) is not an element of this graph.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "run.py", line 168, in
tf.app.run()
File "/opt/shakey/.local/lib/python3.6/site-packages/tensorflow/python/platform/app.py", line 125, in run
_sys.exit(main(argv))
File "run.py", line 165, in main
eval_model(FLAGS)
File "/opt/shakey/imageclass/garbage_classify-master/eval.py", line 218, in eval_model
test_single_model(FLAGS)
File "/opt/shakey/imageclass/garbage_classify-master/eval.py", line 162, in test_single_model
pred_score = sess1.run([output_score], feed_dict={input_images: img})
File "/opt/shakey/.local/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 929, in run
run_metadata_ptr)
File "/opt/shakey/.local/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1095, in _run
'Cannot interpret feed_dict key as Tensor: ' + e.args[0])
TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor("input_1:0", shape=(?, 456, 456, 3), dtype=float32) is not an element of this graph.

run.py运行问题请教

我的是笔记本,显卡1050,运行时候出现ImportError: DLL load failed: 页面文件太小,无法完成操作。是带不动问题吗

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.