Browse Source

update dpo

lyuxiang.lx 5 months ago
parent
commit
63856565f3

+ 1 - 0
cosyvoice/bin/inference.py → cosyvoice/bin/inference_deprecated.py

@@ -122,4 +122,5 @@ def main():
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
+    logging.warning('this code has been deprecated, please refer to README for CosyVoice inference usage!')
     main()
     main()

+ 23 - 3
cosyvoice/bin/train.py

@@ -27,6 +27,7 @@ from hyperpyyaml import load_hyperpyyaml
 
 
 from torch.distributed.elastic.multiprocessing.errors import record
 from torch.distributed.elastic.multiprocessing.errors import record
 
 
+from cosyvoice.utils.losses import DPOLoss
 from cosyvoice.utils.executor import Executor
 from cosyvoice.utils.executor import Executor
 from cosyvoice.utils.train_utils import (
 from cosyvoice.utils.train_utils import (
     init_distributed,
     init_distributed,
@@ -43,6 +44,7 @@ def get_args():
                         choices=['torch_ddp', 'deepspeed'],
                         choices=['torch_ddp', 'deepspeed'],
                         help='Engine for paralleled training')
                         help='Engine for paralleled training')
     parser.add_argument('--model', required=True, help='model which will be trained')
     parser.add_argument('--model', required=True, help='model which will be trained')
+    parser.add_argument('--ref_model', required=False, help='ref model used in dpo')
     parser.add_argument('--config', required=True, help='config file')
     parser.add_argument('--config', required=True, help='config file')
     parser.add_argument('--train_data', required=True, help='train data file')
     parser.add_argument('--train_data', required=True, help='train data file')
     parser.add_argument('--cv_data', required=True, help='cv data file')
     parser.add_argument('--cv_data', required=True, help='cv data file')
@@ -73,6 +75,10 @@ def get_args():
                         action='store_true',
                         action='store_true',
                         default=False,
                         default=False,
                         help='Use automatic mixed precision training')
                         help='Use automatic mixed precision training')
+    parser.add_argument('--dpo',
+                        action='store_true',
+                        default=False,
+                        help='Use Direct Preference Optimization')
     parser.add_argument('--deepspeed.save_states',
     parser.add_argument('--deepspeed.save_states',
                         dest='save_states',
                         dest='save_states',
                         default='model_only',
                         default='model_only',
@@ -113,7 +119,7 @@ def main():
 
 
     # Get dataset & dataloader
     # Get dataset & dataloader
     train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
     train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
-        init_dataset_and_dataloader(args, configs, gan)
+        init_dataset_and_dataloader(args, configs, gan, args.dpo)
 
 
     # Do some sanity checks and save config to arsg.model_dir
     # Do some sanity checks and save config to arsg.model_dir
     configs = check_modify_and_save_config(args, configs)
     configs = check_modify_and_save_config(args, configs)
@@ -122,6 +128,8 @@ def main():
     writer = init_summarywriter(args)
     writer = init_summarywriter(args)
 
 
     # load checkpoint
     # load checkpoint
+    if args.dpo is True:
+        configs[args.model].forward = configs[args.model].forward_dpo
     model = configs[args.model]
     model = configs[args.model]
     start_step, start_epoch = 0, -1
     start_step, start_epoch = 0, -1
     if args.checkpoint is not None:
     if args.checkpoint is not None:
@@ -150,13 +158,25 @@ def main():
     info_dict['epoch'] = start_epoch
     info_dict['epoch'] = start_epoch
     save_model(model, 'init', info_dict)
     save_model(model, 'init', info_dict)
 
 
+    # DPO related
+    if args.dpo is True:
+        ref_model = deepcopy(configs[args.model])
+        state_dict = torch.load(args.ref_model, map_location='cpu')
+        ref_model.load_state_dict(state_dict, strict=False)
+        dpo_loss = DPOLoss(beta=0.01, label_smoothing=0.0, ipo=False)
+        # NOTE maybe it is not needed to wrap ref_model as ddp because its parameter is not updated
+        ref_model = wrap_cuda_model(args, ref_model)
+    else:
+        ref_model, dpo_loss = None, None
+
     # Get executor
     # Get executor
-    executor = Executor(gan=gan)
+    executor = Executor(gan=gan, ref_model=ref_model, dpo_loss=dpo_loss)
     executor.step = start_step
     executor.step = start_step
 
 
     # Init scaler, used for pytorch amp mixed precision training
     # Init scaler, used for pytorch amp mixed precision training
     scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
     scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
     print('start step {} start epoch {}'.format(start_step, start_epoch))
     print('start step {} start epoch {}'.format(start_step, start_epoch))
+
     # Start training loop
     # Start training loop
     for epoch in range(start_epoch + 1, info_dict['max_epoch']):
     for epoch in range(start_epoch + 1, info_dict['max_epoch']):
         executor.epoch = epoch
         executor.epoch = epoch
@@ -167,7 +187,7 @@ def main():
             executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
             executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
                                         writer, info_dict, scaler, group_join)
                                         writer, info_dict, scaler, group_join)
         else:
         else:
-            executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join)
+            executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model=ref_model)
         dist.destroy_process_group(group_join)
         dist.destroy_process_group(group_join)
 
 
 
 

+ 0 - 187
cosyvoice/bin/train_dpo.py

@@ -1,187 +0,0 @@
-# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from __future__ import print_function
-import argparse
-import datetime
-import logging
-logging.getLogger('matplotlib').setLevel(logging.WARNING)
-from copy import deepcopy
-import os
-import torch
-import torch.distributed as dist
-import deepspeed
-
-from hyperpyyaml import load_hyperpyyaml
-
-from torch.distributed.elastic.multiprocessing.errors import record
-
-from cosyvoice.utils.executor_dpo import Executor
-from cosyvoice.utils.train_utils_dpo import (
-    init_distributed,
-    init_dataset_and_dataloader,
-    init_optimizer_and_scheduler,
-    init_summarywriter, save_model,
-    wrap_cuda_model, check_modify_and_save_config)
-
-
-def get_args():
-    parser = argparse.ArgumentParser(description='training your network')
-    parser.add_argument('--train_engine',
-                        default='torch_ddp',
-                        choices=['torch_ddp', 'deepspeed'],
-                        help='Engine for paralleled training')
-    parser.add_argument('--model', required=True, help='model which will be trained')
-    parser.add_argument('--config', required=True, help='config file')
-    parser.add_argument('--train_data', required=True, help='train data file')
-    parser.add_argument('--cv_data', required=True, help='cv data file')
-    parser.add_argument('--checkpoint', help='checkpoint model')
-    parser.add_argument('--model_dir', required=True, help='save model dir')
-    parser.add_argument('--tensorboard_dir',
-                        default='tensorboard',
-                        help='tensorboard log dir')
-    parser.add_argument('--ddp.dist_backend',
-                        dest='dist_backend',
-                        default='nccl',
-                        choices=['nccl', 'gloo'],
-                        help='distributed backend')
-    parser.add_argument('--num_workers',
-                        default=0,
-                        type=int,
-                        help='num of subprocess workers for reading')
-    parser.add_argument('--prefetch',
-                        default=100,
-                        type=int,
-                        help='prefetch number')
-    parser.add_argument('--pin_memory',
-                        action='store_true',
-                        default=False,
-                        help='Use pinned memory buffers used for reading')
-    parser.add_argument('--use_amp',
-                        action='store_true',
-                        default=False,
-                        help='Use automatic mixed precision training')
-    parser.add_argument('--deepspeed.save_states',
-                        dest='save_states',
-                        default='model_only',
-                        choices=['model_only', 'model+optimizer'],
-                        help='save model/optimizer states')
-    parser.add_argument('--timeout',
-                        default=60,
-                        type=int,
-                        help='timeout (in seconds) of cosyvoice_join.')
-    parser.add_argument('--dpo',
-                        action='store_true',
-                        default=False,
-                        help='Use Direct Preference Optimization')
-    parser.add_argument('--beta',
-                        default=0.01,
-                        type=float,
-                        help='beta of dpo training')
-    parser = deepspeed.add_config_arguments(parser)
-    args = parser.parse_args()
-    return args
-
-
-@record
-def main():
-    args = get_args()
-    logging.basicConfig(level=logging.DEBUG,
-                        format='%(asctime)s %(levelname)s %(message)s')
-    # gan train has some special initialization logic
-    gan = True if args.model == 'hifigan' else False
-
-    override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
-    if gan is True:
-        override_dict.pop('hift')
-    with open(args.config, 'r') as f:
-        configs = load_hyperpyyaml(f, overrides=override_dict)
-    if gan is True:
-        configs['train_conf'] = configs['train_conf_gan']
-    configs['train_conf'].update(vars(args))
-
-    # Init env for ddp
-    init_distributed(args)
-
-    # Get dataset & dataloader
-    train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
-        init_dataset_and_dataloader(args, configs, gan)
-
-    # Do some sanity checks and save config to arsg.model_dir
-    configs = check_modify_and_save_config(args, configs)
-
-    # Tensorboard summary
-    writer = init_summarywriter(args)
-
-    # load checkpoint
-    model = configs[args.model]
-    ref_model = None
-    if args.dpo:
-        ref_model = deepcopy(model)
-    start_step, start_epoch = 0, -1
-    if args.checkpoint is not None:
-        if os.path.exists(args.checkpoint):
-            state_dict = torch.load(args.checkpoint, map_location='cpu')
-            model.load_state_dict(state_dict, strict=False)
-            if args.dpo:
-                ref_model.load_state_dict(state_dict, strict=False)
-            if 'step' in state_dict:
-                start_step = state_dict['step']
-            if 'epoch' in state_dict:
-                start_epoch = state_dict['epoch']
-        else:
-            logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))
-
-    # Dispatch model from cpu to gpu
-    model = wrap_cuda_model(args, model)
-    if args.dpo:
-        ref_model = wrap_cuda_model(args, ref_model)
-
-    # Get optimizer & scheduler
-    model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)
-    if args.dpo:
-        ref_model, _, _, _, _ = init_optimizer_and_scheduler(args, configs, ref_model, gan)
-    scheduler.set_step(start_step)
-    if scheduler_d is not None:
-        scheduler_d.set_step(start_step)
-
-    # Save init checkpoints
-    info_dict = deepcopy(configs['train_conf'])
-    info_dict['step'] = start_step
-    info_dict['epoch'] = start_epoch
-    save_model(model, 'init', info_dict)
-
-    # Get executor
-    executor = Executor(gan=gan, dpo=args.dpo, beta=args.beta)
-    executor.step = start_step
-
-    # Init scaler, used for pytorch amp mixed precision training
-    scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
-    print('start step {} start epoch {}'.format(start_step, start_epoch))
-    # Start training loop
-    for epoch in range(start_epoch + 1, info_dict['max_epoch']):
-        executor.epoch = epoch
-        train_dataset.set_epoch(epoch)
-        dist.barrier()
-        group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
-        if gan is True:
-            executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
-                                        writer, info_dict, scaler, group_join)
-        else:
-            executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model)
-        dist.destroy_process_group(group_join)
-
-
-if __name__ == '__main__':
-    main()

+ 2 - 1
cosyvoice/cli/model.py

@@ -103,7 +103,7 @@ class CosyVoiceModel:
     def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
     def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
         with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):
         with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):
             if isinstance(text, Generator):
             if isinstance(text, Generator):
-                assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
+                assert isinstance(self, CosyVoice2Model) and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2 and do not support vllm!'
                 for i in self.llm.inference_bistream(text=text,
                 for i in self.llm.inference_bistream(text=text,
                                                      prompt_text=prompt_text.to(self.device),
                                                      prompt_text=prompt_text.to(self.device),
                                                      prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
                                                      prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
@@ -279,6 +279,7 @@ class CosyVoice2Model(CosyVoiceModel):
                                  enable_prompt_embeds=True,
                                  enable_prompt_embeds=True,
                                  gpu_memory_utilization=0.2)
                                  gpu_memory_utilization=0.2)
         self.llm.vllm = LLMEngine.from_engine_args(engine_args)
         self.llm.vllm = LLMEngine.from_engine_args(engine_args)
+        self.llm.lock = threading.Lock()
         del self.llm.llm.model.model.layers
         del self.llm.llm.model.model.layers
 
 
     def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
     def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):

+ 5 - 18
cosyvoice/dataset/dataset.py

@@ -14,14 +14,13 @@
 # limitations under the License.
 # limitations under the License.
 
 
 import random
 import random
-import json
 import math
 import math
 from functools import partial
 from functools import partial
 
 
 import torch
 import torch
 import torch.distributed as dist
 import torch.distributed as dist
 from torch.utils.data import IterableDataset
 from torch.utils.data import IterableDataset
-from cosyvoice.utils.file_utils import read_lists, read_json_lists
+from cosyvoice.utils.file_utils import read_lists
 
 
 
 
 class Processor(IterableDataset):
 class Processor(IterableDataset):
