JohannesGaessler commited on
Commit
52c45b9
·
unverified ·
1 Parent(s): b1e29bc

CUDA: faster softmax via shared memory + fp16 math (llama/4742)

Browse files
Files changed (1) hide show
  1. ggml-cuda.cu +303 -24
ggml-cuda.cu CHANGED
@@ -116,6 +116,7 @@
116
  #include "ggml.h"
117
  #include "ggml-backend-impl.h"
118
 
 
119
  #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
120
  #define CC_VOLTA 700
121
  #define CC_OFFSET_AMD 1000000
@@ -556,11 +557,12 @@ static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
556
 
557
  struct cuda_device_capabilities {
558
  int cc; // compute capability
 
559
  bool vmm; // virtual memory support
560
  size_t vmm_granularity; // granularity of virtual memory
561
  };
562
 
563
- static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, false, 0} };
564
 
565
  static void * g_scratch_buffer = nullptr;
566
  static size_t g_scratch_size = 0; // disabled by default
@@ -593,6 +595,19 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
593
  return a;
594
  }
595
 
 
 
 
 
 
 
 
 
 
 
 
 
 
596
  static __device__ __forceinline__ float warp_reduce_max(float x) {
597
  #pragma unroll
598
  for (int mask = 16; mask > 0; mask >>= 1) {
@@ -601,6 +616,19 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
601
  return x;
602
  }
603
 
 
 
 
 
 
 
 
 
 
 
 
 
 
604
  static __device__ __forceinline__ float op_repeat(const float a, const float b) {
605
  return b;
606
  GGML_UNUSED(a);
@@ -5385,75 +5413,233 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
5385
  dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
5386
  }
5387
 
5388
- static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5389
  const int tid = threadIdx.x;
5390
  const int rowx = blockIdx.x;
5391
  const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
5392
 
5393
- const int block_size = blockDim.x;
5394
 
5395
  const int warp_id = threadIdx.x / WARP_SIZE;
5396
  const int lane_id = threadIdx.x % WARP_SIZE;
5397
 
5398
- __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
 
 
 
5399
 
5400
  float max_val = -INFINITY;
5401
 
5402
- for (int col = tid; col < ncols; col += block_size) {
 
 
 
 
 
 
 
5403
  const int ix = rowx*ncols + col;
5404
  const int iy = rowy*ncols + col;
5405
- max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f));
 
 
 
5406
  }
5407
 
5408
  // find the max value in the block
5409
  max_val = warp_reduce_max(max_val);
5410
  if (block_size > WARP_SIZE) {
5411
  if (warp_id == 0) {
5412
- buf[lane_id] = -INFINITY;
5413
  }
5414
  __syncthreads();
5415
 
5416
  if (lane_id == 0) {
5417
- buf[warp_id] = max_val;
5418
  }
5419
  __syncthreads();
5420
 
5421
- max_val = buf[lane_id];
5422
  max_val = warp_reduce_max(max_val);
5423
  }
5424
 
5425
- float tmp = 0.f;
5426
 
5427
- for (int col = tid; col < ncols; col += block_size) {
5428
- const int ix = rowx*ncols + col;
5429
- const int iy = rowy*ncols + col;
5430
- const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val);
 
 
 
 
 
5431
  tmp += val;
5432
- dst[ix] = val;
5433
  }
5434
 
5435
  // find the sum of exps in the block
5436
  tmp = warp_reduce_sum(tmp);
5437
  if (block_size > WARP_SIZE) {
5438
  if (warp_id == 0) {
5439
- buf[lane_id] = 0.f;
5440
  }
5441
  __syncthreads();
5442
 
5443
  if (lane_id == 0) {
5444
- buf[warp_id] = tmp;
5445
  }
5446
  __syncthreads();
5447
 
5448
- tmp = buf[lane_id];
5449
  tmp = warp_reduce_sum(tmp);
5450
  }
5451
 
5452
- const float inv_tmp = 1.f / tmp;
5453
 
