Giter VIP home page Giter VIP logo

deeprecommender's Introduction

Deep AutoEncoders for Collaborative Filtering

This is not an official NVIDIA product. It is a research project described in: "Training Deep AutoEncoders for Collaborative Filtering"(https://arxiv.org/abs/1708.01715)

The model

The model is based on deep AutoEncoders.

AutEncoderPic

Requirements

  • Python 3.6
  • Pytorch: pipenv install
  • CUDA (recommended version >= 8.0)

Training using mixed precision with Tensor Cores

Getting Started

Run unittests first

The code is intended to run on GPU. Last test can take a minute or two.

$ python -m unittest test/data_layer_tests.py
$ python -m unittest test/test_model.py

Tutorial

Checkout this tutorial by miguelgfierro.

Get the data

Note: Run all these commands within your DeepRecommender folder

Netflix prize

  • Download from here into your DeepRecommender folder
$ tar -xvf nf_prize_dataset.tar.gz
$ tar -xf download/training_set.tar
$ python ./data_utils/netflix_data_convert.py training_set Netflix

Data stats

Dataset Netflix 3 months Netflix 6 months Netflix 1 year Netflix full
Ratings train 13,675,402 29,179,009 41,451,832 98,074,901
Users train 311,315 390,795 345,855 477,412
Items train 17,736 17,757 16,907 17,768
Time range train 2005-09-01 to 2005-11-31 2005-06-01 to 2005-11-31 2004-06-01 to 2005-05-31 1999-12-01 to 2005-11-31
-------- ---------------- ----------- ------------
Ratings test 2,082,559 2,175,535 3,888,684 2,250,481
Users test 160,906 169,541 197,951 173,482
Items test 17,261 17,290 16,506 17,305
Time range test 2005-12-01 to 2005-12-31 2005-12-01 to 2005-12-31 2005-06-01 to 2005-06-31 2005-12-01 to 2005-12-31

Train the model

In this example, the model will be trained for 12 epochs. In paper we train for 102.

python run.py --gpu_ids 0 \
--path_to_train_data Netflix/NF_TRAIN \
--path_to_eval_data Netflix/NF_VALID \
--hidden_layers 512,512,1024 \
--non_linearity_type selu \
--batch_size 128 \
--logdir model_save \
--drop_prob 0.8 \
--optimizer momentum \
--lr 0.005 \
--weight_decay 0 \
--aug_step 1 \
--noise_prob 0 \
--num_epochs 12 \
--summary_frequency 1000

Note that you can run Tensorboard in parallel

$ tensorboard --logdir=model_save

Run inference on the Test set

python infer.py \
--path_to_train_data Netflix/NF_TRAIN \
--path_to_eval_data Netflix/NF_TEST \
--hidden_layers 512,512,1024 \
--non_linearity_type selu \
--save_path model_save/model.epoch_11 \
--drop_prob 0.8 \
--predictions_path preds.txt

Compute Test RMSE

python compute_RMSE.py --path_to_predictions=preds.txt

After 12 epochs you should get RMSE around 0.927. Train longer to get below 0.92

Results

It should be possible to achieve the following results. Iterative output re-feeding should be applied once during each iteration.

(exact numbers will vary due to randomization)

DataSet RMSE Model Architecture
Netflix 3 months 0.9373 n,128,256,256,dp(0.65),256,128,n
Netflix 6 months 0.9207 n,256,256,512,dp(0.8),256,256,n
Netflix 1 year 0.9225 n,256,256,512,dp(0.8),256,256,n
Netflix full 0.9099 n,512,512,1024,dp(0.8),512,512,n

deeprecommender's People

Contributors

alexgrig avatar amoussawi avatar david30907d avatar juhope avatar lenguyenthedat avatar miguelgfierro avatar okuchaiev avatar paulhendricks avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

deeprecommender's Issues

Implicit feedback

How would I go about applying this model to the implicit feedback, meaning just visiting/non-visiting for example, so not numerical rating?

Data issue

I am new to Deeplearning and python. I am unable to replicate the netflix data. Can you help me with the input data format, file structure, folder structure, valid and train file similarities and difference etc. When i run using same data in Train and Valid I am able to run but when i put a separate data i get multiple errors. Please help.

tricks to reach 0.909?

i tried to follow the example on full dataset and open the dense refeeding trick every iteration. make the training 36 iterations [seems do not change much after that?], but my best result on validation set is 0.914, some gap to the reported 0.91. is there other tricks to duplicate the result? thanks!

Question about dense re-feeding.

In dense re-feeding what do we show as output to model. For example let's say B is our training batch. For model input X and label y both will be equal to B. Let's say model gives y_hat as output for our batch B. Now, in dense re-feeding we give this y_hat as input to model. Now my question is what will be the labels in dense re-feeding? Will it be y_hat or is it B?

Expected range of labels

I would like to use other datasets with varying ranges. Does the recommender only work with a certain range or are all ranges valid, including those with float values (eg. 0.0-10.0, -10-10)?

Autoencoder shape

What's the reasoning behind having the first layer smaller than the middle layers, unlike what the picture shows? Is it to reduce the number of parameters and overfitting, or simply the best configuration from the experiments?

Can I use this code for 0,1 rating data?

In my task, data's rating is 0 or 1.
I want to use this model for my task

I know you said that This model is designed for an "explicit" feedback.
but if I change something, do yo think this model is used for my task?

training on own data, and RMSE is nan

hey @okuchaiev

I have been trying to train on my own data.
Dataset consists of 539278 user_ids and 1551731 items. Data is super sparse.
While training my RMSE: nan. Should I take absolute value of mseloss?

I have PyTorch 0.4, Cuda 9.0. Training on gtx 1080ti.

Using GPUs: [0] Doing epoch 0 of 12 run.py:198: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number t_loss += loss.data[0] [0, 0] RMSE: 8.0848995 run.py:212: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number total_epoch_loss += loss.data[0] [0, 1000] RMSE: nan [0, 2000] RMSE: nan [0, 3000] RMSE: nan [0, 4000] RMSE: nan [0, 5000] RMSE: nan [0, 6000] RMSE: nan [0, 7000] RMSE: nan [0, 8000] RMSE: nan Total epoch 0 finished in 1966.838391304016 seconds with TRAINING RMSE loss: nan run.py:74: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number total_epoch_loss += loss.data[0] run.py:75: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number denom += num_ratings.data[0] Epoch 0 EVALUATION LOSS: nan Saving model to model_save/model.epoch_0 Doing epoch 1 of 12 [1, 0] RMSE: nan [1, 1000] RMSE: nan [1, 2000] RMSE: nan [1, 3000] RMSE: nan [1, 4000] RMSE: nan [1, 5000] RMSE: nan [1, 6000] RMSE: nan [1, 7000] RMSE: nan [1, 8000] RMSE: nan

Could you please help me out?

key error when loading the validation dataset

I'm trying to compute movielens, after downloading movielens 20M, computing the data transformation, when I run the recommender:
python run.py --gpu_ids 0 --path_to_train_data movielens/train --path_to_eval_data movielens/val --hidden_layers 512,512,1024 --non_linearity_type selu --batch_size 128 --logdir model_movielens --drop_prob 0.8 --optimizer momentum --lr 0.005 --weight_decay 0 --aug_step 1 --noise_prob 0 --num_epochs 1 --summary_frequency 1000
I get an error:

Loading training data
Data loaded
Total items found: 138493
Vector dim: 20668
Loading eval data
Traceback (most recent call last):
  File "run.py", line 236, in <module>
    main()
  File "run.py", line 116, in main
    item_id_map=data_layer.itemIdMap)
  File "/home/hoaphumanoid/notebooks/repos/DeepRecommender/reco_encoder/data/input_layer.py", line 48, in __init__
    value = minor_map[int(parts[self._minor_ind])]
