Skip to content

Commit

Permalink
Updated interact_strat based on comments in issue thu-coai#12
Browse files Browse the repository at this point in the history
  • Loading branch information
youralien committed Sep 27, 2022
1 parent 9d375f5 commit 1266372
Showing 1 changed file with 28 additions and 3 deletions.
31 changes: 28 additions & 3 deletions codes_zcj/interact.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,18 @@ def cut_seq_to_eos(sentence, eos, remove_id=None):
'eos_token_id': eos,
}

id2strategy = {
0: "Question",
1: "Restatement or Paraphrasing",
2: "Reflection of feelings",
3: "Self-disclosure",
4: "Affirmation and Reassurance",
5: "Providing Suggestions",
6: "Information",
7: "Others"
}


eof_once = False
history = {'dialog': [],}
print('\n\nA new conversation starts!')
Expand Down Expand Up @@ -164,6 +176,7 @@ def cut_seq_to_eos(sentence, eos, remove_id=None):
history['dialog'].append({ # dummy tgt
'text': 'n/a',
'speaker': 'sys',
'strategy': 'Others'
})
inputs = inputter.convert_data_to_inputs(history, toker, **dataloader_kwargs)
inputs = inputs[-1:]
Expand All @@ -173,15 +186,27 @@ def cut_seq_to_eos(sentence, eos, remove_id=None):
batch.update(generation_kwargs)
encoded_info, generations = model.generate(**batch)


# out = generations[0].tolist()
# out = cut_seq_to_eos(out, eos)
# text = toker.decode(out).encode('ascii', 'ignore').decode('ascii').strip()
# print(" AI: " + text)

out = generations[0].tolist()
out = cut_seq_to_eos(out, eos)
text = toker.decode(out).encode('ascii', 'ignore').decode('ascii').strip()
print(" AI: " + text)

# OLD WAY
# strat_id_out = encoded_info['pred_strat_id_top3'].tolist()[0][0] # 取top1 策略id

# AUTHORS SUGGESTION
strat_id_out = encoded_info['pred_strat_id'][0]
strategy = id2strategy[strat_id_out]
print(" AI: " + "[" + strategy + "] " + text)

history['dialog'].pop()
history['dialog'].append({
'text': text,
'speaker': 'sys',
'strategy': strategy
})


0 comments on commit 1266372

Please sign in to comment.