@@ -127,10 +126,9 @@ def Dataset(data_list_file,
             data_pipeline,
             data_pipeline,
             mode='train',
             mode='train',
             gan=False,
             gan=False,
+            dpo=False,
             shuffle=True,
             shuffle=True,
-            partition=True,
-            tts_file='',
-            prompt_utt2data=''):
+            partition=True):
     """ Construct dataset from arguments
     """ Construct dataset from arguments
 
 
         We have two shuffle stage in the Dataset. The first is global
         We have two shuffle stage in the Dataset. The first is global
@@ -142,23 +140,12 @@ def Dataset(data_list_file,
             tokenizer (BaseTokenizer): tokenizer to tokenize
             tokenizer (BaseTokenizer): tokenizer to tokenize
             partition(bool): whether to do data partition in terms of rank
             partition(bool): whether to do data partition in terms of rank
     """
     """
-    assert mode in ['train', 'inference']
     lists = read_lists(data_list_file)
     lists = read_lists(data_list_file)
-    if mode == 'inference':
-        with open(tts_file) as f:
-            tts_data = json.load(f)
-        utt2lists = read_json_lists(prompt_utt2data)
-        # filter unnecessary file in inference mode
-        lists = list({utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists})
     dataset = DataList(lists,
     dataset = DataList(lists,
                        shuffle=shuffle,
                        shuffle=shuffle,
                        partition=partition)
                        partition=partition)
-    if mode == 'inference':
-        # map partial arg to parquet_opener func in inference mode
-        data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
-    if gan is True:
-        # map partial arg to padding func in gan mode
-        data_pipeline[-1] = partial(data_pipeline[-1], gan=gan)
+    # map partial arg to padding func
+    data_pipeline[-1] = partial(data_pipeline[-1], gan=gan, dpo=dpo)
     for func in data_pipeline:
     for func in data_pipeline:
         dataset = Processor(dataset, func, mode=mode)
         dataset = Processor(dataset, func, mode=mode)
     return dataset
     return dataset

+ 16 - 23
cosyvoice/dataset/processor.py

@@ -43,8 +43,6 @@ def parquet_opener(data, mode='train', tts_data={}):
             for df in pq.ParquetFile(url).iter_batches(batch_size=64):
             for df in pq.ParquetFile(url).iter_batches(batch_size=64):
                 df = df.to_pandas()
                 df = df.to_pandas()
                 for i in range(len(df)):
                 for i in range(len(df)):
-                    if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
-                        continue
                     sample.update(dict(df.loc[i]))
                     sample.update(dict(df.loc[i]))
                     if mode == 'train':
                     if mode == 'train':
                         # NOTE do not return sample directly, must initialize a new dict
                         # NOTE do not return sample directly, must initialize a new dict
@@ -100,6 +98,8 @@ def filter(data,
             continue
             continue
         if len(sample['speech_token']) == 0:
         if len(sample['speech_token']) == 0:
             continue
             continue
+        if 'reject_speech_token' in sample and len(sample['reject_speech_token']) == 0:
+            continue
         if num_frames != 0:
         if num_frames != 0:
             if len(sample['text_token']) / num_frames < min_output_input_ratio:
             if len(sample['text_token']) / num_frames < min_output_input_ratio:
                 continue
                 continue
@@ -242,8 +242,6 @@ def tokenize(data, get_tokenizer, allowed_special, mode='train'):
     for sample in data:
     for sample in data:
         assert 'text' in sample
         assert 'text' in sample
         sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
         sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
-        if mode == 'inference':
-            sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
         yield sample
         yield sample
 
 
 
 
@@ -351,18 +349,15 @@ def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
 def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
 def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
     """ Wrapper for static/dynamic batch
     """ Wrapper for static/dynamic batch
     """
     """
-    if mode == 'inference':
-        return static_batch(data, 1)
+    if batch_type == 'static':
+        return static_batch(data, batch_size)
+    elif batch_type == 'dynamic':
+        return dynamic_batch(data, max_frames_in_batch)
     else:
     else:
-        if batch_type == 'static':
-            return static_batch(data, batch_size)
-        elif batch_type == 'dynamic':
-            return dynamic_batch(data, max_frames_in_batch)
-        else:
-            logging.fatal('Unsupported batch type {}'.format(batch_type))
+        logging.fatal('Unsupported batch type {}'.format(batch_type))
 
 
 
 
-def padding(data, use_spk_embedding, mode='train', gan=False):
+def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
     """ Padding the data into training data
     """ Padding the data into training data
 
 
         Args:
         Args:
@@ -424,16 +419,14 @@ def padding(data, use_spk_embedding, mode='train', gan=False):
             # only gan train needs speech, delete it to save memory
             # only gan train needs speech, delete it to save memory
             del batch["speech"]
             del batch["speech"]
             del batch["speech_len"]
             del batch["speech_len"]
-        if mode == 'inference':
-            tts_text = [sample[i]['tts_text'] for i in order]
-            tts_index = [sample[i]['tts_index'] for i in order]
-            tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
-            tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
-            tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
-            batch.update({'tts_text': tts_text,
-                          'tts_index': tts_index,
-                          'tts_text_token': tts_text_token,
-                          'tts_text_token_len': tts_text_token_len})
+        if dpo is True:
+            reject_speech_token = [torch.tensor(sample[i]['reject_speech_token']) for i in order]
+            reject_speech_token_len = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)
+            reject_speech_token = pad_sequence(reject_speech_token,
+                                               batch_first=True,
+                                               padding_value=0)
+            batch['reject_speech_token'] = reject_speech_token
+            batch['reject_speech_token_len'] = reject_speech_token_len
         if use_spk_embedding is True:
         if use_spk_embedding is True:
             batch["embedding"] = batch["spk_embedding"]
             batch["embedding"] = batch["spk_embedding"]
         else:
         else:

+ 0 - 443
cosyvoice/dataset/processor_dpo.py

@@ -1,443 +0,0 @@
-# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import logging
-import random
-
-import pyarrow.parquet as pq
-from io import BytesIO
-import torch
-import torchaudio
-from torch.nn.utils.rnn import pad_sequence
-import torch.nn.functional as F
-import pyworld as pw
-
-
-AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
-
-
-def parquet_opener(data, mode='train', tts_data={}):
-    """ Give url or local file, return file descriptor
-        Inplace operation.
-
-        Args:
-            data(Iterable[str]): url or local file list
-
-        Returns:
-            Iterable[{src, stream}]
-    """
-    for sample in data:
-        assert 'src' in sample
-        url = sample['src']
-        try:
-            for df in pq.ParquetFile(url).iter_batches(batch_size=64):
-                df = df.to_pandas()
-                for i in range(len(df)):
-                    if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
-                        continue
-                    sample.update(dict(df.loc[i]))
-                    if mode == 'train':
-                        # NOTE do not return sample directly, must initialize a new dict
-                        yield {**sample}
-                    else:
-                        for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
-                            yield {**sample, 'tts_index': index, 'tts_text': text}
-        except Exception as ex:
-            logging.warning('Failed to open {}, ex info {}'.format(url, ex))
-
-
-def filter(data,
-           max_length=10240,
-           min_length=10,
-           token_max_length=200,
-           token_min_length=1,
-           min_output_input_ratio=0.0005,
-           max_output_input_ratio=1,
-           mode='train'):
-    """ Filter sample according to feature and label length
-        Inplace operation.
-
-        Args::
-            data: Iterable[{key, wav, label, sample_rate}]
-            max_length: drop utterance which is greater than max_length(10ms)
-            min_length: drop utterance which is less than min_length(10ms)
-            token_max_length: drop utterance which is greater than
-                token_max_length, especially when use char unit for
-                english modeling
-            token_min_length: drop utterance which is
-                less than token_max_length
-            min_output_input_ratio: minimal ration of
-                token_length / feats_length(10ms)
-            max_output_input_ratio: maximum ration of
-                token_length / feats_length(10ms)
-
-        Returns:
-            Iterable[{key, wav, label, sample_rate}]
-    """
-    for sample in data:
-        sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
-        sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
-        del sample['audio_data']
-        # sample['wav'] is torch.Tensor, we have 100 frames every second
-        num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
-        if num_frames < min_length:
-            continue
-        if num_frames > max_length:
-            continue
-        if len(sample['text_token']) < token_min_length:
-            continue
-        if len(sample['text_token']) > token_max_length:
-            continue
-        if len(sample['speech_token']) == 0:
-            continue
-        if num_frames != 0:
-            if len(sample['text_token']) / num_frames < min_output_input_ratio:
-                continue
-            if len(sample['text_token']) / num_frames > max_output_input_ratio:
-                continue
-        yield sample
-
-
-def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
-    """ Resample data.
-        Inplace operation.
-
-        Args:
-            data: Iterable[{key, wav, label, sample_rate}]
-            resample_rate: target resample rate
-
-        Returns:
-            Iterable[{key, wav, label, sample_rate}]
-    """
-    for sample in data:
-        assert 'sample_rate' in sample
-        assert 'speech' in sample
-        sample_rate = sample['sample_rate']
-        waveform = sample['speech']
-        if sample_rate != resample_rate:
-            if sample_rate < min_sample_rate:
-                continue
-            sample['sample_rate'] = resample_rate
-            sample['speech'] = torchaudio.transforms.Resample(
-                orig_freq=sample_rate, new_freq=resample_rate)(waveform)
-        max_val = sample['speech'].abs().max()
-        if max_val > 1:
-            sample['speech'] /= max_val
-        yield sample
-
-
-def truncate(data, truncate_length=24576, mode='train'):
-    """ Truncate data.
-
-        Args:
-            data: Iterable[{key, wav, label, sample_rate}]
-            truncate_length: truncate length
-
-        Returns:
-            Iterable[{key, wav, label, sample_rate}]
-    """
-    for sample in data:
-        waveform = sample['speech']
-        if waveform.shape[1] > truncate_length:
-            start = random.randint(0, waveform.shape[1] - truncate_length)
-            waveform = waveform[:, start: start + truncate_length]
-        else:
-            waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
-        sample['speech'] = waveform
-        yield sample
-
-
-def compute_fbank(data,
-                  feat_extractor,
-                  mode='train'):
-    """ Extract fbank
-
-        Args:
-            data: Iterable[{key, wav, label, sample_rate}]
-
-        Returns:
-            Iterable[{key, feat, label}]
-    """
-    for sample in data:
-        assert 'sample_rate' in sample
-        assert 'speech' in sample
-        assert 'utt' in sample
-        assert 'text_token' in sample
-        waveform = sample['speech']
-        mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
-        sample['speech_feat'] = mat
-        yield sample
-
-
-def compute_f0(data, sample_rate, hop_size, mode='train'):
-    """ Extract f0
-
-        Args:
-            data: Iterable[{key, wav, label, sample_rate}]
-
-        Returns:
-            Iterable[{key, feat, label}]
-    """
-    frame_period = hop_size * 1000 / sample_rate
-    for sample in data:
-        assert 'sample_rate' in sample
-        assert 'speech' in sample
-        assert 'utt' in sample
-        assert 'text_token' in sample
-        waveform = sample['speech']
-        _f0, t = pw.harvest(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period)
-        if sum(_f0 != 0) < 5: # this happens when the algorithm fails
-            _f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) # if harvest fails, try dio
-        f0 = pw.stonemask(waveform.squeeze(dim=0).numpy().astype('double'), _f0, t, sample_rate)
-        f0 = F.interpolate(torch.from_numpy(f0).view(1, 1, -1), size=sample['speech_feat'].shape[0], mode='linear').view(-1)
-        sample['pitch_feat'] = f0
-        yield sample
-
-
-def parse_embedding(data, normalize, mode='train'):
-    """ Parse utt_embedding/spk_embedding
-
-        Args:
-            data: Iterable[{key, wav, label, sample_rate}]
-
-        Returns:
-            Iterable[{key, feat, label}]
-    """
-    for sample in data:
-        sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
-        sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
-        if normalize:
-            sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
-            sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
-        yield sample
-
-
-def tokenize(data, get_tokenizer, allowed_special, mode='train'):
-    """ Decode text to chars or BPE
-        Inplace operation
-
-        Args:
-            data: Iterable[{key, wav, txt, sample_rate}]
-
-        Returns:
-            Iterable[{key, wav, txt, tokens, label, sample_rate}]
-    """
-    tokenizer = get_tokenizer()
-    for sample in data:
-        assert 'text' in sample
-        sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
-        if mode == 'inference':
-            sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
-        yield sample
-
-
-def shuffle(data, shuffle_size=10000, mode='train'):
-    """ Local shuffle the data
-
-        Args:
-            data: Iterable[{key, feat, label}]
-            shuffle_size: buffer size for shuffle
-
-        Returns:
-            Iterable[{key, feat, label}]
-    """
-    buf = []
-    for sample in data:
-        buf.append(sample)
-        if len(buf) >= shuffle_size:
-            random.shuffle(buf)
-            for x in buf:
-                yield x
-            buf = []
-    # The sample left over
-    random.shuffle(buf)
-    for x in buf:
-        yield x
-
-
-def sort(data, sort_size=500, mode='train'):
-    """ Sort the data by feature length.
-        Sort is used after shuffle and before batch, so we can group
-        utts with similar lengths into a batch, and `sort_size` should
-        be less than `shuffle_size`
-
-        Args:
-            data: Iterable[{key, feat, label}]
-            sort_size: buffer size for sort
-
-        Returns:
-            Iterable[{key, feat, label}]
-    """
-
-    buf = []
-    for sample in data:
-        buf.append(sample)
-        if len(buf) >= sort_size:
-            buf.sort(key=lambda x: x['speech_feat'].size(0))
-            for x in buf:
-                yield x
-            buf = []
-    # The sample left over
-    buf.sort(key=lambda x: x['speech_feat'].size(0))
-    for x in buf:
-        yield x
-
-
-def static_batch(data, batch_size=16):
-    """ Static batch the data by `batch_size`
-
-        Args:
-            data: Iterable[{key, feat, label}]
-            batch_size: batch size
-
-        Returns:
-            Iterable[List[{key, feat, label}]]
-    """
-    buf = []
-    for sample in data:
-        buf.append(sample)
-        if len(buf) >= batch_size:
-            yield buf
-            buf = []
-    if len(buf) > 0:
-        yield buf
-
-
-def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
-    """ Dynamic batch the data until the total frames in batch
-        reach `max_frames_in_batch`
-
-        Args:
-            data: Iterable[{key, feat, label}]
-            max_frames_in_batch: max_frames in one batch
-
-        Returns:
-            Iterable[List[{key, feat, label}]]
-    """
-    buf = []
-    longest_frames = 0
-    for sample in data:
-        assert 'speech_feat' in sample
-        assert isinstance(sample['speech_feat'], torch.Tensor)
-        new_sample_frames = sample['speech_feat'].size(0)
-        longest_frames = max(longest_frames, new_sample_frames)
-        frames_after_padding = longest_frames * (len(buf) + 1)
-        if frames_after_padding > max_frames_in_batch:
-            yield buf
-            buf = [sample]
-            longest_frames = new_sample_frames
-        else:
-            buf.append(sample)
-    if len(buf) > 0:
-        yield buf
-
-
-def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
-    """ Wrapper for static/dynamic batch
-    """
-    if mode == 'inference':
-        return static_batch(data, 1)
-    else:
-        if batch_type == 'static':
-            return static_batch(data, batch_size)
-        elif batch_type == 'dynamic':
-            return dynamic_batch(data, max_frames_in_batch)
-        else:
-            logging.fatal('Unsupported batch type {}'.format(batch_type))
-
-
-def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
-    """ Padding the data into training data
-
-        Args:
-            data: Iterable[List[{key, feat, label}]]
-
-        Returns:
-            Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
-    """
-    for sample in data:
-        assert isinstance(sample, list)
-        speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
-                                       dtype=torch.int32)
-        order = torch.argsort(speech_feat_len, descending=True)
-
-        utts = [sample[i]['utt'] for i in order]
-        speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
-        speech_len = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
-        speech = pad_sequence(speech, batch_first=True, padding_value=0)
-        speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
-        speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
-        speech_token = pad_sequence(speech_token,
-                                    batch_first=True,
-                                    padding_value=0)
-        speech_feat = [sample[i]['speech_feat'] for i in order]
-        speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
-        speech_feat = pad_sequence(speech_feat,
-                                   batch_first=True,
-                                   padding_value=0)
-        text = [sample[i]['text'] for i in order]
-        text_token = [torch.tensor(sample[i]['text_token']) for i in order]
-        text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
-        text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
-        utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
-        spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
-        batch = {
-            "utts": utts,
-            "speech": speech,
-            "speech_len": speech_len,
-            "speech_token": speech_token,
-            "speech_token_len": speech_token_len,
-            "speech_feat": speech_feat,
-            "speech_feat_len": speech_feat_len,
-            "text": text,
-            "text_token": text_token,
-            "text_token_len": text_token_len,
-            "utt_embedding": utt_embedding,
-            "spk_embedding": spk_embedding,
-        }
-        if dpo:
-            reject_speech_token = [torch.tensor(sample[i]['reject_speech_token']) for i in order]
-            reject_speech_token_len = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)
-            reject_speech_token = pad_sequence(reject_speech_token,
-                                                batch_first=True,
-                                                padding_value=0)
-            batch['reject_speech_token'] = reject_speech_token
-            batch['reject_speech_token_len'] = reject_speech_token_len
-        if gan is True:
-            # in gan train, we need pitch_feat
-            pitch_feat = [sample[i]['pitch_feat'] for i in order]
-            pitch_feat_len = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
-            pitch_feat = pad_sequence(pitch_feat,
-                                      batch_first=True,
-                                      padding_value=0)
-            batch["pitch_feat"] = pitch_feat
-            batch["pitch_feat_len"] = pitch_feat_len
-        else:
-            # only gan train needs speech, delete it to save memory
-            del batch["speech"]
-            del batch["speech_len"]
-        if mode == 'inference':
-            tts_text = [sample[i]['tts_text'] for i in order]
-            tts_index = [sample[i]['tts_index'] for i in order]
-            tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
-            tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
-            tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
-            batch.update({'tts_text': tts_text,
-                          'tts_index': tts_index,
-                          'tts_text_token': tts_text_token,
-                          'tts_text_token_len': tts_text_token_len})
-        if use_spk_embedding is True:
-            batch["embedding"] = batch["spk_embedding"]
-        else:
-            batch["embedding"] = batch["utt_embedding"]
-        yield batch