5454
- for (int col = tid; col < ncols; col += block_size) {
5455
- const int i = rowx*ncols + col;
5456
- dst[i] *= inv_tmp;
 
 
 
 
 
 
 
5457
  }
5458
  }
5459
 
@@ -6752,12 +6938,90 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
6752
  diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
6753
  }
6754
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6755
  static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
6756
  int nth = WARP_SIZE;
6757
  while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
6758
  const dim3 block_dims(nth, 1, 1);
6759
  const dim3 block_nums(nrows_x, 1, 1);
6760
- soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6761
  }
6762
 
6763
  static void im2col_f32_f16_cuda(const float* x, half* dst,
@@ -7072,6 +7336,7 @@ void ggml_init_cublas() {
7072
  #else
7073
  g_device_caps[id].cc = 100*prop.major + 10*prop.minor;
7074
  #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
 
7075
  }
7076
  for (int id = 0; id < g_device_count; ++id) {
7077
  g_tensor_split[id] /= total_vram;
@@ -8087,7 +8352,21 @@ static void ggml_cuda_op_soft_max(
8087
  float scale = 1.0f;
8088
  memcpy(&scale, dst->op_params, sizeof(float));
8089
 
8090
- soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8091
 
8092
  (void) dst;
8093
  }
 
116
  #include "ggml.h"
117
  #include "ggml-backend-impl.h"
118
 
119
+ #define CC_PASCAL 600
120
  #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
121
  #define CC_VOLTA 700
122
  #define CC_OFFSET_AMD 1000000
 
557
 
558
  struct cuda_device_capabilities {
559
  int cc; // compute capability
560
+ size_t smpb; // max. shared memory per block
561
  bool vmm; // virtual memory support
562
  size_t vmm_granularity; // granularity of virtual memory
563
  };
564
 
565
+ static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, 0, false, 0} };
566
 
567
  static void * g_scratch_buffer = nullptr;
568
  static size_t g_scratch_size = 0; // disabled by default
 
595
  return a;
596
  }
597
 
598
+ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
599
+ #if __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
600
+ (void) a;
601
+ bad_arch();
602
+ #else
603
+ #pragma unroll
604
+ for (int mask = 16; mask > 0; mask >>= 1) {
605
+ a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
606
+ }
607
+ return a;
608
+ #endif // __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
609
+ }
610
+
611
  static __device__ __forceinline__ float warp_reduce_max(float x) {
612
  #pragma unroll
613
  for (int mask = 16; mask > 0; mask >>= 1) {
 
616
  return x;
617
  }
618
 
619
+ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
620
+ #if __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
621
+ (void) x;
622
+ bad_arch();
623
+ #else
624
+ #pragma unroll
625
+ for (int mask = 16; mask > 0; mask >>= 1) {
626
+ x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
627
+ }
628
+ return x;
629
+ #endif // __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
630
+ }
631
+
632
  static __device__ __forceinline__ float op_repeat(const float a, const float b) {
633
  return b;
634
  GGML_UNUSED(a);
 
5413
  dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
5414
  }
5415
 