KeyError: 18223

It looks that it is trying to find a key in the evaluation dataset that is not in the training dataset?

Item Based CF

Hi,

I am trying to execute both user based and item based recommendations and compare the results.
If my understanding is right, we do this by changing major=users/items ,

def main():
params['major'] = 'users'/items

Is this the correct way to do it? Thanks!

Add MovieLens example

Provide a recepie for training a model on MovieLens data (20M and 1M).

This should include the following:

  1. Data converter. Consider fixing/adjusting this script
  • Train/Eval/Test split. Training test ratings should come before any Eval and Test ratings. See section 3.1 (Experiment setup) of this paper for details and Netflix converter for example.
  1. An example train.sh, and test.sh for 20M and 1M MovieLens data sets which will train/eval and test the model correspondingly.

  2. Choose model architecture and hyper parameters that give best eval RMSE you can get.

Deprecation warnings migrating to 0.4

We see the following deprecation warnings when using PyTorch Version 0.4.0 with the DeepRecommender package. These deprecation warnings will result in errors in PyTorch Version 0.5.0. Specifically, the migration involves two changes:

  • Changing nn.init.xavier_uniform to nn.init.xavier_uniform_,
  • Instead of using loss.data[0] syntax, use loss.item().

These changes are breaking; when using PyTorch Version 0.3.0, nn.init.xavier_uniform_ and loss.item() result in AttributeError: module 'torch.nn.init' has no attribute 'xavier_uniform_' and 'Variable' object has no attribute 'item' , respectively.

