Giter VIP home page Giter VIP logo

Comments (20)

NickleDave avatar NickleDave commented on June 19, 2024 1

Thank you for the feedback. To me your option of adding a separate page for experienced TF users makes the most sense. I will make sure I read the contributor agreement, and then I hope to find time in the next couple of weeks to work something up.

from nengo-dl.

NickleDave avatar NickleDave commented on June 19, 2024 1

Thanks for inviting feedback! Overall it looks good to me and I think this will definitely help orient someone coming from a Tensorflow background.

One general comment:
do you think it would be helpful when talking about the difference between Nengo models and deep learning models to include a link to the Nengo-core "Getting started" and "Examples" section?
i.e. either at the beginning or end of this section, say something like

For a more in-depth explanation of Nengo models see :ref: `getting-started` and :ref: `examples`

or something like that? Just in case someone coming from TF wants to learn more but isn't sure where to start. I looked for this but didn't see it. Sorry if I just missed it.

I noticed one place where there might be some Nengo-lingo that gets used without defining it, I'll comment directly on the PR.

I promise I have had "a slightly more detailed TensorNode class tutorial" on my to-do list but I just keep not managing to get to it. Will find time tomorrow to put together an initial attempt.

from nengo-dl.

NickleDave avatar NickleDave commented on June 19, 2024 1

Hmm, guess I can only comment on diff lines, nvm...

but I noticed in

"This is essentially equivalent to the TensorFlow function `tf.layers.batch_normalization(b_rate.neurons, momentum=0.9)`, except it works with Nengo objects. For example, `b_rate` is a `nengo.Ensemble` in this case, and we can add Probes or Connections to `batch_norm` in the same way as any other Nengo object.\n",
it talks about adding Probes and Connections to a tensor_layer but I don't think it's been stated (aside from code comments) what a Probe does. Minor detail but maybe worth emphasizing to a TF user why this is useful/important

from nengo-dl.

drasmuss avatar drasmuss commented on June 19, 2024

Actually the issue here is just that you need a return probabilities at the end of your __call__ function (the __call__ function is returning None at the moment, which is giving that error). That's definitely a cryptic error message though, I'll add something that performs some better validation on the TensorNode outputs to catch that kind of thing.

from nengo-dl.

NickleDave avatar NickleDave commented on June 19, 2024

Ah right of course. Functions should have return statements. Sorry, that was after a too-long day of coding.

By the way, the original issue I was trying to debug is that I had the initializer for the bias set to just a tf.zeros statement, like so:

        fc1b = tf.get_variable(name='fc1_biases',
                               initializer=tf.zeros(shape=[n_output_units]))

but when I did that I got an error:

ValueError: Initializer for variable fc1_biases/ is from inside a control-flow construct, such as a loop or conditional. When creating a variable inside a loop or conditional, use a lambda as the initializer.

I think this is because there's a tf.while_loop in nengo_dl.tensor_graph.build_loop ... but it's kind of opaque to the user? I can at least get the graph to build if I just obey the error message and use a lambda, like in the first code snippet I posted (missing the return statement), although I did have to work out what arguments I needed for the lambda by trial and error.
Not sure if there's something I'm missing about how to best use Tensorflow models with the simulator. Is it always the case that variables in a TensorNode have to be initialized with a lambda because that graph will be a "child" of a the "parent" tf.while_loop op?

from nengo-dl.

drasmuss avatar drasmuss commented on June 19, 2024

Yes, NengoDL models always run within a while loop. Unlike "standard" feed forward deep networks (like convolutional classification networks), Nengo models almost always need to be simulated over time.
This is because they have complex/recurrent connectivity patterns, temporal neural models (e.g. spiking neurons), or other temporal dynamics like synaptic filters. And if you want to be able to optimize a network over time like that, then it needs to run within a tf.while_loop.

If you wanted to avoid the lambda approach, you could create your variables during the pre_build stage. Operations defined in the pre_build function will be built outside of the while loop. Importantly this means that they won't run each simulation timestep, but in the case of variable creation that's what you want. So your SimpleNode would look something like

class SimpleNode:
    def pre_build(self, shape_in, shape_out)
        self.fc1W = tf.get_variable(name='fc1W_weights',
                                    shape=(shape_in[-1], shape_out[-1]))
        self.fc1b = tf.get_variable(name='fc1W_biases',
                                    initializer=tf.zeros_initializer(),
                                    shape=[shape_out[-1]])
        ...

    def __call__(self, t, x):
        ...
  
        fc1 = tf.nn.relu_layer(
            tf.reshape(maxpool1, [-1, int(np.prod(maxpool1.get_shape()[1:]))]), 
                       self.fc1W, self.fc1b)

