Comments (5)
Hi @wzzhu can you suggest a library to do RNN (LSTM, GRU), pls
I correct one wrong statement above. "In addition, gorgonia has no real capability of doing RNN as it is using a static graph for computation.". By not following the implementation of CharRNN which doesn't share weights, we can implement the RNN directly using the shared weights to construct a unrolled graph given the fixed sequence size. So gorgonia can still do RNN on condition that we construct the computation graph with shared weight nodes and loop each input in the sequence.
Therefore we can still use gorgonia for RNN. But it needs heavy implementation.
e.g.,
g := gorgonia.NewGraph()
x := gorgonia.NewVector(g, tensor.Float64, gorgonia.WithShape(seqLen, inputSize), gorgonia.WithName("x"))
y := gorgonia.NewVector(g, tensor.Float64, gorgonia.WithShape(seqLen, outputSize), gorgonia.WithName("y"))
h := gorgonia.NewVector(g, tensor.Float64, gorgonia.WithShape(seqLen, hiddenSize), gorgonia.WithName("h"))
// the input-to-hidden weights
Wxh := gorgonia.NewMatrix(g, tensor.Float64, gorgonia.WithShape(inputSize, hiddenSize), gorgonia.WithName("Wxh"))
// the hidden-to-hidden weights
Whh := gorgonia.NewMatrix(g, tensor.Float64, gorgonia.WithShape(hiddenSize, hiddenSize), gorgonia.WithName("Whh"))
// the hidden-to-output weights
Why := gorgonia.NewMatrix(g, tensor.Float64, gorgonia.WithShape(hiddenSize, outputSize), gorgonia.WithName("Why"))
// the hidden biases
bh := gorgonia.NewVector(g, tensor.Float64, gorgonia.WithShape(hiddenSize), gorgonia.WithName("bh"))
// the output biases
by := gorgonia.NewVector(g, tensor.Float64, gorgonia.WithShape(outputSize), gorgonia.WithName("by"))
// Define the forward pass
var hPrev *gorgonia.Node
for i := 0; i < seqLen; i++ {
// Compute the input-to-hidden activations
hI := gorgonia.Must(gorgonia.Mul(x.Slice(i, i+1), Wxh))
// Compute the hidden-to-hidden activations
hH := gorgonia.Must(gorgonia.Mul(hPrev, Whh))
// Compute the total hidden activations
hTotal := gorgonia.Must(gorgonia.Add(hI, hH, bh))
// Compute the hidden activations
h[i], err = gorgonia.Tanh(hTotal)
if err != nil {
log.Fatal(err)
}
// Compute the output activations
oI := gorgonia.Must(gorgonia.Mul(h[i], Why))
y[i], err = gorgonia.Sigmoid(gorgonia.Must(gorgonia.Add(oI, by)))
if err != nil {
log.Fatal(err)
}
hPrev = h[i]
}
loss := gorgonia.Must(gorgonia.Mean(gorgonia.Must(gorgonia.Square(gorgonia.Must(gorgonia.Sub(y, yTrue))))))
grads, err := gorgonia.Grad(loss, Wxh, Whh, Why, bh, by)
if err != nil {
log.Fatal(err)
}
solver := gorgonia.NewRMSPropSolver(gorgonia.WithLearnRate(0.001))
vm=gorgonia.NewTapeMachine(g)
for i := 0; i < numEpochs; i++ {
// Perform the forward pass and compute the loss and gradients
gorgonia.Let(x, ...)
gorgonia.Let(y, ...)
vm.Run()
// Back prop
solver.Step(gorgonia.NodesToValueGrads( []*gorgonia.Node{Wxh, Whh, Why, bh, by})
vm.Reset()
}
from gorgonia.
I am getting the same error, running charRNN with gorgonia.org/[email protected]
from gorgonia.
The problem was caused by sampling from a log prob distribution instead of from a normal prob. Since log probs are negative numbers, the sampling will always go out of bound without having a sum larger than a positive number in 0 to 1.
Update the sample func in examples/chartRNN/util.go by converting the log probs back to normal probs will suppress this error.
func sample(val Value) int {
var t tensor.Tensor
var ok bool
if t, ok = val.(tensor.Tensor); !ok {
panic("Expects a tensor")
}
// It is logprob, convert it back to prob
t2, err := tensor.Exp(t, tensor.AsSameType())
if err != nil {
panic("Error converting to prob")
}
return tensor.SampleIndex(t2)
}
However, there are other errors that will crash.
In addition, gorgonia has no real capability of doing RNN as it is using a static graph for computation.
It doesn't have dynamic unrolling to support BPTT (Back Propagation Through Time). Therefore the example CharRNN is just like a feed forward network, as it cannot share the LSTM internal weights like the python counter part in Karpathy's example at https://gist.github.com/karpathy/d4dee566867f8291f086 to capture the characteristics of input sequence.
To really support LSTM or other RNN, it may need to create new Op for RNN capable of doing BPTT.
from gorgonia.
Hi @wzzhu can you suggest a library to do RNN (LSTM, GRU), pls
from gorgonia.
@wzzhu sorry that I ask too many questions
But where I can find any example of how to build the NER model in Gorgonia?
Maybe have you already done something like that?
from gorgonia.
Related Issues (20)
- Getting start code failed to run HOT 1
- Stacking multiple tensors
- Unable to import gorgonia.org/gorgonia/examples/mnist HOT 3
- There is an inexplicable error when running convnet_cuda, and I have no clue to solve it. Can you provide some ideas? HOT 3
- Load model from redisai HOT 3
- convnet with cuda (v11) support not working HOT 2
- If there is any function would be used like deconv2D?
- Critical dualValue bug? HOT 7
- No examples folder for v0.9.16 and v0.9.17
- Unexpected behaviour with Add() and Sub()
- "go get -u gorgonia.org/gorgonia" - error undefined: arrowArray.Interface HOT 1
- panic on parallel runner HOT 1
- Support for OpenCL and multiple GPUs, such as Intel Graphics and AMD. HOT 1
- BatchedMatMul bug when applying to 1-shape matrix
- Please may you tag a new version HOT 1
- OneHot op
- Using a node twice without it getting changed in-place HOT 1
- This library is defunct: prove me wrong HOT 1
- "go install gorgonia.org/gorgonia/cmd/cudagen" - could not determine kind of name for C.CU_TARGET_COMPUTE_20 HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from gorgonia.