Skip to content

Commit b9950f8

Browse files
authored
Merge pull request #33 from materials-data-facility/forge-dev
forge-dev
2 parents 0c3be38 + 95fd6af commit b9950f8

3 files changed

Lines changed: 56 additions & 33 deletions

File tree

mdf_forge/forge.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import os
2-
import re
32
from urllib.parse import urlparse
43

54
import globus_sdk
@@ -420,9 +419,6 @@ def match_source_names(self, source_names):
420419
return self
421420
if isinstance(source_names, str):
422421
source_names = [source_names]
423-
# If no version supplied, add * to each source name to match all versions
424-
source_names = [(sn+"*" if re.search(".*_v[0-9]+", sn) is None else sn)
425-
for sn in source_names]
426422
# First source should be in new group and required
427423
self.match_field(field="mdf.source_name", value=source_names[0],
428424
required=True, new_group=True)
@@ -701,7 +697,7 @@ def get_dataset_version(self, source_name):
701697
int: Version of the dataset in question
702698
"""
703699

704-
hits = self.search("mdf.source_name:{}_v* AND"
700+
hits = self.search("mdf.source_name:{} AND"
705701
" mdf.resource_type:dataset".format(source_name),
706702
advanced=True, limit=2)
707703

@@ -1255,7 +1251,7 @@ def negate(self):
12551251
self.operator("NOT")
12561252
return self
12571253

1258-
def search(self, q=None, index=None, advanced=None, limit=None, info=False):
1254+
def search(self, q=None, index=None, advanced=None, limit=None, info=False, retries=3):
12591255
"""Execute a search and return the results.
12601256
12611257
Args:
@@ -1276,6 +1272,8 @@ def search(self, q=None, index=None, advanced=None, limit=None, info=False):
12761272
If **True**, search will return a tuple containing the results list
12771273
and other information about the query.
12781274
Default **False**.
1275+
retries (int): The number of times to retry a Search query if it fails.
1276+
Default 3.
12791277
12801278
Returns:
12811279
list (if info=False): The results.
@@ -1309,20 +1307,43 @@ def search(self, q=None, index=None, advanced=None, limit=None, info=False):
13091307
"limit": limit,
13101308
"offset": 0
13111309
}
1312-
res = mdf_toolbox.gmeta_pop(self.__search_client.post_search(uuid_index, qu), info=info)
1310+
tries = 0
1311+
errors = []
1312+
while True:
1313+
try:
1314+
search_res = self.__search_client.post_search(uuid_index, qu)
1315+
except globus_sdk.SearchAPIError as e:
1316+
if tries >= retries:
1317+
raise
1318+
else:
1319+
errors.append(repr(e))
1320+
except Exception as e:
1321+
if tries >= retries:
1322+
raise
1323+
else:
1324+
errors.append(repr(e))
1325+
else:
1326+
break
1327+
tries += 1
1328+
res = mdf_toolbox.gmeta_pop(search_res, info=info)
13131329
# Add additional info
13141330
if info:
13151331
res[1]["query"] = qu
13161332
res[1]["index"] = index
13171333
res[1]["index_uuid"] = uuid_index
1334+
res[1]["retries"] = tries
1335+
res[1]["errors"] = errors
13181336
return res
13191337