+ 46 - 1
cosyvoice/llm/llm.py

@@ -300,7 +300,6 @@ class Qwen2LM(TransformerLM):
         # 5. vllm related
         # 5. vllm related
         self.stop_token_ids = [speech_token_size + i for i in range(3)]
         self.stop_token_ids = [speech_token_size + i for i in range(3)]
         self.vllm_output_queue = {}
         self.vllm_output_queue = {}
-        self.lock = threading.Lock()
 
 
     def prepare_lm_input_target(self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len):
     def prepare_lm_input_target(self, text_token, text_token_emb, text_token_len, speech_token, speech_token_emb, speech_token_len):
         lm_target, lm_input = [], []
         lm_target, lm_input = [], []
@@ -378,6 +377,52 @@ class Qwen2LM(TransformerLM):
         acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
         acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
         return {'loss': loss, 'acc': acc}
         return {'loss': loss, 'acc': acc}
 
 
+    def forward_dpo(
+            self,
+            batch: dict,
+            device: torch.device,
+        ) -> Dict[str, Optional[torch.Tensor]]:
+        text_token = batch['text_token'].to(device)
+        text_token_len = batch['text_token_len'].to(device)
+        speech_token = batch['speech_token'].to(device)
+        speech_token_len = batch['speech_token_len'].to(device)
+        reject_speech_token = batch['reject_speech_token'].to(device)
+        reject_speech_token_len = batch['reject_speech_token_len'].to(device)
+
+        # 1. encode text_token
+        text_token_emb = self.llm.model.model.embed_tokens(text_token)
+
+        # 2. encode speech_token
+        speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
+        reject_speech_token = unpad_sequence(reject_speech_token, reject_speech_token_len.cpu(), batch_first=True)
+        speech_token_combined = speech_token + reject_speech_token
+        speech_token_combined = pad_sequence(speech_token_combined, batch_first=True, padding_value=0)
+        speech_token_combined_len = torch.concat([speech_token_len, reject_speech_token_len], dim=0)
+        speech_token_combined_emb = self.speech_embedding(speech_token_combined)
+
+        # 3. prepare llm_input/target
+        lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(text_token.repeat(2, 1), text_token_emb.repeat(2, 1, 1), text_token_len.repeat(2), speech_token_combined, speech_token_combined_emb, speech_token_combined_len)
+        lm_target = lm_target.to(device)
+
+        # 4. run lm forward
+        lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
+        logits = self.llm_decoder(lm_output)
+        chosen_logits = logits[:text_token.shape[0]]
+        rejected_logits = logits[text_token.shape[0]:]
+        chosen_lm_target = lm_target[:text_token.shape[0]]
+        rejected_lm_target = lm_target[text_token.shape[0]:]
+        loss = self.criterion_ce(chosen_logits, chosen_lm_target.to(device))
+        acc = th_accuracy(chosen_logits.view(-1, self.speech_token_size + 3), chosen_lm_target, ignore_label=IGNORE_ID)
+
+        # 5. calculate dpo logits
+        chosen_lm_mask = chosen_lm_target == IGNORE_ID
+        rejected_lm_mask = rejected_lm_target == IGNORE_ID
+        chosen_logps = torch.gather(chosen_logits.log_softmax(dim=-1), dim=2, index=chosen_lm_target.masked_fill(chosen_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
+        rejected_logps = torch.gather(rejected_logits.log_softmax(dim=-1), dim=2, index=rejected_lm_target.masked_fill(rejected_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
+        chosen_logps = (chosen_logps * chosen_lm_mask).mean(dim=-1)
+        rejected_logps = (rejected_logps * chosen_lm_mask).mean(dim=-1)
+        return {'loss': loss, 'acc': acc, 'chosen_logps': chosen_logps, 'rejected_logps': rejected_logps}
+
     @torch.inference_mode()
     @torch.inference_mode()
     def inference(
     def inference(
             self,
             self,

+ 0 - 556
cosyvoice/llm/llm_dpo.py

@@ -1,556 +0,0 @@
-# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from typing import Dict, Optional, Callable, List, Generator
-import torch
-from torch import nn
-import torch.nn.functional as F
-from transformers import Qwen2ForCausalLM
-from torch.nn.utils.rnn import pad_sequence, unpad_sequence
-from cosyvoice.utils.common import IGNORE_ID
-from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
-from cosyvoice.utils.common import th_accuracy
-from cosyvoice.utils.file_utils import logging
-from cosyvoice.utils.mask import make_pad_mask
-
-
-class TransformerLM(torch.nn.Module):
-    def __init__(
-            self,
-            text_encoder_input_size: int,
-            llm_input_size: int,
-            llm_output_size: int,
-            text_token_size: int,
-            speech_token_size: int,
-            text_encoder: torch.nn.Module,
-            llm: torch.nn.Module,
-            sampling: Callable,
-            length_normalized_loss: bool = True,
-            lsm_weight: float = 0.0,
-            spk_embed_dim: int = 192,
-    ):
-        super().__init__()
-        self.llm_input_size = llm_input_size
-        self.speech_token_size = speech_token_size
-        # 1. build text token inputs related modules
-        self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size)
-        self.text_encoder = text_encoder
-        self.text_encoder_affine_layer = nn.Linear(
-            self.text_encoder.output_size(),
-            llm_input_size
-        )
-
-        # 2. build speech token language model related modules
-        self.sos_eos = 0
-        self.task_id = 1
-        self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
-        self.llm = llm
-        self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
-        self.criterion_ce = LabelSmoothingLoss(
-            size=speech_token_size + 1,
-            padding_idx=IGNORE_ID,
-            smoothing=lsm_weight,
-            normalize_length=length_normalized_loss,
-        )
-
-        # 3. [Optional] build speech token related modules
-        self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
-        self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
-
-        # 4. sampling method
-        self.sampling = sampling
-
-    def encode(
-            self,
-            text: torch.Tensor,
-            text_lengths: torch.Tensor,
-    ):
-        encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1)
-        encoder_out_lens = encoder_mask.squeeze(1).sum(1)
-        encoder_out = self.text_encoder_affine_layer(encoder_out)
-        return encoder_out, encoder_out_lens
-
-    def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
-        text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
-        speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
-        lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
-                    for i in range(len(text_token))]
-        lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
-        lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
-        return lm_input, lm_input_len
-
-    def forward(
-            self,
-            batch: dict,
-            device: torch.device,
-    ) -> Dict[str, Optional[torch.Tensor]]:
-        """
-        Args:
-            text: (B, L, D)
-            text_lengths: (B,)
-            audio: (B, T, N) or (B, T)
-            audio_lengths: (B,)
-        """
-        text_token = batch['text_token'].to(device)
-        text_token_len = batch['text_token_len'].to(device)
-        speech_token = batch['speech_token'].to(device)
-        speech_token_len = batch['speech_token_len'].to(device)
-        embedding = batch['embedding'].to(device)
-
-        # 1. prepare llm_target
-        lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
-                                  [self.speech_token_size]) for i in range(text_token.size(0))]
-        lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
-
-        # 1. encode text_token
-        text_token = self.text_embedding(text_token)
-        text_token, text_token_len = self.encode(text_token, text_token_len)
-
-        # 2. embedding projection
-        embedding = F.normalize(embedding, dim=1)
-        embedding = self.spk_embed_affine_layer(embedding)
-        embedding = embedding.unsqueeze(1)
-
-        # 3. eos and task_id
-        sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
-        task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
-
-        # 4. encode speech_token
-        speech_token = self.speech_embedding(speech_token)
-
-        # 5. unpad and pad
-        lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len,
-                                                         task_id_emb, speech_token, speech_token_len)
-
-        # 6. run lm forward
-        lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
-        logits = self.llm_decoder(lm_output)
-        loss = self.criterion_ce(logits, lm_target)
-        acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID)
-        return {'loss': loss, 'acc': acc}
-
-    def sampling_ids(
-            self,
-            weighted_scores: torch.Tensor,
-            decoded_tokens: List,
-            sampling: int,
-            ignore_eos: bool = True,
-    ):
-        num_trials, max_trials = 0, 100
-        while True:
-            top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
-            if (not ignore_eos) or (self.speech_token_size not in top_ids):
-                break
-            num_trials += 1
-            if num_trials > max_trials:
-                raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
-        return top_ids
-
-    @torch.inference_mode()
-    def inference(
-            self,
-            text: torch.Tensor,
-            text_len: torch.Tensor,
-            prompt_text: torch.Tensor,
-            prompt_text_len: torch.Tensor,
-            prompt_speech_token: torch.Tensor,
-            prompt_speech_token_len: torch.Tensor,
-            embedding: torch.Tensor,
-            sampling: int = 25,
-            max_token_text_ratio: float = 20,
-            min_token_text_ratio: float = 2,
-    ) -> Generator[torch.Tensor, None, None]:
-        if self.fp16 is True:
-            embedding = embedding.half()
-
-        device = text.device
-        text = torch.concat([prompt_text, text], dim=1)
-        text_len += prompt_text_len
-        text = self.text_embedding(text)
-
-        # 1. encode text
-        text, text_len = self.encode(text, text_len)
-
-        # 2. encode embedding
-        if embedding.shape[0] != 0:
-            embedding = F.normalize(embedding, dim=1)
-            embedding = self.spk_embed_affine_layer(embedding)
-            embedding = embedding.unsqueeze(dim=1)
-        else:
-            embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
-
-        # 3. concat llm_input
-        sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
-        task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
-        if prompt_speech_token_len != 0:
-            prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
-        else:
-            prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
-        lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
-
-        # 4. cal min/max_length
-        min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
-        max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
-
-        # 5. step by step decode
-        out_tokens = []
-        offset = 0
-        att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
-        for i in range(max_len):
-            y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
-                                                                  att_cache=att_cache, cnn_cache=cnn_cache,
-                                                                  att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
-                                                                                                 device=lm_input.device)).to(torch.bool))
-            logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
-            # force continue decode first token
-            if i == 0:
-                logp[:, self.speech_token_size] = -float('inf')
-            top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
-            if top_ids == self.speech_token_size:
-                break
-            # in stream mode, yield token one by one
-            yield top_ids
-            out_tokens.append(top_ids)
-            offset += lm_input.size(1)
-            lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
-
-
-class Qwen2Encoder(torch.nn.Module):
-    def __init__(self, pretrain_path):
-        super().__init__()
-        self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)
-
-    def forward_one_step(self, xs, masks, cache=None):
-        input_masks = masks[:, -1, :]
-        outs = self.model(
-            inputs_embeds=xs,
-            attention_mask=input_masks,
-            output_hidden_states=True,
-            return_dict=True,
-            use_cache=True,
-            past_key_values=cache,
-        )
-        xs = outs.hidden_states[-1]
-        new_cache = outs.past_key_values
-        return xs, new_cache
-
-
-class Qwen2LM(TransformerLM):
-    def __init__(
-            self,
-            llm_input_size: int,
-            llm_output_size: int,
-            speech_token_size: int,
-            llm: torch.nn.Module,
-            sampling: Callable,
-            length_normalized_loss: bool = True,
-            lsm_weight: float = 0.0,
-            mix_ratio: List[int] = [5, 15],
-            dpo: bool = False,
-    ):
-        torch.nn.Module.__init__(self)
-        self.llm_input_size = llm_input_size
-        self.llm_output_size = llm_output_size
-        self.speech_token_size = speech_token_size
-
-        # 2. build speech token language model related modules
-        self.sos_eos = 0
-        self.task_id = 1
-        self.fill_token = 2
-
-        self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
-        self.llm = llm
-        self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3)
-        self.criterion_ce = LabelSmoothingLoss(
-            size=speech_token_size + 3,
-            padding_idx=IGNORE_ID,
-            smoothing=lsm_weight,
-            normalize_length=length_normalized_loss,
-        )
-
-        # 3. [Optional] build speech token related modules
-        self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)
-
-        # 4. sampling method
-        self.sampling = sampling
-        self.mix_ratio = mix_ratio
-
-        # 5. [Optional] set dpo
-        self.dpo = dpo
-
-
-    def forward(
-            self,
-            batch: dict,
-            device: torch.device,
-        ) -> Dict[str, Optional[torch.Tensor]]:
-        text_token = batch['text_token'].to(device)
-        text_token_len = batch['text_token_len'].to(device)
-        speech_token = batch['speech_token'].to(device)
-        speech_token_len = batch['speech_token_len'].to(device)
-        if self.dpo:
-            reject_speech_token = batch['reject_speech_token'].to(device)
-            reject_speech_token_len = batch['reject_speech_token_len'].to(device)
-        # 1. prepare llm_target
-        sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
-        task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
-        target_ids = [torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
-                                        [self.speech_token_size]) for i in range(text_token.size(0))]
-        if self.dpo:
-            reject_target_ids = [torch.tensor([IGNORE_ID] * (1 + text_token_len[i]) + reject_speech_token[i, :reject_speech_token_len[i]].tolist() +
-                                            [self.speech_token_size]) for i in range(text_token.size(0))]
-            target_ids.extend(reject_target_ids)
-        target_ids = pad_sequence(target_ids, batch_first=True, padding_value=IGNORE_ID).to(device)
-
-        # 2. speech token projection
-        speech_emb = self.speech_embedding(speech_token)
-        if self.dpo:
-            reject_speech_emb = self.speech_embedding(reject_speech_token)
-
-        # 3. text token projection
-        text_token_lst = unpad_sequence(text_token, text_token_len, batch_first=True)
-        text_emb = [self.llm.model.model.embed_tokens(y) for y in text_token_lst]
-
-        # 4. prepare llm_input
-        speech_emb = unpad_sequence(speech_emb, speech_token_len.cpu(), batch_first=True)
-        input_emb = [torch.concat([sos_eos_emb.squeeze(dim=0), text_emb[i], task_id_emb.squeeze(dim=0), speech_emb[i]], dim=0)
-                     for i in range(len(text_emb))]
-        if self.dpo:
-            reject_speech_emb = unpad_sequence(reject_speech_emb, reject_speech_token_len.cpu(), batch_first=True)
-            reject_input_emb = [torch.concat([sos_eos_emb.squeeze(dim=0), text_emb[i], task_id_emb.squeeze(dim=0), reject_speech_emb[i]], dim=0)
-                                for i in range(len(text_emb))]
-            input_emb.extend(reject_input_emb)
-        input_emb_lengths = torch.tensor([i.size(0) for i in input_emb], dtype=torch.int32).to(device)
-        input_emb = pad_sequence(input_emb, batch_first=True, padding_value=IGNORE_ID).to(device)
-
-        attention_mask = ~make_pad_mask(input_emb_lengths)
-
-        result = self.llm.model(
-            inputs_embeds=input_emb,
-            attention_mask=attention_mask,
-            return_dict=True
-        )
-        hidden_states = result.hidden_states
-        logits = self.llm_decoder(hidden_states[-1])
-        loss = self.criterion_ce(logits[: speech_token.shape[0]], target_ids[: speech_token.shape[0]])
-        acc = th_accuracy(
-            logits[: speech_token.shape[0]].view(-1, self.speech_token_size + 3),
-            target_ids[: speech_token.shape[0]],
-            ignore_label=IGNORE_ID,
-        )
-        if not self.dpo:
-            return {
-                "loss": loss,
-                "acc": acc,
-            }
-        else:
-            all_logps_sum, all_logps_mean = self.get_batch_logps(
-                logits, target_ids, attention_mask, text_token_len, average_log_prob=False, ignore_id=IGNORE_ID
-            )
-            chosen_logps = all_logps_sum[: speech_token.shape[0]]
-            rejected_logps = all_logps_sum[speech_token.shape[0]:]
-            return {
-                "loss": loss,
-                "acc": acc,
-                "chosen_logps": chosen_logps,
-                "rejected_logps": rejected_logps
-            }
-
-
-    def get_batch_logps(
-        self,
-        logits: torch.FloatTensor,
-        labels: torch.LongTensor,
-        attention_mask,
-        prompt_token_lens,
-        average_log_prob: bool = False,
-        ignore_id: int = -1,
-    ) -> torch.FloatTensor:
-        """Compute the log probabilities of the given labels under the given logits.
-
-        Args:
-            logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
-            labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length)
-            average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
-
-        Returns:
-            A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
-        """
-        assert average_log_prob == False
-        assert logits.shape[:-1] == labels.shape
-        labels = labels[:, 1:].clone()
-        logits = logits[:, :-1, :]
-        loss_masks = attention_mask.clone().bool()
-        # mask prompts
-        for mask, text_token_len in zip(loss_masks, prompt_token_lens):
-            mask[:text_token_len + 1] = False
-        loss_masks = loss_masks[:, 1:]
-        labels[loss_masks == False] = 0
-        # dummy token; we'll ignore the losses on these tokens later
-        ignore = labels == ignore_id
-        labels = labels.masked_fill(ignore, 0)  # avoid -1 index
-        per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)   # (bs, time,)
-        logprobs_sums = (per_token_logps * loss_masks).sum(-1)
-        logprobs_means = (per_token_logps * loss_masks).sum(-1) / loss_masks.sum(-1)
-        return logprobs_sums, logprobs_means
-
-
-    @torch.inference_mode()
-    def inference(
-            self,
-            text: torch.Tensor,
-            text_len: torch.Tensor,
-            prompt_text: torch.Tensor,
-            prompt_text_len: torch.Tensor,
-            prompt_speech_token: torch.Tensor,
-            prompt_speech_token_len: torch.Tensor,
-            embedding: torch.Tensor,
-            sampling: int = 25,
-            max_token_text_ratio: float = 20,
-            min_token_text_ratio: float = 2,
-    ) -> Generator[torch.Tensor, None, None]:
-        device = text.device
-        text = torch.concat([prompt_text, text], dim=1)
-        text_len += prompt_text_len
-        text = self.llm.model.model.embed_tokens(text)
-
-        # 3. concat llm_input
-        sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
-        task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
-        if prompt_speech_token_len != 0:
-            prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
-        else:
-            prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
-        lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
-
-        # 4. cal min/max_length
-        min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
-        max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
-
-        # 5. step by step decode
-        out_tokens = []
-        cache = None
-        for i in range(max_len):
-            y_pred, cache = self.llm.forward_one_step(lm_input,
-                                                      masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
-                                                      cache=cache)
-            logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
-            top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
-            if top_ids == self.speech_token_size:
-                break
-            if top_ids > self.speech_token_size:
-                continue
-            # in stream mode, yield token one by one
-            yield top_ids
-            out_tokens.append(top_ids)
-            lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
-
-    @torch.inference_mode()
-    def inference_bistream(
-            self,
-            text: Generator,
-            prompt_text: torch.Tensor,
-            prompt_text_len: torch.Tensor,
-            prompt_speech_token: torch.Tensor,
-            prompt_speech_token_len: torch.Tensor,
-            embedding: torch.Tensor,
-            sampling: int = 25,
-            max_token_text_ratio: float = 20,
-            min_token_text_ratio: float = 2,
-    ) -> Generator[torch.Tensor, None, None]:
-
-        device = prompt_text.device
-        # 1. prepare input
-        sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
-        task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
-        if prompt_speech_token_len != 0:
-            prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
-        else:
-            prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
-        lm_input = torch.concat([sos_eos_emb], dim=1)
-
-        # 2. iterate text
-        out_tokens = []
-        cache = None
-        # NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
-        text_cache = self.llm.model.model.embed_tokens(prompt_text)
-        next_fill_index = -1
-        for this_text in text:
-            text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
-            # prompt_speech_token_emb not empty, try append to lm_input
-            while prompt_speech_token_emb.size(1) != 0:
-                if text_cache.size(1) >= self.mix_ratio[0]:
-                    lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]]
-                    logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1)))
-                    lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1)
-                    text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:]
-                else:
-                    logging.info('not enough text token to decode, wait for more')
-                    break
-            # no prompt_speech_token_emb remain, can decode some speech token
-            if prompt_speech_token_emb.size(1) == 0:
-                if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
-                    logging.info('get fill token, need to append more text token')
-                    if text_cache.size(1) >= self.mix_ratio[0]:
-                        lm_input_text = text_cache[:, :self.mix_ratio[0]]
-                        logging.info('append {} text token'.format(lm_input_text.size(1)))
-                        if len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2:
-                            lm_input = lm_input_text
-                        else:
-                            lm_input = torch.concat([lm_input, lm_input_text], dim=1)
-                        text_cache = text_cache[:, self.mix_ratio[0]:]
-                    else:
-                        logging.info('not enough text token to decode, wait for more')
-                        continue
-                while True:
-                    seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
-                    y_pred, cache = self.llm.forward_one_step(lm_input,
-                                                masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
-                                                cache=cache)
-                    logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
-                    if next_fill_index != -1 and len(out_tokens) == next_fill_index:
-                        top_ids = self.speech_token_size + 2
-                        next_fill_index += (self.mix_ratio[1] + 1)
-                    else:
-                        top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item()
-                    if top_ids == self.speech_token_size + 2:
-                        next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
-                        logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
-                    out_tokens.append(top_ids)
-                    if top_ids >= self.speech_token_size:
-                        if top_ids == self.speech_token_size + 2:
-                            break
-                        else:
-                            raise ValueError('should not get token {}'.format(top_ids))
-                    yield top_ids
-                    lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
-
-        # 3. final decode
-        lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1)
-        logging.info('no more text token, decode until met eos')
-        while True:
-            seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
-            y_pred, cache = self.llm.forward_one_step(lm_input,
-                                                      masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
-                                                      cache=cache)
-            logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
-            top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item()
-            out_tokens.append(top_ids)
-            if top_ids >= self.speech_token_size:
-                if top_ids == self.speech_token_size:
-                    break
-                else:
-                    raise ValueError('should not get token {}'.format(top_ids))
-            # in stream mode, yield token one by one
-            yield top_ids
-            lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)

