Przeglądaj źródła

keep only embedding mean as spk embedding

lyuxiang.lx 1 rok temu
rodzic
commit
6a3e44242a
2 zmienionych plików z 3 dodań i 1 usunięć
  1. 1 1
      cosyvoice/dataset/processor.py
  2. 2 0
      tools/extract_embedding.py

+ 1 - 1
cosyvoice/dataset/processor.py

@@ -167,7 +167,7 @@ def parse_embedding(data, normalize, mode='train'):
     """
     for sample in data:
         sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
-        sample['spk_embedding'] = torch.stack([torch.tensor(i, dtype=torch.float32) for i in sample['spk_embedding']], dim=0).mean(dim=0)
+        sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
         if normalize:
             sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
             sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)

+ 2 - 0
tools/extract_embedding.py

@@ -53,6 +53,8 @@ def main(args):
         if spk not in spk2embedding:
             spk2embedding[spk] = []
         spk2embedding[spk].append(embedding)
+    for k, v in spk2embedding.items():
+        spk2embedding[k] = torch.tensor(v).mean(dim=0, keepdim=True).tolist()
 
     torch.save(utt2embedding, '{}/utt2embedding.pt'.format(args.dir))
     torch.save(spk2embedding, '{}/spk2embedding.pt'.format(args.dir))