11from collections import Counter , defaultdict
2- from typing import List
2+ from typing import Dict , List
33
44from .quotation_mark_direction import QuotationMarkDirection
55from .quotation_mark_metadata import QuotationMarkMetadata
@@ -15,6 +15,10 @@ def count_quotation_mark(self, quotation_mark: str) -> None:
1515 self ._quotation_mark_counter .update ([quotation_mark ])
1616 self ._total_count += 1
1717
18+ def count_from (self , quotation_mark_counts : "QuotationMarkCounts" ) -> None :
19+ self ._quotation_mark_counter .update (quotation_mark_counts ._quotation_mark_counter )
20+ self ._total_count += quotation_mark_counts ._total_count
21+
1822 def find_best_quotation_mark_proportion (self ) -> tuple [str , int , int ]:
1923 return self ._quotation_mark_counter .most_common (1 )[0 ] + (self ._total_count ,)
2024
@@ -36,6 +40,13 @@ def tabulate(self, quotation_marks: list[QuotationMarkMetadata]) -> None:
3640 for quotation_mark in quotation_marks :
3741 self ._count_quotation_mark (quotation_mark )
3842
43+ def tabulate_from (self , tabulated_quotation_marks : "QuotationMarkTabulator" ) -> None :
44+ for (
45+ depth_and_direction ,
46+ quotation_mark_counts ,
47+ ) in tabulated_quotation_marks ._quotation_counts_by_depth_and_direction .items ():
48+ self ._quotation_counts_by_depth_and_direction [depth_and_direction ].count_from (quotation_mark_counts )
49+
3950 def _count_quotation_mark (self , quotation_mark : QuotationMarkMetadata ) -> None :
4051 key = (quotation_mark .depth , quotation_mark .direction )
4152 self ._quotation_counts_by_depth_and_direction [key ].count_quotation_mark (quotation_mark .quotation_mark )
@@ -48,23 +59,39 @@ def _find_most_common_quotation_mark_with_depth_and_direction(
4859 ) -> tuple [str , int , int ]:
4960 return self ._quotation_counts_by_depth_and_direction [(depth , direction )].find_best_quotation_mark_proportion ()
5061
62+ def get_total_quotation_mark_count (self ) -> int :
63+ total_count = 0
64+ for counts in self ._quotation_counts_by_depth_and_direction .values ():
65+ total_count += counts .get_observed_count ()
66+ return total_count
67+
5168 def calculate_similarity (self , quote_convention : QuoteConvention ) -> float :
52- weighted_difference = 0
53- total_weight = 0
54- for depth , direction in self ._quotation_counts_by_depth_and_direction :
69+ num_marks_by_depth : Dict [int , int ] = defaultdict (int )
70+ num_matching_marks_by_depth : Dict [int , int ] = defaultdict (int )
71+
72+ for depth , direction in sorted (self ._quotation_counts_by_depth_and_direction , key = lambda item : item [0 ]):
5573 expected_quotation_mark : str = quote_convention .get_expected_quotation_mark (depth , direction )
5674
57- # Give higher weight to shallower depths, since deeper marks are more likely to be mistakes
58- weighted_difference += self ._quotation_counts_by_depth_and_direction [
75+ num_matching_marks = self ._quotation_counts_by_depth_and_direction [(depth , direction )].get_observed_count ()
76+ num_marks_by_depth [depth ] += num_matching_marks
77+ num_matching_marks_by_depth [depth ] += num_matching_marks - self ._quotation_counts_by_depth_and_direction [
5978 (depth , direction )
60- ].calculate_num_differences (expected_quotation_mark ) * 2 ** (- depth )
61- total_weight += self ._quotation_counts_by_depth_and_direction [
62- (depth , direction )
63- ].get_observed_count () * 2 ** (- depth )
79+ ].calculate_num_differences (expected_quotation_mark )
80+
81+ # The scores of greater depths depend on the scores of shallower depths
82+ scores_by_depth : Dict [int , float ] = defaultdict (float )
83+ for depth in sorted (num_marks_by_depth .keys ()):
84+ previous_depth_score = (
85+ scores_by_depth [depth - 1 ] / num_marks_by_depth [depth - 1 ] if depth - 1 in scores_by_depth else 1
86+ )
87+ scores_by_depth [depth ] = previous_depth_score * num_matching_marks_by_depth [depth ]
88+
89+ total_marks = sum (num_marks_by_depth .values ())
90+ total_score = sum (scores_by_depth .values ())
6491
65- if total_weight == 0 :
92+ if total_marks == 0 :
6693 return 0
67- return 1 - ( weighted_difference / total_weight )
94+ return total_score / total_marks
6895
6996 def get_summary_message (self ) -> str :
7097 message_lines : List [str ] = []
0 commit comments