Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dgf/src/io/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def cache(

if len(found_vars) == len(variable_names):
if return_tuple:
return tuple(found_vars)
return tuple(found_vars) # pyrefly: ignore[bad-return]
else:
return found_vars[0]

Expand Down
14 changes: 7 additions & 7 deletions dgf/src/io/graph_in_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def read_graph(

# Metadata
with filesystem.open_read(os.path.join(path, FILENAME_METADATA)) as f:
metadata = gf_metadata_lib.GFGraphMetadata.from_json(f.read())
metadata = gf_metadata_lib.GFGraphMetadata.from_json(f.read()) # pyrefly: ignore[missing-attribute]

if metadata.version > MAX_SUPPORTED_GF_VERSION:
raise NotImplementedError(
Expand Down Expand Up @@ -289,7 +289,7 @@ def write_graph(
metadata = gf_metadata_lib.GFGraphMetadata(version=MAX_SUPPORTED_GF_VERSION)
metadata_path = os.path.join(path, FILENAME_METADATA)
with filesystem.open_write(metadata_path) as f:
f.write(metadata.to_json(indent=2))
f.write(metadata.to_json(indent=2)) # pyrefly: ignore[missing-attribute]

write_results = []

Expand Down Expand Up @@ -346,7 +346,7 @@ def _feature_schema_to_parquet_fields(
"""Creates the schema for the parquet node container."""
fields = []
# Note: The schema has the node "#id".
for feature_name, feature_schema in feature_schema.items():
for feature_name, feature_schema in feature_schema.items(): # pyrefly: ignore[missing-attribute]
pa_type = FEATURE_FORMAT_TO_PY_ARROW_DTYPE[feature_schema.format]
shape = feature_schema.shape
if shape is None:
Expand All @@ -364,7 +364,7 @@ def _node_schema_to_parquet_schema(
node_schema: schema_lib.NodeSchema,
) -> pyarrow.Schema:
"""Creates the schema for the parquet node container."""
return pyarrow.schema(_feature_schema_to_parquet_fields(node_schema.features))
return pyarrow.schema(_feature_schema_to_parquet_fields(node_schema.features)) # pyrefly: ignore[bad-argument-type]


def _edge_schema_to_parquet_schema(
Expand Down Expand Up @@ -397,7 +397,7 @@ def _edge_schema_to_parquet_schema(
pyarrow.field(
KEY_TARGET, FEATURE_FORMAT_TO_PY_ARROW_DTYPE[target_node_format]
),
] + _feature_schema_to_parquet_fields(edge_schema.features)
] + _feature_schema_to_parquet_fields(edge_schema.features) # pyrefly: ignore[bad-argument-type]
return pyarrow.schema(fields)


Expand All @@ -411,7 +411,7 @@ def _node_to_raw(
if feature_name == primary_key:
raw_dict[feature_name] = node.id
else:
feature_values = node.features[feature_name]
feature_values = node.features[feature_name] # pyrefly: ignore[unsupported-operation]
raw_dict[feature_name] = feature_values.tolist()
return raw_dict

Expand All @@ -425,6 +425,6 @@ def _edge_to_raw(
KEY_TARGET: edge.target,
}
for feature_name in schema.features:
feature_values = edge.features[feature_name]
feature_values = edge.features[feature_name] # pyrefly: ignore[unsupported-operation]
raw_dict[feature_name] = feature_values.tolist()
return raw_dict
8 changes: 4 additions & 4 deletions dgf/src/io/graph_in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def _read_edge_set(
target_mapper,
source_ids,
target_ids,
min(32, os.cpu_count()),
min(32, os.cpu_count()), # pyrefly: ignore[bad-specialization]
)
else:
# Slow path
Expand Down Expand Up @@ -280,7 +280,7 @@ def read_graph(
with filesystem.open_read(os.path.join(path, FILENAME_METADATA)) as f:
if verbose:
log.info("Reading metadata from %s", path)
metadata = gf_metadata_lib.GFGraphMetadata.from_json(f.read())
metadata = gf_metadata_lib.GFGraphMetadata.from_json(f.read()) # pyrefly: ignore[missing-attribute]

if metadata.version > MAX_SUPPORTED_GF_VERSION:
raise NotImplementedError(
Expand Down Expand Up @@ -376,7 +376,7 @@ def write_graph(
if verbose:
log.info("Writing metadata to %s", metadata_path)
with filesystem.open_write(metadata_path) as f:
f.write(metadata.to_json(indent=2))
f.write(metadata.to_json(indent=2)) # pyrefly: ignore[missing-attribute]

# Write Node Sets
node_dir = os.path.join(path, FILENAME_NODE_FEATURE)
Expand All @@ -390,7 +390,7 @@ def write_graph(
if verbose:
log.info("Writing nodeset %s to %s", nodeset_name, node_dir)

num_shards, _ = shard_lib.estimate_num_node_shards(node_set.num_nodes)
num_shards, _ = shard_lib.estimate_num_node_shards(node_set.num_nodes) # pyrefly: ignore[bad-argument-type]
if max_num_shards is not None:
num_shards = min(num_shards, max_num_shards)

Expand Down
14 changes: 7 additions & 7 deletions dgf/src/io/hgraph_in_avro.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _generate_node_records(
"""Generates records for writing node sets to Avro."""
iterator = range(start_index, end_index)
if verbose:
iterator = tqdm(
iterator = tqdm( # pyrefly: ignore[not-callable]
iterator,
desc=f" - Writing nodes for '{name}'",
unit="node",
Expand All @@ -109,7 +109,7 @@ def _generate_edge_records(
"""Generates records for writing edge sets to Avro."""
iterator = range(start_index, end_index)
if verbose:
iterator = tqdm(
iterator = tqdm( # pyrefly: ignore[not-callable]
iterator,
desc=f" - Writing edges for '{name}'",
unit="edge",
Expand Down Expand Up @@ -152,7 +152,7 @@ def write_avro_node_sets(
parsed_schema = fastavro.parse_schema(avro_schema_dict)
feature_items = list(nodeset.features.items())
num_shards, num_nodes_per_shard = shard_lib.estimate_num_node_shards(
nodeset.num_nodes
nodeset.num_nodes # pyrefly: ignore[bad-argument-type]
)
for shard_index in range(num_shards):
filename = shard_lib.sharded_filename(
Expand All @@ -163,7 +163,7 @@ def write_avro_node_sets(
)
filepath = os.path.join(directory, filename)
start_index = shard_index * num_nodes_per_shard
end_index = min(
end_index = min( # pyrefly: ignore[bad-specialization]
(shard_index + 1) * num_nodes_per_shard, nodeset.num_nodes
)
with filesystem.open_write(filepath, binary=True) as f_out:
Expand All @@ -173,7 +173,7 @@ def write_avro_node_sets(
_generate_node_records(
feature_items,
start_index,
end_index,
end_index, # pyrefly: ignore[bad-argument-type]
nodeset_name,
verbose,
),
Expand Down Expand Up @@ -280,15 +280,15 @@ def read_avro_record(
reader = fastavro_reader(f_in)
record_iterator = reader
if verbose:
record_iterator = tqdm(
record_iterator = tqdm( # pyrefly: ignore[not-callable]
reader,
desc=f" - Reading records from {avro_file}",
unit="record",
)
for record in record_iterator:
num_records += 1
for feature_name in feature_builders.keys():
feature_builders[feature_name].append(record[feature_name])
feature_builders[feature_name].append(record[feature_name]) # pyrefly: ignore[bad-index, unsupported-operation]

# Convert lists to numpy arrays
final_features = {}
Expand Down
14 changes: 7 additions & 7 deletions dgf/src/io/hgraph_in_avro_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def test_read_avro_record(self):
"f1": ("int64", ()),
"f2": ("float32", ()),
}
data, num_records = avro_lib.read_avro_record([path], columns, False)
data, num_records = avro_lib.read_avro_record([path], columns, False) # pyrefly: ignore[bad-argument-type]

self.assertEqual(num_records, 2)
np.testing.assert_array_equal(data["f1"], np.array([1, 2], dtype="int64"))
Expand All @@ -175,7 +175,7 @@ def test_read_avro_record_sharded(self):

columns = {"f1": ("int64", ())}
data, num_records = avro_lib.read_avro_record(
[path1, path2], columns, False
[path1, path2], columns, False # pyrefly: ignore[bad-argument-type]
)

self.assertEqual(num_records, 3)
Expand All @@ -194,7 +194,7 @@ def test_read_avro_record_empty(self):
fastavro.writer(f, schema, records)

columns = {"f1": ("int64", ())}
data, num_records = avro_lib.read_avro_record([path], columns, False)
data, num_records = avro_lib.read_avro_record([path], columns, False) # pyrefly: ignore[bad-argument-type]

self.assertEqual(num_records, 0)
self.assertEqual(data["f1"].shape, (0,))
Expand All @@ -221,7 +221,7 @@ def test_write_avro_node_sets(self):
"f2": ("float32", (2,)),
}
n1_data, n1_num_records = avro_lib.read_avro_record(
[n1_path], n1_cols, False
[n1_path], n1_cols, False # pyrefly: ignore[bad-argument-type]
)
self.assertEqual(n1_num_records, 2)
np.testing.assert_array_equal(n1_data["#id"], np.array([b"1", b"2"]))
Expand All @@ -237,7 +237,7 @@ def test_write_avro_node_sets(self):
self.assertTrue(os.path.exists(n2_path))
n2_cols = {"#id": ("int64", ()), "f3": ("int64", ()), "f4": ("int64", ())}
n2_data, n2_num_records = avro_lib.read_avro_record(
[n2_path], n2_cols, False
[n2_path], n2_cols, False # pyrefly: ignore[bad-argument-type]
)
self.assertEqual(n2_num_records, 2)
np.testing.assert_array_equal(n2_data["#id"], np.array([1, 2]))
Expand Down Expand Up @@ -272,7 +272,7 @@ def test_write_avro_edge_sets(self):
"#id": ("bytes", ()),
}
e1_data, e1_num_records = avro_lib.read_avro_record(
[e1_path], e1_cols, False
[e1_path], e1_cols, False # pyrefly: ignore[bad-argument-type]
)
self.assertEqual(e1_num_records, 2)
np.testing.assert_array_equal(e1_data["#source"], np.array([b"1", b"1"]))
Expand All @@ -288,7 +288,7 @@ def test_write_avro_edge_sets(self):
"#id": ("bytes", ()),
}
e2_data, e2_num_records = avro_lib.read_avro_record(
[e2_path], e2_cols, False
[e2_path], e2_cols, False # pyrefly: ignore[bad-argument-type]
)
self.assertEqual(e2_num_records, 2)
np.testing.assert_array_equal(e2_data["#source"], np.array([b"1", b"1"]))
Expand Down
4 changes: 2 additions & 2 deletions dgf/src/io/hgraph_in_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def tf_feature_to_feature(
)
value = np.squeeze(value, axis=0)
elif value.ndim != 1:
value = np.reshape(value, feature_schema.shape)
value = np.reshape(value, feature_schema.shape) # pyrefly: ignore[no-matching-overload]
return value


Expand Down Expand Up @@ -558,7 +558,7 @@ def node_to_tf_example(
if feature_name == node_id_column:
value = [node.id]
else:
value = node.features[feature_name]
value = node.features[feature_name] # pyrefly: ignore[unsupported-operation]
if value.ndim == 0:
value = np.expand_dims(value, axis=0)

Expand Down
22 changes: 11 additions & 11 deletions dgf/src/io/hgraph_in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,11 +334,11 @@ def _read_container(
"""Reads features given a container type."""
if container_type == HGraphContainerType.TF_RECORD:
features, num_records = tfexample_lib.read_tf_record(
paths=paths, columns=columns, verbose=verbose, preserve_order=False
paths=paths, columns=columns, verbose=verbose, preserve_order=False # pyrefly: ignore[bad-argument-type]
)
elif container_type == HGraphContainerType.AVRO:
features, num_records = hgraph_in_avro.read_avro_record(
paths=paths, columns=columns, verbose=verbose
paths=paths, columns=columns, verbose=verbose # pyrefly: ignore[bad-argument-type]
)
else:
raise ValueError(
Expand Down Expand Up @@ -488,7 +488,7 @@ def mapper(ids: np.ndarray) -> Tuple[np.ndarray, int]:
raw_edges, _ = _read_container(
paths=paths,
container_type=container_type,
columns=columns,
columns=columns, # pyrefly: ignore[bad-argument-type]
verbose=verbose,
key_column=edge_id_column,
)
Expand All @@ -511,7 +511,7 @@ def mapper(ids: np.ndarray) -> Tuple[np.ndarray, int]:
target_mapper,
source_ids,
target_ids,
min(32, os.cpu_count()),
min(32, os.cpu_count()), # pyrefly: ignore[bad-specialization]
)
else:
# Slow path
Expand Down Expand Up @@ -629,16 +629,16 @@ def in_memory_node_to_tf_example(
node_id_column is not None
and node_id_column not in example.features.feature
):
if np.issubdtype(features[DEFAULT_KEY_ID].dtype, np.integer):
if np.issubdtype(features[DEFAULT_KEY_ID].dtype, np.integer): # pyrefly: ignore[unsupported-operation]
example.features.feature[node_id_column].int64_list.value.append(
features[DEFAULT_KEY_ID][node_index]
features[DEFAULT_KEY_ID][node_index] # pyrefly: ignore[unsupported-operation]
)
elif features[DEFAULT_KEY_ID].dtype.kind == "S":
elif features[DEFAULT_KEY_ID].dtype.kind == "S": # pyrefly: ignore[unsupported-operation]
example.features.feature[node_id_column].bytes_list.value.append(
features[DEFAULT_KEY_ID][node_index]
features[DEFAULT_KEY_ID][node_index] # pyrefly: ignore[unsupported-operation]
)
else:
raise ValueError(f"Non supported type {features[DEFAULT_KEY_ID]}")
raise ValueError(f"Non supported type {features[DEFAULT_KEY_ID]}") # pyrefly: ignore[unsupported-operation]
return example


Expand Down Expand Up @@ -773,7 +773,7 @@ def _write_tfrecord_node_sets(
"""Writes node sets to TFRecord files."""
for nodeset_name, nodeset in graph.node_sets.items():
num_shards, num_nodes_per_shard = shard_lib.estimate_num_node_shards(
nodeset.num_nodes
nodeset.num_nodes # pyrefly: ignore[bad-argument-type]
)
for shard_index in range(num_shards):
examples = []
Expand All @@ -785,7 +785,7 @@ def _write_tfrecord_node_sets(
)
for node_index in range(
shard_index * num_nodes_per_shard,
min(
min( # pyrefly: ignore[bad-argument-type, bad-specialization]
(shard_index + 1) * num_nodes_per_shard,
nodeset.num_nodes,
),
Expand Down
2 changes: 1 addition & 1 deletion dgf/src/io/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _asarray(x):
for node_set_name, node_set in src.node_sets.items():
jax_features = {k: _asarray(v) for k, v in node_set.features.items()}
jax_node_sets[node_set_name] = jax_in_memory_graph_lib.JaxInMemoryNodeSet(
features=jax_features, num_nodes=node_set.num_nodes
features=jax_features, num_nodes=node_set.num_nodes # pyrefly: ignore[bad-argument-type]
)

jax_edge_sets = {}
Expand Down
2 changes: 1 addition & 1 deletion dgf/src/io/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
@numba.njit(parallel=False)
def _numba_copy_kernel(offset, offsets: np.ndarray, data, out, max_len):
"""Copies data from a flat array to a padded output array using Numba."""
for i in numba.prange(len(offsets) - 1):
for i in numba.prange(len(offsets) - 1): # pyrefly: ignore[not-iterable]
start = offsets[i]
end = offsets[i + 1]
length = end - start
Expand Down
4 changes: 2 additions & 2 deletions dgf/src/io/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def read_schema(path: str) -> schema_lib.GraphSchema:
The loaded graph schema.
"""
with filesystem.open_read(path) as f:
return schema_lib.GraphSchema.from_json(f.read())
return schema_lib.GraphSchema.from_json(f.read()) # pyrefly: ignore[missing-attribute]


def write_schema(schema: schema_lib.GraphSchema, path: str):
Expand All @@ -52,4 +52,4 @@ def write_schema(schema: schema_lib.GraphSchema, path: str):
path: Output path.
"""
with filesystem.open_write(path) as f:
f.write(schema.to_json(indent=2))
f.write(schema.to_json(indent=2)) # pyrefly: ignore[missing-attribute]
Loading