ソースを参照

fix(async_cosyvoice): 恢复原本文本令牌处理逻辑

- 在 Frontend 中,恢复原本逐个生成文本令牌
- 在 Model 类中,移除了不必要的日志信息和断言,简化了文本令牌的处理流程
qihua 1 年間 前
コミット
c0f6a474f3
2 ファイル変更2 行追加14 行削除
  1. 2 3
      cosyvoice/cli/frontend.py
  2. 0 11
      cosyvoice/llm/llm_vllm.py

+ 2 - 3
cosyvoice/cli/frontend.py

@@ -102,9 +102,8 @@ class CosyVoiceFrontEnd:
     def _extract_text_token_generator(self, text_generator):
         for text in text_generator:
             text_token, _ = self._extract_text_token(text)
-            # for i in range(text_token.shape[1]):
-            #     yield text_token[:, i: i + 1]
-            yield text_token
+            for i in range(text_token.shape[1]):
+                yield text_token[:, i: i + 1]
 
     def _extract_speech_token(self, speech):
         assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'

+ 0 - 11
cosyvoice/llm/llm_vllm.py

@@ -149,8 +149,6 @@ class VllmQwen2LM(Qwen2LM):
                 need_add_tokens = output.token_ids[:-1]
             else:
                 need_add_tokens = output.token_ids
-            # 单个token 循环处理比较耗时,建议是在model中进行批量(extend)处理,减少循环
-            # yield need_add_tokens
             for token in need_add_tokens:
                 yield token
 
@@ -186,18 +184,14 @@ class VllmQwen2LM(Qwen2LM):
                     text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:]
                     prompt_speech_token = prompt_speech_token[self.mix_ratio[1]:]
                 else:
-                    logging.info('not enough text token to decode, wait for more')
                     break
             if len(prompt_speech_token) == 0:
                 if (len(last_tokens) > 0 and last_tokens[-1] == 6563) or len(prompt_token_ids) == 1:
-                    logging.info('get fill token, need to append more text token')
                     if len(text_tokens_cache) >= self.mix_ratio[0]:
                         text_tokens_temp = text_tokens_cache[:self.mix_ratio[0]]
                         prompt_token_ids += text_tokens_temp
-                        logging.info('append {} text token'.format(len(text_tokens_temp)))
                         text_tokens_cache = text_tokens_cache[self.mix_ratio[0]:]
                     else:
-                        logging.info('not enough text token to decode, wait for more')
                         continue
                 for output in self.llm_inference(prompt_token_ids, stop_token_ids=[6563]):
                     last_tokens = output.token_ids
@@ -205,19 +199,14 @@ class VllmQwen2LM(Qwen2LM):
                         need_add_tokens = last_tokens[:-1]
                     else:
                         need_add_tokens = last_tokens
-                    # 单个token 循环处理比较耗时,建议是在model中进行批量(extend)处理,减少循环
-                    # yield need_add_tokens
                     for token in need_add_tokens:
                         yield token
                     prompt_token_ids.extend(need_add_tokens)
         prompt_token_ids += text_tokens_cache + [self.task_token_id]
-        logging.info('no more text token, decode until met eos')
         for output in self.llm_inference(prompt_token_ids, stop_token_ids=[6561]):
             if output.token_ids[-1] == 6561:
                 need_add_tokens = output.token_ids[:-1]
             else:
                 need_add_tokens = output.token_ids
-            # 单个token 循环处理比较耗时,建议是在model中进行批量(extend)处理,减少循环
-            # yield need_add_tokens
             for token in need_add_tokens:
                 yield token