Skip to content

Commit ab403ee

Browse files
committed
Add script to generate provisional nodes
1 parent 997da0a commit ab403ee

1 file changed

Lines changed: 287 additions & 0 deletions

File tree

Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
import os
2+
import re
3+
import time
4+
import logging
5+
import csv
6+
import io
7+
8+
from absl import app
9+
from absl import flags
10+
from google.cloud import spanner
11+
12+
FLAGS = flags.FLAGS
13+
flags.DEFINE_string("directory", os.getcwd(),
14+
"Directory to scan (default: current working directory)")
15+
flags.DEFINE_bool("no_spanner", False, "Skip Spanner check")
16+
flags.DEFINE_string("spanner_project", "datcom-store", "Spanner project ID")
17+
flags.DEFINE_string("spanner_instance", "dc-kg-test", "Spanner instance ID")
18+
flags.DEFINE_string("spanner_database", "dc_graph_prototype",
19+
"Spanner database ID")
20+
21+
logging.basicConfig(level=logging.INFO,
22+
format='%(asctime)s - %(levelname)s - %(message)s')
23+
24+
# Spanner Configuration
25+
BATCH_SIZE = 1000 # Number of IDs to query at once
26+
27+
ENTITY_PREFIXES = ("dcid:", "dcs:", "schema:")
28+
29+
30+
def strip_prefix(s):
31+
"""Strips common prefixes (dcid:, dcs:, schema:) from a string."""
32+
# Strip common prefixes
33+
for prefix in ENTITY_PREFIXES:
34+
if s.startswith(prefix):
35+
return s[len(prefix):]
36+
return s
37+
38+
39+
def is_quoted(s):
40+
"""Checks if a string is surrounded by double quotes."""
41+
s = s.strip()
42+
return s.startswith('"') and s.endswith('"')
43+
44+
45+
def strip_quotes(s):
46+
"""Removes surrounding double quotes from a string if present."""
47+
s = s.strip()
48+
if is_quoted(s):
49+
return s[1:-1]
50+
return s
51+
52+
53+
def check_spanner_nodes(node_ids, project, instance_id, database_id):
54+
"""
55+
Checks which of the given node_ids exist in the Spanner Node table.
56+
57+
Args:
58+
node_ids: A collection of node IDs (strings) to check against Spanner.
59+
project: Spanner project ID.
60+
instance_id: Spanner instance ID.
61+
database_id: Spanner database ID.
62+
63+
Returns:
64+
A set containing the node IDs that were found in the Spanner database.
65+
"""
66+
existing_nodes = set()
67+
node_ids_list = list(node_ids)
68+
69+
if not node_ids_list:
70+
return existing_nodes
71+
72+
logging.info(
73+
f"Checking {len(node_ids_list)} potential missing nodes in Spanner...")
74+
75+
try:
76+
spanner_client = spanner.Client(project=project)
77+
instance = spanner_client.instance(instance_id)
78+
database = instance.database(database_id)
79+
80+
# Using a single snapshot for consistency across batches
81+
with database.snapshot(multi_use=True) as snapshot:
82+
total_batches = (len(node_ids_list) + BATCH_SIZE - 1) // BATCH_SIZE
83+
84+
for i in range(0, len(node_ids_list), BATCH_SIZE):
85+
batch_num = (i // BATCH_SIZE) + 1
86+
if batch_num % 10 == 0 or batch_num == 1 or batch_num == total_batches:
87+
logging.info(
88+
f"Processing batch {batch_num}/{total_batches}...")
89+
90+
batch = node_ids_list[i:i + BATCH_SIZE]
91+
92+
try:
93+
result = snapshot.execute_sql(
94+
"SELECT subject_id FROM Node WHERE subject_id IN UNNEST(@ids)",
95+
params={"ids": batch},
96+
param_types={
97+
"ids":
98+
spanner.param_types.Array(
99+
spanner.param_types.STRING)
100+
})
101+
102+
# Consume the result fully
103+
for row in result:
104+
existing_nodes.add(row[0])
105+
106+
except Exception as e:
107+
logging.error(f"Error in batch {batch_num}: {e}")
108+
109+
except Exception as e:
110+
logging.error(f"Failed to connect to Spanner or create snapshot: {e}")
111+
112+
return existing_nodes
113+
114+
115+
def generate_provisional_nodes(scan_dir,
116+
no_spanner=False,
117+
spanner_project=None,
118+
spanner_instance=None,
119+
spanner_database=None):
120+
"""
121+
Scans a directory of MCF files to find undefined nodes referenced in properties.
122+
123+
Args:
124+
scan_dir: The local directory containing .mcf files to scan.
125+
no_spanner: If True, skips checking Cloud Spanner for existing nodes.
126+
spanner_project: Spanner project ID.
127+
spanner_instance: Spanner instance ID.
128+
spanner_database: Spanner database ID.
129+
130+
Returns:
131+
The path to the generated provisional_nodes.mcf file.
132+
"""
133+
start_time = time.time()
134+
root_dir = os.path.abspath(scan_dir)
135+
output_dir = root_dir
136+
137+
defined_nodes = set()
138+
referenced_properties = set()
139+
referenced_values = set()
140+
141+
# Regex to capture "Key: Value"
142+
pair_re = re.compile(r"^([^"]+):\s*(.*)$")
143+
144+
logging.info(f"Scanning directory: {root_dir}")
145+
146+
# Walk through the directory to process each .mcf file
147+
for dirpath, dirnames, filenames in os.walk(root_dir):
148+
for filename in filenames:
149+
if ".mcf" not in filename:
150+
continue
151+
152+
# Skip the output file itself if it already exists from a previous run
153+
if filename == "provisional_nodes.mcf":
154+
continue
155+
156+
filepath = os.path.join(dirpath, filename)
157+
158+
with open(filepath, 'r', encoding='utf-8') as f:
159+
current_node_id = None
160+
161+
for line in f:
162+
line = line.strip()
163+
if not line or line.startswith("//") or line.startswith("#"):
164+
continue
165+
166+
# Check for Node definition (e.g., "Node: dcid:City")
167+
if line.startswith("Node:"):
168+
if current_node_id:
169+
defined_nodes.add(current_node_id)
170+
171+
parts = line.split(":", 1)
172+
if len(parts) > 1:
173+
current_node_id = strip_prefix(strip_quotes(parts[1]))
174+
else:
175+
current_node_id = None
176+
continue
177+
178+
# Check for Property: Value pairs
179+
match = pair_re.match(line)
180+
if match:
181+
key = match.group(1).strip()
182+
value_str = match.group(2).strip()
183+
184+
# If explicitly defining dcid as a property, use that as the node ID
185+
if key == "dcid":
186+
current_node_id = strip_prefix(strip_quotes(value_str))
187+
continue
188+
189+
# 1. The Key (Property) is a reference to a Property node (e.g., "containedInPlace")
190+
referenced_properties.add(strip_prefix(key))
191+
192+
# 2. The Value: Only check for explicit prefixes indicating references (e.g., "dcid:geoId/06")
193+
f_io = io.StringIO(value_str)
194+
reader = csv.reader(f_io, skipinitialspace=True)
195+
try:
196+
tokens = next(reader)
197+
except StopIteration:
198+
tokens = []
199+
200+
for token in tokens:
201+
if not token:
202+
continue
203+
204+
clean_token = strip_quotes(token)
205+
206+
# Only strict prefixes are references
207+
if clean_token.startswith(ENTITY_PREFIXES):
208+
ref_id = strip_prefix(clean_token)
209+
referenced_values.add(ref_id)
210+
211+
# Add the last node of the file
212+
if current_node_id:
213+
defined_nodes.add(current_node_id)
214+
215+
# Calculate initially missing nodes (referenced but not defined locally)
216+
missing_props = referenced_properties - defined_nodes
217+
missing_values = referenced_values - defined_nodes
218+
all_missing_local = missing_props | missing_values
219+
220+
# Filter out empty strings if any
221+
all_missing_local = {m for m in all_missing_local if m}
222+
223+
logging.info(f"Found {len(defined_nodes)} defined nodes.")
224+
logging.info(f"Found {len(referenced_properties)} referenced properties.")
225+
logging.info(f"Found {len(referenced_values)} referenced values.")
226+
logging.info(f"Found {len(all_missing_local)} locally missing definitions.")
227+
228+
# Save locally missing nodes to a file (pre-Spanner check) for debugging/audit
229+
local_missing_file_path = os.path.join(output_dir,
230+
"local_missing_nodes.txt")
231+
with open(local_missing_file_path, "w") as f:
232+
for m in sorted(all_missing_local):
233+
f.write(f"{m}\n")
234+
logging.info(
235+
f"Written locally missing nodes (pre-Spanner) to {local_missing_file_path}"
236+
)
237+
238+
# Check Spanner for existence of these missing nodes
239+
existing_in_spanner = set()
240+
if not no_spanner:
241+
existing_in_spanner = check_spanner_nodes(all_missing_local,
242+
spanner_project,
243+
spanner_instance,
244+
spanner_database)
245+
246+
if existing_in_spanner:
247+
logging.info(
248+
f"Found {len(existing_in_spanner)} nodes in Spanner (will not be emitted)."
249+
)
250+
else:
251+
logging.info("Skipping Spanner check as requested.")
252+
253+
# Final missing set = missing locally AND missing in Spanner
254+
final_missing = all_missing_local - existing_in_spanner
255+
256+
logging.info(f"Final missing count: {len(final_missing)}")
257+
258+
# Generate the provisional nodes MCF file
259+
output_file_path = os.path.join(output_dir, "provisional_nodes.mcf")
260+
with open(output_file_path, "w") as out:
261+
for m in sorted(final_missing):
262+
if m in missing_props:
263+
node_type = "dcs:Property"
264+
else:
265+
node_type = "dcs:ProvisionalNode"
266+
267+
# We don't print to stdout, we write to file directly to be useful
268+
node_def = f"Node: dcid:{m}\ntypeOf: {node_type}\nisProvisional: dcs:True\n\n"
269+
out.write(node_def)
270+
271+
logging.info(f"Written missing nodes to {output_file_path}")
272+
273+
end_time = time.time()
274+
logging.info(f"Total runtime: {end_time - start_time:.2f} seconds")
275+
return output_file_path
276+
277+
278+
def main(_):
279+
output_path = generate_provisional_nodes(FLAGS.directory, FLAGS.no_spanner,
280+
FLAGS.spanner_project,
281+
FLAGS.spanner_instance,
282+
FLAGS.spanner_database)
283+
logging.info(f"Generated provisional nodes at: {output_path}")
284+
285+
286+
if __name__ == "__main__":
287+
app.run(main)

0 commit comments

Comments
 (0)