|
|
@@ -167,7 +167,7 @@ def parse_embedding(data, normalize, mode='train'):
|
|
|
"""
|
|
|
for sample in data:
|
|
|
sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
|
|
|
- sample['spk_embedding'] = torch.stack([torch.tensor(i, dtype=torch.float32) for i in sample['spk_embedding']], dim=0).mean(dim=0)
|
|
|
+ sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
|
|
|
if normalize:
|
|
|
sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
|
|
|
sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
|