1
0
Selaa lähdekoodia

add train cfg in flow matching

lyuxiang.lx 1 vuosi sitten
vanhempi
commit
44aea805ea

+ 4 - 2
README.md

@@ -131,8 +131,10 @@ you can run following steps. Otherwise, you can just ignore this step.
 cd runtime/python
 docker build -t cosyvoice:v1.0 .
 # change iic/CosyVoice-300M to iic/CosyVoice-300M-Instruct if you want to use instruct inference
-docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python && python3 server.py --port 50000 --max_conc 4 --model_dir iic/CosyVoice-300M && sleep infinity"
-python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
+# for grpc usage
+docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/grpc && python3 server.py --port 50000 --max_conc 4 --model_dir iic/CosyVoice-300M && sleep infinity"
+python3 grpc/client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
+# for fastapi usage
 ```
 
 ## Discussion & Communication

+ 7 - 0
cosyvoice/flow/flow_matching.py

@@ -126,6 +126,13 @@ class ConditionalCFM(BASECFM):
         y = (1 - (1 - self.sigma_min) * t) * z + t * x1
         u = x1 - (1 - self.sigma_min) * z
 
+        # during training, we randomly drop condition to trade off mode coverage and sample fidelity
+        if self.training_cfg_rate > 0:
+            cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
+            mu = mu * cfg_mask.view(-1, 1, 1)
+            spks = spks * cfg_mask.view(-1, 1)
+            cond = cond * cfg_mask.view(-1, 1, 1)
+
         pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
         loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
         return loss, y

+ 1 - 2
runtime/python/Dockerfile

@@ -10,5 +10,4 @@ RUN git lfs install
 RUN git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
 # here we use python==3.10 because we cannot find an image which have both python3.8 and torch2.0.1-cu118 installed
 RUN cd CosyVoice && pip3 install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
-RUN cd CosyVoice/runtime/python && python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. cosyvoice.proto
-CMD ["/bin/bash", "-c", "cd /opt/CosyVoice/CosyVoice/runtime/python && python3 server.py --port 50000 --max_conc 4 --model_dir iic/CosyVoice-300M && sleep infinity"]
+RUN cd CosyVoice/runtime/python/grpc && python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. cosyvoice.proto

+ 0 - 0
runtime/python/fastapi_client.py → runtime/python/fastapi/fastapi_client.py


+ 0 - 0
runtime/python/fastapi_server.py → runtime/python/fastapi/fastapi_server.py


+ 3 - 3
runtime/python/client.py → runtime/python/grpc/client.py

@@ -14,8 +14,8 @@
 import os
 import sys
 ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
-sys.path.append('{}/../..'.format(ROOT_DIR))
-sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
+sys.path.append('{}/../../..'.format(ROOT_DIR))
+sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
 import logging
 import argparse
 import torchaudio
@@ -90,7 +90,7 @@ if __name__ == "__main__":
                         default='希望你以后能够做的比我还好呦。')
     parser.add_argument('--prompt_wav',
                         type=str,
-                        default='../../zero_shot_prompt.wav')
+                        default='../../../zero_shot_prompt.wav')
     parser.add_argument('--instruct_text',
                         type=str,
                         default='Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.')

+ 0 - 0
runtime/python/cosyvoice.proto → runtime/python/grpc/cosyvoice.proto


+ 2 - 3
runtime/python/server.py → runtime/python/grpc/server.py

@@ -14,8 +14,8 @@
 import os
 import sys
 ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
-sys.path.append('{}/../..'.format(ROOT_DIR))
-sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
+sys.path.append('{}/../../..'.format(ROOT_DIR))
+sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
 from concurrent import futures
 import argparse
 import cosyvoice_pb2
@@ -77,7 +77,6 @@ if __name__ == '__main__':
                         default=4)
     parser.add_argument('--model_dir',
                         type=str,
-                        required=True,
                         default='iic/CosyVoice-300M',
                         help='local path or modelscope repo id')
     args = parser.parse_args()