Skip to content

Commit

Permalink
faster
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed Oct 27, 2021
1 parent 3365862 commit d1365ac
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 37 deletions.
31 changes: 26 additions & 5 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# src.utils.set_seed(42) # 是否固定随机数(固定后每次运行的生成结果都一样)

print('\nAI人工智障写作 https://github.com/BlinkDL/AI-Writer')
print('请关注我的知乎 https://zhuanlan.zhihu.com/p/394766831')
print('请关注我的知乎 https://zhuanlan.zhihu.com/p/423646620')
print('\n声明:模型的训练数据全部来自网文,缺乏生活常识。生成的文字仅供娱乐。请遵守法律法规。')

# gpu:只支持 nvidia 显卡,需要 cuda+cudnn
Expand Down Expand Up @@ -77,16 +77,37 @@
import onnxruntime as rt
sess_options = rt.SessionOptions()
sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.execution_mode = rt.ExecutionMode.ORT_SEQUENTIAL
sess_options.enable_mem_pattern = False
rt_session = rt.InferenceSession(MODEL_NAME + '.onnx', sess_options=sess_options)
rt_session.set_providers(['DmlExecutionProvider'])
rt_session.set_providers(['DmlExecutionProvider'])
else:
model = GPT(GPTConfig(vocab_size, ctx_len, n_layer=n_layer, n_head=n_head, n_embd=n_embd, n_attn=n_attn, n_ffn=n_ffn))
m2 = torch.load(MODEL_NAME + '.pth', map_location='cpu').state_dict()
for i in range(n_layer):
prefix = f'blocks.{i}.attn.'
time_w = m2[prefix + 'time_w']
time_alpha = m2[prefix + 'time_alpha']
time_beta = m2[prefix + 'time_beta']
mask = m2[prefix + 'mask']

TT = ctx_len
T = ctx_len
w = F.pad(time_w, (0, TT))
w = torch.tile(w, [TT])
w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1)
w = w[:, :, TT-1:]
w = w[:, :T, :T] * time_alpha[:, :, :T] * time_beta[:, :T, :]
w = w.masked_fill(mask[:T, :T] == 0, 0)

m2[prefix + 'time_ww'] = w
del m2[prefix + 'time_w']
del m2[prefix + 'time_alpha']
del m2[prefix + 'time_beta']
del m2[prefix + 'mask']
if RUN_DEVICE == 'gpu':
model = model.cuda()
model.load_state_dict(torch.load(MODEL_NAME + '.pth').state_dict())
else:
model.load_state_dict(torch.load(MODEL_NAME + '.pth', map_location='cpu').state_dict())
model.load_state_dict(m2)

print('done:', MODEL_NAME, '&', WORD_NAME)

Expand Down
66 changes: 52 additions & 14 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
_DEBUG_LEVEL_ = 2 # 2 = full, 1 = partial, 0 = none
PORT_NUM = 8266

RUN_DEVICE = 'gpu' # gpu 或 cpu
# gpu:只支持 nvidia 显卡,需要 cuda+cudnn
# dml:支持 amd 和 intel 显卡,需要不同的模型和一些包
# cpu:没有显卡就选它
RUN_DEVICE = 'gpu' # gpu 或 dml 或 cpu

MODEL_NAME = 'model/xuanhuan-2021-10-26'
WORD_NAME = 'model/xuanhuan-2021-10-26'
Expand Down Expand Up @@ -177,14 +180,41 @@ def train_dataset(): return None
train_dataset.itos = {int(k): v for k, v in word_table.items()}
UNKNOWN_CHAR = train_dataset.stoi['\ue083']

model = GPT(GPTConfig(vocab_size, ctx_len, n_layer=n_layer,
n_head=n_head, n_embd=n_embd, n_attn=n_attn, n_ffn=n_ffn))
if RUN_DEVICE == 'gpu':
model = model.cuda()
model.load_state_dict(torch.load(MODEL_NAME + '.pth').state_dict())
if RUN_DEVICE == 'dml':
import onnxruntime as rt
sess_options = rt.SessionOptions()
sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.execution_mode = rt.ExecutionMode.ORT_SEQUENTIAL
sess_options.enable_mem_pattern = False
rt_session = rt.InferenceSession(MODEL_NAME + '.onnx', sess_options=sess_options)
rt_session.set_providers(['DmlExecutionProvider'])
else:
model.load_state_dict(torch.load(
MODEL_NAME + '.pth', map_location='cpu').state_dict())
model = GPT(GPTConfig(vocab_size, ctx_len, n_layer=n_layer, n_head=n_head, n_embd=n_embd, n_attn=n_attn, n_ffn=n_ffn))
m2 = torch.load(MODEL_NAME + '.pth', map_location='cpu').state_dict()
for i in range(n_layer):
prefix = f'blocks.{i}.attn.'
time_w = m2[prefix + 'time_w']
time_alpha = m2[prefix + 'time_alpha']
time_beta = m2[prefix + 'time_beta']
mask = m2[prefix + 'mask']

