sachaarbonel commited on
Commit
170eb31
·
unverified ·
1 Parent(s): a69c121

server : graceful shutdown, atomic server state, and health endpoint Improvements (#3243)

Browse files

* feat(server): implement graceful shutdown and server state management

* refactor(server): use lambda capture by reference in server.cpp

Files changed (1) hide show
  1. examples/server/server.cpp +92 -21
examples/server/server.cpp CHANGED
@@ -14,10 +14,23 @@
14
  #include <string>
15
  #include <thread>
16
  #include <vector>
 
 
 
 
 
 
 
 
17
 
18
  using namespace httplib;
19
  using json = nlohmann::ordered_json;
20
 
 
 
 
 
 
21
  namespace {
22
 
23
  // output formats
@@ -27,6 +40,20 @@ const std::string srt_format = "srt";
27
  const std::string vjson_format = "verbose_json";
28
  const std::string vtt_format = "vtt";
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  struct server_params
31
  {
32
  std::string hostname = "127.0.0.1";
@@ -654,6 +681,9 @@ int main(int argc, char ** argv) {
654
  }
655
  }
656
 
 
 
 
657
  struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
658
 
659
  if (ctx == nullptr) {
@@ -663,9 +693,10 @@ int main(int argc, char ** argv) {
663
 
664
  // initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
665
  whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
 
 
666
 
667
- Server svr;
668
- svr.set_default_headers({{"Server", "whisper.cpp"},
669
  {"Access-Control-Allow-Origin", "*"},
670
  {"Access-Control-Allow-Headers", "content-type, authorization"}});
671
 
@@ -744,15 +775,15 @@ int main(int argc, char ** argv) {
744
  whisper_params default_params = params;
745
 
746
  // this is only called if no index.html is found in the public --path
747
- svr.Get(sparams.request_path + "/", [&default_content](const Request &, Response &res){
748
  res.set_content(default_content, "text/html");
749
  return false;
750
  });
751
 
752
- svr.Options(sparams.request_path + sparams.inference_path, [&](const Request &, Response &){
753
  });
754
 
755
- svr.Post(sparams.request_path + sparams.inference_path, [&](const Request &req, Response &res){
756
  // acquire whisper model mutex lock
757
  std::lock_guard<std::mutex> lock(whisper_mutex);
758
 
@@ -1068,8 +1099,9 @@ int main(int argc, char ** argv) {
1068
  // reset params to their defaults
1069
  params = default_params;
1070
  });
1071
- svr.Post(sparams.request_path + "/load", [&](const Request &req, Response &res){
1072
  std::lock_guard<std::mutex> lock(whisper_mutex);
 
1073
  if (!req.has_file("model"))
1074
  {
1075
  fprintf(stderr, "error: no 'model' field in the request\n");
@@ -1101,18 +1133,25 @@ int main(int argc, char ** argv) {
1101
  // initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
1102
  whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
1103
 
 
1104
  const std::string success = "Load was successful!";
1105
  res.set_content(success, "application/text");
1106
 
1107
  // check if the model is in the file system
1108
  });
1109
 
1110
- svr.Get(sparams.request_path + "/health", [&](const Request &, Response &res){
1111
- const std::string health_response = "{\"status\":\"ok\"}";
1112
- res.set_content(health_response, "application/json");
 
 
 
 
 
 
1113
  });
1114
 
1115
- svr.set_exception_handler([](const Request &, Response &res, std::exception_ptr ep) {
1116
  const char fmt[] = "500 Internal Server Error\n%s";
1117
  char buf[BUFSIZ];
1118
  try {
@@ -1126,7 +1165,7 @@ int main(int argc, char ** argv) {
1126
  res.status = 500;
1127
  });
1128
 
1129
- svr.set_error_handler([](const Request &req, Response &res) {
1130
  if (res.status == 400) {
1131
  res.set_content("Invalid request", "text/plain");
1132
  } else if (res.status != 500) {
@@ -1136,10 +1175,10 @@ int main(int argc, char ** argv) {
1136
  });
1137
 
1138
  // set timeouts and change hostname and port
1139
- svr.set_read_timeout(sparams.read_timeout);
1140
- svr.set_write_timeout(sparams.write_timeout);
1141
 
1142
- if (!svr.bind_to_port(sparams.hostname, sparams.port))
1143
  {
1144
  fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n",
1145
  sparams.hostname.c_str(), sparams.port);
@@ -1147,18 +1186,50 @@ int main(int argc, char ** argv) {
1147
  }
1148
 
1149
  // Set the base directory for serving static files
1150
- svr.set_base_dir(sparams.public_path);
1151
 
1152
  // to make it ctrl+clickable:
1153
  printf("\nwhisper server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port);
1154
 
1155
- if (!svr.listen_after_bind())
1156
- {
1157
- return 1;
1158
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1159
 
1160
- whisper_print_timings(ctx);
1161
- whisper_free(ctx);
1162
 
1163
  return 0;
1164
  }
 
14
  #include <string>
15
  #include <thread>
16
  #include <vector>
17
+ #include <memory>
18
+ #include <csignal>
19
+ #include <atomic>
20
+ #include <functional>
21
+ #include <cstdlib>
22
+ #if defined (_WIN32)
23
+ #include <windows.h>
24
+ #endif
25
 
26
  using namespace httplib;
27
  using json = nlohmann::ordered_json;
28
 
29
+ enum server_state {
30
+ SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet
31
+ SERVER_STATE_READY, // Server is ready and model is loaded
32
+ };
33
+
34
  namespace {
35
 
36
  // output formats
 
40
  const std::string vjson_format = "verbose_json";
41
  const std::string vtt_format = "vtt";
42
 
43
+ std::function<void(int)> shutdown_handler;
44
+ std::atomic_flag is_terminating = ATOMIC_FLAG_INIT;
45
+
46
+ inline void signal_handler(int signal) {
47
+ if (is_terminating.test_and_set()) {
48
+ // in case it hangs, we can force terminate the server by hitting Ctrl+C twice
49
+ // this is for better developer experience, we can remove when the server is stable enough
50
+ fprintf(stderr, "Received second interrupt, terminating immediately.\n");
51
+ exit(1);
52
+ }
53
+
54
+ shutdown_handler(signal);
55
+ }
56
+
57
  struct server_params
58
  {
59
  std::string hostname = "127.0.0.1";
 
681
  }
682
  }
683
 
684
+ std::unique_ptr<httplib::Server> svr = std::make_unique<httplib::Server>();
685
+ std::atomic<server_state> state{SERVER_STATE_LOADING_MODEL};
686
+
687
  struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
688
 
689
  if (ctx == nullptr) {
 
693
 
694
  // initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
695
  whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
696
+ state.store(SERVER_STATE_READY);
697
+
698
 
699
+ svr->set_default_headers({{"Server", "whisper.cpp"},
 
700
  {"Access-Control-Allow-Origin", "*"},
701
  {"Access-Control-Allow-Headers", "content-type, authorization"}});
702
 
 
775
  whisper_params default_params = params;
776
 
777
  // this is only called if no index.html is found in the public --path
778
+ svr->Get(sparams.request_path + "/", [&](const Request &, Response &res){
779
  res.set_content(default_content, "text/html");
780
  return false;
781
  });
782
 
783
+ svr->Options(sparams.request_path + sparams.inference_path, [&](const Request &, Response &){
784
  });
785
 
786
+ svr->Post(sparams.request_path + sparams.inference_path, [&](const Request &req, Response &res){
787
  // acquire whisper model mutex lock
788
  std::lock_guard<std::mutex> lock(whisper_mutex);
789
 
 
1099
  // reset params to their defaults
1100
  params = default_params;
1101
  });
1102
+ svr->Post(sparams.request_path + "/load", [&](const Request &req, Response &res){
1103
  std::lock_guard<std::mutex> lock(whisper_mutex);
1104
+ state.store(SERVER_STATE_LOADING_MODEL);
1105
  if (!req.has_file("model"))
1106
  {
1107
  fprintf(stderr, "error: no 'model' field in the request\n");
 
1133
  // initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
1134
  whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
1135
 
1136
+ state.store(SERVER_STATE_READY);
1137
  const std::string success = "Load was successful!";
1138
  res.set_content(success, "application/text");
1139
 
1140
  // check if the model is in the file system
1141
  });
1142
 
1143
+ svr->Get(sparams.request_path + "/health", [&](const Request &, Response &res){
1144
+ server_state current_state = state.load();
1145
+ if (current_state == SERVER_STATE_READY) {
1146
+ const std::string health_response = "{\"status\":\"ok\"}";
1147
+ res.set_content(health_response, "application/json");
1148
+ } else {
1149
+ res.set_content("{\"status\":\"loading model\"}", "application/json");
1150
+ res.status = 503;
1151
+ }
1152
  });
1153
 
1154
+ svr->set_exception_handler([](const Request &, Response &res, std::exception_ptr ep) {
1155
  const char fmt[] = "500 Internal Server Error\n%s";
1156
  char buf[BUFSIZ];
1157
  try {
 
1165
  res.status = 500;
1166
  });
1167
 
1168
+ svr->set_error_handler([](const Request &req, Response &res) {
1169
  if (res.status == 400) {
1170
  res.set_content("Invalid request", "text/plain");
1171
  } else if (res.status != 500) {
 
1175
  });
1176
 
1177
  // set timeouts and change hostname and port
1178
+ svr->set_read_timeout(sparams.read_timeout);
1179
+ svr->set_write_timeout(sparams.write_timeout);
1180
 
1181
+ if (!svr->bind_to_port(sparams.hostname, sparams.port))
1182
  {
1183
  fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n",
1184
  sparams.hostname.c_str(), sparams.port);
 
1186
  }
1187
 
1188
  // Set the base directory for serving static files
1189
+ svr->set_base_dir(sparams.public_path);
1190
 
1191
  // to make it ctrl+clickable:
1192
  printf("\nwhisper server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port);
1193
 
1194
+ shutdown_handler = [&](int signal) {
1195
+ printf("\nCaught signal %d, shutting down gracefully...\n", signal);
1196
+ if (svr) {
1197
+ svr->stop();
1198
+ }
1199
+ };
1200
+
1201
+ #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
1202
+ struct sigaction sigint_action;
1203
+ sigint_action.sa_handler = signal_handler;
1204
+ sigemptyset (&sigint_action.sa_mask);
1205
+ sigint_action.sa_flags = 0;
1206
+ sigaction(SIGINT, &sigint_action, NULL);
1207
+ sigaction(SIGTERM, &sigint_action, NULL);
1208
+ #elif defined (_WIN32)
1209
+ auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
1210
+ return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
1211
+ };
1212
+ SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
1213
+ #endif
1214
+
1215
+ // clean up function, to be called before exit
1216
+ auto clean_up = [&]() {
1217
+ whisper_print_timings(ctx);
1218
+ whisper_free(ctx);
1219
+ };
1220
+
1221
+ std::thread t([&] {
1222
+ if (!svr->listen_after_bind()) {
1223
+ fprintf(stderr, "error: server listen failed\n");
1224
+ }
1225
+ });
1226
+
1227
+ svr->wait_until_ready();
1228
+
1229
+ t.join();
1230
+
1231
 
1232
+ clean_up();
 
1233
 
1234
  return 0;
1235
  }