Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions dgf/src/analyse/feature_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def merge_accumulators(
if _compute_numerical_quantiles(feature_schema):
assert merged_feature_accumulator.quantiles is not None
merged_feature_accumulator.quantiles.add_reservoir(
feature_accumulator.quantiles
feature_accumulator.quantiles # pyrefly: ignore[bad-argument-type]
)

if feature_accumulator.dictionary is not None:
Expand Down Expand Up @@ -241,7 +241,7 @@ def prune_accumulator(
for feature_name, feature_schema in schema.items():
if _compute_dictionary(feature_schema):
prune_dictionary_before_wiring(
accumulator.features[feature_name].dictionary, config
accumulator.features[feature_name].dictionary, config # pyrefly: ignore[bad-argument-type]
)
return accumulator

Expand Down
2 changes: 1 addition & 1 deletion dgf/src/analyse/reports/visual_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def _nx_graph_to_pyvis_data(
edge_data = {"from": int(u), "to": int(v)}
if "weight" in attrs:
edge_data["value"] = attrs["weight"]
edge_data["title"] = f"Weight: {attrs['weight']}"
edge_data["title"] = f"Weight: {attrs['weight']}" # pyrefly: ignore[bad-assignment]

if "color" in attrs:
edge_data["color"] = attrs["color"]
Expand Down
4 changes: 2 additions & 2 deletions dgf/src/analyse/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,10 @@ def fix_suspicious_shape(
) -> schema_lib.Shape:
log.info("Fix suspicious shape of feature '%s'", feature_name)
assert shape_is_suspicious(shape)
if shape[1] == 1:
if shape[1] == 1: # pyrefly: ignore[unsupported-operation]
return tuple()
else:
return shape[1:]
return shape[1:] # pyrefly: ignore[unsupported-operation]

for nodeset_name, nodeset_def in schema.node_sets.items():
if tf_example:
Expand Down
4 changes: 2 additions & 2 deletions dgf/src/analyse/schema_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ def is_integer(self):
"nodes": NodeSchema(
features={
"f_weird": FeatureSchema(
format=FakeFormat(),
format=FakeFormat(), # pyrefly: ignore[bad-argument-type]
semantic=FeatureSemantic.UNKNOWN,
),
}
Expand Down Expand Up @@ -497,7 +497,7 @@ def is_integer(self):
"nodes": NodeSchema(
features={
"f_weird": FeatureSchema(
format=FakeFormat(),
format=FakeFormat(), # pyrefly: ignore[bad-argument-type]
semantic=FeatureSemantic.UNKNOWN,
),
}
Expand Down
4 changes: 2 additions & 2 deletions dgf/src/analyse/topology/global_graph_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def get_in_memory_graph_topology(
total_edges = 0

for node_set in graph.node_sets.values():
total_nodes += node_set.num_nodes
total_nodes += node_set.num_nodes # pyrefly: ignore[unsupported-operation]
for edge_set in graph.edge_sets.values():
total_edges += edge_set.num_edges()

Expand All @@ -122,7 +122,7 @@ def get_in_memory_graph_topology(

## Connected Components
cc = np.array([])
num_cc, cc_counts = np.unique(cc, return_counts=True)
num_cc, cc_counts = np.unique(cc, return_counts=True) # pyrefly: ignore[no-matching-overload]
largest_cc = np.max(cc_counts).item()
else:
average_degree = None
Expand Down
2 changes: 1 addition & 1 deletion dgf/src/beam/learning/ten_lines/predict_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test_base(self):
)

# Generate reference predictions
seed_node_idxs = np.arange(graph.node_sets["client"].num_nodes)
seed_node_idxs = np.arange(graph.node_sets["client"].num_nodes) # pyrefly: ignore[no-matching-overload]
expected_raw_predictions = model.predict(graph, seed_node_idxs)

expected_predictions = []
Expand Down
4 changes: 2 additions & 2 deletions dgf/src/beam/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ def to_options_dict(self) -> Dict[str, Any]:
if self.machine_type is not None:
opts["machine_type"] = self.machine_type
if self.num_workers is not None:
opts["num_workers"] = self.num_workers
opts["num_workers"] = self.num_workers # pyrefly: ignore[bad-assignment]
if self.max_num_workers is not None:
opts["max_num_workers"] = self.max_num_workers
opts["max_num_workers"] = self.max_num_workers # pyrefly: ignore[bad-assignment]
if self.setup_file is not None:
opts["setup_file"] = self.setup_file
if self.sdk_container_image is not None:
Expand Down
4 changes: 2 additions & 2 deletions dgf/src/data/distributed_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def heterogeneous_graph_from_pieces(
pedge_features = None
if edge_features is not None:
pedge_features = {}
for name, edge_features in edge_features.items():
for name, edge_features in edge_features.items(): # pyrefly: ignore[bad-assignment]
pedge_features[name] = (
p
| f"{stage_prefix}CreateEdgeFeatures_{name}"
Expand All @@ -177,7 +177,7 @@ def heterogeneous_graph_from_pieces(
schema=schema,
node_sets=pnode_sets,
edge_sets=pedge_sets,
edge_format=edge_format,
edge_format=edge_format, # pyrefly: ignore[bad-argument-type]
edge_features=pedge_features,
)

Expand Down
8 changes: 4 additions & 4 deletions dgf/src/data/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ def test_evaluation_json(self):
num_examples=100,
user_metrics={"f1": 0.85},
)
json_str = eval_obj.to_json()
json_str = eval_obj.to_json() # pyrefly: ignore[missing-attribute]
self.assertIsNotNone(json_str)

# Test from_json
loaded_eval = evaluation.Evaluation.from_json(json_str)
loaded_eval = evaluation.Evaluation.from_json(json_str) # pyrefly: ignore[missing-attribute]
self.assertEqual(loaded_eval.loss, 0.1)
self.assertEqual(loaded_eval.accuracy, 0.95)
self.assertEqual(loaded_eval.num_examples, 100)
Expand Down Expand Up @@ -95,11 +95,11 @@ def test_per_class_serialization(self):
fn=np.array([7, 8]),
thresholds=np.array([0.5, 0.1]),
)
json_str = pc.to_json()
json_str = pc.to_json() # pyrefly: ignore[missing-attribute]
self.assertIsNotNone(json_str)
self.assertIn('"tp": [1, 2]', json_str)

loaded = evaluation.PerClass.from_json(json_str)
loaded = evaluation.PerClass.from_json(json_str) # pyrefly: ignore[missing-attribute]
self.assertEqual(loaded.auc_value, 0.9)
self.assertEqual(loaded.pr_auc_value, 0.8)
np.testing.assert_array_equal(loaded.tp, np.array([1, 2]))
Expand Down
6 changes: 3 additions & 3 deletions dgf/src/learning/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


@dataclasses.dataclass(frozen=True, kw_only=True)
class Config(Protocol[_T]):
class Config(Protocol[_T]): # pyrefly: ignore[bad-class-definition]
"""Base class for configurations that can instantiate objects.

Subclasses should, at least, implement `name`, `make`. Other functions such as
Expand All @@ -48,7 +48,7 @@ def name(self) -> str:

def to_dict(self) -> dict[str, Any]:
"""Convert the configuration to a dictionary."""
params = dataclasses.asdict(self)
params = dataclasses.asdict(self) # pyrefly: ignore[bad-argument-type]
params["name"] = self.name()
return params

Expand Down Expand Up @@ -81,4 +81,4 @@ def json_load(cls, path: str) -> None:
params = json.loads(cfg_str)
if "name" in params:
del params["name"]
return cls(**params)
return cls(**params) # pyrefly: ignore[bad-return]
2 changes: 1 addition & 1 deletion dgf/src/learning/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def _process_single_feature(
raw_value = jnp.expand_dims(raw_value, axis=1)
# Create an embedding table.
embedding = nn.Embed(
num_embeddings=feature_schema.num_categorical_values,
num_embeddings=feature_schema.num_categorical_values, # pyrefly: ignore[bad-argument-type]
features=self.categorical_feature_embedding_dim,
name=f"embed_{nodeset_name}_{feature_name}",
)
Expand Down
2 changes: 1 addition & 1 deletion dgf/src/learning/feature_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_embed_categorical_features(self):
node_embeddings = embed_module.apply(params, graph, training=True)
self.assertEqual(params["params"]["embed_n1_f1"]["embedding"].shape, (3, 5))
self.assertEqual(params["params"]["embed_n1_f2"]["embedding"].shape, (4, 5))
self.assertEqual(node_embeddings["n1"].shape, (2, 3 * 5))
self.assertEqual(node_embeddings["n1"].shape, (2, 3 * 5)) # pyrefly: ignore[bad-index]

def test_embed_unsupported_embedding_format_raises_error(self):
schema = schema_lib.GraphSchema(
Expand Down
2 changes: 1 addition & 1 deletion dgf/src/learning/jax/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def architecture(self) -> str:

def jnp_name_from_dtype(dtype: jnp.dtype) -> str:
"""Return a string name for a jnp.dtype object."""
return dtype.__name__
return dtype.__name__ # pyrefly: ignore[missing-attribute]


@dataclasses.dataclass(frozen=True, kw_only=True)
Expand Down
6 changes: 3 additions & 3 deletions dgf/src/learning/jax/flax_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,13 +323,13 @@ def train(
writers = [metric_writers.LoggingWriter()]
if working_path is not None:
writers.append(
metric_writers.SummaryWriter(os.path.join(working_path, "summary"))
metric_writers.SummaryWriter(os.path.join(working_path, "summary")) # pyrefly: ignore[bad-argument-type]
)
if export_metrics_to_xm:
try:
from clu.metric_writers import XmMeasurementsWriter

writers.append(XmMeasurementsWriter(asynchronous=True))
writers.append(XmMeasurementsWriter(asynchronous=True)) # pyrefly: ignore[bad-argument-type]
except ImportError:
pass
metric_writer = metric_writers.MultiWriter(writers)
Expand Down Expand Up @@ -404,7 +404,7 @@ def run_valid_logs():
with jax.profiler.TraceAnnotation("validation"):

start_time = time.time()
for batch in valid_dataset_iterator_fn():
for batch in valid_dataset_iterator_fn(): # pyrefly: ignore[not-callable]
with jax.profiler.TraceAnnotation("valid step"):
step_valid_metrics = valid_step(
params=model_params,
Expand Down
10 changes: 5 additions & 5 deletions dgf/src/learning/jax/layers/hetero_graph_attention_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,11 @@ def __call__(
combined = jnp.concatenate(
[dst_values, jnp.zeros((num_dst_nodes, dims))], axis=1
)
combined = config.update.make(name=f"update_{dst_nodeset_name}")(
combined = config.update.make(name=f"update_{dst_nodeset_name}")( # pyrefly: ignore[unexpected-keyword]
combined, training=training
)
node_values = combined + dst_values
node_values = config.post.make(name=f"post_{dst_nodeset_name}")(
node_values = config.post.make(name=f"post_{dst_nodeset_name}")( # pyrefly: ignore[unexpected-keyword]
node_values, training=training
)
new_node_sets[dst_nodeset_name] = (
Expand Down Expand Up @@ -232,7 +232,7 @@ def __call__(
)

# Apply config.message on the edges to get Values
messages = config.message.make(name=f"message_{relation_name}")(
messages = config.message.make(name=f"message_{relation_name}")( # pyrefly: ignore[unexpected-keyword]
edge_values, training=training
) # [E, dims]
messages = messages.reshape(messages.shape[0], num_heads, head_dim)
Expand Down Expand Up @@ -273,15 +273,15 @@ def __call__(

# Join messages + update
combined = jnp.concatenate([dst_values, aggregated_messages], axis=1)
combined = config.update.make(name=f"update_{dst_nodeset_name}")(
combined = config.update.make(name=f"update_{dst_nodeset_name}")( # pyrefly: ignore[unexpected-keyword]
combined, training=training
)

# Residual
node_values = combined + dst_values

# Feed-forward
node_values = config.post.make(name=f"post_{dst_nodeset_name}")(
node_values = config.post.make(name=f"post_{dst_nodeset_name}")( # pyrefly: ignore[unexpected-keyword]
node_values, training=training
)

Expand Down
2 changes: 1 addition & 1 deletion dgf/src/learning/jax/layers/homo_gnn_sparse_deferred.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ def __call__(self, graph: GraphStruct, training: bool = False) -> GraphStruct:
)

if self.enable_gnn_plus:
h_next = self.post_graph_conv(h_prev, h_next, layer_index, training)
h_next = self.post_graph_conv(h_prev, h_next, layer_index, training) # pyrefly: ignore[not-callable]
else:
h_next = self.activation(h_next)

Expand Down
8 changes: 4 additions & 4 deletions dgf/src/learning/jax/layers/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def output_schema(
for _, feature_schema in schema.items():
if feature_schema.semantic == schema_lib.FeatureSemantic.EMBEDDING:
shape = feature_schema.shape
num_output_dims += (
feature_schema.shape[0]
num_output_dims += ( # pyrefly: ignore[unsupported-operation]
feature_schema.shape[0] # pyrefly: ignore[unsupported-operation]
if shape is not None and shape is not tuple()
else 1
)
Expand Down Expand Up @@ -173,7 +173,7 @@ def __call__(
)
# Create an embedding table
embedding = nn.Embed(
num_embeddings=feature_schema.num_categorical_values,
num_embeddings=feature_schema.num_categorical_values, # pyrefly: ignore[bad-argument-type]
features=self.config.categorical_feature_embedding_dim,
name=f"embed_{feature_name}",
)
Expand Down Expand Up @@ -405,7 +405,7 @@ def __post_init__(self):
)

self._homogenizer = homogenize_lib.Homogenizer(projected_schema)
self.output_schema = self._homogenizer.output_schema()
self.output_schema = self._homogenizer.output_schema() # pyrefly: ignore[bad-assignment]
super().__post_init__()

@nn.compact
Expand Down
4 changes: 2 additions & 2 deletions dgf/src/learning/jax/layers/residual_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class ResidualMLP(nn.Module):
name_prefix: str = "residual_mlp"

def setup(self):
self.activation = common.get_activation(self.activation)
self.activation = common.get_activation(self.activation) # pyrefly: ignore[bad-assignment]

self.hidden_layer = nn.Dense(
features=self.hidden_dim,
Expand All @@ -61,7 +61,7 @@ def __call__(
self, x: jt.Float[jt.Array, "... D"], training: bool = False
) -> jt.Float[jt.Array, "... D"]:
return self.output_layer(
self.activation(self.hidden_layer(x))
self.activation(self.hidden_layer(x)) # pyrefly: ignore[not-callable]
) + self.residual_layer(x)


Expand Down
4 changes: 2 additions & 2 deletions dgf/src/learning/jax/message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,6 @@ def graph_to_sd_sparse_matrix(
target_nodeset = schema.edge_sets[edgeset_name].target
return core_graph_to_sd_sparse_matrix(
adjacency,
graph.node_sets[source_nodeset].num_nodes,
graph.node_sets[target_nodeset].num_nodes,
graph.node_sets[source_nodeset].num_nodes, # pyrefly: ignore[bad-argument-type]
graph.node_sets[target_nodeset].num_nodes, # pyrefly: ignore[bad-argument-type]
)
4 changes: 2 additions & 2 deletions dgf/src/sampling/beam_semi_distributed_sampler_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,8 @@ def _edge_to_np_array(
source_node_id_to_idx: Dict[bytes, int],
target_node_id_to_idx: Dict[bytes, int],
) -> Tuple[int, int]:
source_node_idx = source_node_id_to_idx[edge.source]
target_node_idx = target_node_id_to_idx[edge.target]
source_node_idx = source_node_id_to_idx[edge.source] # pyrefly: ignore[bad-index]
target_node_idx = target_node_id_to_idx[edge.target] # pyrefly: ignore[bad-index]
return (source_node_idx, target_node_idx)

return (
Expand Down
14 changes: 7 additions & 7 deletions dgf/src/sampling/beam_semi_distributed_sampler_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def process(

# Emit the samples
for seed, sample in zip(seeds, samples):
yield distributed_graph.KeyedInMemoryGraph(seed, sample)
yield distributed_graph.KeyedInMemoryGraph(seed, sample) # pyrefly: ignore[bad-argument-type]


def add_features_to_graph_samples(
Expand Down Expand Up @@ -435,7 +435,7 @@ def Stage6IndexBySampleId(
# ensures there is only one element.
d_list = d["s"]
for sample_id, node_idx in d_list:
yield sample_id, (node_idx, features)
yield sample_id, (node_idx, features) # pyrefly: ignore[invalid-yield]


def Stage8IndexSample(
Expand Down Expand Up @@ -474,22 +474,22 @@ def Stage8AddFeatureValueToSample(
].features.items():
if feature_name == KEY_ID:
continue
values = [None] * num_nodes
values = [None] * num_nodes # pyrefly: ignore[unsupported-operation]
num_values = 0
for node_idx, row_features in src_features: # pytype: disable=attribute-error
values[node_idx] = row_features[feature_name]
values[node_idx] = row_features[feature_name] # pyrefly: ignore[unsupported-operation]
num_values += 1
assert num_nodes == num_values
# TODO(gbm): Handle variable length features.
dst_features[feature_name] = safe_stack(values, feature_schema)
dst_features[feature_name] = safe_stack(values, feature_schema) # pyrefly: ignore[bad-argument-type]

dst_features[KEY_ID] = raw_nodeset.features[KEY_ID]
augmented_nodesets[nodeset_name] = in_memory_graph_lib.InMemoryNodeSet(
num_nodes=num_nodes, features=dst_features
)

return distributed_graph.KeyedInMemoryGraph(
sample_id,
sample_id, # pyrefly: ignore[bad-argument-type]
in_memory_graph_lib.InMemoryGraph(
node_sets=augmented_nodesets,
edge_sets=raw_graph.edge_sets,
Expand All @@ -503,7 +503,7 @@ def safe_stack(
"""Stacks feature value arrays, handling static and variable shapes."""
try:
if not values:
return np.empty(
return np.empty( # pyrefly: ignore[no-matching-overload]
dtype=feature_format_lib.FEATURE_FORMAT_TO_NP_DTYPE[schema.format],
shape=(0,) + (schema.shape or ()),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def print_graph_stats(
seed_ids: List[bytes],
):
for sample, seed_id in zip(samples, seed_ids):
num_nodes = sum(nodeset.num_nodes for nodeset in sample.node_sets.values())
num_nodes = sum(nodeset.num_nodes for nodeset in sample.node_sets.values()) # pyrefly: ignore[no-matching-overload]
num_edges = sum(
edgeset.num_edges() for edgeset in sample.edge_sets.values()
)
Expand Down
Loading