Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions .github/workflows/pylint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ jobs:
matrix:
python-version: ["3.8", "3.9", "3.10"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pylint
- name: Analysing the code with pylint
run: |
pylint --disable=C,R,E0401,W0107,W0613,W0612,W0221 $(git ls-files '*.py')
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pylint
- name: Analysing the code with pylint
run: |
pylint --disable=C,R,E0401,W0107,W0613,W0612,W0221,W0212 $(git ls-files '*.py')
101 changes: 48 additions & 53 deletions commands/read_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

def fetch_sql_commands_from_file(file, limit, offset):
"""
Fetch SQL commands from a file until the specified limit, starting from the given byte offset.
Ignores lines containing 'BEGIN TRANSACTION;' and 'COMMIT;'.
Fetch SQL commands from a file using sqlparse for robust parsing.
Uses command-based chunking instead of byte-based to avoid cutting commands in half.

Args:
file (file object): The file object opened for reading.
Expand All @@ -17,57 +17,52 @@ def fetch_sql_commands_from_file(file, limit, offset):
list: Fetched SQL commands from the file.
int: Position in the file after reading (byte offset).
"""
file.seek(offset)
commands = []
command = ""
in_string = False

while True:
line_start_offset = file.tell()
line = file.readline()
if not line:
break

if line.upper() in ['BEGIN TRANSACTION;', 'COMMIT;']:
continue

i = 0
while i < len(line):
char = line[i]

if in_string:
if char == "'":
if i + 1 < len(line) and line[i + 1] == char:
command += "''"
i += 1
else:
in_string = False
command += char
else:
command += char
else:
if char == "'":
in_string = True

command += char

if char == ';' and not in_string:
commands.append(command)
command = ""

if len(commands) >= limit:
current_position = file.tell()
return commands, current_position

i += 1

current_position = file.tell()

if command.strip():
file.seek(line_start_offset)
current_position = line_start_offset

return commands, current_position
import utils_sql as sql

# If this is the first call (offset = 0), parse the entire file once
# and store commands in a global cache to avoid re-parsing
if not hasattr(fetch_sql_commands_from_file, '_cached_commands'):
file.seek(0)
content = file.read()

# Filter out transaction control statements
lines = content.split('\n')
filtered_lines = []

for line in lines:
stripped = line.strip().upper()
if stripped not in ['BEGIN TRANSACTION;', 'COMMIT;', 'BEGIN;', 'COMMIT']:
filtered_lines.append(line)

filtered_content = '\n'.join(filtered_lines)

# Use sqlparse for robust parsing (same as execute() method)
try:
all_commands = sql.sql_to_list(filtered_content)
fetch_sql_commands_from_file._cached_commands = all_commands
fetch_sql_commands_from_file._command_index = 0
except Exception as e: # pylint: disable=broad-exception-caught
# Fallback to simple splitting if sqlparse fails
simple_commands = [cmd.strip() for cmd in filtered_content.split(';') if cmd.strip()]
fetch_sql_commands_from_file._cached_commands = simple_commands
fetch_sql_commands_from_file._command_index = 0

# Return the next batch of commands
start_idx = fetch_sql_commands_from_file._command_index
end_idx = start_idx + limit if limit else len(fetch_sql_commands_from_file._cached_commands)

commands = fetch_sql_commands_from_file._cached_commands[start_idx:end_idx]

# Update index for next call
fetch_sql_commands_from_file._command_index = end_idx

# Calculate new position (approximate)
if commands:
new_position = offset + sum(len(cmd.encode('utf-8')) for cmd in commands)
else:
new_position = offset

return commands, new_position

def limit_estimation(rows, max_chunk_size_bytes, margin):
chunk_size = int(utils.total_size(rows) // len(rows))
Expand Down
81 changes: 66 additions & 15 deletions edgesql.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,34 @@ def execute(self, buffer):
result_data = json_data.get('data', [])
if result_data:
query_result = result_data[0]
if 'error' in query_result:
error_msg = f"{query_result.get('error')}. statusCode={response.status_code}"
result['error'] = error_msg
result['command'] = json.dumps(sql_commands)
else:
# Check for error at query_result level (multiple possible field names)
error_found = False
for error_field in ['error', 'message', 'detail']:
if error_field in query_result:
error_msg = f"{query_result.get(error_field)}. statusCode={response.status_code}"
result['error'] = error_msg
result['command'] = json.dumps(sql_commands)
error_found = True
break

if not error_found:
results = query_result.get('results', {})
columns = results.get('columns', [])
rows = results.get('rows', [])
result['data'] = {'columns': columns, 'rows': rows}
result['success'] = True
# Check if error is inside results field (multiple possible field names)
results_error_found = False
if isinstance(results, dict):
for error_field in ['error', 'message', 'detail']:
if error_field in results:
error_msg = f"{results.get(error_field)}. statusCode={response.status_code}"
result['error'] = error_msg
result['command'] = json.dumps(sql_commands)
results_error_found = True
break

if not results_error_found:
columns = results.get('columns', [])
rows = results.get('rows', [])
result['data'] = {'columns': columns, 'rows': rows}
result['success'] = True
else:
error_msg = "Empty or invalid response data."
result['error'] = error_msg
Expand Down Expand Up @@ -163,7 +181,11 @@ def list_databases(self):
raise ValueError(f"Error decoding JSON response: {e}. statusCode={response.status_code}. Response content: {response.text[:200]}") from e

if response.status_code == HTTPStatus.OK: # 200
databases = json_data.get('results', [])
databases = json_data.get('results')
if databases is None and isinstance(json_data.get('data'), dict):
databases = json_data.get('data', {}).get('results')
if databases is None:
databases = []
db_list = {
'databases': [
(db.get('id'), db.get('name'), db.get('status'), db.get('active'), db.get('last_modified'), db.get('last_editor'), db.get('product_version'))
Expand Down Expand Up @@ -212,12 +234,37 @@ def set_current_database(self, database_name):

if response.status_code == HTTPStatus.OK: # 200
databases = json_data.get('results')
if databases is None and isinstance(json_data.get('data'), dict):
databases = json_data.get('data', {}).get('results')
if databases:
for db in json_data['results']:
if db['name'] == database_name:
self._current_database_id = db['id']
self._current_database_name = db['name']
# 1) Exact match by name
for db in databases:
if db.get('name') == database_name:
self._current_database_id = db.get('id')
self._current_database_name = db.get('name')
return True

# 2) Exact match by ID (if the user typed a number)
if str(database_name).isdigit():
requested_id = int(database_name)
for db in databases:
if db.get('id') == requested_id:
self._current_database_id = db.get('id')
self._current_database_name = db.get('name')
return True

# 3) Unique partial match by name (non-destructive, but avoid ambiguity)
matches = [db for db in databases if isinstance(db.get('name'), str) and database_name in db.get('name')]
if len(matches) == 1:
db = matches[0]
self._current_database_id = db.get('id')
self._current_database_name = db.get('name')
return True
if len(matches) > 1:
match_names = ", ".join([m.get('name') for m in matches if m.get('name')])
utils.write_output(f"Ambiguous database selector '{database_name}'. Matches: {match_names}")
return False

utils.write_output(f"Database '{database_name}' not found.")
else:
msg_err = json_data.get('detail', 'Unknown error')
Expand Down Expand Up @@ -245,7 +292,11 @@ def get_database_id(self, database_name):
raise ValueError(f"Error decoding JSON response: {e}. statusCode={response.status_code}") from e

if response.status_code == HTTPStatus.OK: # 200
databases = json_data.get('results',[])
databases = json_data.get('results')
if databases is None and isinstance(json_data.get('data'), dict):
databases = json_data.get('data', {}).get('results')
if databases is None:
databases = []
if databases:
for db in databases: #json_data['results']:
if db.get('name') == database_name:
Expand Down