Skip to content

WIP Add table summarizing classification weights accuracies in docs #5741

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .circleci/unittest/linux/scripts/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@ fi

printf "* Installing torchvision\n"
python setup.py develop
pip install tabulate
2 changes: 1 addition & 1 deletion docs/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ ifneq ($(EXAMPLES_PATTERN),)
endif

# You can set these variables from the command line.
SPHINXOPTS = -W -j auto $(EXAMPLES_PATTERN_OPTS)
SPHINXOPTS = -j auto $(EXAMPLES_PATTERN_OPTS)
SPHINXBUILD = sphinx-build
SPHINXPROJ = torchvision
SOURCEDIR = source
Expand Down
31 changes: 31 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,5 +292,36 @@ def inject_minigalleries(app, what, name, obj, options, lines):
lines.append("\n")


def generate_table():

import torchvision.models as M
from tabulate import tabulate
import textwrap

# TODO: this is ugly af and incorrect. We'll need an automatic way to
# retrieve weight enums for each section, or manually list them.
Comment on lines +301 to +302
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any thought on that @datumbox ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you are right. We need a registration mechanism for the models. The new Datasets API has one, so part of the reason I didn't want to invest time creating one is to potentially adopt/extend the one on Datasets. Thoughts?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The registration of the dataset is very very basic, it's just a decorator that adds the callable/object to a private dict. It would probably make sense to use something similar for the models / weights. Whether we should be relying on the same utils though is up for discussion - as a first version I'd suggest not to merge things and for the models to have a separate implementation. The code is really basic anyway.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy with what you propose and on the technical details you mentioned. I had a similar simple approach on the original proposal of the Multiweights support but I didn't port it to adopt some solution in common with Datasets. The code doesn't have to be the same but I think the interface can be basic and similar. As discuss offline, the only thing different for models is the fact that there is a hierarchy (Detection, Optical Flow, Classification etc) and this needs to be taken into account because names across modules conflict (for example resnet50 exists both in Classification and Quantizaztion submodules).

weight_enums = [getattr(M, name) for name in dir(M) if name.endswith("Weights")]
weights = [w for weight_enum in weight_enums for w in weight_enum if "acc@1" in w.meta]

column_names = ("**Weight**", "**Acc@1**", "**Acc@5**", "**Params**", "**Recipe**")
content = [
(str(w), w.meta["acc@1"], w.meta["acc@5"], f"{w.meta['num_params']:e}", f"`link <{w.meta['recipe']}>`__")
for w in weights
]
table = tabulate(content, headers=column_names, tablefmt="rst")
print(table)

from pathlib import Path
generated_dir = Path("generated")
generated_dir.mkdir(exist_ok=True)
with open(generated_dir / "classification_table.rst", "w+") as table_file:
table_file.write(".. table::\n")
table_file.write(" :widths: 100 10 10 20 10\n\n")
table_file.write(f"{textwrap.indent(table, ' ' * 4)}\n\n")



generate_table()

def setup(app):
app.connect("autodoc-process-docstring", inject_minigalleries)
Loading