Spaces:
Running
Running
CUDA: faster softmax via shared memory + fp16 math (llama/4742)
Browse files- ggml-cuda.cu +303 -24
ggml-cuda.cu
CHANGED
|
@@ -116,6 +116,7 @@
|
|
| 116 |
#include "ggml.h"
|
| 117 |
#include "ggml-backend-impl.h"
|
| 118 |
|
|
|
|
| 119 |
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
|
| 120 |
#define CC_VOLTA 700
|
| 121 |
#define CC_OFFSET_AMD 1000000
|
|
@@ -556,11 +557,12 @@ static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
|
|
| 556 |
|
| 557 |
struct cuda_device_capabilities {
|
| 558 |
int cc; // compute capability
|
|
|
|
| 559 |
bool vmm; // virtual memory support
|
| 560 |
size_t vmm_granularity; // granularity of virtual memory
|
| 561 |
};
|
| 562 |
|
| 563 |
-
static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, false, 0} };
|
| 564 |
|
| 565 |
static void * g_scratch_buffer = nullptr;
|
| 566 |
static size_t g_scratch_size = 0; // disabled by default
|
|
@@ -593,6 +595,19 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
|
| 593 |
return a;
|
| 594 |
}
|
| 595 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 596 |
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
| 597 |
#pragma unroll
|
| 598 |
for (int mask = 16; mask > 0; mask >>= 1) {
|
|
@@ -601,6 +616,19 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
|
|
| 601 |
return x;
|
| 602 |
}
|
| 603 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 604 |
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
|
| 605 |
return b;
|
| 606 |
GGML_UNUSED(a);
|
|
@@ -5385,75 +5413,233 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
|
|
| 5385 |
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
|
| 5386 |
}
|
| 5387 |
|
| 5388 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5389 |
const int tid = threadIdx.x;
|
| 5390 |
const int rowx = blockIdx.x;
|
| 5391 |
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
|
| 5392 |
|
| 5393 |
-
const int block_size = blockDim.x;
|
| 5394 |
|
| 5395 |
const int warp_id = threadIdx.x / WARP_SIZE;
|
| 5396 |
const int lane_id = threadIdx.x % WARP_SIZE;
|
| 5397 |
|
| 5398 |
-
__shared__ float
|
|
|
|
|
|
|
|
|
|
| 5399 |
|
| 5400 |
float max_val = -INFINITY;
|
| 5401 |
|
| 5402 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5403 |
const int ix = rowx*ncols + col;
|
| 5404 |
const int iy = rowy*ncols + col;
|
| 5405 |
-
|
|
|
|
|
|
|
|
|
|
| 5406 |
}
|
| 5407 |
|
| 5408 |
// find the max value in the block
|
| 5409 |
max_val = warp_reduce_max(max_val);
|
| 5410 |
if (block_size > WARP_SIZE) {
|
| 5411 |
if (warp_id == 0) {
|
| 5412 |
-
|
| 5413 |
}
|
| 5414 |
__syncthreads();
|
| 5415 |
|
| 5416 |
if (lane_id == 0) {
|
| 5417 |
-
|
| 5418 |
}
|
| 5419 |
__syncthreads();
|
| 5420 |
|
| 5421 |
-
max_val =
|
| 5422 |
max_val = warp_reduce_max(max_val);
|
| 5423 |
}
|
| 5424 |
|
| 5425 |
-
float tmp = 0.
|
| 5426 |
|
| 5427 |
-
|
| 5428 |
-
|
| 5429 |
-
const int
|
| 5430 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5431 |
tmp += val;
|
| 5432 |
-
|
| 5433 |
}
|
| 5434 |
|
| 5435 |
// find the sum of exps in the block
|
| 5436 |
tmp = warp_reduce_sum(tmp);
|
| 5437 |
if (block_size > WARP_SIZE) {
|
| 5438 |
if (warp_id == 0) {
|
| 5439 |
-
|
| 5440 |
}
|
| 5441 |
__syncthreads();
|
| 5442 |
|
| 5443 |
if (lane_id == 0) {
|
| 5444 |
-
|
| 5445 |
}
|
| 5446 |
__syncthreads();
|
| 5447 |
|
| 5448 |
-
tmp =
|
| 5449 |
tmp = warp_reduce_sum(tmp);
|
| 5450 |
}
|
| 5451 |
|
| 5452 |
-
const float
|
| 5453 |
|
| 5454 |
-
|
| 5455 |
-
|
| 5456 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5457 |
}
|
| 5458 |
}
|
| 5459 |
|
|
@@ -6752,12 +6938,90 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
|
|
| 6752 |
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
|
| 6753 |
}
|
| 6754 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6755 |
static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
|
| 6756 |
int nth = WARP_SIZE;
|
| 6757 |
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
| 6758 |
const dim3 block_dims(nth, 1, 1);
|
| 6759 |
const dim3 block_nums(nrows_x, 1, 1);
|
| 6760 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6761 |
}
|
| 6762 |
|
| 6763 |
static void im2col_f32_f16_cuda(const float* x, half* dst,
|
|
@@ -7072,6 +7336,7 @@ void ggml_init_cublas() {
|
|
| 7072 |
#else
|
| 7073 |
g_device_caps[id].cc = 100*prop.major + 10*prop.minor;
|
| 7074 |
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
|
|
|
| 7075 |
}
|
| 7076 |
for (int id = 0; id < g_device_count; ++id) {
|
| 7077 |
g_tensor_split[id] /= total_vram;
|
|
@@ -8087,7 +8352,21 @@ static void ggml_cuda_op_soft_max(
|
|
| 8087 |
float scale = 1.0f;
|
| 8088 |
memcpy(&scale, dst->op_params, sizeof(float));
|
| 8089 |
|
| 8090 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8091 |
|
| 8092 |
(void) dst;
|
| 8093 |
}
|
|
|
|
| 116 |
#include "ggml.h"
|
| 117 |
#include "ggml-backend-impl.h"
|
| 118 |
|
| 119 |
+
#define CC_PASCAL 600
|
| 120 |
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
|
| 121 |
#define CC_VOLTA 700
|
| 122 |
#define CC_OFFSET_AMD 1000000
|
|
|
|
| 557 |
|
| 558 |
struct cuda_device_capabilities {
|
| 559 |
int cc; // compute capability
|
| 560 |
+
size_t smpb; // max. shared memory per block
|
| 561 |
bool vmm; // virtual memory support
|
| 562 |
size_t vmm_granularity; // granularity of virtual memory
|
| 563 |
};
|
| 564 |
|
| 565 |
+
static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, 0, false, 0} };
|
| 566 |
|
| 567 |
static void * g_scratch_buffer = nullptr;
|
| 568 |
static size_t g_scratch_size = 0; // disabled by default
|
|
|
|
| 595 |
return a;
|
| 596 |
}
|
| 597 |
|
| 598 |
+
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
| 599 |
+
#if __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
| 600 |
+
(void) a;
|
| 601 |
+
bad_arch();
|
| 602 |
+
#else
|
| 603 |
+
#pragma unroll
|
| 604 |
+
for (int mask = 16; mask > 0; mask >>= 1) {
|
| 605 |
+
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
|
| 606 |
+
}
|
| 607 |
+
return a;
|
| 608 |
+
#endif // __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
| 609 |
+
}
|
| 610 |
+
|
| 611 |
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
| 612 |
#pragma unroll
|
| 613 |
for (int mask = 16; mask > 0; mask >>= 1) {
|
|
|
|
| 616 |
return x;
|
| 617 |
}
|
| 618 |
|
| 619 |
+
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
| 620 |
+
#if __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
| 621 |
+
(void) x;
|
| 622 |
+
bad_arch();
|
| 623 |
+
#else
|
| 624 |
+
#pragma unroll
|
| 625 |
+
for (int mask = 16; mask > 0; mask >>= 1) {
|
| 626 |
+
x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
|
| 627 |
+
}
|
| 628 |
+
return x;
|
| 629 |
+
#endif // __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
| 630 |
+
}
|
| 631 |
+
|
| 632 |
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
|
| 633 |
return b;
|
| 634 |
GGML_UNUSED(a);
|
|
|
|
| 5413 |
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
|
| 5414 |
}
|
| 5415 |
|
| 5416 |
+
template <bool vals_smem, int ncols_template, int block_size_template, bool need_check>
|
| 5417 |
+
static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
|
| 5418 |
+
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
| 5419 |
+
const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template;
|
| 5420 |
+
const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2;
|
| 5421 |
+
|
| 5422 |
+
const int tid = threadIdx.x;
|
| 5423 |
+
const int rowx = blockIdx.x;
|
| 5424 |
+
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
|
| 5425 |
+
|
| 5426 |
+
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
|
| 5427 |
+
|
| 5428 |
+
const int warp_id = threadIdx.x / WARP_SIZE;
|
| 5429 |
+
const int lane_id = threadIdx.x % WARP_SIZE;
|
| 5430 |
+
|
| 5431 |
+
extern __shared__ half data_soft_max_f16[];
|
| 5432 |
+
half * buf_iw = data_soft_max_f16 + 0; // shared memory buffer for inter-warp communication
|
| 5433 |
+
// (shared memory) buffer to cache values between iterations:
|
| 5434 |
+
half2 * vals = vals_smem ? (half2 *) (buf_iw + WARP_SIZE) : (half2 *) (dst + rowx*ncols_data);
|
| 5435 |
+
// if the buffer is larger than max. shared memory per block, use dst as temp. buffer instead
|
| 5436 |
+
// in that case col_smem == col_data must be enforced to avoid race conditions
|
| 5437 |
+
|
| 5438 |
+
half2 max_val = make_half2(-INFINITY, -INFINITY);
|
| 5439 |
+
|
| 5440 |
+
#pragma unroll
|
| 5441 |
+
for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
|
| 5442 |
+
const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id;
|
| 5443 |
+
const int col_smem = vals_smem ? col0 + tid : col_data;
|
| 5444 |
+
|
| 5445 |
+
const int ix = rowx*ncols_data + col_data;
|
| 5446 |
+
const int iy = rowy*ncols_data + col_data;
|
| 5447 |
+
|
| 5448 |
+
half2 val;
|
| 5449 |
+
if (need_check && col_data + 0 >= ncols_data) {
|
| 5450 |
+
val.x = -INFINITY;
|
| 5451 |
+
} else {
|
| 5452 |
+
val.x = x[ix + 0]*scale + (y ? y[iy + 0] : 0.0f);
|
| 5453 |
+
}
|
| 5454 |
+
if (need_check && col_data + WARP_SIZE >= ncols_data) {
|
| 5455 |
+
val.y = -INFINITY;
|
| 5456 |
+
} else {
|
| 5457 |
+
val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f);
|
| 5458 |
+
}
|
| 5459 |
+
if (!need_check || col_smem < (vals_smem ? ncols_smem : ncols_data)) {
|
| 5460 |
+
vals[col_smem] = val;
|
| 5461 |
+
}
|
| 5462 |
+
max_val = __hmax2(max_val, val);
|
| 5463 |
+
}
|
| 5464 |
+
|
| 5465 |
+
// find the max value in the block
|
| 5466 |
+
max_val = warp_reduce_max(max_val);
|
| 5467 |
+
if (block_size > WARP_SIZE) {
|
| 5468 |
+
if (warp_id == 0) {
|
| 5469 |
+
buf_iw[lane_id] = -INFINITY;
|
| 5470 |
+
}
|
| 5471 |
+
__syncthreads();
|
| 5472 |
+
|
| 5473 |
+
if (lane_id == 0) {
|
| 5474 |
+
buf_iw[warp_id] = __hmax(max_val.x, max_val.y);
|
| 5475 |
+
}
|
| 5476 |
+
__syncthreads();
|
| 5477 |
+
|
| 5478 |
+
max_val = __half2half2(buf_iw[lane_id]);
|
| 5479 |
+
max_val = warp_reduce_max(max_val);
|
| 5480 |
+
} else {
|
| 5481 |
+
max_val = __half2half2(__hmax(max_val.x, max_val.y));
|
| 5482 |
+
}
|
| 5483 |
+
|
| 5484 |
+
half2 tmp = make_half2(0.0f, 0.0f); // partial sums
|
| 5485 |
+
|
| 5486 |
+
#pragma unroll
|
| 5487 |
+
for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
|
| 5488 |
+
const int col_smem = vals_smem ? col0 + tid : 2*col0 + 2*warp_id*WARP_SIZE + lane_id;
|
| 5489 |
+
|
| 5490 |
+
if (ncols_template == 0 && col_smem >= (vals_smem ? ncols_smem : ncols_data)) {
|
| 5491 |
+
break;
|
| 5492 |
+
}
|
| 5493 |
+
|
| 5494 |
+
const half2 val = h2exp(vals[col_smem] - max_val);
|
| 5495 |
+
|
| 5496 |
+
tmp += val;
|
| 5497 |
+
vals[col_smem] = val;
|
| 5498 |
+
}
|
| 5499 |
+
|
| 5500 |
+
// find the sum of exps in the block
|
| 5501 |
+
tmp = warp_reduce_sum(tmp);
|
| 5502 |
+
if (block_size > WARP_SIZE) {
|
| 5503 |
+
if (warp_id == 0) {
|
| 5504 |
+
buf_iw[lane_id] = 0.0f;
|
| 5505 |
+
}
|
| 5506 |
+
__syncthreads();
|
| 5507 |
+
|
| 5508 |
+
if (lane_id == 0) {
|
| 5509 |
+
buf_iw[warp_id] = tmp.x + tmp.y;
|
| 5510 |
+
}
|
| 5511 |
+
__syncthreads();
|
| 5512 |
+
|
| 5513 |
+
tmp = __half2half2(buf_iw[lane_id]);
|
| 5514 |
+
tmp = warp_reduce_sum(tmp);
|
| 5515 |
+
} else {
|
| 5516 |
+
tmp = __half2half2(tmp.x + tmp.y);
|
| 5517 |
+
}
|
| 5518 |
+
|
| 5519 |
+
const half2 inv_sum = make_half2(1.0f, 1.0f) / tmp;
|
| 5520 |
+
|
| 5521 |
+
#pragma unroll
|
| 5522 |
+
for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
|
| 5523 |
+
const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id;
|
| 5524 |
+
const int col_smem = vals_smem ? col0 + tid : col_data;
|
| 5525 |
+
|
| 5526 |
+
const int idst = rowx*ncols_data + col_data;
|
| 5527 |
+
const half2 result = vals[col_smem] * inv_sum;
|
| 5528 |
+
|
| 5529 |
+
if (need_check && col_data + 0 >= ncols_data) {
|
| 5530 |
+
return;
|
| 5531 |
+
}
|
| 5532 |
+
dst[idst] = result.x;
|
| 5533 |
+
|
| 5534 |
+
if (need_check && col_data + WARP_SIZE >= ncols_data) {
|
| 5535 |
+
return;
|
| 5536 |
+
}
|
| 5537 |
+
|
| 5538 |
+
dst[idst + WARP_SIZE] = result.y;
|
| 5539 |
+
}
|
| 5540 |
+
#else
|
| 5541 |
+
(void) x; (void) y; (void) dst; (void) ncols_par; (void) nrows_y; (void) scale;
|
| 5542 |
+
bad_arch();
|
| 5543 |
+
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
| 5544 |
+
}
|
| 5545 |
+
|
| 5546 |
+
template <bool vals_smem, int ncols_template, int block_size_template>
|
| 5547 |
+
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
|
| 5548 |
+
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
| 5549 |
+
|
| 5550 |
const int tid = threadIdx.x;
|
| 5551 |
const int rowx = blockIdx.x;
|
| 5552 |
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
|
| 5553 |
|
| 5554 |
+
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
|
| 5555 |
|
| 5556 |
const int warp_id = threadIdx.x / WARP_SIZE;
|
| 5557 |
const int lane_id = threadIdx.x % WARP_SIZE;
|
| 5558 |
|
| 5559 |
+
extern __shared__ float data_soft_max_f32[];
|
| 5560 |
+
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
|
| 5561 |
+
// shared memory buffer to cache values between iterations:
|
| 5562 |
+
float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + rowx*ncols;
|
| 5563 |
|
| 5564 |
float max_val = -INFINITY;
|
| 5565 |
|
| 5566 |
+
#pragma unroll
|
| 5567 |
+
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
| 5568 |
+
const int col = col0 + tid;
|
| 5569 |
+
|
| 5570 |
+
if (ncols_template == 0 && col >= ncols) {
|
| 5571 |
+
break;
|
| 5572 |
+
}
|
| 5573 |
+
|
| 5574 |
const int ix = rowx*ncols + col;
|
| 5575 |
const int iy = rowy*ncols + col;
|
| 5576 |
+
|
| 5577 |
+
const float val = x[ix]*scale + (y ? y[iy] : 0.0f);
|
| 5578 |
+
vals[col] = val;
|
| 5579 |
+
max_val = max(max_val, val);
|
| 5580 |
}
|
| 5581 |
|
| 5582 |
// find the max value in the block
|
| 5583 |
max_val = warp_reduce_max(max_val);
|
| 5584 |
if (block_size > WARP_SIZE) {
|
| 5585 |
if (warp_id == 0) {
|
| 5586 |
+
buf_iw[lane_id] = -INFINITY;
|
| 5587 |
}
|
| 5588 |
__syncthreads();
|
| 5589 |
|
| 5590 |
if (lane_id == 0) {
|
| 5591 |
+
buf_iw[warp_id] = max_val;
|
| 5592 |
}
|
| 5593 |
__syncthreads();
|
| 5594 |
|
| 5595 |
+
max_val = buf_iw[lane_id];
|
| 5596 |
max_val = warp_reduce_max(max_val);
|
| 5597 |
}
|
| 5598 |
|
| 5599 |
+
float tmp = 0.0f; // partial sum
|
| 5600 |
|
| 5601 |
+
#pragma unroll
|
| 5602 |
+
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
| 5603 |
+
const int col = col0 + tid;
|
| 5604 |
+
|
| 5605 |
+
if (ncols_template == 0 && col >= ncols) {
|
| 5606 |
+
break;
|
| 5607 |
+
}
|
| 5608 |
+
|
| 5609 |
+
const float val = expf(vals[col] - max_val);
|
| 5610 |
tmp += val;
|
| 5611 |
+
vals[col] = val;
|
| 5612 |
}
|
| 5613 |
|
| 5614 |
// find the sum of exps in the block
|
| 5615 |
tmp = warp_reduce_sum(tmp);
|
| 5616 |
if (block_size > WARP_SIZE) {
|
| 5617 |
if (warp_id == 0) {
|
| 5618 |
+
buf_iw[lane_id] = 0.0f;
|
| 5619 |
}
|
| 5620 |
__syncthreads();
|
| 5621 |
|
| 5622 |
if (lane_id == 0) {
|
| 5623 |
+
buf_iw[warp_id] = tmp;
|
| 5624 |
}
|
| 5625 |
__syncthreads();
|
| 5626 |
|
| 5627 |
+
tmp = buf_iw[lane_id];
|
| 5628 |
tmp = warp_reduce_sum(tmp);
|
| 5629 |
}
|
| 5630 |
|
| 5631 |
+
const float inv_sum = 1.0f / tmp;
|
| 5632 |
|
| 5633 |
+
#pragma unroll
|
| 5634 |
+
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
| 5635 |
+
const int col = col0 + tid;
|
| 5636 |
+
|
| 5637 |
+
if (ncols_template == 0 && col >= ncols) {
|
| 5638 |
+
return;
|
| 5639 |
+
}
|
| 5640 |
+
|
| 5641 |
+
const int idst = rowx*ncols + col;
|
| 5642 |
+
dst[idst] = vals[col] * inv_sum;
|
| 5643 |
}
|
| 5644 |
}
|
| 5645 |
|
|
|
|
| 6938 |
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
|
| 6939 |
}
|
| 6940 |
|
| 6941 |
+
static void soft_max_f16_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
|
| 6942 |
+
int nth = WARP_SIZE;
|
| 6943 |
+
while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
| 6944 |
+
const dim3 block_dims(nth, 1, 1);
|
| 6945 |
+
const dim3 block_nums(nrows_x, 1, 1);
|
| 6946 |
+
const size_t shmem = (GGML_PAD(ncols_x, 2*WARP_SIZE) + WARP_SIZE)*sizeof(half);
|
| 6947 |
+
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
| 6948 |
+
if (shmem <= g_device_caps[g_main_device].smpb) {
|
| 6949 |
+
switch (ncols_x) {
|
| 6950 |
+
case 32:
|
| 6951 |
+
soft_max_f16<true, 32, 32, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
| 6952 |
+
break;
|
| 6953 |
+
case 64:
|
| 6954 |
+
soft_max_f16<true, 64, 32, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
| 6955 |
+
break;
|
| 6956 |
+
case 128:
|
| 6957 |
+
soft_max_f16<true, 128, 64, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
| 6958 |
+
break;
|
| 6959 |
+
case 256:
|
| 6960 |
+
soft_max_f16<true, 256, 128, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
| 6961 |
+
break;
|
| 6962 |
+
case 512:
|
| 6963 |
+
soft_max_f16<true, 512, 256, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
| 6964 |
+
break;
|
| 6965 |
+
case 1024:
|
| 6966 |
+
soft_max_f16<true, 1024, 512, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
| 6967 |
+
break;
|
| 6968 |
+
case 2048:
|
| 6969 |
+
soft_max_f16<true, 2048, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
| 6970 |
+
break;
|
| 6971 |
+
case 4096:
|
| 6972 |
+
soft_max_f16<true, 4096, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
| 6973 |
+
break;
|
| 6974 |
+
default:
|
| 6975 |
+
soft_max_f16<true, 0, 0, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
| 6976 |
+
break;
|
| 6977 |
+
}
|
| 6978 |
+
} else {
|
| 6979 |
+
const size_t shmem_low = WARP_SIZE*sizeof(half);
|
| 6980 |
+
soft_max_f16<false, 0, 0, true><<<block_nums, block_dims, shmem_low, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
| 6981 |
+
}
|
| 6982 |
+
}
|
| 6983 |
+
|
| 6984 |
static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
|
| 6985 |
int nth = WARP_SIZE;
|
| 6986 |
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
| 6987 |
const dim3 block_dims(nth, 1, 1);
|
| 6988 |
const dim3 block_nums(nrows_x, 1, 1);
|
| 6989 |
+
const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
|
| 6990 |
+
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
| 6991 |
+
if (shmem < g_device_caps[g_main_device].smpb) {
|
| 6992 |
+
switch (ncols_x) {
|
| 6993 |
+
case 32:
|
| 6994 |
+
soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
| 6995 |
+
break;
|
| 6996 |
+
case 64:
|
| 6997 |
+
soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
| 6998 |
+
break;
|
| 6999 |
+
case 128:
|
| 7000 |
+
soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
| 7001 |
+
break;
|
| 7002 |
+
case 256:
|
| 7003 |
+
soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
| 7004 |
+
break;
|
| 7005 |
+
case 512:
|
| 7006 |
+
soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
| 7007 |
+
break;
|
| 7008 |
+
case 1024:
|
| 7009 |
+
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
| 7010 |
+
break;
|
| 7011 |
+
case 2048:
|
| 7012 |
+
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
| 7013 |
+
break;
|
| 7014 |
+
case 4096:
|
| 7015 |
+
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
| 7016 |
+
break;
|
| 7017 |
+
default:
|
| 7018 |
+
soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
| 7019 |
+
break;
|
| 7020 |
+
}
|
| 7021 |
+
} else {
|
| 7022 |
+
const size_t shmem_low = WARP_SIZE*sizeof(float);
|
| 7023 |
+
soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
| 7024 |
+
}
|
| 7025 |
}
|
| 7026 |
|
| 7027 |
static void im2col_f32_f16_cuda(const float* x, half* dst,
|
|
|
|
| 7336 |
#else
|
| 7337 |
g_device_caps[id].cc = 100*prop.major + 10*prop.minor;
|
| 7338 |
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
| 7339 |
+
g_device_caps[id].smpb = prop.sharedMemPerBlock;
|
| 7340 |
}
|
| 7341 |
for (int id = 0; id < g_device_count; ++id) {
|
| 7342 |
g_tensor_split[id] /= total_vram;
|
|
|
|
| 8352 |
float scale = 1.0f;
|
| 8353 |
memcpy(&scale, dst->op_params, sizeof(float));
|
| 8354 |
|
| 8355 |
+
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
| 8356 |
+
const bool use_f16_soft_max = false;
|
| 8357 |
+
#else
|
| 8358 |
+
#ifdef GGML_CUDA_F16
|
| 8359 |
+
const bool use_f16_soft_max = true;
|
| 8360 |
+
#else
|
| 8361 |
+
const bool use_f16_soft_max = false;
|
| 8362 |
+
#endif // GGML_CUDA_F16
|
| 8363 |
+
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
| 8364 |
+
|
| 8365 |
+
if (use_f16_soft_max) {
|
| 8366 |
+
soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
|
| 8367 |
+
} else {
|
| 8368 |
+
soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
|
| 8369 |
+
}
|
| 8370 |
|
| 8371 |
(void) dst;
|
| 8372 |
}
|