|
|
@@ -280,10 +280,14 @@ class Qwen2LM(torch.nn.Module):
|
|
|
sampling: int,
|
|
|
ignore_eos: bool = True,
|
|
|
):
|
|
|
+ num_trials, max_trials = 0, 100
|
|
|
while True:
|
|
|
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
|
|
|
if (not ignore_eos) or (self.speech_token_size not in top_ids):
|
|
|
break
|
|
|
+ num_trials += 1
|
|
|
+ if num_trials > max_trials:
|
|
|
+ raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
|
|
|
return top_ids
|
|
|
|
|
|
@torch.inference_mode()
|