88import gzip
99import inspect
1010import itertools
11+ import logging
1112import os
1213import platform
14+ import shutil
1315import sys
1416import warnings
1517from concurrent .futures import FIRST_COMPLETED , ThreadPoolExecutor , wait
2224from json import JSONDecodeError
2325from math import ceil
2426from typing import TYPE_CHECKING , ForwardRef , Optional , get_args
27+ from urllib .parse import urljoin
2528
2629import boto3
30+ import pyarrow as pa
31+ import pyarrow .dataset as ds
2732import requests
2833from botocore import UNSIGNED
2934from botocore .config import Config
3035from botocore .exceptions import ClientError
36+ from deltalake import DeltaTable , QueryBuilder , convert_to_deltalake
37+ from emmet .core .arrow import arrowize
3138from emmet .core .utils import jsanitize
3239from pydantic import BaseModel , create_model
3340from requests .adapters import HTTPAdapter
3845from mp_api .client .core .exceptions import MPRestError
3946from mp_api .client .core .settings import MAPI_CLIENT_SETTINGS
4047from mp_api .client .core .utils import (
48+ MPDataset ,
4149 load_json ,
4250 validate_api_key ,
4351 validate_endpoint ,
5058 flask = None
5159
5260if TYPE_CHECKING :
53- from typing import Any , Callable
61+ from collections .abc import Callable , Iterable , Iterator
62+ from typing import Any
5463
5564 from pydantic .fields import FieldInfo
5665
6271 __version__ = os .getenv ("SETUPTOOLS_SCM_PRETEND_VERSION" )
6372
6473
65- def _batched (iterable , n ):
74+ hdlr = logging .StreamHandler ()
75+ fmt = logging .Formatter ("%(name)s - %(levelname)s - %(message)s" )
76+ hdlr .setFormatter (fmt )
77+
78+ logger = logging .getLogger (__name__ )
79+ logger .setLevel (logging .INFO )
80+ logger .addHandler (hdlr )
81+
82+
83+ def _batched (iterable : Iterable , n : int ) -> Iterator :
6684 if n < 1 :
6785 raise ValueError ("n must be at least one" )
6886 iterator = iter (iterable )
@@ -93,6 +111,7 @@ class BaseRester:
93111 suffix : str = ""
94112 document_model : type [BaseModel ] | None = None
95113 primary_key : str = "material_id"
114+ delta_backed : bool = False
96115
97116 def __init__ (
98117 self ,
@@ -106,6 +125,10 @@ def __init__(
106125 timeout : int = 20 ,
107126 headers : dict | None = None ,
108127 mute_progress_bars : bool = MAPI_CLIENT_SETTINGS .MUTE_PROGRESS_BARS ,
128+ local_dataset_cache : (
129+ str | os .PathLike
130+ ) = MAPI_CLIENT_SETTINGS .LOCAL_DATASET_CACHE ,
131+ force_renew : bool = False ,
109132 ** kwargs ,
110133 ):
111134 """Initialize the REST API helper class.
@@ -137,6 +160,9 @@ def __init__(
137160 timeout: Time in seconds to wait until a request timeout error is thrown
138161 headers: Custom headers for localhost connections.
139162 mute_progress_bars: Whether to disable progress bars.
163+ local_dataset_cache: Target directory for downloading full datasets. Defaults
164+ to 'mp_datasets' in the user's home directory
165+ force_renew: Option to overwrite existing local dataset
140166 **kwargs: access to legacy kwargs that may be in the process of being deprecated
141167 """
142168 self .api_key = validate_api_key (api_key )
@@ -149,7 +175,14 @@ def __init__(
149175 self .timeout = timeout
150176 self .headers = headers or {}
151177 self .mute_progress_bars = mute_progress_bars
152- self .db_version = BaseRester ._get_database_version (self .base_endpoint )
178+
179+ (
180+ self .db_version ,
181+ self .access_controlled_batch_ids ,
182+ ) = BaseRester ._get_heartbeat_info (self .base_endpoint )
183+
184+ self .local_dataset_cache = local_dataset_cache
185+ self .force_renew = force_renew
153186
154187 self ._session = session
155188 self ._s3_client = s3_client
@@ -217,8 +250,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): # pragma: no cover
217250
218251 @staticmethod
219252 @cache
220- def _get_database_version (endpoint ):
221- """The Materials Project database is periodically updated and has a
253+ def _get_heartbeat_info (endpoint ) -> tuple [str , str ]:
254+ """DB version:
255+ The Materials Project database is periodically updated and has a
222256 database version associated with it. When the database is updated,
223257 consolidated data (information about "a material") may and does
224258 change, while calculation data about a specific calculation task
@@ -228,9 +262,24 @@ def _get_database_version(endpoint):
228262 where "_DD" may be optional. An additional numerical or `postN` suffix
229263 might be added if multiple releases happen on the same day.
230264
231- Returns: database version as a string
265+ Access Controlled Datasets:
266+ Certain contributions to the Materials Project have access
267+ control restrictions that require explicit agreement to the
268+ Terms of Use for the respective datasets prior to access being
269+ granted.
270+
271+ A full list of the Terms of Use for all contributions in the
272+ Materials Project are available at:
273+
274+ https://next-gen.materialsproject.org/about/terms
275+
276+ Returns:
277+ tuple with database version as a string and a comma separated
278+ string with all calculation batch identifiers that have access
279+ restrictions
232280 """
233- return requests .get (url = endpoint + "heartbeat" ).json ()["db_version" ]
281+ response = requests .get (url = endpoint + "heartbeat" ).json ()
282+ return response ["db_version" ], response ["access_controlled_batch_ids" ]
234283
235284 def _post_resource (
236285 self ,
@@ -361,10 +410,7 @@ def _patch_resource(
361410 raise MPRestError (str (ex ))
362411
363412 def _query_open_data (
364- self ,
365- bucket : str ,
366- key : str ,
367- decoder : Callable | None = None ,
413+ self , bucket : str , key : str , decoder : Callable | None = None
368414 ) -> tuple [list [dict ] | list [bytes ], int ]:
369415 """Query and deserialize Materials Project AWS open data s3 buckets.
370416
@@ -466,6 +512,12 @@ def _query_resource(
466512 url = validate_endpoint (self .endpoint , suffix = suburl )
467513
468514 if query_s3 :
515+ pbar_message = ( # type: ignore
516+ f"Retrieving { self .document_model .__name__ } documents" # type: ignore
517+ if self .document_model is not None
518+ else "Retrieving documents"
519+ )
520+
469521 if "/" not in self .suffix :
470522 suffix = self .suffix
471523 elif self .suffix == "molecules/summary" :
@@ -475,15 +527,177 @@ def _query_resource(
475527 suffix = infix if suffix == "core" else suffix
476528 suffix = suffix .replace ("_" , "-" )
477529
478- # Paginate over all entries in the bucket.
479- # TODO: change when a subset of entries needed from DB
530+ # Check if user has access to GNoMe
531+ # temp suppress tqdm
532+ re_enable = not self .mute_progress_bars
533+ self .mute_progress_bars = True
534+ has_gnome_access = bool (
535+ self ._submit_requests (
536+ url = urljoin (self .base_endpoint , "materials/summary/" ),
537+ criteria = {
538+ "batch_id" : "gnome_r2scan_statics" ,
539+ "_fields" : "material_id" ,
540+ },
541+ use_document_model = False ,
542+ num_chunks = 1 ,
543+ chunk_size = 1 ,
544+ timeout = timeout ,
545+ )
546+ .get ("meta" , {})
547+ .get ("total_doc" , 0 )
548+ )
549+ self .mute_progress_bars = not re_enable
550+
480551 if "tasks" in suffix :
481- bucket_suffix , prefix = "parsed" , "tasks_atomate2"
552+ bucket_suffix , prefix = ( "parsed" , "core/tasks/" )
482553 else :
483554 bucket_suffix = "build"
484555 prefix = f"collections/{ self .db_version .replace ('.' , '-' )} /{ suffix } "
485556
486557 bucket = f"materialsproject-{ bucket_suffix } "
558+
559+ if self .delta_backed :
560+ target_path = str (
561+ self .local_dataset_cache .joinpath (f"{ bucket_suffix } /{ prefix } " )
562+ )
563+ os .makedirs (target_path , exist_ok = True )
564+
565+ if DeltaTable .is_deltatable (target_path ):
566+ if self .force_renew :
567+ shutil .rmtree (target_path )
568+ logger .warning (
569+ f"Regenerating { suffix } dataset at { target_path } ..."
570+ )
571+ os .makedirs (target_path , exist_ok = True )
572+ else :
573+ logger .warning (
574+ f"Dataset for { suffix } already exists at { target_path } , returning existing dataset."
575+ )
576+ logger .info (
577+ "Delete or move existing dataset or re-run search query with MPRester(force_renew=True) "
578+ "to refresh local dataset." ,
579+ )
580+
581+ return {
582+ "data" : MPDataset (
583+ path = target_path ,
584+ document_model = self .document_model ,
585+ use_document_model = self .use_document_model ,
586+ )
587+ }
588+
589+ tbl = DeltaTable (
590+ f"s3a://{ bucket } /{ prefix } " ,
591+ storage_options = {
592+ "AWS_SKIP_SIGNATURE" : "true" ,
593+ "AWS_REGION" : "us-east-1" ,
594+ },
595+ )
596+
597+ controlled_batch_str = "," .join (
598+ [f"'{ tag } '" for tag in self .access_controlled_batch_ids ]
599+ )
600+
601+ predicate = (
602+ f"WHERE batch_id NOT IN ({ controlled_batch_str } )"
603+ if not has_gnome_access
604+ else ""
605+ )
606+
607+ builder = QueryBuilder ().register ("tbl" , tbl )
608+
609+ # Setup progress bar
610+ num_docs_needed = tbl .count ()
611+
612+ if not has_gnome_access :
613+ num_docs_needed = self .count (
614+ {"batch_id_neq_any" : self .access_controlled_batch_ids }
615+ )
616+
617+ pbar = (
618+ tqdm (
619+ desc = pbar_message ,
620+ total = num_docs_needed ,
621+ )
622+ if not self .mute_progress_bars
623+ else None
624+ )
625+
626+ iterator = builder .execute (f"SELECT * FROM tbl { predicate } " )
627+
628+ file_options = ds .ParquetFileFormat ().make_write_options (
629+ compression = "zstd"
630+ )
631+
632+ def _flush (
633+ accumulator : list [pa .RecordBatch ], group : int , schema : pa .Schema
634+ ):
635+ # somewhere post datafusion 51.0.0 and arrow-rs 57.0.0
636+ # casts to *View types began, need to cast back to base schema
637+ # -> pyarrow is behind on implementation support for *View types
638+ tbl = (
639+ pa .Table .from_batches (accumulator )
640+ .select (schema .names )
641+ .cast (target_schema = schema )
642+ )
643+
644+ ds .write_dataset (
645+ tbl ,
646+ base_dir = target_path ,
647+ format = "parquet" ,
648+ basename_template = f"group-{ group } -"
649+ + "part-{i}.zstd.parquet" ,
650+ existing_data_behavior = "overwrite_or_ignore" ,
651+ max_rows_per_group = 1024 ,
652+ file_options = file_options ,
653+ )
654+
655+ group = 1
656+ size = 0
657+ accumulator = []
658+ schema = pa .schema (arrowize (self .document_model ))
659+ for page in iterator :
660+ # arro3 rb to pyarrow rb for compat w/ pyarrow ds writer
661+ rg = pa .record_batch (page )
662+ accumulator .append (rg )
663+ page_size = page .num_rows
664+ size += rg .get_total_buffer_size ()
665+
666+ if pbar is not None :
667+ pbar .update (page_size )
668+
669+ if size >= MAPI_CLIENT_SETTINGS .DATASET_FLUSH_THRESHOLD :
670+ _flush (accumulator , group , schema )
671+ group += 1
672+ size = 0
673+ accumulator .clear ()
674+
675+ if accumulator :
676+ _flush (accumulator , group + 1 , schema )
677+
678+ if pbar is not None :
679+ pbar .close ()
680+
681+ logger .info (f"Dataset for { suffix } written to { target_path } " )
682+ logger .info ("Converting to DeltaTable..." )
683+
684+ convert_to_deltalake (target_path )
685+
686+ logger .info (
687+ "Consult the delta-rs and pyarrow documentation for advanced usage: "
688+ "delta-io.github.io/delta-rs, arrow.apache.org/docs/python"
689+ )
690+
691+ return {
692+ "data" : MPDataset (
693+ path = target_path ,
694+ document_model = self .document_model ,
695+ use_document_model = self .use_document_model ,
696+ )
697+ }
698+
699+ # Paginate over all entries in the bucket.
700+ # TODO: change when a subset of entries needed from DB
487701 paginator = self .s3_client .get_paginator ("list_objects_v2" )
488702 pages = paginator .paginate (Bucket = bucket , Prefix = prefix )
489703
@@ -519,11 +733,6 @@ def _query_resource(
519733 }
520734
521735 # Setup progress bar
522- pbar_message = ( # type: ignore
523- f"Retrieving { self .document_model .__name__ } documents" # type: ignore
524- if self .document_model is not None
525- else "Retrieving documents"
526- )
527736 num_docs_needed = int (self .count ())
528737 pbar = (
529738 tqdm (
@@ -1317,6 +1526,8 @@ def __getattr__(self, v: str):
13171526 use_document_model = self .use_document_model ,
13181527 headers = self .headers ,
13191528 mute_progress_bars = self .mute_progress_bars ,
1529+ local_dataset_cache = self .local_dataset_cache ,
1530+ force_renew = self .force_renew ,
13201531 )
13211532 return self .sub_resters [v ]
13221533
0 commit comments