tokenizer.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. import base64
  2. import os
  3. from functools import lru_cache
  4. from typing import Optional
  5. from whisper.tokenizer import Tokenizer
  6. import tiktoken
  7. LANGUAGES = {
  8. "en": "english",
  9. "zh": "chinese",
  10. "de": "german",
  11. "es": "spanish",
  12. "ru": "russian",
  13. "ko": "korean",
  14. "fr": "french",
  15. "ja": "japanese",
  16. "pt": "portuguese",
  17. "tr": "turkish",
  18. "pl": "polish",
  19. "ca": "catalan",
  20. "nl": "dutch",
  21. "ar": "arabic",
  22. "sv": "swedish",
  23. "it": "italian",
  24. "id": "indonesian",
  25. "hi": "hindi",
  26. "fi": "finnish",
  27. "vi": "vietnamese",
  28. "he": "hebrew",
  29. "uk": "ukrainian",
  30. "el": "greek",
  31. "ms": "malay",
  32. "cs": "czech",
  33. "ro": "romanian",
  34. "da": "danish",
  35. "hu": "hungarian",
  36. "ta": "tamil",
  37. "no": "norwegian",
  38. "th": "thai",
  39. "ur": "urdu",
  40. "hr": "croatian",
  41. "bg": "bulgarian",
  42. "lt": "lithuanian",
  43. "la": "latin",
  44. "mi": "maori",
  45. "ml": "malayalam",
  46. "cy": "welsh",
  47. "sk": "slovak",
  48. "te": "telugu",
  49. "fa": "persian",
  50. "lv": "latvian",
  51. "bn": "bengali",
  52. "sr": "serbian",
  53. "az": "azerbaijani",
  54. "sl": "slovenian",
  55. "kn": "kannada",
  56. "et": "estonian",
  57. "mk": "macedonian",
  58. "br": "breton",
  59. "eu": "basque",
  60. "is": "icelandic",
  61. "hy": "armenian",
  62. "ne": "nepali",
  63. "mn": "mongolian",
  64. "bs": "bosnian",
  65. "kk": "kazakh",
  66. "sq": "albanian",
  67. "sw": "swahili",
  68. "gl": "galician",
  69. "mr": "marathi",
  70. "pa": "punjabi",
  71. "si": "sinhala",
  72. "km": "khmer",
  73. "sn": "shona",
  74. "yo": "yoruba",
  75. "so": "somali",
  76. "af": "afrikaans",
  77. "oc": "occitan",
  78. "ka": "georgian",
  79. "be": "belarusian",
  80. "tg": "tajik",
  81. "sd": "sindhi",
  82. "gu": "gujarati",
  83. "am": "amharic",
  84. "yi": "yiddish",
  85. "lo": "lao",
  86. "uz": "uzbek",
  87. "fo": "faroese",
  88. "ht": "haitian creole",
  89. "ps": "pashto",
  90. "tk": "turkmen",
  91. "nn": "nynorsk",
  92. "mt": "maltese",
  93. "sa": "sanskrit",
  94. "lb": "luxembourgish",
  95. "my": "myanmar",
  96. "bo": "tibetan",
  97. "tl": "tagalog",
  98. "mg": "malagasy",
  99. "as": "assamese",
  100. "tt": "tatar",
  101. "haw": "hawaiian",
  102. "ln": "lingala",
  103. "ha": "hausa",
  104. "ba": "bashkir",
  105. "jw": "javanese",
  106. "su": "sundanese",
  107. "yue": "cantonese",
  108. "minnan": "minnan",
  109. "wuyu": "wuyu",
  110. "dialect": "dialect",
  111. "zh/en": "zh/en",
  112. "en/zh": "en/zh",
  113. }
  114. # language code lookup by name, with a few language aliases
  115. TO_LANGUAGE_CODE = {
  116. **{language: code for code, language in LANGUAGES.items()},
  117. "burmese": "my",
  118. "valencian": "ca",
  119. "flemish": "nl",
  120. "haitian": "ht",
  121. "letzeburgesch": "lb",
  122. "pushto": "ps",
  123. "panjabi": "pa",
  124. "moldavian": "ro",
  125. "moldovan": "ro",
  126. "sinhalese": "si",
  127. "castilian": "es",
  128. "mandarin": "zh",
  129. }
  130. AUDIO_EVENT = {
  131. "ASR": "ASR",
  132. "AED": "AED",
  133. "SER": "SER",
  134. "Speech": "Speech",
  135. "/Speech": "/Speech",
  136. "BGM": "BGM",
  137. "/BGM": "/BGM",
  138. "Laughter": "Laughter",
  139. "/Laughter": "/Laughter",
  140. "Applause": "Applause",
  141. "/Applause": "/Applause",
  142. }
  143. EMOTION = {
  144. "HAPPY": "HAPPY",
  145. "SAD": "SAD",
  146. "ANGRY": "ANGRY",
  147. "NEUTRAL": "NEUTRAL",
  148. }
  149. TTS_Vocal_Token = {
  150. "TTS/B": "TTS/B",
  151. "TTS/O": "TTS/O",
  152. "TTS/Q": "TTS/Q",
  153. "TTS/A": "TTS/A",
  154. "TTS/CO": "TTS/CO",
  155. "TTS/CL": "TTS/CL",
  156. "TTS/H": "TTS/H",
  157. **{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(1, 14)}
  158. }
  159. @lru_cache(maxsize=None)
  160. def get_encoding(name: str = "gpt2", num_languages: int = 99):
  161. vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
  162. ranks = {
  163. base64.b64decode(token): int(rank)
  164. for token, rank in (line.split() for line in open(vocab_path) if line)
  165. }
  166. n_vocab = len(ranks)
  167. special_tokens = {}
  168. specials = [
  169. "<|endoftext|>",
  170. "<|startoftranscript|>",
  171. *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
  172. *[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())],
  173. *[f"<|{emotion}|>" for emotion in list(EMOTION.keys())],
  174. "<|translate|>",
  175. "<|transcribe|>",
  176. "<|startoflm|>",
  177. "<|startofprev|>",
  178. "<|nospeech|>",
  179. "<|notimestamps|>",
  180. *[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)], # register special tokens for ASR
  181. *[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())], # register special tokens for TTS
  182. *[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
  183. ]
  184. for token in specials:
  185. special_tokens[token] = n_vocab
  186. n_vocab += 1
  187. return tiktoken.Encoding(
  188. name=os.path.basename(vocab_path),
  189. explicit_n_vocab=n_vocab,
  190. pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
  191. mergeable_ranks=ranks,
  192. special_tokens=special_tokens,
  193. )
  194. @lru_cache(maxsize=None)
  195. def get_tokenizer(
  196. multilingual: bool,
  197. *,
  198. num_languages: int = 99,
  199. language: Optional[str] = None,
  200. task: Optional[str] = None, # Literal["transcribe", "translate", None]
  201. ) -> Tokenizer:
  202. if language is not None:
  203. language = language.lower()
  204. if language not in LANGUAGES:
  205. if language in TO_LANGUAGE_CODE:
  206. language = TO_LANGUAGE_CODE[language]
  207. else:
  208. raise ValueError(f"Unsupported language: {language}")
  209. if multilingual:
  210. encoding_name = "multilingual_zh_ja_yue_char_del"
  211. language = language or "en"
  212. task = task or "transcribe"
  213. else:
  214. encoding_name = "gpt2"
  215. language = None
  216. task = None
  217. encoding = get_encoding(name=encoding_name, num_languages=num_languages)
  218. return Tokenizer(
  219. encoding=encoding, num_languages=num_languages, language=language, task=task
  220. )