frontend_utils.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import re
  15. chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
  16. # whether contain chinese character
  17. def contains_chinese(text):
  18. return bool(chinese_char_pattern.search(text))
  19. # replace special symbol
  20. def replace_corner_mark(text):
  21. text = text.replace('²', '平方')
  22. text = text.replace('³', '立方')
  23. return text
  24. # remove meaningless symbol
  25. def remove_bracket(text):
  26. text = text.replace('(', '').replace(')', '')
  27. text = text.replace('【', '').replace('】', '')
  28. text = text.replace('`', '').replace('`', '')
  29. text = text.replace("——", " ")
  30. return text
  31. # spell Arabic numerals
  32. def spell_out_number(text: str, inflect_parser):
  33. new_text = []
  34. st = None
  35. for i, c in enumerate(text):
  36. if not c.isdigit():
  37. if st is not None:
  38. num_str = inflect_parser.number_to_words(text[st: i])
  39. new_text.append(num_str)
  40. st = None
  41. new_text.append(c)
  42. else:
  43. if st is None:
  44. st = i
  45. if st is not None and st < len(text):
  46. num_str = inflect_parser.number_to_words(text[st:])
  47. new_text.append(num_str)
  48. return ''.join(new_text)
  49. # split paragrah logic:
  50. # 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len
  51. # 2. cal sentence len according to lang
  52. # 3. split sentence according to puncatation
  53. def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False):
  54. def calc_utt_length(_text: str):
  55. if lang == "zh":
  56. return len(_text)
  57. else:
  58. return len(tokenize(_text))
  59. def should_merge(_text: str):
  60. if lang == "zh":
  61. return len(_text) < merge_len
  62. else:
  63. return len(tokenize(_text)) < merge_len
  64. if lang == "zh":
  65. pounc = ['。', '?', '!', ';', ':', '、', '.', '?', '!', ';']
  66. else:
  67. pounc = ['.', '?', '!', ';', ':']
  68. if comma_split:
  69. pounc.extend([',', ','])
  70. if text[-1] not in pounc:
  71. if lang == "zh":
  72. text += "。"
  73. else:
  74. text += "."
  75. st = 0
  76. utts = []
  77. for i, c in enumerate(text):
  78. if c in pounc:
  79. if len(text[st: i]) > 0:
  80. utts.append(text[st: i] + c)
  81. if i + 1 < len(text) and text[i + 1] in ['"', '”']:
  82. tmp = utts.pop(-1)
  83. utts.append(tmp + text[i + 1])
  84. st = i + 2
  85. else:
  86. st = i + 1
  87. final_utts = []
  88. cur_utt = ""
  89. for utt in utts:
  90. if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n:
  91. final_utts.append(cur_utt)
  92. cur_utt = ""
  93. cur_utt = cur_utt + utt
  94. if len(cur_utt) > 0:
  95. if should_merge(cur_utt) and len(final_utts) != 0:
  96. final_utts[-1] = final_utts[-1] + cur_utt
  97. else:
  98. final_utts.append(cur_utt)
  99. return final_utts
  100. # remove blank between chinese character
  101. def replace_blank(text: str):
  102. out_str = []
  103. for i, c in enumerate(text):
  104. if c == " ":
  105. if ((text[i + 1].isascii() and text[i + 1] != " ") and
  106. (text[i - 1].isascii() and text[i - 1] != " ")):
  107. out_str.append(c)
  108. else:
  109. out_str.append(c)
  110. return "".join(out_str)