diff --git "a/ggml.c" "b/ggml.c" --- "a/ggml.c" +++ "b/ggml.c" @@ -1,10 +1,8 @@ #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows +#define _USE_MATH_DEFINES // For M_PI on MSVC -#include "ggml.h" - -#ifdef GGML_USE_K_QUANTS -#include "k_quants.h" -#endif +#include "ggml-impl.h" +#include "ggml-quants.h" #if defined(_MSC_VER) || defined(__MINGW32__) #include // using malloc.h with MSC/MINGW @@ -30,18 +28,6 @@ #include #endif -// static_assert should be a #define, but if it's not, -// fall back to the _Static_assert C11 keyword. -// if C99 - static_assert is noop -// ref: https://stackoverflow.com/a/53923785/4039976 -#ifndef static_assert -#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L) -#define static_assert(cond, msg) _Static_assert(cond, msg) -#else -#define static_assert(cond, msg) struct global_scope_noop_trick -#endif -#endif - #if defined(_MSC_VER) // disable "possible loss of data" to avoid hundreds of casts // we should just be careful :) @@ -89,7 +75,9 @@ static int pthread_create(pthread_t * out, void * unused, thread_ret_t(*func)(vo static int pthread_join(pthread_t thread, void * unused) { (void) unused; - return (int) WaitForSingleObject(thread, INFINITE); + int ret = (int) WaitForSingleObject(thread, INFINITE); + CloseHandle(thread); + return ret; } static int sched_yield (void) { @@ -107,23 +95,60 @@ typedef void * thread_ret_t; #include #endif + #ifdef GGML_USE_CPU_HBM #include #endif -// __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512 -#if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)) -#ifndef __FMA__ -#define __FMA__ -#endif -#ifndef __F16C__ -#define __F16C__ -#endif -#ifndef __SSE3__ -#define __SSE3__ +#if defined(__APPLE__) +#include #endif + +#if (defined(__linux__) || defined(__APPLE__) || defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)) && \ + (!defined(TARGET_OS_TV) && !defined(TARGET_OS_WATCH)) + +#include + +void ggml_print_backtrace(void) { + /* + #include + #include + + void * trace[100]; + + int nptrs = backtrace(trace, sizeof(trace)/sizeof(trace[0])); + + backtrace_symbols_fd(trace, nptrs, STDERR_FILENO); + */ + + // backtrack_symbols does not show line numbers, use gdb instead + char attach[32]; + snprintf(attach, sizeof(attach), "attach %d", getpid()); + int pid = fork(); + if (pid == 0) { + execlp("gdb", "gdb", "--batch", + "-ex", "set style enabled on", + "-ex", attach, + "-ex", "bt -frame-info source-and-location", + "-ex", "detach", + "-ex", "quit", + NULL); + } else { + waitpid(pid, NULL, 0); + } +} +#else +void ggml_print_backtrace(void) { + // platform not supported +} #endif +#undef MIN +#undef MAX + +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + /*#define GGML_PERF*/ #define GGML_DEBUG 0 #define GGML_GELU_FP16 @@ -134,6 +159,7 @@ typedef void * thread_ret_t; #define GGML_SOFT_MAX_UNROLL 4 #define GGML_VEC_DOT_UNROLL 2 +#define GGML_VEC_MAD_UNROLL 32 // // logging @@ -159,40 +185,16 @@ typedef void * thread_ret_t; #define GGML_PRINT(...) printf(__VA_ARGS__) +// +// end of logging block +// + #ifdef GGML_USE_ACCELERATE // uncomment to use vDSP for soft max computation // note: not sure if it is actually faster //#define GGML_SOFT_MAX_ACCELERATE #endif -// -// logging -// - -#if (GGML_DEBUG >= 1) -#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__) -#else -#define GGML_PRINT_DEBUG(...) -#endif - -#if (GGML_DEBUG >= 5) -#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__) -#else -#define GGML_PRINT_DEBUG_5(...) -#endif - -#if (GGML_DEBUG >= 10) -#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__) -#else -#define GGML_PRINT_DEBUG_10(...) -#endif - -#define GGML_PRINT(...) printf(__VA_ARGS__) - -// -// end of logging block -// - #if defined(_MSC_VER) || defined(__MINGW32__) #define GGML_ALIGNED_MALLOC(size) _aligned_malloc(size, GGML_MEM_ALIGN) #define GGML_ALIGNED_FREE(ptr) _aligned_free(ptr) @@ -242,18 +244,18 @@ inline static void * ggml_aligned_malloc(size_t size) { // #define GGML_TENSOR_UNARY_OP_LOCALS \ - GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); \ - GGML_TENSOR_LOCALS(size_t, nb0, src0, nb); \ - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); \ - GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \ + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \ + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \ + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) #define GGML_TENSOR_BINARY_OP_LOCALS \ - GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); \ - GGML_TENSOR_LOCALS(size_t, nb0, src0, nb); \ - GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne); \ - GGML_TENSOR_LOCALS(size_t, nb1, src1, nb); \ - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); \ - GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \ + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \ + GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \ + GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \ + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \ + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) #if defined(GGML_USE_ACCELERATE) #include @@ -272,228 +274,27 @@ inline static void * ggml_aligned_malloc(size_t size) { #include "ggml-opencl.h" #endif -#undef MIN -#undef MAX -#define MIN(a, b) ((a) < (b) ? (a) : (b)) -#define MAX(a, b) ((a) > (b) ? (a) : (b)) - // floating point type used to accumulate sums typedef double ggml_float; -// 16-bit float -// on Arm, we use __fp16 -// on x86, we use uint16_t -#ifdef __ARM_NEON - -// if YCM cannot find , make a symbolic link to it, for example: -// -// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/ -// -#include - -#define GGML_COMPUTE_FP16_TO_FP32(x) ((float) (x)) -#define GGML_COMPUTE_FP32_TO_FP16(x) (x) - -#define GGML_FP16_TO_FP32(x) ((float) (x)) -#define GGML_FP32_TO_FP16(x) (x) - -#else - -#ifdef __wasm_simd128__ -#include -#else -#ifdef __POWER9_VECTOR__ -#include -#undef bool -#define bool _Bool -#else -#if defined(_MSC_VER) || defined(__MINGW32__) -#include -#else -#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) -#if !defined(__riscv) -#include -#endif -#endif -#endif -#endif -#endif - -#ifdef __riscv_v_intrinsic -#include -#endif - -#ifdef __F16C__ - -#ifdef _MSC_VER -#define GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x))) -#define GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0) -#else -#define GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x) -#define GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0) -#endif - -#elif defined(__POWER9_VECTOR__) - -#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) -#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x) -/* the inline asm below is about 12% faster than the lookup method */ -#define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x) -#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) - -static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) { - register float f; - register double d; - __asm__( - "mtfprd %0,%2\n" - "xscvhpdp %0,%0\n" - "frsp %1,%0\n" : - /* temp */ "=d"(d), - /* out */ "=f"(f): - /* in */ "r"(h)); - return f; -} - -static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { - register double d; - register ggml_fp16_t r; - __asm__( /* xscvdphp can work on double or single precision */ - "xscvdphp %0,%2\n" - "mffprd %1,%0\n" : - /* temp */ "=d"(d), - /* out */ "=r"(r): - /* in */ "f"(f)); - return r; -} - -#else - -// FP16 <-> FP32 -// ref: https://github.com/Maratyszcza/FP16 - -static inline float fp32_from_bits(uint32_t w) { - union { - uint32_t as_bits; - float as_value; - } fp32; - fp32.as_bits = w; - return fp32.as_value; -} - -static inline uint32_t fp32_to_bits(float f) { - union { - float as_value; - uint32_t as_bits; - } fp32; - fp32.as_value = f; - return fp32.as_bits; -} - -static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) { - const uint32_t w = (uint32_t) h << 16; - const uint32_t sign = w & UINT32_C(0x80000000); - const uint32_t two_w = w + w; - - const uint32_t exp_offset = UINT32_C(0xE0) << 23; -#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) - const float exp_scale = 0x1.0p-112f; -#else - const float exp_scale = fp32_from_bits(UINT32_C(0x7800000)); -#endif - const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; - - const uint32_t magic_mask = UINT32_C(126) << 23; - const float magic_bias = 0.5f; - const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; - - const uint32_t denormalized_cutoff = UINT32_C(1) << 27; - const uint32_t result = sign | - (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value)); - return fp32_from_bits(result); -} - -static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { -#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) - const float scale_to_inf = 0x1.0p+112f; - const float scale_to_zero = 0x1.0p-110f; -#else - const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000)); - const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000)); -#endif - float base = (fabsf(f) * scale_to_inf) * scale_to_zero; - - const uint32_t w = fp32_to_bits(f); - const uint32_t shl1_w = w + w; - const uint32_t sign = w & UINT32_C(0x80000000); - uint32_t bias = shl1_w & UINT32_C(0xFF000000); - if (bias < UINT32_C(0x71000000)) { - bias = UINT32_C(0x71000000); - } - - base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; - const uint32_t bits = fp32_to_bits(base); - const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); - const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); - const uint32_t nonsign = exp_bits + mantissa_bits; - return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign); -} - -#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) -#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x) - -#endif // __F16C__ - -#endif // __ARM_NEON - // // global data // // precomputed gelu table for f16 (128 KB) -static ggml_fp16_t table_gelu_f16[1 << 16]; +static ggml_fp16_t ggml_table_gelu_f16[1 << 16]; // precomputed quick gelu table for f16 (128 KB) -static ggml_fp16_t table_gelu_quick_f16[1 << 16]; +static ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16]; // precomputed silu table for f16 (128 KB) -static ggml_fp16_t table_silu_f16[1 << 16]; +static ggml_fp16_t ggml_table_silu_f16[1 << 16]; // precomputed exp table for f16 (128 KB) -static ggml_fp16_t table_exp_f16[1 << 16]; - -// precomputed f32 table for f16 (256 KB) -static float table_f32_f16[1 << 16]; - -#if defined(__ARM_NEON) || defined(__wasm_simd128__) -#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s -#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s) -#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s) -#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s) -#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s) -#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s) -#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s) -#define B8(c,s ) B7(c,s, c), B7(c,s, s) - -// precomputed tables for expanding 8bits to 8 bytes: -static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4 -static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4 -#endif - -// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32, -// so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON. -// This is also true for POWER9. -#if !defined(GGML_FP16_TO_FP32) || !defined(GGML_FP32_TO_FP16) - -inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { - uint16_t s; - memcpy(&s, &f, sizeof(uint16_t)); - return table_f32_f16[s]; -} +static ggml_fp16_t ggml_table_exp_f16[1 << 16]; -#define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x) -#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) - -#endif +// precomputed f32 table for f16 (256 KB) (ggml-impl.h) +float ggml_table_f32_f16[1 << 16]; // note: do not use these inside ggml.c // these are meant to be used via the ggml.h API @@ -592,7 +393,6 @@ int64_t ggml_cycles_per_ms(void) { #define ggml_perf_cycles_per_ms() 0 #endif - // // cache line // @@ -609,1009 +409,8 @@ int64_t ggml_cycles_per_ms(void) { static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); -// -// quantization -// - -#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) - -#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) -// multiply int8_t, add results pairwise twice -static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { - // Get absolute values of x vectors - const __m128i ax = _mm_sign_epi8(x, x); - // Sign the values of the y vectors - const __m128i sy = _mm_sign_epi8(y, x); - // Perform multiplication and create 16-bit values - const __m128i dot = _mm_maddubs_epi16(ax, sy); - const __m128i ones = _mm_set1_epi16(1); - return _mm_madd_epi16(ones, dot); -} - -#if __AVX__ || __AVX2__ || __AVX512F__ -// horizontally add 8 floats -static inline float hsum_float_8(const __m256 x) { - __m128 res = _mm256_extractf128_ps(x, 1); - res = _mm_add_ps(res, _mm256_castps256_ps128(x)); - res = _mm_add_ps(res, _mm_movehl_ps(res, res)); - res = _mm_add_ss(res, _mm_movehdup_ps(res)); - return _mm_cvtss_f32(res); -} - -// horizontally add 8 int32_t -static inline int hsum_i32_8(const __m256i a) { - const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1)); - const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); - const __m128i sum64 = _mm_add_epi32(hi64, sum128); - const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); - return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); -} - -// horizontally add 4 int32_t -static inline int hsum_i32_4(const __m128i a) { - const __m128i hi64 = _mm_unpackhi_epi64(a, a); - const __m128i sum64 = _mm_add_epi32(hi64, a); - const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); - return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); -} - -#if defined(__AVX2__) || defined(__AVX512F__) -// spread 32 bits to 32 bytes { 0x00, 0xFF } -static inline __m256i bytes_from_bits_32(const uint8_t * x) { - uint32_t x32; - memcpy(&x32, x, sizeof(uint32_t)); - const __m256i shuf_mask = _mm256_set_epi64x( - 0x0303030303030303, 0x0202020202020202, - 0x0101010101010101, 0x0000000000000000); - __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask); - const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe); - bytes = _mm256_or_si256(bytes, bit_mask); - return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1)); -} - -// Unpack 32 4-bit fields into 32 bytes -// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval -static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) -{ - const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi); - const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp); - const __m256i lowMask = _mm256_set1_epi8( 0xF ); - return _mm256_and_si256(lowMask, bytes); -} - -// add int16_t pairwise and return as float vector -static inline __m256 sum_i16_pairs_float(const __m256i x) { - const __m256i ones = _mm256_set1_epi16(1); - const __m256i summed_pairs = _mm256_madd_epi16(ones, x); - return _mm256_cvtepi32_ps(summed_pairs); -} - -static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { -#if __AVXVNNI__ - const __m256i zero = _mm256_setzero_si256(); - const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy); - return _mm256_cvtepi32_ps(summed_pairs); -#else - // Perform multiplication and create 16-bit values - const __m256i dot = _mm256_maddubs_epi16(ax, sy); - return sum_i16_pairs_float(dot); -#endif -} - -// multiply int8_t, add results pairwise twice and return as float vector -static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { -#if __AVXVNNIINT8__ - const __m256i zero = _mm256_setzero_si256(); - const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y); - return _mm256_cvtepi32_ps(summed_pairs); -#else - // Get absolute values of x vectors - const __m256i ax = _mm256_sign_epi8(x, x); - // Sign the values of the y vectors - const __m256i sy = _mm256_sign_epi8(y, x); - return mul_sum_us8_pairs_float(ax, sy); -#endif -} - -static inline __m128i packNibbles( __m256i bytes ) -{ - // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh -#if __AVX512F__ - const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000 - bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh - return _mm256_cvtepi16_epi8(bytes); // abcd_efgh -#else - const __m256i lowByte = _mm256_set1_epi16( 0xFF ); - __m256i high = _mm256_andnot_si256( lowByte, bytes ); - __m256i low = _mm256_and_si256( lowByte, bytes ); - high = _mm256_srli_epi16( high, 4 ); - bytes = _mm256_or_si256( low, high ); - - // Compress uint16_t lanes into bytes - __m128i r0 = _mm256_castsi256_si128( bytes ); - __m128i r1 = _mm256_extracti128_si256( bytes, 1 ); - return _mm_packus_epi16( r0, r1 ); -#endif -} -#elif defined(__AVX__) -// spread 32 bits to 32 bytes { 0x00, 0xFF } -static inline __m256i bytes_from_bits_32(const uint8_t * x) { - uint32_t x32; - memcpy(&x32, x, sizeof(uint32_t)); - const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000); - const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202); - __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl); - __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh); - const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe); - bytesl = _mm_or_si128(bytesl, bit_mask); - bytesh = _mm_or_si128(bytesh, bit_mask); - bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1)); - bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1)); - return MM256_SET_M128I(bytesh, bytesl); -} - -// Unpack 32 4-bit fields into 32 bytes -// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval -static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) -{ - // Load 16 bytes from memory - __m128i tmpl = _mm_loadu_si128((const __m128i *)rsi); - __m128i tmph = _mm_srli_epi16(tmpl, 4); - const __m128i lowMask = _mm_set1_epi8(0xF); - tmpl = _mm_and_si128(lowMask, tmpl); - tmph = _mm_and_si128(lowMask, tmph); - return MM256_SET_M128I(tmph, tmpl); -} - -// add int16_t pairwise and return as float vector -static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) { - const __m128i ones = _mm_set1_epi16(1); - const __m128i summed_pairsl = _mm_madd_epi16(ones, xl); - const __m128i summed_pairsh = _mm_madd_epi16(ones, xh); - const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl); - return _mm256_cvtepi32_ps(summed_pairs); -} - -static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { - const __m128i axl = _mm256_castsi256_si128(ax); - const __m128i axh = _mm256_extractf128_si256(ax, 1); - const __m128i syl = _mm256_castsi256_si128(sy); - const __m128i syh = _mm256_extractf128_si256(sy, 1); - // Perform multiplication and create 16-bit values - const __m128i dotl = _mm_maddubs_epi16(axl, syl); - const __m128i doth = _mm_maddubs_epi16(axh, syh); - return sum_i16_pairs_float(doth, dotl); -} - -// multiply int8_t, add results pairwise twice and return as float vector -static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { - const __m128i xl = _mm256_castsi256_si128(x); - const __m128i xh = _mm256_extractf128_si256(x, 1); - const __m128i yl = _mm256_castsi256_si128(y); - const __m128i yh = _mm256_extractf128_si256(y, 1); - // Get absolute values of x vectors - const __m128i axl = _mm_sign_epi8(xl, xl); - const __m128i axh = _mm_sign_epi8(xh, xh); - // Sign the values of the y vectors - const __m128i syl = _mm_sign_epi8(yl, xl); - const __m128i syh = _mm_sign_epi8(yh, xh); - // Perform multiplication and create 16-bit values - const __m128i dotl = _mm_maddubs_epi16(axl, syl); - const __m128i doth = _mm_maddubs_epi16(axh, syh); - return sum_i16_pairs_float(doth, dotl); -} - -static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 ) -{ - // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh - const __m128i lowByte = _mm_set1_epi16( 0xFF ); - __m128i high = _mm_andnot_si128( lowByte, bytes1 ); - __m128i low = _mm_and_si128( lowByte, bytes1 ); - high = _mm_srli_epi16( high, 4 ); - bytes1 = _mm_or_si128( low, high ); - high = _mm_andnot_si128( lowByte, bytes2 ); - low = _mm_and_si128( lowByte, bytes2 ); - high = _mm_srli_epi16( high, 4 ); - bytes2 = _mm_or_si128( low, high ); - - return _mm_packus_epi16( bytes1, bytes2); -} -#endif -#elif defined(__SSSE3__) -// horizontally add 4x4 floats -static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) { - __m128 res_0 =_mm_hadd_ps(a, b); - __m128 res_1 =_mm_hadd_ps(c, d); - __m128 res =_mm_hadd_ps(res_0, res_1); - res =_mm_hadd_ps(res, res); - res =_mm_hadd_ps(res, res); - - return _mm_cvtss_f32(res); -} -#endif // __AVX__ || __AVX2__ || __AVX512F__ -#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) - -#if defined(__ARM_NEON) - -#if !defined(__aarch64__) - -inline static int32_t vaddvq_s32(int32x4_t v) { - return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3); -} - -inline static float vaddvq_f32(float32x4_t v) { - return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3); -} - -inline static float vmaxvq_f32(float32x4_t v) { - return - MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)), - MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3))); -} - -inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) { - int32x4_t res; - - res[0] = roundf(vgetq_lane_f32(v, 0)); - res[1] = roundf(vgetq_lane_f32(v, 1)); - res[2] = roundf(vgetq_lane_f32(v, 2)); - res[3] = roundf(vgetq_lane_f32(v, 3)); - - return res; -} - -#endif -#endif - -#define QK4_0 32 -typedef struct { - ggml_fp16_t d; // delta - uint8_t qs[QK4_0 / 2]; // nibbles / quants -} block_q4_0; -static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding"); - -#define QK4_1 32 -typedef struct { - ggml_fp16_t d; // delta - ggml_fp16_t m; // min - uint8_t qs[QK4_1 / 2]; // nibbles / quants -} block_q4_1; -static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_fp16_t) + QK4_1 / 2, "wrong q4_1 block size/padding"); - -#define QK5_0 32 -typedef struct { - ggml_fp16_t d; // delta - uint8_t qh[4]; // 5-th bit of quants - uint8_t qs[QK5_0 / 2]; // nibbles / quants -} block_q5_0; -static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); - -#define QK5_1 32 -typedef struct { - ggml_fp16_t d; // delta - ggml_fp16_t m; // min - uint8_t qh[4]; // 5-th bit of quants - uint8_t qs[QK5_1 / 2]; // nibbles / quants -} block_q5_1; -static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding"); - -#define QK8_0 32 -typedef struct { - ggml_fp16_t d; // delta - int8_t qs[QK8_0]; // quants -} block_q8_0; -static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding"); - -#define QK8_1 32 -typedef struct { - float d; // delta - float s; // d * sum(qs[i]) - int8_t qs[QK8_1]; // quants -} block_q8_1; -static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding"); - -// reference implementation for deterministic creation of model files -static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) { - static const int qk = QK4_0; - - assert(k % qk == 0); - - const int nb = k / qk; - - for (int i = 0; i < nb; i++) { - float amax = 0.0f; // absolute max - float max = 0.0f; - - for (int j = 0; j < qk; j++) { - const float v = x[i*qk + j]; - if (amax < fabsf(v)) { - amax = fabsf(v); - max = v; - } - } - - const float d = max / -8; - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = GGML_FP32_TO_FP16(d); - - for (int j = 0; j < qk/2; ++j) { - const float x0 = x[i*qk + 0 + j]*id; - const float x1 = x[i*qk + qk/2 + j]*id; - - const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f)); - const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f)); - - y[i].qs[j] = xi0; - y[i].qs[j] |= xi1 << 4; - } - } -} - -static void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { - quantize_row_q4_0_reference(x, y, k); -} - -static void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k) { - const int qk = QK4_1; - - assert(k % qk == 0); - - const int nb = k / qk; - - for (int i = 0; i < nb; i++) { - float min = FLT_MAX; - float max = -FLT_MAX; - - for (int j = 0; j < qk; j++) { - const float v = x[i*qk + j]; - - if (v < min) min = v; - if (v > max) max = v; - } - - const float d = (max - min) / ((1 << 4) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = GGML_FP32_TO_FP16(d); - y[i].m = GGML_FP32_TO_FP16(min); - - for (int j = 0; j < qk/2; ++j) { - const float x0 = (x[i*qk + 0 + j] - min)*id; - const float x1 = (x[i*qk + qk/2 + j] - min)*id; - - const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f)); - const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f)); - - y[i].qs[j] = xi0; - y[i].qs[j] |= xi1 << 4; - } - } -} - -static void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) { - quantize_row_q4_1_reference(x, y, k); -} - -static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) { - static const int qk = QK5_0; - - assert(k % qk == 0); - - const int nb = k / qk; - - for (int i = 0; i < nb; i++) { - float amax = 0.0f; // absolute max - float max = 0.0f; - - for (int j = 0; j < qk; j++) { - const float v = x[i*qk + j]; - if (amax < fabsf(v)) { - amax = fabsf(v); - max = v; - } - } - - const float d = max / -16; - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = GGML_FP32_TO_FP16(d); - - uint32_t qh = 0; - - for (int j = 0; j < qk/2; ++j) { - const float x0 = x[i*qk + 0 + j]*id; - const float x1 = x[i*qk + qk/2 + j]*id; - - const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f)); - const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f)); - - y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4); - - // get the 5-th bit and store it in qh at the right position - qh |= ((xi0 & 0x10) >> 4) << (j + 0); - qh |= ((xi1 & 0x10) >> 4) << (j + qk/2); - } - - memcpy(&y[i].qh, &qh, sizeof(qh)); - } -} - -static void quantize_row_q5_0(const float * restrict x, void * restrict y, int k) { - quantize_row_q5_0_reference(x, y, k); -} - -static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) { - const int qk = QK5_1; - - assert(k % qk == 0); - - const int nb = k / qk; - - for (int i = 0; i < nb; i++) { - float min = FLT_MAX; - float max = -FLT_MAX; - - for (int j = 0; j < qk; j++) { - const float v = x[i*qk + j]; - - if (v < min) min = v; - if (v > max) max = v; - } - - const float d = (max - min) / ((1 << 5) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = GGML_FP32_TO_FP16(d); - y[i].m = GGML_FP32_TO_FP16(min); - - uint32_t qh = 0; - - for (int j = 0; j < qk/2; ++j) { - const float x0 = (x[i*qk + 0 + j] - min)*id; - const float x1 = (x[i*qk + qk/2 + j] - min)*id; - - const uint8_t xi0 = (uint8_t)(x0 + 0.5f); - const uint8_t xi1 = (uint8_t)(x1 + 0.5f); - - y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4); - - // get the 5-th bit and store it in qh at the right position - qh |= ((xi0 & 0x10) >> 4) << (j + 0); - qh |= ((xi1 & 0x10) >> 4) << (j + qk/2); - } - - memcpy(&y[i].qh, &qh, sizeof(y[i].qh)); - } -} - -static void quantize_row_q5_1(const float * restrict x, void * restrict y, int k) { - quantize_row_q5_1_reference(x, y, k); -} - -// reference implementation for deterministic creation of model files -static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) { - assert(k % QK8_0 == 0); - const int nb = k / QK8_0; - - for (int i = 0; i < nb; i++) { - float amax = 0.0f; // absolute max - - for (int j = 0; j < QK8_0; j++) { - const float v = x[i*QK8_0 + j]; - amax = MAX(amax, fabsf(v)); - } - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = GGML_FP32_TO_FP16(d); - - for (int j = 0; j < QK8_0; ++j) { - const float x0 = x[i*QK8_0 + j]*id; - - y[i].qs[j] = roundf(x0); - } - } -} - -static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) { - assert(QK8_0 == 32); - assert(k % QK8_0 == 0); - const int nb = k / QK8_0; - - block_q8_0 * restrict y = vy; - -#if defined(__ARM_NEON) - for (int i = 0; i < nb; i++) { - float32x4_t srcv [8]; - float32x4_t asrcv[8]; - float32x4_t amaxv[8]; - - for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); - for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); - - for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); - for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); - for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); - - const float amax = vmaxvq_f32(amaxv[0]); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = GGML_FP32_TO_FP16(d); - - for (int j = 0; j < 8; j++) { - const float32x4_t v = vmulq_n_f32(srcv[j], id); - const int32x4_t vi = vcvtnq_s32_f32(v); - - y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); - y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); - y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); - y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); - } - } -#elif defined(__wasm_simd128__) - for (int i = 0; i < nb; i++) { - v128_t srcv [8]; - v128_t asrcv[8]; - v128_t amaxv[8]; - - for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j); - for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]); - - for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]); - for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]); - for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]); - - const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0), - wasm_f32x4_extract_lane(amaxv[0], 1)), - MAX(wasm_f32x4_extract_lane(amaxv[0], 2), - wasm_f32x4_extract_lane(amaxv[0], 3))); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = GGML_FP32_TO_FP16(d); - - for (int j = 0; j < 8; j++) { - const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id)); - const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v); - - y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0); - y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1); - y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2); - y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3); - } - } -#elif defined(__AVX2__) || defined(__AVX__) - for (int i = 0; i < nb; i++) { - // Load elements into 4 AVX vectors - __m256 v0 = _mm256_loadu_ps( x ); - __m256 v1 = _mm256_loadu_ps( x + 8 ); - __m256 v2 = _mm256_loadu_ps( x + 16 ); - __m256 v3 = _mm256_loadu_ps( x + 24 ); - x += 32; - - // Compute max(abs(e)) for the block - const __m256 signBit = _mm256_set1_ps( -0.0f ); - __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); - - __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); - max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); - max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); - const float maxScalar = _mm_cvtss_f32( max4 ); - - // Quantize these floats - const float d = maxScalar / 127.f; - y[i].d = GGML_FP32_TO_FP16(d); - const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; - const __m256 mul = _mm256_set1_ps( id ); - - // Apply the multiplier - v0 = _mm256_mul_ps( v0, mul ); - v1 = _mm256_mul_ps( v1, mul ); - v2 = _mm256_mul_ps( v2, mul ); - v3 = _mm256_mul_ps( v3, mul ); - - // Round to nearest integer - v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); - v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); - v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); - v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); - - // Convert floats to integers - __m256i i0 = _mm256_cvtps_epi32( v0 ); - __m256i i1 = _mm256_cvtps_epi32( v1 ); - __m256i i2 = _mm256_cvtps_epi32( v2 ); - __m256i i3 = _mm256_cvtps_epi32( v3 ); - -#if defined(__AVX2__) - // Convert int32 to int16 - i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 - i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 - // Convert int16 to int8 - i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 - - // We got our precious signed bytes, but the order is now wrong - // These AVX2 pack instructions process 16-byte pieces independently - // The following instruction is fixing the order - const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); - i0 = _mm256_permutevar8x32_epi32( i0, perm ); - - _mm256_storeu_si256((__m256i *)y[i].qs, i0); -#else - // Since we don't have in AVX some necessary functions, - // we split the registers in half and call AVX2 analogs from SSE - __m128i ni0 = _mm256_castsi256_si128( i0 ); - __m128i ni1 = _mm256_extractf128_si256( i0, 1); - __m128i ni2 = _mm256_castsi256_si128( i1 ); - __m128i ni3 = _mm256_extractf128_si256( i1, 1); - __m128i ni4 = _mm256_castsi256_si128( i2 ); - __m128i ni5 = _mm256_extractf128_si256( i2, 1); - __m128i ni6 = _mm256_castsi256_si128( i3 ); - __m128i ni7 = _mm256_extractf128_si256( i3, 1); - - // Convert int32 to int16 - ni0 = _mm_packs_epi32( ni0, ni1 ); - ni2 = _mm_packs_epi32( ni2, ni3 ); - ni4 = _mm_packs_epi32( ni4, ni5 ); - ni6 = _mm_packs_epi32( ni6, ni7 ); - // Convert int16 to int8 - ni0 = _mm_packs_epi16( ni0, ni2 ); - ni4 = _mm_packs_epi16( ni4, ni6 ); - - _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0); - _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4); -#endif - } -#else - // scalar - quantize_row_q8_0_reference(x, y, k); -#endif -} - -// reference implementation for deterministic creation of model files -static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) { - assert(QK8_1 == 32); - assert(k % QK8_1 == 0); - const int nb = k / QK8_1; - - for (int i = 0; i < nb; i++) { - float amax = 0.0f; // absolute max - - for (int j = 0; j < QK8_1; j++) { - const float v = x[i*QK8_1 + j]; - amax = MAX(amax, fabsf(v)); - } - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = d; - - int sum = 0; - - for (int j = 0; j < QK8_1/2; ++j) { - const float v0 = x[i*QK8_1 + j]*id; - const float v1 = x[i*QK8_1 + QK8_1/2 + j]*id; - - y[i].qs[ j] = roundf(v0); - y[i].qs[QK8_1/2 + j] = roundf(v1); - - sum += y[i].qs[ j]; - sum += y[i].qs[QK8_1/2 + j]; - } - - y[i].s = sum*d; - } -} - -static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) { - assert(k % QK8_1 == 0); - const int nb = k / QK8_1; - - block_q8_1 * restrict y = vy; - -#if defined(__ARM_NEON) - for (int i = 0; i < nb; i++) { - float32x4_t srcv [8]; - float32x4_t asrcv[8]; - float32x4_t amaxv[8]; - - for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); - for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); - - for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); - for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); - for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); - - const float amax = vmaxvq_f32(amaxv[0]); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = d; - - int32x4_t accv = vdupq_n_s32(0); - - for (int j = 0; j < 8; j++) { - const float32x4_t v = vmulq_n_f32(srcv[j], id); - const int32x4_t vi = vcvtnq_s32_f32(v); - - y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); - y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); - y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); - y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); - - accv = vaddq_s32(accv, vi); - } - - y[i].s = d * vaddvq_s32(accv); - } -#elif defined(__wasm_simd128__) - for (int i = 0; i < nb; i++) { - v128_t srcv [8]; - v128_t asrcv[8]; - v128_t amaxv[8]; - - for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j); - for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]); - - for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]); - for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]); - for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]); - - const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0), - wasm_f32x4_extract_lane(amaxv[0], 1)), - MAX(wasm_f32x4_extract_lane(amaxv[0], 2), - wasm_f32x4_extract_lane(amaxv[0], 3))); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = d; - - v128_t accv = wasm_i32x4_splat(0); - - for (int j = 0; j < 8; j++) { - const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id)); - const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v); - - y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0); - y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1); - y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2); - y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3); - - accv = wasm_i32x4_add(accv, vi); - } - - y[i].s = d * (wasm_i32x4_extract_lane(accv, 0) + - wasm_i32x4_extract_lane(accv, 1) + - wasm_i32x4_extract_lane(accv, 2) + - wasm_i32x4_extract_lane(accv, 3)); - } -#elif defined(__AVX2__) || defined(__AVX__) - for (int i = 0; i < nb; i++) { - // Load elements into 4 AVX vectors - __m256 v0 = _mm256_loadu_ps( x ); - __m256 v1 = _mm256_loadu_ps( x + 8 ); - __m256 v2 = _mm256_loadu_ps( x + 16 ); - __m256 v3 = _mm256_loadu_ps( x + 24 ); - x += 32; - - // Compute max(abs(e)) for the block - const __m256 signBit = _mm256_set1_ps( -0.0f ); - __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); - - __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); - max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); - max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); - const float maxScalar = _mm_cvtss_f32( max4 ); - - // Quantize these floats - const float d = maxScalar / 127.f; - y[i].d = d; - const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; - const __m256 mul = _mm256_set1_ps( id ); - - // Apply the multiplier - v0 = _mm256_mul_ps( v0, mul ); - v1 = _mm256_mul_ps( v1, mul ); - v2 = _mm256_mul_ps( v2, mul ); - v3 = _mm256_mul_ps( v3, mul ); - - // Round to nearest integer - v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); - v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); - v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); - v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); - - // Convert floats to integers - __m256i i0 = _mm256_cvtps_epi32( v0 ); - __m256i i1 = _mm256_cvtps_epi32( v1 ); - __m256i i2 = _mm256_cvtps_epi32( v2 ); - __m256i i3 = _mm256_cvtps_epi32( v3 ); - -#if defined(__AVX2__) - // Compute the sum of the quants and set y[i].s - y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))); - - // Convert int32 to int16 - i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 - i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 - // Convert int16 to int8 - i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 - - // We got our precious signed bytes, but the order is now wrong - // These AVX2 pack instructions process 16-byte pieces independently - // The following instruction is fixing the order - const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); - i0 = _mm256_permutevar8x32_epi32( i0, perm ); - - _mm256_storeu_si256((__m256i *)y[i].qs, i0); -#else - // Since we don't have in AVX some necessary functions, - // we split the registers in half and call AVX2 analogs from SSE - __m128i ni0 = _mm256_castsi256_si128( i0 ); - __m128i ni1 = _mm256_extractf128_si256( i0, 1); - __m128i ni2 = _mm256_castsi256_si128( i1 ); - __m128i ni3 = _mm256_extractf128_si256( i1, 1); - __m128i ni4 = _mm256_castsi256_si128( i2 ); - __m128i ni5 = _mm256_extractf128_si256( i2, 1); - __m128i ni6 = _mm256_castsi256_si128( i3 ); - __m128i ni7 = _mm256_extractf128_si256( i3, 1); - - // Compute the sum of the quants and set y[i].s - const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3)); - const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7)); - y[i].s = d * hsum_i32_4(_mm_add_epi32(s0, s1)); - - // Convert int32 to int16 - ni0 = _mm_packs_epi32( ni0, ni1 ); - ni2 = _mm_packs_epi32( ni2, ni3 ); - ni4 = _mm_packs_epi32( ni4, ni5 ); - ni6 = _mm_packs_epi32( ni6, ni7 ); - // Convert int16 to int8 - ni0 = _mm_packs_epi16( ni0, ni2 ); - ni4 = _mm_packs_epi16( ni4, ni6 ); - - _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0); - _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4); -#endif - } -#else - // scalar - quantize_row_q8_1_reference(x, y, k); -#endif -} - -static void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k) { - static const int qk = QK4_0; - - assert(k % qk == 0); - - const int nb = k / qk; - - for (int i = 0; i < nb; i++) { - const float d = GGML_FP16_TO_FP32(x[i].d); - - for (int j = 0; j < qk/2; ++j) { - const int x0 = (x[i].qs[j] & 0x0F) - 8; - const int x1 = (x[i].qs[j] >> 4) - 8; - - y[i*qk + j + 0 ] = x0*d; - y[i*qk + j + qk/2] = x1*d; - } - } -} - -static void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k) { - static const int qk = QK4_1; - - assert(k % qk == 0); - - const int nb = k / qk; - - for (int i = 0; i < nb; i++) { - const float d = GGML_FP16_TO_FP32(x[i].d); - const float m = GGML_FP16_TO_FP32(x[i].m); - - for (int j = 0; j < qk/2; ++j) { - const int x0 = (x[i].qs[j] & 0x0F); - const int x1 = (x[i].qs[j] >> 4); - - y[i*qk + j + 0 ] = x0*d + m; - y[i*qk + j + qk/2] = x1*d + m; - } - } -} - -static void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k) { - static const int qk = QK5_0; - - assert(k % qk == 0); - - const int nb = k / qk; - - for (int i = 0; i < nb; i++) { - const float d = GGML_FP16_TO_FP32(x[i].d); - - uint32_t qh; - memcpy(&qh, x[i].qh, sizeof(qh)); - - for (int j = 0; j < qk/2; ++j) { - const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; - const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; - - const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16; - const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16; - - y[i*qk + j + 0 ] = x0*d; - y[i*qk + j + qk/2] = x1*d; - } - } -} - -static void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k) { - static const int qk = QK5_1; - - assert(k % qk == 0); - - const int nb = k / qk; - - for (int i = 0; i < nb; i++) { - const float d = GGML_FP16_TO_FP32(x[i].d); - const float m = GGML_FP16_TO_FP32(x[i].m); - - uint32_t qh; - memcpy(&qh, x[i].qh, sizeof(qh)); - - for (int j = 0; j < qk/2; ++j) { - const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; - const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; - - const int x0 = (x[i].qs[j] & 0x0F) | xh_0; - const int x1 = (x[i].qs[j] >> 4) | xh_1; - - y[i*qk + j + 0 ] = x0*d + m; - y[i*qk + j + qk/2] = x1*d + m; - } - } -} - -static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, int k) { - static const int qk = QK8_0; - - assert(k % qk == 0); - - const int nb = k / qk; - - const block_q8_0 * restrict x = vx; - - for (int i = 0; i < nb; i++) { - const float d = GGML_FP16_TO_FP32(x[i].d); - - for (int j = 0; j < qk; ++j) { - y[i*qk + j] = x[i].qs[j]*d; - } - } -} - -static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y); -static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y); -static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); -static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); -static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); -static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); -static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); +static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y); +static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y); static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { [GGML_TYPE_I8] = { @@ -1673,6 +472,28 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_q4_1_q8_1, .vec_dot_type = GGML_TYPE_Q8_1, }, + [4] = { // GGML_TYPE_Q4_2 + .type_name = "DEPRECATED", + .blck_size = 0, + .type_size = 0, + .is_quantized = false, + .to_float = NULL, + .from_float = NULL, + .from_float_reference = NULL, + .vec_dot = NULL, + .vec_dot_type = GGML_TYPE_COUNT, + }, + [5] = { // GGML_TYPE_Q4_3 + .type_name = "DEPRECATED", + .blck_size = 0, + .type_size = 0, + .is_quantized = false, + .to_float = NULL, + .from_float = NULL, + .from_float_reference = NULL, + .vec_dot = NULL, + .vec_dot_type = GGML_TYPE_COUNT, + }, [GGML_TYPE_Q5_0] = { .type_name = "q5_0", .blck_size = QK5_0, @@ -1700,7 +521,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .blck_size = QK8_0, .type_size = sizeof(block_q8_0), .is_quantized = true, - .to_float = dequantize_row_q8_0, + .to_float = (ggml_to_float_t) dequantize_row_q8_0, .from_float = quantize_row_q8_0, .from_float_reference = (ggml_from_float_t) quantize_row_q8_0_reference, .vec_dot = ggml_vec_dot_q8_0_q8_0, @@ -1715,7 +536,6 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float_reference = (ggml_from_float_t) quantize_row_q8_1_reference, .vec_dot_type = GGML_TYPE_Q8_1, }, -#ifdef GGML_USE_K_QUANTS [GGML_TYPE_Q2_K] = { .type_name = "q2_K", .blck_size = QK_K, @@ -1778,7 +598,6 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .from_float = quantize_row_q8_K, } -#endif }; // For internal test use @@ -1787,11 +606,22 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) { return type_traits[type]; } - // // simd mappings // +#if defined(__ARM_NEON) +#if !defined(__aarch64__) + +// 64-bit compatibility + +inline static float vaddvq_f32(float32x4_t v) { + return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3); +} + +#endif +#endif + // we define a common set of C macros which map to specific intrinsics based on the current architecture // we then implement the fundamental computation operations below using only these macros // adding support for new architectures requires to define the corresponding SIMD macros @@ -1863,7 +693,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) { #define GGML_F16x8_ADD vaddq_f16 #define GGML_F16x8_MUL vmulq_f16 #define GGML_F16x8_REDUCE(res, x) \ - { \ + do { \ int offset = GGML_F16_ARR >> 1; \ for (int i = 0; i < offset; ++i) { \ x[i] = vaddq_f16(x[i], x[offset+i]); \ @@ -1879,7 +709,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) { const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \ const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \ res = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \ - } + } while (0) #define GGML_F16_VEC GGML_F16x8 #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO @@ -1940,7 +770,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) { #define GGML_F32x8_ADD _mm256_add_ps #define GGML_F32x8_MUL _mm256_mul_ps #define GGML_F32x8_REDUCE(res, x) \ -{ \ +do { \ int offset = GGML_F32_ARR >> 1; \ for (int i = 0; i < offset; ++i) { \ x[i] = _mm256_add_ps(x[i], x[offset+i]); \ @@ -1957,7 +787,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) { _mm256_extractf128_ps(x[0], 1)); \ const __m128 t1 = _mm_hadd_ps(t0, t0); \ res = _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); \ -} +} while (0) // TODO: is this optimal ? #define GGML_F32_VEC GGML_F32x8 @@ -2292,1333 +1122,115 @@ static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) { #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1 #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p) -#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i]) -#define GGML_F16_VEC_FMA GGML_F32Cx4_FMA -#define GGML_F16_VEC_ADD GGML_F32Cx4_ADD -#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL -#define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE - -#endif - -// GGML_F32_ARR / GGML_F16_ARR -// number of registers to use per step -#ifdef GGML_SIMD -#define GGML_F32_ARR (GGML_F32_STEP/GGML_F32_EPR) -#define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR) -#endif - -// -// fundamental operations -// - -inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - -inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - -inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - -inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - -inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; } -inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; } -inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; } -inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; } -inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; } -inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; } -inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; } -inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; } -inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; } -inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; } - -static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) { -#ifdef GGML_SIMD - float sumf = 0.0f; - const int np = (n & ~(GGML_F32_STEP - 1)); - - GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; - - GGML_F32_VEC ax[GGML_F32_ARR]; - GGML_F32_VEC ay[GGML_F32_ARR]; - - for (int i = 0; i < np; i += GGML_F32_STEP) { - for (int j = 0; j < GGML_F32_ARR; j++) { - ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); - ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); - - sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]); - } - } - - // reduce sum0..sum3 to sum0 - GGML_F32_VEC_REDUCE(sumf, sum); - - // leftovers - for (int i = np; i < n; ++i) { - sumf += x[i]*y[i]; - } -#else - // scalar - ggml_float sumf = 0.0; - for (int i = 0; i < n; ++i) { - sumf += (ggml_float)(x[i]*y[i]); - } -#endif - - *s = sumf; -} - -static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) { - ggml_float sumf = 0.0; - -#if defined(GGML_SIMD) - const int np = (n & ~(GGML_F16_STEP - 1)); - - GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO }; - - GGML_F16_VEC ax[GGML_F16_ARR]; - GGML_F16_VEC ay[GGML_F16_ARR]; - - for (int i = 0; i < np; i += GGML_F16_STEP) { - for (int j = 0; j < GGML_F16_ARR; j++) { - ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); - ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); - - sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]); - } - } - - // reduce sum0..sum3 to sum0 - GGML_F16_VEC_REDUCE(sumf, sum); - - // leftovers - for (int i = np; i < n; ++i) { - sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i])); - } -#else - for (int i = 0; i < n; ++i) { - sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i])); - } -#endif - - *s = sumf; -} - -static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int qk = QK8_0; - const int nb = n / qk; - - assert(n % qk == 0); - - const block_q4_0 * restrict x = vx; - const block_q8_0 * restrict y = vy; - -#if defined(__ARM_NEON) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); - - GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb - for (int i = 0; i < nb; i += 2) { - const block_q4_0 * restrict x0 = &x[i + 0]; - const block_q4_0 * restrict x1 = &x[i + 1]; - const block_q8_0 * restrict y0 = &y[i + 0]; - const block_q8_0 * restrict y1 = &y[i + 1]; - - const uint8x16_t m4b = vdupq_n_u8(0x0F); - const int8x16_t s8b = vdupq_n_s8(0x8); - - const uint8x16_t v0_0 = vld1q_u8(x0->qs); - const uint8x16_t v0_1 = vld1q_u8(x1->qs); - - // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); - const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - - // sub 8 - const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); - const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); - const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); - const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); - - // load y - const int8x16_t v1_0l = vld1q_s8(y0->qs); - const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); - const int8x16_t v1_1l = vld1q_s8(y1->qs); - const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); - -#if defined(__ARM_FEATURE_DOTPROD) - // dot product into int32x4_t - const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h); - const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); -#else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0h)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h)); - - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1l)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1h)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h)); - - const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); - const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); - const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); -#endif - } - - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); -#elif defined(__AVX2__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - - // Main loop - for (int i = 0; i < nb; ++i) { - /* Compute combined scale for the block */ - const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); - - __m256i bx = bytes_from_nibbles_32(x[i].qs); - - // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. - const __m256i off = _mm256_set1_epi8( 8 ); - bx = _mm256_sub_epi8( bx, off ); - - __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); - - const __m256 q = mul_sum_i8_pairs_float(bx, by); - - /* Multiply q with scale and accumulate */ - acc = _mm256_fmadd_ps( d, q, acc ); - } - - *s = hsum_float_8(acc); -#elif defined(__AVX__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - - // Main loop - for (int i = 0; i < nb; ++i) { - // Compute combined scale for the block - const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); - - const __m128i lowMask = _mm_set1_epi8(0xF); - const __m128i off = _mm_set1_epi8(8); - - const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs); - - __m128i bx = _mm_and_si128(lowMask, tmp); - __m128i by = _mm_loadu_si128((const __m128i *)y[i].qs); - bx = _mm_sub_epi8(bx, off); - const __m128i i32_0 = mul_sum_i8_pairs(bx, by); - - bx = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4)); - by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16)); - bx = _mm_sub_epi8(bx, off); - const __m128i i32_1 = mul_sum_i8_pairs(bx, by); - - // Convert int32_t to float - __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1)); - - // Apply the scale, and accumulate - acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc); - } - - *s = hsum_float_8(acc); -#elif defined(__SSSE3__) - // set constants - const __m128i lowMask = _mm_set1_epi8(0xF); - const __m128i off = _mm_set1_epi8(8); - - // Initialize accumulator with zeros - __m128 acc_0 = _mm_setzero_ps(); - __m128 acc_1 = _mm_setzero_ps(); - __m128 acc_2 = _mm_setzero_ps(); - __m128 acc_3 = _mm_setzero_ps(); - - // First round without accumulation - { - _mm_prefetch(&x[0] + sizeof(block_q4_0), _MM_HINT_T0); - _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0); - - // Compute combined scale for the block 0 and 1 - const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[0].d) * GGML_FP16_TO_FP32(y[0].d) ); - - const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs); - - __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1); - __m128i by_0 = _mm_loadu_si128((const __m128i *)y[0].qs); - bx_0 = _mm_sub_epi8(bx_0, off); - const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); - - __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4)); - __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[0].qs + 16)); - bx_1 = _mm_sub_epi8(bx_1, off); - const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); - - _mm_prefetch(&x[1] + sizeof(block_q4_0), _MM_HINT_T0); - _mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0); - - // Compute combined scale for the block 2 and 3 - const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[1].d) * GGML_FP16_TO_FP32(y[1].d) ); - - const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs); - - __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3); - __m128i by_2 = _mm_loadu_si128((const __m128i *)y[1].qs); - bx_2 = _mm_sub_epi8(bx_2, off); - const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); - - __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4)); - __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[1].qs + 16)); - bx_3 = _mm_sub_epi8(bx_3, off); - const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); - - // Convert int32_t to float - __m128 p0 = _mm_cvtepi32_ps(i32_0); - __m128 p1 = _mm_cvtepi32_ps(i32_1); - __m128 p2 = _mm_cvtepi32_ps(i32_2); - __m128 p3 = _mm_cvtepi32_ps(i32_3); - - // Apply the scale - acc_0 = _mm_mul_ps( d_0_1, p0 ); - acc_1 = _mm_mul_ps( d_0_1, p1 ); - acc_2 = _mm_mul_ps( d_2_3, p2 ); - acc_3 = _mm_mul_ps( d_2_3, p3 ); - } - - // Main loop - GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb - for (int i = 2; i < nb; i+=2) { - _mm_prefetch(&x[i] + sizeof(block_q4_0), _MM_HINT_T0); - _mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0); - - // Compute combined scale for the block 0 and 1 - const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); - - const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs); - - __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1); - __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs); - bx_0 = _mm_sub_epi8(bx_0, off); - const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); - - __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4)); - __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16)); - bx_1 = _mm_sub_epi8(bx_1, off); - const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); - - _mm_prefetch(&x[i] + 2 * sizeof(block_q4_0), _MM_HINT_T0); - _mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0); - - // Compute combined scale for the block 2 and 3 - const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i + 1].d) * GGML_FP16_TO_FP32(y[i + 1].d) ); - - const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs); - - __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3); - __m128i by_2 = _mm_loadu_si128((const __m128i *)y[i + 1].qs); - bx_2 = _mm_sub_epi8(bx_2, off); - const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); - - __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4)); - __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[i + 1].qs + 16)); - bx_3 = _mm_sub_epi8(bx_3, off); - const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); - - // Convert int32_t to float - __m128 p0 = _mm_cvtepi32_ps(i32_0); - __m128 p1 = _mm_cvtepi32_ps(i32_1); - __m128 p2 = _mm_cvtepi32_ps(i32_2); - __m128 p3 = _mm_cvtepi32_ps(i32_3); - - // Apply the scale - __m128 p0_d = _mm_mul_ps( d_0_1, p0 ); - __m128 p1_d = _mm_mul_ps( d_0_1, p1 ); - __m128 p2_d = _mm_mul_ps( d_2_3, p2 ); - __m128 p3_d = _mm_mul_ps( d_2_3, p3 ); - - // Acummulate - acc_0 = _mm_add_ps(p0_d, acc_0); - acc_1 = _mm_add_ps(p1_d, acc_1); - acc_2 = _mm_add_ps(p2_d, acc_2); - acc_3 = _mm_add_ps(p3_d, acc_3); - } - - *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); -#elif defined(__riscv_v_intrinsic) - float sumf = 0.0; - - size_t vl = __riscv_vsetvl_e8m1(qk/2); - - for (int i = 0; i < nb; i++) { - vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl); - - vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl); - vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl); - - vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl); - vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl); - - vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a); - vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l); - - vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 8, vl); - vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 8, vl); - - vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl); - vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl); - - vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl); - - int sumi = __riscv_vmv_x_s_i32m1_i32(vs1); - sumi += __riscv_vmv_x_s_i32m1_i32(vs2); - - sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d); - } - - *s = sumf; -#else - // scalar - float sumf = 0.0; - - for (int i = 0; i < nb; i++) { - int sumi = 0; - - for (int j = 0; j < qk/2; ++j) { - const int v0 = (x[i].qs[j] & 0x0F) - 8; - const int v1 = (x[i].qs[j] >> 4) - 8; - - sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]); - } - - sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d); - } - - *s = sumf; -#endif -} - -static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int qk = QK8_1; - const int nb = n / qk; - - assert(n % qk == 0); - - const block_q4_1 * restrict x = vx; - const block_q8_1 * restrict y = vy; - - // TODO: add WASM SIMD -#if defined(__ARM_NEON) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); - - float summs = 0; - - GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb - for (int i = 0; i < nb; i += 2) { - const block_q4_1 * restrict x0 = &x[i + 0]; - const block_q4_1 * restrict x1 = &x[i + 1]; - const block_q8_1 * restrict y0 = &y[i + 0]; - const block_q8_1 * restrict y1 = &y[i + 1]; - - summs += GGML_FP16_TO_FP32(x0->m) * y0->s + GGML_FP16_TO_FP32(x1->m) * y1->s; - - const uint8x16_t m4b = vdupq_n_u8(0x0F); - - const uint8x16_t v0_0 = vld1q_u8(x0->qs); - const uint8x16_t v0_1 = vld1q_u8(x1->qs); - - // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); - const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - - // load y - const int8x16_t v1_0l = vld1q_s8(y0->qs); - const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); - const int8x16_t v1_1l = vld1q_s8(y1->qs); - const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); - -#if defined(__ARM_FEATURE_DOTPROD) - // dot product into int32x4_t - const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h); - const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*y1->d); -#else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0l)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0l)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0h)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0h)); - - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1l)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1l)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1h)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1h)); - - const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); - const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); - const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d); -#endif - } - - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs; -#elif defined(__AVX2__) || defined(__AVX__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - - float summs = 0; - - // Main loop - for (int i = 0; i < nb; ++i) { - const float d0 = GGML_FP16_TO_FP32(x[i].d); - const float d1 = y[i].d; - - summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s; - - const __m256 d0v = _mm256_set1_ps( d0 ); - const __m256 d1v = _mm256_set1_ps( d1 ); - - // Compute combined scales - const __m256 d0d1 = _mm256_mul_ps( d0v, d1v ); - - // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes - const __m256i bx = bytes_from_nibbles_32(x[i].qs); - const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs ); - - const __m256 xy = mul_sum_us8_pairs_float(bx, by); - - // Accumulate d0*d1*x*y -#if defined(__AVX2__) - acc = _mm256_fmadd_ps( d0d1, xy, acc ); -#else - acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc ); -#endif - } - - *s = hsum_float_8(acc) + summs; -#elif defined(__riscv_v_intrinsic) - float sumf = 0.0; - - size_t vl = __riscv_vsetvl_e8m1(qk/2); - - for (int i = 0; i < nb; i++) { - vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl); - - vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl); - vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl); - - vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl); - vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl); - - vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a); - vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l); - - vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl); - vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl); - - vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl); - - int sumi = __riscv_vmv_x_s_i32m1_i32(vs1); - sumi += __riscv_vmv_x_s_i32m1_i32(vs2); - - sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s; - } - - *s = sumf; -#else - // scalar - float sumf = 0.0; - - for (int i = 0; i < nb; i++) { - int sumi = 0; - - for (int j = 0; j < qk/2; ++j) { - const int v0 = (x[i].qs[j] & 0x0F); - const int v1 = (x[i].qs[j] >> 4); - - sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]); - } - - sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s; - } - - *s = sumf; -#endif -} - -static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int qk = QK8_0; - const int nb = n / qk; - - assert(n % qk == 0); - assert(qk == QK5_0); - - const block_q5_0 * restrict x = vx; - const block_q8_0 * restrict y = vy; - -#if defined(__ARM_NEON) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); - - uint32_t qh0; - uint32_t qh1; - - uint64_t tmp0[4]; - uint64_t tmp1[4]; - - GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb - for (int i = 0; i < nb; i += 2) { - const block_q5_0 * restrict x0 = &x[i]; - const block_q5_0 * restrict x1 = &x[i + 1]; - const block_q8_0 * restrict y0 = &y[i]; - const block_q8_0 * restrict y1 = &y[i + 1]; - - const uint8x16_t m4b = vdupq_n_u8(0x0F); - - // extract the 5th bit via lookup table ((!b) << 4) - memcpy(&qh0, x0->qh, sizeof(qh0)); - memcpy(&qh1, x1->qh, sizeof(qh1)); - - tmp0[0] = table_b2b_1[(qh0 >> 0) & 0xFF]; - tmp0[1] = table_b2b_1[(qh0 >> 8) & 0xFF]; - tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF]; - tmp0[3] = table_b2b_1[(qh0 >> 24) ]; - - tmp1[0] = table_b2b_1[(qh1 >> 0) & 0xFF]; - tmp1[1] = table_b2b_1[(qh1 >> 8) & 0xFF]; - tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF]; - tmp1[3] = table_b2b_1[(qh1 >> 24) ]; - - const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0)); - const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2)); - const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0)); - const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2)); - - const uint8x16_t v0_0 = vld1q_u8(x0->qs); - const uint8x16_t v0_1 = vld1q_u8(x1->qs); - - // 4-bit -> 8-bit - int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); - int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - - // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero) - const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0); - const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0); - const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1); - const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1); - - // load y - const int8x16_t v1_0l = vld1q_s8(y0->qs); - const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); - const int8x16_t v1_1l = vld1q_s8(y1->qs); - const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); - -#if defined(__ARM_FEATURE_DOTPROD) - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( - vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), - vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( - vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), - vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); -#else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h)); - - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h)); - - const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); - const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); - const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); -#endif - } - - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); -#elif defined(__wasm_simd128__) - v128_t sumv = wasm_f32x4_splat(0.0f); - - uint32_t qh; - uint64_t tmp[4]; - - // TODO: check if unrolling this is better - for (int i = 0; i < nb; ++i) { - const block_q5_0 * restrict x0 = &x[i]; - const block_q8_0 * restrict y0 = &y[i]; - - const v128_t m4b = wasm_i8x16_splat(0x0F); - - // extract the 5th bit - memcpy(&qh, x0->qh, sizeof(qh)); - - tmp[0] = table_b2b_1[(qh >> 0) & 0xFF]; - tmp[1] = table_b2b_1[(qh >> 8) & 0xFF]; - tmp[2] = table_b2b_1[(qh >> 16) & 0xFF]; - tmp[3] = table_b2b_1[(qh >> 24) ]; - - const v128_t qhl = wasm_v128_load(tmp + 0); - const v128_t qhh = wasm_v128_load(tmp + 2); - - const v128_t v0 = wasm_v128_load(x0->qs); - - // 4-bit -> 8-bit - const v128_t v0l = wasm_v128_and (v0, m4b); - const v128_t v0h = wasm_u8x16_shr(v0, 4); - - // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero) - const v128_t v0lf = wasm_i8x16_sub(v0l, qhl); - const v128_t v0hf = wasm_i8x16_sub(v0h, qhh); - - // load y - const v128_t v1l = wasm_v128_load(y0->qs); - const v128_t v1h = wasm_v128_load(y0->qs + 16); - - // int8x16 -> int16x8 - const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf); - const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf); - const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf); - const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf); - - const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l); - const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l); - const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h); - const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h); - - // dot product - sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4( - wasm_i32x4_add( - wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), - wasm_i32x4_dot_i16x8(v0lfh, v1lh)), - wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), - wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), - wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d)))); - } - - *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + - wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3); -#elif defined(__AVX2__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - - // Main loop - for (int i = 0; i < nb; i++) { - /* Compute combined scale for the block */ - const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); - - __m256i bx = bytes_from_nibbles_32(x[i].qs); - __m256i bxhi = bytes_from_bits_32(x[i].qh); - bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0)); - bx = _mm256_or_si256(bx, bxhi); - - __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); - - const __m256 q = mul_sum_i8_pairs_float(bx, by); - - /* Multiply q with scale and accumulate */ - acc = _mm256_fmadd_ps(d, q, acc); - } - - *s = hsum_float_8(acc); -#elif defined(__AVX__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - __m128i mask = _mm_set1_epi8((char)0xF0); - - // Main loop - for (int i = 0; i < nb; i++) { - /* Compute combined scale for the block */ - const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); - - __m256i bx = bytes_from_nibbles_32(x[i].qs); - const __m256i bxhi = bytes_from_bits_32(x[i].qh); - __m128i bxhil = _mm256_castsi256_si128(bxhi); - __m128i bxhih = _mm256_extractf128_si256(bxhi, 1); - bxhil = _mm_andnot_si128(bxhil, mask); - bxhih = _mm_andnot_si128(bxhih, mask); - __m128i bxl = _mm256_castsi256_si128(bx); - __m128i bxh = _mm256_extractf128_si256(bx, 1); - bxl = _mm_or_si128(bxl, bxhil); - bxh = _mm_or_si128(bxh, bxhih); - bx = MM256_SET_M128I(bxh, bxl); - - const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); - - const __m256 q = mul_sum_i8_pairs_float(bx, by); - - /* Multiply q with scale and accumulate */ - acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc); - } - - *s = hsum_float_8(acc); -#elif defined(__riscv_v_intrinsic) - float sumf = 0.0; - - uint32_t qh; - - // These temp values are for masking and shift operations - uint32_t temp_1[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; - uint32_t temp_2[16] = {0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80, - 0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000}; - - size_t vl = __riscv_vsetvl_e8m1(qk/2); - - for (int i = 0; i < nb; i++) { - memcpy(&qh, x[i].qh, sizeof(uint32_t)); - - // temporary registers - vuint32m4_t vt_1 = __riscv_vle32_v_u32m4(temp_2, vl); - vuint32m4_t vt_2 = __riscv_vle32_v_u32m4(temp_1, vl); - vuint32m4_t vt_3 = __riscv_vsll_vx_u32m4(vt_1, 16, vl); - vuint32m4_t vt_4 = __riscv_vadd_vx_u32m4(vt_2, 12, vl); - - // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; - vuint32m4_t xha_0 = __riscv_vand_vx_u32m4(vt_1, qh, vl); - vuint32m4_t xhr_0 = __riscv_vsrl_vv_u32m4(xha_0, vt_2, vl); - vuint32m4_t xhl_0 = __riscv_vsll_vx_u32m4(xhr_0, 4, vl); - - // ((qh & (1u << (j + 16))) >> (j + 12)); - vuint32m4_t xha_1 = __riscv_vand_vx_u32m4(vt_3, qh, vl); - vuint32m4_t xhl_1 = __riscv_vsrl_vv_u32m4(xha_1, vt_4, vl); - - // narrowing - vuint16m2_t xhc_0 = __riscv_vncvt_x_x_w_u16m2(xhl_0, vl); - vuint8m1_t xh_0 = __riscv_vncvt_x_x_w_u8m1(xhc_0, vl); - - vuint16m2_t xhc_1 = __riscv_vncvt_x_x_w_u16m2(xhl_1, vl); - vuint8m1_t xh_1 = __riscv_vncvt_x_x_w_u8m1(xhc_1, vl); - - // load - vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl); - - vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl); - vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl); - - vuint8m1_t x_at = __riscv_vand_vx_u8m1(tx, 0x0F, vl); - vuint8m1_t x_lt = __riscv_vsrl_vx_u8m1(tx, 0x04, vl); - - vuint8m1_t x_a = __riscv_vor_vv_u8m1(x_at, xh_0, vl); - vuint8m1_t x_l = __riscv_vor_vv_u8m1(x_lt, xh_1, vl); - - vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a); - vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l); - - vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 16, vl); - vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 16, vl); - - vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl); - vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl); - - vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl); - - int sumi = __riscv_vmv_x_s_i32m1_i32(vs1); - sumi += __riscv_vmv_x_s_i32m1_i32(vs2); - - sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi; - } - - *s = sumf; -#else - // scalar - float sumf = 0.0; - - for (int i = 0; i < nb; i++) { - uint32_t qh; - memcpy(&qh, x[i].qh, sizeof(qh)); - - int sumi = 0; - - for (int j = 0; j < qk/2; ++j) { - const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; - const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12)); - - const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16; - const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16; - - sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]); - } - - sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi; - } - - *s = sumf; -#endif -} - -static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int qk = QK8_1; - const int nb = n / qk; - - assert(n % qk == 0); - assert(qk == QK5_1); - - const block_q5_1 * restrict x = vx; - const block_q8_1 * restrict y = vy; - -#if defined(__ARM_NEON) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); - - float summs0 = 0.0f; - float summs1 = 0.0f; - - uint32_t qh0; - uint32_t qh1; - - uint64_t tmp0[4]; - uint64_t tmp1[4]; - - GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb - for (int i = 0; i < nb; i += 2) { - const block_q5_1 * restrict x0 = &x[i]; - const block_q5_1 * restrict x1 = &x[i + 1]; - const block_q8_1 * restrict y0 = &y[i]; - const block_q8_1 * restrict y1 = &y[i + 1]; - - const uint8x16_t m4b = vdupq_n_u8(0x0F); - - summs0 += GGML_FP16_TO_FP32(x0->m) * y0->s; - summs1 += GGML_FP16_TO_FP32(x1->m) * y1->s; - - // extract the 5th bit via lookup table ((b) << 4) - memcpy(&qh0, x0->qh, sizeof(qh0)); - memcpy(&qh1, x1->qh, sizeof(qh1)); - - tmp0[0] = table_b2b_0[(qh0 >> 0) & 0xFF]; - tmp0[1] = table_b2b_0[(qh0 >> 8) & 0xFF]; - tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF]; - tmp0[3] = table_b2b_0[(qh0 >> 24) ]; - - tmp1[0] = table_b2b_0[(qh1 >> 0) & 0xFF]; - tmp1[1] = table_b2b_0[(qh1 >> 8) & 0xFF]; - tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF]; - tmp1[3] = table_b2b_0[(qh1 >> 24) ]; - - const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0)); - const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2)); - const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0)); - const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2)); - - const uint8x16_t v0_0 = vld1q_u8(x0->qs); - const uint8x16_t v0_1 = vld1q_u8(x1->qs); - - // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); - const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - - // add high bit - const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0); - const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0); - const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1); - const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1); - - // load y - const int8x16_t v1_0l = vld1q_s8(y0->qs); - const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); - const int8x16_t v1_1l = vld1q_s8(y1->qs); - const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); - -#if defined(__ARM_FEATURE_DOTPROD) - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( - vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), - vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( - vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), - vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d); -#else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h)); - - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h)); - - const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); - const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); - const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d); -#endif - } - - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1; -#elif defined(__wasm_simd128__) - v128_t sumv = wasm_f32x4_splat(0.0f); - - float summs = 0.0f; - - uint32_t qh; - uint64_t tmp[4]; - - // TODO: check if unrolling this is better - for (int i = 0; i < nb; ++i) { - const block_q5_1 * restrict x0 = &x[i]; - const block_q8_1 * restrict y0 = &y[i]; - - summs += GGML_FP16_TO_FP32(x0->m) * y0->s; - - const v128_t m4b = wasm_i8x16_splat(0x0F); - - // extract the 5th bit - memcpy(&qh, x0->qh, sizeof(qh)); - - tmp[0] = table_b2b_0[(qh >> 0) & 0xFF]; - tmp[1] = table_b2b_0[(qh >> 8) & 0xFF]; - tmp[2] = table_b2b_0[(qh >> 16) & 0xFF]; - tmp[3] = table_b2b_0[(qh >> 24) ]; - - const v128_t qhl = wasm_v128_load(tmp + 0); - const v128_t qhh = wasm_v128_load(tmp + 2); - - const v128_t v0 = wasm_v128_load(x0->qs); - - // 4-bit -> 8-bit - const v128_t v0l = wasm_v128_and (v0, m4b); - const v128_t v0h = wasm_u8x16_shr(v0, 4); - - // add high bit - const v128_t v0lf = wasm_v128_or(v0l, qhl); - const v128_t v0hf = wasm_v128_or(v0h, qhh); - - // load y - const v128_t v1l = wasm_v128_load(y0->qs); - const v128_t v1h = wasm_v128_load(y0->qs + 16); - - // int8x16 -> int16x8 - const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf); - const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf); - const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf); - const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf); - - const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l); - const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l); - const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h); - const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h); - - // dot product - sumv = wasm_f32x4_add(sumv, - wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add( - wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), - wasm_i32x4_dot_i16x8(v0lfh, v1lh)), - wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), - wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), - wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * y0->d))); - } - - *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + - wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs; -#elif defined(__AVX2__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - - float summs = 0.0f; - - // Main loop - for (int i = 0; i < nb; i++) { - const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)); - - summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s; - - __m256i bx = bytes_from_nibbles_32(x[i].qs); - __m256i bxhi = bytes_from_bits_32(x[i].qh); - bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10)); - bx = _mm256_or_si256(bx, bxhi); - - const __m256 dy = _mm256_set1_ps(y[i].d); - const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); - - const __m256 q = mul_sum_us8_pairs_float(bx, by); - - acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc); - } - - *s = hsum_float_8(acc) + summs; -#elif defined(__AVX__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - __m128i mask = _mm_set1_epi8(0x10); - - float summs = 0.0f; - - // Main loop - for (int i = 0; i < nb; i++) { - const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)); - - summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s; - - __m256i bx = bytes_from_nibbles_32(x[i].qs); - const __m256i bxhi = bytes_from_bits_32(x[i].qh); - __m128i bxhil = _mm256_castsi256_si128(bxhi); - __m128i bxhih = _mm256_extractf128_si256(bxhi, 1); - bxhil = _mm_and_si128(bxhil, mask); - bxhih = _mm_and_si128(bxhih, mask); - __m128i bxl = _mm256_castsi256_si128(bx); - __m128i bxh = _mm256_extractf128_si256(bx, 1); - bxl = _mm_or_si128(bxl, bxhil); - bxh = _mm_or_si128(bxh, bxhih); - bx = MM256_SET_M128I(bxh, bxl); - - const __m256 dy = _mm256_set1_ps(y[i].d); - const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); - - const __m256 q = mul_sum_us8_pairs_float(bx, by); - - acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc); - } - - *s = hsum_float_8(acc) + summs; -#elif defined(__riscv_v_intrinsic) - float sumf = 0.0; - - uint32_t qh; - - // These temp values are for shift operations - uint32_t temp_1[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; - - size_t vl = __riscv_vsetvl_e8m1(qk/2); - - for (int i = 0; i < nb; i++) { - memcpy(&qh, x[i].qh, sizeof(uint32_t)); - - // temporary registers - vuint32m4_t vt_1 = __riscv_vle32_v_u32m4(temp_1, vl); - vuint32m4_t vt_2 = __riscv_vadd_vx_u32m4(vt_1, 12, vl); +#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i]) +#define GGML_F16_VEC_FMA GGML_F32Cx4_FMA +#define GGML_F16_VEC_ADD GGML_F32Cx4_ADD +#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL +#define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE - // load qh - vuint32m4_t vqh = __riscv_vmv_v_x_u32m4(qh, vl); +#endif - // ((qh >> (j + 0)) << 4) & 0x10; - vuint32m4_t xhr_0 = __riscv_vsrl_vv_u32m4(vqh, vt_1, vl); - vuint32m4_t xhl_0 = __riscv_vsll_vx_u32m4(xhr_0, 4, vl); - vuint32m4_t xha_0 = __riscv_vand_vx_u32m4(xhl_0, 0x10, vl); +// GGML_F32_ARR / GGML_F16_ARR +// number of registers to use per step +#ifdef GGML_SIMD +#define GGML_F32_ARR (GGML_F32_STEP/GGML_F32_EPR) +#define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR) +#endif - // ((qh >> (j + 12)) ) & 0x10; - vuint32m4_t xhr_1 = __riscv_vsrl_vv_u32m4(vqh, vt_2, vl); - vuint32m4_t xha_1 = __riscv_vand_vx_u32m4(xhr_1, 0x10, vl); +// +// fundamental operations +// - // narrowing - vuint16m2_t xhc_0 = __riscv_vncvt_x_x_w_u16m2(xha_0, vl); - vuint8m1_t xh_0 = __riscv_vncvt_x_x_w_u8m1(xhc_0, vl); +inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - vuint16m2_t xhc_1 = __riscv_vncvt_x_x_w_u16m2(xha_1, vl); - vuint8m1_t xh_1 = __riscv_vncvt_x_x_w_u8m1(xhc_1, vl); +inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - // load - vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl); +inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl); - vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl); +inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - vuint8m1_t x_at = __riscv_vand_vx_u8m1(tx, 0x0F, vl); - vuint8m1_t x_lt = __riscv_vsrl_vx_u8m1(tx, 0x04, vl); +inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; } +inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; } +inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; } +inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; } +inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; } +inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; } +inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; } +inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; } +inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; } +inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; } - vuint8m1_t x_a = __riscv_vor_vv_u8m1(x_at, xh_0, vl); - vuint8m1_t x_l = __riscv_vor_vv_u8m1(x_lt, xh_1, vl); +static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) { +#ifdef GGML_SIMD + float sumf = 0.0f; + const int np = (n & ~(GGML_F32_STEP - 1)); - vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a); - vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l); + GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; - vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl); - vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl); + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; - vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl); + sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]); + } + } - int sumi = __riscv_vmv_x_s_i32m1_i32(vs1); - sumi += __riscv_vmv_x_s_i32m1_i32(vs2); + // reduce sum0..sum3 to sum0 + GGML_F32_VEC_REDUCE(sumf, sum); - sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s; + // leftovers + for (int i = np; i < n; ++i) { + sumf += x[i]*y[i]; } - - *s = sumf; #else // scalar - float sumf = 0.0; - - for (int i = 0; i < nb; i++) { - uint32_t qh; - memcpy(&qh, x[i].qh, sizeof(qh)); - - int sumi = 0; - - for (int j = 0; j < qk/2; ++j) { - const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; - const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; - - const int32_t x0 = (x[i].qs[j] & 0xF) | xh_0; - const int32_t x1 = (x[i].qs[j] >> 4) | xh_1; - - sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]); - } - - sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s; + ggml_float sumf = 0.0; + for (int i = 0; i < n; ++i) { + sumf += (ggml_float)(x[i]*y[i]); } +#endif *s = sumf; -#endif } -static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int qk = QK8_0; - const int nb = n / qk; - - assert(n % qk == 0); - - const block_q8_0 * restrict x = vx; - const block_q8_0 * restrict y = vy; - -#if defined(__ARM_NEON) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); - - GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb - for (int i = 0; i < nb; i += 2) { - const block_q8_0 * restrict x0 = &x[i + 0]; - const block_q8_0 * restrict x1 = &x[i + 1]; - const block_q8_0 * restrict y0 = &y[i + 0]; - const block_q8_0 * restrict y1 = &y[i + 1]; - - const int8x16_t x0_0 = vld1q_s8(x0->qs); - const int8x16_t x0_1 = vld1q_s8(x0->qs + 16); - const int8x16_t x1_0 = vld1q_s8(x1->qs); - const int8x16_t x1_1 = vld1q_s8(x1->qs + 16); - - // load y - const int8x16_t y0_0 = vld1q_s8(y0->qs); - const int8x16_t y0_1 = vld1q_s8(y0->qs + 16); - const int8x16_t y1_0 = vld1q_s8(y1->qs); - const int8x16_t y1_1 = vld1q_s8(y1->qs + 16); - -#if defined(__ARM_FEATURE_DOTPROD) - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( - vdotq_s32(vdupq_n_s32(0), x0_0, y0_0), - vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( - vdotq_s32(vdupq_n_s32(0), x1_0, y1_0), - vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); +static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) { + ggml_float sumf = 0.0; -#else - const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0)); - const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0)); - const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1)); - const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1)); - - const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0)); - const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0)); - const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1)); - const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1)); - - const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1)); - const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3)); - const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1)); - const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3)); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); -#endif - } +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); -#elif defined(__AVX2__) || defined(__AVX__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); + GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO }; - // Main loop - for (int i = 0; i < nb; ++i) { - // Compute combined scale for the block - const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); - __m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs); - __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; - const __m256 q = mul_sum_i8_pairs_float(bx, by); + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); - // Multiply q with scale and accumulate -#if defined(__AVX2__) - acc = _mm256_fmadd_ps( d, q, acc ); -#else - acc = _mm256_add_ps( _mm256_mul_ps( d, q ), acc ); -#endif + sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]); + } } - *s = hsum_float_8(acc); -#elif defined(__riscv_v_intrinsic) - float sumf = 0.0; - size_t vl = __riscv_vsetvl_e8m1(qk); - - for (int i = 0; i < nb; i++) { - // load elements - vint8m1_t bx = __riscv_vle8_v_i8m1(x[i].qs, vl); - vint8m1_t by = __riscv_vle8_v_i8m1(y[i].qs, vl); - - vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx, by, vl); - - vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl); - vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl); - - int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum); + // reduce sum0..sum3 to sum0 + GGML_F16_VEC_REDUCE(sumf, sum); - sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)); + // leftovers + for (int i = np; i < n; ++i) { + sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i])); } - - *s = sumf; #else - // scalar - float sumf = 0.0; - - for (int i = 0; i < nb; i++) { - int sumi = 0; - - for (int j = 0; j < qk; j++) { - sumi += x[i].qs[j]*y[i].qs[j]; - } - - sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)); + for (int i = 0; i < n; ++i) { + sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i])); } +#endif *s = sumf; -#endif } // compute GGML_VEC_DOT_UNROLL dot products at once @@ -3707,6 +1319,58 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float #endif } +// xs and vs are byte strides of x and v +inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) { + + const float * restrict x[GGML_VEC_MAD_UNROLL]; + const float * restrict v[GGML_VEC_MAD_UNROLL]; + + for (int i = 0; i < GGML_VEC_MAD_UNROLL; ++i) { + x[i] = (const float *) ((const char *) xv + i*xs); + v[i] = (const float *) ((const char *) vv + i*vs); + } + +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F32_STEP - 1)); + + GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL]; + + for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { + vx[k] = GGML_F32_VEC_SET1(v[k][0]); + } + + GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + + for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { + ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]); + } + + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + } + } + + // leftovers + for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { + for (int i = np; i < n; ++i) { + y[i] += x[k][i]*v[k][0]; + } + } +#else + // scalar + for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { + for (int i = 0; i < n; ++i) { + y[i] += x[k][i]*v[k][0]; + } + } +#endif +} + //inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; } inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { #if defined(GGML_USE_ACCELERATE) @@ -3749,6 +1413,7 @@ inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); } inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; } inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; } +inline static void ggml_vec_leaky_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.1f*x[i]; } static const float GELU_COEF_A = 0.044715f; static const float GELU_QUICK_COEF = -1.702f; @@ -3761,7 +1426,7 @@ inline static float ggml_gelu_f32(float x) { inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { const uint16_t * i16 = (const uint16_t *) x; for (int i = 0; i < n; ++i) { - y[i] = table_gelu_f16[i16[i]]; + y[i] = ggml_table_gelu_f16[i16[i]]; } } @@ -3771,7 +1436,7 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) { ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); memcpy(&t, &fp16, sizeof(uint16_t)); - y[i] = GGML_FP16_TO_FP32(table_gelu_f16[t]); + y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]); } } #else @@ -3789,7 +1454,7 @@ inline static float ggml_gelu_quick_f32(float x) { //inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { // const uint16_t * i16 = (const uint16_t *) x; // for (int i = 0; i < n; ++i) { -// y[i] = table_gelu_quick_f16[i16[i]]; +// y[i] = ggml_table_gelu_quick_f16[i16[i]]; // } //} @@ -3799,7 +1464,7 @@ inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * for (int i = 0; i < n; ++i) { ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); memcpy(&t, &fp16, sizeof(uint16_t)); - y[i] = GGML_FP16_TO_FP32(table_gelu_quick_f16[t]); + y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]); } } #else @@ -3818,7 +1483,7 @@ inline static float ggml_silu_f32(float x) { //inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { // const uint16_t * i16 = (const uint16_t *) x; // for (int i = 0; i < n; ++i) { -// y[i] = table_silu_f16[i16[i]]; +// y[i] = ggml_table_silu_f16[i16[i]]; // } //} @@ -3828,7 +1493,7 @@ inline static void ggml_vec_silu_f32(const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) { ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); memcpy(&t, &fp16, sizeof(uint16_t)); - y[i] = GGML_FP16_TO_FP32(table_silu_f16[t]); + y[i] = GGML_FP16_TO_FP32(ggml_table_silu_f16[t]); } } #else @@ -3970,7 +1635,12 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "ALIBI", "CLAMP", "CONV_1D", + "CONV_1D_STAGE_0", + "CONV_1D_STAGE_1", + "CONV_TRANSPOSE_1D", "CONV_2D", + "CONV_2D_STAGE_0", + "CONV_2D_STAGE_1", "CONV_TRANSPOSE_2D", "POOL_1D", "POOL_2D", @@ -4001,7 +1671,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68"); +static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -4052,7 +1722,12 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "alibi(x)", "clamp(x)", "conv_1d(x)", + "conv_1d_stage_0(x)", + "conv_1d_stage_1(x)", + "conv_transpose_1d(x)", "conv_2d(x)", + "conv_2d_stage_0(x)", + "conv_2d_stage_1(x)", "conv_transpose_2d(x)", "pool_1d(x)", "pool_2d(x)", @@ -4083,7 +1758,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68"); +static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -4112,7 +1787,12 @@ static void ggml_setup_op_has_task_pass(void) { p[GGML_OP_DIAG_MASK_INF ] = true; p[GGML_OP_DIAG_MASK_ZERO ] = true; p[GGML_OP_CONV_1D ] = true; + p[GGML_OP_CONV_1D_STAGE_0 ] = true; + p[GGML_OP_CONV_1D_STAGE_1 ] = true; + p[GGML_OP_CONV_TRANSPOSE_1D ] = true; p[GGML_OP_CONV_2D ] = true; + p[GGML_OP_CONV_2D_STAGE_0 ] = true; + p[GGML_OP_CONV_2D_STAGE_1 ] = true; p[GGML_OP_CONV_TRANSPOSE_2D ] = true; p[GGML_OP_FLASH_ATTN_BACK ] = true; p[GGML_OP_CROSS_ENTROPY_LOSS ] = true; @@ -4392,10 +2072,9 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - return - (t0->ne[1] == t1->ne[1]) && - (t0->ne[2] == t1->ne[2]) && - (t0->ne[3] == t1->ne[3]); + return (t0->ne[1] == t1->ne[1]) && + (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable + (t1->ne[3]%t0->ne[3] == 0); } enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { @@ -4530,11 +2209,11 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { for (int i = 0; i < (1 << 16); ++i) { uint16_t ui = i; memcpy(&ii, &ui, sizeof(ii)); - const float f = table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii); - table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f)); - table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f)); - table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f)); - table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f)); + const float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii); + ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f)); + ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f)); + ggml_table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f)); + ggml_table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f)); } const uint64_t t_end = ggml_time_us(); UNUSED(t_end); @@ -4830,6 +2509,7 @@ static struct ggml_tensor * ggml_new_tensor_impl( *result = (struct ggml_tensor) { /*.type =*/ type, /*.backend =*/ GGML_BACKEND_CPU, + /*.buffer =*/ NULL, /*.n_dims =*/ n_dims, /*.ne =*/ { 1, 1, 1, 1 }, /*.nb =*/ { 0, 0, 0, 0 }, @@ -5065,43 +2745,78 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) { return tensor; } +void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3) { + const int64_t ne2 = tensor->ne[2]; + const int64_t ne1 = tensor->ne[1]; + const int64_t ne0 = tensor->ne[0]; + + const int64_t i3_ = (i/(ne2*ne1*ne0)); + const int64_t i2_ = (i - i3_*ne2*ne1*ne0)/(ne1*ne0); + const int64_t i1_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0)/ne0; + const int64_t i0_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0 - i1_*ne0); + + if (i0) { + * i0 = i0_; + } + if (i1) { + * i1 = i1_; + } + if (i2) { + * i2 = i2_; + } + if (i3) { + * i3 = i3_; + } +} + int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) { + if (!ggml_is_contiguous(tensor)) { + int64_t id[4] = { 0, 0, 0, 0 }; + ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]); + return ggml_get_i32_nd(tensor, id[0], id[1], id[2], id[3]); + } switch (tensor->type) { case GGML_TYPE_I8: { GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); return ((int8_t *)(tensor->data))[i]; - } break; + } case GGML_TYPE_I16: { GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); return ((int16_t *)(tensor->data))[i]; - } break; + } case GGML_TYPE_I32: { GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); return ((int32_t *)(tensor->data))[i]; - } break; + } case GGML_TYPE_F16: { GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]); - } break; + } case GGML_TYPE_F32: { GGML_ASSERT(tensor->nb[0] == sizeof(float)); return ((float *)(tensor->data))[i]; - } break; + } default: { GGML_ASSERT(false); - } break; + } } return 0.0f; } void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) { + if (!ggml_is_contiguous(tensor)) { + int64_t id[4] = { 0, 0, 0, 0 }; + ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]); + ggml_set_i32_nd(tensor, id[0], id[1], id[2], id[3], value); + return; + } switch (tensor->type) { case GGML_TYPE_I8: { @@ -5135,68 +2850,179 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) { } } +int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) { + void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]; + switch (tensor->type) { + case GGML_TYPE_I8: + return ((int8_t *) data)[0]; + case GGML_TYPE_I16: + return ((int16_t *) data)[0]; + case GGML_TYPE_I32: + return ((int32_t *) data)[0]; + case GGML_TYPE_F16: + return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]); + case GGML_TYPE_F32: + return ((float *) data)[0]; + default: + GGML_ASSERT(false); + } + + return 0.0f; +} + +void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value) { + void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]; + switch (tensor->type) { + case GGML_TYPE_I8: + { + ((int8_t *)(data))[0] = value; + } break; + case GGML_TYPE_I16: + { + ((int16_t *)(data))[0] = value; + } break; + case GGML_TYPE_I32: + { + ((int32_t *)(data))[0] = value; + } break; + case GGML_TYPE_F16: + { + ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value); + } break; + case GGML_TYPE_F32: + { + ((float *)(data))[0] = value; + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) { + if (!ggml_is_contiguous(tensor)) { + int64_t id[4] = { 0, 0, 0, 0 }; + ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]); + return ggml_get_f32_nd(tensor, id[0], id[1], id[2], id[3]); + } switch (tensor->type) { case GGML_TYPE_I8: { GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); return ((int8_t *)(tensor->data))[i]; - } break; + } case GGML_TYPE_I16: { GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); return ((int16_t *)(tensor->data))[i]; - } break; + } case GGML_TYPE_I32: { GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); return ((int32_t *)(tensor->data))[i]; - } break; + } case GGML_TYPE_F16: { GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]); - } break; + } case GGML_TYPE_F32: { GGML_ASSERT(tensor->nb[0] == sizeof(float)); return ((float *)(tensor->data))[i]; + } + default: + { + GGML_ASSERT(false); + } + } + + return 0.0f; +} + +void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) { + if (!ggml_is_contiguous(tensor)) { + int64_t id[4] = { 0, 0, 0, 0 }; + ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]); + ggml_set_f32_nd(tensor, id[0], id[1], id[2], id[3], value); + return; + } + switch (tensor->type) { + case GGML_TYPE_I8: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); + ((int8_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_I16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); + ((int16_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_I32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); + ((int32_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_F16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); + ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value); + } break; + case GGML_TYPE_F32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(float)); + ((float *)(tensor->data))[i] = value; } break; default: { GGML_ASSERT(false); } break; } +} + +float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) { + void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]; + switch (tensor->type) { + case GGML_TYPE_I8: + return ((int8_t *) data)[0]; + case GGML_TYPE_I16: + return ((int16_t *) data)[0]; + case GGML_TYPE_I32: + return ((int32_t *) data)[0]; + case GGML_TYPE_F16: + return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]); + case GGML_TYPE_F32: + return ((float *) data)[0]; + default: + GGML_ASSERT(false); + } return 0.0f; } -void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) { +void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value) { + void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]; switch (tensor->type) { case GGML_TYPE_I8: { - GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); - ((int8_t *)(tensor->data))[i] = value; + ((int8_t *)(data))[0] = value; } break; case GGML_TYPE_I16: { - GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); - ((int16_t *)(tensor->data))[i] = value; + ((int16_t *)(data))[0] = value; } break; case GGML_TYPE_I32: { - GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); - ((int32_t *)(tensor->data))[i] = value; + ((int32_t *)(data))[0] = value; } break; case GGML_TYPE_F16: { - GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); - ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value); + ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value); } break; case GGML_TYPE_F32: { - GGML_ASSERT(tensor->nb[0] == sizeof(float)); - ((float *)(tensor->data))[i] = value; + ((float *)(data))[0] = value; } break; default: { @@ -5250,6 +3076,39 @@ struct ggml_tensor * ggml_view_tensor( return result; } +struct ggml_tensor * ggml_get_first_tensor(struct ggml_context * ctx) { + struct ggml_object * obj = ctx->objects_begin; + + char * const mem_buffer = ctx->mem_buffer; + + while (obj != NULL) { + if (obj->type == GGML_OBJECT_TENSOR) { + return (struct ggml_tensor *)(mem_buffer + obj->offs); + } + + obj = obj->next; + } + + return NULL; +} + +struct ggml_tensor * ggml_get_next_tensor(struct ggml_context * ctx, struct ggml_tensor * tensor) { + struct ggml_object * obj = (struct ggml_object *) ((char *)tensor - GGML_OBJECT_SIZE); + obj = obj->next; + + char * const mem_buffer = ctx->mem_buffer; + + while (obj != NULL) { + if (obj->type == GGML_OBJECT_TENSOR) { + return (struct ggml_tensor *)(mem_buffer + obj->offs); + } + + obj = obj->next; + } + + return NULL; +} + struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name) { struct ggml_object * obj = ctx->objects_begin; @@ -5347,6 +3206,44 @@ struct ggml_tensor * ggml_add_inplace( return ggml_add_impl(ctx, a, b, true); } +// ggml_add_cast + +static struct ggml_tensor * ggml_add_cast_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_type type) { + // TODO: support less-strict constraint + // GGML_ASSERT(ggml_can_repeat(b, a)); + GGML_ASSERT(ggml_can_repeat_rows(b, a)); + GGML_ASSERT(ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16); // currently only supported for quantized input and f16 + + bool is_node = false; + + if (a->grad || b->grad) { + // TODO: support backward pass for broadcasting + GGML_ASSERT(ggml_are_same_shape(a, b)); + is_node = true; + } + + struct ggml_tensor * result = ggml_new_tensor(ctx, type, a->n_dims, a->ne); + + result->op = GGML_OP_ADD; + result->grad = is_node ? ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims, a->ne) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_add_cast( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_type type) { + return ggml_add_cast_impl(ctx, a, b, type); +} + // ggml_add1 static struct ggml_tensor * ggml_add1_impl( @@ -5639,7 +3536,6 @@ struct ggml_tensor * ggml_sqrt_inplace( return ggml_sqrt_impl(ctx, a, true); } - // ggml_log static struct ggml_tensor * ggml_log_impl( @@ -5693,7 +3589,6 @@ struct ggml_tensor * ggml_sum( return result; } - // ggml_sum_rows struct ggml_tensor * ggml_sum_rows( @@ -5783,7 +3678,6 @@ struct ggml_tensor * ggml_repeat( result->op = GGML_OP_REPEAT; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = b; return result; } @@ -5811,7 +3705,6 @@ struct ggml_tensor * ggml_repeat_back( result->op = GGML_OP_REPEAT_BACK; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = b; return result; } @@ -5938,6 +3831,14 @@ struct ggml_tensor * ggml_relu_inplace( return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU); } +// ggml_leaky + +struct ggml_tensor * ggml_leaky( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_LEAKY); +} + // ggml_gelu struct ggml_tensor * ggml_gelu( @@ -6186,8 +4087,9 @@ struct ggml_tensor * ggml_out_prod( is_node = true; } - const int64_t ne[4] = { a->ne[0], b->ne[0], a->ne[2], b->ne[3] }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne); + // a is broadcastable to b for ne[2] and ne[3] -> use b->ne[2] and b->ne[3] + const int64_t ne[4] = { a->ne[0], b->ne[0], b->ne[2], b->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne); result->op = GGML_OP_OUT_PROD; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -6326,7 +4228,6 @@ struct ggml_tensor * ggml_set_2d_inplace( return ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, false); } - // ggml_cpy static struct ggml_tensor * ggml_cpy_impl( @@ -6406,6 +4307,52 @@ struct ggml_tensor * ggml_cont_inplace( return ggml_cont_impl(ctx, a, true); } +// make contiguous, with new shape +GGML_API struct ggml_tensor * ggml_cont_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0) { + return ggml_cont_4d(ctx, a, ne0, 1, 1, 1); +} + +GGML_API struct ggml_tensor * ggml_cont_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1) { + return ggml_cont_4d(ctx, a, ne0, ne1, 1, 1); +} + +GGML_API struct ggml_tensor * ggml_cont_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2) { + return ggml_cont_4d(ctx, a, ne0, ne1, ne2, 1); +} + +struct ggml_tensor * ggml_cont_4d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3) { + GGML_ASSERT(ggml_nelements(a) == (ne0*ne1*ne2*ne3)); + + bool is_node = false; + + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3); + ggml_format_name(result, "%s (cont)", a->name); + + result->op = GGML_OP_CONT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + // ggml_reshape struct ggml_tensor * ggml_reshape( @@ -6413,7 +4360,7 @@ struct ggml_tensor * ggml_reshape( struct ggml_tensor * a, struct ggml_tensor * b) { GGML_ASSERT(ggml_is_contiguous(a)); - GGML_ASSERT(ggml_is_contiguous(b)); + // as only the shape of b is relevant, and not its memory layout, b is allowed to be non contiguous. GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b)); bool is_node = false; @@ -6786,7 +4733,6 @@ struct ggml_tensor * ggml_get_rows_back( result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = b; - result->src[2] = c; return result; } @@ -6813,7 +4759,6 @@ struct ggml_tensor * ggml_diag( return result; } - // ggml_diag_mask_inf static struct ggml_tensor * ggml_diag_mask_inf_impl( @@ -6925,7 +4870,6 @@ struct ggml_tensor * ggml_soft_max_inplace( return ggml_soft_max_impl(ctx, a, true); } - // ggml_soft_max_back static struct ggml_tensor * ggml_soft_max_back_impl( @@ -6968,16 +4912,24 @@ struct ggml_tensor * ggml_soft_max_back_inplace( static struct ggml_tensor * ggml_rope_impl( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx, + int n_orig_ctx, float freq_base, float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, float xpos_base, bool xpos_down, bool inplace) { - GGML_ASSERT(n_past >= 0); + GGML_ASSERT(ggml_is_vector(b)); + GGML_ASSERT(b->type == GGML_TYPE_I32); + GGML_ASSERT(a->ne[2] == b->ne[0]); + bool is_node = false; if (a->grad) { @@ -6986,16 +4938,21 @@ static struct ggml_tensor * ggml_rope_impl( struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - int32_t params[8] = { n_past, n_dims, mode, n_ctx }; - memcpy(params + 4, &freq_base, sizeof(float)); - memcpy(params + 5, &freq_scale, sizeof(float)); - memcpy(params + 6, &xpos_base, sizeof(float)); - memcpy(params + 7, &xpos_down, sizeof(bool)); + int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx }; + memcpy(params + 5, &freq_base, sizeof(float)); + memcpy(params + 6, &freq_scale, sizeof(float)); + memcpy(params + 7, &ext_factor, sizeof(float)); + memcpy(params + 8, &attn_factor, sizeof(float)); + memcpy(params + 9, &beta_fast, sizeof(float)); + memcpy(params + 10, &beta_slow, sizeof(float)); + memcpy(params + 11, &xpos_base, sizeof(float)); + memcpy(params + 12, &xpos_down, sizeof(bool)); ggml_set_op_params(result, params, sizeof(params)); result->op = GGML_OP_ROPE; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; + result->src[1] = b; return result; } @@ -7003,55 +4960,75 @@ static struct ggml_tensor * ggml_rope_impl( struct ggml_tensor * ggml_rope( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false); + return ggml_rope_impl( + ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false + ); } struct ggml_tensor * ggml_rope_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true); + return ggml_rope_impl( + ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true + ); } struct ggml_tensor * ggml_rope_custom( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx, + int n_orig_ctx, float freq_base, - float freq_scale) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false); + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow) { + return ggml_rope_impl( + ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false + ); } struct ggml_tensor * ggml_rope_custom_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx, + int n_orig_ctx, float freq_base, - float freq_scale) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true); + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow) { + return ggml_rope_impl( + ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true + ); } struct ggml_tensor * ggml_rope_xpos_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, float base, bool down) { - return ggml_rope_impl(ctx, a, n_past, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true); + return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true); } // ggml_rope_back @@ -7059,7 +5036,7 @@ struct ggml_tensor * ggml_rope_xpos_inplace( struct ggml_tensor * ggml_rope_back( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx, @@ -7067,7 +5044,10 @@ struct ggml_tensor * ggml_rope_back( float freq_scale, float xpos_base, bool xpos_down) { - GGML_ASSERT(n_past >= 0); + GGML_ASSERT(ggml_is_vector(b)); + GGML_ASSERT(b->type == GGML_TYPE_I32); + GGML_ASSERT(a->ne[2] == b->ne[0]); + GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet"); bool is_node = false; @@ -7078,7 +5058,7 @@ struct ggml_tensor * ggml_rope_back( struct ggml_tensor * result = ggml_dup_tensor(ctx, a); - int32_t params[8] = { n_past, n_dims, mode, n_ctx }; + int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx }; memcpy(params + 4, &freq_base, sizeof(float)); memcpy(params + 5, &freq_scale, sizeof(float)); memcpy(params + 6, &xpos_base, sizeof(float)); @@ -7088,6 +5068,7 @@ struct ggml_tensor * ggml_rope_back( result->op = GGML_OP_ROPE_BACK; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; + result->src[1] = b; return result; } @@ -7156,14 +5137,17 @@ static int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; } -GGML_API struct ggml_tensor * ggml_conv_1d( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int s0, - int p0, - int d0) { - GGML_ASSERT(ggml_is_matrix(b)); +// im2col: [N, IC, IL] => [N, OL, IC*K] +// a: [OC,IC, K] +// b: [N, IC, IL] +// result: [N, OL, IC*K] +static struct ggml_tensor * ggml_conv_1d_stage_0( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int p0, + int d0) { GGML_ASSERT(a->ne[1] == b->ne[1]); bool is_node = false; @@ -7172,16 +5156,54 @@ GGML_API struct ggml_tensor * ggml_conv_1d( is_node = true; } + const int64_t OL = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); + const int64_t ne[4] = { - ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0), - a->ne[2], 1, 1, + a->ne[1] * a->ne[0], + OL, + b->ne[2], + 1, }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne); int32_t params[] = { s0, p0, d0 }; ggml_set_op_params(result, params, sizeof(params)); - result->op = GGML_OP_CONV_1D; + result->op = GGML_OP_CONV_1D_STAGE_0; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// ggml_conv_1d_stage_1 + +// gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K] +// a: [OC, IC, K] +// b: [N, OL, IC * K] +// result: [N, OC, OL] +static struct ggml_tensor * ggml_conv_1d_stage_1( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + + bool is_node = false; + + if (a->grad || b->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + const int64_t ne[4] = { + b->ne[1], + a->ne[2], + b->ne[2], + 1, + }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + result->op = GGML_OP_CONV_1D_STAGE_1; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = b; @@ -7189,6 +5211,53 @@ GGML_API struct ggml_tensor * ggml_conv_1d( return result; } +// ggml_conv_1d + +GGML_API struct ggml_tensor * ggml_conv_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int p0, + int d0) { + struct ggml_tensor * result = ggml_conv_1d_stage_0(ctx, a, b, s0, p0, d0); + result = ggml_conv_1d_stage_1(ctx, a, result); + return result; +} + +// GGML_API struct ggml_tensor * ggml_conv_1d( +// struct ggml_context * ctx, +// struct ggml_tensor * a, +// struct ggml_tensor * b, +// int s0, +// int p0, +// int d0) { +// GGML_ASSERT(ggml_is_matrix(b)); +// GGML_ASSERT(a->ne[1] == b->ne[1]); +// bool is_node = false; + +// if (a->grad || b->grad) { +// GGML_ASSERT(false); // TODO: implement backward +// is_node = true; +// } + +// const int64_t ne[4] = { +// ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0), +// a->ne[2], 1, 1, +// }; +// struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne); + +// int32_t params[] = { s0, p0, d0 }; +// ggml_set_op_params(result, params, sizeof(params)); + +// result->op = GGML_OP_CONV_1D; +// result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; +// result->src[0] = a; +// result->src[1] = b; + +// return result; +// } + // ggml_conv_1d_ph struct ggml_tensor* ggml_conv_1d_ph( @@ -7200,9 +5269,57 @@ struct ggml_tensor* ggml_conv_1d_ph( return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d); } +// ggml_conv_transpose_1d + +static int64_t ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) { + return (ins - 1) * s - 2 * p + d * (ks - 1) + 1; +} + +GGML_API struct ggml_tensor * ggml_conv_transpose_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int p0, + int d0) { + GGML_ASSERT(ggml_is_matrix(b)); + GGML_ASSERT(a->ne[2] == b->ne[1]); + GGML_ASSERT(a->ne[3] == 1); + + GGML_ASSERT(p0 == 0); + GGML_ASSERT(d0 == 1); + + bool is_node = false; + + if (a->grad || b->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + const int64_t ne[4] = { + ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/), + a->ne[1], b->ne[2], 1, + }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + int32_t params[] = { s0, p0, d0 }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_CONV_TRANSPOSE_1D; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + // ggml_conv_2d -struct ggml_tensor * ggml_conv_2d( +// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] +// a: [OC,IC, KH, KW] +// b: [N, IC, IH, IW] +// result: [N, OH, OW, IC*KH*KW] +static struct ggml_tensor * ggml_conv_2d_stage_0( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, @@ -7213,7 +5330,46 @@ struct ggml_tensor * ggml_conv_2d( int d0, int d1) { - GGML_ASSERT(a->ne[2] == b->ne[2]); + GGML_ASSERT(a->ne[2] == b->ne[2]); + bool is_node = false; + + if (a->grad || b->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + const int64_t OH = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1); + const int64_t OW = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); + + const int64_t ne[4] = { + a->ne[2] * a->ne[1] * a->ne[0], + OW, + OH, + b->ne[3], + }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne); + + int32_t params[] = { s0, s1, p0, p1, d0, d1 }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_CONV_2D_STAGE_0; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; + +} + +// gemm: [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW] +// a: [OC, IC, KH, KW] +// b: [N, OH, OW, IC * KH * KW] +// result: [N, OC, OH, OW] +static struct ggml_tensor * ggml_conv_2d_stage_1( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + bool is_node = false; if (a->grad || b->grad) { @@ -7222,16 +5378,14 @@ struct ggml_tensor * ggml_conv_2d( } const int64_t ne[4] = { - ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0), - ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1), - a->ne[3], b->ne[3], + b->ne[1], + b->ne[2], + a->ne[3], + b->ne[3], }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - int32_t params[] = { s0, s1, p0, p1, d0, d1 }; - ggml_set_op_params(result, params, sizeof(params)); - - result->op = GGML_OP_CONV_2D; + result->op = GGML_OP_CONV_2D_STAGE_1; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = b; @@ -7240,8 +5394,28 @@ struct ggml_tensor * ggml_conv_2d( } -// ggml_conv_2d_sk_p0 +// a: [OC,IC, KH, KW] +// b: [N, IC, IH, IW] +// result: [N, OC, OH, OW] +struct ggml_tensor * ggml_conv_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1) { + + struct ggml_tensor * result = ggml_conv_2d_stage_0(ctx, a, b, s0, s1, p0, p1, d0, d1); // [N, OH, OW, IC * KH * KW] + result = ggml_conv_2d_stage_1(ctx, a, result); + + return result; + +} +// ggml_conv_2d_sk_p0 struct ggml_tensor * ggml_conv_2d_sk_p0( struct ggml_context * ctx, struct ggml_tensor * a, @@ -7298,7 +5472,7 @@ struct ggml_tensor * ggml_conv_transpose_2d_p0( // ggml_pool_* -static int64_t ggml_calc_pool_output_size(int64_t ins, int ks, int s, int p) { +static int64_t ggml_calc_pool_output_size(int64_t ins, int ks, int s, float p) { return (ins + 2 * p - ks) / s + 1; } @@ -7345,8 +5519,8 @@ struct ggml_tensor * ggml_pool_2d( int k1, int s0, int s1, - int p0, - int p1) { + float p0, + float p1) { bool is_node = false; @@ -7484,27 +5658,30 @@ struct ggml_tensor * ggml_flash_attn_back( // d shape [D,N,ne2,ne3] // q shape [D,N,ne2,ne3] - // k shape [D,M,ne2,ne3] - // v shape [M,D,ne2,ne3] + // k shape [D,M,kvne2,ne3] + // v shape [M,D,kvne2,ne3] - const int64_t D = q->ne[0]; - const int64_t N = q->ne[1]; - const int64_t M = k->ne[1]; - const int64_t ne2 = q->ne[2]; - const int64_t ne3 = q->ne[3]; + const int64_t D = q->ne[0]; + const int64_t N = q->ne[1]; + const int64_t M = k->ne[1]; + const int64_t ne2 = q->ne[2]; + const int64_t ne3 = q->ne[3]; + const int64_t kvne2 = k->ne[2]; GGML_ASSERT(k->ne[0] == D); GGML_ASSERT(v->ne[0] == M); GGML_ASSERT(v->ne[1] == D); GGML_ASSERT(d->ne[0] == D); GGML_ASSERT(d->ne[1] == N); - GGML_ASSERT(k->ne[2] == ne2); + GGML_ASSERT(k->ne[2] == kvne2); GGML_ASSERT(k->ne[3] == ne3); - GGML_ASSERT(v->ne[2] == ne2); + GGML_ASSERT(v->ne[2] == kvne2); GGML_ASSERT(v->ne[3] == ne3); GGML_ASSERT(d->ne[2] == ne2); GGML_ASSERT(d->ne[3] == ne3); + GGML_ASSERT(ne2 % kvne2 == 0); + bool is_node = false; if (q->grad || k->grad || v->grad) { @@ -7514,14 +5691,23 @@ struct ggml_tensor * ggml_flash_attn_back( } // store gradients of q, k and v as continuous tensors concatenated in result. - // q shape[D,N,ne2,ne3] ; k shape [D,M,ne2,ne3] ; v shape [M,D,ne2,ne3] - // gradq->data = result->data - // gradk->data = result->data + nb0*D*N*ne2*ne3 - // gradv->data = result->data + nb0*D*N*ne2*ne3 + nb0*D*M*ne2*ne3 // note: v and gradv are actually transposed, i.e. v->ne[0] != D. - int64_t ne[4] = {D,M+N+M,ne2,ne3}; + const int64_t elem_q = ggml_nelements(q); + const int64_t elem_k = ggml_nelements(k); + const int64_t elem_v = ggml_nelements(v); - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + enum ggml_type result_type = GGML_TYPE_F32; + GGML_ASSERT(ggml_blck_size(result_type) == 1); + const size_t tsize = ggml_type_size(result_type); + + const size_t offs_q = 0; + const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN); + const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN); + const size_t end = offs_v + GGML_PAD(elem_v * tsize, GGML_MEM_ALIGN); + + const size_t nelements = (end + tsize - 1)/tsize; + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nelements); int32_t masked_i = masked ? 1 : 0; ggml_set_op_params(result, &masked_i, sizeof(masked_i)); @@ -7668,7 +5854,6 @@ static struct ggml_tensor * ggml_add_rel_pos_impl( return result; } - struct ggml_tensor * ggml_add_rel_pos( struct ggml_context * ctx, struct ggml_tensor * a, @@ -8113,8 +6298,6 @@ struct ggml_tensor * ggml_map_custom3_inplace( return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, true); } - - // ggml_cross_entropy_loss struct ggml_tensor * ggml_cross_entropy_loss( @@ -8168,6 +6351,7 @@ void ggml_set_param( GGML_ASSERT(tensor->grad == NULL); tensor->grad = ggml_dup_tensor(ctx, tensor); + ggml_format_name(tensor->grad, "%s (grad)", tensor->name); } // ggml_compute_forward_dup @@ -8214,7 +6398,7 @@ static void ggml_compute_forward_dup_f16( return; } - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS const int ith = params->ith; // thread index const int nth = params->nth; // number of threads @@ -8485,7 +6669,7 @@ static void ggml_compute_forward_dup_f32( return; } - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS const int ith = params->ith; // thread index const int nth = params->nth; // number of threads @@ -8766,7 +6950,7 @@ static void ggml_compute_forward_add_f32( const int nr = ggml_nrows(src0); - GGML_TENSOR_BINARY_OP_LOCALS; + GGML_TENSOR_BINARY_OP_LOCALS GGML_ASSERT( nb0 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float)); @@ -8798,8 +6982,6 @@ static void ggml_compute_forward_add_f32( #else ggml_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr); #endif - // } - // } } } else { // src1 is not contiguous @@ -8841,13 +7023,19 @@ static void ggml_compute_forward_add_f16_f32( const int nr = ggml_nrows(src0); - GGML_TENSOR_BINARY_OP_LOCALS; + GGML_TENSOR_BINARY_OP_LOCALS GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F16); - GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); + if (dst->type == GGML_TYPE_F32) { + GGML_ASSERT( nb0 == sizeof(float)); + } + else { + GGML_ASSERT(dst->type == GGML_TYPE_F16); + GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); + } + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); // rows per thread @@ -8858,18 +7046,35 @@ static void ggml_compute_forward_add_f16_f32( const int ir1 = MIN(ir0 + dr, nr); if (nb10 == sizeof(float)) { - for (int ir = ir0; ir < ir1; ++ir) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); - ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); - - for (int i = 0; i < ne0; i++) { - dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]); + if (dst->type == GGML_TYPE_F16) { + for (int ir = ir0; ir < ir1; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); + + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]); + } + } + } else { + for (int ir = ir0; ir < ir1; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); + + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]; + } } } } @@ -8895,7 +7100,7 @@ static void ggml_compute_forward_add_f16_f16( const int nr = ggml_nrows(src0); - GGML_TENSOR_BINARY_OP_LOCALS; + GGML_TENSOR_BINARY_OP_LOCALS GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F16); @@ -8946,14 +7151,15 @@ static void ggml_compute_forward_add_q_f32( const int nr = ggml_nrows(src0); - GGML_TENSOR_BINARY_OP_LOCALS; + GGML_TENSOR_BINARY_OP_LOCALS const int ith = params->ith; const int nth = params->nth; const enum ggml_type type = src0->type; + const enum ggml_type dtype = dst->type; ggml_to_float_t const dequantize_row_q = type_traits[type].to_float; - ggml_from_float_t const quantize_row_q = type_traits[type].from_float; + ggml_from_float_t const quantize_row_q = type_traits[dtype].from_float; // we don't support permuted src0 or src1 GGML_ASSERT(nb00 == ggml_type_size(type)); @@ -8965,7 +7171,6 @@ static void ggml_compute_forward_add_q_f32( GGML_ASSERT(nb2 <= nb3); GGML_ASSERT(ggml_is_quantized(src0->type)); - GGML_ASSERT(dst->type == src0->type); GGML_ASSERT(src1->type == GGML_TYPE_F32); // rows per thread @@ -9003,7 +7208,11 @@ static void ggml_compute_forward_add_q_f32( // add src1 ggml_vec_acc_f32(ne00, wdata, src1_row); // quantize row to dst - quantize_row_q(wdata, dst_row, ne00); + if (quantize_row_q != NULL) { + quantize_row_q(wdata, dst_row, ne00); + } else { + memcpy(dst_row, wdata, ne0*nb0); + } } } @@ -9068,7 +7277,7 @@ static void ggml_compute_forward_add1_f32( const int nr = ggml_nrows(src0); - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS GGML_ASSERT( nb0 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float)); @@ -9123,7 +7332,7 @@ static void ggml_compute_forward_add1_f16_f32( const int nr = ggml_nrows(src0); - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); @@ -9173,7 +7382,7 @@ static void ggml_compute_forward_add1_f16_f16( const int nr = ggml_nrows(src0); - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F16); @@ -9223,7 +7432,7 @@ static void ggml_compute_forward_add1_q_f32( const int nr = ggml_nrows(src0); - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS const enum ggml_type type = src0->type; ggml_to_float_t const dequantize_row_q = type_traits[type].to_float; @@ -9313,7 +7522,6 @@ static void ggml_compute_forward_add1( } } - // ggml_compute_forward_acc static void ggml_compute_forward_acc_f32( @@ -9351,8 +7559,8 @@ static void ggml_compute_forward_acc_f32( const int nr = ggml_nrows(src1); const int nc = src1->ne[0]; - GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne); - GGML_TENSOR_LOCALS(size_t, nb1, src1, nb); + GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) + GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) // src0 and dst as viewed during acc const size_t nb0 = ggml_element_size(src0); @@ -9441,7 +7649,7 @@ static void ggml_compute_forward_sub_f32( const int nr = ggml_nrows(src0); - GGML_TENSOR_BINARY_OP_LOCALS; + GGML_TENSOR_BINARY_OP_LOCALS GGML_ASSERT( nb0 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float)); @@ -9453,7 +7661,6 @@ static void ggml_compute_forward_sub_f32( const int i2 = (ir - i3*ne2*ne1)/ne1; const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - #ifdef GGML_USE_ACCELERATE vDSP_vsub( (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1, @@ -9531,7 +7738,7 @@ static void ggml_compute_forward_mul_f32( const int64_t nr = ggml_nrows(src0); - GGML_TENSOR_BINARY_OP_LOCALS; + GGML_TENSOR_BINARY_OP_LOCALS GGML_ASSERT( nb0 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float)); @@ -9622,7 +7829,7 @@ static void ggml_compute_forward_div_f32( const int nr = ggml_nrows(src0); - GGML_TENSOR_BINARY_OP_LOCALS; + GGML_TENSOR_BINARY_OP_LOCALS GGML_ASSERT( nb0 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float)); @@ -9634,7 +7841,6 @@ static void ggml_compute_forward_div_f32( const int i2 = (ir - i3*ne2*ne1)/ne1; const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - #ifdef GGML_USE_ACCELERATE UNUSED(ggml_vec_div_f32); @@ -9772,7 +7978,6 @@ static void ggml_compute_forward_sqrt( } } - // ggml_compute_forward_log static void ggml_compute_forward_log_f32( @@ -9831,8 +8036,8 @@ static void ggml_compute_forward_sum_f32( assert(ggml_is_scalar(dst)); assert(src0->nb[0] == sizeof(float)); - GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); - GGML_TENSOR_LOCALS(size_t, nb0, src0, nb); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) ggml_float sum = 0; ggml_float row_sum = 0; @@ -9863,8 +8068,8 @@ static void ggml_compute_forward_sum_f16( assert(src0->nb[0] == sizeof(ggml_fp16_t)); - GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); - GGML_TENSOR_LOCALS(size_t, nb0, src0, nb); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) float sum = 0; float row_sum = 0; @@ -9917,7 +8122,7 @@ static void ggml_compute_forward_sum_rows_f32( GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(dst->nb[0] == sizeof(float)); - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS GGML_ASSERT(ne0 == 1); GGML_ASSERT(ne1 == ne01); @@ -9967,7 +8172,7 @@ static void ggml_compute_forward_mean_f32( assert(src0->nb[0] == sizeof(float)); - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS assert(ne0 == 1); assert(ne1 == ne01); @@ -10067,7 +8272,7 @@ static void ggml_compute_forward_repeat_f32( return; } - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS // guaranteed to be an integer due to the check in ggml_can_repeat const int nr0 = (int)(ne0/ne00); @@ -10099,11 +8304,61 @@ static void ggml_compute_forward_repeat_f32( } } +static void ggml_compute_forward_repeat_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(params->ith == 0); + GGML_ASSERT(ggml_can_repeat(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_TENSOR_UNARY_OP_LOCALS; + + // guaranteed to be an integer due to the check in ggml_can_repeat + const int nr0 = (int)(ne0/ne00); + const int nr1 = (int)(ne1/ne01); + const int nr2 = (int)(ne2/ne02); + const int nr3 = (int)(ne3/ne03); + + // TODO: support for transposed / permuted tensors + GGML_ASSERT(nb0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + + // TODO: maybe this is not optimal? + for (int i3 = 0; i3 < nr3; i3++) { + for (int k3 = 0; k3 < ne03; k3++) { + for (int i2 = 0; i2 < nr2; i2++) { + for (int k2 = 0; k2 < ne02; k2++) { + for (int i1 = 0; i1 < nr1; i1++) { + for (int k1 = 0; k1 < ne01; k1++) { + for (int i0 = 0; i0 < nr0; i0++) { + ggml_fp16_t * y = (ggml_fp16_t *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0); + ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01); + // ggml_vec_cpy_f16(ne00, y, x) + for (int i = 0; i < ne00; ++i) { + y[i] = x[i]; + } + } + } + } + } + } + } + } +} + static void ggml_compute_forward_repeat( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_repeat_f16(params, src0, dst); + } break; case GGML_TYPE_F32: { ggml_compute_forward_repeat_f32(params, src0, dst); @@ -10128,7 +8383,7 @@ static void ggml_compute_forward_repeat_back_f32( return; } - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS // guaranteed to be an integer due to the check in ggml_can_repeat const int nr0 = (int)(ne00/ne0); @@ -10206,7 +8461,7 @@ static void ggml_compute_forward_concat_f32( const int ith = params->ith; - GGML_TENSOR_BINARY_OP_LOCALS; + GGML_TENSOR_BINARY_OP_LOCALS // TODO: support for transposed / permuted tensors GGML_ASSERT(nb0 == sizeof(float)); @@ -10702,7 +8957,7 @@ static void ggml_compute_forward_silu_f32( #ifndef NDEBUG for (int k = 0; k < nc; k++) { - const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k]; UNUSED(x); assert(!isnan(x)); assert(!isinf(x)); @@ -10727,6 +8982,48 @@ static void ggml_compute_forward_silu( } } +// ggml_compute_forward_leaky + +static void ggml_compute_forward_leaky_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_leaky_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_leaky( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_leaky_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_silu_back static void ggml_compute_forward_silu_back_f32( @@ -10808,7 +9105,7 @@ static void ggml_compute_forward_norm_f32( const int ith = params->ith; const int nth = params->nth; - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS float eps; memcpy(&eps, dst->op_params, sizeof(float)); @@ -10877,7 +9174,7 @@ static void ggml_compute_forward_rms_norm_f32( const int ith = params->ith; const int nth = params->nth; - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS float eps; memcpy(&eps, dst->op_params, sizeof(float)); @@ -10942,7 +9239,7 @@ static void ggml_compute_forward_rms_norm_back_f32( const int ith = params->ith; const int nth = params->nth; - GGML_TENSOR_BINARY_OP_LOCALS; + GGML_TENSOR_BINARY_OP_LOCALS float eps; memcpy(&eps, dst->op_params, sizeof(float)); @@ -11117,7 +9414,7 @@ static void ggml_compute_forward_group_norm_f32( const int ith = params->ith; const int nth = params->nth; - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS const float eps = 1e-6f; // TODO: make this a parameter @@ -11228,7 +9525,7 @@ static void ggml_compute_forward_mul_mat( int64_t t0 = ggml_perf_time_us(); UNUSED(t0); - GGML_TENSOR_BINARY_OP_LOCALS; + GGML_TENSOR_BINARY_OP_LOCALS const int ith = params->ith; const int nth = params->nth; @@ -11265,11 +9562,6 @@ static void ggml_compute_forward_mul_mat( #if defined(GGML_USE_CLBLAST) if (ggml_cl_can_mul_mat(src0, src1, dst)) { - // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension - // ref: https://github.com/ggerganov/ggml/pull/224 - GGML_ASSERT(ne02 == ne12); - GGML_ASSERT(ne03 == ne13); - if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) { ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize); } @@ -11430,36 +9722,175 @@ static void ggml_compute_forward_mul_mat( for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { vec_dot(ne00, &tmp[ir0 - iir0], src0_row + ir0*nb01, src1_col); } - memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float)); + memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float)); + } + } + } +} + +// ggml_compute_forward_out_prod + +static void ggml_compute_forward_out_prod_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + // int64_t t0 = ggml_perf_time_us(); + // UNUSED(t0); + + GGML_TENSOR_BINARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne03 == ne13); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == sizeof(float)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + // GGML_ASSERT(nb0 <= nb1); + // GGML_ASSERT(nb1 <= nb2); + // GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne10); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + + // TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod + // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST) + + if (params->type == GGML_TASK_INIT) { + ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // dst[:,:,:,:] = 0 + // for i2,i3: + // for i1: + // for i01: + // for i0: + // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3] + + // parallelize by last three dimensions + + // total rows in dst + const int64_t nr = ne1*ne2*ne3; + + // rows per thread + const int64_t dr = (nr + nth - 1)/nth; + + // row range for this thread + const int64_t ir0 = dr*ith; + const int64_t ir1 = MIN(ir0 + dr, nr); + + // block-tiling attempt + const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32); + const int64_t blck_1 = 16; + + for (int64_t bir = ir0; bir < ir1; bir += blck_1) { + const int64_t bir1 = MIN(bir + blck_1, ir1); + for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) { + const int64_t bne01 = MIN(bi01 + blck_0, ne01); + for (int64_t ir = bir; ir < bir1; ++ir) { + // dst indices + const int64_t i3 = ir/(ne2*ne1); + const int64_t i2 = (ir - i3*ne2*ne1)/ne1; + const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1); + + const int64_t i02 = i2; + const int64_t i03 = i3; + + //const int64_t i10 = i1; + const int64_t i12 = i2; + const int64_t i13 = i3; + +#if GGML_VEC_MAD_UNROLL > 2 + const int64_t bne01_unroll = bne01 - (bne01 % GGML_VEC_MAD_UNROLL); + for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_VEC_MAD_UNROLL) { + const int64_t i11 = i01; + + float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); + float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + + ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1); + } + for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) { + const int64_t i11 = i01; + + float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); + float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + + ggml_vec_mad_f32(ne0, d, s0, *s1); + } +#else + for (int64_t i01 = bi01; i01 < bne01; ++i01) { + const int64_t i11 = i01; + + float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); + float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + + ggml_vec_mad_f32(ne0, d, s0, *s1); + } +#endif } } } -} -// ggml_compute_forward_out_prod + //int64_t t1 = ggml_perf_time_us(); + //static int64_t acc = 0; + //acc += t1 - t0; + //if (t1 - t0 > 10) { + // printf("\n"); + // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03); + // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03); + // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13); + // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13); -static void ggml_compute_forward_out_prod_f32( + // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc); + //} +} + +static void ggml_compute_forward_out_prod_q_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); + // int64_t t0 = ggml_perf_time_us(); + // UNUSED(t0); GGML_TENSOR_BINARY_OP_LOCALS; const int ith = params->ith; const int nth = params->nth; + const enum ggml_type type = src0->type; + ggml_to_float_t const dequantize_row_q = type_traits[type].to_float; + GGML_ASSERT(ne02 == ne12); GGML_ASSERT(ne03 == ne13); GGML_ASSERT(ne2 == ne12); GGML_ASSERT(ne3 == ne13); - // we don't support permuted src0 or src1 - GGML_ASSERT(nb00 == sizeof(float)); + // we don't support permuted src0 dim0 + GGML_ASSERT(nb00 == ggml_type_size(type)); - // dst cannot be transposed or permuted + // dst dim0 cannot be transposed or permuted GGML_ASSERT(nb0 == sizeof(float)); // GGML_ASSERT(nb0 <= nb1); // GGML_ASSERT(nb1 <= nb2); @@ -11504,6 +9935,8 @@ static void ggml_compute_forward_out_prod_f32( // for i0: // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3] + float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith; + for (int64_t ir = ir0; ir < ir1; ++ir) { // dst indices const int64_t i3 = ir/(ne2*ne1); @@ -11524,10 +9957,8 @@ static void ggml_compute_forward_out_prod_f32( float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); - ggml_vec_mad_f32(ne0, d, s0, *s1); - // for (int64_t i0 = 0; i0 < ne0; ++i0) { - // d[i0] += s0[i0] * s1[i1]; - // } + dequantize_row_q(s0, wdata, ne0); + ggml_vec_mad_f32(ne0, d, wdata, *s1); } } @@ -11556,10 +9987,13 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: - case GGML_TYPE_Q8_1: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: { - GGML_ASSERT(false); // todo - // ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst); + ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst); } break; case GGML_TYPE_F16: { @@ -11613,7 +10047,6 @@ static void ggml_compute_forward_scale_f32( const size_t nb1 = dst->nb[1]; - for (int i1 = ir0; i1 < ir1; i1++) { if (dst->data != src0->data) { // src0 is same shape as dst => same indices @@ -11677,8 +10110,8 @@ static void ggml_compute_forward_set_f32( const int nr = ggml_nrows(src1); const int nc = src1->ne[0]; - GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne); - GGML_TENSOR_LOCALS(size_t, nb1, src1, nb); + GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) + GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) // src0 and dst as viewed during set const size_t nb0 = ggml_element_size(src0); @@ -11947,14 +10380,15 @@ static void ggml_compute_forward_get_rows_back_f32_f16( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, - const struct ggml_tensor * opt0, struct ggml_tensor * dst) { GGML_ASSERT(params->ith == 0); - GGML_ASSERT(ggml_are_same_shape(opt0, dst)); - GGML_ASSERT(ggml_is_contiguous(opt0)); GGML_ASSERT(ggml_is_contiguous(dst)); - ggml_compute_forward_dup_same_cont(params, opt0, dst); + // ggml_compute_forward_dup_same_cont(params, opt0, dst); + + if (params->type == GGML_TASK_INIT) { + memset(dst->data, 0, ggml_nbytes(dst)); + } if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; @@ -11980,11 +10414,8 @@ static void ggml_compute_forward_get_rows_back_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, - const struct ggml_tensor * opt0, struct ggml_tensor * dst) { GGML_ASSERT(params->ith == 0); - GGML_ASSERT(ggml_are_same_shape(opt0, dst)); - GGML_ASSERT(ggml_is_contiguous(opt0)); GGML_ASSERT(ggml_is_contiguous(dst)); // ggml_compute_forward_dup_same_cont(params, opt0, dst); @@ -12013,21 +10444,19 @@ static void ggml_compute_forward_get_rows_back_f32( } } - static void ggml_compute_forward_get_rows_back( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, - const struct ggml_tensor * opt0, struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F16: { - ggml_compute_forward_get_rows_back_f32_f16(params, src0, src1, opt0, dst); + ggml_compute_forward_get_rows_back_f32_f16(params, src0, src1, dst); } break; case GGML_TYPE_F32: { - ggml_compute_forward_get_rows_back_f32(params, src0, src1, opt0, dst); + ggml_compute_forward_get_rows_back_f32(params, src0, src1, dst); } break; default: { @@ -12068,7 +10497,7 @@ static void ggml_compute_forward_diag_f32( // TODO: handle transposed/permuted matrices - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS GGML_ASSERT(ne00 == ne0); GGML_ASSERT(ne00 == ne1); @@ -12249,7 +10678,7 @@ static void ggml_compute_forward_soft_max_f32( // const float val = (sp[i] == -INFINITY) ? 0.0 : exp(sp[i] - max); ggml_fp16_t s = GGML_FP32_TO_FP16(sp[i] - max); memcpy(&scvt, &s, sizeof(scvt)); - const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]); + const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]); sum += (ggml_float)val; dp[i] = val; } @@ -12393,28 +10822,25 @@ static void ggml_compute_forward_alibi_f32( return; } - const int n_past = ((int32_t *) dst->op_params)[0]; + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_head = ((int32_t *) dst->op_params)[1]; float max_bias; memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - assert(n_past >= 0); - - const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 - const int ne1 = src0->ne[1]; // seq_len_without_past - const int ne2 = src0->ne[2]; // n_head -> this is k - //const int ne3 = src0->ne[3]; // 1 -> bsz + const int64_t ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 + const int64_t ne1 = src0->ne[1]; // seq_len_without_past + const int64_t ne2 = src0->ne[2]; // n_head -> this is k + //const int64_t ne3 = src0->ne[3]; // 1 -> bsz - const int n = ggml_nrows(src0); - const int ne2_ne3 = n/ne1; // ne2*ne3 + const int64_t n = ggml_nrows(src0); + const int64_t ne2_ne3 = n/ne1; // ne2*ne3 - const int nb0 = src0->nb[0]; - const int nb1 = src0->nb[1]; - const int nb2 = src0->nb[2]; + const size_t nb0 = src0->nb[0]; + const size_t nb1 = src0->nb[1]; + const size_t nb2 = src0->nb[2]; //const int nb3 = src0->nb[3]; GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(ne1 + n_past == ne0); GGML_ASSERT(n_head == ne2); // add alibi to src0 (KQ_scaled) @@ -12423,9 +10849,9 @@ static void ggml_compute_forward_alibi_f32( const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); - for (int i = 0; i < ne0; i++) { - for (int j = 0; j < ne1; j++) { - for (int k = 0; k < ne2_ne3; k++) { + for (int64_t i = 0; i < ne0; i++) { + for (int64_t j = 0; j < ne1; j++) { + for (int64_t k = 0; k < ne2_ne3; k++) { float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2); float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2); @@ -12440,7 +10866,6 @@ static void ggml_compute_forward_alibi_f32( } pdst[0] = i * m_k + src[0]; - } } } @@ -12456,13 +10881,11 @@ static void ggml_compute_forward_alibi_f16( return; } - const int n_past = ((int32_t *) dst->op_params)[0]; + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_head = ((int32_t *) dst->op_params)[1]; float max_bias; memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - assert(n_past >= 0); - const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 const int ne1 = src0->ne[1]; // seq_len_without_past const int ne2 = src0->ne[2]; // n_head -> this is k @@ -12477,7 +10900,7 @@ static void ggml_compute_forward_alibi_f16( //const int nb3 = src0->nb[3]; GGML_ASSERT(nb0 == sizeof(ggml_fp16_t)); - GGML_ASSERT(ne1 + n_past == ne0); (void) n_past; + //GGML_ASSERT(ne1 + n_past == ne0); (void) n_past; GGML_ASSERT(n_head == ne2); // add alibi to src0 (KQ_scaled) @@ -12620,34 +11043,76 @@ static void ggml_compute_forward_clamp( // ggml_compute_forward_rope +static float rope_yarn_ramp(const float low, const float high, const int i0) { + const float y = (i0 / 2 - low) / MAX(0.001f, high - low); + return 1 - MIN(1, MAX(0, y)); +} + +// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn +// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. +static void rope_yarn( + float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, + float * cos_theta, float * sin_theta +) { + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = freq_scale * theta_extrap; + float theta = theta_interp; + if (ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale); + } + *cos_theta = cosf(theta) * mscale; + *sin_theta = sinf(theta) * mscale; +} + +// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get +// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` +static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, float base) { + return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base)); +} + +void ggml_rope_yarn_corr_dims( + int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2] +) { + // start and end correction dims + dims[0] = MAX(0, floorf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_fast, freq_base))); + dims[1] = MIN(n_dims - 1, ceilf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_slow, freq_base))); +} + static void ggml_compute_forward_rope_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, + const struct ggml_tensor * src1, struct ggml_tensor * dst) { - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } - float freq_base; - float freq_scale; + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; // these two only relevant for xPos RoPE: float xpos_base; - bool xpos_down; + bool xpos_down; - const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_dims = ((int32_t *) dst->op_params)[1]; - const int mode = ((int32_t *) dst->op_params)[2]; - const int n_ctx = ((int32_t *) dst->op_params)[3]; - memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); - memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float)); - memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool)); + //const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + const int n_ctx = ((int32_t *) dst->op_params)[3]; + const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; - assert(n_past >= 0); + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + memcpy(&xpos_base, (int32_t *) dst->op_params + 11, sizeof(float)); + memcpy(&xpos_down, (int32_t *) dst->op_params + 12, sizeof(bool)); - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); //printf("n_past = %d, ne2 = %d\n", n_past, ne2); @@ -12673,29 +11138,34 @@ static void ggml_compute_forward_rope_f32( int ir = 0; const float theta_scale = powf(freq_base, -2.0f/n_dims); + const float inv_ndims = -1.f/n_dims; + float corr_dims[2]; + ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); const bool is_neox = mode & 2; const bool is_glm = mode & 4; + const int32_t * pos = (const int32_t *) src1->data; + for (int64_t i3 = 0; i3 < ne3; i3++) { - for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) { - const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); + for (int64_t i2 = 0; i2 < ne2; i2++) { + const int64_t p = pos[i2]; for (int64_t i1 = 0; i1 < ne1; i1++) { if (ir++ < ir0) continue; if (ir > ir1) break; - float theta = freq_scale * (float)p; + float theta_base = (float)p; if (is_glm) { - theta = MIN(p, n_ctx - 2); + theta_base = MIN(p, n_ctx - 2); float block_theta = MAX(p - (n_ctx - 2), 0); for (int64_t i0 = 0; i0 < ne0 / 4; i0++) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); + const float cos_theta = cosf(theta_base); + const float sin_theta = sinf(theta_base); const float cos_block_theta = cosf(block_theta); const float sin_block_theta = sinf(block_theta); - theta *= theta_scale; + theta_base *= theta_scale; block_theta *= theta_scale; const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); @@ -12713,13 +11183,16 @@ static void ggml_compute_forward_rope_f32( } } else if (!is_neox) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); + float cos_theta, sin_theta; + rope_yarn( + theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta + ); + // zeta scaling for xPos only: - float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f; + float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f; if (xpos_down) zeta = 1.0f / zeta; - theta *= theta_scale; + theta_base *= theta_scale; const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); @@ -12733,12 +11206,19 @@ static void ggml_compute_forward_rope_f32( } else { // TODO: this might be wrong for ne0 != n_dims - need double check // ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28 + theta_base *= freq_scale; for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { for (int64_t ic = 0; ic < n_dims; ic += 2) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); + // simplified from `(ib * n_dims + ic) * inv_ndims` + float cur_rot = inv_ndims * ic - ib; + + float cos_theta, sin_theta; + rope_yarn( + theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, + &cos_theta, &sin_theta + ); - theta *= theta_scale; + theta_base *= theta_scale; const int64_t i0 = ib*n_dims + ic/2; @@ -12761,25 +11241,27 @@ static void ggml_compute_forward_rope_f32( static void ggml_compute_forward_rope_f16( const struct ggml_compute_params * params, const struct ggml_tensor * src0, + const struct ggml_tensor * src1, struct ggml_tensor * dst) { - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } - float freq_base; - float freq_scale; - - const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_dims = ((int32_t *) dst->op_params)[1]; - const int mode = ((int32_t *) dst->op_params)[2]; - const int n_ctx = ((int32_t *) dst->op_params)[3]; - memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; - assert(n_past >= 0); + //const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + const int n_ctx = ((int32_t *) dst->op_params)[3]; + const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); //printf("n_past = %d, ne2 = %d\n", n_past, ne2); @@ -12805,29 +11287,34 @@ static void ggml_compute_forward_rope_f16( int ir = 0; const float theta_scale = powf(freq_base, -2.0f/n_dims); + const float inv_ndims = -1.f/n_dims; + float corr_dims[2]; + ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); const bool is_neox = mode & 2; const bool is_glm = mode & 4; + const int32_t * pos = (const int32_t *) src1->data; + for (int64_t i3 = 0; i3 < ne3; i3++) { - for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) { - const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); + for (int64_t i2 = 0; i2 < ne2; i2++) { + const int64_t p = pos[i2]; for (int64_t i1 = 0; i1 < ne1; i1++) { if (ir++ < ir0) continue; if (ir > ir1) break; - float theta = freq_scale * (float)p; + float theta_base = (float)p; if (is_glm) { - theta = MIN(p, n_ctx - 2); + theta_base = MIN(p, n_ctx - 2); float block_theta = MAX(p - (n_ctx - 2), 0); for (int64_t i0 = 0; i0 < ne0 / 4; i0++) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); + const float cos_theta = cosf(theta_base); + const float sin_theta = sinf(theta_base); const float cos_block_theta = cosf(block_theta); const float sin_block_theta = sinf(block_theta); - theta *= theta_scale; + theta_base *= theta_scale; block_theta *= theta_scale; const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); @@ -12843,12 +11330,14 @@ static void ggml_compute_forward_rope_f16( dst_data[n_dims] = GGML_FP32_TO_FP16(x2*cos_block_theta - x3*sin_block_theta); dst_data[n_dims/2*3] = GGML_FP32_TO_FP16(x2*sin_block_theta + x3*cos_block_theta); } - } if (!is_neox) { + } else if (!is_neox) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); + float cos_theta, sin_theta; + rope_yarn( + theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta + ); - theta *= theta_scale; + theta_base *= theta_scale; const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); @@ -12862,12 +11351,19 @@ static void ggml_compute_forward_rope_f16( } else { // TODO: this might be wrong for ne0 != n_dims - need double check // ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28 + theta_base *= freq_scale; for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { for (int64_t ic = 0; ic < n_dims; ic += 2) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); + // simplified from `(ib * n_dims + ic) * inv_ndims` + float cur_rot = inv_ndims * ic - ib; - theta *= theta_scale; + float cos_theta, sin_theta; + rope_yarn( + theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, + &cos_theta, &sin_theta + ); + + theta_base *= theta_scale; const int64_t i0 = ib*n_dims + ic/2; @@ -12890,15 +11386,16 @@ static void ggml_compute_forward_rope_f16( static void ggml_compute_forward_rope( const struct ggml_compute_params * params, const struct ggml_tensor * src0, + const struct ggml_tensor * src1, struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F16: { - ggml_compute_forward_rope_f16(params, src0, dst); + ggml_compute_forward_rope_f16(params, src0, src1, dst); } break; case GGML_TYPE_F32: { - ggml_compute_forward_rope_f32(params, src0, dst); + ggml_compute_forward_rope_f32(params, src0, src1, dst); } break; default: { @@ -12912,6 +11409,7 @@ static void ggml_compute_forward_rope( static void ggml_compute_forward_rope_back_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, + const struct ggml_tensor * src1, struct ggml_tensor * dst) { if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { @@ -12929,7 +11427,7 @@ static void ggml_compute_forward_rope_back_f32( float xpos_base; bool xpos_down; - const int n_past = ((int32_t *) dst->op_params)[0]; + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; const int n_ctx = ((int32_t *) dst->op_params)[3]; UNUSED(n_ctx); @@ -12938,9 +11436,7 @@ static void ggml_compute_forward_rope_back_f32( memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float)); memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool)); - assert(n_past >= 0); - - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); //printf("n_past = %d, ne2 = %d\n", n_past, ne2); @@ -12966,24 +11462,27 @@ static void ggml_compute_forward_rope_back_f32( const bool is_neox = mode & 2; + const int32_t * pos = (const int32_t *) src1->data; + for (int64_t i3 = 0; i3 < ne3; i3++) { - for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) { - const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); + for (int64_t i2 = 0; i2 < ne2; i2++) { + const int64_t p = pos[i2]; for (int64_t i1 = 0; i1 < ne1; i1++) { if (ir++ < ir0) continue; if (ir > ir1) break; - float theta = freq_scale * (float)p; + float theta_base = freq_scale * (float)p; if (!is_neox) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); + const float cos_theta = cosf(theta_base); + const float sin_theta = sinf(theta_base); + // zeta scaling for xPos only: - float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f; + float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f; if (xpos_down) zeta = 1.0f / zeta; - theta *= theta_scale; + theta_base *= theta_scale; const float * const dy = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); float * dx = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); @@ -12997,10 +11496,10 @@ static void ggml_compute_forward_rope_back_f32( } else { for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { for (int64_t ic = 0; ic < n_dims; ic += 2) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); + const float cos_theta = cosf(theta_base); + const float sin_theta = sinf(theta_base); - theta *= theta_scale; + theta_base *= theta_scale; const int64_t i0 = ib*n_dims + ic/2; @@ -13023,6 +11522,7 @@ static void ggml_compute_forward_rope_back_f32( static void ggml_compute_forward_rope_back_f16( const struct ggml_compute_params * params, const struct ggml_tensor * src0, + const struct ggml_tensor * src1, struct ggml_tensor * dst) { if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { @@ -13033,13 +11533,11 @@ static void ggml_compute_forward_rope_back_f16( // dx = rope_back(dy, src1) // src0 is dy, src1 contains options - const int n_past = ((int32_t *) dst->op_params)[0]; + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; - assert(n_past >= 0); - - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); //printf("n_past = %d, ne2 = %d\n", n_past, ne2); @@ -13065,21 +11563,23 @@ static void ggml_compute_forward_rope_back_f16( const bool is_neox = mode & 2; + const int32_t * pos = (const int32_t *) src1->data; + for (int64_t i3 = 0; i3 < ne3; i3++) { - for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) { - const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); + for (int64_t i2 = 0; i2 < ne2; i2++) { + const int64_t p = pos[i2]; for (int64_t i1 = 0; i1 < ne1; i1++) { if (ir++ < ir0) continue; if (ir > ir1) break; - float theta = (float)p; + float theta_base = (float)p; if (!is_neox) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); + const float cos_theta = cosf(theta_base); + const float sin_theta = sinf(theta_base); - theta *= theta_scale; + theta_base *= theta_scale; const ggml_fp16_t * const dy = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); ggml_fp16_t * dx = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); @@ -13093,10 +11593,10 @@ static void ggml_compute_forward_rope_back_f16( } else { for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { for (int64_t ic = 0; ic < n_dims; ic += 2) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); + const float cos_theta = cosf(theta_base); + const float sin_theta = sinf(theta_base); - theta *= theta_scale; + theta_base *= theta_scale; const int64_t i0 = ib*n_dims + ic/2; @@ -13119,15 +11619,16 @@ static void ggml_compute_forward_rope_back_f16( static void ggml_compute_forward_rope_back( const struct ggml_compute_params * params, const struct ggml_tensor * src0, + const struct ggml_tensor * src1, struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F16: { - ggml_compute_forward_rope_back_f16(params, src0, dst); + ggml_compute_forward_rope_back_f16(params, src0, src1, dst); } break; case GGML_TYPE_F32: { - ggml_compute_forward_rope_back_f32(params, src0, dst); + ggml_compute_forward_rope_back_f32(params, src0, src1, dst); } break; default: { @@ -13138,7 +11639,7 @@ static void ggml_compute_forward_rope_back( // ggml_compute_forward_conv_1d -static void ggml_compute_forward_conv_1d_s1_ph_f16_f32( +static void ggml_compute_forward_conv_1d_f16_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -13150,48 +11651,39 @@ static void ggml_compute_forward_conv_1d_s1_ph_f16_f32( int64_t t0 = ggml_perf_time_us(); UNUSED(t0); - GGML_TENSOR_BINARY_OP_LOCALS; + GGML_TENSOR_BINARY_OP_LOCALS const int ith = params->ith; const int nth = params->nth; const int nk = ne00; - const int nh = nk/2; - const int ew0 = ggml_up32(ne01); + // size of the convolution row - the kernel size unrolled across all input channels + const int ew0 = nk*ne01; + + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + const int32_t p0 = ((const int32_t*)(dst->op_params))[1]; + const int32_t d0 = ((const int32_t*)(dst->op_params))[2]; - GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); GGML_ASSERT(nb10 == sizeof(float)); if (params->type == GGML_TASK_INIT) { - // TODO: fix this memset (wsize is overestimated) memset(params->wdata, 0, params->wsize); - // prepare kernel data (src0) - { - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01); - ggml_fp16_t * dst_data = wdata + i02*ew0*ne00; - for (int64_t i00 = 0; i00 < ne00; i00++) { - dst_data[i00*ew0 + i01] = src[i00]; - } - } - } - } + for (int64_t i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i11*nb11); + ggml_fp16_t * dst_data = wdata; - // prepare source data (src1) - { - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + ne02*ew0*ne00; + for (int64_t i0 = 0; i0 < ne0; i0++) { + for (int64_t ik = 0; ik < nk; ik++) { + const int idx0 = i0*s0 + ik*d0 - p0; - for (int64_t i11 = 0; i11 < ne11; i11++) { - const float * const src = (float *)((char *) src1->data + i11*nb11); - ggml_fp16_t * dst_data = wdata; - for (int64_t i10 = 0; i10 < ne10; i10++) { - dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]); + if(!(idx0 < 0 || idx0 >= ne10)) { + dst_data[i0*ew0 + i11*nk + ik] = GGML_FP32_TO_FP16(src[idx0]); + } } } } @@ -13204,7 +11696,7 @@ static void ggml_compute_forward_conv_1d_s1_ph_f16_f32( } // total rows in dst - const int nr = ne02; + const int nr = ne2; // rows per thread const int dr = (nr + nth - 1)/nth; @@ -13213,23 +11705,22 @@ static void ggml_compute_forward_conv_1d_s1_ph_f16_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { - float * dst_data = (float *)((char *) dst->data + i1*nb1); - for (int64_t i0 = 0; i0 < ne10; ++i0) { - dst_data[i0] = 0; - for (int k = -nh; k <= nh; k++) { - float v = 0.0f; - ggml_vec_dot_f16(ew0, &v, - (ggml_fp16_t *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, - (ggml_fp16_t *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); - - dst_data[i0] += v; + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + + for (int i2 = 0; i2 < ne2; i2++) { + for (int i1 = ir0; i1 < ir1; i1++) { + float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1); + + for (int i0 = 0; i0 < ne0; i0++) { + ggml_vec_dot_f16(ew0, dst_data + i0, + (ggml_fp16_t *) ((char *) src0->data + i1*nb02), + (ggml_fp16_t *) wdata + i2*nb2 + i0*ew0); } } } } -static void ggml_compute_forward_conv_1d_s1_ph_f32( +static void ggml_compute_forward_conv_1d_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -13241,52 +11732,229 @@ static void ggml_compute_forward_conv_1d_s1_ph_f32( int64_t t0 = ggml_perf_time_us(); UNUSED(t0); - GGML_TENSOR_BINARY_OP_LOCALS; + GGML_TENSOR_BINARY_OP_LOCALS const int ith = params->ith; const int nth = params->nth; const int nk = ne00; - const int nh = nk/2; - const int ew0 = ggml_up32(ne01); + const int ew0 = nk*ne01; + + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + const int32_t p0 = ((const int32_t*)(dst->op_params))[1]; + const int32_t d0 = ((const int32_t*)(dst->op_params))[2]; - GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes GGML_ASSERT(nb00 == sizeof(float)); GGML_ASSERT(nb10 == sizeof(float)); if (params->type == GGML_TASK_INIT) { - // TODO: fix this memset (wsize is overestimated) memset(params->wdata, 0, params->wsize); - // prepare kernel data (src0) - { - float * const wdata = (float *) params->wdata + 0; + float * const wdata = (float *) params->wdata + 0; - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01); - float * dst_data = wdata + i02*ew0*ne00; - for (int64_t i00 = 0; i00 < ne00; i00++) { - dst_data[i00*ew0 + i01] = src[i00]; + for (int64_t i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i11*nb11); + float * dst_data = wdata; + + for (int64_t i0 = 0; i0 < ne0; i0++) { + for (int64_t ik = 0; ik < nk; ik++) { + const int idx0 = i0*s0 + ik*d0 - p0; + + if(!(idx0 < 0 || idx0 >= ne10)) { + dst_data[i0*ew0 + i11*nk + ik] = src[idx0]; } } } } - // prepare source data (src1) - { - float * const wdata = (float *) params->wdata + ne02*ew0*ne00; + return; + } - for (int64_t i11 = 0; i11 < ne11; i11++) { - const float * const src = (float *)((char *) src1->data + i11*nb11); - float * dst_data = wdata; - for (int64_t i10 = 0; i10 < ne10; i10++) { - dst_data[(i10 + nh)*ew0 + i11] = src[i10]; + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // total rows in dst + const int nr = ne02; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + float * const wdata = (float *) params->wdata + 0; + + for (int i2 = 0; i2 < ne2; i2++) { + for (int i1 = ir0; i1 < ir1; i1++) { + float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1); + + for (int i0 = 0; i0 < ne0; i0++) { + ggml_vec_dot_f32(ew0, dst_data + i0, + (float *) ((char *) src0->data + i1*nb02), + (float *) wdata + i2*nb2 + i0*ew0); + } + } + } +} + +// TODO: reuse ggml_mul_mat or implement ggml_im2col and remove stage_0 and stage_1 +static void gemm_f16_out_f32(int64_t m, int64_t n, int64_t k, + ggml_fp16_t * A, + ggml_fp16_t * B, + float * C, + const int ith, const int nth) { + // does not seem to make a difference + int64_t m0, m1, n0, n1; + // patches per thread + if (m > n) { + n0 = 0; + n1 = n; + + // total patches in dst + const int np = m; + + // patches per thread + const int dp = (np + nth - 1)/nth; + + // patch range for this thread + m0 = dp*ith; + m1 = MIN(m0 + dp, np); + } else { + m0 = 0; + m1 = m; + + // total patches in dst + const int np = n; + + // patches per thread + const int dp = (np + nth - 1)/nth; + + // patch range for this thread + n0 = dp*ith; + n1 = MIN(n0 + dp, np); + } + + // block-tiling attempt + int64_t blck_n = 16; + int64_t blck_m = 16; + + // int64_t CACHE_SIZE = 2 * 1024 * 1024; // 2MB + // int64_t blck_size = CACHE_SIZE / (sizeof(float) + 2 * sizeof(ggml_fp16_t) * K); + // if (blck_size > 0) { + // blck_0 = 4; + // blck_1 = blck_size / blck_0; + // if (blck_1 < 0) { + // blck_1 = 1; + // } + // // blck_0 = (int64_t)sqrt(blck_size); + // // blck_1 = blck_0; + // } + // // printf("%zd %zd %zd %zd\n", blck_size, K, blck_0, blck_1); + + for (int j = n0; j < n1; j+=blck_n) { + for (int i = m0; i < m1; i+=blck_m) { + // printf("i j k => %d %d %d\n", i, j, K); + for (int ii = i; ii < i + blck_m && ii < m1; ii++) { + for (int jj = j; jj < j + blck_n && jj < n1; jj++) { + ggml_vec_dot_f16(k, + C + ii*n + jj, + A + ii * k, + B + jj * k); + } + } + } + } +} + +// src0: kernel [OC, IC, K] +// src1: signal [N, IC, IL] +// dst: result [N, OL, IC*K] +static void ggml_compute_forward_conv_1d_stage_0_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F16); + + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_BINARY_OP_LOCALS; + + const int64_t N = ne12; + const int64_t IC = ne11; + const int64_t IL = ne10; + + const int64_t K = ne00; + + const int64_t OL = ne1; + + const int ith = params->ith; + const int nth = params->nth; + + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + const int32_t p0 = ((const int32_t*)(dst->op_params))[1]; + const int32_t d0 = ((const int32_t*)(dst->op_params))[2]; + + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->type == GGML_TASK_INIT) { + memset(dst->data, 0, ggml_nbytes(dst)); + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // im2col: [N, IC, IL] => [N, OL, IC*K] + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data; + + for (int64_t in = 0; in < N; in++) { + for (int64_t iol = 0; iol < OL; iol++) { + for (int64_t iic = ith; iic < IC; iic+=nth) { + + // micro kernel + ggml_fp16_t * dst_data = wdata + (in*OL + iol)*(IC*K); // [IC, K] + const float * const src_data = (float *)((char *) src1->data + in*nb12 + iic*nb11); // [IL] + + for (int64_t ik = 0; ik < K; ik++) { + const int64_t iil = iol*s0 + ik*d0 - p0; + + if (!(iil < 0 || iil >= IL)) { + dst_data[iic*K + ik] = GGML_FP32_TO_FP16(src_data[iil]); + } + } } } } + } +} + +// gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K] +// src0: [OC, IC, K] +// src1: [N, OL, IC * K] +// result: [N, OC, OL] +static void ggml_compute_forward_conv_1d_stage_1_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F16); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + if (params->type == GGML_TASK_INIT) { return; } @@ -13294,45 +11962,83 @@ static void ggml_compute_forward_conv_1d_s1_ph_f32( return; } - // total rows in dst - const int nr = ne02; + GGML_TENSOR_BINARY_OP_LOCALS; - // rows per thread - const int dr = (nr + nth - 1)/nth; + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb10 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb0 == sizeof(float)); - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); + const int N = ne12; + const int OL = ne11; - for (int i1 = ir0; i1 < ir1; i1++) { - float * dst_data = (float *)((char *) dst->data + i1*nb1); - for (int64_t i0 = 0; i0 < ne10; ++i0) { - dst_data[i0] = 0; - for (int k = -nh; k <= nh; k++) { - float v = 0.0f; - ggml_vec_dot_f32(ew0, &v, - (float *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, - (float *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); - - dst_data[i0] += v; - } - } + const int OC = ne02; + const int IC = ne01; + const int K = ne00; + + const int ith = params->ith; + const int nth = params->nth; + + int64_t m = OC; + int64_t n = OL; + int64_t k = IC * K; + + // [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K] + for (int i = 0; i < N; i++) { + ggml_fp16_t * A = (ggml_fp16_t *)src0->data; // [m, k] + ggml_fp16_t * B = (ggml_fp16_t *)src1->data + i * m * k; // [n, k] + float * C = (float *)dst->data + i * m * n; // [m, n] + + gemm_f16_out_f32(m, n, k, A, B, C, ith, nth); } } -static void ggml_compute_forward_conv_1d_s1_ph( +static void ggml_compute_forward_conv_1d( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { - switch (src0->type) { + switch(src0->type) { case GGML_TYPE_F16: { - ggml_compute_forward_conv_1d_s1_ph_f16_f32(params, src0, src1, dst); + ggml_compute_forward_conv_1d_f16_f32(params, src0, src1, dst); } break; case GGML_TYPE_F32: { - ggml_compute_forward_conv_1d_s1_ph_f32(params, src0, src1, dst); + ggml_compute_forward_conv_1d_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +static void ggml_compute_forward_conv_1d_stage_0( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch(src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_conv_1d_stage_0_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +static void ggml_compute_forward_conv_1d_stage_1( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch(src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_conv_1d_stage_1_f16(params, src0, src1, dst); } break; default: { @@ -13341,7 +12047,9 @@ static void ggml_compute_forward_conv_1d_s1_ph( } } -static void ggml_compute_forward_conv_1d_s2_ph_f16_f32( +// ggml_compute_forward_conv_transpose_1d + +static void ggml_compute_forward_conv_transpose_1d_f16_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -13353,52 +12061,50 @@ static void ggml_compute_forward_conv_1d_s2_ph_f16_f32( int64_t t0 = ggml_perf_time_us(); UNUSED(t0); - GGML_TENSOR_BINARY_OP_LOCALS; + GGML_TENSOR_BINARY_OP_LOCALS const int ith = params->ith; const int nth = params->nth; - const int nk = ne00; - const int nh = nk/2; - - const int ew0 = ggml_up32(ne01); + const int nk = ne00*ne01*ne02; - GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); GGML_ASSERT(nb10 == sizeof(float)); if (params->type == GGML_TASK_INIT) { - // TODO: fix this memset (wsize is overestimated) memset(params->wdata, 0, params->wsize); - // prepare kernel data (src0) + // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout) { ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i01 = 0; i01 < ne01; i01++) { const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01); - ggml_fp16_t * dst_data = wdata + i02*ew0*ne00; + ggml_fp16_t * dst_data = wdata + i01*ne00*ne02; for (int64_t i00 = 0; i00 < ne00; i00++) { - dst_data[i00*ew0 + i01] = src[i00]; + dst_data[i00*ne02 + i02] = src[i00]; } } } } - // prepare source data (src1) + // permute source data (src1) from (L x Cin) to (Cin x L) { - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + ne02*ew0*ne00; + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk; + ggml_fp16_t * dst_data = wdata; for (int64_t i11 = 0; i11 < ne11; i11++) { const float * const src = (float *)((char *) src1->data + i11*nb11); - ggml_fp16_t * dst_data = wdata; for (int64_t i10 = 0; i10 < ne10; i10++) { - dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]); + dst_data[i10*ne11 + i11] = GGML_FP32_TO_FP16(src[i10]); } } } + // need to zero dst since we are accumulating into it + memset(dst->data, 0, ggml_nbytes(dst)); + return; } @@ -13406,8 +12112,10 @@ static void ggml_compute_forward_conv_1d_s2_ph_f16_f32( return; } + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + // total rows in dst - const int nr = ne02; + const int nr = ne1; // rows per thread const int dr = (nr + nth - 1)/nth; @@ -13416,23 +12124,26 @@ static void ggml_compute_forward_conv_1d_s2_ph_f16_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + ggml_fp16_t * const wdata_src = wdata + nk; + for (int i1 = ir0; i1 < ir1; i1++) { float * dst_data = (float *)((char *) dst->data + i1*nb1); - for (int64_t i0 = 0; i0 < ne10; i0 += 2) { - dst_data[i0/2] = 0; - for (int k = -nh; k <= nh; k++) { - float v = 0.0f; - ggml_vec_dot_f16(ew0, &v, - (ggml_fp16_t *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, - (ggml_fp16_t *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); - - dst_data[i0/2] += v; + ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00; + for (int i10 = 0; i10 < ne10; i10++) { + const int i1n = i10*ne11; + for (int i00 = 0; i00 < ne00; i00++) { + float v = 0; + ggml_vec_dot_f16(ne02, &v, + (ggml_fp16_t *) wdata_src + i1n, + (ggml_fp16_t *) wdata_kernel + i00*ne02); + dst_data[i10*s0 + i00] += v; } } } } -static void ggml_compute_forward_conv_1d_s2_ph_f32( +static void ggml_compute_forward_conv_transpose_1d_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -13444,34 +12155,29 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32( int64_t t0 = ggml_perf_time_us(); UNUSED(t0); - GGML_TENSOR_BINARY_OP_LOCALS; + GGML_TENSOR_BINARY_OP_LOCALS const int ith = params->ith; const int nth = params->nth; - const int nk = ne00; - const int nh = nk/2; - - const int ew0 = ggml_up32(ne01); + const int nk = ne00*ne01*ne02; - GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes GGML_ASSERT(nb00 == sizeof(float)); GGML_ASSERT(nb10 == sizeof(float)); if (params->type == GGML_TASK_INIT) { - // TODO: fix this memset (wsize is overestimated) memset(params->wdata, 0, params->wsize); - // prepare kernel data (src0) + // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout) { float * const wdata = (float *) params->wdata + 0; for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i01 = 0; i01 < ne01; i01++) { const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01); - float * dst_data = wdata + i02*ew0*ne00; + float * dst_data = wdata + i01*ne00*ne02; for (int64_t i00 = 0; i00 < ne00; i00++) { - dst_data[i00*ew0 + i01] = src[i00]; + dst_data[i00*ne02 + i02] = src[i00]; } } } @@ -13479,17 +12185,20 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32( // prepare source data (src1) { - float * const wdata = (float *) params->wdata + ne02*ew0*ne00; + float * const wdata = (float *) params->wdata + nk; + float * dst_data = wdata; for (int64_t i11 = 0; i11 < ne11; i11++) { const float * const src = (float *)((char *) src1->data + i11*nb11); - float * dst_data = wdata; for (int64_t i10 = 0; i10 < ne10; i10++) { - dst_data[(i10 + nh)*ew0 + i11] = src[i10]; + dst_data[i10*ne11 + i11] = src[i10]; } } } + // need to zero dst since we are accumulating into it + memset(dst->data, 0, ggml_nbytes(dst)); + return; } @@ -13497,8 +12206,10 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32( return; } + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + // total rows in dst - const int nr = ne02; + const int nr = ne1; // rows per thread const int dr = (nr + nth - 1)/nth; @@ -13507,23 +12218,26 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); + float * const wdata = (float *) params->wdata + 0; + float * const wdata_src = wdata + nk; + for (int i1 = ir0; i1 < ir1; i1++) { float * dst_data = (float *)((char *) dst->data + i1*nb1); - for (int64_t i0 = 0; i0 < ne10; i0 += 2) { - dst_data[i0/2] = 0; - for (int k = -nh; k <= nh; k++) { - float v = 0.0f; - ggml_vec_dot_f32(ew0, &v, - (float *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, - (float *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); - - dst_data[i0/2] += v; + float * wdata_kernel = wdata + i1*ne02*ne00; + for (int i10 = 0; i10 < ne10; i10++) { + const int i1n = i10*ne11; + for (int i00 = 0; i00 < ne00; i00++) { + float v = 0; + ggml_vec_dot_f32(ne02, &v, + wdata_src + i1n, + wdata_kernel + i00*ne02); + dst_data[i10*s0 + i00] += v; } } } } -static void ggml_compute_forward_conv_1d_s2_ph( +static void ggml_compute_forward_conv_transpose_1d( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -13531,11 +12245,11 @@ static void ggml_compute_forward_conv_1d_s2_ph( switch (src0->type) { case GGML_TYPE_F16: { - ggml_compute_forward_conv_1d_s2_ph_f16_f32(params, src0, src1, dst); + ggml_compute_forward_conv_transpose_1d_f16_f32(params, src0, src1, dst); } break; case GGML_TYPE_F32: { - ggml_compute_forward_conv_1d_s2_ph_f32(params, src0, src1, dst); + ggml_compute_forward_conv_transpose_1d_f32(params, src0, src1, dst); } break; default: { @@ -13544,28 +12258,145 @@ static void ggml_compute_forward_conv_1d_s2_ph( } } -// ggml_compute_forward_conv_1d +// ggml_compute_forward_conv_2d -static void ggml_compute_forward_conv_1d( +// src0: kernel [OC, IC, KH, KW] +// src1: image [N, IC, IH, IW] +// dst: result [N, OH, OW, IC*KH*KW] +static void ggml_compute_forward_conv_2d_stage_0_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F16); + + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_BINARY_OP_LOCALS; + + const int64_t N = ne13; + const int64_t IC = ne12; + const int64_t IH = ne11; + const int64_t IW = ne10; + + // const int64_t OC = ne03; + // const int64_t IC = ne02; + const int64_t KH = ne01; + const int64_t KW = ne00; + + const int64_t OH = ne2; + const int64_t OW = ne1; + + const int ith = params->ith; + const int nth = params->nth; + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; - const int32_t p0 = ((const int32_t*)(dst->op_params))[1]; - const int32_t d0 = ((const int32_t*)(dst->op_params))[2]; - GGML_ASSERT(d0 == 1); // dilation not supported - GGML_ASSERT(p0 == src0->ne[0]/2); // only half padding supported - if (s0 == 1) { - ggml_compute_forward_conv_1d_s1_ph(params, src0, src1, dst); - } else if (s0 == 2) { - ggml_compute_forward_conv_1d_s2_ph(params, src0, src1, dst); - } else { - GGML_ASSERT(false); // only stride 1 and 2 supported - }; + const int32_t s1 = ((const int32_t*)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t*)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t*)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t*)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t*)(dst->op_params))[5]; + + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->type == GGML_TASK_INIT) { + memset(dst->data, 0, ggml_nbytes(dst)); + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data; + + for (int64_t in = 0; in < N; in++) { + for (int64_t ioh = 0; ioh < OH; ioh++) { + for (int64_t iow = 0; iow < OW; iow++) { + for (int64_t iic = ith; iic < IC; iic+=nth) { + + // micro kernel + ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW] + const float * const src_data = (float *)((char *) src1->data + in*nb13 + iic*nb12); // [IH, IW] + + for (int64_t ikh = 0; ikh < KH; ikh++) { + for (int64_t ikw = 0; ikw < KW; ikw++) { + const int64_t iiw = iow*s0 + ikw*d0 - p0; + const int64_t iih = ioh*s1 + ikh*d1 - p1; + + if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) { + dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_FP32_TO_FP16(src_data[iih*IW + iiw]); + } + } + } + } + } + } + } + } } -// ggml_compute_forward_conv_2d +// gemm: [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW] +// src0: [OC, IC, KH, KW] +// src1: [N, OH, OW, IC * KH * KW] +// result: [N, OC, OH, OW] +static void ggml_compute_forward_conv_2d_stage_1_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F16); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_TENSOR_BINARY_OP_LOCALS; + + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb10 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb0 == sizeof(float)); + + const int N = ne13; + const int OH = ne12; + const int OW = ne11; + + const int OC = ne03; + const int IC = ne02; + const int KH = ne01; + const int KW = ne00; + + const int ith = params->ith; + const int nth = params->nth; + + int64_t m = OC; + int64_t n = OH * OW; + int64_t k = IC * KH * KW; + + // [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW] + for (int i = 0; i < N; i++) { + ggml_fp16_t * A = (ggml_fp16_t *)src0->data; // [m, k] + ggml_fp16_t * B = (ggml_fp16_t *)src1->data + i * m * k; // [n, k] + float * C = (float *)dst->data + i * m * n; // [m, n] + + gemm_f16_out_f32(m, n, k, A, B, C, ith, nth); + } +} static void ggml_compute_forward_conv_2d_f16_f32( const struct ggml_compute_params * params, @@ -13579,16 +12410,40 @@ static void ggml_compute_forward_conv_2d_f16_f32( int64_t t0 = ggml_perf_time_us(); UNUSED(t0); - GGML_TENSOR_BINARY_OP_LOCALS; + GGML_TENSOR_BINARY_OP_LOCALS + + // src1: image [N, IC, IH, IW] + // src0: kernel [OC, IC, KH, KW] + // dst: result [N, OC, OH, OW] + // ne12: IC + // ne0: OW + // ne1: OH + // nk0: KW + // nk1: KH + // ne13: N + + const int N = ne13; + const int IC = ne12; + const int IH = ne11; + const int IW = ne10; + + const int OC = ne03; + // const int IC = ne02; + const int KH = ne01; + const int KW = ne00; + + const int OH = ne1; + const int OW = ne0; const int ith = params->ith; const int nth = params->nth; - const int nk0 = ne00; - const int nk1 = ne01; + // const int nk0 = ne00; + // const int nk1 = ne01; // size of the convolution row - the kernel size unrolled across all channels - const int ew0 = nk0*nk1*ne02; + // const int ew0 = nk0*nk1*ne02; + // ew0: IC*KH*KW const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; const int32_t s1 = ((const int32_t*)(dst->op_params))[1]; @@ -13604,23 +12459,28 @@ static void ggml_compute_forward_conv_2d_f16_f32( memset(params->wdata, 0, params->wsize); // prepare source data (src1) + // im2col: [N, IC, IH, IW] => [N*OH*OW, IC*KH*KW] + { ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; - for (int i12 = 0; i12 < ne12; i12++) { - const float * const src = (float *)((char *) src1->data + i12*nb12); - ggml_fp16_t * dst_data = wdata; + for (int in = 0; in < N; in++) { + for (int iic = 0; iic < IC; iic++) { + for (int ioh = 0; ioh < OH; ioh++) { + for (int iow = 0; iow < OW; iow++) { - for (int i1 = 0; i1 < ne1; i1++) { - for (int i0 = 0; i0 < ne0; i0++) { - for (int ik1 = 0; ik1 < nk1; ik1++) { - for (int ik0 = 0; ik0 < nk0; ik0++) { - const int idx0 = i0*s0 + ik0*d0 - p0; - const int idx1 = i1*s1 + ik1*d1 - p1; - - if (!(idx1 < 0 || idx1 >= ne11 || idx0 < 0 || idx0 >= ne10)) { - dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] = - GGML_FP32_TO_FP16(src[idx1*ne10 + idx0]); + // micro kernel + ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW] + const float * const src_data = (float *)((char *) src1->data + in*nb13 + iic*nb12); // [IH, IW] + + for (int ikh = 0; ikh < KH; ikh++) { + for (int ikw = 0; ikw < KW; ikw++) { + const int iiw = iow*s0 + ikw*d0 - p0; + const int iih = ioh*s1 + ikh*d1 - p1; + + if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) { + dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_FP32_TO_FP16(src_data[iih*IW + iiw]); + } } } } @@ -13636,30 +12496,22 @@ static void ggml_compute_forward_conv_2d_f16_f32( return; } - // total patches in dst - const int np = ne2; - - // patches per thread - const int dp = (np + nth - 1)/nth; + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + // wdata: [N*OH*OW, IC*KH*KW] + // dst: result [N, OC, OH, OW] + // src0: kernel [OC, IC, KH, KW] - // patch range for this thread - const int ip0 = dp*ith; - const int ip1 = MIN(ip0 + dp, np); + int64_t m = OC; + int64_t n = OH * OW; + int64_t k = IC * KH * KW; - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + // [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW] + for (int i = 0; i < N; i++) { + ggml_fp16_t * A = (ggml_fp16_t *)src0->data; // [m, k] + ggml_fp16_t * B = (ggml_fp16_t *)wdata + i * m * k; // [n, k] + float * C = (float *)dst->data + i * m * n; // [m * k] - for (int i3 = 0; i3 < ne3; i3++) { - for (int i2 = ip0; i2 < ip1; i2++) { - float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2); - - for (int i1 = 0; i1 < ne1; ++i1) { - for (int i0 = 0; i0 < ne0; ++i0) { - ggml_vec_dot_f16(ew0, dst_data + i1*ne0 + i0, - (ggml_fp16_t *) ((char *) src0->data + i2*nb03), - (ggml_fp16_t *) wdata + i3*nb3 + (i1*ne0 + i0)*ew0); - } - } - } + gemm_f16_out_f32(m, n, k, A, B, C, ith, nth); } } @@ -13685,6 +12537,48 @@ static void ggml_compute_forward_conv_2d( } } +static void ggml_compute_forward_conv_2d_stage_0( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_conv_2d_stage_0_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_F32: + { + GGML_ASSERT(false); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +static void ggml_compute_forward_conv_2d_stage_1( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_conv_2d_stage_1_f16(params, src0, src1, dst); + } break; + case GGML_TYPE_F32: + { + GGML_ASSERT(false); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_conv_transpose_2d static void ggml_compute_forward_conv_transpose_2d( @@ -13699,7 +12593,7 @@ static void ggml_compute_forward_conv_transpose_2d( int64_t t0 = ggml_perf_time_us(); UNUSED(t0); - GGML_TENSOR_BINARY_OP_LOCALS; + GGML_TENSOR_BINARY_OP_LOCALS const int ith = params->ith; const int nth = params->nth; @@ -13743,6 +12637,8 @@ static void ggml_compute_forward_conv_transpose_2d( } } + memset(dst->data, 0, ggml_nbytes(dst)); + return; } @@ -13855,14 +12751,11 @@ static void ggml_compute_forward_pool_1d( ggml_compute_forward_pool_1d_sk_p0(params, op, src0, k0, dst); } -// ggml_compute_forward_pool_2d_sk_p0 +// ggml_compute_forward_pool_2d -static void ggml_compute_forward_pool_2d_sk_p0( +static void ggml_compute_forward_pool_2d( const struct ggml_compute_params * params, - const enum ggml_op_pool op, const struct ggml_tensor * src, - const int k0, - const int k1, struct ggml_tensor * dst) { assert(src->type == GGML_TYPE_F32); assert(params->ith == 0); @@ -13871,6 +12764,14 @@ static void ggml_compute_forward_pool_2d_sk_p0( return; } + const int32_t * opts = (const int32_t *)dst->op_params; + enum ggml_op_pool op = opts[0]; + const int k0 = opts[1]; + const int k1 = opts[2]; + const int s0 = opts[3]; + const int s1 = opts[4]; + const int p0 = opts[5]; + const int p1 = opts[6]; const char * cdata = (const char*)src->data; const char * const data_end = cdata + ggml_nbytes(src); @@ -13881,6 +12782,8 @@ static void ggml_compute_forward_pool_2d_sk_p0( float * dplane = (float *)dst->data; const int ka = k0 * k1; + const int offset0 = -p0; + const int offset1 = -p1; while (cdata < data_end) { for (int oy = 0; oy < py; ++oy) { @@ -13893,13 +12796,15 @@ static void ggml_compute_forward_pool_2d_sk_p0( case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break; } - const int ix = ox * k0; - const int iy = oy * k1; + const int ix = offset0 + ox * s0; + const int iy = offset1 + oy * s1; for (int ky = 0; ky < k1; ++ky) { + if (iy + ky < 0 || iy + ky >= src->ne[1]) continue; const float * const srow = (const float *)(cdata + src->nb[1] * (iy + ky)); for (int kx = 0; kx < k0; ++kx) { int j = ix + kx; + if (j < 0 || j >= src->ne[0]) continue; switch (op) { case GGML_OP_POOL_AVG: *out += srow[j]; break; case GGML_OP_POOL_MAX: if (srow[j] > *out) *out = srow[j]; break; @@ -13916,31 +12821,8 @@ static void ggml_compute_forward_pool_2d_sk_p0( } cdata += src->nb[2]; - dplane += pa; - } -} - -// ggml_compute_forward_pool_2d - -static void ggml_compute_forward_pool_2d( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - - const int32_t * opts = (const int32_t *)dst->op_params; - enum ggml_op_pool op = opts[0]; - const int k0 = opts[1]; - const int k1 = opts[2]; - const int s0 = opts[3]; - const int s1 = opts[4]; - const int p0 = opts[5]; - const int p1 = opts[6]; - GGML_ASSERT(p0 == 0); - GGML_ASSERT(p1 == 0); // padding not supported - GGML_ASSERT(k0 == s0); - GGML_ASSERT(k1 == s1); // only s = k supported - - ggml_compute_forward_pool_2d_sk_p0(params, op, src0, k0, k1, dst); + dplane += pa; + } } // ggml_compute_forward_upscale @@ -13958,7 +12840,7 @@ static void ggml_compute_forward_upscale_f32( const int ith = params->ith; - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS const int scale_factor = dst->op_params[0]; @@ -14010,14 +12892,14 @@ static void ggml_compute_forward_flash_attn_f32( int64_t t0 = ggml_perf_time_us(); UNUSED(t0); - GGML_TENSOR_LOCALS(int64_t, neq, q, ne); - GGML_TENSOR_LOCALS(size_t, nbq, q, nb); - GGML_TENSOR_LOCALS(int64_t, nek, k, ne); - GGML_TENSOR_LOCALS(size_t, nbk, k, nb); - GGML_TENSOR_LOCALS(int64_t, nev, v, ne); - GGML_TENSOR_LOCALS(size_t, nbv, v, nb); - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); - GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) const int ith = params->ith; const int nth = params->nth; @@ -14087,10 +12969,11 @@ static void ggml_compute_forward_flash_attn_f32( S[i] = -INFINITY; } - for (int64_t ic = 0; ic < nek1; ++ic) { + const int64_t masked_begin = masked ? (P + iq1 + 1) : M; + for (int64_t ic = 0; ic < masked_begin; ++ic) { // k indices const int ik3 = iq3; - const int ik2 = iq2; + const int ik2 = iq2 % nek2; const int ik1 = ic; // S indices @@ -14103,20 +12986,18 @@ static void ggml_compute_forward_flash_attn_f32( } // scale - ggml_vec_scale_f32(nek1, S, scale); + ggml_vec_scale_f32(masked_begin, S, scale); - if (masked) { - for (int64_t i = P; i < M; i++) { - if (i > P + iq1) { - S[i] = -INFINITY; - } - } + for (int64_t i = masked_begin; i < M; i++) { + S[i] = -INFINITY; } // softmax + // exclude known -INF S[..] values from max and loop + // dont forget to set their SW values to zero { float max = -INFINITY; - ggml_vec_max_f32(M, &max, S); + ggml_vec_max_f32(masked_begin, &max, S); ggml_float sum = 0.0; { @@ -14130,10 +13011,15 @@ static void ggml_compute_forward_flash_attn_f32( ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { + if (i >= masked_begin) { + break; + } float * SS = S + i; for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) { - if (SS[j] == -INFINITY) { + if (i + j >= masked_begin) { + break; + } else if (SS[j] == -INFINITY) { SS[j] = 0.0f; } else { #ifndef GGML_FLASH_ATTN_EXP_FP16 @@ -14141,7 +13027,7 @@ static void ggml_compute_forward_flash_attn_f32( #else ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max); memcpy(&scvt[j], &s, sizeof(uint16_t)); - const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); + const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]); #endif sump[j] += (ggml_float)val; SS[j] = val; @@ -14158,10 +13044,10 @@ static void ggml_compute_forward_flash_attn_f32( assert(sum > 0.0); sum = 1.0/sum; - ggml_vec_scale_f32(M, S, sum); + ggml_vec_scale_f32(masked_begin, S, sum); #ifndef NDEBUG - for (int i = 0; i < M; ++i) { + for (int i = 0; i < masked_begin; ++i) { assert(!isnan(S[i])); assert(!isinf(S[i])); } @@ -14174,9 +13060,13 @@ static void ggml_compute_forward_flash_attn_f32( const int i2 = iq2; const int i3 = iq3; - ggml_vec_dot_f32(nek1, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - (float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), + // v indices + const int iv2 = iq2 % nev2; + const int iv3 = iq3; + + ggml_vec_dot_f32(masked_begin, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), S); } } @@ -14192,14 +13082,14 @@ static void ggml_compute_forward_flash_attn_f16( int64_t t0 = ggml_perf_time_us(); UNUSED(t0); - GGML_TENSOR_LOCALS(int64_t, neq, q, ne); - GGML_TENSOR_LOCALS(size_t, nbq, q, nb); - GGML_TENSOR_LOCALS(int64_t, nek, k, ne); - GGML_TENSOR_LOCALS(size_t, nbk, k, nb); - GGML_TENSOR_LOCALS(int64_t, nev, v, ne); - GGML_TENSOR_LOCALS(size_t, nbv, v, nb); - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); - GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) const int ith = params->ith; const int nth = params->nth; @@ -14273,7 +13163,7 @@ static void ggml_compute_forward_flash_attn_f16( for (int64_t ic = 0; ic < nek1; ++ic) { // k indices const int ik3 = iq3; - const int ik2 = iq2; + const int ik2 = iq2 % nek2; const int ik1 = ic; // S indices @@ -14288,7 +13178,7 @@ static void ggml_compute_forward_flash_attn_f16( for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) { // k indices const int ik3 = iq3; - const int ik2 = iq2; + const int ik2 = iq2 % nek2; const int ik1 = ic; // S indices @@ -14313,6 +13203,8 @@ static void ggml_compute_forward_flash_attn_f16( } // softmax + // todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero. + // dont forget to set their S values to zero { float max = -INFINITY; ggml_vec_max_f32(M, &max, S); @@ -14337,7 +13229,7 @@ static void ggml_compute_forward_flash_attn_f16( } else { ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max); memcpy(&scvt[j], &s, sizeof(uint16_t)); - const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); + const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]); sump[j] += (ggml_float)val; SS[j] = val; } @@ -14369,6 +13261,7 @@ static void ggml_compute_forward_flash_attn_f16( S16[i] = GGML_FP32_TO_FP16(S[i]); } + // todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16). if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) { for (int64_t ic = 0; ic < nev1; ++ic) { // dst indices @@ -14376,9 +13269,13 @@ static void ggml_compute_forward_flash_attn_f16( const int i2 = iq2; const int i3 = iq3; - ggml_vec_dot_f16(nek1, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), + // v indices + const int iv2 = iq2 % nev2; + const int iv3 = iq3; + + ggml_vec_dot_f16(nev0, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), S16); } } else { @@ -14388,9 +13285,13 @@ static void ggml_compute_forward_flash_attn_f16( const int i2 = iq2; const int i3 = iq3; - ggml_vec_dot_f16_unroll(nek1, nbv1, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), + // v indices + const int iv2 = iq2 % nev2; + const int iv3 = iq3; + + ggml_vec_dot_f16_unroll(nev0, nbv1, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), S16); } } @@ -14433,18 +13334,18 @@ static void ggml_compute_forward_flash_ff_f16( int64_t t0 = ggml_perf_time_us(); UNUSED(t0); - GGML_TENSOR_LOCALS(int64_t, nea, a, ne); - GGML_TENSOR_LOCALS(size_t, nba, a, nb); - GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne); - GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb); - GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne); - GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb); - GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne); - GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb); - GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne); - GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb); - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); - GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + GGML_TENSOR_LOCALS(int64_t, nea, a, ne) + GGML_TENSOR_LOCALS(size_t, nba, a, nb) + GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne) + GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb) + GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne) + GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb) + GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne) + GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb) + GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne) + GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) const int ith = params->ith; const int nth = params->nth; @@ -14592,16 +13493,16 @@ static void ggml_compute_forward_flash_attn_back_f32( int64_t t0 = ggml_perf_time_us(); UNUSED(t0); - GGML_TENSOR_LOCALS(int64_t, neq, q, ne); - GGML_TENSOR_LOCALS(size_t, nbq, q, nb); - GGML_TENSOR_LOCALS(int64_t, nek, k, ne); - GGML_TENSOR_LOCALS(size_t, nbk, k, nb); - GGML_TENSOR_LOCALS(int64_t, nev, v, ne); - GGML_TENSOR_LOCALS(size_t, nbv, v, nb); - GGML_TENSOR_LOCALS(int64_t, ned, d, ne); - GGML_TENSOR_LOCALS(size_t, nbd, d, nb); - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); - GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ned, d, ne) + GGML_TENSOR_LOCALS(size_t, nbd, d, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) const int ith = params->ith; const int nth = params->nth; @@ -14649,10 +13550,37 @@ static void ggml_compute_forward_flash_attn_back_f32( return; } - // parallelize by q rows using ggml_vec_dot_f32 + const int64_t elem_q = ggml_nelements(q); + const int64_t elem_k = ggml_nelements(k); - // total rows in q - const int nr = neq2*neq3; + enum ggml_type result_type = dst->type; + GGML_ASSERT(ggml_blck_size(result_type) == 1); + const size_t tsize = ggml_type_size(result_type); + + const size_t offs_q = 0; + const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN); + const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN); + + void * grad_q = (char *) dst->data; + void * grad_k = (char *) dst->data + offs_k; + void * grad_v = (char *) dst->data + offs_v; + + const size_t nbgq1 = nb0*neq0; + const size_t nbgq2 = nb0*neq0*neq1; + const size_t nbgq3 = nb0*neq0*neq1*neq2; + + const size_t nbgk1 = nb0*nek0; + const size_t nbgk2 = nb0*nek0*nek1; + const size_t nbgk3 = nb0*nek0*nek1*neq2; + + const size_t nbgv1 = nb0*nev0; + const size_t nbgv2 = nb0*nev0*nev1; + const size_t nbgv3 = nb0*nev0*nev1*neq2; + + // parallelize by k rows using ggml_vec_dot_f32 + + // total rows in k + const int nr = nek2*nek3; // rows per thread const int dr = (nr + nth - 1)/nth; @@ -14665,268 +13593,243 @@ static void ggml_compute_forward_flash_attn_back_f32( //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); + // how often k2 (and v2) is repeated in q2 + int nrep = neq2/nek2; + for (int ir = ir0; ir < ir1; ++ir) { // q indices - const int iq3 = ir/(neq2); - const int iq2 = ir - iq3*neq2; - for ( int iq1 = 0; iq1 < neq1; ++iq1) { + const int ik3 = ir/(nek2); + const int ik2 = ir - ik3*nek2; + const int iq3 = ik3; + const int id3 = ik3; + const int iv3 = ik3; + const int iv2 = ik2; - // not sure about CACHE_LINE_SIZE_F32.. - // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset? - float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32); - float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32); + for (int irep = 0; irep < nrep; ++irep) { + const int iq2 = ik2 + irep*nek2; + const int id2 = iq2; - for (int i = M; i < Mup; ++i) { - S[i] = -INFINITY; - } + // (ik2 + irep*nek2) % nek2 == ik2 + for (int iq1 = 0; iq1 < neq1; ++iq1) { + const int id1 = iq1; - for (int64_t ic = 0; ic < nek1; ++ic) { - // k indices - const int ik3 = iq3; - const int ik2 = iq2; - const int ik1 = ic; + // not sure about CACHE_LINE_SIZE_F32.. + // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset? + float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32); + float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32); - // S indices - const int i1 = ik1; + for (int i = M; i < Mup; ++i) { + S[i] = -INFINITY; + } - ggml_vec_dot_f32(neq0, - S + i1, - (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), - (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); - } + const int64_t masked_begin = masked ? (P + iq1 + 1) : M; + for (int64_t ic = 0; ic < masked_begin; ++ic) { + // k indices + const int ik1 = ic; - // scale - ggml_vec_scale_f32(nek1, S, scale); + // S indices + const int i1 = ik1; - if (masked) { - for (int64_t i = P; i < M; i++) { - if (i > P + iq1) { - S[i] = -INFINITY; - } + ggml_vec_dot_f32(neq0, + S + i1, + (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); } - } - // softmax - { - float max = -INFINITY; - ggml_vec_max_f32(M, &max, S); + // scale + ggml_vec_scale_f32(masked_begin, S, scale); - ggml_float sum = 0.0; + for (int64_t i = masked_begin; i < M; i++) { + S[i] = -INFINITY; + } + + // softmax + // exclude known -INF S[..] values from max and loop + // dont forget to set their SM values to zero { + float max = -INFINITY; + ggml_vec_max_f32(masked_begin, &max, S); + + ggml_float sum = 0.0; + { #ifdef GGML_SOFT_MAX_ACCELERATE - max = -max; - vDSP_vsadd(SM, 1, &max, SM, 1, Mup); - vvexpf(SM, SM, &Mup); - ggml_vec_sum_f32(Mup, &sum, SM); + max = -max; + vDSP_vsadd(SM, 1, &max, SM, 1, Mup); + vvexpf(SM, SM, &Mup); + ggml_vec_sum_f32(Mup, &sum, SM); #else - uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt); - ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; - - for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { - float * SR = S + i; - float * SW = SM + i; + uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt); + ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; - for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) { - if (SR[j] == -INFINITY) { - SW[j] = 0.0f; - } else { + for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { + if (i >= masked_begin) { + break; + } + float * SR = S + i; + float * SW = SM + i; + + for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) { + if (i + j >= masked_begin) { + break; + } else if (SR[j] == -INFINITY) { + SW[j] = 0.0f; + } else { #ifndef GGML_FLASH_ATTN_EXP_FP16 - const float val = expf(SR[j] - max); + const float val = expf(SR[j] - max); #else - ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max); - memcpy(&scvt[j], &s, sizeof(uint16_t)); - const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); + ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max); + memcpy(&scvt[j], &s, sizeof(uint16_t)); + const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]); #endif - sump[j] += (ggml_float)val; - SW[j] = val; + sump[j] += (ggml_float)val; + SW[j] = val; + } } } - } - for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) { - sum += sump[i]; - } + for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) { + sum += sump[i]; + } #endif - } - - assert(sum > 0.0); - - sum = 1.0/sum; - ggml_vec_scale_f32(M, SM, sum); - - } - - // step-by-step explanation - { - // forward-process shape grads from backward process - // parallel_for iq2,iq3: - // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,iq2,iq3] += grad[kcur] - // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur] - // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iq2,iq3] += grad[vcur] - // for iq1: - // kcur = k[:D,:M,iq2,iq3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur - // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur - // vcur = v[:M,:D,iq2,iq3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4 - // S0 = -Inf [D,1,1,1] - // ~S1[i] = dot(kcur[:D,i], qcur) - // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale - // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P) - // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) - // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur - // ~S5[i] = dot(vcur[:,i], S4) - // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,iq1,iq2,iq3] - // ~dst[i,iq1,iq2,iq3] = S5[i] ^ - // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,iq1,iq2,iq3] - // dst backward-/ grad[dst] = d - // - // output gradients with their dependencies: - // - // grad[kcur] = grad[S1].T @ qcur - // grad[S1] = diag_mask_zero(grad[S3], P) * scale - // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) - // grad[S4] = grad[S5] @ vcur - // grad[S4] = d[:D,iq1,iq2,iq3] @ vcur - // grad[qcur] = grad[S1] @ kcur - // grad[vcur] = grad[S5].T @ S4 - // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4 - // - // in post-order: - // - // S1 = qcur @ kcur.T - // S2 = S1 * scale - // S3 = diag_mask_inf(S2, P) - // S4 = softmax(S3) - // grad[S4] = d[:D,iq1,iq2,iq3] @ vcur - // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) - // grad[S1] = diag_mask_zero(grad[S3], P) * scale - // grad[qcur] = grad[S1] @ kcur - // grad[kcur] = grad[S1].T @ qcur - // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4 - // - // using less variables (SM=S4): - // - // S = diag_mask_inf(qcur @ kcur.T * scale, P) - // SM = softmax(S) - // S = d[:D,iq1,iq2,iq3] @ vcur - // dot_SM_gradSM = dot(SM, S) - // S = SM * (S - dot(SM, S)) - // S = diag_mask_zero(S, P) * scale - // - // grad[q][:D,iq1,iq2,iq3] += S @ kcur - // grad[k][:D,:M,iq2,iq3] += S.T @ qcur - // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM - } - - // S = gradSM = d[:D,iq1,iq2,iq3] @ vcur - // S = d[:D,iq1,iq2,iq3] @ vcur - // S[:M] += vcur[:M,ic] * d[ic,iq1,iq2,iq3] - ggml_vec_set_f32(M, S, 0); - for (int64_t ic = 0; ic < D; ++ic) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; + } - ggml_vec_mad_f32(M, - S, - (float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), - *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3))); - } + assert(sum > 0.0); - // S = SM * (S - dot(SM, S)) - float dot_SM_gradSM = 0; - ggml_vec_dot_f32 (M, &dot_SM_gradSM, SM, S); - ggml_vec_acc1_f32(M, S, -dot_SM_gradSM); - ggml_vec_mul_f32 (M, S, S, SM); + sum = 1.0/sum; + ggml_vec_scale_f32(masked_begin, SM, sum); - // S = diag_mask_zero(S, P) * scale - if (masked) { - // for (int64_t i = P + iq1 + 1; i < M; i++) { - // S[i] = 0; - // } - for (int64_t i = P; i < M; i++) { - if (i > P + iq1) { - S[i] = 0; - } } - } - ggml_vec_scale_f32(M, S, scale); - - void * grad_q = (char *) dst->data; - void * grad_k = (char *) dst->data + nb0*D*N*neq2*neq3; - void * grad_v = (char *) dst->data + nb0*D*N*neq2*neq3 + nb0*D*M*neq2*neq3; - - const size_t nbgq1 = nb0*neq0; - const size_t nbgq2 = nb0*neq0*neq1; - const size_t nbgq3 = nb0*neq0*neq1*neq2; - - const size_t nbgk1 = nb0*nek0; - const size_t nbgk2 = nb0*nek0*nek1; - const size_t nbgk3 = nb0*nek0*nek1*neq2; - - const size_t nbgv1 = nb0*nev0; - const size_t nbgv2 = nb0*nev0*nev1; - const size_t nbgv3 = nb0*nev0*nev1*neq2; - - // S shape [M,1] - // SM shape [M,1] - // kcur shape [D,M] - // qcur shape [D,1] - // vcur shape [M,D] - // - // grad[q][:D,iq1,iq2,iq3] += S @ kcur - // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M] - // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic] - // - //// grad[q][ic,iq1,iq2,iq3] += dot(kcur[:,ic],S.T) - //// grad[q][ic,iq1,iq2,iq3] += dot(k[:D,ic,iq2,iq3],S.T) - for (int64_t ic = 0; ic < M; ++ic) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; - ggml_vec_mad_f32(D, - (float *) ((char *) grad_q + (i1*nbgq1 + i2*nbgq2 + i3*nbgq3)), - (float *) ((char *) k->data + (ic*nbk1 + i2*nbk2 + i3*nbk3)), - S[ic]); - } + // step-by-step explanation + { + // forward-process shape grads from backward process + // parallel_for ik2,ik3: + // for irep: + // iq2 = ik2 + irep*nek2 + // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,ik2,ik3] += grad[kcur] + // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur] + // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iv2,iv3] += grad[vcur] + // for iq1: + // kcur = k[:D,:M,ik2,ik3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur + // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur + // vcur = v[:M,:D,iv2,iv3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4 + // S0 = -Inf [D,1,1,1] + // ~S1[i] = dot(kcur[:D,i], qcur) + // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale + // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P) + // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) + // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur + // ~S5[i] = dot(vcur[:,i], S4) + // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,id1,id2,id3] + // ~dst[i,iq1,iq2,iq3] = S5[i] ^ + // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3] + // dst backward-/ grad[dst] = d + // + // output gradients with their dependencies: + // + // grad[kcur] = grad[S1].T @ qcur + // grad[S1] = diag_mask_zero(grad[S3], P) * scale + // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) + // grad[S4] = grad[S5] @ vcur + // grad[S4] = d[:D,id1,id2,id3] @ vcur + // grad[qcur] = grad[S1] @ kcur + // grad[vcur] = grad[S5].T @ S4 + // grad[vcur] = d[:D,id1,id2,id3].T @ S4 + // + // in post-order: + // + // S1 = qcur @ kcur.T + // S2 = S1 * scale + // S3 = diag_mask_inf(S2, P) + // S4 = softmax(S3) + // grad[S4] = d[:D,id1,id2,id3] @ vcur + // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) + // grad[S1] = diag_mask_zero(grad[S3], P) * scale + // grad[qcur] = grad[S1] @ kcur + // grad[kcur] = grad[S1].T @ qcur + // grad[vcur] = d[:D,id1,id2,id3].T @ S4 + // + // using less variables (SM=S4): + // + // S = diag_mask_inf(qcur @ kcur.T * scale, P) + // SM = softmax(S) + // S = d[:D,iq1,iq2,iq3] @ vcur + // dot_SM_gradSM = dot(SM, S) + // S = SM * (S - dot(SM, S)) + // S = diag_mask_zero(S, P) * scale + // + // grad[q][:D,iq1,iq2,iq3] += S @ kcur + // grad[k][:D,:M,ik2,ik3] += S.T @ qcur + // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM + } - // grad[k][:D,:M,iq2,iq3] += S.T @ qcur - // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0] - // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0] - for (int64_t ic = 0; ic < M; ++ic) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; + // S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3] + // S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3] + // for ic: + // S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3] + // exclude known future zero S[..] values from operation + ggml_vec_set_f32(masked_begin, S, 0); + for (int64_t ic = 0; ic < D; ++ic) { + ggml_vec_mad_f32(masked_begin, + S, + (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), + *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3))); + } - // ggml_vec_set_f32(D, - // (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)), - // 0); - ggml_vec_mad_f32(D, - (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)), - (float *) ((char *) q->data + (i1*nbq1 + i2*nbq2 + i3*nbq3)), - S[ic]); - } + // S = SM * (S - dot(SM, S)) + float dot_SM_gradSM = 0; + ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, SM, S); + ggml_vec_acc1_f32(M, S, -dot_SM_gradSM); + ggml_vec_mul_f32 (masked_begin, S, S, SM); + + // S = diag_mask_zero(S, P) * scale + // already done by above ggml_vec_set_f32 + + // exclude known zero S[..] values from operation + ggml_vec_scale_f32(masked_begin, S, scale); + + // S shape [M,1] + // SM shape [M,1] + // kcur shape [D,M] + // qcur shape [D,1] + // vcur shape [M,D] + + // grad[q][:D,iq1,iq2,iq3] += S @ kcur + // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M] + // for ic: + // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3] + // exclude known zero S[..] values from loop + for (int64_t ic = 0; ic < masked_begin; ++ic) { + ggml_vec_mad_f32(D, + (float *) ((char *) grad_q + (iq1*nbgq1 + iq2*nbgq2 + iq3*nbgq3)), + (float *) ((char *) k->data + (ic*nbk1 + ik2*nbk2 + ik3*nbk3)), + S[ic]); + } - // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM - // grad[v][:M,ic,iq2,iq3] += d[:D,iq1,iq2,iq3].T[0,ic] * SM[:M] - // grad[v][:M,ic,iq2,iq3] += d[ic,iq1,iq2,iq3] * SM[:M] - for (int64_t ic = 0; ic < D; ++ic) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; + // grad[k][:D,:M,iq2,iq3] += S.T @ qcur + // for ic: + // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0] + // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0] + // exclude known zero S[..] values from loop + for (int64_t ic = 0; ic < masked_begin; ++ic) { + ggml_vec_mad_f32(D, + (float *) ((char *) grad_k + (ic*nbgk1 + ik2*nbgk2 + ik3*nbgk3)), + (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), + S[ic]); + } - // ggml_vec_set_f32(M, - // (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)), - // 0); - ggml_vec_mad_f32(M, - (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)), - SM, - *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3))); + // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM + // for ic: + // grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M] + // grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3] * SM[:M] + // exclude known zero SM[..] values from mad + for (int64_t ic = 0; ic < D; ++ic) { + ggml_vec_mad_f32(masked_begin, + (float *) ((char *) grad_v + ( ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)), + SM, + *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3))); + } } } } @@ -14962,8 +13865,8 @@ static void ggml_compute_forward_win_part_f32( return; } - GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) const int32_t nep0 = ((const int32_t *)(dst->op_params))[0]; const int32_t nep1 = ((const int32_t *)(dst->op_params))[1]; @@ -15024,8 +13927,8 @@ static void ggml_compute_forward_win_unpart_f32( return; } - GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) const int32_t w = ((const int32_t *)(dst->op_params))[0]; @@ -15123,6 +14026,10 @@ static void ggml_compute_forward_unary( { ggml_compute_forward_silu(params, src0, dst); } break; + case GGML_UNARY_OP_LEAKY: + { + ggml_compute_forward_leaky(params, src0, dst); + } break; default: { GGML_ASSERT(false); @@ -15142,7 +14049,7 @@ static void ggml_compute_forward_get_rel_pos_f16( // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322 - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS const int64_t w = ne1; @@ -15220,7 +14127,6 @@ static void ggml_compute_forward_add_rel_pos_f32( const int ip0 = dp*ith; const int ip1 = MIN(ip0 + dp, np); - for (int64_t i13 = ip0; i13 < ip1; ++i13) { for (int64_t i12 = 0; i12 < ne12; ++i12) { for (int64_t i11 = 0; i11 < ne11; ++i11) { @@ -15287,7 +14193,6 @@ static void ggml_compute_forward_map_unary_f32( } } - static void ggml_compute_forward_map_unary( const struct ggml_compute_params * params, const struct ggml_tensor * src0, @@ -15335,7 +14240,6 @@ static void ggml_compute_forward_map_binary_f32( } } - static void ggml_compute_forward_map_binary( const struct ggml_compute_params * params, const struct ggml_tensor * src0, @@ -15387,7 +14291,6 @@ static void ggml_compute_forward_map_custom2_f32( fun(dst, a, b); } - // ggml_compute_forward_map_custom3 static void ggml_compute_forward_map_custom3_f32( @@ -15531,7 +14434,7 @@ static void ggml_compute_forward_cross_entropy_loss_f32( #else ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max); memcpy(&scvt, &s, sizeof(scvt)); - const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]); + const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]); #endif sum += (ggml_float)val; st[i] = val; @@ -15645,7 +14548,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32( #else ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max); memcpy(&scvt, &s, sizeof(scvt)); - const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]); + const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]); #endif sum += (ggml_float)val; ds0[i] = val; @@ -15662,7 +14565,6 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32( ggml_vec_sub_f32(nc, ds0, ds0, s1); ggml_vec_scale_f32(nc, ds0, d[0] / (float) nr); - #ifndef NDEBUG for (int i = 0; i < nc; ++i) { assert(!isnan(ds0[i])); @@ -15690,12 +14592,15 @@ static void ggml_compute_forward_cross_entropy_loss_back( } } - ///////////////////////////////// static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { GGML_ASSERT(params); + if (tensor->op == GGML_OP_NONE) { + return; + } + #ifdef GGML_USE_CUBLAS bool skip_cpu = ggml_cuda_compute_forward(params, tensor); if (skip_cpu) { @@ -15840,7 +14745,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_GET_ROWS_BACK: { - ggml_compute_forward_get_rows_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor); + ggml_compute_forward_get_rows_back(params, tensor->src[0], tensor->src[1], tensor); } break; case GGML_OP_DIAG: { @@ -15864,11 +14769,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_ROPE: { - ggml_compute_forward_rope(params, tensor->src[0], tensor); + ggml_compute_forward_rope(params, tensor->src[0], tensor->src[1], tensor); } break; case GGML_OP_ROPE_BACK: { - ggml_compute_forward_rope_back(params, tensor->src[0], tensor); + ggml_compute_forward_rope_back(params, tensor->src[0], tensor->src[1], tensor); } break; case GGML_OP_ALIBI: { @@ -15882,10 +14787,30 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_conv_1d(params, tensor->src[0], tensor->src[1], tensor); } break; + case GGML_OP_CONV_1D_STAGE_0: + { + ggml_compute_forward_conv_1d_stage_0(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_CONV_1D_STAGE_1: + { + ggml_compute_forward_conv_1d_stage_1(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_CONV_TRANSPOSE_1D: + { + ggml_compute_forward_conv_transpose_1d(params, tensor->src[0], tensor->src[1], tensor); + } break; case GGML_OP_CONV_2D: { ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor); } break; + case GGML_OP_CONV_2D_STAGE_0: + { + ggml_compute_forward_conv_2d_stage_0(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_CONV_2D_STAGE_1: + { + ggml_compute_forward_conv_2d_stage_1(params, tensor->src[0], tensor->src[1], tensor); + } break; case GGML_OP_CONV_TRANSPOSE_2D: { ggml_compute_forward_conv_transpose_2d(params, tensor->src[0], tensor->src[1], tensor); @@ -16013,7 +14938,265 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm //////////////////////////////////////////////////////////////////////////////// -static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, bool inplace) { +static size_t ggml_hash_size(size_t min_sz) { + // next primes after powers of two + static const size_t primes[] = { + 2, 3, 5, 11, 17, 37, 67, 131, 257, 521, 1031, + 2053, 4099, 8209, 16411, 32771, 65537, 131101, + 262147, 524309, 1048583, 2097169, 4194319, 8388617, + 16777259, 33554467, 67108879, 134217757, 268435459, + 536870923, 1073741827, 2147483659 + }; + static const size_t n_primes = sizeof(primes)/sizeof(primes[0]); + + // find the smallest prime that is larger or equal to min_sz + size_t l = 0; + size_t r = n_primes; + while (l < r) { + size_t m = (l + r)/2; + if (primes[m] < min_sz) { + l = m + 1; + } else { + r = m; + } + } + size_t sz = l < n_primes ? primes[l] : min_sz | 1; + return sz; +} + +static size_t ggml_hash(const void * p) { + return (size_t)p; +} + +size_t ggml_hash_find(const struct ggml_hash_set hash_set, struct ggml_tensor * key) { + size_t h = ggml_hash(key) % hash_set.size; + + // linear probing + size_t i = h; + while (hash_set.keys[i] != NULL && hash_set.keys[i] != key) { + i = (i + 1) % hash_set.size; + if (i == h) { + // visited all hash table entries -> not found + return GGML_HASHTABLE_FULL; + } + } + return i; +} + +bool ggml_hash_contains(struct ggml_hash_set hash_set, struct ggml_tensor * key) { + size_t i = ggml_hash_find(hash_set, key); + return i != GGML_HASHTABLE_FULL && hash_set.keys[i] == key; +} + +size_t ggml_hash_insert(struct ggml_hash_set hash_set, struct ggml_tensor * key) { + size_t i = ggml_hash_find(hash_set, key); + + GGML_ASSERT(i != GGML_HASHTABLE_FULL); + + if (hash_set.keys[i] == key) { + return GGML_HASHTABLE_ALREADY_EXISTS; + } + + // insert + GGML_ASSERT(hash_set.keys[i] == NULL); + hash_set.keys[i] = key; + return i; +} + +size_t ggml_hash_find_or_insert(struct ggml_hash_set hash_set, struct ggml_tensor * key) { + size_t i = ggml_hash_find(hash_set, key); + + GGML_ASSERT(i != GGML_HASHTABLE_FULL); + + hash_set.keys[i] = key; + return i; +} + +static struct ggml_hash_set ggml_hash_set_new(size_t size) { + size = ggml_hash_size(size); + struct ggml_hash_set result; + result.size = size; + result.keys = malloc(sizeof(struct ggml_tensor *) * size); + memset(result.keys, 0, sizeof(struct ggml_tensor *) * size); + return result; +} + +static void ggml_hash_set_free(struct ggml_hash_set hash_set) { + free(hash_set.keys); +} + +struct hash_map { + struct ggml_hash_set set; + struct ggml_tensor ** vals; +}; + +static struct hash_map * ggml_new_hash_map(size_t size) { + struct hash_map * result = malloc(sizeof(struct hash_map)); + result->set = ggml_hash_set_new(size); + result->vals = malloc(sizeof(struct ggml_tensor *) * result->set.size); + memset(result->vals, 0, sizeof(struct ggml_tensor *) * result->set.size); + return result; +} + +static void ggml_hash_map_free(struct hash_map * map) { + ggml_hash_set_free(map->set); + free(map->vals); + free(map); +} + +// gradient checkpointing + +static struct ggml_tensor * ggml_recompute_graph_node( + struct ggml_context * ctx, + struct ggml_cgraph * graph, + struct hash_map * replacements, + struct ggml_tensor * node) { + + if (node == NULL) { + return NULL; + } + + if (node->is_param) { + return node; + } + + if (!ggml_hash_contains(graph->visited_hash_table, node)) { + return node; + } + + int count_children = 0; + for (int k = 0; k < GGML_MAX_SRC; ++k) { + if (node->src[k]) { + ++count_children; + } + } + + if (count_children == 0) { + return node; + } + + size_t i = ggml_hash_find(replacements->set, node); + GGML_ASSERT(i != GGML_HASHTABLE_FULL); // assert that not full + if (replacements->set.keys[i] == node) { + return replacements->vals[i]; + } + + struct ggml_tensor * clone = ggml_new_tensor(ctx, node->type, node->n_dims, node->ne); + + // insert clone into replacements + GGML_ASSERT(replacements->set.keys[i] == NULL); // assert that we don't overwrite + replacements->set.keys[i] = node; + replacements->vals[i] = clone; + + clone->op = node->op; + clone->grad = node->grad; + clone->is_param = node->is_param; + clone->extra = node->extra; + for (int k = 0; k < GGML_MAX_DIMS; ++k) { + clone->nb[k] = node->nb[k]; + } + for (int k = 0; k < GGML_MAX_SRC; ++k) { + clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]); + } + if (node->view_src != NULL) { + clone->data = (node->view_src->data == NULL) + ? NULL // view_src not yet allocated + : (char *) node->view_src->data // view_src already allocated + + node->view_offs; + clone->view_src = node->view_src; + clone->view_offs = node->view_offs; + } + + GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t))); + GGML_ASSERT(sizeof(node->name) == GGML_MAX_NAME); + memcpy(clone->op_params, node->op_params, sizeof(node->op_params)); + ggml_format_name(clone, "%s (clone)", ggml_get_name(node)); + + return clone; +} + +void ggml_build_backward_gradient_checkpointing( + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct ggml_cgraph * gb, + struct ggml_cgraph * gb_tmp, + struct ggml_tensor * * checkpoints, + int n_checkpoints) { + ggml_graph_cpy(gf, gb_tmp); + ggml_build_backward_expand(ctx, gf, gb_tmp, true); + + if (n_checkpoints <= 0) { + ggml_graph_cpy(gb_tmp, gb); + return; + } + + struct hash_map * replacements = ggml_new_hash_map(gf->n_nodes + gf->n_leafs + n_checkpoints); + + // insert checkpoints in replacements + for (int i = 0; i < n_checkpoints; ++i) { + size_t k = ggml_hash_find(replacements->set, checkpoints[i]); + GGML_ASSERT(k != GGML_HASHTABLE_FULL); // assert that not full + GGML_ASSERT(replacements->set.keys[k] == NULL); // assert that we don't overwrite + replacements->set.keys[k] = checkpoints[i]; + replacements->vals[k] = checkpoints[i]; + } + + ggml_graph_cpy(gf, gb); + // rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes], + // replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]), + // by recomputing them from checkpoints + for (int i = gf->n_nodes; in_nodes; ++i) { + struct ggml_tensor * node = gb_tmp->nodes[i]; + for (int k = 0; k < GGML_MAX_SRC; ++k) { + // insert new tensors recomputing src, reusing already made replacements, + // remember replacements: remember new tensors with mapping from corresponding gf nodes + // recurse for input tensors, + // unless (i.e. terminating when) input tensors are replacments (like checkpoints) + node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]); + } + // insert rewritten backward node with replacements made into resulting backward graph gb + ggml_build_forward_expand(gb, node); + } + + ggml_hash_map_free(replacements); +} + +// functions to change gradients considering the case that input a might be initial gradient with zero value + +static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set zero_table) { + if (ggml_hash_contains(zero_table, a)) { + return b; + } else { + return ggml_add_impl(ctx, a, b, false); + } +} + +static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, struct ggml_hash_set zero_table) { + if (ggml_hash_contains(zero_table, a)) { + struct ggml_tensor * a_zero = ggml_scale(ctx, a, ggml_new_f32(ctx, 0)); + return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false); + } else { + return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false); + } +} + +static struct ggml_tensor * ggml_add1_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set zero_table) { + if (ggml_hash_contains(zero_table, a)) { + return ggml_repeat(ctx, b, a); + } else { + return ggml_add1_impl(ctx, a, b, false); + } +} + +static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set zero_table) { + if (ggml_hash_contains(zero_table, a)) { + return ggml_neg(ctx, b); + } else { + return ggml_sub_impl(ctx, a, b, false); + } +} + +static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set zero_table) { struct ggml_tensor * src0 = tensor->src[0]; struct ggml_tensor * src1 = tensor->src[1]; @@ -16021,34 +15204,34 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor case GGML_OP_DUP: { if (src0->grad) { - src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); } } break; case GGML_OP_ADD: { if (src0->grad) { - src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); } if (src1->grad) { - src1->grad = ggml_add_impl(ctx, src1->grad, tensor->grad, inplace); + src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table); } } break; case GGML_OP_ADD1: { if (src0->grad) { - src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); } if (src1->grad) { - src1->grad = ggml_add_impl(ctx, + src1->grad = ggml_add_or_set(ctx, src1->grad, ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean - inplace); + zero_table); } } break; case GGML_OP_ACC: { if (src0->grad) { - src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); } if (src1->grad) { const size_t nb1 = ((int32_t *) tensor->op_params)[0]; @@ -16065,117 +15248,117 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor nb1, nb2, nb3, offset); src1->grad = - ggml_add_impl(ctx, + ggml_add_or_set(ctx, src1->grad, ggml_reshape(ctx, ggml_cont(ctx, tensor_grad_view), src1->grad), - inplace); + zero_table); } } break; case GGML_OP_SUB: { if (src0->grad) { - src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); } if (src1->grad) { - src1->grad = ggml_sub_impl(ctx, src1->grad, tensor->grad, inplace); + src1->grad = ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table); } } break; case GGML_OP_MUL: { if (src0->grad) { src0->grad = - ggml_add_impl(ctx, + ggml_add_or_set(ctx, src0->grad, ggml_mul(ctx, src1, tensor->grad), - inplace); + zero_table); } if (src1->grad) { src1->grad = - ggml_add_impl(ctx, + ggml_add_or_set(ctx, src1->grad, ggml_mul(ctx, src0, tensor->grad), - inplace); + zero_table); } } break; case GGML_OP_DIV: { if (src0->grad) { src0->grad = - ggml_add_impl(ctx, + ggml_add_or_set(ctx, src0->grad, ggml_div(ctx, tensor->grad, src1), - inplace); + zero_table); } if (src1->grad) { src1->grad = - ggml_sub_impl(ctx, + ggml_sub_or_set(ctx, src1->grad, ggml_mul(ctx, tensor->grad, ggml_div(ctx, tensor, src1)), - inplace); + zero_table); } } break; case GGML_OP_SQR: { if (src0->grad) { src0->grad = - ggml_add_impl(ctx, + ggml_add_or_set(ctx, src0->grad, ggml_scale(ctx, ggml_mul(ctx, src0, tensor->grad), ggml_new_f32(ctx, 2.0f)), - inplace); + zero_table); } } break; case GGML_OP_SQRT: { if (src0->grad) { src0->grad = - ggml_add_impl(ctx, + ggml_add_or_set(ctx, src0->grad, ggml_scale(ctx, ggml_div(ctx, tensor->grad, tensor), ggml_new_f32(ctx, 0.5f)), - inplace); + zero_table); } } break; case GGML_OP_LOG: { if (src0->grad) { src0->grad = - ggml_add_impl(ctx, + ggml_add_or_set(ctx, src0->grad, ggml_div(ctx, tensor->grad, src0), - inplace); + zero_table); } } break; case GGML_OP_SUM: { if (src0->grad) { src0->grad = - ggml_add1_impl(ctx, + ggml_add1_or_set(ctx, src0->grad, tensor->grad, - inplace); + zero_table); } } break; case GGML_OP_SUM_ROWS: { if (src0->grad) { src0->grad = - ggml_add_impl(ctx, + ggml_add_or_set(ctx, src0->grad, ggml_repeat(ctx, tensor->grad, src0->grad), - inplace); + zero_table); } } break; case GGML_OP_MEAN: @@ -16187,20 +15370,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { // necessary for llama if (src0->grad) { - src0->grad = ggml_add_impl(ctx, + src0->grad = ggml_add_or_set(ctx, src0->grad, ggml_repeat_back(ctx, tensor->grad, src0->grad), - inplace); + zero_table); } } break; case GGML_OP_REPEAT_BACK: { if (src0->grad) { // TODO: test this - src0->grad = ggml_add_impl(ctx, + src0->grad = ggml_add_or_set(ctx, src0->grad, ggml_repeat(ctx, tensor->grad, src0->grad), - inplace); + zero_table); } } break; case GGML_OP_CONCAT: @@ -16222,10 +15405,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor float eps; memcpy(&eps, tensor->op_params, sizeof(float)); - src0->grad = ggml_add_impl(ctx, + src0->grad = ggml_add_or_set(ctx, src0->grad, ggml_rms_norm_back(ctx, src0, tensor->grad, eps), - inplace); + zero_table); } } break; case GGML_OP_RMS_NORM_BACK: @@ -16249,37 +15432,49 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix // ds1 = t.T.dot(dt) - // tensor.shape [m,p] - // src0.shape [n,m] - // src1.shape [n,p] + // tensor.shape [m,p,qq,rr] + // src0.shape [n,m,q1,r1] + // src1.shape [n,p,qq,rr] // necessary for llama if (src0->grad) { + struct ggml_tensor * s1_tg = + ggml_out_prod(ctx, // [n,m,qq,rr] + src1, // [n,p,qq,rr] + tensor->grad); // [m,p,qq,rr] + const int64_t qq = s1_tg->ne[2]; + const int64_t rr = s1_tg->ne[3]; + const int64_t q1 = src0->ne[2]; + const int64_t r1 = src0->ne[3]; + const bool ne2_broadcasted = qq > q1; + const bool ne3_broadcasted = rr > r1; + if (ne2_broadcasted || ne3_broadcasted) { + // sum broadcast repetitions of s1_tg into shape of src0 + s1_tg = ggml_repeat_back(ctx, s1_tg, src0); + } src0->grad = - ggml_add_impl(ctx, - src0->grad, - ggml_out_prod(ctx, // [n,m] - src1, // [n,p] - tensor->grad), // [m,p] - inplace); + ggml_add_or_set(ctx, + src0->grad, // [n,m,q1,r1] + s1_tg, // [n,m,q1,r1] + zero_table); } if (src1->grad) { src1->grad = - ggml_add_impl(ctx, - src1->grad, - // ggml_mul_mat(ctx, // [n,p] - // ggml_cont(ctx, // [m,n] - // ggml_transpose(ctx, src0)), // [m,n] - // tensor->grad), // [m,p] + ggml_add_or_set(ctx, + src1->grad, // [n,p,qq,rr] + // ggml_mul_mat(ctx, // [n,p,qq,rr] + // ggml_cont(ctx, // [m,n,q1,r1] + // ggml_transpose(ctx, src0)), // [m,n,q1,r1] + // tensor->grad), // [m,p,qq,rr] // // when src0 is bigger than tensor->grad (this is mostly the case in llama), // // avoid transpose of src0, rather transpose smaller tensor->grad // // and then use ggml_out_prod - ggml_out_prod(ctx, // [n,p] - src0, // [n,m] - ggml_transpose(ctx, // [p,m] - tensor->grad)), // [m,p] - inplace); + ggml_out_prod(ctx, // [n,p,qq,rr] + src0, // [n,m,q1,r1] + ggml_transpose(ctx, // [p,m,qq,rr] + tensor->grad)), // [m,p,qq,rr] + zero_table); } } break; case GGML_OP_OUT_PROD: @@ -16291,17 +15486,17 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor // necessary for llama if (src0->grad) { src0->grad = - ggml_add_impl(ctx, + ggml_add_or_set(ctx, src0->grad, ggml_scale_impl(ctx, tensor->grad, src1, false), - inplace); + zero_table); } if (src1->grad) { src1->grad = - ggml_add_impl(ctx, + ggml_add_or_set(ctx, src1->grad, ggml_sum(ctx, ggml_mul_impl(ctx, tensor->grad, src0, false)), - inplace); + zero_table); } } break; case GGML_OP_SET: @@ -16328,23 +15523,23 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor } if (src0->grad) { - src0->grad = ggml_add_impl(ctx, + src0->grad = ggml_add_or_set(ctx, src0->grad, ggml_acc_impl(ctx, tensor->grad, ggml_neg(ctx, tensor_grad_view), nb1, nb2, nb3, offset, false), - inplace); + zero_table); } if (src1->grad) { src1->grad = - ggml_add_impl(ctx, + ggml_add_or_set(ctx, src1->grad, ggml_reshape(ctx, ggml_cont(ctx, tensor_grad_view), src1->grad), - inplace); + zero_table); } } break; case GGML_OP_CPY: @@ -16355,7 +15550,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor // tensor = src0 * 1 + src1 * 0 if (src0->grad) { // dsrc0 = dtensor * 1 - src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); } if (src1->grad) { // dsrc1 = dtensor * 0 -> noop @@ -16367,7 +15562,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor if (src0->grad) { GGML_ASSERT(ggml_is_contiguous(src0->grad)); GGML_ASSERT(ggml_is_contiguous(tensor->grad)); - src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); } } break; case GGML_OP_RESHAPE: @@ -16375,9 +15570,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor // necessary for llama if (src0->grad) { src0->grad = - ggml_add_impl(ctx, src0->grad, - ggml_reshape(ctx, tensor->grad, src0->grad), - inplace); + ggml_add_or_set(ctx, src0->grad, + ggml_reshape(ctx, + ggml_is_contiguous(tensor->grad) + ? tensor->grad + : ggml_cont(ctx, tensor->grad), + src0->grad), + zero_table); } } break; case GGML_OP_VIEW: @@ -16406,7 +15605,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor nb3 = (nb3 / n0) * ng; } - src0->grad = ggml_acc_impl(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, inplace); + src0->grad = ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table); } } break; case GGML_OP_PERMUTE: @@ -16424,14 +15623,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor axes_backward[axis2] = 2; axes_backward[axis3] = 3; src0->grad = - ggml_add_impl(ctx, src0->grad, + ggml_add_or_set(ctx, src0->grad, ggml_permute(ctx, tensor->grad, axes_backward[0], axes_backward[1], axes_backward[2], axes_backward[3]), - inplace); + zero_table); } } break; case GGML_OP_TRANSPOSE: @@ -16439,9 +15638,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor // necessary for llama if (src0->grad) { src0->grad = - ggml_add_impl(ctx, src0->grad, + ggml_add_or_set(ctx, src0->grad, ggml_transpose(ctx, tensor->grad), - inplace); + zero_table); } } break; case GGML_OP_GET_ROWS: @@ -16449,9 +15648,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor // necessary for llama (only for tokenizer) if (src0->grad) { src0->grad = - ggml_add_impl(ctx, src0->grad, + ggml_add_or_set(ctx, src0->grad, + // last ggml_get_rows_back argument src0->grad is only + // necessary to setup correct output shape ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad), - inplace); + zero_table); } if (src1->grad) { // noop @@ -16471,9 +15672,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor if (src0->grad) { const int n_past = ((int32_t *) tensor->op_params)[0]; src0->grad = - ggml_add_impl(ctx, src0->grad, + ggml_add_or_set(ctx, src0->grad, ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false), - inplace); + zero_table); } } break; case GGML_OP_DIAG_MASK_ZERO: @@ -16482,9 +15683,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor if (src0->grad) { const int n_past = ((int32_t *) tensor->op_params)[0]; src0->grad = - ggml_add_impl(ctx, src0->grad, + ggml_add_or_set(ctx, src0->grad, ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false), - inplace); + zero_table); } } break; case GGML_OP_SOFT_MAX: @@ -16492,9 +15693,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor // necessary for llama if (src0->grad) { src0->grad = - ggml_add_impl(ctx, src0->grad, + ggml_add_or_set(ctx, src0->grad, ggml_soft_max_back(ctx, tensor->grad, tensor), - inplace); + zero_table); } } break; @@ -16506,7 +15707,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { // necessary for llama if (src0->grad) { - const int n_past = ((int32_t *) tensor->op_params)[0]; + //const int n_past = ((int32_t *) tensor->op_params)[0]; const int n_dims = ((int32_t *) tensor->op_params)[1]; const int mode = ((int32_t *) tensor->op_params)[2]; const int n_ctx = ((int32_t *) tensor->op_params)[3]; @@ -16519,11 +15720,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float)); memcpy(&xpos_down, (int32_t *) tensor->op_params + 7, sizeof(bool)); - src0->grad = ggml_add_impl(ctx, + src0->grad = ggml_add_or_set(ctx, src0->grad, ggml_rope_back(ctx, tensor->grad, - n_past, + src1, n_dims, mode, n_ctx, @@ -16531,13 +15732,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor freq_scale, xpos_base, xpos_down), - inplace); + zero_table); } } break; case GGML_OP_ROPE_BACK: { if (src0->grad) { - const int n_past = ((int32_t *) tensor->op_params)[0]; + //const int n_past = ((int32_t *) tensor->op_params)[0]; const int n_dims = ((int32_t *) tensor->op_params)[1]; const int mode = ((int32_t *) tensor->op_params)[2]; const int n_ctx = ((int32_t *) tensor->op_params)[3]; @@ -16550,20 +15751,25 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float)); memcpy(&xpos_down, (int32_t *) tensor->op_params + 7, sizeof(bool)); - src0->grad = ggml_add_impl(ctx, + src0->grad = ggml_add_or_set(ctx, src0->grad, ggml_rope_impl(ctx, tensor->grad, - n_past, + src1, n_dims, mode, + 0, n_ctx, freq_base, freq_scale, + 0.0f, + 1.0f, + 0.0f, + 0.0f, xpos_base, xpos_down, false), - inplace); + zero_table); } } break; case GGML_OP_ALIBI: @@ -16578,10 +15784,30 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ASSERT(false); // TODO: not implemented } break; + case GGML_OP_CONV_1D_STAGE_0: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_CONV_1D_STAGE_1: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_CONV_TRANSPOSE_1D: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_CONV_2D: { GGML_ASSERT(false); // TODO: not implemented } break; + case GGML_OP_CONV_2D_STAGE_0: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_CONV_2D_STAGE_1: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_CONV_TRANSPOSE_2D: { GGML_ASSERT(false); // TODO: not implemented @@ -16614,145 +15840,42 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor masked); } - if (src0->grad) { - struct ggml_tensor * grad_q = NULL; - const size_t nb0 = flash_grad->nb[0]; - const size_t offset = 0; - switch(src0->n_dims) { - case 2: - { - grad_q = ggml_view_2d(ctx, - flash_grad, - src0->ne[0], - src0->ne[1], - nb0*src0->ne[0], - offset); - } break; - case 3: - { - grad_q = ggml_view_3d(ctx, - flash_grad, - src0->ne[0], - src0->ne[1], - src0->ne[2], - nb0*src0->ne[0], - nb0*src0->ne[0]*src0->ne[1], - offset); - } break; - case 4: - { - grad_q = ggml_view_4d(ctx, - flash_grad, - src0->ne[0], - src0->ne[1], - src0->ne[2], - src0->ne[3], - nb0*src0->ne[0], - nb0*src0->ne[0]*src0->ne[1], - nb0*src0->ne[0]*src0->ne[1]*src0->ne[2], - offset); - } break; - } + struct ggml_tensor * src2 = tensor->src[2]; + const int64_t elem_q = ggml_nelements(src0); + const int64_t elem_k = ggml_nelements(src1); + const int64_t elem_v = ggml_nelements(src2); + + enum ggml_type result_type = flash_grad->type; + GGML_ASSERT(ggml_blck_size(result_type) == 1); + const size_t tsize = ggml_type_size(result_type); - src0->grad = ggml_add_impl(ctx, + const size_t offs_q = 0; + const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN); + const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN); + + if (src0->grad) { + struct ggml_tensor * view_q = ggml_view_1d(ctx, flash_grad, elem_q, offs_q); + struct ggml_tensor * grad_q = ggml_reshape(ctx, view_q, src0); + src0->grad = ggml_add_or_set(ctx, src0->grad, grad_q, - inplace); + zero_table); } - - if (src1->grad) { - struct ggml_tensor * grad_k = NULL; - const size_t nb0 = flash_grad->nb[0]; - const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3]; - switch(src1->n_dims) { - case 2: - { - grad_k = ggml_view_2d(ctx, - flash_grad, - src1->ne[0], - src1->ne[1], - nb0*src1->ne[0], - offset); - } break; - case 3: - { - grad_k = ggml_view_3d(ctx, - flash_grad, - src1->ne[0], - src1->ne[1], - src1->ne[2], - nb0*src1->ne[0], - nb0*src1->ne[0]*src1->ne[1], - offset); - } break; - case 4: - { - grad_k = ggml_view_4d(ctx, - flash_grad, - src1->ne[0], - src1->ne[1], - src1->ne[2], - src1->ne[3], - nb0*src1->ne[0], - nb0*src1->ne[0]*src1->ne[1], - nb0*src1->ne[0]*src1->ne[1]*src1->ne[2], - offset); - } break; - } - - src1->grad = ggml_add_impl(ctx, + if (src1->grad) { + struct ggml_tensor * view_k = ggml_view_1d(ctx, flash_grad, elem_k, offs_k); + struct ggml_tensor * grad_k = ggml_reshape(ctx, view_k, src1); + src1->grad = ggml_add_or_set(ctx, src1->grad, grad_k, - inplace); + zero_table); } - - struct ggml_tensor * opt0 = tensor->src[2]; - - if (opt0->grad) { - struct ggml_tensor * grad_v = NULL; - const size_t nb0 = flash_grad->nb[0]; - const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3] - + nb0*src1->ne[0]*src1->ne[1]*src1->ne[2]*src1->ne[3]; - switch(opt0->n_dims) { - case 2: - { - grad_v = ggml_view_2d(ctx, - flash_grad, - opt0->ne[0], - opt0->ne[1], - nb0*opt0->ne[0], - offset); - } break; - case 3: - { - grad_v = ggml_view_3d(ctx, - flash_grad, - opt0->ne[0], - opt0->ne[1], - opt0->ne[2], - nb0*opt0->ne[0], - nb0*opt0->ne[0]*opt0->ne[1], - offset); - } break; - case 4: - { - grad_v = ggml_view_4d(ctx, - flash_grad, - opt0->ne[0], - opt0->ne[1], - opt0->ne[2], - opt0->ne[3], - nb0*opt0->ne[0], - nb0*opt0->ne[0]*opt0->ne[1], - nb0*opt0->ne[0]*opt0->ne[1]*opt0->ne[2], - offset); - } break; - } - - opt0->grad = ggml_add_impl(ctx, - opt0->grad, + if (src2->grad) { + struct ggml_tensor * view_v = ggml_view_1d(ctx, flash_grad, elem_v, offs_v); + struct ggml_tensor * grad_v = ggml_reshape(ctx, view_v, src2); + src2->grad = ggml_add_or_set(ctx, + src2->grad, grad_v, - inplace); + zero_table); } } break; case GGML_OP_FLASH_FF: @@ -16772,12 +15895,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { if (src0->grad) { src0->grad = - ggml_add_impl(ctx, + ggml_add_or_set(ctx, src0->grad, ggml_mul(ctx, ggml_sgn(ctx, src0), tensor->grad), - inplace); + zero_table); } } break; case GGML_UNARY_OP_SGN: @@ -16789,7 +15912,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor case GGML_UNARY_OP_NEG: { if (src0->grad) { - src0->grad = ggml_sub_impl(ctx, src0->grad, tensor->grad, inplace); + src0->grad = ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table); } } break; case GGML_UNARY_OP_STEP: @@ -16809,12 +15932,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor case GGML_UNARY_OP_RELU: { if (src0->grad) { - src0->grad = ggml_add_impl(ctx, + src0->grad = ggml_add_or_set(ctx, src0->grad, ggml_mul(ctx, ggml_step(ctx, src0), tensor->grad), - inplace); + zero_table); } } break; case GGML_UNARY_OP_GELU: @@ -16829,10 +15952,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { // necessary for llama if (src0->grad) { - src0->grad = ggml_add_impl(ctx, + src0->grad = ggml_add_or_set(ctx, src0->grad, ggml_silu_back(ctx, src0, tensor->grad), - inplace); + zero_table); } } break; default: @@ -16855,13 +15978,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor case GGML_OP_CROSS_ENTROPY_LOSS: { if (src0->grad) { - src0->grad = ggml_add_impl(ctx, + src0->grad = ggml_add_or_set(ctx, src0->grad, ggml_cross_entropy_loss_back(ctx, src0, src1, tensor->grad), - inplace); + zero_table); } } break; case GGML_OP_CROSS_ENTROPY_LOSS_BACK: @@ -16877,34 +16000,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor GGML_ASSERT(false); } break; } -} - -static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small"); - -static size_t hash(void * p) { - return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE; -} - -static bool hash_insert(void * hash_table[], void * p) { - size_t h = hash(p); - // linear probing - size_t i = h; - while (hash_table[i] != NULL && hash_table[i] != p) { - i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE; - if (i == h) { - // hash table is full - GGML_ASSERT(false); + for (int i = 0; i < GGML_MAX_SRC; ++i) { + if (tensor->src[i] && tensor->src[i]->grad) { + GGML_ASSERT(ggml_are_same_shape(tensor->src[i], tensor->src[i]->grad)); } } - - if (hash_table[i] == p) { - return true; - } - - // insert - hash_table[i] = p; - return false; } static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) { @@ -16917,19 +16018,23 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * } // check if already visited - if (hash_insert(cgraph->visited_hash_table, node)) { + if (ggml_hash_insert(cgraph->visited_hash_table, node) == GGML_HASHTABLE_ALREADY_EXISTS) { return; } for (int i = 0; i < GGML_MAX_SRC; ++i) { - if (node->src[i]) { - ggml_visit_parents(cgraph, node->src[i]); + const int k = + (cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i : + (cgraph->order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (GGML_MAX_SRC-1-i) : + /* unknown order, just fall back to using i*/ i; + if (node->src[k]) { + ggml_visit_parents(cgraph, node->src[k]); } } if (node->op == GGML_OP_NONE && node->grad == NULL) { // reached a leaf node, not part of the gradient graph (e.g. a constant) - GGML_ASSERT(cgraph->n_leafs < GGML_MAX_NODES); + GGML_ASSERT(cgraph->n_leafs < cgraph->size); if (strlen(node->name) == 0) { ggml_format_name(node, "leaf_%d", cgraph->n_leafs); @@ -16938,22 +16043,24 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * cgraph->leafs[cgraph->n_leafs] = node; cgraph->n_leafs++; } else { - GGML_ASSERT(cgraph->n_nodes < GGML_MAX_NODES); + GGML_ASSERT(cgraph->n_nodes < cgraph->size); if (strlen(node->name) == 0) { ggml_format_name(node, "node_%d", cgraph->n_nodes); } cgraph->nodes[cgraph->n_nodes] = node; - cgraph->grads[cgraph->n_nodes] = node->grad; + if (cgraph->grads) { + cgraph->grads[cgraph->n_nodes] = node->grad; + } cgraph->n_nodes++; } } static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) { if (!expand) { - cgraph->n_nodes = 0; - cgraph->n_leafs = 0; + // TODO: this branch isn't accessible anymore, maybe move this to ggml_build_forward_expand + ggml_graph_clear(cgraph); } const int n0 = cgraph->n_nodes; @@ -16974,24 +16081,6 @@ void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * ggml_build_forward_impl(cgraph, tensor, true); } -struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) { - struct ggml_cgraph result = { - /*.n_nodes =*/ 0, - /*.n_leafs =*/ 0, - /*.nodes =*/ { NULL }, - /*.grads =*/ { NULL }, - /*.leafs =*/ { NULL }, - /*.hash_table =*/ { NULL }, - /*.perf_runs =*/ 0, - /*.perf_cycles =*/ 0, - /*.perf_time_us =*/ 0, - }; - - ggml_build_forward_impl(&result, tensor, false); - - return result; -} - void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep) { GGML_ASSERT(gf->n_nodes > 0); @@ -17007,12 +16096,21 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * } } + // remember original gradients which start with zero values + struct ggml_hash_set zero_table = ggml_hash_set_new(gf->size); + for (int i = 0; i < gf->n_nodes; i++) { + if (gf->grads[i]) { + ggml_hash_insert(zero_table, gf->grads[i]); + } + } + for (int i = gf->n_nodes - 1; i >= 0; i--) { struct ggml_tensor * node = gf->nodes[i]; - // because we detached the grad nodes from the original graph, we can afford inplace operations + // inplace operations to add gradients are not created by ggml_compute_backward + // use allocator to automatically make inplace operations if (node->grad) { - ggml_compute_backward(ctx, node, keep); + ggml_compute_backward(ctx, node, zero_table); } } @@ -17024,25 +16122,56 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * ggml_build_forward_expand(gb, node->grad); } } + + ggml_hash_set_free(zero_table); } -struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep) { - struct ggml_cgraph result = *gf; - ggml_build_backward_expand(ctx, gf, &result, keep); - return result; +static size_t ggml_graph_nbytes(size_t size, bool grads) { + size_t nbytes = sizeof(struct ggml_cgraph); + nbytes += size * sizeof(struct ggml_tensor *) * 2; // leafs + nodes + if (grads) { + nbytes += size * sizeof(struct ggml_tensor *); // grads + } + nbytes += ggml_hash_size(size * 2) * sizeof(struct ggml_tensor *); // hash set + return nbytes; } -struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) { - struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_GRAPH, GGML_GRAPH_SIZE); +size_t ggml_graph_overhead_custom(size_t size, bool grads) { + return GGML_OBJECT_SIZE + GGML_PAD(ggml_graph_nbytes(size, grads), GGML_MEM_ALIGN); +} + +size_t ggml_graph_overhead(void) { + return ggml_graph_overhead_custom(GGML_DEFAULT_GRAPH_SIZE, false); +} + +struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t size, bool grads) { + const size_t obj_size = ggml_graph_nbytes(size, grads); + struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_GRAPH, obj_size); struct ggml_cgraph * cgraph = (struct ggml_cgraph *) ((char *) ctx->mem_buffer + obj->offs); + struct ggml_tensor ** data_start = (struct ggml_tensor **) (cgraph + 1); + + size_t hash_size = ggml_hash_size(size * 2); + struct ggml_tensor ** nodes_ptr = data_start; + struct ggml_tensor ** leafs_ptr = nodes_ptr + size; + struct ggml_tensor ** hash_keys_ptr = leafs_ptr + size; + struct ggml_tensor ** grads_ptr = grads ? hash_keys_ptr + hash_size : NULL; + + // check that we allocated the correct amount of memory + assert(obj_size == (size_t) ( + (grads ? (char *)(grads_ptr + size) : (char *)(hash_keys_ptr + hash_size)) - (char *)cgraph)); + + memset(hash_keys_ptr, 0, hash_size * sizeof(struct ggml_tensor *)); + *cgraph = (struct ggml_cgraph) { + /*.size =*/ size, /*.n_nodes =*/ 0, /*.n_leafs =*/ 0, - /*.nodes =*/ { NULL }, - /*.grads =*/ { NULL }, - /*.leafs =*/ { NULL }, - /*.hash_table =*/ { NULL }, + /*.nodes =*/ nodes_ptr, + /*.grads =*/ grads_ptr, + /*.leafs =*/ leafs_ptr, + /*.hash_table =*/ { hash_size, hash_keys_ptr }, + /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT, /*.perf_runs =*/ 0, /*.perf_cycles =*/ 0, /*.perf_time_us =*/ 0, @@ -17051,14 +16180,85 @@ struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) { return cgraph; } -struct ggml_cgraph * ggml_build_forward_ctx(struct ggml_context * ctx, struct ggml_tensor * tensor) { - struct ggml_cgraph * cgraph = ggml_new_graph(ctx); - ggml_build_forward_impl(cgraph, tensor, false); +struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) { + return ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, false); +} + +struct ggml_cgraph * ggml_graph_view(struct ggml_context * ctx, struct ggml_cgraph * cgraph0, int i0, int i1) { + const size_t obj_size = sizeof(struct ggml_cgraph); + struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_GRAPH, obj_size); + struct ggml_cgraph * cgraph = (struct ggml_cgraph *) ((char *) ctx->mem_buffer + obj->offs); + + *cgraph = (struct ggml_cgraph) { + /*.size =*/ 0, + /*.n_nodes =*/ i1 - i0, + /*.n_leafs =*/ 0, + /*.nodes =*/ cgraph0->nodes + i0, + /*.grads =*/ cgraph0->grads ? cgraph0->grads + i0 : NULL, + /*.leafs =*/ NULL, + /*.hash_table =*/ { 0, NULL }, + /*.order =*/ cgraph0->order, + /*.perf_runs =*/ 0, + /*.perf_cycles =*/ 0, + /*.perf_time_us =*/ 0, + }; + return cgraph; } -size_t ggml_graph_overhead(void) { - return GGML_OBJECT_SIZE + GGML_PAD(GGML_GRAPH_SIZE, GGML_MEM_ALIGN); +void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) { + GGML_ASSERT(dst->size >= src->n_leafs); + GGML_ASSERT(dst->size >= src->n_nodes); + GGML_ASSERT(dst->visited_hash_table.size >= src->visited_hash_table.size); + + dst->n_leafs = src->n_leafs; + dst->n_nodes = src->n_nodes; + dst->order = src->order; + + for (int i = 0; i < src->n_leafs; ++i) { + dst->leafs[i] = src->leafs[i]; + } + + for (int i = 0; i < src->n_nodes; ++i) { + dst->nodes[i] = src->nodes[i]; + } + + if (src->grads) { + GGML_ASSERT(dst->grads != NULL); + for (int i = 0; i < src->n_nodes; ++i) { + dst->grads[i] = src->grads[i]; + } + } + + for (size_t i = 0; i < src->visited_hash_table.size; ++i) { + if (src->visited_hash_table.keys[i]) { + ggml_hash_insert(dst->visited_hash_table, src->visited_hash_table.keys[i]); + } + } +} + +struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph) { + struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads != NULL); + ggml_graph_cpy(cgraph, result); + return result; +} + +void ggml_graph_reset(struct ggml_cgraph * cgraph) { + GGML_ASSERT(cgraph->grads != NULL); + + for (int i = 0; i < cgraph->n_nodes; i++) { + struct ggml_tensor * grad = cgraph->grads[i]; + + if (grad) { + ggml_set_zero(grad); + } + } +} + +void ggml_graph_clear(struct ggml_cgraph * cgraph) { + cgraph->n_leafs = 0; + cgraph->n_nodes = 0; + memset(cgraph->visited_hash_table.keys, 0, cgraph->visited_hash_table.size * sizeof(struct ggml_tensor *)); } // @@ -17202,13 +16402,253 @@ struct ggml_compute_state { struct ggml_compute_state_shared * shared; }; -static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const struct ggml_compute_state_shared * st) { - int64_t cycles_cur = ggml_perf_cycles() - st->perf_node_start_cycles; - int64_t time_us_cur = ggml_perf_time_us() - st->perf_node_start_time_us; +static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const struct ggml_compute_state_shared * st) { + int64_t cycles_cur = ggml_perf_cycles() - st->perf_node_start_cycles; + int64_t time_us_cur = ggml_perf_time_us() - st->perf_node_start_time_us; + + node->perf_runs++; + node->perf_cycles += cycles_cur; + node->perf_time_us += time_us_cur; +} + +static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { + int n_tasks = 0; + + switch (node->op) { + case GGML_OP_CPY: + case GGML_OP_DUP: + case GGML_OP_ADD: + case GGML_OP_ADD1: + case GGML_OP_ACC: + { + n_tasks = n_threads; + } break; + case GGML_OP_SUB: + case GGML_OP_DIV: + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_LOG: + case GGML_OP_SUM: + case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: + case GGML_OP_ARGMAX: + case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: + { + n_tasks = 1; + } break; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(node)) { + case GGML_UNARY_OP_ABS: + case GGML_UNARY_OP_SGN: + case GGML_UNARY_OP_NEG: + case GGML_UNARY_OP_STEP: + case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_ELU: + case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_LEAKY: + { + n_tasks = 1; + } break; + + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_SILU: + { + n_tasks = n_threads; + } break; + } + break; + case GGML_OP_SILU_BACK: + case GGML_OP_MUL: + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_RMS_NORM_BACK: + case GGML_OP_GROUP_NORM: + case GGML_OP_CONCAT: + { + n_tasks = n_threads; + } break; + case GGML_OP_MUL_MAT: + { + n_tasks = n_threads; + + // TODO: use different scheduling for different matrix sizes + //const int nr0 = ggml_nrows(node->src[0]); + //const int nr1 = ggml_nrows(node->src[1]); + + //n_tasks = MIN(n_threads, MAX(1, nr0/128)); + //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks%d\n", nr0, nr1, nr0*nr1, n_tasks); + +#if defined(GGML_USE_CUBLAS) + if (ggml_cuda_can_mul_mat(node->src[0], node->src[1], node)) { + n_tasks = 1; // TODO: this actually is doing nothing + // the threads are still spinning + } +#elif defined(GGML_USE_CLBLAST) + if (ggml_cl_can_mul_mat(node->src[0], node->src[1], node)) { + n_tasks = 1; // TODO: this actually is doing nothing + // the threads are still spinning + } +#endif +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) + if (ggml_compute_forward_mul_mat_use_blas(node->src[0], node->src[1], node)) { + n_tasks = 1; // TODO: this actually is doing nothing + // the threads are still spinning + } +#endif + } break; + case GGML_OP_OUT_PROD: + { + n_tasks = n_threads; + } break; + case GGML_OP_SCALE: + case GGML_OP_SET: + case GGML_OP_CONT: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + case GGML_OP_GET_ROWS: + case GGML_OP_GET_ROWS_BACK: + case GGML_OP_DIAG: + { + n_tasks = 1; + } break; + case GGML_OP_DIAG_MASK_ZERO: + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: + case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + case GGML_OP_ADD_REL_POS: + { + n_tasks = n_threads; + } break; + case GGML_OP_ALIBI: + { + n_tasks = 1; //TODO + } break; + case GGML_OP_CLAMP: + { + n_tasks = 1; //TODO + } break; + case GGML_OP_CONV_1D: + { + n_tasks = n_threads; + } break; + case GGML_OP_CONV_1D_STAGE_0: + { + n_tasks = n_threads; + } break; + case GGML_OP_CONV_1D_STAGE_1: + { + n_tasks = n_threads; + } break; + case GGML_OP_CONV_TRANSPOSE_1D: + { + n_tasks = n_threads; + } break; + case GGML_OP_CONV_2D: + { + n_tasks = n_threads; + } break; + case GGML_OP_CONV_2D_STAGE_0: + { + n_tasks = n_threads; + } break; + case GGML_OP_CONV_2D_STAGE_1: + { + n_tasks = n_threads; + } break; + case GGML_OP_CONV_TRANSPOSE_2D: + { + n_tasks = n_threads; + } break; + case GGML_OP_POOL_1D: + case GGML_OP_POOL_2D: + { + n_tasks = 1; + } break; + case GGML_OP_UPSCALE: + { + n_tasks = n_threads; + } break; + case GGML_OP_FLASH_ATTN: + { + n_tasks = n_threads; + } break; + case GGML_OP_FLASH_FF: + { + n_tasks = n_threads; + } break; + case GGML_OP_FLASH_ATTN_BACK: + { + n_tasks = n_threads; + } break; + case GGML_OP_WIN_PART: + case GGML_OP_WIN_UNPART: + case GGML_OP_GET_REL_POS: + case GGML_OP_MAP_UNARY: + case GGML_OP_MAP_BINARY: + case GGML_OP_MAP_CUSTOM1_F32: + case GGML_OP_MAP_CUSTOM2_F32: + case GGML_OP_MAP_CUSTOM3_F32: + { + n_tasks = 1; + } break; + case GGML_OP_MAP_CUSTOM1: + { + struct ggml_map_custom1_op_params * p = (struct ggml_map_custom1_op_params *) node->op_params; + if (p->n_tasks == GGML_N_TASKS_MAX) { + n_tasks = n_threads; + } else { + n_tasks = MIN(p->n_tasks, n_threads); + } + } break; + case GGML_OP_MAP_CUSTOM2: + { + struct ggml_map_custom2_op_params * p = (struct ggml_map_custom2_op_params *) node->op_params; + if (p->n_tasks == GGML_N_TASKS_MAX) { + n_tasks = n_threads; + } else { + n_tasks = MIN(p->n_tasks, n_threads); + } + } break; + case GGML_OP_MAP_CUSTOM3: + { + struct ggml_map_custom3_op_params * p = (struct ggml_map_custom3_op_params *) node->op_params; + if (p->n_tasks == GGML_N_TASKS_MAX) { + n_tasks = n_threads; + } else { + n_tasks = MIN(p->n_tasks, n_threads); + } + } break; + case GGML_OP_CROSS_ENTROPY_LOSS: + { + n_tasks = n_threads; + } break; + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: + { + n_tasks = n_threads; + } break; + case GGML_OP_NONE: + { + n_tasks = 1; + } break; + case GGML_OP_COUNT: + { + GGML_ASSERT(false); + } break; + default: + { + GGML_ASSERT(false); + } break; + } + + assert(n_tasks > 0); - node->perf_runs++; - node->perf_cycles += cycles_cur; - node->perf_time_us += time_us_cur; + return n_tasks; } static thread_ret_t ggml_graph_compute_thread(void * data) { @@ -17217,7 +16657,6 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { const struct ggml_cgraph * cgraph = state->shared->cgraph; const struct ggml_cplan * cplan = state->shared->cplan; - const int * n_tasks_arr = cplan->n_tasks; const int n_threads = state->shared->n_threads; set_numa_thread_affinity(state->ith, n_threads); @@ -17242,9 +16681,9 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { if (node_n != -1) { /* FINALIZE */ - struct ggml_tensor * node = state->shared->cgraph->nodes[node_n]; + struct ggml_tensor * node = cgraph->nodes[node_n]; if (GGML_OP_HAS_FINALIZE[node->op]) { - params.nth = n_tasks_arr[node_n]; + params.nth = ggml_get_n_tasks(node, n_threads); ggml_compute_forward(¶ms, node); } ggml_graph_compute_perf_stats_node(node, state->shared); @@ -17255,7 +16694,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, node_n, cgraph->n_nodes); struct ggml_tensor * node = cgraph->nodes[node_n]; - const int n_tasks = n_tasks_arr[node_n]; + const int n_tasks = ggml_get_n_tasks(node, n_threads); state->shared->perf_node_start_cycles = ggml_perf_cycles(); state->shared->perf_node_start_time_us = ggml_perf_time_us(); @@ -17313,7 +16752,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { /* COMPUTE */ struct ggml_tensor * node = cgraph->nodes[node_n]; - const int n_tasks = n_tasks_arr[node_n]; + const int n_tasks = ggml_get_n_tasks(node, n_threads); struct ggml_compute_params params = { /*.type =*/ GGML_TASK_COMPUTE, @@ -17347,122 +16786,46 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { struct ggml_tensor * node = cgraph->nodes[i]; + size_t cur = 0; + switch (node->op) { case GGML_OP_CPY: case GGML_OP_DUP: { n_tasks = n_threads; - size_t cur = 0; if (ggml_is_quantized(node->type)) { cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks; } - - work_size = MAX(work_size, cur); } break; case GGML_OP_ADD: case GGML_OP_ADD1: { n_tasks = n_threads; - size_t cur = 0; - if (ggml_is_quantized(node->src[0]->type)) { cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks; } - - work_size = MAX(work_size, cur); } break; case GGML_OP_ACC: { n_tasks = n_threads; - size_t cur = 0; - if (ggml_is_quantized(node->src[0]->type)) { cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks; } - - work_size = MAX(work_size, cur); - } break; - case GGML_OP_SUB: - case GGML_OP_DIV: - case GGML_OP_SQR: - case GGML_OP_SQRT: - case GGML_OP_LOG: - case GGML_OP_SUM: - case GGML_OP_SUM_ROWS: - case GGML_OP_MEAN: - case GGML_OP_ARGMAX: - case GGML_OP_REPEAT: - case GGML_OP_REPEAT_BACK: - { - n_tasks = 1; - } break; - - case GGML_OP_UNARY: - { - switch (ggml_get_unary_op(node)) { - case GGML_UNARY_OP_ABS: - case GGML_UNARY_OP_SGN: - case GGML_UNARY_OP_NEG: - case GGML_UNARY_OP_STEP: - case GGML_UNARY_OP_TANH: - case GGML_UNARY_OP_ELU: - case GGML_UNARY_OP_RELU: - { - n_tasks = 1; - } break; - - case GGML_UNARY_OP_GELU: - case GGML_UNARY_OP_GELU_QUICK: - case GGML_UNARY_OP_SILU: - { - n_tasks = n_threads; - } break; - } - } break; - case GGML_OP_SILU_BACK: - case GGML_OP_MUL: - case GGML_OP_NORM: - case GGML_OP_RMS_NORM: - case GGML_OP_RMS_NORM_BACK: - case GGML_OP_GROUP_NORM: - { - n_tasks = n_threads; } break; - case GGML_OP_CONCAT: case GGML_OP_MUL_MAT: - case GGML_OP_OUT_PROD: { - n_tasks = n_threads; - - // TODO: use different scheduling for different matrix sizes - //const int nr0 = ggml_nrows(node->src[0]); - //const int nr1 = ggml_nrows(node->src[1]); - - //n_tasks = MIN(n_threads, MAX(1, nr0/128)); - //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks%d\n", nr0, nr1, nr0*nr1, n_tasks); - - size_t cur = 0; const enum ggml_type vec_dot_type = type_traits[node->src[0]->type].vec_dot_type; -#if defined(GGML_USE_CUBLAS) - if (ggml_cuda_can_mul_mat(node->src[0], node->src[1], node)) { - n_tasks = 1; // TODO: this actually is doing nothing - // the threads are still spinning - } else -#elif defined(GGML_USE_CLBLAST) +#if defined(GGML_USE_CLBLAST) if (ggml_cl_can_mul_mat(node->src[0], node->src[1], node)) { - n_tasks = 1; // TODO: this actually is doing nothing - // the threads are still spinning cur = ggml_cl_mul_mat_get_wsize(node->src[0], node->src[1], node); } else #endif #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) if (ggml_compute_forward_mul_mat_use_blas(node->src[0], node->src[1], node)) { - n_tasks = 1; // TODO: this actually is doing nothing - // the threads are still spinning if (node->src[0]->type != GGML_TYPE_F32) { // here we need memory just for single 2D matrix from src0 cur = ggml_type_size(GGML_TYPE_F32)*(node->src[0]->ne[0]*node->src[0]->ne[1]); @@ -17471,79 +16834,75 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { #endif if (node->src[1]->type != vec_dot_type) { cur = ggml_type_size(vec_dot_type)*ggml_nelements(node->src[1])/ggml_blck_size(vec_dot_type); - } else { - cur = 0; } - - work_size = MAX(work_size, cur); - } break; - case GGML_OP_SCALE: - { - n_tasks = 1; - } break; - case GGML_OP_SET: - case GGML_OP_CONT: - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_PERMUTE: - case GGML_OP_TRANSPOSE: - case GGML_OP_GET_ROWS: - case GGML_OP_GET_ROWS_BACK: - case GGML_OP_DIAG: - { - n_tasks = 1; } break; - case GGML_OP_DIAG_MASK_ZERO: - case GGML_OP_DIAG_MASK_INF: - case GGML_OP_SOFT_MAX: - case GGML_OP_SOFT_MAX_BACK: - case GGML_OP_ROPE: - case GGML_OP_ROPE_BACK: - case GGML_OP_ADD_REL_POS: + case GGML_OP_OUT_PROD: { n_tasks = n_threads; - } break; - case GGML_OP_ALIBI: - { - n_tasks = 1; //TODO - } break; - case GGML_OP_CLAMP: - { - n_tasks = 1; //TODO + + if (ggml_is_quantized(node->src[0]->type)) { + cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks; + } } break; case GGML_OP_CONV_1D: { - n_tasks = n_threads; - GGML_ASSERT(node->src[0]->ne[3] == 1); GGML_ASSERT(node->src[1]->ne[2] == 1); GGML_ASSERT(node->src[1]->ne[3] == 1); - size_t cur = 0; - const int nk = node->src[0]->ne[0]; + const int64_t ne00 = node->src[0]->ne[0]; + const int64_t ne01 = node->src[0]->ne[1]; + const int64_t ne02 = node->src[0]->ne[2]; + + const int64_t ne10 = node->src[1]->ne[0]; + const int64_t ne11 = node->src[1]->ne[1]; + + const int64_t ne0 = node->ne[0]; + const int64_t ne1 = node->ne[1]; + const int64_t nk = ne00; + const int64_t ew0 = nk * ne01; + + UNUSED(ne02); + UNUSED(ne10); + UNUSED(ne11); if (node->src[0]->type == GGML_TYPE_F16 && - node->src[1]->type == GGML_TYPE_F32) { - cur = sizeof(ggml_fp16_t)*( - nk*ggml_up32(node->src[0]->ne[1])*node->src[0]->ne[2] + - ( 2*(nk/2) + node->src[1]->ne[0])*node->src[1]->ne[1] - ); + node->src[1]->type == GGML_TYPE_F32) { + cur = sizeof(ggml_fp16_t)*(ne0*ne1*ew0); } else if (node->src[0]->type == GGML_TYPE_F32 && - node->src[1]->type == GGML_TYPE_F32) { - cur = sizeof(float)*( - nk*ggml_up32(node->src[0]->ne[1])*node->src[0]->ne[2] + - ( 2*(nk/2) + node->src[1]->ne[0])*node->src[1]->ne[1] - ); + node->src[1]->type == GGML_TYPE_F32) { + cur = sizeof(float)*(ne0*ne1*ew0); } else { GGML_ASSERT(false); } + } break; + case GGML_OP_CONV_TRANSPOSE_1D: + { + GGML_ASSERT(node->src[0]->ne[3] == 1); + GGML_ASSERT(node->src[1]->ne[2] == 1); + GGML_ASSERT(node->src[1]->ne[3] == 1); - work_size = MAX(work_size, cur); + const int64_t ne00 = node->src[0]->ne[0]; // K + const int64_t ne01 = node->src[0]->ne[1]; // Cout + const int64_t ne02 = node->src[0]->ne[2]; // Cin + + const int64_t ne10 = node->src[1]->ne[0]; // L + const int64_t ne11 = node->src[1]->ne[1]; // Cin + + if (node->src[0]->type == GGML_TYPE_F16 && + node->src[1]->type == GGML_TYPE_F32) { + cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02; + cur += sizeof(ggml_fp16_t)*ne10*ne11; + } else if (node->src[0]->type == GGML_TYPE_F32 && + node->src[1]->type == GGML_TYPE_F32) { + cur += sizeof(float)*ne00*ne01*ne02; + cur += sizeof(float)*ne10*ne11; + } else { + GGML_ASSERT(false); + } } break; case GGML_OP_CONV_2D: { - n_tasks = n_threads; - const int64_t ne00 = node->src[0]->ne[0]; // W const int64_t ne01 = node->src[0]->ne[1]; // H const int64_t ne02 = node->src[0]->ne[2]; // C @@ -17556,30 +16915,26 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { const int64_t ne0 = node->ne[0]; const int64_t ne1 = node->ne[1]; const int64_t ne2 = node->ne[2]; + const int64_t ne3 = node->ne[3]; const int64_t nk = ne00*ne01; const int64_t ew0 = nk * ne02; UNUSED(ne03); UNUSED(ne2); - size_t cur = 0; - if (node->src[0]->type == GGML_TYPE_F16 && node->src[1]->type == GGML_TYPE_F32) { - cur = sizeof(ggml_fp16_t)*(ne0*ne1*ew0); + // im2col: [N*OH*OW, IC*KH*KW] + cur = sizeof(ggml_fp16_t)*(ne3*ne0*ne1*ew0); } else if (node->src[0]->type == GGML_TYPE_F32 && node->src[1]->type == GGML_TYPE_F32) { cur = sizeof(float)* (ne10*ne11*ne12); } else { GGML_ASSERT(false); } - - work_size = MAX(work_size, cur); } break; case GGML_OP_CONV_TRANSPOSE_2D: { - n_tasks = n_threads; - const int64_t ne00 = node->src[0]->ne[0]; // W const int64_t ne01 = node->src[0]->ne[1]; // H const int64_t ne02 = node->src[0]->ne[2]; // Channels Out @@ -17589,141 +16944,66 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { const int64_t ne11 = node->src[1]->ne[1]; // H const int64_t ne12 = node->src[1]->ne[2]; // Channels In - size_t cur = 0; cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03; cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12; - - work_size = MAX(work_size, cur); - } break; - case GGML_OP_POOL_1D: - case GGML_OP_POOL_2D: - { - n_tasks = 1; - } break; - case GGML_OP_UPSCALE: - { - n_tasks = n_threads; } break; case GGML_OP_FLASH_ATTN: { n_tasks = n_threads; - size_t cur = 0; - const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL); if (node->src[1]->type == GGML_TYPE_F32) { cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1) cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2 - } - - if (node->src[1]->type == GGML_TYPE_F16) { + } else if (node->src[1]->type == GGML_TYPE_F16) { cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1) cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2 } - - work_size = MAX(work_size, cur); } break; case GGML_OP_FLASH_FF: { n_tasks = n_threads; - size_t cur = 0; - if (node->src[1]->type == GGML_TYPE_F32) { cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1) cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2 - } - - if (node->src[1]->type == GGML_TYPE_F16) { + } else if (node->src[1]->type == GGML_TYPE_F16) { cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1) cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2 } - - work_size = MAX(work_size, cur); } break; case GGML_OP_FLASH_ATTN_BACK: { n_tasks = n_threads; - size_t cur = 0; - const int64_t D = node->src[0]->ne[0]; const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL); const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back if (node->src[1]->type == GGML_TYPE_F32) { cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1) cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 - } - - if (node->src[1]->type == GGML_TYPE_F16) { + } else if (node->src[1]->type == GGML_TYPE_F16) { cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1) cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 } - - work_size = MAX(work_size, cur); - } break; - case GGML_OP_WIN_PART: - case GGML_OP_WIN_UNPART: - case GGML_OP_GET_REL_POS: - case GGML_OP_MAP_UNARY: - case GGML_OP_MAP_BINARY: - case GGML_OP_MAP_CUSTOM1_F32: - case GGML_OP_MAP_CUSTOM2_F32: - case GGML_OP_MAP_CUSTOM3_F32: - { - n_tasks = 1; - } break; - case GGML_OP_MAP_CUSTOM1: - { - struct ggml_map_custom1_op_params * p = (struct ggml_map_custom1_op_params *) node->op_params; - if (p->n_tasks == GGML_N_TASKS_MAX) { - n_tasks = n_threads; - } else { - n_tasks = MIN(p->n_tasks, n_threads); - } - } break; - case GGML_OP_MAP_CUSTOM2: - { - struct ggml_map_custom2_op_params * p = (struct ggml_map_custom2_op_params *) node->op_params; - if (p->n_tasks == GGML_N_TASKS_MAX) { - n_tasks = n_threads; - } else { - n_tasks = MIN(p->n_tasks, n_threads); - } - } break; - case GGML_OP_MAP_CUSTOM3: - { - struct ggml_map_custom3_op_params * p = (struct ggml_map_custom3_op_params *) node->op_params; - if (p->n_tasks == GGML_N_TASKS_MAX) { - n_tasks = n_threads; - } else { - n_tasks = MIN(p->n_tasks, n_threads); - } } break; + case GGML_OP_CROSS_ENTROPY_LOSS: { n_tasks = n_threads; - size_t cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks); - - work_size = MAX(work_size, cur); - } break; - case GGML_OP_CROSS_ENTROPY_LOSS_BACK: - { - n_tasks = n_threads; - } break; - case GGML_OP_NONE: - { - n_tasks = 1; + cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks); } break; case GGML_OP_COUNT: { GGML_ASSERT(false); } break; + default: + break; } - cplan.n_tasks[i] = n_tasks; + work_size = MAX(work_size, cur); } if (work_size > 0) { @@ -17745,12 +17025,6 @@ int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) { if (cplan->work_size > 0) { GGML_ASSERT(cplan->work_data); } - - for (int i = 0; i < cgraph->n_nodes; ++i) { - if (cgraph->nodes[i]->op != GGML_OP_NONE) { - GGML_ASSERT(cplan->n_tasks[i] > 0); - } - } } const int n_threads = cplan->n_threads; @@ -17823,16 +17097,6 @@ int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) { return compute_status; } -void ggml_graph_reset(struct ggml_cgraph * cgraph) { - for (int i = 0; i < cgraph->n_nodes; i++) { - struct ggml_tensor * grad = cgraph->grads[i]; - - if (grad) { - ggml_set_zero(grad); - } - } -} - void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) { struct ggml_cplan cplan = ggml_graph_plan(cgraph, n_threads); @@ -17959,12 +17223,12 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) { const uint32_t magic = GGML_FILE_MAGIC; const uint32_t version = GGML_FILE_VERSION; const uint32_t n_leafs = cgraph->n_leafs; - const uint32_t nodes = cgraph->n_nodes; + const uint32_t n_nodes = cgraph->n_nodes; fwrite(&magic, sizeof(uint32_t), 1, fout); fwrite(&version, sizeof(uint32_t), 1, fout); fwrite(&n_leafs, sizeof(uint32_t), 1, fout); - fwrite(&nodes, sizeof(uint32_t), 1, fout); + fwrite(&n_nodes, sizeof(uint32_t), 1, fout); fwrite(&size_eval, sizeof(uint64_t), 1, fout); } @@ -18052,7 +17316,7 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) { if (idx == -1) { for (int k = 0; k < cgraph->n_nodes; ++k) { if (args[j] == cgraph->nodes[k]) { - idx = GGML_MAX_NODES + k; + idx = cgraph->n_leafs + k; break; } } @@ -18060,6 +17324,7 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) { if (idx == -1) { fprintf(stderr, "%s: failed to find tensor, arg = %d, node = %d\n", __func__, j, i); + fclose(fout); return; } @@ -18078,11 +17343,11 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) { } } -struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval) { +struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval) { assert(*ctx_data == NULL); assert(*ctx_eval == NULL); - struct ggml_cgraph result = { 0 }; + struct ggml_cgraph * result = NULL; struct ggml_tensor * data = NULL; @@ -18154,13 +17419,11 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** const uint32_t n_leafs = *(const uint32_t *) ptr; ptr += sizeof(n_leafs); const uint32_t n_nodes = *(const uint32_t *) ptr; ptr += sizeof(n_nodes); const uint64_t size_eval = *(const uint64_t *) ptr; ptr += sizeof(size_eval); - - result.n_leafs = n_leafs; - result.n_nodes = n_nodes; + const int graph_size = MAX(n_leafs, n_nodes); // create the data context { - const size_t overhead = (n_leafs + n_nodes)*ggml_tensor_overhead(); + const size_t overhead = (n_leafs + n_nodes)*ggml_tensor_overhead() + ggml_graph_overhead_custom(graph_size, false); struct ggml_init_params params = { .mem_size = size_eval + overhead, @@ -18176,6 +17439,12 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** } } + result = ggml_new_graph_custom(*ctx_eval, graph_size, false); + + result->n_leafs = n_leafs; + result->n_nodes = n_nodes; + + // leafs { uint32_t type; @@ -18214,7 +17483,7 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** tensor->nb[j] = nb[j]; } - result.leafs[i] = tensor; + result->leafs[i] = tensor; ptr += ggml_nbytes(tensor); @@ -18266,10 +17535,10 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** continue; } - if (arg_idx < GGML_MAX_NODES) { - args[j] = result.leafs[arg_idx]; + if (arg_idx < result->n_leafs) { + args[j] = result->leafs[arg_idx]; } else { - args[j] = result.nodes[arg_idx - GGML_MAX_NODES]; + args[j] = result->nodes[arg_idx - result->n_leafs]; } } @@ -18321,7 +17590,7 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** tensor->src[j] = args[j]; } - result.nodes[i] = tensor; + result->nodes[i] = tensor; fprintf(stderr, "%s: loaded node %d: '%16s', %3d dims, %9zu bytes\n", __func__, i, tensor->name, n_dims, ggml_nbytes(tensor)); } @@ -18568,7 +17837,7 @@ static void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float * } static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g) { - int i = 0; + int64_t i = 0; for (int p = 0; p < np; ++p) { const int64_t ne = ggml_nelements(ps[p]) ; // TODO: add function to get all elements at once @@ -18578,6 +17847,17 @@ static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g } } +static void ggml_opt_acc_grad(int np, struct ggml_tensor * const ps[], float * g, float scale) { + int64_t i = 0; + for (int p = 0; p < np; ++p) { + const int64_t ne = ggml_nelements(ps[p]) ; + // TODO: add function to get all elements at once + for (int64_t j = 0; j < ne; ++j) { + g[i++] += ggml_get_f32_1d(ps[p]->grad, j) * scale; + } + } +} + // // ADAM // @@ -18626,26 +17906,40 @@ static enum ggml_opt_result ggml_opt_adam( const float eps = params.adam.eps; const float gclip = params.adam.gclip; const int decay_min_ndim = params.adam.decay_min_ndim; + const int n_accum = MAX(1, params.n_gradient_accumulation); + const float accum_norm = 1.0f / (float) n_accum; + float * g = opt->adam.g->data; // gradients float * m = opt->adam.m->data; // first moment float * v = opt->adam.v->data; // second moment float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values - if (callback) { - callback(callback_data, &sched); - } - - // compute the function value - ggml_graph_reset (gf); - ggml_set_f32 (f->grad, 1.0f); - struct ggml_cplan cplan = ggml_graph_plan(gb, params.n_threads); struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_WORK_BUFFER, cplan.work_size); cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs; - ggml_graph_compute(gb, &cplan); - opt->adam.fx_prev = ggml_get_f32_1d(f, 0); + bool cancel = false; + + // compute the function value + float fx = 0; + ggml_set_zero(opt->adam.g); + for (int accum_step = 0; accum_step < n_accum; ++accum_step) { + if (callback) { + callback(callback_data, accum_step, &sched, &cancel); + if (cancel) { + return GGML_OPT_CANCEL; + } + } + // ggml_graph_reset (gf); + ggml_set_f32 (f->grad, 1.0f); + ggml_graph_compute(gb, &cplan); + ggml_opt_acc_grad(np, ps, g, accum_norm); + fx += ggml_get_f32_1d(f, 0); + } + fx *= accum_norm; + + opt->adam.fx_prev = fx; opt->adam.fx_best = opt->adam.fx_prev; if (pf) { pf[opt->iter % params.past] = opt->adam.fx_prev; @@ -18690,12 +17984,8 @@ static enum ggml_opt_result ggml_opt_adam( if (gclip > 0.0f) { // gradient clipping ggml_float sum = 0.0; - for (int p = 0; p < np; ++p) { - const int64_t ne = ggml_nelements(ps[p]); - for (int64_t j = 0; j < ne; ++j) { - float g = ggml_get_f32_1d(ps[p]->grad, j); - sum += (ggml_float)(g*g); - } + for (int64_t i = 0; i < nx; ++i) { + sum += (ggml_float)(g[i]*g[i]); } ggml_float norm = sqrt(sum); if (norm > (ggml_float) gclip) { @@ -18709,10 +17999,10 @@ static enum ggml_opt_result ggml_opt_adam( const int64_t ne = ggml_nelements(ps[p]); const float p_decay = ((ps[p]->n_dims >= decay_min_ndim) ? decay : 0.0f) * sched; for (int64_t j = 0; j < ne; ++j) { - float x = ggml_get_f32_1d(ps[p], j); - float g = ggml_get_f32_1d(ps[p]->grad, j)*gnorm; - m[i] = m[i]*beta1 + g*(1.0f - beta1); - v[i] = v[i]*beta2 + g*g*(1.0f - beta2); + float x = ggml_get_f32_1d(ps[p], j); + float g_ = g[i]*gnorm; + m[i] = m[i]*beta1 + g_*(1.0f - beta1); + v[i] = v[i]*beta2 + g_*g_*(1.0f - beta2); float mh = m[i]*beta1h; float vh = v[i]*beta2h; vh = sqrtf(vh) + eps; @@ -18723,19 +18013,25 @@ static enum ggml_opt_result ggml_opt_adam( } } - if (callback) { - callback(callback_data, &sched); + fx = 0; + ggml_set_zero(opt->adam.g); + for (int accum_step = 0; accum_step < n_accum; ++accum_step) { + if (callback) { + callback(callback_data, accum_step, &sched, &cancel); + if (cancel) { + return GGML_OPT_CANCEL;; + } + } + // ggml_graph_reset (gf); + ggml_set_f32 (f->grad, 1.0f); + ggml_graph_compute(gb, &cplan); + ggml_opt_acc_grad(np, ps, g, accum_norm); + fx += ggml_get_f32_1d(f, 0); } + fx *= accum_norm; - ggml_graph_reset (gf); - ggml_set_f32 (f->grad, 1.0f); - - ggml_graph_compute(gb, &cplan); - - const float fx = ggml_get_f32_1d(f, 0); opt->loss_after = fx; - // check convergence if (fabsf(fx - fx_prev[0])/fx < params.adam.eps_f) { GGML_PRINT_DEBUG("converged\n"); @@ -18812,11 +18108,11 @@ static enum ggml_opt_result linesearch_backtracking( float * step, const float * xp, struct ggml_tensor * f, - struct ggml_cgraph * gf, struct ggml_cgraph * gb, struct ggml_cplan * cplan, const int np, struct ggml_tensor * ps[], + bool * cancel, ggml_opt_callback callback, void * callback_data) { int count = 0; @@ -18830,6 +18126,9 @@ static enum ggml_opt_result linesearch_backtracking( const float dec = 0.5f; const float inc = 2.1f; + const int n_accum = MAX(1, params->n_gradient_accumulation); + const float accum_norm = 1.0f / (float) n_accum; + if (*step <= 0.f) { return GGML_LINESEARCH_INVALID_PARAMETERS; } @@ -18847,12 +18146,6 @@ static enum ggml_opt_result linesearch_backtracking( dgtest = params->lbfgs.ftol*dginit; while (true) { - if (callback) { - // LBFG-S does not support learning rate -> ignore learning schedule - float sched = 0; - callback(callback_data, &sched); - } - ggml_vec_cpy_f32(nx, x, xp); ggml_vec_mad_f32(nx, x, d, *step); @@ -18860,14 +18153,25 @@ static enum ggml_opt_result linesearch_backtracking( { ggml_opt_set_params(np, ps, x); - ggml_graph_reset (gf); - ggml_set_f32 (f->grad, 1.0f); - - ggml_graph_compute(gb, cplan); - - ggml_opt_get_grad(np, ps, g); + *fx = 0; + memset(g, 0, sizeof(float)*nx); + for (int accum_step = 0; accum_step < n_accum; ++accum_step) { + if (callback) { + // LBFG-S does not support learning rate -> ignore learning schedule + float sched = 0; + callback(callback_data, accum_step, &sched, cancel); + if (*cancel) { + return GGML_OPT_CANCEL; + } + } + // ggml_graph_reset (gf); + ggml_set_f32 (f->grad, 1.0f); + ggml_graph_compute(gb, cplan); + ggml_opt_acc_grad(np, ps, g, accum_norm); + *fx += ggml_get_f32_1d(f, 0); + } + *fx *= accum_norm; - *fx = ggml_get_f32_1d(f, 0); } ++count; @@ -18913,7 +18217,7 @@ static enum ggml_opt_result linesearch_backtracking( (*step) *= width; } - return GGML_LINESEARCH_FAIL; + GGML_UNREACHABLE(); } static enum ggml_opt_result ggml_opt_lbfgs( @@ -18968,6 +18272,9 @@ static enum ggml_opt_result ggml_opt_lbfgs( float * pf = params.past > 0 ? opt->lbfgs.pf->data : NULL; // past function values + const int n_accum = MAX(1, params.n_gradient_accumulation); + const float accum_norm = 1.0f / (float) n_accum; + float fx = 0.0f; // cost function value float xnorm = 0.0f; // ||x|| float gnorm = 0.0f; // ||g|| @@ -18981,24 +18288,30 @@ static enum ggml_opt_result ggml_opt_lbfgs( float * lm_s = opt->lbfgs.lms->data; float * lm_y = opt->lbfgs.lmy->data; - if (callback) { - // LBFG-S does not support learning rate -> ignore learning schedule - float sched = 0; - callback(callback_data, &sched); - } + bool cancel = false; // evaluate the function value and its gradient { ggml_opt_set_params(np, ps, x); - ggml_graph_reset (gf); - ggml_set_f32 (f->grad, 1.0f); - - ggml_graph_compute(gb, &cplan); - - ggml_opt_get_grad(np, ps, g); - - fx = ggml_get_f32_1d(f, 0); + fx = 0; + memset(g, 0, sizeof(float)*nx); + for (int accum_step = 0; accum_step < n_accum; ++accum_step) { + if (callback) { + // LBFG-S does not support learning rate -> ignore learning schedule + float sched = 0; + callback(callback_data, accum_step, &sched, &cancel); + if (cancel) { + return GGML_OPT_CANCEL; + } + } + // ggml_graph_reset (gf); + ggml_set_f32 (f->grad, 1.0f); + ggml_graph_compute(gb, &cplan); + ggml_opt_acc_grad(np, ps, g, accum_norm); + fx += ggml_get_f32_1d(f, 0); + } + fx *= accum_norm; opt->loss_before = fx; opt->loss_after = fx; @@ -19056,7 +18369,14 @@ static enum ggml_opt_result ggml_opt_lbfgs( ggml_vec_cpy_f32(nx, xp, x); ggml_vec_cpy_f32(nx, gp, g); - ls = linesearch_backtracking(¶ms, nx, x, &fx, g, d, step, xp, f, gf, gb, &cplan, np, ps, callback, callback_data); + // TODO: instead of passing &cancel here, use the return code of the linesearch + // to determine if the optimization should be cancelled + // this is a simple change, but not doing this atm, since I don't have a nice + // way to test and don't want to break something with so many changes lined up + ls = linesearch_backtracking(¶ms, nx, x, &fx, g, d, step, xp, f, gb, &cplan, np, ps, &cancel, callback, callback_data); + if (cancel) { + return GGML_OPT_CANCEL; + } if (ls < 0) { // linesearch failed - go back to the previous point and return @@ -19165,7 +18485,7 @@ static enum ggml_opt_result ggml_opt_lbfgs( step[0] = 1.0; } - return GGML_OPT_DID_NOT_CONVERGE; + GGML_UNREACHABLE(); } struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) { @@ -19175,16 +18495,19 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) { case GGML_OPT_ADAM: { result = (struct ggml_opt_params) { - .type = GGML_OPT_ADAM, - .n_threads = 1, - .past = 0, - .delta = 1e-5f, + .type = GGML_OPT_ADAM, + .graph_size = GGML_DEFAULT_GRAPH_SIZE, + .n_threads = 1, // FIXME: GGML_DEFAULT_N_THREADS ? + .past = 0, + .delta = 1e-5f, .max_no_improvement = 100, .print_forward_graph = true, .print_backward_graph = true, + .n_gradient_accumulation = 1, + .adam = { .n_iter = 10000, .sched = 1.000f, @@ -19203,16 +18526,19 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) { case GGML_OPT_LBFGS: { result = (struct ggml_opt_params) { - .type = GGML_OPT_LBFGS, - .n_threads = 1, - .past = 0, - .delta = 1e-5f, + .type = GGML_OPT_LBFGS, + .graph_size = GGML_DEFAULT_GRAPH_SIZE, + .n_threads = 1, + .past = 0, + .delta = 1e-5f, .max_no_improvement = 0, .print_forward_graph = true, .print_backward_graph = true, + .n_gradient_accumulation = 1, + .lbfgs = { .m = 6, .n_iter = 100, @@ -19243,13 +18569,32 @@ GGML_API void ggml_opt_init( opt->iter = 0; opt->nx = nx; opt->just_initialized = true; + if (opt->ctx == NULL) { + struct ggml_init_params ctx_opt_params; + if (opt->params.type == GGML_OPT_ADAM) { + ctx_opt_params.mem_size = GGML_MEM_ALIGN*3 + ggml_tensor_overhead()*3 + ggml_type_size(GGML_TYPE_F32)*nx*3; + if (opt->params.past > 0) { + ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past; + } + } else if (opt->params.type == GGML_OPT_LBFGS) { + ctx_opt_params.mem_size = GGML_MEM_ALIGN*9 + ggml_tensor_overhead()*9 + ggml_type_size(GGML_TYPE_F32)*(nx*5 + opt->params.lbfgs.m*2 + nx*opt->params.lbfgs.m*2); + if (opt->params.past > 0) { + ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past; + } + } + ctx_opt_params.mem_buffer = NULL; + ctx_opt_params.no_alloc = false; + + opt->ctx = ggml_init(ctx_opt_params); + } switch (opt->params.type) { case GGML_OPT_ADAM: { - opt->adam.m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); - opt->adam.v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); + opt->adam.g = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); + opt->adam.m = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); + opt->adam.v = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); opt->adam.pf = params.past > 0 - ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past) + ? ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.past) : NULL; ggml_set_zero(opt->adam.m); ggml_set_zero(opt->adam.v); @@ -19259,18 +18604,18 @@ GGML_API void ggml_opt_init( } break; case GGML_OPT_LBFGS: { - opt->lbfgs.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); - opt->lbfgs.xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); - opt->lbfgs.g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); - opt->lbfgs.gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); - opt->lbfgs.d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); + opt->lbfgs.x = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); + opt->lbfgs.xp = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); + opt->lbfgs.g = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); + opt->lbfgs.gp = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); + opt->lbfgs.d = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); opt->lbfgs.pf = params.past > 0 - ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past) + ? ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.past) : NULL; - opt->lbfgs.lmal = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m); - opt->lbfgs.lmys = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m); - opt->lbfgs.lms = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m); - opt->lbfgs.lmy = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m); + opt->lbfgs.lmal = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.lbfgs.m); + opt->lbfgs.lmys = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.lbfgs.m); + opt->lbfgs.lms = ggml_new_tensor_2d(opt->ctx, GGML_TYPE_F32, nx, params.lbfgs.m); + opt->lbfgs.lmy = ggml_new_tensor_2d(opt->ctx, GGML_TYPE_F32, nx, params.lbfgs.m); ggml_set_zero(opt->lbfgs.x); ggml_set_zero(opt->lbfgs.xp); ggml_set_zero(opt->lbfgs.g); @@ -19327,14 +18672,11 @@ enum ggml_opt_result ggml_opt_resume( struct ggml_tensor * f) { // build forward + backward compute graphs - struct ggml_tensor * gfbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / ggml_type_size(GGML_TYPE_I32)+ (sizeof(struct ggml_cgraph) % ggml_type_size(GGML_TYPE_I32) ? 1 : 0)); - struct ggml_tensor * gbbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / ggml_type_size(GGML_TYPE_I32)+ (sizeof(struct ggml_cgraph) % ggml_type_size(GGML_TYPE_I32) ? 1 : 0)); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx, opt->params.graph_size, true); + ggml_build_forward_expand(gf, f); - struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data; - struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data; - - *gf = ggml_build_forward (f); - *gb = ggml_build_backward(ctx, gf, true); + struct ggml_cgraph * gb = ggml_graph_dup(ctx, gf); + ggml_build_backward_expand(ctx, gf, gb, true); return ggml_opt_resume_g(ctx, opt, f, gf, gb, NULL, NULL); } @@ -19537,7 +18879,6 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i block_q8_0 * block = (block_q8_0*)dst + start / QK8_0; result = ggml_quantize_q8_0(src + start, block, n, n, hist); } break; -#ifdef GGML_USE_K_QUANTS case GGML_TYPE_Q2_K: { GGML_ASSERT(start % QK_K == 0); @@ -19568,7 +18909,6 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i block_q6_K * block = (block_q6_K*)dst + start / QK_K; result = ggml_quantize_q6_K(src + start, block, n, n, hist); } break; -#endif case GGML_TYPE_F16: { int elemsize = sizeof(ggml_fp16_t); @@ -19659,7 +18999,7 @@ struct gguf_kv { }; struct gguf_header { - uint32_t magic; + char magic[4]; uint32_t version; uint64_t n_tensors; // GGUFv2 uint64_t n_kv; // GGUFv2 @@ -19700,8 +19040,7 @@ static bool gguf_fread_el(FILE * file, void * dst, size_t size, size_t * offset) return n == size; } -// NOTE: temporary handling of GGUFv1 >> remove after Oct 2023 -static bool gguf_fread_str_cur(FILE * file, struct gguf_str * p, size_t * offset) { +static bool gguf_fread_str(FILE * file, struct gguf_str * p, size_t * offset) { p->n = 0; p->data = NULL; @@ -19713,23 +19052,10 @@ static bool gguf_fread_str_cur(FILE * file, struct gguf_str * p, size_t * offset return ok; } -static bool gguf_fread_str_v1(FILE * file, struct gguf_str * p, size_t * offset) { - p->n = 0; - p->data = NULL; - - bool ok = true; - - uint32_t n = 0; - ok = ok && gguf_fread_el(file, &n, sizeof(n), offset); p->data = calloc(n + 1, 1); p->n = n; - ok = ok && gguf_fread_el(file, p->data, p->n, offset); - - return ok; -} - struct gguf_context * gguf_init_empty(void) { struct gguf_context * ctx = GGML_ALIGNED_MALLOC(sizeof(struct gguf_context)); - ctx->header.magic = GGUF_MAGIC; + memcpy(ctx->header.magic, GGUF_MAGIC, sizeof(ctx->header.magic)); ctx->header.version = GGUF_VERSION; ctx->header.n_tensors = 0; ctx->header.n_kv = 0; @@ -19755,16 +19081,18 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p // offset from start of file size_t offset = 0; - uint32_t magic = 0; + char magic[4]; // check the magic before making allocations { gguf_fread_el(file, &magic, sizeof(magic), &offset); - if (magic != GGUF_MAGIC) { - fprintf(stderr, "%s: invalid magic number %08x\n", __func__, magic); - fclose(file); - return NULL; + for (uint32_t i = 0; i < sizeof(magic); i++) { + if (magic[i] != GGUF_MAGIC[i]) { + fprintf(stderr, "%s: invalid magic characters %s.\n", __func__, magic); + fclose(file); + return NULL; + } } } @@ -19774,27 +19102,22 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p // read the header { - ctx->header.magic = magic; + strncpy(ctx->header.magic, magic, 4); + ctx->kv = NULL; ctx->infos = NULL; ctx->data = NULL; ok = ok && gguf_fread_el(file, &ctx->header.version, sizeof(ctx->header.version), &offset); + ok = ok && gguf_fread_el(file, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors), &offset); + ok = ok && gguf_fread_el(file, &ctx->header.n_kv, sizeof(ctx->header.n_kv), &offset); if (ctx->header.version == 1) { - // NOTE: temporary handling of GGUFv1 >> remove after Oct 2023 - uint32_t n_tensors = 0; - uint32_t n_kv = 0; - - ok = ok && gguf_fread_el(file, &n_tensors, sizeof(n_tensors), &offset); - ok = ok && gguf_fread_el(file, &n_kv, sizeof(n_kv), &offset); - - ctx->header.n_tensors = n_tensors; - ctx->header.n_kv = n_kv; - } else { - ok = ok && gguf_fread_el(file, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors), &offset); - ok = ok && gguf_fread_el(file, &ctx->header.n_kv, sizeof(ctx->header.n_kv), &offset); + fprintf(stderr, "%s: GGUFv1 is no longer supported. please use a more up-to-date version\n", __func__); + fclose(file); + gguf_free(ctx); + return NULL; } if (!ok) { @@ -19805,12 +19128,6 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p } } - // NOTE: temporary handling of GGUFv1 >> remove after Oct 2023 - bool (* gguf_fread_str)(FILE *, struct gguf_str *, size_t *) = gguf_fread_str_cur; - if (ctx->header.version == 1) { - gguf_fread_str = gguf_fread_str_v1; - } - // read the kv pairs { ctx->kv = malloc(ctx->header.n_kv * sizeof(struct gguf_kv)); @@ -19841,15 +19158,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p case GGUF_TYPE_ARRAY: { ok = ok && gguf_fread_el(file, &kv->value.arr.type, sizeof(kv->value.arr.type), &offset); - - if (ctx->header.version == 1) { - // NOTE: temporary handling of GGUFv1 >> remove after Oct 2023 - uint32_t n = 0; - ok = ok && gguf_fread_el(file, &n, sizeof(n), &offset); - kv->value.arr.n = n; - } else { - ok = ok && gguf_fread_el(file, &kv->value.arr.n, sizeof(kv->value.arr.n), &offset); - } + ok = ok && gguf_fread_el(file, &kv->value.arr.n, sizeof(kv->value.arr.n), &offset); switch (kv->value.arr.type) { case GGUF_TYPE_UINT8: @@ -19876,10 +19185,10 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p } break; case GGUF_TYPE_ARRAY: case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); break; - }; + } } break; case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); - }; + } if (!ok) { break; @@ -19908,14 +19217,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p ok = ok && gguf_fread_str(file, &info->name, &offset); ok = ok && gguf_fread_el (file, &info->n_dims, sizeof(info->n_dims), &offset); for (uint32_t j = 0; j < info->n_dims; ++j) { - if (ctx->header.version == 1) { - // NOTE: temporary handling of GGUFv1 >> remove after Oct 2023 - uint32_t t = 0; - ok = ok && gguf_fread_el(file, &t, sizeof(t), &offset); - info->ne[j] = t; - } else { - ok = ok && gguf_fread_el(file, &info->ne[j], sizeof(info->ne[j]), &offset); - } + ok = ok && gguf_fread_el(file, &info->ne[j], sizeof(info->ne[j]), &offset); } ok = ok && gguf_fread_el (file, &info->type, sizeof(info->type), &offset); ok = ok && gguf_fread_el (file, &info->offset, sizeof(info->offset), &offset); @@ -20155,78 +19457,94 @@ int gguf_find_key(const struct gguf_context * ctx, const char * key) { return keyfound; } -const char * gguf_get_key(const struct gguf_context * ctx, int i) { - return ctx->kv[i].key.data; +const char * gguf_get_key(const struct gguf_context * ctx, int key_id) { + return ctx->kv[key_id].key.data; } -enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int i) { - return ctx->kv[i].type; +enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int key_id) { + return ctx->kv[key_id].type; } -enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int i) { - return ctx->kv[i].value.arr.type; +enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY); + return ctx->kv[key_id].value.arr.type; } -const void * gguf_get_arr_data(const struct gguf_context * ctx, int i) { - return ctx->kv[i].value.arr.data; +const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY); + return ctx->kv[key_id].value.arr.data; } const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i) { + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY); struct gguf_kv * kv = &ctx->kv[key_id]; struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i]; return str->data; } -int gguf_get_arr_n(const struct gguf_context * ctx, int i) { - return ctx->kv[i].value.arr.n; +int gguf_get_arr_n(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY); + return ctx->kv[key_id].value.arr.n; } -uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int i) { - return ctx->kv[i].value.uint8; +uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT8); + return ctx->kv[key_id].value.uint8; } -int8_t gguf_get_val_i8(const struct gguf_context * ctx, int i) { - return ctx->kv[i].value.int8; +int8_t gguf_get_val_i8(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT8); + return ctx->kv[key_id].value.int8; } -uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int i) { - return ctx->kv[i].value.uint16; +uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT16); + return ctx->kv[key_id].value.uint16; } -int16_t gguf_get_val_i16(const struct gguf_context * ctx, int i) { - return ctx->kv[i].value.int16; +int16_t gguf_get_val_i16(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT16); + return ctx->kv[key_id].value.int16; } -uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int i) { - return ctx->kv[i].value.uint32; +uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT32); + return ctx->kv[key_id].value.uint32; } -int32_t gguf_get_val_i32(const struct gguf_context * ctx, int i) { - return ctx->kv[i].value.int32; +int32_t gguf_get_val_i32(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT32); + return ctx->kv[key_id].value.int32; } -float gguf_get_val_f32(const struct gguf_context * ctx, int i) { - return ctx->kv[i].value.float32; +float gguf_get_val_f32(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT32); + return ctx->kv[key_id].value.float32; } -uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int i) { - return ctx->kv[i].value.uint64; +uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT64); + return ctx->kv[key_id].value.uint64; } -int64_t gguf_get_val_i64(const struct gguf_context * ctx, int i) { - return ctx->kv[i].value.int64; +int64_t gguf_get_val_i64(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT64); + return ctx->kv[key_id].value.int64; } -double gguf_get_val_f64(const struct gguf_context * ctx, int i) { - return ctx->kv[i].value.float64; +double gguf_get_val_f64(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT64); + return ctx->kv[key_id].value.float64; } -bool gguf_get_val_bool(const struct gguf_context * ctx, int i) { - return ctx->kv[i].value.bool_; +bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_BOOL); + return ctx->kv[key_id].value.bool_; } -const char * gguf_get_val_str (const struct gguf_context * ctx, int i) { - return ctx->kv[i].value.str.data; +const char * gguf_get_val_str(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_STRING); + return ctx->kv[key_id].value.str.data; } int gguf_get_n_tensors(const struct gguf_context * ctx) { @@ -20591,10 +19909,10 @@ static void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf * } break; case GGUF_TYPE_ARRAY: case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); break; - }; + } } break; case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); - }; + } } // write tensor infos