Giter VIP home page Giter VIP logo

java-glm's Introduction

Generalized Linear Model implementation in Java

Package implements the generalized linear model in Java

Build Status Coverage Status

GLM

Install

Add the following to dependencies of your pom file:

<dependency>
  <groupId>com.github.chen0040</groupId>
  <artifactId>java-glm</artifactId>
  <version>1.0.6</version>
</dependency>

Features

The current implementation of GLM supports as many distribution families as glm package in R:

  • Normal
  • Exponential
  • Gamma
  • InverseGaussian
  • Poisson
  • Bernouli
  • Binomial
  • Categorical
  • Multinomial

For the solvers, the current implementation of GLM supports a number of variants of the iteratively re-weighted least squares estimation algorithm:

  • IRLS
  • IRLS with QR factorization
  • IRLS with SVD factorization

Usage

Step 1: Create and train the glm against the training data in step 1

Suppose you want to create logistic regression model from GLM and train the logistic regression model against the data frame

import com.github.chen0040.glm.solvers.Glm;
import com.github.chen0040.glm.enums.GlmSolverType;

trainingData = loadTrainingData();

Glm glm = Glm.logistic();
glm.setSolverType(GlmSolverType.GlmIrls);
glm.fit(trainingData);

The "trainingData" is a data frame (Please refers to this link on how to create a data frame from file or from scratch)

The line "Glm.logistic()" create the logistic regression model, which can be easily changed to create other regression models (For example, calling "Glm.linear()" create a linear regression model)

The line "glm.fit(..)" performs the GLM training.

Step 2: Use the trained regression model to predict on new data

The trained glm can then run on the testing data, below is a java code example for logistic regression:

testingData = loadTestingData();
for(int i = 0; i < testingData.rowCount(); ++i){
    boolean predicted = glm.transform(testingData.row(i)) > 0.5;
    boolean actual = frame.row(i).target() > 0.5;
    System.out.println("predicted(Irls): " + predicted + "\texpected: " + actual);
}

The "testingData" is a data frame

The line "glm.transform(..)" perform the regression

Sample code

Sample code for linear regression

The sample code below shows the linear regression example

DataQuery.DataFrameQueryBuilder schema = DataQuery.blank()
      .newInput("x1")
      .newInput("x2")
      .newOutput("y")
      .end();

// y = 4 + 0.5 * x1 + 0.2 * x2
Sampler.DataSampleBuilder sampler = new Sampler()
      .forColumn("x1").generate((name, index) -> randn() * 0.3 + index)
      .forColumn("x2").generate((name, index) -> randn() * 0.3 + index * index)
      .forColumn("y").generate((name, index) -> 4 + 0.5 * index + 0.2 * index * index + randn() * 0.3)
      .end();

DataFrame trainingData = schema.build();

trainingData = sampler.sample(trainingData, 200);

System.out.println(trainingData.head(10));

DataFrame crossValidationData = schema.build();

crossValidationData = sampler.sample(crossValidationData, 40);

Glm glm = Glm.linear();
glm.setSolverType(GlmSolverType.GlmIrlsQr);
glm.fit(trainingData);

for(int i = 0; i < crossValidationData.rowCount(); ++i){
 double predicted = glm.transform(crossValidationData.row(i));
 double actual = crossValidationData.row(i).target();
 System.out.println("predicted: " + predicted + "\texpected: " + actual);
}

System.out.println("Coefficients: " + glm.getCoefficients());

Sample code for logistic regression

The sample code below performs binary classification using logistic regression:

InputStream inputStream = new FileInputStream("heart_scale.txt");
DataFrame dataFrame = DataQuery.libsvm().from(inputStream).build();

for(int i=0; i < dataFrame.rowCount(); ++i){
 DataRow row = dataFrame.row(i);
 String targetColumn = row.getTargetColumnNames().get(0);
 row.setTargetCell(targetColumn, row.getTargetCell(targetColumn) == -1 ? 0 : 1); // change output from (-1, +1) to (0, 1)
}

TupleTwo<DataFrame, DataFrame> miniFrames = dataFrame.shuffle().split(0.9);
DataFrame trainingData = miniFrames._1();
DataFrame crossValidationData = miniFrames._2();

Glm algorithm = Glm.logistic();
algorithm.setSolverType(GlmSolverType.GlmIrlsQr);
algorithm.fit(trainingData);

double threshold = 1.0;
for(int i = 0; i < trainingData.rowCount(); ++i){
 double prob = algorithm.transform(trainingData.row(i));
 if(trainingData.row(i).target() == 1 && prob < threshold){
    threshold = prob;
 }
}
logger.info("threshold: {}",threshold);


BinaryClassifierEvaluator evaluator = new BinaryClassifierEvaluator();

for(int i = 0; i < crossValidationData.rowCount(); ++i){
 double prob = algorithm.transform(crossValidationData.row(i));
 boolean predicted = prob > 0.5;
 boolean actual = crossValidationData.row(i).target() > 0.5;
 evaluator.evaluate(actual, predicted);
 System.out.println("probability of positive: " + prob);
 System.out.println("predicted: " + predicted + "\tactual: " + actual);
}

evaluator.report();

Sample code for multi-class classification

