Skip to content

Commit 8faa44e

Browse files
committed
bug fixes (WIP); documentation; black; test revision in progress
1 parent b01aa03 commit 8faa44e

7 files changed

Lines changed: 459 additions & 394 deletions

File tree

src/electiondata/__init__.py

Lines changed: 42 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -294,10 +294,10 @@ def __new__(cls, param_file: Optional[str] = None, dbname: Optional[str] = None)
294294
return super().__new__(cls)
295295

296296
def __init__(
297-
self,
298-
param_file: Optional[str] = None,
299-
dbname: Optional[str] = None,
300-
major_subdivision_file: Optional[str] = None,
297+
self,
298+
param_file: Optional[str] = None,
299+
dbname: Optional[str] = None,
300+
major_subdivision_file: Optional[str] = None,
301301
):
302302
"""
303303
Inputs:
@@ -340,7 +340,9 @@ def __init__(
340340

341341
# create analyzer with same db
342342
self.analyzer = Analyzer(
343-
dbname=dbname, param_file=param_file, major_subdivision_file=major_subdivision_file
343+
dbname=dbname,
344+
param_file=param_file,
345+
major_subdivision_file=major_subdivision_file,
344346
)
345347

346348
def connect_to_db(
@@ -406,12 +408,21 @@ def change_db(
406408
Optional[dict], error dictionary
407409
"""
408410
err = None
411+
if not db_params:
412+
# get db_params from current session, but with new database name
413+
db_params = {
414+
"host": self.db_engine.url.host,
415+
"port": self.db_engine.url.port,
416+
"user": self.db_engine.url.username,
417+
"password": self.db_engine.url.password,
418+
"dbname": new_db_name,
419+
}
409420

410421
# disconnect from current db and connect to new db, updating self.session and self.dbname
411422
self.session.close()
412423
# # nb: next command updates self.session
413424
new_err = self.connect_to_db(
414-
dbname=new_db_name, db_param_file=db_param_file, db_params=db_params
425+
db_params=db_params,
415426
)
416427
err = ui.consolidate_errors([err, new_err])
417428
if ui.fatal_error(new_err):
@@ -541,7 +552,7 @@ def load_one_from_ini(
541552
err,
542553
"warn-file",
543554
"major_subjurisdiction_types.txt",
544-
f"Rollup cannot be done because no major subdivision found for {sdl.d['jurisdiction']}"
555+
f"Rollup cannot be done because no major subdivision found for {sdl.d['jurisdiction']}",
545556
)
546557

547558
else:
@@ -1009,7 +1020,9 @@ def add_totals_if_missing(self, election: str, jurisdiction: str) -> Optional[di
10091020
)
10101021
return err
10111022

