|
|
@@ -33,8 +33,8 @@ from cosyvoice.transformer.attention import (MultiHeadedAttention,
|
|
|
from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding
|
|
|
from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling
|
|
|
from cosyvoice.llm.llm import TransformerLM, Qwen2LM
|
|
|
-from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec
|
|
|
-from cosyvoice.hifigan.generator import HiFTGenerator
|
|
|
+from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec, CausalMaskedDiffWithDiT
|
|
|
+from cosyvoice.hifigan.generator import HiFTGenerator, CausalHiFTGenerator
|
|
|
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
|
|
|
|
|
|
|
|
|
@@ -80,4 +80,6 @@ def get_model_type(configs):
|
|
|
return CosyVoiceModel
|
|
|
if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
|
|
|
return CosyVoice2Model
|
|
|
+ if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithDiT) and isinstance(configs['hift'], CausalHiFTGenerator):
|
|
|
+ return CosyVoice2Model
|
|
|
raise TypeError('No valid model type found!')
|