Giter VIP home page Giter VIP logo

Comments (5)

ZeyuTeng96 avatar ZeyuTeng96 commented on June 8, 2024

微调的shell脚本:

#!/bin/bash
#SBATCH --job-name=medical_qa_finetune
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=8
#SBATCH --gres=gpu:8 # number of gpus
#SBATCH -o /cognitive_comp/wuziwei/task/fs_medical_qa_finetune/%x-%j.log
#SBATCH -e /cognitive_comp/wuziwei/task/fs_medical_qa_finetune/%x-%j.err
#SBATCH -x dgx[050,049]

#export NCCL_DEBUG=INFO

set -x -e

echo "START TIME: $(date)"
MICRO_BATCH_SIZE=1
ROOT_DIR=$(pwd)

config_json="$ROOT_DIR/training_config.json"

TRAINER_ARGS="
--max_epochs 10
--gpus 4
--num_nodes 1
--default_root_dir $ROOT_DIR
--dirpath $ROOT_DIR/ckpt
--save_top_k 3
--monitor train_loss
--mode min
--save_last
"
DATA_DIR=$(pwd)
DATA_ARGS="
--data_dir $DATA_DIR
--train_batchsize $MICRO_BATCH_SIZE
--valid_batchsize $MICRO_BATCH_SIZE
--train_data processed_train_Wenzhong.json
--valid_data processed_val_Wenzhong.json
--test_data processed_test_Wenzhong.json
"

PRETRAINED_MODEL_PATH=/cognitive_comp/wuziwei/pretrained_model_hf/gpt2

PRETRAINED_MODEL_PATH='IDEA-CCNL/Wenzhong-GPT2-3.5B'
MODEL_ARGS="
--pretrained_model_path ${PRETRAINED_MODEL_PATH}
--output_save_path $ROOT_DIR/predict.json
--learning_rate 1e-4
--weight_decay 0.1
--warmup 0.01
"

SCRIPTS_PATH=$ROOT_DIR/finetune_medicalQA.py

export CMD="
$SCRIPTS_PATH
$TRAINER_ARGS
$MODEL_ARGS
$DATA_ARGS
"

echo $CMD

python $CMD

from fengshenbang-lm.

ZeyuTeng96 avatar ZeyuTeng96 commented on June 8, 2024

微调的python脚本 - finetune_medicalQA.py:

from transformers import GPT2LMHeadModel
from medicalQADataset import GPT2QADataModel
from transformers.optimization import get_linear_schedule_with_warmup
from pytorch_lightning import Trainer, loggers
from pytorch_lightning.callbacks import ModelCheckpoint
import pytorch_lightning as pl
import argparse
import torch
import os
import sys
os.environ["CUDA_VISIBLE_DEVICES"] = '1, 3, 4, 7'

class GPT2FinetuneMedicalQAModelCheckpoint:
@staticmethod
def add_argparse_args(parent_args):
parser = parent_args.add_argument_group('BaseModel')

    parser.add_argument('--monitor', default='train_loss', type=str)
    parser.add_argument('--mode', default='min', type=str)
    parser.add_argument('--dirpath', default='./ckpt/', type=str)
    parser.add_argument(
        '--filename', default='model-{epoch:02d}-{train_loss:.4f}', type=str)
    parser.add_argument('--save_last', action='store_true', default=True)
    parser.add_argument('--save_top_k', default=3, type=float)
    parser.add_argument('--every_n_train_steps', default=1000, type=float)
    parser.add_argument('--save_weights_only', default=True, type=bool)

    return parent_args

def __init__(self, args):
    self.callbacks = ModelCheckpoint(monitor=args.monitor,
                                     save_top_k=args.save_top_k,
                                     mode=args.mode,
                                     #  every_n_train_steps=args.every_n_train_steps,
                                     save_weights_only=args.save_weights_only,
                                     dirpath=args.dirpath,
                                     filename=args.filename,
                                     save_last=args.save_last)

class GPT2FinetuneMedicalQA(pl.LightningModule):

