Fixed graph break serialization bug#4360
Open
cehongwang wants to merge 1 commit into
Open
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fix:
inline_torch_modulestopological-ordering bug in_exporter.pyContext
When you compile with
torch_executed_ops={torch.ops.aten.relu.default}, thepartitioner splits the model into alternating submodules:
_run_on_acc_*→ run in TensorRT engines_run_on_gpu_*→ run in PyTorch (the relu fallbacks)At save time (
output_format="exported_program"),transform()flattens thesesubmodules into a single graph in two passes:
inline_trt_modules— replaces each_run_on_acc_*with anexecute_enginenode +
getitemoutput node(s).inline_torch_modules— copies each_run_on_gpu_*submodule's body (therelu) into the main graph.
Then
eliminate_dead_code()runslint(), which enforces that every node'sinputs are defined before it.
The cause
lint()failed with:From the dumped graph:
relushould consume%getitem_8(engine 0, immediately above it), but it waswired to
%getitem— a completely different node defined further down — hence"used before defined."
The reason is how
inline_torch_modulesmatched submodule inputs to graph nodes.It used a helper,
get_duplicate_nodes, that paired each fallback submodule'sinput placeholders to nodes in the main graph by matching node names:
That assumes node names are stable, unique identifiers — which they are not
after the first pass:
getitem,getitem_1, … for themulti-output engines.
inline_trt_modulesthen created newgetitemnodes for the single-outputengines. Since
getitemwas taken, FX auto-renamed themgetitem_8,getitem_9, …So the relu submodule's input placeholder (named
getitem, after its originalproducer) name-matched to the unrelated
getitembelonging to a later engine,instead of the engine-0 output it was actually fed by. Result: relu wired to a
node defined later → topological violation. (And even when it didn't crash, it
would silently produce wrong results by reading the wrong tensor.)
The fix
Wire the submodule body to the actual
call_modulearguments positionally,ignoring names entirely. The
_run_on_gpu_*call node'sargsare the realproducer nodes, in order, and are already topologically before the call site:
By pre-seeding
val_mapwith every placeholder → its real input,graph_copyskips copying placeholders and rewrites each copied body node's args to the
correct producers. Since the producers already exist earlier in the graph and the
body is copied
inserting_before(gm_node), ordering is guaranteed correct.I also deleted the now-unused, name-based
get_duplicate_nodeshelper and theplaceholder-cleanup branch it required (with positional mapping, no stray
placeholder nodes are ever created).
Why it's robust:
gm_node.argsis the ground truth of "what feeds thissubmodule" — it survives any renaming the earlier passes do, whereas name
matching is coincidental and breaks exactly when two passes both mint
getitemnodes.
Verification
The repro now finishes with
max diff: 0.0(loaded ExportedProgram isnumerically identical to the in-memory module) and no lint/ordering errors.