Comments (3)
Glad that it helped! I agree, a hypothetical "mainstream" implementation of TabR should probably provide a more friendly interface for implementing such business rules. However, I should admit that it is unlikely that we will implement this in this repository.
Patching candidate_indices
is also a solution! However, it may limit the potential of TabR, since the number of candidates becomes significantly smaller (at least on this specific task).
Also, I see a minor "bug" in the new implementation. In the original code, there are two variables: is_train
and training
. They have different meaning:
is_train
is about the dataset split, as defined here. It is used here to add the training batch itself to the list of candidates (which would not be valid for validation/test data, at least not by default).training
is about whether we do training or evaluation.
It seems that the only consequence of the new implementation is that the reported evaluation performance on the training data may be slightly worse than it could be. This should not affect neither the training process nor the validation/test scores. That said, I would recommend fixing this in future runs just in case.
from tabular-dl-tabr.
Hi! This is a great comment, thank you for bringing attention to this topic.
Problem
More generally, this phenomenon can be formulated as follows:
Machine learning models are known to suffer from distribution shifts happening during deployment, where, by "distribution", the distribution of objects (i.e. the join distribution of features and labels) is implied. TabR, however, also depends on the distribution of object-object interactions, so changes in this distribution can also be a problem for TabR.
The above issue is especially relevant to problems with time-like components (as you mentioned): without additional measures, TabR can learn to retrieve context objects from future, which is not practical. Importantly, preventing the model from retrieving from future is not enough. In fact, TabR should be also prevented from retrieving from the most recent lag_size
observations, where lag_size
should be defined individually for each task based on a business intuition or to match the deployment setup.
A similar problem can happen when a dataset has some kind of identifiers. For example, if the data is split by user ids, then, for a given user, TabR can learn to retrieve records related to the same user (even if identifiers are not explicitly presented!), which may not be possible during deployment.
Solution
One has to modify the original code in a way that explicitly prevents TabR from learning impractical algorithms for a given task. In certain situations it can mean that faiss
will not be applicable anymore and one will have to manually reimplement the similarity search (i.e. compute distances and apply torch.topk
). The idea is as follows:
- with
faiss
, more thancontext_size
objects should be retrieved and then filtered based on the business case. - with manual implementation, after computing the similarities, some of them should be set to
-inf
(before applyingtorch.topk
) based on the business case.
The relevant code:
- the similarity search happens here, and the main "output" of this code block is the
context_idx
variable which defines the context objects. If I remember everything correctly, it should be enough to customize howcontext_idx
is computed. - chances are that you will need to pass some additional information to the forward method.
Does this help?
Overall, we are aware of this phenomenon (and we suspect that the bad performance on the "MI" benchmark is caused by this, but we did not verify that), and we are planning to mention it in the "Limitation" section in the next revision. Thank you once again for the report.
from tabular-dl-tabr.
Thanks for your fast help!
I now changed in the ./bin/tabr.py code, the function
def apply_model(part: str, idx: Tensor, training: bool):
CREATIONDATE_column = 0 # Specify the column index for CREATIONDATE
training = part == 'train' # why is training not used in the original and why are sometimes "training" and "part == 'train'" not equal (was the case during training when I had an assert there)?
x, y = get_Xy(part, idx)
# Get the CREATIONDATE value for the specific example specified by idx
current_creation_dates = x['num'][:, CREATIONDATE_column] # Use the entire batch
if training:
full_x, _ = get_Xy('train', None)
# Find the kth smallest value of current_creation_dates
k = 10 # 99% of the 1024 training batch elements experience only similar examples from the past, and most of them very far from the past
percentile_threshold, _ = torch.kthvalue(current_creation_dates, k)
# During training, restrict the candidate set to training data with the CREATIONDATE condition
candidate_mask = full_x['num'][:, CREATIONDATE_column] < percentile_threshold
assert candidate_mask.sum() > 0
candidate_indices = train_indices[candidate_mask]
else:
# During evaluation, use the entire training set as candidates
candidate_indices = train_indices
candidate_x, candidate_y = get_Xy('train', candidate_indices)
return model(
x_=x,
y=y if training else None,
candidate_x_=candidate_x,
candidate_y=candidate_y,
context_size=C.context_size,
is_train=training,
).squeeze(-1)
in order to only have (for 99% of the batch) similar examples form the past only. This is still ab bit hacky, since the minimum is found per batch, which will on average only consider maybe the oldest 5 000 out of 900 000 training examples during training. One would need to run this code with batch size 1 or reimplement it properly (I will try some other approaches, that you wrote in your reply next).
But now we get test accuracy 0.66 (train 0.77, val 0.75), compared to before test accuracy 0.49 (train 0.88, val 0.88)
(the ensemble model had finally 0.53, so +4pp, maybe to also expect now, training still running) so the changes really help and TabR would probably really benefit from having such a thing implemented in general.
from tabular-dl-tabr.
Related Issues (15)
- Make a pip-installable Python package HOT 1
- When n_classes>1, how 'self.label_encoder‘ do for label(float)? HOT 3
- Bugs in parallel gpus HOT 1
- Could you please share the code to create new dataset directory? HOT 1
- Request help debug: I occur a bug when reproduce the winequality dataset HOT 3
- How to calculate the final result in the results of 100trails HOT 7
- How to understand the relationship between tune.py and evaluate.py? HOT 3
- OneHotEncoder of Cat features HOT 2
- Expected 2d tensor for the single feature of such type, got 1d HOT 4
- inference HOT 26
- (params_with_wd if needs_wd else params_with_wd)['params'].append(parameter) in deep.py,is there a mistake here? Why if...else... are all connected to params_with_wd? HOT 4
- ? bug: delu.nn.Lambda(lambda x: x.squeeze(-2)) HOT 2
- bug:when eval, AttributeError: module 'torch.cuda' has no attribute 'OutOfMemoryError' HOT 5
- The change of the candidate set during training HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from tabular-dl-tabr.