5416
+ template <bool vals_smem, int ncols_template, int block_size_template, bool need_check>
5417
+ static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
5418
+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
5419
+ const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template;
5420
+ const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2;
5421
+
5422
+ const int tid = threadIdx.x;
5423
+ const int rowx = blockIdx.x;
5424
+ const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
5425
+
5426
+ const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
5427
+
5428
+ const int warp_id = threadIdx.x / WARP_SIZE;
5429
+ const int lane_id = threadIdx.x % WARP_SIZE;
5430
+
5431
+ extern __shared__ half data_soft_max_f16[];
5432
+ half * buf_iw = data_soft_max_f16 + 0; // shared memory buffer for inter-warp communication
5433
+ // (shared memory) buffer to cache values between iterations:
5434
+ half2 * vals = vals_smem ? (half2 *) (buf_iw + WARP_SIZE) : (half2 *) (dst + rowx*ncols_data);
5435
+ // if the buffer is larger than max. shared memory per block, use dst as temp. buffer instead
5436
+ // in that case col_smem == col_data must be enforced to avoid race conditions
5437
+
5438
+ half2 max_val = make_half2(-INFINITY, -INFINITY);
5439
+
5440
+ #pragma unroll
5441
+ for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
5442
+ const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id;
5443
+ const int col_smem = vals_smem ? col0 + tid : col_data;
5444
+
5445
+ const int ix = rowx*ncols_data + col_data;
5446
+ const int iy = rowy*ncols_data + col_data;
5447
+
5448
+ half2 val;
5449
+ if (need_check && col_data + 0 >= ncols_data) {
5450
+ val.x = -INFINITY;
5451
+ } else {
5452
+ val.x = x[ix + 0]*scale + (y ? y[iy + 0] : 0.0f);
5453
+ }
5454
+ if (need_check && col_data + WARP_SIZE >= ncols_data) {
5455
+ val.y = -INFINITY;
5456
+ } else {
5457
+ val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f);
5458
+ }
5459
+ if (!need_check || col_smem < (vals_smem ? ncols_smem : ncols_data)) {
5460
+ vals[col_smem] = val;
5461
+ }
5462
+ max_val = __hmax2(max_val, val);
5463
+ }
5464
+
5465
+ // find the max value in the block
5466
+ max_val = warp_reduce_max(max_val);
5467
+ if (block_size > WARP_SIZE) {
5468
+ if (warp_id == 0) {
5469
+ buf_iw[lane_id] = -INFINITY;
5470
+ }
5471
+ __syncthreads();
5472
+
5473
+ if (lane_id == 0) {
5474
+ buf_iw[warp_id] = __hmax(max_val.x, max_val.y);
5475
+ }
5476
+ __syncthreads();
5477
+
5478
+ max_val = __half2half2(buf_iw[lane_id]);
5479
+ max_val = warp_reduce_max(max_val);
5480
+ } else {
5481
+ max_val = __half2half2(__hmax(max_val.x, max_val.y));
5482
+ }
5483
+
5484
+ half2 tmp = make_half2(0.0f, 0.0f); // partial sums
5485
+
5486
+ #pragma unroll
5487
+ for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
5488
+ const int col_smem = vals_smem ? col0 + tid : 2*col0 + 2*warp_id*WARP_SIZE + lane_id;
5489
+
5490
+ if (ncols_template == 0 && col_smem >= (vals_smem ? ncols_smem : ncols_data)) {
5491
+ break;
5492
+ }
5493
+
5494
+ const half2 val = h2exp(vals[col_smem] - max_val);
5495
+
5496
+ tmp += val;
5497
+ vals[col_smem] = val;
5498
+ }
5499
+
5500
+ // find the sum of exps in the block
5501
+ tmp = warp_reduce_sum(tmp);
5502
+ if (block_size > WARP_SIZE) {
5503
+ if (warp_id == 0) {
5504
+ buf_iw[lane_id] = 0.0f;
5505
+ }
5506
+ __syncthreads();
5507
+
5508
+ if (lane_id == 0) {
5509
+ buf_iw[warp_id] = tmp.x + tmp.y;
5510
+ }
5511
+ __syncthreads();
5512
+
5513
+ tmp = __half2half2(buf_iw[lane_id]);
5514
+ tmp = warp_reduce_sum(tmp);
5515
+ } else {
5516
+ tmp = __half2half2(tmp.x + tmp.y);
5517
+ }
5518
+
5519
+ const half2 inv_sum = make_half2(1.0f, 1.0f) / tmp;
5520
+
5521
+ #pragma unroll
5522
+ for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
5523
+ const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id;
5524
+ const int col_smem = vals_smem ? col0 + tid : col_data;
5525
+
5526
+ const int idst = rowx*ncols_data + col_data;
5527
+ const half2 result = vals[col_smem] * inv_sum;
5528
+
5529
+ if (need_check && col_data + 0 >= ncols_data) {
5530
+ return;
5531
+ }
5532
+ dst[idst] = result.x;
5533
+
5534
+ if (need_check && col_data + WARP_SIZE >= ncols_data) {
5535
+ return;
5536
+ }
5537
+
5538
+ dst[idst + WARP_SIZE] = result.y;
5539
+ }
5540
+ #else
5541
+ (void) x; (void) y; (void) dst; (void) ncols_par; (void) nrows_y; (void) scale;
5542
+ bad_arch();
5543
+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
5544
+ }
5545
+
5546
+ template <bool vals_smem, int ncols_template, int block_size_template>
5547
+ static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
5548
+ const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
5549
+
5550
  const int tid = threadIdx.x;
