Spaces:
Sleeping
Sleeping
Commit
·
ca79691
1
Parent(s):
2692ce5
CUDA: generalize FP16 fattn vec kernel (llama/7061)
Browse files* CUDA: generalize FP16 fattn vec kernel
* disable unsupported head sizes for AMD in test
* try AMD fix
* fix batch size 2-8
* partially revert changes
- ggml-cuda/common.cuh +124 -108
- ggml-cuda/fattn.cu +223 -82
ggml-cuda/common.cuh
CHANGED
|
@@ -234,6 +234,97 @@ typedef float dfloat; // dequantize float
|
|
| 234 |
typedef float2 dfloat2;
|
| 235 |
#endif //GGML_CUDA_F16
|
| 236 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
[[noreturn]]
|
| 238 |
static __device__ void no_device_code(
|
| 239 |
const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
|
|
@@ -275,16 +366,28 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
|
| 275 |
}
|
| 276 |
|
| 277 |
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
| 278 |
-
#if
|
|
|
|
|
|
|
| 279 |
#pragma unroll
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
|
|
|
|
|
|
| 284 |
#else
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
}
|
| 289 |
|
| 290 |
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
|
@@ -296,20 +399,21 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
|
|
| 296 |
}
|
| 297 |
|
| 298 |
static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
|
| 299 |
-
#if
|
| 300 |
|
| 301 |
-
#if CUDART_VERSION
|
| 302 |
-
return
|
| 303 |
#else
|
| 304 |
-
return
|
| 305 |
-
#endif // CUDART_VERSION
|
| 306 |
|
| 307 |
#else
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
#endif //
|
| 312 |
}
|
|
|
|
| 313 |
static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
|
| 314 |
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
| 315 |
|
|
@@ -317,8 +421,8 @@ static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const hal
|
|
| 317 |
return __hmax2(a, b);
|
| 318 |
#else
|
| 319 |
half2 ret;
|
| 320 |
-
reinterpret_cast<half&>(ret.x) =
|
| 321 |
-
reinterpret_cast<half&>(ret.y) = __high2float(a)
|
| 322 |
return ret;
|
| 323 |
#endif // CUDART_VERSION >= CUDART_HMAX
|
| 324 |
|
|
@@ -326,7 +430,7 @@ static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const hal
|
|
| 326 |
GGML_UNUSED(a);
|
| 327 |
GGML_UNUSED(b);
|
| 328 |
NO_DEVICE_CODE;
|
| 329 |
-
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
| 330 |
}
|
| 331 |
|
| 332 |
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
|
@@ -350,94 +454,6 @@ static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half
|
|
| 350 |
}
|
| 351 |
#endif // CUDART_VERSION < 12000
|
| 352 |
|
| 353 |
-
#if defined(GGML_USE_HIPBLAS)
|
| 354 |
-
#define __CUDA_ARCH__ 1300
|
| 355 |
-
|
| 356 |
-
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
|
| 357 |
-
defined(__gfx1150__) || defined(__gfx1151__)
|
| 358 |
-
#define RDNA3
|
| 359 |
-
#endif
|
| 360 |
-
|
| 361 |
-
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
|
| 362 |
-
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
|
| 363 |
-
#define RDNA2
|
| 364 |
-
#endif
|
| 365 |
-
|
| 366 |
-
#ifndef __has_builtin
|
| 367 |
-
#define __has_builtin(x) 0
|
| 368 |
-
#endif
|
| 369 |
-
|
| 370 |
-
typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
|
| 371 |
-
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
|
| 372 |
-
static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
|
| 373 |
-
const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
|
| 374 |
-
const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
|
| 375 |
-
#if __has_builtin(__builtin_elementwise_sub_sat)
|
| 376 |
-
const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
|
| 377 |
-
return reinterpret_cast<const int &>(c);
|
| 378 |
-
#else
|
| 379 |
-
int8x4_t c;
|
| 380 |
-
int16_t tmp;
|
| 381 |
-
#pragma unroll
|
| 382 |
-
for (int i = 0; i < 4; i++) {
|
| 383 |
-
tmp = va[i] - vb[i];
|
| 384 |
-
if(tmp > std::numeric_limits<int8_t>::max()) tmp = std::numeric_limits<int8_t>::max();
|
| 385 |
-
if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();
|
| 386 |
-
c[i] = tmp;
|
| 387 |
-
}
|
| 388 |
-
return reinterpret_cast<int &>(c);
|
| 389 |
-
#endif // __has_builtin(__builtin_elementwise_sub_sat)
|
| 390 |
-
}
|
| 391 |
-
|
| 392 |
-
static __device__ __forceinline__ int __vsub4(const int a, const int b) {
|
| 393 |
-
return __vsubss4(a, b);
|
| 394 |
-
}
|
| 395 |
-
|
| 396 |
-
static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigned int b) {
|
| 397 |
-
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
|
| 398 |
-
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
|
| 399 |
-
unsigned int c;
|
| 400 |
-
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
|
| 401 |
-
#pragma unroll
|
| 402 |
-
for (int i = 0; i < 4; ++i) {
|
| 403 |
-
vc[i] = va[i] == vb[i] ? 0xff : 0x00;
|
| 404 |
-
}
|
| 405 |
-
return c;
|
| 406 |
-
}
|
| 407 |
-
|
| 408 |
-
static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
|
| 409 |
-
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
|
| 410 |
-
c = __builtin_amdgcn_sdot4(a, b, c, false);
|
| 411 |
-
#elif defined(RDNA3)
|
| 412 |
-
c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
|
| 413 |
-
#elif defined(__gfx1010__) || defined(__gfx900__)
|
| 414 |
-
int tmp1;
|
| 415 |
-
int tmp2;
|
| 416 |
-
asm("\n \
|
| 417 |
-
v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \
|
| 418 |
-
v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \
|
| 419 |
-
v_add3_u32 %0, %1, %2, %0 \n \
|
| 420 |
-
v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \
|
| 421 |
-
v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \
|
| 422 |
-
v_add3_u32 %0, %1, %2, %0 \n \
|
| 423 |
-
"
|
| 424 |
-
: "+v"(c), "=&v"(tmp1), "=&v"(tmp2)
|
| 425 |
-
: "v"(a), "v"(b)
|
| 426 |
-
);
|
| 427 |
-
#else
|
| 428 |
-
const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
|
| 429 |
-
const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
|
| 430 |
-
c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3];
|
| 431 |
-
#endif
|
| 432 |
-
return c;
|
| 433 |
-
}
|
| 434 |
-
#endif // defined(GGML_USE_HIPBLAS)
|
| 435 |
-
|
| 436 |
-
#define FP16_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \
|
| 437 |
-
defined(RDNA1) || defined(RDNA2) || defined(RDNA3) : __CUDA_ARCH__ >= CC_PASCAL
|
| 438 |
-
|
| 439 |
-
#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
|
| 440 |
-
|
| 441 |
// TODO: move to ggml-common.h
|
| 442 |
static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
| 443 |
|
|
|
|
| 234 |
typedef float2 dfloat2;
|
| 235 |
#endif //GGML_CUDA_F16
|
| 236 |
|
| 237 |
+
#if defined(GGML_USE_HIPBLAS)
|
| 238 |
+
#define __CUDA_ARCH__ 1300
|
| 239 |
+
|
| 240 |
+
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
|
| 241 |
+
defined(__gfx1150__) || defined(__gfx1151__)
|
| 242 |
+
#define RDNA3
|
| 243 |
+
#endif
|
| 244 |
+
|
| 245 |
+
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
|
| 246 |
+
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
|
| 247 |
+
#define RDNA2
|
| 248 |
+
#endif
|
| 249 |
+
|
| 250 |
+
#ifndef __has_builtin
|
| 251 |
+
#define __has_builtin(x) 0
|
| 252 |
+
#endif
|
| 253 |
+
|
| 254 |
+
typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
|
| 255 |
+
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
|
| 256 |
+
static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
|
| 257 |
+
const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
|
| 258 |
+
const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
|
| 259 |
+
#if __has_builtin(__builtin_elementwise_sub_sat)
|
| 260 |
+
const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
|
| 261 |
+
return reinterpret_cast<const int &>(c);
|
| 262 |
+
#else
|
| 263 |
+
int8x4_t c;
|
| 264 |
+
int16_t tmp;
|
| 265 |
+
#pragma unroll
|
| 266 |
+
for (int i = 0; i < 4; i++) {
|
| 267 |
+
tmp = va[i] - vb[i];
|
| 268 |
+
if(tmp > std::numeric_limits<int8_t>::max()) tmp = std::numeric_limits<int8_t>::max();
|
| 269 |
+
if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();
|
| 270 |
+
c[i] = tmp;
|
| 271 |
+
}
|
| 272 |
+
return reinterpret_cast<int &>(c);
|
| 273 |
+
#endif // __has_builtin(__builtin_elementwise_sub_sat)
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
static __device__ __forceinline__ int __vsub4(const int a, const int b) {
|
| 277 |
+
return __vsubss4(a, b);
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigned int b) {
|
| 281 |
+
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
|
| 282 |
+
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
|
| 283 |
+
unsigned int c;
|
| 284 |
+
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
|
| 285 |
+
#pragma unroll
|
| 286 |
+
for (int i = 0; i < 4; ++i) {
|
| 287 |
+
vc[i] = va[i] == vb[i] ? 0xff : 0x00;
|
| 288 |
+
}
|
| 289 |
+
return c;
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
|
| 293 |
+
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
|
| 294 |
+
c = __builtin_amdgcn_sdot4(a, b, c, false);
|
| 295 |
+
#elif defined(RDNA3)
|
| 296 |
+
c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
|
| 297 |
+
#elif defined(__gfx1010__) || defined(__gfx900__)
|
| 298 |
+
int tmp1;
|
| 299 |
+
int tmp2;
|
| 300 |
+
asm("\n \
|
| 301 |
+
v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \
|
| 302 |
+
v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \
|
| 303 |
+
v_add3_u32 %0, %1, %2, %0 \n \
|
| 304 |
+
v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \
|
| 305 |
+
v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \
|
| 306 |
+
v_add3_u32 %0, %1, %2, %0 \n \
|
| 307 |
+
"
|
| 308 |
+
: "+v"(c), "=&v"(tmp1), "=&v"(tmp2)
|
| 309 |
+
: "v"(a), "v"(b)
|
| 310 |
+
);
|
| 311 |
+
#else
|
| 312 |
+
const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
|
| 313 |
+
const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
|
| 314 |
+
c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3];
|
| 315 |
+
#endif
|
| 316 |
+
return c;
|
| 317 |
+
}
|
| 318 |
+
#endif // defined(GGML_USE_HIPBLAS)
|
| 319 |
+
|
| 320 |
+
#define FP16_AVAILABLE (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
|
| 321 |
+
|
| 322 |
+
#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
|
| 323 |
+
|
| 324 |
+
static bool fp16_mma_available(const int cc) {
|
| 325 |
+
return cc < CC_OFFSET_AMD && cc >= CC_VOLTA;
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
[[noreturn]]
|
| 329 |
static __device__ void no_device_code(
|
| 330 |
const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
|
|
|
|
| 366 |
}
|
| 367 |
|
| 368 |
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
| 369 |
+
#if FP16_AVAILABLE
|
| 370 |
+
|
| 371 |
+
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
| 372 |
#pragma unroll
|
| 373 |
+
for (int mask = 16; mask > 0; mask >>= 1) {
|
| 374 |
+
const half2 a_other = __shfl_xor_sync(0xffffffff, a, mask, 32);
|
| 375 |
+
reinterpret_cast<half&>(a.x) += __low2half(a_other);
|
| 376 |
+
reinterpret_cast<half&>(a.y) += __high2half(a_other);
|
| 377 |
+
}
|
| 378 |
+
return a;
|
| 379 |
#else
|
| 380 |
+
#pragma unroll
|
| 381 |
+
for (int mask = 16; mask > 0; mask >>= 1) {
|
| 382 |
+
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
|
| 383 |
+
}
|
| 384 |
+
return a;
|
| 385 |
+
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
| 386 |
+
|
| 387 |
+
#else
|
| 388 |
+
NO_DEVICE_CODE;
|
| 389 |
+
return a;
|
| 390 |
+
#endif // FP16_AVAILABLE
|
| 391 |
}
|
| 392 |
|
| 393 |
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
|
|
|
| 399 |
}
|
| 400 |
|
| 401 |
static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
|
| 402 |
+
#if FP16_AVAILABLE
|
| 403 |
|
| 404 |
+
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
|
| 405 |
+
return __float2half(fmaxf(__half2float(a), __half2float(b)));
|
| 406 |
#else
|
| 407 |
+
return __hmax(a, b);
|
| 408 |
+
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
|
| 409 |
|
| 410 |
#else
|
| 411 |
+
NO_DEVICE_CODE;
|
| 412 |
+
GGML_UNUSED(b);
|
| 413 |
+
return a;
|
| 414 |
+
#endif // FP16_AVAILABLE
|
| 415 |
}
|
| 416 |
+
|
| 417 |
static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
|
| 418 |
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
| 419 |
|
|
|
|
| 421 |
return __hmax2(a, b);
|
| 422 |
#else
|
| 423 |
half2 ret;
|
| 424 |
+
reinterpret_cast<half&>(ret.x) = __float2half(fmaxf( __low2float(a), __low2float(b)));
|
| 425 |
+
reinterpret_cast<half&>(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b)));
|
| 426 |
return ret;
|
| 427 |
#endif // CUDART_VERSION >= CUDART_HMAX
|
| 428 |
|
|
|
|
| 430 |
GGML_UNUSED(a);
|
| 431 |
GGML_UNUSED(b);
|
| 432 |
NO_DEVICE_CODE;
|
| 433 |
+
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
| 434 |
}
|
| 435 |
|
| 436 |
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
|
|
|
| 454 |
}
|
| 455 |
#endif // CUDART_VERSION < 12000
|
| 456 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 457 |
// TODO: move to ggml-common.h
|
| 458 |
static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
| 459 |
|
ggml-cuda/fattn.cu
CHANGED
|
@@ -11,8 +11,10 @@
|
|
| 11 |
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
|
| 12 |
#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
|
| 13 |
|
| 14 |
-
template<int D, int parallel_blocks> // D == head size
|
| 15 |
-
|
|
|
|
|
|
|
| 16 |
static __global__ void flash_attn_vec_ext_f16(
|
| 17 |
const char * __restrict__ Q,
|
| 18 |
const char * __restrict__ K,
|
|
@@ -44,55 +46,77 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
| 44 |
#if FP16_AVAILABLE
|
| 45 |
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
| 46 |
|
| 47 |
-
const int
|
| 48 |
-
const int ip
|
| 49 |
|
| 50 |
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
| 51 |
-
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*
|
| 52 |
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
|
| 53 |
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
| 54 |
-
const half * maskh = (const half *) mask + ne11*
|
| 55 |
|
| 56 |
const int stride_KV = nb11 / sizeof(half);
|
| 57 |
const int stride_KV2 = nb11 / sizeof(half2);
|
| 58 |
|
| 59 |
-
|
|
|
|
| 60 |
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
| 61 |
-
__builtin_assume(tid <
|
| 62 |
|
| 63 |
-
__shared__ half KQ[
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
| 65 |
half2 * KQ2 = (half2 *) KQ;
|
| 66 |
|
| 67 |
-
half kqmax
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
-
__shared__ half kqmax_shared[WARP_SIZE];
|
| 71 |
-
__shared__ half kqsum_shared[WARP_SIZE];
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
| 75 |
}
|
| 76 |
__syncthreads();
|
| 77 |
|
| 78 |
// Convert Q to half2 and store in registers:
|
| 79 |
-
half2 Q_h2[
|
| 80 |
#pragma unroll
|
| 81 |
-
for (int
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
}
|
| 86 |
|
| 87 |
-
|
|
|
|
|
|
|
| 88 |
}
|
| 89 |
|
| 90 |
-
half2 VKQ =
|
| 91 |
|
| 92 |
-
const int k_start
|
| 93 |
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
|
| 94 |
// Calculate KQ tile and keep track of new maximum KQ values:
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
#pragma unroll
|
| 97 |
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
|
| 98 |
const int i_KQ = i_KQ_0 + threadIdx.y;
|
|
@@ -101,89 +125,112 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
| 101 |
break;
|
| 102 |
}
|
| 103 |
|
| 104 |
-
half2 sum2 =
|
| 105 |
#pragma unroll
|
| 106 |
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
|
| 107 |
const int k_KQ = k_KQ_0 + threadIdx.x;
|
| 108 |
-
if (k_KQ_0 + WARP_SIZE > D/2 && k_KQ >= D/2) {
|
| 109 |
-
break;
|
| 110 |
-
}
|
| 111 |
|
| 112 |
const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
| 114 |
}
|
| 115 |
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
}
|
| 123 |
}
|
| 124 |
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
}
|
|
|
|
| 129 |
__syncthreads();
|
| 130 |
-
kqmax_new = kqmax_shared[threadIdx.x];
|
| 131 |
-
kqmax_new = warp_reduce_max(kqmax_new);
|
| 132 |
|
| 133 |
-
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
|
| 140 |
-
|
|
|
|
| 141 |
|
| 142 |
__syncthreads();
|
| 143 |
|
| 144 |
-
if (tid < D) {
|
| 145 |
#pragma unroll
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
|
|
|
|
|
|
| 155 |
}
|
| 156 |
}
|
| 157 |
|
| 158 |
__syncthreads();
|
| 159 |
}
|
| 160 |
|
| 161 |
-
|
| 162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
}
|
| 164 |
|
| 165 |
-
kqsum = warp_reduce_sum(kqsum);
|
| 166 |
-
if (threadIdx.x == 0) {
|
| 167 |
-
kqsum_shared[threadIdx.y] = kqsum;
|
| 168 |
-
}
|
| 169 |
__syncthreads();
|
| 170 |
-
kqsum = kqsum_shared[threadIdx.x];
|
| 171 |
-
kqsum = warp_reduce_sum(kqsum);
|
| 172 |
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
|
|
|
| 176 |
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
|
|
|
|
|
|
|
|
|
| 180 |
}
|
| 181 |
-
dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = dst_val;
|
| 182 |
|
| 183 |
-
if (parallel_blocks
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
| 185 |
}
|
| 186 |
-
dst_meta[ic*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax, kqsum);
|
| 187 |
#else
|
| 188 |
NO_DEVICE_CODE;
|
| 189 |
#endif // FP16_AVAILABLE
|
|
@@ -191,7 +238,9 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
| 191 |
|
| 192 |
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
| 193 |
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
|
|
|
|
| 194 |
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
|
|
|
| 195 |
static __global__ void flash_attn_ext_f16(
|
| 196 |
const char * __restrict__ Q,
|
| 197 |
const char * __restrict__ K,
|
|
@@ -573,7 +622,9 @@ static __global__ void flash_attn_ext_f16(
|
|
| 573 |
}
|
| 574 |
|
| 575 |
template<int D, int parallel_blocks> // D == head size
|
|
|
|
| 576 |
__launch_bounds__(D, 1)
|
|
|
|
| 577 |
static __global__ void flash_attn_combine_results(
|
| 578 |
const float * __restrict__ VKQ_parts,
|
| 579 |
const float2 * __restrict__ VKQ_meta,
|
|
@@ -642,7 +693,7 @@ static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed.");
|
|
| 642 |
static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
|
| 643 |
static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
|
| 644 |
|
| 645 |
-
template <int D, int parallel_blocks> void launch_fattn_vec_f16(
|
| 646 |
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
|
| 647 |
ggml_cuda_pool & pool, cudaStream_t main_stream
|
| 648 |
) {
|
|
@@ -656,13 +707,13 @@ template <int D, int parallel_blocks> void launch_fattn_vec_f16(
|
|
| 656 |
|
| 657 |
constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
|
| 658 |
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
| 659 |
-
const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]);
|
| 660 |
const int shmem = 0;
|
| 661 |
|
| 662 |
float scale;
|
| 663 |
memcpy(&scale, KQV->op_params, sizeof(float));
|
| 664 |
|
| 665 |
-
flash_attn_vec_ext_f16<D, parallel_blocks>
|
| 666 |
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
| 667 |
(const char *) Q->data,
|
| 668 |
(const char *) K->data,
|
|
@@ -783,10 +834,99 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
| 783 |
|
| 784 |
ggml_cuda_set_device(ctx.device);
|
| 785 |
|
|
|
|
| 786 |
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
| 787 |
|
| 788 |
const int32_t precision = KQV->op_params[1];
|
| 789 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 790 |
if (precision != GGML_PREC_DEFAULT) {
|
| 791 |
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
|
| 792 |
constexpr int cols_per_block = 16;
|
|
@@ -845,16 +985,17 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
| 845 |
}
|
| 846 |
|
| 847 |
if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
|
|
|
|
| 848 |
constexpr int parallel_blocks = 4;
|
| 849 |
switch (Q->ne[0]) {
|
| 850 |
case 64:
|
| 851 |
-
launch_fattn_vec_f16< 64, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 852 |
break;
|
| 853 |
case 128:
|
| 854 |
-
launch_fattn_vec_f16<128, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 855 |
break;
|
| 856 |
case 256:
|
| 857 |
-
launch_fattn_vec_f16<256, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 858 |
break;
|
| 859 |
default:
|
| 860 |
GGML_ASSERT(false);
|
|
|
|
| 11 |
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
|
| 12 |
#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
|
| 13 |
|
| 14 |
+
template<int D, int ncols, int parallel_blocks> // D == head size
|
| 15 |
+
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
| 16 |
+
__launch_bounds__(D, 1)
|
| 17 |
+
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
| 18 |
static __global__ void flash_attn_vec_ext_f16(
|
| 19 |
const char * __restrict__ Q,
|
| 20 |
const char * __restrict__ K,
|
|
|
|
| 46 |
#if FP16_AVAILABLE
|
| 47 |
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
| 48 |
|
| 49 |
+
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
| 50 |
+
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
| 51 |
|
| 52 |
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
| 53 |
+
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
|
| 54 |
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
|
| 55 |
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
| 56 |
+
const half * maskh = (const half *) mask + ne11*ic0;
|
| 57 |
|
| 58 |
const int stride_KV = nb11 / sizeof(half);
|
| 59 |
const int stride_KV2 = nb11 / sizeof(half2);
|
| 60 |
|
| 61 |
+
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
| 62 |
+
constexpr int nwarps = D / WARP_SIZE;
|
| 63 |
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
| 64 |
+
__builtin_assume(tid < D);
|
| 65 |
|
| 66 |
+
__shared__ half KQ[ncols*D];
|
| 67 |
+
#pragma unroll
|
| 68 |
+
for (int j = 0; j < ncols; ++j) {
|
| 69 |
+
KQ[j*D + tid] = -HALF_MAX_HALF;
|
| 70 |
+
}
|
| 71 |
half2 * KQ2 = (half2 *) KQ;
|
| 72 |
|
| 73 |
+
half kqmax[ncols];
|
| 74 |
+
#pragma unroll
|
| 75 |
+
for (int j = 0; j < ncols; ++j) {
|
| 76 |
+
kqmax[j] = -HALF_MAX_HALF;
|
| 77 |
+
}
|
| 78 |
+
half kqsum[ncols] = {0.0f};
|
| 79 |
|
| 80 |
+
__shared__ half kqmax_shared[ncols][WARP_SIZE];
|
| 81 |
+
__shared__ half kqsum_shared[ncols][WARP_SIZE];
|
| 82 |
+
#pragma unroll
|
| 83 |
+
for (int j = 0; j < ncols; ++j) {
|
| 84 |
+
if (threadIdx.y == 0) {
|
| 85 |
+
kqmax_shared[j][threadIdx.x] = -HALF_MAX_HALF;
|
| 86 |
+
kqsum_shared[j][threadIdx.x] = 0.0f;
|
| 87 |
+
}
|
| 88 |
}
|
| 89 |
__syncthreads();
|
| 90 |
|
| 91 |
// Convert Q to half2 and store in registers:
|
| 92 |
+
half2 Q_h2[ncols][D/(2*WARP_SIZE)];
|
| 93 |
#pragma unroll
|
| 94 |
+
for (int j = 0; j < ncols; ++j) {
|
| 95 |
+
#pragma unroll
|
| 96 |
+
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
| 97 |
+
const int i = i0 + threadIdx.x;
|
|
|
|
| 98 |
|
| 99 |
+
const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i];
|
| 100 |
+
Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
|
| 101 |
+
}
|
| 102 |
}
|
| 103 |
|
| 104 |
+
half2 VKQ[ncols] = {{0.0f, 0.0f}};
|
| 105 |
|
| 106 |
+
const int k_start = parallel_blocks == 1 ? 0 : ip*D;
|
| 107 |
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
|
| 108 |
// Calculate KQ tile and keep track of new maximum KQ values:
|
| 109 |
+
|
| 110 |
+
// For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
|
| 111 |
+
// see https://github.com/ggerganov/llama.cpp/pull/7061 .
|
| 112 |
+
// Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable).
|
| 113 |
+
half kqmax_new = kqmax[0];
|
| 114 |
+
half kqmax_new_arr[ncols];
|
| 115 |
+
#pragma unroll
|
| 116 |
+
for (int j = 0; j < ncols; ++j) {
|
| 117 |
+
kqmax_new_arr[j] = kqmax[j];
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
#pragma unroll
|
| 121 |
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
|
| 122 |
const int i_KQ = i_KQ_0 + threadIdx.y;
|
|
|
|
| 125 |
break;
|
| 126 |
}
|
| 127 |
|
| 128 |
+
half2 sum2[ncols] = {{0.0f, 0.0f}};
|
| 129 |
#pragma unroll
|
| 130 |
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
|
| 131 |
const int k_KQ = k_KQ_0 + threadIdx.x;
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
|
| 134 |
+
#pragma unroll
|
| 135 |
+
for (int j = 0; j < ncols; ++j) {
|
| 136 |
+
sum2[j] += K_ik * Q_h2[j][k_KQ_0/WARP_SIZE];
|
| 137 |
+
}
|
| 138 |
}
|
| 139 |
|
| 140 |
+
#pragma unroll
|
| 141 |
+
for (int j = 0; j < ncols; ++j) {
|
| 142 |
+
sum2[j] = warp_reduce_sum(sum2[j]);
|
| 143 |
+
half sum = __low2half(sum2[j]) + __high2half(sum2[j]);
|
| 144 |
+
sum += mask ? maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
|
| 145 |
+
|
| 146 |
+
if (ncols == 1) {
|
| 147 |
+
kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
|
| 148 |
+
} else {
|
| 149 |
+
kqmax_new_arr[j] = ggml_cuda_hmax(kqmax_new_arr[j], sum);
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
if (threadIdx.x == 0) {
|
| 153 |
+
KQ[j*D + i_KQ] = sum;
|
| 154 |
+
}
|
| 155 |
}
|
| 156 |
}
|
| 157 |
|
| 158 |
+
#pragma unroll
|
| 159 |
+
for (int j = 0; j < ncols; ++j) {
|
| 160 |
+
half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j];
|
| 161 |
+
|
| 162 |
+
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
| 163 |
+
if (threadIdx.x == 0) {
|
| 164 |
+
kqmax_shared[j][threadIdx.y] = kqmax_new_j;
|
| 165 |
+
}
|
| 166 |
}
|
| 167 |
+
|
| 168 |
__syncthreads();
|
|
|
|
|
|
|
| 169 |
|
| 170 |
+
#pragma unroll
|
| 171 |
+
for (int j = 0; j < ncols; ++j) {
|
| 172 |
+
half kqmax_new_j = kqmax_shared[j][threadIdx.x];
|
| 173 |
+
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
| 174 |
+
|
| 175 |
+
const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
|
| 176 |
+
kqmax[j] = kqmax_new_j;
|
| 177 |
|
| 178 |
+
const half val = hexp(KQ[j*D + tid] - kqmax[j]);
|
| 179 |
+
kqsum[j] = kqsum[j]*KQ_max_scale + val;
|
| 180 |
+
KQ[j*D + tid] = val;
|
| 181 |
|
| 182 |
+
VKQ[j] *= __half2half2(KQ_max_scale);
|
| 183 |
+
}
|
| 184 |
|
| 185 |
__syncthreads();
|
| 186 |
|
|
|
|
| 187 |
#pragma unroll
|
| 188 |
+
for (int k0 = 0; k0 < D; k0 += 2) {
|
| 189 |
+
if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) {
|
| 190 |
+
break;
|
| 191 |
+
}
|
| 192 |
|
| 193 |
+
half2 V_k;
|
| 194 |
+
reinterpret_cast<half&>(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid];
|
| 195 |
+
reinterpret_cast<half&>(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid];
|
| 196 |
+
#pragma unroll
|
| 197 |
+
for (int j = 0; j < ncols; ++j) {
|
| 198 |
+
VKQ[j] += V_k*KQ2[j*(D/2) + k0/2];
|
| 199 |
}
|
| 200 |
}
|
| 201 |
|
| 202 |
__syncthreads();
|
| 203 |
}
|
| 204 |
|
| 205 |
+
#pragma unroll
|
| 206 |
+
for (int j = 0; j < ncols; ++j) {
|
| 207 |
+
kqsum[j] = warp_reduce_sum(kqsum[j]);
|
| 208 |
+
if (threadIdx.x == 0) {
|
| 209 |
+
kqsum_shared[j][threadIdx.y] = kqsum[j];
|
| 210 |
+
}
|
| 211 |
}
|
| 212 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
__syncthreads();
|
|
|
|
|
|
|
| 214 |
|
| 215 |
+
#pragma unroll
|
| 216 |
+
for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
|
| 217 |
+
kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
|
| 218 |
+
kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
|
| 219 |
|
| 220 |
+
half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
|
| 221 |
+
if (parallel_blocks == 1) {
|
| 222 |
+
dst_val /= kqsum[j_VKQ];
|
| 223 |
+
}
|
| 224 |
+
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
|
| 225 |
+
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
|
| 226 |
}
|
|
|
|
| 227 |
|
| 228 |
+
if (parallel_blocks != 1 && tid != 0) {
|
| 229 |
+
#pragma unroll
|
| 230 |
+
for (int j = 0; j < ncols; ++j) {
|
| 231 |
+
dst_meta[(ic0 + j)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j], kqsum[j]);
|
| 232 |
+
}
|
| 233 |
}
|
|
|
|
| 234 |
#else
|
| 235 |
NO_DEVICE_CODE;
|
| 236 |
#endif // FP16_AVAILABLE
|
|
|
|
| 238 |
|
| 239 |
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
| 240 |
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
|
| 241 |
+
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
| 242 |
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
| 243 |
+
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
| 244 |
static __global__ void flash_attn_ext_f16(
|
| 245 |
const char * __restrict__ Q,
|
| 246 |
const char * __restrict__ K,
|
|
|
|
| 622 |
}
|
| 623 |
|
| 624 |
template<int D, int parallel_blocks> // D == head size
|
| 625 |
+
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
| 626 |
__launch_bounds__(D, 1)
|
| 627 |
+
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
| 628 |
static __global__ void flash_attn_combine_results(
|
| 629 |
const float * __restrict__ VKQ_parts,
|
| 630 |
const float2 * __restrict__ VKQ_meta,
|
|
|
|
| 693 |
static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
|
| 694 |
static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
|
| 695 |
|
| 696 |
+
template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_f16(
|
| 697 |
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
|
| 698 |
ggml_cuda_pool & pool, cudaStream_t main_stream
|
| 699 |
) {
|
|
|
|
| 707 |
|
| 708 |
constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
|
| 709 |
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
| 710 |
+
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
|
| 711 |
const int shmem = 0;
|
| 712 |
|
| 713 |
float scale;
|
| 714 |
memcpy(&scale, KQV->op_params, sizeof(float));
|
| 715 |
|
| 716 |
+
flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>
|
| 717 |
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
| 718 |
(const char *) Q->data,
|
| 719 |
(const char *) K->data,
|
|
|
|
| 834 |
|
| 835 |
ggml_cuda_set_device(ctx.device);
|
| 836 |
|
| 837 |
+
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
| 838 |
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
| 839 |
|
| 840 |
const int32_t precision = KQV->op_params[1];
|
| 841 |
|
| 842 |
+
if (!fp16_mma_available(cc)) {
|
| 843 |
+
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
|
| 844 |
+
GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
| 845 |
+
|
| 846 |
+
if (Q->ne[1] == 1) {
|
| 847 |
+
constexpr int cols_per_block = 1;
|
| 848 |
+
constexpr int parallel_blocks = 4;
|
| 849 |
+
switch (Q->ne[0]) {
|
| 850 |
+
case 64:
|
| 851 |
+
launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 852 |
+
break;
|
| 853 |
+
case 128:
|
| 854 |
+
launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 855 |
+
break;
|
| 856 |
+
default:
|
| 857 |
+
GGML_ASSERT(false);
|
| 858 |
+
break;
|
| 859 |
+
}
|
| 860 |
+
return;
|
| 861 |
+
}
|
| 862 |
+
|
| 863 |
+
if (Q->ne[1] == 2) {
|
| 864 |
+
constexpr int cols_per_block = 2;
|
| 865 |
+
constexpr int parallel_blocks = 4;
|
| 866 |
+
switch (Q->ne[0]) {
|
| 867 |
+
case 64:
|
| 868 |
+
launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 869 |
+
break;
|
| 870 |
+
case 128:
|
| 871 |
+
launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 872 |
+
break;
|
| 873 |
+
default:
|
| 874 |
+
GGML_ASSERT(false);
|
| 875 |
+
break;
|
| 876 |
+
}
|
| 877 |
+
return;
|
| 878 |
+
}
|
| 879 |
+
|
| 880 |
+
if (Q->ne[1] <= 4) {
|
| 881 |
+
constexpr int cols_per_block = 4;
|
| 882 |
+
constexpr int parallel_blocks = 4;
|
| 883 |
+
switch (Q->ne[0]) {
|
| 884 |
+
case 64:
|
| 885 |
+
launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 886 |
+
break;
|
| 887 |
+
case 128:
|
| 888 |
+
launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 889 |
+
break;
|
| 890 |
+
default:
|
| 891 |
+
GGML_ASSERT(false);
|
| 892 |
+
break;
|
| 893 |
+
}
|
| 894 |
+
return;
|
| 895 |
+
}
|
| 896 |
+
|
| 897 |
+
if (Q->ne[1] <= 8) {
|
| 898 |
+
constexpr int cols_per_block = 8;
|
| 899 |
+
constexpr int parallel_blocks = 4;
|
| 900 |
+
switch (Q->ne[0]) {
|
| 901 |
+
case 64:
|
| 902 |
+
launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 903 |
+
break;
|
| 904 |
+
case 128:
|
| 905 |
+
launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 906 |
+
break;
|
| 907 |
+
default:
|
| 908 |
+
GGML_ASSERT(false);
|
| 909 |
+
break;
|
| 910 |
+
}
|
| 911 |
+
return;
|
| 912 |
+
}
|
| 913 |
+
|
| 914 |
+
constexpr int cols_per_block = 8;
|
| 915 |
+
constexpr int parallel_blocks = 1;
|
| 916 |
+
switch (Q->ne[0]) {
|
| 917 |
+
case 64:
|
| 918 |
+
launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 919 |
+
break;
|
| 920 |
+
case 128:
|
| 921 |
+
launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 922 |
+
break;
|
| 923 |
+
default:
|
| 924 |
+
GGML_ASSERT(false);
|
| 925 |
+
break;
|
| 926 |
+
}
|
| 927 |
+
return;
|
| 928 |
+
}
|
| 929 |
+
|
| 930 |
if (precision != GGML_PREC_DEFAULT) {
|
| 931 |
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
|
| 932 |
constexpr int cols_per_block = 16;
|
|
|
|
| 985 |
}
|
| 986 |
|
| 987 |
if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
|
| 988 |
+
constexpr int cols_per_block = 1;
|
| 989 |
constexpr int parallel_blocks = 4;
|
| 990 |
switch (Q->ne[0]) {
|
| 991 |
case 64:
|
| 992 |
+
launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 993 |
break;
|
| 994 |
case 128:
|
| 995 |
+
launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 996 |
break;
|
| 997 |
case 256:
|
| 998 |
+
launch_fattn_vec_f16<256, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
| 999 |
break;
|
| 1000 |
default:
|
| 1001 |
GGML_ASSERT(false);
|