diff --git a/api/Archive/analyze_efp_schemas.py b/api/Archive/analyze_efp_schemas.py index 3b4193da..bfd3e32e 100644 --- a/api/Archive/analyze_efp_schemas.py +++ b/api/Archive/analyze_efp_schemas.py @@ -20,9 +20,21 @@ # Extra columns that some databases have (we want to know which ones) EXTRA_COLUMNS = { - "channel", "data_call", "data_num", "data_p_val", "data_p_value", - "genome", "genome_id", "log", "orthogroup", "p_val", "project_id", - "qvalue", "sample_file_name", "sample_tissue", "version", + "channel", + "data_call", + "data_num", + "data_p_val", + "data_p_value", + "genome", + "genome_id", + "log", + "orthogroup", + "p_val", + "project_id", + "qvalue", + "sample_file_name", + "sample_tissue", + "version", } @@ -94,7 +106,9 @@ def main(): # ---- 2. Group databases by their 3-column signature ---- print("\n" + "=" * 80) - print("GROUPING BY SIGNATURE (probeset_type, probeset_nullable, signal_nullable, signal_default, bot_type, bot_nullable)") + print( + "GROUPING BY SIGNATURE (probeset_type, probeset_nullable, signal_nullable, signal_default, bot_type, bot_nullable)" + ) print("=" * 80) sig_groups = defaultdict(list) @@ -103,7 +117,9 @@ def main(): sig_groups[sig].append(db) for sig, dbs in sorted(sig_groups.items(), key=lambda x: -len(x[1])): - print(f"\n Signature: probeset={sig[0]}(nullable={sig[1]}) signal(nullable={sig[2]}, default={sig[3]}) bot={sig[4]}(nullable={sig[5]})") + print( + f"\n Signature: probeset={sig[0]}(nullable={sig[1]}) signal(nullable={sig[2]}, default={sig[3]}) bot={sig[4]}(nullable={sig[5]})" + ) print(f" Count: {len(dbs)}") print(f" DBs: {', '.join(dbs[:10])}{'...' if len(dbs) > 10 else ''}") @@ -135,15 +151,17 @@ def main(): # Determine extra columns this DB needs extras = set(cols.keys()) - NEEDED_COLUMNS - {"proj_id", "sample_id"} - compact_entries.append({ - "db": db, - "probeset_len": probeset_len, # None = tinytext - "probeset_type": probeset_type, - "bot_len": bot_len, # None = tinytext - "bot_type": bot_type, - "signal_nullable": signal_nullable, - "extras": extras, - }) + compact_entries.append( + { + "db": db, + "probeset_len": probeset_len, # None = tinytext + "probeset_type": probeset_type, + "bot_len": bot_len, # None = tinytext + "bot_type": bot_type, + "signal_nullable": signal_nullable, + "extras": extras, + } + ) # ---- 4. Show the most compact table-driven representation ---- print("\n" + "=" * 80) @@ -180,9 +198,7 @@ def main(): for e in compact_entries: # Filter out databases that ONLY have unneeded extras # (sample_file_name, data_call, data_p_val etc. are not needed) - has_important_extras = e["extras"] - { - "sample_file_name", "data_call", "data_p_val", "data_p_value", "data_num" - } + has_important_extras = e["extras"] - {"sample_file_name", "data_call", "data_p_val", "data_p_value", "data_num"} if has_important_extras: complex_dbs.append(e) else: @@ -217,11 +233,13 @@ def main(): with open(SAMPLE_DATA_CSV, newline="") as f: reader = csv.DictReader(f) for row in reader: - db_samples[row["source_database"]].append({ - "data_bot_id": row["data_bot_id"], - "data_probeset_id": row["data_probeset_id"], - "data_signal": row["data_signal"], - }) + db_samples[row["source_database"]].append( + { + "data_bot_id": row["data_bot_id"], + "data_probeset_id": row["data_probeset_id"], + "data_signal": row["data_signal"], + } + ) print(f"Total databases with sample data: {len(db_samples)}") print(f"Total sample rows: {sum(len(v) for v in db_samples.values())}") diff --git a/api/__init__.py b/api/__init__.py index 37b6526c..fb0101a2 100644 --- a/api/__init__.py +++ b/api/__init__.py @@ -42,6 +42,7 @@ def create_app(): mysql_efp_base = bar_app.config.get("MYSQL_EFP_BASE_URI") if mysql_efp_base: from api.models.efp_schemas import SIMPLE_EFP_DATABASE_SCHEMAS + binds = bar_app.config.get("SQLALCHEMY_BINDS") or {} base = mysql_efp_base.rstrip("/") for db_name in SIMPLE_EFP_DATABASE_SCHEMAS: @@ -70,9 +71,9 @@ def create_app(): # On BAR, MySQL binds come from the server config — never build SQLite mirrors there. # For CI and local dev, determine whether to build SQLite mirrors. needs_sqlite_mirrors = ( - is_ci # always build on CI - or bar_app.config.get("TESTING") # config requests test mode - or "pytest" in os.sys.modules # running under pytest + is_ci # always build on CI + or bar_app.config.get("TESTING") # config requests test mode + or "pytest" in os.sys.modules # running under pytest or os.environ.get("BAR_API_AUTO_SQLITE_MIRRORS") == "1" # explicit override ) diff --git a/api/models/bar_utils.py b/api/models/bar_utils.py index e30ab466..61ce36c6 100644 --- a/api/models/bar_utils.py +++ b/api/models/bar_utils.py @@ -1,4 +1,4 @@ # Bridge file to maintain backward compatibility with imports from api.utils.bar_utils import BARUtils -__all__ = ['BARUtils'] +__all__ = ["BARUtils"] diff --git a/api/resources/fastpheno.py b/api/resources/fastpheno.py index bc9fcfa3..cd24c91d 100644 --- a/api/resources/fastpheno.py +++ b/api/resources/fastpheno.py @@ -13,7 +13,6 @@ from api.utils.bar_utils import BARUtils from markupsafe import escape - fastpheno = Namespace("FastPheno", description="FastPheno API service", path="/fastpheno") @@ -240,9 +239,7 @@ def get(self, tree_site_id, band): class FastPhenoSites(Resource): def get(self): """Returns all sites with coordinates, for initializing the map view.""" - rows = db.session.execute( - db.select(Sites).order_by(Sites.site_name) - ).scalars().all() + rows = db.session.execute(db.select(Sites).order_by(Sites.site_name)).scalars().all() res = [ { @@ -266,11 +263,11 @@ def get(self, sites_pk): if not BARUtils.is_integer(str(sites_pk)): return BARUtils.error_exit("Invalid sites_pk"), 400 - rows = db.session.execute( - db.select(Flights) - .where(Flights.sites_pk == sites_pk) - .order_by(Flights.flight_date) - ).scalars().all() + rows = ( + db.session.execute(db.select(Flights).where(Flights.sites_pk == sites_pk).order_by(Flights.flight_date)) + .scalars() + .all() + ) if len(rows) == 0: return BARUtils.error_exit("No flights found for the given site"), 400 @@ -297,12 +294,16 @@ def get(self, flights_pk): if not BARUtils.is_integer(str(flights_pk)): return BARUtils.error_exit("Invalid flights_pk"), 400 - rows = db.session.execute( - db.select(Bands.band) - .where(Bands.flights_pk == flights_pk) - .distinct() - .order_by(db.func.cast(db.func.regexp_replace(Bands.band, "[^0-9]", ""), db.Integer)) - ).scalars().all() + rows = ( + db.session.execute( + db.select(Bands.band) + .where(Bands.flights_pk == flights_pk) + .distinct() + .order_by(db.func.cast(db.func.regexp_replace(Bands.band, "[^0-9]", ""), db.Integer)) + ) + .scalars() + .all() + ) if len(rows) == 0: return BARUtils.error_exit("No bands found for the given flight"), 400 diff --git a/api/resources/gene_expression.py b/api/resources/gene_expression.py index 3e9ab958..a1b8c7cc 100644 --- a/api/resources/gene_expression.py +++ b/api/resources/gene_expression.py @@ -14,16 +14,14 @@ ) gene_expression = Namespace( - 'Gene Expression', - description='Gene expression data from BAR eFP databases', - path='/gene_expression', + "Gene Expression", + description="Gene expression data from BAR eFP databases", + path="/gene_expression", ) @gene_expression.route("/expression//") -@gene_expression.doc( - description="Retrieve gene expression values from a specified eFP database." -) +@gene_expression.doc(description="Retrieve gene expression values from a specified eFP database.") @gene_expression.param( "gene_id", "Gene ID (e.g. AT1G01010 for Arabidopsis, or a probeset like 261585_at)", @@ -76,4 +74,4 @@ def get(self, database, gene_id): return BARUtils.error_exit(result["error"]), result.get("error_code", 500) -gene_expression.add_resource(GeneExpression, '/expression//') +gene_expression.add_resource(GeneExpression, "/expression//") diff --git a/api/resources/gene_information.py b/api/resources/gene_information.py index 23c2f97a..c0c62094 100644 --- a/api/resources/gene_information.py +++ b/api/resources/gene_information.py @@ -15,7 +15,6 @@ from api import db from sqlalchemy import func - gene_information = Namespace("Gene Information", description="Information about Genes", path="/gene_information") parser = gene_information.parser() @@ -91,9 +90,9 @@ def post(self): # Query must be run individually for each species lowered_genes = [gene.lower() for gene in genes] - rows = db.session.execute( - db.select(database).where(func.lower(database.agi).in_(lowered_genes)) - ).scalars().all() + rows = ( + db.session.execute(db.select(database).where(func.lower(database.agi).in_(lowered_genes))).scalars().all() + ) # If there are any isoforms found, return data data = [] @@ -286,9 +285,9 @@ def post(self): gene_ids = [] gene_fail = [] for one_term in terms: - query = db.select(alias_database.agi).where( - func.lower(alias_database.agi).contains(one_term.lower()) - ).limit(1) + query = ( + db.select(alias_database.agi).where(func.lower(alias_database.agi).contains(one_term.lower())).limit(1) + ) result = db.session.execute(query).fetchone() if result is not None: gene_ids.append(result[0]) diff --git a/api/resources/interactions.py b/api/resources/interactions.py index 0c104122..b57d6591 100644 --- a/api/resources/interactions.py +++ b/api/resources/interactions.py @@ -474,7 +474,7 @@ def get(self, tag=""): "image_url": ex.image_url, "grn_title": ex.grn_title, "cyjs_layout": _normalize_cyjs_layout(ex.cyjs_layout), - "tag": "|".join(_sort_tag_strings(src_tag_match[ex.source_id])) + "tag": "|".join(_sort_tag_strings(src_tag_match[ex.source_id])), } result.append(one_source[source_id]) @@ -544,7 +544,7 @@ def get(self, number=""): "url": row.url, "image_url": row.image_url, "grn_title": row.grn_title, - "cyjs_layout": _normalize_cyjs_layout(row.cyjs_layout) + "cyjs_layout": _normalize_cyjs_layout(row.cyjs_layout), } ) @@ -607,20 +607,14 @@ def get(self, stringAGI=""): "source_name": row.source_name, "comments": row.comments, "cyjs_layout": _normalize_cyjs_layout(row.cyjs_layout), - "tags": [] + "tags": [], } tag_entry = f"{row.tag_name}:{row.tag_group}" if tag_entry not in result_dict[source_id]["tags"]: # DISTINCT result_dict[source_id]["tags"].append(tag_entry) - result = [ - { - **data, - "tags": "|".join(_sort_tag_strings(data["tags"])) - } - for data in result_dict.values() - ] + result = [{**data, "tags": "|".join(_sort_tag_strings(data["tags"]))} for data in result_dict.values()] result.sort(key=lambda item: (item["grn_title"] or "")) if len(result) == 0: @@ -681,7 +675,7 @@ def get(self, AGI_1="", AGI_2=""): "source_name": row.source_name, "comments": row.comments, "cyjs_layout": _normalize_cyjs_layout(row.cyjs_layout), - "tags": [] + "tags": [], } tag_entry = f"{row.tag_name}:{row.tag_group}" @@ -689,10 +683,7 @@ def get(self, AGI_1="", AGI_2=""): result_dict[source_id]["tags"].append(tag_entry) result = [ - { - **data, - "tags": "|".join(_sort_tag_strings_natural_case_sensitive(data["tags"])) - } + {**data, "tags": "|".join(_sort_tag_strings_natural_case_sensitive(data["tags"]))} for data in result_dict.values() ] result.sort(key=lambda item: item["source_id"]) @@ -746,7 +737,7 @@ def get(self): "image_url": ex.image_url, "grn_title": ex.grn_title, "cyjs_layout": _normalize_cyjs_layout(ex.cyjs_layout), - "tag": "|".join(_sort_tag_strings(src_tag_match[ex.source_id])) + "tag": "|".join(_sort_tag_strings(src_tag_match[ex.source_id])), } result.append(one_source[source_id]) diff --git a/api/resources/llama3.py b/api/resources/llama3.py index 54bcb757..c6ec4af8 100644 --- a/api/resources/llama3.py +++ b/api/resources/llama3.py @@ -10,7 +10,6 @@ from api import db from api.models.llama3 import Summaries - llama3 = Namespace("LLaMA", description="Endpoint for retreiving LLaMA3 results", path="/LLaMA") diff --git a/api/resources/microarray_gene_expression.py b/api/resources/microarray_gene_expression.py index 566f86b5..c308f925 100644 --- a/api/resources/microarray_gene_expression.py +++ b/api/resources/microarray_gene_expression.py @@ -77,7 +77,7 @@ def get(self, species=""): "Bud_Development": "actinidia_bud_development", "Flower_Fruit_Development": "actinidia_flower_fruit_development", "Postharvest": "actinidia_postharvest", - "Vegetative_Growth": "actinidia_vegetative_growth" + "Vegetative_Growth": "actinidia_vegetative_growth", }, "arabidopsis": { "Abiotic_Stress": "atgenexp_stress", @@ -105,68 +105,41 @@ def get(self, species=""): "Shoot_Apex": "shoot_apex", "Silique": "silique", "Single_Cell": "single_cell", - "Tissue_Specific": "atgenexp_plus" - }, - "arabidopsis seedcoat": { - "Seed_Coat": "seedcoat" - }, - "arachis": { - "Arachis_Atlas": "arachis" - }, - "barley": { - "barley_mas": "barley_mas", - "barley_rma": "barley_rma" + "Tissue_Specific": "atgenexp_plus", }, + "arabidopsis seedcoat": {"Seed_Coat": "seedcoat"}, + "arachis": {"Arachis_Atlas": "arachis"}, + "barley": {"barley_mas": "barley_mas", "barley_rma": "barley_rma"}, "brachypodium": { "Brachypodium_Atlas": "brachypodium", "Brachypodium_Grains": "brachypodium_grains", "Brachypodium_Spikes": "brachypodium_Bd21", - "Photo_Thermocycle": "brachypodium_photo_thermocycle" - }, - "brassica rapa": { - "Embryogenesis": "brassica_rapa" + "Photo_Thermocycle": "brachypodium_photo_thermocycle", }, + "brassica rapa": {"Embryogenesis": "brassica_rapa"}, "cacao ccn": { "Developmental_Atlas": "cacao_developmental_atlas", - "Drought_Diurnal_Atlas": "cacao_drought_diurnal_atlas" + "Drought_Diurnal_Atlas": "cacao_drought_diurnal_atlas", }, "cacao sca": { "Developmental_Atlas": "cacao_developmental_atlas_sca", "Drought_Diurnal_Atlas": "cacao_drought_diurnal_atlas_sca", "Meristem_Atlas": "cacao_meristem_atlas_sca", - "Seed_Atlas": "cacao_seed_atlas_sca" - }, - "cacao tc": { - "Cacao_Infection": "cacao_infection", - "Cacao_Leaf": "cacao_leaf" - }, - "camelina": { - "Developmental_Atlas_FPKM": "camelina", - "Developmental_Atlas_TPM": "camelina_tpm" - }, - "cannabis": { - "Cannabis_Atlas": "cannabis" - }, - "canola": { - "Canola_Seed": "canola_seed" - }, - "eutrema": { - "Eutrema": "thellungiella_db" - }, - "grape": { - "grape_developmental": "grape_developmental" - }, - "kalanchoe": { - "Light_Response": "kalanchoe" - }, - "little millet": { - "Life_Cycle": "little_millet" + "Seed_Atlas": "cacao_seed_atlas_sca", }, + "cacao tc": {"Cacao_Infection": "cacao_infection", "Cacao_Leaf": "cacao_leaf"}, + "camelina": {"Developmental_Atlas_FPKM": "camelina", "Developmental_Atlas_TPM": "camelina_tpm"}, + "cannabis": {"Cannabis_Atlas": "cannabis"}, + "canola": {"Canola_Seed": "canola_seed"}, + "eutrema": {"Eutrema": "thellungiella_db"}, + "grape": {"grape_developmental": "grape_developmental"}, + "kalanchoe": {"Light_Response": "kalanchoe"}, + "little millet": {"Life_Cycle": "little_millet"}, "lupin": { "LCM_Leaf": "lupin_lcm_leaf", "LCM_Pod": "lupin_lcm_pod", "LCM_Stem": "lupin_lcm_stem", - "Whole_Plant": "lupin_whole_plant" + "Whole_Plant": "lupin_whole_plant", }, "maize": { "Downs_et_al_Atlas": "maize_gdowns", @@ -180,7 +153,7 @@ def get(self, species=""): "Tassel_and_Ear_Primordia": "maize_ears", "maize_iplant": "maize_iplant", "maize_leaf_gradient": "maize_leaf_gradient", - "maize_rice_comparison": "maize_rice_comparison" + "maize_rice_comparison": "maize_rice_comparison", }, "mangosteen": { "Aril_vs_Rind": "mangosteen_aril_vs_rind", @@ -188,21 +161,15 @@ def get(self, species=""): "Diseased_vs_Normal": "mangosteen_diseased_vs_normal", "Fruit_Ripening": "mangosteen_fruit_ripening", "Seed_Development": "mangosteen_seed_development", - "Seed_Germination": "mangosteen_seed_germination" + "Seed_Germination": "mangosteen_seed_germination", }, "medicago": { "medicago_mas": "medicago_mas", "medicago_rma": "medicago_rma", - "medicago_seed": "medicago_seed" - }, - "poplar": { - "Poplar": "poplar", - "PoplarTreatment": "poplar" - }, - "potato": { - "Potato_Developmental": "potato_dev", - "Potato_Stress": "potato_stress" + "medicago_seed": "medicago_seed", }, + "poplar": {"Poplar": "poplar", "PoplarTreatment": "poplar"}, + "potato": {"Potato_Developmental": "potato_dev", "Potato_Stress": "potato_stress"}, "rice": { "rice_drought_heat_stress": "rice_drought_heat_stress", "rice_leaf_gradient": "rice_leaf_gradient", @@ -214,18 +181,18 @@ def get(self, species=""): "ricestigma_mas": "rice_mas", "ricestigma_rma": "rice_rma", "ricestress_mas": "rice_mas", - "ricestress_rma": "rice_rma" + "ricestress_rma": "rice_rma", }, "soybean": { "soybean": "soybean", "soybean_embryonic_development": "soybean_embryonic_development", "soybean_heart_cotyledon_globular": "soybean_heart_cotyledon_globular", "soybean_senescence": "soybean_senescence", - "soybean_severin": "soybean_severin" + "soybean_severin": "soybean_severin", }, "strawberry": { "Developmental_Map_Strawberry_Flower_and_Fruit": "strawberry", - "Strawberry_Green_vs_White_Stage": "strawberry" + "Strawberry_Green_vs_White_Stage": "strawberry", }, "tomato": { "ILs_Leaf_Chitwood_et_al": "tomato_ils", @@ -236,27 +203,21 @@ def get(self, species=""): "SEED_Lab_Angers": "tomato_seed", "Shade_Mutants": "tomato_shade_mutants", "Shade_Timecourse_WT": "tomato_shade_timecourse", - "Tomato_Meristem": "tomato_meristem" - }, - "triticale": { - "triticale": "triticale", - "triticale_mas": "triticale_mas" + "Tomato_Meristem": "tomato_meristem", }, + "triticale": {"triticale": "triticale", "triticale_mas": "triticale_mas"}, "wheat": { "Developmental_Atlas": "wheat", "Wheat_Abiotic_Stress": "wheat_abiotic_stress", "Wheat_Embryogenesis": "wheat_embryogenesis", - "Wheat_Meiosis": "wheat_meiosis" - } + "Wheat_Meiosis": "wheat_meiosis", + }, } if species not in species_databases: return BARUtils.error_exit("Invalid species") - return BARUtils.success_exit({ - "species": species, - "databases": species_databases[species] - }) + return BARUtils.success_exit({"species": species, "databases": species_databases[species]}) # endpoint made by reena @@ -285,17 +246,12 @@ def get(self, species="", view=""): # if user requests all views if view.lower() == "all": - return BARUtils.success_exit({ - "species": species, - "views": species_data["views"] - }) + return BARUtils.success_exit({"species": species, "views": species_data["views"]}) # otherwise check single view if view not in species_data["views"]: return BARUtils.error_exit("Invalid view for this species") - return BARUtils.success_exit({ - "species": species, - "view": view, - "groups": species_data["views"][view]["groups"] - }) + return BARUtils.success_exit( + {"species": species, "view": view, "groups": species_data["views"][view]["groups"]} + ) diff --git a/api/resources/rnaseq_gene_expression.py b/api/resources/rnaseq_gene_expression.py index 205ba96d..6a3e4253 100644 --- a/api/resources/rnaseq_gene_expression.py +++ b/api/resources/rnaseq_gene_expression.py @@ -28,10 +28,7 @@ } # metadata mirrors the schema catalog so validation stays in sync -DATABASE_METADATA = { - name: spec.get("metadata") or {} - for name, spec in SIMPLE_EFP_DATABASE_SCHEMAS.items() -} +DATABASE_METADATA = {name: spec.get("metadata") or {} for name, spec in SIMPLE_EFP_DATABASE_SCHEMAS.items()} # this is only needed for swagger ui post examples gene_expression_request_fields = rnaseq_gene_expression.model( diff --git a/api/resources/snps.py b/api/resources/snps.py index b026b737..0aaf0013 100644 --- a/api/resources/snps.py +++ b/api/resources/snps.py @@ -33,7 +33,6 @@ from api import db, cache, limiter from api.utils.docking_utils import Docker - snps = Namespace("SNPs", description="Information about SNPs", path="/snps") parser = snps.parser() diff --git a/api/services/efp_bootstrap.py b/api/services/efp_bootstrap.py index b2278f4e..9320d1dd 100644 --- a/api/services/efp_bootstrap.py +++ b/api/services/efp_bootstrap.py @@ -148,11 +148,13 @@ def ensure_database(server_url: URL, db_name: str, charset: str) -> None: :raises ValueError: If db_name or charset contains invalid characters """ # Validate database name to prevent SQL injection - only allow safe identifier characters - if not re.match(r'^[a-zA-Z0-9_$]+$', db_name): - raise ValueError(f"Invalid database name: {db_name}. Only alphanumeric, underscore, and dollar sign characters are allowed.") + if not re.match(r"^[a-zA-Z0-9_$]+$", db_name): + raise ValueError( + f"Invalid database name: {db_name}. Only alphanumeric, underscore, and dollar sign characters are allowed." + ) # Validate charset name to prevent SQL injection - only allow safe characters - if not re.match(r'^[a-zA-Z0-9_]+$', charset): + if not re.match(r"^[a-zA-Z0-9_]+$", charset): raise ValueError(f"Invalid charset name: {charset}. Only alphanumeric and underscore characters are allowed.") server_engine = create_engine(server_url) diff --git a/api/services/efp_data.py b/api/services/efp_data.py index 0bc927a7..d03bf5f6 100644 --- a/api/services/efp_data.py +++ b/api/services/efp_data.py @@ -251,10 +251,9 @@ def _iter_engine_candidates(database: str) -> Iterable[Tuple[str, Engine, bool]] # gene lookups use the index rather than a full table scan try: with sqlite_engine.begin() as _conn: - _conn.execute(text( - "CREATE INDEX IF NOT EXISTS ix_upper_probeset " - "ON sample_data (UPPER(data_probeset_id))" - )) + _conn.execute( + text("CREATE INDEX IF NOT EXISTS ix_upper_probeset " "ON sample_data (UPPER(data_probeset_id))") + ) except Exception: pass # read-only db or schema mismatch — best-effort yield ("sqlite_mirror", sqlite_engine, True) @@ -392,7 +391,7 @@ def query_efp_database_dynamic( (value_col, "value_column"), (table_name, "table"), ]: - if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', identifier): + if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", identifier): return { "success": False, "error": f"Invalid schema identifier for {name}: {identifier}", @@ -458,8 +457,7 @@ def query_efp_database_dynamic( return { "success": False, "error": ( - f"Database query failed for {database}. " - f"{'Last error: ' + last_error if last_error else ''}" + f"Database query failed for {database}. " f"{'Last error: ' + last_error if last_error else ''}" ).strip(), "error_code": 500, } diff --git a/api/utils/gene_id_utils.py b/api/utils/gene_id_utils.py index 0998f271..485b56ec 100644 --- a/api/utils/gene_id_utils.py +++ b/api/utils/gene_id_utils.py @@ -320,62 +320,64 @@ def is_probeset_id(gene_id: str) -> bool: # to one of these databases requires a gene→probeset conversion step. # --------------------------------------------------------------------------- -PROBESET_DATABASES: frozenset[str] = frozenset({ - # ── Arabidopsis ATH1 GeneChip ──────────────────────────────────────────── - "affydb", - "arabidopsis_ecotypes", - "atgenexp", - "atgenexp_hormone", - "atgenexp_pathogen", - "atgenexp_plus", - "atgenexp_stress", - "guard_cell", - "hnahal", - "lateral_root_initiation", - "light_series", - "meristem_db", - "meristem_db_new", - "root", - "rohan", - "rpatel", - "seed_db", - # ── Other species Affymetrix chips ─────────────────────────────────────── - # Lookup tables for these are pending; supply probeset directly (e.g. Contig3267_at) - "barley_mas", # Affymetrix Barley1 GeneChip (AK364622 → Contig3045_at) - "barley_rma", # Affymetrix Barley1 GeneChip - "human_developmental", # CCR5 → 206991_s_at - "human_developmental_SpongeLab", - "human_diseased", - "maize_gdowns", # Affymetrix Maize GeneChip (gene → Zm.XXXXX probeset IDs) - "medicago_mas", # Affymetrix Medicago GeneChip (Medtr* → Mtr.*_at) - "medicago_rma", # Affymetrix Medicago GeneChip - "poplar", # Affymetrix Poplar GeneChip (grail3.* → PtpAffx.*_at) - "rice_mas", # Affymetrix Rice GeneChip (LOC_Os* → Os.*_at) - "rice_rma", # Affymetrix Rice GeneChip - "triticale", # Affymetrix Wheat/Triticale GeneChip (EU* → Ta.*_at) - "triticale_mas", # Affymetrix Wheat/Triticale GeneChip - # ── Non-Affymetrix species with gene ID ≠ probeset ID ──────────────────── - # Confirmed from eFP browser: input gene ID differs from the stored probeset. - # Lookup tables pending for all of these. - "grape_developmental", # VIT_00s0120g00060 → CHRUN_JGVV120_4_T01 - "potato_dev", # PGSC0003DMP400000011 → PGSC0003DMG400000005 - "potato_stress", # same DMP → DMG gene model conversion - "potato_wounding", # same DMP → DMG gene model conversion - "soybean", # Glyma.06g316600 (new) → Glyma06g47400 (old format) - "soybean_embryonic_development", - "soybean_heart_cotyledon_globular", - "soybean_senescence", - "soybean_severin", - # ── Cross-species: input is an Arabidopsis AGI, stored as species gene ID ─ - # See CROSS_SPECIES_DATABASES below for the input validation override. - "phelipanche", # AT1G07890 → OrAeBC5_10.1 - "striga", # AT3G11400 → StHeBC3_1.1 - "thellungiella_db", # AT2G21470 → Thhalv10000089m.g - "triphysaria", # AT1G11260 → TrVeBC3_1.1 - # TODO: seedcoat (oat) – has species-specific probeset IDs; add once format confirmed - # TODO: strawberry – gene10171 → FvH4_1g00010; lookup table needed - # TODO: physcomitrella – Phypa_166136 → Pp1s103_79V6.1; lookup table needed -}) +PROBESET_DATABASES: frozenset[str] = frozenset( + { + # ── Arabidopsis ATH1 GeneChip ──────────────────────────────────────────── + "affydb", + "arabidopsis_ecotypes", + "atgenexp", + "atgenexp_hormone", + "atgenexp_pathogen", + "atgenexp_plus", + "atgenexp_stress", + "guard_cell", + "hnahal", + "lateral_root_initiation", + "light_series", + "meristem_db", + "meristem_db_new", + "root", + "rohan", + "rpatel", + "seed_db", + # ── Other species Affymetrix chips ─────────────────────────────────────── + # Lookup tables for these are pending; supply probeset directly (e.g. Contig3267_at) + "barley_mas", # Affymetrix Barley1 GeneChip (AK364622 → Contig3045_at) + "barley_rma", # Affymetrix Barley1 GeneChip + "human_developmental", # CCR5 → 206991_s_at + "human_developmental_SpongeLab", + "human_diseased", + "maize_gdowns", # Affymetrix Maize GeneChip (gene → Zm.XXXXX probeset IDs) + "medicago_mas", # Affymetrix Medicago GeneChip (Medtr* → Mtr.*_at) + "medicago_rma", # Affymetrix Medicago GeneChip + "poplar", # Affymetrix Poplar GeneChip (grail3.* → PtpAffx.*_at) + "rice_mas", # Affymetrix Rice GeneChip (LOC_Os* → Os.*_at) + "rice_rma", # Affymetrix Rice GeneChip + "triticale", # Affymetrix Wheat/Triticale GeneChip (EU* → Ta.*_at) + "triticale_mas", # Affymetrix Wheat/Triticale GeneChip + # ── Non-Affymetrix species with gene ID ≠ probeset ID ──────────────────── + # Confirmed from eFP browser: input gene ID differs from the stored probeset. + # Lookup tables pending for all of these. + "grape_developmental", # VIT_00s0120g00060 → CHRUN_JGVV120_4_T01 + "potato_dev", # PGSC0003DMP400000011 → PGSC0003DMG400000005 + "potato_stress", # same DMP → DMG gene model conversion + "potato_wounding", # same DMP → DMG gene model conversion + "soybean", # Glyma.06g316600 (new) → Glyma06g47400 (old format) + "soybean_embryonic_development", + "soybean_heart_cotyledon_globular", + "soybean_senescence", + "soybean_severin", + # ── Cross-species: input is an Arabidopsis AGI, stored as species gene ID ─ + # See CROSS_SPECIES_DATABASES below for the input validation override. + "phelipanche", # AT1G07890 → OrAeBC5_10.1 + "striga", # AT3G11400 → StHeBC3_1.1 + "thellungiella_db", # AT2G21470 → Thhalv10000089m.g + "triphysaria", # AT1G11260 → TrVeBC3_1.1 + # TODO: seedcoat (oat) – has species-specific probeset IDs; add once format confirmed + # TODO: strawberry – gene10171 → FvH4_1g00010; lookup table needed + # TODO: physcomitrella – Phypa_166136 → Pp1s103_79V6.1; lookup table needed + } +) # --------------------------------------------------------------------------- # Cross-species input databases @@ -390,10 +392,10 @@ def is_probeset_id(gene_id: str) -> bool: CROSS_SPECIES_DATABASES: dict[str, str] = { # database → species of the expected INPUT gene ID - "phelipanche": "arabidopsis", # AT* AGI → OrAeBC5_* probeset - "striga": "arabidopsis", # AT* AGI → StHeBC3_* probeset - "thellungiella_db": "arabidopsis", # AT* AGI → Thhalv* probeset - "triphysaria": "arabidopsis", # AT* AGI → TrVeBC3_* probeset + "phelipanche": "arabidopsis", # AT* AGI → OrAeBC5_* probeset + "striga": "arabidopsis", # AT* AGI → StHeBC3_* probeset + "thellungiella_db": "arabidopsis", # AT* AGI → Thhalv* probeset + "triphysaria": "arabidopsis", # AT* AGI → TrVeBC3_* probeset } @@ -401,6 +403,7 @@ def is_probeset_id(gene_id: str) -> bool: # Species detection from gene ID format # --------------------------------------------------------------------------- + def detect_gene_species(gene_id: str) -> Optional[str]: """Infer the species of *gene_id* from its format using regex validators. @@ -466,25 +469,25 @@ def detect_gene_species(gene_id: str) -> Optional[str]: # Dispatch table: canonical species key → BARUtils validator _VALIDATORS: dict = { - "arabidopsis": BARUtils.is_arabidopsis_gene_valid, - "arachis": BARUtils.is_arachis_gene_valid, - "brassica": BARUtils.is_brassica_rapa_gene_valid, - "cannabis": BARUtils.is_cannabis_gene_valid, - "canola": BARUtils.is_canola_gene_valid, - "grape": BARUtils.is_grape_gene_valid, - "kalanchoe": BARUtils.is_kalanchoe_gene_valid, - "maize": BARUtils.is_maize_gene_valid, - "phelipanche": BARUtils.is_phelipanche_gene_valid, + "arabidopsis": BARUtils.is_arabidopsis_gene_valid, + "arachis": BARUtils.is_arachis_gene_valid, + "brassica": BARUtils.is_brassica_rapa_gene_valid, + "cannabis": BARUtils.is_cannabis_gene_valid, + "canola": BARUtils.is_canola_gene_valid, + "grape": BARUtils.is_grape_gene_valid, + "kalanchoe": BARUtils.is_kalanchoe_gene_valid, + "maize": BARUtils.is_maize_gene_valid, + "phelipanche": BARUtils.is_phelipanche_gene_valid, "physcomitrella": BARUtils.is_physcomitrella_gene_valid, - "poplar": BARUtils.is_poplar_gene_valid, - "rice": BARUtils.is_rice_gene_valid, - "selaginella": BARUtils.is_selaginella_gene_valid, - "sorghum": BARUtils.is_sorghum_gene_valid, - "soybean": BARUtils.is_soybean_gene_valid, - "strawberry": BARUtils.is_strawberry_gene_valid, - "striga": BARUtils.is_striga_gene_valid, - "tomato": BARUtils.is_tomato_gene_valid, - "triphysaria": BARUtils.is_triphysaria_gene_valid, + "poplar": BARUtils.is_poplar_gene_valid, + "rice": BARUtils.is_rice_gene_valid, + "selaginella": BARUtils.is_selaginella_gene_valid, + "sorghum": BARUtils.is_sorghum_gene_valid, + "soybean": BARUtils.is_soybean_gene_valid, + "strawberry": BARUtils.is_strawberry_gene_valid, + "striga": BARUtils.is_striga_gene_valid, + "tomato": BARUtils.is_tomato_gene_valid, + "triphysaria": BARUtils.is_triphysaria_gene_valid, } @@ -550,6 +553,7 @@ def normalize_gene_id(gene_id: str, species: str) -> str: # Gene ID → probeset conversion # --------------------------------------------------------------------------- + def convert_gene_to_probeset( gene_id: str, species: str, @@ -607,6 +611,7 @@ def convert_gene_to_probeset( if species == "arabidopsis": # Lazy import avoids circular dependencies at module load time from api.services.efp_data import EFPDataService # noqa: PLC0415 + probeset = EFPDataService.agi_to_probst(gene_id.upper()) if probeset: return probeset, None diff --git a/scripts/benchmark_efp.py b/scripts/benchmark_efp.py index a3f0bbc7..acaa5a11 100644 --- a/scripts/benchmark_efp.py +++ b/scripts/benchmark_efp.py @@ -35,6 +35,7 @@ from typing import Any, Dict, List, Optional, Tuple import matplotlib + matplotlib.use("Agg") # non-interactive — safe in CI and SSH sessions @@ -93,6 +94,7 @@ # Dump parsing helpers # =========================================================================== + def _parse_tuple_fields(raw: str) -> List[str]: """Split a raw SQL tuple string on ',' and strip quotes from each field.""" return [f.strip().strip("'\"") for f in raw.split(",")] @@ -158,6 +160,7 @@ def count_dump_samples(db_name: str) -> int: # Part 1 — Model generation benchmark (flat vs dynamic) # =========================================================================== + def _simulate_flat_class_creation(db_names: List[str]) -> float: """Time the overhead of defining flat-file model classes via type(). @@ -171,13 +174,17 @@ def _simulate_flat_class_creation(db_names: List[str]) -> float: t0 = time.perf_counter() for db_name in db_names: class_name = "".join(p.capitalize() for p in db_name.split("_")) + "SampleData" - type(class_name, (), { - "__bind_key__": db_name, - "__tablename__": "sample_data", - "data_probeset_id": None, - "data_signal": None, - "data_bot_id": None, - }) + type( + class_name, + (), + { + "__bind_key__": db_name, + "__tablename__": "sample_data", + "data_probeset_id": None, + "data_signal": None, + "data_bot_id": None, + }, + ) return time.perf_counter() - t0 @@ -252,6 +259,7 @@ def _stats(times: List[float]) -> Dict[str, float]: # Part 2 — RAM benchmark # =========================================================================== + def _rss_kb() -> float: """Return current peak RSS in KB (macOS returns bytes; Linux returns KB).""" usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss @@ -281,13 +289,17 @@ def benchmark_memory() -> Dict[str, Any]: flat_models: Dict[str, Any] = {} for db_name in db_names: class_name = "".join(p.capitalize() for p in db_name.split("_")) + "FlatSD" - flat_models[db_name] = type(class_name, (), { - "__bind_key__": db_name, - "__tablename__": "sample_data", - "data_probeset_id": None, - "data_signal": None, - "data_bot_id": None, - }) + flat_models[db_name] = type( + class_name, + (), + { + "__bind_key__": db_name, + "__tablename__": "sample_data", + "data_probeset_id": None, + "data_signal": None, + "data_bot_id": None, + }, + ) snap1 = tracemalloc.take_snapshot() rss1 = _rss_kb() @@ -401,6 +413,7 @@ def benchmark_api( :return: Nested dict: {db_name: {avg_ms, min_ms, max_ms, genes_tested, success_rate, first_error}}. """ import time as _time + results: Dict[str, Any] = {} # A single persistent session reuses the TCP+TLS connection across all requests # (HTTP keep-alive). This is especially important for ngrok where each new @@ -652,24 +665,17 @@ def verify_values_vs_bar( continue try: - local_resp = requests.get( - f"{base_url}/gene_expression/expression/{db_name}/{gene}", timeout=10 - ) + local_resp = requests.get(f"{base_url}/gene_expression/expression/{db_name}/{gene}", timeout=10) local_obj = local_resp.json() if not local_obj.get("success"): print(f" {gene}: local error — {local_obj.get('error')}") continue - local_vals: Dict[str, float] = { - row["name"]: float(row["value"]) for row in local_obj.get("data", []) - } + local_vals: Dict[str, float] = {row["name"]: float(row["value"]) for row in local_obj.get("data", [])} except Exception as exc: print(f" {gene}: local request failed — {exc}") continue - mismatches = sum( - 1 for s, bv in bar_vals.items() - if s in local_vals and abs(bv - local_vals[s]) > tol - ) + mismatches = sum(1 for s, bv in bar_vals.items() if s in local_vals and abs(bv - local_vals[s]) > tol) missing_in_local = sum(1 for s in bar_vals if s not in local_vals) extra_in_local = sum(1 for s in local_vals if s not in bar_vals) @@ -680,8 +686,10 @@ def verify_values_vs_bar( else: status = f"COUNT DIFF (missing={missing_in_local} extra={extra_in_local})" - print(f" {gene}: BAR={len(bar_vals)} samples local={len(local_vals)} samples " - f"mismatches={mismatches} [{status}]") + print( + f" {gene}: BAR={len(bar_vals)} samples local={len(local_vals)} samples " + f"mismatches={mismatches} [{status}]" + ) # =========================================================================== @@ -725,19 +733,23 @@ def plot_model_generation(results: Dict[str, Any]) -> None: # error bars show min/max range lower = [avgs[i] - mins[i] for i in range(len(avgs))] upper = [maxs[i] - avgs[i] for i in range(len(avgs))] - ax.errorbar(x, avgs, yerr=[lower, upper], fmt="none", color="black", - capsize=6, linewidth=1.5, zorder=4) + ax.errorbar(x, avgs, yerr=[lower, upper], fmt="none", color="black", capsize=6, linewidth=1.5, zorder=4) for bar, val in zip(bars, avgs): - ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.005, - f"{val:.3f} ms", ha="center", va="bottom", fontsize=9) + ax.text( + bar.get_x() + bar.get_width() / 2, + bar.get_height() + 0.005, + f"{val:.3f} ms", + ha="center", + va="bottom", + fontsize=9, + ) ax.set_xticks(list(x)) ax.set_xticklabels(labels, fontsize=10) ax.set_ylabel("Time (ms per full iteration)", fontsize=10) ax.set_title( - f"Model Creation Time — {results['model_count']} databases, " - f"{results['iterations']} iterations", + f"Model Creation Time — {results['model_count']} databases, " f"{results['iterations']} iterations", fontsize=11, ) ax.yaxis.grid(True, linestyle="--", alpha=0.6) @@ -764,8 +776,7 @@ def plot_memory(mem: Dict[str, Any]) -> None: for bar in list(b1) + list(b2): h = bar.get_height() - ax.text(bar.get_x() + bar.get_width() / 2, h + 0.5, - f"{h:.1f}", ha="center", va="bottom", fontsize=8) + ax.text(bar.get_x() + bar.get_width() / 2, h + 0.5, f"{h:.1f}", ha="center", va="bottom", fontsize=8) ax.set_xticks(list(x)) ax.set_xticklabels(categories, fontsize=10) @@ -791,6 +802,7 @@ def plot_query_times( :param legacy_cgi_results: Output of benchmark_legacy_efp_cgi(). :param databases: Ordered list of database names to include. """ + def _get_avg(result_dict: Dict[str, Any], db: str) -> float: return result_dict.get(db, {}).get("avg_ms", 0.0) @@ -816,8 +828,15 @@ def _get_avg(result_dict: Dict[str, Any], db: str) -> float: bars = ax.bar(positions, avgs, width=width * 0.95, label=label, color=color, zorder=3) for bar, val in zip(bars, avgs): if val > 0: - ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 2, - f"{val:.0f}", ha="center", va="bottom", fontsize=7, rotation=90) + ax.text( + bar.get_x() + bar.get_width() / 2, + bar.get_height() + 2, + f"{val:.0f}", + ha="center", + va="bottom", + fontsize=7, + rotation=90, + ) ax.set_xticks(list(x)) ax.set_xticklabels(databases, fontsize=9, rotation=20, ha="right") @@ -876,8 +895,7 @@ def plot_query_distributions( ax.set_visible(False) continue - bp = ax.boxplot(data_sets, patch_artist=True, notch=False, - medianprops={"color": "black", "linewidth": 1.5}) + bp = ax.boxplot(data_sets, patch_artist=True, notch=False, medianprops={"color": "black", "linewidth": 1.5}) for patch, color in zip(bp["boxes"], colors_bp): patch.set_facecolor(color) patch.set_alpha(0.75) @@ -896,23 +914,27 @@ def plot_query_distributions( # Main # =========================================================================== + def main() -> None: parser = argparse.ArgumentParser(description="Benchmark EFP flat vs dynamic models") - parser.add_argument("--local-url", default="http://localhost:5000", - help="Base URL of your local BAR API server") - parser.add_argument("--ngrok-url", default=None, - help="ngrok tunnel URL (e.g. https://xxxx.ngrok-free.app) " - "for internet-latency-aware benchmarking") - parser.add_argument("--iterations", type=int, default=50, - help="Iterations for model-generation microbenchmark (default 50)") - parser.add_argument("--query-genes", type=int, default=20, - help="Number of gene IDs per DB to sample for query benchmarks (default 20)") - parser.add_argument("--query-iters", type=int, default=3, - help="Repeat each gene query this many times (default 3)") - parser.add_argument("--skip-bar", action="store_true", - help="Skip remote BAR CGI requests") - parser.add_argument("--seed", type=int, default=42, - help="Random seed for gene sampling (default 42)") + parser.add_argument("--local-url", default="http://localhost:5000", help="Base URL of your local BAR API server") + parser.add_argument( + "--ngrok-url", + default=None, + help="ngrok tunnel URL (e.g. https://xxxx.ngrok-free.app) " "for internet-latency-aware benchmarking", + ) + parser.add_argument( + "--iterations", type=int, default=50, help="Iterations for model-generation microbenchmark (default 50)" + ) + parser.add_argument( + "--query-genes", + type=int, + default=20, + help="Number of gene IDs per DB to sample for query benchmarks (default 20)", + ) + parser.add_argument("--query-iters", type=int, default=3, help="Repeat each gene query this many times (default 3)") + parser.add_argument("--skip-bar", action="store_true", help="Skip remote BAR CGI requests") + parser.add_argument("--seed", type=int, default=42, help="Random seed for gene sampling (default 42)") args = parser.parse_args() random.seed(args.seed) @@ -931,13 +953,12 @@ def main() -> None: for db_name in databases: all_genes = extract_genes_from_dump(db_name) sample_count = count_dump_samples(db_name) - genes_per_db[db_name] = ( - random.sample(all_genes, min(args.query_genes, len(all_genes))) - if all_genes else [] - ) + genes_per_db[db_name] = random.sample(all_genes, min(args.query_genes, len(all_genes))) if all_genes else [] genes_per_db[db_name].sort() - print(f" {db_name:15s} genes={len(all_genes):6d} total_rows={sample_count:8d} " - f"query_sample={len(genes_per_db[db_name])}") + print( + f" {db_name:15s} genes={len(all_genes):6d} total_rows={sample_count:8d} " + f"query_sample={len(genes_per_db[db_name])}" + ) if not databases: print(" [ERROR] No dump files found in api/Archive/. Exiting.") @@ -950,10 +971,11 @@ def main() -> None: gen_results = benchmark_model_generation(args.iterations) flat_d = gen_results["flat"] dynamic_d = gen_results["dynamic"] - print(f" Flat-file avg={flat_d['avg_ms']:.3f} ms " - f"min={flat_d['min_ms']:.3f} max={flat_d['max_ms']:.3f}") - print(f" Dynamic avg={dynamic_d['avg_ms']:.3f} ms " - f"min={dynamic_d['min_ms']:.3f} max={dynamic_d['max_ms']:.3f}") + print(f" Flat-file avg={flat_d['avg_ms']:.3f} ms " f"min={flat_d['min_ms']:.3f} max={flat_d['max_ms']:.3f}") + print( + f" Dynamic avg={dynamic_d['avg_ms']:.3f} ms " + f"min={dynamic_d['min_ms']:.3f} max={dynamic_d['max_ms']:.3f}" + ) print(f" Models: {gen_results['model_count']}") ratio = dynamic_d["avg_ms"] / max(flat_d["avg_ms"], 1e-9) winner = "flat-file" if flat_d["avg_ms"] < dynamic_d["avg_ms"] else "dynamic" @@ -964,24 +986,30 @@ def main() -> None: # ------------------------------------------------------------------ print("\n[2] Memory usage") mem_results = benchmark_memory() - print(f" Flat-file heap={mem_results['flat_heap_kb']:.1f} KB " - f"RSS delta={mem_results['flat_rss_kb']:.1f} KB") - print(f" Dynamic heap={mem_results['dynamic_heap_kb']:.1f} KB " - f"RSS delta={mem_results['dynamic_rss_kb']:.1f} KB") + print(f" Flat-file heap={mem_results['flat_heap_kb']:.1f} KB " f"RSS delta={mem_results['flat_rss_kb']:.1f} KB") + print( + f" Dynamic heap={mem_results['dynamic_heap_kb']:.1f} KB " + f"RSS delta={mem_results['dynamic_rss_kb']:.1f} KB" + ) # ------------------------------------------------------------------ # 3. Local API query benchmark # ------------------------------------------------------------------ print(f"\n[3] Local API query benchmark ({args.local_url})") local_results = benchmark_api( - "local", args.local_url, databases, genes_per_db, + "local", + args.local_url, + databases, + genes_per_db, iterations=args.query_iters, ) for db_name, stats in local_results.items(): rate = stats["success_rate"] - line = (f" {db_name:15s} avg={stats['avg_ms']:>8.2f} ms " - f"genes={stats['genes_tested']} success={rate}% " - f"avg_records={stats['avg_records']}") + line = ( + f" {db_name:15s} avg={stats['avg_ms']:>8.2f} ms " + f"genes={stats['genes_tested']} success={rate}% " + f"avg_records={stats['avg_records']}" + ) print(line) if rate < 100 and stats.get("top_error"): print(f" -> most common error: {stats['top_error']}") @@ -1000,7 +1028,10 @@ def main() -> None: print(f"\n[4] ngrok tunnel query benchmark ({real_ngrok_url})") print(" (throttling to 1 req/1.5 s to stay under ngrok free-tier rate limit)") ngrok_results = benchmark_api( - "ngrok", real_ngrok_url, databases, genes_per_db, + "ngrok", + real_ngrok_url, + databases, + genes_per_db, iterations=args.query_iters, inter_request_delay=1.5, ) @@ -1021,9 +1052,7 @@ def main() -> None: if not args.skip_bar: print("\n[5a] Legacy BAR eFP Browser CGI (per-database URLs in BAR_LEGACY_CGI_VIEWS)") print(" (full HTML page render — this is what gave 1000-2000 ms originally)") - legacy_cgi_results = benchmark_legacy_efp_cgi( - databases, genes_per_db, iterations=args.query_iters - ) + legacy_cgi_results = benchmark_legacy_efp_cgi(databases, genes_per_db, iterations=args.query_iters) for db_name, stats in legacy_cgi_results.items(): rate = stats["success_rate"] print(f" {db_name:15s} avg={stats['avg_ms']:>8.2f} ms success={rate}%") @@ -1069,12 +1098,13 @@ def main() -> None: print(f" {db_name}:") print(f" Local API : {local_ms:>8.2f} ms (success={local_rate}%)") if ngrok_ms: - print(f" ngrok (internet): {ngrok_ms:>8.2f} ms (success={ngrok_rate}% " - f"+{ngrok_ms - local_ms:.2f} ms latency overhead)") + print( + f" ngrok (internet): {ngrok_ms:>8.2f} ms (success={ngrok_rate}% " + f"+{ngrok_ms - local_ms:.2f} ms latency overhead)" + ) if cgi_ms: speedup = cgi_ms / max(local_ms, 0.01) - print(f" Legacy CGI : {cgi_ms:>8.2f} ms (success={cgi_rate}% " - f"local is {speedup:.1f}x faster)") + print(f" Legacy CGI : {cgi_ms:>8.2f} ms (success={cgi_rate}% " f"local is {speedup:.1f}x faster)") print() print(f" Plots saved to: {RESULTS_DIR}/") diff --git a/scripts/proofread_efp_values.py b/scripts/proofread_efp_values.py index c89e8296..06982109 100644 --- a/scripts/proofread_efp_values.py +++ b/scripts/proofread_efp_values.py @@ -66,7 +66,7 @@ } BAR_CGI_URL = "https://bar.utoronto.ca/eplant/cgi-bin/plantefp.cgi" -BAR_EFP_DATA = "https://bar.utoronto.ca/efp/data" # XML directory +BAR_EFP_DATA = "https://bar.utoronto.ca/efp/data" # XML directory # --------------------------------------------------------------------------- # View → database name mapping @@ -279,6 +279,7 @@ # Shared helpers # =========================================================================== + def _parse_fields(raw: str) -> List[str]: """Split a raw SQL tuple on ',' and strip surrounding quotes from each field.""" return [f.strip().strip("'\"") for f in raw.split(",")] @@ -328,6 +329,7 @@ def query_bar_cgi(db_name: str, gene_id: str, samples: List[str]) -> List[Dict[s # Mode A — SQL dump parsing # =========================================================================== + def extract_all_genes(db_name: str) -> List[str]: """Return all unique gene/probeset IDs found in the dump's sample_data table. @@ -381,7 +383,7 @@ def parse_dump_for_gene(db_name: str, gene_id: str) -> Dict[str, float]: for line in fh: if "INSERT INTO `sample_data`" not in line: continue - if gene_id not in line: # fast pre-filter + if gene_id not in line: # fast pre-filter continue for m in re.finditer(r"\(([^)]+)\)", line): fields = _parse_fields(m.group(1)) @@ -431,15 +433,15 @@ def compare_dump_vs_api( if abs(dump_val - api_val) <= tolerance: matches += 1 else: - mismatches.append({ - "sample": sample, - "dump": dump_val, - "api": api_val, - "abs_diff": round(abs(dump_val - api_val), 8), - "rel_diff_%": round( - abs(dump_val - api_val) / max(abs(dump_val), 1e-12) * 100, 4 - ), - }) + mismatches.append( + { + "sample": sample, + "dump": dump_val, + "api": api_val, + "abs_diff": round(abs(dump_val - api_val), 8), + "rel_diff_%": round(abs(dump_val - api_val) / max(abs(dump_val), 1e-12) * 100, 4), + } + ) else: missing_in_api.append(sample) @@ -447,11 +449,7 @@ def compare_dump_vs_api( if sample not in dump_values: extra_in_api.append(sample) - status = ( - "OK" if not mismatches and not missing_in_api else - "MISMATCH" if mismatches else - "MISSING" - ) + status = "OK" if not mismatches and not missing_in_api else "MISMATCH" if mismatches else "MISSING" return { "dump_sample_count": len(dump_values), "api_sample_count": len(api_lookup), @@ -501,8 +499,7 @@ def proofread_gene_dump( err = api_result.get("error", "unknown error") if verbose: print(f" [API FAIL] {err}") - return {"gene_id": gene_id, "status": "API_FAIL", "error": err, - "dump_samples": len(dump_values)} + return {"gene_id": gene_id, "status": "API_FAIL", "error": err, "dump_samples": len(dump_values)} api_data = api_result.get("data", []) if verbose: @@ -516,14 +513,15 @@ def proofread_gene_dump( if verbose: if cmp["status"] == "OK": - print(f"\n [OK] {cmp['matches']}/{cmp['dump_sample_count']} samples match " - f"(±1e-4)") + print(f"\n [OK] {cmp['matches']}/{cmp['dump_sample_count']} samples match " f"(±1e-4)") elif cmp["mismatches"]: print(f"\n [MISMATCH] {len(cmp['mismatches'])} differences (showing ≤10):") for m in cmp["mismatches"][:10]: - print(f" {m['sample']:<34s} dump={m['dump']:<14} " - f"api={m['api']:<14} diff={m['abs_diff']} " - f"({m['rel_diff_%']}%)") + print( + f" {m['sample']:<34s} dump={m['dump']:<14} " + f"api={m['api']:<14} diff={m['abs_diff']} " + f"({m['rel_diff_%']}%)" + ) if cmp["missing_in_api"]: print(f" [MISSING] {len(cmp['missing_in_api'])} dump samples absent from API:") for s in cmp["missing_in_api"][:5]: @@ -541,8 +539,10 @@ def proofread_gene_dump( confirmed = [d["name"] for d in cgi_data if d.get("value") is not None] not_found = [d["name"] for d in cgi_data if d.get("value") is None] cgi_summary = { - "checked": len(cgi_data), "confirmed": len(confirmed), - "not_found": len(not_found), "not_found_samples": not_found[:10], + "checked": len(cgi_data), + "confirmed": len(confirmed), + "not_found": len(not_found), + "not_found_samples": not_found[:10], } if verbose: print(f" BAR CGI: {len(confirmed)}/{len(cgi_data)} samples confirmed") @@ -608,15 +608,16 @@ def run_mode_a( gene_results: List[Dict[str, Any]] = [] for gene_id in genes: - result = proofread_gene_dump(base_url, db_name, gene_id, - check_cgi=check_cgi, verbose=verbose) + result = proofread_gene_dump(base_url, db_name, gene_id, check_cgi=check_cgi, verbose=verbose) gene_results.append(result) statuses = [r["status"] for r in gene_results] print(f"\n ── {db_name} summary ──") - print(f" OK={statuses.count('OK')} MISMATCH={statuses.count('MISMATCH')} " - f"MISSING={statuses.count('MISSING')} " - f"API_FAIL={statuses.count('API_FAIL')} SKIP={statuses.count('SKIP')}") + print( + f" OK={statuses.count('OK')} MISMATCH={statuses.count('MISMATCH')} " + f"MISSING={statuses.count('MISSING')} " + f"API_FAIL={statuses.count('API_FAIL')} SKIP={statuses.count('SKIP')}" + ) all_results[db_name] = gene_results return all_results @@ -626,6 +627,7 @@ def run_mode_a( # Mode B — XML-based sample verification # =========================================================================== + def fetch_efp_xml(view_name: str) -> Optional[str]: """Download the eFP Browser XML file for a given view name. @@ -697,26 +699,20 @@ def parse_efp_xml(xml_content: str) -> Dict[str, Any]: for group_elem in view_elem.findall(".//group"): # collect controls in this group - group_controls = [ - ctrl.get("sample", "") - for ctrl in group_elem.findall("control") - if ctrl.get("sample") - ] + group_controls = [ctrl.get("sample", "") for ctrl in group_elem.findall("control") if ctrl.get("sample")] all_controls.update(group_controls) for tissue_elem in group_elem.findall("tissue"): tissue_name = tissue_elem.get("name", "") - tissue_samples = [ - s.get("name", "") - for s in tissue_elem.findall("sample") - if s.get("name") - ] + tissue_samples = [s.get("name", "") for s in tissue_elem.findall("sample") if s.get("name")] all_samples.update(tissue_samples) - tissues.append({ - "name": tissue_name, - "samples": tissue_samples, - "controls": group_controls, - }) + tissues.append( + { + "name": tissue_name, + "samples": tissue_samples, + "controls": group_controls, + } + ) return { "db": db_name, @@ -743,9 +739,11 @@ def resolve_db_for_view(view_name: str, species: Optional[str] = None) -> Option db = VIEW_TO_DB.get(view_name) if db: if view_name in _AMBIGUOUS_VIEWS: - print(f" [WARN] '{view_name}' maps to multiple DBs: " - f"{_AMBIGUOUS_VIEWS[view_name] + [db]}. " - f"Using first match '{db}'. Pass --species to disambiguate.") + print( + f" [WARN] '{view_name}' maps to multiple DBs: " + f"{_AMBIGUOUS_VIEWS[view_name] + [db]}. " + f"Using first match '{db}'. Pass --species to disambiguate." + ) return db return None @@ -788,10 +786,11 @@ def proofread_from_xml( if not db_name_from_map: return {"status": "XML_FAIL", "view": view_name, "gene_id": gene_id} # we have a db name from the map but no XML — can't get sample IDs - print(f" [WARN] Could not fetch XML; db='{db_name_from_map}' from mapping " - f"but sample list unavailable. Skipping.") - return {"status": "XML_FAIL", "view": view_name, "gene_id": gene_id, - "db": db_name_from_map} + print( + f" [WARN] Could not fetch XML; db='{db_name_from_map}' from mapping " + f"but sample list unavailable. Skipping." + ) + return {"status": "XML_FAIL", "view": view_name, "gene_id": gene_id, "db": db_name_from_map} xml_data = parse_efp_xml(xml_content) if not xml_data: @@ -804,10 +803,12 @@ def proofread_from_xml( return {"status": "NO_DB", "view": view_name, "gene_id": gene_id} if db_name_from_map and xml_data.get("db") and db_name_from_map != xml_data["db"]: - print(f" [WARN] Mapping says db='{db_name_from_map}' but XML says " - f"db='{xml_data['db']}'. Using mapping value.") + print( + f" [WARN] Mapping says db='{db_name_from_map}' but XML says " + f"db='{xml_data['db']}'. Using mapping value." + ) - expected_sids = xml_data["all_samples"] # data_bot_id values from XML + expected_sids = xml_data["all_samples"] # data_bot_id values from XML tissues = xml_data["tissues"] print(f" DB (from XML): {db_name}") @@ -825,9 +826,14 @@ def proofread_from_xml( err = api_result.get("error", "unknown error") if verbose: print(f"\n [API FAIL] {err}") - return {"status": "API_FAIL", "view": view_name, "db": db_name, - "gene_id": gene_id, "error": err, - "expected_samples": len(expected_sids)} + return { + "status": "API_FAIL", + "view": view_name, + "db": db_name, + "gene_id": gene_id, + "error": err, + "expected_samples": len(expected_sids), + } api_data = api_result.get("data", []) api_lookup: Dict[str, Optional[float]] = {} @@ -862,25 +868,22 @@ def proofread_from_xml( t_present = [s for s in tissue["samples"] if s in api_lookup and api_lookup[s] is not None] t_missing = [s for s in tissue["samples"] if s not in api_lookup] t_null = [s for s in tissue["samples"] if s in api_lookup and api_lookup[s] is None] - tissue_rows.append({ - "tissue": tissue["name"], - "total": len(tissue["samples"]), - "present": len(t_present), - "missing": len(t_missing), - "null": len(t_null), - "ok": len(t_missing) == 0 and len(t_null) == 0, - "samples_missing": t_missing, - }) - - status = ( - "OK" if not missing and not null_val else - "MISSING" if missing else - "NULL_VALUES" - ) + tissue_rows.append( + { + "tissue": tissue["name"], + "total": len(tissue["samples"]), + "present": len(t_present), + "missing": len(t_missing), + "null": len(t_null), + "ok": len(t_missing) == 0 and len(t_null) == 0, + "samples_missing": t_missing, + } + ) + + status = "OK" if not missing and not null_val else "MISSING" if missing else "NULL_VALUES" if verbose: - print(f"\n Results: {len(present)}/{len(expected_sids)} expected samples present " - f"with valid values") + print(f"\n Results: {len(present)}/{len(expected_sids)} expected samples present " f"with valid values") if missing: print(f" [MISSING] {len(missing)} expected sample IDs not in API response:") for s in missing[:10]: @@ -892,8 +895,7 @@ def proofread_from_xml( for s in null_val[:5]: print(f" {s}") if extra_in_api: - print(f" [EXTRA] {len(extra_in_api)} API samples not in XML " - f"(may be from other views / older data)") + print(f" [EXTRA] {len(extra_in_api)} API samples not in XML " f"(may be from other views / older data)") print(f"\n Status: {status}") # Per-tissue table @@ -904,9 +906,11 @@ def proofread_from_xml( print(f" {'─' * w} ───── ─────── ─────── ──── ───") for row in tissue_rows: flag = "✓" if row["ok"] else "✗" - print(f" {row['tissue']:<{w}} {row['total']:5d} " - f"{row['present']:7d} {row['missing']:7d} " - f"{row['null']:4d} {flag}") + print( + f" {row['tissue']:<{w}} {row['total']:5d} " + f"{row['present']:7d} {row['missing']:7d} " + f"{row['null']:4d} {flag}" + ) # Step 4 — BAR CGI cross-check on expected samples cgi_summary: Optional[Dict[str, Any]] = None @@ -916,8 +920,10 @@ def proofread_from_xml( confirmed = [d["name"] for d in cgi_data if d.get("value") is not None] not_found = [d["name"] for d in cgi_data if d.get("value") is None] cgi_summary = { - "checked": len(cgi_data), "confirmed": len(confirmed), - "not_found": len(not_found), "not_found_samples": not_found[:10], + "checked": len(cgi_data), + "confirmed": len(confirmed), + "not_found": len(not_found), + "not_found_samples": not_found[:10], } if verbose: print(f"\n BAR CGI: {len(confirmed)}/{len(cgi_data)} samples confirmed on production") @@ -947,46 +953,50 @@ def proofread_from_xml( # Main # =========================================================================== + def main() -> None: parser = argparse.ArgumentParser( description="Proofread EFP expression values (dump or XML mode, no SQLite)", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=__doc__, ) - parser.add_argument("--local-url", default="http://localhost:5000", - help="Base URL of the local BAR API (default: http://localhost:5000)") + parser.add_argument( + "--local-url", + default="http://localhost:5000", + help="Base URL of the local BAR API (default: http://localhost:5000)", + ) # Mode B (XML) arguments xml_grp = parser.add_argument_group("Mode B — XML-based (BAR eFP Browser view)") - xml_grp.add_argument("--view", default=None, - help="eFP Browser view name (e.g. 'Embryo', 'Guard_Cell'). " - "Activates XML mode for a single view.") - xml_grp.add_argument("--all-views", action="store_true", - help="Run XML mode for every view in VIEW_DB_BY_SPECIES " - "(use with --species to restrict to one species).") - xml_grp.add_argument("--species", default=None, - help="Species label to disambiguate view names or restrict " - "--all-views (e.g. 'arabidopsis', 'maize').") - xml_grp.add_argument("--gene", default="AT1G01010", - help="Gene ID for XML mode (default: AT1G01010)") + xml_grp.add_argument( + "--view", + default=None, + help="eFP Browser view name (e.g. 'Embryo', 'Guard_Cell'). " "Activates XML mode for a single view.", + ) + xml_grp.add_argument( + "--all-views", + action="store_true", + help="Run XML mode for every view in VIEW_DB_BY_SPECIES " "(use with --species to restrict to one species).", + ) + xml_grp.add_argument( + "--species", + default=None, + help="Species label to disambiguate view names or restrict " "--all-views (e.g. 'arabidopsis', 'maize').", + ) + xml_grp.add_argument("--gene", default="AT1G01010", help="Gene ID for XML mode (default: AT1G01010)") # Mode A (dump) arguments dump_grp = parser.add_argument_group("Mode A — dump-based (SQL dump files)") - dump_grp.add_argument("--databases", default=None, - help="Comma-separated databases to check " - "(default: embryo, klepikova, soybean)") - dump_grp.add_argument("--genes", default=None, - help="Comma-separated gene IDs (overrides --max-genes)") - dump_grp.add_argument("--all-genes", action="store_true", - help="Check every gene found in each dump (can be slow)") - dump_grp.add_argument("--max-genes", type=int, default=5, - help="Max genes per database in dump mode (default: 5)") + dump_grp.add_argument( + "--databases", default=None, help="Comma-separated databases to check " "(default: embryo, klepikova, soybean)" + ) + dump_grp.add_argument("--genes", default=None, help="Comma-separated gene IDs (overrides --max-genes)") + dump_grp.add_argument("--all-genes", action="store_true", help="Check every gene found in each dump (can be slow)") + dump_grp.add_argument("--max-genes", type=int, default=5, help="Max genes per database in dump mode (default: 5)") # Shared arguments - parser.add_argument("--skip-cgi", action="store_true", - help="Skip BAR production CGI cross-check") - parser.add_argument("--quiet", action="store_true", - help="Print summaries only, not per-sample detail") + parser.add_argument("--skip-cgi", action="store_true", help="Skip BAR production CGI cross-check") + parser.add_argument("--quiet", action="store_true", help="Print summaries only, not per-sample detail") args = parser.parse_args() verbose = not args.quiet @@ -1019,11 +1029,7 @@ def main() -> None: print(f" Available: {', '.join(VIEW_DB_BY_SPECIES.keys())}") sys.exit(1) else: - views_to_check = [ - (v, sp) - for sp, vmap in VIEW_DB_BY_SPECIES.items() - for v in vmap - ] + views_to_check = [(v, sp) for sp, vmap in VIEW_DB_BY_SPECIES.items() for v in vmap] else: views_to_check = [(args.view, args.species)] @@ -1052,8 +1058,7 @@ def main() -> None: all_pass = False present = r.get("present", "?") expected = r.get("expected_samples", "?") - print(f" {r.get('view', ''):<48s} {r.get('db', ''):<36s} " - f"{verdict} ({present}/{expected})") + print(f" {r.get('view', ''):<48s} {r.get('db', ''):<36s} " f"{verdict} ({present}/{expected})") print(f"\n Gene tested : {args.gene}") overall = "ALL PASS" if all_pass else "FAILURES FOUND" @@ -1069,9 +1074,7 @@ def main() -> None: else: databases = list(DUMP_FILES.keys()) - genes_override = ( - [g.strip() for g in args.genes.split(",")] if args.genes else None - ) + genes_override = [g.strip() for g in args.genes.split(",")] if args.genes else None print("=" * 66) print("EFP PROOFREADER — Mode A (SQL dump-based)") @@ -1107,14 +1110,12 @@ def main() -> None: total_fail += fail verdict = "PASS" if mis == 0 and miss == 0 and fail == 0 else "FAIL" - print(f" {db_name:<15s} {verdict} " - f"ok={ok} mismatch={mis} missing={miss} api_fail={fail}") + print(f" {db_name:<15s} {verdict} " f"ok={ok} mismatch={mis} missing={miss} api_fail={fail}") mismatch_genes = [r["gene_id"] for r in results if r["status"] == "MISMATCH"] if mismatch_genes: print(f" Mismatched genes: {', '.join(mismatch_genes)}") - print(f"\n Totals ok={total_ok} mismatch={total_mismatch} " - f"missing={total_missing} api_fail={total_fail}") + print(f"\n Totals ok={total_ok} mismatch={total_mismatch} " f"missing={total_missing} api_fail={total_fail}") overall = "ALL PASS" if (total_mismatch + total_missing + total_fail) == 0 else "FAILURES FOUND" print(f"\n Result: {overall}") print("=" * 66) diff --git a/tests/models/test_efp_dynamic.py b/tests/models/test_efp_dynamic.py index f9ec112b..dd55de99 100644 --- a/tests/models/test_efp_dynamic.py +++ b/tests/models/test_efp_dynamic.py @@ -34,31 +34,31 @@ def tearDown(self): def test_cannabis_model_columns(self): """Test cannabis model has correct columns and types.""" - model = SIMPLE_EFP_SAMPLE_MODELS['cannabis'] + model = SIMPLE_EFP_SAMPLE_MODELS["cannabis"] # Check the 3 data columns - self.assertTrue(hasattr(model, 'data_probeset_id')) - self.assertTrue(hasattr(model, 'data_signal')) - self.assertTrue(hasattr(model, 'data_bot_id')) + self.assertTrue(hasattr(model, "data_probeset_id")) + self.assertTrue(hasattr(model, "data_signal")) + self.assertTrue(hasattr(model, "data_bot_id")) # Inspect column types mapper = inspect(model) columns = {col.name: col for col in mapper.columns} self.assertEqual(len(columns), 3) - self.assertEqual(str(columns['data_bot_id'].type), 'VARCHAR(255)') - self.assertEqual(str(columns['data_probeset_id'].type), 'VARCHAR(255)') + self.assertEqual(str(columns["data_bot_id"].type), "VARCHAR(255)") + self.assertEqual(str(columns["data_probeset_id"].type), "VARCHAR(255)") def test_mouse_db_has_3_columns(self): """Test mouse_db model has exactly 3 columns.""" - model = SIMPLE_EFP_SAMPLE_MODELS['mouse_db'] + model = SIMPLE_EFP_SAMPLE_MODELS["mouse_db"] # Get all columns mapper = inspect(model) column_names = {col.name for col in mapper.columns} # Should have exactly 3 columns - expected = {'data_probeset_id', 'data_signal', 'data_bot_id'} + expected = {"data_probeset_id", "data_signal", "data_bot_id"} self.assertEqual(column_names, expected) def test_all_models_have_3_columns(self): @@ -68,8 +68,7 @@ def test_all_models_have_3_columns(self): mapper = inspect(model) column_names = {col.name for col in mapper.columns} self.assertEqual( - len(column_names), 3, - f"{db_name} has {len(column_names)} columns, expected 3: {column_names}" + len(column_names), 3, f"{db_name} has {len(column_names)} columns, expected 3: {column_names}" ) @@ -87,28 +86,28 @@ def tearDown(self): def test_cannabis_primary_keys(self): """Test cannabis model has correct primary keys.""" - model = SIMPLE_EFP_SAMPLE_MODELS['cannabis'] + model = SIMPLE_EFP_SAMPLE_MODELS["cannabis"] mapper = inspect(model) pk_columns = [col.name for col in mapper.primary_key] # Should have 3 primary keys self.assertEqual(len(pk_columns), 3) - self.assertIn('data_probeset_id', pk_columns) - self.assertIn('data_signal', pk_columns) - self.assertIn('data_bot_id', pk_columns) + self.assertIn("data_probeset_id", pk_columns) + self.assertIn("data_signal", pk_columns) + self.assertIn("data_bot_id", pk_columns) def test_mouse_db_primary_keys(self): """Test mouse_db has correct primary keys.""" - model = SIMPLE_EFP_SAMPLE_MODELS['mouse_db'] + model = SIMPLE_EFP_SAMPLE_MODELS["mouse_db"] mapper = inspect(model) pk_columns = [col.name for col in mapper.primary_key] self.assertEqual(len(pk_columns), 3) - self.assertIn('data_probeset_id', pk_columns) - self.assertIn('data_signal', pk_columns) - self.assertIn('data_bot_id', pk_columns) + self.assertIn("data_probeset_id", pk_columns) + self.assertIn("data_signal", pk_columns) + self.assertIn("data_bot_id", pk_columns) def test_all_models_have_primary_keys(self): """Test that all 191 models have at least one primary key.""" @@ -117,10 +116,7 @@ def test_all_models_have_primary_keys(self): mapper = inspect(model) pk_columns = list(mapper.primary_key) - self.assertGreater( - len(pk_columns), 0, - f"{db_name} has no primary key columns" - ) + self.assertGreater(len(pk_columns), 0, f"{db_name} has no primary key columns") class TestDynamicModelNullability(TestCase): @@ -137,27 +133,24 @@ def tearDown(self): def test_embryo_signal_nullable(self): """Test embryo has nullable data_signal.""" - schema = SIMPLE_EFP_DATABASE_SCHEMAS['embryo'] - signal_col = next(col for col in schema['columns'] if col['name'] == 'data_signal') + schema = SIMPLE_EFP_DATABASE_SCHEMAS["embryo"] + signal_col = next(col for col in schema["columns"] if col["name"] == "data_signal") - self.assertTrue(signal_col.get('nullable', False)) + self.assertTrue(signal_col.get("nullable", False)) def test_dna_damage_bot_id_nullable(self): """Test dna_damage has nullable data_bot_id.""" - schema = SIMPLE_EFP_DATABASE_SCHEMAS['dna_damage'] - bot_id_col = next(col for col in schema['columns'] if col['name'] == 'data_bot_id') + schema = SIMPLE_EFP_DATABASE_SCHEMAS["dna_damage"] + bot_id_col = next(col for col in schema["columns"] if col["name"] == "data_bot_id") - self.assertTrue(bot_id_col.get('nullable', False)) + self.assertTrue(bot_id_col.get("nullable", False)) def test_all_columns_nullable(self): """All 3 columns in every schema should be nullable.""" for db_name, schema in SIMPLE_EFP_DATABASE_SCHEMAS.items(): - for col in schema['columns']: - with self.subTest(database=db_name, column=col['name']): - self.assertTrue( - col.get('nullable', False), - f"{db_name}.{col['name']} is not nullable" - ) + for col in schema["columns"]: + with self.subTest(database=db_name, column=col["name"]): + self.assertTrue(col.get("nullable", False), f"{db_name}.{col['name']} is not nullable") class TestUniformColumnTypes(TestCase): @@ -175,27 +168,27 @@ def tearDown(self): def test_all_string_columns_are_varchar_255(self): """All string columns should be VARCHAR(255).""" for db_name, schema in SIMPLE_EFP_DATABASE_SCHEMAS.items(): - for col in schema['columns']: - if col['type'] == 'string': - with self.subTest(database=db_name, column=col['name']): - self.assertEqual(col['length'], 255) + for col in schema["columns"]: + if col["type"] == "string": + with self.subTest(database=db_name, column=col["name"]): + self.assertEqual(col["length"], 255) def test_affydb_bot_id_is_varchar(self): """Test affydb uses VARCHAR(255) for data_bot_id (was TEXT).""" - schema = SIMPLE_EFP_DATABASE_SCHEMAS['affydb'] - bot_id_col = next(col for col in schema['columns'] if col['name'] == 'data_bot_id') + schema = SIMPLE_EFP_DATABASE_SCHEMAS["affydb"] + bot_id_col = next(col for col in schema["columns"] if col["name"] == "data_bot_id") - self.assertEqual(bot_id_col['type'], 'string') - self.assertEqual(bot_id_col['length'], 255) + self.assertEqual(bot_id_col["type"], "string") + self.assertEqual(bot_id_col["length"], 255) def test_canola_columns_are_varchar(self): """Test canola uses VARCHAR(255) for both string columns (was TEXT).""" - schema = SIMPLE_EFP_DATABASE_SCHEMAS['canola'] - for col in schema['columns']: - if col['name'] in ('data_probeset_id', 'data_bot_id'): - with self.subTest(column=col['name']): - self.assertEqual(col['type'], 'string') - self.assertEqual(col['length'], 255) + schema = SIMPLE_EFP_DATABASE_SCHEMAS["canola"] + for col in schema["columns"]: + if col["name"] in ("data_probeset_id", "data_bot_id"): + with self.subTest(column=col["name"]): + self.assertEqual(col["type"], "string") + self.assertEqual(col["length"], 255) class TestModelClassNames(TestCase): @@ -212,18 +205,18 @@ def tearDown(self): def test_cannabis_class_name(self): """Test cannabis model has correct class name.""" - model = SIMPLE_EFP_SAMPLE_MODELS['cannabis'] - self.assertEqual(model.__name__, 'CannabisSampleData') + model = SIMPLE_EFP_SAMPLE_MODELS["cannabis"] + self.assertEqual(model.__name__, "CannabisSampleData") def test_maize_atlas_v5_class_name(self): """Test maize_atlas_v5 model has correct class name.""" - model = SIMPLE_EFP_SAMPLE_MODELS['maize_atlas_v5'] - self.assertEqual(model.__name__, 'MaizeAtlasV5SampleData') + model = SIMPLE_EFP_SAMPLE_MODELS["maize_atlas_v5"] + self.assertEqual(model.__name__, "MaizeAtlasV5SampleData") def test_lateral_root_initiation_class_name(self): """Test lateral_root_initiation model has correct class name.""" - model = SIMPLE_EFP_SAMPLE_MODELS['lateral_root_initiation'] - self.assertEqual(model.__name__, 'LateralRootInitiationSampleData') + model = SIMPLE_EFP_SAMPLE_MODELS["lateral_root_initiation"] + self.assertEqual(model.__name__, "LateralRootInitiationSampleData") class TestSampleDatabases(TestCase): @@ -240,10 +233,7 @@ def tearDown(self): def test_arabidopsis_databases(self): """Test all arabidopsis-related databases load.""" - arabidopsis_dbs = [ - 'arabidopsis_ecotypes', 'embryo', 'germination', - 'dna_damage', 'single_cell', 'silique' - ] + arabidopsis_dbs = ["arabidopsis_ecotypes", "embryo", "germination", "dna_damage", "single_cell", "silique"] for db_name in arabidopsis_dbs: with self.subTest(database=db_name): @@ -253,10 +243,7 @@ def test_arabidopsis_databases(self): def test_cereal_databases(self): """Test cereal crop databases (wheat, barley, rice, maize, sorghum, oat).""" - cereal_dbs = [ - 'wheat', 'barley_seed', 'rice_root', - 'maize_atlas_v5', 'sorghum_atlas_w_BS_cells', 'oat' - ] + cereal_dbs = ["wheat", "barley_seed", "rice_root", "maize_atlas_v5", "sorghum_atlas_w_BS_cells", "oat"] for db_name in cereal_dbs: with self.subTest(database=db_name): @@ -266,10 +253,7 @@ def test_cereal_databases(self): def test_legume_databases(self): """Test legume databases (soybean, medicago, lupin).""" - legume_dbs = [ - 'soybean_senescence', 'medicago_root', - 'lupin_whole_plant', 'arachis' - ] + legume_dbs = ["soybean_senescence", "medicago_root", "lupin_whole_plant", "arachis"] for db_name in legume_dbs: with self.subTest(database=db_name): @@ -279,7 +263,7 @@ def test_legume_databases(self): def test_tree_databases(self): """Test tree/woody plant databases (poplar, spruce, eucalyptus).""" - tree_dbs = ['poplar_hormone', 'spruce', 'eucalyptus', 'willow'] + tree_dbs = ["poplar_hormone", "spruce", "eucalyptus", "willow"] for db_name in tree_dbs: with self.subTest(database=db_name): @@ -289,10 +273,7 @@ def test_tree_databases(self): def test_tropical_crop_databases(self): """Test tropical crop databases (cacao, cassava, mangosteen).""" - tropical_dbs = [ - 'cacao_developmental_atlas', 'cassava_atlas', - 'mangosteen_fruit_ripening' - ] + tropical_dbs = ["cacao_developmental_atlas", "cassava_atlas", "mangosteen_fruit_ripening"] for db_name in tropical_dbs: with self.subTest(database=db_name): diff --git a/tests/models/test_efp_schemas.py b/tests/models/test_efp_schemas.py index 3228e296..9ad068c3 100644 --- a/tests/models/test_efp_schemas.py +++ b/tests/models/test_efp_schemas.py @@ -27,119 +27,94 @@ class TestEfpSchemaDefinitions(TestCase): def test_all_191_databases_loaded(self): """Verify all 191 databases from CSV are loaded.""" self.assertEqual( - len(SIMPLE_EFP_DATABASE_SCHEMAS), - 191, - f"Expected 191 databases, found {len(SIMPLE_EFP_DATABASE_SCHEMAS)}" + len(SIMPLE_EFP_DATABASE_SCHEMAS), 191, f"Expected 191 databases, found {len(SIMPLE_EFP_DATABASE_SCHEMAS)}" ) def test_all_schemas_have_required_keys(self): """Every schema must have table_name, charset, columns, index, and metadata.""" - required_keys = {'table_name', 'charset', 'columns', 'index', 'metadata'} + required_keys = {"table_name", "charset", "columns", "index", "metadata"} for db_name, schema in SIMPLE_EFP_DATABASE_SCHEMAS.items(): with self.subTest(database=db_name): missing_keys = required_keys - set(schema.keys()) - self.assertEqual( - len(missing_keys), 0, - f"{db_name} missing required keys: {missing_keys}" - ) + self.assertEqual(len(missing_keys), 0, f"{db_name} missing required keys: {missing_keys}") def test_all_schemas_have_metadata(self): """Every schema must have species and sample_regex metadata.""" for db_name, schema in SIMPLE_EFP_DATABASE_SCHEMAS.items(): with self.subTest(database=db_name): - metadata = schema.get('metadata', {}) - self.assertIn('species', metadata, f"{db_name} missing species metadata") - self.assertIn('sample_regex', metadata, f"{db_name} missing sample_regex metadata") + metadata = schema.get("metadata", {}) + self.assertIn("species", metadata, f"{db_name} missing species metadata") + self.assertIn("sample_regex", metadata, f"{db_name} missing sample_regex metadata") def test_all_schemas_have_3_columns(self): """Every database should have exactly 3 columns: data_probeset_id, data_signal, data_bot_id.""" - expected_columns = {'data_probeset_id', 'data_signal', 'data_bot_id'} + expected_columns = {"data_probeset_id", "data_signal", "data_bot_id"} for db_name, schema in SIMPLE_EFP_DATABASE_SCHEMAS.items(): with self.subTest(database=db_name): - column_names = {col['name'] for col in schema['columns']} - self.assertEqual( - column_names, expected_columns, - f"{db_name} has unexpected columns: {column_names}" - ) + column_names = {col["name"] for col in schema["columns"]} + self.assertEqual(column_names, expected_columns, f"{db_name} has unexpected columns: {column_names}") def test_column_types_are_valid(self): """All column types must be one of: string, integer, float, text.""" - valid_types = {'string', 'integer', 'float', 'text'} + valid_types = {"string", "integer", "float", "text"} for db_name, schema in SIMPLE_EFP_DATABASE_SCHEMAS.items(): - for column in schema['columns']: - with self.subTest(database=db_name, column=column['name']): - col_type = column.get('type') - self.assertIn( - col_type, valid_types, - f"{db_name}.{column['name']} has invalid type: {col_type}" - ) + for column in schema["columns"]: + with self.subTest(database=db_name, column=column["name"]): + col_type = column.get("type") + self.assertIn(col_type, valid_types, f"{db_name}.{column['name']} has invalid type: {col_type}") def test_string_columns_have_length(self): """String columns must have a length specified.""" for db_name, schema in SIMPLE_EFP_DATABASE_SCHEMAS.items(): - for column in schema['columns']: - with self.subTest(database=db_name, column=column['name']): - if column.get('type') == 'string': - self.assertIn( - 'length', column, - f"{db_name}.{column['name']} is string but missing length" - ) - self.assertIsInstance( - column['length'], int, - f"{db_name}.{column['name']} length must be int" - ) - self.assertGreater( - column['length'], 0, - f"{db_name}.{column['name']} length must be > 0" - ) + for column in schema["columns"]: + with self.subTest(database=db_name, column=column["name"]): + if column.get("type") == "string": + self.assertIn("length", column, f"{db_name}.{column['name']} is string but missing length") + self.assertIsInstance(column["length"], int, f"{db_name}.{column['name']} length must be int") + self.assertGreater(column["length"], 0, f"{db_name}.{column['name']} length must be > 0") def test_charset_is_valid(self): """Charset must be either latin1 or utf8mb4.""" - valid_charsets = {'latin1', 'utf8mb4'} + valid_charsets = {"latin1", "utf8mb4"} for db_name, schema in SIMPLE_EFP_DATABASE_SCHEMAS.items(): with self.subTest(database=db_name): - charset = schema.get('charset') - self.assertIn( - charset, valid_charsets, - f"{db_name} has invalid charset: {charset}" - ) + charset = schema.get("charset") + self.assertIn(charset, valid_charsets, f"{db_name} has invalid charset: {charset}") def test_known_databases_have_correct_species(self): """Verify species metadata for known databases.""" known_species = { - 'cannabis': 'cannabis', - 'embryo': 'arabidopsis', - 'wheat': 'wheat', - 'maize_atlas_v5': 'maize', - 'rice_root': 'rice', - 'sorghum_atlas_w_BS_cells': 'sorghum', - 'oat': 'oat', - 'mouse_db': 'mouse', + "cannabis": "cannabis", + "embryo": "arabidopsis", + "wheat": "wheat", + "maize_atlas_v5": "maize", + "rice_root": "rice", + "sorghum_atlas_w_BS_cells": "sorghum", + "oat": "oat", + "mouse_db": "mouse", } for db_name, expected_species in known_species.items(): with self.subTest(database=db_name): schema = SIMPLE_EFP_DATABASE_SCHEMAS.get(db_name) self.assertIsNotNone(schema, f"{db_name} not found in schemas") - actual_species = schema['metadata']['species'] + actual_species = schema["metadata"]["species"] self.assertEqual( - actual_species, expected_species, - f"{db_name} has species={actual_species}, expected {expected_species}" + actual_species, + expected_species, + f"{db_name} has species={actual_species}, expected {expected_species}", ) def test_primary_keys_are_defined(self): """Each schema should have at least one primary key column.""" for db_name, schema in SIMPLE_EFP_DATABASE_SCHEMAS.items(): with self.subTest(database=db_name): - pk_columns = [col['name'] for col in schema['columns'] if col.get('primary_key')] - self.assertGreater( - len(pk_columns), 0, - f"{db_name} has no primary key columns" - ) + pk_columns = [col["name"] for col in schema["columns"] if col.get("primary_key")] + self.assertGreater(len(pk_columns), 0, f"{db_name} has no primary key columns") class TestDynamicOrmGeneration(TestCase): @@ -157,9 +132,7 @@ def tearDown(self): def test_all_191_models_generated(self): """Verify all 191 dynamic ORM models are generated.""" self.assertEqual( - len(SIMPLE_EFP_SAMPLE_MODELS), - 191, - f"Expected 191 ORM models, found {len(SIMPLE_EFP_SAMPLE_MODELS)}" + len(SIMPLE_EFP_SAMPLE_MODELS), 191, f"Expected 191 ORM models, found {len(SIMPLE_EFP_SAMPLE_MODELS)}" ) def test_model_names_match_database_names(self): @@ -168,67 +141,48 @@ def test_model_names_match_database_names(self): model_names = set(SIMPLE_EFP_SAMPLE_MODELS.keys()) self.assertEqual( - schema_names, model_names, - f"Mismatch between schemas and models. Missing: {schema_names - model_names}" + schema_names, model_names, f"Mismatch between schemas and models. Missing: {schema_names - model_names}" ) def test_all_models_have_tablename(self): """Every model must have __tablename__ attribute.""" for db_name, model in SIMPLE_EFP_SAMPLE_MODELS.items(): with self.subTest(database=db_name): - self.assertTrue( - hasattr(model, '__tablename__'), - f"{db_name} model missing __tablename__" - ) - self.assertEqual( - model.__tablename__, 'sample_data', - f"{db_name} has wrong table name" - ) + self.assertTrue(hasattr(model, "__tablename__"), f"{db_name} model missing __tablename__") + self.assertEqual(model.__tablename__, "sample_data", f"{db_name} has wrong table name") def test_all_models_have_bind_key(self): """Every model must have __bind_key__ matching the database name.""" for db_name, model in SIMPLE_EFP_SAMPLE_MODELS.items(): with self.subTest(database=db_name): - self.assertTrue( - hasattr(model, '__bind_key__'), - f"{db_name} model missing __bind_key__" - ) - self.assertEqual( - model.__bind_key__, db_name, - f"{db_name} has wrong bind_key" - ) + self.assertTrue(hasattr(model, "__bind_key__"), f"{db_name} model missing __bind_key__") + self.assertEqual(model.__bind_key__, db_name, f"{db_name} has wrong bind_key") def test_models_have_correct_column_count(self): """Models should have the same number of columns as their schema.""" for db_name, model in SIMPLE_EFP_SAMPLE_MODELS.items(): with self.subTest(database=db_name): schema = SIMPLE_EFP_DATABASE_SCHEMAS[db_name] - expected_cols = len(schema['columns']) + expected_cols = len(schema["columns"]) # Count actual columns (excluding internal attributes) - actual_cols = len([ - attr for attr in dir(model) - if not attr.startswith('_') and - hasattr(getattr(model, attr), 'type') - ]) + actual_cols = len( + [attr for attr in dir(model) if not attr.startswith("_") and hasattr(getattr(model, attr), "type")] + ) self.assertEqual( - actual_cols, expected_cols, - f"{db_name} model has {actual_cols} columns, schema has {expected_cols}" + actual_cols, expected_cols, f"{db_name} model has {actual_cols} columns, schema has {expected_cols}" ) def test_known_models_have_expected_columns(self): """Verify specific models have the 3 expected columns.""" - expected_columns = ['data_probeset_id', 'data_signal', 'data_bot_id'] + expected_columns = ["data_probeset_id", "data_signal", "data_bot_id"] - for db_name in ['cannabis', 'oat', 'mouse_db', 'embryo', 'wheat']: + for db_name in ["cannabis", "oat", "mouse_db", "embryo", "wheat"]: with self.subTest(database=db_name): model = SIMPLE_EFP_SAMPLE_MODELS[db_name] for col_name in expected_columns: - self.assertTrue( - hasattr(model, col_name), - f"{db_name} model missing column: {col_name}" - ) + self.assertTrue(hasattr(model, col_name), f"{db_name} model missing column: {col_name}") class TestDatabaseCategoryDistribution(TestCase): @@ -238,37 +192,30 @@ def test_species_distribution(self): """Verify we have databases for expected species.""" species_counts = {} for schema in SIMPLE_EFP_DATABASE_SCHEMAS.values(): - species = schema['metadata']['species'] + species = schema["metadata"]["species"] species_counts[species] = species_counts.get(species, 0) + 1 # Check we have major species covered - major_species = ['arabidopsis', 'maize', 'wheat', 'rice', 'sorghum'] + major_species = ["arabidopsis", "maize", "wheat", "rice", "sorghum"] for species in major_species: with self.subTest(species=species): - self.assertIn( - species, species_counts, - f"No databases found for {species}" - ) - self.assertGreater( - species_counts[species], 0, - f"Expected databases for {species}" - ) + self.assertIn(species, species_counts, f"No databases found for {species}") + self.assertGreater(species_counts[species], 0, f"Expected databases for {species}") def test_charset_distribution(self): """Verify charset distribution (should have both latin1 and utf8mb4).""" charset_counts = {} for schema in SIMPLE_EFP_DATABASE_SCHEMAS.values(): - charset = schema['charset'] + charset = schema["charset"] charset_counts[charset] = charset_counts.get(charset, 0) + 1 # Should have both charsets - self.assertIn('latin1', charset_counts) - self.assertIn('utf8mb4', charset_counts) + self.assertIn("latin1", charset_counts) + self.assertIn("utf8mb4", charset_counts) # utf8mb4 should be more common (126 databases) self.assertGreater( - charset_counts['utf8mb4'], 100, - f"Expected ~126 utf8mb4 databases, found {charset_counts['utf8mb4']}" + charset_counts["utf8mb4"], 100, f"Expected ~126 utf8mb4 databases, found {charset_counts['utf8mb4']}" ) @@ -278,10 +225,9 @@ class TestVarcharLengths(TestCase): def test_all_string_columns_are_255(self): """All string columns should use a uniform length of 255.""" for db_name, schema in SIMPLE_EFP_DATABASE_SCHEMAS.items(): - for col in schema['columns']: - if col.get('type') == 'string': - with self.subTest(database=db_name, column=col['name']): + for col in schema["columns"]: + if col.get("type") == "string": + with self.subTest(database=db_name, column=col["name"]): self.assertEqual( - col['length'], 255, - f"{db_name}.{col['name']} length is {col['length']}, expected 255" + col["length"], 255, f"{db_name}.{col['name']} length is {col['length']}, expected 255" ) diff --git a/tests/resources/test_proxy.py b/tests/resources/test_proxy.py index e5abb787..963c8f4f 100644 --- a/tests/resources/test_proxy.py +++ b/tests/resources/test_proxy.py @@ -1,7 +1,8 @@ from api import app from json import load -from unittest import TestCase from unittest.mock import MagicMock, patch +import unittest +import requests class TestIntegrations(unittest.TestCase): diff --git a/tests/services/test_efp_data.py b/tests/services/test_efp_data.py index 25d65283..7e845dd7 100644 --- a/tests/services/test_efp_data.py +++ b/tests/services/test_efp_data.py @@ -63,9 +63,7 @@ def test_invalid_gene_format(self): def test_sample_data_agi_is_converted_to_probeset(self): """sample_data requires probesets, so agi ids should be converted automatically""" mapping_date = date(2020, 1, 1) - db.session.query(AtAgiLookup).filter_by( - probeset="261585_at", agi="AT1G01010", date=mapping_date - ).delete() + db.session.query(AtAgiLookup).filter_by(probeset="261585_at", agi="AT1G01010", date=mapping_date).delete() db.session.add( AtAgiLookup( probeset="261585_at", @@ -78,9 +76,7 @@ def test_sample_data_agi_is_converted_to_probeset(self): try: result = query_efp_database_dynamic("sample_data", "At1g01010", allow_empty_results=False) finally: - db.session.query(AtAgiLookup).filter_by( - probeset="261585_at", agi="AT1G01010", date=mapping_date - ).delete() + db.session.query(AtAgiLookup).filter_by(probeset="261585_at", agi="AT1G01010", date=mapping_date).delete() db.session.commit() self.assertTrue(result["success"]) @@ -159,16 +155,12 @@ def setUp(self): self.ctx = app.app_context() self.ctx.push() # seed the AGI → probeset mapping once for all tests in this class - db.session.query(AtAgiLookup).filter_by( - probeset=self.PROBESET, agi=self.AGI, date=self.MAPPING_DATE - ).delete() + db.session.query(AtAgiLookup).filter_by(probeset=self.PROBESET, agi=self.AGI, date=self.MAPPING_DATE).delete() db.session.add(AtAgiLookup(probeset=self.PROBESET, agi=self.AGI, date=self.MAPPING_DATE)) db.session.commit() def tearDown(self): - db.session.query(AtAgiLookup).filter_by( - probeset=self.PROBESET, agi=self.AGI, date=self.MAPPING_DATE - ).delete() + db.session.query(AtAgiLookup).filter_by(probeset=self.PROBESET, agi=self.AGI, date=self.MAPPING_DATE).delete() db.session.commit() self.ctx.pop() @@ -193,9 +185,7 @@ def _assert_probeset_conversion_ran(self, database): # Conversion ran but no local DB to query — acceptable in CI/local dev pass else: - self.fail( - f"{database}: conversion failed before reaching the DB — error: {error}" - ) + self.fail(f"{database}: conversion failed before reaching the DB — error: {error}") def test_affydb_converts_agi_to_probeset(self): self._assert_probeset_conversion_ran("affydb") diff --git a/tests/utils/test_docking_utils.py b/tests/utils/test_docking_utils.py index 3ffaa42f..69b957ad 100644 --- a/tests/utils/test_docking_utils.py +++ b/tests/utils/test_docking_utils.py @@ -7,7 +7,6 @@ from api.utils.docking_utils import SDFMapping import os - NOT_IN_BAR = not os.environ.get("BAR") == "true"