Giter VIP home page Giter VIP logo

Comments (4)

Tomik292 avatar Tomik292 commented on June 30, 2024 1

I've managed to solve the problem it seems.
The problem is that xlm-roberta-like models have (token_type_embeddings): Embedding(1, 1024)
Meaning token_type_ids vector can be of max_length = 1024 and only consist of values 0.

But in the code of explainer.py in the function

def _make_input_reference_token_type_pair(
        self, input_ids: torch.Tensor, sep_idx: int = 0
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Returns two tensors indicating the corresponding token types for the `input_ids`
        and a corresponding all zero reference token type tensor.
        Args:
            input_ids (torch.Tensor): Tensor of text converted to `input_ids`
            sep_idx (int, optional):  Defaults to 0.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]
        """
        seq_len = input_ids.size(1)
        token_type_ids = torch.tensor([0 if i <= sep_idx else 1 for i in range(seq_len)], device=self.device).expand_as(
            input_ids
        )
        ref_token_type_ids = torch.zeros_like(token_type_ids, device=self.device).expand_as(input_ids)

        return (token_type_ids, ref_token_type_ids)

Tensor is created with 0, except for the sep_idx position (in my case it was the last token), where it is 1. So just change the function to something like this.

def _make_input_reference_token_type_pair(
        self, input_ids: torch.Tensor, sep_idx: int = 0
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Returns two tensors indicating the corresponding token types for the `input_ids`
        and a corresponding all zero reference token type tensor.
        Args:
            input_ids (torch.Tensor): Tensor of text converted to `input_ids`
            sep_idx (int, optional):  Defaults to 0.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]
        """
        seq_len = input_ids.size(1)
        
        if self.model.config.model_type == 'xlm-roberta':
            token_type_ids = torch.zeros(seq_len, dtype=torch.int, device=self.device).expand_as(input_ids)
        else:
            token_type_ids = torch.tensor([0 if i <= sep_idx else 1 for i in range(seq_len)], device=self.device).expand_as(
                input_ids
            )
        ref_token_type_ids = torch.zeros_like(token_type_ids, device=self.device).expand_as(input_ids)

        return (token_type_ids, ref_token_type_ids)

from transformers-interpret.

MarianneAK avatar MarianneAK commented on June 30, 2024

I'm getting the same error also on CamemBERT (which is based on RoBERTa).
Really hope we get a fix to this

from transformers-interpret.

ameshrky avatar ameshrky commented on June 30, 2024

@nishantgurunath , Have you managed to solve this issue?

I am getting the same error on XLNet base.

Here is how to reproduce the error:

from transformers import AutoModelForSequenceClassification, AutoTokenizer

model_name = "xlnet-base-cased"
model = AutoModelForSequenceClassification.from_pretrained(model_name,num_labels=2).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name,device = device)

from transformers_interpret import SequenceClassificationExplainer
cls_explainer = SequenceClassificationExplainer(
model,
tokenizer)

word_attributions = cls_explainer("I love you and I like you.")

from transformers-interpret.

Lavriz avatar Lavriz commented on June 30, 2024

Having the same issue for the fine-tuned xlm-roberta: IndexError: index out of range in self

from transformers-interpret.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.