@staticmethod
def add_model_specific_args(parent_args):
    parser = parent_args.add_argument_group('BaseModel')
    parser.add_argument('--learning_rate', default=1e-4, type=float)
    parser.add_argument('--weight_decay', default=0.1, type=float)
    parser.add_argument('--warmup', default=0.01, type=float)
    return parent_args

def __init__(self, args, num_data):
    super().__init__()
    self.args = args
    self.num_data = num_data
    print('num_data:', num_data)
    self.model = GPT2LMHeadModel.from_pretrained(
        args.pretrained_model_path)

def setup(self, stage) -> None:
    if stage == 'fit':
        num_gpus = self.trainer.gpus if self.trainer.gpus is not None else 0
        self.total_step = int(self.trainer.max_epochs * self.num_data /
                              (max(1, num_gpus) * self.trainer.accumulate_grad_batches))
        print('Total training step:', self.total_step)

def training_step(self, batch, batch_idx):
    output = self.model(input_ids=batch['input_ids'],
                        attention_mask=batch['attention_mask'], labels=batch['labels'])
    # output = self.model(input_ids=batch['input_ids'], labels=batch['labels'])
    # acc = self.comput_metrix(output.logits, batch['labels'])
    self.log('train_loss', output.loss)
    return output.loss

def comput_metrix(self, logits, labels):
    y_pred = torch.argmax(logits, dim=-1)
    y_pred = y_pred.view(size=(-1,))
    y_true = labels.view(size=(-1,)).float()
    corr = torch.eq(y_pred, y_true)
    acc = torch.sum(corr.float())/labels.size()[0]
    return acc

def validation_step(self, batch, batch_idx):
    output = self.model(input_ids=batch['input_ids'],
                        attention_mask=batch['attention_mask'], labels=batch['labels'])
    # output = self.model(input_ids=batch['input_ids'], labels=batch['labels'])
    # acc = self.comput_metrix(output.logits, batch['labels'])
    self.log('val_loss', output.loss)
    # self.log('val_acc', acc)

def configure_optimizers(self):
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    paras = list(
        filter(lambda p: p[1].requires_grad, self.named_parameters()))
    paras = [{
        'params':
        [p for n, p in paras if not any(nd in n for nd in no_decay)],
        'weight_decay': self.args.weight_decay
    }, {
        'params': [p for n, p in paras if any(nd in n for nd in no_decay)],
        'weight_decay': 0.0
    }]
    optimizer = torch.optim.AdamW(paras, lr=self.args.learning_rate)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, int(self.total_step * self.args.warmup),
        self.total_step)

    return [{
        'optimizer': optimizer,
        'lr_scheduler': {
            'scheduler': scheduler,
            'interval': 'step',
            'frequency': 1
        }
    }]

def main():
total_parser = argparse.ArgumentParser("Summary Task")
total_parser.add_argument(
'--do_eval_only', action='store_true', default=False)
total_parser.add_argument(
'--pretrained_model_path', default=None, type=str)
total_parser.add_argument('--output_save_path',
default='./predict.json', type=str)
# * Args for data preprocessing
total_parser = GPT2QADataModel.add_data_specific_args(total_parser)
# * Args for training
total_parser = Trainer.add_argparse_args(total_parser)
total_parser = GPT2FinetuneMedicalQAModelCheckpoint.add_argparse_args(
total_parser)
total_parser = GPT2FinetuneMedicalQA.add_model_specific_args(total_parser)
# * Args for base model
args = total_parser.parse_args()

data_model = GPT2QADataModel(args)
if not args.do_eval_only:
    model = GPT2FinetuneMedicalQA(args, len(data_model.train_dataloader()))
    checkpoint_callback = GPT2FinetuneMedicalQAModelCheckpoint(
        args).callbacks
    logger = loggers.TensorBoardLogger(save_dir=os.path.join(
        args.default_root_dir, 'log/'), name='MedicalQA-GPT2')
    trainer = Trainer.from_argparse_args(args,
                                         logger=logger,
                                         callbacks=[checkpoint_callback]
                                         )
    trainer.fit(model, data_model)

    # result = trainer.predict(model, data_model)
    # with open('test_results.txt', 'wt', encoding='utf-8') as w:
    #     for line in result:
    #         w.writelines(line)

    model.model.save_pretrained(
        '/cognitive_comp/wuziwei/pretrained_model_hf')
