Giter VIP home page Giter VIP logo

tinypytorch's Introduction

tinypytorch

import torch
from tinypytorch.data import get_local_data
from tinypytorch.model import initialize_parameters, Model
from tinypytorch.metrics import accuracy

Dependence libraries - nbdev - torch - matplotlib - pytest

This file will become your README and also the index of your documentation.

Install

pip install tinypytorch

Train a neural network

Data

x_train, y_train, x_valid, y_valid = get_local_data()
x_train.shape, y_train.shape
(torch.Size([50000, 784]), torch.Size([50000]))
y_train[1]
tensor(0)
y_train[1:5]
tensor([0, 4, 1, 9])

Initialize hyperparameters

n, m = x_train.shape # num rows and columns
c = y_train.max() + 1
n, m, c
(50000, 784, tensor(10))
nh = 50 # num hidden
w1, b1, w2, b2 = initialize_parameters(m, nh)
w1.shape, b1.shape
(torch.Size([784, 50]), torch.Size([50]))
w2.shape, b2.shape
(torch.Size([50, 1]), torch.Size([1]))
  • Training set’s shape: (50000, 784)
  • Weight’s shape: (784, 50)
  • Bias’s shape: (50)

The first layer (Lin): (50000, 784) x (784, 50) + (50)

model = Model(w1, b1, w2, b2)
loss = model(x_train, y_train)
Model.__call__
l=<tinypytorch.model.Lin object>
Lin.forward
inp=torch.Size([50000, 784])
w=torch.Size([784, 50])
b=torch.Size([50])
output.shape=torch.Size([50000, 50])
x.shape=torch.Size([50000, 50])
Model.__call__
l=<tinypytorch.model.ReLU object>
x.shape=torch.Size([50000, 50])
Model.__call__
l=<tinypytorch.model.Lin object>
Lin.forward
inp=torch.Size([50000, 50])
w=torch.Size([50, 1])
b=torch.Size([1])
output.shape=torch.Size([50000, 1])
x.shape=torch.Size([50000, 1])
loss
tensor(26.1652)
model.backward()

Example 2

bs = 64
xb = x_train[0:64]

tinypytorch's People

Watchers

XλRI-U5 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.