Giter VIP home page Giter VIP logo

chartinstruct's Introduction

ChartInstruct: Instruction Tuning for Chart Comprehension and Reasoning

  • Authors: Ahmed Masry, Mehrad Shahmohammadi, Md Rizwan Parvez, Enamul Hoque, Shafiq Joty (*equal contribution)
  • Paper Link: ChartInstruct
  • Venue: ACL 2024 Findings

Screenshot 2024-06-21 215938

ChartInstruct Model Checkpoints

We release the checkpoint for our pretrained model on huggingface.

Task Checkpoint Path
ChartInstruct-Llama2 ChartInstruct-Llama2
ChartInstruct-Flan-T5-XL ChartInstruct-FlanT5-XL

IMPORTANT: Please note that we have changed the alignment module from a linear layer (as described in the paper) to an MLP with 2 layers to improve the compatability with huggignface's LLaVA codebase. This made our models very easy to run and finetune using a few lines of code as you will see below!

Web Demo

If you wish to quickly try our models, you can access our public web demoes hosted on the Hugging Face Spaces platform with a friendly interface!

Tasks Web Demo
ChartInstruct-Llama2 ChartInstruct-Llama2
ChartInstruct-Flan-T5-XL ChartInstruct-FlanT5-XL

Inference

You can easily use our models for inference with the huggingface library! You just need to do the following:

Chage the image_path to your chart example image path on your system

Write the input_text

We recommend using beam search with a beam size of 4 to better results, but if your machine's GPU has low memory, you can remove the num_beams from the generate method.

ChartInstruct LLama2

from PIL import Image
import requests
from transformers import AutoProcessor, LlavaForConditionalGeneration
import torch

torch.hub.download_url_to_file('https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/val/png/multi_col_1229.png', 'chart_example_1.png')

image_path = "/content/chart_example_1.png"
input_text = "What is the share of respondants who prefer Whatsapp in the 18-29 age group?"

input_prompt = f"<image>\n Question: {input_text} Answer: "

model = LlavaForConditionalGeneration.from_pretrained("ahmed-masry/ChartInstruct-LLama2", torch_dtype=torch.float16)
processor = AutoProcessor.from_pretrained("ahmed-masry/ChartInstruct-LLama2")


device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

image = Image.open(image_path).convert('RGB')
inputs = processor(text=input_prompt, images=image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}

# change type if pixel_values in inputs to fp16. 
inputs['pixel_values'] = inputs['pixel_values'].to(torch.float16)
prompt_length = inputs['input_ids'].shape[1]

# move to device
inputs = {k: v.to(device) for k, v in inputs.items()}

