Ver código fonte

update gradio

lyuxiang.lx 10 meses atrás
pai
commit
1e52c6071e
2 arquivos alterados com 18 adições e 15 exclusões
  1. 17 14
      cosyvoice/flow/flow_matching.py
  2. 1 1
      requirements.txt

+ 17 - 14
cosyvoice/flow/flow_matching.py

@@ -11,6 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import threading
 import torch
 import torch.nn.functional as F
 from matcha.models.components.flow_matching import BASECFM
@@ -30,6 +31,7 @@ class ConditionalCFM(BASECFM):
         in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
         # Just change the architecture of the estimator here
         self.estimator = estimator
+        self.lock = threading.Lock()
 
     @torch.inference_mode()
     def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
@@ -123,20 +125,21 @@ class ConditionalCFM(BASECFM):
         if isinstance(self.estimator, torch.nn.Module):
             return self.estimator.forward(x, mask, mu, t, spks, cond)
         else:
-            self.estimator.set_input_shape('x', (2, 80, x.size(2)))
-            self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
-            self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
-            self.estimator.set_input_shape('t', (2,))
-            self.estimator.set_input_shape('spks', (2, 80))
-            self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
-            # run trt engine
-            self.estimator.execute_v2([x.contiguous().data_ptr(),
-                                       mask.contiguous().data_ptr(),
-                                       mu.contiguous().data_ptr(),
-                                       t.contiguous().data_ptr(),
-                                       spks.contiguous().data_ptr(),
-                                       cond.contiguous().data_ptr(),
-                                       x.data_ptr()])
+            with self.lock:
+                self.estimator.set_input_shape('x', (2, 80, x.size(2)))
+                self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
+                self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
+                self.estimator.set_input_shape('t', (2,))
+                self.estimator.set_input_shape('spks', (2, 80))
+                self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
+                # run trt engine
+                self.estimator.execute_v2([x.contiguous().data_ptr(),
+                                        mask.contiguous().data_ptr(),
+                                        mu.contiguous().data_ptr(),
+                                        t.contiguous().data_ptr(),
+                                        spks.contiguous().data_ptr(),
+                                        cond.contiguous().data_ptr(),
+                                        x.data_ptr()])
             return x
 
     def compute_loss(self, x1, mask, mu, spks=None, cond=None):

+ 1 - 1
requirements.txt

@@ -4,7 +4,7 @@ conformer==0.3.2
 deepspeed==0.14.2; sys_platform == 'linux'
 diffusers==0.27.2
 gdown==5.1.0
-gradio==4.32.2
+gradio==5.4.0
 grpcio==1.57.0
 grpcio-tools==1.57.0
 huggingface-hub==0.25.2