From 1266372322ee570d0db60face0811222741c5cf9 Mon Sep 17 00:00:00 2001 From: youralien Date: Tue, 27 Sep 2022 15:04:16 -0500 Subject: [PATCH] Updated interact_strat based on comments in issue #12 --- codes_zcj/interact.py | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/codes_zcj/interact.py b/codes_zcj/interact.py index d04d4c2..a382b34 100644 --- a/codes_zcj/interact.py +++ b/codes_zcj/interact.py @@ -127,6 +127,18 @@ def cut_seq_to_eos(sentence, eos, remove_id=None): 'eos_token_id': eos, } +id2strategy = { + 0: "Question", + 1: "Restatement or Paraphrasing", + 2: "Reflection of feelings", + 3: "Self-disclosure", + 4: "Affirmation and Reassurance", + 5: "Providing Suggestions", + 6: "Information", + 7: "Others" + } + + eof_once = False history = {'dialog': [],} print('\n\nA new conversation starts!') @@ -164,6 +176,7 @@ def cut_seq_to_eos(sentence, eos, remove_id=None): history['dialog'].append({ # dummy tgt 'text': 'n/a', 'speaker': 'sys', + 'strategy': 'Others' }) inputs = inputter.convert_data_to_inputs(history, toker, **dataloader_kwargs) inputs = inputs[-1:] @@ -173,15 +186,27 @@ def cut_seq_to_eos(sentence, eos, remove_id=None): batch.update(generation_kwargs) encoded_info, generations = model.generate(**batch) + + # out = generations[0].tolist() + # out = cut_seq_to_eos(out, eos) + # text = toker.decode(out).encode('ascii', 'ignore').decode('ascii').strip() + # print(" AI: " + text) + out = generations[0].tolist() out = cut_seq_to_eos(out, eos) text = toker.decode(out).encode('ascii', 'ignore').decode('ascii').strip() - print(" AI: " + text) + # OLD WAY + # strat_id_out = encoded_info['pred_strat_id_top3'].tolist()[0][0] # 取top1 策略id + + # AUTHORS SUGGESTION + strat_id_out = encoded_info['pred_strat_id'][0] + strategy = id2strategy[strat_id_out] + print(" AI: " + "[" + strategy + "] " + text) + history['dialog'].pop() history['dialog'].append({ 'text': text, 'speaker': 'sys', + 'strategy': strategy }) - -