Comments (6)
For now there isn't much benefit of separating the optimizer from the module, even if it would be quite easy to do. This is because the code is 100% generated and the method signature makes sense: updating a module with a mutable reference to an optimizer and the gradients.
We have to keep in mind that the Module
trait declares how the parameters are serialized, deserialized and updated, not about how they are used. It's all about handling the state without any boilerplate code (all generated). At some point, it might be good to have a more generic trait to update the weights without an explicit dependency to the Gradients
type and the Optimizer
trait.
from burn.
Hi,
Burn is still different from PyTorch in its module implementation. The module trait
and derive
only define a state with its parameters and potentially other fields.
Computing the forward pass doesn't change the state of the module, and you are free to implement it as you chose. It doesn't have to be a method attached to the struct, you can pass the state to other functions if you prefer.
In your example, it seems that the Linear
struct only contains the hyper-parameters, and the params are owned by the LinearParams
struct, which is created by a method on the hyper-parameter struct. You could replicate that pattern with burn if your wanned to by attaching the forward pass to the config struct as well.
I don't think this is a superior design because you may need to pass the state and the hyper-parameter struct around instead of just one, which increases the amount of arguments your functions take for no apparent benefit.
This is still interesting to think about and there is no apparent limitation on how the Module trait
is defined. Note that I just removed the Forward
trait from burn because there seems to be no need for an abstraction to define the forward pass. At the moment this is just plain, simple methods on modules.
You can also call them like function:
let params = Forward::new(&config);
Linear::forward(params, x);
from burn.
There are still some mutable methods on modules, namely updating the parameters with an optimizer and updating the current device. This may be better if those methods returned the updated state instead, like pretty much all tensor functions. Do you think it would be an improvement?
from burn.
Yes I think that goes in the direction that I imagined. It might be desirable if the Module was independent of the optimizer
from burn.
@nathanielsimard does your recent work overhauling the training and module cover anything in this request? Since you had much discussion with James about the functional aspect of burn, can you please tell what we will do and will not? With that we can have some resolution and close or keep it open if some parts are remain valid. Regardless we should update this ticket since the work is recent and fresh.
from burn.
I think we should close this issue, but it doesn't mean there isn't any more work to be done to improve the API. However, this issue is probably too vague to ever be considered complete. We might open more focused issues that link to this one for continuity.
from burn.
Related Issues (20)
- error when import onnx of yolo8 HOT 3
- Panic w/ backwards pass when combining gather and max_dim HOT 2
- Convert ONNX graph inputs of 0-dim tensors to scalars HOT 4
- Support for 0-Dimensional Tensors in Burn HOT 5
- No adapter found for graphics API AutoGraphicsApi HOT 1
- [Fusion] Support multi-precision fusion
- burn and drug discovery
- loss.backward() hangs after burn update 0.12 -> 0.13 HOT 3
- Help Wanted: Implementing ONNX Ops
- Implement multi-dimensional repeat operation and rename existing repeat method HOT 2
- [Tensor] Add `cumsum` operation HOT 1
- .select_assign does not work with Autodiff<NdArray> backend
- Add indentation in contributing book
- Text classification example gives "Shader validation error" when run on multiple GPUs HOT 4
- Upgrade all dependencies
- Better memory management in Burn Compute
- Config Derive: Generic Types?
- Optimizer / Visitor / Mapper confusion, no documentation HOT 2
- clamp_min does not handle -inf correctly on Autodiff<NdArray> backend
- Update tch to 0.16+
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from burn.