Giter VIP home page Giter VIP logo

Comments (7)

tarekgh avatar tarekgh commented on June 2, 2024 1

This is issue we can use for the tokenizers design review.

CC @ericstj @michaelgsharp @stephentoub @luisquintanilla @terrajobst @JakeRadMSFT @LittleLittleCloud

from machinelearning.

tarekgh avatar tarekgh commented on June 2, 2024 1

@georg-jung thanks for the feedback.

IIRC, it's easiest to feed Memory/ into onnxruntime. If EncodeToIds returns IReadOnlyList, can we avoid copying the memory?

I'll think more about it if we can return Memory<T> instead. The challenge is during the encoding need to store the produced ids which need to allocate the array before hand. need to estimate the initial array allocation correctly to avoid allocating un-used memory and avoid re-allocation if there is no enough size. Anyway, the idea is worth to think about it.
Also, in the future we can expose a new APIs that can take a destination span/Memory so let the callers decide how to manage the allocation.

If Normalize(ReadOnlySpan original) returns string, wouldn't we do more allocations than needed for some normalizations?

Yes, we are designing the library for generic use case, if there is normalization, then we need to have allocations. We always return the normalized string so callers can perform subsequent calls on that normalized string without more allocations nor re-doing the normalization. Anyone want to optimize and avoid the allocations can manually do the normalization before calling the tokenizer.

If the EncodeToIds overload with maxTokenCount has out string normalizedText, wouldn't that require us to construct and allocate the normalized string even if the caller might not be interested in it?

normalizedText is allocated only if the tokenizer is created with a normalizer object and the caller of EncodeToIds having considerNormalization = true. The caller has full control over the operation.

Couldn't we do something like pretokenize -> try to encode -> on failure: normalize -> try to encode and thereby skip the normalization work for most of the input?

I don't think we can do that. You never know if the encode is failed or not. It is possible the encoding to succeed if the string is not normalized and return unexpected values. Users can choose to do that by calling the tokenizer with considerNormalization = false and then check the result and decide to call again with true if needed.

How would MapTokenToId handle tokens that map to multiple ids?

It doesn't map to multiple Ids. It will return null if cannot map it to a single Id. Users can do EncodeToIds if they need to get the full results.

Would it be able to handle a token that is only valid as a suffix, e.g. by calling it with "##suffix"?

This depends on the tokenizer. If the tokenizer support that, then it should be able to map it. Note, we are supporting many tokenizers so each tokenizer can decide how to map tokens to/from Ids.

Can CountTokens (and IndexOfTokenCount and LastIndexOfTokenCount) have faster implementations than e.g. EncodeToIds().Count? Not sure, but I think for BERT the cost might be similar.

Yes. we don't implement CountTokens by calling EncodeToIds().Count. We try to optimize for such cases. Sometimes we need to create a cache though which help for subsequent calls and speed it more.

Padding input_ids is out of scope?
Combining multiple strings as a single model input (e.g. paragraph + question about paragraph) is out of scope?
Batching is out of scope (e.g. for parallelization)?
Encoding texts longer than maxTokenCount with a stride is out of scope?

For the current version, these are out of scope till we find enough demand on such features. We need to get the main features first which covering the majority of scenarios we are seeing so far.

If you think I can contribute to/help with any of this please let me know :-)

Sure. we appreciate all helps!

CC @stephentoub

from machinelearning.

georg-jung avatar georg-jung commented on June 2, 2024

Hey @tarekgh,

