Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
include_directories(include)

add_subdirectory(dbconnector)
add_subdirectory(storage)

add_library(
Expand Down
16 changes: 16 additions & 0 deletions src/dbconnector/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
set(POSTGRES_DBCONNECTOR_PATH
"${CMAKE_CURRENT_LIST_DIR}/../../database-connector")

add_library(
postgres_ext_dbconnector OBJECT
${POSTGRES_DBCONNECTOR_PATH}/src/query/query_writer.cpp
# ${POSTGRES_DBCONNECTOR_PATH}/src/table_scan/filter_pushdown.cpp
# ${POSTGRES_DBCONNECTOR_PATH}/src/table_scan/filter_util.cpp
# ${POSTGRES_DBCONNECTOR_PATH}/src/optimizer/aggregate_optimizer.cpp
# ${POSTGRES_DBCONNECTOR_PATH}/src/optimizer/optimizer_util.cpp
# ${POSTGRES_DBCONNECTOR_PATH}/src/optimizer/order_by_and_limit_optimizer.cpp
)

set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:postgres_ext_dbconnector>
PARENT_SCOPE)
4 changes: 2 additions & 2 deletions src/include/postgres_binary_writer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ class PostgresBinaryWriter {

void WriteArray(Vector &col, idx_t r, const vector<uint32_t> &dimensions, idx_t depth, uint32_t count) {
auto list_data = FlatVector::GetData<list_entry_t>(col);
auto &child_vector = ListVector::GetEntry(col);
auto &child_vector = ListVector::GetChildMutable(col);
for (idx_t i = 0; i < count; i++) {
auto list_entry = list_data[r + i];
if (list_entry.length != dimensions[depth]) {
Expand Down Expand Up @@ -398,7 +398,7 @@ class PostgresBinaryWriter {
while (current_vector.get().GetType().id() == LogicalTypeId::LIST) {
auto current_entry = FlatVector::GetData<list_entry_t>(current_vector.get())[current_position];
dimensions.push_back(current_entry.length);
current_vector = ListVector::GetEntry(current_vector.get());
current_vector = ListVector::GetChild(current_vector.get());
current_position = current_entry.offset;
}

Expand Down
6 changes: 6 additions & 0 deletions src/include/postgres_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ class PostgresUtils {

static string EscapeConnectionString(const string &input);
static string ExtractConnectionOption(const KeyValueSecret &kv_secret, const string &name);
static string WriteLiteral(const string &identifier);
static string WriteIdentifier(const string &identifier);

private:
static string EscapeQuotes(const string &text, char quote);
static string WriteQuoted(const string &text, char quote);
};

} // namespace duckdb
14 changes: 8 additions & 6 deletions src/postgres_attach.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#include "duckdb.hpp"

#include "duckdb/parser/parsed_data/create_table_function_info.hpp"

#include "postgres_filter_pushdown.hpp"
#include "postgres_scanner.hpp"
#include "postgres_result.hpp"
#include "postgres_utils.hpp"

namespace duckdb {

Expand Down Expand Up @@ -57,7 +59,7 @@ WHERE relkind = 'r' AND attnum > 0 AND nspname = %s
GROUP BY relname
ORDER BY relname;
)",
KeywordHelper::WriteQuoted(data.source_schema));
PostgresUtils::WriteLiteral(data.source_schema));
auto res = conn.Query(context, fetch_table_query);
for (idx_t row = 0; row < PQntuples(res->res); row++) {
auto table_name = res->GetString(row, 0);
Expand All @@ -68,21 +70,21 @@ ORDER BY relname;
query = "CREATE VIEW IF NOT EXISTS ";
}
if (!data.sink_schema.empty()) {
query += KeywordHelper::WriteQuoted(data.sink_schema, '"') + ".";
query += PostgresUtils::WriteIdentifier(data.sink_schema) + ".";
}
query += KeywordHelper::WriteQuoted(table_name, '"');
query += PostgresUtils::WriteIdentifier(table_name);
query += " AS SELECT * FROM ";
if (data.filter_pushdown) {
query += "postgres_scan_pushdown";
} else {
query += "postgres_scan";
}
query += "(";
query += KeywordHelper::WriteQuoted(data.dsn);
query += PostgresUtils::WriteLiteral(data.dsn);
query += ", ";
query += KeywordHelper::WriteQuoted(data.source_schema);
query += PostgresUtils::WriteLiteral(data.source_schema);
query += ", ";
query += KeywordHelper::WriteQuoted(table_name);
query += PostgresUtils::WriteLiteral(table_name);
query += ");";
dconn.Query(query);
}
Expand Down
7 changes: 4 additions & 3 deletions src/postgres_aws.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
#include <mutex>
#include <stdexcept>

#include "duckdb/parser/keyword_helper.hpp"

#include "dbconnector/defer.hpp"

#include "duckdb/parser/keyword_helper.hpp"
#include "postgres_secrets.hpp"
#include "postgres_utils.hpp"

Expand All @@ -29,7 +30,7 @@ static std::string MakeCreateSecretQuery(const std::string &template_secret_name
for (auto &en : kv_secret.secret_map) {
query += " " + en.first + " " + en.second.ToSQLString() + ",\n";
}
query += " RDS_TEMPLATE_SECRET_NAME " + KeywordHelper::WriteQuoted(template_secret_name, '\'') + "\n";
query += " RDS_TEMPLATE_SECRET_NAME " + PostgresUtils::WriteLiteral(template_secret_name) + "\n";
query += ")";
return query;
}
Expand Down Expand Up @@ -71,7 +72,7 @@ std::string PostgresAws::GenerateRdsAuthToken(AttachedDatabase &attached_db,
std::string create_secret_query =
MakeCreateSecretQuery(token_config.rds_secret_name, secret_name, kv_template_secret);

std::string quoted_secret_name = KeywordHelper::WriteQuoted(secret_name, '"');
std::string quoted_secret_name = PostgresUtils::WriteIdentifier(secret_name);
RunQuery(conn, create_secret_query, "error creating RDS secret from template: " + token_config.rds_secret_name);
auto deferred_drop_secret =
dbconnector::Defer([&conn, quoted_secret_name] { conn.Query("DROP SECRET " + quoted_secret_name); });
Expand Down
9 changes: 5 additions & 4 deletions src/postgres_copy_to.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "duckdb/common/vector/list_vector.hpp"
#include "duckdb/common/vector/map_vector.hpp"
#include "duckdb/common/vector/struct_vector.hpp"

#include "postgres_connection.hpp"
#include "postgres_binary_writer.hpp"
#include "postgres_text_writer.hpp"
Expand Down Expand Up @@ -31,16 +32,16 @@ void PostgresConnection::BeginCopyTo(ClientContext &context, PostgresCopyState &
const vector<string> &column_names) {
string query = "COPY ";
if (!schema_name.empty()) {
query += KeywordHelper::WriteQuoted(schema_name, '"') + ".";
query += PostgresUtils::WriteIdentifier(schema_name) + ".";
}
query += KeywordHelper::WriteQuoted(table_name, '"') + " ";
query += PostgresUtils::WriteIdentifier(table_name) + " ";
if (!column_names.empty()) {
query += "(";
for (idx_t c = 0; c < column_names.size(); c++) {
if (c > 0) {
query += ", ";
}
query += KeywordHelper::WriteQuoted(column_names[c], '"');
query += PostgresUtils::WriteIdentifier(column_names[c]);
}
query += ") ";
}
Expand Down Expand Up @@ -172,7 +173,7 @@ void CastToPostgresVarchar(ClientContext &context, Vector &input, Vector &result

void CastListToPostgresArray(ClientContext &context, Vector &input, Vector &varchar_vector, idx_t size) {
// cast child list
auto &child_data = ListVector::GetEntry(input);
auto &child_data = ListVector::GetChildMutable(input);
auto child_count = ListVector::GetListSize(input);
bool skip_quoting = child_data.GetType().id() == LogicalTypeId::LIST; // Do not quote dimensions in multi-D arrays
Vector child_varchar(LogicalType::VARCHAR, child_count);
Expand Down
9 changes: 6 additions & 3 deletions src/postgres_filter_pushdown.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "postgres_filter_pushdown.hpp"

#include "duckdb/parser/keyword_helper.hpp"
#include "duckdb/function/scalar/struct_utils.hpp"
#include "duckdb/planner/expression/bound_comparison_expression.hpp"
Expand All @@ -9,6 +10,8 @@
#include "duckdb/planner/filter/table_filter_functions.hpp"
#include "duckdb/common/enum_util.hpp"

#include "postgres_utils.hpp"

namespace duckdb {

string PostgresFilterPushdown::CreateExpression(const string &column_name,
Expand Down Expand Up @@ -64,7 +67,7 @@ string TransformLiteral(const Value &val) {
case LogicalTypeId::BLOB:
return TransformBlob(StringValue::Get(val));
default:
return KeywordHelper::WriteQuoted(val.ToString());
return PostgresUtils::WriteLiteral(val.ToString());
}
}

Expand Down Expand Up @@ -103,7 +106,7 @@ string PostgresFilterPushdown::TransformExpressionSubject(const string &column_n
if (struct_type.id() != LogicalTypeId::STRUCT || StructType::IsUnnamed(struct_type)) {
return string();
}
auto child_name = KeywordHelper::WriteQuoted(StructType::GetChildName(struct_type, child_idx), '\"');
auto child_name = PostgresUtils::WriteIdentifier(StructType::GetChildName(struct_type, child_idx));
return "(" + parent_name + ")." + child_name;
}
default:
Expand Down Expand Up @@ -218,7 +221,7 @@ string PostgresFilterPushdown::TransformFilters(const vector<column_t> &column_i
if (IsVirtualColumn(column_id)) {
column_name = "ctid";
} else {
column_name = KeywordHelper::WriteQuoted(names[column_id], '"');
column_name = PostgresUtils::WriteIdentifier(names[column_id]);
}
auto &filter = entry.Filter();
auto filter_text = TransformFilter(column_name, filter, column_id);
Expand Down
7 changes: 4 additions & 3 deletions src/postgres_scanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "duckdb/common/shared_ptr.hpp"
#include "duckdb/common/helper.hpp"
#include "duckdb/parser/parsed_data/create_table_function_info.hpp"

#include "postgres_oauth.hpp"
#include "postgres_filter_pushdown.hpp"
#include "postgres_scanner.hpp"
Expand Down Expand Up @@ -243,7 +244,7 @@ static void PostgresInitInternal(ClientContext &context, const PostgresBindData
col_names += "ctid";
}
} else {
col_names += KeywordHelper::WriteQuoted(bind_data->names[column_id], '"');
col_names += PostgresUtils::WriteIdentifier(bind_data->names[column_id]);
if (bind_data->postgres_types[column_id].info == PostgresTypeAnnotation::CAST_TO_VARCHAR) {
col_names += "::VARCHAR";
} else if (bind_data->types[column_id].id() == LogicalTypeId::LIST) {
Expand Down Expand Up @@ -290,8 +291,8 @@ static void PostgresInitInternal(ClientContext &context, const PostgresBindData

} else {
query = StringUtil::Format(R"(SELECT %s FROM %s.%s %s%s)", col_names,
KeywordHelper::WriteQuoted(bind_data->schema_name, '"'),
KeywordHelper::WriteQuoted(bind_data->table_name, '"'), filter, bind_data->limit);
PostgresUtils::WriteIdentifier(bind_data->schema_name),
PostgresUtils::WriteIdentifier(bind_data->table_name), filter, bind_data->limit);
}
if (!bind_data->use_text_protocol) {
query = StringUtil::Format(R"(COPY (%s) TO STDOUT (FORMAT "binary");)", query);
Expand Down
14 changes: 7 additions & 7 deletions src/postgres_text_reader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ struct PostgresListParser {

void AddString(const string &str, bool &quoted) {
if (size >= capacity) {
vector.Resize(capacity, capacity * 2);
vector.Reserve(capacity * 2);
capacity *= 2;
}
if (!quoted && str == "NULL") {
Expand Down Expand Up @@ -204,7 +204,7 @@ void ParsePostgresCTID(PostgresCTIDParser &ctid_parser, string_t list) {
void PostgresTextReader::ConvertList(Vector &source, Vector &target, const PostgresType &postgres_type, idx_t count) {
// lists have the format {1, 2, 3}
UnifiedVectorFormat vdata;
source.ToUnifiedFormat(count, vdata);
source.ToUnifiedFormat(vdata);

auto strings = UnifiedVectorFormat::GetData<string_t>(vdata);
auto list_data = FlatVector::GetDataMutable<list_entry_t>(target);
Expand All @@ -221,7 +221,7 @@ void PostgresTextReader::ConvertList(Vector &source, Vector &target, const Postg
list_data[i].length = list_parser.size - list_data[i].offset;
}
if (list_parser.size > 0) {
auto &target_child = ListVector::GetEntry(target);
auto &target_child = ListVector::GetChildMutable(target);
ListVector::Reserve(target, list_parser.size);
ConvertVector(list_parser.vector, target_child,
postgres_type.children.empty() ? PostgresType() : postgres_type.children[0], list_parser.size);
Expand All @@ -232,7 +232,7 @@ void PostgresTextReader::ConvertList(Vector &source, Vector &target, const Postg
void PostgresTextReader::ConvertStruct(Vector &source, Vector &target, const PostgresType &postgres_type, idx_t count) {
// structs have the format (1, 2, 3)
UnifiedVectorFormat vdata;
source.ToUnifiedFormat(count, vdata);
source.ToUnifiedFormat(vdata);
auto strings = UnifiedVectorFormat::GetData<string_t>(vdata);
auto &children = StructVector::GetEntries(target);

Expand All @@ -257,7 +257,7 @@ void PostgresTextReader::ConvertStruct(Vector &source, Vector &target, const Pos
void PostgresTextReader::ConvertCTID(Vector &source, Vector &target, idx_t count) {
// ctids have the format (page_index, row_in_page)
UnifiedVectorFormat vdata;
source.ToUnifiedFormat(count, vdata);
source.ToUnifiedFormat(vdata);
auto strings = UnifiedVectorFormat::GetData<string_t>(vdata);
auto result = FlatVector::GetDataMutable<int64_t>(target);

Expand All @@ -278,7 +278,7 @@ void PostgresTextReader::ConvertCTID(Vector &source, Vector &target, idx_t count
void PostgresTextReader::ConvertBlob(Vector &source, Vector &target, idx_t count) {
// ctids have the format (page_index, row_in_page)
UnifiedVectorFormat vdata;
source.ToUnifiedFormat(count, vdata);
source.ToUnifiedFormat(vdata);
auto strings = UnifiedVectorFormat::GetData<string_t>(vdata);
auto result = FlatVector::GetDataMutable<string_t>(target);

Expand Down Expand Up @@ -311,7 +311,7 @@ static void ConvertGeometry(Vector &source, Vector &target, idx_t count) {
// Geometry is encoded in HEXWKB format

UnifiedVectorFormat vdata;
source.ToUnifiedFormat(count, vdata);
source.ToUnifiedFormat(vdata);
const auto strings = UnifiedVectorFormat::GetData<string_t>(vdata);
const auto result = FlatVector::GetDataMutable<string_t>(target);

Expand Down
35 changes: 33 additions & 2 deletions src/postgres_utils.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "postgres_utils.hpp"

#include "storage/postgres_schema_entry.hpp"
#include "storage/postgres_transaction.hpp"
#include "postgres_type_oids.hpp"
Expand All @@ -14,7 +15,7 @@ PGconn *PostgresUtils::PGConnect(const string &dsn, const string &attach_path) {
// both PQStatus and PQerrorMessage check for nullptr
if (PQstatus(conn) == CONNECTION_BAD) {
char *msg_cstr = PQerrorMessage(conn);
std::string msg = msg_cstr != nullptr ? std::string(msg_cstr) : std::string();
string msg = msg_cstr != nullptr ? string(msg_cstr) : string();
PQfinish(conn);
throw IOException("Unable TODO:REMOVEME to connect to Postgres at \"%s\": %s", attach_path, msg);
}
Expand Down Expand Up @@ -539,7 +540,10 @@ PostgresVersion PostgresUtils::ExtractPostgresVersion(const string &version_str)
}

string PostgresUtils::QuotePostgresIdentifier(const string &text) {
return KeywordHelper::WriteOptionallyQuoted(text, '"', false);
if (!KeywordHelper::RequiresQuotes(text, false)) {
return text;
}
return PostgresUtils::WriteIdentifier(text);
}

string PostgresUtils::EscapeConnectionString(const string &input) {
Expand Down Expand Up @@ -571,4 +575,31 @@ string PostgresUtils::ExtractConnectionOption(const KeyValueSecret &kv_secret, c
return result;
}

string PostgresUtils::EscapeQuotes(const string &text, char quote) {
string result;
for (auto c : text) {
if (c == quote) {
result += quote;
result += quote;
} else if (c == '\\') {
result += "\\\\";
} else {
result += c;
}
}
return result;
}

string PostgresUtils::WriteQuoted(const string &text, char quote) {
return string(1, quote) + EscapeQuotes(text, quote) + string(1, quote);
}

string PostgresUtils::WriteLiteral(const string &literal) {
return WriteQuoted(literal, '\'');
}

string PostgresUtils::WriteIdentifier(const string &identifier) {
return WriteQuoted(identifier, '"');
}

} // namespace duckdb
5 changes: 3 additions & 2 deletions src/storage/postgres_catalog_set.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "storage/postgres_catalog_set.hpp"

#include "storage/postgres_transaction.hpp"
#include "duckdb/parser/parsed_data/drop_info.hpp"
#include "storage/postgres_schema_entry.hpp"
Expand Down Expand Up @@ -66,9 +67,9 @@ void PostgresCatalogSet::DropEntry(PostgresTransaction &transaction, DropInfo &i
drop_query += " IF EXISTS ";
}
if (!info.schema.empty()) {
drop_query += KeywordHelper::WriteQuoted(info.schema, '"') + ".";
drop_query += PostgresUtils::WriteIdentifier(info.schema) + ".";
}
drop_query += KeywordHelper::WriteQuoted(info.name, '"');
drop_query += PostgresUtils::WriteIdentifier(info.name);
if (info.cascade) {
drop_query += "CASCADE";
}
Expand Down
Loading
Loading