Giter VIP home page Giter VIP logo

Comments (4)

taehoonlee avatar taehoonlee commented on May 19, 2024

@haddis3, Can you share your snippets to reproduce the error?

from tensornets.

haddis3 avatar haddis3 commented on May 19, 2024

`def run_training():

is_training = tf.placeholder(tf.bool)
steps_per_epoch = 3000
# learning_rate = tf.placeholder(dtype=tf.float32)
train_filenames = ['./train.tfrecords']
validation_filenames = ['./valid.tfrecords']
train_datasets = tf.data.TFRecordDataset(filenames=train_filenames)
validation_datasets = tf.data.TFRecordDataset(filenames=validation_filenames)
train_datasets = train_datasets.map(_parse_function).batch(16).repeat()
validation_datasets = validation_datasets.map(_parse_function).batch(1).repeat(1)
iterator = tf.data.Iterator.from_structure(train_datasets.output_types,
                                           train_datasets.output_shapes)
image, label = iterator.get_next()
training_init_op = iterator.make_initializer(train_datasets)
validation_init_op = iterator.make_initializer(validation_datasets)

model = nets.ResNet101(image, stem=True, is_training=is_training, classes=13)     # 注意输入图片的预处理
# model = nets.DenseNet121()
res_feature = model.get_outputs()[128]
scale_feature = tf.image.resize_bilinear(res_feature, [16, 16])
fc7_ = subnet.Part_based_convolution(scale_feature)
# x1 = tf.layers.conv2d(res_feature, filters=2048, kernel_size=3, strides=2)
# x2 = tf.layers.conv2d(res_feature, filters=2048, kernel_size=3, strides=1)
# compute the loss
glob = tf.reduce_mean(model, axis=[1, 2])
fc7 = tf.layers.dense(glob, units=13,
                      kernel_initializer=tf.contrib.layers.variance_scaling_initializer())
out = fc7 + fc7_
# recalibration the weights
# reweight_fc7 = tf.layers.dense(fc7, units=13, activation=tf.nn.sigmoid,
#               kernel_initializer=tf.contrib.layers.variance_scaling_initializer())
# fc8 = tf.multiply(fc7, reweight_fc7)
#

cost = loss(out, label)
# optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9)
optimizer = tf.train.AdamOptimizer(learning_rate=0.0001)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    grads = optimizer.compute_gradients(cost)
    # record the grad of the variables
    for grad, var in grads:
        add_gradient_summary(grad, var)
    train_op = optimizer.apply_gradients(grads)

summary_op = tf.summary.merge_all()
saver = tf.train.Saver(max_to_keep=10)
lr = 0.0001
init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

with tf.Session() as sess:

    sess.run(init)
    summary_writer = tf.summary.FileWriter(FLAGS.logs_dir, sess.graph)
    sess.run(model.pretrained())
    step = 0

    for _ in range(18):
        sess.run(training_init_op)  # must initialize the iterator to read the data flow
        for _ in range(3125):

            # if step / steps_per_epoch == 5:
            #     lr = 0.00001
            # elif step / steps_per_epoch == (FLAGS.epoch * 0.5):
            #     lr = 0.000001
            # elif step / steps_per_epoch == (FLAGS.epoch * 0.8):
            #     lr = 0.0000001

            start_time = time.time()
            _, loss_value = sess.run([train_op, cost], {is_training: True})
            duration = time.time() - start_time

            if step % 100 == 0:
                print('step %d: loss = %.6f ( %.3f sec)' % (step, loss_value, duration))
                # summary_str = sess.run(summary_op, {is_training: False})
                # summary_writer.add_summary(summary_str, step)

            if step % 3125 == 0:
                checkpoint_path = os.path.join(FLAGS.model_dir, 'model.ckpt')
                saver.save(sess, save_path=checkpoint_path, global_step=step)

            step += 1

`

from tensornets.

taehoonlee avatar taehoonlee commented on May 19, 2024

@haddis3, Can you elaborate on your snippets to reproduce the error? The above snippet has no error points.

from tensornets.

taehoonlee avatar taehoonlee commented on May 19, 2024

@haddis3, I'll close the issue for now. Please feel free to open it again at any time if you have time to check. I'm always ready :)

from tensornets.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.