Giter VIP home page Giter VIP logo

pytorch-worker's Introduction

Pytorch Worker

本框架为基于pytorch的模型训练、测试框架。该框架的目的是方便大家快速上手写出pytorch的模型,同时能够定制化属于自己的模型、输出、数据处理和评价指标,方便大家快速完成同任务上的大量模型的实验。

英文的README可以通过这里访问。

目录

运行方法

运行方法分为模型训练和模型测试两个部分。

模型训练

无论是模型的训练还是测试,我们都需要指定本次运行的参数即配置文件,配置文件的详细说明可以可以参考下一节的内容。如果我们想要训练我们的模型,运行方法如下:

python3 train.py --config 配置文件 --gpu GPU列表

例如,如果我们想用编号为2,3,5的三张GPU运行中文Bert的分类任务,我们可以运行如下命令:

python3 train.py --config config/nlp/BasicBert.config --gpu 2,3,5

此外,除了原始的DataParallel的方法以外,我们还提供了使用DistributedDataParallel的多卡并行方法,运行命令为:

python3 -m torch.distributed.launch train.py --config config/nlp/BasicBert.config --gpu 2,3,5

具体的实现方法请参考模型中的写法。

当然,如果你不想使用GPU来运行你的模型的话,你可以去掉--gpu选项来完成这一点,例如:

python3 train.py --config config/nlp/BasicBert.config

这样的运行方式就可以不使用GPU来训练模型。

如果你并不想从头开始训练你的模型,而是希望接着之前某次训练的结果继续运行的话,可以用如下命令:

python3 train.py --config 配置文件 --gpu GPU列表 --checkpoint 模型参数文件

参数checkpoint指向中间某次的训练结果文件,框架会从该文件中读取模型、优化器和训练迭代轮数等信息。(如果在中途修改了优化器并不会导致错误,框架会自动使用新的优化器继续运行)

模型测试

模型测试的运行方法:

python3 test.py --config 配置文件 --gpu GPU列表 --checkpoint 模型参数文件 --result 测试结果保存文件

会将测试的结果以json格式保存到结果文件中。

配置文件

运行逻辑

配置文件是该框架的核心模块之一,绝大部分的运行参数都是通过读取配置文件得到的。我们以一个例子来说明框架从配置文件读取参数的运行逻辑:

python3 train.py --config config/nlp/BasicBert.config

在这段代码里面我们指定了从BasicBert.config中读取我们所需要的参数,在实际运行中我们会总共涉及到三个不同的配置文件:

  1. config/nlp/BasicBert.config
  2. config/default_local.config
  3. config/default.config

当框架尝试读取某个参数的时候,会按照上述三个配置文件的顺序从上至下依次读取。例如如果我们想要读取batch_size这个参数,框架会先尝试从config/nlp/BasicBert.config中读取该参数;如果失败,会再尝试从config/default_local.config读取该参数;如果再次失败,会尝试从config/default.config读取参数。如果在三个配置文件中都没有读取到参数,则会抛出异常。

对于三个配置文件,我们建议每个文件中所需要包含的参数有:

  • config/default.config:在这个文件中,我们建议将一些对于不同模型不变,或者说一些参数的默认值写在该文件中,例如像测试间隔、输出间隔等对于不同模型来说都不会有所改变的参数。
  • config/default_local.config:我们建议将模型所涉及到的路径信息写在该文件中,该文件并不会被同步到git中,更多情况下,该配置文件是用于在不同服务器上进行适配使用的文件。
  • 运行指定的配置文件:我们建议将运行对应模型的相关参数写在该文件里,一些对于不同模型没有变化的参数如数据地址、数据处理方式等参数不建议写在该配置文件里面。

在程序运行中,我们传递的config参数便为所对应的配置文件,支持原版ConfigParser的各种函数包括但不限于get,getint,getboolean等方法。

基本参数说明

配置文件的结构是遵循python下的ConfigParser包进行构建的,文件中[chapter]代表的是不同适用情况的参数,具体结构可以参考config文件夹下的例子。我们接下来将会介绍在基本框架中所涉及到的一些参数的说明,参数分为必要参数(运行所有模型都需要的参数)和可选参数(运行特定模型所需要的参数)。当然,你可以随意的在你自定义的方法里面增加新的参数

