Skip to content

Commit 21258ee

Browse files
authored
Refactor DBFS CLI Put to support new backend. (#371)
Refactors PUT methods in CLI (without creating user facing APIs) for CLI to use new put backend of DBFS rather than using create, add_block and close methods to achieve same results. Changes, in short, create a multipart/form request and sent to /dbfs/put backend. Files with >=2Gb fall back to using create, add_block, close (streaming upload) to not break any pipelines. Version number is not increased for this change since there will not be a new release specific for this change. It will be piggy backed to next release. A fall-back logic has been added to put_file method so that file uploads larger than 2gb automatically uses streaming uploads with open, add_block and close instead of put API.
1 parent 6c48761 commit 21258ee

4 files changed

Lines changed: 58 additions & 16 deletions

File tree

databricks_cli/dbfs/api.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ class DbfsErrorCodes(object):
8686

8787

8888
class DbfsApi(object):
89+
MULTIPART_UPLOAD_LIMIT = 2147483648
90+
8991
def __init__(self, api_client):
9092
self.client = DbfsService(api_client)
9193

@@ -113,16 +115,24 @@ def get_status(self, dbfs_path, headers=None):
113115
json = self.client.get_status(dbfs_path.absolute_path, headers=headers)
114116
return FileInfo.from_json(json)
115117

118+
# Method makes multipart/form-data file upload for files <2GB.
119+
# Otherwise uses create, add-block, close methods for streaming upload.
116120
def put_file(self, src_path, dbfs_path, overwrite, headers=None):
117-
handle = self.client.create(dbfs_path.absolute_path, overwrite, headers=headers)['handle']
118-
with open(src_path, 'rb') as local_file:
119-
while True:
120-
contents = local_file.read(BUFFER_SIZE_BYTES)
121-
if len(contents) == 0:
122-
break
123-
# add_block should not take a bytes object.
124-
self.client.add_block(handle, b64encode(contents).decode(), headers=headers)
125-
self.client.close(handle, headers=headers)
121+
# If file size is >2Gb use streaming upload.
122+
if os.path.getsize(src_path) < self.MULTIPART_UPLOAD_LIMIT:
123+
self.client.put(dbfs_path.absolute_path, src_path=src_path,
124+
overwrite=overwrite, headers=headers)
125+
else:
126+
handle = self.client.create(dbfs_path.absolute_path, overwrite,
127+
headers=headers)['handle']
128+
with open(src_path, 'rb') as local_file:
129+
while True:
130+
contents = local_file.read(BUFFER_SIZE_BYTES)
131+
if len(contents) == 0:
132+
break
133+
# add_block should not take a bytes object.
134+
self.client.add_block(handle, b64encode(contents).decode(), headers=headers)
135+
self.client.close(handle, headers=headers)
126136

127137
def get_file(self, dbfs_path, dst_path, overwrite, headers=None):
128138
if os.path.exists(dst_path) and not overwrite:

databricks_cli/sdk/api_client.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def close(self):
109109

110110
# helper functions starting here
111111

112-
def perform_query(self, method, path, data = {}, headers = None):
112+
def perform_query(self, method, path, data = {}, headers = None, files=None):
113113
"""set up connection and perform query"""
114114
if headers is None:
115115
headers = self.default_headers
@@ -125,8 +125,13 @@ def perform_query(self, method, path, data = {}, headers = None):
125125
resp = self.session.request(method, self.url + path, params = translated_data,
126126
verify = self.verify, headers = headers)
127127
else:
128-
resp = self.session.request(method, self.url + path, data = json.dumps(data),
129-
verify = self.verify, headers = headers)
128+
if files is None:
129+
resp = self.session.request(method, self.url + path, data = json.dumps(data),
130+
verify = self.verify, headers = headers)
131+
else:
132+
# Multipart file upload
133+
resp = self.session.request(method, self.url + path, files = files, data = data,
134+
verify = self.verify, headers = headers)
130135
try:
131136
resp.raise_for_status()
132137
except requests.exceptions.HTTPError as e:

databricks_cli/sdk/service.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
# See the License for the specific language governing permissions and
2424
# limitations under the License.
2525
#
26+
import os
27+
28+
2629
class JobsService(object):
2730
def __init__(self, client):
2831
self.client = client
@@ -519,25 +522,35 @@ def list_test(self, path, headers=None):
519522
_data['path'] = path
520523
return self.client.perform_query('GET', '/dbfs-testing/list', data=_data, headers=headers)
521524

522-
def put(self, path, contents=None, overwrite=None, headers=None):
525+
def put(self, path, contents=None, overwrite=None, headers=None, src_path=None):
523526
_data = {}
527+
_files = None
524528
if path is not None:
525529
_data['path'] = path
526530
if contents is not None:
527531
_data['contents'] = contents
528532
if overwrite is not None:
529533
_data['overwrite'] = overwrite
530-
return self.client.perform_query('POST', '/dbfs/put', data=_data, headers=headers)
534+
if src_path is not None:
535+
headers = {'Content-Type': None}
536+
filename = os.path.basename(src_path)
537+
_files = {'file': (filename, open(src_path, 'rb'), 'multipart/form-data')}
538+
return self.client.perform_query('POST', '/dbfs/put', data=_data, headers=headers, files=_files)
531539

532-
def put_test(self, path, contents=None, overwrite=None, headers=None):
540+
def put_test(self, path, contents=None, overwrite=None, headers=None, src_path=None):
533541
_data = {}
542+
_files = None
534543
if path is not None:
535544
_data['path'] = path
536545
if contents is not None:
537546
_data['contents'] = contents
538547
if overwrite is not None:
539548
_data['overwrite'] = overwrite
540-
return self.client.perform_query('POST', '/dbfs-testing/put', data=_data, headers=headers)
549+
if src_path is not None:
550+
headers = {'Content-Type': None}
551+
filename = os.path.basename(src_path)
552+
_files = {'file': (filename, open(src_path, 'rb'), 'multipart/form-data')}
553+
return self.client.perform_query('POST', '/dbfs/put', data=_data, headers=headers, files=_files)
541554

542555
def mkdirs(self, path, headers=None):
543556
_data = {}

tests/dbfs/test_api.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,20 @@ def test_put_file(self, dbfs_api, tmpdir):
135135
api_mock.create.return_value = {'handle': test_handle}
136136
dbfs_api.put_file(test_file_path, TEST_DBFS_PATH, True)
137137

138+
# Should not call add-block since file is < 2GB
139+
assert api_mock.add_block.call_count == 0
140+
141+
# Files >= 2GB should use create, add_block, close stream upload.
142+
def test_put_large_file(self, dbfs_api, tmpdir):
143+
test_file_path = os.path.join(tmpdir.strpath, 'test')
144+
with open(test_file_path, 'wt') as f:
145+
f.write('test')
146+
api_mock = dbfs_api.client
147+
# Make streaming upload threshold 2 bytes for testing.
148+
dbfs_api.MULTIPART_UPLOAD_LIMIT = 2
149+
test_handle = 0
150+
api_mock.create.return_value = {'handle': test_handle}
151+
dbfs_api.put_file(test_file_path, TEST_DBFS_PATH, True)
138152
assert api_mock.add_block.call_count == 1
139153
assert test_handle == api_mock.add_block.call_args[0][0]
140154
assert b64encode(b'test').decode() == api_mock.add_block.call_args[0][1]

0 commit comments

Comments
 (0)