[Relax] Use optional dtype for absent Relax dtype fields#19890
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors TVM Relax's TensorTypeNode to represent unknown data types using ffi::Optionaltvm::PrimType initialized to std::nullopt instead of an opaque DLDataType placeholder. This change is propagated across various operators, type inference helpers, and transformations. The code review identified several critical issues where calling .value() on the newly introduced dtype optionals could cause crashes when the data type is unknown, specifically in utils.cc, qdq.cc (quantize/dequantize), and to_mixed_precision.cc. Additionally, a minor improvement was suggested in op_common.h to print a cleaner concrete dtype in error messages.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| if (const auto* tensor = ty.as<TensorTypeNode>()) { | ||
| dtype = tensor->dtype->dtype; | ||
| dtype = tensor->dtype.value()->dtype; | ||
| ndim = tensor->ndim; |
There was a problem hiding this comment.
When permit_unknown_dtype is true and tensor->IsUnknownDtype() is true, calling tensor->dtype.value() will throw an exception or crash due to accessing an empty optional. We should safely handle the unknown dtype case when permit_unknown_dtype is enabled.
if (const auto* tensor = ty.as<TensorTypeNode>()) {
if (tensor->IsUnknownDtype()) {
if (!permit_unknown_dtype) {
return false;
}
dtype = DLDataType{kDLOpaqueHandle, 0, 0};
} else {
dtype = tensor->dtype.value()->dtype;
}
ndim = tensor->ndim;| PrimType input_dtype = input_ty->dtype.value(); | ||
| PrimType scale_dtype = scale_ty->dtype.value(); | ||
| PrimType zp_dtype = zp_ty->dtype.value(); |
There was a problem hiding this comment.
If any of the input tensors (input_ty, scale_ty, or zp_ty) has an unknown dtype, calling .value() directly on their dtype fields will crash. Since type inference can be run on expressions with unknown dtypes, we should check IsUnknownDtype() and return a TensorType with an unknown dtype instead of crashing.
| PrimType input_dtype = input_ty->dtype.value(); | |
| PrimType scale_dtype = scale_ty->dtype.value(); | |
| PrimType zp_dtype = zp_ty->dtype.value(); | |
| if (input_ty->IsUnknownDtype() || scale_ty->IsUnknownDtype() || zp_ty->IsUnknownDtype()) { | |
| if (input_ty->shape.defined()) { | |
| return TensorType(input_ty->shape.value(), std::nullopt, input_ty->vdevice); | |
| } else { | |
| return TensorType(std::nullopt, input_ty->ndim, input_ty->vdevice); | |
| } | |
| } | |
| PrimType input_dtype = input_ty->dtype.value(); | |
| PrimType scale_dtype = scale_ty->dtype.value(); | |
| PrimType zp_dtype = zp_ty->dtype.value(); |
| PrimType input_dtype = input_ty->dtype.value(); | ||
| PrimType scale_dtype = scale_ty->dtype.value(); | ||
| PrimType zp_dtype = zp_ty->dtype.value(); |
There was a problem hiding this comment.
If any of the input tensors (input_ty, scale_ty, or zp_ty) has an unknown dtype, calling .value() directly on their dtype fields will crash. Since type inference can be run on expressions with unknown dtypes, we should check IsUnknownDtype() and return a TensorType with an unknown dtype instead of crashing.
if (input_ty->IsUnknownDtype() || scale_ty->IsUnknownDtype() || zp_ty->IsUnknownDtype()) {
if (input_ty->shape.defined()) {
return TensorType(input_ty->shape.value(), std::nullopt, input_ty->vdevice);
} else {
return TensorType(std::nullopt, input_ty->ndim, input_ty->vdevice);
}
}
PrimType input_dtype = input_ty->dtype.value();
PrimType scale_dtype = scale_ty->dtype.value();
PrimType zp_dtype = zp_ty->dtype.value();| // We only rewrite the expr if the dtype is fp16 or fp32, dtypes such as int32, float64 is not | ||
| // supported to be rewritten | ||
| DLDataType tensor_dtype = tensor->dtype->dtype; | ||
| DLDataType tensor_dtype = tensor->dtype.value()->dtype; |
There was a problem hiding this comment.
If tensor->IsUnknownDtype() is true, calling tensor->dtype.value() will crash. We should check if the tensor has an unknown dtype and return the original expression early.
| DLDataType tensor_dtype = tensor->dtype.value()->dtype; | |
| if (tensor->IsUnknownDtype()) return expr; | |
| DLDataType tensor_dtype = tensor->dtype.value()->dtype; |
| TVM_FFI_VISIT_THROW(TypeError, call) | ||
| << call->op | ||
| << " requires the input tensor to have float dtype. However, the given input dtype is " | ||
| << input_ty->dtype; |
There was a problem hiding this comment.
f58a065 to
6bc5adf
Compare
This PR replaces Relax uses of void dtype sentinels for semantically absent dtypes with explicit optional/null representations.
Summary:
TensorTypedtype optional instead of represented asvoid.R.memory.view(..., dtype=None)withR.null_value()at the expression boundary, and lower to a concrete runtime dtype when required.R.memory.view(..., dtype=R.dtype("void"))in inference instead of treating the legacy void spelling as absence.