Ouadie EL FAROUKI commited on
Commit
08501f8
·
1 Parent(s): 8411e3c

Enabled more data types for oneMKL gemm_batch (llama/8236)

Browse files
Files changed (1) hide show
  1. ggml/src/ggml-sycl.cpp +0 -30
ggml/src/ggml-sycl.cpp CHANGED
@@ -3493,10 +3493,6 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
3493
  SYCL_CHECK(ggml_sycl_set_device(ctx.device));
3494
  queue_ptr main_stream = ctx.stream();;
3495
 
3496
- bool no_mixed_dtypes = main_stream->get_backend() == sycl::backend::ext_oneapi_cuda ||
3497
- main_stream->get_backend() == sycl::backend::ext_oneapi_hip;
3498
-
3499
-
3500
  void * src0_ddq = src0->data;
3501
  sycl::half *src0_as_f16 = (sycl::half *)src0_ddq;
3502
  float * src1_ddf = (float *) src1->data;
@@ -3514,15 +3510,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
3514
  sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf
3515
  : src1_f16_alloc.get();
3516
 
3517
- ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
3518
  char * dst_t;
3519
 
3520
  dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float;
3521
  dpct::library_data_t cu_data_type = dpct::library_data_t::real_float;
3522
- if (no_mixed_dtypes) {
3523
- cu_compute_type = dpct::library_data_t::real_half;
3524
- cu_data_type = dpct::library_data_t::real_half;
3525
- }
3526
 
3527
  // dst strides
3528
  size_t nbd2 = dst->nb[2];
@@ -3531,26 +3522,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
3531
  const float alpha_f32 = 1.0f;
3532
  const float beta_f32 = 0.0f;
3533
 
3534
- const sycl::half alpha_f16 = 1.0f;
3535
- const sycl::half beta_f16 = 0.0f;
3536
-
3537
  const void * alpha = &alpha_f32;
3538
  const void * beta = &beta_f32;
3539
- if (no_mixed_dtypes) {
3540
- alpha = &alpha_f16;
3541
- beta = &beta_f16;
3542
- }
3543
-
3544
- // TODO: Renable (dst->op_params[0] =! GGML_PREC_DEFAULT) pathway
3545
- // when oneMKL open source supports half, half, float, float: datatypes
3546
 
3547
  dst_t = (char *) dst_ddf;
3548
- if (no_mixed_dtypes) {
3549
- dst_t = (char *) dst_f16.alloc(ne_dst);
3550
-
3551
- nbd2 /= sizeof(float) / sizeof(sycl::half);
3552
- nbd3 /= sizeof(float) / sizeof(sycl::half);
3553
- }
3554
 
3555
  GGML_ASSERT(ne12 % ne02 == 0);
3556
  GGML_ASSERT(ne13 % ne03 == 0);
@@ -3612,11 +3587,6 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
3612
  (void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23,
3613
  cu_compute_type)));
3614
  }
3615
-
3616
- if (no_mixed_dtypes) {
3617
- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
3618
- to_fp32_sycl(dst_f16.get(), dst_ddf, ne_dst, main_stream);
3619
- }
3620
  }
3621
  catch (sycl::exception const &exc) {
3622
  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
 
3493
  SYCL_CHECK(ggml_sycl_set_device(ctx.device));
3494
  queue_ptr main_stream = ctx.stream();;
3495
 
 
 
 
 
3496
  void * src0_ddq = src0->data;
3497
  sycl::half *src0_as_f16 = (sycl::half *)src0_ddq;
3498
  float * src1_ddf = (float *) src1->data;
 
3510
  sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf
3511
  : src1_f16_alloc.get();
3512
 
 
3513
  char * dst_t;
3514
 
3515
  dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float;
3516
  dpct::library_data_t cu_data_type = dpct::library_data_t::real_float;
 
 
 
 
3517
 
3518
  // dst strides
3519
  size_t nbd2 = dst->nb[2];
 
3522
  const float alpha_f32 = 1.0f;
3523
  const float beta_f32 = 0.0f;
3524
 
 
 
 
3525
  const void * alpha = &alpha_f32;
3526
  const void * beta = &beta_f32;
 
 
 
 
 
 
 
3527
 
3528
  dst_t = (char *) dst_ddf;
 
 
 
 
 
 
3529
 
3530
  GGML_ASSERT(ne12 % ne02 == 0);
3531
  GGML_ASSERT(ne13 % ne03 == 0);
 
3587
  (void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23,
3588
  cu_compute_type)));
3589
  }
 
 
 
 
 
3590
  }
3591
  catch (sycl::exception const &exc) {
3592
  std::cerr << exc.what() << "Exception caught at file:" << __FILE__