+ 7 - 3
cosyvoice/utils/executor.py

@@ -25,14 +25,16 @@ from cosyvoice.utils.train_utils import update_parameter_and_lr, log_per_step, l
 
 
 class Executor:
 class Executor:
 
 
-    def __init__(self, gan: bool = False):
+    def __init__(self, gan: bool = False, ref_model: torch.nn.Module = None, dpo_loss: torch.nn.Module = None):
         self.gan = gan
         self.gan = gan
+        self.ref_model = ref_model
+        self.dpo_loss = dpo_loss
         self.step = 0
         self.step = 0
         self.epoch = 0
         self.epoch = 0
         self.rank = int(os.environ.get('RANK', 0))
         self.rank = int(os.environ.get('RANK', 0))
         self.device = torch.device('cuda:{}'.format(self.rank))
         self.device = torch.device('cuda:{}'.format(self.rank))
 
 
-    def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join):
+    def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model=None):
         ''' Train one epoch
         ''' Train one epoch
         '''
         '''
 
 
@@ -44,6 +46,8 @@ class Executor:
         # torch.nn.parallel.DistributedDataParallel to be able to train
         # torch.nn.parallel.DistributedDataParallel to be able to train
         # with uneven inputs across participating processes.
         # with uneven inputs across participating processes.
         model.train()
         model.train()
