Browse Source

update stream code

lyuxiang.lx 1 year ago
parent
commit
f4e70e222c

+ 2 - 0
.gitignore

@@ -43,6 +43,8 @@ compile_commands.json
 
 # train/inference files
 *.wav
+*.m4a
+*.aac
 *.pt
 pretrained_models/*
 *_pb2_grpc.py

+ 11 - 10
README.md

@@ -86,23 +86,24 @@ import torchaudio
 cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT')
 # sft usage
 print(cosyvoice.list_avaliable_spks())
-output = cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女')
-torchaudio.save('sft.wav', output['tts_speech'], 22050)
+# change stream=True for chunk stream inference
+for i, j in enumerate(cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女', stream=False)):
+    torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], 22050)
 
 cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M')
 # zero_shot usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
 prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
-output = cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k)
-torchaudio.save('zero_shot.wav', output['tts_speech'], 22050)
+for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
+    torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], 22050)
 # cross_lingual usage
 prompt_speech_16k = load_wav('cross_lingual_prompt.wav', 16000)
-output = cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.', prompt_speech_16k)
-torchaudio.save('cross_lingual.wav', output['tts_speech'], 22050)
+for i, j in enumerate(cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.', prompt_speech_16k, stream=False)):
+    torchaudio.save('cross_lingual_{}.wav'.format(i), j['tts_speech'], 22050)
 
 cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-Instruct')
 # instruct usage, support <laughter></laughter><strong></strong>[laughter][breath]
-output = cosyvoice.inference_instruct('在面对挑战时,他展现了非凡的<strong>勇气</strong>与<strong>智慧</strong>。', '中文男', 'Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.')
-torchaudio.save('instruct.wav', output['tts_speech'], 22050)
+for i, j in enumerate(cosyvoice.inference_instruct('在面对挑战时,他展现了非凡的<strong>勇气</strong>与<strong>智慧</strong>。', '中文男', 'Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.', stream=False)):
+    torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], 22050)
 ```
 
 **Start web demo**
@@ -133,10 +134,10 @@ docker build -t cosyvoice:v1.0 .
 # change iic/CosyVoice-300M to iic/CosyVoice-300M-Instruct if you want to use instruct inference
 # for grpc usage
 docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/grpc && python3 server.py --port 50000 --max_conc 4 --model_dir iic/CosyVoice-300M && sleep infinity"
-python3 grpc/client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
+cd grpc && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
 # for fastapi usage
 docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/fastapi && MODEL_DIR=iic/CosyVoice-300M fastapi dev --port 50000 server.py && sleep infinity"
-python3 fastapi/client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
+cd fastapi && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
 ```
 
 ## Discussion & Communication

+ 5 - 2
cosyvoice/bin/inference.py

@@ -100,10 +100,13 @@ def main():
                                'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
                                'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
                                'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
-            model_output = model.inference(**model_input)
+            tts_speeches = []
+            for model_output in model.inference(**model_input):
+                tts_speeches.append(model_output['tts_speech'])
+            tts_speeches = torch.concat(tts_speeches, dim=1)
             tts_key = '{}_{}'.format(utts[0], tts_index[0])
             tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
-            torchaudio.save(tts_fn, model_output['tts_speech'], sample_rate=22050)
+            torchaudio.save(tts_fn, tts_speeches, sample_rate=22050)
             f.write('{} {}\n'.format(tts_key, tts_fn))
             f.flush()
     f.close()

+ 4 - 0
cosyvoice/cli/cosyvoice.py

@@ -49,6 +49,7 @@ class CosyVoice:
         for i in self.frontend.text_normalize(tts_text, split=True):
             model_input = self.frontend.frontend_sft(i, spk_id)
             start_time = time.time()
+            logging.info('synthesis text {}'.format(i))
             for model_output in self.model.inference(**model_input, stream=stream):
                 speech_len = model_output['tts_speech'].shape[1] / 22050
                 logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
@@ -60,6 +61,7 @@ class CosyVoice:
         for i in self.frontend.text_normalize(tts_text, split=True):
             model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
             start_time = time.time()
+            logging.info('synthesis text {}'.format(i))
             for model_output in self.model.inference(**model_input, stream=stream):
                 speech_len = model_output['tts_speech'].shape[1] / 22050
                 logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
@@ -72,6 +74,7 @@ class CosyVoice:
         for i in self.frontend.text_normalize(tts_text, split=True):
             model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
             start_time = time.time()
+            logging.info('synthesis text {}'.format(i))
             for model_output in self.model.inference(**model_input, stream=stream):
                 speech_len = model_output['tts_speech'].shape[1] / 22050
                 logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
@@ -85,6 +88,7 @@ class CosyVoice:
         for i in self.frontend.text_normalize(tts_text, split=True):
             model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
             start_time = time.time()
+            logging.info('synthesis text {}'.format(i))
             for model_output in self.model.inference(**model_input, stream=stream):
                 speech_len = model_output['tts_speech'].shape[1] / 22050
                 logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))

+ 67 - 66
cosyvoice/cli/model.py

@@ -16,6 +16,8 @@ import numpy as np
 import threading
 import time
 from contextlib import nullcontext
+import uuid
+from cosyvoice.utils.common import fade_in_out
 
 
 class CosyVoiceModel:
@@ -28,13 +30,19 @@ class CosyVoiceModel:
         self.llm = llm
         self.flow = flow
         self.hift = hift
-        self.stream_win_len = 60 * 4
-        self.stream_hop_len = 50 * 4
-        self.overlap = 4395 * 4 # 10 token equals 4395 sample point
-        self.window = np.hamming(2 * self.overlap)
+        self.token_min_hop_len = 100
+        self.token_max_hop_len = 400
+        self.token_overlap_len = 20
+        self.speech_overlap_len = 34 * 256
+        self.window = np.hamming(2 * self.speech_overlap_len)
+        self.stream_scale_factor = 1
+        assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
         self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
         self.flow_hift_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
         self.lock = threading.Lock()
+        # dict used to store session related variable
+        self.tts_speech_token = {}
+        self.llm_end = {}
 
     def load(self, llm_model, flow_model, hift_model):
         self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
@@ -44,7 +52,7 @@ class CosyVoiceModel:
         self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
         self.hift.to(self.device).eval()
 
-    def llm_job(self, text, text_len, prompt_text, prompt_text_len, llm_prompt_speech_token, llm_prompt_speech_token_len, llm_embedding):
+    def llm_job(self, text, text_len, prompt_text, prompt_text_len, llm_prompt_speech_token, llm_prompt_speech_token_len, llm_embedding, this_uuid):
         with self.llm_context:
             for i in self.llm.inference(text=text.to(self.device),
                                                 text_len=text_len.to(self.device),
@@ -53,13 +61,11 @@ class CosyVoiceModel:
                                                 prompt_speech_token=llm_prompt_speech_token.to(self.device),
                                                 prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device),
                                                 embedding=llm_embedding.to(self.device),
-                                                beam_size=1,
                                                 sampling=25,
                                                 max_token_text_ratio=30,
-                                                min_token_text_ratio=3,
-                                                stream=True):
-                self.tts_speech_token.append(i)
-        self.llm_end = True
+                                                min_token_text_ratio=3):
+                self.tts_speech_token[this_uuid].append(i)
+        self.llm_end[this_uuid] = True
 
     def token2wav(self, token, prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, embedding):
         with self.flow_hift_context:
@@ -78,15 +84,19 @@ class CosyVoiceModel:
                   llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
                   flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
                   prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32), stream=False):
+        # this_uuid is used to track variables related to this inference thread
+        this_uuid = str(uuid.uuid1())
+        with self.lock:
+            self.tts_speech_token[this_uuid], self.llm_end[this_uuid] = [], False
+        p = threading.Thread(target=self.llm_job, args=(text.to(self.device), text_len.to(self.device), prompt_text.to(self.device), prompt_text_len.to(self.device),
+                                                    llm_prompt_speech_token.to(self.device), llm_prompt_speech_token_len.to(self.device), llm_embedding.to(self.device), this_uuid))
+        p.start()
         if stream is True:
-            self.tts_speech_token, self.llm_end, cache_speech = [], False, None
-            p = threading.Thread(target=self.llm_job, args=(text.to(self.device), text_len.to(self.device), prompt_text.to(self.device), prompt_text_len.to(self.device),
-                                                     llm_prompt_speech_token.to(self.device), llm_prompt_speech_token_len.to(self.device), llm_embedding.to(self.device)))
-            p.start()
+            cache_speech, cache_token, token_hop_len = None, None, self.token_min_hop_len
             while True:
                 time.sleep(0.1)
-                if len(self.tts_speech_token) >= self.stream_win_len:
-                    this_tts_speech_token = torch.concat(self.tts_speech_token[:self.stream_win_len], dim=1)
+                if len(self.tts_speech_token[this_uuid]) >= token_hop_len + self.token_overlap_len:
+                    this_tts_speech_token = torch.concat(self.tts_speech_token[this_uuid][:token_hop_len + self.token_overlap_len], dim=1)
                     with self.flow_hift_context:
                         this_tts_speech = self.token2wav(token=this_tts_speech_token,
                                                     prompt_token=flow_prompt_speech_token.to(self.device),
@@ -96,57 +106,48 @@ class CosyVoiceModel:
                                                     embedding=flow_embedding.to(self.device))
                     # fade in/out if necessary
                     if cache_speech is not None:
-                        this_tts_speech[:, :self.overlap] = this_tts_speech[:, :self.overlap] * self.window[:self.overlap] + cache_speech * self.window[-self.overlap:]
-                    yield  {'tts_speech': this_tts_speech[:, :-self.overlap]}
-                    cache_speech = this_tts_speech[:, -self.overlap:]
+                        this_tts_speech = fade_in_out(this_tts_speech, cache_speech, self.window)
+                    yield  {'tts_speech': this_tts_speech[:, :-self.speech_overlap_len]}
+                    cache_speech = this_tts_speech[:, -self.speech_overlap_len:]
+                    cache_token = self.tts_speech_token[this_uuid][:token_hop_len]
                     with self.lock:
-                        self.tts_speech_token = self.tts_speech_token[self.stream_hop_len:]
-                if self.llm_end is True:
+                        self.tts_speech_token[this_uuid] = self.tts_speech_token[this_uuid][token_hop_len:]
+                    # increase token_hop_len for better speech quality
+                    token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
+                if self.llm_end[this_uuid] is True and len(self.tts_speech_token[this_uuid]) < token_hop_len + self.token_overlap_len:
                     break
-            # deal with remain tokens
-            if cache_speech is None or len(self.tts_speech_token) > self.stream_win_len - self.stream_hop_len:
-                this_tts_speech_token = torch.concat(self.tts_speech_token, dim=1)
-                with self.flow_hift_context:
-                    this_tts_mel = self.flow.inference(token=this_tts_speech_token,
-                                                token_len=torch.tensor([this_tts_speech_token.size(1)], dtype=torch.int32).to(self.device),
-                                                prompt_token=flow_prompt_speech_token.to(self.device),
-                                                prompt_token_len=flow_prompt_speech_token_len.to(self.device),
-                                                prompt_feat=prompt_speech_feat.to(self.device),
-                                                prompt_feat_len=prompt_speech_feat_len.to(self.device),
-                                                embedding=flow_embedding.to(self.device))
-                    this_tts_speech = self.hift.inference(mel=this_tts_mel).cpu()
-                if cache_speech is not None:
-                    this_tts_speech[:, :self.overlap] = this_tts_speech[:, :self.overlap] * self.window[:self.overlap] + cache_speech * self.window[-self.overlap:]
-                yield {'tts_speech': this_tts_speech}
-            else:
-                assert len(self.tts_speech_token) == self.stream_win_len - self.stream_hop_len, 'tts_speech_token not equal to {}'.format(self.stream_win_len - self.stream_hop_len)
-                yield {'tts_speech': cache_speech}
             p.join()
-            torch.cuda.synchronize()
+            # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
+            this_tts_speech_token = torch.concat(self.tts_speech_token[this_uuid], dim=1)
+            if this_tts_speech_token.shape[1] < self.token_min_hop_len + self.token_overlap_len and cache_token is not None:
+                cache_token_len = self.token_min_hop_len + self.token_overlap_len - this_tts_speech_token.shape[1]
+                this_tts_speech_token = torch.concat([torch.concat(cache_token[-cache_token_len:], dim=1), this_tts_speech_token], dim=1)
+            else:
+                cache_token_len = 0
+            with self.flow_hift_context:
+                this_tts_speech = self.token2wav(token=this_tts_speech_token,
+                                            prompt_token=flow_prompt_speech_token.to(self.device),
+                                            prompt_token_len=flow_prompt_speech_token_len.to(self.device),
+                                            prompt_feat=prompt_speech_feat.to(self.device),
+                                            prompt_feat_len=prompt_speech_feat_len.to(self.device),
+                                            embedding=flow_embedding.to(self.device))
+                this_tts_speech = this_tts_speech[:, int(cache_token_len / this_tts_speech_token.shape[1] * this_tts_speech.shape[1]):]
+            if cache_speech is not None:
+                this_tts_speech = fade_in_out(this_tts_speech, cache_speech, self.window)
+            yield {'tts_speech': this_tts_speech}
         else:
-            tts_speech_token = []
-            for i in self.llm.inference(text=text.to(self.device),
-                                                text_len=text_len.to(self.device),
-                                                prompt_text=prompt_text.to(self.device),
-                                                prompt_text_len=prompt_text_len.to(self.device),
-                                                prompt_speech_token=llm_prompt_speech_token.to(self.device),
-                                                prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device),
-                                                embedding=llm_embedding.to(self.device),
-                                                beam_size=1,
-                                                sampling=25,
-                                                max_token_text_ratio=30,
-                                                min_token_text_ratio=3,
-                                                stream=stream):
-                tts_speech_token.append(i)
-            assert len(tts_speech_token) == 1, 'tts_speech_token len should be 1 when stream is {}'.format(stream)
-            tts_speech_token = torch.concat(tts_speech_token, dim=1)
-            tts_mel = self.flow.inference(token=tts_speech_token,
-                                        token_len=torch.tensor([tts_speech_token.size(1)], dtype=torch.int32).to(self.device),
-                                        prompt_token=flow_prompt_speech_token.to(self.device),
-                                        prompt_token_len=flow_prompt_speech_token_len.to(self.device),
-                                        prompt_feat=prompt_speech_feat.to(self.device),
-                                        prompt_feat_len=prompt_speech_feat_len.to(self.device),
-                                        embedding=flow_embedding.to(self.device))
-            tts_speech = self.hift.inference(mel=tts_mel).cpu()
-            torch.cuda.empty_cache()
-            yield {'tts_speech': tts_speech}
+            # deal with all tokens
+            p.join()
+            this_tts_speech_token = torch.concat(self.tts_speech_token[this_uuid], dim=1)
+            with self.flow_hift_context:
+                this_tts_speech = self.token2wav(token=this_tts_speech_token,
+                                            prompt_token=flow_prompt_speech_token.to(self.device),
+                                            prompt_token_len=flow_prompt_speech_token_len.to(self.device),
+                                            prompt_feat=prompt_speech_feat.to(self.device),
+                                            prompt_feat_len=prompt_speech_feat_len.to(self.device),
+                                            embedding=flow_embedding.to(self.device))
+            yield {'tts_speech': this_tts_speech}
+        with self.lock:
+            self.tts_speech_token.pop(this_uuid)
+            self.llm_end.pop(this_uuid)
+        torch.cuda.synchronize()

+ 9 - 9
cosyvoice/flow/flow.py

@@ -105,6 +105,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
         embedding = self.spk_embed_affine_layer(embedding)
 
         # concat text and prompt_text
+        token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
         token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
         mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(embedding)
         token = self.input_embedding(torch.clamp(token, min=0)) * mask
@@ -112,17 +113,16 @@ class MaskedDiffWithXvec(torch.nn.Module):
         # text encode
         h, h_lengths = self.encoder(token, token_len)
         h = self.encoder_proj(h)
-        feat_len = (token_len / 50 * 22050 / 256).int()
-        h, h_lengths = self.length_regulator(h, feat_len)
+        mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / 50 * 22050 / 256)
+        h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2)
 
         # get conditions
-        conds = torch.zeros([1, feat_len.max().item(), self.output_size], device=token.device)
-        if prompt_feat.shape[1] != 0:
-            for i, j in enumerate(prompt_feat_len):
-                conds[i, :j] = prompt_feat[i]
+        conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
+        conds[:, :mel_len1] = prompt_feat
         conds = conds.transpose(1, 2)
 
-        mask = (~make_pad_mask(feat_len)).to(h)
+        # mask = (~make_pad_mask(feat_len)).to(h)
+        mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
         feat = self.decoder(
             mu=h.transpose(1, 2).contiguous(),
             mask=mask.unsqueeze(1),
@@ -130,6 +130,6 @@ class MaskedDiffWithXvec(torch.nn.Module):
             cond=conds,
             n_timesteps=10
         )
-        if prompt_feat.shape[1] != 0:
-            feat = feat[:, :, prompt_feat.shape[1]:]
+        feat = feat[:, :, mel_len1:]
+        assert feat.shape[2] == mel_len2
         return feat

+ 19 - 0
cosyvoice/flow/length_regulator.py

@@ -13,6 +13,7 @@
 # limitations under the License.
 from typing import Tuple
 import torch.nn as nn
+import torch
 from torch.nn import functional as F
 from cosyvoice.utils.mask import make_pad_mask
 
@@ -47,3 +48,21 @@ class InterpolateRegulator(nn.Module):
         out = self.model(x).transpose(1, 2).contiguous()
         olens = ylens
         return out * mask, olens
+
+    def inference(self, x1, x2, mel_len1, mel_len2):
+        # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
+        # x in (B, T, D)
+        if x2.shape[1] > 40:
+            x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=34, mode='linear')
+            x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - 34 * 2, mode='linear')
+            x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=34, mode='linear')
+            x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
+        else:
+            x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
+        if x1.shape[1] != 0:
+            x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
+            x = torch.concat([x1, x2], dim=2)
+        else:
+            x = x2
+        out = self.model(x).transpose(1, 2).contiguous()
+        return out, mel_len1 + mel_len2

+ 11 - 16
cosyvoice/llm/llm.py

@@ -11,7 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Dict, Optional, Union
+from typing import Dict, Optional, Callable, List, Generator
 import torch
 from torch import nn
 import torch.nn.functional as F
@@ -31,6 +31,7 @@ class TransformerLM(torch.nn.Module):
             speech_token_size: int,
             text_encoder: torch.nn.Module,
             llm: torch.nn.Module,
+            sampling: Callable,
             length_normalized_loss: bool = True,
             lsm_weight: float = 0.0,
             spk_embed_dim: int = 192,
@@ -63,6 +64,9 @@ class TransformerLM(torch.nn.Module):
         self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
         self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
 
+        # 4. sampling method
+        self.sampling = sampling
+
     def encode(
             self,
             text: torch.Tensor,
@@ -132,14 +136,12 @@ class TransformerLM(torch.nn.Module):
     def sampling_ids(
             self,
             weighted_scores: torch.Tensor,
-            sampling: Union[bool, int, float] = True,
-            beam_size: int = 1,
+            decoded_tokens: List,
+            sampling: int,
             ignore_eos: bool = True,
     ):
         while True:
-            prob, indices = weighted_scores.softmax(dim=-1).topk(sampling)
-            top_ids = prob.multinomial(beam_size, replacement=True)
-            top_ids = indices[top_ids]
+            top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
             if (not ignore_eos) or (self.speech_token_size not in top_ids):
                 break
         return top_ids
@@ -154,12 +156,10 @@ class TransformerLM(torch.nn.Module):
             prompt_speech_token: torch.Tensor,
             prompt_speech_token_len: torch.Tensor,
             embedding: torch.Tensor,
-            beam_size: int = 1,
             sampling: int = 25,
             max_token_text_ratio: float = 20,
             min_token_text_ratio: float = 2,
-            stream: bool = False,
-    ) -> torch.Tensor:
+    ) -> Generator[torch.Tensor, None, None]:
         device = text.device
         text = torch.concat([prompt_text, text], dim=1)
         text_len += prompt_text_len
@@ -197,16 +197,11 @@ class TransformerLM(torch.nn.Module):
             y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=0, required_cache_size=-1, att_cache=att_cache, cnn_cache=cnn_cache,
                                                                   att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool))
             logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
-            top_ids = self.sampling_ids(logp.squeeze(dim=0), sampling, beam_size, ignore_eos=True if i < min_len else False).item()
+            top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
             if top_ids == self.speech_token_size:
                 break
             # in stream mode, yield token one by one
-            if stream is True:
-                yield torch.tensor([[top_ids]], dtype=torch.int64, device=device)
+            yield torch.tensor([[top_ids]], dtype=torch.int64, device=device)
             out_tokens.append(top_ids)
             offset += lm_input.size(1)
             lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
-
-        # in non-stream mode, yield all token
-        if stream is False:
-            yield torch.tensor([out_tokens], dtype=torch.int64, device=device)

+ 34 - 0
cosyvoice/utils/common.py

@@ -101,3 +101,37 @@ def init_weights(m, mean=0.0, std=0.01):
     classname = m.__class__.__name__
     if classname.find("Conv") != -1:
         m.weight.data.normal_(mean, std)
+
+# Repetition Aware Sampling in VALL-E 2
+def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1):
+    top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
+    rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item()
+    if rep_num >= win_size * tau_r:
+        top_ids = random_sampling(weighted_scores, decoded_tokens, sampling)
+    return top_ids
+
+def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
+    prob, indices = [], []
+    cum_prob = 0.0
+    sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True)
+    for i in range(len(sorted_idx)):
+        # sampling both top-p and numbers.
+        if cum_prob < top_p and len(prob) < top_k:
+            cum_prob += sorted_value[i]
+            prob.append(sorted_value[i])
+            indices.append(sorted_idx[i])
+        else:
+            break
+    prob = torch.tensor(prob).to(weighted_scores)
+    indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
+    top_ids = indices[prob.multinomial(1, replacement=True)]
+    return top_ids
+
+def random_sampling(weighted_scores, decoded_tokens, sampling):
+    top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
+    return top_ids
+
+def fade_in_out(fade_in_speech, fade_out_speech, window):
+    speech_overlap_len = int(window.shape[0] / 2)
+    fade_in_speech[:, :speech_overlap_len] = fade_in_speech[:, :speech_overlap_len] * window[:speech_overlap_len] + fade_out_speech[:, -speech_overlap_len:] * window[speech_overlap_len:]
+    return fade_in_speech

+ 5 - 0
examples/libritts/cosyvoice/conf/cosyvoice.fromscratch.yaml

@@ -54,6 +54,11 @@ llm: !new:cosyvoice.llm.llm.TransformerLM
         pos_enc_layer_type: 'rel_pos_espnet'
         selfattention_layer_type: 'rel_selfattn'
         static_chunk_size: 1
+    sampling: !name:cosyvoice.utils.common.ras_sampling
+        top_p: 0.8
+        top_k: 25
+        win_size: 10
+        tau_r: 0.1
 
 flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
     input_size: 512

+ 5 - 0
examples/libritts/cosyvoice/conf/cosyvoice.yaml

@@ -54,6 +54,11 @@ llm: !new:cosyvoice.llm.llm.TransformerLM
         pos_enc_layer_type: 'rel_pos_espnet'
         selfattention_layer_type: 'rel_selfattn'
         static_chunk_size: 1
+    sampling: !name:cosyvoice.utils.common.ras_sampling
+        top_p: 0.8
+        top_k: 25
+        win_size: 10
+        tau_r: 0.1
 
 flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
     input_size: 512

+ 4 - 1
runtime/python/grpc/client.py

@@ -61,8 +61,11 @@ def main():
             request.instruct_request.CopyFrom(instruct_request)
 
         response = stub.Inference(request)
+        tts_audio = b''
+        for r in response:
+            tts_audio += r.tts_audio
+        tts_speech = torch.from_numpy(np.array(np.frombuffer(tts_audio, dtype=np.int16))).unsqueeze(dim=0)
         logging.info('save response to {}'.format(args.tts_wav))
-        tts_speech = torch.from_numpy(np.array(np.frombuffer(response.tts_audio, dtype=np.int16))).unsqueeze(dim=0)
         torchaudio.save(args.tts_wav, tts_speech, target_sr)
         logging.info('get response')
 

+ 1 - 1
runtime/python/grpc/cosyvoice.proto

@@ -4,7 +4,7 @@ package cosyvoice;
 option go_package = "protos/";
 
 service CosyVoice{
-  rpc Inference(Request) returns (Response) {}
+  rpc Inference(Request) returns (stream Response) {}
 }
 
 message Request{

+ 4 - 3
runtime/python/grpc/server.py

@@ -54,9 +54,10 @@ class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
             model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text, request.instruct_request.spk_id, request.instruct_request.instruct_text)
 
         logging.info('send inference response')
-        response = cosyvoice_pb2.Response()
-        response.tts_audio = (model_output['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
-        return response
+        for i in model_output:
+            response = cosyvoice_pb2.Response()
+            response.tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
+            yield response
 
 def main():
     grpcServer = grpc.server(futures.ThreadPoolExecutor(max_workers=args.max_conc), maximum_concurrent_rpcs=args.max_conc)

+ 1 - 1
webui.py

@@ -164,7 +164,7 @@ def main():
                               outputs=[audio_output])
         mode_checkbox_group.change(fn=change_instruction, inputs=[mode_checkbox_group], outputs=[instruction_text])
     demo.queue(max_size=4, default_concurrency_limit=2)
-    demo.launch(server_port=args.port)
+    demo.launch(server_name='0.0.0.0', server_port=args.port)
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser()