Skip to content

特征淘汰

songyue1104 edited this page Aug 12, 2019 · 1 revision

I. FilterHook

如果某个ID长时间没被更新,那说明这个ID在模型中已经处于不太重要的地位,XDL提供了删除这些ID的功能,使用方法如下:

emb1 = xdl.embedding('emb1', batch['sparse0'], xdl.TruncatedNormal(stddev=0.001), 8, 1024, vtype='hash')
emb2 = xdl.embedding('emb2', batch['sparse1'], xdl.TruncatedNormal(stddev=0.001), 8, 1024, vtype='hash')

hooks = []

vars = ["emb1", "emb2"]
mark_hook1 = xdl.GlobalStepMarkHook("emb1", batch["sparse0"].ids)
mark_hook2 = xdl.GlobalStepMarkHook("emb2", batch["sparse1"].ids)
hooks.append(mark_hook1)
hooks.append(mark_hook2)
if xdl.get_task_index() == 0:
   filter_hook = xdl.GlobalStepFilterHook(vars, 30, 10)
   hooks.append(filter_hook)

方法说明:

  1. "emb1"和"emb2"是两路需要进行ID退出的特征,针对这两路特征创建两个GlobalStepMarkHook,具体创建方法参见上述代码;
  2. 选取一个worker(这里选择worker0),创建一个GlobalStepFilterHook,第一个参数"vars"为需要进行ID退出的变量名称集合,第二个参数"30"表示global_step每隔30步,进行一次ID退出的动作,第三个参数"10",表示如果某个ID超过10步没有被更新,在下一次的ID退出动作时,这个ID就会被删除。

II. 简单的特征删除API

除了XDL提供的基于更新时间的特征删除功能之外,用户还可以自定义删除逻辑,用法如下。

例子:

xdl.execute(xdl.hash_filter(var, "filter.py", "filter", {"x":np.array(100)})

# filter.py:
import numpy as np

def filter(data_, x):
  print x
  print data_
  return data_[:, 0] < x
# 如上代码将所有样本第一列小于100的数据删除

xdl.hash_filter参数说明:

  • 第一个参数是一个variable,需要提供一个哈希的Variable。
  • 第二个参数是filter的代码文件名,可以以相对路径表示。
  • 第三个参数是filter在第二参数指定的文件中的函数名。
  • 第四个参数是payload,用一个dict表示,默认为空。
  • 返回值为被删除的数量

filter函数说明

  • filter接受多种输入参数,通过参数名hash_filter 将决定将何种数据传入filter函数中。
  • data_作为一个特殊的关键字,表示原始数据。
  • 如果参数名可以在slot中被找到,则传入对应的slot。
  • 如果参数名可以在payload中被找到,传入对应的payload。

III. 带slot记忆的特征删除API

xdl.execute(xdl.hash_slot_filter(var, "filter.py", "filter", "filter_slot", 1, {"x":np.array(100)})

#### filter.py:
import numpy as np

def filter(x, filter_slot):
  return filter_slot[:, 0] > x, filter_slot + 1
# 如上代码将所有filter_slot超过100次的数据删除,并filter_slot加1

xdl.hash_slot_filter参数说明:

  • xdl.hash_slot_filter 与xdl.hash_slot 类似,但是加入了一个slot作为临时存储变量。
  • hash_slot_filter的第四个参数为临时存储的slot的名字,而第五个参数为slot的大小。
  • slot只能为一个二维的,float类型的变量。
  • filter的第二个返回值为hash_slot_filter的新值。