Skip to content

Commit 118d7db

Browse files
Add Deltalake query support (#1023)
2 parents b311d19 + 4291b38 commit 118d7db

13 files changed

Lines changed: 538 additions & 145 deletions

.github/workflows/testing.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ jobs:
5656
- name: Test with pytest
5757
env:
5858
MP_API_KEY: ${{ secrets[env.API_KEY_NAME] }}
59-
#MP_API_ENDPOINT: https://api-preview.materialsproject.org/
59+
# MP_API_ENDPOINT: https://api-preview.materialsproject.org/
6060
run: |
6161
pytest -n auto -x --cov=mp_api --cov-report=xml
6262
- uses: codecov/codecov-action@v1

mp_api/client/core/client.py

Lines changed: 230 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
import gzip
99
import inspect
1010
import itertools
11+
import logging
1112
import os
1213
import platform
14+
import shutil
1315
import sys
1416
import warnings
1517
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
@@ -22,12 +24,17 @@
2224
from json import JSONDecodeError
2325
from math import ceil
2426
from typing import TYPE_CHECKING, ForwardRef, Optional, get_args
27+
from urllib.parse import urljoin
2528

2629
import boto3
30+
import pyarrow as pa
31+
import pyarrow.dataset as ds
2732
import requests
2833
from botocore import UNSIGNED
2934
from botocore.config import Config
3035
from botocore.exceptions import ClientError
36+
from deltalake import DeltaTable, QueryBuilder, convert_to_deltalake
37+
from emmet.core.arrow import arrowize
3138
from emmet.core.utils import jsanitize
3239
from pydantic import BaseModel, create_model
3340
from requests.adapters import HTTPAdapter
@@ -38,6 +45,7 @@
3845
from mp_api.client.core.exceptions import MPRestError
3946
from mp_api.client.core.settings import MAPI_CLIENT_SETTINGS
4047
from mp_api.client.core.utils import (
48+
MPDataset,
4149
load_json,
4250
validate_api_key,
4351
validate_endpoint,
@@ -50,7 +58,8 @@
5058
flask = None
5159

5260
if TYPE_CHECKING:
53-
from typing import Any, Callable
61+
from collections.abc import Callable, Iterable, Iterator
62+
from typing import Any
5463

5564
from pydantic.fields import FieldInfo
5665

@@ -62,7 +71,16 @@
6271
__version__ = os.getenv("SETUPTOOLS_SCM_PRETEND_VERSION")
6372

6473

65-
def _batched(iterable, n):
74+
hdlr = logging.StreamHandler()
75+
fmt = logging.Formatter("%(name)s - %(levelname)s - %(message)s")
76+
hdlr.setFormatter(fmt)
77+
78+
logger = logging.getLogger(__name__)
79+
logger.setLevel(logging.INFO)
80+
logger.addHandler(hdlr)
81+
82+
83+
def _batched(iterable: Iterable, n: int) -> Iterator:
6684
if n < 1:
6785
raise ValueError("n must be at least one")
6886
iterator = iter(iterable)
@@ -93,6 +111,7 @@ class BaseRester:
93111
suffix: str = ""
94112
document_model: type[BaseModel] | None = None
95113
primary_key: str = "material_id"
114+
delta_backed: bool = False
96115

97116
def __init__(
98117
self,
@@ -106,6 +125,10 @@ def __init__(
106125
timeout: int = 20,
107126
headers: dict | None = None,
108127
mute_progress_bars: bool = MAPI_CLIENT_SETTINGS.MUTE_PROGRESS_BARS,
128+
local_dataset_cache: (
129+
str | os.PathLike
130+
) = MAPI_CLIENT_SETTINGS.LOCAL_DATASET_CACHE,
131+
force_renew: bool = False,
109132
**kwargs,
110133
):
111134
"""Initialize the REST API helper class.
@@ -137,6 +160,9 @@ def __init__(
137160
timeout: Time in seconds to wait until a request timeout error is thrown
138161
headers: Custom headers for localhost connections.
139162
mute_progress_bars: Whether to disable progress bars.
163+
local_dataset_cache: Target directory for downloading full datasets. Defaults
164+
to 'mp_datasets' in the user's home directory
165+
force_renew: Option to overwrite existing local dataset
140166
**kwargs: access to legacy kwargs that may be in the process of being deprecated
141167
"""
142168
self.api_key = validate_api_key(api_key)
@@ -149,7 +175,14 @@ def __init__(
149175
self.timeout = timeout
150176
self.headers = headers or {}
151177
self.mute_progress_bars = mute_progress_bars
152-
self.db_version = BaseRester._get_database_version(self.base_endpoint)
178+
179+
(
180+
self.db_version,
181+
self.access_controlled_batch_ids,
182+
) = BaseRester._get_heartbeat_info(self.base_endpoint)
183+
184+
self.local_dataset_cache = local_dataset_cache
185+
self.force_renew = force_renew
153186

154187
self._session = session
155188
self._s3_client = s3_client
@@ -217,8 +250,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): # pragma: no cover
217250

218251
@staticmethod
219252
@cache
220-
def _get_database_version(endpoint):
221-
"""The Materials Project database is periodically updated and has a
253+
def _get_heartbeat_info(endpoint) -> tuple[str, str]:
254+
"""DB version:
255+
The Materials Project database is periodically updated and has a
222256
database version associated with it. When the database is updated,
223257
consolidated data (information about "a material") may and does
224258
change, while calculation data about a specific calculation task
@@ -228,9 +262,24 @@ def _get_database_version(endpoint):
228262
where "_DD" may be optional. An additional numerical or `postN` suffix
229263
might be added if multiple releases happen on the same day.
230264
231-
Returns: database version as a string
265+
Access Controlled Datasets:
266+
Certain contributions to the Materials Project have access
267+
control restrictions that require explicit agreement to the
268+
Terms of Use for the respective datasets prior to access being
269+
granted.
270+
271+
A full list of the Terms of Use for all contributions in the
272+
Materials Project are available at:
273+
274+
https://next-gen.materialsproject.org/about/terms
275+
276+
Returns:
277+
tuple with database version as a string and a comma separated
278+
string with all calculation batch identifiers that have access
279+
restrictions
232280
"""
233-
return requests.get(url=endpoint + "heartbeat").json()["db_version"]
281+
response = requests.get(url=endpoint + "heartbeat").json()
282+
return response["db_version"], response["access_controlled_batch_ids"]
234283

235284
def _post_resource(
236285
self,
@@ -361,10 +410,7 @@ def _patch_resource(
361410
raise MPRestError(str(ex))
362411

363412
def _query_open_data(
364-
self,
365-
bucket: str,
366-
key: str,
367-
decoder: Callable | None = None,
413+
self, bucket: str, key: str, decoder: Callable | None = None
368414
) -> tuple[list[dict] | list[bytes], int]:
369415
"""Query and deserialize Materials Project AWS open data s3 buckets.
370416
@@ -466,6 +512,12 @@ def _query_resource(
466512
url = validate_endpoint(self.endpoint, suffix=suburl)
467513

468514
if query_s3:
515+
pbar_message = ( # type: ignore
516+
f"Retrieving {self.document_model.__name__} documents" # type: ignore
517+
if self.document_model is not None
518+
else "Retrieving documents"
519+
)
520+
469521
if "/" not in self.suffix:
470522
suffix = self.suffix
471523
elif self.suffix == "molecules/summary":
@@ -475,15 +527,177 @@ def _query_resource(
475527
suffix = infix if suffix == "core" else suffix
476528
suffix = suffix.replace("_", "-")
477529

478-
# Paginate over all entries in the bucket.
479-
# TODO: change when a subset of entries needed from DB
530+
# Check if user has access to GNoMe
531+
# temp suppress tqdm
532+
re_enable = not self.mute_progress_bars
533+
self.mute_progress_bars = True
534+
has_gnome_access = bool(
535+
self._submit_requests(
536+
url=urljoin(self.base_endpoint, "materials/summary/"),
537+
criteria={
538+
"batch_id": "gnome_r2scan_statics",
539+
"_fields": "material_id",
540+
},
541+
use_document_model=False,
542+
num_chunks=1,
543+
chunk_size=1,
544+
timeout=timeout,
545+
)
546+
.get("meta", {})
547+
.get("total_doc", 0)
548+
)
549+
self.mute_progress_bars = not re_enable
550+
480551
if "tasks" in suffix:
481-
bucket_suffix, prefix = "parsed", "tasks_atomate2"
552+
bucket_suffix, prefix = ("parsed", "core/tasks/")
482553
else:
483554
bucket_suffix = "build"
484555
prefix = f"collections/{self.db_version.replace('.', '-')}/{suffix}"
485556

486557
bucket = f"materialsproject-{bucket_suffix}"
558+
559+
if self.delta_backed:
560+
target_path = str(
561+
self.local_dataset_cache.joinpath(f"{bucket_suffix}/{prefix}")
562+
)
563+
os.makedirs(target_path, exist_ok=True)
564+
565+
if DeltaTable.is_deltatable(target_path):
566+
if self.force_renew:
567+
shutil.rmtree(target_path)
568+
logger.warning(
569+
f"Regenerating {suffix} dataset at {target_path}..."
570+
)
571+
os.makedirs(target_path, exist_ok=True)
572+
else:
573+
logger.warning(
574+
f"Dataset for {suffix} already exists at {target_path}, returning existing dataset."
575+
)
576+
logger.info(
577+
"Delete or move existing dataset or re-run search query with MPRester(force_renew=True) "
578+
"to refresh local dataset.",
579+
)
580+
581+
return {
582+
"data": MPDataset(
583+
path=target_path,
584+
document_model=self.document_model,
585+
use_document_model=self.use_document_model,
586+
)
587+
}
588+
589+
tbl = DeltaTable(
590+
f"s3a://{bucket}/{prefix}",
591+
storage_options={
592+
"AWS_SKIP_SIGNATURE": "true",
593+
"AWS_REGION": "us-east-1",
594+
},
595+
)
596+
597+
controlled_batch_str = ",".join(
598+
[f"'{tag}'" for tag in self.access_controlled_batch_ids]
599+
)
600+
601+
predicate = (
602+
f"WHERE batch_id NOT IN ({controlled_batch_str})"
603+
if not has_gnome_access
604+
else ""
605+
)
606+
607+
builder = QueryBuilder().register("tbl", tbl)
608+
609+
# Setup progress bar
610+
num_docs_needed = tbl.count()
611+
612+
if not has_gnome_access:
613+
num_docs_needed = self.count(
614+
{"batch_id_neq_any": self.access_controlled_batch_ids}
615+
)
616+
617+
pbar = (
618+
tqdm(
619+
desc=pbar_message,
620+
total=num_docs_needed,
621+
)
622+
if not self.mute_progress_bars
623+
else None
624+
)
625+
626+
iterator = builder.execute(f"SELECT * FROM tbl {predicate}")
627+
628+
file_options = ds.ParquetFileFormat().make_write_options(
629+
compression="zstd"
630+
)
631+
632+
def _flush(
633+
accumulator: list[pa.RecordBatch], group: int, schema: pa.Schema
634+
):
635+
# somewhere post datafusion 51.0.0 and arrow-rs 57.0.0
636+
# casts to *View types began, need to cast back to base schema
637+
# -> pyarrow is behind on implementation support for *View types
638+
tbl = (
639+
pa.Table.from_batches(accumulator)
640+
.select(schema.names)
641+
.cast(target_schema=schema)
642+
)
643+
644+
ds.write_dataset(
645+
tbl,
646+
base_dir=target_path,
647+
format="parquet",
648+
basename_template=f"group-{group}-"
649+
+ "part-{i}.zstd.parquet",
650+
existing_data_behavior="overwrite_or_ignore",
651+
max_rows_per_group=1024,
652+
file_options=file_options,
653+
)
654+
655+
group = 1
656+
size = 0
657+
accumulator = []
658+
schema = pa.schema(arrowize(self.document_model))
659+
for page in iterator:
660+
# arro3 rb to pyarrow rb for compat w/ pyarrow ds writer
661+
rg = pa.record_batch(page)
662+
accumulator.append(rg)
663+
page_size = page.num_rows
664+
size += rg.get_total_buffer_size()
665+
666+
if pbar is not None:
667+
pbar.update(page_size)
668+
669+
if size >= MAPI_CLIENT_SETTINGS.DATASET_FLUSH_THRESHOLD:
670+
_flush(accumulator, group, schema)
671+
group += 1
672+
size = 0
673+
accumulator.clear()
674+
675+
if accumulator:
676+
_flush(accumulator, group + 1, schema)
677+
678+
if pbar is not None:
679+
pbar.close()
680+
681+
logger.info(f"Dataset for {suffix} written to {target_path}")
682+
logger.info("Converting to DeltaTable...")
683+
684+
convert_to_deltalake(target_path)
685+
686+
logger.info(
687+
"Consult the delta-rs and pyarrow documentation for advanced usage: "
688+
"delta-io.github.io/delta-rs, arrow.apache.org/docs/python"
689+
)
690+
691+
return {
692+
"data": MPDataset(
693+
path=target_path,
694+
document_model=self.document_model,
695+
use_document_model=self.use_document_model,
696+
)
697+
}
698+
699+
# Paginate over all entries in the bucket.
700+
# TODO: change when a subset of entries needed from DB
487701
paginator = self.s3_client.get_paginator("list_objects_v2")
488702
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)
489703

@@ -519,11 +733,6 @@ def _query_resource(
519733
}
520734

521735
# Setup progress bar
522-
pbar_message = ( # type: ignore
523-
f"Retrieving {self.document_model.__name__} documents" # type: ignore
524-
if self.document_model is not None
525-
else "Retrieving documents"
526-
)
527736
num_docs_needed = int(self.count())
528737
pbar = (
529738
tqdm(
@@ -1317,6 +1526,8 @@ def __getattr__(self, v: str):
13171526
use_document_model=self.use_document_model,
13181527
headers=self.headers,
13191528
mute_progress_bars=self.mute_progress_bars,
1529+
local_dataset_cache=self.local_dataset_cache,
1530+
force_renew=self.force_renew,
13201531
)
13211532
return self.sub_resters[v]
13221533

0 commit comments

Comments
 (0)