1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
| import os import subprocess import wave import sys import json from vosk import Model, KaldiRecognizer, SetLogLevel import difflib import time
def get_edit_distance(str1, str2) -> int: """ 计算两个串的编辑距离,支持str和list类型 str1和str2是列表,列表元素是要比的字符串,计算对应位置字符串的编辑距离 """ leven_cost = 0 # print(f'--str1-str2-{str1}-{str2}') for s1,s2 in zip(str1,str2): sequence_match = difflib.SequenceMatcher(None, s1, s2) for tag, index_1, index_2, index_j1, index_j2 in sequence_match.get_opcodes(): if tag == 'replace': leven_cost += max(index_2-index_1, index_j2-index_j1) elif tag == 'insert': leven_cost += (index_j2-index_j1) elif tag == 'delete': leven_cost += (index_2-index_1) return leven_cost
SetLogLevel(-1)
model = Model("../Downloads/vosk-model-small-cn-0.22")
fr=48000 rec = KaldiRecognizer(model, fr) rec.SetWords(True)
def recognize(file,trans): wf = wave.open(file, "rb") if wf.getnchannels() != 1 or wf.getsampwidth() != 2 or wf.getcomptype() != "NONE": print("Audio file must be WAV format mono PCM.")
sys.exit(1)
str_ret = ""
while True: data = wf.readframes(4000) if len(data) == 0: break if rec.AcceptWaveform(data): result = rec.Result()
result = json.loads(result) if 'text' in result: str_ret += result['text'] + ' '
result = json.loads(rec.FinalResult()) if 'text' in result: str_ret += result['text'] str_ret=str_ret.replace(' ','')
wer=get_edit_distance(str_ret,trans)/len(trans) print(str_ret,trans,wer) wf.close() return wer
wers=[]
os.chdir('../dataset/chs') for file in os.listdir(): fn,_=os.path.splitext(file) st=time.time() wer=recognize(file,fn) et=time.time() print(f'latency:{et-st}, throughput:{1/(et-st)}') wers.append(wer)
print(f'average wer:{sum(wers)/len(wers)}')
|