# Generate
generate_ids = model.generate(**inputs, num_beams=4, max_new_tokens=512)
output_text = processor.batch_decode(generate_ids[:, prompt_length:], skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
print(output_text)

Does you GPU have low memory? The above code is slow on your machine? We got you covered! Use the following code that loads the quantized version of the model. Just make sure to install the following pip modules: bitsandbytes, itsandbytes-cuda112, accelerate

from PIL import Image
import requests
from transformers import AutoProcessor, LlavaForConditionalGeneration, BitsAndBytesConfig
import torch
from PIL import Image

torch.hub.download_url_to_file('https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/val/png/multi_col_1229.png', 'chart_example_1.png')

image_path = "/content/chart_example_1.png"
input_text = "What is the share of respondants who prefer Whatsapp in the 18-29 age group?"

input_prompt = f"<image>\n Question: {input_text} Answer: "

bnb_config = BitsAndBytesConfig(
            load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16
)

model = LlavaForConditionalGeneration.from_pretrained("ahmed-masry/ChartInstruct-LLama2", torch_dtype=torch.float16, quantization_config=bnb_config)
processor = AutoProcessor.from_pretrained("ahmed-masry/ChartInstruct-LLama2")

image = Image.open(image_path).convert('RGB')

inputs = processor(text=input_prompt, images=image, return_tensors="pt")

# change type if pixel_values in inputs to fp16. 
inputs['pixel_values'] = inputs['pixel_values'].to(torch.float16)
prompt_length = inputs['input_ids'].shape[1]


# Generate
generate_ids = model.generate(**inputs, num_beams=4, max_new_tokens=512)
output_text = processor.batch_decode(generate_ids[:, prompt_length:], skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
print(output_text)

ChartInstruct Flan-T5

from PIL import Image
import requests
from transformers import AutoProcessor, AutoModelForSeq2SeqLM
import torch

torch.hub.download_url_to_file('https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/val/png/multi_col_1229.png', 'chart_example_1.png')

image_path = "/content/chart_example_1.png"
input_text = "What is the share of respondants who prefer Whatsapp in the 18-29 age group?"

input_prompt = f"<image>\n Question: {input_text} Answer: "

model = AutoModelForSeq2SeqLM.from_pretrained("ahmed-masry/ChartInstruct-FlanT5-XL", torch_dtype=torch.float16, trust_remote_code=True)
processor = AutoProcessor.from_pretrained("ahmed-masry/ChartInstruct-FlanT5-XL")

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

image = Image.open(image_path).convert('RGB')

inputs = processor(text=input_prompt, images=image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}

# change type if pixel_values in inputs to fp16. 
inputs['pixel_values'] = inputs['pixel_values'].to(torch.float16)

# Generate
generate_ids = model.generate(**inputs, num_beams=4, max_new_tokens=512)
output_text = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
print(output_text)

Finetuning

ChartInstruct LLama2

Checkout the example colab notebook in the repo that shows how to finetune the model on the ChartQA Dataset. The training code is optimized such that you can train it on a T4 GPU which is free on Colab. The notebook has three different setups LoRA & QLoRA & Full Finetuning. Based on your machine's GPU, you can switch between them. This notebook was adapted from NielsRogge Tutorials

Contact

If you have any questions about this work, please contact Ahmed Masry using the following email addresses: [email protected] or [email protected].

Reference

Please cite our paper if you use our models in your research.

@misc{masry2024chartinstruct,
      title={ChartInstruct: Instruction Tuning for Chart Comprehension and Reasoning}, 
      author={Ahmed Masry and Mehrad Shahmohammadi and Md Rizwan Parvez and Enamul Hoque and Shafiq Joty},
      year={2024},
      eprint={2403.09028},
      archivePrefix={arXiv},
      primaryClass={id='cs.CL' full_name='Computation and Language' is_active=True alt_name='cmp-lg' in_archive='cs' is_general=False description='Covers natural language processing. Roughly includes material in ACM Subject Class I.2.7. Note that work on artificial languages (programming languages, logics, formal systems) that does not explicitly address natural-language issues broadly construed (natural-language processing, computational linguistics, speech, text retrieval, etc.) is not appropriate for this area.'}
}

chartinstruct's People

Contributors

ahmedmasryku avatar vis-nlp avatar

Stargazers

 avatar Xia Yu avatar Matt MV avatar nifeng avatar Andrew Chauzov avatar Poushi avatar TobyYang avatar Shivansh Dhar avatar 이루리 avatar Robin Ross avatar Yusuke-TOZAKI avatar 唐国梁Tommy avatar  avatar Coobiw avatar

Watchers

Jiangning Zhu avatar XiaPengcheng avatar  avatar Yusuke-TOZAKI avatar  avatar

Forkers

ydeng1992

chartinstruct's Issues

dataset release?

I am sorry if I miss it somewhere but i could not find the information about your data release.

Could you please point me to where the data is? And if not, would you mind sharing the plan for the release please?

Thank you very much!

Error while running locally on CPU host

For context duplicate of post on HF: https://huggingface.co/spaces/ahmed-masry/ChartInstruct-LLama2/discussions/1
Posting here as well in case you're more active on Github.

First off, thanks for this project! It works great and I'm looking forward to the ChartInstruct-Flan-T5-XL model.

The demo was working until a couple days ago and now Hugging Face shows "Error" for queries without much further details.

I also tried running the app.py file locally by downloading the model but I get the following error.

Does this model have to be run on a GPU?

  File "/usr/local/anaconda3/envs/chart_instruct_testing/lib/python3.12/site-packages/gradio/queueing.py", line 541, in process_events
    response = await route_utils.call_process_api(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/anaconda3/envs/chart_instruct_testing/lib/python3.12/site-packages/gradio/route_utils.py", line 276, in call_process_api
    output = await app.get_blocks().process_api(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/anaconda3/envs/chart_instruct_testing/lib/python3.12/site-packages/gradio/blocks.py", line 1928, in process_api
    result = await self.call_function(
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/anaconda3/envs/chart_instruct_testing/lib/python3.12/site-packages/gradio/blocks.py", line 1514, in call_function
    prediction = await anyio.to_thread.run_sync(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/anaconda3/envs/chart_instruct_testing/lib/python3.12/site-packages/anyio/to_thread.py", line 56, in run_sync
    return await get_async_backend().run_sync_in_worker_thread(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/anaconda3/envs/chart_instruct_testing/lib/python3.12/site-packages/anyio/_backends/_asyncio.py", line 2177, in run_sync_in_worker_thread
    return await future
           ^^^^^^^^^^^^
  File "/usr/local/anaconda3/envs/chart_instruct_testing/lib/python3.12/site-packages/anyio/_backends/_asyncio.py", line 859, in run
    result = context.run(func, *args)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/anaconda3/envs/chart_instruct_testing/lib/python3.12/site-packages/gradio/utils.py", line 833, in wrapper
    response = f(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^
  File "/Users/shivi/Library/CloudStorage/[email protected]/My Drive/Waverly AI/ChartInstruct-LLama2/app.py", line 34, in predict
    generate_ids = model.generate(**inputs, num_beams=4, max_new_tokens=512)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/anaconda3/envs/chart_instruct_testing/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/anaconda3/envs/chart_instruct_testing/lib/python3.12/site-packages/transformers/generation/utils.py", line 1953, in generate
    result = self._beam_search(
             ^^^^^^^^^^^^^^^^^^
  File "/usr/local/anaconda3/envs/chart_instruct_testing/lib/python3.12/site-packages/transformers/generation/utils.py", line 2914, in _beam_search
    outputs = self(
              ^^^^^
  File "/usr/local/anaconda3/envs/chart_instruct_testing/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/anaconda3/envs/chart_instruct_testing/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/anaconda3/envs/chart_instruct_testing/lib/python3.12/site-packages/transformers/models/llava/modeling_llava.py", line 424, in forward
    image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/anaconda3/envs/chart_instruct_testing/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/anaconda3/envs/chart_instruct_testing/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/anaconda3/envs/chart_instruct_testing/lib/python3.12/site-packages/transformers/models/donut/modeling_donut_swin.py", line 965, in forward
    embedding_output, input_dimensions = self.embeddings(
                                         ^^^^^^^^^^^^^^^^
  File "/usr/local/anaconda3/envs/chart_instruct_testing/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/anaconda3/envs/chart_instruct_testing/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/anaconda3/envs/chart_instruct_testing/lib/python3.12/site-packages/transformers/models/donut/modeling_donut_swin.py", line 210, in forward
    embeddings = self.norm(embeddings)
                 ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/anaconda3/envs/chart_instruct_testing/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/anaconda3/envs/chart_instruct_testing/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/anaconda3/envs/chart_instruct_testing/lib/python3.12/site-packages/torch/nn/modules/normalization.py", line 201, in forward
    return F.layer_norm(
           ^^^^^^^^^^^^^
  File "/usr/local/anaconda3/envs/chart_instruct_testing/lib/python3.12/site-packages/torch/nn/functional.py", line 2546, in layer_norm
    return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: "LayerNormKernelImpl" not implemented for 'Half'

I would appreciate any support and would be happy to answer any further questions.

Error in FineTuning

The processor does not return image sizes.
62 File "/home/ec2-user/anaconda3/envs/gcc/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
63 return self.collate_fn(data)
64 ^^^^^^^^^^^^^^^^^^^^^
65 File "/home/ec2-user/SageMaker/fineTune_chart_new.py", line 241, in train_collate_fn
66 image_sizes = batch["image_sizes"]
67 ~~~~~^^^^^^^^^^^^^^^
68 File "/home/ec2-user/anaconda3/envs/gcc/lib/python3.12/site-packages/transformers/feature_extraction_utils.py", line 87, in getitem
69 return self.data[item]
70 ~~~~~~~~~^^^^^^
71 KeyError: 'image_sizes'

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.