Giter VIP home page Giter VIP logo

deep-knowledge-tracing-plus's Introduction

DKT+

This is the repository for the code in the paper Addressing Two Problems in Deep Knowledge Tracing via Prediction-Consistent Regularization (ACM, pdf)

If you find this repository useful, please cite

@inproceedings{LS2018_Yeung_DKTP,
  title={Addressing two problems in deep knowledge tracing via prediction-consistent regularization},
  author={Yeung, Chun Kit and Yeung, Dit Yan},
  year={2018},
  booktitle = {{Proceedings of the 5th ACM Conference on Learning @ Scale}},
  pages = {5:1--5:10},
  publisher = {ACM},
}

Abstact

Knowledge tracing is one of the key research areas for empowering personalized education. It is a task to model students' mastery level of a knowledge component (KC) based on their historical learning trajectories. In recent years, a recurrent neural network model called deep knowledge tracing (DKT) has been proposed to handle the knowledge tracing task and literature has shown that DKT generally outperforms traditional methods. However, through our extensive experimentation, we have noticed two major problems in the DKT model. The first problem is that the model fails to reconstruct the observed input. As a result, even when a student performs well on a KC, the prediction of that KC's mastery level decreases instead, and vice versa. Second, the predicted performance across time-steps is not consistent. This is undesirable and unreasonable because student's performance is expected to transit gradually over time. To address these problems, we introduce regularization terms that correspond to \emph{reconstruction} and \textit{waviness} to the loss function of the original DKT model to enhance the consistency in prediction. Experiments show that the regularized loss function effectively alleviates the two problems without degrading the original task of DKT.

Requirements

I have used tensorflow to develop the deep knowledge tracing model, and the following is the packages I used:

tensorflow==1.2.0 (or tensorflow-gpu==1.3.0)
scikit-learn==0.18.1
scipy==0.19.0
numpy==1.13.3

The packages used for the visualization of the student knowledge state are

seaborn
matplotlib

Data Format

The first line the number of exercises a student attempted. The second line is the exercise tag sequence. The third line is the response sequence.

15
1,1,1,1,7,7,9,10,10,10,10,11,11,45,54
0,1,1,1,1,1,0,0,1,1,1,1,1,0,0

Program Usage

Run the experiment

python main.py

Detail hyperparameter for the program

usage: main.py [-h]
               [-hl [HIDDEN_LAYER_STRUCTURE [HIDDEN_LAYER_STRUCTURE ...]]]
               [-cell {LSTM,GRU,BasicRNN,LayerNormBasicLSTM}]
               [-lr LEARNING_RATE] [-kp KEEP_PROB] [-mgn MAX_GRAD_NORM]
               [-lw1 LAMBDA_W1] [-lw2 LAMBDA_W2] [-lo LAMBDA_O]
               [--num_runs NUM_RUNS] [--num_epochs NUM_EPOCHS]
               [--batch_size BATCH_SIZE] [--data_dir DATA_DIR]
               [--train_file TRAIN_FILE] [--test_file TEST_FILE]
               [-csd CKPT_SAVE_DIR] [--dataset DATASET]

optional arguments:
  -h, --help            show this help message and exit
  -hl [HIDDEN_LAYER_STRUCTURE [HIDDEN_LAYER_STRUCTURE ...]], --hidden_layer_structure [HIDDEN_LAYER_STRUCTURE [HIDDEN_LAYER_STRUCTURE ...]]
                        The hidden layer structure in the RNN. If there is 2
                        hidden layers with first layer of 200 and second layer
                        of 50. Type in '-hl 200 50'
  -cell {LSTM,GRU,BasicRNN,LayerNormBasicLSTM}, --rnn_cell {LSTM,GRU,BasicRNN,LayerNormBasicLSTM}
                        Specify the rnn cell used in the graph.
  -lr LEARNING_RATE, --learning_rate LEARNING_RATE
                        The learning rate when training the model.
  -kp KEEP_PROB, --keep_prob KEEP_PROB
                        Keep probability when training the network.
  -mgn MAX_GRAD_NORM, --max_grad_norm MAX_GRAD_NORM
                        The maximum gradient norm allowed when clipping.
  -lw1 LAMBDA_W1, --lambda_w1 LAMBDA_W1
                        The lambda coefficient for the regularization waviness
                        with l1-norm.
  -lw2 LAMBDA_W2, --lambda_w2 LAMBDA_W2
                        The lambda coefficient for the regularization waviness
                        with l2-norm.
  -lo LAMBDA_O, --lambda_o LAMBDA_O
                        The lambda coefficient for the regularization
                        objective.
  --num_runs NUM_RUNS   Number of runs to repeat the experiment.
  --num_epochs NUM_EPOCHS
                        Maximum number of epochs to train the network.
  --batch_size BATCH_SIZE
                        The mini-batch size used when training the network.
  --data_dir DATA_DIR   the data directory, default as './data/
  --train_file TRAIN_FILE
                        train data file, default as 'skill_id_train.csv'.
  --test_file TEST_FILE
                        train data file, default as 'skill_id_test.csv'.
  -csd CKPT_SAVE_DIR, --ckpt_save_dir CKPT_SAVE_DIR
                        checkpoint save directory
  --dataset DATASET

