Skip to content

Commit

Permalink
Fix causal masks dtype (#443)
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler authored Jun 16, 2024
1 parent d355714 commit 3a79137
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion mistralrs-core/src/models/quantized_phi2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ impl ModelWeights {
let mask = CausalMasker.make_causal_mask_as_attn_bias(
input_ids,
&cache,
xs.dtype(),
DType::F32,
self.layers[0].n_head,
)?;
for (i, layer) in self.layers.iter_mut().enumerate() {
Expand Down
2 changes: 1 addition & 1 deletion mistralrs-core/src/models/quantized_phi3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ impl ModelWeights {
input_ids,
&cache,
Some(self.max_seq_len),
xs.dtype(),
DType::F32,
self.layers[0].n_head,
)?;
for (i, layer) in self.layers.iter_mut().enumerate() {
Expand Down
2 changes: 1 addition & 1 deletion mistralrs-core/src/xlora_models/quantized_llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,7 @@ impl ModelWeights {
let mask = CausalMasker.make_causal_mask_as_attn_bias(
x,
&cache,
x.dtype(),
DType::F32,
self.layers[0].n_head,
)?;
for (i, layer) in self.layers.iter_mut().enumerate() {
Expand Down
2 changes: 1 addition & 1 deletion mistralrs-core/src/xlora_models/quantized_phi3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ impl ModelWeights {
input_ids,
&cache,
Some(self.max_seq_len),
xs.dtype(),
DType::F32,
self.layers[0].n_head,
)?;
for (i, layer) in self.layers.iter_mut().enumerate() {
Expand Down

0 comments on commit 3a79137

Please sign in to comment.