Giter VIP home page Giter VIP logo

tensorflow-haskell-deptyped's Introduction

tensorflow-haskell-deptyped

This repo is dedicated to experiment adding Dependent Types to TensorFlow Haskell.

Beware! The API is not yet stable. This code is in alpha stage.

This repo may be merged into TensorFlow Haskell in the future, in the meantime is just playground to test dependent types and tensorflow.

How to run

Making sure everything is alright:

stack setup

Build the project:

stack build

Running the examples:

stack exec -- tensorflow-haskell-deptyped
stack exec -- tf-example

Examples

Some simple examples can be found in Main.hs.

A more complete example using MNIST can be found in tensorflow-minst-deptyped.

Imitating the presentation shown in https://github.com/tensorflow/haskell below a minimal example of using TensorFlow in Haskell (with Dependent Types) [full code]:

{-# LANGUAGE DataKinds, TypeApplications, ScopedTypeVariables #-}

import Control.Monad (replicateM_)
import System.Random (randomIO)
import Test.HUnit (assertBool)

import qualified TensorFlow.Core as TF
import qualified TensorFlow.Minimize as TF

import qualified TensorFlow.DepTyped as TFD
import           Data.Vector.Sized (Vector)
import qualified Data.Vector.Sized as VS (replicateM, map)

import           GHC.TypeLits (KnownNat)

main :: IO ()
main = do
    -- Generate data where `y = x*3 + 8`.
    xData <- VS.replicateM @100 randomIO
    let yData = VS.map (\x->x*3 + 8) xData
    -- Fit linear regression model.
    (w, b) <- fit xData yData
    assertBool "w == 3" (abs (3 - w) < 0.001)
    assertBool "b == 8" (abs (8 - b) < 0.001)

fit :: forall n. KnownNat n => Vector n Float -> Vector n Float -> IO (Float, Float)
fit xData yData = TF.runSession $ do
    -- Create tensorflow constants for x and y.
    let x = TFD.constant @'[n] xData
        y = TFD.constant @'[n] yData
    -- Create scalar variables for slope and intercept.
    w <- TFD.initializedVariable @'[1] 0
    b <- TFD.initializedVariable @'[1] 0
    -- Define the loss function.
    let yHat = (x `TFD.mul` TFD.readValue w) `TFD.add` TFD.readValue b
        loss = TFD.square (yHat `TFD.sub` y)
    -- Optimize with gradient descent.
    trainStep <- TFD.minimizeWith (TF.gradientDescent 0.001) loss [TFD.unVariable w, TFD.unVariable b]
    replicateM_ 1000 $ do
      () <- TFD.run trainStep -- this is necessary for haskell to select the right instance of `TFD.run`
      return ()               -- alternatively, you could annotate `replicateM_` with `Int -> IO () -> IO ()`
    -- Return the learned parameters.
    TF.Scalar w' <- TFD.run (TFD.readValue w)
    TF.Scalar b' <- TFD.run (TFD.readValue b)
    return (w', b')

LICENSE

This project is dual licensed under Apache 2.0 and BSD3.

tensorflow-haskell-deptyped's People

Contributors

helq avatar mluszczyk avatar

Watchers

James Cloos avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.