123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272 |
- import base64
- import os
- from functools import lru_cache
- from typing import Optional
- import torch
- from transformers import AutoTokenizer
- from whisper.tokenizer import Tokenizer
- import tiktoken
- LANGUAGES = {
- "en": "english",
- "zh": "chinese",
- "de": "german",
- "es": "spanish",
- "ru": "russian",
- "ko": "korean",
- "fr": "french",
- "ja": "japanese",
- "pt": "portuguese",
- "tr": "turkish",
- "pl": "polish",
- "ca": "catalan",
- "nl": "dutch",
- "ar": "arabic",
- "sv": "swedish",
- "it": "italian",
- "id": "indonesian",
- "hi": "hindi",
- "fi": "finnish",
- "vi": "vietnamese",
- "he": "hebrew",
- "uk": "ukrainian",
- "el": "greek",
- "ms": "malay",
- "cs": "czech",
- "ro": "romanian",
- "da": "danish",
- "hu": "hungarian",
- "ta": "tamil",
- "no": "norwegian",
- "th": "thai",
- "ur": "urdu",
- "hr": "croatian",
- "bg": "bulgarian",
- "lt": "lithuanian",
- "la": "latin",
- "mi": "maori",
- "ml": "malayalam",
- "cy": "welsh",
- "sk": "slovak",
- "te": "telugu",
- "fa": "persian",
- "lv": "latvian",
- "bn": "bengali",
- "sr": "serbian",
- "az": "azerbaijani",
- "sl": "slovenian",
- "kn": "kannada",
- "et": "estonian",
- "mk": "macedonian",
- "br": "breton",
- "eu": "basque",
- "is": "icelandic",
- "hy": "armenian",
- "ne": "nepali",
- "mn": "mongolian",
- "bs": "bosnian",
- "kk": "kazakh",
- "sq": "albanian",
- "sw": "swahili",
- "gl": "galician",
- "mr": "marathi",
- "pa": "punjabi",
- "si": "sinhala",
- "km": "khmer",
- "sn": "shona",
- "yo": "yoruba",
- "so": "somali",
- "af": "afrikaans",
- "oc": "occitan",
- "ka": "georgian",
- "be": "belarusian",
- "tg": "tajik",
- "sd": "sindhi",
- "gu": "gujarati",
- "am": "amharic",
- "yi": "yiddish",
- "lo": "lao",
- "uz": "uzbek",
- "fo": "faroese",
- "ht": "haitian creole",
- "ps": "pashto",
- "tk": "turkmen",
- "nn": "nynorsk",
- "mt": "maltese",
- "sa": "sanskrit",
- "lb": "luxembourgish",
- "my": "myanmar",
- "bo": "tibetan",
- "tl": "tagalog",
- "mg": "malagasy",
- "as": "assamese",
- "tt": "tatar",
- "haw": "hawaiian",
- "ln": "lingala",
- "ha": "hausa",
- "ba": "bashkir",
- "jw": "javanese",
- "su": "sundanese",
- "yue": "cantonese",
- "minnan": "minnan",
- "wuyu": "wuyu",
- "dialect": "dialect",
- "zh/en": "zh/en",
- "en/zh": "en/zh",
- }
- # language code lookup by name, with a few language aliases
- TO_LANGUAGE_CODE = {
- **{language: code for code, language in LANGUAGES.items()},
- "burmese": "my",
- "valencian": "ca",
- "flemish": "nl",
- "haitian": "ht",
- "letzeburgesch": "lb",
- "pushto": "ps",
- "panjabi": "pa",
- "moldavian": "ro",
- "moldovan": "ro",
- "sinhalese": "si",
- "castilian": "es",
- "mandarin": "zh",
- }
- AUDIO_EVENT = {
- "ASR": "ASR",
- "AED": "AED",
- "SER": "SER",
- "Speech": "Speech",
- "/Speech": "/Speech",
- "BGM": "BGM",
- "/BGM": "/BGM",
- "Laughter": "Laughter",
- "/Laughter": "/Laughter",
- "Applause": "Applause",
- "/Applause": "/Applause",
- }
- EMOTION = {
- "HAPPY": "HAPPY",
- "SAD": "SAD",
- "ANGRY": "ANGRY",
- "NEUTRAL": "NEUTRAL",
- }
- TTS_Vocal_Token = {
- "TTS/B": "TTS/B",
- "TTS/O": "TTS/O",
- "TTS/Q": "TTS/Q",
- "TTS/A": "TTS/A",
- "TTS/CO": "TTS/CO",
- "TTS/CL": "TTS/CL",
- "TTS/H": "TTS/H",
- **{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(1, 14)}
- }
- @lru_cache(maxsize=None)
- def get_encoding(name: str = "gpt2", num_languages: int = 99):
- vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
- ranks = {
- base64.b64decode(token): int(rank)
- for token, rank in (line.split() for line in open(vocab_path) if line)
- }
- n_vocab = len(ranks)
- special_tokens = {}
- specials = [
- "<|endoftext|>",
- "<|startoftranscript|>",
- *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
- *[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())],
- *[f"<|{emotion}|>" for emotion in list(EMOTION.keys())],
- "<|translate|>",
- "<|transcribe|>",
- "<|startoflm|>",
- "<|startofprev|>",
- "<|nospeech|>",
- "<|notimestamps|>",
- *[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)], # register special tokens for ASR
- *[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())], # register special tokens for TTS
- *[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
- ]
- for token in specials:
- special_tokens[token] = n_vocab
- n_vocab += 1
- return tiktoken.Encoding(
- name=os.path.basename(vocab_path),
- explicit_n_vocab=n_vocab,
- pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
- mergeable_ranks=ranks,
- special_tokens=special_tokens,
- )
- @lru_cache(maxsize=None)
- def get_tokenizer(
- multilingual: bool,
- *,
- num_languages: int = 99,
- language: Optional[str] = None,
- task: Optional[str] = None, # Literal["transcribe", "translate", None]
- ) -> Tokenizer:
- if language is not None:
- language = language.lower()
- if language not in LANGUAGES:
- if language in TO_LANGUAGE_CODE:
- language = TO_LANGUAGE_CODE[language]
- else:
- raise ValueError(f"Unsupported language: {language}")
- if multilingual:
- encoding_name = "multilingual_zh_ja_yue_char_del"
- language = language or "en"
- task = task or "transcribe"
- else:
- encoding_name = "gpt2"
- language = None
- task = None
- encoding = get_encoding(name=encoding_name, num_languages=num_languages)
- return Tokenizer(
- encoding=encoding, num_languages=num_languages, language=language, task=task
- )
- class QwenTokenizer():
- def __init__(self, token_path, skip_special_tokens=True):
- special_tokens = {
- 'eos_token': '<|endoftext|>',
- 'pad_token': '<|endoftext|>',
- 'additional_special_tokens': [
- '<|im_start|>', '<|im_end|>', '<|endofprompt|>',
- '[breath]', '<strong>', '</strong>', '[noise]',
- '[laughter]', '[cough]', '[clucking]', '[accent]',
- '[quick_breath]',
- ]
- }
- self.tokenizer = AutoTokenizer.from_pretrained(token_path)
- self.tokenizer.add_special_tokens(special_tokens)
- self.skip_special_tokens = skip_special_tokens
- def encode(self, text, **kwargs):
- tokens = self.tokenizer([text], return_tensors="pt")
- tokens = tokens["input_ids"][0].cpu().tolist()
- return tokens
- def decode(self, tokens):
- tokens = torch.tensor(tokens, dtype=torch.int64)
- text = self.tokenizer.batch_decode([tokens], skip_special_tokens=self.skip_special_tokens)[0]
- return text
- @lru_cache(maxsize=None)
- def get_qwen_tokenizer(
- token_path: str,
- skip_special_tokens: bool
- ) -> QwenTokenizer:
- return QwenTokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
|