+        if self.ref_model is not None:
+            self.ref_model.eval()
         model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
         model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
         with model_context():
         with model_context():
             for batch_idx, batch_dict in enumerate(train_data_loader):
             for batch_idx, batch_dict in enumerate(train_data_loader):
@@ -65,7 +69,7 @@ class Executor:
                     context = nullcontext
                     context = nullcontext
 
 
                 with context():
                 with context():
-                    info_dict = batch_forward(model, batch_dict, scaler, info_dict)
+                    info_dict = batch_forward(model, batch_dict, scaler, info_dict, ref_model=self.ref_model, dpo_loss=self.dpo_loss)
                     info_dict = batch_backward(model, scaler, info_dict)
                     info_dict = batch_backward(model, scaler, info_dict)
 
 
                 info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)
                 info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)

+ 0 - 184
cosyvoice/utils/executor_dpo.py

@@ -1,184 +0,0 @@
-# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
-#               2024 Alibaba Inc (authors: Xiang Lyu)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-from contextlib import nullcontext
-import os
-
-import torch
-import torch.distributed as dist
-
-from cosyvoice.utils.train_utils_dpo import update_parameter_and_lr, log_per_step, log_per_save, batch_forward, batch_backward, save_model, cosyvoice_join
-from cosyvoice.utils.losses_dpo import DPOLoss
-
-
-class Executor:
-
-    def __init__(self, gan: bool = False, dpo: bool = False, beta: float = 0.01, label_smoothing: float = 0.0, ipo: bool = False):
-        self.gan = gan
-        self.step = 0
-        self.epoch = 0
-        self.rank = int(os.environ.get('RANK', 0))
-        self.device = torch.device('cuda:{}'.format(self.rank))
-        self.dpo = dpo
-        if self.dpo:
-            self.dpo_loss = DPOLoss(beta, label_smoothing, ipo)
-        else:
-            self.dpo_loss = None
-
-    def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model=None):
-        ''' Train one epoch
-        '''
-
-        lr = optimizer.param_groups[0]['lr']
-        logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank))
-        logging.info('using accumulate grad, new batch size is {} times'
-                     ' larger than before'.format(info_dict['accum_grad']))
-        # A context manager to be used in conjunction with an instance of
-        # torch.nn.parallel.DistributedDataParallel to be able to train
-        # with uneven inputs across participating processes.
-        model.train()
-        if self.dpo:
-            assert ref_model is not None
-            ref_model.eval()
-        model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
-        with model_context():
-            for batch_idx, batch_dict in enumerate(train_data_loader):
-                info_dict["tag"] = "TRAIN"
-                info_dict["step"] = self.step
-                info_dict["epoch"] = self.epoch
-                info_dict["batch_idx"] = batch_idx
-                if cosyvoice_join(group_join, info_dict):
-                    break
-
-                # Disable gradient synchronizations across DDP processes.
-                # Within this context, gradients will be accumulated on module
-                # variables, which will later be synchronized.
-                if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0:
-                    context = model.no_sync
-                # Used for single gpu training and DDP gradient synchronization
-                # processes.
-                else:
-                    context = nullcontext
-
-                with context():
-                    info_dict = batch_forward(model, batch_dict, scaler, info_dict, ref_model, self.dpo_loss)
-                    info_dict = batch_backward(model, scaler, info_dict)
-
-                info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)
-                log_per_step(writer, info_dict)
-                # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
-                if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
-                   (batch_idx + 1) % info_dict["accum_grad"] == 0:
-                    dist.barrier()
-                    self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False, ref_model=ref_model, dpo_loss=self.dpo_loss)
-                    model.train()
-                if (batch_idx + 1) % info_dict["accum_grad"] == 0:
-                    self.step += 1
-        dist.barrier()
-        self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True, ref_model=ref_model, dpo_loss=self.dpo_loss)
-
-    def train_one_epoc_gan(self, model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
-                           writer, info_dict, scaler, group_join):
-        ''' Train one epoch
-        '''
-
-        lr = optimizer.param_groups[0]['lr']
-        logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank))
-        logging.info('using accumulate grad, new batch size is {} times'
-                     ' larger than before'.format(info_dict['accum_grad']))
-        # A context manager to be used in conjunction with an instance of
-        # torch.nn.parallel.DistributedDataParallel to be able to train
-        # with uneven inputs across participating processes.
-        model.train()
-        model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
-        with model_context():
-            for batch_idx, batch_dict in enumerate(train_data_loader):
-                info_dict["tag"] = "TRAIN"
-                info_dict["step"] = self.step
-                info_dict["epoch"] = self.epoch
-                info_dict["batch_idx"] = batch_idx
-                if cosyvoice_join(group_join, info_dict):
-                    break
-
-                # Disable gradient synchronizations across DDP processes.
-                # Within this context, gradients will be accumulated on module
-                # variables, which will later be synchronized.
-                if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0:
-                    context = model.no_sync
-                # Used for single gpu training and DDP gradient synchronization
-                # processes.
-                else:
-                    context = nullcontext
-
-                with context():
-                    batch_dict['turn'] = 'discriminator'
-                    info_dict = batch_forward(model, batch_dict, scaler, info_dict)
-                    info_dict = batch_backward(model, scaler, info_dict)
-                info_dict = update_parameter_and_lr(model, optimizer_d, scheduler_d, scaler, info_dict)
-                optimizer.zero_grad()
-                log_per_step(writer, info_dict)
-                with context():
-                    batch_dict['turn'] = 'generator'
-                    info_dict = batch_forward(model, batch_dict, scaler, info_dict)
-                    info_dict = batch_backward(model, scaler, info_dict)
-                info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)
-                optimizer_d.zero_grad()
-                log_per_step(writer, info_dict)
-                # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
-                if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
-                   (batch_idx + 1) % info_dict["accum_grad"] == 0:
-                    dist.barrier()
-                    self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
-                    model.train()
-                if (batch_idx + 1) % info_dict["accum_grad"] == 0:
-                    self.step += 1
-        dist.barrier()
-        self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
-
-    @torch.inference_mode()
-    def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True, ref_model=None, dpo_loss=None):
-        ''' Cross validation on
-        '''
-        logging.info('Epoch {} Step {} on_batch_end {} CV rank {}'.format(self.epoch, self.step + 1, on_batch_end, self.rank))
-        model.eval()
-        if self.dpo:
-            assert ref_model is not None
-            ref_model.eval()
-        total_num_utts, total_loss_dict = 0, {}  # avoid division by 0
-        for batch_idx, batch_dict in enumerate(cv_data_loader):
-            info_dict["tag"] = "CV"
-            info_dict["step"] = self.step
-            info_dict["epoch"] = self.epoch
-            info_dict["batch_idx"] = batch_idx
-
-            num_utts = len(batch_dict["utts"])
-            total_num_utts += num_utts
-
-            if self.gan is True:
-                batch_dict['turn'] = 'generator'
-            info_dict = batch_forward(model, batch_dict, None, info_dict, ref_model, dpo_loss)
-
-            for k, v in info_dict['loss_dict'].items():
-                if k not in total_loss_dict:
-                    total_loss_dict[k] = []
-                total_loss_dict[k].append(v.item() * num_utts)
-            log_per_step(None, info_dict)
-        for k, v in total_loss_dict.items():
-            total_loss_dict[k] = sum(v) / total_num_utts
-        info_dict['loss_dict'] = total_loss_dict
-        log_per_save(writer, info_dict)
-        model_name = 'epoch_{}_whole'.format(self.epoch) if on_batch_end else 'epoch_{}_step_{}'.format(self.epoch, self.step + 1)
-        save_model(model, model_name, info_dict)

+ 37 - 0
cosyvoice/utils/losses.py

@@ -1,5 +1,6 @@
 import torch
 import torch
 import torch.nn.functional as F
 import torch.nn.functional as F
+from typing import Tuple
 
 
 
 
 def tpr_loss(disc_real_outputs, disc_generated_outputs, tau):
 def tpr_loss(disc_real_outputs, disc_generated_outputs, tau):
@@ -18,3 +19,39 @@ def mel_loss(real_speech, generated_speech, mel_transforms):
         mel_g = transform(generated_speech)
         mel_g = transform(generated_speech)
         loss += F.l1_loss(mel_g, mel_r)
         loss += F.l1_loss(mel_g, mel_r)
     return loss
     return loss
