From 100ef498c5c1a393cafa9c574689197308e16430 Mon Sep 17 00:00:00 2001 From: Karlis Goba Date: Sun, 10 Nov 2019 09:04:26 +0200 Subject: [PATCH] Test script to check overall accuracy compared to test cases generated by WSJT-X --- utils/run_tests.py | 49 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100755 utils/run_tests.py diff --git a/utils/run_tests.py b/utils/run_tests.py new file mode 100755 index 0000000..1298190 --- /dev/null +++ b/utils/run_tests.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 + +import sys, os, subprocess + +def parse(line): + fields = line.split() + freq = fields[3] + dest = fields[5] if len(fields) > 5 else '' + source = fields[6] if len(fields) > 6 else '' + report = fields[7] if len(fields) > 7 else '' + return ' '.join([dest, source, report]) + +wav_dir = sys.argv[1] +wav_files = [os.path.join(wav_dir, f) for f in os.listdir(wav_dir)] +wav_files = [f for f in wav_files if os.path.isfile(f) and os.path.splitext(f)[1] == '.wav'] +txt_files = [os.path.splitext(f)[0] + '.txt' for f in wav_files] + +n_extra = 0 +n_missed = 0 +n_total = 0 +for wav_file, txt_file in zip(wav_files, txt_files): + result = subprocess.run(['./decode_ft8', wav_file], stdout=subprocess.PIPE) + result = result.stdout.decode('utf-8').split('\n') + result = [parse(x) for x in result if len(x) > 0] + #print(result[0]) + result = set(result) + + expected = open(txt_file).read().split('\n') + expected = [parse(x) for x in expected if len(x) > 0] + #print(expected[0]) + expected = set(expected) + + extra_decodes = result - expected + missed_decodes = expected - result + print(len(result), '/', len(expected)) + if len(extra_decodes) > 0: + print('Extra decodes: ', list(extra_decodes)) + #print('Missed decodes: ', list(missed_decodes)) + + n_total += len(expected) + n_extra += len(extra_decodes) + n_missed += len(missed_decodes) + + #break + +print('Total: %d, extra: %d (%.1f%%), missed: %d (%.1f%%)' % + (n_total, n_extra, 100.0*n_extra/n_total, n_missed, 100.0*n_missed/n_total)) +recall = (n_total - n_missed) / float(n_total) +print('Recall: %.1f%%' % (100*recall, ))