Giter VIP home page Giter VIP logo

Comments (13)

MikeDean2367 avatar MikeDean2367 commented on June 17, 2024

您好,我没理解您的意思,从您提供的代码来看是训练代码,但是从您的描述来看是推理。

from knowlm.

MikeDean2367 avatar MikeDean2367 commented on June 17, 2024

关于批量预测,transformers包中的generate()方法本身支持批量预测。

首先需要加载模型和分词器得到modeltokenizer变量,然后将您的批量输入送入到分词器进行分词,这样您就得到了batch_size*length形状的向量,最后将其传入到model.generate()即可。

from knowlm.

githubgtl avatar githubgtl commented on June 17, 2024

from knowlm.

githubgtl avatar githubgtl commented on June 17, 2024

from knowlm.

zxlzr avatar zxlzr commented on June 17, 2024

请问您的问题是否已解决?

from knowlm.

githubgtl avatar githubgtl commented on June 17, 2024

请问您的问题是否已解决?
我可以通过data.map()得到list【input_ids】,但是直接输入到model。generate(会有维度错误

from knowlm.

MikeDean2367 avatar MikeDean2367 commented on June 17, 2024

您好,您没有提供代码,下面是我提供的伪代码:

model = LlamaForCausalLM.from_pretrained("model_name")
tokenizer = LlamaTokenizer.from_pretrained("model_name")
# 设置分词器的padding token
inputs = [
     "input1",
     "input3"
]
input_ids = tokenizer(inputs, return_tensors="pt",  padding=True)['input_ids']
outputs = model.generate(input_ids=input_ids, max_new_tokens=10)
ans = [tokenizer.decode(x) for x in outputs.detach().cpu().numpy().tolist()]
print(ans)

此时模型将并行处理2条输入(请确保有足够的显存,更多的解码参数可以在model.generate()中传入)。如有问题请告知我 :)

from knowlm.

githubgtl avatar githubgtl commented on June 17, 2024

import json
import os
import sys
from typing import List

import fire
import torch
import transformers
from datasets import load_dataset
from tqdm import tqdm

"""
Unused imports:
import torch.nn as nn
import bitsandbytes as bnb
"""

from peft import (
LoraConfig,
get_peft_model,
get_peft_model_state_dict,
prepare_model_for_int8_training,
set_peft_model_state_dict,
)
from transformers import LlamaForCausalLM, LlamaTokenizer, GenerationConfig
from peft import PeftModel
from utils.prompter import Prompter

def train(
base_model: str = "/data/home/scv3183/knowlm", # the only required argument
data_path: str = "/data/home/scv3183/deepke/example/llm/finetune/lora/data/RE/valid.json",
cutoff_len: int = 512,
val_set_size: int = 2000,
train_on_inputs: bool = False, # if False, masks out inputs in loss
prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca.
):
lora_weights = "../../finetune/lora/checkpoint/checkpoint-10"
tokenizer = LlamaTokenizer.from_pretrained(base_model)
output_dir = "../results/knowlm-valid-ori-1.json"
from tqdm import tqdm
prompter = Prompter(prompt_template_name)
if torch.cuda.is_available():
print("device is cuda")
device = "cuda"
else:
device = "cpu"

try:
    if torch.backends.mps.is_available():
        device = "mps"
except:
    pass
def tokenize(prompt, add_eos_token=True):
    """

    :param prompt:
    :param add_eos_token:
    :return:
    """
    return tokenizer(
        prompt,
        return_tensors="pt",
    )

def generate_and_tokenize_prompt(data_point):
    full_prompt = prompter.generate_prompt(
        data_point["instruction"],
        data_point["input"],
    )
    tokenized_full_prompt = tokenize(full_prompt)
    return tokenized_full_prompt

if data_path.endswith(".json") or data_path.endswith(".jsonl"):
    data = load_dataset("json", data_files=data_path)
    print(f"data includes: {data_path}")
else:
    """is folder"""
    data_paths = []
    data_path = data_path if data_path[-1] == "/" else data_path + "/"
    for i in os.listdir(data_path):
        data_paths.append(os.path.join(data_path, i))
    print(f"data includes: {data_paths}")
    data = load_dataset(data_paths)

if val_set_size > 0:
    test_data = data.map(generate_and_tokenize_prompt)
input_ids = test_data.data['train']['input_ids']

if device == "cuda":
    model = LlamaForCausalLM.from_pretrained(
        base_model,
        load_in_8bit=False,
        torch_dtype=torch.float16,
        device_map="auto",
    )
    model = PeftModel.from_pretrained(
        model,
        lora_weights,
        torch_dtype=torch.float16,
    )
    model.config.pad_token_id = tokenizer.pad_token_id = 0  # same as unk token id
    model.config.bos_token_id = tokenizer.bos_token_id = 1
    model.config.eos_token_id = tokenizer.eos_token_id = 2
    print("loaded model")
model.eval()
if torch.__version__ >= "2" and sys.platform != "win32":
    model = torch.compile(model)


generation_config = GenerationConfig(
    temperature=1.0,
    top_p=1.0,
    top_k=50,
    num_beams=1,
    max_new_tokens=256,
    pad_token_id=tokenizer.pad_token_id,
    return_dict_in_generate=True,
    output_scores=True
)
with torch.no_grad():
    outputs = model.generate(input_ids=input_ids, generation_config=generation_config)

# 解码生成的文本
decoded_texts = [tokenizer.decode(output.sequences[0], skip_special_tokens=True) for output in outputs]
with open(output_dir, "w") as writer:
    for i, result in tqdm(enumerate(decoded_texts)):
        print(i, result)
        writer.write(result + '\n')

if name == "main":
fire.Fire(train)感觉是一样的逻辑

from knowlm.

MikeDean2367 avatar MikeDean2367 commented on June 17, 2024

您好,请按照我之前提供的代码进行修改。

from knowlm.

githubgtl avatar githubgtl commented on June 17, 2024
input_ids = tokenizer(inputs, return_tensors="pt",  padding=True)['input_ids']

image

from knowlm.

MikeDean2367 avatar MikeDean2367 commented on June 17, 2024

您好,请使用如下代码(经过测试可用):

from transformers import LlamaForCausalLM, LlamaTokenizer, GenerationConfig
import torch

device = "cuda"
repo = "llama2-7b-chat"
model = LlamaForCausalLM.from_pretrained(repo, device_map="auto", torch_dtype=torch.bfloat16)
tokenizer = LlamaTokenizer.from_pretrained(repo)
tokenizer.pad_token_id = tokenizer.unk_token_id
tokenizer.padding_side='left'

inputs = [
     "who are you?",
     "do you know zhejiang university?"
]
input_ids = tokenizer(inputs, return_tensors="pt",  padding=True)['input_ids'].to(device)
outputs = model.generate(input_ids=input_ids, max_new_tokens=100)
ans = [tokenizer.decode(x) for x in outputs.detach().cpu().numpy().tolist()]
print(ans)

请修改里面的repo变量,此外如果您的显卡支持bf16,则保持torch_dtype=torch.bfloat16,否则修改成torch.float16torch.float32

如有其他问题,请告知我 :)

from knowlm.

zxlzr avatar zxlzr commented on June 17, 2024

请问您的问题是否已解决

from knowlm.

githubgtl avatar githubgtl commented on June 17, 2024

请问您的问题是否已解决

已解决

from knowlm.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.