From 9ee984c255d6ee86137c68c9670ff7b483fa21a0 Mon Sep 17 00:00:00 2001 From: Hana Joo Date: Thu, 2 Jul 2026 03:32:52 -0700 Subject: [PATCH] Automated Code Change PiperOrigin-RevId: 941598701 --- dgf/src/analyse/feature_statistics.py | 4 ++-- dgf/src/analyse/reports/visual_utils.py | 2 +- dgf/src/analyse/schema.py | 4 ++-- dgf/src/analyse/schema_test.py | 4 ++-- .../analyse/topology/global_graph_topology.py | 4 ++-- .../beam/learning/ten_lines/predict_test.py | 2 +- dgf/src/beam/runners.py | 4 ++-- dgf/src/data/distributed_graph.py | 4 ++-- dgf/src/data/evaluation_test.py | 8 +++---- dgf/src/learning/config.py | 6 ++--- dgf/src/learning/feature.py | 2 +- dgf/src/learning/feature_test.py | 2 +- dgf/src/learning/jax/common.py | 2 +- dgf/src/learning/jax/flax_train.py | 6 ++--- .../layers/hetero_graph_attention_network.py | 10 ++++----- .../jax/layers/homo_gnn_sparse_deferred.py | 2 +- dgf/src/learning/jax/layers/preprocess.py | 8 +++---- dgf/src/learning/jax/layers/residual_mlp.py | 4 ++-- dgf/src/learning/jax/message_passing.py | 4 ++-- .../beam_semi_distributed_sampler_v1.py | 4 ++-- .../beam_semi_distributed_sampler_v2.py | 14 ++++++------ .../spanner_graph_sampler_integration_test.py | 2 +- dgf/src/sampling/in_memory_sampler.py | 2 +- dgf/src/util/dataclass_registry_test.py | 14 ++++++------ dgf/src/util/gen_test_graph.py | 22 +++++++++---------- dgf/src/util/test_util.py | 4 ++-- examples/create_graph_samples_distributed.py | 12 +++++----- ...reate_graph_samples_semi_distributed_v1.py | 2 +- .../create_in_memory_graph_reports.py | 2 +- .../experimental/train_model_with_tf_gnn.py | 2 +- 30 files changed, 81 insertions(+), 81 deletions(-) diff --git a/dgf/src/analyse/feature_statistics.py b/dgf/src/analyse/feature_statistics.py index 8a7e0bf..f87bdb7 100644 --- a/dgf/src/analyse/feature_statistics.py +++ b/dgf/src/analyse/feature_statistics.py @@ -206,7 +206,7 @@ def merge_accumulators( if _compute_numerical_quantiles(feature_schema): assert merged_feature_accumulator.quantiles is not None merged_feature_accumulator.quantiles.add_reservoir( - feature_accumulator.quantiles + feature_accumulator.quantiles # pyrefly: ignore[bad-argument-type] ) if feature_accumulator.dictionary is not None: @@ -241,7 +241,7 @@ def prune_accumulator( for feature_name, feature_schema in schema.items(): if _compute_dictionary(feature_schema): prune_dictionary_before_wiring( - accumulator.features[feature_name].dictionary, config + accumulator.features[feature_name].dictionary, config # pyrefly: ignore[bad-argument-type] ) return accumulator diff --git a/dgf/src/analyse/reports/visual_utils.py b/dgf/src/analyse/reports/visual_utils.py index da23174..1ae3185 100644 --- a/dgf/src/analyse/reports/visual_utils.py +++ b/dgf/src/analyse/reports/visual_utils.py @@ -174,7 +174,7 @@ def _nx_graph_to_pyvis_data( edge_data = {"from": int(u), "to": int(v)} if "weight" in attrs: edge_data["value"] = attrs["weight"] - edge_data["title"] = f"Weight: {attrs['weight']}" + edge_data["title"] = f"Weight: {attrs['weight']}" # pyrefly: ignore[bad-assignment] if "color" in attrs: edge_data["color"] = attrs["color"] diff --git a/dgf/src/analyse/schema.py b/dgf/src/analyse/schema.py index aea0a24..88a96dd 100644 --- a/dgf/src/analyse/schema.py +++ b/dgf/src/analyse/schema.py @@ -200,10 +200,10 @@ def fix_suspicious_shape( ) -> schema_lib.Shape: log.info("Fix suspicious shape of feature '%s'", feature_name) assert shape_is_suspicious(shape) - if shape[1] == 1: + if shape[1] == 1: # pyrefly: ignore[unsupported-operation] return tuple() else: - return shape[1:] + return shape[1:] # pyrefly: ignore[unsupported-operation] for nodeset_name, nodeset_def in schema.node_sets.items(): if tf_example: diff --git a/dgf/src/analyse/schema_test.py b/dgf/src/analyse/schema_test.py index d97e1e3..34815c0 100644 --- a/dgf/src/analyse/schema_test.py +++ b/dgf/src/analyse/schema_test.py @@ -469,7 +469,7 @@ def is_integer(self): "nodes": NodeSchema( features={ "f_weird": FeatureSchema( - format=FakeFormat(), + format=FakeFormat(), # pyrefly: ignore[bad-argument-type] semantic=FeatureSemantic.UNKNOWN, ), } @@ -497,7 +497,7 @@ def is_integer(self): "nodes": NodeSchema( features={ "f_weird": FeatureSchema( - format=FakeFormat(), + format=FakeFormat(), # pyrefly: ignore[bad-argument-type] semantic=FeatureSemantic.UNKNOWN, ), } diff --git a/dgf/src/analyse/topology/global_graph_topology.py b/dgf/src/analyse/topology/global_graph_topology.py index a51122b..e206204 100644 --- a/dgf/src/analyse/topology/global_graph_topology.py +++ b/dgf/src/analyse/topology/global_graph_topology.py @@ -104,7 +104,7 @@ def get_in_memory_graph_topology( total_edges = 0 for node_set in graph.node_sets.values(): - total_nodes += node_set.num_nodes + total_nodes += node_set.num_nodes # pyrefly: ignore[unsupported-operation] for edge_set in graph.edge_sets.values(): total_edges += edge_set.num_edges() @@ -122,7 +122,7 @@ def get_in_memory_graph_topology( ## Connected Components cc = np.array([]) - num_cc, cc_counts = np.unique(cc, return_counts=True) + num_cc, cc_counts = np.unique(cc, return_counts=True) # pyrefly: ignore[no-matching-overload] largest_cc = np.max(cc_counts).item() else: average_degree = None diff --git a/dgf/src/beam/learning/ten_lines/predict_test.py b/dgf/src/beam/learning/ten_lines/predict_test.py index 3d3b789..9d2069c 100644 --- a/dgf/src/beam/learning/ten_lines/predict_test.py +++ b/dgf/src/beam/learning/ten_lines/predict_test.py @@ -104,7 +104,7 @@ def test_base(self): ) # Generate reference predictions - seed_node_idxs = np.arange(graph.node_sets["client"].num_nodes) + seed_node_idxs = np.arange(graph.node_sets["client"].num_nodes) # pyrefly: ignore[no-matching-overload] expected_raw_predictions = model.predict(graph, seed_node_idxs) expected_predictions = [] diff --git a/dgf/src/beam/runners.py b/dgf/src/beam/runners.py index ca09511..e893637 100644 --- a/dgf/src/beam/runners.py +++ b/dgf/src/beam/runners.py @@ -107,9 +107,9 @@ def to_options_dict(self) -> Dict[str, Any]: if self.machine_type is not None: opts["machine_type"] = self.machine_type if self.num_workers is not None: - opts["num_workers"] = self.num_workers + opts["num_workers"] = self.num_workers # pyrefly: ignore[bad-assignment] if self.max_num_workers is not None: - opts["max_num_workers"] = self.max_num_workers + opts["max_num_workers"] = self.max_num_workers # pyrefly: ignore[bad-assignment] if self.setup_file is not None: opts["setup_file"] = self.setup_file if self.sdk_container_image is not None: diff --git a/dgf/src/data/distributed_graph.py b/dgf/src/data/distributed_graph.py index 82586a3..ed929b5 100644 --- a/dgf/src/data/distributed_graph.py +++ b/dgf/src/data/distributed_graph.py @@ -166,7 +166,7 @@ def heterogeneous_graph_from_pieces( pedge_features = None if edge_features is not None: pedge_features = {} - for name, edge_features in edge_features.items(): + for name, edge_features in edge_features.items(): # pyrefly: ignore[bad-assignment] pedge_features[name] = ( p | f"{stage_prefix}CreateEdgeFeatures_{name}" @@ -177,7 +177,7 @@ def heterogeneous_graph_from_pieces( schema=schema, node_sets=pnode_sets, edge_sets=pedge_sets, - edge_format=edge_format, + edge_format=edge_format, # pyrefly: ignore[bad-argument-type] edge_features=pedge_features, ) diff --git a/dgf/src/data/evaluation_test.py b/dgf/src/data/evaluation_test.py index b8bfac1..9be00e7 100644 --- a/dgf/src/data/evaluation_test.py +++ b/dgf/src/data/evaluation_test.py @@ -53,11 +53,11 @@ def test_evaluation_json(self): num_examples=100, user_metrics={"f1": 0.85}, ) - json_str = eval_obj.to_json() + json_str = eval_obj.to_json() # pyrefly: ignore[missing-attribute] self.assertIsNotNone(json_str) # Test from_json - loaded_eval = evaluation.Evaluation.from_json(json_str) + loaded_eval = evaluation.Evaluation.from_json(json_str) # pyrefly: ignore[missing-attribute] self.assertEqual(loaded_eval.loss, 0.1) self.assertEqual(loaded_eval.accuracy, 0.95) self.assertEqual(loaded_eval.num_examples, 100) @@ -95,11 +95,11 @@ def test_per_class_serialization(self): fn=np.array([7, 8]), thresholds=np.array([0.5, 0.1]), ) - json_str = pc.to_json() + json_str = pc.to_json() # pyrefly: ignore[missing-attribute] self.assertIsNotNone(json_str) self.assertIn('"tp": [1, 2]', json_str) - loaded = evaluation.PerClass.from_json(json_str) + loaded = evaluation.PerClass.from_json(json_str) # pyrefly: ignore[missing-attribute] self.assertEqual(loaded.auc_value, 0.9) self.assertEqual(loaded.pr_auc_value, 0.8) np.testing.assert_array_equal(loaded.tp, np.array([1, 2])) diff --git a/dgf/src/learning/config.py b/dgf/src/learning/config.py index b2238c3..7ea4aaf 100644 --- a/dgf/src/learning/config.py +++ b/dgf/src/learning/config.py @@ -27,7 +27,7 @@ @dataclasses.dataclass(frozen=True, kw_only=True) -class Config(Protocol[_T]): +class Config(Protocol[_T]): # pyrefly: ignore[bad-class-definition] """Base class for configurations that can instantiate objects. Subclasses should, at least, implement `name`, `make`. Other functions such as @@ -48,7 +48,7 @@ def name(self) -> str: def to_dict(self) -> dict[str, Any]: """Convert the configuration to a dictionary.""" - params = dataclasses.asdict(self) + params = dataclasses.asdict(self) # pyrefly: ignore[bad-argument-type] params["name"] = self.name() return params @@ -81,4 +81,4 @@ def json_load(cls, path: str) -> None: params = json.loads(cfg_str) if "name" in params: del params["name"] - return cls(**params) + return cls(**params) # pyrefly: ignore[bad-return] diff --git a/dgf/src/learning/feature.py b/dgf/src/learning/feature.py index 012d2f6..9c45a02 100644 --- a/dgf/src/learning/feature.py +++ b/dgf/src/learning/feature.py @@ -134,7 +134,7 @@ def _process_single_feature( raw_value = jnp.expand_dims(raw_value, axis=1) # Create an embedding table. embedding = nn.Embed( - num_embeddings=feature_schema.num_categorical_values, + num_embeddings=feature_schema.num_categorical_values, # pyrefly: ignore[bad-argument-type] features=self.categorical_feature_embedding_dim, name=f"embed_{nodeset_name}_{feature_name}", ) diff --git a/dgf/src/learning/feature_test.py b/dgf/src/learning/feature_test.py index fcd5243..98cc703 100644 --- a/dgf/src/learning/feature_test.py +++ b/dgf/src/learning/feature_test.py @@ -114,7 +114,7 @@ def test_embed_categorical_features(self): node_embeddings = embed_module.apply(params, graph, training=True) self.assertEqual(params["params"]["embed_n1_f1"]["embedding"].shape, (3, 5)) self.assertEqual(params["params"]["embed_n1_f2"]["embedding"].shape, (4, 5)) - self.assertEqual(node_embeddings["n1"].shape, (2, 3 * 5)) + self.assertEqual(node_embeddings["n1"].shape, (2, 3 * 5)) # pyrefly: ignore[bad-index] def test_embed_unsupported_embedding_format_raises_error(self): schema = schema_lib.GraphSchema( diff --git a/dgf/src/learning/jax/common.py b/dgf/src/learning/jax/common.py index 3de72fb..c878cad 100644 --- a/dgf/src/learning/jax/common.py +++ b/dgf/src/learning/jax/common.py @@ -76,7 +76,7 @@ def architecture(self) -> str: def jnp_name_from_dtype(dtype: jnp.dtype) -> str: """Return a string name for a jnp.dtype object.""" - return dtype.__name__ + return dtype.__name__ # pyrefly: ignore[missing-attribute] @dataclasses.dataclass(frozen=True, kw_only=True) diff --git a/dgf/src/learning/jax/flax_train.py b/dgf/src/learning/jax/flax_train.py index 4bcdfdc..25799b1 100644 --- a/dgf/src/learning/jax/flax_train.py +++ b/dgf/src/learning/jax/flax_train.py @@ -323,13 +323,13 @@ def train( writers = [metric_writers.LoggingWriter()] if working_path is not None: writers.append( - metric_writers.SummaryWriter(os.path.join(working_path, "summary")) + metric_writers.SummaryWriter(os.path.join(working_path, "summary")) # pyrefly: ignore[bad-argument-type] ) if export_metrics_to_xm: try: from clu.metric_writers import XmMeasurementsWriter - writers.append(XmMeasurementsWriter(asynchronous=True)) + writers.append(XmMeasurementsWriter(asynchronous=True)) # pyrefly: ignore[bad-argument-type] except ImportError: pass metric_writer = metric_writers.MultiWriter(writers) @@ -404,7 +404,7 @@ def run_valid_logs(): with jax.profiler.TraceAnnotation("validation"): start_time = time.time() - for batch in valid_dataset_iterator_fn(): + for batch in valid_dataset_iterator_fn(): # pyrefly: ignore[not-callable] with jax.profiler.TraceAnnotation("valid step"): step_valid_metrics = valid_step( params=model_params, diff --git a/dgf/src/learning/jax/layers/hetero_graph_attention_network.py b/dgf/src/learning/jax/layers/hetero_graph_attention_network.py index 7efb5bf..535369a 100644 --- a/dgf/src/learning/jax/layers/hetero_graph_attention_network.py +++ b/dgf/src/learning/jax/layers/hetero_graph_attention_network.py @@ -174,11 +174,11 @@ def __call__( combined = jnp.concatenate( [dst_values, jnp.zeros((num_dst_nodes, dims))], axis=1 ) - combined = config.update.make(name=f"update_{dst_nodeset_name}")( + combined = config.update.make(name=f"update_{dst_nodeset_name}")( # pyrefly: ignore[unexpected-keyword] combined, training=training ) node_values = combined + dst_values - node_values = config.post.make(name=f"post_{dst_nodeset_name}")( + node_values = config.post.make(name=f"post_{dst_nodeset_name}")( # pyrefly: ignore[unexpected-keyword] node_values, training=training ) new_node_sets[dst_nodeset_name] = ( @@ -232,7 +232,7 @@ def __call__( ) # Apply config.message on the edges to get Values - messages = config.message.make(name=f"message_{relation_name}")( + messages = config.message.make(name=f"message_{relation_name}")( # pyrefly: ignore[unexpected-keyword] edge_values, training=training ) # [E, dims] messages = messages.reshape(messages.shape[0], num_heads, head_dim) @@ -273,7 +273,7 @@ def __call__( # Join messages + update combined = jnp.concatenate([dst_values, aggregated_messages], axis=1) - combined = config.update.make(name=f"update_{dst_nodeset_name}")( + combined = config.update.make(name=f"update_{dst_nodeset_name}")( # pyrefly: ignore[unexpected-keyword] combined, training=training ) @@ -281,7 +281,7 @@ def __call__( node_values = combined + dst_values # Feed-forward - node_values = config.post.make(name=f"post_{dst_nodeset_name}")( + node_values = config.post.make(name=f"post_{dst_nodeset_name}")( # pyrefly: ignore[unexpected-keyword] node_values, training=training ) diff --git a/dgf/src/learning/jax/layers/homo_gnn_sparse_deferred.py b/dgf/src/learning/jax/layers/homo_gnn_sparse_deferred.py index 345cbcf..721e23b 100644 --- a/dgf/src/learning/jax/layers/homo_gnn_sparse_deferred.py +++ b/dgf/src/learning/jax/layers/homo_gnn_sparse_deferred.py @@ -481,7 +481,7 @@ def __call__(self, graph: GraphStruct, training: bool = False) -> GraphStruct: ) if self.enable_gnn_plus: - h_next = self.post_graph_conv(h_prev, h_next, layer_index, training) + h_next = self.post_graph_conv(h_prev, h_next, layer_index, training) # pyrefly: ignore[not-callable] else: h_next = self.activation(h_next) diff --git a/dgf/src/learning/jax/layers/preprocess.py b/dgf/src/learning/jax/layers/preprocess.py index 6a9cee8..9f0ef1d 100644 --- a/dgf/src/learning/jax/layers/preprocess.py +++ b/dgf/src/learning/jax/layers/preprocess.py @@ -74,8 +74,8 @@ def output_schema( for _, feature_schema in schema.items(): if feature_schema.semantic == schema_lib.FeatureSemantic.EMBEDDING: shape = feature_schema.shape - num_output_dims += ( - feature_schema.shape[0] + num_output_dims += ( # pyrefly: ignore[unsupported-operation] + feature_schema.shape[0] # pyrefly: ignore[unsupported-operation] if shape is not None and shape is not tuple() else 1 ) @@ -173,7 +173,7 @@ def __call__( ) # Create an embedding table embedding = nn.Embed( - num_embeddings=feature_schema.num_categorical_values, + num_embeddings=feature_schema.num_categorical_values, # pyrefly: ignore[bad-argument-type] features=self.config.categorical_feature_embedding_dim, name=f"embed_{feature_name}", ) @@ -405,7 +405,7 @@ def __post_init__(self): ) self._homogenizer = homogenize_lib.Homogenizer(projected_schema) - self.output_schema = self._homogenizer.output_schema() + self.output_schema = self._homogenizer.output_schema() # pyrefly: ignore[bad-assignment] super().__post_init__() @nn.compact diff --git a/dgf/src/learning/jax/layers/residual_mlp.py b/dgf/src/learning/jax/layers/residual_mlp.py index 710170d..dbafa9d 100644 --- a/dgf/src/learning/jax/layers/residual_mlp.py +++ b/dgf/src/learning/jax/layers/residual_mlp.py @@ -37,7 +37,7 @@ class ResidualMLP(nn.Module): name_prefix: str = "residual_mlp" def setup(self): - self.activation = common.get_activation(self.activation) + self.activation = common.get_activation(self.activation) # pyrefly: ignore[bad-assignment] self.hidden_layer = nn.Dense( features=self.hidden_dim, @@ -61,7 +61,7 @@ def __call__( self, x: jt.Float[jt.Array, "... D"], training: bool = False ) -> jt.Float[jt.Array, "... D"]: return self.output_layer( - self.activation(self.hidden_layer(x)) + self.activation(self.hidden_layer(x)) # pyrefly: ignore[not-callable] ) + self.residual_layer(x) diff --git a/dgf/src/learning/jax/message_passing.py b/dgf/src/learning/jax/message_passing.py index 495c738..1d26df8 100644 --- a/dgf/src/learning/jax/message_passing.py +++ b/dgf/src/learning/jax/message_passing.py @@ -58,6 +58,6 @@ def graph_to_sd_sparse_matrix( target_nodeset = schema.edge_sets[edgeset_name].target return core_graph_to_sd_sparse_matrix( adjacency, - graph.node_sets[source_nodeset].num_nodes, - graph.node_sets[target_nodeset].num_nodes, + graph.node_sets[source_nodeset].num_nodes, # pyrefly: ignore[bad-argument-type] + graph.node_sets[target_nodeset].num_nodes, # pyrefly: ignore[bad-argument-type] ) diff --git a/dgf/src/sampling/beam_semi_distributed_sampler_v1.py b/dgf/src/sampling/beam_semi_distributed_sampler_v1.py index 6704aeb..e76ce8c 100644 --- a/dgf/src/sampling/beam_semi_distributed_sampler_v1.py +++ b/dgf/src/sampling/beam_semi_distributed_sampler_v1.py @@ -285,8 +285,8 @@ def _edge_to_np_array( source_node_id_to_idx: Dict[bytes, int], target_node_id_to_idx: Dict[bytes, int], ) -> Tuple[int, int]: - source_node_idx = source_node_id_to_idx[edge.source] - target_node_idx = target_node_id_to_idx[edge.target] + source_node_idx = source_node_id_to_idx[edge.source] # pyrefly: ignore[bad-index] + target_node_idx = target_node_id_to_idx[edge.target] # pyrefly: ignore[bad-index] return (source_node_idx, target_node_idx) return ( diff --git a/dgf/src/sampling/beam_semi_distributed_sampler_v2.py b/dgf/src/sampling/beam_semi_distributed_sampler_v2.py index 6f6733b..b72a79b 100644 --- a/dgf/src/sampling/beam_semi_distributed_sampler_v2.py +++ b/dgf/src/sampling/beam_semi_distributed_sampler_v2.py @@ -255,7 +255,7 @@ def process( # Emit the samples for seed, sample in zip(seeds, samples): - yield distributed_graph.KeyedInMemoryGraph(seed, sample) + yield distributed_graph.KeyedInMemoryGraph(seed, sample) # pyrefly: ignore[bad-argument-type] def add_features_to_graph_samples( @@ -435,7 +435,7 @@ def Stage6IndexBySampleId( # ensures there is only one element. d_list = d["s"] for sample_id, node_idx in d_list: - yield sample_id, (node_idx, features) + yield sample_id, (node_idx, features) # pyrefly: ignore[invalid-yield] def Stage8IndexSample( @@ -474,14 +474,14 @@ def Stage8AddFeatureValueToSample( ].features.items(): if feature_name == KEY_ID: continue - values = [None] * num_nodes + values = [None] * num_nodes # pyrefly: ignore[unsupported-operation] num_values = 0 for node_idx, row_features in src_features: # pytype: disable=attribute-error - values[node_idx] = row_features[feature_name] + values[node_idx] = row_features[feature_name] # pyrefly: ignore[unsupported-operation] num_values += 1 assert num_nodes == num_values # TODO(gbm): Handle variable length features. - dst_features[feature_name] = safe_stack(values, feature_schema) + dst_features[feature_name] = safe_stack(values, feature_schema) # pyrefly: ignore[bad-argument-type] dst_features[KEY_ID] = raw_nodeset.features[KEY_ID] augmented_nodesets[nodeset_name] = in_memory_graph_lib.InMemoryNodeSet( @@ -489,7 +489,7 @@ def Stage8AddFeatureValueToSample( ) return distributed_graph.KeyedInMemoryGraph( - sample_id, + sample_id, # pyrefly: ignore[bad-argument-type] in_memory_graph_lib.InMemoryGraph( node_sets=augmented_nodesets, edge_sets=raw_graph.edge_sets, @@ -503,7 +503,7 @@ def safe_stack( """Stacks feature value arrays, handling static and variable shapes.""" try: if not values: - return np.empty( + return np.empty( # pyrefly: ignore[no-matching-overload] dtype=feature_format_lib.FEATURE_FORMAT_TO_NP_DTYPE[schema.format], shape=(0,) + (schema.shape or ()), ) diff --git a/dgf/src/sampling/gcp/spanner_graph_sampler_integration_test.py b/dgf/src/sampling/gcp/spanner_graph_sampler_integration_test.py index 42e11a3..917d245 100644 --- a/dgf/src/sampling/gcp/spanner_graph_sampler_integration_test.py +++ b/dgf/src/sampling/gcp/spanner_graph_sampler_integration_test.py @@ -103,7 +103,7 @@ def print_graph_stats( seed_ids: List[bytes], ): for sample, seed_id in zip(samples, seed_ids): - num_nodes = sum(nodeset.num_nodes for nodeset in sample.node_sets.values()) + num_nodes = sum(nodeset.num_nodes for nodeset in sample.node_sets.values()) # pyrefly: ignore[no-matching-overload] num_edges = sum( edgeset.num_edges() for edgeset in sample.edge_sets.values() ) diff --git a/dgf/src/sampling/in_memory_sampler.py b/dgf/src/sampling/in_memory_sampler.py index 94c0c8e..70d515a 100644 --- a/dgf/src/sampling/in_memory_sampler.py +++ b/dgf/src/sampling/in_memory_sampler.py @@ -361,7 +361,7 @@ def create_sampler( # TODO(gbm): Use batch_size for async sampling. if num_threads is None: - num_threads = min(batch_size, os.cpu_count()) + num_threads = min(batch_size, os.cpu_count()) # pyrefly: ignore[bad-specialization] cc_sampler = _in_memory_sampler_ext.CreateSampler( graph, plan, debug_sampling, num_threads, seed, schema, edgeset_to_mask diff --git a/dgf/src/util/dataclass_registry_test.py b/dgf/src/util/dataclass_registry_test.py index 6b74302..5d2cbb1 100644 --- a/dgf/src/util/dataclass_registry_test.py +++ b/dgf/src/util/dataclass_registry_test.py @@ -41,14 +41,14 @@ class RegisterTest(parameterized.TestCase): def test_base(self): b = B(a=A(2)) - b_json = b.to_json() + b_json = b.to_json() # pyrefly: ignore[missing-attribute] self.assertEqual(b_json, '{"a": {"x": 2, "__type": "my_registry.A"}}') - new_b = B.from_json(b_json) + new_b = B.from_json(b_json) # pyrefly: ignore[missing-attribute] test_util.assert_are_equal(self, b, new_b) def test_none(self): b = B(a=None) - new_b = B.from_json(b.to_json()) + new_b = B.from_json(b.to_json()) # pyrefly: ignore[missing-attribute] test_util.assert_are_equal(self, b, new_b) def test_double_registration(self): @@ -68,19 +68,19 @@ class C: b = B(a=C(2)) with self.assertRaises(ValueError): - b.to_json() + b.to_json() # pyrefly: ignore[missing-attribute] def test_invalid_json_missing_field(self): with self.assertRaises(KeyError): - _ = B.from_json('{"a": {"__type": "my_registry.A"}}') + _ = B.from_json('{"a": {"__type": "my_registry.A"}}') # pyrefly: ignore[missing-attribute] def test_invalid_json_extra_field(self): with self.assertRaises(dataclasses_json.undefined.UndefinedParameterError): - _ = B.from_json('{"a": {"x": 2, "y": 3, "__type": "my_registry.A"}}') + _ = B.from_json('{"a": {"x": 2, "y": 3, "__type": "my_registry.A"}}') # pyrefly: ignore[missing-attribute] def test_invalid_json_wrong_field_type(self): with self.assertRaises(ValueError): - _ = B.from_json('{"a": {"x": "hello", "__type": "my_registry.A"}}') + _ = B.from_json('{"a": {"x": "hello", "__type": "my_registry.A"}}') # pyrefly: ignore[missing-attribute] if __name__ == '__main__': diff --git a/dgf/src/util/gen_test_graph.py b/dgf/src/util/gen_test_graph.py index 2fd7400..5ed8588 100644 --- a/dgf/src/util/gen_test_graph.py +++ b/dgf/src/util/gen_test_graph.py @@ -346,7 +346,7 @@ def generate_gf_graph( # Metadata metadata = gf_metadata_lib.GFGraphMetadata(version=0) with open(os.path.join(path, "metadata.json"), "w") as f: - f.write(metadata.to_json(indent=2)) + f.write(metadata.to_json(indent=2)) # pyrefly: ignore[missing-attribute] # Node features pq.write_table( @@ -797,7 +797,7 @@ def generate_avro_graph( shard_idx=0, num_shards=2, extension=".avro", - schema=n1_schema, + schema=n1_schema, # pyrefly: ignore[bad-argument-type] records=[{ "#id": b"1", "f1": [b"blue"], @@ -810,7 +810,7 @@ def generate_avro_graph( shard_idx=1, num_shards=2, extension=".avro", - schema=n1_schema, + schema=n1_schema, # pyrefly: ignore[bad-argument-type] records=[{ "#id": b"2", "f1": [b"red"], @@ -825,17 +825,17 @@ def generate_avro_graph( {"name": "f4", "type": "long"}, ] if variable_length: - n2_fields.append({"name": "f5", "type": {"type": "array", "items": "long"}}) + n2_fields.append({"name": "f5", "type": {"type": "array", "items": "long"}}) # pyrefly: ignore[bad-argument-type] n2_schema_dict = {"type": "record", "name": "n2", "fields": n2_fields} n2_schema = parse_schema(n2_schema_dict) n2_record1 = {"#id": 1, "f3": 4, "f4": 10} if variable_length: - n2_record1["f5"] = [11, 12] + n2_record1["f5"] = [11, 12] # pyrefly: ignore[bad-assignment] n2_record2 = {"#id": 2, "f3": 5, "f4": 11} if variable_length: - n2_record2["f5"] = [12, 13, 14] + n2_record2["f5"] = [12, 13, 14] # pyrefly: ignore[bad-assignment] _write_sharded_avro( directory=os.path.join(path, "node_features"), @@ -843,7 +843,7 @@ def generate_avro_graph( shard_idx=0, num_shards=1, extension=".avro", - schema=n2_schema, + schema=n2_schema, # pyrefly: ignore[bad-argument-type] records=[n2_record1, n2_record2], ) @@ -865,7 +865,7 @@ def generate_avro_graph( shard_idx=0, num_shards=1, extension=".avro", - schema=e1_schema, + schema=e1_schema, # pyrefly: ignore[bad-argument-type] records=[ {"#id": b"a", "#source": b"1", "#target": b"1"}, {"#id": b"b", "#source": b"1", "#target": b"2"}, @@ -888,7 +888,7 @@ def generate_avro_graph( shard_idx=0, num_shards=1, extension=".avro", - schema=e2_schema, + schema=e2_schema, # pyrefly: ignore[bad-argument-type] records=[ {"#source": b"1", "#target": 1}, {"#source": b"1", "#target": 2}, @@ -1192,7 +1192,7 @@ def generate_in_memory_graph( n1_features["#id"] = np.array([b"1", b"2"]) n2_features["#id"] = np.array([1, 2]) if variable_length: - n2_features["f5"] = np.array( + n2_features["f5"] = np.array( # pyrefly: ignore[bad-assignment] [np.array([11, 12]), np.array([12, 13, 14])], dtype=np.object_ ) @@ -1512,7 +1512,7 @@ def _get_spanner_graph_metadata_and_features(): } return ( - sgm.SpannerGraphMetadata.from_json(json.dumps(metadata_json)), + sgm.SpannerGraphMetadata.from_json(json.dumps(metadata_json)), # pyrefly: ignore[missing-attribute] feature_formats, feature_semantics, feature_shapes, diff --git a/dgf/src/util/test_util.py b/dgf/src/util/test_util.py index 34dcdc2..febef01 100644 --- a/dgf/src/util/test_util.py +++ b/dgf/src/util/test_util.py @@ -118,9 +118,9 @@ def ret(equal_result: bool) -> bool: # JAX arrays if isinstance(obj1, jnp.ndarray) and isinstance(obj2, jnp.ndarray): if abs_tol is not None: - return ret(jnp.allclose(obj1, obj2, atol=abs_tol)) + return ret(jnp.allclose(obj1, obj2, atol=abs_tol)) # pyrefly: ignore[bad-argument-type] else: - return ret(jnp.array_equal(obj1, obj2)) + return ret(jnp.array_equal(obj1, obj2)) # pyrefly: ignore[bad-argument-type] # TensorFlow arrays if isinstance(obj1, tf.Tensor) and isinstance(obj2, tf.Tensor): diff --git a/examples/create_graph_samples_distributed.py b/examples/create_graph_samples_distributed.py index 642d188..8d4af76 100644 --- a/examples/create_graph_samples_distributed.py +++ b/examples/create_graph_samples_distributed.py @@ -80,14 +80,14 @@ def run(node_spec: dgf.data.ComputeNodeSpec): logging.info("Start manager") dgf.sampling.sample_with_distributed_batching( - graph_path=_INPUT_GRAPH.value, + graph_path=_INPUT_GRAPH.value, # pyrefly: ignore[bad-argument-type] plan=dgf.sampling.SimpleSamplingConfig( - seed_nodeset=_SEED_NODESET.value, - num_hops=_NUM_HOPS.value, - hop_width=_HOP_WIDTH.value, + seed_nodeset=_SEED_NODESET.value, # pyrefly: ignore[bad-argument-type] + num_hops=_NUM_HOPS.value, # pyrefly: ignore[bad-argument-type] + hop_width=_HOP_WIDTH.value, # pyrefly: ignore[bad-argument-type] ), working_directory=_WORKING_DIR.value, - samples_path=_OUTPUT_TF_GRAPH_SAMPLES.value, + samples_path=_OUTPUT_TF_GRAPH_SAMPLES.value, # pyrefly: ignore[bad-argument-type] node_spec=node_spec, ) @@ -97,7 +97,7 @@ def main(argv: Sequence[str]) -> None: raise app.UsageError("Too many command-line arguments.") if _NODE_SPEC.value: # The distribution spec is passed manually (used in manual Borg). - node_spec = dgf.data.ComputeNodeSpec.from_json(_NODE_SPEC.value) + node_spec = dgf.data.ComputeNodeSpec.from_json(_NODE_SPEC.value) # pyrefly: ignore[missing-attribute] else: # The distribution node spec is obtained from the env (used in VertexAI and # TF distribution enviroments). diff --git a/examples/experimental/create_graph_samples_semi_distributed_v1.py b/examples/experimental/create_graph_samples_semi_distributed_v1.py index 7b5214e..c6344d3 100644 --- a/examples/experimental/create_graph_samples_semi_distributed_v1.py +++ b/examples/experimental/create_graph_samples_semi_distributed_v1.py @@ -70,7 +70,7 @@ def pipeline(root: beam.Pipeline): else: override_schema = None graph = dgf.beam.io.read_graphai_hgraph( - root, _INPUT_HGRAPH.value, _INPUT_FORMAT.value, override_schema + root, _INPUT_HGRAPH.value, _INPUT_FORMAT.value, override_schema # pyrefly: ignore[bad-argument-type] ) # Create sampling config diff --git a/examples/experimental/create_in_memory_graph_reports.py b/examples/experimental/create_in_memory_graph_reports.py index a1ba25c..0221d65 100644 --- a/examples/experimental/create_in_memory_graph_reports.py +++ b/examples/experimental/create_in_memory_graph_reports.py @@ -174,7 +174,7 @@ def main(argv: Sequence[str]) -> None: num_hops=known_args.sampling_num_hops, hop_width=known_args.sampling_hop_width, ), - num_threads=os.cpu_count() * 2, + num_threads=os.cpu_count() * 2, # pyrefly: ignore[unsupported-operation] ) num_nodes = graph.node_sets[known_args.seed_nodeset].num_nodes sub_graphs = sampler.sample( diff --git a/examples/experimental/train_model_with_tf_gnn.py b/examples/experimental/train_model_with_tf_gnn.py index 85e8924..bbfa426 100644 --- a/examples/experimental/train_model_with_tf_gnn.py +++ b/examples/experimental/train_model_with_tf_gnn.py @@ -120,7 +120,7 @@ def main(argv: Sequence[str]) -> None: logging.info("Model training started!") result, output_metadata = trainer_tfgnn.train_tfgnn_model( train_samples_path=_TRAIN_SAMPLES.value, - valid_samples_path=_VALID_SAMPLES.value, + valid_samples_path=_VALID_SAMPLES.value, # pyrefly: ignore[bad-argument-type] schema_path=_SCHEMA.value, model_dir=_MODEL_DIR.value, training_config_path=_INPUT_TRAINING_CONFIG.value,