lyuxiang.lx 11 月之前
父节点
当前提交
d3b1a8e352
共有 1 个文件被更改,包括 4 次插入0 次删除
  1. 4 0
      cosyvoice/llm/llm.py

+ 4 - 0
cosyvoice/llm/llm.py

@@ -280,10 +280,14 @@ class Qwen2LM(torch.nn.Module):
             sampling: int,
             ignore_eos: bool = True,
     ):
+        num_trials, max_trials = 0, 100
         while True:
             top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
             if (not ignore_eos) or (self.speech_token_size not in top_ids):
                 break
+            num_trials += 1
+            if num_trials > max_trials:
+                raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
         return top_ids
 
     @torch.inference_mode()