Deprecation warnings:

DeepRecommender/reco_encoder/model/model.py:60: UserWarning: nn.init.xavier_uniform is now deprecated in favor of nn.init.xavier_uniform_.
  weight_init.xavier_uniform(w)

DeepRecommender/reco_encoder/model/model.py:72: UserWarning: nn.init.xavier_uniform is now deprecated in favor of nn.init.xavier_uniform_.
  weight_init.xavier_uniform(w)

DeepRecommender/run.py:198: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number
  t_loss += loss.data[0]
[0,     0] RMSE: 3.8345311

DeepRecommender/run.py:212: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number
  total_epoch_loss += loss.data[0]

DeepRecommender/run.py:74: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number
  total_epoch_loss += loss.data[0]

DeepRecommender/run.py:75: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number
  denom += num_ratings.data[0]

Movielens experiments

Hey @okuchaiev I've been doing some experiments with the movielens dataset 20M, everything is in 1 GPU P100, wd=0, selu, sdg with momentum and the other default options, here the results:

300 epochs:
* h256.256.256.512_lr0.005_dp0.8_bs64_aug1_ep300, Process time 34708s, RMSE: 0.8270190993497112
* h256.256.256.512_lr0.002_dp0.8_bs64_aug1_ep300, Process time 35469s, RMSE: 0.8298143399008006
* h128.128.256_lr0.005_dp0.8_bs64_aug1_ep300, Process time 30648s, RMSE: 0.832070992094558
* h256.256.512_lr0.005_dp0.8_bs128_aug1_ep300, Process time 29872s, RMSE: 0.8335073607999096

50 epochs:
* h128.128.256_lr0.005_dp0.8_bs64_aug1_ep50, Process time 7316s, RMSE: 0.8373085857776336
* h128.128.256_lr0.002_dp0.8_bs64_aug1_ep50, Process time 7255s, RMSE: 0.8505965955389251
* h128.128.256_lr0.005_dp0.8_bs128_aug1_ep50, Process time 6559s, RMSE: 0.8463281991349552
* h128.128.256_lr0.002_dp0.8_bs128_aug1_ep50, Process time 6550s, RMSE: 0.8666786635557636
* h128.128.128.256_lr0.005_dp0.8_bs64_aug1_ep50, Process time 5893s, RMSE: 0.8356554165814366
* h128.128.128.256_lr0.005_dp0.8_bs128_aug1_ep50, Process time 5892s, RMSE: 0.8426986399494479 
* h128.128.128.256_lr0.002_dp0.8_bs64_aug1_ep50, Process time 6084s, RMSE: 0.8461920729794343
* h128.128.128.256_lr0.002_dp0.8_bs128_aug1_ep50, Process time 5903s, RMSE: 0.8607027585932623

