Skip to content

Move register_canonicalize to graph.rewriting.utils, Adjust function signature, and enhance AttributeError handling #613

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

OmGhadge
Copy link

@OmGhadge OmGhadge commented Jan 23, 2024

Description

1.Function Relocation:

The register_canonicalize function has been relocated from tensor.rewriting.basic to graph.rewriting.utils. The move encompasses all necessary imports to ensure seamless functionality in the new location.However, during this transition, issues were identified specifically related to implementation which I tried to fix and are described below.

2.Type Mismatch Error Resolution:

The function signature of register function inside register_canonicalize was causing a type mismatch error. This was addressed by changing the input type to Union[RewriteDatabase, NodeRewriter] from Union[RewriteDatabase, Rewriter]
Before:

    if isinstance(node_rewriter, str):

        def register(inner_rewriter: Union[RewriteDatabase, Rewriter]):
            return register_canonicalize(inner_rewriter, node_rewriter, *tags, **kwargs)

        return register

Error:

pytensor\graph\rewriting\utils.py:251: error: Argument 1 to "register_canonicalize" has incompatible type "RewriteDatabase | 
Rewriter"; expected "RewriteDatabase | NodeRewriter | str"

After:

  if isinstance(node_rewriter, str):
        def register(inner_rewriter: Union[RewriteDatabase, NodeRewriter]):
            return register_canonicalize(inner_rewriter, node_rewriter, *tags, **kwargs)
        return register

3.AttributeError Fix:

In register_canonicalize, there were errors related to missing __name__ attributes for node_rewriter:Union[RewriteDatabase, NodeRewriter, str]
This was resolved by using getattr() to handle cases where __name__ is not present .In cases where the attribute is not available, Name=None. (We can think of implementing default name )
before:
name = kwargs.pop("name", None) or node_rewriter.__name__

Error:

pytensor\graph\rewriting\utils.py:255: error: Item "RewriteDatabase" of "RewriteDatabase | NodeRewriter" has no attribute "__name__"
pytensor\graph\rewriting\utils.py:255: error: Item "NodeRewriter" of "RewriteDatabase | NodeRewriter" has no attribute "__name__"

after:
name = kwargs.pop("name", None) or getattr(node_rewriter, "__name__", None)

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • refactor

Sorry, something went wrong.

@OmGhadge
Copy link
Author

Hi @ricardoV94 ,

I've moved the register_canonicalize() from tensor.rewriting.basic to graph.rewriting.utils. While doing so, I encountered some errors in the implementation of the function, which I've attempted to address in this PR. Could you please review and confirm if the changes are valid?

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Move register_canonicalize and similar helpers to rewriting module
1 participant