Giter VIP home page Giter VIP logo

cartoon-caption-generation's Introduction

Humor in AI: Massive Scale Crowd-Sourced Preferences and Benchmarks for Cartoon Captioning

Code License Data License Python 3.10+ Code style: black

Dataset

Read our paper at arxiv (link)

See past hosted contest at this website

This code constains code for:

  • Finetuning preference models
  • Evaluation with languages models like GPT-4
  • results exploration and diversity investigation

Examples and Tutorials

Examples of how we use GPT4 to generate captions and descriptions can be found in generation/example.ipynb. Examples of us evaluating GPT4's ranking reliability (human top 10 vs 1000-1009) can be found in ranking/example_rank.ipynb. Examples of evaluation of model generated outputs can be found in ranking/example_rank_more.ipynb. Lastly, examples generated from our finetuned model can be found in examples/generations.

Evaluation

We present a novel multimodal preference dataset for creative tasks, consisting of over 250 million human ratings on more than 2.2 million captions, collected through crowdsourcing rating data for The New Yorker's weekly cartoon caption contest over the past eight years. This unique dataset supports the development and evaluation of multimodal large language models and preference-based fine-tuning algorithms for humorous caption generation. We propose novel benchmarks for judging the quality of model-generated captions, utilizing both GPT4 and human judgments to establish ranking-based evaluation strategies. Our experimental results highlight the limitations of current fine-tuning methods, such as RLHF and DPO, when applied to creative tasks. Furthermore, we demonstrate that even state-of-the-art models like GPT4 and Claude currently underperform top human contestants in generating humorous captions. As we conclude this extensive data collection effort, we release the entire preference dataset to the research community, fostering further advancements in AI humor generation and evaluation.

Finetuning

Prior to finetuning, you need to change the directory.

cd finetuning

and create the train/test datasets for each of the generation method with the following command:

python preprocess.py 

SFT

python humor_sft.py --output_dir /your/output/dir/  --dataset_dir /your/dataset/dir/

DPO

Our experiments show that using a sft checkpoint from simple prompt forms the better checkpoint for DPO than training from scratch or from an SFT checkpoint with long prompt.

python humor_sft.py --output_dir /your/output/dir/  --dataset_dir /your/dataset/dir/ --new_padding_token --simple_prompt

Including --new_padding_token will produce similar model, but it is required to obtain an sft checkpoint for further finetuning the DPO model.
Including --simple_prompt will use a simple prompt (same as DPO) for SFT.

Then, you can finetune the DPO from this generated SFT checkpoint.

python humor_dpo.py --dataset_dir /your/dataset/dir/ --model_name mistralai/Mistral-7B-instruct-v0.1 --run_name full-instruct-dpo-warmup  --do_train --do_eval --output_dir /your/output/dir/
--model_checkpoint_name /your/sft/checkpoint/with/simple/prompt/

Reward Modeling

You can use the following command to finetune a reward model. Since generating humorous texts typically is not in the training dataset of public reward model, we need to finetune the reward model ourselves.

python humor_reward_modeling.py --dataset_dir /your/dataset/dir/ --model_name weqweasdas/RM-Mistral-7B --run_name rm --do_train  --do_eval --output_dir /your/output/dir/ --max_steps 5000

You can also choose custom reward model from reward bench to finetune different reward models.

PPO

Our PPO model is directly finetune from mistralai/Mistral-7B-instruct-v0.1. You need to first finetune a reward model to run the PPO.

python humor_ppo.py --dataset_dir /your/dataset/dir --run_name ppo --output_dir /your/output/dir --target_kl 80 --reward_model /your/finetuned/reward/model

LLaVA finetune

To perform LLaVA finetune, you need to first clone the original LLaVA directory.

git clone https://github.com/haotian-liu/LLaVA/

Then, at the uppermost level, run the following command.

deepspeed --include localhost:2 llava/train/train_mem.py \
    --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \
    --deepspeed ./scripts/zero3.json \
    --model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
    --version v1 \
    --data_path /your/dataset/dir/llava_sft_dataset/train_llava_sft_dataset.json \
    --image_folder /your/dataset/dir/cartoons/ \
    --vision_tower openai/clip-vit-large-patch14-336 \
    --mm_projector_type mlp2x_gelu \
    --mm_vision_select_layer -2 \
    --mm_use_im_start_end False \
    --mm_use_im_patch_token False \
    --image_aspect_ratio pad \
    --group_by_modality_length True \
    --bf16 True \
    --output_dir /your/output/dir/llava_sft/ \
    --num_train_epochs 1 \
    --per_device_train_batch_size 16 \
    --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps 1 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 1000 \
    --save_total_limit 10 \
    --learning_rate 2e-4 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --tf32 True \
    --model_max_length 2048 \
    --gradient_checkpointing True \
    --dataloader_num_workers 4 \
    --lazy_preprocess True \
    --report_to wandb

