Giter VIP home page Giter VIP logo

siamese-tf2's Introduction

Siamese:孪生神经网络在tf2(tensorflow2)当中的实现


目录

  1. 仓库更新 Top News
  2. 注意事项 Attention
  3. 所需环境 Environment
  4. 文件下载 Download
  5. 预测步骤 How2predict
  6. 训练步骤 How2train
  7. 参考资料 Reference

Top News

2022-04:进行了大幅度的更新,支持step、cos学习率下降法、支持adam、sgd优化器选择、支持学习率根据batch_size自适应调整。
BiliBili视频中的原仓库地址为:https://github.com/bubbliiiing/siamese-tf2/tree/bilibili

注意事项

训练Omniglot数据集和训练自己的数据集可以采用两种不同的格式。需要注意格式的摆放噢!

该仓库实现了孪生神经网络(Siamese network),该网络常常用于检测输入进来的两张图片的相似性。该仓库所使用的主干特征提取网络(backbone)为VGG16。

所需环境

tensorflow-gpu==2.2.0

文件下载

训练所需的vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5可在百度网盘中下载。
链接: https://pan.baidu.com/s/1NH3wcVr98vyJLYhYBTglvg 提取码: xyg2

Omniglot数据集下载地址为:
链接: https://pan.baidu.com/s/1pYp6vqiLLRFLn1tVeRk8ZQ 提取码: 5sa7

人脸数据集下载地址为(格式还需要简单修改一下才可以使用,请参考下方“训练自己相似性比较的模型”的格式进行修改):
链接: https://pan.baidu.com/s/1OvEFXTUZrvu4T5qSPkHOJw 提取码: aqhg

我一共会提供两个权重,分别是vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5和Omniglot_vgg.h5。
其中:
Omniglot_vgg.h5是Omniglot训练好的权重,可直接使用进行下面的预测步骤。
vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5是vgg的权重,可以用于训练其它的数据集。

预测步骤

a、使用预训练权重

  1. 下载完库后解压,在百度网盘下载Omniglot_vgg.h5,放入model_data,运行predict.py,依次输入
img/Angelic_01.png
img/Angelic_02.png

b、使用自己训练的权重

  1. 按照训练步骤训练。
  2. 在siamese.py文件里面,在如下部分修改model_path使其对应训练好的文件;model_path对应logs文件夹下面的权值文件
_defaults = {
    "model_path": 'model_data/Omniglot_vgg.h5',
    "input_shape" : (105, 105, 3),
}
  1. 运行predict.py,输入
img/Angelic_01.png
img/Angelic_02.png

训练步骤

可参考我的CSDN博客https://blog.csdn.net/weixin_44791964/article/details/107343394

a、训练Omniglot例子

Omniglot数据集中数据存放格式有三级:

- image_background
	- Alphabet_of_the_Magi
		- character01
			- 0709_01.png
			- 0709_02.png
			- ……
		- character02
		- character03
		- ……
	- Anglo-Saxon_Futhorc
	- ……

训练步骤为:

  1. 下载数据集,放在根目录下的dataset文件夹下。
  2. 运行train.py开始训练。

b、训练自己相似性比较的模型

如果大家想要训练自己的数据集,可以将数据集按照如下格式进行摆放。

- image_background
	- character01
		- 0709_01.png
		- 0709_02.png
		- ……
	- character02
	- character03
	- ……

相比Omniglot少了一级。每一个chapter里面放同类型的图片。
训练步骤为:

  1. 按上述格式放置数据集,放在根目录下的dataset文件夹下。
  2. 之后将train.py当中的train_own_data设置成True。
  3. 运行train.py开始训练。

Reference

https://github.com/tensorfreitas/Siamese-Networks-for-One-Shot-Learning

siamese-tf2's People

Contributors

bubbliiiing avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

siamese-tf2's Issues

使用自己準備的資料集跑完一個epoch後就沒反應

