rgerganov commited on
Commit
81a6cae
·
1 Parent(s): 1101050

rpc : better caching of the base buffer pointer (llama/11331)

Browse files

There is no need to use map, just store the base pointer in the buffer
context.

Files changed (1) hide show
  1. ggml/src/ggml-rpc/ggml-rpc.cpp +6 -7
ggml/src/ggml-rpc/ggml-rpc.cpp CHANGED
@@ -181,7 +181,7 @@ struct ggml_backend_rpc_context {
181
 
182
  struct ggml_backend_rpc_buffer_context {
183
  std::shared_ptr<socket_t> sock;
184
- std::unordered_map<ggml_backend_buffer_t, void *> base_cache;
185
  uint64_t remote_ptr;
186
  };
187
 
@@ -423,16 +423,15 @@ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
423
 
424
  static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
425
  ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
426
- if (ctx->base_cache.find(buffer) != ctx->base_cache.end()) {
427
- return ctx->base_cache[buffer];
428
  }
429
  rpc_msg_buffer_get_base_req request = {ctx->remote_ptr};
430
  rpc_msg_buffer_get_base_rsp response;
431
  bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response));
432
  GGML_ASSERT(status);
433
- void * base_ptr = reinterpret_cast<void *>(response.base_ptr);
434
- ctx->base_cache[buffer] = base_ptr;
435
- return base_ptr;
436
  }
437
 
438
  static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
@@ -557,7 +556,7 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
557
  if (response.remote_ptr != 0) {
558
  ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
559
  ggml_backend_rpc_buffer_interface,
560
- new ggml_backend_rpc_buffer_context{sock, {}, response.remote_ptr},
561
  response.remote_size);
562
  return buffer;
563
  } else {
 
181
 
182
  struct ggml_backend_rpc_buffer_context {
183
  std::shared_ptr<socket_t> sock;
184
+ void * base_ptr;
185
  uint64_t remote_ptr;
186
  };
187
 
 
423
 
424
  static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
425
  ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
426
+ if (ctx->base_ptr != nullptr) {
427
+ return ctx->base_ptr;
428
  }
429
  rpc_msg_buffer_get_base_req request = {ctx->remote_ptr};
430
  rpc_msg_buffer_get_base_rsp response;
431
  bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response));
432
  GGML_ASSERT(status);
433
+ ctx->base_ptr = reinterpret_cast<void *>(response.base_ptr);
434
+ return ctx->base_ptr;
 
435
  }
436
 
437
  static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
 
556
  if (response.remote_ptr != 0) {
557
  ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
558
  ggml_backend_rpc_buffer_interface,
559
+ new ggml_backend_rpc_buffer_context{sock, nullptr, response.remote_ptr},
560
  response.remote_size);
561
  return buffer;
562
  } else {