from nengo-dl.

drasmuss avatar drasmuss commented on June 19, 2024

Added better validation for TensorNode outputs, so hopefully these errors will be easier to diagnose in the future!

from nengo-dl.

NickleDave avatar NickleDave commented on June 19, 2024

@drasmuss thank you for adding the validation and thank you for your example--I hadn't seen it in the throes of getting my code to work but that's essentially what I ended up doing.

I think it might be helpful to have in the docs a more detailed example of using a TensorNode, targeted at people coming from a Tensorflow background. They will have several questions that will arise from their thinking in terms of TF abstractions, e.g. "Do I need to instantiate a Session? Am I supposed to use the session with a context manager within the post_build method, is that why it's passed as one of the arguments by default? How do I properly load weights?" My guess is you left out some of those details because your aim was to have the examples be useful but concise. But if you want to convert Tensorflow users (and maybe you don't 😂), it could be good to help them over those mental speedbumps.

As an example, because I trained the network separately from any Nengo model, I found that I needed to explicitly name the variables I wanted to reload, since there was a mismatch between the graph I had built and the graph that Nengo-DL built with the TensorNode inserted.

So I ended up with something like the following, which I'm pasting here in case it's useful to anyone else.
If you have feedback on the approach here, I'd appreciate it. Maybe this is a bass ackward way to do it.
And I'd be happy to work up a more generic example for the docs as a PR if you agree it would be helpful. (I'd make it more Pythonic, less verbose + cryptic -- I used someone else's code for AlexNet and didn't spend a lot of time making it readable.)

import tensorflow as tf
from tensorflow.python import pywrap_tensorflow

from .myalexnet_forward_newtf import conv

weights_file = "bvlc_alexnet.npy"
if not os.path.isfile(weights_file):
    print("downloading weights for alexnet")
    urlretrieve("http://www.cs.toronto.edu/~guerzhoy/tf_alexnet/bvlc_alexnet.npy",
                weights_file)

with open(weights_file, "rb") as weights:
    # use item to get dictionary saved in a numpy array
net_data = np.load(weights, encoding="latin1").item()

def load(saver, sess, path, epoch, name='alexnet-model.ckpt-%d'):
    print('Loading model from %s' % path)
    saver.restore(sess, os.path.join(path, name % epoch))

