Spaces:
Running
Running
| # Convert Whisper transformer model from PyTorch to ggml format | |
| # | |
| # Usage: python convert-pt-to-ggml.py ~/.cache/whisper/medium.pt ~/path/to/repo/whisper/ ./models/whisper-medium | |
| # | |
| # You need to clone the original repo in ~/path/to/repo/whisper/ | |
| # | |
| # git clone https://github.com/openai/whisper ~/path/to/repo/whisper/ | |
| # | |
| # It is used to various assets needed by the algorithm: | |
| # | |
| # - tokenizer | |
| # - mel filters | |
| # | |
| # Also, you need to have the original models in ~/.cache/whisper/ | |
| # See the original repo for more details. | |
| # | |
| # This script loads the specified model and whisper assets and saves them in ggml format. | |
| # The output is a single binary file containing the following information: | |
| # | |
| # - hparams | |
| # - mel filters | |
| # - tokenizer vocab | |
| # - model variables | |
| # | |
| # For each variable, write the following: | |
| # | |
| # - Number of dimensions (int) | |
| # - Name length (int) | |
| # - Dimensions (int[n_dims]) | |
| # - Name (char[name_length]) | |
| # - Data (float[n_dims]) | |
| # | |
| import io | |
| import os | |
| import sys | |
| import struct | |
| import json | |
| import code | |
| import torch | |
| import numpy as np | |
| import base64 | |
| from pathlib import Path | |
| #from transformers import GPTJForCausalLM | |
| #from transformers import GPT2TokenizerFast | |
| # ref: https://github.com/openai/whisper/blob/8cf36f3508c9acd341a45eb2364239a3d81458b9/whisper/tokenizer.py#L10-L110 | |
| #LANGUAGES = { | |
| # "en": "english", | |
| # "zh": "chinese", | |
| # "de": "german", | |
| # "es": "spanish", | |
| # "ru": "russian", | |
| # "ko": "korean", | |
| # "fr": "french", | |
| # "ja": "japanese", | |
| # "pt": "portuguese", | |
| # "tr": "turkish", | |
| # "pl": "polish", | |
| # "ca": "catalan", | |
| # "nl": "dutch", | |
| # "ar": "arabic", | |
| # "sv": "swedish", | |
| # "it": "italian", | |
| # "id": "indonesian", | |
| # "hi": "hindi", | |
| # "fi": "finnish", | |
| # "vi": "vietnamese", | |
| # "iw": "hebrew", | |
| # "uk": "ukrainian", | |
| # "el": "greek", | |
| # "ms": "malay", | |
| # "cs": "czech", | |
| # "ro": "romanian", | |
| # "da": "danish", | |
| # "hu": "hungarian", | |
| # "ta": "tamil", | |
| # "no": "norwegian", | |
| # "th": "thai", | |
| # "ur": "urdu", | |
| # "hr": "croatian", | |
| # "bg": "bulgarian", | |
| # "lt": "lithuanian", | |
| # "la": "latin", | |
| # "mi": "maori", | |
| # "ml": "malayalam", | |
| # "cy": "welsh", | |
| # "sk": "slovak", | |
| # "te": "telugu", | |
| # "fa": "persian", | |
| # "lv": "latvian", | |
| # "bn": "bengali", | |
| # "sr": "serbian", | |
| # "az": "azerbaijani", | |
| # "sl": "slovenian", | |
| # "kn": "kannada", | |
| # "et": "estonian", | |
| # "mk": "macedonian", | |
| # "br": "breton", | |
| # "eu": "basque", | |
| # "is": "icelandic", | |
| # "hy": "armenian", | |
| # "ne": "nepali", | |
| # "mn": "mongolian", | |
| # "bs": "bosnian", | |
| # "kk": "kazakh", | |
| # "sq": "albanian", | |
| # "sw": "swahili", | |
| # "gl": "galician", | |
| # "mr": "marathi", | |
| # "pa": "punjabi", | |
| # "si": "sinhala", | |
| # "km": "khmer", | |
| # "sn": "shona", | |
| # "yo": "yoruba", | |
| # "so": "somali", | |
| # "af": "afrikaans", | |
| # "oc": "occitan", | |
| # "ka": "georgian", | |
| # "be": "belarusian", | |
| # "tg": "tajik", | |
| # "sd": "sindhi", | |
| # "gu": "gujarati", | |
| # "am": "amharic", | |
| # "yi": "yiddish", | |
| # "lo": "lao", | |
| # "uz": "uzbek", | |
| # "fo": "faroese", | |
| # "ht": "haitian creole", | |
| # "ps": "pashto", | |
| # "tk": "turkmen", | |
| # "nn": "nynorsk", | |
| # "mt": "maltese", | |
| # "sa": "sanskrit", | |
| # "lb": "luxembourgish", | |
| # "my": "myanmar", | |
| # "bo": "tibetan", | |
| # "tl": "tagalog", | |
| # "mg": "malagasy", | |
| # "as": "assamese", | |
| # "tt": "tatar", | |
| # "haw": "hawaiian", | |
| # "ln": "lingala", | |
| # "ha": "hausa", | |
| # "ba": "bashkir", | |
| # "jw": "javanese", | |
| # "su": "sundanese", | |
| #} | |
| ## ref: https://github.com/openai/whisper/blob/8cf36f3508c9acd341a45eb2364239a3d81458b9/whisper/tokenizer.py#L273-L292 | |
| #def build_tokenizer(path_to_whisper_repo: str, name: str = "gpt2"): | |
| # os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| # path = os.path.join(path_to_whisper_repo, "whisper/assets", name) | |
| # tokenizer = GPT2TokenizerFast.from_pretrained(path) | |
| # | |
| # specials = [ | |
| # "<|startoftranscript|>", | |
| # *[f"<|{lang}|>" for lang in LANGUAGES.keys()], | |
| # "<|translate|>", | |
| # "<|transcribe|>", | |
| # "<|startoflm|>", | |
| # "<|startofprev|>", | |
| # "<|nocaptions|>", | |
| # "<|notimestamps|>", | |
| # ] | |
| # | |
| # tokenizer.add_special_tokens(dict(additional_special_tokens=specials)) | |
| # return tokenizer | |
| # ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py | |
| def bytes_to_unicode(): | |
| """ | |
| Returns list of utf-8 byte and a corresponding list of unicode strings. | |
| The reversible bpe codes work on unicode strings. | |
| This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. | |
| When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. | |
| This is a signficant percentage of your normal, say, 32K bpe vocab. | |
| To avoid that, we want lookup tables between utf-8 bytes and unicode strings. | |
| And avoids mapping to whitespace/control characters the bpe code barfs on. | |
| """ | |
| bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) | |
| cs = bs[:] | |
| n = 0 | |
| for b in range(2**8): | |
| if b not in bs: | |
| bs.append(b) | |
| cs.append(2**8+n) | |
| n += 1 | |
| cs = [chr(n) for n in cs] | |
| return dict(zip(bs, cs)) | |
| if len(sys.argv) < 4: | |
| print("Usage: convert-pt-to-ggml.py model.pt path-to-whisper-repo dir-output [use-f32]\n") | |
| sys.exit(1) | |
| fname_inp = Path(sys.argv[1]) | |
| dir_whisper = Path(sys.argv[2]) | |
| dir_out = Path(sys.argv[3]) | |
| # try to load PyTorch binary data | |
| try: | |
| model_bytes = open(fname_inp, "rb").read() | |
| with io.BytesIO(model_bytes) as fp: | |
| checkpoint = torch.load(fp, map_location="cpu") | |
| except Exception: | |
| print("Error: failed to load PyTorch model file:" , fname_inp) | |
| sys.exit(1) | |
| hparams = checkpoint["dims"] | |
| print("hparams:", hparams) | |
| list_vars = checkpoint["model_state_dict"] | |
| #print(list_vars['encoder.positional_embedding']) | |
| #print(list_vars['encoder.conv1.weight']) | |
| #print(list_vars['encoder.conv1.weight'].shape) | |
| # load mel filters | |
| n_mels = hparams["n_mels"] | |
| with np.load(dir_whisper / "whisper" / "assets" / "mel_filters.npz") as f: | |
| filters = torch.from_numpy(f[f"mel_{n_mels}"]) | |
| #print (filters) | |
| #code.interact(local=locals()) | |
| # load tokenizer | |
| # for backwards compatibility, also check for older hf_transformers format tokenizer files | |
| # old format: dir_whisper/whisper/assets/[multilingual/gpt2]/vocab.json | |
| # new format: dir_whisper/whisper/assets/[multilingual/gpt2].tiktoken | |
| multilingual = hparams["n_vocab"] >= 51865 | |
| tokenizer = dir_whisper / "whisper" / "assets" / (multilingual and "multilingual.tiktoken" or "gpt2.tiktoken") | |
| tokenizer_type = "tiktoken" | |
| if not tokenizer.is_file(): | |
| tokenizer = dir_whisper / "whisper" / "assets" / (multilingual and "multilingual" or "gpt2") / "vocab.json" | |
| tokenizer_type = "hf_transformers" | |
| if not tokenizer.is_file(): | |
| print("Error: failed to find either tiktoken or hf_transformers tokenizer file:", tokenizer) | |
| sys.exit(1) | |
| byte_encoder = bytes_to_unicode() | |
| byte_decoder = {v:k for k, v in byte_encoder.items()} | |
| if tokenizer_type == "tiktoken": | |
| with open(tokenizer, "rb") as f: | |
| contents = f.read() | |
| tokens = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line)} | |
| elif tokenizer_type == "hf_transformers": | |
| with open(tokenizer, "r", encoding="utf8") as f: | |
| _tokens_raw = json.load(f) | |
| if '<|endoftext|>' in _tokens_raw: | |
| # ensures exact same model as tokenizer_type == tiktoken | |
| # details: https://github.com/ggerganov/whisper.cpp/pull/725 | |
| del _tokens_raw['<|endoftext|>'] | |
| tokens = {bytes([byte_decoder[c] for c in token]): int(idx) for token, idx in _tokens_raw.items()} | |
| # output in the same directory as the model | |
| fname_out = dir_out / "ggml-model.bin" | |
| # use 16-bit or 32-bit floats | |
| use_f16 = True | |
| if len(sys.argv) > 4: | |
| use_f16 = False | |
| fname_out = dir_out / "ggml-model-f32.bin" | |
| fout = fname_out.open("wb") | |
| fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex | |
| fout.write(struct.pack("i", hparams["n_vocab"])) | |
| fout.write(struct.pack("i", hparams["n_audio_ctx"])) | |
| fout.write(struct.pack("i", hparams["n_audio_state"])) | |
| fout.write(struct.pack("i", hparams["n_audio_head"])) | |
| fout.write(struct.pack("i", hparams["n_audio_layer"])) | |
| fout.write(struct.pack("i", hparams["n_text_ctx"])) | |
| fout.write(struct.pack("i", hparams["n_text_state"])) | |
| fout.write(struct.pack("i", hparams["n_text_head"])) | |
| fout.write(struct.pack("i", hparams["n_text_layer"])) | |
| fout.write(struct.pack("i", hparams["n_mels"])) | |
| fout.write(struct.pack("i", use_f16)) | |
| # write mel filters | |
| fout.write(struct.pack("i", filters.shape[0])) | |
| fout.write(struct.pack("i", filters.shape[1])) | |
| for i in range(filters.shape[0]): | |
| for j in range(filters.shape[1]): | |
| fout.write(struct.pack("f", filters[i][j])) | |
| # write tokenizer | |
| fout.write(struct.pack("i", len(tokens))) | |
| for key in tokens: | |
| fout.write(struct.pack("i", len(key))) | |
| fout.write(key) | |
| for name in list_vars.keys(): | |
| data = list_vars[name].squeeze().numpy() | |
| print("Processing variable: " , name , " with shape: ", data.shape) | |
| # reshape conv bias from [n] to [n, 1] | |
| if name in ["encoder.conv1.bias", "encoder.conv2.bias"]: | |
| data = data.reshape(data.shape[0], 1) | |
| print(f" Reshaped variable: {name} to shape: ", data.shape) | |
| n_dims = len(data.shape) | |
| # looks like the whisper models are in f16 by default | |
| # so we need to convert the small tensors to f32 until we fully support f16 in ggml | |
| # ftype == 0 -> float32, ftype == 1 -> float16 | |
| ftype = 1 | |
| if use_f16: | |
| if n_dims < 2 or \ | |
| name == "encoder.conv1.bias" or \ | |
| name == "encoder.conv2.bias" or \ | |
| name == "encoder.positional_embedding" or \ | |
| name == "decoder.positional_embedding": | |
| print(" Converting to float32") | |
| data = data.astype(np.float32) | |
| ftype = 0 | |
| else: | |
| data = data.astype(np.float32) | |
| ftype = 0 | |
| #if name.startswith("encoder"): | |
| # if name.endswith("mlp.0.weight") or \ | |
| # name.endswith("mlp.2.weight"): | |
| # print(" Transposing") | |
| # data = data.transpose() | |
| # header | |
| str_ = name.encode('utf-8') | |
| fout.write(struct.pack("iii", n_dims, len(str_), ftype)) | |
| for i in range(n_dims): | |
| fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) | |
| fout.write(str_) | |
| # data | |
| data.tofile(fout) | |
| fout.close() | |
| print("Done. Output file: " , fname_out) | |
| print("") | |