Giter VIP home page Giter VIP logo

masr's Introduction

MASR

This repo contains source code for our paper: "Adversarial Mahalanobis Distance-based Attentive Song Recommender for Automatic Playlist Continuation" published in SIGIR 2019.

Data Format:

  • *.rating data file: [user_id]:[playlist_id] \t track_id \t [random-position-number] \t [1]
  • *.negative data file: ([user_id]:[playlist_id],[track_id]) \t [negative_track_id1] \t [negative_track_id2] ...

Hyper-parameters:

  • saved_path: path to save the output checkpoint, default is [chk_points] folder.
  • load_best_chkpoint: whether or not loading the best saved checkpoint [1 = Yes, 0 = No, Default is 0].
  • path: based path to the data directory. Default is [data] folder.
  • dataset: dataset name. Default is [demo]
  • epochs: Number of training epochs. Default is 50.
  • batch_size: Batch size. Default is 256.
  • num_factors: Number of hidden factors. Default is 64.
  • reg_mdr: regularization term of MDR model.
  • reg_mass: regularization term of MASS model.
  • num_neg: number of negative items to be sampled to compare with each positive item. Default is 4.
  • max_seq_len: maximum number of tracks in a playlist to consider. Default is -1, meaning considering all tracks in a target playlist.
  • dropout: Dropout. Default is 0.2
  • act_func: Activation function for the MASS model. 3 options [none, relu, tanh]. "none" means identity activation function here.
  • act_func_mdr: Activation function for MDR model. Default is "none".
  • model: model to train. 3 options ["mdr", "mass", "masr"].
  • beta: contribution of MASS in MASR. Default is 0.5
  • out: whether or not saving the output checkpoint. Default is 1, meaning saving output checkpoint for each epoch.
  • cuda: Using GPU or not. 1 = Using GPU, 0 = using CPU. Default is 0.
  • data_type: whether or not using both user + playlist + track info ["upt"], or only user + track info ["ut"], or playlist + track info ["pt"]. For MDR, choose ["upt"]. For MASS, choose ["ut"]. This parameter is used when training MDR, or MASS. (3 options ["upt", "ut", "pt"].
  • data_type_mdr: Data type of MDR. This parameter is useful when training MASR, so we need to know which MDR model with which input data type we want to use.
  • data_type_mass: Data type of MASS. This parameter is useful when training MASR, so we need to know which MASS model with which input data type we want to use.
  • adv: training with flexible adversarial training? [1 = Yes, 0 = No]. Please training with normal MDR, or MASS to get best initial checkpoint and then train with adversarial training to get best results. Training with adv=1 from scratch cat lead to a lower result.
  • reg_noise: regularization term of noise. Default is 1.0. [or \lambda_\delta in Equation (23)]
  • eps: noise magnitude. Default is 1.0 [refers to \eps in Equation (24)]. For smaller hidden factors (i.e. 8 hidden factors), this noise magnitude can be set to 0.5 if you observe non-boosting results.

Demo example:

Training MDR and AMDR:

Training with MDR:

python -u main.py --cuda 1 --dropout 0.2 --dataset demo --epochs 50 --load_best_chkpoint 0 --model mdr --num_factors 64 --reg_mdr 0.0 --adv 0 --act_func_mdr none --data_type upt

Training with AMDR:

After training MDR, we will have best checkpoint saved at chk_points. The model will then automatically load the best chekpoint w.r.t the validation dataset, and use it as an initial start for adversarial learning. Without the initial learning of MDR, if you learn with adversarial learning from the sractch, we can get lower results.

python main.py --dataset demo --data_type upt --model mdr --num_factors  64 --reg_mdr 0.0 --load_best_chkpoint 1 --cuda 1 --epochs 50 --adv 1 --reg_noise 1.0 --eval 0 --lr 1e-3 

Training with MASS:

python -u main.py --act_func relu --cuda 1 --dropout 0.2 --dataset demo --epochs 50 --load_best_chkpoint 0 --model mass --num_factors 64 --reg_mass 1e-6 --adv 0 --data_type ut

Training with AMASS:

python main.py --act_func relu --dataset demo --data_type ut --model mass --num_factors  64 --reg_mass 1e-6 --load_best_chkpoint 1 --cuda 1 --epochs 50 --adv 1 --reg_noise 1.0 --eval 0 --lr 1e-3 

Training with MASR:

python main.py --act_func relu --dataset demo --model masr --num_factors  64 --reg_mass 1e-6 --reg_mdr 0.0 --load_best_chkpoint 1 --cuda 1 --epochs 50 --adv 0 --reg_noise 1.0 --eval 0 --lr 1e-3 --act_func_mdr none --data_type_mdr upt --data_type_mass ut --beta 0.5

Training with AMASR:

python main.py --act_func relu --dataset demo --model masr --num_factors  64 --reg_mass 1e-6 --reg_mdr 0.0 --load_best_chkpoint 1 --cuda 1 --epochs 50 --adv 1 --reg_noise 1.0 --eval 0 --lr 1e-3 --act_func_mdr none --data_type_mdr upt --data_type_mass ut --beta 0.5

If you dont have GPU, then set --cuda 0. Please enjoy the boosted performance from the adversarial training with our flexible noise magnitude.

Please cite our paper if you see it is helpful at:

@inproceedings{tran2019adversarial,
  title={Adversarial Mahalanobis Distance-based Attentive Song Recommender for Automatic Playlist Continuation},
  author={Tran, Thanh and Sweeney, Renee and Lee, Kyumin},
  booktitle={Proceedings of the 42nd International ACM SIGIR Conference on Research and Development in Information Retrieval},
  pages={245-254},
  year={2019},
  organization={ACM}
}

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.