tokenizer.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. import base64
  2. import os
  3. from functools import lru_cache
  4. from typing import Optional
  5. import torch
  6. from transformers import AutoTokenizer
  7. from whisper.tokenizer import Tokenizer
  8. import tiktoken
  9. LANGUAGES = {
  10. "en": "english",
  11. "zh": "chinese",
  12. "de": "german",
  13. "es": "spanish",
  14. "ru": "russian",
  15. "ko": "korean",
  16. "fr": "french",
  17. "ja": "japanese",
  18. "pt": "portuguese",
  19. "tr": "turkish",
  20. "pl": "polish",
  21. "ca": "catalan",
  22. "nl": "dutch",
  23. "ar": "arabic",
  24. "sv": "swedish",
  25. "it": "italian",
  26. "id": "indonesian",
  27. "hi": "hindi",
  28. "fi": "finnish",
  29. "vi": "vietnamese",
  30. "he": "hebrew",
  31. "uk": "ukrainian",
  32. "el": "greek",
  33. "ms": "malay",
  34. "cs": "czech",
  35. "ro": "romanian",
  36. "da": "danish",
  37. "hu": "hungarian",
  38. "ta": "tamil",
  39. "no": "norwegian",
  40. "th": "thai",
  41. "ur": "urdu",
  42. "hr": "croatian",
  43. "bg": "bulgarian",
  44. "lt": "lithuanian",
  45. "la": "latin",
  46. "mi": "maori",
  47. "ml": "malayalam",
  48. "cy": "welsh",
  49. "sk": "slovak",
  50. "te": "telugu",
  51. "fa": "persian",
  52. "lv": "latvian",
  53. "bn": "bengali",
  54. "sr": "serbian",
  55. "az": "azerbaijani",
  56. "sl": "slovenian",
  57. "kn": "kannada",
  58. "et": "estonian",
  59. "mk": "macedonian",
  60. "br": "breton",
  61. "eu": "basque",
  62. "is": "icelandic",
  63. "hy": "armenian",
  64. "ne": "nepali",
  65. "mn": "mongolian",
  66. "bs": "bosnian",
  67. "kk": "kazakh",
  68. "sq": "albanian",
  69. "sw": "swahili",
  70. "gl": "galician",
  71. "mr": "marathi",
  72. "pa": "punjabi",
  73. "si": "sinhala",
  74. "km": "khmer",
  75. "sn": "shona",
  76. "yo": "yoruba",
  77. "so": "somali",
  78. "af": "afrikaans",
  79. "oc": "occitan",
  80. "ka": "georgian",
  81. "be": "belarusian",
  82. "tg": "tajik",
  83. "sd": "sindhi",
  84. "gu": "gujarati",
  85. "am": "amharic",
  86. "yi": "yiddish",
  87. "lo": "lao",
  88. "uz": "uzbek",
  89. "fo": "faroese",
  90. "ht": "haitian creole",
  91. "ps": "pashto",
  92. "tk": "turkmen",
  93. "nn": "nynorsk",
  94. "mt": "maltese",
  95. "sa": "sanskrit",
  96. "lb": "luxembourgish",
  97. "my": "myanmar",
  98. "bo": "tibetan",
  99. "tl": "tagalog",
  100. "mg": "malagasy",
  101. "as": "assamese",
  102. "tt": "tatar",
  103. "haw": "hawaiian",
  104. "ln": "lingala",
  105. "ha": "hausa",
  106. "ba": "bashkir",
  107. "jw": "javanese",
  108. "su": "sundanese",
  109. "yue": "cantonese",
  110. "minnan": "minnan",
  111. "wuyu": "wuyu",
  112. "dialect": "dialect",
  113. "zh/en": "zh/en",
  114. "en/zh": "en/zh",
  115. }
  116. # language code lookup by name, with a few language aliases
  117. TO_LANGUAGE_CODE = {
  118. **{language: code for code, language in LANGUAGES.items()},
  119. "burmese": "my",
  120. "valencian": "ca",
  121. "flemish": "nl",
  122. "haitian": "ht",
  123. "letzeburgesch": "lb",
  124. "pushto": "ps",
  125. "panjabi": "pa",
  126. "moldavian": "ro",
  127. "moldovan": "ro",
  128. "sinhalese": "si",
  129. "castilian": "es",
  130. "mandarin": "zh",
  131. }
  132. AUDIO_EVENT = {
  133. "ASR": "ASR",
  134. "AED": "AED",
  135. "SER": "SER",
  136. "Speech": "Speech",
  137. "/Speech": "/Speech",
  138. "BGM": "BGM",
  139. "/BGM": "/BGM",
  140. "Laughter": "Laughter",
  141. "/Laughter": "/Laughter",
  142. "Applause": "Applause",
  143. "/Applause": "/Applause",
  144. }
  145. EMOTION = {
  146. "HAPPY": "HAPPY",
  147. "SAD": "SAD",
  148. "ANGRY": "ANGRY",
  149. "NEUTRAL": "NEUTRAL",
  150. }
  151. TTS_Vocal_Token = {
  152. "TTS/B": "TTS/B",
  153. "TTS/O": "TTS/O",
  154. "TTS/Q": "TTS/Q",
  155. "TTS/A": "TTS/A",
  156. "TTS/CO": "TTS/CO",
  157. "TTS/CL": "TTS/CL",
  158. "TTS/H": "TTS/H",
  159. **{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(1, 14)}
  160. }
  161. @lru_cache(maxsize=None)
  162. def get_encoding(name: str = "gpt2", num_languages: int = 99):
  163. vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
  164. ranks = {
  165. base64.b64decode(token): int(rank)
  166. for token, rank in (line.split() for line in open(vocab_path) if line)
  167. }
  168. n_vocab = len(ranks)
  169. special_tokens = {}
  170. specials = [
  171. "<|endoftext|>",
  172. "<|startoftranscript|>",
  173. *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
  174. *[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())],
  175. *[f"<|{emotion}|>" for emotion in list(EMOTION.keys())],
  176. "<|translate|>",
  177. "<|transcribe|>",
  178. "<|startoflm|>",
  179. "<|startofprev|>",
  180. "<|nospeech|>",
  181. "<|notimestamps|>",
  182. *[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)], # register special tokens for ASR
  183. *[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())], # register special tokens for TTS
  184. *[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
  185. ]
  186. for token in specials:
  187. special_tokens[token] = n_vocab
  188. n_vocab += 1
  189. return tiktoken.Encoding(
  190. name=os.path.basename(vocab_path),
  191. explicit_n_vocab=n_vocab,
  192. pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
  193. mergeable_ranks=ranks,
  194. special_tokens=special_tokens,
  195. )
  196. @lru_cache(maxsize=None)
  197. def get_tokenizer(
  198. multilingual: bool,
  199. *,
  200. num_languages: int = 99,
  201. language: Optional[str] = None,
  202. task: Optional[str] = None, # Literal["transcribe", "translate", None]
  203. ) -> Tokenizer:
  204. if language is not None:
  205. language = language.lower()
  206. if language not in LANGUAGES:
  207. if language in TO_LANGUAGE_CODE:
  208. language = TO_LANGUAGE_CODE[language]
  209. else:
  210. raise ValueError(f"Unsupported language: {language}")
  211. if multilingual:
  212. encoding_name = "multilingual_zh_ja_yue_char_del"
  213. language = language or "en"
  214. task = task or "transcribe"
  215. else:
  216. encoding_name = "gpt2"
  217. language = None
  218. task = None
  219. encoding = get_encoding(name=encoding_name, num_languages=num_languages)
  220. return Tokenizer(
  221. encoding=encoding, num_languages=num_languages, language=language, task=task
  222. )
  223. class CosyVoice2Tokenizer():
  224. def __init__(self, token_path, skip_special_tokens=True):
  225. super().__init__()
  226. # NOTE: non-chat model, all these special tokens keep randomly initialized.
  227. special_tokens = {
  228. 'eos_token': '<|endoftext|>',
  229. 'pad_token': '<|endoftext|>',
  230. 'additional_special_tokens': [
  231. '<|im_start|>', '<|im_end|>', '<|endofprompt|>',
  232. '[breath]', '<strong>', '</strong>', '[noise]',
  233. '[laughter]', '[cough]', '[clucking]', '[accent]',
  234. '[quick_breath]',
  235. "<laughter>", "</laughter>",
  236. "[hissing]", "[sigh]", "[vocalized-noise]",
  237. "[lipsmack]", "[mn]"
  238. ]
  239. }
  240. self.special_tokens = special_tokens
  241. self.tokenizer = AutoTokenizer.from_pretrained(token_path)
  242. self.tokenizer.add_special_tokens(special_tokens)
  243. self.skip_special_tokens = skip_special_tokens
  244. def encode(self, text, **kwargs):
  245. tokens = self.tokenizer([text], return_tensors="pt")
  246. tokens = tokens["input_ids"][0].cpu().tolist()
  247. return tokens
  248. def decode(self, tokens):
  249. tokens = torch.tensor(tokens, dtype=torch.int64)
  250. text = self.tokenizer.batch_decode([tokens], skip_special_tokens=self.skip_special_tokens)[0]
  251. return text
  252. class CosyVoice3Tokenizer(CosyVoice2Tokenizer):
  253. def __init__(self, token_path, skip_special_tokens=True):
  254. # NOTE: non-chat model, all these special tokens keep randomly initialized.
  255. special_tokens = {
  256. 'eos_token': '<|endoftext|>',
  257. 'pad_token': '<|endoftext|>',
  258. 'additional_special_tokens': [
  259. '<|im_start|>', '<|im_end|>', '<|endofprompt|>',
  260. '[breath]', '<strong>', '</strong>', '[noise]',
  261. '[laughter]', '[cough]', '[clucking]', '[accent]',
  262. '[quick_breath]',
  263. "<laughter>", "</laughter>",
  264. "[hissing]", "[sigh]", "[vocalized-noise]",
  265. "[lipsmack]", "[mn]", "<|endofsystem|>",
  266. "[AA]", "[AA0]", "[AA1]", "[AA2]", "[AE]", "[AE0]", "[AE1]", "[AE2]", "[AH]", "[AH0]", "[AH1]", "[AH2]",
  267. "[AO]", "[AO0]", "[AO1]", "[AO2]", "[AW]", "[AW0]", "[AW1]", "[AW2]", "[AY]", "[AY0]", "[AY1]", "[AY2]",
  268. "[B]", "[CH]", "[D]", "[DH]", "[EH]", "[EH0]", "[EH1]", "[EH2]", "[ER]", "[ER0]", "[ER1]", "[ER2]", "[EY]",
  269. "[EY0]", "[EY1]", "[EY2]", "[F]", "[G]", "[HH]", "[IH]", "[IH0]", "[IH1]", "[IH2]", "[IY]", "[IY0]", "[IY1]",
  270. "[IY2]", "[JH]", "[K]", "[L]", "[M]", "[N]", "[NG]", "[OW]", "[OW0]", "[OW1]", "[OW2]", "[OY]", "[OY0]",
  271. "[OY1]", "[OY2]", "[P]", "[R]", "[S]", "[SH]", "[T]", "[TH]", "[UH]", "[UH0]", "[UH1]", "[UH2]", "[UW]",
  272. "[UW0]", "[UW1]", "[UW2]", "[V]", "[W]", "[Y]", "[Z]", "[ZH]",
  273. "[a]", "[ai]", "[an]", "[ang]", "[ao]", "[b]", "[c]", "[ch]", "[d]", "[e]", "[ei]", "[en]", "[eng]", "[f]",
  274. "[g]", "[h]", "[i]", "[ian]", "[in]", "[ing]", "[iu]", "[ià]", "[iàn]", "[iàng]", "[iào]", "[iá]", "[ián]",
  275. "[iáng]", "[iáo]", "[iè]", "[ié]", "[iòng]", "[ióng]", "[iù]", "[iú]", "[iā]", "[iān]", "[iāng]", "[iāo]",
  276. "[iē]", "[iě]", "[iōng]", "[iū]", "[iǎ]", "[iǎn]", "[iǎng]", "[iǎo]", "[iǒng]", "[iǔ]", "[j]", "[k]", "[l]",
  277. "[m]", "[n]", "[o]", "[ong]", "[ou]", "[p]", "[q]", "[r]", "[s]", "[sh]", "[t]", "[u]", "[uang]", "[ue]",
  278. "[un]", "[uo]", "[uà]", "[uài]", "[uàn]", "[uàng]", "[uá]", "[uái]", "[uán]", "[uáng]", "[uè]", "[ué]", "[uì]",
  279. "[uí]", "[uò]", "[uó]", "[uā]", "[uāi]", "[uān]", "[uāng]", "[uē]", "[uě]", "[uī]", "[uō]", "[uǎ]", "[uǎi]",
  280. "[uǎn]", "[uǎng]", "[uǐ]", "[uǒ]", "[vè]", "[w]", "[x]", "[y]", "[z]", "[zh]", "[à]", "[ài]", "[àn]", "[àng]",
  281. "[ào]", "[á]", "[ái]", "[án]", "[áng]", "[áo]", "[è]", "[èi]", "[èn]", "[èng]", "[èr]", "[é]", "[éi]", "[én]",
  282. "[éng]", "[ér]", "[ì]", "[ìn]", "[ìng]", "[í]", "[ín]", "[íng]", "[ò]", "[òng]", "[òu]", "[ó]", "[óng]", "[óu]",
  283. "[ù]", "[ùn]", "[ú]", "[ún]", "[ā]", "[āi]", "[ān]", "[āng]", "[āo]", "[ē]", "[ēi]", "[ēn]", "[ēng]", "[ě]",
  284. "[ěi]", "[ěn]", "[ěng]", "[ěr]", "[ī]", "[īn]", "[īng]", "[ō]", "[ōng]", "[ōu]", "[ū]", "[ūn]", "[ǎ]", "[ǎi]",
  285. "[ǎn]", "[ǎng]", "[ǎo]", "[ǐ]", "[ǐn]", "[ǐng]", "[ǒ]", "[ǒng]", "[ǒu]", "[ǔ]", "[ǔn]", "[ǘ]", "[ǚ]", "[ǜ]"
  286. ]
  287. }
  288. self.special_tokens = special_tokens
  289. self.tokenizer = AutoTokenizer.from_pretrained(token_path)
  290. self.tokenizer.add_special_tokens(special_tokens)
  291. self.skip_special_tokens = skip_special_tokens
  292. @lru_cache(maxsize=None)
  293. def get_qwen_tokenizer(
  294. token_path: str,
  295. skip_special_tokens: bool,
  296. version: str = 'cosyvoice2'
  297. ):
  298. if version == 'cosyvoice2':
  299. return CosyVoice2Tokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
  300. elif version == 'cosyvoice3':
  301. return CosyVoice3Tokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
  302. else:
  303. raise ValueError