tokenizer.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. import base64
  2. import os
  3. from functools import lru_cache
  4. from typing import Optional
  5. import torch
  6. from transformers import AutoTokenizer
  7. from whisper.tokenizer import Tokenizer
  8. import tiktoken
  9. LANGUAGES = {
  10. "en": "english",
  11. "zh": "chinese",
  12. "de": "german",
  13. "es": "spanish",
  14. "ru": "russian",
  15. "ko": "korean",
  16. "fr": "french",
  17. "ja": "japanese",
  18. "pt": "portuguese",
  19. "tr": "turkish",
  20. "pl": "polish",
  21. "ca": "catalan",
  22. "nl": "dutch",
  23. "ar": "arabic",
  24. "sv": "swedish",
  25. "it": "italian",
  26. "id": "indonesian",
  27. "hi": "hindi",
  28. "fi": "finnish",
  29. "vi": "vietnamese",
  30. "he": "hebrew",
  31. "uk": "ukrainian",
  32. "el": "greek",
  33. "ms": "malay",
  34. "cs": "czech",
  35. "ro": "romanian",
  36. "da": "danish",
  37. "hu": "hungarian",
  38. "ta": "tamil",
  39. "no": "norwegian",
  40. "th": "thai",
  41. "ur": "urdu",
  42. "hr": "croatian",
  43. "bg": "bulgarian",
  44. "lt": "lithuanian",
  45. "la": "latin",
  46. "mi": "maori",
  47. "ml": "malayalam",
  48. "cy": "welsh",
  49. "sk": "slovak",
  50. "te": "telugu",
  51. "fa": "persian",
  52. "lv": "latvian",
  53. "bn": "bengali",
  54. "sr": "serbian",
  55. "az": "azerbaijani",
  56. "sl": "slovenian",
  57. "kn": "kannada",
  58. "et": "estonian",
  59. "mk": "macedonian",
  60. "br": "breton",
  61. "eu": "basque",
  62. "is": "icelandic",
  63. "hy": "armenian",
  64. "ne": "nepali",
  65. "mn": "mongolian",
  66. "bs": "bosnian",
  67. "kk": "kazakh",
  68. "sq": "albanian",
  69. "sw": "swahili",
  70. "gl": "galician",
  71. "mr": "marathi",
  72. "pa": "punjabi",
  73. "si": "sinhala",
  74. "km": "khmer",
  75. "sn": "shona",
  76. "yo": "yoruba",
  77. "so": "somali",
  78. "af": "afrikaans",
  79. "oc": "occitan",
  80. "ka": "georgian",
  81. "be": "belarusian",
  82. "tg": "tajik",
  83. "sd": "sindhi",
  84. "gu": "gujarati",
  85. "am": "amharic",
  86. "yi": "yiddish",
  87. "lo": "lao",
  88. "uz": "uzbek",
  89. "fo": "faroese",
  90. "ht": "haitian creole",
  91. "ps": "pashto",
  92. "tk": "turkmen",
  93. "nn": "nynorsk",
  94. "mt": "maltese",
  95. "sa": "sanskrit",
  96. "lb": "luxembourgish",
  97. "my": "myanmar",
  98. "bo": "tibetan",
  99. "tl": "tagalog",
  100. "mg": "malagasy",
  101. "as": "assamese",
  102. "tt": "tatar",
  103. "haw": "hawaiian",
  104. "ln": "lingala",
  105. "ha": "hausa",
  106. "ba": "bashkir",
  107. "jw": "javanese",
  108. "su": "sundanese",
  109. "yue": "cantonese",
  110. "minnan": "minnan",
  111. "wuyu": "wuyu",
  112. "dialect": "dialect",
  113. "zh/en": "zh/en",
  114. "en/zh": "en/zh",
  115. }
  116. # language code lookup by name, with a few language aliases
  117. TO_LANGUAGE_CODE = {
  118. **{language: code for code, language in LANGUAGES.items()},
  119. "burmese": "my",
  120. "valencian": "ca",
  121. "flemish": "nl",
  122. "haitian": "ht",
  123. "letzeburgesch": "lb",
  124. "pushto": "ps",
  125. "panjabi": "pa",
  126. "moldavian": "ro",
  127. "moldovan": "ro",
  128. "sinhalese": "si",
  129. "castilian": "es",
  130. "mandarin": "zh",
  131. }
  132. AUDIO_EVENT = {
  133. "ASR": "ASR",
  134. "AED": "AED",
  135. "SER": "SER",
  136. "Speech": "Speech",
  137. "/Speech": "/Speech",
  138. "BGM": "BGM",
  139. "/BGM": "/BGM",
  140. "Laughter": "Laughter",
  141. "/Laughter": "/Laughter",
  142. "Applause": "Applause",
  143. "/Applause": "/Applause",
  144. }
  145. EMOTION = {
  146. "HAPPY": "HAPPY",
  147. "SAD": "SAD",
  148. "ANGRY": "ANGRY",
  149. "NEUTRAL": "NEUTRAL",
  150. }
  151. TTS_Vocal_Token = {
  152. "TTS/B": "TTS/B",
  153. "TTS/O": "TTS/O",
  154. "TTS/Q": "TTS/Q",
  155. "TTS/A": "TTS/A",
  156. "TTS/CO": "TTS/CO",
  157. "TTS/CL": "TTS/CL",
  158. "TTS/H": "TTS/H",
  159. **{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(1, 14)}
  160. }
  161. @lru_cache(maxsize=None)
  162. def get_encoding(name: str = "gpt2", num_languages: int = 99):
  163. vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
  164. ranks = {
  165. base64.b64decode(token): int(rank)
  166. for token, rank in (line.split() for line in open(vocab_path) if line)
  167. }
  168. n_vocab = len(ranks)
  169. special_tokens = {}
  170. specials = [
  171. "<|endoftext|>",
  172. "<|startoftranscript|>",
  173. *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
  174. *[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())],
  175. *[f"<|{emotion}|>" for emotion in list(EMOTION.keys())],
  176. "<|translate|>",
  177. "<|transcribe|>",
  178. "<|startoflm|>",
  179. "<|startofprev|>",
  180. "<|nospeech|>",
  181. "<|notimestamps|>",
  182. *[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)], # register special tokens for ASR
  183. *[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())], # register special tokens for TTS
  184. *[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
  185. ]
  186. for token in specials:
  187. special_tokens[token] = n_vocab
  188. n_vocab += 1
  189. return tiktoken.Encoding(
  190. name=os.path.basename(vocab_path),
  191. explicit_n_vocab=n_vocab,
  192. pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
  193. mergeable_ranks=ranks,
  194. special_tokens=special_tokens,
  195. )
  196. @lru_cache(maxsize=None)
  197. def get_tokenizer(
  198. multilingual: bool,
  199. *,
  200. num_languages: int = 99,
  201. language: Optional[str] = None,
  202. task: Optional[str] = None, # Literal["transcribe", "translate", None]
  203. ) -> Tokenizer:
  204. if language is not None:
  205. language = language.lower()
  206. if language not in LANGUAGES:
  207. if language in TO_LANGUAGE_CODE:
  208. language = TO_LANGUAGE_CODE[language]
  209. else:
  210. raise ValueError(f"Unsupported language: {language}")
  211. if multilingual:
  212. encoding_name = "multilingual_zh_ja_yue_char_del"
  213. language = language or "en"
  214. task = task or "transcribe"
  215. else:
  216. encoding_name = "gpt2"
  217. language = None
  218. task = None
  219. encoding = get_encoding(name=encoding_name, num_languages=num_languages)
  220. return Tokenizer(
  221. encoding=encoding, num_languages=num_languages, language=language, task=task
  222. )
  223. class QwenTokenizer():
  224. def __init__(self, token_path, skip_special_tokens=True):
  225. special_tokens = {
  226. 'eos_token': '<|endoftext|>',
  227. 'pad_token': '<|endoftext|>',
  228. 'additional_special_tokens': [
  229. '<|im_start|>', '<|im_end|>', '<|endofprompt|>',
  230. '[breath]', '<strong>', '</strong>', '[noise]',
  231. '[laughter]', '[cough]', '[clucking]', '[accent]',
  232. '[quick_breath]',
  233. ]
  234. }
  235. self.tokenizer = AutoTokenizer.from_pretrained(token_path)
  236. self.tokenizer.add_special_tokens(special_tokens)
  237. self.skip_special_tokens = skip_special_tokens
  238. def encode(self, text, **kwargs):
  239. tokens = self.tokenizer([text], return_tensors="pt")
  240. tokens = tokens["input_ids"][0].cpu().tolist()
  241. return tokens
  242. def decode(self, tokens):
  243. tokens = torch.tensor(tokens, dtype=torch.int64)
  244. text = self.tokenizer.batch_decode([tokens], skip_special_tokens=self.skip_special_tokens)[0]
  245. return text
  246. @lru_cache(maxsize=None)
  247. def get_qwen_tokenizer(
  248. token_path: str,
  249. skip_special_tokens: bool
  250. ) -> QwenTokenizer:
  251. return QwenTokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)