5551
  const int rowx = blockIdx.x;
5552
  const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
5553
 
5554
+ const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
5555
 
5556
  const int warp_id = threadIdx.x / WARP_SIZE;
5557
  const int lane_id = threadIdx.x % WARP_SIZE;
5558
 
5559
+ extern __shared__ float data_soft_max_f32[];
5560
+ float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
5561
+ // shared memory buffer to cache values between iterations:
5562
+ float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + rowx*ncols;
5563
 
5564
  float max_val = -INFINITY;
5565
 
5566
+ #pragma unroll
5567
+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
5568
+ const int col = col0 + tid;
5569
+
5570
+ if (ncols_template == 0 && col >= ncols) {
5571
+ break;
5572
+ }
5573
+
5574
  const int ix = rowx*ncols + col;
5575
  const int iy = rowy*ncols + col;
5576
+
5577
+ const float val = x[ix]*scale + (y ? y[iy] : 0.0f);
5578
+ vals[col] = val;
5579
+ max_val = max(max_val, val);
5580
  }
5581
 
5582
  // find the max value in the block
5583
  max_val = warp_reduce_max(max_val);
5584
  if (block_size > WARP_SIZE) {
5585
  if (warp_id == 0) {
5586
+ buf_iw[lane_id] = -INFINITY;
5587
  }
5588
  __syncthreads();
5589
 
5590
  if (lane_id == 0) {
5591
+ buf_iw[warp_id] = max_val;
5592
  }
5593
  __syncthreads();
5594
 
5595
+ max_val = buf_iw[lane_id];
5596
  max_val = warp_reduce_max(max_val);
5597
  }
5598
 
5599
+ float tmp = 0.0f; // partial sum
5600
 
5601
+ #pragma unroll
5602
+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
5603
+ const int col = col0 + tid;
5604
+
5605
+ if (ncols_template == 0 && col >= ncols) {
5606
+ break;
5607
+ }
5608
+
5609
+ const float val = expf(vals[col] - max_val);
5610
  tmp += val;
5611
+ vals[col] = val;
5612
  }
5613
 
5614
  // find the sum of exps in the block
5615
  tmp = warp_reduce_sum(tmp);
5616
  if (block_size > WARP_SIZE) {
5617
  if (warp_id == 0) {
5618
+ buf_iw[lane_id] = 0.0f;
5619
  }
5620
  __syncthreads();
5621
 
5622
  if (lane_id == 0) {
5623
+ buf_iw[warp_id] = tmp;
5624
  }
5625
  __syncthreads();
5626
 
5627
+ tmp = buf_iw[lane_id];
5628
  tmp = warp_reduce_sum(tmp);
5629
  }
5630
 
5631
+ const float inv_sum = 1.0f / tmp;
5632
 
5633
+ #pragma unroll
5634
+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
5635
+ const int col = col0 + tid;
5636
+
5637
+ if (ncols_template == 0 && col >= ncols) {
5638
+ return;
5639
+ }
5640
+
5641
+ const int idst = rowx*ncols + col;
5642
+ dst[idst] = vals[col] * inv_sum;
5643
  }
5644
  }
5645
 
 
6938
  diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
6939
  }
6940
 
