Skip to content

Commit d8587de

Browse files
author
Eric Tsai
committed
some type annotations
1 parent c342ebd commit d8587de

1 file changed

Lines changed: 23 additions & 17 deletions

File tree

src/election_anomaly/__init__.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -932,7 +932,7 @@ def __init__(self):
932932
Session = sessionmaker(bind=eng)
933933
self.session = Session()
934934

935-
def display_options(self, input, verbose=False, filters=None):
935+
def display_options(self, input: str, verbose: bool=False, filters: list=None):
936936
if not verbose:
937937
results = db.get_input_options(self.session, input, False)
938938
else:
@@ -945,7 +945,7 @@ def display_options(self, input, verbose=False, filters=None):
945945
return results
946946
return None
947947

948-
def top_counts_by_vote_type(self, election, rollup_unit, sub_unit):
948+
def top_counts_by_vote_type(self, election: str, rollup_unit: str, sub_unit: str) -> str:
949949
d, error = ui.get_runtime_parameters(
950950
["rollup_directory"], param_file="multi.par"
951951
)
@@ -966,7 +966,7 @@ def top_counts_by_vote_type(self, election, rollup_unit, sub_unit):
966966
connection.close()
967967
return err_str
968968

969-
def top_counts(self, rollup_unit, sub_unit):
969+
def top_counts(self, rollup_unit: str , sub_unit: str):
970970
d, error = ui.get_runtime_parameters(["rollup_directory"])
971971
if error:
972972
print("Parameter file missing requirements.")
@@ -993,16 +993,16 @@ def top_counts(self, rollup_unit, sub_unit):
993993

994994
def scatter(
995995
self,
996-
jurisdiction,
997-
subdivision_type,
998-
h_election,
999-
h_category,
1000-
h_count, # horizontal axis params
1001-
v_election,
1002-
v_category,
1003-
v_count, # vertical axis params
1004-
fig_type=None,
1005-
):
996+
jurisdiction: str,
997+
subdivision_type: str,
998+
h_election: str,
999+
h_category: str,
1000+
h_count: str, # horizontal axis params
1001+
v_election: str,
1002+
v_category: str,
1003+
v_count: str, # vertical axis params
1004+
fig_type: str = None,
1005+
) -> list:
10061006
"""Used to create a scatter plot based on selected inputs. The fig_type parameter
10071007
is used when the user wants to actually create the visualization; this uses plotly
10081008
so any image extension that is supported by plotly is usable here. Currently supports
@@ -1056,7 +1056,13 @@ def scatter(
10561056
v.plot("scatter", agg_results, fig_type, d["rollup_directory"])
10571057
return agg_results
10581058

1059-
def bar(self, jurisdiction, contest_type=None, contest=None, fig_type=None):
1059+
def bar(
1060+
self,
1061+
jurisdiction: str,
1062+
contest_type: str = None,
1063+
contest: str = None,
1064+
fig_type: str = None
1065+
) -> list:
10601066
"""contest_type is one of state, congressional, state-senate, state-house"""
10611067
d, error = ui.get_runtime_parameters(
10621068
["rollup_directory", "sub_reporting_unit_type"], param_file="analyze.par"
@@ -1089,7 +1095,7 @@ def bar(self, jurisdiction, contest_type=None, contest=None, fig_type=None):
10891095
v.plot("bar", agg_result, fig_type, d["rollup_directory"])
10901096
return agg_results
10911097

1092-
def split_category_input(self, input_str):
1098+
def split_category_input(self, input_str: str):
10931099
"""Helper function. Takes an input from the front end that is the cartesian
10941100
product of the CountItemType and {'Candidate', 'Contest'}. So something like:
10951101
Total Candidates or Absentee Contests. Cleans this and returns
@@ -1101,7 +1107,7 @@ def split_category_input(self, input_str):
11011107
selection_type = input_str[len(count_item_type) + 1 :]
11021108
return count_item_type, selection_type
11031109

1104-
def export_outlier_data(self, jurisdiction, contest=None):
1110+
def export_outlier_data(self, jurisdiction: str, contest: str=None):
11051111
"""contest_type is one of state, congressional, state-senate, state-house"""
11061112
d, error = ui.get_runtime_parameters(
11071113
["rollup_directory", "sub_reporting_unit_type"], param_file="analyze.par"
@@ -1132,6 +1138,6 @@ def export_outlier_data(self, jurisdiction, contest=None):
11321138
return agg_results
11331139

11341140

1135-
def get_filename(path):
1141+
def get_filename(path: str) -> str:
11361142
head, tail = ntpath.split(path)
11371143
return tail or ntpath.basename(head)

0 commit comments

Comments
 (0)