Spaces:
Sleeping
Sleeping
Romain Biessy
commited on
Commit
·
d89f6fa
1
Parent(s):
fe06fa2
SYCL: Rename oneMKL to oneMath (llama/12192)
Browse files* Rename oneMKL Interface to oneMath
* Use oneMath for Intel vendor
* Rename occurences to mkl
* clang-format
* Silence verbose warnings
* Set oneMath HIP_TARGETS
* Fix silence warnings
* Remove step to build oneMath from build instructions
* Use fixed oneMath version
* Remove INTEL_CPU
* Fold CMake oneDNN conditions
* Use Intel oneMKL for Intel devices
* Improve CMake message
* Link against MKL::MKL_SYCL::BLAS only
* Move oneMath documentation to Nvidia and AMD sections
- ggml/src/ggml-sycl/CMakeLists.txt +87 -22
- ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
- ggml/src/ggml-sycl/ggml-sycl.cpp +11 -22
- ggml/src/ggml-sycl/outprod.cpp +4 -14
ggml/src/ggml-sycl/CMakeLists.txt
CHANGED
|
@@ -23,6 +23,23 @@ ggml_add_backend_library(ggml-sycl
|
|
| 23 |
../../include/ggml-sycl.h
|
| 24 |
)
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
find_package(DNNL)
|
| 27 |
set(GGML_SYCL_DNNL 0)
|
| 28 |
if(DNNL_FOUND)
|
|
@@ -62,8 +79,6 @@ if (GGML_SYCL_F16)
|
|
| 62 |
add_compile_definitions(GGML_SYCL_F16)
|
| 63 |
endif()
|
| 64 |
|
| 65 |
-
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing -fsycl")
|
| 66 |
-
|
| 67 |
if (GGML_SYCL_TARGET STREQUAL "NVIDIA")
|
| 68 |
add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
|
| 69 |
elseif (GGML_SYCL_TARGET STREQUAL "AMD")
|
|
@@ -76,34 +91,84 @@ else()
|
|
| 76 |
add_compile_definitions(GGML_SYCL_WARP_SIZE=16)
|
| 77 |
endif()
|
| 78 |
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
|
| 84 |
-
|
| 85 |
-
|
|
|
|
|
|
|
| 86 |
find_package(MKL REQUIRED)
|
| 87 |
-
target_link_libraries(ggml-sycl PRIVATE
|
|
|
|
| 88 |
else()
|
| 89 |
-
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
endif()
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
|
|
|
| 98 |
elseif (GGML_SYCL_TARGET STREQUAL "AMD")
|
| 99 |
if (NOT GGML_SYCL_DEVICE_ARCH)
|
| 100 |
message(ERROR "Can't enable SYCL hip backend, GGML_SYCL_DEVICE_ARCH has not been set.")
|
| 101 |
endif()
|
| 102 |
-
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
endif()
|
|
|
|
| 105 |
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
endif()
|
|
|
|
| 23 |
../../include/ggml-sycl.h
|
| 24 |
)
|
| 25 |
|
| 26 |
+
file(GLOB GGML_HEADERS_SYCL "*.hpp")
|
| 27 |
+
file(GLOB GGML_SOURCES_SYCL "*.cpp")
|
| 28 |
+
target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL})
|
| 29 |
+
|
| 30 |
+
find_package(IntelSYCL)
|
| 31 |
+
if (IntelSYCL_FOUND)
|
| 32 |
+
# Use oneAPI CMake when possible
|
| 33 |
+
target_link_libraries(ggml-sycl PRIVATE IntelSYCL::SYCL_CXX)
|
| 34 |
+
else()
|
| 35 |
+
# Fallback to the simplest way of enabling SYCL when using intel/llvm nightly for instance
|
| 36 |
+
target_compile_options(ggml-sycl PRIVATE "-fsycl")
|
| 37 |
+
target_link_options(ggml-sycl PRIVATE "-fsycl")
|
| 38 |
+
endif()
|
| 39 |
+
|
| 40 |
+
target_compile_options(ggml-sycl PRIVATE "-Wno-narrowing")
|
| 41 |
+
|
| 42 |
+
# Link against oneDNN
|
| 43 |
find_package(DNNL)
|
| 44 |
set(GGML_SYCL_DNNL 0)
|
| 45 |
if(DNNL_FOUND)
|
|
|
|
| 79 |
add_compile_definitions(GGML_SYCL_F16)
|
| 80 |
endif()
|
| 81 |
|
|
|
|
|
|
|
| 82 |
if (GGML_SYCL_TARGET STREQUAL "NVIDIA")
|
| 83 |
add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
|
| 84 |
elseif (GGML_SYCL_TARGET STREQUAL "AMD")
|
|
|
|
| 91 |
add_compile_definitions(GGML_SYCL_WARP_SIZE=16)
|
| 92 |
endif()
|
| 93 |
|
| 94 |
+
if (GGML_SYCL_GRAPH)
|
| 95 |
+
target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_GRAPH)
|
| 96 |
+
endif()
|
|
|
|
| 97 |
|
| 98 |
+
# Link against Intel oneMKL or oneMath
|
| 99 |
+
if (GGML_SYCL_TARGET STREQUAL "INTEL")
|
| 100 |
+
# Intel devices use Intel oneMKL directly instead of oneMath to avoid the limitation of linking Intel oneMKL statically
|
| 101 |
+
# See https://github.com/uxlfoundation/oneMath/issues/654
|
| 102 |
find_package(MKL REQUIRED)
|
| 103 |
+
target_link_libraries(ggml-sycl PRIVATE MKL::MKL_SYCL::BLAS)
|
| 104 |
+
target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_USE_INTEL_ONEMKL)
|
| 105 |
else()
|
| 106 |
+
find_package(oneMath QUIET)
|
| 107 |
+
if (NOT oneMath_FOUND)
|
| 108 |
+
message(STATUS "oneMath not found: oneMath will be automatically downloaded")
|
| 109 |
+
# Use FetchContent to automatically pull and build oneMath
|
| 110 |
+
include(FetchContent)
|
| 111 |
+
set(BUILD_FUNCTIONAL_TESTS False)
|
| 112 |
+
set(BUILD_EXAMPLES False)
|
| 113 |
+
set(TARGET_DOMAINS blas)
|
| 114 |
+
if (GGML_SYCL_TARGET STREQUAL "NVIDIA")
|
| 115 |
+
set(ENABLE_MKLCPU_BACKEND False)
|
| 116 |
+
set(ENABLE_MKLGPU_BACKEND False)
|
| 117 |
+
set(ENABLE_CUBLAS_BACKEND True)
|
| 118 |
+
elseif (GGML_SYCL_TARGET STREQUAL "AMD")
|
| 119 |
+
set(ENABLE_MKLCPU_BACKEND False)
|
| 120 |
+
set(ENABLE_MKLGPU_BACKEND False)
|
| 121 |
+
set(ENABLE_ROCBLAS_BACKEND True)
|
| 122 |
+
# Ensure setting a string variable here is not overriden by oneMath CACHE variables
|
| 123 |
+
cmake_policy(SET CMP0126 NEW)
|
| 124 |
+
# Setting the device architecture is only needed and useful for AMD devices in oneMath
|
| 125 |
+
set(HIP_TARGETS ${GGML_SYCL_DEVICE_ARCH} CACHE STRING "oneMath HIP target" FORCE)
|
| 126 |
+
endif()
|
| 127 |
+
FetchContent_Declare(
|
| 128 |
+
ONEMATH
|
| 129 |
+
GIT_REPOSITORY https://github.com/uxlfoundation/oneMath.git
|
| 130 |
+
GIT_TAG c255b1b4c41e2ee3059455c1f96a965d6a62568a
|
| 131 |
+
)
|
| 132 |
+
FetchContent_MakeAvailable(ONEMATH)
|
| 133 |
+
# Create alias to match with find_package targets name
|
| 134 |
+
function(onemath_alias target)
|
| 135 |
+
if (TARGET ${target}_obj)
|
| 136 |
+
# Silence verbose warnings from external libraries
|
| 137 |
+
target_compile_options(${target}_obj PRIVATE -w)
|
| 138 |
+
endif()
|
| 139 |
+
if (TARGET ${target})
|
| 140 |
+
add_library(ONEMATH::${target} ALIAS ${target})
|
| 141 |
+
endif()
|
| 142 |
+
endfunction()
|
| 143 |
+
onemath_alias(onemath)
|
| 144 |
+
onemath_alias(onemath_blas_mklcpu)
|
| 145 |
+
onemath_alias(onemath_blas_mklgpu)
|
| 146 |
+
onemath_alias(onemath_blas_cublas)
|
| 147 |
+
onemath_alias(onemath_blas_rocblas)
|
| 148 |
endif()
|
| 149 |
+
|
| 150 |
+
# Below oneMath compile-time dispatching is used for better performance
|
| 151 |
+
if (GGML_SYCL_TARGET STREQUAL "NVIDIA")
|
| 152 |
+
target_link_libraries(ggml-sycl PRIVATE ONEMATH::onemath_blas_cublas)
|
| 153 |
+
target_compile_options(ggml-sycl PRIVATE "-fsycl-targets=nvptx64-nvidia-cuda")
|
| 154 |
+
target_link_options(ggml-sycl PRIVATE "-fsycl-targets=nvptx64-nvidia-cuda")
|
| 155 |
+
target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_NVIDIA)
|
| 156 |
elseif (GGML_SYCL_TARGET STREQUAL "AMD")
|
| 157 |
if (NOT GGML_SYCL_DEVICE_ARCH)
|
| 158 |
message(ERROR "Can't enable SYCL hip backend, GGML_SYCL_DEVICE_ARCH has not been set.")
|
| 159 |
endif()
|
| 160 |
+
target_link_libraries(ggml-sycl PRIVATE ONEMATH::onemath_blas_rocblas)
|
| 161 |
+
target_compile_options(ggml-sycl PRIVATE "-fsycl-targets=amdgcn-amd-amdhsa")
|
| 162 |
+
target_link_options(ggml-sycl PRIVATE "-fsycl-targets=amdgcn-amd-amdhsa")
|
| 163 |
+
target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_AMD)
|
| 164 |
+
else()
|
| 165 |
+
# Fallback to oneMath runtime dispatcher
|
| 166 |
+
target_link_libraries(ggml-sycl PRIVATE ONEMATH::onemath)
|
| 167 |
+
target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_GENERIC)
|
| 168 |
endif()
|
| 169 |
+
endif()
|
| 170 |
|
| 171 |
+
if (GGML_SYCL_DEVICE_ARCH)
|
| 172 |
+
target_compile_options(ggml-sycl PRIVATE -Xsycl-target-backend --offload-arch=${GGML_SYCL_DEVICE_ARCH})
|
| 173 |
+
target_link_options(ggml-sycl PRIVATE -Xsycl-target-backend --offload-arch=${GGML_SYCL_DEVICE_ARCH})
|
| 174 |
endif()
|
ggml/src/ggml-sycl/dpct/helper.hpp
CHANGED
|
@@ -16,9 +16,18 @@
|
|
| 16 |
#include <sycl/sycl.hpp>
|
| 17 |
#include <sycl/half_type.hpp>
|
| 18 |
#include <syclcompat/math.hpp>
|
| 19 |
-
#include <oneapi/mkl.hpp>
|
| 20 |
#include <map>
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
#include "ggml.h"
|
| 23 |
|
| 24 |
#if defined(__linux__)
|
|
@@ -83,13 +92,32 @@ inline std::string get_device_backend_and_type(const sycl::device &device) {
|
|
| 83 |
}
|
| 84 |
|
| 85 |
template <typename Ts> struct matrix_info_t {
|
| 86 |
-
oneapi::
|
| 87 |
Ts value_info[2];
|
| 88 |
std::int64_t size_info[3];
|
| 89 |
std::int64_t ld_info[3];
|
| 90 |
std::int64_t groupsize_info;
|
| 91 |
};
|
| 92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
namespace dpct
|
| 94 |
{
|
| 95 |
typedef sycl::queue *queue_ptr;
|
|
@@ -1686,26 +1714,18 @@ namespace dpct
|
|
| 1686 |
|
| 1687 |
namespace detail
|
| 1688 |
{
|
| 1689 |
-
|
| 1690 |
-
|
| 1691 |
-
|
| 1692 |
-
|
| 1693 |
-
|
| 1694 |
-
|
| 1695 |
-
|
| 1696 |
-
|
| 1697 |
-
|
| 1698 |
-
|
| 1699 |
-
|
| 1700 |
-
|
| 1701 |
-
oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q },
|
| 1702 |
-
a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
|
| 1703 |
-
beta_value, data_c, ldc);
|
| 1704 |
-
#else
|
| 1705 |
-
oneapi::mkl::blas::column_major::gemm(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
|
| 1706 |
-
beta_value, data_c, ldc);
|
| 1707 |
-
#endif
|
| 1708 |
-
}
|
| 1709 |
|
| 1710 |
template <typename VecT, class BinaryOperation, class = void>
|
| 1711 |
class vectorized_binary
|
|
@@ -1735,7 +1755,7 @@ namespace dpct
|
|
| 1735 |
};
|
| 1736 |
|
| 1737 |
template <class Ta, class Tb, class Tc, class Ts>
|
| 1738 |
-
inline void gemm_batch_impl(sycl::queue & q, oneapi::
|
| 1739 |
int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,
|
| 1740 |
int ldb, const void * beta, void ** c, int ldc, int batch_size,
|
| 1741 |
matrix_info_t<float> * matrix_info) {
|
|
@@ -1754,48 +1774,28 @@ namespace dpct
|
|
| 1754 |
matrix_info->ld_info[2] = ldc;
|
| 1755 |
matrix_info->groupsize_info = batch_size;
|
| 1756 |
|
| 1757 |
-
|
| 1758 |
-
|
| 1759 |
-
|
| 1760 |
-
matrix_info->
|
| 1761 |
-
|
| 1762 |
-
reinterpret_cast<
|
| 1763 |
-
matrix_info->ld_info + 1,
|
| 1764 |
-
reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
|
| 1765 |
-
#else
|
| 1766 |
-
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
| 1767 |
-
q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
|
| 1768 |
-
matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info),
|
| 1769 |
-
reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
|
| 1770 |
-
matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1),
|
| 1771 |
-
reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
|
| 1772 |
-
#endif
|
| 1773 |
}
|
| 1774 |
|
| 1775 |
template <class Ta, class Tb, class Tc, class Ts>
|
| 1776 |
-
inline void
|
| 1777 |
-
|
| 1778 |
-
|
| 1779 |
-
|
| 1780 |
-
long long int stride_a, const void *b, int ldb,
|
| 1781 |
-
long long int stride_b, const void *beta, void *c,
|
| 1782 |
-
int ldc, long long int stride_c, int batch_size)
|
| 1783 |
-
{
|
| 1784 |
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
|
| 1785 |
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
|
| 1786 |
auto data_a = get_memory<const Ta>(a);
|
| 1787 |
auto data_b = get_memory<const Tb>(b);
|
| 1788 |
auto data_c = get_memory<Tc>(c);
|
| 1789 |
-
|
| 1790 |
-
|
| 1791 |
-
|
| 1792 |
-
alpha_value, data_a, lda, stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc, stride_c,
|
| 1793 |
-
batch_size);
|
| 1794 |
-
#else
|
| 1795 |
-
oneapi::mkl::blas::column_major::gemm_batch(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
|
| 1796 |
-
stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc,
|
| 1797 |
-
stride_c, batch_size);
|
| 1798 |
-
#endif
|
| 1799 |
}
|
| 1800 |
|
| 1801 |
} // namespace detail
|
|
@@ -2259,13 +2259,10 @@ namespace dpct
|
|
| 2259 |
sycl::range<3>(x, y, 1), direction);
|
| 2260 |
}
|
| 2261 |
|
| 2262 |
-
inline void gemm(sycl::queue &q, oneapi::
|
| 2263 |
-
|
| 2264 |
-
const void *
|
| 2265 |
-
|
| 2266 |
-
const void *beta, void *c, library_data_t c_type, int ldc,
|
| 2267 |
-
library_data_t scaling_type)
|
| 2268 |
-
{
|
| 2269 |
if (scaling_type == library_data_t::real_float &&
|
| 2270 |
c_type == library_data_t::complex_float)
|
| 2271 |
{
|
|
@@ -2329,9 +2326,8 @@ namespace dpct
|
|
| 2329 |
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
| 2330 |
library_data_t::real_float, library_data_t::real_float):
|
| 2331 |
{
|
| 2332 |
-
detail::gemm_impl<oneapi::
|
| 2333 |
-
|
| 2334 |
-
ldb, beta, c, ldc);
|
| 2335 |
break;
|
| 2336 |
}
|
| 2337 |
case detail::get_type_combination_id(
|
|
@@ -2369,8 +2365,7 @@ namespace dpct
|
|
| 2369 |
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
| 2370 |
library_data_t::real_bfloat16, library_data_t::real_float):
|
| 2371 |
{
|
| 2372 |
-
detail::gemm_impl<oneapi::
|
| 2373 |
-
oneapi::mkl::bfloat16, float>(
|
| 2374 |
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
| 2375 |
break;
|
| 2376 |
}
|
|
@@ -2390,7 +2385,7 @@ namespace dpct
|
|
| 2390 |
default:
|
| 2391 |
throw std::runtime_error("the combination of data type is unsupported");
|
| 2392 |
}
|
| 2393 |
-
}
|
| 2394 |
|
| 2395 |
/// Computes a batch of matrix-matrix product with general matrices.
|
| 2396 |
/// \param [in] q The queue where the routine should be executed.
|
|
@@ -2412,7 +2407,7 @@ namespace dpct
|
|
| 2412 |
/// \param [in] ldc Leading dimension of C.
|
| 2413 |
/// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
|
| 2414 |
/// \param [in] scaling_type Data type of the scaling factors.
|
| 2415 |
-
inline void gemm_batch(sycl::queue & q, oneapi::
|
| 2416 |
int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,
|
| 2417 |
const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],
|
| 2418 |
library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,
|
|
@@ -2450,7 +2445,7 @@ namespace dpct
|
|
| 2450 |
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
| 2451 |
library_data_t::real_bfloat16, library_data_t::real_float):
|
| 2452 |
{
|
| 2453 |
-
detail::gemm_batch_impl<oneapi::
|
| 2454 |
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
| 2455 |
break;
|
| 2456 |
}
|
|
@@ -2458,7 +2453,7 @@ namespace dpct
|
|
| 2458 |
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
| 2459 |
library_data_t::real_float, library_data_t::real_float):
|
| 2460 |
{
|
| 2461 |
-
detail::gemm_batch_impl<oneapi::
|
| 2462 |
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
| 2463 |
break;
|
| 2464 |
}
|
|
@@ -2534,15 +2529,11 @@ namespace dpct
|
|
| 2534 |
/// \param [in] stride_c Stride between the different C matrices.
|
| 2535 |
/// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
|
| 2536 |
/// \param [in] scaling_type Data type of the scaling factors.
|
| 2537 |
-
inline void gemm_batch(sycl::queue &q, oneapi::
|
| 2538 |
-
|
| 2539 |
-
|
| 2540 |
-
|
| 2541 |
-
|
| 2542 |
-
const void *beta, void *c, library_data_t c_type,
|
| 2543 |
-
int ldc, long long int stride_c, int batch_size,
|
| 2544 |
-
library_data_t scaling_type)
|
| 2545 |
-
{
|
| 2546 |
if (scaling_type == library_data_t::real_float &&
|
| 2547 |
c_type == library_data_t::complex_float)
|
| 2548 |
{
|
|
@@ -2611,20 +2602,18 @@ namespace dpct
|
|
| 2611 |
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
| 2612 |
library_data_t::real_bfloat16, library_data_t::real_float):
|
| 2613 |
{
|
| 2614 |
-
detail::gemm_batch_impl<oneapi::
|
| 2615 |
-
|
| 2616 |
-
|
| 2617 |
-
beta, c, ldc, stride_c, batch_size);
|
| 2618 |
break;
|
| 2619 |
}
|
| 2620 |
case detail::get_type_combination_id(
|
| 2621 |
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
| 2622 |
library_data_t::real_float, library_data_t::real_float):
|
| 2623 |
{
|
| 2624 |
-
detail::gemm_batch_impl<oneapi::
|
| 2625 |
-
|
| 2626 |
-
|
| 2627 |
-
stride_c, batch_size);
|
| 2628 |
break;
|
| 2629 |
}
|
| 2630 |
#endif
|
|
|
|
| 16 |
#include <sycl/sycl.hpp>
|
| 17 |
#include <sycl/half_type.hpp>
|
| 18 |
#include <syclcompat/math.hpp>
|
|
|
|
| 19 |
#include <map>
|
| 20 |
|
| 21 |
+
#ifdef GGML_SYCL_USE_INTEL_ONEMKL
|
| 22 |
+
#include <oneapi/mkl.hpp>
|
| 23 |
+
// Allow to use the same namespace for Intel oneMKL and oneMath
|
| 24 |
+
namespace oneapi {
|
| 25 |
+
namespace math = mkl;
|
| 26 |
+
}
|
| 27 |
+
#else
|
| 28 |
+
#include <oneapi/math.hpp>
|
| 29 |
+
#endif
|
| 30 |
+
|
| 31 |
#include "ggml.h"
|
| 32 |
|
| 33 |
#if defined(__linux__)
|
|
|
|
| 92 |
}
|
| 93 |
|
| 94 |
template <typename Ts> struct matrix_info_t {
|
| 95 |
+
oneapi::math::transpose transpose_info[2];
|
| 96 |
Ts value_info[2];
|
| 97 |
std::int64_t size_info[3];
|
| 98 |
std::int64_t ld_info[3];
|
| 99 |
std::int64_t groupsize_info;
|
| 100 |
};
|
| 101 |
|
| 102 |
+
inline auto get_onemath_backend(sycl::queue& queue)
|
| 103 |
+
#if defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL)
|
| 104 |
+
-> sycl::queue&
|
| 105 |
+
#endif
|
| 106 |
+
{
|
| 107 |
+
// If the backend is known at compile-time, use oneMath backend_selector to use
|
| 108 |
+
// compile-time dispatching and avoid the need to dlopen libraries. Otherwise
|
| 109 |
+
// fallback to runtime dispatching.
|
| 110 |
+
#if defined(GGML_SYCL_NVIDIA)
|
| 111 |
+
return oneapi::math::backend_selector<oneapi::math::backend::cublas>{ queue };
|
| 112 |
+
#elif defined(GGML_SYCL_AMD)
|
| 113 |
+
return oneapi::math::backend_selector<oneapi::math::backend::rocblas>{ queue };
|
| 114 |
+
#elif defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL)
|
| 115 |
+
return queue;
|
| 116 |
+
#else
|
| 117 |
+
static_assert(false, "Unsupported backend");
|
| 118 |
+
#endif
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
namespace dpct
|
| 122 |
{
|
| 123 |
typedef sycl::queue *queue_ptr;
|
|
|
|
| 1714 |
|
| 1715 |
namespace detail
|
| 1716 |
{
|
| 1717 |
+
template <class Ta, class Tb, class Tc, class Ts>
|
| 1718 |
+
inline void gemm_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
|
| 1719 |
+
int n, int k, const void * alpha, const void * a, int lda, const void * b, int ldb,
|
| 1720 |
+
const void * beta, void * c, int ldc) {
|
| 1721 |
+
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
|
| 1722 |
+
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
|
| 1723 |
+
auto data_a = get_memory<const Ta>(a);
|
| 1724 |
+
auto data_b = get_memory<const Tb>(b);
|
| 1725 |
+
auto data_c = get_memory<Tc>(c);
|
| 1726 |
+
oneapi::math::blas::column_major::gemm(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value, data_a,
|
| 1727 |
+
lda, data_b, ldb, beta_value, data_c, ldc);
|
| 1728 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1729 |
|
| 1730 |
template <typename VecT, class BinaryOperation, class = void>
|
| 1731 |
class vectorized_binary
|
|
|
|
| 1755 |
};
|
| 1756 |
|
| 1757 |
template <class Ta, class Tb, class Tc, class Ts>
|
| 1758 |
+
inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans,
|
| 1759 |
int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,
|
| 1760 |
int ldb, const void * beta, void ** c, int ldc, int batch_size,
|
| 1761 |
matrix_info_t<float> * matrix_info) {
|
|
|
|
| 1774 |
matrix_info->ld_info[2] = ldc;
|
| 1775 |
matrix_info->groupsize_info = batch_size;
|
| 1776 |
|
| 1777 |
+
sycl::event e = oneapi::math::blas::column_major::gemm_batch(
|
| 1778 |
+
get_onemath_backend(q), matrix_info->transpose_info, matrix_info->transpose_info + 1,
|
| 1779 |
+
matrix_info->size_info, matrix_info->size_info + 1, matrix_info->size_info + 2,
|
| 1780 |
+
reinterpret_cast<Ts *>(matrix_info->value_info), reinterpret_cast<const Ta **>(a), matrix_info->ld_info,
|
| 1781 |
+
reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
|
| 1782 |
+
reinterpret_cast<Ts *>(matrix_info->value_info + 1), reinterpret_cast<Tc **>(c),
|
| 1783 |
+
matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1784 |
}
|
| 1785 |
|
| 1786 |
template <class Ta, class Tb, class Tc, class Ts>
|
| 1787 |
+
inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans,
|
| 1788 |
+
int m, int n, int k, const void * alpha, const void * a, int lda,
|
| 1789 |
+
long long int stride_a, const void * b, int ldb, long long int stride_b,
|
| 1790 |
+
const void * beta, void * c, int ldc, long long int stride_c, int batch_size) {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1791 |
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
|
| 1792 |
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
|
| 1793 |
auto data_a = get_memory<const Ta>(a);
|
| 1794 |
auto data_b = get_memory<const Tb>(b);
|
| 1795 |
auto data_c = get_memory<Tc>(c);
|
| 1796 |
+
oneapi::math::blas::column_major::gemm_batch(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value,
|
| 1797 |
+
data_a, lda, stride_a, data_b, ldb, stride_b, beta_value,
|
| 1798 |
+
data_c, ldc, stride_c, batch_size);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1799 |
}
|
| 1800 |
|
| 1801 |
} // namespace detail
|
|
|
|
| 2259 |
sycl::range<3>(x, y, 1), direction);
|
| 2260 |
}
|
| 2261 |
|
| 2262 |
+
inline void gemm(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m, int n,
|
| 2263 |
+
int k, const void * alpha, const void * a, library_data_t a_type, int lda, const void * b,
|
| 2264 |
+
library_data_t b_type, int ldb, const void * beta, void * c, library_data_t c_type, int ldc,
|
| 2265 |
+
library_data_t scaling_type) {
|
|
|
|
|
|
|
|
|
|
| 2266 |
if (scaling_type == library_data_t::real_float &&
|
| 2267 |
c_type == library_data_t::complex_float)
|
| 2268 |
{
|
|
|
|
| 2326 |
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
| 2327 |
library_data_t::real_float, library_data_t::real_float):
|
| 2328 |
{
|
| 2329 |
+
detail::gemm_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
|
| 2330 |
+
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
|
|
|
| 2331 |
break;
|
| 2332 |
}
|
| 2333 |
case detail::get_type_combination_id(
|
|
|
|
| 2365 |
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
| 2366 |
library_data_t::real_bfloat16, library_data_t::real_float):
|
| 2367 |
{
|
| 2368 |
+
detail::gemm_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
|
|
|
|
| 2369 |
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
| 2370 |
break;
|
| 2371 |
}
|
|
|
|
| 2385 |
default:
|
| 2386 |
throw std::runtime_error("the combination of data type is unsupported");
|
| 2387 |
}
|
| 2388 |
+
} // gemm()
|
| 2389 |
|
| 2390 |
/// Computes a batch of matrix-matrix product with general matrices.
|
| 2391 |
/// \param [in] q The queue where the routine should be executed.
|
|
|
|
| 2407 |
/// \param [in] ldc Leading dimension of C.
|
| 2408 |
/// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
|
| 2409 |
/// \param [in] scaling_type Data type of the scaling factors.
|
| 2410 |
+
inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
|
| 2411 |
int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,
|
| 2412 |
const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],
|
| 2413 |
library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,
|
|
|
|
| 2445 |
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
| 2446 |
library_data_t::real_bfloat16, library_data_t::real_float):
|
| 2447 |
{
|
| 2448 |
+
detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
|
| 2449 |
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
| 2450 |
break;
|
| 2451 |
}
|
|
|
|
| 2453 |
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
| 2454 |
library_data_t::real_float, library_data_t::real_float):
|
| 2455 |
{
|
| 2456 |
+
detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
|
| 2457 |
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
| 2458 |
break;
|
| 2459 |
}
|
|
|
|
| 2529 |
/// \param [in] stride_c Stride between the different C matrices.
|
| 2530 |
/// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
|
| 2531 |
/// \param [in] scaling_type Data type of the scaling factors.
|
| 2532 |
+
inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
|
| 2533 |
+
int n, int k, const void * alpha, const void * a, library_data_t a_type, int lda,
|
| 2534 |
+
long long int stride_a, const void * b, library_data_t b_type, int ldb,
|
| 2535 |
+
long long int stride_b, const void * beta, void * c, library_data_t c_type, int ldc,
|
| 2536 |
+
long long int stride_c, int batch_size, library_data_t scaling_type) {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2537 |
if (scaling_type == library_data_t::real_float &&
|
| 2538 |
c_type == library_data_t::complex_float)
|
| 2539 |
{
|
|
|
|
| 2602 |
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
| 2603 |
library_data_t::real_bfloat16, library_data_t::real_float):
|
| 2604 |
{
|
| 2605 |
+
detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
|
| 2606 |
+
q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
|
| 2607 |
+
batch_size);
|
|
|
|
| 2608 |
break;
|
| 2609 |
}
|
| 2610 |
case detail::get_type_combination_id(
|
| 2611 |
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
| 2612 |
library_data_t::real_float, library_data_t::real_float):
|
| 2613 |
{
|
| 2614 |
+
detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
|
| 2615 |
+
q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
|
| 2616 |
+
batch_size);
|
|
|
|
| 2617 |
break;
|
| 2618 |
}
|
| 2619 |
#endif
|
ggml/src/ggml-sycl/ggml-sycl.cpp
CHANGED
|
@@ -2059,8 +2059,8 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
|
| 2059 |
const sycl::half alpha_f16 = 1.0f;
|
| 2060 |
const sycl::half beta_f16 = 0.0f;
|
| 2061 |
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
|
| 2062 |
-
*stream, oneapi::
|
| 2063 |
-
oneapi::
|
| 2064 |
&alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
|
| 2065 |
src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
|
| 2066 |
dst_f16.get(), dpct::library_data_t::real_half, ldc,
|
|
@@ -2097,17 +2097,10 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
|
| 2097 |
#if !GGML_SYCL_DNNL
|
| 2098 |
const float alpha = 1.0f;
|
| 2099 |
const float beta = 0.0f;
|
| 2100 |
-
|
| 2101 |
-
|
| 2102 |
-
|
| 2103 |
-
|
| 2104 |
-
ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
|
| 2105 |
-
# else
|
| 2106 |
-
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
|
| 2107 |
-
*stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
|
| 2108 |
-
dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
|
| 2109 |
-
dst_dd_i, ldc)));
|
| 2110 |
-
# endif
|
| 2111 |
#else
|
| 2112 |
DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i,
|
| 2113 |
DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
|
|
@@ -2836,14 +2829,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
|
|
| 2836 |
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
| 2837 |
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
| 2838 |
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
|
| 2839 |
-
*main_stream, oneapi::
|
| 2840 |
-
|
| 2841 |
-
(const char *)
|
| 2842 |
-
|
| 2843 |
-
(const char *)src1_f16, dpct::library_data_t::real_half,
|
| 2844 |
-
nb11 / nb10, nb12 / nb10, beta,
|
| 2845 |
-
(char *)dst_t, cu_data_type, ne01, nb2 / nb0,
|
| 2846 |
-
ne12 * ne13, cu_compute_type)));
|
| 2847 |
} else {
|
| 2848 |
const int ne23 = ne12*ne13;
|
| 2849 |
|
|
@@ -2878,7 +2867,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
|
|
| 2878 |
});
|
| 2879 |
}
|
| 2880 |
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
|
| 2881 |
-
*main_stream, oneapi::
|
| 2882 |
(const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
|
| 2883 |
(const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta,
|
| 2884 |
(void **) (ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get())));
|
|
|
|
| 2059 |
const sycl::half alpha_f16 = 1.0f;
|
| 2060 |
const sycl::half beta_f16 = 0.0f;
|
| 2061 |
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
|
| 2062 |
+
*stream, oneapi::math::transpose::trans,
|
| 2063 |
+
oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10,
|
| 2064 |
&alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
|
| 2065 |
src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
|
| 2066 |
dst_f16.get(), dpct::library_data_t::real_half, ldc,
|
|
|
|
| 2097 |
#if !GGML_SYCL_DNNL
|
| 2098 |
const float alpha = 1.0f;
|
| 2099 |
const float beta = 0.0f;
|
| 2100 |
+
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm(
|
| 2101 |
+
get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff,
|
| 2102 |
+
src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
|
| 2103 |
+
dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2104 |
#else
|
| 2105 |
DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i,
|
| 2106 |
DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
|
|
|
|
| 2829 |
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
| 2830 |
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
| 2831 |
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
|
| 2832 |
+
*main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
|
| 2833 |
+
(const char *) src0_as_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
|
| 2834 |
+
(const char *) src1_f16, dpct::library_data_t::real_half, nb11 / nb10, nb12 / nb10, beta, (char *) dst_t,
|
| 2835 |
+
cu_data_type, ne01, nb2 / nb0, ne12 * ne13, cu_compute_type)));
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2836 |
} else {
|
| 2837 |
const int ne23 = ne12*ne13;
|
| 2838 |
|
|
|
|
| 2867 |
});
|
| 2868 |
}
|
| 2869 |
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
|
| 2870 |
+
*main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
|
| 2871 |
(const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
|
| 2872 |
(const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta,
|
| 2873 |
(void **) (ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get())));
|
ggml/src/ggml-sycl/outprod.cpp
CHANGED
|
@@ -1,8 +1,5 @@
|
|
| 1 |
-
#include <sycl/sycl.hpp>
|
| 2 |
-
#include <oneapi/mkl.hpp>
|
| 3 |
#include "outprod.hpp"
|
| 4 |
|
| 5 |
-
|
| 6 |
void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
| 7 |
const ggml_tensor *src0 = dst->src[0];
|
| 8 |
const ggml_tensor *src1 = dst->src[1];
|
|
@@ -34,20 +31,13 @@ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
|
| 34 |
|
| 35 |
// Handle transposition of src1
|
| 36 |
const bool src1_T = ggml_is_transposed(src1);
|
| 37 |
-
const oneapi::
|
| 38 |
-
src1_T ? oneapi::mkl::transpose::nontrans : oneapi::mkl::transpose::trans;
|
| 39 |
const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
|
| 40 |
|
| 41 |
try {
|
| 42 |
-
// Perform matrix multiplication using
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha, src0_d,
|
| 46 |
-
ne00, src1_d, ldb, beta, dst_d, ne0);
|
| 47 |
-
#else
|
| 48 |
-
oneapi::mkl::blas::column_major::gemm(*stream, oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha,
|
| 49 |
-
src0_d, ne00, src1_d, ldb, beta, dst_d, ne0);
|
| 50 |
-
#endif
|
| 51 |
}
|
| 52 |
catch (sycl::exception const& exc) {
|
| 53 |
std::cerr << exc.what() << std::endl;
|
|
|
|
|
|
|
|
|
|
| 1 |
#include "outprod.hpp"
|
| 2 |
|
|
|
|
| 3 |
void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
| 4 |
const ggml_tensor *src0 = dst->src[0];
|
| 5 |
const ggml_tensor *src1 = dst->src[1];
|
|
|
|
| 31 |
|
| 32 |
// Handle transposition of src1
|
| 33 |
const bool src1_T = ggml_is_transposed(src1);
|
| 34 |
+
const oneapi::math::transpose src1_op = src1_T ? oneapi::math::transpose::nontrans : oneapi::math::transpose::trans;
|
|
|
|
| 35 |
const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
|
| 36 |
|
| 37 |
try {
|
| 38 |
+
// Perform matrix multiplication using oneMath GEMM
|
| 39 |
+
oneapi::math::blas::column_major::gemm(get_onemath_backend(*stream), oneapi::math::transpose::nontrans, src1_op,
|
| 40 |
+
ne0, ne1, ne01, alpha, src0_d, ne00, src1_d, ldb, beta, dst_d, ne0);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
}
|
| 42 |
catch (sycl::exception const& exc) {
|
| 43 |
std::cerr << exc.what() << std::endl;
|