model.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
  2. # 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import os
  16. from typing import Generator
  17. import queue
  18. import torch
  19. import numpy as np
  20. import threading
  21. import time
  22. from torch.nn import functional as F
  23. from contextlib import nullcontext
  24. import uuid
  25. from cosyvoice.utils.common import fade_in_out
  26. from cosyvoice.utils.file_utils import convert_onnx_to_trt
  27. from cosyvoice.utils.common import TrtContextWrapper
  28. class CosyVoiceModel:
  29. def __init__(self,
  30. llm: torch.nn.Module,
  31. flow: torch.nn.Module,
  32. hift: torch.nn.Module,
  33. fp16: bool = False,
  34. trt_concurrent: int = 1):
  35. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  36. self.llm = llm
  37. self.flow = flow
  38. self.hift = hift
  39. self.fp16 = fp16
  40. self.trt_concurrent = trt_concurrent
  41. if self.fp16 is True:
  42. self.llm.half()
  43. self.flow.half()
  44. self.token_min_hop_len = 2 * self.flow.input_frame_rate
  45. self.token_max_hop_len = 4 * self.flow.input_frame_rate
  46. self.token_overlap_len = 20
  47. # mel fade in out
  48. self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
  49. self.mel_window = np.hamming(2 * self.mel_overlap_len)
  50. # hift cache
  51. self.mel_cache_len = 20
  52. self.source_cache_len = int(self.mel_cache_len * 256)
  53. # speech fade in out
  54. self.speech_window = np.hamming(2 * self.source_cache_len)
  55. # rtf and decoding related
  56. self.stream_scale_factor = 1
  57. assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
  58. self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  59. self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
  60. for _ in range(trt_concurrent):
  61. self.trt_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext())
  62. self.lock = threading.Lock()
  63. # dict used to store session related variable
  64. self.tts_speech_token_dict = {}
  65. self.llm_end_dict = {}
  66. self.mel_overlap_dict = {}
  67. self.flow_cache_dict = {}
  68. self.hift_cache_dict = {}
  69. self.trt_context_dict = {}
  70. def load(self, llm_model, flow_model, hift_model):
  71. self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
  72. self.llm.to(self.device).eval()
  73. self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
  74. self.flow.to(self.device).eval()
  75. # in case hift_model is a hifigan model
  76. hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
  77. self.hift.load_state_dict(hift_state_dict, strict=True)
  78. self.hift.to(self.device).eval()
  79. def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
  80. llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
  81. self.llm.text_encoder = llm_text_encoder
  82. llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
  83. self.llm.llm = llm_llm
  84. flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
  85. self.flow.encoder = flow_encoder
  86. def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
  87. assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
  88. if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
  89. convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
  90. del self.flow.decoder.estimator
  91. import tensorrt as trt
  92. with open(flow_decoder_estimator_model, 'rb') as f:
  93. estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
  94. assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
  95. self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=self.trt_concurrent)
  96. def get_trt_kwargs(self):
  97. min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
  98. opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
  99. max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
  100. input_names = ["x", "mask", "mu", "cond"]
  101. return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
  102. def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
  103. with self.llm_context, torch.cuda.amp.autocast(self.fp16):
  104. if isinstance(text, Generator):
  105. assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
  106. for i in self.llm.inference_bistream(text=text,
  107. prompt_text=prompt_text.to(self.device),
  108. prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
  109. prompt_speech_token=llm_prompt_speech_token.to(self.device),
  110. prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
  111. embedding=llm_embedding.to(self.device)):
  112. self.tts_speech_token_dict[uuid].append(i)
  113. else:
  114. for i in self.llm.inference(text=text.to(self.device),
  115. text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
  116. prompt_text=prompt_text.to(self.device),
  117. prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
  118. prompt_speech_token=llm_prompt_speech_token.to(self.device),
  119. prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
  120. embedding=llm_embedding.to(self.device)):
  121. self.tts_speech_token_dict[uuid].append(i)
  122. self.llm_end_dict[uuid] = True
  123. def vc_job(self, source_speech_token, uuid):
  124. self.tts_speech_token_dict[uuid] = source_speech_token.flatten().tolist()
  125. self.llm_end_dict[uuid] = True
  126. def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
  127. with torch.cuda.amp.autocast(self.fp16):
  128. tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device),
  129. token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
  130. prompt_token=prompt_token.to(self.device),
  131. prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
  132. prompt_feat=prompt_feat.to(self.device),
  133. prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
  134. embedding=embedding.to(self.device),
  135. flow_cache=self.flow_cache_dict[uuid])
  136. # mel overlap fade in out
  137. if self.mel_overlap_dict[uuid].shape[2] != 0:
  138. tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
  139. # append hift cache
  140. if self.hift_cache_dict[uuid] is not None:
  141. hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
  142. tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
  143. else:
  144. hift_cache_source = torch.zeros(1, 1, 0)
  145. # keep overlap mel and hift cache
  146. if finalize is False:
  147. self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
  148. tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
  149. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  150. if self.hift_cache_dict[uuid] is not None:
  151. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  152. self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
  153. 'source': tts_source[:, :, -self.source_cache_len:],
  154. 'speech': tts_speech[:, -self.source_cache_len:]}
  155. tts_speech = tts_speech[:, :-self.source_cache_len]
  156. else:
  157. if speed != 1.0:
  158. assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
  159. tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
  160. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  161. if self.hift_cache_dict[uuid] is not None:
  162. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  163. return tts_speech
  164. def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.zeros(0, 192), llm_embedding=torch.zeros(0, 192),
  165. prompt_text=torch.zeros(1, 0, dtype=torch.int32),
  166. llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  167. flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  168. prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
  169. # this_uuid is used to track variables related to this inference thread
  170. this_uuid = str(uuid.uuid1())
  171. this_trt_context = self.trt_context_pool.get()
  172. with self.lock:
  173. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
  174. self.hift_cache_dict[this_uuid] = None
  175. self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
  176. self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
  177. self.trt_context_dict[this_uuid] = this_trt_context
  178. if source_speech_token.shape[1] == 0:
  179. p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
  180. else:
  181. p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid))
  182. p.start()
  183. if stream is True:
  184. token_hop_len = self.token_min_hop_len
  185. while True:
  186. time.sleep(0.1)
  187. if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
  188. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
  189. .unsqueeze(dim=0)
  190. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  191. prompt_token=flow_prompt_speech_token,
  192. prompt_feat=prompt_speech_feat,
  193. embedding=flow_embedding,
  194. uuid=this_uuid,
  195. finalize=False)
  196. yield {'tts_speech': this_tts_speech.cpu()}
  197. with self.lock:
  198. self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
  199. # increase token_hop_len for better speech quality
  200. token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
  201. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
  202. break
  203. p.join()
  204. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  205. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  206. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  207. prompt_token=flow_prompt_speech_token,
  208. prompt_feat=prompt_speech_feat,
  209. embedding=flow_embedding,
  210. uuid=this_uuid,
  211. finalize=True)
  212. yield {'tts_speech': this_tts_speech.cpu()}
  213. else:
  214. # deal with all tokens
  215. p.join()
  216. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  217. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  218. prompt_token=flow_prompt_speech_token,
  219. prompt_feat=prompt_speech_feat,
  220. embedding=flow_embedding,
  221. uuid=this_uuid,
  222. finalize=True,
  223. speed=speed)
  224. yield {'tts_speech': this_tts_speech.cpu()}
  225. with self.lock:
  226. self.tts_speech_token_dict.pop(this_uuid)
  227. self.llm_end_dict.pop(this_uuid)
  228. self.mel_overlap_dict.pop(this_uuid)
  229. self.hift_cache_dict.pop(this_uuid)
  230. self.flow_cache_dict.pop(this_uuid)
  231. self.trt_context_pool.put(self.trt_context_dict[this_uuid])
  232. self.trt_context_dict.pop(this_uuid)
  233. if torch.cuda.is_available():
  234. torch.cuda.empty_cache()
  235. torch.cuda.current_stream().synchronize()
  236. class CosyVoice2Model(CosyVoiceModel):
  237. def __init__(self,
  238. llm: torch.nn.Module,
  239. flow: torch.nn.Module,
  240. hift: torch.nn.Module,
  241. fp16: bool = False,
  242. trt_concurrent: int = 1):
  243. self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  244. self.llm = llm
  245. self.flow = flow
  246. # NOTE default setting for jit/onnx export, you can set to False when using pytorch inference
  247. self.flow.encoder.streaming = True
  248. self.flow.decoder.estimator.streaming = True
  249. self.hift = hift
  250. self.fp16 = fp16
  251. self.trt_concurrent = trt_concurrent
  252. if self.fp16 is True:
  253. self.llm.half()
  254. self.flow.half()
  255. # NOTE must matching training static_chunk_size
  256. self.token_hop_len = 25
  257. # hift cache
  258. self.mel_cache_len = 8
  259. self.source_cache_len = int(self.mel_cache_len * 480)
  260. # speech fade in out
  261. self.speech_window = np.hamming(2 * self.source_cache_len)
  262. # rtf and decoding related
  263. self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
  264. self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
  265. for _ in range(trt_concurrent):
  266. self.trt_context_pool.put(torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext())
  267. self.lock = threading.Lock()
  268. # dict used to store session related variable
  269. self.tts_speech_token_dict = {}
  270. self.llm_end_dict = {}
  271. self.hift_cache_dict = {}
  272. self.trt_context_dict = {}
  273. def load_jit(self, flow_encoder_model):
  274. flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
  275. self.flow.encoder = flow_encoder
  276. def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, finalize=False, speed=1.0):
  277. with torch.cuda.amp.autocast(self.fp16), self.trt_context_dict[uuid]:
  278. tts_mel, _ = self.flow.inference(token=token.to(self.device),
  279. token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
  280. prompt_token=prompt_token.to(self.device),
  281. prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
  282. prompt_feat=prompt_feat.to(self.device),
  283. prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
  284. embedding=embedding.to(self.device),
  285. finalize=finalize)
  286. tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
  287. # append hift cache
  288. if self.hift_cache_dict[uuid] is not None:
  289. hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
  290. tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
  291. else:
  292. hift_cache_source = torch.zeros(1, 1, 0)
  293. # keep overlap mel and hift cache
  294. if finalize is False:
  295. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  296. if self.hift_cache_dict[uuid] is not None:
  297. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  298. self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
  299. 'source': tts_source[:, :, -self.source_cache_len:],
  300. 'speech': tts_speech[:, -self.source_cache_len:]}
  301. tts_speech = tts_speech[:, :-self.source_cache_len]
  302. else:
  303. if speed != 1.0:
  304. assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
  305. tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
  306. tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
  307. if self.hift_cache_dict[uuid] is not None:
  308. tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
  309. return tts_speech
  310. def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.zeros(0, 192), llm_embedding=torch.zeros(0, 192),
  311. prompt_text=torch.zeros(1, 0, dtype=torch.int32),
  312. llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  313. flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
  314. prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
  315. # this_uuid is used to track variables related to this inference thread
  316. this_uuid = str(uuid.uuid1())
  317. this_trt_context = self.trt_context_pool.get()
  318. with self.lock:
  319. self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
  320. self.hift_cache_dict[this_uuid] = None
  321. self.trt_context_dict[this_uuid] = this_trt_context
  322. if source_speech_token.shape[1] == 0:
  323. p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
  324. else:
  325. p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid))
  326. p.start()
  327. if stream is True:
  328. token_offset = 0
  329. prompt_token_pad = int(np.ceil(flow_prompt_speech_token.shape[1] / self.token_hop_len) * self.token_hop_len - flow_prompt_speech_token.shape[1])
  330. while True:
  331. time.sleep(0.1)
  332. this_token_hop_len = self.token_hop_len + prompt_token_pad if token_offset == 0 else self.token_hop_len
  333. if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= this_token_hop_len + self.flow.pre_lookahead_len:
  334. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + this_token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
  335. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  336. prompt_token=flow_prompt_speech_token,
  337. prompt_feat=prompt_speech_feat,
  338. embedding=flow_embedding,
  339. token_offset=token_offset,
  340. uuid=this_uuid,
  341. finalize=False)
  342. token_offset += this_token_hop_len
  343. yield {'tts_speech': this_tts_speech.cpu()}
  344. if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < this_token_hop_len + self.flow.pre_lookahead_len:
  345. break
  346. p.join()
  347. # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
  348. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  349. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  350. prompt_token=flow_prompt_speech_token,
  351. prompt_feat=prompt_speech_feat,
  352. embedding=flow_embedding,
  353. token_offset=token_offset,
  354. uuid=this_uuid,
  355. finalize=True)
  356. yield {'tts_speech': this_tts_speech.cpu()}
  357. else:
  358. # deal with all tokens
  359. p.join()
  360. this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
  361. this_tts_speech = self.token2wav(token=this_tts_speech_token,
  362. prompt_token=flow_prompt_speech_token,
  363. prompt_feat=prompt_speech_feat,
  364. embedding=flow_embedding,
  365. token_offset=0,
  366. uuid=this_uuid,
  367. finalize=True,
  368. speed=speed)
  369. yield {'tts_speech': this_tts_speech.cpu()}
  370. with self.lock:
  371. self.tts_speech_token_dict.pop(this_uuid)
  372. self.llm_end_dict.pop(this_uuid)
  373. self.hift_cache_dict.pop(this_uuid)
  374. self.trt_context_pool.put(self.trt_context_dict[this_uuid])
  375. self.trt_context_dict.pop(this_uuid)
  376. if torch.cuda.is_available():
  377. torch.cuda.empty_cache()
  378. torch.cuda.current_stream().synchronize()