* h256.512_lr0.005_dp0.8_bs64_aug1_ep50, Process time 6166s, RMSE: 0.8420912025890906
* h256.512_lr0.002_dp0.8_bs64_aug1_ep50, Process time 6263s, RMSE: 0.8615509368380064
* h256.512_lr0.005_dp0.8_bs128_aug1_ep50, Process time 6450s, RMSE: 0.8581210973782695
* h256.512_lr0.002_dp0.8_bs128_aug1_ep50, Process time 6526s, RMSE: 0.8901831164436558
* h256.256.512_lr0.005_dp0.8_bs256_aug1_ep50, Process time 7761s, RMSE: 0.8505812808299635
* h256.256.512_lr0.001_dp0.8_bs256_aug1_ep50, Process time 8048s, RMSE: 0.9072564425137314
* h256.256.512_lr0.005_dp0.8_bs128_aug1_ep50, Process time 7313s, RMSE: 0.8383932386367762
* h256.256.512_lr0.001_dp0.8_bs128_aug1_ep50, Process time 7485s, RMSE: 0.8805132787428274
* h256.256.256.128_lr0.005_dp0.8_bs64_aug1_ep50, Process time 6832s, RMSE: 0.8826071700724554
* h256.256.256.128_lr0.002_dp0.8_bs64_aug1_ep50, Process time 6924s, RMSE: 0.8518409650015333
* h256.256.256.128_lr0.005_dp0.8_bs128_aug1_ep50, Process time 6670s, RMSE: 0.8516106274521363
* h256.256.256.128_lr0.002_dp0.8_bs128_aug1_ep50, Process time 6640s, RMSE: 0.876183417003062
* h256.256.256.256_lr0.005_dp0.8_bs64_aug0_ep50_wd0.00001, Process time 5681s, RMSE: 0.8247678613140474
* h256.256.256.256_lr0.005_dp0.8_bs32_aug0_ep50_wd0.00001_elu, Process time 5366s, RMSE: 0.82389932562926
* h256.256.256.256_lr0.005_dp0.8_bs64_aug0_ep50_wd0.00001_elu, Process time 5017s, RMSE: 0.8282986409964643
* h256.256.256.256_lr0.005_dp0.8_bs64_aug1_ep50_wd0.00001_elu, Process time 5136s, RMSE: 0.8324252316444559
* h256.256.256.512_lr0.005_dp0.8_bs64_aug1_ep50, Process time 8286s, RMSE: 0.8292924646574259
* h256.256.256.512_lr0.002_dp0.8_bs64_aug1_ep50, Process time 8493s, RMSE: 0.8357998662652054
* h256.256.256.512_lr0.005_dp0.8_bs128_aug1_ep50, Process time 8877s, RMSE: 0.8342839521305517
* h256.256.256.512_lr0.002_dp0.8_bs128_aug1_ep50, Process time 9235s, RMSE: 0.8479914043369364
* h256.256.256.512_lr0.005_dp0.8_bs64_aug2_ep50, Process time 6194s, RMSE: 0.8314955237915321
* h256.256.256.512_lr0.005_dp0.6_bs64_aug1_ep50, Process time 5849s, RMSE: 0.8381541416170412
* h256.256.256.512_lr0.005_dp0.8_bs64_aug1_ep50_relu, Process time 5202s, RMSE: 2.165232282658598
* h256.256.256.512_lr0.005_dp0.8_bs64_aug1_ep50_wd0.001, Process time 6358s, RMSE: 0.8615138222322402
* h256.256.256.512_lr0.005_dp0.8_bs32_aug1_ep50, Process time 6601s, RMSE: 0.8286968416622136 
* h256.256.256.512_lr0.005_dp0.8_bs64_aug3_ep50, Process time 6700s, RMSE: 0.8325802227208278 
* h256.256.256.512_lr0.005_dp0.8_bs64_aug1_ep50_elu, Process time 5196s, RMSE: 0.829355669108844
* h256.256.256.512_lr0.005_dp0.8_bs64_aug1_ep50_wd0.00001, Process time 6092s, RMSE: 0.8286978067933191
* h256.256.256.512_lr0.005_dp0.8_bs32_aug1_ep50_wd0.00001_lrelu, Process time 5745s, RMSE: nan
* h256.256.256.512_lr0.005_dp0.9_bs32_aug1_ep50_wd0.00001, Process time 6752s, RMSE: 0.8358882991657209
* h256.256.256.512_lr0.005_dp0.8_bs32_aug1_ep50_wd0.00001_swish, Process time 6438s, RMSE: nan
* h256.256.256.512_lr0.005_dp0.8_bs32_aug1_ep50_wd0.00001_relu6, Process time 5806s, RMSE: 1.5870073433725758
* h256.256.256.512_lr0.005_dp0.8_bs32_aug1_ep50_wd0.00001, Process time 8147s, RMSE: 0.8278444900229919
* h256.256.256.512_lr0.005_dp0.5_bs64_aug1_ep50_wd0.00001, Process time 7146s, RMSE: 0.8458491087391462
* h256.256.256.512_lr0.005_dp0.8_bs64_aug1_ep50_wd0.00001_sigmoid, Process time 5855s, RMSE: 2.6946873485097624
* h256.256.256.512_lr0.005_dp0.8_bs64_aug1_ep50_wd0.00001_tanh, Process time 5879s, RMSE: 2.6892056428673965
* h256.256.256.512_lr0.005_dp0.8_bs64_aug1_ep50_wd0.00001_elu, Process time 6614s, RMSE: 0.829774144023921
* h256.256.256.512_lr0.005_dp0.8_bs64_aug1_ep50_wd0.00001_none, Process time 6069s, RMSE: nan
* h256.256.256.512_lr0.005_dp0.8_bs64_aug0_ep50_wd0.00001_elu, Process time 7231s, RMSE: 0.8246256263673024
* h256.256.256.512_lr0.005_dp0.8_bs64_aug1_ep50_wd0.00001_elu_cons, Process time 7623s, RMSE: nan
* h256.256.256.512_lr0.005_dp0.8_bs64_aug0_ep50_wd0.00001_elu_skp, Process time 5056s, RMSE: 0.8256991938919997
* h256.256.256.512_lr0.005_dp0.8_bs64_aug0_ep50_wd0.00001_elu_np0.2, Process time 5475s, RMSE: 0.8249846126834408
* h256.256.256.512_lr0.005_dp0.8_bs64_aug0_ep50_wd0.00001_elu_np0.4, Process time 5404s, RMSE: 0.8249381947311919
* h256.256.256.512_lr0.005_dp0.8_bs64_aug0_ep50_wd0.00001_elu_np0.6, Process time 5336s, RMSE: 0.8244053170054507
* h256.256.256.256.512_lr0.005_dp0.8_bs64_aug0_ep50_wd0.00001, Process time 8364s, RMSE: 0.8252511286962865
* h256.256.256.256.512_lr0.005_dp0.8_bs32_aug0_ep50_wd0.00001_elu, Process time 7876s, RMSE: 0.8208981665592793
* h256.256.256.256.512_lr0.005_dp0.8_bs64_aug0_ep50_wd0.00001_elu, Process time 6841s, RMSE: 0.8232621691107652
* h256.256.256.256.512_lr0.004_dp0.8_bs64_aug0_ep50_wd0.00001_elu, Process time 7459s, RMSE: 0.8250180634812827

