Comments (4)
I've managed to solve the problem it seems.
The problem is that xlm-roberta-like models have (token_type_embeddings): Embedding(1, 1024)
Meaning token_type_ids vector can be of max_length = 1024 and only consist of values 0.
But in the code of explainer.py in the function
def _make_input_reference_token_type_pair(
self, input_ids: torch.Tensor, sep_idx: int = 0
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns two tensors indicating the corresponding token types for the `input_ids`
and a corresponding all zero reference token type tensor.
Args:
input_ids (torch.Tensor): Tensor of text converted to `input_ids`
sep_idx (int, optional): Defaults to 0.
Returns:
Tuple[torch.Tensor, torch.Tensor]
"""
seq_len = input_ids.size(1)
token_type_ids = torch.tensor([0 if i <= sep_idx else 1 for i in range(seq_len)], device=self.device).expand_as(
input_ids
)
ref_token_type_ids = torch.zeros_like(token_type_ids, device=self.device).expand_as(input_ids)
return (token_type_ids, ref_token_type_ids)
Tensor is created with 0, except for the sep_idx position (in my case it was the last token), where it is 1. So just change the function to something like this.
def _make_input_reference_token_type_pair(
self, input_ids: torch.Tensor, sep_idx: int = 0
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns two tensors indicating the corresponding token types for the `input_ids`
and a corresponding all zero reference token type tensor.
Args:
input_ids (torch.Tensor): Tensor of text converted to `input_ids`
sep_idx (int, optional): Defaults to 0.
Returns:
Tuple[torch.Tensor, torch.Tensor]
"""
seq_len = input_ids.size(1)
if self.model.config.model_type == 'xlm-roberta':
token_type_ids = torch.zeros(seq_len, dtype=torch.int, device=self.device).expand_as(input_ids)
else:
token_type_ids = torch.tensor([0 if i <= sep_idx else 1 for i in range(seq_len)], device=self.device).expand_as(
input_ids
)
ref_token_type_ids = torch.zeros_like(token_type_ids, device=self.device).expand_as(input_ids)
return (token_type_ids, ref_token_type_ids)
from transformers-interpret.
I'm getting the same error also on CamemBERT (which is based on RoBERTa).
Really hope we get a fix to this
from transformers-interpret.
@nishantgurunath , Have you managed to solve this issue?
I am getting the same error on XLNet base.
Here is how to reproduce the error:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
model_name = "xlnet-base-cased"
model = AutoModelForSequenceClassification.from_pretrained(model_name,num_labels=2).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name,device = device)
from transformers_interpret import SequenceClassificationExplainer
cls_explainer = SequenceClassificationExplainer(
model,
tokenizer)
word_attributions = cls_explainer("I love you and I like you.")
from transformers-interpret.
Having the same issue for the fine-tuned xlm-roberta: IndexError: index out of range in self
from transformers-interpret.
Related Issues (20)
- What algorithm is used to visualize text in SequenceClassificationExplainer
- How to use transformers-interpret for sequencelabelling, for example layoutlmv3 or v3 HOT 1
- MultiLabelSequenceClassificationExplainer potentially bugged. HOT 14
- ImportError: cannot import name 'PairwiseSequenceClassificationExplainer' HOT 1
- How to interpret the model fine tuning on the pre-trained ViT model using the imagery with larger resolution (500 * 500) than the pre-trained dataset (224 * 224)
- Token Classification Memory Issue
- Issue using BertTokenizer (AttributeError) HOT 2
- 'Bert' object has no attribute 'config'
- Is it normal that attribution takes multiple seconds per text, even on a GPU? HOT 1
- ZeroShotClassificationExplainer appears to be broken
- Prediction differs from non-explainable evaluation HOT 1
- Output probability - SequenceClassificationExplainer
- Support for Summarization models HOT 3
- Support for Longformer
- ImageClassificationExplainer: AttributeError: ndim when trying to visualize. HOT 3
- Issue with Zero Shot Classifier
- How to use other types of transformers models? HOT 1
- Support for Reformer
- Broken link for Captum Algorithm Overview in the README
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from transformers-interpret.