Thank you for sharing the design! I've had the chance to take a look and have some thoughts - fwiw/probably a bit from a BERT perspective ;-).

  • IIRC, it's easiest to feed Memory<int>/<long> into onnxruntime. If EncodeToIds returns IReadOnlyList<int>, can we avoid copying the memory?
  • If Normalize(ReadOnlySpan<char> original) returns string, wouldn't we do more allocations than needed for some normalizations?
  • If the EncodeToIds overload with maxTokenCount has out string normalizedText, wouldn't that require us to construct and allocate the normalized string even if the caller might not be interested in it?
  • Couldn't we do something like pretokenize -> try to encode -> on failure: normalize -> try to encode and thereby skip the normalization work for most of the input?
    • Some normalizers require allocating a string (unicode normalization), others don't (lowercasing, stripping control chars). If a Tokenizer has one normalizer in its API, a possible way to implement a concrete normalizer could be e.g. BertNormalize = RemoveDiacritics(UnicodeNormalize(RemoveControlChars(Lowercase(input)))). Now consider an input that is mostly ascii/latin1/... chars but not lowercase. Wouldn't we pay for the allocations for the whole input, because UnicodeNormalize is part of the normalization, but lowercasing would often be sufficient and could work alloc-free on Span<char>?
    • If a Tokenizer had List<Normalizer> or something similar wouldn't it be able to save allocations and normalization operaterions? E.g. by doing pretokenize -> try to encode -> on failure: normalize[0] -> try to encode -> on failure: normalize[1] -> try to encode -> ...?
  • How would MapTokenToId handle tokens that map to multiple ids? Would it be able to handle a token that is only valid as a suffix, e.g. by calling it with "##suffix"?
  • Can CountTokens (and IndexOfTokenCount and LastIndexOfTokenCount) have faster implementations than e.g. EncodeToIds().Count? Not sure, but I think for BERT the cost might be similar.
  • Padding input_ids is out of scope?
  • Combining multiple strings as a single model input (e.g. paragraph + question about paragraph) is out of scope?
  • Batching is out of scope (e.g. for parallelization)?
  • Encoding texts longer than maxTokenCount with a stride is out of scope?

If you think I can contribute to/help with any of this please let me know :-)

from machinelearning.

georg-jung avatar georg-jung commented on June 2, 2024

Thanks for the detailed response!

The challenge is during the encoding need to store the produced ids which need to allocate the array before hand. need to estimate the initial array allocation correctly to avoid allocating un-used memory and avoid re-allocation if there is no enough size.

An overload that writes to a passed-in Span would of corse be easy in that regard as it's then up to the caller :D. An approach I took with my bert tokenizer is to re-use an internal buffer for subsequent calls, so the allocation cost becomes less relevant when encoding multiple documents.

Maybe I'm mistaken, regarding IReadOnlyList specifically, I was thinking if one would want to pass it as an input to onnxruntime, wouldn't it then always be needed to write something similar to

var res = EncodeToIds(...)
var modelInput = res is int[] i ? i.AsMemory() : (res is List<int> l ? l.AsMemory() : (int[])[..res])

if there is normalization, then we need to have allocations.

I was thinking of something like MemoryExtensions.ToLower or maybe RemoveControlAndReplacement(ReadOnlySpan<char> text, out ReadOnlySpan<char> cleaned) with a re-used internal buffer or similar. Couldn't then at least many normalizations be alloc-free, e.g. lowercasing, uppercasing, stripping control chars and, probably most important, the "no-op" normalization, where the input already is normalized according to the normalizer it is passed to?

I'm a bit in a hurry and think about the other points soon... Thanks for always taking the time to discuss this, I think it is really interesting!

from machinelearning.

tarekgh avatar tarekgh commented on June 2, 2024

I think trying to optimize for normalization scenarios will complicate the tokenizer interfaces. What proposed would be enough to the users to decide how far they need to optimize. If they need really allocation free normalization, then they can do it themselves before calling the tokenizer. If they are ok with allocation but want to avoid the allocations/processing on the subsequent calls using the normalized string, they can use considerNormalization = false. I am seeing trying to optimize for normalization while want to support all possible scenario would be very challenging and will make the APIs more complicated than normal users want.

from machinelearning.

bartonjs avatar bartonjs commented on June 2, 2024

Video

  • We reshaped the abstracts and virtuals to reduce to one abstract member per method group, this entailed creating an EncodeResults and EncodeSettings struct (that are currently only used in protected members).
  • We renamed Encode to EncodeToTokens so it didn't seem "better" than EncodeToIds
  • We renamed (Last)IndexOfTokenCount to GetIndexByTokenCount(FromEnd)
  • We were discussing the need for an OperationStatus-based Decode when we ran out of time (and did not make it beyond there)
namespace Microsoft.ML.Tokenizers
{
    public struct EncodeResults<T>
    {
        public IReadOnlyList<T> Tokens { get; set; }
        public string? NormalizedText { get; set; }
    }

    public struct EncodeSettings
    {
        public bool ConsiderNormalization { get; set; }
        public bool ConsiderPreTokenization { get; set; }
        public bool ProduceNormalizedString { get; set; }
        public int MaxTokenCount { get; set; }
    }

