Spaces:
Running
Running
CUDA: add FP32 FlashAttention vector kernel (llama/7188)
Browse files* CUDA: add FP32 FlashAttention vector kernel
* fixup! CUDA: add FP32 FlashAttention vector kernel
* fixup! fixup! CUDA: add FP32 FlashAttention vector kernel
* fixup! fixup! fixup! CUDA: add FP32 FlashAttention vector kernel
- ggml-cuda.cu +10 -1
- ggml-cuda/common.cuh +4 -0
- ggml-cuda/fattn-common.cuh +47 -0
- ggml-cuda/fattn-vec-f16.cu +430 -0
- ggml-cuda/fattn-vec-f16.cuh +5 -0
- ggml-cuda/fattn-vec-f32.cu +384 -0
- ggml-cuda/fattn-vec-f32.cuh +3 -0
- ggml-cuda/fattn.cu +15 -453
ggml-cuda.cu
CHANGED
|
@@ -2713,6 +2713,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
|
| 2713 |
}
|
| 2714 |
|
| 2715 |
GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
|
|
|
|
| 2716 |
switch (op->op) {
|
| 2717 |
case GGML_OP_UNARY:
|
| 2718 |
switch (ggml_get_unary_op(op)) {
|
|
@@ -2840,8 +2841,16 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|
| 2840 |
case GGML_OP_ARANGE:
|
| 2841 |
case GGML_OP_TIMESTEP_EMBEDDING:
|
| 2842 |
case GGML_OP_LEAKY_RELU:
|
| 2843 |
-
case GGML_OP_FLASH_ATTN_EXT:
|
| 2844 |
return true;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2845 |
default:
|
| 2846 |
return false;
|
| 2847 |
}
|
|
|
|
| 2713 |
}
|
| 2714 |
|
| 2715 |
GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
|
| 2716 |
+
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
|
| 2717 |
switch (op->op) {
|
| 2718 |
case GGML_OP_UNARY:
|
| 2719 |
switch (ggml_get_unary_op(op)) {
|
|
|
|
| 2841 |
case GGML_OP_ARANGE:
|
| 2842 |
case GGML_OP_TIMESTEP_EMBEDDING:
|
| 2843 |
case GGML_OP_LEAKY_RELU:
|
|
|
|
| 2844 |
return true;
|
| 2845 |
+
case GGML_OP_FLASH_ATTN_EXT:
|
| 2846 |
+
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
| 2847 |
+
return op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128;
|
| 2848 |
+
#else
|
| 2849 |
+
if (op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128) {
|
| 2850 |
+
return true;
|
| 2851 |
+
}
|
| 2852 |
+
return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA;
|
| 2853 |
+
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
| 2854 |
default:
|
| 2855 |
return false;
|
| 2856 |
}
|
ggml-cuda/common.cuh
CHANGED
|
@@ -321,6 +321,10 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
|
|
| 321 |
|
| 322 |
#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
|
| 323 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
static bool fp16_mma_available(const int cc) {
|
| 325 |
return cc < CC_OFFSET_AMD && cc >= CC_VOLTA;
|
| 326 |
}
|
|
|
|
| 321 |
|
| 322 |
#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
|
| 323 |
|
| 324 |
+
static bool fast_fp16_available(const int cc) {
|
| 325 |
+
return cc >= CC_PASCAL && cc != 610;
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
static bool fp16_mma_available(const int cc) {
|
| 329 |
return cc < CC_OFFSET_AMD && cc >= CC_VOLTA;
|
| 330 |
}
|
ggml-cuda/fattn-common.cuh
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#define FATTN_KQ_STRIDE 256
|
| 2 |
+
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
|
| 3 |
+
#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
|
| 4 |
+
|
| 5 |
+
template<int D, int parallel_blocks> // D == head size
|
| 6 |
+
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
| 7 |
+
__launch_bounds__(D, 1)
|
| 8 |
+
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
| 9 |
+
static __global__ void flash_attn_combine_results(
|
| 10 |
+
const float * __restrict__ VKQ_parts,
|
| 11 |
+
const float2 * __restrict__ VKQ_meta,
|
| 12 |
+
float * __restrict__ dst) {
|
| 13 |
+
VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x;
|
| 14 |
+
VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x;
|
| 15 |
+
dst += D * gridDim.y*blockIdx.x;
|
| 16 |
+
|
| 17 |
+
const int tid = threadIdx.x;
|
| 18 |
+
__builtin_assume(tid < D);
|
| 19 |
+
|
| 20 |
+
__shared__ float2 meta[parallel_blocks];
|
| 21 |
+
if (tid < 2*parallel_blocks) {
|
| 22 |
+
((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid];
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
__syncthreads();
|
| 26 |
+
|
| 27 |
+
float kqmax = meta[0].x;
|
| 28 |
+
#pragma unroll
|
| 29 |
+
for (int l = 1; l < parallel_blocks; ++l) {
|
| 30 |
+
kqmax = max(kqmax, meta[l].x);
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
float VKQ_numerator = 0.0f;
|
| 34 |
+
float VKQ_denominator = 0.0f;
|
| 35 |
+
#pragma unroll
|
| 36 |
+
for (int l = 0; l < parallel_blocks; ++l) {
|
| 37 |
+
const float diff = meta[l].x - kqmax;
|
| 38 |
+
const float KQ_max_scale = expf(diff);
|
| 39 |
+
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
|
| 40 |
+
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
|
| 41 |
+
|
| 42 |
+
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
|
| 43 |
+
VKQ_denominator += KQ_max_scale * meta[l].y;
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
|
| 47 |
+
}
|
ggml-cuda/fattn-vec-f16.cu
ADDED
|
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "common.cuh"
|
| 2 |
+
#include "fattn-common.cuh"
|
| 3 |
+
#include "fattn-vec-f16.cuh"
|
| 4 |
+
|
| 5 |
+
template<int D, int ncols, int parallel_blocks> // D == head size
|
| 6 |
+
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
| 7 |
+
__launch_bounds__(D, 1)
|
| 8 |
+
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
| 9 |
+
static __global__ void flash_attn_vec_ext_f16(
|
| 10 |
+
const char * __restrict__ Q,
|
| 11 |
+
const char * __restrict__ K,
|
| 12 |
+
const char * __restrict__ V,
|
| 13 |
+
const char * __restrict__ mask,
|
| 14 |
+
float * __restrict__ dst,
|
| 15 |
+
float2 * __restrict__ dst_meta,
|
| 16 |
+
const float scale,
|
| 17 |
+
const float max_bias,
|
| 18 |
+
const float m0,
|
| 19 |
+
const float m1,
|
| 20 |
+
const uint32_t n_head_log2,
|
| 21 |
+
const int ne00,
|
| 22 |
+
const int ne01,
|
| 23 |
+
const int ne02,
|
| 24 |
+
const int ne03,
|
| 25 |
+
const int ne10,
|
| 26 |
+
const int ne11,
|
| 27 |
+
const int ne12,
|
| 28 |
+
const int ne13,
|
| 29 |
+
const int ne31,
|
| 30 |
+
const int nb31,
|
| 31 |
+
const int nb01,
|
| 32 |
+
const int nb02,
|
| 33 |
+
const int nb03,
|
| 34 |
+
const int nb11,
|
| 35 |
+
const int nb12,
|
| 36 |
+
const int nb13,
|
| 37 |
+
const int ne0,
|
| 38 |
+
const int ne1,
|
| 39 |
+
const int ne2,
|
| 40 |
+
const int ne3) {
|
| 41 |
+
#if FP16_AVAILABLE
|
| 42 |
+
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
| 43 |
+
|
| 44 |
+
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
| 45 |
+
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
| 46 |
+
|
| 47 |
+
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
| 48 |
+
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
|
| 49 |
+
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
|
| 50 |
+
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
| 51 |
+
const half * maskh = (const half *) mask + ne11*ic0;
|
| 52 |
+
|
| 53 |
+
const int stride_KV = nb11 / sizeof(half);
|
| 54 |
+
const int stride_KV2 = nb11 / sizeof(half2);
|
| 55 |
+
|
| 56 |
+
half slopeh = __float2half(1.0f);
|
| 57 |
+
|
| 58 |
+
// ALiBi
|
| 59 |
+
if (max_bias > 0.0f) {
|
| 60 |
+
const int h = blockIdx.y;
|
| 61 |
+
|
| 62 |
+
const float base = h < n_head_log2 ? m0 : m1;
|
| 63 |
+
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
| 64 |
+
|
| 65 |
+
slopeh = __float2half(powf(base, exph));
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
| 69 |
+
constexpr int nwarps = D / WARP_SIZE;
|
| 70 |
+
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
| 71 |
+
__builtin_assume(tid < D);
|
| 72 |
+
|
| 73 |
+
__shared__ half KQ[ncols*D];
|
| 74 |
+
#pragma unroll
|
| 75 |
+
for (int j = 0; j < ncols; ++j) {
|
| 76 |
+
KQ[j*D + tid] = -HALF_MAX_HALF;
|
| 77 |
+
}
|
| 78 |
+
half2 * KQ2 = (half2 *) KQ;
|
| 79 |
+
|
| 80 |
+
half kqmax[ncols];
|
| 81 |
+
#pragma unroll
|
| 82 |
+
for (int j = 0; j < ncols; ++j) {
|
| 83 |
+
kqmax[j] = -HALF_MAX_HALF;
|
| 84 |
+
}
|
| 85 |
+
half kqsum[ncols] = {0.0f};
|
| 86 |
+
|
| 87 |
+
__shared__ half kqmax_shared[ncols][WARP_SIZE];
|
| 88 |
+
__shared__ half kqsum_shared[ncols][WARP_SIZE];
|
| 89 |
+
#pragma unroll
|
| 90 |
+
for (int j = 0; j < ncols; ++j) {
|
| 91 |
+
if (threadIdx.y == 0) {
|
| 92 |
+
kqmax_shared[j][threadIdx.x] = -HALF_MAX_HALF;
|
| 93 |
+
kqsum_shared[j][threadIdx.x] = 0.0f;
|
| 94 |
+
}
|
| 95 |
+
}
|
| 96 |
+
__syncthreads();
|
| 97 |
+
|
| 98 |
+
// Convert Q to half2 and store in registers:
|
| 99 |
+
half2 Q_h2[ncols][D/(2*WARP_SIZE)];
|
| 100 |
+
#pragma unroll
|
| 101 |
+
for (int j = 0; j < ncols; ++j) {
|
| 102 |
+
#pragma unroll
|
| 103 |
+
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
| 104 |
+
const int i = i0 + threadIdx.x;
|
| 105 |
+
|
| 106 |
+
const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i];
|
| 107 |
+
Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
|
| 108 |
+
}
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
half2 VKQ[ncols] = {{0.0f, 0.0f}};
|
| 112 |
+
|
| 113 |
+
const int k_start = parallel_blocks == 1 ? 0 : ip*D;
|
| 114 |
+
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
|
| 115 |
+
// Calculate KQ tile and keep track of new maximum KQ values:
|
| 116 |
+
|
| 117 |
+
// For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
|
| 118 |
+
// see https://github.com/ggerganov/llama.cpp/pull/7061 .
|
| 119 |
+
// Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable).
|
| 120 |
+
half kqmax_new = kqmax[0];
|
| 121 |
+
half kqmax_new_arr[ncols];
|
| 122 |
+
#pragma unroll
|
| 123 |
+
for (int j = 0; j < ncols; ++j) {
|
| 124 |
+
kqmax_new_arr[j] = kqmax[j];
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
#pragma unroll
|
| 128 |
+
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
|
| 129 |
+
const int i_KQ = i_KQ_0 + threadIdx.y;
|
| 130 |
+
|
| 131 |
+
if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) {
|
| 132 |
+
break;
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
half2 sum2[ncols] = {{0.0f, 0.0f}};
|
| 136 |
+
#pragma unroll
|
| 137 |
+
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
|
| 138 |
+
const int k_KQ = k_KQ_0 + threadIdx.x;
|
| 139 |
+
|
| 140 |
+
const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
|
| 141 |
+
#pragma unroll
|
| 142 |
+
for (int j = 0; j < ncols; ++j) {
|
| 143 |
+
sum2[j] += K_ik * Q_h2[j][k_KQ_0/WARP_SIZE];
|
| 144 |
+
}
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
#pragma unroll
|
| 148 |
+
for (int j = 0; j < ncols; ++j) {
|
| 149 |
+
sum2[j] = warp_reduce_sum(sum2[j]);
|
| 150 |
+
half sum = __low2half(sum2[j]) + __high2half(sum2[j]);
|
| 151 |
+
sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
|
| 152 |
+
|
| 153 |
+
if (ncols == 1) {
|
| 154 |
+
kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
|
| 155 |
+
} else {
|
| 156 |
+
kqmax_new_arr[j] = ggml_cuda_hmax(kqmax_new_arr[j], sum);
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
if (threadIdx.x == 0) {
|
| 160 |
+
KQ[j*D + i_KQ] = sum;
|
| 161 |
+
}
|
| 162 |
+
}
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
#pragma unroll
|
| 166 |
+
for (int j = 0; j < ncols; ++j) {
|
| 167 |
+
half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j];
|
| 168 |
+
|
| 169 |
+
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
| 170 |
+
if (threadIdx.x == 0) {
|
| 171 |
+
kqmax_shared[j][threadIdx.y] = kqmax_new_j;
|
| 172 |
+
}
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
__syncthreads();
|
| 176 |
+
|
| 177 |
+
#pragma unroll
|
| 178 |
+
for (int j = 0; j < ncols; ++j) {
|
| 179 |
+
half kqmax_new_j = kqmax_shared[j][threadIdx.x];
|
| 180 |
+
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
| 181 |
+
|
| 182 |
+
const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
|
| 183 |
+
kqmax[j] = kqmax_new_j;
|
| 184 |
+
|
| 185 |
+
const half val = hexp(KQ[j*D + tid] - kqmax[j]);
|
| 186 |
+
kqsum[j] = kqsum[j]*KQ_max_scale + val;
|
| 187 |
+
KQ[j*D + tid] = val;
|
| 188 |
+
|
| 189 |
+
VKQ[j] *= __half2half2(KQ_max_scale);
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
__syncthreads();
|
| 193 |
+
|
| 194 |
+
#pragma unroll
|
| 195 |
+
for (int k0 = 0; k0 < D; k0 += 2) {
|
| 196 |
+
if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) {
|
| 197 |
+
break;
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
half2 V_k;
|
| 201 |
+
reinterpret_cast<half&>(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid];
|
| 202 |
+
reinterpret_cast<half&>(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid];
|
| 203 |
+
#pragma unroll
|
| 204 |
+
for (int j = 0; j < ncols; ++j) {
|
| 205 |
+
VKQ[j] += V_k*KQ2[j*(D/2) + k0/2];
|
| 206 |
+
}
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
__syncthreads();
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
#pragma unroll
|
| 213 |
+
for (int j = 0; j < ncols; ++j) {
|
| 214 |
+
kqsum[j] = warp_reduce_sum(kqsum[j]);
|
| 215 |
+
if (threadIdx.x == 0) {
|
| 216 |
+
kqsum_shared[j][threadIdx.y] = kqsum[j];
|
| 217 |
+
}
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
__syncthreads();
|
| 221 |
+
|
| 222 |
+
#pragma unroll
|
| 223 |
+
for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
|
| 224 |
+
kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
|
| 225 |
+
kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
|
| 226 |
+
|
| 227 |
+
half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
|
| 228 |
+
if (parallel_blocks == 1) {
|
| 229 |
+
dst_val /= kqsum[j_VKQ];
|
| 230 |
+
}
|
| 231 |
+
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
|
| 232 |
+
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
if (parallel_blocks != 1 && tid != 0) {
|
| 236 |
+
#pragma unroll
|
| 237 |
+
for (int j = 0; j < ncols; ++j) {
|
| 238 |
+
dst_meta[(ic0 + j)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j], kqsum[j]);
|
| 239 |
+
}
|
| 240 |
+
}
|
| 241 |
+
#else
|
| 242 |
+
NO_DEVICE_CODE;
|
| 243 |
+
#endif // FP16_AVAILABLE
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_f16(
|
| 247 |
+
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
|
| 248 |
+
ggml_cuda_pool & pool, cudaStream_t main_stream
|
| 249 |
+
) {
|
| 250 |
+
ggml_cuda_pool_alloc<float> dst_tmp(pool);
|
| 251 |
+
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
|
| 252 |
+
|
| 253 |
+
if (parallel_blocks > 1) {
|
| 254 |
+
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
| 255 |
+
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
|
| 259 |
+
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
| 260 |
+
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
|
| 261 |
+
const int shmem = 0;
|
| 262 |
+
|
| 263 |
+
float scale = 1.0f;
|
| 264 |
+
float max_bias = 0.0f;
|
| 265 |
+
|
| 266 |
+
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
|
| 267 |
+
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
|
| 268 |
+
|
| 269 |
+
const uint32_t n_head = Q->ne[2];
|
| 270 |
+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
| 271 |
+
|
| 272 |
+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 273 |
+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
| 274 |
+
|
| 275 |
+
flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>
|
| 276 |
+
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
| 277 |
+
(const char *) Q->data,
|
| 278 |
+
(const char *) K->data,
|
| 279 |
+
(const char *) V->data,
|
| 280 |
+
mask ? ((const char *) mask->data) : nullptr,
|
| 281 |
+
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
| 282 |
+
scale, max_bias, m0, m1, n_head_log2,
|
| 283 |
+
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
| 284 |
+
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
| 285 |
+
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
| 286 |
+
Q->nb[1], Q->nb[2], Q->nb[3],
|
| 287 |
+
K->nb[1], K->nb[2], K->nb[3],
|
| 288 |
+
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
| 289 |
+
);
|
| 290 |
+
CUDA_CHECK(cudaGetLastError());
|
| 291 |
+
|
| 292 |
+
if (parallel_blocks == 1) {
|
| 293 |
+
return;
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
const dim3 block_dim_combine(D, 1, 1);
|
| 297 |
+
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
|
| 298 |
+
const int shmem_combine = 0;
|
| 299 |
+
|
| 300 |
+
flash_attn_combine_results<D, parallel_blocks>
|
| 301 |
+
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
|
| 302 |
+
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
|
| 303 |
+
CUDA_CHECK(cudaGetLastError());
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 307 |
+
const ggml_tensor * Q = dst->src[0];
|
| 308 |
+
const ggml_tensor * K = dst->src[1];
|
| 309 |
+
const ggml_tensor * V = dst->src[2];
|
| 310 |
+
|
| 311 |
+
const ggml_tensor * mask = dst->src[3];
|
| 312 |
+
|
| 313 |
+
ggml_tensor * KQV = dst;
|
| 314 |
+
|
| 315 |
+
const int32_t precision = KQV->op_params[2];
|
| 316 |
+
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
|
| 317 |
+
|
| 318 |
+
constexpr int cols_per_block = 1;
|
| 319 |
+
constexpr int parallel_blocks = 4;
|
| 320 |
+
switch (Q->ne[0]) {
|
| 321 |
+
case 64:
|
| 322 |
+
launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 323 |
+
break;
|
| 324 |
+
case 128:
|
| 325 |
+
launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 326 |
+
break;
|
| 327 |
+
case 256:
|
| 328 |
+
launch_fattn_vec_f16<256, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 329 |
+
break;
|
| 330 |
+
default:
|
| 331 |
+
GGML_ASSERT(false);
|
| 332 |
+
break;
|
| 333 |
+
}
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
void ggml_cuda_flash_attn_ext_vec_f16_no_mma(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 337 |
+
const ggml_tensor * Q = dst->src[0];
|
| 338 |
+
const ggml_tensor * K = dst->src[1];
|
| 339 |
+
const ggml_tensor * V = dst->src[2];
|
| 340 |
+
|
| 341 |
+
const ggml_tensor * mask = dst->src[3];
|
| 342 |
+
|
| 343 |
+
ggml_tensor * KQV = dst;
|
| 344 |
+
|
| 345 |
+
const int32_t precision = KQV->op_params[2];
|
| 346 |
+
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
|
| 347 |
+
GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
| 348 |
+
|
| 349 |
+
if (Q->ne[1] == 1) {
|
| 350 |
+
constexpr int cols_per_block = 1;
|
| 351 |
+
constexpr int parallel_blocks = 4;
|
| 352 |
+
switch (Q->ne[0]) {
|
| 353 |
+
case 64:
|
| 354 |
+
launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 355 |
+
break;
|
| 356 |
+
case 128:
|
| 357 |
+
launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 358 |
+
break;
|
| 359 |
+
default:
|
| 360 |
+
GGML_ASSERT(false);
|
| 361 |
+
break;
|
| 362 |
+
}
|
| 363 |
+
return;
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
if (Q->ne[1] == 2) {
|
| 367 |
+
constexpr int cols_per_block = 2;
|
| 368 |
+
constexpr int parallel_blocks = 4;
|
| 369 |
+
switch (Q->ne[0]) {
|
| 370 |
+
case 64:
|
| 371 |
+
launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 372 |
+
break;
|
| 373 |
+
case 128:
|
| 374 |
+
launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 375 |
+
break;
|
| 376 |
+
default:
|
| 377 |
+
GGML_ASSERT(false);
|
| 378 |
+
break;
|
| 379 |
+
}
|
| 380 |
+
return;
|
| 381 |
+
}
|
| 382 |
+
|
| 383 |
+
if (Q->ne[1] <= 4) {
|
| 384 |
+
constexpr int cols_per_block = 4;
|
| 385 |
+
constexpr int parallel_blocks = 4;
|
| 386 |
+
switch (Q->ne[0]) {
|
| 387 |
+
case 64:
|
| 388 |
+
launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 389 |
+
break;
|
| 390 |
+
case 128:
|
| 391 |
+
launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 392 |
+
break;
|
| 393 |
+
default:
|
| 394 |
+
GGML_ASSERT(false);
|
| 395 |
+
break;
|
| 396 |
+
}
|
| 397 |
+
return;
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
if (Q->ne[1] <= 8) {
|
| 401 |
+
constexpr int cols_per_block = 8;
|
| 402 |
+
constexpr int parallel_blocks = 4;
|
| 403 |
+
switch (Q->ne[0]) {
|
| 404 |
+
case 64:
|
| 405 |
+
launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 406 |
+
break;
|
| 407 |
+
case 128:
|
| 408 |
+
launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 409 |
+
break;
|
| 410 |
+
default:
|
| 411 |
+
GGML_ASSERT(false);
|
| 412 |
+
break;
|
| 413 |
+
}
|
| 414 |
+
return;
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
constexpr int cols_per_block = 8;
|
| 418 |
+
constexpr int parallel_blocks = 1;
|
| 419 |
+
switch (Q->ne[0]) {
|
| 420 |
+
case 64:
|
| 421 |
+
launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 422 |
+
break;
|
| 423 |
+
case 128:
|
| 424 |
+
launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 425 |
+
break;
|
| 426 |
+
default:
|
| 427 |
+
GGML_ASSERT(false);
|
| 428 |
+
break;
|
| 429 |
+
}
|
| 430 |
+
}
|
ggml-cuda/fattn-vec-f16.cuh
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "common.cuh"
|
| 2 |
+
|
| 3 |
+
void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 4 |
+
|
| 5 |
+
void ggml_cuda_flash_attn_ext_vec_f16_no_mma(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
ggml-cuda/fattn-vec-f32.cu
ADDED
|
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "common.cuh"
|
| 2 |
+
#include "fattn-common.cuh"
|
| 3 |
+
#include "fattn-vec-f32.cuh"
|
| 4 |
+
|
| 5 |
+
template<int D, int ncols, int parallel_blocks> // D == head size
|
| 6 |
+
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
| 7 |
+
__launch_bounds__(D, 1)
|
| 8 |
+
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
| 9 |
+
static __global__ void flash_attn_vec_ext_f32(
|
| 10 |
+
const char * __restrict__ Q,
|
| 11 |
+
const char * __restrict__ K,
|
| 12 |
+
const char * __restrict__ V,
|
| 13 |
+
const char * __restrict__ mask,
|
| 14 |
+
float * __restrict__ dst,
|
| 15 |
+
float2 * __restrict__ dst_meta,
|
| 16 |
+
const float scale,
|
| 17 |
+
const float max_bias,
|
| 18 |
+
const float m0,
|
| 19 |
+
const float m1,
|
| 20 |
+
const uint32_t n_head_log2,
|
| 21 |
+
const int ne00,
|
| 22 |
+
const int ne01,
|
| 23 |
+
const int ne02,
|
| 24 |
+
const int ne03,
|
| 25 |
+
const int ne10,
|
| 26 |
+
const int ne11,
|
| 27 |
+
const int ne12,
|
| 28 |
+
const int ne13,
|
| 29 |
+
const int ne31,
|
| 30 |
+
const int nb31,
|
| 31 |
+
const int nb01,
|
| 32 |
+
const int nb02,
|
| 33 |
+
const int nb03,
|
| 34 |
+
const int nb11,
|
| 35 |
+
const int nb12,
|
| 36 |
+
const int nb13,
|
| 37 |
+
const int ne0,
|
| 38 |
+
const int ne1,
|
| 39 |
+
const int ne2,
|
| 40 |
+
const int ne3) {
|
| 41 |
+
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
| 42 |
+
|
| 43 |
+
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
| 44 |
+
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
| 45 |
+
|
| 46 |
+
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
| 47 |
+
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
|
| 48 |
+
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
|
| 49 |
+
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
| 50 |
+
const half * maskh = (const half *) mask + ne11*ic0;
|
| 51 |
+
|
| 52 |
+
const int stride_KV = nb11 / sizeof(half);
|
| 53 |
+
const int stride_KV2 = nb11 / sizeof(half2);
|
| 54 |
+
|
| 55 |
+
float slope = 1.0f;
|
| 56 |
+
|
| 57 |
+
// ALiBi
|
| 58 |
+
if (max_bias > 0.0f) {
|
| 59 |
+
const int h = blockIdx.y;
|
| 60 |
+
|
| 61 |
+
const float base = h < n_head_log2 ? m0 : m1;
|
| 62 |
+
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
| 63 |
+
|
| 64 |
+
slope = powf(base, exph);
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
| 68 |
+
constexpr int nwarps = D / WARP_SIZE;
|
| 69 |
+
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
| 70 |
+
__builtin_assume(tid < D);
|
| 71 |
+
|
| 72 |
+
__shared__ float KQ[ncols*D];
|
| 73 |
+
#pragma unroll
|
| 74 |
+
for (int j = 0; j < ncols; ++j) {
|
| 75 |
+
KQ[j*D + tid] = -FLT_MAX/2.0f;
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
float kqmax[ncols];
|
| 79 |
+
#pragma unroll
|
| 80 |
+
for (int j = 0; j < ncols; ++j) {
|
| 81 |
+
kqmax[j] = -FLT_MAX/2.0f;
|
| 82 |
+
}
|
| 83 |
+
float kqsum[ncols] = {0.0f};
|
| 84 |
+
|
| 85 |
+
__shared__ float kqmax_shared[ncols][WARP_SIZE];
|
| 86 |
+
__shared__ float kqsum_shared[ncols][WARP_SIZE];
|
| 87 |
+
#pragma unroll
|
| 88 |
+
for (int j = 0; j < ncols; ++j) {
|
| 89 |
+
if (threadIdx.y == 0) {
|
| 90 |
+
kqmax_shared[j][threadIdx.x] = -FLT_MAX/2.0f;
|
| 91 |
+
kqsum_shared[j][threadIdx.x] = 0.0f;
|
| 92 |
+
}
|
| 93 |
+
}
|
| 94 |
+
__syncthreads();
|
| 95 |
+
|
| 96 |
+
// Convert Q to half2 and store in registers:
|
| 97 |
+
float2 Q_h2[ncols][D/(2*WARP_SIZE)];
|
| 98 |
+
#pragma unroll
|
| 99 |
+
for (int j = 0; j < ncols; ++j) {
|
| 100 |
+
#pragma unroll
|
| 101 |
+
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
| 102 |
+
const int i = i0 + threadIdx.x;
|
| 103 |
+
|
| 104 |
+
Q_h2[j][i0/WARP_SIZE] = Q_f2[j*(nb01/sizeof(float2)) + i];
|
| 105 |
+
Q_h2[j][i0/WARP_SIZE].x *= scale;
|
| 106 |
+
Q_h2[j][i0/WARP_SIZE].y *= scale;
|
| 107 |
+
}
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
float VKQ[ncols] = {0.0f};
|
| 111 |
+
|
| 112 |
+
const int k_start = parallel_blocks == 1 ? 0 : ip*D;
|
| 113 |
+
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
|
| 114 |
+
// Calculate KQ tile and keep track of new maximum KQ values:
|
| 115 |
+
|
| 116 |
+
float kqmax_new_arr[ncols];
|
| 117 |
+
#pragma unroll
|
| 118 |
+
for (int j = 0; j < ncols; ++j) {
|
| 119 |
+
kqmax_new_arr[j] = kqmax[j];
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
#pragma unroll
|
| 123 |
+
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
|
| 124 |
+
const int i_KQ = i_KQ_0 + threadIdx.y;
|
| 125 |
+
|
| 126 |
+
if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) {
|
| 127 |
+
break;
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
float sum[ncols] = {0.0f};
|
| 131 |
+
#pragma unroll
|
| 132 |
+
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
|
| 133 |
+
const int k_KQ = k_KQ_0 + threadIdx.x;
|
| 134 |
+
|
| 135 |
+
const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
|
| 136 |
+
#pragma unroll
|
| 137 |
+
for (int j = 0; j < ncols; ++j) {
|
| 138 |
+
sum[j] += __low2float(K_ik) * Q_h2[j][k_KQ_0/WARP_SIZE].x;
|
| 139 |
+
sum[j] += __high2float(K_ik) * Q_h2[j][k_KQ_0/WARP_SIZE].y;
|
| 140 |
+
}
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
#pragma unroll
|
| 144 |
+
for (int j = 0; j < ncols; ++j) {
|
| 145 |
+
sum[j] = warp_reduce_sum(sum[j]);
|
| 146 |
+
sum[j] += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
| 147 |
+
|
| 148 |
+
kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum[j]);
|
| 149 |
+
|
| 150 |
+
if (threadIdx.x == 0) {
|
| 151 |
+
KQ[j*D + i_KQ] = sum[j];
|
| 152 |
+
}
|
| 153 |
+
}
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
#pragma unroll
|
| 157 |
+
for (int j = 0; j < ncols; ++j) {
|
| 158 |
+
float kqmax_new_j = kqmax_new_arr[j];
|
| 159 |
+
|
| 160 |
+
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
| 161 |
+
if (threadIdx.x == 0) {
|
| 162 |
+
kqmax_shared[j][threadIdx.y] = kqmax_new_j;
|
| 163 |
+
}
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
__syncthreads();
|
| 167 |
+
|
| 168 |
+
#pragma unroll
|
| 169 |
+
for (int j = 0; j < ncols; ++j) {
|
| 170 |
+
float kqmax_new_j = kqmax_shared[j][threadIdx.x];
|
| 171 |
+
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
| 172 |
+
|
| 173 |
+
const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j);
|
| 174 |
+
kqmax[j] = kqmax_new_j;
|
| 175 |
+
|
| 176 |
+
const float val = expf(KQ[j*D + tid] - kqmax[j]);
|
| 177 |
+
kqsum[j] = kqsum[j]*KQ_max_scale + val;
|
| 178 |
+
KQ[j*D + tid] = val;
|
| 179 |
+
|
| 180 |
+
VKQ[j] *= KQ_max_scale;
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
__syncthreads();
|
| 184 |
+
|
| 185 |
+
#pragma unroll
|
| 186 |
+
for (int k = 0; k < D; ++k) {
|
| 187 |
+
if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k >= ne11) {
|
| 188 |
+
break;
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
const float V_ki = __half2float(V_h[(k_VKQ_0 + k)*stride_KV + tid]);
|
| 192 |
+
#pragma unroll
|
| 193 |
+
for (int j = 0; j < ncols; ++j) {
|
| 194 |
+
VKQ[j] += V_ki*KQ[j*D + k];
|
| 195 |
+
}
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
__syncthreads();
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
#pragma unroll
|
| 202 |
+
for (int j = 0; j < ncols; ++j) {
|
| 203 |
+
kqsum[j] = warp_reduce_sum(kqsum[j]);
|
| 204 |
+
if (threadIdx.x == 0) {
|
| 205 |
+
kqsum_shared[j][threadIdx.y] = kqsum[j];
|
| 206 |
+
}
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
__syncthreads();
|
| 210 |
+
|
| 211 |
+
#pragma unroll
|
| 212 |
+
for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
|
| 213 |
+
kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
|
| 214 |
+
kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
|
| 215 |
+
|
| 216 |
+
float dst_val = VKQ[j_VKQ];
|
| 217 |
+
if (parallel_blocks == 1) {
|
| 218 |
+
dst_val /= kqsum[j_VKQ];
|
| 219 |
+
}
|
| 220 |
+
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
|
| 221 |
+
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
if (parallel_blocks != 1 && tid != 0) {
|
| 225 |
+
#pragma unroll
|
| 226 |
+
for (int j = 0; j < ncols; ++j) {
|
| 227 |
+
dst_meta[(ic0 + j)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j], kqsum[j]);
|
| 228 |
+
}
|
| 229 |
+
}
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_f32(
|
| 233 |
+
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
|
| 234 |
+
ggml_cuda_pool & pool, cudaStream_t main_stream
|
| 235 |
+
) {
|
| 236 |
+
ggml_cuda_pool_alloc<float> dst_tmp(pool);
|
| 237 |
+
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
|
| 238 |
+
|
| 239 |
+
if (parallel_blocks > 1) {
|
| 240 |
+
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
| 241 |
+
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
|
| 245 |
+
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
| 246 |
+
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
|
| 247 |
+
const int shmem = 0;
|
| 248 |
+
|
| 249 |
+
float scale = 1.0f;
|
| 250 |
+
float max_bias = 0.0f;
|
| 251 |
+
|
| 252 |
+
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
|
| 253 |
+
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
|
| 254 |
+
|
| 255 |
+
const uint32_t n_head = Q->ne[2];
|
| 256 |
+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
| 257 |
+
|
| 258 |
+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 259 |
+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
| 260 |
+
|
| 261 |
+
flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks>
|
| 262 |
+
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
| 263 |
+
(const char *) Q->data,
|
| 264 |
+
(const char *) K->data,
|
| 265 |
+
(const char *) V->data,
|
| 266 |
+
mask ? ((const char *) mask->data) : nullptr,
|
| 267 |
+
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
| 268 |
+
scale, max_bias, m0, m1, n_head_log2,
|
| 269 |
+
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
| 270 |
+
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
| 271 |
+
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
| 272 |
+
Q->nb[1], Q->nb[2], Q->nb[3],
|
| 273 |
+
K->nb[1], K->nb[2], K->nb[3],
|
| 274 |
+
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
| 275 |
+
);
|
| 276 |
+
CUDA_CHECK(cudaGetLastError());
|
| 277 |
+
|
| 278 |
+
if (parallel_blocks == 1) {
|
| 279 |
+
return;
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
const dim3 block_dim_combine(D, 1, 1);
|
| 283 |
+
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
|
| 284 |
+
const int shmem_combine = 0;
|
| 285 |
+
|
| 286 |
+
flash_attn_combine_results<D, parallel_blocks>
|
| 287 |
+
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
|
| 288 |
+
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
|
| 289 |
+
CUDA_CHECK(cudaGetLastError());
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 293 |
+
const ggml_tensor * Q = dst->src[0];
|
| 294 |
+
const ggml_tensor * K = dst->src[1];
|
| 295 |
+
const ggml_tensor * V = dst->src[2];
|
| 296 |
+
|
| 297 |
+
const ggml_tensor * mask = dst->src[3];
|
| 298 |
+
|
| 299 |
+
ggml_tensor * KQV = dst;
|
| 300 |
+
|
| 301 |
+
GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
| 302 |
+
|
| 303 |
+
if (Q->ne[1] == 1) {
|
| 304 |
+
constexpr int cols_per_block = 1;
|
| 305 |
+
constexpr int parallel_blocks = 4;
|
| 306 |
+
switch (Q->ne[0]) {
|
| 307 |
+
case 64:
|
| 308 |
+
launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 309 |
+
break;
|
| 310 |
+
case 128:
|
| 311 |
+
launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 312 |
+
break;
|
| 313 |
+
default:
|
| 314 |
+
GGML_ASSERT(false);
|
| 315 |
+
break;
|
| 316 |
+
}
|
| 317 |
+
return;
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
if (Q->ne[1] == 2) {
|
| 321 |
+
constexpr int cols_per_block = 2;
|
| 322 |
+
constexpr int parallel_blocks = 4;
|
| 323 |
+
switch (Q->ne[0]) {
|
| 324 |
+
case 64:
|
| 325 |
+
launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 326 |
+
break;
|
| 327 |
+
case 128:
|
| 328 |
+
launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 329 |
+
break;
|
| 330 |
+
default:
|
| 331 |
+
GGML_ASSERT(false);
|
| 332 |
+
break;
|
| 333 |
+
}
|
| 334 |
+
return;
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
if (Q->ne[1] <= 4) {
|
| 338 |
+
constexpr int cols_per_block = 4;
|
| 339 |
+
constexpr int parallel_blocks = 4;
|
| 340 |
+
switch (Q->ne[0]) {
|
| 341 |
+
case 64:
|
| 342 |
+
launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 343 |
+
break;
|
| 344 |
+
case 128:
|
| 345 |
+
launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 346 |
+
break;
|
| 347 |
+
default:
|
| 348 |
+
GGML_ASSERT(false);
|
| 349 |
+
break;
|
| 350 |
+
}
|
| 351 |
+
return;
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
if (Q->ne[1] <= 8) {
|
| 355 |
+
constexpr int cols_per_block = 8;
|
| 356 |
+
constexpr int parallel_blocks = 4;
|
| 357 |
+
switch (Q->ne[0]) {
|
| 358 |
+
case 64:
|
| 359 |
+
launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 360 |
+
break;
|
| 361 |
+
case 128:
|
| 362 |
+
launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 363 |
+
break;
|
| 364 |
+
default:
|
| 365 |
+
GGML_ASSERT(false);
|
| 366 |
+
break;
|
| 367 |
+
}
|
| 368 |
+
return;
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
constexpr int cols_per_block = 8;
|
| 372 |
+
constexpr int parallel_blocks = 1;
|
| 373 |
+
switch (Q->ne[0]) {
|
| 374 |
+
case 64:
|
| 375 |
+
launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 376 |
+
break;
|
| 377 |
+
case 128:
|
| 378 |
+
launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 379 |
+
break;
|
| 380 |
+
default:
|
| 381 |
+
GGML_ASSERT(false);
|
| 382 |
+
break;
|
| 383 |
+
}
|
| 384 |
+
}
|
ggml-cuda/fattn-vec-f32.cuh
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "common.cuh"
|
| 2 |
+
|
| 3 |
+
void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
ggml-cuda/fattn.cu
CHANGED
|
@@ -1,4 +1,7 @@
|
|
| 1 |
#include "common.cuh"
|
|
|
|
|
|
|
|
|
|
| 2 |
#include "fattn.cuh"
|
| 3 |
|
| 4 |
#include <cstdint>
|
|
@@ -7,251 +10,6 @@
|
|
| 7 |
#include <mma.h>
|
| 8 |
#endif
|
| 9 |
|
| 10 |
-
#define FATTN_KQ_STRIDE 256
|
| 11 |
-
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
|
| 12 |
-
#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
|
| 13 |
-
|
| 14 |
-
template<int D, int ncols, int parallel_blocks> // D == head size
|
| 15 |
-
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
| 16 |
-
__launch_bounds__(D, 1)
|
| 17 |
-
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
| 18 |
-
static __global__ void flash_attn_vec_ext_f16(
|
| 19 |
-
const char * __restrict__ Q,
|
| 20 |
-
const char * __restrict__ K,
|
| 21 |
-
const char * __restrict__ V,
|
| 22 |
-
const char * __restrict__ mask,
|
| 23 |
-
float * __restrict__ dst,
|
| 24 |
-
float2 * __restrict__ dst_meta,
|
| 25 |
-
const float scale,
|
| 26 |
-
const float max_bias,
|
| 27 |
-
const float m0,
|
| 28 |
-
const float m1,
|
| 29 |
-
const uint32_t n_head_log2,
|
| 30 |
-
const int ne00,
|
| 31 |
-
const int ne01,
|
| 32 |
-
const int ne02,
|
| 33 |
-
const int ne03,
|
| 34 |
-
const int ne10,
|
| 35 |
-
const int ne11,
|
| 36 |
-
const int ne12,
|
| 37 |
-
const int ne13,
|
| 38 |
-
const int ne31,
|
| 39 |
-
const int nb31,
|
| 40 |
-
const int nb01,
|
| 41 |
-
const int nb02,
|
| 42 |
-
const int nb03,
|
| 43 |
-
const int nb11,
|
| 44 |
-
const int nb12,
|
| 45 |
-
const int nb13,
|
| 46 |
-
const int ne0,
|
| 47 |
-
const int ne1,
|
| 48 |
-
const int ne2,
|
| 49 |
-
const int ne3) {
|
| 50 |
-
#if FP16_AVAILABLE
|
| 51 |
-
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
| 52 |
-
|
| 53 |
-
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
| 54 |
-
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
| 55 |
-
|
| 56 |
-
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
| 57 |
-
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
|
| 58 |
-
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
|
| 59 |
-
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
| 60 |
-
const half * maskh = (const half *) mask + ne11*ic0;
|
| 61 |
-
|
| 62 |
-
const int stride_KV = nb11 / sizeof(half);
|
| 63 |
-
const int stride_KV2 = nb11 / sizeof(half2);
|
| 64 |
-
|
| 65 |
-
half slopeh = __float2half(1.0f);
|
| 66 |
-
|
| 67 |
-
// ALiBi
|
| 68 |
-
if (max_bias > 0.0f) {
|
| 69 |
-
const int h = blockIdx.y;
|
| 70 |
-
|
| 71 |
-
const float base = h < n_head_log2 ? m0 : m1;
|
| 72 |
-
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
| 73 |
-
|
| 74 |
-
slopeh = __float2half(powf(base, exph));
|
| 75 |
-
}
|
| 76 |
-
|
| 77 |
-
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
| 78 |
-
constexpr int nwarps = D / WARP_SIZE;
|
| 79 |
-
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
| 80 |
-
__builtin_assume(tid < D);
|
| 81 |
-
|
| 82 |
-
__shared__ half KQ[ncols*D];
|
| 83 |
-
#pragma unroll
|
| 84 |
-
for (int j = 0; j < ncols; ++j) {
|
| 85 |
-
KQ[j*D + tid] = -HALF_MAX_HALF;
|
| 86 |
-
}
|
| 87 |
-
half2 * KQ2 = (half2 *) KQ;
|
| 88 |
-
|
| 89 |
-
half kqmax[ncols];
|
| 90 |
-
#pragma unroll
|
| 91 |
-
for (int j = 0; j < ncols; ++j) {
|
| 92 |
-
kqmax[j] = -HALF_MAX_HALF;
|
| 93 |
-
}
|
| 94 |
-
half kqsum[ncols] = {0.0f};
|
| 95 |
-
|
| 96 |
-
__shared__ half kqmax_shared[ncols][WARP_SIZE];
|
| 97 |
-
__shared__ half kqsum_shared[ncols][WARP_SIZE];
|
| 98 |
-
#pragma unroll
|
| 99 |
-
for (int j = 0; j < ncols; ++j) {
|
| 100 |
-
if (threadIdx.y == 0) {
|
| 101 |
-
kqmax_shared[j][threadIdx.x] = -HALF_MAX_HALF;
|
| 102 |
-
kqsum_shared[j][threadIdx.x] = 0.0f;
|
| 103 |
-
}
|
| 104 |
-
}
|
| 105 |
-
__syncthreads();
|
| 106 |
-
|
| 107 |
-
// Convert Q to half2 and store in registers:
|
| 108 |
-
half2 Q_h2[ncols][D/(2*WARP_SIZE)];
|
| 109 |
-
#pragma unroll
|
| 110 |
-
for (int j = 0; j < ncols; ++j) {
|
| 111 |
-
#pragma unroll
|
| 112 |
-
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
| 113 |
-
const int i = i0 + threadIdx.x;
|
| 114 |
-
|
| 115 |
-
const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i];
|
| 116 |
-
Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
|
| 117 |
-
}
|
| 118 |
-
}
|
| 119 |
-
|
| 120 |
-
half2 VKQ[ncols] = {{0.0f, 0.0f}};
|
| 121 |
-
|
| 122 |
-
const int k_start = parallel_blocks == 1 ? 0 : ip*D;
|
| 123 |
-
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
|
| 124 |
-
// Calculate KQ tile and keep track of new maximum KQ values:
|
| 125 |
-
|
| 126 |
-
// For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
|
| 127 |
-
// see https://github.com/ggerganov/llama.cpp/pull/7061 .
|
| 128 |
-
// Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable).
|
| 129 |
-
half kqmax_new = kqmax[0];
|
| 130 |
-
half kqmax_new_arr[ncols];
|
| 131 |
-
#pragma unroll
|
| 132 |
-
for (int j = 0; j < ncols; ++j) {
|
| 133 |
-
kqmax_new_arr[j] = kqmax[j];
|
| 134 |
-
}
|
| 135 |
-
|
| 136 |
-
#pragma unroll
|
| 137 |
-
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
|
| 138 |
-
const int i_KQ = i_KQ_0 + threadIdx.y;
|
| 139 |
-
|
| 140 |
-
if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) {
|
| 141 |
-
break;
|
| 142 |
-
}
|
| 143 |
-
|
| 144 |
-
half2 sum2[ncols] = {{0.0f, 0.0f}};
|
| 145 |
-
#pragma unroll
|
| 146 |
-
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
|
| 147 |
-
const int k_KQ = k_KQ_0 + threadIdx.x;
|
| 148 |
-
|
| 149 |
-
const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
|
| 150 |
-
#pragma unroll
|
| 151 |
-
for (int j = 0; j < ncols; ++j) {
|
| 152 |
-
sum2[j] += K_ik * Q_h2[j][k_KQ_0/WARP_SIZE];
|
| 153 |
-
}
|
| 154 |
-
}
|
| 155 |
-
|
| 156 |
-
#pragma unroll
|
| 157 |
-
for (int j = 0; j < ncols; ++j) {
|
| 158 |
-
sum2[j] = warp_reduce_sum(sum2[j]);
|
| 159 |
-
half sum = __low2half(sum2[j]) + __high2half(sum2[j]);
|
| 160 |
-
sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
|
| 161 |
-
|
| 162 |
-
if (ncols == 1) {
|
| 163 |
-
kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
|
| 164 |
-
} else {
|
| 165 |
-
kqmax_new_arr[j] = ggml_cuda_hmax(kqmax_new_arr[j], sum);
|
| 166 |
-
}
|
| 167 |
-
|
| 168 |
-
if (threadIdx.x == 0) {
|
| 169 |
-
KQ[j*D + i_KQ] = sum;
|
| 170 |
-
}
|
| 171 |
-
}
|
| 172 |
-
}
|
| 173 |
-
|
| 174 |
-
#pragma unroll
|
| 175 |
-
for (int j = 0; j < ncols; ++j) {
|
| 176 |
-
half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j];
|
| 177 |
-
|
| 178 |
-
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
| 179 |
-
if (threadIdx.x == 0) {
|
| 180 |
-
kqmax_shared[j][threadIdx.y] = kqmax_new_j;
|
| 181 |
-
}
|
| 182 |
-
}
|
| 183 |
-
|
| 184 |
-
__syncthreads();
|
| 185 |
-
|
| 186 |
-
#pragma unroll
|
| 187 |
-
for (int j = 0; j < ncols; ++j) {
|
| 188 |
-
half kqmax_new_j = kqmax_shared[j][threadIdx.x];
|
| 189 |
-
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
| 190 |
-
|
| 191 |
-
const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
|
| 192 |
-
kqmax[j] = kqmax_new_j;
|
| 193 |
-
|
| 194 |
-
const half val = hexp(KQ[j*D + tid] - kqmax[j]);
|
| 195 |
-
kqsum[j] = kqsum[j]*KQ_max_scale + val;
|
| 196 |
-
KQ[j*D + tid] = val;
|
| 197 |
-
|
| 198 |
-
VKQ[j] *= __half2half2(KQ_max_scale);
|
| 199 |
-
}
|
| 200 |
-
|
| 201 |
-
__syncthreads();
|
| 202 |
-
|
| 203 |
-
#pragma unroll
|
| 204 |
-
for (int k0 = 0; k0 < D; k0 += 2) {
|
| 205 |
-
if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) {
|
| 206 |
-
break;
|
| 207 |
-
}
|
| 208 |
-
|
| 209 |
-
half2 V_k;
|
| 210 |
-
reinterpret_cast<half&>(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid];
|
| 211 |
-
reinterpret_cast<half&>(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid];
|
| 212 |
-
#pragma unroll
|
| 213 |
-
for (int j = 0; j < ncols; ++j) {
|
| 214 |
-
VKQ[j] += V_k*KQ2[j*(D/2) + k0/2];
|
| 215 |
-
}
|
| 216 |
-
}
|
| 217 |
-
|
| 218 |
-
__syncthreads();
|
| 219 |
-
}
|
| 220 |
-
|
| 221 |
-
#pragma unroll
|
| 222 |
-
for (int j = 0; j < ncols; ++j) {
|
| 223 |
-
kqsum[j] = warp_reduce_sum(kqsum[j]);
|
| 224 |
-
if (threadIdx.x == 0) {
|
| 225 |
-
kqsum_shared[j][threadIdx.y] = kqsum[j];
|
| 226 |
-
}
|
| 227 |
-
}
|
| 228 |
-
|
| 229 |
-
__syncthreads();
|
| 230 |
-
|
| 231 |
-
#pragma unroll
|
| 232 |
-
for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
|
| 233 |
-
kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
|
| 234 |
-
kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
|
| 235 |
-
|
| 236 |
-
half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
|
| 237 |
-
if (parallel_blocks == 1) {
|
| 238 |
-
dst_val /= kqsum[j_VKQ];
|
| 239 |
-
}
|
| 240 |
-
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
|
| 241 |
-
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
|
| 242 |
-
}
|
| 243 |
-
|
| 244 |
-
if (parallel_blocks != 1 && tid != 0) {
|
| 245 |
-
#pragma unroll
|
| 246 |
-
for (int j = 0; j < ncols; ++j) {
|
| 247 |
-
dst_meta[(ic0 + j)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j], kqsum[j]);
|
| 248 |
-
}
|
| 249 |
-
}
|
| 250 |
-
#else
|
| 251 |
-
NO_DEVICE_CODE;
|
| 252 |
-
#endif // FP16_AVAILABLE
|
| 253 |
-
}
|
| 254 |
-
|
| 255 |
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
| 256 |
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
|
| 257 |
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
|
@@ -655,54 +413,6 @@ static __global__ void flash_attn_ext_f16(
|
|
| 655 |
#endif // FP16_MMA_AVAILABLE
|
| 656 |
}
|
| 657 |
|
| 658 |
-
template<int D, int parallel_blocks> // D == head size
|
| 659 |
-
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
| 660 |
-
__launch_bounds__(D, 1)
|
| 661 |
-
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
| 662 |
-
static __global__ void flash_attn_combine_results(
|
| 663 |
-
const float * __restrict__ VKQ_parts,
|
| 664 |
-
const float2 * __restrict__ VKQ_meta,
|
| 665 |
-
float * __restrict__ dst) {
|
| 666 |
-
#if FP16_AVAILABLE
|
| 667 |
-
VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x;
|
| 668 |
-
VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x;
|
| 669 |
-
dst += D * gridDim.y*blockIdx.x;
|
| 670 |
-
|
| 671 |
-
const int tid = threadIdx.x;
|
| 672 |
-
__builtin_assume(tid < D);
|
| 673 |
-
|
| 674 |
-
__shared__ float2 meta[parallel_blocks];
|
| 675 |
-
if (tid < 2*parallel_blocks) {
|
| 676 |
-
((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid];
|
| 677 |
-
}
|
| 678 |
-
|
| 679 |
-
__syncthreads();
|
| 680 |
-
|
| 681 |
-
float kqmax = meta[0].x;
|
| 682 |
-
#pragma unroll
|
| 683 |
-
for (int l = 1; l < parallel_blocks; ++l) {
|
| 684 |
-
kqmax = max(kqmax, meta[l].x);
|
| 685 |
-
}
|
| 686 |
-
|
| 687 |
-
float VKQ_numerator = 0.0f;
|
| 688 |
-
float VKQ_denominator = 0.0f;
|
| 689 |
-
#pragma unroll
|
| 690 |
-
for (int l = 0; l < parallel_blocks; ++l) {
|
| 691 |
-
const float diff = meta[l].x - kqmax;
|
| 692 |
-
const float KQ_max_scale = expf(diff);
|
| 693 |
-
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
|
| 694 |
-
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
|
| 695 |
-
|
| 696 |
-
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
|
| 697 |
-
VKQ_denominator += KQ_max_scale * meta[l].y;
|
| 698 |
-
}
|
| 699 |
-
|
| 700 |
-
dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
|
| 701 |
-
#else
|
| 702 |
-
NO_DEVICE_CODE;
|
| 703 |
-
#endif // FP16_AVAILABLE
|
| 704 |
-
}
|
| 705 |
-
|
| 706 |
constexpr int get_max_power_of_2(int x) {
|
| 707 |
return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;
|
| 708 |
}
|
|
@@ -727,66 +437,6 @@ static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed.");
|
|
| 727 |
static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
|
| 728 |
static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
|
| 729 |
|
| 730 |
-
template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_f16(
|
| 731 |
-
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
|
| 732 |
-
ggml_cuda_pool & pool, cudaStream_t main_stream
|
| 733 |
-
) {
|
| 734 |
-
ggml_cuda_pool_alloc<float> dst_tmp(pool);
|
| 735 |
-
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
|
| 736 |
-
|
| 737 |
-
if (parallel_blocks > 1) {
|
| 738 |
-
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
| 739 |
-
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
|
| 740 |
-
}
|
| 741 |
-
|
| 742 |
-
constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
|
| 743 |
-
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
| 744 |
-
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
|
| 745 |
-
const int shmem = 0;
|
| 746 |
-
|
| 747 |
-
float scale = 1.0f;
|
| 748 |
-
float max_bias = 0.0f;
|
| 749 |
-
|
| 750 |
-
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
|
| 751 |
-
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
|
| 752 |
-
|
| 753 |
-
const uint32_t n_head = Q->ne[2];
|
| 754 |
-
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
| 755 |
-
|
| 756 |
-
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 757 |
-
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
| 758 |
-
|
| 759 |
-
flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>
|
| 760 |
-
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
| 761 |
-
(const char *) Q->data,
|
| 762 |
-
(const char *) K->data,
|
| 763 |
-
(const char *) V->data,
|
| 764 |
-
mask ? ((const char *) mask->data) : nullptr,
|
| 765 |
-
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
| 766 |
-
scale, max_bias, m0, m1, n_head_log2,
|
| 767 |
-
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
| 768 |
-
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
| 769 |
-
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
| 770 |
-
Q->nb[1], Q->nb[2], Q->nb[3],
|
| 771 |
-
K->nb[1], K->nb[2], K->nb[3],
|
| 772 |
-
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
| 773 |
-
);
|
| 774 |
-
CUDA_CHECK(cudaGetLastError());
|
| 775 |
-
|
| 776 |
-
if (parallel_blocks == 1) {
|
| 777 |
-
return;
|
| 778 |
-
}
|
| 779 |
-
|
| 780 |
-
const dim3 block_dim_combine(D, 1, 1);
|
| 781 |
-
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
|
| 782 |
-
const int shmem_combine = 0;
|
| 783 |
-
|
| 784 |
-
flash_attn_combine_results<D, parallel_blocks>
|
| 785 |
-
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
|
| 786 |
-
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
|
| 787 |
-
CUDA_CHECK(cudaGetLastError());
|
| 788 |
-
}
|
| 789 |
-
|
| 790 |
template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename KQ_acc_t> void launch_fattn_f16_impl(
|
| 791 |
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
|
| 792 |
ggml_cuda_pool & pool, cudaStream_t main_stream
|
|
@@ -891,95 +541,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
| 891 |
|
| 892 |
const int32_t precision = KQV->op_params[2];
|
| 893 |
|
| 894 |
-
if (!
|
| 895 |
-
|
| 896 |
-
|
| 897 |
-
|
| 898 |
-
if (Q->ne[1] == 1) {
|
| 899 |
-
constexpr int cols_per_block = 1;
|
| 900 |
-
constexpr int parallel_blocks = 4;
|
| 901 |
-
switch (Q->ne[0]) {
|
| 902 |
-
case 64:
|
| 903 |
-
launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 904 |
-
break;
|
| 905 |
-
case 128:
|
| 906 |
-
launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 907 |
-
break;
|
| 908 |
-
default:
|
| 909 |
-
GGML_ASSERT(false);
|
| 910 |
-
break;
|
| 911 |
-
}
|
| 912 |
-
return;
|
| 913 |
-
}
|
| 914 |
-
|
| 915 |
-
if (Q->ne[1] == 2) {
|
| 916 |
-
constexpr int cols_per_block = 2;
|
| 917 |
-
constexpr int parallel_blocks = 4;
|
| 918 |
-
switch (Q->ne[0]) {
|
| 919 |
-
case 64:
|
| 920 |
-
launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 921 |
-
break;
|
| 922 |
-
case 128:
|
| 923 |
-
launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 924 |
-
break;
|
| 925 |
-
default:
|
| 926 |
-
GGML_ASSERT(false);
|
| 927 |
-
break;
|
| 928 |
-
}
|
| 929 |
-
return;
|
| 930 |
-
}
|
| 931 |
-
|
| 932 |
-
if (Q->ne[1] <= 4) {
|
| 933 |
-
constexpr int cols_per_block = 4;
|
| 934 |
-
constexpr int parallel_blocks = 4;
|
| 935 |
-
switch (Q->ne[0]) {
|
| 936 |
-
case 64:
|
| 937 |
-
launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 938 |
-
break;
|
| 939 |
-
case 128:
|
| 940 |
-
launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 941 |
-
break;
|
| 942 |
-
default:
|
| 943 |
-
GGML_ASSERT(false);
|
| 944 |
-
break;
|
| 945 |
-
}
|
| 946 |
-
return;
|
| 947 |
-
}
|
| 948 |
-
|
| 949 |
-
if (Q->ne[1] <= 8) {
|
| 950 |
-
constexpr int cols_per_block = 8;
|
| 951 |
-
constexpr int parallel_blocks = 4;
|
| 952 |
-
switch (Q->ne[0]) {
|
| 953 |
-
case 64:
|
| 954 |
-
launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 955 |
-
break;
|
| 956 |
-
case 128:
|
| 957 |
-
launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 958 |
-
break;
|
| 959 |
-
default:
|
| 960 |
-
GGML_ASSERT(false);
|
| 961 |
-
break;
|
| 962 |
-
}
|
| 963 |
-
return;
|
| 964 |
-
}
|
| 965 |
|
| 966 |
-
|
| 967 |
-
|
| 968 |
-
switch (Q->ne[0]) {
|
| 969 |
-
case 64:
|
| 970 |
-
launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 971 |
-
break;
|
| 972 |
-
case 128:
|
| 973 |
-
launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 974 |
-
break;
|
| 975 |
-
default:
|
| 976 |
-
GGML_ASSERT(false);
|
| 977 |
-
break;
|
| 978 |
-
}
|
| 979 |
return;
|
| 980 |
}
|
| 981 |
|
| 982 |
if (precision != GGML_PREC_DEFAULT) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 983 |
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
|
| 984 |
constexpr int cols_per_block = 16;
|
| 985 |
constexpr int nwarps = 4;
|
|
@@ -1037,22 +614,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
| 1037 |
}
|
| 1038 |
|
| 1039 |
if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
|
| 1040 |
-
|
| 1041 |
-
constexpr int parallel_blocks = 4;
|
| 1042 |
-
switch (Q->ne[0]) {
|
| 1043 |
-
case 64:
|
| 1044 |
-
launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 1045 |
-
break;
|
| 1046 |
-
case 128:
|
| 1047 |
-
launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 1048 |
-
break;
|
| 1049 |
-
case 256:
|
| 1050 |
-
launch_fattn_vec_f16<256, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 1051 |
-
break;
|
| 1052 |
-
default:
|
| 1053 |
-
GGML_ASSERT(false);
|
| 1054 |
-
break;
|
| 1055 |
-
}
|
| 1056 |
return;
|
| 1057 |
}
|
| 1058 |
|
|
|
|
| 1 |
#include "common.cuh"
|
| 2 |
+
#include "fattn-common.cuh"
|
| 3 |
+
#include "fattn-vec-f16.cuh"
|
| 4 |
+
#include "fattn-vec-f32.cuh"
|
| 5 |
#include "fattn.cuh"
|
| 6 |
|
| 7 |
#include <cstdint>
|
|
|
|
| 10 |
#include <mma.h>
|
| 11 |
#endif
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
| 14 |
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
|
| 15 |
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
|
|
|
| 413 |
#endif // FP16_MMA_AVAILABLE
|
| 414 |
}
|
| 415 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
constexpr int get_max_power_of_2(int x) {
|
| 417 |
return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;
|
| 418 |
}
|
|
|
|
| 437 |
static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
|
| 438 |
static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
|
| 439 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 440 |
template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename KQ_acc_t> void launch_fattn_f16_impl(
|
| 441 |
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
|
| 442 |
ggml_cuda_pool & pool, cudaStream_t main_stream
|
|
|
|
| 541 |
|
| 542 |
const int32_t precision = KQV->op_params[2];
|
| 543 |
|
| 544 |
+
if (!fast_fp16_available(cc)) {
|
| 545 |
+
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
| 546 |
+
return;
|
| 547 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 548 |
|
| 549 |
+
if (!fp16_mma_available(cc)) {
|
| 550 |
+
ggml_cuda_flash_attn_ext_vec_f16_no_mma(ctx, dst);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 551 |
return;
|
| 552 |
}
|
| 553 |
|
| 554 |
if (precision != GGML_PREC_DEFAULT) {
|
| 555 |
+
if (Q->ne[1] == 1 && (Q->ne[0] == 64 || Q->ne[0] == 128)) {
|
| 556 |
+
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
| 557 |
+
return;
|
| 558 |
+
}
|
| 559 |
+
|
| 560 |
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
|
| 561 |
constexpr int cols_per_block = 16;
|
| 562 |
constexpr int nwarps = 4;
|
|
|
|
| 614 |
}
|
| 615 |
|
| 616 |
if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
|
| 617 |
+
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 618 |
return;
|
| 619 |
}
|
| 620 |
|