fix(dynamicemb): traverse nn.Module children in check_emb_collection_modules#355
fix(dynamicemb): traverse nn.Module children in check_emb_collection_modules#355JacoCheung wants to merge 4 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR fixes Confidence Score: 5/5Safe to merge — the traversal fix is correct, the visited-set prevents duplicates, all callers are updated, and the test rewrite actively verifies the counter round-trip. All findings are P2 or lower. Prior P1 (inconsistent return values) was addressed in this PR. The core fix, test additions, and helper consolidation are logically sound with no remaining blocking issues. No files require special attention. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["get_dynamic_emb_module(model)"] --> B["check_emb_collection_modules(module, ret_list, visited)"]
B --> C{"id in visited?"}
C -->|yes| D["return"]
C -->|no| E["visited.add(id)"]
E --> F{"BatchedDynamicEmbeddingTablesV2?"}
F -->|yes| G["ret_list.append(module) then return"]
F -->|no| H{"nn.Module?"}
H -->|yes| I["Traverse _lookups, _emb_modules, _emb_module"]
I --> J["Recurse into each item"]
H -->|yes| K["module.children() - DMP/DDP/Float16Module"]
K --> L["Recurse into each child"]
J --> B
L --> B
Reviews (10): Last reviewed commit: "fix(dynamicemb/test): rewrite check_coun..." | Re-trigger Greptile |
|
Tip: Greploop — Automatically fix all review issues by running Use the Greptile plugin for Claude Code to query reviews, search comments, and manage custom context directly from your terminal. |
5aae19f to
14cb873
Compare
7b79801 to
ad6f688
Compare
|
/build |
4 similar comments
|
/build |
|
/build |
|
/build |
|
/build |
|
❌ Pipeline #48604221 -- failed
Result: 8/14 jobs passed |
|
/build |
ad6f688 to
4356c1a
Compare
|
/build |
|
❔ Pipeline #48747674 -- canceling
|
|
/build |
|
❌ Pipeline #48748868 -- failed
Result: 10/14 jobs passed |
…d MultiTableKVCounter After the admission-counter fusion refactor (NVIDIA#343), `_admission_counter` became a single MultiTableKVCounter (or None) instead of a list of per-table KVCounters. The verification helper still iterated it as a list, so the dump/load counter check was silently no-op on any branch whose `get_dynamic_emb_module` could traverse into DMP wrappers and raised `TypeError: 'MultiTableKVCounter' object is not iterable` on any branch that could. PR NVIDIA#355's `children()` traversal surfaced the latter. Rewrite the check against the current API: - Iterate logical tables via `range(len(table._table_names))`. - Export one table_id's (keys, frequency) with `cnt.table_._batched_export_keys_scores([freq_name], device, table_id)`. - Look those keys up in the peer counter via `cnt_y.table_.lookup(keys, table_ids, score_arg)` and assert both `founds.all()` and that the returned score_out matches the exported frequency. - Handle the None (no-admission) case symmetrically. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
/build |
…modules The function only checked direct attributes, missing EmbeddingCollection wrapped inside DistributedModelParallel or other nn.Module containers. Now recursively walks module.children() to find the embedding submodule. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
check_emb_collection_modules now traverses nn.Module.children() and guards against cycles, so the examples-side duplicate find_dynamicemb_modules is redundant. Route both call sites (cache stats hook + FILL_DYNAMICEMB_TABLES fill path) through dynamicemb's public get_dynamic_emb_module and delete commons/utils/dynamicemb_utils.py. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…d MultiTableKVCounter After the admission-counter fusion refactor (NVIDIA#343), `_admission_counter` became a single MultiTableKVCounter (or None) instead of a list of per-table KVCounters. The verification helper still iterated it as a list, so the dump/load counter check was silently no-op on any branch whose `get_dynamic_emb_module` could traverse into DMP wrappers and raised `TypeError: 'MultiTableKVCounter' object is not iterable` on any branch that could. PR NVIDIA#355's `children()` traversal surfaced the latter. Rewrite the check against the current API: - Iterate logical tables via `range(len(table._table_names))`. - Export one table_id's (keys, frequency) with `cnt.table_._batched_export_keys_scores([freq_name], device, table_id)`. - Look those keys up in the peer counter via `cnt_y.table_.lookup(keys, table_ids, score_arg)` and assert both `founds.all()` and that the returned score_out matches the exported frequency. - Handle the None (no-admission) case symmetrically. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
❔ Pipeline #48764104 -- canceling
|
02b4a92 to
b096feb
Compare
|
/build |
|
❌ Pipeline #48764351 -- failed
Result: 12/14 jobs passed |
Summary
check_emb_collection_modulesonly followed TorchRec's private attributes (_lookups,_emb_modules,_emb_module) but never recursed into standardnn.Module.children(), so it could not discoverBatchedDynamicEmbeddingTablesV2behind wrapper layers (DMP / DDP / Float16Module).children()traversal and avisitedset to prevent re-entry on circular references and duplicateret_listentries.None; callers mutateret_listin place (addresses greptile P2).visitedparameter defaults toNone, so all existing callers work unchanged.commons.utils.dynamicemb_utils.find_dynamicemb_modules; both call sites (dynamicemb_cache_stats,pretrain_gr_ranking'sFILL_DYNAMICEMB_TABLESpath) now use the publicdynamicemb.dump_load.get_dynamic_emb_module.check_counter_table_checkpointintest_embedding_dump_load.pyagainst the fusedMultiTableKVCounterAPI. The old helper iterated_admission_counteras a list of per-tableKVCounters, but since the fusion refactor ([Feature] dynamicemb table fusion and expansion #343, commit 97b80bf) it is a single counter object. The test silently no-op'd on branches whereget_dynamic_emb_modulecouldn't traverse DMP wrappers, and raisedTypeError: 'MultiTableKVCounter' object is not iterableon branches where it could. Now per-logical-table: export (keys, frequency) viacnt.table_._batched_export_keys_scores([freq_name], device, table_id), look up viacnt.table_.lookup(keys, table_ids, score_arg), assertfounds.all()andtorch.equal(frequencies, score_out); handle the None (no-admission) case symmetrically.Closes #353
Test plan
assert_get_dynamic_emb_module_finds_submodulesintest_embedding_dump_load.pyasserts both traversal paths (viafind_sharded_modulesvs directly on the DMP wrapper) return the same set ofBatchedDynamicEmbeddingTablesV2modulesdynamicemb_test_load_dump_8gpus(which exercisescheck_counter_table_checkpoint) — the rewritten helper now actively verifies dump/load counter round-trip instead of silently iterating zero tablestest_embedding_dump_load.py,test_alignment.py,test_lfu_scores.py,test_embedding_admission.pypass without changesget_dynamic_emb_modulefinds modules when called on a DMP-wrapped model (not just a pre-extractedShardedEmbeddingCollection)🤖 Generated with Claude Code