Giter VIP home page Giter VIP logo

Comments (5)

hamuchu avatar hamuchu commented on May 14, 2024

Finally,the issue is resolved.
I will paste the final source code later.
Thanks.

from tensorpack.

ppwwyyxx avatar ppwwyyxx commented on May 14, 2024

From the log it looks like your repo was not up-to-date. It actually runs well with latest commit.

from tensorpack.

hamuchu avatar hamuchu commented on May 14, 2024

Thanks.The source code is this.
This source code works well.

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: cifar-convnet.py
# Author: Yuxin Wu <[email protected]>
import tensorflow as tf
import argparse
import numpy as np
import os

from tensorpack import *
import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils.summary import *
from dorefa import get_dorefa

from tensorpack.tfutils.symbolic_functions import *

"""
A small convnet model for Cifar10 or Cifar100 dataset.

Cifar10:
    90% validation accuracy after 40k step.
    91% accuracy after 80k step.
    19.3 step/s on Tesla M40

Not a good model for Cifar100, just for demonstration.
"""
BITW = 1
BITA = 2
BITG = 6
BATCH_SIZE = 32
class Model(ModelDesc):
    def __init__(self, cifar_classnum):
        super(Model, self).__init__()
        self.cifar_classnum = cifar_classnum

    def _get_input_vars(self):
        return [InputVar(tf.float32, [None, 30, 30, 3], 'input'),
                InputVar(tf.int32, [None], 'label')]

    def _build_graph(self, input_vars, is_training):

        image, label = input_vars
        image = image / 4.0     # just to make range smaller

        fw, fa, fg = get_dorefa(BITW, BITA, BITG)
        # monkey-patch tf.get_variable to apply fw
        old_get_variable = tf.get_variable
        def new_get_variable(name, shape=None, **kwargs):
            v = old_get_variable(name, shape, **kwargs)
            # don't binarize first and last layer
            if name != 'W' or 'conv0' in v.op.name or 'fct' in v.op.name:
                return v
            else:
                logger.info("Binarizing weight {}".format(v.op.name))
                return fw(v)
        tf.get_variable = new_get_variable

        def nonlin(x):
            if BITA == 32:
                return tf.nn.relu(x)    # still use relu for 32bit cases
            return tf.clip_by_value(x, 0.0, 1.0)

        def activate(x):
            return fa(nonlin(x))
        def cabs(x):
            return tf.minimum(1.0, tf.abs(x), name='cabs')

        keep_prob = tf.constant(0.5 if is_training else 1.0)

        if is_training:
            tf.image_summary("train_image", image, 10)

        print "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
        print is_training
        with  argscope(FullyConnected, use_bias=False, nl=tf.identity), \
              argscope(Conv2D, nl=BNReLU(is_training), use_bias=False, kernel_shape=3):
            logits = LinearWrap(image) \
                    .Conv2D('conv1.1', out_channel=64)\
                    .apply(activate)\
                    .Conv2D('conv1.2', out_channel=64) \
                    .apply(fg)\
                    .BatchNorm('bn0',use_local_stat=is_training)\
                    .MaxPooling('pool1', 3, stride=2, padding='SAME') \
                    .apply(activate)\
                    .Conv2D('conv2.1', out_channel=128)\
                    .apply(fg)\
                    .BatchNorm('bn2',use_local_stat=is_training)\
                    .apply(activate)\
                    .Conv2D('conv2.2', out_channel=128)\
                    .apply(fg)\
                    .BatchNorm('bn3',use_local_stat=is_training)\
                    .MaxPooling('pool2', 3, stride=2, padding='SAME') \
                    .apply(activate)\
                    .Conv2D('conv3.1', out_channel=128, padding='VALID') \
                    .apply(fg)\
                    .BatchNorm('bn4',use_local_stat=is_training)\
                    .apply(activate)\
                    .Conv2D('conv3.2', out_channel=128, padding='VALID') \
                    .apply(fg)\
                    .BatchNorm('bn5',use_local_stat=is_training)\
                    .apply(activate)\
                    .FullyConnected('fc0', 1024 + 512,
                           b_init=tf.constant_initializer(0.1)) \
                    .tf.nn.dropout(keep_prob) \
                    .apply(fg)\
                    .BatchNorm('bn6',use_local_stat=is_training)\
                    .FullyConnected('fc1', 512,
                           b_init=tf.constant_initializer(0.1)) \
                    .apply(fg)\
                    .BatchNorm('bn7',use_local_stat=is_training)\
                    .apply(nonlin)\
                    .FullyConnected('linear', out_dim=self.cifar_classnum, nl=tf.identity)()
        tf.get_variable = old_get_variable

        cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, label)
        cost = tf.reduce_mean(cost, name='cross_entropy_loss')

        prob = tf.nn.softmax(logits, name='output')

        # compute the number of failed samples, for ClassificationError to use at test time
        wrong = symbf.prediction_incorrect(logits, label)
        nr_wrong = tf.reduce_sum(wrong, name='wrong')
        # monitor training error
        add_moving_summary(tf.reduce_mean(wrong, name='train_error'))

        # weight decay on all W of fc layers
        wd_cost = tf.mul(0.004,
                         regularize_cost('fc.*/W', tf.nn.l2_loss),
                         name='regularize_loss')
        add_moving_summary(cost, wd_cost)

        add_param_summary([('.*/W', ['histogram'])])   # monitor W
        self.cost = tf.add_n([cost, wd_cost], name='cost')