The sample code below perform multi class classification using the logistic regression model as the generator

InputStream irisStream = FileUtils.getResource("iris.data");
DataFrame irisData = DataQuery.csv(",")
      .from(irisStream)
      .selectColumn(0).asNumeric().asInput("Sepal Length")
      .selectColumn(1).asNumeric().asInput("Sepal Width")
      .selectColumn(2).asNumeric().asInput("Petal Length")
      .selectColumn(3).asNumeric().asInput("Petal Width")
      .selectColumn(4).asCategory().asOutput("Iris Type")
      .build();

TupleTwo<DataFrame, DataFrame> parts = irisData.shuffle().split(0.9);

DataFrame trainingData = parts._1();
DataFrame crossValidationData = parts._2();

System.out.println(crossValidationData.head(10));

OneVsOneGlmClassifier multiClassClassifier = Glm.oneVsOne(Glm::logistic);
multiClassClassifier.fit(trainingData);

ClassifierEvaluator evaluator = new ClassifierEvaluator();

for(int i=0; i < crossValidationData.rowCount(); ++i) {
 String predicted = multiClassClassifier.classify(crossValidationData.row(i));
 String actual = crossValidationData.row(i).categoricalTarget();
 System.out.println("predicted: " + predicted + "\tactual: " + actual);
 evaluator.evaluate(actual, predicted);
}

evaluator.report();

Background on GLM

Introduction

GLM is generalized linear model for exponential family of distribution model b = g(a). g(a) is the inverse link function.

Therefore, for a regressions characterized by inverse link function g(a), the regressions problem be formulated as we are looking for model coefficient set x in

$$g(A * x) = b + e$$

And the objective is to find x such for the following objective:

$$min (g(A * x) - b).transpose * W * (g(A * x) - b)$$

Suppose we assumes that e consist of uncorrelated naive variables with identical variance, then W = sigma^(-2) * I, and The objective

$$min (g(A * x) - b) * W * (g(A * x) - b).transpose$$

is reduced to the OLS form:

$$min || g(A * x) - b ||^2$$

Iteratively Re-weighted Least Squares estimation (IRLS)

In regressions, we tried to find a set of model coefficient such for:

$$A * x = b + e$$

A * x is known as the model matrix, b as the response vector, e is the error terms.

In OLS (Ordinary Least Square), we assumes that the variance-covariance

$$matrix V(e) = sigma^2 * W$$

, where: W is a symmetric positive definite matrix, and is a diagonal matrix sigma is the standard error of e

In OLS (Ordinary Least Square), the objective is to find x_bar such that e.transpose * W * e is minimized (Note that since W is positive definite, e * W * e is alway positive) In other words, we are looking for x_bar such as (A * x_bar - b).transpose * W * (A * x_bar - b) is minimized

Let

$$y = (A * x - b).transpose * W * (A * x - b)$$

Now differentiating y with respect to x, we have

$$dy / dx = A.transpose * W * (A * x - b) * 2$$

To find min y, set dy / dx = 0 at x = x_bar, we have

$$A.transpose * W * (A * x_bar - b) = 0$$

Transform this, we have

$$A.transpose * W * A * x_bar = A.transpose * W * b$$

Multiply both side by (A.transpose * W * A).inverse, we have

$$x_bar = (A.transpose * W * A).inverse * A.transpose * W * b$$

This is commonly solved using IRLS

The implementation of Glm based on iteratively re-weighted least squares estimation (IRLS)

java-glm's People

Contributors

chen0040 avatar

Stargazers

 avatar  avatar  avatar  avatar

Watchers

 avatar

Forkers

nicolasqcheng

java-glm's Issues

Distributions (Gamma, Poisson)

Hi
Thanks for creating the repo (nicely maintained and good to read). I have a problem, however.

I am trying to use other distribution families, like Poisson or Gamma. So I create the GLM as follows:

Glm glm = new Glm(GlmSolverType.GlmIrlsSvd, GlmDistributionFamily.Gamma) glm.fit(dMatrix)

or

Glm glm = new Glm() glm.distributionFamily = GlmDistributionFamily.Poisson glm.fit(dMatrix)

I only get nans in either case for predictions. Is there anything I am missing here? If I change this to linear. I get non-nan results on my test dataset (iris).

For some values of distribution family, such as InverseGaussian, I even get null exceptions from GlmAlgorithm (state is null).

protected TerminationEvaluationMethod shouldTerminate = (state, iteration) -> { if (state.improved() && state.improvement() >= this.mTol) { return iteration >= this.maxIters; } else { return false; } };

Help appreciated.
Thanks again
Ben.

GLM Results

Dear Xianshun Chen,

First of all, thank you very much for the only Java GLM library that is clear and easy to use and integrate!

I would like to ask you if you are going to continue the development of this great library.
An example of the missing feature is the pValue calculation. Also, the Adjusted McFaddenR2 is always > 1, while McFaddenR2 is OK (0<R2<1).

Thanks again!
Wish you all the best.

Aram Paronikyan,
Yerevan, Armenia

Likelihood

This is a really nice project!

One feature request: could the minimised log likelihood could be reported? This would let users compute their own fit stats, deviance, AIC, etc. It would indirectly resolve the R2 issues mentioned in another issue/feature request, too.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.