6941
+ static void soft_max_f16_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
6942
+ int nth = WARP_SIZE;
6943
+ while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
6944
+ const dim3 block_dims(nth, 1, 1);
6945
+ const dim3 block_nums(nrows_x, 1, 1);
6946
+ const size_t shmem = (GGML_PAD(ncols_x, 2*WARP_SIZE) + WARP_SIZE)*sizeof(half);
6947
+ static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
6948
+ if (shmem <= g_device_caps[g_main_device].smpb) {
6949
+ switch (ncols_x) {
6950
+ case 32:
6951
+ soft_max_f16<true, 32, 32, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
6952
+ break;
6953
+ case 64:
6954
+ soft_max_f16<true, 64, 32, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
6955
+ break;
6956
+ case 128:
6957
+ soft_max_f16<true, 128, 64, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
6958
+ break;
6959
+ case 256:
6960
+ soft_max_f16<true, 256, 128, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
6961
+ break;
6962
+ case 512:
6963
+ soft_max_f16<true, 512, 256, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
6964
+ break;
6965
+ case 1024:
6966
+ soft_max_f16<true, 1024, 512, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
6967
+ break;
6968
+ case 2048:
6969
+ soft_max_f16<true, 2048, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
6970
+ break;
6971
+ case 4096:
6972
+ soft_max_f16<true, 4096, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
6973
+ break;
6974
+ default:
6975
+ soft_max_f16<true, 0, 0, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
6976
+ break;
6977
+ }
6978
+ } else {
6979
+ const size_t shmem_low = WARP_SIZE*sizeof(half);
6980
+ soft_max_f16<false, 0, 0, true><<<block_nums, block_dims, shmem_low, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
6981
+ }
6982
+ }
6983
+
6984
  static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
6985
  int nth = WARP_SIZE;
6986
  while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
6987
  const dim3 block_dims(nth, 1, 1);
6988
  const dim3 block_nums(nrows_x, 1, 1);
6989
+ const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
6990
+ static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
6991
+ if (shmem < g_device_caps[g_main_device].smpb) {
6992
+ switch (ncols_x) {
6993
+ case 32:
6994
+ soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
6995
+ break;
6996
+ case 64:
6997
+ soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
6998
+ break;
6999
+ case 128:
7000
+ soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7001
+ break;
7002
+ case 256:
7003
+ soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7004
+ break;
7005
+ case 512:
7006
+ soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7007
+ break;
7008
+ case 1024:
7009
+ soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7010
+ break;
7011
+ case 2048:
7012
+ soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7013
+ break;
7014
+ case 4096:
7015
+ soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7016
+ break;
7017
+ default:
7018
+ soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7019
+ break;
7020
+ }
7021
+ } else {
7022
+ const size_t shmem_low = WARP_SIZE*sizeof(float);
7023
+ soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7024
+ }
7025
  }
7026
 
7027
  static void im2col_f32_f16_cuda(const float* x, half* dst,
 
7336
  #else
7337
  g_device_caps[id].cc = 100*prop.major + 10*prop.minor;
7338
  #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
7339
+ g_device_caps[id].smpb = prop.sharedMemPerBlock;
7340
  }
7341
  for (int id = 0; id < g_device_count; ++id) {
7342
  g_tensor_split[id] /= total_vram;
 
8352
  float scale = 1.0f;
8353
  memcpy(&scale, dst->op_params, sizeof(float));
8354
 
8355
+ #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
8356
+ const bool use_f16_soft_max = false;
8357
+ #else
8358
+ #ifdef GGML_CUDA_F16
8359
+ const bool use_f16_soft_max = true;
8360
+ #else
8361
+ const bool use_f16_soft_max = false;
8362
+ #endif // GGML_CUDA_F16
8363
+ #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
8364
+
8365
+ if (use_f16_soft_max) {
8366
+ soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
8367
+ } else {
8368
+ soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
8369
+ }
8370
 
8371
  (void) dst;
8372
  }