diff --git a/pridepy/files/files.py b/pridepy/files/files.py index 0f8e029..e9cde55 100644 --- a/pridepy/files/files.py +++ b/pridepy/files/files.py @@ -1,53 +1,25 @@ #!/usr/bin/env python import ftplib -import hashlib -import importlib.resources import logging import os -import platform -import posixpath import re -import subprocess import urllib import urllib.request -import time from concurrent.futures import ThreadPoolExecutor, as_completed from ftplib import FTP from typing import Dict, List, Optional, Tuple -import socket from urllib.parse import urlparse import xml.etree.ElementTree as ET -import boto3 -import botocore import requests -from botocore.config import Config from tqdm import tqdm -from pridepy.authentication.authentication import Authentication from pridepy.util.api_handling import Util -class Progress: - def __init__(self, total_size, file_name): - self.pbar = tqdm( - total=total_size, - unit="B", - unit_scale=True, - desc="Downloading {}".format(file_name), - ) - - def __call__(self, bytes_amount): - self.pbar.update(bytes_amount) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.pbar.close() - - def close(self): - self.pbar.close() +# Re-export from providers.util so external `from pridepy.files.files import Progress` +# still works. +from pridepy.providers.util import Progress # noqa: F401 class Files: @@ -55,49 +27,38 @@ class Files: This class handles PRIDE API files endpoint. """ - V3_API_BASE_URL = "https://www.ebi.ac.uk/pride/ws/archive/v3" - API_BASE_URL = "https://www.ebi.ac.uk/pride/ws/archive/v3" - API_PRIVATE_URL = "https://www.ebi.ac.uk/pride/private/ws/archive/v2" - PRIDE_ARCHIVE_FTP = "ftp.pride.ebi.ac.uk" - PRIDE_ARCHIVE_FTP_URL_PREFIX = "ftp://ftp.pride.ebi.ac.uk/" - PRIDE_ARCHIVE_HTTPS_URL_PREFIX = "https://ftp.pride.ebi.ac.uk/" - MASSIVE_ARCHIVE_FTP = "massive-ftp.ucsd.edu" - MASSIVE_ARCHIVE_FTP_URL_PREFIX = "ftp://massive-ftp.ucsd.edu/v01/" - JPOST_ARCHIVE_FTP = "ftp.jpostdb.org" - JPOST_ARCHIVE_FTP_URL_PREFIX = "ftp://ftp.jpostdb.org/" - JPOST_PROXI_BASE_URL = "https://repository.jpostdb.org/proxi/datasets/" - JPOST_PROXI_CATEGORY_MAP = { - "Associated raw file URI": "RAW", - "Result file URI": "RESULT", - "Search engine output file URI": "SEARCH", - "Peak list file URI": "PEAK", - "Spectrum library file URI": "SPECTRUM_LIBRARY", - "Sequence database URI": "FASTA", - "Quantification file URI": "RESULT", - } - IPROX_DOWNLOAD_BASE_URL = "http://download.iprox.org/" - IPROX_PX_XML_URL_TEMPLATE = ( - "http://download.iprox.org/{accession}/PX_{accession}.xml" + # Re-exported from providers/pride.py — kept here for back-compat. + from pridepy.providers.pride import PrideProvider as _PrideProvider + V3_API_BASE_URL = _PrideProvider.V3_API_BASE_URL + API_BASE_URL = _PrideProvider.API_BASE_URL + API_PRIVATE_URL = _PrideProvider.API_PRIVATE_URL + PRIDE_ARCHIVE_FTP = _PrideProvider.ARCHIVE_FTP + PRIDE_ARCHIVE_FTP_URL_PREFIX = _PrideProvider.ARCHIVE_FTP_URL_PREFIX + PRIDE_ARCHIVE_HTTPS_URL_PREFIX = _PrideProvider.ARCHIVE_HTTPS_URL_PREFIX + S3_URL = _PrideProvider.S3_URL + S3_BUCKET = _PrideProvider.S3_BUCKET + PROTOCOL_ORDER = _PrideProvider.PROTOCOL_ORDER + del _PrideProvider + # Re-exported from providers/massive.py — kept here for back-compat. + from pridepy.providers.massive import ( # noqa: E402 + MASSIVE_CATEGORY_MAP as _MASSIVE_CATEGORY_MAP, + MassiveProvider as _MassiveProvider, ) - # iProX PX XML uses the same PSI-MS cvParam "name" values as JPOST, so the - # JPOST PROXI category map applies. PX XML cvParam "Associated raw file URI" - # is the canonical raw-file label per the PSI-MS CV (MS:1002846). - IPROX_PX_CATEGORY_MAP = JPOST_PROXI_CATEGORY_MAP - S3_URL = "https://hh.fire.sdo.ebi.ac.uk" - S3_BUCKET = "pride-public" - PROTOCOL_ORDER = ["aspera", "s3", "ftp", "globus"] - MASSIVE_CATEGORY_MAP = { - "raw": "RAW", - "peak": "PEAK", - "ccms_peak": "PEAK", - "search": "SEARCH", - "result": "RESULT", - "ccms_result": "RESULT", - "quant": "RESULT", - "fasta": "FASTA", - "spectrum_library": "SPECTRUM_LIBRARY", - "library": "SPECTRUM_LIBRARY", - } + MASSIVE_CATEGORY_MAP = _MASSIVE_CATEGORY_MAP + MASSIVE_ARCHIVE_FTP = _MassiveProvider.ARCHIVE_FTP + MASSIVE_ARCHIVE_FTP_URL_PREFIX = _MassiveProvider.ARCHIVE_FTP_URL_PREFIX + del _MASSIVE_CATEGORY_MAP, _MassiveProvider + from pridepy.providers.jpost import JpostProvider as _JpostProvider + JPOST_ARCHIVE_FTP = _JpostProvider.ARCHIVE_FTP + JPOST_ARCHIVE_FTP_URL_PREFIX = _JpostProvider.ARCHIVE_FTP_URL_PREFIX + JPOST_PROXI_BASE_URL = _JpostProvider.PROXI_BASE_URL + JPOST_PROXI_CATEGORY_MAP = _JpostProvider.PROXI_CATEGORY_MAP + del _JpostProvider + from pridepy.providers.iprox import IproxProvider as _IproxProvider + IPROX_DOWNLOAD_BASE_URL = _IproxProvider.DOWNLOAD_BASE_URL + IPROX_PX_XML_URL_TEMPLATE = _IproxProvider.PX_XML_URL_TEMPLATE + IPROX_PX_CATEGORY_MAP = _IproxProvider.PX_CATEGORY_MAP + del _IproxProvider logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") def __init__(self): @@ -105,451 +66,175 @@ def __init__(self): @staticmethod def _find_tsv_columns(header: str) -> Optional[Tuple[int, int]]: - """Return (name_idx, checksum_idx) from a TSV header, or None.""" - cols = [col.strip().lower() for col in header.split("\t")] - required_cols = {"file-name", "file-md5checksum", "file-size"} - if not required_cols.issubset(set(cols)): - return None - return cols.index("file-name"), cols.index("file-md5checksum") + """Shim — see :func:`pridepy.providers.util._find_tsv_columns`.""" + from pridepy.providers import util + return util._find_tsv_columns(header) @staticmethod def _is_md5_checksum(value: str) -> bool: - return len(value) == 32 and all(char in "0123456789abcdef" for char in value) + """Shim — see :func:`pridepy.providers.util._is_md5_checksum`.""" + from pridepy.providers import util + return util._is_md5_checksum(value) @staticmethod def read_checksum_file(checksum_file_path: str) -> Dict[str, str]: - """ - Read PRIDE API checksum TSV and build {file_name: md5} map. - Expected format: File-Name\tFile-MD5Checksum\tFile-Size - """ - checksums: Dict[str, str] = {} - if not checksum_file_path or not os.path.exists(checksum_file_path): - return checksums - - with open(checksum_file_path, "r", encoding="utf-8") as f: - header = f.readline().strip() - if not header: - return checksums - - col_indices = Files._find_tsv_columns(header) - if col_indices is None: - logging.warning(f"Unrecognized checksum file format: {header}") - return checksums - - name_idx, checksum_idx = col_indices - min_cols = max(name_idx, checksum_idx) + 1 - for line in f: - parts = line.strip().split("\t") - if len(parts) >= min_cols: - fn = os.path.basename(parts[name_idx].strip()) - cs = parts[checksum_idx].strip().lower() - if fn and Files._is_md5_checksum(cs): - checksums[fn] = cs - - return checksums + """Shim — see :func:`pridepy.providers.util.read_checksum_file`.""" + from pridepy.providers import util + return util.read_checksum_file(checksum_file_path) @staticmethod def compute_md5(file_path: str, chunk_size: int = 4 * 1024 * 1024) -> str: - """ - Compute an MD5 checksum for integrity validation, not for security use. - """ - try: - md5 = hashlib.md5(usedforsecurity=False) - except TypeError: - md5 = hashlib.md5() - with open(file_path, "rb") as file_handle: - while True: - chunk = file_handle.read(chunk_size) - if not chunk: - break - md5.update(chunk) - return md5.hexdigest() + """Shim — see :func:`pridepy.providers.util.compute_md5`.""" + from pridepy.providers import util + return util.compute_md5(file_path, chunk_size) @staticmethod def validate_download(file_path: str, expected_checksum: Optional[str] = None) -> Tuple[bool, str]: - """ - Validate a local file exists, is non-empty, and checksum matches when provided. - """ - if not os.path.exists(file_path): - return False, "file does not exist" - if os.path.getsize(file_path) == 0: - return False, "file is empty" - if expected_checksum: - actual_checksum = Files.compute_md5(file_path) - if actual_checksum.lower() != expected_checksum.lower(): - return False, ( - f"checksum mismatch (expected={expected_checksum.lower()}, actual={actual_checksum.lower()})" - ) - return True, "ok" + """Shim — see :func:`pridepy.providers.util.validate_download`.""" + from pridepy.providers import util + return util.validate_download(file_path, expected_checksum) @staticmethod def _remove_if_exists(file_path: str) -> None: - """ - Remove a file if it already exists locally. - """ - if os.path.exists(file_path): - os.remove(file_path) + """Shim — see :func:`pridepy.providers.util._remove_if_exists`.""" + from pridepy.providers import util + return util._remove_if_exists(file_path) @staticmethod def _get_download_url(file_record: Dict, protocol: str) -> str: - """ - Resolve the public download URL for a file and protocol. - - Raises ValueError when the requested protocol has no suitable location. - Aspera requires a dedicated "Aspera Protocol" entry; ftp/s3/globus - derive their URL from the "FTP Protocol" entry (falling back to an - arbitrary non-Aspera location would produce a URL the caller cannot - actually transfer with). - """ - locations = file_record.get("publicFileLocations", []) - if not locations: - raise ValueError("No public file locations present") - - aspera_url = None - ftp_url = None - for location in locations: - name = location.get("name") - if name == "Aspera Protocol": - aspera_url = location.get("value") - elif name == "FTP Protocol": - ftp_url = location.get("value") - - if protocol == "aspera": - if not aspera_url: - raise ValueError("Aspera URL not available") - return aspera_url - - if not ftp_url: - raise ValueError("FTP URL not available") - if protocol == "ftp": - return ftp_url - if protocol == "globus": - return ftp_url.replace( - Files.PRIDE_ARCHIVE_FTP_URL_PREFIX, - Files.PRIDE_ARCHIVE_HTTPS_URL_PREFIX, - 1, - ) - if protocol == "s3": - return ftp_url - raise ValueError(f"Unsupported protocol: {protocol}") + """Shim — see :func:`pridepy.providers.util._get_download_url`.""" + from pridepy.providers import util + return util._get_download_url(file_record, protocol) @staticmethod def _resolve_local_path(file_record: Dict, output_folder: str) -> str: - """ - Compute the canonical local path for a file regardless of transfer protocol. - """ - try: - canonical_url = Files._get_download_url(file_record, "ftp") - except ValueError: - canonical_url = "" - if canonical_url: - return Files.get_output_file_name(canonical_url, file_record, output_folder) - return os.path.join(output_folder, file_record["fileName"]) + """Shim — see :func:`pridepy.providers.util._resolve_local_path`.""" + from pridepy.providers import util + return util._resolve_local_path(file_record, output_folder) @staticmethod def _protocol_sequence(protocol: str) -> List[str]: - """ - Build the ordered list of protocols to try for a requested download mode. - """ - if protocol not in Files.PROTOCOL_ORDER: - return [] - return [protocol] + [p for p in Files.PROTOCOL_ORDER if p != protocol] + """Shim — see :meth:`pridepy.providers.pride.PrideProvider._protocol_sequence`.""" + from pridepy.providers.pride import PrideProvider + return PrideProvider._protocol_sequence(protocol) @staticmethod def is_massive_accession(accession: str) -> bool: - """ - Return True when the accession looks like a MassIVE dataset accession. - """ - if not accession: - return False - return bool(re.fullmatch(r"R?MSV\d{9}", accession.upper())) + """Shim — see :meth:`pridepy.providers.massive.MassiveProvider.matches`.""" + from pridepy.providers.massive import MassiveProvider + return MassiveProvider.matches(accession) @staticmethod def _get_massive_public_root(accession: str) -> str: - normalized_accession = accession.upper() - return f"/v01/{normalized_accession}" + from pridepy.providers.massive import MassiveProvider + return MassiveProvider._get_public_root(accession) @staticmethod def _get_massive_public_ftp_url(accession: str, remote_path: str) -> str: - root_path = Files._get_massive_public_root(accession).rstrip("/") - relative_path = remote_path - if remote_path.startswith(root_path): - relative_path = remote_path[len(root_path) :].lstrip("/") - return f"{Files.MASSIVE_ARCHIVE_FTP_URL_PREFIX}{accession.upper()}/{relative_path}" + from pridepy.providers.massive import MassiveProvider + return MassiveProvider._get_public_ftp_url(accession, remote_path) @staticmethod def _map_massive_collection_to_category(collection: str) -> str: - return Files.MASSIVE_CATEGORY_MAP.get(collection.lower(), "OTHER") + from pridepy.providers.massive import MassiveProvider + return MassiveProvider._map_collection_to_category(collection) @staticmethod def _build_massive_file_record(accession: str, ftp_url: str) -> Dict: - parsed = urlparse(ftp_url) - root_prefix = f"/v01/{accession.upper()}/" - relative_path = parsed.path - if relative_path.startswith(root_prefix): - relative_path = relative_path[len(root_prefix) :] - relative_path = relative_path.lstrip("/") - collection = relative_path.split("/", 1)[0] if relative_path else "" - return { - "accession": accession.upper(), - "fileName": os.path.basename(parsed.path), - "fileCategory": {"value": Files._map_massive_collection_to_category(collection)}, - "publicFileLocations": [{"name": "FTP Protocol", "value": ftp_url}], - "relativePath": relative_path, - "collection": collection, - "source": "MassIVE", - } + from pridepy.providers.massive import MassiveProvider + return MassiveProvider._build_file_record(accession, ftp_url) @staticmethod def is_jpost_accession(accession: str) -> bool: - """ - Return True when the accession looks like a JPOST dataset accession. - """ - if not accession: - return False - return bool(re.fullmatch(r"JPST\d{6}", accession.upper())) + """Shim — see :meth:`pridepy.providers.jpost.JpostProvider.matches`.""" + from pridepy.providers.jpost import JpostProvider + return JpostProvider.matches(accession) @staticmethod def _get_jpost_public_root(accession: str) -> str: - return f"/{accession.upper()}" + from pridepy.providers.jpost import JpostProvider + return JpostProvider._get_public_root(accession) @staticmethod def _get_jpost_public_ftp_url(accession: str, remote_path: str) -> str: - root_path = Files._get_jpost_public_root(accession).rstrip("/") - relative_path = remote_path - if remote_path.startswith(root_path): - relative_path = remote_path[len(root_path) :].lstrip("/") - return f"{Files.JPOST_ARCHIVE_FTP_URL_PREFIX}{accession.upper()}/{relative_path}" + from pridepy.providers.jpost import JpostProvider + return JpostProvider._get_public_ftp_url(accession, remote_path) @staticmethod - def _build_jpost_file_record( - accession: str, ftp_url: str, category_from_proxi: Optional[str] = None - ) -> Dict: - """ - Build a pridepy file record for a JPOST file. + def _build_jpost_file_record(accession, ftp_url, category_from_proxi=None): + from pridepy.providers.jpost import JpostProvider + return JpostProvider._build_file_record(accession, ftp_url, category_from_proxi) - When ``category_from_proxi`` is provided (e.g. ``"Associated raw file URI"``), - the PROXI CV name takes precedence over the heuristic collection-from-path - mapping. Falls back to the same path-segment heuristic used for MassIVE - when the category isn't known. - """ - parsed = urlparse(ftp_url) - root_prefix = f"/{accession.upper()}/" - relative_path = parsed.path - if relative_path.startswith(root_prefix): - relative_path = relative_path[len(root_prefix) :] - relative_path = relative_path.lstrip("/") - collection = relative_path.split("/", 1)[0] if relative_path else "" - if category_from_proxi and category_from_proxi in Files.JPOST_PROXI_CATEGORY_MAP: - category = Files.JPOST_PROXI_CATEGORY_MAP[category_from_proxi] - else: - category = Files._map_massive_collection_to_category(collection) - return { - "accession": accession.upper(), - "fileName": os.path.basename(parsed.path), - "fileCategory": {"value": category}, - "publicFileLocations": [{"name": "FTP Protocol", "value": ftp_url}], - "relativePath": relative_path, - "collection": collection, - "source": "JPOST", - } + @staticmethod + def _build_iprox_file_record(accession, https_url, category_from_px=None): + """Shim — see :meth:`pridepy.providers.iprox.IproxProvider._build_file_record`.""" + from pridepy.providers.iprox import IproxProvider + return IproxProvider._build_file_record(accession, https_url, category_from_px) @staticmethod - def _build_iprox_file_record( - accession: str, https_url: str, category_from_px: Optional[str] = None - ) -> Dict: - """ - Build a pridepy file record for an iProX file. iProX exposes files - over anonymous HTTPS at - ``http://download.iprox.org///``; - ``category_from_px`` is the ``cvParam`` ``name`` from the dataset's - ProteomeXchange XML (e.g. ``"Associated raw file URI"``). - """ - parsed = urlparse(https_url) - root_prefix = f"/{accession.upper()}/" - relative_path = parsed.path - if relative_path.startswith(root_prefix): - relative_path = relative_path[len(root_prefix) :] - relative_path = relative_path.lstrip("/") - collection = relative_path.split("/", 1)[0] if relative_path else "" - if category_from_px and category_from_px in Files.IPROX_PX_CATEGORY_MAP: - category = Files.IPROX_PX_CATEGORY_MAP[category_from_px] - else: - category = Files._map_massive_collection_to_category(collection) - return { - "accession": accession.upper(), - "fileName": os.path.basename(parsed.path), - "fileCategory": {"value": category}, - # ``FTP Protocol`` is the existing label the download dispatcher - # uses to locate a file URL; here it actually points at HTTPS. - # ``_download_direct_download_records`` routes by URL scheme. - "publicFileLocations": [{"name": "FTP Protocol", "value": https_url}], - "relativePath": relative_path, - "collection": collection, - "source": "iProX", - } + def _get_iprox_public_root(accession: str) -> str: + from pridepy.providers.iprox import IproxProvider + return IproxProvider._get_public_root(accession) + + @staticmethod + def _get_iprox_public_ftp_url(accession: str, remote_path: str) -> str: + from pridepy.providers.iprox import IproxProvider + return IproxProvider._get_public_ftp_url(accession, remote_path) @staticmethod def is_direct_download_accession(accession: str) -> bool: + """Shim — True for MassIVE/JPOST/iProX (explicitly excludes PRIDE). + + PRIDE is also a registered provider but PRIDE downloads go through + the multi-protocol orchestrator (FTP/Aspera/S3/Globus with checksum + validation and fallback), not the direct-download partitioned-by-URL- + scheme path. So we filter PRIDE out here. """ - Return True when the accession is served by a public repository that - pridepy supports via direct downloads (no ProteomeXchange API). - MassIVE and JPOST use FTP(S); iProX uses anonymous HTTPS via - ``download.iprox.org``. - """ - return ( - Files.is_massive_accession(accession) - or Files.is_jpost_accession(accession) - or Files.is_iprox_accession(accession) - ) + from pridepy.providers import registry + try: + provider = registry.resolve(accession) + except ValueError: + return False + return provider.name != "pride" @staticmethod def is_iprox_accession(accession: str) -> bool: - """ - Return True when the accession looks like an iProX dataset accession - (``IPX`` followed by 7-10 digits). iProX exposes the dataset - ProteomeXchange XML at - ``http://download.iprox.org//PX_.xml`` and the - referenced files are downloadable from ``download.iprox.org`` over - anonymous HTTPS with byte-range support. - """ - if not accession: - return False - return bool(re.fullmatch(r"IPX\d{7,10}", accession.upper())) + """Shim — see :meth:`pridepy.providers.iprox.IproxProvider.matches`.""" + from pridepy.providers.iprox import IproxProvider + return IproxProvider.matches(accession) @staticmethod def _repo_uses_tls(accession: str) -> bool: - """ - Whether the public FTP server for ``accession`` requires FTP over TLS. - MassIVE rejects plain anonymous FTP (``421 TLS is required``); JPOST - accepts plain FTP. - """ - return Files.is_massive_accession(accession) + """Shim — returns the resolved provider's use_tls flag (False if unknown).""" + from pridepy.providers import registry + try: + provider = registry.resolve(accession) + except ValueError: + return False + return getattr(provider, "use_tls", False) @staticmethod def _walk_ftp_tree(ftp: FTP, remote_dir: str) -> List[str]: - """ - Recursively list files under a remote FTP directory. - """ - file_paths: List[str] = [] - try: - entries = list(ftp.mlsd(remote_dir)) - for name, facts in entries: - if name in {".", ".."}: - continue - child_path = posixpath.join(remote_dir.rstrip("/"), name) - if facts.get("type") == "dir": - file_paths.extend(Files._walk_ftp_tree(ftp, child_path)) - elif facts.get("type") == "file": - file_paths.append(child_path) - return file_paths - except (AttributeError, ftplib.error_perm): - pass - - current_dir = ftp.pwd() - listing: List[str] = [] - try: - ftp.cwd(remote_dir) - ftp.retrlines("LIST", listing.append) - for entry in listing: - parts = entry.split(maxsplit=8) - if len(parts) < 9: - continue - name = parts[8] - if name in {".", ".."}: - continue - child_path = posixpath.join(remote_dir.rstrip("/"), name) - if entry.startswith("d"): - file_paths.extend(Files._walk_ftp_tree(ftp, child_path)) - else: - file_paths.append(child_path) - finally: - ftp.cwd(current_dir) - return file_paths + """Shim — see :func:`pridepy.providers.transport._walk_ftp_tree`.""" + from pridepy.providers import transport + return transport._walk_ftp_tree(ftp=ftp, remote_dir=remote_dir) @staticmethod def _open_ftp_connection(host: str, use_tls: bool, timeout: int = 30) -> FTP: - """ - Open an anonymous FTP connection, transparently using FTPS when the - server requires TLS (e.g., MassIVE). When ``use_tls`` is False but the - server replies ``421 TLS is required`` to ``login``, transparently - retry with FTPS so callers don't need to know the policy in advance. - """ - if use_tls: - ftp: FTP = ftplib.FTP_TLS(host, timeout=timeout) - ftp.login() - ftp.prot_p() - else: - ftp = FTP(host, timeout=timeout) - try: - ftp.login() - except ftplib.error_temp as e: - if "TLS" in str(e).upper(): - try: - ftp.close() - except Exception: - pass - ftp = ftplib.FTP_TLS(host, timeout=timeout) - ftp.login() - ftp.prot_p() - else: - raise - ftp.set_pasv(True) - return ftp - - def _list_ftp_repo_files( - self, - host: str, - remote_root: str, - error_label: str, - use_tls: bool = False, - ) -> List[str]: - """ - Connect to an anonymous FTP host (FTP or FTPS), walk a directory tree, - and return file paths. + """Shim — see :func:`pridepy.providers.transport._open_ftp_connection`.""" + from pridepy.providers import transport + return transport._open_ftp_connection(host=host, use_tls=use_tls, timeout=timeout) - ``use_tls`` should be True for servers that reject plain FTP (e.g. - MassIVE). Centralizes connection lifecycle so a constructor failure - doesn't mask the underlying error in ``finally`` (PR #98 review). - """ - ftp: Optional[FTP] = None - try: - ftp = self._open_ftp_connection(host, use_tls=use_tls) - logging.info(f"Connected to FTP host: {host} (tls={use_tls})") - return self._walk_ftp_tree(ftp, remote_root) - except Exception as error: - raise RuntimeError( - f"Unable to list public files for {error_label}: {error}" - ) from error - finally: - if ftp is not None: - try: - ftp.quit() - except Exception: - try: - ftp.close() - except Exception: - pass + @staticmethod + def _list_ftp_repo_files(host, remote_root, error_label, use_tls=False): + """Shim — see :func:`pridepy.providers.transport._list_ftp_repo_files`.""" + from pridepy.providers import transport + return transport._list_ftp_repo_files(host=host, remote_root=remote_root, error_label=error_label, use_tls=use_tls) def _list_massive_public_files(self, accession: str) -> List[Dict]: - """ - Discover all public files for a MassIVE dataset from its anonymous FTP tree. - """ - normalized_accession = accession.upper() - remote_root = self._get_massive_public_root(normalized_accession) - remote_files = self._list_ftp_repo_files( - host=self.MASSIVE_ARCHIVE_FTP, - remote_root=remote_root, - error_label=f"MassIVE dataset {normalized_accession}", - use_tls=True, - ) - return [ - self._build_massive_file_record( - normalized_accession, - self._get_massive_public_ftp_url(normalized_accession, remote_file), - ) - for remote_file in remote_files - ] + """Shim — see :meth:`pridepy.providers.massive.MassiveProvider.list_files`.""" + from pridepy.providers.massive import MassiveProvider + return MassiveProvider().list_files(accession) def _download_massive_file_records( self, @@ -562,11 +247,12 @@ def _download_massive_file_records( ) -> None: """ Download public MassIVE files via anonymous FTP (now FTPS). - Backward-compat wrapper around :meth:`_download_direct_download_records`. + Backward-compat shim — dispatches via the provider registry. """ - self._download_direct_download_records( + from pridepy.providers import registry + registry.resolve(accession).download_files( accession=accession, - file_records=file_records, + records=file_records, output_folder=output_folder, skip_if_downloaded_already=skip_if_downloaded_already, protocol=protocol, @@ -577,12 +263,11 @@ def _list_jpost_public_files(self, accession: str) -> List[Dict]: """ Discover all public files for a JPOST dataset. - Prefers the JPOST PROXI JSON endpoint at - ``https://repository.jpostdb.org/proxi/datasets/`` since it - returns file URLs with category labels and avoids the anonymous-FTP - rate limit that ``ftp.jpostdb.org`` applies per source IP. Falls back - to walking the FTP tree if PROXI is unreachable or returns no files. + Delegates to JpostProvider but routes via the shim methods so that + test patches on ``_list_jpost_public_files_via_proxi`` and + ``_list_ftp_repo_files`` continue to intercept. """ + from pridepy.providers.jpost import JpostProvider normalized_accession = accession.upper() try: return self._list_jpost_public_files_via_proxi(normalized_accession) @@ -591,213 +276,50 @@ def _list_jpost_public_files(self, accession: str) -> List[Dict]: f"JPOST PROXI listing failed for {normalized_accession} " f"({proxi_error}); falling back to FTP tree walk." ) - remote_root = self._get_jpost_public_root(normalized_accession) + remote_root = JpostProvider._get_public_root(normalized_accession) remote_files = self._list_ftp_repo_files( - host=self.JPOST_ARCHIVE_FTP, + host=JpostProvider.ARCHIVE_FTP, remote_root=remote_root, error_label=f"JPOST dataset {normalized_accession}", ) return [ self._build_jpost_file_record( normalized_accession, - self._get_jpost_public_ftp_url(normalized_accession, remote_file), + JpostProvider._get_public_ftp_url(normalized_accession, remote_file), ) for remote_file in remote_files ] def _list_jpost_public_files_via_proxi(self, accession: str) -> List[Dict]: - """ - Fetch the JPOST PROXI dataset metadata and turn each ``datasetFiles`` - entry into a pridepy file record. The PROXI ``name`` field is mapped to - a PRIDE-style category so existing RAW/SEARCH/RESULT filtering works. - """ - import json as _json - - proxi_url = f"{self.JPOST_PROXI_BASE_URL}{accession}" - logging.info(f"Fetching JPOST PROXI metadata: {proxi_url}") - response = requests.get( - proxi_url, - headers={"Accept": "application/json"}, - timeout=30, - ) - response.raise_for_status() - data = _json.loads(response.content) - dataset_files = data.get("datasetFiles") or [] - records: List[Dict] = [] - for entry in dataset_files: - value = (entry or {}).get("value") - if not value or not value.startswith("ftp://"): - continue - records.append( - self._build_jpost_file_record( - accession, - value, - category_from_proxi=(entry or {}).get("name"), - ) - ) - if not records: - raise RuntimeError( - f"JPOST PROXI returned no FTP file URIs for {accession}" - ) - return records + """Shim — see :meth:`pridepy.providers.jpost.JpostProvider._list_via_proxi`.""" + from pridepy.providers.jpost import JpostProvider + return JpostProvider()._list_via_proxi(accession) def _list_iprox_public_files(self, accession: str) -> List[Dict]: - """ - Discover all public files for an iProX dataset. - - iProX publishes the ProteomeXchange XML for every public dataset at a - deterministic path on its anonymous HTTPS download server:: - - http://download.iprox.org//PX_.xml - - We fetch that XML, walk every ````'s ``cvParam`` entries, - and turn each ``Associated raw file URI`` (and sibling URIs for - search-engine output, result files, etc.) into a pridepy file record. - File downloads themselves go through plain HTTPS on the same host, - which supports ``Range`` requests for resume. - """ - normalized_accession = accession.upper() - xml_url = self.IPROX_PX_XML_URL_TEMPLATE.format(accession=normalized_accession) - logging.info(f"Fetching iProX PX XML: {xml_url}") - response = requests.get(xml_url, timeout=30) - response.raise_for_status() - try: - root = ET.fromstring(response.content) - except ET.ParseError as parse_error: - raise RuntimeError( - f"Unable to parse iProX PX XML for {normalized_accession}: {parse_error}" - ) from parse_error - - records: List[Dict] = [] - for dataset_file in root.iter("DatasetFile"): - for cv in dataset_file.findall("cvParam"): - name = cv.attrib.get("name") - value = cv.attrib.get("value") - if not value or not name or not name.endswith("URI"): - continue - if not value.lower().startswith(("http://", "https://")): - continue - records.append( - self._build_iprox_file_record( - normalized_accession, - value, - category_from_px=name, - ) - ) - if not records: - raise RuntimeError( - f"iProX PX XML for {normalized_accession} contained no downloadable HTTPS URIs" - ) - return records + """Shim — see :meth:`pridepy.providers.iprox.IproxProvider.list_files`.""" + from pridepy.providers.iprox import IproxProvider + return IproxProvider().list_files(accession) - def _list_direct_download_files(self, accession: str) -> List[Dict]: - """ - Dispatch to the right listing transport for a direct-download - repository: MassIVE walks FTPS, JPOST uses PROXI JSON over HTTPS with - an FTP fallback, iProX uses the dataset's PX XML over HTTPS. - """ - if self.is_massive_accession(accession): - return self._list_massive_public_files(accession) - if self.is_jpost_accession(accession): - return self._list_jpost_public_files(accession) - if self.is_iprox_accession(accession): - return self._list_iprox_public_files(accession) - raise ValueError( - f"Accession {accession} is not a direct-download repository accession" - ) - - def _download_direct_download_records( - self, - accession: str, - file_records: List[Dict], - output_folder: str, - skip_if_downloaded_already: bool, - protocol: str, - parallel_files: int = 1, - ) -> None: - """ - Download files from a direct-download repository. - - MassIVE and JPOST use anonymous FTP(S) with REST-based resume and - per-host parallel workers. iProX uses anonymous HTTPS via - ``download.iprox.org`` with ``Range``-based resume and per-file - parallel workers. URLs are partitioned by scheme so a mixed batch - (e.g. a JPOST PX XML that ever pointed at HTTPS) routes correctly. - """ - if protocol not in ("ftp", "https", "http"): - logging.warning( - "Direct downloads currently use ftp / https only. " - f"Ignoring requested protocol '{protocol}' for {accession}." - ) - - all_urls = [self._get_download_url(record, "ftp") for record in file_records] - ftp_urls = [u for u in all_urls if u.lower().startswith("ftp://")] - http_urls = [u for u in all_urls if u.lower().startswith(("http://", "https://"))] - if not ftp_urls and not http_urls: - logging.info(f"No files matched for direct-download dataset {accession}") - return - - if ftp_urls: - self.download_ftp_urls( - ftp_urls=ftp_urls, - output_folder=output_folder, - skip_if_downloaded_already=skip_if_downloaded_already, - use_tls=self._repo_uses_tls(accession), - parallel_files=parallel_files, - ) - if http_urls: - self.download_http_urls( - http_urls=http_urls, - output_folder=output_folder, - skip_if_downloaded_already=skip_if_downloaded_already, - parallel_files=parallel_files, - ) async def stream_all_files_metadata(self, output_file, accession=None): - """ - get stream all project files from PRIDE API in JSON format - """ - if accession is None: - request_url = f"{self.V3_API_BASE_URL}/files/all" - count_request_url = f"{self.V3_API_BASE_URL}/files/count" - else: - request_url = f"{self.V3_API_BASE_URL}/projects/{accession}/files/all" - count_request_url = f"{self.V3_API_BASE_URL}/projects/{accession}/files/count" - headers = {"Accept": "application/JSON"} - response = Util.get_api_call(count_request_url, headers) - total_records = response.json() - - regex_search_pattern = '"fileName"' - await Util.stream_response_to_file( - output_file, total_records, regex_search_pattern, request_url, headers - ) + """Shim — see :meth:`pridepy.providers.pride.PrideProvider.stream_all_files_metadata`.""" + from pridepy.providers.pride import PrideProvider + return await PrideProvider().stream_all_files_metadata(output_file, accession) def stream_all_files_by_project(self, accession) -> List[Dict]: - """ - get stream all project files from PRIDE API in JSON format - """ - request_url = f"{self.V3_API_BASE_URL}/projects/{accession}/files/all" - headers = {"Accept": "application/JSON"} - record_files = Util.read_json_stream(api_url=request_url, headers=headers) - return record_files + """Shim — see :meth:`pridepy.providers.pride.PrideProvider.stream_all_files_by_project`.""" + from pridepy.providers.pride import PrideProvider + return PrideProvider().stream_all_files_by_project(accession) def get_all_raw_file_list(self, project_accession): - """ - Get all raw file lists from PRIDE API for a given project_accession - :param project_accession: PRIDE accession - :return: raw file list in JSON format - """ - if self.is_direct_download_accession(project_accession): - record_files = self._list_direct_download_files(project_accession) - return [ - file for file in record_files if file["fileCategory"]["value"] == "RAW" - ] - - record_files = self.stream_all_files_by_project(project_accession) + """Get raw file list for any registered provider. - # Filter projects by fileCategory = RAW - raw_files = [file for file in record_files if file["fileCategory"]["value"] == "RAW"] - return raw_files + Returns the dataset's file records filtered to fileCategory == "RAW". + """ + from pridepy.providers import registry + provider = registry.resolve(project_accession) + records = provider.list_files(project_accession) + return [r for r in records if r["fileCategory"]["value"] == "RAW"] def download_all_raw_files( self, @@ -809,42 +331,21 @@ def download_all_raw_files( checksum_check: bool = False, parallel_files: int = 1, ): - """ - This method will download all the raw files from PRIDE PROJECT - :param output_folder: output directory where raw files will get saved - :param skip_if_downloaded_already: Boolean value to skip the download if the file has already been downloaded. - :param accession: PRIDE accession - :param protocol: ftp, aspera, globus - :param aspera_maximum_bandwidth: Aspera maximum bandwidth - :param checksum_check: Download checksum for a given project. - :return: None - """ - - if not (os.path.isdir(output_folder)): + """Download all RAW files for any registered provider.""" + if not os.path.isdir(output_folder): os.mkdir(output_folder) - - raw_files = self.get_all_raw_file_list(accession) - - if self.is_direct_download_accession(accession): - self._download_direct_download_records( - accession=accession, - file_records=raw_files, - output_folder=output_folder, - skip_if_downloaded_already=skip_if_downloaded_already, - protocol=protocol, - parallel_files=parallel_files, - ) - return - - self.download_files( - raw_files, - accession, - output_folder, - skip_if_downloaded_already, - protocol, - aspera_maximum_bandwidth=aspera_maximum_bandwidth, - checksum_check=checksum_check, + from pridepy.providers import registry + provider = registry.resolve(accession) + records = self.get_all_raw_file_list(accession) + provider.download_files( + accession=accession, + records=records, + output_folder=output_folder, + skip_if_downloaded_already=skip_if_downloaded_already, + protocol=protocol, parallel_files=parallel_files, + checksum_check=checksum_check, + aspera_maximum_bandwidth=aspera_maximum_bandwidth, ) @staticmethod @@ -855,118 +356,15 @@ def download_files_from_ftp( max_connection_retries=3, max_download_retries=3, ): - """ - Download files using a single FTP connection with a retry mechanism and a progress bar for each file. - :param file_list_json: file list in JSON format - :param output_folder: folder to download the files - :param skip_if_downloaded_already: Boolean value to skip the download if the file has already been downloaded. - :param max_connection_retries: Number of attempts to reconnect to the FTP server if the connection is lost. - :param max_download_retries: Number of attempts to retry the download of a file in case of failure. - """ - - if not os.path.isdir(output_folder): - os.makedirs(output_folder) - - def connect_ftp(): - """Helper function to establish FTP connection.""" - ftp = FTP(Files.PRIDE_ARCHIVE_FTP, timeout=30) - ftp.login() # Anonymous login - ftp.set_pasv(True) # Enable passive mode - logging.info(f"Connected to FTP host: {Files.PRIDE_ARCHIVE_FTP}") - return ftp - - connection_attempt = 0 - while connection_attempt < max_connection_retries: - try: - ftp = connect_ftp() - for file in file_list_json: - try: - # Get FTP download URL - if file["publicFileLocations"][0]["name"] == "FTP Protocol": - download_url = file["publicFileLocations"][0]["value"] - else: - download_url = file["publicFileLocations"][1]["value"] - - logging.debug("ftp_filepath:" + download_url) - - # Get output file path - new_file_path = Files.get_output_file_name( - download_url, file, output_folder - ) - - if skip_if_downloaded_already and os.path.exists(new_file_path): - logging.info("Skipping download as file already exists") - continue - - # Extract file path from the download URL - parsed_url = urlparse(download_url) - ftp_file_path = urllib.parse.unquote(parsed_url.path.lstrip("/")) - - logging.info(f"Starting FTP download: {ftp_file_path}") - - # Retry download in case of failure - download_attempt = 0 - while download_attempt < max_download_retries: - try: - # Get file size for progress tracking - total_size = ftp.size(ftp_file_path) - logging.info(f"File size: {total_size} bytes") - - # Initialize progress bar - with open(new_file_path, "wb") as f: - with tqdm( - total=total_size, - unit="B", - unit_scale=True, - desc=new_file_path, - ) as pbar: - - def callback(data): - f.write(data) - pbar.update(len(data)) - - # Retrieve the file with progress callback - ftp.retrbinary(f"RETR {ftp_file_path}", callback) - - logging.info(f"Successfully downloaded {new_file_path}") - break # Exit download retry loop if successful - except ( - socket.timeout, - ftplib.error_temp, - ftplib.error_perm, - ) as e: - download_attempt += 1 - logging.error( - f"Download failed for {new_file_path} (attempt {download_attempt}): {str(e)}" - ) - if download_attempt >= max_download_retries: - logging.error( - f"Giving up on {new_file_path} after {max_download_retries} attempts." - ) - break # Give up on this file after max retries - except (KeyError, IndexError) as e: - logging.error(f"Failed to process file due to missing data: {str(e)}") - except Exception as e: - logging.error(f"Unexpected error while processing file: {str(e)}") - ftp.quit() # Close FTP connection after all files are downloaded - logging.info(f"Disconnected from FTP host: {Files.PRIDE_ARCHIVE_FTP}") - break # Exit connection retry loop if everything was successful - except ( - socket.timeout, - ftplib.error_temp, - ftplib.error_perm, - socket.error, - ) as e: - connection_attempt += 1 - logging.error(f"FTP connection failed (attempt {connection_attempt}): {str(e)}") - if connection_attempt < max_connection_retries: - logging.info("Retrying connection...") - time.sleep(5) # Optional delay before retrying - else: - logging.error( - f"Giving up after {max_connection_retries} failed connection attempts." - ) - break + """Shim — see :meth:`pridepy.providers.pride.PrideProvider.download_files_from_ftp`.""" + from pridepy.providers.pride import PrideProvider + return PrideProvider.download_files_from_ftp( + file_list_json, + output_folder, + skip_if_downloaded_already, + max_connection_retries=max_connection_retries, + max_download_retries=max_download_retries, + ) @staticmethod def get_output_file_name(download_url, file, output_folder): @@ -1061,62 +459,18 @@ def _download_range(url, file_path, start, end, pbar, max_retries=3): @staticmethod def _parallel_download(url, file_path, position=0): - """Download a file via a single-connection HTTP stream with optional resume. - If a partial file exists and the server supports Range requests, resumes - from where it left off; otherwise restarts from scratch.""" - session = Util.create_session_with_retries() - try: - head = session.head(url, timeout=(30, 30)) - head.raise_for_status() - total_size = int(head.headers.get("content-length", 0)) - accept_ranges = head.headers.get("accept-ranges", "none").strip().lower() - except (requests.RequestException, ValueError) as exc: - logging.info(f"HEAD request failed, falling back to single connection: {exc}") - total_size = 0 - accept_ranges = "none" - - resume_size = 0 - if os.path.exists(file_path) and accept_ranges == "bytes" and total_size > 0: - resume_size = os.path.getsize(file_path) - if resume_size >= total_size: - logging.info(f"File already complete: {file_path}") - return - if resume_size > 0: - logging.info(f"Resuming download from {resume_size} bytes: {file_path}") - - headers = {"Range": f"bytes={resume_size}-"} if resume_size > 0 else {} - with session.get(url, headers=headers, stream=True, timeout=(30, 60)) as r: - r.raise_for_status() - if resume_size > 0 and r.status_code != 206: - logging.warning("Server did not honor Range request (status %s), restarting download", r.status_code) - resume_size = 0 - with tqdm(total=total_size, unit="B", unit_scale=True, desc=file_path, - initial=resume_size, position=position, leave=True) as pbar: - mode = "ab" if resume_size > 0 else "wb" - with open(file_path, mode, buffering=8 * 1024 * 1024) as f: - for chunk in r.iter_content(chunk_size=8 * 1024 * 1024): - if chunk: - f.write(chunk) - pbar.update(len(chunk)) + """Shim — see :func:`pridepy.providers.transport._parallel_download`.""" + from pridepy.providers import transport + return transport._parallel_download(url=url, file_path=file_path, position=position) @staticmethod def _globus_download_one(file, output_folder, skip_if_downloaded_already, max_retries=6, position=0): - """Download a single file via globus; used as a worker target.""" - download_url = Files._get_download_url(file, "globus") - new_file_path = Files.get_output_file_name(download_url, file, output_folder) - - if skip_if_downloaded_already and os.path.exists(new_file_path): - logging.info(f"Skipping download as file already exists: {new_file_path}") - return - - for attempt in range(1, max_retries + 1): - try: - Files._parallel_download(download_url, new_file_path, position=position) - return - except Exception as e: - logging.warning(f"Attempt {attempt}/{max_retries} failed for {file.get('fileName', '?')}: {e}") - if attempt == max_retries: - raise + """Shim — see :meth:`pridepy.providers.pride.PrideProvider._globus_download_one`.""" + from pridepy.providers.pride import PrideProvider + return PrideProvider._globus_download_one( + file, output_folder, skip_if_downloaded_already, + max_retries=max_retries, position=position, + ) @staticmethod def download_files_from_globus( @@ -1124,172 +478,28 @@ def download_files_from_globus( parallel_files: int = 1, checksum_map: Optional[Dict[str, str]] = None, ): - """ - Download files using globus transfer url with progress bar for each file. - When skip_if_downloaded_already is True, files are pre-filtered so that - only missing or incomplete files are submitted to the worker pool, - ensuring the -w parallel_files parameter is fully utilised. - When checksum_map is provided, existing files are validated against - their expected checksum; corrupted files are re-downloaded. - :param file_list_json: file list in json format - :param output_folder: folder to download the files - :param skip_if_downloaded_already: Boolean value to skip the download if the file has already been downloaded. - :param parallel_files: number of files to download simultaneously - :param checksum_map: mapping of file name to expected MD5 checksum - """ - if checksum_map is None: - checksum_map = {} - - if not (os.path.isdir(output_folder)): - os.makedirs(output_folder, exist_ok=True) - - # --- Phase 0: pre-filter files that need downloading ----------------- - files_to_download: List[Dict] = [] - for file in file_list_json: - download_url = Files._get_download_url(file, "globus") - new_file_path = Files.get_output_file_name(download_url, file, output_folder) - if skip_if_downloaded_already and os.path.exists(new_file_path): - expected_cs = checksum_map.get(file.get("fileName", "")) - if expected_cs: - valid, reason = Files.validate_download(new_file_path, expected_cs) - if not valid: - logging.warning(f"Corrupted file detected ({reason}), will re-download: {new_file_path}") - files_to_download.append(file) - continue - logging.info(f"Skipping download as file already exists: {new_file_path}") - continue - files_to_download.append(file) - - if not files_to_download: - logging.info("All files already downloaded, nothing to do.") - return - - logging.info( - f"{len(file_list_json) - len(files_to_download)} file(s) skipped, " - f"{len(files_to_download)} file(s) to download" + """Shim — see :meth:`pridepy.providers.pride.PrideProvider.download_files_from_globus`.""" + from pridepy.providers.pride import PrideProvider + return PrideProvider.download_files_from_globus( + file_list_json, output_folder, skip_if_downloaded_already, + parallel_files=parallel_files, + checksum_map=checksum_map, ) - # --- Phase 1: download (skip check already done, pass False) --------- - parallel_files = min(parallel_files, 3, len(files_to_download)) - if parallel_files < 2: - for file in files_to_download: - try: - Files._globus_download_one( - file, output_folder, False - ) - new_file_path = Files.get_output_file_name( - Files._get_download_url(file, "globus"), file, output_folder - ) - logging.info(f"Successfully downloaded {new_file_path}") - except Exception as e: - logging.error(f"Download from Globus failed: {str(e)}") - else: - logging.info(f"Downloading {len(files_to_download)} file(s) with {parallel_files} parallel workers") - with ThreadPoolExecutor(max_workers=parallel_files) as executor: - futures = { - executor.submit( - Files._globus_download_one, - file, output_folder, False, - position=idx, - ): file - for idx, file in enumerate(files_to_download) - } - for future in as_completed(futures): - try: - future.result() - except Exception as e: - logging.error(f"Download from Globus failed: {str(e)}") - @staticmethod def download_files_from_s3( file_list_json: List[Dict], output_folder: str, skip_if_downloaded_already ): - """ - Download files using S3 transfer URL with a progress bar and retry logic. - :param file_list_json: file list in JSON format - :param output_folder: folder to download the files - :param skip_if_downloaded_already: Boolean value to skip the download if the file has already been downloaded. - """ - - if not os.path.isdir(output_folder): - os.makedirs(output_folder, exist_ok=True) - - # Retry and timeout config - retry_config = Config( - retries={"max_attempts": 5, "mode": "standard"}, - connect_timeout=120, # Increase timeout to 120 seconds - read_timeout=120, # Timeout for reading data - signature_version=botocore.UNSIGNED, # Unsigned requests for public data - ) - - s3_resource = boto3.resource( - "s3", - config=retry_config, - endpoint_url=Files.S3_URL, + """Shim — see :meth:`pridepy.providers.pride.PrideProvider.download_files_from_s3`.""" + from pridepy.providers.pride import PrideProvider + return PrideProvider.download_files_from_s3( + file_list_json, output_folder, skip_if_downloaded_already, ) - bucket = s3_resource.Bucket(Files.S3_BUCKET) - - for file in file_list_json: - try: - # Determine S3 or FTP path - download_url = ( - file["publicFileLocations"][0]["value"] - if file["publicFileLocations"][0]["name"] == "FTP Protocol" - else file["publicFileLocations"][1]["value"] - ) - - ftp_base_url = "ftp://ftp.pride.ebi.ac.uk/pride/data/archive/" - s3_path = download_url.replace(ftp_base_url, "") - new_file_path = Files.get_output_file_name(download_url, file, output_folder) - - if skip_if_downloaded_already == True and os.path.exists(new_file_path): - logging.info("Skipping download as file already exists") - continue - - logging.debug(f"Downloading From S3: {s3_path}") - - # Get file size for progress tracking - obj = bucket.Object(s3_path) - total_size = obj.content_length - - # Initialize progress bar - progress = Progress(total_size, new_file_path) - - # Download with progress bar and retry handling - for attempt in range(5): - try: - bucket.download_file(s3_path, new_file_path, Callback=progress) - progress.close() - logging.info(f"Successfully downloaded {new_file_path}") - break - except botocore.exceptions.ClientError as e: - if e.response["Error"]["Code"] == "404": - logging.error("The object does not exist.") - break - else: - logging.error(f"Download failed: {e}") - if attempt < 4: - time.sleep(2**attempt) # Exponential backoff - logging.info(f"Retrying... ({attempt + 1}/5)") - else: - raise - except Exception as e: - logging.error(f"Failed to download {file['fileName']}: {e}") def get_submitted_file_path_prefix(self, accession): - """ - At pride repository, public data is disseminated according to a proper structure. - I.e. base/path/ + yyyy/mm/accession/ + submitted/ - This extracts the yyyy/mm/accession path fragment from the API by examine the file path - of a public file. - I.e. ftp://ftp.pride.ebi.ac.uk/pride/data/archive/2018/10/PXD008644/7550GI_Y.raw - :param accession: PRIDE accession - :return: path fragment (eg: 2018/10/PXD008644) - """ - results = self.get_all_raw_file_list(accession) - first_file = results[0]["publicFileLocations"][0]["value"] - path_fragment = re.search(r"\d{4}/\d{2}/PXD\d*", first_file).group() - return path_fragment + """Shim — see :meth:`pridepy.providers.pride.PrideProvider.get_submitted_file_path_prefix`.""" + from pridepy.providers.pride import PrideProvider + return PrideProvider().get_submitted_file_path_prefix(accession) def download_file_by_name( self, @@ -1316,11 +526,14 @@ def download_file_by_name( :param checksum_check: Download checksum for a given project. """ - if not (os.path.isdir(output_folder)): + if not os.path.isdir(output_folder): os.mkdir(output_folder) + from pridepy.providers import registry + provider = registry.resolve(accession) + ## Check type of project - if self.is_direct_download_accession(accession): + if provider.name in ("massive", "jpost", "iprox"): logging.info( "Downloading file from public direct-download dataset {}".format(accession) ) @@ -1329,9 +542,9 @@ def download_file_by_name( raise Exception( "File name {} not found in dataset {}".format(file_name, accession) ) - self._download_direct_download_records( + provider.download_files( accession=accession, - file_records=response, + records=response, output_folder=output_folder, skip_if_downloaded_already=skip_if_downloaded_already, protocol=protocol, @@ -1389,137 +602,31 @@ def get_file_from_api(self, accession, file_name) -> List[Dict]: :param file_name: file name :return: file in json format """ - + from pridepy.providers import registry try: - if self.is_direct_download_accession(accession): - files = self._list_direct_download_files(accession) - return [f for f in files if f["fileName"] == file_name] - files = self.stream_all_files_by_project(accession) - file = [f for f in files if f["fileName"] == file_name] - return file + records = registry.resolve(accession).list_files(accession) + return [r for r in records if r["fileName"] == file_name] except Exception as e: raise Exception("File not found " + str(e)) def download_private_file_name(self, accession, file_name, output_folder, username, password): - """ - Get the information for a given private file to be downloaded from the api. - :param accession: Project accession - :param file_name: The file name to be downloaded - :param username: Username with access to the dataset - :param password: Password for user with access to the dataset - """ - - auth = Authentication() - auth_token = auth.get_token(username, password) - validate_token = auth.validate_token(auth_token) - logging.info("Valid token after login: {}".format(validate_token)) - - url = self.API_PRIVATE_URL + "/projects/{}/files?search={}".format(accession, file_name) - content = requests.get(url, headers={"Authorization": "Bearer {}".format(auth_token)}) - if content.ok and content.status_code == 200: - json_file = content.json() - if ( - "_embedded" in json_file - and "files" in json_file["_embedded"] - and len(json_file["_embedded"]["files"]) == 1 - ): - download_url = json_file["_embedded"]["files"][0]["_links"]["download"]["href"] - logging.info(download_url) - - # Create a clean filename to save the downloaded file - new_file_path = os.path.join(output_folder, f"{file_name}") - - session = Util.create_session_with_retries() # Create session with retries - # Check if the file already exists - if os.path.exists(new_file_path): - resume_header = {"Range": f"bytes={os.path.getsize(new_file_path)}-"} - mode = "ab" # Append to file - resume_size = os.path.getsize(new_file_path) - else: - resume_header = {} - mode = "wb" # Write new file - resume_size = 0 - - with session.get( - download_url, stream=True, headers=resume_header, timeout=(10, 60) - ) as r: - r.raise_for_status() - total_size = int(r.headers.get("content-length", 0)) + resume_size - block_size = 1024 * 1024 # 1 MB chunks - - with tqdm( - total=total_size, - unit="B", - unit_scale=True, - desc=new_file_path, - initial=resume_size, - ) as pbar: - with open(new_file_path, mode) as f: - for chunk in r.iter_content(chunk_size=block_size): - if chunk: - f.write(chunk) - pbar.update(len(chunk)) - - logging.info(f"Successfully downloaded {new_file_path}") - - else: - logging.info( - "File name {} found more than once for the given project {}".format( - file_name, accession - ) - ) - else: - logging.info( - f"File name {file_name} now found in the project {accession}, or user don't have access" - ) - raise Exception( - f"File name {file_name} now found in the project {accession}, or user don't have access" - ) + """Shim — see :meth:`pridepy.providers.pride.PrideProvider.download_private_file_name`.""" + from pridepy.providers.pride import PrideProvider + return PrideProvider().download_private_file_name( + accession, file_name, output_folder, username, password, + ) @staticmethod def get_ascp_binary(): - """ - Detect the OS and architecture, and return the appropriate ascp binary path. - - Returns: - str: Path to the correct ascp binary. - """ - os_type = platform.system().lower() - arch, _ = platform.architecture() - aspera_dir = importlib.resources.files("pridepy").joinpath("aspera/") - - if os_type == "linux": - if arch == "32bit": - return os.path.join(aspera_dir, "linux-32", "ascp") - elif arch == "64bit": - return os.path.join(aspera_dir, "linux-64", "ascp") - elif os_type == "darwin": # macOS (intel-based) - return os.path.join(aspera_dir, "mac-intel", "ascp") - elif os_type == "windows": - if arch == "32bit": - return os.path.join(aspera_dir, "windows-32", "ascp.exe") - elif arch == "64bit": - return os.path.join(aspera_dir, "windows-64", "ascp.exe") - else: - raise OSError(f"Unsupported OS or architecture: {os_type}, {arch}") + """Shim — see :meth:`pridepy.providers.pride.PrideProvider.get_ascp_binary`.""" + from pridepy.providers.pride import PrideProvider + return PrideProvider.get_ascp_binary() @staticmethod def save_checksum_file(accession, output_folder): - """ - Download and persist the checksum manifest for a PRIDE accession. - """ - os.makedirs(output_folder, exist_ok=True) - url = f"{Files.V3_API_BASE_URL}/files/checksum/{accession}" - headers = {"accept": "text/plain"} - request = urllib.request.Request(url, headers=headers, method="GET") - logging.info(f"Fetching checksum file from {url}") - with urllib.request.urlopen(request) as response: - data = response.read().decode("utf-8") - # Save the data to a .tsv file - output_path = os.path.join(output_folder, f"{accession}-checksum.tsv") - with open(output_path, "w", encoding="utf-8") as file: - file.write(data) - return output_path + """Shim — see :meth:`pridepy.providers.pride.PrideProvider.save_checksum_file`.""" + from pridepy.providers.pride import PrideProvider + return PrideProvider.save_checksum_file(accession, output_folder) @staticmethod def _batch_download_by_protocol( @@ -1531,44 +638,22 @@ def _batch_download_by_protocol( parallel_files: int = 1, checksum_map: Optional[Dict[str, str]] = None, ) -> None: + """Shim — see :meth:`pridepy.providers.pride.PrideProvider._batch_download_by_protocol`. + + Tests patch this method via ``patch.object(Files, "_batch_download_by_protocol")``; + :class:`PrideProvider` calls back through ``Files.X`` so those patches + keep intercepting. """ - Transfer a batch of files with one protocol, reusing a single - connection where the underlying helper supports it (FTP, S3). - """ - if not file_list: - return - if protocol == "ftp": - Files.download_files_from_ftp( - file_list, - output_folder, - skip_if_downloaded_already=skip_if_downloaded_already, - ) - return - if protocol == "aspera": - Files.download_files_from_aspera( - file_list, - output_folder, - skip_if_downloaded_already=skip_if_downloaded_already, - maximum_bandwidth=aspera_maximum_bandwidth, - ) - return - if protocol == "globus": - Files.download_files_from_globus( - file_list, - output_folder, - skip_if_downloaded_already=skip_if_downloaded_already, - parallel_files=parallel_files, - checksum_map=checksum_map or {}, - ) - return - if protocol == "s3": - Files.download_files_from_s3( - file_list, - output_folder, - skip_if_downloaded_already=skip_if_downloaded_already, - ) - return - raise ValueError(f"Unsupported protocol: {protocol}") + from pridepy.providers.pride import PrideProvider + return PrideProvider._batch_download_by_protocol( + file_list, + output_folder, + protocol, + skip_if_downloaded_already, + aspera_maximum_bandwidth, + parallel_files=parallel_files, + checksum_map=checksum_map, + ) @staticmethod def _download_with_fallback( @@ -1580,52 +665,17 @@ def _download_with_fallback( max_protocol_retries: int = 2, parallel_files: int = 1, ) -> bool: - """ - Download one file by trying each protocol in sequence, validating - after every attempt. Intended as the per-file fallback path; batch - download of the primary protocol is handled separately. - """ - local_path = Files._resolve_local_path(file_record, output_folder) - - for protocol in protocol_sequence: - for attempt in range(1, max_protocol_retries + 1): - logging.info( - f"Downloading {file_record['fileName']} via {protocol} " - f"(attempt {attempt}/{max_protocol_retries})" - ) - try: - Files._remove_if_exists(local_path) - Files._batch_download_by_protocol( - [file_record], - output_folder, - protocol, - skip_if_downloaded_already=False, - aspera_maximum_bandwidth=aspera_maximum_bandwidth, - parallel_files=parallel_files, - ) - except Exception as error: - logging.error( - f"Protocol {protocol} failed for {file_record['fileName']}: {error}" - ) - - valid, reason = Files.validate_download(local_path, expected_checksum) - if valid: - logging.info( - f"File {file_record['fileName']} downloaded successfully via {protocol}" - ) - return True - - logging.warning( - f"Validation failed for {file_record['fileName']} via {protocol}: {reason}" - ) - Files._remove_if_exists(local_path) - - logging.warning( - f"Protocol {protocol} exhausted for {file_record['fileName']}, switching protocol." - ) - - logging.error(f"All protocol attempts failed for {file_record['fileName']}") - return False + """Shim — see :meth:`pridepy.providers.pride.PrideProvider._download_with_fallback`.""" + from pridepy.providers.pride import PrideProvider + return PrideProvider._download_with_fallback( + file_record, + output_folder, + protocol_sequence, + expected_checksum, + aspera_maximum_bandwidth, + max_protocol_retries=max_protocol_retries, + parallel_files=parallel_files, + ) @staticmethod def download_files( @@ -1638,94 +688,18 @@ def download_files( checksum_check=False, parallel_files: int = 1, ): - """ - Download files using either FTP or Aspera transfer protocol. - :param file_list_json: File list in JSON format - :param accession: Project accession - :param output_folder: Folder to download the files - :param protocol: ftp, aspera, globus - :param aspera_maximum_bandwidth: parameter in Aspera sets the maximum bandwidth for the transfer. - :param skip_if_downloaded_already: Boolean value to skip the download if the file has already been downloaded. - """ - protocols_supported = ["ftp", "aspera", "globus", "s3"] - if protocol not in protocols_supported: - logging.error("Protocol should be one of ftp, aspera, globus, s3") - return - - os.makedirs(output_folder, exist_ok=True) - - checksum_map: Dict[str, str] = {} - if checksum_check: - checksum_file_path = Files.save_checksum_file(accession, output_folder) - checksum_map = Files.read_checksum_file(checksum_file_path) - logging.info(f"Loaded checksums for {len(checksum_map)} files") - - if not file_list_json: - return - - protocol_sequence = Files._protocol_sequence(protocol) - primary_protocol = protocol_sequence[0] - # Retry with the primary protocol first, then fall back to others - fallback_sequence = protocol_sequence - - # Phase 1: batch download with the requested protocol. Reuses a single - # FTP/S3 connection for all files (the previous behaviour) instead of - # paying the per-file reconnect cost in the common happy path. - logging.info( - f"Downloading {len(file_list_json)} file(s) via {primary_protocol} (batch)" + """Shim — see :meth:`pridepy.providers.pride.PrideProvider._download_files_batch`.""" + from pridepy.providers.pride import PrideProvider + return PrideProvider._download_files_batch( + file_list_json, + accession, + output_folder, + skip_if_downloaded_already, + protocol=protocol, + aspera_maximum_bandwidth=aspera_maximum_bandwidth, + checksum_check=checksum_check, + parallel_files=parallel_files, ) - try: - Files._batch_download_by_protocol( - file_list_json, - output_folder, - primary_protocol, - skip_if_downloaded_already=skip_if_downloaded_already, - aspera_maximum_bandwidth=aspera_maximum_bandwidth, - parallel_files=parallel_files, - checksum_map=checksum_map, - ) - except Exception as exc: - logging.warning( - f"Batch {primary_protocol} run hit an error; will retry individual failures: {exc}" - ) - - # Phase 2: validate every file and fall back per-file for the ones - # that are missing or invalid. - logging.info("Phase 2: validating %d downloaded file(s)", len(file_list_json)) - failed_files: List[str] = [] - for i, file_record in enumerate(file_list_json, 1): - expected_checksum = checksum_map.get(file_record["fileName"]) - local_path = Files._resolve_local_path(file_record, output_folder) - logging.info("Validating [%d/%d] %s", i, len(file_list_json), file_record["fileName"]) - valid, reason = Files.validate_download(local_path, expected_checksum) - if valid: - continue - - logging.warning( - f"{file_record['fileName']} invalid after {primary_protocol} ({reason})" - ) - if "checksum mismatch" in reason: - Files._remove_if_exists(local_path) - - if not fallback_sequence: - failed_files.append(file_record.get("fileName", "")) - continue - - success = Files._download_with_fallback( - file_record=file_record, - output_folder=output_folder, - protocol_sequence=fallback_sequence, - expected_checksum=expected_checksum, - aspera_maximum_bandwidth=aspera_maximum_bandwidth, - parallel_files=parallel_files, - ) - if not success: - failed_files.append(file_record.get("fileName", "")) - - if failed_files: - failed_summary = ", ".join(failed_files) - logging.error(f"Failed to download {len(failed_files)} file(s): {failed_summary}") - raise RuntimeError(f"Failed to download {len(failed_files)} file(s): {failed_summary}") def download_files_by_list( self, @@ -1757,10 +731,10 @@ def download_files_by_list( if not file_names: raise ValueError("file_names must contain at least one filename") - if self.is_direct_download_accession(accession): - all_files = self._list_direct_download_files(accession) - else: - all_files = self.stream_all_files_by_project(accession) + from pridepy.providers import registry + provider = registry.resolve(accession) + all_files = provider.list_files(accession) + requested = set(file_names) matched = [f for f in all_files if f.get("fileName") in requested] missing = sorted(requested - {f.get("fileName") for f in matched}) @@ -1771,26 +745,15 @@ def download_files_by_list( f"No matching files in project {accession} for: {sorted(requested)}" ) - if self.is_direct_download_accession(accession): - self._download_direct_download_records( - accession=accession, - file_records=matched, - output_folder=output_folder, - skip_if_downloaded_already=skip_if_downloaded_already, - protocol=protocol, - parallel_files=parallel_files, - ) - return - - self.download_files( - matched, - accession, - output_folder, - skip_if_downloaded_already, - protocol, - aspera_maximum_bandwidth=aspera_maximum_bandwidth, - checksum_check=checksum_check, + provider.download_files( + accession=accession, + records=matched, + output_folder=output_folder, + skip_if_downloaded_already=skip_if_downloaded_already, + protocol=protocol, parallel_files=parallel_files, + checksum_check=checksum_check, + aspera_maximum_bandwidth=aspera_maximum_bandwidth, ) @staticmethod @@ -2047,26 +1010,18 @@ def download_all_category_files( """ if categories is None: categories = [category] if category else ["RAW"] - raw_files = self.get_all_category_file_list(accession, categories) - if self.is_direct_download_accession(accession): - self._download_direct_download_records( - accession=accession, - file_records=raw_files, - output_folder=output_folder, - skip_if_downloaded_already=skip_if_downloaded_already, - protocol=protocol, - parallel_files=parallel_files, - ) - return - self.download_files( - raw_files, - accession, - output_folder, - skip_if_downloaded_already, - protocol, - aspera_maximum_bandwidth=aspera_maximum_bandwidth, - checksum_check=checksum_check, + records = self.get_all_category_file_list(accession, categories) + from pridepy.providers import registry + provider = registry.resolve(accession) + provider.download_files( + accession=accession, + records=records, + output_folder=output_folder, + skip_if_downloaded_already=skip_if_downloaded_already, + protocol=protocol, parallel_files=parallel_files, + checksum_check=checksum_check, + aspera_maximum_bandwidth=aspera_maximum_bandwidth, ) def get_all_category_file_list( @@ -2081,17 +1036,10 @@ def get_all_category_file_list( """ if isinstance(categories, str): categories = [categories] - category_set = {category.upper() for category in categories} - - if self.is_direct_download_accession(accession): - record_files = self._list_direct_download_files(accession) - else: - record_files = self.stream_all_files_by_project(accession) - - category_files = [ - file for file in record_files if file["fileCategory"]["value"] in category_set - ] - return category_files + category_set = {c.upper() for c in categories} + from pridepy.providers import registry + records = registry.resolve(accession).list_files(accession) + return [r for r in records if r["fileCategory"]["value"] in category_set] # ------------------------------- # ProteomeXchange support @@ -2177,8 +1125,9 @@ def download_px_raw_files( @staticmethod def _local_path_for_url(download_url: str, output_folder: str) -> str: - filename = os.path.basename(urlparse(download_url).path) - return os.path.join(output_folder, filename) + """Shim — see :func:`pridepy.providers.transport._local_path_for_url`.""" + from pridepy.providers import transport + return transport._local_path_for_url(download_url=download_url, output_folder=output_folder) @staticmethod def _download_one_ftp_path( @@ -2189,73 +1138,16 @@ def _download_one_ftp_path( max_download_retries: int, position: int = 0, ) -> None: - """ - Download a single FTP path over an existing connection, with REST resume - and per-file retry. Raises on giving up so the caller can decide what to do. - """ - if skip_if_downloaded_already and os.path.exists(local_path): - logging.info(f"Skipping download as file already exists: {local_path}") - return - - attempt = 0 - last_error: Optional[Exception] = None - while attempt < max_download_retries: - try: - total_size = ftp.size(ftp_path) - if os.path.exists(local_path): - current_size = os.path.getsize(local_path) - mode = "ab" - else: - current_size = 0 - mode = "wb" - - with open(local_path, mode) as f, tqdm( - total=total_size, - unit="B", - unit_scale=True, - desc=local_path, - initial=current_size, - position=position, - leave=True, - ) as pbar: - def callback(data): - f.write(data) - pbar.update(len(data)) - - if current_size: - try: - ftp.sendcmd(f"REST {current_size}") - except Exception: - current_size = 0 - f.seek(0) - f.truncate() - ftp.retrbinary(f"RETR {ftp_path}", callback) - - # Post-transfer integrity check: server-reported size must match - # the local size. Catches half-finished transfers that retrbinary - # didn't raise on (e.g. server closed the data channel early). - # The next iteration will REST-resume from where we left off. - if total_size: - final_size = os.path.getsize(local_path) - if final_size != total_size: - attempt += 1 - logging.error( - f"Size mismatch for {local_path}: " - f"got {final_size} bytes, expected {total_size} " - f"(attempt {attempt})" - ) - continue - logging.info(f"Successfully downloaded {local_path}") - return - except (socket.timeout, ftplib.error_temp, ftplib.error_perm) as e: - attempt += 1 - last_error = e - logging.error( - f"Download failed for {local_path} (attempt {attempt}): {e}" - ) - raise RuntimeError( - f"Giving up on {local_path} after {max_download_retries} attempts" - ) from last_error + """Shim — see :func:`pridepy.providers.transport._download_one_ftp_path`.""" + from pridepy.providers import transport + return transport._download_one_ftp_path( + ftp=ftp, + ftp_path=ftp_path, + local_path=local_path, + skip_if_downloaded_already=skip_if_downloaded_already, + max_download_retries=max_download_retries, + position=position, + ) @staticmethod def _download_ftp_paths_serial( @@ -2267,47 +1159,17 @@ def _download_ftp_paths_serial( max_connection_retries: int, max_download_retries: int, ) -> None: - """Download all paths from one host over a single (reused) connection.""" - connection_attempt = 0 - while connection_attempt < max_connection_retries: - try: - ftp = Files._open_ftp_connection(host, use_tls=use_tls) - logging.info(f"Connected to FTP host: {host} (tls={use_tls})") - for ftp_path in paths: - local_path = os.path.join(output_folder, os.path.basename(ftp_path)) - try: - Files._download_one_ftp_path( - ftp=ftp, - ftp_path=ftp_path, - local_path=local_path, - skip_if_downloaded_already=skip_if_downloaded_already, - max_download_retries=max_download_retries, - ) - except Exception as e: - logging.error( - f"Failed to download {ftp_path} from {host}: {e}" - ) - try: - ftp.quit() - except Exception: - try: - ftp.close() - except Exception: - pass - logging.info(f"Disconnected from FTP host: {host}") - return - except (socket.timeout, ftplib.error_temp, ftplib.error_perm, OSError) as e: - connection_attempt += 1 - logging.error( - f"FTP connection failed (attempt {connection_attempt}): {e}" - ) - if connection_attempt < max_connection_retries: - logging.info("Retrying connection...") - time.sleep(5) - else: - logging.error( - f"Giving up after {max_connection_retries} failed connection attempts to {host}." - ) + """Shim — see :func:`pridepy.providers.transport._download_ftp_paths_serial`.""" + from pridepy.providers import transport + return transport._download_ftp_paths_serial( + host=host, + paths=paths, + output_folder=output_folder, + skip_if_downloaded_already=skip_if_downloaded_already, + use_tls=use_tls, + max_connection_retries=max_connection_retries, + max_download_retries=max_download_retries, + ) @staticmethod def _download_ftp_paths_parallel( @@ -2320,55 +1182,18 @@ def _download_ftp_paths_parallel( max_download_retries: int, parallel_files: int, ) -> None: - """ - Download paths concurrently using ``parallel_files`` workers; each - worker opens its own FTP connection so transfers don't serialize. - """ - def worker(ftp_path: str, position: int) -> None: - local_path = os.path.join(output_folder, os.path.basename(ftp_path)) - if skip_if_downloaded_already and os.path.exists(local_path): - logging.info(f"Skipping download as file already exists: {local_path}") - return - connection_attempt = 0 - while connection_attempt < max_connection_retries: - try: - ftp = Files._open_ftp_connection(host, use_tls=use_tls) - try: - Files._download_one_ftp_path( - ftp=ftp, - ftp_path=ftp_path, - local_path=local_path, - skip_if_downloaded_already=False, - max_download_retries=max_download_retries, - position=position, - ) - return - finally: - try: - ftp.quit() - except Exception: - try: - ftp.close() - except Exception: - pass - except (socket.timeout, ftplib.error_temp, ftplib.error_perm, OSError) as e: - connection_attempt += 1 - logging.error( - f"FTP connection failed for {ftp_path} (attempt {connection_attempt}): {e}" - ) - if connection_attempt < max_connection_retries: - time.sleep(5) - logging.error(f"Giving up on {ftp_path} from {host}") - - with ThreadPoolExecutor(max_workers=parallel_files) as executor: - futures = [ - executor.submit(worker, path, idx) for idx, path in enumerate(paths) - ] - for future in as_completed(futures): - try: - future.result() - except Exception as e: - logging.error(f"Parallel FTP download error: {e}") + """Shim — see :func:`pridepy.providers.transport._download_ftp_paths_parallel`.""" + from pridepy.providers import transport + return transport._download_ftp_paths_parallel( + host=host, + paths=paths, + output_folder=output_folder, + skip_if_downloaded_already=skip_if_downloaded_already, + use_tls=use_tls, + max_connection_retries=max_connection_retries, + max_download_retries=max_download_retries, + parallel_files=parallel_files, + ) @staticmethod def download_ftp_urls( @@ -2380,48 +1205,17 @@ def download_ftp_urls( use_tls: bool = False, parallel_files: int = 1, ) -> None: - """ - Download a list of FTP URLs with retries, REST-based resume, and - optional parallel workers. - - :param use_tls: Open the FTP connection with TLS (FTP_TLS / PROT P). - Required for hosts that reject plain anonymous FTP (e.g. MassIVE). - When False but the server replies ``421 TLS is required``, the - connection is transparently retried over TLS. - :param parallel_files: When >1, downloads run concurrently with that - many worker connections per host (capped at the number of files). - """ - if not os.path.isdir(output_folder): - os.makedirs(output_folder, exist_ok=True) - - host_to_paths: Dict[str, List[str]] = {} - for url in ftp_urls: - parsed = urlparse(url) - host_to_paths.setdefault(parsed.hostname, []).append(parsed.path.lstrip("/")) - - for host, paths in host_to_paths.items(): - workers = max(1, min(parallel_files, len(paths))) - if workers > 1: - Files._download_ftp_paths_parallel( - host=host, - paths=paths, - output_folder=output_folder, - skip_if_downloaded_already=skip_if_downloaded_already, - use_tls=use_tls, - max_connection_retries=max_connection_retries, - max_download_retries=max_download_retries, - parallel_files=workers, - ) - else: - Files._download_ftp_paths_serial( - host=host, - paths=paths, - output_folder=output_folder, - skip_if_downloaded_already=skip_if_downloaded_already, - use_tls=use_tls, - max_connection_retries=max_connection_retries, - max_download_retries=max_download_retries, - ) + """Shim — see :func:`pridepy.providers.transport.download_ftp_urls`.""" + from pridepy.providers import transport + return transport.download_ftp_urls( + ftp_urls=ftp_urls, + output_folder=output_folder, + skip_if_downloaded_already=skip_if_downloaded_already, + max_connection_retries=max_connection_retries, + max_download_retries=max_download_retries, + use_tls=use_tls, + parallel_files=parallel_files, + ) @staticmethod def _http_download_one( @@ -2431,30 +1225,15 @@ def _http_download_one( max_retries: int = 3, position: int = 0, ) -> None: - """ - Download a single HTTP(S) URL with HEAD-then-Range resume and retry. - Used as the worker target for both the serial loop and the parallel - ThreadPoolExecutor path. Reuses :meth:`_parallel_download` so the same - resume / restart-on-non-206 behaviour is shared with globus downloads. - """ - local_path = Files._local_path_for_url(url, output_folder) - if skip_if_downloaded_already and os.path.exists(local_path): - logging.info(f"Skipping download as file already exists: {local_path}") - return - last_error: Optional[Exception] = None - for attempt in range(1, max_retries + 1): - try: - Files._parallel_download(url, local_path, position=position) - logging.info(f"Successfully downloaded {local_path}") - return - except Exception as e: - last_error = e - logging.warning( - f"HTTP download attempt {attempt}/{max_retries} failed for {url}: {e}" - ) - raise RuntimeError( - f"Giving up on {local_path} after {max_retries} HTTP attempts" - ) from last_error + """Shim — see :func:`pridepy.providers.transport._http_download_one`.""" + from pridepy.providers import transport + return transport._http_download_one( + url=url, + output_folder=output_folder, + skip_if_downloaded_already=skip_if_downloaded_already, + max_retries=max_retries, + position=position, + ) @staticmethod def download_http_urls( @@ -2464,51 +1243,12 @@ def download_http_urls( parallel_files: int = 1, max_retries: int = 3, ) -> None: - """ - Download a list of HTTP(S) URLs with HEAD-then-Range resume, per-file - retries, and an optional ``parallel_files`` worker pool. - - When ``parallel_files`` > 1, downloads run concurrently using a - :class:`ThreadPoolExecutor`. Each worker manages its own file (a new - ``requests`` session is opened inside ``_parallel_download``) so the - only shared resource is the output directory. - """ - if not os.path.isdir(output_folder): - os.makedirs(output_folder, exist_ok=True) - - if not http_urls: - return - - workers = max(1, min(parallel_files, len(http_urls))) - if workers > 1: - logging.info( - f"Downloading {len(http_urls)} HTTP(S) file(s) with {workers} parallel workers" - ) - with ThreadPoolExecutor(max_workers=workers) as executor: - futures = [ - executor.submit( - Files._http_download_one, - url, - output_folder, - skip_if_downloaded_already, - max_retries, - idx, - ) - for idx, url in enumerate(http_urls) - ] - for future in as_completed(futures): - try: - future.result() - except Exception as e: - logging.error(f"Parallel HTTP download error: {e}") - else: - for url in http_urls: - try: - Files._http_download_one( - url, - output_folder, - skip_if_downloaded_already, - max_retries, - ) - except Exception as e: - logging.error(f"HTTP download failed for {url}: {e}") + """Shim — see :func:`pridepy.providers.transport.download_http_urls`.""" + from pridepy.providers import transport + return transport.download_http_urls( + http_urls=http_urls, + output_folder=output_folder, + skip_if_downloaded_already=skip_if_downloaded_already, + parallel_files=parallel_files, + max_retries=max_retries, + ) diff --git a/pridepy/providers/__init__.py b/pridepy/providers/__init__.py new file mode 100644 index 0000000..ee1de17 --- /dev/null +++ b/pridepy/providers/__init__.py @@ -0,0 +1,7 @@ +"""Per-repository provider classes used by :class:`pridepy.files.files.Files`. + +Each module under this package owns the listing, transport choice, and +record-construction logic for one repository: PRIDE, MassIVE, JPOST, iProX. +The :mod:`registry` module maps an accession to the right provider; the +:mod:`transport` module hosts the shared FTP/FTPS/HTTPS download plumbing. +""" diff --git a/pridepy/providers/base.py b/pridepy/providers/base.py new file mode 100644 index 0000000..f9fa8bc --- /dev/null +++ b/pridepy/providers/base.py @@ -0,0 +1,107 @@ +"""Abstract base classes for pridepy providers.""" +from abc import ABC, abstractmethod +from typing import ClassVar, Dict, List, Optional + + +class Provider(ABC): + """Abstract base for every repository pridepy can list and download from.""" + + name: ClassVar[str] # "pride", "massive", "jpost", "iprox" + + @staticmethod + @abstractmethod + def matches(accession: str) -> bool: + """Return True if this provider should handle ``accession``.""" + + @abstractmethod + def list_files(self, accession: str) -> List[Dict]: + """Return pridepy file records for the dataset. + + Each record is a dict shaped like the PRIDE V3 API file response, + with at minimum: ``accession``, ``fileName``, ``fileCategory`` + (with nested ``value``), ``publicFileLocations`` (list of + ``{"name": ..., "value": }``). + """ + + @abstractmethod + def download_files( + self, + accession: str, + records: List[Dict], + output_folder: str, + skip_if_downloaded_already: bool, + protocol: str, + parallel_files: int = 1, + checksum_check: bool = False, + aspera_maximum_bandwidth: str = "100M", + username: Optional[str] = None, + password: Optional[str] = None, + ) -> None: + """Download the given records into ``output_folder``.""" + + +class BaseDirectDownloadProvider(Provider): + """Shared ``download_files`` for MassIVE / JPOST / iProX. + + Subclasses set the ``use_tls`` class var (True for MassIVE FTPS, False for + JPOST plain FTP) and override :meth:`list_files`. The shared + ``download_files`` implementation partitions record URLs by scheme: + ``ftp://`` URLs are handed to :meth:`Files.download_ftp_urls`; ``http(s)://`` + URLs go to :meth:`Files.download_http_urls`. It calls **back** into + ``Files`` so that test patches on ``Files.download_ftp_urls`` / + ``Files.download_http_urls`` continue to intercept the calls. + """ + + use_tls: ClassVar[bool] = False + + def download_files( + self, + accession: str, + records: List[Dict], + output_folder: str, + skip_if_downloaded_already: bool, + protocol: str, + parallel_files: int = 1, + checksum_check: bool = False, + aspera_maximum_bandwidth: str = "100M", + username: Optional[str] = None, + password: Optional[str] = None, + ) -> None: + # Lazy import: providers know about Files (the facade) only via the + # public attributes that tests may patch; avoid module-load cycle. + from pridepy.files.files import Files + + if protocol not in ("ftp", "https", "http"): + import logging + logging.warning( + "Direct downloads currently use ftp / https only. " + f"Ignoring requested protocol '{protocol}' for {accession}." + ) + + all_urls = [Files._get_download_url(record, "ftp") for record in records] + ftp_urls = [u for u in all_urls if u.lower().startswith("ftp://")] + http_urls = [ + u for u in all_urls if u.lower().startswith(("http://", "https://")) + ] + if not ftp_urls and not http_urls: + import logging + logging.info( + f"No files matched for direct-download dataset {accession}" + ) + return + + if ftp_urls: + Files.download_ftp_urls( + ftp_urls=ftp_urls, + output_folder=output_folder, + skip_if_downloaded_already=skip_if_downloaded_already, + use_tls=self.use_tls, + parallel_files=parallel_files, + ) + if http_urls: + Files.download_http_urls( + http_urls=http_urls, + output_folder=output_folder, + skip_if_downloaded_already=skip_if_downloaded_already, + parallel_files=parallel_files, + ) diff --git a/pridepy/providers/iprox.py b/pridepy/providers/iprox.py new file mode 100644 index 0000000..292307c --- /dev/null +++ b/pridepy/providers/iprox.py @@ -0,0 +1,129 @@ +"""iProX direct-download provider. + +iProX publishes the ProteomeXchange XML for each dataset at a +deterministic path on its anonymous HTTPS download server:: + + http://download.iprox.org//PX_.xml + +We fetch the XML, walk every ````'s ``cvParam`` entries, and +turn each ``Associated raw file URI`` (and sibling URIs for search-engine +output, result files, etc.) into a pridepy file record. File downloads +themselves go through plain HTTPS on the same host, which supports +``Range`` requests for resume. +""" +import logging +import os +import re +import xml.etree.ElementTree as ET +from typing import ClassVar, Dict, List, Optional +from urllib.parse import urlparse + +import requests + +from pridepy.providers import registry +from pridepy.providers.base import BaseDirectDownloadProvider +from pridepy.providers.jpost import JpostProvider + + +@registry.register +class IproxProvider(BaseDirectDownloadProvider): + name: ClassVar[str] = "iprox" + use_tls: ClassVar[bool] = False # download.iprox.org serves over plain HTTP + + DOWNLOAD_BASE_URL: ClassVar[str] = "http://download.iprox.org/" + PX_XML_URL_TEMPLATE: ClassVar[str] = ( + "http://download.iprox.org/{accession}/PX_{accession}.xml" + ) + # iProX PX XML uses the same PSI-MS cvParam "name" values as JPOST PROXI, + # so we reuse JpostProvider's category map. + PX_CATEGORY_MAP: ClassVar[Dict[str, str]] = JpostProvider.PROXI_CATEGORY_MAP + + @staticmethod + def matches(accession: str) -> bool: + """Return True when ``accession`` looks like an iProX dataset accession.""" + if not accession: + return False + return bool(re.fullmatch(r"IPX\d{7,10}", accession.upper())) + + @staticmethod + def _get_public_root(accession: str) -> str: + return f"/{accession.upper()}" + + @classmethod + def _get_public_ftp_url(cls, accession: str, remote_path: str) -> str: + # NOTE: name kept as `_get_public_ftp_url` for parity with other providers, + # but iProX URLs are http(s) not ftp. The dispatcher routes by scheme. + root_path = cls._get_public_root(accession).rstrip("/") + relative_path = remote_path + if remote_path.startswith(root_path): + relative_path = remote_path[len(root_path):].lstrip("/") + return f"{cls.DOWNLOAD_BASE_URL}{accession.upper()}/{relative_path}" + + @classmethod + def _build_file_record( + cls, accession: str, https_url: str, category_from_px: Optional[str] = None + ) -> Dict: + """Build a pridepy file record for an iProX file. + + ``category_from_px`` is the ``cvParam`` ``name`` from the dataset's + ProteomeXchange XML (e.g. ``"Associated raw file URI"``). + """ + from pridepy.providers.massive import MassiveProvider + parsed = urlparse(https_url) + root_prefix = f"/{accession.upper()}/" + relative_path = parsed.path + if relative_path.startswith(root_prefix): + relative_path = relative_path[len(root_prefix):] + relative_path = relative_path.lstrip("/") + collection = relative_path.split("/", 1)[0] if relative_path else "" + if category_from_px and category_from_px in cls.PX_CATEGORY_MAP: + category = cls.PX_CATEGORY_MAP[category_from_px] + else: + category = MassiveProvider._map_collection_to_category(collection) + return { + "accession": accession.upper(), + "fileName": os.path.basename(parsed.path), + "fileCategory": {"value": category}, + # "FTP Protocol" is the existing label the download dispatcher uses + # to locate a file URL; here it actually points at HTTPS. + # BaseDirectDownloadProvider.download_files routes by URL scheme. + "publicFileLocations": [{"name": "FTP Protocol", "value": https_url}], + "relativePath": relative_path, + "collection": collection, + "source": "iProX", + } + + def list_files(self, accession: str) -> List[Dict]: + normalized = accession.upper() + xml_url = self.PX_XML_URL_TEMPLATE.format(accession=normalized) + logging.info(f"Fetching iProX PX XML: {xml_url}") + response = requests.get(xml_url, timeout=30) + response.raise_for_status() + try: + root = ET.fromstring(response.content) + except ET.ParseError as parse_error: + raise RuntimeError( + f"Unable to parse iProX PX XML for {normalized}: {parse_error}" + ) from parse_error + + records: List[Dict] = [] + for dataset_file in root.iter("DatasetFile"): + for cv in dataset_file.findall("cvParam"): + name = cv.attrib.get("name") + value = cv.attrib.get("value") + if not value or not name or not name.endswith("URI"): + continue + if not value.lower().startswith(("http://", "https://")): + continue + records.append( + self._build_file_record( + normalized, + value, + category_from_px=name, + ) + ) + if not records: + raise RuntimeError( + f"iProX PX XML for {normalized} contained no downloadable HTTPS URIs" + ) + return records diff --git a/pridepy/providers/jpost.py b/pridepy/providers/jpost.py new file mode 100644 index 0000000..a9ab23e --- /dev/null +++ b/pridepy/providers/jpost.py @@ -0,0 +1,150 @@ +"""JPOST direct-download provider. + +PRIMARY listing: PROXI JSON at repository.jpostdb.org. The PROXI endpoint +returns ``datasetFiles[*].value`` as ``ftp://`` URLs alongside CV labels +(Associated raw file URI, Search engine output file URI, etc.) which map +cleanly to PRIDE file categories. + +FALLBACK listing: when PROXI fails, walk the FTP tree at ftp.jpostdb.org. +This is needed because JPOST's FTP server rate-limits aggressively per +source IP (sticky 421-too-many-connections); the PROXI path lets us avoid +walking the FTP tree just for a listing. +""" +import logging +import os +import re +from typing import ClassVar, Dict, List, Optional +from urllib.parse import urlparse + +import requests + +from pridepy.providers import registry +from pridepy.providers.base import BaseDirectDownloadProvider + + +@registry.register +class JpostProvider(BaseDirectDownloadProvider): + name: ClassVar[str] = "jpost" + use_tls: ClassVar[bool] = False + + ARCHIVE_FTP: ClassVar[str] = "ftp.jpostdb.org" + ARCHIVE_FTP_URL_PREFIX: ClassVar[str] = "ftp://ftp.jpostdb.org/" + PROXI_BASE_URL: ClassVar[str] = "https://repository.jpostdb.org/proxi/datasets/" + + PROXI_CATEGORY_MAP: ClassVar[Dict[str, str]] = { + "Associated raw file URI": "RAW", + "Result file URI": "RESULT", + "Search engine output file URI": "SEARCH", + "Peak list file URI": "PEAK", + "Spectrum library file URI": "SPECTRUM_LIBRARY", + "Sequence database URI": "FASTA", + "Quantification file URI": "RESULT", + } + + @staticmethod + def matches(accession: str) -> bool: + if not accession: + return False + return bool(re.fullmatch(r"JPST\d{6}", accession.upper())) + + @staticmethod + def _get_public_root(accession: str) -> str: + return f"/{accession.upper()}" + + @classmethod + def _get_public_ftp_url(cls, accession: str, remote_path: str) -> str: + root_path = cls._get_public_root(accession).rstrip("/") + relative_path = remote_path + if remote_path.startswith(root_path): + relative_path = remote_path[len(root_path):].lstrip("/") + return f"{cls.ARCHIVE_FTP_URL_PREFIX}{accession.upper()}/{relative_path}" + + @classmethod + def _build_file_record( + cls, accession: str, ftp_url: str, category_from_proxi: Optional[str] = None + ) -> Dict: + """Build a pridepy file record from an FTP URL. + + When ``category_from_proxi`` is provided (e.g. ``"Associated raw file URI"``), + the PROXI CV name takes precedence over the heuristic collection-from-path + mapping. Falls back to the same path-segment heuristic used for MassIVE + when the category isn't known. + """ + # Import the MassIVE collection->category map for the fallback heuristic. + from pridepy.providers.massive import MassiveProvider + parsed = urlparse(ftp_url) + root_prefix = f"/{accession.upper()}/" + relative_path = parsed.path + if relative_path.startswith(root_prefix): + relative_path = relative_path[len(root_prefix):] + relative_path = relative_path.lstrip("/") + collection = relative_path.split("/", 1)[0] if relative_path else "" + if category_from_proxi and category_from_proxi in cls.PROXI_CATEGORY_MAP: + category = cls.PROXI_CATEGORY_MAP[category_from_proxi] + else: + category = MassiveProvider._map_collection_to_category(collection) + return { + "accession": accession.upper(), + "fileName": os.path.basename(parsed.path), + "fileCategory": {"value": category}, + "publicFileLocations": [{"name": "FTP Protocol", "value": ftp_url}], + "relativePath": relative_path, + "collection": collection, + "source": "JPOST", + } + + def list_files(self, accession: str) -> List[Dict]: + """PRIMARY: PROXI JSON. FALLBACK: FTP tree walk.""" + normalized = accession.upper() + try: + return self._list_via_proxi(normalized) + except Exception as proxi_error: + logging.warning( + f"JPOST PROXI listing failed for {normalized} " + f"({proxi_error}); falling back to FTP tree walk." + ) + from pridepy.providers import transport + remote_root = self._get_public_root(normalized) + remote_files = transport._list_ftp_repo_files( + host=self.ARCHIVE_FTP, + remote_root=remote_root, + error_label=f"JPOST dataset {normalized}", + ) + return [ + self._build_file_record( + normalized, + self._get_public_ftp_url(normalized, remote_file), + ) + for remote_file in remote_files + ] + + def _list_via_proxi(self, accession: str) -> List[Dict]: + """Fetch JPOST PROXI dataset metadata and turn each datasetFiles entry into a file record.""" + import json as _json + proxi_url = f"{self.PROXI_BASE_URL}{accession}" + logging.info(f"Fetching JPOST PROXI metadata: {proxi_url}") + response = requests.get( + proxi_url, + headers={"Accept": "application/json"}, + timeout=30, + ) + response.raise_for_status() + data = _json.loads(response.content) + dataset_files = data.get("datasetFiles") or [] + records: List[Dict] = [] + for entry in dataset_files: + value = (entry or {}).get("value") + if not value or not value.startswith("ftp://"): + continue + records.append( + self._build_file_record( + accession, + value, + category_from_proxi=(entry or {}).get("name"), + ) + ) + if not records: + raise RuntimeError( + f"JPOST PROXI returned no FTP file URIs for {accession}" + ) + return records diff --git a/pridepy/providers/massive.py b/pridepy/providers/massive.py new file mode 100644 index 0000000..cdc466b --- /dev/null +++ b/pridepy/providers/massive.py @@ -0,0 +1,97 @@ +"""MassIVE direct-download provider. + +Lists files by walking the FTPS tree at massive-ftp.ucsd.edu (TLS is +required by the server). Downloads files via the shared transport layer +with ``use_tls=True``. +""" +import os +import re +from typing import ClassVar, Dict, List +from urllib.parse import urlparse + +from pridepy.providers import registry +from pridepy.providers.base import BaseDirectDownloadProvider + + +MASSIVE_CATEGORY_MAP = { + "raw": "RAW", + "peak": "PEAK", + "ccms_peak": "PEAK", + "search": "SEARCH", + "result": "RESULT", + "ccms_result": "RESULT", + "quant": "RESULT", + "fasta": "FASTA", + "spectrum_library": "SPECTRUM_LIBRARY", + "library": "SPECTRUM_LIBRARY", +} + + +@registry.register +class MassiveProvider(BaseDirectDownloadProvider): + name: ClassVar[str] = "massive" + use_tls: ClassVar[bool] = True + + ARCHIVE_FTP: ClassVar[str] = "massive-ftp.ucsd.edu" + ARCHIVE_FTP_URL_PREFIX: ClassVar[str] = "ftp://massive-ftp.ucsd.edu/v01/" + + @staticmethod + def matches(accession: str) -> bool: + """Return True when ``accession`` is a MassIVE dataset accession.""" + if not accession: + return False + return bool(re.fullmatch(r"R?MSV\d{9}", accession.upper())) + + @staticmethod + def _get_public_root(accession: str) -> str: + return f"/v01/{accession.upper()}" + + @classmethod + def _get_public_ftp_url(cls, accession: str, remote_path: str) -> str: + root_path = cls._get_public_root(accession).rstrip("/") + relative_path = remote_path + if remote_path.startswith(root_path): + relative_path = remote_path[len(root_path):].lstrip("/") + return f"{cls.ARCHIVE_FTP_URL_PREFIX}{accession.upper()}/{relative_path}" + + @staticmethod + def _map_collection_to_category(collection: str) -> str: + return MASSIVE_CATEGORY_MAP.get(collection.lower(), "OTHER") + + @classmethod + def _build_file_record(cls, accession: str, ftp_url: str) -> Dict: + """Build a pridepy file record from an FTP URL inside the dataset.""" + parsed = urlparse(ftp_url) + root_prefix = f"/v01/{accession.upper()}/" + relative_path = parsed.path + if relative_path.startswith(root_prefix): + relative_path = relative_path[len(root_prefix):] + relative_path = relative_path.lstrip("/") + collection = relative_path.split("/", 1)[0] if relative_path else "" + return { + "accession": accession.upper(), + "fileName": os.path.basename(parsed.path), + "fileCategory": {"value": cls._map_collection_to_category(collection)}, + "publicFileLocations": [{"name": "FTP Protocol", "value": ftp_url}], + "relativePath": relative_path, + "collection": collection, + "source": "MassIVE", + } + + def list_files(self, accession: str) -> List[Dict]: + from pridepy.providers import transport + normalized = accession.upper() + remote_root = self._get_public_root(normalized) + remote_files = transport._list_ftp_repo_files( + host=self.ARCHIVE_FTP, + remote_root=remote_root, + error_label=f"MassIVE dataset {normalized}", + use_tls=True, + ) + return [ + self._build_file_record( + normalized, + self._get_public_ftp_url(normalized, remote_file), + ) + for remote_file in remote_files + ] diff --git a/pridepy/providers/pride.py b/pridepy/providers/pride.py new file mode 100644 index 0000000..c8c40ca --- /dev/null +++ b/pridepy/providers/pride.py @@ -0,0 +1,815 @@ +"""PRIDE Archive provider. + +PRIDE has the richest behaviour of all providers: multi-protocol batch +download with aspera/s3/ftp/globus fallback, private-dataset path with +username/password auth, checksum TSV validation, and submitter-path +helpers. This module hosts all of those; the :class:`Files` facade +delegates via lightweight shim methods. + +Implementation note: PRIDE-specific helpers that the existing test suite +patches via ``patch.object(Files, "X")`` are called from inside this +provider via ``Files.X(...)`` (lazy import) — never ``self.X`` — so the +patches keep intercepting. This is a deliberate backward-compat choice +documented in the refactor plan (Task 8). +""" +import ftplib +import importlib.resources +import logging +import os +import platform +import re +import socket +import subprocess +import time +import urllib +import urllib.request +from concurrent.futures import ThreadPoolExecutor, as_completed +from ftplib import FTP +from typing import ClassVar, Dict, List, Optional +from urllib.parse import urlparse + +import boto3 +import botocore +import requests +from botocore.config import Config +from tqdm import tqdm + +from pridepy.authentication.authentication import Authentication +from pridepy.providers import registry +from pridepy.providers.base import Provider +from pridepy.providers.util import Progress +from pridepy.util.api_handling import Util + + +@registry.register +class PrideProvider(Provider): + """PRIDE Archive provider with multi-protocol fallback orchestration.""" + + name: ClassVar[str] = "pride" + + V3_API_BASE_URL: ClassVar[str] = "https://www.ebi.ac.uk/pride/ws/archive/v3" + API_BASE_URL: ClassVar[str] = "https://www.ebi.ac.uk/pride/ws/archive/v3" + API_PRIVATE_URL: ClassVar[str] = "https://www.ebi.ac.uk/pride/private/ws/archive/v2" + ARCHIVE_FTP: ClassVar[str] = "ftp.pride.ebi.ac.uk" + ARCHIVE_FTP_URL_PREFIX: ClassVar[str] = "ftp://ftp.pride.ebi.ac.uk/" + ARCHIVE_HTTPS_URL_PREFIX: ClassVar[str] = "https://ftp.pride.ebi.ac.uk/" + S3_URL: ClassVar[str] = "https://hh.fire.sdo.ebi.ac.uk" + S3_BUCKET: ClassVar[str] = "pride-public" + PROTOCOL_ORDER: ClassVar[List[str]] = ["aspera", "s3", "ftp", "globus"] + + @staticmethod + def matches(accession: str) -> bool: + """Return True when ``accession`` is a PRIDE dataset accession.""" + if not accession: + return False + return bool(re.fullmatch(r"(?:PXD|PRD)\d+", accession.upper())) + + # ------------------------------------------------------------------ + # Listing + # ------------------------------------------------------------------ + + async def stream_all_files_metadata(self, output_file, accession=None): + """ + get stream all project files from PRIDE API in JSON format + """ + if accession is None: + request_url = f"{self.V3_API_BASE_URL}/files/all" + count_request_url = f"{self.V3_API_BASE_URL}/files/count" + else: + request_url = f"{self.V3_API_BASE_URL}/projects/{accession}/files/all" + count_request_url = f"{self.V3_API_BASE_URL}/projects/{accession}/files/count" + headers = {"Accept": "application/JSON"} + response = Util.get_api_call(count_request_url, headers) + total_records = response.json() + + regex_search_pattern = '"fileName"' + await Util.stream_response_to_file( + output_file, total_records, regex_search_pattern, request_url, headers + ) + + def stream_all_files_by_project(self, accession) -> List[Dict]: + """ + get stream all project files from PRIDE API in JSON format + """ + request_url = f"{self.V3_API_BASE_URL}/projects/{accession}/files/all" + headers = {"Accept": "application/JSON"} + record_files = Util.read_json_stream(api_url=request_url, headers=headers) + return record_files + + def list_files(self, accession: str) -> List[Dict]: + """Return PRIDE file records for the dataset.""" + return self.stream_all_files_by_project(accession) + + def get_submitted_file_path_prefix(self, accession): + """ + At pride repository, public data is disseminated according to a proper structure. + I.e. base/path/ + yyyy/mm/accession/ + submitted/ + This extracts the yyyy/mm/accession path fragment from the API by examine the file path + of a public file. + I.e. ftp://ftp.pride.ebi.ac.uk/pride/data/archive/2018/10/PXD008644/7550GI_Y.raw + :param accession: PRIDE accession + :return: path fragment (eg: 2018/10/PXD008644) + """ + # Use Files facade so test patches on get_all_raw_file_list keep working. + from pridepy.files.files import Files + results = Files().get_all_raw_file_list(accession) + first_file = results[0]["publicFileLocations"][0]["value"] + path_fragment = re.search(r"\d{4}/\d{2}/PXD\d*", first_file).group() + return path_fragment + + # ------------------------------------------------------------------ + # Static utilities + # ------------------------------------------------------------------ + + @staticmethod + def _protocol_sequence(protocol: str) -> List[str]: + """ + Build the ordered list of protocols to try for a requested download mode. + """ + if protocol not in PrideProvider.PROTOCOL_ORDER: + return [] + return [protocol] + [p for p in PrideProvider.PROTOCOL_ORDER if p != protocol] + + @staticmethod + def get_ascp_binary(): + """ + Detect the OS and architecture, and return the appropriate ascp binary path. + + Returns: + str: Path to the correct ascp binary. + """ + os_type = platform.system().lower() + arch, _ = platform.architecture() + aspera_dir = importlib.resources.files("pridepy").joinpath("aspera/") + + if os_type == "linux": + if arch == "32bit": + return os.path.join(aspera_dir, "linux-32", "ascp") + elif arch == "64bit": + return os.path.join(aspera_dir, "linux-64", "ascp") + elif os_type == "darwin": # macOS (intel-based) + return os.path.join(aspera_dir, "mac-intel", "ascp") + elif os_type == "windows": + if arch == "32bit": + return os.path.join(aspera_dir, "windows-32", "ascp.exe") + elif arch == "64bit": + return os.path.join(aspera_dir, "windows-64", "ascp.exe") + else: + raise OSError(f"Unsupported OS or architecture: {os_type}, {arch}") + + @staticmethod + def save_checksum_file(accession, output_folder): + """ + Download and persist the checksum manifest for a PRIDE accession. + """ + os.makedirs(output_folder, exist_ok=True) + url = f"{PrideProvider.V3_API_BASE_URL}/files/checksum/{accession}" + headers = {"accept": "text/plain"} + request = urllib.request.Request(url, headers=headers, method="GET") + logging.info(f"Fetching checksum file from {url}") + with urllib.request.urlopen(request) as response: + data = response.read().decode("utf-8") + # Save the data to a .tsv file + output_path = os.path.join(output_folder, f"{accession}-checksum.tsv") + with open(output_path, "w", encoding="utf-8") as file: + file.write(data) + return output_path + + # ------------------------------------------------------------------ + # Per-protocol single-file workers + # ------------------------------------------------------------------ + + @staticmethod + def _globus_download_one(file, output_folder, skip_if_downloaded_already, max_retries=6, position=0): + """Download a single file via globus; used as a worker target.""" + # Use Files facade so test patches on Files helpers keep working. + from pridepy.files.files import Files + + download_url = Files._get_download_url(file, "globus") + new_file_path = Files.get_output_file_name(download_url, file, output_folder) + + if skip_if_downloaded_already and os.path.exists(new_file_path): + logging.info(f"Skipping download as file already exists: {new_file_path}") + return + + for attempt in range(1, max_retries + 1): + try: + Files._parallel_download(download_url, new_file_path, position=position) + return + except Exception as e: + logging.warning(f"Attempt {attempt}/{max_retries} failed for {file.get('fileName', '?')}: {e}") + if attempt == max_retries: + raise + + # ------------------------------------------------------------------ + # Per-protocol batch helpers + # ------------------------------------------------------------------ + + @staticmethod + def download_files_from_ftp( + file_list_json, + output_folder, + skip_if_downloaded_already, + max_connection_retries=3, + max_download_retries=3, + ): + """ + Download files using a single FTP connection with a retry mechanism and a progress bar for each file. + :param file_list_json: file list in JSON format + :param output_folder: folder to download the files + :param skip_if_downloaded_already: Boolean value to skip the download if the file has already been downloaded. + :param max_connection_retries: Number of attempts to reconnect to the FTP server if the connection is lost. + :param max_download_retries: Number of attempts to retry the download of a file in case of failure. + """ + from pridepy.files.files import Files + + if not os.path.isdir(output_folder): + os.makedirs(output_folder) + + def connect_ftp(): + """Helper function to establish FTP connection.""" + ftp = FTP(PrideProvider.ARCHIVE_FTP, timeout=30) + ftp.login() # Anonymous login + ftp.set_pasv(True) # Enable passive mode + logging.info(f"Connected to FTP host: {PrideProvider.ARCHIVE_FTP}") + return ftp + + connection_attempt = 0 + while connection_attempt < max_connection_retries: + try: + ftp = connect_ftp() + for file in file_list_json: + try: + # Get FTP download URL + if file["publicFileLocations"][0]["name"] == "FTP Protocol": + download_url = file["publicFileLocations"][0]["value"] + else: + download_url = file["publicFileLocations"][1]["value"] + + logging.debug("ftp_filepath:" + download_url) + + # Get output file path + new_file_path = Files.get_output_file_name( + download_url, file, output_folder + ) + + if skip_if_downloaded_already and os.path.exists(new_file_path): + logging.info("Skipping download as file already exists") + continue + + # Extract file path from the download URL + parsed_url = urlparse(download_url) + ftp_file_path = urllib.parse.unquote(parsed_url.path.lstrip("/")) + + logging.info(f"Starting FTP download: {ftp_file_path}") + + # Retry download in case of failure + download_attempt = 0 + while download_attempt < max_download_retries: + try: + # Get file size for progress tracking + total_size = ftp.size(ftp_file_path) + logging.info(f"File size: {total_size} bytes") + + # Initialize progress bar + with open(new_file_path, "wb") as f: + with tqdm( + total=total_size, + unit="B", + unit_scale=True, + desc=new_file_path, + ) as pbar: + + def callback(data): + f.write(data) + pbar.update(len(data)) + + # Retrieve the file with progress callback + ftp.retrbinary(f"RETR {ftp_file_path}", callback) + + logging.info(f"Successfully downloaded {new_file_path}") + break # Exit download retry loop if successful + except ( + socket.timeout, + ftplib.error_temp, + ftplib.error_perm, + ) as e: + download_attempt += 1 + logging.error( + f"Download failed for {new_file_path} (attempt {download_attempt}): {str(e)}" + ) + if download_attempt >= max_download_retries: + logging.error( + f"Giving up on {new_file_path} after {max_download_retries} attempts." + ) + break # Give up on this file after max retries + except (KeyError, IndexError) as e: + logging.error(f"Failed to process file due to missing data: {str(e)}") + except Exception as e: + logging.error(f"Unexpected error while processing file: {str(e)}") + ftp.quit() # Close FTP connection after all files are downloaded + logging.info(f"Disconnected from FTP host: {PrideProvider.ARCHIVE_FTP}") + break # Exit connection retry loop if everything was successful + except ( + socket.timeout, + ftplib.error_temp, + ftplib.error_perm, + socket.error, + ) as e: + connection_attempt += 1 + logging.error(f"FTP connection failed (attempt {connection_attempt}): {str(e)}") + if connection_attempt < max_connection_retries: + logging.info("Retrying connection...") + time.sleep(5) # Optional delay before retrying + else: + logging.error( + f"Giving up after {max_connection_retries} failed connection attempts." + ) + break + + @staticmethod + def download_files_from_globus( + file_list_json: List[Dict], output_folder, skip_if_downloaded_already, + parallel_files: int = 1, + checksum_map: Optional[Dict[str, str]] = None, + ): + """ + Download files using globus transfer url with progress bar for each file. + When skip_if_downloaded_already is True, files are pre-filtered so that + only missing or incomplete files are submitted to the worker pool, + ensuring the -w parallel_files parameter is fully utilised. + When checksum_map is provided, existing files are validated against + their expected checksum; corrupted files are re-downloaded. + :param file_list_json: file list in json format + :param output_folder: folder to download the files + :param skip_if_downloaded_already: Boolean value to skip the download if the file has already been downloaded. + :param parallel_files: number of files to download simultaneously + :param checksum_map: mapping of file name to expected MD5 checksum + """ + # Use Files facade so test patches on Files._globus_download_one etc. keep working. + from pridepy.files.files import Files + + if checksum_map is None: + checksum_map = {} + + if not (os.path.isdir(output_folder)): + os.makedirs(output_folder, exist_ok=True) + + # --- Phase 0: pre-filter files that need downloading ----------------- + files_to_download: List[Dict] = [] + for file in file_list_json: + download_url = Files._get_download_url(file, "globus") + new_file_path = Files.get_output_file_name(download_url, file, output_folder) + if skip_if_downloaded_already and os.path.exists(new_file_path): + expected_cs = checksum_map.get(file.get("fileName", "")) + if expected_cs: + valid, reason = Files.validate_download(new_file_path, expected_cs) + if not valid: + logging.warning(f"Corrupted file detected ({reason}), will re-download: {new_file_path}") + files_to_download.append(file) + continue + logging.info(f"Skipping download as file already exists: {new_file_path}") + continue + files_to_download.append(file) + + if not files_to_download: + logging.info("All files already downloaded, nothing to do.") + return + + logging.info( + f"{len(file_list_json) - len(files_to_download)} file(s) skipped, " + f"{len(files_to_download)} file(s) to download" + ) + + # --- Phase 1: download (skip check already done, pass False) --------- + parallel_files = min(parallel_files, 3, len(files_to_download)) + if parallel_files < 2: + for file in files_to_download: + try: + Files._globus_download_one( + file, output_folder, False + ) + new_file_path = Files.get_output_file_name( + Files._get_download_url(file, "globus"), file, output_folder + ) + logging.info(f"Successfully downloaded {new_file_path}") + except Exception as e: + logging.error(f"Download from Globus failed: {str(e)}") + else: + logging.info(f"Downloading {len(files_to_download)} file(s) with {parallel_files} parallel workers") + with ThreadPoolExecutor(max_workers=parallel_files) as executor: + futures = { + executor.submit( + Files._globus_download_one, + file, output_folder, False, + position=idx, + ): file + for idx, file in enumerate(files_to_download) + } + for future in as_completed(futures): + try: + future.result() + except Exception as e: + logging.error(f"Download from Globus failed: {str(e)}") + + @staticmethod + def download_files_from_s3( + file_list_json: List[Dict], output_folder: str, skip_if_downloaded_already + ): + """ + Download files using S3 transfer URL with a progress bar and retry logic. + :param file_list_json: file list in JSON format + :param output_folder: folder to download the files + :param skip_if_downloaded_already: Boolean value to skip the download if the file has already been downloaded. + """ + from pridepy.files.files import Files + + if not os.path.isdir(output_folder): + os.makedirs(output_folder, exist_ok=True) + + # Retry and timeout config + retry_config = Config( + retries={"max_attempts": 5, "mode": "standard"}, + connect_timeout=120, # Increase timeout to 120 seconds + read_timeout=120, # Timeout for reading data + signature_version=botocore.UNSIGNED, # Unsigned requests for public data + ) + + s3_resource = boto3.resource( + "s3", + config=retry_config, + endpoint_url=PrideProvider.S3_URL, + ) + bucket = s3_resource.Bucket(PrideProvider.S3_BUCKET) + + for file in file_list_json: + try: + # Determine S3 or FTP path + download_url = ( + file["publicFileLocations"][0]["value"] + if file["publicFileLocations"][0]["name"] == "FTP Protocol" + else file["publicFileLocations"][1]["value"] + ) + + ftp_base_url = "ftp://ftp.pride.ebi.ac.uk/pride/data/archive/" + s3_path = download_url.replace(ftp_base_url, "") + new_file_path = Files.get_output_file_name(download_url, file, output_folder) + + if skip_if_downloaded_already == True and os.path.exists(new_file_path): + logging.info("Skipping download as file already exists") + continue + + logging.debug(f"Downloading From S3: {s3_path}") + + # Get file size for progress tracking + obj = bucket.Object(s3_path) + total_size = obj.content_length + + # Initialize progress bar + progress = Progress(total_size, new_file_path) + + # Download with progress bar and retry handling + for attempt in range(5): + try: + bucket.download_file(s3_path, new_file_path, Callback=progress) + progress.close() + logging.info(f"Successfully downloaded {new_file_path}") + break + except botocore.exceptions.ClientError as e: + if e.response["Error"]["Code"] == "404": + logging.error("The object does not exist.") + break + else: + logging.error(f"Download failed: {e}") + if attempt < 4: + time.sleep(2**attempt) # Exponential backoff + logging.info(f"Retrying... ({attempt + 1}/5)") + else: + raise + except Exception as e: + logging.error(f"Failed to download {file['fileName']}: {e}") + + # ------------------------------------------------------------------ + # Private dataset download + # ------------------------------------------------------------------ + + def download_private_file_name(self, accession, file_name, output_folder, username, password): + """ + Get the information for a given private file to be downloaded from the api. + :param accession: Project accession + :param file_name: The file name to be downloaded + :param username: Username with access to the dataset + :param password: Password for user with access to the dataset + """ + + auth = Authentication() + auth_token = auth.get_token(username, password) + validate_token = auth.validate_token(auth_token) + logging.info("Valid token after login: {}".format(validate_token)) + + url = self.API_PRIVATE_URL + "/projects/{}/files?search={}".format(accession, file_name) + content = requests.get(url, headers={"Authorization": "Bearer {}".format(auth_token)}) + if content.ok and content.status_code == 200: + json_file = content.json() + if ( + "_embedded" in json_file + and "files" in json_file["_embedded"] + and len(json_file["_embedded"]["files"]) == 1 + ): + download_url = json_file["_embedded"]["files"][0]["_links"]["download"]["href"] + logging.info(download_url) + + # Create a clean filename to save the downloaded file + new_file_path = os.path.join(output_folder, f"{file_name}") + + session = Util.create_session_with_retries() # Create session with retries + # Check if the file already exists + if os.path.exists(new_file_path): + resume_header = {"Range": f"bytes={os.path.getsize(new_file_path)}-"} + mode = "ab" # Append to file + resume_size = os.path.getsize(new_file_path) + else: + resume_header = {} + mode = "wb" # Write new file + resume_size = 0 + + with session.get( + download_url, stream=True, headers=resume_header, timeout=(10, 60) + ) as r: + r.raise_for_status() + total_size = int(r.headers.get("content-length", 0)) + resume_size + block_size = 1024 * 1024 # 1 MB chunks + + with tqdm( + total=total_size, + unit="B", + unit_scale=True, + desc=new_file_path, + initial=resume_size, + ) as pbar: + with open(new_file_path, mode) as f: + for chunk in r.iter_content(chunk_size=block_size): + if chunk: + f.write(chunk) + pbar.update(len(chunk)) + + logging.info(f"Successfully downloaded {new_file_path}") + + else: + logging.info( + "File name {} found more than once for the given project {}".format( + file_name, accession + ) + ) + else: + logging.info( + f"File name {file_name} now found in the project {accession}, or user don't have access" + ) + raise Exception( + f"File name {file_name} now found in the project {accession}, or user don't have access" + ) + + # ------------------------------------------------------------------ + # Multi-protocol orchestrator + # ------------------------------------------------------------------ + + @staticmethod + def _batch_download_by_protocol( + file_list: List[Dict], + output_folder: str, + protocol: str, + skip_if_downloaded_already: bool, + aspera_maximum_bandwidth: str, + parallel_files: int = 1, + checksum_map: Optional[Dict[str, str]] = None, + ) -> None: + """ + Transfer a batch of files with one protocol, reusing a single + connection where the underlying helper supports it (FTP, S3). + """ + # Use Files facade so test patches on each per-protocol helper keep working. + from pridepy.files.files import Files + + if not file_list: + return + if protocol == "ftp": + Files.download_files_from_ftp( + file_list, + output_folder, + skip_if_downloaded_already=skip_if_downloaded_already, + ) + return + if protocol == "aspera": + Files.download_files_from_aspera( + file_list, + output_folder, + skip_if_downloaded_already=skip_if_downloaded_already, + maximum_bandwidth=aspera_maximum_bandwidth, + ) + return + if protocol == "globus": + Files.download_files_from_globus( + file_list, + output_folder, + skip_if_downloaded_already=skip_if_downloaded_already, + parallel_files=parallel_files, + checksum_map=checksum_map or {}, + ) + return + if protocol == "s3": + Files.download_files_from_s3( + file_list, + output_folder, + skip_if_downloaded_already=skip_if_downloaded_already, + ) + return + raise ValueError(f"Unsupported protocol: {protocol}") + + @staticmethod + def _download_with_fallback( + file_record: Dict, + output_folder: str, + protocol_sequence: List[str], + expected_checksum: Optional[str], + aspera_maximum_bandwidth: str, + max_protocol_retries: int = 2, + parallel_files: int = 1, + ) -> bool: + """ + Download one file by trying each protocol in sequence, validating + after every attempt. Intended as the per-file fallback path; batch + download of the primary protocol is handled separately. + """ + # Patch-sensitive: call through Files so test patches intercept. + from pridepy.files.files import Files + + local_path = Files._resolve_local_path(file_record, output_folder) + + for protocol in protocol_sequence: + for attempt in range(1, max_protocol_retries + 1): + logging.info( + f"Downloading {file_record['fileName']} via {protocol} " + f"(attempt {attempt}/{max_protocol_retries})" + ) + try: + Files._remove_if_exists(local_path) + Files._batch_download_by_protocol( + [file_record], + output_folder, + protocol, + skip_if_downloaded_already=False, + aspera_maximum_bandwidth=aspera_maximum_bandwidth, + parallel_files=parallel_files, + ) + except Exception as error: + logging.error( + f"Protocol {protocol} failed for {file_record['fileName']}: {error}" + ) + + valid, reason = Files.validate_download(local_path, expected_checksum) + if valid: + logging.info( + f"File {file_record['fileName']} downloaded successfully via {protocol}" + ) + return True + + logging.warning( + f"Validation failed for {file_record['fileName']} via {protocol}: {reason}" + ) + Files._remove_if_exists(local_path) + + logging.warning( + f"Protocol {protocol} exhausted for {file_record['fileName']}, switching protocol." + ) + + logging.error(f"All protocol attempts failed for {file_record['fileName']}") + return False + + def download_files( + self, + accession, + records: List[Dict], + output_folder: str, + skip_if_downloaded_already, + protocol: str = "ftp", + aspera_maximum_bandwidth: str = "100M", + checksum_check: bool = False, + parallel_files: int = 1, + username: Optional[str] = None, + password: Optional[str] = None, + ): + """Implement Provider.download_files — maps to the legacy static batch downloader.""" + PrideProvider._download_files_batch( + file_list_json=records, + accession=accession, + output_folder=output_folder, + skip_if_downloaded_already=skip_if_downloaded_already, + protocol=protocol, + aspera_maximum_bandwidth=aspera_maximum_bandwidth, + checksum_check=checksum_check, + parallel_files=parallel_files, + ) + + @staticmethod + def _download_files_batch( + file_list_json: List[Dict], + accession, + output_folder: str, + skip_if_downloaded_already, + protocol: str = "ftp", + aspera_maximum_bandwidth: str = "100M", # Aspera maximum bandwidth + checksum_check=False, + parallel_files: int = 1, + ): + """ + Download files using either FTP or Aspera transfer protocol. + :param file_list_json: File list in JSON format + :param accession: Project accession + :param output_folder: Folder to download the files + :param protocol: ftp, aspera, globus + :param aspera_maximum_bandwidth: parameter in Aspera sets the maximum bandwidth for the transfer. + :param skip_if_downloaded_already: Boolean value to skip the download if the file has already been downloaded. + """ + # Patch-sensitive: call _batch_download_by_protocol and + # _download_with_fallback through Files so test patches intercept. + from pridepy.files.files import Files + + protocols_supported = ["ftp", "aspera", "globus", "s3"] + if protocol not in protocols_supported: + logging.error("Protocol should be one of ftp, aspera, globus, s3") + return + + os.makedirs(output_folder, exist_ok=True) + + checksum_map: Dict[str, str] = {} + if checksum_check: + checksum_file_path = Files.save_checksum_file(accession, output_folder) + checksum_map = Files.read_checksum_file(checksum_file_path) + logging.info(f"Loaded checksums for {len(checksum_map)} files") + + if not file_list_json: + return + + protocol_sequence = Files._protocol_sequence(protocol) + primary_protocol = protocol_sequence[0] + # Retry with the primary protocol first, then fall back to others + fallback_sequence = protocol_sequence + + # Phase 1: batch download with the requested protocol. Reuses a single + # FTP/S3 connection for all files (the previous behaviour) instead of + # paying the per-file reconnect cost in the common happy path. + logging.info( + f"Downloading {len(file_list_json)} file(s) via {primary_protocol} (batch)" + ) + try: + Files._batch_download_by_protocol( + file_list_json, + output_folder, + primary_protocol, + skip_if_downloaded_already=skip_if_downloaded_already, + aspera_maximum_bandwidth=aspera_maximum_bandwidth, + parallel_files=parallel_files, + checksum_map=checksum_map, + ) + except Exception as exc: + logging.warning( + f"Batch {primary_protocol} run hit an error; will retry individual failures: {exc}" + ) + + # Phase 2: validate every file and fall back per-file for the ones + # that are missing or invalid. + logging.info("Phase 2: validating %d downloaded file(s)", len(file_list_json)) + failed_files: List[str] = [] + for i, file_record in enumerate(file_list_json, 1): + expected_checksum = checksum_map.get(file_record["fileName"]) + local_path = Files._resolve_local_path(file_record, output_folder) + logging.info("Validating [%d/%d] %s", i, len(file_list_json), file_record["fileName"]) + valid, reason = Files.validate_download(local_path, expected_checksum) + if valid: + continue + + logging.warning( + f"{file_record['fileName']} invalid after {primary_protocol} ({reason})" + ) + if "checksum mismatch" in reason: + Files._remove_if_exists(local_path) + + if not fallback_sequence: + failed_files.append(file_record.get("fileName", "")) + continue + + success = Files._download_with_fallback( + file_record=file_record, + output_folder=output_folder, + protocol_sequence=fallback_sequence, + expected_checksum=expected_checksum, + aspera_maximum_bandwidth=aspera_maximum_bandwidth, + parallel_files=parallel_files, + ) + if not success: + failed_files.append(file_record.get("fileName", "")) + + if failed_files: + failed_summary = ", ".join(failed_files) + logging.error(f"Failed to download {len(failed_files)} file(s): {failed_summary}") + raise RuntimeError(f"Failed to download {len(failed_files)} file(s): {failed_summary}") diff --git a/pridepy/providers/registry.py b/pridepy/providers/registry.py new file mode 100644 index 0000000..7d2c20d --- /dev/null +++ b/pridepy/providers/registry.py @@ -0,0 +1,38 @@ +"""Accession-to-provider resolution. + +Providers are tried in priority order; direct-download repositories +(MassIVE / JPOST / iProX) are tried first because their accession patterns +are unambiguous. PRIDE is tried last and acts as the catch-all for +``PXD\\d+`` / ``PRD\\d+`` accessions. +""" +from typing import List, Type + +from pridepy.providers.base import Provider + +_PROVIDERS: List[Type[Provider]] = [] # populated by individual provider modules + + +def register(provider_cls: Type[Provider]) -> Type[Provider]: + """Register a provider class. Usable as a decorator.""" + if provider_cls not in _PROVIDERS: + _PROVIDERS.append(provider_cls) + return provider_cls + + +def resolve(accession: str) -> Provider: + """Return a provider instance that matches ``accession``. + + :raises ValueError: when no registered provider matches. + """ + for cls in _PROVIDERS: + if cls.matches(accession): + return cls() + raise ValueError(f"No provider registered for accession {accession!r}") + + +def is_known(accession: str) -> bool: + """Return True if any registered provider matches ``accession``.""" + for cls in _PROVIDERS: + if cls.matches(accession): + return True + return False diff --git a/pridepy/providers/transport.py b/pridepy/providers/transport.py new file mode 100644 index 0000000..6649657 --- /dev/null +++ b/pridepy/providers/transport.py @@ -0,0 +1,504 @@ +"""Shared FTP / FTPS / HTTPS download transport. + +Stateless helpers used by the per-repository providers (and re-exported on +:class:`pridepy.files.files.Files` for backward compatibility with tests that +patch ``Files.download_ftp_urls`` etc.). +""" +import ftplib +import logging +import os +import socket +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from ftplib import FTP +from typing import Dict, List, Optional +from urllib.parse import urlparse + +import requests +from tqdm import tqdm + +from pridepy.util.api_handling import Util + + +def _local_path_for_url(download_url: str, output_folder: str) -> str: + filename = os.path.basename(urlparse(download_url).path) + return os.path.join(output_folder, filename) + + +def _open_ftp_connection(host: str, use_tls: bool, timeout: int = 30) -> FTP: + """ + Open an anonymous FTP connection, transparently using FTPS when the + server requires TLS (e.g., MassIVE). When ``use_tls`` is False but the + server replies ``421 TLS is required`` to ``login``, transparently + retry with FTPS so callers don't need to know the policy in advance. + """ + if use_tls: + ftp: FTP = ftplib.FTP_TLS(host, timeout=timeout) + ftp.login() + ftp.prot_p() + else: + ftp = FTP(host, timeout=timeout) + try: + ftp.login() + except ftplib.error_temp as e: + if "TLS" in str(e).upper(): + try: + ftp.close() + except Exception: + pass + ftp = ftplib.FTP_TLS(host, timeout=timeout) + ftp.login() + ftp.prot_p() + else: + raise + ftp.set_pasv(True) + return ftp + + +def _walk_ftp_tree(ftp: FTP, remote_dir: str) -> List[str]: + """ + Recursively list files under a remote FTP directory. + """ + import posixpath + file_paths: List[str] = [] + try: + entries = list(ftp.mlsd(remote_dir)) + for name, facts in entries: + if name in {".", ".."}: + continue + child_path = posixpath.join(remote_dir.rstrip("/"), name) + if facts.get("type") == "dir": + file_paths.extend(_walk_ftp_tree(ftp, child_path)) + elif facts.get("type") == "file": + file_paths.append(child_path) + return file_paths + except (AttributeError, ftplib.error_perm): + pass + + current_dir = ftp.pwd() + listing: List[str] = [] + try: + ftp.cwd(remote_dir) + ftp.retrlines("LIST", listing.append) + for entry in listing: + parts = entry.split(maxsplit=8) + if len(parts) < 9: + continue + name = parts[8] + if name in {".", ".."}: + continue + child_path = posixpath.join(remote_dir.rstrip("/"), name) + if entry.startswith("d"): + file_paths.extend(_walk_ftp_tree(ftp, child_path)) + else: + file_paths.append(child_path) + finally: + ftp.cwd(current_dir) + return file_paths + + +def _list_ftp_repo_files( + host: str, + remote_root: str, + error_label: str, + use_tls: bool = False, +) -> List[str]: + """ + Connect to an anonymous FTP host (FTP or FTPS), walk a directory tree, + and return file paths. + + ``use_tls`` should be True for servers that reject plain FTP (e.g. + MassIVE). Centralizes connection lifecycle so a constructor failure + doesn't mask the underlying error in ``finally`` (PR #98 review). + """ + ftp: Optional[FTP] = None + try: + ftp = _open_ftp_connection(host, use_tls=use_tls) + logging.info(f"Connected to FTP host: {host} (tls={use_tls})") + return _walk_ftp_tree(ftp, remote_root) + except Exception as error: + raise RuntimeError( + f"Unable to list public files for {error_label}: {error}" + ) from error + finally: + if ftp is not None: + try: + ftp.quit() + except Exception: + try: + ftp.close() + except Exception: + pass + + +def _download_one_ftp_path( + ftp: FTP, + ftp_path: str, + local_path: str, + skip_if_downloaded_already: bool, + max_download_retries: int, + position: int = 0, +) -> None: + """ + Download a single FTP path over an existing connection, with REST resume + and per-file retry. Raises on giving up so the caller can decide what to do. + """ + if skip_if_downloaded_already and os.path.exists(local_path): + logging.info(f"Skipping download as file already exists: {local_path}") + return + + attempt = 0 + last_error: Optional[Exception] = None + while attempt < max_download_retries: + try: + total_size = ftp.size(ftp_path) + if os.path.exists(local_path): + current_size = os.path.getsize(local_path) + mode = "ab" + else: + current_size = 0 + mode = "wb" + + with open(local_path, mode) as f, tqdm( + total=total_size, + unit="B", + unit_scale=True, + desc=local_path, + initial=current_size, + position=position, + leave=True, + ) as pbar: + def callback(data): + f.write(data) + pbar.update(len(data)) + + if current_size: + try: + ftp.sendcmd(f"REST {current_size}") + except Exception: + current_size = 0 + f.seek(0) + f.truncate() + ftp.retrbinary(f"RETR {ftp_path}", callback) + + # Post-transfer integrity check: server-reported size must match + # the local size. Catches half-finished transfers that retrbinary + # didn't raise on (e.g. server closed the data channel early). + # The next iteration will REST-resume from where we left off. + if total_size: + final_size = os.path.getsize(local_path) + if final_size != total_size: + attempt += 1 + logging.error( + f"Size mismatch for {local_path}: " + f"got {final_size} bytes, expected {total_size} " + f"(attempt {attempt})" + ) + continue + logging.info(f"Successfully downloaded {local_path}") + return + except (socket.timeout, ftplib.error_temp, ftplib.error_perm) as e: + attempt += 1 + last_error = e + logging.error( + f"Download failed for {local_path} (attempt {attempt}): {e}" + ) + raise RuntimeError( + f"Giving up on {local_path} after {max_download_retries} attempts" + ) from last_error + + +def _download_ftp_paths_serial( + host: str, + paths: List[str], + output_folder: str, + skip_if_downloaded_already: bool, + use_tls: bool, + max_connection_retries: int, + max_download_retries: int, +) -> None: + """Download all paths from one host over a single (reused) connection.""" + connection_attempt = 0 + while connection_attempt < max_connection_retries: + try: + ftp = _open_ftp_connection(host, use_tls=use_tls) + logging.info(f"Connected to FTP host: {host} (tls={use_tls})") + for ftp_path in paths: + local_path = os.path.join(output_folder, os.path.basename(ftp_path)) + try: + _download_one_ftp_path( + ftp=ftp, + ftp_path=ftp_path, + local_path=local_path, + skip_if_downloaded_already=skip_if_downloaded_already, + max_download_retries=max_download_retries, + ) + except Exception as e: + logging.error( + f"Failed to download {ftp_path} from {host}: {e}" + ) + try: + ftp.quit() + except Exception: + try: + ftp.close() + except Exception: + pass + logging.info(f"Disconnected from FTP host: {host}") + return + except (socket.timeout, ftplib.error_temp, ftplib.error_perm, OSError) as e: + connection_attempt += 1 + logging.error( + f"FTP connection failed (attempt {connection_attempt}): {e}" + ) + if connection_attempt < max_connection_retries: + logging.info("Retrying connection...") + time.sleep(5) + else: + logging.error( + f"Giving up after {max_connection_retries} failed connection attempts to {host}." + ) + + +def _download_ftp_paths_parallel( + host: str, + paths: List[str], + output_folder: str, + skip_if_downloaded_already: bool, + use_tls: bool, + max_connection_retries: int, + max_download_retries: int, + parallel_files: int, +) -> None: + """ + Download paths concurrently using ``parallel_files`` workers; each + worker opens its own FTP connection so transfers don't serialize. + """ + def worker(ftp_path: str, position: int) -> None: + local_path = os.path.join(output_folder, os.path.basename(ftp_path)) + if skip_if_downloaded_already and os.path.exists(local_path): + logging.info(f"Skipping download as file already exists: {local_path}") + return + connection_attempt = 0 + while connection_attempt < max_connection_retries: + try: + ftp = _open_ftp_connection(host, use_tls=use_tls) + try: + _download_one_ftp_path( + ftp=ftp, + ftp_path=ftp_path, + local_path=local_path, + skip_if_downloaded_already=False, + max_download_retries=max_download_retries, + position=position, + ) + return + finally: + try: + ftp.quit() + except Exception: + try: + ftp.close() + except Exception: + pass + except (socket.timeout, ftplib.error_temp, ftplib.error_perm, OSError) as e: + connection_attempt += 1 + logging.error( + f"FTP connection failed for {ftp_path} (attempt {connection_attempt}): {e}" + ) + if connection_attempt < max_connection_retries: + time.sleep(5) + logging.error(f"Giving up on {ftp_path} from {host}") + + with ThreadPoolExecutor(max_workers=parallel_files) as executor: + futures = [ + executor.submit(worker, path, idx) for idx, path in enumerate(paths) + ] + for future in as_completed(futures): + try: + future.result() + except Exception as e: + logging.error(f"Parallel FTP download error: {e}") + + +def download_ftp_urls( + ftp_urls: List[str], + output_folder: str, + skip_if_downloaded_already: bool, + max_connection_retries: int = 3, + max_download_retries: int = 3, + use_tls: bool = False, + parallel_files: int = 1, +) -> None: + """ + Download a list of FTP URLs with retries, REST-based resume, and + optional parallel workers. + + :param use_tls: Open the FTP connection with TLS (FTP_TLS / PROT P). + Required for hosts that reject plain anonymous FTP (e.g. MassIVE). + When False but the server replies ``421 TLS is required``, the + connection is transparently retried over TLS. + :param parallel_files: When >1, downloads run concurrently with that + many worker connections per host (capped at the number of files). + """ + if not os.path.isdir(output_folder): + os.makedirs(output_folder, exist_ok=True) + + host_to_paths: Dict[str, List[str]] = {} + for url in ftp_urls: + parsed = urlparse(url) + host_to_paths.setdefault(parsed.hostname, []).append(parsed.path.lstrip("/")) + + for host, paths in host_to_paths.items(): + workers = max(1, min(parallel_files, len(paths))) + if workers > 1: + _download_ftp_paths_parallel( + host=host, + paths=paths, + output_folder=output_folder, + skip_if_downloaded_already=skip_if_downloaded_already, + use_tls=use_tls, + max_connection_retries=max_connection_retries, + max_download_retries=max_download_retries, + parallel_files=workers, + ) + else: + _download_ftp_paths_serial( + host=host, + paths=paths, + output_folder=output_folder, + skip_if_downloaded_already=skip_if_downloaded_already, + use_tls=use_tls, + max_connection_retries=max_connection_retries, + max_download_retries=max_download_retries, + ) + + +def _parallel_download(url, file_path, position=0): + """Download a file via a single-connection HTTP stream with optional resume. + If a partial file exists and the server supports Range requests, resumes + from where it left off; otherwise restarts from scratch.""" + session = Util.create_session_with_retries() + try: + head = session.head(url, timeout=(30, 30)) + head.raise_for_status() + total_size = int(head.headers.get("content-length", 0)) + accept_ranges = head.headers.get("accept-ranges", "none").strip().lower() + except (requests.RequestException, ValueError) as exc: + logging.info(f"HEAD request failed, falling back to single connection: {exc}") + total_size = 0 + accept_ranges = "none" + + resume_size = 0 + if os.path.exists(file_path) and accept_ranges == "bytes" and total_size > 0: + resume_size = os.path.getsize(file_path) + if resume_size >= total_size: + logging.info(f"File already complete: {file_path}") + return + if resume_size > 0: + logging.info(f"Resuming download from {resume_size} bytes: {file_path}") + + headers = {"Range": f"bytes={resume_size}-"} if resume_size > 0 else {} + with session.get(url, headers=headers, stream=True, timeout=(30, 60)) as r: + r.raise_for_status() + if resume_size > 0 and r.status_code != 206: + logging.warning("Server did not honor Range request (status %s), restarting download", r.status_code) + resume_size = 0 + with tqdm(total=total_size, unit="B", unit_scale=True, desc=file_path, + initial=resume_size, position=position, leave=True) as pbar: + mode = "ab" if resume_size > 0 else "wb" + with open(file_path, mode, buffering=8 * 1024 * 1024) as f: + for chunk in r.iter_content(chunk_size=8 * 1024 * 1024): + if chunk: + f.write(chunk) + pbar.update(len(chunk)) + + +def _http_download_one( + url: str, + output_folder: str, + skip_if_downloaded_already: bool, + max_retries: int = 3, + position: int = 0, +) -> None: + """ + Download a single HTTP(S) URL with HEAD-then-Range resume and retry. + Used as the worker target for both the serial loop and the parallel + ThreadPoolExecutor path. Reuses :meth:`_parallel_download` so the same + resume / restart-on-non-206 behaviour is shared with globus downloads. + """ + local_path = _local_path_for_url(url, output_folder) + if skip_if_downloaded_already and os.path.exists(local_path): + logging.info(f"Skipping download as file already exists: {local_path}") + return + last_error: Optional[Exception] = None + for attempt in range(1, max_retries + 1): + try: + _parallel_download(url, local_path, position=position) + logging.info(f"Successfully downloaded {local_path}") + return + except Exception as e: + last_error = e + logging.warning( + f"HTTP download attempt {attempt}/{max_retries} failed for {url}: {e}" + ) + raise RuntimeError( + f"Giving up on {local_path} after {max_retries} HTTP attempts" + ) from last_error + + +def download_http_urls( + http_urls: List[str], + output_folder: str, + skip_if_downloaded_already: bool, + parallel_files: int = 1, + max_retries: int = 3, +) -> None: + """ + Download a list of HTTP(S) URLs with HEAD-then-Range resume, per-file + retries, and an optional ``parallel_files`` worker pool. + + When ``parallel_files`` > 1, downloads run concurrently using a + :class:`ThreadPoolExecutor`. Each worker manages its own file (a new + ``requests`` session is opened inside ``_parallel_download``) so the + only shared resource is the output directory. + """ + if not os.path.isdir(output_folder): + os.makedirs(output_folder, exist_ok=True) + + if not http_urls: + return + + workers = max(1, min(parallel_files, len(http_urls))) + if workers > 1: + logging.info( + f"Downloading {len(http_urls)} HTTP(S) file(s) with {workers} parallel workers" + ) + with ThreadPoolExecutor(max_workers=workers) as executor: + futures = [ + executor.submit( + _http_download_one, + url, + output_folder, + skip_if_downloaded_already, + max_retries, + idx, + ) + for idx, url in enumerate(http_urls) + ] + for future in as_completed(futures): + try: + future.result() + except Exception as e: + logging.error(f"Parallel HTTP download error: {e}") + else: + for url in http_urls: + try: + _http_download_one( + url, + output_folder, + skip_if_downloaded_already, + max_retries, + ) + except Exception as e: + logging.error(f"HTTP download failed for {url}: {e}") diff --git a/pridepy/providers/util.py b/pridepy/providers/util.py new file mode 100644 index 0000000..0fc5791 --- /dev/null +++ b/pridepy/providers/util.py @@ -0,0 +1,183 @@ +"""Cross-cutting utilities used by providers and the Files facade. + +Pure functions (and one tiny Progress class) for checksums, record-shape +helpers, and download progress. Originally on ``Files`` as @staticmethods; +moved here so providers can use them without depending on Files at import +time, and Files keeps shim re-exports for backward compatibility with +existing test patches. +""" +import hashlib +import logging +import os +from typing import Dict, List, Optional, Tuple + +from tqdm import tqdm + + +class Progress: + def __init__(self, total_size, file_name): + self.pbar = tqdm( + total=total_size, + unit="B", + unit_scale=True, + desc="Downloading {}".format(file_name), + ) + + def __call__(self, bytes_amount): + self.pbar.update(bytes_amount) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.pbar.close() + + def close(self): + self.pbar.close() + + +def _find_tsv_columns(header: str) -> Optional[Tuple[int, int]]: + """Return (name_idx, checksum_idx) from a TSV header, or None.""" + cols = [col.strip().lower() for col in header.split("\t")] + required_cols = {"file-name", "file-md5checksum", "file-size"} + if not required_cols.issubset(set(cols)): + return None + return cols.index("file-name"), cols.index("file-md5checksum") + + +def _is_md5_checksum(value: str) -> bool: + return len(value) == 32 and all(char in "0123456789abcdef" for char in value) + + +def read_checksum_file(checksum_file_path: str) -> Dict[str, str]: + """ + Read PRIDE API checksum TSV and build {file_name: md5} map. + Expected format: File-Name\tFile-MD5Checksum\tFile-Size + """ + checksums: Dict[str, str] = {} + if not checksum_file_path or not os.path.exists(checksum_file_path): + return checksums + + with open(checksum_file_path, "r", encoding="utf-8") as f: + header = f.readline().strip() + if not header: + return checksums + + col_indices = _find_tsv_columns(header) + if col_indices is None: + logging.warning(f"Unrecognized checksum file format: {header}") + return checksums + + name_idx, checksum_idx = col_indices + min_cols = max(name_idx, checksum_idx) + 1 + for line in f: + parts = line.strip().split("\t") + if len(parts) >= min_cols: + fn = os.path.basename(parts[name_idx].strip()) + cs = parts[checksum_idx].strip().lower() + if fn and _is_md5_checksum(cs): + checksums[fn] = cs + + return checksums + + +def compute_md5(file_path: str, chunk_size: int = 4 * 1024 * 1024) -> str: + """ + Compute an MD5 checksum for integrity validation, not for security use. + """ + try: + md5 = hashlib.md5(usedforsecurity=False) + except TypeError: + md5 = hashlib.md5() + with open(file_path, "rb") as file_handle: + while True: + chunk = file_handle.read(chunk_size) + if not chunk: + break + md5.update(chunk) + return md5.hexdigest() + + +def validate_download(file_path: str, expected_checksum: Optional[str] = None) -> Tuple[bool, str]: + """ + Validate a local file exists, is non-empty, and checksum matches when provided. + """ + if not os.path.exists(file_path): + return False, "file does not exist" + if os.path.getsize(file_path) == 0: + return False, "file is empty" + if expected_checksum: + actual_checksum = compute_md5(file_path) + if actual_checksum.lower() != expected_checksum.lower(): + return False, ( + f"checksum mismatch (expected={expected_checksum.lower()}, actual={actual_checksum.lower()})" + ) + return True, "ok" + + +def _remove_if_exists(file_path: str) -> None: + """ + Remove a file if it already exists locally. + """ + if os.path.exists(file_path): + os.remove(file_path) + + +def _get_download_url(file_record: Dict, protocol: str) -> str: + """ + Resolve the public download URL for a file and protocol. + + Raises ValueError when the requested protocol has no suitable location. + Aspera requires a dedicated "Aspera Protocol" entry; ftp/s3/globus + derive their URL from the "FTP Protocol" entry (falling back to an + arbitrary non-Aspera location would produce a URL the caller cannot + actually transfer with). + """ + from pridepy.files.files import Files + + locations = file_record.get("publicFileLocations", []) + if not locations: + raise ValueError("No public file locations present") + + aspera_url = None + ftp_url = None + for location in locations: + name = location.get("name") + if name == "Aspera Protocol": + aspera_url = location.get("value") + elif name == "FTP Protocol": + ftp_url = location.get("value") + + if protocol == "aspera": + if not aspera_url: + raise ValueError("Aspera URL not available") + return aspera_url + + if not ftp_url: + raise ValueError("FTP URL not available") + if protocol == "ftp": + return ftp_url + if protocol == "globus": + return ftp_url.replace( + Files.PRIDE_ARCHIVE_FTP_URL_PREFIX, + Files.PRIDE_ARCHIVE_HTTPS_URL_PREFIX, + 1, + ) + if protocol == "s3": + return ftp_url + raise ValueError(f"Unsupported protocol: {protocol}") + + +def _resolve_local_path(file_record: Dict, output_folder: str) -> str: + """ + Compute the canonical local path for a file regardless of transfer protocol. + """ + from pridepy.files.files import Files + + try: + canonical_url = _get_download_url(file_record, "ftp") + except ValueError: + canonical_url = "" + if canonical_url: + return Files.get_output_file_name(canonical_url, file_record, output_folder) + return os.path.join(output_folder, file_record["fileName"]) diff --git a/pridepy/tests/test_download_by_list.py b/pridepy/tests/test_download_by_list.py index df81914..5115b4e 100644 --- a/pridepy/tests/test_download_by_list.py +++ b/pridepy/tests/test_download_by_list.py @@ -14,6 +14,7 @@ from pridepy.files.files import Files from pridepy.pridepy import _read_filename_arguments +from pridepy.providers.pride import PrideProvider class TestDownloadFilesByList(TestCase): @@ -36,8 +37,8 @@ def test_filters_metadata_and_delegates(self): {"fileName": "c.raw"}, ] with patch.object( - files_obj, "stream_all_files_by_project", return_value=api_response - ), patch.object(files_obj, "download_files") as mock_download: + PrideProvider, "list_files", return_value=api_response + ), patch.object(PrideProvider, "download_files") as mock_download: files_obj.download_files_by_list( accession="PXD001819", file_names=["a.raw", "c.raw"], @@ -46,16 +47,16 @@ def test_filters_metadata_and_delegates(self): protocol="ftp", ) - args, _ = mock_download.call_args - matched = args[0] + _, kwargs = mock_download.call_args + matched = kwargs["records"] assert {f["fileName"] for f in matched} == {"a.raw", "c.raw"} def test_warns_on_partial_match(self): files_obj = Files() api_response = [{"fileName": "a.raw"}] with patch.object( - files_obj, "stream_all_files_by_project", return_value=api_response - ), patch.object(files_obj, "download_files") as mock_download, self.assertLogs( + PrideProvider, "list_files", return_value=api_response + ), patch.object(PrideProvider, "download_files") as mock_download, self.assertLogs( level="WARNING" ) as log_ctx: files_obj.download_files_by_list( @@ -71,7 +72,7 @@ def test_warns_on_partial_match(self): def test_raises_when_no_files_match(self): files_obj = Files() with patch.object( - files_obj, "stream_all_files_by_project", return_value=[] + PrideProvider, "list_files", return_value=[] ): with pytest.raises(ValueError, match="No matching files"): files_obj.download_files_by_list( diff --git a/pridepy/tests/test_download_resilience.py b/pridepy/tests/test_download_resilience.py index 21b1603..0f86013 100644 --- a/pridepy/tests/test_download_resilience.py +++ b/pridepy/tests/test_download_resilience.py @@ -268,3 +268,43 @@ def test_download_files_raises_when_any_file_fails(self): skip_if_downloaded_already=False, protocol="ftp", ) + + def test_facade_dispatches_pride_through_registry_to_fallback(self): + """Files().download_all_raw_files for a PXD accession must flow: + Files facade -> Registry.resolve -> PrideProvider.download_files + -> _batch_download_by_protocol (mocked). + + Patching Files._batch_download_by_protocol proves the patch intercepts + (i.e. PrideProvider calls *back* through Files, preserving the test + contract for the multi-protocol orchestrator). + """ + from pridepy.providers.pride import PrideProvider + + fake_records = [ + { + "accession": "PXD000001", + "fileName": "x.raw", + "fileCategory": {"value": "RAW"}, + "publicFileLocations": [ + {"name": "FTP Protocol", "value": "ftp://ftp.pride.ebi.ac.uk/.../x.raw"} + ], + }, + ] + + with tempfile.TemporaryDirectory() as tmp: + with patch.object(PrideProvider, "list_files", return_value=fake_records), \ + patch.object(Files, "_batch_download_by_protocol", return_value=[]) as batch_mock, \ + patch.object(Files, "validate_download", return_value=(True, "ok")), \ + patch.object(Files, "_download_with_fallback") as fallback_mock: + Files().download_all_raw_files( + accession="PXD000001", + output_folder=tmp, + skip_if_downloaded_already=False, + protocol="ftp", + aspera_maximum_bandwidth="100M", + ) + + batch_mock.assert_called_once() + # No fallback expected because all files passed validation after + # the primary-protocol batch run. + fallback_mock.assert_not_called() diff --git a/pridepy/tests/test_jpost_files.py b/pridepy/tests/test_jpost_files.py index 678adda..1e4c652 100644 --- a/pridepy/tests/test_jpost_files.py +++ b/pridepy/tests/test_jpost_files.py @@ -4,6 +4,7 @@ from unittest.mock import MagicMock, patch from pridepy.files.files import Files +from pridepy.providers.jpost import JpostProvider class TestJPOSTFiles(TestCase): @@ -50,12 +51,9 @@ def test_get_all_raw_file_list_filters_jpost_records(self): ), ] - with patch.object(Files, "_list_jpost_public_files", return_value=jpost_records), patch.object( - Files, "stream_all_files_by_project" - ) as pride_mock: + with patch.object(JpostProvider, "list_files", return_value=jpost_records): result = files.get_all_raw_file_list("JPST000001") - pride_mock.assert_not_called() assert len(result) == 1 assert {file["fileName"] for file in result} == {"run1.raw"} @@ -68,7 +66,7 @@ def test_download_file_by_name_uses_jpost_ftp_listing(self): with tempfile.TemporaryDirectory() as tmp_dir: with patch.object( - Files, "_list_jpost_public_files", return_value=[file_record] + JpostProvider, "list_files", return_value=[file_record] ), patch.object(Files, "download_ftp_urls") as download_mock: files.download_file_by_name( accession="JPST000001", diff --git a/pridepy/tests/test_massive_files.py b/pridepy/tests/test_massive_files.py index a958309..a4e9278 100644 --- a/pridepy/tests/test_massive_files.py +++ b/pridepy/tests/test_massive_files.py @@ -3,6 +3,7 @@ from unittest.mock import patch from pridepy.files.files import Files +from pridepy.providers.massive import MassiveProvider class TestMassIVEFiles(TestCase): @@ -66,7 +67,7 @@ def test_get_all_raw_file_list_filters_massive_records(self): ), ] - with patch.object(Files, "_list_massive_public_files", return_value=massive_records): + with patch.object(MassiveProvider, "list_files", return_value=massive_records): result = files.get_all_raw_file_list("MSV000012345") assert len(result) == 1 @@ -80,7 +81,7 @@ def test_download_file_by_name_uses_massive_ftp_listing(self): ) with tempfile.TemporaryDirectory() as tmp_dir: - with patch.object(Files, "_list_massive_public_files", return_value=[file_record]), patch.object( + with patch.object(MassiveProvider, "list_files", return_value=[file_record]), patch.object( Files, "download_ftp_urls" ) as download_mock: files.download_file_by_name( @@ -120,7 +121,7 @@ def test_download_all_raw_files_threads_parallel_files_for_massive(self): with tempfile.TemporaryDirectory() as tmp_dir: with patch.object( - Files, "_list_massive_public_files", return_value=massive_records + MassiveProvider, "list_files", return_value=massive_records ), patch.object(Files, "download_ftp_urls") as download_mock: files.download_all_raw_files( accession="MSV000012345", @@ -135,3 +136,42 @@ def test_download_all_raw_files_threads_parallel_files_for_massive(self): kwargs = download_mock.call_args.kwargs assert kwargs["use_tls"] is True assert kwargs["parallel_files"] == 3 + + def test_base_direct_download_provider_partitions_urls_by_scheme(self): + """Records mixing ftp:// and http(s):// route to the right transport.""" + from pridepy.providers.massive import MassiveProvider + + provider = MassiveProvider() + records = [ + Files._build_massive_file_record( + "MSV000012345", + "ftp://massive-ftp.ucsd.edu/v01/MSV000012345/raw/a.raw", + ), + # Synthetic http record to verify partitioning (real MassIVE uses ftp). + { + "accession": "MSV000012345", + "fileName": "b.raw", + "fileCategory": {"value": "RAW"}, + "publicFileLocations": [ + {"name": "FTP Protocol", "value": "http://example.org/b.raw"} + ], + }, + ] + with patch.object(Files, "download_ftp_urls") as ftp_mock, \ + patch.object(Files, "download_http_urls") as http_mock: + provider.download_files( + accession="MSV000012345", + records=records, + output_folder="/tmp/test", + skip_if_downloaded_already=False, + protocol="ftp", + parallel_files=1, + ) + + ftp_mock.assert_called_once() + assert ftp_mock.call_args.kwargs["use_tls"] is True + assert ftp_mock.call_args.kwargs["ftp_urls"] == [ + "ftp://massive-ftp.ucsd.edu/v01/MSV000012345/raw/a.raw" + ] + http_mock.assert_called_once() + assert http_mock.call_args.kwargs["http_urls"] == ["http://example.org/b.raw"] diff --git a/pyproject.toml b/pyproject.toml index 4a95f24..f5b74ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "pridepy" -version = "0.0.16" +version = "0.0.17" description = "Python Client library for PRIDE Rest API" readme = "README.md" requires-python = ">=3.9"