TT = ctx_len
T = ctx_len
w = F.pad(time_w, (0, TT))
w = torch.tile(w, [TT])
w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1)
w = w[:, :, TT-1:]
w = w[:, :T, :T] * time_alpha[:, :, :T] * time_beta[:, :T, :]
w = w.masked_fill(mask[:T, :T] == 0, 0)

m2[prefix + 'time_ww'] = w
del m2[prefix + 'time_w']
del m2[prefix + 'time_alpha']
del m2[prefix + 'time_beta']
del m2[prefix + 'mask']
if RUN_DEVICE == 'gpu':
model = model.cuda()
model.load_state_dict(m2)

print('done:', MODEL_NAME, '&', WORD_NAME)

Expand Down Expand Up @@ -219,12 +249,20 @@ def train_dataset(): return None
print_begin = real_len

with torch.no_grad():
xxx = torch.tensor(
x[-ctx_len:], dtype=torch.long)[None, ...]
if RUN_DEVICE == 'gpu':
xxx = xxx.cuda()
out, _ = model(xxx)
out[:, :, UNKNOWN_CHAR] = -float('Inf')
if RUN_DEVICE == 'dml':
if real_len < ctx_len:
xxx = np.pad(x, (0, ctx_len - real_len))
else:
xxx = x
out = rt_session.run(None, {rt_session.get_inputs()[0].name: [xxx[-ctx_len:]]})
out = torch.tensor(out[0])
else:
xxx = torch.tensor(x[-ctx_len:], dtype=torch.long)[None,...]
if RUN_DEVICE == 'gpu':
xxx = xxx.cuda()
out, _ = model(xxx)
out[:, :, UNKNOWN_CHAR] = -float('Inf')

pos = -1 if real_len >= ctx_len else real_len - 1

if train_dataset.itos[int(x[real_len-1])] == '\n':
Expand Down
21 changes: 3 additions & 18 deletions src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,9 @@ def __init__(self, config, layer_id):
self.n_head = config.n_head
self.head_size = config.n_attn // config.n_head

ww = torch.ones(config.n_head, config.ctx_len)
self.time_w = nn.Parameter(ww)

self.time_alpha = nn.Parameter(
torch.ones(self.n_head, 1, config.ctx_len))
self.time_beta = nn.Parameter(
torch.ones(self.n_head, config.ctx_len, 1))
self.time_ww = nn.Parameter(
torch.ones(config.n_head, config.ctx_len, config.ctx_len))
self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
self.register_buffer("mask", torch.tril(
torch.ones(config.ctx_len, config.ctx_len)))

self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))

Expand All @@ -44,13 +37,6 @@ def __init__(self, config, layer_id):

def forward(self, x):
B, T, C = x.size()
TT = self.ctx_len
w = F.pad(self.time_w, (0, TT))
w = torch.tile(w, [TT])
w = w[:, :-TT].reshape(-1, TT, 2 * TT - 1)
w = w[:, :, TT-1:]
w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :]
w = w.masked_fill(self.mask[:T, :T] == 0, 0)

x = torch.cat(
[self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim=-1)
Expand All @@ -65,15 +51,14 @@ def forward(self, x):

kv = (k * v).view(B, T, self.n_head, self.head_size)

wkv = (torch.einsum('htu,buhc->bthc', w, kv)
wkv = (torch.einsum('htu,buhc->bthc', self.time_ww[:,:T,:T], kv)
).contiguous().view(B, T, -1)

rwkv = torch.sigmoid(r) * wkv / sum_k

rwkv = self.output(rwkv)
return rwkv * self.time_gamma[:T, :]


class RWKV_ChannelMix(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
Expand Down

0 comments on commit d1365ac

Please sign in to comment.