class AlexnetNode:
    """AlexNet as a node for Nengo models"""

    def __init__(self, alexnet_input_shape, checkpoint_filename, epochs):
        """__init__ function
        Parameters
        ----------
        alexnet_input_shape : tuple
            the vector will be reshaped to this shape 
        checkpoint_filename : str
            Base name of checkpoint files. If defaults were used when
             saving, will be of the form:
            '~/dir/subdir/directory_with_saved_models/alexnet-model.ckpt'
        epochs : number
            Number of epochs after which checkpoint was saved, will be
            appended to end of checkpoint filename
        """
        self.alexnet_input_shape = alexnet_input_shape
        self.checkpoint_filename = checkpoint_filename
        self.epochs = epochs

    def pre_build(self, input_shape, output_shape):
        """steps executed before build.
        Used to define variables that do not need to be initialized
        every time step of the model.
        Parameters
        ----------
        input_shape : tuple
            not used
        output_shape : tuple
            not used
        Returns
        -------
        None
        Adds vars_dict to properties
        """

        reader = pywrap_tensorflow.NewCheckpointReader(
            self.checkpoint_filename + '-' + str(self.epochs))
        var_to_shape_map = reader.get_variable_to_shape_map()

        vars_dict = {}
        weights_shape = var_to_shape_map['fc6W_weights']
        vars_dict['fc6W_weights'] = tf.get_variable(name='fc6W_weights',
                                                     shape=weights_shape)
        vars_dict['fc6W_biases'] = tf.get_variable(name='fc6W_biases',
                                                   initializer=tf.zeros(
                                                       shape=weights_shape[1]))

        weights_shape = var_to_shape_map['fc7W_weights']
        vars_dict['fc7W_weights'] = tf.get_variable(name='fc7W_weights',
                                                   shape=weights_shape)
        vars_dict['fc7W_biases'] = tf.get_variable(name='fc7W_biases',
                                                   initializer=tf.zeros(
                                                       shape=weights_shape[1]))

        # ... other variables here
        self.vars_dict = vars_dict
        self.saver = tf.train.Saver(var_list=self.vars_dict)

    def post_build(self, sess, rng):
        """steps executed post-build. Loads weights from saved checkpoint.
        Parameters
        ----------
        sess : Tensorflow.Session object
            used when creating tf.Saver object
        rng : numpy rng object
            not used
        Returns
        -------
        None
        Adds saver to properties, loads trained weights
        """

        self.saver.restore(sess=sess,
                           save_path=self.checkpoint_filename + '-' + str(self.epochs))

    def __call__(self, t, x):
        """executed each timestep while the network is running
        Parameters
        ----------
        t : int
            time step
        x : tensorflow tensor
        Returns
        -------
        probabilities : tensorflow tensor
            softmax on last fully-connected layer of AlexNet
        """

        # convert our input vector to the shape/dtype of the input image
        image = tf.reshape(tf.cast(x, tf.float32),
                           (-1,) + self.alexnet_input_shape[1:])

        k_h = 11; k_w = 11; c_o = 96; s_h = 4; s_w = 4
        conv1W = tf.Variable(net_data["conv1"][0])
        conv1b = tf.Variable(net_data["conv1"][1])
        conv1_in = conv(image, conv1W, conv1b, k_h, k_w, c_o, s_h, s_w,
                        padding="SAME", group=1)
        conv1 = tf.nn.relu(conv1_in)

        # ... more layers

        fc7 = tf.nn.relu_layer(fc6,
                               self.vars_dict['fc7W_weights'],
                               self.vars_dict['fc7W_biases'])

        fc8 = tf.nn.xw_plus_b(fc7,
                              self.vars_dict['fc8W__weights'],
                              self.vars_dict['fc8W__biases'])

        probabilities = tf.nn.softmax(fc8, name='probabilities')
        return probabilities 

from nengo-dl.

drasmuss avatar drasmuss commented on June 19, 2024

That looks like a sensible way to do things, no obvious issues jumped out at me! And yeah, it's definitely true that the existing example (the inception one https://www.nengo.ai/nengo-dl/v1.1.0/examples/pretrained_model.html) hides a lot of the TensorFlow details, mainly by making heavy use of the slim package, which handles a lot of things automatically (like variable naming/loading).

So one possibility would be to expand that example a bit, doing things more explicitly so that people can see how to do things without slim. We do have a bit of a tricky balancing act though. On the one hand we have users coming from a TensorFlow background, who are comfortable with TensorFlow concepts and wondering how to fit those into Nengo (as you point out). On the other hand, we have users coming from a Nengo/neuroscience background, who may not be familiar with TensorFlow at all, and just want to know, e.g., how to add a pretrained vision system to their Nengo model. We don't want to overwhelm that second group with a lot of TensorFlow details either. So the question would be whether we can add some of those more explicit TensorFlow details into that example without making it too complex.

Another option would be to add a separate page/example explicitly for experienced TensorFlow users coming to NengoDL (leaving the existing example targeting the more inexperienced TensorFlow users). That'd provide more freedom to dive into the details, although it'd be adding a lot more new material to the documentation.

We definitely welcome contributions, if either of those sound like something you want to work on (or if you have some other ideas)! We can discuss more, or you can just go off and work something up, whatever works for you. Do make sure you check out the contributor agreement first. Nothing crazy in there, it'll look familiar if you've signed, e.g., TensorFlow's contributor agreement, I just point it out because I want to make sure you've seen it before you start working on a PR!

from nengo-dl.

drasmuss avatar drasmuss commented on June 19, 2024

I took a stab at improving the introduction for TensorFlow users in #77 (by splitting the introduction into two tutorials, one for Nengo users and one for TensorFlow users). Feel free to take a look if you have any comments/suggestions!

from nengo-dl.

NickleDave avatar NickleDave commented on June 19, 2024

Here's a (very) rough draft of a tutorial that walks through using a TensorNode to insert a net built with keras into a Nengo model:
https://github.com/NickleDave/nengo-dl/blob/add-keras-model-to-examples/docs/examples/keras-model.ipynb

I went with Keras because:

  • the tf devs seem to be moving towards that API for everything
  • all their (neural net) tutorials use it
  • seemed to provide a way to explain the same concepts such as "create tensors in pre-build and load weights into initialized variables in post-build" without using objects from the low-level API like tf.Variables

