Giter VIP home page Giter VIP logo

Comments (32)

zkawfanx avatar zkawfanx commented on August 21, 2024 2

是的,这也是选择DAVIS视频数据集合成数据来进行benchmark的原因,需要给video-based方法提供训练数据

from stablellve.

zkawfanx avatar zkawfanx commented on August 21, 2024 2
  1. 可以
  2. 是的
    建议先把他论文看一遍,他是在RAW数据上进行训练的,这样可以帮助你理解其代码实现

from stablellve.

zkawfanx avatar zkawfanx commented on August 21, 2024

SMOID论文中的网络为3D-UNet,在合成数据集上从头训练即可,本质上是比较不同论文中所使用网络在DAVIS合成数据集上的性能

from stablellve.

Ac2333 avatar Ac2333 commented on August 21, 2024

原来如此 所以基于视频的方法(MBLLVEN和SMOID)都是需要在我们自己的合成数据上重新训练后再进行测试对吗

from stablellve.

Ac2333 avatar Ac2333 commented on August 21, 2024

那前面几种基于图像的方法我们是否还需要利用DAVIS合成数据重新训练一个模型呢?还是直接使用官方已经训练好的模型在DAVIS数据集上测试就行?

from stablellve.

zkawfanx avatar zkawfanx commented on August 21, 2024

论文中对比方法模型都是用源代码重新训练的。

from stablellve.

Ac2333 avatar Ac2333 commented on August 21, 2024

论文中对比方法模型都是用源代码重新训练的。

明白了!十分感谢~

from stablellve.

Ac2333 avatar Ac2333 commented on August 21, 2024

SMOID论文中的网络为3D-UNet,在合成数据集上从头训练即可,本质上是比较不同论文中所使用网络在DAVIS合成数据集上的性能

我在合成数据集上重新训练的SMOID模型不能增强图像,请问您是直接用图片进行训练,还是将图片转成视频进行训练的?如果可以的话,能否提供您的SMOID的train的代码? 不胜感激

from stablellve.

zkawfanx avatar zkawfanx commented on August 21, 2024

你使用的是自己实现的代码还是怎样训练的,SMOID的有官方代码,我用的就是其公布的源码

from stablellve.

Ac2333 avatar Ac2333 commented on August 21, 2024

你使用的是自己实现的代码还是怎样训练的,SMOID的有官方代码,我用的就是其公布的源码

我使用的是SMOID的官方代码,您是直接使用自己的图片数据集进行训练的吗?还是将图片转成了视频然后再训练的

from stablellve.

zkawfanx avatar zkawfanx commented on August 21, 2024

为什么需要将图片转成视频,源码中训练也只是16帧或者32帧一起送入网络这样实现的

from stablellve.

Ac2333 avatar Ac2333 commented on August 21, 2024

为什么需要将图片转成视频,源码中训练也只是16帧或者32帧一起送入网络这样实现的
抱歉 我想起来了,因为官方发布的SMOID的代码进行训练时 首先需要从download_VGG_models.py 中下载VGG模型,但是我直接运行这个代码 无法下载这个VGG模型。所以我后来是找的一篇非官方代码进行训练。您是如何下载的官方提供的那个VGG模型的呢,它在Goole云盘的链接好像失效了

from stablellve.

zkawfanx avatar zkawfanx commented on August 21, 2024

我不记得SMOID的训练里有用到VGG loss呀,你再确认下是不是找错代码了,他们论文里也是说用L1 loss进行训练的。

from stablellve.

Ac2333 avatar Ac2333 commented on August 21, 2024

我不记得SMOID的训练里有用到VGG loss呀,你再确认下是不是找错代码了,他们论文里也是说用L1 loss进行训练的。

这是链接 https://github.com/cchen156/Seeing-Motion-in-the-Dark
进行训练的时候 官方提示先下载一个vgg model
image

from stablellve.

zkawfanx avatar zkawfanx commented on August 21, 2024

我对比的SMOID是ICCV19的《Learning to see moving objects in the dark》,《Seeing Motion in the Dark》训练只能用静态视频,否则随机选不同帧时他的loss约束是不成立的。

from stablellve.

Ac2333 avatar Ac2333 commented on August 21, 2024

所以是不能利用我们的DAVIS 合成数据集来训练 《Seeing Motion in the Dark》这个模型了吗

from stablellve.

zkawfanx avatar zkawfanx commented on August 21, 2024

原理上是不行的,它需要从同一段视频中随机抽取不同帧来做一致性约束,对于动态场景不适用

from stablellve.

Ac2333 avatar Ac2333 commented on August 21, 2024

原理上是不行的,它需要从同一段视频中随机抽取不同帧来做一致性约束,对于动态场景不适用

