Giter VIP home page Giter VIP logo

quaterion's Introduction

Quaterion

Blazing fast framework for fine-tuning Similarity Learning models

Version Tests status Discord Docs & Tutorials

A dwarf on a giant's shoulders sees farther of the two

Quaterion is a framework for fine-tuning similarity learning models. The framework closes the "last mile" problem in training models for semantic search, recommendations, anomaly detection, extreme classification, matching engines, e.t.c.

It is designed to combine the performance of pre-trained models with specialization for the custom task while avoiding slow and costly training.

Features

  • πŸŒ€ Warp-speed fast: With the built-in caching mechanism, Quaterion enables you to train thousands of epochs with huge batch sizes even on laptop GPU.

Regular vs Cached Fine-Tuning

  • πŸˆβ€ Small data compatible: Pre-trained models with specially designed head layers allow you to benefit even from a dataset you can label in one day.

  • πŸ—οΈ Customizable: Quaterion allows you to re-define any part of the framework, making it flexible even for large-scale and sophisticated training pipelines.

  • 🌌 Scalable: Quaterion is built on top of PyTorch Lightning and inherits all its scalability, cost-efficiency, and reliability perks.

Installation

TL;DR:

For training:

pip install quaterion

For inference service:

pip install quaterion-models

Quaterion framework consists of two packages - quaterion and quaterion-models.

Since it is not always possible or convenient to represent a model in ONNX format (also, it is supported), the Quaterion keeps a very minimal collection of model classes, which might be required for model inference, in a separate package.

It allows avoiding installing heavy training dependencies into inference infrastructure: pip install quaterion-models

At the same time, once you need to have a full arsenal of tools for training and debugging models, it is available in one package: pip install quaterion

Docs πŸ““

For a more in-depth dive, check out our end-to-end tutorials:

Tutorials for advanced features of the framework:

Community

License

Quaterion is licensed under the Apache License, Version 2.0. View a copy of the License file.

quaterion's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

quaterion's Issues

Generate documentation from branch in netlify for preview

Currently, documentation is generated based on the published package.
In such case preview on pull requests does not make sense.
To fix it we need to update docs/generate_docs_netlify.sh and generate docs based on code in the repository.

Introduce CLI command to create a project template

Based on the discussion in #38

  • quaterion new project-name to create a basic template with cookie-cutter.
  • This template may include the basic proper structure, e.g., files such as encoders.py, training.py, inference.py and pyproject.toml.
  • It may be helpful for easy experiments, a faster way from experiment to deployment, reproduceable research, reusable and shareable work among others.

Step-by-step tutorial with image dataset

We need to create a tutorial for users who have never used Quaterion and want to go through it step-by-step.

In the tutorial we need

  • Describe what is the use-case of Quaterion - why not just use e.g. pytorch-metric-learning
  • Have a proper project layout - encoders are defined in separate module, it should be possible to use them after training
  • Description of what components serve what purposes
    • why we use Pytorch Lightning - what is our responsibility and what is theirs (with links to their docs)
    • what is the difference between trainable model and MetricModel
    • Why do we need a separate quaterion-models
    • How to properly initialize pre-trained encoders

Choose sensible defaults for losses

Currently, ContrastiveLoss has a default margin value of 1.0 with cosine as a default distance metric, which does not make sense. We need to choose sensible defaults for losses.

Do not autogenerate documentation on each commit

Currently, documentation being generated on each commit, it leads to excess files in pull request diff and makes it harder to determine which files should be reviewed.

One possible approach could be apply restrictions on build job via placing if: github.event.pull_request.merged == 'true' statement right after it. Then workflow will be triggered in the same cases as it being triggered now, but if pull request is not merged - no jobs will be launched.

Another approach is triggering workflow on push to specified branches instead of on pull_request.
E.g. every time we push to master or to docs* branches - documentation will be generated.

Discussion regarding this problem.

Bug: Caching side effect's on multi-worker DataLoader on Windows

Using multi-worker DataLoader together with caching sets DataLoader's mp context to fork on windows

train_loader = PairsSimilarityDataLoader(..., num_workers=2)

Then, when subclassing TrainableModel:

    def configure_caches(self) -> CacheConfig:
        return CacheConfig(CacheType.AUTO, batch_size=32)