+
+
+class DPOLoss(torch.nn.Module):
+    """
+    DPO Loss
+    """
+
+    def __init__(self, beta: float, label_smoothing: float = 0.0, ipo: bool = False) -> None:
+        super().__init__()
+        self.beta = beta
+        self.label_smoothing = label_smoothing
+        self.ipo = ipo
+
+    def forward(
+        self,
+        policy_chosen_logps: torch.Tensor,
+        policy_rejected_logps: torch.Tensor,
+        reference_chosen_logps: torch.Tensor,
+        reference_rejected_logps: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        pi_logratios = policy_chosen_logps - policy_rejected_logps
+        ref_logratios = reference_chosen_logps - reference_rejected_logps
+        logits = pi_logratios - ref_logratios
+        if self.ipo:
+            losses = (logits - 1 / (2 * self.beta)) ** 2  # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf
+        else:
+            # Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf)
+            losses = (
+                -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
+                - F.logsigmoid(-self.beta * logits) * self.label_smoothing
+            )
+        loss = losses.mean()
+        chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
+        rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()
+
+        return loss, chosen_rewards, rejected_rewards

+ 0 - 57
cosyvoice/utils/losses_dpo.py

@@ -1,57 +0,0 @@
-import torch
-import torch.nn.functional as F
-from typing import Tuple
-
-
-def tpr_loss(disc_real_outputs, disc_generated_outputs, tau):
-    loss = 0
-    for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
-        m_DG = torch.median((dr - dg))
-        L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG])
-        loss += tau - F.relu(tau - L_rel)
-    return loss
-
-
-def mel_loss(real_speech, generated_speech, mel_transforms):
-    loss = 0
-    for transform in mel_transforms:
-        mel_r = transform(real_speech)
-        mel_g = transform(generated_speech)
-        loss += F.l1_loss(mel_g, mel_r)
-    return loss
-
-
-class DPOLoss(torch.nn.Module):
-    """
-    DPO Loss
-    """
-
-    def __init__(self, beta: float, label_smoothing: float = 0.0, ipo: bool = False) -> None:
-        super().__init__()
-        self.beta = beta
-        self.label_smoothing = label_smoothing
-        self.ipo = ipo
-
-    def forward(
-        self,
-        policy_chosen_logps: torch.Tensor,
-        policy_rejected_logps: torch.Tensor,
-        reference_chosen_logps: torch.Tensor,
-        reference_rejected_logps: torch.Tensor,
-    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
-        pi_logratios = policy_chosen_logps - policy_rejected_logps
-        ref_logratios = reference_chosen_logps - reference_rejected_logps
-        logits = pi_logratios - ref_logratios
-        if self.ipo:
-            losses = (logits - 1 / (2 * self.beta)) ** 2  # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf
-        else:
-            # Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf)
-            losses = (
-                -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
-                - F.logsigmoid(-self.beta * logits) * self.label_smoothing
-            )
-        loss = losses.mean()
-        chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
-        rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()
-
-        return loss, chosen_rewards, rejected_rewards

+ 22 - 4
cosyvoice/utils/train_utils.py

@@ -50,10 +50,10 @@ def init_distributed(args):
     return world_size, local_rank, rank
     return world_size, local_rank, rank
 
 
 
 
-def init_dataset_and_dataloader(args, configs, gan):
+def init_dataset_and_dataloader(args, configs, gan, dpo):
     data_pipeline = configs['data_pipeline_gan'] if gan is True else configs['data_pipeline']
     data_pipeline = configs['data_pipeline_gan'] if gan is True else configs['data_pipeline']
-    train_dataset = Dataset(args.train_data, data_pipeline=data_pipeline, mode='train', gan=gan, shuffle=True, partition=True)
-    cv_dataset = Dataset(args.cv_data, data_pipeline=data_pipeline, mode='train', gan=gan, shuffle=False, partition=False)
+    train_dataset = Dataset(args.train_data, data_pipeline=data_pipeline, mode='train', gan=gan, dpo=dpo, shuffle=True, partition=True)
+    cv_dataset = Dataset(args.cv_data, data_pipeline=data_pipeline, mode='train', gan=gan, dpo=dpo, shuffle=False, partition=False)
 
 
     # do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts
     # do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts
     train_data_loader = DataLoader(train_dataset,
     train_data_loader = DataLoader(train_dataset,
@@ -235,7 +235,7 @@ def cosyvoice_join(group_join, info_dict):
         return False
         return False
 
 
 
 
-def batch_forward(model, batch, scaler, info_dict):
+def batch_forward(model, batch, scaler, info_dict, ref_model=None, dpo_loss=None):
     device = int(os.environ.get('LOCAL_RANK', 0))
     device = int(os.environ.get('LOCAL_RANK', 0))
 
 
     dtype = info_dict["dtype"]
     dtype = info_dict["dtype"]
@@ -253,6 +253,24 @@ def batch_forward(model, batch, scaler, info_dict):
 
 
     with autocast:
     with autocast:
         info_dict['loss_dict'] = model(batch, device)
         info_dict['loss_dict'] = model(batch, device)
+        if ref_model is not None and dpo_loss is not None:
+            chosen_logps = info_dict['loss_dict']["chosen_logps"]
+            rejected_logps = info_dict['loss_dict']["rejected_logps"]
+            sft_loss = info_dict['loss_dict']['loss']
+            with torch.no_grad():
+                ref_loss_dict = ref_model(batch, device)
+            reference_chosen_logps = ref_loss_dict["chosen_logps"]
+            reference_rejected_logps = ref_loss_dict["rejected_logps"]
+            preference_loss, chosen_reward, reject_reward = dpo_loss(
+                chosen_logps, rejected_logps, reference_chosen_logps, reference_rejected_logps
+            )
+            dpo_acc = (chosen_reward > reject_reward).float().mean()
+            info_dict['loss_dict']["loss"] = preference_loss + sft_loss
+            info_dict['loss_dict']["sft_loss"] = sft_loss
+            info_dict['loss_dict']["dpo_loss"] = preference_loss
+            info_dict['loss_dict']["dpo_acc"] = dpo_acc
+            info_dict['loss_dict']["chosen_reward"] = chosen_reward.mean()
+            info_dict['loss_dict']["reject_reward"] = reject_reward.mean()
     return info_dict
     return info_dict
 
 
 
 

+ 0 - 364
cosyvoice/utils/train_utils_dpo.py

@@ -1,364 +0,0 @@
-# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
-#               2023 Horizon Inc. (authors: Xingchen Song)
-#               2024 Alibaba Inc (authors: Xiang Lyu)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-import os
-import torch
-import json
-import re
-import datetime
-import yaml
-
-import deepspeed
-import torch.optim as optim
-import torch.distributed as dist
-
-from torch.utils.tensorboard import SummaryWriter
-from torch.utils.data import DataLoader
-from torch.nn.utils import clip_grad_norm_
-
-from deepspeed.runtime.zero.stage_1_and_2 import estimate_zero2_model_states_mem_needs_all_live
-
-from cosyvoice.dataset.dataset import Dataset
-from cosyvoice.utils.scheduler import WarmupLR, NoamHoldAnnealing, ConstantLR
-
-
-def init_distributed(args):
-    world_size = int(os.environ.get('WORLD_SIZE', 1))
-    local_rank = int(os.environ.get('LOCAL_RANK', 0))
-    rank = int(os.environ.get('RANK', 0))
-    logging.info('training on multiple gpus, this gpu {}'.format(local_rank) +
-                 ', rank {}, world_size {}'.format(rank, world_size))
-    if args.train_engine == 'torch_ddp':
-        torch.cuda.set_device(local_rank)
-        dist.init_process_group(args.dist_backend)
-    else:
-        deepspeed.init_distributed(dist_backend=args.dist_backend)
-    return world_size, local_rank, rank
-
-
-def init_dataset_and_dataloader(args, configs, gan):
-    data_pipeline = configs['data_pipeline_gan'] if gan is True else configs['data_pipeline']
-    train_dataset = Dataset(args.train_data, data_pipeline=data_pipeline, mode='train', gan=gan, shuffle=True, partition=True)
-    cv_dataset = Dataset(args.cv_data, data_pipeline=data_pipeline, mode='train', gan=gan, shuffle=False, partition=False)
-
-    # do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts
-    train_data_loader = DataLoader(train_dataset,
-                                   batch_size=None,
-                                   pin_memory=args.pin_memory,
-                                   num_workers=args.num_workers,
-                                   prefetch_factor=args.prefetch)
-    cv_data_loader = DataLoader(cv_dataset,
-                                batch_size=None,
-                                pin_memory=args.pin_memory,
-                                num_workers=args.num_workers,
-                                prefetch_factor=args.prefetch)
-    return train_dataset, cv_dataset, train_data_loader, cv_data_loader
-
-
-def check_modify_and_save_config(args, configs):
-    if args.train_engine == "torch_ddp":
-        configs['train_conf']["dtype"] = 'fp32'
-    else:
-        with open(args.deepspeed_config, 'r') as fin:
-            ds_configs = json.load(fin)
-        if "fp16" in ds_configs and ds_configs["fp16"]["enabled"]:
-            configs['train_conf']["dtype"] = "fp16"
-        elif "bf16" in ds_configs and ds_configs["bf16"]["enabled"]:
-            configs['train_conf']["dtype"] = "bf16"
-        else:
-            configs['train_conf']["dtype"] = "fp32"
-        assert ds_configs["train_micro_batch_size_per_gpu"] == 1
-        # if use deepspeed, override ddp config
-        configs['train_conf']['save_per_step'] = int(configs['train_conf']['save_per_step'] *
-                                                     configs['train_conf']['accum_grad'] / ds_configs["gradient_accumulation_steps"])
-        configs['train_conf']['accum_grad'] = ds_configs["gradient_accumulation_steps"]
-        configs['train_conf']['grad_clip'] = ds_configs["gradient_clipping"]
-        configs['train_conf']['log_interval'] = ds_configs["steps_per_print"]
-    return configs
-
-
-def wrap_cuda_model(args, model):
-    local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1))
-    world_size = int(os.environ.get('WORLD_SIZE', 1))
-    if args.train_engine == "torch_ddp":  # native pytorch ddp
-        assert (torch.cuda.is_available())
-        model.cuda()
-        model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)
-    else:
-        if int(os.environ.get('RANK', 0)) == 0:
-            logging.info("Estimating model states memory needs (zero2)...")
-            estimate_zero2_model_states_mem_needs_all_live(
-                model,
-                num_gpus_per_node=local_world_size,
-                num_nodes=world_size // local_world_size)
-    return model
-
-
-def init_optimizer_and_scheduler(args, configs, model, gan):
-    if gan is False:
-        if configs['train_conf']['optim'] == 'adam':
-            optimizer = optim.Adam(model.parameters(), **configs['train_conf']['optim_conf'])
-        elif configs['train_conf']['optim'] == 'adamw':
-            optimizer = optim.AdamW(model.parameters(), **configs['train_conf']['optim_conf'])
-        else:
-            raise ValueError("unknown optimizer: " + configs['train_conf'])
-
-        if configs['train_conf']['scheduler'] == 'warmuplr':
-            scheduler_type = WarmupLR
-            scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf'])
-        elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
-            scheduler_type = NoamHoldAnnealing
-            scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
-        elif configs['train_conf']['scheduler'] == 'constantlr':
-            scheduler_type = ConstantLR
-            scheduler = ConstantLR(optimizer)
-        else:
-            raise ValueError("unknown scheduler: " + configs['train_conf'])
-
-        # use deepspeed optimizer for speedup
-        if args.train_engine == "deepspeed":
-            def scheduler(opt):
-                return scheduler_type(opt, **configs['train_conf']['scheduler_conf'])
-            model, optimizer, _, scheduler = deepspeed.initialize(
-                args=args,
-                model=model,
-                optimizer=None,
-                lr_scheduler=scheduler,
-                model_parameters=model.parameters())
-
-        optimizer_d, scheduler_d = None, None
-
-    else:
-        # currently we wrap generator and discriminator in one model, so we cannot use deepspeed
-        if configs['train_conf']['optim'] == 'adam':
-            optimizer = optim.Adam(model.module.generator.parameters(), **configs['train_conf']['optim_conf'])
-        elif configs['train_conf']['optim'] == 'adamw':
-            optimizer = optim.AdamW(model.module.generator.parameters(), **configs['train_conf']['optim_conf'])
-        else:
-            raise ValueError("unknown optimizer: " + configs['train_conf'])
-
-        if configs['train_conf']['scheduler'] == 'warmuplr':
-            scheduler_type = WarmupLR
-            scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf'])
-        elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
-            scheduler_type = NoamHoldAnnealing
-            scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
-        elif configs['train_conf']['scheduler'] == 'constantlr':
-            scheduler_type = ConstantLR
-            scheduler = ConstantLR(optimizer)
-        else:
-            raise ValueError("unknown scheduler: " + configs['train_conf'])
-
-        if configs['train_conf']['optim_d'] == 'adam':
-            optimizer_d = optim.Adam(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf'])
-        elif configs['train_conf']['optim_d'] == 'adamw':
-            optimizer_d = optim.AdamW(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf'])
-        else:
-            raise ValueError("unknown optimizer: " + configs['train_conf'])
-
-        if configs['train_conf']['scheduler_d'] == 'warmuplr':
-            scheduler_type = WarmupLR
-            scheduler_d = WarmupLR(optimizer_d, **configs['train_conf']['scheduler_conf'])
-        elif configs['train_conf']['scheduler_d'] == 'NoamHoldAnnealing':
-            scheduler_type = NoamHoldAnnealing
-            scheduler_d = NoamHoldAnnealing(optimizer_d, **configs['train_conf']['scheduler_conf'])
-        elif configs['train_conf']['scheduler'] == 'constantlr':
-            scheduler_type = ConstantLR
-            scheduler_d = ConstantLR(optimizer_d)
-        else:
-            raise ValueError("unknown scheduler: " + configs['train_conf'])
-    return model, optimizer, scheduler, optimizer_d, scheduler_d
-
-
-def init_summarywriter(args):
-    writer = None
-    if int(os.environ.get('RANK', 0)) == 0:
-        os.makedirs(args.model_dir, exist_ok=True)
-        writer = SummaryWriter(args.tensorboard_dir)
-    return writer
-
-
-def save_model(model, model_name, info_dict):
-    rank = int(os.environ.get('RANK', 0))
-    model_dir = info_dict["model_dir"]
-    save_model_path = os.path.join(model_dir, '{}.pt'.format(model_name))
-
-    if info_dict["train_engine"] == "torch_ddp":
-        if rank == 0:
-            torch.save({**model.module.state_dict(), 'epoch': info_dict['epoch'], 'step': info_dict['step']}, save_model_path)
-    else:
-        with torch.no_grad():
-            model.save_checkpoint(save_dir=model_dir,
-                                  tag=model_name,
-                                  client_state=info_dict)
-    if rank == 0:
-        info_path = re.sub('.pt$', '.yaml', save_model_path)
-        info_dict['save_time'] = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S')
-        with open(info_path, 'w') as fout:
-            data = yaml.dump(info_dict)
-            fout.write(data)
-        logging.info('[Rank {}] Checkpoint: save to checkpoint {}'.format(rank, save_model_path))
-
-
-def cosyvoice_join(group_join, info_dict):
-    world_size = int(os.environ.get('WORLD_SIZE', 1))
-    local_rank = int(os.environ.get('LOCAL_RANK', 0))
-    rank = int(os.environ.get('RANK', 0))
-
-    if info_dict["batch_idx"] != 0:
-        # we try to join all rank in both ddp and deepspeed mode, in case different rank has different lr
-        try:
-            dist.monitored_barrier(group=group_join,
-                                   timeout=group_join.options._timeout)
-            return False
-        except RuntimeError as e:
-            logging.info("Detected uneven workload distribution: {}\n".format(e) +
-                         "Break current worker to manually join all workers, " +
-                         "world_size {}, current rank {}, current local_rank {}\n".
-                         format(world_size, rank, local_rank))
-            return True
-    else:
-        return False
-
-
-def batch_forward(model, batch, scaler, info_dict, ref_model=None, dpo_loss=None):
-    device = int(os.environ.get('LOCAL_RANK', 0))
-
-    dtype = info_dict["dtype"]
-    if dtype == "fp16":
-        dtype = torch.float16
-    elif dtype == "bf16":
-        dtype = torch.bfloat16
-    else:  # fp32
-        dtype = torch.float32
-
-    if info_dict['train_engine'] == 'torch_ddp':
-        autocast = torch.cuda.amp.autocast(enabled=scaler is not None)
-    else:
-        autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False)
-
-    with autocast:
-        info_dict['loss_dict'] = model(batch, device)
-        if ref_model and dpo_loss:
-            chosen_logps = info_dict['loss_dict']["chosen_logps"]
-            rejected_logps = info_dict['loss_dict']["rejected_logps"]
-            sft_loss = info_dict['loss_dict']['loss']
-            with torch.no_grad():
-                ref_model = ref_model.to(device)
-                ref_loss_dict = ref_model(batch, device)
-            reference_chosen_logps = ref_loss_dict["chosen_logps"]
-            reference_rejected_logps = ref_loss_dict["rejected_logps"]
-            preference_loss, chosen_reward, reject_reward = dpo_loss(
-                chosen_logps, rejected_logps, reference_chosen_logps, reference_rejected_logps
-            )
-            dpo_acc = (chosen_reward > reject_reward).float().mean()
-            info_dict['loss_dict']["loss"] = preference_loss + sft_loss
-            info_dict['loss_dict']["sft_loss"] = sft_loss
-            info_dict['loss_dict']["dpo_loss"] = preference_loss
-            info_dict['loss_dict']["dpo_acc"] = dpo_acc
-            info_dict['loss_dict']["chosen_reward"] = chosen_reward.mean()
-            info_dict['loss_dict']["reject_reward"] = reject_reward.mean()
-    return info_dict
-
-
-def batch_backward(model, scaler, info_dict):
-    if info_dict["train_engine"] == "deepspeed":
-        scaled_loss = model.backward(info_dict['loss_dict']['loss'])
-    else:
-        scaled_loss = info_dict['loss_dict']['loss'] / info_dict['accum_grad']
-        if scaler is not None:
-            scaler.scale(scaled_loss).backward()
-        else:
-            scaled_loss.backward()
-
-    info_dict['loss_dict']['loss'] = scaled_loss
-    return info_dict
-
-
-def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict):
-    grad_norm = 0.0
-    if info_dict['train_engine'] == "deepspeed":
-        info_dict["is_gradient_accumulation_boundary"] = model.is_gradient_accumulation_boundary()
-        model.step()
-        grad_norm = model.get_global_grad_norm()
-    elif (info_dict['batch_idx'] + 1) % info_dict["accum_grad"] == 0:
-        # Use mixed precision training
-        if scaler is not None:
-            scaler.unscale_(optimizer)
-            grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
-            # We don't check grad here since that if the gradient
-            # has inf/nan values, scaler.step will skip
-            # optimizer.step().
-            if torch.isfinite(grad_norm):
-                scaler.step(optimizer)
-            scaler.update()
-        else:
-            grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
-            if torch.isfinite(grad_norm):
-                optimizer.step()
-        optimizer.zero_grad()
-        scheduler.step()
-    info_dict["lr"] = optimizer.param_groups[0]['lr']
-    info_dict["grad_norm"] = grad_norm
-    return info_dict
-
-
-def log_per_step(writer, info_dict):
-    tag = info_dict["tag"]
-    epoch = info_dict.get('epoch', 0)
-    step = info_dict["step"]
-    batch_idx = info_dict["batch_idx"]
-    loss_dict = info_dict['loss_dict']
-    rank = int(os.environ.get('RANK', 0))
-
-    # only rank 0 write to tensorboard to avoid multi-process write
-    if writer is not None:
-        if (info_dict['train_engine'] == 'deepspeed' and info_dict['is_gradient_accumulation_boundary'] is True) or \
-           (info_dict['train_engine'] == 'torch_ddp' and (info_dict['batch_idx'] + 1) % info_dict['accum_grad'] == 0):
-            for k in ['epoch', 'lr', 'grad_norm']:
-                writer.add_scalar('{}/{}'.format(tag, k), info_dict[k], step + 1)
-            for k, v in loss_dict.items():
-                writer.add_scalar('{}/{}'.format(tag, k), v, step + 1)
-
-    # TRAIN & CV, Shell log (stdout)
-    if (info_dict['batch_idx'] + 1) % info_dict['log_interval'] == 0:
-        log_str = '{} Batch {}/{} '.format(tag, epoch, batch_idx + 1)
-        for name, value in loss_dict.items():
-            log_str += '{} {:.6f} '.format(name, value)
-        if tag == "TRAIN":
-            log_str += 'lr {:.8f} grad_norm {:.6f}'.format(
-                info_dict["lr"], info_dict['grad_norm'])
-        log_str += ' rank {}'.format(rank)
-        logging.debug(log_str)
-
-
-def log_per_save(writer, info_dict):
-    tag = info_dict["tag"]
-    epoch = info_dict["epoch"]
-    step = info_dict["step"]
-    loss_dict = info_dict["loss_dict"]
-    lr = info_dict['lr']
-    rank = int(os.environ.get('RANK', 0))
-    logging.info(
-        'Epoch {} Step {} CV info lr {} {} rank {}'.format(
-            epoch, step + 1, lr, rank, ' '.join(['{}_{}'.format(k, v) for k, v in loss_dict.items()])))
-
-    if writer is not None:
-        for k in ['epoch', 'lr']:
-            writer.add_scalar('{}/{}'.format(tag, k), info_dict[k], step + 1)
-        for k, v in loss_dict.items():
-            writer.add_scalar('{}/{}'.format(tag, k), v, step + 1)

+ 2 - 0
examples/libritts/cosyvoice/local/prepare_data.py

@@ -49,5 +49,7 @@ if __name__ == "__main__":
                         type=str)
                         type=str)
     parser.add_argument('--des_dir',
     parser.add_argument('--des_dir',
                         type=str)
                         type=str)
