Skip to content

Commit

Permalink
[SYCL] Fixed minor bug when enabling FP16 for non intel targets (gger…
Browse files Browse the repository at this point in the history
…ganov#6464)

* moved INTEL_MKL guard from gemm_impl to gemm (wrapper)

* Update ggml-sycl.cpp

Co-authored-by: AidanBeltonS <87009434+AidanBeltonS@users.noreply.github.com>

---------

Co-authored-by: AidanBeltonS <87009434+AidanBeltonS@users.noreply.github.com>
  • Loading branch information
2 people authored and tybalex committed Apr 17, 2024
1 parent 365a4fd commit d2acf2a
Showing 1 changed file with 2 additions and 19 deletions.
21 changes: 2 additions & 19 deletions ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1664,24 +1664,6 @@ namespace dpct
const void *alpha, const void *a, int lda, const void *b,
int ldb, const void *beta, void *c, int ldc)
{
#ifndef __INTEL_MKL__
GGML_UNUSED(q);
GGML_UNUSED(a_trans);
GGML_UNUSED(b_trans);
GGML_UNUSED(m);
GGML_UNUSED(n);
GGML_UNUSED(k);
GGML_UNUSED(alpha);
GGML_UNUSED(a);
GGML_UNUSED(lda);
GGML_UNUSED(b);
GGML_UNUSED(ldb);
GGML_UNUSED(beta);
GGML_UNUSED(c);
GGML_UNUSED(ldc);
throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces "
"Project does not support this API.");
#else
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
auto data_a = get_memory<const Ta>(a);
Expand All @@ -1690,7 +1672,6 @@ namespace dpct
oneapi::mkl::blas::column_major::gemm(
q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
data_b, ldb, beta_value, data_c, ldc);
#endif
}

template <typename VecT, class BinaryOperation, class = void>
Expand Down Expand Up @@ -2330,6 +2311,7 @@ namespace dpct
lda, b, ldb, beta, c, ldc);
break;
}
#ifdef __INTEL_MKL__
case detail::get_type_combination_id(
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
library_data_t::real_float, library_data_t::real_float):
Expand Down Expand Up @@ -2391,6 +2373,7 @@ namespace dpct
q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc);
break;
}
#endif // __INTEL_MKL__
default:
throw std::runtime_error("the combination of data type is unsupported");
}
Expand Down

0 comments on commit d2acf2a

Please sign in to comment.