Spaces:
Sleeping
Sleeping
Ouadie EL FAROUKI
AidanBeltonS
commited on
Fixed minor bug when enabling FP16 for non intel targets (llama/6464)
Browse files* moved INTEL_MKL guard from gemm_impl to gemm (wrapper)
* Update ggml-sycl.cpp
Co-authored-by: AidanBeltonS <[email protected]>
---------
Co-authored-by: AidanBeltonS <[email protected]>
- ggml-sycl.cpp +2 -19
ggml-sycl.cpp
CHANGED
|
@@ -1664,24 +1664,6 @@ namespace dpct
|
|
| 1664 |
const void *alpha, const void *a, int lda, const void *b,
|
| 1665 |
int ldb, const void *beta, void *c, int ldc)
|
| 1666 |
{
|
| 1667 |
-
#ifndef __INTEL_MKL__
|
| 1668 |
-
GGML_UNUSED(q);
|
| 1669 |
-
GGML_UNUSED(a_trans);
|
| 1670 |
-
GGML_UNUSED(b_trans);
|
| 1671 |
-
GGML_UNUSED(m);
|
| 1672 |
-
GGML_UNUSED(n);
|
| 1673 |
-
GGML_UNUSED(k);
|
| 1674 |
-
GGML_UNUSED(alpha);
|
| 1675 |
-
GGML_UNUSED(a);
|
| 1676 |
-
GGML_UNUSED(lda);
|
| 1677 |
-
GGML_UNUSED(b);
|
| 1678 |
-
GGML_UNUSED(ldb);
|
| 1679 |
-
GGML_UNUSED(beta);
|
| 1680 |
-
GGML_UNUSED(c);
|
| 1681 |
-
GGML_UNUSED(ldc);
|
| 1682 |
-
throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces "
|
| 1683 |
-
"Project does not support this API.");
|
| 1684 |
-
#else
|
| 1685 |
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
|
| 1686 |
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
|
| 1687 |
auto data_a = get_memory<const Ta>(a);
|
|
@@ -1690,7 +1672,6 @@ namespace dpct
|
|
| 1690 |
oneapi::mkl::blas::column_major::gemm(
|
| 1691 |
q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
|
| 1692 |
data_b, ldb, beta_value, data_c, ldc);
|
| 1693 |
-
#endif
|
| 1694 |
}
|
| 1695 |
|
| 1696 |
template <typename VecT, class BinaryOperation, class = void>
|
|
@@ -2330,6 +2311,7 @@ namespace dpct
|
|
| 2330 |
lda, b, ldb, beta, c, ldc);
|
| 2331 |
break;
|
| 2332 |
}
|
|
|
|
| 2333 |
case detail::get_type_combination_id(
|
| 2334 |
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
| 2335 |
library_data_t::real_float, library_data_t::real_float):
|
|
@@ -2391,6 +2373,7 @@ namespace dpct
|
|
| 2391 |
q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc);
|
| 2392 |
break;
|
| 2393 |
}
|
|
|
|
| 2394 |
default:
|
| 2395 |
throw std::runtime_error("the combination of data type is unsupported");
|
| 2396 |
}
|
|
|
|
| 1664 |
const void *alpha, const void *a, int lda, const void *b,
|
| 1665 |
int ldb, const void *beta, void *c, int ldc)
|
| 1666 |
{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1667 |
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
|
| 1668 |
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
|
| 1669 |
auto data_a = get_memory<const Ta>(a);
|
|
|
|
| 1672 |
oneapi::mkl::blas::column_major::gemm(
|
| 1673 |
q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
|
| 1674 |
data_b, ldb, beta_value, data_c, ldc);
|
|
|
|
| 1675 |
}
|
| 1676 |
|
| 1677 |
template <typename VecT, class BinaryOperation, class = void>
|
|
|
|
| 2311 |
lda, b, ldb, beta, c, ldc);
|
| 2312 |
break;
|
| 2313 |
}
|
| 2314 |
+
#ifdef __INTEL_MKL__
|
| 2315 |
case detail::get_type_combination_id(
|
| 2316 |
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
| 2317 |
library_data_t::real_float, library_data_t::real_float):
|
|
|
|
| 2373 |
q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc);
|
| 2374 |
break;
|
| 2375 |
}
|
| 2376 |
+
#endif // __INTEL_MKL__
|
| 2377 |
default:
|
| 2378 |
throw std::runtime_error("the combination of data type is unsupported");
|
| 2379 |
}
|