Skip to content

Fixed graph break serialization bug#4360

Open
cehongwang wants to merge 1 commit into
mainfrom
graph-break-serialization
Open

Fixed graph break serialization bug#4360
cehongwang wants to merge 1 commit into
mainfrom
graph-break-serialization

Conversation

@cehongwang

@cehongwang cehongwang commented Jun 23, 2026

Copy link
Copy Markdown
Collaborator

Fix: inline_torch_modules topological-ordering bug in _exporter.py

Context

When you compile with torch_executed_ops={torch.ops.aten.relu.default}, the
partitioner splits the model into alternating submodules:

  • _run_on_acc_* → run in TensorRT engines
  • _run_on_gpu_* → run in PyTorch (the relu fallbacks)

At save time (output_format="exported_program"), transform() flattens these
submodules into a single graph in two passes:

  1. inline_trt_modules — replaces each _run_on_acc_* with an execute_engine
    node + getitem output node(s).
  2. inline_torch_modules — copies each _run_on_gpu_* submodule's body (the
    relu) into the main graph.

Then eliminate_dead_code() runs lint(), which enforces that every node's
inputs are defined before it.

The cause

lint() failed with:

Argument 'getitem' of Node 'relu' was used before it has been defined!

From the dumped graph:

%getitem_8 = getitem(execute_engine_default, 0)    # engine 0 output
%relu      = relu(%getitem)                         # WRONG input
...
%getitem   = getitem(execute_engine_default_1, 0)   # engine 1 output, defined LATER

relu should consume %getitem_8 (engine 0, immediately above it), but it was
wired to %getitem — a completely different node defined further down — hence
"used before defined."

The reason is how inline_torch_modules matched submodule inputs to graph nodes.
It used a helper, get_duplicate_nodes, that paired each fallback submodule's
input placeholders to nodes in the main graph by matching node names:

# old, buggy approach
submodule_duplicate_inputs = [ph for ph in submodule_placeholders if ph.name in gm_node_names]
gm_duplicate_inputs        = [n  for n  in gm.graph.nodes        if n.name  in submodule_input_node_names]
val_map = {sub_ph: gm_node_with_same_name ...}

That assumes node names are stable, unique identifiers — which they are not
after the first pass:

  • The partitioner had already created getitem, getitem_1, … for the
    multi-output engines.
  • inline_trt_modules then created new getitem nodes for the single-output
    engines. Since getitem was taken, FX auto-renamed them getitem_8,
    getitem_9, …

So the relu submodule's input placeholder (named getitem, after its original
producer) name-matched to the unrelated getitem belonging to a later engine,
instead of the engine-0 output it was actually fed by. Result: relu wired to a
node defined later → topological violation. (And even when it didn't crash, it
would silently produce wrong results by reading the wrong tensor.)

The fix

Wire the submodule body to the actual call_module arguments positionally,
ignoring names entirely. The _run_on_gpu_* call node's args are the real
producer nodes, in order, and are already topologically before the call site:

# new approach
submodule_inputs = gm_node.args                              # real producers, in order
submodule_placeholders = [n for n in submodule.graph.nodes if n.op == "placeholder"]
val_map = dict(zip(submodule_placeholders, submodule_inputs))
submodule_output = gm.graph.graph_copy(submodule.graph, val_map)

By pre-seeding val_map with every placeholder → its real input, graph_copy
skips copying placeholders and rewrites each copied body node's args to the
correct producers. Since the producers already exist earlier in the graph and the
body is copied inserting_before(gm_node), ordering is guaranteed correct.

I also deleted the now-unused, name-based get_duplicate_nodes helper and the
placeholder-cleanup branch it required (with positional mapping, no stray
placeholder nodes are ever created).

Why it's robust: gm_node.args is the ground truth of "what feeds this
submodule" — it survives any renaming the earlier passes do, whereas name
matching is coincidental and breaks exactly when two passes both mint getitem
nodes.

Verification

The repro now finishes with max diff: 0.0 (loaded ExportedProgram is
numerically identical to the in-memory module) and no lint/ordering errors.

@meta-cla meta-cla Bot added the cla signed label Jun 23, 2026
@github-actions github-actions Bot added component: core Issues re: The core compiler component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Jun 23, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: api [Python] Issues re: Python API component: core Issues re: The core compiler component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant