Comments (2)
有人做过吗,咋调用百川的流式API接口啊
`>>> import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig
tokenizer = AutoTokenizer.from_pretrained("baichuan-inc/Baichuan2-13B-Chat", use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("baichuan-inc/Baichuan2-13B-Chat", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)
model.generation_config = GenerationConfig.from_pretrained("baichuan-inc/Baichuan2-13B-Chat")
messages = []
messages.append({"role": "user", "content": "解释一下“温故而知新”"})
response = model.chat(tokenizer, messages)
print(response)
"温故而知新"是一句**古代的成语,出自《论语·为政》篇。这句话的意思是:通过回顾过去,我们可以发现新的知识和理解。换句话说,学习历史和经验可以让我们更好地理解现在和未来。
`
或者这个能改成流式的代码吗
from baichuan-13b.
# coding=utf-8
# Implements API for ChatGLM2-6B in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat)
# Usage: python openai_api.py
# Visit http://localhost:8000/docs for documents.
import time
import torch
import uvicorn
from pydantic import BaseModel, Field
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from typing import Any, Dict, List, Literal, Optional, Union
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
from sse_starlette.sse import ServerSentEvent, EventSourceResponse
import json
import logging
from transformers.generation.utils import GenerationConfig
@asynccontextmanager
async def lifespan(app: FastAPI): # collects GPU memory
yield
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
logging.basicConfig(level=logging.WARN, filemode='a',filename='./log_stream_answer.log',format='%(asctime)s - %(levelname)s: %(message)s',datefmt='%Y-%m-%d %H:%M:%S')
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
DEVICE = "cuda"
DEVICE_ID = "0"
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
def torch_gc():
if torch.cuda.is_available():
with torch.cuda.device(CUDA_DEVICE):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
class ModelCard(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "owner"
root: Optional[str] = None
parent: Optional[str] = None
permission: Optional[list] = None
class ModelList(BaseModel):
object: str = "list"
data: List[ModelCard] = []
class ChatMessage(BaseModel):
role: Literal["user", "assistant", "system"]
content: str
class DeltaMessage(BaseModel):
role: Optional[Literal["user", "assistant", "system"]] = None
content: Optional[str] = None
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
temperature: Optional[float] = None
top_p: Optional[float] = None
max_length: Optional[int] = None
stream: Optional[bool] = False
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Literal["stop", "length"]
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]]
class ChatCompletionResponse(BaseModel):
model: str
object: Literal["chat.completion", "chat.completion.chunk"]
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
@app.get("/v1/models", response_model=ModelList)
async def list_models():
global model_args
model_card = ModelCard(id="gpt-3.5-turbo")
return ModelList(data=[model_card])
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
global model, tokenizer
if request.messages[-1].role != "user":
raise HTTPException(status_code=400, detail="Invalid request")
query = request.messages[-1].content
prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == "system":
query = prev_messages.pop(0).content + query
messages_input = []
for msg_obj in request.messages:
messages_input.append({"role": msg_obj.role, "content": msg_obj.content})
logging.warning("---------------------------------------------------")
logging.warning(messages_input)
history = []
if request.stream:
#logging.info("query: "+str(query))
generate = predict(query, history, request.model,messages_input)
torch_gc()
return EventSourceResponse(generate, media_type="text/event-stream")
#response, _ = model.chat(tokenizer, query, history=history)
response = model.chat(tokenizer, messages_input, stream=False)
logging.warning(response)
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop"
)
torch_gc()
return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion")
async def predict(query: str, history: List[List[str]], model_id: str,messages: List[dict]):
global model, tokenizer
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
#yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield json.dumps(chunk.dict(exclude_unset=True), ensure_ascii=False)
current_length = 0
#for new_response, _ in model.stream_chat(tokenizer, query, history):
for new_response in model.chat(tokenizer, messages, stream=True):
if len(new_response) == current_length:
continue
new_text = new_response[current_length:]
current_length = len(new_response)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(content=new_text),
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
#yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield json.dumps(chunk.dict(exclude_unset=True), ensure_ascii=False)
logging.warning(new_response)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason="stop"
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
#yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield json.dumps(chunk.dict(exclude_unset=True), ensure_ascii=False)
yield '[DONE]'
if __name__ == "__main__":
model_path = '/home/checkpoints/Baichuan2-13B-Chat-4bits'
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
model.generation_config = GenerationConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(
model_path,
use_fast=False,
trust_remote_code=True
)
uvicorn.run(app, host='0.0.0.0', port=7842, workers=1)
from baichuan-13b.
Related Issues (20)
- 对baichuan13b还没有开始微调,仅仅是对话就自言自语?总是泄露Human: Assistant:对话数据 HOT 1
- baichuan-13b-chat sft微调loss不下降 HOT 1
- 如何离线部署? HOT 1
- ValueError: Tokenizer class BaichuanTokenizer does not exist or is not currently imported. HOT 2
- ValueError: The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder` for them. Alternatively, make sure you have `safetensors` installed if the model you are using offers the weights in this format.
- 这个模型不支持多gpu模式吗
- 请问下,大家都是租用GPU服务器来运行大模型吗
- baichuan2 mmlu结果复现的问题
- baichuan-13b-chat批量生成示例
- 本地部署版本问题 HOT 2
- v100能部署Baichuan-13B-Base么?
- 各位大佬,微调baichuan2-13b后得到pth文件,该如何推理
- 各位大佬,请问采用官网给出的fine-tune文件做微调大概需要多少显存,使用A6000(48G)显示内存溢出。
- 将训练好的模型进行放入到web.demo中报错
- feat: function calling
- Baichuan13B vllm 效果很差 HOT 1
- 想问一下百川2量化版本的算法是什么?
- 解决
- 如何加速模型推理速度?
- npu部署
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from baichuan-13b.