Generating the results of pretrained and finetuned language model

You can generate sample captions given an trained model using the following commands.

# Save ZS result
python save_results.py --method zs --dataset_dir /your/dataset/dir --output_dir /your/output/dir --model_name mistralai/Mistral-7B-Instruct-v0.1 --num_generation 10
# Save SFT result
CUDA_VISIBLE_DEVICES=5 python save_results.py --method sft --dataset_dir /your/dataset/dir --output_dir /your/output/dir --model_name mistralai/Mistral-7B-Instruct-v0.1 --model_checkpoint /your/output/dir/sft/new_pad --num_generation 10 --new_padding_token 
# Save dpo result
CUDA_VISIBLE_DEVICES=5 python save_results.py --method dpo --dataset_dir /your/dataset/dir --output_dir /your/output/dir --model_name mistralai/Mistral-7B-Instruct-v0.1 --model_checkpoint /your/output/dir/sft/new_pad --num_generation 10 --new_padding_token 
# Save ppo result
python save_results.py --method ppo --dataset_dir /your/dataset/dir --output_dir /your/output/dir --model_name mistralai/Mistral-7B-Instruct-v0.1 --model_checkpoint mistralai/Mistral-7B-Instruct-v0.1 --num_generation 10
# Save llava result
python save_results.py --method llava --dataset_dir /your/dataset/dir --output_dir /your/output/dir --model_name llava-hf/llava-v1.6-mistral-7b-hf --num_generation 10 --device cuda:5
# Save llava sft result
python save_results.py --method llava --dataset_dir /your/dataset/dir --output_dir /your/output/dir --model_name llava-hf/llava-v1.6-mistral-7b-hf --model_checkpoint your/llava/sft/checkpoint --num_generation 10 --device cuda:5

To obtain the best-of-N sample generation, you need to first generate more captions. We recommend generating 5 times more captions than the final generations as a rule of thumb. Then, you can pick good captions out of these generations with a finetuned reward model.

python save_results.py --method zs --dataset_dir /your/dataset/dir --output_dir /your/output/dir --model_name mistralai/Mistral-7B-Instruct-v0.1 --num_generation 50
python save_bon_results.py --reward_model /your/reward/model/ --dataset_dir /your/dataset/dir --generation_file /your/output/dir/generation/zs_gen10.csv --model_name mistralai/Mistral-7B-Instruct-v0.1 --num_generation 10

You can also check out our already generated captions in examples/generation. To further evaluate these generations, you can also refer to finetuning/generation_evaluation.ipynb or ranking/example_rank_more.ipynb.

Download Checkpoints

Since the finetune procedure can take from 1 day up to a week on an A100, we provide all model checkpoints for finetuned models. Model checkpoints can be found here.It incluces:

  • reward
  • sft
  • dpo
  • ppo
  • llava_sft

You can also see our sample caption generations from the pretrained model on the test split in this examples/generations

  • claude.csv: 10 captions generated from Claude-3-Opus
  • gpt4o.csv: 10 captions generated from GPT-4o Vision
  • zs.csv: 10 captions generated from Mistral-Instruct-7B in a zero shot manner
  • zs_BoN.csv: We first generated 50 captions using Mistral-Instruct-7B in a zero-shot manner, then we use our finetuned reward model to pick the best 10 captions.
  • sft.csv: 10 captions generated from sft model of Mistral-Instruct-7B
  • dpo.csv: 10 captions generated from DPO finetuned model of Mistral-Instruct-7B
  • ppo.csv: 10 captions generated from PPO finetuned model of Mistral-Instruct-7B
  • llava.csv: 10 captions generated from LLaVA pretrained model (llava-v1.6-mistral-7b-hf)
  • llava_sft.csv: 10 captions generated from LLaVA finetuned model from (llava-v1.6-mistral-7b-hf)
  • human_top.csv, human_200.csv, human_1000.csv, human_median.csv: 10 captions from human contestants, at ranking 1-10, 200-209, 1000-1009, median

Citation

Please consider citing our work if you use our code and data in this repo

@article{zhang2024humor,
  title={Humor in AI: Massive Scale Crowd-Sourced Preferences and Benchmarks for Cartoon Captioning},
  author={Zhang, Jifan and Jain, Lalit and Guo, Yang and Chen, Jiayi and Zhou, Kuan Lok and Suresh, Siddharth and Wagenmaker, Andrew and Sievert, Scott and Rogers, Timothy and Jamieson, Kevin and others},
  journal={arXiv preprint arXiv:2406.10522},
  year={2024}
}

cartoon-caption-generation's People

Contributors

yguooo avatar klz8029 avatar jifanz avatar jiayi-6 avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.