[train]:训练用参数

  • epoch:必要参数,代表需要训练的轮数。
  • batch_size:必要参数,代表训练时一次计算的数据量。
  • shuffle:必要参数,代表是否需要随机打乱数据。
  • reader_num:必要参数,需要多少个进程处理训练数据。
  • optimizer:必要参数,选择的优化器。
  • learning_rate:必要参数,学习率。
  • weight_decay:必要参数,权值正则化参数。
  • step_sizelr_multiplier:必要参数,学习率每过step_size个epoch变为原来的lr_multiplier倍。

[eval]:测试用参数

  • batch_size:必要参数,代表测试时一次计算的数据量。
  • shuffle:必要参数,代表是否需要随机打乱数据。
  • reader_num:必要参数,需要多少个进程处理测试数据。

[data]:数据用参数

  • train_dataset_type,valid_dataset_type,test_dataset_type:必要参数,分别代表训练、验证、测试时使用的数据读取器类型。如果验证和测试的参数没有指定,则默认使用训练的类型。
  • train_formatter_type,valid_formatter_type,test_formatter_type:必要参数,分别代表训练、验证、测试时使用的数据处理器类型。如果验证和测试的参数没有指定,则默认使用训练的类型。
  • train_data_path,valid_data_path,test_data_path:可选参数(仅用于框架已实现的数据读取器),分别代表训练、验证、测试时的数据位置。
  • train_file_list,valid_file_list,test_file_list:可选参数(仅用于框架已实现的数据读取器),分别代表训练、验证、测试对应的数据位置下,哪些文件或者文件夹属于数据。即真正的数据地址应该是train_data_path+train_file_list
  • recursive:可选参数(仅用于FilenameOnly,JsonFromFiles两种数据读取器),代表如果对应的数据地址是一个文件夹,是否需要递归地向下搜索更多的数据。
  • json_format:可选参数(仅用于ImageFromJson,JsonFromFiles两种数据读取器),代表对应的json文件的格式。如果为line代表一行一个json数据;如果为single代表整个文件为一个json数据。
  • load_into_mem:可选参数(仅用于ImageFromJson,JsonFromFiles两种数据读取器),代表是否提前把所有数据加载到内存中。
  • prefix:可选参数(仅用于ImageFromJson数据读取器),代表图片的相对路径起始位置。

[distributed]: 多卡参数

  • use:必要参数,代表是否使用DistributedDataParallel
  • backend:可选参数(仅当useTrue生效),代表所使用的backend,如果你不了解这是什么东西保持默认即可。

[model]:模型用参数

  • model_name:必要参数,代表训练的模型类型。
  • bert_path:可选参数(仅用于BasicBert模型),代表bert模型参数的地址。
  • output_dim:可选参数(仅用于框架已实现的模型),代表分类问题中模型所需要输出的种类数量。

[output]:输出用参数

  • output_time:必要参数,代表每多少次运行模型后输出一次结果。
  • test_time:必要参数,代表每多少个epoch进行一次验证。
  • model_path:必要参数,模型结果文件保存的地址。
  • model_name:必要参数,模型保存的名字。
  • tensorboard_path:必要参数,tensorboard存储的地址。(暂未实现)
  • accuracy_method:必要参数,计算模型好坏程度的指标函数
  • output_function:必要参数,用来产生中间指标输出的函数。
  • output_value:可选参数(仅用于Basic版本的指标输出函数),用来选择要输出的指标。

新方法的添加和已有方法

我们的框架中除开配置文件读取器以外,剩下的绝大部分模块都是可定制的,包括是:数据读取器数据处理器模型指标函数指标输出,这里每一个部分都可以添加你自己需要的方法或者模型,我们将在依次介绍每个模块的实现方法和功能。

数据读取器

模块功能:用于从文件中读取数据,存进pytorch的dataset中。

实现方法:如果需要实现新的数据读取器,我们需要在dataset文件加中新建一个文件来实现我们新的数据读取器,需要按照下列方法实现:

from torch.utils.data import Dataset

class DatasetName(Dataset):
    def __init__(self, config, mode, *args, **params):
        # 在这里进行初始化
        # config为读取的配置文件
        # mode为读取器的模式,包括train、valid和test三种模式
        pass

    def __getitem__(self, item):
        # 返回第item个数据
        pass

    def __len__(self):
        # 返回数据的总量
        return len(self.file_list)

在实现好我们的数据读取器之后,再将实现的数据读取器添加到dataset/__init__.py的列表即可使用。你也可以通过已实现的方法来学习如何实现一个数据读取器。

