Skip to content

Commit 124205d

Browse files
author
Tim Band
committed
install-stories command allows story file to specify its configuration
1 parent 8d43205 commit 124205d

12 files changed

Lines changed: 642 additions & 87 deletions

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,4 +151,4 @@ src-stats.yaml
151151
config.yaml
152152
*.yaml.gz
153153

154-
*_stories.py
154+
./*_stories.py

datafaker/install.py

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
"""Functions to install Python file references in ``config.yaml``."""
2+
from collections.abc import Mapping, MutableMapping, Sequence
3+
from inspect import Parameter, signature
4+
from pathlib import Path
5+
from typing import Any
6+
7+
from datafaker.utils import import_file, logger
8+
9+
10+
def _make_where_from_annotation(
11+
query_def: Mapping[str, Any],
12+
fn_name: str,
13+
param_name: str,
14+
) -> str:
15+
"""Make a where clause from ``query`` value from the annotation."""
16+
if "where" not in query_def:
17+
return ""
18+
w = query_def["where"]
19+
if isinstance(w, str):
20+
return f" WHERE {w}"
21+
if isinstance(w, Sequence):
22+
return " WHERE " + " AND ".join(f'"({clause})"' for clause in w)
23+
logger.warning(
24+
'"where" in the query annotation of parameter "%s" of function "%s"'
25+
" needs to be a string or a list of strings",
26+
param_name,
27+
fn_name,
28+
)
29+
return ""
30+
31+
32+
def _make_vars_from_annotation(
33+
query_def: Mapping[str, Any],
34+
fn_name: str,
35+
param_name: str,
36+
) -> Mapping[str, Any]:
37+
"""Make a variables dict from ``query`` value from the annotation."""
38+
if "vars" not in query_def:
39+
return {}
40+
vars_def = query_def["vars"]
41+
if isinstance(vars_def, Mapping):
42+
return vars_def
43+
if isinstance(vars_def, Sequence):
44+
return {v: v for v in query_def["vars"]}
45+
logger.warning(
46+
'"vars" in the query annotation of parameter "%s" of function "%s"'
47+
" needs to be a list of strings or a dict of strings to strings",
48+
param_name,
49+
fn_name,
50+
)
51+
return {}
52+
53+
54+
def _add_count_vars_from_annotation(
55+
group_vars_out: MutableMapping[str, Any],
56+
query_def: Mapping[str, Any],
57+
fn_name: str,
58+
param_name: str,
59+
) -> None:
60+
"""Add ``GROUP BY`` clauses from ``count_vars``."""
61+
if "count_vars" not in query_def:
62+
return
63+
cntv = query_def["count_vars"]
64+
if isinstance(cntv, Mapping):
65+
group_vars_out.update({k: f"COUNT({v})" for k, v in cntv})
66+
return
67+
logger.warning(
68+
'"count_vars" needs to be a dict in the annotation for parameter %s of function %s',
69+
param_name,
70+
fn_name,
71+
)
72+
73+
74+
def _add_ms_vars_from_annotation(
75+
group_vars_out: MutableMapping[str, Any],
76+
query_def: Mapping[str, Any],
77+
fn_name: str,
78+
param_name: str,
79+
) -> None:
80+
"""Add ``GROUP BY`` clauses from ``ms_vars``."""
81+
if "ms_vars" not in query_def:
82+
return
83+
msv = query_def["ms_vars"]
84+
if not isinstance(msv, Mapping):
85+
logger.warning(
86+
'"ms_vars" needs to be a dict in the annotation for parameter %s of function %s',
87+
param_name,
88+
fn_name,
89+
)
90+
return
91+
for k, v in msv.items():
92+
group_vars_out[k + "_count"] = f"COUNT({v})"
93+
group_vars_out[k + "_mean"] = f"AVG({v})"
94+
group_vars_out[k + "_stddev"] = f"STDDEV({v})"
95+
96+
97+
def make_query_from_annotation(
98+
annotation_data: Any,
99+
fn_name: str,
100+
param_name: str,
101+
) -> str | None:
102+
"""
103+
Make new configuration items describing a query.
104+
105+
The query's result will be passed as this parameter to this function.
106+
107+
The annotation must be a dict with the following keys:
108+
109+
``comment``: A string describing the query in natural language.
110+
111+
``query``: Either a string containing the SQL query required, or
112+
a dict containing the following keys:
113+
114+
* ``table``: The table to query. Could be "tablename AS alias" if you like.
115+
* ``vars`` (optional): Either a list of columns to extract from the table(s),
116+
or a dict of keys (the names of the keys in the dict to be passed to the
117+
annotated function) to values (the names of the columns to be extracted).
118+
At least one of ``vars``, ``ms_vars``, ``count_vars`` must be present.
119+
* ``where`` (optional): A SQL expression to filter the results.
120+
* ``count_vars`` (optional): A dict of keys to be passed to the function
121+
to values that are the names of the columns to be counted (could be
122+
``*``; if the name of a column the result will be the number of non-null
123+
entries in that column). The query will be grouped by ``vars``.
124+
* ``ms_vars`` (optional): A dict of value names to columns to be analysed.
125+
The keys to be passed to the function will be name + ``_count`` for the
126+
number of non-null values in that column, name + ``_mean`` for the
127+
average value in that column and name + ``_stddev`` for the standard
128+
deviation of values in that column.
129+
130+
:param annotation_data: The ``Annotation`` attached to the parameter.
131+
:param fn_name: The name of the function that the parameter is of.
132+
:param param_name: The name of the parameter with the annotation.
133+
:return: A mapping of new configuration items to add to the configuration,
134+
if the annotation had a well-defined query and comment value; otherwise
135+
an empty dict.
136+
"""
137+
if not isinstance(annotation_data, Sequence):
138+
return None
139+
ann = annotation_data[0]
140+
if not isinstance(ann, Mapping) or "query" not in ann:
141+
return None
142+
if isinstance(ann["query"], str):
143+
return ann["query"]
144+
query_def = ann["query"]
145+
if "table" not in query_def:
146+
logger.warning(
147+
'"table" needs to be a key in the annotation for'
148+
' the "query" value of parameter "%s" of function "%s"',
149+
param_name,
150+
fn_name,
151+
)
152+
return None
153+
table = query_def["table"]
154+
nongroup_vars = _make_vars_from_annotation(query_def, fn_name, param_name)
155+
where = _make_where_from_annotation(query_def, fn_name, param_name)
156+
group_vars: dict[str, Any] = {}
157+
_add_count_vars_from_annotation(group_vars, query_def, fn_name, param_name)
158+
_add_ms_vars_from_annotation(group_vars, query_def, fn_name, param_name)
159+
if group_vars and nongroup_vars:
160+
group_by = " GROUP BY " + ", ".join(f'"{v}"' for v in nongroup_vars)
161+
else:
162+
group_by = ""
163+
vars_exprs = ", ".join(
164+
f'{v} AS "{k}"' for k, v in {**nongroup_vars, **group_vars}.items()
165+
)
166+
return f"SELECT {vars_exprs} FROM {table}{group_by}{where}"
167+
168+
169+
def _add_kwarg(
170+
kwargs_out: dict[str, Any], fn_name: str, param: Parameter
171+
) -> list[dict[str, Any]]:
172+
"""
173+
Add a kwargs configuration and return a ``src_stats`` query item.
174+
175+
:param kwargs_out: The story generator's ``kwargs`` value to be updated.
176+
:param fn_name: The name of the story generator function.
177+
:param param: The parameter to specify.
178+
:return: A list of configuration items to add to the ``src_stats`` config, for
179+
all the queries this parameter requires.
180+
"""
181+
if param.annotation is Parameter.empty:
182+
return []
183+
meta = param.annotation.__metadata__
184+
query = make_query_from_annotation(
185+
param.annotation.__metadata__, fn_name, param.name
186+
)
187+
if query is None:
188+
return []
189+
stat_name = f"story_auto__{fn_name}__{param.name}"
190+
if "comments" in meta[0]:
191+
comments = [meta[0]["comment"]]
192+
else:
193+
comments = []
194+
ssc = {
195+
"name": stat_name,
196+
"query": query,
197+
"comments": comments,
198+
}
199+
kwargs_out[param.name] = f'SRC_STATS["{stat_name}"]["results"]'
200+
return [ssc]
201+
202+
203+
def install_stories_from(config: MutableMapping[str, Any], story_file: Path) -> bool:
204+
"""
205+
Configure datafaker with the stories in a Python file.
206+
207+
:param config: The contents of the configuration file, to be mutated.
208+
:param story_file: Path to the Python file containing the story generators.
209+
:return: True if the config was updated correctly, False if it was untouched
210+
because problems were encountered.
211+
"""
212+
story_generators: list[Mapping[str, Any]] = []
213+
src_stats = [
214+
s
215+
for s in config.get("src_stats", [])
216+
if isinstance(s, Mapping)
217+
and "name" in s
218+
and not s["name"].startswith("story_auto__")
219+
]
220+
story_module_name = story_file.stem
221+
story_module = import_file(story_file, story_module_name)
222+
for attr_name in dir(story_module):
223+
attr = getattr(story_module, attr_name)
224+
if (
225+
hasattr(attr, "__module__")
226+
and attr.__module__ == story_module_name
227+
and not attr_name.startswith("_")
228+
and callable(attr)
229+
):
230+
kwargs: dict[str, None] = {}
231+
sig = signature(attr)
232+
for param in sig.parameters.values():
233+
src_stats += _add_kwarg(kwargs, attr, param)
234+
story_generators.append(
235+
{
236+
"name": f"{story_module_name}.{attr_name}",
237+
"num_stories_per_pass": 1,
238+
"kwargs": kwargs,
239+
}
240+
)
241+
config["story_generators_module"] = story_module_name
242+
config["story_generators"] = story_generators
243+
config["src-stats"] = src_stats
244+
return True

0 commit comments

Comments
 (0)