Skip to content

Commit

Permalink
update sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed Jan 21, 2022
1 parent 0a8d555 commit 5b77f74
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 38 deletions.
16 changes: 8 additions & 8 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,19 @@

RUN_DEVICE = 'gpu' # gpu 或 dml 或 cpu

MODEL_NAME = 'model/wangwen-2021-12-11' # 模型名
WORD_NAME = 'model/wangwen-2021-12-11' # 这个也修改
MODEL_NAME = 'model/wangwen-2022-01-09' # 模型名
WORD_NAME = 'model/wangwen-2022-01-09' # 这个也修改

NUM_OF_RUNS = 9999 # 写多少遍
LENGTH_OF_EACH = 200 # 每次写多少字

min_p_ratio = 0.02 # 这个的范围是 0 到 1。越大,生成效果越规矩。越小,变化越多。自己试试 0 和 0.1 和 1.0 的效果就知道了
top_p = 0.8 # 这个的范围是 0 到 1。越大,变化越多。越小,生成效果越规矩。自己试试 0 和 0.5 和 1.0 的效果就知道了
top_p_newline = 0.9

# 开头非常重要。开头需创造剧情点。开头文笔越好,续写就越好。开头乱写,续写也乱写。
# 开头这样输入:
# context = "我"
# context = "他"
# context = "她"
# context = "魔法"
# context = "“区区"
# context = "三体舰队"
context = "这是一颗"
# context = "众人一惊,没想到这林黛玉的剑法竟如此精妙,只见在那剑影下,剑尖朝着伏地魔的脖子探去,眼见避无可避,伏地魔情急,大喊"
Expand Down Expand Up @@ -148,9 +148,9 @@
pos = -1 if real_len >= ctx_len else real_len - 1

if train_dataset.itos[int(x[real_len-1])] == '\n':
char = src.utils.sample_logits(out, pos, temperature=1.0, top_p=0.995)
char = src.utils.sample_logits(out, pos, temperature=1.0, top_p=top_p_newline)
else:
char = src.utils.sample_logits(out, pos, temperature=1.0, min_p_pow=2.0, min_p_ratio=min_p_ratio)
char = src.utils.sample_logits(out, pos, temperature=1.0, top_p=top_p)

x = np.append(x, char)
real_len += 1
Expand Down
44 changes: 14 additions & 30 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,40 +8,24 @@
import torch.nn as nn
from torch.nn import functional as F

def top_k_logits(logits, k):
v, ix = torch.topk(logits, k)
out = logits.clone()
out[out < v[:, [-1]]] = -float('Inf')
return out
def to_float(x):
return x.cpu().detach().numpy().flatten()[0].astype(float)

def top_p_probs(probs, p):
out = probs.clone()

sorted_probs, sorted_indices = torch.sort(out, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
out[indices_to_remove] = 0

return out

# top-p + top-k + pow&ratio sampling
def sample_logits(logits, pos, temperature=1.0, top_k=None, top_p=None, min_p_pow=None, min_p_ratio=None):
logits = logits[:, pos, :] / temperature
probs = F.softmax(logits, dim=-1)
if min_p_ratio is not None:
limit = torch.pow(torch.max(probs), min_p_pow) * min_p_ratio
logits[probs < limit] = -float('Inf')
if top_k is not None:
logits = top_k_logits(logits, top_k)
def sample_logits(logits, pos, temperature=1.0, top_p=None):
logits = logits[0][pos, :]
probs = F.softmax(logits, dim=-1)

if top_p is not None:
probs[0] = top_p_probs(probs[0], top_p)
out = probs.clone()
sorted_probs, _ = torch.sort(out, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy()
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)].cpu())
probs[probs < cutoff] = 0

if temperature != 1.0:
probs = probs.pow(1.0 / temperature)
ix = torch.multinomial(probs, num_samples=1)

return ix[0][0].cpu()
return ix[0].cpu()

def set_seed(seed):
random.seed(seed)
Expand Down

0 comments on commit 5b77f74

Please sign in to comment.