1320-
def aggregate(self, q=None, index=None, scroll_size=SEARCH_LIMIT):
1338+
def aggregate(self, q=None, index=None, retries=1, scroll_size=SEARCH_LIMIT):
13211339
"""Gather all results that match a specific query
13221340
13231341
Args:
13241342
q (str): The query to execute. Defaults to the current query, if any.
13251343
There must be some query to execute.
1344+
index (str): The Globus Search index to search on. Required.
1345+
retries (int): The number of times to retry a Search query if it fails.
1346+
Default 1.
13261347
scroll_size (int): Maximum number of records requested per request.
13271348
13281349
Returns:
@@ -1366,7 +1387,8 @@ def aggregate(self, q=None, index=None, scroll_size=SEARCH_LIMIT):
13661387
while True:
13671388
query = "(" + q + ') AND (mdf.scroll_id:>=%d AND mdf.scroll_id:<%d)' % (
13681389
scroll_pos, scroll_pos+scroll_width)
1369-
results, info = self.search(query, index=index, advanced=True, info=True)
1390+
results, info = self.search(query, index=index, advanced=True,
1391+
info=True, retries=retries)
13701392

13711393
# Check to make sure that all the matching records were returned
13721394
if info["total_query_matches"] <= len(results):

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setup(
44
name='mdf_forge',
5-
version='0.6.4',
5+
version='0.6.5',
66
packages=['mdf_forge'],
77
description='Materials Data Facility python package',
88
long_description=("Forge is the Materials Data Facility Python package"

tests/test_forge.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -150,11 +150,11 @@ def test_query_search(capsys):
150150
# Check default limits
151151
res5 = q.search("Al", index="mdf")
152152
assert len(res5) == 10
153-
res6 = q.search("mdf.source_name:nist_xps_db*", advanced=True, index="mdf")
153+
res6 = q.search("mdf.source_name:nist_xps_db", advanced=True, index="mdf")
154154
assert len(res6) == 10000
155155

156156
# Check limit correction
157-
res7 = q.search("mdf.source_name:nist_xps_db*", advanced=True, index="mdf", limit=20000)
157+
res7 = q.search("mdf.source_name:nist_xps_db", advanced=True, index="mdf", limit=20000)
158158
assert len(res7) == 10000
159159

160160
# Test index translation
@@ -184,29 +184,29 @@ def test_query_aggregate(capsys):
184184
assert "Error: No index specified" in out
185185

186186
# Basic aggregation
187-
res1 = q.aggregate("mdf.source_name:nist_xps_db*", index="mdf")
187+
res1 = q.aggregate("mdf.source_name:nist_xps_db", index="mdf")
188188
assert len(res1) > 10000
189189
assert isinstance(res1[0], dict)
190190

191191
# Multi-dataset aggregation
192-
res2 = q.aggregate("(mdf.source_name:nist_xps_db* OR mdf.source_name:khazana_vasp*)",
192+
res2 = q.aggregate("(mdf.source_name:nist_xps_db OR mdf.source_name:khazana_vasp)",
193193
index="mdf")
194194
assert len(res2) > 10000
195195
assert len(res2) > len(res1)
196196

197197
# Unnecessary aggregation fallback to .search()
198198
# Check success in Coveralls
199-
assert len(q.aggregate("mdf.source_name:khazana_vasp*")) < 10000
199+
assert len(q.aggregate("mdf.source_name:khazana_vasp")) < 10000
200200

201201

202202
def test_query_chaining():
203203
q = forge.Query(query_search_client)
204-
q.field("source_name", "cip*")
204+
q.field("source_name", "cip")
205205
q.and_join()
206206
q.field("elements", "Al")
207207
res1 = q.search(limit=10000, index="mdf")
208208
res2 = (forge.Query(query_search_client)
209-
.field("source_name", "cip*")
209+
.field("source_name", "cip")
210210
.and_join()
211211
.field("elements", "Al")
212212
.search(limit=10000, index="mdf"))
@@ -399,7 +399,7 @@ def test_forge_alt_clients():
399399
def test_forge_match_field():
400400
f = forge.Forge(index="mdf")
401401
# Basic usage
402-
f.match_field("mdf.source_name", "khazana_vasp*")
402+
f.match_field("mdf.source_name", "khazana_vasp")
403403
res1 = f.search()
404404
assert check_field(res1, "mdf.source_name", "khazana_vasp") == 0
405405
# Check that query clears
@@ -417,7 +417,8 @@ def test_forge_exclude_field():
417417
# Basic usage
418418
f.exclude_field("material.elements", "Al")
419419
f.exclude_field("", "")
420-
f.match_field("mdf.source_name", "ab_initio_solute_database*")
420+
f.match_field("mdf.source_name", "ab_initio_solute_database")
421+
f.match_field("mdf.resource_type", "record")
421422
res1 = f.search()
422423
assert check_field(res1, "material.elements", "Al") == -1
423424

@@ -509,7 +510,7 @@ def test_forge_match_source_names():
509510
def test_forge_match_ids():
510511
# Get a couple IDs
511512
f = forge.Forge(index="mdf")
512-
res0 = f.search("mdf.source_name:khazana_vasp*", advanced=True, limit=2)
513+
res0 = f.search("mdf.source_name:khazana_vasp", advanced=True, limit=2)
513514
id1 = res0[0]["mdf"]["mdf_id"]
514515
id2 = res0[1]["mdf"]["mdf_id"]
515516

@@ -653,7 +654,7 @@ def test_forge_search(capsys):
653654
assert len(res4) == 3
654655

655656
# Check reset_query
656-
f.match_field("mdf.source_name", "ta_melting*")
657+
f.match_field("mdf.source_name", "ta_melting")
657658
res5 = f.search(reset_query=False)
658659
res6 = f.search()
659660
assert all([r in res6 for r in res5]) and all([r in res5 for r in res6])
@@ -701,19 +702,19 @@ def test_forge_fetch_datasets_from_results():
701702
# Get some results
702703
f = forge.Forge(index="mdf")
703704
# Record from OQMD
704-
res01 = f.search("mdf.source_name:oqmd* AND mdf.resource_type:record", advanced=True, limit=1)
705+
res01 = f.search("mdf.source_name:oqmd AND mdf.resource_type:record", advanced=True, limit=1)
705706
# Record from OQMD with info
706-
res02 = f.search("mdf.source_name:oqmd* AND mdf.resource_type:record",
707+
res02 = f.search("mdf.source_name:oqmd AND mdf.resource_type:record",
707708
advanced=True, limit=1, info=True)
708709
# Records from JANAF
709-
res03 = f.search("mdf.source_name:khazana_vasp* AND mdf.resource_type:record",
710+
res03 = f.search("mdf.source_name:khazana_vasp AND mdf.resource_type:record",
710711
advanced=True, limit=2)
711712
# Dataset for NIST XPS DB
712-
res04 = f.search("mdf.source_name:nist_xps_db* AND mdf.resource_type:dataset", advanced=True)
713+
res04 = f.search("mdf.source_name:nist_xps_db AND mdf.resource_type:dataset", advanced=True)
713714

714715
# Get the correct dataset entries
715-
oqmd = f.search("mdf.source_name:oqmd* AND mdf.resource_type:dataset", advanced=True)[0]
716-
khazana_vasp = f.search("mdf.source_name:khazana_vasp* AND mdf.resource_type:dataset",
716+
oqmd = f.search("mdf.source_name:oqmd AND mdf.resource_type:dataset", advanced=True)[0]
717+
khazana_vasp = f.search("mdf.source_name:khazana_vasp AND mdf.resource_type:dataset",
717718
advanced=True)[0]
718719

719720
# Fetch single dataset
@@ -749,7 +750,7 @@ def test_forge_aggregate():
749750
# And returns results
750751
# And respects the reset_query arg
751752
f = forge.Forge(index="mdf")
752-
f.match_field("mdf.source_name", "nist_xps_db*")
753+
f.match_field("mdf.source_name", "nist_xps_db")
753754
res1 = f.aggregate(reset_query=False, index="mdf")
754755
assert len(res1) > 10000
755756
assert check_field(res1, "mdf.source_name", "nist_xps_db") == 0
@@ -911,10 +912,10 @@ def test_forge_http_stream(capsys):
911912

912913
def test_forge_chaining():
913914
f = forge.Forge(index="mdf")
914-
f.match_field("source_name", "cip*")
915+
f.match_field("source_name", "cip")
915916
f.match_field("material.elements", "Al")
916917
res1 = f.search()
917-
res2 = f.match_field("source_name", "cip*").match_field("material.elements", "Al").search()
918+
res2 = f.match_field("source_name", "cip").match_field("material.elements", "Al").search()
918919
assert all([r in res2 for r in res1]) and all([r in res1 for r in res2])
919920

920921

@@ -929,11 +930,11 @@ def test_forge_show_fields():
929930
def test_forge_anonymous(capsys):
930931
f = forge.Forge(anonymous=True)
931932
# Test search
932-
assert len(f.search("mdf.source_name:ab_initio_solute_database*",
933+
assert len(f.search("mdf.source_name:ab_initio_solute_database",
933934
advanced=True, limit=300)) == 300
934935

935936
# Test aggregation
936-
assert len(f.aggregate("mdf.source_name:nist_xps_db*")) > 10000
937+
assert len(f.aggregate("mdf.source_name:nist_xps_db")) > 10000
937938

938939
# Error on auth-only functions
939940
# http_download
@@ -956,7 +957,7 @@ def test_forge_anonymous(capsys):
956957
def test_get_dataset_version():
957958
# Get the version number of the OQMD
958959
f = forge.Forge()
959-
hits = f.search('mdf.source_name:oqmd_v* AND mdf.resource_type:dataset',
960+
hits = f.search('mdf.source_name:oqmd AND mdf.resource_type:dataset',
960961
advanced=True, limit=1)
961962
assert hits[0]['mdf']['version'] == f.get_dataset_version('oqmd')
962963

0 commit comments

Comments
 (0)