-
Notifications
You must be signed in to change notification settings - Fork 364
Changed weight map to tensor and fix the refit bug #3573
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
# Used for refit | ||
ctx.weight_refit_map[name + " CONSTANT"] = numpy_value.reshape(-1) | ||
ctx.weight_refit_map[name + " CONSTANT"] = torch_value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment the reason of why adding the " Constant"
d7c6735
to
cf064c5
Compare
@@ -321,7 +321,15 @@ def cast_int_or_float_to_bool( | |||
|
|||
|
|||
def to_trt_weights( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can streamline some arguments like why do we need target, name, layer name and weight name? can we derive some of these from others?
41d1248
to
2520a68
Compare
cpu_weights_reference_holder: dict[str, Union[torch.Tensor]] = field( | ||
default_factory=dict | ||
) | ||
|
||
def record_weight(self, name: str, weight: torch.Tensor) -> None: | ||
self.weight_refit_map[name] = weight |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a docstring explaining why we are doing this especially the comment related to self.cpu_weights_reference_holder[name + " CPU_REFERENCE"] = weight
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name " CPU_REFERENCE" is a bit random. Any name could work because all we need is to hold it on CPU. Moreover, since we have refit map, this is actually a bit redundant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need a suffix here? Its just holding references, it should not be inspected later. I dont think it should even be a dictionary
supported_weight_types = ["KERNEL", "BIAS", "CONSTANT"] | ||
assert ( | ||
layer_type_name in supported_layer_types | ||
), f"Unsupported layer type: {layer_type_name}. Please add the layer type to this function to enable refitting." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add the layer type to this function to enable refitting. - what does this mean ? How do we add this ?
ctx: ConversionContext, | ||
value: torch.Tensor, | ||
name: str, | ||
layer_type_name: str, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should use those literal type annotations
weight_type_name in supported_weight_types | ||
), f"Encountered unsupported weight type: {weight_type_name}. Supported types are: {supported_weight_types}. Manually calling to_trt_weights with a custom weight type is not intended for general use." | ||
|
||
if weight_type_name == "CONSTANT" and layer_type_name == "CONSTANT": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the difference between a weight type and a layer type?
Description
Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
Fixes # (issue)
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: