Spaces:
Running
Introduction of CUDA Graphs to LLama.cpp (llama/6766)
Browse files* DRAFT: Introduction of CUDA Graphs to LLama.cpp
* FIx issues raised in comments
* Tidied to now only use CUDA runtime (not mixed with driver calls)
* disable for multi-gpu and batch size > 1
* Disable CUDA graphs for old GPU arch and with env var
* added missing CUDA_CHECKs
* Addressed comments
* further addressed comments
* limit to GGML_ALLOW_CUDA_GRAPHS defined in llama.cpp cmake
* Added more comprehensive graph node checking
* With mechanism to fall back if graph capture fails
* Revert "With mechanism to fall back if graph capture fails"
This reverts commit eb9f15fb6fcb81384f732c4601a5b25c016a5143.
* Fall back if graph capture fails and address other comments
* - renamed GGML_ALLOW_CUDA_GRAPHS to GGML_CUDA_USE_GRAPHS
- rename env variable to disable CUDA graphs to GGML_CUDA_DISABLE_GRAPHS
- updated Makefile build to enable CUDA graphs
- removed graph capture failure checking in ggml_cuda_error
using a global variable to track this is not thread safe, but I am also not safistied with checking an error by string
if this is necessary to workaround some issues with graph capture with eg. cuBLAS, we can pass the ggml_backend_cuda_context to the error checking macro and store the result in the context
- fixed several resource leaks
- fixed issue with zero node graphs
- changed fixed size arrays to vectors
- removed the count of number of evaluations before start capturing, and instead changed the capture mode to relaxed
- removed the check for multiple devices so that it is still possible to use a single device, instead checks for split buffers to disable cuda graphs with -sm row
- changed the op for checking batch size to GGML_OP_ADD, should be more reliable than GGML_OP_SOFT_MAX
- code style fixes
- things to look into
- VRAM usage of the cudaGraphExec_t, if it is significant we may need to make it optional
- possibility of using cudaStreamBeginCaptureToGraph to keep track of which ggml graph nodes correspond to which cuda graph nodes
* fix build without cuda graphs
* remove outdated comment
* replace minimum cc value with a constant
---------
Co-authored-by: slaren <[email protected]>
- ggml-cuda.cu +286 -14
- ggml-cuda/clamp.cu +0 -1
- ggml-cuda/common.cuh +40 -0
- ggml-cuda/convert.cu +1 -3
- ggml-cuda/cpy.cu +29 -0
- ggml-cuda/cpy.cuh +2 -0
- ggml-cuda/mmq.cu +10 -20
- ggml-cuda/mmvq.cu +2 -4
- ggml-cuda/scale.cu +0 -1
|
@@ -1647,7 +1647,7 @@ static void ggml_cuda_op_mul_mat(
|
|
| 1647 |
}
|
| 1648 |
}
|
| 1649 |
|
| 1650 |
-
static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
|
| 1651 |
GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
|
| 1652 |
GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
|
| 1653 |
GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
|
|
@@ -1670,7 +1670,7 @@ static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const gg
|
|
| 1670 |
ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);
|
| 1671 |
}
|
| 1672 |
|
| 1673 |
-
static void ggml_cuda_mul_mat_vec_nc(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
|
| 1674 |
GGML_ASSERT(!ggml_is_transposed(src0));
|
| 1675 |
GGML_ASSERT(!ggml_is_transposed(src1));
|
| 1676 |
GGML_ASSERT(!ggml_is_permuted(src0));
|
|
@@ -2413,32 +2413,304 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
|
|
| 2413 |
GGML_UNUSED(backend);
|
| 2414 |
}
|
| 2415 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2416 |
GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
| 2417 |
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
|
| 2418 |
|
| 2419 |
ggml_cuda_set_device(cuda_ctx->device);
|
| 2420 |
|
| 2421 |
-
|
| 2422 |
-
|
| 2423 |
|
| 2424 |
-
|
| 2425 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2426 |
}
|
| 2427 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2428 |
#ifndef NDEBUG
|
| 2429 |
-
|
| 2430 |
-
|
| 2431 |
-
|
| 2432 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2433 |
}
|
| 2434 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2435 |
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2436 |
|
| 2437 |
-
|
| 2438 |
-
if (
|
| 2439 |
-
|
| 2440 |
}
|
| 2441 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2442 |
}
|
| 2443 |
|
| 2444 |
return GGML_STATUS_SUCCESS;
|
|
|
|
| 1647 |
}
|
| 1648 |
}
|
| 1649 |
|
| 1650 |
+
static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 1651 |
GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
|
| 1652 |
GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
|
| 1653 |
GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
|
|
|
|
| 1670 |
ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);
|
| 1671 |
}
|
| 1672 |
|
| 1673 |
+
static void ggml_cuda_mul_mat_vec_nc(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 1674 |
GGML_ASSERT(!ggml_is_transposed(src0));
|
| 1675 |
GGML_ASSERT(!ggml_is_transposed(src1));
|
| 1676 |
GGML_ASSERT(!ggml_is_permuted(src0));
|
|
|
|
| 2413 |
GGML_UNUSED(backend);
|
| 2414 |
}
|
| 2415 |
|
| 2416 |
+
static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
|
| 2417 |
+
graph_node_properties->node_address = node->data;
|
| 2418 |
+
graph_node_properties->node_op = node->op;
|
| 2419 |
+
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
| 2420 |
+
graph_node_properties->ne[i] = node->ne[i];
|
| 2421 |
+
graph_node_properties->nb[i] = node->nb[i];
|
| 2422 |
+
}
|
| 2423 |
+
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
| 2424 |
+
graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
|
| 2425 |
+
}
|
| 2426 |
+
}
|
| 2427 |
+
|
| 2428 |
+
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
|
| 2429 |
+
if (node->data != graph_node_properties->node_address &&
|
| 2430 |
+
node->op != GGML_OP_CPY &&
|
| 2431 |
+
node->op != GGML_OP_VIEW) {
|
| 2432 |
+
return false;
|
| 2433 |
+
}
|
| 2434 |
+
|
| 2435 |
+
if (node->op != graph_node_properties->node_op) {
|
| 2436 |
+
return false;
|
| 2437 |
+
}
|
| 2438 |
+
|
| 2439 |
+
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
| 2440 |
+
if (node->ne[i] != graph_node_properties->ne[i]) {
|
| 2441 |
+
return false;
|
| 2442 |
+
}
|
| 2443 |
+
if (node->nb[i] != graph_node_properties->nb[i]) {
|
| 2444 |
+
return false;
|
| 2445 |
+
}
|
| 2446 |
+
}
|
| 2447 |
+
|
| 2448 |
+
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
| 2449 |
+
if (node->src[i] &&
|
| 2450 |
+
node->src[i]->data != graph_node_properties->src_address[i] &&
|
| 2451 |
+
node->op != GGML_OP_CPY &&
|
| 2452 |
+
node->op != GGML_OP_VIEW
|
| 2453 |
+
) {
|
| 2454 |
+
return false;
|
| 2455 |
+
}
|
| 2456 |
+
}
|
| 2457 |
+
return true;
|
| 2458 |
+
}
|
| 2459 |
+
|
| 2460 |
GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
| 2461 |
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
|
| 2462 |
|
| 2463 |
ggml_cuda_set_device(cuda_ctx->device);
|
| 2464 |
|
| 2465 |
+
#ifdef USE_CUDA_GRAPH
|
| 2466 |
+
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
|
| 2467 |
|
| 2468 |
+
// Objects required for CUDA Graph
|
| 2469 |
+
if (cuda_ctx->cuda_graph == nullptr) {
|
| 2470 |
+
cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
|
| 2471 |
+
}
|
| 2472 |
+
|
| 2473 |
+
bool use_cuda_graph = true;
|
| 2474 |
+
bool cuda_graph_update_required = false;
|
| 2475 |
+
// pointer to CUDA cpy kernel, which is required to identify
|
| 2476 |
+
// kernel parameters which need updated in the graph for each token
|
| 2477 |
+
void * ggml_cuda_cpy_fn_ptr = nullptr;
|
| 2478 |
+
|
| 2479 |
+
if (cuda_ctx->cuda_graph->graph == nullptr) {
|
| 2480 |
+
if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) {
|
| 2481 |
+
cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
|
| 2482 |
+
#ifndef NDEBUG
|
| 2483 |
+
fprintf(stderr, "%s: disabling CUDA graphs due to GPU architecture\n", __func__);
|
| 2484 |
+
#endif
|
| 2485 |
+
}
|
| 2486 |
+
}
|
| 2487 |
+
|
| 2488 |
+
// Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
|
| 2489 |
+
// or previous graph capture failure.
|
| 2490 |
+
// Also disable for multi-gpu for now. TO DO investigate
|
| 2491 |
+
if (disable_cuda_graphs_due_to_env
|
| 2492 |
+
|| cuda_ctx->cuda_graph->disable_due_to_gpu_arch
|
| 2493 |
+
|| cuda_ctx->cuda_graph->disable_due_to_too_many_updates
|
| 2494 |
+
|| cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) {
|
| 2495 |
+
use_cuda_graph = false;
|
| 2496 |
+
}
|
| 2497 |
+
|
| 2498 |
+
if (use_cuda_graph) {
|
| 2499 |
+
if (cuda_ctx->cuda_graph->instance == nullptr) {
|
| 2500 |
+
cuda_graph_update_required = true;
|
| 2501 |
+
}
|
| 2502 |
+
|
| 2503 |
+
// Check if the graph size has changed
|
| 2504 |
+
if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
|
| 2505 |
+
cuda_graph_update_required = true;
|
| 2506 |
+
cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
|
| 2507 |
+
}
|
| 2508 |
+
|
| 2509 |
+
// Loop over nodes in GGML graph to determine if CUDA graph update is required
|
| 2510 |
+
// and store properties to allow this comparison for the next token
|
| 2511 |
+
for (int i = 0; i < cgraph->n_nodes; i++) {
|
| 2512 |
+
bool has_matching_properties = true;
|
| 2513 |
+
if (!cuda_graph_update_required) {
|
| 2514 |
+
has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
|
| 2515 |
+
}
|
| 2516 |
+
if (!has_matching_properties) {
|
| 2517 |
+
cuda_graph_update_required = true;
|
| 2518 |
+
}
|
| 2519 |
+
set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
|
| 2520 |
+
}
|
| 2521 |
+
|
| 2522 |
+
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
|
| 2523 |
+
cuda_ctx->cuda_graph->updated_kernel_arg.clear();
|
| 2524 |
+
for (int i = 0; i < cgraph->n_nodes; i++) {
|
| 2525 |
+
ggml_tensor * node = cgraph->nodes[i];
|
| 2526 |
+
|
| 2527 |
+
if (node->src[0] && ggml_backend_buffer_is_cuda_split(node->src[0]->buffer)) {
|
| 2528 |
+
use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture
|
| 2529 |
+
#ifndef NDEBUG
|
| 2530 |
+
fprintf(stderr, "%s: disabling CUDA graphs due to split buffer\n", __func__);
|
| 2531 |
+
#endif
|
| 2532 |
+
}
|
| 2533 |
+
|
| 2534 |
+
if (node->op == GGML_OP_MUL_MAT_ID) {
|
| 2535 |
+
use_cuda_graph = false; // This node type is not supported by CUDA graph capture
|
| 2536 |
+
#ifndef NDEBUG
|
| 2537 |
+
fprintf(stderr, "%s: disabling CUDA graphs due to mul_mat_id\n", __func__);
|
| 2538 |
+
#endif
|
| 2539 |
+
}
|
| 2540 |
+
|
| 2541 |
+
if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
|
| 2542 |
+
// disable CUDA graphs for batch size > 1 for now.
|
| 2543 |
+
// Changes in batch size or context size can cause changes to the grid size of some kernels.
|
| 2544 |
+
use_cuda_graph = false;
|
| 2545 |
+
#ifndef NDEBUG
|
| 2546 |
+
fprintf(stderr, "%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
|
| 2547 |
+
#endif
|
| 2548 |
+
}
|
| 2549 |
+
|
| 2550 |
+
if (node->op == GGML_OP_CPY) {
|
| 2551 |
+
// store the copy op parameter which changes with each token.
|
| 2552 |
+
cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
|
| 2553 |
+
if (ggml_cuda_cpy_fn_ptr == nullptr) {
|
| 2554 |
+
// store a pointer to the copy op CUDA kernel to identify it later
|
| 2555 |
+
ggml_cuda_cpy_fn_ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
|
| 2556 |
+
}
|
| 2557 |
+
}
|
| 2558 |
+
|
| 2559 |
+
if (!use_cuda_graph) {
|
| 2560 |
+
break;
|
| 2561 |
+
}
|
| 2562 |
+
}
|
| 2563 |
+
|
| 2564 |
+
// Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
|
| 2565 |
+
if (cuda_graph_update_required) {
|
| 2566 |
+
cuda_ctx->cuda_graph->number_consecutive_updates++;
|
| 2567 |
+
} else {
|
| 2568 |
+
cuda_ctx->cuda_graph->number_consecutive_updates = 0;
|
| 2569 |
}
|
| 2570 |
|
| 2571 |
+
if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
|
| 2572 |
+
cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
|
| 2573 |
+
#ifndef NDEBUG
|
| 2574 |
+
fprintf(stderr, "%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
|
| 2575 |
+
#endif
|
| 2576 |
+
}
|
| 2577 |
+
}
|
| 2578 |
+
|
| 2579 |
+
if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture
|
| 2580 |
+
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
|
| 2581 |
+
}
|
| 2582 |
+
|
| 2583 |
+
#else
|
| 2584 |
+
bool use_cuda_graph = false;
|
| 2585 |
+
bool cuda_graph_update_required = false;
|
| 2586 |
+
#endif // USE_CUDA_GRAPH
|
| 2587 |
+
|
| 2588 |
+
bool graph_evaluated_or_captured = false;
|
| 2589 |
+
|
| 2590 |
+
while (!graph_evaluated_or_captured) {
|
| 2591 |
+
// Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
|
| 2592 |
+
// With the use of CUDA graphs, the execution will be performed by the graph launch.
|
| 2593 |
+
if (!use_cuda_graph || cuda_graph_update_required) {
|
| 2594 |
+
for (int i = 0; i < cgraph->n_nodes; i++) {
|
| 2595 |
+
ggml_tensor * node = cgraph->nodes[i];
|
| 2596 |
+
|
| 2597 |
+
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
| 2598 |
+
continue;
|
| 2599 |
+
}
|
| 2600 |
+
|
| 2601 |
#ifndef NDEBUG
|
| 2602 |
+
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
|
| 2603 |
+
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
| 2604 |
+
if (node->src[j] != nullptr) {
|
| 2605 |
+
assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer));
|
| 2606 |
+
}
|
| 2607 |
+
}
|
| 2608 |
+
#endif
|
| 2609 |
+
|
| 2610 |
+
bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
|
| 2611 |
+
if (!ok) {
|
| 2612 |
+
fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
|
| 2613 |
+
}
|
| 2614 |
+
GGML_ASSERT(ok);
|
| 2615 |
}
|
| 2616 |
}
|
| 2617 |
+
|
| 2618 |
+
#ifdef USE_CUDA_GRAPH
|
| 2619 |
+
if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
|
| 2620 |
+
if (cuda_ctx->cuda_graph->graph != nullptr) {
|
| 2621 |
+
CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph));
|
| 2622 |
+
cuda_ctx->cuda_graph->graph = nullptr;
|
| 2623 |
+
}
|
| 2624 |
+
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
|
| 2625 |
+
|
| 2626 |
+
#if 0
|
| 2627 |
+
if (disable_cuda_graphs_due_to_failed_capture) {
|
| 2628 |
+
use_cuda_graph = false;
|
| 2629 |
+
cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true;
|
| 2630 |
+
#ifndef NDEBUG
|
| 2631 |
+
fprintf(stderr, "%s: disabling CUDA graphs due to failed graph capture\n", __func__);
|
| 2632 |
#endif
|
| 2633 |
+
} else {
|
| 2634 |
+
graph_evaluated_or_captured = true; // CUDA graph has been captured
|
| 2635 |
+
}
|
| 2636 |
+
#endif
|
| 2637 |
+
graph_evaluated_or_captured = true; // CUDA graph has been captured
|
| 2638 |
+
} else {
|
| 2639 |
+
graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
|
| 2640 |
+
}
|
| 2641 |
+
}
|
| 2642 |
|
| 2643 |
+
if (use_cuda_graph) {
|
| 2644 |
+
if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
|
| 2645 |
+
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
|
| 2646 |
}
|
| 2647 |
+
|
| 2648 |
+
// Perform update to graph (if required for this token), and change copy parameter (required for every token)
|
| 2649 |
+
|
| 2650 |
+
if (cuda_graph_update_required) {
|
| 2651 |
+
// Extract nodes from graph
|
| 2652 |
+
if (cuda_ctx->cuda_graph->num_nodes == 0) {
|
| 2653 |
+
// First call with null argument gets number of nodes in graph
|
| 2654 |
+
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
|
| 2655 |
+
}
|
| 2656 |
+
// Subsequent call with non-null argument gets nodes
|
| 2657 |
+
cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
|
| 2658 |
+
cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
|
| 2659 |
+
if (cuda_ctx->cuda_graph->num_nodes > 0) {
|
| 2660 |
+
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
|
| 2661 |
+
|
| 2662 |
+
// Loop over nodes, and extract kernel parameters from each node
|
| 2663 |
+
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
|
| 2664 |
+
cudaGraphNodeType node_type;
|
| 2665 |
+
CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type));
|
| 2666 |
+
if (node_type == cudaGraphNodeTypeKernel) {
|
| 2667 |
+
cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime
|
| 2668 |
+
if (stat == cudaErrorInvalidDeviceFunction) {
|
| 2669 |
+
// Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
|
| 2670 |
+
// We don't need to update blas nodes, so clear error and move on.
|
| 2671 |
+
cudaGetLastError();
|
| 2672 |
+
} else {
|
| 2673 |
+
GGML_ASSERT(stat == cudaSuccess);
|
| 2674 |
+
}
|
| 2675 |
+
}
|
| 2676 |
+
}
|
| 2677 |
+
}
|
| 2678 |
+
}
|
| 2679 |
+
|
| 2680 |
+
// One of the arguments to the copy kernel is updated for each token, hence we need to
|
| 2681 |
+
// replace that argument with the updated value in the CUDA graph
|
| 2682 |
+
if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
|
| 2683 |
+
int k = 0;
|
| 2684 |
+
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
|
| 2685 |
+
if (cuda_ctx->cuda_graph->params[i].func == ggml_cuda_cpy_fn_ptr) {
|
| 2686 |
+
char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
|
| 2687 |
+
cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
|
| 2688 |
+
CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
|
| 2689 |
+
}
|
| 2690 |
+
}
|
| 2691 |
+
}
|
| 2692 |
+
|
| 2693 |
+
// Update graph executable
|
| 2694 |
+
cudaGraphExecUpdateResultInfo result_info;
|
| 2695 |
+
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
|
| 2696 |
+
if (stat == cudaErrorGraphExecUpdateFailure) {
|
| 2697 |
+
#ifndef NDEBUG
|
| 2698 |
+
fprintf(stderr, "%s: CUDA graph update failed\n", __func__);
|
| 2699 |
+
#endif
|
| 2700 |
+
// The pre-existing graph exec cannot be updated due to violated constraints
|
| 2701 |
+
// so instead clear error and re-instantiate
|
| 2702 |
+
cudaGetLastError();
|
| 2703 |
+
CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
|
| 2704 |
+
cuda_ctx->cuda_graph->instance = nullptr;
|
| 2705 |
+
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
|
| 2706 |
+
} else {
|
| 2707 |
+
GGML_ASSERT(stat == cudaSuccess);
|
| 2708 |
+
}
|
| 2709 |
+
// Launch graph
|
| 2710 |
+
CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
|
| 2711 |
+
#else
|
| 2712 |
+
graph_evaluated_or_captured = true;
|
| 2713 |
+
#endif // USE_CUDA_GRAPH
|
| 2714 |
}
|
| 2715 |
|
| 2716 |
return GGML_STATUS_SUCCESS;
|
|
@@ -31,5 +31,4 @@ void ggml_cuda_op_clamp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
| 31 |
memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
|
| 32 |
|
| 33 |
clamp_f32_cuda(src0_d, dst_d, min, max, ggml_nelements(src0), stream);
|
| 34 |
-
CUDA_CHECK(cudaGetLastError());
|
| 35 |
}
|
|
|
|
| 31 |
memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
|
| 32 |
|
| 33 |
clamp_f32_cuda(src0_d, dst_d, min, max, ggml_nelements(src0), stream);
|
|
|
|
| 34 |
}
|
|
@@ -19,6 +19,7 @@
|
|
| 19 |
#include <cassert>
|
| 20 |
#include <cfloat>
|
| 21 |
#include <string>
|
|
|
|
| 22 |
|
| 23 |
#if defined(GGML_USE_HIPBLAS)
|
| 24 |
#include <hip/hip_runtime.h>
|
|
@@ -526,6 +527,43 @@ struct ggml_tensor_extra_gpu {
|
|
| 526 |
cudaEvent_t events[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; // events for synchronizing multiple GPUs
|
| 527 |
};
|
| 528 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 529 |
struct ggml_backend_cuda_context {
|
| 530 |
int device;
|
| 531 |
std::string name;
|
|
@@ -534,6 +572,8 @@ struct ggml_backend_cuda_context {
|
|
| 534 |
cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
|
| 535 |
cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
|
| 536 |
|
|
|
|
|
|
|
| 537 |
explicit ggml_backend_cuda_context(int device) :
|
| 538 |
device(device),
|
| 539 |
name(GGML_CUDA_NAME + std::to_string(device)) {
|
|
|
|
| 19 |
#include <cassert>
|
| 20 |
#include <cfloat>
|
| 21 |
#include <string>
|
| 22 |
+
#include <vector>
|
| 23 |
|
| 24 |
#if defined(GGML_USE_HIPBLAS)
|
| 25 |
#include <hip/hip_runtime.h>
|
|
|
|
| 527 |
cudaEvent_t events[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; // events for synchronizing multiple GPUs
|
| 528 |
};
|
| 529 |
|
| 530 |
+
|
| 531 |
+
#if (CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)
|
| 532 |
+
#define USE_CUDA_GRAPH
|
| 533 |
+
#endif
|
| 534 |
+
|
| 535 |
+
struct ggml_graph_node_properties {
|
| 536 |
+
void * node_address;
|
| 537 |
+
ggml_op node_op;
|
| 538 |
+
int64_t ne[GGML_MAX_DIMS];
|
| 539 |
+
size_t nb[GGML_MAX_DIMS];
|
| 540 |
+
void * src_address[GGML_MAX_SRC];
|
| 541 |
+
};
|
| 542 |
+
|
| 543 |
+
struct ggml_cuda_graph {
|
| 544 |
+
#ifdef USE_CUDA_GRAPH
|
| 545 |
+
~ggml_cuda_graph() {
|
| 546 |
+
if (instance != nullptr) {
|
| 547 |
+
CUDA_CHECK(cudaGraphExecDestroy(instance));
|
| 548 |
+
}
|
| 549 |
+
if (graph != nullptr) {
|
| 550 |
+
CUDA_CHECK(cudaGraphDestroy(graph));
|
| 551 |
+
}
|
| 552 |
+
}
|
| 553 |
+
cudaGraph_t graph = nullptr;
|
| 554 |
+
cudaGraphExec_t instance = nullptr;
|
| 555 |
+
size_t num_nodes = 0;
|
| 556 |
+
std::vector<cudaGraphNode_t> nodes;
|
| 557 |
+
std::vector<cudaKernelNodeParams> params;
|
| 558 |
+
bool disable_due_to_gpu_arch = false;
|
| 559 |
+
bool disable_due_to_too_many_updates = false;
|
| 560 |
+
bool disable_due_to_failed_graph_capture = false;
|
| 561 |
+
int number_consecutive_updates = 0;
|
| 562 |
+
std::vector<ggml_graph_node_properties> ggml_graph_properties;
|
| 563 |
+
std::vector<char **> updated_kernel_arg;
|
| 564 |
+
#endif
|
| 565 |
+
};
|
| 566 |
+
|
| 567 |
struct ggml_backend_cuda_context {
|
| 568 |
int device;
|
| 569 |
std::string name;
|
|
|
|
| 572 |
cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
|
| 573 |
cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
|
| 574 |
|
| 575 |
+
std::unique_ptr<ggml_cuda_graph> cuda_graph;
|
| 576 |
+
|
| 577 |
explicit ggml_backend_cuda_context(int device) :
|
| 578 |
device(device),
|
| 579 |
name(GGML_CUDA_NAME + std::to_string(device)) {
|
|
@@ -727,7 +727,6 @@ static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict_
|
|
| 727 |
}
|
| 728 |
|
| 729 |
to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
| 730 |
-
int id;
|
| 731 |
switch (type) {
|
| 732 |
case GGML_TYPE_Q4_0:
|
| 733 |
return dequantize_row_q4_0_cuda;
|
|
@@ -738,8 +737,7 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
|
| 738 |
case GGML_TYPE_Q5_1:
|
| 739 |
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
|
| 740 |
case GGML_TYPE_Q8_0:
|
| 741 |
-
|
| 742 |
-
if (ggml_cuda_info().devices[id].cc >= CC_PASCAL) {
|
| 743 |
return dequantize_block_q8_0_f16_cuda;
|
| 744 |
}
|
| 745 |
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
|
|
|
|
| 727 |
}
|
| 728 |
|
| 729 |
to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
|
|
|
| 730 |
switch (type) {
|
| 731 |
case GGML_TYPE_Q4_0:
|
| 732 |
return dequantize_row_q4_0_cuda;
|
|
|
|
| 737 |
case GGML_TYPE_Q5_1:
|
| 738 |
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
|
| 739 |
case GGML_TYPE_Q8_0:
|
| 740 |
+
if (ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= CC_PASCAL) {
|
|
|
|
| 741 |
return dequantize_block_q8_0_f16_cuda;
|
| 742 |
}
|
| 743 |
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
|
|
@@ -459,3 +459,32 @@ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
| 459 |
const ggml_tensor * src0 = dst->src[0];
|
| 460 |
ggml_cuda_cpy(ctx, src0, dst);
|
| 461 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 459 |
const ggml_tensor * src0 = dst->src[0];
|
| 460 |
ggml_cuda_cpy(ctx, src0, dst);
|
| 461 |
}
|
| 462 |
+
|
| 463 |
+
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
| 464 |
+
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
| 465 |
+
return (void*) cpy_f32_f16<cpy_1_f32_f32>;
|
| 466 |
+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
| 467 |
+
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
|
| 468 |
+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
| 469 |
+
return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
|
| 470 |
+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
| 471 |
+
return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
|
| 472 |
+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
| 473 |
+
return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
|
| 474 |
+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
| 475 |
+
return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
|
| 476 |
+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
| 477 |
+
return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
|
| 478 |
+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
| 479 |
+
return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
|
| 480 |
+
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
| 481 |
+
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
|
| 482 |
+
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
| 483 |
+
return (void*) cpy_f32_f16<cpy_1_f16_f32>;
|
| 484 |
+
} else {
|
| 485 |
+
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
|
| 486 |
+
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
| 487 |
+
GGML_ASSERT(false);
|
| 488 |
+
}
|
| 489 |
+
}
|
| 490 |
+
|
|
@@ -5,3 +5,5 @@
|
|
| 5 |
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
|
| 6 |
|
| 7 |
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
|
|
|
|
|
|
|
| 5 |
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
|
| 6 |
|
| 7 |
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
| 8 |
+
|
| 9 |
+
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1);
|
|
@@ -1735,8 +1735,7 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
|
|
| 1735 |
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
| 1736 |
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
| 1737 |
|
| 1738 |
-
int id;
|
| 1739 |
-
CUDA_CHECK(cudaGetDevice(&id));
|
| 1740 |
const int compute_capability = ggml_cuda_info().devices[id].cc;
|
| 1741 |
|
| 1742 |
int mmq_x, mmq_y, nwarps;
|
|
@@ -1780,8 +1779,7 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(
|
|
| 1780 |
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
| 1781 |
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
| 1782 |
|
| 1783 |
-
int id;
|
| 1784 |
-
CUDA_CHECK(cudaGetDevice(&id));
|
| 1785 |
const int compute_capability = ggml_cuda_info().devices[id].cc;
|
| 1786 |
|
| 1787 |
int mmq_x, mmq_y, nwarps;
|
|
@@ -1825,8 +1823,7 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(
|
|
| 1825 |
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
| 1826 |
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
| 1827 |
|
| 1828 |
-
int id;
|
| 1829 |
-
CUDA_CHECK(cudaGetDevice(&id));
|
| 1830 |
const int compute_capability = ggml_cuda_info().devices[id].cc;
|
| 1831 |
|
| 1832 |
int mmq_x, mmq_y, nwarps;
|
|
@@ -1870,8 +1867,7 @@ static void ggml_mul_mat_q5_1_q8_1_cuda(
|
|
| 1870 |
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
| 1871 |
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
| 1872 |
|
| 1873 |
-
int id;
|
| 1874 |
-
CUDA_CHECK(cudaGetDevice(&id));
|
| 1875 |
const int compute_capability = ggml_cuda_info().devices[id].cc;
|
| 1876 |
|
| 1877 |
int mmq_x, mmq_y, nwarps;
|
|
@@ -1915,8 +1911,7 @@ static void ggml_mul_mat_q8_0_q8_1_cuda(
|
|
| 1915 |
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
| 1916 |
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
| 1917 |
|
| 1918 |
-
int id;
|
| 1919 |
-
CUDA_CHECK(cudaGetDevice(&id));
|
| 1920 |
const int compute_capability = ggml_cuda_info().devices[id].cc;
|
| 1921 |
|
| 1922 |
int mmq_x, mmq_y, nwarps;
|
|
@@ -1960,8 +1955,7 @@ static void ggml_mul_mat_q2_K_q8_1_cuda(
|
|
| 1960 |
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
| 1961 |
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
| 1962 |
|
| 1963 |
-
int id;
|
| 1964 |
-
CUDA_CHECK(cudaGetDevice(&id));
|
| 1965 |
const int compute_capability = ggml_cuda_info().devices[id].cc;
|
| 1966 |
|
| 1967 |
int mmq_x, mmq_y, nwarps;
|
|
@@ -2007,8 +2001,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
|
|
| 2007 |
|
| 2008 |
#if QK_K == 256
|
| 2009 |
|
| 2010 |
-
int id;
|
| 2011 |
-
CUDA_CHECK(cudaGetDevice(&id));
|
| 2012 |
const int compute_capability = ggml_cuda_info().devices[id].cc;
|
| 2013 |
|
| 2014 |
int mmq_x, mmq_y, nwarps;
|
|
@@ -2053,8 +2046,7 @@ static void ggml_mul_mat_q4_K_q8_1_cuda(
|
|
| 2053 |
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
| 2054 |
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
| 2055 |
|
| 2056 |
-
int id;
|
| 2057 |
-
CUDA_CHECK(cudaGetDevice(&id));
|
| 2058 |
const int compute_capability = ggml_cuda_info().devices[id].cc;
|
| 2059 |
|
| 2060 |
int mmq_x, mmq_y, nwarps;
|
|
@@ -2098,8 +2090,7 @@ static void ggml_mul_mat_q5_K_q8_1_cuda(
|
|
| 2098 |
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
| 2099 |
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
| 2100 |
|
| 2101 |
-
int id;
|
| 2102 |
-
CUDA_CHECK(cudaGetDevice(&id));
|
| 2103 |
const int compute_capability = ggml_cuda_info().devices[id].cc;
|
| 2104 |
|
| 2105 |
int mmq_x, mmq_y, nwarps;
|
|
@@ -2143,8 +2134,7 @@ static void ggml_mul_mat_q6_K_q8_1_cuda(
|
|
| 2143 |
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
| 2144 |
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
| 2145 |
|
| 2146 |
-
int id;
|
| 2147 |
-
CUDA_CHECK(cudaGetDevice(&id));
|
| 2148 |
const int compute_capability = ggml_cuda_info().devices[id].cc;
|
| 2149 |
|
| 2150 |
int mmq_x, mmq_y, nwarps;
|
|
|
|
| 1735 |
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
| 1736 |
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
| 1737 |
|
| 1738 |
+
int id = ggml_cuda_get_device();
|
|
|
|
| 1739 |
const int compute_capability = ggml_cuda_info().devices[id].cc;
|
| 1740 |
|
| 1741 |
int mmq_x, mmq_y, nwarps;
|
|
|
|
| 1779 |
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
| 1780 |
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
| 1781 |
|
| 1782 |
+
int id = ggml_cuda_get_device();
|
|
|
|
| 1783 |
const int compute_capability = ggml_cuda_info().devices[id].cc;
|
| 1784 |
|
| 1785 |
int mmq_x, mmq_y, nwarps;
|
|
|
|
| 1823 |
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
| 1824 |
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
| 1825 |
|
| 1826 |
+
int id = ggml_cuda_get_device();
|
|
|
|
| 1827 |
const int compute_capability = ggml_cuda_info().devices[id].cc;
|
| 1828 |
|
| 1829 |
int mmq_x, mmq_y, nwarps;
|
|
|
|
| 1867 |
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
| 1868 |
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
| 1869 |
|
| 1870 |
+
int id = ggml_cuda_get_device();
|
|
|
|
| 1871 |
const int compute_capability = ggml_cuda_info().devices[id].cc;
|
| 1872 |
|
| 1873 |
int mmq_x, mmq_y, nwarps;
|
|
|
|
| 1911 |
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
| 1912 |
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
| 1913 |
|
| 1914 |
+
int id = ggml_cuda_get_device();
|
|
|
|
| 1915 |
const int compute_capability = ggml_cuda_info().devices[id].cc;
|
| 1916 |
|
| 1917 |
int mmq_x, mmq_y, nwarps;
|
|
|
|
| 1955 |
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
| 1956 |
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
| 1957 |
|
| 1958 |
+
int id = ggml_cuda_get_device();
|
|
|
|
| 1959 |
const int compute_capability = ggml_cuda_info().devices[id].cc;
|
| 1960 |
|
| 1961 |
int mmq_x, mmq_y, nwarps;
|
|
|
|
| 2001 |
|
| 2002 |
#if QK_K == 256
|
| 2003 |
|
| 2004 |
+
int id = ggml_cuda_get_device();
|
|
|
|
| 2005 |
const int compute_capability = ggml_cuda_info().devices[id].cc;
|
| 2006 |
|
| 2007 |
int mmq_x, mmq_y, nwarps;
|
|
|
|
| 2046 |
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
| 2047 |
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
| 2048 |
|
| 2049 |
+
int id = ggml_cuda_get_device();
|
|
|
|
| 2050 |
const int compute_capability = ggml_cuda_info().devices[id].cc;
|
| 2051 |
|
| 2052 |
int mmq_x, mmq_y, nwarps;
|
|
|
|
| 2090 |
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
| 2091 |
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
| 2092 |
|
| 2093 |
+
int id = ggml_cuda_get_device();
|
|
|
|
| 2094 |
const int compute_capability = ggml_cuda_info().devices[id].cc;
|
| 2095 |
|
| 2096 |
int mmq_x, mmq_y, nwarps;
|
|
|
|
| 2134 |
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
| 2135 |
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
| 2136 |
|
| 2137 |
+
int id = ggml_cuda_get_device();
|
|
|
|
| 2138 |
const int compute_capability = ggml_cuda_info().devices[id].cc;
|
| 2139 |
|
| 2140 |
int mmq_x, mmq_y, nwarps;
|
|
@@ -89,8 +89,7 @@ static void mul_mat_vec_q_cuda(
|
|
| 89 |
GGML_ASSERT(ncols_x % qk == 0);
|
| 90 |
GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
|
| 91 |
|
| 92 |
-
int id;
|
| 93 |
-
CUDA_CHECK(cudaGetDevice(&id));
|
| 94 |
|
| 95 |
int64_t nwarps = 1;
|
| 96 |
int64_t rows_per_cuda_block = 1;
|
|
@@ -328,8 +327,7 @@ void ggml_cuda_op_mul_mat_vec_q(
|
|
| 328 |
|
| 329 |
const int64_t ne0 = dst->ne[0];
|
| 330 |
|
| 331 |
-
int id;
|
| 332 |
-
CUDA_CHECK(cudaGetDevice(&id));
|
| 333 |
|
| 334 |
// the main device has a larger memory buffer to hold the results from all GPUs
|
| 335 |
// nrows_dst == nrows of the matrix that the kernel writes into
|
|
|
|
| 89 |
GGML_ASSERT(ncols_x % qk == 0);
|
| 90 |
GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
|
| 91 |
|
| 92 |
+
int id = ggml_cuda_get_device();
|
|
|
|
| 93 |
|
| 94 |
int64_t nwarps = 1;
|
| 95 |
int64_t rows_per_cuda_block = 1;
|
|
|
|
| 327 |
|
| 328 |
const int64_t ne0 = dst->ne[0];
|
| 329 |
|
| 330 |
+
int id = ggml_cuda_get_device();
|
|
|
|
| 331 |
|
| 332 |
// the main device has a larger memory buffer to hold the results from all GPUs
|
| 333 |
// nrows_dst == nrows of the matrix that the kernel writes into
|
|
@@ -28,5 +28,4 @@ void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
| 28 |
memcpy(&scale, dst->op_params, sizeof(float));
|
| 29 |
|
| 30 |
scale_f32_cuda(src0_d, dst_d, scale, ggml_nelements(src0), stream);
|
| 31 |
-
CUDA_CHECK(cudaGetLastError());
|
| 32 |
}
|
|
|
|
| 28 |
memcpy(&scale, dst->op_params, sizeof(float));
|
| 29 |
|
| 30 |
scale_f32_cuda(src0_d, dst_d, scale, ggml_nelements(src0), stream);
|
|
|
|
| 31 |
}
|