You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I want to get the tensor of the last layer in the inference stage. The batch_size is 1. I use the following two methods:
1: model = LlamaForCausalLM.from_pretrained(args.base_model, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map=device_map, )
Hey @xxyyzztlb. Please take a look at this comment. TL;DR: we cannot guarantee exactly same values for hidden states, due to numerical precision errors. Given that your code snippet is using float16, the error might be even higher than float32
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
I want to get the tensor of the last layer in the inference stage. The batch_size is 1. I use the following two methods:
1:
model = LlamaForCausalLM.from_pretrained(args.base_model, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map=device_map, )
output = model.module.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=10, prefix_allowed_tokens_fn=prefix_allowed_tokens, num_beams=20, num_return_sequences=20, output_scores=True, return_dict_in_generate=True, early_stopping=True, output_hidden_states=True )
last_output = output['hidden_states'][0][-1][0, -1, :]
2:
model = LlamaForCausalLM.from_pretrained(args.base_model, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map=device_map, )
outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], output_hidden_states=True)
last_output = outputs.hidden_states[-1][0, -1, :]
I found that the tensor values obtained are different. I'd like to know what are the potential reasons.
The text was updated successfully, but these errors were encountered: