Giter VIP home page Giter VIP logo

dlrm_ssm's Introduction

Semantically Constrained Memory Allocation (SCMA) for Embedding inEfficient Recommendation Systems

Description:

This project applys similarity-based shared memory(SCMA) embeddings on DLRM. It is modified from Facebook's DLRM repository(https://github.com/facebookresearch/dlrm).

Implementation

DLRM + SCMA: PyTorch. Currently we only support SCMA for PyTorch.

   dlrm_s_pytorch.py - Main script for DLRM model
   lsh_pp_pretraining.py - Pretraining script for SCMA, run this file to generate the minhash table.
   min_hash_generator.py - Min hash table generator. Most of the SCMA + DLRM experiments in the paper are finished using min hash table generated by this script, including the hyperparameter tuning experiments.
   min_hash_generator_universal_hash.py - Min hash table generator using updated universal hash functions. The 540M parameters' experiment for Criteo dataset is finished using min hash table generated by this updated script.

How to run our code?

Prerequests:

pytorch-nightly (6/10/19)

scikit-learn

numpy

onnx (optional)

pydot (optional)

torchviz (optional)

cuda-dev(optional)

tqdm

Data preparation

Download the dataset from https://www.kaggle.com/c/criteo-display-ad-challenge and Put the train.txt file and test.txt file in /dlrm_ssm/input/

Run SCMA embeddings with DLRM:

Note: This code provides a proof of concept of LMA-DLRM . It precomputes lsh based mapping and uses it in code. We are in process of releasing the efficient code where mapping is computed on the fly.

1. Pre-training

Use tricks/lsh_pp_pretraining.py to generate the min_hash_table for SCMA embeddings. 3 hyperparameter is needed:

command line arguments usage
EMBEDDING embedding dimension
NUM_HASH number of hash functions used to generate min_hash_table
NUM_PT number of datapoints used to represent the dataset

For example:

python tricks/lsh_pp_pretraining 128 2 125000

will generate a min hash table with 128 embedding dimension using 2 hash functions and 125k datapoints.

2. Training

Use dlrm_s_pytorch.py to train the DLRM model with SCMA embeddings.

command line arguments

command line arguments type usage
lsh-emb-flag - enable SSM embedding for DLRM
lsh-emb-compression-rate float 1/expansion rate
rand-emb-flag - enable HashNet embedding for DLRM
rand-emb-compression-rate float 1/expansion rate
arch-sparse-feature-size string we use "13-512-256-128" which is the default value for DLRM
arch-mlp-bot string we use "1024-1024-512-256-1" which is the default value for DLRM
data-generation string use "dataset" to run our code, default is "ramdom"
data-set string use "kaggle" to run our code
raw-data-file string please use "./input/train.txt"
processed-data-file string please use "./input/kaggleAdDisplayChallenge_processed.npz"
loss-function string please use "bce"
round-targets boolean pleause use "True"
learning-rate double default learning rate is 0.1
mini-batch-size int the batch size is set to 2048 in our experiments
nepochs int the number of epochs is set to 15 in our experiments
print-freq int the frequency to print log
print-time -
test-mini-batch-size int the batch size for test, please use 16384 to run our code
test-num-workers int we use 16 for our experiments

A sample run of DLRM with SSM ebmedding code

A sample that runs our code is in:

bench/dlrm_s_criteo_kaggle.sh

Run HashNet embeddings with DLRM:

1. Install HashNet

cuda-dev is a prerequest for our implementation of HashNet embeddings.

python tricks/hashedEmbeddingBag/setup.py install

will install the HashNet.

We also have a saparate HashNet repo here: https://github.com/wanmeihuali/HashedEmbeddingBag

2. Run DLRM with HashNet embedding

Use --rand-emb-flag in bench/dlrm_s_criteo_kaggle.sh to enbale HashNet embedding, use --rand-emb-compression-rate to set your expansion rate. Other settings can stay the same as SSM embedding.

License

This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.

dlrm_ssm's People

Contributors

apd10 avatar dkorchevgithub avatar dmudiger avatar hjmshi avatar huwan avatar hyuen avatar jeng1220 avatar lucky096 avatar mnaumovfb avatar mneilly-et avatar rachithayp avatar rohithkrn avatar taylanbil avatar tayo avatar tginart avatar tgrel avatar wanmeihuali avatar xulunfan avatar xw285cornell avatar yanzhoupan avatar

Stargazers

 avatar  avatar  avatar  avatar

Watchers

 avatar

dlrm_ssm's Issues

How to use this on the Criteo Terabyte Dataset?

FYI. Reference your @yanzhoupan code, I just see you only support the small dataset of " Criteo Kaggle Display Advertising Challenge Dataset". In the original Facebook DLRM repo they also support Criteo Terabyte Dataset. So will you also support this? And also could you tell me how to use lsh_pp_pretaining.py on Criteo Terabyte Dataset? Because the larger dataset for training almost 621GB after processed by data_utils.py, directly load whole file will OOM. Thanks!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.