And the error message is:

  File "C:\Users\Yusuf\coding\quaterion\quaterion\train\cache_mixin.py", line 261, in _wrap_cache_dataloader
    cls._switch_multiprocessing_context(dataloader)
  File "C:\Users\Yusuf\coding\quaterion\quaterion\train\cache_mixin.py", line 321, in _switch_multiprocessing_context
    dataloader.multiprocessing_context = cls.CACHE_MULTIPROCESSING_CONTEXT
  File "C:\ProgramData\Anaconda3\envs\torch\lib\site-packages\torch\utils\data\dataloader.py", line 342, in __setattr__
    super(DataLoader, self).__setattr__(attr, val)
  File "C:\ProgramData\Anaconda3\envs\torch\lib\site-packages\torch\utils\data\dataloader.py", line 318, in multiprocessing_context
    raise ValueError(
ValueError: multiprocessing_context option should specify a valid start method in ['spawn'], but got multiprocessing_context='fork'

Do we really need such a side effect?

Add pair and triplet samplers

Why?

  • Most loss functions work on pairs or triplets.
  • Many datasets do not offer ready-to-use pairs and/or triplets, and they need forming on the fly.

How?

  • PyTorch has a built-in torch.utils.data.Sampler class that can be passed to DataLoader alongside a Dataset. We may subclass it to form pairs and triplets.

[docs] Caching tutorial

  • Detailed explanation of an example which uses caching.
  • Visual proof that caching improves speed significantly (e.g. side-by-side gif screencast comparision)
  • Explanation of caching limits and advanced options

Implement a `nn.Sequential`-like head layer

This will allow users to compose custom head layers easily by appending arbitrary nn.Modules to it.

For example, they will be able to add a dropout before the actual head layer etc.

[metrics] Handle metrics updates inside TrainableModel

Currently it is required to write metrics update logic manually in the custom implementation of TrainableModel.
But there is a large group of metrics which are updated in similar way, therefore it is possible to relieve the user from it.

Proposition

Implement metrics configuration similar to loss - create an abstract function in the base class configure_metrics.
If re-refined configure_metrics returns non-empty dict, TrainableModel should register metric updates in related hooks

Caching may have additional side effects

Currently, TrainableModel.unwrap_cache() is called at the end of Quaterion.fit, but this may have unpredictable side effects.

For example, when we want to re-use data loaders after training just like here on line 89, it won't work anymore.

We may need to let users skip unwrapping with a boolean keyword argument, but in any case we need to document caching and its side effects very clearly and in a visible way. It took my hours to realize it 😒

Separate cache dataloader

Create internal dataloader for caching to reduce repeated calculations.

E.g. dataloader provides batches of the following form:

[
    [anchor, positive],
    [anchor, negative_0],
    [anchor, negative_1],
    ...,
    [anchor, negative_n]
]

Assume these negative samples are random ones, then they can appear in batches with different anchors, and every time we need to calculate embeddings, and it may be really costly.

The other example is when you want to perform hard-negative mining and to do this you want to search through the whole dataset. If your dataset is small enough, you can load it into one batch. So it will be a batch of form like in the first example, but amount of negatives will be N - 1, where N is size of dataset. And you will need to calculate embeddings N^2 times and it is really expensive.

To avoid extra calculations, separate internal dataloader should be implemented, via this dataloader cache has to be filled in linear time.

Bug in caching: Boolean value of tensor comparison is ambiguous

I think this is an edge case. when we do if sample.obj not in unique_objects in fetch_unique_objects of data loader classes, it may return a tensor of Trues and Falses if the individual elements in the tensor (sample.obj) are equal in certain indexes and not equal in others. When we want to use that returned tensor in a conditional expression, it tries to reduce it to a single boolean value, but the combination of multiple Trues and Falses is ambiguous, so it throws a runtime error:

...
  File "C:\Users\Yusuf\coding\quaterion\quaterion\dataset\cache_data_loader.py", line 68, in cache_collate_fn
    unique_objects = self.unique_objects_extractor(batch)
  File "C:\Users\Yusuf\coding\quaterion\quaterion\dataset\similarity_data_loader.py", line 42, in fetch_unique_objects
    if sample.obj_b not in unique_objects:
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
Predicting:   0%|          | 0/387 [00:00<?, ?it/s]

Possible solution

  1. Create another container for unique hashes in addition to unique_objects, e.g.: unique_hashes.
  2. If hash(obj) not in the new container, add the hash there and add the object to unique_objects.
  3. Return unique_objects as usual.
    We can even return list of hashes to further use them later, maybe.

[metrics] Integrate metric calculation into Quaterion

Allow user to specify "nearest-neighbor based accuracy metrics" in TrainableModel similar to loss functions.

Similar to https://kevinmusgrave.github.io/pytorch-metric-learning/testers/

Variant examples:

  • For group samples - encode train dataset with current version of the model and evaluate if validation samples are properly classified by retrieval (rp@1)
  • For pair sampler - encode one object of each pair and on the second run estimate precision@N

Perform metrics calculation during the eval epoch.

Explicitly set 'fork' start method in multiprocessing for cache

Basically, to obtain keys for cache we apply hash to some hashable input.

When we iterate through dataloader, it starts new worker-process, in which we have to produce the same hash values as in parent process.

spawn method does not inherit some internal seed that is crucial for generating randomized salt for hash of str and bytes objects, but fork method inherits it. (doc)

Since python 3.8, default process start method in multiprocessing on mac os x is spawn (doc), we need to explicitly replace it with fork when cache is utilized.

Roadmap to the initial public release

Updating this one as per @generall's comments.

  • Implement evaluation methods and metrics such as MRR etc. WIP #61
  • Create a step-by-step getting-started tutorial for documentation. Needs extra work, but we already have some basic examples.
  • Standardize docstrings in a well known format, and generate API docs as a part of CI/CD.
  • Start a changelog, and keep it up-to-date. This one needs attention.
  • Proofread the documentation, and fix minor issues.
  • Finalize #10, and merge it.
  • Tagline issue, #53.

Please add other issues that you consider. I'm also pinning this issue to prioritize it.

Synchronize model checkpointing and servable model saving

Currently, TrainableModel.save_servable() is called by the user at the end of the training loop. This is problematic because we may end up with saving an overfitted state of the model even if we are trying to monitor an evaluation metric with pl.callbacks.ModelCheckpoint. So we need to come up with a way to synchronize both.

Possible solution

  • We may need to subclass ModelCheckpoint inside quaterion for synchronization.
  • We may accept additional keyword arguments in Quaterion.fit to automatically save a servable checkpoints to the specified directory with a specified interval.

Experiments with Triplet Loss to prevent vector space collapsing

Margin-based losses have a tendency for vector space collapsing. We introduced a simple trick in #44, but it might be interesting to try more sophisticated tricks such as encouraging variance in embeddings or the distance matrix directly. It might turn out to be a research article as well.

[docs] Readme

  • Badges
  • Quick intro
  • Motivation and comparison with other approaches
  • Features list
  • Installation / components (quaterion-models)
  • Quick Start
  • Link to docs, other products, community

Review and update docstrings

Currently, some docstrings are not descriptive enough and should be reviewed and updated as necessary.

Core entities should contain exhaustive examples.
Examples of such entities could be TrainableModel or Quaterion and their corresponding methods.

Enforce code style of the project

  • Add CI/CD for tests + linter

  • Use type annotations in all functions

  • For NN model-related code:

Add shape comment for each tensor transformation - example: https://github.com/generall/EntityCategoryPrediction/blob/master/category_prediction/model/multi_head_attention.py#L92

It might seem an over-kill, but for me it looks like a game changer. I only decided to use AllenNLP, because they were head and shoulders above their competitors from fast.ai (terrible, terrible project!)

Update documentation content and formatting before release

@generall @monatis let's collect all things which should be done in documentation before release.

I think it would be better if this issue contains a bunch of comments with task lists like this:

Epic:

  • Create start page
  • Configure links to quaterion-models doc (probably could be done via extension)

Routine:

  • Hide indexing dataset module page
  • Add params description to TrainCollater
  • Update formatting for note in multiple_negatives_ranking_loss
  • Make mentions of quaterion members in docstrings clickable. Example of mention to fix is SiameseDistanceMetric in pairwise_loss
  • Get rid of excess hyphens and breaklines, e.g. in GroupSimilarityDataLoader.collate_labels
  • Try to find a way to get rid of inherited documented class members like training: bool in contrastive_loss (This seems to me like a bug in sphinx)
  • Try to get rid of duplicate generation as in quaterion.train.cache.html and quaterion.train.cache.cache_config.html.
  • Provide links to custom types definitions in signatures
  • Document constants
    Fix TrainCollater signature (unhashable type: TypeAliasForwardRef) Seems like a bug in sphinx

Optional:

  • Look into sphinx warnings and fix those are possible to fix

Optionally cache encoder outputs

Encoder outputs should be optionally cached to make training faster and more memory-efficient.

  • Implement an abstract class for caching.
  • Subclass it to implement a simple in-memory cache.
  • Other storage options might be mmapped Numpy files or HDF5.
  • Caching should be able to be turn on or off with a boolean flag.

warning during test

There is a pytorch lightning warning reproducable via tests:

tests/cache/test_cache_dataloader.py::test_cache_dataloader
  /home/generall/projects/vector_search/quaterion/venv/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/prediction_epoch_loop.py:175: UserWarning: Lightning couldn't infer the indices fetched for your dataloader.
    warning_cache.warn("Lightning couldn't infer the indices fetched for your dataloader.")

Come up with a better tagline / motto

the current one, "A dwarf on a giant's shoulders sees farther of the two," is quite generic and a clichΓ©. I suggest coming up with a new tagline that better communicates the overall purpose of the package.

Make imports less verbose

  • Import all public classes in module's __init__.py
    • Reason: Multiple classes from the same module should be able to be imported in one line.
  • Create pl.Trainer instance in quaterion.Quaterion.fit, and except keyword arguments to do so.
    • Reason: Currently, pl.Trainer is created by the user just to pass it to quaterion.Quaterion.fit, and we can merge two calls into one.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.