已实现的方法

  • FilenameOnly:只获取所有数据所对应的绝对路径的数据读取器。
  • ImageFromJson:通过一个json文件获取图片的地址和标签的数据读取器。json文件需要包含一个数组,数组里每个元素需要包括path(图片相对路径)和label(图片标签)两个字段。所谓相对路径,是以配置文件中[data] prefix所指定的路径为基础路径来说的。另外,该方法还可以通过改变[data] load_into_mem的值来决定是否提前将所有数据载入内存。
  • JsonFromFiles:从多个json文件中读取文本信息和标签的数据读取器。首先你可以通过设定[data] json_format来指定对应json文件的格式(见基本参数说明)。对于每条json数据,需要包含text(文本信息)和label(标签信息)。你同时也可以通过[data] recursive[data] load_into_mem来决定是否递归查找文件和是否提前将数据加载进内存。注意,由于文本数据的特殊性,我们建议:只有当提前将数据加载至内存时才对数据进行打乱操作,否则不要进行打乱操作,不然会大大降低数据读取的速度

数据处理器

模块功能:将数据读取器读取的数据处理成更够交给模型运行的格式。

实现方法:如果需要实现新的数据处理器,我们需要在formatter文件加中新建一个文件来实现我们新的数据处理器,需要按照下列方法实现:

class FormatterName:
    def __init__(self, config, mode, *args, **params):
        # 在这里进行初始化
        # config为读取的配置文件
        # mode为处理器的模式,包括train、valid和test三种模式
        pass

    def process(self, data, config, mode, *args, **params):
        # 对给定的数据data进行处理
        # data的格式为大小为batch_size的数组(在test模式下,最后一个batch的大小可能小于batch_size),里面每个元素即为从数据处理器的__getitem__中返回的格式
        # config和mode参数的设置同上
        # 这里我们返回的数据类型要求必须为python的dict格式,且需要把模型需要的字段处理为Tensor的格式
        pass

在实现好我们的数据处理器之后,再将实现的数据处理器添加到formatter/__init__.py的列表即可使用。你也可以通过已实现的方法来学习如何实现一个数据处理器。

已实现的方法

  • Basic:啥也不干的数据处理器,提供一个最基本的格式。
  • BasicBert:将JsonFromFiles读取的数据进行处理,把text转换为BasicBert模型所需要的token。

模型

模块功能:运行数据,产生结果。

实现方法:如果需要实现新的模型,我们需要在model文件加中新建一个文件来实现我们新的模型,需要按照下列方法实现:

class ModelName(nn.Module):
    def __init__(self, config, gpu_list, *args, **params):
        # 模型初始化部分
        # config为读取的配置文件
        # gpu_list为在运行的时候指定的gpu_id的列表
        super(ModelName, self).__init__()

    def init_multi_gpu(self, device, config, *args, **params):
        # 多卡初始化部分,用于将模型放置于多卡上
        # 如果没有多卡的需求,则不需要实现该函数
        # device为gpu_id的列表
        # config为读取的配置文件
        pass

    def forward(self, data, config, gpu_list, acc_result, mode):
        # 模型运行的核心部分
        # data为数据处理器处理好的数据,已自动将其中的Tensor进行gpu化
        # config为读取的配置文件
        # gpu_list为在运行的时候指定的gpu_id的列表
        # acc_result为类型的评价指标结果,如已经运行的所有结果中的准确率、召回率等信息,由之后的指标函数所决定
        # mode为模型的模式,包括train、valid和test三种模式
        # 返回格式为要求为python的dict
        # 在train和valid模式中,由于需要衡量模型和优化模型,返回的结果中必须包含loss和acc_result两个字段,分别代表损失函数的结果和累计的指标量。acc_result的计算在这里并不是必须的,但是如果想从多维的角度评判模型请一定使用
        # 在test模式中,会将output字段作为结果进行记录,所以需要保证output字段的类型必须为list,且其中的内容能够被json化
        pass

在实现好我们的模型之后,再将实现的模型添加到model/__init__.py的列表即可使用。你也可以通过已实现的方法来学习如何实现一个模型。

已实现的方法

  • BasicBert:基础的Bert单标签分类器。

指标函数

模块功能:产生除了损失以外的其他指标,用于衡量模型的水平。

实现方法:如果需要实现新的指标函数,我们需要在tools/accuracy_tool.py文件中新建一个方法来实现我们新的指标函数,需要按照下列方法实现:

def FunctionName(outputs, label, config, result):
    # 这只是一个示例,实际上由于不同模型使用的评价指标都是不一样的,这里你可以随意改造参数,我们只以已经实现好的几个方法的参数进行说明
    # outputs为模型预测的结果
    # label为标签
    # config为读取的配置文件
    # result为历史累计的评价指标结果
    # 返回值为新的评价指标结果
    pass

