Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
[project]
name = "postgresql-charms-single-kernel"
description = "Shared and reusable code for PostgreSQL-related charms"
version = "16.1.12"
version = "16.2.1"
readme = "README.md"
license = {file = "LICENSE"}
authors = [
Expand All @@ -22,6 +22,9 @@ dependencies = [
"tenacity>=9.0.0",
]

[project.optional-dependencies]
postgresql = ["httpx; python_version >= '3.12'"]

Comment on lines +25 to +27
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extra deps, so not to install for PGB charms.

[build-system]
requires = ["uv_build>=0.11.0,<0.12.0"]
build-backend = "uv_build"
Expand Down
2 changes: 2 additions & 0 deletions single_kernel_postgresql/config/literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
USER = "operator"
SYSTEM_USERS = [MONITORING_USER, REPLICATION_USER, REWIND_USER, USER]

API_REQUEST_TIMEOUT = 5


class Substrates(str, Enum):
"""Possible substrates."""
Expand Down
198 changes: 196 additions & 2 deletions single_kernel_postgresql/utils/__init__.py
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To keep the namespace similar.

Original file line number Diff line number Diff line change
@@ -1,3 +1,197 @@
# Copyright 2025 Canonical Ltd.
# Copyright 2022 Canonical Ltd.
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Old code, so old copyright.

# See LICENSE file for licensing details.
"""Utils and helpers for PostgreSQL charms."""

"""A collection of utility functions that are used in the charm."""

import os
import pwd
import re
import secrets
import string
from asyncio import as_completed, create_task, run, wait
from contextlib import suppress
from ssl import CERT_NONE, create_default_context
from typing import Any

from httpx import AsyncClient, BasicAuth, HTTPError

from ..config.literals import API_REQUEST_TIMEOUT, Substrates


def new_password() -> str:
"""Generate a random password string.

Returns:
A random password string.
"""
choices = string.ascii_letters + string.digits
password = "".join([secrets.choice(choices) for _ in range(16)])
return password


def split_mem(mem_str) -> tuple:
"""Split a memory string into a number and a unit.

Args:
mem_str: a string representing a memory value, e.g. "1Gi"
"""
pattern = r"^(\d+)(\w+)$"
parts = re.match(pattern, mem_str)
if parts:
return parts.groups()
return None, "No unit found"


def any_memory_to_bytes(mem_str) -> int:
"""Convert a memory string to bytes.

Args:
mem_str: a string representing a memory value, e.g. "1Gi"
"""
units = {
"KI": 1024,
"K": 10**3,
"MI": 1048576,
"M": 10**6,
"GI": 1073741824,
"G": 10**9,
"TI": 1099511627776,
"T": 10**12,
}
try:
num = int(mem_str)
return num
except ValueError as e:
memory, unit = split_mem(mem_str)
unit = unit.upper()
if unit not in units:
raise ValueError(f"Invalid memory definition in '{mem_str}'") from e

num = int(memory)
return int(num * units[unit])


def any_cpu_to_cores(cpu_str) -> int:
"""Convert a CPU string to cores.

Args:
cpu_str: a string representing a CPU value, as integer or millis
"""
if cpu_str.endswith("m"):
# convert millis to cores, undercommited
return int(cpu_str[:-1]) // 1000
return int(cpu_str)


def label2name(label: str) -> str:
"""Convert a unit label (with `-`) to a unit name (with `/`).

Args:
label: The label to convert.

Returns:
The converted name.
"""
return "/".join(label.rsplit("-", 1))
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed from charms to use a single rsplit.



def render_file(
substrate: Substrates, path: str, content: str, mode: int, change_owner: bool = True
) -> None:
Comment on lines +98 to +100
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Used also by k8s to render files on the shared volume.

"""Write a content rendered from a template to a file.

Args:
substrate: Charm substrate.
path: the path to the file.
content: the data to be written to the file.
mode: access permission mask applied to the
file using chmod (e.g. 0o640).
change_owner: whether to change the file owner
to the _daemon_ user.
"""
# TODO: keep this method to use it also for generating replication configuration files and
# move it to an utils / helpers file.
# Write the content to the file.
with open(path, "w+") as file:
file.write(content)
# Ensure correct permissions are set on the file.
os.chmod(path, mode)
if change_owner:
_change_owner(substrate, path)


def create_directory(substrate: Substrates, path: str, mode: int) -> None:
"""Creates a directory.

Args:
substrate: Charm substrate.
path: the path of the directory that should be created.
mode: access permission mask applied to the
directory using chmod (e.g. 0o640).
"""
os.makedirs(path, mode=mode, exist_ok=True)
# Ensure correct permissions are set on the directory.
os.chmod(path, mode)
_change_owner(substrate, path)


def _change_owner(substrate: Substrates, path: str) -> None:
"""Change the ownership of a file or a directory to the postgres user.

Args:
substrate: Charm substrate.
path: path to a file or directory.
"""
try:
# Get the uid/gid for the _daemon_ user.
user_database = (
pwd.getpwnam("_daemon_") if substrate == Substrates.VM else pwd.getpwnam("postgres")
)
# Set the correct ownership for the file or directory.
os.chown(path, uid=user_database.pw_uid, gid=user_database.pw_gid)
except KeyError:
# Ignore non existing user error when it wasn't created yet.
pass


async def _httpx_get_request(
url: str, cafile: str, auth: BasicAuth | None = None, verify: bool = True
) -> dict[str, Any] | None:
ssl_ctx = create_default_context()
if verify:
with suppress(FileNotFoundError):
ssl_ctx.load_verify_locations(cafile=cafile)
else:
ssl_ctx.check_hostname = False
ssl_ctx.verify_mode = CERT_NONE
async with AsyncClient(auth=auth, timeout=API_REQUEST_TIMEOUT, verify=ssl_ctx) as client:
try:
return (await client.get(url)).raise_for_status().json()
except (HTTPError, ValueError):
return None


async def _async_get_request(
uri: str, endpoints: list[str], cafile: str, auth: BasicAuth | None, verify: bool = True
) -> dict[str, Any] | None:
tasks = [
create_task(_httpx_get_request(f"https://{ip}:8008{uri}", cafile, auth, verify))
for ip in endpoints
]
for task in as_completed(tasks):
if result := await task:
for task in tasks:
task.cancel()
await wait(tasks)
return result


def parallel_patroni_get_request(
uri: str,
endpoints: list[str],
cafile: str,
auth: BasicAuth | None = None,
verify: bool = True,
) -> dict[str, Any] | None:
"""Call all possible patroni endpoints in parallel."""
return run(_async_get_request(uri, endpoints, cafile, auth, verify))
107 changes: 107 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright 2021 Canonical Ltd.
# See LICENSE file for licensing details.

import re
from unittest.mock import mock_open, patch

from single_kernel_postgresql.config.literals import Substrates
from single_kernel_postgresql.utils import (
any_cpu_to_cores,
any_memory_to_bytes,
create_directory,
label2name,
new_password,
render_file,
)


def test_any_memory_to_bytes():
assert any_memory_to_bytes(1024) == 1024

assert any_memory_to_bytes("1KI") == 1024

try:
any_memory_to_bytes("KI")
assert False
except ValueError as e:
assert str(e) == "Invalid memory definition in 'KI'"


def test_label2name():
assert label2name("postgresql-k8s-1") == "postgresql-k8s/1"


def test_any_cpu_to_cores():
assert any_cpu_to_cores("12") == 12
assert any_cpu_to_cores("1000m") == 1


def test_new_password():
# Test the password generation twice in order to check if we get different passwords and
# that they meet the required criteria.
first_password = new_password()
assert len(first_password) == 16
assert re.fullmatch("[a-zA-Z0-9\b]{16}$", first_password) is not None

second_password = new_password()
assert re.fullmatch("[a-zA-Z0-9\b]{16}$", second_password) is not None
assert second_password != first_password


def test_render_file():
with (
patch("os.chmod") as _chmod,
patch("os.chown") as _chown,
patch("pwd.getpwnam") as _pwnam,
patch("tempfile.NamedTemporaryFile") as _temp_file,
):
# Set a mocked temporary filename.
filename = "/tmp/temporaryfilename"
_temp_file.return_value.name = filename
# Setup a mock for the `open` method.
mock = mock_open()
# Patch the `open` method with our mock.
with patch("builtins.open", mock, create=True):
# Set the uid/gid return values for lookup of 'postgres' user.
_pwnam.return_value.pw_uid = 35
_pwnam.return_value.pw_gid = 35
# Call the method using a temporary configuration file.
render_file(Substrates.VM, filename, "rendered-content", 0o640)

# Check the rendered file is opened with "w+" mode.
assert mock.call_args_list[0][0] == (filename, "w+")
# Ensure that the correct user is lookup up.
_pwnam.assert_called_with("_daemon_")
# Ensure the file is chmod'd correctly.
_chmod.assert_called_with(filename, 0o640)
# Ensure the file is chown'd correctly.
_chown.assert_called_with(filename, uid=35, gid=35)

# Test when it's requested to not change the file owner.
mock.reset_mock()
_pwnam.reset_mock()
_chmod.reset_mock()
_chown.reset_mock()
with patch("builtins.open", mock, create=True):
render_file(Substrates.VM, filename, "rendered-content", 0o640, change_owner=False)
_pwnam.assert_not_called()
_chmod.assert_called_once_with(filename, 0o640)
_chown.assert_not_called()


def test_create_directory():
with (
patch("os.chmod") as _chmod,
patch("os.chown") as _chown,
patch("os.makedirs") as _makedirs,
patch("pwd.getpwnam") as _pwnam,
):
_pwnam.return_value.pw_uid = 35
_pwnam.return_value.pw_gid = 35

create_directory(Substrates.K8S, "test", 0o640)

_makedirs.assert_called_once_with("test", mode=0o640, exist_ok=True)
_chmod.assert_called_once_with("test", 0o640)
_chown.assert_called_once_with("test", uid=35, gid=35)
_pwnam.assert_called_with("postgres")
6 changes: 3 additions & 3 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ allowlist_externals =
[testenv:format]
description = Apply coding style standards to code
commands_pre =
uv sync --active --group format
uv sync --active --group format --all-extras
commands =
uv run --active ruff check --fix {[vars]all_path}
uv run --active ruff format {[vars]all_path}
Expand All @@ -31,7 +31,7 @@ commands =
[testenv:lint]
description = Check code against coding style standards
commands_pre =
uv sync --active --group lint --group format
uv sync --active --group lint --group format --all-extras
commands =
uv lock --check
uv run --active codespell "{tox_root}" --skip "{tox_root}/.git" --skip "{tox_root}/.tox" \
Expand All @@ -45,7 +45,7 @@ commands =
[testenv:unit]
description = Run unit tests
commands_pre =
uv sync --active --group unit
uv sync --active --group unit --all-extras
commands =
uv run --active coverage run --source={[vars]src_path} \
-m pytest -v --tb native -s {posargs} {[vars]tests_path}/unit
Expand Down
Loading