+    parser.add_argument('--ref_model',
+                        type=str)
     args = parser.parse_args()
     args = parser.parse_args()
     main()
     main()

+ 49 - 0
examples/libritts/cosyvoice/local/prepare_reject_sample.py

@@ -0,0 +1,49 @@
+import argparse
+import logging
+import os
+from tqdm import tqdm
+import torch, torchaudio
+from cosyvoice.cli.cosyvoice import CosyVoice2
+from cosyvoice.utils.file_utils import load_wav
+
+
+logger = logging.getLogger()
+
+
+def main():
+    cosyvoice = CosyVoice2(args.ref_model)
+
+    utt2wav, utt2text = {}, {}
+    with open('{}/wav.scp'.format(args.src_dir)) as f:
+        for l in f:
+            l = l.split('\n')[0].split()
+            utt2wav[l[0]] = l[1]
+    with open('{}/text'.format(args.src_dir)) as f:
+        for l in f:
+            l = l.split('\n')[0].split()
+            utt2text[l[0]] = ' '.join(l[1:])
+
+    os.makedirs('{}/wav'.format(args.des_dir), exist_ok=True)
+    with open('{}/wav.scp'.format(args.des_dir), 'w') as f:
+        for utt, wav in tqdm(utt2wav.items()):
+            prompt_speech_16k = load_wav(wav, 16000)
+            if prompt_speech_16k.shape[1] >= 30 * 16000:
+                continue
+            speech_list = []
+            for i, j in enumerate(cosyvoice.inference_zero_shot(utt2text[utt], utt2text[utt], prompt_speech_16k, stream=False, text_frontend=False)):
+                speech_list.append(j['tts_speech'])
+            negative_wav = os.path.abspath('{}/wav/{}'.format(args.des_dir, os.path.basename(wav)))
+            torchaudio.save(negative_wav, torch.concat(speech_list, dim=1), cosyvoice.sample_rate, backend='soundfile')
+            f.write('{} {}\n'.format(utt, negative_wav))
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--src_dir',
+                        type=str)
+    parser.add_argument('--des_dir',
+                        type=str)
+    parser.add_argument('--ref_model',
+                        type=str)
+    args = parser.parse_args()
+    main()

+ 0 - 17
examples/libritts/cosyvoice/run.sh

@@ -51,23 +51,6 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
   done
   done
 fi
 fi
 
 
-# inference
-if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
-  echo "Run inference. Please make sure utt in tts_text is in prompt_data"
-  for mode in sft zero_shot; do
-    python cosyvoice/bin/inference.py --mode $mode \
-      --gpu 0 \
-      --config conf/cosyvoice.yaml \
-      --prompt_data data/test-clean/parquet/data.list \
-      --prompt_utt2data data/test-clean/parquet/utt2data.list \
-      --tts_text `pwd`/tts_text.json \
-      --llm_model $pretrained_model_dir/llm.pt \
-      --flow_model $pretrained_model_dir/flow.pt \
-      --hifigan_model $pretrained_model_dir/hift.pt \
-      --result_dir `pwd`/exp/cosyvoice/test-clean/$mode
-  done
-fi
-
 # train llm
 # train llm
 export CUDA_VISIBLE_DEVICES="0,1,2,3"
 export CUDA_VISIBLE_DEVICES="0,1,2,3"
 num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
 num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')

+ 1 - 20
examples/libritts/cosyvoice2/run.sh

@@ -51,25 +51,6 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
   done
   done
 fi
 fi
 
 
-# inference
-if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
-  echo "Run inference. Please make sure utt in tts_text is in prompt_data"
-  # TODO consider remove bin/inference.py, or use similar initilization method as in readme
-  for mode in sft zero_shot; do
-    python cosyvoice/bin/inference.py --mode $mode \
-      --gpu 0 \
-      --config conf/cosyvoice2.yaml \
-      --prompt_data data/test-clean/parquet/data.list \
-      --prompt_utt2data data/test-clean/parquet/utt2data.list \
-      --tts_text `pwd`/tts_text.json \
-      --qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \
-      --llm_model $pretrained_model_dir/llm.pt \
-      --flow_model $pretrained_model_dir/flow.pt \
-      --hifigan_model $pretrained_model_dir/hift.pt \
-      --result_dir `pwd`/exp/cosyvoice/test-clean/$mode
-  done
-fi
-
 # train llm
 # train llm
 export CUDA_VISIBLE_DEVICES="0,1,2,3"
 export CUDA_VISIBLE_DEVICES="0,1,2,3"
 num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
 num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
@@ -86,7 +67,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
   cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list
   cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list
   cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list
   cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list
   # NOTE will update llm/hift training later
   # NOTE will update llm/hift training later
-  for model in llm flow; do
+  for model in llm flow hifigan; do
     torchrun --nnodes=1 --nproc_per_node=$num_gpus \
     torchrun --nnodes=1 --nproc_per_node=$num_gpus \
         --rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \
         --rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \
       cosyvoice/bin/train.py \
       cosyvoice/bin/train.py \

+ 123 - 0
examples/libritts/cosyvoice2/run_dpo.sh