1012-
def load_data_from_db_dump(self, dbname: str, dump_file: str) -> Optional[str]:
1023+
def load_data_from_db_dump(
1024+
self, dbname: str, dump_file: str, delete_existing: bool = False
1025+
) -> Optional[str]:
10131026
"""
10141027
Inputs:
10151028
dbname: str, name for database to be created and loaded with data
@@ -1024,7 +1037,7 @@ def load_data_from_db_dump(self, dbname: str, dump_file: str) -> Optional[str]:
10241037
connection = self.session.bind.raw_connection()
10251038
cursor = connection.cursor()
10261039
err_str = db.create_database(
1027-
connection, cursor, dbname=dbname, delete_existing=False
1040+
connection, cursor, dbname=dbname, delete_existing=delete_existing
10281041
)
10291042

10301043
cursor.close()
@@ -2531,7 +2544,9 @@ def test_loaded_results(
25312544
f"\nNo data found",
25322545
)
25332546
if not self.check_totals_match_vote_types(
2534-
election, juris_true_name, sub_unit_type=self.major_subdivision_type[juris_true_name]
2547+
election,
2548+
juris_true_name,
2549+
sub_unit_type=self.major_subdivision_type[juris_true_name],
25352550
):
25362551
err = ui.add_new_error(
25372552
err,
@@ -2835,9 +2850,9 @@ def scatter(
28352850
Required inputs:
28362851
jurisdiction: str,
28372852
for horizontal axis of scatter plot:
2838-
h_election: str, election parameter
2853+
h_election: str, election parameter
28392854
h_category: str, category parameter (e.g., Population by Race or Candidate total)
2840-
h_count: str, count label parameter (e.g., "Black" or "Joseph R. Biden", depending on category)
2855+
h_count: str, count label parameter (e.g., "Black" or "Joseph R. Biden", depending on category)
28412856
for vertical axis of scatter plot (same definitions as for horizontal):
28422857
v_election: str,
28432858
v_category: str,
@@ -2923,7 +2938,7 @@ def bar(
29232938
self.major_subdivision_type[jurisdiction],
29242939
contest_district_type=contest_type,
29252940
contest_or_contest_group=contest,
2926-
for_export = False,
2941+
for_export=False,
29272942
)
29282943
if fig_type and agg_results:
29292944
for agg_result in agg_results:
@@ -2991,7 +3006,9 @@ def top_counts(
29913006
return err
29923007

29933008
def export_nist(
2994-
self, election: str, jurisdiction,
3009+
self,
3010+
election: str,
3011+
jurisdiction,
29953012
) -> Union[str, Dict[str, Any]]:
29963013
"""picks either version 1.0 (json) or version 2.0 (xml) based on value of constants.nist_version"""
29973014
if electiondata.constants.nist_version == "1.0":
@@ -3069,10 +3086,10 @@ def export_election_to_tsv(
30693086
"""
30703087
Required inputs:
30713088
target_file: str, path to file
3072-
election: str,
3089+
election: str,
30733090
Optional inputs:
3074-
jurisdiction: Optional[str] = None,
3075-
3091+
jurisdiction: Optional[str] = None,
3092+
30763093
Exports all election results from <self.session>'s database for the election <election> (and the jurisdiction
30773094
<jurisdiction>, if given) to the <target_file>. Columns exported are: "Election",
30783095
"Contest", "Selection", "Party", "ReportingUnit", "VoteType", "Count", "Preliminary"
@@ -3443,7 +3460,7 @@ def aggregate(
34433460
def pres_counts_by_vote_type_and_major_subdiv(
34443461
self, jurisdiction: str
34453462
) -> pd.DataFrame:
3446-
"""Not ready for prime time """
3463+
"""Not ready for prime time"""
34473464
# TODO return dataframe with columns jurisdiction, subdivision, year, CountItemType,
34483465
# total votes for pres in general election
34493466
group_cols = [
@@ -3778,7 +3795,7 @@ def data_exists(
37783795
bool, True if database specified by parameters in <param_file> (or database named <dbname>, if given) has
37793796
any election results data for the given <election> and <jurisdiction>. Otherwise false.
37803797
"""
3781-
analyzer = Analyzer(param_file=param_file,dbname=dbname)
3798+
analyzer = Analyzer(param_file=param_file, dbname=dbname)
37823799
return analyzer.data_exists(election, jurisdiction)
37833800

37843801

@@ -3800,7 +3817,7 @@ def external_data_exists(
38003817
bool, True if database specified by parameters in <param_file> (or database named <dbname>, if given) has
38013818
any external dataset content for the given election and jurisdiction. Otherwise false.
38023819
"""
3803-
an = Analyzer(param_file=param_file,dbname=dbname)
3820+
an = Analyzer(param_file=param_file, dbname=dbname)
38043821
if not an:
38053822
return False
38063823

@@ -4504,11 +4521,11 @@ def check_major_subdivisions(
45044521
for jurisdiction_id in db.jurisdiction_id_list(session):
45054522
# make sure jurisdiction has subdivision type
45064523
juris_name = db.name_from_id(session, "ReportingUnit", jurisdiction_id)
4507-
if (
4508-
juris_name not in major_subdiv_dict.keys()
4509-
or major_subdiv_dict[juris_name] is None
4510-
):
4511-
bad_jurisdictions.update({juris_name})
4524+
if (
4525+
juris_name not in major_subdiv_dict.keys()
4526+
or major_subdiv_dict[juris_name] is None
4527+
):
4528+
bad_jurisdictions.update({juris_name})
45124529
if bad_jurisdictions:
45134530
err_string = f"No Analyzer created, because no major subdivisions were found for these jurisdictions:\n"
45144531
f"{sorted(list(bad_jurisdictions))}"

src/electiondata/analyze/__init__.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,7 @@ def create_bar(
561561
0
562562
]
563563
y_party_abbr = create_party_abbreviation(y_party)
564-
jurisdiction = db.name_from_id_cursor(cursor, "ReportingUnit",jurisdiction_id)
564+
jurisdiction = db.name_from_id_cursor(cursor, "ReportingUnit", jurisdiction_id)
565565

566566
pivot_df = pd.pivot_table(
567567
temp_df, values="Count", index=["Name"], columns="Selection", fill_value=0
@@ -597,7 +597,7 @@ def create_bar(
597597
acted = "widened"
598598
results["votes_at_stake"] = f"Outlier {acted} margin by ~ {votes_at_stake}"
599599
results["margin"] = human_readable_numbers(results["margin_raw"])
600-
results["preliminary"] = db.is_preliminary(cursor,election_id,jurisdiction_id)
600+
results["preliminary"] = db.is_preliminary(cursor, election_id, jurisdiction_id)
601601

602602
# display ballot info
603603
if multiple_ballot_types:
@@ -615,8 +615,8 @@ def create_bar(
615615
results[
616616
"title"
617617
] = f"""{results["count_item_type"].replace("-", " ").title()} Ballots Reported"""
618-
download_date = db.data_file_download(cursor,election_id,jurisdiction_id)
619-
if db.is_preliminary(cursor,election_id,jurisdiction_id) and download_date:
618+
download_date = db.data_file_download(cursor, election_id, jurisdiction_id)
619+
if db.is_preliminary(cursor, election_id, jurisdiction_id) and download_date:
620620
results[
621621
"title"
622622
] = f"""{results["title"]} as of {download_date} (preliminary)"""
@@ -946,7 +946,9 @@ def create_candidate_contests(df: pd.DataFrame, columns: List[str]) -> pd.DataFr
946946
return contest_df
947947

948948

949-
def create_ballot_measure_contests(df: pd.DataFrame, columns: List[str]) -> pd.DataFrame:
949+
def create_ballot_measure_contests(
950+
df: pd.DataFrame, columns: List[str]
951+
) -> pd.DataFrame:
950952
ballotmeasure_df = (
951953
df["ContestSelectionJoin"]
952954
.merge(
@@ -994,7 +996,7 @@ def human_readable_numbers(value: float) -> str:
994996
return "{:,}".format(round(value, -3))
995997

996998

997-
def sort_pivot_by_margins(df:pd.DataFrame) -> pd.DataFrame:
999+
def sort_pivot_by_margins(df: pd.DataFrame) -> pd.DataFrame:
9981000
"""grab the row with the highest anomaly score, then sort the remainder by
9991001
margin. The sorting order depends on whether the anomalous row is >50% or <50%"""
10001002

src/electiondata/database/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def test_connection_and_tables(
251251

252252
# look for database
253253
db_df = get_database_names(con)
254-
if not db_params["dbname"] in db_df.datname:
254+
if not db_params["dbname"] in db_df.datname.unique():
255255
# NB: this is not really an error, just leads to returning False.
256256
return False, err
257257

0 commit comments

Comments
 (0)