Giter VIP home page Giter VIP logo

Comments (5)

wzzhu avatar wzzhu commented on June 6, 2024 1

Hi @wzzhu can you suggest a library to do RNN (LSTM, GRU), pls

I correct one wrong statement above. "In addition, gorgonia has no real capability of doing RNN as it is using a static graph for computation.". By not following the implementation of CharRNN which doesn't share weights, we can implement the RNN directly using the shared weights to construct a unrolled graph given the fixed sequence size. So gorgonia can still do RNN on condition that we construct the computation graph with shared weight nodes and loop each input in the sequence.

Therefore we can still use gorgonia for RNN. But it needs heavy implementation.

e.g.,

g := gorgonia.NewGraph()

x := gorgonia.NewVector(g, tensor.Float64, gorgonia.WithShape(seqLen, inputSize), gorgonia.WithName("x"))
y := gorgonia.NewVector(g, tensor.Float64, gorgonia.WithShape(seqLen, outputSize), gorgonia.WithName("y"))
h := gorgonia.NewVector(g, tensor.Float64, gorgonia.WithShape(seqLen, hiddenSize), gorgonia.WithName("h"))

// the input-to-hidden weights
Wxh := gorgonia.NewMatrix(g, tensor.Float64, gorgonia.WithShape(inputSize, hiddenSize), gorgonia.WithName("Wxh"))
// the hidden-to-hidden weights
Whh := gorgonia.NewMatrix(g, tensor.Float64, gorgonia.WithShape(hiddenSize, hiddenSize), gorgonia.WithName("Whh"))
// the hidden-to-output weights
Why := gorgonia.NewMatrix(g, tensor.Float64, gorgonia.WithShape(hiddenSize, outputSize), gorgonia.WithName("Why"))
// the hidden biases
bh := gorgonia.NewVector(g, tensor.Float64, gorgonia.WithShape(hiddenSize), gorgonia.WithName("bh"))
// the output biases
by := gorgonia.NewVector(g, tensor.Float64, gorgonia.WithShape(outputSize), gorgonia.WithName("by"))

// Define the forward pass
var hPrev *gorgonia.Node
for i := 0; i < seqLen; i++ {
    // Compute the input-to-hidden activations
    hI := gorgonia.Must(gorgonia.Mul(x.Slice(i, i+1), Wxh))
    
    // Compute the hidden-to-hidden activations
    hH := gorgonia.Must(gorgonia.Mul(hPrev, Whh))
    
    // Compute the total hidden activations
    hTotal := gorgonia.Must(gorgonia.Add(hI, hH, bh))
    
    // Compute the hidden activations
    h[i], err = gorgonia.Tanh(hTotal)
    if err != nil {
        log.Fatal(err)
    }
    
    // Compute the output activations
    oI := gorgonia.Must(gorgonia.Mul(h[i], Why))
    y[i], err = gorgonia.Sigmoid(gorgonia.Must(gorgonia.Add(oI, by)))
    if err != nil {
        log.Fatal(err)
    }
    
    hPrev = h[i]
}

loss := gorgonia.Must(gorgonia.Mean(gorgonia.Must(gorgonia.Square(gorgonia.Must(gorgonia.Sub(y, yTrue))))))

grads, err := gorgonia.Grad(loss, Wxh, Whh, Why, bh, by)
if err != nil {
    log.Fatal(err)
}

solver := gorgonia.NewRMSPropSolver(gorgonia.WithLearnRate(0.001))

 vm=gorgonia.NewTapeMachine(g)
for i := 0; i < numEpochs; i++ {
    // Perform the forward pass and compute the loss and gradients
    gorgonia.Let(x, ...)
   gorgonia.Let(y, ...)
   vm.Run()
  // Back prop
   solver.Step(gorgonia.NodesToValueGrads( []*gorgonia.Node{Wxh, Whh, Why, bh, by})
  vm.Reset()
}

from gorgonia.

artheus avatar artheus commented on June 6, 2024

I am getting the same error, running charRNN with gorgonia.org/[email protected]

from gorgonia.

wzzhu avatar wzzhu commented on June 6, 2024

The problem was caused by sampling from a log prob distribution instead of from a normal prob. Since log probs are negative numbers, the sampling will always go out of bound without having a sum larger than a positive number in 0 to 1.

Update the sample func in examples/chartRNN/util.go by converting the log probs back to normal probs will suppress this error.


func sample(val Value) int {
	var t tensor.Tensor
	var ok bool
	if t, ok = val.(tensor.Tensor); !ok {
		panic("Expects a tensor")
	}

	// It is logprob, convert it back to prob
	t2, err := tensor.Exp(t, tensor.AsSameType())
	if err != nil {
		panic("Error converting to prob")
	}

	return tensor.SampleIndex(t2)
}

However, there are other errors that will crash.

In addition, gorgonia has no real capability of doing RNN as it is using a static graph for computation.
It doesn't have dynamic unrolling to support BPTT (Back Propagation Through Time). Therefore the example CharRNN is just like a feed forward network, as it cannot share the LSTM internal weights like the python counter part in Karpathy's example at https://gist.github.com/karpathy/d4dee566867f8291f086 to capture the characteristics of input sequence.

To really support LSTM or other RNN, it may need to create new Op for RNN capable of doing BPTT.

from gorgonia.

chirmicci avatar chirmicci commented on June 6, 2024

Hi @wzzhu can you suggest a library to do RNN (LSTM, GRU), pls

from gorgonia.

chirmicci avatar chirmicci commented on June 6, 2024

@wzzhu sorry that I ask too many questions
But where I can find any example of how to build the NER model in Gorgonia?
Maybe have you already done something like that?

from gorgonia.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.