    public abstract partial class Tokenizer
    {
        protected Tokenizer() { }

        public virtual Normalizer? Normalizer { get { throw null; } }

        public virtual PreTokenizer? PreTokenizer { get { throw null; } }

        public IReadOnlyList<int> EncodeToIds(string text, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true);

        public IReadOnlyList<int> EncodeToIds(string text, int maxTokenCount, out string? normalizedText, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedText, out int textLength, bool considerPreTokenization = true, bool considerNormalization = true);

        protected abstract EncodeResults<int> EncodeToIds(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings);

        public IReadOnlyList<Token> EncodeToTokens(string text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public IReadOnlyList<Token> EncodeToTokens(ReadOnlySpan<char> text, out string? normalizedString, bool considerPreTokenization = true, bool considerNormalization = true);

        protected abstract EncodeResults<Token> EncodeToTokens(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings);

        public int CountTokens(string text, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public int CountTokens(ReadOnlySpan<char> text, bool considerPreTokenization = true, bool considerNormalization = true);

        protected abstract int CountTokens(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings);

        public int GetIndexByTokenCount(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public int GetIndexByTokenCount(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }

        public int GetIndexByTokenCountFromEnd(string text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }
        public int GetIndexByTokenCountFromEnd(ReadOnlySpan<char> text, int maxTokenCount, out string? normalizedString, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { throw null; }

        protected abstract int GetIndexByTokenCount(string? text, ReadOnlySpan<char> textSpan, EncodeSettings settings, bool fromEnd, out string? normalizedString, out int tokenCount);

        public abstract string? Decode(IEnumerable<int> ids) { throw null; }
        public abstract OperationStatus Decode(IEnumerable<int> ids, Span<char> destination, out int idsConsumed, out int charsWritten) { throw null; }

        public virtual int? MapTokenToId(string token) { throw null; }
        public abstract int? MapTokenToId(ReadOnlySpan<char> token);

        public abstract string? MapIdToToken(int? id);

       //
       // Factory methods
       // 

        public static Task<Tokenizer> CreateTiktokenAsync(Stream vocabStream, PreTokenizer? preTokenizer, Normalizer? normalizer, IReadOnlyDictionary<string, int> specialTokens = null, 
                                                                                                  int cacheSize = 8192, Threading.CancellationToken cancellationToken = null) { throw null; }

        public static Task<Tokenizer> CreateTiktokenAsync(string vocabFilePath, PreTokenizer? preTokenizer, Normalizer? normalizer, IReadOnlyDictionary<string, int> specialTokensEncoder = null, 
                                                                                                  int cacheSize = 8192, Threading.CancellationToken cancellationToken = null) { throw null; }

        public static Tokenizer CreateTiktokenForEncoding(string encodingName, IReadOnlyDictionary<string, int> extraSpecialTokens = null, Normalizer? normalizer = null) { throw null; }

        public static Tokenizer CreateTiktokenForModel(string modelName, IReadOnlyDictionary<string, int> extraSpecialTokens = null, Normalizer? normalizer = null) { throw null; }

        public static Tokenizer CreateTiktokenForModel(string modelName, Stream vocabStream, IReadOnlyDictionary<string, int> extraSpecialTokens = null, 
                                                                                                    int cacheSize = 8192, Normalizer? normalizer = null) { throw null; }

        public static Task<Tokenizer> CreateTiktokenForModelAsync(string modelName, Stream vocabStream, IReadOnlyDictionary<string, int> extraSpecialTokens = null, 
                                                                                                   int cacheSize = 8192, Normalizer? normalizer = null, Threading.CancellationToken cancellationToken = null) { throw null; }

        public static Tokenizer CreateLlama(Stream modelStream, bool addBeginOfSentence = true, bool addEndOfSentence = false) { throw null; }

        public static Tokenizer CreateCodeGen(Stream vocabStream, Stream mergesStream, bool addPrefixSpace = false, bool addBeginOfSentence = false, bool addEndOfSentence = false) { throw null; }

        public static Tokenizer CreatePhi2(Stream vocabStream, Stream mergesStream, bool addPrefixSpace = false, bool addBeginOfSentence = false, bool addEndOfSentence = false) { throw null; }
    }
}

from machinelearning.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.