Giter VIP home page Giter VIP logo

Comments (20)

shchur avatar shchur commented on June 27, 2024 2

One more important detail: make sure that you apply the relevant transformations to the samples before feeding them into the RNN. Under default settings, we transform the RNN input in_times by applying logarithm (code), and also additionally normalize the values to have zero mean and unit standard deviation (code) using the statistics of the training set.

from ifl-tpp.

shchur avatar shchur commented on June 27, 2024 1

There is no reason, really, I guess you could just do .detach() and get the same result. You could also wrap the entire call to sample in a no_grad() context if you don't need to differentiate the samples afterwards.

from ifl-tpp.

shchur avatar shchur commented on June 27, 2024

Hi, please have a look at #3 and #5.

from ifl-tpp.

shchur avatar shchur commented on June 27, 2024

Please have a look at my answer #5 (comment)

from ifl-tpp.

KristenMoore avatar KristenMoore commented on June 27, 2024

@Kouin - did you figure it out? I'm trying to do the same.

from ifl-tpp.

shchur avatar shchur commented on June 27, 2024

I provided the code for generating predictions in the thread for another issue #5 (comment) and it seems to work for the original poster there

from ifl-tpp.

KristenMoore avatar KristenMoore commented on June 27, 2024

Thanks, I was thrown off by the comment in #3 (comment) about the step() function in RNNLayer being needed when generating new sequences. Has it been used in #5 (comment)?

from ifl-tpp.

shchur avatar shchur commented on June 27, 2024

I'm not sure if I understand what exactly you want to do. Can you describe it in more detail?

If you want to sample new trajectories from the TPP, you will need to use RNNLayer.step. At each step you will sample the next inter-event time \tau_{i} from p(\tau_i | History_i) and feed it into the RNN to obtain the parameters for p(\tau_{i+1} | History_{i+1}).

If you want to get the predictions one step into the future (i.e. you want to compute the expected time until the next event \mathbb{E}[\tau_i | \History_i]) you should use the code from #5 that I referenced. The code that I wrote there computes the expected time until the next event for all events in the batch. You can use it to, for example, compute the mean squared error or mean absolute error in the event time prediction task.

from ifl-tpp.

KristenMoore avatar KristenMoore commented on June 27, 2024

Thanks a lot for the explanation - I want to do the former. I'll follow these directions.

from ifl-tpp.

KristenMoore avatar KristenMoore commented on June 27, 2024

Thanks a lot!

from ifl-tpp.

cjchristopher avatar cjchristopher commented on June 27, 2024

Hi @shchur - apologies if this is a silly question! Just on the above instructions for sampling new trajectories with RNNLayer.step - is the suggestion that this would be done as model.rnn.step() in a sampling loop? I'm running into problems trying to do this correctly after training the model as per your example in the interactive notebook. We're interested in sampling new trajectories from i=0, or extending one of the input samples. Any assistance getting us on the right track is appreciated!

dl_train = torch.utils.data.DataLoader(d_train, batch_size=1, shuffle=False, collate_fn=collate)
for x in dl_train:
    break
y, h = model.rnn.step(x, model.rnn(x))

from ifl-tpp.

shchur avatar shchur commented on June 27, 2024

Hi @cjchristopher, here is my implementation of sampling for entire trajectories

next_in_time = torch.zeros(1, 1, 1)
h = torch.zeros(1, 1, history_size)
inter_times = []
t_max = 1000
with torch.no_grad():
    while sum(inter_times) < t_max:
        _, h = model.rnn.step(next_in_time, h)
        tau = model.decoder.sample(1, h)
        inter_times.append(tau.item())
        next_in_time = ((tau + 1e-8).log() - mean_in_train) / std_in_train

from ifl-tpp.

avs123 avatar avs123 commented on June 27, 2024

Hi @shchur, thanks for sharing the code. I need your clarification on how you denormalize the sample generated. The transformation you applied [next_in_time = ((tau + 1e-8).log() - mean_in_train) / std_in_train] doesn't seem to work as my input has discrete integer times and outputs generated are fractional. Please assist by providing possible de-normalization code that works on your model. Thanks :)

from ifl-tpp.

shchur avatar shchur commented on June 27, 2024

Hi @avs123, do I understand it correctly that you want the model to generate discrete inter-event times? This is currently not supported, as the model is learning a continuous probability distribution for the inter-event times, so the sampled inter-event times tau will all be continuous. A potential hacky solution is to discretize the inter-event times after they are sampled with torch.ceil(tau).

from ifl-tpp.

avs123 avatar avs123 commented on June 27, 2024

Thanks for responding on time @shchur. Can you please confirm that is the generated tau the actual inter-event time or do we need to apply some transformation to get it in the time space as is the input sequence? Anything to nullify the effect of log and normalisation transformation that is being applied to the input. Please elaborate.

from ifl-tpp.

shchur avatar shchur commented on June 27, 2024

The transformations applied to the RNN input are not related to the transformations applied to the output, so tau should already be the correct inter-event time.

from ifl-tpp.

KristenMoore avatar KristenMoore commented on June 27, 2024

Thanks for the refactored code, @shchur .
To sample in the new framework, do I just need to implement sampling from inter_time_dist?

features = self.get_features(batch)  
context = self.get_context(features)   
inter_time_dist = self.get_inter_time_dist(context)

Any tips would be appreciated. Thanks.

from ifl-tpp.

shchur avatar shchur commented on June 27, 2024

Hey @KristenMoore, I have just implemented sampling for the new code. I checked it on a few datasets (see interactive.ipynb) and it seems that the generated sequences look fine. I haven't tested it on marked sequences, though. The code seems simple enough, so I hope that there are no mistakes, but let me know if anything seems odd.

I also realized that there was a very serious bug that I introduced while refactoring. The slicing for context embedding was off-by-one, which means that the model was peeking into the future. This is fixed by the last commit.

from ifl-tpp.

KristenMoore avatar KristenMoore commented on June 27, 2024

Great - thanks @shchur!
I will let you know if I notice anything that seems odd.

from ifl-tpp.

KristenMoore avatar KristenMoore commented on June 27, 2024

Hi @shchur - just one question about the sampling.
Why is the torch.no_grad() only applied to this one line: https://github.com/shchur/ifl-tpp/blob/master/code/dpp/models/recurrent_tpp.py#L178

Thanks.

from ifl-tpp.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.