1
0
Преглед на файлове

Merge branch 'dev/lyuxiang.lx' into main

Xiang Lyu преди 10 месеца
родител
ревизия
92f1c659b9
променени са 2 файла, в които са добавени 22 реда и са изтрити 2 реда
  1. 21 1
      cosyvoice/flow/flow_matching.py
  2. 1 1
      requirements.txt

+ 21 - 1
cosyvoice/flow/flow_matching.py

@@ -11,6 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import threading
 import torch
 import torch.nn.functional as F
 from matcha.models.components.flow_matching import BASECFM
@@ -30,6 +31,7 @@ class ConditionalCFM(BASECFM):
         in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
         # Just change the architecture of the estimator here
         self.estimator = estimator
+        self.lock = threading.Lock()
 
     @torch.inference_mode()
     def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
@@ -120,7 +122,25 @@ class ConditionalCFM(BASECFM):
         return sol[-1].float()
 
     def forward_estimator(self, x, mask, mu, t, spks, cond):
-        return self.estimator.forward(x, mask, mu, t, spks, cond)
+        if isinstance(self.estimator, torch.nn.Module):
+            return self.estimator.forward(x, mask, mu, t, spks, cond)
+        else:
+            with self.lock:
+                self.estimator.set_input_shape('x', (2, 80, x.size(2)))
+                self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
+                self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
+                self.estimator.set_input_shape('t', (2,))
+                self.estimator.set_input_shape('spks', (2, 80))
+                self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
+                # run trt engine
+                self.estimator.execute_v2([x.contiguous().data_ptr(),
+                                        mask.contiguous().data_ptr(),
+                                        mu.contiguous().data_ptr(),
+                                        t.contiguous().data_ptr(),
+                                        spks.contiguous().data_ptr(),
+                                        cond.contiguous().data_ptr(),
+                                        x.data_ptr()])
+            return x
 
     def compute_loss(self, x1, mask, mu, spks=None, cond=None):
         """Computes diffusion loss

+ 1 - 1
requirements.txt

@@ -4,7 +4,7 @@ conformer==0.3.2
 deepspeed==0.14.2; sys_platform == 'linux'
 diffusers==0.27.2
 gdown==5.1.0
-gradio==4.32.2
+gradio==5.4.0
 grpcio==1.57.0
 grpcio-tools==1.57.0
 huggingface-hub==0.25.2