* h512.12.12_lr0.005_dp0.8_bs256_aug1_ep50, Process time 8197s, RMSE: nan
* h512.128.12_lr0.005_dp0.8_bs256_aug1_ep50, Process time 5436s, RMSE: 1.0141237434398267
* h512.128.12_lr0.01_dp0.8_bs256_aug1_ep50, Process time 7107s, RMSE: 1.669601678946766
* h512.128.12_lr0.005_dp0.8_bs128_aug1_ep50, Process time 7304s, RMSE: 1.9644852247335254
* h512.128.12_lr0.01_dp0.8_bs128_aug1_ep50, Process time 8256s, RMSE: 2.2515750498298295
* h512.256.128_lr0.005_dp0.8_bs64_aug1_ep50, Process time 8330s, RMSE: nan
* h512.256.128_lr0.002_dp0.8_bs64_aug1_ep50, Process time 9167s, RMSE: 0.8532720004964209
* h512.256.128_lr0.005_dp0.8_bs128_aug1_ep50, Process time 7964s, RMSE: 0.8514060297120034
* h512.256.128_lr0.002_dp0.8_bs128_aug1_ep50, Process time 7808s, RMSE: 0.8771772641279514
* h512.512.512.1024_lr0.005_dp0.8_bs256_aug1_ep50, Process time 7631s, RMSE: 0.8408605201075207
* h512.512.512.1024_lr0.001_dp0.8_bs256_aug1_ep50, Process time 9081s, RMSE: 0.8851269050050165
* h512.512.512.1024_lr0.005_dp0.8_bs128_aug1_ep50, Process time 8221s, RMSE: 0.8390253067561858
* h512.512.512.1024_lr0.001_dp0.8_bs128_aug1_ep50, Process time 9201s, RMSE: 0.85770380406226
* h512.512.512_lr0.005_dp0.8_bs256_aug1_ep50, Process time 6015s, RMSE: 0.8490931741924439
* h512.512.512_lr0.001_dp0.8_bs256_aug1_ep50, Process time 6013s, RMSE: 0.8994081473635861
* h512.512.512_lr0.005_dp0.8_bs128_aug1_ep50, Process time 6276s, RMSE: 0.8435522435037354
* h512.512.512_lr0.001_dp0.8_bs128_aug1_ep50, Process time 6347s, RMSE: 0.8727244815655271
* h512.512.1024_lr0.005_dp0.8_bs256_aug1_ep50, Process time 7118s, RMSE: 0.8515680809349422
* h512.512.1024_lr0.001_dp0.8_bs256_aug1_ep50, Process time 8119s, RMSE: 0.9003549734935724
* h512.512.1024_lr0.005_dp0.8_bs128_aug1_ep50, Process time 8581s, RMSE: 0.8412896519786348
* h512.512.1024_lr0.001_dp0.8_bs128_aug1_ep50, Process time 8284s, RMSE: 0.8725194757170811

