Neo Zhang Jianyu commited on
Commit
6901743
·
1 Parent(s): 6ccb5a5

fix memcpy() crash, add missed cmd in guide, fix softmax (llama/6622)

Browse files

* disable mmap to fix memcpy crash, add missed cmd in guide, fix softmax

* refactor to disable mmap for SYCL backend

* fix compile error in other os

* refactor the solution, use host buf to fix it, instead of disable mmap

* keep to support mmap()

* use host buff to reduce malloc times

* revert to malloc/free solution, for threaad safe

Files changed (1) hide show
  1. ggml-sycl.cpp +14 -5
ggml-sycl.cpp CHANGED
@@ -3154,7 +3154,6 @@ typedef float (*vec_dot_q_mul_mat_sycl_t)(
3154
  #define SYCL_SCALE_BLOCK_SIZE 256
3155
  #define SYCL_CLAMP_BLOCK_SIZE 256
3156
  #define SYCL_ROPE_BLOCK_SIZE 256
3157
- #define SYCL_SOFT_MAX_BLOCK_SIZE 1024
3158
  #define SYCL_ALIBI_BLOCK_SIZE 32
3159
  #define SYCL_DIAG_MASK_INF_BLOCK_SIZE 32
3160
  #define SYCL_QUANTIZE_BLOCK_SIZE 256
@@ -13080,11 +13079,13 @@ static void soft_max_f32_sycl(const float * x, const float * mask, const float *
13080
  const int nrows_y, const float scale, const float max_bias,
13081
  dpct::queue_ptr stream) {
13082
  int nth = WARP_SIZE;
13083
- while (nth < ncols_x && nth < SYCL_SOFT_MAX_BLOCK_SIZE) nth *= 2;
 
 
 
13084
  const sycl::range<3> block_dims(1, 1, nth);
13085
  const sycl::range<3> block_nums(1, 1, nrows_x);
13086
  const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE);
13087
- static_assert(SYCL_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
13088
 
13089
  const uint32_t n_head_kv = nrows_x/nrows_y;
13090
  const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
@@ -13094,6 +13095,12 @@ static void soft_max_f32_sycl(const float * x, const float * mask, const float *
13094
 
13095
  const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
13096
  if (n_local_scratch*sizeof(float) < local_mem_size) {
 
 
 
 
 
 
13097
  switch (ncols_x) {
13098
  case 32:
13099
  soft_max_f32_submitter<true, 32, 32>(x, mask, pos, dst, ncols_x, nrows_y, scale,
@@ -16814,11 +16821,13 @@ static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer,
16814
  const dpct::queue_ptr stream = g_syclStreams[ctx->device][0];
16815
  SYCL_CHECK(
16816
  CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw()));
16817
-
 
16818
  SYCL_CHECK(
16819
  CHECK_TRY_ERROR((*stream)
16820
- .memcpy((char *)tensor->data + offset, data, size)
16821
  .wait()));
 
16822
  }
16823
  catch (sycl::exception const &exc) {
16824
  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
 
3154
  #define SYCL_SCALE_BLOCK_SIZE 256
3155
  #define SYCL_CLAMP_BLOCK_SIZE 256
3156
  #define SYCL_ROPE_BLOCK_SIZE 256
 
3157
  #define SYCL_ALIBI_BLOCK_SIZE 32
3158
  #define SYCL_DIAG_MASK_INF_BLOCK_SIZE 32
3159
  #define SYCL_QUANTIZE_BLOCK_SIZE 256
 
13079
  const int nrows_y, const float scale, const float max_bias,
13080
  dpct::queue_ptr stream) {
13081
  int nth = WARP_SIZE;
13082
+ int max_block_size = g_work_group_size;
13083
+ while (nth < ncols_x && nth < max_block_size) nth *= 2;
13084
+ if (nth>max_block_size) nth = max_block_size;
13085
+
13086
  const sycl::range<3> block_dims(1, 1, nth);
13087
  const sycl::range<3> block_nums(1, 1, nrows_x);
13088
  const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE);
 
13089
 
13090
  const uint32_t n_head_kv = nrows_x/nrows_y;
13091
  const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
 
13095
 
13096
  const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
13097
  if (n_local_scratch*sizeof(float) < local_mem_size) {
13098
+ if (ncols_x > max_block_size) {
13099
+ soft_max_f32_submitter<true, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13100
+ max_bias, m0, m1, n_head_log2, block_nums,
13101
+ block_dims, n_local_scratch, stream);
13102
+ return;
13103
+ }
13104
  switch (ncols_x) {
13105
  case 32:
13106
  soft_max_f32_submitter<true, 32, 32>(x, mask, pos, dst, ncols_x, nrows_y, scale,
 
16821
  const dpct::queue_ptr stream = g_syclStreams[ctx->device][0];
16822
  SYCL_CHECK(
16823
  CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw()));
16824
+ char* host_buf = (char*)malloc(size);
16825
+ memcpy(host_buf, data, size);
16826
  SYCL_CHECK(
16827
  CHECK_TRY_ERROR((*stream)
16828
+ .memcpy((char *)tensor->data + offset, host_buf, size)
16829
  .wait()));
16830
+ free(host_buf);
16831
  }
16832
  catch (sycl::exception const &exc) {
16833
  std::cerr << exc.what() << "Exception caught at file:" << __FILE__