deep-knowledge-tracing-plus's People

Contributors

ckyeungac avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

deep-knowledge-tracing-plus's Issues

Why correct_seqs_oh looks like that?

Hello guys,

congratulations on your "Addressing Two Problems in Deep Knowledge Tracing via
Prediction-Consistent Regularization" paper. It is very interesting, and I enjoyed it very much. To get a deeper understanding of it I started analysing your code, and found a thing that is not clear to me. I would be very grateful if you could clarify it for me.

In OriginalInputProcessor class you have got a process_problems_and_corrects function. Let's assume that:

num_problems = 10
max_seq_length = 5

and one of the sequence pairs looks as following (same as in the comment in your script):

problem_seq = [1,3,2]
correct_seq = [1,0,1]

As I understand it, this function does the following:

  1. Pad with -1 from the right.

problem_seqs_pad = [1,3,2,-1,-1]
problem_seqs_pad = [1,0,1,-1,-1]

  1. Transform correct_seq.

[1,3,2,-1,-1] * [1,0,1,-1,-1] * [1,0,1,-1,-1] = [1,0,2,-1,-1] # problem_seqs_pad * correct_seq_pad * correct_seq_pad
correct_seqs_pad = [1,-1,2,-1,-1] # replaced 0s with -1s

  1. One hot encode both variables:

problem_seqs_oh = array([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

correct_seqs_oh = array([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

  1. You than concatenate problem_seqs_oh and correct_seqs_oh to get the X (removing the last element of a sequence first).

What I don't understand here is why the correct_seq_oh looks this way? I could imagine that it could look like that:

correct_seqs_oh = [[1],
[0],
[1],
[0],
[0]]

so just the information whether the particular problem was solved correctly or not. It would make sense for me because we already have the information on the id of the problem in problem_seq, so one could argue that you don't need it once again in the correct_seqs_oh, but I feel that I might be missing something here.

Could you please explain why have you decided to make correct_seqs_oh look this way?

Best,
MO

Should an separate instance of model be created for each student in actual use?

Congratulations on your paper. It is very interesting, and I enjoyed it very much.

When I read your paper I have the following question:
Should an separate instance of model be created for each student in actual use? Or, is the interaction sequence of all students going through a single model in order (according to timestamp)?

m1 and m2 calculation

I've got a quick question about the way in which you have calculated m1 and m2 metrics. You have written in your paper that those two metrics aim "to measure the consistency between the observed input and the change in the corresponding prediction". So if we are interested in how the input impacts the prediction, why when calculating coefficient, you use y (y_corr_batch and y_seq_batch, so target values) instead of corresponding x values?

I have in mind this part of your code:

coefficient = np.sum( (np.power(base, 1 - y_corr_batch[:, 1:, :])) * y_seq_batch[:, 1:, :], axis=2)

m1 = np.sum(
    coefficient * np.sign(np.sum(
        (pred_seqs[:, 1:, :] - pred_seqs[:, :-1, :]) * y_seq_batch[:, 1:, :], #y_t-y_{t-1}
        axis=2
        ))
    )
m2 = np.sum(
    coefficient * np.sum(
        (pred_seqs[:, 1:, :] - pred_seqs[:, :-1, :]) * y_seq_batch[:, 1:, :],
        axis=2
        )
    )

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.