frontend_utils.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import re
  15. import regex
  16. chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
  17. # whether contain chinese character
  18. def contains_chinese(text):
  19. return bool(chinese_char_pattern.search(text))
  20. # replace special symbol
  21. def replace_corner_mark(text):
  22. text = text.replace('²', '平方')
  23. text = text.replace('³', '立方')
  24. return text
  25. # remove meaningless symbol
  26. def remove_bracket(text):
  27. text = text.replace('(', '').replace(')', '')
  28. text = text.replace('【', '').replace('】', '')
  29. text = text.replace('`', '').replace('`', '')
  30. text = text.replace("——", " ")
  31. return text
  32. # spell Arabic numerals
  33. def spell_out_number(text: str, inflect_parser):
  34. new_text = []
  35. st = None
  36. for i, c in enumerate(text):
  37. if not c.isdigit():
  38. if st is not None:
  39. num_str = inflect_parser.number_to_words(text[st: i])
  40. new_text.append(num_str)
  41. st = None
  42. new_text.append(c)
  43. else:
  44. if st is None:
  45. st = i
  46. if st is not None and st < len(text):
  47. num_str = inflect_parser.number_to_words(text[st:])
  48. new_text.append(num_str)
  49. return ''.join(new_text)
  50. # split paragrah logic:
  51. # 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len
  52. # 2. cal sentence len according to lang
  53. # 3. split sentence according to puncatation
  54. def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False):
  55. def calc_utt_length(_text: str):
  56. if lang == "zh":
  57. return len(_text)
  58. else:
  59. return len(tokenize(_text))
  60. def should_merge(_text: str):
  61. if lang == "zh":
  62. return len(_text) < merge_len
  63. else:
  64. return len(tokenize(_text)) < merge_len
  65. if lang == "zh":
  66. pounc = ['。', '?', '!', ';', ':', '、', '.', '?', '!', ';']
  67. else:
  68. pounc = ['.', '?', '!', ';', ':']
  69. if comma_split:
  70. pounc.extend([',', ','])
  71. if text[-1] not in pounc:
  72. if lang == "zh":
  73. text += "。"
  74. else:
  75. text += "."
  76. st = 0
  77. utts = []
  78. for i, c in enumerate(text):
  79. if c in pounc:
  80. if len(text[st: i]) > 0:
  81. utts.append(text[st: i] + c)
  82. if i + 1 < len(text) and text[i + 1] in ['"', '”']:
  83. tmp = utts.pop(-1)
  84. utts.append(tmp + text[i + 1])
  85. st = i + 2
  86. else:
  87. st = i + 1
  88. final_utts = []
  89. cur_utt = ""
  90. for utt in utts:
  91. if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n:
  92. final_utts.append(cur_utt)
  93. cur_utt = ""
  94. cur_utt = cur_utt + utt
  95. if len(cur_utt) > 0:
  96. if should_merge(cur_utt) and len(final_utts) != 0:
  97. final_utts[-1] = final_utts[-1] + cur_utt
  98. else:
  99. final_utts.append(cur_utt)
  100. return final_utts
  101. # remove blank between chinese character
  102. def replace_blank(text: str):
  103. out_str = []
  104. for i, c in enumerate(text):
  105. if c == " ":
  106. if ((text[i + 1].isascii() and text[i + 1] != " ") and
  107. (text[i - 1].isascii() and text[i - 1] != " ")):
  108. out_str.append(c)
  109. else:
  110. out_str.append(c)
  111. return "".join(out_str)
  112. def is_only_punctuation(text):
  113. # Regular expression: Match strings that consist only of punctuation marks or are empty.
  114. punctuation_pattern = r'^[\p{P}\p{S}]*$'
  115. return bool(regex.fullmatch(punctuation_pattern, text))