Skip to content

Commit 0c89543

Browse files
committed
Add script to generate provisional nodes
1 parent 997da0a commit 0c89543

1 file changed

Lines changed: 294 additions & 0 deletions

File tree

Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
import os
2+
import re
3+
import time
4+
import logging
5+
import csv
6+
import io
7+
import sys
8+
9+
from absl import app
10+
from absl import flags
11+
from google.cloud import spanner
12+
13+
FLAGS = flags.FLAGS
14+
flags.DEFINE_string("directory", os.getcwd(),
15+
"Directory to scan (default: current working directory)")
16+
flags.DEFINE_bool("no_spanner", False, "Skip Spanner check")
17+
flags.DEFINE_string("spanner_project", "datcom-store", "Spanner project ID")
18+
flags.DEFINE_string("spanner_instance", "dc-kg-test", "Spanner instance ID")
19+
flags.DEFINE_string("spanner_database", "dc_graph_prototype",
20+
"Spanner database ID")
21+
22+
logging.basicConfig(level=logging.INFO,
23+
format='%(asctime)s - %(levelname)s - %(message)s')
24+
25+
# Spanner Configuration
26+
BATCH_SIZE = 1000 # Number of IDs to query at once
27+
28+
# Increase CSV field size limit for large MCF values
29+
csv.field_size_limit(sys.maxsize)
30+
31+
ENTITY_PREFIXES = ("dcid:", "dcs:", "schema:")
32+
33+
34+
def strip_prefix(s):
35+
"""Strips common prefixes (dcid:, dcs:, schema:) from a string."""
36+
# Strip common prefixes
37+
for prefix in ENTITY_PREFIXES:
38+
if s.startswith(prefix):
39+
return s[len(prefix):]
40+
return s
41+
42+
43+
def is_quoted(s):
44+
"""Checks if a string is surrounded by double quotes."""
45+
s = s.strip()
46+
return s.startswith('"') and s.endswith('"')
47+
48+
49+
def strip_quotes(s):
50+
"""Removes surrounding double quotes from a string if present."""
51+
s = s.strip()
52+
if is_quoted(s):
53+
return s[1:-1]
54+
return s
55+
56+
57+
def check_spanner_nodes(node_ids, project, instance_id, database_id):
58+
"""
59+
Checks which of the given node_ids exist in the Spanner Node table.
60+
61+
Args:
62+
node_ids: A collection of node IDs (strings) to check against Spanner.
63+
project: Spanner project ID.
64+
instance_id: Spanner instance ID.
65+
database_id: Spanner database ID.
66+
67+
Returns:
68+
A set containing the node IDs that were found in the Spanner database.
69+
"""
70+
existing_nodes = set()
71+
node_ids_list = list(node_ids)
72+
73+
if not node_ids_list:
74+
return existing_nodes
75+
76+
logging.info(
77+
f"Checking {len(node_ids_list)} potential missing nodes in Spanner...")
78+
79+
try:
80+
spanner_client = spanner.Client(project=project)
81+
instance = spanner_client.instance(instance_id)
82+
database = instance.database(database_id)
83+
84+
# Using a single snapshot for consistency across batches
85+
with database.snapshot(multi_use=True) as snapshot:
86+
total_batches = (len(node_ids_list) + BATCH_SIZE - 1) // BATCH_SIZE
87+
88+
for i in range(0, len(node_ids_list), BATCH_SIZE):
89+
batch_num = (i // BATCH_SIZE) + 1
90+
if batch_num % 10 == 0 or batch_num == 1 or batch_num == total_batches:
91+
logging.info(
92+
f"Processing batch {batch_num}/{total_batches}...")
93+
94+
batch = node_ids_list[i:i + BATCH_SIZE]
95+
96+
try:
97+
result = snapshot.execute_sql(
98+
"SELECT subject_id FROM Node WHERE subject_id IN UNNEST(@ids)",
99+
params={"ids": batch},
100+
param_types={
101+
"ids":
102+
spanner.param_types.Array(
103+
spanner.param_types.STRING)
104+
})
105+
106+
# Consume the result fully
107+
for row in result:
108+
existing_nodes.add(row[0])
109+
110+
except Exception as e:
111+
logging.error(f"Error in batch {batch_num}: {e}")
112+
113+
except Exception as e:
114+
logging.error(f"Failed to connect to Spanner or create snapshot: {e}")
115+
116+
return existing_nodes
117+
118+
119+
def generate_provisional_nodes(scan_dir,
120+
no_spanner=False,
121+
spanner_project=None,
122+
spanner_instance=None,
123+
spanner_database=None):
124+
"""
125+
Scans a directory of MCF files to find undefined nodes referenced in properties.
126+
127+
Args:
128+
scan_dir: The local directory containing .mcf files to scan.
129+
no_spanner: If True, skips checking Cloud Spanner for existing nodes.
130+
spanner_project: Spanner project ID.
131+
spanner_instance: Spanner instance ID.
132+
spanner_database: Spanner database ID.
133+
134+
Returns:
135+
The path to the generated provisional_nodes.mcf file.
136+
"""
137+
start_time = time.time()
138+
root_dir = os.path.abspath(scan_dir)
139+
output_dir = root_dir
140+
141+
defined_nodes = set()
142+
referenced_properties = set()
143+
referenced_values = set()
144+
145+
# Regex to capture "Key: Value"
146+
pair_re = re.compile(r"^(\w+):\s*(.*)$")
147+
148+
logging.info(f"Scanning directory: {root_dir}")
149+
150+
# Walk through the directory to process each .mcf file
151+
for dirpath, dirnames, filenames in os.walk(root_dir):
152+
for filename in filenames:
153+
if ".mcf" not in filename:
154+
continue
155+
156+
# Skip the output file itself if it already exists from a previous run
157+
if filename == "provisional_nodes.mcf":
158+
continue
159+
160+
filepath = os.path.join(dirpath, filename)
161+
162+
with open(filepath, 'r', encoding='utf-8') as f:
163+
current_node_id = None
164+
165+
for line in f:
166+
line = line.strip()
167+
if not line or line.startswith("//") or line.startswith(
168+
"#"):
169+
continue
170+
171+
# Check for Node definition (e.g., "Node: dcid:City")
172+
if line.startswith("Node:"):
173+
if current_node_id:
174+
defined_nodes.add(current_node_id)
175+
176+
parts = line.split(":", 1)
177+
if len(parts) > 1:
178+
current_node_id = strip_prefix(
179+
strip_quotes(parts[1]))
180+
else:
181+
current_node_id = None
182+
continue
183+
184+
# Check for Property: Value pairs
185+
match = pair_re.match(line)
186+
if match:
187+
key = match.group(1).strip()
188+
value_str = match.group(2).strip()
189+
190+
# If explicitly defining dcid as a property, use that as the node ID
191+
if key == "dcid":
192+
current_node_id = strip_prefix(
193+
strip_quotes(value_str))
194+
continue
195+
196+
# 1. The Key (Property) is a reference to a Property node (e.g., "containedInPlace")
197+
referenced_properties.add(strip_prefix(key))
198+
199+
# 2. The Value: Only check for explicit prefixes indicating references (e.g., "dcid:geoId/06")
200+
f_io = io.StringIO(value_str)
201+
reader = csv.reader(f_io, skipinitialspace=True)
202+
try:
203+
tokens = next(reader)
204+
except StopIteration:
205+
tokens = []
206+
207+
for token in tokens:
208+
if not token:
209+
continue
210+
211+
clean_token = strip_quotes(token)
212+
213+
# Only strict prefixes are references
214+
if clean_token.startswith(ENTITY_PREFIXES):
215+
ref_id = strip_prefix(clean_token)
216+
referenced_values.add(ref_id)
217+
218+
# Add the last node of the file
219+
if current_node_id:
220+
defined_nodes.add(current_node_id)
221+
222+
# Calculate initially missing nodes (referenced but not defined locally)
223+
missing_props = referenced_properties - defined_nodes
224+
missing_values = referenced_values - defined_nodes
225+
all_missing_local = missing_props | missing_values
226+
227+
# Filter out empty strings if any
228+
all_missing_local = {m for m in all_missing_local if m}
229+
230+
logging.info(f"Found {len(defined_nodes)} defined nodes.")
231+
logging.info(f"Found {len(referenced_properties)} referenced properties.")
232+
logging.info(f"Found {len(referenced_values)} referenced values.")
233+
logging.info(f"Found {len(all_missing_local)} locally missing definitions.")
234+
235+
# Save locally missing nodes to a file (pre-Spanner check) for debugging/audit
236+
local_missing_file_path = os.path.join(output_dir,
237+
"local_missing_nodes.txt")
238+
with open(local_missing_file_path, "w") as f:
239+
for m in sorted(all_missing_local):
240+
f.write(f"{m}\n")
241+
logging.info(
242+
f"Written locally missing nodes (pre-Spanner) to {local_missing_file_path}"
243+
)
244+
245+
# Check Spanner for existence of these missing nodes
246+
existing_in_spanner = set()
247+
if not no_spanner:
248+
existing_in_spanner = check_spanner_nodes(all_missing_local,
249+
spanner_project,
250+
spanner_instance,
251+
spanner_database)
252+
253+
if existing_in_spanner:
254+
logging.info(
255+
f"Found {len(existing_in_spanner)} nodes in Spanner (will not be emitted)."
256+
)
257+
else:
258+
logging.info("Skipping Spanner check as requested.")
259+
260+
# Final missing set = missing locally AND missing in Spanner
261+
final_missing = all_missing_local - existing_in_spanner
262+
263+
logging.info(f"Final missing count: {len(final_missing)}")
264+
265+
# Generate the provisional nodes MCF file
266+
output_file_path = os.path.join(output_dir, "provisional_nodes.mcf")
267+
with open(output_file_path, "w") as out:
268+
for m in sorted(final_missing):
269+
if m in missing_props:
270+
node_type = "dcs:Property"
271+
else:
272+
node_type = "dcs:ProvisionalNode"
273+
274+
# We don't print to stdout, we write to file directly to be useful
275+
node_def = f"Node: dcid:{m}\ntypeOf: {node_type}\nisProvisional: dcs:True\n\n"
276+
out.write(node_def)
277+
278+
logging.info(f"Written missing nodes to {output_file_path}")
279+
280+
end_time = time.time()
281+
logging.info(f"Total runtime: {end_time - start_time:.2f} seconds")
282+
return output_file_path
283+
284+
285+
def main(_):
286+
output_path = generate_provisional_nodes(FLAGS.directory, FLAGS.no_spanner,
287+
FLAGS.spanner_project,
288+
FLAGS.spanner_instance,
289+
FLAGS.spanner_database)
290+
logging.info(f"Generated provisional nodes at: {output_path}")
291+
292+
293+
if __name__ == "__main__":
294+
app.run(main)

0 commit comments

Comments
 (0)