好的,谢谢您的解答! 对于《Learning to see moving objects in the dark》SMOID,在它的train代码 我有2个疑问:

  1. 下面这个 in_image的最后一个维度是4,对于我们的DAVIS合成数据集 应该是RGB 3通道吧? 所以是否应该把 in_image 最后的一个维度从4 改成 3呢?
    image

2.训练代码给的注释是16 bit,但我们的DAVIS数据集 应该是24位深度的图片, 所以训练代码中的 / 65535.0 应该得改成 / 255.0 对吗?
image

from stablellve.

Ac2333 avatar Ac2333 commented on August 21, 2024
  1. 可以
  2. 是的
    建议先把他论文看一遍,他是在RAW数据上进行训练的,这样可以帮助你理解其代码实现

您好,下面这个是我用DAVIS数据集训练的SMOID模型对弱光视频增强的效果。感觉我训练的SMOID还是有问题,请问您可以提供SMOID的训练代码吗?
image

from stablellve.

zkawfanx avatar zkawfanx commented on August 21, 2024

你先检查下图片输入输出时RGB/BGR的格式,这个看起来是通道反了

from stablellve.

Ac2333 avatar Ac2333 commented on August 21, 2024

你先检查下图片输入输出时RGB/BGR的格式,这个看起来是通道反了

我对network进行了修改,因为不修改的话无法训练起来,原来的network是用于处理raw格式的视频,下面这个是我修改后network的代码,我修改的正确吗?