else:
    print('save to hf.....')
    trainer = Trainer.from_argparse_args(args)
    model = GPT2FinetuneMedicalQA(
        args, len(data_model.predict_dataloader()))

    result = trainer.predict(
        model, data_model, ckpt_path='/cognitive_comp/wuziwei/task/fs_medical_qa_finetune/ckpt/last.ckpt')
    # with open('test_results.txt','wt',encoding='utf-8') as w:
    #     for line in result:
    #         w.writelines(line)

    model.model.save_pretrained(
        '/cognitive_comp/wuziwei/pretrained_model_hf')

if name == 'main':
main()

from fengshenbang-lm.

koking0 avatar koking0 commented on June 8, 2024

Bus error (core dumped)会不会是磁盘的问题,我之前遇到过是因为磁盘满了。

from fengshenbang-lm.

Gavingx avatar Gavingx commented on June 8, 2024

尝试微调Wenzhong-GPT2-3.5B报错, 具体报错信息如下:

Using pad_token, but it is not set yet. 训练集处理进度: 100%|████████████████████████████████████████████████████████████████| 3774619/3774619 [00:41<00:00, 90473.26it/s] Using pad_token, but it is not set yet. 验证集处理进度: 100%|████████████████████████████████████████████████████████████████████| 19220/19220 [00:00<00:00, 60371.62it/s] Using pad_token, but it is not set yet. 测试集处理进度: 100%|██████████████████████████████████████████████████████████████████████| 2409/2409 [00:00<00:00, 67752.58it/s] num_data: 3774619 /opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:446: LightningDeprecationWarning: Setting Trainer(gpus=4) is deprecated in v1.7 and will be removed in v2.0. Please use Trainer(accelerator='gpu', devices=4) instead. f"Setting Trainer(gpus={gpus!r}) is deprecated in v1.7 and will be removed" GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs /opt/conda/lib/python3.7/multiprocessing/semaphore_tracker.py:144: UserWarning: semaphore_tracker: There appear to be 4 leaked semaphores to clean up at shutdown len(cache)) Bus error (core dumped)

我也遇到了这个问题,我是在容器里面跑的。我将容器的共享内存设置高一点解决了这个问题

from fengshenbang-lm.

ZeyuTeng96 avatar ZeyuTeng96 commented on June 8, 2024

尝试微调Wenzhong-GPT2-3.5B报错, 具体报错信息如下:
Using pad_token, but it is not set yet. 训练集处理进度: 100%|████████████████████████████████████████████████████████████████| 3774619/3774619 [00:41<00:00, 90473.26it/s] Using pad_token, but it is not set yet. 验证集处理进度: 100%|████████████████████████████████████████████████████████████████████| 19220/19220 [00:00<00:00, 60371.62it/s] Using pad_token, but it is not set yet. 测试集处理进度: 100%|██████████████████████████████████████████████████████████████████████| 2409/2409 [00:00<00:00, 67752.58it/s] num_data: 3774619 /opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:446: LightningDeprecationWarning: Setting Trainer(gpus=4) is deprecated in v1.7 and will be removed in v2.0. Please use Trainer(accelerator='gpu', devices=4) instead. f"Setting Trainer(gpus={gpus!r}) is deprecated in v1.7 and will be removed" GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs /opt/conda/lib/python3.7/multiprocessing/semaphore_tracker.py:144: UserWarning: semaphore_tracker: There appear to be 4 leaked semaphores to clean up at shutdown len(cache)) Bus error (core dumped)

我也遇到了这个问题,我是在容器里面跑的。我将容器的共享内存设置高一点解决了这个问题

谢谢大佬,后续是能运行了,但是因为参数量太大,显存不够用,开了off_load训练时间太长了,所以后续就没有微调了

from fengshenbang-lm.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.