您好:
請問我使用自己準備的梅爾頻譜圖資料集(圖片大小432x288)在訓練完一個epoch後就會沒有反應停在原地,可能是發生了甚麼狀況呢?

Number of devices: 1
Configurations:
----------------------------------------------------------------------
|                     keys |                                   values|
----------------------------------------------------------------------
|               model_path | model_data/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5|
|              input_shape |                               [105, 105]|
|               Init_Epoch |                                        0|
|                    Epoch |                                      100|
|               batch_size |                                       64|
|                  Init_lr |                                     0.01|
|                   Min_lr |                                   0.0001|
|           optimizer_type |                                      sgd|
|                 momentum |                                      0.9|
|            lr_decay_type |                                      cos|
|              save_period |                                       10|
|                 save_dir |                                     logs|
|              num_workers |                                        1|
|                num_train |                                      900|
|                  num_val |                                      100|
----------------------------------------------------------------------

[Warning] 使用sgd优化器时,建议将训练总步长设置到30000以上。
[Warning] 本次运行的总训练数据量为900,batch_size为64,共训练100个Epoch,计算出总训练步长为1400。
[Warning] 由于总训练步长为1400,小于建议总步长30000,建议设置总世代为2143。
C:\Users\little7\anaconda3\envs\tf2-gpu\lib\site-packages\keras\optimizers\optimizer_v2\adam.py:110: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
  super(Adam, self).__init__(name, **kwargs)
C:\Users\little7\anaconda3\envs\tf2-gpu\lib\site-packages\keras\optimizers\optimizer_v2\gradient_descent.py:108: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
  super(SGD, self).__init__(name, **kwargs)
Train on 900 samples, val on 100 samples, with batch size 64.

Epoch 1: LearningRateScheduler setting learning rate to 0.001.
Epoch 1/100
 6/14 [===========>..................] - ETA: 10s - loss: 0.7094 - binary_accuracy: 0.4935WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.4824s vs `on_train_batch_end` time: 0.6985s). Check your callbacks.
14/14 [==============================] - ETA: 0s - loss: 0.7034 - binary_accuracy: 0.4933

手動Ctrl+c終止後的訊息:

Traceback (most recent call last):
  File "c:\Users\little7\Downloads\Siamese-tf2-master\train.py", line 283, in <module>
    model.fit(
  File "C:\Users\little7\anaconda3\envs\tf2-gpu\lib\site-packages\keras\utils\traceback_utils.py", line 64, in error_handler
    return fn(*args, **kwargs)
  File "C:\Users\little7\anaconda3\envs\tf2-gpu\lib\site-packages\keras\engine\training.py", line 1432, in fit 
    self._eval_data_handler = data_adapter.get_data_handler(
  File "C:\Users\little7\anaconda3\envs\tf2-gpu\lib\site-packages\keras\engine\data_adapter.py", line 1401, in get_data_handler
    return DataHandler(*args, **kwargs)
  File "C:\Users\little7\anaconda3\envs\tf2-gpu\lib\site-packages\keras\engine\data_adapter.py", line 1151, in __init__
    self._adapter = adapter_cls(
  File "C:\Users\little7\anaconda3\envs\tf2-gpu\lib\site-packages\keras\engine\data_adapter.py", line 926, in __init__
    super(KerasSequenceAdapter, self).__init__(
  File "C:\Users\little7\anaconda3\envs\tf2-gpu\lib\site-packages\keras\engine\data_adapter.py", line 800, in __init__
    peek, x = self._peek_and_restore(x)
  File "C:\Users\little7\anaconda3\envs\tf2-gpu\lib\site-packages\keras\engine\data_adapter.py", line 937, in _peek_and_restore
    return x[0], x
  File "c:\Users\little7\Downloads\Siamese-tf2-master\utils\dataloader.py", line 44, in __getitem__
    selected_path   = self.train_lines[self.train_labels[:] == c]
KeyboardInterrupt

疑问

请问可不可以减少一下计算量呢,达到差不多的效果

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.