Giter VIP home page Giter VIP logo

Comments (3)

Yura52 avatar Yura52 commented on June 10, 2024 2

Glad that it helped! I agree, a hypothetical "mainstream" implementation of TabR should probably provide a more friendly interface for implementing such business rules. However, I should admit that it is unlikely that we will implement this in this repository.

Patching candidate_indices is also a solution! However, it may limit the potential of TabR, since the number of candidates becomes significantly smaller (at least on this specific task).

Also, I see a minor "bug" in the new implementation. In the original code, there are two variables: is_train and training. They have different meaning:

  • is_train is about the dataset split, as defined here. It is used here to add the training batch itself to the list of candidates (which would not be valid for validation/test data, at least not by default).
  • training is about whether we do training or evaluation.

It seems that the only consequence of the new implementation is that the reported evaluation performance on the training data may be slightly worse than it could be. This should not affect neither the training process nor the validation/test scores. That said, I would recommend fixing this in future runs just in case.

from tabular-dl-tabr.

Yura52 avatar Yura52 commented on June 10, 2024

Hi! This is a great comment, thank you for bringing attention to this topic.

Problem

More generally, this phenomenon can be formulated as follows:

Machine learning models are known to suffer from distribution shifts happening during deployment, where, by "distribution", the distribution of objects (i.e. the join distribution of features and labels) is implied. TabR, however, also depends on the distribution of object-object interactions, so changes in this distribution can also be a problem for TabR.

The above issue is especially relevant to problems with time-like components (as you mentioned): without additional measures, TabR can learn to retrieve context objects from future, which is not practical. Importantly, preventing the model from retrieving from future is not enough. In fact, TabR should be also prevented from retrieving from the most recent lag_size observations, where lag_size should be defined individually for each task based on a business intuition or to match the deployment setup.

A similar problem can happen when a dataset has some kind of identifiers. For example, if the data is split by user ids, then, for a given user, TabR can learn to retrieve records related to the same user (even if identifiers are not explicitly presented!), which may not be possible during deployment.

Solution

One has to modify the original code in a way that explicitly prevents TabR from learning impractical algorithms for a given task. In certain situations it can mean that faiss will not be applicable anymore and one will have to manually reimplement the similarity search (i.e. compute distances and apply torch.topk). The idea is as follows:

  • with faiss, more than context_size objects should be retrieved and then filtered based on the business case.
  • with manual implementation, after computing the similarities, some of them should be set to -inf (before applying torch.topk) based on the business case.

The relevant code:

  • the similarity search happens here, and the main "output" of this code block is the context_idx variable which defines the context objects. If I remember everything correctly, it should be enough to customize how context_idx is computed.
  • chances are that you will need to pass some additional information to the forward method.

Does this help?


Overall, we are aware of this phenomenon (and we suspect that the bad performance on the "MI" benchmark is caused by this, but we did not verify that), and we are planning to mention it in the "Limitation" section in the next revision. Thank you once again for the report.

from tabular-dl-tabr.

MichaelMedek avatar MichaelMedek commented on June 10, 2024

Thanks for your fast help!

I now changed in the ./bin/tabr.py code, the function

def apply_model(part: str, idx: Tensor, training: bool):
    CREATIONDATE_column = 0  # Specify the column index for CREATIONDATE

    training = part == 'train'  # why is training not used in the original and why are sometimes "training" and "part == 'train'" not equal (was the case during training when I had an assert there)?
    
    x, y = get_Xy(part, idx)

    # Get the CREATIONDATE value for the specific example specified by idx
    current_creation_dates = x['num'][:, CREATIONDATE_column]  # Use the entire batch

    if training:
        full_x, _ = get_Xy('train', None)

        # Find the kth smallest value of current_creation_dates
        k = 10  # 99% of the 1024 training batch elements experience only similar examples from the past, and most of them very far from the past
        percentile_threshold, _ = torch.kthvalue(current_creation_dates, k)

        # During training, restrict the candidate set to training data with the CREATIONDATE condition
        candidate_mask = full_x['num'][:, CREATIONDATE_column] < percentile_threshold
        assert candidate_mask.sum() > 0
        candidate_indices = train_indices[candidate_mask]
    else:
        # During evaluation, use the entire training set as candidates
        candidate_indices = train_indices

    candidate_x, candidate_y = get_Xy('train', candidate_indices)

    return model(
        x_=x,
        y=y if training else None,
        candidate_x_=candidate_x,
        candidate_y=candidate_y,
        context_size=C.context_size,
        is_train=training,
    ).squeeze(-1)

in order to only have (for 99% of the batch) similar examples form the past only. This is still ab bit hacky, since the minimum is found per batch, which will on average only consider maybe the oldest 5 000 out of 900 000 training examples during training. One would need to run this code with batch size 1 or reimplement it properly (I will try some other approaches, that you wrote in your reply next).

But now we get test accuracy 0.66 (train 0.77, val 0.75), compared to before test accuracy 0.49 (train 0.88, val 0.88)
(the ensemble model had finally 0.53, so +4pp, maybe to also expect now, training still running) so the changes really help and TabR would probably really benefit from having such a thing implemented in general.

from tabular-dl-tabr.

Related Issues (15)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.