-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathconfig.py
131 lines (127 loc) · 3.92 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# -*- coding: utf-8 -*-
# @Author : lishouxian
# @Email : gzlishouxian@gmail.com
# @File : config.py
# @Software: VScode
# 模式
# train: 训练分类器
# interactive_predict: 交互模式
# test: 跑测试集
# convert_onnx: 将torch模型保存onnx文件
# show_model_info: 打印模型参数
mode = 'train'
# 使用GPU设备
use_cuda = True
cuda_device = 0
configure = {
# 训练数据集
'train_file': 'data/example_datasets2/train_data.json',
# 验证数据集
'dev_file': '',
# 使用交叉验证
'kfold': False,
'fold_splits': 5,
# 没有验证集时,从训练集抽取验证集比例
'validation_rate': 0.1,
# 测试数据集
'test_file': '',
# 单词还是方块字
# 西方语法单词: word
# 中日韩等不需要空格的方块字: char
'token_level': 'char',
# 存放词表的地方
'token_file': 'data/example_datasets2/token2id.txt',
# 使用的预训练模型,这个地方的模型路径是和huggingface上的路径对应的
'ptm': 'hfl/chinese-bert-wwm-ext',
# 'ptm': 'Davlan/bert-base-multilingual-cased-ner-hrl',
# 使用的方法
# sequence_tag:序列标注
# span:方式
'method': 'span',
# 使用的模型
# sequence label方式:
# ptm crf: ptm_crf
# ptm bilstm crf: ptm_bilstm_crf
# ptm idcnn crf: ptm_idcnn_crf
# idcnn crf: idcnn_crf
# bilstm crf: bilstm_crf
# ptm: ptm
# span方式:
# binary pointer: ptm_bp
# global pointer: ptm_gp
'model_type': 'ptm_gp',
# 选择lstm时,隐藏层大小
'hidden_dim': 200,
# Embedding向量维度
'embedding_dim': 300,
# 选择idcnn时filter的个数
'filter_nums': 64,
# 模型保存的文件夹
'checkpoints_dir': 'checkpoints/example_datasets2',
# 模型名字
'model_name': 'bert.pkl',
# 类别列表
'span_classes': ['PER', 'ORG', 'LOC'],
'sequence_tag_classes': ['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC'],
# 'sequence_tag_classes': ['O', 'B-DATE', 'I-DATE', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC'],
# decision_threshold
'decision_threshold': 0.5,
# 使用bp和gp的时候是否使用苏神的多标签分类的损失函数,默认使用BCELoss
'use_multilabel_categorical_cross_entropy': True,
# 使用对抗学习
'use_gan': False,
# fgsm:Fast Gradient Sign Method
# fgm:Fast Gradient Method
# pgd:Projected Gradient Descent
# awp: Weighted Adversarial Perturbation
'gan_method': 'pgd',
# 对抗次数
'attack_round': 3,
# 使用Multisample Dropout
# 使用Multisample Dropout后dropout会失效
'multisample_dropout': False,
'dropout_round': 5,
# 随机种子
'seed': 3407,
# 预训练模型是否前置加入Noisy
'noisy_tune': False,
'noise_lambda': 0.12,
# 是否进行warmup
'warmup': False,
# 是否进行随机权重平均swa
'swa': False,
'swa_start_step': 5000,
'swa_lr': 1e-6,
# 每个多久平均一次
'anneal_epochs': 1,
# 使用EMA
'ema': False,
# warmup方法,可选:linear、cosine
'scheduler_type': 'linear',
# warmup步数,-1自动推断为总步数的0.1
'num_warmup_steps': -1,
# 句子最大长度
'max_sequence_length': 64,
# epoch
'epoch': 50,
# batch_size
'batch_size': 24,
# dropout rate
'dropout_rate': 0.5,
# 每print_per_batch打印损失函数
'print_per_batch': 100,
# learning_rate
'learning_rate': 5e-5,
# 优化器选择
'optimizer': 'AdamW',
# 执行权重初始化,仅限于非微调
'init_network': False,
# 权重初始化方式,可选:xavier、kaiming、normal
'init_network_method': 'xavier',
# fp16混合精度训练,仅在Cuda支持下使用
'use_fp16': False,
# 训练是否提前结束微调
'is_early_stop': True,
# 训练阶段的patient
'patient': 5,
}