* h1024.512.512_lr0.005_dp0.8_bs256_aug1_ep50, Process time 8366s, RMSE: 0.848893039387492
* h1024.512.512_lr0.001_dp0.8_bs256_aug1_ep50, Process time 8487s, RMSE: 0.9071224941082034
* h1024.512.512_lr0.005_dp0.8_bs128_aug1_ep50, Process time 8055s, RMSE: 0.8402467531636137
* h1024.512.512_lr0.001_dp0.8_bs128_aug1_ep50, Process time 7878s, RMSE: 0.8799644549035793
* h2048.1024.512_lr0.005_dp0.8_bs256_aug1_ep50, Process time 9201s, RMSE: 0.8531879835494411
* h2048.1024.512_lr0.001_dp0.8_bs256_aug1_ep50, Process time 9343s, RMSE: 0.9083370416845887
* h2048.1024.512_lr0.005_dp0.8_bs128_aug1_ep50, Process time 9071s, RMSE: 0.8457299178975725
* h2048.1024.512_lr0.001_dp0.8_bs128_aug1_ep50, Process time 9159s, RMSE: 0.8802430695913057

Some observations:

  • It looks that LR=0.005 is the winner, with 0.002 sometimes giving good results
  • Smaller BS yields better accuracy (which makes sense), but there is no a massive difference in training time
  • The architectures that seems to work better are inverted pyramids (256,256, 512) instead of traditional autoencoder shape that are pyramids (512, 256, 128). Could this behavior be because the input is super sparse, so it's better a small input size? However, it is odd that adding bigger internal layers seems to work better.

Did you have time to do some experiments on your side?

How to get prediction from a model?

I have follow your guide to train model in movie lens data. Now I want to get predictions from my trained model but i'm totally new in pytorch. Can u guide me how I can get predictions?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.