but it's actually not clear to me whether this would even be necessary with a Keras model (could I just build and compile the whole thing within __call__?) or alternatively whether it actually works...you can see I got an error when I tried to instantiate the simulator, not sure where I went wrong.

But I thought I'd bring up the general idea here before I fight with it more

from nengo-dl.

drasmuss avatar drasmuss commented on June 19, 2024

Sorry I haven't had a chance to take a look at this yet. Just wanted to let you know it is on my plate, just been busy with other things! I'm hoping to get to it by the end of the week (end of next week at the latest).

from nengo-dl.

NickleDave avatar NickleDave commented on June 19, 2024

Completely understand, no worries

from nengo-dl.

drasmuss avatar drasmuss commented on June 19, 2024

Incorporated those suggestions into #77, and had a chance to take a look at your tutorial. I definitely like the idea to show how to insert a keras model into NengoDL. I think I would probably try to incorporate that right into the pretrained-model tutorial (showing two similar-but-different ways of accomplishing those goals is a nice way to reinforce the concepts). I can take care of that reworking of the pretrained-model tutorial afterwards though.

As to the error you're getting, I'm guessing that's because model.predict(images) expects images to be a numpy array, whereas in this case it is a tf.Tensor. I don't use keras myself so I'm just basing that on what I see in the documentation. I'm not sure how this is done in keras, but what you want is to build the operations represented in self.model into the TensorFlow graph (as opposed to model.predict, which takes a built model and runs an inference pass with some inputs).

from nengo-dl.

NickleDave avatar NickleDave commented on June 19, 2024

Glad to hear you found the suggestions helpful, thanks for incorporating them.

And thanks for pointing out the source of the error--I should've read the traceback a little more carefully. You must be right. It's not obvious to me if a Keras model will behave as one would hope and return a tensor if I just add it like any other op in the graph built within the call function, i.e. if I do something like

        images = tf.reshape(x, (-1,) + image_shape)
        return self.model(images)

Trying to meet a deadline at my day job but I hope to have a chance to get back to this on the weekend

from nengo-dl.

NickleDave avatar NickleDave commented on June 19, 2024

I had another go at this, and managed to get the __call__ function to return a Tensor from the Keras model, but it still crashes when I run the simulator. I think it's because Keras builds the model in a separate graph from the one Nengo-DL is using.

Looks like I'm not the first person to run into this error:
https://stackoverflow.com/questions/51588186/keras-tensorflow-typeerror-cannot-interpret-feed-dict-key-as-tensor
keras-team/keras#6462
but all of the solutions there look either really fragile and/or like they would cause issues for Nengo-DL

from nengo-dl.

drasmuss avatar drasmuss commented on June 19, 2024

Keras uses the default session when loading weights, so if we want to load weights into the right graph/session we need to set the default session to the one NengoDL is using. You can do this in your notebook via

def post_build(self, sess, rng):
    # load checkpoint file into model
    with sess.as_default():
        self.model.load_weights(model_weights)

That does make me think though that maybe we should do that by default (set the default session before calling the post_build function), so that this would just work automatically.

from nengo-dl.

drasmuss avatar drasmuss commented on June 19, 2024

It was a quick change so I went ahead and updated it so that the default session is set as described above. So now your original code should work, without the with sess.as_default() addition.

from nengo-dl.

NickleDave avatar NickleDave commented on June 19, 2024

Ah of course--there's was reason the sess got passed. Just out of curiousity: is there a case when you would not want to set the session the simulator is using as the default?

I removed that line and fleshed out the rest of the notebook, and went ahead and submitted a PR; trying to get it to pass CI checks now. Happy to hear any feedback including "maybe this is too much to be a separate thing and I'll just use Keras in the pre-trained notebook". And/or if it looks good I can try to squash the commits, I tend to get a little too granular sometimes I think.

from nengo-dl.

drasmuss avatar drasmuss commented on June 19, 2024

Just out of curiousity: is there a case when you would not want to set the session the simulator is using as the default?

It seems unlikely, but in case someone did they could basically use the same idea as above to override the nengo-dl default (i.e. with my_special_session.as_default(): <do stuff>).

I'm going to close this issue and continue the discussion in the PR (#84), just so that we don't get confused and end up talking in two different places 😄 .

from nengo-dl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.