在实现好我们的指标函数之后,再将实现的指标函数添加到tools/accuracy_init.py的列表即可使用。你也可以通过已实现的方法来学习如何实现一个指标函数。

已实现的方法

  • Null:什么也不做的指标函数。
  • SingleLabelTop1:单标签分类问题的指标函数,用于计算每一类的TP,TN,FP,FN的值。
  • MultiLabel:多标签分类问题的指标函数,用于计算每一类的TP,TN,FP,FN的值。

指标输出

模块功能:通过指标函数产生的结果,产生用于打印至终端的评价指标。

实现方法:如果需要实现新的指标输出函数,我们需要在tools/output_tool.py文件中新建一个方法来实现我们新的指标输出函数,需要按照下列方法实现:

def FunctionName(data, config, *args, **params):
    # data为我们使用指标函数产生的结果
    # config为读取的配置文件
    # 返回值为需要输出的指标结果,要求类型为字符串

在实现好我们的指标输出函数之后,再将实现的指标输出函数添加到tools/output_init.py的列表即可使用。你也可以通过已实现的方法来学习如何实现一个指标输出函数。

已实现的方法

  • Null:什么也不做的指标输出函数。
  • Basic:分类方法的指标输出函数,可以选择输出micro_precision,micro_recall,micro_f1,macro_precision,macro_recall,macro_f1这六个指标中的任意多个。

框架运行逻辑

  1. 读取配置文件。

  2. 进行初始化操作。

    1)初始化数据处理器。

    2)初始化数据读取器。

    3)初始化模型,并多gpu化。

    4)初始化优化器。

    5)如果需要加载checkpoint,则进行加载。(加载出错只会显示warning)

  3. 开始训练。

    训练分为训练和验证两个步骤,其中训练每次迭代的逻辑为:

    1)从数据读取器读取数据。

    2)将数据交给数据处理器进行处理。

    3)模型运行数据,产生损失和评价指标,并优化模型。

    4)如果需要输出评价指标,利用指标输出函数产生输出。

    5)返回第一步,完成一次迭代。

    一个epoch完成之后,保存模型,且会检查是否需要进行验证,验证流程与训练大致相同,不过只会在全部完成之后产生一次评价指标。

依赖库

请参考requirements.txt

未来计划

  1. 添加可定制化的tensorboard显示。
  2. 添加对lr_scheduler的可定制化支持。
  3. 在各个可定制化模块中增加更多常用方法。

更新记录

V1.0.1

2020.04.22 加入了多显卡内存更平衡运行速度更快的方法。

V1.0.0

2020.01.01 完成最基本的框架。

作者与致谢

暂无。

pytorch-worker's People

Contributors

haoxizhong avatar wangyurzee7 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

pytorch-worker's Issues

求助。

06/18/2020 16:52:29 - INFO - tools.train_tool - Training start....
Traceback (most recent call last):
File "D:/workspace/CAIL2020/sfks/baseline/train.py", line 59, in
train(parameters, config, gpu_list, do_test)
File "D:\workspace\CAIL2020\sfks\baseline\tools\train_tool.py", line 91, in train
for step, data in enumerate(dataset):
File "D:\ProgramData\Anaconda3\envs\pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 279, in iter
return _MultiProcessingDataLoaderIter(self)
File "D:\ProgramData\Anaconda3\envs\pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 719, in init
w.start()
File "D:\ProgramData\Anaconda3\envs\pytorch\lib\multiprocessing\process.py", line 112, in start
self._popen = self._Popen(self)
File "D:\ProgramData\Anaconda3\envs\pytorch\lib\multiprocessing\context.py", line 223, in _Popen
return _default_context.get_context().Process._Popen(process_obj)
File "D:\ProgramData\Anaconda3\envs\pytorch\lib\multiprocessing\context.py", line 322, in _Popen
return Popen(process_obj)
File "D:\ProgramData\Anaconda3\envs\pytorch\lib\multiprocessing\popen_spawn_win32.py", line 89, in init
reduction.dump(process_obj, to_child)
File "D:\ProgramData\Anaconda3\envs\pytorch\lib\multiprocessing\reduction.py", line 60, in dump
ForkingPickler(file, protocol).dump(obj)
AttributeError: Can't pickle local object 'init_formatter..train_collate_fn'

请问大佬们运行train时有没有遇到这个问题?如何解决呀?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.