def get_data(train_or_test, cifar_classnum):
    isTrain = train_or_test == 'train'
    if cifar_classnum == 10:
        ds = dataset.Cifar10(train_or_test)
    else:
        ds = dataset.Cifar100(train_or_test)
    if isTrain:
        augmentors = [
            imgaug.RandomCrop((30, 30)),
            imgaug.Flip(horiz=True),
            imgaug.Brightness(63),
            imgaug.Contrast((0.2,1.8)),
            imgaug.GaussianDeform(
                [(0.2, 0.2), (0.2, 0.8), (0.8,0.8), (0.8,0.2)],
                (30,30), 0.2, 3),
            imgaug.MeanVarianceNormalize(all_channel=True)
        ]
    else:
        augmentors = [
            imgaug.CenterCrop((30, 30)),
            imgaug.MeanVarianceNormalize(all_channel=True)
        ]
    ds = AugmentImageComponent(ds, augmentors)
    ds = BatchData(ds, 128, remainder=not isTrain)
    if isTrain:
        ds = PrefetchData(ds, 3, 2)
    return ds
def get_config(cifar_classnum):
    logger.auto_set_dir()

    # prepare dataset
    dataset_train = get_data('train', cifar_classnum)
    step_per_epoch = dataset_train.size()
    dataset_test = get_data('test', cifar_classnum)

    sess_config = get_default_sess_config(0.5)

    nr_gpu = get_nr_gpu()#1e-5でeの−5乗0.00001
    lr = tf.train.exponential_decay(
        learning_rate=1e-4,
        global_step=get_global_step_var(),
        decay_steps=step_per_epoch * (30 if nr_gpu == 1 else 20),
        decay_rate=0.5, staircase=True, name='learning_rate')
    tf.scalar_summary('learning_rate', lr)

    return TrainConfig(
        dataset=dataset_train,
        optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-4),
        callbacks=Callbacks([
            StatPrinter(),
            ModelSaver(),
            InferenceRunner(dataset_train, ClassificationError())#dataset_testに書き換える
        ]),
        session_config=sess_config,
        model=Model(cifar_classnum),
        step_per_epoch=step_per_epoch,
        max_epoch=250,
    )

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') # nargs='*' in multi mode
    parser.add_argument('--load', help='load model')
    parser.add_argument('--classnum', help='10 for cifar10 or 100 for cifar100',
                        type=int, default=10)
    args = parser.parse_args()

    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    else:
        os.environ['CUDA_VISIBLE_DEVICES'] = '0'

    with tf.Graph().as_default():
        config = get_config(args.classnum)
        if args.load:
            config.session_init = SaverRestore(args.load)
        if args.gpu:
            config.nr_tower = len(args.gpu.split(','))
        #QueueInputTrainer(config).train()
        SimpleTrainer(config).train()

from tensorpack.

RyuLee avatar RyuLee commented on May 14, 2024

Hi, when I run your code I got an error like this:

(py35) D:\workplace\comp 535\src>Traceback (most recent call last):
File "", line 1, in
File "D:\Software\Anaconda\envs\py35\lib\multiprocessing\spawn.py", line 106, in spawn_main
exitcode = _main(fd)
File "D:\Software\Anaconda\envs\py35\lib\multiprocessing\spawn.py", line 116, in _main
self = pickle.load(from_parent)
EOFError: Ran out of input

any suggestions?

from tensorpack.

hamuchu avatar hamuchu commented on May 14, 2024

Thank you for point out the bug of code.
But I have no idea to fix my bugs.
Sorry.

from tensorpack.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.