# 3D-Conv-2D-Pool UNet
def network(input, depth=3, channel=32, prefix=''):
    depth = min(max(depth, 2), 4)

    conv1 = slim.conv3d(input, channel, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv1_1')
    conv1 = slim.conv3d(conv1, channel, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv1_2')
    pool1 = tf.expand_dims(slim.max_pool2d(conv1[0], [2, 2], padding='SAME'), axis=0)

    conv2 = slim.conv3d(pool1, channel * 2, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv2_1')
    conv2 = slim.conv3d(conv2, channel * 2, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv2_2')
    pool2 = tf.expand_dims(slim.max_pool2d(conv2[0], [2, 2], padding='SAME'), axis=0)

    conv3 = slim.conv3d(pool2, channel * 4, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv3_1')
    conv3 = slim.conv3d(conv3, channel * 4, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv3_2')
    if depth == 2:
        up8 = upsample_and_concat(conv3, conv2, channel * 2, channel * 4)
    else:
        pool3 = tf.expand_dims(slim.max_pool2d(conv3[0], [2, 2], padding='SAME'), axis=0)

        conv4 = slim.conv3d(pool3, channel * 8, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv4_1')
        conv4 = slim.conv3d(conv4, channel * 8, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv4_2')
        if depth == 3:
            up7 = upsample_and_concat(conv4, conv3, channel * 4, channel * 8)
        else:
            pool4 = tf.expand_dims(slim.max_pool2d(conv4[0], [2, 2], padding='SAME'), axis=0)

            conv5 = slim.conv3d(pool4, channel * 16, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv5_1')
            conv5 = slim.conv3d(conv5, channel * 16, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv5_2')

            up6 = upsample_and_concat(conv5, conv4, channel * 8, channel * 16)
            conv6 = slim.conv3d(up6, channel * 8, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv6_1')
            conv6 = slim.conv3d(conv6, channel * 8, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv6_2')

            up7 = upsample_and_concat(conv6, conv3, channel * 4, channel * 8)
        conv7 = slim.conv3d(up7, channel * 4, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv7_1')
        conv7 = slim.conv3d(conv7, channel * 4, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv7_2')

        up8 = upsample_and_concat(conv7, conv2, channel * 2, channel * 4)
    conv8 = slim.conv3d(up8, channel * 2, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv8_1')
    conv8 = slim.conv3d(conv8, channel * 2, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv8_2')

    up9 = upsample_and_concat(conv8, conv1, channel, channel * 2)
    conv9 = slim.conv3d(up9, channel, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv9_1')
    conv9 = slim.conv3d(conv9, channel, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv9_2')

    conv10 = slim.conv3d(conv9, 3, [1, 1, 1], rate=1, activation_fn=None, scope=prefix + 'g_conv10')

    # out = tf.concat([tf.expand_dims(tf.depth_to_space(conv10[:, i, :, :, :], 2), axis=1) for i in range(conv10.shape[1])], axis=1)
    # Directly use conv10 as final output for RGB images
    out = conv10
    if DEBUG:
        print('[DEBUG] (network.py) conv10.shape, out.shape:', conv10.shape, out.shape)
    return out

我将原来的slim.conv3d中的12改为了3 并且省略了tf.concat的步骤

from stablellve.

Ac2333 avatar Ac2333 commented on August 21, 2024

你先检查下图片输入输出时RGB/BGR的格式,这个看起来是通道反了

我是通过 img = cv2.imread( )来读取图片,所以是要加上 img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 这个过程来将 BGR的格式转化为RGB格式吗?

from stablellve.

zkawfanx avatar zkawfanx commented on August 21, 2024

你先检查下图片输入输出时RGB/BGR的格式,这个看起来是通道反了

我对network进行了修改,因为不修改的话无法训练起来,原来的network是用于处理raw格式的视频,下面这个是我修改后network的代码,我修改的正确吗?

# 3D-Conv-2D-Pool UNet
def network(input, depth=3, channel=32, prefix=''):
    depth = min(max(depth, 2), 4)

    conv1 = slim.conv3d(input, channel, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv1_1')
    conv1 = slim.conv3d(conv1, channel, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv1_2')
    pool1 = tf.expand_dims(slim.max_pool2d(conv1[0], [2, 2], padding='SAME'), axis=0)

    conv2 = slim.conv3d(pool1, channel * 2, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv2_1')
    conv2 = slim.conv3d(conv2, channel * 2, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv2_2')
    pool2 = tf.expand_dims(slim.max_pool2d(conv2[0], [2, 2], padding='SAME'), axis=0)

    conv3 = slim.conv3d(pool2, channel * 4, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv3_1')
    conv3 = slim.conv3d(conv3, channel * 4, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv3_2')
    if depth == 2:
        up8 = upsample_and_concat(conv3, conv2, channel * 2, channel * 4)
    else:
        pool3 = tf.expand_dims(slim.max_pool2d(conv3[0], [2, 2], padding='SAME'), axis=0)

        conv4 = slim.conv3d(pool3, channel * 8, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv4_1')
        conv4 = slim.conv3d(conv4, channel * 8, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv4_2')
        if depth == 3:
            up7 = upsample_and_concat(conv4, conv3, channel * 4, channel * 8)
        else:
            pool4 = tf.expand_dims(slim.max_pool2d(conv4[0], [2, 2], padding='SAME'), axis=0)

            conv5 = slim.conv3d(pool4, channel * 16, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv5_1')
            conv5 = slim.conv3d(conv5, channel * 16, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv5_2')

            up6 = upsample_and_concat(conv5, conv4, channel * 8, channel * 16)
            conv6 = slim.conv3d(up6, channel * 8, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv6_1')
            conv6 = slim.conv3d(conv6, channel * 8, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv6_2')

            up7 = upsample_and_concat(conv6, conv3, channel * 4, channel * 8)
        conv7 = slim.conv3d(up7, channel * 4, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv7_1')
        conv7 = slim.conv3d(conv7, channel * 4, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv7_2')

        up8 = upsample_and_concat(conv7, conv2, channel * 2, channel * 4)
    conv8 = slim.conv3d(up8, channel * 2, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv8_1')
    conv8 = slim.conv3d(conv8, channel * 2, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv8_2')

    up9 = upsample_and_concat(conv8, conv1, channel, channel * 2)
    conv9 = slim.conv3d(up9, channel, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv9_1')
    conv9 = slim.conv3d(conv9, channel, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv9_2')

    conv10 = slim.conv3d(conv9, 3, [1, 1, 1], rate=1, activation_fn=None, scope=prefix + 'g_conv10')

    # out = tf.concat([tf.expand_dims(tf.depth_to_space(conv10[:, i, :, :, :], 2), axis=1) for i in range(conv10.shape[1])], axis=1)
    # Directly use conv10 as final output for RGB images
    out = conv10
    if DEBUG:
        print('[DEBUG] (network.py) conv10.shape, out.shape:', conv10.shape, out.shape)
    return out

我将原来的slim.conv3d中的12改为了3 并且省略了tf.concat的步骤

我记得原代码中应该是首先将单通道的raw数据并预处理为4通道的格式,然后gt和输出都是4通道。

你直接用在DAVIS上的话输入输出都改为3通道应该就可以了

from stablellve.

zkawfanx avatar zkawfanx commented on August 21, 2024

你先检查下图片输入输出时RGB/BGR的格式,这个看起来是通道反了

我是通过 img = cv2.imread( )来读取图片,所以是要加上 img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 这个过程来将 BGR的格式转化为RGB格式吗?

你可以把你之前那种测试结果不正常的图片读入并且转为RGB再可视化看看,我感觉可能是通道反了

from stablellve.

Ac2333 avatar Ac2333 commented on August 21, 2024

你先检查下图片输入输出时RGB/BGR的格式,这个看起来是通道反了

我是通过 img = cv2.imread( )来读取图片,所以是要加上 img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 这个过程来将 BGR的格式转化为RGB格式吗?

你可以把你之前那种测试结果不正常的图片读入并且转为RGB再可视化看看,我感觉可能是通道反了

下面这个图是我新训练的结果,这次训练读取的图片我已经将 BGR的格式转化为RGB格式,但是下面这个测试结果看起来还是有点问题
image

from stablellve.

Ac2333 avatar Ac2333 commented on August 21, 2024

您能否提供一下SMOID的训练代码呢 十分感谢!!
下面这个是我的训练代码,我已经将输入输出都改为3通道,并且将14位深度修改为24位深度。不过与原代码不同的是,我修改了原代码的crop函数,不修改的话裁剪会出现0的情况。

import time, glob

import cv2
import numpy as np
import tensorflow as tf
import tf_slim as slim
from skvideo.io import vwrite

from network import network
from config import *
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "9"
tf.compat.v1.disable_eager_execution()


# get train IDs
with open(FILE_LIST) as f:
    text = f.readlines()
train_files = text

train_ids = [line.strip().split(' ')[0] for line in train_files]
gt_files = [line.strip().split(' ')[1] for line in train_files]
in_files = [line.strip().split(' ')[2] for line in train_files]

raw = np.load(in_files[0])
# F = raw.shape[0]
# H = raw.shape[1]
# W = raw.shape[2]

#def crop(raw, gt_raw, start_frame=0):
#    # inputs must be in a form of [batch_num, frame_num, height, width, channel_num]
#    tt = start_frame
#    xx = np.random.randint(0, W - CROP_WIDTH)
#    yy = np.random.randint(0, H - CROP_HEIGHT)
#
#    input_patch = raw[:, tt:tt + CROP_FRAME, yy:yy + CROP_HEIGHT, xx:xx + CROP_WIDTH, :]
#    gt_patch = gt_raw[:, tt:tt + CROP_FRAME, yy * 2:(yy + CROP_HEIGHT) * 2, xx * 2:(xx + CROP_WIDTH) * 2, :]
#    return input_patch, gt_patch

def crop(raw, gt_raw, start_frame=0):
    # inputs must be in a form of [batch_num, frame_num, height, width, channel_num]
    _, _, H, W, _ = raw.shape
    tt = start_frame
    xx = np.random.randint(0, W - CROP_WIDTH)
    yy = np.random.randint(0, H - CROP_HEIGHT)

    input_patch = raw[:, tt:tt + CROP_FRAME, yy:yy + CROP_HEIGHT, xx:xx + CROP_WIDTH, :]
    gt_patch = gt_raw[:, tt:tt + CROP_FRAME, yy:yy + CROP_HEIGHT, xx:xx + CROP_WIDTH, :]
    return input_patch, gt_patch

def flip(input_patch, gt_patch):
    # inputs must be in a form of [batch_num, frame_num, height, width, channel_num]
    if np.random.randint(2, size=1)[0] == 1:  # random flip
        input_patch = np.flip(input_patch, axis=1)
        gt_patch = np.flip(gt_patch, axis=1)
    if np.random.randint(2, size=1)[0] == 1:
        input_patch = np.flip(input_patch, axis=2)
        gt_patch = np.flip(gt_patch, axis=2)
    if np.random.randint(2, size=1)[0] == 1:
        input_patch = np.flip(input_patch, axis=3)
        gt_patch = np.flip(gt_patch, axis=3)
    if np.random.randint(2, size=1)[0] == 1:  # random transpose
        input_patch = np.transpose(input_patch, (0, 1, 3, 2, 4))
        gt_patch = np.transpose(gt_patch, (0, 1, 3, 2, 4))
    return input_patch, gt_patch

def count_images_in_folders(root_folder):
 
    video_frame_counts = {}
    for folder_name in os.listdir(root_folder):
        folder_path = os.path.join(root_folder, folder_name)
        if os.path.isdir(folder_path):
            image_count = len([name for name in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, name))])
            video_frame_counts[folder_name] = image_count
    return video_frame_counts

def main():
    sess = tf.compat.v1.Session()
    if tf.test.gpu_device_name():
        print(('Default GPU Device: {}'.format(tf.test.gpu_device_name())))
    else:
        print("Please install GPU version of TF")
    root_folder = '../data/gt'
    video_frame_counts = count_images_in_folders(root_folder)
    
    in_image = tf.compat.v1.placeholder(tf.float32, [None, CROP_FRAME, None, None, 3])
    gt_image = tf.compat.v1.placeholder(tf.float32, [None, CROP_FRAME, None, None, 3])
    out_image = network(in_image)
    if DEBUG:
        print('[DEBUG] out_image shape:', out_image.shape)


    G_loss = tf.reduce_mean(input_tensor=tf.abs(out_image - gt_image))
    v_loss = tf.Variable(0.0)

    # tensorboard summary
    tf.compat.v1.summary.scalar('loss', v_loss)
    # tf.summary.scalar('validation loss', v_loss)
    summary_op = tf.compat.v1.summary.merge_all()
    writer = tf.compat.v1.summary.FileWriter(os.path.join(LOGS_DIR, TRAIN_LOG_DIR), graph=tf.compat.v1.get_default_graph())
    writer_val = tf.compat.v1.summary.FileWriter(os.path.join(LOGS_DIR, VAL_LOG_DIR), graph=tf.compat.v1.get_default_graph())

    t_vars = tf.compat.v1.trainable_variables()
    lr = tf.compat.v1.placeholder(tf.float32)
    G_opt = tf.compat.v1.train.AdamOptimizer(learning_rate=lr).minimize(G_loss)

    saver = tf.compat.v1.train.Saver()
    sess.run(tf.compat.v1.global_variables_initializer())
    ckpt = tf.train.get_checkpoint_state(CHECKPOINT_DIR)
    if ckpt:
        print(('loaded ' + ckpt.model_checkpoint_path))
        saver.restore(sess, ckpt.model_checkpoint_path)
    if not os.path.isdir(CHECKPOINT_DIR):
        os.makedirs(CHECKPOINT_DIR)

    # Raw data takes long time to load. Keep them in memory after loaded.
    gt_images = [None] * len(train_ids)
    input_images = [None] * len(train_ids)

    g_loss = np.zeros((len(train_ids), 1))

    lastepoch = 0
    if not os.path.isdir(RESULT_DIR):
        os.makedirs(RESULT_DIR)
    else:
        all_items = glob.glob(os.path.join(RESULT_DIR, '*'))
        all_folders = [os.path.basename(d) for d in all_items if os.path.isdir(d) and os.path.basename(d).isdigit()]
        for folder in all_folders:
            lastepoch = np.maximum(lastepoch, int(folder))

    learning_rate = INIT_LR

    np.random.seed(ord('c') + 137)
    count = 0
    for epoch in range(lastepoch + 1, MAX_EPOCH + 1):
        if epoch % SAVE_FREQ == 0:
            save_results = True
            if not os.path.isdir(RESULT_DIR + '%04d' % epoch):
                os.makedirs(RESULT_DIR + '%04d' % epoch)
        else:
            save_results = False
        cnt = 0
        if epoch > DECAY_EPOCH:
            learning_rate = DECAY_LR

        N = len(train_ids)
        all_order = np.random.permutation(N)
        last_group = (N // GROUP_NUM) * GROUP_NUM
        split_order = np.split(all_order[:last_group], (N // GROUP_NUM))
        split_order.append(all_order[last_group:])
        for order in split_order:
            gt_images = [None] * len(train_ids)
            input_images = [None] * len(train_ids)
            # order_frame = [(one, y) for y in [t for t in np.random.permutation(ALL_FRAME - CROP_FRAME) if t % FRAME_FREQ == 0] for one in order]
            order_frame = []
            for one in order:
                video_name = train_ids[one]
                frame_count = video_frame_counts[video_name]
                available_frames = [t for t in np.random.permutation(frame_count - CROP_FRAME) if t % FRAME_FREQ == 0]
                for y in available_frames:
                    order_frame.append((one, y))
                    
            index = np.random.permutation(len(order_frame))
            for idx in index:
                ind, start_frame = order_frame[idx]
                start_frame += np.random.randint(FRAME_FREQ)
                # get the path from image id
                train_id = train_ids[ind] + '_start_frame_' + str(start_frame)
                in_path = in_files[ind]

                gt_path = gt_files[ind]

                st = time.time()
                cnt += 1

                if input_images[ind] is None:
                    read_in = np.load(in_path)
                    # 16 bit
                    input_images[ind] = np.expand_dims(np.float32(read_in) / 255.0, axis=0)
                raw = input_images[ind]
                # raw = np.expand_dims(raw / 65535.0, axis=0)

                if gt_images[ind] is None:
                    gt_images[ind] = np.expand_dims(np.float32(np.load(gt_path) / 255.0), axis=0)
                gt_raw = gt_images[ind]
                # gt_raw = np.expand_dims(np.float32(gt_raw / 255.0), axis=0)

                input_patch, gt_patch = crop(raw, gt_raw, start_frame)

                input_patch, gt_patch = flip(input_patch, gt_patch)

                input_patch = np.minimum(input_patch, 1.0)

                _, G_current, output = sess.run([G_opt, G_loss, out_image], feed_dict={in_image: input_patch, gt_image: gt_patch, lr: learning_rate})
                output = np.minimum(np.maximum(output, 0), 1)
                g_loss[ind] = G_current




                # save loss
                summary = sess.run(summary_op, feed_dict={v_loss:G_current})
                writer.add_summary(summary, count)
                count += 1

                if save_results and start_frame in SAVE_FRAMES:
                    temp = np.concatenate((gt_patch[0, :, ::-1, :, :], output[0, :, ::-1, :, :]), axis=2)
                    try:
                        vwrite((RESULT_DIR + '%04d/%s_train.avi' % (epoch, train_id)), (temp * 255).astype('uint8'))
                    except OSError as e:
                        print(('\t', e, 'Skip saving.'))



                print(("%d %d Loss=%.8f Time=%.3f" % (epoch, cnt, np.mean(g_loss[np.where(g_loss)]), time.time() - st)), train_id)
            
        saver.save(sess, CHECKPOINT_DIR + 'model.ckpt')
        if save_results:
            saver.save(sess, RESULT_DIR + '%04d/' % epoch + 'model.ckpt')


if __name__ == '__main__':
    main()

from stablellve.

zkawfanx avatar zkawfanx commented on August 21, 2024

您能否提供一下SMOID的训练代码呢 十分感谢!! 下面这个是我的训练代码,我已经将输入输出都改为3通道,并且将14位深度修改为24位深度。不过与原代码不同的是,我修改了原代码的crop函数,不修改的话裁剪会出现0的情况。

import time, glob

import cv2
import numpy as np
import tensorflow as tf
import tf_slim as slim
from skvideo.io import vwrite

from network import network
from config import *
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "9"
tf.compat.v1.disable_eager_execution()


# get train IDs
with open(FILE_LIST) as f:
    text = f.readlines()
train_files = text

train_ids = [line.strip().split(' ')[0] for line in train_files]
gt_files = [line.strip().split(' ')[1] for line in train_files]
in_files = [line.strip().split(' ')[2] for line in train_files]

raw = np.load(in_files[0])
# F = raw.shape[0]
# H = raw.shape[1]
# W = raw.shape[2]

#def crop(raw, gt_raw, start_frame=0):
#    # inputs must be in a form of [batch_num, frame_num, height, width, channel_num]
#    tt = start_frame
#    xx = np.random.randint(0, W - CROP_WIDTH)
#    yy = np.random.randint(0, H - CROP_HEIGHT)
#
#    input_patch = raw[:, tt:tt + CROP_FRAME, yy:yy + CROP_HEIGHT, xx:xx + CROP_WIDTH, :]
#    gt_patch = gt_raw[:, tt:tt + CROP_FRAME, yy * 2:(yy + CROP_HEIGHT) * 2, xx * 2:(xx + CROP_WIDTH) * 2, :]
#    return input_patch, gt_patch

def crop(raw, gt_raw, start_frame=0):
    # inputs must be in a form of [batch_num, frame_num, height, width, channel_num]
    _, _, H, W, _ = raw.shape
    tt = start_frame
    xx = np.random.randint(0, W - CROP_WIDTH)
    yy = np.random.randint(0, H - CROP_HEIGHT)

    input_patch = raw[:, tt:tt + CROP_FRAME, yy:yy + CROP_HEIGHT, xx:xx + CROP_WIDTH, :]
    gt_patch = gt_raw[:, tt:tt + CROP_FRAME, yy:yy + CROP_HEIGHT, xx:xx + CROP_WIDTH, :]
    return input_patch, gt_patch

def flip(input_patch, gt_patch):
    # inputs must be in a form of [batch_num, frame_num, height, width, channel_num]
    if np.random.randint(2, size=1)[0] == 1:  # random flip
        input_patch = np.flip(input_patch, axis=1)
        gt_patch = np.flip(gt_patch, axis=1)
    if np.random.randint(2, size=1)[0] == 1:
        input_patch = np.flip(input_patch, axis=2)
        gt_patch = np.flip(gt_patch, axis=2)
    if np.random.randint(2, size=1)[0] == 1:
        input_patch = np.flip(input_patch, axis=3)
        gt_patch = np.flip(gt_patch, axis=3)
    if np.random.randint(2, size=1)[0] == 1:  # random transpose
        input_patch = np.transpose(input_patch, (0, 1, 3, 2, 4))
        gt_patch = np.transpose(gt_patch, (0, 1, 3, 2, 4))
    return input_patch, gt_patch

def count_images_in_folders(root_folder):
 
    video_frame_counts = {}
    for folder_name in os.listdir(root_folder):
        folder_path = os.path.join(root_folder, folder_name)
        if os.path.isdir(folder_path):
            image_count = len([name for name in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, name))])
            video_frame_counts[folder_name] = image_count
    return video_frame_counts

def main():
    sess = tf.compat.v1.Session()
    if tf.test.gpu_device_name():
        print(('Default GPU Device: {}'.format(tf.test.gpu_device_name())))
    else:
        print("Please install GPU version of TF")
    root_folder = '../data/gt'
    video_frame_counts = count_images_in_folders(root_folder)
    
    in_image = tf.compat.v1.placeholder(tf.float32, [None, CROP_FRAME, None, None, 3])
    gt_image = tf.compat.v1.placeholder(tf.float32, [None, CROP_FRAME, None, None, 3])
    out_image = network(in_image)
    if DEBUG:
        print('[DEBUG] out_image shape:', out_image.shape)


    G_loss = tf.reduce_mean(input_tensor=tf.abs(out_image - gt_image))
    v_loss = tf.Variable(0.0)

    # tensorboard summary
    tf.compat.v1.summary.scalar('loss', v_loss)
    # tf.summary.scalar('validation loss', v_loss)
    summary_op = tf.compat.v1.summary.merge_all()
    writer = tf.compat.v1.summary.FileWriter(os.path.join(LOGS_DIR, TRAIN_LOG_DIR), graph=tf.compat.v1.get_default_graph())
    writer_val = tf.compat.v1.summary.FileWriter(os.path.join(LOGS_DIR, VAL_LOG_DIR), graph=tf.compat.v1.get_default_graph())

    t_vars = tf.compat.v1.trainable_variables()
    lr = tf.compat.v1.placeholder(tf.float32)
    G_opt = tf.compat.v1.train.AdamOptimizer(learning_rate=lr).minimize(G_loss)

    saver = tf.compat.v1.train.Saver()
    sess.run(tf.compat.v1.global_variables_initializer())
    ckpt = tf.train.get_checkpoint_state(CHECKPOINT_DIR)
    if ckpt:
        print(('loaded ' + ckpt.model_checkpoint_path))
        saver.restore(sess, ckpt.model_checkpoint_path)
    if not os.path.isdir(CHECKPOINT_DIR):
        os.makedirs(CHECKPOINT_DIR)

    # Raw data takes long time to load. Keep them in memory after loaded.
    gt_images = [None] * len(train_ids)
    input_images = [None] * len(train_ids)

    g_loss = np.zeros((len(train_ids), 1))

    lastepoch = 0
    if not os.path.isdir(RESULT_DIR):
        os.makedirs(RESULT_DIR)
    else:
        all_items = glob.glob(os.path.join(RESULT_DIR, '*'))
        all_folders = [os.path.basename(d) for d in all_items if os.path.isdir(d) and os.path.basename(d).isdigit()]
        for folder in all_folders:
            lastepoch = np.maximum(lastepoch, int(folder))

    learning_rate = INIT_LR

    np.random.seed(ord('c') + 137)
    count = 0
    for epoch in range(lastepoch + 1, MAX_EPOCH + 1):
        if epoch % SAVE_FREQ == 0:
            save_results = True
            if not os.path.isdir(RESULT_DIR + '%04d' % epoch):
                os.makedirs(RESULT_DIR + '%04d' % epoch)
        else:
            save_results = False
        cnt = 0
        if epoch > DECAY_EPOCH:
            learning_rate = DECAY_LR

        N = len(train_ids)
        all_order = np.random.permutation(N)
        last_group = (N // GROUP_NUM) * GROUP_NUM
        split_order = np.split(all_order[:last_group], (N // GROUP_NUM))
        split_order.append(all_order[last_group:])
        for order in split_order:
            gt_images = [None] * len(train_ids)
            input_images = [None] * len(train_ids)
            # order_frame = [(one, y) for y in [t for t in np.random.permutation(ALL_FRAME - CROP_FRAME) if t % FRAME_FREQ == 0] for one in order]
            order_frame = []
            for one in order:
                video_name = train_ids[one]
                frame_count = video_frame_counts[video_name]
                available_frames = [t for t in np.random.permutation(frame_count - CROP_FRAME) if t % FRAME_FREQ == 0]
                for y in available_frames:
                    order_frame.append((one, y))
                    
            index = np.random.permutation(len(order_frame))
            for idx in index:
                ind, start_frame = order_frame[idx]
                start_frame += np.random.randint(FRAME_FREQ)
                # get the path from image id
                train_id = train_ids[ind] + '_start_frame_' + str(start_frame)
                in_path = in_files[ind]

                gt_path = gt_files[ind]

                st = time.time()
                cnt += 1

                if input_images[ind] is None:
                    read_in = np.load(in_path)
                    # 16 bit
                    input_images[ind] = np.expand_dims(np.float32(read_in) / 255.0, axis=0)
                raw = input_images[ind]
                # raw = np.expand_dims(raw / 65535.0, axis=0)

                if gt_images[ind] is None:
                    gt_images[ind] = np.expand_dims(np.float32(np.load(gt_path) / 255.0), axis=0)
                gt_raw = gt_images[ind]
                # gt_raw = np.expand_dims(np.float32(gt_raw / 255.0), axis=0)

                input_patch, gt_patch = crop(raw, gt_raw, start_frame)

                input_patch, gt_patch = flip(input_patch, gt_patch)

                input_patch = np.minimum(input_patch, 1.0)

                _, G_current, output = sess.run([G_opt, G_loss, out_image], feed_dict={in_image: input_patch, gt_image: gt_patch, lr: learning_rate})
                output = np.minimum(np.maximum(output, 0), 1)
                g_loss[ind] = G_current




                # save loss
                summary = sess.run(summary_op, feed_dict={v_loss:G_current})
                writer.add_summary(summary, count)
                count += 1

                if save_results and start_frame in SAVE_FRAMES:
                    temp = np.concatenate((gt_patch[0, :, ::-1, :, :], output[0, :, ::-1, :, :]), axis=2)
                    try:
                        vwrite((RESULT_DIR + '%04d/%s_train.avi' % (epoch, train_id)), (temp * 255).astype('uint8'))
                    except OSError as e:
                        print(('\t', e, 'Skip saving.'))



                print(("%d %d Loss=%.8f Time=%.3f" % (epoch, cnt, np.mean(g_loss[np.where(g_loss)]), time.time() - st)), train_id)
            
        saver.save(sess, CHECKPOINT_DIR + 'model.ckpt')
        if save_results:
            saver.save(sess, RESULT_DIR + '%04d/' % epoch + 'model.ckpt')


if __name__ == '__main__':
    main()

看样子颜色通道异常的问题已经解决了,我觉得现在的结果看起来知识有些过曝,你可以尝试把网络输出归一化或者截断到[0,1]范围内再可视化看看。

黑条应该是你修改crop没改对,你可以试试原来的crop,裁出0应该不会影响训练,很多时候对于尺寸不对的输入,都是通过pad 0来预处理的。

这个方法只是那篇论文里我选择对比的其中之一,现在要找到当时的代码比较麻烦。你也许可以发issue问问原作者。

from stablellve.

zkawfanx avatar zkawfanx commented on August 21, 2024

@Ac2333 你需要理解raw格式的像素排列和rgb的对应关系,源代码中crop应该是对rggb排列的raw数据进行裁剪,也就是每2x2个像素对应rgb中的一个pixel,所以crop的时候*2了,你的修改可能有点问题。

其次,raw数据选用16bit只是因为数据本身是14或者12bit的,通过16bit的格式保存了。DAVIS数据本身是sRGB,只有8bit,你不需要改成24bit等操作,而且你所说的24深度也有歧义,有可能指的是8bit*3一共24深度,这和low light image enhancement中提到的bit位数有些差别,你需要注意区分。

而且从你最新的结果图可以看出,网络是可以正常输出的,low light图是能够被网络提高亮度的,只是你的预处理和后处理可能还存在问题,你排查一下,相信对你理解这整个处理流程也有帮助。

from stablellve.

Ac2333 avatar Ac2333 commented on August 21, 2024

@Ac2333 你需要理解raw格式的像素排列和rgb的对应关系,源代码中crop应该是对rggb排列的raw数据进行裁剪,也就是每2x2个像素对应rgb中的一个pixel,所以crop的时候*2了,你的修改可能有点问题。

其次,raw数据选用16bit只是因为数据本身是14或者12bit的,通过16bit的格式保存了。DAVIS数据本身是sRGB,只有8bit,你不需要改成24bit等操作,而且你所说的24深度也有歧义,有可能指的是8bit*3一共24深度,这和low light image enhancement中提到的bit位数有些差别,你需要注意区分。

而且从你最新的结果图可以看出,网络是可以正常输出的,low light图是能够被网络提高亮度的,只是你的预处理和后处理可能还存在问题,你排查一下,相信对你理解这整个处理流程也有帮助。

好的 十分感谢您的帮助!

from stablellve.

TuHaiqing avatar TuHaiqing commented on August 21, 2024

你好,请问大佬是否有Learning to See Moving Objects in the Dark论文中的SMOID数据集呢?现在从论文作者的github网址https://github.com/MichaelHYJiang/Learning-to-See-Moving-Objects-in-the-Dark执行python download_dataset.py
下载下来的数据集压缩包似乎不完整。如果有数据方便分享一下吗?谢谢。

from stablellve.

zkawfanx avatar zkawfanx commented on August 21, 2024

你好,请问大佬是否有Learning to See Moving Objects in the Dark论文中的SMOID数据集呢?现在从论文作者的github网址https://github.com/MichaelHYJiang/Learning-to-See-Moving-Objects-in-the-Dark执行python download_dataset.py 下载下来的数据集压缩包似乎不完整。如果有数据方便分享一下吗?谢谢。

抱歉,我没有SMOID数据集,你也许可以给通讯作者Yinqiang Zheng发信问一下。

from stablellve.

TuHaiqing avatar TuHaiqing commented on August 21, 2024

你好,请问大佬是否有Learning to See Moving Objects in the Dark论文中的SMOID数据集呢?现在从论文作者的github网址https://github.com/MichaelHYJiang/Learning-to-See-Moving-Objects-in-the-Dark执行python download_dataset.py 下载下来的数据集压缩包似乎不完整。如果有数据方便分享一下吗?谢谢。

抱歉,我没有SMOID数据集,你也许可以给通讯作者Yinqiang Zheng发信问一下。

感谢回复,我在论文作者的github项目主页上提问没有回复,我尝试给通讯作者发信问一下。

from stablellve.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.