Giter VIP home page Giter VIP logo

mgbdt's People

Contributors

kingfengji avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

mgbdt's Issues

Environment

I feel that code written in python 3.5 would likely be compatible with other python 3 versions, are you sure that a build is necessary in 3.5?

Can not find the uci dataset

Hi,
I wanna run the uci_year and uci_adult demo, but I can't find the get_data.sh files as ReadME said. Would you please upload it or tell me the data format so I can handle it by myself.
I find that the code uses features file, but it is not in the git too.

Problem with Pop

/opt/conda/lib/python3.7/site-packages/joblib/parallel.py in (.0)
254 with parallel_backend(self._backend, n_jobs=self._n_jobs):
255 return [func(*args, **kwargs)
--> 256 for func, args, kwargs in self.items]
257
258 def len(self):

/kaggle/working/mGBDT/lib/mgbdt/model/online_xgb.py in fit_increment(self, X, y, num_boost_round, params)
13 for k, v in extra_params.items():
14 params[k] = v
---> 15 params.pop("n_estimators")
16
17 if callable(self.objective):

KeyError: 'n_estimators'

Performance of your model on regression tasks

Description

@kingfengji Thanks for making the code available. I believe that multi-layered decision trees is a very elegant and powerful approach! I was applying your model to the boston housing dataset but wasn't able to outperform a baseline xgboost model.

Details

To compare your approach to several alternatives, I ran a small benchmark study using the following approaches, where all models have the same hyper-parameters

  • baseline xgboost model (xgboost)
  • mGBDT with xgboost for hidden and output layer (mGBDT_XGBoost)
  • mGBDT with xgboost for hidden but with linear model for output layer (mGBDT_Linear)
  • linear model as implemented here (Linear)

I am using PyTorch's L1Loss for model training and use the MAE for evaluation, where all models are trained in serial mode. Results are as follows

image

In particular, I observe the following

  • irresepective of the hyper-parameters and number of epochs, a basline xgboost model tends to outperforms your approach
  • with increasing number of epochs, the runtime for an epoch increases considerably. Any idea as to why this happens?
  • using mGBDT_Linear,
    • I wasn't able to use PyTorch's MSELoss since the loss exploded after some iterations, even after normalizing X. Should we, similar to Neural Networks, also scale y to avoid exploding gradients?
    • the training loss starts at exceptionally high values, then decreases before it starts to increase again

Additional Questions

  • Given that you have mostly been using your approach for classification tasks, is there anything we need to change before we use it for regression tasks, except the PyTorch Loss?
  • Besides the loss of F, can we also track how well the target propagation is working by evaluating the reconstruction loss of G?
  • When using mGBDT with a linear output layer, would we expect to generally see better results compared to using xgboost for the output layer?
  • What is the benefit of using a linear output layer compared to a xgboost layer?
  • For training F and G, you are currently using the MSELoss for the xgboost models. Do you have some experience with modifying this loss?
  • What is the effect of the number of iterations for initializing the model before training?
  • What is the relationship between the number of boosting iterations (for xgboost training) and the number of epochs (for MGBDT training)?
  • In Section 4 of your paper you state "The experiments for this section is mainly designed to empirically examine if it is feasible to jointly train the multi-layered structure proposed by this work. That is, we make no claims that the current structure can outperform CNNs in computer vision tasks." So as a question, would that mean that your intention is not to outperform existing Deep Learning based models, say CNN, or to outperform existing GBM-models, like XGBoost, but rather to show that a Decision Tree based model can be also used for learning meaningful representations that can then be used for downstreaming tasks?
  • Connected to the previous question: Gradient boosting models are already very strong learners that obtain very good results in many applications. So what would be your motivation of using multiple layers of such a model? May it even happen that, based on the implicit error correction mechanism of GBM, training several of them leads to a drop in accuracy?

Code

To reproduce the results, you an use the attached notebook.

ModelComparison.zip

@kingfengji I would highly appreciate your feedback. Many thanks.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.