@@ -0,0 +1,123 @@
+#!/bin/bash
+# Copyright 2024 Alibaba Inc. All Rights Reserved.
+. ./path.sh || exit 1;
+
+stage=-1
+stop_stage=3
+
+data_url=www.openslr.org/resources/60
+data_dir=/mnt/lyuxiang.lx/data/tts/openslr/libritts
+pretrained_model_dir=../../../pretrained_models/CosyVoice2-0.5B
+
+if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
+  echo "Data Download"
+  for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do
+    local/download_and_untar.sh ${data_dir} ${data_url} ${part}
+  done
+fi
+
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+  echo "Data preparation, prepare wav.scp/text/utt2spk/spk2utt"
+  for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
+    mkdir -p data/$x
+    python local/prepare_data.py --src_dir $data_dir/LibriTTS/$x --des_dir data/$x
+  done
+fi
+
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+  echo "Prepare negative samples using CosyVoice2-0.5B, this is also our reference model.
+    Here we use CosyVoice2-0.5B generated audio as reject sample for simplicity, you can use metric like wer/similarity."
+  for x in train-clean-100 train-clean-360 train-other-500; do
+    mkdir -p data/${x}_reject
+    python local/prepare_reject_sample.py --src_dir data/$x --des_dir data/${x}_reject --ref_model $pretrained_model_dir
+  done
+fi
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+  echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir"
+  for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
+    tools/extract_embedding.py --dir data/$x \
+      --onnx_path $pretrained_model_dir/campplus.onnx
+  done
+fi
+
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+  echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir"
+  for x in train-clean-100 train-clean-360 train-other-500 train-clean-100_reject train-clean-360_reject dev-clean dev-other test-clean test-other; do
+    tools/extract_speech_token.py --dir data/$x \
+      --onnx_path $pretrained_model_dir/speech_tokenizer_v2.onnx
+  done
+fi
+
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+  echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
+  for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
+    mkdir -p data/$x/parquet
+    tools/make_parquet_list.py --num_utts_per_parquet 1000 \
+      --num_processes 10 \
+      --dpo \
+      --src_dir data/$x \
+      --des_dir data/$x/parquet
+  done
+fi
+
+# train llm
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+job_id=1986
+dist_backend="nccl"
+num_workers=2
+prefetch=100
+train_engine=torch_ddp
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+  echo "Run train. We only support llm traning for now. If your want to train from scratch, please use conf/cosyvoice.fromscratch.yaml"
+  if [ $train_engine == 'deepspeed' ]; then
+    echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary"
+  fi
+  cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list
+  cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list
+  # NOTE only llm supports dpo
+  for model in llm; do
+    torchrun --nnodes=1 --nproc_per_node=$num_gpus \
+        --rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \
+      cosyvoice/bin/train.py \
+      --train_engine $train_engine \
+      --config conf/cosyvoice2.yaml \
+      --train_data data/train.data.list \
+      --cv_data data/dev.data.list \
+      --qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \
+      --model $model \
+      --checkpoint $pretrained_model_dir/$model.pt \
+      --ref_model $pretrained_model_dir/llm.pt \
+      --model_dir `pwd`/exp/cosyvoice2/$model/$train_engine \
+      --tensorboard_dir `pwd`/tensorboard/cosyvoice2/$model/$train_engine \
+      --ddp.dist_backend $dist_backend \
+      --num_workers ${num_workers} \
+      --prefetch ${prefetch} \
+      --pin_memory \
+      --use_amp \
+      --dpo \
+      --deepspeed_config ./conf/ds_stage2.json \
+      --deepspeed.save_states model+optimizer
+  done
+fi
+
+# average model
+average_num=5
+if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
+  for model in llm flow hifigan; do
+    decode_checkpoint=`pwd`/exp/cosyvoice/$model/$train_engine/${model}.pt
+    echo "do model average and final checkpoint is $decode_checkpoint"
+    python cosyvoice/bin/average_model.py \
+      --dst_model $decode_checkpoint \
+      --src_path `pwd`/exp/cosyvoice/$model/$train_engine  \
+      --num ${average_num} \
+      --val_best
+  done
+fi
+
+if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
+  echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir"
+  python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir
+  python cosyvoice/bin/export_onnx.py --model_dir $pretrained_model_dir
+fi

+ 0 - 17
examples/magicdata-read/cosyvoice/run.sh

@@ -51,23 +51,6 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
   done
   done
 fi
 fi
 
 
-# inference
-if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
-  echo "Run inference. Please make sure utt in tts_text is in prompt_data"
-  for mode in sft zero_shot; do
-    python cosyvoice/bin/inference.py --mode $mode \
-      --gpu 0 \
-      --config conf/cosyvoice.yaml \
-      --prompt_data data/test/parquet/data.list \
-      --prompt_utt2data data/test/parquet/utt2data.list \
-      --tts_text `pwd`/tts_text.json \
-      --llm_model $pretrained_model_dir/llm.pt \
-      --flow_model $pretrained_model_dir/flow.pt \
-      --hifigan_model $pretrained_model_dir/hift.pt \
-      --result_dir `pwd`/exp/cosyvoice/test/$mode
-  done
-fi
-
 # train llm
 # train llm
 export CUDA_VISIBLE_DEVICES="0,1,2,3"
 export CUDA_VISIBLE_DEVICES="0,1,2,3"
 num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
 num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')

+ 11 - 1
tools/make_parquet_list.py

@@ -34,7 +34,9 @@ def job(utt_list, parquet_file, utt2parquet_file, spk2parquet_file):
     spk_list = [utt2spk[utt] for utt in utt_list]
     spk_list = [utt2spk[utt] for utt in utt_list]
     uttembedding_list = [utt2embedding[utt] for utt in utt_list]
     uttembedding_list = [utt2embedding[utt] for utt in utt_list]
     spkembedding_list = [spk2embedding[utt2spk[utt]] for utt in utt_list]
     spkembedding_list = [spk2embedding[utt2spk[utt]] for utt in utt_list]
-    speech_token_list = [utt2speech_token[utt] for utt in utt_list]
+    speech_token_list = [utt2speech_token.get(utt, []) for utt in utt_list]
+    if args.dpo:
+        reject_speech_token_list = [utt2reject_speech_token[utt] for utt in utt_list]
 
 
     # 保存到parquet,utt2parquet_file,spk2parquet_file
     # 保存到parquet,utt2parquet_file,spk2parquet_file
     df = pd.DataFrame()
     df = pd.DataFrame()
@@ -46,6 +48,8 @@ def job(utt_list, parquet_file, utt2parquet_file, spk2parquet_file):
     df['utt_embedding'] = uttembedding_list
     df['utt_embedding'] = uttembedding_list
     df['spk_embedding'] = spkembedding_list
     df['spk_embedding'] = spkembedding_list
     df['speech_token'] = speech_token_list
     df['speech_token'] = speech_token_list
+    if args.dpo:
+        df['reject_speech_token'] = reject_speech_token_list
     df.to_parquet(parquet_file)
     df.to_parquet(parquet_file)
     with open(utt2parquet_file, 'w') as f:
     with open(utt2parquet_file, 'w') as f:
         json.dump({k: parquet_file for k in utt_list}, f, ensure_ascii=False, indent=2)
         json.dump({k: parquet_file for k in utt_list}, f, ensure_ascii=False, indent=2)
@@ -68,6 +72,10 @@ if __name__ == "__main__":
                         type=str)
                         type=str)
     parser.add_argument('--des_dir',
     parser.add_argument('--des_dir',
                         type=str)
                         type=str)
+    parser.add_argument('--dpo',
+                        action='store_true',
+                        default=False,
+                        help='Use Direct Preference Optimization')
     args = parser.parse_args()
     args = parser.parse_args()
 
 
     utt2wav, utt2text, utt2spk = {}, {}, {}
     utt2wav, utt2text, utt2spk = {}, {}, {}
@@ -86,6 +94,8 @@ if __name__ == "__main__":
     utt2embedding = torch.load('{}/utt2embedding.pt'.format(args.src_dir))
     utt2embedding = torch.load('{}/utt2embedding.pt'.format(args.src_dir))
     spk2embedding = torch.load('{}/spk2embedding.pt'.format(args.src_dir))
     spk2embedding = torch.load('{}/spk2embedding.pt'.format(args.src_dir))
     utt2speech_token = torch.load('{}/utt2speech_token.pt'.format(args.src_dir))
     utt2speech_token = torch.load('{}/utt2speech_token.pt'.format(args.src_dir))
+    if args.dpo:
+        utt2reject_speech_token = torch.load('{}_reject/utt2speech_token.pt'.format(args.src_dir))
     utts = list(utt2wav.keys())
     utts = list(utt2wav.keys())
 
 
     # Using process pool to speedup
     # Using process pool to speedup

+ 0 - 125
tools/make_parquet_list_dpo.py

@@ -1,125 +0,0 @@
-#!/usr/bin/env python3
-# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import argparse
-import logging
-import os
-import json
-from tqdm import tqdm
-import pandas as pd
-import multiprocessing
-import time
-import torch
-
-
-def job(utt_list, parquet_file, utt2parquet_file, spk2parquet_file):
-    start_time = time.time()
-    data_list = []
-    for utt in tqdm(utt_list):
-        data = open(utt2wav[utt], 'rb').read()
-        data_list.append(data)
-    wav_list = [utt2wav[utt] for utt in utt_list]
-    text_list = [utt2text[utt] for utt in utt_list]
-    spk_list = [utt2spk[utt] for utt in utt_list]
-    uttembedding_list = [utt2embedding[utt] for utt in utt_list]
-    spkembedding_list = [spk2embedding[utt2spk[utt]] for utt in utt_list]
-    speech_token_list = [utt2speech_token[utt] for utt in utt_list]
-    if utt2reject_speech_token:
-        reject_speech_token_list = [utt2reject_speech_token[utt] for utt in utt_list]
-
-    # 保存到parquet,utt2parquet_file,spk2parquet_file
-    df = pd.DataFrame()
-    df['utt'] = utt_list
-    df['wav'] = wav_list
-    df['audio_data'] = data_list
-    df['text'] = text_list
-    df['spk'] = spk_list
-    df['utt_embedding'] = uttembedding_list
-    df['spk_embedding'] = spkembedding_list
-    df['speech_token'] = speech_token_list
-    if utt2reject_speech_token:
-        df['reject_speech_token'] = reject_speech_token_list
-    df.to_parquet(parquet_file)
-    with open(utt2parquet_file, 'w') as f:
-        json.dump({k: parquet_file for k in utt_list}, f, ensure_ascii=False, indent=2)
-    with open(spk2parquet_file, 'w') as f:
-        json.dump({k: parquet_file for k in list(set(spk_list))}, f, ensure_ascii=False, indent=2)
-    logging.info('spend time {}'.format(time.time() - start_time))
-
-
-if __name__ == "__main__":
-    parser = argparse.ArgumentParser()
-    parser.add_argument('--num_utts_per_parquet',
-                        type=int,
-                        default=1000,
-                        help='num utts per parquet')
-    parser.add_argument('--num_processes',
-                        type=int,
-                        default=1,
-                        help='num processes for make parquets')
-    parser.add_argument('--src_dir',
-                        type=str)
-    parser.add_argument('--des_dir',
-                        type=str)
-    parser.add_argument('--dpo',
-                        action='store_true',
-                        default=False,
-                        help='Use Direct Preference Optimization')
-    args = parser.parse_args()
-
-    utt2wav, utt2text, utt2spk = {}, {}, {}
-    with open('{}/wav.scp'.format(args.src_dir)) as f:
-        for l in f:
-            l = l.replace('\n', '').split()
-            utt2wav[l[0]] = l[1]
-    with open('{}/text'.format(args.src_dir)) as f:
-        for l in f:
-            l = l.replace('\n', '').split()
-            utt2text[l[0]] = ' '.join(l[1:])
-    with open('{}/utt2spk'.format(args.src_dir)) as f:
-        for l in f:
-            l = l.replace('\n', '').split()
-            utt2spk[l[0]] = l[1]
-    utt2embedding = torch.load('{}/utt2embedding.pt'.format(args.src_dir))
-    spk2embedding = torch.load('{}/spk2embedding.pt'.format(args.src_dir))
-    utt2speech_token = torch.load('{}/utt2speech_token.pt'.format(args.src_dir))
-    if args.dpo:
-        utt2reject_speech_token = torch.load('{}/utt2reject_speech_token.pt'.format(args.src_dir))
-    else:
-        utt2reject_speech_token = None
-    utts = list(utt2wav.keys())
-
-    # Using process pool to speedup
-    pool = multiprocessing.Pool(processes=args.num_processes)
-    parquet_list, utt2parquet_list, spk2parquet_list = [], [], []
-    for i, j in enumerate(range(0, len(utts), args.num_utts_per_parquet)):
-        parquet_file = os.path.join(args.des_dir, 'parquet_{:09d}.tar'.format(i))
-        utt2parquet_file = os.path.join(args.des_dir, 'utt2parquet_{:09d}.json'.format(i))
-        spk2parquet_file = os.path.join(args.des_dir, 'spk2parquet_{:09d}.json'.format(i))
-        parquet_list.append(parquet_file)
-        utt2parquet_list.append(utt2parquet_file)
-        spk2parquet_list.append(spk2parquet_file)
-        pool.apply_async(job, (utts[j: j + args.num_utts_per_parquet], parquet_file, utt2parquet_file, spk2parquet_file))
-    pool.close()
-    pool.join()
-
-    with open('{}/data.list'.format(args.des_dir), 'w', encoding='utf8') as f1, \
-            open('{}/utt2data.list'.format(args.des_dir), 'w', encoding='utf8') as f2, \
-            open('{}/spk2data.list'.format(args.des_dir), 'w', encoding='utf8') as f3:
-        for name in parquet_list:
-            f1.write(name + '\n')
-        for name in utt2parquet_list:
-            f2.write(name + '\n')
-        for name in spk2parquet_list:
-            f3.write(name + '\n')